├── .gitignore ├── LICENSE ├── README.md ├── data_prep ├── README.md ├── data │ ├── download.sh │ └── tables.json ├── data_utils.py ├── kummerfeld_utils.py ├── prep_text2sql_data.py ├── prompt_formatters.py ├── schema.py ├── text2sql_data_config.yaml └── text2sql_dataset.py ├── examples ├── colab.ipynb ├── db_connectors.py ├── finetune.ipynb ├── postgres.ipynb ├── prompt_formatters.py └── sqlite.ipynb └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python ### 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | *.pyc 7 | 8 | # Distribution / packaging 9 | .Python 10 | build/ 11 | develop-eggs/ 12 | dist/ 13 | downloads/ 14 | eggs/ 15 | .eggs/ 16 | lib/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | .pytest_cache/ 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Jupyter Notebook 53 | .ipynb_checkpoints 54 | 55 | # pyenv 56 | .python-version 57 | 58 | # Environments 59 | .env 60 | .venv 61 | .virtualenv 62 | env/ 63 | venv/ 64 | virtualenv/ 65 | ENV/ 66 | env.bak/ 67 | venv.bak/ 68 | 69 | # mkdocs documentation 70 | /site 71 | 72 | # mypy 73 | .mypy_cache/ 74 | .dmypy.json 75 | 76 | # wheels 77 | .whl 78 | 79 | # VS Code 80 | *.vscode 81 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NSQL 2 | Numbers Station Text to SQL model code. 3 | 4 | NSQL is a family of autoregressive open-source large foundation models (FMs) designed specifically for SQL generation tasks. All model weights are provided on HuggingFace. 5 | 6 | | Model Name | Size | Link | 7 | | ---------- | ---- | ------- | 8 | | NumbersStation/nsql-350M | 350M | [link](https://huggingface.co/NumbersStation/nsql-350M) 9 | | NumbersStation/nsql-2B | 2.7B | [link](https://huggingface.co/NumbersStation/nsql-2B) 10 | | NumbersStation/nsql-6B | 6B | [link](https://huggingface.co/NumbersStation/nsql-6B) 11 | | NumbersStation/nsql-llama-2-7B | 7B | [link](https://huggingface.co/NumbersStation/nsql-llama-2-7B) 12 | 13 | ## Setup 14 | To install, run 15 | ``` 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## Usage 20 | See examples in `examples/` for how to connect to Postgres or SQLite to ask questions directly over your data. A small code snippet is provided below from the `examples/` directory. 21 | 22 | In a separate screen or window, run 23 | ```bash 24 | python3 -m manifest.api.app \ 25 | --model_type huggingface \ 26 | --model_generation_type text-generation \ 27 | --model_name_or_path NumbersStation/nsql-350M \ 28 | --device 0 29 | ``` 30 | 31 | Then run 32 | 33 | ```python 34 | from db_connectors import PostgresConnector 35 | from prompt_formatters import RajkumarFormatter 36 | from manifest import Manifest 37 | 38 | postgres_connector = PostgresConnector( 39 | user=USER, password=PASSWORD, dbname=DATABASE, host=HOST, port=PORT 40 | ) 41 | postgres_connector.connect() 42 | db_schema = [postgres_connector.get_schema(table) for table in postgres_connector.get_tables()] 43 | formatter = RajkumarFormatter(db_schema) 44 | 45 | manifest_client = Manifest(client_name="huggingface", client_connection="http://127.0.0.1:5000") 46 | 47 | def get_sql(instruction: str, max_tokens: int = 300) -> str: 48 | prompt = formatter.format_prompt(instruction) 49 | res = manifest_client.run(prompt, max_tokens=max_tokens) 50 | return formatter.format_model_output(res) 51 | 52 | print(get_sql("Number of rows in table?")) 53 | ``` 54 | 55 | ## Data Preparation 56 | 57 | In `data_prep` folder, we provide data preparation scripts to generate [NSText2SQL](https://huggingface.co/datasets/NumbersStation/NSText2SQL) to train [NSQL](https://huggingface.co/NumbersStation/nsql-6B) models. 58 | 59 | ## License 60 | 61 | The code in this repo is licensed under the Apache 2.0 license. Unless otherwise noted, 62 | 63 | ``` 64 | Copyright 2023 Numbers Station 65 | 66 | Licensed under the Apache License, Version 2.0 (the "License"); 67 | you may not use this file except in compliance with the License. 68 | You may obtain a copy of the License at 69 | 70 | http://www.apache.org/licenses/LICENSE-2.0 71 | 72 | Unless required by applicable law or agreed to in writing, software 73 | distributed under the License is distributed on an "AS IS" BASIS, 74 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 75 | See the License for the specific language governing permissions and 76 | limitations under the License. 77 | ``` 78 | 79 | The data to generate NSText2SQL is sourced from repositories with various licenses. Any use of all or part of the data gathered in NSText2SQL must abide by the terms of the original licenses, including attribution clauses when relevant. We thank all authors who provided these datasets. We provide provenance information for each dataset below. 80 | 81 | | Datasets | License | Link | 82 | | ---------------------- | ------------ | -------------------------------------------------------------------------------------------------------------------- | 83 | | academic | Not Found | [https://github.com/jkkummerfeld/text2sql-data](https://github.com/jkkummerfeld/text2sql-data) | 84 | | advising | CC-BY-4.0 | [https://github.com/jkkummerfeld/text2sql-data](https://github.com/jkkummerfeld/text2sql-data) | 85 | | atis | Not Found | [https://github.com/jkkummerfeld/text2sql-data](https://github.com/jkkummerfeld/text2sql-data) | 86 | | restaurants | Not Found | [https://github.com/jkkummerfeld/text2sql-data](https://github.com/jkkummerfeld/text2sql-data) | 87 | | scholar | Not Found | [https://github.com/jkkummerfeld/text2sql-data](https://github.com/jkkummerfeld/text2sql-data) | 88 | | imdb | Not Found | [https://github.com/jkkummerfeld/text2sql-data](https://github.com/jkkummerfeld/text2sql-data) | 89 | | yelp | Not Found | [https://github.com/jkkummerfeld/text2sql-data](https://github.com/jkkummerfeld/text2sql-data) | 90 | | criteria2sql | Apache-2.0 | [https://github.com/xiaojingyu92/Criteria2SQL](https://github.com/xiaojingyu92/Criteria2SQL) | 91 | | css | CC-BY-4.0 | [https://huggingface.co/datasets/zhanghanchong/css](https://huggingface.co/datasets/zhanghanchong/css) | 92 | | eICU | CC-BY-4.0 | [https://github.com/glee4810/EHRSQL](https://github.com/glee4810/EHRSQL) | 93 | | mimic_iii | CC-BY-4.0 | [https://github.com/glee4810/EHRSQL](https://github.com/glee4810/EHRSQL) | 94 | | geonucleardata | CC-BY-SA-4.0 | [https://github.com/chiahsuan156/KaggleDBQA](https://github.com/chiahsuan156/KaggleDBQA) | 95 | | greatermanchestercrime | CC-BY-SA-4.0 | [https://github.com/chiahsuan156/KaggleDBQA](https://github.com/chiahsuan156/KaggleDBQA) | 96 | | studentmathscore | CC-BY-SA-4.0 | [https://github.com/chiahsuan156/KaggleDBQA](https://github.com/chiahsuan156/KaggleDBQA) | 97 | | thehistoryofbaseball | CC-BY-SA-4.0 | [https://github.com/chiahsuan156/KaggleDBQA](https://github.com/chiahsuan156/KaggleDBQA) | 98 | | uswildfires | CC-BY-SA-4.0 | [https://github.com/chiahsuan156/KaggleDBQA](https://github.com/chiahsuan156/KaggleDBQA) | 99 | | whatcdhiphop | CC-BY-SA-4.0 | [https://github.com/chiahsuan156/KaggleDBQA](https://github.com/chiahsuan156/KaggleDBQA) | 100 | | worldsoccerdatabase | CC-BY-SA-4.0 | [https://github.com/chiahsuan156/KaggleDBQA](https://github.com/chiahsuan156/KaggleDBQA) | 101 | | pesticide | CC-BY-SA-4.0 | [https://github.com/chiahsuan156/KaggleDBQA](https://github.com/chiahsuan156/KaggleDBQA) | 102 | | mimicsql_data | MIT | [https://github.com/wangpinggl/TREQS](https://github.com/wangpinggl/TREQS) | 103 | | nvbench | MIT | [https://github.com/TsinghuaDatabaseGroup/nvBench](https://github.com/TsinghuaDatabaseGroup/nvBench) | 104 | | sede | Apache-2.0 | [https://github.com/hirupert/sede](https://github.com/hirupert/sede) | 105 | | spider | CC-BY-SA-4.0 | [https://huggingface.co/datasets/spider](https://huggingface.co/datasets/spider) | 106 | | sql_create_context | CC-BY-4.0 | [https://huggingface.co/datasets/b-mc2/sql-create-context](https://huggingface.co/datasets/b-mc2/sql-create-context) | 107 | | squall | CC-BY-SA-4.0 | [https://github.com/tzshi/squall](https://github.com/tzshi/squall) | 108 | | wikisql | BSD 3-Clause | [https://github.com/salesforce/WikiSQL](https://github.com/salesforce/WikiSQL) | 109 | 110 | For full terms, see the LICENSE file. If you have any questions, comments, or concerns about licensing please [contact us](https://www.numbersstation.ai/signup). 111 | 112 | # Citing this work 113 | 114 | If you use this data in your work, please cite our work _and_ the appropriate original sources: 115 | 116 | To cite NSText2SQL, please use: 117 | ```TeX 118 | @software{numbersstation2023NSText2SQL, 119 | author = {Numbers Station Labs}, 120 | title = {NSText2SQL: An Open Source Text-to-SQL Dataset for Foundation Model Training}, 121 | month = {July}, 122 | year = {2023}, 123 | url = {https://github.com/NumbersStationAI/NSQL}, 124 | } 125 | ``` 126 | 127 | To cite dataset used in this work, please use: 128 | 129 | | Datasets | Cite | 130 | | ---------------------- | ---------------------------------------------------------------------------------------- | 131 | | academic | `\cite{data-advising,data-academic}` | 132 | | advising | `\cite{data-advising}` | 133 | | atis | `\cite{data-advising,data-atis-original,data-atis-geography-scholar}` | 134 | | restaurants | `\cite{data-advising,data-restaurants-logic,data-restaurants-original,data-restaurants}` | 135 | | scholar | `\cite{data-advising,data-atis-geography-scholar}` | 136 | | imdb | `\cite{data-advising,data-imdb-yelp}` | 137 | | yelp | `\cite{data-advising,data-imdb-yelp}` | 138 | | criteria2sql | `\cite{Criteria-to-SQL}` | 139 | | css | `\cite{zhang2023css}` | 140 | | eICU | `\cite{lee2022ehrsql}` | 141 | | mimic_iii | `\cite{lee2022ehrsql}` | 142 | | geonucleardata | `\cite{lee-2021-kaggle-dbqa}` | 143 | | greatermanchestercrime | `\cite{lee-2021-kaggle-dbqa}` | 144 | | studentmathscore | `\cite{lee-2021-kaggle-dbqa}` | 145 | | thehistoryofbaseball | `\cite{lee-2021-kaggle-dbqa}` | 146 | | uswildfires | `\cite{lee-2021-kaggle-dbqa}` | 147 | | whatcdhiphop | `\cite{lee-2021-kaggle-dbqa}` | 148 | | worldsoccerdatabase | `\cite{lee-2021-kaggle-dbqa}` | 149 | | pesticide | `\cite{lee-2021-kaggle-dbqa}` | 150 | | mimicsql_data | `\cite{wang2020text}` | 151 | | nvbench | `\cite{nvBench_SIGMOD21}` | 152 | | sede | `\cite{hazoom2021text}` | 153 | | spider | `\cite{data-spider}` | 154 | | sql_create_context | Not Found | 155 | | squall | `\cite{squall}` | 156 | | wikisql | `\cite{data-wikisql}` | 157 | 158 | 159 | ```TeX 160 | @InProceedings{data-advising, 161 | dataset = {Advising}, 162 | author = {Catherine Finegan-Dollak, Jonathan K. Kummerfeld, Li Zhang, Karthik Ramanathan, Sesh Sadasivam, Rui Zhang, and Dragomir Radev}, 163 | title = {Improving Text-to-SQL Evaluation Methodology}, 164 | booktitle = {Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)}, 165 | month = {July}, 166 | year = {2018}, 167 | location = {Melbourne, Victoria, Australia}, 168 | pages = {351--360}, 169 | url = {http://aclweb.org/anthology/P18-1033}, 170 | } 171 | 172 | @InProceedings{data-imdb-yelp, 173 | dataset = {IMDB and Yelp}, 174 | author = {Navid Yaghmazadeh, Yuepeng Wang, Isil Dillig, and Thomas Dillig}, 175 | title = {SQLizer: Query Synthesis from Natural Language}, 176 | booktitle = {International Conference on Object-Oriented Programming, Systems, Languages, and Applications, ACM}, 177 | month = {October}, 178 | year = {2017}, 179 | pages = {63:1--63:26}, 180 | url = {http://doi.org/10.1145/3133887}, 181 | } 182 | 183 | @article{data-academic, 184 | dataset = {Academic}, 185 | author = {Fei Li and H. V. Jagadish}, 186 | title = {Constructing an Interactive Natural Language Interface for Relational Databases}, 187 | journal = {Proceedings of the VLDB Endowment}, 188 | volume = {8}, 189 | number = {1}, 190 | month = {September}, 191 | year = {2014}, 192 | pages = {73--84}, 193 | url = {http://dx.doi.org/10.14778/2735461.2735468}, 194 | } 195 | 196 | @InProceedings{data-atis-geography-scholar, 197 | dataset = {Scholar, and Updated ATIS and Geography}, 198 | author = {Srinivasan Iyer, Ioannis Konstas, Alvin Cheung, Jayant Krishnamurthy, and Luke Zettlemoyer}, 199 | title = {Learning a Neural Semantic Parser from User Feedback}, 200 | booktitle = {Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)}, 201 | year = {2017}, 202 | pages = {963--973}, 203 | location = {Vancouver, Canada}, 204 | url = {http://www.aclweb.org/anthology/P17-1089}, 205 | } 206 | 207 | @article{data-atis-original, 208 | dataset = {ATIS, original}, 209 | author = {Deborah A. Dahl, Madeleine Bates, Michael Brown, William Fisher, Kate Hunicke-Smith, David Pallett, Christine Pao, Alexander Rudnicky, and Elizabeth Shriber}, 210 | title = {{Expanding the scope of the ATIS task: The ATIS-3 corpus}}, 211 | journal = {Proceedings of the workshop on Human Language Technology}, 212 | year = {1994}, 213 | pages = {43--48}, 214 | url = {http://dl.acm.org/citation.cfm?id=1075823}, 215 | } 216 | 217 | @inproceedings{data-restaurants-logic, 218 | author = {Lappoon R. Tang and Raymond J. Mooney}, 219 | title = {Automated Construction of Database Interfaces: Intergrating Statistical and Relational Learning for Semantic Parsing}, 220 | booktitle = {2000 Joint SIGDAT Conference on Empirical Methods in Natural Language Processing and Very Large Corpora}, 221 | year = {2000}, 222 | pages = {133--141}, 223 | location = {Hong Kong, China}, 224 | url = {http://www.aclweb.org/anthology/W00-1317}, 225 | } 226 | 227 | @inproceedings{data-restaurants-original, 228 | author = {Ana-Maria Popescu, Oren Etzioni, and Henry Kautz}, 229 | title = {Towards a Theory of Natural Language Interfaces to Databases}, 230 | booktitle = {Proceedings of the 8th International Conference on Intelligent User Interfaces}, 231 | year = {2003}, 232 | location = {Miami, Florida, USA}, 233 | pages = {149--157}, 234 | url = {http://doi.acm.org/10.1145/604045.604070}, 235 | } 236 | 237 | @inproceedings{data-restaurants, 238 | author = {Alessandra Giordani and Alessandro Moschitti}, 239 | title = {Automatic Generation and Reranking of SQL-derived Answers to NL Questions}, 240 | booktitle = {Proceedings of the Second International Conference on Trustworthy Eternal Systems via Evolving Software, Data and Knowledge}, 241 | year = {2012}, 242 | location = {Montpellier, France}, 243 | pages = {59--76}, 244 | url = {https://doi.org/10.1007/978-3-642-45260-4_5}, 245 | } 246 | 247 | @InProceedings{data-spider, 248 | author = {Tao Yu, Rui Zhang, Kai Yang, Michihiro Yasunaga, Dongxu Wang, Zifan Li, James Ma, Irene Li, Qingning Yao, Shanelle Roman, Zilin Zhang, and Dragomir Radev}, 249 | title = {Spider: A Large-Scale Human-Labeled Dataset for Complex and Cross-Domain Semantic Parsing and Text-to-SQL Task}, 250 | booktitle = {Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing}, 251 | year = {2018}, 252 | location = {Brussels, Belgium}, 253 | pages = {3911--3921}, 254 | url = {http://aclweb.org/anthology/D18-1425}, 255 | } 256 | 257 | @article{data-wikisql, 258 | author = {Victor Zhong, Caiming Xiong, and Richard Socher}, 259 | title = {Seq2SQL: Generating Structured Queries from Natural Language using Reinforcement Learning}, 260 | year = {2017}, 261 | journal = {CoRR}, 262 | volume = {abs/1709.00103}, 263 | } 264 | 265 | @InProceedings{Criteria-to-SQL, 266 | author = {Yu, Xiaojing and Chen, Tianlong and Yu, Zhengjie and Li, Huiyu and Yang, Yang and Jiang, Xiaoqian and Jiang, Anxiao}, 267 | title = {Dataset and Enhanced Model for Eligibility Criteria-to-SQL Semantic Parsing}, 268 | booktitle = {Proceedings of The 12th Language Resources and Evaluation Conference}, 269 | month = {May}, 270 | year = {2020}, 271 | address = {Marseille, France}, 272 | publisher = {European Language Resources Association}, 273 | pages = {5831--5839}, 274 | } 275 | 276 | @misc{zhang2023css, 277 | title = {CSS: A Large-scale Cross-schema Chinese Text-to-SQL Medical Dataset}, 278 | author = {Hanchong Zhang and Jieyu Li and Lu Chen and Ruisheng Cao and Yunyan Zhang and Yu Huang and Yefeng Zheng and Kai Yu}, 279 | year = {2023}, 280 | } 281 | 282 | @article{lee2022ehrsql, 283 | title = {EHRSQL: A Practical Text-to-SQL Benchmark for Electronic Health Records}, 284 | author = {Lee, Gyubok and Hwang, Hyeonji and Bae, Seongsu and Kwon, Yeonsu and Shin, Woncheol and Yang, Seongjun and Seo, Minjoon and Kim, Jong-Yeup and Choi, Edward}, 285 | journal = {Advances in Neural Information Processing Systems}, 286 | volume = {35}, 287 | pages = {15589--15601}, 288 | year = {2022}, 289 | } 290 | 291 | @inproceedings{lee-2021-kaggle-dbqa, 292 | title = {KaggleDBQA: Realistic Evaluation of Text-to-SQL Parsers}, 293 | author = {Lee, Chia-Hsuan and Polozov, Oleksandr and Richardson, Matthew}, 294 | booktitle = {Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)}, 295 | pages = {2261--2273}, 296 | year = {2021}, 297 | } 298 | 299 | @inproceedings{squall, 300 | title = {On the Potential of Lexico-logical Alignments for Semantic Parsing to {SQL} Queries}, 301 | author = {Tianze Shi and Chen Zhao and Jordan Boyd-Graber and Hal {Daum\'{e} III} and Lillian Lee}, 302 | booktitle = {Findings of EMNLP}, 303 | year = {2020}, 304 | } 305 | 306 | @article{hazoom2021text, 307 | title = {Text-to-SQL in the wild: a naturally-occurring dataset based on Stack exchange data}, 308 | author = {Hazoom, Moshe and Malik, Vibhor and Bogin, Ben}, 309 | journal = {arXiv preprint arXiv:2106.05006}, 310 | year = {2021}, 311 | } 312 | 313 | @inproceedings{wang2020text, 314 | title = {Text-to-SQL Generation for Question Answering on Electronic Medical Records}, 315 | author = {Wang, Ping and Shi, Tian and Reddy, Chandan K}, 316 | booktitle = {Proceedings of The Web Conference 2020}, 317 | pages = {350--361}, 318 | year = {2020}, 319 | } 320 | 321 | @inproceedings{nvBench_SIGMOD21, 322 | title = {Synthesizing Natural Language to Visualization (NL2VIS) Benchmarks from NL2SQL Benchmarks}, 323 | author = {Yuyu Luo and Nan Tang and Guoliang Li and Chengliang Chai and Wenbo Li and Xuedi Qin}, 324 | booktitle = {Proceedings of the 2021 International Conference on Management of Data, {SIGMOD} Conference 2021, June 20–25, 2021, Virtual Event, China}, 325 | publisher = {ACM}, 326 | year = {2021}, 327 | } 328 | ``` 329 | 330 | 331 | ## Acknowledgement 332 | We are appreciative to the work done by the all authors for those datasets that made this project possible. -------------------------------------------------------------------------------- /data_prep/README.md: -------------------------------------------------------------------------------- 1 | ## Data Preprocessing 2 | 3 | We provide scripts and instructions to create the [NSText2SQL](https://huggingface.co/datasets/NumbersStation/NSText2SQL) to train [NSQL](https://huggingface.co/NumbersStation/nsql-6B) models. The dataset will saved as jsonl files, following the format: 4 | 5 | ``` 6 | {"instruction": ..., "output": "...", "source": "..."} 7 | ``` 8 | 9 | #### Data Download 10 | 11 | We use the datasets hosted on Github, Huggingface, and other online servers. You will need to download the datasets by running the following commands from the `data` folder: 12 | 13 | ```bash 14 | cd data/ 15 | bash download.sh 16 | cd .. 17 | ``` 18 | 19 | To download spider dataset, you need to download it from [here](https://drive.google.com/uc?export=download&id=1TqleXec_OykOYFREKKtschzY29dUcVAQ) and unzip it to the data folder. 20 | 21 | #### Data Preparation 22 | 23 | To preprocess the data into our own format. You can run: 24 | 25 | ```bash 26 | python prep_text2sql_data.py \ 27 | --datasets academic \ 28 | --datasets advising \ 29 | --output_dir [OUTPUT_DIR] 30 | ``` 31 | 32 | The processed data will be saved into `[OUPUT_DIR]/YYYY-MM-DD` folder. Here is the available DATASET_NAME list: 33 | - wikisql 34 | - academic 35 | - advising 36 | - atis 37 | - imdb 38 | - restaurants 39 | - scholar 40 | - yelp 41 | - sede 42 | - eicu 43 | - mimic_iii 44 | - GeoNuclearData 45 | - GreaterManchesterCrime 46 | - Pesticide 47 | - StudentMathScore 48 | - TheHistoryofBaseball 49 | - USWildFires 50 | - WhatCDHipHop 51 | - WorldSoccerDataBase 52 | - mimicsql_data 53 | - criteria2sql 54 | - sql_create_context 55 | - squall 56 | - css 57 | - spider 58 | - nvbench 59 | 60 | For more information you can find in the `prep_text2sql_data.py`. -------------------------------------------------------------------------------- /data_prep/data/download.sh: -------------------------------------------------------------------------------- 1 | git clone https://github.com/jkkummerfeld/text2sql-data 2 | git clone https://github.com/hirupert/sede 3 | git clone https://github.com/glee4810/EHRSQL 4 | git clone https://github.com/chiahsuan156/KaggleDBQA 5 | 6 | git clone https://github.com/wangpinggl/TREQS 7 | cp tables.json TREQS/mimicsql_data/ # copy table schema to TREQS dataset 8 | 9 | git clone https://github.com/tzshi/squall 10 | git clone https://github.com/xiaojingyu92/Criteria2SQL 11 | 12 | wget https://huggingface.co/datasets/zhanghanchong/css/resolve/main/css.zip 13 | unzip css.zip 14 | 15 | git clone https://github.com/TsinghuaDatabaseGroup/nvBench 16 | unzip nvBench/databases.zip -d nvBench 17 | -------------------------------------------------------------------------------- /data_prep/data/tables.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "column_names": [ 4 | [ 5 | -1, 6 | "*" 7 | ], 8 | [ 9 | 0, 10 | "subject_id" 11 | ], 12 | [ 13 | 0, 14 | "hadm_id" 15 | ], 16 | [ 17 | 0, 18 | "name" 19 | ], 20 | [ 21 | 0, 22 | "marital_status" 23 | ], 24 | [ 25 | 0, 26 | "age" 27 | ], 28 | [ 29 | 0, 30 | "dob" 31 | ], 32 | [ 33 | 0, 34 | "gender" 35 | ], 36 | [ 37 | 0, 38 | "language" 39 | ], 40 | [ 41 | 0, 42 | "religion" 43 | ], 44 | [ 45 | 0, 46 | "admission_type" 47 | ], 48 | [ 49 | 0, 50 | "days_stay" 51 | ], 52 | [ 53 | 0, 54 | "insurance" 55 | ], 56 | [ 57 | 0, 58 | "ethnicity" 59 | ], 60 | [ 61 | 0, 62 | "expire_flag" 63 | ], 64 | [ 65 | 0, 66 | "admission_location" 67 | ], 68 | [ 69 | 0, 70 | "discharge_location" 71 | ], 72 | [ 73 | 0, 74 | "diagnosis" 75 | ], 76 | [ 77 | 0, 78 | "dod" 79 | ], 80 | [ 81 | 0, 82 | "dob_year" 83 | ], 84 | [ 85 | 0, 86 | "dod_year" 87 | ], 88 | [ 89 | 0, 90 | "admittime" 91 | ], 92 | [ 93 | 0, 94 | "dischtime" 95 | ], 96 | [ 97 | 0, 98 | "admityear" 99 | ], 100 | [ 101 | 1, 102 | "subject_id" 103 | ], 104 | [ 105 | 1, 106 | "hadm_id" 107 | ], 108 | [ 109 | 1, 110 | "icd9_code" 111 | ], 112 | [ 113 | 1, 114 | "short_title" 115 | ], 116 | [ 117 | 1, 118 | "long_title" 119 | ], 120 | [ 121 | 2, 122 | "subject_id" 123 | ], 124 | [ 125 | 2, 126 | "hadm_id" 127 | ], 128 | [ 129 | 2, 130 | "icd9_code" 131 | ], 132 | [ 133 | 2, 134 | "short_title" 135 | ], 136 | [ 137 | 2, 138 | "long_title" 139 | ], 140 | [ 141 | 3, 142 | "subject_id" 143 | ], 144 | [ 145 | 3, 146 | "hadm_id" 147 | ], 148 | [ 149 | 3, 150 | "icustay_id" 151 | ], 152 | [ 153 | 3, 154 | "drug_type" 155 | ], 156 | [ 157 | 3, 158 | "drug" 159 | ], 160 | [ 161 | 3, 162 | "formulary_drug_cd" 163 | ], 164 | [ 165 | 3, 166 | "route" 167 | ], 168 | [ 169 | 3, 170 | "drug_dose" 171 | ], 172 | [ 173 | 4, 174 | "subject_id" 175 | ], 176 | [ 177 | 4, 178 | "hadm_id" 179 | ], 180 | [ 181 | 4, 182 | "itemid" 183 | ], 184 | [ 185 | 4, 186 | "charttime" 187 | ], 188 | [ 189 | 4, 190 | "flag" 191 | ], 192 | [ 193 | 4, 194 | "value_unit" 195 | ], 196 | [ 197 | 4, 198 | "label" 199 | ], 200 | [ 201 | 4, 202 | "fluid" 203 | ], 204 | [ 205 | 4, 206 | "category" 207 | ] 208 | ], 209 | "column_names_original": [ 210 | [ 211 | -1, 212 | "*" 213 | ], 214 | [ 215 | 0, 216 | "subject_id" 217 | ], 218 | [ 219 | 0, 220 | "hadm_id" 221 | ], 222 | [ 223 | 0, 224 | "name" 225 | ], 226 | [ 227 | 0, 228 | "marital_status" 229 | ], 230 | [ 231 | 0, 232 | "age" 233 | ], 234 | [ 235 | 0, 236 | "dob" 237 | ], 238 | [ 239 | 0, 240 | "gender" 241 | ], 242 | [ 243 | 0, 244 | "language" 245 | ], 246 | [ 247 | 0, 248 | "religion" 249 | ], 250 | [ 251 | 0, 252 | "admission_type" 253 | ], 254 | [ 255 | 0, 256 | "days_stay" 257 | ], 258 | [ 259 | 0, 260 | "insurance" 261 | ], 262 | [ 263 | 0, 264 | "ethnicity" 265 | ], 266 | [ 267 | 0, 268 | "expire_flag" 269 | ], 270 | [ 271 | 0, 272 | "admission_location" 273 | ], 274 | [ 275 | 0, 276 | "discharge_location" 277 | ], 278 | [ 279 | 0, 280 | "diagnosis" 281 | ], 282 | [ 283 | 0, 284 | "dod" 285 | ], 286 | [ 287 | 0, 288 | "dob_year" 289 | ], 290 | [ 291 | 0, 292 | "dod_year" 293 | ], 294 | [ 295 | 0, 296 | "admittime" 297 | ], 298 | [ 299 | 0, 300 | "dischtime" 301 | ], 302 | [ 303 | 0, 304 | "admityear" 305 | ], 306 | [ 307 | 1, 308 | "subject_id" 309 | ], 310 | [ 311 | 1, 312 | "hadm_id" 313 | ], 314 | [ 315 | 1, 316 | "icd9_code" 317 | ], 318 | [ 319 | 1, 320 | "short_title" 321 | ], 322 | [ 323 | 1, 324 | "long_title" 325 | ], 326 | [ 327 | 2, 328 | "subject_id" 329 | ], 330 | [ 331 | 2, 332 | "hadm_id" 333 | ], 334 | [ 335 | 2, 336 | "icd9_code" 337 | ], 338 | [ 339 | 2, 340 | "short_title" 341 | ], 342 | [ 343 | 2, 344 | "long_title" 345 | ], 346 | [ 347 | 3, 348 | "subject_id" 349 | ], 350 | [ 351 | 3, 352 | "hadm_id" 353 | ], 354 | [ 355 | 3, 356 | "icustay_id" 357 | ], 358 | [ 359 | 3, 360 | "drug_type" 361 | ], 362 | [ 363 | 3, 364 | "drug" 365 | ], 366 | [ 367 | 3, 368 | "formulary_drug_cd" 369 | ], 370 | [ 371 | 3, 372 | "route" 373 | ], 374 | [ 375 | 3, 376 | "drug_dose" 377 | ], 378 | [ 379 | 4, 380 | "subject_id" 381 | ], 382 | [ 383 | 4, 384 | "hadm_id" 385 | ], 386 | [ 387 | 4, 388 | "itemid" 389 | ], 390 | [ 391 | 4, 392 | "charttime" 393 | ], 394 | [ 395 | 4, 396 | "flag" 397 | ], 398 | [ 399 | 4, 400 | "value_unit" 401 | ], 402 | [ 403 | 4, 404 | "label" 405 | ], 406 | [ 407 | 4, 408 | "fluid" 409 | ], 410 | [ 411 | 4, 412 | "category" 413 | ] 414 | ], 415 | "column_types": [ 416 | "text", 417 | "text", 418 | "text", 419 | "text", 420 | "text", 421 | "text", 422 | "text", 423 | "text", 424 | "text", 425 | "text", 426 | "text", 427 | "text", 428 | "text", 429 | "text", 430 | "text", 431 | "text", 432 | "text", 433 | "text", 434 | "text", 435 | "text", 436 | "text", 437 | "text", 438 | "text", 439 | "text", 440 | "text", 441 | "text", 442 | "text", 443 | "text", 444 | "text", 445 | "text", 446 | "text", 447 | "text", 448 | "text", 449 | "text", 450 | "text", 451 | "text", 452 | "text", 453 | "text", 454 | "text", 455 | "text", 456 | "text", 457 | "text", 458 | "text", 459 | "text", 460 | "text", 461 | "text", 462 | "text", 463 | "text", 464 | "text", 465 | "text" 466 | ], 467 | "db_id": "mimicsql", 468 | "foreign_keys": [], 469 | "primary_keys": [], 470 | "table_names": [ 471 | "demographic", 472 | "diagnoses", 473 | "procedures", 474 | "prescriptions", 475 | "lab" 476 | ], 477 | "table_names_original": [ 478 | "demographic", 479 | "diagnoses", 480 | "procedures", 481 | "prescriptions", 482 | "lab" 483 | ] 484 | } 485 | ] -------------------------------------------------------------------------------- /data_prep/data_utils.py: -------------------------------------------------------------------------------- 1 | """Training data prep utils.""" 2 | import json 3 | import re 4 | from collections import defaultdict 5 | from typing import Any 6 | 7 | import sqlglot 8 | from kummerfeld_utils import preprocess_for_jsql 9 | from schema import ForeignKey, Table, TableColumn 10 | 11 | _AGG_OPS = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"] 12 | _COND_OPS = ["=", ">", "<", "OP"] 13 | 14 | 15 | def escape_everything(string: str) -> str: 16 | """Escape everything. 17 | 18 | Args: 19 | string: string to escape 20 | 21 | Returns: 22 | Escaped string. 23 | """ 24 | return json.dumps(string)[1:-1] 25 | 26 | 27 | def serialize_dict_to_str(d: dict) -> str: 28 | """ 29 | Serialize a dict into a str. 30 | 31 | Args: 32 | d: dict to serialize. 33 | 34 | Returns: 35 | serialized dict. 36 | """ 37 | return json.dumps(d, sort_keys=True) 38 | 39 | 40 | def read_tables_json( 41 | schema_file: str, 42 | lowercase: bool = False, 43 | ) -> dict[str, dict[str, Table]]: 44 | """Read tables json.""" 45 | data = json.load(open(schema_file)) 46 | db_to_tables = {} 47 | for db in data: 48 | db_name = db["db_id"] 49 | table_names = db["table_names_original"] 50 | db["column_names_original"] = [ 51 | [x[0], x[1]] for x in db["column_names_original"] 52 | ] 53 | db["column_types"] = db["column_types"] 54 | if lowercase: 55 | table_names = [tn.lower() for tn in table_names] 56 | pks = db["primary_keys"] 57 | fks = db["foreign_keys"] 58 | tables = defaultdict(list) 59 | tables_pks = defaultdict(list) 60 | tables_fks = defaultdict(list) 61 | for idx, ((ti, col_name), col_type) in enumerate( 62 | zip(db["column_names_original"], db["column_types"]) 63 | ): 64 | if ti == -1: 65 | continue 66 | if lowercase: 67 | col_name = col_name.lower() 68 | col_type = col_type.lower() 69 | if idx in pks: 70 | tables_pks[table_names[ti]].append( 71 | TableColumn(name=col_name, dtype=col_type) 72 | ) 73 | for fk in fks: 74 | if idx == fk[0]: 75 | other_column = db["column_names_original"][fk[1]] 76 | other_column_type = db["column_types"][fk[1]] 77 | other_table = table_names[other_column[0]] 78 | tables_fks[table_names[ti]].append( 79 | ForeignKey( 80 | column=TableColumn(name=col_name, dtype=col_type), 81 | references_name=other_table, 82 | references_column=TableColumn( 83 | name=other_column[1], dtype=other_column_type 84 | ), 85 | ) 86 | ) 87 | tables[table_names[ti]].append(TableColumn(name=col_name, dtype=col_type)) 88 | db_to_tables[db_name] = { 89 | table_name: Table( 90 | name=table_name, 91 | columns=tables[table_name], 92 | pks=tables_pks[table_name], 93 | fks=tables_fks[table_name], 94 | ) 95 | for table_name in tables 96 | } 97 | return db_to_tables 98 | 99 | 100 | def clean_str(target: str) -> str: 101 | """Clean string for question.""" 102 | if not target: 103 | return target 104 | 105 | target = re.sub(r"[^\x00-\x7f]", r" ", target) 106 | line = re.sub(r"''", r" ", target) 107 | line = re.sub(r"``", r" ", line) 108 | line = re.sub(r"\"", r"'", line) 109 | line = re.sub(r"[\t ]+", " ", line) 110 | return line.strip() 111 | 112 | 113 | def case_sql(query: str) -> str: 114 | """Case sql query.""" 115 | try: 116 | cased_sql = sqlglot.parse_one(query).sql() # type: ignore 117 | # SQLGlot makes NOT IN. We want NOT IN for Spider 118 | cased_sql = re.sub(r"NOT\s+([^\s]+)\s+IN", r"\1 NOT IN", cased_sql) 119 | # Replace <> with != 120 | cased_sql = cased_sql.replace("<>", "!=") 121 | return cased_sql 122 | except Exception: 123 | print("Cannot CASE this SQL") 124 | return query 125 | 126 | 127 | def crude_remove_aliases(sql: str) -> str: 128 | """Cruder way of cleaning up aliases.""" 129 | alias2cleanalias = {} 130 | new_sql = re.sub(r"[\t\s\n]+", " ", sql) 131 | for word in sql.split(): 132 | if "." in word: 133 | alias = word.split(".")[0] 134 | if "alias" in alias: 135 | clean_alias = alias.split("alias")[0] + "_" + alias.split("alias")[1] 136 | alias2cleanalias[alias] = clean_alias 137 | for alias, clean_alias in alias2cleanalias.items(): 138 | new_sql = new_sql.replace(alias, clean_alias) 139 | return new_sql 140 | 141 | 142 | def remove_aliases(sql: str) -> str: 143 | """Remove aliases from SQL.""" 144 | new_sql = re.sub(r"[\t\s\n]+", " ", sql) 145 | # Handle from 146 | alias2table = {} 147 | table2alias: dict[str, list[str]] = {} 148 | # Get substring from FROM to WHERE or to GROUP BY or to end 149 | inside_from = re.search( 150 | r"FROM (.*?) (WHERE|GROUP BY|ORDER BY|LIMIT|;)", new_sql, re.DOTALL 151 | ) 152 | if not inside_from: 153 | inside_from = re.search(r"FROM (.*?)$", new_sql, re.DOTALL) 154 | if not inside_from: 155 | print("BAD FROM", sql) 156 | for from_clause in re.split( 157 | r",| INNER JOIN| OUTER JOIN| LEFT JOIN| RIGHT JOIN| JOIN| EXCEPT", 158 | inside_from.group(1), # type: ignore 159 | ): 160 | # If JOIN table ON XXX, remove the ON XXX 161 | if " ON " in from_clause: 162 | from_clause = from_clause.split(" ON ")[0] 163 | if " AS " in from_clause: 164 | table = from_clause.split(" AS ")[0].strip() 165 | alias = from_clause.split(" AS ")[1].strip() 166 | alias2table[alias] = table 167 | # If we have two of the same tables in the from clause 168 | # must keep and handle aliases differently 169 | if table in table2alias: 170 | # If only one already in, start creating new aliases 171 | if len(table2alias[table]) == 1: 172 | old_alias = table2alias[table][0] 173 | table2alias[table] = [f"{table}_{len(table2alias[table])-1}"] 174 | alias2table[old_alias] = table2alias[table][-1] 175 | table2alias[table].append(f"{table}_{len(table2alias[table])}") 176 | alias2table[alias] = table2alias[table][-1] 177 | else: 178 | table2alias[table] = [alias] 179 | # Now replace AS alias in from clauses where we can 180 | for from_clause in re.split( 181 | r",| INNER JOIN| OUTER JOIN| LEFT JOIN| RIGHT JOIN| JOIN| EXCEPT", 182 | inside_from.group(1), # type: ignore 183 | ): 184 | if " ON " in from_clause: 185 | from_clause = from_clause.split(" ON ")[0] 186 | if " AS " in from_clause: 187 | table = from_clause.split(" AS ")[0].strip() 188 | alias = from_clause.split(" AS ")[1].strip() 189 | if len(table2alias[table]) == 1: 190 | new_sql = new_sql.replace(from_clause, " " + table) 191 | 192 | # Replace old aliases with new ones (or og table name) 193 | for al, table in alias2table.items(): 194 | new_sql = new_sql.replace(al, table) 195 | 196 | # Replace table references as not needed with one table 197 | if len(alias2table) == 1: 198 | table = list(alias2table.values())[0] 199 | new_sql = new_sql.replace(table + ".", "") 200 | 201 | new_sql = re.sub(r"[\t\s\n]+", " ", new_sql) 202 | return new_sql 203 | 204 | 205 | def get_table_alias_to_ref_map(sql: str) -> dict[str, set[str]]: 206 | """Get all aliases and the reference tables they point to. 207 | 208 | Key of None will be all unaliased tables. 209 | 210 | This accounts for both table AS T1 clauses and subexpressions. 211 | """ 212 | try: 213 | parsed: sqlglot.expressions.Expression = sqlglot.parse_one(sql, read="sqlite") 214 | except Exception: 215 | return defaultdict(set) 216 | # Get all table aliases - including CTEs 217 | mapping = defaultdict(set) 218 | all_table_aliases = list(parsed.find_all(sqlglot.exp.TableAlias)) 219 | for tbl_alias in all_table_aliases: 220 | sql_parent = tbl_alias.parent 221 | if sql_parent: 222 | tbls = [ 223 | table.name 224 | for table in sql_parent.find_all(sqlglot.exp.Table) 225 | if table.name != tbl_alias.name 226 | ] 227 | if tbls: 228 | mapping[tbl_alias.name].update(tbls) 229 | # Add any table without alias 230 | for table in parsed.find_all(sqlglot.exp.Table): 231 | if not table.alias or table.alias == table.name: 232 | mapping[None].add(table.name) 233 | return mapping 234 | 235 | 236 | def format_to_match_schema( 237 | sql: str, 238 | schema: dict[str, Table], 239 | ) -> str: 240 | """Format the tables and columns in the query to match the schema.""" 241 | table_alias_to_ref = get_table_alias_to_ref_map(sql) 242 | all_tables = set().union(*table_alias_to_ref.values()) 243 | all_tables_lower = {t.lower() for t in all_tables} 244 | tablename2colset = { 245 | tbl.name: set([c.name for c in tbl.columns]) 246 | for tbl in schema.values() 247 | if tbl.name.lower() in all_tables_lower 248 | } 249 | 250 | def transformer(node: sqlglot.Expression) -> sqlglot.Expression: 251 | if isinstance(node, sqlglot.exp.Column): 252 | for tbl in tablename2colset: 253 | for col in tablename2colset[tbl]: 254 | # Due to table aliases, we don't want to make this a joint 255 | # condition on the column and alias 256 | if node.table and node.table.lower() == tbl.lower(): 257 | node.args["table"] = tbl 258 | if node.name.lower() == col.lower(): 259 | node.args["this"] = col 260 | break 261 | elif isinstance(node, sqlglot.exp.Table): 262 | for tbl in tablename2colset: 263 | if node.name.lower() == tbl.lower(): 264 | node.args["this"] = tbl 265 | break 266 | return node 267 | 268 | parsed: sqlglot.expressions.Expression = sqlglot.parse_one(sql, read="sqlite") 269 | transformed_parsed = parsed.transform(transformer) 270 | return transformed_parsed.sql() 271 | 272 | 273 | def convert_kummerfeld_instance( 274 | data: dict[str, Any], 275 | schema: dict[str, dict[str, Table]] = {}, 276 | keep_vars: bool = False, 277 | keep_sql_vars: bool = False, 278 | ) -> list[dict[str, Any]]: 279 | """Convert a single instance of the data into a list of examples. 280 | 281 | Used for the text2sql-data repo from jkkummerfeld. 282 | """ 283 | var_sql = None 284 | var_sql = data["sql"][0] # 285 | parsed_results: list[dict[str, Any]] = [] 286 | for sentence in data["sentences"]: 287 | text = sentence["text"] 288 | sql = preprocess_for_jsql( 289 | var_sql 290 | ) # Needed to do variable replacement correctly 291 | if not sql: 292 | raise ValueError(f"No SQL for sentence {sentence}") 293 | sql = str(sql) 294 | cleaned_sql = remove_aliases(sql) 295 | cleaned_sql = case_sql(cleaned_sql) 296 | crude_cleaned_sql = crude_remove_aliases(sql) 297 | crude_cleaned_sql = case_sql(crude_cleaned_sql) 298 | # Variable replacement 299 | if not keep_vars: 300 | for name in sentence["variables"]: 301 | value = sentence["variables"][name] 302 | if len(value) == 0: 303 | for variable in data["variables"]: 304 | if variable["name"] == name: 305 | value = variable["example"] 306 | text = value.join(text.split(name)) 307 | if not keep_sql_vars: 308 | cleaned_sql = value.join(cleaned_sql.split(name)) 309 | crude_cleaned_sql = value.join(crude_cleaned_sql.split(name)) 310 | sql = value.join(sql.split(name)) # type: ignore 311 | 312 | # Query split is either train/dev/test or 0-9 for cross validation 313 | # We use test/0 for test, dev/1 for dev and the rest for train 314 | if data["query-split"] == "N/A": 315 | # Flip a coin to decide if it's train or test or valid 316 | output_file = sentence["question-split"] 317 | else: 318 | if data["query-split"] == "test" or data["query-split"] == "0": 319 | output_file = "test" 320 | elif data["query-split"] == "dev" or data["query-split"] == "1": 321 | output_file = "dev" 322 | else: 323 | output_file = "train" 324 | 325 | db_id = sentence.get("database", "database").lower() 326 | try: 327 | cleaned_sql = format_to_match_schema(cleaned_sql, schema[db_id]) 328 | except Exception: 329 | print("ERROR") 330 | continue 331 | parsed_results.append( 332 | { 333 | "question": text, 334 | "sql": cleaned_sql, 335 | "split": output_file, 336 | "db_id": db_id, 337 | } 338 | ) 339 | 340 | return parsed_results 341 | 342 | 343 | def convert_sede_instance( 344 | data: dict[str, str], 345 | schema: dict[str, dict[str, Table]] = {}, 346 | ) -> dict[str, Any]: 347 | """Convert a single instance of the data into an example. 348 | 349 | Used for the sede dataset. 350 | """ 351 | # clean title and description 352 | cleaned_title = clean_str(data["Title"]) 353 | cleaned_description = clean_str(data["Description"]) 354 | 355 | # clean SQL query 356 | cleaned_sql = None 357 | # cleaned_sql_with_values = None 358 | if data["QueryBody"]: 359 | target_with_values = str(preprocess_for_jsql(data["QueryBody"])) 360 | target_with_values = case_sql(target_with_values) 361 | if target_with_values: 362 | target_tokens = target_with_values.strip(";").split() 363 | target = case_sql(" ".join(target_tokens)) 364 | cleaned_sql = target 365 | # Handle With statement by removing WITH part before SELECT 366 | # Example: 367 | # WITH no activity in last 6 months SELECT * FROM TABLE 368 | # --> 369 | # SELECT * FROM TABLE 370 | index_s = cleaned_sql.lower().find("select") 371 | index_w = cleaned_sql.lower().find("with") 372 | if index_w < index_s and index_w == 0: 373 | prefix = re.sub(r"\s\s+", " ", cleaned_sql[index_w:index_s][4:].strip()) 374 | # Ignore the valid CTE: With a AS(...) 375 | # Don't want to skip With a AS ... 376 | # since no () means it won't use the defined var in the SQL 377 | if not ( 378 | prefix.lower().endswith(" as (") or prefix.lower().endswith(" as(") 379 | ): 380 | print("ORI:", cleaned_sql) 381 | print("NEW:", case_sql(cleaned_sql[index_s:])) 382 | cleaned_sql = case_sql(cleaned_sql[index_s:]) 383 | 384 | # Try to convert from TSQL to SQLite 385 | try: 386 | cleaned_sql = sqlglot.transpile(cleaned_sql, read="tsql", write="sqlite")[0] 387 | except Exception: 388 | pass 389 | 390 | if cleaned_title and cleaned_sql: 391 | try: 392 | cleaned_sql = format_to_match_schema(cleaned_sql, schema["stackexchange"]) 393 | cleaned_sql = case_sql(cleaned_sql) 394 | except Exception: 395 | print("ERROR:::", cleaned_sql) 396 | cleaned_sql = None 397 | if cleaned_sql: 398 | preprocessed_annotated_sql = { 399 | "question": ( 400 | cleaned_title.strip() + ". " + (cleaned_description or "").strip() 401 | ).strip(), 402 | "db_id": "stackexchange", 403 | "sql": cleaned_sql, 404 | } 405 | else: 406 | preprocessed_annotated_sql = {} 407 | else: 408 | preprocessed_annotated_sql = {} 409 | 410 | return preprocessed_annotated_sql 411 | 412 | 413 | def convert_spider_instance( 414 | data: dict[str, str], 415 | schema: dict[str, dict[str, Table]] = {}, 416 | ) -> dict[str, Any]: 417 | """Convert a single instance of the data into an example. 418 | 419 | Used for the spider dataset. 420 | """ 421 | query = data["query"] 422 | question = data["question"] 423 | db_id = data["db_id"] 424 | target = case_sql(query) 425 | target = format_to_match_schema(target, schema[db_id]) 426 | sql = { 427 | "question": question, 428 | "db_id": db_id, 429 | "sql": target, 430 | } 431 | # Check if example is impossible to answer 432 | if not data.get("is_impossible", False): 433 | return sql 434 | return {} 435 | 436 | 437 | def convert_wikisql_instance( 438 | data: dict[str, Any], 439 | schema: dict[str, dict[str, Table]] = {}, 440 | ) -> dict[str, Any]: 441 | """Convert a single instance of the data into an example. 442 | 443 | Used for the wikisql dataset. 444 | """ 445 | 446 | def _convert_to_human_readable( 447 | table_name: str, 448 | sel: int, 449 | agg: int, 450 | columns: list[str], 451 | conditions: list[tuple[int, int, str]], 452 | ) -> str: 453 | """Make SQL query string. Based on https://github.com/salesforce/WikiSQL/blob/c2ed4f9b22db1cc2721805d53e6e76e07e2ccbdc/lib/query.py#L10""" # noqa: E501 454 | strip_quotes = lambda x: x.strip('"').strip("'").replace("'", "''") 455 | quoted_columns = [f'"{escape_everything(c)}"' for c in columns] 456 | if _AGG_OPS[agg] == "": 457 | rep = f"SELECT {quoted_columns[sel] if quoted_columns is not None else f'col{sel}'} FROM {table_name}" # noqa: E501 458 | else: 459 | rep = f"SELECT {_AGG_OPS[agg]}({quoted_columns[sel] if quoted_columns is not None else f'col{sel}'}) FROM {table_name}" # noqa: E501 460 | 461 | if conditions: 462 | rep += " WHERE " + " AND ".join( 463 | [ 464 | f"{quoted_columns[i]} {_COND_OPS[o]} '{strip_quotes(v)}'" 465 | for i, o, v in conditions 466 | ] 467 | ) 468 | return " ".join(rep.split()) 469 | 470 | conds = data["sql"]["conds"] 471 | iov_list = list( 472 | zip(conds["column_index"], conds["operator_index"], conds["condition"]) 473 | ) 474 | query = _convert_to_human_readable( 475 | data["table"]["name"], 476 | data["sql"]["sel"], 477 | data["sql"]["agg"], 478 | data["table"]["header"], 479 | iov_list, 480 | ) 481 | question = data["question"] 482 | db_id = data["table"]["name"] 483 | target = case_sql(query) 484 | 485 | try: 486 | target = format_to_match_schema(target, schema[db_id]) 487 | except Exception as e: 488 | print("ERROR:::") 489 | print(target) 490 | print(e) 491 | return {} 492 | sql = { 493 | "question": question, 494 | "db_id": db_id, 495 | "sql": target, 496 | } 497 | return sql 498 | 499 | 500 | def convert_criteria2sql_instance( 501 | data: dict[str, Any], 502 | schema: dict[str, dict[str, Table]] = {}, 503 | ) -> dict[str, Any]: 504 | """Convert a single instance of the data into an example. 505 | 506 | Modified from the criteria2sql dataset. 507 | """ 508 | # We want to use the 'real' table name and all columns in the query 509 | assert data["query"].startswith("select id from records") 510 | query = data["query"] 511 | query = query.replace("select id from records", f"select * from {data['db_id']}") 512 | question = data["question"] 513 | db_id = data["db_id"] 514 | target = case_sql(query) 515 | 516 | try: 517 | target = format_to_match_schema(target, schema[db_id]) 518 | except Exception as e: 519 | print("ERROR:::") 520 | print(target) 521 | print(e) 522 | return {} 523 | 524 | sql = { 525 | "question": question, 526 | "db_id": db_id, 527 | "sql": target, 528 | } 529 | return sql 530 | 531 | 532 | def convert_sql_create_context_instance( 533 | data: dict[str, Any], 534 | schema: dict[str, dict[str, Table]] = {}, 535 | ) -> dict[str, Any]: 536 | """Convert a single instance of the data into an example.""" 537 | query = data["answer"] 538 | question = data["question"] 539 | db_id = data["db_id"] 540 | target = case_sql(query) 541 | 542 | try: 543 | target = format_to_match_schema(target, schema[db_id]) 544 | except Exception as e: 545 | print("ERROR:::") 546 | print(target) 547 | print(e) 548 | return {} 549 | 550 | sql = { 551 | "question": question, 552 | "db_id": db_id, 553 | "sql": target, 554 | } 555 | return sql 556 | 557 | 558 | def convert_squall_instance( 559 | data: dict[str, Any], 560 | schema: dict[str, dict[str, Table]] = {}, 561 | ) -> dict[str, Any]: 562 | """Convert a single instance of the data into an example.""" 563 | db_id = data["db_id"] 564 | question = " ".join(data["nl"]) 565 | sql_toks = [] 566 | for tok in data["sql"]: 567 | if tok[0] in ["Literal.Number", "Literal.String"]: 568 | sql_toks.append(tok[1]) 569 | elif tok[0] == "Keyword": 570 | sql_toks.append(tok[1] if tok[1] != "w" else db_id) 571 | else: 572 | if "_" in tok[1]: 573 | idx = int(tok[1][1 : tok[1].find("_")]) 574 | else: 575 | idx = int(tok[1][1]) 576 | sql_toks.append(schema[db_id][db_id].columns[idx].name) 577 | query = " ".join(sql_toks) 578 | # Fix not null error 579 | query = query.replace("not null", "is not null") 580 | target = case_sql(query) 581 | 582 | try: 583 | target = format_to_match_schema(target, schema[db_id]) 584 | except Exception as e: 585 | print("ERROR:::") 586 | print(target) 587 | print(e) 588 | return {} 589 | 590 | sql = { 591 | "question": question, 592 | "db_id": db_id, 593 | "sql": target, 594 | } 595 | return sql 596 | 597 | 598 | def convert_css_nvbench_instance( 599 | data: dict[str, Any], 600 | schema: dict[str, dict[str, Table]] = {}, 601 | ) -> dict[str, Any]: 602 | """Convert a single instance of the data into an example.""" 603 | db_id = data["db_id"] 604 | question = data["question"] 605 | query = data["query"] 606 | target = case_sql(query) 607 | try: 608 | target = format_to_match_schema(target, schema[db_id]) 609 | except Exception as e: 610 | print("ERROR:::") 611 | print(target) 612 | print(e) 613 | return {} 614 | 615 | sql = { 616 | "question": question, 617 | "db_id": db_id, 618 | "sql": target, 619 | } 620 | return sql 621 | -------------------------------------------------------------------------------- /data_prep/kummerfeld_utils.py: -------------------------------------------------------------------------------- 1 | """SQL processing utils. 2 | 3 | Adapted from https://github.com/hirupert/sede 4 | """ 5 | import re 6 | 7 | ALIAS_PATTERN = re.compile(r"\[([^\]]+)]", re.MULTILINE | re.IGNORECASE) 8 | TAGS_PATTERN = re.compile( 9 | r"([^'%])(##[a-z0-9_?:]+##)([^'%]?)", re.MULTILINE | re.IGNORECASE 10 | ) 11 | TOP_TAGS_PATTERN = re.compile( 12 | r"(top|percentile_cont)([ ]+)?[\(]?[ ]?(##[a-z0-9_]+(:[a-z]+)?(\?([0-9.]+))?##)[ ]?[\)]?", 13 | re.IGNORECASE, 14 | ) 15 | SQL_TOKENS = { 16 | "select", 17 | "from", 18 | "where", 19 | "group", 20 | "order", 21 | "limit", 22 | "intersect", 23 | "union", 24 | "except", 25 | "join", 26 | "on", 27 | "as", 28 | "not", 29 | "between", 30 | "=", 31 | ">", 32 | "<", 33 | ">=", 34 | "<=", 35 | "!=", 36 | "in", 37 | "like", 38 | "is", 39 | "exists", 40 | "none", 41 | "max", 42 | "min", 43 | "count", 44 | "sum", 45 | "avg", 46 | "or", 47 | "and", 48 | } 49 | 50 | 51 | def _remove_comment_at_beginning(cleaned_query: str) -> str: 52 | """Remove comments at the beginning of the line.""" 53 | return re.sub(r"^([- ]+|(result))+", "", cleaned_query, re.MULTILINE) 54 | 55 | 56 | def remove_comments(sql: str) -> str: 57 | """Remove comments from sql.""" "" 58 | # remove comments at the beginning of line 59 | sql = _remove_comment_at_beginning(sql) 60 | 61 | # remove comments at the end of lines 62 | sql = re.sub(r"--(.+)?\n", "", sql) 63 | 64 | # remove comments at the end of lines 65 | sql = re.sub(r"\n;\n", " ", sql) 66 | 67 | sql = re.sub(" +", " ", sql) 68 | 69 | return sql.strip() 70 | 71 | 72 | def remove_comments_after_removing_new_lines(sql: str) -> str: 73 | """Remove comments and newlines from sql.""" 74 | # remove comments at the end of the query 75 | sql = re.sub(r"--(.?)+$", "", sql, re.MULTILINE) 76 | 77 | # remove comments like /* a comment */ 78 | sql = re.sub(r"/\*[^*/]+\*/", "", sql, re.MULTILINE) 79 | 80 | sql = re.sub(" +", " ", sql) 81 | 82 | return sql.strip() 83 | 84 | 85 | def _surrounded_by_apostrophes(sql: str, start_index: int, end_index: int) -> bool: 86 | """Check if the string is surrounded by apostrophes.""" 87 | max_steps = 10 88 | 89 | starts_with_apostrophe = False 90 | step_count = 0 91 | while start_index >= 0 and step_count < max_steps: 92 | if sql[start_index] == "'": 93 | starts_with_apostrophe = True 94 | break 95 | if sql[start_index] == " ": 96 | starts_with_apostrophe = False 97 | break 98 | start_index -= 1 99 | step_count += 1 100 | 101 | end_with_apostrophe = False 102 | step_count = 0 103 | while end_index < len(sql) and step_count < max_steps: 104 | if sql[end_index] == "'": 105 | end_with_apostrophe = True 106 | break 107 | if sql[end_index] == " ": 108 | end_with_apostrophe = False 109 | break 110 | end_index += 1 111 | step_count += 1 112 | 113 | return starts_with_apostrophe and end_with_apostrophe 114 | 115 | 116 | # pylint: disable=too-many-branches 117 | def preprocess_for_jsql(sql: str) -> str | None: 118 | """Preprocess sql for jsql.""" 119 | # replace all alias like "as [User Id]" to "as 'user_id'" 120 | match = re.search(ALIAS_PATTERN, sql) 121 | while match is not None: 122 | group_one = match.group(1) 123 | if not _surrounded_by_apostrophes(sql, match.start(), match.end()): 124 | new_alias = f"'{group_one.lower()}'" 125 | else: 126 | new_alias = group_one.lower() 127 | 128 | if " " in new_alias: 129 | new_alias = new_alias.replace(" ", "_") 130 | sql = sql.replace(match.group(0), new_alias) 131 | match = re.search(ALIAS_PATTERN, sql) 132 | 133 | # replace all parameters like "TOP ##topn:int?200##" to "TOP 200" 134 | match = re.search(TOP_TAGS_PATTERN, sql) 135 | while match is not None: 136 | group_zero = match.group(0) 137 | default_number = match.group(6) 138 | 139 | if default_number is not None: 140 | new_alias = f"{match.group(1)} ({default_number})" 141 | else: 142 | new_alias = f"{match.group(1)} (100)" 143 | 144 | sql = sql.replace(group_zero, new_alias) 145 | match = re.search(TOP_TAGS_PATTERN, sql) 146 | 147 | # replace all parameters like ##tagName:Java## to '##tagName:Java##' 148 | new_sql = "" 149 | match = re.search(TAGS_PATTERN, sql) 150 | while match is not None: 151 | group_two = match.group(2) 152 | 153 | if not _surrounded_by_apostrophes(sql, match.start(), match.end()): 154 | new_alias = f"{match.group(1)}'{group_two}'{match.group(3)}" 155 | new_sql = new_sql + sql[0 : match.start()] + new_alias 156 | else: 157 | new_sql = new_sql + sql[0 : match.start()] + match.group(0) 158 | 159 | sql = sql[match.end() :] 160 | match = re.search(TAGS_PATTERN, sql) 161 | if sql: 162 | new_sql = new_sql + sql 163 | sql = new_sql 164 | 165 | # convert FORMAT function to CONVERT function to support JSQL 166 | sql = re.sub(r" format\(", " convert(", sql, flags=re.IGNORECASE) 167 | 168 | # remove comments from SQL 169 | sql = remove_comments(sql) 170 | 171 | # replace N'%Kitchener%' with '%Kitchener%' 172 | sql = re.sub(r" N'", " '", sql, re.IGNORECASE) 173 | 174 | # remove declares with a new line 175 | sql = re.sub( 176 | r"(DECLARE|declare) [^\n]+\n", 177 | " ", 178 | sql, 179 | re.IGNORECASE | re.MULTILINE | re.DOTALL, 180 | ) 181 | 182 | # remove new lines 183 | sql = re.sub(r"[\n\t\r]+", " ", sql) 184 | 185 | sql = remove_comments_after_removing_new_lines(sql) 186 | 187 | # remove declares 188 | sql = re.sub( 189 | r"(DECLARE|declare) [^;]+;", " ", sql, re.IGNORECASE | re.MULTILINE | re.DOTALL 190 | ) 191 | sql = re.sub(r"(DECLARE|declare) (?:.(?!(SELECT|select)))", "SELECT", sql) 192 | 193 | if "))))))))))))))))))))" in sql or "((((((((((((((((((((" in sql: 194 | return None 195 | 196 | if "cast(avg(cast(avg(cast(avg(cast(avg(cast(avg(cast(avg(cast(avg(" in sql: 197 | return None 198 | 199 | sql = re.sub(r"[^\x00-\x7f]", r" ", sql) 200 | sql = re.sub(r"``", r"'", sql) 201 | sql = re.sub(r"\"", r"'", sql) 202 | sql = re.sub(r" +", " ", sql).strip() 203 | 204 | if not sql: 205 | return None 206 | 207 | if sql[-1] == ";": 208 | sql = sql[0:-1] 209 | 210 | if ";" in sql: 211 | sql = sql.split(";")[-1] 212 | 213 | return sql 214 | -------------------------------------------------------------------------------- /data_prep/prep_text2sql_data.py: -------------------------------------------------------------------------------- 1 | """Prepare data for NSText2SQL.""" 2 | 3 | import hashlib 4 | import json 5 | import multiprocessing 6 | import os 7 | import random 8 | from collections import defaultdict 9 | from datetime import datetime 10 | from functools import partial 11 | from pathlib import Path 12 | 13 | import click 14 | import numpy as np 15 | import yaml 16 | from prompt_formatters import RajkumarFormatter 17 | from rich.console import Console 18 | from text2sql_dataset import ( 19 | CSS2SQL, 20 | NVBENCH2SQL, 21 | Criteria2SQL2SQL, 22 | KummerfeldText2SQL, 23 | MimicsqlText2SQL, 24 | SedeText2SQL, 25 | SpiderText2SQL, 26 | SqlCreateContext2SQL, 27 | Squall2SQL, 28 | Text2SQLData, 29 | Text2SQLDataset, 30 | WikiSQL2SQL, 31 | ) 32 | from tqdm.auto import tqdm 33 | from transformers import AutoTokenizer 34 | 35 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 36 | 37 | console = Console(soft_wrap=True) 38 | 39 | 40 | TEXT2SQL_DATA_LOADERS = { 41 | "kummerfeld": KummerfeldText2SQL, 42 | "sede": SedeText2SQL, 43 | "spider": SpiderText2SQL, 44 | "wikisql": WikiSQL2SQL, 45 | "mimicsql": MimicsqlText2SQL, 46 | "criteria2sql": Criteria2SQL2SQL, 47 | "sql_create_context2sql": SqlCreateContext2SQL, 48 | "squall": Squall2SQL, 49 | "css": CSS2SQL, 50 | "nvbench": NVBENCH2SQL, 51 | } 52 | 53 | 54 | def process_dataset( 55 | prompt_formatter: RajkumarFormatter, 56 | splits: dict[str, list[Text2SQLData]], 57 | bad_parses: dict[str, int], 58 | total: dict[str, int], 59 | text2sql_dataset: Text2SQLDataset, 60 | hash_experiment_key: str, 61 | ) -> None: 62 | """Process a dataset and add it to the splits.""" 63 | schema = text2sql_dataset.load_schema() 64 | temp_outfile = f"_temp_text2sql/prep_data/temp_{hash_experiment_key}.jsonl" 65 | Path(temp_outfile).parent.mkdir(parents=True, exist_ok=True) 66 | if os.path.exists(temp_outfile): 67 | console.print(f"Reading from {temp_outfile}") 68 | with open(temp_outfile, "r") as in_f: 69 | loaded_data = json.load(in_f) 70 | else: 71 | loaded_data = text2sql_dataset.load_data(schema) 72 | console.print(f"Saving to {temp_outfile}") 73 | with open(temp_outfile, "w") as out_f: 74 | json.dump(loaded_data, out_f) 75 | 76 | formatting_func = partial( 77 | text2sql_dataset.format_example, 78 | schema=schema, 79 | prompt_formatter=prompt_formatter, 80 | ) 81 | cnt = 0 82 | for split, split_data in loaded_data.items(): 83 | console.print(f"Found {len(split_data)} examples for {split}.") 84 | pool = multiprocessing.Pool( 85 | processes=15, 86 | ) 87 | for formatted_data in tqdm( 88 | pool.imap(formatting_func, split_data, chunksize=100), 89 | total=len(split_data), 90 | desc=f"Formatting {split}", 91 | ): 92 | total[split] += 1 93 | cnt += 1 94 | if formatted_data: 95 | formatted_as_traindata = Text2SQLData(**formatted_data) 96 | splits[split].append(Text2SQLData(**formatted_data)) 97 | if total[split] <= 20 or cnt <= 20: 98 | console.print(f"\n***[yellow]Example {total[split]}[/yellow]***") 99 | console.print( 100 | json.dumps( 101 | formatted_as_traindata.dict(), indent=2, ensure_ascii=False 102 | ) 103 | ) 104 | else: 105 | bad_parses[split] += 1 106 | pool.close() 107 | pool.join() 108 | console.print(f"Bad parses: {json.dumps(bad_parses, indent=2, ensure_ascii=False)}") 109 | 110 | 111 | @click.command() 112 | @click.option("--datasets", type=str, required=True, multiple=True) 113 | @click.option( 114 | "--config_path", 115 | type=str, 116 | default=(f"{os.path.join(os.path.dirname(__file__))}/text2sql_data_config.yaml"), 117 | ) 118 | @click.option("--output_dir", type=str, default="") 119 | @click.option("--seed", type=int, default=0) 120 | @click.option("--tokenizer_name", type=str, default="Salesforce/codegen-2B-multi") 121 | @click.option("--seq_length", type=int, default=2048) 122 | @click.option("--merge_dev", type=bool, default=False, is_flag=True) 123 | @click.option("--merge_test", type=bool, default=False, is_flag=True) 124 | def build( 125 | datasets: list[str], 126 | config_path: str, 127 | output_dir: str, 128 | seed: int, 129 | tokenizer_name: str, 130 | seq_length: int, 131 | merge_dev: bool, 132 | merge_test: bool, 133 | ) -> None: 134 | """Build training data for text2SQL model training. 135 | 136 | Args: 137 | datasets: the datasets to read - matches on name in config 138 | config_path: path to config 139 | output_dir: output directory 140 | seed: the random seed 141 | tokenizer_name: the tokenizer to use 142 | seq_length: max seq_length for training data 143 | merge_dev: whether merge dev in to train 144 | merge_test: whether merge test in to train 145 | """ 146 | to_save_args = locals() 147 | random.seed(seed) 148 | np.random.seed(seed) 149 | try: 150 | multiprocessing.set_start_method("spawn") 151 | except RuntimeError: 152 | pass 153 | 154 | prompt_formatter = RajkumarFormatter() 155 | 156 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 157 | 158 | config = yaml.safe_load(open(config_path)) 159 | to_save_args["config"] = config 160 | 161 | data_configs = {} 162 | for data_loader in config["datasets"]: 163 | data_configs[data_loader["name"]] = data_loader 164 | assert len(data_configs) == len( 165 | config["datasets"] 166 | ), f"Overloaded name in {config_path}" 167 | 168 | datasets = [dataset.lower() for dataset in datasets] 169 | for dataset in datasets: 170 | if dataset not in data_configs: 171 | raise ValueError(f"Dataset {dataset} not supported.") 172 | if data_configs[dataset]["loader"] not in TEXT2SQL_DATA_LOADERS: 173 | raise ValueError(f"Loader {data_configs[dataset]['loader']} not supported.") 174 | 175 | data_classes: list[Text2SQLDataset] = [ 176 | TEXT2SQL_DATA_LOADERS[data_configs[dataset]["loader"]]( # type: ignore 177 | **data_configs[dataset], 178 | context_length=seq_length, 179 | tokenizer_name=tokenizer_name, 180 | ) 181 | for dataset in datasets 182 | ] 183 | 184 | splits: dict[str, list[Text2SQLData]] = {"train": [], "dev": [], "test": []} 185 | bad_parses: dict[str, int] = defaultdict(int) 186 | total: dict[str, int] = defaultdict(int) 187 | 188 | for data_class in data_classes: 189 | console.print(f"[green]Loading[/green] {data_class.name}") 190 | to_hash_args = to_save_args.copy() 191 | to_hash_args["dataset_name"] = data_class.name 192 | hash_experiment_key = hashlib.sha256( 193 | json.dumps(to_hash_args, sort_keys=True).encode("utf-8") 194 | ).hexdigest() 195 | process_dataset( 196 | prompt_formatter=prompt_formatter, 197 | splits=splits, 198 | bad_parses=bad_parses, 199 | total=total, 200 | text2sql_dataset=data_class, 201 | hash_experiment_key=hash_experiment_key, 202 | ) 203 | 204 | if merge_dev: 205 | splits["train"].extend(splits["dev"]) 206 | splits["dev"] = [] 207 | if merge_test: 208 | splits["train"].extend(splits["test"]) 209 | splits["test"] = [] 210 | 211 | date = datetime.now().strftime("%Y-%m-%d") 212 | joined_output_dir = Path(output_dir) / date 213 | joined_output_dir.mkdir(parents=True, exist_ok=True) 214 | 215 | console.print(f"Starting length of train: {len(splits['train'])}") 216 | 217 | # Deduplicate training data 218 | unq_inps = set() 219 | new_train = [] 220 | for ex in splits["train"]: 221 | if ex.instruction not in unq_inps: 222 | new_train.append(ex) 223 | unq_inps.add(ex.instruction) 224 | splits["train"] = new_train 225 | 226 | console.print(f"After dedup length of train: {len(splits['train'])}") 227 | 228 | # Get token size statistics 229 | tokenized_inputs = tokenizer(list(map(lambda x: x.instruction, splits["train"]))) 230 | tokenized_outputs = tokenizer(list(map(lambda x: x.output, splits["train"]))) 231 | input_lengths = [len(x) for x in tokenized_inputs["input_ids"]] 232 | output_lengths = [len(x) for x in tokenized_outputs["input_ids"]] 233 | sum_lengths = [x + y for x, y in zip(input_lengths, output_lengths)] 234 | 235 | console.print( 236 | f"Max input length: {max(input_lengths)}, " 237 | f"Median length: {np.median(input_lengths)}, " 238 | f"90th percentile: {np.percentile(input_lengths, 90)}" 239 | ) 240 | console.print( 241 | f"Max output length: {max(output_lengths)}, " 242 | f"Median length: {np.median(output_lengths)}, " 243 | f"90th percentile: {np.percentile(output_lengths, 90)}" 244 | ) 245 | console.print( 246 | f"Percent overflow: {100*sum(x > seq_length for x in sum_lengths)/len(sum_lengths):.2f}" 247 | ) 248 | console.print( 249 | f"Max sum length: {max(sum_lengths)}, " 250 | f"Median length: {np.median(sum_lengths)}, " 251 | f"85th percentile: {np.percentile(sum_lengths, 85)}" 252 | f"90th percentile: {np.percentile(sum_lengths, 90)}" 253 | f"95th percentile: {np.percentile(sum_lengths, 95)}" 254 | ) 255 | 256 | # Save the data 257 | random.seed(seed) 258 | random.shuffle(splits["train"]) 259 | 260 | for split in splits: 261 | console.print( 262 | f"Found {bad_parses[split]} bad parses out of " 263 | f"{total[split]} ({100*bad_parses[split]/max(total[split], 1): .2f})." 264 | ) 265 | console.print( 266 | f"Saving [green]{split} ({len(splits[split])}) " 267 | f"[/green] data to {joined_output_dir}/{split}.jsonl" 268 | ) 269 | with open(joined_output_dir / f"{split}.jsonl", "w") as f: 270 | for formatted_ex in splits[split]: 271 | f.write(json.dumps(formatted_ex.dict(), ensure_ascii=False) + "\n") 272 | with open(f"{joined_output_dir}/config.json", "w") as f: 273 | json.dump(to_save_args, f, indent=4) 274 | 275 | 276 | if __name__ == "__main__": 277 | build() 278 | -------------------------------------------------------------------------------- /data_prep/prompt_formatters.py: -------------------------------------------------------------------------------- 1 | """Rajkumar prompt formatter.""" 2 | 3 | from abc import ABC 4 | from random import shuffle 5 | 6 | from schema import Table 7 | 8 | 9 | class RajkumarFormatter(ABC): 10 | """RajkumarFormatter class. 11 | 12 | From https://arxiv.org/pdf/2204.00498.pdf. 13 | """ 14 | 15 | table_sep: str = "\n\n" 16 | shuffle_table_order: bool = True 17 | _cache: dict[tuple[str, str, str], list[str]] = {} 18 | 19 | @classmethod 20 | def format_table(cls, table: Table) -> str: 21 | """Get table format.""" 22 | table_fmt = [] 23 | for col in table.columns or []: 24 | # This is technically an incorrect type, but it should be a catchall word 25 | table_fmt.append(f" {col.name} {col.dtype or 'any'}") 26 | if table_fmt: 27 | all_cols = ",\n".join(table_fmt) 28 | create_tbl = f"CREATE TABLE {table.name} (\n{all_cols}\n)" 29 | else: 30 | create_tbl = f"CREATE TABLE {table.name}" 31 | return create_tbl 32 | 33 | @classmethod 34 | def format_all_tables(cls, tables: list[Table], instruction: str) -> list[str]: 35 | """Get all tables format.""" 36 | table_texts = [cls.format_table(table) for table in tables] 37 | key = ("tables", instruction, str(tables)) 38 | if key not in cls._cache: 39 | shuffle(table_texts) 40 | cls._cache[key] = table_texts 41 | else: 42 | table_texts = cls._cache[key] 43 | return table_texts 44 | 45 | @classmethod 46 | def format_prompt( 47 | cls, 48 | instruction: str, 49 | table_text: str, 50 | ) -> str: 51 | """Get prompt format.""" 52 | return f"""{table_text}\n\n\n-- Using valid SQLite, answer the following questions for the tables provided above.\n\n-- {instruction}\n""" # noqa: E501 53 | 54 | @classmethod 55 | def format_model_output(cls, output_sql: str, prompt: str) -> str: 56 | """Format model output.""" 57 | return output_sql 58 | 59 | @classmethod 60 | def format_gold_output(cls, output_sql: str) -> str: 61 | """Format gold output for demonstration.""" 62 | return output_sql 63 | -------------------------------------------------------------------------------- /data_prep/schema.py: -------------------------------------------------------------------------------- 1 | """Text2SQL schemas.""" 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class TableColumn(BaseModel): 7 | """Table column.""" 8 | 9 | name: str 10 | dtype: str | None 11 | 12 | 13 | class ForeignKey(BaseModel): 14 | """Foreign key.""" 15 | 16 | # Referenced column 17 | column: TableColumn 18 | # References table name 19 | references_name: str 20 | # References column 21 | references_column: TableColumn 22 | 23 | 24 | class Table(BaseModel): 25 | """Table.""" 26 | 27 | name: str | None 28 | columns: list[TableColumn] | None 29 | pks: list[TableColumn] | None 30 | # FK from this table to another column in another table 31 | fks: list[ForeignKey] | None 32 | examples: list[dict] | None 33 | # Is the table a source or intermediate reference table 34 | is_reference_table: bool = False 35 | -------------------------------------------------------------------------------- /data_prep/text2sql_data_config.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | - loader: kummerfeld 3 | name: academic 4 | train_data_file: data/text2sql-data/data/academic.json 5 | val_data_file: null 6 | test_data_file: null 7 | schema_file: data/text2sql-data/data/academic-schema.csv 8 | 9 | - loader: kummerfeld 10 | name: advising 11 | train_data_file: data/text2sql-data/data/advising.json 12 | val_data_file: null 13 | test_data_file: null 14 | schema_file: data/text2sql-data/data/advising-schema.csv 15 | 16 | - loader: kummerfeld 17 | name: atis 18 | train_data_file: data/text2sql-data/data/atis.json 19 | val_data_file: null 20 | test_data_file: null 21 | schema_file: data/text2sql-data/data/atis-schema.csv 22 | 23 | - loader: kummerfeld 24 | name: geography 25 | train_data_file: data/text2sql-data/data/geography.json 26 | val_data_file: null 27 | test_data_file: null 28 | schema_file: data/text2sql-data/data/geography-schema.csv 29 | 30 | - loader: kummerfeld 31 | name: imdb 32 | train_data_file: data/text2sql-data/data/imdb.json 33 | val_data_file: null 34 | test_data_file: null 35 | schema_file: data/text2sql-data/data/imdb-schema.csv 36 | 37 | - loader: kummerfeld 38 | name: restaurants 39 | train_data_file: data/text2sql-data/data/restaurants.json 40 | val_data_file: null 41 | test_data_file: null 42 | schema_file: data/text2sql-data/data/restaurants-schema.csv 43 | 44 | - loader: kummerfeld 45 | name: scholar 46 | train_data_file: data/text2sql-data/data/scholar.json 47 | val_data_file: null 48 | test_data_file: null 49 | schema_file: data/text2sql-data/data/scholar-schema.csv 50 | 51 | - loader: kummerfeld 52 | name: yelp 53 | train_data_file: data/text2sql-data/data/yelp.json 54 | val_data_file: null 55 | test_data_file: null 56 | schema_file: data/text2sql-data/data/yelp-schema.csv 57 | 58 | - loader: spider 59 | name: spider 60 | train_data_file: data/spider/train_spider.json 61 | val_data_file: data/spider/dev.json 62 | test_data_file: null 63 | schema_file: data/spider/tables.json 64 | 65 | - loader: sede 66 | name: sede 67 | train_data_file: data/sede/data/sede/train.jsonl 68 | val_data_file: data/sede/data/sede/val.jsonl 69 | test_data_file: data/sede/data/sede/test.jsonl 70 | schema_file: data/sede/stackexchange_schema/tables_so.json 71 | 72 | - loader: wikisql 73 | name: wikisql 74 | train_data_file: null 75 | val_data_file: null 76 | test_data_file: null 77 | schema_file: null 78 | 79 | - loader: spider 80 | name: eicu 81 | train_data_file: data/EHRSQL/dataset/ehrsql/eicu/train.json 82 | val_data_file: data/EHRSQL/dataset/ehrsql/eicu/valid.json 83 | test_data_file: null 84 | schema_file: data/EHRSQL/dataset/ehrsql/tables.json 85 | 86 | - loader: spider 87 | name: mimic_iii 88 | train_data_file: data/EHRSQL/dataset/ehrsql/mimic_iii/train.json 89 | val_data_file: data/EHRSQL/dataset/ehrsql/mimic_iii/valid.json 90 | test_data_file: null 91 | schema_file: data/EHRSQL/dataset/ehrsql/tables.json 92 | 93 | - loader: spider 94 | name: geonucleardata 95 | train_data_file: data/KaggleDBQA/examples/GeoNuclearData.json@data/KaggleDBQA/examples/GeoNuclearData_fewshot.json 96 | val_data_file: null 97 | test_data_file: data/KaggleDBQA/examples/GeoNuclearData_test.json 98 | schema_file: data/KaggleDBQA/KaggleDBQA_tables.json 99 | 100 | - loader: spider 101 | name: greatermanchestercrime 102 | train_data_file: data/KaggleDBQA/examples/GreaterManchesterCrime.json@data/KaggleDBQA/examples/GreaterManchesterCrime_fewshot.json 103 | val_data_file: null 104 | test_data_file: data/KaggleDBQA/examples/GreaterManchesterCrime_test.json 105 | schema_file: data/KaggleDBQA/KaggleDBQA_tables.json 106 | 107 | - loader: spider 108 | name: pesticide 109 | train_data_file: data/KaggleDBQA/examples/Pesticide.json@data/KaggleDBQA/examples/Pesticide_fewshot.json 110 | val_data_file: null 111 | test_data_file: data/KaggleDBQA/examples/Pesticide_test.json 112 | schema_file: data/KaggleDBQA/KaggleDBQA_tables.json 113 | 114 | - loader: spider 115 | name: studentmathscore 116 | train_data_file: data/KaggleDBQA/examples/StudentMathScore.json@data/KaggleDBQA/examples/StudentMathScore_fewshot.json 117 | val_data_file: null 118 | test_data_file: data/KaggleDBQA/examples/StudentMathScore_test.json 119 | schema_file: data/KaggleDBQA/KaggleDBQA_tables.json 120 | 121 | - loader: spider 122 | name: thehistoryofbaseball 123 | train_data_file: data/KaggleDBQA/examples/TheHistoryofBaseball.json@data/KaggleDBQA/examples/TheHistoryofBaseball_fewshot.json 124 | val_data_file: null 125 | test_data_file: data/KaggleDBQA/examples/TheHistoryofBaseball_test.json 126 | schema_file: data/KaggleDBQA/KaggleDBQA_tables.json 127 | 128 | - loader: spider 129 | name: uswildfires 130 | train_data_file: data/KaggleDBQA/examples/USWildFires.json@data/KaggleDBQA/examples/USWildFires_fewshot.json 131 | val_data_file: null 132 | test_data_file: data/KaggleDBQA/examples/USWildFires_test.json 133 | schema_file: data/KaggleDBQA/KaggleDBQA_tables.json 134 | 135 | - loader: spider 136 | name: whatcdhiphop 137 | train_data_file: data/KaggleDBQA/examples/WhatCDHipHop.json@data/KaggleDBQA/examples/WhatCDHipHop_fewshot.json 138 | val_data_file: null 139 | test_data_file: data/KaggleDBQA/examples/WhatCDHipHop_test.json 140 | schema_file: data/KaggleDBQA/KaggleDBQA_tables.json 141 | num_demos: 0 142 | num_copies: 1 143 | 144 | - loader: spider 145 | name: worldsoccerdatabase 146 | train_data_file: data/KaggleDBQA/examples/WorldSoccerDataBase.json@data/KaggleDBQA/examples/WorldSoccerDataBase_fewshot.json 147 | val_data_file: null 148 | test_data_file: data/KaggleDBQA/examples/WorldSoccerDataBase_test.json 149 | schema_file: data/KaggleDBQA/KaggleDBQA_tables.json 150 | num_demos: 0 151 | num_copies: 1 152 | 153 | - loader: mimicsql 154 | name: mimicsql_data 155 | train_data_file: data/TREQS/mimicsql_data/mimicsql_natural_v2/train.json@data/TREQS/mimicsql_data/mimicsql_template/train.json 156 | val_data_file: data/TREQS/mimicsql_data/mimicsql_natural_v2/dev.json@data/TREQS/mimicsql_data/mimicsql_template/dev.json 157 | test_data_file: data/TREQS/mimicsql_data/mimicsql_natural_v2/test.json@data/TREQS/mimicsql_data/mimicsql_template/test.json 158 | schema_file: data/TREQS/mimicsql_data/tables.json 159 | 160 | - loader: criteria2sql 161 | name: criteria2sql 162 | train_data_file: data/Criteria2SQL/data/train.jsonl 163 | val_data_file: data/Criteria2SQL/data/dev.jsonl 164 | test_data_file: data/Criteria2SQL/data/test.jsonl 165 | train_schema_file: data/Criteria2SQL/data/train.tables.jsonl 166 | val_schema_file: data/Criteria2SQL/data/dev.tables.jsonl 167 | test_schema_file: data/Criteria2SQL/data/test.tables.jsonl 168 | schema_file: null 169 | 170 | - loader: sql_create_context2sql 171 | name: sql_create_context 172 | train_data_file: null 173 | val_data_file: null 174 | test_data_file: null 175 | schema_file: null 176 | 177 | - loader: squall 178 | name: squall 179 | train_data_file: data/squall/data/squall.json 180 | val_data_file: null 181 | test_data_file: null 182 | schema_file: null 183 | 184 | - loader: css 185 | name: css 186 | train_data_file: example.train@template.train@schema.train 187 | val_data_file: example.dev@template.dev@schema.dev 188 | test_data_file: example.test@template.test@schema.test 189 | schema_file: data/css/tables.json 190 | 191 | - loader: nvbench 192 | name: nvbench 193 | train_data_file: data/nvBench/NVBench.json 194 | val_data_file: null 195 | test_data_file: null 196 | schema_file: data/nvBench/database/ 197 | -------------------------------------------------------------------------------- /data_prep/text2sql_dataset.py: -------------------------------------------------------------------------------- 1 | """Text2SQL dataset class.""" 2 | 3 | import copy 4 | import json 5 | import os 6 | import sqlite3 7 | from abc import ABC, abstractmethod 8 | from functools import partial 9 | from glob import glob 10 | from pathlib import Path 11 | from typing import Any 12 | 13 | import jsonlines 14 | import sqlglot 15 | from data_utils import ( 16 | clean_str, 17 | convert_criteria2sql_instance, 18 | convert_css_nvbench_instance, 19 | convert_kummerfeld_instance, 20 | convert_sede_instance, 21 | convert_spider_instance, 22 | convert_sql_create_context_instance, 23 | convert_squall_instance, 24 | convert_wikisql_instance, 25 | escape_everything, 26 | read_tables_json, 27 | serialize_dict_to_str, 28 | ) 29 | from datasets import load_dataset 30 | from prompt_formatters import RajkumarFormatter 31 | from pydantic import BaseModel 32 | from rich.console import Console 33 | from schema import ForeignKey, Table, TableColumn 34 | from sqlglot import parse_one 35 | from tqdm.auto import tqdm 36 | from transformers import AutoTokenizer 37 | 38 | console = Console(soft_wrap=True) 39 | 40 | 41 | class Text2SQLData(BaseModel): 42 | """Text2SQL data class.""" 43 | 44 | instruction: str 45 | output: str 46 | source: str 47 | 48 | 49 | class Text2SQLDataset(ABC): 50 | """Text2SQL dataset class.""" 51 | 52 | def __init__( 53 | self, 54 | name: str, 55 | train_data_file: str, 56 | val_data_file: str, 57 | test_data_file: str, 58 | schema_file: str, 59 | context_length: int, 60 | tokenizer_name: str, 61 | **kwargs: Any, 62 | ) -> None: 63 | """Initialize.""" 64 | self.name = name 65 | self.train_data_file = train_data_file 66 | self.val_data_file = val_data_file 67 | self.test_data_file = test_data_file 68 | self.schema_file = schema_file 69 | self.context_length = context_length 70 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 71 | self.clean_question = True 72 | self.process_init_kwargs(**kwargs) 73 | 74 | def process_init_kwargs(self, **kwargs: Any) -> None: 75 | """Process init kwargs.""" 76 | pass 77 | 78 | @abstractmethod 79 | def load_data( 80 | self, schema: dict[str, dict[str, Table]] 81 | ) -> dict[str, list[dict[str, Any]]]: 82 | """Load data.""" 83 | raise NotImplementedError 84 | 85 | @abstractmethod 86 | def load_schema(self) -> dict[str, dict[str, Table]]: 87 | """Load schema.""" 88 | raise NotImplementedError 89 | 90 | def _is_parseable(self, sql: str) -> bool: 91 | try: 92 | res: sqlglot.expressions.Expression | None = parse_one(sql, read="sqlite") 93 | return res is not None 94 | except Exception: 95 | return False 96 | 97 | def _format_example( 98 | self, 99 | ex: dict[str, Any], 100 | schema: dict[str, dict[str, Table]], 101 | prompt_formatter: RajkumarFormatter, 102 | gold_sql_key: str, 103 | ) -> tuple[str, str] | None: 104 | if not self._is_parseable(ex[gold_sql_key]): 105 | print("BAD:::", ex[gold_sql_key]) 106 | return None 107 | 108 | db_id = ex.get("db_id", "database") 109 | db_schema = schema[db_id] 110 | tables_to_add = list(db_schema.keys()) 111 | 112 | if self.clean_question: 113 | question = clean_str(ex["question"]).strip("'").strip('"') 114 | else: 115 | question = ex["question"].strip("'").strip('"') 116 | table_text = prompt_formatter.table_sep.join( 117 | prompt_formatter.format_all_tables( 118 | [db_schema[t] for t in tables_to_add], question 119 | ) 120 | ) 121 | 122 | input_str = prompt_formatter.format_prompt(question, table_text) 123 | output_str = prompt_formatter.format_gold_output(ex[gold_sql_key]) 124 | return input_str, output_str 125 | 126 | def format_example( 127 | self, 128 | example: dict[str, Any], 129 | schema: dict[str, dict[str, Table]], 130 | prompt_formatter: RajkumarFormatter, 131 | ) -> dict[str, Any] | None: 132 | """Format example.""" 133 | 134 | result = self._format_example( 135 | example, 136 | schema, 137 | prompt_formatter, 138 | "sql", 139 | ) 140 | if not result: 141 | return None 142 | input_str, output_str = result 143 | input_str = input_str.strip() + "\n" 144 | output_str = output_str.strip() 145 | data_ex = dict( 146 | instruction=input_str, 147 | output=output_str, 148 | source=self.name, 149 | ) 150 | return data_ex 151 | 152 | 153 | class KummerfeldText2SQL(Text2SQLDataset): 154 | """Kummerfeld text2sql dataset from the text2sql-data repo.""" 155 | 156 | def load_data( 157 | self, schema: dict[str, dict[str, Table]] 158 | ) -> dict[str, list[dict[str, Any]]]: 159 | """Load data.""" 160 | data_pathobj = Path(self.train_data_file) 161 | splits: dict[str, list[dict[str, Any]]] = {"train": [], "dev": [], "test": []} 162 | for raw_ex in tqdm(json.load(data_pathobj.open()), desc="Loading data"): 163 | for ex in convert_kummerfeld_instance(raw_ex, schema=schema): 164 | if ex: 165 | splits[ex["split"]].append(ex) 166 | return splits 167 | 168 | def mine_for_fks( 169 | self, data: list[list[str]], header: list[str] 170 | ) -> dict[str, list[tuple[str, str, list[str]]]]: 171 | """Mine for fks from schema.""" 172 | # The Is Foreign Key column is not always correct so mine via exact match 173 | cur_tablename = None 174 | cur_database = None 175 | schema: dict = {} 176 | cur_table: dict[str, list] = {} 177 | for ex in data: 178 | if len(header) != len(ex): 179 | ex = ex[: len(header)] 180 | row = {h: r.strip() for h, r in zip(header, ex)} 181 | # Keep the type as only the first key 182 | # e.g. varchar(255) default null -> varchar 183 | table_name = row["Table Name"].lower() 184 | field_name = row["Field Name"].lower() 185 | database_name = row.get("Database name", "database").lower() 186 | if ( 187 | table_name == "-" 188 | or table_name != cur_tablename 189 | or database_name != cur_database 190 | ): 191 | if len(cur_table) > 0: 192 | if cur_database not in schema: 193 | schema[cur_database] = {} 194 | schema[cur_database][cur_tablename] = cur_table 195 | cur_table = {} 196 | cur_database = None 197 | cur_tablename = None 198 | if cur_tablename is None and table_name != "-": 199 | cur_tablename = table_name 200 | cur_table = { 201 | "columns": [], 202 | "pks": [], 203 | } 204 | if cur_database is None and database_name != "-": 205 | cur_database = database_name 206 | if cur_tablename is not None: 207 | assert cur_database is not None 208 | cur_table["columns"].append(field_name) 209 | if row["Is Primary Key"].strip().lower() in [ 210 | "yes", 211 | "true", 212 | "y", 213 | "t", 214 | "pri", 215 | ]: 216 | cur_table["pks"].append(field_name) 217 | 218 | # Add last table 219 | assert cur_database is not None 220 | assert cur_tablename is not None 221 | schema[cur_database][cur_tablename] = cur_table 222 | 223 | # Find Fks by matching on field_name 224 | fks: dict[str, list[tuple[str, str, list[str]]]] = {} 225 | for database in schema: 226 | fks[database] = [] 227 | for referenced_table in schema[database]: 228 | # Only want one key per column 229 | used_columns = set() 230 | for references_table in schema[database]: 231 | if referenced_table == references_table: 232 | continue 233 | # Find all columns in referenced table that are in references table 234 | matching_cols = [ 235 | c 236 | for c in schema[database][referenced_table]["columns"] 237 | if c in schema[database][references_table]["columns"] 238 | and c not in used_columns 239 | ] 240 | matching_pk_cols = [ 241 | c 242 | for c in matching_cols 243 | if c in schema[database][references_table]["pks"] 244 | and c.lower() != "id" 245 | ] 246 | used_columns.update(matching_cols) 247 | if len(matching_pk_cols) > 0: 248 | # Use the fk 249 | fks[database].append( 250 | (referenced_table, references_table, matching_pk_cols) 251 | ) 252 | return fks 253 | 254 | def load_schema(self) -> dict[str, dict[str, Table]]: 255 | """Load schema for each table in the database.""" 256 | schema_pathobj = Path(self.schema_file) 257 | # Header is Table Name, Field Name, Is Primary Key, Is Foreign Key, Type 258 | data = [l.strip().split(",") for l in schema_pathobj.open().readlines()] 259 | header = [h.strip() for h in data[0]] 260 | data = data[1:] 261 | 262 | all_fks = self.mine_for_fks(data, header) 263 | schema: dict[str, dict[str, Table]] = {} 264 | cur_tablename = None 265 | cur_database = None 266 | cur_table: dict[str, list] = {} 267 | for ex in data: 268 | if len(header) != len(ex): 269 | ex = ex[: len(header)] 270 | row = {h: r.strip() for h, r in zip(header, ex)} 271 | # Keep the type as only the first key 272 | # e.g. varchar(255) default null -> varchar 273 | row_type = row["Type"].split("(")[0].lower() 274 | table_name = row["Table Name"].lower() 275 | field_name = row["Field Name"].lower() 276 | database_name = row.get("Database name", "database").lower() 277 | if ( 278 | table_name == "-" 279 | or table_name != cur_tablename 280 | or database_name != cur_database 281 | ): 282 | if len(cur_table) > 0: 283 | assert cur_database is not None 284 | if cur_database not in schema: 285 | schema[cur_database] = {} # type: ignore 286 | schema[cur_database][cur_tablename] = Table( # type: ignore 287 | name=cur_tablename, 288 | columns=[ 289 | TableColumn(name=cn, dtype=ct) 290 | for cn, ct in cur_table["columns"] 291 | ], 292 | pks=[ 293 | TableColumn(name=cn, dtype=ct) 294 | for cn, ct in cur_table["pks"] 295 | ], 296 | fks=[ 297 | ForeignKey( 298 | column=TableColumn(name=cn, dtype=ct), 299 | references_name=rtn, 300 | references_column=TableColumn(name=rn, dtype=rt), 301 | ) 302 | for ((cn, ct), rtn, (rn, rt)) in cur_table["fks"] 303 | ], 304 | examples=[], 305 | ) 306 | cur_table = {} 307 | cur_database = None 308 | cur_tablename = None 309 | if cur_tablename is None and table_name != "-": 310 | cur_tablename = table_name 311 | cur_table = { 312 | "columns": [], 313 | "pks": [], 314 | "fks": [], 315 | } 316 | if cur_database is None and database_name != "-": 317 | cur_database = database_name 318 | if cur_tablename is not None: 319 | assert cur_database is not None 320 | cur_table["columns"].append((field_name, row_type)) 321 | if row["Is Primary Key"].strip().lower() in [ 322 | "yes", 323 | "true", 324 | "y", 325 | "t", 326 | "pri", 327 | ]: 328 | cur_table["pks"].append((field_name, row_type)) 329 | for fk_tuple in all_fks[cur_database]: 330 | # referenced_table, references_table, matching_pk_cols 331 | if fk_tuple[0] == cur_tablename and field_name in fk_tuple[2]: 332 | cur_table["fks"].append( 333 | ( 334 | (field_name, row_type), 335 | fk_tuple[1], 336 | (field_name, row_type), 337 | ) 338 | ) 339 | 340 | # Add last table 341 | assert cur_database is not None 342 | assert cur_tablename is not None 343 | schema[cur_database][cur_tablename] = Table( 344 | name=cur_tablename, 345 | columns=[TableColumn(name=cn, dtype=ct) for cn, ct in cur_table["columns"]], 346 | pks=[TableColumn(name=cn, dtype=ct) for cn, ct in cur_table["pks"]], 347 | fks=[ 348 | ForeignKey( 349 | column=TableColumn(name=cn, dtype=ct), 350 | references_name=rtn, 351 | references_column=TableColumn(name=rn, dtype=rt), 352 | ) 353 | for ((cn, ct), rtn, (rn, rt)) in cur_table["fks"] 354 | ], 355 | examples=[], 356 | ) 357 | return schema 358 | 359 | 360 | class SedeText2SQL(Text2SQLDataset): 361 | """Sede text2sql dataset from the text2sql-data repo.""" 362 | 363 | def load_data( 364 | self, schema: dict[str, dict[str, Table]] 365 | ) -> dict[str, list[dict[str, Any]]]: 366 | """Load data.""" 367 | splits: dict[str, list[dict[str, Any]]] = {"train": [], "dev": [], "test": []} 368 | for split in splits: 369 | if split == "dev": 370 | to_read_split = self.val_data_file 371 | elif split == "test": 372 | to_read_split = self.test_data_file 373 | else: 374 | to_read_split = self.train_data_file 375 | data_file = Path(to_read_split) 376 | for line in tqdm( 377 | data_file.open().readlines(), desc=f"Loading {split} data" 378 | ): 379 | raw_ex = json.loads(line) 380 | ex = convert_sede_instance(raw_ex, schema=schema) 381 | if ex: 382 | splits[split].append(ex) 383 | return splits 384 | 385 | def load_schema(self) -> dict[str, dict[str, Table]]: 386 | """Load schema for each table in the database.""" 387 | schema_dct = read_tables_json(self.schema_file) 388 | return schema_dct 389 | 390 | 391 | class SpiderText2SQL(Text2SQLDataset): 392 | """Spider text2sql dataset adapted from Huggingface/Picard.""" 393 | 394 | def load_data( 395 | self, schema: dict[str, dict[str, Table]] 396 | ) -> dict[str, list[dict[str, Any]]]: 397 | """Load data.""" 398 | splits: dict[str, list[dict[str, Any]]] = {"train": [], "dev": [], "test": []} 399 | all_data_for_demos: dict[str, list[dict[str, Any]]] = { 400 | "train": [], 401 | "dev": [], 402 | "test": [], 403 | } 404 | for split in splits: 405 | if split in "dev": 406 | to_read_files = [Path(self.val_data_file)] if self.val_data_file else [] 407 | elif split == "train": 408 | to_read_files = [Path(p) for p in self.train_data_file.split("@")] 409 | elif split == "test": 410 | to_read_files = ( 411 | [Path(self.test_data_file)] if self.test_data_file else [] 412 | ) 413 | else: 414 | to_read_files = [] 415 | console.print(f"Loading {split} data", style="bold blue") 416 | for file in to_read_files: 417 | data_file = Path(file) 418 | try: 419 | data = json.load(data_file.open()) 420 | except json.decoder.JSONDecodeError: 421 | data = [json.loads(line) for line in data_file.open()] 422 | convert_function = partial( 423 | convert_spider_instance, 424 | schema=schema, 425 | ) 426 | for ex in tqdm( 427 | map( 428 | convert_function, 429 | data, 430 | ), 431 | desc=f"Loading {split} data from {data_file.name}", 432 | total=len(data), 433 | ): 434 | if ex: 435 | splits[split].append(ex) 436 | all_data_for_demos[split].append(copy.deepcopy(ex)) 437 | 438 | return splits 439 | 440 | def load_schema(self) -> dict[str, dict[str, Table]]: 441 | """Load schema for each table in the database.""" 442 | schema_dct = read_tables_json(self.schema_file, lowercase=True) 443 | return schema_dct 444 | 445 | 446 | class WikiSQL2SQL(Text2SQLDataset): 447 | """WikiSQL text2sql dataset from the text2sql-data repo.""" 448 | 449 | example2table_name: dict[str, str] = {} 450 | 451 | def load_schema(self) -> dict[str, dict[str, Table]]: 452 | """Load schema for each table in the database.""" 453 | self.dataset = load_dataset("wikisql") 454 | schema_dct: dict[str, dict[str, Table]] = {} 455 | for split in sorted(self.dataset): 456 | for ex in tqdm(self.dataset[split], desc=f"Loading {split} data schema"): 457 | table = ex["table"] 458 | key = serialize_dict_to_str(ex) 459 | if key not in self.example2table_name: 460 | self.example2table_name[ 461 | key 462 | ] = f"table_{len(self.example2table_name)}" 463 | table_name = self.example2table_name[key] 464 | # Quote column names to handle spaces 465 | column_names = [ 466 | f'"{escape_everything(col)}"' for col in table["header"] 467 | ] 468 | if table_name in schema_dct: 469 | continue 470 | columns = [ 471 | TableColumn(name=col, dtype=typ) 472 | for col, typ in zip(column_names, table["types"]) 473 | ] 474 | examples = [ 475 | {column_names[i]: row[i] for i in range(len(row))} 476 | for row in table["rows"] 477 | ] 478 | if table_name not in schema_dct: 479 | schema_dct[table_name] = {} 480 | # WikiSQL uses table_name for both db and table name 481 | schema_dct[table_name][table_name] = Table( 482 | name=table_name, columns=columns, examples=examples 483 | ) 484 | return schema_dct 485 | 486 | def load_data( 487 | self, schema: dict[str, dict[str, Table]] 488 | ) -> dict[str, list[dict[str, Any]]]: 489 | """Load data.""" 490 | splits: dict[str, list[dict[str, Any]]] = {"train": [], "dev": [], "test": []} 491 | for split in splits: 492 | split_to_use = split 493 | if split == "dev": 494 | split_to_use = "validation" 495 | for line in tqdm(self.dataset[split_to_use], desc=f"Loading {split} data"): 496 | key = serialize_dict_to_str(line) 497 | line_to_use = line 498 | line_to_use["table"]["name"] = self.example2table_name[key] 499 | ex = convert_wikisql_instance(line_to_use, schema=schema) 500 | if ex: 501 | splits[split].append(ex) 502 | console.print( 503 | f"Loaded {split} data: {len(splits[split])} over {len(self.dataset[split_to_use])}" 504 | ) 505 | return splits 506 | 507 | def _get_table_schema(self, table: Table) -> list[str]: 508 | "Get table schema from Table." 509 | return sorted([_.name.lower() for _ in table.columns]) if table.columns else [] 510 | 511 | 512 | class MimicsqlText2SQL(Text2SQLDataset): 513 | """Mimicsql text2sql dataset from the TREQS repo.""" 514 | 515 | def load_data( 516 | self, schema: dict[str, dict[str, Table]] 517 | ) -> dict[str, list[dict[str, Any]]]: 518 | """Load data.""" 519 | splits: dict[str, list[dict[str, Any]]] = {"train": [], "dev": [], "test": []} 520 | data_file_mapping = { 521 | "train": self.train_data_file, 522 | "dev": self.val_data_file, 523 | "test": self.test_data_file, 524 | } 525 | for split in splits: 526 | to_read_files = ( 527 | [Path(p) for p in data_file_mapping[split].split("@")] 528 | if split in data_file_mapping 529 | else [] 530 | ) 531 | console.print(f"Loading {split} data", style="bold blue") 532 | 533 | for file in to_read_files: 534 | data_file = Path(file) 535 | try: 536 | data = json.load(data_file.open()) 537 | except json.decoder.JSONDecodeError: 538 | data = [json.loads(line) for line in data_file.open()] 539 | # Convert the data to spider compatible 540 | for i in range(len(data)): 541 | data[i]["db_id"] = "mimicsql" 542 | data[i]["question"] = data[i]["question_refine"] 543 | data[i]["query"] = data[i]["sql"] 544 | convert_function = partial( 545 | convert_spider_instance, 546 | schema=schema, 547 | ) 548 | for ex in tqdm( 549 | map( 550 | convert_function, 551 | data, 552 | ), 553 | desc=f"Loading {split} data from {data_file.name}", 554 | total=len(data), 555 | ): 556 | if ex: 557 | splits[split].append(ex) 558 | 559 | return splits 560 | 561 | def load_schema(self) -> dict[str, dict[str, Table]]: 562 | """Load schema for each table in the database.""" 563 | schema_dct = read_tables_json(self.schema_file) 564 | return schema_dct 565 | 566 | 567 | class Criteria2SQL2SQL(Text2SQLDataset): 568 | """Criteria2SQL text2sql dataset from https://github.com/xiaojingyu92/Criteria2SQL.""" 569 | 570 | def process_init_kwargs(self, **kwargs: Any) -> None: 571 | """Process kwargs.""" 572 | self.train_schema_file = kwargs.pop("train_schema_file", "") 573 | self.val_schema_file = kwargs.pop("val_schema_file", "") 574 | self.test_schema_file = kwargs.pop("test_schema_file", "") 575 | 576 | def load_schema(self) -> dict[str, dict[str, Table]]: 577 | """Load schema for each table in the database.""" 578 | schema_dct: dict[str, dict[str, Table]] = {} 579 | schema_file_mapping = { 580 | "train": self.train_schema_file, 581 | "dev": self.val_schema_file, 582 | "test": self.test_schema_file, 583 | } 584 | 585 | for split in sorted(schema_file_mapping): 586 | with jsonlines.open(schema_file_mapping[split], "r") as f: 587 | for table in tqdm( 588 | [line for line in f], desc=f"Loading {split} data schema" 589 | ): 590 | table_name = f"table_{split}_{table['id']}" 591 | # Quote column names to handle spaces 592 | column_names = [ 593 | f'"{escape_everything(col)}"' for col in table["header"] 594 | ] 595 | columns = [ 596 | TableColumn(name=col, dtype=typ) 597 | for col, typ in zip(column_names, table["types"]) 598 | ] 599 | examples = [ 600 | {column_names[i]: row[i] for i in range(len(row))} 601 | for row in table["rows"] 602 | ] 603 | if table_name not in schema_dct: 604 | schema_dct[table_name] = {} 605 | # Criteria2SQL is similar to WikiSQL and it uses table_name for 606 | # both db and table name 607 | schema_dct[table_name][table_name] = Table( 608 | name=table_name, columns=columns, examples=examples 609 | ) 610 | return schema_dct 611 | 612 | def load_data( 613 | self, schema: dict[str, dict[str, Table]] 614 | ) -> dict[str, list[dict[str, Any]]]: 615 | """Load data.""" 616 | splits: dict[str, list[dict[str, Any]]] = {"train": [], "dev": [], "test": []} 617 | data_file_mapping = { 618 | "train": self.train_data_file, 619 | "dev": self.val_data_file, 620 | "test": self.test_data_file, 621 | } 622 | 623 | for split in sorted(data_file_mapping): 624 | with jsonlines.open(data_file_mapping[split], "r") as f: 625 | all_samples = [line for line in f] 626 | for line in tqdm(all_samples, desc=f"Loading {split} data"): 627 | line_to_use = line 628 | line_to_use["db_id"] = f"table_{split}_{line['table_id']}" 629 | ex = convert_criteria2sql_instance(line_to_use, schema=schema) 630 | if ex: 631 | splits[split].append(ex) 632 | console.print( 633 | f"Loaded {split} data: {len(splits[split])} over {len(all_samples)}" 634 | ) 635 | return splits 636 | 637 | 638 | class SqlCreateContext2SQL(Text2SQLDataset): 639 | """sql-create-context text2sql dataset from huggingface.""" 640 | 641 | def load_schema(self) -> dict[str, dict[str, Table]]: 642 | """Load schema for each table in the database.""" 643 | schema_dct: dict[str, dict[str, Table]] = {} 644 | self.dataset = load_dataset("b-mc2/sql-create-context") 645 | for db_id, ex in tqdm(enumerate(self.dataset["train"])): 646 | for table_context in ex["context"].split(";"): 647 | table_context = table_context.strip() 648 | assert table_context.startswith("CREATE TABLE ") 649 | table_context = table_context[len("CREATE TABLE ") :].strip() 650 | table_name = table_context[: table_context.find("(")].strip() 651 | col_context = table_context[len(table_name) :].strip()[1:-1] 652 | cols = [col.strip().split(" ") for col in col_context.split(",")] 653 | columns = [TableColumn(name=col, dtype=typ) for col, typ in cols] 654 | 655 | if db_id not in schema_dct: 656 | schema_dct[db_id] = {} 657 | if table_name not in schema_dct[db_id]: 658 | schema_dct[db_id][table_name] = Table( 659 | name=table_name, columns=columns 660 | ) 661 | return schema_dct 662 | 663 | def load_data( 664 | self, schema: dict[str, dict[str, Table]] 665 | ) -> dict[str, list[dict[str, Any]]]: 666 | """Load data.""" 667 | splits: dict[str, list[dict[str, Any]]] = {"train": [], "dev": [], "test": []} 668 | 669 | for split in sorted(self.dataset): 670 | for db_id, ex in tqdm( 671 | enumerate(self.dataset[split]), desc=f"Loading {split} data" 672 | ): 673 | line_to_use = ex 674 | line_to_use["db_id"] = db_id 675 | ex = convert_sql_create_context_instance(line_to_use, schema=schema) 676 | if ex: 677 | splits[split].append(ex) 678 | console.print( 679 | f"Loaded {split} data: {len(splits[split])} over {len(self.dataset[split])}" 680 | ) 681 | return splits 682 | 683 | 684 | class Squall2SQL(Text2SQLDataset): 685 | """Squall text2sql dataset from huggingface.""" 686 | 687 | def load_schema(self) -> dict[str, dict[str, Table]]: 688 | """Load schema for each table in the database.""" 689 | schema_dct: dict[str, dict[str, Table]] = {} 690 | self.data = json.load(open(self.train_data_file, "r")) 691 | for i, ex in enumerate(self.data): 692 | table_name = f"table_{ex['tbl']}" 693 | cols = [["id", "number"]] + [ 694 | [ 695 | f'"{escape_everything(col[0])}"', 696 | col[3] if col[3] == "number" else "text", 697 | ] 698 | for col in ex["columns"] 699 | ] 700 | # Skip the table with duplicate column names 701 | if len(set([col[0] for col in cols])) != len(cols): 702 | continue 703 | # Skip the table with empty column name 704 | if '""' in [col[0] for col in cols]: 705 | continue 706 | columns = [TableColumn(name=col, dtype=typ) for col, typ in cols] 707 | if table_name not in schema_dct: 708 | schema_dct[table_name] = {} 709 | if table_name not in schema_dct[table_name]: 710 | schema_dct[table_name][table_name] = Table( 711 | name=table_name, columns=columns 712 | ) 713 | return schema_dct 714 | 715 | def load_data( 716 | self, schema: dict[str, dict[str, Table]] 717 | ) -> dict[str, list[dict[str, Any]]]: 718 | """Load data.""" 719 | splits: dict[str, list[dict[str, Any]]] = {"train": [], "dev": [], "test": []} 720 | split = "train" 721 | for i, ex in enumerate(self.data): 722 | line_to_use = ex 723 | line_to_use["db_id"] = f"table_{ex['tbl']}" 724 | if line_to_use["db_id"] not in schema: 725 | continue 726 | ex = convert_squall_instance(line_to_use, schema=schema) 727 | if ex: 728 | splits[split].append(ex) 729 | console.print( 730 | f"Loaded {split} data: {len(splits[split])} over {len(self.data)}" 731 | ) 732 | return splits 733 | 734 | 735 | class CSS2SQL(Text2SQLDataset): 736 | """CSS2SQL text2sql dataset from huggingface.""" 737 | 738 | def process_init_kwargs(self, **kwargs: Any) -> None: 739 | """Process kwargs.""" 740 | self.clean_question = False 741 | 742 | def load_schema(self) -> dict[str, dict[str, Table]]: 743 | """Load schema for each table in the database.""" 744 | schema_dct = read_tables_json(self.schema_file) 745 | return schema_dct 746 | 747 | def load_data( 748 | self, schema: dict[str, dict[str, Table]] 749 | ) -> dict[str, list[dict[str, Any]]]: 750 | """Load data.""" 751 | splits: dict[str, list[dict[str, Any]]] = {"train": [], "dev": [], "test": []} 752 | ds = load_dataset("zhanghanchong/css") 753 | data_split_mapping = { 754 | "train": self.train_data_file, 755 | "dev": self.val_data_file, 756 | "test": self.test_data_file, 757 | } 758 | for split in splits: 759 | split_cnt = 0 760 | if data_split_mapping[split] is None: 761 | continue 762 | ex_cache: set = set([]) 763 | to_read_splits = ( 764 | [p for p in data_split_mapping[split].split("@")] 765 | if split in data_split_mapping 766 | else [] 767 | ) 768 | console.print(f"Loading {split} data", style="bold blue") 769 | for spt in to_read_splits: 770 | split_cnt += len(ds[spt]) 771 | for line in ds[spt]: 772 | line["question"] = line["question"].strip() 773 | ex = convert_css_nvbench_instance(line, schema=schema) 774 | if ex: 775 | key = f"{ex['db_id']}###{ex['question']}" 776 | if key in ex_cache: 777 | continue 778 | ex_cache.add(key) 779 | splits[split].append(ex) 780 | console.print(f"Loaded {split} data: {len(splits[split])} over {split_cnt}") 781 | return splits 782 | 783 | 784 | class NVBENCH2SQL(Text2SQLDataset): 785 | """NVBENCH2SQL text2sql dataset from https://github.com/TsinghuaDatabaseGroup/nvBench.""" 786 | 787 | def load_schema(self) -> dict[str, dict[str, Table]]: 788 | """Load schema for each table in the database.""" 789 | schema_dct: dict[str, dict[str, Table]] = {} 790 | 791 | db_files = [ 792 | file 793 | for file in glob(self.schema_file + "**", recursive=True) 794 | if file.endswith(".sqlite") 795 | ] 796 | 797 | for db_file in db_files: 798 | db_id = os.path.basename(os.path.dirname(db_file)) 799 | if db_id not in schema_dct: 800 | schema_dct[db_id] = {} 801 | 802 | # Connect to the SQLite database 803 | conn = sqlite3.connect(db_file) 804 | 805 | # Create a cursor object to execute SQL queries 806 | cursor = conn.cursor() 807 | 808 | # Get the list of tables in the database 809 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") 810 | tables = cursor.fetchall() 811 | 812 | # Iterate over the tables and retrieve their schemas 813 | for table in tables: 814 | table_name = table[0] 815 | # Execute a PRAGMA query to get the schema of the table 816 | cursor.execute(f"PRAGMA table_info({table_name});") 817 | schema = cursor.fetchall() 818 | 819 | # Get the schema details 820 | columns = [ 821 | TableColumn( 822 | name=column[1] if " " not in column[1] else f'"{column[1]}"', 823 | dtype=column[2], 824 | ) 825 | for column in schema 826 | ] 827 | 828 | if table_name not in schema_dct[db_id]: 829 | schema_dct[db_id][table_name] = Table( 830 | name=table_name, columns=columns 831 | ) 832 | 833 | # Close the cursor and the database connection 834 | cursor.close() 835 | conn.close() 836 | 837 | return schema_dct 838 | 839 | def load_data( 840 | self, schema: dict[str, dict[str, Table]] 841 | ) -> dict[str, list[dict[str, Any]]]: 842 | """Load data.""" 843 | splits: dict[str, list[dict[str, Any]]] = {"train": [], "dev": [], "test": []} 844 | data_file_mapping = { 845 | "train": self.train_data_file, 846 | "dev": self.val_data_file, 847 | "test": self.test_data_file, 848 | } 849 | for split in splits: 850 | split_cnt = 0 851 | if data_file_mapping[split] is None: 852 | continue 853 | ex_cache: set = set([]) 854 | to_read_files = ( 855 | [Path(p) for p in data_file_mapping[split].split("@")] 856 | if split in data_file_mapping 857 | else [] 858 | ) 859 | console.print(f"Loading {split} data", style="bold blue") 860 | for data_file in to_read_files: 861 | data = json.load(open(data_file, "r")) 862 | for k, v in data.items(): 863 | for nl_query in v["nl_queries"]: 864 | split_cnt += 1 865 | if len(nl_query.strip()) == 0: 866 | continue 867 | line = { 868 | "db_id": v["db_id"], 869 | "query": v["vis_query"]["data_part"]["sql_part"].strip(), 870 | "question": nl_query.strip(), 871 | } 872 | ex = convert_css_nvbench_instance(line, schema=schema) 873 | if ex: 874 | key = f"{ex['db_id']}###{ex['question']}" 875 | if key in ex_cache: 876 | continue 877 | ex_cache.add(key) 878 | splits[split].append(ex) 879 | console.print(f"Loaded {split} data: {len(splits[split])} over {split_cnt}") 880 | return splits 881 | -------------------------------------------------------------------------------- /examples/colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "execution": { 8 | "iopub.execute_input": "2023-07-26T17:54:03.363360Z", 9 | "iopub.status.busy": "2023-07-26T17:54:03.363013Z", 10 | "iopub.status.idle": "2023-07-26T17:54:03.384950Z", 11 | "shell.execute_reply": "2023-07-26T17:54:03.384339Z", 12 | "shell.execute_reply.started": "2023-07-26T17:54:03.363334Z" 13 | }, 14 | "tags": [] 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "%load_ext autoreload\n", 19 | "%autoreload 2" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": { 26 | "execution": { 27 | "iopub.execute_input": "2023-07-26T17:54:03.554498Z", 28 | "iopub.status.busy": "2023-07-26T17:54:03.554003Z", 29 | "iopub.status.idle": "2023-07-26T17:54:03.565041Z", 30 | "shell.execute_reply": "2023-07-26T17:54:03.564332Z", 31 | "shell.execute_reply.started": "2023-07-26T17:54:03.554473Z" 32 | } 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "# Install transformer if you don't have it installed\n", 37 | "# ! pip install transformers" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "This is a standalone notebook that can be imported into colab to run." 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "## Model Setup" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 3, 57 | "metadata": { 58 | "execution": { 59 | "iopub.execute_input": "2023-07-26T17:54:04.109420Z", 60 | "iopub.status.busy": "2023-07-26T17:54:04.108996Z", 61 | "iopub.status.idle": "2023-07-26T17:54:04.120388Z", 62 | "shell.execute_reply": "2023-07-26T17:54:04.119708Z", 63 | "shell.execute_reply.started": "2023-07-26T17:54:04.109396Z" 64 | }, 65 | "tags": [] 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "model_name = \"NumbersStation/nsql-350M\" # <-- You can switch to other models like \"NumbersStation/nsql-6B\"" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 4, 75 | "metadata": { 76 | "execution": { 77 | "iopub.execute_input": "2023-07-26T17:54:04.286491Z", 78 | "iopub.status.busy": "2023-07-26T17:54:04.286167Z", 79 | "iopub.status.idle": "2023-07-26T17:54:09.554296Z", 80 | "shell.execute_reply": "2023-07-26T17:54:09.553527Z", 81 | "shell.execute_reply.started": "2023-07-26T17:54:04.286468Z" 82 | }, 83 | "tags": [] 84 | }, 85 | "outputs": [], 86 | "source": [ 87 | "from transformers import AutoTokenizer, AutoModelForCausalLM\n", 88 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 89 | "model = AutoModelForCausalLM.from_pretrained(model_name)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "## Setup table schema\n", 97 | "\n", 98 | "This is a simple example of database table schema if you want to connect to your own PostgreSQL or SQlite please refer to [other notebooks](https://github.com/NumbersStationAI/NSQL/tree/main/examples)." 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 5, 104 | "metadata": { 105 | "execution": { 106 | "iopub.execute_input": "2023-07-26T17:54:09.555794Z", 107 | "iopub.status.busy": "2023-07-26T17:54:09.555447Z", 108 | "iopub.status.idle": "2023-07-26T17:54:09.580203Z", 109 | "shell.execute_reply": "2023-07-26T17:54:09.579442Z", 110 | "shell.execute_reply.started": "2023-07-26T17:54:09.555775Z" 111 | }, 112 | "tags": [] 113 | }, 114 | "outputs": [], 115 | "source": [ 116 | "table_schema = \"\"\"CREATE TABLE stadium (\n", 117 | " stadium_id number,\n", 118 | " location text,\n", 119 | " name text,\n", 120 | " capacity number,\n", 121 | " highest number,\n", 122 | " lowest number,\n", 123 | " average number\n", 124 | ")\n", 125 | "\n", 126 | "CREATE TABLE singer (\n", 127 | " singer_id number,\n", 128 | " name text,\n", 129 | " country text,\n", 130 | " song_name text,\n", 131 | " song_release_year text,\n", 132 | " age number,\n", 133 | " is_male others\n", 134 | ")\n", 135 | "\n", 136 | "CREATE TABLE concert (\n", 137 | " concert_id number,\n", 138 | " concert_name text,\n", 139 | " theme text,\n", 140 | " stadium_id text,\n", 141 | " year text\n", 142 | ")\n", 143 | "\n", 144 | "CREATE TABLE singer_in_concert (\n", 145 | " concert_id number,\n", 146 | " singer_id text\n", 147 | ")\n", 148 | "\"\"\"\n", 149 | "\n", 150 | "question = \"What is the maximum, the average, and the minimum capacity of stadiums ?\"\n", 151 | "\n", 152 | "prompt = f\"\"\"{table_schema}\n", 153 | "\n", 154 | "-- Using valid SQLite, answer the following questions for the tables provided above.\n", 155 | "\n", 156 | "-- {question}\n", 157 | "\n", 158 | "SELECT\"\"\"" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": {}, 164 | "source": [ 165 | "## Generate SQL" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 9, 171 | "metadata": { 172 | "execution": { 173 | "iopub.execute_input": "2023-07-26T17:55:04.586239Z", 174 | "iopub.status.busy": "2023-07-26T17:55:04.585800Z", 175 | "iopub.status.idle": "2023-07-26T17:55:07.043773Z", 176 | "shell.execute_reply": "2023-07-26T17:55:07.043013Z", 177 | "shell.execute_reply.started": "2023-07-26T17:55:04.586214Z" 178 | }, 179 | "tags": [] 180 | }, 181 | "outputs": [ 182 | { 183 | "name": "stdout", 184 | "output_type": "stream", 185 | "text": [ 186 | "SELECT MAX(capacity), AVG(capacity), MIN(capacity) FROM stadium;\n" 187 | ] 188 | } 189 | ], 190 | "source": [ 191 | "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n", 192 | "generated_ids = model.generate(input_ids, max_length=500)\n", 193 | "output = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n", 194 | "output = 'SELECT' + output.split('SELECT')[-1]\n", 195 | "\n", 196 | "print(output)" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [] 205 | } 206 | ], 207 | "metadata": { 208 | "kernelspec": { 209 | "display_name": "Python 3 (ipykernel)", 210 | "language": "python", 211 | "name": "python3" 212 | }, 213 | "language_info": { 214 | "codemirror_mode": { 215 | "name": "ipython", 216 | "version": 3 217 | }, 218 | "file_extension": ".py", 219 | "mimetype": "text/x-python", 220 | "name": "python", 221 | "nbconvert_exporter": "python", 222 | "pygments_lexer": "ipython3", 223 | "version": "3.10.8" 224 | } 225 | }, 226 | "nbformat": 4, 227 | "nbformat_minor": 4 228 | } 229 | -------------------------------------------------------------------------------- /examples/db_connectors.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from dataclasses import dataclass 3 | from functools import cached_property 4 | from typing import Any, Generator, List 5 | import pandas as pd 6 | import sqlalchemy 7 | 8 | from prompt_formatters import TableColumn, Table 9 | 10 | 11 | @dataclass 12 | class PostgresConnector: 13 | """Postgres connection.""" 14 | 15 | user: str 16 | password: str 17 | dbname: str 18 | host: str 19 | port: int 20 | 21 | @cached_property 22 | def pg_uri(self) -> str: 23 | """Get Postgres URI.""" 24 | uri = ( 25 | f"postgresql://" 26 | f"{self.user}:{self.password}@{self.host}:{self.port}/{self.dbname}" 27 | ) 28 | # ensure we can actually connect to this postgres uri 29 | engine = sqlalchemy.create_engine(uri) 30 | conn = engine.connect() 31 | 32 | # assuming the above connection is successful, we can now close the connection 33 | conn.close() 34 | engine.dispose() 35 | 36 | return uri 37 | 38 | @contextmanager 39 | def connect(self) -> Generator[sqlalchemy.engine.base.Connection, None, None]: 40 | """Yield a connection to a Postgres db. 41 | 42 | Example: 43 | .. code-block:: python 44 | postgres = PostgresConnector( 45 | user=USER, password=PASSWORD, dbname=DBNAME, host=HOST, port=PORT 46 | ) 47 | with postgres.connect() as conn: 48 | conn.execute(sql) 49 | """ 50 | try: 51 | engine = sqlalchemy.create_engine(self.pg_uri) 52 | conn = engine.connect() 53 | yield conn 54 | finally: 55 | conn.close() 56 | engine.dispose() 57 | 58 | def run_sql_as_df(self, sql: str) -> pd.DataFrame: 59 | """Run SQL statement.""" 60 | with self.connect() as conn: 61 | return pd.read_sql(sql, conn) 62 | 63 | def get_tables(self) -> List[str]: 64 | """Get all tables in the database.""" 65 | engine = sqlalchemy.create_engine(self.pg_uri) 66 | table_names = engine.table_names() 67 | engine.dispose() 68 | return table_names 69 | 70 | def get_schema(self, table: str) -> Table: 71 | """Return Table.""" 72 | with self.connect() as conn: 73 | columns = [] 74 | sql = f""" 75 | SELECT column_name, data_type 76 | FROM information_schema.columns 77 | WHERE table_name = '{table}'; 78 | """ 79 | schema = conn.execute(sql).fetchall() 80 | for col, type_ in schema: 81 | columns.append(TableColumn(name=col, dtype=type_)) 82 | return Table(name=table, columns=columns) 83 | 84 | 85 | @dataclass 86 | class SQLiteConnector: 87 | """SQLite connection.""" 88 | 89 | database_path: str 90 | 91 | @cached_property 92 | def sqlite_uri(self) -> str: 93 | """Get SQLite URI.""" 94 | uri = f"sqlite:///{self.database_path}" 95 | # ensure we can actually connect to this SQLite uri 96 | engine = sqlalchemy.create_engine(uri) 97 | conn = engine.connect() 98 | 99 | # assuming the above connection is successful, we can now close the connection 100 | conn.close() 101 | engine.dispose() 102 | 103 | return uri 104 | 105 | @contextmanager 106 | def connect(self) -> Generator[sqlalchemy.engine.base.Connection, None, None]: 107 | """Yield a connection to a SQLite database. 108 | 109 | Example: 110 | .. code-block:: python 111 | sqlite = SQLiteConnector(database_path=DB_PATH) 112 | with sqlite.connect() as conn: 113 | conn.execute(sql) 114 | """ 115 | try: 116 | engine = sqlalchemy.create_engine(self.sqlite_uri) 117 | conn = engine.connect() 118 | yield conn 119 | finally: 120 | conn.close() 121 | engine.dispose() 122 | 123 | def get_tables(self) -> List[str]: 124 | """Get all tables in the database.""" 125 | engine = sqlalchemy.create_engine(self.sqlite_uri) 126 | table_names = engine.table_names() 127 | engine.dispose() 128 | return table_names 129 | 130 | def run_sql_as_df(self, sql: str) -> pd.DataFrame: 131 | """Run SQL statement.""" 132 | with self.connect() as conn: 133 | return pd.read_sql(sql, conn) 134 | 135 | def get_schema(self, table: str) -> Table: 136 | """Return Table.""" 137 | with self.connect() as conn: 138 | columns = [] 139 | sql = f"PRAGMA table_info({table});" 140 | schema = conn.execute(sql).fetchall() 141 | for row in schema: 142 | col = row[1] 143 | type_ = row[2] 144 | columns.append(TableColumn(name=col, dtype=type_)) 145 | return Table(name=table, columns=columns) 146 | -------------------------------------------------------------------------------- /examples/finetune.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "3652e112-d916-48bf-a49d-20ecfa01ed52", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "%load_ext autoreload\n", 13 | "%autoreload 2" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "9178d12e-1058-4a12-b7ad-3f003361775b", 20 | "metadata": { 21 | "tags": [] 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "# Install transformer and peft if you don't have it installed\n", 26 | "# ! pip install transformers==4.31.0\n", 27 | "# ! pip install peft\n", 28 | "# ! pip install accelerate\n", 29 | "# ! pip install bitsandbytes" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "id": "e5a3c3e2-0a1d-429d-94d2-3f81838a010c", 35 | "metadata": {}, 36 | "source": [ 37 | "This is a standalone notebook to train the NSQL model on a single GPU (e.g., A5000 with 24GB) with int8 and LoRA." 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "id": "b77a70b6-a0ac-4f18-a0c8-354839d149ce", 43 | "metadata": { 44 | "execution": { 45 | "iopub.execute_input": "2023-08-02T17:16:32.100885Z", 46 | "iopub.status.busy": "2023-08-02T17:16:32.100451Z", 47 | "iopub.status.idle": "2023-08-02T17:16:32.111648Z", 48 | "shell.execute_reply": "2023-08-02T17:16:32.111052Z", 49 | "shell.execute_reply.started": "2023-08-02T17:16:32.100860Z" 50 | }, 51 | "tags": [] 52 | }, 53 | "source": [ 54 | "# Load the model" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "id": "ed5d1984-7b44-43de-af56-30c7943140aa", 61 | "metadata": { 62 | "tags": [] 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "import torch\n", 67 | "from transformers import AutoTokenizer, AutoModelForCausalLM\n", 68 | "tokenizer = AutoTokenizer.from_pretrained(\"NumbersStation/nsql-llama-2-7B\")\n", 69 | "model = AutoModelForCausalLM.from_pretrained(\"NumbersStation/nsql-llama-2-7B\", load_in_8bit=True, torch_dtype=torch.bfloat16, device_map='auto')" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "id": "939db736-3d70-4f2b-807c-efc029fe8ab7", 75 | "metadata": {}, 76 | "source": [ 77 | "# Prepare the data\n", 78 | "\n", 79 | "We use NumbersStation/NSText2SQL dataset as an example here and feel free to customize the training data based on your need." 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "id": "5b38f8db-7910-4b9d-9061-53fd11ede153", 86 | "metadata": { 87 | "tags": [] 88 | }, 89 | "outputs": [], 90 | "source": [ 91 | "from datasets import load_dataset\n", 92 | "from torch.utils.data import Dataset\n", 93 | "import copy\n", 94 | "\n", 95 | "class NSText2SQLDataset(Dataset):\n", 96 | " def __init__(self, size=None, max_seq_length=2048):\n", 97 | " self.dataset = load_dataset(\"NumbersStation/NSText2SQL\",split=\"train\")\n", 98 | " if size:\n", 99 | " self.dataset = self.dataset.select(range(size))\n", 100 | " self.max_seq_length = max_seq_length\n", 101 | "\n", 102 | " def __len__(self):\n", 103 | " return len(self.dataset)\n", 104 | "\n", 105 | " def __getitem__(self, index):\n", 106 | " instruction = torch.tensor(tokenizer.encode(self.dataset[index]['instruction']), dtype=torch.int64)\n", 107 | " example = self.dataset[index]['instruction'] + self.dataset[index][\"output\"]\n", 108 | " example = tokenizer.encode(example)\n", 109 | " example.append(tokenizer.eos_token_id)\n", 110 | " padding = self.max_seq_length - len(example)\n", 111 | " example = torch.tensor(example, dtype=torch.int64)\n", 112 | "\n", 113 | " if padding < 0:\n", 114 | " example = example[:self.max_seq_length]\n", 115 | " else:\n", 116 | " example = torch.cat((example, torch.zeros(padding, dtype=torch.int64)))\n", 117 | " \n", 118 | " labels = copy.deepcopy(example)\n", 119 | " labels[: len(instruction)] = -100\n", 120 | " \n", 121 | " return {\"input_ids\": example, \"labels\": labels}" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "id": "d669cd4e-1b53-4810-aa5a-8194916b01c0", 128 | "metadata": { 129 | "tags": [] 130 | }, 131 | "outputs": [], 132 | "source": [ 133 | "dataset = NSText2SQLDataset(size=1000, max_seq_length=1024)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "id": "78aa7f7a-1fb5-417e-bf92-e07f8c172a30", 139 | "metadata": {}, 140 | "source": [ 141 | "# Prepare PEFT for model training" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "id": "1627483c-2739-4cc9-8d85-e7353b4c61b5", 148 | "metadata": { 149 | "tags": [] 150 | }, 151 | "outputs": [], 152 | "source": [ 153 | "from peft import (\n", 154 | " get_peft_model,\n", 155 | " LoraConfig,\n", 156 | " TaskType,\n", 157 | " prepare_model_for_int8_training,\n", 158 | ")\n", 159 | "\n", 160 | "\n", 161 | "model.train()\n", 162 | "\n", 163 | "model = prepare_model_for_int8_training(model)\n", 164 | "\n", 165 | "lora_config = LoraConfig(\n", 166 | " task_type=TaskType.CAUSAL_LM,\n", 167 | " inference_mode=False,\n", 168 | " r=8,\n", 169 | " lora_alpha=32,\n", 170 | " lora_dropout=0.05,\n", 171 | " target_modules = [\"q_proj\", \"v_proj\"]\n", 172 | ")\n", 173 | "\n", 174 | "model = get_peft_model(model, lora_config)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "id": "b9a8b18d-a25a-4d12-ac00-0f720703f579", 180 | "metadata": { 181 | "execution": { 182 | "iopub.execute_input": "2023-08-02T17:50:38.809539Z", 183 | "iopub.status.busy": "2023-08-02T17:50:38.809187Z", 184 | "iopub.status.idle": "2023-08-02T17:50:38.836443Z", 185 | "shell.execute_reply": "2023-08-02T17:50:38.835678Z", 186 | "shell.execute_reply.started": "2023-08-02T17:50:38.809520Z" 187 | }, 188 | "tags": [] 189 | }, 190 | "source": [ 191 | "# Finetune the model with Huggingface trainer" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "id": "997db2e9-9b2e-4169-9b18-695d403f70e4", 198 | "metadata": { 199 | "tags": [] 200 | }, 201 | "outputs": [], 202 | "source": [ 203 | "from transformers import default_data_collator, Trainer, TrainingArguments\n", 204 | "\n", 205 | "output_dir = \"training_run\"\n", 206 | "\n", 207 | "config = {\n", 208 | " 'lora_config': lora_config,\n", 209 | " 'learning_rate': 1e-4,\n", 210 | " 'num_train_epochs': 1,\n", 211 | " 'gradient_accumulation_steps': 2,\n", 212 | " 'gradient_checkpointing': False,\n", 213 | "}\n", 214 | "\n", 215 | "\n", 216 | "training_args = TrainingArguments(\n", 217 | " output_dir=output_dir,\n", 218 | " overwrite_output_dir=True,\n", 219 | " bf16=True,\n", 220 | " # logging strategies\n", 221 | " logging_dir=f\"{output_dir}/logs\",\n", 222 | " logging_strategy=\"steps\",\n", 223 | " logging_steps=5,\n", 224 | " optim=\"adamw_torch_fused\",\n", 225 | " **{k:v for k,v in config.items() if k != 'lora_config'}\n", 226 | ")" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "id": "007525f0-41ac-4de7-aba8-92e547931972", 233 | "metadata": { 234 | "tags": [] 235 | }, 236 | "outputs": [], 237 | "source": [ 238 | "trainer = Trainer(\n", 239 | " model=model,\n", 240 | " args=training_args,\n", 241 | " train_dataset=dataset,\n", 242 | " data_collator=default_data_collator,\n", 243 | ")\n", 244 | "\n", 245 | "# Start training\n", 246 | "trainer.train()" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "id": "680eb420-2c91-407c-b19c-89ea0391b591", 252 | "metadata": {}, 253 | "source": [ 254 | "# Save model checkpoint" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "id": "fc1fb82e-b158-40d7-92bf-d63446dc12ae", 261 | "metadata": { 262 | "tags": [] 263 | }, 264 | "outputs": [], 265 | "source": [ 266 | "model.save_pretrained(output_dir)" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "id": "1eaeb1d5-1b49-4873-a205-b3ba94af4214", 272 | "metadata": {}, 273 | "source": [ 274 | "# Evaluate the finetuned model" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": null, 280 | "id": "75e98d82-07de-48fa-8c40-fb59e7805b08", 281 | "metadata": { 282 | "tags": [] 283 | }, 284 | "outputs": [], 285 | "source": [ 286 | "model.eval()" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": null, 292 | "id": "fad30645-9472-4d29-94f8-94e20fdde79a", 293 | "metadata": { 294 | "tags": [] 295 | }, 296 | "outputs": [], 297 | "source": [ 298 | "text = \"\"\"CREATE TABLE stadium (\n", 299 | " stadium_id number,\n", 300 | " location text,\n", 301 | " name text,\n", 302 | " capacity number,\n", 303 | ")\n", 304 | "\n", 305 | "-- Using valid SQLite, answer the following questions for the tables provided above.\n", 306 | "\n", 307 | "-- how many stadiums in total?\n", 308 | "\n", 309 | "SELECT\"\"\"\n", 310 | "\n", 311 | "model_input = tokenizer(text, return_tensors=\"pt\").to(\"cuda\")\n", 312 | "\n", 313 | "generated_ids = model.generate(**model_input, max_new_tokens=100)\n", 314 | "print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "id": "532632db-3579-48dd-931b-b6aa4669ce87", 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [] 324 | } 325 | ], 326 | "metadata": { 327 | "kernelspec": { 328 | "display_name": "Python 3 (ipykernel)", 329 | "language": "python", 330 | "name": "python3" 331 | }, 332 | "language_info": { 333 | "codemirror_mode": { 334 | "name": "ipython", 335 | "version": 3 336 | }, 337 | "file_extension": ".py", 338 | "mimetype": "text/x-python", 339 | "name": "python", 340 | "nbconvert_exporter": "python", 341 | "pygments_lexer": "ipython3", 342 | "version": "3.10.8" 343 | } 344 | }, 345 | "nbformat": 4, 346 | "nbformat_minor": 5 347 | } 348 | -------------------------------------------------------------------------------- /examples/postgres.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "attachments": {}, 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "## DB Setup\n", 19 | "\n", 20 | "We assume you already have a postgres database ready." 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "DATABASE = \"database\"\n", 30 | "USER = \"postgres\"\n", 31 | "PASSWORD = \"password\"\n", 32 | "HOST = \"localhost\"\n", 33 | "PORT = 5432\n", 34 | "TABLES = [] # list of tables to load or [] to load all tables" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "from db_connectors import PostgresConnector\n", 44 | "from prompt_formatters import RajkumarFormatter\n", 45 | "\n", 46 | "# Get the connector and formatter\n", 47 | "postgres_connector = PostgresConnector(\n", 48 | " user=USER, password=PASSWORD, dbname=DATABASE, host=HOST, port=PORT\n", 49 | ")\n", 50 | "postgres_connector.connect()\n", 51 | "if len(TABLES) <= 0:\n", 52 | " TABLES.extend(postgres_connector.get_tables())\n", 53 | "\n", 54 | "print(f\"Loading tables: {TABLES}\")\n", 55 | "\n", 56 | "db_schema = [postgres_connector.get_schema(table) for table in TABLES]\n", 57 | "formatter = RajkumarFormatter(db_schema)" 58 | ] 59 | }, 60 | { 61 | "attachments": {}, 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "## Model Setup\n", 66 | "\n", 67 | "In a separate screen or window, first install [Manifest](https://github.com/HazyResearch/manifest)\n", 68 | "```bash\n", 69 | "pip install manifest-ml\\[all\\]\n", 70 | "```\n", 71 | "\n", 72 | "Then run\n", 73 | "```bash\n", 74 | "python3 -m manifest.api.app \\\n", 75 | " --model_type huggingface \\\n", 76 | " --model_generation_type text-generation \\\n", 77 | " --model_name_or_path NumbersStation/nsql-350M \\\n", 78 | " --device 0\n", 79 | "```\n", 80 | "\n", 81 | "If successful, you will see an output like\n", 82 | "```bash\n", 83 | "* Running on http://127.0.0.1:5000\n", 84 | "```" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "from manifest import Manifest\n", 94 | "\n", 95 | "manifest_client = Manifest(client_name=\"huggingface\", client_connection=\"http://127.0.0.1:5000\")\n", 96 | "\n", 97 | "def get_sql(instruction: str, max_tokens: int = 300) -> str:\n", 98 | " prompt = formatter.format_prompt(instruction)\n", 99 | " res = manifest_client.run(prompt, max_tokens=max_tokens)\n", 100 | " return formatter.format_model_output(res)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "sql = get_sql(\"Number of rows in table?\")\n", 110 | "print(sql)" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "print(postgres_connector.run_sql_as_df(sql))" 120 | ] 121 | } 122 | ], 123 | "metadata": { 124 | "kernelspec": { 125 | "display_name": "dbt", 126 | "language": "python", 127 | "name": "python3" 128 | }, 129 | "language_info": { 130 | "codemirror_mode": { 131 | "name": "ipython", 132 | "version": 3 133 | }, 134 | "file_extension": ".py", 135 | "mimetype": "text/x-python", 136 | "name": "python", 137 | "nbconvert_exporter": "python", 138 | "pygments_lexer": "ipython3", 139 | "version": "3.10.0" 140 | }, 141 | "orig_nbformat": 4 142 | }, 143 | "nbformat": 4, 144 | "nbformat_minor": 2 145 | } 146 | -------------------------------------------------------------------------------- /examples/prompt_formatters.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class TableColumn(BaseModel): 5 | """Table column.""" 6 | 7 | name: str 8 | dtype: str | None 9 | 10 | 11 | class ForeignKey(BaseModel): 12 | """Foreign key.""" 13 | 14 | # Referenced column 15 | column: TableColumn 16 | # References table name 17 | references_name: str 18 | # References column 19 | references_column: TableColumn 20 | 21 | 22 | class Table(BaseModel): 23 | """Table.""" 24 | 25 | name: str 26 | columns: list[TableColumn] | None 27 | pks: list[TableColumn] | None 28 | # FK from this table to another column in another table 29 | fks: list[ForeignKey] | None 30 | 31 | 32 | class RajkumarFormatter: 33 | """RajkumarFormatter class. 34 | 35 | From https://arxiv.org/pdf/2204.00498.pdf. 36 | """ 37 | 38 | table_sep: str = "\n\n" 39 | 40 | def __init__(self, tables: list[Table]) -> None: 41 | self.tables = tables 42 | self.table_str = self.format_tables(tables) 43 | 44 | def format_table(self, table: Table) -> str: 45 | """Get table format.""" 46 | table_fmt = [] 47 | table_name = table.name 48 | for col in table.columns or []: 49 | # This is technically an incorrect type, but it should be a catchall word 50 | table_fmt.append(f" {col.name} {col.dtype or 'any'}") 51 | if table.pks: 52 | table_fmt.append( 53 | f" primary key ({', '.join(pk.name for pk in table.pks)})" 54 | ) 55 | for fk in table.fks or []: 56 | table_fmt.append( 57 | f" foreign key ({fk.column.name}) references {fk.references_name}({fk.references_column.name})" # noqa: E501 58 | ) 59 | if table_fmt: 60 | all_cols = ",\n".join(table_fmt) 61 | create_tbl = f"CREATE TABLE {table_name} (\n{all_cols}\n)" 62 | else: 63 | create_tbl = f"CREATE TABLE {table_name}" 64 | return create_tbl 65 | 66 | def format_tables(self, tables: list[Table]) -> str: 67 | """Get tables format.""" 68 | return self.table_sep.join(self.format_table(table) for table in tables) 69 | 70 | def format_prompt( 71 | self, 72 | instruction: str, 73 | ) -> str: 74 | """Get prompt format.""" 75 | sql_prefix = "SELECT" 76 | return f"""{self.table_str}\n\n\n-- Using valid SQLite, answer the following questions for the tables provided above.\n\n-- {instruction}\n{sql_prefix}""" # noqa: E501 77 | 78 | def format_model_output(self, output_sql: str) -> str: 79 | """Format model output. 80 | 81 | Our prompt ends with SELECT so we need to add it back. 82 | """ 83 | if not output_sql.lower().startswith("select"): 84 | output_sql = "SELECT " + output_sql.strip() 85 | return output_sql 86 | -------------------------------------------------------------------------------- /examples/sqlite.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "attachments": {}, 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "## DB Setup\n", 19 | "\n", 20 | "We assume you already have a sqlite database ready." 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "DATABASE = \"testDB.db\"\n", 30 | "TABLES = [] # list of tables to load or [] to load all tables" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "from db_connectors import SQLiteConnector\n", 40 | "from prompt_formatters import RajkumarFormatter\n", 41 | "\n", 42 | "# Get the connector and formatter\n", 43 | "sqlite_connector = SQLiteConnector(\n", 44 | " database_path=DATABASE\n", 45 | ")\n", 46 | "sqlite_connector.connect()\n", 47 | "if len(TABLES) <= 0:\n", 48 | " TABLES.extend(sqlite_connector.get_tables())\n", 49 | "\n", 50 | "print(f\"Loading tables: {TABLES}\")\n", 51 | "\n", 52 | "db_schema = [sqlite_connector.get_schema(table) for table in TABLES]\n", 53 | "formatter = RajkumarFormatter(db_schema)" 54 | ] 55 | }, 56 | { 57 | "attachments": {}, 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "## Model Setup\n", 62 | "\n", 63 | "In a separate screen or window, first install [Manifest](https://github.com/HazyResearch/manifest)\n", 64 | "```bash\n", 65 | "pip install manifest-ml\\[all\\]\n", 66 | "```\n", 67 | "\n", 68 | "Then run\n", 69 | "```bash\n", 70 | "python3 -m manifest.api.app \\\n", 71 | " --model_type huggingface \\\n", 72 | " --model_generation_type text-generation \\\n", 73 | " --model_name_or_path NumbersStation/nsql-350M \\\n", 74 | " --device 0\n", 75 | "```\n", 76 | "\n", 77 | "If successful, you will see an output like\n", 78 | "```bash\n", 79 | "* Running on http://127.0.0.1:5000\n", 80 | "```" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "from manifest import Manifest\n", 90 | "\n", 91 | "manifest_client = Manifest(client_name=\"huggingface\", client_connection=\"http://127.0.0.1:5000\")\n", 92 | "\n", 93 | "def get_sql(instruction: str, max_tokens: int = 300) -> str:\n", 94 | " prompt = formatter.format_prompt(instruction)\n", 95 | " res = manifest_client.run(prompt, max_tokens=max_tokens)\n", 96 | " return formatter.format_model_output(res)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "sql = get_sql(\"Number of rows in table?\")\n", 106 | "print(sql)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "print(sqlite_connector.run_sql_as_df(sql))" 116 | ] 117 | } 118 | ], 119 | "metadata": { 120 | "kernelspec": { 121 | "display_name": "dbt", 122 | "language": "python", 123 | "name": "python3" 124 | }, 125 | "language_info": { 126 | "codemirror_mode": { 127 | "name": "ipython", 128 | "version": 3 129 | }, 130 | "file_extension": ".py", 131 | "mimetype": "text/x-python", 132 | "name": "python", 133 | "nbconvert_exporter": "python", 134 | "pygments_lexer": "ipython3", 135 | "version": "3.10.0" 136 | }, 137 | "orig_nbformat": 4 138 | }, 139 | "nbformat": 4, 140 | "nbformat_minor": 2 141 | } 142 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | manifest-ml[all]==0.1.8 2 | pandas>=2.0.0 3 | sqlalchemy<2.0.0 4 | transformers>=4.29.0 5 | datasets==2.11.0 6 | jsonlines>=3.1.0 7 | sqlglot==11.5.5 8 | --------------------------------------------------------------------------------