├── Data_Synthesizer ├── __init__.py ├── requirements.txt ├── pipeline │ ├── sqlite │ │ └── prompt_templates │ │ │ ├── question_synthesis_prompt.txt │ │ │ ├── sqlite_note_prompt.txt │ │ │ ├── cot_synthesis_prompt_template.txt │ │ │ ├── sql_generate_prompt_template.txt │ │ │ ├── sqlite_vec_note_prompt.txt │ │ │ └── find_semantic_rich_column.txt │ ├── clickhouse │ │ └── prompt_templates │ │ │ └── clickhouse_vec_note_prompt.txt │ ├── myscale │ │ └── prompt_templates │ │ │ └── myscale_vec_note_prompt.txt │ └── postgresql │ │ └── prompt_templates │ │ └── postgresql_vec_note_prompt.txt ├── tools │ ├── manage_db.sh │ ├── README.md │ ├── migrate_db.sh │ ├── duplicate.py │ ├── collect_input_llm.py │ ├── add_candidate_prefix.py │ ├── convert_db_id.py │ ├── batch_copy_files.py │ ├── add_prefix.py │ ├── change_embedding_model.py │ └── mix_datasets.py ├── database_synthesis │ ├── prompt_templates │ │ ├── enhance_prompt.txt │ │ ├── embedding_prompt.txt │ │ └── embedding_with_new_line_prompt.txt │ ├── generate_schema_embedding_prompts.py │ ├── generate_schema_enhancement_prompts.py │ ├── README.md │ ├── build_sqlite_databases.py │ ├── embedding_schema.py │ ├── synthesize_schema.py │ └── enhance_schema.py ├── synthesis_sql │ ├── README.md │ └── synthesize_sql.py ├── synthesis_eval │ ├── generate_eval_prompts.py │ └── generate_input.py ├── collect_input_llm.py ├── synthesis_cot │ └── generate_cot_synthesis_prompts.py ├── README.md ├── vectorization │ ├── generate_schema.py │ ├── find_semantic_rich_column.py │ ├── batch_vectorize_databases.py │ └── generate_vector_schema.py └── synthesis_nl │ ├── synthesize_question.py │ └── synthesize_candidate.py ├── Execution_Engine ├── __init__.py ├── requirements.txt ├── engine_config.yaml └── README.md ├── Evaluation_Framework ├── __init__.py ├── requirements.txt ├── script │ ├── README.md │ ├── generate_eval_prompts.py │ ├── generate_query_id.py │ ├── generate_ground_truth.py │ ├── config.yaml.example │ └── api_pipeline.py ├── prompt_templates │ ├── sql_generate_prompt_template.txt │ └── sqlite_vec_note_prompt.txt ├── README.md └── evaluation_config.yaml ├── Figures ├── fig1.png └── fig2.png ├── Embedding_Service ├── requirements.txt ├── run.sh ├── multi_client.py ├── README.md ├── multi_server.py └── server.py └── .gitignore /Data_Synthesizer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Execution_Engine/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Evaluation_Framework/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Figures/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDCAI/Text2VectorSQL/HEAD/Figures/fig1.png -------------------------------------------------------------------------------- /Figures/fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDCAI/Text2VectorSQL/HEAD/Figures/fig2.png -------------------------------------------------------------------------------- /Evaluation_Framework/requirements.txt: -------------------------------------------------------------------------------- 1 | requests 2 | PyYAML 3 | tqdm 4 | numpy 5 | pyparsing 6 | sqlalchemy -------------------------------------------------------------------------------- /Execution_Engine/requirements.txt: -------------------------------------------------------------------------------- 1 | psycopg2-binary 2 | requests 3 | pyyaml 4 | clickhouse-connect 5 | sqlite-vec 6 | sqlite-lembed -------------------------------------------------------------------------------- /Embedding_Service/requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi 2 | uvicorn[standard] 3 | pydantic 4 | sentence-transformers 5 | torch 6 | PyYAML 7 | requests 8 | Pillow 9 | -------------------------------------------------------------------------------- /Data_Synthesizer/requirements.txt: -------------------------------------------------------------------------------- 1 | clickhouse-driver 2 | python-dotenv 3 | func-timeout 4 | json-repair 5 | openai 6 | scipy 7 | sentence-transformers 8 | scikit-learn 9 | tenacity 10 | tqdm 11 | transformers 12 | httpx 13 | ijson 14 | matplotlib 15 | numpy 16 | psycopg2-binary 17 | requests 18 | PyYAML 19 | torch 20 | torchvision 21 | sqlite-vec 22 | sqlite-lembed 23 | -------------------------------------------------------------------------------- /Embedding_Service/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # export HF_ENDPOINT=https://hf-mirror.com 3 | python server.py --config config.yaml 4 | 5 | # curl http://localhost:8000/health 6 | curl -X POST http://localhost:8000/embed \ 7 | -H "Content-Type: application/json" \ 8 | -d '{ 9 | "model": "all-MiniLM-L6-v2", 10 | "texts": [ 11 | "Hello, world!", 12 | "This is a test of the embedding service." 13 | ] 14 | }' 15 | -------------------------------------------------------------------------------- /Evaluation_Framework/script/README.md: -------------------------------------------------------------------------------- 1 | 这个目录用了为评估框架生成输入文件。 2 | 3 | 它读取Data_Synthesizer/pipeline/sqlite/results目录下,数据库目录中的candidate_sql.sql文件。然后产生评估框架需要的ground_truth.json和eval_queries.json文件。 4 | 5 | 你需要先去掉config.yaml.example的".example",然后为其添加你的模型调用API。 6 | 7 | 然后需要修改Evaluation_Framework/script/api_pipeline.py文件中的DATASET_BACKEND = "sqlite" 8 | DATASET_TO_LOAD = "toy_spider"参数,来选择数据库后端和数据库。 9 | 10 | 运行: 11 | ```bash 12 | python api_pipeline.py 13 | ``` 14 | 即可得到最终的文件用于评估。 15 | -------------------------------------------------------------------------------- /Data_Synthesizer/pipeline/sqlite/prompt_templates/question_synthesis_prompt.txt: -------------------------------------------------------------------------------- 1 | **Task Overview** 2 | Your task is to create a high-quality natural language question based on a given SQL query and other information. 3 | {using_knn} 4 | 5 | **Style** 6 | The natural language question should follow this style: 7 | {style_desc} 8 | 9 | **Database Engine** 10 | {engine} 11 | 12 | **Database Extension** 13 | {extension} 14 | 15 | **Column Information** 16 | Below are column names and their corresponding descriptions: 17 | {column_info} 18 | 19 | **SQL Query** 20 | Given SQL query: 21 | ```sql 22 | {sql} 23 | ``` 24 | 25 | **Reasoning Steps** 26 | {steps} 27 | 28 | **Guidelines** 29 | {guidelines} 30 | 31 | **Output Format** 32 | {output_format} 33 | 34 | **Insturction** 35 | {instruction} 36 | -------------------------------------------------------------------------------- /Data_Synthesizer/tools/manage_db.sh: -------------------------------------------------------------------------------- 1 | # /usr/lib/postgresql/14/bin/pg_ctl -D /mnt/b_public/data/wangzr/pgdata/ -l logfile start 2 | # clickhouse start # 启动clickhouse http://localhost:8123/play 3 | 4 | # 清空clickhouse数据库 5 | # clickhouse-client --query="SELECT name FROM system.databases WHERE name NOT IN ('system', 'default', 'INFORMATION_SCHEMA', 'information_schema')" | xargs -I {} clickhouse-client --query="DROP DATABASE IF EXISTS {}" 6 | 7 | # 清空postgresql数据库 8 | sudo -u postgres psql -t -c "SELECT datname FROM pg_database WHERE datistemplate = false AND datname <> 'postgres';" | while read dbname; do 9 | # 如果读取到非空行(即数据库名),则执行删除 10 | if [ -n "$dbname" ]; then 11 | echo "Dropping database: $dbname" 12 | sudo -u postgres dropdb "$dbname" 13 | fi 14 | done 15 | echo "All user databases have been dropped." 16 | 17 | # /etc/init.d/postgresql start # 启动postgresql sudo -u postgres psql 18 | -------------------------------------------------------------------------------- /Data_Synthesizer/database_synthesis/prompt_templates/enhance_prompt.txt: -------------------------------------------------------------------------------- 1 | **Task Overview:** 2 | As a senior data analyst, your task is to enhance an initial database schema to provide a more detailed and realistic structure based on a given business scenario. 3 | 4 | **Steps:** 5 | 1. **Analyze the Scenario:** Understand the provided business context. 6 | 2. **Identify Enhancements:** For each existing table, suggest new columns and explain their relevance. Be creative and thorough. 7 | 3. **Enrich the Schema:** Present the enriched schema in JSON format, ensuring proper primary and foreign key relationships. 8 | 9 | **Business Domain:** 10 | {domain} 11 | 12 | **Business Scenario:** 13 | {scenario} 14 | 15 | **Initial Database Schema:** 16 | ```json 17 | {schema} 18 | ``` 19 | 20 | **Output Format:** 21 | Your output should provide the enriched database schema in JSON format: 22 | ```json 23 | -- enriched database schema 24 | ``` 25 | 26 | Let's think step by step. -------------------------------------------------------------------------------- /Data_Synthesizer/synthesis_sql/README.md: -------------------------------------------------------------------------------- 1 | # Complexity-Aware SQL Query Generation 2 | 3 | This is the second step in our data synthesis framework, focused on generating complexity-aware SQL queries based on synthetic databases. 4 | 5 | ## Step 1: SQL Query Generation 6 | 7 | Generate SQL queries by leveraging database schemas, database values, query complexity, and SQLite-supported functions. 8 | 9 | 1. Execute `python3 generate_sql_synthesis_prompts.py` to create prompts for SQL query generation. 10 | 2. Run `python3 synthesize_sql.py` to generate SQL queries using LLMs. (Note: Implement the `llm_inference()` function to integrate your preferred LLM.) 11 | 12 | ## Step 2: Post-Processing 13 | 14 | Refine the generated SQL queries to ensure quality and remove invalid or redundant queries: 15 | 16 | 1. Run `python3 post_process_sqls.py` to: 17 | - Discard non-SELECT queries. 18 | - Remove queries with syntax errors or execution timeouts. 19 | - Deduplicate queries based on their templates. 20 | 21 | 2. The final synthetic SQL queries will be saved in `./results/synthetic_sqls.json`. 22 | -------------------------------------------------------------------------------- /Data_Synthesizer/tools/README.md: -------------------------------------------------------------------------------- 1 | 如果你想要迁移sqlite数据库和sql到其他的数据库,比如clickhouse,postgre,myscale,那么你得先启动clickhouse服务。 2 | 运行: 3 | ```bash 4 | sudo -u clickhouse /usr/bin/clickhouse-server --config-file=/etc/clickhouse-server/my-clean-config.xml --daemon 5 | ``` 6 | 7 | 检查postgre是否运行: 8 | ```bash 9 | ## 里面的路径是postgre数据库保存路径中的文件 10 | head -n 1 /mnt/DataFlow/ydw/data/pgdata/postmaster.pid 11 | ## 下面的是上面的结果 12 | ps -p -f 13 | 14 | ## 如果显示有 postgres 进程:说明数据库已经在运行了,无需再次启动。 15 | ## 如果报错说进程不存在:说明数据库之前崩了,你需要手动删除这个锁文件才能重新启动: 16 | rm /mnt/DataFlow/ydw/data/pgdata/postmaster.pid 17 | ``` 18 | 19 | 启动postgre,首先切换用户,因为PostgreSQL 严禁使用 root 用户启动。 20 | ```bash 21 | ## 确保权限正确 你需要把这个数据目录的所有权给到 postgres 用户(或者你想用来运行数据库的非 root 用户): 22 | chown -R postgres:postgres /mnt/DataFlow/ydw/data/pgdata/ 23 | ## 切换用户 24 | su - postgres 25 | ## 启动 26 | /usr/lib/postgresql/14/bin/pg_ctl -D /mnt/DataFlow/ydw/data/pgdata/ -l /tmp/pg_logfile.log start 27 | ## 或者使用更通用的命令: 28 | ## pg_ctl -D /mnt/DataFlow/ydw/data/pgdata/ -l /tmp/pg_logfile.log start 29 | 30 | ## 验证启动 31 | # 查看日志确认 success 32 | tail -f /tmp/pg_logfile.log 33 | 34 | # 尝试连接(默认端口通常是 5432) 35 | psql -h localhost -p 5432 -d postgres 36 | ``` 37 | -------------------------------------------------------------------------------- /Data_Synthesizer/tools/migrate_db.sh: -------------------------------------------------------------------------------- 1 | python migrate_db.py --source /mnt/DataFlow/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/sqlite/results/arxiv/vector_databases 2 | python migrate_db.py --source /mnt/DataFlow/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/sqlite/results/bird/vector_databases 3 | # python migrate_db.py --source /mnt/DataFlow/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/sqlite/results/spider/vector_databases 4 | python migrate_db.py --source /mnt/DataFlow/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/sqlite/results/wikipedia_multimodal/vector_databases 5 | 6 | python migrate_db_myscale.py --source /mnt/DataFlow/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/sqlite/results/arxiv/vector_databases 7 | python migrate_db_myscale.py --source /mnt/DataFlow/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/sqlite/results/bird/vector_databases 8 | # python migrate_db_myscale.py --source /mnt/DataFlow/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/sqlite/results/spider/vector_databases 9 | python migrate_db_myscale.py --source /mnt/DataFlow/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/sqlite/results/wikipedia_multimodal/vector_databases 10 | 11 | python migrate_main_sql_only.py 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | .env.* 3 | !.env.example 4 | 5 | # macOS system files 6 | .DS_Store 7 | __MACOSX/ 8 | 9 | # Python cache and environment files 10 | __pycache__/ 11 | *.pyc 12 | venv/ 13 | env/ 14 | 15 | 16 | # big files 17 | train/ 18 | results/ 19 | !question_synthesis/results/ 20 | !synthesis/toy_spider/train 21 | prompts/ 22 | models/ 23 | models_cache/ 24 | results/ 25 | Data_Synthesizer/database_synthesis/*.json 26 | Evaluation_Framework/*.json 27 | Data_Synthesizer/database_synthesis/synthesis_data 28 | Data_Synthesizer/database_synthesis/tables.json 29 | Data_Synthesizer/database_synthesis/web_tables.json 30 | Data_Synthesizer/database_synthesis/results_ignore_UL 31 | Data_Synthesizer/database_synthesis/synthesis_data_ignore_UL 32 | 33 | spider_data/ 34 | spider2_data/ 35 | cache/ 36 | # prompt*/ 37 | logging/ 38 | 39 | *.gguf 40 | ./synthesis/model/ 41 | 42 | *.sql 43 | *.sqlite 44 | 45 | # *.yaml 46 | Embedding_Service/config.yaml 47 | Data_Synthesizer/pipeline/config.yaml 48 | 49 | __pycache__ 50 | .vscode 51 | *.sqlite-shm 52 | *.sqlite-wal 53 | *.egg-info 54 | results/* 55 | database/ 56 | clash-for-linux/ 57 | tmp/ 58 | *.zip 59 | *.db 60 | *.gguf 61 | *.yaml 62 | Data_Synthesizer/data 63 | 64 | LLaMA-Factory/ 65 | -------------------------------------------------------------------------------- /Data_Synthesizer/pipeline/sqlite/prompt_templates/sqlite_note_prompt.txt: -------------------------------------------------------------------------------- 1 | You are an expert SQLite generator. Your primary goal is to generate syntactically correct and efficient SQL queries that strictly adhere to the following rules. Your adherence to these rules is mandatory to prevent common errors. 2 | 3 | --- 4 | This single example illustrates the correct application of the most important rules, especially for complex queries. 5 | 6 | ```sql 7 | -- This is a perfect example of a complex query done correctly. 8 | WITH ActiveUsers AS (\n SELECT uploaded_by\n FROM code_snippets\n GROUP BY uploaded_by\n HAVING COUNT(snippet_id) > 1\n),\nPublicSnippets AS (\n SELECT snippet_id, description\n FROM code_snippets\n INNER JOIN ActiveUsers ON ActiveUsers.uploaded_by = code_snippets.uploaded_by\n WHERE is_public = 1\n),\nSuccessfulUses AS (\n SELECT snippet_id\n FROM snippet_usage\n WHERE is_successful = 1\n),\nLowQualitySnippets AS (\n SELECT ps.snippet_id, ps.description\n FROM PublicSnippets ps\n INNER JOIN SuccessfulUses su ON ps.snippet_id = su.snippet_id\n INNER JOIN quality_scores qs ON ps.snippet_id = qs.snippet_id\n WHERE qs.explanation_quality = 0\n)\nSELECT description\nFROM LowQualitySnippets; 9 | ``` 10 | -------------------------------------------------------------------------------- /Data_Synthesizer/database_synthesis/generate_schema_embedding_prompts.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | random.seed(42) 5 | 6 | if __name__ == '__main__': 7 | prompts = [] 8 | prompt_template = open("./prompt_templates/embedding_prompt.txt", "r", encoding = "utf-8").read() 9 | schema_synthesis_results = json.load(open("./results/schema_enhancement.json")) 10 | 11 | no_res_num = 0 12 | for data in schema_synthesis_results: 13 | try: 14 | if data["enhanced_schema"] == {}: 15 | no_res_num += 1 16 | continue 17 | 18 | domain = data["domain"] 19 | scenario = data["scenario"] 20 | schema_str = data["enhanced_schema"] 21 | 22 | prompts.append( 23 | prompt_template.format(domain = domain, scenario = scenario, schema = schema_str).strip() 24 | ) 25 | 26 | except Exception as e: 27 | print(e) 28 | 29 | print("no_res_num:", no_res_num) 30 | print("len(prompts):", len(prompts)) 31 | random.shuffle(prompts) 32 | 33 | with open("./prompts/prompts_schema_embedding.json", "w", encoding="utf-8") as file: 34 | file.write(json.dumps(prompts, ensure_ascii=False, indent=2)) -------------------------------------------------------------------------------- /Data_Synthesizer/pipeline/sqlite/prompt_templates/cot_synthesis_prompt_template.txt: -------------------------------------------------------------------------------- 1 | You are a senior data analyst specializing in SQL. Your task is to translate a natural language question into an executable {database_backend} query, providing a detailed reasoning trace. 2 | 3 | You will also receive a reference solution from a colleague, which may or may not be correct. This extra information intends to help you generate your answer, but you are asked not to mention the reference solution in any form. 4 | The reference solution might include: 5 | 1. Unnecessary table and column selections. 6 | 2. Incorrect or excessive joins. 7 | 3. Misalignment with the question. 8 | 4. Opportunities for simplification. 9 | 10 | Ensure the SQL query is presented in a Markdown code block with proper syntax highlighting, like this: 11 | ```sql 12 | SELECT * FROM table; 13 | ``` 14 | 15 | [Database Schema]: 16 | {schema} 17 | 18 | [Natural Language Question]: 19 | {question} 20 | 21 | [Reference Solution]: 22 | ```sql 23 | {sql} 24 | ``` 25 | 26 | [Rules of database backend and extension] 27 | {database_note} 28 | 29 | [Function Context]: 30 | The function lembed('{embedding_model_name}', text) is available. It converts text into a vector embedding. 31 | 32 | Provide your step-by-step text-to-SQL solution here. 33 | -------------------------------------------------------------------------------- /Data_Synthesizer/database_synthesis/generate_schema_enhancement_prompts.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | random.seed(42) 5 | 6 | if __name__ == '__main__': 7 | prompts = [] 8 | prompt_template = open("./prompt_templates/enhance_prompt.txt", "r", encoding = "utf-8").read() 9 | schema_synthesis_results = json.load(open("./results/schema_synthesis.json")) 10 | 11 | no_res_num = 0 12 | for data in schema_synthesis_results: 13 | try: 14 | if data["generated_content"] == {}: 15 | no_res_num += 1 16 | continue 17 | 18 | domain = data["generated_content"]["domain"] 19 | scenario = data["generated_content"]["scenario"] 20 | schema_str = data["generated_content"]["schema"] 21 | 22 | prompts.append( 23 | prompt_template.format(domain = domain, scenario = scenario, schema = schema_str).strip() 24 | ) 25 | 26 | except Exception as e: 27 | print(e) 28 | 29 | print("no_res_num:", no_res_num) 30 | print("len(prompts):", len(prompts)) 31 | random.shuffle(prompts) 32 | 33 | with open("./prompts/prompts_schema_enhancement.json", "w", encoding="utf-8") as file: 34 | file.write(json.dumps(prompts, ensure_ascii=False, indent=2)) -------------------------------------------------------------------------------- /Evaluation_Framework/prompt_templates/sql_generate_prompt_template.txt: -------------------------------------------------------------------------------- 1 | You are a senior SQL engineer. Your task is to generate a single, correct, and executable SQL query to answer the user's question based on the provided database context. 2 | 3 | ## INSTRUCTIONS 4 | 1. **Backend Adherence**: The query MUST be written for the `{dataset_backend}` database backend. This is a strict requirement. 5 | 2. **Follow Special Notes**: You MUST strictly follow all syntax, functions, or constraints described in the [Database Backend Notes]. Pay extremely close attention to this section, as it contains critical, non-standard rules. 6 | 3. **Schema Integrity**: The query MUST ONLY use the tables and columns provided in the [Database Schema]. Do not invent or guess table or column names. 7 | 4. **Answer the Question**: The query must directly and accurately answer the [Natural Language Question]. 8 | 5. **Output Format**: Enclose the final SQL query in a single Markdown code block formatted for SQL (` ```sql ... ``` `). 9 | 10 | ## DATABASE CONTEXT 11 | 12 | [DATABASE BACKEND]: 13 | {dataset_backend} 14 | 15 | [DATABASE SCHEMA]: 16 | {schema} 17 | 18 | [DATABASE BACKEND NOTES]: 19 | {database_note_prompt} 20 | 21 | [EMBEDDING MODEL NAME]: 22 | {embedding_model_name} 23 | 24 | ## NATURAL LANGUAGE QUESTION 25 | {question} 26 | 27 | Let's think step by step! -------------------------------------------------------------------------------- /Execution_Engine/engine_config.yaml: -------------------------------------------------------------------------------- 1 | # engine_config.yaml 2 | 3 | # 上一步部署的Embedding Service的API地址 4 | embedding_service: 5 | url: "http://localhost:8000/embed" 6 | # 可以增加超时等配置 7 | # timeout: 10 8 | 9 | # 不同类型数据库的连接模板 10 | # 脚本将使用这些模板,并用运行时传入的数据库名或路径来完成连接 11 | database_connections: 12 | postgresql: 13 | # 'database' (数据库名) 将在运行时通过参数传入 14 | host: "localhost" 15 | port: 5432 16 | user: "postgres" # 请替换为您的用户名 17 | password: "postgres" # 请替换为您的密码 18 | 19 | clickhouse: 20 | # 'database' (数据库名) 将在运行时通过参数传入 21 | host: "localhost" 22 | port: 8123 # HTTP协议端口, 通常是8123 23 | user: "default" 24 | password: "" # 您的ClickHouse密码 25 | 26 | # --- 新增 MyScale 配置块 --- 27 | myscale: 28 | # MyScale 使用与 ClickHouse 相同的 HTTP 协议 (via clickhouse_connect) 29 | # 你可以替换成你的公网IP (例如 "8.140.37.123") 或保持 "localhost" 30 | host: "112.126.57.89" 31 | port: 8123 # HTTP 协议端口 9000 32 | user: "default" 33 | password: "myscale#EDC" # MyScale 的密码 (如果设置了) 34 | # --- 新增结束 --- 35 | 36 | # SQLite不需要配置模板,因为它的连接就是文件路径本身 37 | 38 | # 超时配置(单位:秒) 39 | timeouts: 40 | embedding_service: 10 # Embedding服务调用超时 (1秒太短了, 建议改长一点) 41 | database_connection: 10 # 数据库连接超时 (1秒太短了, 建议改长一点) 42 | sql_execution: 60 # SQL执行超时 43 | total_execution: 120 # 总执行超时 (改成了120秒) 44 | 45 | # 日志配置 46 | logging: 47 | level: "WARNING" 48 | format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 49 | -------------------------------------------------------------------------------- /Data_Synthesizer/database_synthesis/README.md: -------------------------------------------------------------------------------- 1 | # Web Table-Driven Database Synthesis 2 | 3 | This is the first step in our data synthesis framework, designed to generate realistic databases using web tables. 4 | 5 | ## Prepare Web Tables 6 | Unzip `web_tables.json.zip` to access 19,935 high-quality web tables from [Tablib](https://arxiv.org/pdf/2310.07875). 7 | 8 | ## Step 1: Initial Database Generation 9 | Generate an initial database from the web tables. 10 | 11 | 1. Run `python3 generate_schema_synthesis_prompts.py` to create prompts for database generation. 12 | 2. Run `python3 synthesize_schema.py` to generate initial database schemas. (Implement the `llm_inference()` function to use your preferred LLMs.) 13 | 14 | ## Step 2: Database Enhancement 15 | Enhance the initially generated databases to increase complexity and realism. 16 | 17 | 1. Run `python3 generate_schema_enhancement_prompts.py` to create prompts for database enhancement. 18 | 2. Run `python3 enhance_schema.py` to generate enhanced database schemas. (Implement the `llm_inference()` function to use your preferred LLMs.) 19 | 20 | 23 | 24 | ## Step 4: Building SQLite Databases 25 | Build SQLite databases based on the enhanced database schemas. 26 | 27 | 1. Run `python3 build_sqlite_databases.py` to construct SQLite databases, which are stored in the `synthesis_data` folder. 28 | 2. (Optional) Run `python3 generate_tables_json.py` to create the `tables.json` file, containing detailed information about the synthetic databases, aligning with previous text-to-SQL datasets. 29 | 3. Run `cp -r synthesis_data ../pipeline/sqlite/train/` to move database to right place. 30 | -------------------------------------------------------------------------------- /Evaluation_Framework/script/generate_eval_prompts.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import os 4 | 5 | from tqdm import tqdm 6 | 7 | def generate_sql_prompts(dataset_json_path="../sqlite/results/toy_spider/candidate_sql_query_id.json", tables_json_path="../../Data_Synthesizer/pipeline/sqlite/results/toy_spider/embedding_table_vector.json", prompt_tamplate_path="../prompt_templates/sql_generate_prompt_template.txt", output_prompt_path="../sqlite/prompts/sql_generate_prompts.json",dataset_backend="sqlite",database_note_prompt_path="../prompt_templates/sqlite_vec_note_prompt.txt",embedding_model_name="all-MiniLM-L6-v2"): 8 | dataset_json = json.load(open(dataset_json_path)) 9 | tables_json = json.load(open(tables_json_path)) 10 | print("len(question-vecsql):", len(dataset_json)) 11 | 12 | prompts = [] 13 | db_id2ddls = dict() 14 | for table in tables_json: 15 | db_id2ddls[table["db_id"]] = table["ddls"] 16 | print("len(db_id2ddls):", len(db_id2ddls)) 17 | 18 | database_note_prompt = open(database_note_prompt_path).read().format(embedding_model = embedding_model_name) 19 | prompt_tamplate = open(prompt_tamplate_path).read() 20 | for data in tqdm(dataset_json): 21 | if data["external_knowledge"] != "": 22 | question = data["external_knowledge"] + "\n" + data["question"] 23 | else: 24 | question = data["question"] 25 | 26 | data["sql_synthesis_prompt"] = prompt_tamplate.format( 27 | schema = "\n\n".join(db_id2ddls[data["db_id"]]), 28 | question = question, 29 | dataset_backend =dataset_backend, 30 | database_note_prompt = database_note_prompt 31 | ) 32 | # 创建输出目录 33 | os.makedirs("../sqlite/prompts", exist_ok=True) 34 | with open(output_prompt_path, "w", encoding="utf-8") as f: 35 | f.write(json.dumps(dataset_json, indent=2, ensure_ascii=False)) 36 | 37 | if __name__ == "__main__": 38 | generate_sql_prompts() 39 | -------------------------------------------------------------------------------- /Data_Synthesizer/tools/duplicate.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | 4 | def find_duplicate_prefix_sqlite_files(root_dir): 5 | """ 6 | 递归遍历指定目录,查找具有相同前缀的 .sqlite 文件。 7 | 8 | Args: 9 | root_dir (str): 要开始搜索的根目录的路径。 10 | 11 | Returns: 12 | dict: 一个字典,其中键是 .sqlite 文件的前缀, 13 | 值是具有该前缀的文件路径列表。 14 | 仅包含具有多个文件的条目。 15 | """ 16 | if not os.path.isdir(root_dir): 17 | print(f"错误:提供的路径 '{root_dir}' 不是一个有效的目录。") 18 | return {} 19 | 20 | sqlite_files_by_prefix = defaultdict(list) 21 | 22 | # 递归遍历目录树 23 | for dirpath, _, filenames in os.walk(root_dir): 24 | for filename in filenames: 25 | if filename.endswith('.sqlite'): 26 | # 获取文件名(不包括.sqlite扩展名)作为前缀 27 | prefix = filename[:-7] # ".sqlite" 的长度是 7 28 | full_path = os.path.join(dirpath, filename) 29 | sqlite_files_by_prefix[prefix].append(full_path) 30 | 31 | # 筛选出具有多个文件(即重复前缀)的条目 32 | duplicate_files = { 33 | prefix: paths for prefix, paths in sqlite_files_by_prefix.items() if len(paths) > 1 34 | } 35 | 36 | return duplicate_files 37 | 38 | if __name__ == '__main__': 39 | # --- 使用说明 --- 40 | # 1. 将下面的 'your_target_directory' 替换为您要搜索的实际目录路径。 41 | # 例如: '/home/user/documents' 或 'C:\\Users\\User\\Documents' 42 | target_directory = '/mnt/b_public/data/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/sqlite/results' 43 | 44 | # 检查并查找具有相同前缀的.sqlite文件 45 | duplicates = find_duplicate_prefix_sqlite_files(target_directory) 46 | 47 | if not duplicates: 48 | print(f"在目录 '{target_directory}' 及其子目录中没有找到具有相同前缀的 .sqlite 文件。") 49 | else: 50 | print(f"在 '{target_directory}' 中找到了以下具有相同前缀的 .sqlite 文件:\n") 51 | for prefix, paths in duplicates.items(): 52 | print(f"前缀: '{prefix}.sqlite'") 53 | for path in paths: 54 | print(f" - {path}") 55 | print("-" * 20) -------------------------------------------------------------------------------- /Data_Synthesizer/synthesis_eval/generate_eval_prompts.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import os 4 | 5 | from tqdm import tqdm 6 | 7 | def generate_sql_prompts(dataset_json_path="../pipeline/sqlite/results/toy_spider/candidate_sql.json", tables_json_path="../pipeline/sqlite/results/toy_spider/embedding_table_vector.json", prompt_tamplate_path="../pipeline/sqlite/prompt_templates/sql_generate_prompt_template.txt", output_prompt_path="../pipeline/sqlite/prompts/sql_generate_prompts.json",dataset_backend="sqlite",database_note_prompt_path="../prompt_templates/sqlite_vec_note_prompt.txt",embedding_model_name="all-MiniLM-L6-v2"): 8 | dataset_json = json.load(open(dataset_json_path)) 9 | tables_json = json.load(open(tables_json_path)) 10 | print("len(question-vecsql):", len(dataset_json)) 11 | 12 | prompts = [] 13 | db_id2ddls = dict() 14 | for table in tables_json: 15 | db_id2ddls[table["db_id"]] = table["ddls"] 16 | print("len(db_id2ddls):", len(db_id2ddls)) 17 | 18 | database_note_prompt = open(database_note_prompt_path).read().format(embedding_model = embedding_model_name) 19 | prompt_tamplate = open(prompt_tamplate_path).read() 20 | for data in tqdm(dataset_json): 21 | if data["external_knowledge"] != "": 22 | question = data["external_knowledge"] + "\n" + data["question"] 23 | else: 24 | question = data["question"] 25 | 26 | data["sql_synthesis_prompt"] = prompt_tamplate.format( 27 | schema = "\n\n".join(db_id2ddls[data["db_id"]]), 28 | question = question, 29 | dataset_backend =dataset_backend, 30 | database_note_prompt = database_note_prompt 31 | ) 32 | 33 | 34 | # 创建输出目录 35 | os.makedirs("../pipeline/sqlite/prompts", exist_ok=True) 36 | with open(output_prompt_path, "w", encoding="utf-8") as f: 37 | f.write(json.dumps(dataset_json, indent=2, ensure_ascii=False)) 38 | 39 | if __name__ == "__main__": 40 | generate_sql_prompts() 41 | -------------------------------------------------------------------------------- /Data_Synthesizer/collect_input_llm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | # 1. 定义您提供的三个JSON文件的路径列表 5 | file_paths = [ 6 | '/mnt/b_public/data/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/sqlite/results/synthesis_data/input_llm.json', 7 | '/mnt/b_public/data/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/postgresql/results/synthesis_data/input_llm.json', 8 | '/mnt/b_public/data/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/clickhouse/results/synthesis_data/input_llm.json' 9 | ] 10 | 11 | output_file_path = 'collected_input_llm.json' 12 | 13 | 14 | # 2. 初始化一个空列表,用于存放所有合并后的数据 15 | combined_list = [] 16 | 17 | print("开始处理文件...") 18 | 19 | # 3. 遍历文件路径列表,读取并合并数据 20 | for file_path in file_paths: 21 | try: 22 | with open(file_path, 'r', encoding='utf-8') as f: 23 | # 从文件中加载JSON数据(每个文件都是一个列表) 24 | data = json.load(f) 25 | # 使用 extend 方法将当前文件中的列表元素添加到主列表中 26 | combined_list.extend(data) 27 | print(f"成功从 {file_path} 加载了 {len(data)} 条数据。") 28 | except FileNotFoundError: 29 | print(f"错误:文件未找到 {file_path}") 30 | except json.JSONDecodeError: 31 | print(f"错误:无法解析文件中的JSON内容 {file_path}") 32 | except Exception as e: 33 | print(f"处理文件 {file_path} 时发生未知错误: {e}") 34 | 35 | # 打印合并后的总数据量 36 | print(f"\n数据合并完成。总共聚合了 {len(combined_list)} 条数据。") 37 | 38 | # 4. 对合并后的列表进行随机排序 (in-place shuffle) 39 | print("正在对聚合列表进行随机排序...") 40 | random.shuffle(combined_list) 41 | print("列表已成功打乱顺序。") 42 | 43 | # (可选) 验证一下,打印前5个元素看看效果 44 | # print("\n打乱顺序后列表的前5个元素示例:") 45 | # for item in combined_list[:5]: 46 | # print(item) 47 | 48 | # (可选) 5. 将最终的列表写入一个新的JSON文件 49 | print(f"\n准备将结果保存到文件: {output_file_path}") 50 | try: 51 | with open(output_file_path, 'w', encoding='utf-8') as f: 52 | # 使用 json.dump 将列表写入文件 53 | # ensure_ascii=False 确保中文字符等能正确显示 54 | # indent=4 让JSON文件格式化,更易读 55 | json.dump(combined_list, f, ensure_ascii=False, indent=4) 56 | print(f"数据已成功保存到 {output_file_path}") 57 | except Exception as e: 58 | print(f"保存文件时发生错误: {e}") -------------------------------------------------------------------------------- /Data_Synthesizer/database_synthesis/prompt_templates/embedding_prompt.txt: -------------------------------------------------------------------------------- 1 | **Task Overview:** 2 | As a senior data analyst, your task is to enhance an initial database schema to incorporate embedding columns for semantically rich fields based on a given business scenario while preserving all original data structure and sample values. 3 | 4 | **Strict Requirements:** 5 | 1. **Schema Preservation:** 6 | - Maintain ALL original columns unless you have valid reason 7 | - Do not delete existing sample_rows values just modify them 8 | - Keep original column order unless adding new embedding columns 9 | 10 | 2. **Embedding Column Addition:** 11 | - Only add new columns with suffix `_embedding` (type: BLOB) 12 | - Select columns based on semantic richness (text, descriptions, content) 13 | - Never modify existing column names/types/values 14 | 15 | 3. **Sample Data Handling:** 16 | - For new embedding columns in sample_rows: 17 | * Use explicit null values (`null`) 18 | * Maintain original structure 19 | * Preserve all existing key-value pairs 20 | 21 | **Steps:** 22 | 1. **Analyze Semantic Columns:** Identify existing columns with rich semantic information (text descriptions, reviews, content, names, etc.) 23 | 2. **Generate Embedding Columns:** For each identified column: 24 | - Create a new column with name `[original_col]_embedding` 25 | - Set type as `BLOB` (binary storage for vector embeddings) 26 | - Keep sample values as empty (`null` or `""`) 27 | 3. **Preserve Relationships:** Maintain existing primary/foreign key relationships 28 | 4. **Document Reasoning:** Briefly explain why each embedding column was added 29 | 30 | **Business Domain:** 31 | {domain} 32 | 33 | **Business Scenario:** 34 | {scenario} 35 | 36 | **Initial Database Schema:** 37 | ```json 38 | {schema} 39 | ``` 40 | 41 | **Output Format:** 42 | Your output should provide the enriched database schema with embedding columns in JSON format: 43 | ```json 44 | -- enriched database schema with embedding 45 | ``` 46 | 47 | Let's think step by step. 48 | -------------------------------------------------------------------------------- /Data_Synthesizer/tools/collect_input_llm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | # 1. 定义您提供的三个JSON文件的路径列表 5 | file_paths = [ 6 | '/mnt/DataFlow/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/sqlite/results/synthesis_data/input_llm.json', 7 | '/mnt/DataFlow/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/postgresql/results/synthesis_data/input_llm.json', 8 | '/mnt/DataFlow/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/myscale/results/synthesis_data/input_llm.json', 9 | '/mnt/DataFlow/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/clickhouse/results/synthesis_data/input_llm.json' 10 | ] 11 | 12 | output_file_path = 'results/collected_input_llm.json' 13 | 14 | 15 | # 2. 初始化一个空列表,用于存放所有合并后的数据 16 | combined_list = [] 17 | 18 | print("开始处理文件...") 19 | 20 | # 3. 遍历文件路径列表,读取并合并数据 21 | for file_path in file_paths: 22 | try: 23 | with open(file_path, 'r', encoding='utf-8') as f: 24 | # 从文件中加载JSON数据(每个文件都是一个列表) 25 | data = json.load(f) 26 | # 使用 extend 方法将当前文件中的列表元素添加到主列表中 27 | combined_list.extend(data) 28 | print(f"成功从 {file_path} 加载了 {len(data)} 条数据。") 29 | except FileNotFoundError: 30 | print(f"错误:文件未找到 {file_path}") 31 | except json.JSONDecodeError: 32 | print(f"错误:无法解析文件中的JSON内容 {file_path}") 33 | except Exception as e: 34 | print(f"处理文件 {file_path} 时发生未知错误: {e}") 35 | 36 | # 打印合并后的总数据量 37 | print(f"\n数据合并完成。总共聚合了 {len(combined_list)} 条数据。") 38 | 39 | # 4. 对合并后的列表进行随机排序 (in-place shuffle) 40 | print("正在对聚合列表进行随机排序...") 41 | random.shuffle(combined_list) 42 | print("列表已成功打乱顺序。") 43 | 44 | # (可选) 验证一下,打印前5个元素看看效果 45 | # print("\n打乱顺序后列表的前5个元素示例:") 46 | # for item in combined_list[:5]: 47 | # print(item) 48 | 49 | # (可选) 5. 将最终的列表写入一个新的JSON文件 50 | print(f"\n准备将结果保存到文件: {output_file_path}") 51 | try: 52 | with open(output_file_path, 'w', encoding='utf-8') as f: 53 | # 使用 json.dump 将列表写入文件 54 | # ensure_ascii=False 确保中文字符等能正确显示 55 | # indent=4 让JSON文件格式化,更易读 56 | json.dump(combined_list, f, ensure_ascii=False, indent=4) 57 | print(f"数据已成功保存到 {output_file_path}") 58 | except Exception as e: 59 | print(f"保存文件时发生错误: {e}") 60 | -------------------------------------------------------------------------------- /Evaluation_Framework/script/generate_query_id.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import os 4 | 5 | def add_query_ids_to_json(input_path: str, output_path: str): 6 | """ 7 | 读取一个包含字典列表的 JSON 文件,为每个字典添加一个唯一的 'query_id', 8 | 然后将结果写入新的 JSON 文件。 9 | 10 | Args: 11 | input_path (str): 输入的 JSON 文件路径。 12 | output_path (str): 输出的 JSON 文件路径。 13 | """ 14 | try: 15 | # 1. 读取输入的 JSON 文件 16 | # 使用 'r' 模式和 utf-8 编码确保能正确处理中文字符 17 | with open(input_path, 'r', encoding='utf-8') as f: 18 | data = json.load(f) 19 | 20 | # 2. 检查数据格式是否为列表 21 | if not isinstance(data, list): 22 | print(f"错误: JSON 文件的顶层结构不是一个列表/数组。文件路径: {input_path}") 23 | return 24 | 25 | # 3. 遍历列表中的每个元素(字典)并添加字段 26 | # 使用 enumerate(..., start=1) 可以同时获得索引和元素,索引从 1 开始 27 | for index, item in enumerate(data, start=1): 28 | # 确保列表中的元素是字典,以防数据格式混淆 29 | if isinstance(item, dict): 30 | # 构造 query_id, 例如 "q1", "q2", ... 31 | query_id = f"q{index}" 32 | # 为字典添加新的键值对 33 | item['query_id'] = query_id 34 | else: 35 | print(f"警告: 在索引 {index-1} 处找到一个非字典类型的元素,已跳过。") 36 | 37 | # 4. 将修改后的数据写入新的 JSON 文件 38 | # 使用 'w' 模式写入 39 | # indent=4 让输出的 JSON 文件格式化,更易于阅读 40 | # ensure_ascii=False 确保中文字符能正常显示,而不是被转义成 \uXXXX 41 | with open(output_path, 'w', encoding='utf-8') as f: 42 | json.dump(data, f, indent=4, ensure_ascii=False) 43 | 44 | print(f"处理完成!已成功为 {len(data)} 个元素添加 query_id。") 45 | print(f"结果已保存至: {output_path}") 46 | 47 | except FileNotFoundError: 48 | print(f"错误: 找不到输入文件。请检查路径是否正确: {input_path}") 49 | except json.JSONDecodeError: 50 | print(f"错误: 文件内容不是有效的 JSON 格式。文件路径: {input_path}") 51 | except Exception as e: 52 | print(f"发生未知错误: {e}") 53 | 54 | # --- 使用示例 --- 55 | if __name__ == "__main__": 56 | # 定义输入和输出文件名 57 | input_file = "../data/candidate_sql.json" 58 | output_file = "../data/candidate_sql_query_id.json" 59 | 60 | # 调用函数进行处理 61 | add_query_ids_to_json(input_file, output_file) 62 | -------------------------------------------------------------------------------- /Data_Synthesizer/tools/add_candidate_prefix.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | def add_prefix_to_db_id(input_file, output_file, prefix="deverse_2_"): 5 | """ 6 | 读取一个JSON文件,为文件中字典数组的每个'db_id'字段添加前缀, 7 | 并保存到新的文件中。 8 | 9 | Args: 10 | input_file (str): 输入的JSON文件名。 11 | output_file (str): 输出的JSON文件名。 12 | prefix (str): 要添加的前缀。 13 | """ 14 | # --- 1. 读取并解析JSON文件 --- 15 | try: 16 | with open(input_file, 'r', encoding='utf-8') as f: 17 | data = json.load(f) 18 | print(f"✅ 成功读取文件: '{input_file}'") 19 | except FileNotFoundError: 20 | print(f"❌ 错误: 输入文件 '{input_file}' 未找到。请检查文件名和路径。") 21 | sys.exit(1) # 退出脚本 22 | except json.JSONDecodeError: 23 | print(f"❌ 错误: 文件 '{input_file}' 不是有效的JSON格式。") 24 | sys.exit(1) 25 | 26 | # 检查数据是否为列表 27 | if not isinstance(data, list): 28 | print(f"❌ 错误: JSON文件的顶层结构不是一个数组(列表)。") 29 | sys.exit(1) 30 | 31 | # --- 2. 遍历并修改数据 --- 32 | modified_count = 0 33 | for item in data: 34 | # 确保元素是字典并且包含 'db_id' 键 35 | if isinstance(item, dict) and 'db_id' in item: 36 | original_id = item['db_id'] 37 | item['db_id'] = prefix + original_id 38 | modified_count += 1 39 | # print(f" - 已修改: '{original_id}' -> '{item['db_id']}'") # 如果需要详细日志可以取消此行注释 40 | 41 | print(f"🔄 已处理 {len(data)} 个元素,其中 {modified_count} 个元素的 'db_id' 被修改。") 42 | 43 | # --- 3. 将修改后的数据写入新文件 --- 44 | try: 45 | with open(output_file, 'w', encoding='utf-8') as f: 46 | # indent=2 使输出的JSON文件格式化,更易读 47 | # ensure_ascii=False 确保中文字符等能被正确写入 48 | json.dump(data, f, indent=2, ensure_ascii=False) 49 | print(f"✅ 操作完成!结果已保存到: '{output_file}'") 50 | except IOError as e: 51 | print(f"❌ 错误: 无法写入到文件 '{output_file}'。") 52 | print(f"详细信息: {e}") 53 | sys.exit(1) 54 | 55 | 56 | if __name__ == "__main__": 57 | # --- 请在这里配置你的文件名 --- 58 | input_filename = "cot_synthesis_old.json" # <--- 你的原始JSON文件名 59 | output_filename = "cot_synthesis.json" # <--- 你希望保存的新文件名 60 | 61 | # 运行主函数 62 | add_prefix_to_db_id(input_filename, output_filename) 63 | 64 | -------------------------------------------------------------------------------- /Data_Synthesizer/pipeline/sqlite/prompt_templates/sql_generate_prompt_template.txt: -------------------------------------------------------------------------------- 1 | You are a senior SQL engineer. Your task is to generate a single, correct, and executable SQL query to answer the user's question based on the provided database context. 2 | 3 | ## INSTRUCTIONS 4 | 1. **Backend Adherence**: The query MUST be written for the `{dataset_backend}` database backend. This is a strict requirement. 5 | 2. **Follow Special Notes**: You MUST strictly follow all syntax, functions, or constraints described in the [Database Backend Notes]. Pay extremely close attention to this section, as it contains critical, non-standard rules. 6 | 3. **Schema Integrity**: The query MUST ONLY use the tables and columns provided in the [Database Schema]. Do not invent or guess table or column names. 7 | 4. **Answer the Question**: The query must directly and accurately answer the [Natural Language Question]. 8 | 5. **Output Format**: Enclose the final SQL query in a single Markdown code block formatted for SQL (` ```sql ... ``` `). 9 | 6. **Embedding Match**: If the [EMBEDDING_MODEL_NAME] parameter is a valid string (e.g., 'all-MiniLM-L6-v2'), you MUST generate a query that includes the WHERE [EMBEDDING_COLUMN_NAME] MATCH lembed(...) clause for vector search. Otherwise, if embedding model name below the [EMBEDDING MODEL NAME] is None, , you MUST generate a standard SQL query that OMITS the entire MATCH lembed(...) clause. The query should not perform any vector search. 10 | 7. **Embedding Name**: If a value is provided for the parameter `[EMBEDDING_MODEL_NAME]`, your generated query must contain a `lembed` function call. The first parameter to the `lembed` function MUST be the exact value of `[EMBEDDING_MODEL_NAME]`, formatted as a string literal (enclosed in single quotes). For example, if `[EMBEDDING_MODEL_NAME]` is `laion/CLIP-ViT-B-32-laion2B-s34B-b79K`, the generated SQL must include `MATCH lembed('laion/CLIP-ViT-B-32-laion2B-s34B-b79K', ...)`. 11 | 12 | ## DATABASE CONTEXT 13 | 14 | [DATABASE BACKEND]: 15 | {dataset_backend} 16 | 17 | [DATABASE SCHEMA]: 18 | {schema} 19 | 20 | [DATABASE BACKEND NOTES]: 21 | {database_note_prompt} 22 | 23 | [EMBEDDING MODEL NAME]: 24 | {embedding_model_name} 25 | 26 | ## NATURAL LANGUAGE QUESTION 27 | {question} 28 | 29 | Let's think step by step! 30 | -------------------------------------------------------------------------------- /Embedding_Service/multi_client.py: -------------------------------------------------------------------------------- 1 | # client_example.py 2 | import requests 3 | import base64 4 | from PIL import Image 5 | import io 6 | 7 | API_URL = "http://127.0.0.1:8000/embed" 8 | 9 | def get_text_embedding(text: str, model: str): 10 | """Gets embedding for a single text.""" 11 | payload = { 12 | "model": model, 13 | "texts": [text] 14 | } 15 | response = requests.post(API_URL, json=payload) 16 | response.raise_for_status() 17 | return response.json()['embeddings'][0] 18 | 19 | def get_image_embedding(image_path: str, model: str): 20 | """Gets embedding for a single image file.""" 21 | # Read image and convert to Base64 22 | with open(image_path, "rb") as f: 23 | image_bytes = f.read() 24 | b64_string = base64.b64encode(image_bytes).decode("utf-8") 25 | 26 | payload = { 27 | "model": model, 28 | "images": [b64_string] 29 | } 30 | response = requests.post(API_URL, json=payload) 31 | response.raise_for_status() 32 | return response.json()['embeddings'][0] 33 | 34 | if __name__ == "__main__": 35 | # Create a dummy image for testing 36 | try: 37 | img = Image.new('RGB', (60, 30), color = 'red') 38 | # img.save('test_image.png') 39 | print("Created 'test_image.png' for testing.") 40 | 41 | # --- Test Text Embedding --- 42 | print("\n--- Testing Text Embedding ---") 43 | text_emb = get_text_embedding("A photo of a red square", model="sentence-transformers/clip-ViT-B-32") 44 | print(f"Model: clip-vit-base") 45 | print(f"Text: 'A photo of a red square'") 46 | print(f"Embedding vector (first 5 dims): {text_emb[:5]}") 47 | print(f"Embedding dimension: {len(text_emb)}") 48 | 49 | # --- Test Image Embedding --- 50 | print("\n--- Testing Image Embedding ---") 51 | image_emb = get_image_embedding("test_image.png", model="sentence-transformers/clip-ViT-B-32") 52 | print(f"Model: clip-vit-base") 53 | print(f"Image: 'test_image.png'") 54 | print(f"Embedding vector (first 5 dims): {image_emb[:5]}") 55 | print(f"Embedding dimension: {len(image_emb)}") 56 | 57 | except requests.exceptions.ConnectionError as e: 58 | print(f"\nCould not connect to the server at {API_URL}.") 59 | print("Please ensure the embedding_server.py is running.") 60 | except Exception as e: 61 | print(f"An error occurred: {e}") -------------------------------------------------------------------------------- /Data_Synthesizer/synthesis_cot/generate_cot_synthesis_prompts.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import os 4 | from tqdm import tqdm 5 | 6 | def remove_sql_comments(sql): 7 | # Remove single-line comments 8 | sql = re.sub(r'--.*', '', sql) 9 | # Remove multi-line comments 10 | sql = re.sub(r'/\*.*?\*/', '', sql, flags=re.DOTALL) 11 | return sql.strip() 12 | 13 | def generate_cot_prompts(dataset_json_path="./results/question_and_sql_pairs.json", tables_json_path="sqlite/results/enhanced_embedding_table_vector.json", prompt_tamplate_path="sqlite/prompt_templates/cot_synthesis_prompt_template.txt", database_note_template="./sqlite/prompt_templates/sqlite_vec_note_prompt.txt", output_prompt_path="sqlite/prompts/cot_synthesis_prompts.json", database_backend="sqlite", embedding_model_name="all-MiniLM-L6-v2"): 14 | dataset_json = json.load(open(dataset_json_path)) 15 | print("len(question-sql):", len(dataset_json)) 16 | 17 | if os.path.exists(tables_json_path): 18 | tables_json = json.load(open(tables_json_path)) 19 | db_id2ddls = dict() 20 | for table in tables_json: 21 | db_id2ddls[table["db_id"]] = table["ddls"] 22 | print("len(db_id2ddls):", len(db_id2ddls)) 23 | else: 24 | assert "schema" in dataset_json[0], "When tables_json_path not exists, the schema should be in dataset_json" 25 | 26 | prompt_tamplate = open(prompt_tamplate_path).read() 27 | for data in tqdm(dataset_json): 28 | if data["external_knowledge"] != "": 29 | question = data["external_knowledge"] + "\n" + data["question"] 30 | else: 31 | question = data["question"] 32 | 33 | if os.path.exists(tables_json_path): 34 | schema = "\n\n".join(db_id2ddls[data["db_id"]]) 35 | else: 36 | schema = data["schema"] 37 | database_note = open(database_note_template).read() 38 | data["cot_synthesis_prompt"] = prompt_tamplate.format( 39 | database_backend = database_backend, 40 | schema = schema, 41 | question = question, 42 | sql = remove_sql_comments(data["sql"]), 43 | embedding_model_name = embedding_model_name, 44 | database_note = database_note 45 | ) 46 | # 创建输出目录 47 | os.makedirs("sqlite/prompts", exist_ok=True) 48 | with open(output_prompt_path, "w", encoding="utf-8") as f: 49 | f.write(json.dumps(dataset_json, indent=2, ensure_ascii=False)) 50 | 51 | if __name__ == "__main__": 52 | generate_cot_prompts() 53 | -------------------------------------------------------------------------------- /Embedding_Service/README.md: -------------------------------------------------------------------------------- 1 | # 嵌入服务 (Embedding Service) 2 | 3 | 嵌入服务提供一个高性能、支持多模型、多GPU的文本和图像向量化API服务。它基于FastAPI和Sentence-Transformers构建,能够自动管理模型下载与缓存。 4 | 5 | ## 主要功能 6 | 7 | - **多模型支持**: 可通过`config.yaml`配置文件同时加载和管理多个不同的向量化模型。 8 | - **高性能**: 9 | - 基于FastAPI和Uvicorn,提供异步处理能力。 10 | - 支持通过`tensor_parallel_size`配置为单个模型启动多进程池,充分利用多GPU资源进行张量并行计算。 11 | - **自动模型缓存**: 12 | - 服务启动时,会自动检查`config.yaml`中指定的本地模型路径。 13 | - 如果模型不存在,服务会从Hugging Face Hub自动下载并保存到指定路径,后续启动将直接从本地加载,避免重复下载。 14 | - **统一的API接口**: 15 | - `/embed`: 核心接口,接收文本或图像数据,返回对应的向量表示。支持文本和图像两种输入模式。 16 | - `/health`: 健康检查接口,返回服务运行状态和已加载的模型列表。 17 | - **客户端示例**: 提供`multi_client.py`作为示例,演示如何请求API来获取文本和图像的向量。 18 | 19 | ## 文件结构 20 | 21 | - `server.py`: 核心服务文件。实现了FastAPI应用,负责模型加载、多进程池管理和API请求处理。 22 | - `multi_server.py`: `server.py`的多模态版本,支持同时处理文本和图像嵌入请求(弃用)。 23 | - `multi_client.py`: 用于测试图片嵌入服务的客户端示例代码。 24 | - `run.sh`: 启动服务的便捷脚本。 25 | - `config.yaml`(需自行创建): 服务和模型的配置文件。 26 | 27 | ## 环境依赖 28 | 29 | 所有依赖项都已在`requirements.txt`中列出。 30 | 31 | ## 快速开始 32 | 33 | 1. **安装依赖**: 34 | ```bash 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | 2. **创建配置文件**: 39 | 在`Embedding_Service`目录下创建一个名为`config.yaml`的文件,并参考以下示例填入内容。 40 | 41 | ```yaml 42 | # config.yaml 示例 43 | server: 44 | host: "0.0.0.0" 45 | port: 8000 46 | 47 | models: 48 | # 文本模型示例 49 | - name: "all-MiniLM-L6-v2" 50 | hf_model_path: "sentence-transformers/all-MiniLM-L6-v2" 51 | local_model_path: "./models/all-MiniLM-L6-v2" # 本地缓存路径 52 | trust_remote_code: true 53 | max_model_len: 512 54 | 55 | # 多模态模型示例 (CLIP) 56 | - name: "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" 57 | hf_model_path: "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" 58 | local_model_path: "./models/CLIP-ViT-B-32-laion2B-s34B-b79K" 59 | trust_remote_code: true 60 | tensor_parallel_size: 2 # 使用2个GPU 61 | ``` 62 | **注意**: 请确保`local_model_path`指向的目录存在或有权限创建。 63 | 64 | 3. **启动服务**: 65 | ```bash 66 | bash run.sh 67 | ``` 68 | 服务启动后,会首先检查并准备模型。首次运行会因下载模型而耗时较长。 69 | 70 | 4. **测试API**: 71 | - **健康检查**: 72 | ```bash 73 | curl http://localhost:8000/health 74 | ``` 75 | - **获取文本向量**: 76 | ```bash 77 | curl -X POST http://localhost:8000/embed \ 78 | -H "Content-Type: application/json" \ 79 | -d 80 | { 81 | "model": "all-MiniLM-L6-v2", 82 | "texts": [ 83 | "Hello World!", 84 | "Machine Learning" 85 | ] 86 | } 87 | ``` 88 | - **图片嵌入测试**: 89 | ```bash 90 | python multi_client.py 91 | ``` 92 | -------------------------------------------------------------------------------- /Data_Synthesizer/tools/convert_db_id.py: -------------------------------------------------------------------------------- 1 | import ijson 2 | import json 3 | from tqdm import tqdm 4 | 5 | def process_large_json(input_path, output_path): 6 | """ 7 | 流式读取一个大型JSON文件,修改每个对象的 'db_id' 字段, 8 | 然后将修改后的对象流式写入一个新的JSON文件。 9 | 10 | Args: 11 | input_path (str): 输入的大型JSON文件路径。 12 | output_path (str): 输出的JSON文件路径。 13 | """ 14 | # 路径的前缀 15 | prefix = "/mnt/b_public/data/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/sqlite/results/synthesis_data_deversity/vector_databases/" 16 | 17 | # 以二进制模式读取输入文件,以文本模式写入输出文件 18 | with open(input_path, 'rb') as f_in, open(output_path, 'w', encoding='utf-8') as f_out: 19 | 20 | # ijson.items 会返回一个迭代器,逐个产出文件根数组('item')中的对象 21 | # 这样可以避免将整个文件加载到内存中 22 | json_objects = ijson.items(f_in, 'item') 23 | 24 | is_first_item = True 25 | 26 | # 手动开始写入JSON数组 27 | f_out.write('[') 28 | 29 | print(f"开始处理文件 {input_path}...") 30 | 31 | # 使用tqdm来显示处理进度条 32 | for item in tqdm(json_objects, desc="处理进度"): 33 | # 如果不是第一个元素,就在前面加上逗号,以符合JSON数组的格式 34 | if not is_first_item: 35 | f_out.write(',') 36 | 37 | # 获取原始的 "db_id" 38 | db_id_name = item.get("db_id") 39 | 40 | # 如果存在 "db_id" 字段 41 | if db_id_name and isinstance(db_id_name, str): 42 | # 拼接新的文件路径 43 | # 示例: prefix + db_id_name + / + db_id_name + .sqlite 44 | new_db_id_path = f"{prefix}{db_id_name}/{db_id_name}.sqlite" 45 | 46 | # 更新字典中的值 47 | item["db_id"] = new_db_id_path 48 | 49 | # 将处理过的单个Python字典转换成JSON字符串并写入输出文件 50 | # ensure_ascii=False 保证中文字符能被正确写入 51 | json.dump(item, f_out, ensure_ascii=False) 52 | 53 | is_first_item = False 54 | 55 | # 手动结束JSON数组 56 | f_out.write(']') 57 | 58 | print(f"\n处理完成!结果已保存至 {output_path}") 59 | 60 | # ============================================================================== 61 | # =================== 在这里修改您的输入和输出文件名 =================== 62 | # ============================================================================== 63 | 64 | # 1. 设置您的原始大JSON文件的名字或路径 65 | input_filename = "embedding_table_vector_old.json" 66 | 67 | # 2. 设置您希望保存结果的新文件名 68 | output_filename = "embedding_table_vector.json" 69 | 70 | # ============================================================================== 71 | 72 | # 执行主函数 73 | if __name__ == "__main__": 74 | process_large_json(input_filename, output_filename) 75 | -------------------------------------------------------------------------------- /Data_Synthesizer/tools/batch_copy_files.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | # --- 配置区域 --- 5 | 6 | # 1. 定义需要处理的数据库列表 7 | DATABASES = [ 8 | "arxiv", 9 | "bird", 10 | "spider", 11 | "synthesis_data", 12 | "wikipedia_multimodal" 13 | ] 14 | 15 | # 2. 定义后缀到目标数据库类型的映射关系 16 | # 'pg' -> 'postgresql' 17 | # 'ch' -> 'clickhouse' 18 | SUFFIX_MAP = { 19 | "pg": "postgresql", 20 | "ch": "clickhouse" 21 | } 22 | 23 | # 3. 定义基础的源目录和目标目录 24 | BASE_SOURCE_DIR = "/mnt/b_public/data/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/sqlite/results" 25 | BASE_DEST_DIR = "/mnt/b_public/data/ydw/Text2VectorSQL/Data_Synthesizer/pipeline" 26 | 27 | # --- 脚本主逻辑 --- 28 | 29 | def run_batch_copy(): 30 | """ 31 | 根据配置,批量复制文件到指定位置。 32 | """ 33 | print("🚀 开始执行批量文件复制任务...") 34 | 35 | copied_count = 0 36 | skipped_count = 0 37 | 38 | # 遍历每一个数据库 39 | for db_name in DATABASES: 40 | # 遍历每一种后缀 ('pg' 和 'ch') 41 | for suffix, db_type in SUFFIX_MAP.items(): 42 | 43 | # --- 步骤 1: 构建源文件路径 --- 44 | source_filename = f"input_llm_{suffix}.json" 45 | source_path = os.path.join(BASE_SOURCE_DIR, db_name, source_filename) 46 | 47 | # --- 步骤 2: 构建目标文件路径 --- 48 | # 目标目录,例如:.../pipeline/postgresql/results/arxiv/ 49 | dest_dir = os.path.join(BASE_DEST_DIR, db_type, "results", db_name) 50 | # 目标文件,统一命名为 input_llm.json 51 | dest_path = os.path.join(dest_dir, "input_llm.json") 52 | 53 | # --- 步骤 3: 检查源文件是否存在 --- 54 | if os.path.exists(source_path): 55 | try: 56 | # --- 步骤 4: 确保目标目录存在,如果不存在则创建 --- 57 | os.makedirs(dest_dir, exist_ok=True) 58 | 59 | # --- 步骤 5: 执行文件复制操作 --- 60 | # 使用 shutil.copy2 可以同时复制文件内容和元数据(如修改时间) 61 | shutil.copy2(source_path, dest_path) 62 | print(f"✅ 复制成功: \n - 源: {source_path}\n - 至: {dest_path}\n") 63 | copied_count += 1 64 | 65 | except (OSError, shutil.Error) as e: 66 | print(f"❌ 复制失败: 从 {source_path} 到 {dest_path}\n - 错误: {e}\n") 67 | skipped_count += 1 68 | else: 69 | # 如果源文件不存在,则打印提示并跳过 70 | print(f"⚠️ 源文件不存在,已跳过: {source_path}\n") 71 | skipped_count += 1 72 | 73 | print("--- 任务摘要 ---") 74 | print(f"总计成功复制: {copied_count} 个文件") 75 | print(f"总计跳过(或失败): {skipped_count} 个文件") 76 | print("✨ 批量任务执行完毕。") 77 | 78 | 79 | if __name__ == "__main__": 80 | run_batch_copy() -------------------------------------------------------------------------------- /Evaluation_Framework/README.md: -------------------------------------------------------------------------------- 1 | # 评估框架 (Evaluation Framework) 2 | 3 | 本模块专门用于评估模型生成的SQL(特别是向量SQL)查询的准确性。 4 | 5 | 该框架通过一个分为两个核心阶段的流水线来工作:**SQL执行** 和 **结果评估**。 6 | 7 | ## 核心功能 8 | 9 | 1. **SQL生成 (`generate.py`)** 10 | * 根据输入的问题、数据库Schema和相关元数据,调用大语言模型(LLM)生成预测的SQL查询。 11 | * 支持两种生成模式: 12 | * **vLLM离线推理**:利用vLLM库在本地进行高效的批量推理,支持多GPU张量并行。 13 | * **API在线调用**:通过HTTP API调用外部模型服务(如OpenAI、Claude等),支持多线程并发请求。 14 | * 具备强大的缓存机制,能够根据输入数据和模型自动缓存生成结果,实现断点续传。 15 | 16 | 2. **SQL执行 (`sql_executor.py`)** 17 | * 连接到指定的数据库(支持 SQLite、PostgreSQL、ClickHouse 等)并执行SQL查询。 18 | * **沙箱化执行**:在独立的进程中执行每个SQL查询,并强制实施超时,防止恶意或低效的查询卡死整个评估流程。 19 | * 同时执行模型生成的预测SQL和所有标准的(Ground Truth)SQL,并保存两者的执行结果(数据、列名、状态等)。 20 | * 集成了嵌入服务(Embedding Service)的自动管理功能,可以在执行前自动启动所需的服务。 21 | * 同样具备缓存机制,可以跳过已成功执行的查询。 22 | 23 | 3. **结果评估 (`evaluate_results.py`)** 24 | * 对比预测SQL和标准SQL的**执行结果**,而非仅仅比较SQL字符串。 25 | * 支持多种评估指标: 26 | * **精确匹配 (Exact Match)**: 预测结果与任一标准结果完全一致。 27 | * **集合指标 (Set-based Metrics)**: 28 | * **Precision, Recall, F1-Score**: 基于结果集的交集计算,不考虑行顺序。 29 | * **排序指标 (Ranking Metrics)**: 30 | * **nDCG@k**: 评估返回结果的排序质量。 31 | * **MAP (Mean Average Precision)**, **MRR (Mean Reciprocal Rank)**. 32 | * **基于LLM的评估**: 调用另一个LLM来从语义层面评估预测SQL的“SQL骨架”和“向量部分”的正确性。 33 | * 评估过程同样支持并发处理和断点续传。 34 | 35 | 4. **结果聚合 (`aggregate_results.py`)** 36 | * 一个实用工具,用于从多个模型、多个数据集的评估报告(JSON文件)中收集、汇总评估指标。 37 | * 将分散的结果聚合成一个结构化的CSV文件,便于横向对比不同模型的性能。 38 | 39 | ## 使用流程 40 | 41 | 评估过程通过 `run_eval_pipeline.py` 脚本进行统一调度,该脚本通过读取 `evaluation_config.yaml` 配置文件来驱动整个流程。 42 | 43 | 1. **准备配置文件 (`evaluation_config.yaml`)** 44 | * 配置数据库类型 (`db_type`)、数据库文件根目录 (`base_dir`)。 45 | * 指定包含预测SQL的输入文件 (`eval_data_file`)。 46 | * 定义中间结果和最终报告的输出路径。 47 | * 配置需要计算的评估指标 (`metrics`)。 48 | * (可选)配置嵌入服务的地址和模型。 49 | 50 | 2. **运行SQL生成 (如果需要)** 51 | * 执行 `generate.py` 脚本,为数据集生成预测SQL。 52 | * `python generate.py --config generate_config.yaml` 53 | 54 | 3. **运行评估流水线** 55 | * **完整流程** (执行 + 评估): 56 | ```bash 57 | python run_eval_pipeline.py --all --config evaluation_config.yaml 58 | ``` 59 | * **仅执行SQL**: 60 | ```bash 61 | python run_eval_pipeline.py --execute --config evaluation_config.yaml 62 | ``` 63 | * **仅评估结果** (需要已有的执行结果文件): 64 | ```bash 65 | python run_eval_pipeline.py --evaluate --config evaluation_config.yaml 66 | ``` 67 | 68 | 4. **聚合多个实验的结果** 69 | * 将所有实验的报告(JSON文件)按约定目录结构存放。 70 | * 运行 `aggregate_results.py` 生成对比表格。 71 | ```bash 72 | python aggregate_results.py --results-dir ./results --output summary.csv 73 | ``` 74 | 75 | ## 依赖安装 76 | 77 | 要运行此评估框架,请安装所需的Python包: 78 | 79 | ```bash 80 | pip install -r requirements.txt 81 | ``` -------------------------------------------------------------------------------- /Data_Synthesizer/database_synthesis/build_sqlite_databases.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | from sqlite_schema_parser import verify_schema 4 | import random 5 | 6 | if __name__ == "__main__": 7 | enhanced_results = json.load(open("./results/schema_enhancement.json")) 8 | 9 | final_schemas = [] 10 | error_case_num = 0 11 | for result in tqdm(enhanced_results): 12 | try: 13 | domain = result["domain"] 14 | schema = json.loads(result["enhanced_schema"]) 15 | assert "tables" in schema and "foreign_keys" in schema 16 | 17 | tables = [] 18 | for table in schema["tables"]: 19 | try: 20 | assert "table_name" in table and "column_names" in table and \ 21 | "column_types" in table and "column_descriptions" in table 22 | assert len(table["column_names"]) == len(table["column_types"]) == len(table["column_descriptions"]) 23 | tables.append(table) 24 | except Exception as e: 25 | pass 26 | 27 | table_names_lower = [table["table_name"].lower() for table in tables] 28 | 29 | foreign_keys = [] 30 | for foreign_key in schema["foreign_keys"]: 31 | try: 32 | assert "source_table" in foreign_key and "column_in_source_table" in foreign_key and \ 33 | "referenced_table" in foreign_key and "column_in_referenced_table" in foreign_key 34 | assert foreign_key["source_table"].lower() in table_names_lower and \ 35 | foreign_key["referenced_table"].lower() in table_names_lower 36 | foreign_keys.append(foreign_key) 37 | except Exception as e: 38 | pass 39 | 40 | final_schemas.append( 41 | { 42 | "domain": domain, 43 | "tables": tables, 44 | "foreign_keys": foreign_keys 45 | } 46 | ) 47 | except Exception as e: 48 | error_case_num += 1 49 | # print(e) 50 | print("error_case_num:", error_case_num) 51 | 52 | db_ids = [] 53 | success_labels = [] 54 | for final_schema in tqdm(final_schemas): 55 | db_id = final_schema["domain"].lower().replace("(", "_").replace(")", "_").replace("-", "_").replace(" ", "_").replace("*", "_").strip() 56 | 57 | if len(db_id) > 75: 58 | db_id = db_id[:75] 59 | 60 | # resolve db_id conflict issues 61 | while db_id in db_ids: 62 | db_id += "_" + str(random.randint(0, 1000000000000)) 63 | 64 | success_label = verify_schema(final_schema, db_id) 65 | if success_label: 66 | db_ids.append(db_id) 67 | 68 | success_labels.append(success_label) 69 | 70 | print("success rate:", sum(success_labels)/len(success_labels)) 71 | -------------------------------------------------------------------------------- /Data_Synthesizer/database_synthesis/prompt_templates/embedding_with_new_line_prompt.txt: -------------------------------------------------------------------------------- 1 | **Task Overview:** 2 | As a senior data analyst, your task is to enhance an initial database schema to incorporate embedding columns for semantically rich fields based on a given business scenario while preserving all original data structure and sample values. 3 | 4 | **Strict Requirements:** 5 | 1. **Schema Preservation:** 6 | - Maintain ALL original columns unless you have valid reason 7 | - Do not delete existing sample_rows values just modify them 8 | - Keep original column order unless adding new embedding columns 9 | 10 | 2. **Embedding Column Addition:** 11 | - Select columns based on semantic richness (text, descriptions, content, abstract, review, title, plot_summary and name), then add new columns with suffix `_embedding` (type: BLOB) 12 | - If the table lacks columns with semantically rich information (e.g., names, descriptions), and you determine that such columns would improve its utility, you may add them with appropriate, meaningful data. For each added column, ensure you also create a corresponding embedding column by appending '_embedding' to its name (e.g., adding a 'description' column requires a 'description_embedding' column) 13 | - Never modify existing column names/types/values 14 | 15 | 3. **Sample Data Handling:** 16 | - For new embedding columns in sample_rows: 17 | * Use explicit null values (`null`) 18 | * Maintain original structure 19 | * Preserve all existing key-value pairs 20 | 21 | **Steps:** 22 | 1. **Analyze Semantic Columns:** Identify existing columns with rich semantic information (text descriptions, reviews, content, names, etc.) 23 | 2. **Generate Embedding Columns:** For each identified column: 24 | - Create a new column with name `[original_col]_embedding` 25 | - Set type as `BLOB` (binary storage for vector embeddings) 26 | - Keep sample values as empty (`null` or `""`) 27 | 3. **Generate New Columns With Embedding Columns:** If no columns with rich semantic information, you need to check the possibility of generate new columns with rich semantic information and corresponding meaningful sample values and Embedding Columns. 28 | Key Clarifications: 29 | ​--​Condition​​: Only add columns if the table lacks meaningful semantic data andyou judge them necessary 30 | --​​Embedding Requirement​​: Emphasizes the 1:1 pairing (new column → with suffix _embedding column) with naming convention 31 | 3. **Preserve Relationships:** Maintain existing primary/foreign key relationships 32 | 4. **Document Reasoning:** Briefly explain why each embedding column was added 33 | 34 | **Business Domain:** 35 | {domain} 36 | 37 | **Business Scenario:** 38 | {scenario} 39 | 40 | **Initial Database Schema:** 41 | ```json 42 | {schema} 43 | ``` 44 | 45 | **Output Format:** 46 | Your output should provide the enriched database schema with embedding columns in JSON format: 47 | ```json 48 | -- enriched database schema with embedding 49 | ``` 50 | 51 | Let's think step by step. 52 | -------------------------------------------------------------------------------- /Data_Synthesizer/pipeline/clickhouse/prompt_templates/clickhouse_vec_note_prompt.txt: -------------------------------------------------------------------------------- 1 | There are a few requirements you should comply with in addition: 2 | 1. When generating SQL queries, you should prioritize utilizing K-Nearest Neighbor (KNN) searches whenever contextually appropriate. However, you must avoid unnecessary/forced KNN implementations for: 3 | -- Traditional relational data queries (especially for columns like: id, age, price). 4 | -- Cases where standard SQL operators (equality, range, or aggregation functions) are more efficient and semantically appropriate. 5 | 2. Only columns with a vector type (like: Array(Float32)) support KNN queries. The names of these vector columns often end with "_embedding". You can perform KNN searches when the column name you need to query ends with "_embedding" or is otherwise identified as a vector column. 6 | 3. In ClickHouse, vector similarity search is performed using distance functions. You must explicitly calculate the distance in the SELECT clause using a function like L2Distance and give it an alias, typically "AS distance". This distance alias will not be implicitly generated. 7 | 4. The lembed function is used to transform a string into a semantic vector. This function should be used within a WITH clause to define the reference vector. The lembed function has two parameters: the first is the name of the embedding model used (default value: '{embedding_model}'), and the second is the string content to embed. The resulting vector should be given an alias in the WITH clause. 8 | 5. You must generate plausible and semantically relevant words or sentences for the second parameter of the lembed function based on the column's name, type, and comment. For example, if a column is named product_description_embedding and its comment is "Embedding of the product's features and marketing text", you could generate text like "durable and waterproof outdoor adventure camera". 9 | 6. Every KNN search query MUST conclude with "ORDER BY distance LIMIT N" to retrieve the top-N most similar results. The LIMIT clause is mandatory for performing a KNN search and ensuring predictable performance. 10 | 7. When combining a vector search with JOIN operations, the standard WHERE clause should be used to apply filters from any of the joined tables. The ORDER BY distance LIMIT N clause is applied after all filtering and joins are resolved. 11 | 8. A SELECT statement should typically be ordered by a single distance calculation to perform one primary KNN search. However, subqueries can perform their own independent KNN searches, each with its own WITH clause, distance calculation, and ORDER BY distance LIMIT N clause. 12 | 13 | ## Example of a ClickHouse KNN Query 14 | DB Schema: Some table on articles with a column content_embedding Array(Float32). 15 | Query Task: Identify the article ID of the single most relevant article discussing innovative algorithms in graph theory. 16 | Generated SQL: 17 | ```sql 18 | WITH\n lembed('all-MiniLM-L6-v2', 'innovative algorithms in graph theory.') AS ref_vec_0\n\nSELECT id, L2Distance(articles.abstract_embedding, ref_vec_0) AS distance\nFROM articles\nORDER BY distance\nLIMIT 1; 19 | ``` -------------------------------------------------------------------------------- /Evaluation_Framework/script/generate_ground_truth.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import os 4 | 5 | def transform_json_data(input_path: str, output_path: str): 6 | """ 7 | 读取一个包含特定结构字典数组的 JSON 文件,并将其转换为新的字典格式。 8 | 9 | 输入格式: 10 | [ 11 | { 12 | "query_id": "q1", "db_id": "...", "sql": "...", "sql_candidate": ["..."] 13 | }, ... 14 | ] 15 | 16 | 输出格式: 17 | { 18 | "q1": { "db_name": "...", "sqls": ["...", "..."] }, ... 19 | } 20 | 21 | Args: 22 | input_path (str): 输入的 JSON 文件路径。 23 | output_path (str): 输出的 JSON 文件路径。 24 | """ 25 | # 定义输入字典必须包含的字段 26 | required_keys = {"query_id", "db_id", "sql", "sql_candidate"} 27 | 28 | try: 29 | # 1. 读取并解析输入的 JSON 文件 30 | with open(input_path, 'r', encoding='utf-8') as f: 31 | source_data = json.load(f) 32 | 33 | # 2. 检查输入数据是否为列表 34 | if not isinstance(source_data, list): 35 | print(f"错误: JSON 文件的顶层结构应为一个列表 (array)。文件: {input_path}") 36 | return 37 | 38 | # 3. 初始化一个新的空字典用于存放结果 39 | transformed_data = {} 40 | 41 | # 4. 循环处理源列表中的每一个元素 42 | for index, item in enumerate(source_data): 43 | # 确保元素是字典并且包含所有必需的字段 44 | if not isinstance(item, dict): 45 | print(f"警告: 在索引 {index} 处找到一个非字典元素,已跳过。") 46 | continue 47 | 48 | if not required_keys.issubset(item.keys()): 49 | print(f"警告: 在索引 {index} 处的字典缺少必要字段,已跳过。必需字段: {required_keys}") 50 | continue 51 | 52 | # 提取所需数据 53 | query_id = item["query_id"] 54 | db_name = item["db_id"] 55 | main_sql = item["sql"] 56 | candidate_sqls = item["sql_candidate"] 57 | 58 | # 检查 sql_candidate 是否为列表 59 | if not isinstance(candidate_sqls, list): 60 | print(f"警告: query_id '{query_id}' 的 'sql_candidate' 字段不是列表,已跳过。") 61 | continue 62 | 63 | # 构造新的 sqls 列表 64 | # 将主 sql 字符串放在列表开头,然后拼接上候选 sql 列表 65 | all_sqls = [main_sql] + candidate_sqls 66 | 67 | # 按照目标格式,构建新的字典条目 68 | transformed_data[query_id] = { 69 | "db_name": db_name, 70 | "sqls": all_sqls 71 | } 72 | 73 | # 5. 将转换后的字典写入新的 JSON 文件 74 | # indent=4 使输出文件格式优美,易于阅读 75 | # ensure_ascii=False 保证中文字符的正确显示 76 | with open(output_path, 'w', encoding='utf-8') as f: 77 | json.dump(transformed_data, f, indent=2, ensure_ascii=False) # 使用 indent=2 更接近示例 78 | 79 | print(f"处理成功!共转换 {len(transformed_data)} 条数据。") 80 | print(f"结果已保存至: {output_path}") 81 | 82 | except FileNotFoundError: 83 | print(f"错误: 找不到输入文件 '{input_path}'。") 84 | except json.JSONDecodeError: 85 | print(f"错误: 文件 '{input_path}' 的内容不是有效的 JSON 格式。") 86 | except Exception as e: 87 | print(f"发生了未知错误: {e}") 88 | 89 | # --- 使用示例 --- 90 | if __name__ == "__main__": 91 | input_file = "../data/candidate_sql_query_id.json" 92 | output_file = "../data/ground_truth.json" 93 | 94 | # 调用核心函数进行转换 95 | transform_json_data(input_file, output_file) 96 | 97 | -------------------------------------------------------------------------------- /Data_Synthesizer/pipeline/myscale/prompt_templates/myscale_vec_note_prompt.txt: -------------------------------------------------------------------------------- 1 | There are a few requirements you should comply with in addition: 2 | 1. When generating SQL queries, you should prioritize utilizing K-Nearest Neighbor (KNN) searches whenever contextually appropriate. However, you must avoid unnecessary/forced KNN implementations for: 3 | -- Traditional relational data queries (especially for columns like: id, age, price). 4 | -- Cases where standard SQL operators (equality, range, or aggregation functions) are more efficient and semantically appropriate. 5 | 2. Only columns with a vector type (like: Array(Float32) or FixedString) support KNN queries. The names of these vector columns often end with "_embedding". You can perform KNN searches when the column name you need to query ends with "_embedding" or is otherwise identified as a vector column. 6 | 3. In MyScale, vector similarity search is performed using the `distance()` function. You must explicitly calculate the distance in the SELECT clause and give it an alias, typically "AS distance". This distance alias will not be implicitly generated. 7 | 4. **MyScale Specific Syntax:** When providing a query vector (the "needle") for an `Array(Float32)` column, at least one number in the array *must* contain a decimal point (e.g., `[3.0, 9, 45]`). This prevents the database from misinterpreting the vector as `Array(UInt64)`, which would cause an error. 8 | 5. The `lembed` function is used to transform a string into a semantic vector. This function should be used within a WITH clause to define the reference vector. The lembed function has two parameters: the first is the name of the embedding model used (default value: '{embedding_model}'), and the second is the string content to embed. The resulting vector should be given an alias in the WITH clause. 9 | 6. You must generate plausible and semantically relevant words or sentences for the second parameter of the `lembed` function based on the column's name, type, and comment. For example, if a column is named `product_description_embedding` and its comment is "Embedding of the product's features and marketing text", you could generate text like "durable and waterproof outdoor adventure camera". 10 | 7. Every KNN search query MUST conclude with "ORDER BY distance LIMIT N" to retrieve the top-N most similar results. The LIMIT clause is mandatory for performing a KNN search and ensuring predictable performance. 11 | 8. When combining a vector search with JOIN operations, the standard `WHERE` clause should be used to apply filters from any of the joined tables. The `ORDER BY distance LIMIT N` clause is applied after all filtering and joins are resolved. 12 | 9. A SELECT statement should typically be ordered by a single distance calculation to perform one primary KNN search. However, subqueries can perform their own independent KNN searches, each with its own WITH clause, distance calculation, and `ORDER BY distance LIMIT N` clause. 13 | 14 | ## Example of a MyScale KNN Query 15 | DB Schema: Some table on `articles` with a column `abstract_embedding` `Array(Float32)`. 16 | Query Task: Identify the article ID of the single most relevant article discussing innovative algorithms in graph theory. 17 | Generated SQL: 18 | ``` 19 | WITH 20 | lembed('all-MiniLM-L6-v2', 'innovative algorithms in graph theory.') AS ref_vec_0 21 | SELECT id, distance(articles.abstract_embedding, ref_vec_0) AS distance 22 | FROM articles 23 | ORDER BY distance 24 | LIMIT 1; 25 | ``` 26 | -------------------------------------------------------------------------------- /Data_Synthesizer/tools/add_prefix.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | def add_prefix_to_dirs_and_files(target_directory, prefix="deverse_2_"): 5 | """ 6 | 遍历指定目录,为符合条件的子目录和其中的同名 .sqlite 文件添加前缀。 7 | 8 | 条件: 9 | 1. 目标必须是 `target_directory` 下的一个子目录。 10 | 2. 该子目录中必须包含一个与子目录同名的 .sqlite 文件。 11 | 12 | 例如: 13 | - /your/path/folder1/folder1.sqlite 14 | 将被重命名为: 15 | - /your/path/deverse_2_folder1/deverse_2_folder1.sqlite 16 | 17 | Args: 18 | target_directory (str): 需要处理的根目录路径。 19 | prefix (str): 要添加的前缀。 20 | """ 21 | # 检查目标目录是否存在 22 | if not os.path.isdir(target_directory): 23 | print(f"错误:目录 '{target_directory}' 不存在或不是一个有效的目录。") 24 | sys.exit(1) # 退出脚本 25 | 26 | print(f"开始扫描目录: {target_directory}\n") 27 | 28 | # 获取目录下所有的文件和文件夹名 29 | try: 30 | entries = os.listdir(target_directory) 31 | except OSError as e: 32 | print(f"错误:无法访问目录 '{target_directory}'。请检查权限。") 33 | print(f"详细信息: {e}") 34 | sys.exit(1) 35 | 36 | renamed_count = 0 37 | # 遍历所有条目 38 | for dir_name in entries: 39 | old_dir_path = os.path.join(target_directory, dir_name) 40 | 41 | # 检查当前条目是否是一个目录 42 | if os.path.isdir(old_dir_path): 43 | # 构造原始 sqlite 文件的路径 44 | sqlite_file_name = dir_name + ".sqlite" 45 | old_sqlite_file_path = os.path.join(old_dir_path, sqlite_file_name) 46 | 47 | # 检查同名的 sqlite 文件是否存在 48 | if os.path.isfile(old_sqlite_file_path): 49 | print(f"找到匹配项: 目录 '{dir_name}' 和文件 '{sqlite_file_name}'") 50 | 51 | # 1. 重命名目录 52 | new_dir_name = prefix + dir_name 53 | new_dir_path = os.path.join(target_directory, new_dir_name) 54 | 55 | try: 56 | print(f" -> 正在重命名目录 '{dir_name}' 为 '{new_dir_name}'") 57 | os.rename(old_dir_path, new_dir_path) 58 | 59 | # 2. 重命名目录内的 sqlite 文件 60 | # 注意:此时目录已经改名,所以要用 new_dir_path 61 | file_to_rename_path = os.path.join(new_dir_path, sqlite_file_name) 62 | new_sqlite_file_name = new_dir_name + ".sqlite" 63 | new_sqlite_file_path = os.path.join(new_dir_path, new_sqlite_file_name) 64 | 65 | print(f" -> 正在重命名文件 '{sqlite_file_name}' 为 '{new_sqlite_file_name}'") 66 | os.rename(file_to_rename_path, new_sqlite_file_path) 67 | 68 | print(" -> 操作成功!\n") 69 | renamed_count += 1 70 | 71 | except OSError as e: 72 | print(f" -> 操作失败!错误: {e}\n") 73 | # 如果目录重命名成功但文件失败,尝试恢复目录名 74 | if not os.path.exists(old_dir_path): 75 | os.rename(new_dir_path, old_dir_path) 76 | print(f" -> 已将目录恢复为 '{dir_name}'") 77 | 78 | 79 | if renamed_count > 0: 80 | print(f"处理完成!总共重命名了 {renamed_count} 个目录和文件对。") 81 | else: 82 | print("处理完成!没有找到符合条件的目录和文件对。") 83 | 84 | 85 | if __name__ == "__main__": 86 | # --- 请修改这里的路径 --- 87 | # Windows 示例: "C:\\Users\\YourUser\\Desktop\\my_databases" 88 | # macOS/Linux 示例: "/Users/youruser/Documents/my_databases" 89 | target_directory = "/mnt/b_public/data/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/sqlite/results/synthesis_data_deversity/vector_databases" # <--- 修改这里为你需要处理的目录路径! 90 | 91 | # 运行主函数 92 | add_prefix_to_dirs_and_files(target_directory) 93 | -------------------------------------------------------------------------------- /Execution_Engine/README.md: -------------------------------------------------------------------------------- 1 | # 执行引擎 (Execution Engine) 2 | 3 | 执行引擎负责解析并执行"VectorSQL"查询. 它在用户通过VectorSQL表达的意图与各种数据库系统的原生功能之间架起了一座桥梁. 4 | 5 | ## 核心功能 6 | 7 | 该引擎专门用于处理包含特殊函数 `lembed(model, text)` 的VectorSQL查询. 其主要职责包括: 8 | 9 | 1. **SQL解析**: 解析输入的VectorSQL查询, 查找所有 `lembed(model, text)` 函数的实例. 10 | 11 | 2. **向量化**: 对于找到的每一个唯一的 `(model, text)` 组合, 它会向一个外部的**Embedding服务**发起网络请求. 该服务负责使用指定的嵌入模型将文本转换为高维向量. 12 | 13 | 3. **SQL翻译**: 收到向量后, 引擎会将原始的VectorSQL翻译成与目标数据库兼容的原生查询. 它会用相应的向量字面量替换 `lembed(...)` 调用, 并确保格式符合数据库要求. 14 | 15 | 4. **数据库执行**: 引擎连接到指定的目标数据库 (PostgreSQL, ClickHouse, 或 SQLite), 并执行翻译后的原生查询. 16 | 17 | 5. **结果处理**: 获取查询结果, 并以结构化的JSON格式返回. 18 | 19 | 6. **超时与错误管理**: 引擎为网络请求, 数据库连接和查询执行实现了健壮的超时机制, 以防止无限期挂起. 同时, 它为失败的操作提供清晰的错误信息. 20 | 21 | ## 依赖项 22 | 23 | 引擎的正常运行依赖于多个Python库. 这些库已在 `requirements.txt` 文件中列出. 24 | 25 | - `psycopg2-binary`: 用于连接到PostgreSQL数据库. 26 | - `requests`: 用于向Embedding服务发出HTTP请求. 27 | - `PyYAML`: 用于从YAML文件加载引擎配置. 28 | - `clickhouse-connect`: 用于连接到ClickHouse数据库. 29 | - `sqlite-vec`: 用于提供向量搜索能力的自定义SQLite扩展. 30 | - `sqlite-lembed`: 用于处理 `lembed` 函数的自定义SQLite扩展. 31 | 32 | ## 配置 33 | 34 | 引擎通过一个YAML文件 (例如 `engine_config.yaml`) 进行配置. 该文件必须指定Embedding服务的URL, 并且可以定义不同数据库的连接参数和各种超时设置. 35 | 36 | **`engine_config.yaml` 示例:** 37 | ```yaml 38 | embedding_service: 39 | url: "http://127.0.0.1:8000/embed" 40 | 41 | database_connections: 42 | postgresql: 43 | user: "postgres" 44 | password: "postgres" 45 | host: "localhost" 46 | port: 5432 47 | clickhouse: 48 | host: "localhost" 49 | port: 8123 50 | user: "default" 51 | password: "" 52 | 53 | timeouts: 54 | embedding_service: 1 # 秒 55 | database_connection: 1 # 秒 56 | sql_execution: 60 # 秒 57 | total_execution: 60 # 秒 58 | ``` 59 | 60 | ## 用法 61 | 62 | `ExecutionEngine` 可以作为一个命令行工具使用,也可以作为一个Python类调用。 63 | 64 | 注意:输入VectorSQL的lembed部分不包含双引号,单引号内无单引号。 65 | 66 | ### 命令行界面 67 | 68 | ```bash 69 | # 安装依赖 70 | pip install -r requirements.txt 71 | 72 | # 运行查询 73 | python execution_engine.py \ 74 | --sql "SELECT Musical_ID,Name FROM musical m order by L2Distance(Category_embedding, lembed('all-MiniLM-L6-v2','xxx')) + L2Distance(Category_embedding, lembed('all-MiniLM-L6-v2','yyy')) LIMIT 5;" \ 75 | --db-type "clickhouse" \ 76 | --db-identifier "musical" \ 77 | --config "engine_config.yaml" 78 | ``` 79 | 80 | ### 命令行参数 81 | 82 | - `--sql`: (必需) 要执行的VectorSQL查询语句. 83 | - `--db-type`: (必需) 目标数据库的类型. 可选值: `postgresql`, `clickhouse`, `sqlite`. 84 | - `--db-identifier`: (必需) 数据库标识符 (例如, PostgreSQL/ClickHouse的数据库名, 或SQLite的文件路径). 85 | - `--config`: 引擎的YAML配置文件路径 (默认为 `engine_config.yaml`). 86 | - `--...-timeout`: 可选参数, 用于覆盖配置文件中的超时设置 (例如, `--sql-execution-timeout 90`). 87 | 88 | ### Python类调用 89 | 90 | ```python 91 | from execution_engine import ExecutionEngine 92 | 93 | # 1. 初始化引擎(可以在您的应用启动时完成) 94 | try: 95 | engine = ExecutionEngine(config_path="path/to/engine_config.yaml") 96 | except Exception as e: 97 | print(f"Failed to initialize engine: {e}") 98 | # 处理初始化失败 99 | 100 | # 2. 在需要时调用执行方法 101 | my_sql_query = "SELECT name FROM products ORDER BY embedding <-> lembed('bge-base', 'high quality headphones') LIMIT 3" 102 | db_name = "e_commerce_db" 103 | db_type = "postgresql" 104 | 105 | result = engine.execute(sql=my_sql_query, db_type=db_type, db_identifier=db_name) 106 | 107 | # 3. 处理结果 108 | if result['status'] == 'success': 109 | print("Execution successful!") 110 | print("Columns:", result['columns']) 111 | for row in result['data']: 112 | print(row) 113 | else: 114 | print("Execution failed!") 115 | print("Error:", result['message']) 116 | ``` -------------------------------------------------------------------------------- /Data_Synthesizer/synthesis_eval/generate_input.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import os 4 | 5 | from tqdm import tqdm 6 | 7 | 8 | def _filter_success_only(dataset_json): 9 | """如果存在 execution_status 字段,仅保留成功记录。""" 10 | has_status = any('execution_status' in item for item in dataset_json) 11 | if not has_status: 12 | return dataset_json 13 | filtered = [item for item in dataset_json if item.get('execution_status') == 'success'] 14 | return filtered 15 | 16 | def generate_input_llm(dataset_json_path="../pipeline/sqlite/results/toy_spider/candidate_sql.json", tables_json_path="../pipeline/sqlite/results/toy_spider/embedding_table_vector.json", prompt_tamplate_path="../pipeline/sqlite/prompt_templates/sql_generate_prompt_template.txt", output_input_path="../pipeline/sqlite/results/toy_spider/input_llm.json",dataset_backend="sqlite",database_note_prompt_path="../pipeline/sqlite/prompt_templates/sqlite_vec_note_prompt.txt",embedding_model_name="all-MiniLM-L6-v2"): 17 | dataset_json = json.load(open(dataset_json_path)) 18 | dataset_json = _filter_success_only(dataset_json) 19 | print("len(question-vecsql):", len(dataset_json)) 20 | 21 | if os.path.exists(tables_json_path): 22 | db_id2ddls = dict() 23 | tables_json = json.load(open(tables_json_path)) 24 | for table in tables_json: 25 | db_id2ddls[table["db_id"]] = table["ddls"] 26 | print("len(db_id2ddls):", len(db_id2ddls)) 27 | else: 28 | assert "schema" in dataset_json[0], "When tables_json_path not exists, the schema should be in dataset_json" 29 | 30 | database_note_prompt = open(database_note_prompt_path).read().format(embedding_model = embedding_model_name) 31 | prompt_tamplate = open(prompt_tamplate_path).read() 32 | for data in tqdm(dataset_json): 33 | if data["external_knowledge"] != "": 34 | question = data["external_knowledge"] + "\n" + data["question"] 35 | else: 36 | question = data["question"] 37 | 38 | if os.path.exists(tables_json_path): 39 | schema = "\n\n".join(db_id2ddls[data["db_id"]]) 40 | else: 41 | schema = data["schema"] 42 | 43 | data["db_type"] = dataset_backend 44 | data["embedding_model_name"] = embedding_model_name 45 | data["database_note_prompt"] = database_note_prompt 46 | data["input"] = prompt_tamplate.format( 47 | dataset_backend =dataset_backend, 48 | schema = schema, 49 | database_note_prompt = database_note_prompt, 50 | embedding_model_name = embedding_model_name, 51 | question = question, 52 | ) 53 | 54 | 55 | # 创建输出目录 56 | # os.makedirs("../pipeline/sqlite/results", exist_ok=True) 57 | with open(output_input_path, "w", encoding="utf-8") as f: 58 | f.write(json.dumps(dataset_json, indent=2, ensure_ascii=False)) 59 | 60 | def generate_output_llm(dataset_json_path="../pipeline/sqlite/results/toy_spider/input_llm.json", output_path_input="../pipeline/sqlite/results/toy_spider/output_llm.json",dataset_backend="sqlite"): 61 | dataset_json = json.load(open(dataset_json_path)) 62 | dataset_json = _filter_success_only(dataset_json) 63 | print("len(question-vecsql):", len(dataset_json)) 64 | 65 | for data in tqdm(dataset_json): 66 | 67 | data['output'] = { 68 | "sql": data["sql"], 69 | "cot": data["cot"], 70 | } 71 | 72 | 73 | # 创建输出目录 74 | os.makedirs("../pipeline/sqlite/results", exist_ok=True) 75 | with open(output_path_input, "w", encoding="utf-8") as f: 76 | f.write(json.dumps(dataset_json, indent=2, ensure_ascii=False)) 77 | 78 | if __name__ == "__main__": 79 | generate_input_llm() 80 | 81 | -------------------------------------------------------------------------------- /Evaluation_Framework/evaluation_config.yaml: -------------------------------------------------------------------------------- 1 | # ================================================== 2 | # Text2VectorSQL Evaluation Framework 3 | # ================================================== 4 | 5 | # --- Database Configuration --- 6 | # Base directory for database files (only for SQLite) 7 | # For SQLite: db_identifier is a relative path to the database file 8 | # For PostgreSQL/ClickHouse: db_identifier is the database name 9 | 10 | #数据库的根目录 11 | base_dir: /mnt/b_public/data/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/sqlite/results/test/vector_databases 12 | 13 | evaluation_report_file: evaluation_report.json 14 | 15 | #项目的根目录 16 | project_dir: /mnt/DataFlow/ydw/Text2VectorSQL/ 17 | 18 | # --- SQL Execution Configuration --- 19 | # Path to the configuration for the ExecutionEngine 20 | engine_config_path: Execution_Engine/engine_config.yaml 21 | 22 | # The database backend type for this evaluation run. 23 | # Supported types: 'sqlite', 'postgresql', 'clickhouse' 24 | db_type: 'sqlite' 25 | 26 | #待测评的数据的文件路径 27 | eval_data_file: /mnt/b_public/data/ydw/Text2VectorSQL/Evaluation_Framework/input_output.json 28 | 29 | 30 | 31 | # Output file to save the SQL execution results (intermediate file) 32 | execution_results_file: sql_execution_results.json 33 | 34 | 35 | # --- Embedding Service Configuration --- 36 | # Embedding service will be automatically started by default 37 | embedding_service: 38 | # Whether to automatically manage the embedding service lifecycle 39 | auto_manage: false 40 | 41 | # Server configuration 42 | host: "127.0.0.1" 43 | port: 8000 44 | 45 | # Model configuration 46 | models: 47 | - name: "all-MiniLM-L6-v2" 48 | hf_model_path: "/mnt/DataFlow/ydw/Text2VectorSQL/Embedding_Service/models_cache/all-MiniLM-L6-v2" 49 | trust_remote_code: true 50 | tensor_parallel_size: 1 51 | max_model_len: 256 52 | 53 | # Additional models can be configured here 54 | # - name: "gte-large" 55 | # hf_model_path: "thenlper/gte-large" 56 | # trust_remote_code: true 57 | # tensor_parallel_size: 1 58 | # max_model_len: 512 59 | 60 | 61 | 62 | 63 | # --- Evaluation & Metrics Configuration --- 64 | # Select the metrics you want to calculate. 65 | metrics: 66 | # --- Set-based Metrics (order doesn't matter) --- 67 | - name: 'exact_match' # 0/1 score for perfect set match 68 | - name: 'f1_score' # The harmonic mean of precision and recall 69 | - name: 'precision' 70 | - name: 'recall' 71 | 72 | # --- Rank-aware Metrics (order matters) --- 73 | - name: 'map' # Mean Average Precision 74 | - name: 'mrr' # Mean Reciprocal Rank 75 | - name: 'ndcg' # Normalized Discounted Cumulative Gain 76 | k: 10 # The '@k' value for NDCG (e.g., ndcg@10) 77 | 78 | llm_evaluation: 79 | # Enable or disable LLM-based evaluation 80 | enabled: false # 设置为 true 启用 LLM 评估 81 | 82 | 83 | # LLM API configuration 84 | api_url: "http://123.129.219.111:3000/v1/chat/completions" 85 | api_key: "sk-GBD9BZqsXcgZXtWFVLraCy1U5Wws2Ix7xYzwTjHvBRij2MPQ" # 替换为您的 API Key 86 | model_name: "gpt-4o" # 可选: gpt-4, gpt-3.5-turbo 等 87 | 88 | # Request timeout in seconds 89 | timeout: 60 90 | 91 | # Optional: Whether to include LLM evaluation details in the report 92 | include_details: true 93 | # --- LLM-based Metrics (requires LLM evaluation to be enabled) --- 94 | # These metrics will only be calculated if llm_evaluation.enabled is true 95 | # - name: 'llm_sql_skeleton' # LLM评估的SQL骨架正确性 96 | # - name: 'llm_vector_component' # LLM评估的向量组件正确性 97 | # - name: 'llm_overall' # LLM评估的综合得分 98 | 99 | # The final evaluation report will be saved to this file. 100 | 101 | -------------------------------------------------------------------------------- /Data_Synthesizer/pipeline/sqlite/prompt_templates/sqlite_vec_note_prompt.txt: -------------------------------------------------------------------------------- 1 | You are an expert SQLite generator which support `sqlite-vec` extension. Your primary goal is to generate syntactically correct and efficient SQL queries that strictly adhere to the following rules. Your adherence to these rules is mandatory to prevent common errors. 2 | 3 | ### **Rule 1: When to Use Vector Search** 4 | - You **MUST** use a KNN `MATCH` query when a column name ends with `_embedding` or its data type is a vector (e.g., `float[?]`). This is the primary trigger for applying all subsequent vector search rules. 5 | - You **MUST NOT** use KNN search for traditional scalar columns like `id`, `age`, `price`, `gender`, etc. For these, always use standard SQL operators (`=`, `>`, `IN`, `LIKE`). 6 | 7 | ### **Rule 2: The Mandatory Constraint Rule (To prevent "A LIMIT or 'k = ?' constraint is required")** 8 | - Every query block containing a `MATCH` operator **MUST** include a vector search constraint. 9 | - **Strongly prefer using `k=N`**. It should be added as a standard `AND` condition in the `WHERE` clause. This method is the most compatible and clear, especially in complex queries. 10 | - The `k=N` parameter for the vector search is different from a `LIMIT` clause at the end of the entire query. `k=N` finds the N-nearest neighbors first, and the final `LIMIT` then takes a subset of that result. Both can be used in the same query. 11 | - If you must use `LIMIT N` as the vector constraint (e.g., in very simple queries), it **MUST** be placed *immediately* after the `MATCH ...` expression and before any other `AND` clauses. 12 | 13 | ### **Rule 3: The Critical Alias Qualification Rule (To prevent "no such column" errors)** 14 | - This is the most common source of errors in `JOIN` queries. 15 | - When a table in the `MATCH` clause has an alias (e.g., `FROM genre AS g`), you **MUST** qualify both the `k` parameter and the `distance` column with that same alias. 16 | - **Correct syntax:** 17 | - For the `k` parameter: `g.k = 5` 18 | - For the `distance` column: `SELECT g.distance`, `ORDER BY g.distance` 19 | - This is mandatory for any query using table aliases to ensure the SQL engine knows which table the virtual `distance` column and `k` parameter belong to. 20 | 21 | ### **Rule 4: The `ORDER BY` and `distance` Column** 22 | - The `distance` column is a virtual column generated by the `MATCH` operator that represents similarity. 23 | - When a query contains a `MATCH` operator, the results can **ONLY** be ordered by the `distance` column (e.g., `ORDER BY g.distance`). Do not add other columns to this `ORDER BY` clause. 24 | 25 | ### **Rule 5: The `lembed` Function** 26 | - Always use the `lembed('model_name', 'your_string_content')` function inside the `MATCH` clause to convert your search text into a vector. 27 | - The string content should be a meaningful phrase, sentence, or keyword that is contextually relevant to the column you are searching. 28 | 29 | --- 30 | ### **Golden Example: A `JOIN` Query Demonstrating All Critical Rules** 31 | 32 | This single example illustrates the correct application of the most important rules, especially for complex queries. 33 | 34 | ```sql 35 | -- This is a perfect example of a complex query done correctly. 36 | SELECT 37 | g.g_name, 38 | g.rating, 39 | gi.value AS genre_additional_info, 40 | g.distance -- Rule 3: 'distance' is correctly qualified with the alias 'g'. 41 | FROM 42 | genre AS g -- A table alias 'g' is defined here. 43 | JOIN 44 | genre_info AS gi ON g.g_name = gi.key 45 | WHERE 46 | g.g_name_embedding MATCH lembed('all-MiniLM-L6-v2', "Popular and highly acclaimed music genre") 47 | AND g.k = 5 -- Rule 2 & 3: The mandatory constraint 'k' is used and correctly qualified with 'g'. 48 | ORDER BY 49 | g.distance; -- Rule 4: The query is ordered ONLY by the qualified 'distance' column. 50 | ``` 51 | -------------------------------------------------------------------------------- /Data_Synthesizer/pipeline/sqlite/prompt_templates/find_semantic_rich_column.txt: -------------------------------------------------------------------------------- 1 | You are an expert database analyst. Given a database schema in JSON format, your task is to identify semantically rich columns in each table - those containing descriptive information like information(info), descriptions(desc), notes, title, review, comments, or textual content that conveys meaningful information about entities. 2 | 3 | **Input Structure:** 4 | {{ 5 | "db_id": "database_name", 6 | "table_names_original": ["table1", "table2"], 7 | "table_names": ["table1", "table2"], 8 | "column_names_original": [ 9 | [-1, "*"], 10 | [0, "column1"], 11 | [0, "column2"], 12 | [1, "column3"] 13 | ], 14 | "column_names": [ 15 | [-1, "*"], 16 | [0, "column1"], 17 | [0, "column2"], 18 | [1, "column3"] 19 | ], 20 | "column_types": ["type1", "type2", "type3"], 21 | "primary_keys": [n1], "foreign_keys": [[n1, n2]], 22 | "table_description": {{ 23 | "table1": "description of table1 columns", 24 | "table2": "description of table2 columns" 25 | }}, 26 | "table_samples": {{ 27 | "table1": [{{"column1": "value", "column2": "value"}}], 28 | "table2": [{{"column3": "value"}}] 29 | }} 30 | }} 31 | 32 | **Your Task:** 33 | 1. Your task is to identify semantically rich columns for each table in table_names_original. To do this, you must analyze both the column names and the corresponding sample data provided in table_samples. 34 | - A column is considered semantically rich if it meets the following criteria: 35 | * Descriptive Column Name. The column's name suggests it contains descriptive text. Examples include description, note, comment, info, or desc. 36 | 37 | * Substantial Textual Content. The corresponding sample data for the column in table_samples consists of a sentence or a paragraph of string data, not just single words or short phrases. 38 | 39 | * Not Numerical or an Identifier. The column's content is primarily textual, not numerical data or identifiers. 40 | - A column is NOT to be considered semantically rich if any of the following exclusion rules apply: 41 | * Specific Names. The column name is name, address. 42 | * Non-descriptive Content. Despite being text, the data in table_samples represents non-descriptive information like codes, IDs, or simple labels. 43 | 44 | 2. Output a JSON object where: 45 | - Keys are table names (from "table_names_original") 46 | - Values are arrays of objects with: 47 | {{ 48 | "column_name": "original_column_name", 49 | "column_type": "column_data_type", 50 | "semantic_type": "detected_semantic_description" 51 | }} 52 | 53 | 3. Semantic categories should be one of: 54 | - "description" (e.g., product_description, note, summary) 55 | - "title" (e.g., title) 56 | - "information" (e.g., information, info) 57 | - "text_content" (e.g., comment, review) 58 | 59 | 4. Focus on columns that: 60 | - Have text-based data types (text, varchar, etc.) 61 | - Contain descriptive information in sample data 62 | 63 | 5. Avoid including: 64 | - Primary/foreign keys used only for relationships 65 | - Technical/identifier columns (IDs, codes, numbers) 66 | - Date/time fields unless explicitly descriptive 67 | - Numerical measurements 68 | - name, title or address 69 | 70 | **Output Example:** 71 | ```json 72 | {{ 73 | "table1": [ 74 | {{"column_name": "product_information", "column_type": "text", "semantic_type": "information"}}, 75 | {{"column_name": "description", "column_type": "text", "semantic_type": "description"}} 76 | ], 77 | "table2": [ 78 | {{"column_name": "title", "column_type": "text", "semantic_type": "title"}} 79 | ] 80 | }} 81 | ``` 82 | 83 | **Special Considerations:** 84 | - Use both column names and sample data to determine semantic richness 85 | - Prefer false positives over missing important semantic columns 86 | - For columns with multiple semantic aspects, choose the most prominent one 87 | - If no semantically rich columns exist in a table, return an empty array for that table 88 | 89 | **Now process this database schema:** 90 | {dababase_schema} 91 | 92 | Let's think step by step 93 | -------------------------------------------------------------------------------- /Data_Synthesizer/README.md: -------------------------------------------------------------------------------- 1 | # 数据合成 (Data_Synthesizer) 2 | 3 | ## 简介 4 | 5 | Data_Synthesizer是一个功能强大的数据合成流水线,专为生成高质量Text2VectorSQL数据集而设计。本模块从基础数据库出发,通过一系列自动化步骤,最终产出包含数据库、自然语言问题、VectorSQL查询以及思维链(Chain-of-Thought)的完整数据集,为训练和评测先进的Text2VectorSQL模型提供支持。 6 | 7 | 该流水线支持多种数据库后端,如 SQLite、PostgreSQL 和 ClickHouse,并集成了向量化能力,使模型能够理解和利用数据中的语义信息。 8 | 9 | ## 核心功能 10 | 11 | - **数据库合成与增强 (`database_synthesis`)**: 从零开始或基于现有Web表格,自动生成结构化的数据库,并可对数据库模式进行增强,增加其复杂性和真实性。 12 | - **数据库向量化 (`vectorization`)**: 识别数据库中的“语义丰富”列(如描述性文本),利用Sentence Transformer模型为这些列生成向量嵌入,并构建支持向量查询的新数据库。这是实现语义搜索的关键。 13 | - **VectorSQL与问题合成 (`synthesis_sql`, `synthesis_nl`)**: 基于(向量化的)数据库模式,自动生成不同复杂度的VectorSQL查询,并为每个VectorSQL查询生成对应的自然语言问题。 14 | - **思维链合成 (`synthesis_cot`)**: 为每一个“数据库-问题-VectorSQL”三元组,生成详细的推理步骤(即思维链),解释从问题到VectorSQL的推导过程。这对于训练具有更强推理能力的模型至关重要。 15 | - **统一流水线 (`pipeline`)**: 提供一个总控脚本 `general_pipeline.py`,通过简单的配置即可运行完整的端到端数据合成流程,同时也支持对流程中每一步的独立调用和微调。 16 | 17 | ## 目录结构 18 | 19 | ``` 20 | Data_Synthesizer/ 21 | ├─── database_synthesis/ # 从Web表格合成数据库 22 | ├─── pipeline/ # 统一流水线和配置文件 23 | ├─── synthesis_cot/ # 思维链(CoT)合成 24 | ├─── synthesis_eval/ # 为模型生成训练和评估数据 25 | ├─── synthesis_nl/ # 自然语言问题(NL)合成 26 | ├─── synthesis_sql/ # VectorSQL查询合成 27 | ├─── tools/ # 数据集迁移、混合等辅助工具 28 | └─── vectorization/ # 数据库向量化 29 | ``` 30 | 31 | ## 快速开始 32 | 33 | 通过运行统一流水线,您可以最便捷地完成整个数据合成过程。 34 | 35 | 1. **安装依赖**: 36 | ```bash 37 | pip install -r requirements.txt 38 | ``` 39 | 40 | 2. **配置环境**: 41 | - 复制 `pipeline/config.yaml.example` 并重命名为 `pipeline/config.yaml`。 42 | - 在 `config.yaml` 中填入您的 LLM API-Key、Base-URL 以及其他相关配置。 43 | - 在服务器中启动embedding服务,参考文件 `../../Embedding_Service/README.md` 44 | 45 | 3. **配置流水线**: 46 | - 打开 `pipeline/general_pipeline.py` 文件。 47 | - 修改顶部的 `DATASET_BACKEND` 和 `DATASET_TO_LOAD` 变量,以选择您要使用的数据库类型和具体的数据集配置。 48 | ```python 49 | # 示例:选择 clickhouse 后端和名为 synthesis_data_deversity 的数据集 50 | DATASET_BACKEND = "clickhouse" 51 | DATASET_TO_LOAD = "synthesis_data_deversity" 52 | ``` 53 | 54 | 4. **运行流水线**: 55 | ```bash 56 | python pipeline/general_pipeline.py 57 | ``` 58 | 脚本将自动执行所有步骤,包括数据库向量化、VectorSQL生成、问题生成等。最终的合成数据集和向量数据库将保存在 `config.yaml` 中为该数据集配置的 `result_path` 路径下。 59 | 60 | ## 分步执行 61 | 62 | 如果您希望更精细地控制每一步,可以按照以下顺序手动执行各个子模块的脚本。 63 | 64 | ### 第1步: 数据库向量化 (`vectorization`) 65 | 66 | 此步骤为现有数据库添加向量信息。 67 | 68 | 1. **生成基础Schema**: 69 | ```bash 70 | python vectorization/generate_schema.py --db-dir <数据库目录> --output-file <输出的tables.json路径> 71 | ``` 72 | 2. **(可选)为Schema填充样本数据**: 73 | ```bash 74 | python vectorization/enhance_tables_json.py ... 75 | ``` 76 | 3. **寻找语义丰富的列**: 77 | ```bash 78 | python vectorization/find_semantic_rich_column.py ... 79 | ``` 80 | 4. **批量向量化**: 81 | 为语义列生成向量嵌入,并创建初始的向量数据库脚本。 82 | ```bash 83 | python vectorization/batch_vectorize_databases.py ... 84 | ``` 85 | 5. **生成最终向量数据库**: 86 | 使用上一步生成的脚本,构建最终的SQLite向量数据库。 87 | ```bash 88 | python vectorization/vector_database_generate.py ... 89 | ``` 90 | 91 | ### 第2步: SQL查询合成 (`synthesis_sql`) 92 | 93 | 1. **生成VectorSQL合成提示**: 94 | ```bash 95 | python synthesis_sql/generate_sql_synthesis_prompts.py ... 96 | ``` 97 | 2. **调用LLM合成VectorSQL**: 98 | ```bash 99 | python synthesis_sql/synthesize_sql.py ... 100 | ``` 101 | 3. **后处理与筛选**: 102 | 验证VectorSQL的正确性,去除无效和重复的查询。 103 | ```bash 104 | python synthesis_sql/post_process_sqls.py ... 105 | ``` 106 | 107 | ### 第3步: 自然语言问题合成 (`synthesis_nl`) 108 | 109 | 1. **生成问题合成提示**: 110 | ```bash 111 | python synthesis_nl/generate_question_synthesis_prompts.py ... 112 | ``` 113 | 2. **调用LLM合成问题**: 114 | ```bash 115 | python synthesis_nl/synthesize_question.py ... 116 | ``` 117 | 3. **后处理与筛选**: 118 | 通过语义一致性筛选,确保问题与VectorSQL查询高度匹配。 119 | ```bash 120 | python synthesis_nl/post_process_questions.py ... 121 | ``` 122 | 123 | ### 第4步: 思维链合成 (`synthesis_cot`) 124 | 125 | 1. **生成CoT合成提示**: 126 | ```bash 127 | python synthesis_cot/generate_cot_synthesis_prompts.py ... 128 | ``` 129 | 2. **调用LLM合成CoT**: 130 | ```bash 131 | python synthesis_cot/synthesize_cot.py ... 132 | ``` 133 | 3. **后处理与筛选**: 134 | 通过执行验证和投票机制,选出最可靠的思维链。 135 | ```bash 136 | python synthesis_cot/post_process_cot.py ... 137 | ``` 138 | 139 | 最终,您将得到一个包含“数据库、问题、VectorSQL、思维链”的完整数据集,可用于模型训练。 140 | 141 | ## 依赖安装 142 | 143 | 在运行任何脚本之前,请确保已安装所有必需的Python库。 144 | 145 | ```bash 146 | pip install -r requirements.txt 147 | ``` 148 | -------------------------------------------------------------------------------- /Data_Synthesizer/pipeline/postgresql/prompt_templates/postgresql_vec_note_prompt.txt: -------------------------------------------------------------------------------- 1 | There are a few requirements you should comply with in addition: 2 | 1. When generating SQL queries, you should prioritize utilizing K-Nearest Neighbor (KNN) searches whenever contextually appropriate. However, you must avoid unnecessary/forced KNN implementations for: 3 | -- Traditional relational data queries (especially for columns like: id, age, price). 4 | -- Cases where standard SQL operators (equality, range, or aggregation functions) are more efficient and semantically appropriate. 5 | 2. Only columns with the vector data type support KNN queries. The names of these vector columns often end with "_embedding". You should perform KNN searches when the column name you need to query ends with "_embedding" or is otherwise identified as a vector column. 6 | 3. In PostgreSQL with the pgvector extension, vector similarity search is performed using distance operators. You must calculate the distance in the SELECT clause and give it an alias, typically "AS distance". The primary operators are: 7 | <->: L2 (Euclidean) distance. This is the most common operator and is recommended for general-purpose similarity search. 8 | <#>: Negative inner product. Can be used to find the maximum inner product similarity. 9 | 4. The lembed function is used to transform a query string into a semantic vector. This function has two parameters: the first is the name of the embedding model used (default value: '{embedding_model}'), and the second is the string content to embed. 10 | 5. You must generate plausible and semantically relevant words or sentences for the second parameter of the lembed function based on the column's name, type, and comment. For example, if a column is named review_embedding and its comment is "Embedding of the user's review text", you could generate text like "an overwhelmingly positive experience with great customer service". 11 | 6. Every KNN search query MUST conclude with "ORDER BY distance LIMIT N" to retrieve the top-N most similar results. pgvector uses this pattern to leverage specialized indexes (like HNSW) for extremely fast and efficient KNN searches. The LIMIT clause is mandatory. 12 | 7. When combining a vector search with standard JOIN or WHERE clauses, these filters are applied to pre-filter the dataset. The KNN search (ORDER BY distance LIMIT N) is then performed on the filtered results. 13 | 8. A SELECT statement should typically contain only one vector operator to perform a single KNN search. However, subqueries can perform their own independent KNN searches, each with its own distance calculation, ORDER BY, and LIMIT clause. 14 | 9. **CRITICAL RULE: Clarification on Clause Roles for Vector Search.** The roles of SQL clauses in a pgvector query are very specific and must not be confused: 15 | * **`SELECT` clause**: Its role is to **display** the calculated distance. This is where you use `AS distance`. This is optional; you don't have to show the distance. 16 | * **`WHERE` clause**: Its role is to **filter** rows *before* ranking. The expression here MUST return a boolean (`true`/`false`). For example, you can filter for candidates where the distance is below a certain threshold (`vector_column <-> lembed(...) < 0.5`). 17 | * **`ORDER BY` clause**: This is the **engine** of the KNN search. Its role is to **rank** the results by similarity. This is where the vector operator does its main job. 18 | 10. **CRITICAL RULE: All table and column names are case-sensitive. You MUST enclose all table and column identifiers in double quotes (") to preserve their original case as defined in the schema. For example, query Paragraphs as \"Paragraphs\". For another example, query column name of table with aliases as \"Headings\".\"heading_text_embedding\". 19 | 11. **FORBIDDEN PATTERN: Never define a distance alias in the `WHERE` clause.** The following pattern is syntactically invalid SQL and is strictly forbidden. The model must never generate it: 20 | * **WRONG**: `WHERE "embedding_column" <-> lembed(...) AS "distance"` 21 | * **WHY IT'S WRONG**: The `WHERE` clause is for filtering conditions, not for creating aliases. This will always cause a PostgreSQL syntax error. 22 | 23 | ## Example of a PostgreSQL (pgvector) KNN Query 24 | DB Schema: Some table on articles with a column content_embedding vector(384). 25 | Embedding Model: laion/CLIP-ViT-B-32-laion2B-s34B-b79K. 26 | Query Task: Identify the article ID of the single most relevant article discussing innovative algorithms in graph theory. 27 | Generated SQL: 28 | ```sql 29 | SELECT \"p\".\"paragraph_id\", \"p\".\"article_id\", \"p\".\"text\", \"p\".\"text_embedding\" <-> lembed('laion/CLIP-ViT-B-32-laion2B-s34B-b79K', 'The story unfolds in the bustling city of New York, where characters navigate complex social dynamics.') AS \"distance\"\nFROM \"Paragraphs\" \"p\"\nORDER BY \"distance\"\nLIMIT 5; 30 | ``` 31 | -------------------------------------------------------------------------------- /Evaluation_Framework/script/config.yaml.example: -------------------------------------------------------------------------------- 1 | # ==================================================================== 2 | # 通用数据库配置 (Universal Database Configuration) 3 | # ==================================================================== 4 | sqlite: 5 | # ==================================================================== 6 | # 数据集: toy_spider 7 | # ==================================================================== 8 | toy_spider: 9 | services: &common_services 10 | vllm: 11 | api_url: "http://127.0.0.1:8000/v1" 12 | API_KEY: "none" 13 | model_name: "/mnt/b_public/data/ydw/model/Qwen/Qwen2.5-72B-Instruct" 14 | 15 | openai: 16 | api_key: &common_api_key "sk-xxxx" 17 | base_url: "http://123.129.219.111:3000/v1" 18 | llm_model_name: "gpt-4o" 19 | embedding_model_name: "all-MiniLM-L6-v2" 20 | 21 | # ------------------------------------------------------------------ 22 | # 路径配置 - 在这里使用锚点 (&) 定义一个可复用的代码块 23 | # ------------------------------------------------------------------ 24 | paths: &common_paths # <-- (1) 使用 &common_paths 定义锚点 25 | input_file_to_id: "../../Data_Synthesizer/pipeline/sqlite/results/toy_spider/candidate_sql.json" 26 | dataset_json_path: "../sqlite/results/toy_spider/candidate_sql_query_id.json" 27 | ground_truth_output_path: "../ground_truth.json" 28 | tables_json_path: "../../Data_Synthesizer/pipeline/sqlite/results/toy_spider/embedding_table_vector.json" 29 | prompt_tamplate_path: "../prompt_templates/sql_generate_prompt_template.txt" 30 | output_prompt_path: "../sqlite/prompts/sql_generate_prompts.json" 31 | database_note_prompt_path: "../prompt_templates/sqlite_vec_note_prompt.txt" 32 | 33 | sql_prompt_file_path: "../sqlite/prompts/sql_generate_prompts.json" 34 | eval_input_path: "../eval_queries.json" 35 | cache_file_path_sql: "../cache/sqlite/synthesis_sql_cache.jsonl" 36 | 37 | parameters: &common_parameters 38 | dataset_backend: "sqlite" 39 | use_vllm: False #True则使用openai服务,否则使用VLLM服务 40 | max_workers: 32 41 | no_parallel: false 42 | num_cpus: 10 43 | sql_exec_timeout: 60 44 | num_candidates: 5 45 | 46 | # ==================================================================== 47 | # 合成数据: synthesis_data 48 | # ==================================================================== 49 | synthesis_data: 50 | services: *common_services 51 | 52 | # ------------------------------------------------------------------ 53 | # 直接使用别名 (*) 复用上面定义的 paths 代码块 54 | # ------------------------------------------------------------------ 55 | paths: *common_paths # <-- (2) 使用 *common_paths 引用锚点 56 | 57 | parameters: *common_parameters 58 | 59 | # ==================================================================== 60 | # 数据集: bird 61 | # ==================================================================== 62 | bird: 63 | # services 和 parameters 可以是 bird 数据集特有的 64 | services: *common_services 65 | 66 | # ------------------------------------------------------------------ 67 | # 直接使用别名 (*) 复用上面定义的 paths 代码块 68 | # ------------------------------------------------------------------ 69 | paths: *common_paths # <-- (2) 使用 *common_paths 引用锚点 70 | 71 | parameters: *common_parameters 72 | 73 | # ==================================================================== 74 | # 数据集: arxiv 75 | # ==================================================================== 76 | arxiv: 77 | # services 和 parameters 可以是 bird 数据集特有的 78 | services: *common_services 79 | 80 | # ------------------------------------------------------------------ 81 | # 直接使用别名 (*) 复用上面定义的 paths 代码块 82 | # ------------------------------------------------------------------ 83 | paths: *common_paths # <-- (2) 使用 *common_paths 引用锚点 84 | 85 | parameters: 86 | max_workers: 32 87 | no_parallel_find_semantic_rich: false 88 | num_cpus: 10 89 | sql_exec_timeout: 60 90 | num_candidates: 5 91 | sql_number: 75 # this number is vary big, because the semantically rich column number of this dataset is vary small 92 | 93 | # ==================================================================== 94 | # 数据集: spider 95 | # ==================================================================== 96 | spider: 97 | # services 和 parameters 可以是 bird 数据集特有的 98 | services: *common_services 99 | 100 | # ------------------------------------------------------------------ 101 | # 直接使用别名 (*) 复用上面定义的 paths 代码块 102 | # ------------------------------------------------------------------ 103 | paths: *common_paths # <-- (2) 使用 *common_paths 引用锚点 104 | 105 | parameters: 106 | max_workers: 32 107 | no_parallel_find_semantic_rich: false 108 | num_cpus: 10 109 | sql_exec_timeout: 60 110 | num_candidates: 5 111 | sql_number: 2 #这个参数特别重要!控制最终文件的大小。 112 | 113 | # ==================================================================== 114 | # 数据集: wikipedia_multimodal 115 | # ==================================================================== 116 | -------------------------------------------------------------------------------- /Data_Synthesizer/vectorization/generate_schema.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import sqlite3 4 | import argparse 5 | from tqdm import tqdm 6 | 7 | def get_schema_for_db(db_path): 8 | """ 9 | 连接到单个 SQLite 数据库,检查其 schema,并返回一个字典。 10 | 11 | Args: 12 | db_path (str): 数据库文件的路径。 13 | 14 | Returns: 15 | dict: 包含数据库 schema 信息的字典,如果出错则返回 None。 16 | """ 17 | try: 18 | # 从文件路径中提取 db_id 19 | # 修正:db_id 应该是数据库所在的文件夹名,与 Spider 数据集格式保持一致 20 | db_id = os.path.basename(os.path.dirname(db_path)) 21 | 22 | # 初始化 schema 字典结构 23 | schema = { 24 | "db_id": db_id, 25 | "table_names_original": [], 26 | "table_names": [], 27 | "column_names_original": [[-1, "*"]], # 包含通配符 '*' 28 | "column_names": [[-1, "*"]], 29 | "column_types": ["text"], 30 | "primary_keys": [], 31 | "foreign_keys": [] 32 | } 33 | 34 | conn = sqlite3.connect(db_path) 35 | cursor = conn.cursor() 36 | 37 | # 1. 获取所有表名 38 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';") 39 | table_names = [row[0] for row in cursor.fetchall()] 40 | schema["table_names_original"] = table_names 41 | schema["table_names"] = table_names 42 | 43 | # 用于快速查找表名对应的索引 44 | table_name_to_idx = {name: i for i, name in enumerate(table_names)} 45 | 46 | # 用于后续查找外键时,快速定位列的全局索引 47 | column_to_global_idx = {} 48 | current_col_idx = 1 # 从 1 开始,因为 0 被 '*' 占用 49 | 50 | # 2. 遍历每个表,获取列信息和主键 51 | for table_idx, table_name in enumerate(table_names): 52 | cursor.execute(f'PRAGMA table_info("{table_name}");') 53 | columns_info = cursor.fetchall() 54 | 55 | for col in columns_info: 56 | # col 格式: (cid, name, type, notnull, dflt_value, pk) 57 | col_name = col[1] 58 | col_type = col[2].upper() # 保持与 Spider 数据集一致,通常为大写 59 | is_primary_key = col[5] == 1 60 | 61 | # 添加列信息 62 | schema["column_names_original"].append([table_idx, col_name]) 63 | schema["column_names"].append([table_idx, col_name]) 64 | schema["column_types"].append(col_type) 65 | 66 | # 记录主键 67 | if is_primary_key: 68 | schema["primary_keys"].append(current_col_idx) 69 | 70 | # 建立 (表名, 列名) -> 全局索引 的映射 71 | column_to_global_idx[(table_name, col_name)] = current_col_idx 72 | current_col_idx += 1 73 | 74 | # 3. 遍历每个表,获取外键信息 75 | for table_idx, table_name in enumerate(table_names): 76 | cursor.execute(f'PRAGMA foreign_key_list("{table_name}");') 77 | foreign_keys_info = cursor.fetchall() 78 | 79 | for fk in foreign_keys_info: 80 | # fk 格式: (id, seq, table, from, to, on_update, on_delete, match) 81 | from_column = fk[3] 82 | to_table = fk[2] 83 | to_column = fk[4] 84 | 85 | # 查找源列和目标列的全局索引 86 | from_col_idx = column_to_global_idx.get((table_name, from_column)) 87 | to_col_idx = column_to_global_idx.get((to_table, to_column)) 88 | 89 | if from_col_idx is not None and to_col_idx is not None: 90 | schema["foreign_keys"].append([from_col_idx, to_col_idx]) 91 | 92 | conn.close() 93 | return schema 94 | 95 | except sqlite3.Error as e: 96 | # 使用 os.path.basename(db_path) 使得错误信息更清晰 97 | print(f" [错误] 处理数据库 {os.path.basename(db_path)} 失败: {e}") 98 | return None 99 | 100 | 101 | def generate_schema(db_dir,output_file): 102 | """ 103 | 主函数,为指定目录中的所有数据库生成 schema。 104 | """ 105 | if not os.path.isdir(db_dir): 106 | print(f"✖ 错误:目录 '{db_dir}' 不存在。") 107 | return 108 | 109 | print(f"🚀 开始从目录 '{db_dir}' 及其子目录中递归查找数据库...") 110 | 111 | # --- 主要修改部分 --- 112 | # 使用 os.walk() 递归遍历目录以查找所有数据库文件 113 | db_files = [] 114 | for root, dirs, files in os.walk(db_dir): 115 | for file in files: 116 | if file.endswith(('.sqlite', '.db')): 117 | db_files.append(os.path.join(root, file)) 118 | # --- 修改结束 --- 119 | 120 | if not db_files: 121 | print("🟡 警告:在指定目录及其所有子目录中均未找到 .sqlite 或 .db 文件。") 122 | return 123 | 124 | # 按照字母顺序处理,保证输出结果的稳定性 125 | db_files.sort() 126 | 127 | all_schemas = [] 128 | for db_path in tqdm(db_files, desc="处理数据库中"): 129 | schema_data = get_schema_for_db(db_path) 130 | if schema_data: 131 | all_schemas.append(schema_data) 132 | 133 | # 确保输出目录存在 134 | output_dir = os.path.dirname(output_file) 135 | if output_dir: 136 | os.makedirs(output_dir, exist_ok=True) 137 | 138 | try: 139 | with open(output_file, 'w', encoding='utf-8') as f: 140 | json.dump(all_schemas, f, indent=4, ensure_ascii=False) # indent=4 格式更美观 141 | print(f"\n✔ 成功创建 schema 文件 '{output_file}',包含 {len(all_schemas)} 个数据库。") 142 | except IOError as e: 143 | print(f"\n✖ 写入输出文件失败: {e}") 144 | 145 | 146 | # if __name__ == '__main__': 147 | # main() 148 | -------------------------------------------------------------------------------- /Data_Synthesizer/tools/change_embedding_model.py: -------------------------------------------------------------------------------- 1 | import ijson 2 | import json 3 | import random 4 | import argparse 5 | from tqdm import tqdm 6 | import os 7 | 8 | def create_model_list(): 9 | """ 10 | 创建包含带前缀和不带前缀的40个模型名称的列表。 11 | """ 12 | base_models = [ 13 | # 商业闭源模型 14 | "OpenAI/text-embedding-3-large", 15 | "OpenAI/text-embedding-3-small", 16 | "OpenAI/text-embedding-ada-02", 17 | "Voyage-AI/voyage-large-2", 18 | "Voyage-AI/voyage-code-2", 19 | "Voyage-AI/voyage-2", 20 | "Google/text-embedding-004", 21 | "Google/text-embedding-gecko@003", 22 | "Cohere/embed-english-v3.0", 23 | "Cohere/embed-multilingual-v3.0", 24 | 25 | # 开源顶级性能模型 26 | "BAAI/bge-large-en-v1.5", 27 | "NVIDIA/NV-Embed-v2", 28 | "Alibaba-NLP/gte-Qwen2-7B-instruct", 29 | "intfloat/E5-Mistral-7B-Instruct", 30 | "Salesforce/SFR-Embedding-2_R", 31 | "nomic-ai/nomic-embed-text-v1.5", 32 | "intfloat/e5-large-v2", 33 | "Alibaba-NLP/gte-large", 34 | "hkunlp/instructor-xl", 35 | 36 | # 开源高效与经典模型 37 | "sentence-transformers/all-mpnet-base-v2", 38 | "sentence-transformers/all-MiniLM-L6-v2", 39 | "sentence-transformers/all-MiniLM-L12-v2", # 新增:L6 的稍大(12层)版本,性能更强 40 | "sentence-transformers/msmarco-distilbert-base-v4", # 新增:专为语义搜索(MS MARCO)优化的经典模型 41 | "princeton-nlp/sup-simcse-bert-base-uncased", # 新增:SimCSE,对比学习领域的经典之作 42 | "intfloat/e5-base-v2", # 新增:E5 系列的 base 版本 43 | "BAAI/bge-base-en-v1.5", 44 | "BAAI/bge-small-en-v1.5", 45 | "Alibaba-NLP/gte-base", 46 | "jina-ai/jina-embeddings-v2-base-en", 47 | "Grit-AI/g-gt-large", 48 | 49 | # 开源多语言模型 50 | "BAAI/bge-m3", 51 | "intfloat/multilingual-e5-large", 52 | "intfloat/multilingual-e5-base", # 新增:多语言 E5 的 base 版本 53 | "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", 54 | "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", # 新增:高效的多语言 MiniLM 55 | "sentence-transformers/distiluse-base-multilingual-cased-v1", # 新增:SBERT 经典的多语言 DistilUSE 56 | "sentence-transformers/LaBSE", 57 | "google-bert/bert-base-multilingual-cased" 58 | ] 59 | 60 | full_list = [] 61 | for model in base_models: 62 | # 添加带前缀的完整名称 63 | full_list.append(model) 64 | # 如果有前缀,则添加不带前缀的名称 65 | if "/" in model: 66 | unprefixed_name = model.split('/')[-1] 67 | full_list.append(unprefixed_name) 68 | 69 | # 去重后返回,确保唯一性 70 | return list(set(full_list)) 71 | 72 | def process_large_json(input_path, output_path): 73 | """ 74 | 流式处理大型JSON文件,替换指定字符串。 75 | 76 | :param input_path: 输入的JSON文件路径。 77 | :param output_path: 输出的JSON文件路径。 78 | """ 79 | 80 | target_string = "all-MiniLM-L6-v2" 81 | model_replacements = create_model_list() 82 | 83 | print(f"输入文件: {input_path}") 84 | print(f"输出文件: {output_path}") 85 | print(f"将从 {len(model_replacements)} 个候选项中随机选择模型进行替换...") 86 | 87 | try: 88 | # 使用 ijson 流式读取顶层数组的每个元素 89 | with open(input_path, 'rb') as f_in, open(output_path, 'w', encoding='utf-8') as f_out: 90 | # 开始写入JSON数组 91 | f_out.write('[\n') 92 | 93 | is_first_item = True 94 | # 使用 ijson.items 解析文件,'item' 表示我们期望一个数组作为根元素 95 | parser = ijson.items(f_in, 'item') 96 | 97 | # 使用 tqdm 显示进度条 98 | for item in tqdm(parser, desc="正在处理JSON对象"): 99 | # 确保 item 是一个字典 100 | if not isinstance(item, dict): 101 | continue 102 | 103 | # 为当前这一个字典元素,只选择一个随机的替换值 104 | chosen_replacement = random.choice(model_replacements) 105 | 106 | updated_item = {} 107 | for key, value in item.items(): 108 | # 检查字段值是否为字符串 109 | if isinstance(value, str): 110 | # 替换所有出现的子字符串 111 | updated_item[key] = value.replace(target_string, chosen_replacement) 112 | else: 113 | # 如果不是字符串,保持原样 114 | updated_item[key] = value 115 | 116 | # 写入处理后的对象 117 | if not is_first_item: 118 | f_out.write(',\n') 119 | 120 | json.dump(updated_item, f_out, ensure_ascii=False, indent=2) 121 | is_first_item = False 122 | 123 | # 结束JSON数组 124 | f_out.write('\n]') 125 | 126 | print("\n处理完成!") 127 | 128 | except Exception as e: 129 | print(f"\n处理过程中发生错误: {e}") 130 | # 如果出错,可能需要清理不完整的输出文件 131 | if os.path.exists(output_path): 132 | # os.remove(output_path) # 可选:如果希望出错时删除不完整的文件 133 | print(f"注意: 输出文件 '{output_path}' 可能不完整。") 134 | 135 | 136 | if __name__ == "__main__": 137 | # 直接在代码中定义文件路径 138 | input_file_path = "/mnt/DataFlow/ydw/Text2VectorSQL/Data_Synthesizer/tools/results/collected_input_llm.json" 139 | output_file_path = "/mnt/DataFlow/ydw/Text2VectorSQL/LLaMA-Factory/data/myscale_synthesis_data.json" 140 | 141 | # 确保文件存在,给出更友好的提示 142 | if not os.path.exists(input_file_path): 143 | print(f"错误: 输入文件未找到 -> {input_file_path}") 144 | else: 145 | # 直接调用处理函数 146 | process_large_json(input_file_path, output_file_path) 147 | -------------------------------------------------------------------------------- /Data_Synthesizer/tools/mix_datasets.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import os 4 | import sys 5 | 6 | def load_json_file(filepath): 7 | """常规加载.json文件(适用于小文件)。""" 8 | try: 9 | with open(filepath, 'r', encoding='utf-8') as f: 10 | return json.load(f) 11 | except FileNotFoundError: 12 | print(f"❌ 错误: 文件未找到 -> {filepath}。") 13 | sys.exit(1) 14 | except json.JSONDecodeError: 15 | print(f"❌ 错误: 文件 '{filepath}' 不是有效的JSON格式。") 16 | sys.exit(1) 17 | 18 | def count_lines_in_file(filepath): 19 | """快速计算文件行数,无需加载到内存。""" 20 | print(f" -> 正在快速计算 '{os.path.basename(filepath)}' 的总行数...") 21 | count = 0 22 | with open(filepath, 'r', encoding='utf-8') as f: 23 | for _ in f: 24 | count += 1 25 | return count 26 | 27 | def reservoir_sample_jsonl(filepath, k): 28 | """ 29 | 使用蓄水池抽样从一个大的 .jsonl 文件中高效地随机抽取 k 个样本。 30 | 这只会在内存中保留 k 个元素。 31 | """ 32 | print(f" -> 正在从 '{os.path.basename(filepath)}' 中进行蓄水池抽样...") 33 | reservoir = [] 34 | with open(filepath, 'r', encoding='utf-8') as f: 35 | for i, line in enumerate(f): 36 | if i < k: 37 | # 1. 直接填满蓄水池 38 | reservoir.append(json.loads(line)) 39 | else: 40 | # 2. 以 k/i 的概率替换旧元素 41 | j = random.randint(0, i) 42 | if j < k: 43 | reservoir[j] = json.loads(line) 44 | return reservoir 45 | 46 | def process_and_mix_datasets(file1_path, file2_path, output_dir, ratios): 47 | """ 48 | 根据给定的比例,高效混合两个JSON文件的数据。 49 | """ 50 | # --- 1. 识别文件类型并获取大小 --- 51 | # 我们假设 .jsonl 文件是潜在的大文件,另一个是小文件 52 | if file1_path.endswith('.jsonl') and file2_path.endswith('.json'): 53 | large_file_path, small_file_path = file1_path, file2_path 54 | large_file_is_file1 = True 55 | elif file2_path.endswith('.jsonl') and file1_path.endswith('.json'): 56 | large_file_path, small_file_path = file2_path, file1_path 57 | large_file_is_file1 = False 58 | else: 59 | print("❌ 错误: 脚本需要一个 .json 文件和一个 .jsonl 文件才能进行优化。") 60 | print(f" 文件1: {file1_path}") 61 | print(f" 文件2: {file2_path}") 62 | sys.exit(1) 63 | 64 | print(f"识别到小文件 (完全加载): {os.path.basename(small_file_path)}") 65 | print(f"识别到大文件 (流式处理): {os.path.basename(large_file_path)}") 66 | 67 | # 加载小文件数据,并获取大文件的行数 68 | small_data = load_json_file(small_file_path) 69 | large_file_len = count_lines_in_file(large_file_path) 70 | 71 | print(f" -> 小文件加载了 {len(small_data)} 条数据。") 72 | print(f" -> 大文件共有 {large_file_len} 条数据。") 73 | 74 | # --- 2. 确定基准数量 --- 75 | # 逻辑保持不变:基准数量由两个文件中较小的那一个决定 76 | if len(small_data) <= large_file_len: 77 | base_size = len(small_data) 78 | base_is_from_small_file = True 79 | print(f"\n抽样基准由小文件决定,数量 = {base_size}") 80 | else: 81 | base_size = large_file_len 82 | base_is_from_small_file = False 83 | print(f"\n抽样基准由大文件决定,数量 = {base_size}") 84 | 85 | if base_size == 0: 86 | print("❌ 错误: 基准文件中没有数据,无法进行混合。") 87 | sys.exit(1) 88 | 89 | # --- 3. 创建输出目录 (您的代码是正确的) --- 90 | os.makedirs(output_dir, exist_ok=True) 91 | print(f"结果将保存到目录: '{output_dir}'") 92 | 93 | # --- 4. 遍历比例进行处理 --- 94 | for ratio_str in ratios: 95 | # (解析比例字符串的代码与您的一样,所以这里省略了重复部分) 96 | r1_str, r2_str = ratio_str.split(':') 97 | r1, r2 = int(r1_str), int(r2_str) 98 | print(f"\n--- 开始处理比例 {r1}:{r2} ---") 99 | 100 | # 确定r_small, r_large的值 101 | if base_is_from_small_file: 102 | # 根据哪个文件是file1/file2来分配比例值 103 | r_small = r2 if large_file_is_file1 else r1 104 | r_large = r1 if large_file_is_file1 else r2 105 | else: 106 | r_small = r1 if large_file_is_file1 else r2 107 | r_large = r2 if large_file_is_file1 else r1 108 | 109 | # 计算抽样数量 110 | if r_small == 0 and r_large > 0: 111 | n_base = 0; n_other = base_size 112 | elif r_large == 0 and r_small > 0: 113 | n_base = base_size; n_other = 0 114 | elif r_small > 0 and r_large > 0: 115 | n_base = base_size 116 | n_other = int(base_size * (r_large / r_small)) 117 | else: 118 | n_base = 0; n_other = 0 119 | 120 | # 将抽样数映射回 small_data 和 large_file 121 | n_small, n_large = (n_base, n_other) if base_is_from_small_file else (n_other, n_base) 122 | 123 | # 防止抽样数超过实际数据量 124 | n_small = min(n_small, len(small_data)) 125 | n_large = min(n_large, large_file_len) 126 | 127 | print(f"计划抽样: {n_small}条 from '{os.path.basename(small_file_path)}', {n_large}条 from '{os.path.basename(large_file_path)}'") 128 | 129 | # 进行抽样 130 | sample_small = random.sample(small_data, n_small) 131 | sample_large = reservoir_sample_jsonl(large_file_path, n_large) 132 | 133 | # 合并并打乱 134 | combined_data = sample_small + sample_large 135 | random.shuffle(combined_data) 136 | print(f"混合后总数据量: {len(combined_data)}") 137 | 138 | # 确定哪个是file1的样本 139 | sample_f1, sample_f2 = (sample_large, sample_small) if large_file_is_file1 else (sample_small, sample_large) 140 | 141 | # 保存到文件 142 | output_filename = f"input_llm_{r1}_{r2}.json" 143 | output_path = os.path.join(output_dir, output_filename) 144 | try: 145 | with open(output_path, 'w', encoding='utf-8') as f: 146 | json.dump(combined_data, f, indent=2, ensure_ascii=False) 147 | print(f"✅ 成功保存到: {output_path}") 148 | except IOError as e: 149 | print(f"❌ 错误: 无法写入文件 {output_path}。错误信息: {e}") 150 | 151 | print("\n🎉 所有任务处理完毕!") 152 | 153 | 154 | if __name__ == "__main__": 155 | # --- 请在这里配置 --- 156 | 157 | # 1. 输入文件路径 158 | # 脚本会自动识别哪个是.json,哪个是.jsonl 159 | FILE_1_PATH = "/mnt/b_public/data/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/sqlite/results/synthesis_data_deversity/input_llm.json" 160 | FILE_2_PATH = "/mnt/b_public/data/ydw/datasets/input_llm.jsonl" 161 | 162 | # 2. 输出目录 163 | OUTPUT_DIR = "./results/mixed_datasets" 164 | 165 | # 3. 需要生成的混合比例列表 166 | RATIOS_TO_GENERATE = [ 167 | "1:0", "0:1", "1:1", "1:2", "2:1", "1:4", "4:1" 168 | ] 169 | 170 | # --- 配置结束,运行脚本 --- 171 | process_and_mix_datasets(FILE_1_PATH, FILE_2_PATH, OUTPUT_DIR, RATIOS_TO_GENERATE) 172 | -------------------------------------------------------------------------------- /Data_Synthesizer/database_synthesis/embedding_schema.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import concurrent.futures 3 | import json 4 | import os 5 | import re 6 | from functools import partial 7 | from json_repair import json_repair 8 | from tqdm import tqdm 9 | import hashlib 10 | from tenacity import retry, stop_after_attempt, wait_exponential 11 | import openai 12 | 13 | # 配置缓存目录 14 | CACHE_DIR = "./cache" 15 | os.makedirs(CACHE_DIR, exist_ok=True) 16 | 17 | def get_cache_key(prompt, model): 18 | """生成基于提示内容和模型名称的唯一缓存键""" 19 | key_str = f"{model}_{prompt}" 20 | return hashlib.md5(key_str.encode('utf-8')).hexdigest() 21 | 22 | def load_from_cache(cache_key): 23 | """从缓存加载数据""" 24 | cache_file = os.path.join(CACHE_DIR, f"{cache_key}.json") 25 | if os.path.exists(cache_file): 26 | with open(cache_file, 'r', encoding='utf-8') as f: 27 | return json.load(f) 28 | return None 29 | 30 | def save_to_cache(cache_key, data): 31 | """保存数据到缓存""" 32 | cache_file = os.path.join(CACHE_DIR, f"{cache_key}.json") 33 | with open(cache_file, 'w', encoding='utf-8') as f: 34 | json.dump(data, f, ensure_ascii=False, indent=2) 35 | 36 | def parse_response(response): 37 | """保持原有解析函数不变""" 38 | schema_pattern = r'```json\s*([\s\S]*?)\s*```' 39 | 40 | try: 41 | enhanced_schema_match = re.search(schema_pattern, response, re.DOTALL) 42 | enhanced_schema_str = enhanced_schema_match.group(0).strip() if enhanced_schema_match else None 43 | enhanced_schema_dict = json_repair.loads(enhanced_schema_str) 44 | return enhanced_schema_dict 45 | except Exception as e: 46 | print(response) 47 | print("Parsing Exception:", str(e)) 48 | return None 49 | 50 | def parse_prompt(prompt): 51 | """保持原有解析函数不变""" 52 | domain_pattern = r'(?<=\*\*Business Domain:\*\*)(.*?)(?=\*\*Business Scenario:\*\*)' 53 | scenario_pattern = r'(?<=\*\*Business Scenario:\*\*)(.*?)(?=\*\*Initial Database Schema:\*\*)' 54 | 55 | domain_match = re.search(domain_pattern, prompt, re.DOTALL) 56 | domain = domain_match.group(0).strip() if domain_match else None 57 | 58 | scenario_match = re.search(scenario_pattern, prompt, re.DOTALL) 59 | scenario = scenario_match.group(0).strip() if scenario_match else None 60 | 61 | return domain, scenario 62 | 63 | @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)) 64 | def generate_response(client, model, prompt): 65 | """带重试机制的API调用函数(仅修改max_tokens)""" 66 | try: 67 | response = client.chat.completions.create( 68 | model=model, 69 | messages=[ 70 | {"role": "system", "content": "你是一个专业的数据库架构师,负责生成带有embedding列的数据库模式"}, 71 | {"role": "user", "content": prompt} 72 | ], 73 | temperature=0.7, 74 | max_tokens=10000 # 修改为10000 75 | ) 76 | return response.choices[0].message.content 77 | except Exception as e: 78 | print(f"API调用失败: {str(e)}") 79 | raise 80 | 81 | def process_prompt(prompt, model, client): 82 | """处理单个提示的核心逻辑""" 83 | cache_key = get_cache_key(prompt, model) 84 | 85 | # 检查缓存 86 | cached = load_from_cache(cache_key) 87 | if cached: 88 | return cached 89 | 90 | # 生成响应 91 | response = generate_response(client, model, prompt) 92 | 93 | # 解析响应 94 | enhanced_schema = parse_response(response) 95 | if not enhanced_schema: 96 | return None 97 | 98 | # 解析提示 99 | domain, scenario = parse_prompt(prompt) 100 | 101 | # 构建结果 102 | result = { 103 | "prompt": prompt, 104 | "domain": domain, 105 | "scenario": scenario, 106 | "enhanced_schema": json.dumps(enhanced_schema, indent=2, ensure_ascii=False), 107 | "raw_response": response 108 | } 109 | 110 | # 保存到缓存 111 | save_to_cache(cache_key, result) 112 | return result 113 | 114 | def llm_inference(model, prompts, api_key=None, api_url=None, max_workers=8): 115 | """ 116 | 多线程推理函数(仅修改max_tokens) 117 | """ 118 | # 初始化OpenAI客户端 119 | client = openai.OpenAI( 120 | api_key=api_key, 121 | base_url=api_url.rstrip('/') if api_url else "https://api.openai.com/v1" 122 | ) 123 | 124 | results = [] 125 | 126 | # 使用线程池并行处理 127 | with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: 128 | # 创建部分函数固定参数 129 | process_fn = partial(process_prompt, model=model, client=client) 130 | 131 | # 提交所有任务 132 | futures = {executor.submit(process_fn, prompt): prompt for prompt in prompts} 133 | 134 | # 使用tqdm显示进度 135 | for future in tqdm(concurrent.futures.as_completed(futures), 136 | total=len(prompts), 137 | desc="处理进度"): 138 | try: 139 | result = future.result() 140 | if result: 141 | results.append(result) 142 | except Exception as e: 143 | prompt = futures[future] 144 | print(f"处理失败 - 提示: {prompt[:50]}... 错误: {str(e)}") 145 | 146 | return results 147 | 148 | if __name__ == '__main__': 149 | parser = argparse.ArgumentParser() 150 | parser.add_argument("--model", type=str, required=True) 151 | parser.add_argument("--api_key", type=str) 152 | parser.add_argument("--api_url", type=str) 153 | parser.add_argument("--max_workers", type=int, default=8) 154 | parser.add_argument("--limited_num", type=int, default=0) 155 | args = parser.parse_args() 156 | 157 | # 加载提示 158 | limited_prompts = json.load(open("./prompts/prompts_schema_embedding.json", encoding='utf-8')) 159 | 160 | #限制数据数目 161 | limited_number = args.limited_num 162 | if limited_number != 0: 163 | limited_prompts = limited_prompts[:limited_number] 164 | 165 | # 执行推理 166 | results = llm_inference( 167 | model=args.model, 168 | prompts=limited_prompts, 169 | api_key=args.api_key, 170 | api_url=args.api_url, 171 | max_workers=args.max_workers 172 | ) 173 | 174 | # 保存结果 175 | output_file = "./results/schema_embedding.json" 176 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 177 | with open(output_file, 'w', encoding='utf-8') as f: 178 | json.dump(results, f, indent=2, ensure_ascii=False) 179 | 180 | print(f"处理完成,结果已保存到 {output_file}") 181 | -------------------------------------------------------------------------------- /Data_Synthesizer/database_synthesis/synthesize_schema.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import time 6 | import json_repair 7 | import openai 8 | from tqdm import tqdm 9 | import random 10 | from tenacity import retry, stop_after_attempt, wait_exponential 11 | 12 | import concurrent.futures 13 | from functools import partial 14 | 15 | # 新增缓存目录常量 16 | CACHE_DIR = "./cache" 17 | os.makedirs(CACHE_DIR, exist_ok=True) 18 | 19 | def parse_response(response): 20 | domain_pattern = r'(?<=\[START_DOMAIN\])(.*?)(?=\[END_DOMAIN\])' 21 | scenario_pattern = r'(?<=\[START_SCENARIO\])(.*?)(?=\[END_SCENARIO\])' 22 | schema_pattern = r'(?<=\[START_DATABASE_SCHEMA\])(.*?)(?=\[END_DATABASE_SCHEMA\])' 23 | 24 | try: 25 | domain_match = re.search(domain_pattern, response, re.DOTALL) 26 | domain = domain_match.group(0).strip() if domain_match else None 27 | 28 | scenario_match = re.search(scenario_pattern, response, re.DOTALL) 29 | scenario = scenario_match.group(0).strip() if scenario_match else None 30 | 31 | schema_match = re.search(schema_pattern, response, re.DOTALL) 32 | schema = schema_match.group(0).strip() if schema_match else None 33 | schema_dict = json_repair.loads(schema) 34 | schema = json.dumps(schema_dict, indent=2, ensure_ascii=False) 35 | 36 | return domain, scenario, schema 37 | except Exception as e: 38 | print(response) 39 | print(f"length: {len(response)}") 40 | print("Parsing Exception:", str(e)) 41 | return None, None, None 42 | 43 | 44 | def get_cache_filename(prompt, model): 45 | """生成基于提示内容和模型名称的唯一缓存文件名""" 46 | import hashlib 47 | prompt_hash = hashlib.md5(prompt.encode('utf-8')).hexdigest() 48 | return f"{CACHE_DIR}/{model}_{prompt_hash}.json" 49 | 50 | def load_from_cache(prompt, model): 51 | """尝试从缓存加载结果""" 52 | cache_file = get_cache_filename(prompt, model) 53 | if os.path.exists(cache_file): 54 | with open(cache_file, 'r', encoding='utf-8') as f: 55 | return json.load(f) 56 | return None 57 | 58 | def save_to_cache(prompt, model, result): 59 | """保存结果到缓存""" 60 | cache_file = get_cache_filename(prompt, model) 61 | with open(cache_file, 'w', encoding='utf-8') as f: 62 | json.dump(result, f, ensure_ascii=False, indent=2) 63 | 64 | def llm_inference_openai(model, prompts, api_key, api_url=None, max_tokens=10240, temperature=0.7, max_workers=128): 65 | """ 66 | 改进后的推理函数,带有缓存机制和多线程并行处理 67 | 68 | Args: 69 | model: 模型名称 70 | prompts: 提示列表 71 | api_key: OpenAI API密钥 72 | api_url: API端点URL 73 | max_tokens: 最大token数 74 | temperature: 生成温度 75 | max_workers: 最大线程数 76 | """ 77 | client = openai.OpenAI( 78 | api_key=api_key, 79 | base_url=api_url.rstrip('/') if api_url else "https://api.openai.com" 80 | ) 81 | 82 | @retry(stop=stop_after_attempt(3), 83 | wait=wait_exponential(multiplier=1, min=4, max=10)) 84 | def generate_response(prompt): 85 | try: 86 | # 先检查缓存 87 | cached = load_from_cache(prompt, model) 88 | if cached: 89 | return cached["generated_content"]["response"] 90 | 91 | response = client.chat.completions.create( 92 | model=model, 93 | messages=[ 94 | {"role": "system", "content": "You are a helpful assistant that generates database schemas."}, 95 | {"role": "user", "content": prompt} 96 | ], 97 | temperature=temperature, 98 | max_tokens=max_tokens 99 | ) 100 | content = response.choices[0].message.content 101 | return content 102 | except Exception as e: 103 | print(f"生成响应失败: {str(e)}") 104 | raise 105 | 106 | def process_prompt(prompt): 107 | try: 108 | # 检查是否已有完整结果缓存 109 | cached_result = load_from_cache(prompt, model) 110 | if cached_result and all(cached_result["generated_content"].get(k) for k in ["response", "domain", "scenario", "schema"]): 111 | return cached_result 112 | 113 | response = generate_response(prompt) 114 | domain, scenario, schema = parse_response(response) 115 | 116 | if all([domain, scenario, schema]): 117 | result = { 118 | "prompt": prompt, 119 | "generated_content": { 120 | "response": response, 121 | "domain": domain, 122 | "scenario": scenario, 123 | "schema": schema 124 | } 125 | } 126 | save_to_cache(prompt, model, result) 127 | return result 128 | else: 129 | print(f"无效响应格式 - 提示: {prompt[:50]}...") 130 | return None 131 | 132 | except Exception as e: 133 | print(f"处理失败 - 提示: {prompt[:50]}... 错误: {str(e)}") 134 | return None 135 | 136 | results = [] 137 | 138 | # 使用线程池并行处理 139 | with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: 140 | # 使用partial固定除prompt外的其他参数 141 | process_fn = partial(process_prompt) 142 | 143 | # 使用tqdm显示进度 144 | futures = {executor.submit(process_fn, prompt): prompt for prompt in prompts} 145 | 146 | for future in tqdm(concurrent.futures.as_completed(futures), total=len(prompts), desc="并行生成响应进度"): 147 | result = future.result() 148 | if result: 149 | results.append(result) 150 | 151 | return results 152 | 153 | if __name__ == '__main__': 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument("--model", type=str) 156 | parser.add_argument("--api_key", type=str) 157 | parser.add_argument("--api_url", type=str, default=None) 158 | parser.add_argument("--use_cache", type=bool, default=True, help="是否使用缓存") 159 | args = parser.parse_args() 160 | 161 | # 加载并抽样提示 162 | prompts = json.load(open("./prompts/prompts_schema_synthesis.json")) 163 | sample_size = int(len(prompts) * 0.1) 164 | 165 | # 设置固定的随机种子,确保每次运行时抽样的提示列表都相同 166 | # 这是让缓存能够被成功加载的关键 167 | random.seed(42) 168 | 169 | test_prompts = random.sample(prompts, sample_size) 170 | 171 | # 执行推理 172 | results = llm_inference_openai(args.model, test_prompts, args.api_key, args.api_url) 173 | 174 | # 保存最终结果 175 | output_file = "./results/schema_synthesis.json" 176 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 177 | with open(output_file, "w", encoding="utf-8") as f: 178 | json.dump(results, f, indent=2, ensure_ascii=False) 179 | -------------------------------------------------------------------------------- /Data_Synthesizer/vectorization/find_semantic_rich_column.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from tqdm import tqdm 4 | from openai import OpenAI 5 | import httpx # 导入 httpx 6 | from functools import lru_cache 7 | from concurrent.futures import ThreadPoolExecutor 8 | import os 9 | from pathlib import Path 10 | import logging 11 | import traceback 12 | from dotenv import load_dotenv 13 | 14 | # --- 1. 配置日志 --- 15 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 16 | logger = logging.getLogger(__name__) 17 | 18 | # --- 2. 核心函数 (保留了代理修复) --- 19 | 20 | @lru_cache(maxsize=10000) 21 | def cached_llm_call(model: str, prompt: str, api_url: str, api_key: str) -> str: 22 | """ 23 | 带有缓存的LLM调用,以避免对相同提示的重复请求。 24 | """ 25 | # --- 关键修复:创建一个不信任系统环境变量(包括代理)的HTTP客户端 --- 26 | # trust_env=False 是一个更可靠的方法来确保httpx不使用任何系统级的代理设置。 27 | http_client = httpx.Client(trust_env=False) 28 | 29 | client = OpenAI( 30 | api_key=api_key, 31 | base_url=api_url if api_url else None, 32 | http_client=http_client # 将自定义的客户端传递给OpenAI 33 | ) 34 | 35 | try: 36 | response = client.chat.completions.create( 37 | model=model, 38 | messages=[{"role": "user", "content": prompt}], 39 | temperature=0.8 40 | ) 41 | return response.choices[0].message.content 42 | except Exception as e: 43 | logger.error(f"调用LLM API时出错: {str(e)}") 44 | return "" 45 | 46 | def parse_response(response: str) -> dict: 47 | """ 48 | 解析大模型响应,提取JSON对象。 49 | """ 50 | # 尝试直接解析为JSON 51 | try: 52 | return json.loads(response) 53 | except json.JSONDecodeError: 54 | pass 55 | 56 | # 尝试从代码块中提取JSON 57 | pattern = r"```json\s*(.*?)\s*```" 58 | json_blocks = re.findall(pattern, response, re.DOTALL) 59 | 60 | if json_blocks: 61 | try: 62 | return json.loads(json_blocks[-1].strip()) 63 | except json.JSONDecodeError as e: 64 | logger.error(f"JSON解析错误: {str(e)}") 65 | 66 | # 尝试提取纯JSON内容 67 | try: 68 | start = response.find('{') 69 | end = response.rfind('}') 70 | if start != -1 and end != -1 and end > start: 71 | return json.loads(response[start:end+1]) 72 | except Exception: 73 | pass 74 | 75 | logger.error("无法从响应中提取有效的JSON") 76 | return {} 77 | 78 | def process_db_info(db_info: dict, model: str, api_url: str, api_key: str, prompt_template: str) -> dict: 79 | """ 80 | 处理单个数据库信息。 81 | """ 82 | prompt = prompt_template.format(dababase_schema=json.dumps(db_info, ensure_ascii=False)) 83 | response = cached_llm_call(model, prompt, api_url, api_key) 84 | parsed_response = parse_response(response) 85 | db_info["semantic_rich_column"] = parsed_response 86 | return db_info 87 | 88 | # --- 3. 主执行逻辑 (已更新参数) --- 89 | 90 | def main_find_rich_semantic_column(model,api_key,api_url,input_file,output_file,no_parallel_str,prompt_template_path = "./prompt_templates/find_semantic_rich_column.txt"): 91 | """ 92 | 主函数,加载配置并执行处理流程。 93 | """ 94 | 95 | no_parallel = no_parallel_str in ['true', '1', 't'] 96 | 97 | config = { 98 | "model": model, 99 | "api_url": api_url, 100 | "api_key": "********", # 不在日志中显示密钥 101 | "no_parallel": no_parallel, 102 | "input_file": input_file, 103 | "output_file": output_file 104 | } 105 | print(f"加载的配置: {config}") 106 | 107 | # 确保输出目录存在 108 | output_dir = Path(output_file).parent 109 | output_dir.mkdir(parents=True, exist_ok=True) 110 | 111 | # 读取提示模板 112 | try: 113 | with open(prompt_template_path, 'r', encoding='utf-8') as file: 114 | prompt_template = file.read() 115 | logger.info("提示词模版文件内容读取成功!") 116 | except Exception as e: 117 | logger.error(f"读取提示词模版文件时出错: {str(e)}") 118 | exit(1) 119 | 120 | # 加载输入数据 121 | try: 122 | with open(input_file, encoding="utf-8") as f: 123 | input_dataset = json.load(f) 124 | logger.info(f"成功加载输入文件,共 {len(input_dataset)} 个数据库") 125 | except Exception as e: 126 | logger.error(f"加载输入文件失败: {str(e)}") 127 | exit(1) 128 | 129 | # 处理函数包装器 130 | def process_item(db_info): 131 | return process_db_info( 132 | db_info=db_info, 133 | model=model, 134 | api_url=api_url, 135 | api_key=api_key, 136 | prompt_template=prompt_template 137 | ) 138 | 139 | # 并行或顺序处理 140 | results = [] 141 | if not no_parallel: 142 | logger.info("使用并行处理模式...") 143 | with ThreadPoolExecutor(max_workers=os.cpu_count() * 2) as executor: 144 | futures = [executor.submit(process_item, db_info) for db_info in input_dataset] 145 | for future in tqdm(futures, total=len(input_dataset), desc="处理数据库"): 146 | try: 147 | results.append(future.result()) 148 | except Exception as e: 149 | logger.error(f"处理数据库时出错: {str(e)}") 150 | traceback.print_exc() 151 | failed_index = futures.index(future) 152 | results.append(input_dataset[failed_index]) 153 | else: 154 | logger.info("使用顺序处理模式...") 155 | for db_info in tqdm(input_dataset, desc="处理数据库"): 156 | try: 157 | results.append(process_item(db_info)) 158 | except Exception as e: 159 | logger.error(f"处理数据库 {db_info.get('db_id', '未知')} 时出错: {str(e)}") 160 | results.append(db_info) 161 | 162 | # 保存结果 163 | try: 164 | with open(output_file, "w", encoding="utf-8") as f: 165 | json.dump(results, f, indent=2, ensure_ascii=False) 166 | logger.info(f"结果成功保存到 {output_file}") 167 | print(f"结果成功保存到 {output_file}") 168 | except Exception as e: 169 | logger.error(f"保存结果失败: {str(e)}") 170 | temp_file = output_dir / "temp_results.json" 171 | with open(temp_file, "w", encoding="utf-8") as f: 172 | json.dump(results, f, indent=2, ensure_ascii=False) 173 | logger.info(f"临时结果已保存到 {temp_file}") 174 | 175 | if __name__ == '__main__': 176 | # 加载 .env 文件中的环境变量 177 | load_dotenv() 178 | logger.info("正在从 .env 文件加载配置...") 179 | 180 | # 从环境变量中读取配置 (已更新) 181 | # 必填参数 182 | model = os.getenv("LLM_MODEL_NAME") 183 | api_key = os.getenv("API_KEY") 184 | 185 | # 检查必填参数是否存在 (已更新) 186 | if not model or not api_key: 187 | missing_vars = [] 188 | if not model: missing_vars.append("LLM_MODEL_NAME") 189 | if not api_key: missing_vars.append("API_KEY") 190 | logger.error(f"错误:以下必须的环境变量未在 .env 文件中设置: {', '.join(missing_vars)}") 191 | exit(1) 192 | # 选填参数(带默认值)(已更新) 193 | api_url = os.getenv("BASE_URL", "http://123.129.219.111:3000/v1") 194 | input_file = os.getenv("INPUT_FILE_FIND_SEMANTIC_RICH", "./results/enhanced_train_tables.json") 195 | output_file = os.getenv("OUTPUT_FILE_FIND_SEMANTIC_RICH", "./results/find_semantic_tables.json") 196 | # 处理布尔值参数 197 | no_parallel_str = os.getenv("NO_PARALLEL_FIND_SEMANTIC_RICH", "false").lower() 198 | main_find_rich_semantic_column(model,api_key,api_url,input_file,output_file,no_parallel_str) 199 | -------------------------------------------------------------------------------- /Data_Synthesizer/synthesis_nl/synthesize_question.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from concurrent.futures import ThreadPoolExecutor 5 | # 移除了 lru_cache, 引入了 Lock 用于线程安全的文件写入 6 | from typing import List, Dict 7 | from threading import Lock 8 | 9 | import openai 10 | from dotenv import load_dotenv 11 | from tqdm import tqdm 12 | 13 | # Load environment variables from a .env file 14 | load_dotenv() 15 | 16 | # --- 新增:用于线程安全地写入缓存文件的锁 --- 17 | cache_lock = Lock() 18 | 19 | # --- 新增:持久化磁盘缓存函数 --- 20 | def load_cache(cache_file: str) -> dict: 21 | """ 22 | 从 .jsonl 文件加载缓存。 23 | 每一行都是一个独立的 JSON 对象 {"key": prompt, "value": response}。 24 | """ 25 | if not os.path.exists(cache_file): 26 | return {} 27 | 28 | cache = {} 29 | with open(cache_file, 'r', encoding='utf-8') as f: 30 | for line in f: 31 | line = line.strip() 32 | if not line: 33 | continue 34 | try: 35 | record = json.loads(line) 36 | if "key" in record and "value" in record: 37 | cache[record["key"]] = record["value"] 38 | except json.JSONDecodeError: 39 | print(f"Skipping corrupted line in cache file: {line}") 40 | return cache 41 | 42 | def save_to_cache(cache_file: str, key: str, value: str): 43 | """ 44 | 将单个键值对安全地以 .jsonl 格式追加到缓存文件中。 45 | """ 46 | with cache_lock: 47 | record = {"key": key, "value": value} 48 | with open(cache_file, 'a', encoding='utf-8') as f: 49 | f.write(json.dumps(record, ensure_ascii=False) + '\n') 50 | 51 | # --- 修改:移除了 @lru_cache 装饰器,并重命名函数 --- 52 | def make_llm_call(model: str, prompt: str, api_url: str, api_key: str) -> str: 53 | """ 54 | 实际执行 LLM API 调用的函数,并增加了对返回结果的健壮性检查。 55 | """ 56 | client = openai.OpenAI( 57 | api_key=api_key, 58 | base_url=api_url if api_url else None 59 | ) 60 | 61 | try: 62 | response = client.chat.completions.create( 63 | model=model, 64 | messages=[{"role": "user", "content": prompt}], 65 | temperature=0.8 66 | ) 67 | # --- 新增的健壮性检查 --- 68 | # 1. 检查 response 对象是否存在 69 | # 2. 检查 response.choices 列表是否存在且不为空 70 | if response and response.choices and len(response.choices) > 0: 71 | # 只有检查通过后,才安全地访问内容 72 | return response.choices[0].message.content 73 | else: 74 | # 如果响应格式不正确,打印警告信息并返回空字符串 75 | print(f"Warning: Received an invalid or empty response from API. Response: {response}") 76 | return "" 77 | 78 | except Exception as e: 79 | print(f"Error calling LLM API: {e}") 80 | return "" 81 | 82 | # --- 修改:整合了新的持久化缓存逻辑 --- 83 | def llm_inference( 84 | model: str, 85 | dataset: List[Dict], 86 | api_key: str, 87 | cache_file_path: str, # 新增缓存文件路径参数 88 | api_url: str = "", 89 | parallel_workers: int = 4 90 | ) -> List[Dict]: 91 | """ 92 | 使用持久化磁盘缓存执行 LLM 推理。 93 | """ 94 | # 1. 加载现有缓存 95 | cache = load_cache(cache_file_path) 96 | print(f"Loaded {len(cache)} items from cache file: {cache_file_path}") 97 | 98 | # 2. 筛选出需要新处理的任务和已处理的任务 99 | items_to_process = [] 100 | final_results = [] 101 | for data in dataset: 102 | prompt = data["prompt"] 103 | if prompt in cache: 104 | # 如果在缓存中,直接添加到最终结果 105 | final_results.append({**data, "responses": [cache[prompt]]}) 106 | else: 107 | # 否则,添加到待处理列表 108 | items_to_process.append(data) 109 | 110 | print(f"Total items: {len(dataset)}, To process: {len(items_to_process)}") 111 | 112 | # 3. 定义单个任务的处理函数 113 | def process_item(data: Dict) -> Dict: 114 | prompt = data["prompt"] 115 | response = make_llm_call(model, prompt, api_url, api_key) 116 | # 成功后立刻写入缓存 117 | if response: 118 | save_to_cache(cache_file_path, prompt, response) 119 | # 返回包含响应的完整数据 120 | return {**data, "responses": [response]} 121 | 122 | # 4. 执行需要处理的任务 123 | if items_to_process: 124 | newly_processed_results = [] 125 | if parallel_workers > 1: 126 | with ThreadPoolExecutor(max_workers=parallel_workers) as executor: 127 | newly_processed_results = list(tqdm( 128 | executor.map(process_item, items_to_process), 129 | total=len(items_to_process), 130 | desc="Generating responses" 131 | )) 132 | else: 133 | newly_processed_results = [process_item(data) for data in tqdm(items_to_process, desc="Generating responses")] 134 | 135 | # 5. 将新处理的结果与已在缓存中的结果合并 136 | final_results.extend(newly_processed_results) 137 | 138 | return final_results 139 | 140 | def synthesize_questions( 141 | input_file: str, 142 | output_file: str, 143 | model_name: str, 144 | api_key: str, 145 | api_url: str, 146 | max_workers: int, 147 | cache_file_path: str # 新增缓存文件路径参数 148 | ): 149 | """ 150 | 主逻辑函数,现在包含缓存路径。 151 | """ 152 | if not api_key or not model_name: 153 | raise ValueError("Error: api_key and model_name must be provided.") 154 | 155 | print("--- Running Synthesis with Configuration ---") 156 | print(f"Model: {model_name}") 157 | print(f"API URL: {api_url}") 158 | print(f"Max Workers: {max_workers}") 159 | print(f"Cache File: {cache_file_path}") 160 | print(f"Input File: {input_file}") 161 | print(f"Output File: {output_file}") 162 | print("------------------------------------------") 163 | 164 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 165 | os.makedirs(os.path.dirname(cache_file_path), exist_ok=True) 166 | 167 | with open(input_file, 'r', encoding='utf-8') as f: 168 | input_dataset = json.load(f) 169 | 170 | results = llm_inference( 171 | model=model_name, 172 | dataset=input_dataset, 173 | api_key=api_key, 174 | api_url=api_url, 175 | cache_file_path=cache_file_path, # 传递缓存路径 176 | parallel_workers=max_workers 177 | ) 178 | 179 | with open(output_file, "w", encoding="utf-8") as f: 180 | json.dump(results, f, indent=2, ensure_ascii=False) 181 | 182 | print(f"\nSynthesis complete. Results saved to {output_file}") 183 | 184 | 185 | if __name__ == '__main__': 186 | parser = argparse.ArgumentParser(description="Run LLM inference for question synthesis.") 187 | parser.add_argument("--input_file", type=str, default="./prompts/question_synthesis_prompts.json") 188 | parser.add_argument("--output_file", type=str, default="./results/question_synthesis.json") 189 | # 新增缓存文件路径的命令行参数 190 | parser.add_argument("--cache_file", type=str, default="./cache/question_synthesis_cache.jsonl", 191 | help="Path to the persistent cache file") 192 | 193 | opt = parser.parse_args() 194 | 195 | # 从环境变量加载配置 196 | api_key_env = os.getenv("API_KEY") 197 | api_url_env = os.getenv("BASE_URL") 198 | model_name_env = os.getenv("LLM_MODEL_NAME") 199 | max_workers_env = int(os.getenv("MAX_WORKERS", 32)) # 增加了默认的并行数 200 | 201 | synthesize_questions( 202 | input_file=opt.input_file, 203 | output_file=opt.output_file, 204 | model_name=model_name_env, 205 | api_key=api_key_env, 206 | api_url=api_url_env, 207 | max_workers=max_workers_env, 208 | cache_file_path=opt.cache_file # 传递缓存路径 209 | ) 210 | -------------------------------------------------------------------------------- /Evaluation_Framework/script/api_pipeline.py: -------------------------------------------------------------------------------- 1 | # 这个文件是通用pipeline,可以用来生成训练数据(你需要先准备好向量数据库) 2 | # main文件中放了所有算子,你可以选择性地使用它们。 3 | # 如果是多模态的wiki数据集的处理,它比其他数据集要多一个build_final_db_with_images算子 4 | # 如果是训练数据,可以把cot相关的算子也加进来 5 | import yaml 6 | import os 7 | from pprint import pprint 8 | from pathlib import Path 9 | import sys 10 | 11 | # 配置hugging face代理 12 | os.environ['HF_ENDPOINT'] = 'https://alpha.hf-mirror.com' 13 | 14 | # 只需要修改这里,就可以加载不同的数据集配置! 15 | DATASET_BACKEND = "sqlite" 16 | # DATASET_BACKEND = "clickhouse" 17 | 18 | DATASET_TO_LOAD = "toy_spider" 19 | # DATASET_TO_LOAD = "bird" # 例如,切换到bird数据集 20 | 21 | # 获取当前文件的绝对路径 22 | current_file_path = os.path.abspath(__file__) 23 | 24 | # # 获取 project 目录的路径 (当前文件的父目录的父目录) 25 | # project_root_path = os.path.dirname(os.path.dirname(current_file_path)) 26 | 27 | # # 将 project 目录添加到 sys.path 28 | # if project_root_path not in sys.path: 29 | # sys.path.append(project_root_path) 30 | 31 | 32 | from generate_query_id import add_query_ids_to_json 33 | from generate_ground_truth import transform_json_data 34 | from generate_eval_prompts import generate_sql_prompts 35 | from synthesize_sql import run_sql_synthesis 36 | 37 | # -------------------------------------------------------------------- 38 | # 安装提示 (Installation Tip) 39 | # -------------------------------------------------------------------- 40 | # 如果您的环境中没有 PyYAML 库,请先安装它。 41 | # You need to install the PyYAML library if you don't have it. 42 | # pip install PyYAML 43 | # -------------------------------------------------------------------- 44 | from typing import Any 45 | 46 | class DynamicConfig: 47 | def __init__(self, config_dict: dict): 48 | if config_dict: 49 | for key, value in config_dict.items(): 50 | # 如果值是字典,也将其转换为DynamicConfig实例,以支持链式调用 51 | if isinstance(value, dict): 52 | setattr(self, key, DynamicConfig(value)) 53 | else: 54 | setattr(self, key, value) 55 | 56 | # 增加一个get方法以安全地获取属性,类似字典 57 | def get(self, key: str, default: Any = None) -> Any: 58 | return getattr(self, key, default) 59 | 60 | class ServicesConfig(DynamicConfig): 61 | pass 62 | 63 | class PathsConfig(DynamicConfig): 64 | pass 65 | 66 | class ParametersConfig(DynamicConfig): 67 | pass 68 | 69 | class AppConfig: 70 | def __init__(self, base_dir: str, services_dict: dict, paths_dict: dict, params_dict: dict): 71 | self.base_dir = base_dir 72 | self.services = ServicesConfig(services_dict) 73 | self.paths = PathsConfig(paths_dict) 74 | self.parameters = ParametersConfig(params_dict) 75 | 76 | # ------------------------------------------------------------------- 77 | # 在这里添加辅助函数并修改 load_config 78 | # ------------------------------------------------------------------- 79 | def _format_config_paths(config_node: Any, dataset_name: str) -> Any: 80 | """ 81 | 【新增的辅助函数】 82 | 递归地遍历配置节点(字典或列表),格式化所有字符串。 83 | """ 84 | if isinstance(config_node, dict): 85 | return {key: _format_config_paths(value, dataset_name) for key, value in config_node.items()} 86 | elif isinstance(config_node, list): 87 | return [_format_config_paths(item, dataset_name) for item in config_node] 88 | elif isinstance(config_node, str): 89 | # 核心逻辑:用实际的数据集名称替换占位符 {dataset} 90 | return config_node.format(dataset=dataset_name) 91 | else: 92 | return config_node 93 | 94 | def load_config(database: str, dataset: str, config_path: str = 'config.yaml') -> AppConfig: 95 | """ 96 | 【修改后的函数】 97 | 从 YAML 文件中加载、解析、格式化并封装配置。 98 | """ 99 | config_file = Path(config_path) 100 | if not config_file.is_file(): 101 | raise FileNotFoundError(f"配置文件未找到: {config_file.resolve()}") 102 | 103 | with open(config_file, 'r', encoding='utf-8') as f: 104 | full_config = yaml.safe_load(f) 105 | 106 | try: 107 | db_config = full_config[database] 108 | dataset_config = db_config[dataset] 109 | except KeyError as e: 110 | raise KeyError(f"在 '{config_path}' 中找不到配置路径: {database}.{dataset} - {e}") 111 | 112 | base_dir = db_config.get('base_dir') 113 | # 1. 先像原来一样获取原始的字典 114 | services_dict = dataset_config.get('services', {}) 115 | paths_dict = dataset_config.get('paths', {}) 116 | params_dict = dataset_config.get('parameters', {}) 117 | 118 | # 2. 【关键新增步骤】调用辅助函数,用 `dataset` 变量格式化路径字典 119 | formatted_paths_dict = _format_config_paths(paths_dict, dataset) 120 | 121 | # 3. 将格式化后的字典传入 AppConfig 122 | return AppConfig(base_dir, services_dict, formatted_paths_dict, params_dict) 123 | 124 | def create_directory_with_os(directory_name: str): 125 | """ 126 | 使用 os 模块创建目录。 127 | 如果路径已作为文件存在,则发出警告;如果目录已存在,则静默处理。 128 | """ 129 | # 检查目标路径是否已经存在,并且是一个文件 130 | if os.path.exists(directory_name) and not os.path.isdir(directory_name): 131 | # 如果是文件,打印警告并直接返回,不做任何事 132 | print(f"warning: 路径 '{directory_name}' 已作为文件存在, 无法创建同名目录。") 133 | return 134 | 135 | # 尝试创建目录,exist_ok=True 会处理目录已存在的情况 136 | try: 137 | os.makedirs(directory_name, exist_ok=True) 138 | except OSError as e: 139 | # 捕获其他可能的OS错误 140 | print(f"创建目录 '{directory_name}' 时发生未知错误: {e}") 141 | 142 | def main(): 143 | try: 144 | config = load_config(database=DATASET_BACKEND, dataset=DATASET_TO_LOAD) 145 | 146 | print(f"--- 成功加载 '{DATASET_TO_LOAD}' 数据集的 '{DATASET_BACKEND}' 配置! ---") 147 | 148 | # --- 【关键修改】更准确、更稳健的目录创建逻辑 --- 149 | print("\n正在创建所有必需的目录...") 150 | 151 | for path_name, path_value in vars(config.paths).items(): 152 | if not isinstance(path_value, str) or not path_value: 153 | continue # 如果值不是字符串或为空,则跳过 154 | 155 | # 使用 os.path.splitext() 来判断一个路径是否指向文件 156 | # 如果路径有扩展名(如 .json, .txt),我们认为它是一个文件路径。 157 | _, file_extension = os.path.splitext(path_value) 158 | 159 | if file_extension: 160 | # 如果有文件扩展名, 说明这是一个文件路径。 161 | # 我们需要创建的是它的【父目录】。 162 | directory_to_create = os.path.dirname(path_value) 163 | # else: 164 | # # 如果没有文件扩展名, 我们假定它本身就是一个【目录路径】。 165 | # directory_to_create = path_value 166 | 167 | # 只有当 `directory_to_create` 非空时才创建 168 | if directory_to_create: 169 | create_directory_with_os(directory_to_create) 170 | 171 | print("所有相关目录已创建完毕。") 172 | 173 | 174 | except (FileNotFoundError, KeyError, yaml.YAMLError) as e: 175 | print(f"错误: 配置加载失败 - {e}") 176 | 177 | # #开始执行pipeline 178 | print("开始为原始数据添加query_id,作为后续数据的标识") 179 | add_query_ids_to_json(config.paths.input_file_to_id,config.paths.dataset_json_path) 180 | 181 | print("生成ground truth文件") 182 | transform_json_data(config.paths.dataset_json_path, config.paths.ground_truth_output_path) 183 | 184 | print("生成sql提示词") 185 | generate_sql_prompts(config.paths.dataset_json_path, config.paths.tables_json_path, config.paths.prompt_tamplate_path, config.paths.output_prompt_path, config.parameters.dataset_backend, config.paths.database_note_prompt_path, config.services.openai.get('embedding_model_name')) 186 | 187 | print("生成最终sql文件,作为测评框架的输入文件") 188 | run_sql_synthesis(config.paths.sql_prompt_file_path, config.paths.eval_input_path, config.services.openai.get('llm_model_name'),config.services.openai.get('api_key'), config.services.openai.get('base_url'), config.paths.cache_file_path_sql, config.parameters.dataset_backend, config.parameters.no_parallel,config.parameters.use_vllm) 189 | 190 | 191 | if __name__ == "__main__": 192 | main() 193 | -------------------------------------------------------------------------------- /Data_Synthesizer/synthesis_sql/synthesize_sql.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import re 4 | import os 5 | from dotenv import load_dotenv 6 | from tqdm import tqdm 7 | from openai import OpenAI 8 | # 移除了 lru_cache,引入了 Lock 用于线程安全的文件写入 9 | from concurrent.futures import ThreadPoolExecutor 10 | from pathlib import Path 11 | from threading import Lock 12 | 13 | # Load environment variables from a .env file 14 | load_dotenv() 15 | 16 | # --- 新增:用于线程安全地写入缓存文件的锁 --- 17 | cache_lock = Lock() 18 | 19 | # --- 新增:持久化磁盘缓存函数 --- 20 | def load_cache(cache_file: str) -> dict: 21 | """ 22 | 从 .jsonl 文件加载缓存。 23 | 每一行都是一个独立的 JSON 对象 {"key": prompt, "value": response}。 24 | 这种方式可以抵抗因程序中断导致的文件损坏。 25 | """ 26 | if not os.path.exists(cache_file): 27 | return {} 28 | 29 | cache = {} 30 | with open(cache_file, 'r', encoding='utf-8') as f: 31 | for line in f: 32 | line = line.strip() 33 | if not line: 34 | continue 35 | try: 36 | record = json.loads(line) 37 | if "key" in record and "value" in record: 38 | cache[record["key"]] = record["value"] 39 | except json.JSONDecodeError: 40 | print(f"Skipping corrupted line in cache file: {line}") 41 | return cache 42 | 43 | def save_to_cache(cache_file: str, key: str, value: str): 44 | """ 45 | 将单个键值对安全地以 .jsonl 格式追加到缓存文件中。 46 | 使用追加模式 ('a'),既高效又安全。 47 | """ 48 | with cache_lock: 49 | record = {"key": key, "value": value} 50 | with open(cache_file, 'a', encoding='utf-8') as f: 51 | f.write(json.dumps(record, ensure_ascii=False) + '\n') 52 | 53 | # --- 修改:移除了 @lru_cache 装饰器 --- 54 | def make_llm_call(model: str, prompt: str, api_url: str, api_key: str) -> str: 55 | """ 56 | 实际执行 LLM API 调用的函数。 57 | """ 58 | client = OpenAI( 59 | api_key=api_key, 60 | base_url=api_url if api_url else None 61 | ) 62 | 63 | try: 64 | response = client.chat.completions.create( 65 | model=model, 66 | messages=[{"role": "user", "content": prompt}], 67 | temperature=0.8 68 | ) 69 | return response.choices[0].message.content 70 | except Exception as e: 71 | print(f"Error calling LLM API: {str(e)}") 72 | return "" 73 | 74 | def parse_response(response): 75 | pattern = r"```sql\s*(.*?)\s*```" 76 | 77 | sql_blocks = re.findall(pattern, response, re.DOTALL) 78 | 79 | if sql_blocks: 80 | last_sql = sql_blocks[-1].strip() 81 | return last_sql 82 | else: 83 | print("No SQL blocks found.") 84 | return "" 85 | 86 | # --- 修改:整合了新的缓存逻辑 --- 87 | def llm_inference( 88 | model: str, 89 | items: list, 90 | api_url: str, 91 | api_key: str, 92 | cache_file_path: str, # 新增缓存文件路径参数 93 | parallel: int = 32 94 | ) -> list: 95 | """ 96 | 使用持久化磁盘缓存生成 LLM 响应。 97 | """ 98 | 99 | # 1. 加载现有缓存 100 | cache = load_cache(cache_file_path) 101 | print(f"Loaded {len(cache)} items from cache file: {cache_file_path}") 102 | 103 | # 2. 筛选出需要新处理的任务 104 | items_to_process = [] 105 | 106 | for item in items: 107 | # 假设每个 item 都包含 'sql_synthesis_prompt' 键 108 | if item.get("prompt") not in cache: 109 | items_to_process.append(item) 110 | 111 | print(f"Total items: {len(items)}, To process: {len(items_to_process)}") 112 | 113 | # 3. 定义单个任务的处理函数 114 | def process_item(item): 115 | prompt = item["prompt"] 116 | # 调用 API 117 | response = make_llm_call(model, prompt, api_url, api_key) 118 | # 成功后立刻写入缓存 119 | if response: 120 | save_to_cache(cache_file_path, prompt, response) 121 | return prompt, response 122 | 123 | # 4. 执行需要处理的任务(并行或顺序) 124 | if items_to_process: 125 | if parallel: 126 | with ThreadPoolExecutor(max_workers=parallel) as executor: 127 | # 使用 list 包装 tqdm 以立即显示进度条 128 | results_iterator = list(tqdm( 129 | executor.map(process_item, items_to_process), 130 | total=len(items_to_process), 131 | desc="Generating responses" 132 | )) 133 | # 将新结果更新到内存缓存中 134 | for prompt, response in results_iterator: 135 | cache[prompt] = response 136 | else: 137 | for prompt in tqdm(items_to_process, desc="Generating responses"): 138 | _, response = process_item(prompt) 139 | cache[prompt] = response 140 | 141 | # 5. 组装最终结果 142 | final_results = [] 143 | for item in items: 144 | prompt = item["prompt"] 145 | final_results.append({ 146 | "prompt": prompt, 147 | "db_id": item["db_id"], 148 | "response": cache.get(prompt, "") # 从更新后的缓存中获取结果 149 | }) 150 | 151 | return final_results 152 | 153 | def run_sql_synthesis( 154 | input_file: str, 155 | output_file: str, 156 | model_name: str, 157 | api_key: str, 158 | api_url: str, 159 | cache_file_path: str, # 新增缓存文件路径参数 160 | max_workers: int = 32 161 | ): 162 | """ 163 | 主逻辑函数,现在包含缓存路径。 164 | """ 165 | if not api_key or not model_name: 166 | raise ValueError("Error: api_key and model_name must be provided.") 167 | 168 | print("--- Running Synthesis with Configuration ---") 169 | print(f"Model: {model_name}") 170 | print(f"API URL: {api_url}") 171 | print(f"Cache File: {cache_file_path}") # 打印缓存文件路径 172 | print(f"Parallel Number: {max_workers}") 173 | print(f"Input File: {input_file}") 174 | print(f"Output File: {output_file}") 175 | print("------------------------------------------") 176 | 177 | Path(output_file).parent.mkdir(parents=True, exist_ok=True) 178 | # 确保缓存目录也存在 179 | Path(cache_file_path).parent.mkdir(parents=True, exist_ok=True) 180 | 181 | input_dataset = json.load(open(input_file, encoding="utf-8")) 182 | 183 | results = llm_inference( 184 | model=model_name, 185 | items=input_dataset, 186 | api_url=api_url, 187 | api_key=api_key, 188 | cache_file_path=cache_file_path, # 传递缓存路径 189 | parallel=max_workers 190 | ) 191 | 192 | with open(output_file, "w", encoding="utf-8") as f: 193 | json.dump(results, f, indent=2, ensure_ascii=False) 194 | 195 | print(f"\nSynthesis complete. Results saved to {output_file}") 196 | 197 | 198 | if __name__ == '__main__': 199 | parser = argparse.ArgumentParser() 200 | parser.add_argument("--input_file", type=str, default="./prompts/sql_synthesis_prompts.json", 201 | help="Input JSON file with prompts") 202 | parser.add_argument("--output_file", type=str, default="./results/sql_synthesis.json", 203 | help="Output JSON file for results") 204 | # 新增缓存文件路径的命令行参数 205 | parser.add_argument("--cache_file", type=str, default="./cache/synthesis_cache.jsonl", 206 | help="Path to the persistent cache file") 207 | opt = parser.parse_args() 208 | 209 | # 从环境变量加载配置 210 | api_key_env = os.getenv("API_KEY") 211 | api_url_env = os.getenv("BASE_URL") 212 | model_name_env = os.getenv("LLM_MODEL_NAME") 213 | 214 | no_parallel_str = os.getenv("NO_PARALLEL", "false").lower() 215 | parallel_execution = not (no_parallel_str == 'true') 216 | 217 | run_sql_synthesis( 218 | input_file=opt.input_file, 219 | output_file=opt.output_file, 220 | model_name=model_name_env, 221 | api_key=api_key_env, 222 | api_url=api_url_env, 223 | cache_file_path=opt.cache_file, # 传递缓存路径 224 | parallel=parallel_execution 225 | ) 226 | -------------------------------------------------------------------------------- /Evaluation_Framework/prompt_templates/sqlite_vec_note_prompt.txt: -------------------------------------------------------------------------------- 1 | There are a few Requirements you should Comply with in addition: 2 | 1. When generating SQL queries, you should prioritize utilizing KNN searches whenever contextually appropriate. However, you have to avoid unnecessary/forced KNN implementations for: 3 | --Traditional relational data queries (especially for columns like: id, age, price) 4 | --Cases where standard SQL operators (equality, range, or aggregation functions) are more efficient and semantically appropriate 5 | 2. Only vector type(like: float[?]) support KNN queries and the name of vector column often end with "_embedding". So, you can use knn queries to search when the column name you need to search for ends with "_embedding" or when the column name with "_embedding" is also in the list. 6 | 3. In any complexity level, you can choose to use KNN queries if need. 7 | 4. When using KNN queries, you have to add LIMIT or 'And k = ?' constraint but do not use them all in the same statement. This rule is very important, do not forget to add LIMIT or 'And k = ?' constraint after MATCH operator. 8 | 5. The lembed function is used to transform a string into a vector, whose type and size match the corresponding column type in the data table. The function has two parameters, the first parameter is the name of the embedding model used (default value: {embedding_model}), and the second parameter is the content of the string type you want to convert. So, you should generate some words or sentences with specific semantic information based on name, type and comment of this column. For example, you can generate "The Daily Grind Coffee Shop\n 456 Oak Avenue\n Springfield, IL 62704\n USA" when this column name is Location_embedding, column type is float[?] and comment of column is "the embedding of location". 9 | 6. The lembed function's second parameter MUST be a SPECIFIC semantic description. 10 | - For location_embedding: Generate REAL addresses (e.g. "Stadium: 123 Main St, Boston, MA. Capacity: 50,000. Home team: Patriots") 11 | - For columns containing semantically meaningful data (e.g., descriptions), generate rich, contextually appropriate information. For columns without meaningful content (e.g., placeholder names), avoid creating semantically dense output to facilitate fuzzy matching operations. 12 | - For name_embedding: You should generate variationsof the original names (e.g., altered spellings, phonetic approximations, or intentionally obfuscated words/characters) to enable Subsequent fuzzy matching to identify semantically similar names. Importantly, never generate redundant information. For example, you can generate "Lige", but do not generate "Ligand Lige", "Ligand example name", "Ligand similar to Aspirin" and "Ligand name variation". 13 | Examples: 14 | ✅ Correct: 15 | name_embedding MATCH lembed({embedding_model}, "Kri") 16 | ❌ Wrong: 17 | name_embedding MATCH lembed({embedding_model}, "A leading publisher based in Germany specializing in 18 | scientific journals and books.") 19 | - For text_embedding: Use ACTUAL and meaningful sentences (e.g. "Harper Lee’s To Kill a Mockingbirdis a timeless exploration of racial injustice and moral growth, seen through the innocent yet perceptive eyes of Scout Finch. With its powerful themes, unforgettable characters like Atticus Finch, and Lee’s poignant prose, the novel remains a searing critique of society’s failures and a testament to the courage of standing for what’s right.") 20 | - NEVER use vague words and generic phrases like "a book review" 21 | Examples: 22 | ✅ Correct: 23 | lembed({embedding_model}, "To Kill a Mockingbird") 24 | ❌ Wrong: 25 | lembed({embedding_model}, "name of a famous book") 26 | 7. When using MATCH, please fill in a vector using function lembed after MATCH that matches the column type (with the same dimension and type). Using details are in examples. 27 | 8. The distance column is an ​​implicitly generated metric​​ that appears when performing vector similarity searches (using the MATCH operator) in SQLite vector extensions like sqlite-vec. If using JOIN operator, you have to clarify which table that distance belongs to. 28 | 9. A SELECT statement should have no more than one MATCH operation. However, each subquery within a SELECT statement could also have no more than one MATCH operation, independent of the parent query." 29 | 10. When performing a KNN/vector similarity search (e.g., using MATCH or lembed), always specify a LIMIT or k=N constraint directly on the vector search operation, even if the outer query already has a LIMIT. The vector search requires its own result cap to avoid ambiguity in ranking and performance issues. 30 | 11. When both LIMIT and k operations are available for vector search queries, prioritize using k operation for ​​Broader Compatibility. 31 | Key Points: 32 | ​--​Vector search needs its own LIMIT/k​​ – The outer LIMIT applies to the final filtered results, not the initial similarity search. 33 | --LIMIT operator should follow closely after "ORDER BY distance". 34 | ❌ Wrong Example: 35 | ```sql 36 | SELECT a.codec_name 37 | FROM audio_codecs a 38 | JOIN quality_levels q ON a.codec_id = q.quality_id 39 | WHERE a.description_embedding MATCH lembed({embedding_model}, "High efficiency audio codec with low latency and optimal bandwidth") 40 | AND q.quality_name = 'HighQuality' 41 | LIMIT 1; 42 | ``` 43 | ✅ Correct Example: 44 | ```sql 45 | SELECT a.codec_name 46 | FROM audio_codecs a 47 | JOIN quality_levels q ON a.codec_id = q.quality_id 48 | WHERE a.description_embedding MATCH lembed({embedding_model}, "High efficiency audio codec with low latency and optimal bandwidth") LIMIT 1 49 | AND q.quality_name = 'HighQuality'; 50 | ``` 51 | --When using JOIN operations, you need to ensure that k does not cause ambiguity in the query. In most cases, the k parameter logically belongs to the same table as the column used in the MATCH clause. So, when the column referenced in the MATCH clause includes a table qualifier (e.g., table1.embedding), the k parameter must be explicitly bound to the same table. 52 | ❌ Wrong Example: 53 | ```sql 54 | SELECT s.stock_id, s.symbol 55 | FROM stocks s 56 | JOIN exchanges e ON s.exchange_id = e.exchange_id 57 | WHERE s.sector_embedding MATCH lembed({embedding_model}, "Tech industry sector in the USA") 58 | AND e.country = 'USA' 59 | AND k = 5 60 | ORDER BY s.stock_id; 61 | ``` 62 | ✅ Correct Example: 63 | ```sql 64 | SELECT s.stock_id, s.symbol 65 | FROM stocks s 66 | JOIN exchanges e ON s.exchange_id = e.exchange_id 67 | WHERE s.sector_embedding MATCH lembed({embedding_model}, "Tech industry sector in the USA") 68 | AND e.country = 'USA' 69 | AND s.k = 5 70 | ORDER BY s.stock_id; 71 | ``` 72 | 12. ​Avoids runtime errors​​ – Many vector databases (e.g., SQLite with sqlite-vss, pgvector) enforce this requirement strictly. 73 | 13. Only a single 'ORDER BY distance' clause is allowed on vec0 KNN queries, not on other columns. 74 | ***Example of KNN queries of sqlite-vec*** 75 | first example(type of vector_embedding is float[384]): 76 | ```sql 77 | SELECT rowid, distance 78 | FROM vec_table 79 | WHERE vector_embedding MATCH lembed({embedding_model},"vector of sun") 80 | ORDER BY distance 81 | LIMIT 1; 82 | ``` 83 | 84 | second example(type of sentence_embedding is float[384]): 85 | ```sql 86 | select 87 | movie_id, 88 | title, 89 | genre, 90 | num_reviews, 91 | mean_rating, 92 | distance 93 | from vec_movies 94 | where sentence_embedding match lembed({embedding_model},"This is a great movie!") 95 | and genre = 'scifi' 96 | and num_reviews between 100 and 500 97 | and mean_rating > 3.5 98 | and k = 5; 99 | ``` 100 | 101 | third example(type of vector_embedding is float[384]): 102 | ```sql 103 | select rowid, name1, name2, age, vec_to_json 104 | from v 105 | where vector_embedding match lembed({embedding_model},"aaa and xxx are good friends, whose age is 18.") 106 | and k = 1 107 | and name1 in ('alex', 'brian', 'craig') 108 | and name2 in ('Rick', 'Morty') 109 | and age in (21, 18); 110 | ``` 111 | -------------------------------------------------------------------------------- /Data_Synthesizer/vectorization/batch_vectorize_databases.py: -------------------------------------------------------------------------------- 1 | # batch_vectorize_databases.py (最终版) 2 | 3 | import os, sys, logging, json 4 | from tqdm import tqdm 5 | import torch 6 | from sentence_transformers import SentenceTransformer, models 7 | from transformers import AutoConfig 8 | import torchvision 9 | from typing import Optional 10 | 11 | os.makedirs("logging", exist_ok=True) 12 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filename='logging/out.log', filemode='w') 13 | torchvision.disable_beta_transforms_warning() 14 | 15 | try: 16 | from .vector_database_generate import generate_database_script, build_vector_database 17 | except ImportError as e: 18 | logging.critical(f"Import Error: {e}.") 19 | sys.exit(1) 20 | 21 | from dotenv import load_dotenv 22 | load_dotenv() 23 | 24 | def load_completion_status(status_file): 25 | if os.path.exists(status_file): 26 | try: 27 | with open(status_file, 'r', encoding='utf-8') as f: return json.load(f) 28 | except (json.JSONDecodeError, TypeError): return {} 29 | return {} 30 | 31 | def save_completion_status(status_file, completed_dbs_dict): 32 | with open(status_file, 'w', encoding='utf-8') as f: json.dump(completed_dbs_dict, f, indent=2) 33 | 34 | def find_database_file(base_path: str, db_id: str) -> Optional[str]: 35 | """ 36 | Finds a database file, checking for .sqlite and then .db extensions. 37 | Returns the full path if found, otherwise None. 38 | """ 39 | # Check for .sqlite first 40 | path_sqlite = os.path.join(base_path, f"{db_id}.sqlite") 41 | if os.path.exists(path_sqlite): 42 | return path_sqlite 43 | 44 | # If not found, check for .db 45 | path_db = os.path.join(base_path, f"{db_id}.db") 46 | if os.path.exists(path_db): 47 | return path_db 48 | 49 | # If neither exists 50 | return None 51 | 52 | def load_universal_sentence_transformer(model_name: str, cache_folder: str, device: str = 'cpu') -> SentenceTransformer: 53 | """ 54 | 加载一个 SentenceTransformer 模型,能自动兼容处理标准的 Transformer 模型和 CLIP 模型。 55 | 56 | Args: 57 | model_name (str): 需要加载的模型名称或路径 (例如 "all-MiniLM-L6-v2" 或 "openai/clip-vit-base-patch32")。 58 | cache_folder (str): 用于缓存下载的模型的文件夹路径。 59 | device (str, optional): 加载模型的设备 ('cpu', 'cuda', etc.)。默认为 'cpu'。 60 | 61 | Returns: 62 | SentenceTransformer: 初始化完成的模型实例。 63 | """ 64 | try: 65 | config = AutoConfig.from_pretrained(model_name, cache_dir=cache_folder) 66 | is_clip_model = 'clip' in getattr(config, 'model_type', '').lower() 67 | except Exception as e: 68 | logging.warning(f"无法加载模型 '{model_name}' 的配置. 将基于模型名称进行判断。错误: {e}") 69 | is_clip_model = 'clip' in model_name.lower() 70 | 71 | if is_clip_model: 72 | logging.info(f"正在以 CLIP 模型方式加载: '{model_name}'") 73 | clip_model_wrapper = models.CLIPModel(model_name) 74 | 75 | # 外层的 SentenceTransformer 会处理缓存,所以这里 cache_folder 参数是必须的 76 | model = SentenceTransformer(modules=[clip_model_wrapper], device=device, cache_folder=cache_folder) 77 | else: 78 | logging.info(f"正在以标准模型方式加载: '{model_name}'") 79 | model = SentenceTransformer(model_name, device=device, cache_folder=cache_folder) 80 | 81 | return model 82 | 83 | def main_batch_vectorize_databases( 84 | SOURCE_DB_ROOT, 85 | SQL_SCRIPT_DIR, 86 | VECTOR_DB_ROOT, 87 | TABLE_JSON_PATH, 88 | EMBEDDING_MODEL_NAME, 89 | model_path 90 | ): 91 | logging.info("--- Starting Batch Database Vectorization ---") 92 | if not all([SOURCE_DB_ROOT, SQL_SCRIPT_DIR, VECTOR_DB_ROOT, TABLE_JSON_PATH]): 93 | logging.critical("One or more required configurations are missing in main_batch_vectorize_databases. Please check (SOURCE_DB_ROOT, SQL_SCRIPT_DIR, VECTOR_DB_ROOT, TABLE_JSON_PATH).") 94 | sys.exit(1) 95 | 96 | if not os.path.exists(SOURCE_DB_ROOT): 97 | print(f"error: no source db: {SOURCE_DB_ROOT}") 98 | 99 | os.makedirs(SQL_SCRIPT_DIR, exist_ok=True) 100 | os.makedirs(VECTOR_DB_ROOT, exist_ok=True) 101 | status_file_path = os.path.join(VECTOR_DB_ROOT, "processing_status.json") 102 | completed_dbs = load_completion_status(status_file_path) 103 | 104 | model, pool = None, None 105 | try: 106 | db_targets = [] 107 | if os.path.exists(SOURCE_DB_ROOT): 108 | for db_id in os.listdir(SOURCE_DB_ROOT): 109 | db_dir = os.path.join(SOURCE_DB_ROOT, db_id) 110 | if os.path.isdir(db_dir): 111 | # Call the helper function to find the database path 112 | db_path = find_database_file(db_dir, db_id) 113 | 114 | # If the helper found a file (either .sqlite or .db), its path will be returned 115 | if db_path: 116 | db_targets.append({'id': db_id, 'path': db_path}) 117 | 118 | if not db_targets: 119 | print(f"error: there is not file in {db_targets}") 120 | return 121 | 122 | dbs_to_process = [target for target in db_targets if completed_dbs.get(target['id']) != 'db_built'] 123 | if not dbs_to_process: 124 | logging.info("All databases are already processed.") 125 | return 126 | 127 | if any(completed_dbs.get(target['id']) != 'sql_generated' for target in dbs_to_process): 128 | model = load_universal_sentence_transformer(EMBEDDING_MODEL_NAME, model_path) 129 | if torch.cuda.is_available(): pool = model.start_multi_process_pool() 130 | 131 | for target in tqdm(db_targets, desc="Overall Progress"): 132 | db_id, source_db_path = target['id'], target['path'] 133 | if completed_dbs.get(db_id) == 'db_built': continue 134 | 135 | logging.info(f"--- Processing database: {db_id} ---") 136 | sql_script_path = os.path.join(SQL_SCRIPT_DIR, f"{db_id}_vector.sql") 137 | final_db_path = os.path.join(VECTOR_DB_ROOT, db_id, f"{db_id}.sqlite") 138 | 139 | try: 140 | if completed_dbs.get(db_id) != 'sql_generated': 141 | logging.info(f"Step 1/2: Generating SQL for '{db_id}'...") 142 | generate_database_script(db_path=source_db_path, output_file=sql_script_path, embedding_model=model, pool=pool, table_json_path=TABLE_JSON_PATH) 143 | completed_dbs[db_id] = 'sql_generated' 144 | save_completion_status(status_file_path, completed_dbs) 145 | 146 | logging.info(f"Step 2/2: Building vector DB for '{db_id}'...") 147 | # 恢复为简单的调用 148 | build_vector_database(SQL_FILE=sql_script_path, DB_FILE=final_db_path) 149 | completed_dbs[db_id] = 'db_built' 150 | save_completion_status(status_file_path, completed_dbs) 151 | except Exception as e: 152 | logging.error(f"Error processing '{db_id}': {e}", exc_info=True) 153 | continue 154 | finally: 155 | if pool: model.stop_multi_process_pool(pool) 156 | logging.info("--- Batch Vectorization Process Completed ---") 157 | 158 | if __name__ == '__main__': 159 | SOURCE_DB_ROOT = "/mnt/b_public/data/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/sqlite/train/arxiv" 160 | SQL_SCRIPT_DIR = "./vector_sql" 161 | VECTOR_DB_ROOT = "./vector_databases" 162 | TABLE_JSON_PATH = '/mnt/b_public/data/ydw/Text2VectorSQL/Data_Synthesizer/pipeline/sqlite/results/arxiv/find_semantic_tables.json' 163 | EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" 164 | model_path = "/mnt/b_public/data/yaodongwen/model" 165 | main_batch_vectorize_databases( 166 | SOURCE_DB_ROOT, 167 | SQL_SCRIPT_DIR, 168 | VECTOR_DB_ROOT, 169 | TABLE_JSON_PATH, 170 | EMBEDDING_MODEL_NAME, 171 | model_path 172 | ) 173 | -------------------------------------------------------------------------------- /Embedding_Service/multi_server.py: -------------------------------------------------------------------------------- 1 | # embedding_server.py 2 | 3 | import argparse 4 | import base64 # 【新增】导入 base64 用于解码 5 | import io # 【新增】导入 io 用于处理二进制数据 6 | import logging 7 | from contextlib import asynccontextmanager 8 | from typing import List, Optional # 【修改】导入 Optional 9 | 10 | import uvicorn 11 | import yaml 12 | from fastapi import FastAPI, HTTPException 13 | from pydantic import BaseModel, Field 14 | from PIL import Image # 【新增】导入 Pillow 用于图像处理 15 | # 导入 SentenceTransformer 保持不变 16 | from sentence_transformers import SentenceTransformer 17 | 18 | # --- Globals (保持不变) --- 19 | CONFIG = {} 20 | MODELS = {} 21 | 22 | # --- Logging Setup (保持不变) --- 23 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 24 | logger = logging.getLogger("EmbeddingService") 25 | 26 | # --- Pydantic Models for API validation --- 27 | class EmbeddingRequest(BaseModel): 28 | model: str = Field(..., description="The name of the model to use for embedding (must match a name in config.yaml).") 29 | # 【修改】将 'texts' 和 'images' 设为可选字段,以支持不同类型的输入 30 | texts: Optional[List[str]] = Field(None, description="A list of texts to embed.") 31 | images: Optional[List[str]] = Field(None, description="A list of Base64-encoded images to embed.") 32 | 33 | class EmbeddingResponse(BaseModel): 34 | model: str = Field(..., description="The name of the model used.") 35 | embeddings: List[List[float]] = Field(..., description="A list of embedding vectors.") 36 | 37 | # --- FastAPI Lifespan Management (保持不变) --- 38 | # The model loading logic does not need to change, as SentenceTransformer 39 | # handles loading of both text and multi-modal models transparently. 40 | @asynccontextmanager 41 | async def lifespan(app: FastAPI): 42 | """ 43 | Handles startup and shutdown events. 44 | Loads models on startup. 45 | """ 46 | logger.info("Starting up Embedding Service...") 47 | 48 | parser = argparse.ArgumentParser(description="Embedding Service with Sentence-Transformers and FastAPI") 49 | parser.add_argument("--config", type=str, default="config.yaml", help="Path to the configuration YAML file.") 50 | args = parser.parse_args() 51 | 52 | try: 53 | with open(args.config, 'r') as f: 54 | config_data = yaml.safe_load(f) 55 | CONFIG.update(config_data) 56 | logger.info(f"Configuration loaded from {args.config}") 57 | except FileNotFoundError: 58 | logger.error(f"Configuration file not found at {args.config}. Exiting.") 59 | exit(1) 60 | except Exception as e: 61 | logger.error(f"Error loading configuration: {e}. Exiting.") 62 | exit(1) 63 | 64 | if not CONFIG.get('models'): 65 | logger.error("No models found in the configuration file. Exiting.") 66 | exit(1) 67 | 68 | for model_config in CONFIG['models']: 69 | model_name = model_config.get('name') 70 | hf_path = model_config.get('hf_model_path') 71 | if not model_name or not hf_path: 72 | logger.warning(f"Skipping invalid model configuration: {model_config}") 73 | continue 74 | 75 | logger.info(f"Loading model '{model_name}' from '{hf_path}' using Sentence-Transformers...") 76 | try: 77 | model = SentenceTransformer( 78 | model_name_or_path=hf_path, 79 | trust_remote_code=model_config.get('trust_remote_code', True) 80 | ) 81 | MODELS[model_name] = model 82 | logger.info(f"Successfully loaded model '{model_name}'.") 83 | except Exception as e: 84 | logger.error(f"Failed to load model '{model_name}': {e}") 85 | 86 | if not MODELS: 87 | logger.error("No models were successfully loaded. Shutting down.") 88 | exit(1) 89 | 90 | yield 91 | 92 | logger.info("Shutting down Embedding Service...") 93 | MODELS.clear() 94 | 95 | 96 | # --- FastAPI App Initialization (保持不变) --- 97 | app = FastAPI( 98 | title="Text2VectorSQL Embedding Service (Sentence-Transformers Backend)", 99 | description="A high-performance API service for text and image embeddings, powered by Sentence-Transformers.", 100 | version="1.1.0", # Version bump 101 | lifespan=lifespan 102 | ) 103 | 104 | # --- API Endpoints --- 105 | @app.get("/health") 106 | async def health_check(): 107 | """Health check endpoint to verify service is running.""" 108 | return {"status": "ok", "loaded_models": list(MODELS.keys())} 109 | 110 | @app.post("/embed", response_model=EmbeddingResponse) 111 | async def create_embeddings(request: EmbeddingRequest): 112 | """ 113 | Takes a list of texts OR a list of images and returns their embeddings. 114 | """ 115 | if request.model not in MODELS: 116 | raise HTTPException(status_code=404, detail=f"Model '{request.model}' not found. Available models: {list(MODELS.keys())}") 117 | 118 | # 【核心修改】增加输入验证和处理逻辑 119 | # 1. 验证输入:确保只提供了 'texts' 或 'images' 中的一个 120 | if (request.texts is None and request.images is None) or \ 121 | (request.texts is not None and request.images is not None): 122 | raise HTTPException( 123 | status_code=400, 124 | detail="You must provide either 'texts' or 'images', but not both." 125 | ) 126 | 127 | model_engine = MODELS[request.model] 128 | 129 | try: 130 | inputs_to_encode = [] 131 | # 2. 根据输入类型准备数据 132 | if request.texts: 133 | inputs_to_encode = request.texts 134 | logger.info(f"Processing {len(request.texts)} texts with model '{request.model}'.") 135 | 136 | elif request.images: 137 | logger.info(f"Processing {len(request.images)} images with model '{request.model}'.") 138 | # 将 Base64 字符串解码为 Pillow 图像对象 139 | pil_images = [] 140 | for b64_string in request.images: 141 | try: 142 | # 从 Base64 字符串解码为字节 143 | image_bytes = base64.b64decode(b64_string) 144 | # 从字节数据创建 PIL.Image 对象 145 | image = Image.open(io.BytesIO(image_bytes)) 146 | pil_images.append(image) 147 | except Exception as img_e: 148 | logger.error(f"Failed to decode or open image: {img_e}") 149 | raise HTTPException(status_code=400, detail="Invalid Base64-encoded image data provided.") 150 | inputs_to_encode = pil_images 151 | 152 | # 3. 调用 SentenceTransformer 的 encode 方法进行编码 153 | # 该方法可以透明地处理文本列表或图像对象列表 154 | embeddings = model_engine.encode(inputs_to_encode, convert_to_numpy=False) 155 | 156 | return EmbeddingResponse(model=request.model, embeddings=embeddings) 157 | 158 | except Exception as e: 159 | logger.error(f"Error during embedding process for model '{request.model}': {e}") 160 | raise HTTPException(status_code=500, detail=f"Internal server error during embedding: {e}") 161 | 162 | # --- Main execution block (保持不变) --- 163 | if __name__ == "__main__": 164 | # This part remains the same to parse config for host/port 165 | if not CONFIG: 166 | parser = argparse.ArgumentParser() 167 | parser.add_argument("--config", type=str, default="config.yaml") 168 | args, _ = parser.parse_known_args() 169 | try: 170 | with open(args.config, 'r') as f: 171 | config_data = yaml.safe_load(f) 172 | server_config = config_data.get('server', {}) 173 | host = server_config.get('host', '0.0.0.0') 174 | port = server_config.get('port', 8000) 175 | except Exception: 176 | host, port = "0.0.0.0", 8000 177 | else: 178 | server_config = CONFIG.get('server', {}) 179 | host = server_config.get('host', '0.0.0.0') 180 | port = server_config.get('port', 8000) 181 | 182 | uvicorn.run(app, host=host, port=port) -------------------------------------------------------------------------------- /Data_Synthesizer/database_synthesis/enhance_schema.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import concurrent.futures 3 | import json 4 | import os 5 | import re 6 | from functools import partial 7 | from json_repair import json_repair 8 | from tqdm import tqdm 9 | import hashlib 10 | from tenacity import retry, stop_after_attempt, wait_exponential 11 | import openai 12 | 13 | # 配置缓存目录 14 | CACHE_DIR = "./cache" 15 | os.makedirs(CACHE_DIR, exist_ok=True) 16 | 17 | def get_cache_key(prompt, model): 18 | """生成基于提示内容和模型名称的唯一缓存键""" 19 | key_str = f"{model}_{prompt}" 20 | return hashlib.md5(key_str.encode('utf-8')).hexdigest() 21 | 22 | def load_from_cache(cache_key): 23 | """从缓存加载数据""" 24 | cache_file = os.path.join(CACHE_DIR, f"{cache_key}.json") 25 | if os.path.exists(cache_file): 26 | with open(cache_file, 'r', encoding='utf-8') as f: 27 | return json.load(f) 28 | return None 29 | 30 | def save_to_cache(cache_key, data): 31 | """保存数据到缓存""" 32 | cache_file = os.path.join(CACHE_DIR, f"{cache_key}.json") 33 | with open(cache_file, 'w', encoding='utf-8') as f: 34 | json.dump(data, f, ensure_ascii=False, indent=2) 35 | 36 | def parse_response(response): 37 | """保持原有解析函数不变""" 38 | schema_pattern = r'```json\s*([\s\S]*?)\s*```' 39 | 40 | try: 41 | enhanced_schema_match = re.search(schema_pattern, response, re.DOTALL) 42 | enhanced_schema_str = enhanced_schema_match.group(0).strip() if enhanced_schema_match else None 43 | enhanced_schema_dict = json_repair.loads(enhanced_schema_str) 44 | return enhanced_schema_dict 45 | except Exception as e: 46 | print(response) 47 | print("Parsing Exception:", str(e)) 48 | return None 49 | 50 | def parse_prompt(prompt): 51 | """保持原有解析函数不变""" 52 | domain_pattern = r'(?<=\*\*Business Domain:\*\*)(.*?)(?=\*\*Business Scenario:\*\*)' 53 | scenario_pattern = r'(?<=\*\*Business Scenario:\*\*)(.*?)(?=\*\*Initial Database Schema:\*\*)' 54 | 55 | domain_match = re.search(domain_pattern, prompt, re.DOTALL) 56 | domain = domain_match.group(0).strip() if domain_match else None 57 | 58 | scenario_match = re.search(scenario_pattern, prompt, re.DOTALL) 59 | scenario = scenario_match.group(0).strip() if scenario_match else None 60 | 61 | return domain, scenario 62 | 63 | @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)) 64 | def generate_response(client, model, prompt): 65 | """带重试机制的API调用函数(仅修改max_tokens)""" 66 | try: 67 | response = client.chat.completions.create( 68 | model=model, 69 | messages=[ 70 | {"role": "system", "content": "你是一个专业的数据库架构师,负责生成增强的数据库模式"}, 71 | {"role": "user", "content": prompt} 72 | ], 73 | temperature=0.7, 74 | max_tokens=10000 # 修改为10000 75 | ) 76 | return response.choices[0].message.content 77 | except Exception as e: 78 | print(f"API调用失败: {str(e)}") 79 | raise 80 | 81 | def limit_columns_per_table(schema_dict, limit=16): 82 | """ 83 | 遍历 schema 中的每个表,确保其总列数不超过指定的限制。 84 | 修剪策略分阶段进行: 85 | 1. 首先移除所有带 `_embedding` 后缀的列。 86 | 2. 如果列数仍然超限,则从后向前移除列,直到符合限制。 87 | """ 88 | if not isinstance(schema_dict, dict) or "tables" not in schema_dict: 89 | return schema_dict 90 | 91 | modified_tables = [] 92 | for table in schema_dict.get("tables", []): 93 | columns = table.get("columns", []) 94 | table_name = table.get('table_name', '未知') 95 | 96 | if len(columns) <= limit: 97 | modified_tables.append(table) 98 | continue 99 | 100 | print(f"警告: 表 '{table_name}' 的列数 {len(columns)} 超过了 {limit} 的限制。正在进行修剪...") 101 | 102 | # --- 第一阶段:移除所有 _embedding 列 --- 103 | cols_after_stage1 = [col for col in columns if not col.get('name', '').endswith('_embedding')] 104 | 105 | if len(cols_after_stage1) <= limit: 106 | print(f" - 操作: 移除所有 _embedding 列后,列数变为 {len(cols_after_stage1)},符合要求。") 107 | table['columns'] = cols_after_stage1 108 | modified_tables.append(table) 109 | continue 110 | 111 | # --- 第二阶段:如果仍然超限,则直接从后往前删除 --- 112 | print(f" - 注意: 移除 _embedding 列后,列数仍为 {len(cols_after_stage1)},将继续从后往前删除。") 113 | 114 | # 直接截断列表,只保留前 limit 个元素 115 | final_columns = cols_after_stage1[:limit] 116 | table['columns'] = final_columns 117 | 118 | print(f" - 操作完成: 表 '{table_name}' 的总列数已从 {len(columns)} 减少到 {len(final_columns)}。") 119 | modified_tables.append(table) 120 | 121 | schema_dict["tables"] = modified_tables 122 | return schema_dict 123 | 124 | def process_prompt(prompt, model, client): 125 | """处理单个提示的核心逻辑""" 126 | cache_key = get_cache_key(prompt, model) 127 | 128 | # 检查缓存 129 | cached = load_from_cache(cache_key) 130 | if cached: 131 | return cached 132 | 133 | # 生成响应 134 | response = generate_response(client, model, prompt) 135 | 136 | # 解析响应 137 | enhanced_schema = parse_response(response) 138 | if not enhanced_schema: 139 | return None 140 | 141 | # 新增逻辑:限制每个表的列数不超过16 142 | enhanced_schema = limit_columns_per_table(enhanced_schema, limit=16) 143 | 144 | # 解析提示 145 | domain, scenario = parse_prompt(prompt) 146 | 147 | # 构建结果 148 | result = { 149 | "prompt": prompt, 150 | "domain": domain, 151 | "scenario": scenario, 152 | "enhanced_schema": json.dumps(enhanced_schema, indent=2, ensure_ascii=False), 153 | "raw_response": response 154 | } 155 | 156 | # 保存到缓存 157 | save_to_cache(cache_key, result) 158 | return result 159 | 160 | def llm_inference(model, prompts, api_key=None, api_url=None, max_workers=32): 161 | """ 162 | 多线程推理函数(仅修改max_tokens) 163 | """ 164 | # 初始化OpenAI客户端 165 | client = openai.OpenAI( 166 | api_key=api_key, 167 | base_url=api_url.rstrip('/') if api_url else "https://api.openai.com/v1" 168 | ) 169 | 170 | results = [] 171 | 172 | # 使用线程池并行处理 173 | with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: 174 | # 创建部分函数固定参数 175 | process_fn = partial(process_prompt, model=model, client=client) 176 | 177 | # 提交所有任务 178 | futures = {executor.submit(process_fn, prompt): prompt for prompt in prompts} 179 | 180 | # 使用tqdm显示进度 181 | for future in tqdm(concurrent.futures.as_completed(futures), 182 | total=len(prompts), 183 | desc="处理进度"): 184 | try: 185 | result = future.result() 186 | if result: 187 | results.append(result) 188 | except Exception as e: 189 | prompt = futures[future] 190 | print(f"处理失败 - 提示: {prompt[:50]}... 错误: {str(e)}") 191 | 192 | return results 193 | 194 | if __name__ == '__main__': 195 | parser = argparse.ArgumentParser() 196 | parser.add_argument("--model", type=str, required=True) 197 | parser.add_argument("--api_key", type=str) 198 | parser.add_argument("--api_url", type=str) 199 | parser.add_argument("--max_workers", type=int, default=128) 200 | args = parser.parse_args() 201 | 202 | # 加载提示 203 | prompts = json.load(open("./prompts/prompts_schema_enhancement.json", encoding='utf-8')) 204 | 205 | # 执行推理 206 | results = llm_inference( 207 | model=args.model, 208 | prompts=prompts, 209 | api_key=args.api_key, 210 | api_url=args.api_url, 211 | max_workers=args.max_workers 212 | ) 213 | 214 | # 保存结果 215 | output_file = "./results/schema_enhancement.json" 216 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 217 | with open(output_file, 'w', encoding='utf-8') as f: 218 | json.dump(results, f, indent=2, ensure_ascii=False) 219 | 220 | print(f"处理完成,结果已保存到 {output_file}") 221 | -------------------------------------------------------------------------------- /Data_Synthesizer/vectorization/generate_vector_schema.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import sqlite3 4 | import sqlite_vec # Import the sqlite_vec library 5 | from tqdm import tqdm 6 | # from dotenv import load_dotenv 7 | 8 | # --- Configuration from .env --- 9 | # load_dotenv() 10 | 11 | # Read the variables using os.getenv() 12 | # vector_db_root = os.getenv("vector_db_root_GENERATE_SCHEMA") 13 | # original_schema_path = os.getenv("original_schema_path") 14 | # output_dir = os.getenv("output_dir_GENERATE_SCHEMA") 15 | # output_json_path = os.getenv("output_json_path_GENERATE_SCHEMA") 16 | 17 | def generate_schema_for_db(db_id, db_path, original_schema): 18 | """ 19 | Connects to a single vector database, loads the vec extension, 20 | inspects its schema, extracts DDLs, and returns a dictionary. 21 | """ 22 | # --- MODIFICATION START --- 23 | # Added "ddls": [] to the schema structure 24 | new_schema = { 25 | "db_id": db_id, 26 | "ddls": [], 27 | "table_names_original": [], 28 | "table_names": [], 29 | "column_names_original": [[-1, "*"]], 30 | "column_names": [[-1, "*"]], 31 | "column_types": ["text"], 32 | "primary_keys": original_schema.get("primary_keys", []), 33 | "foreign_keys": original_schema.get("foreign_keys", []) 34 | } 35 | # --- MODIFICATION END --- 36 | 37 | try: 38 | # print(f"""processing da_path: {db_path}""") 39 | conn = sqlite3.connect(db_path) 40 | conn.enable_load_extension(True) 41 | sqlite_vec.load(conn) 42 | cursor = conn.cursor() 43 | 44 | # --- MODIFICATION START --- 45 | # Query sqlite_master to get the CREATE TABLE statements (DDLs) 46 | # This is done first to ensure we capture the schema definition. 47 | cursor.execute("SELECT sql FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY rowid") 48 | # Extract the SQL statements from the query result. Filter out any None values. 49 | ddl_statements = [row[0] for row in cursor.fetchall() if row[0]] 50 | new_schema["ddls"] = ddl_statements 51 | # --- MODIFICATION END --- 52 | 53 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY rowid") 54 | table_names = [row[0] for row in cursor.fetchall()] 55 | 56 | new_schema["table_names_original"] = table_names 57 | new_schema["table_names"] = table_names 58 | 59 | table_name_to_idx = {name: i for i, name in enumerate(table_names)} 60 | 61 | for table_name in table_names: 62 | table_idx = table_name_to_idx[table_name] 63 | cursor.execute(f'PRAGMA table_xinfo("{table_name}");') 64 | columns_info = cursor.fetchall() 65 | 66 | for col in columns_info: 67 | col_name = col[1] 68 | col_type = col[2].upper() 69 | if col[5] != 0: # Skip generated columns 70 | continue 71 | new_schema["column_names_original"].append([table_idx, col_name]) 72 | new_schema["column_names"].append([table_idx, col_name]) 73 | if 'FLOAT' in col_type or '[' in col_type: 74 | new_schema["column_types"].append("text") 75 | else: 76 | new_schema["column_types"].append(col_type.lower()) 77 | 78 | conn.close() 79 | return new_schema 80 | except sqlite3.Error as e: 81 | print(f" [ERROR] Could not process database {db_id}: {e}") 82 | return None 83 | 84 | # 通用函数,用了产生向量数据库的schema 85 | def generate_vector_schema(vector_db_root,original_schema_path,output_dir,output_json_path): 86 | """ 87 | Main function to find vector databases, generate their schemas, 88 | and write the result to a new JSON file. 89 | """ 90 | print(f"Starting schema generation from vector databases in: {vector_db_root}") 91 | 92 | try: 93 | with open(original_schema_path, 'r', encoding='utf-8') as f: 94 | original_schemas_list = json.load(f) 95 | original_schemas = {item['db_id']: item for item in original_schemas_list} 96 | print(f"Loaded {len(original_schemas)} original schemas for reference.") 97 | except (FileNotFoundError, json.JSONDecodeError) as e: 98 | print(f"✖ Could not load original schema file '{original_schema_path}': {e}") 99 | return 100 | 101 | all_new_schemas = [] 102 | 103 | db_targets = [] 104 | try: 105 | if not os.path.exists(vector_db_root): 106 | raise FileNotFoundError 107 | 108 | for item_name in os.listdir(vector_db_root): 109 | full_path = os.path.join(vector_db_root, item_name) 110 | 111 | if os.path.isdir(full_path): 112 | db_id = item_name 113 | db_path_sqlite = os.path.join(full_path, f"{db_id}.sqlite") 114 | db_path_db = os.path.join(full_path, f"{db_id}.db") 115 | db_path_sqlite_final = os.path.join(full_path, f"{db_id}_final.sqlite") 116 | db_path_db_final = os.path.join(full_path, f"{db_id}_final.db") 117 | if os.path.exists(db_path_db_final): 118 | db_targets.append({'id': db_id, 'path': db_path_db_final}) 119 | elif os.path.exists(db_path_sqlite_final): 120 | db_targets.append({'id': db_id, 'path': db_path_sqlite_final}) 121 | elif os.path.exists(db_path_sqlite): 122 | db_targets.append({'id': db_id, 'path': db_path_sqlite}) 123 | elif os.path.exists(db_path_db): 124 | db_targets.append({'id': db_id, 'path': db_path_db}) 125 | 126 | elif os.path.isfile(full_path) and item_name.endswith(('.sqlite', '.db')): 127 | db_id = os.path.splitext(item_name)[0] 128 | db_targets.append({'id': db_id, 'path': full_path}) 129 | 130 | except FileNotFoundError: 131 | print(f"✖ Vector database directory not found: {vector_db_root}") 132 | return 133 | 134 | print(f"Found {len(db_targets)} potential databases. Processing...") 135 | 136 | for target in tqdm(db_targets, desc="Processing Databases"): 137 | db_id = target['id'] 138 | db_path = target['path'] 139 | 140 | if not os.path.exists(db_path): 141 | print(f" [WARN] Skipping '{db_id}': database file not found at {db_path}") 142 | continue 143 | 144 | if db_id not in original_schemas: 145 | print(f" [WARN] Skipping '{db_id}': no matching entry in original schema file.") 146 | continue 147 | 148 | original_schema = original_schemas[db_id] 149 | new_schema_data = generate_schema_for_db(db_id, db_path, original_schema) 150 | 151 | if new_schema_data: 152 | all_new_schemas.append(new_schema_data) 153 | 154 | try: 155 | if not os.path.exists(output_dir): 156 | os.makedirs(output_dir) 157 | 158 | # Ensure output_json_path is a full path if it's just a filename 159 | if os.path.dirname(output_json_path) == '': 160 | full_output_path = os.path.join(output_dir, output_json_path) 161 | else: 162 | full_output_path = output_json_path 163 | 164 | with open(full_output_path, 'w', encoding='utf-8') as f: 165 | json.dump(all_new_schemas, f, indent=2, ensure_ascii=False) 166 | print(f"\n✔ Successfully created '{full_output_path}' with {len(all_new_schemas)} database schemas.") 167 | except IOError as e: 168 | print(f"\n✖ Failed to write output file: {e}") 169 | 170 | # if __name__ == '__main__': 171 | # # Example of how you might call this function 172 | # # You would replace these paths with the ones from your config/env 173 | # generate_vector_schema( 174 | # vector_db_root="path/to/your/databases", 175 | # original_schema_path="path/to/original/schema.json", 176 | # output_dir="results", 177 | # output_json_path="vector_schemas_with_ddls.json" 178 | # ) 179 | -------------------------------------------------------------------------------- /Data_Synthesizer/synthesis_nl/synthesize_candidate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import time 6 | from concurrent.futures import ThreadPoolExecutor 7 | from functools import lru_cache 8 | from typing import Dict, List 9 | 10 | import openai 11 | from tqdm import tqdm 12 | from dotenv import load_dotenv # <-- Import dotenv 13 | 14 | # --- Load environment variables at the start --- 15 | load_dotenv() 16 | 17 | # --- 配置日志记录 --- 18 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 19 | 20 | # --- 缓存实现 --- 21 | CACHE_FILE = './cache/openai_cache.json' 22 | CACHE = {} 23 | 24 | def load_cache(): 25 | """如果缓存文件存在,则从中加载缓存。""" 26 | global CACHE 27 | if os.path.exists(CACHE_FILE): 28 | try: 29 | with open(CACHE_FILE, 'r', encoding='utf-8') as f: 30 | CACHE = json.load(f) 31 | logging.info(f"从缓存文件中加载了 {len(CACHE)} 个项目。") 32 | except (json.JSONDecodeError, IOError) as e: 33 | logging.warning(f"无法加载缓存文件,将使用空缓存启动。错误: {e}") 34 | CACHE = {} 35 | else: 36 | logging.info("未找到缓存文件,将使用空缓存启动。") 37 | 38 | def save_cache(): 39 | """将当前缓存保存到文件。""" 40 | try: 41 | with open(CACHE_FILE, 'w', encoding='utf-8') as f: 42 | json.dump(CACHE, f, indent=2, ensure_ascii=False) 43 | logging.info(f"已将 {len(CACHE)} 个项目保存到缓存文件。") 44 | except IOError as e: 45 | logging.error(f"保存缓存文件失败: {e}") 46 | 47 | # --- 关键修改: 将模型名称添加到缓存键中 --- 48 | def get_cache_key(model_name: str, question: str, original_sql: str) -> str: 49 | """根据模型、问题和SQL创建一致的缓存键。""" 50 | return f"{model_name}|{question}|{original_sql}" 51 | 52 | # --- OpenAI API 交互 --- 53 | @lru_cache(maxsize=None) 54 | # --- 关键修改: 添加 model_name 参数 --- 55 | def get_sql_candidates_from_openai(client: openai.OpenAI, model_name: str, question: str, original_sql: str, num_candidates: int = 5) -> List[str]: 56 | """ 57 | 使用 OpenAI API 生成 SQL 候选项。 58 | 此函数首先会检查持久化缓存。 59 | """ 60 | cache_key = get_cache_key(model_name, question, original_sql) 61 | if cache_key in CACHE: 62 | logging.info(f"缓存命中: '{question[:50]}...'") 63 | return CACHE[cache_key] 64 | 65 | logging.info(f"缓存未命中,正在为问题调用API: '{question[:50]}...'") 66 | 67 | try: 68 | start_phrase = original_sql.split("lembed('all-MiniLM-L6-v2',")[1] 69 | original_phrase = start_phrase.split("')")[0].strip() 70 | if original_phrase.startswith(('"', "'")) and original_phrase.endswith(('"', "'")): 71 | original_phrase = original_phrase[1:-1] 72 | except IndexError: 73 | logging.error(f"无法从SQL中解析原始短语: {original_sql}") 74 | return [original_sql] 75 | 76 | system_prompt = ( 77 | "你是一个精通语义搜索和SQL的专家。你的任务是重写SQL查询中`lembed`函数内的搜索短语。" 78 | "给定用户的问题和原始SQL查询,你必须生成多个能够捕捉用户意图的替代短语。" 79 | "SQL查询的其余结构必须保持完全相同。" 80 | "只更改`lembed`函数的第二个参数。" 81 | "请以JSON数组字符串的形式提供输出,其中每个字符串都是一个完整的SQL查询。" 82 | ) 83 | 84 | user_prompt = f""" 85 | 原始问题: "{question}" 86 | 原始SQL: "{original_sql}" 87 | 原始搜索短语: "{original_phrase}" 88 | 89 | 根据问题,生成 {num_candidates} 个替代的SQL查询。每个查询都应与原始SQL相同,除了 `lembed('all-MiniLM-L6-v2', '...')` 中的搜索短语。 90 | 新的搜索短语在语义上应与原始问题中的用户意图相似或是其替代解释。 91 | 92 | 请仅返回一个有效的JSON数组字符串。例如: 93 | {{ 94 | "sql_candidate": [ 95 | "SELECT ... FROM ... WHERE ... MATCH lembed('all-MiniLM-L6-v2', '新短语1')", 96 | "SELECT ... FROM ... WHERE ... MATCH lembed('all-MiniLM-L6-v2', '新短语2')", 97 | "SELECT ... FROM ... WHERE ... MATCH lembed('all-MiniLM-L6-v2', '新短语3')" 98 | ] 99 | }} 100 | """ 101 | 102 | max_retries = 3 103 | retry_delay = 5 104 | for attempt in range(max_retries): 105 | try: 106 | response = client.chat.completions.create( 107 | model=model_name, # <-- 关键修改: 使用来自 .env 的模型名称 108 | messages=[ 109 | {"role": "system", "content": system_prompt}, 110 | {"role": "user", "content": user_prompt} 111 | ], 112 | response_format={"type": "json_object"}, 113 | temperature=0.5, 114 | ) 115 | content = response.choices[0].message.content 116 | result_data = json.loads(content) 117 | 118 | sql_list = [] 119 | if isinstance(result_data, dict) and "sql_candidate" in result_data: 120 | value = result_data["sql_candidate"] 121 | if isinstance(value, list) and all(isinstance(item, str) for item in value): 122 | sql_list = value 123 | 124 | if not sql_list: 125 | raise ValueError("JSON响应中未包含一个名为 'sql_candidate' 的SQL字符串列表。") 126 | 127 | CACHE[cache_key] = sql_list 128 | save_cache() 129 | return sql_list 130 | 131 | except (openai.APIError, json.JSONDecodeError, ValueError) as e: 132 | logging.error(f"在第 {attempt + 1} 次尝试中发生错误 ('{question[:50]}...'): {e}") 133 | if attempt < max_retries - 1: 134 | logging.info(f"将在 {retry_delay} 秒后重试...") 135 | time.sleep(retry_delay) 136 | else: 137 | logging.error(f"在 {max_retries} 次尝试后未能为 '{question[:50]}...' 生成候选项。") 138 | return [original_sql] 139 | 140 | return [original_sql] 141 | 142 | # --- 主要处理逻辑 --- 143 | # --- 关键修改: 添加 model_name 参数 --- 144 | def process_item(item: Dict, client: openai.OpenAI, num_candidates: int, model_name: str) -> Dict: 145 | """ 146 | 处理单个JSON项目以添加 'sql_candidate' 字段。 147 | """ 148 | question = item.get("question") 149 | sql = item.get("sql") 150 | 151 | if not question or not sql: 152 | item['sql_candidate'] = [] 153 | return item 154 | 155 | candidates = get_sql_candidates_from_openai(client, model_name, question, sql, num_candidates) 156 | item['sql_candidate'] = candidates 157 | return item 158 | 159 | def synthesize_candidate(model_name,api_key,base_url,num_candidates,max_workers,input_file="results/question_and_sql_pairs.json",output_file="results/candidate_sql.json"): 160 | """ 161 | 主函数:解析参数,读取文件,处理数据,并写入输出。 162 | """ 163 | # 验证关键配置 164 | if not api_key or not model_name: 165 | raise ValueError("错误: API_KEY 和 LLM_MODEL_NAME 必须在 .env 文件中设置。") 166 | 167 | logging.info("-----------------------------------") 168 | logging.info("大模型开始为问题生成多个sql语句候选\n") 169 | logging.info("--- 文件配置 ---") 170 | logging.info(f"模型: {model_name}") 171 | logging.info(f"基础 URL: {base_url}") 172 | logging.info(f"每个问题的候选项数量: {num_candidates}") 173 | logging.info(f"最大并发线程数: {max_workers}") 174 | logging.info("-----------------------------------") 175 | # --- MODIFICATION END --- 176 | 177 | try: 178 | client = openai.OpenAI(api_key=api_key, base_url=base_url) 179 | except Exception as e: 180 | logging.error(f"初始化OpenAI客户端失败: {e}") 181 | return 182 | 183 | try: 184 | with open(input_file, 'r', encoding='utf-8') as f: 185 | data = json.load(f) 186 | except (FileNotFoundError, json.JSONDecodeError) as e: 187 | logging.error(f"读取或解析输入文件 '{input_file}' 失败: {e}") 188 | return 189 | 190 | load_cache() 191 | 192 | results = [] 193 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 194 | # --- 关键修改: 将 model_name 传递给 process_item --- 195 | futures = [executor.submit(process_item, item, client, num_candidates, model_name) for item in data] 196 | 197 | for future in tqdm(futures, total=len(data), desc="正在生成SQL候选项"): 198 | try: 199 | result = future.result() 200 | results.append(result) 201 | except Exception as e: 202 | logging.error(f"线程池中的一个任务失败: {e}") 203 | 204 | try: 205 | with open(output_file, 'w', encoding='utf-8') as f: 206 | json.dump(results, f, indent=2, ensure_ascii=False) 207 | logging.info(f"已成功将更新后的数据写入 {output_file}") 208 | except IOError as e: 209 | logging.error(f"写入输出文件 '{output_file}' 失败: {e}") 210 | 211 | # if __name__ == "__main__": 212 | # main() 213 | -------------------------------------------------------------------------------- /Embedding_Service/server.py: -------------------------------------------------------------------------------- 1 | # server.py 2 | 3 | import os 4 | import argparse 5 | import asyncio 6 | import functools 7 | import logging 8 | from contextlib import asynccontextmanager 9 | from threading import Lock 10 | from typing import List, Dict, Any 11 | 12 | import uvicorn 13 | import yaml 14 | import torch 15 | from fastapi import FastAPI, HTTPException 16 | from pydantic import BaseModel, Field 17 | from sentence_transformers import SentenceTransformer 18 | 19 | # --- Globals --- 20 | CONFIG: Dict[str, Any] = {} 21 | MODELS: Dict[str, Dict[str, Any]] = {} 22 | # 新增一个线程锁,以防止在多worker模式下可能出现的下载竞争问题 23 | model_download_lock = Lock() 24 | 25 | # --- Logging Setup --- 26 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 27 | logger = logging.getLogger("EmbeddingService") 28 | 29 | # --- 【新增】模型下载与准备的辅助函数 --- 30 | def prepare_model_path(model_config: Dict[str, Any]) -> str: 31 | """ 32 | 检查本地模型路径是否存在。如果不存在,则从Hugging Face下载并保存。 33 | 返回最终可供加载的本地模型路径。 34 | 此函数设计为线程安全。 35 | """ 36 | hf_path = model_config.get("hf_model_path") 37 | local_path = model_config.get("local_model_path") 38 | 39 | if not hf_path or not local_path: 40 | raise ValueError(f"模型 '{model_config.get('name')}' 的配置缺少 'hf_model_path' 或 'local_model_path'。") 41 | 42 | # 使用一个关键文件(如config.json)来判断模型是否已完整存在 43 | local_config_file = os.path.join(local_path, "config.json") 44 | 45 | if os.path.exists(local_config_file): 46 | logger.info(f"在本地路径 '{local_path}' 找到模型。将直接加载。") 47 | return local_path 48 | 49 | # 如果本地不存在,则加锁以确保只有一个进程/线程执行下载 50 | with model_download_lock: 51 | # 双重检查,防止在等待锁的过程中,其他线程已经下载完毕 52 | if os.path.exists(local_config_file): 53 | logger.info(f"在等待锁后,发现模型已存在于 '{local_path}'。") 54 | return local_path 55 | 56 | logger.warning(f"本地模型未找到。开始从 '{hf_path}' 下载...") 57 | logger.warning("(首次下载会花费一些时间,请耐心等待...)") 58 | 59 | try: 60 | # 1. 下载模型到 huggingface 的默认缓存中 61 | model = SentenceTransformer(hf_path) 62 | # 2. 将完整的模型文件保存到我们指定的永久本地路径 63 | model.save(local_path) 64 | logger.info(f"✅ 模型成功下载并保存到: '{local_path}'") 65 | return local_path 66 | except Exception as e: 67 | logger.error(f"❌ 下载或保存模型 '{hf_path}' 时发生错误: {e}", exc_info=True) 68 | raise 69 | 70 | # --- Pydantic Models (保持不变) --- 71 | class EmbeddingRequest(BaseModel): 72 | model: str = Field(..., description="The name of the model to use for embedding (must match a name in config.yaml).") 73 | texts: List[str] = Field(..., description="A list of texts to embed.") 74 | 75 | class EmbeddingResponse(BaseModel): 76 | model: str = Field(..., description="The name of the model used.") 77 | embeddings: List[List[float]] = Field(..., description="A list of embedding vectors.") 78 | 79 | # --- FastAPI Lifespan Management (已修改) --- 80 | @asynccontextmanager 81 | async def lifespan(app: FastAPI): 82 | """ 83 | 处理启动和关闭事件。 84 | 在启动时,会先准备好模型(下载或使用缓存),然后再加载到GPU。 85 | """ 86 | global CONFIG, MODELS 87 | logger.info("Starting up Embedding Service...") 88 | 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument("--config", type=str, default="config.yaml", help="Path to the configuration YAML file.") 91 | args, _ = parser.parse_known_args() 92 | 93 | try: 94 | with open(args.config, 'r') as f: 95 | CONFIG.update(yaml.safe_load(f)) 96 | logger.info(f"Configuration loaded from {args.config}") 97 | except Exception as e: 98 | logger.error(f"加载配置文件失败: {e}. Exiting.", exc_info=True) 99 | exit(1) 100 | 101 | if not CONFIG.get('models'): 102 | logger.error("配置文件中未找到模型定义. Exiting.") 103 | exit(1) 104 | 105 | for model_config in CONFIG['models']: 106 | model_name = model_config.get('name') 107 | try: 108 | # 【核心修改】在加载模型前,先调用辅助函数确保模型已在本地准备好 109 | final_model_path = prepare_model_path(model_config) 110 | 111 | logger.info(f"开始从本地路径 '{final_model_path}' 加载模型 '{model_name}'...") 112 | 113 | # 从准备好的本地路径加载模型 114 | model = SentenceTransformer( 115 | model_name_or_path=final_model_path, 116 | trust_remote_code=model_config.get('trust_remote_code', True) 117 | ) 118 | 119 | max_len = model_config.get('max_model_len') 120 | if isinstance(max_len, int): 121 | model.max_seq_length = max_len 122 | logger.info(f"模型 '{model_name}' 的最大序列长度设置为 {max_len}.") 123 | 124 | pool = None 125 | parallel_size = model_config.get('tensor_parallel_size', 1) 126 | 127 | if parallel_size > 1: 128 | if not torch.cuda.is_available(): 129 | logger.warning(f"CUDA 不可用。模型 '{model_name}' 将在 CPU 单进程上运行。") 130 | else: 131 | num_gpus = torch.cuda.device_count() 132 | if num_gpus < parallel_size: 133 | logger.warning(f"请求 {parallel_size} 个 GPU, 但只有 {num_gpus} 个可用。将使用所有可用的GPU。") 134 | parallel_size = num_gpus 135 | 136 | target_devices = [f'cuda:{i}' for i in range(parallel_size)] 137 | logger.info(f"为模型 '{model_name}' 在设备 {target_devices} 上启动多进程池...") 138 | pool = model.start_multi_process_pool(target_devices=target_devices) 139 | logger.info(f"✅ 成功为 '{model_name}' 启动多进程池。") 140 | 141 | MODELS[model_name] = {"engine": model, "pool": pool} 142 | logger.info(f"✅ 模型 '{model_name}' 加载成功。") 143 | 144 | except Exception as e: 145 | logger.error(f"❌ 加载模型 '{model_name}' 失败: {e}", exc_info=True) 146 | 147 | if not MODELS: 148 | logger.error("没有任何模型被成功加载。服务将关闭。") 149 | exit(1) 150 | 151 | yield 152 | 153 | logger.info("Shutting down Embedding Service...") 154 | for model_name, model_data in MODELS.items(): 155 | if model_data.get("pool"): 156 | logger.info(f"正在停止模型 '{model_name}' 的多进程池...") 157 | SentenceTransformer.stop_multi_process_pool(model_data["pool"]) 158 | MODELS.clear() 159 | 160 | # --- FastAPI App & Endpoints (保持不变) --- 161 | app = FastAPI( 162 | title="Intelligent Embedding Service", 163 | description="自动处理模型缓存的高性能、多GPU嵌入服务。", 164 | version="2.0.0", 165 | lifespan=lifespan 166 | ) 167 | 168 | @app.get("/health") 169 | async def health_check(): 170 | model_status = {name: f"pool: {'active' if data.get('pool') else 'inactive'}" for name, data in MODELS.items()} 171 | return {"status": "ok", "loaded_models": model_status} 172 | 173 | @app.post("/embed", response_model=EmbeddingResponse) 174 | async def create_embeddings(request: EmbeddingRequest): 175 | model_entry = MODELS.get(request.model) 176 | if not model_entry: 177 | raise HTTPException(status_code=404, detail=f"Model '{request.model}' not found. Available models: {list(MODELS.keys())}") 178 | 179 | model_engine = model_entry["engine"] 180 | pool = model_entry.get("pool") 181 | 182 | try: 183 | loop = asyncio.get_event_loop() 184 | encode_func = functools.partial(model_engine.encode, request.texts, pool=pool, batch_size=256, convert_to_numpy=False) 185 | embeddings = await loop.run_in_executor(None, encode_func) 186 | return EmbeddingResponse(model=request.model, embeddings=embeddings) 187 | except Exception as e: 188 | logger.error(f"处理嵌入请求时出错: {e}", exc_info=True) 189 | raise HTTPException(status_code=500, detail="Internal server error during embedding.") 190 | 191 | # --- Main execution block (保持不变) --- 192 | if __name__ == "__main__": 193 | parser = argparse.ArgumentParser() 194 | parser.add_argument("--config", type=str, default="config.yaml", help="Path to the configuration file.") 195 | args, _ = parser.parse_known_args() 196 | 197 | host, port = "0.0.0.0", 8000 198 | try: 199 | with open(args.config, 'r') as f: 200 | config_data = yaml.safe_load(f) 201 | server_config = config_data.get('server', {}) 202 | host = server_config.get('host', host) 203 | port = server_config.get('port', port) 204 | except Exception: 205 | pass 206 | 207 | uvicorn.run("server:app", host=host, port=port) 208 | --------------------------------------------------------------------------------