├── LICENSE ├── README.md ├── data ├── annotated_qdmr │ ├── imdb_qdmr.csv │ └── yelp_qdmr.csv ├── sql_synthesis_results │ ├── gold_qdmr_supervision │ │ ├── academic.csv │ │ ├── geo880.csv │ │ ├── gold_qdmr_synthesized_full.zip │ │ ├── imdb.csv │ │ ├── spider_dev.csv │ │ ├── spider_train.csv │ │ └── yelp.csv │ ├── grounding_statistics.txt │ ├── predicted_qdmr_supervision │ │ ├── geo880_dev_test_pred.csv │ │ ├── geo880_dev_test_pred.json │ │ ├── spider_dev_pred.csv │ │ └── spider_dev_pred.json │ └── sql_synthesis_input_example.csv └── text_to_sql │ ├── encoded_qdmr_datasets.zip │ └── gold_sql_datasets.zip ├── requirements.txt ├── requirements_qdmr_parser.txt └── src ├── data_generation ├── db_schema.py ├── graph_utils.py ├── ground_example.py ├── grounded_qdmr.py ├── grounding_repairs.py ├── main.py ├── operator_identifier.py ├── predicted_sql.py ├── preprocess_db.py ├── preprocess_grounding_data.py ├── qdmr_editor.py ├── qdmr_encoding.py ├── qdmr_encoding_parser.py ├── qdmr_grounding.py ├── qdmr_identifier.py ├── schema_parser.py ├── sql_execution.py ├── sql_parser.py ├── sql_query.py ├── test_encoding_conversion.py ├── test_grounded_qdmr.py ├── test_predicted_grounded_qdmr.py ├── utils.py ├── write_encoding.py └── write_grounding.py ├── qdmr_parser ├── dataset_qdmr.py ├── eval_qdmr │ ├── eval_string_match.py │ └── sari_hook.py ├── model.py ├── test.py ├── train.py └── utils_data.py └── text_to_sql ├── dataset_qdmr.py ├── dataset_spider.py ├── dataset_utils.py ├── eval_exec ├── db_schema.py ├── graph_utils.py ├── grounded_qdmr.py ├── operator_identifier.py ├── predicted_sql.py ├── preprocess_db.py ├── qdmr_encoding.py ├── qdmr_encoding_parser.py ├── qdmr_grounding.py ├── qdmr_identifier.py ├── qdmr_sql.py ├── schema_parser.py ├── sql_execution.py ├── sql_parser.py └── utils.py ├── eval_spider.py ├── evaluation.py ├── model.py ├── process_sql.py ├── test.py ├── test_exec_eval.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 tomerwolgithub 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/sql_synthesis_results/gold_qdmr_supervision/gold_qdmr_synthesized_full.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomerwolgithub/question-decomposition-to-sql/955b2d221fde73548ddb29239e41cf1080d17777/data/sql_synthesis_results/gold_qdmr_supervision/gold_qdmr_synthesized_full.zip -------------------------------------------------------------------------------- /data/sql_synthesis_results/grounding_statistics.txt: -------------------------------------------------------------------------------- 1 | Academic all examples: 195 2 | Academic all examples correctly grounded: 155 3 | 4 | Geo880 all examples: 877 5 | Geo880 all examples correctly grounded: 736 6 | 7 | IMDB all examples: 132 8 | IMDB all examples correctly grounded: 116 9 | 10 | Yelp all examples: 128 11 | Yelp all examples correctly grounded: 100 12 | 13 | Spider train all examples: 6955 14 | Spider train all examples correctly grounded: 5375 15 | 16 | Spider dev all examples: 1027 17 | Spider dev non-empty examples correctly grounded: 793 18 | 19 | Total all examples: 9313 20 | Total all examples correctly grounded: 7249 21 | 22 | ---------- 23 | 24 | Academic non-empty examples: 183 25 | Academic non-empty examples correctly grounded: 148 26 | 27 | Geo880 non-empty examples: 846 28 | Geo880 non-empty examples correctly grounded: 707 29 | 30 | IMDB non-empty examples: 113 31 | IMDB non-empty examples correctly grounded: 101 32 | 33 | Yelp non-empty examples: 66 34 | Yelp non-empty examples correctly grounded: 54 35 | 36 | Spider train non-empty examples: 6701 37 | Spider train non-empty examples correctly grounded: 5137 38 | 39 | Spider dev non-empty examples: 978 40 | Spider dev non-empty examples correctly grounded: 745 41 | 42 | Total non-empty examples: 8903 43 | Total non-empty examples correctly grounded: 6892 44 | 45 | ---------- 46 | 47 | Geo880 Pred. non-empty examples: 288 48 | Geo880 Pred. non-empty examples correctly grounded: 277 49 | 50 | Spider dev Pred. non-empty examples: 978 51 | Spider dev Pred. non-empty examples correctly grounded: 750 -------------------------------------------------------------------------------- /data/sql_synthesis_results/sql_synthesis_input_example.csv: -------------------------------------------------------------------------------- 1 | ,index,db_id,question,query,qdmr 2 | 0,ACADEMIC_train_0,academic,return me the homepage of PVLDB .,"select journal_0.homepage from journal as journal_0 where journal_0.name = ""PVLDB""",return homepages ;return #1 of PVLDB 3 | 1,ACADEMIC_train_1,academic,"return me the homepage of "" H. V. Jagadish "" .","select author_0.homepage from author as author_0 where author_0.name = ""H. V. Jagadish""",return homepages ;return #1 of H. V. Jagadish 4 | 2,ACADEMIC_train_10,academic,"return me the number of references of "" Making database systems usable "" .","select publication_0.reference_num from publication as publication_0 where publication_0.title = ""Making database systems usable""",return references ;return #1 of Making database systems usable ;return number of #2 5 | -------------------------------------------------------------------------------- /data/text_to_sql/encoded_qdmr_datasets.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomerwolgithub/question-decomposition-to-sql/955b2d221fde73548ddb29239e41cf1080d17777/data/text_to_sql/encoded_qdmr_datasets.zip -------------------------------------------------------------------------------- /data/text_to_sql/gold_sql_datasets.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomerwolgithub/question-decomposition-to-sql/955b2d221fde73548ddb29239e41cf1080d17777/data/text_to_sql/gold_sql_datasets.zip -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.10.0 2 | alabaster==0.7.12 3 | albumentations==0.1.12 4 | altair==4.1.0 5 | appdirs==1.4.4 6 | argon2-cffi==20.1.0 7 | asgiref==3.3.1 8 | astor==0.8.1 9 | astropy==4.2 10 | astunparse==1.6.3 11 | async-generator==1.10 12 | atari-py==0.2.6 13 | atomicwrites==1.4.0 14 | attrs==20.3.0 15 | audioread==2.1.9 16 | autograd==1.3 17 | Babel==2.9.0 18 | backcall==0.2.0 19 | beautifulsoup4==4.6.3 20 | bleach==3.3.0 21 | blis==0.4.1 22 | bokeh==2.1.1 23 | Bottleneck==1.3.2 24 | branca==0.4.2 25 | bs4==0.0.1 26 | CacheControl==0.12.6 27 | cachetools==4.2.1 28 | catalogue==1.0.0 29 | certifi==2020.12.5 30 | cffi==1.14.5 31 | chainer==7.4.0 32 | chardet==3.0.4 33 | click==7.1.2 34 | cloudpickle==1.3.0 35 | cmake==3.12.0 36 | cmdstanpy==0.9.5 37 | colorlover==0.3.0 38 | community==1.0.0b1 39 | contextlib2==0.5.5 40 | convertdate==2.3.1 41 | coverage==3.7.1 42 | coveralls==0.5 43 | crcmod==1.7 44 | cufflinks==0.17.3 45 | cupy-cuda101==7.4.0 46 | cvxopt==1.2.6 47 | cvxpy==1.0.31 48 | cycler==0.10.0 49 | cymem==2.0.5 50 | Cython==0.29.22 51 | daft==0.0.4 52 | dask==2.12.0 53 | datascience==0.10.6 54 | debugpy==1.0.0 55 | decorator==4.4.2 56 | defusedxml==0.7.1 57 | descartes==1.1.0 58 | dill==0.3.3 59 | distributed==1.25.3 60 | Django==3.1.7 61 | dlib==19.18.0 62 | dm-tree==0.1.5 63 | docopt==0.6.2 64 | docutils==0.16 65 | dopamine-rl==1.0.5 66 | earthengine-api==0.1.255 67 | easydict==1.9 68 | ecos==2.0.7.post1 69 | editdistance==0.5.3 70 | en-core-web-sm==2.2.5 71 | entrypoints==0.3 72 | ephem==3.7.7.1 73 | et-xmlfile==1.0.1 74 | fa2==0.3.5 75 | fancyimpute==0.4.3 76 | fastai==1.0.61 77 | fastdtw==0.3.4 78 | fastprogress==1.0.0 79 | fastrlock==0.5 80 | fbprophet==0.7.1 81 | feather-format==0.4.1 82 | filelock==3.0.12 83 | firebase-admin==4.4.0 84 | fix-yahoo-finance==0.0.22 85 | Flask==1.1.2 86 | flatbuffers==1.12 87 | folium==0.8.3 88 | future==0.18.2 89 | gast==0.3.3 90 | GDAL==2.2.2 91 | gdown==3.6.4 92 | gensim==3.6.0 93 | geographiclib==1.50 94 | geopy==1.17.0 95 | gin-config==0.4.0 96 | glob2==0.7 97 | google==2.0.3 98 | google-api-core==1.26.1 99 | google-api-python-client==1.12.8 100 | google-auth==1.27.1 101 | google-auth-httplib2==0.0.4 102 | google-auth-oauthlib==0.4.3 103 | google-cloud-bigquery==1.21.0 104 | google-cloud-bigquery-storage==1.1.0 105 | google-cloud-core==1.0.3 106 | google-cloud-datastore==1.8.0 107 | google-cloud-firestore==1.7.0 108 | google-cloud-language==1.2.0 109 | google-cloud-storage==1.18.1 110 | google-cloud-translate==1.5.0 111 | google-colab==1.0.0 112 | google-pasta==0.2.0 113 | google-resumable-media==0.4.1 114 | googleapis-common-protos==1.53.0 115 | googledrivedownloader==0.4 116 | graphviz==0.10.1 117 | grpcio==1.32.0 118 | gspread==3.0.1 119 | gspread-dataframe==3.0.8 120 | gym==0.17.3 121 | h5py==2.10.0 122 | HeapDict==1.0.1 123 | hijri-converter==2.1.1 124 | holidays==0.10.5.2 125 | holoviews==1.13.5 126 | html5lib==1.0.1 127 | httpimport==0.5.18 128 | httplib2==0.17.4 129 | httplib2shim==0.0.3 130 | humanize==0.5.1 131 | hyperopt==0.1.2 132 | ideep4py==2.0.0.post3 133 | idna==2.10 134 | image==1.5.33 135 | imageio==2.4.1 136 | imagesize==1.2.0 137 | imbalanced-learn==0.4.3 138 | imblearn==0.0 139 | imgaug==0.2.9 140 | importlib-metadata==3.7.2 141 | importlib-resources==5.1.2 142 | imutils==0.5.4 143 | inflect==2.1.0 144 | iniconfig==1.1.1 145 | intel-openmp==2021.1.2 146 | intervaltree==2.1.0 147 | ipykernel==4.10.1 148 | ipython==5.5.0 149 | ipython-genutils==0.2.0 150 | ipython-sql==0.3.9 151 | ipywidgets==7.6.3 152 | itsdangerous==1.1.0 153 | jax==0.2.10 154 | jaxlib==0.1.62+cuda110 155 | jdcal==1.4.1 156 | jedi==0.18.0 157 | jieba==0.42.1 158 | Jinja2==2.11.3 159 | joblib==1.0.1 160 | jpeg4py==0.1.4 161 | jsonschema==2.6.0 162 | jupyter==1.0.0 163 | jupyter-client==5.3.5 164 | jupyter-console==5.2.0 165 | jupyter-core==4.7.1 166 | jupyterlab-pygments==0.1.2 167 | jupyterlab-widgets==1.0.0 168 | kaggle==1.5.10 169 | kapre==0.1.3.1 170 | Keras==2.4.3 171 | Keras-Preprocessing==1.1.2 172 | keras-vis==0.4.1 173 | kiwisolver==1.3.1 174 | knnimpute==0.1.0 175 | korean-lunar-calendar==0.2.1 176 | librosa==0.8.0 177 | lightgbm==2.2.3 178 | llvmlite==0.34.0 179 | lmdb==0.99 180 | lucid==0.3.8 181 | LunarCalendar==0.0.9 182 | lxml==4.2.6 183 | Markdown==3.3.4 184 | MarkupSafe==1.1.1 185 | matplotlib==3.2.2 186 | matplotlib-venn==0.11.6 187 | missingno==0.4.2 188 | mistune==0.8.4 189 | mizani==0.6.0 190 | mkl==2019.0 191 | mlxtend==0.14.0 192 | more-itertools==8.7.0 193 | moviepy==0.2.3.5 194 | mpmath==1.2.1 195 | msgpack==1.0.2 196 | multiprocess==0.70.11.1 197 | multitasking==0.0.9 198 | murmurhash==1.0.5 199 | music21==5.5.0 200 | natsort==5.5.0 201 | nbclient==0.5.3 202 | nbconvert==5.6.1 203 | nbformat==5.1.2 204 | nest-asyncio==1.5.1 205 | networkx==2.5 206 | nibabel==3.0.2 207 | nltk==3.2.5 208 | notebook==5.3.1 209 | np-utils==0.5.12.1 210 | numba==0.51.2 211 | numexpr==2.7.3 212 | numpy==1.19.5 213 | nvidia-ml-py3==7.352.0 214 | oauth2client==4.1.3 215 | oauthlib==3.1.0 216 | okgrade==0.4.3 217 | opencv-contrib-python==4.1.2.30 218 | opencv-python==4.1.2.30 219 | openpyxl==2.5.9 220 | opt-einsum==3.3.0 221 | osqp==0.6.2.post0 222 | packaging==20.9 223 | palettable==3.3.0 224 | pandas==1.1.5 225 | pandas-datareader==0.9.0 226 | pandas-gbq==0.13.3 227 | pandas-profiling==1.4.1 228 | pandocfilters==1.4.3 229 | panel==0.9.7 230 | param==1.10.1 231 | parso==0.8.1 232 | pathlib==1.0.1 233 | patsy==0.5.1 234 | pexpect==4.8.0 235 | pickleshare==0.7.5 236 | Pillow==7.0.0 237 | pip-tools==4.5.1 238 | plac==1.1.3 239 | plotly==4.4.1 240 | plotnine==0.6.0 241 | pluggy==0.7.1 242 | pooch==1.3.0 243 | portpicker==1.3.1 244 | prefetch-generator==1.0.1 245 | preshed==3.0.5 246 | prettytable==2.1.0 247 | progressbar2==3.38.0 248 | prometheus-client==0.9.0 249 | promise==2.3 250 | prompt-toolkit==1.0.18 251 | protobuf==3.12.4 252 | psutil==5.4.8 253 | psycopg2==2.7.6.1 254 | ptyprocess==0.7.0 255 | py==1.10.0 256 | pyarrow==3.0.0 257 | pyasn1==0.4.8 258 | pyasn1-modules==0.2.8 259 | pycocotools==2.0.2 260 | pycparser==2.20 261 | pyct==0.4.8 262 | pydata-google-auth==1.1.0 263 | pydot==1.3.0 264 | pydot-ng==2.0.0 265 | pydotplus==2.0.2 266 | PyDrive==1.3.1 267 | pyemd==0.5.1 268 | pyerfa==1.7.2 269 | pyglet==1.5.0 270 | Pygments==2.6.1 271 | pygobject==3.26.1 272 | pymc3==3.7 273 | PyMeeus==0.5.9 274 | pymongo==3.11.3 275 | pymystem3==0.2.0 276 | pynndescent==0.5.2 277 | PyOpenGL==3.1.5 278 | pyparsing==2.4.7 279 | pyrsistent==0.17.3 280 | pysndfile==1.3.8 281 | PySocks==1.7.1 282 | pystan==2.19.1.1 283 | pytest==3.6.4 284 | python-apt==0.0.0 285 | python-chess==0.23.11 286 | python-dateutil==2.8.1 287 | python-louvain==0.15 288 | python-slugify==4.0.1 289 | python-utils==2.5.6 290 | pytorch-lightning==0.7.5 291 | pytz==2018.9 292 | pyviz-comms==2.0.1 293 | PyWavelets==1.1.1 294 | PyYAML==3.13 295 | pyzmq==22.0.3 296 | qdldl==0.1.5.post0 297 | qtconsole==5.0.3 298 | QtPy==1.9.0 299 | regex==2019.12.20 300 | requests==2.23.0 301 | requests-oauthlib==1.3.0 302 | resampy==0.2.2 303 | retrying==1.3.3 304 | rpy2==3.4.2 305 | rsa==4.7.2 306 | sacremoses==0.0.43 307 | scikit-image==0.16.2 308 | scikit-learn==0.22.2.post1 309 | scipy==1.4.1 310 | screen-resolution-extra==0.0.0 311 | scs==2.1.2 312 | seaborn==0.11.1 313 | Send2Trash==1.5.0 314 | sentencepiece==0.1.95 315 | setuptools-git==1.2 316 | Shapely==1.7.1 317 | simplegeneric==0.8.1 318 | six==1.15.0 319 | sklearn==0.0 320 | sklearn-pandas==1.8.0 321 | smart-open==4.2.0 322 | snowballstemmer==2.1.0 323 | sortedcontainers==2.3.0 324 | SoundFile==0.10.3.post1 325 | spacy==2.2.4 326 | Sphinx==1.8.5 327 | sphinxcontrib-serializinghtml==1.1.4 328 | sphinxcontrib-websupport==1.2.4 329 | SQLAlchemy==1.3.23 330 | sqlparse==0.4.1 331 | srsly==1.0.5 332 | statsmodels==0.10.2 333 | sympy==1.7.1 334 | tables==3.4.4 335 | tabulate==0.8.9 336 | tblib==1.7.0 337 | tensorboard==2.4.1 338 | tensorboard-plugin-wit==1.8.0 339 | tensorflow==2.4.1 340 | tensorflow-datasets==4.0.1 341 | tensorflow-estimator==2.4.0 342 | tensorflow-gcs-config==2.4.0 343 | tensorflow-hub==0.11.0 344 | tensorflow-metadata==0.28.0 345 | tensorflow-probability==0.12.1 346 | termcolor==1.1.0 347 | terminado==0.9.2 348 | testpath==0.4.4 349 | text-unidecode==1.3 350 | textblob==0.15.3 351 | textgenrnn==1.4.1 352 | Theano==1.0.5 353 | thinc==7.4.0 354 | tifffile==2021.3.17 355 | tokenizers==0.7.0 356 | toml==0.10.2 357 | toolz==0.11.1 358 | torch==1.8.0+cu101 359 | torchsummary==1.5.1 360 | torchtext==0.9.0 361 | torchvision==0.9.0+cu101 362 | tornado==5.1.1 363 | tqdm==4.41.1 364 | traitlets==5.0.5 365 | transformers==2.9.0 366 | tweepy==3.10.0 367 | typeguard==2.7.1 368 | typing-extensions==3.7.4.3 369 | tzlocal==1.5.1 370 | umap-learn==0.5.1 371 | uritemplate==3.0.1 372 | urllib3==1.24.3 373 | vega-datasets==0.9.0 374 | wasabi==0.8.2 375 | wcwidth==0.2.5 376 | webencodings==0.5.1 377 | Werkzeug==1.0.1 378 | widgetsnbextension==3.5.1 379 | wordcloud==1.5.0 380 | wrapt==1.12.1 381 | xarray==0.15.1 382 | xgboost==0.90 383 | xkit==0.0.0 384 | xlrd==1.1.0 385 | xlwt==1.3.0 386 | yellowbrick==0.9.1 387 | zict==2.0.0 388 | zipp==3.4.1 389 | wordninja -------------------------------------------------------------------------------- /requirements_qdmr_parser.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.10.0 2 | alabaster==0.7.12 3 | albumentations==0.1.12 4 | altair==4.1.0 5 | appdirs==1.4.4 6 | argon2-cffi==20.1.0 7 | asgiref==3.3.1 8 | astor==0.8.1 9 | astropy==4.2 10 | astunparse==1.6.3 11 | async-generator==1.10 12 | atari-py==0.2.6 13 | atomicwrites==1.4.0 14 | attrs==20.3.0 15 | audioread==2.1.9 16 | autograd==1.3 17 | Babel==2.9.0 18 | backcall==0.2.0 19 | beautifulsoup4==4.6.3 20 | bleach==3.3.0 21 | blis==0.4.1 22 | bokeh==2.1.1 23 | Bottleneck==1.3.2 24 | branca==0.4.2 25 | bs4==0.0.1 26 | CacheControl==0.12.6 27 | cachetools==4.2.1 28 | catalogue==1.0.0 29 | certifi==2020.12.5 30 | cffi==1.14.5 31 | chainer==7.4.0 32 | chardet==3.0.4 33 | click==7.1.2 34 | cloudpickle==1.3.0 35 | # cmake==3.12.0 36 | cmdstanpy==0.9.5 37 | colorlover==0.3.0 38 | community==1.0.0b1 39 | contextlib2==0.5.5 40 | convertdate==2.3.1 41 | coverage==3.7.1 42 | coveralls==0.5 43 | crcmod==1.7 44 | cufflinks==0.17.3 45 | cupy-cuda101==7.4.0 46 | cvxopt==1.2.6 47 | cvxpy==1.0.31 48 | cycler==0.10.0 49 | cymem==2.0.5 50 | Cython==0.29.22 51 | daft==0.0.4 52 | dask==2.12.0 53 | datascience==0.10.6 54 | debugpy==1.0.0 55 | decorator==4.4.2 56 | defusedxml==0.7.1 57 | descartes==1.1.0 58 | dill==0.3.3 59 | distributed==1.25.3 60 | Django==3.1.7 61 | dlib==19.18.0 62 | dm-tree==0.1.5 63 | docopt==0.6.2 64 | docutils==0.16 65 | dopamine-rl==1.0.5 66 | earthengine-api==0.1.255 67 | easydict==1.9 68 | ecos==2.0.7.post1 69 | editdistance==0.5.3 70 | en-core-web-sm==2.2.5 71 | entrypoints==0.3 72 | ephem==3.7.7.1 73 | et-xmlfile==1.0.1 74 | fa2==0.3.5 75 | fancyimpute==0.4.3 76 | fastai==1.0.61 77 | fastdtw==0.3.4 78 | fastprogress==1.0.0 79 | fastrlock==0.5 80 | fbprophet==0.7.1 81 | feather-format==0.4.1 82 | filelock==3.0.12 83 | firebase-admin==4.4.0 84 | fix-yahoo-finance==0.0.22 85 | Flask==1.1.2 86 | flatbuffers==1.12 87 | folium==0.8.3 88 | future==0.18.2 89 | gast==0.3.3 90 | GDAL==2.2.2 91 | gdown==3.6.4 92 | gensim==3.6.0 93 | geographiclib==1.50 94 | geopy==1.17.0 95 | gin-config==0.4.0 96 | glob2==0.7 97 | google==2.0.3 98 | google-api-core==1.26.1 99 | google-api-python-client==1.12.8 100 | google-auth==1.27.1 101 | google-auth-httplib2==0.0.4 102 | google-auth-oauthlib==0.4.3 103 | google-cloud-bigquery==1.21.0 104 | google-cloud-bigquery-storage==1.1.0 105 | google-cloud-core==1.0.3 106 | google-cloud-datastore==1.8.0 107 | google-cloud-firestore==1.7.0 108 | google-cloud-language==1.2.0 109 | google-cloud-storage==1.18.1 110 | google-cloud-translate==1.5.0 111 | google-colab==1.0.0 112 | google-pasta==0.2.0 113 | google-resumable-media==0.4.1 114 | googleapis-common-protos==1.53.0 115 | googledrivedownloader==0.4 116 | graphviz==0.10.1 117 | grpcio==1.32.0 118 | gspread==3.0.1 119 | gspread-dataframe==3.0.8 120 | gym==0.17.3 121 | h5py==2.10.0 122 | HeapDict==1.0.1 123 | hijri-converter==2.1.1 124 | holidays==0.10.5.2 125 | holoviews==1.13.5 126 | html5lib==1.0.1 127 | httpimport==0.5.18 128 | httplib2==0.17.4 129 | httplib2shim==0.0.3 130 | humanize==0.5.1 131 | hyperopt==0.1.2 132 | ideep4py==2.0.0.post3 133 | idna==2.10 134 | image==1.5.33 135 | imageio==2.4.1 136 | imagesize==1.2.0 137 | imbalanced-learn==0.4.3 138 | imblearn==0.0 139 | imgaug==0.2.9 140 | importlib-metadata==3.7.2 141 | importlib-resources==5.1.2 142 | imutils==0.5.4 143 | inflect==2.1.0 144 | iniconfig==1.1.1 145 | intel-openmp==2021.1.2 146 | intervaltree==2.1.0 147 | ipykernel==4.10.1 148 | ipython==5.5.0 149 | ipython-genutils==0.2.0 150 | ipython-sql==0.3.9 151 | ipywidgets==7.6.3 152 | itsdangerous==1.1.0 153 | jax==0.2.10 154 | jaxlib==0.1.62+cuda110 155 | jdcal==1.4.1 156 | jedi==0.18.0 157 | jieba==0.42.1 158 | Jinja2==2.11.3 159 | joblib==1.0.1 160 | jpeg4py==0.1.4 161 | jsonschema==2.6.0 162 | jupyter==1.0.0 163 | jupyter-client==5.3.5 164 | jupyter-console==5.2.0 165 | jupyter-core==4.7.1 166 | jupyterlab-pygments==0.1.2 167 | jupyterlab-widgets==1.0.0 168 | kaggle==1.5.10 169 | kapre==0.1.3.1 170 | Keras==2.4.3 171 | Keras-Preprocessing==1.1.2 172 | keras-vis==0.4.1 173 | kiwisolver==1.3.1 174 | knnimpute==0.1.0 175 | korean-lunar-calendar==0.2.1 176 | librosa==0.8.0 177 | lightgbm==2.2.3 178 | llvmlite==0.34.0 179 | lmdb==0.99 180 | lucid==0.3.8 181 | LunarCalendar==0.0.9 182 | lxml==4.2.6 183 | Markdown==3.3.4 184 | MarkupSafe==1.1.1 185 | matplotlib==3.2.2 186 | matplotlib-venn==0.11.6 187 | missingno==0.4.2 188 | mistune==0.8.4 189 | mizani==0.6.0 190 | mkl==2019.0 191 | mlxtend==0.14.0 192 | more-itertools==8.7.0 193 | moviepy==0.2.3.5 194 | mpmath==1.2.1 195 | msgpack==1.0.2 196 | multiprocess==0.70.11.1 197 | multitasking==0.0.9 198 | murmurhash==1.0.5 199 | music21==5.5.0 200 | natsort==5.5.0 201 | nbclient==0.5.3 202 | nbconvert==5.6.1 203 | nbformat==5.1.2 204 | nest-asyncio==1.5.1 205 | networkx==2.5 206 | nibabel==3.0.2 207 | nltk==3.2.5 208 | notebook==5.3.1 209 | np-utils==0.5.12.1 210 | numba==0.51.2 211 | numexpr==2.7.3 212 | numpy==1.19.5 213 | nvidia-ml-py3==7.352.0 214 | oauth2client==4.1.3 215 | oauthlib==3.1.0 216 | okgrade==0.4.3 217 | opencv-contrib-python==4.1.2.30 218 | opencv-python==4.1.2.30 219 | openpyxl==2.5.9 220 | opt-einsum==3.3.0 221 | osqp==0.6.2.post0 222 | packaging==20.9 223 | palettable==3.3.0 224 | pandas==1.1.5 225 | pandas-datareader==0.9.0 226 | pandas-gbq==0.13.3 227 | pandas-profiling==1.4.1 228 | pandocfilters==1.4.3 229 | panel==0.9.7 230 | param==1.10.1 231 | parso==0.8.1 232 | pathlib==1.0.1 233 | patsy==0.5.1 234 | pexpect==4.8.0 235 | pickleshare==0.7.5 236 | Pillow==7.0.0 237 | pip-tools==4.5.1 238 | plac==1.1.3 239 | plotly==4.4.1 240 | plotnine==0.6.0 241 | pluggy==0.7.1 242 | pooch==1.3.0 243 | portpicker==1.3.1 244 | prefetch-generator==1.0.1 245 | preshed==3.0.5 246 | prettytable==2.1.0 247 | progressbar2==3.38.0 248 | prometheus-client==0.9.0 249 | promise==2.3 250 | prompt-toolkit==1.0.18 251 | protobuf==3.12.4 252 | psutil==5.4.8 253 | psycopg2==2.7.6.1 254 | ptyprocess==0.7.0 255 | py==1.10.0 256 | pyarrow==3.0.0 257 | pyasn1==0.4.8 258 | pyasn1-modules==0.2.8 259 | pycocotools==2.0.2 260 | pycparser==2.20 261 | pyct==0.4.8 262 | pydata-google-auth==1.1.0 263 | pydot==1.3.0 264 | pydot-ng==2.0.0 265 | pydotplus==2.0.2 266 | PyDrive==1.3.1 267 | pyemd==0.5.1 268 | pyerfa==1.7.2 269 | pyglet==1.5.0 270 | Pygments==2.6.1 271 | pygobject==3.26.1 272 | pymc3==3.7 273 | PyMeeus==0.5.9 274 | pymongo==3.11.3 275 | pymystem3==0.2.0 276 | pynndescent==0.5.2 277 | PyOpenGL==3.1.5 278 | pyparsing==2.4.7 279 | pyrsistent==0.17.3 280 | pysndfile==1.3.8 281 | PySocks==1.7.1 282 | pystan==2.19.1.1 283 | pytest==3.6.4 284 | python-apt==0.0.0 285 | python-chess==0.23.11 286 | python-dateutil==2.8.1 287 | python-louvain==0.15 288 | python-slugify==4.0.1 289 | python-utils==2.5.6 290 | pytz==2018.9 291 | pyviz-comms==2.0.1 292 | PyWavelets==1.1.1 293 | PyYAML==3.13 294 | pyzmq==22.0.3 295 | qdldl==0.1.5.post0 296 | qtconsole==5.0.3 297 | QtPy==1.9.0 298 | regex==2019.12.20 299 | requests==2.23.0 300 | requests-oauthlib==1.3.0 301 | resampy==0.2.2 302 | retrying==1.3.3 303 | rpy2==3.4.2 304 | rsa==4.7.2 305 | sacremoses==0.0.43 306 | scikit-image==0.16.2 307 | scikit-learn==0.22.2.post1 308 | scipy==1.4.1 309 | screen-resolution-extra==0.0.0 310 | scs==2.1.2 311 | seaborn==0.11.1 312 | Send2Trash==1.5.0 313 | sentencepiece==0.1.95 314 | setuptools-git==1.2 315 | Shapely==1.7.1 316 | simplegeneric==0.8.1 317 | six==1.15.0 318 | sklearn==0.0 319 | sklearn-pandas==1.8.0 320 | smart-open==4.2.0 321 | snowballstemmer==2.1.0 322 | sortedcontainers==2.3.0 323 | SoundFile==0.10.3.post1 324 | Sphinx==1.8.5 325 | sphinxcontrib-serializinghtml==1.1.4 326 | sphinxcontrib-websupport==1.2.4 327 | SQLAlchemy==1.3.23 328 | sqlparse==0.4.1 329 | srsly==1.0.5 330 | statsmodels==0.10.2 331 | sympy==1.7.1 332 | tables==3.4.4 333 | tabulate==0.8.9 334 | tblib==1.7.0 335 | tensorboard==2.4.1 336 | tensorboard-plugin-wit==1.8.0 337 | tensorflow==2.4.1 338 | tensorflow-datasets==4.0.1 339 | tensorflow-estimator==2.4.0 340 | tensorflow-gcs-config==2.4.0 341 | tensorflow-hub==0.11.0 342 | tensorflow-metadata==0.28.0 343 | tensorflow-probability==0.12.1 344 | termcolor==1.1.0 345 | terminado==0.9.2 346 | testpath==0.4.4 347 | text-unidecode==1.3 348 | textblob==0.15.3 349 | textgenrnn==1.4.1 350 | Theano==1.0.5 351 | thinc==7.4.0 352 | tifffile==2021.3.17 353 | tokenizers==0.7.0 354 | toml==0.10.2 355 | toolz==0.11.1 356 | tornado==5.1.1 357 | tqdm==4.41.1 358 | traitlets==5.0.5 359 | tweepy==3.10.0 360 | typeguard==2.7.1 361 | typing-extensions==3.7.4.3 362 | tzlocal==1.5.1 363 | umap-learn==0.5.1 364 | uritemplate==3.0.1 365 | urllib3==1.24.3 366 | vega-datasets==0.9.0 367 | wasabi==0.8.2 368 | wcwidth==0.2.5 369 | webencodings==0.5.1 370 | Werkzeug==1.0.1 371 | widgetsnbextension==3.5.1 372 | wordcloud==1.5.0 373 | wrapt==1.12.1 374 | xarray==0.15.1 375 | xgboost==0.90 376 | xkit==0.0.0 377 | xlrd==1.1.0 378 | xlwt==1.3.0 379 | yellowbrick==0.9.1 380 | zict==2.0.0 381 | zipp==3.4.1 382 | pandas 383 | sentencepiece 384 | transformers==4.6.1 385 | pytorch-lightning==1.3.3 386 | fastt5==0.0.5 387 | wordninja==2.0.0 388 | spacy==3.1.0 389 | networkx==2.5.1 390 | torchmetrics 391 | # torch>=1.7.0,!=1.8.0 -------------------------------------------------------------------------------- /src/data_generation/graph_utils.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def has_path(graph, start, end, path=[]): 5 | G = nx.Graph(graph) 6 | try: 7 | return nx.has_path(G, start, end) 8 | except: 9 | return None 10 | return None 11 | 12 | 13 | def find_shortest_paths(graph, start, end): 14 | G = nx.Graph(graph) 15 | try: 16 | return [p for p in nx.all_shortest_paths(G, source=start, target=end)] 17 | except: 18 | None 19 | return None 20 | 21 | # a sample graph 22 | # graph = {'A': ['B', 'C', 'E'], 23 | # 'B': ['A', 'C', 'D'], 24 | # 'C': ['A', 'B', 'D', 'F'], 25 | # 'D': ['A', 'B', 'C', 'E'], 26 | # 'E': ['A', 'D', 'F'], 27 | # 'F': ['C', 'E']} 28 | 29 | 30 | # print(has_path(graph,"A","D")) 31 | # print(find_shortest_paths(graph,"A","F")) 32 | -------------------------------------------------------------------------------- /src/data_generation/ground_example.py: -------------------------------------------------------------------------------- 1 | from csv import DictWriter 2 | from collections import OrderedDict 3 | 4 | from sql_execution import * 5 | from preprocess_db import * 6 | from grounding_repairs import * 7 | 8 | 9 | def handle_imdb_cast_table(sql): 10 | """cast is a reserved word in SQL. 11 | We alias the table cast if it appears in the query""" 12 | if "cast" not in sql: 13 | return sql 14 | sql = sql.replace(" cast.", " @@@placeholder@@@") 15 | sql = sql.replace("(cast.", "(@@@placeholder@@@") 16 | sql = sql.replace(" cast", " cast as cast_0") 17 | sql = sql.replace("@@@placeholder@@@", "cast_0.") 18 | return sql 19 | 20 | 21 | def append_dict_as_row(file_name, dict_of_elem, field_names): 22 | # Open file in append mode 23 | with open(file_name, 'a+', newline='', encoding='utf-8') as write_obj: 24 | # Create a writer object from csv module 25 | dict_writer = DictWriter(write_obj, fieldnames=field_names) 26 | # Add dictionary as wor in the csv 27 | dict_writer.writerow(dict_of_elem) 28 | 29 | 30 | def mixs(val): 31 | """"" 32 | comparator function for sorting mixed lists of number & string 33 | """ 34 | if isinstance(val, str): 35 | return (1, val, '') 36 | return (0, val, '') 37 | 38 | 39 | class GroundingTestExample: 40 | def __init__(self, example_id, db_id, question, qdmr, gold_sql, schema_path, dataset): 41 | self.example_id = example_id 42 | self.db_id = db_id 43 | self.question = question 44 | self.qdmr = qdmr 45 | self.gold_sql = gold_sql 46 | default_path = "data/spider_databases/%s/%s.sqlite" % (self.db_id, self.db_id) \ 47 | if dataset == "spider" else "data/other_databases/%s/%s.sqlite" % (self.db_id, self.db_id) 48 | self.schema_path = default_path if schema_path is None else schema_path 49 | self.dataset = dataset 50 | self.grounding = None 51 | self.grounded_sql = None 52 | self.grounding_error = False 53 | 54 | def to_dict(self): 55 | d = {} 56 | d["id"] = self.example_id 57 | d["db"] = self.db_id 58 | d["qdmr_grounding"] = self.qdmr 59 | d["gold_sql"] = self.gold_sql 60 | d["grounding"] = self.grounding.to_dict() 61 | return d 62 | 63 | def ground_example(self, assignment=None): 64 | if assignment: 65 | assert self.grounding 66 | self.grounding.assign_groundings(assignment) 67 | else: 68 | schema = prepare_db_schema(self.schema_path, self.dataset) 69 | self.grounding = QDMRGrounding(self.qdmr, self.question, schema, self.gold_sql) 70 | grounded_steps = self.grounding.ground() 71 | self.grounded_sql = self.grounding.get_grounded_sql() 72 | return self.grounded_sql is not None 73 | 74 | def set_qdmr(self, new_qdmr): 75 | self.qdmr = new_qdmr 76 | return True 77 | 78 | def set_grounded_sql(self, new_sql): 79 | self.grounded_sql = new_sql 80 | return True 81 | 82 | def get_gold_sql_denotation(self): 83 | return execute_sql(self.schema_path, self.gold_sql) 84 | 85 | def get_grounded_sql_denotation(self): 86 | if self.grounded_sql: 87 | sql = handle_imdb_cast_table(self.grounded_sql) 88 | try: 89 | denotation = execute_sql(self.schema_path, sql) 90 | except TimeoutError: 91 | print('* Grounded SQL execution timeout') 92 | denotation = None 93 | except (sqlite3.Warning, sqlite3.Error, sqlite3.DatabaseError, 94 | sqlite3.IntegrityError, sqlite3.ProgrammingError, 95 | sqlite3.OperationalError, sqlite3.NotSupportedError) as e: 96 | print('* Grounded SQL execution error') 97 | denotation = None 98 | return denotation 99 | 100 | def normalize_tuple(self, tup): 101 | # cast all tuple values to strings 102 | norm_vars = [str(var) for var in tup] 103 | return tuple(norm_vars) 104 | 105 | def normalize_denotation(self, denotation_list, distinct=None): 106 | if not denotation_list: 107 | return denotation_list 108 | # remove duplicates 109 | denotation_list = list(OrderedDict.fromkeys(denotation_list)) if distinct else denotation_list 110 | sort_tuples = [sorted(self.normalize_tuple(tup)) for tup in denotation_list] 111 | return sorted(sort_tuples) # sort results set 112 | 113 | def correct_denotation(self, distinct=None): 114 | if not self.grounded_sql: 115 | return False 116 | if self.get_gold_sql_denotation() == self.get_grounded_sql_denotation(): 117 | # unnormalized denotations 118 | return True 119 | gold_denotation_norm = self.normalize_denotation(self.get_gold_sql_denotation(), distinct=distinct) 120 | ground_denotation_norm = self.normalize_denotation(self.get_grounded_sql_denotation(), distinct=distinct) 121 | return gold_denotation_norm == ground_denotation_norm 122 | 123 | def repair(self, modules=None): 124 | """ 125 | Run all repair modules for an incorrect grounding (wrong denotation). 126 | If nor correct grounding found return the original grounding. 127 | All repairs rely on the gold denotation for filtering. 128 | The modules input indicate which repairs should be run. 129 | 130 | Returns 131 | ------- 132 | bool 133 | True if repair was successful, otherwise False 134 | """ 135 | valid_modules = ["syntax", "column_ground", "qdmr"] 136 | if modules is not None: 137 | assert set(modules) <= set(valid_modules) 138 | assert self.grounding is not None 139 | original_grounding = self.grounding 140 | original_grounded_sql = self.grounded_sql 141 | repair_modules = valid_modules if modules is None else modules 142 | if "syntax" in repair_modules and self.syntax_repairs(): 143 | return True 144 | self.restore_grounding(original_grounding, original_grounded_sql) 145 | if "column_ground" in repair_modules: 146 | repair = ColumnGroundingRepair(self) 147 | if repair.repair(): 148 | return True 149 | self.restore_grounding(original_grounding, original_grounded_sql) 150 | if "qdmr" in repair_modules: 151 | repair = CountSumGroundingRepair(self) 152 | if repair.repair(): 153 | return True 154 | repair = SuperlativeGroundingRepair(self) 155 | if repair.repair(): 156 | return True 157 | self.restore_grounding(original_grounding, original_grounded_sql) 158 | print("*** Grounding example repair failed - restoring original grounding") 159 | return False 160 | 161 | def restore_grounding(self, grounding, grounded_sql): 162 | self.grounding = grounding 163 | self.grounded_sql = grounded_sql 164 | return True 165 | 166 | def syntax_repairs(self): 167 | """runs all sytax repairs on grounding example""" 168 | repair_count_distinct = AggrDistinctGroundingRepair(self) 169 | repair_like_val = LikeEqualsGroundingRepair(self) 170 | if repair_count_distinct.repair(): 171 | return True 172 | return repair_like_val.repair() 173 | -------------------------------------------------------------------------------- /src/data_generation/main.py: -------------------------------------------------------------------------------- 1 | from write_grounding import * 2 | import argparse 3 | import os 4 | 5 | 6 | def main(args): 7 | examples = load_grounding_examples(args.input_file) 8 | print(f"Loaded {len(examples)} grounding examples.") 9 | write_grounding_results(examples, args.output_file, to_json=args.json_steps) 10 | print("Done grounding all examples.\n") 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser( 14 | description="example command: " 15 | "python main.py test/grounding_examples.csv " 16 | "--json_steps True" 17 | ) 18 | parser.add_argument('input_file', type=str, help='path to grounding examples csv') 19 | parser.add_argument('output_file', type=str, help='path to output file, with csv extension') 20 | parser.add_argument('--json_steps', type=bool, default=None, 21 | help='whether to generate grounding steps json') 22 | args = parser.parse_args() 23 | assert os.path.exists(args.input_file) 24 | 25 | main(args) 26 | 27 | -------------------------------------------------------------------------------- /src/data_generation/predicted_sql.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def remove_quotation_marks(string): 5 | if string.startswith("'") and string.endswith("'"): 6 | return string[1:-1] 7 | if string.startswith('"') and string.endswith('"'): 8 | return string[1:-1] 9 | return string 10 | 11 | 12 | def sql_quotation_values(sql): 13 | """ 14 | Returns list of lowercase quotation value in SQL query 15 | """ 16 | 17 | def remove_like_pct(string): 18 | string = string.replace("%", "") 19 | return string 20 | 21 | query = sql.lower() 22 | value_to_col = {} 23 | # Find all values based on string delimiters 24 | single_paren_vals = [item.group(0) for item in re.finditer(r'\'.*?\'', query)] 25 | double_paren_vals = [item.group(0) for item in re.finditer(r'\".*?\"', query)] 26 | vals_list = single_paren_vals + double_paren_vals 27 | return [remove_quotation_marks(remove_like_pct(val)) for val in vals_list] 28 | 29 | 30 | def val_sql_quotation(value): 31 | return [f"'{value}'", f"'%{value}%'", f'\"{value}\"', f'\"%{value}%\"'] 32 | 33 | 34 | def sql_value_case(sql, value): 35 | """returns a non-numeric value in the case it appears in the original SQL query""" 36 | 37 | def escape_parentheses(value): 38 | return value.replace("(", "\(").replace(")", "\)") 39 | 40 | value_quotation = val_sql_quotation(value) 41 | for quote_val in value_quotation: 42 | escaped_val = escape_parentheses(quote_val) 43 | if re.search(escaped_val, sql, re.IGNORECASE): 44 | return re.search(escaped_val, sql, re.IGNORECASE).group(0) 45 | return None 46 | 47 | 48 | def fix_sql_casing(pred_sql, gold_sql): 49 | gold_values = sql_quotation_values(gold_sql) 50 | fixed_sql = pred_sql 51 | for val in gold_values: 52 | val_case_quotes = sql_value_case(gold_sql, val) 53 | if val_case_quotes is not None: 54 | val_case = remove_quotation_marks(val_case_quotes) 55 | for pred_val_cased in val_sql_quotation(val_case): 56 | # quoted value as it appears in the *predicted* sql 57 | fixed_sql = fixed_sql.replace(pred_val_cased.lower(), 58 | pred_val_cased) if pred_val_cased.lower() in fixed_sql else fixed_sql 59 | return fixed_sql 60 | -------------------------------------------------------------------------------- /src/data_generation/preprocess_db.py: -------------------------------------------------------------------------------- 1 | from db_schema import * 2 | 3 | 4 | def prepare_db_schema(path, dataset): 5 | if dataset == "spider": 6 | return prepare_spider_db(path) 7 | return prepare_other_db(path, dataset) 8 | 9 | 10 | def prepare_spider_db(path): 11 | return DBSchema(path) 12 | 13 | 14 | def prepare_other_db(path, dataset): 15 | schema = DBSchema(path) 16 | # manually add foreign keys that are absent from the original DBs 17 | if dataset == "academic": 18 | schema.add_foreign_key('publication', 'pid', 'writes', 'pid') 19 | schema.add_foreign_key('author', 'aid', 'writes', 'aid') 20 | schema.add_foreign_key('journal', 'jid', 'publication', 'jid') 21 | schema.add_foreign_key('conference', 'cid', 'publication', 'cid') 22 | schema.add_foreign_key('publication', 'pid', 'publication_keyword', 'pid') 23 | schema.add_foreign_key('keyword', 'kid', 'publication_keyword', 'kid') 24 | schema.add_foreign_key('author', 'oid', 'organization', 'oid') 25 | schema.add_foreign_key('author', 'aid', 'domain_author', 'aid') 26 | schema.add_foreign_key('domain', 'did', 'domain_author', 'did') 27 | schema.add_foreign_key('domain', 'did', 'domain_publication', 'did') 28 | schema.add_foreign_key('domain_publication', 'pid', 'publication', 'pid') 29 | schema.add_foreign_key('cite', 'cited', 'publication', 'pid') 30 | schema.add_foreign_key('cite', 'citing', 'publication', 'pid') 31 | schema.add_foreign_key('domain', 'did', 'domain_keyword', 'did') 32 | schema.add_foreign_key('domain_keyword', 'kid', 'keyword', 'kid') 33 | schema.add_foreign_key('domain', 'did', 'domain_journal', 'did') 34 | schema.add_foreign_key('domain_journal', 'jid', 'journal', 'jid') 35 | schema.add_foreign_key('conference', 'cid', 'domain_conference', 'cid') 36 | schema.add_foreign_key('domain', 'did', 'domain_conference', 'did') 37 | elif dataset == "atis": 38 | schema.add_foreign_key('airport_service', 'city_code', 'city', 'city_code') 39 | schema.add_foreign_key('airport_service', 'airport_code', 'flight', 'from_airport') 40 | schema.add_foreign_key('airport_service', 'airport_code', 'flight', 'to_airport') 41 | schema.add_foreign_key('date_day', 'day_name', 'days', 'day_name') 42 | schema.add_foreign_key('days', 'days_code', 'flight', 'flight_days') 43 | schema.add_foreign_key('fare', 'fare_id', 'flight_fare', 'fare_id') 44 | schema.add_foreign_key('flight', 'flight_id', 'flight_fare', 'flight_id') 45 | schema.add_foreign_key('fare', 'fare_basis_code', 'fare_basis', 'fare_basis_code') 46 | schema.add_foreign_key('flight', 'flight_id', 'flight_stop', 'flight_id') 47 | schema.add_foreign_key('days', 'days_code', 'fare_basis', 'basis_days') 48 | schema.add_foreign_key('airport_service', 'airport_code', 'flight_stop', 'stop_airport') 49 | schema.add_foreign_key('city', 'city_code', 'ground_service', 'city_code') 50 | schema.add_foreign_key('airline', 'airline_code', 'flight', 'airline_code') 51 | schema.add_foreign_key('airport', 'airport_code', 'airport_service', 'airport_code') 52 | schema.add_foreign_key('flight', 'meal_code', 'food_service', 'meal_code') 53 | schema.add_foreign_key('aircraft', 'aircraft_code', 'equipment_sequence', 'aircraft_code') 54 | schema.add_foreign_key('equipment_sequence', 'aircraft_code_sequence', 'flight', 'aircraft_code_sequence') 55 | schema.add_foreign_key('city', 'state_code', 'state', 'state_code') 56 | schema.add_foreign_key('airport', 'airport_code', 'flight', 'to_airport') 57 | schema.add_foreign_key('airport', 'airport_code', 'ground_service', 'airport_code') 58 | schema.add_foreign_key('airport', 'airport_code', 'flight', 'from_airport') 59 | schema.add_foreign_key('airport_service', 'airport_code', 'fare', 'to_airport') 60 | schema.add_foreign_key('airport_service', 'airport_code', 'fare', 'from_airport') 61 | schema.add_foreign_key('flight', 'flight_id', 'flight_leg', 'flight_id') 62 | schema.add_foreign_key('flight', 'flight_id', 'flight_leg', 'leg_flight') 63 | schema.add_foreign_key('class_of_service', 'booking_class', 'fare_basis', 'booking_class') 64 | schema.add_foreign_key('airport', 'state_code', 'state', 'state_code') 65 | schema.add_foreign_key('airport', 'airport_code', 'flight_stop', 'stop_airport') 66 | schema.add_foreign_key('fare', 'restriction_code', 'restriction', 'restriction_code') 67 | elif dataset == "geo": 68 | schema.add_foreign_key('border_info', 'border', 'state', 'state_name') 69 | schema.add_foreign_key('river', 'traverse', 'state', 'state_name') 70 | schema.add_foreign_key('city', 'city_name', 'state', 'capital') 71 | schema.add_foreign_key('border_info', 'state_name', 'state', 'state_name') 72 | schema.add_foreign_key('city', 'state_name', 'state', 'state_name') 73 | schema.add_foreign_key('border_info', 'border', 'river', 'traverse') 74 | schema.add_foreign_key('highlow', 'state_name', 'state', 'state_name') 75 | schema.add_foreign_key('border_info', 'border', 'border_info', 'state_name') 76 | schema.add_foreign_key('highlow', 'state_name', 'river', 'traverse') 77 | schema.add_foreign_key('border_info', 'state_name', 'river', 'traverse') 78 | schema.add_foreign_key('city', 'state_name', 'river', 'traverse') 79 | schema.add_foreign_key('border_info', 'border', 'highlow', 'state_name') 80 | schema.add_foreign_key('border_info', 'border', 'city', 'state_name') 81 | schema.add_foreign_key('border_info', 'border', 'lake', 'state_name') 82 | elif dataset == "yelp": 83 | schema.add_foreign_key('business', 'business_id', 'category', 'business_id') 84 | schema.add_foreign_key('review', 'user_id', 'user', 'user_id') 85 | schema.add_foreign_key('business', 'business_id', 'review', 'business_id') 86 | schema.add_foreign_key('business', 'business_id', 'neighborhood', 'business_id') 87 | schema.add_foreign_key('tip', 'user_id', 'user', 'user_id') 88 | schema.add_foreign_key('business', 'business_id', 'tip', 'business_id') 89 | schema.add_foreign_key('business', 'business_id', 'checkin', 'business_id') 90 | elif dataset == "imdb": 91 | schema.add_foreign_key('actor', 'aid', 'cast', 'aid') 92 | schema.add_foreign_key('cast', 'msid', 'movie', 'mid') 93 | schema.add_foreign_key('directed_by', 'did', 'director', 'did') 94 | schema.add_foreign_key('directed_by', 'msid', 'movie', 'mid') 95 | schema.add_foreign_key('company', 'id', 'copyright', 'cid') 96 | schema.add_foreign_key('copyright', 'msid', 'movie', 'mid') 97 | schema.add_foreign_key('keyword', 'id', 'tags', 'kid') 98 | schema.add_foreign_key('movie', 'mid', 'tags', 'msid') 99 | schema.add_foreign_key('classification', 'msid', 'movie', 'mid') 100 | schema.add_foreign_key('made_by', 'pid', 'producer', 'pid') 101 | schema.add_foreign_key('classification', 'gid', 'genre', 'gid') 102 | schema.add_foreign_key('movie', 'mid', 'written_by', 'msid') 103 | schema.add_foreign_key('made_by', 'msid', 'movie', 'mid') 104 | schema.add_foreign_key('writer', 'wid', 'written_by', 'wid') 105 | schema.add_foreign_key('copyright', 'msid', 'tv_series', 'sid') 106 | schema.add_foreign_key('cast', 'msid', 'tv_series', 'sid') 107 | schema.add_foreign_key('directed_by', 'msid', 'tv_series', 'sid') 108 | schema.add_foreign_key('made_by', 'msid', 'tv_series', 'sid') 109 | else: 110 | raise ValueError("Invalid dataset name: %s" % dataset) 111 | return schema 112 | -------------------------------------------------------------------------------- /src/data_generation/preprocess_grounding_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pandas.io.json import json_normalize 3 | 4 | import json 5 | import re 6 | 7 | 8 | def remove_space_between_quotes(string): 9 | matches = re.findall(r'\"(.+?)\"', string) # match text between two quotes 10 | for m in matches: 11 | trimmed_m = m.strip() 12 | string = string.replace('\"%s\"' % m, '\"%s\"' % trimmed_m) 13 | return string 14 | 15 | 16 | def add_space_between_conds(sql): 17 | for op in ['>=', '<=', '>', '>', '!=']: 18 | sql = sql.replace(op, " %s " % op) 19 | sql = sql.replace("in(", "in (") 20 | sql = sql.replace("IN(", "IN (") 21 | sql = sql.replace("=(", " = (") 22 | sql = sql.replace("='", " = '") 23 | sql = sql.replace('="', ' = "') 24 | sql = sql.replace(" ", " ") 25 | return sql 26 | 27 | 28 | def change_not_equal_op(sql): 29 | sql = sql.replace("<>", "!=") 30 | return sql 31 | 32 | 33 | def format_sql(sql): 34 | sql = change_not_equal_op(sql) 35 | sql = remove_space_between_quotes(sql) 36 | sql = add_space_between_conds(sql) 37 | return sql 38 | 39 | 40 | def load_spider_data(dataset_path): 41 | """ 42 | Reads query & DB info from Spider dataset 43 | 44 | Parameters 45 | ---------- 46 | dataset_path : str 47 | Full path to dataset json 48 | 49 | Returns 50 | ------- 51 | dict 52 | Dict of (db_id, question, query) 53 | """ 54 | examples = {} 55 | with open(dataset_path) as f: 56 | data = json.load(f) 57 | for i in range(len(data)): 58 | db_id = data[i]["db_id"] 59 | sql = data[i]["query"] 60 | question = data[i]["question"] 61 | split = "train" if "train" in dataset_path else "dev" 62 | index = "SPIDER_%s_%d" % (split, i) 63 | examples[index] = {} 64 | examples[index]["db_id"] = db_id 65 | examples[index]["question"] = question 66 | examples[index]["query"] = sql 67 | assert len(data) == len(examples) 68 | return examples 69 | 70 | 71 | def load_spider_others_data(dataset_path, target_db=None): 72 | """ 73 | Reads query & DB info from Spider dataset 74 | 75 | Parameters 76 | ---------- 77 | dataset_path : str 78 | Full path to dataset json 79 | target_db : str 80 | Filter according to specific db_id 81 | 82 | Returns 83 | ------- 84 | dict 85 | Dict of (db_id, question, query) 86 | """ 87 | examples = {} 88 | with open(dataset_path) as f: 89 | data = json.load(f) 90 | for i in range(len(data)): 91 | db_id = data[i]["db_id"] 92 | if target_db and db_id != target_db: 93 | continue 94 | sql = data[i]["query"] 95 | question = data[i]["question"] 96 | examples[question] = {} 97 | examples[question]["db_id"] = db_id 98 | examples[question]["question"] = question 99 | examples[question]["query"] = sql 100 | assert len(data) == len(examples) 101 | return examples 102 | 103 | 104 | def load_washington_data(dataset_path, db): 105 | """ 106 | Reads question & SQL query from the Washington & 107 | Texas universities text-to-SQL datasets: 108 | Academic, ATIS, Geo, IMDB, Yelp 109 | 110 | Parameters 111 | ---------- 112 | dataset_path : str 113 | Full path to dataset txt file 114 | 115 | Returns 116 | ------- 117 | dict 118 | Dict of (db_id, question, query) 119 | """ 120 | examples = {} 121 | delimiter = "|||" 122 | queries_delimiter = "|" 123 | valid_dbs = ["academic", "atis", "geo", "imdb", "yelp"] 124 | assert db in valid_dbs 125 | with open(dataset_path) as f: 126 | for line in f: 127 | line_data = line.split(delimiter) 128 | question, sql = line.split(delimiter)[1:] if db in ["imdb", "yelp"] else line.split(delimiter) 129 | question = question.strip() 130 | sql = sql.strip() 131 | if queries_delimiter in sql: 132 | sql = sql.split(queries_delimiter)[0].strip() 133 | sql = format_sql(sql) 134 | examples[question] = {} 135 | examples[question]["db_id"] = db 136 | examples[question]["question"] = question 137 | examples[question]["query"] = sql 138 | return examples 139 | 140 | 141 | def load_break_qdmrs(data_path, dataset_name=None): 142 | """ 143 | Reads qdmr structures from Break 144 | 145 | Parameters 146 | ---------- 147 | dataset_path : str 148 | Full path to dataset csv 149 | dataset_name : str 150 | Prefix of a specific dataset to load 151 | 152 | Returns 153 | ------- 154 | Dataframe 155 | Dict of Spider_id --> ('question_id', 'question_text', 'decomposition') 156 | """ 157 | df = pd.read_csv(data_path) 158 | if dataset_name is not None: 159 | df = df[df['question_id'].str.contains(dataset_name)] 160 | df = df[['question_id', 'question_text', 'decomposition']] 161 | df = df.set_index('question_id') 162 | return df.groupby('question_id').apply(lambda dfg: dfg.to_dict(orient='list')).to_dict() 163 | 164 | 165 | def build_grounding_data(qdmr_data, text2sql_data, text2sql_format, \ 166 | db=None, split=None, output=None, to_csv=None): 167 | """ 168 | Builds a json/csv file containing qdmr to sql examples 169 | 170 | Parameters 171 | ---------- 172 | qdmr_data : str 173 | Path to file mapping text questions to their QDMRs 174 | text2sql_data : str 175 | Path to file containing text to SQL data 176 | complete with schema details 177 | text2sql_format : str 178 | Format of the text2sql file (spider/others/washington) 179 | db : str 180 | Particular db name (e.g.,"atis") 181 | split : str 182 | Dataset split, train/dev/test 183 | output : str 184 | Path to save the grounding data file 185 | 186 | 187 | Returns 188 | ------- 189 | dict 190 | dictionary containing examples of 191 | id, db, question, sql, qdmr 192 | """ 193 | grounding_data = {} 194 | assert text2sql_format in ["spider", "others", "washington"] 195 | if split is not None: 196 | assert split in ["train", "dev", "test"] 197 | if text2sql_format == "spider": 198 | dataset_name = "SPIDER_%s" % split 199 | qdmr = load_break_qdmrs(qdmr_data, dataset_name=dataset_name) 200 | print("* QDMR structures loaded: %s" % len(qdmr)) 201 | text2sql = load_spider_data(text2sql_data) 202 | for example_id in qdmr: 203 | try: 204 | assert (qdmr[example_id]['question_text'][0] == \ 205 | text2sql[example_id]['question']) 206 | except: 207 | print("example_id: ", example_id) 208 | print("qdmr[example_id]['question_text'][0]: ", qdmr[example_id]['question_text'][0]) 209 | print("text2sql[example_id]['question']", text2sql[example_id]['question']) 210 | print("***") 211 | grounding_data[example_id] = text2sql[example_id] 212 | grounding_data[example_id]['qdmr'] = qdmr[example_id]['decomposition'][0] 213 | elif text2sql_format == "others": 214 | qdmr = load_break_qdmrs(qdmr_data, dataset_name="ACADEMIC_train") 215 | qdmr.update(load_break_qdmrs(qdmr_data, dataset_name="GEO_train")) 216 | print("* QDMR structures loaded: %s" % len(qdmr)) 217 | text2sql = load_spider_others_data(text2sql_data) 218 | for example_id in qdmr: 219 | key = qdmr[example_id]['question_text'][0].replace('"', '\"').strip() 220 | if key in text2sql.keys(): 221 | grounding_data[example_id] = text2sql[key] 222 | grounding_data[example_id]['qdmr'] = qdmr[example_id]['decomposition'][0] 223 | else: 224 | print("*** Could not find key: ", key) 225 | elif text2sql_format == "washington": 226 | assert db is not None 227 | dataset_name = "%s_" % db.upper() 228 | if split is not None: 229 | dataset_name = "%s_%s" % (db.upper(), split) 230 | if db == "atis": 231 | qdmr = load_break_qdmrs(qdmr_data, dataset_name=dataset_name) 232 | elif db == "academic": 233 | qdmr = load_break_qdmrs(qdmr_data, dataset_name=dataset_name) 234 | elif db == "geo": 235 | qdmr = load_break_qdmrs(qdmr_data, dataset_name=dataset_name) 236 | elif db == "imdb": 237 | qdmr = load_break_qdmrs(qdmr_data, dataset_name=dataset_name) 238 | elif db == "yelp": 239 | qdmr = load_break_qdmrs(qdmr_data, dataset_name=dataset_name) 240 | else: 241 | raise Exception("Invalid db name: %s" % db) 242 | print("* QDMR structures loaded: %s" % len(qdmr)) 243 | text2sql = load_washington_data(text2sql_data, db) 244 | for example_id in qdmr: 245 | key = qdmr[example_id]['question_text'][0].strip() 246 | if key in text2sql.keys(): 247 | grounding_data[example_id] = text2sql[key] 248 | grounding_data[example_id]['qdmr'] = qdmr[example_id]['decomposition'][0] 249 | else: 250 | print("*** Could not find key: ", key) 251 | print("* Overall grounding data examples: %s" % len(grounding_data)) 252 | if output: 253 | with open("%s.json" % output, "w") as fp: 254 | json.dump(grounding_data, fp) 255 | print("* Saved grounding data examples to %s.json" % output) 256 | if to_csv: 257 | df = pd.DataFrame.from_dict(grounding_data, orient="index").reset_index() 258 | df.to_csv("%s.csv" % output) 259 | print("* Saved grounding data examples to %s.csv" % output) 260 | return grounding_data -------------------------------------------------------------------------------- /src/data_generation/qdmr_editor.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from qdmr_identifier import * 3 | 4 | class QDMREditor(object): 5 | def __init__(self, qdmr_text): 6 | self.qdmr_steps = {} 7 | steps_list = parse_decomposition(qdmr_text) 8 | for i in range(len(steps_list)): 9 | self.qdmr_steps[i + 1] = steps_list[i] 10 | 11 | def get_step(self, step_num): 12 | assert step_num in self.qdmr_steps.keys() 13 | return self.qdmr_steps[step_num] 14 | 15 | def replace_step(self, step_num, step): 16 | self.qdmr_steps[step_num] = step 17 | 18 | def add_new_step(self, step_num, step): 19 | new_steps = {} 20 | new_steps[step_num] = step 21 | for i in self.qdmr_steps.keys(): 22 | orig_step = self.qdmr_steps[i] 23 | if i < step_num: 24 | new_steps[i] = orig_step 25 | elif i >= step_num: 26 | new_steps[i + 1] = self.refs_one_up(orig_step, step_num, len(self.qdmr_steps)) 27 | self.qdmr_steps = new_steps 28 | 29 | def refs_one_up(self, qdmr_text, start_idx, end_idx): 30 | target_refs_map = {} 31 | for i in range(start_idx, end_idx + 1): 32 | target_refs_map["#%s" % i] = "#%s" % (i + 1) 33 | new_qdmr_step = "" 34 | for tok in qdmr_text.split(): 35 | if tok in target_refs_map.keys(): 36 | new_qdmr_step += "%s " % target_refs_map[tok] 37 | else: 38 | new_qdmr_step += "%s " % tok 39 | return new_qdmr_step.strip() 40 | 41 | def step_type_phrases(self): 42 | qdmr_text = self.get_qdmr_text() 43 | builder = QDMRProgramBuilder(qdmr_text) 44 | builder.build() 45 | type_phrases = {} 46 | for i in range(len(builder.steps)): 47 | step = builder.steps[i] 48 | op = step.operator 49 | if op == "select": 50 | type_phrases[i + 1] = step.arguments[0] 51 | elif op == "project": 52 | ref_phrase, ref_idx = step.arguments 53 | ref_idx = int(ref_idx.replace("#", "")) 54 | ref_type = type_phrases[ref_idx] 55 | type_phrases[i + 1] = ref_phrase.replace("#REF", ref_type) 56 | elif op in ["filter", "aggregate", "superlative", "comparative", \ 57 | "sort", "discard", "intersection", "union"]: 58 | ref_idx = step.arguments[1] if op in ["aggregate", "superlative"] else step.arguments[0] 59 | ref_idx = int(ref_idx.replace("#", "")) 60 | type_phrases[i + 1] = type_phrases[ref_idx] 61 | else: 62 | type_phrases[i + 1] = None 63 | return type_phrases 64 | 65 | def get_step_type_phrase(self, step_num): 66 | type_phrases = self.step_type_phrases() 67 | return type_phrases[step_num] 68 | 69 | def get_qdmr_text(self): 70 | qdmr = "" 71 | for i in range(len(self.qdmr_steps)): 72 | qdmr += "return %s; " % self.qdmr_steps[i + 1] 73 | return qdmr.strip()[:-1] 74 | -------------------------------------------------------------------------------- /src/data_generation/qdmr_encoding.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import re 3 | import itertools 4 | 5 | from qdmr_grounding import extract_comparator 6 | 7 | 8 | def get_condition(ground_step_dict): 9 | return ground_step_dict["WHERE"] 10 | 11 | 12 | def get_columns(ground_step_dict): 13 | return ground_step_dict["SELECT"] 14 | 15 | 16 | def get_distinct(ground_step_dict): 17 | return "distinct , " if ground_step_dict["distinct"] else "" 18 | 19 | 20 | def get_group(ground_step_dict): 21 | return ground_step_dict["GROUP"] 22 | 23 | 24 | def get_group_select(ground_step_dict): 25 | group_select = ground_step_dict["SELECT"] 26 | col = group_select[0] 27 | # "COUNT(DISTINCT car_makers.maker)" --> DISTINCT car_makers.maker 28 | regex = re.compile(".*?\((.*?)\)") 29 | result = re.findall(regex, col) 30 | return result 31 | 32 | 33 | def get_having_clause(ground_step_dict): 34 | return ground_step_dict["HAVING"] 35 | 36 | 37 | def get_order_clause(ground_step_dict): 38 | return ground_step_dict["ORDER"] 39 | 40 | 41 | def get_order_columns(ground_step_dict): 42 | return ground_step_dict["ORDER BY"] 43 | 44 | 45 | def get_superlative_agg(step): 46 | order_clause_str = get_order_clause(step)[0].lower() 47 | agg = "max" if order_clause_str.startswith("desc ") else "min" 48 | return agg 49 | 50 | 51 | def get_superlative_arg_k(step): 52 | order_clause_str = get_order_clause(step)[0].lower() 53 | if "limit" in order_clause_str: 54 | k = order_clause_str.split("limit")[1].strip() 55 | return k 56 | return "1" 57 | 58 | 59 | def get_join_clause(ground_step_dict): 60 | return ground_step_dict["JOIN"] 61 | 62 | 63 | def is_reference(phrase): 64 | phrase = phrase.strip() 65 | return re.match("^#[0-9]*$", phrase) 66 | 67 | 68 | def extract_refs(phrase): 69 | refs = [] 70 | toks = [i.replace(",", "").strip() for i in phrase.split()] 71 | return list(filter(lambda x: is_reference(x), toks)) 72 | 73 | 74 | def get_arithmetic_op(arithmetic_phrase): 75 | op_map = {} 76 | op_map["sum"] = "+" 77 | op_map["difference"] = "-" 78 | op_map["multiplication"] = "*" 79 | op_map["division"] = "/" 80 | if arithmetic_phrase not in op_map.keys(): 81 | return None 82 | return op_map[arithmetic_phrase] 83 | 84 | 85 | def get_new_joined_cols(new_step, old_step): 86 | new_step_cols = list(itertools.chain.from_iterable(get_join_clause(new_step))) 87 | old_step_cols = list(itertools.chain.from_iterable(get_join_clause(old_step))) 88 | added_cols = list(filter(lambda col: col not in old_step_cols, new_step_cols)) 89 | return added_cols 90 | 91 | 92 | def get_new_joined_column(new_step, old_step): 93 | """ 94 | Return the last new column added to the join clause 95 | which is the grounded column of a filter step phrase 96 | as it is at the end of the added join path. 97 | """ 98 | added_columns = get_new_joined_cols(new_step, old_step) 99 | if added_columns != []: 100 | return [added_columns[-1]] 101 | return added_columns 102 | 103 | 104 | def grounded_step_encoding(data, step): 105 | d = copy.deepcopy(data) 106 | op = d["op"] 107 | # check if distinct exists 108 | distinct = get_distinct(step) 109 | argument_list = d["arguments"] 110 | return "%s ( %s%s )" % (op, distinct, parse_string_list(argument_list)) 111 | 112 | 113 | def parse_string_list(string_list): 114 | if isinstance(string_list[0], (list, tuple)): 115 | new_list = [" ".join(element_list) for element_list in string_list] 116 | string_list = new_list 117 | if isinstance(string_list[-1], (list, tuple)): 118 | # conditions list 119 | string_list[-1] = format_conditions(string_list[-1]) 120 | return " , ".join(string_list) 121 | 122 | 123 | def format_conditions(conds_list): 124 | cond_strings = [] 125 | for cond_triple in conds_list: 126 | cond_strings += [" ".join(cond_triple)] 127 | return " , ".join(cond_strings) 128 | 129 | 130 | def no_reference_encoding(encoded_steps): 131 | step_phrases = {} 132 | for i in range(len(encoded_steps)): 133 | step = encoded_steps[i] 134 | ref_idxs = [int(ref.replace("#", "")) for ref in extract_refs(step)] 135 | sorted_ref_idxs = sorted(ref_idxs, key=int, reverse=True) 136 | enc_step = step 137 | for idx in sorted_ref_idxs: 138 | # go over references in desc order to avoid #1 #1x replacement issues 139 | enc_step = enc_step.replace(f"#{idx}", step_phrases[idx]) 140 | step_phrases[i + 1] = enc_step 141 | return step_phrases[len(encoded_steps)] 142 | 143 | 144 | def has_reference_encoding(encoded_steps): 145 | return " ; ".join(encoded_steps) 146 | 147 | 148 | # parse grounded steps 149 | 150 | def grounded_select(step, qdmr_args=None): 151 | data = {} 152 | data["op"] = "select" 153 | sql_cond = get_condition(step) 154 | # select column or select DB value 155 | data["arguments"] = sql_cond if len(sql_cond) > 0 else get_columns(step) 156 | data["string"] = grounded_step_encoding(data, step) 157 | return data 158 | 159 | 160 | def grounded_project(step, qdmr_args): 161 | data = {} 162 | data["op"] = "project" 163 | project_cols = get_columns(step) 164 | ref = qdmr_args[1] 165 | data["arguments"] = project_cols + [ref] 166 | data["string"] = grounded_step_encoding(data, step) 167 | return data 168 | 169 | 170 | def grounded_filter(step, qdmr_args, grounded_steps): 171 | data = {} 172 | data["op"] = "filter" 173 | ref, _ = qdmr_args 174 | conditions = get_condition(step) 175 | ref_step = grounded_steps[ref.replace("#", "")] 176 | prev_conditions = get_condition(ref_step) 177 | new_conds = list(filter(lambda cond: cond not in prev_conditions, conditions)) 178 | # three cases filter is a: (1) condition; (2) joined column; (3) distinct 179 | args_list = [new_conds] if len(new_conds) > 0 else get_new_joined_column(step, ref_step) 180 | data["arguments"] = [ref] + args_list 181 | data["string"] = grounded_step_encoding(data, step) 182 | return data 183 | 184 | 185 | def grounded_aggregate(step, qdmr_args): 186 | data = {} 187 | data["op"] = "aggregate" 188 | data["arguments"] = qdmr_args 189 | data["string"] = grounded_step_encoding(data, step) 190 | return data 191 | 192 | 193 | def grounded_group(step, qdmr_args): 194 | data = {} 195 | data["op"] = "group" 196 | aggregate, values, keys = qdmr_args 197 | values = [values] if is_reference(values) else get_group_select(step) 198 | keys = [keys] if is_reference(keys) else get_group(step) 199 | data["arguments"] = [aggregate] + values + keys 200 | data["string"] = grounded_step_encoding(data, step) 201 | return data 202 | 203 | 204 | def grounded_superlative(step, qdmr_args): 205 | data = {} 206 | data["op"] = "superlative" 207 | min_max = get_superlative_agg(step) 208 | keys, values = qdmr_args[1:] 209 | arg_k = get_superlative_arg_k(step) 210 | data["arguments"] = [min_max, keys, values, arg_k] 211 | data["string"] = grounded_step_encoding(data, step) 212 | return data 213 | 214 | 215 | def grounded_comparative(step, qdmr_args, grounded_steps): 216 | data = {} 217 | data["op"] = "comparative" 218 | keys, values, condition = qdmr_args 219 | # two cases: (1) value is string grounded in column; (2) value is subquery 220 | # (1) extract the new condition (col, comp, val) from the grounded conditions clause 221 | conditions = get_condition(step) 222 | keys_step = grounded_steps[keys.replace("#", "")] 223 | values_step = grounded_steps[values.replace("#", "")] 224 | prev_conditions = get_condition(keys_step) + get_condition(values_step) 225 | new_conds = list(filter(lambda cond: cond not in prev_conditions, conditions)) 226 | # (2) extract the comparator and referenced subquery from the arguments 227 | comparator, value = extract_comparator(condition) 228 | if is_reference(value): 229 | new_conds = [[cond[0], cond[1], value] for cond in new_conds] 230 | data["arguments"] = [keys, values, new_conds] 231 | data["string"] = grounded_step_encoding(data, step) 232 | return data 233 | 234 | 235 | def grounded_comparative_group(step, qdmr_args): 236 | data = {} 237 | data["op"] = "comparative_group" 238 | keys, values, condition = qdmr_args 239 | data["arguments"] = [keys, values, get_having_clause(step)] 240 | data["string"] = grounded_step_encoding(data, step) 241 | return data 242 | 243 | 244 | def grounded_intersection(step, qdmr_args): 245 | data = {} 246 | data["op"] = "intersection" 247 | projection = get_columns(step) 248 | refs = qdmr_args[1:] 249 | data["arguments"] = projection + refs 250 | data["string"] = grounded_step_encoding(data, step) 251 | return data 252 | 253 | 254 | def grounded_union_column(step, qdmr_args): 255 | data = {} 256 | data["op"] = "union_column" 257 | data["arguments"] = qdmr_args 258 | data["string"] = grounded_step_encoding(data, step) 259 | return data 260 | 261 | 262 | def grounded_union(step, qdmr_args): 263 | data = {} 264 | data["op"] = "union" 265 | data["arguments"] = qdmr_args 266 | data["string"] = grounded_step_encoding(data, step) 267 | return data 268 | 269 | 270 | def grounded_discard(step, qdmr_args): 271 | data = {} 272 | data["op"] = "discard" 273 | input_orig, input_discarded = qdmr_args 274 | orig = [input_orig] if is_reference(input_orig) else get_columns(step) 275 | discarded = [input_discarded] if is_reference(input_discarded) else get_columns(step) 276 | data["arguments"] = orig + discarded 277 | data["string"] = grounded_step_encoding(data, step) 278 | return data 279 | 280 | 281 | def grounded_sort(step, qdmr_args, grounded_steps, encoding): 282 | def order_columns_is_subquery(clause): 283 | return "SELECT " in clause and " FROM " in clause 284 | 285 | def sql_step_id(sql, grnd_steps): 286 | for num in grnd_steps: 287 | if sql in grnd_steps[num]["SQL"]: 288 | return int(num) - 1 289 | return None 290 | 291 | def get_grounded_enc_subquery(sql, grnd_steps, encoding_list): 292 | step_id = sql_step_id(sql, grnd_steps) 293 | return encoding_list[step_id]["string"] 294 | 295 | data = {} 296 | data["op"] = "sort" 297 | results_phrase = qdmr_args[0] 298 | results_ref = extract_refs(results_phrase)[0] 299 | order_cols = get_order_columns(step) 300 | if order_columns_is_subquery(order_cols[0]): 301 | # order column is a sql subquery, extract the *encoded* subquery 302 | order_cols = [get_grounded_enc_subquery(order_cols[0], grounded_steps, encoding)] 303 | order = get_order_clause(step) 304 | data["arguments"] = [results_ref] + order_cols + order 305 | data["string"] = grounded_step_encoding(data, step) 306 | return data 307 | 308 | 309 | def grounded_arithmetic(step, qdmr_args): 310 | data = {} 311 | data["op"] = "arithmetic" 312 | arith, ref1, ref2 = qdmr_args 313 | arith_op = get_arithmetic_op(arith) 314 | assert arith_op is not None 315 | data["arguments"] = [arith_op, ref1, ref2] 316 | data["string"] = grounded_step_encoding(data, step) 317 | return data 318 | 319 | 320 | def encode_qdmr_steps(qdmr_steps, grounded_steps): 321 | encoding = [] 322 | for num in grounded_steps: 323 | step = grounded_steps[num] 324 | operator = step["op"] 325 | sql = step["SQL"] 326 | qdmr_step = qdmr_steps[int(num) - 1] 327 | if operator == "select": 328 | encoding += [grounded_select(step)] 329 | if operator == "project": 330 | encoding += [grounded_project(step, qdmr_step.arguments)] 331 | if operator == "filter": 332 | encoding += [grounded_filter(step, qdmr_step.arguments, grounded_steps)] 333 | if operator == "aggregate": 334 | encoding += [grounded_aggregate(step, qdmr_step.arguments)] 335 | if operator == "group": 336 | encoding += [grounded_group(step, qdmr_step.arguments)] 337 | if operator == "superlative" or operator == "superlative_group": 338 | encoding += [grounded_superlative(step, qdmr_step.arguments)] 339 | if operator == "comparative": 340 | encoding += [grounded_comparative(step, qdmr_step.arguments, grounded_steps)] 341 | if operator == "comparative_group": 342 | encoding += [grounded_comparative_group(step, qdmr_step.arguments)] 343 | if operator == "intersection": 344 | encoding += [grounded_intersection(step, qdmr_step.arguments)] 345 | if operator == "union_column": 346 | encoding += [grounded_union_column(step, qdmr_step.arguments)] 347 | if operator == "union": 348 | encoding += [grounded_union(step, qdmr_step.arguments)] 349 | if operator == "discard": 350 | encoding += [grounded_discard(step, qdmr_step.arguments)] 351 | if operator == "sort": 352 | encoding += [grounded_sort(step, qdmr_step.arguments, grounded_steps, encoding)] 353 | if operator == "arithmetic": 354 | encoding += [grounded_arithmetic(step, qdmr_step.arguments)] 355 | encoded_strs = [enc_step["string"] for enc_step in encoding] 356 | return encoded_strs 357 | -------------------------------------------------------------------------------- /src/data_generation/qdmr_encoding_parser.py: -------------------------------------------------------------------------------- 1 | import pyparsing as pp 2 | from collections import namedtuple 3 | 4 | from qdmr_encoding import is_reference 5 | 6 | QDMR_STEP_DELIMITER = ";" 7 | 8 | op_list = ["select", "project", "filter", "aggregate", "group", "superlative", "comparative", 9 | "comparative_group", "intersection", "union_column", "union", "discard", "sort", "arithmetic"] 10 | comparators = ["=", ">", "<", ">=", "<=", "!=", "LIKE", "like", "BETWEEN", "start", "end"] 11 | aggregates = ["COUNT", "SUM", "AVG", "MIN", "MAX", "count", "sum", "avg", "min", "max"] 12 | arithmetics = ["+", "-", "*", "/"] 13 | OP = pp.oneOf(op_list) 14 | COMP = pp.oneOf(comparators) 15 | AGGR = pp.oneOf(aggregates) 16 | ARITHMETIC = pp.oneOf(arithmetics) 17 | LP = pp.Literal("(").suppress() 18 | RP = pp.Literal(")").suppress() 19 | COMMA = pp.Literal(",").suppress() 20 | String = pp.Word(pp.alphanums + "_" + "-" + "." + "%" + "*" + "/") 21 | SingleQuoteString = pp.QuotedString(quoteChar="'", unquoteResults=False) 22 | DoubleQuoteString = pp.QuotedString(quoteChar='"', unquoteResults=False) 23 | QuotedString = SingleQuoteString | DoubleQuoteString 24 | ConditionPrefix = AGGR + pp.Literal("(") + String + pp.Literal(")") | String 25 | BetweenValue = pp.Word(pp.alphanums) + pp.Literal("AND") + pp.Word(pp.alphanums) 26 | BasicCondition = pp.Group(ConditionPrefix + COMP + pp.OneOrMore(String)) 27 | Atom = BasicCondition | ConditionPrefix | QuotedString | ARITHMETIC 28 | SExpr = pp.Forward() 29 | FormulaCondition = ConditionPrefix + COMP + SExpr 30 | SExprList = pp.Group((FormulaCondition | SExpr | Atom) + pp.ZeroOrMore(COMMA + (FormulaCondition | SExpr | Atom))) 31 | SExpr << (OP + LP + SExprList + RP) 32 | 33 | Node = namedtuple("Node", ["operator", "arguments"]) 34 | 35 | 36 | def parseAction(string, location, tokens): 37 | return Node(operator=tokens[0], arguments=tokens[1:]) 38 | 39 | 40 | SExpr.setParseAction(parseAction) 41 | 42 | 43 | def pprint(node, tab=""): 44 | print(tab + u"|--" + str(node.operator)) 45 | new_tab = tab + " " 46 | for arg in node.arguments[0]: 47 | if isinstance(arg, Node): 48 | pprint(arg, new_tab) 49 | else: 50 | print(new_tab + arg) 51 | 52 | 53 | def formula_dfs(node, stack): 54 | s = "%s ( " % str(node.operator) 55 | space = " " 56 | for i in range(len(node.arguments[0])): 57 | arg = node.arguments[0][i] 58 | comma = "" if i == 0 else " , " 59 | if isinstance(arg, Node): 60 | last_token = s[-1] 61 | # handle case where argument is formula value of a condition 62 | delimiter = comma if last_token not in comparators else space 63 | s += delimiter + str(formula_dfs(arg, stack)[0]) 64 | elif isinstance(arg, pp.ParseResults): 65 | s += comma + ' '.join(arg) # argument is a simple condition list 66 | else: 67 | # handle case where argument is the comparator of a formula condition 68 | s += comma + str(arg) if arg not in comparators else space + str(arg) 69 | s += " )" 70 | stack += [s] 71 | return s, stack 72 | 73 | 74 | def dfs_ref_substitution(dfs_qdmr_steps): 75 | def remove_references(steps_list): 76 | return list(filter(lambda x: not is_reference(x), steps_list)) 77 | 78 | ret_steps = [] 79 | for i in range(len(dfs_qdmr_steps)): 80 | next_ref = "#%s" % (len(ret_steps) + 1) 81 | next_step = dfs_qdmr_steps[i] 82 | new_steps = [] 83 | if not is_reference(next_step): 84 | for step in dfs_qdmr_steps[i + 1:]: 85 | new_steps += [step.replace(next_step, next_ref)] 86 | dfs_qdmr_steps = dfs_qdmr_steps[:i + 1] + new_steps 87 | ret_steps += [next_step] 88 | return remove_references(dfs_qdmr_steps) 89 | 90 | 91 | def formula_qdmr_to_ref_steps(qdmr_formula_encoding): 92 | parsed = SExpr.parseString(qdmr_formula_encoding) 93 | dfs_steps = formula_dfs(parsed[0], [])[1] 94 | return dfs_ref_substitution(dfs_steps) 95 | 96 | 97 | def formula_to_ref_encoding(qdmr_formula_encoding): 98 | ref_steps = formula_qdmr_to_ref_steps(qdmr_formula_encoding) 99 | delim = " %s " % QDMR_STEP_DELIMITER 100 | return delim.join(ref_steps) 101 | -------------------------------------------------------------------------------- /src/data_generation/qdmr_identifier.py: -------------------------------------------------------------------------------- 1 | from operator_identifier import * 2 | from utils import * 3 | 4 | 5 | class QDMRStep: 6 | def __init__(self, step_text, operator, arguments): 7 | self.step = step_text 8 | self.operator = operator 9 | self.arguments = arguments 10 | 11 | def __str__(self): 12 | return "%s%a" % (self.operator.upper(), self.arguments) 13 | 14 | 15 | class StepIdentifier(object): 16 | def __init__(self): 17 | self.identifiers = {"select": IdentifyOperatorSelect(), 18 | "filter": IdentifyOperatorFilter(), 19 | "project": IdentifyOperatorProject(), 20 | "aggregate": IdentifyOperatorAggregate(), 21 | "group": IdentifyOperatorGroup(), 22 | "superlative": IdentifyOperatorSuperlative(), 23 | "comparative": IdentifyOperatorComparative(), 24 | "union": IdentifyOperatorUnion(), 25 | "intersection": IdentifyOperatorIntersect(), 26 | "discard": IdentifyOperatorDiscard(), 27 | "sort": IdentifyOperatorSort(), 28 | "boolean": IdentifyOperatorBoolean(), 29 | "arithmetic": IdentifyOperatorArithmetic(), 30 | "comparison": IdentifyOperatorComparison()} 31 | self.operator = None 32 | 33 | def step_type(self, step_text): 34 | potential_operators = set() 35 | for op in self.identifiers: 36 | identifier = self.identifiers[op] 37 | if identifier.identify_op(step_text): 38 | potential_operators.add(op) 39 | # no matching operator found 40 | if len(potential_operators) == 0: 41 | return None 42 | operators = potential_operators.copy() 43 | # duplicate candidates 44 | while len(operators) > 1: 45 | # avoid project duplicity with aggregate 46 | if "project" in operators: 47 | operators.remove("project") 48 | # avoid filter duplcitiy with comparative, superlative, sort, discard 49 | elif "filter" in operators: 50 | operators.remove("filter") 51 | # return boolean (instead of intersect) 52 | elif "boolean" in operators: 53 | operators = {"boolean"} 54 | # return intersect (instead of filter) 55 | elif "intersect" in operators: 56 | operators = {"intersect"} 57 | # return superlative (instead of comparative) 58 | elif "superlative" in operators: 59 | operators = {"superlative"} 60 | # return group (instead of arithmetic) 61 | elif "group" in operators: 62 | operators = {"group"} 63 | # return comparative (instead of discard) 64 | elif "comparative" in operators: 65 | operators = {"comparative"} 66 | # return intersection (instead of comparison) 67 | elif "intersection" in operators: 68 | operators = {"intersection"} 69 | else: 70 | # no valid operator 71 | assert (len(operators) == 1) 72 | operator = list(operators)[0] 73 | self.operator = operator 74 | return operator 75 | 76 | def step_args(self, step_text): 77 | self.operator = self.step_type(step_text) 78 | identifier = self.identifiers[self.operator] 79 | args = identifier.extract_args(step_text) 80 | return args 81 | 82 | def identify(self, step_text): 83 | self.operator = self.step_type(step_text) 84 | args = self.step_args(step_text) 85 | return QDMRStep(step_text, self.operator, args) 86 | 87 | 88 | class QDMRProgramBuilder(object): 89 | def __init__(self, qdmr_text): 90 | self.qdmr_text = qdmr_text 91 | self.steps = None 92 | self.operators = None 93 | self.program = None 94 | 95 | def build(self): 96 | self.get_operators() 97 | self.build_steps() 98 | return True 99 | 100 | def build_steps(self): 101 | self.steps = [] 102 | steps = parse_decomposition(self.qdmr_text) 103 | step_identifier = StepIdentifier() 104 | for step_text in steps: 105 | try: 106 | step = step_identifier.identify(step_text) 107 | except: 108 | print("Unable to identify step: %s" % step_text) 109 | step = None 110 | finally: 111 | self.steps += [step] 112 | return self.steps 113 | 114 | def get_operators(self): 115 | self.operators = [] 116 | steps = parse_decomposition(self.qdmr_text) 117 | step_identifier = StepIdentifier() 118 | for step_text in steps: 119 | try: 120 | op = step_identifier.step_type(step_text) 121 | except: 122 | print("Unable to identify operator: %s" % step_text) 123 | op = None 124 | finally: 125 | self.operators += [op] 126 | return self.operators 127 | 128 | def build_program(self): 129 | raise NotImplementedError 130 | return True 131 | -------------------------------------------------------------------------------- /src/data_generation/schema_parser.py: -------------------------------------------------------------------------------- 1 | # DB schema abstraction 2 | 3 | # A sqlite3 schema parser 4 | 5 | import sqlite3 6 | import traceback 7 | import sys 8 | 9 | 10 | class SchemaParser(object): 11 | def __init__(self): 12 | self.path = None 13 | 14 | def parse(self, schema_path, name): 15 | self.path = schema_path 16 | parsed_data = {'db_id': name, 17 | 'table_names_original': [], 18 | 'table_names': [], 19 | 'column_names_original': [(-1, '*')], 20 | 'column_names': [(-1, '*')], 21 | 'column_types': ['text'], 22 | 'primary_keys': [], 23 | 'foreign_keys': []} 24 | 25 | conn = sqlite3.connect(self.path) 26 | conn.execute('pragma foreign_keys=ON') 27 | cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table';") 28 | 29 | fk_holder = [] 30 | for i, item in enumerate(cursor.fetchall()): 31 | table_name = item[0] 32 | parsed_data['table_names_original'].append(table_name) 33 | parsed_data['table_names'].append(table_name.lower().replace("_", ' ')) 34 | fks = conn.execute("PRAGMA foreign_key_list('{}') ".format(table_name)).fetchall() 35 | # print("db:{} table:{} fks:{}".format(f,table_name,fks)) 36 | fk_holder.extend([[(table_name, fk[3]), (fk[2], fk[4])] for fk in fks]) 37 | cur = conn.execute("PRAGMA table_info('{}') ".format(table_name)) 38 | for j, col in enumerate(cur.fetchall()): 39 | parsed_data['column_names_original'].append((i, col[1])) 40 | parsed_data['column_names'].append((i, col[1].lower().replace("_", " "))) 41 | # varchar, '' -> text, int, numeric -> integer, 42 | col_type = col[2].lower() 43 | if 'char' in col_type or col_type == '' or 'text' in col_type or 'var' in col_type: 44 | parsed_data['column_types'].append('text') 45 | elif 'int' in col_type or 'numeric' in col_type or 'decimal' in col_type or 'number' in col_type \ 46 | or 'id' in col_type or 'real' in col_type or 'double' in col_type or 'float' in col_type: 47 | parsed_data['column_types'].append('number') 48 | elif 'date' in col_type or 'time' in col_type or 'year' in col_type: 49 | parsed_data['column_types'].append('time') 50 | elif 'boolean' in col_type: 51 | parsed_data['column_types'].append('boolean') 52 | else: 53 | parsed_data['column_types'].append('others') 54 | 55 | if col[5] == 1: 56 | parsed_data['primary_keys'].append(len(parsed_data['column_names']) - 1) 57 | 58 | parsed_data["foreign_keys"] = fk_holder 59 | parsed_data['foreign_keys'] = self.convert_fk_index(parsed_data) 60 | return parsed_data 61 | 62 | def convert_fk_index(self, data): 63 | fk_holder = [] 64 | for fk in data["foreign_keys"]: 65 | tn, col, ref_tn, ref_col = fk[0][0], fk[0][1], fk[1][0], fk[1][1] 66 | ref_cid, cid = None, None 67 | try: 68 | tid = data['table_names_original'].index(tn) 69 | ref_tid = data['table_names_original'].index(ref_tn) 70 | 71 | for i, (tab_id, col_org) in enumerate(data['column_names_original']): 72 | if tab_id == ref_tid and ref_col == col_org: 73 | ref_cid = i 74 | elif tid == tab_id and col == col_org: 75 | cid = i 76 | if ref_cid and cid: 77 | fk_holder.append([cid, ref_cid]) 78 | except: 79 | traceback.print_exc() 80 | print("table_names_original: ", data['table_names_original']) 81 | print("finding tab name: ", tn, ref_tn) 82 | sys.exit() 83 | return fk_holder -------------------------------------------------------------------------------- /src/data_generation/sql_execution.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | import time 3 | import threading 4 | import sys 5 | from collections import OrderedDict 6 | from wrapt_timeout_decorator import * 7 | 8 | 9 | TIMEOUT = 60 10 | TIMEOUT = 90 # 90 seconds for Academic, IMDB & Yelp 11 | 12 | 13 | def interrupt_sqlite(connection): 14 | print('Interrupted sqlite connection', file=sys.stderr) 15 | connection.interrupt() 16 | 17 | 18 | @timeout(dec_timeout=TIMEOUT, use_signals=False) 19 | def execute_sql(db, sql): 20 | """ 21 | Returns a list of tuple that are the query results 22 | 23 | Parameters 24 | ---------- 25 | db : str 26 | Full path to DB schema 27 | sql : str 28 | SQL query to be executed 29 | 30 | 31 | Returns 32 | ------- 33 | list 34 | List of tuple that are the query results 35 | """ 36 | conn = sqlite3.connect(db) 37 | conn.text_factory = lambda b: b.decode(errors='ignore') 38 | c = conn.cursor() 39 | try: 40 | c.execute(sql) 41 | except: 42 | return None 43 | return c.fetchall() 44 | 45 | 46 | def normalize_tuple(tup): 47 | # cast all tuple values to strings 48 | norm_vars = [str(var) for var in tup] 49 | return tuple(norm_vars) 50 | 51 | 52 | def normalize_denotation(denotation_list, distinct=None): 53 | if not denotation_list: 54 | return denotation_list 55 | # remove duplicates 56 | denotation_list = list(OrderedDict.fromkeys(denotation_list)) if distinct else denotation_list 57 | sort_tuples = [sorted(normalize_tuple(tup)) for tup in denotation_list] 58 | return sorted(sort_tuples) # sort results set 59 | 60 | 61 | def correct_denotation(pred_sql, gold_sql, db_path, distinct=None): 62 | gold_denotation = execute_sql(db_path, gold_sql) 63 | pred_denotation = execute_sql(db_path, pred_sql) 64 | if gold_denotation == pred_denotation: 65 | # unnormalized denotations 66 | return True 67 | gold_denotation_norm = normalize_denotation(gold_denotation, distinct=distinct) 68 | pred_denotation_norm = normalize_denotation(pred_denotation, distinct=distinct) 69 | return gold_denotation_norm == pred_denotation_norm 70 | -------------------------------------------------------------------------------- /src/data_generation/sql_parser.py: -------------------------------------------------------------------------------- 1 | # A SQL query parser - for grounding evaluation 2 | import re 3 | from utils import get_table_and_column 4 | 5 | 6 | class SQLParser(object): 7 | 8 | 9 | def __init__(self): 10 | self.tables = None 11 | self.columns = None 12 | 13 | 14 | def parse(self, query, schema): 15 | query_tables = set() 16 | query_columns = set() 17 | query = query.lower() 18 | tokens = query.split() 19 | for tok in tokens: 20 | if re.match(r't[0-9]\.', tok): 21 | query_columns.add(tok) 22 | # get tables in sql query 23 | from_clause = query.split('from')[1].split('where')[0] 24 | from_tokens = from_clause.split() 25 | schema_tables = schema.tables() 26 | schema_tables_lowercase = [name.lower() for name in schema_tables] 27 | for tok in from_tokens: 28 | if tok in schema_tables_lowercase: 29 | query_tables.add(tok) 30 | self.tables = list(query_tables) 31 | columns = list(query_columns) 32 | if len(self.tables) == 1: 33 | # all columns in query belong to a single table 34 | table_name = self.tables[0] 35 | for tok in tokens: 36 | # parse token from 'op()' and 'table.col' 37 | tok = tok.split('(')[1] if '(' in tok else tok 38 | tok = tok.split(')')[0] if ')' in tok else tok 39 | tok = tok.split('.')[1] if '.' in tok else tok 40 | schema_columns = schema.columns() 41 | for col in schema_columns: 42 | if col == tok: 43 | col_full_name = "%s.%s" % (table_name, col) 44 | query_columns.add(col_full_name) 45 | self.columns = list(query_columns) 46 | return True 47 | # more than one table in query 48 | # replace column table alias T1.col --> table_name.col 49 | aliases = re.findall(r'as\st[0-9]', query) 50 | alias_map = {} 51 | for alias in aliases: 52 | table_alias = alias.split()[-1] 53 | prefix = query.split(alias)[0] 54 | table_name = prefix.split()[-1] 55 | alias_map[table_alias] = table_name 56 | self.columns = [] 57 | for col in columns: 58 | for alias in alias_map.keys(): 59 | if alias in col: 60 | column_full = col.replace(alias, alias_map[alias]) 61 | self.columns += [column_full] 62 | self.columns = list(set(self.columns)) 63 | return True 64 | 65 | 66 | def get_table_aliases(self, query): 67 | """Returns map from table alias (t#) to its name""" 68 | query = query.lower() 69 | aliases = re.findall(r'as\st[0-9]', query) 70 | alias_map = {} 71 | for alias in aliases: 72 | table_alias = alias.split()[-1] 73 | prefix = query.split(alias)[0] 74 | table_name = prefix.split()[-1] 75 | alias_map[table_alias] = table_name 76 | # map from table alias (e.g. t1) to its name 77 | return alias_map 78 | 79 | 80 | def extract_values(self, query, schema): 81 | query = query.lower() 82 | value_to_col = {} 83 | # Find all values based on string delimiters 84 | single_paren_vals = [item.group(0) for item in re.finditer(r'\'.*?\'', query)] 85 | double_paren_vals = [item.group(0) for item in re.finditer(r'\".*?\"', query)] 86 | number_vals = [item.group(0) for item in re.finditer(r'[0-9]+', query)] 87 | # filter numbers in table aliases e.g., 1 in T1 88 | number_vals = list(filter(lambda x: (" %s" % x) in query, number_vals)) 89 | vals = single_paren_vals + double_paren_vals + number_vals 90 | # Map values to corresponding columns 91 | for value in vals: 92 | # SQL satement will be: "table.col operator value", e.g.: 93 | # T2.allergytype = "food" 94 | # name LIKE '%Led%' 95 | table = None 96 | prefix = query.split(value)[0] 97 | aliased_column = prefix.split()[-2] 98 | column_names = schema.column_names() 99 | schema_columns = schema.columns() 100 | if "." in aliased_column: 101 | # column is either aliased T#.col or table.col 102 | aliased_table, col = get_table_and_column(aliased_column) 103 | table = self.get_aliased_table(aliased_table, query, schema) 104 | elif aliased_column.lower() not in column_names: 105 | # nearest token is not column name 106 | # return the nearest column name instead 107 | preceding_toks = prefix.lower().split() 108 | for i in reversed(range(len(preceding_toks))): 109 | if preceding_toks[i] in column_names: 110 | aliased_column = preceding_toks[i] 111 | break 112 | else: 113 | # no aliased table in query 114 | # find nearest table to the column name 115 | col = aliased_column 116 | col_match_positions = [m.start() for m in re.finditer(col, query)] 117 | last_match_pos = col_match_positions[-1] 118 | preceding_toks = query[:last_match_pos].split() 119 | table_names = schema.tables() 120 | for i in reversed(range(len(preceding_toks))): 121 | if preceding_toks[i] in table_names: 122 | table = preceding_toks[i] 123 | full_col_name = "%s.%s" % (table, col) 124 | if full_col_name in schema_columns: 125 | # validate full column name is valid 126 | break 127 | # non-number values have parentheses 128 | value_no_paren = value[1:-1] if not value.isdigit() else value 129 | if value_no_paren.startswith("%") \ 130 | and value_no_paren.endswith("%"): 131 | # value extracted from LIKE '%%' statement 132 | value_no_paren = value_no_paren[1:-1] 133 | if table: 134 | value_to_col[value_no_paren.strip()] = "%s.%s".strip() % (table, col) 135 | return value_to_col 136 | 137 | 138 | def get_aliased_table(self, aliased_table, query, schema): 139 | """ 140 | Receive table name referenced query and retreive its actual table 141 | Handles: 142 | Spider aliases format e.g., T#.column 143 | ATIS aliases format e.g., table_#.column 144 | """ 145 | table_aliases = self.get_table_aliases(query) 146 | if re.match(r't[0-9]', aliased_table): 147 | return table_aliases[aliased_table] 148 | if re.match(r'.*\_[0-9]', aliased_table): 149 | # remove the '_#' suffix 150 | actual_table = '_'.join(aliased_table.split('_')[:-1]) 151 | if actual_table in schema.tables(): 152 | return actual_table 153 | return aliased_table -------------------------------------------------------------------------------- /src/data_generation/sql_query.py: -------------------------------------------------------------------------------- 1 | # SQL Query Abstraction 2 | 3 | import db_schema 4 | from qdmr_identifier import * 5 | 6 | 7 | class SQLQuery: 8 | def __init__(self, schema_path, query=None): 9 | self.schema_path = schema_path 10 | self.query = query 11 | self.results = None 12 | self.tables = None 13 | self.columns = None 14 | self.subqueries = None 15 | self.schema = db_schema.DBSchema(self.schema_path) 16 | 17 | def execute(self): 18 | results = set() 19 | conn = db_schema.sqlite3.connect(self.schema_path) 20 | c = conn.cursor() 21 | for row in c.execute(self.query): 22 | results.add(row) 23 | conn.close() 24 | self.results = results 25 | return True 26 | 27 | def add_subquery(self, subquery): 28 | if self.subqueries == None: 29 | self.subqueries = [] 30 | self.subqueries += [subquery] 31 | return True 32 | 33 | def ground(self, qdmr, question=None): 34 | if self.query is not None: 35 | return False 36 | 37 | # parse QDMR 38 | qdmr_steps = QDMRProgramBuilder(qdmr).build_steps() 39 | string_steps = [str(step) for step in qdmr_steps] 40 | print(string_steps) 41 | 42 | # add referenced subqueries 43 | # build SQL query 44 | return True 45 | -------------------------------------------------------------------------------- /src/data_generation/test_encoding_conversion.py: -------------------------------------------------------------------------------- 1 | # TODO: Test the mapping of grounded QDMR encoded as formula to the reference steps encoding 2 | # E.g.: 3 | # project ( table2.column2 , select ( table.column ) ) --> 4 | # select ( table.column ) ; project ( table2.column2 , #1 ) 5 | # Steps: 6 | # 1. Read grounded QDMR encodings from file 7 | # 2. Convert formula (no-ref) encodings to ref steps encoding 8 | # 3. Compare the converted encoding to the original ref steps encoding 9 | # 4. Print error if the converted ref steps is different from the original 10 | # 5. Check if the different ref step encodings (original & converted) are still equivalent 11 | 12 | from tqdm import tqdm 13 | 14 | from qdmr_encoding_parser import formula_to_ref_encoding 15 | from write_encoding import load_json, write_to_json 16 | 17 | 18 | def test_enc_conversion(grounded_qdmr_file, output_file): 19 | raw_data = load_json(grounded_qdmr_file) 20 | examples = raw_data["data"] 21 | failed_examples = {} 22 | failed_examples["data"] = [] 23 | num_correct = 0 24 | for i in tqdm(range(len(examples)), desc="Loading...", ascii=False, ncols=75): 25 | example = examples[i] 26 | enc_example = {} 27 | enc_example["ex_id"] = example["example_id"] 28 | enc_example["db_name"] = example["db_id"] 29 | enc_example["question"] = example["question"] 30 | enc_example["qdmr"] = example["grounding"]["qdmr_grounding"] 31 | enc_example["sql_ground"] = example["grounding"]["grounded_sql"] 32 | enc_example["qdmr_ref_enc"] = example["grounding_enc_has_ref"] 33 | enc_example["qdmr_formula_enc"] = example["grounding_enc_no_ref"] 34 | enc_example["error"] = None 35 | try: 36 | enc_example["converted_ref_enc"] = formula_to_ref_encoding(enc_example["qdmr_formula_enc"]) 37 | except: 38 | enc_example["converted_ref_enc"] = None 39 | enc_example["error"] = "PARSE_ERROR" 40 | failed_examples["data"] += [enc_example] 41 | finally: 42 | if enc_example["error"] is None: 43 | if enc_example["converted_ref_enc"] == enc_example["qdmr_ref_enc"]: 44 | num_correct += 1 45 | else: 46 | enc_example["error"] = "CONVERSION_ERROR" 47 | failed_examples["data"] += [enc_example] 48 | write_to_json(failed_examples, output_file) 49 | num_examples = len(examples) 50 | print(f"Done writing {num_examples} examples to {output_file}.") 51 | print(f"Number of correctly converted formula encodings: {num_correct}/{num_examples}.") 52 | return True 53 | -------------------------------------------------------------------------------- /src/data_generation/test_grounded_qdmr.py: -------------------------------------------------------------------------------- 1 | # TODO: ensure the mapping to SQL of grounded QDMRs is accurate 2 | # 1. Read all grounded QDMRs along with their SQL queries for: Spider train, dev & Geo 3 | # 2. Map grounded QDMRs to SQL 4 | # 3. Compare execution results of grounded SQL versus mapped SQL and debug errors 5 | # 3.1. Potential self-join issues 6 | # 4. Deal with the conversion of parentheses QDMR encoding to reference-based encoding 7 | import random 8 | 9 | from tqdm import tqdm 10 | from db_schema import DBSchema 11 | from grounded_qdmr import GroundedQDMR 12 | from preprocess_db import prepare_db_schema 13 | from qdmr_encoding_parser import formula_to_ref_encoding 14 | from sql_execution import correct_denotation 15 | from write_encoding import encoded_grounded_qdmr, load_json, write_to_json 16 | 17 | 18 | # 0. Create the grounded QDMR encodings 19 | 20 | def create_qdmr_encodings(): 21 | # grounded_qdmr_file = "data/qdmr_grounding/qdmr_groundings_unlimited_spider_train.json" 22 | # encoded_grounded_qdmr(grounded_qdmr_file, "data/qdmr_grounding/qdmr_ground_enc_spider_train.json") 23 | # grounded_qdmr_file = "data/qdmr_grounding/groundings_spider_dev.json" 24 | # encoded_grounded_qdmr(grounded_qdmr_file, "data/qdmr_grounding/qdmr_ground_enc_spider_dev.json") 25 | # grounded_qdmr_file = "data/qdmr_grounding/groundings_geo880.json" 26 | # encoded_grounded_qdmr(grounded_qdmr_file, "data/qdmr_grounding/qdmr_ground_enc_geo880.json") 27 | # grounded_qdmr_file = "data/qdmr_grounding/groundings_predicted_spider_dev.json" 28 | # encoded_grounded_qdmr(grounded_qdmr_file, "data/qdmr_grounding/qdmr_ground_enc_predicted_spider_dev.json") 29 | # grounded_qdmr_file = "data/qdmr_grounding/groundings_predicted_spider_train_40_db.json" 30 | # encoded_grounded_qdmr(grounded_qdmr_file, "data/qdmr_grounding/qdmr_ground_enc_predicted_spider_train_40_db.json") 31 | # grounded_qdmr_file = "data/qdmr_grounding/groundings_predicted_spider_train_40_db_02_V2.json" 32 | # encoded_grounded_qdmr(grounded_qdmr_file, "data/qdmr_grounding/qdmr_ground_enc_predicted_spider_train_40_db_V2.json") 33 | # grounded_qdmr_file = "data/qdmr_grounding/groundings_predicted_spider_train_30_db_02.json" 34 | # encoded_grounded_qdmr(grounded_qdmr_file, "data/qdmr_grounding/qdmr_ground_enc_predicted_spider_train_30_db.json") 35 | # grounded_qdmr_file = "data/qdmr_grounding/groundings_predicted_geo880_train_zero_shot.json" 36 | # encoded_grounded_qdmr(grounded_qdmr_file, "data/qdmr_grounding/qdmr_ground_enc_predicted_geo_train_zero_shot.json") 37 | grounded_qdmr_file = "data/qdmr_grounding/spider_train_few_shot_groundings/groundings_predicted_spider_train_few_shot_full.json" 38 | encoded_grounded_qdmr(grounded_qdmr_file, "data/qdmr_grounding/spider_train_few_shot_groundings/qdmr_ground_enc_predicted_spider_train_few_shot.json") 39 | 40 | 41 | # 1. Convert map encoding to SQL and compare execution with grounded SQL 42 | 43 | def encoding_to_sql(has_ref_encoding, question, db_dir, db_name, dataset=None): 44 | schema_path = "%s/%s/%s.sqlite" % (db_dir, db_name, db_name) 45 | # add missing join paths for non-Spider DB schema 46 | dataset = "spider" if dataset in ["spider", None] else dataset 47 | schema = prepare_db_schema(schema_path, dataset=dataset) 48 | grounded_qdmr = GroundedQDMR(has_ref_encoding, question, schema) 49 | grounded_qdmr.to_sql() 50 | n = str(len(grounded_qdmr.sql_steps)) 51 | return grounded_qdmr.sql_steps[n]["SQL"] 52 | 53 | 54 | def evaluate_grounded_qdmr(grounded_qdmr_file, db_dir, output_file, encoding=None, dataset=None): 55 | assert encoding in ["no_ref", "has_ref", None] 56 | raw_data = load_json(grounded_qdmr_file) 57 | examples = raw_data["data"] 58 | enc_data = {} 59 | enc_data["data"] = [] 60 | num_correct = 0 61 | for i in tqdm(range(len(examples)), desc="Loading...", ascii=False, ncols=75): 62 | example = examples[i] 63 | enc_example = {} 64 | enc_example["ex_id"] = example["example_id"] 65 | enc_example["db_name"] = example["db_id"] 66 | enc_example["question"] = example["question"] 67 | enc_example["qdmr"] = example["grounding"]["qdmr_grounding"] 68 | enc_example["sql_ground"] = example["grounding"]["grounded_sql"] 69 | enc_example["qdmr_ground_enc"] = example["grounding_enc_has_ref"] 70 | if encoding == "no_ref": 71 | formula_encoding = example["grounding_enc_no_ref"] 72 | enc_example["qdmr_ground_enc_original"] = formula_encoding 73 | try: 74 | enc_example["qdmr_ground_enc"] = formula_to_ref_encoding(formula_encoding) 75 | except: 76 | enc_example["qdmr_ground_enc"] = "CONVERSION_ERROR" 77 | try: 78 | enc_example["sql_enc"] = encoding_to_sql(enc_example["qdmr_ground_enc"], enc_example["question"], 79 | db_dir, enc_example["db_name"], dataset=dataset) 80 | except: 81 | enc_example["sql_enc"] = "ERROR" 82 | db_path = "%s/%s/%s.sqlite" % (db_dir, enc_example["db_name"], enc_example["db_name"]) 83 | if enc_example["sql_enc"] == "ERROR": 84 | denotation_flag = False 85 | else: 86 | denotation_flag = correct_denotation(enc_example["sql_enc"], enc_example["sql_ground"], db_path, 87 | distinct=None) 88 | enc_example["correct_enc_denotation"] = denotation_flag 89 | num_correct = num_correct + 1 if denotation_flag else num_correct 90 | enc_data["data"] += [enc_example] 91 | write_to_json(enc_data, output_file) 92 | num_examples = len(enc_data["data"]) 93 | print(f"Done writing {num_examples} examples to {output_file}.") 94 | print(f"Number of correct grounded enc. denotations: {num_correct}/{num_examples}.") 95 | return True 96 | 97 | 98 | create_qdmr_encodings() 99 | 100 | # Spider dataset 101 | # evaluate_grounded_qdmr(grounded_qdmr_file="data/qdmr_grounding/qdmr_ground_enc_predicted_spider_train_30_db.json", 102 | # db_dir="data/spider_databases", 103 | # output_file="data/qdmr_grounding/test_encoding_to_sql_02.json", 104 | # encoding="no_ref") 105 | 106 | # Geo880 dataset 107 | # evaluate_grounded_qdmr(grounded_qdmr_file="data/qdmr_grounding/qdmr_ground_enc_geo880.json", 108 | # db_dir="data/other_databases", 109 | # output_file="data/qdmr_grounding/test_encoding_to_sql_geo880.json", 110 | # encoding="no_ref", 111 | # dataset="geo") 112 | 113 | # Test why the encoding-to-SQL mapping returns different results between runs - join_path_chain 114 | 115 | # path1 = "data/qdmr_grounding/test_encoding_to_sql_01.json" 116 | # path2 = "data/qdmr_grounding/test_encoding_to_sql_02.json" 117 | # 118 | # data = load_json(path1) 119 | # other_data = load_json(path2) 120 | # examples = data["data"] 121 | # other_examples = other_data["data"] 122 | # 123 | # for i in range(len(examples)): 124 | # example = examples[i] 125 | # other_example = other_examples[i] 126 | # if example["correct_enc_denotation"] != other_example["correct_enc_denotation"]: 127 | # print("* example: ", example["ex_id"]) 128 | # print("* grounded: ", example["sql_ground"]) 129 | # print("* qdmr encoded SQL 01: ", example["sql_enc"]) 130 | # print("* result: ", example["correct_enc_denotation"]) 131 | # print("* qdmr encoded SQL 02: ", other_example["sql_enc"]) 132 | # print("* result: ", other_example["correct_enc_denotation"]) 133 | # print("*"*20) 134 | 135 | 136 | def test_unite_two_qdmr_jsons(json_file, other_json_file, dataset_name, output_json): 137 | raw_data = load_json(json_file) 138 | examples = raw_data["data"] 139 | other_raw_data = load_json(other_json_file) 140 | other_examples = other_raw_data["data"] 141 | all_examples = examples + other_examples 142 | dataset_examples = [] 143 | for ex in all_examples: 144 | if dataset_name in ex["example_id"]: 145 | dataset_examples += [ex] 146 | enc_data = {"data": dataset_examples} 147 | write_to_json(enc_data, output_json) 148 | num_examples = len(enc_data["data"]) 149 | print(f"Done writing {num_examples} examples to {output_json}.") 150 | return True 151 | 152 | 153 | def test_spider_random_sample(spider_json, output_json, sample_size): 154 | raw_data = load_json(spider_json) 155 | sample = list(random.sample(raw_data, sample_size)) 156 | write_to_json(sample, output_json) 157 | num_examples = len(sample) 158 | print(f"Done writing {num_examples} examples to {output_json}.") 159 | for ex in sample: 160 | print(ex["query"]) 161 | return True 162 | 163 | 164 | # test_spider_random_sample(spider_json="data/spider_queries/dev.json", 165 | # output_json="test_sample_spider_dev.json", 166 | # sample_size=100) 167 | -------------------------------------------------------------------------------- /src/data_generation/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | DELIMITER = ';' 4 | REF = '#' 5 | 6 | 7 | def parse_decomposition(qdmr): 8 | """Parses the decomposition into an ordered list of steps 9 | 10 | Parameters 11 | ---------- 12 | qdmr : str 13 | String representation of the QDMR 14 | 15 | Returns 16 | ------- 17 | list 18 | returns ordered list of qdmr steps 19 | """ 20 | # remove digit commas 1,000 --> 1000 21 | matches = re.findall(r"[\d,]+[,\d]", qdmr) 22 | for m in matches: 23 | no_comma = m.replace(",", "") 24 | qdmr = qdmr.replace(m, no_comma) 25 | # parse commas as separate tokens 26 | qdmr = qdmr.replace(",", " , ") 27 | crude_steps = qdmr.split(DELIMITER) 28 | steps = [] 29 | for i in range(len(crude_steps)): 30 | step = crude_steps[i] 31 | tokens = step.split() 32 | step = "" 33 | # remove 'return' prefix 34 | for tok in tokens[1:]: 35 | step += tok.strip() + " " 36 | step = step.strip() 37 | steps += [step] 38 | return steps 39 | 40 | 41 | def get_table_and_column(full_column_name): 42 | return full_column_name.split(".") 43 | -------------------------------------------------------------------------------- /src/data_generation/write_encoding.py: -------------------------------------------------------------------------------- 1 | # Encode Grounded QDMR 2 | # For each qdmr steps: 3 | # 1. identify its operator 4 | # 2. identify its qdmr arguments 5 | # 3. ground *specific* arguments to columns, conditions, values in the DB 6 | 7 | import json 8 | 9 | from qdmr_encoding import encode_qdmr_steps, no_reference_encoding, has_reference_encoding 10 | from qdmr_identifier import * 11 | from tqdm import tqdm 12 | 13 | 14 | def load_json(filepath): 15 | with open(filepath, "r") as reader: 16 | text = reader.read() 17 | return json.loads(text) 18 | 19 | 20 | def write_to_json(data, json_file): 21 | with open(json_file, mode='w+', encoding='utf-8') as file: 22 | json.dump(data, file, indent=4) 23 | return True 24 | 25 | 26 | def encoded_grounded_qdmr(grounded_qdmr_file, out_file): 27 | raw_data = load_json(grounded_qdmr_file) 28 | examples = raw_data["data"] 29 | enc_data = {} 30 | enc_data["data"] = [] 31 | for i in tqdm(range(len(examples)), desc="Loading...", ascii=False, ncols=75): 32 | example = examples[i] 33 | # skip incorrectly grounded examples 34 | if (example["correct_denotation"] is False or 35 | example["grounding"] is None): 36 | continue 37 | ex_id = example["example_id"] 38 | db_id = example["db_id"] 39 | question = example["question"] 40 | qdmr = example["grounding"]["qdmr_grounding"] 41 | builder = QDMRProgramBuilder(qdmr) 42 | builder.build() 43 | steps = builder.steps 44 | grounded_steps = example["grounding"]["grounded_steps"] 45 | encoded_list = encode_qdmr_steps(steps, grounded_steps) 46 | enc_example = example 47 | enc_example["grounding_enc_no_ref"] = no_reference_encoding(encoded_list) 48 | enc_example["grounding_enc_has_ref"] = has_reference_encoding(encoded_list) 49 | enc_data["data"] += [enc_example] 50 | write_to_json(enc_data, out_file) 51 | num_examples = len(enc_data["data"]) 52 | print(f"Done writing {num_examples} examples to {out_file}.") 53 | return True -------------------------------------------------------------------------------- /src/data_generation/write_grounding.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from typing import Dict 3 | import random 4 | from tqdm import tqdm 5 | import json 6 | from wrapt_timeout_decorator import * 7 | from ground_example import * 8 | 9 | REPAIR_TIMEOUT = 600 # 10 minutes 10 | REPAIR_TIMEOUT = 3000 # 50 minutes for Academic, IMDB & Yelp 11 | 12 | 13 | def load_grounding_examples(examples_csv): 14 | """gets csv with examples to test""" 15 | data = pd.read_csv(examples_csv) 16 | return data 17 | 18 | 19 | def read_grounding_example(example: Dict[str, str], schema_path): 20 | dataset = example['index'].split("_")[0].lower() 21 | return GroundingTestExample(example['index'], example['db_id'], example['question'],\ 22 | example['qdmr'], example['query'], schema_path, dataset) 23 | 24 | # @timeout(dec_timeout=REPAIR_TIMEOUT, use_signals=False) 25 | def try_repair(grounding_test): 26 | repair_res = grounding_test.repair() 27 | return repair_res 28 | 29 | def write_one_grounding(example: Dict[str, str], output_file, field_names,\ 30 | json_file=None, schema_path=None): 31 | grounding_test = read_grounding_example(example, schema_path) 32 | row_dict = {} 33 | row_dict['example_id'] = grounding_test.example_id 34 | row_dict['db_id'] = grounding_test.db_id 35 | row_dict['question'] = grounding_test.question 36 | row_dict['qdmr'] = grounding_test.qdmr 37 | row_dict['sql_gold'] = grounding_test.gold_sql 38 | gold_denotation = grounding_test.get_gold_sql_denotation() 39 | row_dict['denotation_gold'] = gold_denotation[:min(200, len(gold_denotation))] if gold_denotation \ 40 | else gold_denotation 41 | try: 42 | grounding_test.ground_example() 43 | row_dict['sql_ground'] = grounding_test.grounded_sql 44 | except: 45 | grounding_test.grounding_error = True 46 | row_dict['sql_ground'] = row_dict['denotation_ground'] = "ERROR" 47 | row_dict['correct_denotation'] = False 48 | try: 49 | # grounded SQL execution terminates 50 | if not grounding_test.grounding_error: 51 | grounded_denotation = grounding_test.get_grounded_sql_denotation() 52 | row_dict['denotation_ground'] = grounded_denotation[:min(200, len(grounded_denotation))] 53 | row_dict['correct_denotation'] = grounding_test.correct_denotation(distinct=True) 54 | except: 55 | row_dict['denotation_ground'] = "ERROR" 56 | row_dict['correct_denotation'] = False 57 | 58 | #### try to repair grounding example 59 | if not row_dict['correct_denotation']: 60 | # try to repair the grounding 61 | try: 62 | if try_repair(grounding_test): 63 | repaired_sql = grounding_test.grounded_sql 64 | row_dict['correct_denotation'] = True 65 | row_dict['sql_ground'] = grounding_test.grounded_sql 66 | grounded_denotation = grounding_test.get_grounded_sql_denotation() 67 | row_dict['denotation_ground'] = grounded_denotation[:min(200, len(grounded_denotation))] 68 | except: 69 | row_dict['denotation_ground'] = "ERROR" 70 | row_dict['correct_denotation'] = False 71 | # write results 72 | append_dict_as_row(output_file, row_dict, field_names) 73 | if json_file: 74 | # full grounding data 75 | grounding_dict = None if row_dict['denotation_ground'] == "ERROR" else grounding_test.to_dict() 76 | row_dict['grounding'] = None if grounding_dict is None else grounding_dict['grounding'] 77 | update_grounding_json(json_file, row_dict) 78 | return True 79 | 80 | def write_to_json(data, json_file): 81 | with open(json_file, mode='w+', encoding='utf-8') as file: 82 | json.dump(data, file, indent=4) 83 | return True 84 | 85 | def init_grounding_json(json_file): 86 | data = {} 87 | data['data'] = [] 88 | return write_to_json(data, json_file) 89 | 90 | def update_grounding_json(json_file, grounding_dict): 91 | with open(json_file, mode='r', encoding='utf-8') as file: 92 | data = json.load(file) 93 | temp = data['data'] 94 | temp.append(grounding_dict) 95 | return write_to_json(data, json_file) 96 | 97 | # @exit_after(180) 98 | def write_grounding_results(grounding_examples, output_file, to_json=None): 99 | json_file = to_json 100 | if to_json: 101 | file = output_file.split(".")[0] if "." in output_file else output_file 102 | json_file = f"{file}.json" 103 | init_grounding_json(json_file) 104 | with open(output_file, mode='w', encoding='utf-8') as csv_file: 105 | field_names = ['example_id','db_id','question','qdmr',\ 106 | 'sql_gold', 'sql_ground', 'denotation_gold', 'denotation_ground',\ 107 | 'correct_denotation'] 108 | writer = DictWriter(csv_file, fieldnames=field_names) 109 | writer.writeheader() 110 | n = 6955 111 | # random_ints = [i for i in range(n)][:500] #### Spider train split 112 | # random_ints = [i for i in range(6955, 7136)] #### Academic split 113 | # random_ints = [i for i in range(7136, 7384)] #### Geo split 114 | # random_ints = [i for i in range(7384, 11426)] #### ATIS split 115 | # random_ints = [i for i in range(11426, 11557)] #### IMDB split 116 | # random_ints = [i for i in range(11557, 11685)] #### Yelp split 117 | # random_ints = [i for i in range(11426, 11685)] #### IMDB+Yelp split 118 | # random_ints = [i for i in range(11685, 12562)] #### Geo880 split 119 | # random_ints = [i for i in range(11685, 11697)] #### Geo880 split 120 | # random_ints = [i for i in range(12562, 13589)] #### Spider dev split 121 | # random_ints = [i for i in range(13589)] #### Entire example split 122 | # random_ints = [i for i in range(11685, 13589)] #### Geo880 + Spider dev split 123 | random_ints = [i for i in range(len(grounding_examples))] # use all file examples 124 | # random.shuffle(random_ints) 125 | for i in tqdm(range(len(random_ints)), desc="Loading...", ascii=False, ncols=75): 126 | try: 127 | write_one_grounding(grounding_examples.iloc[random_ints[i]], output_file, \ 128 | field_names=field_names, json_file=json_file) 129 | except: 130 | continue 131 | # write_one_grounding(grounding_examples.iloc[12576], output_file,\ 132 | # field_names=field_names, json_file=json_file) 133 | print("Complete.") 134 | return True 135 | 136 | 137 | -------------------------------------------------------------------------------- /src/qdmr_parser/dataset_qdmr.py: -------------------------------------------------------------------------------- 1 | import json 2 | import collections 3 | 4 | import torch 5 | from torch.utils.data import Dataset, DataLoader 6 | from utils_data import load_json, normalize_whitespace, read_csv_to_dictionaries 7 | 8 | DATASET_DELIMITER = ":@@@" 9 | 10 | 11 | def remove_dataset_delimiter_from_source(source): 12 | return source.split(DATASET_DELIMITER)[1].strip() 13 | 14 | 15 | class BreakDataset(Dataset): 16 | def __init__(self, tokenizer, data_file, source_max_token_len=512, target_max_token_len=512, 17 | prepend_dataset_name=None): 18 | self.data_file = data_file 19 | self.prepend_dataset_name = prepend_dataset_name 20 | 21 | self.source_max_token_len = source_max_token_len 22 | self.target_max_token_len = target_max_token_len 23 | self.tokenizer = tokenizer 24 | self.inputs = [] 25 | self.targets = [] 26 | 27 | self._build() 28 | 29 | def __len__(self): 30 | return len(self.inputs) 31 | 32 | def __getitem__(self, index): 33 | source_ids = self.inputs[index]["input_ids"].squeeze() 34 | target_ids = self.targets[index]["input_ids"].squeeze() 35 | 36 | src_mask = self.inputs[index]["attention_mask"].squeeze() # might need to squeeze 37 | target_mask = self.targets[index]["attention_mask"].squeeze() # might need to squeeze 38 | 39 | return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "target_mask": target_mask} 40 | 41 | def _build(self): 42 | self._build_input_output(self.data_file) 43 | 44 | def _build_input_output(self, data_file): 45 | # Read examples 46 | raw_data = read_csv_to_dictionaries(data_file) 47 | for example in raw_data: 48 | # e.g., SPIDER_train_999 --> spider 49 | dataset = example["question_id"].split("_")[0].lower() 50 | source = example["question_text"] 51 | target = example["decomposition"] 52 | if self.prepend_dataset_name: 53 | source = "%s %s %s" % (dataset, DATASET_DELIMITER, source) 54 | target = normalize_whitespace(target) 55 | source += self.tokenizer.eos_token 56 | target += self.tokenizer.eos_token 57 | input = source.lower() 58 | target = target.lower() 59 | print("**** dataset input: ", input) 60 | print("**** dataset target: ", target) 61 | 62 | # tokenize inputs 63 | tokenized_inputs = self.tokenizer.batch_encode_plus( 64 | [input], 65 | max_length=self.source_max_token_len, 66 | padding='max_length', 67 | return_tensors="pt", 68 | add_special_tokens=True, 69 | truncation=True 70 | ) 71 | # tokenize targets 72 | tokenized_targets = self.tokenizer.batch_encode_plus( 73 | [target], 74 | max_length=self.target_max_token_len, 75 | padding='max_length', 76 | return_tensors="pt", 77 | add_special_tokens=True, 78 | truncation=True 79 | ) 80 | self.inputs.append(tokenized_inputs) 81 | self.targets.append(tokenized_targets) 82 | -------------------------------------------------------------------------------- /src/qdmr_parser/eval_qdmr/eval_string_match.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchtext.data.metrics import bleu_score 3 | 4 | from dataset_qdmr import remove_dataset_delimiter_from_source 5 | from eval_qdmr.sari_hook import get_sari 6 | 7 | import string 8 | 9 | 10 | class StringMatch(object): 11 | 12 | def __init__(self): 13 | return 14 | 15 | def evaluate(self, questions, gold, predict, metrics=None, prepend_dataset_name=None): 16 | """ 17 | :param predict: list 18 | List of lower-cased source questions 19 | :param gold: list 20 | List of cased gold decompositions 21 | :param predict: list 22 | List of lower-cased predicted decompositions 23 | :param metrics: list 24 | Subset of metrics to choose from e.g., only "bleu" 25 | :return: dict 26 | Dictionary with the exact string match & bleu_4 scores 27 | """ 28 | glist = gold 29 | plist = predict 30 | count = 0 31 | exact_match = 0.0 32 | f1_score = 0.0 33 | bleu_score = 0.0 34 | sari_score = 0.0 35 | for p, g, q in zip(plist, glist, questions): 36 | p_str = format_prediction(p) 37 | g_str = format_prediction(g) 38 | q = remove_dataset_delimiter_from_source(q) if prepend_dataset_name else q 39 | question = _normalize_question(q) 40 | print("**** question: ", question) 41 | print("**** g_str: ", g_str) 42 | print("**** p_str: ", p_str) 43 | count += 1 44 | exact_match += _compute_exact_match(p_str, g_str) 45 | f1_score += _compute_f1(p_str, g_str) 46 | sari_score += _compute_sari(p_str, g_str, question) 47 | bleu_score = torch.add(bleu_score, _compute_bleu(p_str, g_str)) 48 | exact_match /= count 49 | f1_score /= count 50 | sari_score /= count 51 | bleu_score = torch.div(bleu_score, count) 52 | return {"sari_score": sari_score, 53 | "exact_match": exact_match, 54 | "bleu_4": bleu_score, 55 | "f1_score": f1_score} 56 | 57 | 58 | def restore_oov(prediction): 59 | """ 60 | Replace T5 SPM OOV character with `<`. 61 | Certain punctuation characters are mapped to the OOV symbol in T5's 62 | sentence-piece model. For Spider, this appears to only affect the `<` symbol, 63 | so it can be deterministically recovered by running this script. 64 | An alternative is to preprocess dataset to avoid OOV symbols for T5. 65 | """ 66 | pred = prediction.replace(" ⁇ ", "<") 67 | return pred 68 | 69 | 70 | def remove_t5_tokens(prediction): 71 | t5_special_tokens = ["", ""] 72 | for tok in t5_special_tokens: 73 | prediction = prediction.replace(tok, "") 74 | return prediction.strip() 75 | 76 | 77 | def format_prediction(prediction, no_split=None): 78 | pred = remove_t5_tokens(restore_oov(prediction)) 79 | return _normalize_question(pred) 80 | 81 | 82 | def _white_space_fix(text: str) -> str: 83 | return ' '.join(text.split()) 84 | 85 | 86 | def _lower(text: str) -> str: 87 | return text.lower() 88 | 89 | 90 | def _normalize_question(question): 91 | """Lower text and remove punctuation, articles and extra whitespace.""" 92 | model_tokens = ["", ""] 93 | for tok in model_tokens: 94 | question = question.replace(tok, "").strip() 95 | parts = [_white_space_fix((_lower(token))) for token in question.split()] 96 | parts = [part for part in parts if part.strip()] 97 | normalized = ' '.join(parts).strip() 98 | return normalized 99 | 100 | 101 | def _compute_f1(predicted, gold): 102 | predicted_bag = set(predicted.split()) 103 | gold_bag = set(gold.split()) 104 | intersection = len(gold_bag.intersection(predicted_bag)) 105 | if not predicted_bag: 106 | precision = 1.0 107 | else: 108 | precision = intersection / float(len(predicted_bag)) 109 | if not gold_bag: 110 | recall = 1.0 111 | else: 112 | recall = intersection / float(len(gold_bag)) 113 | f1 = (2 * precision * recall) / (precision + recall) if not (precision == 0.0 and recall == 0.0) else 0.0 114 | return f1 115 | 116 | 117 | def _compute_exact_match(predicted, gold): 118 | if predicted == gold: 119 | return 1.0 120 | return 0.0 121 | 122 | 123 | def _compute_bleu(predicted_text, gold_text, n_gram=None): 124 | candidate_corpus = [predicted_text.split()] 125 | references_corpus = [[gold_text.split()]] 126 | # tensor 127 | return bleu_score(candidate_corpus=candidate_corpus, 128 | references_corpus=references_corpus, 129 | max_n=4) 130 | 131 | 132 | def _compute_sari(predicted_text, gold_text, question): 133 | # evaluate using SARI 134 | source = question.split(" ") 135 | prediction = predicted_text.split(" ") 136 | targets = [gold_text.split(" ")] 137 | sari, keep, add, deletion = get_sari(source, prediction, targets) 138 | return sari[0] 139 | -------------------------------------------------------------------------------- /src/qdmr_parser/eval_qdmr/sari_hook.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """SARI score for evaluating paraphrasing and other text generation models. 17 | 18 | The score is introduced in the following paper: 19 | 20 | Optimizing Statistical Machine Translation for Text Simplification 21 | Wei Xu, Courtney Napoles, Ellie Pavlick, Quanze Chen and Chris Callison-Burch 22 | In Transactions of the Association for Computational Linguistics (TACL) 2015 23 | http://cs.jhu.edu/~napoles/res/tacl2016-optimizing.pdf 24 | 25 | This implementation has two differences with the GitHub [1] implementation: 26 | (1) Define 0/0=1 instead of 0 to give higher scores for predictions that match 27 | a target exactly. 28 | (2) Fix an alleged bug [2] in the deletion score computation. 29 | 30 | [1] https://github.com/cocoxu/simplification/blob/master/SARI.py 31 | (commit 0210f15) 32 | [2] https://github.com/cocoxu/simplification/issues/6 33 | """ 34 | 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function 38 | 39 | import collections 40 | 41 | import numpy as np 42 | # import tensorflow as tf 43 | 44 | # The paper that intoduces the SARI score uses only the precision of the deleted 45 | # tokens (i.e. beta=0). To give more emphasis on recall, you may set, e.g., 46 | # beta=1. 47 | BETA_FOR_SARI_DELETION_F_MEASURE = 0 48 | 49 | 50 | def _get_ngram_counter(ids, n): 51 | """Get a Counter with the ngrams of the given ID list. 52 | 53 | Args: 54 | ids: np.array or a list corresponding to a single sentence 55 | n: n-gram size 56 | 57 | Returns: 58 | collections.Counter with ID tuples as keys and 1s as values. 59 | """ 60 | # Remove zero IDs used to pad the sequence. 61 | ids = [token_id for token_id in ids if token_id != 0] 62 | ngram_list = [tuple(ids[i:i + n]) for i in range(len(ids) + 1 - n)] 63 | ngrams = set(ngram_list) 64 | counts = collections.Counter() 65 | for ngram in ngrams: 66 | counts[ngram] = 1 67 | return counts 68 | 69 | 70 | def _get_fbeta_score(true_positives, selected, relevant, beta=1): 71 | """Compute Fbeta score. 72 | 73 | Args: 74 | true_positives: Number of true positive ngrams. 75 | selected: Number of selected ngrams. 76 | relevant: Number of relevant ngrams. 77 | beta: 0 gives precision only, 1 gives F1 score, and Inf gives recall only. 78 | 79 | Returns: 80 | Fbeta score. 81 | """ 82 | precision = 1 83 | if selected > 0: 84 | precision = true_positives / selected 85 | if beta == 0: 86 | return precision 87 | recall = 1 88 | if relevant > 0: 89 | recall = true_positives / relevant 90 | if precision > 0 and recall > 0: 91 | beta2 = beta * beta 92 | return (1 + beta2) * precision * recall / (beta2 * precision + recall) 93 | else: 94 | return 0 95 | 96 | 97 | def get_addition_score(source_counts, prediction_counts, target_counts): 98 | """Compute the addition score (Equation 4 in the paper).""" 99 | added_to_prediction_counts = prediction_counts - source_counts 100 | true_positives = sum((added_to_prediction_counts & target_counts).values()) 101 | selected = sum(added_to_prediction_counts.values()) 102 | # Note that in the paper the summation is done over all the ngrams in the 103 | # output rather than the ngrams in the following set difference. Since the 104 | # former does not make as much sense we compute the latter, which is also done 105 | # in the GitHub implementation. 106 | relevant = sum((target_counts - source_counts).values()) 107 | return _get_fbeta_score(true_positives, selected, relevant) 108 | 109 | 110 | def get_keep_score(source_counts, prediction_counts, target_counts): 111 | """Compute the keep score (Equation 5 in the paper).""" 112 | source_and_prediction_counts = source_counts & prediction_counts 113 | source_and_target_counts = source_counts & target_counts 114 | true_positives = sum((source_and_prediction_counts & 115 | source_and_target_counts).values()) 116 | selected = sum(source_and_prediction_counts.values()) 117 | relevant = sum(source_and_target_counts.values()) 118 | return _get_fbeta_score(true_positives, selected, relevant) 119 | 120 | 121 | def get_deletion_score(source_counts, prediction_counts, target_counts, beta=0): 122 | """Compute the deletion score (Equation 6 in the paper).""" 123 | source_not_prediction_counts = source_counts - prediction_counts 124 | source_not_target_counts = source_counts - target_counts 125 | true_positives = sum((source_not_prediction_counts & 126 | source_not_target_counts).values()) 127 | selected = sum(source_not_prediction_counts.values()) 128 | relevant = sum(source_not_target_counts.values()) 129 | return _get_fbeta_score(true_positives, selected, relevant, beta=beta) 130 | 131 | 132 | def get_sari_score(source_ids, prediction_ids, list_of_targets, 133 | max_gram_size=4, beta_for_deletion=0): 134 | """Compute the SARI score for a single prediction and one or more targets. 135 | 136 | Args: 137 | source_ids: a list / np.array of SentencePiece IDs 138 | prediction_ids: a list / np.array of SentencePiece IDs 139 | list_of_targets: a list of target ID lists / np.arrays 140 | max_gram_size: int. largest n-gram size we care about (e.g. 3 for unigrams, 141 | bigrams, and trigrams) 142 | beta_for_deletion: beta for deletion F score. 143 | 144 | Returns: 145 | the SARI score and its three components: add, keep, and deletion scores 146 | """ 147 | addition_scores = [] 148 | keep_scores = [] 149 | deletion_scores = [] 150 | for n in range(1, max_gram_size + 1): 151 | source_counts = _get_ngram_counter(source_ids, n) 152 | prediction_counts = _get_ngram_counter(prediction_ids, n) 153 | # All ngrams in the targets with count 1. 154 | target_counts = collections.Counter() 155 | # All ngrams in the targets with count r/num_targets, where r is the number 156 | # of targets where the ngram occurs. 157 | weighted_target_counts = collections.Counter() 158 | num_nonempty_targets = 0 159 | for target_ids_i in list_of_targets: 160 | target_counts_i = _get_ngram_counter(target_ids_i, n) 161 | if target_counts_i: 162 | weighted_target_counts += target_counts_i 163 | num_nonempty_targets += 1 164 | for gram in weighted_target_counts.keys(): 165 | weighted_target_counts[gram] /= num_nonempty_targets 166 | target_counts[gram] = 1 167 | keep_scores.append(get_keep_score(source_counts, prediction_counts, 168 | weighted_target_counts)) 169 | deletion_scores.append(get_deletion_score(source_counts, prediction_counts, 170 | weighted_target_counts, 171 | beta_for_deletion)) 172 | addition_scores.append(get_addition_score(source_counts, prediction_counts, 173 | target_counts)) 174 | 175 | avg_keep_score = sum(keep_scores) / max_gram_size 176 | avg_addition_score = sum(addition_scores) / max_gram_size 177 | avg_deletion_score = sum(deletion_scores) / max_gram_size 178 | sari = (avg_keep_score + avg_addition_score + avg_deletion_score) / 3.0 179 | return sari, avg_keep_score, avg_addition_score, avg_deletion_score 180 | 181 | 182 | def get_sari(source_ids, prediction_ids, target_ids, max_gram_size=4): 183 | """Computes the SARI scores from the given source, prediction and targets. 184 | 185 | Args: 186 | source_ids: A 2D tf.Tensor of size (batch_size , sequence_length) 187 | prediction_ids: A 2D tf.Tensor of size (batch_size, sequence_length) 188 | target_ids: A 3D tf.Tensor of size (batch_size, number_of_targets, 189 | sequence_length) 190 | max_gram_size: int. largest n-gram size we care about (e.g. 3 for unigrams, 191 | bigrams, and trigrams) 192 | 193 | Returns: 194 | A 4-tuple of 1D float Tensors of size (batch_size) for the SARI score and 195 | the keep, addition and deletion scores. 196 | """ 197 | 198 | # def get_sari_numpy(source_ids, prediction_ids, target_ids): 199 | """Iterate over elements in the batch and call the SARI function.""" 200 | sari_scores = [] 201 | keep_scores = [] 202 | add_scores = [] 203 | deletion_scores = [] 204 | # Iterate over elements in the batch. 205 | for source_ids_i, prediction_ids_i, target_ids_i in zip( 206 | source_ids, prediction_ids, target_ids): 207 | sari, keep, add, deletion = get_sari_score( 208 | source_ids_i, prediction_ids_i, target_ids_i, max_gram_size, 209 | BETA_FOR_SARI_DELETION_F_MEASURE) 210 | sari_scores.append(sari) 211 | keep_scores.append(keep) 212 | add_scores.append(add) 213 | deletion_scores.append(deletion) 214 | return (np.asarray(sari_scores), np.asarray(keep_scores), 215 | np.asarray(add_scores), np.asarray(deletion_scores)) 216 | 217 | # sari, keep, add, deletion = tf.py_func( 218 | # get_sari_numpy, 219 | # [source_ids, prediction_ids, target_ids], 220 | # [tf.float64, tf.float64, tf.float64, tf.float64]) 221 | # return sari, keep, add, deletion -------------------------------------------------------------------------------- /src/qdmr_parser/model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import logging 5 | import re 6 | from string import punctuation 7 | 8 | import nltk 9 | 10 | from dataset_qdmr import BreakDataset 11 | from eval_qdmr.eval_string_match import StringMatch 12 | 13 | nltk.download('punkt') 14 | 15 | import torch 16 | from torch.utils.data import Dataset, DataLoader 17 | import pytorch_lightning as pl 18 | 19 | from transformers import ( 20 | AdamW, 21 | Adafactor, 22 | T5ForConditionalGeneration, 23 | T5TokenizerFast as T5Tokenizer, 24 | get_linear_schedule_with_warmup 25 | ) 26 | 27 | 28 | class T5FineTuner(pl.LightningModule): 29 | def __init__(self, hparams): 30 | super(T5FineTuner, self).__init__() 31 | self.save_hyperparameters(hparams) 32 | self.model = T5ForConditionalGeneration.from_pretrained(hparams.model_name_or_path) 33 | self.tokenizer = T5Tokenizer.from_pretrained(hparams.tokenizer_name_or_path) 34 | 35 | def is_logger(self): 36 | return True 37 | 38 | def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, 39 | labels=None): 40 | return self.model( 41 | input_ids, 42 | attention_mask=attention_mask, 43 | decoder_input_ids=decoder_input_ids, 44 | decoder_attention_mask=decoder_attention_mask, 45 | labels=labels, 46 | ) 47 | 48 | def _step(self, batch): 49 | labels = batch["target_ids"] 50 | labels[labels[:, :] == self.tokenizer.pad_token_id] = -100 51 | 52 | outputs = self( 53 | input_ids=batch["source_ids"], 54 | attention_mask=batch["source_mask"], 55 | labels=labels, 56 | decoder_attention_mask=batch['target_mask'] 57 | ) 58 | 59 | loss = outputs[0] 60 | return loss 61 | 62 | def training_step(self, batch, batch_idx): 63 | loss = self._step(batch) 64 | return {"loss": loss} 65 | 66 | def training_epoch_end(self, outputs): 67 | avg_train_loss = torch.stack([x["loss"] for x in outputs]).mean() 68 | pl_logs = {"avg_train_loss": avg_train_loss} 69 | self.log_dict(pl_logs) 70 | return {"avg_train_loss": avg_train_loss} 71 | 72 | def validation_step(self, batch, batch_idx): 73 | # record the source, targets & predictions for each batch 74 | # use these to compute evaluate on the validation set 75 | src = [self.tokenizer.decode(ids) for ids in batch['source_ids']] 76 | target = [self.tokenizer.decode(ids) for ids in batch['target_ids']] 77 | preds = self.model.generate(input_ids=batch['source_ids'].cuda(), 78 | attention_mask=batch['source_mask'].cuda(), 79 | max_length=512) 80 | dec_preds = [self.tokenizer.decode(ids) for ids in preds] 81 | loss = self._step(batch) 82 | return {"val_loss": loss, 83 | "source": src, 84 | "target": target, 85 | "preds": dec_preds} 86 | 87 | def validation_epoch_end(self, outputs): 88 | # evaluate predictions of the entire validation set 89 | all_inputs = [] 90 | all_targets = [] 91 | all_predictions = [] 92 | for val_step_out in outputs: 93 | all_inputs.extend(val_step_out["source"]) 94 | all_targets.extend(val_step_out["target"]) 95 | all_predictions.extend(val_step_out["preds"]) 96 | results = evaluate_predictions(examples=all_inputs, 97 | gold_labels=all_targets, 98 | predictions=all_predictions, 99 | args=self.hparams, 100 | task=self.hparams.task, 101 | prepend_dataset_name=self.hparams.prepend_dataset_name) 102 | avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() 103 | tensorboard_logs = {**{"val_loss": avg_loss}, **results} 104 | pl_logs = {"epoch": self.trainer.current_epoch, "avg_val_loss": avg_loss, 105 | "log": tensorboard_logs, 'progress_bar': tensorboard_logs} 106 | pl_logs = {**pl_logs, **results} 107 | self.log_dict(pl_logs) 108 | return {"avg_val_loss": avg_loss} 109 | 110 | 111 | def dummy_val_epoch_eval(self, inputs, targets, predictions): 112 | score = 0 113 | for i in range(len(targets)): 114 | if targets[i] == predictions[i]: 115 | score += 1 116 | return float(score) / len(targets) 117 | 118 | def configure_optimizers(self): 119 | "Prepare optimizer and schedule (linear warmup and decay)" 120 | model = self.model 121 | no_decay = ["bias", "LayerNorm.weight"] 122 | optimizer_grouped_parameters = [ 123 | { 124 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 125 | "weight_decay": self.hparams.weight_decay, 126 | }, 127 | { 128 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 129 | "weight_decay": 0.0, 130 | }, 131 | ] 132 | optimizer = Adafactor( 133 | model.parameters(), 134 | lr=self.hparams.learning_rate, 135 | eps=(1e-30, 1e-3), 136 | clip_threshold=1.0, 137 | decay_rate=-0.8, 138 | beta1=None, 139 | weight_decay=0.0, 140 | relative_step=False, 141 | scale_parameter=False, 142 | warmup_init=False 143 | ) 144 | self.opt = optimizer 145 | return [optimizer] 146 | 147 | def get_tqdm_dict(self): 148 | tqdm_dict = {"loss": "{:.3f}".format(self.trainer.avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]} 149 | 150 | return tqdm_dict 151 | 152 | def train_dataloader(self): 153 | train_dataset = get_dataset(tokenizer=self.tokenizer, type_path="train", args=self.hparams) 154 | dataloader = DataLoader(train_dataset, batch_size=self.hparams.train_batch_size, drop_last=True, shuffle=True, 155 | num_workers=4) 156 | t_total = ( 157 | (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu))) 158 | // self.hparams.gradient_accumulation_steps 159 | * float(self.hparams.num_train_epochs) 160 | ) 161 | scheduler = get_linear_schedule_with_warmup( 162 | self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total 163 | ) 164 | self.lr_scheduler = scheduler 165 | return dataloader 166 | 167 | def val_dataloader(self): 168 | val_dataset = get_dataset(tokenizer=self.tokenizer, type_path="val", args=self.hparams) 169 | return DataLoader(val_dataset, batch_size=self.hparams.eval_batch_size, num_workers=4) 170 | 171 | 172 | ####### Dataset ####### 173 | 174 | def get_dataset(tokenizer, type_path, args): 175 | assert args.task in ["nl_to_qdmr"] 176 | data_file = args.training_set_file if type_path == "train" else args.dev_set_file 177 | return BreakDataset(tokenizer=tokenizer, 178 | data_file=os.path.join(args.data_dir, data_file), 179 | source_max_token_len=args.max_seq_length, 180 | target_max_token_len=args.max_seq_length, 181 | prepend_dataset_name=args.prepend_dataset_name) 182 | 183 | 184 | ####### Evaluation ####### 185 | 186 | def evaluate_predictions(examples, gold_labels, predictions, args, is_test=None, task=None, 187 | prepend_dataset_name=None): 188 | evaluator = StringMatch() 189 | return evaluator.evaluate(questions=examples, 190 | gold=gold_labels, 191 | predict=predictions, 192 | prepend_dataset_name=prepend_dataset_name) 193 | 194 | 195 | ####### Logger ####### 196 | 197 | logger = logging.getLogger(__name__) 198 | 199 | 200 | class LoggingCallback(pl.Callback): 201 | def on_validation_end(self, trainer, pl_module): 202 | logger.info("***** Validation results *****") 203 | if pl_module.is_logger(): 204 | metrics = trainer.callback_metrics 205 | 206 | # Log and save results to file 207 | output_test_results_file = os.path.join(pl_module.hparams.output_dir, "val_results.txt") 208 | with open(output_test_results_file, "a") as writer: 209 | for key in sorted(metrics): 210 | if key not in ["progress_bar"]: 211 | logger.info("{} = {}\n".format(key, str(metrics[key]))) 212 | writer.write("{} = {}\n".format(key, str(metrics[key]))) 213 | 214 | def on_test_end(self, trainer, pl_module): 215 | logger.info("***** Test results *****") 216 | 217 | if pl_module.is_logger(): 218 | metrics = trainer.callback_metrics 219 | 220 | # Log and save results to file 221 | output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt") 222 | with open(output_test_results_file, "w") as writer: 223 | for key in sorted(metrics): 224 | if key not in ["log", "progress_bar"]: 225 | logger.info("{} = {}\n".format(key, str(metrics[key]))) 226 | writer.write("{} = {}\n".format(key, str(metrics[key]))) 227 | -------------------------------------------------------------------------------- /src/qdmr_parser/test.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | from tqdm.auto import tqdm 3 | import argparse 4 | import glob 5 | import os 6 | import json 7 | import time 8 | import logging 9 | import random 10 | import re 11 | import sys 12 | from itertools import chain 13 | from string import punctuation 14 | 15 | import nltk 16 | 17 | from eval_qdmr.eval_string_match import format_prediction 18 | 19 | nltk.download('punkt') 20 | from nltk.tokenize import sent_tokenize 21 | 22 | import pandas as pd 23 | import numpy as np 24 | import torch 25 | from torch.utils.data import Dataset, DataLoader 26 | import pytorch_lightning as pl 27 | 28 | from transformers import ( 29 | AdamW, 30 | T5ForConditionalGeneration, 31 | T5Tokenizer, 32 | get_linear_schedule_with_warmup 33 | ) 34 | 35 | from model import LoggingCallback, T5FineTuner, get_dataset, evaluate_predictions 36 | import numpy as np 37 | import torch 38 | import pytorch_lightning as pl 39 | 40 | 41 | if __name__ == '__main__': 42 | args_dict = dict( 43 | predictions_output_file='t5_large_qdmr_parser_spider_train.txt', 44 | checkpoint_path='/trained_models/t5_large_qdmr_parser_bs_2_accum_64_epochs_150_lr_1e-4/epoch=6-exact_match=0.231.ckpt', 45 | task='nl_to_qdmr', 46 | data_dir='', # path for data files 47 | output_dir='', # path to save the checkpoints 48 | training_set_file='train.csv', # name of training set file in data dir 49 | dev_set_file='dev.csv', # name of dev set file in data dir 50 | dev_set_labels='', # name of file containing the dev qdmr/nl queries in data dir 51 | test_set_file='', 52 | test_set_labels='', 53 | model_name_or_path='t5-large', 54 | tokenizer_name_or_path='t5-large', 55 | max_seq_length=512, 56 | learning_rate=1e-4, 57 | weight_decay=0.0, 58 | adam_epsilon=1e-8, 59 | warmup_steps=0, 60 | train_batch_size=2, 61 | eval_batch_size=2, 62 | num_train_epochs=2, 63 | gradient_accumulation_steps=64, 64 | n_gpu=1, 65 | fp_16=False, # if you want to enable 16-bit training then install apex and set this to true 66 | opt_level='O1', 67 | # you can find out more on optimisation levels here https://nvidia.github.io/apex/amp.html#opt-levels-and-properties 68 | max_grad_norm=1.0, # if you enable 16-bit training then set this to a sensible value, 0.5 is a good default 69 | seed=42, 70 | ) 71 | 72 | args = argparse.Namespace(**args_dict) 73 | 74 | tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path) 75 | 76 | print("Loading T5FinalTuner pretrained model...") 77 | model = T5FineTuner.load_from_checkpoint(args.checkpoint_path) 78 | model = model.to('cuda') 79 | print("Done!") 80 | 81 | # re-set dataset params 82 | model.hparams.dev_set_file = args.dev_set_file 83 | 84 | print("Loading dataset...") 85 | dataset = get_dataset(tokenizer=tokenizer, 86 | type_path="test", 87 | args=model.hparams) 88 | 89 | print("Generating sequences...") 90 | loader = DataLoader(dataset, batch_size=model.hparams.eval_batch_size, num_workers=2) 91 | model.model.eval() 92 | outputs = [] 93 | targets = [] 94 | for batch in tqdm(loader): 95 | outs = model.model.generate(input_ids=batch['source_ids'].cuda(), 96 | attention_mask=batch['source_mask'].cuda(), 97 | max_length=512) 98 | dec = [tokenizer.decode(ids) for ids in outs] 99 | target = [tokenizer.decode(ids) for ids in batch["target_ids"]] 100 | outputs.extend(dec) 101 | targets.extend(target) 102 | 103 | with open(args.predictions_output_file, 'w') as f: 104 | for i, out in enumerate(outputs): 105 | formatted_out = format_prediction(out) 106 | print(formatted_out, file=f) 107 | -------------------------------------------------------------------------------- /src/qdmr_parser/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import json 5 | import time 6 | import logging 7 | import random 8 | import re 9 | from itertools import chain 10 | from string import punctuation 11 | 12 | from model import LoggingCallback, T5FineTuner 13 | import numpy as np 14 | import torch 15 | import pytorch_lightning as pl 16 | 17 | 18 | def set_seed(seed): 19 | random.seed(seed) 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | if torch.cuda.is_available(): 23 | torch.cuda.manual_seed_all(seed) 24 | 25 | 26 | if __name__ == '__main__': 27 | args_dict = dict( 28 | task='nl_to_qdmr', 29 | data_dir='', # path for data files 30 | output_dir='', # path to save the checkpoints 31 | training_set_file='', # name of training set file in data dir 32 | dev_set_file='', # name of dev set file in data dir 33 | dev_set_labels='', # name of file containing the dev qdmr/nl queries in data dir 34 | test_set_file='', 35 | test_set_labels='', 36 | prepend_dataset_name=True, 37 | model_name_or_path='t5-large', 38 | tokenizer_name_or_path='t5-large', 39 | max_seq_length=512, 40 | learning_rate=1e-4, 41 | weight_decay=0.0, 42 | warmup_steps=0, 43 | train_batch_size=2, 44 | eval_batch_size=2, 45 | num_train_epochs=15, 46 | gradient_accumulation_steps=64, 47 | n_gpu=1, 48 | fp_16=False, # if you want to enable 16-bit training then install apex and set this to true 49 | opt_level='O1', 50 | # you can find out more on optimisation levels here https://nvidia.github.io/apex/amp.html#opt-levels-and-properties 51 | max_grad_norm=1.0, # if you enable 16-bit training then set this to a sensible value, 0.5 is a good default 52 | seed=42, 53 | ) 54 | 55 | args_dict.update({'data_dir': 'data/break', 56 | 'training_set_file': 'train.csv', 57 | 'dev_set_file': 'dev.csv', 58 | 'output_dir': 't5_large_qdmr_parser_bs_2_accum_64_epochs_10_lr_1e-4'}) 59 | 60 | args = argparse.Namespace(**args_dict) 61 | 62 | # set random seed 63 | set_seed(42) 64 | 65 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 66 | dirpath=args.output_dir, filename='{epoch}-{exact_match:.3f}', 67 | monitor="exact_match", mode="max", save_top_k=1, save_last=False 68 | ) 69 | 70 | train_params = dict( 71 | accumulate_grad_batches=args.gradient_accumulation_steps, 72 | gpus=args.n_gpu, 73 | max_epochs=args.num_train_epochs, 74 | precision=16 if args.fp_16 else 32, 75 | amp_level=args.opt_level, 76 | gradient_clip_val=args.max_grad_norm, 77 | checkpoint_callback=True, 78 | callbacks=[LoggingCallback(), checkpoint_callback], 79 | ) 80 | 81 | # initialize model 82 | model = T5FineTuner(args) 83 | 84 | # Load pre-trained model from checkpoint 85 | # checkpoint_path = "/trained_models/t5_large_qdmr_parser_bs_2_accum_64_epochs_150_lr_1e-4/last.ckpt" 86 | # print("Loading T5FinalTuner pretrained model...") 87 | # model = T5FineTuner.load_from_checkpoint(checkpoint_path) 88 | # for key in args_dict.keys(): 89 | # model.hparams[key] = args_dict[key] 90 | # model = model.to('cuda') 91 | # print("Done!") 92 | 93 | # initialize trainer 94 | trainer = pl.Trainer(**train_params) 95 | 96 | # start fine-tuning 97 | trainer.fit(model) -------------------------------------------------------------------------------- /src/qdmr_parser/utils_data.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | 4 | 5 | def normalize_whitespace(source): 6 | tokens = source.split() 7 | return " ".join(tokens) 8 | 9 | 10 | def load_json(filepath): 11 | with open(filepath, "r") as reader: 12 | text = reader.read() 13 | return json.loads(text) 14 | 15 | 16 | def read_csv_to_dictionaries(filepath): 17 | with open(filepath, mode='r', encoding='utf-8') as csv_file: 18 | csv_reader = csv.DictReader(csv_file) 19 | line_count = 0 20 | csv_examples = [] 21 | for row in csv_reader: 22 | if line_count >= 0: 23 | csv_examples += [row] 24 | line_count += 1 25 | return csv_examples 26 | -------------------------------------------------------------------------------- /src/text_to_sql/dataset_qdmr.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | from dataset_utils import load_json, normalize_whitespace 6 | 7 | 8 | class QDMRDataset(Dataset): 9 | def __init__(self, tokenizer, data_file, tables_file, dataset_type, 10 | max_len=512, append_schema=None, encoding=None): 11 | self.dataset_type = dataset_type 12 | self.data_file = data_file 13 | self.tables_file = tables_file 14 | self.append_schema = append_schema 15 | 16 | self.max_len = max_len 17 | self.tokenizer = tokenizer 18 | self.inputs = [] 19 | self.targets = [] 20 | 21 | self.target_encoding = encoding 22 | if self.target_encoding is not None: 23 | assert self.target_encoding in ["qdmr_formula", "qdmr_steps", "qdmr_sql", "sql"] 24 | 25 | self._build() 26 | 27 | def __len__(self): 28 | return len(self.inputs) 29 | 30 | def __getitem__(self, index): 31 | source_ids = self.inputs[index]["input_ids"].squeeze() 32 | target_ids = self.targets[index]["input_ids"].squeeze() 33 | 34 | src_mask = self.inputs[index]["attention_mask"].squeeze() # might need to squeeze 35 | target_mask = self.targets[index]["attention_mask"].squeeze() # might need to squeeze 36 | 37 | return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "target_mask": target_mask} 38 | 39 | def _build(self): 40 | self._build_input_output(self.data_file, self.tables_file, self.append_schema) 41 | 42 | def _build_input_output(self, data_file, tables_file, append_schema): 43 | if append_schema: 44 | # Serialize DB schema 45 | tables_json = load_json(tables_file) 46 | db_id_to_schema_string = {} 47 | for table_json in tables_json: 48 | db_id = table_json["db_id"].lower() 49 | db_id_to_schema_string[db_id] = self._get_schema_string(table_json) 50 | # Read examples 51 | data = load_json(data_file) 52 | raw_data = data["data"] 53 | for example in raw_data: 54 | database = example["db_id"].lower() 55 | source = example["question"] 56 | # target meaning representation 57 | target = example["sql_gold"] 58 | if self.target_encoding == "qdmr_formula": 59 | target = example["grounding_enc_no_ref"] 60 | elif self.target_encoding == "qdmr_steps": 61 | target = example["grounding_enc_has_ref"] 62 | elif self.target_encoding == "qdmr_sql": 63 | target = example["sql_ground"] 64 | # Prepend database 65 | source = "%s: %s" % (database, source) 66 | if append_schema: 67 | schema_string = db_id_to_schema_string[database] 68 | source = "%s%s" % (source, schema_string) 69 | target = normalize_whitespace(target) 70 | source += self.tokenizer.eos_token 71 | target += self.tokenizer.eos_token 72 | input = source.lower() 73 | target = target.lower() 74 | 75 | # tokenize inputs 76 | tokenized_inputs = self.tokenizer.batch_encode_plus( 77 | [input], max_length=self.max_len, pad_to_max_length=True, return_tensors="pt" 78 | ) 79 | # tokenize targets 80 | tokenized_targets = self.tokenizer.batch_encode_plus( 81 | [target], max_length=self.max_len, pad_to_max_length=True, return_tensors="pt" 82 | ) 83 | self.inputs.append(tokenized_inputs) 84 | self.targets.append(tokenized_targets) 85 | 86 | def _get_schema_string(self, table_json): 87 | """Returns the schema serialized as a string.""" 88 | table_id_to_column_names = collections.defaultdict(list) 89 | for table_id, name in table_json["column_names_original"]: 90 | table_id_to_column_names[table_id].append(name.lower()) 91 | tables = table_json["table_names_original"] 92 | 93 | table_strings = [] 94 | for table_id, table_name in enumerate(tables): 95 | column_names = table_id_to_column_names[table_id] 96 | table_string = " | %s : %s" % (table_name.lower(), " , ".join(column_names)) 97 | table_strings.append(table_string) 98 | 99 | return "".join(table_strings) 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /src/text_to_sql/dataset_spider.py: -------------------------------------------------------------------------------- 1 | import json 2 | import collections 3 | 4 | import torch 5 | from torch.utils.data import Dataset, DataLoader 6 | from dataset_utils import load_json, normalize_whitespace 7 | 8 | 9 | class SpiderDataset(Dataset): 10 | def __init__(self, tokenizer, data_file, tables_file, dataset_type, 11 | max_len=512, append_schema=None): 12 | self.dataset_type = dataset_type 13 | self.data_file = data_file 14 | self.tables_file = tables_file 15 | self.append_schema = append_schema 16 | 17 | self.max_len = max_len 18 | self.tokenizer = tokenizer 19 | self.inputs = [] 20 | self.targets = [] 21 | 22 | self._build() 23 | 24 | def __len__(self): 25 | return len(self.inputs) 26 | 27 | def __getitem__(self, index): 28 | source_ids = self.inputs[index]["input_ids"].squeeze() 29 | target_ids = self.targets[index]["input_ids"].squeeze() 30 | 31 | src_mask = self.inputs[index]["attention_mask"].squeeze() # might need to squeeze 32 | target_mask = self.targets[index]["attention_mask"].squeeze() # might need to squeeze 33 | 34 | return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "target_mask": target_mask} 35 | 36 | def _build(self): 37 | self._build_input_output(self.data_file, self.tables_file, self.append_schema) 38 | 39 | def _build_input_output(self, data_file, tables_file, append_schema): 40 | if append_schema: 41 | # Serialize DB schema 42 | tables_json = load_json(tables_file) 43 | db_id_to_schema_string = {} 44 | for table_json in tables_json: 45 | db_id = table_json["db_id"].lower() 46 | db_id_to_schema_string[db_id] = self._get_schema_string(table_json) 47 | # Read examples 48 | raw_data = load_json(data_file) 49 | for example in raw_data: 50 | database = example["db_id"].lower() 51 | source = example["question"] 52 | target = example["query"] 53 | # Prepend database 54 | source = "%s: %s" % (database, source) 55 | if append_schema: 56 | schema_string = db_id_to_schema_string[database] 57 | source = "%s%s" % (source, schema_string) 58 | target = normalize_whitespace(target) 59 | source += self.tokenizer.eos_token 60 | target += self.tokenizer.eos_token 61 | input = source.lower() 62 | target = target.lower() 63 | 64 | # tokenize inputs 65 | tokenized_inputs = self.tokenizer.batch_encode_plus( 66 | [input], max_length=self.max_len, pad_to_max_length=True, return_tensors="pt" 67 | ) 68 | # tokenize targets 69 | tokenized_targets = self.tokenizer.batch_encode_plus( 70 | [target], max_length=self.max_len, pad_to_max_length=True, return_tensors="pt" 71 | ) 72 | self.inputs.append(tokenized_inputs) 73 | self.targets.append(tokenized_targets) 74 | 75 | def _get_schema_string(self, table_json): 76 | """Returns the schema serialized as a string.""" 77 | table_id_to_column_names = collections.defaultdict(list) 78 | for table_id, name in table_json["column_names_original"]: 79 | table_id_to_column_names[table_id].append(name.lower()) 80 | tables = table_json["table_names_original"] 81 | 82 | table_strings = [] 83 | for table_id, table_name in enumerate(tables): 84 | column_names = table_id_to_column_names[table_id] 85 | table_string = " | %s : %s" % (table_name.lower(), " , ".join(column_names)) 86 | table_strings.append(table_string) 87 | 88 | return "".join(table_strings) -------------------------------------------------------------------------------- /src/text_to_sql/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def load_json(filepath): 4 | with open(filepath, "r") as reader: 5 | text = reader.read() 6 | return json.loads(text) 7 | 8 | 9 | def normalize_whitespace(source): 10 | tokens = source.split() 11 | return " ".join(tokens) 12 | 13 | 14 | def write_to_json(data, json_file): 15 | with open(json_file, mode='w+', encoding='utf-8') as file: 16 | json.dump(data, file, indent=4) 17 | return True 18 | 19 | 20 | def geo_data_splits(input_data, output_file, data_split): 21 | """ 22 | script to generate Geo880 data file in the QDMRDataset json format: 23 | 1. Geo880 gold SQL json files for train/dev/test split 24 | 2. Geo880 grounded QDMR json files for the training set 25 | """ 26 | assert data_split in ["train", "dev", "test"] 27 | prefix = f"GEO_{data_split}" 28 | # Read examples 29 | data = load_json(input_data) 30 | raw_data = data["data"] 31 | filtered_data = {"data": []} 32 | for example in raw_data: 33 | if example["example_id"].startswith(prefix): 34 | filtered_data["data"] += [example] 35 | num_left = len(filtered_data["data"]) 36 | write_to_json(filtered_data, output_file) 37 | print(f"Done writing {num_left} examples to {output_file}.") 38 | return True 39 | 40 | 41 | #geo_gold_sql_data = "../data/qdmr_data/groundings_geo880.json" 42 | #geo_encoded_qdmr_data = "../data/qdmr_data/qdmr_ground_enc_geo880.json" 43 | #geo_data_splits(geo_encoded_qdmr_data, "qdmr_ground_enc_geo880_train.json", "train") 44 | #geo_data_splits(geo_gold_sql_data, "geo880_sql_train.json", "train") 45 | #geo_data_splits(geo_gold_sql_data, "geo880_sql_dev.json", "dev") 46 | #geo_data_splits(geo_gold_sql_data, "geo880_sql_test.json", "test") 47 | -------------------------------------------------------------------------------- /src/text_to_sql/eval_exec/graph_utils.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | 4 | def has_path(graph, start, end, path=[]): 5 | G = nx.Graph(graph) 6 | try: 7 | return nx.has_path(G, start, end) 8 | except: 9 | return None 10 | return None 11 | 12 | 13 | def find_shortest_paths(graph, start, end): 14 | G = nx.Graph(graph) 15 | try: 16 | return [p for p in nx.all_shortest_paths(G, source=start, target=end)] 17 | except: 18 | None 19 | return None 20 | 21 | # a sample graph 22 | # graph = {'A': ['B', 'C', 'E'], 23 | # 'B': ['A', 'C', 'D'], 24 | # 'C': ['A', 'B', 'D', 'F'], 25 | # 'D': ['A', 'B', 'C', 'E'], 26 | # 'E': ['A', 'D', 'F'], 27 | # 'F': ['C', 'E']} 28 | 29 | 30 | # print(has_path(graph,"A","D")) 31 | # print(find_shortest_paths(graph,"A","F")) 32 | -------------------------------------------------------------------------------- /src/text_to_sql/eval_exec/predicted_sql.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def remove_quotation_marks(string): 5 | if string.startswith("'") and string.endswith("'"): 6 | return string[1:-1] 7 | if string.startswith('"') and string.endswith('"'): 8 | return string[1:-1] 9 | return string 10 | 11 | 12 | def sql_quotation_values(sql): 13 | """ 14 | Returns list of lowercase quotation value in SQL query 15 | """ 16 | 17 | def remove_like_pct(string): 18 | string = string.replace("%", "") 19 | return string 20 | 21 | query = sql.lower() 22 | value_to_col = {} 23 | # Find all values based on string delimiters 24 | single_paren_vals = [item.group(0) for item in re.finditer(r'\'.*?\'', query)] 25 | double_paren_vals = [item.group(0) for item in re.finditer(r'\".*?\"', query)] 26 | vals_list = single_paren_vals + double_paren_vals 27 | return [remove_quotation_marks(remove_like_pct(val)) for val in vals_list] 28 | 29 | 30 | def val_sql_quotation(value): 31 | return [f"'{value}'", f"'%{value}%'", f'\"{value}\"', f'\"%{value}%\"'] 32 | 33 | 34 | def sql_value_case(sql, value): 35 | """returns a non-numeric value in the case it appears in the original SQL query""" 36 | 37 | def escape_parentheses(value): 38 | return value.replace("(", "\(").replace(")", "\)") 39 | 40 | value_quotation = val_sql_quotation(value) 41 | for quote_val in value_quotation: 42 | escaped_val = escape_parentheses(quote_val) 43 | if re.search(escaped_val, sql, re.IGNORECASE): 44 | return re.search(escaped_val, sql, re.IGNORECASE).group(0) 45 | return None 46 | 47 | 48 | def fix_sql_casing(pred_sql, gold_sql): 49 | gold_values = sql_quotation_values(gold_sql) 50 | fixed_sql = pred_sql 51 | for val in gold_values: 52 | val_case_quotes = sql_value_case(gold_sql, val) 53 | if val_case_quotes is not None: 54 | val_case = remove_quotation_marks(val_case_quotes) 55 | for pred_val_cased in val_sql_quotation(val_case): 56 | # quoted value as it appears in the *predicted* sql 57 | fixed_sql = fixed_sql.replace(pred_val_cased.lower(), 58 | pred_val_cased) if pred_val_cased.lower() in fixed_sql else fixed_sql 59 | return fixed_sql 60 | -------------------------------------------------------------------------------- /src/text_to_sql/eval_exec/preprocess_db.py: -------------------------------------------------------------------------------- 1 | from eval_exec.db_schema import * 2 | 3 | 4 | def prepare_db_schema(path, dataset): 5 | if dataset == "spider": 6 | return prepare_spider_db(path) 7 | return prepare_other_db(path, dataset) 8 | 9 | 10 | def prepare_spider_db(path): 11 | return DBSchema(path) 12 | 13 | 14 | def prepare_other_db(path, dataset): 15 | schema = DBSchema(path) 16 | # manually add foreign keys that are absent from the original DBs 17 | if dataset == "academic": 18 | schema.add_foreign_key('publication', 'pid', 'writes', 'pid') 19 | schema.add_foreign_key('author', 'aid', 'writes', 'aid') 20 | schema.add_foreign_key('journal', 'jid', 'publication', 'jid') 21 | schema.add_foreign_key('conference', 'cid', 'publication', 'cid') 22 | schema.add_foreign_key('publication', 'pid', 'publication_keyword', 'pid') 23 | schema.add_foreign_key('keyword', 'kid', 'publication_keyword', 'kid') 24 | schema.add_foreign_key('author', 'oid', 'organization', 'oid') 25 | schema.add_foreign_key('author', 'aid', 'domain_author', 'aid') 26 | schema.add_foreign_key('domain', 'did', 'domain_author', 'did') 27 | schema.add_foreign_key('domain', 'did', 'domain_publication', 'did') 28 | schema.add_foreign_key('domain_publication', 'pid', 'publication', 'pid') 29 | schema.add_foreign_key('cite', 'cited', 'publication', 'pid') 30 | schema.add_foreign_key('cite', 'citing', 'publication', 'pid') 31 | schema.add_foreign_key('domain', 'did', 'domain_keyword', 'did') 32 | schema.add_foreign_key('domain_keyword', 'kid', 'keyword', 'kid') 33 | schema.add_foreign_key('domain', 'did', 'domain_journal', 'did') 34 | schema.add_foreign_key('domain_journal', 'jid', 'journal', 'jid') 35 | schema.add_foreign_key('conference', 'cid', 'domain_conference', 'cid') 36 | schema.add_foreign_key('domain', 'did', 'domain_conference', 'did') 37 | elif dataset == "atis": 38 | schema.add_foreign_key('airport_service', 'city_code', 'city', 'city_code') 39 | schema.add_foreign_key('airport_service', 'airport_code', 'flight', 'from_airport') 40 | schema.add_foreign_key('airport_service', 'airport_code', 'flight', 'to_airport') 41 | schema.add_foreign_key('date_day', 'day_name', 'days', 'day_name') 42 | schema.add_foreign_key('days', 'days_code', 'flight', 'flight_days') 43 | schema.add_foreign_key('fare', 'fare_id', 'flight_fare', 'fare_id') 44 | schema.add_foreign_key('flight', 'flight_id', 'flight_fare', 'flight_id') 45 | schema.add_foreign_key('fare', 'fare_basis_code', 'fare_basis', 'fare_basis_code') 46 | schema.add_foreign_key('flight', 'flight_id', 'flight_stop', 'flight_id') 47 | schema.add_foreign_key('days', 'days_code', 'fare_basis', 'basis_days') 48 | schema.add_foreign_key('airport_service', 'airport_code', 'flight_stop', 'stop_airport') 49 | schema.add_foreign_key('city', 'city_code', 'ground_service', 'city_code') 50 | schema.add_foreign_key('airline', 'airline_code', 'flight', 'airline_code') 51 | schema.add_foreign_key('airport', 'airport_code', 'airport_service', 'airport_code') 52 | schema.add_foreign_key('flight', 'meal_code', 'food_service', 'meal_code') 53 | schema.add_foreign_key('aircraft', 'aircraft_code', 'equipment_sequence', 'aircraft_code') 54 | schema.add_foreign_key('equipment_sequence', 'aircraft_code_sequence', 'flight', 'aircraft_code_sequence') 55 | schema.add_foreign_key('city', 'state_code', 'state', 'state_code') 56 | schema.add_foreign_key('airport', 'airport_code', 'flight', 'to_airport') 57 | schema.add_foreign_key('airport', 'airport_code', 'ground_service', 'airport_code') 58 | schema.add_foreign_key('airport', 'airport_code', 'flight', 'from_airport') 59 | schema.add_foreign_key('airport_service', 'airport_code', 'fare', 'to_airport') 60 | schema.add_foreign_key('airport_service', 'airport_code', 'fare', 'from_airport') 61 | schema.add_foreign_key('flight', 'flight_id', 'flight_leg', 'flight_id') 62 | schema.add_foreign_key('flight', 'flight_id', 'flight_leg', 'leg_flight') 63 | schema.add_foreign_key('class_of_service', 'booking_class', 'fare_basis', 'booking_class') 64 | schema.add_foreign_key('airport', 'state_code', 'state', 'state_code') 65 | schema.add_foreign_key('airport', 'airport_code', 'flight_stop', 'stop_airport') 66 | schema.add_foreign_key('fare', 'restriction_code', 'restriction', 'restriction_code') 67 | elif dataset == "geo": 68 | schema.add_foreign_key('border_info', 'border', 'state', 'state_name') 69 | schema.add_foreign_key('river', 'traverse', 'state', 'state_name') 70 | schema.add_foreign_key('city', 'city_name', 'state', 'capital') 71 | schema.add_foreign_key('border_info', 'state_name', 'state', 'state_name') 72 | schema.add_foreign_key('city', 'state_name', 'state', 'state_name') 73 | schema.add_foreign_key('border_info', 'border', 'river', 'traverse') 74 | schema.add_foreign_key('highlow', 'state_name', 'state', 'state_name') 75 | schema.add_foreign_key('border_info', 'border', 'border_info', 'state_name') 76 | schema.add_foreign_key('highlow', 'state_name', 'river', 'traverse') 77 | schema.add_foreign_key('border_info', 'state_name', 'river', 'traverse') 78 | schema.add_foreign_key('city', 'state_name', 'river', 'traverse') 79 | schema.add_foreign_key('border_info', 'border', 'highlow', 'state_name') 80 | schema.add_foreign_key('border_info', 'border', 'city', 'state_name') 81 | schema.add_foreign_key('border_info', 'border', 'lake', 'state_name') 82 | elif dataset == "yelp": 83 | schema.add_foreign_key('business', 'business_id', 'category', 'business_id') 84 | schema.add_foreign_key('review', 'user_id', 'user', 'user_id') 85 | schema.add_foreign_key('business', 'business_id', 'review', 'business_id') 86 | schema.add_foreign_key('business', 'business_id', 'neighborhood', 'business_id') 87 | schema.add_foreign_key('tip', 'user_id', 'user', 'user_id') 88 | schema.add_foreign_key('business', 'business_id', 'tip', 'business_id') 89 | schema.add_foreign_key('business', 'business_id', 'checkin', 'business_id') 90 | elif dataset == "imdb": 91 | schema.add_foreign_key('actor', 'aid', 'cast', 'aid') 92 | schema.add_foreign_key('cast', 'msid', 'movie', 'mid') 93 | schema.add_foreign_key('directed_by', 'did', 'director', 'did') 94 | schema.add_foreign_key('directed_by', 'msid', 'movie', 'mid') 95 | schema.add_foreign_key('company', 'id', 'copyright', 'cid') 96 | schema.add_foreign_key('copyright', 'msid', 'movie', 'mid') 97 | schema.add_foreign_key('keyword', 'id', 'tags', 'kid') 98 | schema.add_foreign_key('movie', 'mid', 'tags', 'msid') 99 | schema.add_foreign_key('classification', 'msid', 'movie', 'mid') 100 | schema.add_foreign_key('made_by', 'pid', 'producer', 'pid') 101 | schema.add_foreign_key('classification', 'gid', 'genre', 'gid') 102 | schema.add_foreign_key('movie', 'mid', 'written_by', 'msid') 103 | schema.add_foreign_key('made_by', 'msid', 'movie', 'mid') 104 | schema.add_foreign_key('writer', 'wid', 'written_by', 'wid') 105 | schema.add_foreign_key('copyright', 'msid', 'tv_series', 'sid') 106 | schema.add_foreign_key('cast', 'msid', 'tv_series', 'sid') 107 | schema.add_foreign_key('directed_by', 'msid', 'tv_series', 'sid') 108 | schema.add_foreign_key('made_by', 'msid', 'tv_series', 'sid') 109 | else: 110 | raise ValueError("Invalid dataset name: %s" % dataset) 111 | return schema 112 | -------------------------------------------------------------------------------- /src/text_to_sql/eval_exec/qdmr_encoding_parser.py: -------------------------------------------------------------------------------- 1 | import pyparsing as pp 2 | from collections import namedtuple 3 | 4 | from eval_exec.qdmr_encoding import is_reference 5 | 6 | QDMR_STEP_DELIMITER = ";" 7 | 8 | op_list = ["select", "project", "filter", "aggregate", "group", "superlative", "comparative", 9 | "comparative_group", "intersection", "union_column", "union", "discard", "sort", "arithmetic"] 10 | comparators = ["=", ">", "<", ">=", "<=", "!=", "LIKE", "like", "BETWEEN", "start", "end"] 11 | aggregates = ["COUNT", "SUM", "AVG", "MIN", "MAX", "count", "sum", "avg", "min", "max"] 12 | arithmetics = ["+", "-", "*", "/"] 13 | OP = pp.oneOf(op_list) 14 | COMP = pp.oneOf(comparators) 15 | AGGR = pp.oneOf(aggregates) 16 | ARITHMETIC = pp.oneOf(arithmetics) 17 | LP = pp.Literal("(").suppress() 18 | RP = pp.Literal(")").suppress() 19 | COMMA = pp.Literal(",").suppress() 20 | String = pp.Word(pp.alphanums + "_" + "-" + "." + "%" + "*" + "/") 21 | SingleQuoteString = pp.QuotedString(quoteChar="'", unquoteResults=False) 22 | DoubleQuoteString = pp.QuotedString(quoteChar='"', unquoteResults=False) 23 | QuotedString = SingleQuoteString | DoubleQuoteString 24 | ConditionPrefix = AGGR + pp.Literal("(") + String + pp.Literal(")") | String 25 | BetweenValue = pp.Word(pp.alphanums) + pp.Literal("AND") + pp.Word(pp.alphanums) 26 | BasicCondition = pp.Group(ConditionPrefix + COMP + pp.OneOrMore(String)) 27 | Atom = BasicCondition | ConditionPrefix | QuotedString | ARITHMETIC 28 | SExpr = pp.Forward() 29 | FormulaCondition = ConditionPrefix + COMP + SExpr 30 | SExprList = pp.Group((FormulaCondition | SExpr | Atom) + pp.ZeroOrMore(COMMA + (FormulaCondition | SExpr | Atom))) 31 | SExpr << (OP + LP + SExprList + RP) 32 | 33 | Node = namedtuple("Node", ["operator", "arguments"]) 34 | 35 | 36 | def parseAction(string, location, tokens): 37 | return Node(operator=tokens[0], arguments=tokens[1:]) 38 | 39 | 40 | SExpr.setParseAction(parseAction) 41 | 42 | 43 | def pprint(node, tab=""): 44 | print(tab + u"|--" + str(node.operator)) 45 | new_tab = tab + " " 46 | for arg in node.arguments[0]: 47 | if isinstance(arg, Node): 48 | pprint(arg, new_tab) 49 | else: 50 | print(new_tab + arg) 51 | 52 | 53 | def formula_dfs(node, stack): 54 | s = "%s ( " % str(node.operator) 55 | space = " " 56 | for i in range(len(node.arguments[0])): 57 | arg = node.arguments[0][i] 58 | comma = "" if i == 0 else " , " 59 | if isinstance(arg, Node): 60 | last_token = s[-1] 61 | # handle case where argument is formula value of a condition 62 | delimiter = comma if last_token not in comparators else space 63 | s += delimiter + str(formula_dfs(arg, stack)[0]) 64 | elif isinstance(arg, pp.ParseResults): 65 | s += comma + ' '.join(arg) # argument is a simple condition list 66 | else: 67 | # handle case where argument is the comparator of a formula condition 68 | s += comma + str(arg) if arg not in comparators else space + str(arg) 69 | s += " )" 70 | stack += [s] 71 | return s, stack 72 | 73 | 74 | def dfs_ref_substitution(dfs_qdmr_steps): 75 | def remove_references(steps_list): 76 | return list(filter(lambda x: not is_reference(x), steps_list)) 77 | 78 | ret_steps = [] 79 | for i in range(len(dfs_qdmr_steps)): 80 | next_ref = "#%s" % (len(ret_steps) + 1) 81 | next_step = dfs_qdmr_steps[i] 82 | new_steps = [] 83 | if not is_reference(next_step): 84 | for step in dfs_qdmr_steps[i + 1:]: 85 | new_steps += [step.replace(next_step, next_ref)] 86 | dfs_qdmr_steps = dfs_qdmr_steps[:i + 1] + new_steps 87 | ret_steps += [next_step] 88 | return remove_references(dfs_qdmr_steps) 89 | 90 | 91 | def formula_qdmr_to_ref_steps(qdmr_formula_encoding): 92 | parsed = SExpr.parseString(qdmr_formula_encoding) 93 | dfs_steps = formula_dfs(parsed[0], [])[1] 94 | return dfs_ref_substitution(dfs_steps) 95 | 96 | 97 | def formula_to_ref_encoding(qdmr_formula_encoding): 98 | ref_steps = formula_qdmr_to_ref_steps(qdmr_formula_encoding) 99 | delim = " %s " % QDMR_STEP_DELIMITER 100 | return delim.join(ref_steps) 101 | -------------------------------------------------------------------------------- /src/text_to_sql/eval_exec/qdmr_identifier.py: -------------------------------------------------------------------------------- 1 | from eval_exec.operator_identifier import * 2 | from eval_exec.utils import * 3 | 4 | 5 | class QDMRStep: 6 | def __init__(self, step_text, operator, arguments): 7 | self.step = step_text 8 | self.operator = operator 9 | self.arguments = arguments 10 | 11 | def __str__(self): 12 | return "%s%a" % (self.operator.upper(), self.arguments) 13 | 14 | 15 | class StepIdentifier(object): 16 | def __init__(self): 17 | self.identifiers = {"select": IdentifyOperatorSelect(), 18 | "filter": IdentifyOperatorFilter(), 19 | "project": IdentifyOperatorProject(), 20 | "aggregate": IdentifyOperatorAggregate(), 21 | "group": IdentifyOperatorGroup(), 22 | "superlative": IdentifyOperatorSuperlative(), 23 | "comparative": IdentifyOperatorComparative(), 24 | "union": IdentifyOperatorUnion(), 25 | "intersection": IdentifyOperatorIntersect(), 26 | "discard": IdentifyOperatorDiscard(), 27 | "sort": IdentifyOperatorSort(), 28 | "boolean": IdentifyOperatorBoolean(), 29 | "arithmetic": IdentifyOperatorArithmetic(), 30 | "comparison": IdentifyOperatorComparison()} 31 | self.operator = None 32 | 33 | def step_type(self, step_text): 34 | potential_operators = set() 35 | for op in self.identifiers: 36 | identifier = self.identifiers[op] 37 | if identifier.identify_op(step_text): 38 | potential_operators.add(op) 39 | # no matching operator found 40 | if len(potential_operators) == 0: 41 | return None 42 | operators = potential_operators.copy() 43 | # duplicate candidates 44 | while len(operators) > 1: 45 | # avoid project duplicity with aggregate 46 | if "project" in operators: 47 | operators.remove("project") 48 | # avoid filter duplcitiy with comparative, superlative, sort, discard 49 | elif "filter" in operators: 50 | operators.remove("filter") 51 | # return boolean (instead of intersect) 52 | elif "boolean" in operators: 53 | operators = {"boolean"} 54 | # return intersect (instead of filter) 55 | elif "intersect" in operators: 56 | operators = {"intersect"} 57 | # return superlative (instead of comparative) 58 | elif "superlative" in operators: 59 | operators = {"superlative"} 60 | # return group (instead of arithmetic) 61 | elif "group" in operators: 62 | operators = {"group"} 63 | # return comparative (instead of discard) 64 | elif "comparative" in operators: 65 | operators = {"comparative"} 66 | # return intersection (instead of comparison) 67 | elif "intersection" in operators: 68 | operators = {"intersection"} 69 | else: 70 | # no valid operator 71 | assert (len(operators) == 1) 72 | operator = list(operators)[0] 73 | self.operator = operator 74 | return operator 75 | 76 | def step_args(self, step_text): 77 | self.operator = self.step_type(step_text) 78 | identifier = self.identifiers[self.operator] 79 | args = identifier.extract_args(step_text) 80 | return args 81 | 82 | def identify(self, step_text): 83 | self.operator = self.step_type(step_text) 84 | args = self.step_args(step_text) 85 | return QDMRStep(step_text, self.operator, args) 86 | 87 | 88 | class QDMRProgramBuilder(object): 89 | def __init__(self, qdmr_text): 90 | self.qdmr_text = qdmr_text 91 | self.steps = None 92 | self.operators = None 93 | self.program = None 94 | 95 | def build(self): 96 | self.get_operators() 97 | self.build_steps() 98 | return True 99 | 100 | def build_steps(self): 101 | self.steps = [] 102 | steps = parse_decomposition(self.qdmr_text) 103 | step_identifier = StepIdentifier() 104 | for step_text in steps: 105 | try: 106 | step = step_identifier.identify(step_text) 107 | except: 108 | print("Unable to identify step: %s" % step_text) 109 | step = None 110 | finally: 111 | self.steps += [step] 112 | return self.steps 113 | 114 | def get_operators(self): 115 | self.operators = [] 116 | steps = parse_decomposition(self.qdmr_text) 117 | step_identifier = StepIdentifier() 118 | for step_text in steps: 119 | try: 120 | op = step_identifier.step_type(step_text) 121 | except: 122 | print("Unable to identify operator: %s" % step_text) 123 | op = None 124 | finally: 125 | self.operators += [op] 126 | return self.operators 127 | 128 | def build_program(self): 129 | raise NotImplementedError 130 | return True 131 | -------------------------------------------------------------------------------- /src/text_to_sql/eval_exec/qdmr_sql.py: -------------------------------------------------------------------------------- 1 | from eval_exec.db_schema import DBSchema 2 | from eval_exec.grounded_qdmr import GroundedQDMR 3 | from eval_exec.utils import get_table_and_column 4 | from eval_exec.preprocess_db import prepare_db_schema 5 | import re 6 | 7 | 8 | def qdmr_to_sql(qdmr_ref_steps_encoding, question, db_schema_path, dataset=None): 9 | dataset = "spider" if dataset in ["spider", None] else dataset 10 | schema = prepare_db_schema(db_schema_path, dataset=dataset) 11 | grounded_qdmr = GroundedQDMR(qdmr_ref_steps_encoding, question, schema) 12 | grounded_qdmr.to_sql() 13 | n = str(len(grounded_qdmr.sql_steps)) 14 | return grounded_qdmr.sql_steps[n]["SQL"] 15 | 16 | 17 | def prune_nested_queries(pred_sql, schema_path, lowercased=None): 18 | schema = DBSchema(schema_path) 19 | columns = schema.columns() 20 | for col in columns: 21 | table, _ = get_table_and_column(col) 22 | redundant_nested_cond = f"AND {col} IN ( SELECT {col} FROM {table} )" 23 | redundant_nested_cond = redundant_nested_cond.lower() if lowercased else redundant_nested_cond 24 | if redundant_nested_cond in pred_sql: 25 | pred_sql = pred_sql.replace(redundant_nested_cond, "") 26 | return pred_sql 27 | 28 | 29 | def normalize_qdmr_prediction(predicted_qdmr): 30 | def fix_not_equal(pred): 31 | op = "!=" 32 | if op in pred and " " + op not in pred: 33 | return pred.replace(op, " " + op) 34 | return pred 35 | 36 | def replace_trailing_dot(pred): 37 | if " ." in pred: 38 | pred = pred.replace(" .", ".") 39 | if ". " in pred: 40 | pred = pred.replace(". ", ".") 41 | return pred 42 | 43 | def handle_like_value(pred): 44 | def add_like_pct(string): 45 | extracted_value = string.split("like ")[1][:-1].strip() # omit last ")" 46 | pct_value = f"%{extracted_value}%" 47 | return string.replace(extracted_value, pct_value) 48 | 49 | like_cond_val = [item.group(0) for item in re.finditer(r'like .*? \)', pred)] 50 | for cond in like_cond_val: 51 | pred = pred.replace(cond, add_like_pct(cond)) 52 | return pred 53 | 54 | return handle_like_value(replace_trailing_dot(fix_not_equal(predicted_qdmr))) 55 | -------------------------------------------------------------------------------- /src/text_to_sql/eval_exec/schema_parser.py: -------------------------------------------------------------------------------- 1 | # DB schema abstraction 2 | 3 | # A sqlite3 schema parser 4 | 5 | import sqlite3 6 | import traceback 7 | import sys 8 | 9 | 10 | class SchemaParser(object): 11 | def __init__(self): 12 | self.path = None 13 | 14 | def parse(self, schema_path, name): 15 | self.path = schema_path 16 | parsed_data = {'db_id': name, 17 | 'table_names_original': [], 18 | 'table_names': [], 19 | 'column_names_original': [(-1, '*')], 20 | 'column_names': [(-1, '*')], 21 | 'column_types': ['text'], 22 | 'primary_keys': [], 23 | 'foreign_keys': []} 24 | 25 | conn = sqlite3.connect(self.path) 26 | conn.execute('pragma foreign_keys=ON') 27 | cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table';") 28 | 29 | fk_holder = [] 30 | for i, item in enumerate(cursor.fetchall()): 31 | table_name = item[0] 32 | parsed_data['table_names_original'].append(table_name) 33 | parsed_data['table_names'].append(table_name.lower().replace("_", ' ')) 34 | fks = conn.execute("PRAGMA foreign_key_list('{}') ".format(table_name)).fetchall() 35 | # print("db:{} table:{} fks:{}".format(f,table_name,fks)) 36 | fk_holder.extend([[(table_name, fk[3]), (fk[2], fk[4])] for fk in fks]) 37 | cur = conn.execute("PRAGMA table_info('{}') ".format(table_name)) 38 | for j, col in enumerate(cur.fetchall()): 39 | parsed_data['column_names_original'].append((i, col[1])) 40 | parsed_data['column_names'].append((i, col[1].lower().replace("_", " "))) 41 | # varchar, '' -> text, int, numeric -> integer, 42 | col_type = col[2].lower() 43 | if 'char' in col_type or col_type == '' or 'text' in col_type or 'var' in col_type: 44 | parsed_data['column_types'].append('text') 45 | elif 'int' in col_type or 'numeric' in col_type or 'decimal' in col_type or 'number' in col_type \ 46 | or 'id' in col_type or 'real' in col_type or 'double' in col_type or 'float' in col_type: 47 | parsed_data['column_types'].append('number') 48 | elif 'date' in col_type or 'time' in col_type or 'year' in col_type: 49 | parsed_data['column_types'].append('time') 50 | elif 'boolean' in col_type: 51 | parsed_data['column_types'].append('boolean') 52 | else: 53 | parsed_data['column_types'].append('others') 54 | 55 | if col[5] == 1: 56 | parsed_data['primary_keys'].append(len(parsed_data['column_names']) - 1) 57 | 58 | parsed_data["foreign_keys"] = fk_holder 59 | parsed_data['foreign_keys'] = self.convert_fk_index(parsed_data) 60 | return parsed_data 61 | 62 | def convert_fk_index(self, data): 63 | fk_holder = [] 64 | for fk in data["foreign_keys"]: 65 | tn, col, ref_tn, ref_col = fk[0][0], fk[0][1], fk[1][0], fk[1][1] 66 | ref_cid, cid = None, None 67 | try: 68 | tid = data['table_names_original'].index(tn) 69 | ref_tid = data['table_names_original'].index(ref_tn) 70 | 71 | for i, (tab_id, col_org) in enumerate(data['column_names_original']): 72 | if tab_id == ref_tid and ref_col == col_org: 73 | ref_cid = i 74 | elif tid == tab_id and col == col_org: 75 | cid = i 76 | if ref_cid and cid: 77 | fk_holder.append([cid, ref_cid]) 78 | except: 79 | traceback.print_exc() 80 | print("table_names_original: ", data['table_names_original']) 81 | print("finding tab name: ", tn, ref_tn) 82 | sys.exit() 83 | return fk_holder -------------------------------------------------------------------------------- /src/text_to_sql/eval_exec/sql_execution.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | import time 3 | import threading 4 | import sys 5 | from collections import OrderedDict 6 | 7 | # from wrapt_timeout_decorator import * 8 | 9 | 10 | TIMEOUT = 60 11 | 12 | 13 | # TIMEOUT = 90 # 90 seconds for Academic, IMDB & Yelp 14 | 15 | 16 | def interrupt_sqlite(connection): 17 | print('Interrupted sqlite connection', file=sys.stderr) 18 | connection.interrupt() 19 | 20 | 21 | # @timeout(dec_timeout=TIMEOUT, use_signals=False) 22 | def execute_sql(db, sql): 23 | """ 24 | Returns a list of tuple that are the query results 25 | 26 | Parameters 27 | ---------- 28 | db : str 29 | Full path to DB schema 30 | sql : str 31 | SQL query to be executed 32 | 33 | 34 | Returns 35 | ------- 36 | list 37 | List of tuple that are the query results 38 | """ 39 | conn = sqlite3.connect(db) 40 | conn.text_factory = lambda b: b.decode(errors='ignore') 41 | c = conn.cursor() 42 | try: 43 | c.execute(sql) 44 | except: 45 | return None 46 | return c.fetchall() 47 | 48 | 49 | def normalize_tuple(tup): 50 | # cast all tuple values to strings 51 | norm_vars = [str(var) for var in tup] 52 | return tuple(norm_vars) 53 | 54 | 55 | def normalize_denotation(denotation_list, distinct=None): 56 | if not denotation_list: 57 | return denotation_list 58 | # remove duplicates 59 | denotation_list = list(OrderedDict.fromkeys(denotation_list)) if distinct else denotation_list 60 | sort_tuples = [sorted(normalize_tuple(tup)) for tup in denotation_list] 61 | return sorted(sort_tuples) # sort results set 62 | 63 | 64 | def correct_denotation(pred_sql, gold_sql, db_path, distinct=None): 65 | gold_denotation = execute_sql(db_path, gold_sql) 66 | pred_denotation = execute_sql(db_path, pred_sql) 67 | if gold_denotation == pred_denotation: 68 | # unnormalized denotations 69 | return True 70 | gold_denotation_norm = normalize_denotation(gold_denotation, distinct=distinct) 71 | pred_denotation_norm = normalize_denotation(pred_denotation, distinct=distinct) 72 | return gold_denotation_norm == pred_denotation_norm 73 | -------------------------------------------------------------------------------- /src/text_to_sql/eval_exec/sql_parser.py: -------------------------------------------------------------------------------- 1 | # A SQL query parser - for grounding evaluation 2 | import re 3 | 4 | from eval_exec.utils import get_table_and_column 5 | 6 | 7 | class SQLParser(object): 8 | def __init__(self): 9 | self.tables = None 10 | self.columns = None 11 | 12 | def parse(self, query, schema): 13 | query_tables = set() 14 | query_columns = set() 15 | query = query.lower() 16 | tokens = query.split() 17 | for tok in tokens: 18 | if re.match(r't[0-9]\.', tok): 19 | query_columns.add(tok) 20 | # get tables in sql query 21 | from_clause = query.split('from')[1].split('where')[0] 22 | from_tokens = from_clause.split() 23 | schema_tables = schema.tables() 24 | schema_tables_lowercase = [name.lower() for name in schema_tables] 25 | for tok in from_tokens: 26 | if tok in schema_tables_lowercase: 27 | query_tables.add(tok) 28 | self.tables = list(query_tables) 29 | columns = list(query_columns) 30 | if len(self.tables) == 1: 31 | # all columns in query belong to a single table 32 | table_name = self.tables[0] 33 | for tok in tokens: 34 | # parse token from 'op()' and 'table.col' 35 | tok = tok.split('(')[1] if '(' in tok else tok 36 | tok = tok.split(')')[0] if ')' in tok else tok 37 | tok = tok.split('.')[1] if '.' in tok else tok 38 | schema_columns = schema.columns() 39 | for col in schema_columns: 40 | if col == tok: 41 | col_full_name = "%s.%s" % (table_name, col) 42 | query_columns.add(col_full_name) 43 | self.columns = list(query_columns) 44 | return True 45 | # more than one table in query 46 | # replace column table alias T1.col --> table_name.col 47 | aliases = re.findall(r'as\st[0-9]', query) 48 | alias_map = {} 49 | for alias in aliases: 50 | table_alias = alias.split()[-1] 51 | prefix = query.split(alias)[0] 52 | table_name = prefix.split()[-1] 53 | alias_map[table_alias] = table_name 54 | self.columns = [] 55 | for col in columns: 56 | for alias in alias_map.keys(): 57 | if alias in col: 58 | column_full = col.replace(alias, alias_map[alias]) 59 | self.columns += [column_full] 60 | self.columns = list(set(self.columns)) 61 | return True 62 | 63 | def get_table_aliases(self, query): 64 | """Returns map from table alias (t#) to its name""" 65 | query = query.lower() 66 | aliases = re.findall(r'as\st[0-9]', query) 67 | alias_map = {} 68 | for alias in aliases: 69 | table_alias = alias.split()[-1] 70 | prefix = query.split(alias)[0] 71 | table_name = prefix.split()[-1] 72 | alias_map[table_alias] = table_name 73 | # map from table alias (e.g. t1) to its name 74 | return alias_map 75 | 76 | def extract_values(self, query, schema): 77 | query = query.lower() 78 | value_to_col = {} 79 | # Find all values based on string delimiters 80 | single_paren_vals = [item.group(0) for item in re.finditer(r'\'.*?\'', query)] 81 | double_paren_vals = [item.group(0) for item in re.finditer(r'\".*?\"', query)] 82 | number_vals = [item.group(0) for item in re.finditer(r'[0-9]+', query)] 83 | # filter numbers in table aliases e.g., 1 in T1 84 | number_vals = list(filter(lambda x: (" %s" % x) in query, number_vals)) 85 | vals = single_paren_vals + double_paren_vals + number_vals 86 | # Map values to corresponding columns 87 | for value in vals: 88 | # SQL satement will be: "table.col operator value", e.g.: 89 | # T2.allergytype = "food" 90 | # name LIKE '%Led%' 91 | table = None 92 | prefix = query.split(value)[0] 93 | aliased_column = prefix.split()[-2] 94 | column_names = schema.column_names() 95 | schema_columns = schema.columns() 96 | if "." in aliased_column: 97 | # column is either aliased T#.col or table.col 98 | aliased_table, col = get_table_and_column(aliased_column) 99 | table = self.get_aliased_table(aliased_table, query, schema) 100 | elif aliased_column.lower() not in column_names: 101 | # nearest token is not column name 102 | # return the nearest column name instead 103 | preceding_toks = prefix.lower().split() 104 | for i in reversed(range(len(preceding_toks))): 105 | if preceding_toks[i] in column_names: 106 | aliased_column = preceding_toks[i] 107 | break 108 | else: 109 | # no aliased table in query 110 | # find nearest table to the column name 111 | col = aliased_column 112 | col_match_positions = [m.start() for m in re.finditer(col, query)] 113 | last_match_pos = col_match_positions[-1] 114 | preceding_toks = query[:last_match_pos].split() 115 | table_names = schema.tables() 116 | for i in reversed(range(len(preceding_toks))): 117 | if preceding_toks[i] in table_names: 118 | table = preceding_toks[i] 119 | full_col_name = "%s.%s" % (table, col) 120 | if full_col_name in schema_columns: 121 | # validate full column name is valid 122 | break 123 | # non-number values have parentheses 124 | value_no_paren = value[1:-1] if not value.isdigit() else value 125 | if value_no_paren.startswith("%") \ 126 | and value_no_paren.endswith("%"): 127 | # value extracted from LIKE '%%' statement 128 | value_no_paren = value_no_paren[1:-1] 129 | if table: 130 | value_to_col[value_no_paren.strip()] = "%s.%s".strip() % (table, col) 131 | return value_to_col 132 | 133 | def get_aliased_table(self, aliased_table, query, schema): 134 | """ 135 | Receive table name referenced query and retreive its actual table 136 | Handles: 137 | Spider aliases format e.g., T#.column 138 | ATIS aliases format e.g., table_#.column 139 | """ 140 | table_aliases = self.get_table_aliases(query) 141 | if re.match(r't[0-9]', aliased_table): 142 | return table_aliases[aliased_table] 143 | if re.match(r'.*\_[0-9]', aliased_table): 144 | # remove the '_#' suffix 145 | actual_table = '_'.join(aliased_table.split('_')[:-1]) 146 | if actual_table in schema.tables(): 147 | return actual_table 148 | return aliased_table -------------------------------------------------------------------------------- /src/text_to_sql/eval_exec/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | DELIMITER = ';' 4 | REF = '#' 5 | 6 | 7 | def parse_decomposition(qdmr): 8 | """Parses the decomposition into an ordered list of steps 9 | 10 | Parameters 11 | ---------- 12 | qdmr : str 13 | String representation of the QDMR 14 | 15 | Returns 16 | ------- 17 | list 18 | returns ordered list of qdmr steps 19 | """ 20 | # remove digit commas 1,000 --> 1000 21 | matches = re.findall(r"[\d,]+[,\d]", qdmr) 22 | for m in matches: 23 | no_comma = m.replace(",", "") 24 | qdmr = qdmr.replace(m, no_comma) 25 | # parse commas as separate tokens 26 | qdmr = qdmr.replace(",", " , ") 27 | crude_steps = qdmr.split(DELIMITER) 28 | steps = [] 29 | for i in range(len(crude_steps)): 30 | step = crude_steps[i] 31 | tokens = step.split() 32 | step = "" 33 | # remove 'return' prefix 34 | for tok in tokens[1:]: 35 | step += tok.strip() + " " 36 | step = step.strip() 37 | steps += [step] 38 | return steps 39 | 40 | 41 | def get_table_and_column(full_column_name): 42 | return full_column_name.split(".") 43 | 44 | 45 | def extract_comparator(condition): 46 | """ 47 | Returns comparator and value of a 48 | QDMR comparative step condition 49 | 50 | Parameters 51 | ---------- 52 | condition : str 53 | Phrase representing condition of a QDMR step 54 | 55 | Returns 56 | ------- 57 | tuple 58 | (comparator, value) 59 | """ 60 | # extract comparative 61 | numbers = {"zero": "0", "one": "1", "two": "2", "three": "3", "four": "4", "five": "5", \ 62 | "six": "6", "seven": "7", "eight": "8", "nine": "9", "ten": "10"} 63 | comparatives = {} 64 | comparatives["BETWEEN"] = ["between"] 65 | comparatives[">"] = ["more than", "above", "larger than", "larger", \ 66 | "older than", "older", "higher than", "higher", \ 67 | "greater than", "greater", "bigger than", "bigger", \ 68 | "after", "over"] 69 | comparatives[">="] = ["at least"] 70 | comparatives["<"] = ["less than", "under", "lower than", "lower", \ 71 | "younger than", "younger", "before", "below", \ 72 | "shorter than", "smaller than", "smaller"] 73 | comparatives["<="] = ["at most"] 74 | comparatives["!="] = ["is not"] 75 | comparatives["start"] = ['start with', 'starts with', 'begin'] 76 | comparatives["end"] = ['end with', 'ends with'] 77 | comparatives["LIKE"] = ["the letter", "the string", "the word", "the phrase", \ 78 | "contain", "include", "has", "have", \ 79 | "contains", "substring", "includes"] 80 | comparatives["="] = ['is equal to', 'equal to', 'same as', \ 81 | 'is ', 'are ', 'was '] 82 | unformatted = {} 83 | unformatted[">="] = ["or later", "or more", "or after"] 84 | unformatted["<="] = ["or earlier", "or less", "or before"] 85 | ###TODO: handle "NOT LIKE" 86 | comp = None 87 | for c in comparatives.keys(): 88 | if comp: 89 | break 90 | for trigger in comparatives[c]: 91 | if trigger in condition: 92 | comp = c 93 | break 94 | if comp: 95 | # extract value/reference 96 | value_phrase = condition.split(trigger)[1].strip() 97 | if comp == "BETWEEN": 98 | # "between num1 AND num2" 99 | return comp, value_phrase.upper() 100 | elif comp: 101 | # check for unformatted comparators in value phrase 102 | for c in unformatted.keys(): 103 | for trigger in unformatted[c]: 104 | if trigger in condition: 105 | comp = c 106 | value_phrase = condition.split(trigger)[0].strip() 107 | break 108 | for tok in value_phrase.split(): 109 | if tok.isnumeric(): 110 | return comp, tok 111 | if tok in numbers.keys(): 112 | return comp, numbers[tok] 113 | return comp, value_phrase 114 | return "=", None 115 | -------------------------------------------------------------------------------- /src/text_to_sql/eval_spider.py: -------------------------------------------------------------------------------- 1 | from eval_exec.db_schema import DBSchema 2 | from eval_exec.predicted_sql import fix_sql_casing 3 | from eval_exec.qdmr_encoding_parser import formula_to_ref_encoding 4 | from eval_exec.qdmr_sql import qdmr_to_sql, normalize_qdmr_prediction, prune_nested_queries 5 | from eval_exec.sql_execution import correct_denotation 6 | from eval_exec.utils import get_table_and_column 7 | from evaluation import * 8 | from process_sql import * 9 | import re 10 | 11 | 12 | def restore_oov(prediction): 13 | """ 14 | Replace T5 SPM OOV character with `<`. 15 | Certain punctuation characters are mapped to the OOV symbol in T5's 16 | sentence-piece model. For Spider, this appears to only affect the `<` symbol, 17 | so it can be deterministically recovered by running this script. 18 | An alternative is to preprocess dataset to avoid OOV symbols for T5. 19 | """ 20 | pred = prediction.replace(" ⁇ ", "<") 21 | return pred 22 | 23 | 24 | def format_sql(prediction, no_split=None): 25 | pred = restore_oov(prediction) 26 | if no_split: 27 | return pred 28 | return pred.split(";")[0] + ";" 29 | 30 | 31 | UPPERCASE_DBS = ["cre_Doc_Control_Systems", 32 | "cre_Doc_Template_Mgt", 33 | "cre_Doc_Tracking_DB", 34 | "cre_Docs_and_Epenses", 35 | "cre_Drama_Workshop_Groups", 36 | "cre_Theme_park", 37 | "insurance_and_eClaims"] 38 | 39 | 40 | class ExactSetMatch(object): 41 | 42 | def __init__(self): 43 | return 44 | 45 | def extract_db_name(self, example): 46 | db = example.split(":")[0].strip() 47 | for uppercase_db in UPPERCASE_DBS: 48 | if uppercase_db.lower() == db: 49 | return uppercase_db 50 | return db 51 | 52 | def evaluate(self, examples, gold, predict, db_dir, table, etype): 53 | kmaps = build_foreign_key_map_from_json(table) 54 | glist = gold 55 | plist = predict 56 | db_list = [self.extract_db_name(ex) for ex in examples] 57 | evaluator = Evaluator() 58 | 59 | levels = ['easy', 'medium', 'hard', 'extra', 'all'] 60 | partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 61 | 'group', 'order', 'and/or', 'IUEN', 'keywords'] 62 | entries = [] 63 | scores = {} 64 | 65 | for level in levels: 66 | scores[level] = {'count': 0, 'partial': {}, 'exact': 0.} 67 | scores[level]['exec'] = 0 68 | for type_ in partial_types: 69 | scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0., 'acc_count': 0, 'rec_count': 0} 70 | 71 | eval_err_num = 0 72 | for p, g, db in zip(plist, glist, db_list): 73 | p_str = format_sql(p) 74 | g_str = format_sql(g) # sentencepiece tokenization to oov tokens 75 | db_name = db 76 | db = os.path.join(db_dir, db, db + ".sqlite") 77 | schema = Schema(get_schema(db)) 78 | g_sql = get_sql(schema, g_str) 79 | hardness = evaluator.eval_hardness(g_sql) 80 | scores[hardness]['count'] += 1 81 | scores['all']['count'] += 1 82 | 83 | try: 84 | p_sql = get_sql(schema, p_str) 85 | except: 86 | # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql 87 | p_sql = { 88 | "except": None, 89 | "from": { 90 | "conds": [], 91 | "table_units": [] 92 | }, 93 | "groupBy": [], 94 | "having": [], 95 | "intersect": None, 96 | "limit": None, 97 | "orderBy": [], 98 | "select": [ 99 | False, 100 | [] 101 | ], 102 | "union": None, 103 | "where": [] 104 | } 105 | eval_err_num += 1 106 | 107 | # rebuild sql for value evaluation 108 | kmap = kmaps[db_name] 109 | g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema) 110 | g_sql = rebuild_sql_val(g_sql) 111 | g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) 112 | p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema) 113 | p_sql = rebuild_sql_val(p_sql) 114 | p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) 115 | 116 | if etype in ["all", "exec"]: 117 | exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql) 118 | if exec_score: 119 | scores[hardness]['exec'] += 1.0 120 | scores['all']['exec'] += 1.0 121 | 122 | if etype in ["all", "match"]: 123 | exact_score = evaluator.eval_exact_match(p_sql, g_sql) 124 | partial_scores = evaluator.partial_scores 125 | scores[hardness]['exact'] += exact_score 126 | scores['all']['exact'] += exact_score 127 | 128 | for level in levels: 129 | if scores[level]['count'] == 0: 130 | continue 131 | if etype in ["all", "exec"]: 132 | scores[level]['exec'] /= scores[level]['count'] 133 | if etype in ["all", "match"]: 134 | scores[level]['exact'] /= scores[level]['count'] 135 | if etype == "exec": 136 | return scores['all']['exec'] 137 | if etype == "match": 138 | return scores['all']['exact'] 139 | return {"match": scores['all']['exact'], "exec": scores['all']['exec']} 140 | 141 | 142 | class ExecutionAccuracy(object): 143 | 144 | def __init__(self): 145 | return 146 | 147 | def extract_db_name(self, example): 148 | db = example.split(":")[0].strip() 149 | for uppercase_db in UPPERCASE_DBS: 150 | if uppercase_db.lower() == db: 151 | return uppercase_db 152 | return db 153 | 154 | def evaluate(self, examples, gold, predict, db_dir, exec_set_match, pred_type=None, dataset=None): 155 | """ 156 | :param examples: list 157 | List of question and db examples 158 | :param gold: list 159 | List of cased gold SQL queries 160 | :param predict: list 161 | List of lower-cased predicted queries 162 | :param db_dir: string 163 | Path to db files 164 | :param exec_set_match: bool 165 | Flag whether to use set match when comparing execution results 166 | :param pred_type: str 167 | The format of the predictions, if not sql, convert them to sql 168 | :return: dict 169 | Dictionary with the execution accuracy & number of errors 170 | """ 171 | if pred_type: 172 | assert pred_type in ["qdmr_formula", "qdmr_steps", "qdmr_sql", "sql"] 173 | glist = gold 174 | plist = predict 175 | db_list = [self.extract_db_name(ex) for ex in examples] 176 | evaluator = Evaluator() 177 | 178 | levels = ['easy', 'medium', 'hard', 'extra', 'all'] 179 | partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 180 | 'group', 'order', 'and/or', 'IUEN', 'keywords'] 181 | scores = {} 182 | 183 | for level in levels: 184 | scores[level] = {'count': 0, 'partial': {}, 'exact': 0.} 185 | scores[level]['exec'] = 0 186 | for type_ in partial_types: 187 | scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0., 'acc_count': 0, 'rec_count': 0} 188 | 189 | eval_err_num = 0 190 | 191 | for p, g, db in zip(plist, glist, db_list): 192 | val = pred_type=="qdmr_steps" 193 | p_str = format_sql(p, no_split=(pred_type=="qdmr_steps")) 194 | g_str = format_sql(g) # sentencepiece tokenization to oov tokens 195 | db = os.path.join(db_dir, db, db + ".sqlite") 196 | # TODO: commented out because of Geo880 SQL format 197 | # schema = Schema(get_schema(db)) 198 | # g_sql = get_sql(schema, g_str) 199 | # hardness = evaluator.eval_hardness(g_sql) 200 | # scores[hardness]['count'] += 1 201 | scores['all']['count'] += 1 202 | 203 | # repair predicted query into executable SQL 204 | print("**** g_str: ", g_str) 205 | print("**** p_str: ", p_str) 206 | try: 207 | if pred_type in ["qdmr_formula", "qdmr_steps"]: 208 | p_str = normalize_qdmr_prediction(p_str) 209 | p_str = formula_to_ref_encoding(p_str) if pred_type == "qdmr_formula" else p_str 210 | p_str = qdmr_to_sql(qdmr_ref_steps_encoding=p_str, 211 | question=None, 212 | db_schema_path=db, 213 | dataset=dataset) 214 | p_str = prune_nested_queries(p_str, db) 215 | else: 216 | p_str = prune_nested_queries(p_str, db, lowercased=True) 217 | p_sql = fix_sql_casing(p_str, g_str) 218 | 219 | exec_score = correct_denotation(p_sql, g_str, db, distinct=exec_set_match) 220 | if exec_score: 221 | # scores[hardness]['exec'] += 1.0 222 | scores['all']['exec'] += 1.0 223 | except: 224 | eval_err_num += 1 225 | 226 | for level in levels: 227 | if scores[level]['count'] == 0: 228 | continue 229 | scores[level]['exec'] /= scores[level]['count'] 230 | return {"exec": scores['all']['exec'], "errors": eval_err_num} -------------------------------------------------------------------------------- /src/text_to_sql/test.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | from tqdm.auto import tqdm 3 | import argparse 4 | import glob 5 | import os 6 | import json 7 | import time 8 | import logging 9 | import random 10 | import re 11 | from itertools import chain 12 | from string import punctuation 13 | 14 | import nltk 15 | 16 | from dataset_utils import load_json 17 | 18 | nltk.download('punkt') 19 | from nltk.tokenize import sent_tokenize 20 | 21 | import pandas as pd 22 | import numpy as np 23 | import torch 24 | from torch.utils.data import Dataset, DataLoader 25 | import pytorch_lightning as pl 26 | 27 | from transformers import ( 28 | AdamW, 29 | T5ForConditionalGeneration, 30 | T5Tokenizer, 31 | get_linear_schedule_with_warmup 32 | ) 33 | 34 | from model import LoggingCallback, T5FineTuner, get_dataset, evaluate_predictions 35 | import numpy as np 36 | import torch 37 | import pytorch_lightning as pl 38 | 39 | if __name__ == '__main__': 40 | args_dict = dict( 41 | dataset='geo', 42 | target_encoding='qdmr_formula', 43 | data_dir='data/spider', # path for data files 44 | output_dir='', # path to save the checkpoints 45 | db_dir='other_database', # name of db dir in data dir 46 | training_set_file='train_spider.json', # name of training set file in data dir 47 | dev_set_file='other_queries/yelp/yelp_sql.json', # name of dev set file in data dir 48 | dev_set_sql='other_queries/yelp/test_gold_yelp.sql', # name of file containing the dev sql queries in data dir 49 | test_set_file='', 50 | test_set_sql='', 51 | model_name_or_path='t5-base', 52 | tokenizer_name_or_path='t5-base', 53 | max_seq_length=512, 54 | learning_rate=1e-4, 55 | weight_decay=0.0, 56 | adam_epsilon=1e-8, 57 | warmup_steps=0, 58 | train_batch_size=2, 59 | eval_batch_size=2, 60 | num_train_epochs=2, 61 | gradient_accumulation_steps=64, 62 | n_gpu=1, 63 | early_stop_callback=False, 64 | fp_16=False, # if you want to enable 16-bit training then install apex and set this to true 65 | opt_level='O1', 66 | # you can find out more on optimisation levels here https://nvidia.github.io/apex/amp.html#opt-levels-and-properties 67 | max_grad_norm=1.0, # if you enable 16-bit training then set this to a sensible value, 0.5 is a good default 68 | seed=42, 69 | ) 70 | 71 | args = argparse.Namespace(**args_dict) 72 | 73 | tokenizer = T5Tokenizer.from_pretrained('t5-large') 74 | checkpoint_path = "/t5_large_spider-train_qdmr_formula_few_shot_finetune_135_epochs_bs_1_accum_128_epochs_150_lr_1e-4_seed_42/checkpointepoch=9.ckpt" 75 | 76 | print("Loading T5FinalTuner pretrained model...") 77 | model = T5FineTuner.load_from_checkpoint(checkpoint_path) 78 | 79 | # re-set missing params 80 | model.hparams.target_encoding = args.target_encoding 81 | model.hparams.data_dir = args.data_dir 82 | model.hparams.db_dir = args.db_dir 83 | model.hparams.dev_set_file = args.dev_set_file 84 | model.hparams.dev_set_sql = args.dev_set_sql 85 | 86 | model = model.to('cuda') 87 | print("Done!") 88 | 89 | print("Loading dataset...") 90 | # dataset = get_dataset(tokenizer, "test", args) 91 | dataset = get_dataset(tokenizer=tokenizer, 92 | type_path="val", 93 | args=args, 94 | dataset_name=args.dataset, 95 | target_encoding=args.target_encoding) 96 | 97 | print("Generating sequences...") 98 | loader = DataLoader(dataset, batch_size=1, num_workers=4) 99 | model.model.eval() 100 | outputs = [] 101 | targets = [] 102 | for batch in tqdm(loader): 103 | outs = model.model.generate(input_ids=batch['source_ids'].cuda(), 104 | attention_mask=batch['source_mask'].cuda(), 105 | max_length=512) 106 | 107 | dec = [tokenizer.decode(ids) for ids in outs] 108 | target = [tokenizer.decode(ids) for ids in batch["target_ids"]] 109 | 110 | outputs.extend(dec) 111 | targets.extend(target) 112 | 113 | 114 | def restore_oov(prediction): 115 | """ 116 | Replace T5 SPM OOV character with `<`. 117 | Certain punctuation characters are mapped to the OOV symbol in T5's 118 | sentence-piece model. For Spider, this appears to only affect the `<` symbol, 119 | so it can be deterministically recovered by running this script. 120 | An alternative is to preprocess dataset to avoid OOV symbols for T5. 121 | """ 122 | pred = prediction.replace(" ⁇ ", "<") 123 | return pred 124 | 125 | 126 | def format_sql(prediction): 127 | pred = restore_oov(prediction) 128 | return pred.split(";")[0] + ";" 129 | 130 | 131 | print(len(outputs)) 132 | 133 | for i, out in enumerate(outputs): 134 | print(out) 135 | 136 | #print("****************** Formatted *************************") 137 | #for i, out in enumerate(outputs): 138 | # print(format_sql(out)) 139 | 140 | 141 | dbs_list = ["geo:"] * len(outputs) 142 | if args.dataset == "spider": 143 | dbs_list = [] 144 | dev_file_path = os.path.join(args.data_dir, args.dev_set_file) 145 | raw_data = load_json(dev_file_path) 146 | for example in raw_data: 147 | dbs_list += [example["db_id"]] 148 | 149 | score = evaluate_predictions(examples=dbs_list, 150 | gold_labels=args.test_set_sql, 151 | predictions=outputs, 152 | args=args, 153 | is_test=True) 154 | print(score) -------------------------------------------------------------------------------- /src/text_to_sql/test_exec_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from eval_spider import ExecutionAccuracy 4 | 5 | 6 | def load_json_dup(filepath): 7 | with open(filepath, "r") as reader: 8 | text = reader.read() 9 | return json.loads(text) 10 | 11 | 12 | def read_spider_examples(train_dataset_path): 13 | inputs = [] 14 | # Read examples 15 | raw_data = load_json_dup(train_dataset_path) 16 | for example in raw_data: 17 | database = example["db_id"].lower() 18 | source = example["question"] 19 | target = example["query"] 20 | # Prepend database 21 | source = "%s: %s" % (database, source) 22 | input = source.lower() 23 | target = target.lower() 24 | inputs.append(input) 25 | return inputs 26 | 27 | 28 | def read_lines_from_file_sql(sql_file): 29 | with open(sql_file, encoding='UTF-8') as f: 30 | glist = [l.strip().split('\t')[0] for l in f.readlines() if len(l.strip()) > 0] 31 | return glist 32 | 33 | 34 | def prepare_qdmr_grounded_sql_preds(grounded_sql_path, gold_examples_length): 35 | """Read qdmr-grounded SQL and lowercase them to mimic T5 predictions. 36 | Order the predictions according to Spider's original order. 37 | Where no valid SQL was grounded place an empty string as the default value.""" 38 | DEFAULT = "" 39 | raw_data = load_json_dup(grounded_sql_path) 40 | examples = raw_data["data"] 41 | ex_id_to_sql = {} 42 | for ex in examples: 43 | spider_id = int(ex["example_id"].split("_")[-1]) 44 | correct_grounded_sql = ex["correct_denotation"] 45 | sql = ex["sql_ground"] if correct_grounded_sql else DEFAULT 46 | sql = sql.lower() 47 | ex_id_to_sql[spider_id] = sql 48 | ordered_sql = [] 49 | for idx in range(gold_examples_length): 50 | if idx not in ex_id_to_sql: 51 | ex_id_to_sql[idx] = DEFAULT 52 | ordered_sql += [ex_id_to_sql[idx]] 53 | return ordered_sql 54 | 55 | 56 | spider_train_path = "data/spider_queries/train_spider.json" 57 | spider_dev_path = "data/spider_queries/dev.json" 58 | 59 | train_examples = read_spider_examples(spider_train_path) 60 | dev_examples = read_spider_examples(spider_dev_path) 61 | 62 | gold_sql_dev = "data/spider_queries/dev_gold.sql" 63 | gold_sql_list = read_lines_from_file_sql(gold_sql_dev) 64 | 65 | pred_sql_dev = "data/predictions/t5_spider_bs_4_accum_32_epochs_100_lr_1e-4_checkpoint_95.txt" 66 | pred_sql_list = read_lines_from_file_sql(pred_sql_dev) 67 | 68 | # pred_qdmr_formula_dev = "data/predictions/preds_t5_spider_formula_data_aug_bs_4_accum_32_epochs_150_lr_1e-4_ckpt_62.txt" 69 | pred_qdmr_formula_dev = "data/predictions/preds_t5_large_spider_formula_bs_1_accum_128_epochs_150_lr_1e-4_seed_42_ckpt_25.txt" 70 | pred_qdmr_formula_list = read_lines_from_file_sql(pred_qdmr_formula_dev) 71 | 72 | pred_qdmr_steps_dev = "data/predictions/preds_t5_spider_hasref_bs_4_accum_32_epochs_150_lr_1e-4_ckpt_63.txt" 73 | pred_qdmr_steps_list = read_lines_from_file_sql(pred_qdmr_steps_dev) 74 | 75 | # grounded_sql_preds = prepare_qdmr_grounded_sql_preds("data/predictions/groundings_spider_dev.json", len(gold_sql_list)) 76 | 77 | pred_grounded_sql_dev = "data/predictions/preds_t5_spider_groundsql_bs_4_accum_32_epochs_150_lr_1e-4_ckpt_144.txt" 78 | grounded_sql_preds = read_lines_from_file_sql(pred_grounded_sql_dev) 79 | 80 | pred_gold_sql_dev = "data/predictions/preds_t5_large_spider_gold_bs_1_accum_128_epochs_150_lr_1e-4_seed_42_ckpt_10.txt" 81 | gold_sql_preds = read_lines_from_file_sql(pred_gold_sql_dev) 82 | 83 | exec_accuracy = ExecutionAccuracy() 84 | # x = exec_accuracy.evaluate(examples=dev_examples, 85 | # gold=gold_sql_list, 86 | # predict=gold_sql_preds, 87 | # db_dir="data/spider_databases", 88 | # exec_set_match=True) 89 | 90 | # x = exec_accuracy.evaluate(examples=dev_examples, 91 | # gold=gold_sql_list, 92 | # predict=grounded_sql_preds,#pred_sql_list, 93 | # db_dir="data/spider_databases", 94 | # exec_set_match=True) 95 | 96 | # x = exec_accuracy.evaluate(examples=dev_examples, 97 | # gold=gold_sql_list, 98 | # predict=pred_qdmr_formula_list, 99 | # db_dir="data/spider_databases", 100 | # exec_set_match=True, 101 | # pred_type="qdmr_formula") 102 | 103 | # x = exec_accuracy.evaluate(examples=dev_examples, 104 | # gold=gold_sql_list, 105 | # predict=pred_qdmr_steps_list, 106 | # db_dir="data/spider_databases", 107 | # exec_set_match=True, 108 | # pred_type="qdmr_steps") 109 | 110 | 111 | # GEO880 evaluation 112 | 113 | geo_train_path = "data/geo_queries/geo880_sql_train.json" 114 | geo_dev_path = "data/geo_queries/geo880_sql_dev.json" 115 | geo_test_path = "data/geo_queries/geo880_sql_test.json" 116 | 117 | gold_sql_dev = "data/geo_queries/dev_gold_geo.sql" 118 | gold_sql_list_dev = read_lines_from_file_sql(gold_sql_dev) 119 | 120 | gold_sql_test = "data/geo_queries/test_gold_geo.sql" 121 | gold_sql_list_test = read_lines_from_file_sql(gold_sql_test) 122 | 123 | pred_qdmr_formula_test = "data/predictions/geo/t5_base_geo880_qdmr_formula_preds.txt" 124 | pred_qdmr_formula_list = read_lines_from_file_sql(pred_qdmr_formula_test) 125 | 126 | dbs_list = ["geo:"] * len(gold_sql_list_test) 127 | x = exec_accuracy.evaluate(examples=dbs_list, 128 | gold=gold_sql_list_test, 129 | predict=pred_qdmr_formula_list, 130 | db_dir="data/geo_database", 131 | exec_set_match=True, 132 | pred_type="qdmr_formula", 133 | dataset="geo") 134 | 135 | 136 | print(x) 137 | 138 | -------------------------------------------------------------------------------- /src/text_to_sql/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import json 5 | import time 6 | import logging 7 | import random 8 | import re 9 | from itertools import chain 10 | from string import punctuation 11 | 12 | from model import LoggingCallback, T5FineTuner 13 | import numpy as np 14 | import torch 15 | import pytorch_lightning as pl 16 | 17 | 18 | def set_seed(seed): 19 | random.seed(seed) 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | if torch.cuda.is_available(): 23 | torch.cuda.manual_seed_all(seed) 24 | 25 | 26 | if __name__ == '__main__': 27 | args_dict = dict( 28 | dataset='geo', 29 | target_encoding='qdmr_sql', 30 | data_dir='data/spider', # path for data files 31 | output_dir='', # path to save the checkpoints 32 | db_dir='database', # name of db dir in data dir 33 | training_set_file='qdmr_ground_enc_geo880_train.json', # name of training set file in data dir 34 | dev_set_file='geo_dev.json', # name of dev set file in data dir 35 | dev_set_sql='geo_dev.sql', # name of file containing the dev sql queries in data dir 36 | test_set_file='', 37 | test_set_sql='', 38 | model_name_or_path='t5-base', 39 | tokenizer_name_or_path='t5-base', 40 | max_seq_length=512, 41 | learning_rate=1e-4, 42 | weight_decay=0.0, 43 | adam_epsilon=1e-8, 44 | warmup_steps=0, 45 | train_batch_size=2, 46 | eval_batch_size=2, 47 | num_train_epochs=2, 48 | gradient_accumulation_steps=64, 49 | n_gpu=1, 50 | early_stop_callback=False, 51 | fp_16=False, # if you want to enable 16-bit training then install apex and set this to true 52 | opt_level='O1', 53 | # you can find out more on optimisation levels here https://nvidia.github.io/apex/amp.html#opt-levels-and-properties 54 | max_grad_norm=1.0, # if you enable 16-bit training then set this to a sensible value, 0.5 is a good default 55 | seed=42, 56 | ) 57 | 58 | args_dict.update({'output_dir': 't5_geo_qdmr_formula_bs_2_accum_64_epochs_150_lr_1e-4', 59 | 'num_train_epochs': 150}) 60 | args = argparse.Namespace(**args_dict) 61 | 62 | # logger = logging.getLogger(__name__) 63 | 64 | # set random seed 65 | set_seed(42) 66 | 67 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 68 | # filepath=args.output_dir, prefix="checkpoint", monitor="train_loss", mode="min", save_top_k=1 69 | filepath=args.output_dir, prefix="checkpoint", monitor="exec_acc", mode="max", save_top_k=1 70 | ) 71 | 72 | train_params = dict( 73 | accumulate_grad_batches=args.gradient_accumulation_steps, 74 | gpus=args.n_gpu, 75 | max_epochs=args.num_train_epochs, 76 | early_stop_callback=False, 77 | precision=16 if args.fp_16 else 32, 78 | amp_level=args.opt_level, 79 | gradient_clip_val=args.max_grad_norm, 80 | checkpoint_callback=checkpoint_callback, 81 | callbacks=[LoggingCallback()], 82 | ) 83 | 84 | # initialize model 85 | model = T5FineTuner(args) 86 | 87 | # initialize trainer 88 | trainer = pl.Trainer(**train_params) 89 | 90 | # start fine-tuning 91 | trainer.fit(model) --------------------------------------------------------------------------------