├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── assets ├── BIRD_Heatmap.png ├── Boxplot.png ├── DB_Domain_Boxplot.png ├── DB_Domain_Heatmap.png ├── QVT.png ├── QVT_New.png ├── SQLiteStudio.png ├── Spider_Heatmap.png ├── domain.png ├── leaderboard.png ├── nl2sql360.png └── sql_charac.png ├── examples ├── cli_examples │ ├── dataset_spider.yaml │ ├── delete_history.yaml │ ├── evaluation.yaml │ └── report.yaml └── py_examples │ ├── dataset_import.py │ ├── delete_history.py │ ├── evaluation.py │ └── report.py ├── requirements.txt ├── setup.py └── src └── nl2sql360 ├── __init__.py ├── arguments ├── __init__.py ├── core_args.py ├── dataset_args.py ├── delete_history_args.py ├── evaluation_args.py ├── hf_argparser.py ├── parser.py └── report_args.py ├── cli ├── __init__.py ├── cli.py └── util.py ├── core ├── __init__.py └── core.py ├── database ├── __init__.py ├── model.py ├── template.py └── util.py ├── dataset ├── __init__.py └── dataset.py ├── evaluator ├── __init__.py ├── bird_accuracy.py ├── bird_eval │ ├── __init__.py │ ├── bird_accuracy.py │ └── bird_ves.py ├── spider_accuracy.py ├── test_suite_sql_eval │ ├── __init__.py │ ├── evaluation.py │ ├── exec_eval.py │ ├── parse.py │ └── process_sql.py └── ves.py ├── filter ├── __init__.py └── filter.py └── parser ├── __init__.py └── sql_parser.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/data 2 | **/.vscode 3 | **/__pycache__ 4 | **/dist 5 | **/build 6 | **/*.egg-info 7 | **/tests -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Boyan Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include assets/* 2 | include examples/* 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # :mag_right:NL2SQL360 2 | 3 | **Please visit [NL2SQL Official Repo](https://github.com/HKUSTDial/NL2SQL360) to get the latest code!** 4 | 5 | ## :pushpin:Citation 6 | 7 | ``` 8 | @misc{li2024dawn, 9 | title={The Dawn of Natural Language to SQL: Are We Fully Ready?}, 10 | author={Boyan Li and Yuyu Luo and Chengliang Chai and Guoliang Li and Nan Tang}, 11 | year={2024}, 12 | eprint={2406.01265}, 13 | archivePrefix={arXiv}, 14 | primaryClass={id='cs.DB' full_name='Databases' is_active=True alt_name=None in_archive='cs' is_general=False description='Covers database management, datamining, and data processing. Roughly includes material in ACM Subject Classes E.2, E.5, H.0, H.2, and J.1.'} 15 | } 16 | ``` 17 | -------------------------------------------------------------------------------- /assets/BIRD_Heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BugMaker-Boyan/NL2SQL360/59407da3c84a8b64b61d2e44e23d498146230cb5/assets/BIRD_Heatmap.png -------------------------------------------------------------------------------- /assets/Boxplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BugMaker-Boyan/NL2SQL360/59407da3c84a8b64b61d2e44e23d498146230cb5/assets/Boxplot.png -------------------------------------------------------------------------------- /assets/DB_Domain_Boxplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BugMaker-Boyan/NL2SQL360/59407da3c84a8b64b61d2e44e23d498146230cb5/assets/DB_Domain_Boxplot.png -------------------------------------------------------------------------------- /assets/DB_Domain_Heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BugMaker-Boyan/NL2SQL360/59407da3c84a8b64b61d2e44e23d498146230cb5/assets/DB_Domain_Heatmap.png -------------------------------------------------------------------------------- /assets/QVT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BugMaker-Boyan/NL2SQL360/59407da3c84a8b64b61d2e44e23d498146230cb5/assets/QVT.png -------------------------------------------------------------------------------- /assets/QVT_New.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BugMaker-Boyan/NL2SQL360/59407da3c84a8b64b61d2e44e23d498146230cb5/assets/QVT_New.png -------------------------------------------------------------------------------- /assets/SQLiteStudio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BugMaker-Boyan/NL2SQL360/59407da3c84a8b64b61d2e44e23d498146230cb5/assets/SQLiteStudio.png -------------------------------------------------------------------------------- /assets/Spider_Heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BugMaker-Boyan/NL2SQL360/59407da3c84a8b64b61d2e44e23d498146230cb5/assets/Spider_Heatmap.png -------------------------------------------------------------------------------- /assets/domain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BugMaker-Boyan/NL2SQL360/59407da3c84a8b64b61d2e44e23d498146230cb5/assets/domain.png -------------------------------------------------------------------------------- /assets/leaderboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BugMaker-Boyan/NL2SQL360/59407da3c84a8b64b61d2e44e23d498146230cb5/assets/leaderboard.png -------------------------------------------------------------------------------- /assets/nl2sql360.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BugMaker-Boyan/NL2SQL360/59407da3c84a8b64b61d2e44e23d498146230cb5/assets/nl2sql360.png -------------------------------------------------------------------------------- /assets/sql_charac.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BugMaker-Boyan/NL2SQL360/59407da3c84a8b64b61d2e44e23d498146230cb5/assets/sql_charac.png -------------------------------------------------------------------------------- /examples/cli_examples/dataset_spider.yaml: -------------------------------------------------------------------------------- 1 | # ---------------- Core Arguments ---------------- 2 | 3 | # The directory to save NL2SQL360 core data 4 | core_dir: "data" 5 | 6 | # The NL2SQL360 core name, such that NL2SQL360 core data is saved to "core_dir/core_name.sqlite" 7 | core_name: "nl2sql360" 8 | 9 | # The dataset SQL dialect, "sqlite" by default 10 | sql_dialect: "sqlite" 11 | 12 | # ---------------- Dataset Arguments ---------------- 13 | 14 | # The unique dataset name to import 15 | dataset_name: "spider_dev" 16 | 17 | # The dataset root directory path 18 | dataset_dir: "data/spider" 19 | 20 | # The name of samples json file 21 | samples_file: "dev.json" 22 | 23 | # The name of tables json file 24 | tables_file: "tables.json" 25 | 26 | # The database directory in "dataset_dir" 27 | database_dir: "database" 28 | 29 | # The key name of NL question in samples json file 30 | question_key: "question" 31 | 32 | # The key name of gold sql in samples json file 33 | sql_key: "query" 34 | 35 | # The key name of database id in samples json file 36 | db_id_key: "db_id" 37 | 38 | # The key name of sql complexity in samples json file, set to null if not specified 39 | sql_complexity_key: null 40 | 41 | # The database domain json file (mapping database id to database domain) in "dataset_dir" 42 | database_domain_file: null 43 | -------------------------------------------------------------------------------- /examples/cli_examples/delete_history.yaml: -------------------------------------------------------------------------------- 1 | # ---------------- Core Arguments ---------------- 2 | 3 | # The directory to save NL2SQL360 core data. 4 | core_dir: "data" 5 | 6 | # The NL2SQL360 core name, such that NL2SQL360 core data is saved to "core_dir/core_name.sqlite". 7 | core_name: "nl2sql360" 8 | 9 | # The dataset SQL dialect, "sqlite" by default. 10 | sql_dialect: "sqlite" 11 | 12 | # ---------------- Delete History Arguments ---------------- 13 | 14 | # Notes: 15 | # If `dataset_name` is not set to null and `eval_name` is set to null, it will delete dataset. 16 | # If both `dataset_name` and `eval_name` is set with values, it will delete specific evaluations on the corrsponding dataset. 17 | 18 | # The dataset name to delete which has been imported in NL2SQL360, set to null if you do not want to delete dataset history. 19 | dataset_name: "spider_dev" 20 | 21 | # The dataset name to delete which has been imported in NL2SQL360. 22 | delete_dataset_evaluations: True 23 | 24 | # The evaluation(s) to delete which has been imported in NL2SQL360, set to null if you do not want to delete evaluation history. 25 | # Using list with two keys 26 | eval_name: null -------------------------------------------------------------------------------- /examples/cli_examples/evaluation.yaml: -------------------------------------------------------------------------------- 1 | # ---------------- Core Arguments ---------------- 2 | 3 | # The directory to save NL2SQL360 core data 4 | core_dir: "data" 5 | 6 | # The NL2SQL360 core name, such that NL2SQL360 core data is saved to "core_dir/core_name.sqlite" 7 | core_name: "nl2sql360" 8 | 9 | # The dataset SQL dialect, "sqlite" by default 10 | sql_dialect: "sqlite" 11 | 12 | # ---------------- Evaluation Arguments ---------------- 13 | 14 | # The unique evaluation name to be saved in NL2SQL360 15 | eval_name: "SuperSQL" 16 | 17 | # The dataset name which has been imported in NL2SQL360 18 | eval_dataset: "spider_dev" 19 | 20 | # The evaluation metrics, supporting three different metrics: 21 | # "ex": "Execution Accuracy" 22 | # "em": "Exact-Match Accuracy" 23 | # "ves": "Valid Efficiency Score" 24 | eval_metrics: 25 | - "ex" 26 | - "em" 27 | - "ves" 28 | 29 | # The model predited file in the dataset, containing predited sqls in each line. 30 | pred_sqls_file: "tests/SuperSQL.sql" 31 | 32 | # Whether to enable Spider offcial evaluation script, generally set to True if the dataset is Spider or Spider series (e.g., Spider-Syn). 33 | enable_spider_eval: True -------------------------------------------------------------------------------- /examples/cli_examples/report.yaml: -------------------------------------------------------------------------------- 1 | # ---------------- Core Arguments ---------------- 2 | 3 | # The directory to save NL2SQL360 core data. 4 | core_dir: "data" 5 | 6 | # The NL2SQL360 core name, such that NL2SQL360 core data is saved to "core_dir/core_name.sqlite". 7 | core_name: "nl2sql360" 8 | 9 | # The dataset SQL dialect, "sqlite" by default. 10 | sql_dialect: "sqlite" 11 | 12 | # ---------------- Report Arguments ---------------- 13 | 14 | # The dataset name which has been imported in NL2SQL360. 15 | report_dataset: "spider_dev" 16 | 17 | # The evaluation(s) name to report, listing each name in the following. 18 | # Set to null if you want to report all history evaluations. 19 | report_evaluation: 20 | - "SuperSQL" 21 | 22 | # The metric(s) name to report, including "ex", "em", "ves", "qvt". List each name in the following: 23 | metric: 24 | - "ex" 25 | - "em" 26 | 27 | # Define subset(s) performance by filter(s). List each filter defination in the following. 28 | # Filter: 29 | # "name": The name for the filtered subset to show. 30 | # "expression": The filter expression in format "{FILTER_KEY} {<, >, =} {NUMBER}". 31 | # Valid {FILTER_KEY} is listed in 32 | # https://github.com/BugMaker-Boyan/NL2SQL360/blob/fe436d43031e06cd457e44ec98fd25a5acd25c2b/src/nl2sql360/filter/filter.py#L13 33 | filter: 34 | - 35 | name: "Subquery" 36 | expression: "SUBQUERY > 0" 37 | - 38 | name: "Join" 39 | expression: "JOIN > 0" 40 | 41 | # Define scenario(s) performance. List each scenario defination in the following. 42 | # Scenario: Combination of multiple filters joined by "&&". 43 | scenario: 44 | - 45 | name: "BI" 46 | expression: "SUBQUERY > 0 && JOIN > 0" 47 | 48 | # The report save path in CSV format 49 | save_path: "./report.csv" 50 | -------------------------------------------------------------------------------- /examples/py_examples/dataset_import.py: -------------------------------------------------------------------------------- 1 | from nl2sql360.core import Core 2 | from nl2sql360.arguments import CoreArguments, DatasetArguments 3 | 4 | core_args = CoreArguments() 5 | 6 | core = Core(core_args) 7 | 8 | dataset_args = DatasetArguments( 9 | dataset_name="spider_dev", 10 | dataset_dir="../data/spider", 11 | samples_file="dev.json", 12 | database_dir="database", 13 | tables_file="tables.json", 14 | question_key="question", 15 | sql_key="query", 16 | db_id_key="db_id", 17 | sql_complexity_key=None, 18 | database_domain_file=None 19 | ) 20 | 21 | core.import_dataset(dataset_args) 22 | 23 | 24 | -------------------------------------------------------------------------------- /examples/py_examples/delete_history.py: -------------------------------------------------------------------------------- 1 | from nl2sql360.core import Core 2 | from nl2sql360.arguments import CoreArguments, EvaluationArguments 3 | from nl2sql360.filter import Filter, Scenario, Field, Operator 4 | 5 | if __name__ == "__main__": 6 | core_args = CoreArguments() 7 | 8 | core = Core(core_args) 9 | 10 | core.delete_evaluation_history( 11 | dataset_name="spider_dev", 12 | eval_name="C3SQL" 13 | ) 14 | 15 | core.delete_dataset_history( 16 | dataset_name="spider_dev", 17 | delete_relavant_evaluations=True 18 | ) 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /examples/py_examples/evaluation.py: -------------------------------------------------------------------------------- 1 | from nl2sql360.core import Core 2 | from nl2sql360.arguments import CoreArguments, EvaluationArguments 3 | 4 | if __name__ == "__main__": 5 | core_args = CoreArguments() 6 | 7 | core = Core(core_args) 8 | 9 | evaluation_args = EvaluationArguments( 10 | eval_name="C3SQL", 11 | eval_dataset="spider_dev", 12 | eval_metrics=["ex", "em", "ves"], 13 | pred_sqls_file="./SuperSQL.sql", 14 | enable_spider_eval=True 15 | ) 16 | 17 | core.evaluate(evaluation_args) 18 | -------------------------------------------------------------------------------- /examples/py_examples/report.py: -------------------------------------------------------------------------------- 1 | from nl2sql360.core import Core 2 | from nl2sql360.arguments import CoreArguments, EvaluationArguments 3 | from nl2sql360.filter import Filter, Scenario, Field, Operator 4 | 5 | if __name__ == "__main__": 6 | core_args = CoreArguments() 7 | 8 | core = Core(core_args) 9 | 10 | SUBQUERY_FILTER = Filter( 11 | name="subquery", 12 | field=Field.SUBQUERY, 13 | operator=Operator.GT, 14 | value=0 15 | ) 16 | 17 | BI_SCENARIO = Scenario( 18 | name="BI", 19 | filters=[Filter('agg', Field.AGGREGATION, Operator.GT, 0), Filter('join', Field.JOIN, Operator.GT, 0)] 20 | ) 21 | 22 | print(core.query_overall_leaderboard(dataset_name="spider_dev", metric="ex")) 23 | 24 | print(core.query_filter_performance(dataset_name="spider_dev", metric="ex", filter=filter, eval_name="SuperSQL")) 25 | 26 | print(core.query_filter_leaderboard(dataset_name="spider_dev", metric="ex", filter=filter)) 27 | 28 | print(core.query_scenario_performance(dataset_name="spider_dev", metric="ex", eval_name="SuperSQL", scenario=BI_SCENARIO)) 29 | 30 | print(core.query_scenario_leaderboard(dataset_name="spider_dev", metric="ex", scenario=BI_SCENARIO)) 31 | 32 | print(core.query_dataset_domain_distribution(dataset_name="spider_dev")) 33 | 34 | print(core.generate_evaluation_report(dataset_name="spider_dev", 35 | filters=[SUBQUERY_FILTER], 36 | scenarios=[BI_SCENARIO], 37 | metrics=["ex", "em", "ves"])) 38 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BugMaker-Boyan/NL2SQL360/59407da3c84a8b64b61d2e44e23d498146230cb5/requirements.txt -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | from setuptools import find_packages, setup 5 | 6 | 7 | def get_requires(): 8 | with open("requirements.txt", "r", encoding="utf-8") as f: 9 | file_content = f.read() 10 | lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] 11 | return lines 12 | 13 | 14 | def get_version(): 15 | with open(os.path.join("src", "nl2sql360", "__init__.py"), "r", encoding="utf-8") as f: 16 | file_content = f.read() 17 | pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION") 18 | (version,) = re.findall(pattern, file_content) 19 | return version 20 | 21 | 22 | def main(): 23 | setup( 24 | name="nl2sql360", 25 | version=get_version(), 26 | author="boyanli", 27 | author_email="boyanli" "@" "hkust-gz.edu.cn", 28 | description="A Multi-angle NL2SQL Evaluation Framework", 29 | long_description=open("README.md", "r", encoding="utf-8").read(), 30 | long_description_content_type="text/markdown", 31 | keywords=["NL2SQL", "Text-to-SQL", "T2S", "SQL", "Database", "NLIDB"], 32 | license="MIT License", 33 | url="https://github.com/HKUSTDial/NL2SQL360", 34 | package_dir={"": "src"}, 35 | packages=find_packages("src"), 36 | python_requires=">=3.8.0", 37 | install_requires=get_requires(), 38 | entry_points={"console_scripts": ["nl2sql360-cli = nl2sql360.cli:main"]}, 39 | classifiers=[ 40 | "Development Status :: 3 - Alpha", 41 | "Intended Audience :: Developers", 42 | "Intended Audience :: Education", 43 | "Intended Audience :: Science/Research", 44 | "License :: OSI Approved :: MIT License", 45 | "Operating System :: OS Independent", 46 | "Programming Language :: Python :: 3", 47 | "Programming Language :: Python :: 3.8", 48 | "Programming Language :: Python :: 3.9", 49 | "Programming Language :: Python :: 3.10", 50 | "Programming Language :: Python :: 3.11", 51 | "Topic :: Database", 52 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 53 | "Topic :: Software Development :: Testing" 54 | ], 55 | ) 56 | 57 | 58 | if __name__ == "__main__": 59 | main() -------------------------------------------------------------------------------- /src/nl2sql360/__init__.py: -------------------------------------------------------------------------------- 1 | VERSION = "1.0.4" -------------------------------------------------------------------------------- /src/nl2sql360/arguments/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_args import DatasetArguments 2 | from .core_args import CoreArguments 3 | from .evaluation_args import EvaluationArguments 4 | from .report_args import ReportArguments 5 | from .delete_history_args import DeleteHistoryArguments 6 | from .parser import get_dataset_import_args, get_evaluation_args, get_delete_history_args, get_report_args 7 | 8 | 9 | __all__ = [ 10 | "DatasetArguments", 11 | "CoreArguments", 12 | "EvaluationArguments", 13 | "ReportArguments", 14 | "DeleteHistoryArguments", 15 | "get_dataset_import_args", 16 | "get_evaluation_args", 17 | "get_report_args", 18 | "get_delete_history_args" 19 | ] -------------------------------------------------------------------------------- /src/nl2sql360/arguments/core_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import Optional, Literal 4 | 5 | 6 | @dataclass 7 | class CoreArguments: 8 | r""" 9 | Arguments for NL2SQL360-core. 10 | """ 11 | 12 | core_dir: str = field( 13 | default="./nl2sql360", 14 | metadata={"help": "The directory for NL2SQL360-core storage."} 15 | ) 16 | 17 | core_name: str = field( 18 | default="nl2sql360", 19 | metadata={"help": "The name of NL2SQL360-core"} 20 | ) 21 | 22 | sql_dialect: str = field( 23 | default="sqlite", 24 | metadata={"help": "Specify SQL dialect (e.g., sqlite) to parse."} 25 | ) 26 | -------------------------------------------------------------------------------- /src/nl2sql360/arguments/dataset_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | 6 | @dataclass 7 | class DatasetArguments: 8 | r""" 9 | Arguments for importing dataset. 10 | """ 11 | 12 | dataset_name: str = field( 13 | metadata={"help": "The unique dataset name."} 14 | ) 15 | 16 | dataset_dir: str = field( 17 | metadata={"help": "Path to the folder containing the dataset."} 18 | ) 19 | 20 | samples_file: str = field( 21 | metadata={"help": "The json file containing dataset samples."} 22 | ) 23 | 24 | database_dir: str = field( 25 | metadata={"help": "The directory containing databases (i.e., sqlite files)."} 26 | ) 27 | 28 | tables_file: Optional[str] = field( 29 | default=None, 30 | metadata={"help": "The json file containing database tables (schemas)."} 31 | ) 32 | 33 | question_key: str = field( 34 | default="question", 35 | metadata={"help": "The key name of NL questions in the data json file."} 36 | ) 37 | 38 | sql_key: str = field( 39 | default="sql", 40 | metadata={"help": "The key name of SQL queries in the data json file."} 41 | ) 42 | 43 | db_id_key: str = field( 44 | default="db_id", 45 | metadata={"help": "The key name of database id in the data json file."} 46 | ) 47 | 48 | sql_complexity_key: Optional[str] = field( 49 | default=None, 50 | metadata={"help": "The key name of SQL complexity in the data json file."} 51 | ) 52 | 53 | database_domain_file: str = field( 54 | default=None, 55 | metadata={"help": "The json file containing database domain classifications."} 56 | ) 57 | 58 | def __post_init__(self): 59 | samples_file_path = Path(self.dataset_dir, self.samples_file) 60 | if not samples_file_path.exists() or not samples_file_path.is_file() or samples_file_path.suffix != ".json": 61 | raise ValueError("`samples_file` dose not exist or is not a valid json file.") 62 | 63 | if self.tables_file: 64 | tables_file_path = Path(self.dataset_dir, self.tables_file) 65 | if not tables_file_path.exists() or not tables_file_path.is_file() or tables_file_path.suffix != ".json": 66 | raise ValueError("`tables_file` dose not exist or is not a valid json file.") 67 | 68 | database_dir_path = Path(self.dataset_dir, self.database_dir) 69 | if not database_dir_path.exists() or not database_dir_path.is_dir(): 70 | raise ValueError("`database_dir` dose not exist or is not a directory.") 71 | -------------------------------------------------------------------------------- /src/nl2sql360/arguments/delete_history_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import Optional, Literal, List, Union, Dict 4 | from ..filter import parse_filter, parse_scenario 5 | 6 | 7 | @dataclass 8 | class DeleteHistoryArguments: 9 | r""" 10 | Arguments for deleting datasets and evaluations. 11 | """ 12 | 13 | dataset_name: Optional[str] = field( 14 | default=None, 15 | metadata={"help": "The dataset (list) to delete."} 16 | ) 17 | 18 | delete_dataset_evaluations: Optional[bool] = field( 19 | default=True, 20 | metadata={"help": "Whether to delete relevant evaluations for dataset."} 21 | ) 22 | 23 | eval_name: Optional[str] = field( 24 | default=None, 25 | metadata={"help": "The evaluation (list) to delete, each item contains two keys `dataset` and `evaluation`."} 26 | ) 27 | 28 | def __post_init__(self): 29 | if self.dataset_name is None and self.eval_name is None: 30 | raise ValueError("`dataset_name` and `eval_name` cannot be empty at the same time.") 31 | 32 | if self.dataset_name is None and self.eval_name is not None: 33 | raise ValueError("Need to specify `dataset_name` for `eval_name` evaluation.") -------------------------------------------------------------------------------- /src/nl2sql360/arguments/evaluation_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import Optional, Literal, List 4 | 5 | 6 | @dataclass 7 | class EvaluationArguments: 8 | r""" 9 | Arguments for evaluation. 10 | """ 11 | 12 | eval_name: str = field( 13 | metadata={"help": "The model name for evaluation."} 14 | ) 15 | 16 | eval_dataset: str = field( 17 | metadata={"help": "The dataset name for evaluation."} 18 | ) 19 | 20 | eval_metrics: List[str] = field( 21 | metadata={"help": "Specify metrics (`ex`, `em`, `ves`) for evaluation."} 22 | ) 23 | 24 | pred_sqls_file: str = field( 25 | metadata={"help": "The file containing all predicted sqls (in lines)."} 26 | ) 27 | 28 | enable_spider_eval: bool = field( 29 | default=False, 30 | metadata={"help": "Enable official spider evaluator."} 31 | ) 32 | 33 | num_processes: int = field( 34 | default=8, 35 | metadata={"help": "The number of multi-processes used in the evaluation "} 36 | ) 37 | 38 | timeout: float = field( 39 | default=30.0, 40 | metadata={"help": "The timeout of SQL execution."} 41 | ) 42 | 43 | def __post_init__(self): 44 | 45 | for metric in self.eval_metrics: 46 | if metric not in ["ex", "em", "ves"]: 47 | raise ValueError("`eval_metrics` only supports metrics combinations in (`ex`, `em`, `ves`).") 48 | 49 | if self.num_processes <= 0: 50 | raise ValueError("`num_processes` should be positive.") 51 | 52 | if self.timeout <= 0: 53 | raise ValueError("`timeout` should be positive.") 54 | -------------------------------------------------------------------------------- /src/nl2sql360/arguments/hf_argparser.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import dataclasses 16 | import json 17 | import sys 18 | import types 19 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError 20 | from copy import copy 21 | from enum import Enum 22 | from inspect import isclass 23 | from pathlib import Path 24 | from typing import Any, Callable, Dict, Iterable, List, Literal, NewType, Optional, Tuple, Union, get_type_hints 25 | 26 | import yaml 27 | 28 | 29 | DataClass = NewType("DataClass", Any) 30 | DataClassType = NewType("DataClassType", Any) 31 | 32 | 33 | # From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 34 | def string_to_bool(v): 35 | if isinstance(v, bool): 36 | return v 37 | if v.lower() in ("yes", "true", "t", "y", "1"): 38 | return True 39 | elif v.lower() in ("no", "false", "f", "n", "0"): 40 | return False 41 | else: 42 | raise ArgumentTypeError( 43 | f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)." 44 | ) 45 | 46 | 47 | def make_choice_type_function(choices: list) -> Callable[[str], Any]: 48 | """ 49 | Creates a mapping function from each choices string representation to the actual value. Used to support multiple 50 | value types for a single argument. 51 | 52 | Args: 53 | choices (list): List of choices. 54 | 55 | Returns: 56 | Callable[[str], Any]: Mapping function from string representation to actual value for each choice. 57 | """ 58 | str_to_choice = {str(choice): choice for choice in choices} 59 | return lambda arg: str_to_choice.get(arg, arg) 60 | 61 | 62 | def HfArg( 63 | *, 64 | aliases: Union[str, List[str]] = None, 65 | help: str = None, 66 | default: Any = dataclasses.MISSING, 67 | default_factory: Callable[[], Any] = dataclasses.MISSING, 68 | metadata: dict = None, 69 | **kwargs, 70 | ) -> dataclasses.Field: 71 | """Argument helper enabling a concise syntax to create dataclass fields for parsing with `HfArgumentParser`. 72 | 73 | Example comparing the use of `HfArg` and `dataclasses.field`: 74 | ``` 75 | @dataclass 76 | class Args: 77 | regular_arg: str = dataclasses.field(default="Huggingface", metadata={"aliases": ["--example", "-e"], "help": "This syntax could be better!"}) 78 | hf_arg: str = HfArg(default="Huggingface", aliases=["--example", "-e"], help="What a nice syntax!") 79 | ``` 80 | 81 | Args: 82 | aliases (Union[str, List[str]], optional): 83 | Single string or list of strings of aliases to pass on to argparse, e.g. `aliases=["--example", "-e"]`. 84 | Defaults to None. 85 | help (str, optional): Help string to pass on to argparse that can be displayed with --help. Defaults to None. 86 | default (Any, optional): 87 | Default value for the argument. If not default or default_factory is specified, the argument is required. 88 | Defaults to dataclasses.MISSING. 89 | default_factory (Callable[[], Any], optional): 90 | The default_factory is a 0-argument function called to initialize a field's value. It is useful to provide 91 | default values for mutable types, e.g. lists: `default_factory=list`. Mutually exclusive with `default=`. 92 | Defaults to dataclasses.MISSING. 93 | metadata (dict, optional): Further metadata to pass on to `dataclasses.field`. Defaults to None. 94 | 95 | Returns: 96 | Field: A `dataclasses.Field` with the desired properties. 97 | """ 98 | if metadata is None: 99 | # Important, don't use as default param in function signature because dict is mutable and shared across function calls 100 | metadata = {} 101 | if aliases is not None: 102 | metadata["aliases"] = aliases 103 | if help is not None: 104 | metadata["help"] = help 105 | 106 | return dataclasses.field(metadata=metadata, default=default, default_factory=default_factory, **kwargs) 107 | 108 | 109 | class HfArgumentParser(ArgumentParser): 110 | """ 111 | This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments. 112 | 113 | The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed) 114 | arguments to the parser after initialization and you'll get the output back after parsing as an additional 115 | namespace. Optional: To create sub argument groups use the `_argument_group_name` attribute in the dataclass. 116 | """ 117 | 118 | dataclass_types: Iterable[DataClassType] 119 | 120 | def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs): 121 | """ 122 | Args: 123 | dataclass_types: 124 | Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args. 125 | kwargs (`Dict[str, Any]`, *optional*): 126 | Passed to `argparse.ArgumentParser()` in the regular way. 127 | """ 128 | # To make the default appear when using --help 129 | if "formatter_class" not in kwargs: 130 | kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter 131 | super().__init__(**kwargs) 132 | if dataclasses.is_dataclass(dataclass_types): 133 | dataclass_types = [dataclass_types] 134 | self.dataclass_types = list(dataclass_types) 135 | for dtype in self.dataclass_types: 136 | self._add_dataclass_arguments(dtype) 137 | 138 | @staticmethod 139 | def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field): 140 | field_name = f"--{field.name}" 141 | kwargs = field.metadata.copy() 142 | # field.metadata is not used at all by Data Classes, 143 | # it is provided as a third-party extension mechanism. 144 | if isinstance(field.type, str): 145 | raise RuntimeError( 146 | "Unresolved type detected, which should have been done with the help of " 147 | "`typing.get_type_hints` method by default" 148 | ) 149 | 150 | aliases = kwargs.pop("aliases", []) 151 | if isinstance(aliases, str): 152 | aliases = [aliases] 153 | 154 | origin_type = getattr(field.type, "__origin__", field.type) 155 | if origin_type is Union or (hasattr(types, "UnionType") and isinstance(origin_type, types.UnionType)): 156 | if str not in field.type.__args__ and ( 157 | len(field.type.__args__) != 2 or type(None) not in field.type.__args__ 158 | ): 159 | raise ValueError( 160 | "Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union` because" 161 | " the argument parser only supports one type per argument." 162 | f" Problem encountered in field '{field.name}'." 163 | ) 164 | if type(None) not in field.type.__args__: 165 | # filter `str` in Union 166 | field.type = field.type.__args__[0] if field.type.__args__[1] == str else field.type.__args__[1] 167 | origin_type = getattr(field.type, "__origin__", field.type) 168 | elif bool not in field.type.__args__: 169 | # filter `NoneType` in Union (except for `Union[bool, NoneType]`) 170 | field.type = ( 171 | field.type.__args__[0] if isinstance(None, field.type.__args__[1]) else field.type.__args__[1] 172 | ) 173 | origin_type = getattr(field.type, "__origin__", field.type) 174 | 175 | # A variable to store kwargs for a boolean field, if needed 176 | # so that we can init a `no_*` complement argument (see below) 177 | bool_kwargs = {} 178 | if origin_type is Literal or (isinstance(field.type, type) and issubclass(field.type, Enum)): 179 | if origin_type is Literal: 180 | kwargs["choices"] = field.type.__args__ 181 | else: 182 | kwargs["choices"] = [x.value for x in field.type] 183 | 184 | kwargs["type"] = make_choice_type_function(kwargs["choices"]) 185 | 186 | if field.default is not dataclasses.MISSING: 187 | kwargs["default"] = field.default 188 | else: 189 | kwargs["required"] = True 190 | elif field.type is bool or field.type == Optional[bool]: 191 | # Copy the currect kwargs to use to instantiate a `no_*` complement argument below. 192 | # We do not initialize it here because the `no_*` alternative must be instantiated after the real argument 193 | bool_kwargs = copy(kwargs) 194 | 195 | # Hack because type=bool in argparse does not behave as we want. 196 | kwargs["type"] = string_to_bool 197 | if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING): 198 | # Default value is False if we have no default when of type bool. 199 | default = False if field.default is dataclasses.MISSING else field.default 200 | # This is the value that will get picked if we don't include --field_name in any way 201 | kwargs["default"] = default 202 | # This tells argparse we accept 0 or 1 value after --field_name 203 | kwargs["nargs"] = "?" 204 | # This is the value that will get picked if we do --field_name (without value) 205 | kwargs["const"] = True 206 | elif isclass(origin_type) and issubclass(origin_type, list): 207 | kwargs["type"] = field.type.__args__[0] 208 | kwargs["nargs"] = "+" 209 | if field.default_factory is not dataclasses.MISSING: 210 | kwargs["default"] = field.default_factory() 211 | elif field.default is dataclasses.MISSING: 212 | kwargs["required"] = True 213 | else: 214 | kwargs["type"] = field.type 215 | if field.default is not dataclasses.MISSING: 216 | kwargs["default"] = field.default 217 | elif field.default_factory is not dataclasses.MISSING: 218 | kwargs["default"] = field.default_factory() 219 | else: 220 | kwargs["required"] = True 221 | parser.add_argument(field_name, *aliases, **kwargs) 222 | 223 | # Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added. 224 | # Order is important for arguments with the same destination! 225 | # We use a copy of earlier kwargs because the original kwargs have changed a lot before reaching down 226 | # here and we do not need those changes/additional keys. 227 | if field.default is True and (field.type is bool or field.type == Optional[bool]): 228 | bool_kwargs["default"] = False 229 | parser.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **bool_kwargs) 230 | 231 | def _add_dataclass_arguments(self, dtype: DataClassType): 232 | if hasattr(dtype, "_argument_group_name"): 233 | parser = self.add_argument_group(dtype._argument_group_name) 234 | else: 235 | parser = self 236 | 237 | try: 238 | type_hints: Dict[str, type] = get_type_hints(dtype) 239 | except NameError: 240 | raise RuntimeError( 241 | f"Type resolution failed for {dtype}. Try declaring the class in global scope or " 242 | "removing line of `from __future__ import annotations` which opts in Postponed " 243 | "Evaluation of Annotations (PEP 563)" 244 | ) 245 | except TypeError as ex: 246 | # Remove this block when we drop Python 3.9 support 247 | if sys.version_info[:2] < (3, 10) and "unsupported operand type(s) for |" in str(ex): 248 | python_version = ".".join(map(str, sys.version_info[:3])) 249 | raise RuntimeError( 250 | f"Type resolution failed for {dtype} on Python {python_version}. Try removing " 251 | "line of `from __future__ import annotations` which opts in union types as " 252 | "`X | Y` (PEP 604) via Postponed Evaluation of Annotations (PEP 563). To " 253 | "support Python versions that lower than 3.10, you need to use " 254 | "`typing.Union[X, Y]` instead of `X | Y` and `typing.Optional[X]` instead of " 255 | "`X | None`." 256 | ) from ex 257 | raise 258 | 259 | for field in dataclasses.fields(dtype): 260 | if not field.init: 261 | continue 262 | field.type = type_hints[field.name] 263 | self._parse_dataclass_field(parser, field) 264 | 265 | def parse_args_into_dataclasses( 266 | self, 267 | args=None, 268 | return_remaining_strings=False, 269 | look_for_args_file=True, 270 | args_filename=None, 271 | args_file_flag=None, 272 | ) -> Tuple[DataClass, ...]: 273 | """ 274 | Parse command-line args into instances of the specified dataclass types. 275 | 276 | This relies on argparse's `ArgumentParser.parse_known_args`. See the doc at: 277 | docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args 278 | 279 | Args: 280 | args: 281 | List of strings to parse. The default is taken from sys.argv. (same as argparse.ArgumentParser) 282 | return_remaining_strings: 283 | If true, also return a list of remaining argument strings. 284 | look_for_args_file: 285 | If true, will look for a ".args" file with the same base name as the entry point script for this 286 | process, and will append its potential content to the command line args. 287 | args_filename: 288 | If not None, will uses this file instead of the ".args" file specified in the previous argument. 289 | args_file_flag: 290 | If not None, will look for a file in the command-line args specified with this flag. The flag can be 291 | specified multiple times and precedence is determined by the order (last one wins). 292 | 293 | Returns: 294 | Tuple consisting of: 295 | 296 | - the dataclass instances in the same order as they were passed to the initializer.abspath 297 | - if applicable, an additional namespace for more (non-dataclass backed) arguments added to the parser 298 | after initialization. 299 | - The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args) 300 | """ 301 | 302 | if args_file_flag or args_filename or (look_for_args_file and len(sys.argv)): 303 | args_files = [] 304 | 305 | if args_filename: 306 | args_files.append(Path(args_filename)) 307 | elif look_for_args_file and len(sys.argv): 308 | args_files.append(Path(sys.argv[0]).with_suffix(".args")) 309 | 310 | # args files specified via command line flag should overwrite default args files so we add them last 311 | if args_file_flag: 312 | # Create special parser just to extract the args_file_flag values 313 | args_file_parser = ArgumentParser() 314 | args_file_parser.add_argument(args_file_flag, type=str, action="append") 315 | 316 | # Use only remaining args for further parsing (remove the args_file_flag) 317 | cfg, args = args_file_parser.parse_known_args(args=args) 318 | cmd_args_file_paths = vars(cfg).get(args_file_flag.lstrip("-"), None) 319 | 320 | if cmd_args_file_paths: 321 | args_files.extend([Path(p) for p in cmd_args_file_paths]) 322 | 323 | file_args = [] 324 | for args_file in args_files: 325 | if args_file.exists(): 326 | file_args += args_file.read_text().split() 327 | 328 | # in case of duplicate arguments the last one has precedence 329 | # args specified via the command line should overwrite args from files, so we add them last 330 | args = file_args + args if args is not None else file_args + sys.argv[1:] 331 | namespace, remaining_args = self.parse_known_args(args=args) 332 | outputs = [] 333 | for dtype in self.dataclass_types: 334 | keys = {f.name for f in dataclasses.fields(dtype) if f.init} 335 | inputs = {k: v for k, v in vars(namespace).items() if k in keys} 336 | for k in keys: 337 | delattr(namespace, k) 338 | obj = dtype(**inputs) 339 | outputs.append(obj) 340 | if len(namespace.__dict__) > 0: 341 | # additional namespace. 342 | outputs.append(namespace) 343 | if return_remaining_strings: 344 | return (*outputs, remaining_args) 345 | else: 346 | if remaining_args: 347 | raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}") 348 | 349 | return (*outputs,) 350 | 351 | def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: 352 | """ 353 | Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass 354 | types. 355 | 356 | Args: 357 | args (`dict`): 358 | dict containing config values 359 | allow_extra_keys (`bool`, *optional*, defaults to `False`): 360 | Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed. 361 | 362 | Returns: 363 | Tuple consisting of: 364 | 365 | - the dataclass instances in the same order as they were passed to the initializer. 366 | """ 367 | unused_keys = set(args.keys()) 368 | outputs = [] 369 | for dtype in self.dataclass_types: 370 | keys = {f.name for f in dataclasses.fields(dtype) if f.init} 371 | inputs = {k: v for k, v in args.items() if k in keys} 372 | unused_keys.difference_update(inputs.keys()) 373 | obj = dtype(**inputs) 374 | outputs.append(obj) 375 | if not allow_extra_keys and unused_keys: 376 | raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}") 377 | return tuple(outputs) 378 | 379 | def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: 380 | """ 381 | Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the 382 | dataclass types. 383 | 384 | Args: 385 | json_file (`str` or `os.PathLike`): 386 | File name of the json file to parse 387 | allow_extra_keys (`bool`, *optional*, defaults to `False`): 388 | Defaults to False. If False, will raise an exception if the json file contains keys that are not 389 | parsed. 390 | 391 | Returns: 392 | Tuple consisting of: 393 | 394 | - the dataclass instances in the same order as they were passed to the initializer. 395 | """ 396 | with open(Path(json_file), encoding="utf-8") as open_json_file: 397 | data = json.loads(open_json_file.read()) 398 | outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys) 399 | return tuple(outputs) 400 | 401 | def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: 402 | """ 403 | Alternative helper method that does not use `argparse` at all, instead loading a yaml file and populating the 404 | dataclass types. 405 | 406 | Args: 407 | yaml_file (`str` or `os.PathLike`): 408 | File name of the yaml file to parse 409 | allow_extra_keys (`bool`, *optional*, defaults to `False`): 410 | Defaults to False. If False, will raise an exception if the json file contains keys that are not 411 | parsed. 412 | 413 | Returns: 414 | Tuple consisting of: 415 | 416 | - the dataclass instances in the same order as they were passed to the initializer. 417 | """ 418 | outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys) 419 | return tuple(outputs) 420 | -------------------------------------------------------------------------------- /src/nl2sql360/arguments/parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import loguru 4 | from typing import Any, Dict, Optional, Tuple 5 | from .hf_argparser import HfArgumentParser 6 | 7 | from .core_args import CoreArguments 8 | from .dataset_args import DatasetArguments 9 | from .evaluation_args import EvaluationArguments 10 | from .report_args import ReportArguments 11 | from .delete_history_args import DeleteHistoryArguments 12 | 13 | 14 | _DATASET_IMPORT_ARGS = [CoreArguments, DatasetArguments] 15 | _DATASET_IMPORT_CLS = Tuple[CoreArguments, DatasetArguments] 16 | _EVALUATION_ARGS = [CoreArguments, EvaluationArguments] 17 | _EVALUATION_CLS = Tuple[CoreArguments, EvaluationArguments] 18 | _REPORT_ARGS = [CoreArguments, ReportArguments] 19 | _REPORT_CLS = Tuple[CoreArguments, ReportArguments] 20 | _DELETE_HISTORY_ARGS = [CoreArguments, DeleteHistoryArguments] 21 | _DELETE_HISTORY_CLS = Tuple[CoreArguments, DeleteHistoryArguments] 22 | 23 | 24 | def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: 25 | if args is not None: 26 | return parser.parse_dict(args) 27 | 28 | if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): 29 | return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) 30 | 31 | (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True) 32 | 33 | if unknown_args: 34 | print(parser.format_help()) 35 | print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args)) 36 | raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args)) 37 | 38 | return (*parsed_args,) 39 | 40 | 41 | def get_dataset_import_args(args: Optional[Dict[str, Any]] = None) -> _DATASET_IMPORT_CLS: 42 | parser = HfArgumentParser(_DATASET_IMPORT_ARGS) 43 | return _parse_args(parser, args) 44 | 45 | 46 | def get_evaluation_args(args: Optional[Dict[str, Any]] = None) -> _EVALUATION_CLS: 47 | parser = HfArgumentParser(_EVALUATION_ARGS) 48 | return _parse_args(parser, args) 49 | 50 | 51 | def get_report_args(args: Optional[Dict[str, Any]] = None) -> _REPORT_CLS: 52 | parser = HfArgumentParser(_REPORT_ARGS) 53 | return _parse_args(parser, args) 54 | 55 | 56 | def get_delete_history_args(args: Optional[Dict[str, Any]] = None) -> _DELETE_HISTORY_CLS: 57 | parser = HfArgumentParser(_DELETE_HISTORY_ARGS) 58 | return _parse_args(parser, args) 59 | -------------------------------------------------------------------------------- /src/nl2sql360/arguments/report_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import Optional, Literal, List, Union, Dict 4 | from types import NoneType 5 | from ..filter import parse_filter, parse_scenario 6 | 7 | 8 | @dataclass 9 | class ReportArguments: 10 | r""" 11 | Arguments for reporting dataset and evaluation. 12 | """ 13 | 14 | save_path: str = field( 15 | metadata={"help": "The path to save the CSV report."} 16 | ) 17 | 18 | report_dataset: str = field( 19 | metadata={"help": "The dataset name to report."} 20 | ) 21 | 22 | report_evaluation: Union[Optional[str], List[str]] = field( 23 | default=None, 24 | metadata={"help": "The evaluation (list) to report."} 25 | ) 26 | 27 | metric: Union[str, List[str]] = field( 28 | default=None, 29 | metadata={"help": "The metric (list) to report."} 30 | ) 31 | 32 | filter: Union[Optional[str], List[str]] = field( 33 | default=None, 34 | metadata={"help": "The filter expressions (list) used to filter subset performance."} 35 | ) 36 | 37 | scenario: Union[Optional[str], List[str]] = field( 38 | default=None, 39 | metadata={"help": "The scenario expressions (list) used to filter subset performance."} 40 | ) 41 | 42 | def __post_init__(self): 43 | if isinstance(self.metric, str): 44 | self.metric = [m.strip() for m in self.metric.split(",")] 45 | 46 | for metric in self.metric: 47 | if metric not in ["ex", "em", "ves"]: 48 | raise ValueError("`eval_metrics` only supports metrics combinations in (`ex`, `em`, `ves`).") 49 | 50 | filter_list = [] 51 | if self.filter: 52 | try: 53 | for _f in self.filter: 54 | f = parse_filter(_f["name"], _f["expression"]) 55 | if f is None: 56 | raise ValueError(f"Parse filter error: {_f}") 57 | filter_list.append(f) 58 | except Exception as e: 59 | raise ValueError("Parse filter error.") 60 | self.filter = filter_list 61 | 62 | scenario_list = [] 63 | if self.scenario: 64 | try: 65 | for _s in self.scenario: 66 | s = parse_scenario(_s["name"], _s["expression"]) 67 | if s is None: 68 | raise ValueError(f"Parse scenario error: {_s}") 69 | scenario_list.append(s) 70 | except Exception as e: 71 | raise ValueError("Parse scenario error.") 72 | self.scenario = scenario_list 73 | -------------------------------------------------------------------------------- /src/nl2sql360/cli/__init__.py: -------------------------------------------------------------------------------- 1 | from .cli import main -------------------------------------------------------------------------------- /src/nl2sql360/cli/cli.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from enum import Enum, unique 3 | from .. import VERSION 4 | from .util import run_dataset_import, run_delete_history, run_evaluation, run_report 5 | 6 | 7 | USAGE = ( 8 | "-" * 70 9 | + "\n" 10 | + "| Usage: |\n" 11 | + "| nl2sql360-cli dataset -h: import NL2SQL dataset |\n" 12 | + "| nl2sql360-cli evaluate -h: evaluate NL2SQL model |\n" 13 | + "| nl2sql360-cli report -h: output evaluation report |\n" 14 | + "| nl2sql360-cli delete -h: delete dataset or evaluation history |\n" 15 | + "| nl2sql360-cli version: show version info |\n" 16 | + "-" * 70 17 | ) 18 | 19 | 20 | WELCOME = ( 21 | "-" * 58 22 | + "\n" 23 | + "| Welcome to NL2SQL360, version {}".format(VERSION) 24 | + " " * (25 - len(VERSION)) 25 | + "|\n|" 26 | + " " * 56 27 | + "|\n" 28 | + "| Project page: https://github.com/HKUSTDial/NL2SQL360" 29 | + " " * 3 30 | + "|\n" 31 | + "-" * 58 32 | ) 33 | 34 | 35 | @unique 36 | class Command(str, Enum): 37 | DATASET = "dataset" 38 | EVALUATE = "evaluate" 39 | REPORT = "report" 40 | DELETE = "delete" 41 | VERSION = "version" 42 | HELP = "help" 43 | 44 | 45 | def main(): 46 | command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP 47 | if command == Command.DATASET: 48 | run_dataset_import() 49 | elif command == Command.EVALUATE: 50 | run_evaluation() 51 | elif command == Command.REPORT: 52 | run_report() 53 | elif command == Command.DELETE: 54 | run_delete_history() 55 | elif command == Command.VERSION: 56 | print(WELCOME) 57 | elif command == Command.HELP: 58 | print(USAGE) 59 | else: 60 | raise NotImplementedError("Unknown command: {}".format(command)) 61 | -------------------------------------------------------------------------------- /src/nl2sql360/cli/util.py: -------------------------------------------------------------------------------- 1 | from ..arguments import ( 2 | get_dataset_import_args, 3 | get_evaluation_args, 4 | get_report_args, 5 | get_delete_history_args 6 | ) 7 | from ..core import Core 8 | import pandas as pd 9 | from loguru import logger 10 | from pathlib import Path 11 | 12 | 13 | def run_dataset_import(): 14 | core_args, dataset_args = get_dataset_import_args() 15 | Core(core_args).import_dataset(dataset_args) 16 | 17 | 18 | def run_evaluation(): 19 | core_args, evaluation_args = get_evaluation_args() 20 | Core(core_args).evaluate(evaluation_args) 21 | 22 | 23 | def run_report(): 24 | core_args, report_args = get_report_args() 25 | report = Core(core_args).generate_evaluation_report( 26 | dataset_name=report_args.report_dataset, 27 | filters=report_args.filter, 28 | scenarios=report_args.scenario, 29 | metrics=report_args.metric, 30 | eval_names=report_args.report_evaluation 31 | ) 32 | report.to_csv(report_args.save_path) 33 | logger.success(f"Save report in path `{Path(report_args.save_path).resolve()}` successfully.`") 34 | 35 | 36 | def run_delete_history(): 37 | core_args, delete_history_args = get_delete_history_args() 38 | core = Core(core_args) 39 | if delete_history_args.dataset_name and not delete_history_args.eval_name: 40 | core.delete_dataset_history( 41 | dataset_name=delete_history_args.dataset_name, 42 | delete_relavant_evaluations=delete_history_args.delete_dataset_evaluations 43 | ) 44 | if delete_history_args.dataset_name and delete_history_args.eval_name: 45 | core.delete_evaluation_history( 46 | dataset_name=delete_history_args.dataset_name, 47 | eval_name=delete_history_args.eval_name 48 | ) 49 | -------------------------------------------------------------------------------- /src/nl2sql360/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import Core 2 | 3 | 4 | __all__ = [ 5 | "Core" 6 | ] -------------------------------------------------------------------------------- /src/nl2sql360/core/core.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import create_engine, inspect, text 2 | from sqlalchemy.orm import Session 3 | from loguru import logger 4 | from tqdm import tqdm 5 | from pathlib import Path 6 | from typing import Sequence, Optional, List, Dict, Tuple, Union 7 | from pathlib import Path 8 | from pandas import DataFrame 9 | import pandas as pd 10 | import itertools 11 | 12 | from ..database import * 13 | from ..parser import SQLParser 14 | from ..dataset import NL2SQLDataset 15 | from ..arguments import CoreArguments, DatasetArguments, EvaluationArguments 16 | from ..evaluator import BirdAccraucyEvaluator, SpiderAccraucyEvaluator, VesEvaluator 17 | from ..filter import Filter, Scenario, serialize_filter, serialize_scenario 18 | 19 | 20 | class _Core: 21 | r""" 22 | Base core class implementation for importing datasets and evaluating. 23 | """ 24 | 25 | def __init__(self, core_args: "CoreArguments") -> None: 26 | self.core_args = core_args 27 | Path(core_args.core_dir).mkdir(exist_ok=True) 28 | self.engine = create_engine(f"sqlite:///{core_args.core_dir}/{core_args.core_name}.sqlite") 29 | self.insp = inspect(self.engine) 30 | Base.metadata.create_all(self.engine, checkfirst=True) # `DatasetInfo` Table Initialize 31 | self.models_dict = dict() 32 | for table_name in self.insp.get_table_names(): 33 | if table_name == "__DATASET_INFO__": 34 | continue 35 | if "_EVALUATION_" in table_name: 36 | self.models_dict[table_name] = get_evaluation_model(*get_dataset_name_and_evaluation_name_from_table_name(table_name)) 37 | else: 38 | self.models_dict[table_name] = get_dataset_model(get_dataset_name_from_table_name(table_name)) 39 | 40 | def import_dataset(self, dataset_args: "DatasetArguments") -> None: 41 | table_name = f"DATASET_{dataset_args.dataset_name}" 42 | if table_name in self.models_dict.keys(): 43 | logger.warning(f"Dataset `{dataset_args.dataset_name}` has been already imported.") 44 | return 45 | 46 | dataset_model = get_dataset_model(dataset_args.dataset_name) 47 | self.models_dict[dataset_model.__tablename__] = dataset_model 48 | Base.metadata.create_all(self.engine, checkfirst=True) 49 | logger.success(f"Dataset table `{table_name}` creation completed.") 50 | 51 | dataset = NL2SQLDataset(dataset_args) 52 | with Session(self.engine) as session: 53 | # Insert dataset info 54 | dataset_info_item = DatasetInfo( 55 | dataset_name=dataset_args.dataset_name, 56 | database_dir_path=str(Path(dataset_args.dataset_dir, dataset_args.database_dir).resolve()), 57 | tables_json_path=str(Path(dataset_args.dataset_dir, dataset_args.tables_file).resolve()) if dataset_args.tables_file else None 58 | ) 59 | session.add(dataset_info_item) 60 | 61 | # Insert dataset samples 62 | for id, (nlq, gold, db_id, complexity, db_domain) in enumerate(tqdm(list(zip( 63 | dataset.get_all_questions(), 64 | dataset.get_all_sqls(), 65 | dataset.get_all_db_ids(), 66 | dataset.get_all_sql_complexity(), 67 | dataset.get_all_database_domains() 68 | )), desc="Import dataset")): 69 | parsed_sql = SQLParser(gold) 70 | table_item = self.models_dict[table_name]( 71 | id=id, 72 | nlq=nlq, 73 | gold=gold, 74 | db_id=db_id, 75 | complexity=complexity, 76 | db_domain=db_domain, 77 | **{attr: getattr(parsed_sql, attr) for attr in dir(parsed_sql) if attr.startswith("count_")} 78 | ) 79 | session.add(table_item) 80 | session.commit() 81 | logger.success(f"Import dataset `{dataset_args.dataset_name}` completed, {len(dataset)} samples in total.") 82 | 83 | def evaluate(self, evaluation_args: "EvaluationArguments") -> None: 84 | dataset_table_name = f"DATASET_{evaluation_args.eval_dataset}" 85 | if dataset_table_name not in self.models_dict.keys(): 86 | logger.warning(f"Dataset `{evaluation_args.eval_dataset}` has not been imported.") 87 | return 88 | 89 | table_name = f"DATASET_{evaluation_args.eval_dataset}_EVALUATION_{evaluation_args.eval_name}" 90 | if table_name in self.models_dict.keys(): 91 | logger.warning(f"Evaluation `{evaluation_args.eval_name}` on dataset `{evaluation_args.eval_dataset}` has been existed.") 92 | return 93 | 94 | evaluation_model = get_evaluation_model(evaluation_args.eval_dataset, evaluation_args.eval_name) 95 | self.models_dict[evaluation_model.__tablename__] = evaluation_model 96 | Base.metadata.create_all(self.engine, checkfirst=True) 97 | logger.success(f"Evaluation table `{table_name}` creation completed.") 98 | 99 | dataset_info = get_dataset_info(self.engine, evaluation_args.eval_dataset) 100 | if dataset_info is None: 101 | logger.error(f"Cannot find imported dataset `{evaluation_args.eval_dataset}`.") 102 | return 103 | 104 | evaluators = [] 105 | if "ex" in evaluation_args.eval_metrics: 106 | if evaluation_args.enable_spider_eval: 107 | eval_em = "em" in evaluation_args.eval_metrics and dataset_info.tables_json_path 108 | if "em" in evaluation_args.eval_metrics and dataset_info.tables_json_path is None: 109 | logger.warning(f"`EM` metric evaluation ignored, due to no imported `tables_file` for {evaluation_args.eval_dataset} dataset.") 110 | evaluators.append(SpiderAccraucyEvaluator(eval_em=eval_em, eval_ex=True)) 111 | else: 112 | eval_em = "em" in evaluation_args.eval_metrics and dataset_info.tables_json_path 113 | if "em" in evaluation_args.eval_metrics and dataset_info.tables_json_path is None: 114 | logger.warning(f"`EM` metric evaluation ignored, due to no imported `tables_file` for {evaluation_args.eval_dataset} dataset.") 115 | if eval_em: 116 | evaluators.append(SpiderAccraucyEvaluator(eval_em=eval_em, eval_ex=False)) 117 | evaluators.append(BirdAccraucyEvaluator()) 118 | elif "em" in evaluation_args.eval_metrics: 119 | if dataset_info.tables_json_path is None: 120 | logger.warning(f"`EM` metric evaluation ignored, due to no imported `tables_file` for {evaluation_args.eval_dataset} dataset.") 121 | else: 122 | evaluators.append(SpiderAccraucyEvaluator(eval_em=True, eval_ex=False)) 123 | 124 | if "ves" in evaluation_args.eval_metrics: 125 | evaluators.append(VesEvaluator(reuse_ex=evaluation_args.enable_spider_eval)) 126 | 127 | with open(evaluation_args.pred_sqls_file, "r", encoding="utf-8") as f: 128 | pred_sqls = f.readlines() 129 | 130 | dataset_samples = get_dataset_samples(self.engine, self.models_dict[dataset_table_name]) 131 | gold_sqls = [sample["gold"] for sample in dataset_samples] 132 | db_ids = [sample["db_id"] for sample in dataset_samples] 133 | 134 | eval_results = dict() 135 | eval_metrics = set() 136 | for evaluator in evaluators: 137 | logger.info(f"Evaluating {evaluator.get_eval_metrics()}...") 138 | exec_acc_list = eval_results.get("exec_acc", None) 139 | eval_results.update(evaluator.evaluate( 140 | gold_sqls=gold_sqls, 141 | pred_sqls=pred_sqls, 142 | db_ids=db_ids, 143 | db_dir=dataset_info.database_dir_path, 144 | tables_json_path=dataset_info.tables_json_path, 145 | exec_acc_list=exec_acc_list 146 | )) 147 | logger.success(f"Evaluating {evaluator.get_eval_metrics()} completed.") 148 | eval_metrics.update(evaluator.get_eval_metrics()) 149 | 150 | insert_data = [] 151 | for idx, pred in enumerate(pred_sqls): 152 | item = { 153 | "id": idx, 154 | "pred": pred 155 | } 156 | for metric in eval_metrics: 157 | item[metric] = eval_results[metric][idx] 158 | insert_data.append(item) 159 | 160 | with Session(self.engine) as session: 161 | for data in tqdm(insert_data, desc="Intert into evaluation table"): 162 | table_item = self.models_dict[table_name]( 163 | **data 164 | ) 165 | session.add(table_item) 166 | session.commit() 167 | logger.success(f"Evaluation `{evaluation_args.eval_name}` completed.") 168 | 169 | 170 | class Core(_Core): 171 | r""" 172 | Extended core class implementation, including more user query interfaces. 173 | """ 174 | 175 | def query_available_datasets(self) -> DataFrame: 176 | datasets = [get_dataset_name_from_table_name(table) 177 | for table in self.models_dict.keys() 178 | if table.startswith("DATASET_") and "_EVALUATION_" not in table] 179 | return DataFrame(data={"Dataset": datasets}) 180 | 181 | def query_available_evaluations(self, dataset_name: str) -> DataFrame: 182 | evaluations = [get_dataset_name_and_evaluation_name_from_table_name(table)[1] 183 | for table in self.models_dict.keys() 184 | if table.startswith(f"DATASET_{dataset_name}") and "_EVALUATION_" in table] 185 | return DataFrame(data={"Evaluation": evaluations}) 186 | 187 | def _check_dataset_valid(self, dataset_name: str) -> bool: 188 | if dataset_name in self.query_available_datasets()["Dataset"].values: 189 | return True 190 | else: 191 | logger.warning(f"Cannot find `{dataset_name}` dataset in NL2SQL360.") 192 | return False 193 | 194 | def _check_evaluation_valid(self, dataset_name: str, eval_name: str) -> bool: 195 | if eval_name in self.query_available_evaluations(dataset_name)["Evaluation"].values: 196 | return True 197 | else: 198 | logger.warning(f"Cannot find `{eval_name}` evaluation for `{dataset_name}` dataset in NL2SQL360.") 199 | return False 200 | 201 | def _check_metric_valid(self, metric: str) -> bool: 202 | if metric in METRIC_COL_MAPPING.keys(): 203 | return True 204 | else: 205 | logger.warning(f"`{metric}` metric is not supported, available metrics: (`ex`, `em`, `ves`, `qvt`).") 206 | return False 207 | 208 | def query_overall_performance(self, dataset_name: str, metric: str, eval_name: str) -> DataFrame: 209 | if not (self._check_dataset_valid(dataset_name) and self._check_evaluation_valid(dataset_name, eval_name) and self._check_metric_valid(metric)): 210 | return None 211 | else: 212 | if metric == "qvt": 213 | statetment = QUERY_QVT_PERFORMANCE.format( 214 | DATASET_NAME=dataset_name, 215 | EVAL_NAME=eval_name 216 | ) 217 | else: 218 | statetment = QUERY_OVERALL_PERFORMANCE.format( 219 | DATASET_NAME=dataset_name, 220 | EVAL_NAME=eval_name, 221 | METRIC_COL=METRIC_COL_MAPPING[metric] 222 | ) 223 | with self.engine.connect() as connection: 224 | result = connection.execute(text(statetment)) 225 | connection.commit() 226 | res = result.first() 227 | if res: 228 | return DataFrame(data={"Evaluation": eval_name, metric.upper(): res}).round(decimals=2) 229 | else: 230 | logger.warning("Query an empty result.") 231 | return DataFrame(data={"Evaluation": eval_name, metric.upper(): pd.NA}) 232 | 233 | def query_overall_leaderboard(self, dataset_name: str, metric: str, eval_names: List[str] = None) -> DataFrame: 234 | if not (self._check_dataset_valid(dataset_name) and self._check_metric_valid(metric)): 235 | return None 236 | 237 | if eval_names: 238 | for eval_name in eval_names: 239 | if not self._check_evaluation_valid(dataset_name, eval_name): 240 | return None 241 | else: 242 | eval_names = self.query_available_evaluations(dataset_name)["Evaluation"].values 243 | dataframes = [] 244 | for eval_name in eval_names: 245 | dataframes.append(self.query_overall_performance(dataset_name, metric, eval_name)) 246 | df = pd.concat(dataframes, ignore_index=True).sort_values(by=[metric.upper(), "Evaluation"], ascending=False) 247 | df["Rank"] = df[f"{metric.upper()}"].rank(axis=0, method="dense", ascending=False) 248 | return df 249 | 250 | def query_filter_performance(self, dataset_name: str, filter: Filter, metric: str, eval_name: str) -> DataFrame: 251 | if not (self._check_dataset_valid(dataset_name) and self._check_evaluation_valid(dataset_name, eval_name) and self._check_metric_valid(metric)): 252 | return None 253 | 254 | if metric == "qvt": 255 | logger.warning(f"QVT metric only supports overall performance.") 256 | return None 257 | 258 | statetment = QUERY_SUBSET_PERFORMANCE.format( 259 | DATASET_NAME=dataset_name, 260 | EVAL_NAME=eval_name, 261 | METRIC_COL=METRIC_COL_MAPPING[metric], 262 | WHERE_CONDITION=serialize_filter(filter) 263 | ) 264 | with self.engine.connect() as connection: 265 | result = connection.execute(text(statetment)) 266 | connection.commit() 267 | res = result.first() 268 | if res: 269 | return DataFrame(data={"Evaluation": eval_name, "Subset": filter.name, metric.upper(): res}).round(decimals=2) 270 | else: 271 | logger.warning("Query an empty result.") 272 | return DataFrame(data={"Evaluation": eval_name, "Subset": filter.name, metric.upper(): pd.NA}) 273 | 274 | def query_filter_leaderboard(self, dataset_name: str, filter: Filter, metric: str, eval_names: List[str] = None) -> DataFrame: 275 | if not (self._check_dataset_valid(dataset_name) and self._check_metric_valid(metric)): 276 | return None 277 | 278 | if eval_names: 279 | for eval_name in eval_names: 280 | if not self._check_evaluation_valid(dataset_name, eval_name): 281 | return None 282 | else: 283 | eval_names = self.query_available_evaluations(dataset_name)["Evaluation"].values 284 | dataframes = [] 285 | for eval_name in eval_names: 286 | dataframes.append(self.query_filter_performance(dataset_name, filter, metric, eval_name)) 287 | df = pd.concat(dataframes, ignore_index=True).sort_values(by=[metric.upper(), "Evaluation"], ascending=False, ignore_index=True) 288 | df["Rank"] = df[f"{metric.upper()}"].rank(axis=0, method="dense", ascending=False) 289 | return df 290 | 291 | def query_scenario_performance(self, dataset_name: str, scenario: Scenario, metric: str, eval_name: str) -> DataFrame: 292 | if not (self._check_dataset_valid(dataset_name) and self._check_evaluation_valid(dataset_name, eval_name) and self._check_metric_valid(metric)): 293 | return None 294 | 295 | if metric == "qvt": 296 | logger.warning(f"QVT metric only supports overall performance.") 297 | return None 298 | 299 | statetment = QUERY_SUBSET_PERFORMANCE.format( 300 | DATASET_NAME=dataset_name, 301 | EVAL_NAME=eval_name, 302 | METRIC_COL=METRIC_COL_MAPPING[metric], 303 | WHERE_CONDITION=serialize_scenario(scenario) 304 | ) 305 | with self.engine.connect() as connection: 306 | result = connection.execute(text(statetment)) 307 | connection.commit() 308 | res = result.first() 309 | if res: 310 | return DataFrame(data={"Evaluation": eval_name, "Subset": scenario.name, metric.upper(): res}).round(decimals=2) 311 | else: 312 | logger.warning("Query an empty result.") 313 | return DataFrame(data={"Evaluation": eval_name, "Subset": scenario.name, metric.upper(): pd.NA}) 314 | 315 | def query_scenario_leaderboard(self, dataset_name, scenario, metric, eval_names: List[str] = None) -> DataFrame: 316 | if not (self._check_dataset_valid(dataset_name) and self._check_metric_valid(metric)): 317 | return None 318 | 319 | if eval_names: 320 | for eval_name in eval_names: 321 | if not self._check_evaluation_valid(dataset_name, eval_name): 322 | return None 323 | else: 324 | eval_names = self.query_available_evaluations(dataset_name)["Evaluation"].values 325 | dataframes = [] 326 | for eval_name in eval_names: 327 | dataframes.append(self.query_scenario_performance(dataset_name, scenario, metric, eval_name)) 328 | df = pd.concat(dataframes, ignore_index=True).sort_values(by=[metric.upper(), "Evaluation"], ascending=False, ignore_index=True) 329 | df["Rank"] = df[f"{metric.upper()}"].rank(axis=0, method="dense", ascending=False) 330 | return df 331 | 332 | def query_dataset_sql_distribution(self, dataset_name: str) -> DataFrame: 333 | if not self._check_dataset_valid(dataset_name): 334 | return None 335 | else: 336 | statetment = QUERY_DATASET_SIZE.format(DATASET_NAME=dataset_name) 337 | with self.engine.connect() as connection: 338 | result = connection.execute(text(statetment)) 339 | connection.commit() 340 | total_count, unique_sqls_count = result.first() 341 | 342 | statetment = QUERY_DATASET_SQL_KEYWORDS_DISTRIBUTION.format(DATASET_NAME=dataset_name) 343 | with self.engine.connect() as connection: 344 | result = connection.execute(text(statetment)) 345 | connection.commit() 346 | (avg_count_query_fields, 347 | avg_count_group_by, 348 | avg_count_order_by, 349 | avg_count_limit, 350 | avg_count_join, 351 | avg_count_predicate, 352 | avg_count_aggregation, 353 | avg_count_scalar_function, 354 | avg_count_subquery, 355 | avg_count_set_operation, 356 | avg_count_math_compute, 357 | avg_count_logical_connecter, 358 | avg_count_distinct, 359 | avg_count_like, 360 | avg_count_control_flow, 361 | avg_count_window) = result.first() 362 | 363 | df = DataFrame(data=[ 364 | {"Metric": "Total Count", "Value": total_count}, 365 | {"Metric": "Unique SQL Count", "Value": unique_sqls_count}, 366 | {"Metric": "[QUERY FIELDS] / SQL", "Value": avg_count_query_fields}, 367 | {"Metric": "[GROUP BY] / SQL", "Value": avg_count_group_by}, 368 | {"Metric": "[ORDER BY] / SQL", "Value": avg_count_order_by}, 369 | {"Metric": "[LIMIT] / SQL", "Value": avg_count_limit}, 370 | {"Metric": "[JOIN] / SQL", "Value": avg_count_join}, 371 | {"Metric": "[PREDICATE] / SQL", "Value": avg_count_predicate}, 372 | {"Metric": "[AGGREGATION] / SQL", "Value": avg_count_aggregation}, 373 | {"Metric": "[SCALAR FUNCTION] / SQL", "Value": avg_count_scalar_function}, 374 | {"Metric": "[SUBQUERY] / SQL", "Value": avg_count_subquery}, 375 | {"Metric": "[SET OPERATION] / SQL", "Value": avg_count_set_operation}, 376 | {"Metric": "[MATH COMPUTE] / SQL", "Value": avg_count_math_compute}, 377 | {"Metric": "[LOGICAL CONNECTOR] / SQL", "Value": avg_count_logical_connecter}, 378 | {"Metric": "[DISTINCT] / SQL", "Value": avg_count_distinct}, 379 | {"Metric": "[LIKE] / SQL", "Value": avg_count_like}, 380 | {"Metric": "[CONTROL FLOW] / SQL", "Value": avg_count_control_flow}, 381 | {"Metric": "[WINDOW] / SQL", "Value": avg_count_window}, 382 | ]).round(decimals=2) 383 | 384 | return df 385 | 386 | def query_dataset_domain_distribution(self, dataset_name: str) -> DataFrame: 387 | if not self._check_dataset_valid(dataset_name): 388 | return None 389 | else: 390 | statetment = QUERY_DATASET_DOMAIN_DISTRIBUTION.format(DATASET_NAME=dataset_name) 391 | with self.engine.connect() as connection: 392 | result = connection.execute(text(statetment)) 393 | connection.commit() 394 | db_domain_count = [] 395 | for res in result: 396 | db_domain_count.append({"DB Domain": res[0], "Count": res[1]}) 397 | df = DataFrame(data=db_domain_count) 398 | return df 399 | 400 | def generate_evaluation_report(self, dataset_name: str, filters: List[Filter], scenarios: List[Scenario], metrics: List[str], eval_names: List[str] = None) -> DataFrame: 401 | if not self._check_dataset_valid(dataset_name): 402 | return None 403 | for metric in metrics: 404 | if not self._check_metric_valid(metric): 405 | return None 406 | 407 | if eval_names: 408 | for eval_name in eval_names: 409 | if not self._check_evaluation_valid(dataset_name, eval_name): 410 | return None 411 | else: 412 | eval_names = self.query_available_evaluations(dataset_name)["Evaluation"].values 413 | 414 | results = [] 415 | 416 | for eval_name in eval_names: 417 | 418 | # Overall performance 419 | 420 | data = {"Subset": "Overall"} 421 | for metric in metrics: 422 | df = self.query_overall_performance(dataset_name=dataset_name, metric=metric, eval_name=eval_name) 423 | data.update(df.to_dict()) 424 | results.append(DataFrame(data)) 425 | 426 | # Filter Performance 427 | 428 | for filter in filters: 429 | data = dict() 430 | for metric in metrics: 431 | filter_df = self.query_filter_performance(dataset_name=dataset_name, filter=filter, metric=metric, eval_name=eval_name) 432 | data.update(filter_df.to_dict()) 433 | results.append(DataFrame(data)) 434 | 435 | # Scenario Performance 436 | 437 | for scenario in scenarios: 438 | data = dict() 439 | for metric in metrics: 440 | scenario_df = self.query_scenario_performance(dataset_name=dataset_name, scenario=scenario, metric=metric, eval_name=eval_name) 441 | data.update(scenario_df.to_dict()) 442 | results.append(DataFrame(data)) 443 | 444 | df = pd.concat(results, ignore_index=True).sort_values(by=["Subset", "Evaluation"], ignore_index=True) 445 | return df 446 | 447 | def delete_dataset_history(self, dataset_name: str, delete_relavant_evaluations=True) -> None: 448 | logger.warning( 449 | "You are deleting the dataset history. Please enter `Y` / `YES` to confirm or enter `N` / `NO` to cancel the operation. " 450 | ) 451 | flag = input("Input your choice:\n").strip().upper() 452 | while flag not in ["Y", "YES", "N", "NO"]: 453 | logger.warning( 454 | "You are deleting the dataset history. Please enter `Y` / `YES` to confirm or enter `N` / `NO` to cancel the operation. " 455 | ) 456 | flag = input("Input your choice:\n").strip().upper() 457 | 458 | if flag in ["N", "NO"]: 459 | return 460 | 461 | if flag in ["Y", "YES"]: 462 | statements = [DELETE_DATASET_TABLE.format(DATASET_NAME=dataset_name), DELETE_DATASET_INFO.format(DATASET_NAME=dataset_name)] 463 | if delete_relavant_evaluations: 464 | for eval_name in self.query_available_evaluations(dataset_name)["Evaluation"].values: 465 | statements.append(DELETE_EVALUATION_TABLE.format(DATASET_NAME=dataset_name, EVAL_NAME=eval_name)) 466 | 467 | with self.engine.connect() as connection: 468 | for stat in statements: 469 | connection.execute(text(stat)) 470 | connection.commit() 471 | logger.success(f"Delete dataset `{dataset_name}` successfully.") 472 | return 473 | 474 | def delete_evaluation_history(self, dataset_name: str, eval_name: str) -> None: 475 | logger.warning( 476 | "You are deleting the evaluation history. Please enter `Y` / `YES` to confirm or enter `N` / `NO` to cancel the operation. " 477 | ) 478 | flag = input("Input your choice:\n").strip().upper() 479 | while flag not in ["Y", 'YES', "N", "NO"]: 480 | logger.warning( 481 | "You are deleting the dataset history. Please enter `Y` / `YES` to confirm or enter `N` / `NO` to cancel the operation. " 482 | ) 483 | flag = input("Input your choice:\n").strip().upper() 484 | 485 | if flag in ["N", "NO"]: 486 | return 487 | 488 | if flag in ["Y", "YES"]: 489 | statement = DELETE_EVALUATION_TABLE.format(DATASET_NAME=dataset_name, EVAL_NAME=eval_name) 490 | with self.engine.connect() as connection: 491 | connection.execute(text(statement)) 492 | connection.commit() 493 | logger.success(f"Delete evaluation `{eval_name}` for dataset `{dataset_name}` successfully.") 494 | return -------------------------------------------------------------------------------- /src/nl2sql360/database/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Base, DatasetInfo, MetaDataset, MetaEvaluation, get_dataset_model, get_evaluation_model 2 | from .util import (get_dataset_name_from_table_name, 3 | get_dataset_name_and_evaluation_name_from_table_name, 4 | get_dataset_info, 5 | get_dataset_samples) 6 | from .template import (METRIC_COL_MAPPING, 7 | QUERY_OVERALL_PERFORMANCE, 8 | QUERY_QVT_PERFORMANCE, 9 | QUERY_SUBSET_PERFORMANCE, 10 | QUERY_DATASET_SIZE, 11 | QUERY_DATASET_DOMAIN_DISTRIBUTION, 12 | QUERY_DATASET_SQL_KEYWORDS_DISTRIBUTION, 13 | DELETE_DATASET_TABLE, 14 | DELETE_EVALUATION_TABLE, 15 | DELETE_DATASET_INFO) 16 | 17 | 18 | __all__ = [ 19 | "Base", 20 | "DatasetInfo", 21 | "MetaDataset", 22 | "MetaEvaluation", 23 | "get_dataset_model", 24 | "get_evaluation_model", 25 | "get_dataset_name_from_table_name", 26 | "get_dataset_name_and_evaluation_name_from_table_name", 27 | "get_dataset_info", 28 | "get_dataset_samples", 29 | "METRIC_COL_MAPPING", 30 | "QUERY_OVERALL_PERFORMANCE", 31 | "QUERY_QVT_PERFORMANCE", 32 | "QUERY_SUBSET_PERFORMANCE", 33 | "QUERY_DATASET_SIZE", 34 | "QUERY_DATASET_DOMAIN_DISTRIBUTION", 35 | "QUERY_DATASET_SQL_KEYWORDS_DISTRIBUTION", 36 | "DELETE_DATASET_TABLE", 37 | "DELETE_EVALUATION_TABLE", 38 | "DELETE_DATASET_INFO" 39 | ] -------------------------------------------------------------------------------- /src/nl2sql360/database/model.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, String, Float, ForeignKey 2 | from sqlalchemy.orm import DeclarativeBase 3 | 4 | 5 | class Base(DeclarativeBase): 6 | pass 7 | 8 | 9 | class DatasetInfo(Base): 10 | __tablename__ = "__DATASET_INFO__" 11 | 12 | dataset_name = Column(String, primary_key=True) 13 | database_dir_path = Column(String, nullable=False) 14 | tables_json_path = Column(String, nullable=True, default=None) 15 | 16 | 17 | class MetaDataset: 18 | 19 | id = Column(Integer, primary_key=True) 20 | nlq = Column(String, nullable=False) 21 | gold = Column(String, nullable=False) 22 | db_id = Column(String, nullable=False) 23 | 24 | """Note: 25 | BIRD complexity: ["simple", "moderate", "challenging"] 26 | Spider complexity: ["easy", "medium", "hard", "extra"] 27 | """ 28 | complexity = Column(String, nullable=False) 29 | db_domain = Column(String, nullable=False) 30 | 31 | count_query_fields = Column(Integer, nullable=False) 32 | count_group_by = Column(Integer, nullable=False) 33 | count_order_by = Column(Integer, nullable=False) 34 | count_limit = Column(Integer, nullable=False) 35 | count_join = Column(Integer, nullable=False) 36 | count_predicate = Column(Integer, nullable=False) 37 | count_aggregation = Column(Integer, nullable=False) 38 | count_scalar_function = Column(Integer, nullable=False) 39 | count_subquery = Column(Integer, nullable=False) 40 | count_set_operation = Column(Integer, nullable=False) 41 | count_math_compute = Column(Integer, nullable=False) 42 | count_logical_connector = Column(Integer, nullable=False) 43 | count_distinct = Column(Integer, nullable=False) 44 | count_like = Column(Integer, nullable=False) 45 | count_control_flow = Column(Integer, nullable=False) 46 | count_window = Column(Integer, nullable=False) 47 | 48 | 49 | class MetaEvaluation: 50 | 51 | pred = Column(String, nullable=False) 52 | exec_acc = Column(Float, nullable=True, default=None) 53 | exact_acc = Column(Float, nullable=True, default=None) 54 | ves = Column(Float, nullable=True, default=None) 55 | 56 | 57 | def get_dataset_model(dataset_name): 58 | return type(f"DATASET_{dataset_name}", 59 | (MetaDataset, Base), 60 | dict(__tablename__=f"DATASET_{dataset_name}")) 61 | 62 | 63 | def get_evaluation_model(dataset_name, evaluation_name): 64 | return type(f"DATASET_{dataset_name}_EVALUATION_{evaluation_name}", 65 | (MetaEvaluation, Base), 66 | dict(id=Column(Integer, ForeignKey(f"DATASET_{dataset_name}"), primary_key=True), 67 | __tablename__=f"DATASET_{dataset_name}_EVALUATION_{evaluation_name}")) 68 | -------------------------------------------------------------------------------- /src/nl2sql360/database/template.py: -------------------------------------------------------------------------------- 1 | 2 | METRIC_COL_MAPPING = { 3 | "ex": "exec_acc", 4 | "em": "exact_acc", 5 | "ves": "ves", 6 | "qvt": None 7 | } 8 | 9 | 10 | QUERY_OVERALL_PERFORMANCE = \ 11 | """ 12 | SELECT AVG({METRIC_COL}) * 100 from DATASET_{DATASET_NAME}_EVALUATION_{EVAL_NAME} AS e JOIN DATASET_{DATASET_NAME} AS d ON e.id = d.id; 13 | """ 14 | 15 | 16 | QUERY_SUBSET_PERFORMANCE = \ 17 | """ 18 | SELECT AVG({METRIC_COL}) * 100 from DATASET_{DATASET_NAME}_EVALUATION_{EVAL_NAME} AS e JOIN DATASET_{DATASET_NAME} AS d ON e.id = d.id WHERE {WHERE_CONDITION}; 19 | """ 20 | 21 | 22 | QUERY_QVT_PERFORMANCE = \ 23 | """ 24 | SELECT AVG(exec_acc) * 100 FROM ( 25 | SELECT AVG(exec_acc) as exec_acc FROM DATASET_{DATASET_NAME}_EVALUATION_{EVAL_NAME} AS e DATASET_{DATASET_NAME} AS d ON e.id = d.id GROUP BY gold HAVING COUNT(d.gold) >= 2 and sum(e.exec_acc) != 0 26 | ); 27 | """ 28 | 29 | 30 | QUERY_DATASET_SIZE = \ 31 | """ 32 | SELECT COUNT(*), COUNT(DISTINCT gold) FROM DATASET_{DATASET_NAME}; 33 | """ 34 | 35 | 36 | QUERY_DATASET_SQL_KEYWORDS_DISTRIBUTION = \ 37 | """ 38 | SELECT 39 | AVG(count_query_fields), 40 | AVG(count_group_by), 41 | AVG(count_order_by), 42 | AVG(count_limit), 43 | AVG(count_join), 44 | AVG(count_predicate), 45 | AVG(count_aggregation), 46 | AVG(count_scalar_function), 47 | AVG(count_subquery), 48 | AVG(count_set_operation), 49 | AVG(count_math_compute), 50 | AVG(count_logical_connecter), 51 | AVG(count_distinct), 52 | AVG(count_like), 53 | AVG(count_control_flow), 54 | AVG(count_window) 55 | FROM DATASET_{DATASET_NAME}; 56 | """ 57 | 58 | 59 | QUERY_DATASET_DOMAIN_DISTRIBUTION = \ 60 | """ 61 | SELECT db_domain, COUNT(*) FROM DATASET_{DATASET_NAME} GROUP BY db_domain ORDER BY db_domain; 62 | """ 63 | 64 | 65 | DELETE_DATASET_TABLE = \ 66 | """ 67 | DROP TABLE IF EXISTS DATASET_{DATASET_NAME}; 68 | """ 69 | 70 | DELETE_DATASET_INFO = \ 71 | """ 72 | DELETE FROM __DATASET_INFO__ WHERE dataset_name = "{DATASET_NAME}"; 73 | """ 74 | 75 | 76 | DELETE_EVALUATION_TABLE = \ 77 | """ 78 | DROP TABLE IF EXISTS DATASET_{DATASET_NAME}_EVALUATION_{EVAL_NAME}; 79 | """ 80 | -------------------------------------------------------------------------------- /src/nl2sql360/database/util.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict, Any 2 | from sqlalchemy import Engine 3 | from sqlalchemy.orm import Session 4 | from .model import DatasetInfo, MetaDataset, get_dataset_model 5 | 6 | 7 | 8 | def get_dataset_name_from_table_name(dataset_table_name: str) -> str: 9 | return dataset_table_name.split("DATASET_")[1] 10 | 11 | 12 | def get_dataset_name_and_evaluation_name_from_table_name(evaluation_table_name: str) -> str: 13 | splits = evaluation_table_name.split("_EVALUATION_") 14 | dataset_name = splits[0].split("DATASET_")[-1] 15 | evaluation_name = splits[1] 16 | return dataset_name, evaluation_name 17 | 18 | 19 | def get_dataset_info(db_engine: "Engine", dataset_name: str) -> Optional[DatasetInfo]: 20 | with Session(db_engine) as session: 21 | query_res = session.query(DatasetInfo).filter(DatasetInfo.dataset_name == dataset_name).all() 22 | if query_res: 23 | assert len(query_res) == 1 24 | return query_res[0] 25 | else: 26 | return None 27 | 28 | 29 | def get_dataset_samples(db_engine: "Engine", dataset_model: "MetaDataset") -> Optional[Dict]: 30 | with Session(db_engine) as session: 31 | query_res = session.query(dataset_model).order_by(dataset_model.id) 32 | 33 | return [{ 34 | "nlq": record.nlq, 35 | "gold": record.gold, 36 | "db_id": record.db_id, 37 | } for record in query_res] 38 | 39 | -------------------------------------------------------------------------------- /src/nl2sql360/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import NL2SQLDataset 2 | 3 | 4 | __all__ = [ 5 | "NL2SQLDataset" 6 | ] -------------------------------------------------------------------------------- /src/nl2sql360/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | import os 3 | from ..arguments import DatasetArguments 4 | import json 5 | 6 | 7 | class NL2SQLDataset: 8 | 9 | def __init__(self, dataset_args: "DatasetArguments") -> None: 10 | self.dataset_args = dataset_args 11 | 12 | def get_all_samples(self): 13 | with open(os.path.join(self.dataset_args.dataset_dir, self.dataset_args.samples_file), "r", encoding="utf-8") as f: 14 | samples = json.load(f) 15 | return samples 16 | 17 | def get_all_questions(self): 18 | return [item[self.dataset_args.question_key] for item in self.get_all_samples()] 19 | 20 | def get_all_sqls(self): 21 | return [item[self.dataset_args.sql_key] for item in self.get_all_samples()] 22 | 23 | def get_all_db_ids(self): 24 | return [item[self.dataset_args.db_id_key] for item in self.get_all_samples()] 25 | 26 | def get_all_sql_complexity(self): 27 | if self.dataset_args.sql_complexity_key is not None: 28 | return [item[self.dataset_args.sql_complexity_key] for item in self.get_all_samples()] 29 | else: 30 | return ["" for _ in self.get_all_samples()] 31 | 32 | def get_all_database_domains(self): 33 | if self.dataset_args.database_domain_file is not None: 34 | with open(os.path.join(self.dataset_args.dataset_dir, self.dataset_args.database_domain_file), "r", encoding="utf-8") as f: 35 | domain_mapping = json.load(f) 36 | else: 37 | domain_mapping = dict() 38 | return [domain_mapping.get(item[self.dataset_args.db_id_key], "") for item in self.get_all_samples()] 39 | 40 | def get_all_database_paths(self): 41 | return [os.path.join(self.dataset_args.dataset_dir, 42 | self.dataset_args.database_dir, 43 | item[self.dataset_args.db_id_key], 44 | f"{item[self.dataset_args.db_id_key]}.sqlite" 45 | ) for item in self.get_all_samples()] 46 | 47 | def get_tables(self): 48 | if self.dataset_args.tables_file is None: 49 | return ValueError("`tables_file` dose not exist or is not a valid json file.") 50 | else: 51 | with open(os.path.join(self.dataset_args.dataset_dir, self.dataset_args.tables_file), "r", encoding="utf-8") as f: 52 | return json.load(f) 53 | 54 | def __len__(self): 55 | return len(self.get_all_samples()) 56 | 57 | -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from .bird_accuracy import BirdAccraucyEvaluator 2 | from .spider_accuracy import SpiderAccraucyEvaluator 3 | from .ves import VesEvaluator 4 | 5 | 6 | __all__ = [ 7 | "BirdAccraucyEvaluator", 8 | "SpiderAccraucyEvaluator", 9 | "SpiderHardnessEvaluator", 10 | "VesEvaluator" 11 | ] -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/bird_accuracy.py: -------------------------------------------------------------------------------- 1 | from .bird_eval.bird_accuracy import run_sqls_parallel, sort_results 2 | import os 3 | 4 | 5 | class BirdAccraucyEvaluator: 6 | 7 | def evaluate(self, gold_sqls, pred_sqls, db_ids, db_dir, **kwds): 8 | query_pairs = list(zip(pred_sqls, gold_sqls)) 9 | db_places = [os.path.join(db_dir, db_id, f"{db_id}.sqlite") for db_id in db_ids] 10 | exec_result = run_sqls_parallel( 11 | sqls=query_pairs, 12 | db_places=db_places, 13 | num_cpus=kwds.get("num_processes", 8), 14 | meta_time_out=kwds.get("timeout", 30) 15 | ) 16 | exec_result = sort_results(exec_result) 17 | exec_result = [res['res'] for res in exec_result] 18 | return { 19 | "exec_acc": exec_result 20 | } 21 | 22 | def get_eval_metrics(self): 23 | return ["exec_acc"] -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/bird_eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BugMaker-Boyan/NL2SQL360/59407da3c84a8b64b61d2e44e23d498146230cb5/src/nl2sql360/evaluator/bird_eval/__init__.py -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/bird_eval/bird_accuracy.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import argparse 4 | import sqlite3 5 | import multiprocessing as mp 6 | from func_timeout import func_timeout, FunctionTimedOut 7 | from tqdm import tqdm 8 | 9 | exec_result = [] 10 | progress_bar = None 11 | 12 | def load_json(dir): 13 | with open(dir, 'r') as j: 14 | contents = json.loads(j.read()) 15 | return contents 16 | 17 | 18 | def result_callback(result): 19 | exec_result.append(result) 20 | progress_bar.update() 21 | 22 | 23 | def execute_sql(predicted_sql,ground_truth, db_path): 24 | conn = sqlite3.connect(db_path) 25 | # Connect to the database 26 | cursor = conn.cursor() 27 | cursor.execute(predicted_sql) 28 | predicted_res = cursor.fetchall() 29 | cursor.execute(ground_truth) 30 | ground_truth_res = cursor.fetchall() 31 | res = 0 32 | if set(predicted_res) == set(ground_truth_res): 33 | res = 1 34 | return res 35 | 36 | 37 | 38 | def execute_model(predicted_sql,ground_truth, db_place, idx, meta_time_out): 39 | try: 40 | res = func_timeout(meta_time_out, execute_sql, 41 | args=(predicted_sql, ground_truth, db_place)) 42 | except KeyboardInterrupt: 43 | sys.exit(0) 44 | except FunctionTimedOut: 45 | result = [(f'timeout',)] 46 | res = 0 47 | except Exception as e: 48 | result = [(f'error',)] # possibly len(query) > 512 or not executable 49 | res = 0 50 | # print(result) 51 | # result = str(set([ret[0] for ret in result])) 52 | result = {'sql_idx': idx, 'res': res} 53 | # print(result) 54 | return result 55 | 56 | 57 | def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'): 58 | clean_sqls = [] 59 | db_path_list = [] 60 | if mode == 'gpt': 61 | sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r')) 62 | for idx, sql_str in sql_data.items(): 63 | if type(sql_str) == str: 64 | sql, db_name = sql_str.split('\t----- bird -----\t') 65 | else: 66 | sql, db_name = " ", "financial" 67 | clean_sqls.append(sql) 68 | db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') 69 | 70 | elif mode == 'gt': 71 | sqls = open(sql_path + data_mode + '_gold.sql') 72 | sql_txt = sqls.readlines() 73 | # sql_txt = [sql.split('\t')[0] for sql in sql_txt] 74 | for idx, sql_str in enumerate(sql_txt): 75 | sql, db_name = sql_str.strip().split('\t') 76 | clean_sqls.append(sql) 77 | db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') 78 | 79 | return clean_sqls, db_path_list 80 | 81 | def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0): 82 | global exec_result, progress_bar 83 | exec_result.clear() 84 | progress_bar = tqdm(total=len(sqls)) 85 | pool = mp.Pool(processes=num_cpus) 86 | for i,sql_pair in enumerate(sqls): 87 | predicted_sql, ground_truth = sql_pair 88 | pool.apply_async(execute_model, args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out), callback=result_callback) 89 | pool.close() 90 | pool.join() 91 | return exec_result 92 | 93 | def sort_results(list_of_dicts): 94 | return sorted(list_of_dicts, key=lambda x: x['sql_idx']) 95 | 96 | def compute_acc_by_diff(exec_results,diff_json_path): 97 | num_queries = len(exec_results) 98 | results = [res['res'] for res in exec_results] 99 | contents = load_json(diff_json_path) 100 | simple_results, moderate_results, challenging_results = [], [], [] 101 | 102 | for i,content in enumerate(contents): 103 | if content['difficulty'] == 'simple': 104 | simple_results.append(exec_results[i]) 105 | 106 | if content['difficulty'] == 'moderate': 107 | moderate_results.append(exec_results[i]) 108 | 109 | if content['difficulty'] == 'challenging': 110 | challenging_results.append(exec_results[i]) 111 | 112 | simple_acc = sum([res['res'] for res in simple_results])/len(simple_results) 113 | moderate_acc = sum([res['res'] for res in moderate_results])/len(moderate_results) 114 | challenging_acc = sum([res['res'] for res in challenging_results])/len(challenging_results) 115 | all_acc = sum(results)/num_queries 116 | count_lists = [len(simple_results), len(moderate_results), len(challenging_results), num_queries] 117 | return simple_acc * 100, moderate_acc * 100, challenging_acc * 100, all_acc * 100, count_lists 118 | 119 | 120 | 121 | def print_data(score_lists,count_lists): 122 | levels = ['simple', 'moderate', 'challenging', 'total'] 123 | print("{:20} {:20} {:20} {:20} {:20}".format("", *levels)) 124 | print("{:20} {:<20} {:<20} {:<20} {:<20}".format('count', *count_lists)) 125 | 126 | print('====================================== ACCURACY =====================================') 127 | print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format('accuracy', *score_lists)) 128 | 129 | 130 | if __name__ == '__main__': 131 | args_parser = argparse.ArgumentParser() 132 | args_parser.add_argument('--predicted_sql_path', type=str, required=True, default='') 133 | args_parser.add_argument('--ground_truth_path', type=str, required=True, default='') 134 | args_parser.add_argument('--data_mode', type=str, required=True, default='dev') 135 | args_parser.add_argument('--db_root_path', type=str, required=True, default='') 136 | args_parser.add_argument('--num_cpus', type=int, default=1) 137 | args_parser.add_argument('--meta_time_out', type=float, default=30.0) 138 | args_parser.add_argument('--mode_gt', type=str, default='gt') 139 | args_parser.add_argument('--mode_predict', type=str, default='gpt') 140 | args_parser.add_argument('--difficulty',type=str,default='simple') 141 | args_parser.add_argument('--diff_json_path',type=str,default='') 142 | args = args_parser.parse_args() 143 | exec_result = [] 144 | 145 | pred_queries, db_paths = package_sqls(args.predicted_sql_path, args.db_root_path, mode=args.mode_predict, 146 | data_mode=args.data_mode) 147 | # generate gt sqls: 148 | gt_queries, db_paths_gt = package_sqls(args.ground_truth_path, args.db_root_path, mode='gt', 149 | data_mode=args.data_mode) 150 | 151 | query_pairs = list(zip(pred_queries,gt_queries)) 152 | run_sqls_parallel(query_pairs, db_places=db_paths, num_cpus=args.num_cpus, meta_time_out=args.meta_time_out) 153 | exec_result = sort_results(exec_result) 154 | 155 | print('start calculate') 156 | simple_acc, moderate_acc, challenging_acc, acc, count_lists = \ 157 | compute_acc_by_diff(exec_result,args.diff_json_path) 158 | score_lists = [simple_acc, moderate_acc, challenging_acc, acc] 159 | print_data(score_lists,count_lists) 160 | print('===========================================================================================') 161 | print("Finished evaluation") 162 | -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/bird_eval/bird_ves.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import numpy as np 4 | import argparse 5 | import sqlite3 6 | import multiprocessing as mp 7 | from func_timeout import func_timeout, FunctionTimedOut 8 | import time 9 | import math 10 | from tqdm import tqdm 11 | 12 | exec_result = [] 13 | progress_bar = None 14 | 15 | def result_callback(result): 16 | exec_result.append(result) 17 | progress_bar.update() 18 | 19 | 20 | def clean_abnormal(input): 21 | input = np.asarray(input) 22 | processed_list = [] 23 | mean = np.mean(input,axis=0) 24 | std = np.std(input,axis=0) 25 | for x in input: 26 | if x < mean + 3 * std and x > mean - 3 * std: 27 | processed_list.append(x) 28 | return processed_list 29 | 30 | def execute_sql(sql, db_path): 31 | # Connect to the database 32 | conn = sqlite3.connect(db_path) 33 | # Create a cursor object 34 | cursor = conn.cursor() 35 | start_time = time.perf_counter() 36 | cursor.execute(sql) 37 | exec_time = time.perf_counter() - start_time 38 | return exec_time if exec_time != 0 else 1e-9 39 | 40 | def iterated_execute_sql(predicted_sql,ground_truth,db_path,iterate_num, exec_acc): 41 | conn = sqlite3.connect(db_path) 42 | diff_list = [] 43 | cursor = conn.cursor() 44 | cursor.execute(predicted_sql) 45 | predicted_res = cursor.fetchall() 46 | cursor.execute(ground_truth) 47 | ground_truth_res = cursor.fetchall() 48 | time_ratio = 0 49 | if (exec_acc is None and set(predicted_res) == set(ground_truth_res)) or (exec_acc is not None and exec_acc == 1): 50 | for i in range(iterate_num): 51 | predicted_time = execute_sql(predicted_sql, db_path) 52 | ground_truth_time = execute_sql(ground_truth, db_path) 53 | diff_list.append(ground_truth_time / predicted_time) 54 | processed_diff_list = clean_abnormal(diff_list) 55 | time_ratio = sum(processed_diff_list) / len(processed_diff_list) 56 | return time_ratio 57 | 58 | 59 | 60 | def execute_model(predicted_sql,ground_truth, db_place, idx, iterate_num, meta_time_out, exec_acc): 61 | try: 62 | # you can personalize the total timeout number 63 | # larger timeout leads to more stable ves 64 | # while it needs more your patience.... 65 | time_ratio = func_timeout(meta_time_out * iterate_num, iterated_execute_sql, 66 | args=(predicted_sql, ground_truth, db_place, iterate_num, exec_acc)) 67 | # print([idx, math.sqrt(time_ratio)]) 68 | except KeyboardInterrupt as e: 69 | # print(e) 70 | sys.exit(0) 71 | except FunctionTimedOut as e: 72 | print(e) 73 | result = [(f'timeout',)] 74 | time_ratio = 0 75 | except Exception as e: 76 | # print(e) 77 | result = [(f'error',)] # possibly len(query) > 512 or not executable 78 | time_ratio = 0 79 | result = {'sql_idx': idx, 'time_ratio': time_ratio} 80 | return result 81 | 82 | 83 | def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'): 84 | clean_sqls = [] 85 | db_path_list = [] 86 | if mode == 'gpt': 87 | sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r')) 88 | for idx, sql_str in sql_data.items(): 89 | if type(sql_str) == str: 90 | sql, db_name = sql_str.split('\t----- bird -----\t') 91 | else: 92 | sql, db_name = " ", "financial" 93 | clean_sqls.append(sql) 94 | db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') 95 | 96 | elif mode == 'gt': 97 | sqls = open(sql_path + data_mode + '_gold.sql') 98 | sql_txt = sqls.readlines() 99 | for idx, sql_str in enumerate(sql_txt): 100 | sql, db_name = sql_str.strip().split('\t') 101 | clean_sqls.append(sql) 102 | db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') 103 | 104 | return clean_sqls, db_path_list 105 | 106 | def run_sqls_parallel(sqls, db_places, num_cpus=1, iterate_num=100, meta_time_out=30.0, exec_acc_list=None): 107 | global exec_result, progress_bar 108 | exec_result.clear() 109 | progress_bar = tqdm(total=len(sqls)) 110 | pool = mp.Pool(processes=num_cpus) 111 | for i,sql_pair in enumerate(sqls): 112 | predicted_sql, ground_truth = sql_pair 113 | exec_acc = exec_acc_list[i] if exec_acc_list else None 114 | pool.apply_async(execute_model, args=(predicted_sql, ground_truth, db_places[i], i, iterate_num, meta_time_out, exec_acc), callback=result_callback) 115 | pool.close() 116 | pool.join() 117 | return exec_result 118 | 119 | def sort_results(list_of_dicts): 120 | return sorted(list_of_dicts, key=lambda x: x['sql_idx']) 121 | 122 | def compute_ves(exec_results): 123 | num_queries = len(exec_results) 124 | total_ratio = 0 125 | count = 0 126 | 127 | for i, result in enumerate(exec_results): 128 | if result['time_ratio'] != 0: 129 | count += 1 130 | total_ratio += math.sqrt(result['time_ratio']) * 100 131 | ves = (total_ratio/num_queries) 132 | return ves 133 | 134 | def load_json(dir): 135 | with open(dir, 'r') as j: 136 | contents = json.loads(j.read()) 137 | return contents 138 | 139 | def compute_ves_by_diff(exec_results,diff_json_path): 140 | num_queries = len(exec_results) 141 | contents = load_json(diff_json_path) 142 | simple_results, moderate_results, challenging_results = [], [], [] 143 | for i,content in enumerate(contents): 144 | if content['difficulty'] == 'simple': 145 | simple_results.append(exec_results[i]) 146 | if content['difficulty'] == 'moderate': 147 | moderate_results.append(exec_results[i]) 148 | if content['difficulty'] == 'challenging': 149 | challenging_results.append(exec_results[i]) 150 | simple_ves = compute_ves(simple_results) 151 | moderate_ves = compute_ves(moderate_results) 152 | challenging_ves = compute_ves(challenging_results) 153 | all_ves = compute_ves(exec_results) 154 | count_lists = [len(simple_results), len(moderate_results), len(challenging_results), num_queries] 155 | return simple_ves, moderate_ves, challenging_ves, all_ves, count_lists 156 | 157 | def print_data(score_lists,count_lists): 158 | levels = ['simple', 'moderate', 'challenging', 'total'] 159 | print("{:20} {:20} {:20} {:20} {:20}".format("", *levels)) 160 | print("{:20} {:<20} {:<20} {:<20} {:<20}".format('count', *count_lists)) 161 | 162 | print('========================================= VES ========================================') 163 | print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format('ves', *score_lists)) 164 | 165 | if __name__ == '__main__': 166 | args_parser = argparse.ArgumentParser() 167 | args_parser.add_argument('--predicted_sql_path', type=str, required=True, default='') 168 | args_parser.add_argument('--ground_truth_path', type=str, required=True, default='') 169 | args_parser.add_argument('--data_mode', type=str, required=True, default='dev') 170 | args_parser.add_argument('--db_root_path', type=str, required=True, default='') 171 | args_parser.add_argument('--num_cpus', type=int, default=1) 172 | args_parser.add_argument('--meta_time_out', type=float, default=30.0) 173 | args_parser.add_argument('--mode_gt', type=str, default='gt') 174 | args_parser.add_argument('--mode_predict', type=str, default='gpt') 175 | args_parser.add_argument('--diff_json_path',type=str,default='') 176 | args = args_parser.parse_args() 177 | exec_result = [] 178 | 179 | pred_queries, db_paths = package_sqls(args.predicted_sql_path, args.db_root_path, mode=args.mode_predict, 180 | data_mode=args.data_mode) 181 | # generate gt sqls: 182 | gt_queries, db_paths_gt = package_sqls(args.ground_truth_path, args.db_root_path, mode='gt', 183 | data_mode=args.data_mode) 184 | 185 | query_pairs = list(zip(pred_queries, gt_queries)) 186 | run_sqls_parallel(query_pairs, db_places=db_paths, num_cpus=args.num_cpus, meta_time_out=args.meta_time_out) 187 | exec_result = sort_results(exec_result) 188 | print('start calculate') 189 | simple_ves, moderate_ves, challenging_ves, ves, count_lists = \ 190 | compute_ves_by_diff(exec_result, args.diff_json_path) 191 | score_lists = [simple_ves, moderate_ves, challenging_ves, ves] 192 | print_data(score_lists, count_lists) 193 | print('===========================================================================================') 194 | print("Finished evaluation") 195 | 196 | -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/spider_accuracy.py: -------------------------------------------------------------------------------- 1 | from .test_suite_sql_eval.evaluation import evaluate, build_foreign_key_map_from_json 2 | from loguru import logger 3 | 4 | 5 | class SpiderAccraucyEvaluator: 6 | 7 | def __init__(self, eval_ex, eval_em): 8 | self.eval_ex = eval_ex 9 | self.eval_em = eval_em 10 | 11 | def get_eval_metrics(self): 12 | eval_metrics = [] 13 | if self.eval_ex: 14 | eval_metrics.append("exec_acc") 15 | if self.eval_em: 16 | eval_metrics.append("exact_acc") 17 | return eval_metrics 18 | 19 | def evaluate(self, gold_sqls, pred_sqls, db_ids, db_dir, **kwds): 20 | if self.eval_ex and self.eval_em: 21 | etype = "all" 22 | elif self.eval_ex: 23 | etype = "exec" 24 | elif self.eval_em: 25 | etype = "match" 26 | else: 27 | logger.warning(f"Spider evaluator has not been set with evaluation metrics and is defaulted to EX.") 28 | etype = "exec" 29 | 30 | kmaps = None 31 | 32 | if etype in ["all", "match"]: 33 | if kwds.get("tables_json_path", None) is None: 34 | logger.warning(f"`EM` metric evaluation need tables json path passed in. Exclude `EM` by default.") 35 | etype = "exec" 36 | kmaps = None 37 | else: 38 | kmaps = build_foreign_key_map_from_json(kwds.get("tables_json_path")) 39 | 40 | golds = [f"{gold}\t{db_id}" for gold, db_id in zip(gold_sqls, db_ids)] 41 | 42 | entries = evaluate( 43 | golds=golds, 44 | preds=pred_sqls, 45 | db_dir=db_dir, 46 | etype=etype, 47 | kmaps=kmaps, 48 | plug_value=False, 49 | keep_distinct=False, 50 | progress_bar_for_each_datapoint=False 51 | ) 52 | 53 | return { 54 | "exec_acc": [entry.get("exec", None) for entry in entries], 55 | "exact_acc": [entry.get("exact", None) for entry in entries] 56 | } -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/test_suite_sql_eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BugMaker-Boyan/NL2SQL360/59407da3c84a8b64b61d2e44e23d498146230cb5/src/nl2sql360/evaluator/test_suite_sql_eval/__init__.py -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/test_suite_sql_eval/evaluation.py: -------------------------------------------------------------------------------- 1 | ################################ 2 | # val: number(float)/string(str)/sql(dict) 3 | # col_unit: (agg_id, col_id, isDistinct(bool)) 4 | # val_unit: (unit_op, col_unit1, col_unit2) 5 | # table_unit: (table_type, col_unit/sql) 6 | # cond_unit: (not_op, op_id, val_unit, val1, val2) 7 | # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] 8 | # sql { 9 | # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) 10 | # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} 11 | # 'where': condition 12 | # 'groupBy': [col_unit1, col_unit2, ...] 13 | # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) 14 | # 'having': condition 15 | # 'limit': None/limit value 16 | # 'intersect': None/sql 17 | # 'except': None/sql 18 | # 'union': None/sql 19 | # } 20 | ################################ 21 | 22 | import os 23 | import json 24 | import sqlite3 25 | import argparse 26 | from tqdm import tqdm 27 | from copy import deepcopy 28 | from loguru import logger 29 | 30 | from .process_sql import get_schema, Schema, get_sql 31 | from .exec_eval import eval_exec_match 32 | 33 | # Flag to disable value evaluation 34 | DISABLE_VALUE = True 35 | # Flag to disable distinct in select evaluation 36 | DISABLE_DISTINCT = True 37 | 38 | 39 | CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') 40 | JOIN_KEYWORDS = ('join', 'on', 'as') 41 | 42 | WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') 43 | UNIT_OPS = ('none', '-', '+', "*", '/') 44 | AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') 45 | TABLE_TYPE = { 46 | 'sql': "sql", 47 | 'table_unit': "table_unit", 48 | } 49 | 50 | COND_OPS = ('and', 'or') 51 | SQL_OPS = ('intersect', 'union', 'except') 52 | ORDER_OPS = ('desc', 'asc') 53 | 54 | 55 | HARDNESS = { 56 | "component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'), 57 | "component2": ('except', 'union', 'intersect') 58 | } 59 | 60 | _EMPTY_SQL = { 61 | "except": None, 62 | "from": { 63 | "conds": [], 64 | "table_units": [] 65 | }, 66 | "groupBy": [], 67 | "having": [], 68 | "intersect": None, 69 | "limit": None, 70 | "orderBy": [], 71 | "select": [ 72 | False, 73 | [] 74 | ], 75 | "union": None, 76 | "where": [] 77 | } 78 | 79 | 80 | def condition_has_or(conds): 81 | return 'or' in conds[1::2] 82 | 83 | 84 | def condition_has_like(conds): 85 | return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]] 86 | 87 | 88 | def condition_has_sql(conds): 89 | for cond_unit in conds[::2]: 90 | val1, val2 = cond_unit[3], cond_unit[4] 91 | if val1 is not None and type(val1) is dict: 92 | return True 93 | if val2 is not None and type(val2) is dict: 94 | return True 95 | return False 96 | 97 | 98 | def val_has_op(val_unit): 99 | return val_unit[0] != UNIT_OPS.index('none') 100 | 101 | 102 | def has_agg(unit): 103 | return unit[0] != AGG_OPS.index('none') 104 | 105 | 106 | def accuracy(count, total): 107 | if count == total: 108 | return 1 109 | return 0 110 | 111 | 112 | def recall(count, total): 113 | if count == total: 114 | return 1 115 | return 0 116 | 117 | 118 | def F1(acc, rec): 119 | if (acc + rec) == 0: 120 | return 0 121 | return (2. * acc * rec) / (acc + rec) 122 | 123 | 124 | def get_scores(count, pred_total, label_total): 125 | if pred_total != label_total: 126 | return 0,0,0 127 | elif count == pred_total: 128 | return 1,1,1 129 | return 0,0,0 130 | 131 | 132 | def eval_sel(pred, label): 133 | pred_sel = pred['select'][1] 134 | label_sel = label['select'][1] 135 | label_wo_agg = [unit[1] for unit in label_sel] 136 | pred_total = len(pred_sel) 137 | label_total = len(label_sel) 138 | cnt = 0 139 | cnt_wo_agg = 0 140 | 141 | for unit in pred_sel: 142 | if unit in label_sel: 143 | cnt += 1 144 | label_sel.remove(unit) 145 | if unit[1] in label_wo_agg: 146 | cnt_wo_agg += 1 147 | label_wo_agg.remove(unit[1]) 148 | 149 | return label_total, pred_total, cnt, cnt_wo_agg 150 | 151 | 152 | def eval_where(pred, label): 153 | pred_conds = [unit for unit in pred['where'][::2]] 154 | label_conds = [unit for unit in label['where'][::2]] 155 | label_wo_agg = [unit[2] for unit in label_conds] 156 | pred_total = len(pred_conds) 157 | label_total = len(label_conds) 158 | cnt = 0 159 | cnt_wo_agg = 0 160 | 161 | for unit in pred_conds: 162 | if unit in label_conds: 163 | cnt += 1 164 | label_conds.remove(unit) 165 | if unit[2] in label_wo_agg: 166 | cnt_wo_agg += 1 167 | label_wo_agg.remove(unit[2]) 168 | 169 | return label_total, pred_total, cnt, cnt_wo_agg 170 | 171 | 172 | def eval_group(pred, label): 173 | pred_cols = [unit[1] for unit in pred['groupBy']] 174 | label_cols = [unit[1] for unit in label['groupBy']] 175 | pred_total = len(pred_cols) 176 | label_total = len(label_cols) 177 | cnt = 0 178 | pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols] 179 | label_cols = [label.split(".")[1] if "." in label else label for label in label_cols] 180 | for col in pred_cols: 181 | if col in label_cols: 182 | cnt += 1 183 | label_cols.remove(col) 184 | return label_total, pred_total, cnt 185 | 186 | 187 | def eval_having(pred, label): 188 | pred_total = label_total = cnt = 0 189 | if len(pred['groupBy']) > 0: 190 | pred_total = 1 191 | if len(label['groupBy']) > 0: 192 | label_total = 1 193 | 194 | pred_cols = [unit[1] for unit in pred['groupBy']] 195 | label_cols = [unit[1] for unit in label['groupBy']] 196 | if pred_total == label_total == 1 \ 197 | and pred_cols == label_cols \ 198 | and pred['having'] == label['having']: 199 | cnt = 1 200 | 201 | return label_total, pred_total, cnt 202 | 203 | 204 | def eval_order(pred, label): 205 | pred_total = label_total = cnt = 0 206 | if len(pred['orderBy']) > 0: 207 | pred_total = 1 208 | if len(label['orderBy']) > 0: 209 | label_total = 1 210 | if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \ 211 | ((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)): 212 | cnt = 1 213 | return label_total, pred_total, cnt 214 | 215 | 216 | def eval_and_or(pred, label): 217 | pred_ao = pred['where'][1::2] 218 | label_ao = label['where'][1::2] 219 | pred_ao = set(pred_ao) 220 | label_ao = set(label_ao) 221 | 222 | if pred_ao == label_ao: 223 | return 1,1,1 224 | return len(pred_ao),len(label_ao),0 225 | 226 | 227 | def get_nestedSQL(sql): 228 | nested = [] 229 | for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]: 230 | if type(cond_unit[3]) is dict: 231 | nested.append(cond_unit[3]) 232 | if type(cond_unit[4]) is dict: 233 | nested.append(cond_unit[4]) 234 | if sql['intersect'] is not None: 235 | nested.append(sql['intersect']) 236 | if sql['except'] is not None: 237 | nested.append(sql['except']) 238 | if sql['union'] is not None: 239 | nested.append(sql['union']) 240 | return nested 241 | 242 | 243 | def eval_nested(pred, label): 244 | label_total = 0 245 | pred_total = 0 246 | cnt = 0 247 | if pred is not None: 248 | pred_total += 1 249 | if label is not None: 250 | label_total += 1 251 | if pred is not None and label is not None: 252 | cnt += Evaluator().eval_exact_match(pred, label) 253 | return label_total, pred_total, cnt 254 | 255 | 256 | def eval_IUEN(pred, label): 257 | lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect']) 258 | lt2, pt2, cnt2 = eval_nested(pred['except'], label['except']) 259 | lt3, pt3, cnt3 = eval_nested(pred['union'], label['union']) 260 | label_total = lt1 + lt2 + lt3 261 | pred_total = pt1 + pt2 + pt3 262 | cnt = cnt1 + cnt2 + cnt3 263 | return label_total, pred_total, cnt 264 | 265 | 266 | def get_keywords(sql): 267 | res = set() 268 | if len(sql['where']) > 0: 269 | res.add('where') 270 | if len(sql['groupBy']) > 0: 271 | res.add('group') 272 | if len(sql['having']) > 0: 273 | res.add('having') 274 | if len(sql['orderBy']) > 0: 275 | res.add(sql['orderBy'][0]) 276 | res.add('order') 277 | if sql['limit'] is not None: 278 | res.add('limit') 279 | if sql['except'] is not None: 280 | res.add('except') 281 | if sql['union'] is not None: 282 | res.add('union') 283 | if sql['intersect'] is not None: 284 | res.add('intersect') 285 | 286 | # or keyword 287 | ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] 288 | if len([token for token in ao if token == 'or']) > 0: 289 | res.add('or') 290 | 291 | cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] 292 | # not keyword 293 | if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0: 294 | res.add('not') 295 | 296 | # in keyword 297 | if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0: 298 | res.add('in') 299 | 300 | # like keyword 301 | if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0: 302 | res.add('like') 303 | 304 | return res 305 | 306 | 307 | def eval_keywords(pred, label): 308 | pred_keywords = get_keywords(pred) 309 | label_keywords = get_keywords(label) 310 | pred_total = len(pred_keywords) 311 | label_total = len(label_keywords) 312 | cnt = 0 313 | 314 | for k in pred_keywords: 315 | if k in label_keywords: 316 | cnt += 1 317 | return label_total, pred_total, cnt 318 | 319 | 320 | def count_agg(units): 321 | return len([unit for unit in units if has_agg(unit)]) 322 | 323 | 324 | def count_component1(sql): 325 | count = 0 326 | if len(sql['where']) > 0: 327 | count += 1 328 | if len(sql['groupBy']) > 0: 329 | count += 1 330 | if len(sql['orderBy']) > 0: 331 | count += 1 332 | if sql['limit'] is not None: 333 | count += 1 334 | if len(sql['from']['table_units']) > 0: # JOIN 335 | count += len(sql['from']['table_units']) - 1 336 | 337 | ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] 338 | count += len([token for token in ao if token == 'or']) 339 | cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] 340 | count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) 341 | 342 | return count 343 | 344 | 345 | def count_component2(sql): 346 | nested = get_nestedSQL(sql) 347 | return len(nested) 348 | 349 | 350 | def count_others(sql): 351 | count = 0 352 | # number of aggregation 353 | agg_count = count_agg(sql['select'][1]) 354 | agg_count += count_agg(sql['where'][::2]) 355 | agg_count += count_agg(sql['groupBy']) 356 | if len(sql['orderBy']) > 0: 357 | agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] + 358 | [unit[2] for unit in sql['orderBy'][1] if unit[2]]) 359 | agg_count += count_agg(sql['having']) 360 | if agg_count > 1: 361 | count += 1 362 | 363 | # number of select columns 364 | if len(sql['select'][1]) > 1: 365 | count += 1 366 | 367 | # number of where conditions 368 | if len(sql['where']) > 1: 369 | count += 1 370 | 371 | # number of group by clauses 372 | if len(sql['groupBy']) > 1: 373 | count += 1 374 | 375 | return count 376 | 377 | 378 | class Evaluator: 379 | """A simple evaluator""" 380 | def __init__(self): 381 | self.partial_scores = None 382 | 383 | def eval_hardness(self, sql): 384 | count_comp1_ = count_component1(sql) 385 | count_comp2_ = count_component2(sql) 386 | count_others_ = count_others(sql) 387 | 388 | if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0: 389 | return "easy" 390 | elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \ 391 | (count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0): 392 | return "medium" 393 | elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \ 394 | (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \ 395 | (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1): 396 | return "hard" 397 | else: 398 | return "extra" 399 | 400 | def eval_exact_match(self, pred, label): 401 | partial_scores = self.eval_partial_match(pred, label) 402 | self.partial_scores = partial_scores 403 | 404 | for key, score in partial_scores.items(): 405 | if score['f1'] != 1: 406 | return 0 407 | 408 | if len(label['from']['table_units']) > 0: 409 | label_tables = sorted(label['from']['table_units']) 410 | pred_tables = sorted(pred['from']['table_units']) 411 | return label_tables == pred_tables 412 | return 1 413 | 414 | def eval_partial_match(self, pred, label): 415 | res = {} 416 | 417 | label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label) 418 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 419 | res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 420 | acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) 421 | res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 422 | 423 | label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label) 424 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 425 | res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 426 | acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) 427 | res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 428 | 429 | label_total, pred_total, cnt = eval_group(pred, label) 430 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 431 | res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 432 | 433 | label_total, pred_total, cnt = eval_having(pred, label) 434 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 435 | res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 436 | 437 | label_total, pred_total, cnt = eval_order(pred, label) 438 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 439 | res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 440 | 441 | label_total, pred_total, cnt = eval_and_or(pred, label) 442 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 443 | res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 444 | 445 | label_total, pred_total, cnt = eval_IUEN(pred, label) 446 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 447 | res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 448 | 449 | label_total, pred_total, cnt = eval_keywords(pred, label) 450 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 451 | res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 452 | 453 | return res 454 | 455 | 456 | def isValidSQL(sql, db): 457 | conn = sqlite3.connect(db) 458 | cursor = conn.cursor() 459 | try: 460 | cursor.execute(sql) 461 | except: 462 | return False 463 | return True 464 | 465 | 466 | 467 | def print_formated_s(row_name, l, element_format): 468 | template = "{:20} " + ' '.join([element_format] * len(l)) 469 | print(template.format(row_name, *l)) 470 | 471 | 472 | def print_scores(scores, etype, include_turn_acc=True): 473 | turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn > 4'] 474 | levels = ['easy', 'medium', 'hard', 'extra', 'all'] 475 | if include_turn_acc: 476 | levels.append('joint_all') 477 | partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 478 | 'group', 'order', 'and/or', 'IUEN', 'keywords'] 479 | 480 | print_formated_s("", levels, '{:20}') 481 | counts = [scores[level]['count'] for level in levels] 482 | print_formated_s("count", counts, '{:<20d}') 483 | 484 | if etype in ["all", "exec"]: 485 | print ('===================== EXECUTION ACCURACY =====================') 486 | exec_scores = [scores[level]['exec'] for level in levels] 487 | print_formated_s("execution", exec_scores, '{:<20.3f}') 488 | 489 | if etype in ["all", "match"]: 490 | print ('\n====================== EXACT MATCHING ACCURACY =====================') 491 | exact_scores = [scores[level]['exact'] for level in levels] 492 | print_formated_s("exact match", exact_scores, '{:<20.3f}') 493 | print ('\n---------------------PARTIAL MATCHING ACCURACY----------------------') 494 | for type_ in partial_types: 495 | this_scores = [scores[level]['partial'][type_]['acc'] for level in levels] 496 | print_formated_s(type_, this_scores, '{:<20.3f}') 497 | 498 | print ('---------------------- PARTIAL MATCHING RECALL ----------------------') 499 | for type_ in partial_types: 500 | this_scores = [scores[level]['partial'][type_]['rec'] for level in levels] 501 | print_formated_s(type_, this_scores, '{:<20.3f}') 502 | 503 | print ('---------------------- PARTIAL MATCHING F1 --------------------------') 504 | for type_ in partial_types: 505 | this_scores = [scores[level]['partial'][type_]['f1'] for level in levels] 506 | print_formated_s(type_, this_scores, '{:<20.3f}') 507 | 508 | if include_turn_acc: 509 | print() 510 | print() 511 | print_formated_s("", turns, '{:20}') 512 | counts = [scores[turn]['count'] for turn in turns] 513 | print_formated_s("count", counts, "{:<20d}") 514 | 515 | if etype in ["all", "exec"]: 516 | print ('===================== TURN EXECUTION ACCURACY =====================') 517 | exec_scores = [scores[turn]['exec'] for turn in turns] 518 | print_formated_s("execution", exec_scores, '{:<20.3f}') 519 | 520 | if etype in ["all", "match"]: 521 | print ('\n====================== TURN EXACT MATCHING ACCURACY =====================') 522 | exact_scores = [scores[turn]['exact'] for turn in turns] 523 | print_formated_s("exact match", exact_scores, '{:<20.3f}') 524 | 525 | 526 | def evaluate(golds, preds, db_dir, etype, kmaps, plug_value, keep_distinct, progress_bar_for_each_datapoint): 527 | 528 | glist = [] 529 | gseq_one = [] 530 | for l in golds: 531 | if len(l.strip()) == 0: 532 | glist.append(gseq_one) 533 | gseq_one = [] 534 | else: 535 | lstrip = l.strip().split('\t') 536 | gseq_one.append(lstrip) 537 | 538 | # include the last session 539 | # this was previously ignored in the SParC evaluation script 540 | # which might lead to slight differences in scores 541 | if len(gseq_one) != 0: 542 | glist.append(gseq_one) 543 | 544 | # spider formatting indicates that there is only one "single turn" 545 | # do not report "turn accuracy" for SPIDER 546 | include_turn_acc = len(glist) > 1 547 | 548 | plist = [] 549 | pseq_one = [] 550 | for l in preds: 551 | if len(l.strip()) == 0: 552 | plist.append(pseq_one) 553 | pseq_one = [] 554 | else: 555 | pseq_one.append(l.strip().split('\t')) 556 | 557 | if len(pseq_one) != 0: 558 | plist.append(pseq_one) 559 | 560 | assert len(plist) == len(glist), "number of sessions must equal" 561 | 562 | evaluator = Evaluator() 563 | turns = ['turn 1', 'turn 2', 'turn 3', 'turn 4', 'turn > 4'] 564 | levels = ['easy', 'medium', 'hard', 'extra', 'all', 'joint_all'] 565 | 566 | partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 567 | 'group', 'order', 'and/or', 'IUEN', 'keywords'] 568 | entries = [] 569 | scores = {} 570 | 571 | for turn in turns: 572 | scores[turn] = {'count': 0, 'exact': 0.} 573 | scores[turn]['exec'] = 0 574 | 575 | for level in levels: 576 | scores[level] = {'count': 0, 'partial': {}, 'exact': 0.} 577 | scores[level]['exec'] = 0 578 | for type_ in partial_types: 579 | scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0} 580 | 581 | parse_g_sql_error_flag = False 582 | 583 | for i, (p, g) in enumerate(zip(plist, glist)): 584 | if (i + 1) % 10 == 0: 585 | print('Evaluating %dth prediction' % (i + 1)) 586 | scores['joint_all']['count'] += 1 587 | turn_scores = {"exec": [], "exact": []} 588 | for idx, pg in enumerate(tqdm(list(zip(p, g)))): 589 | p, g = pg 590 | p_str = p[0] 591 | # p_str = p_str.replace("value", "1") 592 | g_str, db = g 593 | db_name = db 594 | db = os.path.join(db_dir, db, db + ".sqlite") 595 | schema = Schema(get_schema(db)) 596 | try: 597 | g_sql = get_sql(schema, g_str) 598 | except: 599 | parse_g_sql_error_flag = True 600 | g_sql = deepcopy(_EMPTY_SQL) 601 | hardness = evaluator.eval_hardness(g_sql) 602 | if idx > 3: 603 | idx = "> 4" 604 | else: 605 | idx += 1 606 | turn_id = "turn " + str(idx) 607 | scores[turn_id]['count'] += 1 608 | scores[hardness]['count'] += 1 609 | scores['all']['count'] += 1 610 | 611 | try: 612 | p_sql = get_sql(schema, p_str) 613 | except: 614 | # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql 615 | p_sql = deepcopy(_EMPTY_SQL) 616 | 617 | entry = { 618 | 'predictSQL': p_str, 619 | 'goldSQL': g_str, 620 | 'hardness': hardness, 621 | } 622 | 623 | if etype in ["all", "exec"]: 624 | exec_score = eval_exec_match(db=db, p_str=p_str, g_str=g_str, plug_value=plug_value, 625 | keep_distinct=keep_distinct, progress_bar_for_each_datapoint=progress_bar_for_each_datapoint) 626 | if exec_score: 627 | scores[hardness]['exec'] += 1 628 | scores[turn_id]['exec'] += 1 629 | scores['all']['exec'] += 1 630 | turn_scores['exec'].append(1) 631 | else: 632 | turn_scores['exec'].append(0) 633 | 634 | entry["exec"] = exec_score 635 | 636 | if etype in ["all", "match"]: 637 | # rebuild sql for value evaluation 638 | kmap = kmaps[db_name] 639 | g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema) 640 | g_sql = rebuild_sql_val(g_sql) 641 | g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) 642 | p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema) 643 | p_sql = rebuild_sql_val(p_sql) 644 | p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) 645 | exact_score = evaluator.eval_exact_match(p_sql, g_sql) 646 | partial_scores = evaluator.partial_scores 647 | if exact_score == 0: 648 | turn_scores['exact'].append(0) 649 | else: 650 | turn_scores['exact'].append(1) 651 | scores[turn_id]['exact'] += exact_score 652 | scores[hardness]['exact'] += exact_score 653 | scores['all']['exact'] += exact_score 654 | for type_ in partial_types: 655 | if partial_scores[type_]['pred_total'] > 0: 656 | scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc'] 657 | scores[hardness]['partial'][type_]['acc_count'] += 1 658 | if partial_scores[type_]['label_total'] > 0: 659 | scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec'] 660 | scores[hardness]['partial'][type_]['rec_count'] += 1 661 | scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1'] 662 | if partial_scores[type_]['pred_total'] > 0: 663 | scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc'] 664 | scores['all']['partial'][type_]['acc_count'] += 1 665 | if partial_scores[type_]['label_total'] > 0: 666 | scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec'] 667 | scores['all']['partial'][type_]['rec_count'] += 1 668 | scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1'] 669 | 670 | entry["exact"] = exact_score 671 | entry["partial"] = partial_scores 672 | 673 | entries.append(entry) 674 | 675 | if all(v == 1 for v in turn_scores["exec"]): 676 | scores['joint_all']['exec'] += 1 677 | 678 | if all(v == 1 for v in turn_scores["exact"]): 679 | scores['joint_all']['exact'] += 1 680 | 681 | for turn in turns: 682 | if scores[turn]['count'] == 0: 683 | continue 684 | if etype in ["all", "exec"]: 685 | scores[turn]['exec'] /= scores[turn]['count'] 686 | 687 | if etype in ["all", "match"]: 688 | scores[turn]['exact'] /= scores[turn]['count'] 689 | 690 | for level in levels: 691 | if scores[level]['count'] == 0: 692 | continue 693 | if etype in ["all", "exec"]: 694 | scores[level]['exec'] /= scores[level]['count'] 695 | 696 | if etype in ["all", "match"]: 697 | scores[level]['exact'] /= scores[level]['count'] 698 | for type_ in partial_types: 699 | if scores[level]['partial'][type_]['acc_count'] == 0: 700 | scores[level]['partial'][type_]['acc'] = 0 701 | else: 702 | scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \ 703 | scores[level]['partial'][type_]['acc_count'] * 1.0 704 | if scores[level]['partial'][type_]['rec_count'] == 0: 705 | scores[level]['partial'][type_]['rec'] = 0 706 | else: 707 | scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \ 708 | scores[level]['partial'][type_]['rec_count'] * 1.0 709 | if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0: 710 | scores[level]['partial'][type_]['f1'] = 1 711 | else: 712 | scores[level]['partial'][type_]['f1'] = \ 713 | 2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / ( 714 | scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc']) 715 | 716 | # print_scores(scores, etype, include_turn_acc=include_turn_acc) 717 | if parse_g_sql_error_flag and etype in ["all", "match"]: 718 | logger.warning( 719 | "Errors occurred when parsing Gold SQL, so the evaluation of the `EM` metric was ignored." 720 | ) 721 | [entry.pop("exact") for entry in entries] 722 | [entry.pop("partial") for entry in entries] 723 | 724 | return entries 725 | 726 | 727 | # Rebuild SQL functions for value evaluation 728 | def rebuild_cond_unit_val(cond_unit): 729 | if cond_unit is None or not DISABLE_VALUE: 730 | return cond_unit 731 | 732 | not_op, op_id, val_unit, val1, val2 = cond_unit 733 | if type(val1) is not dict: 734 | val1 = None 735 | else: 736 | val1 = rebuild_sql_val(val1) 737 | if type(val2) is not dict: 738 | val2 = None 739 | else: 740 | val2 = rebuild_sql_val(val2) 741 | return not_op, op_id, val_unit, val1, val2 742 | 743 | 744 | def rebuild_condition_val(condition): 745 | if condition is None or not DISABLE_VALUE: 746 | return condition 747 | 748 | res = [] 749 | for idx, it in enumerate(condition): 750 | if idx % 2 == 0: 751 | res.append(rebuild_cond_unit_val(it)) 752 | else: 753 | res.append(it) 754 | return res 755 | 756 | 757 | def rebuild_sql_val(sql): 758 | if sql is None or not DISABLE_VALUE: 759 | return sql 760 | 761 | sql['from']['conds'] = rebuild_condition_val(sql['from']['conds']) 762 | sql['having'] = rebuild_condition_val(sql['having']) 763 | sql['where'] = rebuild_condition_val(sql['where']) 764 | sql['intersect'] = rebuild_sql_val(sql['intersect']) 765 | sql['except'] = rebuild_sql_val(sql['except']) 766 | sql['union'] = rebuild_sql_val(sql['union']) 767 | 768 | return sql 769 | 770 | 771 | # Rebuild SQL functions for foreign key evaluation 772 | def build_valid_col_units(table_units, schema): 773 | col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']] 774 | prefixs = [col_id[:-2] for col_id in col_ids] 775 | valid_col_units= [] 776 | for value in schema.idMap.values(): 777 | if '.' in value and value[:value.index('.')] in prefixs: 778 | valid_col_units.append(value) 779 | return valid_col_units 780 | 781 | 782 | def rebuild_col_unit_col(valid_col_units, col_unit, kmap): 783 | if col_unit is None: 784 | return col_unit 785 | 786 | agg_id, col_id, distinct = col_unit 787 | if col_id in kmap and col_id in valid_col_units: 788 | col_id = kmap[col_id] 789 | if DISABLE_DISTINCT: 790 | distinct = None 791 | return agg_id, col_id, distinct 792 | 793 | 794 | def rebuild_val_unit_col(valid_col_units, val_unit, kmap): 795 | if val_unit is None: 796 | return val_unit 797 | 798 | unit_op, col_unit1, col_unit2 = val_unit 799 | col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap) 800 | col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap) 801 | return unit_op, col_unit1, col_unit2 802 | 803 | 804 | def rebuild_table_unit_col(valid_col_units, table_unit, kmap): 805 | if table_unit is None: 806 | return table_unit 807 | 808 | table_type, col_unit_or_sql = table_unit 809 | if isinstance(col_unit_or_sql, tuple): 810 | col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap) 811 | return table_type, col_unit_or_sql 812 | 813 | 814 | def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap): 815 | if cond_unit is None: 816 | return cond_unit 817 | 818 | not_op, op_id, val_unit, val1, val2 = cond_unit 819 | val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap) 820 | return not_op, op_id, val_unit, val1, val2 821 | 822 | 823 | def rebuild_condition_col(valid_col_units, condition, kmap): 824 | for idx in range(len(condition)): 825 | if idx % 2 == 0: 826 | condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap) 827 | return condition 828 | 829 | 830 | def rebuild_select_col(valid_col_units, sel, kmap): 831 | if sel is None: 832 | return sel 833 | distinct, _list = sel 834 | new_list = [] 835 | for it in _list: 836 | agg_id, val_unit = it 837 | new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap))) 838 | if DISABLE_DISTINCT: 839 | distinct = None 840 | return distinct, new_list 841 | 842 | 843 | def rebuild_from_col(valid_col_units, from_, kmap): 844 | if from_ is None: 845 | return from_ 846 | 847 | from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']] 848 | from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap) 849 | return from_ 850 | 851 | 852 | def rebuild_group_by_col(valid_col_units, group_by, kmap): 853 | if group_by is None: 854 | return group_by 855 | 856 | return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by] 857 | 858 | 859 | def rebuild_order_by_col(valid_col_units, order_by, kmap): 860 | if order_by is None or len(order_by) == 0: 861 | return order_by 862 | 863 | direction, val_units = order_by 864 | new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units] 865 | return direction, new_val_units 866 | 867 | 868 | def rebuild_sql_col(valid_col_units, sql, kmap): 869 | if sql is None: 870 | return sql 871 | 872 | sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap) 873 | sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap) 874 | sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap) 875 | sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap) 876 | sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap) 877 | sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap) 878 | sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap) 879 | sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap) 880 | sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap) 881 | 882 | return sql 883 | 884 | 885 | def build_foreign_key_map(entry): 886 | cols_orig = entry["column_names_original"] 887 | tables_orig = entry["table_names_original"] 888 | 889 | # rebuild cols corresponding to idmap in Schema 890 | cols = [] 891 | for col_orig in cols_orig: 892 | if col_orig[0] >= 0: 893 | t = tables_orig[col_orig[0]] 894 | c = col_orig[1] 895 | cols.append("__" + t.lower() + "." + c.lower() + "__") 896 | else: 897 | cols.append("__all__") 898 | 899 | def keyset_in_list(k1, k2, k_list): 900 | for k_set in k_list: 901 | if k1 in k_set or k2 in k_set: 902 | return k_set 903 | new_k_set = set() 904 | k_list.append(new_k_set) 905 | return new_k_set 906 | 907 | foreign_key_list = [] 908 | foreign_keys = entry["foreign_keys"] 909 | for fkey in foreign_keys: 910 | key1, key2 = fkey 911 | key_set = keyset_in_list(key1, key2, foreign_key_list) 912 | key_set.add(key1) 913 | key_set.add(key2) 914 | 915 | foreign_key_map = {} 916 | for key_set in foreign_key_list: 917 | sorted_list = sorted(list(key_set)) 918 | midx = sorted_list[0] 919 | for idx in sorted_list: 920 | foreign_key_map[cols[idx]] = cols[midx] 921 | 922 | return foreign_key_map 923 | 924 | 925 | def build_foreign_key_map_from_json(table): 926 | with open(table) as f: 927 | data = json.load(f) 928 | tables = {} 929 | for entry in data: 930 | tables[entry['db_id']] = build_foreign_key_map(entry) 931 | return tables 932 | 933 | 934 | if __name__ == "__main__": 935 | parser = argparse.ArgumentParser() 936 | parser.add_argument('--gold', dest='gold', type=str, help="the path to the gold queries") 937 | parser.add_argument('--pred', dest='pred', type=str, help="the path to the predicted queries") 938 | parser.add_argument('--db', dest='db', type=str, help="the directory that contains all the databases and test suites") 939 | parser.add_argument('--table', dest='table', type=str, help="the tables.json schema file") 940 | parser.add_argument('--etype', dest='etype', type=str, default='exec', 941 | help="evaluation type, exec for test suite accuracy, match for the original exact set match accuracy", 942 | choices=('all', 'exec', 'match')) 943 | parser.add_argument('--plug_value', default=False, action='store_true', 944 | help='whether to plug in the gold value into the predicted query; suitable if your model does not predict values.') 945 | parser.add_argument('--keep_distinct', default=False, action='store_true', 946 | help='whether to keep distinct keyword during evaluation. default is false.') 947 | parser.add_argument('--progress_bar_for_each_datapoint', default=False, action='store_true', 948 | help='whether to print progress bar of running test inputs for each datapoint') 949 | args = parser.parse_args() 950 | 951 | # only evaluting exact match needs this argument 952 | kmaps = None 953 | if args.etype in ['all', 'match']: 954 | assert args.table is not None, 'table argument must be non-None if exact set match is evaluated' 955 | kmaps = build_foreign_key_map_from_json(args.table) 956 | 957 | evaluate(args.gold, args.pred, args.db, args.etype, kmaps, args.plug_value, args.keep_distinct, args.progress_bar_for_each_datapoint) -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/test_suite_sql_eval/exec_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import asyncio 4 | import sqlite3 5 | import threading 6 | from typing import Tuple, Any, List, Set 7 | from itertools import product 8 | from collections import defaultdict 9 | import tqdm 10 | import random 11 | from .parse import get_all_preds_for_execution, remove_distinct 12 | import time 13 | import pickle as pkl 14 | import subprocess 15 | from itertools import chain 16 | 17 | 18 | 19 | threadLock = threading.Lock() 20 | TIMEOUT = 60 21 | EXEC_TMP_DIR = 'tmp/' 22 | 23 | def permute_tuple(element: Tuple, perm: Tuple) -> Tuple: 24 | assert len(element) == len(perm) 25 | return tuple([element[i] for i in perm]) 26 | 27 | 28 | def unorder_row(row: Tuple) -> Tuple: 29 | return tuple(sorted(row, key=lambda x: str(x) + str(type(x)))) 30 | 31 | 32 | # unorder each row in the table 33 | # [result_1 and result_2 has the same bag of unordered row] 34 | # is a necessary condition of 35 | # [result_1 and result_2 are equivalent in denotation] 36 | def quick_rej(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool: 37 | s1 = [unorder_row(row) for row in result1] 38 | s2 = [unorder_row(row) for row in result2] 39 | if order_matters: 40 | return s1 == s2 41 | else: 42 | return set(s1) == set(s2) 43 | 44 | 45 | # return whether two bag of relations are equivalent 46 | def multiset_eq(l1: List, l2: List) -> bool: 47 | if len(l1) != len(l2): 48 | return False 49 | d = defaultdict(int) 50 | for e in l1: 51 | d[e] = d[e] + 1 52 | for e in l2: 53 | d[e] = d[e] - 1 54 | if d[e] < 0: 55 | return False 56 | return True 57 | 58 | 59 | def get_constraint_permutation(tab1_sets_by_columns: List[Set], result2: List[Tuple]): 60 | num_cols = len(result2[0]) 61 | perm_constraints = [{i for i in range(num_cols)} for _ in range(num_cols)] 62 | if num_cols <= 3: 63 | return product(*perm_constraints) 64 | 65 | # we sample 20 rows and constrain the space of permutations 66 | for _ in range(20): 67 | random_tab2_row = random.choice(result2) 68 | 69 | for tab1_col in range(num_cols): 70 | for tab2_col in set(perm_constraints[tab1_col]): 71 | if random_tab2_row[tab2_col] not in tab1_sets_by_columns[tab1_col]: 72 | perm_constraints[tab1_col].remove(tab2_col) 73 | return product(*perm_constraints) 74 | 75 | 76 | # check whether two denotations are correct 77 | def result_eq(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool: 78 | if len(result1) == 0 and len(result2) == 0: 79 | return True 80 | 81 | # if length is not the same, then they are definitely different bag of rows 82 | if len(result1) != len(result2): 83 | return False 84 | 85 | num_cols = len(result1[0]) 86 | 87 | # if the results do not have the same number of columns, they are different 88 | if len(result2[0]) != num_cols: 89 | return False 90 | 91 | # unorder each row and compare whether the denotation is the same 92 | # this can already find most pair of denotations that are different 93 | if not quick_rej(result1, result2, order_matters): 94 | return False 95 | 96 | # the rest of the problem is in fact more complicated than one might think 97 | # we want to find a permutation of column order and a permutation of row order, 98 | # s.t. result_1 is the same as result_2 99 | # we return true if we can find such column & row permutations 100 | # and false if we cannot 101 | tab1_sets_by_columns = [{row[i] for row in result1} for i in range(num_cols)] 102 | 103 | # on a high level, we enumerate all possible column permutations that might make result_1 == result_2 104 | # we decrease the size of the column permutation space by the function get_constraint_permutation 105 | # if one of the permutation make result_1, result_2 equivalent, then they are equivalent 106 | for perm in get_constraint_permutation(tab1_sets_by_columns, result2): 107 | if len(perm) != len(set(perm)): 108 | continue 109 | if num_cols == 1: 110 | result2_perm = result2 111 | else: 112 | result2_perm = [permute_tuple(element, perm) for element in result2] 113 | if order_matters: 114 | if result1 == result2_perm: 115 | return True 116 | else: 117 | # in fact the first condition must hold if the second condition holds 118 | # but the first is way more efficient implementation-wise 119 | # and we use it to quickly reject impossible candidates 120 | if set(result1) == set(result2_perm) and multiset_eq(result1, result2_perm): 121 | return True 122 | return False 123 | 124 | 125 | def replace_cur_year(query: str) -> str: 126 | return re.sub( 127 | "YEAR\s*\(\s*CURDATE\s*\(\s*\)\s*\)\s*", "2020", query, flags=re.IGNORECASE 128 | ) 129 | 130 | 131 | # get the database cursor for a sqlite database path 132 | def get_cursor_from_path(sqlite_path: str): 133 | try: 134 | if not os.path.exists(sqlite_path): 135 | print("Openning a new connection %s" % sqlite_path) 136 | connection = sqlite3.connect(sqlite_path) 137 | except Exception as e: 138 | print(sqlite_path) 139 | raise e 140 | connection.text_factory = lambda b: b.decode(errors="ignore") 141 | cursor = connection.cursor() 142 | return cursor 143 | 144 | 145 | async def exec_on_db_(sqlite_path: str, query: str) -> Tuple[str, Any]: 146 | query = replace_cur_year(query) 147 | cursor = get_cursor_from_path(sqlite_path) 148 | try: 149 | cursor.execute(query) 150 | result = cursor.fetchall() 151 | cursor.close() 152 | cursor.connection.close() 153 | return "result", result 154 | except Exception as e: 155 | cursor.close() 156 | cursor.connection.close() 157 | return "exception", e 158 | 159 | async def exec_on_db( 160 | sqlite_path: str, query: str, process_id: str = "", timeout: int = TIMEOUT 161 | ) -> Tuple[str, Any]: 162 | try: 163 | return await asyncio.wait_for(exec_on_db_(sqlite_path, query), timeout) 164 | except asyncio.TimeoutError: 165 | return ('exception', TimeoutError) 166 | except Exception as e: 167 | return ("exception", e) 168 | 169 | 170 | # postprocess the model predictions to avoid execution errors 171 | # e.g. removing spaces between ">" and "=" 172 | def postprocess(query: str) -> str: 173 | query = query.replace('> =', '>=').replace('< =', '<=').replace('! =', '!=') 174 | return query 175 | 176 | 177 | # approximate whether p_str and g_str are semantically equivalent 178 | # db is the database path 179 | # we are going to evaluate whether they are equivalent in all the databases 180 | # that are in the same directory as db 181 | # 0 if denotationally equivalent 182 | # 1 otherwise 183 | # the meaning of each auxillary argument can be seen in the parser definition in evaluation.py 184 | def eval_exec_match(db: str, p_str: str, g_str: str, plug_value: bool, keep_distinct: bool, progress_bar_for_each_datapoint: bool) -> int: 185 | # post-process the prediction. 186 | # e.g. removing spaces between ">" and "=" 187 | p_str, g_str = postprocess(p_str), postprocess(g_str) 188 | if not keep_distinct: 189 | p_str = remove_distinct(p_str) 190 | g_str = remove_distinct(g_str) 191 | 192 | # we decide whether two denotations are equivalent based on "bag semantics" 193 | # https://courses.cs.washington.edu/courses/cse444/10sp/lectures/lecture16.pdf 194 | # if there is order by in query, then we assume order of the rows matter 195 | # order by might also be used to find the max/min instead of sorting, 196 | # but in that case the result mostly only contains one row and hence order_matters does not make a difference 197 | order_matters = 'order by' in g_str.lower() 198 | 199 | # find all databases in the same directory 200 | db_dir = os.path.dirname(db) 201 | db_paths = [os.path.join(db_dir, basename) for basename in os.listdir(db_dir) if '.sqlite' in basename] 202 | 203 | preds = [p_str] 204 | # if plug in value (i.e. we do not consider value prediction correctness) 205 | # enumerate all ways to plug in values in the gold query to the model predictions 206 | # otherwise, we only evaluate the predicted query with its own value prediction 207 | if plug_value: 208 | _, preds = get_all_preds_for_execution(g_str, p_str) 209 | # we did not add this line in our EMNLP work 210 | # this reduces "false negatives" when value is substituted 211 | preds = chain([p_str], preds) 212 | 213 | for pred in preds: 214 | 215 | pred_passes = 1 216 | # compare the gold and predicted denotations on each database in the directory 217 | # wrap with progress bar if required 218 | if progress_bar_for_each_datapoint: 219 | ranger = tqdm.tqdm(db_paths) 220 | else: 221 | ranger = db_paths 222 | 223 | for db_path in ranger: 224 | g_flag, g_denotation = asyncio.run(exec_on_db(db_path, g_str)) 225 | p_flag, p_denotation = asyncio.run(exec_on_db(db_path, pred)) 226 | 227 | # we should expect the gold to be succesfully executed on the database 228 | assert g_flag != 'exception', 'gold query %s has error on database file %s' % (g_str, db_path) 229 | 230 | # wrong if execution fails 231 | if p_flag == 'exception': 232 | pred_passes = 0 233 | 234 | # if denotations are not equivalent, the prediction must be wrong 235 | elif not result_eq(g_denotation, p_denotation, order_matters=order_matters): 236 | pred_passes = 0 237 | if pred_passes == 0: 238 | break 239 | 240 | # the model prediction has the same denotation as the gold for all databases 241 | if pred_passes == 1: 242 | return 1 243 | 244 | # none of the predictions passed 245 | return 0 -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/test_suite_sql_eval/parse.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sqlparse 3 | from typing import List, Tuple, Set, Iterator, Dict, Any, Union 4 | from sqlparse.sql import Comparison, Identifier 5 | from sqlparse.tokens import Whitespace 6 | import itertools 7 | from collections import namedtuple 8 | 9 | Token = namedtuple('Token', ['ttype', 'value']) 10 | VALUE_NUM_SYMBOL = 'VALUERARE' 11 | QUOTE_CHARS = {'`', '\'', '"'} 12 | 13 | 14 | def tokenize(query: str) -> List[Token]: 15 | tokens = list([Token(t.ttype, t.value) for t in sqlparse.parse(query)[0].flatten()]) 16 | return tokens 17 | 18 | 19 | def join_tokens(tokens: List[Token]) -> str: 20 | return ''.join([x.value for x in tokens]).strip().replace(' ', ' ') 21 | 22 | 23 | def round_trip_test(query: str) -> None: 24 | tokens = tokenize(query) 25 | reconstructed = ''.join([token.value for token in tokens]) 26 | assert query == reconstructed, "Round trip test fails for string %s" % query 27 | 28 | 29 | def postprocess(query: str) -> str: 30 | query = query.replace('> =', '>=').replace('< =', '<=').replace('! =', '!=') 31 | return query 32 | 33 | 34 | # strip_query, reformat_query and replace values 35 | # were implemented by Yu Tao for processing CoSQL 36 | def strip_query(query: str) -> Tuple[List[str], List[str]]: 37 | query_keywords, all_values = [], [] 38 | 39 | # then replace all stuff enclosed by "" with a numerical value to get it marked as {VALUE} 40 | 41 | # Tao's implementation is commented out here. 42 | """ 43 | str_1 = re.findall("\"[^\"]*\"", query) 44 | str_2 = re.findall("\'[^\']*\'", query) 45 | values = str_1 + str_2 46 | """ 47 | 48 | toks = sqlparse.parse(query)[0].flatten() 49 | values = [t.value for t in toks if t.ttype == sqlparse.tokens.Literal.String.Single or t.ttype == sqlparse.tokens.Literal.String.Symbol] 50 | 51 | 52 | for val in values: 53 | all_values.append(val) 54 | query = query.replace(val.strip(), VALUE_NUM_SYMBOL) 55 | 56 | query_tokenized = query.split() 57 | float_nums = re.findall("[-+]?\d*\.\d+", query) 58 | all_values += [qt for qt in query_tokenized if qt in float_nums] 59 | query_tokenized = [VALUE_NUM_SYMBOL if qt in float_nums else qt for qt in query_tokenized] 60 | 61 | query = " ".join(query_tokenized) 62 | int_nums = [i.strip() for i in re.findall("[^tT]\d+", query)] 63 | 64 | all_values += [qt for qt in query_tokenized if qt in int_nums] 65 | query_tokenized = [VALUE_NUM_SYMBOL if qt in int_nums else qt for qt in query_tokenized] 66 | # print int_nums, query, query_tokenized 67 | 68 | for tok in query_tokenized: 69 | if "." in tok: 70 | table = re.findall("[Tt]\d+\.", tok) 71 | if len(table) > 0: 72 | to = tok.replace(".", " . ").split() 73 | to = [t.lower() for t in to if len(t) > 0] 74 | query_keywords.extend(to) 75 | else: 76 | query_keywords.append(tok.lower()) 77 | 78 | elif len(tok) > 0: 79 | query_keywords.append(tok.lower()) 80 | return query_keywords, all_values 81 | 82 | 83 | def reformat_query(query: str) -> str: 84 | query = query.strip().replace(";", "").replace("\t", "") 85 | query = ' '.join([t.value for t in tokenize(query) if t.ttype != sqlparse.tokens.Whitespace]) 86 | t_stars = ["t1.*", "t2.*", "t3.*", "T1.*", "T2.*", "T3.*"] 87 | for ts in t_stars: 88 | query = query.replace(ts, "*") 89 | return query 90 | 91 | 92 | def replace_values(sql: str) -> Tuple[List[str], Set[str]]: 93 | sql = sqlparse.format(sql, reindent=False, keyword_case='upper') 94 | # sql = re.sub(r"(<=|>=|!=|=|<|>|,)", r" \1 ", sql) 95 | sql = re.sub(r"(T\d+\.)\s", r"\1", sql) 96 | query_toks_no_value, values = strip_query(sql) 97 | return query_toks_no_value, set(values) 98 | 99 | 100 | # extract the non-value tokens and the set of values 101 | # from a sql query 102 | def extract_query_values(sql: str) -> Tuple[List[str], Set[str]]: 103 | reformated = reformat_query(query=sql) 104 | query_value_replaced, values = replace_values(reformated) 105 | return query_value_replaced, values 106 | 107 | 108 | # plug in the values into query with value slots 109 | def plugin(query_value_replaced: List[str], values_in_order: List[str]) -> str: 110 | q_length = len(query_value_replaced) 111 | query_w_values = query_value_replaced[:] 112 | value_idx = [idx for idx in range(q_length) if query_value_replaced[idx] == VALUE_NUM_SYMBOL.lower()] 113 | assert len(value_idx) == len(values_in_order) 114 | 115 | for idx, value in zip(value_idx, values_in_order): 116 | query_w_values[idx] = value 117 | return ' '.join(query_w_values) 118 | 119 | 120 | # a generator generating all possible ways of 121 | # filling values into predicted query 122 | def plugin_all_permutations(query_value_replaced: List[str], values: Set[str]) -> Iterator[str]: 123 | num_slots = len([v for v in query_value_replaced if v == VALUE_NUM_SYMBOL.lower()]) 124 | for values in itertools.product(*[list(values) for _ in range(num_slots)]): 125 | yield plugin(query_value_replaced, list(values)) 126 | 127 | 128 | # given the gold query and the model prediction 129 | # extract values from the gold, extract predicted sql with value slots 130 | # return 1) number of possible ways to plug in gold values and 2) an iterator of predictions with value plugged in 131 | def get_all_preds_for_execution(gold: str, pred: str) -> Tuple[int, Iterator[str]]: 132 | _, gold_values = extract_query_values(gold) 133 | pred_query_value_replaced, _ = extract_query_values(pred) 134 | num_slots = len([v for v in pred_query_value_replaced if v == VALUE_NUM_SYMBOL.lower()]) 135 | num_alternatives = len(gold_values) ** num_slots 136 | return num_alternatives, plugin_all_permutations(pred_query_value_replaced, gold_values) 137 | 138 | 139 | def remove_distinct(s): 140 | toks = [t.value for t in list(sqlparse.parse(s)[0].flatten())] 141 | return ''.join([t for t in toks if t.lower() != 'distinct']) 142 | 143 | 144 | def extract_all_comparison_from_node(node: Token) -> List[Comparison]: 145 | comparison_list = [] 146 | if hasattr(node, 'tokens'): 147 | for t in node.tokens: 148 | comparison_list.extend(extract_all_comparison_from_node(t)) 149 | if type(node) == Comparison: 150 | comparison_list.append(node) 151 | return comparison_list 152 | 153 | 154 | def extract_all_comparison(query: str) -> List[Comparison]: 155 | tree = sqlparse.parse(query)[0] 156 | comparison_list = extract_all_comparison_from_node(tree) 157 | return comparison_list 158 | 159 | 160 | def extract_toks_from_comparison(comparison_node: Comparison) -> List[Token]: 161 | tokens = [t for t in comparison_node.tokens if t.ttype != Whitespace] 162 | return tokens 163 | 164 | 165 | def extract_info_from_comparison(comparison_node: Comparison) -> Dict[str, Any]: 166 | tokens = extract_toks_from_comparison(comparison_node) 167 | left, op, right = tokens 168 | 169 | returned_dict = { 170 | 'left': left, 171 | 'op': op.value, 172 | 'right': right 173 | } 174 | 175 | if type(left) != Identifier: 176 | return returned_dict 177 | 178 | table = None 179 | if len(left.tokens) == 3 and re.match('^[tT][0-9]$', left.tokens[0].value) is None: 180 | table = left.tokens[0].value.lower() 181 | col = left.tokens[-1].value 182 | 183 | if type(right) == Identifier: 184 | if len(right.tokens) == 1 and type(right.tokens[0]) == sqlparse.sql.Token: 185 | right_val = right.tokens[0].value 186 | else: 187 | return returned_dict 188 | elif type(right) == sqlparse.sql.Token: 189 | right_val = right.value 190 | else: 191 | return returned_dict 192 | 193 | returned_dict['table_col'], returned_dict['val'] = (table, col.upper()), process_str_value(right_val) 194 | 195 | return returned_dict 196 | 197 | 198 | def extract_all_comparison_from_query(query: str) -> List[Dict[str, Any]]: 199 | comparison_list = extract_all_comparison(query) 200 | return [extract_info_from_comparison(c) for c in comparison_list] 201 | 202 | 203 | def extract_typed_value_in_comparison_from_query(query: str) -> List[Tuple[Tuple[Union[str, None], str], str]]: 204 | cmps = extract_all_comparison_from_query(query) 205 | typed_values = [(cmp['table_col'], cmp['val']) for cmp in cmps if 'table_col' in cmp] 206 | for table, col, val1, val2 in re.findall('(?:([^\.\s]*)\.)?([^\.\s]+) between ([^\s;]+) and ([^\s;]+)', query, re.IGNORECASE): 207 | if table == '': 208 | table = None 209 | else: 210 | table = table.lower() 211 | col = col.upper() 212 | for v in [val1, val2]: 213 | typed_values.append(((table, col), v)) 214 | return typed_values 215 | 216 | 217 | def process_str_value(v: str) -> str: 218 | if len(v) > 0 and v[0] in QUOTE_CHARS: 219 | v = v[1:] 220 | if len(v) > 0 and v[-1] in QUOTE_CHARS: 221 | v = v[:-1] 222 | for c in QUOTE_CHARS: 223 | v = v.replace(c + c, c) 224 | return v -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/test_suite_sql_eval/process_sql.py: -------------------------------------------------------------------------------- 1 | ################################ 2 | # Assumptions: 3 | # 1. sql is correct 4 | # 2. only table name has alias 5 | # 3. only one intersect/union/except 6 | # 7 | # val: number(float)/string(str)/sql(dict) 8 | # col_unit: (agg_id, col_id, isDistinct(bool)) 9 | # val_unit: (unit_op, col_unit1, col_unit2) 10 | # table_unit: (table_type, col_unit/sql) 11 | # cond_unit: (not_op, op_id, val_unit, val1, val2) 12 | # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] 13 | # sql { 14 | # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) 15 | # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} 16 | # 'where': condition 17 | # 'groupBy': [col_unit1, col_unit2, ...] 18 | # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) 19 | # 'having': condition 20 | # 'limit': None/limit value 21 | # 'intersect': None/sql 22 | # 'except': None/sql 23 | # 'union': None/sql 24 | # } 25 | ################################ 26 | 27 | import json 28 | import sqlite3 29 | from nltk import word_tokenize 30 | 31 | CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') 32 | JOIN_KEYWORDS = ('join', 'on', 'as') 33 | 34 | WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') 35 | UNIT_OPS = ('none', '-', '+', "*", '/') 36 | AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') 37 | TABLE_TYPE = { 38 | 'sql': "sql", 39 | 'table_unit': "table_unit", 40 | } 41 | 42 | COND_OPS = ('and', 'or') 43 | SQL_OPS = ('intersect', 'union', 'except') 44 | ORDER_OPS = ('desc', 'asc') 45 | 46 | 47 | 48 | class Schema: 49 | """ 50 | Simple schema which maps table&column to a unique identifier 51 | """ 52 | def __init__(self, schema): 53 | self._schema = schema 54 | self._idMap = self._map(self._schema) 55 | 56 | @property 57 | def schema(self): 58 | return self._schema 59 | 60 | @property 61 | def idMap(self): 62 | return self._idMap 63 | 64 | def _map(self, schema): 65 | idMap = {'*': "__all__"} 66 | id = 1 67 | for key, vals in schema.items(): 68 | for val in vals: 69 | idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__" 70 | id += 1 71 | 72 | for key in schema: 73 | idMap[key.lower()] = "__" + key.lower() + "__" 74 | id += 1 75 | 76 | return idMap 77 | 78 | 79 | def get_schema(db): 80 | """ 81 | Get database's schema, which is a dict with table name as key 82 | and list of column names as value 83 | :param db: database path 84 | :return: schema dict 85 | """ 86 | 87 | schema = {} 88 | conn = sqlite3.connect(db) 89 | cursor = conn.cursor() 90 | 91 | # fetch table names 92 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") 93 | tables = [str(table[0].lower()) for table in cursor.fetchall()] 94 | 95 | # fetch table info 96 | for table in tables: 97 | cursor.execute("PRAGMA table_info({})".format(table)) 98 | schema[table] = [str(col[1].lower()) for col in cursor.fetchall()] 99 | 100 | return schema 101 | 102 | 103 | def get_schema_from_json(fpath): 104 | with open(fpath) as f: 105 | data = json.load(f) 106 | 107 | schema = {} 108 | for entry in data: 109 | table = str(entry['table'].lower()) 110 | cols = [str(col['column_name'].lower()) for col in entry['col_data']] 111 | schema[table] = cols 112 | 113 | return schema 114 | 115 | 116 | def tokenize(string): 117 | string = str(string) 118 | string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem?? 119 | quote_idxs = [idx for idx, char in enumerate(string) if char == '"'] 120 | assert len(quote_idxs) % 2 == 0, "Unexpected quote" 121 | 122 | # keep string value as token 123 | vals = {} 124 | for i in range(len(quote_idxs)-1, -1, -2): 125 | qidx1 = quote_idxs[i-1] 126 | qidx2 = quote_idxs[i] 127 | val = string[qidx1: qidx2+1] 128 | key = "__val_{}_{}__".format(qidx1, qidx2) 129 | string = string[:qidx1] + key + string[qidx2+1:] 130 | vals[key] = val 131 | 132 | toks = [word.lower() for word in word_tokenize(string)] 133 | # replace with string value token 134 | for i in range(len(toks)): 135 | if toks[i] in vals: 136 | toks[i] = vals[toks[i]] 137 | 138 | # find if there exists !=, >=, <= 139 | eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="] 140 | eq_idxs.reverse() 141 | prefix = ('!', '>', '<') 142 | for eq_idx in eq_idxs: 143 | pre_tok = toks[eq_idx-1] 144 | if pre_tok in prefix: 145 | toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ] 146 | 147 | return toks 148 | 149 | 150 | def scan_alias(toks): 151 | """Scan the index of 'as' and build the map for all alias""" 152 | as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as'] 153 | alias = {} 154 | for idx in as_idxs: 155 | alias[toks[idx+1]] = toks[idx-1] 156 | return alias 157 | 158 | 159 | def get_tables_with_alias(schema, toks): 160 | tables = scan_alias(toks) 161 | for key in schema: 162 | assert key not in tables, "Alias {} has the same name in table".format(key) 163 | tables[key] = key 164 | return tables 165 | 166 | 167 | def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): 168 | """ 169 | :returns next idx, column id 170 | """ 171 | tok = toks[start_idx] 172 | if tok == "*": 173 | return start_idx + 1, schema.idMap[tok] 174 | 175 | if '.' in tok: # if token is a composite 176 | alias, col = tok.split('.') 177 | key = tables_with_alias[alias] + "." + col 178 | return start_idx+1, schema.idMap[key] 179 | 180 | assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty" 181 | 182 | for alias in default_tables: 183 | table = tables_with_alias[alias] 184 | if tok in schema.schema[table]: 185 | key = table + "." + tok 186 | return start_idx+1, schema.idMap[key] 187 | 188 | assert False, "Error col: {}".format(tok) 189 | 190 | 191 | def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): 192 | """ 193 | :returns next idx, (agg_op id, col_id) 194 | """ 195 | idx = start_idx 196 | len_ = len(toks) 197 | isBlock = False 198 | isDistinct = False 199 | if toks[idx] == '(': 200 | isBlock = True 201 | idx += 1 202 | 203 | if toks[idx] in AGG_OPS: 204 | agg_id = AGG_OPS.index(toks[idx]) 205 | idx += 1 206 | assert idx < len_ and toks[idx] == '(' 207 | idx += 1 208 | if toks[idx] == "distinct": 209 | idx += 1 210 | isDistinct = True 211 | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) 212 | assert idx < len_ and toks[idx] == ')' 213 | idx += 1 214 | return idx, (agg_id, col_id, isDistinct) 215 | 216 | if toks[idx] == "distinct": 217 | idx += 1 218 | isDistinct = True 219 | agg_id = AGG_OPS.index("none") 220 | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) 221 | 222 | if isBlock: 223 | assert toks[idx] == ')' 224 | idx += 1 # skip ')' 225 | 226 | return idx, (agg_id, col_id, isDistinct) 227 | 228 | 229 | def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): 230 | idx = start_idx 231 | len_ = len(toks) 232 | isBlock = False 233 | if toks[idx] == '(': 234 | isBlock = True 235 | idx += 1 236 | 237 | col_unit1 = None 238 | col_unit2 = None 239 | unit_op = UNIT_OPS.index('none') 240 | 241 | idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) 242 | if idx < len_ and toks[idx] in UNIT_OPS: 243 | unit_op = UNIT_OPS.index(toks[idx]) 244 | idx += 1 245 | idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) 246 | 247 | if isBlock: 248 | assert toks[idx] == ')' 249 | idx += 1 # skip ')' 250 | 251 | return idx, (unit_op, col_unit1, col_unit2) 252 | 253 | 254 | def parse_table_unit(toks, start_idx, tables_with_alias, schema): 255 | """ 256 | :returns next idx, table id, table name 257 | """ 258 | idx = start_idx 259 | len_ = len(toks) 260 | key = tables_with_alias[toks[idx]] 261 | 262 | if idx + 1 < len_ and toks[idx+1] == "as": 263 | idx += 3 264 | else: 265 | idx += 1 266 | 267 | return idx, schema.idMap[key], key 268 | 269 | 270 | def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None): 271 | idx = start_idx 272 | len_ = len(toks) 273 | 274 | isBlock = False 275 | if toks[idx] == '(': 276 | isBlock = True 277 | idx += 1 278 | 279 | if toks[idx] == 'select': 280 | idx, val = parse_sql(toks, idx, tables_with_alias, schema) 281 | elif "\"" in toks[idx]: # token is a string value 282 | val = toks[idx] 283 | idx += 1 284 | else: 285 | try: 286 | val = float(toks[idx]) 287 | idx += 1 288 | except: 289 | end_idx = idx 290 | while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\ 291 | and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS: 292 | end_idx += 1 293 | 294 | idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables) 295 | idx = end_idx 296 | 297 | if isBlock: 298 | assert toks[idx] == ')' 299 | idx += 1 300 | 301 | return idx, val 302 | 303 | 304 | def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None): 305 | idx = start_idx 306 | len_ = len(toks) 307 | conds = [] 308 | 309 | while idx < len_: 310 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) 311 | not_op = False 312 | if toks[idx] == 'not': 313 | not_op = True 314 | idx += 1 315 | 316 | assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx]) 317 | op_id = WHERE_OPS.index(toks[idx]) 318 | idx += 1 319 | val1 = val2 = None 320 | if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values 321 | idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) 322 | assert toks[idx] == 'and' 323 | idx += 1 324 | idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables) 325 | else: # normal case: single value 326 | idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) 327 | val2 = None 328 | 329 | conds.append((not_op, op_id, val_unit, val1, val2)) 330 | 331 | if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS): 332 | break 333 | 334 | if idx < len_ and toks[idx] in COND_OPS: 335 | conds.append(toks[idx]) 336 | idx += 1 # skip and/or 337 | 338 | return idx, conds 339 | 340 | 341 | def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None): 342 | idx = start_idx 343 | len_ = len(toks) 344 | 345 | assert toks[idx] == 'select', "'select' not found" 346 | idx += 1 347 | isDistinct = False 348 | if idx < len_ and toks[idx] == 'distinct': 349 | idx += 1 350 | isDistinct = True 351 | val_units = [] 352 | 353 | while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS: 354 | agg_id = AGG_OPS.index("none") 355 | if toks[idx] in AGG_OPS: 356 | agg_id = AGG_OPS.index(toks[idx]) 357 | idx += 1 358 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) 359 | val_units.append((agg_id, val_unit)) 360 | if idx < len_ and toks[idx] == ',': 361 | idx += 1 # skip ',' 362 | 363 | return idx, (isDistinct, val_units) 364 | 365 | 366 | def parse_from(toks, start_idx, tables_with_alias, schema): 367 | """ 368 | Assume in the from clause, all table units are combined with join 369 | """ 370 | assert 'from' in toks[start_idx:], "'from' not found" 371 | 372 | len_ = len(toks) 373 | idx = toks.index('from', start_idx) + 1 374 | default_tables = [] 375 | table_units = [] 376 | conds = [] 377 | 378 | while idx < len_: 379 | isBlock = False 380 | if toks[idx] == '(': 381 | isBlock = True 382 | idx += 1 383 | 384 | if toks[idx] == 'select': 385 | idx, sql = parse_sql(toks, idx, tables_with_alias, schema) 386 | table_units.append((TABLE_TYPE['sql'], sql)) 387 | else: 388 | if idx < len_ and toks[idx] == 'join': 389 | idx += 1 # skip join 390 | idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema) 391 | table_units.append((TABLE_TYPE['table_unit'],table_unit)) 392 | default_tables.append(table_name) 393 | if idx < len_ and toks[idx] == "on": 394 | idx += 1 # skip on 395 | idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 396 | if len(conds) > 0: 397 | conds.append('and') 398 | conds.extend(this_conds) 399 | 400 | if isBlock: 401 | assert toks[idx] == ')' 402 | idx += 1 403 | if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 404 | break 405 | 406 | return idx, table_units, conds, default_tables 407 | 408 | 409 | def parse_where(toks, start_idx, tables_with_alias, schema, default_tables): 410 | idx = start_idx 411 | len_ = len(toks) 412 | 413 | if idx >= len_ or toks[idx] != 'where': 414 | return idx, [] 415 | 416 | idx += 1 417 | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 418 | return idx, conds 419 | 420 | 421 | def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables): 422 | idx = start_idx 423 | len_ = len(toks) 424 | col_units = [] 425 | 426 | if idx >= len_ or toks[idx] != 'group': 427 | return idx, col_units 428 | 429 | idx += 1 430 | assert toks[idx] == 'by' 431 | idx += 1 432 | 433 | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 434 | idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) 435 | col_units.append(col_unit) 436 | if idx < len_ and toks[idx] == ',': 437 | idx += 1 # skip ',' 438 | else: 439 | break 440 | 441 | return idx, col_units 442 | 443 | 444 | def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables): 445 | idx = start_idx 446 | len_ = len(toks) 447 | val_units = [] 448 | order_type = 'asc' # default type is 'asc' 449 | 450 | if idx >= len_ or toks[idx] != 'order': 451 | return idx, val_units 452 | 453 | idx += 1 454 | assert toks[idx] == 'by' 455 | idx += 1 456 | 457 | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 458 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) 459 | val_units.append(val_unit) 460 | if idx < len_ and toks[idx] in ORDER_OPS: 461 | order_type = toks[idx] 462 | idx += 1 463 | if idx < len_ and toks[idx] == ',': 464 | idx += 1 # skip ',' 465 | else: 466 | break 467 | 468 | return idx, (order_type, val_units) 469 | 470 | 471 | def parse_having(toks, start_idx, tables_with_alias, schema, default_tables): 472 | idx = start_idx 473 | len_ = len(toks) 474 | 475 | if idx >= len_ or toks[idx] != 'having': 476 | return idx, [] 477 | 478 | idx += 1 479 | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 480 | return idx, conds 481 | 482 | 483 | def parse_limit(toks, start_idx): 484 | idx = start_idx 485 | len_ = len(toks) 486 | 487 | if idx < len_ and toks[idx] == 'limit': 488 | idx += 2 489 | # make limit value can work, cannot assume put 1 as a fake limit number 490 | if type(toks[idx-1]) != int: 491 | return idx, 1 492 | 493 | return idx, int(toks[idx-1]) 494 | 495 | return idx, None 496 | 497 | 498 | def parse_sql(toks, start_idx, tables_with_alias, schema): 499 | isBlock = False # indicate whether this is a block of sql/sub-sql 500 | len_ = len(toks) 501 | idx = start_idx 502 | 503 | sql = {} 504 | if toks[idx] == '(': 505 | isBlock = True 506 | idx += 1 507 | 508 | # parse from clause in order to get default tables 509 | from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema) 510 | sql['from'] = {'table_units': table_units, 'conds': conds} 511 | # select clause 512 | _, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables) 513 | idx = from_end_idx 514 | sql['select'] = select_col_units 515 | # where clause 516 | idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables) 517 | sql['where'] = where_conds 518 | # group by clause 519 | idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables) 520 | sql['groupBy'] = group_col_units 521 | # having clause 522 | idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables) 523 | sql['having'] = having_conds 524 | # order by clause 525 | idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables) 526 | sql['orderBy'] = order_col_units 527 | # limit clause 528 | idx, limit_val = parse_limit(toks, idx) 529 | sql['limit'] = limit_val 530 | 531 | idx = skip_semicolon(toks, idx) 532 | if isBlock: 533 | assert toks[idx] == ')' 534 | idx += 1 # skip ')' 535 | idx = skip_semicolon(toks, idx) 536 | 537 | # intersect/union/except clause 538 | for op in SQL_OPS: # initialize IUE 539 | sql[op] = None 540 | if idx < len_ and toks[idx] in SQL_OPS: 541 | sql_op = toks[idx] 542 | idx += 1 543 | idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema) 544 | sql[sql_op] = IUE_sql 545 | return idx, sql 546 | 547 | 548 | def load_data(fpath): 549 | with open(fpath) as f: 550 | data = json.load(f) 551 | return data 552 | 553 | 554 | def get_sql(schema, query): 555 | toks = tokenize(query) 556 | tables_with_alias = get_tables_with_alias(schema.schema, toks) 557 | _, sql = parse_sql(toks, 0, tables_with_alias, schema) 558 | 559 | return sql 560 | 561 | 562 | def skip_semicolon(toks, start_idx): 563 | idx = start_idx 564 | while idx < len(toks) and toks[idx] == ";": 565 | idx += 1 566 | return idx -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/ves.py: -------------------------------------------------------------------------------- 1 | from .bird_eval.bird_ves import run_sqls_parallel, sort_results 2 | import os 3 | import math 4 | from loguru import logger 5 | 6 | 7 | class VesEvaluator: 8 | 9 | def __init__(self, reuse_ex): 10 | self.reuse_ex = reuse_ex 11 | 12 | def evaluate(self, gold_sqls, pred_sqls, db_ids, db_dir, **kwds): 13 | query_pairs = list(zip(pred_sqls, gold_sqls)) 14 | db_places = [os.path.join(db_dir, db_id, f"{db_id}.sqlite") for db_id in db_ids] 15 | exec_acc_list = kwds.get("exec_acc_list", None) 16 | if self.reuse_ex and exec_acc_list is None: 17 | logger.warning("VES evaluator is set to reuse the EX result, but it has not been passed in.") 18 | ves_result = run_sqls_parallel( 19 | sqls=query_pairs, 20 | db_places=db_places, 21 | num_cpus=kwds.get("num_processes", 8), 22 | meta_time_out=kwds.get("timeout", 30), 23 | exec_acc_list=exec_acc_list 24 | ) 25 | ves_result = sort_results(ves_result) 26 | ves_result = [math.sqrt(res['time_ratio']) for res in ves_result] 27 | return { 28 | "ves": ves_result 29 | } 30 | 31 | def get_eval_metrics(self): 32 | return ["ves"] 33 | -------------------------------------------------------------------------------- /src/nl2sql360/filter/__init__.py: -------------------------------------------------------------------------------- 1 | from .filter import Operator, Field, Filter, Scenario, parse_filter, parse_scenario, serialize_filter, serialize_scenario 2 | 3 | 4 | __all__ = [ 5 | "Operator", 6 | "Field", 7 | "Filter", 8 | "Scenario", 9 | "parse_filter", 10 | "parse_scenario", 11 | "serialize_filter", 12 | "serialize_scenario" 13 | ] -------------------------------------------------------------------------------- /src/nl2sql360/filter/filter.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum 3 | from typing import List, Optional 4 | import re 5 | 6 | 7 | class Operator(Enum): 8 | GT = ">" 9 | LT = "<" 10 | EQ = "=" 11 | 12 | 13 | class Field(Enum): 14 | QUERY_FIELDS = "QUERY_FIELDS" 15 | GROUP_BY = "GROUP_BY" 16 | ORDER_BY = "ORDER_BY" 17 | LIMIT = "LIMIT" 18 | JOIN = "JOIN" 19 | PREDICATE = "PREDICATE" 20 | AGGREGATION = "AGGREGATION" 21 | SCALAR_FUNCTION = "SCALAR_FUNCTION" 22 | SUBQUERY = "SUBQUERY" 23 | SET_OPERATION = "SET_OPERATION" 24 | MATH_COMPUTE = "MATH_COMPUTE" 25 | LOGICAL_CONNECTOR = "LOGICAL_CONNECTOR" 26 | DISTINCT = "DISTINCT" 27 | LIKE = "LIKE" 28 | CONTROL_FLOW = "CONTROL_FLOW" 29 | WINDOW = "WINDOW" 30 | 31 | 32 | @dataclass 33 | class Filter: 34 | name: str 35 | field: Field 36 | operator: Operator 37 | value: int 38 | 39 | 40 | _SCENARIO_CONNECTOR = "&&" 41 | 42 | @dataclass 43 | class Scenario: 44 | name: str 45 | filters: List[Filter] 46 | 47 | 48 | def parse_filter(filter_name: str, filter_expression: str) -> Optional[Filter]: 49 | pattern = re.compile(r'^(?P{})\s*(?P{})\s*(?P\d+)$'.format( 50 | '|'.join([field.value for field in Field]), 51 | '|'.join([op.value for op in Operator]) 52 | )) 53 | match = pattern.match(filter_expression) 54 | if match: 55 | field = match.group('field') 56 | op = match.group('op') 57 | value = match.group('value') 58 | return Filter( 59 | name=filter_name, 60 | field=Field(field), 61 | operator=Operator(op), 62 | value=int(value) 63 | ) 64 | else: 65 | return None 66 | 67 | 68 | def parse_scenario(scenario_name: str, scenario_str: str) -> Optional[Scenario]: 69 | filter_expressions = scenario_str.split(_SCENARIO_CONNECTOR) 70 | filters = [] 71 | for idx, exp in enumerate(filter_expressions): 72 | filter = parse_filter(filter_name=f"{scenario_name}-Filter-{idx}", filter_expression=exp.strip()) 73 | if filter is None: 74 | return None 75 | else: 76 | filters.append(filter) 77 | return Scenario(name=scenario_name, filters=filters) 78 | 79 | 80 | def map_field_to_database_col(field: Field) -> str: 81 | return "count_" + field.value.lower() 82 | 83 | 84 | def serialize_filter(filter: Filter) -> str: 85 | return f"{map_field_to_database_col(filter.field)} {filter.operator.value} {filter.value}" 86 | 87 | 88 | def serialize_scenario(scenario: Scenario) -> str: 89 | return " AND ".join([serialize_filter(filter) for filter in scenario.filters]) 90 | -------------------------------------------------------------------------------- /src/nl2sql360/parser/__init__.py: -------------------------------------------------------------------------------- 1 | from .sql_parser import SQLParser 2 | 3 | 4 | __all__ = [ 5 | "SQLParser" 6 | ] -------------------------------------------------------------------------------- /src/nl2sql360/parser/sql_parser.py: -------------------------------------------------------------------------------- 1 | from sqlglot import parse_one, exp 2 | 3 | 4 | class SQLParser: 5 | 6 | _SET_KEYWORDS = (exp.Union, exp.Except, exp.Intersect) 7 | 8 | _SCALAR_KEYWORDS = (exp.Abs, exp.Length, exp.Cast, exp.Round, exp.Upper, exp.Lower, exp.Rand) 9 | _SCALAR_KEYWORDS_ANONYMOUS_STR = ("STRFTIME", "JULIADAY", "NOW", "INSTR", "SUBSTR") 10 | 11 | _MATH_COMPUTE_KEYWORDS = (exp.Add, exp.Sub, exp.Mul, exp.Div, exp.Mod) 12 | 13 | _LOGICAL_CONNECTOR_KEYWORDS = (exp.And, exp.Or) 14 | 15 | _CONTROL_FLOW_KEYWORDS = (exp.Case) 16 | _CONTROL_FLOW_KEYWORDS_ANONYMOUS_STR = ("IIF") 17 | 18 | def __init__(self, sql, dialect="sqlite"): 19 | self.ast = parse_one(sql, dialect=dialect) 20 | 21 | @property 22 | def count_query_fields(self): 23 | _ast = self.ast 24 | while isinstance(_ast, self._SET_KEYWORDS): 25 | _ast = _ast.this 26 | assert isinstance(_ast, exp.Select) 27 | return len(_ast.expressions) 28 | 29 | @property 30 | def count_group_by(self): 31 | return len(list(self.ast.find_all(exp.Group))) 32 | 33 | @property 34 | def count_order_by(self): 35 | return len(list(self.ast.find_all(exp.Order))) 36 | 37 | @property 38 | def count_limit(self): 39 | return len(list(self.ast.find_all(exp.Limit))) 40 | 41 | @property 42 | def count_join(self): 43 | return len(list(self.ast.find_all(exp.Join))) 44 | 45 | @property 46 | def count_predicate(self): 47 | return len(list(self.ast.find_all(exp.Predicate))) 48 | 49 | @property 50 | def count_aggregation(self): 51 | return len(list(self.ast.find_all(exp.AggFunc))) 52 | 53 | @property 54 | def count_scalar_function(self): 55 | scalar_nodes = list(self.ast.find_all(self._SCALAR_KEYWORDS)) 56 | scalar_nodes.extend([node for node in self.ast.find_all(exp.Anonymous) if node.this.upper() in self._SCALAR_KEYWORDS_ANONYMOUS_STR]) 57 | return len(scalar_nodes) 58 | 59 | @property 60 | def count_subquery(self): 61 | return len(list(self.ast.find_all(exp.Subquery))) 62 | 63 | @property 64 | def count_set_operation(self): 65 | return len(list(self.ast.find_all(self._SET_KEYWORDS))) 66 | 67 | @property 68 | def count_math_compute(self): 69 | return len(list(self.ast.find_all(self._MATH_COMPUTE_KEYWORDS))) 70 | 71 | @property 72 | def count_logical_connector(self): 73 | return len(list(self.ast.find_all(self._LOGICAL_CONNECTOR_KEYWORDS))) 74 | 75 | @property 76 | def count_distinct(self): 77 | return len(list(self.ast.find_all(exp.Distinct))) 78 | 79 | @property 80 | def count_like(self): 81 | return len(list(self.ast.find_all(exp.Like))) 82 | 83 | @property 84 | def count_control_flow(self): 85 | control_flow_nodes = list(self.ast.find_all(self._CONTROL_FLOW_KEYWORDS)) 86 | control_flow_nodes.extend([node for node in self.ast.find_all(exp.Anonymous) if node.this.upper() in self._CONTROL_FLOW_KEYWORDS_ANONYMOUS_STR]) 87 | return len(control_flow_nodes) 88 | 89 | @property 90 | def count_window(self): 91 | return len(list(self.ast.find_all(exp.Window))) 92 | --------------------------------------------------------------------------------