├── LICENSE ├── README.md ├── eval ├── README.md ├── constants.py ├── data │ ├── databases │ │ ├── flightinfo │ │ │ └── flightinfo.duckdb │ │ ├── hn │ │ │ └── hn.duckdb │ │ ├── json │ │ │ └── json.duckdb │ │ ├── laptop │ │ │ └── laptop.duckdb │ │ ├── laptop_array │ │ │ └── laptop_array.duckdb │ │ ├── laptop_json │ │ │ └── laptop_json.duckdb │ │ ├── laptop_struct │ │ │ └── laptop_struct.duckdb │ │ ├── none │ │ │ └── none.duckdb │ │ ├── nyc │ │ │ └── nyc.duckdb │ │ ├── product │ │ │ └── product.duckdb │ │ ├── transactions │ │ │ └── transactions.duckdb │ │ └── who │ │ │ └── who.duckdb │ ├── dev.json │ └── tables.json ├── data_utils.py ├── doc_retriever.py ├── evaluate.py ├── get_manifest.py ├── loaders.py ├── metric_utils.py ├── predict.py ├── prompt_formatters.py ├── schema.py └── text_to_sql.py ├── examples ├── local_demo.ipynb ├── nyc.duckdb ├── utils.py └── validate_sql.py └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DuckDB-NSQL 2 | Numbers Station Text to SQL model for DuckDB. 3 | 4 | NSQL is a family of autoregressive open-source foundational models (FMs) that are particularly designed for SQL generation tasks. We are thrilled to introduce DuckDB-NSQL in this repository, an FM tailored for local DuckDB SQL analytics tasks. All model weights can be found on HuggingFace. 5 | 6 | | Model Name | Size | Link | 7 | | --------------------------------------| ---- | -------------------------------------------------------------- | 8 | | motherduckdb/DuckDB-NSQL-7B-v0.1 | 7B | [link](https://huggingface.co/motherduckdb/DuckDB-NSQL-7B-v0.1) | 9 | | motherduckdb/DuckDB-NSQL-7B-v0.1-GGUF | 7B | [link](https://huggingface.co/motherduckdb/DuckDB-NSQL-7B-v0.1-GGUF)| 10 | 11 | ## Setup 12 | To install all the necessary dependencies, please run 13 | ``` 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | ## Usage 18 | Please refer to the examples in the `examples/` folder to learn how to connect to a local DuckDB and directly query your data. A simple notebook is provided in the `examples/` directory for reference. 19 | 20 | To host the model with llama.cpp, please execute the following: 21 | 22 | ```python 23 | # Import necessary modules 24 | from llama_cpp import Llama 25 | from wurlitzer import pipes 26 | 27 | # Set up client with model path and context size 28 | with pipes() as (out, err): 29 | client = Llama( 30 | model_path="DuckDB-NSQL-7B-v0.1-q8_0.gguf", 31 | n_ctx=2048, 32 | ) 33 | ``` 34 | 35 | To load the DuckDB database and query against it, please execute the following: 36 | 37 | ```python 38 | # Import necessary modules 39 | import duckdb 40 | from utils import generate_sql 41 | 42 | # Connect to DuckDB database 43 | con = duckdb.connect("nyc.duckdb") 44 | 45 | # Sample question for SQL generation 46 | question = "alter taxi table and add struct column with name test and keys a:int, b:double" 47 | 48 | # Generate SQL, check validity, and print 49 | sql = generate_sql(question, con, client) 50 | print(sql) 51 | ``` 52 | 53 | ## Training Data 54 | 55 | The training data for this model consists of two parts: 1) 200k synthetically generated DuckDB SQL queries, based on the DuckDB v.0.9.2 documentation, and 2) labeled text-to-SQL pairs from [NSText2SQL](https://huggingface.co/datasets/NumbersStation/NSText2SQL) transpiled to DuckDB SQL using [sqlglot](https://github.com/tobymao/sqlglot). 56 | 57 | ## Evaluate the benchmark 58 | 59 | Please refer to the `eval/` folder to check the details for evaluating the model against our proposed DuckDB benchmark. 60 | 61 | ## Acknowledgement 62 | 63 | We would like to express our appreciation to all authors of the evaluation scripts. Their work made this project possible. 64 | -------------------------------------------------------------------------------- /eval/README.md: -------------------------------------------------------------------------------- 1 | This folder contains the suite for evaluating the DuckDB-Text2SQL model. 2 | 3 | Please install the dependencies listed in the requirements.txt file located in the parent folder. 4 | 5 | ## Setup 6 | To evaluate against the benchmark dataset, you need to prepare the evaluation script using this benchmark. 7 | 8 | ``` 9 | mkdir metrics 10 | cd metrics 11 | git clone git@github.com:ElementAI/test-suite-sql-eval.git test_suite_sql_eval 12 | cd .. 13 | ``` 14 | 15 | You need to add a new remote to evaluate against duckdb in the test-suite-sql-eval folder. And check the latest duckdb-only branch (640a12975abf75a94e917caca149d56dbc6bcdd7). 16 | 17 | ``` 18 | git remote add till https://github.com/tdoehmen/test-suite-sql-eval.git 19 | git fetch till 20 | git checkout till/duckdb-only 21 | ``` 22 | 23 | Next, prepare the docs for retrieval. 24 | ``` 25 | mkdir docs 26 | cd docs 27 | git clone https://github.com/duckdb/duckdb-web.git 28 | cd .. 29 | ``` 30 | 31 | #### Dataset 32 | The benchmark dataset is located in the `data/` folder and includes all databases (`data/databases`), table schemas (`data/tables.json`), and examples (`data/dev.json`). 33 | 34 | #### Eval 35 | Start a manifest session with the model you want to evaluate. 36 | 37 | ```bash 38 | python -m manifest.api.app \ 39 | --model_type huggingface \ 40 | --model_generation_type text-generation \ 41 | --model_name_or_path motherduckdb/DuckDB-NSQL-7B-v0.1 \ 42 | --fp16 \ 43 | --device 0 44 | ``` 45 | 46 | Then, from the `DuckDB-NSQL` main folder, run: 47 | 48 | ```bash 49 | python eval/predict.py \ 50 | predict \ 51 | eval/data/dev.json \ 52 | eval/data/tables.json \ 53 | --output-dir output/ \ 54 | --stop-tokens ';' \ 55 | --stop-tokens '--' \ 56 | --stop-tokens '```' \ 57 | --stop-tokens '###' \ 58 | --overwrite-manifest \ 59 | --manifest-client huggingface \ 60 | --manifest-connection http://localhost:5000 \ 61 | --prompt-format duckdbinst 62 | ``` 63 | This will format the prompt using the duckdbinst style. 64 | 65 | To evaluate the prediction, first run the following in a Python shell: 66 | 67 | ```python 68 | try: 69 | import duckdb 70 | 71 | con = duckdb.connect() 72 | con.install_extension("httpfs") 73 | con.load_extension("httpfs") 74 | except Exception as e: 75 | print(f"Error loading duckdb extensions: {e}") 76 | ``` 77 | 78 | Then, run the evaluation script: 79 | 80 | ```bash 81 | python eval/evaluate.py \ 82 | evaluate \ 83 | --gold eval/data/dev.json \ 84 | --db eval/data/databases/ \ 85 | --tables eval/data/tables.json \ 86 | --output-dir output/ \ 87 | --pred [PREDICITON_FILE] 88 | ``` 89 | 90 | To view the output, all the information is located in the prediction file in the [output-dir]. Here, `query` is gold and `pred` is predicted. 91 | -------------------------------------------------------------------------------- /eval/constants.py: -------------------------------------------------------------------------------- 1 | """Constants.""" 2 | 3 | from prompt_formatters import ( 4 | DuckDBFormatter, 5 | DuckDBInstFormatter, 6 | DuckDBInstNoShorthandFormatter, 7 | RajkumarFormatter, 8 | DuckDBChat, 9 | ) 10 | 11 | PROMPT_FORMATTERS = { 12 | "rajkumar": RajkumarFormatter, 13 | "duckdb": DuckDBFormatter, 14 | "duckdbinst": DuckDBInstFormatter, 15 | "duckdbinstnoshort": DuckDBInstNoShorthandFormatter, 16 | "duckdbchat": DuckDBChat, 17 | } 18 | -------------------------------------------------------------------------------- /eval/data/databases/flightinfo/flightinfo.duckdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NumbersStationAI/DuckDB-NSQL/5860f7a0796f7a22e041b76b430559490c199cd7/eval/data/databases/flightinfo/flightinfo.duckdb -------------------------------------------------------------------------------- /eval/data/databases/hn/hn.duckdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NumbersStationAI/DuckDB-NSQL/5860f7a0796f7a22e041b76b430559490c199cd7/eval/data/databases/hn/hn.duckdb -------------------------------------------------------------------------------- /eval/data/databases/json/json.duckdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NumbersStationAI/DuckDB-NSQL/5860f7a0796f7a22e041b76b430559490c199cd7/eval/data/databases/json/json.duckdb -------------------------------------------------------------------------------- /eval/data/databases/laptop/laptop.duckdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NumbersStationAI/DuckDB-NSQL/5860f7a0796f7a22e041b76b430559490c199cd7/eval/data/databases/laptop/laptop.duckdb -------------------------------------------------------------------------------- /eval/data/databases/laptop_array/laptop_array.duckdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NumbersStationAI/DuckDB-NSQL/5860f7a0796f7a22e041b76b430559490c199cd7/eval/data/databases/laptop_array/laptop_array.duckdb -------------------------------------------------------------------------------- /eval/data/databases/laptop_json/laptop_json.duckdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NumbersStationAI/DuckDB-NSQL/5860f7a0796f7a22e041b76b430559490c199cd7/eval/data/databases/laptop_json/laptop_json.duckdb -------------------------------------------------------------------------------- /eval/data/databases/laptop_struct/laptop_struct.duckdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NumbersStationAI/DuckDB-NSQL/5860f7a0796f7a22e041b76b430559490c199cd7/eval/data/databases/laptop_struct/laptop_struct.duckdb -------------------------------------------------------------------------------- /eval/data/databases/none/none.duckdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NumbersStationAI/DuckDB-NSQL/5860f7a0796f7a22e041b76b430559490c199cd7/eval/data/databases/none/none.duckdb -------------------------------------------------------------------------------- /eval/data/databases/nyc/nyc.duckdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NumbersStationAI/DuckDB-NSQL/5860f7a0796f7a22e041b76b430559490c199cd7/eval/data/databases/nyc/nyc.duckdb -------------------------------------------------------------------------------- /eval/data/databases/product/product.duckdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NumbersStationAI/DuckDB-NSQL/5860f7a0796f7a22e041b76b430559490c199cd7/eval/data/databases/product/product.duckdb -------------------------------------------------------------------------------- /eval/data/databases/transactions/transactions.duckdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NumbersStationAI/DuckDB-NSQL/5860f7a0796f7a22e041b76b430559490c199cd7/eval/data/databases/transactions/transactions.duckdb -------------------------------------------------------------------------------- /eval/data/databases/who/who.duckdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NumbersStationAI/DuckDB-NSQL/5860f7a0796f7a22e041b76b430559490c199cd7/eval/data/databases/who/who.duckdb -------------------------------------------------------------------------------- /eval/data/dev.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "db_id": "hn", 4 | "query": "SELECT COUNT(*) as domain_count, \nSUBSTRING(SPLIT_PART(url, '//', 2), 1, POSITION('/' IN SPLIT_PART(url, '//', 2)) - 1) as domain \nFROM hacker_news\nWHERE url IS NOT NULL GROUP BY domain ORDER BY domain_count DESC LIMIT 10;", 5 | "setup_sql": ";", 6 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 7 | "question": "what are the top domains being shared on hacker_news?", 8 | "category": "hard" 9 | }, 10 | { 11 | "db_id": "laptop", 12 | "query": "SELECT c.firstname, c.lastname, COUNT(*) AS num_pcs_bought\nFROM customers c\nJOIN sales s ON c.customer_id = s.customer_id\nJOIN pcs p ON s.model = p.model\nGROUP BY c.customer_id, c.firstname, c.lastname\nORDER BY num_pcs_bought DESC\nLIMIT 1;", 13 | "setup_sql": ";", 14 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 15 | "question": "Who bought the most PCs, print also the users name?", 16 | "category": "medium" 17 | }, 18 | { 19 | "db_id": "transactions", 20 | "query": "select users.id, users.name, sum(transactions.amount) as balance from users join transactions on users.id = transactions.user_id group by users.id, users.name having balance = 0", 21 | "setup_sql": ";", 22 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 23 | "question": "list the names off account holders who have negative balances", 24 | "category": "easy" 25 | }, 26 | { 27 | "db_id": "laptop", 28 | "query": "SELECT model FROM products WHERE maker = 'B';", 29 | "setup_sql": ";", 30 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 31 | "question": "List only the model number of all products made by maker B.", 32 | "category": "easy" 33 | }, 34 | { 35 | "db_id": "laptop", 36 | "query": "SELECT model FROM products WHERE maker <> 'B';", 37 | "setup_sql": ";", 38 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 39 | "question": "List the model numbers of all products not made by maker B.", 40 | "category": "easy" 41 | }, 42 | { 43 | "db_id": "laptop", 44 | "query": "SELECT AVG(speed) FROM pcs WHERE speed >= 3.00", 45 | "setup_sql": ";", 46 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 47 | "question": "Return the average speed all PCs with speed >= 3.00", 48 | "category": "easy" 49 | }, 50 | { 51 | "db_id": "laptop", 52 | "query": "SELECT MAX(price) FROM printers WHERE color = 'TRUE' AND type='laser'", 53 | "setup_sql": ";", 54 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 55 | "question": "Return the price of the most expensive color laser printer", 56 | "category": "medium" 57 | }, 58 | { 59 | "db_id": "laptop", 60 | "query": "SELECT MIN(paid) FROM sales WHERE type_of_payment LIKE '%visa%'", 61 | "setup_sql": ";", 62 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 63 | "question": "Return the minimum amount paid by customers who used a visa card (debit or credit) to purchase a product", 64 | "category": "medium" 65 | }, 66 | { 67 | "db_id": "laptop", 68 | "query": "SELECT customer_id FROM customers WHERE firstname LIKE '%e%' OR lastname LIKE '%e%'", 69 | "setup_sql": ";", 70 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 71 | "question": "Find the customer_id of customers who have the letter 'e' either in their first name or in their last name", 72 | "category": "medium" 73 | }, 74 | { 75 | "db_id": "laptop", 76 | "query": "SELECT model, price/0.85 AS 'price (USD)' FROM laptops WHERE ram >= 1024", 77 | "setup_sql": ";", 78 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 79 | "question": "Assume all prices in the table Laptops are in Euro. List the prices of laptops with at least 1024 ram. You should return the price in USD in a column called 'price (USD)'. Assume that 1 USD = 0.85 EURO. Name the price column 'price (USD)'.", 80 | "category": "hard" 81 | }, 82 | { 83 | "db_id": "laptop", 84 | "query": "SELECT maker FROM products GROUP BY maker HAVING COUNT(maker) > 4;", 85 | "setup_sql": ";", 86 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 87 | "question": "Return a list of makers that make more than four different products.", 88 | "category": "medium" 89 | }, 90 | { 91 | "db_id": "laptop", 92 | "query": "SELECT model FROM laptops WHERE speed > 1.7 ORDER BY speed DESC;", 93 | "setup_sql": ";", 94 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 95 | "question": "List all the laptop model numbers that have a speed greater than 1.7 in descending order of speed.", 96 | "category": "medium" 97 | }, 98 | { 99 | "db_id": "laptop", 100 | "query": "SELECT firstname \n FROM sales \n JOIN customers ON sales.customer_id = customers.customer_id \n GROUP BY firstname \n ORDER BY COUNT(firstname);", 101 | "setup_sql": ";", 102 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 103 | "question": "List firstnames of customers in an ascending order based on the number of purchases made by customers with this firstname.", 104 | "category": "medium" 105 | }, 106 | { 107 | "db_id": "laptop", 108 | "query": "SELECT DISTINCT maker FROM products JOIN pcs ON products.model = pcs.model WHERE ram > 1500;", 109 | "setup_sql": ";", 110 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 111 | "question": "List all the makers (with only one entry per maker) who make PCs with RAM greater than 1500.", 112 | "category": "medium" 113 | }, 114 | { 115 | "db_id": "laptop", 116 | "query": "SELECT city, AVG(paid) as 'avg_spend' FROM sales JOIN customers ON sales.customer_id = customers.customer_id GROUP BY city", 117 | "setup_sql": ";", 118 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 119 | "question": "Find the city and the average amount of money spent by customers in each city. Name the column for the amount 'avg_spend'", 120 | "category": "medium" 121 | }, 122 | { 123 | "db_id": "laptop", 124 | "query": "SELECT color, MAX(price) as 'max_price' FROM printers GROUP BY color;", 125 | "setup_sql": ";", 126 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 127 | "question": "Find the maximum price for each color of printer. Name the column for the maximum price 'max_price'", 128 | "category": "medium" 129 | }, 130 | { 131 | "db_id": "who", 132 | "query": "select country_name, max(pm25_concentration) as worst_pm25_for_country\nfrom ambient_air_quality\ngroup by country_name\norder by worst_pm25_for_country desc\nlimit 1", 133 | "setup_sql": ";", 134 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 135 | "question": "Find the country with the worst single reading of air quality (highest PM 2.5 value). Show the PM 2.5 value as well.", 136 | "category": "medium" 137 | }, 138 | { 139 | "db_id": "who", 140 | "query": "select country_name, avg(pm25_concentration) as worst_avg_pm25_for_country\nfrom ambient_air_quality\ngroup by country_name\norder by worst_avg_pm25_for_country desc\nlimit 1", 141 | "setup_sql": ";", 142 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 143 | "question": "Find the country with the worst average air quality (highest PM 2.5 value). Show the PM 2.5 value as well.", 144 | "category": "medium" 145 | }, 146 | { 147 | "db_id": "who", 148 | "query": "select distinct country_name from ambient_air_quality order by country_name", 149 | "setup_sql": ";", 150 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 151 | "question": "Find all countries for which WHO air quality data is available. Sort alphabetically.", 152 | "category": "medium" 153 | }, 154 | { 155 | "db_id": "who", 156 | "query": "select year, avg(pm25_concentration) from ambient_air_quality \nwhere country_name = 'Singapore'\ngroup by year\norder by year", 157 | "setup_sql": ";", 158 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 159 | "question": "Find Singapore air quality defined as PM2.5 concentration over time", 160 | "category": "medium" 161 | }, 162 | { 163 | "db_id": "nyc", 164 | "query": "SELECT COLUMNS('^trip_') FROM rideshare LIMIT 10;", 165 | "setup_sql": ";", 166 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 167 | "question": "select only the column names from the rideshare table that start with trip_ and return the first 10 values", 168 | "category": "duckdb" 169 | }, 170 | { 171 | "db_id": "nyc", 172 | "query": "SELECT * FROM rideshare USING SAMPLE 1%;", 173 | "setup_sql": ";", 174 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 175 | "question": "select a 1% sample from the nyc.rideshare table", 176 | "category": "duckdb" 177 | }, 178 | { 179 | "db_id": "laptop", 180 | "query": "SELECT * EXCLUDE (customer_id) FROM customers;\n", 181 | "setup_sql": ";", 182 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 183 | "question": "select all columns from the customer table, except customer_id", 184 | "category": "duckdb" 185 | }, 186 | { 187 | "db_id": "nyc", 188 | "query": "SUMMARIZE rideshare;", 189 | "setup_sql": ";", 190 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 191 | "question": "show summary statistics of the rideshare table", 192 | "category": "duckdb" 193 | }, 194 | { 195 | "db_id": "none", 196 | "query": "SELECT * FROM read_csv_auto(\n'https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv')", 197 | "setup_sql": ";", 198 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 199 | "question": "read a CSV from https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv", 200 | "category": "duckdb" 201 | }, 202 | { 203 | "db_id": "none", 204 | "query": "COPY (SELECT * FROM read_csv_auto(\n'https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv'))\nTO 'titanic.parquet' (FORMAT 'parquet');", 205 | "setup_sql": ";", 206 | "validation_sql": "SELECT * FROM 'titanic.parquet'", 207 | "question": "read a CSV from https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv and convert it to a parquet file called \"titanic\"", 208 | "category": "duckdb" 209 | }, 210 | { 211 | "db_id": "none", 212 | "query": "CREATE TABLE titanic AS (SELECT * FROM read_csv_auto(\n'https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv'))", 213 | "setup_sql": ";", 214 | "validation_sql": "SELECT * FROM titanic;", 215 | "question": "create a table called \"titanic\" from CSV file https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv", 216 | "category": "duckdb" 217 | }, 218 | { 219 | "db_id": "none", 220 | "query": "PRAGMA default_null_order='NULLS LAST';", 221 | "setup_sql": ";", 222 | "validation_sql": "SELECT current_setting('default_null_order');", 223 | "question": "configure duckdb to put null values last when sorting", 224 | "category": "duckdb" 225 | }, 226 | { 227 | "db_id": "none", 228 | "query": "CREATE TABLE IF NOT EXISTS products (\n maker varchar(10),\n model varchar(10),\n type varchar(10));", 229 | "setup_sql": ";", 230 | "validation_sql": "SELECT * FROM products;", 231 | "question": "create a table about products, that contains a maker, model and type column", 232 | "category": "ddl" 233 | }, 234 | { 235 | "db_id": "product", 236 | "query": "INSERT INTO products (maker, model, type)\nVALUES\n ('A', '1001', 'pc');", 237 | "setup_sql": ";", 238 | "validation_sql": "SELECT * FROM products;", 239 | "question": "add a row with values for model \"1001\" of type \"pc\", from maker \"A\" to products table", 240 | "category": "ddl" 241 | }, 242 | { 243 | "db_id": "none", 244 | "query": "CALL pragma_version();\n", 245 | "setup_sql": ";", 246 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 247 | "question": "get current version of duckdb", 248 | "category": "duckdb" 249 | }, 250 | { 251 | "db_id": "nyc", 252 | "query": "PRAGMA table_info('rideshare');", 253 | "setup_sql": ";", 254 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 255 | "question": "list all columns in table nyc.rideshare", 256 | "category": "duckdb" 257 | }, 258 | { 259 | "db_id": "nyc", 260 | "query": "PRAGMA show_tables;", 261 | "setup_sql": ";", 262 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 263 | "question": "show all tables in the curent database", 264 | "category": "duckdb" 265 | }, 266 | { 267 | "db_id": "laptop", 268 | "query": "SELECT customer_id, model, sum(paid) FROM sales GROUP BY ALL", 269 | "setup_sql": ";", 270 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 271 | "question": "how much did each customer spend per model type?", 272 | "category": "easy" 273 | }, 274 | { 275 | "db_id": "nyc", 276 | "query": "SELECT Max(datediff('minute', tpep_pickup_datetime, tpep_dropoff_datetime)) from nyc.taxi", 277 | "setup_sql": ";", 278 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 279 | "question": "What was the longest taxi ride in minutes?", 280 | "category": "hard" 281 | }, 282 | { 283 | "db_id": "who", 284 | "query": "with per_region as (\n select avg(pm10_concentration) as avg_pm10, who_region from ambient_air_quality group by who_region\n), max_region as (\n select who_region from per_region where avg_pm10 = (select max(avg_pm10) from per_region)\n), min_city_value_in_max_region as (\n select min(pm10_concentration) from ambient_air_quality where who_region in (from max_region)\n), min_city_in_max_region as (\n select city from ambient_air_quality where pm10_concentration in (from min_city_value_in_max_region) and who_region in (from max_region)\n)\nfrom min_city_in_max_region", 285 | "setup_sql": ";", 286 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 287 | "question": "What is the city with the lowest pm10 concentration in the region with the highest average pm10 concentration?", 288 | "category": "hard" 289 | }, 290 | { 291 | "db_id": "hn", 292 | "query": "SELECT *, regexp_extract(text, '([a-z0-9_\\.-]+)@([\\da-z\\.-]+)\\.([a-z\\.]{2,63})',0) email from hacker_news where email[:4]='test'", 293 | "setup_sql": ";", 294 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 295 | "question": "Get all posts on hn that contain an email address starting with test. Return all original columns, plus a new column containing the email address.", 296 | "category": "hard" 297 | }, 298 | { 299 | "db_id": "json", 300 | "query": "SELECT employee.id, employee.first_name FROM employee_json", 301 | "setup_sql": ";", 302 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 303 | "question": "Extract id and first_name properties as individual columns from the employee struct", 304 | "category": "duckdb" 305 | }, 306 | { 307 | "db_id": "who", 308 | "query": "SELECT who_region[1]::INT as region, * EXCLUDE (who_region) FROM who.ambient_air_quality", 309 | "setup_sql": ";", 310 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 311 | "question": "count quality measurements per region. Make sure to return the region code (first char of who_region) as integer and sort by region.", 312 | "category": "duckdb" 313 | }, 314 | { 315 | "db_id": "flightinfo", 316 | "query": "SELECT seat.seat_number FROM seat \nJOIN direct_flight ON direct_flight.flight_number = seat.flight_number \nJOIN airport AS departure_airport ON departure_airport.iata_code = direct_flight.departure_airport_iata_code \nJOIN airport AS arriving_airport ON arriving_airport.iata_code = direct_flight.arriving_airport_iata_code \nJOIN city AS departure_city ON departure_city.city_zipcode = departure_airport.city_zip_code \nJOIN city AS arriving_city ON arriving_city.city_zipcode = arriving_airport.city_zip_code \nWHERE departure_city.city_name = 'Bruxelles' AND arriving_city.city_name = 'Newark';", 317 | "setup_sql": ";", 318 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 319 | "question": "Which seats were available on the flight from Bruxelles to Newark?", 320 | "category": "hard" 321 | }, 322 | { 323 | "db_id": "laptop", 324 | "query": "COPY customers FROM 'customers_12_12_2023.csv';", 325 | "setup_sql": "COPY customers TO 'customers_12_12_2023.csv';", 326 | "validation_sql": "SELECT * FROM customers;", 327 | "question": "copy content of csv file customers_12_12_2023.csv into customers table", 328 | "category": "duckdb" 329 | }, 330 | { 331 | "db_id": "laptop", 332 | "query": "COPY customers FROM 'customers_12_12_2023.csv' (DELIMITER '\\t');", 333 | "setup_sql": "COPY customers TO 'customers_12_12_2023.csv' (FORMAT CSV, DELIMITER '\\t');", 334 | "validation_sql": "SELECT * FROM customers;", 335 | "question": "copy content of csv file costomers_12_12_2023.csv into customers table with tab separator", 336 | "category": "duckdb" 337 | }, 338 | { 339 | "db_id": "laptop", 340 | "query": "COPY customers FROM 'customers_partitioned/city=Amsterdam/*.parquet';", 341 | "setup_sql": "COPY customers TO 'customers_partitioned' (FORMAT PARQUET, PARTITION_BY (city), OVERWRITE_OR_IGNORE True);", 342 | "validation_sql": "SELECT * FROM customers;;", 343 | "question": "copy any parquet files from 'customers_partitioned/city=Amsterdam/' into customers table", 344 | "category": "duckdb" 345 | }, 346 | { 347 | "db_id": "laptop", 348 | "query": "COPY customers(customer_id) FROM 'customers_customer_ids.csv';", 349 | "setup_sql": "COPY customers(customer_id) TO 'customers_customer_ids.csv';", 350 | "validation_sql": "SELECT * FROM customers;", 351 | "question": "copy only the customer_id column from the customers_customer_ids.csv into the customers tables", 352 | "category": "duckdb" 353 | }, 354 | { 355 | "db_id": "laptop", 356 | "query": "CREATE TABLE test_tbl AS SELECT * FROM read_json_auto('test.json');", 357 | "setup_sql": "COPY customers TO 'test.json'\n", 358 | "validation_sql": "SELECT * FROM test_tbl;", 359 | "question": "read json file from test.json and create new table from it called 'test_tbl'", 360 | "category": "duckdb" 361 | }, 362 | { 363 | "db_id": "laptop", 364 | "query": "SELECT * FROM read_csv_auto('test.csv');", 365 | "setup_sql": "COPY customers TO 'test.csv';", 366 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 367 | "question": "read csv from test.csv", 368 | "category": "duckdb" 369 | }, 370 | { 371 | "db_id": "laptop", 372 | "query": "SELECT * FROM read_csv_auto('test.csv', columns={'customer_id': 'VARCHAR', 'firstname': 'VARCHAR', 'lastname': 'VARCHAR'});", 373 | "setup_sql": "COPY customers(customer_id, firstname, lastname) TO 'test.csv';", 374 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 375 | "question": "read csv from test.csv with predefined column and types - customer_id: string, firstname: string, lastname: string", 376 | "category": "duckdb" 377 | }, 378 | { 379 | "db_id": "laptop", 380 | "query": "SELECT * EXCLUDE (ram, hd) FROM pcs;", 381 | "setup_sql": ";", 382 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 383 | "question": "select all columns from pcs table except for ram and hd", 384 | "category": "duckdb" 385 | }, 386 | { 387 | "db_id": "laptop", 388 | "query": "SELECT COLUMNS('name$') FROM customers;", 389 | "setup_sql": ";", 390 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 391 | "question": "select all columns ending with 'name' from customers table", 392 | "category": "duckdb" 393 | }, 394 | { 395 | "db_id": "laptop", 396 | "query": "SELECT LENGTH(COLUMNS('name$')) FROM customers", 397 | "setup_sql": ";", 398 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 399 | "question": "for each column ending with 'name' in the customers table, compute the string length", 400 | "category": "duckdb" 401 | }, 402 | { 403 | "db_id": "laptop", 404 | "query": "SELECT * REPLACE (upper(city) AS city) FROM customers;", 405 | "setup_sql": ";", 406 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 407 | "question": "get all columns from customer table, and make all city names uppercase", 408 | "category": "duckdb" 409 | }, 410 | { 411 | "db_id": "laptop", 412 | "query": "EXPLAIN SELECT * FROM customers", 413 | "setup_sql": ";", 414 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 415 | "question": "show query plan for query: SELECT * from customers", 416 | "category": "duckdb" 417 | }, 418 | { 419 | "db_id": "laptop", 420 | "query": "SELECT ascii(lastname) FROM customers;", 421 | "setup_sql": ";", 422 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 423 | "question": "get the first character of the firstname column and cast it to an INT", 424 | "category": "duckdb" 425 | }, 426 | { 427 | "db_id": "laptop", 428 | "query": "SELECT model, speed::INTEGER FROM laptops;", 429 | "setup_sql": ";", 430 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 431 | "question": "get laptop name and speed, return the speed as integer", 432 | "category": "duckdb" 433 | }, 434 | { 435 | "db_id": "laptop_array", 436 | "query": "SELECT phone_numbers[1] FROM customers;", 437 | "setup_sql": ";", 438 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 439 | "question": "get the first phone number of each customer", 440 | "category": "duckdb" 441 | }, 442 | { 443 | "db_id": "laptop_array", 444 | "query": "INSERT INTO customers(customer_id, phone_numbers) VALUES (5, ['12312323', '23123344']);", 445 | "setup_sql": ";", 446 | "validation_sql": "SELECT * FROM customers;", 447 | "question": "insert two phone numbers to customer with id 5 [\\\"12312323\\\", and '23123344']", 448 | "category": "duckdb" 449 | }, 450 | { 451 | "db_id": "laptop", 452 | "query": "ALTER TABLE customers ADD COLUMN phone_numbers VARCHAR[];", 453 | "setup_sql": ";", 454 | "validation_sql": "DESCRIBE customers;", 455 | "question": "how to add a new column phone_numbers to the customers table, with array type varchar", 456 | "category": "duckdb" 457 | }, 458 | { 459 | "db_id": "laptop", 460 | "query": "SELECT firstname[1] FROM customers;", 461 | "setup_sql": ";", 462 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 463 | "question": "get the first letter of the customers firstname", 464 | "category": "duckdb" 465 | }, 466 | { 467 | "db_id": "laptop_array", 468 | "query": "SELECT phone_numbers[:2] FROM customers;", 469 | "setup_sql": ";", 470 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 471 | "question": "get the first two phone numbers from the phone numbers array of each customer", 472 | "category": "duckdb" 473 | }, 474 | { 475 | "db_id": "laptop", 476 | "query": "SELECT {'a':1, 'b':2, 'c':3};", 477 | "setup_sql": ";", 478 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 479 | "question": "create a struct with keys a, b, c and values 1,2,3", 480 | "category": "duckdb" 481 | }, 482 | { 483 | "db_id": "laptop", 484 | "query": "SELECT [1,2,3];\n", 485 | "setup_sql": ";", 486 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 487 | "question": "create array with values 1,2,3", 488 | "category": "duckdb" 489 | }, 490 | { 491 | "db_id": "laptop", 492 | "query": "CREATE TABLE test (embeddings FLOAT[100]);", 493 | "setup_sql": ";", 494 | "validation_sql": "DESCRIBE test;", 495 | "question": "create table test with a fix-sized array column with 100 dimenions, called embeddings", 496 | "category": "duckdb" 497 | }, 498 | { 499 | "db_id": "laptop", 500 | "query": "CREATE TABLE test (person STRUCT(name VARCHAR, id INTEGER));", 501 | "setup_sql": ";", 502 | "validation_sql": "DESCRIBE test;", 503 | "question": "create table test with a struct column called person with properties name and id", 504 | "category": "duckdb" 505 | }, 506 | { 507 | "db_id": "laptop_struct", 508 | "query": "SELECT person.name, person.id FROM test;", 509 | "setup_sql": ";", 510 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 511 | "question": "get persons name and persons id from the test table.", 512 | "category": "duckdb" 513 | }, 514 | { 515 | "db_id": "laptop", 516 | "query": "UPDATE customers SET email = NULL;", 517 | "setup_sql": ";", 518 | "validation_sql": "SELECT email FROM customers;", 519 | "question": "remove all values from email column in customers table", 520 | "category": "duckdb" 521 | }, 522 | { 523 | "db_id": "laptop_json", 524 | "query": "ALTER TABLE customers ALTER COLUMN email SET DATA TYPE VARCHAR;", 525 | "setup_sql": ";", 526 | "validation_sql": "DESCRIBE customers;", 527 | "question": "make customer email of type VARCHAR", 528 | "category": "duckdb" 529 | }, 530 | { 531 | "db_id": "laptop_json", 532 | "query": "INSERT INTO customers (customer_id, email) VALUES (5,'{\"from\": \"test2@gmail.com\", \"to\": \"test@gmail.com\"}');", 533 | "setup_sql": ";", 534 | "validation_sql": "SELECT * FROM customers;", 535 | "question": "insert json into customer email for customer id 5: {'from': 'test2@gmail.com', 'to': 'test@gmail.com'}", 536 | "category": "duckdb" 537 | }, 538 | { 539 | "db_id": "laptop_json", 540 | "query": "SELECT customers.email->>'from' FROM customers;", 541 | "setup_sql": ";", 542 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 543 | "question": "get 'from' field from customer email json", 544 | "category": "duckdb" 545 | }, 546 | { 547 | "db_id": "laptop", 548 | "query": "SUMMARIZE customers;", 549 | "setup_sql": ";", 550 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 551 | "question": "summarize the customer table", 552 | "category": "duckdb" 553 | }, 554 | { 555 | "db_id": "laptop", 556 | "query": "SELECT * FROM customers USING SAMPLE 10% (reservoir);", 557 | "setup_sql": ";", 558 | "validation_sql": "SELECT count(*) FROM ddb_benchmark_result;", 559 | "question": "sample 10% from the customers table using reservoir sampling", 560 | "category": "duckdb" 561 | }, 562 | { 563 | "db_id": "laptop", 564 | "query": "SET threads = 10;", 565 | "setup_sql": ";", 566 | "validation_sql": "SELECT current_setting('threads');", 567 | "question": "set number of threads to 10", 568 | "category": "duckdb" 569 | }, 570 | { 571 | "db_id": "laptop", 572 | "query": "SET memory_limit = '20G';\n", 573 | "setup_sql": ";", 574 | "validation_sql": "SELECT current_setting('memory_limit');", 575 | "question": "set memory limit to 20 gigabyte", 576 | "category": "duckdb" 577 | }, 578 | { 579 | "db_id": "laptop", 580 | "query": "SELECT * EXCLUDE (price), avg(price) FROM laptops GROUP BY ALL;", 581 | "setup_sql": ";", 582 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 583 | "question": "show the average price of laptop and group by the remaining columns", 584 | "category": "duckdb" 585 | }, 586 | { 587 | "db_id": "laptop", 588 | "query": "SELECT * FROM laptops WHERE price > 1000 ORDER BY ALL;\n", 589 | "setup_sql": ";", 590 | "validation_sql": "SELECT * FROM ddb_benchmark_result;", 591 | "question": "show all laptops with price above 1000 and order by all columns", 592 | "category": "duckdb" 593 | }, 594 | { 595 | "db_id": "laptop", 596 | "query": "ATTACH 'who.ddb';", 597 | "setup_sql": ";", 598 | "validation_sql": "SHOW DATABASES;", 599 | "question": "attach database file who.ddb", 600 | "category": "duckdb" 601 | } 602 | ] -------------------------------------------------------------------------------- /eval/data/tables.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "db_id": "hn", 4 | "column_names": [ 5 | [ 6 | -1, 7 | "*" 8 | ], 9 | [ 10 | 0, 11 | "title" 12 | ], 13 | [ 14 | 0, 15 | "url" 16 | ], 17 | [ 18 | 0, 19 | "text" 20 | ], 21 | [ 22 | 0, 23 | "dead" 24 | ], 25 | [ 26 | 0, 27 | "by" 28 | ], 29 | [ 30 | 0, 31 | "score" 32 | ], 33 | [ 34 | 0, 35 | "time" 36 | ], 37 | [ 38 | 0, 39 | "timestamp" 40 | ], 41 | [ 42 | 0, 43 | "type" 44 | ], 45 | [ 46 | 0, 47 | "id" 48 | ], 49 | [ 50 | 0, 51 | "parent" 52 | ], 53 | [ 54 | 0, 55 | "descendants" 56 | ], 57 | [ 58 | 0, 59 | "ranking" 60 | ], 61 | [ 62 | 0, 63 | "deleted" 64 | ] 65 | ], 66 | "column_names_original": [ 67 | [ 68 | -1, 69 | "*" 70 | ], 71 | [ 72 | 0, 73 | "title" 74 | ], 75 | [ 76 | 0, 77 | "url" 78 | ], 79 | [ 80 | 0, 81 | "text" 82 | ], 83 | [ 84 | 0, 85 | "dead" 86 | ], 87 | [ 88 | 0, 89 | "by" 90 | ], 91 | [ 92 | 0, 93 | "score" 94 | ], 95 | [ 96 | 0, 97 | "time" 98 | ], 99 | [ 100 | 0, 101 | "timestamp" 102 | ], 103 | [ 104 | 0, 105 | "type" 106 | ], 107 | [ 108 | 0, 109 | "id" 110 | ], 111 | [ 112 | 0, 113 | "parent" 114 | ], 115 | [ 116 | 0, 117 | "descendants" 118 | ], 119 | [ 120 | 0, 121 | "ranking" 122 | ], 123 | [ 124 | 0, 125 | "deleted" 126 | ] 127 | ], 128 | "column_types": [ 129 | "text", 130 | "varchar", 131 | "varchar", 132 | "varchar", 133 | "boolean", 134 | "varchar", 135 | "bigint", 136 | "bigint", 137 | "timestamp", 138 | "varchar", 139 | "bigint", 140 | "bigint", 141 | "bigint", 142 | "bigint", 143 | "boolean" 144 | ], 145 | "foreign_keys": {}, 146 | "primary_keys": {}, 147 | "table_names": [ 148 | "hacker_news" 149 | ], 150 | "table_names_original": [ 151 | "hacker_news" 152 | ] 153 | }, 154 | { 155 | "db_id": "laptop", 156 | "column_names": [ 157 | [ 158 | -1, 159 | "*" 160 | ], 161 | [ 162 | 0, 163 | "customer_id" 164 | ], 165 | [ 166 | 0, 167 | "firstname" 168 | ], 169 | [ 170 | 0, 171 | "lastname" 172 | ], 173 | [ 174 | 0, 175 | "city" 176 | ], 177 | [ 178 | 0, 179 | "address" 180 | ], 181 | [ 182 | 0, 183 | "email" 184 | ], 185 | [ 186 | 1, 187 | "model" 188 | ], 189 | [ 190 | 1, 191 | "speed" 192 | ], 193 | [ 194 | 1, 195 | "ram" 196 | ], 197 | [ 198 | 1, 199 | "hd" 200 | ], 201 | [ 202 | 1, 203 | "screen" 204 | ], 205 | [ 206 | 1, 207 | "price" 208 | ], 209 | [ 210 | 2, 211 | "model" 212 | ], 213 | [ 214 | 2, 215 | "speed" 216 | ], 217 | [ 218 | 2, 219 | "ram" 220 | ], 221 | [ 222 | 2, 223 | "hd" 224 | ], 225 | [ 226 | 2, 227 | "price" 228 | ], 229 | [ 230 | 3, 231 | "model" 232 | ], 233 | [ 234 | 3, 235 | "color" 236 | ], 237 | [ 238 | 3, 239 | "type" 240 | ], 241 | [ 242 | 3, 243 | "price" 244 | ], 245 | [ 246 | 4, 247 | "maker" 248 | ], 249 | [ 250 | 4, 251 | "model" 252 | ], 253 | [ 254 | 4, 255 | "type" 256 | ], 257 | [ 258 | 5, 259 | "customer_id" 260 | ], 261 | [ 262 | 5, 263 | "model" 264 | ], 265 | [ 266 | 5, 267 | "quantity" 268 | ], 269 | [ 270 | 5, 271 | "day" 272 | ], 273 | [ 274 | 5, 275 | "paid" 276 | ], 277 | [ 278 | 5, 279 | "type_of_payment" 280 | ] 281 | ], 282 | "column_names_original": [ 283 | [ 284 | -1, 285 | "*" 286 | ], 287 | [ 288 | 0, 289 | "customer_id" 290 | ], 291 | [ 292 | 0, 293 | "firstname" 294 | ], 295 | [ 296 | 0, 297 | "lastname" 298 | ], 299 | [ 300 | 0, 301 | "city" 302 | ], 303 | [ 304 | 0, 305 | "address" 306 | ], 307 | [ 308 | 0, 309 | "email" 310 | ], 311 | [ 312 | 1, 313 | "model" 314 | ], 315 | [ 316 | 1, 317 | "speed" 318 | ], 319 | [ 320 | 1, 321 | "ram" 322 | ], 323 | [ 324 | 1, 325 | "hd" 326 | ], 327 | [ 328 | 1, 329 | "screen" 330 | ], 331 | [ 332 | 1, 333 | "price" 334 | ], 335 | [ 336 | 2, 337 | "model" 338 | ], 339 | [ 340 | 2, 341 | "speed" 342 | ], 343 | [ 344 | 2, 345 | "ram" 346 | ], 347 | [ 348 | 2, 349 | "hd" 350 | ], 351 | [ 352 | 2, 353 | "price" 354 | ], 355 | [ 356 | 3, 357 | "model" 358 | ], 359 | [ 360 | 3, 361 | "color" 362 | ], 363 | [ 364 | 3, 365 | "type" 366 | ], 367 | [ 368 | 3, 369 | "price" 370 | ], 371 | [ 372 | 4, 373 | "maker" 374 | ], 375 | [ 376 | 4, 377 | "model" 378 | ], 379 | [ 380 | 4, 381 | "type" 382 | ], 383 | [ 384 | 5, 385 | "customer_id" 386 | ], 387 | [ 388 | 5, 389 | "model" 390 | ], 391 | [ 392 | 5, 393 | "quantity" 394 | ], 395 | [ 396 | 5, 397 | "day" 398 | ], 399 | [ 400 | 5, 401 | "paid" 402 | ], 403 | [ 404 | 5, 405 | "type_of_payment" 406 | ] 407 | ], 408 | "column_types": [ 409 | "text", 410 | "char", 411 | "varchar", 412 | "varchar", 413 | "varchar", 414 | "varchar", 415 | "varchar", 416 | "char", 417 | "double", 418 | "int", 419 | "int", 420 | "double", 421 | "double", 422 | "char", 423 | "double", 424 | "int", 425 | "int", 426 | "double", 427 | "char", 428 | "varchar", 429 | "varchar", 430 | "double", 431 | "char", 432 | "char", 433 | "varchar", 434 | "char", 435 | "char", 436 | "int", 437 | "date", 438 | "double", 439 | "varchar" 440 | ], 441 | "foreign_keys": {}, 442 | "primary_keys": {}, 443 | "table_names": [ 444 | "customers", 445 | "laptops", 446 | "pcs", 447 | "printers", 448 | "products", 449 | "sales" 450 | ], 451 | "table_names_original": [ 452 | "customers", 453 | "laptops", 454 | "pcs", 455 | "printers", 456 | "products", 457 | "sales" 458 | ] 459 | }, 460 | { 461 | "db_id": "transactions", 462 | "column_names": [ 463 | [ 464 | -1, 465 | "*" 466 | ], 467 | [ 468 | 0, 469 | "id" 470 | ], 471 | [ 472 | 0, 473 | "name" 474 | ], 475 | [ 476 | 1, 477 | "user_id" 478 | ], 479 | [ 480 | 1, 481 | "amount" 482 | ] 483 | ], 484 | "column_names_original": [ 485 | [ 486 | -1, 487 | "*" 488 | ], 489 | [ 490 | 0, 491 | "id" 492 | ], 493 | [ 494 | 0, 495 | "name" 496 | ], 497 | [ 498 | 1, 499 | "user_id" 500 | ], 501 | [ 502 | 1, 503 | "amount" 504 | ] 505 | ], 506 | "column_types": [ 507 | "text", 508 | "int", 509 | "varchar", 510 | "int", 511 | "int" 512 | ], 513 | "foreign_keys": {}, 514 | "primary_keys": {}, 515 | "table_names": [ 516 | "users", 517 | "transactions" 518 | ], 519 | "table_names_original": [ 520 | "users", 521 | "transactions" 522 | ] 523 | }, 524 | { 525 | "db_id": "who", 526 | "column_names": [ 527 | [ 528 | -1, 529 | "*" 530 | ], 531 | [ 532 | 0, 533 | "who_region" 534 | ], 535 | [ 536 | 0, 537 | "iso3" 538 | ], 539 | [ 540 | 0, 541 | "country_name" 542 | ], 543 | [ 544 | 0, 545 | "city" 546 | ], 547 | [ 548 | 0, 549 | "year" 550 | ], 551 | [ 552 | 0, 553 | "version" 554 | ], 555 | [ 556 | 0, 557 | "pm10_concentration" 558 | ], 559 | [ 560 | 0, 561 | "pm25_concentration" 562 | ], 563 | [ 564 | 0, 565 | "no2_concentration" 566 | ], 567 | [ 568 | 0, 569 | "pm10_tempcov" 570 | ], 571 | [ 572 | 0, 573 | "pm25_tempcov" 574 | ], 575 | [ 576 | 0, 577 | "no2_tempcov" 578 | ], 579 | [ 580 | 0, 581 | "type_of_stations" 582 | ], 583 | [ 584 | 0, 585 | "reference" 586 | ], 587 | [ 588 | 0, 589 | "web_link" 590 | ], 591 | [ 592 | 0, 593 | "population" 594 | ], 595 | [ 596 | 0, 597 | "population_source" 598 | ], 599 | [ 600 | 0, 601 | "latitude" 602 | ], 603 | [ 604 | 0, 605 | "longitude" 606 | ], 607 | [ 608 | 0, 609 | "who_ms" 610 | ] 611 | ], 612 | "column_names_original": [ 613 | [ 614 | -1, 615 | "*" 616 | ], 617 | [ 618 | 0, 619 | "who_region" 620 | ], 621 | [ 622 | 0, 623 | "iso3" 624 | ], 625 | [ 626 | 0, 627 | "country_name" 628 | ], 629 | [ 630 | 0, 631 | "city" 632 | ], 633 | [ 634 | 0, 635 | "year" 636 | ], 637 | [ 638 | 0, 639 | "version" 640 | ], 641 | [ 642 | 0, 643 | "pm10_concentration" 644 | ], 645 | [ 646 | 0, 647 | "pm25_concentration" 648 | ], 649 | [ 650 | 0, 651 | "no2_concentration" 652 | ], 653 | [ 654 | 0, 655 | "pm10_tempcov" 656 | ], 657 | [ 658 | 0, 659 | "pm25_tempcov" 660 | ], 661 | [ 662 | 0, 663 | "no2_tempcov" 664 | ], 665 | [ 666 | 0, 667 | "type_of_stations" 668 | ], 669 | [ 670 | 0, 671 | "reference" 672 | ], 673 | [ 674 | 0, 675 | "web_link" 676 | ], 677 | [ 678 | 0, 679 | "population" 680 | ], 681 | [ 682 | 0, 683 | "population_source" 684 | ], 685 | [ 686 | 0, 687 | "latitude" 688 | ], 689 | [ 690 | 0, 691 | "longitude" 692 | ], 693 | [ 694 | 0, 695 | "who_ms" 696 | ] 697 | ], 698 | "column_types": [ 699 | "text", 700 | "varchar", 701 | "varchar", 702 | "varchar", 703 | "varchar", 704 | "bigint", 705 | "varchar", 706 | "bigint", 707 | "bigint", 708 | "bigint", 709 | "bigint", 710 | "bigint", 711 | "bigint", 712 | "varchar", 713 | "varchar", 714 | "varchar", 715 | "varchar", 716 | "varchar", 717 | "float", 718 | "float", 719 | "bigint" 720 | ], 721 | "foreign_keys": {}, 722 | "primary_keys": {}, 723 | "table_names": [ 724 | "ambient_air_quality" 725 | ], 726 | "table_names_original": [ 727 | "ambient_air_quality" 728 | ] 729 | }, 730 | { 731 | "db_id": "nyc", 732 | "column_names": [ 733 | [ 734 | -1, 735 | "*" 736 | ], 737 | [ 738 | 0, 739 | "unique_key" 740 | ], 741 | [ 742 | 0, 743 | "created_date" 744 | ], 745 | [ 746 | 0, 747 | "closed_date" 748 | ], 749 | [ 750 | 0, 751 | "agency" 752 | ], 753 | [ 754 | 0, 755 | "agency_name" 756 | ], 757 | [ 758 | 0, 759 | "complaint_type" 760 | ], 761 | [ 762 | 0, 763 | "descriptor" 764 | ], 765 | [ 766 | 0, 767 | "location_type" 768 | ], 769 | [ 770 | 0, 771 | "incident_zip" 772 | ], 773 | [ 774 | 0, 775 | "incident_address" 776 | ], 777 | [ 778 | 0, 779 | "street_name" 780 | ], 781 | [ 782 | 0, 783 | "cross_street_1" 784 | ], 785 | [ 786 | 0, 787 | "cross_street_2" 788 | ], 789 | [ 790 | 0, 791 | "intersection_street_1" 792 | ], 793 | [ 794 | 0, 795 | "intersection_street_2" 796 | ], 797 | [ 798 | 0, 799 | "address_type" 800 | ], 801 | [ 802 | 0, 803 | "city" 804 | ], 805 | [ 806 | 0, 807 | "landmark" 808 | ], 809 | [ 810 | 0, 811 | "facility_type" 812 | ], 813 | [ 814 | 0, 815 | "status" 816 | ], 817 | [ 818 | 0, 819 | "due_date" 820 | ], 821 | [ 822 | 0, 823 | "resolution_description" 824 | ], 825 | [ 826 | 0, 827 | "resolution_action_updated_date" 828 | ], 829 | [ 830 | 0, 831 | "community_board" 832 | ], 833 | [ 834 | 0, 835 | "bbl" 836 | ], 837 | [ 838 | 0, 839 | "borough" 840 | ], 841 | [ 842 | 0, 843 | "x_coordinate_state_plane" 844 | ], 845 | [ 846 | 0, 847 | "y_coordinate_state_plane" 848 | ], 849 | [ 850 | 0, 851 | "open_data_channel_type" 852 | ], 853 | [ 854 | 0, 855 | "park_facility_name" 856 | ], 857 | [ 858 | 0, 859 | "park_borough" 860 | ], 861 | [ 862 | 0, 863 | "vehicle_type" 864 | ], 865 | [ 866 | 0, 867 | "taxi_company_borough" 868 | ], 869 | [ 870 | 0, 871 | "taxi_pick_up_location" 872 | ], 873 | [ 874 | 0, 875 | "bridge_highway_name" 876 | ], 877 | [ 878 | 0, 879 | "bridge_highway_direction" 880 | ], 881 | [ 882 | 0, 883 | "road_ramp" 884 | ], 885 | [ 886 | 0, 887 | "bridge_highway_segment" 888 | ], 889 | [ 890 | 0, 891 | "latitude" 892 | ], 893 | [ 894 | 0, 895 | "longitude" 896 | ], 897 | [ 898 | 1, 899 | "hvfhs_license_num" 900 | ], 901 | [ 902 | 1, 903 | "dispatching_base_num" 904 | ], 905 | [ 906 | 1, 907 | "originating_base_num" 908 | ], 909 | [ 910 | 1, 911 | "request_datetime" 912 | ], 913 | [ 914 | 1, 915 | "on_scene_datetime" 916 | ], 917 | [ 918 | 1, 919 | "pickup_datetime" 920 | ], 921 | [ 922 | 1, 923 | "dropoff_datetime" 924 | ], 925 | [ 926 | 1, 927 | "PULocationID" 928 | ], 929 | [ 930 | 1, 931 | "DOLocationID" 932 | ], 933 | [ 934 | 1, 935 | "trip_miles" 936 | ], 937 | [ 938 | 1, 939 | "trip_time" 940 | ], 941 | [ 942 | 1, 943 | "base_passenger_fare" 944 | ], 945 | [ 946 | 1, 947 | "tolls" 948 | ], 949 | [ 950 | 1, 951 | "bcf" 952 | ], 953 | [ 954 | 1, 955 | "sales_tax" 956 | ], 957 | [ 958 | 1, 959 | "congestion_surcharge" 960 | ], 961 | [ 962 | 1, 963 | "airport_fee" 964 | ], 965 | [ 966 | 1, 967 | "tips" 968 | ], 969 | [ 970 | 1, 971 | "driver_pay" 972 | ], 973 | [ 974 | 1, 975 | "shared_request_flag" 976 | ], 977 | [ 978 | 1, 979 | "shared_match_flag" 980 | ], 981 | [ 982 | 1, 983 | "access_a_ride_flag" 984 | ], 985 | [ 986 | 1, 987 | "wav_request_flag" 988 | ], 989 | [ 990 | 1, 991 | "wav_match_flag" 992 | ], 993 | [ 994 | 2, 995 | "VendorID" 996 | ], 997 | [ 998 | 2, 999 | "tpep_pickup_datetime" 1000 | ], 1001 | [ 1002 | 2, 1003 | "tpep_dropoff_datetime" 1004 | ], 1005 | [ 1006 | 2, 1007 | "passenger_count" 1008 | ], 1009 | [ 1010 | 2, 1011 | "trip_distance" 1012 | ], 1013 | [ 1014 | 2, 1015 | "RatecodeID" 1016 | ], 1017 | [ 1018 | 2, 1019 | "store_and_fwd_flag" 1020 | ], 1021 | [ 1022 | 2, 1023 | "PULocationID" 1024 | ], 1025 | [ 1026 | 2, 1027 | "DOLocationID" 1028 | ], 1029 | [ 1030 | 2, 1031 | "payment_type" 1032 | ], 1033 | [ 1034 | 2, 1035 | "fare_amount" 1036 | ], 1037 | [ 1038 | 2, 1039 | "extra" 1040 | ], 1041 | [ 1042 | 2, 1043 | "mta_tax" 1044 | ], 1045 | [ 1046 | 2, 1047 | "tip_amount" 1048 | ], 1049 | [ 1050 | 2, 1051 | "tolls_amount" 1052 | ], 1053 | [ 1054 | 2, 1055 | "improvement_surcharge" 1056 | ], 1057 | [ 1058 | 2, 1059 | "total_amount" 1060 | ], 1061 | [ 1062 | 2, 1063 | "congestion_surcharge" 1064 | ], 1065 | [ 1066 | 2, 1067 | "airport_fee" 1068 | ] 1069 | ], 1070 | "column_names_original": [ 1071 | [ 1072 | -1, 1073 | "*" 1074 | ], 1075 | [ 1076 | 0, 1077 | "unique_key" 1078 | ], 1079 | [ 1080 | 0, 1081 | "created_date" 1082 | ], 1083 | [ 1084 | 0, 1085 | "closed_date" 1086 | ], 1087 | [ 1088 | 0, 1089 | "agency" 1090 | ], 1091 | [ 1092 | 0, 1093 | "agency_name" 1094 | ], 1095 | [ 1096 | 0, 1097 | "complaint_type" 1098 | ], 1099 | [ 1100 | 0, 1101 | "descriptor" 1102 | ], 1103 | [ 1104 | 0, 1105 | "location_type" 1106 | ], 1107 | [ 1108 | 0, 1109 | "incident_zip" 1110 | ], 1111 | [ 1112 | 0, 1113 | "incident_address" 1114 | ], 1115 | [ 1116 | 0, 1117 | "street_name" 1118 | ], 1119 | [ 1120 | 0, 1121 | "cross_street_1" 1122 | ], 1123 | [ 1124 | 0, 1125 | "cross_street_2" 1126 | ], 1127 | [ 1128 | 0, 1129 | "intersection_street_1" 1130 | ], 1131 | [ 1132 | 0, 1133 | "intersection_street_2" 1134 | ], 1135 | [ 1136 | 0, 1137 | "address_type" 1138 | ], 1139 | [ 1140 | 0, 1141 | "city" 1142 | ], 1143 | [ 1144 | 0, 1145 | "landmark" 1146 | ], 1147 | [ 1148 | 0, 1149 | "facility_type" 1150 | ], 1151 | [ 1152 | 0, 1153 | "status" 1154 | ], 1155 | [ 1156 | 0, 1157 | "due_date" 1158 | ], 1159 | [ 1160 | 0, 1161 | "resolution_description" 1162 | ], 1163 | [ 1164 | 0, 1165 | "resolution_action_updated_date" 1166 | ], 1167 | [ 1168 | 0, 1169 | "community_board" 1170 | ], 1171 | [ 1172 | 0, 1173 | "bbl" 1174 | ], 1175 | [ 1176 | 0, 1177 | "borough" 1178 | ], 1179 | [ 1180 | 0, 1181 | "x_coordinate_state_plane" 1182 | ], 1183 | [ 1184 | 0, 1185 | "y_coordinate_state_plane" 1186 | ], 1187 | [ 1188 | 0, 1189 | "open_data_channel_type" 1190 | ], 1191 | [ 1192 | 0, 1193 | "park_facility_name" 1194 | ], 1195 | [ 1196 | 0, 1197 | "park_borough" 1198 | ], 1199 | [ 1200 | 0, 1201 | "vehicle_type" 1202 | ], 1203 | [ 1204 | 0, 1205 | "taxi_company_borough" 1206 | ], 1207 | [ 1208 | 0, 1209 | "taxi_pick_up_location" 1210 | ], 1211 | [ 1212 | 0, 1213 | "bridge_highway_name" 1214 | ], 1215 | [ 1216 | 0, 1217 | "bridge_highway_direction" 1218 | ], 1219 | [ 1220 | 0, 1221 | "road_ramp" 1222 | ], 1223 | [ 1224 | 0, 1225 | "bridge_highway_segment" 1226 | ], 1227 | [ 1228 | 0, 1229 | "latitude" 1230 | ], 1231 | [ 1232 | 0, 1233 | "longitude" 1234 | ], 1235 | [ 1236 | 1, 1237 | "hvfhs_license_num" 1238 | ], 1239 | [ 1240 | 1, 1241 | "dispatching_base_num" 1242 | ], 1243 | [ 1244 | 1, 1245 | "originating_base_num" 1246 | ], 1247 | [ 1248 | 1, 1249 | "request_datetime" 1250 | ], 1251 | [ 1252 | 1, 1253 | "on_scene_datetime" 1254 | ], 1255 | [ 1256 | 1, 1257 | "pickup_datetime" 1258 | ], 1259 | [ 1260 | 1, 1261 | "dropoff_datetime" 1262 | ], 1263 | [ 1264 | 1, 1265 | "PULocationID" 1266 | ], 1267 | [ 1268 | 1, 1269 | "DOLocationID" 1270 | ], 1271 | [ 1272 | 1, 1273 | "trip_miles" 1274 | ], 1275 | [ 1276 | 1, 1277 | "trip_time" 1278 | ], 1279 | [ 1280 | 1, 1281 | "base_passenger_fare" 1282 | ], 1283 | [ 1284 | 1, 1285 | "tolls" 1286 | ], 1287 | [ 1288 | 1, 1289 | "bcf" 1290 | ], 1291 | [ 1292 | 1, 1293 | "sales_tax" 1294 | ], 1295 | [ 1296 | 1, 1297 | "congestion_surcharge" 1298 | ], 1299 | [ 1300 | 1, 1301 | "airport_fee" 1302 | ], 1303 | [ 1304 | 1, 1305 | "tips" 1306 | ], 1307 | [ 1308 | 1, 1309 | "driver_pay" 1310 | ], 1311 | [ 1312 | 1, 1313 | "shared_request_flag" 1314 | ], 1315 | [ 1316 | 1, 1317 | "shared_match_flag" 1318 | ], 1319 | [ 1320 | 1, 1321 | "access_a_ride_flag" 1322 | ], 1323 | [ 1324 | 1, 1325 | "wav_request_flag" 1326 | ], 1327 | [ 1328 | 1, 1329 | "wav_match_flag" 1330 | ], 1331 | [ 1332 | 2, 1333 | "VendorID" 1334 | ], 1335 | [ 1336 | 2, 1337 | "tpep_pickup_datetime" 1338 | ], 1339 | [ 1340 | 2, 1341 | "tpep_dropoff_datetime" 1342 | ], 1343 | [ 1344 | 2, 1345 | "passenger_count" 1346 | ], 1347 | [ 1348 | 2, 1349 | "trip_distance" 1350 | ], 1351 | [ 1352 | 2, 1353 | "RatecodeID" 1354 | ], 1355 | [ 1356 | 2, 1357 | "store_and_fwd_flag" 1358 | ], 1359 | [ 1360 | 2, 1361 | "PULocationID" 1362 | ], 1363 | [ 1364 | 2, 1365 | "DOLocationID" 1366 | ], 1367 | [ 1368 | 2, 1369 | "payment_type" 1370 | ], 1371 | [ 1372 | 2, 1373 | "fare_amount" 1374 | ], 1375 | [ 1376 | 2, 1377 | "extra" 1378 | ], 1379 | [ 1380 | 2, 1381 | "mta_tax" 1382 | ], 1383 | [ 1384 | 2, 1385 | "tip_amount" 1386 | ], 1387 | [ 1388 | 2, 1389 | "tolls_amount" 1390 | ], 1391 | [ 1392 | 2, 1393 | "improvement_surcharge" 1394 | ], 1395 | [ 1396 | 2, 1397 | "total_amount" 1398 | ], 1399 | [ 1400 | 2, 1401 | "congestion_surcharge" 1402 | ], 1403 | [ 1404 | 2, 1405 | "airport_fee" 1406 | ] 1407 | ], 1408 | "column_types": [ 1409 | "text", 1410 | "bigint", 1411 | "timestamp", 1412 | "timestamp", 1413 | "varchar", 1414 | "varchar", 1415 | "varchar", 1416 | "varchar", 1417 | "varchar", 1418 | "varchar", 1419 | "varchar", 1420 | "varchar", 1421 | "varchar", 1422 | "varchar", 1423 | "varchar", 1424 | "varchar", 1425 | "varchar", 1426 | "varchar", 1427 | "varchar", 1428 | "varchar", 1429 | "varchar", 1430 | "timestamp", 1431 | "varchar", 1432 | "timestamp", 1433 | "varchar", 1434 | "varchar", 1435 | "varchar", 1436 | "varchar", 1437 | "varchar", 1438 | "varchar", 1439 | "varchar", 1440 | "varchar", 1441 | "varchar", 1442 | "varchar", 1443 | "varchar", 1444 | "varchar", 1445 | "varchar", 1446 | "varchar", 1447 | "varchar", 1448 | "double", 1449 | "double", 1450 | "varchar", 1451 | "varchar", 1452 | "varchar", 1453 | "timestamp", 1454 | "timestamp", 1455 | "timestamp", 1456 | "timestamp", 1457 | "bigint", 1458 | "bigint", 1459 | "double", 1460 | "bigint", 1461 | "double", 1462 | "double", 1463 | "double", 1464 | "double", 1465 | "double", 1466 | "double", 1467 | "double", 1468 | "double", 1469 | "varchar", 1470 | "varchar", 1471 | "varchar", 1472 | "varchar", 1473 | "varchar", 1474 | "bigint", 1475 | "timestamp", 1476 | "timestamp", 1477 | "double", 1478 | "double", 1479 | "double", 1480 | "varchar", 1481 | "bigint", 1482 | "bigint", 1483 | "bigint", 1484 | "double", 1485 | "double", 1486 | "double", 1487 | "double", 1488 | "double", 1489 | "double", 1490 | "double", 1491 | "double", 1492 | "double" 1493 | ], 1494 | "foreign_keys": {}, 1495 | "primary_keys": {}, 1496 | "table_names": [ 1497 | "service_requests", 1498 | "rideshare", 1499 | "taxi" 1500 | ], 1501 | "table_names_original": [ 1502 | "service_requests", 1503 | "rideshare", 1504 | "taxi" 1505 | ] 1506 | }, 1507 | { 1508 | "db_id": "product", 1509 | "column_names": [ 1510 | [ 1511 | -1, 1512 | "*" 1513 | ], 1514 | [ 1515 | 0, 1516 | "maker" 1517 | ], 1518 | [ 1519 | 0, 1520 | "model" 1521 | ], 1522 | [ 1523 | 0, 1524 | "type" 1525 | ] 1526 | ], 1527 | "column_names_original": [ 1528 | [ 1529 | -1, 1530 | "*" 1531 | ], 1532 | [ 1533 | 0, 1534 | "maker" 1535 | ], 1536 | [ 1537 | 0, 1538 | "model" 1539 | ], 1540 | [ 1541 | 0, 1542 | "type" 1543 | ] 1544 | ], 1545 | "column_types": [ 1546 | "text", 1547 | "varchar", 1548 | "varchar", 1549 | "varchar" 1550 | ], 1551 | "foreign_keys": {}, 1552 | "primary_keys": {}, 1553 | "table_names": [ 1554 | "products" 1555 | ], 1556 | "table_names_original": [ 1557 | "products" 1558 | ] 1559 | }, 1560 | { 1561 | "db_id": "json", 1562 | "column_names": [ 1563 | [ 1564 | -1, 1565 | "*" 1566 | ], 1567 | [ 1568 | 0, 1569 | "employee" 1570 | ] 1571 | ], 1572 | "column_names_original": [ 1573 | [ 1574 | -1, 1575 | "*" 1576 | ], 1577 | [ 1578 | 0, 1579 | "employee" 1580 | ] 1581 | ], 1582 | "column_types": [ 1583 | "text", 1584 | "struct(id int, first_name text, last_name text, email text)" 1585 | ], 1586 | "foreign_keys": {}, 1587 | "primary_keys": {}, 1588 | "table_names": [ 1589 | "employee_json" 1590 | ], 1591 | "table_names_original": [ 1592 | "employee_json" 1593 | ] 1594 | }, 1595 | { 1596 | "db_id": "flightinfo", 1597 | "column_names": [ 1598 | [ 1599 | -1, 1600 | "*" 1601 | ], 1602 | [ 1603 | 0, 1604 | "icao_code" 1605 | ], 1606 | [ 1607 | 0, 1608 | "email" 1609 | ], 1610 | [ 1611 | 0, 1612 | "name" 1613 | ], 1614 | [ 1615 | 0, 1616 | "phone_number" 1617 | ], 1618 | [ 1619 | 0, 1620 | "iata_code" 1621 | ], 1622 | [ 1623 | 1, 1624 | "title" 1625 | ], 1626 | [ 1627 | 1, 1628 | "description" 1629 | ], 1630 | [ 1631 | 1, 1632 | "price" 1633 | ], 1634 | [ 1635 | 1, 1636 | "service_type" 1637 | ], 1638 | [ 1639 | 1, 1640 | "airline_icao_code" 1641 | ], 1642 | [ 1643 | 2, 1644 | "iata_code" 1645 | ], 1646 | [ 1647 | 2, 1648 | "address" 1649 | ], 1650 | [ 1651 | 2, 1652 | "name" 1653 | ], 1654 | [ 1655 | 2, 1656 | "phone_number" 1657 | ], 1658 | [ 1659 | 2, 1660 | "email" 1661 | ], 1662 | [ 1663 | 2, 1664 | "city_zip_code" 1665 | ], 1666 | [ 1667 | 2, 1668 | "city_dbpedia" 1669 | ], 1670 | [ 1671 | 3, 1672 | "title" 1673 | ], 1674 | [ 1675 | 3, 1676 | "cabin_bag_dimension_cm" 1677 | ], 1678 | [ 1679 | 3, 1680 | "cabin_bags_no" 1681 | ], 1682 | [ 1683 | 3, 1684 | "cabin_bg_weight_kg" 1685 | ], 1686 | [ 1687 | 3, 1688 | "checked_bag_dimension_cm" 1689 | ], 1690 | [ 1691 | 3, 1692 | "checked_bags_no" 1693 | ], 1694 | [ 1695 | 3, 1696 | "checked_bag_weight_kg" 1697 | ], 1698 | [ 1699 | 3, 1700 | "excessive_price_perkg" 1701 | ], 1702 | [ 1703 | 3, 1704 | "flight_type" 1705 | ], 1706 | [ 1707 | 3, 1708 | "airline_icao_code" 1709 | ], 1710 | [ 1711 | 4, 1712 | "title" 1713 | ], 1714 | [ 1715 | 4, 1716 | "description" 1717 | ], 1718 | [ 1719 | 4, 1720 | "airline_icao_code" 1721 | ], 1722 | [ 1723 | 5, 1724 | "title" 1725 | ], 1726 | [ 1727 | 5, 1728 | "description" 1729 | ], 1730 | [ 1731 | 5, 1732 | "due_date" 1733 | ], 1734 | [ 1735 | 5, 1736 | "refund_postdue_percentage" 1737 | ], 1738 | [ 1739 | 5, 1740 | "refund_predue_percentage" 1741 | ], 1742 | [ 1743 | 5, 1744 | "airline_icao_code" 1745 | ], 1746 | [ 1747 | 6, 1748 | "city_zipcode" 1749 | ], 1750 | [ 1751 | 6, 1752 | "city_name" 1753 | ], 1754 | [ 1755 | 6, 1756 | "country_iso_code" 1757 | ], 1758 | [ 1759 | 7, 1760 | "country_iso_code" 1761 | ], 1762 | [ 1763 | 7, 1764 | "country_name" 1765 | ], 1766 | [ 1767 | 8, 1768 | "flight_number" 1769 | ], 1770 | [ 1771 | 8, 1772 | "departure_airport_iata_code" 1773 | ], 1774 | [ 1775 | 8, 1776 | "arriving_airport_iata_code" 1777 | ], 1778 | [ 1779 | 9, 1780 | "number" 1781 | ], 1782 | [ 1783 | 9, 1784 | "departure_date" 1785 | ], 1786 | [ 1787 | 9, 1788 | "arrival_date" 1789 | ], 1790 | [ 1791 | 9, 1792 | "distance_km" 1793 | ], 1794 | [ 1795 | 9, 1796 | "is_available" 1797 | ], 1798 | [ 1799 | 9, 1800 | "duration_min" 1801 | ], 1802 | [ 1803 | 9, 1804 | "airline_icao_code" 1805 | ], 1806 | [ 1807 | 9, 1808 | "type" 1809 | ], 1810 | [ 1811 | 10, 1812 | "title" 1813 | ], 1814 | [ 1815 | 10, 1816 | "description" 1817 | ], 1818 | [ 1819 | 10, 1820 | "cabin_class_title" 1821 | ], 1822 | [ 1823 | 10, 1824 | "baggage_policy_title" 1825 | ], 1826 | [ 1827 | 10, 1828 | "cancelation_policy_title" 1829 | ], 1830 | [ 1831 | 11, 1832 | "subflight_number" 1833 | ], 1834 | [ 1835 | 11, 1836 | "flight_number" 1837 | ], 1838 | [ 1839 | 12, 1840 | "flight_package_title" 1841 | ], 1842 | [ 1843 | 12, 1844 | "airline_service_title" 1845 | ], 1846 | [ 1847 | 13, 1848 | "seat_number" 1849 | ], 1850 | [ 1851 | 13, 1852 | "is_available" 1853 | ], 1854 | [ 1855 | 13, 1856 | "flight_number" 1857 | ], 1858 | [ 1859 | 14, 1860 | "meal_type" 1861 | ], 1862 | [ 1863 | 14, 1864 | "airline_service_title" 1865 | ], 1866 | [ 1867 | 15, 1868 | "duration_min" 1869 | ], 1870 | [ 1871 | 15, 1872 | "duration_from" 1873 | ], 1874 | [ 1875 | 15, 1876 | "duration_to" 1877 | ], 1878 | [ 1879 | 15, 1880 | "airport_iatacode" 1881 | ], 1882 | [ 1883 | 15, 1884 | "flight_number" 1885 | ], 1886 | [ 1887 | 16, 1888 | "flight_number" 1889 | ], 1890 | [ 1891 | 16, 1892 | "package_title" 1893 | ], 1894 | [ 1895 | 16, 1896 | "trip_id" 1897 | ], 1898 | [ 1899 | 16, 1900 | "requested_excessive_baggage_kg" 1901 | ], 1902 | [ 1903 | 16, 1904 | "seat_number" 1905 | ], 1906 | [ 1907 | 16, 1908 | "chosen_meal_service_price" 1909 | ], 1910 | [ 1911 | 16, 1912 | "chosen_wifi_service_price" 1913 | ], 1914 | [ 1915 | 16, 1916 | "price" 1917 | ], 1918 | [ 1919 | 17, 1920 | "id" 1921 | ], 1922 | [ 1923 | 17, 1924 | "tax" 1925 | ], 1926 | [ 1927 | 17, 1928 | "booking_date" 1929 | ], 1930 | [ 1931 | 17, 1932 | "user_email" 1933 | ], 1934 | [ 1935 | 17, 1936 | "type" 1937 | ], 1938 | [ 1939 | 18, 1940 | "email" 1941 | ], 1942 | [ 1943 | 18, 1944 | "first_name" 1945 | ], 1946 | [ 1947 | 18, 1948 | "last_name" 1949 | ], 1950 | [ 1951 | 18, 1952 | "birthdate" 1953 | ], 1954 | [ 1955 | 18, 1956 | "passport_number" 1957 | ], 1958 | [ 1959 | 18, 1960 | "address" 1961 | ], 1962 | [ 1963 | 18, 1964 | "password" 1965 | ], 1966 | [ 1967 | 18, 1968 | "phone_number" 1969 | ], 1970 | [ 1971 | 19, 1972 | "wifi_onboard_service_bandwidth_MB" 1973 | ], 1974 | [ 1975 | 19, 1976 | "airline_service_title" 1977 | ] 1978 | ], 1979 | "column_names_original": [ 1980 | [ 1981 | -1, 1982 | "*" 1983 | ], 1984 | [ 1985 | 0, 1986 | "icao_code" 1987 | ], 1988 | [ 1989 | 0, 1990 | "email" 1991 | ], 1992 | [ 1993 | 0, 1994 | "name" 1995 | ], 1996 | [ 1997 | 0, 1998 | "phone_number" 1999 | ], 2000 | [ 2001 | 0, 2002 | "iata_code" 2003 | ], 2004 | [ 2005 | 1, 2006 | "title" 2007 | ], 2008 | [ 2009 | 1, 2010 | "description" 2011 | ], 2012 | [ 2013 | 1, 2014 | "price" 2015 | ], 2016 | [ 2017 | 1, 2018 | "service_type" 2019 | ], 2020 | [ 2021 | 1, 2022 | "airline_icao_code" 2023 | ], 2024 | [ 2025 | 2, 2026 | "iata_code" 2027 | ], 2028 | [ 2029 | 2, 2030 | "address" 2031 | ], 2032 | [ 2033 | 2, 2034 | "name" 2035 | ], 2036 | [ 2037 | 2, 2038 | "phone_number" 2039 | ], 2040 | [ 2041 | 2, 2042 | "email" 2043 | ], 2044 | [ 2045 | 2, 2046 | "city_zip_code" 2047 | ], 2048 | [ 2049 | 2, 2050 | "city_dbpedia" 2051 | ], 2052 | [ 2053 | 3, 2054 | "title" 2055 | ], 2056 | [ 2057 | 3, 2058 | "cabin_bag_dimension_cm" 2059 | ], 2060 | [ 2061 | 3, 2062 | "cabin_bags_no" 2063 | ], 2064 | [ 2065 | 3, 2066 | "cabin_bg_weight_kg" 2067 | ], 2068 | [ 2069 | 3, 2070 | "checked_bag_dimension_cm" 2071 | ], 2072 | [ 2073 | 3, 2074 | "checked_bags_no" 2075 | ], 2076 | [ 2077 | 3, 2078 | "checked_bag_weight_kg" 2079 | ], 2080 | [ 2081 | 3, 2082 | "excessive_price_perkg" 2083 | ], 2084 | [ 2085 | 3, 2086 | "flight_type" 2087 | ], 2088 | [ 2089 | 3, 2090 | "airline_icao_code" 2091 | ], 2092 | [ 2093 | 4, 2094 | "title" 2095 | ], 2096 | [ 2097 | 4, 2098 | "description" 2099 | ], 2100 | [ 2101 | 4, 2102 | "airline_icao_code" 2103 | ], 2104 | [ 2105 | 5, 2106 | "title" 2107 | ], 2108 | [ 2109 | 5, 2110 | "description" 2111 | ], 2112 | [ 2113 | 5, 2114 | "due_date" 2115 | ], 2116 | [ 2117 | 5, 2118 | "refund_postdue_percentage" 2119 | ], 2120 | [ 2121 | 5, 2122 | "refund_predue_percentage" 2123 | ], 2124 | [ 2125 | 5, 2126 | "airline_icao_code" 2127 | ], 2128 | [ 2129 | 6, 2130 | "city_zipcode" 2131 | ], 2132 | [ 2133 | 6, 2134 | "city_name" 2135 | ], 2136 | [ 2137 | 6, 2138 | "country_iso_code" 2139 | ], 2140 | [ 2141 | 7, 2142 | "country_iso_code" 2143 | ], 2144 | [ 2145 | 7, 2146 | "country_name" 2147 | ], 2148 | [ 2149 | 8, 2150 | "flight_number" 2151 | ], 2152 | [ 2153 | 8, 2154 | "departure_airport_iata_code" 2155 | ], 2156 | [ 2157 | 8, 2158 | "arriving_airport_iata_code" 2159 | ], 2160 | [ 2161 | 9, 2162 | "number" 2163 | ], 2164 | [ 2165 | 9, 2166 | "departure_date" 2167 | ], 2168 | [ 2169 | 9, 2170 | "arrival_date" 2171 | ], 2172 | [ 2173 | 9, 2174 | "distance_km" 2175 | ], 2176 | [ 2177 | 9, 2178 | "is_available" 2179 | ], 2180 | [ 2181 | 9, 2182 | "duration_min" 2183 | ], 2184 | [ 2185 | 9, 2186 | "airline_icao_code" 2187 | ], 2188 | [ 2189 | 9, 2190 | "type" 2191 | ], 2192 | [ 2193 | 10, 2194 | "title" 2195 | ], 2196 | [ 2197 | 10, 2198 | "description" 2199 | ], 2200 | [ 2201 | 10, 2202 | "cabin_class_title" 2203 | ], 2204 | [ 2205 | 10, 2206 | "baggage_policy_title" 2207 | ], 2208 | [ 2209 | 10, 2210 | "cancelation_policy_title" 2211 | ], 2212 | [ 2213 | 11, 2214 | "subflight_number" 2215 | ], 2216 | [ 2217 | 11, 2218 | "flight_number" 2219 | ], 2220 | [ 2221 | 12, 2222 | "flight_package_title" 2223 | ], 2224 | [ 2225 | 12, 2226 | "airline_service_title" 2227 | ], 2228 | [ 2229 | 13, 2230 | "seat_number" 2231 | ], 2232 | [ 2233 | 13, 2234 | "is_available" 2235 | ], 2236 | [ 2237 | 13, 2238 | "flight_number" 2239 | ], 2240 | [ 2241 | 14, 2242 | "meal_type" 2243 | ], 2244 | [ 2245 | 14, 2246 | "airline_service_title" 2247 | ], 2248 | [ 2249 | 15, 2250 | "duration_min" 2251 | ], 2252 | [ 2253 | 15, 2254 | "duration_from" 2255 | ], 2256 | [ 2257 | 15, 2258 | "duration_to" 2259 | ], 2260 | [ 2261 | 15, 2262 | "airport_iatacode" 2263 | ], 2264 | [ 2265 | 15, 2266 | "flight_number" 2267 | ], 2268 | [ 2269 | 16, 2270 | "flight_number" 2271 | ], 2272 | [ 2273 | 16, 2274 | "package_title" 2275 | ], 2276 | [ 2277 | 16, 2278 | "trip_id" 2279 | ], 2280 | [ 2281 | 16, 2282 | "requested_excessive_baggage_kg" 2283 | ], 2284 | [ 2285 | 16, 2286 | "seat_number" 2287 | ], 2288 | [ 2289 | 16, 2290 | "chosen_meal_service_price" 2291 | ], 2292 | [ 2293 | 16, 2294 | "chosen_wifi_service_price" 2295 | ], 2296 | [ 2297 | 16, 2298 | "price" 2299 | ], 2300 | [ 2301 | 17, 2302 | "id" 2303 | ], 2304 | [ 2305 | 17, 2306 | "tax" 2307 | ], 2308 | [ 2309 | 17, 2310 | "booking_date" 2311 | ], 2312 | [ 2313 | 17, 2314 | "user_email" 2315 | ], 2316 | [ 2317 | 17, 2318 | "type" 2319 | ], 2320 | [ 2321 | 18, 2322 | "email" 2323 | ], 2324 | [ 2325 | 18, 2326 | "first_name" 2327 | ], 2328 | [ 2329 | 18, 2330 | "last_name" 2331 | ], 2332 | [ 2333 | 18, 2334 | "birthdate" 2335 | ], 2336 | [ 2337 | 18, 2338 | "passport_number" 2339 | ], 2340 | [ 2341 | 18, 2342 | "address" 2343 | ], 2344 | [ 2345 | 18, 2346 | "password" 2347 | ], 2348 | [ 2349 | 18, 2350 | "phone_number" 2351 | ], 2352 | [ 2353 | 19, 2354 | "wifi_onboard_service_bandwidth_MB" 2355 | ], 2356 | [ 2357 | 19, 2358 | "airline_service_title" 2359 | ] 2360 | ], 2361 | "column_types": [ 2362 | "text", 2363 | "varchar", 2364 | "varchar", 2365 | "varchar", 2366 | "varchar", 2367 | "varchar", 2368 | "varchar", 2369 | "varchar", 2370 | "double", 2371 | "varchar", 2372 | "varchar", 2373 | "varchar", 2374 | "text", 2375 | "varchar", 2376 | "varchar", 2377 | "varchar", 2378 | "varchar", 2379 | "varchar", 2380 | "varchar", 2381 | "double", 2382 | "double", 2383 | "double", 2384 | "double", 2385 | "double", 2386 | "double", 2387 | "double", 2388 | "varchar", 2389 | "varchar", 2390 | "varchar", 2391 | "text", 2392 | "varchar", 2393 | "varchar", 2394 | "text", 2395 | "text", 2396 | "int", 2397 | "int", 2398 | "varchar", 2399 | "varchar", 2400 | "varchar", 2401 | "varchar", 2402 | "varchar", 2403 | "text", 2404 | "varchar", 2405 | "varchar", 2406 | "varchar", 2407 | "varchar", 2408 | "datetime", 2409 | "datetime", 2410 | "double", 2411 | "tinyint", 2412 | "double", 2413 | "varchar", 2414 | "varchar", 2415 | "varchar", 2416 | "text", 2417 | "varchar", 2418 | "varchar", 2419 | "varchar", 2420 | "varchar", 2421 | "varchar", 2422 | "varchar", 2423 | "varchar", 2424 | "varchar", 2425 | "tinyint", 2426 | "varchar", 2427 | "varchar", 2428 | "varchar", 2429 | "double", 2430 | "datetime", 2431 | "datetime", 2432 | "varchar", 2433 | "varchar", 2434 | "varchar", 2435 | "varchar", 2436 | "int", 2437 | "int", 2438 | "varchar", 2439 | "int", 2440 | "int", 2441 | "double", 2442 | "int", 2443 | "double", 2444 | "datetime", 2445 | "varchar", 2446 | "varchar", 2447 | "varchar", 2448 | "varchar", 2449 | "varchar", 2450 | "date", 2451 | "varchar", 2452 | "varchar", 2453 | "varchar", 2454 | "double", 2455 | "double", 2456 | "varchar" 2457 | ], 2458 | "foreign_keys": {}, 2459 | "primary_keys": {}, 2460 | "table_names": [ 2461 | "airline", 2462 | "airline_service", 2463 | "airport", 2464 | "baggage_policy", 2465 | "cabin_class", 2466 | "cancellation_policy", 2467 | "city", 2468 | "country", 2469 | "direct_flight", 2470 | "flight", 2471 | "flight_package", 2472 | "non_direct_flight", 2473 | "package_service", 2474 | "seat", 2475 | "special_meal_type", 2476 | "stopping", 2477 | "ticke", 2478 | "trip", 2479 | "user", 2480 | "wifi_onboard_service" 2481 | ], 2482 | "table_names_original": [ 2483 | "airline", 2484 | "airline_service", 2485 | "airport", 2486 | "baggage_policy", 2487 | "cabin_class", 2488 | "cancellation_policy", 2489 | "city", 2490 | "country", 2491 | "direct_flight", 2492 | "flight", 2493 | "flight_package", 2494 | "non_direct_flight", 2495 | "package_service", 2496 | "seat", 2497 | "special_meal_type", 2498 | "stopping", 2499 | "ticke", 2500 | "trip", 2501 | "user", 2502 | "wifi_onboard_service" 2503 | ] 2504 | }, 2505 | { 2506 | "db_id": "none", 2507 | "column_names": [ 2508 | [ 2509 | -1, 2510 | "*" 2511 | ] 2512 | ], 2513 | "column_names_original": [ 2514 | [ 2515 | -1, 2516 | "*" 2517 | ] 2518 | ], 2519 | "column_types": [ 2520 | "text" 2521 | ], 2522 | "foreign_keys": {}, 2523 | "primary_keys": {}, 2524 | "table_names": [], 2525 | "table_names_original": [] 2526 | }, 2527 | { 2528 | "db_id": "laptop_array", 2529 | "column_names": [ 2530 | [ 2531 | -1, 2532 | "*" 2533 | ], 2534 | [ 2535 | 0, 2536 | "customer_id" 2537 | ], 2538 | [ 2539 | 0, 2540 | "firstname" 2541 | ], 2542 | [ 2543 | 0, 2544 | "lastname" 2545 | ], 2546 | [ 2547 | 0, 2548 | "city" 2549 | ], 2550 | [ 2551 | 0, 2552 | "address" 2553 | ], 2554 | [ 2555 | 0, 2556 | "email" 2557 | ], 2558 | [ 2559 | 0, 2560 | "phone_numbers" 2561 | ], 2562 | [ 2563 | 1, 2564 | "model" 2565 | ], 2566 | [ 2567 | 1, 2568 | "speed" 2569 | ], 2570 | [ 2571 | 1, 2572 | "ram" 2573 | ], 2574 | [ 2575 | 1, 2576 | "hd" 2577 | ], 2578 | [ 2579 | 1, 2580 | "screen" 2581 | ], 2582 | [ 2583 | 1, 2584 | "price" 2585 | ], 2586 | [ 2587 | 2, 2588 | "model" 2589 | ], 2590 | [ 2591 | 2, 2592 | "speed" 2593 | ], 2594 | [ 2595 | 2, 2596 | "ram" 2597 | ], 2598 | [ 2599 | 2, 2600 | "hd" 2601 | ], 2602 | [ 2603 | 2, 2604 | "price" 2605 | ], 2606 | [ 2607 | 3, 2608 | "model" 2609 | ], 2610 | [ 2611 | 3, 2612 | "color" 2613 | ], 2614 | [ 2615 | 3, 2616 | "type" 2617 | ], 2618 | [ 2619 | 3, 2620 | "price" 2621 | ], 2622 | [ 2623 | 4, 2624 | "maker" 2625 | ], 2626 | [ 2627 | 4, 2628 | "model" 2629 | ], 2630 | [ 2631 | 4, 2632 | "type" 2633 | ], 2634 | [ 2635 | 5, 2636 | "customer_id" 2637 | ], 2638 | [ 2639 | 5, 2640 | "model" 2641 | ], 2642 | [ 2643 | 5, 2644 | "quantity" 2645 | ], 2646 | [ 2647 | 5, 2648 | "day" 2649 | ], 2650 | [ 2651 | 5, 2652 | "paid" 2653 | ], 2654 | [ 2655 | 5, 2656 | "type_of_payment" 2657 | ] 2658 | ], 2659 | "column_names_original": [ 2660 | [ 2661 | -1, 2662 | "*" 2663 | ], 2664 | [ 2665 | 0, 2666 | "customer_id" 2667 | ], 2668 | [ 2669 | 0, 2670 | "firstname" 2671 | ], 2672 | [ 2673 | 0, 2674 | "lastname" 2675 | ], 2676 | [ 2677 | 0, 2678 | "city" 2679 | ], 2680 | [ 2681 | 0, 2682 | "address" 2683 | ], 2684 | [ 2685 | 0, 2686 | "email" 2687 | ], 2688 | [ 2689 | 0, 2690 | "phone_number" 2691 | ], 2692 | [ 2693 | 1, 2694 | "model" 2695 | ], 2696 | [ 2697 | 1, 2698 | "speed" 2699 | ], 2700 | [ 2701 | 1, 2702 | "ram" 2703 | ], 2704 | [ 2705 | 1, 2706 | "hd" 2707 | ], 2708 | [ 2709 | 1, 2710 | "screen" 2711 | ], 2712 | [ 2713 | 1, 2714 | "price" 2715 | ], 2716 | [ 2717 | 2, 2718 | "model" 2719 | ], 2720 | [ 2721 | 2, 2722 | "speed" 2723 | ], 2724 | [ 2725 | 2, 2726 | "ram" 2727 | ], 2728 | [ 2729 | 2, 2730 | "hd" 2731 | ], 2732 | [ 2733 | 2, 2734 | "price" 2735 | ], 2736 | [ 2737 | 3, 2738 | "model" 2739 | ], 2740 | [ 2741 | 3, 2742 | "color" 2743 | ], 2744 | [ 2745 | 3, 2746 | "type" 2747 | ], 2748 | [ 2749 | 3, 2750 | "price" 2751 | ], 2752 | [ 2753 | 4, 2754 | "maker" 2755 | ], 2756 | [ 2757 | 4, 2758 | "model" 2759 | ], 2760 | [ 2761 | 4, 2762 | "type" 2763 | ], 2764 | [ 2765 | 5, 2766 | "customer_id" 2767 | ], 2768 | [ 2769 | 5, 2770 | "model" 2771 | ], 2772 | [ 2773 | 5, 2774 | "quantity" 2775 | ], 2776 | [ 2777 | 5, 2778 | "day" 2779 | ], 2780 | [ 2781 | 5, 2782 | "paid" 2783 | ], 2784 | [ 2785 | 5, 2786 | "type_of_payment" 2787 | ] 2788 | ], 2789 | "column_types": [ 2790 | "text", 2791 | "char", 2792 | "varchar", 2793 | "varchar", 2794 | "varchar", 2795 | "varchar", 2796 | "varchar", 2797 | "array", 2798 | "char", 2799 | "double", 2800 | "int", 2801 | "int", 2802 | "double", 2803 | "double", 2804 | "char", 2805 | "double", 2806 | "int", 2807 | "int", 2808 | "double", 2809 | "char", 2810 | "varchar", 2811 | "varchar", 2812 | "double", 2813 | "char", 2814 | "char", 2815 | "varchar", 2816 | "char", 2817 | "char", 2818 | "int", 2819 | "date", 2820 | "double", 2821 | "varchar" 2822 | ], 2823 | "foreign_keys": {}, 2824 | "primary_keys": {}, 2825 | "table_names": [ 2826 | "customers", 2827 | "laptops", 2828 | "pcs", 2829 | "printers", 2830 | "products", 2831 | "sales" 2832 | ], 2833 | "table_names_original": [ 2834 | "customers", 2835 | "laptops", 2836 | "pcs", 2837 | "printers", 2838 | "products", 2839 | "sales" 2840 | ] 2841 | }, 2842 | { 2843 | "db_id": "laptop_struct", 2844 | "column_names": [ 2845 | [ 2846 | -1, 2847 | "*" 2848 | ], 2849 | [ 2850 | 0, 2851 | "person" 2852 | ], 2853 | [ 2854 | 1, 2855 | "customer_id" 2856 | ], 2857 | [ 2858 | 1, 2859 | "firstname" 2860 | ], 2861 | [ 2862 | 1, 2863 | "lastname" 2864 | ], 2865 | [ 2866 | 1, 2867 | "city" 2868 | ], 2869 | [ 2870 | 1, 2871 | "address" 2872 | ], 2873 | [ 2874 | 1, 2875 | "email" 2876 | ], 2877 | [ 2878 | 2, 2879 | "model" 2880 | ], 2881 | [ 2882 | 2, 2883 | "speed" 2884 | ], 2885 | [ 2886 | 2, 2887 | "ram" 2888 | ], 2889 | [ 2890 | 2, 2891 | "hd" 2892 | ], 2893 | [ 2894 | 2, 2895 | "screen" 2896 | ], 2897 | [ 2898 | 2, 2899 | "price" 2900 | ], 2901 | [ 2902 | 3, 2903 | "model" 2904 | ], 2905 | [ 2906 | 3, 2907 | "speed" 2908 | ], 2909 | [ 2910 | 3, 2911 | "ram" 2912 | ], 2913 | [ 2914 | 3, 2915 | "hd" 2916 | ], 2917 | [ 2918 | 3, 2919 | "price" 2920 | ], 2921 | [ 2922 | 4, 2923 | "model" 2924 | ], 2925 | [ 2926 | 4, 2927 | "color" 2928 | ], 2929 | [ 2930 | 4, 2931 | "type" 2932 | ], 2933 | [ 2934 | 4, 2935 | "price" 2936 | ], 2937 | [ 2938 | 5, 2939 | "maker" 2940 | ], 2941 | [ 2942 | 5, 2943 | "model" 2944 | ], 2945 | [ 2946 | 5, 2947 | "type" 2948 | ], 2949 | [ 2950 | 6, 2951 | "customer_id" 2952 | ], 2953 | [ 2954 | 6, 2955 | "model" 2956 | ], 2957 | [ 2958 | 6, 2959 | "quantity" 2960 | ], 2961 | [ 2962 | 6, 2963 | "day" 2964 | ], 2965 | [ 2966 | 6, 2967 | "paid" 2968 | ], 2969 | [ 2970 | 6, 2971 | "type_of_payment" 2972 | ] 2973 | ], 2974 | "column_names_original": [ 2975 | [ 2976 | -1, 2977 | "*" 2978 | ], 2979 | [ 2980 | 0, 2981 | "person" 2982 | ], 2983 | [ 2984 | 1, 2985 | "customer_id" 2986 | ], 2987 | [ 2988 | 1, 2989 | "firstname" 2990 | ], 2991 | [ 2992 | 1, 2993 | "lastname" 2994 | ], 2995 | [ 2996 | 1, 2997 | "city" 2998 | ], 2999 | [ 3000 | 1, 3001 | "address" 3002 | ], 3003 | [ 3004 | 1, 3005 | "email" 3006 | ], 3007 | [ 3008 | 2, 3009 | "model" 3010 | ], 3011 | [ 3012 | 2, 3013 | "speed" 3014 | ], 3015 | [ 3016 | 2, 3017 | "ram" 3018 | ], 3019 | [ 3020 | 2, 3021 | "hd" 3022 | ], 3023 | [ 3024 | 2, 3025 | "screen" 3026 | ], 3027 | [ 3028 | 2, 3029 | "price" 3030 | ], 3031 | [ 3032 | 3, 3033 | "model" 3034 | ], 3035 | [ 3036 | 3, 3037 | "speed" 3038 | ], 3039 | [ 3040 | 3, 3041 | "ram" 3042 | ], 3043 | [ 3044 | 3, 3045 | "hd" 3046 | ], 3047 | [ 3048 | 3, 3049 | "price" 3050 | ], 3051 | [ 3052 | 4, 3053 | "model" 3054 | ], 3055 | [ 3056 | 4, 3057 | "color" 3058 | ], 3059 | [ 3060 | 4, 3061 | "type" 3062 | ], 3063 | [ 3064 | 4, 3065 | "price" 3066 | ], 3067 | [ 3068 | 5, 3069 | "maker" 3070 | ], 3071 | [ 3072 | 5, 3073 | "model" 3074 | ], 3075 | [ 3076 | 5, 3077 | "type" 3078 | ], 3079 | [ 3080 | 6, 3081 | "customer_id" 3082 | ], 3083 | [ 3084 | 6, 3085 | "model" 3086 | ], 3087 | [ 3088 | 6, 3089 | "quantity" 3090 | ], 3091 | [ 3092 | 6, 3093 | "day" 3094 | ], 3095 | [ 3096 | 6, 3097 | "paid" 3098 | ], 3099 | [ 3100 | 6, 3101 | "type_of_payment" 3102 | ] 3103 | ], 3104 | "column_types": [ 3105 | "text", 3106 | "struct(id int, name: text)", 3107 | "char", 3108 | "varchar", 3109 | "varchar", 3110 | "varchar", 3111 | "varchar", 3112 | "varchar", 3113 | "char", 3114 | "double", 3115 | "int", 3116 | "int", 3117 | "double", 3118 | "double", 3119 | "char", 3120 | "double", 3121 | "int", 3122 | "int", 3123 | "double", 3124 | "char", 3125 | "varchar", 3126 | "varchar", 3127 | "double", 3128 | "char", 3129 | "char", 3130 | "varchar", 3131 | "char", 3132 | "char", 3133 | "int", 3134 | "date", 3135 | "double", 3136 | "varchar" 3137 | ], 3138 | "foreign_keys": {}, 3139 | "primary_keys": {}, 3140 | "table_names": [ 3141 | "test", 3142 | "customers", 3143 | "laptops", 3144 | "pcs", 3145 | "printers", 3146 | "products", 3147 | "sales" 3148 | ], 3149 | "table_names_original": [ 3150 | "test", 3151 | "customers", 3152 | "laptops", 3153 | "pcs", 3154 | "printers", 3155 | "products", 3156 | "sales" 3157 | ] 3158 | }, 3159 | { 3160 | "db_id": "laptop_json", 3161 | "column_names": [ 3162 | [ 3163 | -1, 3164 | "*" 3165 | ], 3166 | [ 3167 | 0, 3168 | "customer_id" 3169 | ], 3170 | [ 3171 | 0, 3172 | "firstname" 3173 | ], 3174 | [ 3175 | 0, 3176 | "lastname" 3177 | ], 3178 | [ 3179 | 0, 3180 | "city" 3181 | ], 3182 | [ 3183 | 0, 3184 | "address" 3185 | ], 3186 | [ 3187 | 0, 3188 | "email" 3189 | ], 3190 | [ 3191 | 1, 3192 | "model" 3193 | ], 3194 | [ 3195 | 1, 3196 | "speed" 3197 | ], 3198 | [ 3199 | 1, 3200 | "ram" 3201 | ], 3202 | [ 3203 | 1, 3204 | "hd" 3205 | ], 3206 | [ 3207 | 1, 3208 | "screen" 3209 | ], 3210 | [ 3211 | 1, 3212 | "price" 3213 | ], 3214 | [ 3215 | 2, 3216 | "model" 3217 | ], 3218 | [ 3219 | 2, 3220 | "speed" 3221 | ], 3222 | [ 3223 | 2, 3224 | "ram" 3225 | ], 3226 | [ 3227 | 2, 3228 | "hd" 3229 | ], 3230 | [ 3231 | 2, 3232 | "price" 3233 | ], 3234 | [ 3235 | 3, 3236 | "model" 3237 | ], 3238 | [ 3239 | 3, 3240 | "color" 3241 | ], 3242 | [ 3243 | 3, 3244 | "type" 3245 | ], 3246 | [ 3247 | 3, 3248 | "price" 3249 | ], 3250 | [ 3251 | 4, 3252 | "maker" 3253 | ], 3254 | [ 3255 | 4, 3256 | "model" 3257 | ], 3258 | [ 3259 | 4, 3260 | "type" 3261 | ], 3262 | [ 3263 | 5, 3264 | "customer_id" 3265 | ], 3266 | [ 3267 | 5, 3268 | "model" 3269 | ], 3270 | [ 3271 | 5, 3272 | "quantity" 3273 | ], 3274 | [ 3275 | 5, 3276 | "day" 3277 | ], 3278 | [ 3279 | 5, 3280 | "paid" 3281 | ], 3282 | [ 3283 | 5, 3284 | "type_of_payment" 3285 | ] 3286 | ], 3287 | "column_names_original": [ 3288 | [ 3289 | -1, 3290 | "*" 3291 | ], 3292 | [ 3293 | 0, 3294 | "customer_id" 3295 | ], 3296 | [ 3297 | 0, 3298 | "firstname" 3299 | ], 3300 | [ 3301 | 0, 3302 | "lastname" 3303 | ], 3304 | [ 3305 | 0, 3306 | "city" 3307 | ], 3308 | [ 3309 | 0, 3310 | "address" 3311 | ], 3312 | [ 3313 | 0, 3314 | "email" 3315 | ], 3316 | [ 3317 | 1, 3318 | "model" 3319 | ], 3320 | [ 3321 | 1, 3322 | "speed" 3323 | ], 3324 | [ 3325 | 1, 3326 | "ram" 3327 | ], 3328 | [ 3329 | 1, 3330 | "hd" 3331 | ], 3332 | [ 3333 | 1, 3334 | "screen" 3335 | ], 3336 | [ 3337 | 1, 3338 | "price" 3339 | ], 3340 | [ 3341 | 2, 3342 | "model" 3343 | ], 3344 | [ 3345 | 2, 3346 | "speed" 3347 | ], 3348 | [ 3349 | 2, 3350 | "ram" 3351 | ], 3352 | [ 3353 | 2, 3354 | "hd" 3355 | ], 3356 | [ 3357 | 2, 3358 | "price" 3359 | ], 3360 | [ 3361 | 3, 3362 | "model" 3363 | ], 3364 | [ 3365 | 3, 3366 | "color" 3367 | ], 3368 | [ 3369 | 3, 3370 | "type" 3371 | ], 3372 | [ 3373 | 3, 3374 | "price" 3375 | ], 3376 | [ 3377 | 4, 3378 | "maker" 3379 | ], 3380 | [ 3381 | 4, 3382 | "model" 3383 | ], 3384 | [ 3385 | 4, 3386 | "type" 3387 | ], 3388 | [ 3389 | 5, 3390 | "customer_id" 3391 | ], 3392 | [ 3393 | 5, 3394 | "model" 3395 | ], 3396 | [ 3397 | 5, 3398 | "quantity" 3399 | ], 3400 | [ 3401 | 5, 3402 | "day" 3403 | ], 3404 | [ 3405 | 5, 3406 | "paid" 3407 | ], 3408 | [ 3409 | 5, 3410 | "type_of_payment" 3411 | ] 3412 | ], 3413 | "column_types": [ 3414 | "text", 3415 | "char", 3416 | "varchar", 3417 | "varchar", 3418 | "varchar", 3419 | "varchar", 3420 | "json", 3421 | "char", 3422 | "double", 3423 | "int", 3424 | "int", 3425 | "double", 3426 | "double", 3427 | "char", 3428 | "double", 3429 | "int", 3430 | "int", 3431 | "double", 3432 | "char", 3433 | "varchar", 3434 | "varchar", 3435 | "double", 3436 | "char", 3437 | "char", 3438 | "varchar", 3439 | "char", 3440 | "char", 3441 | "int", 3442 | "date", 3443 | "double", 3444 | "varchar" 3445 | ], 3446 | "foreign_keys": {}, 3447 | "primary_keys": {}, 3448 | "table_names": [ 3449 | "customers", 3450 | "laptops", 3451 | "pcs", 3452 | "printers", 3453 | "products", 3454 | "sales" 3455 | ], 3456 | "table_names_original": [ 3457 | "customers", 3458 | "laptops", 3459 | "pcs", 3460 | "printers", 3461 | "products", 3462 | "sales" 3463 | ] 3464 | } 3465 | ] -------------------------------------------------------------------------------- /eval/data_utils.py: -------------------------------------------------------------------------------- 1 | """Training data prep utils.""" 2 | import json 3 | import re 4 | from collections import defaultdict 5 | from schema import ForeignKey, Table, TableColumn 6 | 7 | 8 | def read_tables_json( 9 | schema_file: str, 10 | lowercase: bool = False, 11 | ) -> dict[str, dict[str, Table]]: 12 | """Read tables json.""" 13 | data = json.load(open(schema_file)) 14 | db_to_tables = {} 15 | for db in data: 16 | db_name = db["db_id"] 17 | table_names = db["table_names_original"] 18 | db["column_names_original"] = [ 19 | [x[0], x[1]] for x in db["column_names_original"] 20 | ] 21 | db["column_types"] = db["column_types"] 22 | if lowercase: 23 | table_names = [tn.lower() for tn in table_names] 24 | pks = db["primary_keys"] 25 | fks = db["foreign_keys"] 26 | tables = defaultdict(list) 27 | tables_pks = defaultdict(list) 28 | tables_fks = defaultdict(list) 29 | for idx, ((ti, col_name), col_type) in enumerate( 30 | zip(db["column_names_original"], db["column_types"]) 31 | ): 32 | if ti == -1: 33 | continue 34 | if lowercase: 35 | col_name = col_name.lower() 36 | col_type = col_type.lower() 37 | if idx in pks: 38 | tables_pks[table_names[ti]].append( 39 | TableColumn(name=col_name, dtype=col_type) 40 | ) 41 | for fk in fks: 42 | if idx == fk[0]: 43 | other_column = db["column_names_original"][fk[1]] 44 | other_column_type = db["column_types"][fk[1]] 45 | other_table = table_names[other_column[0]] 46 | tables_fks[table_names[ti]].append( 47 | ForeignKey( 48 | column=TableColumn(name=col_name, dtype=col_type), 49 | references_name=other_table, 50 | references_column=TableColumn( 51 | name=other_column[1], dtype=other_column_type 52 | ), 53 | ) 54 | ) 55 | tables[table_names[ti]].append(TableColumn(name=col_name, dtype=col_type)) 56 | db_to_tables[db_name] = { 57 | table_name: Table( 58 | name=table_name, 59 | columns=tables[table_name], 60 | pks=tables_pks[table_name], 61 | fks=tables_fks[table_name], 62 | examples=None, 63 | ) 64 | for table_name in tables 65 | } 66 | return db_to_tables 67 | 68 | 69 | def clean_str(target: str) -> str: 70 | """Clean string for question.""" 71 | if not target: 72 | return target 73 | 74 | target = re.sub(r"[^\x00-\x7f]", r" ", target) 75 | line = re.sub(r"''", r" ", target) 76 | line = re.sub(r"``", r" ", line) 77 | line = re.sub(r"\"", r"'", line) 78 | line = re.sub(r"[\t ]+", " ", line) 79 | return line.strip() 80 | -------------------------------------------------------------------------------- /eval/doc_retriever.py: -------------------------------------------------------------------------------- 1 | """Retrieve documentation for a given query.""" 2 | 3 | from pathlib import Path 4 | from typing import Any 5 | from rich.console import Console 6 | from tqdm import tqdm 7 | import numpy as np 8 | from manifest import Manifest 9 | from langchain.text_splitter import MarkdownHeaderTextSplitter 10 | from langchain.text_splitter import RecursiveCharacterTextSplitter 11 | 12 | console = Console(soft_wrap=True) 13 | 14 | try: 15 | EMBEDDING_MODEL = Manifest( 16 | client_name="openaiembedding", 17 | cache_name="sqlite", 18 | cache_connection=".manifest.sqlite", 19 | ) 20 | except Exception as e: 21 | console.print(e) 22 | console.print( 23 | "Failed to load embedding model. Likely OPENAI API key is not set. Please set to run document retrieval.", 24 | style="bold red", 25 | ) 26 | 27 | 28 | def load_documentation(path: Path) -> dict[str, str]: 29 | """Load documentation from path.""" 30 | content = {} 31 | for file in path.glob("**/*.md"): 32 | with open(file, "r") as f: 33 | data = f.read() 34 | key = str(file).replace(str(path), "") 35 | content[key] = data 36 | return content 37 | 38 | 39 | def split_documents(content: dict[str, str]) -> dict[str, Any]: 40 | """Split documents into chunks.""" 41 | md_splitted_docs = [] 42 | markdown_splitter = MarkdownHeaderTextSplitter( 43 | headers_to_split_on=[("#", "Header 1"), ("##", "Header 2"), ("###", "Header 3")] 44 | ) 45 | text_splitter = RecursiveCharacterTextSplitter( 46 | separators=["\n"], chunk_size=500, chunk_overlap=50, length_function=len 47 | ) 48 | 49 | for file, raw_doc in content.items(): 50 | splitted_text = markdown_splitter.split_text(raw_doc) 51 | for t in splitted_text: 52 | t.metadata["source"] = file 53 | md_splitted_docs.extend(splitted_text) 54 | 55 | docs = text_splitter.split_documents(md_splitted_docs) 56 | docs_as_dict = [doc.dict() for doc in docs] 57 | return docs_as_dict 58 | 59 | 60 | def get_embeddings(text: str) -> np.ndarray: 61 | """Get embeddings.""" 62 | return np.array(EMBEDDING_MODEL.run(text)) 63 | 64 | 65 | def embed_documents( 66 | chunked_docs: dict[str, Any], key: str = "page_content" 67 | ) -> tuple[dict[str, Any], np.ndarray]: 68 | """Embed documents.""" 69 | all_embeddings = [] 70 | for doc in tqdm(chunked_docs): 71 | emb = get_embeddings(doc[key]) 72 | doc["embedding"] = emb 73 | all_embeddings.append(doc["embedding"]) 74 | full_embedding_mat = np.vstack(all_embeddings) 75 | return chunked_docs, full_embedding_mat 76 | 77 | 78 | def query_docs( 79 | query: str, 80 | docs: dict[str, Any], 81 | embedding_mat: np.ndarray, 82 | top_n: int = 10, 83 | key: str = "page_content", 84 | ) -> tuple[list[int], list[str]]: 85 | """Query documents.""" 86 | query_embedding = get_embeddings(query) 87 | scores = embedding_mat.dot(query_embedding) 88 | sorted_indices = np.argsort(scores)[::-1] 89 | top_n_indices = sorted_indices[:top_n] 90 | top_n_indices_rev = top_n_indices[::-1] 91 | returned_docs = [] 92 | for i in top_n_indices_rev: 93 | returned_docs.append(docs[i][key]) 94 | return top_n_indices_rev.tolist(), returned_docs 95 | -------------------------------------------------------------------------------- /eval/evaluate.py: -------------------------------------------------------------------------------- 1 | """Evaluate text2sql spider model predictions.""" 2 | import json 3 | import os 4 | import re 5 | import signal 6 | import sys 7 | import traceback 8 | from pathlib import Path 9 | from typing import Any 10 | 11 | import click 12 | import pandas as pd 13 | from rich.console import Console 14 | from tqdm.auto import tqdm 15 | 16 | sys.path.append(os.path.join(os.path.dirname(__file__), ".")) 17 | # from metrics.spider import evaluation as spider_evaluation # type: ignore # noqa: E402 18 | from metrics.test_suite_sql_eval import ( # type: ignore # noqa: E402 19 | evaluation as test_suite_evaluation, 20 | ) 21 | from data_utils import read_tables_json # type: ignore # noqa: E402 22 | from metric_utils import ( # type: ignore # noqa: E402 23 | correct_casing, 24 | edit_distance, 25 | ) 26 | 27 | console = Console(soft_wrap=True) 28 | 29 | LEVELS = ["easy", "medium", "hard", "duckdb", "ddl", "all"] 30 | PARTIAL_TYPES = [ 31 | "select", 32 | "select(no AGG)", 33 | "where", 34 | "where(no OP)", 35 | "group(no Having)", 36 | "group", 37 | "order", 38 | "and/or", 39 | "IUEN", 40 | "keywords", 41 | ] 42 | TIMEOUT_SECONDS = 30 43 | 44 | 45 | def timeout_handler(signum: int, frame: Any) -> None: 46 | raise TimeoutError("Function execution timed out.") 47 | 48 | 49 | def print_scores(scores: dict, model_name: str, metric_type: str = "exec") -> None: 50 | """Print scores.""" 51 | 52 | def print_formated_s( 53 | row_name: str, l: list[str], element_format: str = "{}", sep: str = "\t" 54 | ) -> None: 55 | template = "{}" + sep + sep.join([element_format] * len(l)) 56 | console.print(template.format(row_name, *l)) 57 | 58 | # Add empty scores for each level if not present 59 | for level in LEVELS: 60 | if level not in scores: 61 | scores[level] = {} 62 | scores[level]["count"] = 0 63 | scores[level]["exec"] = 0 64 | scores[level]["exact"] = 0 65 | 66 | print_formated_s("", LEVELS) 67 | counts = [scores[level]["count"] for level in LEVELS] 68 | print_formated_s("count", counts) 69 | console.print(f">====================== {model_name} =====================") 70 | if metric_type == "exec": 71 | console.print( 72 | ">===================== EXECUTION ACCURACY =====================" 73 | ) 74 | exec_scores = [scores[level]["exec"] for level in LEVELS] 75 | print_formated_s("execution", exec_scores, element_format="{:.3f}") 76 | 77 | elif metric_type == "exact": 78 | console.print( 79 | "\n>====================== EXACT MATCHING ACCURACY =====================" 80 | ) 81 | exact_scores = [scores[level]["exact"] for level in LEVELS] 82 | print_formated_s("exact match", exact_scores, element_format="{:.3f}") 83 | 84 | 85 | def compute_exact_match_metric( 86 | predictions: list, 87 | references: list, 88 | gold_dbs: list, 89 | kmaps: dict, 90 | db_dir: str, 91 | categories, 92 | ) -> dict: 93 | """Compute exact match metric.""" 94 | exact_match = {} 95 | exact_match["all"] = {} 96 | exact_match["all"]["count"] = 0 97 | exact_match["all"]["exact"] = 0 98 | for prediction, reference, gold_db, category in tqdm( 99 | zip(predictions, references, gold_dbs, categories), total=len(predictions) 100 | ): 101 | if category not in exact_match: 102 | exact_match[category] = {} 103 | exact_match[category]["count"] = 0 104 | exact_match[category]["exact"] = 0 105 | exact_match["all"]["count"] += 1 106 | exact_match[category]["count"] += 1 107 | try: 108 | match = int(prediction.trim() == reference.trim()) 109 | exact_match[category]["exact"] += match 110 | exact_match["all"]["exact"] += match 111 | except Exception: 112 | pass 113 | return exact_match 114 | 115 | 116 | def compute_test_suite_metric( 117 | predictions: list, 118 | references: list, 119 | gold_dbs: list, 120 | setup_sqls: list, 121 | validate_sqls: list, 122 | kmaps: dict, 123 | db_dir: str, 124 | categories: list[str] = None, 125 | ) -> tuple[Any, list[int | None]]: 126 | """Compute test suite execution metric.""" 127 | evaluator = test_suite_evaluation.Evaluator( 128 | db_dir=db_dir, 129 | kmaps=kmaps, 130 | etype="exec", 131 | plug_value=False, 132 | keep_distinct=False, 133 | progress_bar_for_each_datapoint=False, 134 | ) 135 | # Only used for Sparc/CoSQL 136 | turn_scores: dict[str, list] = {"exec": [], "exact": []} 137 | by_row_metrics: list[int | None] = [] 138 | for prediction, reference, gold_db, setup_sql, validate_sql, category in tqdm( 139 | zip(predictions, references, gold_dbs, setup_sqls, validate_sqls, categories), 140 | total=len(predictions), 141 | ): 142 | turn_idx = 0 143 | # skip final utterance-query pairs 144 | if turn_idx < 0: 145 | continue 146 | 147 | # Register the timeout handler function 148 | signal.signal(signal.SIGALRM, timeout_handler) 149 | signal.alarm(TIMEOUT_SECONDS) 150 | 151 | try: 152 | ex_metrics = evaluator.evaluate_one( 153 | gold_db, 154 | reference, 155 | prediction, 156 | setup_sql, 157 | validate_sql, 158 | turn_scores, 159 | idx=turn_idx, 160 | category=category, 161 | ) 162 | signal.alarm(0) 163 | 164 | by_row_metrics.append(int(ex_metrics["exec"])) 165 | except Exception as e: 166 | raise e 167 | by_row_metrics.append(None) 168 | pass 169 | evaluator.finalize() 170 | return evaluator.scores, by_row_metrics 171 | 172 | 173 | def compute_metrics( 174 | gold_sqls: list[str], 175 | pred_sqls: list[str], 176 | gold_dbs: list[str], 177 | setup_sqls: list[str], 178 | validate_sqls: list[str], 179 | kmaps: dict, 180 | db_schemas: dict, 181 | database_dir: str, 182 | lowercase_schema_match: bool, 183 | model_name: str, 184 | categories: list[str] = None, 185 | ) -> dict[str, str]: 186 | """Compute all metrics for data slice.""" 187 | if len(gold_sqls) != len(pred_sqls): 188 | raise ValueError( 189 | f"Gold {len(gold_sqls)} and pred {len(pred_sqls)} have different number of lines!" 190 | ) 191 | all_metrics: dict[str, Any] = {} 192 | 193 | # Execution Accuracy 194 | metrics, by_row_metrics = compute_test_suite_metric( 195 | pred_sqls, 196 | gold_sqls, 197 | gold_dbs, 198 | setup_sqls, 199 | validate_sqls, 200 | kmaps, 201 | database_dir, 202 | categories, 203 | ) 204 | all_metrics["exec"] = metrics 205 | all_metrics["by_row_exec"] = by_row_metrics 206 | print_scores(metrics, model_name, "exec") 207 | 208 | # Exact Match Accuracy 209 | metrics = compute_exact_match_metric( 210 | pred_sqls, gold_sqls, gold_dbs, kmaps, database_dir, categories 211 | ) 212 | all_metrics["exact"] = metrics 213 | print_scores(metrics, model_name, "exact") 214 | 215 | # Equality Accuracy 216 | per_row_match = [ 217 | int(gold.lower() == pred.lower()) for gold, pred in zip(gold_sqls, pred_sqls) 218 | ] 219 | all_metrics["equality"] = {"equality": sum(per_row_match) / len(gold_sqls)} 220 | all_metrics["by_row_equality"] = per_row_match 221 | 222 | # Edit Distance 223 | per_row_edit_dist = [ 224 | edit_distance(gold, pred) for gold, pred in zip(gold_sqls, pred_sqls) 225 | ] 226 | edit_dist = sum(per_row_edit_dist) / len(gold_sqls) 227 | all_metrics["edit_distance"] = {"edit_distance": edit_dist} 228 | all_metrics["by_row_edit_distance"] = per_row_edit_dist 229 | 230 | return all_metrics 231 | 232 | 233 | def get_to_print(metrics: dict, key: str, model_name: str, num_rows: int) -> dict: 234 | """Get pretty print dictionary of metrics.""" 235 | return { 236 | "slice": key, 237 | "model": model_name, 238 | "support": num_rows, 239 | "exec": f"{metrics[key]['exec']['all']['exec']:.3f}", 240 | "exact": f"{metrics[key]['exact']['all']['exact']:.3f}", 241 | "equality": f"{metrics[key]['equality']['equality']:.3f}", 242 | "edit_distance": f"{metrics[key]['edit_distance']['edit_distance']:.3f}", 243 | } 244 | 245 | 246 | @click.group() 247 | def cli() -> None: 248 | """Entrypoint.""" 249 | pass 250 | 251 | 252 | @cli.command() 253 | @click.option("--gold", type=str, required=True) 254 | @click.option("--pred", type=str, required=True) 255 | @click.option("--tables", type=str, required=True) 256 | @click.option("--db", type=str, default="") 257 | @click.option("--slice-attribute", type=str, default=None) 258 | @click.option("--output-dir", type=str, default="") 259 | @click.option("--output-filename", type=str, default="") 260 | @click.option( 261 | "--correct-sql-casing", type=bool, is_flag=True, default=False, required=False 262 | ) 263 | @click.option( 264 | "--lowercase-schema-match", type=bool, is_flag=True, default=False, required=False 265 | ) 266 | def evaluate( 267 | gold: str, 268 | pred: str, 269 | tables: str, 270 | db: str, 271 | slice_attribute: str, 272 | output_dir: str, 273 | output_filename: str, 274 | correct_sql_casing: bool, 275 | lowercase_schema_match: bool, 276 | ) -> None: 277 | """Evaluate SQL. 278 | 279 | Args: 280 | gold: path to gold sql file. 281 | pred: path to predicted json lines file. 282 | tables: the json path of the table metadata. 283 | db: path to database dir. 284 | slice_attribute: json attribute in gold data to slice on. 285 | output_dir: the prediction output directory 286 | output_filename: the prediction output filename 287 | correct_sql_casing: whether to correct casing of SQL keywords 288 | lowercase_schema_match: whether to lowercase schema match 289 | """ 290 | gold_path = Path(gold) 291 | pred_path = Path(pred) 292 | model_name = pred_path.stem 293 | if not output_filename: 294 | output_filename = pred_path.stem + "_eval.json" 295 | console.print(f"Saving to {Path(output_dir) / output_filename}") 296 | Path(output_dir).mkdir(parents=True, exist_ok=True) 297 | 298 | kmaps = test_suite_evaluation.build_foreign_key_map_from_json(tables) 299 | db_schemas = read_tables_json(tables) 300 | 301 | gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8")) 302 | pred_sqls_dict = [json.loads(l) for l in pred_path.open("r").readlines()] 303 | 304 | # Data validation 305 | assert len(gold_sqls_dict) == len( 306 | pred_sqls_dict 307 | ), "Sample size doesn't match between pred and gold file" 308 | 309 | # Keep track of everything 310 | full_results = [] 311 | for gold_sql, pred_sql in zip(gold_sqls_dict, pred_sqls_dict): 312 | merged_res = {**pred_sql, **gold_sql} 313 | full_results.append(merged_res) 314 | 315 | gold_sqls = [ 316 | re.sub(r"[\s\t\n]+", " ", p.get("gold", p.get("query", p.get("sql", "")))) 317 | for p in gold_sqls_dict 318 | ] 319 | setup_sqls = [re.sub(r"[\s\t\n]+", " ", p["setup_sql"]) for p in gold_sqls_dict] 320 | validate_sqls = [ 321 | re.sub(r"[\s\t\n]+", " ", p["validation_sql"]) for p in gold_sqls_dict 322 | ] 323 | gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict] 324 | pred_sqls = [re.sub(r"[\s\t\n]+", " ", p["pred"]) for p in pred_sqls_dict] 325 | categories = [p.get("category", "") for p in gold_sqls_dict] 326 | if correct_sql_casing: 327 | # One line to correct casing of SQL keywords using correct_casing(sql) 328 | gold_sqls = [correct_casing(sql) for sql in gold_sqls] 329 | pred_sqls = [correct_casing(sql) for sql in pred_sqls] 330 | 331 | final_metrics: dict[str, dict[str, Any]] = {} 332 | to_print = [] 333 | final_metrics["all"] = compute_metrics( 334 | gold_sqls=gold_sqls, 335 | pred_sqls=pred_sqls, 336 | gold_dbs=gold_dbs, 337 | setup_sqls=setup_sqls, 338 | validate_sqls=validate_sqls, 339 | kmaps=kmaps, 340 | db_schemas=db_schemas, 341 | database_dir=db, 342 | lowercase_schema_match=lowercase_schema_match, 343 | model_name=model_name + "(all)", 344 | categories=categories, 345 | ) 346 | 347 | for k, v in final_metrics["all"].items(): 348 | if k.startswith("by_row"): 349 | assert len(v) == len(gold_sqls) 350 | for dct, val in zip(full_results, v): 351 | dct[k[len("by_row_") :]] = val 352 | to_print.append(get_to_print(final_metrics, "all", model_name, len(gold_sqls))) 353 | # TODO: could be way more efficient if we subsliced the results but...whatever 354 | if slice_attribute: 355 | for unq_value in sorted(set([g[slice_attribute] for g in gold_sqls_dict])): 356 | idx_set = [ 357 | i 358 | for i, g in enumerate(gold_sqls_dict) 359 | if g[slice_attribute] == unq_value 360 | ] 361 | print(f"Processing {unq_value} with {len(idx_set)} samples") 362 | final_metrics[unq_value] = compute_metrics( 363 | gold_sqls=[gold_sqls[i] for i in idx_set], 364 | pred_sqls=[pred_sqls[i] for i in idx_set], 365 | gold_dbs=[gold_dbs[i] for i in idx_set], 366 | setup_sqls=[setup_sqls[i] for i in idx_set], 367 | validate_sqls=[validate_sqls[i] for i in idx_set], 368 | kmaps=kmaps, 369 | db_schemas=db_schemas, 370 | database_dir=db, 371 | lowercase_schema_match=lowercase_schema_match, 372 | model_name=model_name + f"({unq_value})", 373 | categories=[categories[i] for i in idx_set], 374 | ) 375 | to_print.append( 376 | get_to_print(final_metrics, unq_value, model_name, len(idx_set)) 377 | ) 378 | 379 | df = pd.DataFrame(to_print) 380 | console.print(df.to_csv(sep=",", index=False)) 381 | console.print("******") 382 | console.print(f"Saved metrics to {Path(output_dir) / output_filename}") 383 | json.dump(final_metrics, open(Path(output_dir) / output_filename, "w"), indent=4) 384 | output_filename = str(output_filename).replace("_eval.json", "_fd.jsonl") 385 | console.print(f"Saved dump to {Path(output_dir) / output_filename}") 386 | with open(Path(output_dir) / output_filename, "w") as f: 387 | for dct in full_results: 388 | f.write(json.dumps(dct) + "\n") 389 | 390 | 391 | if __name__ == "__main__": 392 | cli() 393 | -------------------------------------------------------------------------------- /eval/get_manifest.py: -------------------------------------------------------------------------------- 1 | """Manifest utils.""" 2 | from manifest import Manifest 3 | from manifest.connections.client_pool import ClientConnection 4 | 5 | 6 | def get_manifest( 7 | manifest_client: str, 8 | manifest_connection: str, 9 | manifest_engine: str, 10 | ) -> Manifest: 11 | """Get manifest engine.""" 12 | if manifest_client in {"openai", "openaichat", "openai_mock"}: 13 | manifest = Manifest( 14 | client_name=manifest_client, 15 | engine=manifest_engine, 16 | cache_name="redis", 17 | cache_connection="localhost:6411", 18 | ) 19 | elif manifest_client in {"huggingface"}: 20 | manifest = Manifest( 21 | client_pool=[ 22 | ClientConnection( 23 | client_name=manifest_client, 24 | client_connection=manifest_conn, 25 | ) 26 | for manifest_conn in manifest_connection.split(";") 27 | ], 28 | cache_name="redis", 29 | cache_connection="localhost:6411", 30 | ) 31 | else: 32 | raise ValueError(f"Unknown manifest client {manifest_client}") 33 | return manifest 34 | -------------------------------------------------------------------------------- /eval/loaders.py: -------------------------------------------------------------------------------- 1 | """Data loaders.""" 2 | import json 3 | import re 4 | import string 5 | from abc import ABC, abstractmethod 6 | 7 | from rich.console import Console 8 | from data_utils import read_tables_json 9 | from schema import Table 10 | 11 | RE_COLUMN = re.compile(r"^select (.+?) from") 12 | RE_CONDS = re.compile(r"where (.+?)$") 13 | RE_COND = re.compile(r"^(.+?)\s*([=><])\s*(.+?)$") 14 | 15 | translator = str.maketrans( 16 | string.punctuation, " " * len(string.punctuation) 17 | ) # map punctuation to space 18 | 19 | console = Console(soft_wrap=True) 20 | 21 | 22 | def standardize_column(col: str) -> str: 23 | """Standardize the column name to SQL compatible.""" 24 | col_name = col.replace("#", "num").replace("%", "perc") 25 | col_name = col_name.strip().lower().translate(translator) 26 | col_name = re.sub("[^0-9a-z ]", " ", col_name).strip() 27 | col_name = re.sub(" +", "_", col_name) 28 | if not col_name: 29 | console.print(f"original {col}, new {col_name}") 30 | return col_name 31 | 32 | 33 | def clean_col(col: str) -> str: 34 | """Remove table name and standardize column name.""" 35 | if "." in col and not col.endswith("."): 36 | col = col.split(".")[-1] 37 | return standardize_column(col) 38 | 39 | 40 | class Loader(ABC): 41 | """Loader abstract class.""" 42 | 43 | @classmethod 44 | @abstractmethod 45 | def load_data(cls, path: str) -> list[dict]: 46 | """Load data from path.""" 47 | 48 | @classmethod 49 | @abstractmethod 50 | def load_table_metadata(cls, path: str) -> dict[str, dict[str, Table]]: 51 | """Extract table metadata from table-metadata-path.""" 52 | 53 | @classmethod 54 | def format_output(cls, prediction: dict) -> dict: 55 | """Parse for spider format.""" 56 | return prediction 57 | 58 | 59 | class DefaultLoader(Loader): 60 | """Spider loader and writer.""" 61 | 62 | @classmethod 63 | def load_data(cls, path: str) -> list[dict]: 64 | """Load data from path.""" 65 | try: 66 | with open(path) as f: 67 | data = json.loads(f.read()) 68 | except json.decoder.JSONDecodeError: 69 | # Try with jsonl 70 | data = [json.loads(line) for line in open(path)] 71 | return data 72 | 73 | @classmethod 74 | def load_table_metadata(cls, path: str) -> dict[str, dict[str, Table]]: 75 | """Extract table metadata from table-metadata-path.""" 76 | # load the tables 77 | db_to_tables = read_tables_json(path, lowercase=True) 78 | return db_to_tables 79 | -------------------------------------------------------------------------------- /eval/metric_utils.py: -------------------------------------------------------------------------------- 1 | """Utility metrics.""" 2 | import sqlglot 3 | from rich.console import Console 4 | from sqlglot import parse_one 5 | 6 | console = Console(soft_wrap=True) 7 | 8 | 9 | def correct_casing(sql: str) -> str: 10 | """Correct casing of SQL.""" 11 | parse: sqlglot.expressions.Expression = parse_one(sql, read="sqlite") 12 | return parse.sql() 13 | 14 | 15 | def prec_recall_f1(gold: set, pred: set) -> dict[str, float]: 16 | """Compute precision, recall and F1 score.""" 17 | prec = len(gold.intersection(pred)) / len(pred) if pred else 0.0 18 | recall = len(gold.intersection(pred)) / len(gold) if gold else 0.0 19 | f1 = 2 * prec * recall / (prec + recall) if prec + recall else 0.0 20 | return {"prec": prec, "recall": recall, "f1": f1} 21 | 22 | 23 | def edit_distance(s1: str, s2: str) -> int: 24 | """Compute edit distance between two strings.""" 25 | # Make sure s1 is the shorter string 26 | if len(s1) > len(s2): 27 | s1, s2 = s2, s1 28 | 29 | distances: list[int] = list(range(len(s1) + 1)) 30 | for i2, c2 in enumerate(s2): 31 | distances_ = [i2 + 1] 32 | for i1, c1 in enumerate(s1): 33 | if c1 == c2: 34 | distances_.append(distances[i1]) 35 | else: 36 | distances_.append( 37 | 1 + min((distances[i1], distances[i1 + 1], distances_[-1])) 38 | ) 39 | distances = distances_ 40 | return distances[-1] 41 | -------------------------------------------------------------------------------- /eval/predict.py: -------------------------------------------------------------------------------- 1 | """Run dataset on text2sql zazu experiment. 2 | 3 | See README.md for more details. 4 | """ 5 | import datetime 6 | import json 7 | import multiprocessing 8 | import random 9 | import re 10 | from pathlib import Path 11 | 12 | import click 13 | import numpy as np 14 | from constants import PROMPT_FORMATTERS 15 | from loaders import DefaultLoader 16 | from get_manifest import get_manifest 17 | from manifest import Manifest 18 | from prompt_formatters import RajkumarFormatter 19 | from rich.console import Console 20 | from schema import Table, TextToSQLModelResponse, TextToSQLParams 21 | from text_to_sql import instruction_to_sql, instruction_to_sql_list 22 | from doc_retriever import ( 23 | load_documentation, 24 | split_documents, 25 | embed_documents, 26 | query_docs, 27 | ) 28 | from tqdm import tqdm 29 | from transformers import AutoTokenizer 30 | 31 | console = Console(soft_wrap=True) 32 | 33 | 34 | def generate_sql( 35 | manifest: Manifest, 36 | text_to_sql_in: list[TextToSQLParams], 37 | retrieved_docs: list[list[str]], 38 | prompt_formatter: RajkumarFormatter, 39 | stop_tokens: list[str] | None = None, 40 | overwrite_manifest: bool = False, 41 | max_tokens: int = 300, 42 | temperature: float = 0.01, 43 | num_beams: int = 2, 44 | parallel: bool = False, 45 | ) -> list[tuple[str, TextToSQLModelResponse]]: 46 | """Call our text2sql function with manifest of our choice.""" 47 | if parallel: 48 | instruction_to_sql_resps: list[ 49 | TextToSQLModelResponse 50 | ] = instruction_to_sql_list( 51 | params=text_to_sql_in, 52 | extra_context=retrieved_docs, 53 | manifest=manifest, 54 | prompt_formatter=prompt_formatter, 55 | overwrite_manifest=overwrite_manifest, 56 | max_tokens=max_tokens, 57 | temperature=temperature, 58 | stop_sequences=stop_tokens, 59 | num_beams=num_beams, 60 | ) 61 | else: 62 | instruction_to_sql_resps = [ 63 | instruction_to_sql( 64 | params=_text_to_sql_in, 65 | extra_context=_retrieved_docs, 66 | manifest=manifest, 67 | prompt_formatter=prompt_formatter, 68 | overwrite_manifest=overwrite_manifest, 69 | max_tokens=max_tokens, 70 | temperature=temperature, 71 | stop_sequences=stop_tokens, 72 | num_beams=num_beams, 73 | ) 74 | for _retrieved_docs, _text_to_sql_in in tqdm( 75 | zip(retrieved_docs, text_to_sql_in), 76 | desc="Generating SQL", 77 | total=len(text_to_sql_in), 78 | disable=(len(text_to_sql_in) <= 1), 79 | ) 80 | ] 81 | assert len(instruction_to_sql_resps) == len(text_to_sql_in) 82 | 83 | sql_statements = [] 84 | for i in range(len(instruction_to_sql_resps)): 85 | sql_statement = instruction_to_sql_resps[i].output.strip() 86 | if "<>" in sql_statement: 87 | sql_statement.replace("<>", "!=") 88 | # Models sometime train to predict | 89 | sql_statement = sql_statement.split("|")[-1].strip() 90 | sql_statements.append(sql_statement) 91 | return list(zip(sql_statements, instruction_to_sql_resps)) 92 | 93 | 94 | def get_text_to_sql_in( 95 | input_question: dict, db_to_tables: dict[str, dict[str, Table]] 96 | ) -> TextToSQLParams: 97 | """Format input question for text2sql function.""" 98 | question = input_question["question"] 99 | db_id = input_question.get("db_id", None) 100 | if db_id != "none": 101 | table_params = list(db_to_tables.get(db_id, {}).values()) 102 | else: 103 | table_params = [] 104 | if len(table_params) == 0: 105 | console.print(f"[red] WARNING: No tables found for {db_id} [/red]") 106 | text_to_sql_in = TextToSQLParams( 107 | instruction=question, 108 | database=db_id, 109 | tables=table_params, 110 | ) 111 | return text_to_sql_in 112 | 113 | 114 | @click.group() 115 | def cli() -> None: 116 | """Entrypoint.""" 117 | pass 118 | 119 | 120 | @cli.command() 121 | @click.argument("dataset-path") 122 | @click.argument("table-meta-path") 123 | @click.option("--output-dir", type=str, default="") 124 | @click.option("--run-name", type=str, default="") 125 | @click.option("--num-run", type=int, default=-1) 126 | @click.option("--num-print", type=int, default=20) 127 | # Format options 128 | @click.option("--prompt-format", type=str, default="spider") 129 | # Prompt options 130 | @click.option("--stop-tokens", type=str, default=[], multiple=True) 131 | @click.option("--max-tokens", type=int, default=200) 132 | @click.option("--temperature", type=float, default=0) 133 | @click.option("--num-beams", type=int, default=-1) # use whatever is in manifest 134 | @click.option("--max-context-length", type=int, default=-1) 135 | # Docs options 136 | @click.option( 137 | "--markdown-docs-path", 138 | type=click.Path( 139 | exists=True, file_okay=True, dir_okay=True, readable=True, path_type=Path 140 | ), 141 | default="eval/docs/duckdb-web/docs/archive/0.9.2/sql", 142 | ) 143 | @click.option("--num-retrieved-docs", type=int, default=0) 144 | # Manifest options 145 | @click.option("--manifest-client", type=str, default="openai") 146 | @click.option("--manifest-engine", type=str, default="gpt-4-1106-preview") 147 | @click.option("--manifest-connection", type=str, default="http://localhost:5005") 148 | @click.option("--overwrite-manifest", is_flag=True, default=False) 149 | @click.option("--parallel", is_flag=True, default=False) 150 | def predict( 151 | dataset_path: str, 152 | table_meta_path: str, 153 | output_dir: str, 154 | run_name: str, 155 | num_run: int, 156 | num_print: int, 157 | prompt_format: str, 158 | stop_tokens: list[str], 159 | max_tokens: int, 160 | temperature: float, 161 | num_beams: int, 162 | max_context_length: int, 163 | markdown_docs_path: Path, 164 | num_retrieved_docs: int, 165 | manifest_client: str, 166 | manifest_engine: str, 167 | manifest_connection: str, 168 | overwrite_manifest: bool, 169 | parallel: bool, 170 | ) -> None: 171 | """Predict SQL. 172 | 173 | Args: 174 | dataset_path: the dataset path. 175 | table_meta_path: the json path of the table metadata. 176 | database_path: the database path for sqlite. 177 | output_dir: the prediction output directory 178 | run_name: special prefix to add to filename 179 | num_run: the number of examples to run 180 | num_print: the number of examples to print 181 | prompt_format: the format of the prompt. E.g., "rajkumar" 182 | stop_tokens: the stop tokens to try 183 | max_tokens: the max tokens 184 | temperature: the temperature 185 | num_beams: the number of beams 186 | max_context_length: max context length for demonstration truncation (-1 means None) 187 | markdown_docs_path: path to duckdb sql docs 188 | num_retrieved_docs: number of docs to retrieve 189 | manifest_client: the manifest client 190 | manifest_engine: the manifest engine 191 | manifest_connection: the manifest connection 192 | """ 193 | multiprocessing.set_start_method("spawn", force=True) 194 | random.seed(0) 195 | np.random.seed(0) 196 | locals_dict = locals() 197 | locals_dict["markdown_docs_path"] = str(markdown_docs_path) 198 | console.print(json.dumps(locals_dict, indent=2)) 199 | 200 | data_formatter = DefaultLoader() 201 | 202 | if prompt_format not in PROMPT_FORMATTERS: 203 | raise ValueError(f"Unknown prompt format {prompt_format}") 204 | prompt_formatter = PROMPT_FORMATTERS[prompt_format]() 205 | 206 | # load manifest 207 | manifest = get_manifest( 208 | manifest_client=manifest_client, 209 | manifest_connection=manifest_connection, 210 | manifest_engine=manifest_engine, 211 | ) 212 | manifest_params = manifest.client_pool.get_current_client().get_model_params() 213 | console.print(f"Running with {manifest_params} manifest.") 214 | model_name = manifest_params.get("engine", manifest_params["model_name"]) 215 | 216 | if "openai" in manifest_client: 217 | tokenizer = AutoTokenizer.from_pretrained("gpt2", trust_remote_code=True) 218 | else: 219 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 220 | 221 | if stop_tokens: 222 | stop_tokens = [st.strip("'") for st in stop_tokens] 223 | console.print(f"Stop tokens: {stop_tokens}") 224 | 225 | # Get output filename 226 | full_dataset_path = Path(dataset_path) 227 | # Get todays date 228 | date_today = datetime.datetime.now().strftime("%y-%m-%d") 229 | if run_name: 230 | run_name = f"{run_name}_" 231 | suffix = f"{run_name}{full_dataset_path.stem}_{date_today}.json" # noqa: E501 232 | prefix = f"{prompt_format}_{num_retrieved_docs}docs" 233 | if manifest_client in {"openai", "openaichat", "openaiazure"}: 234 | middleix = manifest_engine 235 | elif manifest_client in {"huggingface", "ray"}: 236 | middleix = Path(manifest_params.get("model_path", "")).name.replace("/", "-") 237 | elif manifest_client == "toma": 238 | middleix = manifest_engine.split("/")[-1] 239 | else: 240 | raise ValueError(f"Unknown manifest client {manifest_client}") 241 | output_filename = f"{prefix}_{middleix}_{suffix}" 242 | console.print(f"Saving to {Path(output_dir) / output_filename}") 243 | Path(output_dir).mkdir(parents=True, exist_ok=True) 244 | 245 | console.print("Loading metadata...") 246 | db_to_tables = data_formatter.load_table_metadata(table_meta_path) 247 | 248 | console.print("Loading data...") 249 | data = data_formatter.load_data(dataset_path) 250 | if num_run > 0: 251 | console.print(f"Running on {min(len(data), num_run)} examples") 252 | data = data[:num_run] 253 | original_data = data 254 | 255 | # load the examples 256 | console.print("Formatting data...") 257 | num_print = min(num_print, len(data)) 258 | token_lengths = [] 259 | text_to_sql_in = [ 260 | get_text_to_sql_in(input_question, db_to_tables) for input_question in data 261 | ] 262 | 263 | if num_retrieved_docs > 0: 264 | console.print("Loading documenration and indexing...") 265 | retrieved_docs = [] 266 | doc_contents = load_documentation(markdown_docs_path) 267 | chunked_docs = split_documents(doc_contents) 268 | embedded_docs, full_embedding_mat = embed_documents(chunked_docs) 269 | for i in tqdm(range(len(text_to_sql_in)), desc="Retrieving docs"): 270 | _, retrieved_docs_strings = query_docs( 271 | text_to_sql_in[i].instruction, 272 | embedded_docs, 273 | full_embedding_mat, 274 | top_n=num_retrieved_docs, 275 | ) 276 | retrieved_docs.append(retrieved_docs_strings) 277 | else: 278 | retrieved_docs = [[] for _ in range(len(text_to_sql_in))] 279 | 280 | for i in range(num_print): 281 | # Run a few to get some examples to print 282 | generated_responses = generate_sql( 283 | manifest=manifest, 284 | text_to_sql_in=[text_to_sql_in[i]], 285 | retrieved_docs=[retrieved_docs[i]], 286 | stop_tokens=stop_tokens, 287 | max_tokens=max_tokens, 288 | temperature=temperature, 289 | num_beams=num_beams, 290 | prompt_formatter=prompt_formatter, 291 | overwrite_manifest=overwrite_manifest, 292 | parallel=parallel, 293 | ) 294 | for prediction, model_response in generated_responses: 295 | prediction = re.sub(r"[\s\t\n]+", " ", prediction) 296 | token_lengths.append(len(tokenizer(prediction).input_ids)) 297 | console.print(f"[blue]Prompt:[/blue] {model_response.final_prompt}") 298 | console.print(f"[red]Prediction:[/red] {prediction}") 299 | if data[i].get("query") or data[i].get("sql"): 300 | console.print( 301 | "[purple]Gold:[/purple] " 302 | f"{data[i].get('query') or data[i].get('sql')}" 303 | ) 304 | console.print("\n****\n") 305 | 306 | # Run the entire thing now - the to_print results will be in cache and fast 307 | generated_sqls = generate_sql( 308 | manifest=manifest, 309 | text_to_sql_in=text_to_sql_in, 310 | retrieved_docs=retrieved_docs, 311 | stop_tokens=stop_tokens, 312 | max_tokens=max_tokens, 313 | temperature=temperature, 314 | num_beams=num_beams, 315 | prompt_formatter=prompt_formatter, 316 | overwrite_manifest=overwrite_manifest, 317 | parallel=parallel, 318 | ) 319 | 320 | with open(Path(output_dir) / output_filename, "w") as fout: 321 | for i, (prediction, model_response) in enumerate(generated_sqls): 322 | if isinstance(model_response.final_prompt, str): 323 | token_lengths.append( 324 | len(tokenizer(model_response.final_prompt).input_ids) 325 | ) 326 | else: 327 | for prompt in model_response.final_prompt: 328 | token_lengths.append(len(tokenizer(prompt["content"]).input_ids)) 329 | entry = { 330 | **original_data[i], 331 | "pred": prediction, 332 | "raw_pred": model_response.output, 333 | "raw_output": model_response.raw_output, 334 | "prompt": model_response.final_prompt, 335 | "tables": [tbl.model_dump() for tbl in text_to_sql_in[i].tables or []], 336 | } 337 | formatted_entry = data_formatter.format_output(entry) 338 | print(json.dumps(formatted_entry), file=fout) 339 | overflow = len([tl for tl in token_lengths if tl > 2048]) / len(token_lengths) 340 | console.print(f"Overflow 2048 prompt {100*overflow:.2f}%") 341 | console.print(f"Saved to {Path(output_dir) / output_filename}") 342 | 343 | 344 | if __name__ == "__main__": 345 | cli() 346 | -------------------------------------------------------------------------------- /eval/prompt_formatters.py: -------------------------------------------------------------------------------- 1 | """Rajkumar prompt formatter.""" 2 | 3 | from random import shuffle 4 | from manifest import Manifest 5 | from schema import Table 6 | 7 | 8 | class RajkumarFormatter: 9 | """RajkumarFormatter class. 10 | 11 | From https://arxiv.org/pdf/2204.00498.pdf. 12 | """ 13 | 14 | table_sep: str = "\n\n" 15 | shuffle_table_order: bool = True 16 | _cache: dict[tuple[str, str, str], list[str]] = {} 17 | clean_whitespace = False 18 | 19 | @classmethod 20 | def format_table(cls, table: Table) -> str: 21 | """Get table format.""" 22 | table_fmt = [] 23 | for col in table.columns or []: 24 | # This is technically an incorrect type, but it should be a catchall word 25 | table_fmt.append(f" {col.name} {col.dtype or 'any'}") 26 | if table_fmt: 27 | all_cols = ",\n".join(table_fmt) 28 | create_tbl = f"CREATE TABLE {table.name} (\n{all_cols}\n)" 29 | else: 30 | create_tbl = f"CREATE TABLE {table.name}" 31 | return create_tbl 32 | 33 | @classmethod 34 | def format_all_tables(cls, tables: list[Table], instruction: str) -> list[str]: 35 | """Get all tables format.""" 36 | table_texts = [cls.format_table(table) for table in tables] 37 | key = ("tables", instruction, str(tables)) 38 | if key not in cls._cache: 39 | shuffle(table_texts) 40 | cls._cache[key] = table_texts 41 | else: 42 | table_texts = cls._cache[key] 43 | return table_texts 44 | 45 | @classmethod 46 | def format_retrieved_context( 47 | cls, 48 | context: list[str], 49 | ) -> str: 50 | """Format retrieved context.""" 51 | context_str = "\n--------\n".join(context) 52 | return f"\n\n/*\nHere is additional documentation about DuckDB that could be useful.\n--------\n{context_str}\n--------\n*/" 53 | 54 | @classmethod 55 | def format_prompt( 56 | cls, 57 | instruction: str, 58 | table_text: str, 59 | context_text: str, 60 | ) -> str | list[str]: 61 | """Get prompt format.""" 62 | return f"""{table_text}\n\n\n-- Using valid DuckDB SQL, answer the following question for the tables provided above.{context_text}\n\n-- {instruction}\n""" # noqa: E501 63 | 64 | @classmethod 65 | def format_model_output(cls, output_sql: str, prompt: str) -> str: 66 | """Format model output.""" 67 | return output_sql 68 | 69 | @classmethod 70 | def format_gold_output(cls, output_sql: str) -> str: 71 | """Format gold output for demonstration.""" 72 | return output_sql 73 | 74 | 75 | class DuckDBFormatter(RajkumarFormatter): 76 | """DuckDB class.""" 77 | 78 | @classmethod 79 | def format_prompt( 80 | cls, 81 | instruction: str, 82 | table_text: str, 83 | context_text: str, 84 | ) -> str | list[str]: 85 | """Get prompt format.""" 86 | return f"""{table_text}\n\n\n-- Using valid DuckDB SQL, answer the following question for the tables provided above.{context_text}\n\n-- {instruction}\n```sql\n""" # noqa: E501 87 | 88 | 89 | class DuckDBInstFormatter(RajkumarFormatter): 90 | """DuckDB Inst class.""" 91 | 92 | PROMPT_TEMPLATE = """### Instruction:\n{instruction}\n\n### Input:\n{input}{context}\n### Question:\n{question}\n\n### Response (use duckdb shorthand if possible):\n""" 93 | INSTRUCTION_TEMPLATE = """Your task is to generate valid duckdb SQL to answer the following question{has_schema}""" # noqa: E501 94 | 95 | @classmethod 96 | def format_retrieved_context( 97 | cls, 98 | context: list[str], 99 | ) -> str: 100 | """Format retrieved context.""" 101 | context_str = "\n--------\n".join(context) 102 | return f"\n### Documentation:\n{context_str}\n" 103 | 104 | @classmethod 105 | def format_prompt( 106 | cls, 107 | instruction: str, 108 | table_text: str, 109 | context_text: str, 110 | ) -> str | list[str]: 111 | """Get prompt format.""" 112 | input = "" 113 | if table_text: 114 | input = """Here is the database schema that the SQL query will run on:\n{schema}\n""".format( # noqa: E501 115 | schema=table_text 116 | ) 117 | instruction = cls.PROMPT_TEMPLATE.format( 118 | instruction=cls.INSTRUCTION_TEMPLATE.format( 119 | has_schema="." 120 | if table_text == "" 121 | else ", given a duckdb database schema." 122 | ), 123 | context=context_text, 124 | input=input, 125 | question=instruction, 126 | ) 127 | return instruction 128 | 129 | 130 | class DuckDBInstNoShorthandFormatter(DuckDBInstFormatter): 131 | """DuckDB Inst class.""" 132 | 133 | PROMPT_TEMPLATE = """### Instruction:\n{instruction}\n\n### Input:\n{input}{context}\n### Question:\n{question}\n\n### Response:\n""" 134 | INSTRUCTION_TEMPLATE = """Your task is to generate valid duckdb SQL to answer the following question{has_schema}""" # noqa: E501 135 | 136 | 137 | class DuckDBChat: 138 | """DuckDB Inst class.""" 139 | 140 | table_sep: str = "\n\n" 141 | shuffle_table_order: bool = True 142 | _cache: dict[tuple[str, str, str], list[str]] = {} 143 | clean_whitespace = False 144 | model = None 145 | 146 | @classmethod 147 | def format_table(cls, table: Table) -> str: 148 | """Get table format.""" 149 | table_fmt = [] 150 | for col in table.columns or []: 151 | # This is technically an incorrect type, but it should be a catchall word 152 | table_fmt.append(f" {col.name} {col.dtype or 'any'}") 153 | if table_fmt: 154 | all_cols = ",\n".join(table_fmt) 155 | create_tbl = f"CREATE TABLE {table.name} (\n{all_cols}\n)" 156 | else: 157 | create_tbl = f"CREATE TABLE {table.name}" 158 | return create_tbl 159 | 160 | @classmethod 161 | def format_all_tables(cls, tables: list[Table], instruction: str) -> list[dict]: 162 | """Get all tables format.""" 163 | if not cls.model: 164 | cls.model = Manifest( 165 | engine="gpt-3.5-turbo", 166 | client_name="openaichat", 167 | cache_name="sqlite", 168 | cache_connection=".manifest.sqlite", 169 | ) 170 | table_texts = [cls.format_table(table) for table in tables] 171 | full_schema = cls.table_sep.join(table_texts) 172 | prompt = f"""SQL schema of my database: 173 | {full_schema} 174 | Explain in a few sentences what the data is about: 175 | """ 176 | messages = [ 177 | { 178 | "role": "system", 179 | "content": "You are a helpful assistant that can generate an human redable summary of database content based on the schema.", 180 | }, 181 | {"role": "user", "content": prompt}, 182 | ] 183 | explanation = cls.model.run(messages, temperature=0) 184 | messages.append({"role": "assistant", "content": explanation}) 185 | return messages[1:] 186 | 187 | @classmethod 188 | def format_retrieved_context( 189 | cls, 190 | context: list[str], 191 | ) -> str: 192 | """Format retrieved context.""" 193 | context_str = "\n--------\n".join(context) 194 | return f"\n\nHere is additional documentation about DuckDB that could be useful.\n--------\n{context_str}\n--------\n" 195 | 196 | @classmethod 197 | def format_prompt( 198 | cls, 199 | instruction: str, 200 | table_text: list[dict], 201 | context_text: str, 202 | ) -> str | list[str]: 203 | """Get prompt format.""" 204 | prompt = f"""Now output a single SQL query without any explanation and do not add anything 205 | to the query that was not part of the question, also do not use markdown. Make sure to only 206 | use information provided in the prompt, or tables and columns from the schema above and write a query to answer the question.{context_text}\n\nMy quesiton is \n`{instruction}`\n\nGenerate the DuckDB specific SQL query:""" # noqa: E501 207 | messages = [ 208 | { 209 | "role": "system", 210 | "content": "You are a helpful assistant that can generate DuckDB sql queries, which is a superset of Postgresql, based on the user input. You do not respond with any human readable text, only SQL code.", 211 | }, 212 | *table_text, 213 | {"role": "user", "content": prompt}, 214 | ] 215 | return messages 216 | 217 | @classmethod 218 | def format_model_output(cls, output_sql: str, prompt: str) -> str: 219 | """Format model output.""" 220 | return output_sql 221 | 222 | @classmethod 223 | def format_gold_output(cls, output_sql: str) -> str: 224 | """Format gold output for demonstration.""" 225 | return output_sql 226 | -------------------------------------------------------------------------------- /eval/schema.py: -------------------------------------------------------------------------------- 1 | """Text2SQL schemas.""" 2 | import enum 3 | 4 | from manifest.response import Usage 5 | from pydantic import BaseModel 6 | 7 | DEFAULT_TABLE_NAME: str = "db_table" 8 | 9 | 10 | class Dialect(str, enum.Enum): 11 | """SQGFluff and SQLGlot dialects. 12 | 13 | Lucky for us, the dialects match both parsers. 14 | 15 | Ref: https://github.com/sqlfluff/sqlfluff/blob/main/src/sqlfluff/core/dialects/__init__.py # noqa: E501 16 | Ref: https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py # noqa: E501 17 | """ 18 | 19 | SNOWFLAKE = "snowflake" 20 | BIGQUERY = "bigquery" 21 | REDSHIFT = "redshift" 22 | POSTGRES = "postgres" 23 | UNKNOWN = "unknown" 24 | 25 | @property 26 | def dialect_str(self) -> str | None: 27 | """Get the dialect string for validation. 28 | 29 | We need to pass in dialect = None for UNKNOWN dialects. 30 | """ 31 | if self != Dialect.UNKNOWN: 32 | return self.value 33 | else: 34 | return None 35 | 36 | @property 37 | def quote_str(self) -> str: 38 | """Get the quote string for the dialect.""" 39 | if self == Dialect.SNOWFLAKE: 40 | return '"' 41 | elif self == Dialect.BIGQUERY: 42 | return "`" 43 | elif self == Dialect.REDSHIFT: 44 | return '"' 45 | elif self == Dialect.POSTGRES: 46 | return '"' 47 | elif self == Dialect.UNKNOWN: 48 | return '"' 49 | raise NotImplementedError(f"Quote string not implemented for dialect {self}") 50 | 51 | def quote(self, string: str) -> str: 52 | """Quote a string.""" 53 | return f"{self.quote_str}{string}{self.quote_str}" 54 | 55 | 56 | class ColumnOrLiteral(BaseModel): 57 | """Column that may or may not be a literal.""" 58 | 59 | name: str | None = None 60 | literal: bool = False 61 | 62 | def __hash__(self) -> int: 63 | """Hash.""" 64 | return hash((self.name, self.literal)) 65 | 66 | 67 | class TableColumn(BaseModel): 68 | """Table column.""" 69 | 70 | name: str 71 | dtype: str | None 72 | 73 | 74 | class ForeignKey(BaseModel): 75 | """Foreign key.""" 76 | 77 | # Referenced column 78 | column: TableColumn 79 | # References table name 80 | references_name: str 81 | # References column 82 | references_column: TableColumn 83 | 84 | 85 | class Table(BaseModel): 86 | """Table.""" 87 | 88 | name: str | None 89 | columns: list[TableColumn] | None 90 | pks: list[TableColumn] | None 91 | # FK from this table to another column in another table 92 | fks: list[ForeignKey] | None 93 | examples: list[dict] | None 94 | # Is the table a source or intermediate reference table 95 | is_reference_table: bool = False 96 | 97 | 98 | class TextToSQLParams(BaseModel): 99 | """A text to sql request.""" 100 | 101 | instruction: str 102 | database: str | None 103 | # Default to unknown 104 | dialect: Dialect = Dialect.UNKNOWN 105 | tables: list[Table] | None 106 | 107 | 108 | class TextToSQLModelResponse(BaseModel): 109 | """Model for Autocomplete Responses.""" 110 | 111 | output: str 112 | final_prompt: str | list[dict] 113 | raw_output: str 114 | usage: Usage 115 | metadata: str | None = None 116 | -------------------------------------------------------------------------------- /eval/text_to_sql.py: -------------------------------------------------------------------------------- 1 | """Text-to-SQL running.""" 2 | import asyncio 3 | import json 4 | import re 5 | import time 6 | from typing import cast 7 | 8 | import structlog 9 | from manifest import Manifest 10 | from manifest.response import Response, Usage 11 | from prompt_formatters import RajkumarFormatter 12 | from schema import DEFAULT_TABLE_NAME, TextToSQLModelResponse, TextToSQLParams 13 | from tqdm.auto import tqdm 14 | 15 | logger = structlog.get_logger() 16 | 17 | 18 | def clean_whitespace(sql: str) -> str: 19 | """Clean whitespace.""" 20 | return re.sub(r"[\t\n\s]+", " ", sql) 21 | 22 | 23 | def instruction_to_sql( 24 | params: TextToSQLParams, 25 | extra_context: list[str], 26 | manifest: Manifest, 27 | prompt_formatter: RajkumarFormatter = None, 28 | overwrite_manifest: bool = False, 29 | max_tokens: int = 300, 30 | temperature: float = 0.0, 31 | stop_sequences: list[str] | None = None, 32 | num_beams: int = 1, 33 | ) -> TextToSQLModelResponse: 34 | """Parse the instruction to a sql command.""" 35 | return instruction_to_sql_list( 36 | params=[params], 37 | extra_context=[extra_context], 38 | manifest=manifest, 39 | prompt_formatter=prompt_formatter, 40 | overwrite_manifest=overwrite_manifest, 41 | max_tokens=max_tokens, 42 | temperature=temperature, 43 | stop_sequences=stop_sequences, 44 | num_beams=num_beams, 45 | )[0] 46 | 47 | 48 | def instruction_to_sql_list( 49 | params: list[TextToSQLParams], 50 | extra_context: list[list[str]], 51 | manifest: Manifest, 52 | prompt_formatter: RajkumarFormatter = None, 53 | overwrite_manifest: bool = False, 54 | max_tokens: int = 300, 55 | temperature: float = 0.0, 56 | stop_sequences: list[str] | None = None, 57 | num_beams: int = 1, 58 | verbose: bool = False, 59 | ) -> list[TextToSQLModelResponse]: 60 | """Parse the list of instructions to sql commands. 61 | 62 | Connector is used for default retry handlers only. 63 | """ 64 | if prompt_formatter is None: 65 | raise ValueError("Prompt formatter is required.") 66 | 67 | def construct_params( 68 | params: TextToSQLParams, 69 | context: list[str], 70 | ) -> str | list[dict]: 71 | """Turn params into prompt.""" 72 | if prompt_formatter.clean_whitespace: 73 | instruction = clean_whitespace(params.instruction) 74 | else: 75 | instruction = params.instruction 76 | 77 | table_texts = prompt_formatter.format_all_tables( 78 | params.tables, instruction=instruction 79 | ) 80 | # table_texts can be list of chat messages. Only join list of str. 81 | if table_texts: 82 | if isinstance(table_texts[0], str): 83 | table_text = prompt_formatter.table_sep.join(table_texts) 84 | else: 85 | table_text = table_texts 86 | else: 87 | table_text = "" 88 | 89 | if context: 90 | context_text = prompt_formatter.format_retrieved_context(context) 91 | else: 92 | context_text = "" if isinstance(table_text, str) else [] 93 | prompt = prompt_formatter.format_prompt( 94 | instruction, 95 | table_text, 96 | context_text, 97 | ) 98 | return prompt 99 | 100 | # If no inputs, return nothing 101 | if not params: 102 | return [] 103 | 104 | # Stitch together demonstrations and params 105 | prompts: list[str | list[dict]] = [] 106 | for i, param in tqdm( 107 | enumerate(params), 108 | total=len(params), 109 | desc="Constructing prompts", 110 | disable=not verbose, 111 | ): 112 | predict_str = construct_params(param, extra_context[i] if extra_context else []) 113 | if isinstance(predict_str, str): 114 | prompt = predict_str.lstrip() 115 | else: 116 | prompt = predict_str 117 | prompts.append(prompt) 118 | 119 | manifest_params = dict( 120 | max_tokens=max_tokens, 121 | overwrite_cache=overwrite_manifest, 122 | num_beams=num_beams, 123 | logprobs=5, 124 | temperature=temperature, 125 | do_sample=False if temperature <= 0 else True, 126 | stop_sequences=stop_sequences or prompt_formatter.stop_sequences, 127 | ) 128 | 129 | ret: list[TextToSQLModelResponse] = [] 130 | if len(params) == 1: 131 | prompt = prompts[0] 132 | model_response = _run_manifest( 133 | prompt, 134 | manifest_params, 135 | prompt_formatter, 136 | manifest, 137 | stop_sequences=stop_sequences, 138 | ) 139 | usage = model_response.usage 140 | model_response.usage = usage 141 | ret.append(model_response) 142 | else: 143 | # We do not handle retry logic on parallel requests right now 144 | loop = asyncio.new_event_loop() 145 | asyncio.set_event_loop(loop) 146 | response = cast( 147 | Response, 148 | loop.run_until_complete( 149 | manifest.arun_batch( 150 | prompts, 151 | **manifest_params, # type: ignore 152 | ), 153 | ), 154 | ) 155 | loop.close() 156 | 157 | response_usage = response.get_usage() 158 | response_text = response.get_parsed_response() 159 | for prompt, resp in zip(prompts, response_text): 160 | # This will restitch the query in the case we force it to start with SELECT 161 | sql_query = prompt_formatter.format_model_output(cast(str, resp), prompt) 162 | for token in stop_sequences: 163 | sql_query = sql_query.split(token)[0] 164 | logger.info(f"FINAL OUTPUT: {sql_query}") 165 | ret.append( 166 | TextToSQLModelResponse( 167 | output=sql_query, 168 | raw_output=cast(str, resp), 169 | final_prompt=prompt, 170 | usage=response_usage, 171 | ) 172 | ) 173 | 174 | return ret 175 | 176 | 177 | def _run_manifest( 178 | prompt: str | list[str], 179 | manifest_params: dict, 180 | prompt_formatter: RajkumarFormatter, 181 | manifest: Manifest, 182 | stop_sequences: list[str] | None = None, 183 | ) -> TextToSQLModelResponse: 184 | """Run manifest for prompt format.""" 185 | logger.info(f"PARAMS: {manifest_params}") 186 | if isinstance(prompt, list): 187 | for p in prompt: 188 | logger.info(f"PROMPT: {p['role']}: {p['content']}") 189 | else: 190 | logger.info(f"PROMPT: {prompt}") 191 | start_time = time.time() 192 | # Run result 193 | response = cast( 194 | Response, 195 | manifest.run( 196 | prompt, 197 | return_response=True, 198 | client_timeout=1800, 199 | **manifest_params, # type: ignore 200 | ), 201 | ) 202 | logger.info(f"TIME: {time.time() - start_time: .2f}") 203 | 204 | response_usage = response.get_usage_obj() 205 | summed_usage = Usage() 206 | for usage in response_usage.usages: 207 | summed_usage.completion_tokens += usage.completion_tokens 208 | summed_usage.prompt_tokens += usage.prompt_tokens 209 | summed_usage.total_tokens += usage.total_tokens 210 | # This will restitch the query in the case we force it to start with SELECT 211 | sql_query = prompt_formatter.format_model_output( 212 | cast(str, response.get_response()), prompt 213 | ) 214 | 215 | for token in stop_sequences: 216 | sql_query = sql_query.split(token)[0] 217 | logger.info(f"OUTPUT: {sql_query}") 218 | model_response = TextToSQLModelResponse( 219 | output=sql_query, 220 | raw_output=cast(str, response.get_response()), 221 | final_prompt=prompt, 222 | usage=summed_usage, 223 | ) 224 | return model_response 225 | -------------------------------------------------------------------------------- /examples/local_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Setup - Llama.cpp\n", 8 | "\n", 9 | "If you want to use [llama.cpp](https://github.com/abetlen/llama-cpp-python) on a MacBook M1 or M2, run the code below. For more options, check out the [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) docs." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "!CMAKE_ARGS=\"-DLLAMA_METAL=on\" pip install llama-cpp-python" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "# Download the model weight from huggingface\n", 28 | "!huggingface-cli download motherduckdb/DuckDB-NSQL-7B-v0.1-GGUF DuckDB-NSQL-7B-v0.1-q8_0.gguf --local-dir . --local-dir-use-symlinks False\n" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "!pip install wurlitzer pandas duckdb==0.9.2" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "## Setup - General imports" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "import duckdb\n", 54 | "from wurlitzer import pipes\n", 55 | "from utils import generate_sql\n", 56 | "from llama_cpp import Llama" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "## Load Model " 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": { 70 | "scrolled": true 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "with pipes() as (out, err):\n", 75 | " client = Llama(\n", 76 | " model_path=\"DuckDB-NSQL-7B-v0.1-q8_0.gguf\",\n", 77 | " n_ctx=2048,\n", 78 | " )" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "## Connect to DuckDB" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "con = duckdb.connect(\"nyc.duckdb\")" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "## Ask Question" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "question = \"get all columns from taxi table starting with 'a'\"" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "%%time\n", 120 | "sql_query = generate_sql(question, con, client)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "## Run Query on DuckDB" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "con.execute(sql_query).fetchdf()" 137 | ] 138 | } 139 | ], 140 | "metadata": { 141 | "kernelspec": { 142 | "display_name": "Python 3 (ipykernel)", 143 | "language": "python", 144 | "name": "python3" 145 | }, 146 | "language_info": { 147 | "codemirror_mode": { 148 | "name": "ipython", 149 | "version": 3 150 | }, 151 | "file_extension": ".py", 152 | "mimetype": "text/x-python", 153 | "name": "python", 154 | "nbconvert_exporter": "python", 155 | "pygments_lexer": "ipython3", 156 | "version": "3.11.2" 157 | } 158 | }, 159 | "nbformat": 4, 160 | "nbformat_minor": 4 161 | } 162 | -------------------------------------------------------------------------------- /examples/nyc.duckdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NumbersStationAI/DuckDB-NSQL/5860f7a0796f7a22e041b76b430559490c199cd7/examples/nyc.duckdb -------------------------------------------------------------------------------- /examples/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import re 3 | from typing import Any 4 | import subprocess 5 | from wurlitzer import pipes 6 | from duckdb import DuckDBPyConnection 7 | 8 | PROMPT_TEMPLATE = """### Instruction:\n{instruction}\n\n### Input:\n{input}\n### Question:\n{question}\n\n### Response (use duckdb shorthand if possible):\n""" 9 | INSTRUCTION_TEMPLATE = """Your task is to generate valid duckdb SQL to answer the following question{has_schema}""" # noqa: E501 10 | ERROR_MESSAGE = "Quack! Much to our regret, SQL generation has gone a tad duck-side-down.\nThe model is currently not capable of crafting the desired SQL. \nSorry my duck friend.\n\nIf the question is about your own database, make sure to set the correct schema.\n\n```sql\n{sql_query}\n```\n\n```sql\n{error_msg}\n```" 11 | 12 | 13 | def get_schema(connection: DuckDBPyConnection) -> str: 14 | """Get schema from DuckDB connection.""" 15 | tables = [] 16 | information_schema = connection.execute( 17 | "SELECT * FROM information_schema.tables" 18 | ).fetchdf() 19 | for table_name in information_schema["table_name"].unique(): 20 | table_df = connection.execute( 21 | f"SELECT * FROM information_schema.columns WHERE table_name = '{table_name}'" 22 | ).fetchdf() 23 | columns = [] 24 | for _, row in table_df.iterrows(): 25 | col_name = row["column_name"] 26 | col_dtype = row["data_type"] 27 | columns.append(f"{col_name} {col_dtype}") 28 | column_str = ",\n ".join(columns) 29 | table = f"CREATE TABLE {table_name} (\n {column_str}\n);" 30 | tables.append(table) 31 | return "\n\n".join(tables) 32 | 33 | 34 | def generate_prompt(question: str, schema: str) -> str: 35 | """Generate prompt.""" 36 | input = "" 37 | if schema: 38 | # Lowercase types inside each CREATE TABLE (...) statement 39 | for create_table in re.findall( 40 | r"CREATE TABLE [^(]+\((.*?)\);", schema, flags=re.DOTALL | re.MULTILINE 41 | ): 42 | for create_col in re.findall(r"(\w+) (\w+)", create_table): 43 | schema = schema.replace( 44 | f"{create_col[0]} {create_col[1]}", 45 | f"{create_col[0]} {create_col[1].lower()}", 46 | ) 47 | input = """Here is the database schema that the SQL query will run on:\n{schema}\n""".format( # noqa: E501 48 | schema=schema 49 | ) 50 | prompt = PROMPT_TEMPLATE.format( 51 | instruction=INSTRUCTION_TEMPLATE.format( 52 | has_schema="." if schema == "" else ", given a duckdb database schema." 53 | ), 54 | input=input, 55 | question=question, 56 | ) 57 | return prompt 58 | 59 | 60 | def generate_sql( 61 | question: str, 62 | connection: DuckDBPyConnection, 63 | llama: Any, 64 | max_tokens: int = 300, 65 | ) -> [str, bool, str]: 66 | schema = get_schema(connection) 67 | prompt = generate_prompt(question, schema) 68 | 69 | with pipes() as (out, err): 70 | res = llama(prompt, temperature=0.1, max_tokens=max_tokens) 71 | sql_query = res["choices"][0]["text"] 72 | 73 | is_valid, error_msg = validate_sql(sql_query, schema) 74 | 75 | if is_valid: 76 | print(sql_query) 77 | else: 78 | print("!!!Invalid SQL detected!!!") 79 | print(sql_query) 80 | print(error_msg) 81 | 82 | return sql_query 83 | 84 | 85 | def validate_sql(query, schema): 86 | try: 87 | # Define subprocess 88 | process = subprocess.Popen( 89 | [sys.executable, './validate_sql.py', query, schema], 90 | stdout=subprocess.PIPE, 91 | stderr=subprocess.PIPE 92 | ) 93 | # Get output and potential parser, and binder error message 94 | stdout, stderr = process.communicate(timeout=0.5) 95 | if stderr: 96 | error_message = stderr.decode('utf8').split("\n") 97 | # skip traceback 98 | if len(error_message) > 3: 99 | error_message = "\n".join(error_message[3:]) 100 | return False, error_message 101 | return True, "" 102 | except subprocess.TimeoutExpired: 103 | process.kill() 104 | # timeout reached, so parsing and binding was very likely successful 105 | return True, "" 106 | -------------------------------------------------------------------------------- /examples/validate_sql.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import duckdb 3 | from duckdb import ParserException, SyntaxException, BinderException, CatalogException 4 | 5 | 6 | def validate_query(query, schemas): 7 | try: 8 | with duckdb.connect( 9 | ":memory:", config={"enable_external_access": False} 10 | ) as duckdb_conn: 11 | # register schemas 12 | for schema in schemas.split(";"): 13 | duckdb_conn.execute(schema) 14 | cursor = duckdb_conn.cursor() 15 | cursor.execute(query) 16 | except ParserException as e: 17 | return str(e) 18 | except SyntaxException as e: 19 | return str(e) 20 | except BinderException as e: 21 | return str(e) 22 | except CatalogException as e: 23 | if not ("but it exists" in str(e) and "extension" in str(e)): 24 | return str(e) 25 | except Exception as e: 26 | return None 27 | return None 28 | 29 | 30 | if __name__ == "__main__": 31 | if len(sys.argv) > 2: 32 | error = validate_query(sys.argv[1], sys.argv[2]) 33 | if error: 34 | raise Exception(error) 35 | else: 36 | print("No query provided.") 37 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | manifest-ml[all]==0.1.8 2 | pandas>=2.0.0 3 | sqlalchemy<2.0.0 4 | transformers>=4.34.1 5 | datasets==2.11.0 6 | jsonlines>=3.1.0 7 | sqlglot==11.5.5 8 | click 9 | rich 10 | nltk>=3.5,<3.6 11 | sqlparse 12 | pebble 13 | structlog 14 | sentencepiece 15 | duckdb==0.9.2 16 | structlog==22.3.0 17 | tensorboard==2.15.1 18 | lightning==2.1.0 19 | lmdb==1.4.1 20 | cloudpathlib==0.13.0 21 | peft==0.6.0 22 | packaging==23.2 23 | ninja==1.11.1.1 24 | flash-attn==2.3.3 25 | langchain 26 | pydantic>2 27 | --------------------------------------------------------------------------------