├── .env.template ├── .github └── workflows │ └── main.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── analyze_results_and_post_to_slack.py ├── auto_error_analysis.ipynb ├── correct_sql_instructions.ipynb ├── data ├── idk.csv ├── idk_bigquery.csv ├── instruct_advanced_bigquery.csv ├── instruct_advanced_mysql.csv ├── instruct_advanced_postgres.csv ├── instruct_advanced_sqlite.csv ├── instruct_advanced_tsql.csv ├── instruct_basic_bigquery.csv ├── instruct_basic_mysql.csv ├── instruct_basic_postgres.csv ├── instruct_basic_sqlite.csv ├── instruct_basic_tsql.csv ├── questions_gen_bigquery.csv ├── questions_gen_mysql.csv ├── questions_gen_postgres.csv ├── questions_gen_snowflake.csv ├── questions_gen_sqlite.csv └── questions_gen_tsql.csv ├── eval └── eval.py ├── gcs_eval.py ├── gcs_eval_checkpoints.py ├── main.py ├── prompts ├── README.md ├── prompt.md ├── prompt_anthropic.md ├── prompt_cot.md ├── prompt_cot_postgres.md ├── prompt_cot_sqlite.md ├── prompt_experimental.md ├── prompt_gemini.md ├── prompt_mistral.md ├── prompt_openai.json ├── prompt_openai_o1.json ├── prompt_qwen.json └── prompt_together.json ├── requirements.txt ├── requirements_test.txt ├── results_fn_bigquery ├── .env.yaml.template ├── main.py └── requirements.txt ├── results_fn_postgres ├── .env.yaml.template ├── create.sql ├── main.py └── requirements.txt ├── run_checkpoints.sh ├── run_checkpoints_adapters.sh ├── run_checkpoints_cot.sh ├── run_model_cot.sh ├── run_qwen.sh ├── runners ├── anthropic_runner.py ├── api_runner.py ├── bedrock_runner.py ├── deepseek_runner.py ├── gemini_runner.py ├── hf_runner.py ├── llama_cpp_runner.py ├── mistral_runner.py ├── mlx_runner.py ├── openai_runner.py ├── together_runner.py └── vllm_runner.py ├── tests ├── __init__.py ├── local_db_tests.py ├── test_eval.py └── test_utils_pruning.py ├── translate_sql_dialect.py ├── upload_wandb.py └── utils ├── aliases.py ├── api_server.py ├── creds.py ├── dialects.py ├── gen_prompt.py ├── llm.py ├── pruning.py ├── questions.py ├── reporting.py └── upload_report_gcloud.py /.env.template: -------------------------------------------------------------------------------- 1 | SQL_EVAL_UPLOAD_URL=YOUR_SQL_EVAL_UPLOAD_URL 2 | SLACK_BOT_TOKEN=YOUR_SLACK_BOT_TOKEN -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v4 10 | - uses: psf/black@stable 11 | test: 12 | runs-on: ubuntu-latest 13 | needs: lint 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Set up Python 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: '3.10' 20 | cache: 'pip' 21 | - name: Install pip dependencies 22 | run: | 23 | pip install --upgrade pip setuptools 24 | pip install -r requirements_test.txt 25 | pip install pytest 26 | - name: Run tests 27 | run: | 28 | pytest tests/test*.py 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/postgres 2 | data/*embeddings.pkl 3 | data/*ner_metadata.pkl 4 | results 5 | 6 | # credentials 7 | **/.env.yaml 8 | results_fn_bigquery/*.json 9 | results_fn_postgres/*.json 10 | 11 | # pycache 12 | **/__pycache__/ 13 | .pytest_cache 14 | 15 | # virtual envs 16 | *venv/ 17 | *env/ 18 | 19 | .vscode 20 | 21 | # all eda notebooks 22 | eda_*.ipynb 23 | 24 | # wandb output (created when running upload_wandb.ipynb) 25 | wandb/ 26 | 27 | # mac os specific 28 | .DS_Store -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project! We value your contributions and want to ensure a smooth and collaborative experience for everyone. Please take a moment to review the following guidelines. 4 | 5 | ## Table of Contents 6 | - [Installation](#installation) 7 | - [Linting](#linting) 8 | - [Testing](#testing) 9 | - [Submitting Changes](#submitting-changes) 10 | 11 | ## Installation 12 | 13 | Firstly, clone the repository where we store our database data and schema. Install all Python libraries listed in the `requirements.txt` file. Download the spacy model used in the NER heuristic for our [metadata-pruning method](https://github.com/defog-ai/sql-eval/blob/main/utils/pruning.py). Finally, install the library: 14 | ```bash 15 | git clone https://github.com/defog-ai/defog-data.git 16 | cd defog-data 17 | pip install -r requirements.txt 18 | pip install -e . 19 | ``` 20 | 21 | ## Linting 22 | 23 | We use [black](https://black.readthedocs.io/en/stable/) for code formatting and linting. After installing it via pip, you can automatically lint your code with black by adding it as a pre-commit git hook: 24 | ```bash 25 | pip install black 26 | echo -e '#!/bin/sh\n#\n# Run linter before commit\nblack $(git rev-parse --show-toplevel)' > .git/hooks/pre-commit && chmod +x .git/hooks/pre-commit 27 | ``` 28 | 29 | ## Testing 30 | 31 | [_Quis probabit ipsa probationem?_](https://en.wikipedia.org/wiki/Quis_custodiet_ipsos_custodes%3F) 32 | 33 | We have a comprehensive test suite that ensures the quality and reliability of our codebase. To run the python tests, you can use the following command: 34 | 35 | ```bash 36 | pytest tests 37 | ``` 38 | 39 | Our CI excludes [tests/verify_questions.py](tests/verify_questions.py) as it depends on having a local postgres environment with the data loaded. 40 | 41 | Please make sure that all tests pass before submitting your changes. 42 | 43 | We also understand that some changes might not be easily testable with unit tests. In such cases, please provide a detailed description of your changes and how you tested them. We will review your changes and work with you to ensure that they are tested and verified. 44 | 45 | ## Submitting Changes 46 | 47 | When submitting changes to this repository, please follow these steps: 48 | 49 | - Fork the repository and create a new branch for your changes. 50 | - Make your changes, following the coding style and best practices outlined here. 51 | - Run the tests to ensure your changes don't introduce any regressions. 52 | - Lint your code and [squash your commits](https://www.git-tower.com/learn/git/faq/git-squash) down to 1 single commit. 53 | - Commit your changes and push them to your forked repository. 54 | - Open a pull request to the main repository and provide a detailed description of your changes: 55 | - What problem you are trying to solve. 56 | - Alternatives considered and how you decided on which one to take 57 | - How you solved it. 58 | - How you tested your changes. 59 | - Your pull request will be reviewed by our team, and we may ask for further improvements or clarifications before merging. Thank you for your contribution! -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /analyze_results_and_post_to_slack.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pandas as pd 4 | from slack_sdk import WebClient 5 | from matplotlib import pyplot as plt 6 | import seaborn as sns 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("-m", "--model_names", nargs="+", type=str, required=True) 11 | model_names = parser.parse_args().model_names 12 | 13 | # Load the results 14 | results = [] 15 | for model_name in model_names: 16 | fnames = [i for i in os.listdir(f"results/{model_name}") if i.endswith(".csv")] 17 | for fname in fnames: 18 | checkpoint = fname.split(model_name)[1].split("_")[1] 19 | eval_type = fname.split("_")[-1].replace(".csv", "") 20 | # get cot_inference type based on whether there is _cot in the filename 21 | if "_cot" in fname: 22 | cot_inference = "cot" 23 | else: 24 | cot_inference = "no_cot" 25 | tdf = pd.read_csv(f"results/{model_name}/{fname}") 26 | tdf["model"] = model_name 27 | tdf["checkpoint"] = checkpoint 28 | tdf["eval_type"] = eval_type 29 | tdf["cot_inference"] = cot_inference 30 | tdf["model_name"] = model_name 31 | results.append(tdf) 32 | 33 | results = pd.concat(results) 34 | 35 | # create a graph of the average correct for each model, with each model as a line and each checkpoint as a point on the x axis 36 | avg_correct = ( 37 | results.groupby(["model", "eval_type", "checkpoint", "cot_inference"])[ 38 | "correct" 39 | ] 40 | .mean() 41 | .reset_index() 42 | ) 43 | avg_correct = avg_correct.melt( 44 | id_vars=["model", "eval_type", "checkpoint", "cot_inference"], 45 | var_name="metric", 46 | value_name="correct_pct", 47 | ) 48 | # arrange order of eval_type to be basic, v1, advanced, idk 49 | avg_correct["eval_type"] = pd.Categorical( 50 | avg_correct["eval_type"], 51 | categories=["basic", "v1", "advanced", "idk"], 52 | ordered=True, 53 | ) 54 | print(avg_correct.drop(columns=["metric"])) 55 | facet_plot = sns.relplot( 56 | data=avg_correct, 57 | x="checkpoint", 58 | y="correct_pct", 59 | hue="model", 60 | style="cot_inference", 61 | col="eval_type", 62 | kind="line", 63 | col_wrap=3, 64 | ) 65 | # add grid lines to all subplots 66 | for ax in facet_plot.axes: 67 | ax.grid(True, linestyle="--") 68 | 69 | plt.show() 70 | # save the graph 71 | # this will get overwritten each time the script is run, but that's okay 72 | facet_plot.figure.savefig(f"results/avg_correct_{model_name}.png") 73 | fnames = sorted( 74 | [i for i in os.listdir(f"results/{model_name}") if i.endswith(".csv")] 75 | ) 76 | fnames = "\n".join([i.replace(".csv", "") for i in fnames]) 77 | 78 | # post the graph to slack 79 | slack_client = WebClient(token=os.environ["SLACK_BOT_TOKEN"]) 80 | slack_client.files_upload_v2( 81 | channel="C07940SRVM5", # id of the eval-results channel 82 | title=f"Average Correct for {model_name}", 83 | file=f"results/avg_correct_{model_name}.png", 84 | initial_comment=f"""A set of evals just finished running for model `{model_name}`! The graph below has the average correct rate for each model and each checkpoint that was in the evals (excluding idk questions). 85 | Additionally, if you want to see the raw data for any run in eval-visualizer, you can paste one of the following run names into the Eval Visualizer search bar: 86 | 87 | ``` 88 | {fnames} 89 | ``` 90 | """, 91 | ) 92 | -------------------------------------------------------------------------------- /data/instruct_basic_postgres.csv: -------------------------------------------------------------------------------- 1 | db_name,query_category,question,query 2 | broker,basic_join_date_group_order_limit,"What are the top 5 countries by total transaction amount in the past 30 days, inclusive of 30 days ago? Return the country name, number of transactions and total transaction amount.","SELECT c.sbCustCountry, COUNT(t.sbTxId) AS num_transactions, SUM(t.sbTxAmount) AS total_amount FROM sbCustomer c JOIN sbTransaction t ON c.sbCustId = t.sbTxCustId WHERE t.sbTxDateTime >= CURRENT_DATE - INTERVAL '30 days' GROUP BY c.sbCustCountry ORDER BY total_amount DESC LIMIT 5" 3 | broker,basic_join_date_group_order_limit,"How many distinct customers made each type of transaction between Jan 1, 2023 and Mar 31, 2023 (inclusive of start and end dates)? Return the transaction type, number of distinct customers and average number of shares, for the top 3 transaction types by number of customers.","SELECT t.sbTxType, COUNT(DISTINCT t.sbTxCustId) AS num_customers, AVG(t.sbTxShares) AS avg_shares FROM sbTransaction t WHERE t.sbTxDateTime BETWEEN '2023-01-01' AND '2023-03-31 23:59:59' GROUP BY t.sbTxType ORDER BY num_customers DESC LIMIT 3" 4 | broker,basic_join_group_order_limit,"What are the top 10 ticker symbols by total transaction amount? Return the ticker symbol, number of transactions and total transaction amount.","SELECT tk.sbTickerSymbol, COUNT(tx.sbTxId) AS num_transactions, SUM(tx.sbTxAmount) AS total_amount FROM sbTicker tk JOIN sbTransaction tx ON tk.sbTickerId = tx.sbTxTickerId GROUP BY tk.sbTickerSymbol ORDER BY total_amount DESC LIMIT 10" 5 | broker,basic_join_group_order_limit,"What are the top 5 combinations of customer state and ticker type by number of transactions? Return the customer state, ticker type and number of transactions.","SELECT c.sbCustState, t.sbTickerType, COUNT(*) AS num_transactions FROM sbTransaction tx JOIN sbCustomer c ON tx.sbTxCustId = c.sbCustId JOIN sbTicker t ON tx.sbTxTickerId = t.sbTickerId GROUP BY c.sbCustState, t.sbTickerType ORDER BY num_transactions DESC LIMIT 5" 6 | broker,basic_join_distinct,Return the distinct list of customer IDs who have made a 'buy' transaction.,SELECT DISTINCT c.sbCustId FROM sbCustomer c JOIN sbTransaction t ON c.sbCustId = t.sbTxCustId WHERE t.sbTxType = 'buy' 7 | broker,basic_join_distinct,"Return the distinct list of ticker IDs that have daily price records on or after Apr 1, 2023.",SELECT DISTINCT tk.sbTickerId FROM sbTicker tk JOIN sbDailyPrice dp ON tk.sbTickerId = dp.sbDpTickerId WHERE dp.sbDpDate >= '2023-04-01' 8 | broker,basic_group_order_limit,What are the top 3 transaction statuses by number of transactions? Return the status and number of transactions.,"SELECT sbTxStatus, COUNT(*) AS num_transactions FROM sbTransaction GROUP BY sbTxStatus ORDER BY num_transactions DESC LIMIT 3" 9 | broker,basic_group_order_limit,What are the top 5 countries by number of customers? Return the country name and number of customers.,"SELECT sbCustCountry, COUNT(*) AS num_customers FROM sbCustomer GROUP BY sbCustCountry ORDER BY num_customers DESC LIMIT 5" 10 | broker,basic_left_join,Return the customer ID and name of customers who have not made any transactions.,"SELECT c.sbCustId, c.sbCustName FROM sbCustomer c LEFT JOIN sbTransaction t ON c.sbCustId = t.sbTxCustId WHERE t.sbTxCustId IS NULL" 11 | broker,basic_left_join,Return the ticker ID and symbol of tickers that do not have any daily price records.,"SELECT tk.sbTickerId, tk.sbTickerSymbol FROM sbTicker tk LEFT JOIN sbDailyPrice dp ON tk.sbTickerId = dp.sbDpTickerId WHERE dp.sbDpTickerId IS NULL" 12 | car_dealership,basic_join_date_group_order_limit,"Who were the top 3 sales representatives by total revenue in the past 3 months, inclusive of today's date? Return their first name, last name, total number of sales and total revenue. Note that revenue refers to the sum of sale_price in the sales table.","SELECT c.first_name, c.last_name, COUNT(s.id) AS total_sales, SUM(s.sale_price) AS total_revenue FROM sales s JOIN salespersons c ON s.salesperson_id = c.id WHERE s.sale_date >= CURRENT_DATE - INTERVAL '3 months' GROUP BY c.first_name, c.last_name ORDER BY total_revenue DESC LIMIT 3" 13 | car_dealership,basic_join_date_group_order_limit,"Return the top 5 salespersons by number of sales in the past 30 days? Return their first and last name, total sales count and total revenue amount.","SELECT sp.first_name, sp.last_name, COUNT(s.id) AS total_sales, SUM(s.sale_price) AS total_revenue FROM sales s JOIN salespersons sp ON s.salesperson_id = sp.id WHERE s.sale_date >= CURRENT_DATE - INTERVAL '30 days' GROUP BY sp.first_name, sp.last_name, sp.id ORDER BY total_sales DESC LIMIT 5" 14 | car_dealership,basic_join_group_order_limit,"Return the top 5 states by total revenue, showing the number of unique customers and total revenue (based on sale price) for each state.","SELECT c.state, COUNT(DISTINCT s.customer_id) AS unique_customers, SUM(s.sale_price) AS total_revenue FROM sales s JOIN customers c ON s.customer_id = c.id GROUP BY c.state ORDER BY total_revenue DESC LIMIT 5" 15 | car_dealership,basic_join_group_order_limit,"What are the top 5 best selling car models by total revenue? Return the make, model, total number of sales and total revenue.","SELECT c.make, c.model, COUNT(s.id) AS total_sales, SUM(s.sale_price) AS total_revenue FROM sales s JOIN cars c ON s.car_id = c.id GROUP BY c.make, c.model ORDER BY total_revenue DESC LIMIT 5" 16 | car_dealership,basic_join_distinct,"Return the distinct list of customer IDs that have made a purchase, based on joining the customers and sales tables.",SELECT DISTINCT c.id AS customer_id FROM customers c JOIN sales s ON c.id = s.customer_id 17 | car_dealership,basic_join_distinct,"Return the distinct list of salesperson IDs that have received a cash payment, based on joining the salespersons, sales and payments_received tables.",SELECT DISTINCT s.id AS salesperson_id FROM salespersons s JOIN sales sa ON s.id = sa.salesperson_id JOIN payments_received p ON sa.id = p.sale_id WHERE p.payment_method = 'cash' 18 | car_dealership,basic_group_order_limit,"What are the top 3 payment methods by total payment amount received? Return the payment method, total number of payments and total amount.","SELECT payment_method, COUNT(*) AS total_payments, SUM(payment_amount) AS total_amount FROM payments_received GROUP BY payment_method ORDER BY total_amount DESC LIMIT 3" 19 | car_dealership,basic_group_order_limit,"What are the total number of customer signups for the top 2 states? Return the state and total signups, starting from the top.","SELECT state, COUNT(*) AS total_signups FROM customers GROUP BY state ORDER BY total_signups DESC LIMIT 2" 20 | car_dealership,basic_left_join,"Return the car ID, make, model and year for cars that have no sales records, by doing a left join from the cars to sales table.","SELECT c.id AS car_id, c.make, c.model, c.year FROM cars c LEFT JOIN sales s ON c.id = s.car_id WHERE s.car_id IS NULL" 21 | car_dealership,basic_left_join,"Return the salesperson ID, first name and last name for salespersons that have no sales records, by doing a left join from the salespersons to sales table.","SELECT s.id AS salesperson_id, s.first_name, s.last_name FROM salespersons s LEFT JOIN sales sa ON s.id = sa.salesperson_id WHERE sa.salesperson_id IS NULL" 22 | derm_treatment,basic_join_date_group_order_limit,"What are the top 3 doctor specialties by total drug amount prescribed for treatments started in the past 6 calendar months? Return the specialty, number of treatments, and total drug amount.","SELECT d.specialty, COUNT(*) AS num_treatments, SUM(t.tot_drug_amt) AS total_drug_amt FROM treatments t JOIN doctors d ON t.doc_id = d.doc_id WHERE t.start_dt >= DATE_TRUNC('month', CURRENT_DATE - INTERVAL '6 months') GROUP BY d.specialty ORDER BY total_drug_amt DESC LIMIT 3" 23 | derm_treatment,basic_join_date_group_order_limit,"For treatments that ended in the year 2022 (from Jan 1st to Dec 31st inclusive), what is the average PASI score at day 100 and number of distinct patients per insurance type? Return the top 5 insurance types sorted by lowest average PASI score first.","SELECT p.ins_type, COUNT(DISTINCT t.patient_id) AS num_patients, AVG(o.day100_pasi_score) AS avg_pasi_score FROM treatments t JOIN patients p ON t.patient_id = p.patient_id JOIN outcomes o ON t.treatment_id = o.treatment_id WHERE t.end_dt BETWEEN '2022-01-01' AND '2022-12-31' GROUP BY p.ins_type ORDER BY avg_pasi_score LIMIT 5" 24 | derm_treatment,basic_join_group_order_limit,"What are the top 5 drugs by number of treatments and average drug amount per treatment? Return the drug name, number of treatments, and average drug amount.","SELECT d.drug_name, COUNT(*) AS num_treatments, AVG(t.tot_drug_amt) AS avg_drug_amt FROM treatments t JOIN drugs d ON t.drug_id = d.drug_id GROUP BY d.drug_name ORDER BY num_treatments DESC, avg_drug_amt DESC LIMIT 5" 25 | derm_treatment,basic_join_group_order_limit,"What are the top 3 diagnoses by maximum itch VAS score at day 100 and number of distinct patients? Return the diagnosis name, number of patients, and maximum itch score.","SELECT di.diag_name, COUNT(DISTINCT t.patient_id) AS num_patients, MAX(o.day100_itch_vas) AS max_itch_score FROM treatments t JOIN diagnoses di ON t.diag_id = di.diag_id JOIN outcomes o ON t.treatment_id = o.treatment_id GROUP BY di.diag_name ORDER BY max_itch_score DESC, num_patients DESC LIMIT 3" 26 | derm_treatment,basic_join_distinct,"Return the distinct list of doctor IDs, first names and last names that have prescribed treatments.","SELECT DISTINCT d.doc_id, d.first_name, d.last_name FROM treatments t JOIN doctors d ON t.doc_id = d.doc_id" 27 | derm_treatment,basic_join_distinct,"Return the distinct list of patient IDs, first names and last names that have outcome assessments.","SELECT DISTINCT p.patient_id, p.first_name, p.last_name FROM outcomes o JOIN treatments t ON o.treatment_id = t.treatment_id JOIN patients p ON t.patient_id = p.patient_id" 28 | derm_treatment,basic_group_order_limit,"What are the top 3 insurance types by average patient height in cm? Return the insurance type, average height and average weight.","SELECT ins_type, AVG(height_cm) AS avg_height, AVG(weight_kg) AS avg_weight FROM patients GROUP BY ins_type ORDER BY avg_height DESC LIMIT 3" 29 | derm_treatment,basic_group_order_limit,What are the top 2 specialties by number of doctors? Return the specialty and number of doctors.,"SELECT specialty, COUNT(*) AS num_doctors FROM doctors GROUP BY specialty ORDER BY num_doctors DESC LIMIT 2" 30 | derm_treatment,basic_left_join,"Return the patient IDs, first names and last names of patients who have not received any treatments.","SELECT p.patient_id, p.first_name, p.last_name FROM patients p LEFT JOIN treatments t ON p.patient_id = t.patient_id WHERE t.patient_id IS NULL" 31 | derm_treatment,basic_left_join,Return the drug IDs and names of drugs that have not been used in any treatments.,"SELECT d.drug_id, d.drug_name FROM drugs d LEFT JOIN treatments t ON d.drug_id = t.drug_id WHERE t.drug_id IS NULL" 32 | ewallet,basic_join_date_group_order_limit,"Who are the top 2 merchants (receiver type 1) by total transaction amount in the past 150 days (inclusive of 150 days ago)? Return the merchant name, total number of transactions, and total transaction amount.","SELECT m.name AS merchant_name, COUNT(t.txid) AS total_transactions, SUM(t.amount) AS total_amount FROM consumer_div.merchants m JOIN consumer_div.wallet_transactions_daily t ON m.mid = t.receiver_id WHERE t.receiver_type = 1 AND t.created_at >= CURRENT_DATE - INTERVAL '150 days' GROUP BY m.name ORDER BY total_amount DESC LIMIT 2" 33 | ewallet,basic_join_date_group_order_limit,"How many distinct active users sent money per month in 2023? Return the number of active users per month (as a date), starting from the earliest date. Do not include merchants in the query. Only include successful transactions.","SELECT DATE_TRUNC('month', t.created_at) AS MONTH, COUNT(DISTINCT t.sender_id) AS active_users FROM consumer_div.wallet_transactions_daily t JOIN consumer_div.users u ON t.sender_id = u.uid WHERE t.sender_type = 0 AND t.status = 'success' AND u.status = 'active' AND t.created_at >= '2023-01-01' AND t.created_at < '2024-01-01' GROUP BY MONTH ORDER BY MONTH;SELECT DATE_TRUNC('month', w.created_at)::DATE AS MONTH, COUNT(DISTINCT w.sender_id) AS active_user_count FROM consumer_div.wallet_transactions_daily w JOIN consumer_div.users u ON w.sender_id = u.uid WHERE w.sender_type = 0 AND w.status = 'success' AND w.created_at >= '2023-01-01' AND w.created_at < '2024-01-01' AND u.status = 'active' GROUP BY DATE_TRUNC('month', w.created_at) ORDER BY MONTH ASC;" 34 | ewallet,basic_join_group_order_limit,"What are the top 3 most frequently used coupon codes? Return the coupon code, total number of redemptions, and total amount redeemed.","SELECT c.code AS coupon_code, COUNT(t.txid) AS redemption_count, SUM(t.amount) AS total_discount FROM consumer_div.coupons c JOIN consumer_div.wallet_transactions_daily t ON c.cid = t.coupon_id GROUP BY c.code ORDER BY redemption_count DESC LIMIT 3" 35 | ewallet,basic_join_group_order_limit,"Which are the top 5 countries by total transaction amount sent by users, sender_type = 0? Return the country, number of distinct users who sent, and total transaction amount.","SELECT u.country, COUNT(DISTINCT t.sender_id) AS user_count, SUM(t.amount) AS total_amount FROM consumer_div.users u JOIN consumer_div.wallet_transactions_daily t ON u.uid = t.sender_id WHERE t.sender_type = 0 GROUP BY u.country ORDER BY total_amount DESC LIMIT 5" 36 | ewallet,basic_join_distinct,"Return the distinct list of merchant IDs that have received money from a transaction. Consider all transaction types in the results you return, but only include the merchant ids in your final answer.",SELECT DISTINCT m.mid AS merchant_id FROM consumer_div.merchants m JOIN consumer_div.wallet_transactions_daily t ON m.mid = t.receiver_id WHERE t.receiver_type = 1 37 | ewallet,basic_join_distinct,Return the distinct list of user IDs who have received transaction notifications.,SELECT DISTINCT user_id FROM consumer_div.notifications WHERE type = 'transaction' 38 | ewallet,basic_group_order_limit,What are the top 3 most common transaction statuses and their respective counts?,"SELECT status, COUNT(*) AS COUNT FROM consumer_div.wallet_transactions_daily GROUP BY status ORDER BY COUNT DESC LIMIT 3" 39 | ewallet,basic_group_order_limit,What are the top 2 most frequently used device types for user sessions and their respective counts?,"SELECT device_type, COUNT(*) AS COUNT FROM consumer_div.user_sessions GROUP BY device_type ORDER BY COUNT DESC LIMIT 2" 40 | ewallet,basic_left_join,Return users (user ID and username) who have not received any notifications,"SELECT u.uid, u.username FROM consumer_div.users u LEFT JOIN consumer_div.notifications n ON u.uid = n.user_id WHERE n.id IS NULL" 41 | ewallet,basic_left_join,Return merchants (merchant ID and name) who have not issued any coupons.,"SELECT m.mid AS merchant_id, m.name AS merchant_name FROM consumer_div.merchants m LEFT JOIN consumer_div.coupons c ON m.mid = c.merchant_id WHERE c.cid IS NULL" 42 | -------------------------------------------------------------------------------- /gcs_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import time 4 | 5 | # edit these 4 paths as per your setup 6 | # GCS_MODEL_DIR: gcs path where the models are stored. 7 | # this should be the same as GCS_MODEL_DIR in model_uploader.py 8 | # GCS_MODEL_EVAL_DIR: gcs path where the evaluated models will be shifted to 9 | # LOCAL_MODEL_DIR: local path where the models will be downloaded 10 | # SQL_EVAL_DIR: local path where the sql-eval repo is cloned 11 | GCS_MODEL_DIR = "gs://defog-finetuning/fsdp_wrong_sql_eval" 12 | GCS_MODEL_EVAL_DIR = "gs://defog-finetuning/fsdp_evaluated" 13 | LOCAL_MODEL_DIR = os.path.expanduser("/models/fsdp") 14 | SQL_EVAL_DIR = os.path.expanduser("~/sql-eval") 15 | # edit the question files, prompt files and output files as you desire. 16 | # they should have the same length, as they will be zipped and iterated through 17 | # in the vllm runner. 18 | os.makedirs(LOCAL_MODEL_DIR, exist_ok=True) 19 | os.chdir(SQL_EVAL_DIR) # for executing sql-eval commands 20 | # edit the run configs as per your requirements 21 | NUM_BEAMS = 1 22 | 23 | 24 | def download_evaluate(): 25 | while True: 26 | existing_models = ( 27 | subprocess.run(["gsutil", "ls", GCS_MODEL_DIR], capture_output=True) 28 | .stdout.decode("utf-8") 29 | .split("\n") 30 | ) 31 | for gcs_model_path in existing_models: 32 | model_name = ( 33 | gcs_model_path.replace(GCS_MODEL_DIR, "").replace("/", "").strip() 34 | ) 35 | if not model_name: 36 | continue 37 | local_model_path = os.path.join(LOCAL_MODEL_DIR, model_name) 38 | if not os.path.exists(local_model_path): 39 | print(f"Downloading model: {model_name}") 40 | # download from gcs 41 | subprocess.run( 42 | ["gsutil", "-m", "cp", "-r", gcs_model_path, LOCAL_MODEL_DIR] 43 | ) 44 | else: 45 | print(f"Model folder exists: {model_name}") 46 | try: 47 | # run evaluation 48 | # python3 main.py \ 49 | # -db postgres \ 50 | # -q data/instruct_basic_postgres.csv data/instruct_advanced_postgres.csv data/questions_gen_postgres.csv \ 51 | # -o "results/${model_name}_beam4_basic.csv" "results/${model_name}_beam4_advanced.csv" "results/${model_name}_beam4_v1.csv" \ 52 | # -g vllm \ 53 | # -b 4 \ 54 | # -c 0 \ 55 | # -f "prompts/prompt.md" \ 56 | # -m "/models/fsdp/${model_name}" 57 | question_files = [ 58 | "data/instruct_basic_postgres.csv", 59 | "data/instruct_advanced_postgres.csv", 60 | "data/questions_gen_postgres.csv", 61 | ] 62 | prompt_file = "prompts/prompt.md" 63 | output_files = [ 64 | f"results/{model_name}_beam{NUM_BEAMS}_basic.csv", 65 | f"results/{model_name}_beam{NUM_BEAMS}_advanced.csv", 66 | f"results/{model_name}_beam{NUM_BEAMS}_v1.csv", 67 | ] 68 | subprocess.run( 69 | [ 70 | "python3", 71 | "main.py", 72 | "-db", 73 | "postgres", 74 | "-q", 75 | *question_files, 76 | "-o", 77 | *output_files, 78 | "-g", 79 | "vllm", 80 | "-b", 81 | str(NUM_BEAMS), 82 | "-c", 83 | "0", 84 | "-f", 85 | prompt_file, 86 | "-m", 87 | local_model_path, 88 | "-bs", 89 | "16", 90 | ], 91 | check=True, 92 | ) 93 | # move the model to the evaluated directory once evaluated successfully 94 | subprocess.run( 95 | ["gsutil", "-m", "mv", gcs_model_path, GCS_MODEL_EVAL_DIR], 96 | check=True, 97 | ) 98 | subprocess.run(["rm", "-rf", local_model_path], check=True) 99 | except Exception as e: 100 | print(f"Error in evaluation: {e}") 101 | exit(1) 102 | time.sleep(10) 103 | 104 | 105 | if __name__ == "__main__": 106 | download_evaluate() 107 | -------------------------------------------------------------------------------- /gcs_eval_checkpoints.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import time 4 | 5 | from transformers import AutoTokenizer 6 | 7 | # Alternate version of gcs_eval if you're working with nested checkpoint folders 8 | # with model weights instead of model weight folders directly 9 | 10 | # edit these 4 paths as per your setup 11 | # GCS_MODEL_DIR: gcs path where the models are stored. 12 | # this should be the same as GCS_MODEL_DIR in model_uploader.py 13 | # GCS_MODEL_EVAL_DIR: gcs path where the evaluated models will be shifted to 14 | # LOCAL_MODEL_DIR: local path where the models will be downloaded 15 | # SQL_EVAL_DIR: local path where the sql-eval repo is cloned 16 | GCS_MODEL_DIR = "gs://defog-finetuning/fft" 17 | GCS_MODEL_EVAL_DIR = "gs://defog-finetuning/fft_evaluated" 18 | LOCAL_MODEL_DIR = os.path.expanduser("/models/fft") 19 | SQL_EVAL_DIR = os.path.expanduser("~/sql-eval") 20 | # edit the question files, prompt files and output files as you desire. 21 | # they should have the same length, as they will be zipped and iterated through 22 | # in the vllm runner. 23 | os.makedirs(LOCAL_MODEL_DIR, exist_ok=True) 24 | os.chdir(SQL_EVAL_DIR) # for executing sql-eval commands 25 | # edit the run configs as per your requirements 26 | NUM_BEAMS = 1 27 | TOKENIZER_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" 28 | 29 | 30 | def check_and_save_tokenizer(dir: str): 31 | if not os.path.exists(os.path.join(dir, "tokenizer_config.json")): 32 | print(f"Saving tokenizer in {dir}") 33 | tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL) 34 | tokenizer.save_pretrained(dir) 35 | 36 | 37 | def download_evaluate(): 38 | while True: 39 | existing_models = ( 40 | subprocess.run(["gsutil", "ls", GCS_MODEL_DIR], capture_output=True) 41 | .stdout.decode("utf-8") 42 | .split("\n") 43 | ) 44 | existing_checkpoints = [] 45 | for existing_model_folder in existing_models: 46 | results = ( 47 | subprocess.run( 48 | ["gsutil", "ls", existing_model_folder], capture_output=True 49 | ) 50 | .stdout.decode("utf-8") 51 | .split("\n") 52 | ) 53 | for path in results: 54 | if path.startswith(GCS_MODEL_DIR) and "checkpoint" in path: 55 | existing_checkpoints.append(path) 56 | print("Existing checkpoints:") 57 | for ec in existing_checkpoints: 58 | print(ec) 59 | # sort existing checkpoints lexically 60 | existing_checkpoints.sort() 61 | for gcs_model_checkpoint_path in existing_checkpoints: 62 | run_name_checkpoint = gcs_model_checkpoint_path.replace( 63 | GCS_MODEL_DIR, "" 64 | ).strip(" /") 65 | if not run_name_checkpoint: 66 | print("No model found, skipping.") 67 | continue 68 | local_model_path = os.path.join(LOCAL_MODEL_DIR, run_name_checkpoint) 69 | run_name = run_name_checkpoint.split("/checkpoint-", 1)[0] 70 | print(f"Model name: {run_name_checkpoint}") 71 | if not os.path.exists(local_model_path): 72 | local_run_name_folder = os.path.join(LOCAL_MODEL_DIR, run_name) 73 | os.makedirs(local_run_name_folder, exist_ok=True) 74 | # download from gcs's checkpoint folder into a run name folder 75 | print( 76 | f"Downloading from {gcs_model_checkpoint_path} to {local_run_name_folder}" 77 | ) 78 | subprocess.run( 79 | [ 80 | "gsutil", 81 | "-m", 82 | "cp", 83 | "-r", 84 | gcs_model_checkpoint_path, 85 | local_run_name_folder, 86 | ] 87 | ) 88 | else: 89 | print(f"Model folder exists: {run_name_checkpoint}") 90 | check_and_save_tokenizer(local_model_path) 91 | try: 92 | # run evaluation 93 | # python3 main.py \ 94 | # -db postgres \ 95 | # -q data/instruct_basic_postgres.csv data/instruct_advanced_postgres.csv data/questions_gen_postgres.csv \ 96 | # -o "results/${run_name_checkpoint}_beam4_basic.csv" "results/${run_name_checkpoint}_beam4_advanced.csv" "results/${run_name_checkpoint}_beam4_v1.csv" \ 97 | # -g vllm \ 98 | # -b 4 \ 99 | # -c 0 \ 100 | # -f "prompts/prompt.md" \ 101 | # -m "/models/fsdp/${run_name_checkpoint}" 102 | question_files = [ 103 | "data/instruct_basic_postgres.csv", 104 | "data/instruct_advanced_postgres.csv", 105 | "data/questions_gen_postgres.csv", 106 | ] 107 | prompt_file = "prompts/prompt.md" 108 | output_files = [ 109 | f"results/{run_name_checkpoint}_beam{NUM_BEAMS}_basic.csv", 110 | f"results/{run_name_checkpoint}_beam{NUM_BEAMS}_advanced.csv", 111 | f"results/{run_name_checkpoint}_beam{NUM_BEAMS}_v1.csv", 112 | ] 113 | os.makedirs(os.path.join("results", run_name), exist_ok=True) 114 | subprocess.run( 115 | [ 116 | "python3", 117 | "main.py", 118 | "-db", 119 | "postgres", 120 | "-q", 121 | *question_files, 122 | "-o", 123 | *output_files, 124 | "-g", 125 | "vllm", 126 | "-b", 127 | str(NUM_BEAMS), 128 | "-c", 129 | "0", 130 | "-f", 131 | prompt_file, 132 | "-m", 133 | local_model_path, 134 | "-bs", 135 | "200", 136 | ], 137 | check=True, 138 | ) 139 | # make model directory in gcs 140 | subprocess.run( 141 | [ 142 | "gsutil", 143 | "mkdir", 144 | f"{GCS_MODEL_EVAL_DIR}/{run_name}", 145 | ] 146 | ) 147 | # move the model to the evaluated directory once evaluated successfully 148 | subprocess.run( 149 | [ 150 | "gsutil", 151 | "-m", 152 | "mv", 153 | gcs_model_checkpoint_path, 154 | f"{GCS_MODEL_EVAL_DIR}/{run_name}", 155 | ], 156 | check=True, 157 | ) 158 | subprocess.run(["rm", "-rf", local_model_path], check=True) 159 | except Exception as e: 160 | print(f"Error in evaluation: {e}") 161 | exit(1) 162 | time.sleep(10) 163 | 164 | 165 | if __name__ == "__main__": 166 | download_evaluate() 167 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | if __name__ == "__main__": 5 | parser = argparse.ArgumentParser() 6 | # data-related parameters 7 | parser.add_argument("-q", "--questions_file", nargs="+", type=str, required=True) 8 | parser.add_argument("-n", "--num_questions", type=int, default=None) 9 | parser.add_argument("-db", "--db_type", type=str, required=True) 10 | parser.add_argument("-d", "--use_private_data", action="store_true") 11 | parser.add_argument("-dp", "--decimal_points", type=int, default=None) 12 | # model-related parameters 13 | parser.add_argument("-g", "--model_type", type=str, required=True) 14 | parser.add_argument("-m", "--model", type=str) 15 | parser.add_argument("-a", "--adapter", type=str) # path to adapter 16 | parser.add_argument( 17 | "-an", "--adapter_name", type=str, default=None 18 | ) # only for use with production server 19 | parser.add_argument("--api_url", type=str) 20 | parser.add_argument("--api_type", type=str) 21 | # inference-technique-related parameters 22 | parser.add_argument("-f", "--prompt_file", nargs="+", type=str, required=True) 23 | parser.add_argument("-b", "--num_beams", type=int, default=1) 24 | parser.add_argument( 25 | "-bs", "--batch_size", type=int, default=4 26 | ) # batch size, only relevant for the hf runner 27 | parser.add_argument("-c", "--num_columns", type=int, default=0) 28 | parser.add_argument("-s", "--shuffle_metadata", action="store_true") 29 | parser.add_argument("-k", "--k_shot", action="store_true") 30 | parser.add_argument( 31 | "--cot_table_alias", type=str, choices=["instruct", "pregen", ""], default="" 32 | ) 33 | parser.add_argument("--thinking", action="store_true") 34 | # execution-related parameters 35 | parser.add_argument("-o", "--output_file", nargs="+", type=str, required=True) 36 | parser.add_argument("-p", "--parallel_threads", type=int, default=5) 37 | parser.add_argument("-t", "--timeout_gen", type=float, default=30.0) 38 | parser.add_argument("-u", "--timeout_exec", type=float, default=10.0) 39 | parser.add_argument("-v", "--verbose", action="store_true") 40 | parser.add_argument("-l", "--logprobs", action="store_true") 41 | parser.add_argument("--upload_url", type=str) 42 | parser.add_argument("--run_name", type=str, required=False) 43 | parser.add_argument( 44 | "-qz", "--quantized", default=False, action=argparse.BooleanOptionalAction 45 | ) 46 | 47 | args = parser.parse_args() 48 | 49 | # if questions_file is None, set it to the default questions file for the given db_type 50 | if args.questions_file is None: 51 | args.questions_file = f"data/questions_gen_{args.db_type}.csv" 52 | 53 | # check that questions_file matches db_type 54 | for questions_file in args.questions_file: 55 | if args.db_type not in questions_file and questions_file != "data/idk.csv": 56 | print( 57 | f"WARNING: Check that questions_file {questions_file} is compatible with db_type {args.db_type}" 58 | ) 59 | 60 | if args.upload_url is None: 61 | args.upload_url = os.environ.get("SQL_EVAL_UPLOAD_URL") 62 | 63 | # check args 64 | # check that either args.questions_file > 1 and args.prompt_file = 1 or vice versa 65 | if ( 66 | len(args.questions_file) > 1 67 | and len(args.prompt_file) == 1 68 | and len(args.output_file) > 1 69 | ): 70 | args.prompt_file = args.prompt_file * len(args.questions_file) 71 | elif ( 72 | len(args.questions_file) == 1 73 | and len(args.prompt_file) > 1 74 | and len(args.output_file) > 1 75 | ): 76 | args.questions_file = args.questions_file * len(args.prompt_file) 77 | if not (len(args.questions_file) == len(args.prompt_file) == len(args.output_file)): 78 | raise ValueError( 79 | "If args.output_file > 1, then at least 1 of args.prompt_file or args.questions_file must be > 1 and match lengths." 80 | f"Obtained lengths: args.questions_file={len(args.questions_file)}, args.prompt_file={len(args.prompt_file)}, args.output_file={len(args.output_file)}" 81 | ) 82 | 83 | if args.model_type == "oa": 84 | from runners.openai_runner import run_openai_eval 85 | 86 | if args.model is None: 87 | args.model = "gpt-4o" 88 | run_openai_eval(args) 89 | elif args.model_type == "anthropic": 90 | from runners.anthropic_runner import run_anthropic_eval 91 | 92 | if args.model is None: 93 | args.model = "claude-2" 94 | run_anthropic_eval(args) 95 | elif args.model_type == "vllm": 96 | import platform 97 | 98 | if platform.system() == "Darwin": 99 | raise ValueError( 100 | "vLLM is not supported on macOS. Please run on another OS supporting CUDA." 101 | ) 102 | from runners.vllm_runner import run_vllm_eval 103 | 104 | run_vllm_eval(args) 105 | elif args.model_type == "hf": 106 | from runners.hf_runner import run_hf_eval 107 | 108 | run_hf_eval(args) 109 | elif args.model_type == "api": 110 | assert args.api_url is not None, "api_url must be provided for api model" 111 | assert args.api_type is not None, "api_type must be provided for api model" 112 | assert args.api_type in [ 113 | "openai", 114 | "vllm", 115 | "tgi", 116 | ], "api_type must be one of 'openai', 'vllm', 'tgi'" 117 | 118 | from runners.api_runner import run_api_eval 119 | 120 | run_api_eval(args) 121 | elif args.model_type == "llama_cpp": 122 | from runners.llama_cpp_runner import run_llama_cpp_eval 123 | 124 | run_llama_cpp_eval(args) 125 | elif args.model_type == "mlx": 126 | from runners.mlx_runner import run_mlx_eval 127 | 128 | run_mlx_eval(args) 129 | elif args.model_type == "gemini": 130 | from runners.gemini_runner import run_gemini_eval 131 | 132 | run_gemini_eval(args) 133 | elif args.model_type == "mistral": 134 | from runners.mistral_runner import run_mistral_eval 135 | 136 | run_mistral_eval(args) 137 | elif args.model_type == "bedrock": 138 | from runners.bedrock_runner import run_bedrock_eval 139 | 140 | run_bedrock_eval(args) 141 | elif args.model_type == "together": 142 | from runners.together_runner import run_together_eval 143 | 144 | run_together_eval(args) 145 | elif args.model_type == "deepseek": 146 | from runners.deepseek_runner import run_deepseek_eval 147 | 148 | run_deepseek_eval(args) 149 | else: 150 | raise ValueError( 151 | f"Invalid model type: {args.model_type}. Model type must be one of: 'oa', 'hf', 'anthropic', 'vllm', 'api', 'llama_cpp', 'mlx', 'gemini', 'mistral'" 152 | ) 153 | -------------------------------------------------------------------------------- /prompts/README.md: -------------------------------------------------------------------------------- 1 | # Defining your prompt 2 | You can define your prompt template by using the following variables: 3 | - `user_question`: The question that we want to generate sql for 4 | - `table_metadata_string`: The metadata of the table that we want to query. This is a string that contains the table names, column names and column types. This allows the model to know which columns/tables are available for getting information from. For the sqlcoder model that we released, you would need to represent your table metadata as a [SQL DDL](https://en.wikipedia.org/wiki/Data_definition_language) statement. 5 | - `instructions`: This is an optional field that allows you to customize specific instructions for each question, if needed. For example, if you want to ask the model to format your dates a particular way, define keywords, or adapt the SQL to a different database, you can do so here. If you don't need to customize the instructions, you can omit this in your prompt. 6 | - `k_shot_prompt`: This is another optional field that allows you to provide example SQL queries and their corresponding questions. These examples serve as a context for the model, helping it understand the type of SQL query you're expecting for a given question. Including a few examples in the k_shot_prompt field can significantly improve the model's accuracy in generating relevant SQL queries, especially for complex or less straightforward questions. 7 | - `glossary`: This is an optional field that allows you to define special terminology or rules for creating the SQL queries. 8 | 9 | Here is how a sample might look like with the above variables: 10 | ```markdown 11 | ### Task 12 | Generate a SQL query to answer the following question: 13 | `{user_question}` 14 | `{instructions}` 15 | `{glossary}` 16 | ### Database Schema 17 | The query will run on a database with the following schema: 18 | {table_metadata_string} 19 | {k_shot_prompt} 20 | ### Answer 21 | Given the database schema, here is the SQL query that answers `{user_question}`: 22 | ```sql 23 | ``` 24 | 25 | # Adding variables 26 | You can add variables using curly braces - like so `{user_question}`. These can then be updated at runtime using Python's `.format()` method for strings. Like [here](../eval/hf_runner.py#L18). 27 | 28 | # Translating to OpenAI's messages prompt 29 | If you're performing evaluation with OpenAI's chat models, please ensure that your prompt contains the keywords `### Instructions:`, `### Input:`, and `### Response:`. This will help ensure that the prompt sections are automatically mapped to OpenAI's different prompt roles. The text under Instructions, Input and Response will be converted to the `system`, `user` and `assistant` prompts respectively. 30 | 31 | # Translating to Anthropic's messages prompt 32 | If you're performing evaluation with Anthropic's models, please ensure that your prompt contains the keywords `Human:`, and `Assistant:`. This will help ensure that the model correctly interprets the roles in the conversation and responds appropriately. -------------------------------------------------------------------------------- /prompts/prompt.md: -------------------------------------------------------------------------------- 1 | <|start_header_id|>system<|end_header_id|> 2 | 3 | Follow instructions to the letter, and answer questions without making any additional assumptions.<|start_header_id|>user<|end_header_id|> 4 | 5 | Generate a {db_type} query to answer this question: `{user_question}` 6 | {instructions} 7 | DDL statements: 8 | {table_metadata_string} 9 | 10 | Generate a valid {db_type} query that best answers the question `{user_question}`.<|eot_id|><|start_header_id|>assistant<|end_header_id|> 11 | 12 | I will reflect on the user's request before answering the question. 13 | 14 | I was asked to generate a SQL query for this question: `{user_question}` 15 | 16 | {instruction_reflections} 17 | With this in mind, here is the {db_type} query that best answers the question while only using appropriate tables and columns from the DDL statements: 18 | ```sql 19 | -------------------------------------------------------------------------------- /prompts/prompt_anthropic.md: -------------------------------------------------------------------------------- 1 | Your task is to convert a text question to a {db_type} query, given a database schema. 2 | 3 | The question that you must generate a SQL for is this `{user_question}`. 4 | {instructions} 5 | This query will run on a database with the following schema: 6 | {table_metadata_string} 7 | {k_shot_prompt} 8 | 9 | Just return the SQL query, nothing else. -------------------------------------------------------------------------------- /prompts/prompt_cot.md: -------------------------------------------------------------------------------- 1 | <|start_header_id|>system<|end_header_id|> 2 | 3 | Follow instructions to the letter, and answer questions without making any additional assumptions.<|start_header_id|>user<|end_header_id|> 4 | 5 | Generate a {db_type} query to answer this question: `{user_question}` 6 | {instructions} 7 | DDL statements: 8 | {table_metadata_string} 9 | 10 | {table_aliases}Generate a valid {db_type} query that best answers the question `{user_question}`.<|eot_id|><|start_header_id|>assistant<|end_header_id|> 11 | 12 | I will reflect on the user's request before answering the question. 13 | 14 | I was asked to generate a SQL query for this question: `{user_question}` 15 | 16 | {instruction_reflections} 17 | With this in mind, here is the {db_type} query that best answers the question while only using appropriate tables and columns from the DDL statements: 18 | ```sql 19 | -------------------------------------------------------------------------------- /prompts/prompt_cot_postgres.md: -------------------------------------------------------------------------------- 1 | <|start_header_id|>system<|end_header_id|> 2 | 3 | Follow instructions to the letter, and answer questions without making any additional assumptions.<|start_header_id|>user<|end_header_id|> 4 | 5 | Generate a SQL query to answer this question: `{user_question}` 6 | {instructions} 7 | DDL statements: 8 | {table_metadata_string} 9 | {join_str} 10 | 11 | {table_aliases}Generate a valid SQL query that best answers the question `{user_question}`.<|eot_id|><|start_header_id|>assistant<|end_header_id|> 12 | 13 | I will reflect on the user's request before answering the question. 14 | 15 | I was asked to generate a SQL query for this question: `{user_question}` 16 | 17 | {instruction_reflections} 18 | With this in mind, here is the SQL query that best answers the question while only using appropriate tables and columns from the DDL statements: 19 | ```sql 20 | -------------------------------------------------------------------------------- /prompts/prompt_cot_sqlite.md: -------------------------------------------------------------------------------- 1 | <|start_header_id|>system<|end_header_id|> 2 | 3 | Follow instructions to the letter, and answer questions without making any additional assumptions.<|start_header_id|>user<|end_header_id|> 4 | 5 | Generate a {db_type} query to answer this question: `{user_question}` 6 | {instructions} 7 | DDL statements: 8 | {table_metadata_string} 9 | 10 | {table_aliases}<|eot_id|><|start_header_id|>assistant<|end_header_id|> 11 | 12 | I was asked to generate a SQL query for this question: `{user_question}` 13 | 14 | {instruction_reflections} 15 | Here is the {db_type} query that best answers the question: 16 | ```sql 17 | -------------------------------------------------------------------------------- /prompts/prompt_experimental.md: -------------------------------------------------------------------------------- 1 | ### Task 2 | Generate a SQL query to answer this question: `{user_question}` 3 | {instructions} 4 | 5 | ### Database Schema 6 | The query will run on a database with the following schema: 7 | {table_metadata_string} 8 | 9 | ### Answer 10 | Given the database schema, here is the SQL query that answers the question `{user_question}` 11 | ```sql 12 | -------------------------------------------------------------------------------- /prompts/prompt_gemini.md: -------------------------------------------------------------------------------- 1 | Your task is to convert a text question to a {db_type} query, given a database schema. 2 | 3 | Generate a SQL query that answers the question `{user_question}`. 4 | {instructions} 5 | This query will run on a database whose schema is represented in this SQL DDL: 6 | {table_metadata_string} 7 | 8 | Return the SQL query that answers the question `{user_question}` 9 | ```sql -------------------------------------------------------------------------------- /prompts/prompt_mistral.md: -------------------------------------------------------------------------------- 1 | System: Your task is to convert a text question to a SQL query that runs on Postgres, given a database schema. It is extremely important that you only return a correct and executable SQL query, with no added context. 2 | 3 | User: Generate a SQL query that answers the question `{user_question}`. This query will run on a PostgreSQL database whose schema is represented in this string: 4 | {table_metadata_string} 5 | -------------------------------------------------------------------------------- /prompts/prompt_openai.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "role": "system", 4 | "content": "Your role is to convert a user question to a {db_type} query, given a database schema." 5 | }, 6 | { 7 | "role": "user", 8 | "content": "Generate a SQL query that answers the question `{user_question}`.\n{instructions}\nThis query will run on a database whose schema is represented in this string:\n{table_metadata_string}\n{k_shot_prompt}\nReturn only the SQL query, and nothing else." 9 | } 10 | ] 11 | -------------------------------------------------------------------------------- /prompts/prompt_openai_o1.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "role": "user", 4 | "content": "Generate a SQL query for {db_type} that answers the question `{user_question}`.\n{instructions}\nThis query will run on a database whose schema is represented in this string:\n{table_metadata_string}\n\nReturn only the SQL query, and nothing else." 5 | } 6 | ] 7 | -------------------------------------------------------------------------------- /prompts/prompt_qwen.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "role": "system", 4 | "content": "Your task is to convert a user question to a {db_type} query, given a database schema." 5 | }, 6 | { 7 | "role": "user", 8 | "content": "Generate a SQL query that answers the question `{user_question}`.\n{instructions}\nThis query will run on a database whose schema is represented in this string:\n{table_metadata_string}\n{join_str}\n{table_aliases}\nAfter reasoning, return only the SQL query, and nothing else." 9 | } 10 | ] 11 | -------------------------------------------------------------------------------- /prompts/prompt_together.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "role": "system", 4 | "content": "Your role is to convert a user question to a {db_type} query, given a database schema." 5 | }, 6 | { 7 | "role": "user", 8 | "content": "Generate a SQL query that answers the question `{user_question}`.\n{instructions}\nThis query will run on a database whose schema is represented in this SQL DDL:\n{table_metadata_string}\n{table_aliases}\n{pruned_join_str}\n{k_shot_prompt}\nReturn the SQL query that answers the question `{user_question}`" 9 | }, 10 | { 11 | "role": "assistant", 12 | "content": "```sql\n" 13 | } 14 | ] 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | anthropic 2 | argparse 3 | func_timeout 4 | google-generativeai 5 | mistralai 6 | mysql-connector-python 7 | numpy==2.1.2 8 | openai>=1.1.0 9 | pandas==2.2.3 10 | pandas-gbq 11 | peft 12 | psycopg2-binary 13 | pyodbc 14 | pytest 15 | pyyaml 16 | sentence-transformers 17 | snowflake-connector-python 18 | sqlalchemy 19 | tiktoken 20 | together 21 | torch==2.4.0 22 | tqdm 23 | transformers 24 | sqlparse 25 | sqlglot 26 | vllm==0.6.3.post1; sys_platform != 'darwin' 27 | -------------------------------------------------------------------------------- /requirements_test.txt: -------------------------------------------------------------------------------- 1 | func_timeout 2 | numpy 3 | openai 4 | pandas 5 | psycopg2-binary 6 | snowflake-connector-python 7 | sqlalchemy 8 | sqlglot 9 | tqdm -------------------------------------------------------------------------------- /results_fn_bigquery/.env.yaml.template: -------------------------------------------------------------------------------- 1 | BQ_PROJECT: 2 | BQ_TABLE: 3 | CREDENTIALS_PATH: -------------------------------------------------------------------------------- /results_fn_bigquery/main.py: -------------------------------------------------------------------------------- 1 | # this is a Google cloud function for receiving the data from the web app and storing it in Bigquery 2 | 3 | import functions_framework 4 | from google.oauth2 import service_account 5 | import pandas as pd 6 | import pandas_gbq 7 | import os 8 | 9 | BQ_PROJECT = os.environ.get("BQ_PROJECT") 10 | BQ_TABLE = os.environ.get("BQ_TABLE") 11 | 12 | # authenticate using service account's json credentials 13 | credentials_path = os.environ.get("CREDENTIALS_PATH") 14 | print(f"CREDENTIALS_PATH: {credentials_path}") 15 | credentials = service_account.Credentials.from_service_account_file(credentials_path) 16 | print(f"Credentials: {credentials}") 17 | 18 | 19 | @functions_framework.http 20 | def bigquery(request): 21 | request_json = request.get_json(force=True) 22 | results = request_json["results"] 23 | run_id = request_json["run_id"] 24 | run_time = pd.to_datetime(request_json["timestamp"]) 25 | runner_type = request_json["runner_type"] 26 | prompt = request_json["prompt"] 27 | prompt_id = request_json["prompt_id"] 28 | model = request_json["model"] 29 | num_beams = request_json["num_beams"] 30 | db_type = request_json["db_type"] 31 | gpu_name = request_json["gpu_name"] 32 | gpu_memory = request_json["gpu_memory"] 33 | gpu_driver_version = request_json["gpu_driver_version"] 34 | gpu_cuda_version = request_json["gpu_cuda_version"] 35 | num_gpus = request_json["num_gpus"] 36 | run_args = request_json["run_args"] 37 | 38 | if len(results) == 0: 39 | return "no results to write" 40 | 41 | df = pd.DataFrame(results) 42 | df["run_time"] = run_time 43 | print(f"results:\n{results}") 44 | print(f"df:\n{df}") 45 | # add other metadata to the dataframe 46 | run_args["run_id"] = run_id 47 | run_args["runner_type"] = runner_type 48 | run_args["prompt"] = prompt 49 | run_args["prompt_id"] = prompt_id 50 | run_args["model"] = model 51 | run_args["num_beams"] = num_beams 52 | run_args["db_type"] = db_type 53 | run_args["gpu_name"] = gpu_name 54 | run_args["gpu_memory"] = gpu_memory 55 | run_args["gpu_driver_version"] = gpu_driver_version 56 | run_args["gpu_cuda_version"] = gpu_cuda_version 57 | run_args["num_gpus"] = num_gpus 58 | df["run_params"] = run_args 59 | print(f"df with run_params:\n{df}") 60 | # write to bigquery 61 | pandas_gbq.to_gbq( 62 | dataframe=df, 63 | destination_table=BQ_TABLE, 64 | project_id=BQ_PROJECT, 65 | if_exists="append", 66 | progress_bar=False, 67 | credentials=credentials, 68 | ) 69 | return "success" 70 | -------------------------------------------------------------------------------- /results_fn_bigquery/requirements.txt: -------------------------------------------------------------------------------- 1 | functions_framework 2 | google-cloud-bigquery[bqstorage,pandas] 3 | pandas 4 | pandas-gbq -------------------------------------------------------------------------------- /results_fn_postgres/.env.yaml.template: -------------------------------------------------------------------------------- 1 | POSTGRES_DB: 2 | POSTGRES_HOST: 3 | POSTGRES_PORT: # put your port number between quotes so that gcloud functions can parse it properly 4 | POSTGRES_USER: 5 | POSTGRES_PASSWORD: -------------------------------------------------------------------------------- /results_fn_postgres/create.sql: -------------------------------------------------------------------------------- 1 | -- Lists the SQL commands to create the tables used for storing run data 2 | -- The database is called "sql_eval", and the tables are "prompt" and "eval" 3 | 4 | CREATE TABLE IF NOT EXISTS eval ( 5 | -- first, metadata about the run 6 | run_id VARCHAR(255), 7 | created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 8 | runner_type VARCHAR(255), 9 | prompt_id VARCHAR(255), 10 | model VARCHAR(255), 11 | num_beams INT, 12 | db_type VARCHAR(255), 13 | gpu_name VARCHAR(255), 14 | gpu_memory VARCHAR(255), 15 | gpu_driver_version VARCHAR(255), 16 | gpu_cuda_version VARCHAR(255), 17 | num_gpus VARCHAR(255), 18 | 19 | -- then, data about actual questions 20 | question TEXT, 21 | golden_query TEXT, 22 | db_name VARCHAR(255), 23 | query_category VARCHAR(255), 24 | generated_query TEXT, 25 | error_msg TEXT, 26 | exact_match BOOLEAN, 27 | correct BOOLEAN, 28 | error_db_exec BOOLEAN, 29 | latency_seconds FLOAT, 30 | tokens_used INT 31 | ); 32 | 33 | -- indexes for the table on run_id, model, db_name, and query_category 34 | CREATE INDEX IF NOT EXISTS eval_run_id ON eval(run_id); 35 | CREATE INDEX IF NOT EXISTS eval_model ON eval(model); 36 | CREATE INDEX IF NOT EXISTS eval_db_name ON eval(db_name); 37 | CREATE INDEX IF NOT EXISTS eval_query_category ON eval(query_category); 38 | 39 | -- create prompt table 40 | CREATE TABLE IF NOT EXISTS prompt ( 41 | prompt_id VARCHAR(255) PRIMARY KEY, 42 | prompt TEXT 43 | ); -------------------------------------------------------------------------------- /results_fn_postgres/main.py: -------------------------------------------------------------------------------- 1 | # this is a Google cloud function for receiving the data from the web app and storing it in Postgres 2 | 3 | import functions_framework 4 | import psycopg2 5 | import os 6 | 7 | POSTGRES_DB = os.environ.get("POSTGRES_DB") 8 | POSTGRES_HOST = os.environ.get("POSTGRES_HOST") 9 | POSTGRES_PORT = os.environ.get("POSTGRES_PORT") 10 | POSTGRES_USER = os.environ.get("POSTGRES_USER") 11 | POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD") 12 | 13 | 14 | @functions_framework.http 15 | def postgres(request): 16 | request_json = request.get_json(force=True) 17 | results = request_json["results"] 18 | run_id = request_json["run_id"] 19 | timestamp = request_json["timestamp"] 20 | runner_type = request_json["runner_type"] 21 | prompt = request_json["prompt"] 22 | prompt_id = request_json["prompt_id"] 23 | model = request_json["model"] 24 | num_beams = request_json["num_beams"] 25 | db_type = request_json["db_type"] 26 | gpu_name = request_json["gpu_name"] 27 | gpu_memory = request_json["gpu_memory"] 28 | gpu_driver_version = request_json["gpu_driver_version"] 29 | gpu_cuda_version = request_json["gpu_cuda_version"] 30 | num_gpus = request_json["num_gpus"] 31 | db_type = request_json.get("db_type", "bigquery") 32 | print( 33 | f"Received {len(results)} rows for run {run_id} at {timestamp} from {runner_type}" 34 | ) 35 | conn = psycopg2.connect( 36 | dbname=POSTGRES_DB, 37 | host=POSTGRES_HOST, 38 | port=POSTGRES_PORT, 39 | user=POSTGRES_USER, 40 | password=POSTGRES_PASSWORD, 41 | ) 42 | print(f"Connected to the postgres db {POSTGRES_DB}") 43 | cur = conn.cursor() 44 | 45 | # add prompt to the prompts table if it doesn't exist 46 | cur.execute("SELECT * FROM prompt WHERE prompt_id = %s", (prompt_id,)) 47 | if cur.fetchone() is None: 48 | cur.execute( 49 | "INSERT INTO prompt (prompt_id, prompt) VALUES (%s, %s)", 50 | (prompt_id, prompt), 51 | ) 52 | print(f"Inserted prompt {prompt_id} into the prompts table") 53 | 54 | for result in results: 55 | question = result["question"] 56 | golden_query = result["query"] 57 | db_name = result["db_name"] 58 | query_category = result["query_category"] 59 | generated_query = result["generated_query"] 60 | error_msg = result["error_msg"] 61 | exact_match = bool(result["exact_match"]) 62 | correct = bool(result["correct"]) 63 | error_db_exec = bool(result["error_db_exec"]) 64 | latency_seconds = result["latency_seconds"] 65 | tokens_used = result["tokens_used"] 66 | 67 | cur.execute( 68 | "INSERT INTO eval (run_id, question, golden_query, db_name, query_category, generated_query, error_msg, exact_match, correct, error_db_exec, latency_seconds, tokens_used, created_at, runner_type, prompt_id, model, num_beams, db_type, gpu_name, gpu_memory, gpu_driver_version, gpu_cuda_version, num_gpus) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)", 69 | ( 70 | run_id, 71 | question, 72 | golden_query, 73 | db_name, 74 | query_category, 75 | generated_query, 76 | error_msg, 77 | exact_match, 78 | correct, 79 | error_db_exec, 80 | latency_seconds, 81 | tokens_used, 82 | timestamp, 83 | runner_type, 84 | prompt_id, 85 | model, 86 | num_beams, 87 | db_type, 88 | gpu_name, 89 | gpu_memory, 90 | gpu_driver_version, 91 | gpu_cuda_version, 92 | num_gpus, 93 | ), 94 | ) 95 | print(f"Inserted {len(results)} rows into the postgres db {POSTGRES_DB}") 96 | conn.commit() 97 | cur.close() 98 | conn.close() 99 | return "success" 100 | -------------------------------------------------------------------------------- /results_fn_postgres/requirements.txt: -------------------------------------------------------------------------------- 1 | functions_framework 2 | pandas 3 | psycopg2 -------------------------------------------------------------------------------- /run_checkpoints.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | 3 | model_names=("sqlcoder_8b_fullft_ds_012_llama3_old_join_hints_mgn1_b1_0900_b2_0990_steps_600") 4 | db_type="postgres" 5 | PORT=8082 # avoid 8081 as it's used by nginx 6 | export CUDA_VISIBLE_DEVICES=0 # set gpu you want to use (just 1 will do) 7 | 8 | # if db_type not postgres or sqlite, prompt_file should be prompts/prompt_cot.md else use prompts/prompt_cot_${dbtype}.md 9 | if [ "$db_type" != "postgres" ] && [ "$db_type" != "sqlite" ]; then 10 | prompt_file="prompts/prompt_cot.md" 11 | else 12 | prompt_file="prompts/prompt_cot_${db_type}.md" 13 | fi 14 | 15 | # Loop over model names 16 | for model_name in "${model_names[@]}"; do 17 | # list the folder names in /models/combined/${model_name} 18 | model_dir="/workspace/finetuning/models/${model_name}" 19 | echo "Model directory: ${model_dir}" 20 | checkpoints=($(ls $model_dir)) 21 | echo "Checkpoints: ${checkpoints}" 22 | # Loop over checkpoints 23 | for checkpoint in "${checkpoints[@]}"; do 24 | # skip if does not start with "checkpoint-" 25 | if [[ ! $checkpoint == checkpoint-* ]]; then 26 | continue 27 | fi 28 | model_path="${model_dir}/${checkpoint}" 29 | checkpoint_num=$(echo $checkpoint | cut -d'-' -f2) 30 | echo "Running model ${model_name} checkpoint ${checkpoint_num}" 31 | # first, get the API up 32 | python3 utils/api_server.py --model "$model_path" --tensor-parallel-size 1 --dtype float16 --max-model-len 8192 --gpu-memory-utilization 0.90 --block-size 16 --disable-log-requests --port "${PORT}" & 33 | 34 | # run a loop to check if the http://localhost:8080/health endpoint returns a valid 200 result 35 | while true; do 36 | http_status=$(curl -s -o /dev/null -w "%{http_code}" "http://localhost:${PORT}/health") 37 | if [ "$http_status" -eq 200 ]; then 38 | echo "API server is up and running" 39 | break 40 | else 41 | echo "Waiting for API server to be up..." 42 | sleep 1 43 | fi 44 | done 45 | 46 | # then run sql-eval 47 | python3 main.py -db "${db_type}" \ 48 | -f "${prompt_file}" \ 49 | -q "data/questions_gen_${db_type}.csv" "data/instruct_basic_${db_type}.csv" "data/instruct_advanced_${db_type}.csv" "data/idk.csv" \ 50 | -o "results/${model_name}/c${checkpoint_num}_api_v1.csv" "results/${model_name}/c${checkpoint_num}_api_basic.csv" "results/${model_name}/c${checkpoint_num}_api_advanced.csv" "results/${model_name}/c${checkpoint_num}_api_idk.csv" \ 51 | -g api \ 52 | -b 1 \ 53 | -c 0 \ 54 | --api_url "http://localhost:${PORT}/generate" \ 55 | --api_type "vllm" \ 56 | -p 10 \ 57 | --logprobs 58 | # finally, kill the api server 59 | pkill -9 -f "python3 utils/api_server.py.*--port ${PORT}" 60 | done 61 | done 62 | -------------------------------------------------------------------------------- /run_checkpoints_adapters.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | 3 | model_names=("sqlcoder_8b_bf16_r64_ds_013_sqlite_600_b24_lr8e-5") 4 | base_model_path="defog/sqlcoder-8b-padded-sorry" 5 | db_type="sqlite" 6 | PORT=8084 # avoid 8081 as it's used by nginx 7 | export CUDA_VISIBLE_DEVICES=1 # set gpu you want to use (just 1 will do) 8 | preprocess_adapters=true # set to false if you have already preprocessed the adapters 9 | cot_table_alias=true # set to true if you want to use the cot_table_alias prompt in evals 10 | 11 | # check that the base model was trained on cot data otherwise print a warning 12 | if [[ ! $base_model_path == *"cot"* ]] && [[ $cot_table_alias == true ]]; then 13 | echo "WARNING: Base model was not trained on 'cot' data. This may lead to less than optimal results" 14 | fi 15 | for model_name in "${model_names[@]}"; do 16 | # list the folder names in /models/combined/${model_name} 17 | adapter_dir="${HOME}/finetuning/models/${model_name}" 18 | echo "Adapter directory: ${adapter_dir}" 19 | checkpoints=($(ls $adapter_dir)) 20 | echo "Checkpoints: ${checkpoints}" 21 | if [ "$preprocess_adapters" = true ]; then 22 | # Preprocess the adapters 23 | for checkpoint in "${checkpoints[@]}"; do 24 | # skip if does not start with "checkpoint-" 25 | if [[ ! $checkpoint == checkpoint-* ]]; then 26 | continue 27 | fi 28 | checkpoint_num=$(echo $checkpoint | cut -d'-' -f2) 29 | echo "Preprocessing adapter ${model_name}/checkpoint-${checkpoint_num}" 30 | python3 ${HOME}/finetuning/preprocess_adapters.py -f ${adapter_dir}/checkpoint-${checkpoint_num}/adapter_model.safetensors 31 | done 32 | else 33 | echo "Skipping preprocessing adapters..." 34 | fi 35 | 36 | # Loop over checkpoints to run sql-eval 37 | for checkpoint in "${checkpoints[@]}"; do 38 | # skip if does not start with "checkpoint-" 39 | if [[ ! $checkpoint == checkpoint-* ]]; then 40 | continue 41 | fi 42 | checkpoint_num=$(echo $checkpoint | cut -d'-' -f2) 43 | echo "Running adapter ${model_name} checkpoint ${checkpoint_num}" 44 | # first, get the API up 45 | python3 utils/api_server.py --model "$base_model_path" --tensor-parallel-size 1 --dtype float16 --max-model-len 4096 --gpu-memory-utilization 0.90 --block-size 16 --disable-log-requests --port "${PORT}" --enable-lora --max-lora-rank 64 & 46 | 47 | # run a loop to check if the http://localhost:8084/health endpoint returns a valid 200 result 48 | while true; do 49 | http_status=$(curl -s -o /dev/null -w "%{http_code}" "http://localhost:${PORT}/health") 50 | if [ "$http_status" -eq 200 ]; then 51 | echo "API server is up and running" 52 | break 53 | else 54 | echo "Waiting for API server to be up..." 55 | sleep 1 56 | fi 57 | done 58 | 59 | # then run sql-eval 60 | if [ "$cot_table_alias" = true ]; then 61 | python3 main.py -db ${db_type} \ 62 | -f prompts/prompt_cot.md \ 63 | -q "data/instruct_basic_${db_type}.csv" "data/instruct_advanced_${db_type}.csv" "data/questions_gen_${db_type}.csv" "data/idk.csv" \ 64 | -o "results/${model_name}/${model_name}_c${checkpoint_num}_${db_type}_api_cot_basic.csv" "results/${model_name}/${model_name}_c${checkpoint_num}_${db_type}_api_cot_advanced.csv" "results/${model_name}/${model_name}_c${checkpoint_num}_${db_type}_api_cot_v1.csv" "results/${model_name}/${model_name}_c${checkpoint_num}_${db_type}_api_cot_idk.csv" \ 65 | -g api \ 66 | -b 1 \ 67 | -c 0 \ 68 | --api_url "http://localhost:${PORT}/generate" \ 69 | --api_type "vllm" \ 70 | -p 10 \ 71 | -a ${adapter_dir}/checkpoint-${checkpoint_num}\ 72 | --cot_table_alias "prealias" \ 73 | --logprobs 74 | else 75 | python3 main.py -db ${db_type} \ 76 | -f prompts/prompt.md \ 77 | -q "data/instruct_basic_${db_type}.csv" "data/instruct_advanced_${db_type}.csv" "data/questions_gen_${db_type}.csv" "data/idk.csv" \ 78 | -o "results/${model_name}/${model_name}_c${checkpoint_num}_${db_type}_api_basic.csv" "results/${model_name}/${model_name}_c${checkpoint_num}_${db_type}_api_advanced.csv" "results/${model_name}/${model_name}_c${checkpoint_num}_${db_type}_api_v1.csv" "results/${model_name}/${model_name}_c${checkpoint_num}_${db_type}_api_idk.csv" \ 79 | -g api \ 80 | -b 1 \ 81 | -c 0 \ 82 | --api_url "http://localhost:${PORT}/generate" \ 83 | --api_type "vllm" \ 84 | -p 10 \ 85 | -a ${adapter_dir}/checkpoint-${checkpoint_num}\ 86 | --logprobs 87 | fi 88 | # finally, kill the api server 89 | pkill -9 -f "python3 utils/api_server.py.*--port ${PORT}" 90 | done 91 | done 92 | 93 | # pass all the model_names to the python script 94 | python3 analyze_results_and_post_to_slack.py -m "${model_names[@]}" -------------------------------------------------------------------------------- /run_checkpoints_cot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | 3 | model_names=("sqlcoder_8b_fullft_ds_012_llama3_old_join_hints_mgn1_b1_0900_b2_0990_steps_600") 4 | db_type="postgres" 5 | PORT=8083 # avoid 8081 as it's used by nginx 6 | export CUDA_VISIBLE_DEVICES=1 # set gpu you want to use (just 1 will do) 7 | 8 | # if db_type not postgres or sqlite, prompt_file should be prompts/prompt_cot.md else use prompts/prompt_cot_${dbtype}.md 9 | if [ "$db_type" != "postgres" ] && [ "$db_type" != "sqlite" ]; then 10 | prompt_file="prompts/prompt_cot.md" 11 | else 12 | prompt_file="prompts/prompt_cot_${db_type}.md" 13 | fi 14 | 15 | # Loop over model names 16 | for model_name in "${model_names[@]}"; do 17 | # list the folder names in /models/combined/${model_name} 18 | model_dir="/workspace/finetuning/models/${model_name}" 19 | echo "Model directory: ${model_dir}" 20 | checkpoints=($(ls $model_dir)) 21 | echo "Checkpoints: ${checkpoints}" 22 | # Loop over checkpoints 23 | for checkpoint in "${checkpoints[@]}"; do 24 | # skip if does not start with "checkpoint-" 25 | if [[ ! $checkpoint == checkpoint-* ]]; then 26 | continue 27 | fi 28 | model_path="${model_dir}/${checkpoint}" 29 | checkpoint_num=$(echo $checkpoint | cut -d'-' -f2) 30 | echo "Running model ${model_name} checkpoint ${checkpoint_num}" 31 | # first, get the API up 32 | python3 utils/api_server.py --model "$model_path" --tensor-parallel-size 1 --dtype float16 --max-model-len 8192 --gpu-memory-utilization 0.90 --block-size 16 --disable-log-requests --port "${PORT}" & 33 | 34 | # run a loop to check if the http://localhost:8080/health endpoint returns a valid 200 result 35 | while true; do 36 | http_status=$(curl -s -o /dev/null -w "%{http_code}" "http://localhost:${PORT}/health") 37 | if [ "$http_status" -eq 200 ]; then 38 | echo "API server is up and running" 39 | break 40 | else 41 | echo "Waiting for API server to be up..." 42 | sleep 1 43 | fi 44 | done 45 | 46 | # then run sql-eval 47 | python3 main.py -db "${db_type}" \ 48 | -f "${prompt_file}" \ 49 | -q "data/questions_gen_${db_type}.csv" "data/instruct_basic_${db_type}.csv" "data/instruct_advanced_${db_type}.csv" "data/idk.csv" \ 50 | -o "results/${model_name}/c${checkpoint_num}_api_v1_cot.csv" "results/${model_name}/c${checkpoint_num}_api_basic_cot.csv" "results/${model_name}/c${checkpoint_num}_api_advanced_cot.csv" "results/${model_name}/c${checkpoint_num}_api_idk_cot.csv" \ 51 | -g api \ 52 | -b 1 \ 53 | -c 0 \ 54 | --api_url "http://localhost:${PORT}/generate" \ 55 | --api_type "vllm" \ 56 | -p 10 \ 57 | --cot_table_alias "prealias" \ 58 | --logprobs 59 | # finally, kill the api server 60 | pkill -9 -f "python3 utils/api_server.py.*--port ${PORT}" 61 | done 62 | done 63 | -------------------------------------------------------------------------------- /run_model_cot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | 3 | # model_dir="${HOME}/models" 4 | # parse in model_names from args to script 5 | model_names=("$@") 6 | db_type="postgres" 7 | # if model_names is empty, print and exit 8 | if [ -z "$model_names" ]; then 9 | echo "No model names provided" 10 | exit 1 11 | fi 12 | PORT=8084 # avoid 8081 as it's used by nginx 13 | export CUDA_VISIBLE_DEVICES=0 # set gpu you want to use (just 1 will do) 14 | 15 | # if db_type not postgres or sqlite, prompt_file should be prompts/prompt_cot.md else use prompts/prompt_cot_${dbtype}.md 16 | if [ "$db_type" != "postgres" ] && [ "$db_type" != "sqlite" ]; then 17 | prompt_file="prompts/prompt_cot.md" 18 | else 19 | prompt_file="prompts/prompt_cot_${db_type}.md" 20 | fi 21 | 22 | # Loop over model names 23 | for model_name in "${model_names[@]}"; do 24 | 25 | echo "Running model ${model_name}" 26 | # first, get the API up 27 | python3 utils/api_server.py --model "${model_name}" --tensor-parallel-size 1 --dtype float16 --max-model-len 16384 --gpu-memory-utilization 0.90 --block-size 16 --disable-log-requests --port "${PORT}" & 28 | 29 | # run a loop to check if the http://localhost:8080/health endpoint returns a valid 200 result 30 | while true; do 31 | http_status=$(curl -s -o /dev/null -w "%{http_code}" "http://localhost:${PORT}/health") 32 | if [ "$http_status" -eq 200 ]; then 33 | echo "API server is up and running" 34 | break 35 | else 36 | echo "Waiting for API server to be up..." 37 | sleep 1 38 | fi 39 | done 40 | 41 | # then run sql-eval 42 | python3 main.py -db "${db_type}" \ 43 | -f "${prompt_file}" \ 44 | -q "data/questions_gen_${db_type}.csv" "data/instruct_basic_${db_type}.csv" "data/instruct_advanced_${db_type}.csv" "data/idk.csv" \ 45 | -o "results/${model_name}/api_v1_cot.csv" "results/${model_name}/api_basic_cot.csv" "results/${model_name}/api_advanced_cot.csv" "results/${model_name}/api_idk_cot.csv" \ 46 | -g api \ 47 | -b 1 \ 48 | -c 0 \ 49 | --api_url "http://localhost:${PORT}/generate" \ 50 | --api_type "vllm" \ 51 | -p 10 \ 52 | --logprobs 53 | # finally, kill the api server 54 | pkill -9 -f "python3 utils/api_server.py.*--port ${PORT}" 55 | 56 | done 57 | -------------------------------------------------------------------------------- /run_qwen.sh: -------------------------------------------------------------------------------- 1 | export db_type="postgres" 2 | export prompt_file="prompts/prompt_qwen.json" 3 | export model_name="qwen" 4 | export PORT=8000 5 | 6 | # assume you already have the vllm server running 7 | # vllm serve "$model_name" --port 8000 8 | 9 | if [[ "$1" == "--thinking" ]]; then 10 | echo "Running sql-eval on $model_name with thinking tokens" 11 | python3 main.py -db "${db_type}" \ 12 | -f "${prompt_file}" \ 13 | -q "data/questions_gen_${db_type}.csv" "data/instruct_basic_${db_type}.csv" "data/instruct_advanced_${db_type}.csv" \ 14 | -o "results/${model_name}/openai_api_v1.csv" "results/${model_name}/openai_api_basic.csv" "results/${model_name}/openai_api_advanced.csv" \ 15 | -g api \ 16 | -m "Qwen/Qwen3-4B" \ 17 | -b 1 \ 18 | -c 0 \ 19 | --thinking \ 20 | --api_url "http://localhost:${PORT}/v1/chat/completions" \ 21 | --api_type "openai" \ 22 | -p 10 23 | else 24 | echo "Running sql-eval on $model_name without generating thinking tokens" 25 | python3 main.py -db "${db_type}" \ 26 | -f "${prompt_file}" \ 27 | -q "data/questions_gen_${db_type}.csv" "data/instruct_basic_${db_type}.csv" "data/instruct_advanced_${db_type}.csv" \ 28 | -o "results/${model_name}/openai_api_v1.csv" "results/${model_name}/openai_api_basic.csv" "results/${model_name}/openai_api_advanced.csv" \ 29 | -g api \ 30 | -m "Qwen/Qwen3-4B" \ 31 | -b 1 \ 32 | -c 0 \ 33 | --api_url "http://localhost:${PORT}/v1/chat/completions" \ 34 | --api_type "openai" \ 35 | -p 10 36 | fi -------------------------------------------------------------------------------- /runners/anthropic_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import time 3 | from concurrent.futures import ThreadPoolExecutor, as_completed 4 | 5 | import pandas as pd 6 | import sqlparse 7 | from tqdm import tqdm 8 | 9 | from eval.eval import compare_query_results 10 | from utils.creds import db_creds_all 11 | from utils.dialects import convert_postgres_ddl_to_dialect 12 | from utils.gen_prompt import to_prompt_schema 13 | from utils.questions import prepare_questions_df 14 | from utils.reporting import upload_results 15 | from utils.llm import chat_anthropic 16 | import json 17 | 18 | 19 | def generate_prompt( 20 | prompt_file, 21 | question, 22 | db_name, 23 | db_type, 24 | instructions="", 25 | k_shot_prompt="", 26 | glossary="", 27 | table_metadata_string="", 28 | prev_invalid_sql="", 29 | prev_error_msg="", 30 | public_data=True, 31 | shuffle=True, 32 | ): 33 | if public_data: 34 | from defog_data.metadata import dbs 35 | import defog_data.supplementary as sup 36 | else: 37 | from defog_data_private.metadata import dbs 38 | import defog_data_private.supplementary as sup 39 | 40 | with open(prompt_file, "r") as f: 41 | prompt = f.read() 42 | 43 | if table_metadata_string == "": 44 | md = dbs[db_name]["table_metadata"] 45 | pruned_metadata_ddl = to_prompt_schema(md, shuffle) 46 | pruned_metadata_ddl = convert_postgres_ddl_to_dialect( 47 | postgres_ddl=pruned_metadata_ddl, 48 | to_dialect=db_type, 49 | db_name=db_name, 50 | ) 51 | column_join = sup.columns_join.get(db_name, {}) 52 | join_list = [] 53 | for values in column_join.values(): 54 | if isinstance(values[0], tuple): 55 | for col_pair in values: 56 | col_1, col_2 = col_pair 57 | join_str = f"{col_1} can be joined with {col_2}" 58 | if join_str not in join_list: 59 | join_list.append(join_str) 60 | else: 61 | col_1, col_2 = values[0] 62 | join_str = f"{col_1} can be joined with {col_2}" 63 | if join_str not in join_list: 64 | join_list.append(join_str) 65 | if len(join_list) > 0: 66 | join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) 67 | else: 68 | join_str = "" 69 | pruned_metadata_str = pruned_metadata_ddl + join_str 70 | else: 71 | pruned_metadata_str = table_metadata_string 72 | 73 | prompt = prompt.format( 74 | user_question=question, 75 | db_type=db_type, 76 | instructions=instructions, 77 | table_metadata_string=pruned_metadata_str, 78 | k_shot_prompt=k_shot_prompt, 79 | glossary=glossary, 80 | prev_invalid_sql=prev_invalid_sql, 81 | prev_error_msg=prev_error_msg, 82 | ) 83 | return prompt 84 | 85 | 86 | def process_row(row, model_name, args): 87 | start_time = time() 88 | prompt = generate_prompt( 89 | prompt_file=args.prompt_file[0], 90 | question=row["question"], 91 | db_name=row["db_name"], 92 | db_type=args.db_type, 93 | instructions=row["instructions"], 94 | k_shot_prompt=row["k_shot_prompt"], 95 | glossary=row["glossary"], 96 | table_metadata_string=row["table_metadata_string"], 97 | prev_invalid_sql=row["prev_invalid_sql"], 98 | prev_error_msg=row["prev_error_msg"], 99 | public_data=not args.use_private_data, 100 | shuffle=args.shuffle_metadata, 101 | ) 102 | messages = [{"role": "user", "content": prompt}] 103 | try: 104 | response = chat_anthropic(messages=messages, model=model_name, temperature=0.0) 105 | generated_query = ( 106 | response.content.split("```sql", 1)[-1].split("```", 1)[0].strip() 107 | ) 108 | try: 109 | generated_query = sqlparse.format( 110 | generated_query, reindent=True, keyword_case="upper" 111 | ) 112 | except: 113 | pass 114 | return { 115 | "query": generated_query, 116 | "reason": "", 117 | "err": "", 118 | "latency_seconds": time() - start_time, 119 | "tokens_used": response.input_tokens + response.output_tokens, 120 | } 121 | except Exception as e: 122 | return { 123 | "query": "", 124 | "reason": "", 125 | "err": f"GENERATION ERROR: {str(e)}", 126 | "latency_seconds": time() - start_time, 127 | "tokens_used": 0, 128 | } 129 | 130 | 131 | def run_anthropic_eval(args): 132 | # get params from args 133 | questions_file_list = args.questions_file 134 | prompt_file_list = args.prompt_file 135 | output_file_list = args.output_file 136 | num_questions = args.num_questions 137 | k_shot = args.k_shot 138 | db_type = args.db_type 139 | cot_table_alias = args.cot_table_alias 140 | 141 | for questions_file, prompt_file, output_file in zip( 142 | questions_file_list, prompt_file_list, output_file_list 143 | ): 144 | print(f"Using prompt file {prompt_file}") 145 | print("Preparing questions...") 146 | print( 147 | f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" 148 | ) 149 | question_query_df = prepare_questions_df( 150 | questions_file, db_type, num_questions, k_shot, cot_table_alias 151 | ) 152 | input_rows = question_query_df.to_dict("records") 153 | output_rows = [] 154 | with ThreadPoolExecutor(args.parallel_threads) as executor: 155 | futures = [] 156 | for row in input_rows: 157 | generated_query_fut = executor.submit( 158 | process_row, 159 | row=row, 160 | model_name=args.model, 161 | args=args, 162 | ) 163 | futures.append(generated_query_fut) 164 | 165 | total_tried = 0 166 | total_correct = 0 167 | for f in (pbar := tqdm(as_completed(futures), total=len(futures))): 168 | total_tried += 1 169 | i = futures.index(f) 170 | row = input_rows[i] 171 | result_dict = f.result() 172 | query_gen = result_dict["query"] 173 | reason = result_dict["reason"] 174 | err = result_dict["err"] 175 | # save custom metrics 176 | if "latency_seconds" in result_dict: 177 | row["latency_seconds"] = result_dict["latency_seconds"] 178 | if "tokens_used" in result_dict: 179 | row["tokens_used"] = result_dict["tokens_used"] 180 | row["generated_query"] = query_gen 181 | row["reason"] = reason 182 | row["error_msg"] = err 183 | # save failures into relevant columns in the dataframe 184 | if "GENERATION ERROR" in err: 185 | row["error_query_gen"] = 1 186 | elif "TIMEOUT" in err: 187 | row["timeout"] = 1 188 | else: 189 | expected_query = row["query"] 190 | db_name = row["db_name"] 191 | db_type = row["db_type"] 192 | try: 193 | exact_match, correct = compare_query_results( 194 | query_gold=expected_query, 195 | query_gen=query_gen, 196 | db_name=db_name, 197 | db_type=db_type, 198 | db_creds=db_creds_all[db_type], 199 | question=row["question"], 200 | query_category=row["query_category"], 201 | decimal_points=args.decimal_points, 202 | ) 203 | if correct: 204 | total_correct += 1 205 | row["correct"] = 1 206 | row["error_msg"] = "" 207 | else: 208 | row["correct"] = 0 209 | row["error_msg"] = "INCORRECT RESULTS" 210 | except Exception as e: 211 | row["error_db_exec"] = 1 212 | row["error_msg"] = f"EXECUTION ERROR: {str(e)}" 213 | output_rows.append(row) 214 | pbar.set_description( 215 | f"Accuracy: {round(total_correct/total_tried * 100, 2)}% ({total_correct}/{total_tried})" 216 | ) 217 | 218 | # save results to csv 219 | output_df = pd.DataFrame(output_rows) 220 | output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) 221 | # get directory of output_file and create if not exist 222 | output_dir = os.path.dirname(output_file) 223 | if not os.path.exists(output_dir): 224 | os.makedirs(output_dir) 225 | output_df.to_csv(output_file, index=False, float_format="%.2f") 226 | 227 | # get average rate of correct results 228 | avg_subset = output_df["correct"].sum() / len(output_df) 229 | print(f"Average correct rate: {avg_subset:.2f}") 230 | 231 | results = output_df.to_dict("records") 232 | with open( 233 | f"./eval-visualizer/public/{output_file.split('/')[-1].replace('.csv', '.json')}", 234 | "w", 235 | ) as f: 236 | json.dump(results, f) 237 | -------------------------------------------------------------------------------- /runners/bedrock_runner.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import json 3 | import os 4 | from concurrent.futures import ThreadPoolExecutor, as_completed 5 | from typing import Optional 6 | 7 | from eval.eval import compare_query_results 8 | import pandas as pd 9 | from utils.gen_prompt import generate_prompt 10 | from utils.questions import prepare_questions_df 11 | from utils.creds import db_creds_all 12 | from tqdm import tqdm 13 | from time import time 14 | from utils.reporting import upload_results 15 | 16 | bedrock = boto3.client(service_name="bedrock-runtime") 17 | 18 | 19 | def process_row(row, model_id, decimal_points): 20 | start_time = time() 21 | 22 | body = json.dumps( 23 | { 24 | "prompt": row["prompt"], 25 | "max_gen_len": 600, 26 | "temperature": 0, 27 | "top_p": 1, 28 | } 29 | ) 30 | 31 | accept = "application/json" 32 | contentType = "application/json" 33 | response = bedrock.invoke_model( 34 | body=body, modelId=model_id, accept=accept, contentType=contentType 35 | ) 36 | model_response = json.loads(response["body"].read()) 37 | 38 | generated_query = model_response["generation"] 39 | end_time = time() 40 | 41 | generated_query = ( 42 | generated_query.split("```sql")[-1].split("```")[0].split(";")[0].strip() + ";" 43 | ) 44 | 45 | row["generated_query"] = generated_query 46 | row["latency_seconds"] = end_time - start_time 47 | row["tokens_used"] = None 48 | golden_query = row["query"] 49 | db_name = row["db_name"] 50 | db_type = row["db_type"] 51 | question = row["question"] 52 | query_category = row["query_category"] 53 | table_metadata_string = row["table_metadata_string"] 54 | exact_match = correct = 0 55 | 56 | try: 57 | exact_match, correct = compare_query_results( 58 | query_gold=golden_query, 59 | query_gen=generated_query, 60 | db_name=db_name, 61 | db_type=db_type, 62 | db_creds=db_creds_all[row["db_type"]], 63 | question=question, 64 | query_category=query_category, 65 | table_metadata_string=table_metadata_string, 66 | decimal_points=decimal_points, 67 | ) 68 | row["exact_match"] = int(exact_match) 69 | row["correct"] = int(correct) 70 | row["error_msg"] = "" 71 | except Exception as e: 72 | row["error_db_exec"] = 1 73 | row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" 74 | 75 | return row 76 | 77 | 78 | def run_bedrock_eval(args): 79 | # get params from args 80 | questions_file_list = args.questions_file 81 | prompt_file_list = args.prompt_file 82 | num_questions = args.num_questions 83 | public_data = not args.use_private_data 84 | output_file_list = args.output_file 85 | k_shot = args.k_shot 86 | max_workers = args.parallel_threads 87 | db_type = args.db_type 88 | decimal_points = args.decimal_points 89 | model_id = args.model 90 | cot_table_alias = args.cot_table_alias 91 | 92 | for questions_file, prompt_file, output_file in zip( 93 | questions_file_list, prompt_file_list, output_file_list 94 | ): 95 | print(f"Using prompt file {prompt_file}") 96 | # get questions 97 | print("Preparing questions...") 98 | print( 99 | f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" 100 | ) 101 | df = prepare_questions_df( 102 | questions_file, db_type, num_questions, k_shot, cot_table_alias 103 | ) 104 | # create a prompt for each question 105 | df["prompt"] = df.apply( 106 | lambda row: generate_prompt( 107 | prompt_file, 108 | row["question"], 109 | row["db_name"], 110 | row["db_type"], 111 | row["instructions"], 112 | row["k_shot_prompt"], 113 | row["glossary"], 114 | row["table_metadata_string"], 115 | row["prev_invalid_sql"], 116 | row["prev_error_msg"], 117 | row["question_0"], 118 | row["query_0"], 119 | row["question_1"], 120 | row["query_1"], 121 | row["cot_instructions"], 122 | row["cot_pregen"], 123 | public_data, 124 | args.num_columns, 125 | args.shuffle_metadata, 126 | ), 127 | axis=1, 128 | ) 129 | 130 | total_tried = 0 131 | total_correct = 0 132 | output_rows = [] 133 | 134 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 135 | futures = [] 136 | for row in df.to_dict("records"): 137 | futures.append( 138 | executor.submit(process_row, row, model_id, decimal_points) 139 | ) 140 | 141 | with tqdm(as_completed(futures), total=len(futures)) as pbar: 142 | for f in pbar: 143 | row = f.result() 144 | output_rows.append(row) 145 | if row["correct"]: 146 | total_correct += 1 147 | total_tried += 1 148 | pbar.update(1) 149 | pbar.set_description( 150 | f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" 151 | ) 152 | 153 | output_df = pd.DataFrame(output_rows) 154 | del output_df["prompt"] 155 | print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) 156 | output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) 157 | # get directory of output_file and create if not exist 158 | output_dir = os.path.dirname(output_file) 159 | if not os.path.exists(output_dir): 160 | os.makedirs(output_dir) 161 | try: 162 | output_df.to_csv(output_file, index=False, float_format="%.2f") 163 | except: 164 | output_df.to_pickle(output_file) 165 | 166 | results = output_df.to_dict("records") 167 | # upload results 168 | with open(prompt_file, "r") as f: 169 | prompt = f.read() 170 | if args.upload_url is not None: 171 | upload_results( 172 | results=results, 173 | url=args.upload_url, 174 | runner_type="api_runner", 175 | prompt=prompt, 176 | args=args, 177 | ) 178 | -------------------------------------------------------------------------------- /runners/deepseek_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | from concurrent.futures import ThreadPoolExecutor, as_completed 3 | from typing import Dict 4 | 5 | from eval.eval import compare_query_results 6 | import pandas as pd 7 | from utils.gen_prompt import generate_prompt 8 | from utils.questions import prepare_questions_df 9 | from utils.creds import db_creds_all 10 | from tqdm import tqdm 11 | from time import time 12 | from openai import OpenAI 13 | from utils.reporting import upload_results 14 | 15 | 16 | client = OpenAI( 17 | base_url="https://api.deepseek.com", api_key=os.environ.get("DEEPSEEK_API_KEY") 18 | ) 19 | 20 | 21 | def process_row(row: Dict, model: str): 22 | start_time = time() 23 | messages = row["prompt"] 24 | if model != "deepseek-reasoner": 25 | response = client.chat.completions.create( 26 | model=model, 27 | messages=messages, 28 | max_tokens=800, 29 | temperature=0.0, 30 | ) 31 | else: 32 | response = client.chat.completions.create( 33 | model=model, 34 | messages=messages, 35 | max_tokens=800, 36 | ) 37 | content = response.choices[0].message.content 38 | generated_query = content.replace("```sql", "").replace("```", "").strip() 39 | end_time = time() 40 | 41 | row["generated_query"] = generated_query 42 | row["latency_seconds"] = end_time - start_time 43 | row["tokens_used"] = None 44 | golden_query = row["query"] 45 | db_name = row["db_name"] 46 | db_type = row["db_type"] 47 | question = row["question"] 48 | query_category = row["query_category"] 49 | table_metadata_string = row["table_metadata_string"] 50 | exact_match = correct = 0 51 | 52 | try: 53 | exact_match, correct = compare_query_results( 54 | query_gold=golden_query, 55 | query_gen=generated_query, 56 | db_name=db_name, 57 | db_type=db_type, 58 | db_creds=db_creds_all[row["db_type"]], 59 | question=question, 60 | query_category=query_category, 61 | table_metadata_string=table_metadata_string, 62 | ) 63 | row["exact_match"] = int(exact_match) 64 | row["correct"] = int(correct) 65 | row["error_msg"] = "" 66 | except Exception as e: 67 | row["error_db_exec"] = 1 68 | row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" 69 | 70 | return row 71 | 72 | 73 | def run_deepseek_eval(args): 74 | # get params from args 75 | questions_file_list = args.questions_file 76 | prompt_file_list = args.prompt_file 77 | num_questions = args.num_questions 78 | public_data = not args.use_private_data 79 | output_file_list = args.output_file 80 | k_shot = args.k_shot 81 | max_workers = args.parallel_threads 82 | db_type = args.db_type 83 | decimal_points = args.decimal_points 84 | model = args.model 85 | cot_table_alias = args.cot_table_alias 86 | 87 | for questions_file, prompt_file, output_file in zip( 88 | questions_file_list, prompt_file_list, output_file_list 89 | ): 90 | if not prompt_file.endswith(".json"): 91 | raise ValueError(f"Prompt file must be a JSON file. Got {prompt_file}") 92 | print(f"Using prompt file {prompt_file}") 93 | # get questions 94 | print("Preparing questions...") 95 | print( 96 | f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" 97 | ) 98 | df = prepare_questions_df( 99 | questions_file, db_type, num_questions, k_shot, cot_table_alias 100 | ) 101 | # create a prompt for each question 102 | # note that the prompt for together ai uses the openai chat API 103 | df["prompt"] = df.apply( 104 | lambda row: generate_prompt( 105 | prompt_file, 106 | row["question"], 107 | row["db_name"], 108 | row["db_type"], 109 | row["instructions"], 110 | row["k_shot_prompt"], 111 | row["glossary"], 112 | row["table_metadata_string"], 113 | row["prev_invalid_sql"], 114 | row["prev_error_msg"], 115 | row["question_0"], 116 | row["query_0"], 117 | row["question_1"], 118 | row["query_1"], 119 | row["cot_instructions"], 120 | row["cot_pregen"], 121 | public_data, 122 | args.num_columns, 123 | args.shuffle_metadata, 124 | row["table_aliases"], 125 | ), 126 | axis=1, 127 | ) 128 | 129 | total_tried = 0 130 | total_correct = 0 131 | output_rows = [] 132 | 133 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 134 | futures = [] 135 | for row in df.to_dict("records"): 136 | futures.append(executor.submit(process_row, row, model)) 137 | 138 | with tqdm(as_completed(futures), total=len(futures)) as pbar: 139 | for f in pbar: 140 | row = f.result() 141 | output_rows.append(row) 142 | if row["correct"]: 143 | total_correct += 1 144 | total_tried += 1 145 | pbar.update(1) 146 | pbar.set_description( 147 | f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" 148 | ) 149 | 150 | output_df = pd.DataFrame(output_rows) 151 | del output_df["prompt"] 152 | print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) 153 | output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) 154 | # get directory of output_file and create if not exist 155 | output_dir = os.path.dirname(output_file) 156 | if not os.path.exists(output_dir): 157 | os.makedirs(output_dir) 158 | try: 159 | output_df.to_csv(output_file, index=False, float_format="%.2f") 160 | except: 161 | output_df.to_pickle(output_file) 162 | 163 | results = output_df.to_dict("records") 164 | with open( 165 | f"./eval-visualizer/public/{output_file.split('/')[-1].replace('.csv', '.json')}", 166 | "w", 167 | ) as f: 168 | json.dump(results, f) 169 | 170 | print("Total cost of evaluation (in cents): ", output_df["cost_in_cents"].sum()) 171 | -------------------------------------------------------------------------------- /runners/gemini_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import time 3 | from concurrent.futures import ThreadPoolExecutor, as_completed 4 | 5 | import pandas as pd 6 | import sqlparse 7 | from tqdm import tqdm 8 | import json 9 | 10 | from eval.eval import compare_query_results 11 | from utils.creds import db_creds_all 12 | from utils.dialects import convert_postgres_ddl_to_dialect 13 | from utils.gen_prompt import to_prompt_schema 14 | from utils.questions import prepare_questions_df 15 | from utils.reporting import upload_results 16 | from utils.llm import chat_gemini 17 | 18 | 19 | def generate_prompt( 20 | prompt_file, 21 | question, 22 | db_name, 23 | db_type, 24 | instructions="", 25 | k_shot_prompt="", 26 | glossary="", 27 | table_metadata_string="", 28 | prev_invalid_sql="", 29 | prev_error_msg="", 30 | public_data=True, 31 | shuffle=True, 32 | ): 33 | if public_data: 34 | from defog_data.metadata import dbs 35 | import defog_data.supplementary as sup 36 | else: 37 | # raise Exception("Replace this with your private data import") 38 | from defog_data_private.metadata import dbs 39 | import defog_data_private.supplementary as sup 40 | 41 | with open(prompt_file, "r") as f: 42 | prompt = f.read() 43 | 44 | if table_metadata_string == "": 45 | md = dbs[db_name]["table_metadata"] 46 | pruned_metadata_ddl = to_prompt_schema(md, shuffle) 47 | pruned_metadata_ddl = convert_postgres_ddl_to_dialect( 48 | postgres_ddl=pruned_metadata_ddl, 49 | to_dialect=db_type, 50 | db_name=db_name, 51 | ) 52 | column_join = sup.columns_join.get(db_name, {}) 53 | # get join_str from column_join 54 | join_list = [] 55 | for values in column_join.values(): 56 | if isinstance(values[0], tuple): 57 | for col_pair in values: 58 | col_1, col_2 = col_pair 59 | # add to join_list 60 | join_str = f"{col_1} can be joined with {col_2}" 61 | if join_str not in join_list: 62 | join_list.append(join_str) 63 | else: 64 | col_1, col_2 = values[0] 65 | # add to join_list 66 | join_str = f"{col_1} can be joined with {col_2}" 67 | if join_str not in join_list: 68 | join_list.append(join_str) 69 | if len(join_list) > 0: 70 | join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) 71 | else: 72 | join_str = "" 73 | pruned_metadata_str = pruned_metadata_ddl + join_str 74 | else: 75 | pruned_metadata_str = table_metadata_string 76 | 77 | prompt = prompt.format( 78 | user_question=question, 79 | db_type=db_type, 80 | instructions=instructions, 81 | table_metadata_string=pruned_metadata_str, 82 | k_shot_prompt=k_shot_prompt, 83 | glossary=glossary, 84 | prev_invalid_sql=prev_invalid_sql, 85 | prev_error_msg=prev_error_msg, 86 | ) 87 | return prompt 88 | 89 | 90 | def process_row(row, model_name, args): 91 | start_time = time() 92 | messages = [{"role": "user", "content": row["prompt"]}] 93 | try: 94 | response = chat_gemini(messages=messages, model=model_name, temperature=0.0) 95 | generated_query = ( 96 | response.content.split("```sql", 1)[-1].split("```", 1)[0].strip() 97 | ) 98 | try: 99 | generated_query = sqlparse.format( 100 | generated_query, 101 | strip_comments=True, 102 | strip_whitespace=True, 103 | keyword_case="upper", 104 | ) 105 | except: 106 | pass 107 | row["generated_query"] = generated_query 108 | row["latency_seconds"] = response.time 109 | row["tokens_used"] = response.input_tokens + response.output_tokens 110 | except Exception as e: 111 | row["error_db_exec"] = 1 112 | row["error_msg"] = f"GENERATION ERROR: {e}" 113 | return row 114 | 115 | golden_query = row["query"] 116 | db_name = row["db_name"] 117 | db_type = row["db_type"] 118 | question = row["question"] 119 | query_category = row["query_category"] 120 | exact_match = correct = 0 121 | 122 | try: 123 | exact_match, correct = compare_query_results( 124 | query_gold=golden_query, 125 | query_gen=generated_query, 126 | db_name=db_name, 127 | db_type=db_type, 128 | db_creds=db_creds_all[db_type], 129 | question=question, 130 | query_category=query_category, 131 | decimal_points=args.decimal_points, 132 | ) 133 | row["exact_match"] = int(exact_match) 134 | row["correct"] = int(correct) 135 | row["error_msg"] = "" 136 | except Exception as e: 137 | row["error_db_exec"] = 1 138 | row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" 139 | 140 | return row 141 | 142 | 143 | def run_gemini_eval(args): 144 | # get params from args 145 | questions_file_list = args.questions_file 146 | prompt_file_list = args.prompt_file 147 | num_questions = args.num_questions 148 | public_data = not args.use_private_data 149 | model_name = args.model 150 | output_file_list = args.output_file 151 | k_shot = args.k_shot 152 | max_workers = args.parallel_threads 153 | db_type = args.db_type 154 | cot_table_alias = args.cot_table_alias 155 | 156 | for questions_file, prompt_file, output_file in zip( 157 | questions_file_list, prompt_file_list, output_file_list 158 | ): 159 | print(f"Using prompt file {prompt_file}") 160 | print("Preparing questions...") 161 | print( 162 | f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" 163 | ) 164 | df = prepare_questions_df( 165 | questions_file, db_type, num_questions, k_shot, cot_table_alias 166 | ) 167 | 168 | df["prompt"] = df.apply( 169 | lambda row: generate_prompt( 170 | prompt_file, 171 | row["question"], 172 | row["db_name"], 173 | row["db_type"], 174 | row["instructions"], 175 | row["k_shot_prompt"], 176 | row["glossary"], 177 | row["table_metadata_string"], 178 | row["prev_invalid_sql"], 179 | row["prev_error_msg"], 180 | public_data, 181 | args.shuffle_metadata, 182 | ), 183 | axis=1, 184 | ) 185 | 186 | total_tried = 0 187 | total_correct = 0 188 | output_rows = [] 189 | 190 | print(f"Running evaluation using {model_name}...") 191 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 192 | futures = [] 193 | for row in df.to_dict("records"): 194 | futures.append(executor.submit(process_row, row, model_name, args)) 195 | 196 | with tqdm(as_completed(futures), total=len(futures)) as pbar: 197 | for f in pbar: 198 | row = f.result() 199 | output_rows.append(row) 200 | if row.get("correct", 0): 201 | total_correct += 1 202 | total_tried += 1 203 | pbar.set_description( 204 | f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" 205 | ) 206 | 207 | output_df = pd.DataFrame(output_rows) 208 | del output_df["prompt"] 209 | print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) 210 | output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) 211 | 212 | output_dir = os.path.dirname(output_file) 213 | if not os.path.exists(output_dir): 214 | os.makedirs(output_dir) 215 | try: 216 | output_df.to_csv(output_file, index=False, float_format="%.2f") 217 | except: 218 | output_df.to_pickle(output_file) 219 | 220 | results = output_df.to_dict("records") 221 | with open( 222 | f"./eval-visualizer/public/{output_file.split('/')[-1].replace('.csv', '.json')}", 223 | "w", 224 | ) as f: 225 | json.dump(results, f) 226 | -------------------------------------------------------------------------------- /runners/hf_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | from eval.eval import compare_query_results 5 | import pandas as pd 6 | import torch 7 | from transformers import ( 8 | AutoTokenizer, 9 | AutoModelForCausalLM, 10 | pipeline, 11 | ) 12 | from utils.gen_prompt import generate_prompt 13 | from utils.questions import prepare_questions_df 14 | from utils.creds import db_creds_all 15 | from tqdm import tqdm 16 | from psycopg2.extensions import QueryCanceledError 17 | from time import time 18 | import gc 19 | from utils.reporting import upload_results 20 | 21 | device_map = "mps" if torch.backends.mps.is_available() else "auto" 22 | 23 | 24 | def get_tokenizer_model(model_name: Optional[str], adapter_path: Optional[str]): 25 | """ 26 | Load a HuggingFace tokenizer and model. 27 | You may supply either a normal huggingface model name, or a peft adapter path. 28 | """ 29 | if adapter_path is not None: 30 | from peft import PeftModel, PeftConfig 31 | 32 | print(f"Loading adapter model {adapter_path}") 33 | config = PeftConfig.from_pretrained(adapter_path) 34 | tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) 35 | model = AutoModelForCausalLM.from_pretrained( 36 | config.base_model_name_or_path, 37 | torch_dtype=torch.float16, 38 | trust_remote_code=True, 39 | use_cache=True, 40 | device_map=device_map, 41 | ) 42 | print(f"Loading adapter {adapter_path}") 43 | model = PeftModel.from_pretrained(model, adapter_path) 44 | model = model.merge_and_unload() 45 | print(f"Merged adapter {adapter_path}") 46 | else: 47 | print(f"Loading model {model_name}") 48 | try: 49 | tokenizer = AutoTokenizer.from_pretrained(model_name) 50 | except: 51 | tokenizer = AutoTokenizer.from_pretrained( 52 | "meta-llama/Meta-Llama-3-8B-Instruct" 53 | ) 54 | 55 | tokenizer.pad_token_id = tokenizer.eos_token_id 56 | model = AutoModelForCausalLM.from_pretrained( 57 | model_name, 58 | torch_dtype=torch.float16, 59 | trust_remote_code=True, 60 | device_map=device_map, 61 | ) 62 | return tokenizer, model 63 | 64 | 65 | def run_hf_eval(args): 66 | # get params from args 67 | questions_file_list = args.questions_file 68 | prompt_file_list = args.prompt_file 69 | num_questions = args.num_questions 70 | public_data = not args.use_private_data 71 | model_name = args.model 72 | adapter_path = args.adapter 73 | output_file_list = args.output_file 74 | k_shot = args.k_shot 75 | db_type = args.db_type 76 | num_beams = args.num_beams 77 | cot_table_alias = args.cot_table_alias 78 | 79 | if model_name is None and adapter_path is None: 80 | raise ValueError( 81 | "You must supply either a model name or an adapter path to run an evaluation." 82 | ) 83 | 84 | print(f"Questions prepared\nNow loading model...") 85 | # initialize tokenizer and model 86 | tokenizer, model = get_tokenizer_model(model_name, adapter_path) 87 | 88 | if "8b" in model_name.lower(): 89 | # do this since it doesn't seem to have been done by default 90 | tokenizer.padding_side = "left" 91 | 92 | tokenizer.pad_token_id = tokenizer.eos_token_id 93 | model.tie_weights() 94 | 95 | print("model loaded\nnow generating and evaluating predictions...") 96 | 97 | # from here, we generate and evaluate predictions 98 | # eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0] 99 | pipe = pipeline( 100 | "text-generation", model=model, tokenizer=tokenizer, batch_size=args.batch_size 101 | ) 102 | 103 | for questions_file, prompt_file, output_file in zip( 104 | questions_file_list, prompt_file_list, output_file_list 105 | ): 106 | print(f"Using prompt file {prompt_file}") 107 | # get questions 108 | print("Preparing questions...") 109 | print( 110 | f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" 111 | ) 112 | df = prepare_questions_df( 113 | questions_file, db_type, num_questions, k_shot, cot_table_alias 114 | ) 115 | # create a prompt for each question 116 | df["prompt"] = df.apply( 117 | lambda row: generate_prompt( 118 | prompt_file, 119 | row["question"], 120 | row["db_name"], 121 | row["db_type"], 122 | row["instructions"], 123 | row["k_shot_prompt"], 124 | row["glossary"], 125 | row["table_metadata_string"], 126 | row["prev_invalid_sql"], 127 | row["prev_error_msg"], 128 | row["question_0"], 129 | row["query_0"], 130 | row["question_1"], 131 | row["query_1"], 132 | row["cot_instructions"], 133 | row["cot_pregen"], 134 | public_data, 135 | args.num_columns, 136 | args.shuffle_metadata, 137 | ), 138 | axis=1, 139 | ) 140 | 141 | total_tried = 0 142 | total_correct = 0 143 | output_rows = [] 144 | 145 | def chunk_dataframe(df, chunk_size): 146 | """Yield successive chunk_size chunks from df.""" 147 | for i in range(0, len(df), chunk_size): 148 | yield df[i : min(i + chunk_size, len(df))] 149 | 150 | df_chunks = list(chunk_dataframe(df, args.batch_size)) 151 | 152 | with tqdm(total=len(df)) as pbar: 153 | for batch in df_chunks: 154 | prompts = batch["prompt"].tolist() 155 | generated_queries = pipe( 156 | prompts, 157 | max_new_tokens=600, 158 | do_sample=False, 159 | num_beams=num_beams, 160 | num_return_sequences=1, 161 | return_full_text=False, 162 | eos_token_id=tokenizer.eos_token_id, 163 | pad_token_id=tokenizer.eos_token_id, 164 | temperature=None, 165 | top_p=None, 166 | ) 167 | gc.collect() 168 | torch.cuda.empty_cache() 169 | torch.cuda.synchronize() 170 | 171 | for row, result in zip(batch.to_dict("records"), generated_queries): 172 | total_tried += 1 173 | # we set return_full_text to False so that we don't get the prompt text in the generated text 174 | # this simplifies our postprocessing to deal with just the truncation of the end of the query 175 | 176 | if "[SQL]" not in row["prompt"]: 177 | generated_query = ( 178 | result[0]["generated_text"] 179 | .split("```")[0] 180 | .split(";")[0] 181 | .strip() 182 | + ";" 183 | ) 184 | else: 185 | generated_query = ( 186 | result[0]["generated_text"] 187 | .split("[/SQL]")[0] 188 | .split(";")[0] 189 | .strip() 190 | + ";" 191 | ) 192 | 193 | gc.collect() 194 | if torch.cuda.is_available(): 195 | torch.cuda.empty_cache() 196 | torch.cuda.synchronize() 197 | 198 | row["generated_query"] = generated_query 199 | row["latency_seconds"] = None 200 | golden_query = row["query"] 201 | db_name = row["db_name"] 202 | db_type = row["db_type"] 203 | question = row["question"] 204 | query_category = row["query_category"] 205 | table_metadata_string = row["table_metadata_string"] 206 | exact_match = correct = 0 207 | db_creds = db_creds_all[db_type] 208 | 209 | try: 210 | exact_match, correct = compare_query_results( 211 | query_gold=golden_query, 212 | query_gen=generated_query, 213 | db_name=db_name, 214 | db_type=db_type, 215 | db_creds=db_creds, 216 | question=question, 217 | query_category=query_category, 218 | table_metadata_string=table_metadata_string, 219 | decimal_points=args.decimal_points, 220 | ) 221 | row["exact_match"] = int(exact_match) 222 | row["correct"] = int(correct) 223 | row["error_msg"] = "" 224 | if correct: 225 | total_correct += 1 226 | except QueryCanceledError as e: 227 | row["timeout"] = 1 228 | row["error_msg"] = f"QUERY EXECUTION TIMEOUT: {e}" 229 | except Exception as e: 230 | row["error_db_exec"] = 1 231 | row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" 232 | 233 | output_rows.append(row) 234 | pbar.update(1) 235 | pbar.set_description( 236 | f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" 237 | ) 238 | 239 | output_df = pd.DataFrame(output_rows) 240 | del output_df["prompt"] 241 | print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) 242 | output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) 243 | # get directory of output_file and create if not exist 244 | output_dir = os.path.dirname(output_file) 245 | if not os.path.exists(output_dir): 246 | os.makedirs(output_dir) 247 | output_df.to_csv(output_file, index=False, float_format="%.2f") 248 | 249 | results = output_df.to_dict("records") 250 | # upload results 251 | with open(prompt_file, "r") as f: 252 | prompt = f.read() 253 | if args.upload_url is not None: 254 | upload_results( 255 | results=results, 256 | url=args.upload_url, 257 | runner_type="hf_runner", 258 | prompt=prompt, 259 | args=args, 260 | ) 261 | -------------------------------------------------------------------------------- /runners/llama_cpp_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from eval.eval import compare_query_results 4 | import pandas as pd 5 | from utils.gen_prompt import generate_prompt 6 | from utils.questions import prepare_questions_df 7 | from utils.creds import db_creds_all 8 | from tqdm import tqdm 9 | from time import time 10 | from utils.reporting import upload_results 11 | from llama_cpp import Llama 12 | 13 | 14 | def process_row(llm, row, args): 15 | start_time = time() 16 | prompt = row["prompt"] 17 | generated_query = ( 18 | llm( 19 | prompt, 20 | max_tokens=512, 21 | temperature=0, 22 | top_p=1, 23 | echo=False, 24 | repeat_penalty=1.0, 25 | )["choices"][0]["text"] 26 | .split(";")[0] 27 | .split("```")[0] 28 | .strip() 29 | + ";" 30 | ) 31 | end_time = time() 32 | row["generated_query"] = generated_query 33 | row["latency_seconds"] = end_time - start_time 34 | golden_query = row["query"] 35 | db_name = row["db_name"] 36 | db_type = row["db_type"] 37 | question = row["question"] 38 | query_category = row["query_category"] 39 | table_metadata_string = row["table_metadata_string"] 40 | exact_match = correct = 0 41 | 42 | try: 43 | exact_match, correct = compare_query_results( 44 | query_gold=golden_query, 45 | query_gen=generated_query, 46 | db_name=db_name, 47 | db_type=db_type, 48 | db_creds=db_creds_all[row["db_type"]], 49 | question=question, 50 | query_category=query_category, 51 | table_metadata_string=table_metadata_string, 52 | decimal_points=args.decimal_points, 53 | ) 54 | row["exact_match"] = int(exact_match) 55 | row["correct"] = int(correct) 56 | row["error_msg"] = "" 57 | except Exception as e: 58 | row["error_db_exec"] = 1 59 | row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" 60 | return row 61 | 62 | 63 | def run_llama_cpp_eval(args): 64 | # get params from args 65 | questions_file_list = args.questions_file 66 | prompt_file_list = args.prompt_file 67 | num_questions = args.num_questions 68 | public_data = not args.use_private_data 69 | model_path = args.model 70 | output_file_list = args.output_file 71 | k_shot = args.k_shot 72 | db_type = args.db_type 73 | cot_table_alias = args.cot_table_alias 74 | 75 | llm = Llama(model_path=model_path, n_gpu_layers=-1, n_ctx=4096) 76 | 77 | for questions_file, prompt_file, output_file in zip( 78 | questions_file_list, prompt_file_list, output_file_list 79 | ): 80 | print(f"Using prompt file {prompt_file}") 81 | # get questions 82 | print("Preparing questions...") 83 | print( 84 | f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" 85 | ) 86 | df = prepare_questions_df( 87 | questions_file, db_type, num_questions, k_shot, cot_table_alias 88 | ) 89 | # create a prompt for each question 90 | df["prompt"] = df.apply( 91 | lambda row: generate_prompt( 92 | prompt_file, 93 | row["question"], 94 | row["db_name"], 95 | row["db_type"], 96 | row["instructions"], 97 | row["k_shot_prompt"], 98 | row["glossary"], 99 | row["table_metadata_string"], 100 | row["prev_invalid_sql"], 101 | row["prev_error_msg"], 102 | row["question_0"], 103 | row["query_0"], 104 | row["question_1"], 105 | row["query_1"], 106 | row["cot_instructions"], 107 | row["cot_pregen"], 108 | public_data, 109 | args.num_columns, 110 | args.shuffle_metadata, 111 | ), 112 | axis=1, 113 | ) 114 | 115 | total_tried = 0 116 | total_correct = 0 117 | output_rows = [] 118 | 119 | with tqdm(total=len(df)) as pbar: 120 | for row in df.to_dict("records"): 121 | row = process_row(llm, row, args) 122 | output_rows.append(row) 123 | if row["correct"]: 124 | total_correct += 1 125 | total_tried += 1 126 | pbar.update(1) 127 | pbar.set_description( 128 | f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" 129 | ) 130 | 131 | output_df = pd.DataFrame(output_rows) 132 | del output_df["prompt"] 133 | print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) 134 | output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) 135 | # get directory of output_file and create if not exist 136 | output_dir = os.path.dirname(output_file) 137 | if not os.path.exists(output_dir): 138 | os.makedirs(output_dir) 139 | try: 140 | output_df.to_csv(output_file, index=False, float_format="%.2f") 141 | except: 142 | output_df.to_pickle(output_file) 143 | 144 | results = output_df.to_dict("records") 145 | # upload results 146 | with open(prompt_file, "r") as f: 147 | prompt = f.read() 148 | if args.upload_url is not None: 149 | upload_results( 150 | results=results, 151 | url=args.upload_url, 152 | runner_type="llama_cpp_runner", 153 | prompt=prompt, 154 | args=args, 155 | ) 156 | -------------------------------------------------------------------------------- /runners/mistral_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import time 3 | from concurrent.futures import ThreadPoolExecutor, as_completed 4 | 5 | from mistralai.client import MistralClient 6 | from mistralai.models.chat_completion import ChatMessage 7 | import pandas as pd 8 | from tqdm import tqdm 9 | 10 | from eval.eval import compare_query_results 11 | from utils.creds import db_creds_all 12 | from utils.gen_prompt import to_prompt_schema 13 | from utils.dialects import convert_postgres_ddl_to_dialect 14 | from utils.questions import prepare_questions_df 15 | from utils.reporting import upload_results 16 | 17 | api_key = os.environ.get("MISTRAL_API_KEY") 18 | client = MistralClient(api_key=api_key) 19 | 20 | 21 | def generate_prompt( 22 | prompt_file, 23 | question, 24 | db_name, 25 | db_type, 26 | instructions="", 27 | k_shot_prompt="", 28 | glossary="", 29 | table_metadata_string="", 30 | prev_invalid_sql="", 31 | prev_error_msg="", 32 | public_data=True, 33 | shuffle=True, 34 | ): 35 | with open(prompt_file, "r") as f: 36 | prompt = f.read() 37 | 38 | # Check that System and User prompts are in the prompt file 39 | if "System:" not in prompt or "User:" not in prompt: 40 | raise ValueError("Invalid prompt file. Please use prompt_mistral.md") 41 | sys_prompt = prompt.split("System:")[1].split("User:")[0].strip() 42 | user_prompt = prompt.split("User:")[1].strip() 43 | 44 | if table_metadata_string == "": 45 | if public_data: 46 | from defog_data.metadata import dbs 47 | import defog_data.supplementary as sup 48 | else: 49 | from defog_data_private.metadata import dbs 50 | import defog_data_private.supplementary as sup 51 | 52 | md = dbs[db_name]["table_metadata"] 53 | metadata_ddl = to_prompt_schema(md, shuffle) 54 | metadata_ddl = convert_postgres_ddl_to_dialect( 55 | postgres_ddl=metadata_ddl, 56 | to_dialect=db_type, 57 | db_name=db_name, 58 | ) 59 | column_join = sup.columns_join.get(db_name, {}) 60 | # get join_str from column_join 61 | join_list = [] 62 | for values in column_join.values(): 63 | if isinstance(values[0], tuple): 64 | for col_pair in values: 65 | col_1, col_2 = col_pair 66 | # add to join_list 67 | join_str = f"{col_1} can be joined with {col_2}" 68 | if join_str not in join_list: 69 | join_list.append(join_str) 70 | else: 71 | col_1, col_2 = values[0] 72 | # add to join_list 73 | join_str = f"{col_1} can be joined with {col_2}" 74 | if join_str not in join_list: 75 | join_list.append(join_str) 76 | if len(join_list) > 0: 77 | join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) 78 | else: 79 | join_str = "" 80 | pruned_metadata_str = metadata_ddl + join_str 81 | else: 82 | pruned_metadata_str = table_metadata_string 83 | 84 | user_prompt = user_prompt.format( 85 | user_question=question, 86 | instructions=instructions, 87 | table_metadata_string=pruned_metadata_str, 88 | k_shot_prompt=k_shot_prompt, 89 | glossary=glossary, 90 | prev_invalid_sql=prev_invalid_sql, 91 | prev_error_msg=prev_error_msg, 92 | ) 93 | messages = [ 94 | ChatMessage( 95 | role="system", 96 | content=sys_prompt, 97 | ), 98 | ChatMessage( 99 | role="user", 100 | content=user_prompt, 101 | ), 102 | ] 103 | return messages 104 | 105 | 106 | def process_row(row, model, args): 107 | start_time = time() 108 | chat_response = client.chat( 109 | model=model, 110 | messages=row["prompt"], 111 | temperature=0, 112 | max_tokens=600, 113 | ) 114 | end_time = time() 115 | generated_query = chat_response.choices[0].message.content 116 | 117 | try: 118 | # replace all backslashes with empty string 119 | generated_query = generated_query.replace("\\", "") 120 | 121 | generated_query = generated_query.split(";")[0].split("```sql")[-1].strip() 122 | generated_query = [i for i in generated_query.split("```") if i.strip() != ""][ 123 | 0 124 | ] + ";" 125 | except Exception as e: 126 | print(e) 127 | generated_query = chat_response.choices[0].message.content 128 | row["generated_query"] = generated_query 129 | row["latency_seconds"] = end_time - start_time 130 | golden_query = row["query"] 131 | db_name = row["db_name"] 132 | db_type = row["db_type"] 133 | question = row["question"] 134 | query_category = row["query_category"] 135 | table_metadata_string = row["table_metadata_string"] 136 | exact_match = correct = 0 137 | 138 | try: 139 | exact_match, correct = compare_query_results( 140 | query_gold=golden_query, 141 | query_gen=generated_query, 142 | db_name=db_name, 143 | db_type=db_type, 144 | db_creds=db_creds_all[row["db_type"]], 145 | question=question, 146 | query_category=query_category, 147 | table_metadata_string=table_metadata_string, 148 | decimal_points=args.decimal_points, 149 | ) 150 | row["exact_match"] = int(exact_match) 151 | row["correct"] = int(correct) 152 | row["error_msg"] = "" 153 | except Exception as e: 154 | row["error_db_exec"] = 1 155 | row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" 156 | 157 | return row 158 | 159 | 160 | def run_mistral_eval(args): 161 | # get params from args 162 | questions_file_list = args.questions_file 163 | prompt_file_list = args.prompt_file 164 | num_questions = args.num_questions 165 | public_data = not args.use_private_data 166 | model = args.model 167 | output_file_list = args.output_file 168 | k_shot = args.k_shot 169 | max_workers = args.parallel_threads 170 | db_type = args.db_type 171 | cot_table_alias = args.cot_table_alias 172 | 173 | for questions_file, prompt_file, output_file in zip( 174 | questions_file_list, prompt_file_list, output_file_list 175 | ): 176 | print(f"Using prompt file {prompt_file}") 177 | # get questions 178 | print("Preparing questions...") 179 | print( 180 | f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" 181 | ) 182 | df = prepare_questions_df( 183 | questions_file, db_type, num_questions, k_shot, cot_table_alias 184 | ) 185 | # create a prompt for each question 186 | df["prompt"] = df.apply( 187 | lambda row: generate_prompt( 188 | prompt_file, 189 | row["question"], 190 | row["db_name"], 191 | row["db_type"], 192 | row["instructions"], 193 | row["k_shot_prompt"], 194 | row["glossary"], 195 | row["table_metadata_string"], 196 | row["prev_invalid_sql"], 197 | row["prev_error_msg"], 198 | public_data, 199 | args.shuffle_metadata, 200 | ), 201 | axis=1, 202 | ) 203 | 204 | total_tried = 0 205 | total_correct = 0 206 | output_rows = [] 207 | 208 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 209 | futures = [] 210 | for row in df.to_dict("records"): 211 | futures.append(executor.submit(process_row, row, model, args)) 212 | 213 | with tqdm(as_completed(futures), total=len(futures)) as pbar: 214 | for f in pbar: 215 | row = f.result() 216 | output_rows.append(row) 217 | if row.get("correct", 0): 218 | total_correct += 1 219 | total_tried += 1 220 | pbar.set_description( 221 | f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" 222 | ) 223 | 224 | output_df = pd.DataFrame(output_rows) 225 | del output_df["prompt"] 226 | print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) 227 | output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) 228 | # get directory of output_file and create if not exist 229 | output_dir = os.path.dirname(output_file) 230 | if not os.path.exists(output_dir): 231 | os.makedirs(output_dir) 232 | try: 233 | output_df.to_csv(output_file, index=False, float_format="%.2f") 234 | except: 235 | output_df.to_pickle(output_file) 236 | 237 | results = output_df.to_dict("records") 238 | # upload results 239 | with open(prompt_file, "r") as f: 240 | prompt = f.read() 241 | if args.upload_url is not None: 242 | upload_results( 243 | results=results, 244 | url=args.upload_url, 245 | runner_type="mistral_runner", 246 | prompt=prompt, 247 | args=args, 248 | ) 249 | -------------------------------------------------------------------------------- /runners/mlx_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from eval.eval import compare_query_results 4 | import pandas as pd 5 | from utils.gen_prompt import generate_prompt 6 | from utils.questions import prepare_questions_df 7 | from utils.creds import db_creds_all 8 | from tqdm import tqdm 9 | from time import time 10 | from utils.reporting import upload_results 11 | from mlx_lm import load, generate 12 | 13 | 14 | def process_row(model, tokenizer, row, args): 15 | start_time = time() 16 | prompt = row["prompt"] 17 | 18 | generated_query = ( 19 | generate(model, tokenizer, prompt=prompt, max_tokens=512, temp=0, verbose=True) 20 | .split(";")[0] 21 | .split("```")[0] 22 | .strip() 23 | + ";" 24 | ) 25 | end_time = time() 26 | row["generated_query"] = generated_query 27 | row["latency_seconds"] = end_time - start_time 28 | golden_query = row["query"] 29 | db_name = row["db_name"] 30 | db_type = row["db_type"] 31 | question = row["question"] 32 | query_category = row["query_category"] 33 | table_metadata_string = row["table_metadata_string"] 34 | exact_match = correct = 0 35 | 36 | try: 37 | exact_match, correct = compare_query_results( 38 | query_gold=golden_query, 39 | query_gen=generated_query, 40 | db_name=db_name, 41 | db_type=db_type, 42 | db_creds=db_creds_all[row["db_type"]], 43 | question=question, 44 | query_category=query_category, 45 | table_metadata_string=table_metadata_string, 46 | decimal_points=args.decimal_points, 47 | ) 48 | row["exact_match"] = int(exact_match) 49 | row["correct"] = int(correct) 50 | row["error_msg"] = "" 51 | except Exception as e: 52 | row["error_db_exec"] = 1 53 | row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" 54 | return row 55 | 56 | 57 | def run_mlx_eval(args): 58 | # get params from args 59 | questions_file_list = args.questions_file 60 | prompt_file_list = args.prompt_file 61 | num_questions = args.num_questions 62 | public_data = not args.use_private_data 63 | model_path = args.model 64 | output_file_list = args.output_file 65 | k_shot = args.k_shot 66 | db_type = args.db_type 67 | cot_table_alias = args.cot_table_alias 68 | 69 | model, tokenizer = load(model_path) 70 | 71 | for questions_file, prompt_file, output_file in zip( 72 | questions_file_list, prompt_file_list, output_file_list 73 | ): 74 | print(f"Using prompt file {prompt_file}") 75 | # get questions 76 | print("Preparing questions...") 77 | print( 78 | f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" 79 | ) 80 | df = prepare_questions_df( 81 | questions_file, db_type, num_questions, k_shot, cot_table_alias 82 | ) 83 | # create a prompt for each question 84 | df["prompt"] = df.apply( 85 | lambda row: generate_prompt( 86 | prompt_file, 87 | row["question"], 88 | row["db_name"], 89 | row["db_type"], 90 | row["instructions"], 91 | row["k_shot_prompt"], 92 | row["glossary"], 93 | row["table_metadata_string"], 94 | row["prev_invalid_sql"], 95 | row["prev_error_msg"], 96 | row["question_0"], 97 | row["query_0"], 98 | row["question_1"], 99 | row["query_1"], 100 | row["cot_instructions"], 101 | row["cot_pregen"], 102 | public_data, 103 | args.num_columns, 104 | args.shuffle_metadata, 105 | ), 106 | axis=1, 107 | ) 108 | 109 | total_tried = 0 110 | total_correct = 0 111 | output_rows = [] 112 | 113 | with tqdm(total=len(df)) as pbar: 114 | for row in df.to_dict("records"): 115 | row = process_row(model, tokenizer, row, args) 116 | output_rows.append(row) 117 | if row["correct"]: 118 | total_correct += 1 119 | total_tried += 1 120 | pbar.update(1) 121 | pbar.set_description( 122 | f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" 123 | ) 124 | 125 | output_df = pd.DataFrame(output_rows) 126 | del output_df["prompt"] 127 | print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) 128 | output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) 129 | # get directory of output_file and create if not exist 130 | output_dir = os.path.dirname(output_file) 131 | if not os.path.exists(output_dir): 132 | os.makedirs(output_dir) 133 | try: 134 | output_df.to_csv(output_file, index=False, float_format="%.2f") 135 | except: 136 | output_df.to_pickle(output_file) 137 | 138 | results = output_df.to_dict("records") 139 | # upload results 140 | with open(prompt_file, "r") as f: 141 | prompt = f.read() 142 | if args.upload_url is not None: 143 | upload_results( 144 | results=results, 145 | url=args.upload_url, 146 | runner_type="mlx_runner", 147 | prompt=prompt, 148 | args=args, 149 | ) 150 | -------------------------------------------------------------------------------- /runners/openai_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import time 3 | from concurrent.futures import ThreadPoolExecutor, as_completed 4 | import json 5 | 6 | import pandas as pd 7 | import sqlparse 8 | from tqdm import tqdm 9 | 10 | from eval.eval import compare_query_results 11 | from utils.creds import db_creds_all 12 | from utils.dialects import convert_postgres_ddl_to_dialect 13 | from utils.gen_prompt import to_prompt_schema 14 | from utils.questions import prepare_questions_df 15 | from utils.reporting import upload_results 16 | from utils.llm import chat_openai 17 | 18 | 19 | def generate_prompt( 20 | prompt_file, 21 | question, 22 | db_name, 23 | db_type, 24 | instructions="", 25 | k_shot_prompt="", 26 | table_metadata_string="", 27 | public_data=True, 28 | shuffle=True, 29 | ): 30 | if public_data: 31 | from defog_data.metadata import dbs 32 | import defog_data.supplementary as sup 33 | else: 34 | from defog_data_private.metadata import dbs 35 | import defog_data_private.supplementary as sup 36 | 37 | with open(prompt_file, "r") as f: 38 | prompt = json.load(f) 39 | 40 | if table_metadata_string == "": 41 | md = dbs[db_name]["table_metadata"] 42 | pruned_metadata_ddl = to_prompt_schema(md, shuffle) 43 | pruned_metadata_ddl = convert_postgres_ddl_to_dialect( 44 | postgres_ddl=pruned_metadata_ddl, 45 | to_dialect=db_type, 46 | db_name=db_name, 47 | ) 48 | column_join = sup.columns_join.get(db_name, {}) 49 | join_list = [] 50 | for values in column_join.values(): 51 | if isinstance(values[0], tuple): 52 | for col_pair in values: 53 | col_1, col_2 = col_pair 54 | join_str = f"{col_1} can be joined with {col_2}" 55 | if join_str not in join_list: 56 | join_list.append(join_str) 57 | else: 58 | col_1, col_2 = values[0] 59 | join_str = f"{col_1} can be joined with {col_2}" 60 | if join_str not in join_list: 61 | join_list.append(join_str) 62 | if len(join_list) > 0: 63 | join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) 64 | else: 65 | join_str = "" 66 | pruned_metadata_str = pruned_metadata_ddl + join_str 67 | else: 68 | pruned_metadata_str = table_metadata_string 69 | 70 | if prompt[0]["role"] == "system": 71 | prompt[0]["content"] = prompt[0]["content"].format( 72 | db_type=db_type, 73 | ) 74 | prompt[1]["content"] = prompt[1]["content"].format( 75 | user_question=question, 76 | instructions=instructions, 77 | table_metadata_string=pruned_metadata_str, 78 | k_shot_prompt=k_shot_prompt, 79 | ) 80 | else: 81 | prompt[0]["content"] = prompt[1]["content"].format( 82 | db_type=db_type, 83 | user_question=question, 84 | instructions=instructions, 85 | table_metadata_string=pruned_metadata_str, 86 | k_shot_prompt=k_shot_prompt, 87 | ) 88 | 89 | return prompt 90 | 91 | 92 | def process_row(row, model_name, args): 93 | start_time = time() 94 | messages = generate_prompt( 95 | prompt_file=args.prompt_file[0], 96 | question=row["question"], 97 | db_name=row["db_name"], 98 | db_type=args.db_type, 99 | instructions=row["instructions"], 100 | k_shot_prompt=row["k_shot_prompt"], 101 | table_metadata_string=row["table_metadata_string"], 102 | public_data=not args.use_private_data, 103 | shuffle=args.shuffle_metadata, 104 | ) 105 | try: 106 | response = chat_openai(messages=messages, model=model_name, temperature=0.0) 107 | generated_query = ( 108 | response.content.split("```sql", 1)[-1].split("```", 1)[0].strip() 109 | ) 110 | try: 111 | generated_query = sqlparse.format( 112 | generated_query, reindent=True, keyword_case="upper" 113 | ) 114 | except: 115 | pass 116 | return { 117 | "query": generated_query, 118 | "reason": "", 119 | "err": "", 120 | "latency_seconds": time() - start_time, 121 | "tokens_used": response.input_tokens + response.output_tokens, 122 | "cost_in_cents": response.cost_in_cents, 123 | } 124 | except Exception as e: 125 | return { 126 | "query": "", 127 | "reason": "", 128 | "err": f"GENERATION ERROR: {str(e)}", 129 | "latency_seconds": time() - start_time, 130 | "tokens_used": 0, 131 | "cost_in_cents": None, 132 | } 133 | 134 | 135 | def run_openai_eval(args): 136 | # get params from args 137 | questions_file_list = args.questions_file 138 | prompt_file_list = args.prompt_file 139 | output_file_list = args.output_file 140 | num_questions = args.num_questions 141 | k_shot = args.k_shot 142 | db_type = args.db_type 143 | cot_table_alias = args.cot_table_alias 144 | 145 | for questions_file, prompt_file, output_file in zip( 146 | questions_file_list, prompt_file_list, output_file_list 147 | ): 148 | print(f"Using prompt file {prompt_file}") 149 | print("Preparing questions...") 150 | print( 151 | f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" 152 | ) 153 | question_query_df = prepare_questions_df( 154 | questions_file, db_type, num_questions, k_shot, cot_table_alias 155 | ) 156 | input_rows = question_query_df.to_dict("records") 157 | output_rows = [] 158 | with ThreadPoolExecutor(args.parallel_threads) as executor: 159 | futures = [] 160 | for row in input_rows: 161 | generated_query_fut = executor.submit( 162 | process_row, 163 | row=row, 164 | model_name=args.model, 165 | args=args, 166 | ) 167 | futures.append(generated_query_fut) 168 | 169 | total_tried = 0 170 | total_correct = 0 171 | for f in (pbar := tqdm(as_completed(futures), total=len(futures))): 172 | total_tried += 1 173 | i = futures.index(f) 174 | row = input_rows[i] 175 | result_dict = f.result() 176 | query_gen = result_dict["query"] 177 | reason = result_dict["reason"] 178 | err = result_dict["err"] 179 | # save custom metrics 180 | if "latency_seconds" in result_dict: 181 | row["latency_seconds"] = result_dict["latency_seconds"] 182 | if "tokens_used" in result_dict: 183 | row["tokens_used"] = result_dict["tokens_used"] 184 | if "cost_in_cents" in result_dict: 185 | row["cost_in_cents"] = result_dict["cost_in_cents"] 186 | row["generated_query"] = query_gen 187 | row["reason"] = reason 188 | row["error_msg"] = err 189 | # save failures into relevant columns in the dataframe 190 | if "GENERATION ERROR" in err: 191 | row["error_query_gen"] = 1 192 | else: 193 | expected_query = row["query"] 194 | db_name = row["db_name"] 195 | db_type = row["db_type"] 196 | try: 197 | exact_match, correct = compare_query_results( 198 | query_gold=expected_query, 199 | query_gen=query_gen, 200 | db_name=db_name, 201 | db_type=db_type, 202 | question=row["question"], 203 | query_category=row["query_category"], 204 | db_creds=db_creds_all[db_type], 205 | ) 206 | if correct: 207 | total_correct += 1 208 | row["correct"] = 1 209 | row["error_msg"] = "" 210 | else: 211 | row["correct"] = 0 212 | row["error_msg"] = "INCORRECT RESULTS" 213 | except Exception as e: 214 | row["correct"] = 0 215 | row["error_db_exec"] = 1 216 | row["error_msg"] = f"EXECUTION ERROR: {str(e)}" 217 | output_rows.append(row) 218 | pbar.set_description( 219 | f"Accuracy: {round(total_correct/total_tried * 100, 2)}% ({total_correct}/{total_tried})" 220 | ) 221 | 222 | # save results to csv 223 | output_df = pd.DataFrame(output_rows) 224 | output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) 225 | if "prompt" in output_df.columns: 226 | del output_df["prompt"] 227 | # get num rows, mean correct, mean error_db_exec for each query_category 228 | agg_stats = ( 229 | output_df.groupby("query_category") 230 | .agg( 231 | num_rows=("db_name", "count"), 232 | mean_correct=("correct", "mean"), 233 | mean_error_db_exec=("error_db_exec", "mean"), 234 | ) 235 | .reset_index() 236 | ) 237 | print(agg_stats) 238 | # get directory of output_file and create if not exist 239 | output_dir = os.path.dirname(output_file) 240 | if not os.path.exists(output_dir): 241 | os.makedirs(output_dir) 242 | output_df.to_csv(output_file, index=False, float_format="%.2f") 243 | 244 | # get average rate of correct results 245 | avg_subset = output_df["correct"].sum() / len(output_df) 246 | print(f"Average correct rate: {avg_subset:.2f}") 247 | 248 | results = output_df.to_dict("records") 249 | with open( 250 | f"./eval-visualizer/public/{output_file.split('/')[-1].replace('.csv', '.json')}", 251 | "w", 252 | ) as f: 253 | json.dump(results, f) 254 | 255 | print("Total cost of evaluation (in cents): ", output_df["cost_in_cents"].sum()) 256 | -------------------------------------------------------------------------------- /runners/together_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | from concurrent.futures import ThreadPoolExecutor, as_completed 3 | from typing import Dict 4 | 5 | from eval.eval import compare_query_results 6 | import pandas as pd 7 | from utils.gen_prompt import generate_prompt 8 | from utils.questions import prepare_questions_df 9 | from utils.creds import db_creds_all 10 | from tqdm import tqdm 11 | from time import time 12 | from together import Together 13 | from utils.reporting import upload_results 14 | 15 | 16 | client = Together(api_key=os.environ.get("TOGETHER_API_KEY")) 17 | 18 | 19 | def process_row(row: Dict, model: str): 20 | start_time = time() 21 | if model.startswith("meta-llama"): 22 | stop = ["<|eot_id|>", "<|eom_id|>"] 23 | else: 24 | print( 25 | "Undefined stop token(s). Please specify the stop token(s) for the model." 26 | ) 27 | stop = [] 28 | messages = row["prompt"] 29 | response = client.chat.completions.create( 30 | model=model, 31 | messages=messages, 32 | max_tokens=800, 33 | temperature=0.0, 34 | stop=stop, 35 | stream=False, 36 | ) 37 | content = response.choices[0].message.content 38 | generated_query = content.split("```", 1)[0].strip() 39 | end_time = time() 40 | 41 | row["generated_query"] = generated_query 42 | row["latency_seconds"] = end_time - start_time 43 | row["tokens_used"] = None 44 | golden_query = row["query"] 45 | db_name = row["db_name"] 46 | db_type = row["db_type"] 47 | question = row["question"] 48 | query_category = row["query_category"] 49 | table_metadata_string = row["table_metadata_string"] 50 | exact_match = correct = 0 51 | 52 | try: 53 | exact_match, correct = compare_query_results( 54 | query_gold=golden_query, 55 | query_gen=generated_query, 56 | db_name=db_name, 57 | db_type=db_type, 58 | db_creds=db_creds_all[row["db_type"]], 59 | question=question, 60 | query_category=query_category, 61 | table_metadata_string=table_metadata_string, 62 | ) 63 | row["exact_match"] = int(exact_match) 64 | row["correct"] = int(correct) 65 | row["error_msg"] = "" 66 | except Exception as e: 67 | row["error_db_exec"] = 1 68 | row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" 69 | 70 | return row 71 | 72 | 73 | def run_together_eval(args): 74 | # get params from args 75 | questions_file_list = args.questions_file 76 | prompt_file_list = args.prompt_file 77 | num_questions = args.num_questions 78 | public_data = not args.use_private_data 79 | output_file_list = args.output_file 80 | k_shot = args.k_shot 81 | max_workers = args.parallel_threads 82 | db_type = args.db_type 83 | decimal_points = args.decimal_points 84 | model = args.model 85 | cot_table_alias = args.cot_table_alias 86 | 87 | for questions_file, prompt_file, output_file in zip( 88 | questions_file_list, prompt_file_list, output_file_list 89 | ): 90 | if not prompt_file.endswith(".json"): 91 | raise ValueError(f"Prompt file must be a JSON file. Got {prompt_file}") 92 | print(f"Using prompt file {prompt_file}") 93 | # get questions 94 | print("Preparing questions...") 95 | print( 96 | f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" 97 | ) 98 | df = prepare_questions_df( 99 | questions_file, db_type, num_questions, k_shot, cot_table_alias 100 | ) 101 | # create a prompt for each question 102 | # note that the prompt for together ai uses the openai chat API 103 | df["prompt"] = df.apply( 104 | lambda row: generate_prompt( 105 | prompt_file, 106 | row["question"], 107 | row["db_name"], 108 | row["db_type"], 109 | row["instructions"], 110 | row["k_shot_prompt"], 111 | row["glossary"], 112 | row["table_metadata_string"], 113 | row["prev_invalid_sql"], 114 | row["prev_error_msg"], 115 | row["question_0"], 116 | row["query_0"], 117 | row["question_1"], 118 | row["query_1"], 119 | row["cot_instructions"], 120 | row["cot_pregen"], 121 | public_data, 122 | args.num_columns, 123 | args.shuffle_metadata, 124 | row["table_aliases"], 125 | ), 126 | axis=1, 127 | ) 128 | 129 | total_tried = 0 130 | total_correct = 0 131 | output_rows = [] 132 | 133 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 134 | futures = [] 135 | for row in df.to_dict("records"): 136 | futures.append(executor.submit(process_row, row, model)) 137 | 138 | with tqdm(as_completed(futures), total=len(futures)) as pbar: 139 | for f in pbar: 140 | row = f.result() 141 | output_rows.append(row) 142 | if row["correct"]: 143 | total_correct += 1 144 | total_tried += 1 145 | pbar.update(1) 146 | pbar.set_description( 147 | f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" 148 | ) 149 | 150 | output_df = pd.DataFrame(output_rows) 151 | del output_df["prompt"] 152 | print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) 153 | output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) 154 | # get directory of output_file and create if not exist 155 | output_dir = os.path.dirname(output_file) 156 | if not os.path.exists(output_dir): 157 | os.makedirs(output_dir) 158 | try: 159 | output_df.to_csv(output_file, index=False, float_format="%.2f") 160 | except: 161 | output_df.to_pickle(output_file) 162 | 163 | results = output_df.to_dict("records") 164 | # upload results 165 | with open(prompt_file, "r") as f: 166 | prompt = f.read() 167 | if args.upload_url is not None: 168 | upload_results( 169 | results=results, 170 | url=args.upload_url, 171 | runner_type="api_runner", 172 | prompt=prompt, 173 | args=args, 174 | ) 175 | -------------------------------------------------------------------------------- /runners/vllm_runner.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import List 4 | import sqlparse 5 | from vllm import LLM, SamplingParams 6 | from vllm.lora.request import LoRARequest 7 | from eval.eval import compare_query_results 8 | import pandas as pd 9 | from utils.gen_prompt import generate_prompt 10 | from utils.questions import prepare_questions_df 11 | from utils.creds import db_creds_all 12 | import time 13 | import torch 14 | from transformers import AutoTokenizer 15 | from tqdm import tqdm 16 | from utils.reporting import upload_results 17 | 18 | 19 | def run_vllm_eval(args): 20 | # get params from args 21 | questions_file_list = args.questions_file 22 | prompt_file_list = args.prompt_file 23 | num_questions = args.num_questions 24 | public_data = not args.use_private_data 25 | model_name = args.model 26 | output_file_list = args.output_file 27 | num_beams = args.num_beams 28 | k_shot = args.k_shot 29 | db_type = args.db_type 30 | cot_table_alias = args.cot_table_alias 31 | enable_lora = True if args.adapter else False 32 | lora_request = LoRARequest("sql_adapter", 1, args.adapter) if args.adapter else None 33 | 34 | # initialize model only once as it takes a while 35 | print(f"Preparing {model_name}") 36 | tokenizer = AutoTokenizer.from_pretrained(model_name) 37 | tokenizer.pad_token_id = tokenizer.eos_token_id 38 | if not args.quantized: 39 | llm = LLM( 40 | model=model_name, 41 | tensor_parallel_size=1, 42 | enable_lora=enable_lora, 43 | max_model_len=4096, 44 | max_lora_rank=64, 45 | ) 46 | else: 47 | llm = LLM( 48 | model=model_name, 49 | tensor_parallel_size=1, 50 | quantization="AWQ", 51 | enable_lora=enable_lora, 52 | max_model_len=4096, 53 | max_lora_rank=64, 54 | ) 55 | 56 | sampling_params = SamplingParams( 57 | n=1, 58 | best_of=num_beams, 59 | use_beam_search=num_beams != 1, 60 | stop_token_ids=[tokenizer.eos_token_id], 61 | max_tokens=1000, 62 | temperature=0, 63 | ) 64 | 65 | for questions_file, prompt_file, output_file in zip( 66 | questions_file_list, prompt_file_list, output_file_list 67 | ): 68 | print(f"Using prompt file {prompt_file}") 69 | # get questions 70 | print("Preparing questions...") 71 | print( 72 | f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" 73 | ) 74 | df = prepare_questions_df( 75 | questions_file, db_type, num_questions, k_shot, cot_table_alias 76 | ) 77 | # create a prompt for each question 78 | df["prompt"] = df.apply( 79 | lambda row: generate_prompt( 80 | prompt_file, 81 | row["question"], 82 | row["db_name"], 83 | row["db_type"], 84 | row["instructions"], 85 | row["k_shot_prompt"], 86 | row["glossary"], 87 | row["table_metadata_string"], 88 | row["prev_invalid_sql"], 89 | row["prev_error_msg"], 90 | row["question_0"], 91 | row["query_0"], 92 | row["question_1"], 93 | row["query_1"], 94 | row["cot_instructions"], 95 | row["cot_pregen"], 96 | public_data, 97 | args.num_columns, 98 | args.shuffle_metadata, 99 | ), 100 | axis=1, 101 | ) 102 | print(f"Prepared {len(df)} question(s) from {questions_file}") 103 | 104 | def chunk_dataframe(df, chunk_size): 105 | """Returns successive chunk_size chunks from df as a list of dfs""" 106 | df_chunks = [] 107 | for i in range(0, len(df), chunk_size): 108 | df_i = df.iloc[i : min(i + chunk_size, len(df))] 109 | print( 110 | f"Chunk {i//chunk_size+1}/{len(df)//chunk_size+1} with {len(df_i)} questions" 111 | ) 112 | df_chunks.append(df_i) 113 | return df_chunks 114 | 115 | df_chunks = chunk_dataframe(df, args.batch_size) 116 | 117 | total_tried = 0 118 | total_correct = 0 119 | output_rows = [] 120 | 121 | print(f"Generating completions") 122 | 123 | for batch in (pbar := tqdm(df_chunks, total=len(df))): 124 | prompts = batch["prompt"].tolist() 125 | print(f"Generating completions for {len(prompts)} prompts") 126 | prompt_tokens = [] 127 | prompt_token_sizes = [] 128 | for prompt in prompts: 129 | token_ids = tokenizer.encode(prompt, add_special_tokens=False) 130 | # add bos token if not already present in prompt 131 | if token_ids[0] != tokenizer.bos_token_id: 132 | token_ids = [tokenizer.bos_token_id] + token_ids 133 | prompt_tokens.append(token_ids) 134 | prompt_token_sizes.append(len(token_ids)) 135 | print( 136 | f"Average prompt size: {sum(prompt_token_sizes)/len(prompt_token_sizes):.0f}" 137 | ) 138 | start_time = time.time() 139 | # outputs = llm.generate(prompts, sampling_params) # if you prefer to use prompts instead of token_ids 140 | outputs = llm.generate( 141 | sampling_params=sampling_params, 142 | prompt_token_ids=prompt_tokens, 143 | use_tqdm=False, 144 | lora_request=lora_request, 145 | ) 146 | print( 147 | f"Generated {len(outputs)} completions in {time.time() - start_time:.2f} seconds" 148 | ) 149 | time_taken = time.time() - start_time 150 | for row, output in zip(batch.to_dict("records"), outputs): 151 | generated_query = ( 152 | output.outputs[0].text.split(";")[0].split("```")[0].strip() + ";" 153 | ) 154 | normalized_query = sqlparse.format( 155 | generated_query, keyword_case="upper", strip_whitespace=True 156 | ) 157 | row["generated_query"] = normalized_query 158 | row["tokens_used"] = len(output.outputs[0].token_ids) 159 | row["latency_seconds"] = time_taken / len(batch) 160 | 161 | golden_query = row["query"] 162 | db_name = row["db_name"] 163 | db_type = row["db_type"] 164 | question = row["question"] 165 | query_category = row["query_category"] 166 | table_metadata_string = row["table_metadata_string"] 167 | exact_match = correct = 0 168 | db_creds = db_creds_all[db_type] 169 | try: 170 | exact_match, correct = compare_query_results( 171 | query_gold=golden_query, 172 | query_gen=generated_query, 173 | db_name=db_name, 174 | db_type=db_type, 175 | db_creds=db_creds, 176 | question=question, 177 | query_category=query_category, 178 | table_metadata_string=table_metadata_string, 179 | decimal_points=args.decimal_points, 180 | ) 181 | row["exact_match"] = int(exact_match) 182 | row["correct"] = int(correct) 183 | row["error_msg"] = "" 184 | if correct: 185 | total_correct += 1 186 | except Exception as e: 187 | row["error_db_exec"] = 1 188 | row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" 189 | 190 | total_tried += 1 191 | output_rows.append(row) 192 | pbar.update(len(batch)) 193 | pbar.set_description( 194 | f"Correct so far: {total_correct}/{(total_tried)} ({100*total_correct/(total_tried):.2f}%)" 195 | ) 196 | df = pd.DataFrame(output_rows) 197 | del df["prompt"] 198 | print(df.groupby("query_category")[["exact_match", "correct"]].mean()) 199 | df = df.sort_values(by=["db_name", "query_category", "question"]) 200 | print(f"Average tokens generated: {df['tokens_used'].mean():.1f}") 201 | # get directory of output_file and create if not exist 202 | output_dir = os.path.dirname(output_file) 203 | if not os.path.exists(output_dir): 204 | os.makedirs(output_dir) 205 | df.to_csv(output_file, index=False, float_format="%.2f") 206 | print(f"Saved results to {output_file}") 207 | 208 | results = df.to_dict("records") 209 | # upload results 210 | with open(prompt_file, "r") as f: 211 | prompt = f.read() 212 | if args.upload_url is not None: 213 | upload_results( 214 | results=results, 215 | url=args.upload_url, 216 | runner_type="vllm_runner", 217 | prompt=prompt, 218 | args=args, 219 | ) 220 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/defog-ai/sql-eval/c7986a7e0922e66b9dafe9096f454e41fa89589b/tests/__init__.py -------------------------------------------------------------------------------- /tests/local_db_tests.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from unittest import mock 3 | from eval.eval import get_all_minimal_queries, query_postgres_db, query_postgres_temp_db 4 | from pandas.testing import assert_frame_equal 5 | 6 | 7 | def test_questions_non_null(): 8 | # read in questions_gen_postgres.csv 9 | df = pd.read_csv("data/questions_gen_postgres.csv") 10 | # for each row, run the query with eval.query_postgres_db 11 | # check that the result is not null 12 | for i, row in df.iterrows(): 13 | queries_gold = get_all_minimal_queries(row["query"]) 14 | for query_gold in queries_gold: 15 | df_result = query_postgres_db(query_gold, row["db_name"]) 16 | if len(df_result) == 0: 17 | print(i, query_gold) 18 | print(df_result) 19 | 20 | 21 | @mock.patch("pandas.read_sql_query") 22 | def test_query_postgres_temp_db(mock_pd_read_sql_query): 23 | # note that we need to mock create_engine 24 | db_name = "db_temp" 25 | db_creds = { 26 | "host": "localhost", 27 | "port": 5432, 28 | "user": "postgres", 29 | "password": "postgres", 30 | } 31 | 32 | table_metadata_string = "CREATE TABLE table_name (A INT, B INT);" 33 | timeout = 10 34 | query = "SELECT * FROM table_name;" 35 | df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) 36 | mock_pd_read_sql_query.return_value = df 37 | 38 | results_df = query_postgres_temp_db( 39 | query, db_name, db_creds, table_metadata_string, timeout 40 | ) 41 | assert mock_pd_read_sql_query.call_count == 1 42 | assert_frame_equal(results_df, df) 43 | -------------------------------------------------------------------------------- /tests/test_utils_pruning.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/defog-ai/sql-eval/c7986a7e0922e66b9dafe9096f454e41fa89589b/tests/test_utils_pruning.py -------------------------------------------------------------------------------- /upload_wandb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import wandb 4 | 5 | # Step 1. Specify the folder where the result csv files are stored 6 | result_folder = ( 7 | "results/sqlcoder_8b_fullft_ds_012_llama3_join_after_mgn1_b1_0900_b2_0990_steps_600" 8 | ) 9 | # Step 2. Specify the wandb run id 10 | wandb_run_id = "fqianfsq" 11 | 12 | # rest of the script logic 13 | csv_files = [] 14 | for f in os.listdir(result_folder): 15 | if f.endswith(".csv"): 16 | csv_files.append(f) 17 | print(f"Found {len(csv_files)} csv files in {result_folder}") 18 | 19 | # Load results from csv file into dataframe 20 | results_dfs = [] 21 | for csv_file_name in csv_files: 22 | file_path = os.path.join(result_folder, csv_file_name) 23 | df_i = pd.read_csv(file_path, comment="#") 24 | df_i["model"] = csv_file_name.rsplit(".csv", 1)[0] 25 | results_dfs.append(df_i) 26 | results_df = pd.concat(results_dfs, ignore_index=True) 27 | print(f"Loaded {results_df.shape[0]} results from {len(csv_files)} csv files") 28 | 29 | s = results_df.groupby("model")["correct"].mean() 30 | s = pd.DataFrame(s) 31 | s["file_name"] = s.index 32 | s["benchmark"] = s["file_name"].str.extract(r"_(advanced|basic|v1|idk)") 33 | s["checkpoint"] = s["file_name"].str.extract(r"c(\d+)_").astype(int) 34 | s["cot"] = s["file_name"].str.extract(r"_(cot)").fillna("no_cot") 35 | s = s.reset_index(drop=True) 36 | 37 | # Get unique checkpoints 38 | checkpoints = s["checkpoint"].unique() 39 | checkpoints.sort() 40 | print(f"Found {len(checkpoints)} checkpoints: {checkpoints}") 41 | 42 | # Continue existing run, specifying the project and the run ID 43 | run = wandb.init(project="huggingface", id=wandb_run_id, resume="must") 44 | 45 | # get current step, so that we can log incrementally after it 46 | # this is because wandb doesn't allow logging back to previous steps 47 | current_step = run.step 48 | print(f"Current step: {current_step}") 49 | 50 | for checkpoint in checkpoints: 51 | checkpoint_metrics = {} 52 | for benchmark in ["advanced", "basic", "v1", "idk"]: 53 | for cot in ["cot", "no_cot"]: 54 | mask = ( 55 | (s["checkpoint"] == checkpoint) 56 | & (s["benchmark"] == benchmark) 57 | & (s["cot"] == cot) 58 | ) 59 | if mask.sum() == 1: 60 | row = s[mask] 61 | metric_name = f"vllm/{benchmark}" 62 | if cot == "cot": 63 | metric_name += "_cot" 64 | metric_value = row["correct"].values[0] 65 | checkpoint_metrics[metric_name] = metric_value 66 | print(f"Logging checkpoint {checkpoint} metrics:") 67 | for k, v in checkpoint_metrics.items(): 68 | print(f"\t{k}: {v}") 69 | # we log the metrics at the current step + checkpoint 70 | wandb.log(checkpoint_metrics, step=current_step + checkpoint) 71 | 72 | # Finish the run 73 | run.finish() 74 | -------------------------------------------------------------------------------- /utils/aliases.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Dict, List 3 | 4 | reserved_keywords = [ 5 | "abs", 6 | "all", 7 | "and", 8 | "any", 9 | "avg", 10 | "as", 11 | "at", 12 | "asc", 13 | "bit", 14 | "by", 15 | "day", 16 | "dec", 17 | "do", 18 | "div", 19 | "end", 20 | "for", 21 | "go", 22 | "in", 23 | "is", 24 | "not", 25 | "or", 26 | "to", 27 | ] 28 | 29 | 30 | def get_table_names(md: str) -> List[str]: 31 | """ 32 | Given a string of metadata formatted as a series of 33 | CREATE TABLE statements, return a list of table names in the same order as 34 | they appear in the metadata. 35 | """ 36 | table_names = [] 37 | if "CREATE TABLE" not in md: 38 | return table_names 39 | for table_md_str in md.split(");"): 40 | if "CREATE TABLE " not in table_md_str: 41 | continue 42 | header = table_md_str.split("(", 1)[0] 43 | table_name = header.split("CREATE TABLE ", 1)[1].strip() 44 | table_names.append(table_name) 45 | return table_names 46 | 47 | 48 | def generate_aliases_dict( 49 | table_names: List, reserved_keywords: List[str] = reserved_keywords 50 | ) -> Dict[str, str]: 51 | """ 52 | Generate aliases for table names as a dictionary mapping of table names to aliases 53 | Aliases should always be in lower case 54 | """ 55 | aliases = {} 56 | for original_table_name in table_names: 57 | if "." in original_table_name: 58 | table_name = original_table_name.rsplit(".", 1)[-1] 59 | else: 60 | table_name = original_table_name 61 | if "_" in table_name: 62 | # get the first letter of each subword delimited by "_" 63 | alias = "".join([word[0] for word in table_name.split("_")]).lower() 64 | else: 65 | # if camelCase, get the first letter of each subword 66 | # otherwise defaults to just getting the 1st letter of the table_name 67 | temp_table_name = table_name[0].upper() + table_name[1:] 68 | alias = "".join( 69 | [char for char in temp_table_name if char.isupper()] 70 | ).lower() 71 | # append ending numbers to alias if table_name ends with digits 72 | m = re.match(r".*(\d+)$", table_name) 73 | if m: 74 | alias += m.group(1) 75 | if alias in aliases.values() or alias in reserved_keywords: 76 | alias = table_name[:2].lower() 77 | if alias in aliases.values() or alias in reserved_keywords: 78 | alias = table_name[:3].lower() 79 | num = 2 80 | while alias in aliases.values() or alias in reserved_keywords: 81 | alias = table_name[0].lower() + str(num) 82 | num += 1 83 | 84 | aliases[original_table_name] = alias 85 | return aliases 86 | 87 | 88 | def mk_alias_str(table_aliases: Dict[str, str]) -> str: 89 | """ 90 | Given a dictionary of table names to aliases, return a string of aliases in the form: 91 | -- table1 AS t1 92 | -- table2 AS t2 93 | """ 94 | aliases_str = "" 95 | for table_name, alias in table_aliases.items(): 96 | aliases_str += f"-- {table_name} AS {alias}\n" 97 | return aliases_str 98 | 99 | 100 | def generate_aliases( 101 | table_names: List, reserved_keywords: List[str] = reserved_keywords 102 | ) -> str: 103 | """ 104 | Generate aliases for table names in a comment str form, eg 105 | -- table1 AS t1 106 | -- table2 AS t2 107 | """ 108 | aliases = generate_aliases_dict(table_names, reserved_keywords) 109 | return mk_alias_str(aliases) 110 | -------------------------------------------------------------------------------- /utils/api_server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from typing import AsyncGenerator 4 | 5 | from fastapi import FastAPI, Request 6 | from fastapi.responses import JSONResponse, Response, StreamingResponse 7 | import uvicorn 8 | 9 | from vllm.engine.arg_utils import AsyncEngineArgs 10 | from vllm.engine.async_llm_engine import AsyncLLMEngine 11 | from vllm.sampling_params import SamplingParams 12 | from vllm.utils import random_uuid 13 | from vllm.lora.request import LoRARequest 14 | from vllm import __version__ as vllm_version 15 | 16 | TIMEOUT_KEEP_ALIVE = 5 # seconds. 17 | app = FastAPI() 18 | engine = None 19 | 20 | # This is a fork of https://github.com/vllm-project/vllm/blob/0650e5935b0f6af35fb2acf71769982c47b804d7/vllm/entrypoints/api_server.py 21 | # with the following changes: 22 | # - remove the prompt from response. we only return the generated output to avoid parsing errors when including the prompt. 23 | # - don't add special_tokens (bos/eos) and only add it if it's missing from the prompt 24 | # You can start it similar to how you would with the usual vllm api server: 25 | # ``` 26 | # python3 utils/api_server.py \ 27 | # --model "${model_path}" \ 28 | # --tensor-parallel-size 4 \ 29 | # --dtype float16 \ 30 | # --max-model-len 4096 \ 31 | # --port 5000 \ 32 | # --gpu-memory-utilization 0.90 \ 33 | # --enable-lora \ 34 | # --max-lora-rank 64 \ 35 | 36 | 37 | @app.get("/health") 38 | async def health() -> Response: 39 | """Health check.""" 40 | return Response(status_code=200) 41 | 42 | 43 | @app.post("/generate") 44 | async def generate(request: Request) -> Response: 45 | """Generate completion for the request. 46 | 47 | The request should be a JSON object with the following fields: 48 | - prompt: the prompt to use for the generation. 49 | - stream: whether to stream the results or not. 50 | - other fields: the sampling parameters (See `SamplingParams` for details). 51 | """ 52 | request_dict = await request.json() 53 | prompt = request_dict.pop("prompt") 54 | stream = request_dict.pop("stream", False) 55 | sql_lora_path = request_dict.pop("sql_lora_path", None) 56 | request_dict.pop("sql_lora_name", None) 57 | lora_request = ( 58 | LoRARequest(lora_name="sql_adapter", lora_int_id=1, lora_path=sql_lora_path) 59 | if sql_lora_path 60 | else None 61 | ) 62 | if vllm_version >= "0.6.2": 63 | # remove use_beam_search if present as it's no longer supported 64 | # see https://github.com/vllm-project/vllm/releases/tag/v0.6.2 65 | if "use_beam_search" in request_dict: 66 | request_dict.pop("use_beam_search") 67 | sampling_params = SamplingParams(**request_dict) 68 | request_id = random_uuid() 69 | tokenizer = await engine.get_tokenizer() 70 | prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False) 71 | if prompt_token_ids[0] != tokenizer.bos_token_id: 72 | prompt_token_ids = [tokenizer.bos_token_id] + prompt_token_ids 73 | 74 | if vllm_version >= "0.6.3": 75 | from vllm import TokensPrompt 76 | 77 | results_generator = engine.generate( 78 | prompt=TokensPrompt(prompt_token_ids=prompt_token_ids), 79 | sampling_params=sampling_params, 80 | request_id=request_id, 81 | lora_request=lora_request, 82 | ) 83 | elif vllm_version >= "0.4.2": 84 | results_generator = engine.generate( 85 | inputs={"prompt_token_ids": prompt_token_ids}, 86 | sampling_params=sampling_params, 87 | request_id=request_id, 88 | lora_request=lora_request, 89 | ) 90 | else: 91 | results_generator = engine.generate( 92 | prompt=None, 93 | sampling_params=sampling_params, 94 | request_id=request_id, 95 | prompt_token_ids=prompt_token_ids, 96 | lora_request=LoRARequest("sql_adapter", 1, sql_lora_path), 97 | ) 98 | 99 | # Streaming case 100 | async def stream_results() -> AsyncGenerator[bytes, None]: 101 | async for request_output in results_generator: 102 | prompt = request_output.prompt 103 | text_outputs = [prompt + output.text for output in request_output.outputs] 104 | ret = {"text": text_outputs} 105 | yield (json.dumps(ret) + "\0").encode("utf-8") 106 | 107 | if stream: 108 | return StreamingResponse(stream_results()) 109 | 110 | # Non-streaming case 111 | final_output = None 112 | async for request_output in results_generator: 113 | if await request.is_disconnected(): 114 | # Abort the request if the client disconnects. 115 | await engine.abort(request_id) 116 | return Response(status_code=499) 117 | final_output = request_output 118 | 119 | assert final_output is not None 120 | prompt = final_output.prompt 121 | text_outputs = [] 122 | for output in final_output.outputs: 123 | text_outputs.append(output.text) 124 | 125 | try: 126 | logprobs = [output.logprobs for output in final_output.outputs] 127 | except Exception as e: 128 | logprobs = [] 129 | print(e) 130 | print("Could not extract logprobs") 131 | 132 | logprobs = logprobs[0] 133 | logprobs_json = [] 134 | if logprobs: 135 | for item in logprobs: 136 | # do this to make our response JSON serializable 137 | item = {key: value.__dict__ for key, value in item.items()} 138 | logprobs_json.append(item) 139 | 140 | ret = {"text": text_outputs, "logprobs": logprobs_json} 141 | return JSONResponse(ret) 142 | 143 | 144 | if __name__ == "__main__": 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument("--host", type=str, default=None) 147 | parser.add_argument("--port", type=int, default=8000) 148 | parser.add_argument("--ssl-keyfile", type=str, default=None) 149 | parser.add_argument("--ssl-certfile", type=str, default=None) 150 | parser.add_argument( 151 | "--root-path", 152 | type=str, 153 | default=None, 154 | help="FastAPI root_path when app is behind a path based routing proxy", 155 | ) 156 | parser = AsyncEngineArgs.add_cli_args(parser) 157 | args = parser.parse_args() 158 | 159 | engine_args = AsyncEngineArgs.from_cli_args(args) 160 | engine = AsyncLLMEngine.from_engine_args(engine_args) 161 | 162 | app.root_path = args.root_path 163 | uvicorn.run( 164 | app, 165 | host=args.host, 166 | port=args.port, 167 | log_level="debug", 168 | timeout_keep_alive=TIMEOUT_KEEP_ALIVE, 169 | ssl_keyfile=args.ssl_keyfile, 170 | ssl_certfile=args.ssl_certfile, 171 | ) 172 | -------------------------------------------------------------------------------- /utils/creds.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | db_creds_all = { 4 | "postgres": { 5 | "host": os.environ.get("DBHOST", "localhost"), 6 | "port": os.environ.get("DBPORT", 5432), 7 | "user": os.environ.get("DBUSER", "postgres"), 8 | "password": os.environ.get("DBPASSWORD", "postgres"), 9 | }, 10 | "snowflake": { 11 | "user": os.environ.get("SFDBUSER"), 12 | "password": os.environ.get("SFDBPASSWORD"), 13 | "account": os.environ.get("SFDBACCOUNT"), 14 | "warehouse": os.environ.get("SFDBWAREHOUSE"), 15 | }, 16 | "mysql": { 17 | "user": "root", 18 | "password": "password", 19 | "host": "localhost", 20 | }, 21 | "bigquery": { 22 | "project": os.environ.get("BIGQUERY_PROJ"), 23 | "creds": os.environ.get("GOOGLE_APPLICATION_CREDENTIALS"), 24 | }, 25 | "sqlite": { 26 | "path_to_folder": os.environ.get("HOME") 27 | + f"/defog-data/sqlite_dbs/", # Path to folder containing sqlite dbs 28 | }, 29 | "tsql": { 30 | "server": os.getenv("TSQL_SERVER"), 31 | "user": "test_user", 32 | "password": "password", 33 | "driver": "{ODBC Driver 17 for SQL Server}", 34 | }, 35 | } 36 | 37 | bq_project = os.environ.get("BQ_PROJECT") 38 | -------------------------------------------------------------------------------- /utils/gen_prompt.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import json 3 | from typing import Dict, List, Optional, Union 4 | import numpy as np 5 | from utils.dialects import ( 6 | ddl_to_bigquery, 7 | ddl_to_mysql, 8 | ddl_to_sqlite, 9 | ddl_to_tsql, 10 | get_schema_names, 11 | ) 12 | 13 | 14 | def to_prompt_schema( 15 | md: Dict[str, List[Dict[str, str]]], seed: Optional[int] = None 16 | ) -> str: 17 | """ 18 | Return a DDL statement for creating tables from a metadata dictionary 19 | `md` has the following structure: 20 | {'table1': [ 21 | {'column_name': 'col1', 'data_type': 'int', 'column_description': 'primary key'}, 22 | {'column_name': 'col2', 'data_type': 'text', 'column_description': 'not null'}, 23 | {'column_name': 'col3', 'data_type': 'text', 'column_description': ''}, 24 | ], 25 | 'table2': [ 26 | ... 27 | ]}, 28 | This is just for converting the dictionary structure of one's metadata into a string 29 | for pasting into prompts, and not meant to be used to initialize a database. 30 | seed is used to shuffle the order of the tables when not None 31 | """ 32 | md_create = "" 33 | table_names = list(md.keys()) 34 | if seed: 35 | np.random.seed(seed) 36 | np.random.shuffle(table_names) 37 | for table in table_names: 38 | md_create += f"CREATE TABLE {table} (\n" 39 | columns = md[table] 40 | if seed: 41 | np.random.seed(seed) 42 | np.random.shuffle(columns) 43 | for i, column in enumerate(columns): 44 | col_name = column["column_name"] 45 | # if column name has spaces, wrap it in double quotes 46 | if " " in col_name: 47 | col_name = f'"{col_name}"' 48 | dtype = column["data_type"] 49 | col_desc = column.get("column_description", "").replace("\n", " ") 50 | if col_desc: 51 | col_desc = f" --{col_desc}" 52 | if i < len(columns) - 1: 53 | md_create += f" {col_name} {dtype},{col_desc}\n" 54 | else: 55 | # avoid the trailing comma for the last line 56 | md_create += f" {col_name} {dtype}{col_desc}\n" 57 | md_create += ");\n" 58 | return md_create 59 | 60 | 61 | def generate_aliases(table_names: list) -> str: 62 | """ 63 | Generate aliases for table names 64 | """ 65 | aliases = {} 66 | reserved_keywords = [ 67 | "all", 68 | "and", 69 | "any", 70 | "as", 71 | "asc", 72 | "do", 73 | "end", 74 | "for", 75 | "in", 76 | "is", 77 | "not", 78 | "to", 79 | ] 80 | for table_name in table_names: 81 | if "." in table_name: 82 | table_name = table_name.split(".", 1)[1] 83 | alias = table_name[0] 84 | if ( 85 | alias in aliases.values() and "_" in table_name 86 | ) or alias.lower() in reserved_keywords: 87 | alias = table_name.split("_")[0][0] + table_name.split("_")[1][0] 88 | if alias in aliases.values() or alias.lower() in reserved_keywords: 89 | alias = table_name[:2] 90 | if alias in aliases.values() or alias.lower() in reserved_keywords: 91 | alias = table_name[:3] 92 | num = 2 93 | while alias in aliases.values() or alias.lower() in reserved_keywords: 94 | alias = table_name[0] + str(num) 95 | num += 1 96 | 97 | aliases[table_name] = alias 98 | 99 | aliases_str = "" 100 | for table_name, alias in aliases.items(): 101 | aliases_str += f"-- {table_name} AS {alias}\n" 102 | return aliases_str 103 | 104 | 105 | def generate_prompt( 106 | prompt_file, 107 | question, 108 | db_name, 109 | db_type="postgres", 110 | instructions="", 111 | k_shot_prompt="", 112 | glossary="", 113 | table_metadata_string="", 114 | prev_invalid_sql="", 115 | prev_error_msg="", 116 | question_0="", 117 | query_0="", 118 | question_1="", 119 | query_1="", 120 | cot_instructions="", 121 | cot_pregen=False, 122 | public_data=True, 123 | columns_to_keep=40, 124 | shuffle_metadata=False, 125 | table_aliases="", 126 | ) -> Union[List[Dict[str, str]], str]: 127 | """ 128 | Generates the prompt for the given question. 129 | If a json file is passed in as the prompt_file, please ensure that it is a list 130 | of dictionaries, which should have the `content` key minimally. 131 | Else, we will treat the file as a string template. 132 | """ 133 | from defog_data.metadata import dbs # to avoid CI error 134 | 135 | is_json = prompt_file.endswith(".json") 136 | if is_json: 137 | with open(prompt_file, "r") as f: 138 | messages_template = json.load(f) 139 | else: 140 | with open(prompt_file, "r") as f: 141 | prompt = f.read() 142 | question_instructions = question + " " + instructions 143 | table_names = [] 144 | 145 | join_str = "" 146 | # retrieve metadata 147 | if table_metadata_string == "": 148 | if public_data: 149 | import defog_data.supplementary as sup 150 | 151 | column_join = sup.columns_join.get(db_name, {}) 152 | else: 153 | import defog_data_private.supplementary as sup 154 | 155 | column_join = sup.columns_join.get(db_name, {}) 156 | 157 | md = dbs[db_name]["table_metadata"] 158 | table_names = list(md.keys()) 159 | table_metadata_ddl = to_prompt_schema(md, shuffle_metadata) 160 | 161 | # get join_str from column_join 162 | join_list = [] 163 | pruned_join_list = [] 164 | for values in column_join.values(): 165 | for col_pair in values: 166 | # add to join_list 167 | col_1, col_2 = col_pair 168 | join_str = f"{col_1} can be joined with {col_2}" 169 | if join_str not in join_list: 170 | join_list.append(join_str) 171 | # add to pruned_join_list if column names are not equal 172 | colname_1 = col_1.rsplit(".", 1)[1] 173 | colname_2 = col_2.rsplit(".", 1)[1] 174 | if colname_1 != colname_2 and join_str not in pruned_join_list: 175 | pruned_join_list.append(join_str) 176 | if len(join_list) > 0: 177 | join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) 178 | else: 179 | join_str = "" 180 | if len(pruned_join_list) > 0: 181 | pruned_join_str = ( 182 | "\nHere is a list of joinable columns with different names:\n" 183 | + "\n".join(pruned_join_list) 184 | ) 185 | else: 186 | pruned_join_str = "" 187 | 188 | # add schema creation statements if relevant 189 | schema_names = get_schema_names(table_metadata_ddl) 190 | if schema_names: 191 | for schema_name in schema_names: 192 | table_metadata_ddl = ( 193 | f"CREATE SCHEMA IF NOT EXISTS {schema_name};\n" + table_metadata_ddl 194 | ) 195 | # remove schema names from cot_instructions if db_type is in ["sqlite", "mysql", "bigquery"] 196 | if db_type in ["sqlite", "mysql", "bigquery"]: 197 | for schema_name in schema_names: 198 | cot_instructions = cot_instructions.replace(f"{schema_name}.", "") 199 | 200 | # transform metadata string to target dialect if necessary 201 | if db_type in ["postgres", "snowflake"]: 202 | table_metadata_string = table_metadata_ddl 203 | elif db_type == "bigquery": 204 | table_metadata_string = ddl_to_bigquery( 205 | table_metadata_ddl, "postgres", db_name, "" 206 | )[0] 207 | elif db_type == "mysql": 208 | table_metadata_string = ddl_to_mysql( 209 | table_metadata_ddl, "postgres", db_name, "" 210 | )[0] 211 | elif db_type == "sqlite": 212 | table_metadata_string = ddl_to_sqlite( 213 | table_metadata_ddl, "postgres", db_name, "" 214 | )[0] 215 | elif db_type == "tsql": 216 | table_metadata_string = ddl_to_tsql( 217 | table_metadata_ddl, "postgres", db_name, "" 218 | )[0] 219 | else: 220 | raise ValueError( 221 | "db_type must be one of postgres, snowflake, bigquery, mysql, sqlite, or tsql" 222 | ) 223 | if glossary == "": 224 | glossary = dbs[db_name]["glossary"] 225 | 226 | instruction_reflections = instructions.replace( 227 | "\nFollow the instructions below to generate the query:", 228 | "\nAdditionally, I was asked to follow the instructions below to generate the query:", 229 | ) 230 | instruction_reflections = instruction_reflections + "\n" 231 | 232 | if is_json: 233 | messages = [] 234 | for msg_template in messages_template: 235 | msg = deepcopy(msg_template) 236 | msg["content"] = msg_template["content"].format( 237 | user_question=question, 238 | db_type=db_type, 239 | instructions=instructions, 240 | table_metadata_string=table_metadata_string, 241 | k_shot_prompt=k_shot_prompt, 242 | glossary=glossary, 243 | prev_invalid_sql=prev_invalid_sql, 244 | prev_error_msg=prev_error_msg, 245 | question_0=question_0, 246 | query_0=query_0, 247 | question_1=question_1, 248 | query_1=query_1, 249 | cot_instructions=cot_instructions, 250 | instruction_reflections=instruction_reflections, 251 | table_aliases=table_aliases, 252 | join_str=join_str, 253 | pruned_join_str=pruned_join_str, 254 | ) 255 | messages.append(msg) 256 | return messages 257 | else: 258 | prompt = prompt.format( 259 | user_question=question, 260 | db_type=db_type, 261 | instructions=instructions, 262 | table_metadata_string=table_metadata_string, 263 | k_shot_prompt=k_shot_prompt, 264 | glossary=glossary, 265 | prev_invalid_sql=prev_invalid_sql, 266 | prev_error_msg=prev_error_msg, 267 | question_0=question_0, 268 | query_0=query_0, 269 | question_1=question_1, 270 | query_1=query_1, 271 | cot_instructions=cot_instructions, 272 | instruction_reflections=instruction_reflections, 273 | table_aliases=table_aliases, 274 | join_str=join_str, 275 | pruned_join_hints=pruned_join_str, 276 | ) 277 | return prompt 278 | -------------------------------------------------------------------------------- /utils/llm.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import os 3 | import time 4 | from dataclasses import dataclass 5 | from typing import Dict, List, Optional, Any 6 | 7 | LLM_COSTS_PER_TOKEN = { 8 | "gpt-4o": {"input_cost_per1k": 0.0025, "output_cost_per1k": 0.01}, 9 | "gpt-4o-mini": {"input_cost_per1k": 0.00015, "output_cost_per1k": 0.0006}, 10 | "o1": {"input_cost_per1k": 0.015, "output_cost_per1k": 0.06}, 11 | "o1-preview": {"input_cost_per1k": 0.015, "output_cost_per1k": 0.06}, 12 | "o1-mini": {"input_cost_per1k": 0.003, "output_cost_per1k": 0.012}, 13 | "o3-mini": {"input_cost_per1k": 0.0011, "output_cost_per1k": 0.0044}, 14 | "gpt-4-turbo": {"input_cost_per1k": 0.01, "output_cost_per1k": 0.03}, 15 | "gpt-3.5-turbo": {"input_cost_per1k": 0.0005, "output_cost_per1k": 0.0015}, 16 | "claude-3-5-sonnet": {"input_cost_per1k": 0.003, "output_cost_per1k": 0.015}, 17 | "claude-3-5-haiku": {"input_cost_per1k": 0.00025, "output_cost_per1k": 0.00125}, 18 | "claude-3-opus": {"input_cost_per1k": 0.015, "output_cost_per1k": 0.075}, 19 | "claude-3-sonnet": {"input_cost_per1k": 0.003, "output_cost_per1k": 0.015}, 20 | "claude-3-haiku": {"input_cost_per1k": 0.00025, "output_cost_per1k": 0.00125}, 21 | "gemini-1.5-pro": {"input_cost_per1k": 0.00125, "output_cost_per1k": 0.005}, 22 | "gemini-1.5-flash": {"input_cost_per1k": 0.000075, "output_cost_per1k": 0.0003}, 23 | "gemini-1.5-flash-8b": { 24 | "input_cost_per1k": 0.0000375, 25 | "output_cost_per1k": 0.00015, 26 | }, 27 | "gemini-2.0-flash": { 28 | "input_cost_per1k": 0.000075, 29 | "output_cost_per1k": 0.0003, 30 | }, 31 | } 32 | 33 | 34 | @dataclass 35 | class LLMResponse: 36 | content: Any 37 | model: str 38 | time: float 39 | input_tokens: int 40 | output_tokens: int 41 | output_tokens_details: Optional[Dict[str, int]] = None 42 | cost_in_cents: Optional[float] = None 43 | 44 | def __post_init__(self): 45 | if self.model in LLM_COSTS_PER_TOKEN: 46 | model_name = self.model 47 | else: 48 | model_name = None 49 | potential_model_names = [] 50 | 51 | for mname in LLM_COSTS_PER_TOKEN.keys(): 52 | if mname in self.model: 53 | potential_model_names.append(mname) 54 | 55 | if len(potential_model_names) > 0: 56 | model_name = max(potential_model_names, key=len) 57 | 58 | if model_name: 59 | self.cost_in_cents = ( 60 | self.input_tokens 61 | / 1000 62 | * LLM_COSTS_PER_TOKEN[model_name]["input_cost_per1k"] 63 | + self.output_tokens 64 | / 1000 65 | * LLM_COSTS_PER_TOKEN[model_name]["output_cost_per1k"] 66 | ) * 100 67 | 68 | 69 | def chat_anthropic( 70 | messages: List[Dict[str, str]], 71 | model: str = "claude-3-5-sonnet-20241022", 72 | max_completion_tokens: int = 8192, 73 | temperature: float = 0.0, 74 | stop: List[str] = [], 75 | thinking: Dict[str, Any] = None, 76 | json_mode: bool = False, 77 | response_format=None, 78 | seed: int = 0, 79 | store=True, 80 | metadata=None, 81 | timeout=100, 82 | ) -> LLMResponse: 83 | """ 84 | Returns the response from the Anthropic API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used. 85 | Note that anthropic doesn't have explicit json mode api constraints, nor does it have a seed parameter. 86 | """ 87 | from anthropic import Anthropic 88 | 89 | client_anthropic = Anthropic() 90 | t = time.time() 91 | if len(messages) >= 1 and messages[0].get("role") == "system": 92 | sys_msg = messages[0]["content"] 93 | messages = messages[1:] 94 | else: 95 | sys_msg = "" 96 | response = client_anthropic.messages.create( 97 | system=sys_msg, 98 | messages=messages, 99 | model=model, 100 | max_tokens=max_completion_tokens, 101 | temperature=temperature, 102 | stop_sequences=stop, 103 | timeout=timeout, 104 | thinking=thinking, 105 | ) 106 | if response.stop_reason == "max_tokens": 107 | raise Exception("Max tokens reached") 108 | if len(response.content) == 0: 109 | raise Exception("Max tokens reached") 110 | return LLMResponse( 111 | model=model, 112 | content=response.content[-1].text, 113 | time=round(time.time() - t, 3), 114 | input_tokens=response.usage.input_tokens, 115 | output_tokens=response.usage.output_tokens, 116 | ) 117 | 118 | 119 | def chat_openai( 120 | messages: List[Dict[str, str]], 121 | model: str = "gpt-4o", 122 | max_completion_tokens: int = 16384, 123 | temperature: float = 0.0, 124 | stop: List[str] = [], 125 | json_mode: bool = False, 126 | response_format=None, 127 | seed: int = 0, 128 | store=True, 129 | metadata=None, 130 | timeout=100, 131 | ) -> LLMResponse: 132 | """ 133 | Returns the response from the OpenAI API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used. 134 | We use max_completion_tokens here, instead of using max_tokens. This is to support o1 models. 135 | """ 136 | from openai import OpenAI 137 | 138 | client_openai = OpenAI() 139 | t = time.time() 140 | if model.startswith("o"): 141 | if messages[0].get("role") == "system": 142 | sys_msg = messages[0]["content"] 143 | messages = messages[1:] 144 | messages[0]["content"] = sys_msg + messages[0]["content"] 145 | 146 | response = client_openai.chat.completions.create( 147 | messages=messages, 148 | model=model, 149 | max_completion_tokens=max_completion_tokens, 150 | store=store, 151 | metadata=metadata, 152 | timeout=timeout, 153 | ) 154 | else: 155 | if response_format or json_mode: 156 | response = client_openai.beta.chat.completions.parse( 157 | messages=messages, 158 | model=model, 159 | max_completion_tokens=max_completion_tokens, 160 | temperature=temperature, 161 | stop=stop, 162 | response_format=( 163 | {"type": "json_object"} if json_mode else response_format 164 | ), 165 | seed=seed, 166 | store=store, 167 | metadata=metadata, 168 | ) 169 | else: 170 | response = client_openai.chat.completions.create( 171 | messages=messages, 172 | model=model, 173 | max_completion_tokens=max_completion_tokens, 174 | temperature=temperature, 175 | stop=stop, 176 | seed=seed, 177 | store=store, 178 | metadata=metadata, 179 | ) 180 | 181 | if response_format and not model.startswith("o1"): 182 | content = response.choices[0].message.parsed 183 | else: 184 | content = response.choices[0].message.content 185 | 186 | if response.choices[0].finish_reason == "length": 187 | print("Max tokens reached") 188 | raise Exception("Max tokens reached") 189 | if len(response.choices) == 0: 190 | print("Empty response") 191 | raise Exception("No response") 192 | return LLMResponse( 193 | model=model, 194 | content=content, 195 | time=round(time.time() - t, 3), 196 | input_tokens=response.usage.prompt_tokens, 197 | output_tokens=response.usage.completion_tokens, 198 | output_tokens_details=response.usage.completion_tokens_details, 199 | ) 200 | 201 | 202 | def chat_gemini( 203 | messages: List[Dict[str, str]], 204 | model: str = "gemini-2.0-flash-exp", 205 | max_completion_tokens: int = 8192, 206 | temperature: float = 0.0, 207 | stop: List[str] = [], 208 | json_mode: bool = False, 209 | response_format=None, 210 | seed: int = 0, 211 | store=True, 212 | metadata=None, 213 | timeout=100, # does not have timeout method 214 | ) -> LLMResponse: 215 | from google import genai 216 | from google.genai import types 217 | 218 | client = genai.Client( 219 | api_key=os.getenv("GEMINI_API_KEY"), 220 | ) 221 | t = time.time() 222 | if messages[0]["role"] == "system": 223 | system_msg = messages[0]["content"] 224 | messages = messages[1:] 225 | else: 226 | system_msg = None 227 | 228 | message = "\n".join([i["content"] for i in messages]) 229 | 230 | generation_config = types.GenerateContentConfig( 231 | temperature=temperature, 232 | system_instruction=system_msg, 233 | max_output_tokens=max_completion_tokens, 234 | stop_sequences=stop, 235 | ) 236 | 237 | if response_format: 238 | # use Pydantic classes for response_format 239 | generation_config.response_mime_type = "application/json" 240 | generation_config.response_schema = response_format 241 | 242 | try: 243 | response = client.models.generate_content( 244 | model=model, 245 | contents=message, 246 | config=generation_config, 247 | ) 248 | content = response.text 249 | except Exception as e: 250 | raise Exception(f"An error occurred: {e}") 251 | 252 | if response_format: 253 | # convert the content into Pydantic class 254 | content = response_format.parse_raw(content) 255 | 256 | return LLMResponse( 257 | model=model, 258 | content=content, 259 | time=round(time.time() - t, 3), 260 | input_tokens=response.usage_metadata.prompt_token_count, 261 | output_tokens=response.usage_metadata.candidates_token_count, 262 | ) 263 | 264 | 265 | def map_model_to_chat_fn(model: str) -> Callable: 266 | """ 267 | Returns the appropriate chat function based on the model. 268 | """ 269 | if model.startswith("claude"): 270 | return chat_anthropic 271 | if model.startswith("gemini"): 272 | return chat_gemini 273 | if model.startswith("gpt") or model.startswith("o1"): 274 | return chat_openai 275 | raise ValueError(f"Unknown model: {model}") 276 | 277 | 278 | async def chat( 279 | model, 280 | messages, 281 | max_completion_tokens=4096, 282 | temperature=0.0, 283 | stop=[], 284 | json_mode=False, 285 | response_format=None, 286 | seed=0, 287 | store=True, 288 | metadata=None, 289 | timeout=100, # in seconds 290 | ) -> LLMResponse: 291 | """ 292 | Returns the response from the LLM API for a single model that is passed in. 293 | Includes retry logic with exponential backoff for up to 3 attempts. 294 | """ 295 | llm_function = map_model_to_chat_fn(model) 296 | max_retries = 3 297 | base_delay = 1 # Initial delay in seconds 298 | 299 | for attempt in range(max_retries): 300 | try: 301 | return llm_function( 302 | model=model, 303 | messages=messages, 304 | max_completion_tokens=max_completion_tokens, 305 | temperature=temperature, 306 | stop=stop, 307 | json_mode=json_mode, 308 | response_format=response_format, 309 | seed=seed, 310 | store=store, 311 | metadata=metadata, 312 | timeout=timeout, 313 | ) 314 | except Exception as e: 315 | delay = base_delay * (2**attempt) # Exponential backoff 316 | print( 317 | f"Attempt {attempt + 1} failed. Retrying in {delay} seconds...", 318 | flush=True, 319 | ) 320 | print(f"Error: {e}", flush=True) 321 | time.sleep(delay) 322 | 323 | # If we get here, all attempts failed 324 | raise Exception("All attempts at calling the chat function failed") 325 | -------------------------------------------------------------------------------- /utils/pruning.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/defog-ai/sql-eval/c7986a7e0922e66b9dafe9096f454e41fa89589b/utils/pruning.py -------------------------------------------------------------------------------- /utils/questions.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import pandas as pd 3 | 4 | 5 | def get_table_aliases(db_name: str) -> str: 6 | from defog_data.metadata import dbs 7 | from utils.aliases import generate_aliases 8 | 9 | metadata = dbs[db_name]["table_metadata"] 10 | table_names = list(metadata.keys()) 11 | aliases = generate_aliases(table_names) 12 | aliases_instruction = ( 13 | "Use the following table aliases when referencing tables in the query:\n" 14 | + aliases 15 | ) 16 | return aliases_instruction 17 | 18 | 19 | def prepare_questions_df( 20 | questions_file: str, 21 | db_type: str, 22 | num_questions: Optional[int] = None, 23 | k_shot: bool = False, 24 | cot_table_alias: str = "", 25 | ): 26 | question_query_df = pd.read_csv(questions_file, nrows=num_questions) 27 | question_query_df["db_type"] = db_type 28 | question_query_df["generated_query"] = "" 29 | question_query_df["reason"] = "" 30 | question_query_df["error_msg"] = "" 31 | question_query_df["exact_match"] = 0 32 | question_query_df["correct"] = 0 33 | question_query_df["error_query_gen"] = 0 34 | question_query_df["error_db_exec"] = 0 35 | question_query_df["timeout"] = 0 36 | # add custom metrics below: 37 | question_query_df["latency_seconds"] = 0.0 # latency of query generation in seconds 38 | question_query_df["tokens_used"] = 0 # number of tokens used in query generation 39 | 40 | # get instructions if applicable 41 | if "instructions" in question_query_df.columns: 42 | question_query_df["instructions"] = question_query_df["instructions"].fillna("") 43 | question_query_df["instructions"] = question_query_df["instructions"].apply( 44 | lambda x: x.replace(". ", ".\n") 45 | ) 46 | question_query_df["instructions"] = question_query_df["instructions"].apply( 47 | lambda x: ( 48 | f"\nFollow the instructions below to generate the query:\n{x}\n" 49 | if x != "" 50 | else "" 51 | ) 52 | ) 53 | else: 54 | question_query_df["instructions"] = "" 55 | 56 | # get k_shot prompt if applicable 57 | if not k_shot: 58 | question_query_df["k_shot_prompt"] = "" 59 | else: 60 | if "k_shot_prompt" not in question_query_df.columns: 61 | raise ValueError( 62 | "k_shot is True but k_shot_prompt column not in questions file" 63 | ) 64 | else: 65 | question_query_df["k_shot_prompt"] = question_query_df[ 66 | "k_shot_prompt" 67 | ].fillna("") 68 | question_query_df["k_shot_prompt"] = question_query_df[ 69 | "k_shot_prompt" 70 | ].apply(lambda x: x.replace("\\n", "\n")) 71 | question_query_df["k_shot_prompt"] = question_query_df[ 72 | "k_shot_prompt" 73 | ].apply( 74 | lambda x: f"Adhere closely to the following correct examples as references for answering the question:\n{x}" 75 | ) 76 | 77 | # get glossary if applicable 78 | if "glossary" in question_query_df.columns: 79 | question_query_df["glossary"] = question_query_df["glossary"].fillna("") 80 | question_query_df["glossary"] = question_query_df["glossary"].apply( 81 | lambda x: f"\nCarefully follow these instructions if and only if they are relevant to the question and the query you generate:\n{x}\n" 82 | ) 83 | else: 84 | question_query_df["glossary"] = "" 85 | 86 | question_query_df.reset_index(inplace=True, drop=True) 87 | 88 | # get table_metadata_string if applicable 89 | if "table_metadata_string" in question_query_df.columns: 90 | question_query_df["table_metadata_string"] = question_query_df[ 91 | "table_metadata_string" 92 | ].fillna("") 93 | else: 94 | question_query_df["table_metadata_string"] = "" 95 | 96 | # get table_aliases 97 | question_query_df["table_aliases"] = question_query_df["db_name"].apply( 98 | get_table_aliases 99 | ) 100 | 101 | # get prev_invalid_sql if applicable 102 | if "prev_invalid_sql" in question_query_df.columns: 103 | question_query_df["prev_invalid_sql"] = question_query_df[ 104 | "prev_invalid_sql" 105 | ].fillna("") 106 | else: 107 | question_query_df["prev_invalid_sql"] = "" 108 | 109 | # get prev_error_msg if applicable 110 | if "prev_error_msg" in question_query_df.columns: 111 | question_query_df["prev_error_msg"] = question_query_df[ 112 | "prev_error_msg" 113 | ].fillna("") 114 | else: 115 | question_query_df["prev_error_msg"] = "" 116 | 117 | # get question_0, query_0, question_1, query_1 if applicable 118 | if "question_0" in question_query_df.columns: 119 | question_query_df["question_0"] = question_query_df["question_0"].fillna("") 120 | else: 121 | question_query_df["question_0"] = "" 122 | if "query_0" in question_query_df.columns: 123 | question_query_df["query_0"] = question_query_df["query_0"].fillna("") 124 | else: 125 | question_query_df["query_0"] = "" 126 | if "question_1" in question_query_df.columns: 127 | question_query_df["question_1"] = question_query_df["question_1"].fillna("") 128 | else: 129 | question_query_df["question_1"] = "" 130 | if "query_1" in question_query_df.columns: 131 | question_query_df["query_1"] = question_query_df["query_1"].fillna("") 132 | else: 133 | question_query_df["query_1"] = "" 134 | 135 | # add all cot instructions to the respective columns 136 | question_query_df["cot_instructions"] = "" 137 | question_query_df["cot_pregen"] = False 138 | if cot_table_alias == "instruct": 139 | question_query_df["cot_instructions"] = ( 140 | "List the table aliases for each table as comments, starting with the most relevant tables to the question." 141 | ) 142 | elif cot_table_alias == "pregen": 143 | question_query_df["cot_pregen"] = True 144 | 145 | return question_query_df 146 | -------------------------------------------------------------------------------- /utils/reporting.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from uuid import uuid4 3 | from datetime import datetime 4 | import os 5 | 6 | 7 | # get the GPU name this is running on 8 | def get_gpu_name(): 9 | """ 10 | Get the GPU name this is running on. 11 | """ 12 | # Get the GPU name 13 | try: 14 | gpu_name = os.popen( 15 | "nvidia-smi --query-gpu=gpu_name --format=csv,noheader" 16 | ).read() 17 | except: 18 | gpu_name = "No GPU found" 19 | # Return the GPU name 20 | return gpu_name 21 | 22 | 23 | def get_gpu_memory(): 24 | """ 25 | Get the GPU memory this is running on. 26 | """ 27 | # Get the GPU memory 28 | try: 29 | gpu_memory = os.popen( 30 | "nvidia-smi --query-gpu=memory.total --format=csv,noheader" 31 | ).read() 32 | except: 33 | gpu_memory = "No GPU found" 34 | # Return the GPU memory 35 | return gpu_memory 36 | 37 | 38 | def get_gpu_driver_version(): 39 | """ 40 | Get the GPU driver version this is running on. 41 | """ 42 | # Get the GPU driver version 43 | try: 44 | gpu_driver_version = os.popen( 45 | "nvidia-smi --query-gpu=driver_version --format=csv,noheader" 46 | ).read() 47 | except: 48 | gpu_driver_version = "No GPU found" 49 | # Return the GPU driver version 50 | return gpu_driver_version 51 | 52 | 53 | def get_gpu_cuda_version(): 54 | """ 55 | Get the GPU CUDA version this is running on. 56 | """ 57 | # Get the GPU CUDA version 58 | try: 59 | gpu_cuda_version = os.popen("nvcc --version").read() 60 | except: 61 | gpu_cuda_version = "No GPU found" 62 | # Return the GPU CUDA version 63 | return gpu_cuda_version 64 | 65 | 66 | def num_gpus(): 67 | """ 68 | Get the number of GPUs this is running on. 69 | """ 70 | # Get the number of GPUs 71 | try: 72 | num_gpus = os.popen("nvidia-smi --query-gpu=count --format=csv,noheader").read() 73 | except: 74 | num_gpus = "No GPU found" 75 | # Return the number of GPUs 76 | return num_gpus 77 | 78 | 79 | def upload_results( 80 | results: list, 81 | url: str, 82 | run_name: str = None, 83 | runner_type: str = None, 84 | args: dict = None, 85 | **kwargs, # this is to make sure other imports don't break 86 | ): 87 | """ 88 | Uploads results to a server. 89 | Customize where the results are stored by changing the url. 90 | The db_type in the args below refers to the db_type of the queries we evaluated, 91 | not the db_type of where we are storing our results. 92 | """ 93 | if not run_name: 94 | return 95 | # Create a unique id for the request 96 | run_id = uuid4().hex 97 | 98 | # Create a dictionary with the request id and the results 99 | data = { 100 | "run_id": run_id, 101 | "results": results, 102 | "timestamp": datetime.now().isoformat(), 103 | "runner_type": runner_type, 104 | "model": args.model, 105 | "num_beams": args.num_beams, 106 | "db_type": args.db_type, 107 | "run_name": run_name, 108 | } 109 | # Send the data to the server 110 | response = requests.post(url, json=data) 111 | if response.status_code != 200: 112 | print(f"Error uploading results:\n{response.text}") 113 | # Return the response 114 | return response 115 | -------------------------------------------------------------------------------- /utils/upload_report_gcloud.py: -------------------------------------------------------------------------------- 1 | # this is a Google cloud function for receiving the data from the web app and storing it in the database 2 | # to launch the cloud function, run the following command in the terminal: 3 | # gcloud functions deploy record-eval --runtime python10 --trigger-http --allow-unauthenticated 4 | 5 | import functions_framework 6 | from google.cloud import storage 7 | import json 8 | 9 | BUCKET_NAME = "YOUR-BUCKET-NAME" 10 | 11 | 12 | @functions_framework.http 13 | def hello_http(request): 14 | request_json = request.get_json(silent=True) 15 | results = request_json["results"] 16 | run_name = request_json["run_name"] 17 | storage_client = storage.Client() 18 | bucket = storage_client.bucket(BUCKET_NAME) 19 | blob = bucket.blob(run_name + ".json") 20 | blob.upload_from_string(json.dumps(results)) 21 | return "success" 22 | --------------------------------------------------------------------------------