├── src └── nl2sql360 │ ├── __init__.py │ ├── evaluator │ ├── bird_eval │ │ ├── __init__.py │ │ ├── evaluation_utils.py │ │ ├── bird_ex.py │ │ ├── bird_ves.py │ │ ├── bird_rves.py │ │ └── evaluation_f1.py │ ├── test_suite_sql_eval │ │ ├── __init__.py │ │ ├── parse.py │ │ ├── exec_eval.py │ │ └── process_sql.py │ ├── __init__.py │ ├── f1.py │ ├── bird_ex.py │ ├── rves.py │ ├── ves.py │ └── spider_ex_em.py │ ├── cli │ ├── __init__.py │ ├── util.py │ └── cli.py │ ├── core │ ├── __init__.py │ └── core.py │ ├── parser │ ├── __init__.py │ └── sql_parser.py │ ├── dataset │ ├── __init__.py │ └── dataset.py │ ├── filter │ ├── __init__.py │ └── filter.py │ ├── arguments │ ├── __init__.py │ ├── core_args.py │ ├── delete_history_args.py │ ├── parser.py │ ├── dataset_args.py │ ├── report_args.py │ ├── evaluation_args.py │ └── hf_argparser.py │ └── database │ ├── util.py │ ├── __init__.py │ ├── template.py │ └── model.py ├── MANIFEST.in ├── assets ├── QVT.png ├── Boxplot.png ├── QVT_New.png ├── domain.png ├── nl2sql360.png ├── sql_charac.png ├── BIRD_Heatmap.png ├── SQLiteStudio.png ├── leaderboard.png ├── Spider_Heatmap.png ├── DB_Domain_Boxplot.png └── DB_Domain_Heatmap.png ├── .gitignore ├── requirements.txt ├── examples ├── py_examples │ ├── evaluation.py │ ├── delete_history.py │ ├── dataset_import.py │ └── report.py └── cli_examples │ ├── spider │ ├── delete_history.yaml │ ├── evaluation.yaml │ ├── dataset_spider.yaml │ └── report.yaml │ └── bird │ ├── bird_dev_evaluate.yaml │ ├── bird_test_evaluate.yaml │ ├── bird_minidev_evaluate.yaml │ ├── bird_dev_dataset.yaml │ ├── bird_test_dataset.yaml │ ├── bird_minidev_dataset.yaml │ ├── bird_dev_report.yaml │ ├── bird_test_report.yaml │ └── bird_minidev_report.yaml ├── LICENSE ├── setup.py └── README.md /src/nl2sql360/__init__.py: -------------------------------------------------------------------------------- 1 | VERSION = "1.1.0" -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/bird_eval/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/nl2sql360/cli/__init__.py: -------------------------------------------------------------------------------- 1 | from .cli import main -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include assets/* 2 | include examples/* 3 | -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/test_suite_sql_eval/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/QVT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUSTDial/NL2SQL360/HEAD/assets/QVT.png -------------------------------------------------------------------------------- /assets/Boxplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUSTDial/NL2SQL360/HEAD/assets/Boxplot.png -------------------------------------------------------------------------------- /assets/QVT_New.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUSTDial/NL2SQL360/HEAD/assets/QVT_New.png -------------------------------------------------------------------------------- /assets/domain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUSTDial/NL2SQL360/HEAD/assets/domain.png -------------------------------------------------------------------------------- /assets/nl2sql360.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUSTDial/NL2SQL360/HEAD/assets/nl2sql360.png -------------------------------------------------------------------------------- /assets/sql_charac.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUSTDial/NL2SQL360/HEAD/assets/sql_charac.png -------------------------------------------------------------------------------- /assets/BIRD_Heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUSTDial/NL2SQL360/HEAD/assets/BIRD_Heatmap.png -------------------------------------------------------------------------------- /assets/SQLiteStudio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUSTDial/NL2SQL360/HEAD/assets/SQLiteStudio.png -------------------------------------------------------------------------------- /assets/leaderboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUSTDial/NL2SQL360/HEAD/assets/leaderboard.png -------------------------------------------------------------------------------- /src/nl2sql360/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import Core 2 | 3 | 4 | __all__ = [ 5 | "Core" 6 | ] -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/data 2 | **/.vscode 3 | **/__pycache__ 4 | **/dist 5 | **/build 6 | **/*.egg-info 7 | **/tests -------------------------------------------------------------------------------- /assets/Spider_Heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUSTDial/NL2SQL360/HEAD/assets/Spider_Heatmap.png -------------------------------------------------------------------------------- /assets/DB_Domain_Boxplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUSTDial/NL2SQL360/HEAD/assets/DB_Domain_Boxplot.png -------------------------------------------------------------------------------- /assets/DB_Domain_Heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUSTDial/NL2SQL360/HEAD/assets/DB_Domain_Heatmap.png -------------------------------------------------------------------------------- /src/nl2sql360/parser/__init__.py: -------------------------------------------------------------------------------- 1 | from .sql_parser import SQLParser 2 | 3 | 4 | __all__ = [ 5 | "SQLParser" 6 | ] -------------------------------------------------------------------------------- /src/nl2sql360/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import NL2SQLDataset 2 | 3 | 4 | __all__ = [ 5 | "NL2SQLDataset" 6 | ] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | psycopg2-binary 2 | pymysql 3 | func-timeout 4 | pydantic 5 | sqlglot[rs] 6 | tqdm 7 | loguru 8 | sqlalchemy 9 | pandas 10 | pyyaml 11 | nltk 12 | sqlparse -------------------------------------------------------------------------------- /src/nl2sql360/filter/__init__.py: -------------------------------------------------------------------------------- 1 | from .filter import Operator, Field, Filter, Scenario, parse_filter, parse_scenario, serialize_filter, serialize_scenario 2 | 3 | 4 | __all__ = [ 5 | "Operator", 6 | "Field", 7 | "Filter", 8 | "Scenario", 9 | "parse_filter", 10 | "parse_scenario", 11 | "serialize_filter", 12 | "serialize_scenario" 13 | ] -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from .bird_ex import BirdEXEvaluator 2 | from .spider_ex_em import SpiderEXEMEvaluator 3 | from .ves import VesEvaluator 4 | from .rves import RVesEvaluator 5 | from .f1 import F1Evaluator 6 | 7 | 8 | __all__ = [ 9 | "BirdEXEvaluator", 10 | "SpiderEXEMEvaluator", 11 | "VesEvaluator", 12 | "RVesEvaluator", 13 | "F1Evaluator" 14 | ] -------------------------------------------------------------------------------- /examples/py_examples/evaluation.py: -------------------------------------------------------------------------------- 1 | from nl2sql360.core import Core 2 | from nl2sql360.arguments import CoreArguments, EvaluationArguments 3 | 4 | if __name__ == "__main__": 5 | core_args = CoreArguments() 6 | 7 | core = Core(core_args) 8 | 9 | evaluation_args = EvaluationArguments( 10 | eval_name="C3SQL", 11 | eval_dataset="spider_dev", 12 | eval_metrics=["ex", "em", "ves"], 13 | pred_sqls_file="./SuperSQL.sql", 14 | enable_spider_eval=True 15 | ) 16 | 17 | core.evaluate(evaluation_args) 18 | -------------------------------------------------------------------------------- /examples/py_examples/delete_history.py: -------------------------------------------------------------------------------- 1 | from nl2sql360.core import Core 2 | from nl2sql360.arguments import CoreArguments, EvaluationArguments 3 | from nl2sql360.filter import Filter, Scenario, Field, Operator 4 | 5 | if __name__ == "__main__": 6 | core_args = CoreArguments() 7 | 8 | core = Core(core_args) 9 | 10 | core.delete_evaluation_history( 11 | dataset_name="spider_dev", 12 | eval_name="C3SQL" 13 | ) 14 | 15 | core.delete_dataset_history( 16 | dataset_name="spider_dev", 17 | delete_relavant_evaluations=True 18 | ) 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /examples/py_examples/dataset_import.py: -------------------------------------------------------------------------------- 1 | from nl2sql360.core import Core 2 | from nl2sql360.arguments import CoreArguments, DatasetArguments 3 | 4 | core_args = CoreArguments() 5 | 6 | core = Core(core_args) 7 | 8 | dataset_args = DatasetArguments( 9 | dataset_name="spider_dev", 10 | dataset_dir="../data/spider", 11 | samples_file="dev.json", 12 | database_dir="database", 13 | tables_file="tables.json", 14 | question_key="question", 15 | sql_key="query", 16 | db_id_key="db_id", 17 | sql_complexity_key=None, 18 | database_domain_file=None 19 | ) 20 | 21 | core.import_dataset(dataset_args) 22 | 23 | 24 | -------------------------------------------------------------------------------- /src/nl2sql360/arguments/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_args import DatasetArguments 2 | from .core_args import CoreArguments 3 | from .evaluation_args import EvaluationArguments 4 | from .report_args import ReportArguments 5 | from .delete_history_args import DeleteHistoryArguments 6 | from .parser import get_dataset_import_args, get_evaluation_args, get_delete_history_args, get_report_args 7 | 8 | 9 | __all__ = [ 10 | "DatasetArguments", 11 | "CoreArguments", 12 | "EvaluationArguments", 13 | "ReportArguments", 14 | "DeleteHistoryArguments", 15 | "get_dataset_import_args", 16 | "get_evaluation_args", 17 | "get_report_args", 18 | "get_delete_history_args" 19 | ] -------------------------------------------------------------------------------- /src/nl2sql360/arguments/core_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import Optional, Literal 4 | 5 | 6 | @dataclass 7 | class CoreArguments: 8 | r""" 9 | Arguments for NL2SQL360-core. 10 | """ 11 | 12 | core_dir: str = field( 13 | default="./nl2sql360", 14 | metadata={"help": "The directory for NL2SQL360-core storage."} 15 | ) 16 | 17 | core_name: str = field( 18 | default="nl2sql360", 19 | metadata={"help": "The name of NL2SQL360-core"} 20 | ) 21 | 22 | sql_dialect: str = field( 23 | default="SQLite", 24 | metadata={"help": "Specify SQL dialect (e.g., sqlite) to parse."} 25 | ) 26 | 27 | def __post_init__(self): 28 | if self.sql_dialect not in ["SQLite", "MySQL", "PostgreSQL"]: 29 | raise ValueError("`sql_dialect` must be one of `SQLite`, `MySQL` and `PostgreSQL`.") 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Boyan Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/cli_examples/spider/delete_history.yaml: -------------------------------------------------------------------------------- 1 | # ---------------- Core Arguments ---------------- 2 | 3 | # The directory to save NL2SQL360 core data. 4 | core_dir: "data" 5 | 6 | # The NL2SQL360 core name, such that NL2SQL360 core data is saved to "core_dir/core_name.sqlite". 7 | core_name: "nl2sql360" 8 | 9 | # The dataset SQL dialect, "SQLite" by default. 10 | sql_dialect: "SQLite" 11 | 12 | # ---------------- Delete History Arguments ---------------- 13 | 14 | # Notes: 15 | # If `dataset_name` is not set to null and `eval_name` is set to null, it will delete dataset. 16 | # If both `dataset_name` and `eval_name` is set with values, it will delete specific evaluations on the corrsponding dataset. 17 | 18 | # The dataset name to delete which has been imported in NL2SQL360, set to null if you do not want to delete dataset history. 19 | dataset_name: "spider_dev" 20 | 21 | # The dataset name to delete which has been imported in NL2SQL360. 22 | delete_dataset_evaluations: True 23 | 24 | # The evaluation(s) to delete which has been imported in NL2SQL360, set to null if you do not want to delete evaluation history. 25 | # Using list with two keys 26 | eval_name: null -------------------------------------------------------------------------------- /examples/cli_examples/bird/bird_dev_evaluate.yaml: -------------------------------------------------------------------------------- 1 | # ---------------- Core Arguments ---------------- 2 | 3 | # The directory to save NL2SQL360 core data 4 | core_dir: "./nl2sql360_cache" 5 | 6 | # The NL2SQL360 core name, such that NL2SQL360 core data is saved to "core_dir/core_name.sqlite" 7 | core_name: "nl2sql360" 8 | 9 | # The dataset SQL dialect, "SQLite" by default 10 | sql_dialect: "SQLite" 11 | 12 | # ---------------- Evaluation Arguments ---------------- 13 | 14 | # The unique evaluation name to be saved in NL2SQL360 15 | eval_name: "SuperSQL" 16 | 17 | # The dataset name which has been imported in NL2SQL360 18 | eval_dataset: "bird_dev" 19 | 20 | # The evaluation metrics, supporting three different metrics: 21 | # "ex": "Execution Accuracy" 22 | # "em": "Exact-Match Accuracy" 23 | # "ves": "Valid Efficiency Score" 24 | # "rves": "Reward-based Valid Efficiency Score" 25 | # "f1": "Soft-F1 Score" 26 | eval_metrics: 27 | - "ex" 28 | - "ves" 29 | - "rves" 30 | - "f1" 31 | 32 | # The model predited file in the dataset, containing predited sqls in each line. 33 | pred_sqls_file: "SuperSQL.sql" 34 | 35 | # Whether to enable Spider offcial evaluation script, generally set to True if the dataset is Spider or Spider series (e.g., Spider-Syn). 36 | enable_spider_eval: False -------------------------------------------------------------------------------- /examples/cli_examples/spider/evaluation.yaml: -------------------------------------------------------------------------------- 1 | # ---------------- Core Arguments ---------------- 2 | 3 | # The directory to save NL2SQL360 core data 4 | core_dir: "data" 5 | 6 | # The NL2SQL360 core name, such that NL2SQL360 core data is saved to "core_dir/core_name.sqlite" 7 | core_name: "nl2sql360" 8 | 9 | # The dataset SQL dialect, "SQLite" by default 10 | sql_dialect: "SQLite" 11 | 12 | # ---------------- Evaluation Arguments ---------------- 13 | 14 | # The unique evaluation name to be saved in NL2SQL360 15 | eval_name: "SuperSQL" 16 | 17 | # The dataset name which has been imported in NL2SQL360 18 | eval_dataset: "spider_dev" 19 | 20 | # The evaluation metrics, supporting three different metrics: 21 | # "ex": "Execution Accuracy" 22 | # "em": "Exact-Match Accuracy" 23 | # "ves": "Valid Efficiency Score" 24 | # "rves": "Reward-based Valid Efficiency Score" 25 | # "f1": "Soft-F1 Score" 26 | eval_metrics: 27 | - "ex" 28 | - "em" 29 | - "ves" 30 | - "rves" 31 | - "f1" 32 | 33 | # The model predited file in the dataset, containing predited sqls in each line. 34 | pred_sqls_file: "tests/SuperSQL.sql" 35 | 36 | # Whether to enable Spider offcial evaluation script, generally set to True if the dataset is Spider or Spider series (e.g., Spider-Syn). 37 | enable_spider_eval: True -------------------------------------------------------------------------------- /src/nl2sql360/arguments/delete_history_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import Optional, Literal, List, Union, Dict 4 | from ..filter import parse_filter, parse_scenario 5 | 6 | 7 | @dataclass 8 | class DeleteHistoryArguments: 9 | r""" 10 | Arguments for deleting datasets and evaluations. 11 | """ 12 | 13 | dataset_name: Optional[str] = field( 14 | default=None, 15 | metadata={"help": "The dataset (list) to delete."} 16 | ) 17 | 18 | delete_dataset_evaluations: Optional[bool] = field( 19 | default=True, 20 | metadata={"help": "Whether to delete relevant evaluations for dataset."} 21 | ) 22 | 23 | eval_name: Optional[str] = field( 24 | default=None, 25 | metadata={"help": "The evaluation (list) to delete, each item contains two keys `dataset` and `evaluation`."} 26 | ) 27 | 28 | def __post_init__(self): 29 | if self.dataset_name is None and self.eval_name is None: 30 | raise ValueError("`dataset_name` and `eval_name` cannot be empty at the same time.") 31 | 32 | if self.dataset_name is None and self.eval_name is not None: 33 | raise ValueError("Need to specify `dataset_name` for `eval_name` evaluation.") -------------------------------------------------------------------------------- /examples/cli_examples/bird/bird_test_evaluate.yaml: -------------------------------------------------------------------------------- 1 | # ---------------- Core Arguments ---------------- 2 | 3 | # The directory to save NL2SQL360 core data 4 | core_dir: "./nl2sql360_cache" 5 | 6 | # The NL2SQL360 core name, such that NL2SQL360 core data is saved to "core_dir/core_name.sqlite" 7 | core_name: "nl2sql360" 8 | 9 | # The dataset SQL dialect, "SQLite" by default 10 | sql_dialect: "SQLite" 11 | 12 | # ---------------- Evaluation Arguments ---------------- 13 | 14 | # The unique evaluation name to be saved in NL2SQL360 15 | eval_name: "TA_SQL_GPT4o_COT" 16 | 17 | # The dataset name which has been imported in NL2SQL360 18 | eval_dataset: "bird_test" 19 | 20 | # The evaluation metrics, supporting three different metrics: 21 | # "ex": "Execution Accuracy" 22 | # "em": "Exact-Match Accuracy" 23 | # "ves": "Valid Efficiency Score" 24 | # "rves": "Reward-based Valid Efficiency Score" 25 | # "f1": "Soft-F1 Score" 26 | eval_metrics: 27 | - "ex" 28 | - "ves" 29 | - "rves" 30 | - "f1" 31 | 32 | # The model predited file in the dataset, containing predited sqls in each line. 33 | pred_sqls_file: "./TA_SQL_GPT4o_COT.sql" 34 | 35 | # Whether to enable Spider offcial evaluation script, generally set to True if the dataset is Spider or Spider series (e.g., Spider-Syn). 36 | enable_spider_eval: False 37 | -------------------------------------------------------------------------------- /examples/cli_examples/bird/bird_minidev_evaluate.yaml: -------------------------------------------------------------------------------- 1 | # ---------------- Core Arguments ---------------- 2 | 3 | # The directory to save NL2SQL360 core data 4 | core_dir: "./nl2sql360_cache" 5 | 6 | # The NL2SQL360 core name, such that NL2SQL360 core data is saved to "core_dir/core_name.sqlite" 7 | core_name: "nl2sql360" 8 | 9 | # The dataset SQL dialect, "SQLite" by default 10 | sql_dialect: "SQLite" 11 | 12 | # ---------------- Evaluation Arguments ---------------- 13 | 14 | # The unique evaluation name to be saved in NL2SQL360 15 | eval_name: "TA_SQL_GPT4o_COT" 16 | 17 | # The dataset name which has been imported in NL2SQL360 18 | eval_dataset: "bird_minidev" 19 | 20 | # The evaluation metrics, supporting three different metrics: 21 | # "ex": "Execution Accuracy" 22 | # "em": "Exact-Match Accuracy" 23 | # "ves": "Valid Efficiency Score" 24 | # "rves": "Reward-based Valid Efficiency Score" 25 | # "f1": "Soft-F1 Score" 26 | eval_metrics: 27 | - "ex" 28 | - "ves" 29 | - "rves" 30 | - "f1" 31 | 32 | # The model predited file in the dataset, containing predited sqls in each line. 33 | pred_sqls_file: "./TA_SQL_GPT4o_COT.sql" 34 | 35 | # Whether to enable Spider offcial evaluation script, generally set to True if the dataset is Spider or Spider series (e.g., Spider-Syn). 36 | enable_spider_eval: False 37 | -------------------------------------------------------------------------------- /examples/cli_examples/spider/dataset_spider.yaml: -------------------------------------------------------------------------------- 1 | # ---------------- Core Arguments ---------------- 2 | 3 | # The directory to save NL2SQL360 core data 4 | core_dir: "data" 5 | 6 | # The NL2SQL360 core name, such that NL2SQL360 core data is saved to "core_dir/core_name.sqlite" 7 | core_name: "nl2sql360" 8 | 9 | # The dataset SQL dialect, "SQLite" by default 10 | sql_dialect: "SQLite" 11 | 12 | # ---------------- Dataset Arguments ---------------- 13 | 14 | # The unique dataset name to import 15 | dataset_name: "spider_dev" 16 | 17 | # The dataset root directory path 18 | dataset_dir: "data/spider" 19 | 20 | # The name of samples json file 21 | samples_file: "dev.json" 22 | 23 | # The name of tables json file 24 | tables_file: "tables.json" 25 | 26 | # The database directory in "dataset_dir" 27 | database_dir: "database" 28 | 29 | # The key name of NL question in samples json file 30 | question_key: "question" 31 | 32 | # The key name of gold sql in samples json file 33 | sql_key: "query" 34 | 35 | # The key name of database id in samples json file 36 | db_id_key: "db_id" 37 | 38 | # The key name of sql complexity in samples json file, set to null if not specified 39 | sql_complexity_key: null 40 | 41 | # The database domain json file (mapping database id to database domain) in "dataset_dir" 42 | database_domain_file: null 43 | -------------------------------------------------------------------------------- /examples/cli_examples/bird/bird_dev_dataset.yaml: -------------------------------------------------------------------------------- 1 | # ---------------- Core Arguments ---------------- 2 | 3 | # The directory to save NL2SQL360 core data 4 | core_dir: "./nl2sql360_cache" 5 | 6 | # The NL2SQL360 core name, such that NL2SQL360 core data is saved to "core_dir/core_name.sqlite" 7 | core_name: "nl2sql360" 8 | 9 | # The dataset SQL dialect, "SQLite" by default 10 | sql_dialect: "SQLite" 11 | 12 | # ---------------- Dataset Arguments ---------------- 13 | 14 | # The unique dataset name to import 15 | dataset_name: "bird_dev" 16 | 17 | # The dataset root directory path 18 | dataset_dir: "../data/bird/dev" 19 | 20 | # The name of samples json file 21 | samples_file: "dev.json" 22 | 23 | # The name of tables json file 24 | tables_file: "dev_tables.json" 25 | 26 | # The database directory in "dataset_dir" 27 | database_dir: "dev_databases" 28 | 29 | # The key name of NL question in samples json file 30 | question_key: "question" 31 | 32 | # The key name of gold sql in samples json file 33 | sql_key: "SQL" 34 | 35 | # The key name of database id in samples json file 36 | db_id_key: "db_id" 37 | 38 | # The key name of sql complexity in samples json file, set to null if not specified 39 | sql_complexity_key: null 40 | 41 | # The database domain json file (mapping database id to database domain) in "dataset_dir" 42 | database_domain_file: null 43 | -------------------------------------------------------------------------------- /examples/cli_examples/bird/bird_test_dataset.yaml: -------------------------------------------------------------------------------- 1 | # ---------------- Core Arguments ---------------- 2 | 3 | # The directory to save NL2SQL360 core data 4 | core_dir: "./nl2sql360_cache" 5 | 6 | # The NL2SQL360 core name, such that NL2SQL360 core data is saved to "core_dir/core_name.sqlite" 7 | core_name: "nl2sql360" 8 | 9 | # The dataset SQL dialect, "SQLite" by default 10 | sql_dialect: "SQLite" 11 | 12 | # ---------------- Dataset Arguments ---------------- 13 | 14 | # The unique dataset name to import 15 | dataset_name: "bird_test" 16 | 17 | # The dataset root directory path 18 | dataset_dir: "../data/bird/test" 19 | 20 | # The name of samples json file 21 | samples_file: "test.json" 22 | 23 | # The name of tables json file 24 | tables_file: "test_tables.json" 25 | 26 | # The database directory in "dataset_dir" 27 | database_dir: "test_databases" 28 | 29 | # The key name of NL question in samples json file 30 | question_key: "question" 31 | 32 | # The key name of gold sql in samples json file 33 | sql_key: "SQL" 34 | 35 | # The key name of database id in samples json file 36 | db_id_key: "db_id" 37 | 38 | # The key name of sql complexity in samples json file, set to null if not specified 39 | sql_complexity_key: null 40 | 41 | # The database domain json file (mapping database id to database domain) in "dataset_dir" 42 | database_domain_file: null 43 | -------------------------------------------------------------------------------- /examples/cli_examples/bird/bird_minidev_dataset.yaml: -------------------------------------------------------------------------------- 1 | # ---------------- Core Arguments ---------------- 2 | 3 | # The directory to save NL2SQL360 core data 4 | core_dir: "./nl2sql360_cache" 5 | 6 | # The NL2SQL360 core name, such that NL2SQL360 core data is saved to "core_dir/core_name.sqlite" 7 | core_name: "nl2sql360" 8 | 9 | # The dataset SQL dialect, "SQLite" by default 10 | sql_dialect: "SQLite" 11 | 12 | # ---------------- Dataset Arguments ---------------- 13 | 14 | # The unique dataset name to import 15 | dataset_name: "bird_minidev" 16 | 17 | # The dataset root directory path 18 | dataset_dir: "../data/bird/minidev/MINIDEV" 19 | 20 | # The name of samples json file 21 | samples_file: "mini_dev_sqlite.json" 22 | 23 | # The name of tables json file 24 | tables_file: "dev_tables.json" 25 | 26 | # The database directory in "dataset_dir" 27 | database_dir: "dev_databases" 28 | 29 | # The key name of NL question in samples json file 30 | question_key: "question" 31 | 32 | # The key name of gold sql in samples json file 33 | sql_key: "SQL" 34 | 35 | # The key name of database id in samples json file 36 | db_id_key: "db_id" 37 | 38 | # The key name of sql complexity in samples json file, set to null if not specified 39 | sql_complexity_key: null 40 | 41 | # The database domain json file (mapping database id to database domain) in "dataset_dir" 42 | database_domain_file: null 43 | -------------------------------------------------------------------------------- /src/nl2sql360/database/util.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict, Any 2 | from sqlalchemy import Engine 3 | from sqlalchemy.orm import Session 4 | from .model import DatasetInfo, MetaDataset, get_dataset_model 5 | 6 | 7 | 8 | def get_dataset_name_from_table_name(dataset_table_name: str) -> str: 9 | return dataset_table_name.split("DATASET_")[1] 10 | 11 | 12 | def get_dataset_name_and_evaluation_name_from_table_name(evaluation_table_name: str) -> str: 13 | splits = evaluation_table_name.split("_EVALUATION_") 14 | dataset_name = splits[0].split("DATASET_")[-1] 15 | evaluation_name = splits[1] 16 | return dataset_name, evaluation_name 17 | 18 | 19 | def get_dataset_info(db_engine: "Engine", dataset_name: str) -> Optional[DatasetInfo]: 20 | with Session(db_engine) as session: 21 | query_res = session.query(DatasetInfo).filter(DatasetInfo.dataset_name == dataset_name).all() 22 | if query_res: 23 | assert len(query_res) == 1 24 | return query_res[0] 25 | else: 26 | return None 27 | 28 | 29 | def get_dataset_samples(db_engine: "Engine", dataset_model: "MetaDataset") -> Optional[Dict]: 30 | with Session(db_engine) as session: 31 | query_res = session.query(dataset_model).order_by(dataset_model.id) 32 | 33 | return [{ 34 | "nlq": record.nlq, 35 | "gold": record.gold, 36 | "db_id": record.db_id, 37 | } for record in query_res] 38 | 39 | -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/f1.py: -------------------------------------------------------------------------------- 1 | from .bird_eval.evaluation_f1 import run_sqls_parallel, sort_results 2 | import os 3 | 4 | 5 | class F1Evaluator: 6 | 7 | def __init__(self, sql_dialect="SQLite", **kwds) -> None: 8 | self.sql_dialect = sql_dialect 9 | self.db_host = kwds.get("db_host", None) 10 | self.db_port = kwds.get("db_port", None) 11 | self.db_name = kwds.get("db_name", None) 12 | self.user = kwds.get("db_user", None) 13 | self.password = kwds.get("db_password", None) 14 | 15 | def evaluate(self, gold_sqls, pred_sqls, db_ids, db_dir, **kwds): 16 | query_pairs = list(zip(pred_sqls, gold_sqls)) 17 | db_places = [os.path.join(db_dir, db_id, f"{db_id}.sqlite") for db_id in db_ids] 18 | exec_result = run_sqls_parallel( 19 | sqls=query_pairs, 20 | db_places=db_places, 21 | num_cpus=kwds.get("num_processes", 8), 22 | meta_time_out=kwds.get("timeout", 30), 23 | sql_dialect=self.sql_dialect, 24 | host=self.db_host, 25 | user=self.user, 26 | password=self.password, 27 | dbname=self.db_name, 28 | port=self.db_port 29 | ) 30 | exec_result = sort_results(exec_result) 31 | exec_result = [res['res'] for res in exec_result] 32 | return { 33 | "f1": exec_result 34 | } 35 | 36 | def get_eval_metrics(self): 37 | return ["f1"] -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/bird_ex.py: -------------------------------------------------------------------------------- 1 | from .bird_eval.bird_ex import run_sqls_parallel, sort_results 2 | import os 3 | 4 | 5 | class BirdEXEvaluator: 6 | 7 | def __init__(self, sql_dialect="SQLite", **kwds) -> None: 8 | self.sql_dialect = sql_dialect 9 | self.db_host = kwds.get("db_host", None) 10 | self.db_port = kwds.get("db_port", None) 11 | self.db_name = kwds.get("db_name", None) 12 | self.user = kwds.get("db_user", None) 13 | self.password = kwds.get("db_password", None) 14 | 15 | def evaluate(self, gold_sqls, pred_sqls, db_ids, db_dir, **kwds): 16 | query_pairs = list(zip(pred_sqls, gold_sqls)) 17 | db_places = [os.path.join(db_dir, db_id, f"{db_id}.sqlite") for db_id in db_ids] 18 | exec_result = run_sqls_parallel( 19 | sqls=query_pairs, 20 | db_places=db_places, 21 | num_cpus=kwds.get("num_processes", 8), 22 | meta_time_out=kwds.get("timeout", 30), 23 | sql_dialect=self.sql_dialect, 24 | host=self.db_host, 25 | user=self.user, 26 | password=self.password, 27 | dbname=self.db_name, 28 | port=self.db_port 29 | 30 | ) 31 | exec_result = sort_results(exec_result) 32 | exec_result = [res['res'] for res in exec_result] 33 | return { 34 | "exec_acc": exec_result 35 | } 36 | 37 | def get_eval_metrics(self): 38 | return ["exec_acc"] -------------------------------------------------------------------------------- /src/nl2sql360/database/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Base, DatasetInfo, MetaDataset, MetaEvaluation, get_dataset_model, get_evaluation_model 2 | from .util import (get_dataset_name_from_table_name, 3 | get_dataset_name_and_evaluation_name_from_table_name, 4 | get_dataset_info, 5 | get_dataset_samples) 6 | from .template import (METRIC_COL_MAPPING, 7 | QUERY_OVERALL_PERFORMANCE, 8 | QUERY_QVT_PERFORMANCE, 9 | QUERY_SUBSET_PERFORMANCE, 10 | QUERY_DATASET_SIZE, 11 | QUERY_DATASET_DOMAIN_DISTRIBUTION, 12 | QUERY_DATASET_SQL_KEYWORDS_DISTRIBUTION, 13 | DELETE_DATASET_TABLE, 14 | DELETE_EVALUATION_TABLE, 15 | DELETE_DATASET_INFO) 16 | 17 | 18 | __all__ = [ 19 | "Base", 20 | "DatasetInfo", 21 | "MetaDataset", 22 | "MetaEvaluation", 23 | "get_dataset_model", 24 | "get_evaluation_model", 25 | "get_dataset_name_from_table_name", 26 | "get_dataset_name_and_evaluation_name_from_table_name", 27 | "get_dataset_info", 28 | "get_dataset_samples", 29 | "METRIC_COL_MAPPING", 30 | "QUERY_OVERALL_PERFORMANCE", 31 | "QUERY_QVT_PERFORMANCE", 32 | "QUERY_SUBSET_PERFORMANCE", 33 | "QUERY_DATASET_SIZE", 34 | "QUERY_DATASET_DOMAIN_DISTRIBUTION", 35 | "QUERY_DATASET_SQL_KEYWORDS_DISTRIBUTION", 36 | "DELETE_DATASET_TABLE", 37 | "DELETE_EVALUATION_TABLE", 38 | "DELETE_DATASET_INFO" 39 | ] -------------------------------------------------------------------------------- /examples/py_examples/report.py: -------------------------------------------------------------------------------- 1 | from nl2sql360.core import Core 2 | from nl2sql360.arguments import CoreArguments, EvaluationArguments 3 | from nl2sql360.filter import Filter, Scenario, Field, Operator 4 | 5 | if __name__ == "__main__": 6 | core_args = CoreArguments() 7 | 8 | core = Core(core_args) 9 | 10 | SUBQUERY_FILTER = Filter( 11 | name="subquery", 12 | field=Field.SUBQUERY, 13 | operator=Operator.GT, 14 | value=0 15 | ) 16 | 17 | BI_SCENARIO = Scenario( 18 | name="BI", 19 | filters=[Filter('agg', Field.AGGREGATION, Operator.GT, 0), Filter('join', Field.JOIN, Operator.GT, 0)] 20 | ) 21 | 22 | print(core.query_overall_leaderboard(dataset_name="spider_dev", metric="ex")) 23 | 24 | print(core.query_filter_performance(dataset_name="spider_dev", metric="ex", filter=filter, eval_name="SuperSQL")) 25 | 26 | print(core.query_filter_leaderboard(dataset_name="spider_dev", metric="ex", filter=filter)) 27 | 28 | print(core.query_scenario_performance(dataset_name="spider_dev", metric="ex", eval_name="SuperSQL", scenario=BI_SCENARIO)) 29 | 30 | print(core.query_scenario_leaderboard(dataset_name="spider_dev", metric="ex", scenario=BI_SCENARIO)) 31 | 32 | print(core.query_dataset_domain_distribution(dataset_name="spider_dev")) 33 | 34 | print(core.generate_evaluation_report(dataset_name="spider_dev", 35 | filters=[SUBQUERY_FILTER], 36 | scenarios=[BI_SCENARIO], 37 | metrics=["ex", "em", "ves"])) 38 | -------------------------------------------------------------------------------- /src/nl2sql360/cli/util.py: -------------------------------------------------------------------------------- 1 | from ..arguments import ( 2 | get_dataset_import_args, 3 | get_evaluation_args, 4 | get_report_args, 5 | get_delete_history_args 6 | ) 7 | from ..core import Core 8 | import pandas as pd 9 | from loguru import logger 10 | from pathlib import Path 11 | 12 | 13 | def run_dataset_import(): 14 | core_args, dataset_args = get_dataset_import_args() 15 | Core(core_args).import_dataset(dataset_args) 16 | 17 | 18 | def run_evaluation(): 19 | core_args, evaluation_args = get_evaluation_args() 20 | Core(core_args).evaluate(evaluation_args) 21 | 22 | 23 | def run_report(): 24 | core_args, report_args = get_report_args() 25 | report = Core(core_args).generate_evaluation_report( 26 | dataset_name=report_args.report_dataset, 27 | filters=report_args.filter, 28 | scenarios=report_args.scenario, 29 | metrics=report_args.metric, 30 | eval_names=report_args.report_evaluation 31 | ) 32 | report.to_csv(report_args.save_path) 33 | logger.success(f"Save report in path `{Path(report_args.save_path).resolve()}` successfully.`") 34 | 35 | 36 | def run_delete_history(): 37 | core_args, delete_history_args = get_delete_history_args() 38 | core = Core(core_args) 39 | if delete_history_args.dataset_name and not delete_history_args.eval_name: 40 | core.delete_dataset_history( 41 | dataset_name=delete_history_args.dataset_name, 42 | delete_relavant_evaluations=delete_history_args.delete_dataset_evaluations 43 | ) 44 | if delete_history_args.dataset_name and delete_history_args.eval_name: 45 | core.delete_evaluation_history( 46 | dataset_name=delete_history_args.dataset_name, 47 | eval_name=delete_history_args.eval_name 48 | ) 49 | -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/rves.py: -------------------------------------------------------------------------------- 1 | from .bird_eval.bird_rves import run_sqls_parallel, sort_results 2 | import os 3 | import math 4 | from loguru import logger 5 | 6 | 7 | class RVesEvaluator: 8 | 9 | def __init__(self, reuse_ex, sql_dialect="SQLite", **kwds): 10 | self.reuse_ex = reuse_ex 11 | self.sql_dialect = sql_dialect 12 | self.db_host = kwds.get("db_host", None) 13 | self.db_port = kwds.get("db_port", None) 14 | self.db_name = kwds.get("db_name", None) 15 | self.user = kwds.get("db_user", None) 16 | self.password = kwds.get("db_password", None) 17 | 18 | def evaluate(self, gold_sqls, pred_sqls, db_ids, db_dir, **kwds): 19 | query_pairs = list(zip(pred_sqls, gold_sqls)) 20 | db_places = [os.path.join(db_dir, db_id, f"{db_id}.sqlite") for db_id in db_ids] 21 | exec_acc_list = kwds.get("exec_acc_list", None) 22 | if self.reuse_ex and exec_acc_list is None: 23 | logger.warning("VES evaluator is set to reuse the EX result, but it has not been passed in.") 24 | rves_result = run_sqls_parallel( 25 | sqls=query_pairs, 26 | db_places=db_places, 27 | num_cpus=kwds.get("num_processes", 8), 28 | meta_time_out=kwds.get("timeout", 30), 29 | sql_dialect=self.sql_dialect, 30 | exec_acc_list=exec_acc_list, 31 | host=self.db_host, 32 | user=self.user, 33 | password=self.password, 34 | dbname=self.db_name, 35 | port=self.db_port 36 | ) 37 | rves_result = sort_results(rves_result) 38 | rves_result = [math.sqrt(res['reward']) for res in rves_result] 39 | return { 40 | "rves": rves_result 41 | } 42 | 43 | def get_eval_metrics(self): 44 | return ["rves"] 45 | -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/ves.py: -------------------------------------------------------------------------------- 1 | from .bird_eval.bird_ves import run_sqls_parallel, sort_results 2 | import os 3 | import math 4 | from loguru import logger 5 | 6 | 7 | class VesEvaluator: 8 | 9 | def __init__(self, reuse_ex, sql_dialect="SQLite", **kwds): 10 | self.reuse_ex = reuse_ex 11 | self.sql_dialect = sql_dialect 12 | self.db_host = kwds.get("db_host", None) 13 | self.db_port = kwds.get("db_port", None) 14 | self.db_name = kwds.get("db_name", None) 15 | self.user = kwds.get("db_user", None) 16 | self.password = kwds.get("db_password", None) 17 | 18 | 19 | def evaluate(self, gold_sqls, pred_sqls, db_ids, db_dir, **kwds): 20 | query_pairs = list(zip(pred_sqls, gold_sqls)) 21 | db_places = [os.path.join(db_dir, db_id, f"{db_id}.sqlite") for db_id in db_ids] 22 | exec_acc_list = kwds.get("exec_acc_list", None) 23 | if self.reuse_ex and exec_acc_list is None: 24 | logger.warning("VES evaluator is set to reuse the EX result, but it has not been passed in.") 25 | ves_result = run_sqls_parallel( 26 | sqls=query_pairs, 27 | db_places=db_places, 28 | num_cpus=kwds.get("num_processes", 8), 29 | meta_time_out=kwds.get("timeout", 30), 30 | exec_acc_list=exec_acc_list, 31 | sql_dialect=self.sql_dialect, 32 | host=self.db_host, 33 | user=self.user, 34 | password=self.password, 35 | dbname=self.db_name, 36 | port=self.db_port 37 | ) 38 | ves_result = sort_results(ves_result) 39 | ves_result = [math.sqrt(res['time_ratio']) for res in ves_result] 40 | return { 41 | "ves": ves_result 42 | } 43 | 44 | def get_eval_metrics(self): 45 | return ["ves"] 46 | -------------------------------------------------------------------------------- /examples/cli_examples/spider/report.yaml: -------------------------------------------------------------------------------- 1 | # ---------------- Core Arguments ---------------- 2 | 3 | # The directory to save NL2SQL360 core data. 4 | core_dir: "data" 5 | 6 | # The NL2SQL360 core name, such that NL2SQL360 core data is saved to "core_dir/core_name.sqlite". 7 | core_name: "nl2sql360" 8 | 9 | # The dataset SQL dialect, "SQLite" by default. 10 | sql_dialect: "SQLite" 11 | 12 | # ---------------- Report Arguments ---------------- 13 | 14 | # The dataset name which has been imported in NL2SQL360. 15 | report_dataset: "spider_dev" 16 | 17 | # The evaluation(s) name to report, listing each name in the following. 18 | # Set to null if you want to report all history evaluations. 19 | report_evaluation: 20 | - "SuperSQL" 21 | 22 | # The metric(s) name to report, including "ex", "em", "ves", "rves", "f1", "qvt". List each name in the following: 23 | metric: 24 | - "em" 25 | - "ex" 26 | - "ves" 27 | - "rves" 28 | - "f1" 29 | - "qvt" 30 | 31 | # Define subset(s) performance by filter(s). List each filter defination in the following. 32 | # Filter: 33 | # "name": The name for the filtered subset to show. 34 | # "expression": The filter expression in format "{FILTER_KEY} {<, >, =} {NUMBER}". 35 | # Valid {FILTER_KEY} is listed in 36 | # https://github.com/BugMaker-Boyan/NL2SQL360/blob/fe436d43031e06cd457e44ec98fd25a5acd25c2b/src/nl2sql360/filter/filter.py#L13 37 | filter: 38 | - 39 | name: "Subquery" 40 | expression: "SUBQUERY > 0" 41 | - 42 | name: "Join" 43 | expression: "JOIN > 0" 44 | 45 | # Define scenario(s) performance. List each scenario defination in the following. 46 | # Scenario: Combination of multiple filters joined by "&&". 47 | scenario: 48 | - 49 | name: "BI" 50 | expression: "SUBQUERY > 0 && JOIN > 0" 51 | 52 | # The report save path in CSV format 53 | save_path: "./report.csv" 54 | -------------------------------------------------------------------------------- /src/nl2sql360/cli/cli.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from enum import Enum, unique 3 | from .. import VERSION 4 | from .util import run_dataset_import, run_delete_history, run_evaluation, run_report 5 | 6 | 7 | USAGE = ( 8 | "-" * 70 9 | + "\n" 10 | + "| Usage: |\n" 11 | + "| nl2sql360-cli dataset -h: import NL2SQL dataset |\n" 12 | + "| nl2sql360-cli evaluate -h: evaluate NL2SQL model |\n" 13 | + "| nl2sql360-cli report -h: output evaluation report |\n" 14 | + "| nl2sql360-cli delete -h: delete dataset or evaluation history |\n" 15 | + "| nl2sql360-cli version: show version info |\n" 16 | + "-" * 70 17 | ) 18 | 19 | 20 | WELCOME = ( 21 | "-" * 58 22 | + "\n" 23 | + "| Welcome to NL2SQL360, version {}".format(VERSION) 24 | + " " * (25 - len(VERSION)) 25 | + "|\n|" 26 | + " " * 56 27 | + "|\n" 28 | + "| Project page: https://github.com/HKUSTDial/NL2SQL360" 29 | + " " * 3 30 | + "|\n" 31 | + "-" * 58 32 | ) 33 | 34 | 35 | @unique 36 | class Command(str, Enum): 37 | DATASET = "dataset" 38 | EVALUATE = "evaluate" 39 | REPORT = "report" 40 | DELETE = "delete" 41 | VERSION = "version" 42 | HELP = "help" 43 | 44 | 45 | def main(): 46 | command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP 47 | if command == Command.DATASET: 48 | run_dataset_import() 49 | elif command == Command.EVALUATE: 50 | run_evaluation() 51 | elif command == Command.REPORT: 52 | run_report() 53 | elif command == Command.DELETE: 54 | run_delete_history() 55 | elif command == Command.VERSION: 56 | print(WELCOME) 57 | elif command == Command.HELP: 58 | print(USAGE) 59 | else: 60 | raise NotImplementedError("Unknown command: {}".format(command)) 61 | -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/spider_ex_em.py: -------------------------------------------------------------------------------- 1 | from .test_suite_sql_eval.evaluation import evaluate, build_foreign_key_map_from_json 2 | from loguru import logger 3 | 4 | 5 | class SpiderEXEMEvaluator: 6 | 7 | def __init__(self, eval_ex, eval_em): 8 | self.eval_ex = eval_ex 9 | self.eval_em = eval_em 10 | 11 | def get_eval_metrics(self): 12 | eval_metrics = [] 13 | if self.eval_ex: 14 | eval_metrics.append("exec_acc") 15 | if self.eval_em: 16 | eval_metrics.append("exact_acc") 17 | return eval_metrics 18 | 19 | def evaluate(self, gold_sqls, pred_sqls, db_ids, db_dir, **kwds): 20 | if self.eval_ex and self.eval_em: 21 | etype = "all" 22 | elif self.eval_ex: 23 | etype = "exec" 24 | elif self.eval_em: 25 | etype = "match" 26 | else: 27 | logger.warning(f"Spider evaluator has not been set with evaluation metrics and is defaulted to EX.") 28 | etype = "exec" 29 | 30 | kmaps = None 31 | 32 | if etype in ["all", "match"]: 33 | if kwds.get("tables_json_path", None) is None: 34 | logger.warning(f"`EM` metric evaluation need tables json path passed in. Exclude `EM` by default.") 35 | etype = "exec" 36 | kmaps = None 37 | else: 38 | kmaps = build_foreign_key_map_from_json(kwds.get("tables_json_path")) 39 | 40 | golds = [f"{gold}\t{db_id}" for gold, db_id in zip(gold_sqls, db_ids)] 41 | 42 | entries = evaluate( 43 | golds=golds, 44 | preds=pred_sqls, 45 | db_dir=db_dir, 46 | etype=etype, 47 | kmaps=kmaps, 48 | plug_value=False, 49 | keep_distinct=False, 50 | progress_bar_for_each_datapoint=False 51 | ) 52 | 53 | return { 54 | "exec_acc": [entry.get("exec", None) for entry in entries], 55 | "exact_acc": [entry.get("exact", None) for entry in entries] 56 | } -------------------------------------------------------------------------------- /src/nl2sql360/database/template.py: -------------------------------------------------------------------------------- 1 | 2 | METRIC_COL_MAPPING = { 3 | "ex": "exec_acc", 4 | "em": "exact_acc", 5 | "ves": "ves", 6 | "rves": "rves", 7 | "f1": "f1", 8 | "qvt": None 9 | } 10 | 11 | 12 | QUERY_OVERALL_PERFORMANCE = \ 13 | """ 14 | SELECT AVG({METRIC_COL}) * 100 from DATASET_{DATASET_NAME}_EVALUATION_{EVAL_NAME} AS e JOIN DATASET_{DATASET_NAME} AS d ON e.id = d.id; 15 | """ 16 | 17 | 18 | QUERY_SUBSET_PERFORMANCE = \ 19 | """ 20 | SELECT AVG({METRIC_COL}) * 100 from DATASET_{DATASET_NAME}_EVALUATION_{EVAL_NAME} AS e JOIN DATASET_{DATASET_NAME} AS d ON e.id = d.id WHERE {WHERE_CONDITION}; 21 | """ 22 | 23 | 24 | QUERY_QVT_PERFORMANCE = \ 25 | """ 26 | SELECT AVG(exec_acc) * 100 FROM ( 27 | SELECT AVG(exec_acc) as exec_acc FROM DATASET_{DATASET_NAME}_EVALUATION_{EVAL_NAME} AS e JOIN DATASET_{DATASET_NAME} AS d ON e.id = d.id GROUP BY gold HAVING COUNT(d.gold) >= 2 and sum(e.exec_acc) != 0 28 | ); 29 | """ 30 | 31 | 32 | QUERY_DATASET_SIZE = \ 33 | """ 34 | SELECT COUNT(*), COUNT(DISTINCT gold) FROM DATASET_{DATASET_NAME}; 35 | """ 36 | 37 | 38 | QUERY_DATASET_SQL_KEYWORDS_DISTRIBUTION = \ 39 | """ 40 | SELECT 41 | AVG(count_query_fields), 42 | AVG(count_group_by), 43 | AVG(count_order_by), 44 | AVG(count_limit), 45 | AVG(count_join), 46 | AVG(count_predicate), 47 | AVG(count_aggregation), 48 | AVG(count_scalar_function), 49 | AVG(count_subquery), 50 | AVG(count_set_operation), 51 | AVG(count_math_compute), 52 | AVG(count_logical_connecter), 53 | AVG(count_distinct), 54 | AVG(count_like), 55 | AVG(count_control_flow), 56 | AVG(count_window) 57 | FROM DATASET_{DATASET_NAME}; 58 | """ 59 | 60 | 61 | QUERY_DATASET_DOMAIN_DISTRIBUTION = \ 62 | """ 63 | SELECT db_domain, COUNT(*) FROM DATASET_{DATASET_NAME} GROUP BY db_domain ORDER BY db_domain; 64 | """ 65 | 66 | 67 | DELETE_DATASET_TABLE = \ 68 | """ 69 | DROP TABLE IF EXISTS DATASET_{DATASET_NAME}; 70 | """ 71 | 72 | DELETE_DATASET_INFO = \ 73 | """ 74 | DELETE FROM __DATASET_INFO__ WHERE dataset_name = "{DATASET_NAME}"; 75 | """ 76 | 77 | 78 | DELETE_EVALUATION_TABLE = \ 79 | """ 80 | DROP TABLE IF EXISTS DATASET_{DATASET_NAME}_EVALUATION_{EVAL_NAME}; 81 | """ 82 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | from setuptools import find_packages, setup 5 | 6 | 7 | def get_requires(): 8 | with open("requirements.txt", "r", encoding="utf-8") as f: 9 | file_content = f.read() 10 | lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] 11 | return lines 12 | 13 | 14 | def get_version(): 15 | with open(os.path.join("src", "nl2sql360", "__init__.py"), "r", encoding="utf-8") as f: 16 | file_content = f.read() 17 | pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION") 18 | (version,) = re.findall(pattern, file_content) 19 | return version 20 | 21 | 22 | def main(): 23 | setup( 24 | name="nl2sql360", 25 | version=get_version(), 26 | author="boyanli", 27 | author_email="boyanli" "@" "hkust-gz.edu.cn", 28 | description="A Multi-angle NL2SQL Evaluation Framework", 29 | long_description=open("README.md", "r", encoding="utf-8").read(), 30 | long_description_content_type="text/markdown", 31 | keywords=["NL2SQL", "Text-to-SQL", "T2S", "SQL", "Database", "NLIDB"], 32 | license="MIT License", 33 | url="https://github.com/HKUSTDial/NL2SQL360", 34 | package_dir={"": "src"}, 35 | packages=find_packages("src"), 36 | python_requires=">=3.8.0", 37 | install_requires=get_requires(), 38 | entry_points={"console_scripts": ["nl2sql360-cli = nl2sql360.cli:main"]}, 39 | classifiers=[ 40 | "Development Status :: 3 - Alpha", 41 | "Intended Audience :: Developers", 42 | "Intended Audience :: Education", 43 | "Intended Audience :: Science/Research", 44 | "License :: OSI Approved :: MIT License", 45 | "Operating System :: OS Independent", 46 | "Programming Language :: Python :: 3", 47 | "Programming Language :: Python :: 3.8", 48 | "Programming Language :: Python :: 3.9", 49 | "Programming Language :: Python :: 3.10", 50 | "Programming Language :: Python :: 3.11", 51 | "Topic :: Database", 52 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 53 | "Topic :: Software Development :: Testing" 54 | ], 55 | ) 56 | 57 | 58 | if __name__ == "__main__": 59 | main() -------------------------------------------------------------------------------- /src/nl2sql360/arguments/parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import loguru 4 | from typing import Any, Dict, Optional, Tuple 5 | from .hf_argparser import HfArgumentParser 6 | 7 | from .core_args import CoreArguments 8 | from .dataset_args import DatasetArguments 9 | from .evaluation_args import EvaluationArguments 10 | from .report_args import ReportArguments 11 | from .delete_history_args import DeleteHistoryArguments 12 | 13 | 14 | _DATASET_IMPORT_ARGS = [CoreArguments, DatasetArguments] 15 | _DATASET_IMPORT_CLS = Tuple[CoreArguments, DatasetArguments] 16 | _EVALUATION_ARGS = [CoreArguments, EvaluationArguments] 17 | _EVALUATION_CLS = Tuple[CoreArguments, EvaluationArguments] 18 | _REPORT_ARGS = [CoreArguments, ReportArguments] 19 | _REPORT_CLS = Tuple[CoreArguments, ReportArguments] 20 | _DELETE_HISTORY_ARGS = [CoreArguments, DeleteHistoryArguments] 21 | _DELETE_HISTORY_CLS = Tuple[CoreArguments, DeleteHistoryArguments] 22 | 23 | 24 | def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: 25 | if args is not None: 26 | return parser.parse_dict(args) 27 | 28 | if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): 29 | return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) 30 | 31 | (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True) 32 | 33 | if unknown_args: 34 | print(parser.format_help()) 35 | print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args)) 36 | raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args)) 37 | 38 | return (*parsed_args,) 39 | 40 | 41 | def get_dataset_import_args(args: Optional[Dict[str, Any]] = None) -> _DATASET_IMPORT_CLS: 42 | parser = HfArgumentParser(_DATASET_IMPORT_ARGS) 43 | return _parse_args(parser, args) 44 | 45 | 46 | def get_evaluation_args(args: Optional[Dict[str, Any]] = None) -> _EVALUATION_CLS: 47 | parser = HfArgumentParser(_EVALUATION_ARGS) 48 | return _parse_args(parser, args) 49 | 50 | 51 | def get_report_args(args: Optional[Dict[str, Any]] = None) -> _REPORT_CLS: 52 | parser = HfArgumentParser(_REPORT_ARGS) 53 | return _parse_args(parser, args) 54 | 55 | 56 | def get_delete_history_args(args: Optional[Dict[str, Any]] = None) -> _DELETE_HISTORY_CLS: 57 | parser = HfArgumentParser(_DELETE_HISTORY_ARGS) 58 | return _parse_args(parser, args) 59 | -------------------------------------------------------------------------------- /src/nl2sql360/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | import os 3 | from ..arguments import DatasetArguments 4 | import json 5 | 6 | 7 | class NL2SQLDataset: 8 | 9 | def __init__(self, dataset_args: "DatasetArguments") -> None: 10 | self.dataset_args = dataset_args 11 | 12 | def get_all_samples(self): 13 | with open(os.path.join(self.dataset_args.dataset_dir, self.dataset_args.samples_file), "r", encoding="utf-8") as f: 14 | samples = json.load(f) 15 | return samples 16 | 17 | def get_all_questions(self): 18 | return [item[self.dataset_args.question_key] for item in self.get_all_samples()] 19 | 20 | def get_all_sqls(self): 21 | return [item[self.dataset_args.sql_key] for item in self.get_all_samples()] 22 | 23 | def get_all_db_ids(self): 24 | return [item[self.dataset_args.db_id_key] for item in self.get_all_samples()] 25 | 26 | def get_all_sql_complexity(self): 27 | if self.dataset_args.sql_complexity_key is not None: 28 | return [item[self.dataset_args.sql_complexity_key] for item in self.get_all_samples()] 29 | else: 30 | return ["" for _ in self.get_all_samples()] 31 | 32 | def get_all_database_domains(self): 33 | if self.dataset_args.database_domain_file is not None: 34 | with open(os.path.join(self.dataset_args.dataset_dir, self.dataset_args.database_domain_file), "r", encoding="utf-8") as f: 35 | domain_mapping = json.load(f) 36 | else: 37 | domain_mapping = dict() 38 | return [domain_mapping.get(item[self.dataset_args.db_id_key], "") for item in self.get_all_samples()] 39 | 40 | def get_all_database_paths(self): 41 | return [os.path.join(self.dataset_args.dataset_dir, 42 | self.dataset_args.database_dir, 43 | item[self.dataset_args.db_id_key], 44 | f"{item[self.dataset_args.db_id_key]}.sqlite" 45 | ) for item in self.get_all_samples()] 46 | 47 | def get_tables(self): 48 | if self.dataset_args.tables_file is None: 49 | return ValueError("`tables_file` dose not exist or is not a valid json file.") 50 | else: 51 | with open(os.path.join(self.dataset_args.dataset_dir, self.dataset_args.tables_file), "r", encoding="utf-8") as f: 52 | return json.load(f) 53 | 54 | def __len__(self): 55 | return len(self.get_all_samples()) 56 | 57 | -------------------------------------------------------------------------------- /src/nl2sql360/arguments/dataset_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | 6 | @dataclass 7 | class DatasetArguments: 8 | r""" 9 | Arguments for importing dataset. 10 | """ 11 | 12 | dataset_name: str = field( 13 | metadata={"help": "The unique dataset name."} 14 | ) 15 | 16 | dataset_dir: str = field( 17 | metadata={"help": "Path to the folder containing the dataset."} 18 | ) 19 | 20 | samples_file: str = field( 21 | metadata={"help": "The json file containing dataset samples."} 22 | ) 23 | 24 | database_dir: str = field( 25 | metadata={"help": "The directory containing databases (i.e., sqlite files)."} 26 | ) 27 | 28 | tables_file: Optional[str] = field( 29 | default=None, 30 | metadata={"help": "The json file containing database tables (schemas)."} 31 | ) 32 | 33 | question_key: str = field( 34 | default="question", 35 | metadata={"help": "The key name of NL questions in the data json file."} 36 | ) 37 | 38 | sql_key: str = field( 39 | default="sql", 40 | metadata={"help": "The key name of SQL queries in the data json file."} 41 | ) 42 | 43 | db_id_key: str = field( 44 | default="db_id", 45 | metadata={"help": "The key name of database id in the data json file."} 46 | ) 47 | 48 | sql_complexity_key: Optional[str] = field( 49 | default=None, 50 | metadata={"help": "The key name of SQL complexity in the data json file."} 51 | ) 52 | 53 | database_domain_file: str = field( 54 | default=None, 55 | metadata={"help": "The json file containing database domain classifications."} 56 | ) 57 | 58 | def __post_init__(self): 59 | samples_file_path = Path(self.dataset_dir, self.samples_file) 60 | if not samples_file_path.exists() or not samples_file_path.is_file() or samples_file_path.suffix != ".json": 61 | raise ValueError("`samples_file` dose not exist or is not a valid json file.") 62 | 63 | if self.tables_file: 64 | tables_file_path = Path(self.dataset_dir, self.tables_file) 65 | if not tables_file_path.exists() or not tables_file_path.is_file() or tables_file_path.suffix != ".json": 66 | raise ValueError("`tables_file` dose not exist or is not a valid json file.") 67 | 68 | database_dir_path = Path(self.dataset_dir, self.database_dir) 69 | if not database_dir_path.exists() or not database_dir_path.is_dir(): 70 | raise ValueError("`database_dir` dose not exist or is not a directory.") 71 | -------------------------------------------------------------------------------- /src/nl2sql360/filter/filter.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum 3 | from typing import List, Optional 4 | import re 5 | 6 | 7 | class Operator(Enum): 8 | GT = ">" 9 | LT = "<" 10 | EQ = "=" 11 | 12 | 13 | class Field(Enum): 14 | QUERY_FIELDS = "QUERY_FIELDS" 15 | GROUP_BY = "GROUP_BY" 16 | ORDER_BY = "ORDER_BY" 17 | LIMIT = "LIMIT" 18 | JOIN = "JOIN" 19 | PREDICATE = "PREDICATE" 20 | AGGREGATION = "AGGREGATION" 21 | SCALAR_FUNCTION = "SCALAR_FUNCTION" 22 | SUBQUERY = "SUBQUERY" 23 | SET_OPERATION = "SET_OPERATION" 24 | MATH_COMPUTE = "MATH_COMPUTE" 25 | LOGICAL_CONNECTOR = "LOGICAL_CONNECTOR" 26 | DISTINCT = "DISTINCT" 27 | LIKE = "LIKE" 28 | CONTROL_FLOW = "CONTROL_FLOW" 29 | WINDOW = "WINDOW" 30 | 31 | 32 | @dataclass 33 | class Filter: 34 | name: str 35 | field: Field 36 | operator: Operator 37 | value: int 38 | 39 | 40 | _SCENARIO_CONNECTOR = "&&" 41 | 42 | @dataclass 43 | class Scenario: 44 | name: str 45 | filters: List[Filter] 46 | 47 | 48 | def parse_filter(filter_name: str, filter_expression: str) -> Optional[Filter]: 49 | pattern = re.compile(r'^(?P{})\s*(?P{})\s*(?P\d+)$'.format( 50 | '|'.join([field.value for field in Field]), 51 | '|'.join([op.value for op in Operator]) 52 | )) 53 | match = pattern.match(filter_expression) 54 | if match: 55 | field = match.group('field') 56 | op = match.group('op') 57 | value = match.group('value') 58 | return Filter( 59 | name=filter_name, 60 | field=Field(field), 61 | operator=Operator(op), 62 | value=int(value) 63 | ) 64 | else: 65 | return None 66 | 67 | 68 | def parse_scenario(scenario_name: str, scenario_str: str) -> Optional[Scenario]: 69 | filter_expressions = scenario_str.split(_SCENARIO_CONNECTOR) 70 | filters = [] 71 | for idx, exp in enumerate(filter_expressions): 72 | filter = parse_filter(filter_name=f"{scenario_name}-Filter-{idx}", filter_expression=exp.strip()) 73 | if filter is None: 74 | return None 75 | else: 76 | filters.append(filter) 77 | return Scenario(name=scenario_name, filters=filters) 78 | 79 | 80 | def map_field_to_database_col(field: Field) -> str: 81 | return "count_" + field.value.lower() 82 | 83 | 84 | def serialize_filter(filter: Filter) -> str: 85 | return f"{map_field_to_database_col(filter.field)} {filter.operator.value} {filter.value}" 86 | 87 | 88 | def serialize_scenario(scenario: Scenario) -> str: 89 | return " AND ".join([serialize_filter(filter) for filter in scenario.filters]) 90 | -------------------------------------------------------------------------------- /src/nl2sql360/arguments/report_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import Optional, Literal, List, Union, Dict 4 | from types import NoneType 5 | from ..filter import parse_filter, parse_scenario 6 | 7 | 8 | @dataclass 9 | class ReportArguments: 10 | r""" 11 | Arguments for reporting dataset and evaluation. 12 | """ 13 | 14 | save_path: str = field( 15 | metadata={"help": "The path to save the CSV report."} 16 | ) 17 | 18 | report_dataset: str = field( 19 | metadata={"help": "The dataset name to report."} 20 | ) 21 | 22 | report_evaluation: Union[Optional[str], List[str]] = field( 23 | default=None, 24 | metadata={"help": "The evaluation (list) to report."} 25 | ) 26 | 27 | metric: Union[str, List[str]] = field( 28 | default=None, 29 | metadata={"help": "The metric (list) to report."} 30 | ) 31 | 32 | filter: Union[Optional[str], List[str]] = field( 33 | default=None, 34 | metadata={"help": "The filter expressions (list) used to filter subset performance."} 35 | ) 36 | 37 | scenario: Union[Optional[str], List[str]] = field( 38 | default=None, 39 | metadata={"help": "The scenario expressions (list) used to filter subset performance."} 40 | ) 41 | 42 | def __post_init__(self): 43 | if isinstance(self.metric, str): 44 | self.metric = [m.strip() for m in self.metric.split(",")] 45 | 46 | for metric in self.metric: 47 | if metric not in ["ex", "em", "ves", "rves", "f1", "qvt"]: 48 | raise ValueError("`eval_metrics` only supports metrics combinations in (`ex`, `em`, `ves`, `rves`, `f1`, `qvt`).") 49 | 50 | filter_list = [] 51 | if self.filter: 52 | try: 53 | for _f in self.filter: 54 | f = parse_filter(_f["name"], _f["expression"]) 55 | if f is None: 56 | raise ValueError(f"Parse filter error: {_f}") 57 | filter_list.append(f) 58 | except Exception as e: 59 | raise ValueError("Parse filter error.") 60 | self.filter = filter_list 61 | 62 | scenario_list = [] 63 | if self.scenario: 64 | try: 65 | for _s in self.scenario: 66 | s = parse_scenario(_s["name"], _s["expression"]) 67 | if s is None: 68 | raise ValueError(f"Parse scenario error: {_s}") 69 | scenario_list.append(s) 70 | except Exception as e: 71 | raise ValueError("Parse scenario error.") 72 | self.scenario = scenario_list 73 | -------------------------------------------------------------------------------- /examples/cli_examples/bird/bird_dev_report.yaml: -------------------------------------------------------------------------------- 1 | # ---------------- Core Arguments ---------------- 2 | 3 | # The directory to save NL2SQL360 core data. 4 | core_dir: "./nl2sql360_cache" 5 | 6 | # The NL2SQL360 core name, such that NL2SQL360 core data is saved to "core_dir/core_name.sqlite". 7 | core_name: "nl2sql360" 8 | 9 | # The dataset SQL dialect, "SQLite" by default. 10 | sql_dialect: "SQLite" 11 | 12 | # ---------------- Report Arguments ---------------- 13 | 14 | # The dataset name which has been imported in NL2SQL360. 15 | report_dataset: "bird_dev" 16 | 17 | # The evaluation(s) name to report, listing each name in the following. 18 | # Set to `null` if you want to report all history evaluations. 19 | report_evaluation: null 20 | 21 | # The metric(s) name to report, including "ex", "em", "ves", "rves", "f1", "qvt". List each name in the following: 22 | metric: 23 | - "ex" 24 | - "ves" 25 | - "rves" 26 | - "f1" 27 | - "qvt" 28 | 29 | # Define subset(s) performance by filter(s). List each filter defination in the following. 30 | # Filter: 31 | # "name": The name for the filtered subset to show. 32 | # "expression": The filter expression in format "{FILTER_KEY} {<, >, =} {NUMBER}". 33 | # Valid {FILTER_KEY} is listed in 34 | # https://github.com/BugMaker-Boyan/NL2SQL360/blob/fe436d43031e06cd457e44ec98fd25a5acd25c2b/src/nl2sql360/filter/filter.py#L13 35 | filter: 36 | - 37 | name: "Filter - Subquery" 38 | expression: "SUBQUERY > 0" 39 | - 40 | name: "Filter - JOIN" 41 | expression: "JOIN > 0" 42 | - 43 | name: "Filter - Aggregation" 44 | expression: "AGGREGATION > 0" 45 | - 46 | name: "Filter - Scaler Function" 47 | expression: "SCALAR_FUNCTION > 0" 48 | - 49 | name: "Filter - Set Operation" 50 | expression: "SET_OPERATION > 0" 51 | - 52 | name: "Filter - Math Compute" 53 | expression: "MATH_COMPUTE > 0" 54 | - 55 | name: "Filter - Logical Connector" 56 | expression: "LOGICAL_CONNECTOR > 0" 57 | - 58 | name: "Filter - DISTINCT" 59 | expression: "DISTINCT > 0" 60 | - 61 | name: "Filter - LIKE" 62 | expression: "LIKE > 0" 63 | - 64 | name: "Filter - Control Flow" 65 | expression: "CONTROL_FLOW > 0" 66 | - 67 | name: "Filter - Window Function" 68 | expression: "WINDOW > 0" 69 | - 70 | name: "Filter - LIMIT" 71 | expression: "LIMIT > 0" 72 | - 73 | name: "Filter - ORDER BY" 74 | expression: "ORDER_BY > 0" 75 | - 76 | name: "Filter - GROUP BY" 77 | expression: "GROUP_BY > 0" 78 | 79 | # Define scenario(s) performance. List each scenario defination in the following. 80 | # Scenario: Combination of multiple filters joined by "&&". 81 | scenario: null 82 | 83 | # The report save path in CSV format 84 | save_path: "./report.csv" 85 | -------------------------------------------------------------------------------- /examples/cli_examples/bird/bird_test_report.yaml: -------------------------------------------------------------------------------- 1 | # ---------------- Core Arguments ---------------- 2 | 3 | # The directory to save NL2SQL360 core data. 4 | core_dir: "./nl2sql360_cache" 5 | 6 | # The NL2SQL360 core name, such that NL2SQL360 core data is saved to "core_dir/core_name.sqlite". 7 | core_name: "nl2sql360" 8 | 9 | # The dataset SQL dialect, "SQLite" by default. 10 | sql_dialect: "SQLite" 11 | 12 | # ---------------- Report Arguments ---------------- 13 | 14 | # The dataset name which has been imported in NL2SQL360. 15 | report_dataset: "bird_test" 16 | 17 | # The evaluation(s) name to report, listing each name in the following. 18 | # Set to `null` if you want to report all history evaluations. 19 | report_evaluation: null 20 | 21 | # The metric(s) name to report, including "ex", "em", "ves", "rves", "f1", "qvt". List each name in the following: 22 | metric: 23 | - "ex" 24 | - "ves" 25 | - "rves" 26 | - "f1" 27 | - "qvt" 28 | 29 | # Define subset(s) performance by filter(s). List each filter defination in the following. 30 | # Filter: 31 | # "name": The name for the filtered subset to show. 32 | # "expression": The filter expression in format "{FILTER_KEY} {<, >, =} {NUMBER}". 33 | # Valid {FILTER_KEY} is listed in 34 | # https://github.com/BugMaker-Boyan/NL2SQL360/blob/fe436d43031e06cd457e44ec98fd25a5acd25c2b/src/nl2sql360/filter/filter.py#L13 35 | filter: 36 | - 37 | name: "Filter - Subquery" 38 | expression: "SUBQUERY > 0" 39 | - 40 | name: "Filter - JOIN" 41 | expression: "JOIN > 0" 42 | - 43 | name: "Filter - Aggregation" 44 | expression: "AGGREGATION > 0" 45 | - 46 | name: "Filter - Scaler Function" 47 | expression: "SCALAR_FUNCTION > 0" 48 | - 49 | name: "Filter - Set Operation" 50 | expression: "SET_OPERATION > 0" 51 | - 52 | name: "Filter - Math Compute" 53 | expression: "MATH_COMPUTE > 0" 54 | - 55 | name: "Filter - Logical Connector" 56 | expression: "LOGICAL_CONNECTOR > 0" 57 | - 58 | name: "Filter - DISTINCT" 59 | expression: "DISTINCT > 0" 60 | - 61 | name: "Filter - LIKE" 62 | expression: "LIKE > 0" 63 | - 64 | name: "Filter - Control Flow" 65 | expression: "CONTROL_FLOW > 0" 66 | - 67 | name: "Filter - Window Function" 68 | expression: "WINDOW > 0" 69 | - 70 | name: "Filter - LIMIT" 71 | expression: "LIMIT > 0" 72 | - 73 | name: "Filter - ORDER BY" 74 | expression: "ORDER_BY > 0" 75 | - 76 | name: "Filter - GROUP BY" 77 | expression: "GROUP_BY > 0" 78 | 79 | # Define scenario(s) performance. List each scenario defination in the following. 80 | # Scenario: Combination of multiple filters joined by "&&". 81 | scenario: null 82 | 83 | # The report save path in CSV format 84 | save_path: "./report.csv" 85 | -------------------------------------------------------------------------------- /examples/cli_examples/bird/bird_minidev_report.yaml: -------------------------------------------------------------------------------- 1 | # ---------------- Core Arguments ---------------- 2 | 3 | # The directory to save NL2SQL360 core data. 4 | core_dir: "./nl2sql360_cache" 5 | 6 | # The NL2SQL360 core name, such that NL2SQL360 core data is saved to "core_dir/core_name.sqlite". 7 | core_name: "nl2sql360" 8 | 9 | # The dataset SQL dialect, "SQLite" by default. 10 | sql_dialect: "SQLite" 11 | 12 | # ---------------- Report Arguments ---------------- 13 | 14 | # The dataset name which has been imported in NL2SQL360. 15 | report_dataset: "bird_minidev" 16 | 17 | # The evaluation(s) name to report, listing each name in the following. 18 | # Set to `null` if you want to report all history evaluations. 19 | report_evaluation: null 20 | 21 | # The metric(s) name to report, including "ex", "em", "ves", "rves", "f1", "qvt". List each name in the following: 22 | metric: 23 | - "ex" 24 | - "ves" 25 | - "rves" 26 | - "f1" 27 | - "qvt" 28 | 29 | # Define subset(s) performance by filter(s). List each filter defination in the following. 30 | # Filter: 31 | # "name": The name for the filtered subset to show. 32 | # "expression": The filter expression in format "{FILTER_KEY} {<, >, =} {NUMBER}". 33 | # Valid {FILTER_KEY} is listed in 34 | # https://github.com/BugMaker-Boyan/NL2SQL360/blob/fe436d43031e06cd457e44ec98fd25a5acd25c2b/src/nl2sql360/filter/filter.py#L13 35 | filter: 36 | - 37 | name: "Filter - Subquery" 38 | expression: "SUBQUERY > 0" 39 | - 40 | name: "Filter - JOIN" 41 | expression: "JOIN > 0" 42 | - 43 | name: "Filter - Aggregation" 44 | expression: "AGGREGATION > 0" 45 | - 46 | name: "Filter - Scaler Function" 47 | expression: "SCALAR_FUNCTION > 0" 48 | - 49 | name: "Filter - Set Operation" 50 | expression: "SET_OPERATION > 0" 51 | - 52 | name: "Filter - Math Compute" 53 | expression: "MATH_COMPUTE > 0" 54 | - 55 | name: "Filter - Logical Connector" 56 | expression: "LOGICAL_CONNECTOR > 0" 57 | - 58 | name: "Filter - DISTINCT" 59 | expression: "DISTINCT > 0" 60 | - 61 | name: "Filter - LIKE" 62 | expression: "LIKE > 0" 63 | - 64 | name: "Filter - Control Flow" 65 | expression: "CONTROL_FLOW > 0" 66 | - 67 | name: "Filter - Window Function" 68 | expression: "WINDOW > 0" 69 | - 70 | name: "Filter - LIMIT" 71 | expression: "LIMIT > 0" 72 | - 73 | name: "Filter - ORDER BY" 74 | expression: "ORDER_BY > 0" 75 | - 76 | name: "Filter - GROUP BY" 77 | expression: "GROUP_BY > 0" 78 | 79 | # Define scenario(s) performance. List each scenario defination in the following. 80 | # Scenario: Combination of multiple filters joined by "&&". 81 | scenario: null 82 | 83 | # The report save path in CSV format 84 | save_path: "./report.csv" 85 | -------------------------------------------------------------------------------- /src/nl2sql360/database/model.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, String, Float, ForeignKey 2 | from sqlalchemy.orm import DeclarativeBase 3 | 4 | 5 | class Base(DeclarativeBase): 6 | pass 7 | 8 | 9 | class DatasetInfo(Base): 10 | __tablename__ = "__DATASET_INFO__" 11 | 12 | dataset_name = Column(String, primary_key=True) 13 | database_dir_path = Column(String, nullable=False) 14 | tables_json_path = Column(String, nullable=True, default=None) 15 | 16 | 17 | class MetaDataset: 18 | 19 | id = Column(Integer, primary_key=True) 20 | nlq = Column(String, nullable=False) 21 | gold = Column(String, nullable=False) 22 | db_id = Column(String, nullable=False) 23 | 24 | """Note: 25 | BIRD complexity: ["simple", "moderate", "challenging"] 26 | Spider complexity: ["easy", "medium", "hard", "extra"] 27 | """ 28 | complexity = Column(String, nullable=False) 29 | db_domain = Column(String, nullable=False) 30 | 31 | count_query_fields = Column(Integer, nullable=False) 32 | count_group_by = Column(Integer, nullable=False) 33 | count_order_by = Column(Integer, nullable=False) 34 | count_limit = Column(Integer, nullable=False) 35 | count_join = Column(Integer, nullable=False) 36 | count_predicate = Column(Integer, nullable=False) 37 | count_aggregation = Column(Integer, nullable=False) 38 | count_scalar_function = Column(Integer, nullable=False) 39 | count_subquery = Column(Integer, nullable=False) 40 | count_set_operation = Column(Integer, nullable=False) 41 | count_math_compute = Column(Integer, nullable=False) 42 | count_logical_connector = Column(Integer, nullable=False) 43 | count_distinct = Column(Integer, nullable=False) 44 | count_like = Column(Integer, nullable=False) 45 | count_control_flow = Column(Integer, nullable=False) 46 | count_window = Column(Integer, nullable=False) 47 | 48 | 49 | class MetaEvaluation: 50 | 51 | pred = Column(String, nullable=False) 52 | exec_acc = Column(Float, nullable=True, default=None) 53 | exact_acc = Column(Float, nullable=True, default=None) 54 | ves = Column(Float, nullable=True, default=None) 55 | rves = Column(Float, nullable=True, default=None) 56 | f1 = Column(Float, nullable=True, default=None) 57 | 58 | 59 | def get_dataset_model(dataset_name): 60 | return type(f"DATASET_{dataset_name}", 61 | (MetaDataset, Base), 62 | dict(__tablename__=f"DATASET_{dataset_name}")) 63 | 64 | 65 | def get_evaluation_model(dataset_name, evaluation_name): 66 | return type(f"DATASET_{dataset_name}_EVALUATION_{evaluation_name}", 67 | (MetaEvaluation, Base), 68 | dict(id=Column(Integer, ForeignKey(f"DATASET_{dataset_name}"), primary_key=True), 69 | __tablename__=f"DATASET_{dataset_name}_EVALUATION_{evaluation_name}")) 70 | -------------------------------------------------------------------------------- /src/nl2sql360/arguments/evaluation_args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import Optional, Literal, List 4 | 5 | 6 | @dataclass 7 | class EvaluationArguments: 8 | r""" 9 | Arguments for evaluation. 10 | """ 11 | 12 | eval_name: str = field( 13 | metadata={"help": "The model name for evaluation."} 14 | ) 15 | 16 | eval_dataset: str = field( 17 | metadata={"help": "The dataset name for evaluation."} 18 | ) 19 | 20 | eval_metrics: List[str] = field( 21 | metadata={"help": "Specify metrics (`ex`, `em`, `ves`) for evaluation."} 22 | ) 23 | 24 | pred_sqls_file: str = field( 25 | metadata={"help": "The file containing all predicted sqls (in lines)."} 26 | ) 27 | 28 | enable_spider_eval: bool = field( 29 | default=False, 30 | metadata={"help": "Enable official spider evaluator."} 31 | ) 32 | 33 | num_processes: int = field( 34 | default=8, 35 | metadata={"help": "The number of multi-processes used in the evaluation "} 36 | ) 37 | 38 | timeout: float = field( 39 | default=30.0, 40 | metadata={"help": "The timeout of SQL execution."} 41 | ) 42 | 43 | # for bird mini-dev MySQL / PostgreSQL database 44 | 45 | db_host: str = field( 46 | default="localhost", 47 | metadata={"help": "The db host (`localhost` by default) for BIRD Mini-Dev dataset to connect MySQL or PostgreSQL database. "} 48 | ) 49 | 50 | db_port: str = field( 51 | default="localhost", 52 | metadata={"help": "The db port (`3306` for MySQL, `5432` for PostgreSQL by default) for BIRD Mini-Dev dataset to connect MySQL or PostgreSQL database. "} 53 | ) 54 | 55 | db_name: str = field( 56 | default="BIRD", 57 | metadata={"help": "The db name (`BIRD` by default) for BIRD Mini-Dev dataset to connect MySQL or PostgreSQL database."} 58 | ) 59 | 60 | db_user: str = field( 61 | default="root", 62 | metadata={"help": "The db username (`root` by default) for BIRD Mini-Dev dataset to connect MySQL or PostgreSQL database. "} 63 | ) 64 | 65 | db_password: str = field( 66 | default="localhost", 67 | metadata={"help": "The db password (`password` by default) for BIRD Mini-Dev dataset to connect MySQL or PostgreSQL database. "} 68 | ) 69 | 70 | def __post_init__(self): 71 | 72 | for metric in self.eval_metrics: 73 | if metric not in ["ex", "em", "ves", "rves", "f1"]: 74 | raise ValueError("`eval_metrics` only supports metrics combinations in (`ex`, `em`, `ves`).") 75 | 76 | if self.num_processes <= 0: 77 | raise ValueError("`num_processes` should be positive.") 78 | 79 | if self.timeout <= 0: 80 | raise ValueError("`timeout` should be positive.") 81 | 82 | -------------------------------------------------------------------------------- /src/nl2sql360/parser/sql_parser.py: -------------------------------------------------------------------------------- 1 | from sqlglot import parse_one, exp 2 | 3 | 4 | class SQLParser: 5 | 6 | _SET_KEYWORDS = (exp.Union, exp.Except, exp.Intersect) 7 | 8 | _SCALAR_KEYWORDS = (exp.Abs, exp.Length, exp.Cast, exp.Round, exp.Upper, exp.Lower, exp.Rand) 9 | _SCALAR_KEYWORDS_ANONYMOUS_STR = ("STRFTIME", "JULIADAY", "NOW", "INSTR", "SUBSTR") 10 | 11 | _MATH_COMPUTE_KEYWORDS = (exp.Add, exp.Sub, exp.Mul, exp.Div, exp.Mod) 12 | 13 | _LOGICAL_CONNECTOR_KEYWORDS = (exp.And, exp.Or) 14 | 15 | _CONTROL_FLOW_KEYWORDS = (exp.Case) 16 | _CONTROL_FLOW_KEYWORDS_ANONYMOUS_STR = ("IIF") 17 | 18 | def __init__(self, sql, dialect="sqlite"): 19 | self.ast = parse_one(sql, dialect=dialect) 20 | 21 | @property 22 | def count_query_fields(self): 23 | _ast = self.ast 24 | while isinstance(_ast, self._SET_KEYWORDS): 25 | _ast = _ast.this 26 | assert isinstance(_ast, exp.Select) 27 | return len(_ast.expressions) 28 | 29 | @property 30 | def count_group_by(self): 31 | return len(list(self.ast.find_all(exp.Group))) 32 | 33 | @property 34 | def count_order_by(self): 35 | return len(list(self.ast.find_all(exp.Order))) 36 | 37 | @property 38 | def count_limit(self): 39 | return len(list(self.ast.find_all(exp.Limit))) 40 | 41 | @property 42 | def count_join(self): 43 | return len(list(self.ast.find_all(exp.Join))) 44 | 45 | @property 46 | def count_predicate(self): 47 | return len(list(self.ast.find_all(exp.Predicate))) 48 | 49 | @property 50 | def count_aggregation(self): 51 | return len(list(self.ast.find_all(exp.AggFunc))) 52 | 53 | @property 54 | def count_scalar_function(self): 55 | scalar_nodes = list(self.ast.find_all(self._SCALAR_KEYWORDS)) 56 | scalar_nodes.extend([node for node in self.ast.find_all(exp.Anonymous) if node.this.upper() in self._SCALAR_KEYWORDS_ANONYMOUS_STR]) 57 | return len(scalar_nodes) 58 | 59 | @property 60 | def count_subquery(self): 61 | return len(list(self.ast.find_all(exp.Subquery))) 62 | 63 | @property 64 | def count_set_operation(self): 65 | return len(list(self.ast.find_all(self._SET_KEYWORDS))) 66 | 67 | @property 68 | def count_math_compute(self): 69 | return len(list(self.ast.find_all(self._MATH_COMPUTE_KEYWORDS))) 70 | 71 | @property 72 | def count_logical_connector(self): 73 | return len(list(self.ast.find_all(self._LOGICAL_CONNECTOR_KEYWORDS))) 74 | 75 | @property 76 | def count_distinct(self): 77 | return len(list(self.ast.find_all(exp.Distinct))) 78 | 79 | @property 80 | def count_like(self): 81 | return len(list(self.ast.find_all(exp.Like))) 82 | 83 | @property 84 | def count_control_flow(self): 85 | control_flow_nodes = list(self.ast.find_all(self._CONTROL_FLOW_KEYWORDS)) 86 | control_flow_nodes.extend([node for node in self.ast.find_all(exp.Anonymous) if node.this.upper() in self._CONTROL_FLOW_KEYWORDS_ANONYMOUS_STR]) 87 | return len(control_flow_nodes) 88 | 89 | @property 90 | def count_window(self): 91 | return len(list(self.ast.find_all(exp.Window))) 92 | -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/bird_eval/evaluation_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import psycopg2 3 | import pymysql 4 | import sqlite3 5 | 6 | 7 | def load_json(dir): 8 | with open(dir, "r") as j: 9 | contents = json.loads(j.read()) 10 | return contents 11 | 12 | 13 | def connect_postgresql(dbname="BIRD", user="root", host="localhost", password="password", port=5432): 14 | # Open database connection 15 | # Connect to the database 16 | db = psycopg2.connect( 17 | f"dbname={dbname} user={user} host={host} password={password} port={port}" 18 | ) 19 | return db 20 | 21 | 22 | def connect_mysql(dbname="BIRD", user="root", host="localhost", password="password", port=3306): 23 | # Open database connection 24 | # Connect to the database" 25 | db = pymysql.connect( 26 | host=host, 27 | user=user, 28 | password=password, 29 | database=dbname, 30 | unix_socket="/tmp/mysql.sock", 31 | port=port, 32 | ) 33 | return db 34 | 35 | 36 | def connect_db(sql_dialect, db_path, **kwds): 37 | if sql_dialect == "SQLite": 38 | conn = sqlite3.connect(db_path) 39 | elif sql_dialect == "MySQL": 40 | conn = connect_mysql(**kwds) 41 | elif sql_dialect == "PostgreSQL": 42 | conn = connect_postgresql(**kwds) 43 | else: 44 | raise ValueError("Unsupported SQL dialect") 45 | return conn 46 | 47 | 48 | def execute_sql(predicted_sql, ground_truth, db_path, sql_dialect, calculate_func, **kwds): 49 | conn = connect_db(sql_dialect, db_path, **kwds) 50 | # Connect to the database 51 | cursor = conn.cursor() 52 | cursor.execute(predicted_sql) 53 | predicted_res = cursor.fetchall() 54 | cursor.execute(ground_truth) 55 | ground_truth_res = cursor.fetchall() 56 | conn.close() 57 | res = calculate_func(predicted_res, ground_truth_res) 58 | return res 59 | 60 | 61 | def package_sqls( 62 | sql_path, db_root_path, engine, sql_dialect="SQLite", mode="gpt", data_mode="dev" 63 | ): 64 | clean_sqls = [] 65 | db_path_list = [] 66 | if mode == "gpt": 67 | # use chain of thought 68 | sql_data = json.load( 69 | open( 70 | sql_path 71 | + "predict_" 72 | + data_mode 73 | + "_" 74 | + engine 75 | + "_cot_" 76 | + sql_dialect 77 | + ".json", 78 | "r", 79 | ) 80 | ) 81 | for _, sql_str in sql_data.items(): 82 | if type(sql_str) == str: 83 | sql, db_name = sql_str.split("\t----- bird -----\t") 84 | else: 85 | sql, db_name = " ", "financial" 86 | clean_sqls.append(sql) 87 | db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite") 88 | 89 | elif mode == "gt": 90 | sqls = open(sql_path + data_mode + "_" + sql_dialect + "_gold.sql") 91 | sql_txt = sqls.readlines() 92 | # sql_txt = [sql.split('\t')[0] for sql in sql_txt] 93 | for idx, sql_str in enumerate(sql_txt): 94 | # print(sql_str) 95 | sql, db_name = sql_str.strip().split("\t") 96 | clean_sqls.append(sql) 97 | db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite") 98 | 99 | return clean_sqls, db_path_list 100 | 101 | 102 | def sort_results(list_of_dicts): 103 | return sorted(list_of_dicts, key=lambda x: x["sql_idx"]) 104 | 105 | 106 | def print_data(score_lists, count_lists, metric="F1 Score"): 107 | levels = ["simple", "moderate", "challenging", "total"] 108 | print("{:20} {:20} {:20} {:20} {:20}".format("", *levels)) 109 | print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists)) 110 | 111 | print( 112 | f"====================================== {metric} =====================================" 113 | ) 114 | print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format(metric, *score_lists)) 115 | -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/bird_eval/bird_ex.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import multiprocessing as mp 4 | from func_timeout import func_timeout, FunctionTimedOut 5 | from .evaluation_utils import ( 6 | load_json, 7 | execute_sql, 8 | package_sqls, 9 | sort_results, 10 | print_data, 11 | ) 12 | from tqdm import tqdm 13 | 14 | exec_result = [] 15 | progress_bar = None 16 | 17 | 18 | def result_callback(result): 19 | exec_result.append(result) 20 | progress_bar.update() 21 | 22 | 23 | def calculate_ex(predicted_res, ground_truth_res): 24 | res = 0 25 | if set(predicted_res) == set(ground_truth_res): 26 | res = 1 27 | return res 28 | 29 | 30 | def execute_model( 31 | predicted_sql, ground_truth, db_place, idx, meta_time_out, sql_dialect, **kwds 32 | ): 33 | try: 34 | res = func_timeout( 35 | meta_time_out, 36 | execute_sql, 37 | args=(predicted_sql, ground_truth, db_place, sql_dialect, calculate_ex), 38 | kwargs=kwds 39 | ) 40 | except KeyboardInterrupt: 41 | sys.exit(0) 42 | except FunctionTimedOut: 43 | result = [(f"timeout",)] 44 | res = 0 45 | except Exception as e: 46 | result = [(f"error",)] # possibly len(query) > 512 or not executable 47 | res = 0 48 | result = {"sql_idx": idx, "res": res} 49 | return result 50 | 51 | 52 | def run_sqls_parallel( 53 | sqls, db_places, num_cpus=1, meta_time_out=30.0, sql_dialect="SQLite", **kwds 54 | ): 55 | global exec_result, progress_bar 56 | exec_result.clear() 57 | progress_bar = tqdm(total=len(sqls)) 58 | pool = mp.Pool(processes=num_cpus) 59 | for i, sql_pair in enumerate(sqls): 60 | predicted_sql, ground_truth = sql_pair 61 | pool.apply_async( 62 | execute_model, 63 | args=( 64 | predicted_sql, 65 | ground_truth, 66 | db_places[i], 67 | i, 68 | meta_time_out, 69 | sql_dialect, 70 | ), 71 | kwds=kwds, 72 | callback=result_callback, 73 | ) 74 | pool.close() 75 | pool.join() 76 | return exec_result 77 | 78 | 79 | def compute_acc_by_diff(exec_results, diff_json_path): 80 | num_queries = len(exec_results) 81 | results = [res["res"] for res in exec_results] 82 | contents = load_json(diff_json_path) 83 | simple_results, moderate_results, challenging_results = [], [], [] 84 | 85 | for i, content in enumerate(contents): 86 | if content["difficulty"] == "simple": 87 | simple_results.append(exec_results[i]) 88 | 89 | if content["difficulty"] == "moderate": 90 | moderate_results.append(exec_results[i]) 91 | 92 | if content["difficulty"] == "challenging": 93 | try: 94 | challenging_results.append(exec_results[i]) 95 | except: 96 | print(i) 97 | 98 | simple_acc = sum([res["res"] for res in simple_results]) / len(simple_results) 99 | moderate_acc = sum([res["res"] for res in moderate_results]) / len(moderate_results) 100 | challenging_acc = sum([res["res"] for res in challenging_results]) / len( 101 | challenging_results 102 | ) 103 | all_acc = sum(results) / num_queries 104 | count_lists = [ 105 | len(simple_results), 106 | len(moderate_results), 107 | len(challenging_results), 108 | num_queries, 109 | ] 110 | return ( 111 | simple_acc * 100, 112 | moderate_acc * 100, 113 | challenging_acc * 100, 114 | all_acc * 100, 115 | count_lists, 116 | ) 117 | 118 | 119 | if __name__ == "__main__": 120 | args_parser = argparse.ArgumentParser() 121 | args_parser.add_argument( 122 | "--predicted_sql_path", type=str, required=True, default="" 123 | ) 124 | args_parser.add_argument("--ground_truth_path", type=str, required=True, default="") 125 | args_parser.add_argument("--data_mode", type=str, required=True, default="dev") 126 | args_parser.add_argument("--db_root_path", type=str, required=True, default="") 127 | args_parser.add_argument("--num_cpus", type=int, default=1) 128 | args_parser.add_argument("--meta_time_out", type=float, default=30.0) 129 | args_parser.add_argument("--mode_gt", type=str, default="gt") 130 | args_parser.add_argument("--mode_predict", type=str, default="gpt") 131 | args_parser.add_argument("--difficulty", type=str, default="simple") 132 | args_parser.add_argument("--diff_json_path", type=str, default="") 133 | args_parser.add_argument("--engine", type=str, default="") 134 | args_parser.add_argument("--sql_dialect", type=str, default="SQLite") 135 | args = args_parser.parse_args() 136 | exec_result = [] 137 | 138 | pred_queries, db_paths = package_sqls( 139 | args.predicted_sql_path, 140 | args.db_root_path, 141 | args.engine, 142 | sql_dialect=args.sql_dialect, 143 | mode=args.mode_predict, 144 | data_mode=args.data_mode, 145 | ) 146 | # generate ground truth sqls: 147 | gt_queries, db_paths_gt = package_sqls( 148 | args.ground_truth_path, 149 | args.db_root_path, 150 | args.engine, 151 | sql_dialect=args.sql_dialect, 152 | mode="gt", 153 | data_mode=args.data_mode, 154 | ) 155 | 156 | query_pairs = list(zip(pred_queries, gt_queries)) 157 | 158 | run_sqls_parallel( 159 | query_pairs, 160 | db_places=db_paths, 161 | num_cpus=args.num_cpus, 162 | meta_time_out=args.meta_time_out, 163 | sql_dialect=args.sql_dialect, 164 | ) 165 | exec_result = sort_results(exec_result) 166 | print("start calculate") 167 | simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_acc_by_diff( 168 | exec_result, args.diff_json_path 169 | ) 170 | score_lists = [simple_acc, moderate_acc, challenging_acc, acc] 171 | print(f"EX for {args.engine} on {args.sql_dialect} set") 172 | print("start calculate") 173 | print_data(score_lists, count_lists, metric="EX") 174 | print( 175 | "===========================================================================================" 176 | ) 177 | print(f"Finished EX evaluation for {args.engine} on {args.sql_dialect} set") 178 | print("\n\n") 179 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # :mag_right:NL2SQL360 2 | 3 |
4 | 5 | ## :dizzy:Overview 6 | 7 | **NL2SQL360** is a testbed for fine-grained evaluation of NL2SQL solutions. Our testbed integrates existing NL2SQL benchmarks, a repository of NL2SQL models, and various evaluation metrics, which aims to provide an intuitive and user-friendly platform to enable both standard and customized performance evaluations. Users can utilize **NL2SQL360** to assess different NL2SQL methods against established benchmarks or tailor their evaluations based on specific criteria. This flexibility allows for testing solutions in specific data domains or analyzing performance on different characteristics of SQL queries. 8 | 9 | In addition, we propose **SuperSQL**, which achieves competitive performance with execution accuracy of **87%** and **62.66%** on the Spider and BIRD test sets, respectively. 10 | 11 | >[!TIP] 12 | > We also offer the [NL2SQL Handbook](https://github.com/HKUSTDial/NL2SQL_Handbook), which tracks the latest NL2SQL advancements in the literature and offers practical guidance for researchers and practitioners! If you have any suggestions, feel free to submit an issue to us! 13 | 14 | ## :tada:News 15 | 16 | [24/9/23] We release NL2SQL360 `1.1.0` version, which supports two new metrics **Reward-based VES** (RVES), **Soft-F1 Score** (F1), from [BIRD-Mini-Dev ](https://github.com/bird-bench/mini_dev) dataset. **Please update your package with** `pip install --upgrade nl2sql360`. 17 | 18 | [24/9/1] We have released our **[Homepage & Leaderboard](https://nl2sql360.github.io)!** 19 | 20 | [24/8/2] We have released CLI usage / Code usage tutorials. **Please [check out](#rocketquick-start)!** 21 | 22 | [24/7/30] We have refactored the code and released the official python package([nl2sql360 · PyPI](https://pypi.org/project/nl2sql360)). **Stay tuned for the complete documents!** 23 | 24 | [24/6/30] Our paper [The Dawn of Natural Language to SQL: Are We Fully Ready?](https://arxiv.org/abs/2406.01265) has been accepted by VLDB'24. 25 | 26 | ## :balloon:Features 27 | 28 | - **Easy-to-use Evaluation**: Command Line Usage / Python Code Usage. 29 | - **Integrated Metrics**: Execution Accuracy / Exact-Match Accuracy / Valid Efficiency Score / Question Variance Testing. 30 | - **Multi-angle Performance**: Fine-grained performance (JOIN, Sub-query, etc.) / Scenario-based (Business Intelligence, etc.) 31 | 32 | ## :wrench:Installation 33 | 34 | ```bash 35 | pip install nl2sql360 36 | ``` 37 | 38 | ## :rocket:Quick Start 39 | 40 |
Prepare Dataset 41 | 42 | Download NL2SQL dataset to `DATASET_DIR_PATH`. The directory structure should be like: 43 | ```bash 44 | DATASET_DIR_PATH: 45 | ├─database 46 | │ ├─academic 47 | │ │ ├─academic.sqlite 48 | │ ├─college 49 | │ │ ├─college.sqlite 50 | ├─dev.json 51 | ├─tables.json 52 | ``` 53 | 54 | - `database` directory contains multiple subdirectories, which include the corresponding `sqlite` database file. 55 | - `dev.json` is the samples file in JSON format, which at least contains three keys for `NL Question`, `Gold SQL`, `Databae Id`. You can also add the key for `Sample Complexity` for categorizing samples into different difficulty levels. 56 | - `tables.json` contains all database schema, following [Spider Preprocess Procedure](https://github.com/taoyds/spider/tree/master/preprocess). **You can also ignore this file if you do not want to evaluate Exact-Match Accuracy Metic.** 57 | - Note that the name for `database` directory, samples file `dev.json` and tables file `tables.json` can be changed. 58 | 59 |
60 | 61 |
Import Dataset into NL2SQL360 62 | 63 | - CLI Usage: 64 | 65 | - Create / Modify the YAML configuration following [NL2SQL360/examples/cli_examples/dataset_spider.yaml](https://github.com/HKUSTDial/NL2SQL360/blob/master/examples/cli_examples/dataset_spider.yaml). 66 | 67 | - Save the YAML file to the path `DATASET_YAML_PATH`. Then run the command line: 68 | 69 | ```bash 70 | nl2sql360-cli dataset DATASET_YAML_PATH 71 | ``` 72 | 73 | - Code Usage: 74 | 75 | - Create / Modify Python File following [NL2SQL360/examples/py_examples/dataset_import.py](https://github.com/HKUSTDial/NL2SQL360/blob/master/examples/py_examples/dataset_import.py). 76 | - Run the python file to import dataset. 77 | 78 |
79 | 80 |
Evaluation NL2SQL Model 81 | 82 | - CLI Usage: 83 | 84 | - Create / Modify the YAML configuration following [NL2SQL360/examples/cli_examples/evaluation.yaml](https://github.com/HKUSTDial/NL2SQL360/blob/master/examples/cli_examples/evaluation.yaml). 85 | 86 | - Save the YAML file to the path `DATASET_YAML_PATH`. Then run the command line: 87 | 88 | ```bash 89 | nl2sql360-cli evaluate DATASET_YAML_PATH 90 | ``` 91 | 92 | - Code Usage: 93 | 94 | - Create / Modify Python File following [NL2SQL360/examples/py_examples/evaluation.py](https://github.com/HKUSTDial/NL2SQL360/blob/master/examples/py_examples/evaluation.py). 95 | - Run the python file to evaluate the model. 96 | 97 |
98 | 99 |
Query Multi-angle Performance 100 | 101 | - CLI Usage: 102 | 103 | - Create / Modify the YAML configuration following [NL2SQL360/examples/cli_examples/report.yaml](https://github.com/HKUSTDial/NL2SQL360/blob/master/examples/cli_examples/report.yaml). 104 | 105 | - Save the YAML file to the path `DATASET_YAML_PATH`. Then run the command line: 106 | 107 | ```bash 108 | nl2sql360-cli report DATASET_YAML_PATH 109 | ``` 110 | 111 | - The generated report will be in `save_path` specified in the YAML file. 112 | 113 | - Code Usage: 114 | - Create / Modify Python File following [NL2SQL360/examples/py_examples/report.py](https://github.com/HKUSTDial/NL2SQL360/blob/master/examples/py_examples/report.py). 115 | - Run the python file to generate report. 116 | 117 |
118 | 119 |
Delete History Cache 120 | 121 | - CLI Usage: 122 | 123 | - Create / Modify the YAML configuration following [NL2SQL360/examples/cli_examples/delete_history.yaml](https://github.com/HKUSTDial/NL2SQL360/blob/master/examples/cli_examples/delete_history.yaml). 124 | 125 | - Save the YAML file to the path `DATASET_YAML_PATH`. Then run the command line: 126 | 127 | ```bash 128 | nl2sql360-cli delete DATASET_YAML_PATH 129 | ``` 130 | 131 | - Code Usage: 132 | 133 | - Create / Modify Python File following [NL2SQL360/examples/py_examples/delete_history.py](https://github.com/HKUSTDial/NL2SQL360/blob/master/examples/py_examples/delete_history.py). 134 | - Run the python file to delete dataset / evaluation cache. 135 | 136 |
137 | 138 | ## :dart:Road Map 139 | 140 | :white_check_mark:Release **NL2SQL360** evaluation code. 141 | 142 | :white_check_mark:Release **NL2SQL360** experiments data. 143 | 144 | :white_check_mark:Release **NL2SQL360** Official Python Package. 145 | 146 | ## :floppy_disk:Experiment Data 147 | 148 | We have released all experiment data used in our paper. 149 | 150 | [Download Link](https://drive.google.com/drive/folders/1SDwY30H2r6XNYeS53wcZNocVFm0hVgpz?usp=sharing) 151 | 152 | ## :pushpin:Citation 153 | 154 | ``` 155 | @article{nl2sql360, 156 | author = {Boyan Li and 157 | Yuyu Luo and 158 | Chengliang Chai and 159 | Guoliang Li and 160 | Nan Tang}, 161 | title = {The Dawn of Natural Language to {SQL:} Are We Fully Ready? }, 162 | journal = {Proc. {VLDB} Endow.}, 163 | volume = {17}, 164 | number = {11}, 165 | pages = {3318--3331}, 166 | year = {2024} 167 | } 168 | ``` 169 | -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/test_suite_sql_eval/parse.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sqlparse 3 | from typing import List, Tuple, Set, Iterator, Dict, Any, Union 4 | from sqlparse.sql import Comparison, Identifier 5 | from sqlparse.tokens import Whitespace 6 | import itertools 7 | from collections import namedtuple 8 | 9 | Token = namedtuple('Token', ['ttype', 'value']) 10 | VALUE_NUM_SYMBOL = 'VALUERARE' 11 | QUOTE_CHARS = {'`', '\'', '"'} 12 | 13 | 14 | def tokenize(query: str) -> List[Token]: 15 | tokens = list([Token(t.ttype, t.value) for t in sqlparse.parse(query)[0].flatten()]) 16 | return tokens 17 | 18 | 19 | def join_tokens(tokens: List[Token]) -> str: 20 | return ''.join([x.value for x in tokens]).strip().replace(' ', ' ') 21 | 22 | 23 | def round_trip_test(query: str) -> None: 24 | tokens = tokenize(query) 25 | reconstructed = ''.join([token.value for token in tokens]) 26 | assert query == reconstructed, "Round trip test fails for string %s" % query 27 | 28 | 29 | def postprocess(query: str) -> str: 30 | query = query.replace('> =', '>=').replace('< =', '<=').replace('! =', '!=') 31 | return query 32 | 33 | 34 | # strip_query, reformat_query and replace values 35 | # were implemented by Yu Tao for processing CoSQL 36 | def strip_query(query: str) -> Tuple[List[str], List[str]]: 37 | query_keywords, all_values = [], [] 38 | 39 | # then replace all stuff enclosed by "" with a numerical value to get it marked as {VALUE} 40 | 41 | # Tao's implementation is commented out here. 42 | """ 43 | str_1 = re.findall("\"[^\"]*\"", query) 44 | str_2 = re.findall("\'[^\']*\'", query) 45 | values = str_1 + str_2 46 | """ 47 | 48 | toks = sqlparse.parse(query)[0].flatten() 49 | values = [t.value for t in toks if t.ttype == sqlparse.tokens.Literal.String.Single or t.ttype == sqlparse.tokens.Literal.String.Symbol] 50 | 51 | 52 | for val in values: 53 | all_values.append(val) 54 | query = query.replace(val.strip(), VALUE_NUM_SYMBOL) 55 | 56 | query_tokenized = query.split() 57 | float_nums = re.findall("[-+]?\d*\.\d+", query) 58 | all_values += [qt for qt in query_tokenized if qt in float_nums] 59 | query_tokenized = [VALUE_NUM_SYMBOL if qt in float_nums else qt for qt in query_tokenized] 60 | 61 | query = " ".join(query_tokenized) 62 | int_nums = [i.strip() for i in re.findall("[^tT]\d+", query)] 63 | 64 | all_values += [qt for qt in query_tokenized if qt in int_nums] 65 | query_tokenized = [VALUE_NUM_SYMBOL if qt in int_nums else qt for qt in query_tokenized] 66 | # print int_nums, query, query_tokenized 67 | 68 | for tok in query_tokenized: 69 | if "." in tok: 70 | table = re.findall("[Tt]\d+\.", tok) 71 | if len(table) > 0: 72 | to = tok.replace(".", " . ").split() 73 | to = [t.lower() for t in to if len(t) > 0] 74 | query_keywords.extend(to) 75 | else: 76 | query_keywords.append(tok.lower()) 77 | 78 | elif len(tok) > 0: 79 | query_keywords.append(tok.lower()) 80 | return query_keywords, all_values 81 | 82 | 83 | def reformat_query(query: str) -> str: 84 | query = query.strip().replace(";", "").replace("\t", "") 85 | query = ' '.join([t.value for t in tokenize(query) if t.ttype != sqlparse.tokens.Whitespace]) 86 | t_stars = ["t1.*", "t2.*", "t3.*", "T1.*", "T2.*", "T3.*"] 87 | for ts in t_stars: 88 | query = query.replace(ts, "*") 89 | return query 90 | 91 | 92 | def replace_values(sql: str) -> Tuple[List[str], Set[str]]: 93 | sql = sqlparse.format(sql, reindent=False, keyword_case='upper') 94 | # sql = re.sub(r"(<=|>=|!=|=|<|>|,)", r" \1 ", sql) 95 | sql = re.sub(r"(T\d+\.)\s", r"\1", sql) 96 | query_toks_no_value, values = strip_query(sql) 97 | return query_toks_no_value, set(values) 98 | 99 | 100 | # extract the non-value tokens and the set of values 101 | # from a sql query 102 | def extract_query_values(sql: str) -> Tuple[List[str], Set[str]]: 103 | reformated = reformat_query(query=sql) 104 | query_value_replaced, values = replace_values(reformated) 105 | return query_value_replaced, values 106 | 107 | 108 | # plug in the values into query with value slots 109 | def plugin(query_value_replaced: List[str], values_in_order: List[str]) -> str: 110 | q_length = len(query_value_replaced) 111 | query_w_values = query_value_replaced[:] 112 | value_idx = [idx for idx in range(q_length) if query_value_replaced[idx] == VALUE_NUM_SYMBOL.lower()] 113 | assert len(value_idx) == len(values_in_order) 114 | 115 | for idx, value in zip(value_idx, values_in_order): 116 | query_w_values[idx] = value 117 | return ' '.join(query_w_values) 118 | 119 | 120 | # a generator generating all possible ways of 121 | # filling values into predicted query 122 | def plugin_all_permutations(query_value_replaced: List[str], values: Set[str]) -> Iterator[str]: 123 | num_slots = len([v for v in query_value_replaced if v == VALUE_NUM_SYMBOL.lower()]) 124 | for values in itertools.product(*[list(values) for _ in range(num_slots)]): 125 | yield plugin(query_value_replaced, list(values)) 126 | 127 | 128 | # given the gold query and the model prediction 129 | # extract values from the gold, extract predicted sql with value slots 130 | # return 1) number of possible ways to plug in gold values and 2) an iterator of predictions with value plugged in 131 | def get_all_preds_for_execution(gold: str, pred: str) -> Tuple[int, Iterator[str]]: 132 | _, gold_values = extract_query_values(gold) 133 | pred_query_value_replaced, _ = extract_query_values(pred) 134 | num_slots = len([v for v in pred_query_value_replaced if v == VALUE_NUM_SYMBOL.lower()]) 135 | num_alternatives = len(gold_values) ** num_slots 136 | return num_alternatives, plugin_all_permutations(pred_query_value_replaced, gold_values) 137 | 138 | 139 | def remove_distinct(s): 140 | toks = [t.value for t in list(sqlparse.parse(s)[0].flatten())] 141 | return ''.join([t for t in toks if t.lower() != 'distinct']) 142 | 143 | 144 | def extract_all_comparison_from_node(node: Token) -> List[Comparison]: 145 | comparison_list = [] 146 | if hasattr(node, 'tokens'): 147 | for t in node.tokens: 148 | comparison_list.extend(extract_all_comparison_from_node(t)) 149 | if type(node) == Comparison: 150 | comparison_list.append(node) 151 | return comparison_list 152 | 153 | 154 | def extract_all_comparison(query: str) -> List[Comparison]: 155 | tree = sqlparse.parse(query)[0] 156 | comparison_list = extract_all_comparison_from_node(tree) 157 | return comparison_list 158 | 159 | 160 | def extract_toks_from_comparison(comparison_node: Comparison) -> List[Token]: 161 | tokens = [t for t in comparison_node.tokens if t.ttype != Whitespace] 162 | return tokens 163 | 164 | 165 | def extract_info_from_comparison(comparison_node: Comparison) -> Dict[str, Any]: 166 | tokens = extract_toks_from_comparison(comparison_node) 167 | left, op, right = tokens 168 | 169 | returned_dict = { 170 | 'left': left, 171 | 'op': op.value, 172 | 'right': right 173 | } 174 | 175 | if type(left) != Identifier: 176 | return returned_dict 177 | 178 | table = None 179 | if len(left.tokens) == 3 and re.match('^[tT][0-9]$', left.tokens[0].value) is None: 180 | table = left.tokens[0].value.lower() 181 | col = left.tokens[-1].value 182 | 183 | if type(right) == Identifier: 184 | if len(right.tokens) == 1 and type(right.tokens[0]) == sqlparse.sql.Token: 185 | right_val = right.tokens[0].value 186 | else: 187 | return returned_dict 188 | elif type(right) == sqlparse.sql.Token: 189 | right_val = right.value 190 | else: 191 | return returned_dict 192 | 193 | returned_dict['table_col'], returned_dict['val'] = (table, col.upper()), process_str_value(right_val) 194 | 195 | return returned_dict 196 | 197 | 198 | def extract_all_comparison_from_query(query: str) -> List[Dict[str, Any]]: 199 | comparison_list = extract_all_comparison(query) 200 | return [extract_info_from_comparison(c) for c in comparison_list] 201 | 202 | 203 | def extract_typed_value_in_comparison_from_query(query: str) -> List[Tuple[Tuple[Union[str, None], str], str]]: 204 | cmps = extract_all_comparison_from_query(query) 205 | typed_values = [(cmp['table_col'], cmp['val']) for cmp in cmps if 'table_col' in cmp] 206 | for table, col, val1, val2 in re.findall('(?:([^\.\s]*)\.)?([^\.\s]+) between ([^\s;]+) and ([^\s;]+)', query, re.IGNORECASE): 207 | if table == '': 208 | table = None 209 | else: 210 | table = table.lower() 211 | col = col.upper() 212 | for v in [val1, val2]: 213 | typed_values.append(((table, col), v)) 214 | return typed_values 215 | 216 | 217 | def process_str_value(v: str) -> str: 218 | if len(v) > 0 and v[0] in QUOTE_CHARS: 219 | v = v[1:] 220 | if len(v) > 0 and v[-1] in QUOTE_CHARS: 221 | v = v[:-1] 222 | for c in QUOTE_CHARS: 223 | v = v.replace(c + c, c) 224 | return v -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/bird_eval/bird_ves.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import numpy as np 4 | import argparse 5 | import multiprocessing as mp 6 | from func_timeout import func_timeout, FunctionTimedOut 7 | from .evaluation_utils import ( 8 | load_json, 9 | package_sqls, 10 | sort_results, 11 | print_data, 12 | connect_db, 13 | ) 14 | import time 15 | import math 16 | from tqdm import tqdm 17 | 18 | exec_result = [] 19 | progress_bar = None 20 | 21 | def result_callback(result): 22 | exec_result.append(result) 23 | progress_bar.update() 24 | 25 | 26 | def clean_abnormal(input): 27 | input = np.asarray(input) 28 | processed_list = [] 29 | mean = np.mean(input, axis=0) 30 | std = np.std(input, axis=0) 31 | for x in input: 32 | if x < mean + 3 * std and x > mean - 3 * std: 33 | processed_list.append(x) 34 | return processed_list 35 | 36 | 37 | def execute_sql(sql, db_path, sql_dialect, return_time=False, **kwds): 38 | # Connect to the database 39 | conn = connect_db(sql_dialect, db_path, **kwds) 40 | start_time = time.time() 41 | cursor = conn.cursor() 42 | cursor.execute(sql) 43 | res = cursor.fetchall() 44 | conn.close() # Don't forget to close the connection! 45 | exec_time = time.time() - start_time 46 | if return_time: 47 | return exec_time 48 | 49 | return res 50 | 51 | 52 | def iterated_execute_sql( 53 | predicted_sql, ground_truth, db_path, iterate_num, sql_dialect, exec_acc, **kwds 54 | ): 55 | diff_list = [] 56 | predicted_res = execute_sql(predicted_sql, db_path, sql_dialect, **kwds) 57 | ground_truth_res = execute_sql(ground_truth, db_path, sql_dialect, **kwds) 58 | time_ratio = 0 59 | if (exec_acc is None and set(predicted_res) == set(ground_truth_res)) or (exec_acc is not None and exec_acc == 1): 60 | for _ in range(iterate_num): 61 | predicted_time = execute_sql( 62 | predicted_sql, db_path, sql_dialect, return_time=True, 63 | **kwds 64 | ) 65 | ground_truth_time = execute_sql( 66 | ground_truth, db_path, sql_dialect, return_time=True, 67 | **kwds 68 | ) 69 | diff_list.append(ground_truth_time / predicted_time) 70 | processed_diff_list = clean_abnormal(diff_list) 71 | time_ratio = sum(processed_diff_list) / len(processed_diff_list) 72 | return time_ratio 73 | 74 | 75 | def execute_model( 76 | predicted_sql, ground_truth, db_place, idx, iterate_num, meta_time_out, sql_dialect, exec_acc, **kwds 77 | ): 78 | try: 79 | # you can personalize the total timeout number 80 | # larger timeout leads to more stable ves 81 | # while it needs more your patience.... 82 | time_ratio = func_timeout( 83 | meta_time_out * iterate_num, 84 | iterated_execute_sql, 85 | args=(predicted_sql, ground_truth, db_place, iterate_num, sql_dialect, exec_acc), 86 | kwargs=kwds 87 | ) 88 | except KeyboardInterrupt: 89 | sys.exit(0) 90 | except FunctionTimedOut: 91 | result = [(f"timeout",)] 92 | time_ratio = 0 93 | except Exception as e: 94 | result = [(f"error",)] # possibly len(query) > 512 or not executable 95 | time_ratio = 0 96 | result = {"sql_idx": idx, "time_ratio": time_ratio} 97 | return result 98 | 99 | 100 | def run_sqls_parallel( 101 | sqls, 102 | db_places, 103 | num_cpus=1, 104 | iterate_num=100, 105 | meta_time_out=30.0, 106 | sql_dialect="SQLite", 107 | exec_acc_list=None, 108 | **kwds 109 | ): 110 | global exec_result, progress_bar 111 | exec_result.clear() 112 | progress_bar = tqdm(total=len(sqls)) 113 | pool = mp.Pool(processes=num_cpus) 114 | for i, sql_pair in enumerate(sqls): 115 | predicted_sql, ground_truth = sql_pair 116 | exec_acc = exec_acc_list[i] if exec_acc_list else None 117 | pool.apply_async( 118 | execute_model, 119 | args=( 120 | predicted_sql, 121 | ground_truth, 122 | db_places[i], 123 | i, 124 | iterate_num, 125 | meta_time_out, 126 | sql_dialect, 127 | exec_acc 128 | ), 129 | kwds=kwds, 130 | callback=result_callback, 131 | ) 132 | pool.close() 133 | pool.join() 134 | return exec_result 135 | 136 | 137 | def compute_ves(exec_results): 138 | num_queries = len(exec_results) 139 | total_ratio = 0 140 | count = 0 141 | 142 | for i, result in enumerate(exec_results): 143 | if result["time_ratio"] != 0: 144 | count += 1 145 | total_ratio += math.sqrt(result["time_ratio"]) * 100 146 | ves = total_ratio / num_queries 147 | return ves 148 | 149 | 150 | def compute_ves_by_diff(exec_results, diff_json_path): 151 | num_queries = len(exec_results) 152 | contents = load_json(diff_json_path) 153 | simple_results, moderate_results, challenging_results = [], [], [] 154 | for i, content in enumerate(contents): 155 | if content["difficulty"] == "simple": 156 | simple_results.append(exec_results[i]) 157 | if content["difficulty"] == "moderate": 158 | moderate_results.append(exec_results[i]) 159 | if content["difficulty"] == "challenging": 160 | challenging_results.append(exec_results[i]) 161 | simple_ves = compute_ves(simple_results) 162 | moderate_ves = compute_ves(moderate_results) 163 | challenging_ves = compute_ves(challenging_results) 164 | all_ves = compute_ves(exec_results) 165 | count_lists = [ 166 | len(simple_results), 167 | len(moderate_results), 168 | len(challenging_results), 169 | num_queries, 170 | ] 171 | return simple_ves, moderate_ves, challenging_ves, all_ves, count_lists 172 | 173 | 174 | def print_reward_category(exec_results, engine, sql_dialect): 175 | res = { 176 | "engine": engine, 177 | "sql_dialect": sql_dialect, 178 | "distribution": exec_results, 179 | } 180 | file_path = "results.json" 181 | try: 182 | with open(file_path, "r") as file: 183 | data = json.load(file) 184 | except (FileNotFoundError, json.JSONDecodeError): 185 | data = [] # Start with an empty list if file doesn't exist or is empty 186 | 187 | # Append the new data 188 | data.append(res) 189 | 190 | # Write the updated data back to the file 191 | with open(file_path, "w") as file: 192 | json.dump(data, file, indent=4) 193 | 194 | 195 | if __name__ == "__main__": 196 | args_parser = argparse.ArgumentParser() 197 | args_parser.add_argument( 198 | "--predicted_sql_path", type=str, required=True, default="" 199 | ) 200 | args_parser.add_argument("--ground_truth_path", type=str, required=True, default="") 201 | args_parser.add_argument("--data_mode", type=str, required=True, default="dev") 202 | args_parser.add_argument("--db_root_path", type=str, required=True, default="") 203 | args_parser.add_argument("--num_cpus", type=int, default=1) 204 | args_parser.add_argument("--meta_time_out", type=float, default=30.0) 205 | args_parser.add_argument("--mode_gt", type=str, default="gt") 206 | args_parser.add_argument("--mode_predict", type=str, default="gpt") 207 | args_parser.add_argument("--diff_json_path", type=str, default="") 208 | args_parser.add_argument("--engine", type=str, default="") 209 | args_parser.add_argument("--sql_dialect", type=str, default="SQLite") 210 | args = args_parser.parse_args() 211 | exec_result = [] 212 | 213 | pred_queries, db_paths = package_sqls( 214 | args.predicted_sql_path, 215 | args.db_root_path, 216 | args.engine, 217 | sql_dialect=args.sql_dialect, 218 | mode=args.mode_predict, 219 | data_mode=args.data_mode, 220 | ) 221 | # generate ground truth sqls: 222 | gt_queries, db_paths_gt = package_sqls( 223 | args.ground_truth_path, 224 | args.db_root_path, 225 | args.engine, 226 | sql_dialect=args.sql_dialect, 227 | mode="gt", 228 | data_mode=args.data_mode, 229 | ) 230 | query_pairs = list(zip(pred_queries, gt_queries)) 231 | run_sqls_parallel( 232 | query_pairs, 233 | db_places=db_paths, 234 | num_cpus=args.num_cpus, 235 | meta_time_out=args.meta_time_out, 236 | sql_dialect=args.sql_dialect, 237 | ) 238 | exec_result = sort_results(exec_result) 239 | # print_reward_category(exec_result, args.engine, args.sql_dialect) 240 | print("start calculate") 241 | simple_ves, moderate_ves, challenging_ves, ves, count_lists = compute_ves_by_diff( 242 | exec_result, args.diff_json_path 243 | ) 244 | score_lists = [simple_ves, moderate_ves, challenging_ves, ves] 245 | print(f"VES for {args.engine} on {args.sql_dialect} set") 246 | print("start calculate") 247 | print_data(score_lists, count_lists, metric="VES") 248 | print( 249 | "===========================================================================================" 250 | ) 251 | print(f"Finished VES evaluation for {args.engine} on {args.sql_dialect} set") 252 | print("\n\n") 253 | -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/bird_eval/bird_rves.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import numpy as np 4 | import argparse 5 | import multiprocessing as mp 6 | from func_timeout import func_timeout, FunctionTimedOut 7 | from .evaluation_utils import ( 8 | load_json, 9 | package_sqls, 10 | sort_results, 11 | print_data, 12 | connect_db, 13 | ) 14 | import time 15 | import math 16 | from tqdm import tqdm 17 | 18 | exec_result = [] 19 | progress_bar = None 20 | 21 | def result_callback(result): 22 | exec_result.append(result) 23 | progress_bar.update() 24 | 25 | 26 | def clean_abnormal(input): 27 | input = np.asarray(input) 28 | processed_list = [] 29 | mean = np.mean(input, axis=0) 30 | std = np.std(input, axis=0) 31 | for x in input: 32 | if x < mean + 3 * std and x > mean - 3 * std: 33 | processed_list.append(x) 34 | return processed_list 35 | 36 | 37 | def execute_sql(sql, db_path, sql_dialect, return_time=False, **kwds): 38 | # Connect to the database 39 | conn = connect_db(sql_dialect, db_path, **kwds) 40 | start_time = time.time() 41 | cursor = conn.cursor() 42 | cursor.execute(sql) 43 | res = cursor.fetchall() 44 | conn.close() # Don't forget to close the connection! 45 | exec_time = time.time() - start_time 46 | if return_time: 47 | return exec_time 48 | 49 | return res 50 | 51 | 52 | def iterated_execute_sql( 53 | predicted_sql, ground_truth, db_path, iterate_num, sql_dialect, exec_acc, **kwds 54 | ): 55 | diff_list = [] 56 | predicted_res = execute_sql(predicted_sql, db_path, sql_dialect, **kwds) 57 | ground_truth_res = execute_sql(ground_truth, db_path, sql_dialect, **kwds) 58 | reward = 0 59 | time_ratio = 0 60 | if (exec_acc is None and set(predicted_res) == set(ground_truth_res)) or (exec_acc is not None and exec_acc == 1): 61 | for _ in range(iterate_num): 62 | predicted_time = execute_sql( 63 | predicted_sql, db_path, sql_dialect, return_time=True, **kwds 64 | ) 65 | ground_truth_time = execute_sql( 66 | ground_truth, db_path, sql_dialect, return_time=True, **kwds 67 | ) 68 | diff_list.append(ground_truth_time / predicted_time) 69 | processed_diff_list = clean_abnormal(diff_list) 70 | time_ratio = sum(processed_diff_list) / len(processed_diff_list) 71 | if time_ratio == 0: 72 | reward = 0 73 | elif time_ratio >= 2: 74 | reward = 1.25 75 | elif time_ratio >= 1 and time_ratio < 2: 76 | reward = 1 77 | elif time_ratio >= 0.5 and time_ratio < 1: 78 | reward = 0.75 79 | elif time_ratio >= 0.25 and time_ratio < 0.5: 80 | reward = 0.5 81 | else: 82 | reward = 0.25 83 | # return time_ratio 84 | return reward 85 | 86 | 87 | def execute_model( 88 | predicted_sql, ground_truth, db_place, idx, iterate_num, meta_time_out, sql_dialect, exec_acc, **kwds 89 | ): 90 | try: 91 | # you can personalize the total timeout number 92 | # larger timeout leads to more stable ves 93 | # while it needs more your patience.... 94 | reward = func_timeout( 95 | meta_time_out * iterate_num, 96 | iterated_execute_sql, 97 | args=(predicted_sql, ground_truth, db_place, iterate_num, sql_dialect, exec_acc), 98 | kwargs=kwds 99 | ) 100 | except KeyboardInterrupt: 101 | sys.exit(0) 102 | except FunctionTimedOut: 103 | result = [(f"timeout",)] 104 | reward = 0 105 | except Exception as e: 106 | result = [(f"error",)] # possibly len(query) > 512 or not executable 107 | reward = 0 108 | result = {"sql_idx": idx, "reward": reward} 109 | return result 110 | 111 | 112 | def run_sqls_parallel( 113 | sqls, 114 | db_places, 115 | num_cpus=1, 116 | iterate_num=100, 117 | meta_time_out=30.0, 118 | sql_dialect="SQLite", 119 | exec_acc_list=None, 120 | **kwds 121 | ): 122 | global exec_result, progress_bar 123 | exec_result.clear() 124 | progress_bar = tqdm(total=len(sqls)) 125 | pool = mp.Pool(processes=num_cpus) 126 | for i, sql_pair in enumerate(sqls): 127 | predicted_sql, ground_truth = sql_pair 128 | exec_acc = exec_acc_list[i] if exec_acc_list else None 129 | pool.apply_async( 130 | execute_model, 131 | args=( 132 | predicted_sql, 133 | ground_truth, 134 | db_places[i], 135 | i, 136 | iterate_num, 137 | meta_time_out, 138 | sql_dialect, 139 | exec_acc 140 | ), 141 | kwds=kwds, 142 | callback=result_callback, 143 | ) 144 | pool.close() 145 | pool.join() 146 | return exec_result 147 | 148 | 149 | def compute_ves(exec_results): 150 | num_queries = len(exec_results) 151 | total_reward = 0 152 | count = 0 153 | 154 | for i, result in enumerate(exec_results): 155 | if result["reward"] != 0: 156 | count += 1 157 | total_reward += math.sqrt(result["reward"]) * 100 158 | ves = total_reward / num_queries 159 | return ves 160 | 161 | 162 | def compute_ves_by_diff(exec_results, diff_json_path): 163 | num_queries = len(exec_results) 164 | contents = load_json(diff_json_path) 165 | simple_results, moderate_results, challenging_results = [], [], [] 166 | for i, content in enumerate(contents): 167 | if content["difficulty"] == "simple": 168 | simple_results.append(exec_results[i]) 169 | if content["difficulty"] == "moderate": 170 | moderate_results.append(exec_results[i]) 171 | if content["difficulty"] == "challenging": 172 | challenging_results.append(exec_results[i]) 173 | simple_ves = compute_ves(simple_results) 174 | moderate_ves = compute_ves(moderate_results) 175 | challenging_ves = compute_ves(challenging_results) 176 | all_ves = compute_ves(exec_results) 177 | count_lists = [ 178 | len(simple_results), 179 | len(moderate_results), 180 | len(challenging_results), 181 | num_queries, 182 | ] 183 | return simple_ves, moderate_ves, challenging_ves, all_ves, count_lists 184 | 185 | 186 | def print_reward_category(exec_results, engine, sql_dialect): 187 | res = { 188 | "engine": engine, 189 | "sql_dialect": sql_dialect, 190 | "distribution": exec_results, 191 | } 192 | file_path = "results.json" 193 | try: 194 | with open(file_path, "r") as file: 195 | data = json.load(file) 196 | except (FileNotFoundError, json.JSONDecodeError): 197 | data = [] # Start with an empty list if file doesn't exist or is empty 198 | 199 | # Append the new data 200 | data.append(res) 201 | 202 | # Write the updated data back to the file 203 | with open(file_path, "w") as file: 204 | json.dump(data, file, indent=4) 205 | 206 | 207 | if __name__ == "__main__": 208 | args_parser = argparse.ArgumentParser() 209 | args_parser.add_argument( 210 | "--predicted_sql_path", type=str, required=True, default="" 211 | ) 212 | args_parser.add_argument("--ground_truth_path", type=str, required=True, default="") 213 | args_parser.add_argument("--data_mode", type=str, required=True, default="dev") 214 | args_parser.add_argument("--db_root_path", type=str, required=True, default="") 215 | args_parser.add_argument("--num_cpus", type=int, default=1) 216 | args_parser.add_argument("--meta_time_out", type=float, default=30.0) 217 | args_parser.add_argument("--mode_gt", type=str, default="gt") 218 | args_parser.add_argument("--mode_predict", type=str, default="gpt") 219 | args_parser.add_argument("--diff_json_path", type=str, default="") 220 | args_parser.add_argument("--engine", type=str, default="") 221 | args_parser.add_argument("--sql_dialect", type=str, default="SQLite") 222 | args = args_parser.parse_args() 223 | exec_result = [] 224 | 225 | pred_queries, db_paths = package_sqls( 226 | args.predicted_sql_path, 227 | args.db_root_path, 228 | args.engine, 229 | sql_dialect=args.sql_dialect, 230 | mode=args.mode_predict, 231 | data_mode=args.data_mode, 232 | ) 233 | # generate ground truth sqls: 234 | gt_queries, db_paths_gt = package_sqls( 235 | args.ground_truth_path, 236 | args.db_root_path, 237 | args.engine, 238 | sql_dialect=args.sql_dialect, 239 | mode="gt", 240 | data_mode=args.data_mode, 241 | ) 242 | query_pairs = list(zip(pred_queries, gt_queries)) 243 | run_sqls_parallel( 244 | query_pairs, 245 | db_places=db_paths, 246 | num_cpus=args.num_cpus, 247 | meta_time_out=args.meta_time_out, 248 | sql_dialect=args.sql_dialect, 249 | ) 250 | exec_result = sort_results(exec_result) 251 | # print_reward_category(exec_result, args.engine, args.sql_dialect) 252 | print("start calculate") 253 | simple_ves, moderate_ves, challenging_ves, ves, count_lists = compute_ves_by_diff( 254 | exec_result, args.diff_json_path 255 | ) 256 | score_lists = [simple_ves, moderate_ves, challenging_ves, ves] 257 | print(f"VES for {args.engine} on {args.sql_dialect} set") 258 | print("start calculate") 259 | print_data(score_lists, count_lists, metric="VES") 260 | print( 261 | "===========================================================================================" 262 | ) 263 | print(f"Finished VES evaluation for {args.engine} on {args.sql_dialect} set") 264 | print("\n\n") 265 | -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/bird_eval/evaluation_f1.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import multiprocessing as mp 4 | from func_timeout import func_timeout, FunctionTimedOut 5 | from .evaluation_utils import ( 6 | load_json, 7 | execute_sql, 8 | package_sqls, 9 | sort_results, 10 | print_data, 11 | ) 12 | from tqdm import tqdm 13 | 14 | exec_result = [] 15 | progress_bar = None 16 | 17 | def calculate_row_match(predicted_row, ground_truth_row): 18 | """ 19 | Calculate the matching percentage for a single row. 20 | 21 | Args: 22 | predicted_row (tuple): The predicted row values. 23 | ground_truth_row (tuple): The actual row values from ground truth. 24 | 25 | Returns: 26 | float: The match percentage (0 to 1 scale). 27 | """ 28 | total_columns = len(ground_truth_row) 29 | matches = 0 30 | element_in_pred_only = 0 31 | element_in_truth_only = 0 32 | for pred_val in predicted_row: 33 | if pred_val in ground_truth_row: 34 | matches += 1 35 | else: 36 | element_in_pred_only += 1 37 | for truth_val in ground_truth_row: 38 | if truth_val not in predicted_row: 39 | element_in_truth_only += 1 40 | match_percentage = matches / total_columns 41 | pred_only_percentage = element_in_pred_only / total_columns 42 | truth_only_percentage = element_in_truth_only / total_columns 43 | return match_percentage, pred_only_percentage, truth_only_percentage 44 | 45 | 46 | def calculate_f1_score(predicted, ground_truth): 47 | """ 48 | Calculate the F1 score based on sets of predicted results and ground truth results, 49 | where each element (tuple) represents a row from the database with multiple columns. 50 | 51 | Args: 52 | predicted (set of tuples): Predicted results from SQL query. 53 | ground_truth (set of tuples): Actual results expected (ground truth). 54 | 55 | Returns: 56 | float: The calculated F1 score. 57 | """ 58 | # if both predicted and ground_truth are empty, return 1.0 for f1_score 59 | if not predicted and not ground_truth: 60 | return 1.0 61 | 62 | # Drop duplicates 63 | predicted_set = set(predicted) if predicted else set() 64 | ground_truth_set = set(ground_truth) 65 | 66 | # convert back to list 67 | predicted = list(predicted_set) 68 | ground_truth = list(ground_truth_set) 69 | 70 | # Calculate matching scores for each possible pair 71 | match_scores = [] 72 | pred_only_scores = [] 73 | truth_only_scores = [] 74 | for i, gt_row in enumerate(ground_truth): 75 | # rows only in the ground truth results 76 | if i >= len(predicted): 77 | match_scores.append(0) 78 | truth_only_scores.append(1) 79 | continue 80 | pred_row = predicted[i] 81 | match_score, pred_only_score, truth_only_score = calculate_row_match( 82 | pred_row, gt_row 83 | ) 84 | match_scores.append(match_score) 85 | pred_only_scores.append(pred_only_score) 86 | truth_only_scores.append(truth_only_score) 87 | 88 | # rows only in the predicted results 89 | for i in range(len(predicted) - len(ground_truth)): 90 | match_scores.append(0) 91 | pred_only_scores.append(1) 92 | truth_only_scores.append(0) 93 | 94 | tp = sum(match_scores) 95 | fp = sum(pred_only_scores) 96 | fn = sum(truth_only_scores) 97 | 98 | precision = tp / (tp + fp) if tp + fp > 0 else 0 99 | recall = tp / (tp + fn) if tp + fn > 0 else 0 100 | 101 | f1_score = ( 102 | 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0 103 | ) 104 | return f1_score 105 | 106 | 107 | def result_callback(result): 108 | exec_result.append(result) 109 | progress_bar.update() 110 | 111 | 112 | def execute_model( 113 | predicted_sql, ground_truth, db_place, idx, meta_time_out, sql_dialect, **kwds 114 | ): 115 | try: 116 | res = func_timeout( 117 | meta_time_out, 118 | execute_sql, 119 | args=( 120 | predicted_sql, 121 | ground_truth, 122 | db_place, 123 | sql_dialect, 124 | calculate_f1_score, 125 | ), 126 | kwargs=kwds 127 | ) 128 | except KeyboardInterrupt: 129 | sys.exit(0) 130 | except FunctionTimedOut: 131 | result = [(f"timeout",)] 132 | res = 0 133 | except Exception as e: 134 | result = [(f"error",)] # possibly len(query) > 512 or not executable 135 | res = 0 136 | # print(result) 137 | # result = str(set([ret[0] for ret in result])) 138 | result = {"sql_idx": idx, "res": res} 139 | # print(result) 140 | return result 141 | 142 | 143 | def run_sqls_parallel( 144 | sqls, db_places, num_cpus=1, meta_time_out=30.0, sql_dialect="SQLite", **kwds 145 | ): 146 | global exec_result, progress_bar 147 | exec_result.clear() 148 | progress_bar = tqdm(total=len(sqls)) 149 | pool = mp.Pool(processes=num_cpus) 150 | for i, sql_pair in enumerate(sqls): 151 | 152 | predicted_sql, ground_truth = sql_pair 153 | pool.apply_async( 154 | execute_model, 155 | args=( 156 | predicted_sql, 157 | ground_truth, 158 | db_places[i], 159 | i, 160 | meta_time_out, 161 | sql_dialect, 162 | ), 163 | kwds=kwds, 164 | callback=result_callback, 165 | ) 166 | pool.close() 167 | pool.join() 168 | return exec_result 169 | 170 | 171 | def compute_f1_by_diff(exec_results, diff_json_path): 172 | num_queries = len(exec_results) 173 | results = [res["res"] for res in exec_results] 174 | contents = load_json(diff_json_path) 175 | simple_results, moderate_results, challenging_results = [], [], [] 176 | 177 | for i, content in enumerate(contents): 178 | if content["difficulty"] == "simple": 179 | simple_results.append(exec_results[i]) 180 | 181 | if content["difficulty"] == "moderate": 182 | moderate_results.append(exec_results[i]) 183 | 184 | if content["difficulty"] == "challenging": 185 | try: 186 | challenging_results.append(exec_results[i]) 187 | except: 188 | print(i) 189 | 190 | simple_f1 = sum([res["res"] for res in simple_results]) / len(simple_results) * 100 191 | moderate_f1 = ( 192 | sum([res["res"] for res in moderate_results]) / len(moderate_results) * 100 193 | ) 194 | challenging_f1 = ( 195 | sum([res["res"] for res in challenging_results]) 196 | / len(challenging_results) 197 | * 100 198 | ) 199 | all_f1 = sum(results) / num_queries * 100 200 | count_lists = [ 201 | len(simple_results), 202 | len(moderate_results), 203 | len(challenging_results), 204 | num_queries, 205 | ] 206 | return ( 207 | simple_f1, 208 | moderate_f1, 209 | challenging_f1, 210 | all_f1, 211 | count_lists, 212 | ) 213 | 214 | 215 | if __name__ == "__main__": 216 | args_parser = argparse.ArgumentParser() 217 | args_parser.add_argument( 218 | "--predicted_sql_path", type=str, required=True, default="" 219 | ) 220 | args_parser.add_argument("--ground_truth_path", type=str, required=True, default="") 221 | args_parser.add_argument("--data_mode", type=str, required=True, default="dev") 222 | args_parser.add_argument("--db_root_path", type=str, required=True, default="") 223 | args_parser.add_argument("--num_cpus", type=int, default=1) 224 | args_parser.add_argument("--meta_time_out", type=float, default=30.0) 225 | args_parser.add_argument("--mode_gt", type=str, default="gt") 226 | args_parser.add_argument("--mode_predict", type=str, default="gpt") 227 | args_parser.add_argument("--difficulty", type=str, default="simple") 228 | args_parser.add_argument("--diff_json_path", type=str, default="") 229 | args_parser.add_argument("--engine", type=str, default="") 230 | args_parser.add_argument("--sql_dialect", type=str, default="SQLite") 231 | args = args_parser.parse_args() 232 | exec_result = [] 233 | 234 | pred_queries, db_paths = package_sqls( 235 | args.predicted_sql_path, 236 | args.db_root_path, 237 | args.engine, 238 | sql_dialect=args.sql_dialect, 239 | mode=args.mode_predict, 240 | data_mode=args.data_mode, 241 | ) 242 | # generate ground truth sqls: 243 | gt_queries, db_paths_gt = package_sqls( 244 | args.ground_truth_path, 245 | args.db_root_path, 246 | args.engine, 247 | sql_dialect=args.sql_dialect, 248 | mode="gt", 249 | data_mode=args.data_mode, 250 | ) 251 | 252 | query_pairs = list(zip(pred_queries, gt_queries)) 253 | 254 | run_sqls_parallel( 255 | query_pairs, 256 | db_places=db_paths, 257 | num_cpus=args.num_cpus, 258 | meta_time_out=args.meta_time_out, 259 | sql_dialect=args.sql_dialect, 260 | ) 261 | exec_result = sort_results(exec_result) 262 | 263 | print("start calculate") 264 | simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_f1_by_diff( 265 | exec_result, args.diff_json_path 266 | ) 267 | score_lists = [simple_acc, moderate_acc, challenging_acc, acc] 268 | print(f"Soft F1 for {args.engine} on {args.sql_dialect} set") 269 | print("start calculate") 270 | print_data(score_lists, count_lists) 271 | print( 272 | "===========================================================================================" 273 | ) 274 | print(f"Finished Soft F1 evaluation for {args.engine} on {args.sql_dialect} set") 275 | print("\n\n") 276 | -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/test_suite_sql_eval/exec_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import asyncio 4 | import sqlite3 5 | import threading 6 | from typing import Tuple, Any, List, Set 7 | from itertools import product 8 | from collections import defaultdict 9 | import tqdm 10 | import random 11 | from .parse import get_all_preds_for_execution, remove_distinct 12 | import time 13 | import pickle as pkl 14 | import subprocess 15 | from itertools import chain 16 | 17 | 18 | 19 | threadLock = threading.Lock() 20 | TIMEOUT = 60 21 | EXEC_TMP_DIR = 'tmp/' 22 | 23 | def permute_tuple(element: Tuple, perm: Tuple) -> Tuple: 24 | assert len(element) == len(perm) 25 | return tuple([element[i] for i in perm]) 26 | 27 | 28 | def unorder_row(row: Tuple) -> Tuple: 29 | return tuple(sorted(row, key=lambda x: str(x) + str(type(x)))) 30 | 31 | 32 | # unorder each row in the table 33 | # [result_1 and result_2 has the same bag of unordered row] 34 | # is a necessary condition of 35 | # [result_1 and result_2 are equivalent in denotation] 36 | def quick_rej(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool: 37 | s1 = [unorder_row(row) for row in result1] 38 | s2 = [unorder_row(row) for row in result2] 39 | if order_matters: 40 | return s1 == s2 41 | else: 42 | return set(s1) == set(s2) 43 | 44 | 45 | # return whether two bag of relations are equivalent 46 | def multiset_eq(l1: List, l2: List) -> bool: 47 | if len(l1) != len(l2): 48 | return False 49 | d = defaultdict(int) 50 | for e in l1: 51 | d[e] = d[e] + 1 52 | for e in l2: 53 | d[e] = d[e] - 1 54 | if d[e] < 0: 55 | return False 56 | return True 57 | 58 | 59 | def get_constraint_permutation(tab1_sets_by_columns: List[Set], result2: List[Tuple]): 60 | num_cols = len(result2[0]) 61 | perm_constraints = [{i for i in range(num_cols)} for _ in range(num_cols)] 62 | if num_cols <= 3: 63 | return product(*perm_constraints) 64 | 65 | # we sample 20 rows and constrain the space of permutations 66 | for _ in range(20): 67 | random_tab2_row = random.choice(result2) 68 | 69 | for tab1_col in range(num_cols): 70 | for tab2_col in set(perm_constraints[tab1_col]): 71 | if random_tab2_row[tab2_col] not in tab1_sets_by_columns[tab1_col]: 72 | perm_constraints[tab1_col].remove(tab2_col) 73 | return product(*perm_constraints) 74 | 75 | 76 | # check whether two denotations are correct 77 | def result_eq(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool: 78 | if len(result1) == 0 and len(result2) == 0: 79 | return True 80 | 81 | # if length is not the same, then they are definitely different bag of rows 82 | if len(result1) != len(result2): 83 | return False 84 | 85 | num_cols = len(result1[0]) 86 | 87 | # if the results do not have the same number of columns, they are different 88 | if len(result2[0]) != num_cols: 89 | return False 90 | 91 | # unorder each row and compare whether the denotation is the same 92 | # this can already find most pair of denotations that are different 93 | if not quick_rej(result1, result2, order_matters): 94 | return False 95 | 96 | # the rest of the problem is in fact more complicated than one might think 97 | # we want to find a permutation of column order and a permutation of row order, 98 | # s.t. result_1 is the same as result_2 99 | # we return true if we can find such column & row permutations 100 | # and false if we cannot 101 | tab1_sets_by_columns = [{row[i] for row in result1} for i in range(num_cols)] 102 | 103 | # on a high level, we enumerate all possible column permutations that might make result_1 == result_2 104 | # we decrease the size of the column permutation space by the function get_constraint_permutation 105 | # if one of the permutation make result_1, result_2 equivalent, then they are equivalent 106 | for perm in get_constraint_permutation(tab1_sets_by_columns, result2): 107 | if len(perm) != len(set(perm)): 108 | continue 109 | if num_cols == 1: 110 | result2_perm = result2 111 | else: 112 | result2_perm = [permute_tuple(element, perm) for element in result2] 113 | if order_matters: 114 | if result1 == result2_perm: 115 | return True 116 | else: 117 | # in fact the first condition must hold if the second condition holds 118 | # but the first is way more efficient implementation-wise 119 | # and we use it to quickly reject impossible candidates 120 | if set(result1) == set(result2_perm) and multiset_eq(result1, result2_perm): 121 | return True 122 | return False 123 | 124 | 125 | def replace_cur_year(query: str) -> str: 126 | return re.sub( 127 | "YEAR\s*\(\s*CURDATE\s*\(\s*\)\s*\)\s*", "2020", query, flags=re.IGNORECASE 128 | ) 129 | 130 | 131 | # get the database cursor for a sqlite database path 132 | def get_cursor_from_path(sqlite_path: str): 133 | try: 134 | if not os.path.exists(sqlite_path): 135 | print("Openning a new connection %s" % sqlite_path) 136 | connection = sqlite3.connect(sqlite_path) 137 | except Exception as e: 138 | print(sqlite_path) 139 | raise e 140 | connection.text_factory = lambda b: b.decode(errors="ignore") 141 | cursor = connection.cursor() 142 | return cursor 143 | 144 | 145 | async def exec_on_db_(sqlite_path: str, query: str) -> Tuple[str, Any]: 146 | query = replace_cur_year(query) 147 | cursor = get_cursor_from_path(sqlite_path) 148 | try: 149 | cursor.execute(query) 150 | result = cursor.fetchall() 151 | cursor.close() 152 | cursor.connection.close() 153 | return "result", result 154 | except Exception as e: 155 | cursor.close() 156 | cursor.connection.close() 157 | return "exception", e 158 | 159 | async def exec_on_db( 160 | sqlite_path: str, query: str, process_id: str = "", timeout: int = TIMEOUT 161 | ) -> Tuple[str, Any]: 162 | try: 163 | return await asyncio.wait_for(exec_on_db_(sqlite_path, query), timeout) 164 | except asyncio.TimeoutError: 165 | return ('exception', TimeoutError) 166 | except Exception as e: 167 | return ("exception", e) 168 | 169 | 170 | # postprocess the model predictions to avoid execution errors 171 | # e.g. removing spaces between ">" and "=" 172 | def postprocess(query: str) -> str: 173 | query = query.replace('> =', '>=').replace('< =', '<=').replace('! =', '!=') 174 | return query 175 | 176 | 177 | # approximate whether p_str and g_str are semantically equivalent 178 | # db is the database path 179 | # we are going to evaluate whether they are equivalent in all the databases 180 | # that are in the same directory as db 181 | # 0 if denotationally equivalent 182 | # 1 otherwise 183 | # the meaning of each auxillary argument can be seen in the parser definition in evaluation.py 184 | def eval_exec_match(db: str, p_str: str, g_str: str, plug_value: bool, keep_distinct: bool, progress_bar_for_each_datapoint: bool) -> int: 185 | # post-process the prediction. 186 | # e.g. removing spaces between ">" and "=" 187 | p_str, g_str = postprocess(p_str), postprocess(g_str) 188 | if not keep_distinct: 189 | p_str = remove_distinct(p_str) 190 | g_str = remove_distinct(g_str) 191 | 192 | # we decide whether two denotations are equivalent based on "bag semantics" 193 | # https://courses.cs.washington.edu/courses/cse444/10sp/lectures/lecture16.pdf 194 | # if there is order by in query, then we assume order of the rows matter 195 | # order by might also be used to find the max/min instead of sorting, 196 | # but in that case the result mostly only contains one row and hence order_matters does not make a difference 197 | order_matters = 'order by' in g_str.lower() 198 | 199 | # find all databases in the same directory 200 | db_dir = os.path.dirname(db) 201 | db_paths = [os.path.join(db_dir, basename) for basename in os.listdir(db_dir) if '.sqlite' in basename] 202 | 203 | preds = [p_str] 204 | # if plug in value (i.e. we do not consider value prediction correctness) 205 | # enumerate all ways to plug in values in the gold query to the model predictions 206 | # otherwise, we only evaluate the predicted query with its own value prediction 207 | if plug_value: 208 | _, preds = get_all_preds_for_execution(g_str, p_str) 209 | # we did not add this line in our EMNLP work 210 | # this reduces "false negatives" when value is substituted 211 | preds = chain([p_str], preds) 212 | 213 | for pred in preds: 214 | 215 | pred_passes = 1 216 | # compare the gold and predicted denotations on each database in the directory 217 | # wrap with progress bar if required 218 | if progress_bar_for_each_datapoint: 219 | ranger = tqdm.tqdm(db_paths) 220 | else: 221 | ranger = db_paths 222 | 223 | for db_path in ranger: 224 | g_flag, g_denotation = asyncio.run(exec_on_db(db_path, g_str)) 225 | p_flag, p_denotation = asyncio.run(exec_on_db(db_path, pred)) 226 | 227 | # we should expect the gold to be succesfully executed on the database 228 | assert g_flag != 'exception', 'gold query %s has error on database file %s' % (g_str, db_path) 229 | 230 | # wrong if execution fails 231 | if p_flag == 'exception': 232 | pred_passes = 0 233 | 234 | # if denotations are not equivalent, the prediction must be wrong 235 | elif not result_eq(g_denotation, p_denotation, order_matters=order_matters): 236 | pred_passes = 0 237 | if pred_passes == 0: 238 | break 239 | 240 | # the model prediction has the same denotation as the gold for all databases 241 | if pred_passes == 1: 242 | return 1 243 | 244 | # none of the predictions passed 245 | return 0 -------------------------------------------------------------------------------- /src/nl2sql360/evaluator/test_suite_sql_eval/process_sql.py: -------------------------------------------------------------------------------- 1 | ################################ 2 | # Assumptions: 3 | # 1. sql is correct 4 | # 2. only table name has alias 5 | # 3. only one intersect/union/except 6 | # 7 | # val: number(float)/string(str)/sql(dict) 8 | # col_unit: (agg_id, col_id, isDistinct(bool)) 9 | # val_unit: (unit_op, col_unit1, col_unit2) 10 | # table_unit: (table_type, col_unit/sql) 11 | # cond_unit: (not_op, op_id, val_unit, val1, val2) 12 | # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] 13 | # sql { 14 | # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) 15 | # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} 16 | # 'where': condition 17 | # 'groupBy': [col_unit1, col_unit2, ...] 18 | # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) 19 | # 'having': condition 20 | # 'limit': None/limit value 21 | # 'intersect': None/sql 22 | # 'except': None/sql 23 | # 'union': None/sql 24 | # } 25 | ################################ 26 | 27 | import json 28 | import sqlite3 29 | from nltk import word_tokenize 30 | 31 | CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') 32 | JOIN_KEYWORDS = ('join', 'on', 'as') 33 | 34 | WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') 35 | UNIT_OPS = ('none', '-', '+', "*", '/') 36 | AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') 37 | TABLE_TYPE = { 38 | 'sql': "sql", 39 | 'table_unit': "table_unit", 40 | } 41 | 42 | COND_OPS = ('and', 'or') 43 | SQL_OPS = ('intersect', 'union', 'except') 44 | ORDER_OPS = ('desc', 'asc') 45 | 46 | 47 | 48 | class Schema: 49 | """ 50 | Simple schema which maps table&column to a unique identifier 51 | """ 52 | def __init__(self, schema): 53 | self._schema = schema 54 | self._idMap = self._map(self._schema) 55 | 56 | @property 57 | def schema(self): 58 | return self._schema 59 | 60 | @property 61 | def idMap(self): 62 | return self._idMap 63 | 64 | def _map(self, schema): 65 | idMap = {'*': "__all__"} 66 | id = 1 67 | for key, vals in schema.items(): 68 | for val in vals: 69 | idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__" 70 | id += 1 71 | 72 | for key in schema: 73 | idMap[key.lower()] = "__" + key.lower() + "__" 74 | id += 1 75 | 76 | return idMap 77 | 78 | 79 | def get_schema(db): 80 | """ 81 | Get database's schema, which is a dict with table name as key 82 | and list of column names as value 83 | :param db: database path 84 | :return: schema dict 85 | """ 86 | 87 | schema = {} 88 | conn = sqlite3.connect(db) 89 | cursor = conn.cursor() 90 | 91 | # fetch table names 92 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") 93 | tables = [str(table[0].lower()) for table in cursor.fetchall()] 94 | 95 | # fetch table info 96 | for table in tables: 97 | cursor.execute("PRAGMA table_info(`{}`)".format(table)) 98 | schema[table] = [str(col[1].lower()) for col in cursor.fetchall()] 99 | 100 | return schema 101 | 102 | 103 | def get_schema_from_json(fpath): 104 | with open(fpath) as f: 105 | data = json.load(f) 106 | 107 | schema = {} 108 | for entry in data: 109 | table = str(entry['table'].lower()) 110 | cols = [str(col['column_name'].lower()) for col in entry['col_data']] 111 | schema[table] = cols 112 | 113 | return schema 114 | 115 | 116 | def tokenize(string): 117 | string = str(string) 118 | string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem?? 119 | quote_idxs = [idx for idx, char in enumerate(string) if char == '"'] 120 | assert len(quote_idxs) % 2 == 0, "Unexpected quote" 121 | 122 | # keep string value as token 123 | vals = {} 124 | for i in range(len(quote_idxs)-1, -1, -2): 125 | qidx1 = quote_idxs[i-1] 126 | qidx2 = quote_idxs[i] 127 | val = string[qidx1: qidx2+1] 128 | key = "__val_{}_{}__".format(qidx1, qidx2) 129 | string = string[:qidx1] + key + string[qidx2+1:] 130 | vals[key] = val 131 | 132 | toks = [word.lower() for word in word_tokenize(string)] 133 | # replace with string value token 134 | for i in range(len(toks)): 135 | if toks[i] in vals: 136 | toks[i] = vals[toks[i]] 137 | 138 | # find if there exists !=, >=, <= 139 | eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="] 140 | eq_idxs.reverse() 141 | prefix = ('!', '>', '<') 142 | for eq_idx in eq_idxs: 143 | pre_tok = toks[eq_idx-1] 144 | if pre_tok in prefix: 145 | toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ] 146 | 147 | return toks 148 | 149 | 150 | def scan_alias(toks): 151 | """Scan the index of 'as' and build the map for all alias""" 152 | as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as'] 153 | alias = {} 154 | for idx in as_idxs: 155 | alias[toks[idx+1]] = toks[idx-1] 156 | return alias 157 | 158 | 159 | def get_tables_with_alias(schema, toks): 160 | tables = scan_alias(toks) 161 | for key in schema: 162 | assert key not in tables, "Alias {} has the same name in table".format(key) 163 | tables[key] = key 164 | return tables 165 | 166 | 167 | def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): 168 | """ 169 | :returns next idx, column id 170 | """ 171 | tok = toks[start_idx] 172 | if tok == "*": 173 | return start_idx + 1, schema.idMap[tok] 174 | 175 | if '.' in tok: # if token is a composite 176 | alias, col = tok.split('.') 177 | key = tables_with_alias[alias] + "." + col 178 | return start_idx+1, schema.idMap[key] 179 | 180 | assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty" 181 | 182 | for alias in default_tables: 183 | table = tables_with_alias[alias] 184 | if tok in schema.schema[table]: 185 | key = table + "." + tok 186 | return start_idx+1, schema.idMap[key] 187 | 188 | assert False, "Error col: {}".format(tok) 189 | 190 | 191 | def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): 192 | """ 193 | :returns next idx, (agg_op id, col_id) 194 | """ 195 | idx = start_idx 196 | len_ = len(toks) 197 | isBlock = False 198 | isDistinct = False 199 | if toks[idx] == '(': 200 | isBlock = True 201 | idx += 1 202 | 203 | if toks[idx] in AGG_OPS: 204 | agg_id = AGG_OPS.index(toks[idx]) 205 | idx += 1 206 | assert idx < len_ and toks[idx] == '(' 207 | idx += 1 208 | if toks[idx] == "distinct": 209 | idx += 1 210 | isDistinct = True 211 | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) 212 | assert idx < len_ and toks[idx] == ')' 213 | idx += 1 214 | return idx, (agg_id, col_id, isDistinct) 215 | 216 | if toks[idx] == "distinct": 217 | idx += 1 218 | isDistinct = True 219 | agg_id = AGG_OPS.index("none") 220 | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) 221 | 222 | if isBlock: 223 | assert toks[idx] == ')' 224 | idx += 1 # skip ')' 225 | 226 | return idx, (agg_id, col_id, isDistinct) 227 | 228 | 229 | def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): 230 | idx = start_idx 231 | len_ = len(toks) 232 | isBlock = False 233 | if toks[idx] == '(': 234 | isBlock = True 235 | idx += 1 236 | 237 | col_unit1 = None 238 | col_unit2 = None 239 | unit_op = UNIT_OPS.index('none') 240 | 241 | idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) 242 | if idx < len_ and toks[idx] in UNIT_OPS: 243 | unit_op = UNIT_OPS.index(toks[idx]) 244 | idx += 1 245 | idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) 246 | 247 | if isBlock: 248 | assert toks[idx] == ')' 249 | idx += 1 # skip ')' 250 | 251 | return idx, (unit_op, col_unit1, col_unit2) 252 | 253 | 254 | def parse_table_unit(toks, start_idx, tables_with_alias, schema): 255 | """ 256 | :returns next idx, table id, table name 257 | """ 258 | idx = start_idx 259 | len_ = len(toks) 260 | key = tables_with_alias[toks[idx]] 261 | 262 | if idx + 1 < len_ and toks[idx+1] == "as": 263 | idx += 3 264 | else: 265 | idx += 1 266 | 267 | return idx, schema.idMap[key], key 268 | 269 | 270 | def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None): 271 | idx = start_idx 272 | len_ = len(toks) 273 | 274 | isBlock = False 275 | if toks[idx] == '(': 276 | isBlock = True 277 | idx += 1 278 | 279 | if toks[idx] == 'select': 280 | idx, val = parse_sql(toks, idx, tables_with_alias, schema) 281 | elif "\"" in toks[idx]: # token is a string value 282 | val = toks[idx] 283 | idx += 1 284 | else: 285 | try: 286 | val = float(toks[idx]) 287 | idx += 1 288 | except: 289 | end_idx = idx 290 | while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\ 291 | and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS: 292 | end_idx += 1 293 | 294 | idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables) 295 | idx = end_idx 296 | 297 | if isBlock: 298 | assert toks[idx] == ')' 299 | idx += 1 300 | 301 | return idx, val 302 | 303 | 304 | def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None): 305 | idx = start_idx 306 | len_ = len(toks) 307 | conds = [] 308 | 309 | while idx < len_: 310 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) 311 | not_op = False 312 | if toks[idx] == 'not': 313 | not_op = True 314 | idx += 1 315 | 316 | assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx]) 317 | op_id = WHERE_OPS.index(toks[idx]) 318 | idx += 1 319 | val1 = val2 = None 320 | if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values 321 | idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) 322 | assert toks[idx] == 'and' 323 | idx += 1 324 | idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables) 325 | else: # normal case: single value 326 | idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) 327 | val2 = None 328 | 329 | conds.append((not_op, op_id, val_unit, val1, val2)) 330 | 331 | if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS): 332 | break 333 | 334 | if idx < len_ and toks[idx] in COND_OPS: 335 | conds.append(toks[idx]) 336 | idx += 1 # skip and/or 337 | 338 | return idx, conds 339 | 340 | 341 | def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None): 342 | idx = start_idx 343 | len_ = len(toks) 344 | 345 | assert toks[idx] == 'select', "'select' not found" 346 | idx += 1 347 | isDistinct = False 348 | if idx < len_ and toks[idx] == 'distinct': 349 | idx += 1 350 | isDistinct = True 351 | val_units = [] 352 | 353 | while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS: 354 | agg_id = AGG_OPS.index("none") 355 | if toks[idx] in AGG_OPS: 356 | agg_id = AGG_OPS.index(toks[idx]) 357 | idx += 1 358 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) 359 | val_units.append((agg_id, val_unit)) 360 | if idx < len_ and toks[idx] == ',': 361 | idx += 1 # skip ',' 362 | 363 | return idx, (isDistinct, val_units) 364 | 365 | 366 | def parse_from(toks, start_idx, tables_with_alias, schema): 367 | """ 368 | Assume in the from clause, all table units are combined with join 369 | """ 370 | assert 'from' in toks[start_idx:], "'from' not found" 371 | 372 | len_ = len(toks) 373 | idx = toks.index('from', start_idx) + 1 374 | default_tables = [] 375 | table_units = [] 376 | conds = [] 377 | 378 | while idx < len_: 379 | isBlock = False 380 | if toks[idx] == '(': 381 | isBlock = True 382 | idx += 1 383 | 384 | if toks[idx] == 'select': 385 | idx, sql = parse_sql(toks, idx, tables_with_alias, schema) 386 | table_units.append((TABLE_TYPE['sql'], sql)) 387 | else: 388 | if idx < len_ and toks[idx] == 'join': 389 | idx += 1 # skip join 390 | idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema) 391 | table_units.append((TABLE_TYPE['table_unit'],table_unit)) 392 | default_tables.append(table_name) 393 | if idx < len_ and toks[idx] == "on": 394 | idx += 1 # skip on 395 | idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 396 | if len(conds) > 0: 397 | conds.append('and') 398 | conds.extend(this_conds) 399 | 400 | if isBlock: 401 | assert toks[idx] == ')' 402 | idx += 1 403 | if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 404 | break 405 | 406 | return idx, table_units, conds, default_tables 407 | 408 | 409 | def parse_where(toks, start_idx, tables_with_alias, schema, default_tables): 410 | idx = start_idx 411 | len_ = len(toks) 412 | 413 | if idx >= len_ or toks[idx] != 'where': 414 | return idx, [] 415 | 416 | idx += 1 417 | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 418 | return idx, conds 419 | 420 | 421 | def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables): 422 | idx = start_idx 423 | len_ = len(toks) 424 | col_units = [] 425 | 426 | if idx >= len_ or toks[idx] != 'group': 427 | return idx, col_units 428 | 429 | idx += 1 430 | assert toks[idx] == 'by' 431 | idx += 1 432 | 433 | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 434 | idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) 435 | col_units.append(col_unit) 436 | if idx < len_ and toks[idx] == ',': 437 | idx += 1 # skip ',' 438 | else: 439 | break 440 | 441 | return idx, col_units 442 | 443 | 444 | def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables): 445 | idx = start_idx 446 | len_ = len(toks) 447 | val_units = [] 448 | order_type = 'asc' # default type is 'asc' 449 | 450 | if idx >= len_ or toks[idx] != 'order': 451 | return idx, val_units 452 | 453 | idx += 1 454 | assert toks[idx] == 'by' 455 | idx += 1 456 | 457 | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 458 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) 459 | val_units.append(val_unit) 460 | if idx < len_ and toks[idx] in ORDER_OPS: 461 | order_type = toks[idx] 462 | idx += 1 463 | if idx < len_ and toks[idx] == ',': 464 | idx += 1 # skip ',' 465 | else: 466 | break 467 | 468 | return idx, (order_type, val_units) 469 | 470 | 471 | def parse_having(toks, start_idx, tables_with_alias, schema, default_tables): 472 | idx = start_idx 473 | len_ = len(toks) 474 | 475 | if idx >= len_ or toks[idx] != 'having': 476 | return idx, [] 477 | 478 | idx += 1 479 | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 480 | return idx, conds 481 | 482 | 483 | def parse_limit(toks, start_idx): 484 | idx = start_idx 485 | len_ = len(toks) 486 | 487 | if idx < len_ and toks[idx] == 'limit': 488 | idx += 2 489 | # make limit value can work, cannot assume put 1 as a fake limit number 490 | if type(toks[idx-1]) != int: 491 | return idx, 1 492 | 493 | return idx, int(toks[idx-1]) 494 | 495 | return idx, None 496 | 497 | 498 | def parse_sql(toks, start_idx, tables_with_alias, schema): 499 | isBlock = False # indicate whether this is a block of sql/sub-sql 500 | len_ = len(toks) 501 | idx = start_idx 502 | 503 | sql = {} 504 | if toks[idx] == '(': 505 | isBlock = True 506 | idx += 1 507 | 508 | # parse from clause in order to get default tables 509 | from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema) 510 | sql['from'] = {'table_units': table_units, 'conds': conds} 511 | # select clause 512 | _, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables) 513 | idx = from_end_idx 514 | sql['select'] = select_col_units 515 | # where clause 516 | idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables) 517 | sql['where'] = where_conds 518 | # group by clause 519 | idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables) 520 | sql['groupBy'] = group_col_units 521 | # having clause 522 | idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables) 523 | sql['having'] = having_conds 524 | # order by clause 525 | idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables) 526 | sql['orderBy'] = order_col_units 527 | # limit clause 528 | idx, limit_val = parse_limit(toks, idx) 529 | sql['limit'] = limit_val 530 | 531 | idx = skip_semicolon(toks, idx) 532 | if isBlock: 533 | assert toks[idx] == ')' 534 | idx += 1 # skip ')' 535 | idx = skip_semicolon(toks, idx) 536 | 537 | # intersect/union/except clause 538 | for op in SQL_OPS: # initialize IUE 539 | sql[op] = None 540 | if idx < len_ and toks[idx] in SQL_OPS: 541 | sql_op = toks[idx] 542 | idx += 1 543 | idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema) 544 | sql[sql_op] = IUE_sql 545 | return idx, sql 546 | 547 | 548 | def load_data(fpath): 549 | with open(fpath) as f: 550 | data = json.load(f) 551 | return data 552 | 553 | 554 | def get_sql(schema, query): 555 | toks = tokenize(query) 556 | tables_with_alias = get_tables_with_alias(schema.schema, toks) 557 | _, sql = parse_sql(toks, 0, tables_with_alias, schema) 558 | 559 | return sql 560 | 561 | 562 | def skip_semicolon(toks, start_idx): 563 | idx = start_idx 564 | while idx < len(toks) and toks[idx] == ";": 565 | idx += 1 566 | return idx -------------------------------------------------------------------------------- /src/nl2sql360/arguments/hf_argparser.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import dataclasses 16 | import json 17 | import sys 18 | import types 19 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError 20 | from copy import copy 21 | from enum import Enum 22 | from inspect import isclass 23 | from pathlib import Path 24 | from typing import Any, Callable, Dict, Iterable, List, Literal, NewType, Optional, Tuple, Union, get_type_hints 25 | 26 | import yaml 27 | 28 | 29 | DataClass = NewType("DataClass", Any) 30 | DataClassType = NewType("DataClassType", Any) 31 | 32 | 33 | # From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 34 | def string_to_bool(v): 35 | if isinstance(v, bool): 36 | return v 37 | if v.lower() in ("yes", "true", "t", "y", "1"): 38 | return True 39 | elif v.lower() in ("no", "false", "f", "n", "0"): 40 | return False 41 | else: 42 | raise ArgumentTypeError( 43 | f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)." 44 | ) 45 | 46 | 47 | def make_choice_type_function(choices: list) -> Callable[[str], Any]: 48 | """ 49 | Creates a mapping function from each choices string representation to the actual value. Used to support multiple 50 | value types for a single argument. 51 | 52 | Args: 53 | choices (list): List of choices. 54 | 55 | Returns: 56 | Callable[[str], Any]: Mapping function from string representation to actual value for each choice. 57 | """ 58 | str_to_choice = {str(choice): choice for choice in choices} 59 | return lambda arg: str_to_choice.get(arg, arg) 60 | 61 | 62 | def HfArg( 63 | *, 64 | aliases: Union[str, List[str]] = None, 65 | help: str = None, 66 | default: Any = dataclasses.MISSING, 67 | default_factory: Callable[[], Any] = dataclasses.MISSING, 68 | metadata: dict = None, 69 | **kwargs, 70 | ) -> dataclasses.Field: 71 | """Argument helper enabling a concise syntax to create dataclass fields for parsing with `HfArgumentParser`. 72 | 73 | Example comparing the use of `HfArg` and `dataclasses.field`: 74 | ``` 75 | @dataclass 76 | class Args: 77 | regular_arg: str = dataclasses.field(default="Huggingface", metadata={"aliases": ["--example", "-e"], "help": "This syntax could be better!"}) 78 | hf_arg: str = HfArg(default="Huggingface", aliases=["--example", "-e"], help="What a nice syntax!") 79 | ``` 80 | 81 | Args: 82 | aliases (Union[str, List[str]], optional): 83 | Single string or list of strings of aliases to pass on to argparse, e.g. `aliases=["--example", "-e"]`. 84 | Defaults to None. 85 | help (str, optional): Help string to pass on to argparse that can be displayed with --help. Defaults to None. 86 | default (Any, optional): 87 | Default value for the argument. If not default or default_factory is specified, the argument is required. 88 | Defaults to dataclasses.MISSING. 89 | default_factory (Callable[[], Any], optional): 90 | The default_factory is a 0-argument function called to initialize a field's value. It is useful to provide 91 | default values for mutable types, e.g. lists: `default_factory=list`. Mutually exclusive with `default=`. 92 | Defaults to dataclasses.MISSING. 93 | metadata (dict, optional): Further metadata to pass on to `dataclasses.field`. Defaults to None. 94 | 95 | Returns: 96 | Field: A `dataclasses.Field` with the desired properties. 97 | """ 98 | if metadata is None: 99 | # Important, don't use as default param in function signature because dict is mutable and shared across function calls 100 | metadata = {} 101 | if aliases is not None: 102 | metadata["aliases"] = aliases 103 | if help is not None: 104 | metadata["help"] = help 105 | 106 | return dataclasses.field(metadata=metadata, default=default, default_factory=default_factory, **kwargs) 107 | 108 | 109 | class HfArgumentParser(ArgumentParser): 110 | """ 111 | This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments. 112 | 113 | The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed) 114 | arguments to the parser after initialization and you'll get the output back after parsing as an additional 115 | namespace. Optional: To create sub argument groups use the `_argument_group_name` attribute in the dataclass. 116 | """ 117 | 118 | dataclass_types: Iterable[DataClassType] 119 | 120 | def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs): 121 | """ 122 | Args: 123 | dataclass_types: 124 | Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args. 125 | kwargs (`Dict[str, Any]`, *optional*): 126 | Passed to `argparse.ArgumentParser()` in the regular way. 127 | """ 128 | # To make the default appear when using --help 129 | if "formatter_class" not in kwargs: 130 | kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter 131 | super().__init__(**kwargs) 132 | if dataclasses.is_dataclass(dataclass_types): 133 | dataclass_types = [dataclass_types] 134 | self.dataclass_types = list(dataclass_types) 135 | for dtype in self.dataclass_types: 136 | self._add_dataclass_arguments(dtype) 137 | 138 | @staticmethod 139 | def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field): 140 | field_name = f"--{field.name}" 141 | kwargs = field.metadata.copy() 142 | # field.metadata is not used at all by Data Classes, 143 | # it is provided as a third-party extension mechanism. 144 | if isinstance(field.type, str): 145 | raise RuntimeError( 146 | "Unresolved type detected, which should have been done with the help of " 147 | "`typing.get_type_hints` method by default" 148 | ) 149 | 150 | aliases = kwargs.pop("aliases", []) 151 | if isinstance(aliases, str): 152 | aliases = [aliases] 153 | 154 | origin_type = getattr(field.type, "__origin__", field.type) 155 | if origin_type is Union or (hasattr(types, "UnionType") and isinstance(origin_type, types.UnionType)): 156 | if str not in field.type.__args__ and ( 157 | len(field.type.__args__) != 2 or type(None) not in field.type.__args__ 158 | ): 159 | raise ValueError( 160 | "Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union` because" 161 | " the argument parser only supports one type per argument." 162 | f" Problem encountered in field '{field.name}'." 163 | ) 164 | if type(None) not in field.type.__args__: 165 | # filter `str` in Union 166 | field.type = field.type.__args__[0] if field.type.__args__[1] == str else field.type.__args__[1] 167 | origin_type = getattr(field.type, "__origin__", field.type) 168 | elif bool not in field.type.__args__: 169 | # filter `NoneType` in Union (except for `Union[bool, NoneType]`) 170 | field.type = ( 171 | field.type.__args__[0] if isinstance(None, field.type.__args__[1]) else field.type.__args__[1] 172 | ) 173 | origin_type = getattr(field.type, "__origin__", field.type) 174 | 175 | # A variable to store kwargs for a boolean field, if needed 176 | # so that we can init a `no_*` complement argument (see below) 177 | bool_kwargs = {} 178 | if origin_type is Literal or (isinstance(field.type, type) and issubclass(field.type, Enum)): 179 | if origin_type is Literal: 180 | kwargs["choices"] = field.type.__args__ 181 | else: 182 | kwargs["choices"] = [x.value for x in field.type] 183 | 184 | kwargs["type"] = make_choice_type_function(kwargs["choices"]) 185 | 186 | if field.default is not dataclasses.MISSING: 187 | kwargs["default"] = field.default 188 | else: 189 | kwargs["required"] = True 190 | elif field.type is bool or field.type == Optional[bool]: 191 | # Copy the currect kwargs to use to instantiate a `no_*` complement argument below. 192 | # We do not initialize it here because the `no_*` alternative must be instantiated after the real argument 193 | bool_kwargs = copy(kwargs) 194 | 195 | # Hack because type=bool in argparse does not behave as we want. 196 | kwargs["type"] = string_to_bool 197 | if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING): 198 | # Default value is False if we have no default when of type bool. 199 | default = False if field.default is dataclasses.MISSING else field.default 200 | # This is the value that will get picked if we don't include --field_name in any way 201 | kwargs["default"] = default 202 | # This tells argparse we accept 0 or 1 value after --field_name 203 | kwargs["nargs"] = "?" 204 | # This is the value that will get picked if we do --field_name (without value) 205 | kwargs["const"] = True 206 | elif isclass(origin_type) and issubclass(origin_type, list): 207 | kwargs["type"] = field.type.__args__[0] 208 | kwargs["nargs"] = "+" 209 | if field.default_factory is not dataclasses.MISSING: 210 | kwargs["default"] = field.default_factory() 211 | elif field.default is dataclasses.MISSING: 212 | kwargs["required"] = True 213 | else: 214 | kwargs["type"] = field.type 215 | if field.default is not dataclasses.MISSING: 216 | kwargs["default"] = field.default 217 | elif field.default_factory is not dataclasses.MISSING: 218 | kwargs["default"] = field.default_factory() 219 | else: 220 | kwargs["required"] = True 221 | parser.add_argument(field_name, *aliases, **kwargs) 222 | 223 | # Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added. 224 | # Order is important for arguments with the same destination! 225 | # We use a copy of earlier kwargs because the original kwargs have changed a lot before reaching down 226 | # here and we do not need those changes/additional keys. 227 | if field.default is True and (field.type is bool or field.type == Optional[bool]): 228 | bool_kwargs["default"] = False 229 | parser.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **bool_kwargs) 230 | 231 | def _add_dataclass_arguments(self, dtype: DataClassType): 232 | if hasattr(dtype, "_argument_group_name"): 233 | parser = self.add_argument_group(dtype._argument_group_name) 234 | else: 235 | parser = self 236 | 237 | try: 238 | type_hints: Dict[str, type] = get_type_hints(dtype) 239 | except NameError: 240 | raise RuntimeError( 241 | f"Type resolution failed for {dtype}. Try declaring the class in global scope or " 242 | "removing line of `from __future__ import annotations` which opts in Postponed " 243 | "Evaluation of Annotations (PEP 563)" 244 | ) 245 | except TypeError as ex: 246 | # Remove this block when we drop Python 3.9 support 247 | if sys.version_info[:2] < (3, 10) and "unsupported operand type(s) for |" in str(ex): 248 | python_version = ".".join(map(str, sys.version_info[:3])) 249 | raise RuntimeError( 250 | f"Type resolution failed for {dtype} on Python {python_version}. Try removing " 251 | "line of `from __future__ import annotations` which opts in union types as " 252 | "`X | Y` (PEP 604) via Postponed Evaluation of Annotations (PEP 563). To " 253 | "support Python versions that lower than 3.10, you need to use " 254 | "`typing.Union[X, Y]` instead of `X | Y` and `typing.Optional[X]` instead of " 255 | "`X | None`." 256 | ) from ex 257 | raise 258 | 259 | for field in dataclasses.fields(dtype): 260 | if not field.init: 261 | continue 262 | field.type = type_hints[field.name] 263 | self._parse_dataclass_field(parser, field) 264 | 265 | def parse_args_into_dataclasses( 266 | self, 267 | args=None, 268 | return_remaining_strings=False, 269 | look_for_args_file=True, 270 | args_filename=None, 271 | args_file_flag=None, 272 | ) -> Tuple[DataClass, ...]: 273 | """ 274 | Parse command-line args into instances of the specified dataclass types. 275 | 276 | This relies on argparse's `ArgumentParser.parse_known_args`. See the doc at: 277 | docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args 278 | 279 | Args: 280 | args: 281 | List of strings to parse. The default is taken from sys.argv. (same as argparse.ArgumentParser) 282 | return_remaining_strings: 283 | If true, also return a list of remaining argument strings. 284 | look_for_args_file: 285 | If true, will look for a ".args" file with the same base name as the entry point script for this 286 | process, and will append its potential content to the command line args. 287 | args_filename: 288 | If not None, will uses this file instead of the ".args" file specified in the previous argument. 289 | args_file_flag: 290 | If not None, will look for a file in the command-line args specified with this flag. The flag can be 291 | specified multiple times and precedence is determined by the order (last one wins). 292 | 293 | Returns: 294 | Tuple consisting of: 295 | 296 | - the dataclass instances in the same order as they were passed to the initializer.abspath 297 | - if applicable, an additional namespace for more (non-dataclass backed) arguments added to the parser 298 | after initialization. 299 | - The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args) 300 | """ 301 | 302 | if args_file_flag or args_filename or (look_for_args_file and len(sys.argv)): 303 | args_files = [] 304 | 305 | if args_filename: 306 | args_files.append(Path(args_filename)) 307 | elif look_for_args_file and len(sys.argv): 308 | args_files.append(Path(sys.argv[0]).with_suffix(".args")) 309 | 310 | # args files specified via command line flag should overwrite default args files so we add them last 311 | if args_file_flag: 312 | # Create special parser just to extract the args_file_flag values 313 | args_file_parser = ArgumentParser() 314 | args_file_parser.add_argument(args_file_flag, type=str, action="append") 315 | 316 | # Use only remaining args for further parsing (remove the args_file_flag) 317 | cfg, args = args_file_parser.parse_known_args(args=args) 318 | cmd_args_file_paths = vars(cfg).get(args_file_flag.lstrip("-"), None) 319 | 320 | if cmd_args_file_paths: 321 | args_files.extend([Path(p) for p in cmd_args_file_paths]) 322 | 323 | file_args = [] 324 | for args_file in args_files: 325 | if args_file.exists(): 326 | file_args += args_file.read_text().split() 327 | 328 | # in case of duplicate arguments the last one has precedence 329 | # args specified via the command line should overwrite args from files, so we add them last 330 | args = file_args + args if args is not None else file_args + sys.argv[1:] 331 | namespace, remaining_args = self.parse_known_args(args=args) 332 | outputs = [] 333 | for dtype in self.dataclass_types: 334 | keys = {f.name for f in dataclasses.fields(dtype) if f.init} 335 | inputs = {k: v for k, v in vars(namespace).items() if k in keys} 336 | for k in keys: 337 | delattr(namespace, k) 338 | obj = dtype(**inputs) 339 | outputs.append(obj) 340 | if len(namespace.__dict__) > 0: 341 | # additional namespace. 342 | outputs.append(namespace) 343 | if return_remaining_strings: 344 | return (*outputs, remaining_args) 345 | else: 346 | if remaining_args: 347 | raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}") 348 | 349 | return (*outputs,) 350 | 351 | def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: 352 | """ 353 | Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass 354 | types. 355 | 356 | Args: 357 | args (`dict`): 358 | dict containing config values 359 | allow_extra_keys (`bool`, *optional*, defaults to `False`): 360 | Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed. 361 | 362 | Returns: 363 | Tuple consisting of: 364 | 365 | - the dataclass instances in the same order as they were passed to the initializer. 366 | """ 367 | unused_keys = set(args.keys()) 368 | outputs = [] 369 | for dtype in self.dataclass_types: 370 | keys = {f.name for f in dataclasses.fields(dtype) if f.init} 371 | inputs = {k: v for k, v in args.items() if k in keys} 372 | unused_keys.difference_update(inputs.keys()) 373 | obj = dtype(**inputs) 374 | outputs.append(obj) 375 | if not allow_extra_keys and unused_keys: 376 | raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}") 377 | return tuple(outputs) 378 | 379 | def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: 380 | """ 381 | Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the 382 | dataclass types. 383 | 384 | Args: 385 | json_file (`str` or `os.PathLike`): 386 | File name of the json file to parse 387 | allow_extra_keys (`bool`, *optional*, defaults to `False`): 388 | Defaults to False. If False, will raise an exception if the json file contains keys that are not 389 | parsed. 390 | 391 | Returns: 392 | Tuple consisting of: 393 | 394 | - the dataclass instances in the same order as they were passed to the initializer. 395 | """ 396 | with open(Path(json_file), encoding="utf-8") as open_json_file: 397 | data = json.loads(open_json_file.read()) 398 | outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys) 399 | return tuple(outputs) 400 | 401 | def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: 402 | """ 403 | Alternative helper method that does not use `argparse` at all, instead loading a yaml file and populating the 404 | dataclass types. 405 | 406 | Args: 407 | yaml_file (`str` or `os.PathLike`): 408 | File name of the yaml file to parse 409 | allow_extra_keys (`bool`, *optional*, defaults to `False`): 410 | Defaults to False. If False, will raise an exception if the json file contains keys that are not 411 | parsed. 412 | 413 | Returns: 414 | Tuple consisting of: 415 | 416 | - the dataclass instances in the same order as they were passed to the initializer. 417 | """ 418 | outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys) 419 | return tuple(outputs) 420 | -------------------------------------------------------------------------------- /src/nl2sql360/core/core.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import create_engine, inspect, text 2 | from sqlalchemy.orm import Session 3 | from loguru import logger 4 | from tqdm import tqdm 5 | from pathlib import Path 6 | from typing import Sequence, Optional, List, Dict, Tuple, Union 7 | from pathlib import Path 8 | from pandas import DataFrame 9 | import pandas as pd 10 | import itertools 11 | 12 | from ..database import * 13 | from ..parser import SQLParser 14 | from ..dataset import NL2SQLDataset 15 | from ..arguments import CoreArguments, DatasetArguments, EvaluationArguments 16 | from ..evaluator import BirdEXEvaluator, SpiderEXEMEvaluator, VesEvaluator, RVesEvaluator, F1Evaluator 17 | from ..filter import Filter, Scenario, serialize_filter, serialize_scenario 18 | 19 | 20 | class _Core: 21 | r""" 22 | Base core class implementation for importing datasets and evaluating. 23 | """ 24 | 25 | def __init__(self, core_args: "CoreArguments") -> None: 26 | self.core_args = core_args 27 | Path(core_args.core_dir).mkdir(exist_ok=True) 28 | self.engine = create_engine(f"sqlite:///{core_args.core_dir}/{core_args.core_name}.sqlite") 29 | self.insp = inspect(self.engine) 30 | Base.metadata.create_all(self.engine, checkfirst=True) # `DatasetInfo` Table Initialize 31 | self.models_dict = dict() 32 | for table_name in self.insp.get_table_names(): 33 | if table_name == "__DATASET_INFO__": 34 | continue 35 | if "_EVALUATION_" in table_name: 36 | self.models_dict[table_name] = get_evaluation_model(*get_dataset_name_and_evaluation_name_from_table_name(table_name)) 37 | else: 38 | self.models_dict[table_name] = get_dataset_model(get_dataset_name_from_table_name(table_name)) 39 | 40 | def import_dataset(self, dataset_args: "DatasetArguments") -> None: 41 | table_name = f"DATASET_{dataset_args.dataset_name}" 42 | if table_name in self.models_dict.keys(): 43 | logger.warning(f"Dataset `{dataset_args.dataset_name}` has been already imported.") 44 | return 45 | 46 | dataset_model = get_dataset_model(dataset_args.dataset_name) 47 | self.models_dict[dataset_model.__tablename__] = dataset_model 48 | Base.metadata.create_all(self.engine, checkfirst=True) 49 | logger.success(f"Dataset table `{table_name}` creation completed.") 50 | 51 | dataset = NL2SQLDataset(dataset_args) 52 | with Session(self.engine) as session: 53 | # Insert dataset info 54 | dataset_info_item = DatasetInfo( 55 | dataset_name=dataset_args.dataset_name, 56 | database_dir_path=str(Path(dataset_args.dataset_dir, dataset_args.database_dir).resolve()), 57 | tables_json_path=str(Path(dataset_args.dataset_dir, dataset_args.tables_file).resolve()) if dataset_args.tables_file else None 58 | ) 59 | session.add(dataset_info_item) 60 | 61 | # Insert dataset samples 62 | for id, (nlq, gold, db_id, complexity, db_domain) in enumerate(tqdm(list(zip( 63 | dataset.get_all_questions(), 64 | dataset.get_all_sqls(), 65 | dataset.get_all_db_ids(), 66 | dataset.get_all_sql_complexity(), 67 | dataset.get_all_database_domains() 68 | )), desc="Import dataset")): 69 | parsed_sql = SQLParser(gold, dialect=self.core_args.sql_dialect.lower()) 70 | table_item = self.models_dict[table_name]( 71 | id=id, 72 | nlq=nlq, 73 | gold=gold, 74 | db_id=db_id, 75 | complexity=complexity, 76 | db_domain=db_domain, 77 | **{attr: getattr(parsed_sql, attr) for attr in dir(parsed_sql) if attr.startswith("count_")} 78 | ) 79 | session.add(table_item) 80 | session.commit() 81 | logger.success(f"Import dataset `{dataset_args.dataset_name}` completed, {len(dataset)} samples in total.") 82 | 83 | def evaluate(self, evaluation_args: "EvaluationArguments") -> None: 84 | dataset_table_name = f"DATASET_{evaluation_args.eval_dataset}" 85 | if dataset_table_name not in self.models_dict.keys(): 86 | logger.warning(f"Dataset `{evaluation_args.eval_dataset}` has not been imported.") 87 | return 88 | 89 | table_name = f"DATASET_{evaluation_args.eval_dataset}_EVALUATION_{evaluation_args.eval_name}" 90 | if table_name in self.models_dict.keys(): 91 | logger.warning(f"Evaluation `{evaluation_args.eval_name}` on dataset `{evaluation_args.eval_dataset}` has been existed.") 92 | return 93 | 94 | evaluation_model = get_evaluation_model(evaluation_args.eval_dataset, evaluation_args.eval_name) 95 | self.models_dict[evaluation_model.__tablename__] = evaluation_model 96 | Base.metadata.create_all(self.engine, checkfirst=True) 97 | logger.success(f"Evaluation table `{table_name}` creation completed.") 98 | 99 | dataset_info = get_dataset_info(self.engine, evaluation_args.eval_dataset) 100 | if dataset_info is None: 101 | logger.error(f"Cannot find imported dataset `{evaluation_args.eval_dataset}`.") 102 | return 103 | 104 | evaluators = [] 105 | if "ex" in evaluation_args.eval_metrics: 106 | if evaluation_args.enable_spider_eval: 107 | eval_em = "em" in evaluation_args.eval_metrics and dataset_info.tables_json_path 108 | if "em" in evaluation_args.eval_metrics and dataset_info.tables_json_path is None: 109 | logger.warning(f"`EM` metric evaluation ignored, due to no imported `tables_file` for {evaluation_args.eval_dataset} dataset.") 110 | evaluators.append(SpiderEXEMEvaluator(eval_em=eval_em, eval_ex=True)) 111 | else: 112 | eval_em = "em" in evaluation_args.eval_metrics and dataset_info.tables_json_path 113 | if "em" in evaluation_args.eval_metrics and dataset_info.tables_json_path is None: 114 | logger.warning(f"`EM` metric evaluation ignored, due to no imported `tables_file` for {evaluation_args.eval_dataset} dataset.") 115 | if eval_em: 116 | evaluators.append(SpiderEXEMEvaluator(eval_em=eval_em, eval_ex=False)) 117 | evaluators.append(BirdEXEvaluator( 118 | sql_dialect=self.core_args.sql_dialect, 119 | dbname=evaluation_args.db_name, 120 | user=evaluation_args.db_user, 121 | host=evaluation_args.db_host, 122 | password=evaluation_args.db_password, 123 | port=evaluation_args.db_port 124 | )) 125 | elif "em" in evaluation_args.eval_metrics: 126 | if dataset_info.tables_json_path is None: 127 | logger.warning(f"`EM` metric evaluation ignored, due to no imported `tables_file` for {evaluation_args.eval_dataset} dataset.") 128 | else: 129 | evaluators.append(SpiderEXEMEvaluator(eval_em=True, eval_ex=False)) 130 | 131 | if "ves" in evaluation_args.eval_metrics: 132 | evaluators.append(VesEvaluator( 133 | reuse_ex=evaluation_args.enable_spider_eval, 134 | sql_dialect=self.core_args.sql_dialect, 135 | dbname=evaluation_args.db_name, 136 | user=evaluation_args.db_user, 137 | host=evaluation_args.db_host, 138 | password=evaluation_args.db_password, 139 | port=evaluation_args.db_port 140 | )) 141 | 142 | if "rves" in evaluation_args.eval_metrics: 143 | evaluators.append(RVesEvaluator( 144 | reuse_ex=evaluation_args.enable_spider_eval, 145 | sql_dialect=self.core_args.sql_dialect, 146 | dbname=evaluation_args.db_name, 147 | user=evaluation_args.db_user, 148 | host=evaluation_args.db_host, 149 | password=evaluation_args.db_password, 150 | port=evaluation_args.db_port 151 | )) 152 | 153 | if "f1" in evaluation_args.eval_metrics: 154 | evaluators.append(F1Evaluator( 155 | sql_dialect=self.core_args.sql_dialect, 156 | dbname=evaluation_args.db_name, 157 | user=evaluation_args.db_user, 158 | host=evaluation_args.db_host, 159 | password=evaluation_args.db_password, 160 | port=evaluation_args.db_port 161 | )) 162 | 163 | with open(evaluation_args.pred_sqls_file, "r", encoding="utf-8") as f: 164 | pred_sqls = f.readlines() 165 | 166 | dataset_samples = get_dataset_samples(self.engine, self.models_dict[dataset_table_name]) 167 | gold_sqls = [sample["gold"] for sample in dataset_samples] 168 | db_ids = [sample["db_id"] for sample in dataset_samples] 169 | 170 | eval_results = dict() 171 | eval_metrics = set() 172 | for evaluator in evaluators: 173 | logger.info(f"Evaluating {evaluator.get_eval_metrics()}...") 174 | exec_acc_list = eval_results.get("exec_acc", None) 175 | eval_results.update(evaluator.evaluate( 176 | gold_sqls=gold_sqls, 177 | pred_sqls=pred_sqls, 178 | db_ids=db_ids, 179 | db_dir=dataset_info.database_dir_path, 180 | tables_json_path=dataset_info.tables_json_path, 181 | exec_acc_list=exec_acc_list 182 | )) 183 | logger.success(f"Evaluating {evaluator.get_eval_metrics()} completed.") 184 | eval_metrics.update(evaluator.get_eval_metrics()) 185 | 186 | insert_data = [] 187 | for idx, pred in enumerate(pred_sqls): 188 | item = { 189 | "id": idx, 190 | "pred": pred 191 | } 192 | for metric in eval_metrics: 193 | item[metric] = eval_results[metric][idx] 194 | insert_data.append(item) 195 | 196 | with Session(self.engine) as session: 197 | for data in tqdm(insert_data, desc="Intert into evaluation table"): 198 | table_item = self.models_dict[table_name]( 199 | **data 200 | ) 201 | session.add(table_item) 202 | session.commit() 203 | logger.success(f"Evaluation `{evaluation_args.eval_name}` completed.") 204 | 205 | 206 | class Core(_Core): 207 | r""" 208 | Extended core class implementation, including more user query interfaces. 209 | """ 210 | 211 | def query_available_datasets(self) -> DataFrame: 212 | datasets = [get_dataset_name_from_table_name(table) 213 | for table in self.models_dict.keys() 214 | if table.startswith("DATASET_") and "_EVALUATION_" not in table] 215 | return DataFrame(data={"Dataset": datasets}) 216 | 217 | def query_available_evaluations(self, dataset_name: str) -> DataFrame: 218 | evaluations = [get_dataset_name_and_evaluation_name_from_table_name(table)[1] 219 | for table in self.models_dict.keys() 220 | if table.startswith(f"DATASET_{dataset_name}") and "_EVALUATION_" in table] 221 | return DataFrame(data={"Evaluation": evaluations}) 222 | 223 | def _check_dataset_valid(self, dataset_name: str) -> bool: 224 | if dataset_name in self.query_available_datasets()["Dataset"].values: 225 | return True 226 | else: 227 | logger.warning(f"Cannot find `{dataset_name}` dataset in NL2SQL360.") 228 | return False 229 | 230 | def _check_evaluation_valid(self, dataset_name: str, eval_name: str) -> bool: 231 | if eval_name in self.query_available_evaluations(dataset_name)["Evaluation"].values: 232 | return True 233 | else: 234 | logger.warning(f"Cannot find `{eval_name}` evaluation for `{dataset_name}` dataset in NL2SQL360.") 235 | return False 236 | 237 | def _check_metric_valid(self, metric: str) -> bool: 238 | if metric in METRIC_COL_MAPPING.keys(): 239 | return True 240 | else: 241 | logger.warning(f"`{metric}` metric is not supported, available metrics: (`ex`, `em`, `ves`, `rves`, `f1`, `qvt`).") 242 | return False 243 | 244 | def query_overall_performance(self, dataset_name: str, metric: str, eval_name: str) -> DataFrame: 245 | if not (self._check_dataset_valid(dataset_name) and self._check_evaluation_valid(dataset_name, eval_name) and self._check_metric_valid(metric)): 246 | return None 247 | else: 248 | if metric == "qvt": 249 | statetment = QUERY_QVT_PERFORMANCE.format( 250 | DATASET_NAME=dataset_name, 251 | EVAL_NAME=eval_name 252 | ) 253 | else: 254 | statetment = QUERY_OVERALL_PERFORMANCE.format( 255 | DATASET_NAME=dataset_name, 256 | EVAL_NAME=eval_name, 257 | METRIC_COL=METRIC_COL_MAPPING[metric] 258 | ) 259 | with self.engine.connect() as connection: 260 | result = connection.execute(text(statetment)) 261 | connection.commit() 262 | res = result.first() 263 | if res: 264 | return DataFrame(data={"Evaluation": eval_name, metric.upper(): res}).round(decimals=2) 265 | else: 266 | logger.warning("Query an empty result.") 267 | return DataFrame(data={"Evaluation": eval_name, metric.upper(): pd.NA}) 268 | 269 | def query_overall_leaderboard(self, dataset_name: str, metric: str, eval_names: List[str] = None) -> DataFrame: 270 | if not (self._check_dataset_valid(dataset_name) and self._check_metric_valid(metric)): 271 | return None 272 | 273 | if eval_names: 274 | for eval_name in eval_names: 275 | if not self._check_evaluation_valid(dataset_name, eval_name): 276 | return None 277 | else: 278 | eval_names = self.query_available_evaluations(dataset_name)["Evaluation"].values 279 | dataframes = [] 280 | for eval_name in eval_names: 281 | dataframes.append(self.query_overall_performance(dataset_name, metric, eval_name)) 282 | df = pd.concat(dataframes, ignore_index=True).sort_values(by=[metric.upper(), "Evaluation"], ascending=False) 283 | df["Rank"] = df[f"{metric.upper()}"].rank(axis=0, method="dense", ascending=False) 284 | return df 285 | 286 | def query_filter_performance(self, dataset_name: str, filter: Filter, metric: str, eval_name: str) -> DataFrame: 287 | if not (self._check_dataset_valid(dataset_name) and self._check_evaluation_valid(dataset_name, eval_name) and self._check_metric_valid(metric)): 288 | return None 289 | 290 | if metric == "qvt": 291 | logger.warning(f"QVT metric only supports overall performance.") 292 | return None 293 | 294 | statetment = QUERY_SUBSET_PERFORMANCE.format( 295 | DATASET_NAME=dataset_name, 296 | EVAL_NAME=eval_name, 297 | METRIC_COL=METRIC_COL_MAPPING[metric], 298 | WHERE_CONDITION=serialize_filter(filter) 299 | ) 300 | with self.engine.connect() as connection: 301 | result = connection.execute(text(statetment)) 302 | connection.commit() 303 | res = result.first() 304 | if res: 305 | return DataFrame(data={"Evaluation": eval_name, "Subset": filter.name, metric.upper(): res}).round(decimals=2) 306 | else: 307 | logger.warning("Query an empty result.") 308 | return DataFrame(data={"Evaluation": eval_name, "Subset": filter.name, metric.upper(): pd.NA}) 309 | 310 | def query_filter_leaderboard(self, dataset_name: str, filter: Filter, metric: str, eval_names: List[str] = None) -> DataFrame: 311 | if not (self._check_dataset_valid(dataset_name) and self._check_metric_valid(metric)): 312 | return None 313 | 314 | if eval_names: 315 | for eval_name in eval_names: 316 | if not self._check_evaluation_valid(dataset_name, eval_name): 317 | return None 318 | else: 319 | eval_names = self.query_available_evaluations(dataset_name)["Evaluation"].values 320 | dataframes = [] 321 | for eval_name in eval_names: 322 | dataframes.append(self.query_filter_performance(dataset_name, filter, metric, eval_name)) 323 | df = pd.concat(dataframes, ignore_index=True).sort_values(by=[metric.upper(), "Evaluation"], ascending=False, ignore_index=True) 324 | df["Rank"] = df[f"{metric.upper()}"].rank(axis=0, method="dense", ascending=False) 325 | return df 326 | 327 | def query_scenario_performance(self, dataset_name: str, scenario: Scenario, metric: str, eval_name: str) -> DataFrame: 328 | if not (self._check_dataset_valid(dataset_name) and self._check_evaluation_valid(dataset_name, eval_name) and self._check_metric_valid(metric)): 329 | return None 330 | 331 | if metric == "qvt": 332 | logger.warning(f"QVT metric only supports overall performance.") 333 | return None 334 | 335 | statetment = QUERY_SUBSET_PERFORMANCE.format( 336 | DATASET_NAME=dataset_name, 337 | EVAL_NAME=eval_name, 338 | METRIC_COL=METRIC_COL_MAPPING[metric], 339 | WHERE_CONDITION=serialize_scenario(scenario) 340 | ) 341 | with self.engine.connect() as connection: 342 | result = connection.execute(text(statetment)) 343 | connection.commit() 344 | res = result.first() 345 | if res: 346 | return DataFrame(data={"Evaluation": eval_name, "Subset": scenario.name, metric.upper(): res}).round(decimals=2) 347 | else: 348 | logger.warning("Query an empty result.") 349 | return DataFrame(data={"Evaluation": eval_name, "Subset": scenario.name, metric.upper(): pd.NA}) 350 | 351 | def query_scenario_leaderboard(self, dataset_name, scenario, metric, eval_names: List[str] = None) -> DataFrame: 352 | if not (self._check_dataset_valid(dataset_name) and self._check_metric_valid(metric)): 353 | return None 354 | 355 | if eval_names: 356 | for eval_name in eval_names: 357 | if not self._check_evaluation_valid(dataset_name, eval_name): 358 | return None 359 | else: 360 | eval_names = self.query_available_evaluations(dataset_name)["Evaluation"].values 361 | dataframes = [] 362 | for eval_name in eval_names: 363 | dataframes.append(self.query_scenario_performance(dataset_name, scenario, metric, eval_name)) 364 | df = pd.concat(dataframes, ignore_index=True).sort_values(by=[metric.upper(), "Evaluation"], ascending=False, ignore_index=True) 365 | df["Rank"] = df[f"{metric.upper()}"].rank(axis=0, method="dense", ascending=False) 366 | return df 367 | 368 | def query_dataset_sql_distribution(self, dataset_name: str) -> DataFrame: 369 | if not self._check_dataset_valid(dataset_name): 370 | return None 371 | else: 372 | statetment = QUERY_DATASET_SIZE.format(DATASET_NAME=dataset_name) 373 | with self.engine.connect() as connection: 374 | result = connection.execute(text(statetment)) 375 | connection.commit() 376 | total_count, unique_sqls_count = result.first() 377 | 378 | statetment = QUERY_DATASET_SQL_KEYWORDS_DISTRIBUTION.format(DATASET_NAME=dataset_name) 379 | with self.engine.connect() as connection: 380 | result = connection.execute(text(statetment)) 381 | connection.commit() 382 | (avg_count_query_fields, 383 | avg_count_group_by, 384 | avg_count_order_by, 385 | avg_count_limit, 386 | avg_count_join, 387 | avg_count_predicate, 388 | avg_count_aggregation, 389 | avg_count_scalar_function, 390 | avg_count_subquery, 391 | avg_count_set_operation, 392 | avg_count_math_compute, 393 | avg_count_logical_connecter, 394 | avg_count_distinct, 395 | avg_count_like, 396 | avg_count_control_flow, 397 | avg_count_window) = result.first() 398 | 399 | df = DataFrame(data=[ 400 | {"Metric": "Total Count", "Value": total_count}, 401 | {"Metric": "Unique SQL Count", "Value": unique_sqls_count}, 402 | {"Metric": "[QUERY FIELDS] / SQL", "Value": avg_count_query_fields}, 403 | {"Metric": "[GROUP BY] / SQL", "Value": avg_count_group_by}, 404 | {"Metric": "[ORDER BY] / SQL", "Value": avg_count_order_by}, 405 | {"Metric": "[LIMIT] / SQL", "Value": avg_count_limit}, 406 | {"Metric": "[JOIN] / SQL", "Value": avg_count_join}, 407 | {"Metric": "[PREDICATE] / SQL", "Value": avg_count_predicate}, 408 | {"Metric": "[AGGREGATION] / SQL", "Value": avg_count_aggregation}, 409 | {"Metric": "[SCALAR FUNCTION] / SQL", "Value": avg_count_scalar_function}, 410 | {"Metric": "[SUBQUERY] / SQL", "Value": avg_count_subquery}, 411 | {"Metric": "[SET OPERATION] / SQL", "Value": avg_count_set_operation}, 412 | {"Metric": "[MATH COMPUTE] / SQL", "Value": avg_count_math_compute}, 413 | {"Metric": "[LOGICAL CONNECTOR] / SQL", "Value": avg_count_logical_connecter}, 414 | {"Metric": "[DISTINCT] / SQL", "Value": avg_count_distinct}, 415 | {"Metric": "[LIKE] / SQL", "Value": avg_count_like}, 416 | {"Metric": "[CONTROL FLOW] / SQL", "Value": avg_count_control_flow}, 417 | {"Metric": "[WINDOW] / SQL", "Value": avg_count_window}, 418 | ]).round(decimals=2) 419 | 420 | return df 421 | 422 | def query_dataset_domain_distribution(self, dataset_name: str) -> DataFrame: 423 | if not self._check_dataset_valid(dataset_name): 424 | return None 425 | else: 426 | statetment = QUERY_DATASET_DOMAIN_DISTRIBUTION.format(DATASET_NAME=dataset_name) 427 | with self.engine.connect() as connection: 428 | result = connection.execute(text(statetment)) 429 | connection.commit() 430 | db_domain_count = [] 431 | for res in result: 432 | db_domain_count.append({"DB Domain": res[0], "Count": res[1]}) 433 | df = DataFrame(data=db_domain_count) 434 | return df 435 | 436 | def generate_evaluation_report(self, dataset_name: str, filters: List[Filter], scenarios: List[Scenario], metrics: List[str], eval_names: List[str] = None) -> DataFrame: 437 | if not self._check_dataset_valid(dataset_name): 438 | return None 439 | for metric in metrics: 440 | if not self._check_metric_valid(metric): 441 | return None 442 | 443 | if eval_names: 444 | for eval_name in eval_names: 445 | if not self._check_evaluation_valid(dataset_name, eval_name): 446 | return None 447 | else: 448 | eval_names = self.query_available_evaluations(dataset_name)["Evaluation"].values 449 | 450 | results = [] 451 | 452 | for eval_name in eval_names: 453 | 454 | # Overall performance 455 | 456 | data = {"Subset": "Overall"} 457 | for metric in metrics: 458 | df = self.query_overall_performance(dataset_name=dataset_name, metric=metric, eval_name=eval_name) 459 | if df is not None: 460 | data.update(df.to_dict()) 461 | results.append(DataFrame(data)) 462 | 463 | 464 | # `qvt`` metric only supports `overall performance`` 465 | if "qvt" in metrics: 466 | metrics.remove("qvt") 467 | 468 | # Filter Performance 469 | 470 | for filter in filters: 471 | data = dict() 472 | for metric in metrics: 473 | filter_df = self.query_filter_performance(dataset_name=dataset_name, filter=filter, metric=metric, eval_name=eval_name) 474 | if filter_df is not None: 475 | data.update(filter_df.to_dict()) 476 | results.append(DataFrame(data)) 477 | 478 | # Scenario Performance 479 | 480 | for scenario in scenarios: 481 | data = dict() 482 | for metric in metrics: 483 | scenario_df = self.query_scenario_performance(dataset_name=dataset_name, scenario=scenario, metric=metric, eval_name=eval_name) 484 | if scenario_df is not None: 485 | data.update(scenario_df.to_dict()) 486 | results.append(DataFrame(data)) 487 | 488 | df = pd.concat(results, ignore_index=True).sort_values(by=["Subset", "Evaluation"], ignore_index=True) 489 | return df 490 | 491 | def delete_dataset_history(self, dataset_name: str, delete_relavant_evaluations=True) -> None: 492 | logger.warning( 493 | "You are deleting the dataset history. Please enter `Y` / `YES` to confirm or enter `N` / `NO` to cancel the operation. " 494 | ) 495 | flag = input("Input your choice:\n").strip().upper() 496 | while flag not in ["Y", "YES", "N", "NO"]: 497 | logger.warning( 498 | "You are deleting the dataset history. Please enter `Y` / `YES` to confirm or enter `N` / `NO` to cancel the operation. " 499 | ) 500 | flag = input("Input your choice:\n").strip().upper() 501 | 502 | if flag in ["N", "NO"]: 503 | return 504 | 505 | if flag in ["Y", "YES"]: 506 | statements = [DELETE_DATASET_TABLE.format(DATASET_NAME=dataset_name), DELETE_DATASET_INFO.format(DATASET_NAME=dataset_name)] 507 | if delete_relavant_evaluations: 508 | for eval_name in self.query_available_evaluations(dataset_name)["Evaluation"].values: 509 | statements.append(DELETE_EVALUATION_TABLE.format(DATASET_NAME=dataset_name, EVAL_NAME=eval_name)) 510 | 511 | with self.engine.connect() as connection: 512 | for stat in statements: 513 | connection.execute(text(stat)) 514 | connection.commit() 515 | logger.success(f"Delete dataset `{dataset_name}` successfully.") 516 | return 517 | 518 | def delete_evaluation_history(self, dataset_name: str, eval_name: str) -> None: 519 | logger.warning( 520 | "You are deleting the evaluation history. Please enter `Y` / `YES` to confirm or enter `N` / `NO` to cancel the operation. " 521 | ) 522 | flag = input("Input your choice:\n").strip().upper() 523 | while flag not in ["Y", 'YES', "N", "NO"]: 524 | logger.warning( 525 | "You are deleting the dataset history. Please enter `Y` / `YES` to confirm or enter `N` / `NO` to cancel the operation. " 526 | ) 527 | flag = input("Input your choice:\n").strip().upper() 528 | 529 | if flag in ["N", "NO"]: 530 | return 531 | 532 | if flag in ["Y", "YES"]: 533 | statement = DELETE_EVALUATION_TABLE.format(DATASET_NAME=dataset_name, EVAL_NAME=eval_name) 534 | with self.engine.connect() as connection: 535 | connection.execute(text(statement)) 536 | connection.commit() 537 | logger.success(f"Delete evaluation `{eval_name}` for dataset `{dataset_name}` successfully.") 538 | return --------------------------------------------------------------------------------