├── requirements.txt
├── description_generation.png
├── main.py
├── call_llamaindex_llm.py
├── parallel_main.py
├── utils.py
├── README.md
├── type_engine.py
├── default_prompts.py
├── components.py
├── v6-duckdb
└── parallel_llm_processor.py
├── mschema.py
├── LICENSE
└── schema_engine.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | llama-index
2 | numpy
3 | llama-index-llms-dashscope
4 |
--------------------------------------------------------------------------------
/description_generation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XGenerationLab/XiYan-DBDescGen/HEAD/description_generation.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | from llama_index.llms.dashscope import DashScope, DashScopeGenerationModels
3 | from sqlalchemy import create_engine
4 | from schema_engine import SchemaEngine
5 |
6 | dashscope_llm = DashScope(model_name=DashScopeGenerationModels.QWEN_PLUS, api_key='YOUR API KEY HERE.')
7 |
8 | db_path = './book_1.sqlite'
9 | db_abs_path = os.path.abspath(db_path)
10 | db_engine = create_engine(f'sqlite:///{db_abs_path}')
11 |
12 | comment_mode = 'generation'
13 | schema_engine_instance = SchemaEngine(db_engine, llm=dashscope_llm, db_name='book_1',
14 | comment_mode=comment_mode)
15 | schema_engine_instance.fields_category()
16 | schema_engine_instance.table_and_column_desc_generation()
17 | mschema = schema_engine_instance.mschema
18 | mschema.save('./book_1.json')
19 | mschema_str = mschema.to_mschema()
20 | print(mschema_str)
--------------------------------------------------------------------------------
/call_llamaindex_llm.py:
--------------------------------------------------------------------------------
1 | import time
2 | from llama_index.core.llms import LLM, ChatMessage
3 | from llama_index.core.prompts import BasePromptTemplate
4 | from typing import Any, Dict, Iterable, List, Optional, Tuple, Sequence
5 |
6 |
7 | def call_llm(prompt: BasePromptTemplate, llm: Optional[LLM] = None, max_try=5, sleep=10, **prompt_args)->str:
8 | for try_idx in range(max_try):
9 | try:
10 | res = llm.predict(prompt, **prompt_args)
11 | return res
12 | except:
13 | time.sleep(sleep)
14 | return ''
15 |
16 |
17 | def call_llm_message(messages: Sequence[ChatMessage], llm: Optional[LLM] = None, max_try=5, sleep=10, **kwargs)->str:
18 | for try_idx in range(max_try):
19 | try:
20 | res = llm.chat(messages, **kwargs)
21 | return res.message.content
22 | except:
23 | time.sleep(sleep)
24 | return ''
25 |
26 |
--------------------------------------------------------------------------------
/parallel_main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import argparse
4 | from llama_index.llms.dashscope import DashScope, DashScopeGenerationModels
5 | from sqlalchemy import create_engine
6 | from schema_engine import SchemaEngine
7 | from parallel_schema_engine import ParallelSchemaEngine
8 |
9 |
10 | def main():
11 | parser = argparse.ArgumentParser(description="Process database schema with parallel or sequential execution")
12 | parser.add_argument('--parallel', action='store_true', help='Use parallel processing')
13 | parser.add_argument('--db_path', type=str, default='./book_1.sqlite', help='Path to the database file')
14 | parser.add_argument('--output_json', type=str, default='./output_schema.json', help='Output JSON file path')
15 | parser.add_argument('--max_workers', type=int, default=4,
16 | help='Maximum number of worker threads for parallel processing')
17 | parser.add_argument('--comment_mode', type=str, default='generation',
18 | choices=['origin', 'merge', 'generation', 'no_comment'],
19 | help='Comment mode for schema generation')
20 | parser.add_argument('--api_key', type=str, default='YOUR API KEY HERE.', help='DashScope API key')
21 | args = parser.parse_args()
22 |
23 | # Initialize LLM
24 | dashscope_llm = DashScope(model_name=DashScopeGenerationModels.QWEN_PLUS, api_key=args.api_key)
25 |
26 | # Get absolute path to database
27 | db_abs_path = os.path.abspath(args.db_path)
28 | db_engine = create_engine(f'sqlite:///{db_abs_path}')
29 |
30 | # Get database name from the file path
31 | db_name = os.path.splitext(os.path.basename(args.db_path))[0]
32 |
33 | # Start timing
34 | start_time = time.time()
35 |
36 | # Choose between parallel and sequential processing
37 | if args.parallel:
38 | print(f"Using parallel processing with {args.max_workers} workers")
39 | schema_engine_instance = ParallelSchemaEngine(
40 | db_engine,
41 | llm=dashscope_llm,
42 | db_name=db_name,
43 | comment_mode=args.comment_mode,
44 | max_workers=args.max_workers
45 | )
46 | else:
47 | print("Using sequential processing")
48 | schema_engine_instance = SchemaEngine(
49 | db_engine,
50 | llm=dashscope_llm,
51 | db_name=db_name,
52 | comment_mode=args.comment_mode
53 | )
54 |
55 | # Process the database schema
56 | print("Categorizing fields...")
57 | schema_engine_instance.fields_category()
58 |
59 | print("Generating table and column descriptions...")
60 | schema_engine_instance.table_and_column_desc_generation()
61 |
62 | # Save the result
63 | mschema = schema_engine_instance.mschema
64 | mschema.save(args.output_json)
65 |
66 | # Print the result
67 | mschema_str = mschema.to_mschema()
68 | print(mschema_str)
69 |
70 | # Print execution time
71 | end_time = time.time()
72 | elapsed_time = end_time - start_time
73 | print(f"Execution completed in {elapsed_time:.2f} seconds")
74 |
75 |
76 | if __name__ == "__main__":
77 | main()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import datetime, decimal
3 | import json
4 |
5 |
6 | def write_json(path, data):
7 | with open(path, 'w', encoding='utf-8') as f:
8 | json.dump(data, f, ensure_ascii=False, indent=2)
9 |
10 |
11 | def read_json(path):
12 | with open(path, 'r', encoding='utf-8') as f:
13 | data = json.load(f)
14 | return data
15 |
16 |
17 | def read_text(filename)->str:
18 | data = []
19 | with open(filename, 'r', encoding='utf-8') as file:
20 | for line in file.readlines():
21 | line = line.strip()
22 | data.append(line)
23 | return data
24 |
25 |
26 | def save_raw_text(filename, content):
27 | with open(filename, 'w', encoding='utf-8') as file:
28 | file.write(content)
29 |
30 |
31 | def save_json(target_file,js,indent=4):
32 | with open(target_file, 'w', encoding='utf-8') as f:
33 | json.dump(js, f, ensure_ascii=False, indent=indent)
34 |
35 |
36 | def is_email(string):
37 | pattern = r'^[\w\.-]+@[\w\.-]+\.\w+$'
38 | match = re.match(pattern, string)
39 | if match:
40 | return True
41 | else:
42 | return False
43 |
44 |
45 | def extract_sql_from_llm_response(llm_response: str) -> str:
46 | """
47 | Parse SQL from LLM response in markdown format
48 | """
49 |
50 | sql = llm_response
51 | pattern = r"```sql(.*?)```"
52 |
53 | sql_code_snippets = re.findall(pattern, llm_response, re.DOTALL)
54 |
55 | if len(sql_code_snippets) > 0:
56 | sql = sql_code_snippets[-1].strip()
57 |
58 | return sql
59 |
60 |
61 | def examples_to_str(examples: list) -> list[str]:
62 | """
63 | from examples to a list of str
64 | """
65 | values = examples
66 | for i in range(len(values)):
67 | if isinstance(values[i], datetime.date):
68 | values = [values[i]]
69 | break
70 | elif isinstance(values[i], datetime.datetime):
71 | values = [values[i]]
72 | break
73 | elif isinstance(values[i], decimal.Decimal):
74 | values[i] = str(float(values[i]))
75 | elif is_email(str(values[i])):
76 | values = []
77 | break
78 | elif 'http://' in str(values[i]) or 'https://' in str(values[i]):
79 | values = []
80 | break
81 | elif values[i] is not None and not isinstance(values[i], str):
82 | pass
83 | elif values[i] is not None and '.com' in values[i]:
84 | pass
85 |
86 | return [str(v) for v in values if v is not None and len(str(v)) > 0]
87 |
88 | def extract_simple_json_from_qwen(qwen_result) -> dict:
89 | qwen_result=qwen_result.replace('\n', '')
90 | pattern = r"```json(.*?)```"
91 |
92 | # 使用re.DOTALL标志来使得点号(.)可以匹配包括换行符在内的任意字符
93 | sql_code_snippets = re.findall(pattern, qwen_result, re.DOTALL)
94 | data={}
95 | if len(sql_code_snippets) > 0:
96 | data = sql_code_snippets[-1].strip()
97 | try:
98 | data = eval(data)
99 | except:
100 | find = re.findall('错误信息\':\'(.*)\'', data)
101 | try:
102 | if len(find)>0:
103 | find_out = find[0].replace('\'','"')
104 | data=data.replace(find[0],find_out)
105 | data = eval(data)
106 | else:
107 |
108 | #re.findall('错误信息\':\'(.*)\'', data)[0].replace('\'', '"')
109 | if "]}" in data:
110 | data = data.replace(']}', '}]')
111 | data = eval(data)
112 | if 'false' in data or 'true' in data:
113 | data = data.replace('false','False').replace('true','True')
114 | data = eval(data)
115 | else:
116 | print("en error happened on eval")
117 | data={}
118 | except:
119 | data={}
120 | return data
121 |
122 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Automatic database description generation for Text-to-SQL
2 |
3 | ## Important Links
4 |
5 | 🤖[Arxiv](https://arxiv.org/abs/2502.20657) |
6 | 📖[XiYan-SQL](https://github.com/XGenerationLab/XiYan-SQL) |
7 |
8 |
9 | ## Introduction
10 | This repository provides a method for automatically generating effective database descriptions when explicit descriptions are unavailable. The proposed method employs a dual-process approach: a coarse-to-fine process, followed by a fine-to-coarse process. Experimental results on the Bird benchmark indicate that using descriptions generated by the proposed improves SQL generation accuracy by 0.93% compared to not using descriptions, and achieves 37% of human-level performance.
11 | We support three common database dialects: SQLite, MySQL, PostgreSQL and SQL Server.
12 |
13 | Read more: [Arxiv](https://arxiv.org/abs/2502.20657)
14 |
15 |
16 |
17 |
18 | ## Requirements
19 | + python >= 3.9
20 |
21 | You can install the required packages with the following command:
22 | ```shell
23 | pip install -r requirements.txt
24 | ```
25 |
26 | ## Quick Start
27 |
28 | 1. Create a database connection.
29 |
30 | Connect to SQLite:
31 | ```python
32 | import os
33 | from sqlalchemy import create_engine
34 |
35 | db_path = "path_to_sqlite"
36 | abs_path = os.path.abspath(db_path)
37 | db_engine = create_engine(f'sqlite:///{abs_path}')
38 | ```
39 |
40 | 2. Set llama-index LLM.
41 |
42 | Take dashscope as an example:
43 | ```python
44 | from llama_index.llms.dashscope import DashScope, DashScopeGenerationModels
45 | dashscope_llm = DashScope(model_name=DashScopeGenerationModels.QWEN_PLUS, api_key='YOUR API KEY HERE.')
46 | ```
47 |
48 | 3. Generate the database description and build M-Schema.
49 | ```python
50 | from schema_engine import SchemaEngine
51 |
52 | db_name = 'your_db_name'
53 | comment_mode = 'generation'
54 | schema_engine_instance = SchemaEngine(db_engine, llm=dashscope_llm, db_name=db_name,
55 | comment_mode=comment_mode)
56 | schema_engine_instance.fields_category()
57 | schema_engine_instance.table_and_column_desc_generation()
58 | mschema = schema_engine_instance.mschema
59 | mschema.save(f'./{db_name}.json')
60 | mschema_str = mschema.to_mschema()
61 | print(mschema_str)
62 | ```
63 |
64 | ## Contact us:
65 |
66 | If you are interested in our research or products, please feel free to contact us.
67 |
68 | #### Contact Information:
69 |
70 | Yifu Liu, zhencang.lyf@alibaba-inc.com
71 |
72 | #### Join Our DingTalk Group
73 |
74 | Ding Group钉钉群
75 |
76 |
77 |
78 | ## Citation
79 | If you find our work helpful, feel free to give us a cite.
80 |
81 | ```bibtex
82 | @article{description_generation,
83 | title={Automatic database description generation for Text-to-SQL},
84 | author={Yingqi Gao and Zhiling Luo},
85 | year={2025},
86 | eprint={2502.20657},
87 | archivePrefix={arXiv},
88 | primaryClass={cs.AI},
89 | url={https://arxiv.org/abs/2502.20657},
90 | }
91 |
92 | @article{XiYanSQL,
93 | title={XiYan-SQL: A Novel Multi-Generator Framework For Text-to-SQL},
94 | author={Yifu Liu and Yin Zhu and Yingqi Gao and Zhiling Luo and Xiaoxia Li and Xiaorong Shi and Yuntao Hong and Jinyang Gao and Yu Li and Bolin Ding and Jingren Zhou},
95 | year={2025},
96 | eprint={2507.04701},
97 | archivePrefix={arXiv},
98 | primaryClass={cs.CL},
99 | url={https://arxiv.org/abs/2507.04701},
100 | }
101 |
102 | @article{xiyansql_pre,
103 | title={A Preview of XiYan-SQL: A Multi-Generator Ensemble Framework for Text-to-SQL},
104 | author={Yingqi Gao and Yifu Liu and Xiaoxia Li and Xiaorong Shi and Yin Zhu and Yiming Wang and Shiqi Li and Wei Li and Yuntao Hong and Zhiling Luo and Jinyang Gao and Liyu Mou and Yu Li},
105 | year={2024},
106 | journal={arXiv preprint arXiv:2411.08599},
107 | url={https://arxiv.org/abs/2411.08599},
108 | primaryClass={cs.AI}
109 | }
110 | ```
111 |
112 |
--------------------------------------------------------------------------------
/type_engine.py:
--------------------------------------------------------------------------------
1 | class TypeEngine:
2 | def __int__(self):
3 | pass
4 |
5 | @property
6 | def supported_dialects(self):
7 | return [self.mysql_dialect, self.postgres_dialect, self.sqlite_dialect, self.sqlserver_dialect]
8 |
9 | @property
10 | def mysql_dialect(self):
11 | return 'mysql'
12 |
13 | @property
14 | def postgres_dialect(self):
15 | return 'postgresql'
16 |
17 | @property
18 | def sqlite_dialect(self):
19 | return 'sqlite'
20 |
21 | @property
22 | def sqlserver_dialect(self):
23 | return 'mssql'
24 |
25 | def field_type_abbr(self, field_type: str):
26 | """字段类型缩写,用于在MSchema中展示"""
27 | return field_type.split("(")[0]
28 |
29 | @property
30 | def mysql_date_types(self):
31 | return ['DATE', 'TIME', 'DATETIME', 'TIMESTAMP', 'YEAR']
32 |
33 | @property
34 | def pg_date_types(self):
35 | return ['DATE', 'TIME', 'TIMESTAMP', 'TIMESTAMP WITHOUT TIME ZONE', 'TIMESTAMP WITH TIME ZONE']
36 |
37 | @property
38 | def date_date_type(self):
39 | return 'DATE'
40 |
41 | @property
42 | def date_time_type(self):
43 | return 'TIME'
44 |
45 | @property
46 | def date_datetime_type(self):
47 | return 'DATETIME'
48 |
49 | @property
50 | def date_timestamp_type(self):
51 | return 'TIMESTAMP'
52 |
53 | @property
54 | def all_date_types(self):
55 | return ['DATE', 'TIME', 'DATETIME', 'TIMESTAMP', 'TIMESTAMP WITHOUT TIME ZONE',
56 | 'TIMESTAMP WITH TIME ZONE']
57 |
58 | @property
59 | def mysql_string_types(self):
60 | return ['BLOB', 'TINYBLOB', 'MEDIUMBLOB', 'LONGBLOB', 'CHAR', 'VARCHAR',
61 | 'TEXT', 'TINYTEXT', 'MEDIUMTEXT', 'LONGTEXT']
62 |
63 | @property
64 | def pg_string_types(self):
65 | return ['CHARACTER VARYING', 'VARCHAR', 'CHAR', 'CHARACTER', 'TEXT']
66 |
67 | @property
68 | def all_string_types(self):
69 | return self.mysql_string_types + self.pg_string_types
70 |
71 | @property
72 | def mysql_number_types(self):
73 | return ['TINYINT', 'SMALLINT', 'MEDIUMINT', 'INT', 'INTEGER', 'BIGINT',
74 | 'FLOAT', 'DOUBLE', 'DECIMAL']
75 |
76 | @property
77 | def pg_number_types(self):
78 | return ['SMALLINT', 'INTEGER', 'BIGINT',
79 | 'DECIMAL', 'NUMERIC', 'REAL', 'DOUBLE PRECISION',
80 | 'SMALLSERIAL', 'SERIAL', 'BIGSERIAL']
81 |
82 | @property
83 | def all_number_types(self):
84 | return self.mysql_number_types + self.pg_number_types
85 |
86 | @property
87 | def all_enum_types(self):
88 | return ['ENUM', 'SET']
89 |
90 | def field_type_cate(self, field_type: str) -> str:
91 | """根据数据类型分组"""
92 | field_type = self.field_type_abbr(field_type.upper())
93 | if field_type in self.all_number_types:
94 | return self.field_type_number_label
95 | elif field_type in self.all_string_types:
96 | return self.field_type_string_label
97 | elif field_type in self.all_date_types:
98 | return self.field_type_date_label
99 | elif field_type in ['BOOL', 'BOOLEAN']:
100 | return self.field_type_bool_label
101 | else:
102 | return self.field_type_other_label
103 |
104 | @property
105 | def date_time_min_grans(self):
106 | """时间日期类字段的最小颗粒度"""
107 | return ['YEAR', 'MONTH', 'DAY', 'QUARTER', 'WEEK', 'HOUR', 'MINUTE',
108 | 'SECOND', 'MILLISECOND', 'MICROSECOND', 'OTHER']
109 |
110 | @property
111 | def field_type_all_labels(self):
112 | return [self.field_type_number_label, self.field_type_string_label, self.field_type_date_label,
113 | self.field_type_bool_label, self.field_type_other_label]
114 |
115 | @property
116 | def field_type_number_label(self):
117 | return 'Number'
118 |
119 | @property
120 | def field_type_string_label(self):
121 | return 'String'
122 |
123 | @property
124 | def field_type_date_label(self):
125 | return 'DateTime'
126 |
127 | @property
128 | def field_type_bool_label(self):
129 | return 'Bool'
130 |
131 | @property
132 | def field_type_other_label(self):
133 | return 'Other'
134 |
135 | @property
136 | def field_category_all_labels(self):
137 | return [self.field_category_code_label, self.field_category_enum_label,
138 | self.field_category_date_label, self.field_category_text_label,
139 | self.field_category_measure_label]
140 |
141 | @property
142 | def field_category_code_label(self):
143 | return 'Code'
144 |
145 | @property
146 | def field_category_enum_label(self):
147 | return 'Enum'
148 |
149 | @property
150 | def field_category_date_label(self):
151 | return 'DateTime'
152 |
153 | @property
154 | def field_category_text_label(self):
155 | return 'Text'
156 |
157 | @property
158 | def field_category_measure_label(self):
159 | return 'Measure'
160 |
161 | @property
162 | def dim_measure_labels(self):
163 | return [self.dimension_label, self.measure_label]
164 |
165 | @property
166 | def dimension_label(self):
167 | return 'Dimension'
168 |
169 | @property
170 | def measure_label(self):
171 | return 'Measure'
172 |
--------------------------------------------------------------------------------
/default_prompts.py:
--------------------------------------------------------------------------------
1 | from llama_index.core.prompts import PromptTemplate
2 | from llama_index.core.prompts.prompt_type import PromptType
3 |
4 | DEFAULT_IS_DATE_TIME_FIELD_TMPL = """你现在是一名数据分析师,给你数据表中某一列的相关信息,请你分析该列是否属于时间日期类型,仅回答"是"或"否"。
5 | 时间日期类型指的是由年、月、日、时、分、秒中的一种或几种组合而成的,要求月份必须在1-12之间,日期在1-31之间,小时在0-23之间,分钟和秒在0-59之间。
6 |
7 | {field_info_str}
8 | """
9 |
10 | DEFAULT_IS_DATE_TIME_FIELD_PROMPT = PromptTemplate(
11 | DEFAULT_IS_DATE_TIME_FIELD_TMPL,
12 | prompt_type=PromptType.CUSTOM,
13 | )
14 |
15 | # 时间日期类字段的最小颗粒度
16 | DEFAULT_DATE_TIME_MIN_GRAN_TMPL = """你现在是一名数据分析师,给你数据表中的一个字段,已知该字段表示的含义与时间日期有关,请你根据该字段的组成格式和数据样例,推测该字段的最小颗粒度是什么。
17 | 说明:时间日期字段的最小颗粒度是指该字段能够精确到的最小时间单位。
18 |
19 | 以下是常见的时间单位:
20 | YEAR: 最小时间单位是一年,例如,2024
21 | MONTH: 某年的第几个月,一年有12个月,Month的取值在1-12之间,例如,2024-12
22 | DAY: 某月的第几天,一个月最多有31天,因此Day取值在1-31之间,例如,2024-12-31
23 | WEEK: 自然周,一般为一年中的第几周,一年包含52周多几天,Week通常在0-53之间,例如,2024-34
24 | QUARTER: 某年的第几个季度,一年有四个季度,Quarter通常在1-4取值
25 | HOUR: 某天的第几个小时,一天有24个小时,Hour在0-23之间
26 | MINUTE: 某小时的第几分钟,一小时有60分钟,Minute在0-59之间
27 | SECOND: 某分钟的第几秒,一分钟有60秒,Second在0-59之间
28 | MILLISECOND: 毫秒
29 | MICROSECOND: 微秒
30 | OTHER: 其他不属于以上的时间单位,比如半年、一刻钟等
31 |
32 | 直接给出最小时间单位的名称。
33 |
34 | 以下样例供你参考:
35 | 【字段信息】
36 | 字段名称: dt
37 | 数据类型: DOUBLE
38 | Value Examples: [202412.0, 202301.0 202411.0, 202201.0, 202308.0, 202110.0, 202211.0]
39 | 最小时间单位: MONTH
40 |
41 | 【字段信息】
42 | 字段名称: dt
43 | 数据类型: TEXT
44 | Value Examples: ['2022-12', '2022-14', '2021-40', '2021-37', '2021-01', '2021-32', '2023-04', '2023-37']
45 | 最小时间单位: WEEK
46 |
47 | 【字段信息】
48 | 字段名称: dt
49 | 数据类型: TEXT
50 | Value Examples: ['12:30:30', '23:45:23', '01:23:12', '12:12:12', '14:34:31', '18:43:01', '22:13:21']
51 | 最小时间单位: SECOND
52 |
53 | 请你参考以上样例,推测下面字段的最小时间单位,直接给出最小时间单位的名称。
54 | 【字段信息】
55 | {field_info_str}
56 | 最小时间单位: """
57 |
58 | DEFAULT_DATE_TIME_MIN_GRAN_PROMPT = PromptTemplate(
59 | DEFAULT_DATE_TIME_MIN_GRAN_TMPL,
60 | prompt_type=PromptType.CUSTOM,
61 | )
62 |
63 |
64 | DEFAULT_STRING_CATEGORY_FIELD_TMPL = '''你现在是一名数据分析师,给你数据表中某一列的相关信息,请你分析该列是enum类型、code类型、还是text类型,仅回答"enum"、"code"或"text"。
65 |
66 | enum:具有枚举的特征:字段取值相对固定,集中在一个预定义的有限集合内,通常长度较短,组成模式相对固定,一般用于状态、类型等字段;
67 | code:有特定意义的编码,code的组成通常存在一定的规律或标准,比如用户id、身份证号等;
68 | text:自由文本,通常用于描述或说明,不受长度和形式的限制,内容可以是任何形式的文本。
69 |
70 | {field_info_str}
71 | '''
72 |
73 | DEFAULT_STRING_CATEGORY_FIELD_PROMPT = PromptTemplate(
74 | DEFAULT_STRING_CATEGORY_FIELD_TMPL,
75 | prompt_type=PromptType.CUSTOM,
76 | )
77 |
78 | DEFAULT_NUMBER_CATEGORY_FIELD_TMPL = """你现在是一名数据分析师,给你数据表中某一列的相关信息,请你分析该列是enum类型、code类型、还是measure类型,仅回答"enum"、"code"或"measure"。
79 |
80 | enum:枚举类型,取值局限于一个预定义的有限集合,通常长度较短,一般用于状态、类型等字段;
81 | code:有特定意义的编码,code的组成通常存在一定的规律或标准,比如用户id、身份证号等;
82 | measure:指标、度量,可以用来做进行计算和聚合,比如求平均、最大值等。
83 |
84 | {field_info_str}
85 | """
86 |
87 | DEFAULT_NUMBER_CATEGORY_FIELD_PROMPT = PromptTemplate(
88 | DEFAULT_NUMBER_CATEGORY_FIELD_TMPL,
89 | prompt_type=PromptType.CUSTOM,
90 | )
91 |
92 | DEFAULT_UNKNOWN_CATEGORY_FIELD_TMPL = """你现在是一名数据分析师,给你数据表中某一列的相关信息,请你分析该列是enum类型、measure类型、code类型、还是text类型,仅回答"enum"、"measure"、"code"或"text"。
93 |
94 | enum:枚举类型,取值局限于一个预定义的有限集合,通常长度较短,一般用于状态、类型等字段;
95 | code:有特定意义的编码,code的组成通常存在一定的规律或标准,比如用户id、身份证号等;
96 | text:自由文本,通常用于描述或说明,不受长度限制,内容可以是任何形式的文本;
97 | measure:指标、度量,可以用来做进行计算和聚合,比如求平均、最大值等。
98 |
99 | {field_info_str}
100 | """
101 |
102 | DEFAULT_UNKNOWN_FIELD_PROMPT = PromptTemplate(
103 | DEFAULT_UNKNOWN_CATEGORY_FIELD_TMPL,
104 | prompt_type=PromptType.CUSTOM,
105 | )
106 |
107 | DEFAULT_COLUMN_DESC_GEN_CHINESE_TMPL = '''你现在是一名数据分析师,给你一张数据表的字段信息和一些数据样例如下:
108 |
109 | {table_mschema}
110 |
111 | 【SQL】
112 | {sql}
113 | 【Examples】
114 | {sql_res}
115 |
116 | 下面是该表中字段"{field_name}"的详细信息:
117 | {field_info_str}
118 |
119 | 以下信息可供你参考:
120 | {supp_info}
121 |
122 | 现在请你仔细阅读并理解上述内容和数据,为字段"{field_name}"添加中文名称,要求如下:
123 | 1、中文名称尽可能简洁清晰,准确描述该字段所表示的业务语义,不要偏离原有的字段描述;
124 | 2、字段中文名的长度不要超过20个字;
125 | 3、按json格式输出:
126 | ```json
127 | {"chinese_name": ""}
128 | ```
129 | '''
130 |
131 | DEFAULT_COLUMN_DESC_GEN_CHINESE_PROMPT = PromptTemplate(
132 | DEFAULT_COLUMN_DESC_GEN_CHINESE_TMPL,
133 | prompt_type=PromptType.CUSTOM,
134 | )
135 |
136 | DEFAULT_COLUMN_DESC_GEN_ENGLISH_TMPL = '''你现在是一名数据分析师,给你一张数据表的字段信息和一些数据样例如下:
137 |
138 | {table_mschema}
139 |
140 | 【SQL】
141 | {sql}
142 | 【Examples】
143 | {sql_res}
144 |
145 | 下面是该表中字段"{field_name}"的详细信息:
146 | {field_info_str}
147 |
148 | 以下信息可供你参考:
149 | {supp_info}
150 |
151 | 现在请你仔细阅读并理解上述内容和数据,为字段"{field_name}"添加英文描述,要求如下:
152 | 1、英文描述要尽可能简洁清晰,准确描述该字段所表示的业务语义,不要偏离原有的字段描述;
153 | 2、总输出长度不要超过20个单词;
154 | 3、按json格式输出:
155 | ```json
156 | {"english_desc": ""}
157 | ```
158 | '''
159 |
160 | DEFAULT_COLUMN_DESC_GEN_ENGLISH_PROMPT = PromptTemplate(
161 | DEFAULT_COLUMN_DESC_GEN_ENGLISH_TMPL,
162 | prompt_type=PromptType.CUSTOM,
163 | )
164 |
165 |
166 | DEFAULT_UNDERSTAND_DATABASE_TMPL = '''你现在是一名数据分析师,给你一个数据库的Schema如下:
167 |
168 | {db_mschema}
169 |
170 | 请你仔细阅读以上信息,在database的层面上分析,该数据库主要存储的是什么领域的什么数据,给出总结即可,不需要针对每张表分析。
171 | '''
172 |
173 | DEFAULT_UNDERSTAND_DATABASE_PROMPT = PromptTemplate(
174 | DEFAULT_UNDERSTAND_DATABASE_TMPL,
175 | prompt_type=PromptType.CUSTOM,
176 | )
177 |
178 | DEFAULT_GET_DOMAIN_KNOWLEDGE_TMPL = '''有这样一个数据库,基本信息如下:
179 | {db_info}
180 |
181 | 结合你所学习到的知识分析,在该领域,人们通常关心的维度和指标有哪些?
182 | '''
183 |
184 | DEFAULT_GET_DOMAIN_KNOWLEDGE_PROMPT = PromptTemplate(
185 | DEFAULT_GET_DOMAIN_KNOWLEDGE_TMPL,
186 | prompt_type=PromptType.CUSTOM,
187 | )
188 |
189 | # 按照category,理解各个字段之间的区别与联系
190 | DEFAULT_UNDERSTAND_FIELDS_BY_CATEGORY_TMPL = '''你现在是一名数据分析师,给你一个数据的基本信息:
191 |
192 | 【数据库信息】
193 | {db_info}
194 |
195 | 其中数据表"{table_name}"的字段信息和数据样例如下:
196 | {table_mschema}
197 |
198 | 【SQL】
199 | {sql}
200 | 【Examples】
201 | {sql_res}
202 |
203 | 请你仔细阅读并理解该数据表,已知表中的{fields}字段均为 {category} 字段,请你分析这几个字段之间的关系和区别是什么?
204 | '''
205 |
206 | DEFAULT_UNDERSTAND_FIELDS_BY_CATEGORY_PROMPT = PromptTemplate(
207 | DEFAULT_UNDERSTAND_FIELDS_BY_CATEGORY_TMPL,
208 | prompt_type=PromptType.CUSTOM,
209 | )
210 |
211 | DEFAULT_TABLE_DESC_GEN_CHINESE_TMPL = '''你现在是一名数据分析师,给你一张数据表的字段信息如下:
212 |
213 | {table_mschema}
214 |
215 | 以下是一些数据样例:
216 | 【SQL】
217 | {sql}
218 | 【Examples】
219 | {sql_res}
220 |
221 | 现在请你仔细阅读并理解上述内容和数据,为该数据表生成一段中文的表描述,要求:
222 | 1、说明该表在何种维度(包括时间维度和其他维度)上存储了什么指标数据;
223 | 2、字数控制在100字以内。
224 | 3、回答以json格式输出。
225 |
226 | ```json
227 | {"table_desc": ""}
228 | ```
229 | '''
230 |
231 | DEFAULT_TABLE_DESC_GEN_CHINESE_PROMPT = PromptTemplate(
232 | DEFAULT_TABLE_DESC_GEN_CHINESE_TMPL,
233 | prompt_type=PromptType.CUSTOM,
234 | )
235 |
236 | DEFAULT_TABLE_DESC_GEN_ENGLISH_TMPL = '''你现在是一名数据分析师,给你一张数据表的字段信息如下:
237 |
238 | {table_mschema}
239 |
240 | 以下是一些数据样例:
241 | 【SQL】
242 | {sql}
243 | 【Examples】
244 | {sql_res}
245 |
246 | 现在请你仔细阅读并理解上述内容和数据,为该数据表生成一段英文的表描述,要求:
247 | 1、说明该表在何种维度(包括时间维度和其他维度)上存储了什么指标数据;
248 | 2、长度不要超过100个单词。
249 | 3、回答以json格式输出。
250 |
251 | ```json
252 | {"table_desc": ""}
253 | ```
254 | '''
255 |
256 | DEFAULT_TABLE_DESC_GEN_ENGLISH_PROMPT = PromptTemplate(
257 | DEFAULT_TABLE_DESC_GEN_ENGLISH_TMPL,
258 | prompt_type=PromptType.CUSTOM,
259 | )
260 |
261 | DEFAULT_SQL_GEN_TMPL = '''你现在是一名{dialect}数据分析师,给你一个数据库的Schema信息如下:
262 |
263 | 【数据库Schema】
264 | {db_mschema}
265 |
266 | 【用户问题】
267 | {question}
268 | 【参考信息】
269 | {evidence}
270 |
271 | 请你仔细阅读并理解该数据库,根据用户问题和参考信息的提示,生成一句可执行的SQL来回答用户问题,生成的SQL用```sql 和```保护起来。
272 | '''
273 |
274 | DEFAULT_SQL_GEN_PROMPT = PromptTemplate(
275 | DEFAULT_SQL_GEN_TMPL,
276 | prompt_type=PromptType.CUSTOM,
277 | )
278 |
279 |
--------------------------------------------------------------------------------
/components.py:
--------------------------------------------------------------------------------
1 | from llama_index.core.llms import LLM
2 | from typing import Any, Dict, Iterable, List, Optional, Tuple
3 | from utils import extract_sql_from_llm_response, extract_simple_json_from_qwen
4 | from default_prompts import (
5 | DEFAULT_IS_DATE_TIME_FIELD_PROMPT,
6 | DEFAULT_NUMBER_CATEGORY_FIELD_PROMPT,
7 | DEFAULT_STRING_CATEGORY_FIELD_PROMPT,
8 | DEFAULT_UNKNOWN_FIELD_PROMPT,
9 | DEFAULT_COLUMN_DESC_GEN_CHINESE_PROMPT,
10 | DEFAULT_COLUMN_DESC_GEN_ENGLISH_PROMPT,
11 | DEFAULT_TABLE_DESC_GEN_CHINESE_PROMPT,
12 | DEFAULT_TABLE_DESC_GEN_ENGLISH_PROMPT,
13 | DEFAULT_UNDERSTAND_FIELDS_BY_CATEGORY_PROMPT,
14 | DEFAULT_UNDERSTAND_DATABASE_PROMPT,
15 | DEFAULT_GET_DOMAIN_KNOWLEDGE_PROMPT,
16 | DEFAULT_DATE_TIME_MIN_GRAN_PROMPT,
17 | DEFAULT_SQL_GEN_PROMPT
18 | )
19 | from call_llamaindex_llm import call_llm, call_llm_message
20 | from type_engine import TypeEngine
21 |
22 |
23 | def understand_date_time_min_gran(field_info_str: str = '', llm: Optional[LLM] = None):
24 | """
25 | Determine the minimum granularity of date and time fields.
26 | """
27 | res = call_llm(
28 | prompt=DEFAULT_DATE_TIME_MIN_GRAN_PROMPT,
29 | llm=llm,
30 | field_info_str=field_info_str
31 | )
32 | return res.upper().strip()
33 |
34 |
35 | def understand_database(db_mschema: str = '', llm: Optional[LLM] = None):
36 | """
37 | Database understanding.
38 | """
39 | db_info1 = call_llm(DEFAULT_UNDERSTAND_DATABASE_PROMPT, llm, db_mschema=db_mschema)
40 | db_info2 = call_llm(DEFAULT_GET_DOMAIN_KNOWLEDGE_PROMPT, llm, db_info=db_info1)
41 | return (db_info1 + '\n' + db_info2).strip()
42 |
43 |
44 | def generate_column_desc(field_name: str, field_info_str: str = '', table_mschema: str = '',
45 | llm: Optional[LLM] = None, sql: Optional[str] = None, sql_res: Optional[str] = None,
46 | supp_info: Optional[str] = None, language: Optional[str] = 'CN'):
47 |
48 | if language == 'CN':
49 | prompt = DEFAULT_COLUMN_DESC_GEN_CHINESE_PROMPT
50 | elif language == 'EN':
51 | prompt = DEFAULT_COLUMN_DESC_GEN_ENGLISH_PROMPT
52 | else:
53 | raise NotImplementedError(f'Unsupported language {language}.')
54 |
55 | column_desc = call_llm(
56 | prompt,
57 | llm,
58 | table_mschema=table_mschema,
59 | sql=sql,
60 | sql_res=sql_res,
61 | field_name=field_name,
62 | field_info_str=field_info_str,
63 | supp_info=supp_info
64 | ).strip()
65 |
66 | if language == 'CN':
67 | column_desc = extract_simple_json_from_qwen(column_desc).get('chinese_name', '')
68 | column_desc = column_desc.replace('"', '').replace('“', '').replace('”', '').replace('**', '')
69 | if column_desc.endswith('。'):
70 | column_desc = column_desc[:-1].strip()
71 | elif language == 'EN':
72 | column_desc = extract_simple_json_from_qwen(column_desc).get('english_desc', '')
73 | column_desc = column_desc.strip()
74 | if column_desc.startswith(':') or column_desc.startswith(':'):
75 | column_desc = column_desc[1:].strip()
76 | column_desc = column_desc.split('\n')[0]
77 |
78 | return column_desc.strip()
79 |
80 | def generate_table_desc(table_name: str, table_mschema: str = '',
81 | llm: Optional[LLM] = None, sql: Optional[str] = None, sql_res: Optional[str] = None,
82 | language: Optional[str] = 'CN'):
83 | if language == 'CN':
84 | prompt = DEFAULT_TABLE_DESC_GEN_CHINESE_PROMPT
85 | elif language == 'EN':
86 | prompt = DEFAULT_TABLE_DESC_GEN_ENGLISH_PROMPT
87 | else:
88 | raise NotImplementedError(f'Unsupported language {language}.')
89 |
90 | table_desc = call_llm(
91 | prompt,
92 | llm,
93 | table_name=table_name,
94 | table_mschema=table_mschema,
95 | sql=sql,
96 | sql_res=sql_res
97 | )
98 | table_desc = extract_simple_json_from_qwen(table_desc).get('table_desc', '')
99 | table_desc = table_desc.strip()
100 |
101 | return table_desc.strip()
102 |
103 |
104 | def understand_fields_by_category(db_info: str, table_name: str, table_mschema: str = '',
105 | llm: Optional[LLM] = None, sql: Optional[str] = None, sql_res: Optional[str] = None,
106 | fields: Optional[List] = [], dim_or_meas: str = ''):
107 | text = call_llm(
108 | DEFAULT_UNDERSTAND_FIELDS_BY_CATEGORY_PROMPT,
109 | llm,
110 | db_info=db_info,
111 | table_name=table_name,
112 | table_mschema=table_mschema,
113 | sql=sql,
114 | sql_res=sql_res,
115 | fields='、'.join([f"{field}" for field in fields]),
116 | category=dim_or_meas
117 | )
118 | return text
119 |
120 |
121 | def field_category(field_type_cate: str, type_engine: TypeEngine, llm: Optional[LLM] = None,
122 | field_info_str: str = ''):
123 | """
124 | Distinguish field category and whether dimension or measure.
125 | is_unique_pk_cons: 是否为主键、外键或者唯一键(包含与其他字段共同构成联合的主键)
126 | """
127 | code_res = {"category": type_engine.field_category_code_label,
128 | "dim_or_meas": type_engine.dimension_label}
129 | enum_res = {"category": type_engine.field_category_enum_label,
130 | 'dim_or_meas': type_engine.dimension_label}
131 | date_res = {"category": type_engine.field_category_date_label,
132 | 'dim_or_meas': type_engine.dimension_label}
133 | measure_res = {"category": type_engine.field_category_measure_label,
134 | 'dim_or_meas': type_engine.measure_label}
135 | text_res = {"category": type_engine.field_category_text_label,
136 | 'dim_or_meas': type_engine.dimension_label}
137 |
138 | if field_type_cate == type_engine.field_type_date_label:
139 | return date_res
140 | elif field_type_cate == type_engine.field_type_bool_label:
141 | return enum_res
142 | else:
143 | kwargs = {"llm": llm, "field_info_str": field_info_str}
144 | is_date_time = call_llm(
145 | DEFAULT_IS_DATE_TIME_FIELD_PROMPT, **kwargs
146 | ).strip()
147 | if is_date_time == '是':
148 | return date_res
149 | else:
150 | if field_type_cate == type_engine.field_type_string_label:
151 | # 非时间日期类字符串,判断是code、text还是enum
152 | res = call_llm(
153 | DEFAULT_STRING_CATEGORY_FIELD_PROMPT, **kwargs
154 | ).strip().lower()
155 | if res == 'enum':
156 | return enum_res
157 | elif res == 'text':
158 | return text_res
159 | else:
160 | return code_res
161 | elif field_type_cate == type_engine.field_type_number_label:
162 | # 非时间日期类的数值,判断是code、measure还是enum
163 | res = call_llm(DEFAULT_NUMBER_CATEGORY_FIELD_PROMPT, **kwargs).strip().lower()
164 | if res == 'enum':
165 | return enum_res
166 | elif res == 'measure':
167 | return measure_res
168 | else:
169 | return code_res
170 | else:
171 | res = call_llm(DEFAULT_UNKNOWN_FIELD_PROMPT, **kwargs).strip().lower()
172 | if res == 'enum':
173 | return enum_res
174 | elif res == 'measure':
175 | return measure_res
176 | elif res == 'text':
177 | return text_res
178 | else:
179 | return code_res
180 |
181 |
182 | def dummy_sql_generator(dialect: str, db_mschema: str, question: str, evidence: str = '',
183 | llm: Optional[LLM] = None) -> None or str:
184 | """
185 | SQL Generator
186 | """
187 | kwargs = {"dialect": dialect, "db_mschema": db_mschema,
188 | "question": question, "evidence": evidence}
189 | llm_response = call_llm(DEFAULT_SQL_GEN_PROMPT, llm, **kwargs)
190 | sql = extract_sql_from_llm_response(llm_response)
191 | return sql
192 |
--------------------------------------------------------------------------------
/v6-duckdb/parallel_llm_processor.py:
--------------------------------------------------------------------------------
1 | # parallel_llm_processor.py
2 | """并行LLM处理模块
3 |
4 | 该模块支持并行调用大语言模型API,提高批量处理效率。
5 | 它处理请求限速、错误重试,并提供结果缓存。
6 | """
7 |
8 | import time
9 | import hashlib
10 | import json
11 | import os
12 | from typing import List, Dict, Any, Optional, Callable, Union, Tuple
13 | from concurrent.futures import ThreadPoolExecutor, as_completed
14 | from llama_index.core.llms import LLM
15 | from call_llamaindex_llm import call_llm, call_llm_with_json_validation
16 |
17 |
18 | class ParallelLLMProcessor:
19 | """并行LLM处理器,支持批量并行处理LLM请求"""
20 |
21 | def __init__(self, llm: LLM, max_workers: int = 5,
22 | rate_limit: float = 0.5, retry_delay: float = 2.0,
23 | max_retries: int = 3, cache_dir: Optional[str] = "./llm_cache",
24 | use_cache: bool = True):
25 | """初始化并行LLM处理器
26 |
27 | Args:
28 | llm: 大语言模型实例
29 | max_workers: 最大并行工作线程数
30 | rate_limit: 请求间隔时间(秒)
31 | retry_delay: 重试延迟时间(秒)
32 | max_retries: 最大重试次数
33 | cache_dir: 缓存目录路径
34 | use_cache: 是否使用缓存
35 | """
36 | self.llm = llm
37 | self.max_workers = max_workers
38 | self.rate_limit = rate_limit
39 | self.retry_delay = retry_delay
40 | self.max_retries = max_retries
41 | self.cache_dir = cache_dir
42 | self.use_cache = use_cache
43 |
44 | # 创建缓存目录
45 | if self.use_cache and self.cache_dir:
46 | os.makedirs(self.cache_dir, exist_ok=True)
47 |
48 | # 用于请求限速的时间戳
49 | self.last_request_time = 0
50 |
51 | def _calculate_cache_key(self, prompt: str, **kwargs) -> str:
52 | """计算缓存键
53 |
54 | Args:
55 | prompt: LLM提示文本
56 | **kwargs: 其他参数
57 |
58 | Returns:
59 | 缓存键
60 | """
61 | # 创建一个包含所有相关内容的字符串
62 | cache_str = f"{prompt}_{json.dumps(kwargs, sort_keys=True)}"
63 |
64 | # 使用MD5生成一个固定长度的哈希
65 | return hashlib.md5(cache_str.encode()).hexdigest()
66 |
67 | def _get_from_cache(self, cache_key: str) -> Optional[str]:
68 | """从缓存中获取结果
69 |
70 | Args:
71 | cache_key: 缓存键
72 |
73 | Returns:
74 | 缓存的结果,如果不存在则返回None
75 | """
76 | if not self.use_cache:
77 | return None
78 |
79 | cache_file = os.path.join(self.cache_dir, f"{cache_key}.txt")
80 | if os.path.exists(cache_file):
81 | try:
82 | with open(cache_file, 'r', encoding='utf-8') as f:
83 | return f.read()
84 | except Exception as e:
85 | print(f"读取缓存时出错: {e}")
86 |
87 | return None
88 |
89 | def _save_to_cache(self, cache_key: str, result: str) -> None:
90 | """保存结果到缓存
91 |
92 | Args:
93 | cache_key: 缓存键
94 | result: 要缓存的结果
95 | """
96 | if not self.use_cache:
97 | return
98 |
99 | cache_file = os.path.join(self.cache_dir, f"{cache_key}.txt")
100 | try:
101 | with open(cache_file, 'w', encoding='utf-8') as f:
102 | f.write(result)
103 | except Exception as e:
104 | print(f"保存缓存时出错: {e}")
105 |
106 | def _rate_limited_call(self, func: Callable, *args, **kwargs) -> Any:
107 | """使用速率限制调用函数
108 |
109 | Args:
110 | func: 要调用的函数
111 | *args: 位置参数
112 | **kwargs: 关键字参数
113 |
114 | Returns:
115 | 函数调用结果
116 | """
117 | # 检查距离上次请求的时间
118 | current_time = time.time()
119 | time_since_last = current_time - self.last_request_time
120 |
121 | # 如果需要,等待以满足速率限制
122 | if time_since_last < self.rate_limit:
123 | time.sleep(self.rate_limit - time_since_last)
124 |
125 | # 更新最后请求时间
126 | self.last_request_time = time.time()
127 |
128 | # 调用函数
129 | return func(*args, **kwargs)
130 |
131 | def _process_single_prompt(self, prompt: str, task_id: Optional[str] = None,
132 | with_json_validation: bool = False,
133 | required_fields: Optional[List[str]] = None,
134 | expected_type: Optional[type] = None,
135 | **kwargs) -> Union[str, Dict[str, Any]]:
136 | """处理单个提示
137 |
138 | Args:
139 | prompt: LLM提示文本
140 | task_id: 任务ID(可选)
141 | with_json_validation: 是否进行JSON验证
142 | required_fields: JSON必需字段列表
143 | expected_type: 预期的JSON类型
144 | **kwargs: 其他LLM调用参数
145 |
146 | Returns:
147 | LLM响应或JSON验证结果
148 | """
149 | # 计算缓存键
150 | cache_key = self._calculate_cache_key(prompt, **kwargs)
151 |
152 | # 尝试从缓存获取
153 | cached_result = self._get_from_cache(cache_key)
154 | if cached_result is not None:
155 | print(f"任务 {task_id or '未命名'}: 从缓存获取结果")
156 | if with_json_validation:
157 | # 尝试解析缓存的JSON结果
158 | try:
159 | return json.loads(cached_result)
160 | except:
161 | # 缓存的结果不是有效JSON,继续处理
162 | pass
163 | else:
164 | return cached_result
165 |
166 | # 执行LLM调用,带重试
167 | for retry in range(self.max_retries):
168 | try:
169 | if with_json_validation:
170 | result = self._rate_limited_call(
171 | call_llm_with_json_validation,
172 | prompt=prompt,
173 | llm=self.llm,
174 | required_fields=required_fields,
175 | expected_type=expected_type,
176 | max_retries=1, # 我们在这里自己处理重试
177 | **kwargs
178 | )
179 |
180 | # 对于JSON验证,我们只缓存成功的结果
181 | if result.get("success", False):
182 | self._save_to_cache(cache_key, json.dumps(result.get("data", {})))
183 |
184 | return result.get("data") if result.get("success", False) else {}
185 |
186 | else:
187 | result = self._rate_limited_call(call_llm, prompt, self.llm, **kwargs)
188 | # 缓存结果
189 | self._save_to_cache(cache_key, result)
190 | return result
191 |
192 | except Exception as e:
193 | print(f"任务 {task_id or '未命名'} 第 {retry + 1}/{self.max_retries} 次重试出错: {e}")
194 | if retry < self.max_retries - 1:
195 | time.sleep(self.retry_delay * (retry + 1)) # 指数退避
196 |
197 | # 所有重试都失败
198 | print(f"任务 {task_id or '未命名'} 在 {self.max_retries} 次尝试后失败")
199 | return "" if not with_json_validation else {}
200 |
201 | def process_batch(self, prompts: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
202 | """并行处理一批提示
203 |
204 | Args:
205 | prompts: 提示配置列表,每项应包含:
206 | - prompt: 提示文本
207 | - task_id: 任务ID(可选)
208 | - with_json_validation: 是否需要JSON验证(可选)
209 | - required_fields: JSON验证所需字段(可选)
210 | - expected_type: 预期JSON类型(可选)
211 | - kwargs: 其他参数(可选)
212 |
213 | Returns:
214 | 结果列表,每项包含:
215 | - task_id: 任务ID
216 | - result: LLM响应
217 | - success: 是否成功
218 | """
219 | results = []
220 |
221 | # 使用ThreadPoolExecutor并行处理
222 | with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
223 | # 提交所有任务
224 | future_to_prompt = {}
225 | for i, prompt_config in enumerate(prompts):
226 | task_id = prompt_config.get("task_id", f"task_{i}")
227 |
228 | # 提取参数
229 | prompt_text = prompt_config.get("prompt", "")
230 | with_json_validation = prompt_config.get("with_json_validation", False)
231 | required_fields = prompt_config.get("required_fields")
232 | expected_type = prompt_config.get("expected_type")
233 | kwargs = prompt_config.get("kwargs", {})
234 |
235 | # 提交任务
236 | future = executor.submit(
237 | self._process_single_prompt,
238 | prompt=prompt_text,
239 | task_id=task_id,
240 | with_json_validation=with_json_validation,
241 | required_fields=required_fields,
242 | expected_type=expected_type,
243 | **kwargs
244 | )
245 |
246 | future_to_prompt[future] = {
247 | "task_id": task_id,
248 | "config": prompt_config
249 | }
250 |
251 | # 收集结果
252 | for future in as_completed(future_to_prompt):
253 | prompt_info = future_to_prompt[future]
254 | task_id = prompt_info["task_id"]
255 |
256 | try:
257 | result = future.result()
258 | results.append({
259 | "task_id": task_id,
260 | "result": result,
261 | "success": True
262 | })
263 | print(f"任务 {task_id} 成功完成")
264 | except Exception as e:
265 | results.append({
266 | "task_id": task_id,
267 | "result": None,
268 | "success": False,
269 | "error": str(e)
270 | })
271 | print(f"任务 {task_id} 处理失败: {e}")
272 |
273 | # 按原始顺序排序结果
274 | task_id_to_index = {prompt.get("task_id", f"task_{i}"): i for i, prompt in enumerate(prompts)}
275 | results.sort(key=lambda x: task_id_to_index.get(x["task_id"], 0))
276 |
277 | return results
278 |
279 |
280 | # 使用示例:
281 | """
282 | # 初始化处理器
283 | processor = ParallelLLMProcessor(llm=your_llm_instance, max_workers=5)
284 |
285 | # 准备一批提示
286 | prompts = [
287 | {
288 | "task_id": "table_analysis_1",
289 | "prompt": "分析表结构...",
290 | "with_json_validation": True,
291 | "expected_type": dict,
292 | "required_fields": ["table_type", "fields"]
293 | },
294 | {
295 | "task_id": "table_analysis_2",
296 | "prompt": "分析另一个表结构...",
297 | "with_json_validation": True,
298 | "expected_type": dict,
299 | "required_fields": ["table_type", "fields"]
300 | },
301 | # 更多提示...
302 | ]
303 |
304 | # 并行处理
305 | results = processor.process_batch(prompts)
306 |
307 | # 使用结果
308 | for result in results:
309 | if result["success"]:
310 | # 处理成功的结果
311 | print(f"任务 {result['task_id']} 结果: {result['result']}")
312 | else:
313 | # 处理失败
314 | print(f"任务 {result['task_id']} 失败: {result.get('error')}")
315 | """
--------------------------------------------------------------------------------
/mschema.py:
--------------------------------------------------------------------------------
1 | from utils import examples_to_str, read_json, write_json
2 | from type_engine import TypeEngine
3 | from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
4 |
5 |
6 | class MSchema:
7 | def __init__(self, db_id: str = 'Anonymous', type_engine: Optional[TypeEngine] = None,
8 | schema: Optional[str] = None):
9 | self.db_id = db_id
10 | self.schema = schema
11 | self.tables = {}
12 | self.foreign_keys = []
13 | self.type_engine = type_engine
14 |
15 | def add_table(self, name, fields={}, comment=None):
16 | self.tables[name] = {"fields": fields.copy(), 'examples': [], 'comment': comment}
17 |
18 | def add_field(self, table_name: str, field_name: str, field_type: str = "",
19 | primary_key: bool = False, nullable: bool = True, default: Any = None,
20 | autoincrement: bool = False, unique: bool = False,
21 | comment: str = "",
22 | examples: list = [], category: str = '', dim_or_meas: Optional[str] = '', **kwargs):
23 | self.tables[table_name]["fields"][field_name] = {
24 | "type": field_type,
25 | "primary_key": primary_key,
26 | "nullable": nullable,
27 | "default": default if default is None else f'{default}',
28 | "autoincrement": autoincrement,
29 | "unique": unique,
30 | "comment": comment,
31 | "examples": examples.copy(),
32 | "category": category,
33 | "dim_or_meas": dim_or_meas,
34 | **kwargs}
35 |
36 | def add_foreign_key(self, table_name, field_name, ref_schema, ref_table_name, ref_field_name):
37 | self.foreign_keys.append([table_name, field_name, ref_schema, ref_table_name, ref_field_name])
38 |
39 | def get_abbr_field_type(self, field_type, simple_mode=True)->str:
40 | if not simple_mode:
41 | return field_type
42 | else:
43 | return field_type.split("(")[0]
44 |
45 | def erase_all_table_comment(self):
46 | """clear all table descriptions."""
47 | for table_name in self.tables.keys():
48 | self.tables[table_name]['comment'] = ''
49 |
50 | def erase_all_column_comment(self):
51 | """clear all column descriptions."""
52 | for table_name in self.tables.keys():
53 | fields = self.tables[table_name]['fields']
54 | for field_name, field_info in fields.items():
55 | self.tables[table_name]['fields'][field_name]['comment'] = ''
56 |
57 | def has_table(self, table_name: str) -> bool:
58 | """check if given table_name exists in M-Schema"""
59 | if table_name in self.tables.keys():
60 | return True
61 | else:
62 | return False
63 |
64 | def has_column(self, table_name: str, field_name: str) -> bool:
65 | if self.has_table(table_name):
66 | if field_name in self.tables[table_name]["fields"].keys():
67 | return True
68 | else:
69 | return False
70 | else:
71 | return False
72 |
73 | def set_table_property(self, table_name: str, key: str, value: Any):
74 | if not self.has_table(table_name):
75 | print("The table name {} does not exist in M-Schema.".format(table_name))
76 | else:
77 | self.tables[table_name][key] = value
78 |
79 | def set_column_property(self, table_name: str, field_name: str, key: str, value: Any):
80 | if not self.has_column(table_name, field_name):
81 | print("The table name {} or column name {} does not exist in M-Schema.".format(table_name, field_name))
82 | else:
83 | self.tables[table_name]['fields'][field_name][key] = value
84 |
85 | def get_field_info(self, table_name: str, field_name: str) -> Dict:
86 | try:
87 | return self.tables[table_name]['fields'][field_name]
88 | except:
89 | return {}
90 |
91 | def get_category_fields(self, category: str, table_name: str) -> List:
92 | """
93 | 给定table_name和category,获取当前table下所有category类型的字段名称
94 | category: 从type_engine.field_category_all_labels中取值
95 | """
96 | assert category in self.type_engine.field_category_all_labels, \
97 | 'Invalid category {}'.format(category)
98 | if self.has_table(table_name):
99 | res = []
100 | fields = self.tables[table_name]['fields']
101 | for field_name, field_info in fields.items():
102 | _ = field_info.get('category', '')
103 | if _ == category:
104 | res.append(field_name)
105 | return res
106 | else:
107 | return []
108 |
109 | def get_dim_or_meas_fields(self, dim_or_meas: str, table_name: str) -> List:
110 | assert dim_or_meas in self.type_engine.dim_measure_labels, 'Invalid dim_or_meas {}'.format(dim_or_meas)
111 | if self.has_table(table_name):
112 | res = []
113 | fields = self.tables[table_name]['fields']
114 | for field_name, field_info in fields.items():
115 | _ = field_info.get('dim_or_meas', '')
116 | if _ == dim_or_meas:
117 | res.append(field_name)
118 | return res
119 | else:
120 | return []
121 |
122 | def single_table_mschema(self, table_name: str, selected_columns: Optional[List] = None, example_num=3, show_type_detail=False) -> str:
123 | table_info = self.tables.get(table_name, {})
124 | output = []
125 | table_comment = table_info.get('comment', '')
126 | if table_comment is not None and table_comment != 'None' and len(table_comment) > 0:
127 | if self.schema is not None and len(self.schema) > 0:
128 | output.append(f"# Table: {self.schema}.{table_name}, {table_comment}")
129 | else:
130 | output.append(f"# Table: {table_name}, {table_comment}")
131 | else:
132 | if self.schema is not None and len(self.schema) > 0:
133 | output.append(f"# Table: {self.schema}.{table_name}")
134 | else:
135 | output.append(f"# Table: {table_name}")
136 |
137 | field_lines = []
138 | # 处理表中的每一个字段
139 | for field_name, field_info in table_info['fields'].items():
140 | if selected_columns is not None and field_name.lower() not in selected_columns:
141 | continue
142 |
143 | raw_type = self.get_abbr_field_type(field_info['type'], not show_type_detail)
144 | field_line = f"({field_name}:{raw_type.upper()}"
145 | if len(field_info['comment']) > 0:
146 | field_line += f", {field_info['comment'].strip()}"
147 |
148 | ## 打上主键标识
149 | is_primary_key = field_info.get('primary_key', False)
150 | if is_primary_key:
151 | field_line += f", Primary Key"
152 |
153 | # 如果有示例,添加上
154 | if len(field_info.get('examples', [])) > 0 and example_num > 0:
155 | examples = field_info['examples']
156 | examples = [s for s in examples if s is not None]
157 | examples = examples_to_str(examples)
158 | if len(examples) > example_num:
159 | examples = examples[:example_num]
160 |
161 | if raw_type in ['DATE', 'TIME', 'DATETIME', 'TIMESTAMP']:
162 | examples = [examples[0]]
163 | elif len(examples) > 0 and max([len(s) for s in examples]) > 20:
164 | if max([len(s) for s in examples]) > 50:
165 | examples = []
166 | else:
167 | examples = [examples[0]]
168 | else:
169 | pass
170 | if len(examples) > 0:
171 | example_str = ', '.join([str(example) for example in examples])
172 | field_line += f", Examples: [{example_str}]"
173 | else:
174 | pass
175 | else:
176 | field_line += ""
177 | field_line += ")"
178 |
179 | field_lines.append(field_line)
180 | output.append('[')
181 | output.append(',\n'.join(field_lines))
182 | output.append(']')
183 |
184 | return '\n'.join(output)
185 |
186 | def to_mschema(self, selected_tables: List = None, selected_columns: List = None,
187 | example_num=3, show_type_detail=False) -> str:
188 | """
189 | convert to a MSchema string.
190 | selected_tables: 默认为None,表示选择所有的表
191 | selected_columns: 默认为None,表示所有列全选,格式['table_name.column_name']
192 | """
193 | output = []
194 |
195 | output.append(f"【DB_ID】 {self.db_id}")
196 | output.append(f"【Schema】")
197 |
198 | if selected_tables is not None:
199 | selected_tables = [s.lower() for s in selected_tables]
200 | if selected_columns is not None:
201 | selected_columns = [s.lower() for s in selected_columns]
202 | selected_tables = [s.split('.')[0].lower() for s in selected_columns]
203 |
204 | # 依次处理每一个表
205 | for table_name, table_info in self.tables.items():
206 | if selected_tables is None or table_name.lower() in selected_tables:
207 | column_names = list(table_info['fields'].keys())
208 | if selected_columns is not None:
209 | cur_selected_columns = [c for c in column_names if f"{table_name}.{c}".lower() in selected_columns]
210 | else:
211 | cur_selected_columns = selected_columns
212 | output.append(self.single_table_mschema(table_name, cur_selected_columns, example_num, show_type_detail))
213 |
214 | # 添加外键信息,选择table_type为view时不展示外键
215 | if self.foreign_keys:
216 | output.append("【Foreign keys】")
217 | for fk in self.foreign_keys:
218 | ref_schema = fk[2]
219 | table1, column1, _, table2, column2 = fk
220 | if selected_tables is None or \
221 | (table1.lower() in selected_tables and table2.lower() in selected_tables):
222 | if ref_schema == self.schema:
223 | output.append(f"{fk[0]}.{fk[1]}={fk[3]}.{fk[4]}")
224 |
225 | return '\n'.join(output)
226 |
227 | def dump(self):
228 | schema_dict = {
229 | "db_id": self.db_id,
230 | "schema": self.schema,
231 | "tables": self.tables,
232 | "foreign_keys": self.foreign_keys
233 | }
234 | return schema_dict
235 |
236 | def save(self, file_path: str):
237 | schema_dict = self.dump()
238 | write_json(file_path, schema_dict)
239 |
240 | def load(self, file_path: str):
241 | data = read_json(file_path)
242 | self.db_id = data.get("db_id", "Anonymous")
243 | self.schema = data.get("schema", None)
244 | self.tables = data.get("tables", {})
245 | self.foreign_keys = data.get("foreign_keys", [])
246 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/schema_engine.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
2 | from concurrent.futures import ThreadPoolExecutor, TimeoutError
3 | from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, select, text
4 | from sqlalchemy.engine import Engine
5 | from llama_index.core import SQLDatabase
6 | from llama_index.core.llms import LLM
7 | from components import (
8 | field_category,
9 | generate_column_desc,
10 | generate_table_desc,
11 | understand_fields_by_category,
12 | understand_database,
13 | understand_date_time_min_gran,
14 | dummy_sql_generator
15 | )
16 | from utils import examples_to_str
17 | from type_engine import TypeEngine
18 | from mschema import MSchema
19 |
20 |
21 | class SchemaEngine(SQLDatabase):
22 | def __init__(self, engine: Engine, schema: Optional[str] = None, metadata: Optional[MetaData] = None,
23 | ignore_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None,
24 | sample_rows_in_table_info: int = 3, indexes_in_table_info: bool = False,
25 | custom_table_info: Optional[dict] = None, view_support: bool = False, max_string_length: int = 300,
26 | mschema: Optional[MSchema] = None, llm: Optional[LLM] = None,
27 | db_name: Optional[str] = '', comment_mode: str = 'origin'):
28 | super().__init__(engine, schema, metadata, ignore_tables, include_tables, sample_rows_in_table_info,
29 | indexes_in_table_info, custom_table_info, view_support, max_string_length)
30 |
31 | self._db_name = db_name
32 | self._usable_tables = [table_name for table_name in self._usable_tables if self._inspector.has_table(table_name, schema)]
33 | self._dialect = engine.dialect.name
34 | self._type_engine = TypeEngine()
35 | assert self._dialect in self._type_engine.supported_dialects, "Unsupported dialect {}.".format(self._dialect)
36 |
37 | self._llm = llm
38 |
39 | if mschema is not None:
40 | self._mschema = mschema
41 | else:
42 | self._mschema = MSchema(db_id=db_name, schema=schema, type_engine=self._type_engine)
43 | self.init_mschema()
44 |
45 | self.comment_mode = comment_mode
46 |
47 | @property
48 | def mschema(self) -> MSchema:
49 | """Return M-Schema"""
50 | return self._mschema
51 |
52 | @property
53 | def type_engine(self) -> TypeEngine:
54 | return self._type_engine
55 |
56 | def get_pk_constraint(self, table_name: str) -> Dict:
57 | return self._inspector.get_pk_constraint(table_name, self._schema)['constrained_columns']
58 |
59 | def get_table_comment(self, table_name: str):
60 | try:
61 | return self._inspector.get_table_comment(table_name, self._schema)['text']
62 | except: # sqlite不支持添加注释
63 | return ''
64 |
65 | def default_schema_name(self) -> Optional[str]:
66 | return self._inspector.default_schema_name
67 |
68 | def get_schema_names(self) -> List[str]:
69 | return self._inspector.get_schema_names()
70 |
71 | def get_table_options(self, table_name: str) -> Dict[str, Any]:
72 | return self._inspector.get_table_options(table_name, self._schema)
73 |
74 | def get_foreign_keys(self, table_name: str):
75 | return self._inspector.get_foreign_keys(table_name, self._schema)
76 |
77 | def get_unique_constraints(self, table_name: str):
78 | # 唯一键
79 | return self._inspector.get_unique_constraints(table_name, self._schema)
80 |
81 | def get_indexes(self, table_name: str):
82 | # 索引字段
83 | return self._inspector.get_indexes(table_name, self._schema)
84 |
85 | def add_semicolon_to_sql(self, sql_query: str):
86 | if not sql_query.strip().endswith(';'):
87 | sql_query += ';'
88 | return sql_query
89 |
90 | def fetch(self, sql_query: str):
91 | sql_query = self.add_semicolon_to_sql(sql_query)
92 |
93 | with self._engine.begin() as connection:
94 | try:
95 | cursor = connection.execute(text(sql_query))
96 | records = cursor.fetchall()
97 | except Exception as e:
98 | print("An exception occurred during SQL execution.\n", e)
99 | records = None
100 | return records
101 |
102 | def fetch_truncated(self, sql_query: str, max_rows: Optional[int] = None, max_str_len: int = 30) -> Dict:
103 | sql_query = self.add_semicolon_to_sql(sql_query)
104 | with self._engine.begin() as connection:
105 | try:
106 | cursor = connection.execute(text(sql_query))
107 | result = cursor.fetchall()
108 | truncated_results = []
109 | if max_rows:
110 | result = result[:max_rows]
111 | for row in result:
112 | truncated_row = tuple(
113 | self.truncate_word(column, length=max_str_len)
114 | for column in row
115 | )
116 | truncated_results.append(truncated_row)
117 | return {"truncated_results": truncated_results, "fields": list(cursor.keys())}
118 | except Exception as e:
119 | print("An exception occurred during SQL execution.\n", e)
120 | records = None
121 | return {"truncated_results": records, "fields": []}
122 |
123 | def trunc_result_to_markdown(self, sql_res: Dict) -> str:
124 | """
125 | 数据库查询结果转换成markdown格式
126 | """
127 | truncated_results = sql_res.get("truncated_results", [])
128 | fields = sql_res.get("fields", [])
129 |
130 | if not truncated_results:
131 | return ""
132 |
133 | header = "| " + " | ".join(fields) + " |"
134 | separator = "| " + " | ".join(["---"] * len(fields)) + " |"
135 | rows = []
136 | for row in truncated_results:
137 | rows.append("| " + " | ".join(str(value) for value in row) + " |")
138 | markdown_table = "\n".join([header, separator] + rows)
139 | return markdown_table
140 |
141 | def execute(self, sql_query: str, timeout=10) -> Any:
142 | sql_query = self.add_semicolon_to_sql(sql_query)
143 | def run_query():
144 | with self._engine.begin() as connection:
145 | cursor = connection.execute(text(sql_query))
146 | return True
147 | with ThreadPoolExecutor(max_workers=1) as executor:
148 | future = executor.submit(run_query)
149 | try:
150 | result = future.result(timeout=timeout)
151 | return result
152 | except TimeoutError:
153 | print(f"SQL执行超时({timeout}秒){sql_query}.")
154 | return None
155 | except Exception as e:
156 | print("执行SQL时发生异常。", e)
157 | return None
158 |
159 | def get_protected_table_name(self, table_name: str) -> str:
160 | if self._dialect == self._type_engine.mysql_dialect or self._dialect == self._type_engine.sqlite_dialect:
161 | return f'`{table_name}`'
162 | elif self._dialect == self._type_engine.postgres_dialect:
163 | if self._schema:
164 | return f'"{self._schema}"."{table_name}"'
165 | else:
166 | return f'"{table_name}"'
167 | elif self._dialect == self._type_engine.sqlserver_dialect:
168 | return f'[{table_name}]'
169 | else:
170 | raise NotImplementedError
171 |
172 | def get_protected_field_name(self, field_name: str) -> str:
173 | if self._dialect == self._type_engine.mysql_dialect or self._dialect == self._type_engine.sqlite_dialect:
174 | return f'`{field_name}`'
175 | elif self._dialect == self._type_engine.postgres_dialect:
176 | return f'"{field_name}"'
177 | else:
178 | raise NotImplementedError
179 |
180 | def init_mschema(self):
181 | for table_name in self._usable_tables:
182 | table_comment = self.get_table_comment(table_name)
183 | table_comment = '' if table_comment is None else table_comment.strip()
184 | self._mschema.add_table(table_name, fields={}, comment=table_comment)
185 | pks = self.get_pk_constraint(table_name)
186 |
187 | # 数据表的唯一键
188 | unique_keys = []
189 | unique_constraints = self.get_unique_constraints(table_name)
190 | for u_con in unique_constraints:
191 | column_names = u_con['column_names']
192 | unique_keys.append(column_names)
193 | self._mschema.tables[table_name]['unique_keys'] = unique_keys
194 |
195 | # 数据表索引
196 | indexes = self.get_indexes(table_name)
197 | keys = []
198 | for index in indexes:
199 | is_unique = index.get("unique", False)
200 | keys.append(index['column_names'])
201 | self._mschema.tables[table_name]['keys'] = keys
202 |
203 | fks = self.get_foreign_keys(table_name)
204 | constrained_columns = []
205 | for fk in fks:
206 | referred_schema = fk['referred_schema']
207 | for c, r in zip(fk['constrained_columns'], fk['referred_columns']):
208 | self._mschema.add_foreign_key(table_name, c, referred_schema, fk['referred_table'], r)
209 | constrained_columns.append(c)
210 |
211 | fields = self._inspector.get_columns(table_name, schema=self._schema)
212 | for field in fields:
213 | field_type = f"{field['type']!s}"
214 | field_name = field['name']
215 | if field_name in pks:
216 | primary_key = True
217 | if len(pks) == 1:
218 | is_unique = True
219 | else:
220 | is_unique = False
221 | else:
222 | primary_key = False
223 | if [field_name] in unique_keys:
224 | is_unique = True
225 | else:
226 | is_unique = False
227 | field_comment = field.get("comment", None)
228 | field_comment = "" if field_comment is None else field_comment.strip()
229 | autoincrement = field.get('autoincrement', False)
230 | default = field.get('default', None)
231 | if default is not None:
232 | default = f'{default}'
233 |
234 | examples = []
235 | try:
236 | sql = f"select distinct {self.get_protected_field_name(field_name)} from {self.get_protected_table_name(table_name)} where {self.get_protected_field_name(field_name)} is not null limit 5;"
237 | examples = [s[0] for s in self.fetch(sql)]
238 | except:
239 | pass
240 | examples = examples_to_str(examples)
241 | if None in examples:
242 | examples.remove(None)
243 | if '' in examples:
244 | examples.remove('')
245 |
246 | self._mschema.add_field(table_name, field_name, field_type=field_type, primary_key=primary_key,
247 | nullable=field['nullable'], default=default, autoincrement=autoincrement, unique=is_unique,
248 | comment=field_comment, examples=examples)
249 |
250 | def get_column_count(self, table_name: str, field_name: str) -> int:
251 | sql = 'select count({}) from {};'.format(self.get_protected_field_name(field_name),
252 | self.get_protected_table_name(table_name))
253 | r = self.fetch(sql)
254 | if r is not None:
255 | total_num = r[0][0]
256 | else:
257 | total_num = -1
258 | return total_num
259 |
260 | def get_column_unique_count(self, table_name: str, field_name: str) -> int:
261 | sql = 'select count(distinct {}) from {};'.format(self.get_protected_field_name(field_name),
262 | self.get_protected_table_name(table_name))
263 | r = self.fetch(sql)
264 | if r is not None:
265 | unique_num = r[0][0]
266 | else:
267 | unique_num = -1
268 | return unique_num
269 |
270 | def get_column_value_examples(self, table_name: str, field_name: str, max_rows: Optional[int] = None, max_str_len: int = 30)-> List:
271 | sql = 'select distinct {} from {} where {} is not null;'.format(self.get_protected_field_name(field_name),
272 | self.get_protected_table_name(table_name), self.get_protected_field_name(field_name))
273 | res = self.fetch_truncated(sql, max_rows, max_str_len)
274 | res = res['truncated_results']
275 | if res is not None:
276 | return [r[0] for r in res]
277 | else:
278 | return []
279 |
280 | def check_column_value_exist(self, table_name: str, field_name: str, value_name: str, is_string: bool) -> bool:
281 | if is_string:
282 | sql = '''select count(*) from {} where {} = '{}';'''.format(
283 | self.get_protected_table_name(table_name), self.get_protected_field_name(field_name), value_name.replace("'", "''"))
284 | else:
285 | sql = "select count(*) from {} where {} = {};".format(
286 | self.get_protected_table_name(table_name), self.get_protected_field_name(field_name), value_name)
287 | r = self.fetch(sql)
288 | if r is not None:
289 | return r[0][0] > 0
290 | else:
291 | return False
292 |
293 | def check_agg_func(self, agg_func: str):
294 | assert agg_func.upper() in ['MAX', 'MIN', 'AVG', 'SUM'], \
295 | "Invalid aggregate function {}.".format(agg_func)
296 |
297 | def get_column_agg_value(self, table_name: str, field_name: str, field_type: str, agg_func: str):
298 | self.check_agg_func(agg_func)
299 | if self._type_engine.field_type_cate(field_type) != self._type_engine.field_type_number_label:
300 | return None
301 |
302 | sql = 'select {}({}) from {} where {} is not null;'.format(agg_func, self.get_protected_field_name(field_name),
303 | self.get_protected_table_name(table_name), self.get_protected_field_name(field_name))
304 | r = self.fetch(sql)
305 | if r is not None:
306 | return r[0][0]
307 | else:
308 | return None
309 |
310 | def get_column_agg_char_length(self, table_name: str, field_name: str, agg_func: str) -> int:
311 | if self._dialect == self._type_engine.postgres_dialect:
312 | snip = '{}::TEXT'.format(self.get_protected_field_name(field_name))
313 | elif self._dialect == self._type_engine.mysql_dialect:
314 | snip = '{}'.format(self.get_protected_field_name(field_name))
315 | elif self._dialect == self._type_engine.sqlite_dialect:
316 | snip = "CAST({} AS TEXT)".format(self.get_protected_field_name(field_name))
317 | elif self._dialect == self._type_engine.sqlserver_dialect:
318 | snip = 'CAST({} AS NVARCHAR(MAX))'.format(self.get_protected_field_name(field_name))
319 | else:
320 | raise NotImplementedError
321 |
322 | self.check_agg_func(agg_func)
323 | if self._dialect == self._type_engine.sqlite_dialect:
324 | sql = 'select {}(length({})) from {} where {} is not null;'.format(agg_func, snip,
325 | self.get_protected_table_name(table_name),self.get_protected_field_name(field_name))
326 | elif self._dialect == self._type_engine.sqlserver_dialect:
327 | sql = 'select {}(LEN({})) from {} where {} is not null;'.format(agg_func, snip,
328 | self.get_protected_table_name(table_name), self.get_protected_field_name(field_name))
329 | else:
330 | sql = 'select {}(char_length({})) from {} where {} is not null;'.format(agg_func, snip,
331 | self.get_protected_table_name(table_name), self.get_protected_field_name(field_name))
332 | r = self.fetch(sql)
333 | if r is not None and r[0][0] is not None:
334 | return r[0][0]
335 | else:
336 | return -1
337 |
338 | def get_all_field_examples(self, table_name: str, max_rows: Optional[int] = None):
339 | sql = f"""SELECT DISTINCT * FROM {self.get_protected_table_name(table_name)}""" # group by {dimension_fields}
340 | if max_rows is not None and max_rows > 0:
341 | sql += ' LIMIT {};'.format(max_rows)
342 | return sql
343 |
344 | def get_single_field_info_str(self, table_name: str, field_name: str)->str:
345 | """
346 | 某一列的相关信息:列名、类型、列描述、是否主键、最大/最小值等
347 | """
348 | field_info = self._mschema.get_field_info(table_name, field_name)
349 | field_type = field_info.get('type', '')
350 |
351 | unique_num = self.get_column_unique_count(table_name, field_name)
352 | total_num = self.get_column_count(table_name, field_name)
353 | max_value = self.get_column_agg_value(table_name, field_name, field_type, 'max')
354 | min_value = self.get_column_agg_value(table_name, field_name, field_type, 'min')
355 | avg_value = self.get_column_agg_value(table_name, field_name, field_type, 'avg')
356 | max_len = self.get_column_agg_char_length(table_name, field_name, 'max')
357 | min_len = self.get_column_agg_char_length(table_name, field_name, 'min')
358 |
359 | comment = field_info.get('comment', '')
360 | primary_key = field_info.get("primary_key", False)
361 | nullable = field_info.get("nullable", True)
362 |
363 | field_info_str = ['【字段信息】', f'字段名称: {field_name}', f'字段类型: {field_type}']
364 | dim_or_meas = field_info.get('dim_or_meas', '')
365 | unique = field_info.get('unique', False)
366 | if primary_key:
367 | unique = True
368 |
369 | if len(comment) > 0:
370 | field_info_str.append(f'字段描述: {comment}')
371 | field_info_str.append(f'是否为主键(或者与其他字段组成联合主键): {primary_key}')
372 | field_info_str.append(f'UNIQUE: {unique}')
373 | field_info_str.append(f'NULLABLE: {nullable}')
374 | date_min_gran = field_info.get('date_min_gran', None)
375 | if total_num >= 0:
376 | field_info_str.append(f'COUNT: {total_num}')
377 | if unique_num >= 0:
378 | field_info_str.append(f'COUNT(DISTINCT): {unique_num}')
379 | if max_value is not None:
380 | field_info_str.append(f'MAX: {max_value}')
381 | if min_value is not None:
382 | field_info_str.append(f'MIN: {min_value}')
383 | if avg_value is not None:
384 | field_info_str.append(f'AVG: {avg_value}')
385 | if max_len >= 0:
386 | field_info_str.append(f'MAX(CHAR_LENGTH): {max_len}')
387 | if min_len >= 0:
388 | field_info_str.append(f'MIN(CHAR_LENGTH): {min_len}')
389 | if dim_or_meas in self._type_engine.dim_measure_labels:
390 | field_info_str.append(f'Dimension/Measure: {dim_or_meas}')
391 | if date_min_gran is not None:
392 | field_info_str.append(f'该字段表示的语义可能与日期或时间有关,推测它表示的最小时间颗粒度是: {date_min_gran}')
393 |
394 | value_examples = self.get_column_value_examples(table_name, field_name, max_rows=10, max_str_len=30)
395 | if len(value_examples) > 0:
396 | field_info_str.append(f"Value Examples: {value_examples}")
397 |
398 | return '\n'.join(field_info_str)
399 |
400 | def fields_category(self):
401 | tables = self._mschema.tables
402 | for table_name in tables.keys():
403 | print("Table Name: ", table_name)
404 | fields = tables[table_name]['fields']
405 | for field_name, field_info in fields.items():
406 | print("Field Name: ", field_name)
407 | field_type = field_info['type']
408 | field_type_cate = self._type_engine.field_type_cate(field_type)
409 | field_info_str = self.get_single_field_info_str(table_name, field_name)
410 | res = field_category(field_type_cate, self._type_engine, self._llm, field_info_str=field_info_str)
411 | print(field_info_str)
412 | print(res)
413 | if res['category'] == self._type_engine.field_category_date_label:
414 | min_gran = understand_date_time_min_gran(field_info_str, llm=self._llm)
415 | print("最小时间颗粒度:", min_gran)
416 | if min_gran in self._type_engine.date_time_min_grans:
417 | self._mschema.set_column_property(table_name, field_name, "date_min_gran", min_gran)
418 |
419 | category = res['category']
420 | # 对于枚举类型的字段,获取它所有的枚举候选值
421 | if category == self._type_engine.field_category_enum_label:
422 | examples = self.get_column_value_examples(table_name, field_name)
423 | examples = [s for s in examples if len(str(examples)) > 0]
424 | self._mschema.set_column_property(table_name, field_name, "examples", examples)
425 | self._mschema.set_column_property(table_name, field_name, "category", res['category'])
426 | self._mschema.set_column_property(table_name, field_name, "dim_or_meas", res['dim_or_meas'])
427 |
428 |
429 | def table_and_column_desc_generation(self, language: str='CN'):
430 | """"
431 | table and column description genration
432 |
433 | 四种模式:
434 | no_comment: 不带任何描述信息
435 | origin: 跟数据库中保持一致
436 | generation: 清除已有的描述信息,完全由模型生成
437 | merge: 没有描述的生成描述信息;已经有描述信息的,不再生成
438 | """
439 | if self.comment_mode == 'origin':
440 | return
441 | elif self.comment_mode == 'merge':
442 | pass
443 | elif self.comment_mode == 'generation':
444 | self._mschema.erase_all_column_comment()
445 | self._mschema.erase_all_table_comment()
446 | elif self.comment_mode == 'no_comment':
447 | self._mschema.erase_all_column_comment()
448 | self._mschema.erase_all_table_comment()
449 | return
450 | else:
451 | raise NotImplementedError(f"Unsupported comment mode {self.comment_mode}.")
452 |
453 | db_mschema = self._mschema.to_mschema()
454 | """1、初步理解数据库的基本信息和每张表的内容"""
455 | db_info = understand_database(db_mschema, self._llm)
456 | self._mschema.db_info = db_info
457 | print("DB INFO: ", db_info)
458 |
459 | for table_name, table_info in self._mschema.tables.items():
460 | fields = table_info['fields']
461 | table_comment = table_info.get('comment', '')
462 | if len(table_comment) >= 10:
463 | need_table_comment = False
464 | else:
465 | need_table_comment = True
466 |
467 | table_mschema = self._mschema.single_table_mschema(table_name)
468 |
469 | sql = self.get_all_field_examples(table_name, max_rows=10)
470 | res = self.fetch_truncated(sql, max_rows=10)
471 | res = self.trunc_result_to_markdown(res)
472 |
473 | """2、按照维度和度量分类,理解各个维度/度量字段之间的区别与联系,供参考"""
474 | supp_info = {}
475 | dim_fields = self._mschema.get_dim_or_meas_fields(self._type_engine.dimension_label, table_name)
476 | mea_fields = self._mschema.get_dim_or_meas_fields(self._type_engine.measure_label, table_name)
477 | if len(dim_fields) > 0:
478 | supp_info[self._type_engine.dimension_label] = understand_fields_by_category(db_info, table_name,
479 | table_mschema, self._llm, sql, res, dim_fields, self._type_engine.dimension_label)
480 | if len(mea_fields) > 0:
481 | supp_info[self._type_engine.measure_label] = understand_fields_by_category(db_info, table_name,
482 | table_mschema, self._llm, sql, res, mea_fields, self._type_engine.measure_label)
483 | print("Supplementary information:")
484 | print(supp_info)
485 |
486 | """3、对每一列生成列描述"""
487 | for field_name, field_info in fields.items():
488 | field_info_str = self.get_single_field_info_str(table_name, field_name)
489 | dim_or_meas = field_info.get("dim_or_meas", '')
490 | field_desc = field_info.get('comment', '')
491 | if len(field_desc) == 0: # 原来没有字段描述,重新生成
492 | field_desc = generate_column_desc(field_name, field_info_str, table_mschema,
493 | self._llm, sql, res, supp_info.get(dim_or_meas, ""),
494 | language=language)
495 | print("Table Name: {}, Field Name: {}".format(table_name, field_name))
496 | print("Column Description: {}".format(field_desc))
497 | self._mschema.set_column_property(table_name, field_name, 'comment', field_desc)
498 |
499 | """4、表描述生成"""
500 | table_mschema = self._mschema.single_table_mschema(table_name)
501 | if need_table_comment:
502 | table_desc = generate_table_desc(table_name, table_mschema, self._llm, sql, res, language=language)
503 | print("Table Description: {}".format(table_desc))
504 | self._mschema.set_table_property(table_name, 'comment', table_desc)
505 |
506 | def sql_generator(self, question: str, evidence: str = '') -> str:
507 | db_mschema = self._mschema.to_mschema()
508 | pred_sql = dummy_sql_generator(self._dialect, db_mschema=db_mschema,
509 | question=question, evidence=evidence, llm=self._llm)
510 |
511 | return pred_sql
512 |
--------------------------------------------------------------------------------