├── .gitignore
├── LICENSE
├── README.md
├── data_prep
├── README.md
├── data
│ ├── download.sh
│ └── tables.json
├── data_utils.py
├── kummerfeld_utils.py
├── prep_text2sql_data.py
├── prompt_formatters.py
├── schema.py
├── text2sql_data_config.yaml
└── text2sql_dataset.py
├── examples
├── colab.ipynb
├── db_connectors.py
├── finetune.ipynb
├── postgres.ipynb
├── prompt_formatters.py
└── sqlite.ipynb
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | ### Python ###
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 | *.pyc
7 |
8 | # Distribution / packaging
9 | .Python
10 | build/
11 | develop-eggs/
12 | dist/
13 | downloads/
14 | eggs/
15 | .eggs/
16 | lib/
17 | lib64/
18 | parts/
19 | sdist/
20 | var/
21 | wheels/
22 | *.egg-info/
23 | .installed.cfg
24 | *.egg
25 |
26 | # PyInstaller
27 | # Usually these files are written by a python script from a template
28 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
29 | *.manifest
30 | *.spec
31 |
32 | # Installer logs
33 | pip-log.txt
34 | pip-delete-this-directory.txt
35 |
36 | # Unit test / coverage reports
37 | htmlcov/
38 | .tox/
39 | .coverage
40 | .coverage.*
41 | .cache
42 | .pytest_cache/
43 | nosetests.xml
44 | coverage.xml
45 | *.cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Jupyter Notebook
53 | .ipynb_checkpoints
54 |
55 | # pyenv
56 | .python-version
57 |
58 | # Environments
59 | .env
60 | .venv
61 | .virtualenv
62 | env/
63 | venv/
64 | virtualenv/
65 | ENV/
66 | env.bak/
67 | venv.bak/
68 |
69 | # mkdocs documentation
70 | /site
71 |
72 | # mypy
73 | .mypy_cache/
74 | .dmypy.json
75 |
76 | # wheels
77 | .whl
78 |
79 | # VS Code
80 | *.vscode
81 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NSQL
2 | Numbers Station Text to SQL model code.
3 |
4 | NSQL is a family of autoregressive open-source large foundation models (FMs) designed specifically for SQL generation tasks. All model weights are provided on HuggingFace.
5 |
6 | | Model Name | Size | Link |
7 | | ---------- | ---- | ------- |
8 | | NumbersStation/nsql-350M | 350M | [link](https://huggingface.co/NumbersStation/nsql-350M)
9 | | NumbersStation/nsql-2B | 2.7B | [link](https://huggingface.co/NumbersStation/nsql-2B)
10 | | NumbersStation/nsql-6B | 6B | [link](https://huggingface.co/NumbersStation/nsql-6B)
11 | | NumbersStation/nsql-llama-2-7B | 7B | [link](https://huggingface.co/NumbersStation/nsql-llama-2-7B)
12 |
13 | ## Setup
14 | To install, run
15 | ```
16 | pip install -r requirements.txt
17 | ```
18 |
19 | ## Usage
20 | See examples in `examples/` for how to connect to Postgres or SQLite to ask questions directly over your data. A small code snippet is provided below from the `examples/` directory.
21 |
22 | In a separate screen or window, run
23 | ```bash
24 | python3 -m manifest.api.app \
25 | --model_type huggingface \
26 | --model_generation_type text-generation \
27 | --model_name_or_path NumbersStation/nsql-350M \
28 | --device 0
29 | ```
30 |
31 | Then run
32 |
33 | ```python
34 | from db_connectors import PostgresConnector
35 | from prompt_formatters import RajkumarFormatter
36 | from manifest import Manifest
37 |
38 | postgres_connector = PostgresConnector(
39 | user=USER, password=PASSWORD, dbname=DATABASE, host=HOST, port=PORT
40 | )
41 | postgres_connector.connect()
42 | db_schema = [postgres_connector.get_schema(table) for table in postgres_connector.get_tables()]
43 | formatter = RajkumarFormatter(db_schema)
44 |
45 | manifest_client = Manifest(client_name="huggingface", client_connection="http://127.0.0.1:5000")
46 |
47 | def get_sql(instruction: str, max_tokens: int = 300) -> str:
48 | prompt = formatter.format_prompt(instruction)
49 | res = manifest_client.run(prompt, max_tokens=max_tokens)
50 | return formatter.format_model_output(res)
51 |
52 | print(get_sql("Number of rows in table?"))
53 | ```
54 |
55 | ## Data Preparation
56 |
57 | In `data_prep` folder, we provide data preparation scripts to generate [NSText2SQL](https://huggingface.co/datasets/NumbersStation/NSText2SQL) to train [NSQL](https://huggingface.co/NumbersStation/nsql-6B) models.
58 |
59 | ## License
60 |
61 | The code in this repo is licensed under the Apache 2.0 license. Unless otherwise noted,
62 |
63 | ```
64 | Copyright 2023 Numbers Station
65 |
66 | Licensed under the Apache License, Version 2.0 (the "License");
67 | you may not use this file except in compliance with the License.
68 | You may obtain a copy of the License at
69 |
70 | http://www.apache.org/licenses/LICENSE-2.0
71 |
72 | Unless required by applicable law or agreed to in writing, software
73 | distributed under the License is distributed on an "AS IS" BASIS,
74 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75 | See the License for the specific language governing permissions and
76 | limitations under the License.
77 | ```
78 |
79 | The data to generate NSText2SQL is sourced from repositories with various licenses. Any use of all or part of the data gathered in NSText2SQL must abide by the terms of the original licenses, including attribution clauses when relevant. We thank all authors who provided these datasets. We provide provenance information for each dataset below.
80 |
81 | | Datasets | License | Link |
82 | | ---------------------- | ------------ | -------------------------------------------------------------------------------------------------------------------- |
83 | | academic | Not Found | [https://github.com/jkkummerfeld/text2sql-data](https://github.com/jkkummerfeld/text2sql-data) |
84 | | advising | CC-BY-4.0 | [https://github.com/jkkummerfeld/text2sql-data](https://github.com/jkkummerfeld/text2sql-data) |
85 | | atis | Not Found | [https://github.com/jkkummerfeld/text2sql-data](https://github.com/jkkummerfeld/text2sql-data) |
86 | | restaurants | Not Found | [https://github.com/jkkummerfeld/text2sql-data](https://github.com/jkkummerfeld/text2sql-data) |
87 | | scholar | Not Found | [https://github.com/jkkummerfeld/text2sql-data](https://github.com/jkkummerfeld/text2sql-data) |
88 | | imdb | Not Found | [https://github.com/jkkummerfeld/text2sql-data](https://github.com/jkkummerfeld/text2sql-data) |
89 | | yelp | Not Found | [https://github.com/jkkummerfeld/text2sql-data](https://github.com/jkkummerfeld/text2sql-data) |
90 | | criteria2sql | Apache-2.0 | [https://github.com/xiaojingyu92/Criteria2SQL](https://github.com/xiaojingyu92/Criteria2SQL) |
91 | | css | CC-BY-4.0 | [https://huggingface.co/datasets/zhanghanchong/css](https://huggingface.co/datasets/zhanghanchong/css) |
92 | | eICU | CC-BY-4.0 | [https://github.com/glee4810/EHRSQL](https://github.com/glee4810/EHRSQL) |
93 | | mimic_iii | CC-BY-4.0 | [https://github.com/glee4810/EHRSQL](https://github.com/glee4810/EHRSQL) |
94 | | geonucleardata | CC-BY-SA-4.0 | [https://github.com/chiahsuan156/KaggleDBQA](https://github.com/chiahsuan156/KaggleDBQA) |
95 | | greatermanchestercrime | CC-BY-SA-4.0 | [https://github.com/chiahsuan156/KaggleDBQA](https://github.com/chiahsuan156/KaggleDBQA) |
96 | | studentmathscore | CC-BY-SA-4.0 | [https://github.com/chiahsuan156/KaggleDBQA](https://github.com/chiahsuan156/KaggleDBQA) |
97 | | thehistoryofbaseball | CC-BY-SA-4.0 | [https://github.com/chiahsuan156/KaggleDBQA](https://github.com/chiahsuan156/KaggleDBQA) |
98 | | uswildfires | CC-BY-SA-4.0 | [https://github.com/chiahsuan156/KaggleDBQA](https://github.com/chiahsuan156/KaggleDBQA) |
99 | | whatcdhiphop | CC-BY-SA-4.0 | [https://github.com/chiahsuan156/KaggleDBQA](https://github.com/chiahsuan156/KaggleDBQA) |
100 | | worldsoccerdatabase | CC-BY-SA-4.0 | [https://github.com/chiahsuan156/KaggleDBQA](https://github.com/chiahsuan156/KaggleDBQA) |
101 | | pesticide | CC-BY-SA-4.0 | [https://github.com/chiahsuan156/KaggleDBQA](https://github.com/chiahsuan156/KaggleDBQA) |
102 | | mimicsql_data | MIT | [https://github.com/wangpinggl/TREQS](https://github.com/wangpinggl/TREQS) |
103 | | nvbench | MIT | [https://github.com/TsinghuaDatabaseGroup/nvBench](https://github.com/TsinghuaDatabaseGroup/nvBench) |
104 | | sede | Apache-2.0 | [https://github.com/hirupert/sede](https://github.com/hirupert/sede) |
105 | | spider | CC-BY-SA-4.0 | [https://huggingface.co/datasets/spider](https://huggingface.co/datasets/spider) |
106 | | sql_create_context | CC-BY-4.0 | [https://huggingface.co/datasets/b-mc2/sql-create-context](https://huggingface.co/datasets/b-mc2/sql-create-context) |
107 | | squall | CC-BY-SA-4.0 | [https://github.com/tzshi/squall](https://github.com/tzshi/squall) |
108 | | wikisql | BSD 3-Clause | [https://github.com/salesforce/WikiSQL](https://github.com/salesforce/WikiSQL) |
109 |
110 | For full terms, see the LICENSE file. If you have any questions, comments, or concerns about licensing please [contact us](https://www.numbersstation.ai/signup).
111 |
112 | # Citing this work
113 |
114 | If you use this data in your work, please cite our work _and_ the appropriate original sources:
115 |
116 | To cite NSText2SQL, please use:
117 | ```TeX
118 | @software{numbersstation2023NSText2SQL,
119 | author = {Numbers Station Labs},
120 | title = {NSText2SQL: An Open Source Text-to-SQL Dataset for Foundation Model Training},
121 | month = {July},
122 | year = {2023},
123 | url = {https://github.com/NumbersStationAI/NSQL},
124 | }
125 | ```
126 |
127 | To cite dataset used in this work, please use:
128 |
129 | | Datasets | Cite |
130 | | ---------------------- | ---------------------------------------------------------------------------------------- |
131 | | academic | `\cite{data-advising,data-academic}` |
132 | | advising | `\cite{data-advising}` |
133 | | atis | `\cite{data-advising,data-atis-original,data-atis-geography-scholar}` |
134 | | restaurants | `\cite{data-advising,data-restaurants-logic,data-restaurants-original,data-restaurants}` |
135 | | scholar | `\cite{data-advising,data-atis-geography-scholar}` |
136 | | imdb | `\cite{data-advising,data-imdb-yelp}` |
137 | | yelp | `\cite{data-advising,data-imdb-yelp}` |
138 | | criteria2sql | `\cite{Criteria-to-SQL}` |
139 | | css | `\cite{zhang2023css}` |
140 | | eICU | `\cite{lee2022ehrsql}` |
141 | | mimic_iii | `\cite{lee2022ehrsql}` |
142 | | geonucleardata | `\cite{lee-2021-kaggle-dbqa}` |
143 | | greatermanchestercrime | `\cite{lee-2021-kaggle-dbqa}` |
144 | | studentmathscore | `\cite{lee-2021-kaggle-dbqa}` |
145 | | thehistoryofbaseball | `\cite{lee-2021-kaggle-dbqa}` |
146 | | uswildfires | `\cite{lee-2021-kaggle-dbqa}` |
147 | | whatcdhiphop | `\cite{lee-2021-kaggle-dbqa}` |
148 | | worldsoccerdatabase | `\cite{lee-2021-kaggle-dbqa}` |
149 | | pesticide | `\cite{lee-2021-kaggle-dbqa}` |
150 | | mimicsql_data | `\cite{wang2020text}` |
151 | | nvbench | `\cite{nvBench_SIGMOD21}` |
152 | | sede | `\cite{hazoom2021text}` |
153 | | spider | `\cite{data-spider}` |
154 | | sql_create_context | Not Found |
155 | | squall | `\cite{squall}` |
156 | | wikisql | `\cite{data-wikisql}` |
157 |
158 |
159 | ```TeX
160 | @InProceedings{data-advising,
161 | dataset = {Advising},
162 | author = {Catherine Finegan-Dollak, Jonathan K. Kummerfeld, Li Zhang, Karthik Ramanathan, Sesh Sadasivam, Rui Zhang, and Dragomir Radev},
163 | title = {Improving Text-to-SQL Evaluation Methodology},
164 | booktitle = {Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
165 | month = {July},
166 | year = {2018},
167 | location = {Melbourne, Victoria, Australia},
168 | pages = {351--360},
169 | url = {http://aclweb.org/anthology/P18-1033},
170 | }
171 |
172 | @InProceedings{data-imdb-yelp,
173 | dataset = {IMDB and Yelp},
174 | author = {Navid Yaghmazadeh, Yuepeng Wang, Isil Dillig, and Thomas Dillig},
175 | title = {SQLizer: Query Synthesis from Natural Language},
176 | booktitle = {International Conference on Object-Oriented Programming, Systems, Languages, and Applications, ACM},
177 | month = {October},
178 | year = {2017},
179 | pages = {63:1--63:26},
180 | url = {http://doi.org/10.1145/3133887},
181 | }
182 |
183 | @article{data-academic,
184 | dataset = {Academic},
185 | author = {Fei Li and H. V. Jagadish},
186 | title = {Constructing an Interactive Natural Language Interface for Relational Databases},
187 | journal = {Proceedings of the VLDB Endowment},
188 | volume = {8},
189 | number = {1},
190 | month = {September},
191 | year = {2014},
192 | pages = {73--84},
193 | url = {http://dx.doi.org/10.14778/2735461.2735468},
194 | }
195 |
196 | @InProceedings{data-atis-geography-scholar,
197 | dataset = {Scholar, and Updated ATIS and Geography},
198 | author = {Srinivasan Iyer, Ioannis Konstas, Alvin Cheung, Jayant Krishnamurthy, and Luke Zettlemoyer},
199 | title = {Learning a Neural Semantic Parser from User Feedback},
200 | booktitle = {Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
201 | year = {2017},
202 | pages = {963--973},
203 | location = {Vancouver, Canada},
204 | url = {http://www.aclweb.org/anthology/P17-1089},
205 | }
206 |
207 | @article{data-atis-original,
208 | dataset = {ATIS, original},
209 | author = {Deborah A. Dahl, Madeleine Bates, Michael Brown, William Fisher, Kate Hunicke-Smith, David Pallett, Christine Pao, Alexander Rudnicky, and Elizabeth Shriber},
210 | title = {{Expanding the scope of the ATIS task: The ATIS-3 corpus}},
211 | journal = {Proceedings of the workshop on Human Language Technology},
212 | year = {1994},
213 | pages = {43--48},
214 | url = {http://dl.acm.org/citation.cfm?id=1075823},
215 | }
216 |
217 | @inproceedings{data-restaurants-logic,
218 | author = {Lappoon R. Tang and Raymond J. Mooney},
219 | title = {Automated Construction of Database Interfaces: Intergrating Statistical and Relational Learning for Semantic Parsing},
220 | booktitle = {2000 Joint SIGDAT Conference on Empirical Methods in Natural Language Processing and Very Large Corpora},
221 | year = {2000},
222 | pages = {133--141},
223 | location = {Hong Kong, China},
224 | url = {http://www.aclweb.org/anthology/W00-1317},
225 | }
226 |
227 | @inproceedings{data-restaurants-original,
228 | author = {Ana-Maria Popescu, Oren Etzioni, and Henry Kautz},
229 | title = {Towards a Theory of Natural Language Interfaces to Databases},
230 | booktitle = {Proceedings of the 8th International Conference on Intelligent User Interfaces},
231 | year = {2003},
232 | location = {Miami, Florida, USA},
233 | pages = {149--157},
234 | url = {http://doi.acm.org/10.1145/604045.604070},
235 | }
236 |
237 | @inproceedings{data-restaurants,
238 | author = {Alessandra Giordani and Alessandro Moschitti},
239 | title = {Automatic Generation and Reranking of SQL-derived Answers to NL Questions},
240 | booktitle = {Proceedings of the Second International Conference on Trustworthy Eternal Systems via Evolving Software, Data and Knowledge},
241 | year = {2012},
242 | location = {Montpellier, France},
243 | pages = {59--76},
244 | url = {https://doi.org/10.1007/978-3-642-45260-4_5},
245 | }
246 |
247 | @InProceedings{data-spider,
248 | author = {Tao Yu, Rui Zhang, Kai Yang, Michihiro Yasunaga, Dongxu Wang, Zifan Li, James Ma, Irene Li, Qingning Yao, Shanelle Roman, Zilin Zhang, and Dragomir Radev},
249 | title = {Spider: A Large-Scale Human-Labeled Dataset for Complex and Cross-Domain Semantic Parsing and Text-to-SQL Task},
250 | booktitle = {Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing},
251 | year = {2018},
252 | location = {Brussels, Belgium},
253 | pages = {3911--3921},
254 | url = {http://aclweb.org/anthology/D18-1425},
255 | }
256 |
257 | @article{data-wikisql,
258 | author = {Victor Zhong, Caiming Xiong, and Richard Socher},
259 | title = {Seq2SQL: Generating Structured Queries from Natural Language using Reinforcement Learning},
260 | year = {2017},
261 | journal = {CoRR},
262 | volume = {abs/1709.00103},
263 | }
264 |
265 | @InProceedings{Criteria-to-SQL,
266 | author = {Yu, Xiaojing and Chen, Tianlong and Yu, Zhengjie and Li, Huiyu and Yang, Yang and Jiang, Xiaoqian and Jiang, Anxiao},
267 | title = {Dataset and Enhanced Model for Eligibility Criteria-to-SQL Semantic Parsing},
268 | booktitle = {Proceedings of The 12th Language Resources and Evaluation Conference},
269 | month = {May},
270 | year = {2020},
271 | address = {Marseille, France},
272 | publisher = {European Language Resources Association},
273 | pages = {5831--5839},
274 | }
275 |
276 | @misc{zhang2023css,
277 | title = {CSS: A Large-scale Cross-schema Chinese Text-to-SQL Medical Dataset},
278 | author = {Hanchong Zhang and Jieyu Li and Lu Chen and Ruisheng Cao and Yunyan Zhang and Yu Huang and Yefeng Zheng and Kai Yu},
279 | year = {2023},
280 | }
281 |
282 | @article{lee2022ehrsql,
283 | title = {EHRSQL: A Practical Text-to-SQL Benchmark for Electronic Health Records},
284 | author = {Lee, Gyubok and Hwang, Hyeonji and Bae, Seongsu and Kwon, Yeonsu and Shin, Woncheol and Yang, Seongjun and Seo, Minjoon and Kim, Jong-Yeup and Choi, Edward},
285 | journal = {Advances in Neural Information Processing Systems},
286 | volume = {35},
287 | pages = {15589--15601},
288 | year = {2022},
289 | }
290 |
291 | @inproceedings{lee-2021-kaggle-dbqa,
292 | title = {KaggleDBQA: Realistic Evaluation of Text-to-SQL Parsers},
293 | author = {Lee, Chia-Hsuan and Polozov, Oleksandr and Richardson, Matthew},
294 | booktitle = {Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)},
295 | pages = {2261--2273},
296 | year = {2021},
297 | }
298 |
299 | @inproceedings{squall,
300 | title = {On the Potential of Lexico-logical Alignments for Semantic Parsing to {SQL} Queries},
301 | author = {Tianze Shi and Chen Zhao and Jordan Boyd-Graber and Hal {Daum\'{e} III} and Lillian Lee},
302 | booktitle = {Findings of EMNLP},
303 | year = {2020},
304 | }
305 |
306 | @article{hazoom2021text,
307 | title = {Text-to-SQL in the wild: a naturally-occurring dataset based on Stack exchange data},
308 | author = {Hazoom, Moshe and Malik, Vibhor and Bogin, Ben},
309 | journal = {arXiv preprint arXiv:2106.05006},
310 | year = {2021},
311 | }
312 |
313 | @inproceedings{wang2020text,
314 | title = {Text-to-SQL Generation for Question Answering on Electronic Medical Records},
315 | author = {Wang, Ping and Shi, Tian and Reddy, Chandan K},
316 | booktitle = {Proceedings of The Web Conference 2020},
317 | pages = {350--361},
318 | year = {2020},
319 | }
320 |
321 | @inproceedings{nvBench_SIGMOD21,
322 | title = {Synthesizing Natural Language to Visualization (NL2VIS) Benchmarks from NL2SQL Benchmarks},
323 | author = {Yuyu Luo and Nan Tang and Guoliang Li and Chengliang Chai and Wenbo Li and Xuedi Qin},
324 | booktitle = {Proceedings of the 2021 International Conference on Management of Data, {SIGMOD} Conference 2021, June 20–25, 2021, Virtual Event, China},
325 | publisher = {ACM},
326 | year = {2021},
327 | }
328 | ```
329 |
330 |
331 | ## Acknowledgement
332 | We are appreciative to the work done by the all authors for those datasets that made this project possible.
--------------------------------------------------------------------------------
/data_prep/README.md:
--------------------------------------------------------------------------------
1 | ## Data Preprocessing
2 |
3 | We provide scripts and instructions to create the [NSText2SQL](https://huggingface.co/datasets/NumbersStation/NSText2SQL) to train [NSQL](https://huggingface.co/NumbersStation/nsql-6B) models. The dataset will saved as jsonl files, following the format:
4 |
5 | ```
6 | {"instruction": ..., "output": "...", "source": "..."}
7 | ```
8 |
9 | #### Data Download
10 |
11 | We use the datasets hosted on Github, Huggingface, and other online servers. You will need to download the datasets by running the following commands from the `data` folder:
12 |
13 | ```bash
14 | cd data/
15 | bash download.sh
16 | cd ..
17 | ```
18 |
19 | To download spider dataset, you need to download it from [here](https://drive.google.com/uc?export=download&id=1TqleXec_OykOYFREKKtschzY29dUcVAQ) and unzip it to the data folder.
20 |
21 | #### Data Preparation
22 |
23 | To preprocess the data into our own format. You can run:
24 |
25 | ```bash
26 | python prep_text2sql_data.py \
27 | --datasets academic \
28 | --datasets advising \
29 | --output_dir [OUTPUT_DIR]
30 | ```
31 |
32 | The processed data will be saved into `[OUPUT_DIR]/YYYY-MM-DD` folder. Here is the available DATASET_NAME list:
33 | - wikisql
34 | - academic
35 | - advising
36 | - atis
37 | - imdb
38 | - restaurants
39 | - scholar
40 | - yelp
41 | - sede
42 | - eicu
43 | - mimic_iii
44 | - GeoNuclearData
45 | - GreaterManchesterCrime
46 | - Pesticide
47 | - StudentMathScore
48 | - TheHistoryofBaseball
49 | - USWildFires
50 | - WhatCDHipHop
51 | - WorldSoccerDataBase
52 | - mimicsql_data
53 | - criteria2sql
54 | - sql_create_context
55 | - squall
56 | - css
57 | - spider
58 | - nvbench
59 |
60 | For more information you can find in the `prep_text2sql_data.py`.
--------------------------------------------------------------------------------
/data_prep/data/download.sh:
--------------------------------------------------------------------------------
1 | git clone https://github.com/jkkummerfeld/text2sql-data
2 | git clone https://github.com/hirupert/sede
3 | git clone https://github.com/glee4810/EHRSQL
4 | git clone https://github.com/chiahsuan156/KaggleDBQA
5 |
6 | git clone https://github.com/wangpinggl/TREQS
7 | cp tables.json TREQS/mimicsql_data/ # copy table schema to TREQS dataset
8 |
9 | git clone https://github.com/tzshi/squall
10 | git clone https://github.com/xiaojingyu92/Criteria2SQL
11 |
12 | wget https://huggingface.co/datasets/zhanghanchong/css/resolve/main/css.zip
13 | unzip css.zip
14 |
15 | git clone https://github.com/TsinghuaDatabaseGroup/nvBench
16 | unzip nvBench/databases.zip -d nvBench
17 |
--------------------------------------------------------------------------------
/data_prep/data/tables.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "column_names": [
4 | [
5 | -1,
6 | "*"
7 | ],
8 | [
9 | 0,
10 | "subject_id"
11 | ],
12 | [
13 | 0,
14 | "hadm_id"
15 | ],
16 | [
17 | 0,
18 | "name"
19 | ],
20 | [
21 | 0,
22 | "marital_status"
23 | ],
24 | [
25 | 0,
26 | "age"
27 | ],
28 | [
29 | 0,
30 | "dob"
31 | ],
32 | [
33 | 0,
34 | "gender"
35 | ],
36 | [
37 | 0,
38 | "language"
39 | ],
40 | [
41 | 0,
42 | "religion"
43 | ],
44 | [
45 | 0,
46 | "admission_type"
47 | ],
48 | [
49 | 0,
50 | "days_stay"
51 | ],
52 | [
53 | 0,
54 | "insurance"
55 | ],
56 | [
57 | 0,
58 | "ethnicity"
59 | ],
60 | [
61 | 0,
62 | "expire_flag"
63 | ],
64 | [
65 | 0,
66 | "admission_location"
67 | ],
68 | [
69 | 0,
70 | "discharge_location"
71 | ],
72 | [
73 | 0,
74 | "diagnosis"
75 | ],
76 | [
77 | 0,
78 | "dod"
79 | ],
80 | [
81 | 0,
82 | "dob_year"
83 | ],
84 | [
85 | 0,
86 | "dod_year"
87 | ],
88 | [
89 | 0,
90 | "admittime"
91 | ],
92 | [
93 | 0,
94 | "dischtime"
95 | ],
96 | [
97 | 0,
98 | "admityear"
99 | ],
100 | [
101 | 1,
102 | "subject_id"
103 | ],
104 | [
105 | 1,
106 | "hadm_id"
107 | ],
108 | [
109 | 1,
110 | "icd9_code"
111 | ],
112 | [
113 | 1,
114 | "short_title"
115 | ],
116 | [
117 | 1,
118 | "long_title"
119 | ],
120 | [
121 | 2,
122 | "subject_id"
123 | ],
124 | [
125 | 2,
126 | "hadm_id"
127 | ],
128 | [
129 | 2,
130 | "icd9_code"
131 | ],
132 | [
133 | 2,
134 | "short_title"
135 | ],
136 | [
137 | 2,
138 | "long_title"
139 | ],
140 | [
141 | 3,
142 | "subject_id"
143 | ],
144 | [
145 | 3,
146 | "hadm_id"
147 | ],
148 | [
149 | 3,
150 | "icustay_id"
151 | ],
152 | [
153 | 3,
154 | "drug_type"
155 | ],
156 | [
157 | 3,
158 | "drug"
159 | ],
160 | [
161 | 3,
162 | "formulary_drug_cd"
163 | ],
164 | [
165 | 3,
166 | "route"
167 | ],
168 | [
169 | 3,
170 | "drug_dose"
171 | ],
172 | [
173 | 4,
174 | "subject_id"
175 | ],
176 | [
177 | 4,
178 | "hadm_id"
179 | ],
180 | [
181 | 4,
182 | "itemid"
183 | ],
184 | [
185 | 4,
186 | "charttime"
187 | ],
188 | [
189 | 4,
190 | "flag"
191 | ],
192 | [
193 | 4,
194 | "value_unit"
195 | ],
196 | [
197 | 4,
198 | "label"
199 | ],
200 | [
201 | 4,
202 | "fluid"
203 | ],
204 | [
205 | 4,
206 | "category"
207 | ]
208 | ],
209 | "column_names_original": [
210 | [
211 | -1,
212 | "*"
213 | ],
214 | [
215 | 0,
216 | "subject_id"
217 | ],
218 | [
219 | 0,
220 | "hadm_id"
221 | ],
222 | [
223 | 0,
224 | "name"
225 | ],
226 | [
227 | 0,
228 | "marital_status"
229 | ],
230 | [
231 | 0,
232 | "age"
233 | ],
234 | [
235 | 0,
236 | "dob"
237 | ],
238 | [
239 | 0,
240 | "gender"
241 | ],
242 | [
243 | 0,
244 | "language"
245 | ],
246 | [
247 | 0,
248 | "religion"
249 | ],
250 | [
251 | 0,
252 | "admission_type"
253 | ],
254 | [
255 | 0,
256 | "days_stay"
257 | ],
258 | [
259 | 0,
260 | "insurance"
261 | ],
262 | [
263 | 0,
264 | "ethnicity"
265 | ],
266 | [
267 | 0,
268 | "expire_flag"
269 | ],
270 | [
271 | 0,
272 | "admission_location"
273 | ],
274 | [
275 | 0,
276 | "discharge_location"
277 | ],
278 | [
279 | 0,
280 | "diagnosis"
281 | ],
282 | [
283 | 0,
284 | "dod"
285 | ],
286 | [
287 | 0,
288 | "dob_year"
289 | ],
290 | [
291 | 0,
292 | "dod_year"
293 | ],
294 | [
295 | 0,
296 | "admittime"
297 | ],
298 | [
299 | 0,
300 | "dischtime"
301 | ],
302 | [
303 | 0,
304 | "admityear"
305 | ],
306 | [
307 | 1,
308 | "subject_id"
309 | ],
310 | [
311 | 1,
312 | "hadm_id"
313 | ],
314 | [
315 | 1,
316 | "icd9_code"
317 | ],
318 | [
319 | 1,
320 | "short_title"
321 | ],
322 | [
323 | 1,
324 | "long_title"
325 | ],
326 | [
327 | 2,
328 | "subject_id"
329 | ],
330 | [
331 | 2,
332 | "hadm_id"
333 | ],
334 | [
335 | 2,
336 | "icd9_code"
337 | ],
338 | [
339 | 2,
340 | "short_title"
341 | ],
342 | [
343 | 2,
344 | "long_title"
345 | ],
346 | [
347 | 3,
348 | "subject_id"
349 | ],
350 | [
351 | 3,
352 | "hadm_id"
353 | ],
354 | [
355 | 3,
356 | "icustay_id"
357 | ],
358 | [
359 | 3,
360 | "drug_type"
361 | ],
362 | [
363 | 3,
364 | "drug"
365 | ],
366 | [
367 | 3,
368 | "formulary_drug_cd"
369 | ],
370 | [
371 | 3,
372 | "route"
373 | ],
374 | [
375 | 3,
376 | "drug_dose"
377 | ],
378 | [
379 | 4,
380 | "subject_id"
381 | ],
382 | [
383 | 4,
384 | "hadm_id"
385 | ],
386 | [
387 | 4,
388 | "itemid"
389 | ],
390 | [
391 | 4,
392 | "charttime"
393 | ],
394 | [
395 | 4,
396 | "flag"
397 | ],
398 | [
399 | 4,
400 | "value_unit"
401 | ],
402 | [
403 | 4,
404 | "label"
405 | ],
406 | [
407 | 4,
408 | "fluid"
409 | ],
410 | [
411 | 4,
412 | "category"
413 | ]
414 | ],
415 | "column_types": [
416 | "text",
417 | "text",
418 | "text",
419 | "text",
420 | "text",
421 | "text",
422 | "text",
423 | "text",
424 | "text",
425 | "text",
426 | "text",
427 | "text",
428 | "text",
429 | "text",
430 | "text",
431 | "text",
432 | "text",
433 | "text",
434 | "text",
435 | "text",
436 | "text",
437 | "text",
438 | "text",
439 | "text",
440 | "text",
441 | "text",
442 | "text",
443 | "text",
444 | "text",
445 | "text",
446 | "text",
447 | "text",
448 | "text",
449 | "text",
450 | "text",
451 | "text",
452 | "text",
453 | "text",
454 | "text",
455 | "text",
456 | "text",
457 | "text",
458 | "text",
459 | "text",
460 | "text",
461 | "text",
462 | "text",
463 | "text",
464 | "text",
465 | "text"
466 | ],
467 | "db_id": "mimicsql",
468 | "foreign_keys": [],
469 | "primary_keys": [],
470 | "table_names": [
471 | "demographic",
472 | "diagnoses",
473 | "procedures",
474 | "prescriptions",
475 | "lab"
476 | ],
477 | "table_names_original": [
478 | "demographic",
479 | "diagnoses",
480 | "procedures",
481 | "prescriptions",
482 | "lab"
483 | ]
484 | }
485 | ]
--------------------------------------------------------------------------------
/data_prep/data_utils.py:
--------------------------------------------------------------------------------
1 | """Training data prep utils."""
2 | import json
3 | import re
4 | from collections import defaultdict
5 | from typing import Any
6 |
7 | import sqlglot
8 | from kummerfeld_utils import preprocess_for_jsql
9 | from schema import ForeignKey, Table, TableColumn
10 |
11 | _AGG_OPS = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
12 | _COND_OPS = ["=", ">", "<", "OP"]
13 |
14 |
15 | def escape_everything(string: str) -> str:
16 | """Escape everything.
17 |
18 | Args:
19 | string: string to escape
20 |
21 | Returns:
22 | Escaped string.
23 | """
24 | return json.dumps(string)[1:-1]
25 |
26 |
27 | def serialize_dict_to_str(d: dict) -> str:
28 | """
29 | Serialize a dict into a str.
30 |
31 | Args:
32 | d: dict to serialize.
33 |
34 | Returns:
35 | serialized dict.
36 | """
37 | return json.dumps(d, sort_keys=True)
38 |
39 |
40 | def read_tables_json(
41 | schema_file: str,
42 | lowercase: bool = False,
43 | ) -> dict[str, dict[str, Table]]:
44 | """Read tables json."""
45 | data = json.load(open(schema_file))
46 | db_to_tables = {}
47 | for db in data:
48 | db_name = db["db_id"]
49 | table_names = db["table_names_original"]
50 | db["column_names_original"] = [
51 | [x[0], x[1]] for x in db["column_names_original"]
52 | ]
53 | db["column_types"] = db["column_types"]
54 | if lowercase:
55 | table_names = [tn.lower() for tn in table_names]
56 | pks = db["primary_keys"]
57 | fks = db["foreign_keys"]
58 | tables = defaultdict(list)
59 | tables_pks = defaultdict(list)
60 | tables_fks = defaultdict(list)
61 | for idx, ((ti, col_name), col_type) in enumerate(
62 | zip(db["column_names_original"], db["column_types"])
63 | ):
64 | if ti == -1:
65 | continue
66 | if lowercase:
67 | col_name = col_name.lower()
68 | col_type = col_type.lower()
69 | if idx in pks:
70 | tables_pks[table_names[ti]].append(
71 | TableColumn(name=col_name, dtype=col_type)
72 | )
73 | for fk in fks:
74 | if idx == fk[0]:
75 | other_column = db["column_names_original"][fk[1]]
76 | other_column_type = db["column_types"][fk[1]]
77 | other_table = table_names[other_column[0]]
78 | tables_fks[table_names[ti]].append(
79 | ForeignKey(
80 | column=TableColumn(name=col_name, dtype=col_type),
81 | references_name=other_table,
82 | references_column=TableColumn(
83 | name=other_column[1], dtype=other_column_type
84 | ),
85 | )
86 | )
87 | tables[table_names[ti]].append(TableColumn(name=col_name, dtype=col_type))
88 | db_to_tables[db_name] = {
89 | table_name: Table(
90 | name=table_name,
91 | columns=tables[table_name],
92 | pks=tables_pks[table_name],
93 | fks=tables_fks[table_name],
94 | )
95 | for table_name in tables
96 | }
97 | return db_to_tables
98 |
99 |
100 | def clean_str(target: str) -> str:
101 | """Clean string for question."""
102 | if not target:
103 | return target
104 |
105 | target = re.sub(r"[^\x00-\x7f]", r" ", target)
106 | line = re.sub(r"''", r" ", target)
107 | line = re.sub(r"``", r" ", line)
108 | line = re.sub(r"\"", r"'", line)
109 | line = re.sub(r"[\t ]+", " ", line)
110 | return line.strip()
111 |
112 |
113 | def case_sql(query: str) -> str:
114 | """Case sql query."""
115 | try:
116 | cased_sql = sqlglot.parse_one(query).sql() # type: ignore
117 | # SQLGlot makes NOT
IN. We want NOT IN for Spider
118 | cased_sql = re.sub(r"NOT\s+([^\s]+)\s+IN", r"\1 NOT IN", cased_sql)
119 | # Replace <> with !=
120 | cased_sql = cased_sql.replace("<>", "!=")
121 | return cased_sql
122 | except Exception:
123 | print("Cannot CASE this SQL")
124 | return query
125 |
126 |
127 | def crude_remove_aliases(sql: str) -> str:
128 | """Cruder way of cleaning up aliases."""
129 | alias2cleanalias = {}
130 | new_sql = re.sub(r"[\t\s\n]+", " ", sql)
131 | for word in sql.split():
132 | if "." in word:
133 | alias = word.split(".")[0]
134 | if "alias" in alias:
135 | clean_alias = alias.split("alias")[0] + "_" + alias.split("alias")[1]
136 | alias2cleanalias[alias] = clean_alias
137 | for alias, clean_alias in alias2cleanalias.items():
138 | new_sql = new_sql.replace(alias, clean_alias)
139 | return new_sql
140 |
141 |
142 | def remove_aliases(sql: str) -> str:
143 | """Remove aliases from SQL."""
144 | new_sql = re.sub(r"[\t\s\n]+", " ", sql)
145 | # Handle from
146 | alias2table = {}
147 | table2alias: dict[str, list[str]] = {}
148 | # Get substring from FROM to WHERE or to GROUP BY or to end
149 | inside_from = re.search(
150 | r"FROM (.*?) (WHERE|GROUP BY|ORDER BY|LIMIT|;)", new_sql, re.DOTALL
151 | )
152 | if not inside_from:
153 | inside_from = re.search(r"FROM (.*?)$", new_sql, re.DOTALL)
154 | if not inside_from:
155 | print("BAD FROM", sql)
156 | for from_clause in re.split(
157 | r",| INNER JOIN| OUTER JOIN| LEFT JOIN| RIGHT JOIN| JOIN| EXCEPT",
158 | inside_from.group(1), # type: ignore
159 | ):
160 | # If JOIN table ON XXX, remove the ON XXX
161 | if " ON " in from_clause:
162 | from_clause = from_clause.split(" ON ")[0]
163 | if " AS " in from_clause:
164 | table = from_clause.split(" AS ")[0].strip()
165 | alias = from_clause.split(" AS ")[1].strip()
166 | alias2table[alias] = table
167 | # If we have two of the same tables in the from clause
168 | # must keep and handle aliases differently
169 | if table in table2alias:
170 | # If only one already in, start creating new aliases
171 | if len(table2alias[table]) == 1:
172 | old_alias = table2alias[table][0]
173 | table2alias[table] = [f"{table}_{len(table2alias[table])-1}"]
174 | alias2table[old_alias] = table2alias[table][-1]
175 | table2alias[table].append(f"{table}_{len(table2alias[table])}")
176 | alias2table[alias] = table2alias[table][-1]
177 | else:
178 | table2alias[table] = [alias]
179 | # Now replace AS alias in from clauses where we can
180 | for from_clause in re.split(
181 | r",| INNER JOIN| OUTER JOIN| LEFT JOIN| RIGHT JOIN| JOIN| EXCEPT",
182 | inside_from.group(1), # type: ignore
183 | ):
184 | if " ON " in from_clause:
185 | from_clause = from_clause.split(" ON ")[0]
186 | if " AS " in from_clause:
187 | table = from_clause.split(" AS ")[0].strip()
188 | alias = from_clause.split(" AS ")[1].strip()
189 | if len(table2alias[table]) == 1:
190 | new_sql = new_sql.replace(from_clause, " " + table)
191 |
192 | # Replace old aliases with new ones (or og table name)
193 | for al, table in alias2table.items():
194 | new_sql = new_sql.replace(al, table)
195 |
196 | # Replace table references as not needed with one table
197 | if len(alias2table) == 1:
198 | table = list(alias2table.values())[0]
199 | new_sql = new_sql.replace(table + ".", "")
200 |
201 | new_sql = re.sub(r"[\t\s\n]+", " ", new_sql)
202 | return new_sql
203 |
204 |
205 | def get_table_alias_to_ref_map(sql: str) -> dict[str, set[str]]:
206 | """Get all aliases and the reference tables they point to.
207 |
208 | Key of None will be all unaliased tables.
209 |
210 | This accounts for both table AS T1 clauses and subexpressions.
211 | """
212 | try:
213 | parsed: sqlglot.expressions.Expression = sqlglot.parse_one(sql, read="sqlite")
214 | except Exception:
215 | return defaultdict(set)
216 | # Get all table aliases - including CTEs
217 | mapping = defaultdict(set)
218 | all_table_aliases = list(parsed.find_all(sqlglot.exp.TableAlias))
219 | for tbl_alias in all_table_aliases:
220 | sql_parent = tbl_alias.parent
221 | if sql_parent:
222 | tbls = [
223 | table.name
224 | for table in sql_parent.find_all(sqlglot.exp.Table)
225 | if table.name != tbl_alias.name
226 | ]
227 | if tbls:
228 | mapping[tbl_alias.name].update(tbls)
229 | # Add any table without alias
230 | for table in parsed.find_all(sqlglot.exp.Table):
231 | if not table.alias or table.alias == table.name:
232 | mapping[None].add(table.name)
233 | return mapping
234 |
235 |
236 | def format_to_match_schema(
237 | sql: str,
238 | schema: dict[str, Table],
239 | ) -> str:
240 | """Format the tables and columns in the query to match the schema."""
241 | table_alias_to_ref = get_table_alias_to_ref_map(sql)
242 | all_tables = set().union(*table_alias_to_ref.values())
243 | all_tables_lower = {t.lower() for t in all_tables}
244 | tablename2colset = {
245 | tbl.name: set([c.name for c in tbl.columns])
246 | for tbl in schema.values()
247 | if tbl.name.lower() in all_tables_lower
248 | }
249 |
250 | def transformer(node: sqlglot.Expression) -> sqlglot.Expression:
251 | if isinstance(node, sqlglot.exp.Column):
252 | for tbl in tablename2colset:
253 | for col in tablename2colset[tbl]:
254 | # Due to table aliases, we don't want to make this a joint
255 | # condition on the column and alias
256 | if node.table and node.table.lower() == tbl.lower():
257 | node.args["table"] = tbl
258 | if node.name.lower() == col.lower():
259 | node.args["this"] = col
260 | break
261 | elif isinstance(node, sqlglot.exp.Table):
262 | for tbl in tablename2colset:
263 | if node.name.lower() == tbl.lower():
264 | node.args["this"] = tbl
265 | break
266 | return node
267 |
268 | parsed: sqlglot.expressions.Expression = sqlglot.parse_one(sql, read="sqlite")
269 | transformed_parsed = parsed.transform(transformer)
270 | return transformed_parsed.sql()
271 |
272 |
273 | def convert_kummerfeld_instance(
274 | data: dict[str, Any],
275 | schema: dict[str, dict[str, Table]] = {},
276 | keep_vars: bool = False,
277 | keep_sql_vars: bool = False,
278 | ) -> list[dict[str, Any]]:
279 | """Convert a single instance of the data into a list of examples.
280 |
281 | Used for the text2sql-data repo from jkkummerfeld.
282 | """
283 | var_sql = None
284 | var_sql = data["sql"][0] #
285 | parsed_results: list[dict[str, Any]] = []
286 | for sentence in data["sentences"]:
287 | text = sentence["text"]
288 | sql = preprocess_for_jsql(
289 | var_sql
290 | ) # Needed to do variable replacement correctly
291 | if not sql:
292 | raise ValueError(f"No SQL for sentence {sentence}")
293 | sql = str(sql)
294 | cleaned_sql = remove_aliases(sql)
295 | cleaned_sql = case_sql(cleaned_sql)
296 | crude_cleaned_sql = crude_remove_aliases(sql)
297 | crude_cleaned_sql = case_sql(crude_cleaned_sql)
298 | # Variable replacement
299 | if not keep_vars:
300 | for name in sentence["variables"]:
301 | value = sentence["variables"][name]
302 | if len(value) == 0:
303 | for variable in data["variables"]:
304 | if variable["name"] == name:
305 | value = variable["example"]
306 | text = value.join(text.split(name))
307 | if not keep_sql_vars:
308 | cleaned_sql = value.join(cleaned_sql.split(name))
309 | crude_cleaned_sql = value.join(crude_cleaned_sql.split(name))
310 | sql = value.join(sql.split(name)) # type: ignore
311 |
312 | # Query split is either train/dev/test or 0-9 for cross validation
313 | # We use test/0 for test, dev/1 for dev and the rest for train
314 | if data["query-split"] == "N/A":
315 | # Flip a coin to decide if it's train or test or valid
316 | output_file = sentence["question-split"]
317 | else:
318 | if data["query-split"] == "test" or data["query-split"] == "0":
319 | output_file = "test"
320 | elif data["query-split"] == "dev" or data["query-split"] == "1":
321 | output_file = "dev"
322 | else:
323 | output_file = "train"
324 |
325 | db_id = sentence.get("database", "database").lower()
326 | try:
327 | cleaned_sql = format_to_match_schema(cleaned_sql, schema[db_id])
328 | except Exception:
329 | print("ERROR")
330 | continue
331 | parsed_results.append(
332 | {
333 | "question": text,
334 | "sql": cleaned_sql,
335 | "split": output_file,
336 | "db_id": db_id,
337 | }
338 | )
339 |
340 | return parsed_results
341 |
342 |
343 | def convert_sede_instance(
344 | data: dict[str, str],
345 | schema: dict[str, dict[str, Table]] = {},
346 | ) -> dict[str, Any]:
347 | """Convert a single instance of the data into an example.
348 |
349 | Used for the sede dataset.
350 | """
351 | # clean title and description
352 | cleaned_title = clean_str(data["Title"])
353 | cleaned_description = clean_str(data["Description"])
354 |
355 | # clean SQL query
356 | cleaned_sql = None
357 | # cleaned_sql_with_values = None
358 | if data["QueryBody"]:
359 | target_with_values = str(preprocess_for_jsql(data["QueryBody"]))
360 | target_with_values = case_sql(target_with_values)
361 | if target_with_values:
362 | target_tokens = target_with_values.strip(";").split()
363 | target = case_sql(" ".join(target_tokens))
364 | cleaned_sql = target
365 | # Handle With statement by removing WITH part before SELECT
366 | # Example:
367 | # WITH no activity in last 6 months SELECT * FROM TABLE
368 | # -->
369 | # SELECT * FROM TABLE
370 | index_s = cleaned_sql.lower().find("select")
371 | index_w = cleaned_sql.lower().find("with")
372 | if index_w < index_s and index_w == 0:
373 | prefix = re.sub(r"\s\s+", " ", cleaned_sql[index_w:index_s][4:].strip())
374 | # Ignore the valid CTE: With a AS(...)
375 | # Don't want to skip With a AS ...
376 | # since no () means it won't use the defined var in the SQL
377 | if not (
378 | prefix.lower().endswith(" as (") or prefix.lower().endswith(" as(")
379 | ):
380 | print("ORI:", cleaned_sql)
381 | print("NEW:", case_sql(cleaned_sql[index_s:]))
382 | cleaned_sql = case_sql(cleaned_sql[index_s:])
383 |
384 | # Try to convert from TSQL to SQLite
385 | try:
386 | cleaned_sql = sqlglot.transpile(cleaned_sql, read="tsql", write="sqlite")[0]
387 | except Exception:
388 | pass
389 |
390 | if cleaned_title and cleaned_sql:
391 | try:
392 | cleaned_sql = format_to_match_schema(cleaned_sql, schema["stackexchange"])
393 | cleaned_sql = case_sql(cleaned_sql)
394 | except Exception:
395 | print("ERROR:::", cleaned_sql)
396 | cleaned_sql = None
397 | if cleaned_sql:
398 | preprocessed_annotated_sql = {
399 | "question": (
400 | cleaned_title.strip() + ". " + (cleaned_description or "").strip()
401 | ).strip(),
402 | "db_id": "stackexchange",
403 | "sql": cleaned_sql,
404 | }
405 | else:
406 | preprocessed_annotated_sql = {}
407 | else:
408 | preprocessed_annotated_sql = {}
409 |
410 | return preprocessed_annotated_sql
411 |
412 |
413 | def convert_spider_instance(
414 | data: dict[str, str],
415 | schema: dict[str, dict[str, Table]] = {},
416 | ) -> dict[str, Any]:
417 | """Convert a single instance of the data into an example.
418 |
419 | Used for the spider dataset.
420 | """
421 | query = data["query"]
422 | question = data["question"]
423 | db_id = data["db_id"]
424 | target = case_sql(query)
425 | target = format_to_match_schema(target, schema[db_id])
426 | sql = {
427 | "question": question,
428 | "db_id": db_id,
429 | "sql": target,
430 | }
431 | # Check if example is impossible to answer
432 | if not data.get("is_impossible", False):
433 | return sql
434 | return {}
435 |
436 |
437 | def convert_wikisql_instance(
438 | data: dict[str, Any],
439 | schema: dict[str, dict[str, Table]] = {},
440 | ) -> dict[str, Any]:
441 | """Convert a single instance of the data into an example.
442 |
443 | Used for the wikisql dataset.
444 | """
445 |
446 | def _convert_to_human_readable(
447 | table_name: str,
448 | sel: int,
449 | agg: int,
450 | columns: list[str],
451 | conditions: list[tuple[int, int, str]],
452 | ) -> str:
453 | """Make SQL query string. Based on https://github.com/salesforce/WikiSQL/blob/c2ed4f9b22db1cc2721805d53e6e76e07e2ccbdc/lib/query.py#L10""" # noqa: E501
454 | strip_quotes = lambda x: x.strip('"').strip("'").replace("'", "''")
455 | quoted_columns = [f'"{escape_everything(c)}"' for c in columns]
456 | if _AGG_OPS[agg] == "":
457 | rep = f"SELECT {quoted_columns[sel] if quoted_columns is not None else f'col{sel}'} FROM {table_name}" # noqa: E501
458 | else:
459 | rep = f"SELECT {_AGG_OPS[agg]}({quoted_columns[sel] if quoted_columns is not None else f'col{sel}'}) FROM {table_name}" # noqa: E501
460 |
461 | if conditions:
462 | rep += " WHERE " + " AND ".join(
463 | [
464 | f"{quoted_columns[i]} {_COND_OPS[o]} '{strip_quotes(v)}'"
465 | for i, o, v in conditions
466 | ]
467 | )
468 | return " ".join(rep.split())
469 |
470 | conds = data["sql"]["conds"]
471 | iov_list = list(
472 | zip(conds["column_index"], conds["operator_index"], conds["condition"])
473 | )
474 | query = _convert_to_human_readable(
475 | data["table"]["name"],
476 | data["sql"]["sel"],
477 | data["sql"]["agg"],
478 | data["table"]["header"],
479 | iov_list,
480 | )
481 | question = data["question"]
482 | db_id = data["table"]["name"]
483 | target = case_sql(query)
484 |
485 | try:
486 | target = format_to_match_schema(target, schema[db_id])
487 | except Exception as e:
488 | print("ERROR:::")
489 | print(target)
490 | print(e)
491 | return {}
492 | sql = {
493 | "question": question,
494 | "db_id": db_id,
495 | "sql": target,
496 | }
497 | return sql
498 |
499 |
500 | def convert_criteria2sql_instance(
501 | data: dict[str, Any],
502 | schema: dict[str, dict[str, Table]] = {},
503 | ) -> dict[str, Any]:
504 | """Convert a single instance of the data into an example.
505 |
506 | Modified from the criteria2sql dataset.
507 | """
508 | # We want to use the 'real' table name and all columns in the query
509 | assert data["query"].startswith("select id from records")
510 | query = data["query"]
511 | query = query.replace("select id from records", f"select * from {data['db_id']}")
512 | question = data["question"]
513 | db_id = data["db_id"]
514 | target = case_sql(query)
515 |
516 | try:
517 | target = format_to_match_schema(target, schema[db_id])
518 | except Exception as e:
519 | print("ERROR:::")
520 | print(target)
521 | print(e)
522 | return {}
523 |
524 | sql = {
525 | "question": question,
526 | "db_id": db_id,
527 | "sql": target,
528 | }
529 | return sql
530 |
531 |
532 | def convert_sql_create_context_instance(
533 | data: dict[str, Any],
534 | schema: dict[str, dict[str, Table]] = {},
535 | ) -> dict[str, Any]:
536 | """Convert a single instance of the data into an example."""
537 | query = data["answer"]
538 | question = data["question"]
539 | db_id = data["db_id"]
540 | target = case_sql(query)
541 |
542 | try:
543 | target = format_to_match_schema(target, schema[db_id])
544 | except Exception as e:
545 | print("ERROR:::")
546 | print(target)
547 | print(e)
548 | return {}
549 |
550 | sql = {
551 | "question": question,
552 | "db_id": db_id,
553 | "sql": target,
554 | }
555 | return sql
556 |
557 |
558 | def convert_squall_instance(
559 | data: dict[str, Any],
560 | schema: dict[str, dict[str, Table]] = {},
561 | ) -> dict[str, Any]:
562 | """Convert a single instance of the data into an example."""
563 | db_id = data["db_id"]
564 | question = " ".join(data["nl"])
565 | sql_toks = []
566 | for tok in data["sql"]:
567 | if tok[0] in ["Literal.Number", "Literal.String"]:
568 | sql_toks.append(tok[1])
569 | elif tok[0] == "Keyword":
570 | sql_toks.append(tok[1] if tok[1] != "w" else db_id)
571 | else:
572 | if "_" in tok[1]:
573 | idx = int(tok[1][1 : tok[1].find("_")])
574 | else:
575 | idx = int(tok[1][1])
576 | sql_toks.append(schema[db_id][db_id].columns[idx].name)
577 | query = " ".join(sql_toks)
578 | # Fix not null error
579 | query = query.replace("not null", "is not null")
580 | target = case_sql(query)
581 |
582 | try:
583 | target = format_to_match_schema(target, schema[db_id])
584 | except Exception as e:
585 | print("ERROR:::")
586 | print(target)
587 | print(e)
588 | return {}
589 |
590 | sql = {
591 | "question": question,
592 | "db_id": db_id,
593 | "sql": target,
594 | }
595 | return sql
596 |
597 |
598 | def convert_css_nvbench_instance(
599 | data: dict[str, Any],
600 | schema: dict[str, dict[str, Table]] = {},
601 | ) -> dict[str, Any]:
602 | """Convert a single instance of the data into an example."""
603 | db_id = data["db_id"]
604 | question = data["question"]
605 | query = data["query"]
606 | target = case_sql(query)
607 | try:
608 | target = format_to_match_schema(target, schema[db_id])
609 | except Exception as e:
610 | print("ERROR:::")
611 | print(target)
612 | print(e)
613 | return {}
614 |
615 | sql = {
616 | "question": question,
617 | "db_id": db_id,
618 | "sql": target,
619 | }
620 | return sql
621 |
--------------------------------------------------------------------------------
/data_prep/kummerfeld_utils.py:
--------------------------------------------------------------------------------
1 | """SQL processing utils.
2 |
3 | Adapted from https://github.com/hirupert/sede
4 | """
5 | import re
6 |
7 | ALIAS_PATTERN = re.compile(r"\[([^\]]+)]", re.MULTILINE | re.IGNORECASE)
8 | TAGS_PATTERN = re.compile(
9 | r"([^'%])(##[a-z0-9_?:]+##)([^'%]?)", re.MULTILINE | re.IGNORECASE
10 | )
11 | TOP_TAGS_PATTERN = re.compile(
12 | r"(top|percentile_cont)([ ]+)?[\(]?[ ]?(##[a-z0-9_]+(:[a-z]+)?(\?([0-9.]+))?##)[ ]?[\)]?",
13 | re.IGNORECASE,
14 | )
15 | SQL_TOKENS = {
16 | "select",
17 | "from",
18 | "where",
19 | "group",
20 | "order",
21 | "limit",
22 | "intersect",
23 | "union",
24 | "except",
25 | "join",
26 | "on",
27 | "as",
28 | "not",
29 | "between",
30 | "=",
31 | ">",
32 | "<",
33 | ">=",
34 | "<=",
35 | "!=",
36 | "in",
37 | "like",
38 | "is",
39 | "exists",
40 | "none",
41 | "max",
42 | "min",
43 | "count",
44 | "sum",
45 | "avg",
46 | "or",
47 | "and",
48 | }
49 |
50 |
51 | def _remove_comment_at_beginning(cleaned_query: str) -> str:
52 | """Remove comments at the beginning of the line."""
53 | return re.sub(r"^([- ]+|(result))+", "", cleaned_query, re.MULTILINE)
54 |
55 |
56 | def remove_comments(sql: str) -> str:
57 | """Remove comments from sql.""" ""
58 | # remove comments at the beginning of line
59 | sql = _remove_comment_at_beginning(sql)
60 |
61 | # remove comments at the end of lines
62 | sql = re.sub(r"--(.+)?\n", "", sql)
63 |
64 | # remove comments at the end of lines
65 | sql = re.sub(r"\n;\n", " ", sql)
66 |
67 | sql = re.sub(" +", " ", sql)
68 |
69 | return sql.strip()
70 |
71 |
72 | def remove_comments_after_removing_new_lines(sql: str) -> str:
73 | """Remove comments and newlines from sql."""
74 | # remove comments at the end of the query
75 | sql = re.sub(r"--(.?)+$", "", sql, re.MULTILINE)
76 |
77 | # remove comments like /* a comment */
78 | sql = re.sub(r"/\*[^*/]+\*/", "", sql, re.MULTILINE)
79 |
80 | sql = re.sub(" +", " ", sql)
81 |
82 | return sql.strip()
83 |
84 |
85 | def _surrounded_by_apostrophes(sql: str, start_index: int, end_index: int) -> bool:
86 | """Check if the string is surrounded by apostrophes."""
87 | max_steps = 10
88 |
89 | starts_with_apostrophe = False
90 | step_count = 0
91 | while start_index >= 0 and step_count < max_steps:
92 | if sql[start_index] == "'":
93 | starts_with_apostrophe = True
94 | break
95 | if sql[start_index] == " ":
96 | starts_with_apostrophe = False
97 | break
98 | start_index -= 1
99 | step_count += 1
100 |
101 | end_with_apostrophe = False
102 | step_count = 0
103 | while end_index < len(sql) and step_count < max_steps:
104 | if sql[end_index] == "'":
105 | end_with_apostrophe = True
106 | break
107 | if sql[end_index] == " ":
108 | end_with_apostrophe = False
109 | break
110 | end_index += 1
111 | step_count += 1
112 |
113 | return starts_with_apostrophe and end_with_apostrophe
114 |
115 |
116 | # pylint: disable=too-many-branches
117 | def preprocess_for_jsql(sql: str) -> str | None:
118 | """Preprocess sql for jsql."""
119 | # replace all alias like "as [User Id]" to "as 'user_id'"
120 | match = re.search(ALIAS_PATTERN, sql)
121 | while match is not None:
122 | group_one = match.group(1)
123 | if not _surrounded_by_apostrophes(sql, match.start(), match.end()):
124 | new_alias = f"'{group_one.lower()}'"
125 | else:
126 | new_alias = group_one.lower()
127 |
128 | if " " in new_alias:
129 | new_alias = new_alias.replace(" ", "_")
130 | sql = sql.replace(match.group(0), new_alias)
131 | match = re.search(ALIAS_PATTERN, sql)
132 |
133 | # replace all parameters like "TOP ##topn:int?200##" to "TOP 200"
134 | match = re.search(TOP_TAGS_PATTERN, sql)
135 | while match is not None:
136 | group_zero = match.group(0)
137 | default_number = match.group(6)
138 |
139 | if default_number is not None:
140 | new_alias = f"{match.group(1)} ({default_number})"
141 | else:
142 | new_alias = f"{match.group(1)} (100)"
143 |
144 | sql = sql.replace(group_zero, new_alias)
145 | match = re.search(TOP_TAGS_PATTERN, sql)
146 |
147 | # replace all parameters like ##tagName:Java## to '##tagName:Java##'
148 | new_sql = ""
149 | match = re.search(TAGS_PATTERN, sql)
150 | while match is not None:
151 | group_two = match.group(2)
152 |
153 | if not _surrounded_by_apostrophes(sql, match.start(), match.end()):
154 | new_alias = f"{match.group(1)}'{group_two}'{match.group(3)}"
155 | new_sql = new_sql + sql[0 : match.start()] + new_alias
156 | else:
157 | new_sql = new_sql + sql[0 : match.start()] + match.group(0)
158 |
159 | sql = sql[match.end() :]
160 | match = re.search(TAGS_PATTERN, sql)
161 | if sql:
162 | new_sql = new_sql + sql
163 | sql = new_sql
164 |
165 | # convert FORMAT function to CONVERT function to support JSQL
166 | sql = re.sub(r" format\(", " convert(", sql, flags=re.IGNORECASE)
167 |
168 | # remove comments from SQL
169 | sql = remove_comments(sql)
170 |
171 | # replace N'%Kitchener%' with '%Kitchener%'
172 | sql = re.sub(r" N'", " '", sql, re.IGNORECASE)
173 |
174 | # remove declares with a new line
175 | sql = re.sub(
176 | r"(DECLARE|declare) [^\n]+\n",
177 | " ",
178 | sql,
179 | re.IGNORECASE | re.MULTILINE | re.DOTALL,
180 | )
181 |
182 | # remove new lines
183 | sql = re.sub(r"[\n\t\r]+", " ", sql)
184 |
185 | sql = remove_comments_after_removing_new_lines(sql)
186 |
187 | # remove declares
188 | sql = re.sub(
189 | r"(DECLARE|declare) [^;]+;", " ", sql, re.IGNORECASE | re.MULTILINE | re.DOTALL
190 | )
191 | sql = re.sub(r"(DECLARE|declare) (?:.(?!(SELECT|select)))", "SELECT", sql)
192 |
193 | if "))))))))))))))))))))" in sql or "((((((((((((((((((((" in sql:
194 | return None
195 |
196 | if "cast(avg(cast(avg(cast(avg(cast(avg(cast(avg(cast(avg(cast(avg(" in sql:
197 | return None
198 |
199 | sql = re.sub(r"[^\x00-\x7f]", r" ", sql)
200 | sql = re.sub(r"``", r"'", sql)
201 | sql = re.sub(r"\"", r"'", sql)
202 | sql = re.sub(r" +", " ", sql).strip()
203 |
204 | if not sql:
205 | return None
206 |
207 | if sql[-1] == ";":
208 | sql = sql[0:-1]
209 |
210 | if ";" in sql:
211 | sql = sql.split(";")[-1]
212 |
213 | return sql
214 |
--------------------------------------------------------------------------------
/data_prep/prep_text2sql_data.py:
--------------------------------------------------------------------------------
1 | """Prepare data for NSText2SQL."""
2 |
3 | import hashlib
4 | import json
5 | import multiprocessing
6 | import os
7 | import random
8 | from collections import defaultdict
9 | from datetime import datetime
10 | from functools import partial
11 | from pathlib import Path
12 |
13 | import click
14 | import numpy as np
15 | import yaml
16 | from prompt_formatters import RajkumarFormatter
17 | from rich.console import Console
18 | from text2sql_dataset import (
19 | CSS2SQL,
20 | NVBENCH2SQL,
21 | Criteria2SQL2SQL,
22 | KummerfeldText2SQL,
23 | MimicsqlText2SQL,
24 | SedeText2SQL,
25 | SpiderText2SQL,
26 | SqlCreateContext2SQL,
27 | Squall2SQL,
28 | Text2SQLData,
29 | Text2SQLDataset,
30 | WikiSQL2SQL,
31 | )
32 | from tqdm.auto import tqdm
33 | from transformers import AutoTokenizer
34 |
35 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
36 |
37 | console = Console(soft_wrap=True)
38 |
39 |
40 | TEXT2SQL_DATA_LOADERS = {
41 | "kummerfeld": KummerfeldText2SQL,
42 | "sede": SedeText2SQL,
43 | "spider": SpiderText2SQL,
44 | "wikisql": WikiSQL2SQL,
45 | "mimicsql": MimicsqlText2SQL,
46 | "criteria2sql": Criteria2SQL2SQL,
47 | "sql_create_context2sql": SqlCreateContext2SQL,
48 | "squall": Squall2SQL,
49 | "css": CSS2SQL,
50 | "nvbench": NVBENCH2SQL,
51 | }
52 |
53 |
54 | def process_dataset(
55 | prompt_formatter: RajkumarFormatter,
56 | splits: dict[str, list[Text2SQLData]],
57 | bad_parses: dict[str, int],
58 | total: dict[str, int],
59 | text2sql_dataset: Text2SQLDataset,
60 | hash_experiment_key: str,
61 | ) -> None:
62 | """Process a dataset and add it to the splits."""
63 | schema = text2sql_dataset.load_schema()
64 | temp_outfile = f"_temp_text2sql/prep_data/temp_{hash_experiment_key}.jsonl"
65 | Path(temp_outfile).parent.mkdir(parents=True, exist_ok=True)
66 | if os.path.exists(temp_outfile):
67 | console.print(f"Reading from {temp_outfile}")
68 | with open(temp_outfile, "r") as in_f:
69 | loaded_data = json.load(in_f)
70 | else:
71 | loaded_data = text2sql_dataset.load_data(schema)
72 | console.print(f"Saving to {temp_outfile}")
73 | with open(temp_outfile, "w") as out_f:
74 | json.dump(loaded_data, out_f)
75 |
76 | formatting_func = partial(
77 | text2sql_dataset.format_example,
78 | schema=schema,
79 | prompt_formatter=prompt_formatter,
80 | )
81 | cnt = 0
82 | for split, split_data in loaded_data.items():
83 | console.print(f"Found {len(split_data)} examples for {split}.")
84 | pool = multiprocessing.Pool(
85 | processes=15,
86 | )
87 | for formatted_data in tqdm(
88 | pool.imap(formatting_func, split_data, chunksize=100),
89 | total=len(split_data),
90 | desc=f"Formatting {split}",
91 | ):
92 | total[split] += 1
93 | cnt += 1
94 | if formatted_data:
95 | formatted_as_traindata = Text2SQLData(**formatted_data)
96 | splits[split].append(Text2SQLData(**formatted_data))
97 | if total[split] <= 20 or cnt <= 20:
98 | console.print(f"\n***[yellow]Example {total[split]}[/yellow]***")
99 | console.print(
100 | json.dumps(
101 | formatted_as_traindata.dict(), indent=2, ensure_ascii=False
102 | )
103 | )
104 | else:
105 | bad_parses[split] += 1
106 | pool.close()
107 | pool.join()
108 | console.print(f"Bad parses: {json.dumps(bad_parses, indent=2, ensure_ascii=False)}")
109 |
110 |
111 | @click.command()
112 | @click.option("--datasets", type=str, required=True, multiple=True)
113 | @click.option(
114 | "--config_path",
115 | type=str,
116 | default=(f"{os.path.join(os.path.dirname(__file__))}/text2sql_data_config.yaml"),
117 | )
118 | @click.option("--output_dir", type=str, default="")
119 | @click.option("--seed", type=int, default=0)
120 | @click.option("--tokenizer_name", type=str, default="Salesforce/codegen-2B-multi")
121 | @click.option("--seq_length", type=int, default=2048)
122 | @click.option("--merge_dev", type=bool, default=False, is_flag=True)
123 | @click.option("--merge_test", type=bool, default=False, is_flag=True)
124 | def build(
125 | datasets: list[str],
126 | config_path: str,
127 | output_dir: str,
128 | seed: int,
129 | tokenizer_name: str,
130 | seq_length: int,
131 | merge_dev: bool,
132 | merge_test: bool,
133 | ) -> None:
134 | """Build training data for text2SQL model training.
135 |
136 | Args:
137 | datasets: the datasets to read - matches on name in config
138 | config_path: path to config
139 | output_dir: output directory
140 | seed: the random seed
141 | tokenizer_name: the tokenizer to use
142 | seq_length: max seq_length for training data
143 | merge_dev: whether merge dev in to train
144 | merge_test: whether merge test in to train
145 | """
146 | to_save_args = locals()
147 | random.seed(seed)
148 | np.random.seed(seed)
149 | try:
150 | multiprocessing.set_start_method("spawn")
151 | except RuntimeError:
152 | pass
153 |
154 | prompt_formatter = RajkumarFormatter()
155 |
156 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
157 |
158 | config = yaml.safe_load(open(config_path))
159 | to_save_args["config"] = config
160 |
161 | data_configs = {}
162 | for data_loader in config["datasets"]:
163 | data_configs[data_loader["name"]] = data_loader
164 | assert len(data_configs) == len(
165 | config["datasets"]
166 | ), f"Overloaded name in {config_path}"
167 |
168 | datasets = [dataset.lower() for dataset in datasets]
169 | for dataset in datasets:
170 | if dataset not in data_configs:
171 | raise ValueError(f"Dataset {dataset} not supported.")
172 | if data_configs[dataset]["loader"] not in TEXT2SQL_DATA_LOADERS:
173 | raise ValueError(f"Loader {data_configs[dataset]['loader']} not supported.")
174 |
175 | data_classes: list[Text2SQLDataset] = [
176 | TEXT2SQL_DATA_LOADERS[data_configs[dataset]["loader"]]( # type: ignore
177 | **data_configs[dataset],
178 | context_length=seq_length,
179 | tokenizer_name=tokenizer_name,
180 | )
181 | for dataset in datasets
182 | ]
183 |
184 | splits: dict[str, list[Text2SQLData]] = {"train": [], "dev": [], "test": []}
185 | bad_parses: dict[str, int] = defaultdict(int)
186 | total: dict[str, int] = defaultdict(int)
187 |
188 | for data_class in data_classes:
189 | console.print(f"[green]Loading[/green] {data_class.name}")
190 | to_hash_args = to_save_args.copy()
191 | to_hash_args["dataset_name"] = data_class.name
192 | hash_experiment_key = hashlib.sha256(
193 | json.dumps(to_hash_args, sort_keys=True).encode("utf-8")
194 | ).hexdigest()
195 | process_dataset(
196 | prompt_formatter=prompt_formatter,
197 | splits=splits,
198 | bad_parses=bad_parses,
199 | total=total,
200 | text2sql_dataset=data_class,
201 | hash_experiment_key=hash_experiment_key,
202 | )
203 |
204 | if merge_dev:
205 | splits["train"].extend(splits["dev"])
206 | splits["dev"] = []
207 | if merge_test:
208 | splits["train"].extend(splits["test"])
209 | splits["test"] = []
210 |
211 | date = datetime.now().strftime("%Y-%m-%d")
212 | joined_output_dir = Path(output_dir) / date
213 | joined_output_dir.mkdir(parents=True, exist_ok=True)
214 |
215 | console.print(f"Starting length of train: {len(splits['train'])}")
216 |
217 | # Deduplicate training data
218 | unq_inps = set()
219 | new_train = []
220 | for ex in splits["train"]:
221 | if ex.instruction not in unq_inps:
222 | new_train.append(ex)
223 | unq_inps.add(ex.instruction)
224 | splits["train"] = new_train
225 |
226 | console.print(f"After dedup length of train: {len(splits['train'])}")
227 |
228 | # Get token size statistics
229 | tokenized_inputs = tokenizer(list(map(lambda x: x.instruction, splits["train"])))
230 | tokenized_outputs = tokenizer(list(map(lambda x: x.output, splits["train"])))
231 | input_lengths = [len(x) for x in tokenized_inputs["input_ids"]]
232 | output_lengths = [len(x) for x in tokenized_outputs["input_ids"]]
233 | sum_lengths = [x + y for x, y in zip(input_lengths, output_lengths)]
234 |
235 | console.print(
236 | f"Max input length: {max(input_lengths)}, "
237 | f"Median length: {np.median(input_lengths)}, "
238 | f"90th percentile: {np.percentile(input_lengths, 90)}"
239 | )
240 | console.print(
241 | f"Max output length: {max(output_lengths)}, "
242 | f"Median length: {np.median(output_lengths)}, "
243 | f"90th percentile: {np.percentile(output_lengths, 90)}"
244 | )
245 | console.print(
246 | f"Percent overflow: {100*sum(x > seq_length for x in sum_lengths)/len(sum_lengths):.2f}"
247 | )
248 | console.print(
249 | f"Max sum length: {max(sum_lengths)}, "
250 | f"Median length: {np.median(sum_lengths)}, "
251 | f"85th percentile: {np.percentile(sum_lengths, 85)}"
252 | f"90th percentile: {np.percentile(sum_lengths, 90)}"
253 | f"95th percentile: {np.percentile(sum_lengths, 95)}"
254 | )
255 |
256 | # Save the data
257 | random.seed(seed)
258 | random.shuffle(splits["train"])
259 |
260 | for split in splits:
261 | console.print(
262 | f"Found {bad_parses[split]} bad parses out of "
263 | f"{total[split]} ({100*bad_parses[split]/max(total[split], 1): .2f})."
264 | )
265 | console.print(
266 | f"Saving [green]{split} ({len(splits[split])}) "
267 | f"[/green] data to {joined_output_dir}/{split}.jsonl"
268 | )
269 | with open(joined_output_dir / f"{split}.jsonl", "w") as f:
270 | for formatted_ex in splits[split]:
271 | f.write(json.dumps(formatted_ex.dict(), ensure_ascii=False) + "\n")
272 | with open(f"{joined_output_dir}/config.json", "w") as f:
273 | json.dump(to_save_args, f, indent=4)
274 |
275 |
276 | if __name__ == "__main__":
277 | build()
278 |
--------------------------------------------------------------------------------
/data_prep/prompt_formatters.py:
--------------------------------------------------------------------------------
1 | """Rajkumar prompt formatter."""
2 |
3 | from abc import ABC
4 | from random import shuffle
5 |
6 | from schema import Table
7 |
8 |
9 | class RajkumarFormatter(ABC):
10 | """RajkumarFormatter class.
11 |
12 | From https://arxiv.org/pdf/2204.00498.pdf.
13 | """
14 |
15 | table_sep: str = "\n\n"
16 | shuffle_table_order: bool = True
17 | _cache: dict[tuple[str, str, str], list[str]] = {}
18 |
19 | @classmethod
20 | def format_table(cls, table: Table) -> str:
21 | """Get table format."""
22 | table_fmt = []
23 | for col in table.columns or []:
24 | # This is technically an incorrect type, but it should be a catchall word
25 | table_fmt.append(f" {col.name} {col.dtype or 'any'}")
26 | if table_fmt:
27 | all_cols = ",\n".join(table_fmt)
28 | create_tbl = f"CREATE TABLE {table.name} (\n{all_cols}\n)"
29 | else:
30 | create_tbl = f"CREATE TABLE {table.name}"
31 | return create_tbl
32 |
33 | @classmethod
34 | def format_all_tables(cls, tables: list[Table], instruction: str) -> list[str]:
35 | """Get all tables format."""
36 | table_texts = [cls.format_table(table) for table in tables]
37 | key = ("tables", instruction, str(tables))
38 | if key not in cls._cache:
39 | shuffle(table_texts)
40 | cls._cache[key] = table_texts
41 | else:
42 | table_texts = cls._cache[key]
43 | return table_texts
44 |
45 | @classmethod
46 | def format_prompt(
47 | cls,
48 | instruction: str,
49 | table_text: str,
50 | ) -> str:
51 | """Get prompt format."""
52 | return f"""{table_text}\n\n\n-- Using valid SQLite, answer the following questions for the tables provided above.\n\n-- {instruction}\n""" # noqa: E501
53 |
54 | @classmethod
55 | def format_model_output(cls, output_sql: str, prompt: str) -> str:
56 | """Format model output."""
57 | return output_sql
58 |
59 | @classmethod
60 | def format_gold_output(cls, output_sql: str) -> str:
61 | """Format gold output for demonstration."""
62 | return output_sql
63 |
--------------------------------------------------------------------------------
/data_prep/schema.py:
--------------------------------------------------------------------------------
1 | """Text2SQL schemas."""
2 |
3 | from pydantic import BaseModel
4 |
5 |
6 | class TableColumn(BaseModel):
7 | """Table column."""
8 |
9 | name: str
10 | dtype: str | None
11 |
12 |
13 | class ForeignKey(BaseModel):
14 | """Foreign key."""
15 |
16 | # Referenced column
17 | column: TableColumn
18 | # References table name
19 | references_name: str
20 | # References column
21 | references_column: TableColumn
22 |
23 |
24 | class Table(BaseModel):
25 | """Table."""
26 |
27 | name: str | None
28 | columns: list[TableColumn] | None
29 | pks: list[TableColumn] | None
30 | # FK from this table to another column in another table
31 | fks: list[ForeignKey] | None
32 | examples: list[dict] | None
33 | # Is the table a source or intermediate reference table
34 | is_reference_table: bool = False
35 |
--------------------------------------------------------------------------------
/data_prep/text2sql_data_config.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | - loader: kummerfeld
3 | name: academic
4 | train_data_file: data/text2sql-data/data/academic.json
5 | val_data_file: null
6 | test_data_file: null
7 | schema_file: data/text2sql-data/data/academic-schema.csv
8 |
9 | - loader: kummerfeld
10 | name: advising
11 | train_data_file: data/text2sql-data/data/advising.json
12 | val_data_file: null
13 | test_data_file: null
14 | schema_file: data/text2sql-data/data/advising-schema.csv
15 |
16 | - loader: kummerfeld
17 | name: atis
18 | train_data_file: data/text2sql-data/data/atis.json
19 | val_data_file: null
20 | test_data_file: null
21 | schema_file: data/text2sql-data/data/atis-schema.csv
22 |
23 | - loader: kummerfeld
24 | name: geography
25 | train_data_file: data/text2sql-data/data/geography.json
26 | val_data_file: null
27 | test_data_file: null
28 | schema_file: data/text2sql-data/data/geography-schema.csv
29 |
30 | - loader: kummerfeld
31 | name: imdb
32 | train_data_file: data/text2sql-data/data/imdb.json
33 | val_data_file: null
34 | test_data_file: null
35 | schema_file: data/text2sql-data/data/imdb-schema.csv
36 |
37 | - loader: kummerfeld
38 | name: restaurants
39 | train_data_file: data/text2sql-data/data/restaurants.json
40 | val_data_file: null
41 | test_data_file: null
42 | schema_file: data/text2sql-data/data/restaurants-schema.csv
43 |
44 | - loader: kummerfeld
45 | name: scholar
46 | train_data_file: data/text2sql-data/data/scholar.json
47 | val_data_file: null
48 | test_data_file: null
49 | schema_file: data/text2sql-data/data/scholar-schema.csv
50 |
51 | - loader: kummerfeld
52 | name: yelp
53 | train_data_file: data/text2sql-data/data/yelp.json
54 | val_data_file: null
55 | test_data_file: null
56 | schema_file: data/text2sql-data/data/yelp-schema.csv
57 |
58 | - loader: spider
59 | name: spider
60 | train_data_file: data/spider/train_spider.json
61 | val_data_file: data/spider/dev.json
62 | test_data_file: null
63 | schema_file: data/spider/tables.json
64 |
65 | - loader: sede
66 | name: sede
67 | train_data_file: data/sede/data/sede/train.jsonl
68 | val_data_file: data/sede/data/sede/val.jsonl
69 | test_data_file: data/sede/data/sede/test.jsonl
70 | schema_file: data/sede/stackexchange_schema/tables_so.json
71 |
72 | - loader: wikisql
73 | name: wikisql
74 | train_data_file: null
75 | val_data_file: null
76 | test_data_file: null
77 | schema_file: null
78 |
79 | - loader: spider
80 | name: eicu
81 | train_data_file: data/EHRSQL/dataset/ehrsql/eicu/train.json
82 | val_data_file: data/EHRSQL/dataset/ehrsql/eicu/valid.json
83 | test_data_file: null
84 | schema_file: data/EHRSQL/dataset/ehrsql/tables.json
85 |
86 | - loader: spider
87 | name: mimic_iii
88 | train_data_file: data/EHRSQL/dataset/ehrsql/mimic_iii/train.json
89 | val_data_file: data/EHRSQL/dataset/ehrsql/mimic_iii/valid.json
90 | test_data_file: null
91 | schema_file: data/EHRSQL/dataset/ehrsql/tables.json
92 |
93 | - loader: spider
94 | name: geonucleardata
95 | train_data_file: data/KaggleDBQA/examples/GeoNuclearData.json@data/KaggleDBQA/examples/GeoNuclearData_fewshot.json
96 | val_data_file: null
97 | test_data_file: data/KaggleDBQA/examples/GeoNuclearData_test.json
98 | schema_file: data/KaggleDBQA/KaggleDBQA_tables.json
99 |
100 | - loader: spider
101 | name: greatermanchestercrime
102 | train_data_file: data/KaggleDBQA/examples/GreaterManchesterCrime.json@data/KaggleDBQA/examples/GreaterManchesterCrime_fewshot.json
103 | val_data_file: null
104 | test_data_file: data/KaggleDBQA/examples/GreaterManchesterCrime_test.json
105 | schema_file: data/KaggleDBQA/KaggleDBQA_tables.json
106 |
107 | - loader: spider
108 | name: pesticide
109 | train_data_file: data/KaggleDBQA/examples/Pesticide.json@data/KaggleDBQA/examples/Pesticide_fewshot.json
110 | val_data_file: null
111 | test_data_file: data/KaggleDBQA/examples/Pesticide_test.json
112 | schema_file: data/KaggleDBQA/KaggleDBQA_tables.json
113 |
114 | - loader: spider
115 | name: studentmathscore
116 | train_data_file: data/KaggleDBQA/examples/StudentMathScore.json@data/KaggleDBQA/examples/StudentMathScore_fewshot.json
117 | val_data_file: null
118 | test_data_file: data/KaggleDBQA/examples/StudentMathScore_test.json
119 | schema_file: data/KaggleDBQA/KaggleDBQA_tables.json
120 |
121 | - loader: spider
122 | name: thehistoryofbaseball
123 | train_data_file: data/KaggleDBQA/examples/TheHistoryofBaseball.json@data/KaggleDBQA/examples/TheHistoryofBaseball_fewshot.json
124 | val_data_file: null
125 | test_data_file: data/KaggleDBQA/examples/TheHistoryofBaseball_test.json
126 | schema_file: data/KaggleDBQA/KaggleDBQA_tables.json
127 |
128 | - loader: spider
129 | name: uswildfires
130 | train_data_file: data/KaggleDBQA/examples/USWildFires.json@data/KaggleDBQA/examples/USWildFires_fewshot.json
131 | val_data_file: null
132 | test_data_file: data/KaggleDBQA/examples/USWildFires_test.json
133 | schema_file: data/KaggleDBQA/KaggleDBQA_tables.json
134 |
135 | - loader: spider
136 | name: whatcdhiphop
137 | train_data_file: data/KaggleDBQA/examples/WhatCDHipHop.json@data/KaggleDBQA/examples/WhatCDHipHop_fewshot.json
138 | val_data_file: null
139 | test_data_file: data/KaggleDBQA/examples/WhatCDHipHop_test.json
140 | schema_file: data/KaggleDBQA/KaggleDBQA_tables.json
141 | num_demos: 0
142 | num_copies: 1
143 |
144 | - loader: spider
145 | name: worldsoccerdatabase
146 | train_data_file: data/KaggleDBQA/examples/WorldSoccerDataBase.json@data/KaggleDBQA/examples/WorldSoccerDataBase_fewshot.json
147 | val_data_file: null
148 | test_data_file: data/KaggleDBQA/examples/WorldSoccerDataBase_test.json
149 | schema_file: data/KaggleDBQA/KaggleDBQA_tables.json
150 | num_demos: 0
151 | num_copies: 1
152 |
153 | - loader: mimicsql
154 | name: mimicsql_data
155 | train_data_file: data/TREQS/mimicsql_data/mimicsql_natural_v2/train.json@data/TREQS/mimicsql_data/mimicsql_template/train.json
156 | val_data_file: data/TREQS/mimicsql_data/mimicsql_natural_v2/dev.json@data/TREQS/mimicsql_data/mimicsql_template/dev.json
157 | test_data_file: data/TREQS/mimicsql_data/mimicsql_natural_v2/test.json@data/TREQS/mimicsql_data/mimicsql_template/test.json
158 | schema_file: data/TREQS/mimicsql_data/tables.json
159 |
160 | - loader: criteria2sql
161 | name: criteria2sql
162 | train_data_file: data/Criteria2SQL/data/train.jsonl
163 | val_data_file: data/Criteria2SQL/data/dev.jsonl
164 | test_data_file: data/Criteria2SQL/data/test.jsonl
165 | train_schema_file: data/Criteria2SQL/data/train.tables.jsonl
166 | val_schema_file: data/Criteria2SQL/data/dev.tables.jsonl
167 | test_schema_file: data/Criteria2SQL/data/test.tables.jsonl
168 | schema_file: null
169 |
170 | - loader: sql_create_context2sql
171 | name: sql_create_context
172 | train_data_file: null
173 | val_data_file: null
174 | test_data_file: null
175 | schema_file: null
176 |
177 | - loader: squall
178 | name: squall
179 | train_data_file: data/squall/data/squall.json
180 | val_data_file: null
181 | test_data_file: null
182 | schema_file: null
183 |
184 | - loader: css
185 | name: css
186 | train_data_file: example.train@template.train@schema.train
187 | val_data_file: example.dev@template.dev@schema.dev
188 | test_data_file: example.test@template.test@schema.test
189 | schema_file: data/css/tables.json
190 |
191 | - loader: nvbench
192 | name: nvbench
193 | train_data_file: data/nvBench/NVBench.json
194 | val_data_file: null
195 | test_data_file: null
196 | schema_file: data/nvBench/database/
197 |
--------------------------------------------------------------------------------
/data_prep/text2sql_dataset.py:
--------------------------------------------------------------------------------
1 | """Text2SQL dataset class."""
2 |
3 | import copy
4 | import json
5 | import os
6 | import sqlite3
7 | from abc import ABC, abstractmethod
8 | from functools import partial
9 | from glob import glob
10 | from pathlib import Path
11 | from typing import Any
12 |
13 | import jsonlines
14 | import sqlglot
15 | from data_utils import (
16 | clean_str,
17 | convert_criteria2sql_instance,
18 | convert_css_nvbench_instance,
19 | convert_kummerfeld_instance,
20 | convert_sede_instance,
21 | convert_spider_instance,
22 | convert_sql_create_context_instance,
23 | convert_squall_instance,
24 | convert_wikisql_instance,
25 | escape_everything,
26 | read_tables_json,
27 | serialize_dict_to_str,
28 | )
29 | from datasets import load_dataset
30 | from prompt_formatters import RajkumarFormatter
31 | from pydantic import BaseModel
32 | from rich.console import Console
33 | from schema import ForeignKey, Table, TableColumn
34 | from sqlglot import parse_one
35 | from tqdm.auto import tqdm
36 | from transformers import AutoTokenizer
37 |
38 | console = Console(soft_wrap=True)
39 |
40 |
41 | class Text2SQLData(BaseModel):
42 | """Text2SQL data class."""
43 |
44 | instruction: str
45 | output: str
46 | source: str
47 |
48 |
49 | class Text2SQLDataset(ABC):
50 | """Text2SQL dataset class."""
51 |
52 | def __init__(
53 | self,
54 | name: str,
55 | train_data_file: str,
56 | val_data_file: str,
57 | test_data_file: str,
58 | schema_file: str,
59 | context_length: int,
60 | tokenizer_name: str,
61 | **kwargs: Any,
62 | ) -> None:
63 | """Initialize."""
64 | self.name = name
65 | self.train_data_file = train_data_file
66 | self.val_data_file = val_data_file
67 | self.test_data_file = test_data_file
68 | self.schema_file = schema_file
69 | self.context_length = context_length
70 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
71 | self.clean_question = True
72 | self.process_init_kwargs(**kwargs)
73 |
74 | def process_init_kwargs(self, **kwargs: Any) -> None:
75 | """Process init kwargs."""
76 | pass
77 |
78 | @abstractmethod
79 | def load_data(
80 | self, schema: dict[str, dict[str, Table]]
81 | ) -> dict[str, list[dict[str, Any]]]:
82 | """Load data."""
83 | raise NotImplementedError
84 |
85 | @abstractmethod
86 | def load_schema(self) -> dict[str, dict[str, Table]]:
87 | """Load schema."""
88 | raise NotImplementedError
89 |
90 | def _is_parseable(self, sql: str) -> bool:
91 | try:
92 | res: sqlglot.expressions.Expression | None = parse_one(sql, read="sqlite")
93 | return res is not None
94 | except Exception:
95 | return False
96 |
97 | def _format_example(
98 | self,
99 | ex: dict[str, Any],
100 | schema: dict[str, dict[str, Table]],
101 | prompt_formatter: RajkumarFormatter,
102 | gold_sql_key: str,
103 | ) -> tuple[str, str] | None:
104 | if not self._is_parseable(ex[gold_sql_key]):
105 | print("BAD:::", ex[gold_sql_key])
106 | return None
107 |
108 | db_id = ex.get("db_id", "database")
109 | db_schema = schema[db_id]
110 | tables_to_add = list(db_schema.keys())
111 |
112 | if self.clean_question:
113 | question = clean_str(ex["question"]).strip("'").strip('"')
114 | else:
115 | question = ex["question"].strip("'").strip('"')
116 | table_text = prompt_formatter.table_sep.join(
117 | prompt_formatter.format_all_tables(
118 | [db_schema[t] for t in tables_to_add], question
119 | )
120 | )
121 |
122 | input_str = prompt_formatter.format_prompt(question, table_text)
123 | output_str = prompt_formatter.format_gold_output(ex[gold_sql_key])
124 | return input_str, output_str
125 |
126 | def format_example(
127 | self,
128 | example: dict[str, Any],
129 | schema: dict[str, dict[str, Table]],
130 | prompt_formatter: RajkumarFormatter,
131 | ) -> dict[str, Any] | None:
132 | """Format example."""
133 |
134 | result = self._format_example(
135 | example,
136 | schema,
137 | prompt_formatter,
138 | "sql",
139 | )
140 | if not result:
141 | return None
142 | input_str, output_str = result
143 | input_str = input_str.strip() + "\n"
144 | output_str = output_str.strip()
145 | data_ex = dict(
146 | instruction=input_str,
147 | output=output_str,
148 | source=self.name,
149 | )
150 | return data_ex
151 |
152 |
153 | class KummerfeldText2SQL(Text2SQLDataset):
154 | """Kummerfeld text2sql dataset from the text2sql-data repo."""
155 |
156 | def load_data(
157 | self, schema: dict[str, dict[str, Table]]
158 | ) -> dict[str, list[dict[str, Any]]]:
159 | """Load data."""
160 | data_pathobj = Path(self.train_data_file)
161 | splits: dict[str, list[dict[str, Any]]] = {"train": [], "dev": [], "test": []}
162 | for raw_ex in tqdm(json.load(data_pathobj.open()), desc="Loading data"):
163 | for ex in convert_kummerfeld_instance(raw_ex, schema=schema):
164 | if ex:
165 | splits[ex["split"]].append(ex)
166 | return splits
167 |
168 | def mine_for_fks(
169 | self, data: list[list[str]], header: list[str]
170 | ) -> dict[str, list[tuple[str, str, list[str]]]]:
171 | """Mine for fks from schema."""
172 | # The Is Foreign Key column is not always correct so mine via exact match
173 | cur_tablename = None
174 | cur_database = None
175 | schema: dict = {}
176 | cur_table: dict[str, list] = {}
177 | for ex in data:
178 | if len(header) != len(ex):
179 | ex = ex[: len(header)]
180 | row = {h: r.strip() for h, r in zip(header, ex)}
181 | # Keep the type as only the first key
182 | # e.g. varchar(255) default null -> varchar
183 | table_name = row["Table Name"].lower()
184 | field_name = row["Field Name"].lower()
185 | database_name = row.get("Database name", "database").lower()
186 | if (
187 | table_name == "-"
188 | or table_name != cur_tablename
189 | or database_name != cur_database
190 | ):
191 | if len(cur_table) > 0:
192 | if cur_database not in schema:
193 | schema[cur_database] = {}
194 | schema[cur_database][cur_tablename] = cur_table
195 | cur_table = {}
196 | cur_database = None
197 | cur_tablename = None
198 | if cur_tablename is None and table_name != "-":
199 | cur_tablename = table_name
200 | cur_table = {
201 | "columns": [],
202 | "pks": [],
203 | }
204 | if cur_database is None and database_name != "-":
205 | cur_database = database_name
206 | if cur_tablename is not None:
207 | assert cur_database is not None
208 | cur_table["columns"].append(field_name)
209 | if row["Is Primary Key"].strip().lower() in [
210 | "yes",
211 | "true",
212 | "y",
213 | "t",
214 | "pri",
215 | ]:
216 | cur_table["pks"].append(field_name)
217 |
218 | # Add last table
219 | assert cur_database is not None
220 | assert cur_tablename is not None
221 | schema[cur_database][cur_tablename] = cur_table
222 |
223 | # Find Fks by matching on field_name
224 | fks: dict[str, list[tuple[str, str, list[str]]]] = {}
225 | for database in schema:
226 | fks[database] = []
227 | for referenced_table in schema[database]:
228 | # Only want one key per column
229 | used_columns = set()
230 | for references_table in schema[database]:
231 | if referenced_table == references_table:
232 | continue
233 | # Find all columns in referenced table that are in references table
234 | matching_cols = [
235 | c
236 | for c in schema[database][referenced_table]["columns"]
237 | if c in schema[database][references_table]["columns"]
238 | and c not in used_columns
239 | ]
240 | matching_pk_cols = [
241 | c
242 | for c in matching_cols
243 | if c in schema[database][references_table]["pks"]
244 | and c.lower() != "id"
245 | ]
246 | used_columns.update(matching_cols)
247 | if len(matching_pk_cols) > 0:
248 | # Use the fk
249 | fks[database].append(
250 | (referenced_table, references_table, matching_pk_cols)
251 | )
252 | return fks
253 |
254 | def load_schema(self) -> dict[str, dict[str, Table]]:
255 | """Load schema for each table in the database."""
256 | schema_pathobj = Path(self.schema_file)
257 | # Header is Table Name, Field Name, Is Primary Key, Is Foreign Key, Type
258 | data = [l.strip().split(",") for l in schema_pathobj.open().readlines()]
259 | header = [h.strip() for h in data[0]]
260 | data = data[1:]
261 |
262 | all_fks = self.mine_for_fks(data, header)
263 | schema: dict[str, dict[str, Table]] = {}
264 | cur_tablename = None
265 | cur_database = None
266 | cur_table: dict[str, list] = {}
267 | for ex in data:
268 | if len(header) != len(ex):
269 | ex = ex[: len(header)]
270 | row = {h: r.strip() for h, r in zip(header, ex)}
271 | # Keep the type as only the first key
272 | # e.g. varchar(255) default null -> varchar
273 | row_type = row["Type"].split("(")[0].lower()
274 | table_name = row["Table Name"].lower()
275 | field_name = row["Field Name"].lower()
276 | database_name = row.get("Database name", "database").lower()
277 | if (
278 | table_name == "-"
279 | or table_name != cur_tablename
280 | or database_name != cur_database
281 | ):
282 | if len(cur_table) > 0:
283 | assert cur_database is not None
284 | if cur_database not in schema:
285 | schema[cur_database] = {} # type: ignore
286 | schema[cur_database][cur_tablename] = Table( # type: ignore
287 | name=cur_tablename,
288 | columns=[
289 | TableColumn(name=cn, dtype=ct)
290 | for cn, ct in cur_table["columns"]
291 | ],
292 | pks=[
293 | TableColumn(name=cn, dtype=ct)
294 | for cn, ct in cur_table["pks"]
295 | ],
296 | fks=[
297 | ForeignKey(
298 | column=TableColumn(name=cn, dtype=ct),
299 | references_name=rtn,
300 | references_column=TableColumn(name=rn, dtype=rt),
301 | )
302 | for ((cn, ct), rtn, (rn, rt)) in cur_table["fks"]
303 | ],
304 | examples=[],
305 | )
306 | cur_table = {}
307 | cur_database = None
308 | cur_tablename = None
309 | if cur_tablename is None and table_name != "-":
310 | cur_tablename = table_name
311 | cur_table = {
312 | "columns": [],
313 | "pks": [],
314 | "fks": [],
315 | }
316 | if cur_database is None and database_name != "-":
317 | cur_database = database_name
318 | if cur_tablename is not None:
319 | assert cur_database is not None
320 | cur_table["columns"].append((field_name, row_type))
321 | if row["Is Primary Key"].strip().lower() in [
322 | "yes",
323 | "true",
324 | "y",
325 | "t",
326 | "pri",
327 | ]:
328 | cur_table["pks"].append((field_name, row_type))
329 | for fk_tuple in all_fks[cur_database]:
330 | # referenced_table, references_table, matching_pk_cols
331 | if fk_tuple[0] == cur_tablename and field_name in fk_tuple[2]:
332 | cur_table["fks"].append(
333 | (
334 | (field_name, row_type),
335 | fk_tuple[1],
336 | (field_name, row_type),
337 | )
338 | )
339 |
340 | # Add last table
341 | assert cur_database is not None
342 | assert cur_tablename is not None
343 | schema[cur_database][cur_tablename] = Table(
344 | name=cur_tablename,
345 | columns=[TableColumn(name=cn, dtype=ct) for cn, ct in cur_table["columns"]],
346 | pks=[TableColumn(name=cn, dtype=ct) for cn, ct in cur_table["pks"]],
347 | fks=[
348 | ForeignKey(
349 | column=TableColumn(name=cn, dtype=ct),
350 | references_name=rtn,
351 | references_column=TableColumn(name=rn, dtype=rt),
352 | )
353 | for ((cn, ct), rtn, (rn, rt)) in cur_table["fks"]
354 | ],
355 | examples=[],
356 | )
357 | return schema
358 |
359 |
360 | class SedeText2SQL(Text2SQLDataset):
361 | """Sede text2sql dataset from the text2sql-data repo."""
362 |
363 | def load_data(
364 | self, schema: dict[str, dict[str, Table]]
365 | ) -> dict[str, list[dict[str, Any]]]:
366 | """Load data."""
367 | splits: dict[str, list[dict[str, Any]]] = {"train": [], "dev": [], "test": []}
368 | for split in splits:
369 | if split == "dev":
370 | to_read_split = self.val_data_file
371 | elif split == "test":
372 | to_read_split = self.test_data_file
373 | else:
374 | to_read_split = self.train_data_file
375 | data_file = Path(to_read_split)
376 | for line in tqdm(
377 | data_file.open().readlines(), desc=f"Loading {split} data"
378 | ):
379 | raw_ex = json.loads(line)
380 | ex = convert_sede_instance(raw_ex, schema=schema)
381 | if ex:
382 | splits[split].append(ex)
383 | return splits
384 |
385 | def load_schema(self) -> dict[str, dict[str, Table]]:
386 | """Load schema for each table in the database."""
387 | schema_dct = read_tables_json(self.schema_file)
388 | return schema_dct
389 |
390 |
391 | class SpiderText2SQL(Text2SQLDataset):
392 | """Spider text2sql dataset adapted from Huggingface/Picard."""
393 |
394 | def load_data(
395 | self, schema: dict[str, dict[str, Table]]
396 | ) -> dict[str, list[dict[str, Any]]]:
397 | """Load data."""
398 | splits: dict[str, list[dict[str, Any]]] = {"train": [], "dev": [], "test": []}
399 | all_data_for_demos: dict[str, list[dict[str, Any]]] = {
400 | "train": [],
401 | "dev": [],
402 | "test": [],
403 | }
404 | for split in splits:
405 | if split in "dev":
406 | to_read_files = [Path(self.val_data_file)] if self.val_data_file else []
407 | elif split == "train":
408 | to_read_files = [Path(p) for p in self.train_data_file.split("@")]
409 | elif split == "test":
410 | to_read_files = (
411 | [Path(self.test_data_file)] if self.test_data_file else []
412 | )
413 | else:
414 | to_read_files = []
415 | console.print(f"Loading {split} data", style="bold blue")
416 | for file in to_read_files:
417 | data_file = Path(file)
418 | try:
419 | data = json.load(data_file.open())
420 | except json.decoder.JSONDecodeError:
421 | data = [json.loads(line) for line in data_file.open()]
422 | convert_function = partial(
423 | convert_spider_instance,
424 | schema=schema,
425 | )
426 | for ex in tqdm(
427 | map(
428 | convert_function,
429 | data,
430 | ),
431 | desc=f"Loading {split} data from {data_file.name}",
432 | total=len(data),
433 | ):
434 | if ex:
435 | splits[split].append(ex)
436 | all_data_for_demos[split].append(copy.deepcopy(ex))
437 |
438 | return splits
439 |
440 | def load_schema(self) -> dict[str, dict[str, Table]]:
441 | """Load schema for each table in the database."""
442 | schema_dct = read_tables_json(self.schema_file, lowercase=True)
443 | return schema_dct
444 |
445 |
446 | class WikiSQL2SQL(Text2SQLDataset):
447 | """WikiSQL text2sql dataset from the text2sql-data repo."""
448 |
449 | example2table_name: dict[str, str] = {}
450 |
451 | def load_schema(self) -> dict[str, dict[str, Table]]:
452 | """Load schema for each table in the database."""
453 | self.dataset = load_dataset("wikisql")
454 | schema_dct: dict[str, dict[str, Table]] = {}
455 | for split in sorted(self.dataset):
456 | for ex in tqdm(self.dataset[split], desc=f"Loading {split} data schema"):
457 | table = ex["table"]
458 | key = serialize_dict_to_str(ex)
459 | if key not in self.example2table_name:
460 | self.example2table_name[
461 | key
462 | ] = f"table_{len(self.example2table_name)}"
463 | table_name = self.example2table_name[key]
464 | # Quote column names to handle spaces
465 | column_names = [
466 | f'"{escape_everything(col)}"' for col in table["header"]
467 | ]
468 | if table_name in schema_dct:
469 | continue
470 | columns = [
471 | TableColumn(name=col, dtype=typ)
472 | for col, typ in zip(column_names, table["types"])
473 | ]
474 | examples = [
475 | {column_names[i]: row[i] for i in range(len(row))}
476 | for row in table["rows"]
477 | ]
478 | if table_name not in schema_dct:
479 | schema_dct[table_name] = {}
480 | # WikiSQL uses table_name for both db and table name
481 | schema_dct[table_name][table_name] = Table(
482 | name=table_name, columns=columns, examples=examples
483 | )
484 | return schema_dct
485 |
486 | def load_data(
487 | self, schema: dict[str, dict[str, Table]]
488 | ) -> dict[str, list[dict[str, Any]]]:
489 | """Load data."""
490 | splits: dict[str, list[dict[str, Any]]] = {"train": [], "dev": [], "test": []}
491 | for split in splits:
492 | split_to_use = split
493 | if split == "dev":
494 | split_to_use = "validation"
495 | for line in tqdm(self.dataset[split_to_use], desc=f"Loading {split} data"):
496 | key = serialize_dict_to_str(line)
497 | line_to_use = line
498 | line_to_use["table"]["name"] = self.example2table_name[key]
499 | ex = convert_wikisql_instance(line_to_use, schema=schema)
500 | if ex:
501 | splits[split].append(ex)
502 | console.print(
503 | f"Loaded {split} data: {len(splits[split])} over {len(self.dataset[split_to_use])}"
504 | )
505 | return splits
506 |
507 | def _get_table_schema(self, table: Table) -> list[str]:
508 | "Get table schema from Table."
509 | return sorted([_.name.lower() for _ in table.columns]) if table.columns else []
510 |
511 |
512 | class MimicsqlText2SQL(Text2SQLDataset):
513 | """Mimicsql text2sql dataset from the TREQS repo."""
514 |
515 | def load_data(
516 | self, schema: dict[str, dict[str, Table]]
517 | ) -> dict[str, list[dict[str, Any]]]:
518 | """Load data."""
519 | splits: dict[str, list[dict[str, Any]]] = {"train": [], "dev": [], "test": []}
520 | data_file_mapping = {
521 | "train": self.train_data_file,
522 | "dev": self.val_data_file,
523 | "test": self.test_data_file,
524 | }
525 | for split in splits:
526 | to_read_files = (
527 | [Path(p) for p in data_file_mapping[split].split("@")]
528 | if split in data_file_mapping
529 | else []
530 | )
531 | console.print(f"Loading {split} data", style="bold blue")
532 |
533 | for file in to_read_files:
534 | data_file = Path(file)
535 | try:
536 | data = json.load(data_file.open())
537 | except json.decoder.JSONDecodeError:
538 | data = [json.loads(line) for line in data_file.open()]
539 | # Convert the data to spider compatible
540 | for i in range(len(data)):
541 | data[i]["db_id"] = "mimicsql"
542 | data[i]["question"] = data[i]["question_refine"]
543 | data[i]["query"] = data[i]["sql"]
544 | convert_function = partial(
545 | convert_spider_instance,
546 | schema=schema,
547 | )
548 | for ex in tqdm(
549 | map(
550 | convert_function,
551 | data,
552 | ),
553 | desc=f"Loading {split} data from {data_file.name}",
554 | total=len(data),
555 | ):
556 | if ex:
557 | splits[split].append(ex)
558 |
559 | return splits
560 |
561 | def load_schema(self) -> dict[str, dict[str, Table]]:
562 | """Load schema for each table in the database."""
563 | schema_dct = read_tables_json(self.schema_file)
564 | return schema_dct
565 |
566 |
567 | class Criteria2SQL2SQL(Text2SQLDataset):
568 | """Criteria2SQL text2sql dataset from https://github.com/xiaojingyu92/Criteria2SQL."""
569 |
570 | def process_init_kwargs(self, **kwargs: Any) -> None:
571 | """Process kwargs."""
572 | self.train_schema_file = kwargs.pop("train_schema_file", "")
573 | self.val_schema_file = kwargs.pop("val_schema_file", "")
574 | self.test_schema_file = kwargs.pop("test_schema_file", "")
575 |
576 | def load_schema(self) -> dict[str, dict[str, Table]]:
577 | """Load schema for each table in the database."""
578 | schema_dct: dict[str, dict[str, Table]] = {}
579 | schema_file_mapping = {
580 | "train": self.train_schema_file,
581 | "dev": self.val_schema_file,
582 | "test": self.test_schema_file,
583 | }
584 |
585 | for split in sorted(schema_file_mapping):
586 | with jsonlines.open(schema_file_mapping[split], "r") as f:
587 | for table in tqdm(
588 | [line for line in f], desc=f"Loading {split} data schema"
589 | ):
590 | table_name = f"table_{split}_{table['id']}"
591 | # Quote column names to handle spaces
592 | column_names = [
593 | f'"{escape_everything(col)}"' for col in table["header"]
594 | ]
595 | columns = [
596 | TableColumn(name=col, dtype=typ)
597 | for col, typ in zip(column_names, table["types"])
598 | ]
599 | examples = [
600 | {column_names[i]: row[i] for i in range(len(row))}
601 | for row in table["rows"]
602 | ]
603 | if table_name not in schema_dct:
604 | schema_dct[table_name] = {}
605 | # Criteria2SQL is similar to WikiSQL and it uses table_name for
606 | # both db and table name
607 | schema_dct[table_name][table_name] = Table(
608 | name=table_name, columns=columns, examples=examples
609 | )
610 | return schema_dct
611 |
612 | def load_data(
613 | self, schema: dict[str, dict[str, Table]]
614 | ) -> dict[str, list[dict[str, Any]]]:
615 | """Load data."""
616 | splits: dict[str, list[dict[str, Any]]] = {"train": [], "dev": [], "test": []}
617 | data_file_mapping = {
618 | "train": self.train_data_file,
619 | "dev": self.val_data_file,
620 | "test": self.test_data_file,
621 | }
622 |
623 | for split in sorted(data_file_mapping):
624 | with jsonlines.open(data_file_mapping[split], "r") as f:
625 | all_samples = [line for line in f]
626 | for line in tqdm(all_samples, desc=f"Loading {split} data"):
627 | line_to_use = line
628 | line_to_use["db_id"] = f"table_{split}_{line['table_id']}"
629 | ex = convert_criteria2sql_instance(line_to_use, schema=schema)
630 | if ex:
631 | splits[split].append(ex)
632 | console.print(
633 | f"Loaded {split} data: {len(splits[split])} over {len(all_samples)}"
634 | )
635 | return splits
636 |
637 |
638 | class SqlCreateContext2SQL(Text2SQLDataset):
639 | """sql-create-context text2sql dataset from huggingface."""
640 |
641 | def load_schema(self) -> dict[str, dict[str, Table]]:
642 | """Load schema for each table in the database."""
643 | schema_dct: dict[str, dict[str, Table]] = {}
644 | self.dataset = load_dataset("b-mc2/sql-create-context")
645 | for db_id, ex in tqdm(enumerate(self.dataset["train"])):
646 | for table_context in ex["context"].split(";"):
647 | table_context = table_context.strip()
648 | assert table_context.startswith("CREATE TABLE ")
649 | table_context = table_context[len("CREATE TABLE ") :].strip()
650 | table_name = table_context[: table_context.find("(")].strip()
651 | col_context = table_context[len(table_name) :].strip()[1:-1]
652 | cols = [col.strip().split(" ") for col in col_context.split(",")]
653 | columns = [TableColumn(name=col, dtype=typ) for col, typ in cols]
654 |
655 | if db_id not in schema_dct:
656 | schema_dct[db_id] = {}
657 | if table_name not in schema_dct[db_id]:
658 | schema_dct[db_id][table_name] = Table(
659 | name=table_name, columns=columns
660 | )
661 | return schema_dct
662 |
663 | def load_data(
664 | self, schema: dict[str, dict[str, Table]]
665 | ) -> dict[str, list[dict[str, Any]]]:
666 | """Load data."""
667 | splits: dict[str, list[dict[str, Any]]] = {"train": [], "dev": [], "test": []}
668 |
669 | for split in sorted(self.dataset):
670 | for db_id, ex in tqdm(
671 | enumerate(self.dataset[split]), desc=f"Loading {split} data"
672 | ):
673 | line_to_use = ex
674 | line_to_use["db_id"] = db_id
675 | ex = convert_sql_create_context_instance(line_to_use, schema=schema)
676 | if ex:
677 | splits[split].append(ex)
678 | console.print(
679 | f"Loaded {split} data: {len(splits[split])} over {len(self.dataset[split])}"
680 | )
681 | return splits
682 |
683 |
684 | class Squall2SQL(Text2SQLDataset):
685 | """Squall text2sql dataset from huggingface."""
686 |
687 | def load_schema(self) -> dict[str, dict[str, Table]]:
688 | """Load schema for each table in the database."""
689 | schema_dct: dict[str, dict[str, Table]] = {}
690 | self.data = json.load(open(self.train_data_file, "r"))
691 | for i, ex in enumerate(self.data):
692 | table_name = f"table_{ex['tbl']}"
693 | cols = [["id", "number"]] + [
694 | [
695 | f'"{escape_everything(col[0])}"',
696 | col[3] if col[3] == "number" else "text",
697 | ]
698 | for col in ex["columns"]
699 | ]
700 | # Skip the table with duplicate column names
701 | if len(set([col[0] for col in cols])) != len(cols):
702 | continue
703 | # Skip the table with empty column name
704 | if '""' in [col[0] for col in cols]:
705 | continue
706 | columns = [TableColumn(name=col, dtype=typ) for col, typ in cols]
707 | if table_name not in schema_dct:
708 | schema_dct[table_name] = {}
709 | if table_name not in schema_dct[table_name]:
710 | schema_dct[table_name][table_name] = Table(
711 | name=table_name, columns=columns
712 | )
713 | return schema_dct
714 |
715 | def load_data(
716 | self, schema: dict[str, dict[str, Table]]
717 | ) -> dict[str, list[dict[str, Any]]]:
718 | """Load data."""
719 | splits: dict[str, list[dict[str, Any]]] = {"train": [], "dev": [], "test": []}
720 | split = "train"
721 | for i, ex in enumerate(self.data):
722 | line_to_use = ex
723 | line_to_use["db_id"] = f"table_{ex['tbl']}"
724 | if line_to_use["db_id"] not in schema:
725 | continue
726 | ex = convert_squall_instance(line_to_use, schema=schema)
727 | if ex:
728 | splits[split].append(ex)
729 | console.print(
730 | f"Loaded {split} data: {len(splits[split])} over {len(self.data)}"
731 | )
732 | return splits
733 |
734 |
735 | class CSS2SQL(Text2SQLDataset):
736 | """CSS2SQL text2sql dataset from huggingface."""
737 |
738 | def process_init_kwargs(self, **kwargs: Any) -> None:
739 | """Process kwargs."""
740 | self.clean_question = False
741 |
742 | def load_schema(self) -> dict[str, dict[str, Table]]:
743 | """Load schema for each table in the database."""
744 | schema_dct = read_tables_json(self.schema_file)
745 | return schema_dct
746 |
747 | def load_data(
748 | self, schema: dict[str, dict[str, Table]]
749 | ) -> dict[str, list[dict[str, Any]]]:
750 | """Load data."""
751 | splits: dict[str, list[dict[str, Any]]] = {"train": [], "dev": [], "test": []}
752 | ds = load_dataset("zhanghanchong/css")
753 | data_split_mapping = {
754 | "train": self.train_data_file,
755 | "dev": self.val_data_file,
756 | "test": self.test_data_file,
757 | }
758 | for split in splits:
759 | split_cnt = 0
760 | if data_split_mapping[split] is None:
761 | continue
762 | ex_cache: set = set([])
763 | to_read_splits = (
764 | [p for p in data_split_mapping[split].split("@")]
765 | if split in data_split_mapping
766 | else []
767 | )
768 | console.print(f"Loading {split} data", style="bold blue")
769 | for spt in to_read_splits:
770 | split_cnt += len(ds[spt])
771 | for line in ds[spt]:
772 | line["question"] = line["question"].strip()
773 | ex = convert_css_nvbench_instance(line, schema=schema)
774 | if ex:
775 | key = f"{ex['db_id']}###{ex['question']}"
776 | if key in ex_cache:
777 | continue
778 | ex_cache.add(key)
779 | splits[split].append(ex)
780 | console.print(f"Loaded {split} data: {len(splits[split])} over {split_cnt}")
781 | return splits
782 |
783 |
784 | class NVBENCH2SQL(Text2SQLDataset):
785 | """NVBENCH2SQL text2sql dataset from https://github.com/TsinghuaDatabaseGroup/nvBench."""
786 |
787 | def load_schema(self) -> dict[str, dict[str, Table]]:
788 | """Load schema for each table in the database."""
789 | schema_dct: dict[str, dict[str, Table]] = {}
790 |
791 | db_files = [
792 | file
793 | for file in glob(self.schema_file + "**", recursive=True)
794 | if file.endswith(".sqlite")
795 | ]
796 |
797 | for db_file in db_files:
798 | db_id = os.path.basename(os.path.dirname(db_file))
799 | if db_id not in schema_dct:
800 | schema_dct[db_id] = {}
801 |
802 | # Connect to the SQLite database
803 | conn = sqlite3.connect(db_file)
804 |
805 | # Create a cursor object to execute SQL queries
806 | cursor = conn.cursor()
807 |
808 | # Get the list of tables in the database
809 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
810 | tables = cursor.fetchall()
811 |
812 | # Iterate over the tables and retrieve their schemas
813 | for table in tables:
814 | table_name = table[0]
815 | # Execute a PRAGMA query to get the schema of the table
816 | cursor.execute(f"PRAGMA table_info({table_name});")
817 | schema = cursor.fetchall()
818 |
819 | # Get the schema details
820 | columns = [
821 | TableColumn(
822 | name=column[1] if " " not in column[1] else f'"{column[1]}"',
823 | dtype=column[2],
824 | )
825 | for column in schema
826 | ]
827 |
828 | if table_name not in schema_dct[db_id]:
829 | schema_dct[db_id][table_name] = Table(
830 | name=table_name, columns=columns
831 | )
832 |
833 | # Close the cursor and the database connection
834 | cursor.close()
835 | conn.close()
836 |
837 | return schema_dct
838 |
839 | def load_data(
840 | self, schema: dict[str, dict[str, Table]]
841 | ) -> dict[str, list[dict[str, Any]]]:
842 | """Load data."""
843 | splits: dict[str, list[dict[str, Any]]] = {"train": [], "dev": [], "test": []}
844 | data_file_mapping = {
845 | "train": self.train_data_file,
846 | "dev": self.val_data_file,
847 | "test": self.test_data_file,
848 | }
849 | for split in splits:
850 | split_cnt = 0
851 | if data_file_mapping[split] is None:
852 | continue
853 | ex_cache: set = set([])
854 | to_read_files = (
855 | [Path(p) for p in data_file_mapping[split].split("@")]
856 | if split in data_file_mapping
857 | else []
858 | )
859 | console.print(f"Loading {split} data", style="bold blue")
860 | for data_file in to_read_files:
861 | data = json.load(open(data_file, "r"))
862 | for k, v in data.items():
863 | for nl_query in v["nl_queries"]:
864 | split_cnt += 1
865 | if len(nl_query.strip()) == 0:
866 | continue
867 | line = {
868 | "db_id": v["db_id"],
869 | "query": v["vis_query"]["data_part"]["sql_part"].strip(),
870 | "question": nl_query.strip(),
871 | }
872 | ex = convert_css_nvbench_instance(line, schema=schema)
873 | if ex:
874 | key = f"{ex['db_id']}###{ex['question']}"
875 | if key in ex_cache:
876 | continue
877 | ex_cache.add(key)
878 | splits[split].append(ex)
879 | console.print(f"Loaded {split} data: {len(splits[split])} over {split_cnt}")
880 | return splits
881 |
--------------------------------------------------------------------------------
/examples/colab.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "execution": {
8 | "iopub.execute_input": "2023-07-26T17:54:03.363360Z",
9 | "iopub.status.busy": "2023-07-26T17:54:03.363013Z",
10 | "iopub.status.idle": "2023-07-26T17:54:03.384950Z",
11 | "shell.execute_reply": "2023-07-26T17:54:03.384339Z",
12 | "shell.execute_reply.started": "2023-07-26T17:54:03.363334Z"
13 | },
14 | "tags": []
15 | },
16 | "outputs": [],
17 | "source": [
18 | "%load_ext autoreload\n",
19 | "%autoreload 2"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 2,
25 | "metadata": {
26 | "execution": {
27 | "iopub.execute_input": "2023-07-26T17:54:03.554498Z",
28 | "iopub.status.busy": "2023-07-26T17:54:03.554003Z",
29 | "iopub.status.idle": "2023-07-26T17:54:03.565041Z",
30 | "shell.execute_reply": "2023-07-26T17:54:03.564332Z",
31 | "shell.execute_reply.started": "2023-07-26T17:54:03.554473Z"
32 | }
33 | },
34 | "outputs": [],
35 | "source": [
36 | "# Install transformer if you don't have it installed\n",
37 | "# ! pip install transformers"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {},
43 | "source": [
44 | "This is a standalone notebook that can be imported into colab to run."
45 | ]
46 | },
47 | {
48 | "cell_type": "markdown",
49 | "metadata": {},
50 | "source": [
51 | "## Model Setup"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 3,
57 | "metadata": {
58 | "execution": {
59 | "iopub.execute_input": "2023-07-26T17:54:04.109420Z",
60 | "iopub.status.busy": "2023-07-26T17:54:04.108996Z",
61 | "iopub.status.idle": "2023-07-26T17:54:04.120388Z",
62 | "shell.execute_reply": "2023-07-26T17:54:04.119708Z",
63 | "shell.execute_reply.started": "2023-07-26T17:54:04.109396Z"
64 | },
65 | "tags": []
66 | },
67 | "outputs": [],
68 | "source": [
69 | "model_name = \"NumbersStation/nsql-350M\" # <-- You can switch to other models like \"NumbersStation/nsql-6B\""
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": 4,
75 | "metadata": {
76 | "execution": {
77 | "iopub.execute_input": "2023-07-26T17:54:04.286491Z",
78 | "iopub.status.busy": "2023-07-26T17:54:04.286167Z",
79 | "iopub.status.idle": "2023-07-26T17:54:09.554296Z",
80 | "shell.execute_reply": "2023-07-26T17:54:09.553527Z",
81 | "shell.execute_reply.started": "2023-07-26T17:54:04.286468Z"
82 | },
83 | "tags": []
84 | },
85 | "outputs": [],
86 | "source": [
87 | "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
88 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
89 | "model = AutoModelForCausalLM.from_pretrained(model_name)"
90 | ]
91 | },
92 | {
93 | "cell_type": "markdown",
94 | "metadata": {},
95 | "source": [
96 | "## Setup table schema\n",
97 | "\n",
98 | "This is a simple example of database table schema if you want to connect to your own PostgreSQL or SQlite please refer to [other notebooks](https://github.com/NumbersStationAI/NSQL/tree/main/examples)."
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": 5,
104 | "metadata": {
105 | "execution": {
106 | "iopub.execute_input": "2023-07-26T17:54:09.555794Z",
107 | "iopub.status.busy": "2023-07-26T17:54:09.555447Z",
108 | "iopub.status.idle": "2023-07-26T17:54:09.580203Z",
109 | "shell.execute_reply": "2023-07-26T17:54:09.579442Z",
110 | "shell.execute_reply.started": "2023-07-26T17:54:09.555775Z"
111 | },
112 | "tags": []
113 | },
114 | "outputs": [],
115 | "source": [
116 | "table_schema = \"\"\"CREATE TABLE stadium (\n",
117 | " stadium_id number,\n",
118 | " location text,\n",
119 | " name text,\n",
120 | " capacity number,\n",
121 | " highest number,\n",
122 | " lowest number,\n",
123 | " average number\n",
124 | ")\n",
125 | "\n",
126 | "CREATE TABLE singer (\n",
127 | " singer_id number,\n",
128 | " name text,\n",
129 | " country text,\n",
130 | " song_name text,\n",
131 | " song_release_year text,\n",
132 | " age number,\n",
133 | " is_male others\n",
134 | ")\n",
135 | "\n",
136 | "CREATE TABLE concert (\n",
137 | " concert_id number,\n",
138 | " concert_name text,\n",
139 | " theme text,\n",
140 | " stadium_id text,\n",
141 | " year text\n",
142 | ")\n",
143 | "\n",
144 | "CREATE TABLE singer_in_concert (\n",
145 | " concert_id number,\n",
146 | " singer_id text\n",
147 | ")\n",
148 | "\"\"\"\n",
149 | "\n",
150 | "question = \"What is the maximum, the average, and the minimum capacity of stadiums ?\"\n",
151 | "\n",
152 | "prompt = f\"\"\"{table_schema}\n",
153 | "\n",
154 | "-- Using valid SQLite, answer the following questions for the tables provided above.\n",
155 | "\n",
156 | "-- {question}\n",
157 | "\n",
158 | "SELECT\"\"\""
159 | ]
160 | },
161 | {
162 | "cell_type": "markdown",
163 | "metadata": {},
164 | "source": [
165 | "## Generate SQL"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": 9,
171 | "metadata": {
172 | "execution": {
173 | "iopub.execute_input": "2023-07-26T17:55:04.586239Z",
174 | "iopub.status.busy": "2023-07-26T17:55:04.585800Z",
175 | "iopub.status.idle": "2023-07-26T17:55:07.043773Z",
176 | "shell.execute_reply": "2023-07-26T17:55:07.043013Z",
177 | "shell.execute_reply.started": "2023-07-26T17:55:04.586214Z"
178 | },
179 | "tags": []
180 | },
181 | "outputs": [
182 | {
183 | "name": "stdout",
184 | "output_type": "stream",
185 | "text": [
186 | "SELECT MAX(capacity), AVG(capacity), MIN(capacity) FROM stadium;\n"
187 | ]
188 | }
189 | ],
190 | "source": [
191 | "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n",
192 | "generated_ids = model.generate(input_ids, max_length=500)\n",
193 | "output = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
194 | "output = 'SELECT' + output.split('SELECT')[-1]\n",
195 | "\n",
196 | "print(output)"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": null,
202 | "metadata": {},
203 | "outputs": [],
204 | "source": []
205 | }
206 | ],
207 | "metadata": {
208 | "kernelspec": {
209 | "display_name": "Python 3 (ipykernel)",
210 | "language": "python",
211 | "name": "python3"
212 | },
213 | "language_info": {
214 | "codemirror_mode": {
215 | "name": "ipython",
216 | "version": 3
217 | },
218 | "file_extension": ".py",
219 | "mimetype": "text/x-python",
220 | "name": "python",
221 | "nbconvert_exporter": "python",
222 | "pygments_lexer": "ipython3",
223 | "version": "3.10.8"
224 | }
225 | },
226 | "nbformat": 4,
227 | "nbformat_minor": 4
228 | }
229 |
--------------------------------------------------------------------------------
/examples/db_connectors.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | from dataclasses import dataclass
3 | from functools import cached_property
4 | from typing import Any, Generator, List
5 | import pandas as pd
6 | import sqlalchemy
7 |
8 | from prompt_formatters import TableColumn, Table
9 |
10 |
11 | @dataclass
12 | class PostgresConnector:
13 | """Postgres connection."""
14 |
15 | user: str
16 | password: str
17 | dbname: str
18 | host: str
19 | port: int
20 |
21 | @cached_property
22 | def pg_uri(self) -> str:
23 | """Get Postgres URI."""
24 | uri = (
25 | f"postgresql://"
26 | f"{self.user}:{self.password}@{self.host}:{self.port}/{self.dbname}"
27 | )
28 | # ensure we can actually connect to this postgres uri
29 | engine = sqlalchemy.create_engine(uri)
30 | conn = engine.connect()
31 |
32 | # assuming the above connection is successful, we can now close the connection
33 | conn.close()
34 | engine.dispose()
35 |
36 | return uri
37 |
38 | @contextmanager
39 | def connect(self) -> Generator[sqlalchemy.engine.base.Connection, None, None]:
40 | """Yield a connection to a Postgres db.
41 |
42 | Example:
43 | .. code-block:: python
44 | postgres = PostgresConnector(
45 | user=USER, password=PASSWORD, dbname=DBNAME, host=HOST, port=PORT
46 | )
47 | with postgres.connect() as conn:
48 | conn.execute(sql)
49 | """
50 | try:
51 | engine = sqlalchemy.create_engine(self.pg_uri)
52 | conn = engine.connect()
53 | yield conn
54 | finally:
55 | conn.close()
56 | engine.dispose()
57 |
58 | def run_sql_as_df(self, sql: str) -> pd.DataFrame:
59 | """Run SQL statement."""
60 | with self.connect() as conn:
61 | return pd.read_sql(sql, conn)
62 |
63 | def get_tables(self) -> List[str]:
64 | """Get all tables in the database."""
65 | engine = sqlalchemy.create_engine(self.pg_uri)
66 | table_names = engine.table_names()
67 | engine.dispose()
68 | return table_names
69 |
70 | def get_schema(self, table: str) -> Table:
71 | """Return Table."""
72 | with self.connect() as conn:
73 | columns = []
74 | sql = f"""
75 | SELECT column_name, data_type
76 | FROM information_schema.columns
77 | WHERE table_name = '{table}';
78 | """
79 | schema = conn.execute(sql).fetchall()
80 | for col, type_ in schema:
81 | columns.append(TableColumn(name=col, dtype=type_))
82 | return Table(name=table, columns=columns)
83 |
84 |
85 | @dataclass
86 | class SQLiteConnector:
87 | """SQLite connection."""
88 |
89 | database_path: str
90 |
91 | @cached_property
92 | def sqlite_uri(self) -> str:
93 | """Get SQLite URI."""
94 | uri = f"sqlite:///{self.database_path}"
95 | # ensure we can actually connect to this SQLite uri
96 | engine = sqlalchemy.create_engine(uri)
97 | conn = engine.connect()
98 |
99 | # assuming the above connection is successful, we can now close the connection
100 | conn.close()
101 | engine.dispose()
102 |
103 | return uri
104 |
105 | @contextmanager
106 | def connect(self) -> Generator[sqlalchemy.engine.base.Connection, None, None]:
107 | """Yield a connection to a SQLite database.
108 |
109 | Example:
110 | .. code-block:: python
111 | sqlite = SQLiteConnector(database_path=DB_PATH)
112 | with sqlite.connect() as conn:
113 | conn.execute(sql)
114 | """
115 | try:
116 | engine = sqlalchemy.create_engine(self.sqlite_uri)
117 | conn = engine.connect()
118 | yield conn
119 | finally:
120 | conn.close()
121 | engine.dispose()
122 |
123 | def get_tables(self) -> List[str]:
124 | """Get all tables in the database."""
125 | engine = sqlalchemy.create_engine(self.sqlite_uri)
126 | table_names = engine.table_names()
127 | engine.dispose()
128 | return table_names
129 |
130 | def run_sql_as_df(self, sql: str) -> pd.DataFrame:
131 | """Run SQL statement."""
132 | with self.connect() as conn:
133 | return pd.read_sql(sql, conn)
134 |
135 | def get_schema(self, table: str) -> Table:
136 | """Return Table."""
137 | with self.connect() as conn:
138 | columns = []
139 | sql = f"PRAGMA table_info({table});"
140 | schema = conn.execute(sql).fetchall()
141 | for row in schema:
142 | col = row[1]
143 | type_ = row[2]
144 | columns.append(TableColumn(name=col, dtype=type_))
145 | return Table(name=table, columns=columns)
146 |
--------------------------------------------------------------------------------
/examples/finetune.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "3652e112-d916-48bf-a49d-20ecfa01ed52",
7 | "metadata": {
8 | "tags": []
9 | },
10 | "outputs": [],
11 | "source": [
12 | "%load_ext autoreload\n",
13 | "%autoreload 2"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": null,
19 | "id": "9178d12e-1058-4a12-b7ad-3f003361775b",
20 | "metadata": {
21 | "tags": []
22 | },
23 | "outputs": [],
24 | "source": [
25 | "# Install transformer and peft if you don't have it installed\n",
26 | "# ! pip install transformers==4.31.0\n",
27 | "# ! pip install peft\n",
28 | "# ! pip install accelerate\n",
29 | "# ! pip install bitsandbytes"
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "id": "e5a3c3e2-0a1d-429d-94d2-3f81838a010c",
35 | "metadata": {},
36 | "source": [
37 | "This is a standalone notebook to train the NSQL model on a single GPU (e.g., A5000 with 24GB) with int8 and LoRA."
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "id": "b77a70b6-a0ac-4f18-a0c8-354839d149ce",
43 | "metadata": {
44 | "execution": {
45 | "iopub.execute_input": "2023-08-02T17:16:32.100885Z",
46 | "iopub.status.busy": "2023-08-02T17:16:32.100451Z",
47 | "iopub.status.idle": "2023-08-02T17:16:32.111648Z",
48 | "shell.execute_reply": "2023-08-02T17:16:32.111052Z",
49 | "shell.execute_reply.started": "2023-08-02T17:16:32.100860Z"
50 | },
51 | "tags": []
52 | },
53 | "source": [
54 | "# Load the model"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": null,
60 | "id": "ed5d1984-7b44-43de-af56-30c7943140aa",
61 | "metadata": {
62 | "tags": []
63 | },
64 | "outputs": [],
65 | "source": [
66 | "import torch\n",
67 | "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
68 | "tokenizer = AutoTokenizer.from_pretrained(\"NumbersStation/nsql-llama-2-7B\")\n",
69 | "model = AutoModelForCausalLM.from_pretrained(\"NumbersStation/nsql-llama-2-7B\", load_in_8bit=True, torch_dtype=torch.bfloat16, device_map='auto')"
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "id": "939db736-3d70-4f2b-807c-efc029fe8ab7",
75 | "metadata": {},
76 | "source": [
77 | "# Prepare the data\n",
78 | "\n",
79 | "We use NumbersStation/NSText2SQL dataset as an example here and feel free to customize the training data based on your need."
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "id": "5b38f8db-7910-4b9d-9061-53fd11ede153",
86 | "metadata": {
87 | "tags": []
88 | },
89 | "outputs": [],
90 | "source": [
91 | "from datasets import load_dataset\n",
92 | "from torch.utils.data import Dataset\n",
93 | "import copy\n",
94 | "\n",
95 | "class NSText2SQLDataset(Dataset):\n",
96 | " def __init__(self, size=None, max_seq_length=2048):\n",
97 | " self.dataset = load_dataset(\"NumbersStation/NSText2SQL\",split=\"train\")\n",
98 | " if size:\n",
99 | " self.dataset = self.dataset.select(range(size))\n",
100 | " self.max_seq_length = max_seq_length\n",
101 | "\n",
102 | " def __len__(self):\n",
103 | " return len(self.dataset)\n",
104 | "\n",
105 | " def __getitem__(self, index):\n",
106 | " instruction = torch.tensor(tokenizer.encode(self.dataset[index]['instruction']), dtype=torch.int64)\n",
107 | " example = self.dataset[index]['instruction'] + self.dataset[index][\"output\"]\n",
108 | " example = tokenizer.encode(example)\n",
109 | " example.append(tokenizer.eos_token_id)\n",
110 | " padding = self.max_seq_length - len(example)\n",
111 | " example = torch.tensor(example, dtype=torch.int64)\n",
112 | "\n",
113 | " if padding < 0:\n",
114 | " example = example[:self.max_seq_length]\n",
115 | " else:\n",
116 | " example = torch.cat((example, torch.zeros(padding, dtype=torch.int64)))\n",
117 | " \n",
118 | " labels = copy.deepcopy(example)\n",
119 | " labels[: len(instruction)] = -100\n",
120 | " \n",
121 | " return {\"input_ids\": example, \"labels\": labels}"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": null,
127 | "id": "d669cd4e-1b53-4810-aa5a-8194916b01c0",
128 | "metadata": {
129 | "tags": []
130 | },
131 | "outputs": [],
132 | "source": [
133 | "dataset = NSText2SQLDataset(size=1000, max_seq_length=1024)"
134 | ]
135 | },
136 | {
137 | "cell_type": "markdown",
138 | "id": "78aa7f7a-1fb5-417e-bf92-e07f8c172a30",
139 | "metadata": {},
140 | "source": [
141 | "# Prepare PEFT for model training"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": null,
147 | "id": "1627483c-2739-4cc9-8d85-e7353b4c61b5",
148 | "metadata": {
149 | "tags": []
150 | },
151 | "outputs": [],
152 | "source": [
153 | "from peft import (\n",
154 | " get_peft_model,\n",
155 | " LoraConfig,\n",
156 | " TaskType,\n",
157 | " prepare_model_for_int8_training,\n",
158 | ")\n",
159 | "\n",
160 | "\n",
161 | "model.train()\n",
162 | "\n",
163 | "model = prepare_model_for_int8_training(model)\n",
164 | "\n",
165 | "lora_config = LoraConfig(\n",
166 | " task_type=TaskType.CAUSAL_LM,\n",
167 | " inference_mode=False,\n",
168 | " r=8,\n",
169 | " lora_alpha=32,\n",
170 | " lora_dropout=0.05,\n",
171 | " target_modules = [\"q_proj\", \"v_proj\"]\n",
172 | ")\n",
173 | "\n",
174 | "model = get_peft_model(model, lora_config)"
175 | ]
176 | },
177 | {
178 | "cell_type": "markdown",
179 | "id": "b9a8b18d-a25a-4d12-ac00-0f720703f579",
180 | "metadata": {
181 | "execution": {
182 | "iopub.execute_input": "2023-08-02T17:50:38.809539Z",
183 | "iopub.status.busy": "2023-08-02T17:50:38.809187Z",
184 | "iopub.status.idle": "2023-08-02T17:50:38.836443Z",
185 | "shell.execute_reply": "2023-08-02T17:50:38.835678Z",
186 | "shell.execute_reply.started": "2023-08-02T17:50:38.809520Z"
187 | },
188 | "tags": []
189 | },
190 | "source": [
191 | "# Finetune the model with Huggingface trainer"
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": null,
197 | "id": "997db2e9-9b2e-4169-9b18-695d403f70e4",
198 | "metadata": {
199 | "tags": []
200 | },
201 | "outputs": [],
202 | "source": [
203 | "from transformers import default_data_collator, Trainer, TrainingArguments\n",
204 | "\n",
205 | "output_dir = \"training_run\"\n",
206 | "\n",
207 | "config = {\n",
208 | " 'lora_config': lora_config,\n",
209 | " 'learning_rate': 1e-4,\n",
210 | " 'num_train_epochs': 1,\n",
211 | " 'gradient_accumulation_steps': 2,\n",
212 | " 'gradient_checkpointing': False,\n",
213 | "}\n",
214 | "\n",
215 | "\n",
216 | "training_args = TrainingArguments(\n",
217 | " output_dir=output_dir,\n",
218 | " overwrite_output_dir=True,\n",
219 | " bf16=True,\n",
220 | " # logging strategies\n",
221 | " logging_dir=f\"{output_dir}/logs\",\n",
222 | " logging_strategy=\"steps\",\n",
223 | " logging_steps=5,\n",
224 | " optim=\"adamw_torch_fused\",\n",
225 | " **{k:v for k,v in config.items() if k != 'lora_config'}\n",
226 | ")"
227 | ]
228 | },
229 | {
230 | "cell_type": "code",
231 | "execution_count": null,
232 | "id": "007525f0-41ac-4de7-aba8-92e547931972",
233 | "metadata": {
234 | "tags": []
235 | },
236 | "outputs": [],
237 | "source": [
238 | "trainer = Trainer(\n",
239 | " model=model,\n",
240 | " args=training_args,\n",
241 | " train_dataset=dataset,\n",
242 | " data_collator=default_data_collator,\n",
243 | ")\n",
244 | "\n",
245 | "# Start training\n",
246 | "trainer.train()"
247 | ]
248 | },
249 | {
250 | "cell_type": "markdown",
251 | "id": "680eb420-2c91-407c-b19c-89ea0391b591",
252 | "metadata": {},
253 | "source": [
254 | "# Save model checkpoint"
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": null,
260 | "id": "fc1fb82e-b158-40d7-92bf-d63446dc12ae",
261 | "metadata": {
262 | "tags": []
263 | },
264 | "outputs": [],
265 | "source": [
266 | "model.save_pretrained(output_dir)"
267 | ]
268 | },
269 | {
270 | "cell_type": "markdown",
271 | "id": "1eaeb1d5-1b49-4873-a205-b3ba94af4214",
272 | "metadata": {},
273 | "source": [
274 | "# Evaluate the finetuned model"
275 | ]
276 | },
277 | {
278 | "cell_type": "code",
279 | "execution_count": null,
280 | "id": "75e98d82-07de-48fa-8c40-fb59e7805b08",
281 | "metadata": {
282 | "tags": []
283 | },
284 | "outputs": [],
285 | "source": [
286 | "model.eval()"
287 | ]
288 | },
289 | {
290 | "cell_type": "code",
291 | "execution_count": null,
292 | "id": "fad30645-9472-4d29-94f8-94e20fdde79a",
293 | "metadata": {
294 | "tags": []
295 | },
296 | "outputs": [],
297 | "source": [
298 | "text = \"\"\"CREATE TABLE stadium (\n",
299 | " stadium_id number,\n",
300 | " location text,\n",
301 | " name text,\n",
302 | " capacity number,\n",
303 | ")\n",
304 | "\n",
305 | "-- Using valid SQLite, answer the following questions for the tables provided above.\n",
306 | "\n",
307 | "-- how many stadiums in total?\n",
308 | "\n",
309 | "SELECT\"\"\"\n",
310 | "\n",
311 | "model_input = tokenizer(text, return_tensors=\"pt\").to(\"cuda\")\n",
312 | "\n",
313 | "generated_ids = model.generate(**model_input, max_new_tokens=100)\n",
314 | "print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))"
315 | ]
316 | },
317 | {
318 | "cell_type": "code",
319 | "execution_count": null,
320 | "id": "532632db-3579-48dd-931b-b6aa4669ce87",
321 | "metadata": {},
322 | "outputs": [],
323 | "source": []
324 | }
325 | ],
326 | "metadata": {
327 | "kernelspec": {
328 | "display_name": "Python 3 (ipykernel)",
329 | "language": "python",
330 | "name": "python3"
331 | },
332 | "language_info": {
333 | "codemirror_mode": {
334 | "name": "ipython",
335 | "version": 3
336 | },
337 | "file_extension": ".py",
338 | "mimetype": "text/x-python",
339 | "name": "python",
340 | "nbconvert_exporter": "python",
341 | "pygments_lexer": "ipython3",
342 | "version": "3.10.8"
343 | }
344 | },
345 | "nbformat": 4,
346 | "nbformat_minor": 5
347 | }
348 |
--------------------------------------------------------------------------------
/examples/postgres.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "attachments": {},
15 | "cell_type": "markdown",
16 | "metadata": {},
17 | "source": [
18 | "## DB Setup\n",
19 | "\n",
20 | "We assume you already have a postgres database ready."
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": null,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "DATABASE = \"database\"\n",
30 | "USER = \"postgres\"\n",
31 | "PASSWORD = \"password\"\n",
32 | "HOST = \"localhost\"\n",
33 | "PORT = 5432\n",
34 | "TABLES = [] # list of tables to load or [] to load all tables"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "from db_connectors import PostgresConnector\n",
44 | "from prompt_formatters import RajkumarFormatter\n",
45 | "\n",
46 | "# Get the connector and formatter\n",
47 | "postgres_connector = PostgresConnector(\n",
48 | " user=USER, password=PASSWORD, dbname=DATABASE, host=HOST, port=PORT\n",
49 | ")\n",
50 | "postgres_connector.connect()\n",
51 | "if len(TABLES) <= 0:\n",
52 | " TABLES.extend(postgres_connector.get_tables())\n",
53 | "\n",
54 | "print(f\"Loading tables: {TABLES}\")\n",
55 | "\n",
56 | "db_schema = [postgres_connector.get_schema(table) for table in TABLES]\n",
57 | "formatter = RajkumarFormatter(db_schema)"
58 | ]
59 | },
60 | {
61 | "attachments": {},
62 | "cell_type": "markdown",
63 | "metadata": {},
64 | "source": [
65 | "## Model Setup\n",
66 | "\n",
67 | "In a separate screen or window, first install [Manifest](https://github.com/HazyResearch/manifest)\n",
68 | "```bash\n",
69 | "pip install manifest-ml\\[all\\]\n",
70 | "```\n",
71 | "\n",
72 | "Then run\n",
73 | "```bash\n",
74 | "python3 -m manifest.api.app \\\n",
75 | " --model_type huggingface \\\n",
76 | " --model_generation_type text-generation \\\n",
77 | " --model_name_or_path NumbersStation/nsql-350M \\\n",
78 | " --device 0\n",
79 | "```\n",
80 | "\n",
81 | "If successful, you will see an output like\n",
82 | "```bash\n",
83 | "* Running on http://127.0.0.1:5000\n",
84 | "```"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "from manifest import Manifest\n",
94 | "\n",
95 | "manifest_client = Manifest(client_name=\"huggingface\", client_connection=\"http://127.0.0.1:5000\")\n",
96 | "\n",
97 | "def get_sql(instruction: str, max_tokens: int = 300) -> str:\n",
98 | " prompt = formatter.format_prompt(instruction)\n",
99 | " res = manifest_client.run(prompt, max_tokens=max_tokens)\n",
100 | " return formatter.format_model_output(res)"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": null,
106 | "metadata": {},
107 | "outputs": [],
108 | "source": [
109 | "sql = get_sql(\"Number of rows in table?\")\n",
110 | "print(sql)"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": null,
116 | "metadata": {},
117 | "outputs": [],
118 | "source": [
119 | "print(postgres_connector.run_sql_as_df(sql))"
120 | ]
121 | }
122 | ],
123 | "metadata": {
124 | "kernelspec": {
125 | "display_name": "dbt",
126 | "language": "python",
127 | "name": "python3"
128 | },
129 | "language_info": {
130 | "codemirror_mode": {
131 | "name": "ipython",
132 | "version": 3
133 | },
134 | "file_extension": ".py",
135 | "mimetype": "text/x-python",
136 | "name": "python",
137 | "nbconvert_exporter": "python",
138 | "pygments_lexer": "ipython3",
139 | "version": "3.10.0"
140 | },
141 | "orig_nbformat": 4
142 | },
143 | "nbformat": 4,
144 | "nbformat_minor": 2
145 | }
146 |
--------------------------------------------------------------------------------
/examples/prompt_formatters.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 |
3 |
4 | class TableColumn(BaseModel):
5 | """Table column."""
6 |
7 | name: str
8 | dtype: str | None
9 |
10 |
11 | class ForeignKey(BaseModel):
12 | """Foreign key."""
13 |
14 | # Referenced column
15 | column: TableColumn
16 | # References table name
17 | references_name: str
18 | # References column
19 | references_column: TableColumn
20 |
21 |
22 | class Table(BaseModel):
23 | """Table."""
24 |
25 | name: str
26 | columns: list[TableColumn] | None
27 | pks: list[TableColumn] | None
28 | # FK from this table to another column in another table
29 | fks: list[ForeignKey] | None
30 |
31 |
32 | class RajkumarFormatter:
33 | """RajkumarFormatter class.
34 |
35 | From https://arxiv.org/pdf/2204.00498.pdf.
36 | """
37 |
38 | table_sep: str = "\n\n"
39 |
40 | def __init__(self, tables: list[Table]) -> None:
41 | self.tables = tables
42 | self.table_str = self.format_tables(tables)
43 |
44 | def format_table(self, table: Table) -> str:
45 | """Get table format."""
46 | table_fmt = []
47 | table_name = table.name
48 | for col in table.columns or []:
49 | # This is technically an incorrect type, but it should be a catchall word
50 | table_fmt.append(f" {col.name} {col.dtype or 'any'}")
51 | if table.pks:
52 | table_fmt.append(
53 | f" primary key ({', '.join(pk.name for pk in table.pks)})"
54 | )
55 | for fk in table.fks or []:
56 | table_fmt.append(
57 | f" foreign key ({fk.column.name}) references {fk.references_name}({fk.references_column.name})" # noqa: E501
58 | )
59 | if table_fmt:
60 | all_cols = ",\n".join(table_fmt)
61 | create_tbl = f"CREATE TABLE {table_name} (\n{all_cols}\n)"
62 | else:
63 | create_tbl = f"CREATE TABLE {table_name}"
64 | return create_tbl
65 |
66 | def format_tables(self, tables: list[Table]) -> str:
67 | """Get tables format."""
68 | return self.table_sep.join(self.format_table(table) for table in tables)
69 |
70 | def format_prompt(
71 | self,
72 | instruction: str,
73 | ) -> str:
74 | """Get prompt format."""
75 | sql_prefix = "SELECT"
76 | return f"""{self.table_str}\n\n\n-- Using valid SQLite, answer the following questions for the tables provided above.\n\n-- {instruction}\n{sql_prefix}""" # noqa: E501
77 |
78 | def format_model_output(self, output_sql: str) -> str:
79 | """Format model output.
80 |
81 | Our prompt ends with SELECT so we need to add it back.
82 | """
83 | if not output_sql.lower().startswith("select"):
84 | output_sql = "SELECT " + output_sql.strip()
85 | return output_sql
86 |
--------------------------------------------------------------------------------
/examples/sqlite.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "attachments": {},
15 | "cell_type": "markdown",
16 | "metadata": {},
17 | "source": [
18 | "## DB Setup\n",
19 | "\n",
20 | "We assume you already have a sqlite database ready."
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": null,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "DATABASE = \"testDB.db\"\n",
30 | "TABLES = [] # list of tables to load or [] to load all tables"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "from db_connectors import SQLiteConnector\n",
40 | "from prompt_formatters import RajkumarFormatter\n",
41 | "\n",
42 | "# Get the connector and formatter\n",
43 | "sqlite_connector = SQLiteConnector(\n",
44 | " database_path=DATABASE\n",
45 | ")\n",
46 | "sqlite_connector.connect()\n",
47 | "if len(TABLES) <= 0:\n",
48 | " TABLES.extend(sqlite_connector.get_tables())\n",
49 | "\n",
50 | "print(f\"Loading tables: {TABLES}\")\n",
51 | "\n",
52 | "db_schema = [sqlite_connector.get_schema(table) for table in TABLES]\n",
53 | "formatter = RajkumarFormatter(db_schema)"
54 | ]
55 | },
56 | {
57 | "attachments": {},
58 | "cell_type": "markdown",
59 | "metadata": {},
60 | "source": [
61 | "## Model Setup\n",
62 | "\n",
63 | "In a separate screen or window, first install [Manifest](https://github.com/HazyResearch/manifest)\n",
64 | "```bash\n",
65 | "pip install manifest-ml\\[all\\]\n",
66 | "```\n",
67 | "\n",
68 | "Then run\n",
69 | "```bash\n",
70 | "python3 -m manifest.api.app \\\n",
71 | " --model_type huggingface \\\n",
72 | " --model_generation_type text-generation \\\n",
73 | " --model_name_or_path NumbersStation/nsql-350M \\\n",
74 | " --device 0\n",
75 | "```\n",
76 | "\n",
77 | "If successful, you will see an output like\n",
78 | "```bash\n",
79 | "* Running on http://127.0.0.1:5000\n",
80 | "```"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": null,
86 | "metadata": {},
87 | "outputs": [],
88 | "source": [
89 | "from manifest import Manifest\n",
90 | "\n",
91 | "manifest_client = Manifest(client_name=\"huggingface\", client_connection=\"http://127.0.0.1:5000\")\n",
92 | "\n",
93 | "def get_sql(instruction: str, max_tokens: int = 300) -> str:\n",
94 | " prompt = formatter.format_prompt(instruction)\n",
95 | " res = manifest_client.run(prompt, max_tokens=max_tokens)\n",
96 | " return formatter.format_model_output(res)"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": null,
102 | "metadata": {},
103 | "outputs": [],
104 | "source": [
105 | "sql = get_sql(\"Number of rows in table?\")\n",
106 | "print(sql)"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": null,
112 | "metadata": {},
113 | "outputs": [],
114 | "source": [
115 | "print(sqlite_connector.run_sql_as_df(sql))"
116 | ]
117 | }
118 | ],
119 | "metadata": {
120 | "kernelspec": {
121 | "display_name": "dbt",
122 | "language": "python",
123 | "name": "python3"
124 | },
125 | "language_info": {
126 | "codemirror_mode": {
127 | "name": "ipython",
128 | "version": 3
129 | },
130 | "file_extension": ".py",
131 | "mimetype": "text/x-python",
132 | "name": "python",
133 | "nbconvert_exporter": "python",
134 | "pygments_lexer": "ipython3",
135 | "version": "3.10.0"
136 | },
137 | "orig_nbformat": 4
138 | },
139 | "nbformat": 4,
140 | "nbformat_minor": 2
141 | }
142 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | manifest-ml[all]==0.1.8
2 | pandas>=2.0.0
3 | sqlalchemy<2.0.0
4 | transformers>=4.29.0
5 | datasets==2.11.0
6 | jsonlines>=3.1.0
7 | sqlglot==11.5.5
8 |
--------------------------------------------------------------------------------