├── config ├── sq1_sql.txt ├── sq2_sql.txt ├── sq3_sql.txt ├── sq1_bash.txt ├── sq2_bash.txt ├── sq1_python.txt ├── sq3_bash.txt ├── sq2_python.txt ├── sq3_python.txt ├── test_config.json ├── sq2_cpp.txt ├── sq1_cpp.txt ├── simple_prompts2.json ├── simple_prompts.json ├── sq3_cpp.txt └── prompts.json ├── requirements.txt ├── src └── codexdb │ ├── bench │ ├── analyze.py │ ├── test.py │ ├── temperature.py │ ├── run.py │ ├── alpha.py │ ├── scale.py │ └── plot.py │ ├── finetuning │ ├── finetune.py │ └── prepare.py │ ├── catalog.py │ ├── prep │ ├── spider.py │ └── wikisql.py │ ├── gui.py │ ├── engine.py │ ├── solve.py │ ├── code.py │ └── plan.py ├── LICENSE ├── .gitignore ├── README.md └── prompt_collection /config/sq1_sql.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /config/sq2_sql.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /config/sq3_sql.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /config/sq1_bash.txt: -------------------------------------------------------------------------------- 1 | wc -l perpetrator.csv | awk '{print $1-1}' 2 | --- End of Bash script --- -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | openai==0.28 2 | streamlit==1.40 3 | pandas==2.2 4 | sqlglot==1.16.1 5 | -------------------------------------------------------------------------------- /config/sq2_bash.txt: -------------------------------------------------------------------------------- 1 | tail -n +2 classroom.csv | cut -f1 -d, | sort | uniq | wc -l | xargs 2 | --- End of Bash script --- -------------------------------------------------------------------------------- /config/sq1_python.txt: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | perpetrator = pd.read_csv("perpetrator.csv") 3 | print(len(perpetrator)) 4 | --- End of Python program --- -------------------------------------------------------------------------------- /config/sq3_bash.txt: -------------------------------------------------------------------------------- 1 | join -1 1 -2 8 -t , <(tail -n +2 operate_company.csv | sort -k 1 -t ,) <(tail -n +2 flight.csv | sort -k 8 -t ,) | cut -d, -f2,3 -s 2 | --- End of Bash script --- -------------------------------------------------------------------------------- /config/sq2_python.txt: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | classroom = pd.read_csv("classroom.csv").query("capacity > 50") 3 | print(len(classroom.groupby("building"))) 4 | --- End of Python program --- -------------------------------------------------------------------------------- /config/sq3_python.txt: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | operate_company = pd.read_csv("operate_company.csv") 3 | flight = pd.read_csv("flight.csv") 4 | operate_company_flight = pd.merge(left=operate_company, right=flight, left_on="id", right_on="company_id") 5 | print(operate_company_flight[["name", "Type"]]) 6 | --- End of Python program --- -------------------------------------------------------------------------------- /src/codexdb/bench/analyze.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Jan 28, 2022 3 | 4 | @author: immanueltrummer 5 | ''' 6 | import argparse 7 | 8 | if __name__ == '__main__': 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('run_dir', type=str, help='Directory with results') 12 | args = parser.parse_args() 13 | 14 | -------------------------------------------------------------------------------- /config/test_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_dir":"/home/ubuntu/spider/spider", 3 | "sample_path":"/home/ubuntu/codexdb/experiments/spider3/plain/train_plain.json", 4 | "test_path":"/home/ubuntu/spider/spider/results_dev.json", 5 | "test_start":0, 6 | "test_step":2, 7 | "test_end":200, 8 | "model_id":"code-davinci-002", 9 | "prompt_style":"plan", 10 | "nr_samples":4, 11 | "id_case":0, 12 | "mod_start":"", 13 | "mod_between":"", 14 | "mod_end":"", 15 | "nr_retries":2, 16 | "max_temperature":1.0, 17 | "out_dir":"/home/ubuntu/codexdb/experiments/spider3/temperature" 18 | } -------------------------------------------------------------------------------- /config/sq2_cpp.txt: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | using namespace std; 6 | 7 | struct s_classroom { 8 | string building, 9 | string room_number, 10 | int capacity 11 | }; 12 | 13 | int main() 14 | { 15 | ifstream classroom_file("classroom.csv"); 16 | classroom_file.ignore(5000, '\n'); 17 | vector building_vector; 18 | string line, field; 19 | while (getline(classroom_file, line)) { 20 | stringstream ss(line); 21 | struct s_classroom classroom; 22 | getline(ss, field, ','); 23 | classroom.building=field; 24 | getline(ss, field, ','); 25 | classroom.room_number=field; 26 | getline(ss, field, ','); 27 | classroom.capacity=stoi(field); 28 | 29 | if (capacity > 50) { 30 | building_vector.push_back(classroom.building); 31 | } 32 | } 33 | sort(building_vector.begin(), building_vector.end()) 34 | cout << unique(building_vector.begin(), building_vector.end()) << endl; 35 | } 36 | --- End of C++ program --- -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Immanuel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config/sq1_cpp.txt: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | using namespace std; 6 | 7 | struct s_perpetrator { 8 | string Perpetrator_ID; 9 | int People_ID; 10 | int Date; 11 | string Year; 12 | int Location; 13 | string Country; 14 | int Killed; 15 | int Injured; 16 | }; 17 | 18 | int main() 19 | { 20 | ifstream perpetrator_file("perpetrator.csv"); 21 | perpetrator_file.ignore(5000, '\n'); 22 | vector perpetrator_vector; 23 | string line, field; 24 | while (getline(perpetrator_file, line)) { 25 | stringstream ss(line); 26 | struct s_perpetrator perpetrator; 27 | getline(ss, field, ','); 28 | perpetrator.Perpetrator_ID=field; 29 | getline(ss, field, ','); 30 | perpetrator.People_ID=stoi(field); 31 | getline(ss, field, ','); 32 | perpetrator.Date=stoi(field); 33 | getline(ss, field, ','); 34 | perpetrator.Year=field; 35 | getline(ss, field, ','); 36 | perpetrator.Location=field; 37 | getline(ss, field, ','); 38 | perpetrator.Country=field; 39 | getline(ss, field, ','); 40 | perpetrator.Killed=stoi(field); 41 | getline(ss, field, ','); 42 | perpetrator.Injured=stoi(field); 43 | perpetrator_vector.push_back(perpetrator); 44 | } 45 | cout << perpetrator_vector.size() << endl; 46 | } 47 | --- End of C++ program --- -------------------------------------------------------------------------------- /config/simple_prompts2.json: -------------------------------------------------------------------------------- 1 | { 2 | "query": { 3 | "from_nl": { 4 | "to_python": { 5 | "template": "\"\"\"\nThis Python program answers the query \"\" on the following tables:\n\n\n\"\"\"\n\n--- Start of Python program ---", 6 | "marker": "--- End of Python program ---", 7 | "linepre": "" 8 | }, 9 | "to_bash": { 10 | "template": "# This Bash script answers the query \"\" on the following tables:\n\n# Answer the query \"\":\n\n\n--- Start of Bash script ---\n#!/bin/bash\n\necho \"Processing query ...\"", 11 | "marker": "--- End of Bash script ---", 12 | "linepre": "# " 13 | }, 14 | "to_cpp": { 15 | "template": "// This C++ program answers the query \"\" on the following tables:\n\n\n\n--- Start of C++ program ---\n", 16 | "marker": "--- End of C++ program ---", 17 | "linepre": "// " 18 | }, 19 | "to_pg_sql": { 20 | "template": "##### Translate this query into SQL: \n\n--- Start of SQL query ---\nSELECT ", 21 | "marker": "--- End of SQL query ---", 22 | "linepre": "# " 23 | } 24 | }, 25 | "tactics": [ 26 | "Filter tables using query predicates.", 27 | "Perform joins as needed for query.", 28 | "Aggregate data as specified in query.", 29 | "Calculate the answer to the query." 30 | ], 31 | "precedence": [ 32 | {"F":0, "S":1}, 33 | {"F":1, "S":2}, 34 | {"F":2, "S":3} 35 | ], 36 | "strategies": [""] 37 | } 38 | } -------------------------------------------------------------------------------- /src/codexdb/finetuning/finetune.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Nov 21, 2024 3 | 4 | @author: immanueltrummer 5 | ''' 6 | import argparse 7 | import openai 8 | import time 9 | 10 | client = openai.OpenAI() 11 | 12 | 13 | if __name__ == '__main__': 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('in_path', type=str, help='Path to input file') 17 | parser.add_argument('model', type=str, help='Base model for fine-tuning') 18 | parser.add_argument('suffix', type=str, help='Suffix used in model name') 19 | args = parser.parse_args() 20 | 21 | reply = client.files.create( 22 | file=open(args.in_path, 'rb'), 23 | purpose='fine-tune') 24 | file_id = reply.id 25 | 26 | reply = client.fine_tuning.jobs.create( 27 | training_file=file_id, 28 | model=args.model, 29 | suffix=args.suffix) 30 | job_id = reply.id 31 | print(f'Job ID: {job_id}') 32 | 33 | status = None 34 | start_s = time.time() 35 | 36 | while not (status == 'succeeded'): 37 | 38 | time.sleep(5) 39 | total_s = time.time() - start_s 40 | print(f'Fine-tuning since {total_s} seconds.') 41 | 42 | reply = client.fine_tuning.jobs.retrieve(job_id) 43 | status = reply.status 44 | print(f'Status: {status}') 45 | print(status) 46 | 47 | print(f'Fine-tuning is finished!') 48 | model_id = reply.fine_tuned_model 49 | print(f'Model ID: {model_id}') -------------------------------------------------------------------------------- /config/simple_prompts.json: -------------------------------------------------------------------------------- 1 | { 2 | "query": { 3 | "from_nl": { 4 | "to_python": { 5 | "template": "\"\"\"\nThis Python program answers the query \"\" on the following tables:\n\n\n\"\"\"\n\n--- Start of Python program ---", 6 | "marker": "--- End of Python program ---", 7 | "linepre": "" 8 | }, 9 | "to_bash": { 10 | "template": "# This Bash script answers the query \"\" on the following tables:\n\n# Answer the query \"\":\n\n\n--- Start of Bash script ---\n#!/bin/bash\n\necho \"Processing query ...\"", 11 | "marker": "--- End of Bash script ---", 12 | "linepre": "# " 13 | }, 14 | "to_cpp": { 15 | "template": "// This C++ program answers the query \"\" on the following tables:\n\n\n\n--- Start of C++ program ---\n", 16 | "marker": "--- End of C++ program ---", 17 | "linepre": "// " 18 | }, 19 | "to_pg_sql": { 20 | "template": "##### Translate this query into SQL: \n\n--- Start of SQL query ---\nSELECT ", 21 | "marker": "--- End of SQL query ---", 22 | "linepre": "# " 23 | } 24 | }, 25 | "tactics": [ 26 | "Import the Pandas library.", 27 | "Import the Dask library.", 28 | "Import the Vaex library.", 29 | "Import the Modin library.", 30 | "Load data for all relevant tables.", 31 | "Calculate the answer to the query.", 32 | "Write query results to file 'result.csv'." 33 | ], 34 | "precedence": [ 35 | {"F":0, "S":4}, 36 | {"F":1, "S":4}, 37 | {"F":2, "S":4}, 38 | {"F":3, "S":4}, 39 | {"F":4, "S":5}, 40 | {"F":5, "S":6} 41 | ], 42 | "strategies": [ 43 | "", 44 | " for parallel processing", 45 | " for efficient processing"] 46 | } 47 | } -------------------------------------------------------------------------------- /src/codexdb/bench/test.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on May 17, 2022 3 | 4 | @author: immanueltrummer 5 | ''' 6 | import argparse 7 | import codexdb.solve 8 | import json 9 | import openai 10 | import os.path 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('ai_key', type=str, help='Key for OpenAI access') 15 | parser.add_argument('config_path', type=str, help='Configuration file path') 16 | args = parser.parse_args() 17 | 18 | openai.api_key = args.ai_key 19 | with open(args.config_path) as file: 20 | config = json.load(file) 21 | 22 | data_dir = config['data_dir'] 23 | sample_path = config['sample_path'] 24 | test_path = config['test_path'] 25 | test_start = config['test_start'] 26 | test_step = config['test_step'] 27 | test_end = config['test_end'] 28 | model_id = config['model_id'] 29 | prompt_style = config['prompt_style'] 30 | nr_samples = config['nr_samples'] 31 | id_case = config['id_case'] 32 | mod_start = config['mod_start'] 33 | mod_between = config['mod_between'] 34 | mod_end = config['mod_end'] 35 | nr_retries = config['nr_retries'] 36 | max_temperature = config['max_temperature'] 37 | out_dir = config['out_dir'] 38 | 39 | run_id = f'{model_id}_{prompt_style}_S{nr_samples}_' +\ 40 | f'R{nr_retries}_T{max_temperature}' 41 | log_path = f'{out_dir}/log_{run_id}' 42 | result_path = f'{out_dir}/results_{run_id}.json' 43 | if os.path.exists(log_path) or os.path.exists(result_path): 44 | raise ValueError('Cannot override existing files!') 45 | 46 | codexdb.solve.main( 47 | data_dir, test_path, 'python', 48 | model_id, prompt_style, id_case, 49 | mod_start, mod_between, mod_end, 50 | sample_path, nr_samples, 51 | test_start, test_step, test_end, 'executed', 52 | nr_retries, max_temperature, log_path, result_path) -------------------------------------------------------------------------------- /src/codexdb/bench/temperature.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on May 25, 2022 3 | 4 | @author: immanueltrummer 5 | ''' 6 | import argparse 7 | import codexdb.solve 8 | import openai 9 | import os 10 | 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('ai_key', type=str, help='Key for OpenAI access') 15 | args = parser.parse_args() 16 | 17 | openai.api_key = args.ai_key 18 | data_dir = '/home/ubuntu/spider/spider' 19 | test_path = '/home/ubuntu/spider/spider/results_dev.json' 20 | test_start = 0 21 | test_step = 2 22 | test_end = 200 23 | model_id = 'code-davinci-002' 24 | prompt_style = 'plan' 25 | nr_samples = 4 26 | id_case = 0 27 | mod_end = '' 28 | nr_retries = 2 29 | 30 | for mod_start, mod_between, out_suffix in [ 31 | ('', '', 'plain'), 32 | ('Use pandas library', '', 'pandas'), 33 | ('Use vaex library', '', 'vaex'), 34 | ('Use datatable library', '', 'datatable'), 35 | ('', 'Print "Done."', 'done'), 36 | ('', 'Print intermediate results', 'results'), 37 | ('', 'Print progress updates', 'progress')]: 38 | sample_path = f'/home/ubuntu/codexdb/experiments/spider3/{out_suffix}/train_plain.json' 39 | for max_temperature in [0.125, 0.25, 0.5, 1.0, 2.0]: 40 | out_dir = f'/home/ubuntu/codexdb/experiments/spider3/temperature/{out_suffix}' 41 | run_id = f'{model_id}_{prompt_style}_S{nr_samples}_' +\ 42 | f'R{nr_retries}_T{max_temperature}' 43 | log_path = f'{out_dir}/log_{run_id}' 44 | result_path = f'{out_dir}/results_{run_id}.json' 45 | if os.path.exists(log_path) or os.path.exists(result_path): 46 | raise ValueError( 47 | 'Cannot override existing files: ' +\ 48 | f'{log_path}; {result_path}') 49 | 50 | codexdb.solve.main( 51 | data_dir, test_path, 'python', 52 | model_id, prompt_style, id_case, 53 | mod_start, mod_between, '', 54 | sample_path, nr_samples, 55 | test_start, test_step, test_end, 'executed', 56 | nr_retries, max_temperature, log_path, result_path) -------------------------------------------------------------------------------- /src/codexdb/finetuning/prepare.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Nov 21, 2024 3 | 4 | @author: immanueltrummer 5 | ''' 6 | import argparse 7 | import json 8 | import jsonlines 9 | 10 | from codexdb.catalog import DbCatalog 11 | from codexdb.code import PythonGenerator 12 | 13 | 14 | def get_sample(final_try): 15 | """ Generate sample from solved test case. 16 | 17 | Args: 18 | final_try: try that solved the test case. 19 | 20 | Returns: 21 | Sample in OpenAI format for fine-tuning. 22 | """ 23 | schema = final_try['schema'] 24 | db = final_try['db'] 25 | files = final_try['files'] 26 | question = final_try['question'] 27 | query = final_try['query'] 28 | 29 | prompt = coder.get_prompt(schema, db, files, question, query) 30 | user_message = {'role':'user', 'content':prompt} 31 | 32 | code = final_try['code'] 33 | assistant_message = {'role':'assistant', 'content':code} 34 | 35 | sample = {'messages':[user_message, assistant_message]} 36 | return sample 37 | 38 | 39 | if __name__ == '__main__': 40 | 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('data_dir', type=str, help='Data directory') 43 | parser.add_argument('in_path', type=str, help='Path to input') 44 | parser.add_argument('mod_start', type=str, help='Modification at plan start') 45 | parser.add_argument('mod_between', type=str, help='Modification between steps') 46 | parser.add_argument('mod_end', type=str, help='Modification at plan end') 47 | parser.add_argument('out_path', type=str, help='Path to output') 48 | args = parser.parse_args() 49 | 50 | catalog = DbCatalog(args.data_dir) 51 | coder = PythonGenerator( 52 | catalog, [], 0, 'plan', '', 53 | id_case=True, mod_start=args.mod_start, 54 | mod_between=args.mod_between, 55 | mod_end=args.mod_end) 56 | 57 | with open(args.in_path) as file: 58 | data = json.load(file) 59 | 60 | samples = [] 61 | for test_case in data.values(): 62 | final_try = test_case[-1] 63 | solved = (final_try['similarity'] == 1.0) 64 | if solved: 65 | sample = get_sample(final_try) 66 | samples.append(sample) 67 | 68 | with jsonlines.open(args.out_path, 'w') as file: 69 | for sample in samples: 70 | file.write(sample) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /config/sq3_cpp.txt: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | using namespace std; 6 | 7 | struct s_operate_company { 8 | int id; 9 | string name; 10 | string Type; 11 | string Principal_activities; 12 | string Incorporated_in; 13 | string Group_Equity_Shareholding; 14 | }; 15 | 16 | struct s_flight { 17 | int id; 18 | int Vehicle_Flight_number; 19 | string Date; 20 | string Pilot; 21 | double Velocity; 22 | double Altitude; 23 | int airport_id; 24 | int company_id; 25 | }; 26 | 27 | int main() 28 | { 29 | ifstream operate_company_file("operate_company.csv"); 30 | ifstream flight_file("flight.csv"); 31 | operate_company_file.ignore(5000, '\n'); 32 | flight_file.ignore(5000, '\n'); 33 | vector operate_company_vector; 34 | vector flight_vector; 35 | 36 | string line, field; 37 | while (getline(operate_company_file, line)) { 38 | stringstream ss(line); 39 | struct s_operate_company operate_company; 40 | getline(ss, field, ','); 41 | operate_company.id=stoi(field); 42 | getline(ss, field, ','); 43 | operate_company.name=field; 44 | getline(ss, field, ','); 45 | operate_company.Type=field; 46 | getline(ss, field, ','); 47 | operate_company.Principal_activities=field; 48 | getline(ss, field, ','); 49 | operate_company.Incorporated_in=field; 50 | getline(ss, field, ','); 51 | operate_company.Group_Equity_Shareholding=field; 52 | operate_company_vector.push_back(operate_company); 53 | } 54 | while (getline(flight_company_file, line)) { 55 | stringstream ss(line); 56 | struct s_flight flight; 57 | getline(ss, field, ','); 58 | flight.id=stoi(field); 59 | getline(ss, field, ','); 60 | flight.Vehicle_Flight_number=stoi(field); 61 | getline(ss, field, ','); 62 | flight.Date=field; 63 | getline(ss, field, ','); 64 | flight.Pilot=field; 65 | getline(ss, field, ','); 66 | flight.Velocity=stod(field); 67 | getline(ss, field, ','); 68 | flight.Altitude=stod(field); 69 | getline(ss, field, ','); 70 | flight.airport_id=stoi(field); 71 | getline(ss, field, ','); 72 | flight.company_id=stoi(field); 73 | flight_vector.push_back(flight); 74 | } 75 | 76 | map operate_company_by_id; 77 | for (struct s_operate_company operate_company in operate_company_vector) { 78 | operate_company_by_id.insert(pair(operate_company.id, operate_company)); 79 | } 80 | for (struct s_flight flight in flight_vector) { 81 | company_id = flight.company_id; 82 | company = operate_company_by_id[company_id]; 83 | cout << company.name << "," << company.Type << endl; 84 | } 85 | } 86 | --- End of C++ program --- -------------------------------------------------------------------------------- /src/codexdb/catalog.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Oct 5, 2021 3 | 4 | @author: immanueltrummer 5 | ''' 6 | import json 7 | 8 | class DbCatalog(): 9 | """ Information over all databases in database directory. """ 10 | 11 | def __init__(self, data_dir): 12 | """ Initialize for given database directory. 13 | 14 | Args: 15 | data_dir: contains databases and schemata 16 | """ 17 | self.data_dir = data_dir 18 | self.schema_path = f'{data_dir}/schemata.json' 19 | with open(self.schema_path) as file: 20 | self.schemata = json.load(file) 21 | self.table_to_file = {} 22 | 23 | def assign_file(self, db_id, table, file_name): 24 | """ Assign file to given table in given database. 25 | 26 | Args: 27 | db_id: table is in this database 28 | table: assign file containing data for this table 29 | file_name: name of file containing data 30 | """ 31 | self.table_to_file[(db_id, table)] = file_name 32 | 33 | def db_dir(self, db_id): 34 | """ Returns directory storing specific database. 35 | 36 | Args: 37 | db_id: name of database 38 | 39 | Returns: 40 | path of directory containing database 41 | """ 42 | return f'{self.data_dir}/database/{db_id}' 43 | 44 | def db_ids(self): 45 | """ Returns IDs of available databases. 46 | 47 | Returns: 48 | list with database IDs 49 | """ 50 | return self.schemata.keys() 51 | 52 | def file_name(self, db_id, table): 53 | """ Returns name of file storing table data. 54 | 55 | Args: 56 | db_id: ID of database 57 | table: name of table 58 | 59 | Returns: 60 | name of file storing data 61 | """ 62 | key = (db_id, table) 63 | default = f'{table}.csv' 64 | return self.table_to_file.get(key, default) 65 | 66 | def file_path(self, db_id, table): 67 | """ Returns path to file containing data for table. 68 | 69 | Args: 70 | db_id: search table in this database 71 | table: name of table 72 | 73 | Returns: 74 | path to file containing data for table 75 | """ 76 | db_dir = self.db_dir(db_id) 77 | file_name = self.file_name(db_id, table) 78 | return f'{db_dir}/{file_name}' 79 | 80 | def files(self, db_id): 81 | """ Returns names of files containing database tables. 82 | 83 | Args: 84 | db_id: unique database identifier 85 | 86 | Returns: 87 | list of files associated with database tables 88 | """ 89 | tables = self.schema(db_id)['table_names_original'] 90 | return [self.file_name(db_id, t) for t in tables] 91 | 92 | def schema(self, db_id): 93 | """ Returns description of database schema. 94 | 95 | Args: 96 | db_id: unique name of database 97 | 98 | Returns: 99 | JSON object describing database schema 100 | """ 101 | return self.schemata[db_id] -------------------------------------------------------------------------------- /src/codexdb/bench/run.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Oct 3, 2021 3 | 4 | @author: immanueltrummer 5 | ''' 6 | import argparse 7 | import codexdb.solve 8 | import getpass 9 | import openai 10 | import os 11 | import time 12 | 13 | 14 | def train(data_dir, train_path, log_path, id_case, 15 | mod_start, mod_between, mod_end, result_path): 16 | """ Generate training samples for few-shot learning. 17 | 18 | Args: 19 | data_dir: path to database directory 20 | train_path: path to test cases for training 21 | log_path: path to file for logging output 22 | id_case: whether to consider letter case of identifiers 23 | mod_start: modification at plan start 24 | mod_between: modification between plan steps 25 | mod_end: modification at plan end 26 | result_path: path to file for result output 27 | """ 28 | codexdb.solve.main( 29 | data_dir, train_path, 'python', 'gpt-4o', 'plan', 30 | id_case, mod_start, mod_between, mod_end, '', 0, 31 | 0, 1, 50, 32 | 'solved', 10, 1, log_path, result_path) 33 | 34 | def test(data_dir, test_path, sample_path, id_case, 35 | mod_start, mod_between, mod_end, out_dir): 36 | """ Solve test cases using previously generated samples. 37 | 38 | Args: 39 | data_dir: directory of database 40 | test_path: path to file with test cases 41 | sample_path: path to file with samples 42 | id_case: whether to consider letter case of identifiers 43 | mod_start: modification at plan start 44 | mod_between: modification between plan steps 45 | mod_end: modification at plan end 46 | out_dir: generate output in this directory 47 | """ 48 | for model_id in ['gpt-3.5-turbo', 'gpt-4o']: 49 | for prompt_style in ['plan']: 50 | for nr_samples in [2]: 51 | run_id = f'{model_id}_{prompt_style}_{nr_samples}' 52 | log_path = f'{out_dir}/log_{run_id}' 53 | result_path = f'{out_dir}/results_{run_id}.json' 54 | codexdb.solve.main( 55 | data_dir, test_path, 'python', 56 | model_id, prompt_style, id_case, 57 | mod_start, mod_between, mod_end, 58 | sample_path, nr_samples, 0, 2, 200, 59 | 'executed', 2, 1, log_path, result_path) 60 | 61 | 62 | if __name__ == '__main__': 63 | 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('ai_key', type=str, help='Key for OpenAI access') 66 | parser.add_argument('data_dir', type=str, help='Data directory') 67 | parser.add_argument('train_path', type=str, help='Path to train case file') 68 | parser.add_argument('test_path', type=str, help='Path to test case file') 69 | parser.add_argument('id_case', type=int, help='Consider letter case of IDs?') 70 | parser.add_argument('mod_start', type=str, help='Modifying plan start') 71 | parser.add_argument('mod_between', type=str, help='Modifications in between') 72 | parser.add_argument('mod_end', type=str, help='Modifying plan end') 73 | parser.add_argument('out_dir', type=str, help='Output directory') 74 | args = parser.parse_args() 75 | if os.listdir(args.out_dir): 76 | raise ValueError('Output directory must be empty!') 77 | 78 | print(f'Login: {os.getlogin()}') 79 | print(f'User: {getpass.getuser()}') 80 | 81 | openai.api_key = args.ai_key 82 | 83 | # Train and test generating code without modifications 84 | log_path = f'{args.out_dir}/train_log_plain' 85 | sample_path = f'{args.out_dir}/train_plain.json' 86 | 87 | start_s = time.time() 88 | train(args.data_dir, args.train_path, log_path, args.id_case, 89 | args.mod_start, args.mod_between, args.mod_end, sample_path) 90 | total_s = time.time() - start_s 91 | print(f'Training took {total_s} seconds') 92 | 93 | start_s = time.time() 94 | test(args.data_dir, args.test_path, sample_path, args.id_case, 95 | args.mod_start, args.mod_between, args.mod_end, args.out_dir) 96 | total_s = time.time() - start_s 97 | print(f'Testing took {total_s} seconds') -------------------------------------------------------------------------------- /src/codexdb/bench/alpha.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Nov 23, 2024 3 | 4 | @author: immanueltrummer 5 | ''' 6 | import argparse 7 | import json 8 | import pandas 9 | 10 | from pathlib import Path 11 | 12 | 13 | def get_inputs(db_dir, db_id): 14 | """ Extracts all .csv files in directory. 15 | 16 | Args: 17 | db_dir: directory containing databases. 18 | db_id: name of sub-directory. 19 | 20 | Returns: 21 | Dictionary mapping file names to data frames. 22 | """ 23 | name2df = {} 24 | db_path = Path(db_dir) / Path(db_id) 25 | for csv_path in db_path.glob('*.csv'): 26 | file_name = csv_path.name 27 | csv_file = pandas.read_csv(csv_path) 28 | name2df[file_name] = csv_file 29 | 30 | return name2df 31 | 32 | 33 | def table_to_JSON(table_name, table_df): 34 | """ Transforms a named data frame into a JSON representation. 35 | 36 | Args: 37 | table_name: name of the table. 38 | table_df: data frame. 39 | 40 | Returns: 41 | JSON dictionary with table details. 42 | """ 43 | return { 44 | 'name':table_name, 45 | 'headers':table_df.columns.values.tolist(), 46 | 'rows':table_df.values.tolist() 47 | } 48 | 49 | 50 | def make_alpha_test(test_name, db_description, query, results, modification): 51 | """ Create test case for AlphaCodium. 52 | 53 | Args: 54 | test_name: name of the test case. 55 | db_description: list of JSON dictionaries describing tables. 56 | query: the query to process. 57 | results: results from reference SQL engine. 58 | modification: instructions on code customization. 59 | 60 | Returns: 61 | JSON object suitable as input for AlphaChromium. 62 | """ 63 | task_parts = [] 64 | task_parts += [f'Write Python code implementing the SQL query "{query}".'] 65 | task_parts += ['The input is a database, represented as a list of Python dictionaries.'] 66 | task_parts += ['Each dictionary describes one table with fields "name", "headers", and "rows".'] 67 | task_parts += ['The "headers" field contains a list of column names.'] 68 | task_parts += ['The "rows" field contains a list of rows where each row is a list of values.'] 69 | task_parts += ['The output is the query result: a list of rows where each row is a list of values.'] 70 | task_parts += ['Tables:'] 71 | for table_info in db_description: 72 | table_name = table_info['name'] 73 | table_headers = table_info['headers'] 74 | task_parts += [f'{table_name} columns: {table_headers}'] 75 | task = '\n'.join(task_parts) 76 | alpha_test = { 77 | 'name':test_name, 78 | 'description':task, 79 | 'public_tests':{ 80 | 'input':[str(db_description)], 81 | 'output':[str(results)] 82 | } 83 | } 84 | return alpha_test 85 | 86 | 87 | if __name__ == '__main__': 88 | 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument('db_dir', type=str, help='Path to databases') 91 | parser.add_argument('test_path', type=str, help='Path to test cases') 92 | parser.add_argument('modification', type=str, help='Code modification') 93 | args = parser.parse_args() 94 | 95 | with open(args.test_path) as file: 96 | test_cases = json.load(file) 97 | 98 | for i in range(0, 200, 2): 99 | test_case = test_cases[i] 100 | 101 | db_id = test_case['db_id'] 102 | inputs = get_inputs(args.db_dir, db_id) 103 | db_description = [] 104 | for table_name, table_df in inputs.items(): 105 | db_description += [table_to_JSON(table_name, table_df)] 106 | 107 | query = test_case['query'] 108 | question = test_case['question'] 109 | results = test_case['results'] 110 | test_name = f'CodingTest{i}' 111 | alpha_test = make_alpha_test( 112 | test_name, db_description, query, 113 | results, args.modification) 114 | 115 | with open(f'{test_name}.json', 'w') as file: 116 | json.dump(alpha_test, file) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | CodexDB allows users to specify natural language instructions, together with their SQL queries. It uses OpenAI's GPT models to generate code for query processing that complies with those instructions. This enables far-reaching customization, ranging from the selection of frameworks for query processing to custom logging output. In doing so, CodexDB blurs the line between user and developer. 4 | 5 | # Setup 6 | 7 | The following instructions have been tested on an EC2 instance of type t2.medium with Ubuntu 22.04 OS, Python 3.12, and 25 GB of disk space. 8 | 9 | 1. After logging into the EC2 instance, run the following command (from ```/home/ubuntu```): 10 | ``` 11 | git clone https://github.com/itrummer/CodexDB 12 | ``` 13 | 2. Switch into the CodexDB root directory: 14 | ``` 15 | cd CodexDB/ 16 | ``` 17 | 3. Install pip if it is not yet installed, e.g., run: 18 | ``` 19 | sudo apt update 20 | sudo apt install python3-pip 21 | ``` 22 | 4. Create and activate a virtual environment and use pip to install dependencies: 23 | ``` 24 | python3 -m venv .venv 25 | source .venv/bin/activate 26 | pip install -r requirements.txt 27 | ``` 28 | 5. Download and unzip the SPIDER dataset for benchmarking: 29 | ``` 30 | cd .. 31 | pip install gdown 32 | gdown 1WwbJjvqPzLIcVmn8QcTOWr79oOFKFO6e 33 | sudo apt install unzip 34 | unzip spider_data.zip 35 | ``` 36 | 6. Pre-process the SPIDER data set: 37 | ``` 38 | cd CodexDB 39 | PYTHONPATH=src python3 src/codexdb/prep/spider.py /home/ubuntu/spider_data 40 | ``` 41 | 7. Set the following environment variables: 42 | - `CODEXDB_TMP` designates a working directory into which CodexDB writes temporary files (e.g., Python code for query execution). 43 | - `CODEXDB_PYTHON` is the name (or path) of the Python interpreter CodexDB uses to test the Python code it generates. 44 | E.g., set the two variables using the following commands: 45 | ``` 46 | export CODEXDB_TMP=/tmp 47 | export CODEXDB_PYTHON=python3 48 | ``` 49 | 50 | # Running CodexDB 51 | 52 | **WARNING: CodexDB generates Python code for query execution via large language models. Since CodexDB cannot guarantee to generate correct code, it is highly recommended to avoid running CodexDB on your primary machine. Instead, run CodexDB on a temporary EC2 instance and log into the Web interface from your primary machine.** 53 | 54 | 1. Start the CodexDB Web interface (replace `[OPENAI_API_ACCESS_KEY]` with your OpenAI access key!): 55 | ``` 56 | streamlit run src/codexdb/gui.py [OPENAI_API_ACCESS_KEY] /home/ubuntu/spider_data 57 | ``` 58 | 2. After executing the command above, you should see two URLs on the console: 59 | - Network URL 60 | - External URL 61 | 62 | If using CodexDB on your local machine, open the first URL on your Web browser. If using CodexDB on a remote machine, open the second URL via your local Web browser. You may have to enable external access in the second case. E.g., when running CodexDB on Amazon EC2, make sure to add an inbound rule allowing TCP access on port 8501. 63 | 64 | # Troubleshooting 65 | 66 | CodexDB only works with specific versions of the `sqlglot` SQL parsing library. If you encounter frequent errors in `plan.py`, check the installed version of sqlglot by running `pip show sqlglot` in the terminal. The required version is 1.16.1. If you see a different version number, uninstall sqlglot (`sudo pip uninstall sqlglot`) and reinstall the required version (e.g., by running `pip install sqlglot==1.16.1`). 67 | 68 | CodexDB only supports a restricted class of SQL queries via the "plan" prompt. In particular, it only supports the specific join syntax used in the queries of the SPIDER benchmark. If your query falls outside of the class of supported queries, you can switch to the "query" prompt by selecting the corresponding prompt style in the "Prompt Configuration" section (see buttons on the left side of the Web interface). This prompt style does not integrate a summary of processing steps into the prompt and may therefore degrade quality. 69 | 70 | ## How to cite 71 | 72 | ``` 73 | @article{Trummer2022b, 74 | author = {Trummer, Immanuel}, 75 | journal = {PVLDB}, 76 | number = {11}, 77 | pages = {2921 -- 2928}, 78 | title = {{CodexDB: Synthesizing code for query processing from natural language instructions using GPT-3 Codex}}, 79 | volume = {15}, 80 | year = {2022} 81 | } 82 | ``` 83 | -------------------------------------------------------------------------------- /src/codexdb/prep/spider.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Sep 19, 2021 3 | 4 | @author: immanueltrummer 5 | ''' 6 | import argparse 7 | import collections 8 | import json 9 | import pandas as pd 10 | import sqlite3 11 | 12 | def get_db_path(spider_dir, db_id): 13 | """ Return path to SQLite database file. 14 | 15 | Args: 16 | spider_dir: path to SPIDER benchmark 17 | db_id: database identifier 18 | 19 | Returns: 20 | path to SQLite database file 21 | """ 22 | return f'{spider_dir}/database/{db_id}/{db_id}.sqlite' 23 | 24 | 25 | def extract(spider_dir, db_json): 26 | """ Extract data from database into .csv files. 27 | 28 | Args: 29 | spider_dir: path to SPIDER main directory 30 | db_json: JSON description of database 31 | """ 32 | db_id = db_json['db_id'] 33 | db_dir = f'{spider_dir}/database/{db_id}' 34 | db_path = f'{db_dir}/{db_id}.sqlite' 35 | print(f'Path to DB: {db_path}') 36 | with sqlite3.connect(db_path) as con: 37 | #con.text_factory = bytes 38 | con.text_factory = lambda b: b.decode(errors = 'ignore') 39 | for tbl in db_json['table_names_original']: 40 | query = f'select * from {tbl}' 41 | df = pd.read_sql_query(query, con) 42 | out_path = f'{db_dir}/{tbl}.csv' 43 | df.to_csv(out_path, index=False) 44 | 45 | 46 | def get_result(spider_dir, query_json): 47 | """ Execute query and return result. 48 | 49 | Args: 50 | spider_dir: path to SPIDER benchmark 51 | query_json: describes query by JSON 52 | 53 | Returns: 54 | query result 55 | """ 56 | db_id = query_json['db_id'] 57 | db_path = get_db_path(spider_dir, db_id) 58 | sql = query_json['query'] 59 | with sqlite3.connect(db_path) as con: 60 | cur = con.cursor() 61 | cur.execute(sql) 62 | result = cur.fetchall() 63 | 64 | print(f'Query: {sql}; Result: {result}') 65 | return result 66 | 67 | 68 | if __name__ == '__main__': 69 | 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument('spider', type=str, help='Path to SPIDER benchmark') 72 | args = parser.parse_args() 73 | 74 | tables_path = f'{args.spider}/tables.json' 75 | db_to_s = {} 76 | with open(tables_path) as file: 77 | tables = json.load(file) 78 | nr_dbs = len(tables) 79 | for db_idx, db in enumerate(tables): 80 | db_id = db['db_id'] 81 | db_to_s[db_id] = db 82 | print(f'Extracting {db_id} ({db_idx+1}/{nr_dbs})') 83 | extract(args.spider, db) 84 | db_path = f'{args.spider}/schemata.json' 85 | with open(db_path, 'w') as file: 86 | json.dump(db_to_s, file) 87 | 88 | for in_file in ['train_spider', 'dev']: 89 | db_to_q = collections.defaultdict(lambda:[]) 90 | all_results = [] 91 | train_path = f'{args.spider}/{in_file}.json' 92 | with open(train_path) as file: 93 | queries = json.load(file) 94 | nr_queries = len(queries) 95 | nr_valid = 0 96 | 97 | for q_idx, q_json in enumerate(queries): 98 | query = q_json['query'] 99 | question = q_json['question'] 100 | db_id = q_json['db_id'] 101 | print(f'"{query}" on "{db_id}" ({q_idx+1}/{nr_queries})') 102 | 103 | db_to_q[db_id].append(q_json) 104 | try: 105 | result = get_result(args.spider, q_json) 106 | row = { 107 | 'db_id':db_id, 'question':question, 108 | 'query':query, 'results':result} 109 | all_results.append(row) 110 | nr_valid += 1 111 | except: 112 | print(f'Invalid Query: {query} on {db_id}') 113 | 114 | print(f'Processed {nr_valid}/{nr_queries} queries') 115 | results_path = f'{args.spider}/results_{in_file}.json' 116 | with open(results_path, 'w') as file: 117 | json.dump(all_results, file) 118 | 119 | q_path = f'{args.spider}/{in_file}_queries.json' 120 | with open(q_path, 'w') as file: 121 | json.dump(db_to_q, file) -------------------------------------------------------------------------------- /src/codexdb/prep/wikisql.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Jan 15, 2022 3 | 4 | @author: immanueltrummer 5 | ''' 6 | import argparse 7 | import json 8 | import jsonlines 9 | import lib.dbengine 10 | import lib.query 11 | import os 12 | import pandas as pd 13 | import sqlite3 14 | 15 | 16 | def sql_name(raw_name): 17 | """ Cleaned column name. 18 | 19 | Args: 20 | raw_name: raw column name 21 | 22 | Returns: 23 | cleaned name suitable for SQL column 24 | """ 25 | sql_name = raw_name 26 | for to_replace in ['"', ' ', '\\', '/', '(', ')', '.']: 27 | sql_name = sql_name.replace(to_replace, '_') 28 | return sql_name.strip() 29 | 30 | 31 | def extract_data(source_dir, split, target_dir): 32 | """ Extract data from given split and store on hard disk. 33 | 34 | Args: 35 | source_dir: source data directory 36 | split: treat this split of data 37 | target_dir: write into this directory 38 | """ 39 | tbl_path = f'{source_dir}/{split}.tables.jsonl' 40 | db_path = f'{source_dir}/{split}.db' 41 | 42 | with jsonlines.open(tbl_path) as file: 43 | tables = list(file) 44 | 45 | with sqlite3.connect(db_path) as connection: 46 | for table in tables: 47 | table_id = table['id'] 48 | table_name = 'table_' + table_id.replace('-','_') 49 | query = f'select * from {table_name}' 50 | df = pd.read_sql_query(query, connection) 51 | raw_columns = table['header'] 52 | df.columns = [sql_name(c) for c in raw_columns] 53 | out_dir = f'{target_dir}/database/{table_id}' 54 | os.makedirs(out_dir, exist_ok=True) 55 | out_path = f'{out_dir}/Data.csv' 56 | df.to_csv(out_path, index=False) 57 | 58 | 59 | def extract_schemata(source_dir, split): 60 | """ Extract table schemata from given file. 61 | 62 | Args: 63 | source_dir: source directory for WikiSQL benchmark 64 | split: treat this split of data 65 | 66 | Returns: 67 | schemata of all tables in database 68 | """ 69 | tbl_path = f'{source_dir}/{split}.tables.jsonl' 70 | with jsonlines.open(tbl_path) as file: 71 | tables = list(file) 72 | 73 | schemata = {} 74 | for table in tables: 75 | columns = [sql_name(h) for h in table['header']] 76 | idx_cols = [(0, c) for c in columns] 77 | schema = {} 78 | schema['table_names_original'] = ['Data'] 79 | schema['column_names_original'] = idx_cols 80 | db_id = table['id'] 81 | schema['db_id'] = db_id 82 | schemata[db_id] = schema 83 | 84 | return schemata 85 | 86 | 87 | def extract_tests(source_dir, split, target_dir): 88 | """ Extract test cases from file. 89 | 90 | Args: 91 | source_dir: source directory of WikiSQL 92 | split: extract queries from this split 93 | target_dir: target directory for data 94 | 95 | Returns: 96 | list of extracted test cases 97 | """ 98 | in_path = f'{source_dir}/{split}.jsonl' 99 | with jsonlines.open(in_path) as file: 100 | out_cases = [] 101 | for idx, in_case in enumerate(file): 102 | # if idx == 3729: 103 | # print('Here!') 104 | out_case = {} 105 | db_id = in_case['table_id'] 106 | out_case['db_id'] = db_id 107 | csv_path = f'{target_dir}/database/{db_id}/Data.csv' 108 | db_path = '/tmp/tmp.db' 109 | df = pd.read_csv(csv_path) 110 | with sqlite3.connect(db_path) as connection: 111 | cursor = connection.cursor() 112 | cursor.execute('drop table if exists Data') 113 | connection.commit() 114 | df.to_sql('Data', connection, index=False) 115 | # for row in cursor.execute('select * from Data'): 116 | # print(row) 117 | # for row in cursor.execute('select sql, tbl_name from sqlite_master'): 118 | # print(row) 119 | 120 | engine = lib.dbengine.DBEngine(db_path) 121 | query_template = lib.query.Query.from_dict(in_case['sql'], True) 122 | query, raw_result = engine.execute_query( 123 | 'Data', query_template, lower=True) 124 | # print(f'Query: {query}') 125 | out_case['question'] = in_case['question'] 126 | out_case['query'] = str(query) 127 | result = [[r] for r in raw_result if r is not None] 128 | out_case['results'] = result 129 | out_cases.append(out_case) 130 | 131 | return out_cases 132 | 133 | if __name__ == '__main__': 134 | 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument('source_dir', type=str, help='Path of WikiSQL directory') 137 | parser.add_argument('target_dir', type=str, help='Write test cases here') 138 | args = parser.parse_args() 139 | 140 | schemata = {} 141 | test_cases = [] 142 | 143 | for split in ['dev', 'test', 'train']: 144 | print(f'Processing {split} split ...') 145 | 146 | extract_data(args.source_dir, split, args.target_dir) 147 | split_schemata = extract_schemata(args.source_dir, split) 148 | schemata = {**schemata, **split_schemata} 149 | 150 | tests = extract_tests(args.source_dir, split, args.target_dir) 151 | test_out = f'{args.source_dir}/results_{split}.json' 152 | with open(test_out, 'w') as file: 153 | json.dump(tests, file) 154 | 155 | schema_out = f'{args.target_dir}/schemata.json' 156 | with open(schema_out, 'w') as file: 157 | json.dump(schemata, file) -------------------------------------------------------------------------------- /src/codexdb/gui.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Aug 23, 2022 3 | 4 | @author: immanueltrummer 5 | ''' 6 | import argparse 7 | import openai 8 | import os 9 | import pandas as pd 10 | import pathlib 11 | import streamlit as st 12 | import sqlite3 13 | import sys 14 | 15 | cur_file_dir = os.path.dirname(__file__) 16 | src_dir = pathlib.Path(cur_file_dir).parent 17 | root_dir = src_dir.parent 18 | sys.path.append(str(src_dir)) 19 | sys.path.append(str(root_dir)) 20 | print(f'sys.path: {sys.path}') 21 | 22 | import codexdb.catalog 23 | import codexdb.code 24 | import codexdb.engine 25 | import codexdb.solve 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('ai_key', type=str, help='Access key for OpenAI platform') 29 | parser.add_argument('data_dir', type=str, help='Path to data directory') 30 | args = parser.parse_args() 31 | 32 | openai.api_key = args.ai_key 33 | catalog = codexdb.catalog.DbCatalog(args.data_dir) 34 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 35 | 36 | st.set_page_config(page_title='CARD') 37 | st.markdown(''' 38 | # CARD 39 | CARD: the Coding Assistant for Data. 40 | ''') 41 | 42 | 43 | with st.sidebar: 44 | 45 | with st.expander('Data Source'): 46 | db_ids = catalog.db_ids() 47 | db_id = st.selectbox('Select source database:', options=db_ids) 48 | 49 | schema = catalog.schema(db_id) 50 | all_tables = schema['table_names_original'] 51 | all_columns = schema['column_names_original'] 52 | for table_idx, table in enumerate(all_tables): 53 | columns = [c[1] for c in all_columns if c[0] == table_idx] 54 | st.write(f'{table}({", ".join(columns)})') 55 | 56 | 57 | with st.expander('Model Configuration'): 58 | 59 | model_ids = ['gpt-4', 'gpt-3.5-turbo'] 60 | model_id = st.selectbox( 61 | 'Select GPT Model:', 62 | options=model_ids, index=0) 63 | 64 | start_temp = float(st.slider( 65 | 'Start temperature:', 66 | min_value=0.0, max_value=1.0)) 67 | final_temp = float(st.slider( 68 | 'Final temperature:', 69 | min_value=0.0, max_value=1.0, value=0.5)) 70 | 71 | 72 | with st.expander('Prompt Configuration'): 73 | 74 | prompt_styles = ['query', 'plan'] 75 | prompt_style = st.selectbox( 76 | 'Select prompt style:', 77 | options=prompt_styles, index=1) 78 | 79 | nr_samples = int(st.slider( 80 | 'Number of samples in prompt:', 81 | min_value=0, max_value=6)) 82 | 83 | 84 | with st.expander('Termination Condition'): 85 | termination_options = [ 86 | 'No Condition', 'Executable Code', 'Correct Result'] 87 | condition = st.selectbox( 88 | 'Select termination condition:', options=range(3), 89 | format_func=lambda i:termination_options[i], index=2) 90 | # execute_code = st.checkbox('Check if Executable', value=1) 91 | # verify_result = st.checkbox('Verify Query Result', value=1) 92 | 93 | 94 | with st.expander('Code Customization'): 95 | mod_start = st.text_input('General instructions (natural language):') 96 | mod_between = st.text_input('Per-step instructions (natural language):') 97 | mod_end = '' 98 | 99 | 100 | id_case = 0 101 | query = st.text_input('Write query:') 102 | 103 | max_tries = int(st.slider( 104 | 'Number of generation tries:', 105 | min_value=1, max_value=10, value=3)) 106 | 107 | examples = [] 108 | temp_delta = final_temp - start_temp 109 | temp_step = 0 if max_tries == 1 else temp_delta / (max_tries - 1.0) 110 | test_case = {'question':'', 'query':query, 'db_id':db_id} 111 | reorder = False if 'order by' in query.lower() else True 112 | 113 | coder = codexdb.code.PythonGenerator( 114 | catalog, examples, nr_samples, 115 | prompt_style, model_id, 116 | id_case=id_case, 117 | mod_start=mod_start, 118 | mod_between=mod_between, 119 | mod_end=mod_end) 120 | engine = codexdb.engine.PythonEngine( 121 | catalog, id_case) 122 | 123 | 124 | if st.button('Generate Code'): 125 | 126 | if condition > 1: 127 | sqlite_path = f'{args.data_dir}/database/{db_id}/{db_id}.sqlite' 128 | with sqlite3.connect(sqlite_path) as con: 129 | ref_result = pd.read_sql_query(query, con) 130 | 131 | for try_idx in range(max_tries): 132 | 133 | st.subheader(f'Trial {(try_idx+1)} ...') 134 | 135 | db_id = test_case['db_id'] 136 | schema = catalog.schema(db_id) 137 | files = catalog.files(db_id) 138 | db_dir = catalog.db_dir(db_id) 139 | prompt = coder.get_prompt(schema, db_dir, files, '', query) 140 | 141 | temperature = start_temp + temp_step * try_idx 142 | gen_stats, code = coder.generate(test_case, temperature) 143 | st.write('Code Generated by CARD:') 144 | st.code(code, language='python') 145 | 146 | if condition > 0: 147 | executed, codb_result, elapsed_s = engine.execute(db_id, code, 30) 148 | if condition > 1: 149 | comparable, nr_diffs, similarity = codexdb.solve.result_cmp( 150 | ref_result, codb_result, reorder) 151 | with st.expander(f'Result Similarity: {similarity}'): 152 | st.write('Reference Result:') 153 | st.dataframe(ref_result) 154 | st.write('CARD Result:') 155 | st.dataframe(codb_result) 156 | 157 | with st.expander('Input Prompt'): 158 | st.code(prompt, language='python') 159 | 160 | if (condition == 1 and executed) or \ 161 | (condition == 2 and similarity >= 1.0): 162 | st.write('Termination Criterion Satisfied.') 163 | break 164 | -------------------------------------------------------------------------------- /prompt_collection: -------------------------------------------------------------------------------- 1 | """ 2 | This Python program answers the query "What is the average age of customers?" on the following tables: 3 | - Table customers, columns customer_id, name, age, zip, stored in 'customer.csv'. 4 | The program performs the followings steps: 5 | 1. Import libraries for efficient data processing. 6 | 2. Load data for all tables. 7 | 3. Process the query. 8 | 4. Enable display for all rows and columns, strings of infinite length. 9 | 5. Print out the query result. 10 | """ 11 | 12 | --- Start of Python program --- 13 | 14 | --- 15 | 16 | """ 17 | File "customer.csv" stores table customer with columns customer_ID, age, name, zip. 18 | Write a Python program for data format transformation: 19 | 1. Load libraries for efficient data formats. 20 | 2. Read "customer.csv" from disk. 21 | 3. Change into more efficient format. 22 | 4. Write changed file to hard disk. 23 | 5. Print only the name of the new file. 24 | 6. --- End of Python program --- 25 | """ 26 | 27 | --- Begin of Python program --- 28 | 29 | --- 30 | Write a script for executing a C program, stored in file "testProgram", on Linux (no prompts): 31 | 32 | --- Script starts here --- 33 | 34 | --- 35 | Write a script for executing a Python program, stored in file "helloWorld", with Linux (no prompts): 36 | 37 | --- Script starts here --- 38 | --- 39 | """ 40 | This Python program answers the query "What is the average age of customers?" on the following tables: 41 | - Table customers, columns customer_id, name, age, zip, stored in 'customer.csv'. 42 | The program performs the followings steps: 43 | 1. Import libraries for efficient data processing. 44 | 2. Load data for all tables. 45 | 3. Process the query. 46 | 4. Print out the query result. 47 | """ 48 | 49 | --- 50 | 51 | # Answer the query "What is the average age of customers?" on the following tables: 52 | # - Table customers, columns customer_id, name, age, zip, stored in 'customer.csv'. 53 | 54 | #!/bin/bash 55 | 56 | --- 57 | 58 | """ 59 | Write Python program to transform 'customer.csv' into a more efficient file format. 60 | 1. Read data from 'customer.csv'. 61 | 2. Transform into efficient format for large-scale processing. 62 | 3. Write transformed file to hard disk. 63 | 4. Print out name of new file only. 64 | """ 65 | 66 | --- 67 | 68 | ///// Translate query from SQL to C++. 69 | /// SQL query: 70 | // Table lineorder, columns lo_extendedprice (numeric), lo_discount (numeric), lo_quantity (integer) 71 | // Table date, columns d_year (integer) 72 | select sum(lo_extendedprice*lo_discount) as revenue 73 |  from lineorder, date 74 |  where lo_orderdate = d_datekey 75 |  and d_year = 1993 76 |  and lo_discount between 2 - 1 77 |  and 2 + 1 and lo_quantity < 24; 78 | 79 | /// C++ query: 80 | // 1. Create index. 81 | // 2. Print query result. 82 | std::vector lineorder; 83 | std::vector date; 84 | 85 | // 1. Create index. 86 | std::vector::iterator lo_it; 87 | std::vector::iterator d_it; 88 | 89 | for (lo_it = lineorder.begin(); lo_it != lineorder.end(); ++lo_it) { 90 |     for (d_it = date.begin(); d_it != date.end(); ++d_it) { 91 |         if (lo_it->lo_orderdate == d_it->d_datekey) { 92 |             lo_it->d_it = d_it; 93 |         } 94 |     } 95 | } 96 | 97 | // 2. Print query result. 98 | double revenue = 0; 99 | for (lo_it = lineorder.begin(); lo_it != lineorder.end(); ++lo_it) { 100 |     if (lo_it->d_it->d_year == 1993) { 101 |         if (lo_it->lo_discount >= 2 - 1 && lo_it->lo_discount <= 2 + 1) { 102 |             if (lo_it->lo_quantity < 24) { 103 |                 revenue += lo_it->lo_extendedprice * lo_it->lo_discount; 104 |             } 105 |         } 106 |     } 107 | } 108 | std::cout << "revenue = " << revenue << std::endl; 109 | --- 110 | Programming and scripting languages supported by Codex: 111 | 1. 112 | 113 | --- 114 | 115 | # Create an index on the following table: 116 | # Table customer, columns name, customer_id, age, zip, stored in 'customer.csv'. 117 | # Execute the following steps: 118 | # 1. Read customer.csv from disk. 119 | # 2. Create a B+ tree index on the data. 120 | # 3. Write index to file on hard disk. 121 | # 4. Print out name of index file only. 122 | 123 | --- 124 | # This Python program indexes the file customer.csv. The program has no input arguments. 125 | # Table customer, columns name, customer_id, age, zip, stored in 'customer.csv'. 126 | # Execute the following steps: 127 | # 1. Read customer.csv from disk. 128 | # 2. Create a B+ tree index on the data. 129 | # 3. Write index to file on hard disk. 130 | # 4. Print out name of index file only. 131 | 132 | 133 | --- 134 | # This Python program indexes the file customer.csv. The program has no input arguments. 135 | # Table customer, columns name, customer_id, age, zip, stored in customer.csv. 136 | # Execute the following steps: 137 | # 1. Read customer.csv from disk. 138 | # 2. Create a B+ tree index on the data. 139 | # 3. Write index to file on hard disk. 140 | # 4. Print out name of index file only. 141 | 142 | --- 143 | """ 144 | This Python program answers the query "What is the total number of singers?" on the following tables: 145 | Table stadium with columns Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average, stored in stadium.csv. 146 | Table singer with columns Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male, stored in singer.csv. 147 | Table concert with columns concert_ID, concert_Name, Theme, Stadium_ID, Year, stored in concert.csv. 148 | Table singer_in_concert with columns concert_ID, Singer_ID, stored in singer_in_concert.csv. 149 | 1. Load data for all relevant tables. 150 | 2. Calculate the answer to the query. 151 | 3. Write query results to file 'result.csv'. 152 | Generate code to reduce main memory footprint as much as possible. 153 | """ 154 | --- 155 | 1. Import library for processing tabular data in parallel. 156 | 1. Import library for processing tabular data on GPU. 157 | 1. Import library for processing tabular data on TPU. 158 | (with temperature 0.2 - no start/end tags, 400 tokens) -------------------------------------------------------------------------------- /src/codexdb/bench/scale.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Jan 29, 2022 3 | 4 | @author: immanueltrummer 5 | ''' 6 | import argparse 7 | import codexdb.catalog 8 | import codexdb.engine 9 | import json 10 | import math 11 | import os 12 | import pandas as pd 13 | import time 14 | 15 | def get_code(language, test_case): 16 | """ Extract code in specified language from test case. 17 | 18 | Args: 19 | language: extract code in this language from case 20 | test_case: test case containing code to execute 21 | 22 | Returns: 23 | code in specified language extracted from test case 24 | """ 25 | if language == 'python': 26 | return test_case['code'] 27 | elif language == 'sql': 28 | return test_case['query'] 29 | else: 30 | raise ValueError(f'Unknown language: {language}') 31 | 32 | def get_engine(language): 33 | """ Returns processing engine for given language. 34 | 35 | Args: 36 | language: create engine processing code in this language 37 | 38 | Returns: 39 | an engine that can process the given language 40 | """ 41 | if language == 'python': 42 | return codexdb.engine.PythonEngine(catalog) 43 | elif language == 'sql': 44 | return codexdb.engine.SqliteEngine(catalog) 45 | else: 46 | raise ValueError(f'Unknown implementation language: {args.language}!') 47 | 48 | def scale_data(source_path, factor, target_path): 49 | """ Duplicates rows in source file by given factor. 50 | 51 | Args: 52 | source_path: read data from this file 53 | factor: duplicate rows by this factor 54 | target_path: write scaled data here 55 | """ 56 | os.system(f'cp {source_path} /tmp/scaled1') 57 | nr_iterations = math.ceil(math.log(factor, 2)) 58 | for i in range(nr_iterations): 59 | print(f'Doubling rows - iteration {i} ...') 60 | # Double the number of rows (without header) 61 | os.system(f'cat /tmp/scaled1 > /tmp/scaled2') 62 | os.system(f'tail -n +2 /tmp/scaled1 >> /tmp/scaled2') 63 | os.system('cp /tmp/scaled2 /tmp/scaled1') 64 | os.system(f'cp /tmp/scaled1 {target_path}') 65 | 66 | def scale_tables(catalog, db_id, factor, code): 67 | """ Scale up data size of tables in database by given factor. 68 | 69 | Args: 70 | catalog: contains information on database schema 71 | db_id: scale up tables of this database 72 | factor: scale up data size by approximately this factor 73 | code: code referencing original data files 74 | 75 | Returns: 76 | code referencing scaled data files, byte sizes, #rows 77 | """ 78 | scaled_code = code 79 | schema = catalog.schema(db_id) 80 | tables = schema['table_names_original'] 81 | table_byte_sizes = [] 82 | table_nr_rows = [] 83 | for table in tables: 84 | original_file = catalog.file_name(db_id, table) 85 | original_path = catalog.file_path(db_id, table) 86 | scaled_file = f'xxl_{original_file}' 87 | catalog.assign_file(db_id, table, scaled_file) 88 | scaled_path = catalog.file_path(db_id, table) 89 | scale_data(original_path, factor, scaled_path) 90 | scaled_code = scaled_code.replace(original_file, scaled_file) 91 | byte_size = os.path.getsize(scaled_path) 92 | nr_rows = sum(1 for _ in open(scaled_path)) 93 | table_byte_sizes += [byte_size] 94 | table_nr_rows += [nr_rows] 95 | return scaled_code, table_byte_sizes, table_nr_rows 96 | 97 | def unscale_tables(catalog, db_id): 98 | """ Replace scaled tables by the original. 99 | 100 | Args: 101 | catalog: information on the database schema 102 | db_id: unscale all tables in this database 103 | """ 104 | schema = catalog.schema(db_id) 105 | tables = schema['table_names_original'] 106 | for table in tables: 107 | key = (db_id, table) 108 | del catalog.table_to_file[key] 109 | 110 | def test_performance(engine, db_id, factor, code, timeout_s): 111 | """ Measure performance when processing code on given engine. 112 | 113 | Args: 114 | engine: use this engine for code processing 115 | db_id: execute code on this database 116 | factor: scale number of rows in tables by this factor 117 | code: measure performance when executing this code 118 | timeout_s: timeout in seconds for each execution 119 | 120 | Returns: 121 | performance and size statistics 122 | """ 123 | print('Starting data scaling ...') 124 | scaled_code, byte_sizes, row_sizes = scale_tables( 125 | catalog, db_id, factor, code) 126 | print('Scaling finished - starting measurements ...') 127 | start_s = time.time() 128 | _, _, stats = engine.execute(db_id, scaled_code, timeout_s) 129 | total_s = time.time() - start_s 130 | print('Execution finished - unscaling tables ...') 131 | unscale_tables(catalog, db_id) 132 | stats['total_s'] = total_s 133 | stats['byte_sizes'] = byte_sizes 134 | stats['row_sizes'] = row_sizes 135 | return stats 136 | 137 | 138 | if __name__ == '__main__': 139 | 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument('data_dir', type=str, help='Data directory') 142 | parser.add_argument('language', type=str, help='Implementation language') 143 | parser.add_argument('test_path', type=str, help='Path to file with tests') 144 | parser.add_argument('nr_tests', type=int, help='How many test cases') 145 | parser.add_argument('timeout_s', type=int, help='Timeout in seconds') 146 | args = parser.parse_args() 147 | 148 | catalog = codexdb.catalog.DbCatalog(args.data_dir) 149 | with open(args.test_path) as file: 150 | test_cases = json.load(file) 151 | engine = get_engine(args.language) 152 | 153 | nr_all_tests = len(test_cases) 154 | nr_tests = min(args.nr_tests, nr_all_tests) 155 | factors = [1000, 1000000] 156 | nr_factors = len(factors) 157 | 158 | times_path = 'times.csv' 159 | results_path = 'stats.json' 160 | if os.path.exists(times_path): 161 | raise ValueError(f'Error - {times_path} exists!') 162 | if os.path.exists(results_path): 163 | raise ValueError(f'Error - {results_path} exists!') 164 | 165 | results = [] 166 | for test_case_id in range(nr_tests): 167 | test_case_key = str(test_case_id) 168 | tries = test_cases[test_case_key] 169 | test_case = tries[-1] 170 | for factor in factors: 171 | print(f'Treating test case {test_case_id}, factor {factor}') 172 | try: 173 | if test_case['similarity'] == 1.0 or args.language == 'sql': 174 | db_id = test_case['schema']['db_id'] 175 | code = get_code(args.language, test_case) 176 | stats = test_performance( 177 | engine, db_id, factor, 178 | code, args.timeout_s) 179 | stats['scaling_factor'] = factor 180 | print(f'Run statistics: {stats}') 181 | results += [stats] 182 | else: 183 | results += [{'total_s':-1, 'scaling_factor':factor}] 184 | except Exception as e: 185 | print(f'Exception: {e}') 186 | 187 | by_factors = {} 188 | for idx, factor in enumerate(factors): 189 | factor_times = [r['total_s'] for r in results[idx::nr_factors]] 190 | by_factors[factor] = factor_times 191 | times_df = pd.DataFrame(by_factors) 192 | times_df.to_csv(times_path, index=False) 193 | print(times_df) 194 | 195 | with open(results_path, 'w') as file: 196 | json.dump(results, file) -------------------------------------------------------------------------------- /src/codexdb/engine.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Oct 3, 2021 3 | 4 | @author: immanueltrummer 5 | ''' 6 | import abc 7 | import os 8 | import pandas as pd 9 | import subprocess 10 | import sys 11 | import sqlite3 12 | import time 13 | 14 | class ExecutionEngine(abc.ABC): 15 | """ Executes code in different languages. """ 16 | 17 | def __init__(self, catalog): 18 | """ Initialize with database catalog and variables. 19 | 20 | Args: 21 | catalog: informs on database schema and file locations 22 | """ 23 | self.catalog = catalog 24 | self.tmp_dir = os.environ['CODEXDB_TMP'] 25 | self.result_path = f'{self.tmp_dir}/result.csv' 26 | 27 | @abc.abstractmethod 28 | def execute(self, db_id, code, timeout_s): 29 | """ Execute code written in specified language. 30 | 31 | Args: 32 | db_id: code references data in this database 33 | code: execute this code 34 | timeout_s: execution timeout in seconds 35 | 36 | Returns: 37 | Boolean success flag, output, execution statistics 38 | """ 39 | raise NotImplementedError() 40 | 41 | def _clean(self): 42 | """ Cleans up working directory before execution. 43 | 44 | The result file may have been generated either as 45 | file or as directory. This handles multiple cases. 46 | """ 47 | subprocess.run(['rm', f'{self.tmp_dir}/result.csv/*']) 48 | subprocess.run(['rm', '-d', f'{self.tmp_dir}/result.csv']) 49 | subprocess.run(['rm', f'{self.tmp_dir}/result.csv']) 50 | 51 | def _copy_db(self, db_id): 52 | """ Copies data to a temporary directory. 53 | 54 | Args: 55 | db_id: database ID 56 | """ 57 | src_dir = self.catalog.db_dir(db_id) 58 | for tbl_file in self.catalog.files(db_id): 59 | src_path = f'{src_dir}/{tbl_file}' 60 | if self.id_case: 61 | cmd = f'sudo cp -r {src_path} {self.tmp_dir}' 62 | os.system(cmd) 63 | else: 64 | with open(src_path) as file: 65 | lines = file.readlines() 66 | lines[0] = lines[0].lower() 67 | to_path = f'{self.tmp_dir}/{tbl_file.lower()}' 68 | with open(to_path, 'w') as file: 69 | for line in lines: 70 | file.write(line) 71 | 72 | def _expand_paths(self, db_id, code): 73 | """ Expand relative paths to data files in code. 74 | 75 | Args: 76 | db_id: database identifier 77 | code: generated code 78 | 79 | Returns: 80 | code after expanding paths 81 | """ 82 | for file in self.catalog.files(db_id): 83 | for quote in ['"', "'"]: 84 | file_path = f'{quote}{file}{quote}' 85 | full_path = f'{quote}{self.tmp_dir}/{file}{quote}' 86 | code = code.replace(file_path, full_path) 87 | 88 | prefix = f"import os\nos.chdir('{self.tmp_dir}')\n" 89 | return prefix + code 90 | 91 | def _write_file(self, filename, code): 92 | """ Write code into file in temporary directory. 93 | 94 | Args: 95 | db_id: database ID 96 | filename: name of code file 97 | code: write code into this file 98 | 99 | """ 100 | file_path = f'{self.tmp_dir}/{filename}' 101 | with open(file_path, 'w') as file: 102 | file.write(code) 103 | 104 | 105 | class PythonEngine(ExecutionEngine): 106 | """ Executes Python code. """ 107 | 108 | def __init__(self, catalog, id_case): 109 | """ Initialize with database catalog and paths. 110 | 111 | Args: 112 | catalog: informs on database schema and file locations 113 | id_case: whether to consider letter case for identifiers 114 | """ 115 | super().__init__(catalog) 116 | self.id_case = id_case 117 | self.python_path = os.environ['CODEXDB_PYTHON'] 118 | 119 | def execute(self, db_id, code, timeout_s): 120 | """ Execute code written in specified language. 121 | 122 | Args: 123 | db_id: code references data in this database 124 | code: code for processing query 125 | timeout_s: execution timeout in seconds 126 | 127 | Returns: 128 | Boolean success flag, output, execution statistics 129 | """ 130 | self._clean() 131 | self._copy_db(db_id) 132 | start_s = time.time() 133 | success, output, stats = self._exec_python(db_id, code, timeout_s) 134 | total_s = time.time() - start_s 135 | stats['total_s'] = total_s 136 | return success, output, stats 137 | 138 | def _exec_python(self, db_id, code, timeout_s): 139 | """ Execute Python code and return generated output. 140 | 141 | Args: 142 | db_id: database identifier 143 | code: Python code to execute 144 | timeout_s: execution timeout in seconds 145 | 146 | Returns: 147 | Success flag, output, and execution statistics 148 | """ 149 | filename = 'execute.py' 150 | code = self._expand_paths(db_id, code) 151 | print('--- EXECUTED CODE ---') 152 | print(code) 153 | print('--- (EXECUTED CODE) ---') 154 | self._write_file(filename, code) 155 | exe_path = f'{self.tmp_dir}/{filename}' 156 | cmd_parts = ['timeout', str(timeout_s), self.python_path, exe_path] 157 | sub_comp = subprocess.run(cmd_parts, capture_output=True) 158 | success = False if sub_comp.returncode > 0 else True 159 | if not success: 160 | print(f'Python stdout: {sub_comp.stdout}') 161 | print(f'Python stderr: {sub_comp.stderr}') 162 | output = pd.DataFrame([[]]) 163 | else: 164 | try: 165 | output = pd.read_csv(self.result_path) 166 | except: 167 | e = sys.exc_info()[0] 168 | print(f'Exception while reading result file: {e}') 169 | output = pd.DataFrame([[]]) 170 | return success, output, {} 171 | 172 | 173 | class SqliteEngine(ExecutionEngine): 174 | """ SQL execution engine using SQLite. """ 175 | 176 | def __init__(self, catalog): 177 | """ Initialize with given catalog. 178 | 179 | Args: 180 | catalog: information about database schemata 181 | """ 182 | super().__init__(catalog) 183 | 184 | def execute(self, db_id, sql, timeout_s): 185 | """ Execute given SQL query. 186 | 187 | Args: 188 | db_id: ID of database (in catalog) 189 | sql: SQL query to execute on database 190 | timeout_s: execution timeout in seconds 191 | 192 | Returns: 193 | Success flag, output, and execution statistics 194 | """ 195 | self._prepare_db(db_id) 196 | return self._execute(db_id, sql, timeout_s) 197 | 198 | def _execute(self, db_id, sql, timeout_s): 199 | """ Execute given SQL query on specified database. 200 | 201 | Args: 202 | db_id: ID of database in catalog 203 | sql: execute this SQL query 204 | timeout_s: execution timeout in seconds 205 | 206 | Returns: 207 | success flag, result, and execution statistics 208 | """ 209 | db_dir = self.catalog.db_dir(db_id) 210 | db_path = f'{db_dir}/db.db' 211 | try: 212 | with sqlite3.connect(db_path) as connection: 213 | start_s = time.time() 214 | result = pd.read_sql(sql, connection) 215 | total_s = time.time() - start_s 216 | print(f'Query Result Info: {result.info()}') 217 | result.to_csv(self.result_path) 218 | return True, result, {'execution_s':total_s} 219 | except Exception as e: 220 | print(f'Exception: {e}') 221 | return False, pd.DataFrame(), {'execution_s':-1} 222 | 223 | def _prepare_db(self, db_id): 224 | """ Prepare database for querying. 225 | 226 | Args: 227 | db_id: database ID in catalog 228 | """ 229 | db_dir = self.catalog.db_dir(db_id) 230 | db_path = f'{db_dir}/db.db' 231 | if os.path.exists(db_path): 232 | subprocess.run(['rm', db_path]) 233 | with sqlite3.connect(db_path) as connection: 234 | schema = self.catalog.schema(db_id) 235 | tables = schema['table_names_original'] 236 | for table in tables: 237 | file_name = self.catalog.file_name(db_id, table) 238 | table_path = f'{db_dir}/{file_name}' 239 | df = pd.read_csv(table_path) 240 | df.columns = df.columns.str.replace(' ', '_') 241 | df.to_sql(table, connection) -------------------------------------------------------------------------------- /src/codexdb/bench/plot.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Jan 28, 2022 3 | 4 | @author: immanueltrummer 5 | ''' 6 | import argparse 7 | import collections 8 | import json 9 | import re 10 | import statistics 11 | 12 | 13 | model_ids = ['gpt-3.5-turbo', 'gpt-4o'] 14 | 15 | 16 | def get_run_id(model_id, prompt_style, nr_samples): 17 | """ Generate string ID characterizing run settings. 18 | 19 | Args: 20 | model_id: ID of the LLM used. 21 | prompt_style: style of prompt. 22 | nr_samples: how many samples for few-shot learning. 23 | 24 | Returns: 25 | string ID of run (used in result file names). 26 | """ 27 | return f'{model_id}_{prompt_style}_{nr_samples}' 28 | 29 | 30 | def agg_all(run_dir, solved, map_fct, agg_fct): 31 | """ Calculate aggregates over all runs. 32 | 33 | Args: 34 | run_dir: directory containing benchmark results 35 | solved: whether to consider only successful tries 36 | map_fct: maps tries to values for aggregation 37 | agg_fct: function used to aggregate values 38 | 39 | Returns: 40 | aggregated value over all tries 41 | """ 42 | values = [] 43 | #for model_id in ['cushman-codex', 'davinci-codex']: 44 | #for model_id in ['code-cushman-001', 'code-davinci-002']: 45 | for model_id in model_ids: 46 | # for prompt_style in ['question', 'query', 'plan']: 47 | for prompt_style in ['plan']: 48 | for nr_samples in [0, 2, 4]: 49 | run_id = get_run_id(model_id, prompt_style, nr_samples) 50 | result_path = f'{run_dir}/results_{run_id}.json' 51 | try: 52 | with open(result_path) as file: 53 | data = json.load(file) 54 | for tries in data.values(): 55 | for one_try in tries: 56 | if not solved or one_try['similarity']==1.0: 57 | value = map_fct(one_try) 58 | if value is not None: 59 | values.append(value) 60 | except Exception as e: 61 | pass 62 | #print(e) 63 | 64 | return agg_fct(values) 65 | 66 | 67 | def analyze_code(run_dir): 68 | """ Analyze generated code. 69 | 70 | Args: 71 | run_dir: directory containing benchmark results 72 | """ 73 | print('Analyzing generated code ...') 74 | model_id = model_ids[-1] 75 | prompt_style = 'plan' 76 | nr_samples = 2 77 | run_id = get_run_id(model_id, prompt_style, nr_samples) 78 | result_path = f'{run_dir}/results_{run_id}.json' 79 | with open(result_path) as file: 80 | data = json.load(file) 81 | 82 | lib_count = collections.defaultdict(lambda:0) 83 | tries_by_case = data.values() 84 | libraries = [ 85 | 'csv', 'pandas', 'vaex', 'datatable'] 86 | for tries in tries_by_case: 87 | code = tries[-1]['code'] 88 | for library in libraries: 89 | if f'import {library}' in code: 90 | lib_count[library] += 1 91 | 92 | reg_count = collections.defaultdict(lambda:0) 93 | reg_exps = [ 94 | 'print\(\'Done\.\'\)', 'print\(\'Done\'\)', 95 | 'print\("Done"\)', 'print\("Done."\)', 96 | 'print\("done"\)', 'print\("done."\)', 97 | 'print\(\'done\'\)', 'print\(\'done.\'\)', 98 | 'print\(([a-zA-Z_\(\)\.])+\)', 99 | 'print\(["|\'].+["|\']\)', 'print'] 100 | for tries in tries_by_case: 101 | code = tries[-1]['code'] 102 | for reg_exp in reg_exps: 103 | if re.search(reg_exp, code): 104 | reg_count[reg_exp] += 1 105 | 106 | print(lib_count) 107 | print(reg_count) 108 | 109 | 110 | def analyze_training(run_dir): 111 | """ Analyze training process. """ 112 | print('Analyzing training process ...') 113 | result_path = f'{run_dir}/train_plain.json' 114 | with open(result_path) as file: 115 | data = json.load(file) 116 | 117 | tries_by_case = data.values() 118 | nr_solved = 0 119 | for tries in tries_by_case: 120 | if [t for t in tries if t['similarity']==1.0]: 121 | nr_solved += 1 122 | print(f'Cases solved: {nr_solved}') 123 | 124 | solved_by = [] 125 | for max_nr_tries in range(11): 126 | count = 0 127 | for tries in tries_by_case: 128 | nr_tries = len(tries) 129 | last_idx = min(max_nr_tries, nr_tries) 130 | count += len([t for t in tries[:last_idx] if t['similarity']==1.0]) 131 | solved_by += [count] 132 | solved_by = list(enumerate(solved_by)) 133 | solved_by = ' '.join([str(s) for s in solved_by]) 134 | print(f'Nr. solved by try: {solved_by}') 135 | 136 | total_s = 0 137 | for tries in tries_by_case: 138 | for one_try in tries: 139 | generation_s = one_try['gen_total_s'] 140 | execution_s = one_try['execution_s']['total_s'] 141 | total_s += generation_s 142 | total_s += execution_s 143 | print(f'Total training time: {total_s} s') 144 | 145 | nr_tries_by_case = [len(tries) for tries in tries_by_case] 146 | print('Analyzing number of tries') 147 | for fct in [min, statistics.mean, statistics.median, max]: 148 | agg = fct(nr_tries_by_case) 149 | print(f'Nr. tries {fct}: {agg}') 150 | 151 | 152 | def count_solved(results, must_contain, multiplicity): 153 | """ Count the number of solved test cases. 154 | 155 | Args: 156 | results: results from one run 157 | must_contain: strings that must appear in code, separated by colon 158 | multiplicity: minimal number of occurrences for each required string 159 | 160 | Returns: 161 | number of solved test cases 162 | """ 163 | if must_contain: 164 | required = list(zip(must_contain.split(':'), multiplicity.split(':'))) 165 | else: 166 | required = [] 167 | 168 | nr_solved = 0 169 | for results in results.values(): 170 | for r in results: 171 | if r['similarity'] == 1.0: 172 | valid = True 173 | for required_string, required_number in required: 174 | code = r['code'] 175 | if code.count(required_string) < int(required_number): 176 | valid = False 177 | 178 | if valid: 179 | nr_solved += 1 180 | return nr_solved 181 | 182 | 183 | def generate_plot(run_dir, y_fct): 184 | """ Generates commands for PGF group plot. 185 | 186 | Args: 187 | run_dir: source data for plot 188 | y_fct: how to calculate 189 | 190 | Returns: 191 | list of plot groups 192 | """ 193 | plots = [] 194 | for model_id in model_ids: 195 | plot = [] 196 | for prompt_style in ['question', 'query', 'plan']: 197 | # for prompt_style in ['plan']: 198 | line = [] 199 | for nr_samples in [0, 2, 4]: 200 | run_id = get_run_id(model_id, prompt_style, nr_samples) 201 | result_path = f'{run_dir}/results_{run_id}.json' 202 | try: 203 | with open(result_path) as file: 204 | data = json.load(file) 205 | y_coordinate = y_fct(data) 206 | point = f'({nr_samples}, {y_coordinate})' 207 | line += [point] 208 | except Exception as e: 209 | #print(f'Exception for {result_path}: {e}') 210 | line += ['(-1, -1)'] 211 | plot += ['\\addplot coordinates {' + ' '.join(line) + '};'] 212 | plots += ['\n'.join(plot)] 213 | return plots 214 | 215 | 216 | def median(results, extractor, solved): 217 | """ Calculate median of numerical field over all tries. 218 | 219 | Args: 220 | results: dictionary mapping test case IDs to lists of tries 221 | extractor: function for extracting value of interest 222 | solved: whether to consider solved test cases only 223 | 224 | Returns: 225 | average of field over all relevant tries 226 | """ 227 | values = [] 228 | for tries in results.values(): 229 | for one_try in tries: 230 | if not solved or one_try['similarity'] == 1.0: 231 | value = extractor(one_try) 232 | if value is not None: 233 | values += [value] 234 | return statistics.median(values) if values else -1 235 | 236 | 237 | def print_aggs(run_dir, solved, map_fct): 238 | """ Print out aggregates for tries in directory. 239 | 240 | Args: 241 | run_dir: directory containing benchmark results 242 | solved: whether to only consider successful tries 243 | map_fct: maps tries to numbers for aggregation 244 | """ 245 | print(f'Printint aggregates for {run_dir} (solved: {solved}):') 246 | for agg_fct in [min, statistics.median, max]: 247 | agg_val = agg_all(run_dir, solved, map_fct, agg_fct) 248 | print(f'{agg_fct}:{agg_val}') 249 | print('\n' * 3) 250 | 251 | 252 | def print_group(plots): 253 | """ Print out a group of plots. 254 | 255 | Args: 256 | plots: list of plots 257 | """ 258 | for plot in plots: 259 | print('---') 260 | print(plot) 261 | print() 262 | print('###') 263 | 264 | 265 | if __name__ == '__main__': 266 | 267 | parser = argparse.ArgumentParser() 268 | parser.add_argument('run_dir', type=str, help='Path to directory with runs') 269 | parser.add_argument('must_contain', type=str, help='Code must contain this') 270 | parser.add_argument('multiplicity', type=str, help='Minimal #occurrences') 271 | args = parser.parse_args() 272 | 273 | print('Counting number of solved test cases:') 274 | count_fct = lambda d:count_solved(d, args.must_contain, args.multiplicity) 275 | count_plots = generate_plot(args.run_dir, count_fct) 276 | print_group(count_plots) 277 | 278 | print('CODE LENGTH') 279 | map_fct = lambda x:len(x['code']) 280 | y_fct = lambda d:median(d, map_fct, True) 281 | print_group(generate_plot(args.run_dir, y_fct)) 282 | print_aggs(args.run_dir, True, map_fct) 283 | 284 | print('QUERY LENGTH') 285 | map_fct = lambda x:len(x['query']) 286 | y_fct = lambda d:median(d, map_fct, False) 287 | print_group(generate_plot(args.run_dir, y_fct)) 288 | print_aggs(args.run_dir, False, map_fct) 289 | 290 | print('GENERATION TIMES') 291 | map_fct = lambda x:x['gen_stats']['last_request_s'] if 'gen_stats' in x and 'last_request_s' in x['gen_stats'] else None 292 | y_fct = lambda d:median(d, map_fct, False) 293 | print_group(generate_plot(args.run_dir, y_fct)) 294 | print_aggs(args.run_dir, False, map_fct) 295 | 296 | print('EXECUTION TIMES') 297 | map_fct = lambda x:x['execution_s']['total_s'] 298 | y_fct = lambda d:median(d, map_fct, True) 299 | print_group(generate_plot(args.run_dir, y_fct)) 300 | print_aggs(args.run_dir, True, map_fct) 301 | 302 | print('PROMPT TOKENS') 303 | map_fct = lambda x:x['gen_stats']['prompt_tokens'] 304 | y_fct = lambda d:median(d, map_fct, False) 305 | print_group(generate_plot(args.run_dir, y_fct)) 306 | print_aggs(args.run_dir, False, map_fct) 307 | 308 | print('COMPLETION TOKENS') 309 | map_fct = lambda x:x['gen_stats']['completion_tokens'] 310 | y_fct = lambda d:median(d, map_fct, False) 311 | print_group(generate_plot(args.run_dir, y_fct)) 312 | print_aggs(args.run_dir, False, map_fct) 313 | 314 | print('ANALYZING CODE') 315 | analyze_code(args.run_dir) 316 | print('\n\n\n') 317 | 318 | print('ANALYZING TRAINING') 319 | analyze_training(args.run_dir) -------------------------------------------------------------------------------- /src/codexdb/solve.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Jan 3, 2022 3 | 4 | @author: immanueltrummer 5 | ''' 6 | import argparse 7 | import codexdb.catalog 8 | import codexdb.code 9 | import codexdb.engine 10 | import contextlib 11 | import json 12 | import os 13 | import openai 14 | import pandas as pd 15 | import time 16 | 17 | def extract_samples(catalog, path_to_results): 18 | """ Extracts completion examples from prior results file. 19 | 20 | Args: 21 | catalog: database catalog with schema information 22 | path_to_results: path to prior results file 23 | 24 | Returns: 25 | list of extracted examples 26 | """ 27 | with open(path_to_results) as file: 28 | prior_results = json.load(file) 29 | 30 | examples = [] 31 | for cur_results in prior_results.values(): 32 | for r in cur_results: 33 | if r['similarity'] == 1.0: 34 | examples.append(r) 35 | 36 | for e in examples: 37 | if 'schema' not in e: 38 | db_id = e['db_id'] 39 | e['schema'] = catalog.schema(db_id) 40 | e['files'] = catalog.files(db_id) 41 | 42 | return examples 43 | 44 | def result_cmp(ref_output, cmp_output, reorder): 45 | """ Compares query result output against reference. 46 | 47 | Args: 48 | ref_output: reference query result 49 | cmp_output: compare this against reference 50 | reorder: whether to consider reordering 51 | 52 | Returns: 53 | Comparable flag, number of differences, similarity 54 | """ 55 | print(f'-- CodexDB output:\n{cmp_output}\n--\n') 56 | print(f'CodexDB Index: {cmp_output.index}') 57 | print(f'CodexDB info: {cmp_output.info()}') 58 | print(f'-- Reference output:\n{ref_output}\n--\n') 59 | print(f'Reference Index: {ref_output.index}') 60 | print(f'Reference info: {ref_output.info()}') 61 | 62 | ref_output.columns = range(ref_output.shape[1]) 63 | cmp_output.columns = range(cmp_output.shape[1]) 64 | try: 65 | print('Casting all columns to string type ...') 66 | ref_output = ref_output.astype(str) 67 | cmp_output = cmp_output.astype(str) 68 | 69 | print('Normalizing representation of integers ...') 70 | def to_int(float_str): 71 | """ Transforms rounded float values into integers. """ 72 | if float_str.endswith('.0'): 73 | return float_str[:-2] 74 | else: 75 | return float_str 76 | ref_output = ref_output.applymap(to_int) 77 | cmp_output = cmp_output.applymap(to_int) 78 | 79 | print('Normalizing representation of lists ...') 80 | def unwrap(cell): 81 | """ Unwrap elements from singleton lists. """ 82 | if isinstance(cell, list) and len(cell) == 1: 83 | return cell[0] 84 | else: 85 | return cell 86 | ref_output = ref_output.applymap(unwrap) 87 | cmp_output = cmp_output.applymap(unwrap) 88 | 89 | if reorder: 90 | print('Reordering Rows Before Comparison') 91 | nr_columns = len(ref_output.columns) 92 | column_idxs = list(range(nr_columns)) 93 | ref_output.sort_values(by=column_idxs, inplace=True) 94 | cmp_output.sort_values(by=column_idxs, inplace=True) 95 | 96 | ref_output.reset_index(drop=True, inplace=True) 97 | cmp_output.reset_index(drop=True, inplace=True) 98 | 99 | print(f'--- CodexDB column types:\n{cmp_output.dtypes}') 100 | print(f'--- CodexDB normalized output:\n{cmp_output}\n--\n') 101 | print(f'--- Reference column types:\n{ref_output.dtypes}') 102 | print(f'--- Normalized reference output:\n{ref_output}\n--\n') 103 | 104 | nr_ref_rows = ref_output.shape[0] 105 | nr_cmp_rows = cmp_output.shape[0] 106 | if nr_ref_rows == 0 and nr_cmp_rows == 0: 107 | diffs = pd.DataFrame() 108 | else: 109 | diffs = ref_output.compare(cmp_output, align_axis=0) 110 | print(f'-- Differences:\n{diffs}\n--\n') 111 | nr_diffs = diffs.shape[0] 112 | return True, nr_diffs, 1.0/(nr_diffs+1) 113 | except Exception as e: 114 | print('(Incomparable)') 115 | print(f'Exception: {e}') 116 | return False, -1, 0 117 | 118 | def solve(catalog, test_case, coder, engine, 119 | termination, max_tries, max_temperature): 120 | """ Solve given test case by generating code. 121 | 122 | Args: 123 | catalog: database catalog 124 | test_case: a natural language query 125 | coder: code generator to use 126 | engine: execution engine for code 127 | termination: criterion to advance to next case 128 | max_tries: maximal number of tries 129 | max_temperature: maximal temperature 130 | 131 | Returns: 132 | list of dictionaries with generated code and statistics 133 | """ 134 | db_id = test_case['db_id'] 135 | schema = catalog.schema(db_id) 136 | files = catalog.files(db_id) 137 | question = test_case['question'] 138 | query = test_case['query'] 139 | reorder = False if 'order by' in query.lower() else True 140 | temperature_step = max_temperature / max_tries 141 | print(f'Treating query {query}, question {question}.') 142 | 143 | results = [] 144 | for try_idx in range(max_tries): 145 | print("Waiting due to OpenAI's rate limit ...") 146 | time.sleep(3) 147 | print(f'Starting try number {try_idx} ...') 148 | gen_start_s = time.time() 149 | temperature = try_idx * temperature_step 150 | gen_stats, code = coder.generate(test_case, temperature) 151 | print(f'Generated code:\n-------\n{code}\n-------\n') 152 | print(f'Reference Query: "{query}"') 153 | gen_total_s = time.time() - gen_start_s 154 | executed, codb_result, elapsed_s = engine.execute(db_id, code, 30) 155 | print(f'CodexDB executed: {executed} in {elapsed_s}s') 156 | ref_output = pd.DataFrame(test_case['results']) 157 | comparable, nr_diffs, similarity = result_cmp( 158 | ref_output, codb_result, reorder) 159 | nr_tries = try_idx + 1 160 | results.append({ 161 | 'nr_tries':nr_tries, 'executed':executed, 'comparable':comparable, 162 | 'nr_diffs':nr_diffs, 'similarity':similarity, 163 | 'outsize':len(codb_result), 164 | 'question':question, 'query':query, 165 | 'db':db_id, 'schema':schema, 'files':files, 166 | 'code':code, 'gen_stats':gen_stats, 'gen_total_s':gen_total_s, 167 | 'execution_s':elapsed_s}) 168 | 169 | if (termination == 'executed' and executed) or \ 170 | (termination == 'solved' and similarity >= 1.0): 171 | print('Termination Criterion Satisfied.') 172 | break 173 | 174 | return results 175 | 176 | def main( 177 | data_dir, test_path, language, model_id, prompt_style, id_case, 178 | mod_start, mod_between, mod_end, sample_path, nr_samples, 179 | test_start, test_step, test_end, termination, max_tries, 180 | max_temperature, log_path, result_path): 181 | """ Try solving given test cases and write results to file. 182 | 183 | Args: 184 | data_dir: directory containing database 185 | test_path: path to file with test cases 186 | language: generate code in this language 187 | model_id: OpenAI engine for code generation 188 | prompt_style: choose prompt template 189 | id_case: whether to consider letter case of identifiers 190 | mod_start: modification at plan start 191 | mod_between: modifications between steps 192 | mod_end: modification at plan end 193 | sample_path: path to example library 194 | nr_samples: number of examples in prompt 195 | test_start: index of first test case 196 | test_step: gap between test case indexes 197 | test_end: index of last test case + 1 198 | termination: termination criterion 199 | max_tries: maximal tries per test case 200 | max_temperature: maximal temperature 201 | log_path: path for logging output 202 | result_path: path to result .json file 203 | """ 204 | catalog = codexdb.catalog.DbCatalog(data_dir) 205 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 206 | 207 | with open(test_path) as file: 208 | test_cases = json.load(file) 209 | if language not in ['python', 'sql']: 210 | raise ValueError(f'Unknown implementation language: {language}!') 211 | examples = [] 212 | if sample_path: 213 | with open(sample_path) as file: 214 | examples = extract_samples(catalog, sample_path) 215 | if prompt_style not in ['question', 'query', 'plan', 'data']: 216 | raise ValueError(f'Unknown prompt style: {prompt_style}!') 217 | if termination not in ['executed', 'solved']: 218 | raise ValueError(f'Unknown termination criterion: {termination}') 219 | 220 | with open(log_path, 'w') as log_file: 221 | with contextlib.redirect_stdout(log_file): 222 | if language == 'python': 223 | coder = codexdb.code.PythonGenerator( 224 | catalog, examples, nr_samples, 225 | prompt_style, model_id, 226 | id_case=id_case, 227 | mod_start=mod_start, 228 | mod_between=mod_between, 229 | mod_end=mod_end) 230 | engine = codexdb.engine.PythonEngine( 231 | catalog, id_case) 232 | elif language == 'sql': 233 | coder = codexdb.code.SqlGenerator( 234 | catalog, examples, nr_samples, 235 | prompt_style, model_id) 236 | engine = codexdb.engine.SqliteEngine(catalog) 237 | 238 | idx_to_results = {} 239 | for i in range(test_start, test_end, test_step): 240 | print(f'Starting test case nr. {i} ...') 241 | test_case = test_cases[i] 242 | cur_results = solve( 243 | catalog, test_case, coder, engine, 244 | termination, max_tries, max_temperature) 245 | idx_to_results[i] = cur_results 246 | print(cur_results) 247 | 248 | with open(result_path, 'w') as results_file: 249 | json.dump(idx_to_results, results_file) 250 | 251 | if __name__ == '__main__': 252 | 253 | parser = argparse.ArgumentParser() 254 | parser.add_argument('ai_key', type=str, help='Key for OpenAI access') 255 | parser.add_argument('data_dir', type=str, help='Data directory') 256 | parser.add_argument('test_path', type=str, help='Path to test case file') 257 | parser.add_argument('language', type=str, help='Implementation language') 258 | parser.add_argument('model_id', type=str, help='ID of OpenAI model') 259 | parser.add_argument('prompt_style', type=str, help='Style of prompt') 260 | parser.add_argument('mod_start', type=str, help='Instructions at start') 261 | parser.add_argument('mod_between', type=str, help='Execute between steps') 262 | parser.add_argument('mod_end', type=str, help='Instructions at end') 263 | parser.add_argument('sample_path', type=str, help='Path to sample file') 264 | parser.add_argument('nr_samples', type=int, help='Number of samples in prompt') 265 | parser.add_argument('test_start', type=int, help='Index of first test case') 266 | parser.add_argument('test_step', type=int, help='Gap between test case indexes') 267 | parser.add_argument('test_end', type=int, help='Index of last test case +1') 268 | parser.add_argument('termination', type=str, help='Termination criterion') 269 | parser.add_argument('max_tries', type=int, help='Maximal number of tries') 270 | parser.add_argument('log_path', type=str, help='Redirect output here') 271 | parser.add_argument('result_path', type=str, help='Contains results') 272 | args = parser.parse_args() 273 | 274 | openai.api_key = args.ai_key 275 | main( 276 | args.data_dir, args.test_path, args.language, args.model_id, 277 | args.prompt_style, True, args.mod_start, args.mod_between, args.mod_end, 278 | args.sample_path, args.nr_samples, args.test_start, args.test_step, 279 | args.test_end, args.termination, args.max_tries, 0.5, 280 | args.log_path, args.result_path) -------------------------------------------------------------------------------- /src/codexdb/code.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Jan 17, 2022 3 | 4 | @author: immanueltrummer 5 | ''' 6 | import abc 7 | import codexdb.plan 8 | import numpy as np 9 | import openai.error 10 | import pandas as pd 11 | import random 12 | import re 13 | import time 14 | 15 | class CodeGenerator(abc.ABC): 16 | """ Generates code in different languages using OpenAI. """ 17 | 18 | def __init__(self, catalog, examples, nr_samples, prompt_style, model_id): 19 | """ Initializes with examples for few-shot learning. 20 | 21 | Args: 22 | catalog: database catalog 23 | examples: list of examples for few-shot learning. 24 | nr_samples: maximal number of examples to use. 25 | prompt_style: style of prompt to generate 26 | model_id: OpenAI model to use for generation 27 | """ 28 | self.catalog = catalog 29 | self.examples = examples 30 | self.nr_samples = nr_samples 31 | self.prompt_style = prompt_style 32 | self.ai_kwargs = {'model':model_id} 33 | self.code_prefix = '' 34 | self.code_suffix = '' 35 | 36 | def generate(self, test_case, temperature): 37 | """ Generate code to solve given test case. 38 | 39 | Args: 40 | test_case: generate code solving this test case 41 | temperature: degree of randomness during generation 42 | 43 | Returns: 44 | statistics, generated code 45 | """ 46 | prefix = self._sample_prompts() 47 | db_id = test_case['db_id'] 48 | schema = self.catalog.schema(db_id) 49 | files = self.catalog.files(db_id) 50 | db_dir = self.catalog.db_dir(db_id) 51 | question = test_case['question'] 52 | query = test_case['query'] 53 | suffix = self.get_prompt(schema, db_dir, files, question, query) 54 | prompt = prefix + '\n' + suffix 55 | stats, gen_code = self._complete(prompt, temperature) 56 | final_code = self.code_prefix + gen_code + self.code_suffix 57 | return stats, final_code 58 | 59 | def _complete(self, prompt, temperature): 60 | """ Generate code by completing given prompt. 61 | 62 | Args: 63 | prompt: initiate generation with this prompt 64 | temperature: degree of randomness 65 | 66 | Returns: 67 | statistics, generated code 68 | """ 69 | wait_s = 1 70 | nr_retries = 0 71 | while nr_retries < 5: 72 | stats = {'nr_retries':nr_retries} 73 | try: 74 | print(f'\nPrompt:\n*******\n{prompt}\n*******') 75 | start_s = time.time() 76 | response = openai.ChatCompletion.create( 77 | messages=[ 78 | {'role':'system', 79 | 'content':'You write Python code, implementing Python comments.'}, 80 | {'role':'user', 'content':prompt}], 81 | temperature=temperature, 82 | **self.ai_kwargs) 83 | completion = self._extract_code(response) 84 | total_s = time.time() - start_s 85 | usage = response['usage'] 86 | stats['prompt_tokens'] = usage['prompt_tokens'] 87 | stats['completion_tokens'] = usage['completion_tokens'] 88 | stats['last_request_s'] = total_s 89 | stats['error'] = False 90 | return stats, completion 91 | except openai.error.InvalidRequestError as e: 92 | print(f'InvalidRequestError: {e} - giving up') 93 | # No point in retrying (often: prompt to long) 94 | stats['error'] = True 95 | return stats, '' 96 | except Exception as e: 97 | print(f'Error querying OpenAI: {e}') 98 | print(f'Wait {wait_s} s before retry nr. {nr_retries} ...') 99 | time.sleep(wait_s) 100 | wait_s *= 2 101 | nr_retries += 1 102 | stats['error'] = True 103 | return stats, '' 104 | 105 | def _db_sample(self, db_dir, file_name, max_rows): 106 | """ Returns data sample from specified file. 107 | 108 | Args: 109 | db_dir: directory containing database data 110 | file_name: name of file within directory 111 | max_rows: maximal number of sample rows 112 | 113 | Returns: 114 | list of string representing sample rows 115 | """ 116 | lines = [] 117 | df = pd.read_csv(f'{db_dir}/{file_name}') 118 | nr_rows = df.shape[0] 119 | nr_cols = df.shape[1] 120 | for row_idx in range(min(max_rows, nr_rows)): 121 | row_parts = [] 122 | for col_idx in range(nr_cols): 123 | value = str(df.iloc[row_idx, col_idx]) 124 | col_type = df.dtypes[col_idx].type 125 | if not np.issubdtype(col_type, np.number): 126 | value = '"' + value + '"' 127 | row_parts.append(value) 128 | lines.append(','.join(row_parts)) 129 | return lines 130 | 131 | def _extract_code(self, response): 132 | """ Extract Python code from LLM answer. 133 | 134 | Args: 135 | response: response generated by the LLM. 136 | 137 | Returns: 138 | answer extract containing Python code. 139 | """ 140 | completion = response['choices'][0]['message']['content'] 141 | snippets = re.findall('```python(.*)```', completion, re.DOTALL) 142 | if snippets: 143 | completion = snippets[0] 144 | 145 | return completion 146 | 147 | @abc.abstractmethod 148 | def get_prompt(self, schema, db_dir, files, question, query): 149 | """ Generate prompt for processing specific query. 150 | 151 | Args: 152 | schema: description of database schema 153 | db_dir: directory storing data files 154 | files: location of data files for tables 155 | question: natural language query 156 | query: SQL translation of query 157 | 158 | Returns: 159 | Prompt for generating code for executing query 160 | """ 161 | raise NotImplementedError() 162 | 163 | @abc.abstractmethod 164 | def _sample_prompts(self): 165 | """ Generate sample prompts for few-shot learning. 166 | 167 | Returns: 168 | Prompt prefix with completion examples 169 | """ 170 | raise NotImplementedError() 171 | 172 | 173 | class PythonGenerator(CodeGenerator): 174 | """ Generates Python code to solve database queries. """ 175 | 176 | def __init__(self, *pargs, id_case, mod_start, mod_between, mod_end): 177 | """ Initializes for Python code generation. 178 | 179 | Args: 180 | pargs: arguments of super class constructor 181 | id_case: whether to consider letter case for identifiers 182 | mod_start: modification at start of query plan 183 | mod_between: modifications between plan steps 184 | mod_end: modifications at end of query plan 185 | """ 186 | super().__init__(*pargs) 187 | self.ai_kwargs['max_tokens'] = 800 188 | self.ai_kwargs['stop'] = '"""' 189 | self.planner = codexdb.plan.NlPlanner(id_case) 190 | self.id_case = id_case 191 | self.mod_start = mod_start 192 | self.mod_between = mod_between 193 | self.mod_end = mod_end 194 | # Reproducible experiments 195 | random.seed(42) 196 | 197 | def _db_info(self, schema, db_dir, files, max_rows): 198 | """ Generate description of database. 199 | 200 | Args: 201 | schema: description of database schema 202 | db_dir: directory containing data 203 | files: names to files storing tables 204 | max_rows: maximal number of rows in data sample 205 | 206 | Returns: 207 | list of description lines 208 | """ 209 | lines = [] 210 | tables = schema['table_names_original'] 211 | all_columns = schema['column_names_original'] 212 | nr_tables = len(tables) 213 | for tbl_idx in range(nr_tables): 214 | filename = files[tbl_idx] 215 | tbl_name = tables[tbl_idx] 216 | if not self.id_case: 217 | filename = filename.lower() 218 | tbl_name = tbl_name.lower() 219 | 220 | if self.prompt_style == 'data': 221 | 222 | lines.append(f'Sample from table {tbl_name}, stored in "{filename}":') 223 | df = pd.read_csv(f'{db_dir}/{filename}') 224 | headers = [] 225 | for col_name in df.columns: 226 | if not self.id_case: 227 | col_name = col_name.lower() 228 | header = f'"{col_name}"' 229 | headers.append(header) 230 | lines.append(','.join(headers)) 231 | 232 | file_name = files[tbl_idx] 233 | lines += self._db_sample(db_dir, file_name, max_rows) 234 | 235 | type_items = [] 236 | for col_name, col_type in zip(df.columns, df.dtypes): 237 | if np.issubdtype(col_type, np.number): 238 | print_type = 'numeric' 239 | else: 240 | print_type = 'text' 241 | type_item = f'"{col_name}": {print_type}' 242 | type_items.append(type_item) 243 | lines.append('Column types: ' + ', '.join(type_items)) 244 | 245 | else: 246 | table_columns = [c[1] for c in all_columns if c[0] == tbl_idx] 247 | if not self.id_case: 248 | table_columns = [c.lower() for c in table_columns] 249 | quoted_columns = ["'" + c + "'" for c in table_columns] 250 | col_list = ','.join(quoted_columns) 251 | line = f'Table {tbl_name} with columns {col_list}, ' \ 252 | f'stored in \'{filename}\'.' 253 | lines.append(line) 254 | 255 | return lines 256 | 257 | def get_prompt(self, schema, db_dir, files, question, query): 258 | """ Generate prompt for processing specific query. 259 | 260 | Args: 261 | schema: description of database schema 262 | db_dir: directory storing data files 263 | files: location of data files for tables 264 | question: natural language query 265 | query: SQL translation of query 266 | 267 | Returns: 268 | Prompt for generating code for executing query 269 | """ 270 | prompt_parts = [] 271 | prompt_parts.append('"""') 272 | prompt_parts += self._db_info(schema, db_dir, files, 5) 273 | if self.prompt_style in ['question', 'query', 'plan']: 274 | if self.prompt_style in ['question', 'query']: 275 | if self.prompt_style == 'question': 276 | prompt_parts.append(f'Question: {question}') 277 | else: 278 | prompt_parts.append(f'SQL query: {query}') 279 | 280 | if self.mod_between: 281 | prompt_parts.append(f'Between steps: {self.mod_between}') 282 | if self.mod_start: 283 | prompt_parts.append(self.mod_start) 284 | if self.mod_end: 285 | prompt_parts.append(self.mod_end) 286 | else: 287 | prompt_parts.append('Processing steps:') 288 | plan = self.planner.plan(query) 289 | if self.mod_between: 290 | plan.intersperse_step([self.mod_between]) 291 | if self.mod_start: 292 | plan.add_step([self.mod_start], False) 293 | if self.mod_end: 294 | plan.add_step([self.mod_end]) 295 | prompt_parts += plan.steps() 296 | else: 297 | prompt_parts.append(f'Query: "{question}".') 298 | prompt_parts.append('1. Import pandas library.') 299 | prompt_parts.append('2. Calculate query answer.') 300 | prompt_parts.append("3. Store result in 'result.csv'.") 301 | prompt_parts.append('"""') 302 | return '\n'.join(prompt_parts) 303 | 304 | def _sample_prompts(self): 305 | """ Generate prompts from examples for few-shot learning. 306 | 307 | Returns: 308 | a prefix of the full prompt to generate 309 | """ 310 | parts = [] 311 | if self.examples: 312 | selected = random.sample(self.examples, k=self.nr_samples) 313 | for example in selected: 314 | db_id = example['schema']['db_id'] 315 | db_dir = self.catalog.db_dir(db_id) 316 | prompt = self.get_prompt( 317 | example['schema'], db_dir, example['files'], 318 | example['question'], example['query']) 319 | parts.append(prompt) 320 | parts.append(example['code']) 321 | parts.append('') 322 | parts.append('') 323 | return '\n'.join(parts) 324 | 325 | 326 | class SqlGenerator(CodeGenerator): 327 | """ Translates natural language questions into SQL queries. """ 328 | 329 | def __init__(self, *kwargs): 330 | """ Initializes for SQL query generation. 331 | 332 | Args: 333 | kwargs: arguments for super class constructor 334 | """ 335 | super().__init__(*kwargs) 336 | self.ai_kwargs['max_tokens'] = 150 337 | self.ai_kwargs['stop'] = ['#', ';'] 338 | self.code_prefix = 'SELECT ' 339 | 340 | def get_prompt(self, schema, db_dir, files, question, query): 341 | """ Returns prompt for given question. """ 342 | lines = [] 343 | lines.append('### Postgres SQL tables, with their properties:') 344 | lines.append('#') 345 | 346 | tables = schema['table_names_original'] 347 | all_columns = schema['column_names_original'] 348 | for idx, table in enumerate(tables): 349 | cols = [c[1].replace(' ', '_') for c in all_columns if c[0] == idx] 350 | lines.append(f'# {table}({",".join(cols)})') 351 | if self.prompt_style == 'data': 352 | #lines.append(f'Sample rows from {table}:') 353 | file_name = files[idx] 354 | sample = self._db_sample(db_dir, file_name, 5) 355 | lines += ['# ' + s for s in sample] 356 | 357 | lines.append('#') 358 | lines.append(f'### Query: "{question}"') 359 | lines.append('SELECT') 360 | return '\n'.join(lines) 361 | 362 | def _sample_prompts(self): 363 | """ Returns prefix with samples for few-shot learning. """ 364 | parts = [] 365 | selected = random.sample(self.examples, k=self.nr_samples) 366 | for example in selected: 367 | db_id = example['schema']['db_id'] 368 | db_dir = self.catalog.db_dir(db_id) 369 | prompt = self.get_prompt( 370 | example['schema'], db_dir, example['files'], 371 | example['question'], example['query']) 372 | parts.append(prompt + example['query'][6:]) 373 | parts.append('') 374 | parts.append('') 375 | return '\n'.join(parts) -------------------------------------------------------------------------------- /config/prompts.json: -------------------------------------------------------------------------------- 1 | { 2 | "sample_databases": [ 3 | { 4 | "column_names_original": [ 5 | [ 6 | 0, 7 | "Perpetrator_ID" 8 | ], 9 | [ 10 | 0, 11 | "People_ID" 12 | ], 13 | [ 14 | 0, 15 | "Date" 16 | ], 17 | [ 18 | 0, 19 | "Year" 20 | ], 21 | [ 22 | 0, 23 | "Location" 24 | ], 25 | [ 26 | 0, 27 | "Country" 28 | ], 29 | [ 30 | 0, 31 | "Killed" 32 | ], 33 | [ 34 | 0, 35 | "Injured" 36 | ], 37 | [ 38 | 1, 39 | "People_ID" 40 | ], 41 | [ 42 | 1, 43 | "Name" 44 | ], 45 | [ 46 | 1, 47 | "Height" 48 | ], 49 | [ 50 | 1, 51 | "Weight" 52 | ], 53 | [ 54 | 1, 55 | "Home Town" 56 | ] 57 | ], 58 | "column_types": [ 59 | "text", 60 | "number", 61 | "number", 62 | "text", 63 | "number", 64 | "text", 65 | "text", 66 | "number", 67 | "number", 68 | "number", 69 | "text", 70 | "number", 71 | "number", 72 | "text" 73 | ], 74 | "db_id": "perpetrator", 75 | "foreign_keys": [ 76 | [ 77 | 2, 78 | 9 79 | ] 80 | ], 81 | "primary_keys": [ 82 | 1, 83 | 9 84 | ], 85 | "table_names_original": [ 86 | "perpetrator", 87 | "people" 88 | ], 89 | "files": [ 90 | "perpetrator.csv", 91 | "people.csv" 92 | ] 93 | }, 94 | { 95 | "column_names_original": [ 96 | [ 97 | 0, 98 | "building" 99 | ], 100 | [ 101 | 0, 102 | "room_number" 103 | ], 104 | [ 105 | 0, 106 | "capacity" 107 | ], 108 | [ 109 | 1, 110 | "dept_name" 111 | ], 112 | [ 113 | 1, 114 | "building" 115 | ], 116 | [ 117 | 1, 118 | "budget" 119 | ], 120 | [ 121 | 2, 122 | "course_id" 123 | ], 124 | [ 125 | 2, 126 | "title" 127 | ], 128 | [ 129 | 2, 130 | "dept_name" 131 | ], 132 | [ 133 | 2, 134 | "credits" 135 | ], 136 | [ 137 | 3, 138 | "ID" 139 | ], 140 | [ 141 | 3, 142 | "name" 143 | ], 144 | [ 145 | 3, 146 | "dept_name" 147 | ], 148 | [ 149 | 3, 150 | "salary" 151 | ], 152 | [ 153 | 4, 154 | "course_id" 155 | ], 156 | [ 157 | 4, 158 | "sec_id" 159 | ], 160 | [ 161 | 4, 162 | "semester" 163 | ], 164 | [ 165 | 4, 166 | "year" 167 | ], 168 | [ 169 | 4, 170 | "building" 171 | ], 172 | [ 173 | 4, 174 | "room_number" 175 | ], 176 | [ 177 | 4, 178 | "time_slot_id" 179 | ], 180 | [ 181 | 5, 182 | "ID" 183 | ], 184 | [ 185 | 5, 186 | "course_id" 187 | ], 188 | [ 189 | 5, 190 | "sec_id" 191 | ], 192 | [ 193 | 5, 194 | "semester" 195 | ], 196 | [ 197 | 5, 198 | "year" 199 | ], 200 | [ 201 | 6, 202 | "ID" 203 | ], 204 | [ 205 | 6, 206 | "name" 207 | ], 208 | [ 209 | 6, 210 | "dept_name" 211 | ], 212 | [ 213 | 6, 214 | "tot_cred" 215 | ], 216 | [ 217 | 7, 218 | "ID" 219 | ], 220 | [ 221 | 7, 222 | "course_id" 223 | ], 224 | [ 225 | 7, 226 | "sec_id" 227 | ], 228 | [ 229 | 7, 230 | "semester" 231 | ], 232 | [ 233 | 7, 234 | "year" 235 | ], 236 | [ 237 | 7, 238 | "grade" 239 | ], 240 | [ 241 | 8, 242 | "s_ID" 243 | ], 244 | [ 245 | 8, 246 | "i_ID" 247 | ], 248 | [ 249 | 9, 250 | "time_slot_id" 251 | ], 252 | [ 253 | 9, 254 | "day" 255 | ], 256 | [ 257 | 9, 258 | "start_hr" 259 | ], 260 | [ 261 | 9, 262 | "start_min" 263 | ], 264 | [ 265 | 9, 266 | "end_hr" 267 | ], 268 | [ 269 | 9, 270 | "end_min" 271 | ], 272 | [ 273 | 10, 274 | "course_id" 275 | ], 276 | [ 277 | 10, 278 | "prereq_id" 279 | ] 280 | ], 281 | "column_types": [ 282 | "text", 283 | "text", 284 | "text", 285 | "number", 286 | "text", 287 | "text", 288 | "number", 289 | "text", 290 | "text", 291 | "text", 292 | "number", 293 | "text", 294 | "text", 295 | "text", 296 | "number", 297 | "text", 298 | "text", 299 | "text", 300 | "number", 301 | "text", 302 | "text", 303 | "text", 304 | "text", 305 | "text", 306 | "text", 307 | "text", 308 | "number", 309 | "text", 310 | "text", 311 | "text", 312 | "number", 313 | "text", 314 | "text", 315 | "text", 316 | "text", 317 | "number", 318 | "text", 319 | "text", 320 | "text", 321 | "text", 322 | "text", 323 | "number", 324 | "number", 325 | "number", 326 | "number", 327 | "text", 328 | "text" 329 | ], 330 | "db_id": "college_2", 331 | "foreign_keys": [ 332 | [ 333 | 9, 334 | 4 335 | ], 336 | [ 337 | 13, 338 | 4 339 | ], 340 | [ 341 | 19, 342 | 1 343 | ], 344 | [ 345 | 20, 346 | 2 347 | ], 348 | [ 349 | 15, 350 | 7 351 | ], 352 | [ 353 | 22, 354 | 11 355 | ], 356 | [ 357 | 23, 358 | 15 359 | ], 360 | [ 361 | 24, 362 | 16 363 | ], 364 | [ 365 | 25, 366 | 17 367 | ], 368 | [ 369 | 26, 370 | 18 371 | ], 372 | [ 373 | 29, 374 | 4 375 | ], 376 | [ 377 | 31, 378 | 27 379 | ], 380 | [ 381 | 32, 382 | 15 383 | ], 384 | [ 385 | 33, 386 | 16 387 | ], 388 | [ 389 | 34, 390 | 17 391 | ], 392 | [ 393 | 35, 394 | 18 395 | ], 396 | [ 397 | 37, 398 | 27 399 | ], 400 | [ 401 | 38, 402 | 11 403 | ], 404 | [ 405 | 46, 406 | 7 407 | ], 408 | [ 409 | 45, 410 | 7 411 | ] 412 | ], 413 | "primary_keys": [ 414 | 1, 415 | 4, 416 | 7, 417 | 11, 418 | 15, 419 | 22, 420 | 27, 421 | 31, 422 | 37, 423 | 39, 424 | 45 425 | ], 426 | "table_names_original": [ 427 | "classroom", 428 | "department", 429 | "course", 430 | "instructor", 431 | "section", 432 | "teaches", 433 | "student", 434 | "takes", 435 | "advisor", 436 | "time_slot", 437 | "prereq" 438 | ], 439 | "files": [ 440 | "classroom.csv", 441 | "department.csv", 442 | "course.csv", 443 | "instructor.csv", 444 | "section.csv", 445 | "teaches.csv", 446 | "student.csv", 447 | "takes.csv", 448 | "advisor.csv", 449 | "time_slot.csv", 450 | "prereq.csv" 451 | ] 452 | }, 453 | { 454 | "column_names_original": [ 455 | [ 456 | 0, 457 | "id" 458 | ], 459 | [ 460 | 0, 461 | "City" 462 | ], 463 | [ 464 | 0, 465 | "Country" 466 | ], 467 | [ 468 | 0, 469 | "IATA" 470 | ], 471 | [ 472 | 0, 473 | "ICAO" 474 | ], 475 | [ 476 | 0, 477 | "name" 478 | ], 479 | [ 480 | 1, 481 | "id" 482 | ], 483 | [ 484 | 1, 485 | "name" 486 | ], 487 | [ 488 | 1, 489 | "Type" 490 | ], 491 | [ 492 | 1, 493 | "Principal_activities" 494 | ], 495 | [ 496 | 1, 497 | "Incorporated_in" 498 | ], 499 | [ 500 | 1, 501 | "Group_Equity_Shareholding" 502 | ], 503 | [ 504 | 2, 505 | "id" 506 | ], 507 | [ 508 | 2, 509 | "Vehicle_Flight_number" 510 | ], 511 | [ 512 | 2, 513 | "Date" 514 | ], 515 | [ 516 | 2, 517 | "Pilot" 518 | ], 519 | [ 520 | 2, 521 | "Velocity" 522 | ], 523 | [ 524 | 2, 525 | "Altitude" 526 | ], 527 | [ 528 | 2, 529 | "airport_id" 530 | ], 531 | [ 532 | 2, 533 | "company_id" 534 | ] 535 | ], 536 | "column_types": [ 537 | "text", 538 | "number", 539 | "text", 540 | "text", 541 | "text", 542 | "text", 543 | "text", 544 | "number", 545 | "text", 546 | "text", 547 | "text", 548 | "text", 549 | "number", 550 | "number", 551 | "text", 552 | "text", 553 | "text", 554 | "number", 555 | "number", 556 | "number", 557 | "number" 558 | ], 559 | "db_id": "flight_company", 560 | "foreign_keys": [ 561 | [ 562 | 20, 563 | 7 564 | ], 565 | [ 566 | 19, 567 | 1 568 | ] 569 | ], 570 | "primary_keys": [ 571 | 1, 572 | 7, 573 | 13 574 | ], 575 | "table_names_original": [ 576 | "airport", 577 | "operate_company", 578 | "flight" 579 | ], 580 | "files": [ 581 | "airport.csv", 582 | "operate_company.csv", 583 | "flight.csv" 584 | ] 585 | } 586 | ], 587 | "transform": { 588 | "from_*": { 589 | "to_python": { 590 | "template": "\"\"\\nThis Python program transforms the files of the following database into anotehr format.\n\n\"\"\"\n\n--- Start of Python program ---", 591 | "marker": "--- End of Python program ---", 592 | "linepre": "" 593 | }, 594 | "to_bash": { 595 | "template": "# This Bash script transforms the files of the following database into another format.\n\n\n\n--- Start of Bash script ---\n#!/bin/bash\n\necho \"Transforming data ...\"", 596 | "marker": "--- End of Bash script ---", 597 | "linepre": "# " 598 | }, 599 | "to_cpp": { 600 | "template": "// This C++ program transforms the files of the following database into another format.\n\n\n\n--- Start of C++ program ---\n", 601 | "marker": "--- End of C++ program ---", 602 | "linepre": "// " 603 | } 604 | }, 605 | "tactics": [ 606 | "Import libraries.", 607 | "Iterate over all tables.", 608 | "For each table, load the associated file.", 609 | "Transform the file into another format.", 610 | "Write new file to hard disk.", 611 | "Delete in-memory data structures."], 612 | "precedence": [ 613 | {"F":0, "S":3}, 614 | {"F":1, "S":3}, 615 | {"F":2, "S":3}, 616 | {"F":3, "S":4}, 617 | {"F":4, "S":5} 618 | ], 619 | "strategies": [ 620 | "", 621 | " for efficient processing", 622 | " for higher space efficiency", 623 | " for easy parallelization"] 624 | }, 625 | "index": { 626 | "from_*": { 627 | "to_python": { 628 | "template": "\"\"\\nThis Python program indexes the files of the following database.\n\n\"\"\"\n\n--- Start of Python program ---", 629 | "marker": "--- End of Python program ---", 630 | "linepre": "" 631 | }, 632 | "to_bash": { 633 | "template": "# This Bash script indexes the files of the following database.\n\n\n\n--- Start of Bash script ---\n#!/bin/bash\n\necho \"Transforming data ...\"", 634 | "marker": "--- End of Bash script ---", 635 | "linepre": "# " 636 | }, 637 | "to_cpp": { 638 | "template": "// This C++ program indexes the files of the following database.\n\n\n\n--- Start of C++ program ---\n", 639 | "marker": "--- End of C++ program ---", 640 | "linepre": "// " 641 | } 642 | }, 643 | "tactics": [ 644 | "Iterate over all tables in the database. For each table:", 645 | "Load data from hard disk.", 646 | "Index the data .", 647 | "Write indexed data to disk.", 648 | "Remove in-memory data structures." 649 | ], 650 | "precedence": [ 651 | {"F":0, "S":3}, 652 | {"F":3, "S":4}, 653 | {"F":1, "S":2}, 654 | {"F":2, "S":3} 655 | ], 656 | "strategies": [ 657 | "", 658 | "using B+ tree indexes", 659 | "using hash indexes" 660 | ] 661 | }, 662 | "query": { 663 | "from_nl": { 664 | "sample_tasks": [ 665 | { 666 | "task": "How many perpetrators are there?", 667 | "db_id": 0 668 | }, 669 | { 670 | "task": "Find the buildings which have rooms with capacity more than 50.", 671 | "db_id": 1 672 | }, 673 | { 674 | "task": "What are the names and types of the companies that have ever operated a flight?", 675 | "db_id": 2 676 | } 677 | ], 678 | "to_python": { 679 | "template": "\"\"\"\nThis Python program answers the query \"\" on the following tables:\n\n\n\"\"\"\n\n--- Start of Python program ---", 680 | "marker": "--- End of Python program ---", 681 | "linepre": "", 682 | "sample_solution_links": [ 683 | "config/sq1_python.txt", 684 | "config/sq2_python.txt", 685 | "config/sq3_python.txt" 686 | ] 687 | }, 688 | "to_bash": { 689 | "template": "# This Bash script answers the query \"\" on the following tables:\n\n# Answer the query \"\":\n\n\n--- Start of Bash script ---\n#!/bin/bash\n\necho \"Processing query ...\"", 690 | "marker": "--- End of Bash script ---", 691 | "linepre": "# ", 692 | "sample_solution_links": [ 693 | "config/sq1_bash.txt", 694 | "config/sq2_bash.txt", 695 | "config/sq3_bash.txt" 696 | ] 697 | }, 698 | "to_cpp": { 699 | "template": "// This C++ program answers the query \"\" on the following tables:\n\n\n\n--- Start of C++ program ---\n", 700 | "marker": "--- End of C++ program ---", 701 | "linepre": "// ", 702 | "sample_solution_links": [ 703 | "config/sq1_cpp.txt", 704 | "config/sq2_cpp.txt", 705 | "config/sq3_cpp.txt" 706 | ] 707 | }, 708 | "to_pg_sql": { 709 | "template": "##### Translate this query into SQL: \n\n--- Start of SQL query ---\nSELECT ", 710 | "marker": "--- End of SQL query ---", 711 | "linepre": "# ", 712 | "sample_solution_links": [ 713 | "config/sq1_sql.txt", 714 | "config/sq2_sql.txt", 715 | "config/sq3_sql.txt" 716 | ] 717 | } 718 | }, 719 | "from_pg_sql": { 720 | "sample_tasks": [ 721 | { 722 | "task": "SELECT count(*) FROM perpetrator", 723 | "db_id": 0 724 | }, 725 | { 726 | "task": "SELECT DISTINCT building FROM classroom WHERE capacity > 50", 727 | "db_id": 1 728 | }, 729 | { 730 | "task": "SELECT T1.name , T1.type FROM operate_company AS T1 JOIN flight AS t2 ON T1.id = T2.company_id", 731 | "db_id": 2 732 | } 733 | ], 734 | "to_python": { 735 | "template": "\"\"\"\nThis Python program answers the SQL query \"\" on the following tables:\n\n\n\"\"\"\n\n--- Start of Python program ---", 736 | "marker": "--- End of Python program ---", 737 | "linepre": "", 738 | "sample_solution_links": [ 739 | "config/sq1_python.txt", 740 | "config/sq2_python.txt", 741 | "config/sq3_python.txt" 742 | ] 743 | }, 744 | "to_bash": { 745 | "template": "# This Bash script answers the SQL query \"\" on the following tables:\n\n# Answer the query \"\":\n\n\n--- Start of Bash script ---\n#!/bin/bash\n\necho \"Processing query ...\"", 746 | "marker": "--- End of Bash script ---", 747 | "linepre": "# ", 748 | "sample_solution_links": [ 749 | "config/sq1_bash.txt", 750 | "config/sq2_bash.txt", 751 | "config/sq3_bash.txt" 752 | ] 753 | }, 754 | "to_cpp": { 755 | "template": "// This C++ program answers the SQL query \"\" on the following tables:\n\n\n\n--- Start of C++ program ---\n", 756 | "marker": "--- End of C++ program ---", 757 | "linepre": "// ", 758 | "sample_solution_links": [ 759 | "config/sq1_cpp.txt", 760 | "config/sq2_cpp.txt", 761 | "config/sq3_cpp.txt" 762 | ] 763 | } 764 | }, 765 | "tactics": [ 766 | "Import libraries.", 767 | "Load data for all relevant tables.", 768 | "Sort data.", 769 | "Hash data.", 770 | "Calculate the answer to the query.", 771 | "Enable display for all rows and columns, strings of infinite length.", 772 | "Print out query result only." 773 | ], 774 | "precedence": [ 775 | {"F":0, "S":4}, 776 | {"F":5, "S":6}, 777 | {"F":1, "S":2}, 778 | {"F":1, "S":3}, 779 | {"F":1, "S":4}, 780 | {"F":4, "S":6} 781 | ], 782 | "strategies": [ 783 | "", 784 | " for parallel processing", 785 | " for GPU processing", 786 | " for efficient processing"] 787 | } 788 | } -------------------------------------------------------------------------------- /src/codexdb/plan.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Jan 21, 2022 3 | 4 | @author: immanueltrummer 5 | ''' 6 | import collections 7 | import json 8 | import sqlglot.parser 9 | import sqlglot.tokens 10 | import sqlglot.expressions 11 | 12 | 13 | class NlPlan(): 14 | """ Represents a plan described in natural language. 15 | 16 | Attributes: 17 | next_id: next integer ID to use for steps 18 | """ 19 | next_id = 0 20 | 21 | def __init__(self): 22 | """ Initializes plan steps. 23 | 24 | Args: 25 | id_steps: list of tuples (step ID and step) 26 | """ 27 | # List of Tuple(Step ID, List of parts) 28 | self.id_steps = [] 29 | 30 | def add_plan(self, plan): 31 | """ Add steps of another plan. 32 | 33 | Args: 34 | plan: add steps of this plan (after current steps). 35 | """ 36 | self.id_steps += plan.id_steps 37 | 38 | def add_step(self, step, at_end=True): 39 | """ Adds one step to plan. 40 | 41 | Args: 42 | step: a list mixing strings and expressions 43 | at_end: whether to add step at the end 44 | 45 | Returns: 46 | integer ID of newly created step 47 | """ 48 | step_id = self._step_ID() 49 | if at_end: 50 | self.id_steps.append((step_id, step)) 51 | else: 52 | self.id_steps.insert(0, (step_id, step)) 53 | return step_id 54 | 55 | def id_to_step(self): 56 | """ Returns dictionary mapping step IDs to steps. """ 57 | id_to_step = {} 58 | for step_id, step in self.id_steps: 59 | id_to_step[step_id] = step 60 | return id_to_step 61 | 62 | def intersperse_step(self, step): 63 | """ Add given step after each current plan step. 64 | 65 | Args: 66 | step: intersperse this step 67 | """ 68 | nr_steps = len(self.id_steps) 69 | for _ in range(nr_steps): 70 | self.add_step(step, False) 71 | 72 | new_id_steps = [] 73 | for i in range(nr_steps): 74 | new_id_steps.append(self.id_steps[i+nr_steps]) 75 | new_id_steps.append(self.id_steps[i]) 76 | 77 | self.id_steps = new_id_steps 78 | 79 | def last_step_id(self): 80 | """ Returns ID of last step or None. """ 81 | if self.id_steps: 82 | return self.id_steps[-1][0] 83 | else: 84 | return None 85 | 86 | def step_ref_counts(self): 87 | """ Count number of references for each step. 88 | 89 | Returns: 90 | dictionary mapping step IDs to number of references 91 | """ 92 | step_ref_counts = collections.defaultdict(lambda:0) 93 | for _, step in self.id_steps: 94 | for part in step: 95 | if isinstance(part, int): 96 | step_ref_counts[part] += 1 97 | return step_ref_counts 98 | 99 | def steps(self, offset=0): 100 | """ Generates list of natural language plan steps. 101 | 102 | Args: 103 | offset: add this number to each plan step index 104 | 105 | Returns: 106 | list of steps (strings) 107 | """ 108 | nl_steps = [self._step_to_nl(step) for _, step in self.id_steps] 109 | return [f'{(idx+offset)}. {s}.' for idx, s in enumerate(nl_steps, 1)] 110 | 111 | def _index_of(self, search): 112 | """ Finds step by its index. 113 | 114 | Args: 115 | search: search for step with this ID 116 | 117 | Returns: 118 | index of corresponding step or None 119 | """ 120 | for idx, (step_id, _) in enumerate(self.id_steps): 121 | if step_id == search: 122 | return idx + 1 123 | return None 124 | 125 | def _step_ID(self): 126 | """ Returns next unused step ID and advances counter. """ 127 | NlPlan.next_id += 1 128 | return NlPlan.next_id - 1 129 | 130 | def _step_to_nl(self, step): 131 | """ Transforms plan step into natural language. 132 | 133 | Args: 134 | step: list of strings or expressions 135 | 136 | Returns: 137 | string describing the step in natural language 138 | """ 139 | out_parts = [] 140 | for in_part in step: 141 | if isinstance(in_part, str): 142 | out_parts.append(in_part) 143 | elif isinstance(in_part, int): 144 | in_idx = self._index_of(in_part) 145 | out_parts.append(f'results of Step {in_idx}') 146 | else: 147 | raise ValueError(f'Cannot translate plan step part: {in_part}') 148 | return ' '.join(out_parts) 149 | 150 | 151 | class NlPlanner(): 152 | """ Generates natural language query plan for query. """ 153 | 154 | def __init__(self, id_case, quote_ids=True): 155 | """ Initializes planner. 156 | 157 | Args: 158 | id_case: whether to consider letter case for identifiers 159 | quote_ids: whether to place all identifiers in quotes 160 | """ 161 | self.id_case = id_case 162 | self.quote_ids = quote_ids 163 | self.tokenizer = sqlglot.tokens.Tokenizer() 164 | self.parser = sqlglot.parser.Parser() 165 | 166 | def nl(self, expression, key=None): 167 | """ Returns a natural language plan for given expression. 168 | 169 | Args: 170 | expression: an SQL expression 171 | key: transform this attribute 172 | 173 | Returns: 174 | a label list to refer to expression in text, preparatory steps 175 | """ 176 | if not expression: 177 | return [''] 178 | 179 | if isinstance(expression, str): 180 | return [expression] 181 | 182 | if key: 183 | return self.nl(expression.args.get(key)) 184 | 185 | handler_name = f'_{expression.key}_nl' 186 | if hasattr(self, handler_name): 187 | return getattr(self, handler_name)(expression) 188 | 189 | if isinstance(expression, sqlglot.expressions.Func): 190 | print(f'Function: {expression.key}') 191 | 192 | raise ValueError(f'Error - cannot process expression {expression.key}!') 193 | 194 | def plan(self, query): 195 | """ Parse query and return natural language plan. 196 | 197 | Args: 198 | query: SQL query to plan for 199 | 200 | Returns: 201 | plan for query with steps described in natural language 202 | """ 203 | tokens = self.tokenizer.tokenize(query) 204 | ast = self.parser.parse(tokens)[0] 205 | if not self.id_case: 206 | ast = self._lower_ids(ast) 207 | labels, plan = self.nl(ast) 208 | write_out = ['Write'] + labels + ["to file 'result.csv' (with header)"] 209 | plan.add_step(write_out) 210 | return plan 211 | 212 | def _alias(self, expression): 213 | """ Extract alias from alias expression. """ 214 | assert expression.key == 'alias', 'No alias type expression' 215 | alias_id = expression.args['alias'] 216 | return self._identifier_label(alias_id) 217 | 218 | def _alias_nl(self, expression): 219 | """ Translate alias into natural language. """ 220 | alias_labels, alias_prep = self.nl(expression, 'alias') 221 | this_labels, plan = self.nl(expression, 'this') 222 | plan.add_plan(alias_prep) 223 | last_labels = this_labels + ['(aka.'] + alias_labels + [')'] 224 | return last_labels, plan 225 | 226 | def _and_nl(self, expression): 227 | """ Translate logical and into natural language. """ 228 | return self._cmp(expression, 'and') 229 | 230 | def _agg_nl(self, expression, agg_name): 231 | """ Translates aggregate into natural language. 232 | 233 | Args: 234 | expression: aggregate expression to translate 235 | agg_name: name of aggregate to use in description 236 | 237 | Returns: 238 | list of labels, plan preparing aggregate 239 | """ 240 | arg_labels, prep = self.nl(expression, 'this') 241 | labels = [agg_name] + ['of'] + arg_labels 242 | return labels, prep 243 | 244 | def _avg_nl(self, expression): 245 | """ Translate average aggregate into natural language. """ 246 | return self._agg_nl(expression, 'average') 247 | 248 | def _between_nl(self, expression): 249 | """ Translates between statement into natural language. """ 250 | op = expression.args.get('this') 251 | low = expression.args.get('low') 252 | high = expression.args.get('high') 253 | op_label, op_prep = self.nl(op) 254 | low_label, low_prep = self.nl(low) 255 | high_label, high_prep = self.nl(high) 256 | plan = NlPlan() 257 | plan.add_plan(op_prep) 258 | plan.add_plan(low_prep) 259 | plan.add_plan(high_prep) 260 | step = ['Check if'] + op_label + ['is between'] + \ 261 | low_label + ['and'] + high_label 262 | last_labels = [plan.add_step(step)] 263 | return last_labels, plan 264 | 265 | def _binary(self, expression): 266 | """ Pre-processes a generic binary expression. 267 | 268 | Args: 269 | expression: a generic binary expression 270 | 271 | Returns: 272 | tuple: left labels, right labels, combined preparation 273 | """ 274 | left_labels, plan = self.nl(expression, 'this') 275 | right_labels, right_prep = self.nl(expression, 'expression') 276 | plan.add_plan(right_prep) 277 | return left_labels, right_labels, plan 278 | 279 | def _cmp(self, expression, comparator): 280 | """ Processes a binary comparison. 281 | 282 | Args: 283 | expression: a binary comparison expression 284 | comparator: natural language comparator 285 | 286 | Returns: 287 | labels representing comparison result, corresponding plan 288 | """ 289 | left, right, plan = self._binary(expression) 290 | step = ['Check if'] + left + [comparator] + right 291 | last_labels = [plan.add_step(step)] 292 | return last_labels, plan 293 | 294 | def _column_nl(self, expression): 295 | """ Express a column reference in natural language. """ 296 | labels, plan = self.nl(expression.args['this']) 297 | # labels += ['column'] 298 | table = expression.args.get('table') 299 | if table: 300 | table_labels, table_prep = self.nl(table) 301 | plan.add_plan(table_prep) 302 | labels += ['in'] + table_labels 303 | database = expression.args.get('database') 304 | if database: 305 | db_labels, db_prep = self.nl(database) 306 | plan.add_plan(db_prep) 307 | labels += ['in'] + db_labels 308 | return labels, plan 309 | 310 | def _complex_select_nl(self, expression): 311 | """ Translate complex select expression into natural language plan. """ 312 | # tbl_to_preds = self._unary_predicates(expression) 313 | from_expressions = expression.args['from'].args['expressions'] 314 | join_expressions = expression.args.get('joins') 315 | 316 | tbl_expressions = from_expressions 317 | for join in join_expressions: 318 | tbl_expression = join.args['this'] 319 | tbl_expressions += [tbl_expression] 320 | 321 | tables_aliases = [] 322 | for tbl_expression in tbl_expressions: 323 | table = self._tables(tbl_expression).pop() 324 | alias = table 325 | if tbl_expression.key == 'alias': 326 | alias = self._alias(tbl_expression) 327 | tables_aliases += [(table, alias)] 328 | 329 | # Load data and assign aliases 330 | plan = NlPlan() 331 | for table, alias in tables_aliases: 332 | step = ['Load table'] + [table] + ['and store as'] + [alias] 333 | plan.add_step(step) 334 | 335 | # preds = tbl_to_preds[alias] 336 | # for pred in preds: 337 | # pred_label, pred_plan = self.nl(pred) 338 | # plan.add_plan(pred_plan) 339 | # step = ['Filter'] + [alias] + ['using'] + pred_label 340 | # plan.add_step(step) 341 | 342 | # Apply predicates in where clause 343 | if expression.args.get('where'): 344 | where_expr = expression.args['where'].args['this'] 345 | conjuncts = self._conjuncts(where_expr) 346 | for pred in conjuncts: 347 | pred_labels, pred_plan = self.nl(pred) 348 | pred_plan = self._simplify_plan(pred_plan) 349 | tables = self._tables(pred) 350 | if len(tables) == 1: 351 | table = tables.pop() 352 | prefix = f'Filter {table}:' 353 | else: 354 | prefix = 'Filter table:' 355 | pred_plan.id_steps[-1][1].insert(0, prefix) 356 | plan.add_plan(pred_plan) 357 | # plan.add_step(step) 358 | 359 | # Join tables considering join conditions 360 | left_label = tables_aliases[0][1] 361 | for idx, join in enumerate(join_expressions, 1): 362 | right_label = tables_aliases[idx][1] 363 | join = self._strip_tables(join) 364 | eq_label = self._join_eq_label(join) 365 | step = ['Join'] + [left_label] + ['with'] + \ 366 | [right_label] + ['- condition:'] + [eq_label] 367 | left_label = plan.add_step(step) 368 | last_labels = [left_label] 369 | 370 | if expression.args.get('group'): 371 | group_expr = expression.args['group'] 372 | group_expr = self._strip_tables(group_expr) 373 | group_labels, group_prep = self._expressions(group_expr) 374 | plan.add_plan(group_prep) 375 | group_step = ['Group'] + last_labels + ['by'] + group_labels 376 | last_labels = [plan.add_step(group_step)] 377 | 378 | if expression.args.get('having'): 379 | having_expr = expression.args['having'].args['this'] 380 | having_expr = self._strip_tables(having_expr) 381 | having_labels, having_prep = self.nl(having_expr) 382 | plan.add_plan(having_prep) 383 | having_step = ['Filter groups from'] + last_labels + \ 384 | ['using'] + having_labels 385 | last_labels = [plan.add_step(having_step)] 386 | 387 | if expression.args.get('order'): 388 | order_expr = expression.args['order'] 389 | order_expr = self._strip_tables(order_expr) 390 | order_labels, order_prep = self._expressions(order_expr) 391 | plan.add_plan(order_prep) 392 | order_step = ['Order'] + last_labels + ['by'] + order_labels 393 | last_labels = [plan.add_step(order_step)] 394 | 395 | if expression.args.get('limit'): 396 | limit_expr = expression.args['limit'].args['this'] 397 | limit_expr = self._strip_tables(limit_expr) 398 | limit_labels, limit_prep = self.nl(limit_expr) 399 | plan.add_plan(limit_prep) 400 | limit_step = ['Keep only'] + limit_labels + \ 401 | ['rows from'] + last_labels 402 | last_labels = [plan.add_step(limit_step)] 403 | 404 | selectors = expression.args.get('expressions') 405 | select_labels = [] 406 | for idx, selector in enumerate(selectors, 1): 407 | selector = self._strip_tables(selector) 408 | select_cmd, select_prep = self.nl(selector) 409 | plan.add_plan(select_prep) 410 | step = ['Retrieve'] + select_cmd + ['from'] + last_labels 411 | select_label = plan.add_step(step) 412 | if select_labels: 413 | select_labels += [','] 414 | select_labels += [select_label] 415 | select_labels += [f'(column {idx})'] 416 | 417 | step = ['Create table with columns for'] + select_labels 418 | last_labels = [plan.add_step(step)] 419 | 420 | if expression.args.get('distinct'): 421 | distinct_step = ['Only keep unique rows from'] + last_labels 422 | last_labels = [plan.add_step(distinct_step)] 423 | 424 | return last_labels, plan 425 | 426 | def _conjuncts(self, expression): 427 | """ Extract list of conjuncts from expression. """ 428 | if expression.key == 'and': 429 | conjuncts = [] 430 | conjuncts += [expression.args.get('this')] 431 | conjuncts += [expression.args.get('expression')] 432 | return conjuncts 433 | else: 434 | return [expression] 435 | 436 | def _count_nl(self, expression): 437 | """ Translate count aggregate into natural language. """ 438 | count_args = expression.args.get('this') 439 | row_ref = 'rows' 440 | if expression.args.get('distinct') is True: 441 | row_ref = 'distinct rows' 442 | if count_args.args.get('this').key == 'star': 443 | return ['number of'] + [row_ref], NlPlan() 444 | else: 445 | arg_labels, prep = self.nl(count_args) 446 | labels = ['count of'] + [row_ref] + \ 447 | ['without null values in'] + arg_labels 448 | return labels, prep 449 | 450 | def _eq_nl(self, expression): 451 | """ Translate equality condition into natural language. """ 452 | return self._cmp(expression, 'equals') 453 | 454 | def _except_nl(self, expression): 455 | """ Translates SQL except expression into natural language. """ 456 | return self._set_operation(expression, 'From', 'remove', None) 457 | 458 | def _expressions(self, expression): 459 | """ Translates associated expressions into natural language. """ 460 | plan = NlPlan() 461 | labels = [] 462 | for expr in expression.args.get('expressions'): 463 | new_labels, prep = self.nl(expr) 464 | plan.add_plan(prep) 465 | labels += new_labels + [', '] 466 | return labels[:-1], plan 467 | 468 | def _from_nl(self, expression): 469 | """ Translates from clause into natural language description. """ 470 | last_label = None 471 | for expr in expression.args['expressions']: 472 | from_label, from_prep = self.nl(expr) 473 | if last_label is None: 474 | last_label = from_label 475 | plan = from_prep 476 | else: 477 | step = ['Join', last_label, 'with', from_label, '.'] 478 | last_label = plan.add_step(step) 479 | return last_label, plan 480 | 481 | def _gt_nl(self, expression): 482 | """ Translate greater than condition into natural language. """ 483 | return self._cmp(expression, 'is greater than') 484 | 485 | def _gte_nl(self, expression): 486 | """ Translate greater or equal into natural language. """ 487 | return self._cmp(expression, 'is greater or equal to') 488 | 489 | def _identifier_label(self, expression): 490 | """ Construct text label for identifier. """ 491 | label = expression.args.get('this') or '' 492 | if self.quote_ids or expression.args.get('quoted'): 493 | label = f"'{label}'" 494 | return label 495 | 496 | def _identifier_nl(self, expression): 497 | """ Express identifier (e.g., table name) in natural language. """ 498 | label = self._identifier_label(expression) 499 | return [label], NlPlan() 500 | 501 | def _in_nl(self, expression): 502 | """ Translate SQL IN expression into natural language. """ 503 | left_label, left_prep = self.nl(expression, 'this') 504 | right_label, right_prep = self.nl(expression, 'query') 505 | step = ['Check if'] + left_label + ['appears in'] + right_label 506 | plan = NlPlan() 507 | plan.add_plan(left_prep) 508 | plan.add_plan(right_prep) 509 | last_labels = [plan.add_step(step)] 510 | return last_labels, plan 511 | 512 | def _intersect_nl(self, expression): 513 | """ Translate set intersection into natural language. """ 514 | return self._set_operation( 515 | expression, 'Intersect', 'and', 516 | 'and eliminate duplicates') 517 | 518 | def _is_nl(self, expression): 519 | """ Translate SQL IS comparison into natural language. """ 520 | return self._cmp(expression, 'is') 521 | 522 | def _join_eq_label(self, expression): 523 | """ Translate equality join condition into natural language label. """ 524 | predicate = expression.args['on'] 525 | assert predicate.key == 'eq', 'No equality join predicate' 526 | left_op = predicate.args.get('this') 527 | right_op = predicate.args.get('expression') 528 | left_labels, _ = self._column_nl(left_op) 529 | right_labels, _ = self._column_nl(right_op) 530 | left_label = ' '.join(left_labels) 531 | right_label = ' '.join(right_labels) 532 | return left_label + ' (left) equals ' + right_label + ' (right)' 533 | 534 | def _join_nl(self, expression): 535 | """ Translates join expression into natural language. """ 536 | raise NotImplementedError 537 | 538 | def _literal_nl(self, expression): 539 | """ Translates a literal into natural language. """ 540 | text = expression.args.get('this') or '' 541 | try: 542 | float(text) 543 | is_string = False 544 | except: 545 | is_string = True 546 | # is_string = expression.args.get('is_string') 547 | if is_string: 548 | escape_code = sqlglot.tokens.Tokenizer.ESCAPE_CODE 549 | text = text.replace(escape_code, "'") 550 | text = f"'{text}'" 551 | if text == text.lower(): 552 | text = text + ' (all lowercase)' 553 | # text = text.replace("'", "''") 554 | return [text], NlPlan() 555 | else: 556 | return [text], NlPlan() 557 | 558 | def _like_nl(self, expression): 559 | """ Translate SQL LIKE into natural language. """ 560 | return self._cmp(expression, 'matches') 561 | 562 | def _lower_ids(self, expression): 563 | """ Lower references to databases, tables, and columns. """ 564 | def lower_id(node): 565 | """ Transforms unquoted identifiers to lower case. """ 566 | if node.key == 'identifier': 567 | value = node.args['this'] 568 | if not node.args['quoted']: 569 | value = value.lower() 570 | node.args['this'] = value 571 | return node 572 | 573 | return expression.transform( 574 | lambda n:lower_id(n)) 575 | 576 | def _lt_nl(self, expression): 577 | """ Translate less than comparison into natural language. """ 578 | return self._cmp(expression, 'is less than') 579 | 580 | def _lte_nl(self, expression): 581 | """ Translates less or equal than comparison into natural language. """ 582 | return self._cmp(expression, 'is less than or equal to') 583 | 584 | def _max_nl(self, expression): 585 | """ Translate maximum aggregate into natural language. """ 586 | return self._agg_nl(expression, 'maximum') 587 | 588 | def _min_nl(self, expression): 589 | """ Translate minimum aggregate into natural language. """ 590 | return self._agg_nl(expression, 'minimum') 591 | 592 | def _neg_nl(self, expression): 593 | """ Translates negation into natural language. """ 594 | labels, plan = self.nl(expression, 'this') 595 | return ['-'] + labels, plan 596 | 597 | def _neq_nl(self, expression): 598 | """ Translates inequality into natural language. """ 599 | return self._cmp(expression, 'is not equal to') 600 | 601 | def _not_nl(self, expression): 602 | """ Express negation in natural language. """ 603 | op_label, plan = self.nl(expression, 'this') 604 | step = ['Check if'] + op_label + ['is false'] 605 | last_labels = [plan.add_step(step)] 606 | return last_labels, plan 607 | 608 | def _null_nl(self, _): 609 | """ Express SQL NULL value in natural language. """ 610 | return 'unknown', NlPlan() 611 | 612 | def _ordered_nl(self, expression): 613 | """ Translates item in ORDER BY clause into natural language. """ 614 | last_labels, plan = self.nl(expression, 'this') 615 | is_desc = True if expression.args.get('desc') else False 616 | direction = '(descending)' if is_desc else '(ascending)' 617 | return last_labels + [direction], plan 618 | 619 | def _or_nl(self, expression): 620 | """ Translate logical or into natural language. """ 621 | return self._cmp(expression, 'or') 622 | 623 | def _paren_nl(self, expression): 624 | """ Translate parenthesis expression to natural language. """ 625 | return self.nl(expression, 'this') 626 | 627 | def _select_nl(self, expression): 628 | """ Generates natural language plan for select query. """ 629 | # Check for complex queries 630 | has_joins = expression.args['joins'] 631 | has_multi_select = len(expression.args['expressions']) > 1 632 | has_sub_queries = expression.sql().lower().count('select') > 1 633 | if has_joins or has_multi_select or has_sub_queries: 634 | return self._complex_select_nl(expression) 635 | 636 | else: 637 | from_labels, plan = self.nl(expression, 'from') 638 | last_labels = from_labels 639 | 640 | if expression.args.get('where'): 641 | where_expr = expression.args['where'].args['this'] 642 | where_labels, where_prep = self.nl(where_expr) 643 | plan.add_plan(where_prep) 644 | where_step = ['Filter'] + from_labels + ['using'] + where_labels 645 | last_labels = [plan.add_step(where_step)] 646 | 647 | if expression.args.get('group'): 648 | group_expr = expression.args['group'] 649 | group_labels, group_prep = self._expressions(group_expr) 650 | plan.add_plan(group_prep) 651 | group_step = \ 652 | ['Group rows from'] + last_labels + \ 653 | ['using'] + group_labels 654 | last_labels = [plan.add_step(group_step)] 655 | 656 | if expression.args.get('having'): 657 | having_expr = expression.args['having'].args['this'] 658 | having_labels, having_prep = self.nl(having_expr) 659 | plan.add_plan(having_prep) 660 | having_step = \ 661 | ['Filter groups from'] + last_labels + \ 662 | ['using'] + having_labels 663 | last_labels = [plan.add_step(having_step)] 664 | 665 | if expression.args.get('order'): 666 | order_expr = expression.args['order'] 667 | order_labels, order_prep = self._expressions(order_expr) 668 | plan.add_plan(order_prep) 669 | order_step = \ 670 | ['Order rows from'] + last_labels + \ 671 | ['using'] + order_labels 672 | last_labels = [plan.add_step(order_step)] 673 | 674 | if expression.args.get('limit'): 675 | limit_expr = expression.args['limit'].args['this'] 676 | limit_labels, limit_prep = self.nl(limit_expr) 677 | plan.add_plan(limit_prep) 678 | limit_step = \ 679 | ['Keep only'] + limit_labels + \ 680 | ['rows from'] + last_labels 681 | last_labels = [plan.add_step(limit_step)] 682 | 683 | select_labels, select_prep = self._expressions(expression) 684 | plan.add_plan(select_prep) 685 | select_step = ['Create table with columns'] + select_labels + \ 686 | ['from'] + last_labels 687 | last_labels = [plan.add_step(select_step)] 688 | 689 | if expression.args.get('distinct'): 690 | distinct_step = \ 691 | ['Only keep unique values in the rows from'] + \ 692 | last_labels 693 | last_labels = [plan.add_step(distinct_step)] 694 | 695 | return last_labels, plan 696 | 697 | def _set_operation(self, expression, prefix, connector, postfix): 698 | """ Translate set expression into natural language. 699 | 700 | Args: 701 | expression: SQL expression representing set operation 702 | prefix: start final step with this text 703 | connector: text between references to operands 704 | postfix: end final step with this text 705 | 706 | Returns: 707 | label and preparatory plan 708 | """ 709 | left_label, left_prep = self.nl(expression, 'this') 710 | right_label, right_prep = self.nl(expression, 'expression') 711 | plan = NlPlan() 712 | plan.add_plan(left_prep) 713 | plan.add_plan(right_prep) 714 | step = [prefix] + left_label + [connector] + right_label 715 | if postfix is not None: 716 | step += [postfix] 717 | last_labels = [plan.add_step(step)] 718 | return last_labels, plan 719 | 720 | def _simplify_plan(self, plan): 721 | """ Try to simplify given plan by merging steps. 722 | 723 | Args: 724 | plan: try to reduce number of steps in this plan 725 | 726 | Returns: 727 | simplified plan 728 | """ 729 | id_to_step = plan.id_to_step() 730 | step_ref_counts = plan.step_ref_counts() 731 | for step_id, step in plan.id_steps: 732 | new_step = [] 733 | for part in step: 734 | if isinstance(part, int): 735 | ref_step = id_to_step[part] 736 | if ref_step[0] == 'Check if': 737 | new_step += ref_step[1:] 738 | step_ref_counts[part] -= 1 739 | if step_ref_counts[part] == 0: 740 | del id_to_step[part] 741 | continue 742 | 743 | new_step += [part] 744 | id_to_step[step_id] = new_step 745 | id_steps = sorted(id_to_step.items(), key=lambda t:t[0]) 746 | 747 | new_plan = NlPlan() 748 | new_plan.id_steps = id_steps 749 | return new_plan 750 | 751 | def _star_nl(self, _): 752 | """ Translates star into natural language. """ 753 | return ['all columns'], NlPlan() 754 | 755 | def _strip_tables(self, expression): 756 | """ Recursively strips table references from expression. 757 | 758 | Args: 759 | expression: strip table references from this expression 760 | 761 | Returns: 762 | expression without table references 763 | """ 764 | def column_without_table(expression): 765 | """ Removes tables from column references. """ 766 | if expression.key == 'column': 767 | if 'table' in expression.args: 768 | del expression.args['table'] 769 | return expression 770 | 771 | if isinstance(expression, str): 772 | return expression 773 | else: 774 | return expression.transform( 775 | lambda n:column_without_table(n)) 776 | 777 | def _sum_nl(self, expression): 778 | """ Translate sum aggregate into natural language. """ 779 | return self._agg_nl(expression, 'sum') 780 | 781 | def _table_nl(self, expression): 782 | """ Describe table in natural language. """ 783 | table_labels, plan = self.nl(expression, 'this') 784 | step = ['Load data for table'] + table_labels 785 | last_labels = [plan.add_step(step)] 786 | return last_labels, plan 787 | 788 | def _tables(self, expression): 789 | """ Returns set of tables mentioned in expression. """ 790 | tables = set() 791 | if isinstance(expression, list): 792 | for element in expression: 793 | tables.update(self._tables(element)) 794 | 795 | elif isinstance(expression, sqlglot.expressions.Expression): 796 | if expression.key == 'table': 797 | tbl_expr = expression.args.get('this') 798 | tbl_label = self._identifier_label(tbl_expr) 799 | tables.add(tbl_label) 800 | else: 801 | for k, v in expression.args.items(): 802 | if k == 'table' and v is not None: 803 | tbl_label = self._identifier_label(v) 804 | tables.add(tbl_label) 805 | else: 806 | tables.update(self._tables(v)) 807 | 808 | return tables 809 | 810 | def _unary_predicates(self, expression): 811 | """ Map tables to their unary predicates. """ 812 | tbl_to_preds = collections.defaultdict(lambda:[]) 813 | if expression.args.get('where'): 814 | where_expr = expression.args['where'].args['this'] 815 | conjuncts = self._conjuncts(where_expr) 816 | for conjunct in conjuncts: 817 | tables = self._tables(conjunct) 818 | assert len(tables) == 1, 'Only unary predicates!' 819 | table = tables.pop() 820 | tbl_to_preds[table] += [conjunct] 821 | 822 | return tbl_to_preds 823 | 824 | def _union_nl(self, expression): 825 | """ Translates union into natural language. """ 826 | distinct = expression.args.get('distinct') 827 | drop_duplicates = True if distinct is not None else False 828 | postfix = 'and eliminate duplicates' if drop_duplicates else None 829 | return self._set_operation(expression, 'Form union of', 'and', postfix) 830 | 831 | 832 | if __name__ == '__main__': 833 | 834 | with open('/Users/immanueltrummer/benchmarks/spider/results_dev.json') as file: 835 | test_cases = json.load(file) 836 | 837 | planner = NlPlanner(False) 838 | for idx, test_case in enumerate(test_cases[0:200:2]): 839 | db_id = test_case['db_id'] 840 | query = test_case['query'] 841 | print('-----------------------') 842 | print(f'Q{idx}: {db_id}/{query}') 843 | plan = planner.plan(query) 844 | for step in plan.steps(): 845 | print(step) 846 | 847 | # query = "SELECT count(*) FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid WHERE T1.age > 20" 848 | # query = "SELECT T1.CountryName FROM COUNTRIES AS T1 JOIN CONTINENTS AS T2 ON T1.Continent = T2.ContId JOIN CAR_MAKERS AS T3 ON T1.CountryId = T3.Country WHERE T2.Continent = 'europe' GROUP BY T1.CountryName HAVING count(*) >= 3" 849 | #query = "select count(*) from ta as a join tb as b on (a.x=b.x) where a.c = 1 and a.d = 2 and (b.i=1 or b.j=2)" 850 | # query = "SELECT T2.name FROM singer_in_concert AS T1 JOIN singer AS T2 ON T1.singer_id = T2.singer_id JOIN concert AS T3 ON T1.concert_id = T3.concert_id WHERE T3.year = 2014" 851 | # query = "SELECT DISTINCT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid = T2.stuid JOIN pets AS T3 ON T3.petid = T2.petid WHERE T3.pettype = 'cat' OR T3.pettype = 'dog'" 852 | #query = "SELECT count(DISTINCT pettype) FROM pets" 853 | query = "select count(*) , t1.stuid from student as t1 join has_pet as t2 on t1.stuid = t2.stuid group by t1.stuid" 854 | planner = NlPlanner(False) 855 | plan = planner.plan(query) 856 | for step in plan.steps(): 857 | print(step) 858 | 859 | # with open('/Users/immanueltrummer/benchmarks/WikiSQL/data/results_test.json') as file: 860 | # test_cases = json.load(file) 861 | # 862 | # 863 | # for idx, test_case in enumerate(test_cases): 864 | # print(f'Idx: {idx}') 865 | # question = test_case['question'] 866 | # query = test_case['query'] 867 | # print(f'Question: {question}') 868 | # print(f'Query: {query}') 869 | # label, plan = planner.plan(query) 870 | # print(plan.steps()) --------------------------------------------------------------------------------