├── LICENSE ├── README.md ├── CONTRIBUTING.md ├── .gitignore ├── evaluate.py ├── CODE_OF_CONDUCT.md ├── data.py ├── execution.py ├── env.yml ├── collectors.py ├── process_sql.py ├── sample_selectors.py └── utils_sql.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Freda Shi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Natural Language to Code Translation with Execution 2 | [Freda Shi](mailto:freda@ttic.edu), Daniel Fried, Marjan Ghazvininejad, Luke Zettlemoyer, [Sida I. Wang](mailto:sida@fb.com) 3 | 4 | ## Setup 5 | 1. Download the [MBPP](https://github.com/google-research/google-research/tree/master/mbpp), [Spider](https://yale-lily.github.io/spider), and [NL2Bash](https://github.com/TellinaTool/nl2bash) datasets to `data/` and follow their instructions for necessary preprocessing steps. 6 | 2. Download our [collected Codex data](https://dl.fbaipublicfiles.com/mbr-exec/mbr-exec-release.zip). We have included the pre-executed result with the data; see also `execution.py` if you'd like to execute automatically collected code locally. 7 | 3. Install the `conda` environment by 8 | ```bash 9 | conda env create -f env.yml 10 | ``` 11 | --- 12 | ## Run the Selector 13 | Suppose that the collected Codex data is located at `data/mbr-exec/`, the following code returns the execution accuracy of selected MBPP `test` split code among `5` samples which are collected with temperature `0.3` (from Codex), using the `mbr_exec`(ours) method and random seed `0`: 14 | 15 | ```python 16 | from sample_selectors import select_mbpp 17 | select_mbpp(('test', 0.3, 'mbr_exec', 'data/mbr-exec/mbpp/', 5, 0)) 18 | ``` 19 | 20 | See also the code for more details. 21 | 22 | 23 | 24 | --- 25 | ## License 26 | MIT 27 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | ... (in particular how this is synced with internal changes to the project) 7 | 8 | ## Pull Requests 9 | We actively welcome your pull requests. 10 | 11 | 1. Fork the repo and create your branch from `main`. 12 | 2. If you've added code that should be tested, add tests. 13 | 3. If you've changed APIs, update the documentation. 14 | 4. Ensure the test suite passes. 15 | 5. Make sure your code lints. 16 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 17 | 18 | ## Contributor License Agreement ("CLA") 19 | In order to accept your pull request, we need you to submit a CLA. You only need 20 | to do this once to work on any of Facebook's open source projects. 21 | 22 | Complete your CLA here: 23 | 24 | ## Issues 25 | We use GitHub issues to track public bugs. Please ensure your description is 26 | clear and has sufficient instructions to be able to reproduce the issue. 27 | 28 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 29 | disclosure of security bugs. In those cases, please go through the process 30 | outlined on that page and do not file a public issue. 31 | 32 | ## Coding Style 33 | * 2 spaces for indentation rather than tabs 34 | * 80 character line length 35 | * ... 36 | 37 | ## License 38 | By contributing, you agree that your contributions will be licensed 39 | under the LICENSE file in the root directory of this source tree. 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import os 4 | import tempfile 5 | from datasets import load_metric 6 | from tqdm import tqdm 7 | 8 | from data import MBPPGoogleDataset 9 | from execution import Command 10 | 11 | 12 | """ dataset keys: src, trg_prediction, reference """ 13 | def evaluate_charbleu(dataset): 14 | bleu = load_metric('bleu') 15 | predictions = [[ch for ch in item['trg_prediction']] for item in dataset] 16 | references = [[[ch for ch in item['reference']]] for item in dataset] 17 | return bleu.compute(predictions=predictions, references=references) 18 | 19 | 20 | """ dataset keys: src, trg_prediction, reference (only trg_prediction useful) """ 21 | def evaluate_spider(dataset, gold_path, eval_mode='all'): 22 | tempdir = tempfile.TemporaryDirectory(prefix='spider-') 23 | pred_path = f'{tempdir.name}/preds.sql' 24 | result_path = f'{tempdir.name}/results.out' 25 | with open(pred_path, 'w') as fout: 26 | for item in dataset: 27 | print(item['trg_prediction'], file=fout) 28 | fout.close() 29 | os.system( 30 | f'''python ~/codebase/nl2code-expr/v2code/spider_official/evaluation.py --gold {gold_path} \ 31 | --pred {pred_path} \ 32 | --etype {eval_mode} \ 33 | --db ~/data/spider/database/ \ 34 | --table ~/data/spider/tables.json > {result_path} 35 | ''' 36 | ) 37 | results = open(result_path).readlines() 38 | tempdir.cleanup() 39 | return results 40 | 41 | 42 | """ dataset keys: src, trg_prediction, reference (only trg_prediction useful) """ 43 | def evaluate_google_mbpp(dataset, reference_path, split='test', timeout=10, return_details=False): 44 | references = MBPPGoogleDataset(reference_path) 45 | assert len(dataset) == len(references.raw_data[split]) 46 | tempdir = tempfile.TemporaryDirectory() 47 | passed_information = list() 48 | pbar = tqdm(references.raw_data[split]) 49 | for i, item in enumerate(pbar): 50 | if 'execution_result_full_pass' in dataset[i]: 51 | passed_information.append(int(all(x[1] == True for x in dataset[i]['execution_result_full_pass']))) 52 | else: 53 | test_cases = item['test_list'] 54 | test_setups = item['test_setup_code'] 55 | code = dataset[i]['trg_prediction'] 56 | # write code to file 57 | with open(f'{tempdir.name}/code.py', 'w') as fout: 58 | print(code, file=fout) 59 | print(test_setups, file=fout) 60 | for case in test_cases: 61 | print(case, file=fout) 62 | fout.close() 63 | command = Command(f'python {tempdir.name}/code.py >/dev/null 2>&1') 64 | execution_result = (command.run(timeout=timeout) == 0) 65 | passed_information.append(int(execution_result)) 66 | pbar.set_description(f'{sum(passed_information)} out of {i+1} passed.') 67 | tempdir.cleanup() 68 | if return_details: 69 | return passed_information 70 | else: 71 | return sum(passed_information) / len(passed_information) 72 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Open Source Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. 6 | 7 | ## Our Standards 8 | 9 | Examples of behavior that contributes to creating a positive environment include: 10 | 11 | Using welcoming and inclusive language 12 | Being respectful of differing viewpoints and experiences 13 | Gracefully accepting constructive criticism 14 | Focusing on what is best for the community 15 | Showing empathy towards other community members 16 | Examples of unacceptable behavior by participants include: 17 | 18 | The use of sexualized language or imagery and unwelcome sexual attention or advances 19 | Trolling, insulting/derogatory comments, and personal or political attacks 20 | Public or private harassment 21 | Publishing others’ private information, such as a physical or electronic address, without explicit permission 22 | Other conduct which could reasonably be considered inappropriate in a professional setting 23 | 24 | ## Our Responsibilities 25 | 26 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 27 | 28 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 29 | 30 | ## Scope 31 | 32 | This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 33 | 34 | ## Enforcement 35 | 36 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at opensource-conduct@fb.com. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. 37 | 38 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project’s leadership. 39 | 40 | ## Attribution 41 | 42 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 43 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 44 | 45 | [homepage]: https://www.contributor-covenant.org 46 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import collections 4 | import json 5 | import os 6 | import regex 7 | 8 | 9 | class NL2BashDataset(object): 10 | def __init__(self, path='data/nl2bash/data/bash'): 11 | self.data = collections.defaultdict() 12 | for split in ['train', 'dev', 'test']: 13 | nls = [x.strip() for x in open(os.path.join(path, f'{split}.nl.filtered'))] 14 | cms = [x.strip() for x in open(os.path.join(path, f'{split}.cm.filtered'))] 15 | self.data[split] = list(zip(nls, cms)) 16 | 17 | 18 | class SpiderDataset(object): 19 | def __init__(self, path='data/spider'): 20 | self.data = collections.defaultdict() 21 | self.dbs = json.load(open(f'{path}/tables.json')) 22 | self.id2db = {item['db_id']: item for item in self.dbs} 23 | for split in ['train', 'dev']: 24 | split_fname = 'train_spider' if split == 'train' else split 25 | data = json.load(open(f'{path}/{split_fname}.json')) 26 | nls = [x['question'] for x in data] 27 | cms = [x['query'] for x in data] 28 | db_info = [self.extract_db_info(x['db_id']) for x in data] 29 | self.data[split] = list(zip(nls, cms, db_info)) 30 | 31 | def extract_db_info(self, db_id): 32 | db = self.id2db[db_id] 33 | id2table = {i: table_name for i, table_name in enumerate(db['table_names_original'])} 34 | info = f'{db_id} ' 35 | used_table_id = set() 36 | for table_id, column_name in db['column_names_original']: 37 | if table_id == -1: 38 | info += f'| {column_name} ' 39 | elif table_id not in used_table_id: 40 | info += f'| {id2table[table_id]} : {column_name} ' 41 | used_table_id.add(table_id) 42 | else: 43 | info += f', {column_name} ' 44 | return info.strip() 45 | 46 | 47 | class MBPPGoogleDataset(object): 48 | def __init__(self, path='data/mbpp/mbpp.jsonl', mode='function_name'): 49 | raw_data = sorted([json.loads(x) for x in open(path)], key=lambda x: x['task_id']) 50 | for i, data_item in enumerate(raw_data): 51 | assert data_item['task_id'] == i + 1 52 | self.raw_data = collections.defaultdict() 53 | self.mode = mode 54 | # 374 for training, 100 heldout, 500 test 55 | self.raw_data['train'] = raw_data[:10] + raw_data[510:] 56 | self.raw_data['test'] = raw_data[10:510] 57 | # data for codex collector, in input-output-info format 58 | self.data = collections.defaultdict() 59 | for split in self.raw_data: 60 | self.data[split] = self.extract_data(self.raw_data[split], mode) 61 | 62 | @staticmethod 63 | def extract_data(raw_data, mode): 64 | if mode == 'function_name': 65 | get_function_name = lambda test_example: regex.match('assert [\(]*([^\(]+)\(', test_example).group(1) 66 | info = [get_function_name(x['test_list'][0]) for x in raw_data] 67 | elif mode == 'assertion': 68 | info = [x['test_list'][0] for x in raw_data] 69 | elif mode == 'assertion-full': 70 | info = [x['test_list'] for x in raw_data] 71 | else: 72 | raise Exception(f'Mode {mode} not supported.') 73 | nls = [x['text'] for x in raw_data] 74 | codes = [x['code'] for x in raw_data] 75 | return list(zip(nls, codes, info)) 76 | -------------------------------------------------------------------------------- /execution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import bashlex 4 | import json 5 | import os 6 | import pickle 7 | import regex 8 | import signal 9 | import subprocess 10 | import tempfile 11 | import threading 12 | from datasets import load_metric 13 | from glob import glob 14 | from nltk.translate.bleu_score import sentence_bleu 15 | from tqdm import tqdm 16 | 17 | from data import MBPPGoogleDataset 18 | from utils_sql import * 19 | 20 | 21 | class Command(object): 22 | def __init__(self, cmd): 23 | self.cmd = cmd 24 | self.process = None 25 | 26 | def run(self, timeout): 27 | def target(): 28 | self.process = subprocess.Popen(self.cmd, shell=True, preexec_fn=os.setsid) 29 | self.process.communicate() 30 | 31 | thread = threading.Thread(target=target) 32 | thread.start() 33 | 34 | thread.join(timeout) 35 | if thread.is_alive(): 36 | os.killpg(self.process.pid, signal.SIGTERM) 37 | thread.join() 38 | return self.process.returncode 39 | 40 | 41 | class PythonFunctionExecutor(object): 42 | def __init__(self, function_content, function_call, timeout=10): 43 | self.function_content = function_content 44 | self.function_call = function_call 45 | self.timeout = timeout 46 | 47 | def __call__(self): 48 | tempdir = tempfile.TemporaryDirectory() 49 | with open(f'{tempdir.name}/code.py', 'w') as fout: 50 | print(self.function_content, file=fout) 51 | print(f'result = {self.function_call}', file=fout) 52 | print(f'import pickle', file=fout) 53 | print(f'pickle.dump(result, open("{tempdir.name}/execution_result.pkl", "wb"))', file=fout) 54 | command = Command(f'python {tempdir.name}/code.py >/dev/null 2>&1') 55 | execution_status = command.run(timeout=self.timeout) 56 | if execution_status == 0: 57 | try: 58 | execution_results = pickle.load(open(f'{tempdir.name}/execution_result.pkl', 'rb')) 59 | except: 60 | execution_results = None 61 | else: 62 | execution_results = None 63 | tempdir.cleanup() 64 | return execution_status, execution_results 65 | 66 | 67 | def execute_mbpp_google_folder(base_path): 68 | # single assertion 69 | dataset = MBPPGoogleDataset(mode='assertion') 70 | for path in glob(f'{base_path}/*jsonl'): # execute first assertion call 71 | if os.path.exists(path.replace('jsonl', 'exec.pkl')): 72 | continue 73 | split = os.path.basename(path).split('-')[0] 74 | execution_results = list() 75 | for i, line in enumerate(tqdm(open(path).readlines())): 76 | assertion = dataset.data[split][i][-1] 77 | command = regex.match(f'assert (.+)==.+', assertion).group(1) 78 | item = json.loads(line) 79 | python_function = item['trg_prediction'] 80 | executor = PythonFunctionExecutor(python_function, command) 81 | execution_result = executor() 82 | execution_results.append(execution_result) 83 | with open(path.replace('jsonl', 'exec.pkl'), 'wb') as fout: 84 | pickle.dump(execution_results, fout) 85 | # multiple assertions (cheating) 86 | dataset = MBPPGoogleDataset(mode='assertion-full') 87 | for path in glob(f'{base_path}/*jsonl'): # execute all assertion calls 88 | if os.path.exists(path.replace('jsonl', 'execfull.pkl')): 89 | continue 90 | split = os.path.basename(path).split('-')[0] 91 | execution_results = list() 92 | for i, line in enumerate(tqdm(open(path).readlines())): 93 | execution_result = list() 94 | item = json.loads(line) 95 | python_function = item['trg_prediction'] 96 | for assertion in dataset.data[split][i][-1]: 97 | command = regex.match(f'assert (.+)==.+', assertion).group(1) 98 | executor = PythonFunctionExecutor(python_function, command) 99 | execution_result.append(executor()) 100 | execution_results.append(execution_result) 101 | with open(path.replace('jsonl', 'execfull.pkl'), 'wb') as fout: 102 | pickle.dump(execution_results, fout) 103 | # multiple assertions (pass or fail) 104 | for path in glob(f'{base_path}/*jsonl'): 105 | if os.path.exists(path.replace('jsonl', 'execfullpass.pkl')): 106 | continue 107 | split = os.path.basename(path).split('-')[0] 108 | execution_results = list() 109 | for i, line in enumerate(tqdm(open(path).readlines())): 110 | execution_result = list() 111 | item = json.loads(line) 112 | python_function = item['trg_prediction'] 113 | for assertion in dataset.data[split][i][-1]: 114 | command = regex.match(f'assert (.+==.+)', assertion).group(1) 115 | executor = PythonFunctionExecutor(python_function, f'({command})') 116 | execution_result.append(executor()) 117 | execution_results.append(execution_result) 118 | with open(path.replace('jsonl', 'execfullpass.pkl'), 'wb') as fout: 119 | pickle.dump(execution_results, fout) 120 | 121 | 122 | def execute_spider_folder( 123 | base_path, 124 | db_path='data/spider/database', 125 | gold_path='data/spider', 126 | table_path='data/spider/tables.json', 127 | timeout=10 128 | ): 129 | kmaps = build_foreign_key_map_from_json(table_path) 130 | for path in glob(f'{base_path}/*jsonl'): 131 | if os.path.exists(path.replace('jsonl', 'exec.pkl')): 132 | continue 133 | execution_results = list() 134 | split = os.path.basename(path).split('-')[0] 135 | file_gold_path = f'{gold_path}/{split}_gold.sql' 136 | with open(file_gold_path) as f: 137 | glist = [l.strip().split('\t') for l in f if len(l.strip()) > 0] 138 | with open(path) as f: 139 | plist = [json.loads(l)['trg_prediction'] for l in f] 140 | for p_str, (_, db_name) in tqdm(list(zip(plist, glist))): 141 | db = os.path.join(db_path, db_name, db_name + ".sqlite") 142 | schema = Schema(get_schema(db)) 143 | try: 144 | p_sql = get_sql(schema, p_str) 145 | except: 146 | # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql 147 | p_sql = { 148 | "except": None, 149 | "from": { 150 | "conds": [], 151 | "table_units": [] 152 | }, 153 | "groupBy": [], 154 | "having": [], 155 | "intersect": None, 156 | "limit": None, 157 | "orderBy": [], 158 | "select": [ 159 | False, 160 | [] 161 | ], 162 | "union": None, 163 | "where": [] 164 | } 165 | # rebuild sql for value evaluation 166 | kmap = kmaps[db_name] 167 | p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema) 168 | p_sql = rebuild_sql_val(p_sql) 169 | p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) 170 | execution_result = execute(db, p_str, p_sql, timeout) 171 | execution_results.append(execution_result) 172 | with open(path.replace('jsonl', 'exec.pkl'), 'wb') as fout: 173 | pickle.dump(execution_results, fout) 174 | 175 | 176 | def simulate_bash_exec(command): 177 | return list(bashlex.split(command)) 178 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: mbr-exec 2 | channels: 3 | - fastai 4 | - pytorch 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=1_llvm 10 | - _pytorch_select=0.1=cpu_0 11 | - alsa-lib=1.2.3=h516909a_0 12 | - argcomplete=1.12.3=pyhd3eb1b0_0 13 | - argon2-cffi=20.1.0=py37h27cfd23_1 14 | - astroid=2.3.3=py37_0 15 | - async_generator=1.10=py37h28b3542_0 16 | - attrs=21.2.0=pyhd3eb1b0_0 17 | - backcall=0.2.0=pyhd3eb1b0_0 18 | - blas=1.0=mkl 19 | - bleach=4.0.0=pyhd3eb1b0_0 20 | - bottleneck=1.3.2=py37heb32a55_1 21 | - brotli=1.0.9=he6710b0_2 22 | - brotlipy=0.7.0=py37h27cfd23_1003 23 | - ca-certificates=2022.2.1=h06a4308_0 24 | - certifi=2021.10.8=py37h06a4308_2 25 | - cffi=1.14.6=py37h036bc23_1 26 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 27 | - cryptography=3.4.7=py37hd23ed53_0 28 | - cudatoolkit=11.0.221=h6bb024c_0 29 | - cycler=0.10.0=py37_0 30 | - cymem=1.31.2=py37h6bb024c_0 31 | - cyrus-sasl=2.1.27=h230043b_4 32 | - cytoolz=0.9.0.1=py37h14c3975_1 33 | - dbus=1.13.18=hb2f20db_0 34 | - debugpy=1.4.1=py37h295c915_0 35 | - decorator=5.0.9=pyhd3eb1b0_0 36 | - defusedxml=0.7.1=pyhd3eb1b0_0 37 | - entrypoints=0.3=py37_0 38 | - expat=2.4.1=h2531618_2 39 | - faiss-gpu=1.7.1=py3.7_h293177f_1_cuda11.0 40 | - fastcore=1.3.26=py_0 41 | - fontconfig=2.13.1=h6c09931_0 42 | - fonttools=4.25.0=pyhd3eb1b0_0 43 | - freetype=2.10.4=h5ab3b9f_0 44 | - gettext=0.19.8.1=h73d1719_1008 45 | - ghapi=0.1.19=py_0 46 | - glib=2.70.0=h780b84a_1 47 | - glib-tools=2.70.0=h780b84a_1 48 | - gst-plugins-base=1.18.5=hf529b03_0 49 | - gstreamer=1.18.5=h76c114f_0 50 | - icu=68.1=h58526e2_0 51 | - idna=3.2=pyhd3eb1b0_0 52 | - importlib_metadata=3.10.0=hd3eb1b0_0 53 | - intel-openmp=2021.3.0=h06a4308_3350 54 | - ipykernel=6.2.0=py37h06a4308_1 55 | - ipython=7.26.0=py37hb070fc8_0 56 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 57 | - ipywidgets=7.6.3=pyhd3eb1b0_1 58 | - isort=4.3.21=py37_0 59 | - jbig=2.1=h7f98852_2003 60 | - jedi=0.18.0=py37h06a4308_1 61 | - jinja2=3.0.1=pyhd3eb1b0_0 62 | - jpeg=9d=h36c2ea0_0 63 | - jsonschema=3.2.0=pyhd3eb1b0_2 64 | - jupyter=1.0.0=py37_7 65 | - jupyter_client=7.0.1=pyhd3eb1b0_0 66 | - jupyter_console=6.4.0=pyhd3eb1b0_0 67 | - jupyter_core=4.7.1=py37h06a4308_0 68 | - jupyterlab_pygments=0.1.2=py_0 69 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 70 | - kiwisolver=1.3.1=py37h2531618_0 71 | - krb5=1.19.2=hcc1bbae_2 72 | - lazy-object-proxy=1.6.0=py37h27cfd23_0 73 | - lcms2=2.12=h3be6417_0 74 | - ld_impl_linux-64=2.35.1=h7274673_9 75 | - lerc=3.0=h9c3ff4c_0 76 | - libclang=11.1.0=default_ha53f305_1 77 | - libdeflate=1.8=h7f98852_0 78 | - libedit=3.1.20191231=he28a2e2_2 79 | - libevent=2.1.10=h9b69904_4 80 | - libfaiss=1.7.1=h7f34bec_1_cuda11.0 81 | - libffi=3.4.2=h9c3ff4c_4 82 | - libgcc-ng=11.2.0=h1d223b6_11 83 | - libgfortran-ng=7.5.0=ha8ba4b0_17 84 | - libgfortran4=7.5.0=ha8ba4b0_17 85 | - libglib=2.70.0=h174f98d_1 86 | - libiconv=1.16=h516909a_0 87 | - libllvm11=11.1.0=hf817b99_2 88 | - libnsl=2.0.0=h7f98852_0 89 | - libntlm=1.4=h7f98852_1002 90 | - libogg=1.3.4=h7f98852_1 91 | - libopus=1.3.1=h7f98852_1 92 | - libpng=1.6.37=hbc83047_0 93 | - libpq=13.3=hd57d9b9_1 94 | - libprotobuf=3.18.1=h780b84a_0 95 | - libsodium=1.0.18=h7b6447c_0 96 | - libstdcxx-ng=11.2.0=he4da1e4_11 97 | - libtiff=4.3.0=h6f004c6_2 98 | - libuuid=1.0.3=h1bed415_2 99 | - libuv=1.40.0=h7b6447c_0 100 | - libvorbis=1.3.7=h9c3ff4c_0 101 | - libwebp-base=1.2.0=h27cfd23_0 102 | - libxcb=1.14=h7b6447c_0 103 | - libxkbcommon=1.0.3=he3ba5ed_0 104 | - libxml2=2.9.12=h72842e0_0 105 | - libzlib=1.2.11=h36c2ea0_1013 106 | - llvm-openmp=12.0.1=h4bd325d_1 107 | - lz4-c=1.9.3=h295c915_1 108 | - markupsafe=2.0.1=py37h27cfd23_0 109 | - matplotlib=3.4.2=py37h06a4308_0 110 | - matplotlib-base=3.4.2=py37hab158f2_0 111 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2 112 | - mccabe=0.6.1=py37_1 113 | - mistune=0.8.4=py37h14c3975_1001 114 | - mkl=2021.3.0=h06a4308_520 115 | - mkl-service=2.4.0=py37h7f8727e_0 116 | - mkl_fft=1.3.0=py37h42c9631_2 117 | - mkl_random=1.2.2=py37h51133e4_0 118 | - mscorefonts=0.0.1=3 119 | - msgpack-numpy=0.4.7.1=pyhd3eb1b0_0 120 | - msgpack-python=0.5.6=py37h6bb024c_1 121 | - munkres=1.1.4=py_0 122 | - murmurhash=0.28.0=py37hf484d3e_0 123 | - mysql-client=8.0.27=hf09c6a7_0 124 | - mysql-common=8.0.27=ha770c72_0 125 | - mysql-devel=8.0.27=ha770c72_0 126 | - mysql-libs=8.0.27=hfa10184_0 127 | - mysql-server=8.0.27=h1069331_0 128 | - nbclient=0.5.3=pyhd3eb1b0_0 129 | - nbconvert=6.1.0=py37h06a4308_0 130 | - nbformat=5.1.3=pyhd3eb1b0_0 131 | - ncurses=6.2=he6710b0_1 132 | - nest-asyncio=1.5.1=pyhd3eb1b0_0 133 | - notebook=6.4.3=py37h06a4308_0 134 | - nspr=4.30=h9c3ff4c_0 135 | - nss=3.69=hb5efdd6_1 136 | - numexpr=2.7.3=py37h22e1b3c_1 137 | - numpy=1.20.3=py37hf144106_0 138 | - numpy-base=1.20.3=py37h74d4b33_0 139 | - olefile=0.46=py37_0 140 | - openjpeg=2.4.0=h3ad879b_0 141 | - openssl=1.1.1m=h7f8727e_0 142 | - packaging=21.0=pyhd3eb1b0_0 143 | - pandocfilters=1.4.3=py37h06a4308_1 144 | - parso=0.8.2=pyhd3eb1b0_0 145 | - pcre=8.45=h295c915_0 146 | - pexpect=4.8.0=pyhd3eb1b0_3 147 | - pickleshare=0.7.5=pyhd3eb1b0_1003 148 | - pip=21.0.1=py37h06a4308_0 149 | - plac=0.9.6=py37_1 150 | - preshed=1.0.1=py37he6710b0_0 151 | - prometheus_client=0.11.0=pyhd3eb1b0_0 152 | - prompt-toolkit=3.0.17=pyh06a4308_0 153 | - prompt_toolkit=3.0.17=hd3eb1b0_0 154 | - ptyprocess=0.7.0=pyhd3eb1b0_2 155 | - pycparser=2.20=py_2 156 | - pygments=2.10.0=pyhd3eb1b0_0 157 | - pylint=2.4.4=py37_0 158 | - pyopenssl=20.0.1=pyhd3eb1b0_1 159 | - pyparsing=2.4.7=pyhd3eb1b0_0 160 | - pyqt=5.12.3=py37h89c1867_7 161 | - pyqt-impl=5.12.3=py37he336c9b_7 162 | - pyqt5-sip=4.19.18=py37hcd2ae1e_7 163 | - pyqtchart=5.12=py37he336c9b_7 164 | - pyqtwebengine=5.12.1=py37he336c9b_7 165 | - pyrsistent=0.17.3=py37h7b6447c_0 166 | - pysocks=1.7.1=py37_1 167 | - python=3.7.12=hb7a2778_100_cpython 168 | - python-dateutil=2.8.2=pyhd3eb1b0_0 169 | - python_abi=3.7=2_cp37m 170 | - pytorch=1.7.1=py3.7_cuda11.0.221_cudnn8.0.5_0 171 | - pytz=2021.1=pyhd3eb1b0_0 172 | - pyzmq=22.2.1=py37h295c915_1 173 | - qt=5.12.9=hda022c4_4 174 | - qtconsole=5.1.0=pyhd3eb1b0_0 175 | - qtpy=1.10.0=pyhd3eb1b0_0 176 | - readline=8.1=h27cfd23_0 177 | - requests=2.26.0=pyhd3eb1b0_0 178 | - scipy=1.7.1=py37h292c36d_2 179 | - seaborn=0.11.2=pyhd3eb1b0_0 180 | - send2trash=1.5.0=pyhd3eb1b0_1 181 | - setuptools=52.0.0=py37h06a4308_0 182 | - sip=4.19.8=py37hf484d3e_0 183 | - six=1.16.0=pyhd3eb1b0_0 184 | - spacy=2.0.12=py37h962f231_0 185 | - sqlite=3.36.0=hc218d9a_0 186 | - termcolor=1.1.0=py37h06a4308_1 187 | - terminado=0.9.4=py37h06a4308_0 188 | - testpath=0.5.0=pyhd3eb1b0_0 189 | - thinc=6.10.3=py37h962f231_0 190 | - tk=8.6.11=h27826a3_1 191 | - toolz=0.11.1=pyhd3eb1b0_0 192 | - torchaudio=0.7.2=py37 193 | - torchvision=0.8.2=cpu_py37ha229d99_0 194 | - tornado=6.1=py37h27cfd23_0 195 | - tqdm=4.62.2=pyhd3eb1b0_1 196 | - traitlets=5.0.5=pyhd3eb1b0_0 197 | - typing_extensions=3.10.0.0=pyh06a4308_0 198 | - ujson=4.0.2=py37h2531618_0 199 | - urllib3=1.26.6=pyhd3eb1b0_1 200 | - wcwidth=0.2.5=py_0 201 | - webencodings=0.5.1=py37_1 202 | - wheel=0.37.0=pyhd3eb1b0_0 203 | - widgetsnbextension=3.5.1=py37_0 204 | - wrapt=1.10.11=py37h14c3975_2 205 | - xz=5.2.5=h7b6447c_0 206 | - zeromq=4.3.4=h2531618_0 207 | - zipp=3.5.0=pyhd3eb1b0_0 208 | - zlib=1.2.11=h36c2ea0_1013 209 | - zstd=1.5.0=ha95c52a_0 210 | - pip: 211 | - absl-py==1.0.0 212 | - aiohttp==3.7.4.post0 213 | - antlr4-python3-runtime==4.8 214 | - async-timeout==3.0.1 215 | - bashlex==0.16 216 | - beautifulsoup4==4.10.0 217 | - bitarray==2.3.3 218 | - cached-property==1.5.2 219 | - cachetools==5.0.0 220 | - chardet==4.0.0 221 | - chinese-converter==1.0.2 222 | - click==8.0.1 223 | - cloudpickle==1.6.0 224 | - colorama==0.4.4 225 | - conllu==4.4.1 226 | - cython==0.29.24 227 | - datasets==1.12.1 228 | - deepspeed==0.5.3 229 | - dill==0.3.4 230 | - et-xmlfile==1.1.0 231 | - filelock==3.0.12 232 | - flexible-dotdict==0.2.1 233 | - fsspec==2021.8.1 234 | - google-auth==2.6.0 235 | - google-auth-oauthlib==0.4.6 236 | - greenlet==1.1.1 237 | - grpcio==1.43.0 238 | - h5py==3.6.0 239 | - huggingface-hub==0.0.16 240 | - hydra-core==1.0.7 241 | - importlib-metadata==4.11.0 242 | - importlib-resources==5.2.2 243 | - joblib==1.0.1 244 | - markdown==3.3.6 245 | - mpi4py==3.1.1 246 | - multidict==5.1.0 247 | - multiprocess==0.70.12.2 248 | - mysql==0.0.3 249 | - mysql-connector-python==8.0.27 250 | - mysqlclient==2.0.3 251 | - networkx==2.4 252 | - ninja==1.10.2.1 253 | - nltk==3.4.5 254 | - nose==1.3.7 255 | - oauthlib==3.2.0 256 | - omegaconf==2.0.6 257 | - openai==0.10.4 258 | - openpyxl==3.0.8 259 | - pandas==1.3.3 260 | - pandas-stubs==1.2.0.16 261 | - pdfminer-six==20211012 262 | - pdfplumber==0.6.0 263 | - pillow==9.0.1 264 | - portalocker==2.3.2 265 | - protobuf==3.18.0 266 | - psutil==5.8.0 267 | - pyarrow==5.0.0 268 | - pyasn1==0.4.8 269 | - pyasn1-modules==0.2.8 270 | - pyyaml==5.4.1 271 | - regex==2021.8.28 272 | - requests-oauthlib==1.3.1 273 | - rsa==4.8 274 | - sacrebleu==2.0.0 275 | - sacremoses==0.0.45 276 | - scikit-learn==0.24.2 277 | - sentencepiece==0.1.96 278 | - simalign==0.2 279 | - soupsieve==2.2.1 280 | - sqlalchemy==1.4.24 281 | - stackapi==0.2.0 282 | - submitit==1.3.3 283 | - tabulate==0.8.9 284 | - tensorboard-data-server==0.6.1 285 | - tensorboard-plugin-wit==1.8.1 286 | - tensorboardx==1.8 287 | - threadpoolctl==2.2.0 288 | - tokenizers==0.10.3 289 | - transformers==4.10.0 290 | - triton==1.1.0 291 | - wand==0.6.7 292 | - werkzeug==2.0.3 293 | - xx-ent-wiki-sm==2.0.0 294 | - xxhash==2.0.2 295 | - yarl==1.6.3 296 | prefix: /private/home/fhs/.conda/envs/mbr-exec 297 | -------------------------------------------------------------------------------- /collectors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import argparse 4 | import copy 5 | import json 6 | import openai 7 | import os 8 | import pickle 9 | import random 10 | import signal 11 | import submitit 12 | import time 13 | from glob import glob 14 | from nltk.translate.bleu_score import sentence_bleu 15 | from tqdm import tqdm 16 | 17 | 18 | # Signal Handlers 19 | def handle_sigusr1(signum, frame): 20 | print(f'Received {signum}, requeuing job.', flush=True) 21 | os.system(f'scontrol requeue {os.getenv("SLURM_JOB_ID")}') 22 | exit() 23 | 24 | def handle_sigterm(signum, frame): 25 | print(f'Received {signum}, bypassing.', flush=True) 26 | pass 27 | 28 | 29 | 30 | def codex(configs, dataset, prefixes): 31 | # model 32 | openai.api_key = os.getenv("OPENAI_API_KEY") 33 | def codex_greedy(prompt): 34 | response = openai.Completion.create( 35 | engine=configs.engine_name if configs.engine_name is not None else 'davinci-codex', 36 | prompt=prompt, 37 | temperature=0, 38 | max_tokens=920, 39 | top_p=1, 40 | frequency_penalty=0, 41 | presence_penalty=0, 42 | stop=configs.end_template 43 | ) 44 | return response['choices'][0]['text'], None, None 45 | 46 | def codex_sample(prompt): 47 | response = openai.Completion.create( 48 | engine=configs.engine_name if configs.engine_name is not None else 'davinci-codex', 49 | prompt=prompt, 50 | temperature=configs.temperature, 51 | max_tokens=920, 52 | top_p=1, 53 | frequency_penalty=0, 54 | presence_penalty=0, 55 | logprobs=1, 56 | stop=configs.end_template 57 | ) 58 | return response['choices'][0]['text'], response['choices'][0]['logprobs']['tokens'], response['choices'][0]['logprobs']['token_logprobs'] 59 | 60 | prompt_prefix = ''.join([configs.prompt_template.format(src=x[0], trg=x[1]) for x in prefixes]) 61 | 62 | # save folder 63 | save_dir = f'{configs.output_path}/seed-{configs.seed}/{configs.n_prompts}-shot/{configs.mode}-{configs.temperature}/' 64 | os.system(f'mkdir -p {save_dir}') 65 | # save configs and prefixes 66 | if configs.rank == 0: 67 | with open(f'{save_dir}/prefixes.json', 'w') as fout: 68 | json.dump(prefixes, fout) 69 | fout.close() 70 | with open(f'{save_dir}/configs.pkl', 'wb') as fout: 71 | pickle.dump(configs, fout) 72 | fout.close() 73 | ofname = f'{save_dir}/{configs.split}-{configs.rank}.jsonl' 74 | # load checkpoint 75 | if os.path.exists(ofname): 76 | n_processed_examples = len(open(ofname).readlines()) 77 | else: 78 | n_processed_examples = 0 79 | pbar = tqdm(dataset) 80 | 81 | with open(ofname, 'a') as fout: 82 | for i, (src, trg) in enumerate(pbar): 83 | if i < n_processed_examples: 84 | continue 85 | prompt = prompt_prefix + configs.example_template.format(src=src) 86 | while True: 87 | try: 88 | trg_prediction, tokens, logprobs = codex_greedy(prompt) if configs.mode == 'greedy' else codex_sample(prompt) 89 | time.sleep(2) 90 | break 91 | except: 92 | print('calling too frequently.. sleeping for 30 secs.', flush=True) 93 | time.sleep(30) 94 | try: 95 | bleu_score = sentence_bleu([[ch for ch in trg]], [ch for ch in trg_prediction]) 96 | except: 97 | bleu_score = 0 98 | print( 99 | json.dumps( 100 | { 101 | 'prompt': prompt, 102 | 'src': src, 103 | 'trg_prediction': trg_prediction, 104 | 'reference': trg, 105 | 'tokens': tokens, 106 | 'logprobs': logprobs, 107 | 'bleu': bleu_score 108 | } 109 | ), 110 | file=fout, flush=True 111 | ) 112 | pbar.set_description(f'Process {configs.rank}') 113 | fout.close() 114 | 115 | 116 | def codex_with_info(configs, dataset, prefixes): 117 | # model 118 | openai.api_key = os.getenv("OPENAI_API_KEY") 119 | def codex_greedy(prompt): 120 | response = openai.Completion.create( 121 | engine=configs.engine_name if configs.engine_name is not None else 'davinci-codex', 122 | prompt=prompt, 123 | temperature=0, 124 | max_tokens=920, 125 | top_p=1, 126 | frequency_penalty=0, 127 | presence_penalty=0, 128 | stop=configs.end_template 129 | ) 130 | return response['choices'][0]['text'], None, None 131 | 132 | def codex_sample(prompt): 133 | response = openai.Completion.create( 134 | engine=configs.engine_name if configs.engine_name is not None else 'davinci-codex', 135 | prompt=prompt, 136 | temperature=configs.temperature, 137 | max_tokens=920, 138 | top_p=1, 139 | frequency_penalty=0, 140 | presence_penalty=0, 141 | logprobs=1, 142 | stop=configs.end_template 143 | ) 144 | return response['choices'][0]['text'], response['choices'][0]['logprobs']['tokens'], response['choices'][0]['logprobs']['token_logprobs'] 145 | 146 | prompt_prefix = ''.join([configs.prompt_template.format(src=x[0], trg=x[1], info=x[2]) for x in prefixes]) 147 | 148 | # save folder 149 | save_dir = f'{configs.output_path}/seed-{configs.seed}/{configs.n_prompts}-shot/{configs.mode}-{configs.temperature}/' 150 | os.system(f'mkdir -p {save_dir}') 151 | # save configs and prefixes 152 | if configs.rank == 0: 153 | with open(f'{save_dir}/prefixes.json', 'w') as fout: 154 | json.dump(prefixes, fout) 155 | fout.close() 156 | with open(f'{save_dir}/configs.pkl', 'wb') as fout: 157 | pickle.dump(configs, fout) 158 | fout.close() 159 | ofname = f'{save_dir}/{configs.split}-{configs.rank}.jsonl' 160 | # load checkpoint 161 | if os.path.exists(ofname): 162 | n_processed_examples = len(open(ofname).readlines()) 163 | else: 164 | n_processed_examples = 0 165 | pbar = tqdm(dataset) 166 | 167 | with open(ofname, 'a') as fout: 168 | for i, (src, trg, info) in enumerate(pbar): 169 | if i < n_processed_examples: 170 | continue 171 | prompt = prompt_prefix + configs.example_template.format(src=src, info=info) 172 | while True: 173 | try: 174 | trg_prediction, tokens, logprobs = codex_greedy(prompt) if configs.mode == 'greedy' else codex_sample(prompt) 175 | time.sleep(2) 176 | break 177 | except Exception as e: 178 | print(e, flush=True) 179 | time.sleep(30) 180 | try: 181 | bleu_score = sentence_bleu([[ch for ch in trg]], [ch for ch in trg_prediction]) 182 | except: 183 | bleu_score = 0 184 | print( 185 | json.dumps( 186 | { 187 | 'prompt': prompt, 188 | 'src': src, 189 | 'trg_prediction': trg_prediction, 190 | 'reference': trg, 191 | 'tokens': tokens, 192 | 'logprobs': logprobs, 193 | 'bleu': bleu_score 194 | } 195 | ), 196 | file=fout, flush=True 197 | ) 198 | pbar.set_description(f'Process {configs.rank}') 199 | fout.close() 200 | 201 | 202 | """ example collector: """ 203 | class Collector(object): 204 | def __init__(self, configs, dataset): 205 | self.configs = configs 206 | self.dataset = dataset 207 | 208 | def __call__(self): 209 | signal.signal(signal.SIGUSR1, handle_sigusr1) 210 | signal.signal(signal.SIGTERM, handle_sigterm) 211 | job_env = submitit.JobEnvironment() 212 | configs = copy.deepcopy(self.configs) 213 | configs.rank = job_env.global_rank 214 | configs.gpu = job_env.local_rank 215 | configs.world_size = job_env.num_tasks 216 | for seed in self.configs.seed: 217 | for n_prompts in self.configs.n_prompts: 218 | for temperature in self.configs.temperature: 219 | configs.n_prompts = n_prompts 220 | configs.seed = seed 221 | configs.temperature = temperature 222 | random.seed(configs.seed) 223 | if configs.saved_prefixes_path_template is not None: 224 | prefix_pool = list() 225 | for path in glob(configs.saved_prefixes_path_template, recursive=True): 226 | prefix_pool.extend(json.load(open(path))) 227 | prefix_pool = sorted(set([tuple(x) for x in prefix_pool])) 228 | prefixes = random.sample(prefix_pool, configs.n_prompts) 229 | else: 230 | prefixes = random.sample(self.dataset.data['train'], configs.n_prompts) 231 | if configs.shuffle_prefix: 232 | original_prefixes = copy.deepcopy(prefixes) 233 | while original_prefixes == prefixes: 234 | random.shuffle(prefixes) 235 | codex(configs, self.dataset.data[configs.split], prefixes) 236 | 237 | @staticmethod 238 | def parse_args(main_parser=None): 239 | if main_parser is None: 240 | main_parser = argparse.ArgumentParser() 241 | subparsers = main_parser.add_subparsers(title='commands', dest='mode') 242 | # collect 243 | parser = subparsers.add_parser('collect', help='collecting stage') 244 | parser.add_argument('--output-path', type=str, required=True) 245 | parser.add_argument('--split', type=str, default='dev', choices=['train', 'dev', 'test']) 246 | parser.add_argument('--seed', type=int, nargs='+', default=[0]) 247 | parser.add_argument('--n-prompts', type=int, nargs='+', default=[3], help='number of few-shot prompt examples') 248 | parser.add_argument('--mode', type=str, default='greedy', choices=['greedy', 'sample']) 249 | parser.add_argument('--n-samples', type=int, default=5, help='number of sampled examples under the sampling mode') 250 | parser.add_argument('--temperature', type=float, default=[0.6], nargs='+', help='sample temperature') 251 | parser.add_argument('--prompt-template', type=str, default='# {src}\n{trg}\n') 252 | parser.add_argument('--example-template', type=str, default='# {src}\n') 253 | parser.add_argument('--end-template', type=str, default='\n') 254 | parser.add_argument('--shuffle-prefix', action='store_true', default=False) 255 | parser.add_argument('--saved-prefixes-path-template', type=str, default=None) 256 | parser.add_argument('--engine-name', type=str, default=None) 257 | 258 | # slurm arguments 259 | parser.add_argument('--slurm-ntasks', type=int, default=None) 260 | parser.add_argument('--slurm-ngpus', type=int, default=0) 261 | parser.add_argument('--slurm-nnodes', type=int, default=1) 262 | parser.add_argument('--slurm-partition', type=str, default='devlab') 263 | 264 | args = main_parser.parse_args() 265 | 266 | if args.mode == 'greedy': 267 | args.n_samples = 1 268 | args.temperature = [0] 269 | if args.slurm_ntasks is None: 270 | args.slurm_ntasks = args.n_samples 271 | else: 272 | assert args.slurm_ntasks == args.n_samples 273 | return args 274 | 275 | @classmethod 276 | def from_args(cls, args=None, dataset=None): 277 | if args is None: 278 | args = cls.parse_args() 279 | assert dataset is not None 280 | return cls(args, dataset) 281 | 282 | 283 | """ example collector: """ 284 | class CollectorWithInfo(object): 285 | def __init__(self, configs, dataset): 286 | self.configs = configs 287 | self.dataset = dataset 288 | 289 | def __call__(self): 290 | signal.signal(signal.SIGUSR1, handle_sigusr1) 291 | signal.signal(signal.SIGTERM, handle_sigterm) 292 | job_env = submitit.JobEnvironment() 293 | configs = copy.deepcopy(self.configs) 294 | configs.rank = job_env.global_rank 295 | configs.gpu = job_env.local_rank 296 | configs.world_size = job_env.num_tasks 297 | for seed in self.configs.seed: 298 | for n_prompts in self.configs.n_prompts: 299 | for temperature in self.configs.temperature: 300 | configs.n_prompts = n_prompts 301 | configs.seed = seed 302 | configs.temperature = temperature 303 | random.seed(configs.seed) 304 | if configs.saved_prefixes_path_template is not None: 305 | prefix_pool = list() 306 | for path in glob(configs.saved_prefixes_path_template, recursive=True): 307 | prefix_pool.extend(json.load(open(path))) 308 | prefix_pool = sorted(set([tuple(x) for x in prefix_pool])) 309 | prefixes = random.sample(prefix_pool, configs.n_prompts) 310 | else: 311 | prefixes = random.sample(self.dataset.data['train'], configs.n_prompts) 312 | if configs.shuffle_prefix: 313 | original_prefixes = copy.deepcopy(prefixes) 314 | while original_prefixes == prefixes: 315 | random.shuffle(prefixes) 316 | codex_with_info(configs, self.dataset.data[configs.split], prefixes) 317 | 318 | @staticmethod 319 | def parse_args(main_parser=None): 320 | if main_parser is None: 321 | main_parser = argparse.ArgumentParser() 322 | subparsers = main_parser.add_subparsers(title='commands', dest='mode') 323 | # collect 324 | parser = subparsers.add_parser('collect', help='collecting stage') 325 | parser.add_argument('--output-path', type=str, required=True) 326 | parser.add_argument('--split', type=str, default='dev', choices=['train', 'dev', 'test']) 327 | parser.add_argument('--seed', type=int, nargs='+', default=[0]) 328 | parser.add_argument('--n-prompts', type=int, nargs='+', default=[3], help='number of few-shot prompt examples') 329 | parser.add_argument('--mode', type=str, default='greedy', choices=['greedy', 'sample']) 330 | parser.add_argument('--n-samples', type=int, default=5, help='number of sampled examples under the sampling mode') 331 | parser.add_argument('--temperature', type=float, default=[0.6], nargs='+', help='sample temperature') 332 | parser.add_argument('--prompt-template', type=str, default='{info}\n{src}\n{trg}\n') 333 | parser.add_argument('--example-template', type=str, default='{info}\n{src}\n') 334 | parser.add_argument('--end-template', type=str, default='') 335 | parser.add_argument('--shuffle-prefix', action='store_true', default=False) 336 | parser.add_argument('--saved-prefixes-path-template', type=str, default=None) 337 | parser.add_argument('--engine-name', type=str, default=None) 338 | # slurm arguments 339 | parser.add_argument('--slurm-ntasks', type=int, default=None) 340 | parser.add_argument('--slurm-ngpus', type=int, default=0) 341 | parser.add_argument('--slurm-nnodes', type=int, default=1) 342 | parser.add_argument('--slurm-partition', type=str, default='devlab') 343 | 344 | args = main_parser.parse_args() 345 | 346 | if args.mode == 'greedy': 347 | args.n_samples = 1 348 | args.temperature = [0] 349 | if args.slurm_ntasks is None: 350 | args.slurm_ntasks = args.n_samples 351 | else: 352 | assert args.slurm_ntasks == args.n_samples 353 | return args 354 | 355 | @classmethod 356 | def from_args(cls, args=None, dataset=None): 357 | if args is None: 358 | args = cls.parse_args() 359 | assert dataset is not None 360 | return cls(args, dataset) 361 | -------------------------------------------------------------------------------- /process_sql.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | ################################ 4 | # Assumptions: 5 | # 1. sql is correct 6 | # 2. only table name has alias 7 | # 3. only one intersect/union/except 8 | # 9 | # val: number(float)/string(str)/sql(dict) 10 | # col_unit: (agg_id, col_id, isDistinct(bool)) 11 | # val_unit: (unit_op, col_unit1, col_unit2) 12 | # table_unit: (table_type, col_unit/sql) 13 | # cond_unit: (not_op, op_id, val_unit, val1, val2) 14 | # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] 15 | # sql { 16 | # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) 17 | # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} 18 | # 'where': condition 19 | # 'groupBy': [col_unit1, col_unit2, ...] 20 | # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) 21 | # 'having': condition 22 | # 'limit': None/limit value 23 | # 'intersect': None/sql 24 | # 'except': None/sql 25 | # 'union': None/sql 26 | # } 27 | ################################ 28 | 29 | import json 30 | import sqlite3 31 | from nltk import word_tokenize 32 | 33 | CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') 34 | JOIN_KEYWORDS = ('join', 'on', 'as') 35 | 36 | WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') 37 | UNIT_OPS = ('none', '-', '+', "*", '/') 38 | AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') 39 | TABLE_TYPE = { 40 | 'sql': "sql", 41 | 'table_unit': "table_unit", 42 | } 43 | 44 | COND_OPS = ('and', 'or') 45 | SQL_OPS = ('intersect', 'union', 'except') 46 | ORDER_OPS = ('desc', 'asc') 47 | 48 | 49 | 50 | class Schema: 51 | """ 52 | Simple schema which maps table&column to a unique identifier 53 | """ 54 | def __init__(self, schema): 55 | self._schema = schema 56 | self._idMap = self._map(self._schema) 57 | 58 | @property 59 | def schema(self): 60 | return self._schema 61 | 62 | @property 63 | def idMap(self): 64 | return self._idMap 65 | 66 | def _map(self, schema): 67 | idMap = {'*': "__all__"} 68 | id = 1 69 | for key, vals in schema.items(): 70 | for val in vals: 71 | idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__" 72 | id += 1 73 | 74 | for key in schema: 75 | idMap[key.lower()] = "__" + key.lower() + "__" 76 | id += 1 77 | 78 | return idMap 79 | 80 | 81 | def get_schema(db): 82 | """ 83 | Get database's schema, which is a dict with table name as key 84 | and list of column names as value 85 | :param db: database path 86 | :return: schema dict 87 | """ 88 | 89 | schema = {} 90 | conn = sqlite3.connect(db) 91 | cursor = conn.cursor() 92 | 93 | # fetch table names 94 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") 95 | tables = [str(table[0].lower()) for table in cursor.fetchall()] 96 | 97 | # fetch table info 98 | for table in tables: 99 | cursor.execute("PRAGMA table_info({})".format(table)) 100 | schema[table] = [str(col[1].lower()) for col in cursor.fetchall()] 101 | 102 | return schema 103 | 104 | 105 | def get_schema_from_json(fpath): 106 | with open(fpath) as f: 107 | data = json.load(f) 108 | 109 | schema = {} 110 | for entry in data: 111 | table = str(entry['table'].lower()) 112 | cols = [str(col['column_name'].lower()) for col in entry['col_data']] 113 | schema[table] = cols 114 | 115 | return schema 116 | 117 | 118 | def tokenize(string): 119 | string = str(string) 120 | string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem?? 121 | quote_idxs = [idx for idx, char in enumerate(string) if char == '"'] 122 | assert len(quote_idxs) % 2 == 0, "Unexpected quote" 123 | 124 | # keep string value as token 125 | vals = {} 126 | for i in range(len(quote_idxs)-1, -1, -2): 127 | qidx1 = quote_idxs[i-1] 128 | qidx2 = quote_idxs[i] 129 | val = string[qidx1: qidx2+1] 130 | key = "__val_{}_{}__".format(qidx1, qidx2) 131 | string = string[:qidx1] + key + string[qidx2+1:] 132 | vals[key] = val 133 | 134 | toks = [word.lower() for word in word_tokenize(string)] 135 | # replace with string value token 136 | for i in range(len(toks)): 137 | if toks[i] in vals: 138 | toks[i] = vals[toks[i]] 139 | 140 | # find if there exists !=, >=, <= 141 | eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="] 142 | eq_idxs.reverse() 143 | prefix = ('!', '>', '<') 144 | for eq_idx in eq_idxs: 145 | pre_tok = toks[eq_idx-1] 146 | if pre_tok in prefix: 147 | toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ] 148 | 149 | return toks 150 | 151 | 152 | def scan_alias(toks): 153 | """Scan the index of 'as' and build the map for all alias""" 154 | as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as'] 155 | alias = {} 156 | for idx in as_idxs: 157 | alias[toks[idx+1]] = toks[idx-1] 158 | return alias 159 | 160 | 161 | def get_tables_with_alias(schema, toks): 162 | tables = scan_alias(toks) 163 | for key in schema: 164 | assert key not in tables, "Alias {} has the same name in table".format(key) 165 | tables[key] = key 166 | return tables 167 | 168 | 169 | def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): 170 | """ 171 | :returns next idx, column id 172 | """ 173 | tok = toks[start_idx] 174 | if tok == "*": 175 | return start_idx + 1, schema.idMap[tok] 176 | 177 | if '.' in tok: # if token is a composite 178 | alias, col = tok.split('.') 179 | key = tables_with_alias[alias] + "." + col 180 | return start_idx+1, schema.idMap[key] 181 | 182 | assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty" 183 | 184 | for alias in default_tables: 185 | table = tables_with_alias[alias] 186 | if tok in schema.schema[table]: 187 | key = table + "." + tok 188 | return start_idx+1, schema.idMap[key] 189 | 190 | assert False, "Error col: {}".format(tok) 191 | 192 | 193 | def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): 194 | """ 195 | :returns next idx, (agg_op id, col_id) 196 | """ 197 | idx = start_idx 198 | len_ = len(toks) 199 | isBlock = False 200 | isDistinct = False 201 | if toks[idx] == '(': 202 | isBlock = True 203 | idx += 1 204 | 205 | if toks[idx] in AGG_OPS: 206 | agg_id = AGG_OPS.index(toks[idx]) 207 | idx += 1 208 | assert idx < len_ and toks[idx] == '(' 209 | idx += 1 210 | if toks[idx] == "distinct": 211 | idx += 1 212 | isDistinct = True 213 | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) 214 | assert idx < len_ and toks[idx] == ')' 215 | idx += 1 216 | return idx, (agg_id, col_id, isDistinct) 217 | 218 | if toks[idx] == "distinct": 219 | idx += 1 220 | isDistinct = True 221 | agg_id = AGG_OPS.index("none") 222 | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) 223 | 224 | if isBlock: 225 | assert toks[idx] == ')' 226 | idx += 1 # skip ')' 227 | 228 | return idx, (agg_id, col_id, isDistinct) 229 | 230 | 231 | def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): 232 | idx = start_idx 233 | len_ = len(toks) 234 | isBlock = False 235 | if toks[idx] == '(': 236 | isBlock = True 237 | idx += 1 238 | 239 | col_unit1 = None 240 | col_unit2 = None 241 | unit_op = UNIT_OPS.index('none') 242 | 243 | idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) 244 | if idx < len_ and toks[idx] in UNIT_OPS: 245 | unit_op = UNIT_OPS.index(toks[idx]) 246 | idx += 1 247 | idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) 248 | 249 | if isBlock: 250 | assert toks[idx] == ')' 251 | idx += 1 # skip ')' 252 | 253 | return idx, (unit_op, col_unit1, col_unit2) 254 | 255 | 256 | def parse_table_unit(toks, start_idx, tables_with_alias, schema): 257 | """ 258 | :returns next idx, table id, table name 259 | """ 260 | idx = start_idx 261 | len_ = len(toks) 262 | key = tables_with_alias[toks[idx]] 263 | 264 | if idx + 1 < len_ and toks[idx+1] == "as": 265 | idx += 3 266 | else: 267 | idx += 1 268 | 269 | return idx, schema.idMap[key], key 270 | 271 | 272 | def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None): 273 | idx = start_idx 274 | len_ = len(toks) 275 | 276 | isBlock = False 277 | if toks[idx] == '(': 278 | isBlock = True 279 | idx += 1 280 | 281 | if toks[idx] == 'select': 282 | idx, val = parse_sql(toks, idx, tables_with_alias, schema) 283 | elif "\"" in toks[idx]: # token is a string value 284 | val = toks[idx] 285 | idx += 1 286 | else: 287 | try: 288 | val = float(toks[idx]) 289 | idx += 1 290 | except: 291 | end_idx = idx 292 | while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\ 293 | and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS: 294 | end_idx += 1 295 | 296 | idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables) 297 | idx = end_idx 298 | 299 | if isBlock: 300 | assert toks[idx] == ')' 301 | idx += 1 302 | 303 | return idx, val 304 | 305 | 306 | def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None): 307 | idx = start_idx 308 | len_ = len(toks) 309 | conds = [] 310 | 311 | while idx < len_: 312 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) 313 | not_op = False 314 | if toks[idx] == 'not': 315 | not_op = True 316 | idx += 1 317 | 318 | assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx]) 319 | op_id = WHERE_OPS.index(toks[idx]) 320 | idx += 1 321 | val1 = val2 = None 322 | if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values 323 | idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) 324 | assert toks[idx] == 'and' 325 | idx += 1 326 | idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables) 327 | else: # normal case: single value 328 | idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) 329 | val2 = None 330 | 331 | conds.append((not_op, op_id, val_unit, val1, val2)) 332 | 333 | if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS): 334 | break 335 | 336 | if idx < len_ and toks[idx] in COND_OPS: 337 | conds.append(toks[idx]) 338 | idx += 1 # skip and/or 339 | 340 | return idx, conds 341 | 342 | 343 | def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None): 344 | idx = start_idx 345 | len_ = len(toks) 346 | 347 | assert toks[idx] == 'select', "'select' not found" 348 | idx += 1 349 | isDistinct = False 350 | if idx < len_ and toks[idx] == 'distinct': 351 | idx += 1 352 | isDistinct = True 353 | val_units = [] 354 | 355 | while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS: 356 | agg_id = AGG_OPS.index("none") 357 | if toks[idx] in AGG_OPS: 358 | agg_id = AGG_OPS.index(toks[idx]) 359 | idx += 1 360 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) 361 | val_units.append((agg_id, val_unit)) 362 | if idx < len_ and toks[idx] == ',': 363 | idx += 1 # skip ',' 364 | 365 | return idx, (isDistinct, val_units) 366 | 367 | 368 | def parse_from(toks, start_idx, tables_with_alias, schema): 369 | """ 370 | Assume in the from clause, all table units are combined with join 371 | """ 372 | assert 'from' in toks[start_idx:], "'from' not found" 373 | 374 | len_ = len(toks) 375 | idx = toks.index('from', start_idx) + 1 376 | default_tables = [] 377 | table_units = [] 378 | conds = [] 379 | 380 | while idx < len_: 381 | isBlock = False 382 | if toks[idx] == '(': 383 | isBlock = True 384 | idx += 1 385 | 386 | if toks[idx] == 'select': 387 | idx, sql = parse_sql(toks, idx, tables_with_alias, schema) 388 | table_units.append((TABLE_TYPE['sql'], sql)) 389 | else: 390 | if idx < len_ and toks[idx] == 'join': 391 | idx += 1 # skip join 392 | idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema) 393 | table_units.append((TABLE_TYPE['table_unit'],table_unit)) 394 | default_tables.append(table_name) 395 | if idx < len_ and toks[idx] == "on": 396 | idx += 1 # skip on 397 | idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 398 | if len(conds) > 0: 399 | conds.append('and') 400 | conds.extend(this_conds) 401 | 402 | if isBlock: 403 | assert toks[idx] == ')' 404 | idx += 1 405 | if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 406 | break 407 | 408 | return idx, table_units, conds, default_tables 409 | 410 | 411 | def parse_where(toks, start_idx, tables_with_alias, schema, default_tables): 412 | idx = start_idx 413 | len_ = len(toks) 414 | 415 | if idx >= len_ or toks[idx] != 'where': 416 | return idx, [] 417 | 418 | idx += 1 419 | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 420 | return idx, conds 421 | 422 | 423 | def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables): 424 | idx = start_idx 425 | len_ = len(toks) 426 | col_units = [] 427 | 428 | if idx >= len_ or toks[idx] != 'group': 429 | return idx, col_units 430 | 431 | idx += 1 432 | assert toks[idx] == 'by' 433 | idx += 1 434 | 435 | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 436 | idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) 437 | col_units.append(col_unit) 438 | if idx < len_ and toks[idx] == ',': 439 | idx += 1 # skip ',' 440 | else: 441 | break 442 | 443 | return idx, col_units 444 | 445 | 446 | def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables): 447 | idx = start_idx 448 | len_ = len(toks) 449 | val_units = [] 450 | order_type = 'asc' # default type is 'asc' 451 | 452 | if idx >= len_ or toks[idx] != 'order': 453 | return idx, val_units 454 | 455 | idx += 1 456 | assert toks[idx] == 'by' 457 | idx += 1 458 | 459 | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): 460 | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) 461 | val_units.append(val_unit) 462 | if idx < len_ and toks[idx] in ORDER_OPS: 463 | order_type = toks[idx] 464 | idx += 1 465 | if idx < len_ and toks[idx] == ',': 466 | idx += 1 # skip ',' 467 | else: 468 | break 469 | 470 | return idx, (order_type, val_units) 471 | 472 | 473 | def parse_having(toks, start_idx, tables_with_alias, schema, default_tables): 474 | idx = start_idx 475 | len_ = len(toks) 476 | 477 | if idx >= len_ or toks[idx] != 'having': 478 | return idx, [] 479 | 480 | idx += 1 481 | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) 482 | return idx, conds 483 | 484 | 485 | def parse_limit(toks, start_idx): 486 | idx = start_idx 487 | len_ = len(toks) 488 | 489 | if idx < len_ and toks[idx] == 'limit': 490 | idx += 2 491 | return idx, int(toks[idx-1]) 492 | 493 | return idx, None 494 | 495 | 496 | def parse_sql(toks, start_idx, tables_with_alias, schema): 497 | isBlock = False # indicate whether this is a block of sql/sub-sql 498 | len_ = len(toks) 499 | idx = start_idx 500 | 501 | sql = {} 502 | if toks[idx] == '(': 503 | isBlock = True 504 | idx += 1 505 | 506 | # parse from clause in order to get default tables 507 | from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema) 508 | sql['from'] = {'table_units': table_units, 'conds': conds} 509 | # select clause 510 | _, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables) 511 | idx = from_end_idx 512 | sql['select'] = select_col_units 513 | # where clause 514 | idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables) 515 | sql['where'] = where_conds 516 | # group by clause 517 | idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables) 518 | sql['groupBy'] = group_col_units 519 | # having clause 520 | idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables) 521 | sql['having'] = having_conds 522 | # order by clause 523 | idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables) 524 | sql['orderBy'] = order_col_units 525 | # limit clause 526 | idx, limit_val = parse_limit(toks, idx) 527 | sql['limit'] = limit_val 528 | 529 | idx = skip_semicolon(toks, idx) 530 | if isBlock: 531 | assert toks[idx] == ')' 532 | idx += 1 # skip ')' 533 | idx = skip_semicolon(toks, idx) 534 | 535 | # intersect/union/except clause 536 | for op in SQL_OPS: # initialize IUE 537 | sql[op] = None 538 | if idx < len_ and toks[idx] in SQL_OPS: 539 | sql_op = toks[idx] 540 | idx += 1 541 | idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema) 542 | sql[sql_op] = IUE_sql 543 | return idx, sql 544 | 545 | 546 | def load_data(fpath): 547 | with open(fpath) as f: 548 | data = json.load(f) 549 | return data 550 | 551 | 552 | def get_sql(schema, query): 553 | toks = tokenize(query) 554 | tables_with_alias = get_tables_with_alias(schema.schema, toks) 555 | _, sql = parse_sql(toks, 0, tables_with_alias, schema) 556 | 557 | return sql 558 | 559 | 560 | def skip_semicolon(toks, start_idx): 561 | idx = start_idx 562 | while idx < len(toks) and toks[idx] == ";": 563 | idx += 1 564 | return idx 565 | -------------------------------------------------------------------------------- /sample_selectors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import bashlex 4 | import collections 5 | import json 6 | import pickle 7 | import numpy as np 8 | import os 9 | import random 10 | from glob import glob 11 | from nltk.translate.bleu_score import sentence_bleu 12 | from evaluate import evaluate_charbleu, evaluate_google_mbpp, evaluate_spider 13 | from execution import execute_mbpp_google_folder, execute_spider_folder, simulate_bash_exec 14 | 15 | """ k sets of configs in separate paths, n * k choose 1 selector """ 16 | class MultiSampleSelector(object): 17 | def __init__(self, paths, split='dev'): 18 | self.paths = list(sorted(glob(paths, recursive=True))) if isinstance(paths, str) else list(sorted(paths)) 19 | self.split = split 20 | self.data = collections.defaultdict(list) 21 | self.args = collections.defaultdict(list) 22 | for i, path in enumerate(self.paths): 23 | self.args[i] = pickle.load(open(f'{self.paths[0]}/configs.pkl', 'rb')) 24 | idx = 0 25 | while os.path.exists(f'{path}/{split}-{idx}.jsonl'): 26 | self.data[i, idx].extend([json.loads(x) for x in open(f'{path}/{split}-{idx}.jsonl')]) 27 | idx += 1 28 | for path_id, sample_id in self.data: 29 | for item in self.data[path_id, sample_id]: 30 | try: 31 | avg_logprob, sum_logprob = self.extract_logprob_stats(item, path_id) 32 | item['avg_logprob'] = avg_logprob 33 | item['sum_logprob'] = sum_logprob 34 | except: 35 | item['avg_logprob'] = item['sum_logprob'] = 0 36 | if self.paths[path_id].find('nl2bash') != -1: # NL2bash data, exec simulation 37 | try: 38 | bashlex.parse(item['trg_prediction']) 39 | item['executable'] = True 40 | except: 41 | item['executable'] = False 42 | try: 43 | item['trg_prediction_splitted'] = simulate_bash_exec(item['trg_prediction']) 44 | item['execution_result_simulated'] = collections.Counter(item['trg_prediction_splitted']) 45 | except: 46 | item['trg_prediction_splitted'] = list() 47 | item['execution_result_simulated'] = collections.Counter() 48 | 49 | def extract_logprob_stats(self, item, path_id): 50 | current_seq = '' 51 | extracted_position = None 52 | for i, _ in enumerate(item['tokens']): 53 | current_seq += item['tokens'][i] 54 | if current_seq.find(item['trg_prediction']) != -1 and current_seq.find(self.args[path_id].end_template) != -1: 55 | extracted_position = i + 1 56 | break 57 | logprobs = item['logprobs'][:extracted_position] if extracted_position is not None else item['logprobs'] 58 | logprobs = list(filter(lambda x: x<0, logprobs)) # handle potential codex bug on positive log probability 59 | return np.mean(logprobs), np.sum(logprobs) 60 | 61 | def select(self, ids=None, key_extractor=lambda x:x['avg_logprob'],return_keys=False): 62 | if ids is None: 63 | ids = self.data.keys() 64 | ids = list(sorted(ids)) 65 | print(f'Selecting Samples from IDs: {ids}', flush=True) 66 | n_examples = len(self.data[ids[0]]) 67 | selected_examples = list() 68 | sample_keys = collections.defaultdict(list) 69 | for i in range(n_examples): 70 | max_key = None 71 | selected_item = None 72 | for idx in ids: 73 | item = self.data[idx][i] 74 | key = key_extractor(item) 75 | sample_keys[idx].append(key) 76 | if max_key is None or key > max_key: 77 | max_key = key 78 | selected_item = item 79 | assert selected_item is not None 80 | selected_examples.append(selected_item) 81 | if return_keys: 82 | return selected_examples, sample_keys 83 | else: 84 | return selected_examples 85 | 86 | 87 | class ExecutionBasedMultiSampleSelector(MultiSampleSelector): 88 | def __init__(self, paths, split='dev', execution_type=None): 89 | super().__init__(paths, split=split) 90 | self.execution_type = execution_type 91 | for i, path in enumerate(self.paths): 92 | if execution_type == 'mbpp': 93 | execute_mbpp_google_folder(path) 94 | elif execution_type == 'spider': 95 | execute_spider_folder(path) 96 | else: 97 | raise Exception(f'Execution type {execution_type} not supported.') 98 | idx = 0 99 | while os.path.exists(f'{path}/{split}-{idx}.exec.pkl'): 100 | for j, execution_result in enumerate(pickle.load(open(f'{path}/{split}-{idx}.exec.pkl', 'rb'))): 101 | self.data[i, idx][j]['execution_result'] = execution_result 102 | idx += 1 103 | idx = 0 104 | while os.path.exists(f'{path}/{split}-{idx}.execfull.pkl'): 105 | for j, execution_result in enumerate(pickle.load(open(f'{path}/{split}-{idx}.execfull.pkl', 'rb'))): 106 | self.data[i, idx][j]['execution_result_full'] = execution_result 107 | idx += 1 108 | idx = 0 109 | while os.path.exists(f'{path}/{split}-{idx}.execfullpass.pkl'): 110 | for j, execution_result in enumerate(pickle.load(open(f'{path}/{split}-{idx}.execfullpass.pkl', 'rb'))): 111 | self.data[i, idx][j]['execution_result_full_pass'] = execution_result 112 | idx += 1 113 | 114 | 115 | class IntraMultiSampleSelector(MultiSampleSelector): 116 | def __init__(self, paths, split='dev'): 117 | super().__init__(paths, split=split) 118 | 119 | def select( 120 | self, 121 | ids=None, 122 | key_extractor=None, 123 | second_key_extractor=None, 124 | return_keys=False 125 | ): 126 | if ids is None: 127 | ids = self.data.keys() 128 | elif isinstance(ids, int): 129 | ids = [(i, j) for i in set(x[0] for x in self.data.keys()) for j in range(ids)] 130 | ids = list(sorted(ids)) 131 | id_set = set(ids) 132 | sample_keys = collections.defaultdict(list) 133 | print(f'Selecting Samples from IDs: {ids}') 134 | n_examples = len(self.data[ids[0]]) 135 | selected_examples = list() 136 | for i in range(n_examples): 137 | max_key = None 138 | selected_item = None 139 | for idx in id_set: 140 | item = self.data[idx][i] 141 | first_keys = list() 142 | for grndtruth_idx in ids: 143 | grndtruth_item = self.data[grndtruth_idx][i] 144 | key = key_extractor(item, grndtruth_item) 145 | first_keys.append(key) 146 | first_key = sum(first_keys) 147 | second_key = second_key_extractor(item) if second_key_extractor is not None else 0 148 | current_key = (first_key, second_key) 149 | item['mbr_key'] = current_key 150 | sample_keys[idx].append(current_key) 151 | if max_key is None or current_key > max_key: 152 | max_key = current_key 153 | selected_item = item 154 | assert selected_item is not None 155 | selected_examples.append(selected_item) 156 | if return_keys: 157 | return selected_examples, sample_keys 158 | else: 159 | return selected_examples 160 | 161 | 162 | class ExecutionBasedIntraMultiSampleSelector(IntraMultiSampleSelector): 163 | def __init__(self, paths, split='dev', execution_type=None): 164 | super().__init__(paths, split=split) 165 | self.execution_type = execution_type 166 | for i, path in enumerate(self.paths): 167 | if execution_type == 'mbpp': 168 | execute_mbpp_google_folder(path) 169 | elif execution_type == 'spider': 170 | execute_spider_folder(path) 171 | else: 172 | raise Exception(f'Execution type {execution_type} not supported.') 173 | idx = 0 174 | while os.path.exists(f'{path}/{split}-{idx}.exec.pkl'): 175 | for j, execution_result in enumerate(pickle.load(open(f'{path}/{split}-{idx}.exec.pkl', 'rb'))): 176 | self.data[i, idx][j]['execution_result'] = execution_result 177 | idx += 1 178 | idx = 0 179 | while os.path.exists(f'{path}/{split}-{idx}.execfull.pkl'): 180 | for j, execution_result in enumerate(pickle.load(open(f'{path}/{split}-{idx}.execfull.pkl', 'rb'))): 181 | self.data[i, idx][j]['execution_result_full'] = execution_result 182 | idx += 1 183 | idx = 0 184 | while os.path.exists(f'{path}/{split}-{idx}.exec.codexcases.pkl'): 185 | for j, execution_result in enumerate(pickle.load(open(f'{path}/{split}-{idx}.exec.codexcases.pkl', 'rb'))): 186 | self.data[i, idx][j]['execution_result_codexexec'] = execution_result 187 | idx += 1 188 | idx = 0 189 | while os.path.exists(f'{path}/{split}-{idx}.execfullpass.pkl'): 190 | for j, execution_result in enumerate(pickle.load(open(f'{path}/{split}-{idx}.execfullpass.pkl', 'rb'))): 191 | self.data[i, idx][j]['execution_result_full_pass'] = execution_result 192 | idx += 1 193 | 194 | 195 | """equivalence checking functions""" 196 | # base equavalence checking function 197 | def single_exec_result_matching(exec_x, exec_y, good_execution_result): 198 | try: 199 | if exec_x[0] == good_execution_result and exec_y[0] == good_execution_result and exec_x[1] == exec_y[1]: 200 | return 1 201 | else: 202 | return 0 203 | except: 204 | return 0 205 | 206 | 207 | # first assertion call matching 208 | def execution_selection_function(x, y, good_execution_result=0): 209 | exec_x, exec_y = x['execution_result'], y['execution_result'] 210 | return single_exec_result_matching(exec_x, exec_y, good_execution_result) 211 | 212 | 213 | # just executability checking 214 | def executability_selection_function(x, good_execution_result=0): 215 | exec_res = x['execution_result'] 216 | return exec_res[0] == good_execution_result 217 | 218 | 219 | def bleu_selection_function(x, y): 220 | return sentence_bleu([[ch for ch in x['trg_prediction']]], [ch for ch in y['trg_prediction']]) 221 | 222 | 223 | def token_bleu_selection_function(x, y): 224 | return sentence_bleu([x['trg_prediction'].split()], y['trg_prediction'].split()) 225 | 226 | 227 | def bash_execution_tokenbleu_selection_function(x, y): 228 | if not x['executable'] or not y['executable']: 229 | return 0 230 | x = x['trg_prediction_splitted'] 231 | y = y['trg_prediction_splitted'] 232 | return sentence_bleu([x], y) 233 | 234 | 235 | """ 236 | select and evaluate a group in batch 237 | required keys: 238 | data_split: 'train', 'dev' or 'test' 239 | temperature: 0.1 .. 1.0 240 | criterion: 'mbr_exec' ... see full options in the function 241 | data_path: root data path for the task 242 | n_samples: number of candidates 243 | rand_seed: random seed for one experiment 244 | """ 245 | def select_mbpp(args, return_selected=False, return_selector=False): 246 | data_split, temperature, criterion, data_path, n_samples, rand_seed = args 247 | mbpp_good_execution_result = 0 248 | data_path = f'{data_path}/seed-*/**/*-{temperature}/' 249 | secondary_key_function = None 250 | if criterion == 'mbr_exec': 251 | selector = ExecutionBasedIntraMultiSampleSelector(data_path, data_split, 'mbpp') 252 | sample_selection_function = lambda x, y: execution_selection_function(x, y, mbpp_good_execution_result) 253 | secondary_key_function = lambda x: x['sum_logprob'] 254 | elif criterion == 'logprob': 255 | selector = ExecutionBasedMultiSampleSelector(data_path, data_split, 'mbpp') # pre-execution for faster evaluation 256 | sample_selection_function = lambda x: x['sum_logprob'] 257 | elif criterion == 'avg_logprob': 258 | selector = ExecutionBasedMultiSampleSelector(data_path, data_split, 'mbpp') # pre-execution for faster evaluation 259 | sample_selection_function = lambda x: x['avg_logprob'] 260 | elif criterion == 'mbr_bleu': 261 | selector = ExecutionBasedIntraMultiSampleSelector(data_path, data_split, 'mbpp') # pre-execution for faster evaluation 262 | sample_selection_function = lambda x, y: bleu_selection_function(x, y) 263 | elif criterion == 'mbr_tokenbleu': 264 | selector = ExecutionBasedIntraMultiSampleSelector(data_path, data_split, 'mbpp') # pre-execution for faster evaluation 265 | sample_selection_function = lambda x, y: token_bleu_selection_function(x, y) 266 | elif criterion == 'executability-logprob': 267 | selector = ExecutionBasedMultiSampleSelector(data_path, data_split, 'mbpp') 268 | sample_selection_function = lambda x: (executability_selection_function(x, mbpp_good_execution_result), x['sum_logprob']) 269 | elif criterion == 'executability-avglogprob': 270 | selector = ExecutionBasedMultiSampleSelector(data_path, data_split, 'mbpp') 271 | sample_selection_function = lambda x: (executability_selection_function(x, mbpp_good_execution_result), x['avg_logprob']) 272 | elif criterion == 'executability-mbr_bleu': 273 | selector = ExecutionBasedIntraMultiSampleSelector(data_path, data_split, 'mbpp') # pre-execution for faster evaluation 274 | sample_selection_function = lambda x, y: bleu_selection_function(x, y) * (1 - x['execution_result'][0]) * \ 275 | (1 - y['execution_result'][0]) 276 | elif criterion == 'executability-mbr_tokenbleu': 277 | selector = ExecutionBasedIntraMultiSampleSelector(data_path, data_split, 'mbpp') # pre-execution for faster evaluation 278 | sample_selection_function = lambda x, y: token_bleu_selection_function(x, y) * (1 - x['execution_result'][0]) * \ 279 | (1 - y['execution_result'][0]) 280 | else: 281 | raise ValueError(f'Unknown criterion: {criterion}') 282 | id_keys = list(selector.data.keys()) 283 | random.seed(rand_seed) 284 | ids = random.sample(id_keys, n_samples) 285 | if secondary_key_function is not None: 286 | selected = selector.select(ids, sample_selection_function, secondary_key_function) 287 | else: 288 | selected = selector.select(ids, sample_selection_function) 289 | if return_selector: 290 | return selector 291 | elif return_selected: 292 | return selected 293 | else: 294 | result = evaluate_google_mbpp(selected, 'data/mbpp/mbpp.jsonl', 'test') 295 | return result 296 | 297 | 298 | def select_nl2bash(args, return_selected=False, return_selector=False): 299 | data_split, temperature, criterion, data_path, n_samples, rand_seed = args 300 | data_path = f'{data_path}/seed-*/**/*-{temperature}/' 301 | secondary_key_function = None 302 | if criterion == 'mbr_bleu': 303 | selector = IntraMultiSampleSelector(data_path, data_split) 304 | sample_selection_function = lambda x, y: bleu_selection_function(x, y) 305 | elif criterion == 'mbr_tokenbleu': 306 | selector = IntraMultiSampleSelector(data_path, data_split) 307 | sample_selection_function = lambda x, y: token_bleu_selection_function(x, y) 308 | elif criterion == 'mbr_exec_tokenbleu': 309 | selector = IntraMultiSampleSelector(data_path, data_split) 310 | sample_selection_function = lambda x, y: bash_execution_tokenbleu_selection_function(x, y) 311 | secondary_key_function = lambda x: x['sum_logprob'] 312 | elif criterion == 'logprob': 313 | selector = MultiSampleSelector(data_path, data_split) 314 | sample_selection_function = lambda x: x['sum_logprob'] 315 | elif criterion == 'avg_logprob': 316 | selector = MultiSampleSelector(data_path, data_split) 317 | sample_selection_function = lambda x: x['avg_logprob'] 318 | elif criterion == 'executability-logprob': 319 | selector = MultiSampleSelector(data_path, data_split) 320 | sample_selection_function = lambda x: (x['executable'], x['sum_logprob']) 321 | elif criterion == 'executability-avglogprob': 322 | selector = MultiSampleSelector(data_path, data_split) 323 | sample_selection_function = lambda x: (x['executable'], x['avg_logprob']) 324 | else: 325 | raise ValueError(f'Unknown criterion: {criterion}') 326 | id_keys = list(selector.data.keys()) 327 | random.seed(rand_seed) 328 | ids = random.sample(id_keys, n_samples) 329 | if secondary_key_function is not None: 330 | selected = selector.select(ids, sample_selection_function, secondary_key_function) 331 | else: 332 | selected = selector.select(ids, sample_selection_function) 333 | if return_selector: 334 | return selector 335 | elif return_selected: 336 | return selected 337 | else: 338 | result = evaluate_charbleu(selected) 339 | return result 340 | 341 | 342 | def select_spider( 343 | args, 344 | return_selected=False, 345 | return_selector=False, 346 | ): 347 | data_split, temperature, criterion, data_path, n_samples, rand_seed = args 348 | spider_good_execution_result = True 349 | data_path = f'{data_path}/seed-*/**/*-{temperature}/' 350 | secondary_key_function = None 351 | if criterion == 'mbr_exec': 352 | selector = ExecutionBasedIntraMultiSampleSelector(data_path, data_split, 'spider') 353 | sample_selection_function = lambda x, y: execution_selection_function(x, y, spider_good_execution_result) 354 | secondary_key_function = lambda x: x['sum_logprob'] 355 | elif criterion == 'logprob': 356 | selector = ExecutionBasedMultiSampleSelector(data_path, data_split, 'spider') # pre-execution for faster evaluation 357 | sample_selection_function = lambda x: x['sum_logprob'] 358 | elif criterion == 'avg_logprob': 359 | selector = ExecutionBasedMultiSampleSelector(data_path, data_split, 'spider') # pre-execution for faster evaluation 360 | sample_selection_function = lambda x: x['avg_logprob'] 361 | elif criterion == 'mbr_bleu': 362 | selector = ExecutionBasedIntraMultiSampleSelector(data_path, data_split, 'spider') # pre-execution for faster evaluation 363 | sample_selection_function = lambda x, y: bleu_selection_function(x, y) 364 | elif criterion == 'mbr_tokenbleu': 365 | selector = ExecutionBasedIntraMultiSampleSelector(data_path, data_split, 'spider') # pre-execution for faster evaluation 366 | sample_selection_function = lambda x, y: token_bleu_selection_function(x, y) 367 | elif criterion == 'executability-logprob': 368 | selector = ExecutionBasedMultiSampleSelector(data_path, data_split, 'spider') 369 | sample_selection_function = lambda x: (executability_selection_function(x, spider_good_execution_result), x['sum_logprob']) 370 | elif criterion == 'executability-avglogprob': 371 | selector = ExecutionBasedMultiSampleSelector(data_path, data_split, 'spider') 372 | sample_selection_function = lambda x: (executability_selection_function(x, spider_good_execution_result), x['avg_logprob']) 373 | elif criterion == 'executability-mbr_bleu': 374 | selector = ExecutionBasedIntraMultiSampleSelector(data_path, data_split, 'spider') # pre-execution for faster evaluation 375 | sample_selection_function = lambda x, y: bleu_selection_function(x, y) * x['execution_result'][0] * y['execution_result'][0] 376 | elif criterion == 'executability-mbr_tokenbleu': 377 | selector = ExecutionBasedIntraMultiSampleSelector(data_path, data_split, 'spider') # pre-execution for faster evaluation 378 | sample_selection_function = lambda x, y: token_bleu_selection_function(x, y) * x['execution_result'][0] * y['execution_result'][0] 379 | else: 380 | raise ValueError(f'Unknown criterion: {criterion}') 381 | id_keys = list(selector.data.keys()) 382 | random.seed(rand_seed) 383 | ids = random.sample(id_keys, n_samples) 384 | if secondary_key_function is not None: 385 | selected = selector.select(ids, sample_selection_function, secondary_key_function) 386 | else: 387 | selected = selector.select(ids, sample_selection_function) 388 | if return_selector: 389 | return selector 390 | if return_selected: 391 | return selected 392 | else: 393 | return evaluate_spider(selected, 'data/spider/dev_gold.sql', 'all') 394 | -------------------------------------------------------------------------------- /utils_sql.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | ################################ 4 | # val: number(float)/string(str)/sql(dict) 5 | # col_unit: (agg_id, col_id, isDistinct(bool)) 6 | # val_unit: (unit_op, col_unit1, col_unit2) 7 | # table_unit: (table_type, col_unit/sql) 8 | # cond_unit: (not_op, op_id, val_unit, val1, val2) 9 | # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] 10 | # sql { 11 | # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) 12 | # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} 13 | # 'where': condition 14 | # 'groupBy': [col_unit1, col_unit2, ...] 15 | # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) 16 | # 'having': condition 17 | # 'limit': None/limit value 18 | # 'intersect': None/sql 19 | # 'except': None/sql 20 | # 'union': None/sql 21 | # } 22 | ################################ 23 | 24 | from __future__ import print_function 25 | import os 26 | import json 27 | import sqlite3 28 | import signal 29 | from contextlib import contextmanager 30 | import argparse 31 | from process_sql import get_schema, Schema, get_sql 32 | 33 | # Flag to disable value evaluation 34 | DISABLE_VALUE = True 35 | # Flag to disable distinct in select evaluation 36 | DISABLE_DISTINCT = True 37 | 38 | 39 | CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') 40 | JOIN_KEYWORDS = ('join', 'on', 'as') 41 | 42 | WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') 43 | UNIT_OPS = ('none', '-', '+', "*", '/') 44 | AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') 45 | TABLE_TYPE = { 46 | 'sql': "sql", 47 | 'table_unit': "table_unit", 48 | } 49 | 50 | COND_OPS = ('and', 'or') 51 | SQL_OPS = ('intersect', 'union', 'except') 52 | ORDER_OPS = ('desc', 'asc') 53 | 54 | 55 | HARDNESS = { 56 | "component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'), 57 | "component2": ('except', 'union', 'intersect') 58 | } 59 | 60 | class TimeoutException(Exception): pass 61 | 62 | @contextmanager 63 | def time_limit(seconds): 64 | def signal_handler(signum, frame): 65 | raise TimeoutException("Timed out!") 66 | signal.signal(signal.SIGALRM, signal_handler) 67 | signal.alarm(seconds) 68 | try: 69 | yield 70 | finally: 71 | signal.alarm(0) 72 | 73 | 74 | def condition_has_or(conds): 75 | return 'or' in conds[1::2] 76 | 77 | 78 | def condition_has_like(conds): 79 | return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]] 80 | 81 | 82 | def condition_has_sql(conds): 83 | for cond_unit in conds[::2]: 84 | val1, val2 = cond_unit[3], cond_unit[4] 85 | if val1 is not None and type(val1) is dict: 86 | return True 87 | if val2 is not None and type(val2) is dict: 88 | return True 89 | return False 90 | 91 | 92 | def val_has_op(val_unit): 93 | return val_unit[0] != UNIT_OPS.index('none') 94 | 95 | 96 | def has_agg(unit): 97 | return unit[0] != AGG_OPS.index('none') 98 | 99 | 100 | def accuracy(count, total): 101 | if count == total: 102 | return 1 103 | return 0 104 | 105 | 106 | def recall(count, total): 107 | if count == total: 108 | return 1 109 | return 0 110 | 111 | 112 | def F1(acc, rec): 113 | if (acc + rec) == 0: 114 | return 0 115 | return (2. * acc * rec) / (acc + rec) 116 | 117 | 118 | def get_scores(count, pred_total, label_total): 119 | if pred_total != label_total: 120 | return 0,0,0 121 | elif count == pred_total: 122 | return 1,1,1 123 | return 0,0,0 124 | 125 | 126 | def eval_sel(pred, label): 127 | pred_sel = pred['select'][1] 128 | label_sel = label['select'][1] 129 | label_wo_agg = [unit[1] for unit in label_sel] 130 | pred_total = len(pred_sel) 131 | label_total = len(label_sel) 132 | cnt = 0 133 | cnt_wo_agg = 0 134 | 135 | for unit in pred_sel: 136 | if unit in label_sel: 137 | cnt += 1 138 | label_sel.remove(unit) 139 | if unit[1] in label_wo_agg: 140 | cnt_wo_agg += 1 141 | label_wo_agg.remove(unit[1]) 142 | 143 | return label_total, pred_total, cnt, cnt_wo_agg 144 | 145 | 146 | def eval_where(pred, label): 147 | pred_conds = [unit for unit in pred['where'][::2]] 148 | label_conds = [unit for unit in label['where'][::2]] 149 | label_wo_agg = [unit[2] for unit in label_conds] 150 | pred_total = len(pred_conds) 151 | label_total = len(label_conds) 152 | cnt = 0 153 | cnt_wo_agg = 0 154 | 155 | for unit in pred_conds: 156 | if unit in label_conds: 157 | cnt += 1 158 | label_conds.remove(unit) 159 | if unit[2] in label_wo_agg: 160 | cnt_wo_agg += 1 161 | label_wo_agg.remove(unit[2]) 162 | 163 | return label_total, pred_total, cnt, cnt_wo_agg 164 | 165 | 166 | def eval_group(pred, label): 167 | pred_cols = [unit[1] for unit in pred['groupBy']] 168 | label_cols = [unit[1] for unit in label['groupBy']] 169 | pred_total = len(pred_cols) 170 | label_total = len(label_cols) 171 | cnt = 0 172 | pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols] 173 | label_cols = [label.split(".")[1] if "." in label else label for label in label_cols] 174 | for col in pred_cols: 175 | if col in label_cols: 176 | cnt += 1 177 | label_cols.remove(col) 178 | return label_total, pred_total, cnt 179 | 180 | 181 | def eval_having(pred, label): 182 | pred_total = label_total = cnt = 0 183 | if len(pred['groupBy']) > 0: 184 | pred_total = 1 185 | if len(label['groupBy']) > 0: 186 | label_total = 1 187 | 188 | pred_cols = [unit[1] for unit in pred['groupBy']] 189 | label_cols = [unit[1] for unit in label['groupBy']] 190 | if pred_total == label_total == 1 \ 191 | and pred_cols == label_cols \ 192 | and pred['having'] == label['having']: 193 | cnt = 1 194 | 195 | return label_total, pred_total, cnt 196 | 197 | 198 | def eval_order(pred, label): 199 | pred_total = label_total = cnt = 0 200 | if len(pred['orderBy']) > 0: 201 | pred_total = 1 202 | if len(label['orderBy']) > 0: 203 | label_total = 1 204 | if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \ 205 | ((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)): 206 | cnt = 1 207 | return label_total, pred_total, cnt 208 | 209 | 210 | def eval_and_or(pred, label): 211 | pred_ao = pred['where'][1::2] 212 | label_ao = label['where'][1::2] 213 | pred_ao = set(pred_ao) 214 | label_ao = set(label_ao) 215 | 216 | if pred_ao == label_ao: 217 | return 1,1,1 218 | return len(pred_ao),len(label_ao),0 219 | 220 | 221 | def get_nestedSQL(sql): 222 | nested = [] 223 | for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]: 224 | if type(cond_unit[3]) is dict: 225 | nested.append(cond_unit[3]) 226 | if type(cond_unit[4]) is dict: 227 | nested.append(cond_unit[4]) 228 | if sql['intersect'] is not None: 229 | nested.append(sql['intersect']) 230 | if sql['except'] is not None: 231 | nested.append(sql['except']) 232 | if sql['union'] is not None: 233 | nested.append(sql['union']) 234 | return nested 235 | 236 | 237 | def eval_nested(pred, label): 238 | label_total = 0 239 | pred_total = 0 240 | cnt = 0 241 | if pred is not None: 242 | pred_total += 1 243 | if label is not None: 244 | label_total += 1 245 | if pred is not None and label is not None: 246 | cnt += Evaluator().eval_exact_match(pred, label) 247 | return label_total, pred_total, cnt 248 | 249 | 250 | def eval_IUEN(pred, label): 251 | lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect']) 252 | lt2, pt2, cnt2 = eval_nested(pred['except'], label['except']) 253 | lt3, pt3, cnt3 = eval_nested(pred['union'], label['union']) 254 | label_total = lt1 + lt2 + lt3 255 | pred_total = pt1 + pt2 + pt3 256 | cnt = cnt1 + cnt2 + cnt3 257 | return label_total, pred_total, cnt 258 | 259 | 260 | def get_keywords(sql): 261 | res = set() 262 | if len(sql['where']) > 0: 263 | res.add('where') 264 | if len(sql['groupBy']) > 0: 265 | res.add('group') 266 | if len(sql['having']) > 0: 267 | res.add('having') 268 | if len(sql['orderBy']) > 0: 269 | res.add(sql['orderBy'][0]) 270 | res.add('order') 271 | if sql['limit'] is not None: 272 | res.add('limit') 273 | if sql['except'] is not None: 274 | res.add('except') 275 | if sql['union'] is not None: 276 | res.add('union') 277 | if sql['intersect'] is not None: 278 | res.add('intersect') 279 | 280 | # or keyword 281 | ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] 282 | if len([token for token in ao if token == 'or']) > 0: 283 | res.add('or') 284 | 285 | cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] 286 | # not keyword 287 | if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0: 288 | res.add('not') 289 | 290 | # in keyword 291 | if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0: 292 | res.add('in') 293 | 294 | # like keyword 295 | if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0: 296 | res.add('like') 297 | 298 | return res 299 | 300 | 301 | def eval_keywords(pred, label): 302 | pred_keywords = get_keywords(pred) 303 | label_keywords = get_keywords(label) 304 | pred_total = len(pred_keywords) 305 | label_total = len(label_keywords) 306 | cnt = 0 307 | 308 | for k in pred_keywords: 309 | if k in label_keywords: 310 | cnt += 1 311 | return label_total, pred_total, cnt 312 | 313 | 314 | def count_agg(units): 315 | return len([unit for unit in units if has_agg(unit)]) 316 | 317 | 318 | def count_component1(sql): 319 | count = 0 320 | if len(sql['where']) > 0: 321 | count += 1 322 | if len(sql['groupBy']) > 0: 323 | count += 1 324 | if len(sql['orderBy']) > 0: 325 | count += 1 326 | if sql['limit'] is not None: 327 | count += 1 328 | if len(sql['from']['table_units']) > 0: # JOIN 329 | count += len(sql['from']['table_units']) - 1 330 | 331 | ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] 332 | count += len([token for token in ao if token == 'or']) 333 | cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] 334 | count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) 335 | 336 | return count 337 | 338 | 339 | def count_component2(sql): 340 | nested = get_nestedSQL(sql) 341 | return len(nested) 342 | 343 | 344 | def count_others(sql): 345 | count = 0 346 | # number of aggregation 347 | agg_count = count_agg(sql['select'][1]) 348 | agg_count += count_agg(sql['where'][::2]) 349 | agg_count += count_agg(sql['groupBy']) 350 | if len(sql['orderBy']) > 0: 351 | agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] + 352 | [unit[2] for unit in sql['orderBy'][1] if unit[2]]) 353 | agg_count += count_agg(sql['having']) 354 | if agg_count > 1: 355 | count += 1 356 | 357 | # number of select columns 358 | if len(sql['select'][1]) > 1: 359 | count += 1 360 | 361 | # number of where conditions 362 | if len(sql['where']) > 1: 363 | count += 1 364 | 365 | # number of group by clauses 366 | if len(sql['groupBy']) > 1: 367 | count += 1 368 | 369 | return count 370 | 371 | 372 | class Evaluator: 373 | """A simple evaluator""" 374 | def __init__(self): 375 | self.partial_scores = None 376 | 377 | def eval_hardness(self, sql): 378 | count_comp1_ = count_component1(sql) 379 | count_comp2_ = count_component2(sql) 380 | count_others_ = count_others(sql) 381 | 382 | if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0: 383 | return "easy" 384 | elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \ 385 | (count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0): 386 | return "medium" 387 | elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \ 388 | (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \ 389 | (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1): 390 | return "hard" 391 | else: 392 | return "extra" 393 | 394 | def eval_exact_match(self, pred, label): 395 | partial_scores = self.eval_partial_match(pred, label) 396 | self.partial_scores = partial_scores 397 | 398 | for _, score in partial_scores.items(): 399 | if score['f1'] != 1: 400 | return 0 401 | if len(label['from']['table_units']) > 0: 402 | label_tables = sorted(label['from']['table_units']) 403 | pred_tables = sorted(pred['from']['table_units']) 404 | return label_tables == pred_tables 405 | return 1 406 | 407 | def eval_partial_match(self, pred, label): 408 | res = {} 409 | 410 | label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label) 411 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 412 | res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 413 | acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) 414 | res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 415 | 416 | label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label) 417 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 418 | res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 419 | acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) 420 | res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 421 | 422 | label_total, pred_total, cnt = eval_group(pred, label) 423 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 424 | res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 425 | 426 | label_total, pred_total, cnt = eval_having(pred, label) 427 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 428 | res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 429 | 430 | label_total, pred_total, cnt = eval_order(pred, label) 431 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 432 | res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 433 | 434 | label_total, pred_total, cnt = eval_and_or(pred, label) 435 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 436 | res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 437 | 438 | label_total, pred_total, cnt = eval_IUEN(pred, label) 439 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 440 | res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 441 | 442 | label_total, pred_total, cnt = eval_keywords(pred, label) 443 | acc, rec, f1 = get_scores(cnt, pred_total, label_total) 444 | res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} 445 | 446 | return res 447 | 448 | 449 | def isValidSQL(sql, db): 450 | conn = sqlite3.connect(db) 451 | cursor = conn.cursor() 452 | try: 453 | cursor.execute(sql) 454 | except: 455 | return False 456 | return True 457 | 458 | 459 | def print_scores(scores, etype): 460 | levels = ['easy', 'medium', 'hard', 'extra', 'all'] 461 | partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 462 | 'group', 'order', 'and/or', 'IUEN', 'keywords'] 463 | 464 | print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels)) 465 | counts = [scores[level]['count'] for level in levels] 466 | print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts)) 467 | 468 | if etype in ["all", "exec"]: 469 | print('===================== EXECUTION ACCURACY =====================') 470 | this_scores = [scores[level]['exec'] for level in levels] 471 | print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores)) 472 | 473 | if etype in ["all", "match"]: 474 | print('\n====================== EXACT MATCHING ACCURACY =====================') 475 | exact_scores = [scores[level]['exact'] for level in levels] 476 | print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores)) 477 | print('\n---------------------PARTIAL MATCHING ACCURACY----------------------') 478 | for type_ in partial_types: 479 | this_scores = [scores[level]['partial'][type_]['acc'] for level in levels] 480 | print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) 481 | 482 | print('---------------------- PARTIAL MATCHING RECALL ----------------------') 483 | for type_ in partial_types: 484 | this_scores = [scores[level]['partial'][type_]['rec'] for level in levels] 485 | print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) 486 | 487 | print('---------------------- PARTIAL MATCHING F1 --------------------------') 488 | for type_ in partial_types: 489 | this_scores = [scores[level]['partial'][type_]['f1'] for level in levels] 490 | print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) 491 | 492 | 493 | def evaluate(gold, predict, db_dir, etype, kmaps): 494 | with open(gold) as f: 495 | glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] 496 | 497 | with open(predict) as f: 498 | plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] 499 | # plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")] 500 | # glist = [("SELECT max(SHARE) , min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")] 501 | evaluator = Evaluator() 502 | 503 | levels = ['easy', 'medium', 'hard', 'extra', 'all'] 504 | partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 505 | 'group', 'order', 'and/or', 'IUEN', 'keywords'] 506 | entries = [] 507 | scores = {} 508 | 509 | for level in levels: 510 | scores[level] = {'count': 0, 'partial': {}, 'exact': 0.} 511 | scores[level]['exec'] = 0 512 | for type_ in partial_types: 513 | scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0} 514 | 515 | eval_err_num = 0 516 | for p, g in zip(plist, glist): 517 | p_str = p[0] 518 | g_str, db = g 519 | db_name = db 520 | db = os.path.join(db_dir, db, db + ".sqlite") 521 | schema = Schema(get_schema(db)) 522 | g_sql = get_sql(schema, g_str) 523 | hardness = evaluator.eval_hardness(g_sql) 524 | scores[hardness]['count'] += 1 525 | scores['all']['count'] += 1 526 | 527 | try: 528 | p_sql = get_sql(schema, p_str) 529 | except: 530 | # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql 531 | p_sql = { 532 | "except": None, 533 | "from": { 534 | "conds": [], 535 | "table_units": [] 536 | }, 537 | "groupBy": [], 538 | "having": [], 539 | "intersect": None, 540 | "limit": None, 541 | "orderBy": [], 542 | "select": [ 543 | False, 544 | [] 545 | ], 546 | "union": None, 547 | "where": [] 548 | } 549 | eval_err_num += 1 550 | print("eval_err_num:{}".format(eval_err_num)) 551 | 552 | # rebuild sql for value evaluation 553 | kmap = kmaps[db_name] 554 | g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema) 555 | g_sql = rebuild_sql_val(g_sql) 556 | g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) 557 | p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema) 558 | p_sql = rebuild_sql_val(p_sql) 559 | p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) 560 | 561 | if etype in ["all", "exec"]: 562 | exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql) 563 | if exec_score: 564 | scores[hardness]['exec'] += 1.0 565 | scores['all']['exec'] += 1.0 566 | 567 | if etype in ["all", "match"]: 568 | exact_score = evaluator.eval_exact_match(p_sql, g_sql) 569 | partial_scores = evaluator.partial_scores 570 | if exact_score == 0: 571 | print("{} pred: {}".format(hardness,p_str)) 572 | print("{} gold: {}".format(hardness,g_str)) 573 | print("") 574 | scores[hardness]['exact'] += exact_score 575 | scores['all']['exact'] += exact_score 576 | for type_ in partial_types: 577 | if partial_scores[type_]['pred_total'] > 0: 578 | scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc'] 579 | scores[hardness]['partial'][type_]['acc_count'] += 1 580 | if partial_scores[type_]['label_total'] > 0: 581 | scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec'] 582 | scores[hardness]['partial'][type_]['rec_count'] += 1 583 | scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1'] 584 | if partial_scores[type_]['pred_total'] > 0: 585 | scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc'] 586 | scores['all']['partial'][type_]['acc_count'] += 1 587 | if partial_scores[type_]['label_total'] > 0: 588 | scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec'] 589 | scores['all']['partial'][type_]['rec_count'] += 1 590 | scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1'] 591 | 592 | entries.append({ 593 | 'predictSQL': p_str, 594 | 'goldSQL': g_str, 595 | 'hardness': hardness, 596 | 'exact': exact_score, 597 | 'partial': partial_scores 598 | }) 599 | 600 | for level in levels: 601 | if scores[level]['count'] == 0: 602 | continue 603 | if etype in ["all", "exec"]: 604 | scores[level]['exec'] /= scores[level]['count'] 605 | 606 | if etype in ["all", "match"]: 607 | scores[level]['exact'] /= scores[level]['count'] 608 | for type_ in partial_types: 609 | if scores[level]['partial'][type_]['acc_count'] == 0: 610 | scores[level]['partial'][type_]['acc'] = 0 611 | else: 612 | scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \ 613 | scores[level]['partial'][type_]['acc_count'] * 1.0 614 | if scores[level]['partial'][type_]['rec_count'] == 0: 615 | scores[level]['partial'][type_]['rec'] = 0 616 | else: 617 | scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \ 618 | scores[level]['partial'][type_]['rec_count'] * 1.0 619 | if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0: 620 | scores[level]['partial'][type_]['f1'] = 1 621 | else: 622 | scores[level]['partial'][type_]['f1'] = \ 623 | 2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / ( 624 | scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc']) 625 | 626 | print_scores(scores, etype) 627 | 628 | 629 | def eval_exec_match(db, p_str, g_str, pred, gold): 630 | """ 631 | return 1 if the values between prediction and gold are matching 632 | in the corresponding index. Currently not support multiple col_unit(pairs). 633 | """ 634 | conn = sqlite3.connect(db) 635 | cursor = conn.cursor() 636 | try: 637 | cursor.execute(p_str) 638 | p_res = cursor.fetchall() 639 | except: 640 | return False 641 | 642 | cursor.execute(g_str) 643 | q_res = cursor.fetchall() 644 | 645 | def res_map(res, val_units): 646 | rmap = {} 647 | for idx, val_unit in enumerate(val_units): 648 | key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2])) 649 | rmap[key] = [r[idx] for r in res] 650 | return rmap 651 | 652 | p_val_units = [unit[1] for unit in pred['select'][1]] 653 | q_val_units = [unit[1] for unit in gold['select'][1]] 654 | return res_map(p_res, p_val_units) == res_map(q_res, q_val_units) 655 | 656 | # tricks for dropping inexecutable code 657 | bad_commands = { 658 | 'SELECT first_name , country_code , birth_date FROM players, matches, rankings', 659 | 'SELECT Model FROM model_list WHERE ModelId IN (SELECT ModelId FROM car_names WHERE MakeId IN (SELECT MakeId FROM car_names WHERE Model IN (SELECT Model FROM car_names WHERE Make IN (SELECT Make FROM car_names WHERE ModelId IN (SELECT ModelId FROM cars_data WHERE MPG IN (SELECT MAX(MPG) FROM cars_data))))));', 660 | 'SELECT winner_name , winner_rank_points FROM matches , rankings , players ORDER BY winner_rank_points DESC LIMIT 1', 661 | 'SELECT winner_name , winner_rank FROM matches JOIN players JOIN rankings ORDER BY winner_age ASC LIMIT 3', 662 | 'SELECT winner_name FROM matches WHERE matches.winner_id IN (SELECT player_id FROM rankings WHERE rankings.ranking_date IN (SELECT ranking_date FROM rankings WHERE year = 2013) AND rankings.ranking_date IN (SELECT ranking_date FROM rankings WHERE year = 2016))', 663 | 'SELECT first_name , country_code , birth_date FROM players LEFT JOIN matches ON players.player_id = winner_id LEFT JOIN rankings ON players.player_id = winner_id WHERE winner_rank_points = (SELECT MAX(winner_rank_points) FROM players LEFT JOIN matches ON players.player_id = winner_id LEFT JOIN rankings ON players.player_id = winner_id);', 664 | 'SELECT first_name , country_code FROM players AS T1 JOIN rankings AS T2 ON T1.player_id = T2.player_id JOIN matches AS T3 ON T2.player_id = T3.winner_id WHERE T1.player_id = (SELECT MAX(T4.player_id) FROM players AS T4 JOIN rankings AS T5 ON T4.player_id = T5.player_id JOIN matches AS T6 ON T2.player_id = T6.winner_id) GROUP BY first_name , country_code', 665 | 'SELECT winner_ioc, winner_name FROM matches JOIN players JOIN rankings;', 666 | 'SELECT A.*, B.* FROM Players AS A, Rankings AS B WHERE A.country_code = "USA";', 667 | 'SELECT p.first_name, p.last_name, r.ranking_points FROM rankings AS r LEFT JOIN matches ON (r.player_id = matches.winner_id) LEFT JOIN players AS p ON(r.player_id = p.player_id) WHERE r.ranking_points = (SELECT max(ranking_points) FROM rankings AS max_r WHERE max_r.ranking_date <= r.ranking_date) GROUP BY r.player_id;', 668 | 'SELECT winner_name FROM matches WHERE winner_id IN (SELECT player_id FROM rankings WHERE ranking_date IN (SELECT ranking_date FROM rankings WHERE ranking_date IN (SELECT ranking_date FROM rankings WHERE year = 2013) AND year = 2016))', 669 | "SELECT winner_name FROM matches, rankings WHERE winner_id IN (SELECT player_id FROM rankings WHERE ranking_date >= (SELECT min(ranking_date) FROM rankings WHERE tours = 'atp' AND ranking = 'atp_rankings' AND ranking_points > (SELECT max(ranking_points) FROM rankings WHERE ranking_date >= (SELECT min(ranking_date) FROM rankings WHERE tours = 'atp' AND ranking = 'atp_rankings')) AND year = (SELECT year FROM matches WHERE tourney_name = 'Australian Open')) AND tours = 'atp' AND ranking = 'atp_rankings' AND ranking_points > (SELECT max(ranking_points) FROM rankings WHERE ranking_date >= (SELECT min(ranking_date) FROM rankings WHERE tours = 'atp' AND ranking = 'atp_rankings')) AND year = (SELECT year FROM matches WHERE tourney_name = 'Australian Open')) AND winner_name = (SELECT max(winner_name) FROM matches WHERE tourney_name = 'Australian Open')", 670 | 'SELECT first_name , avg(ranking) FROM players LEFT JOIN rankings' 671 | } 672 | 673 | def execute(db, p_str, pred, timeout): 674 | conn = sqlite3.connect(db) 675 | cursor = conn.cursor() 676 | try: 677 | print(p_str, flush=True) 678 | assert p_str.strip() not in bad_commands 679 | with time_limit(timeout): 680 | cursor.execute(p_str) 681 | p_res = cursor.fetchall() 682 | def res_map(res, val_units): 683 | rmap = {} 684 | for idx, val_unit in enumerate(val_units): 685 | key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2])) 686 | rmap[key] = [r[idx] for r in res] 687 | return rmap 688 | p_val_units = [unit[1] for unit in pred['select'][1]] 689 | return True, res_map(p_res, p_val_units) 690 | except: 691 | return False, None 692 | 693 | 694 | # Rebuild SQL functions for value evaluation 695 | def rebuild_cond_unit_val(cond_unit): 696 | if cond_unit is None or not DISABLE_VALUE: 697 | return cond_unit 698 | 699 | not_op, op_id, val_unit, val1, val2 = cond_unit 700 | if type(val1) is not dict: 701 | val1 = None 702 | else: 703 | val1 = rebuild_sql_val(val1) 704 | if type(val2) is not dict: 705 | val2 = None 706 | else: 707 | val2 = rebuild_sql_val(val2) 708 | return not_op, op_id, val_unit, val1, val2 709 | 710 | 711 | def rebuild_condition_val(condition): 712 | if condition is None or not DISABLE_VALUE: 713 | return condition 714 | 715 | res = [] 716 | for idx, it in enumerate(condition): 717 | if idx % 2 == 0: 718 | res.append(rebuild_cond_unit_val(it)) 719 | else: 720 | res.append(it) 721 | return res 722 | 723 | 724 | def rebuild_sql_val(sql): 725 | if sql is None or not DISABLE_VALUE: 726 | return sql 727 | 728 | sql['from']['conds'] = rebuild_condition_val(sql['from']['conds']) 729 | sql['having'] = rebuild_condition_val(sql['having']) 730 | sql['where'] = rebuild_condition_val(sql['where']) 731 | sql['intersect'] = rebuild_sql_val(sql['intersect']) 732 | sql['except'] = rebuild_sql_val(sql['except']) 733 | sql['union'] = rebuild_sql_val(sql['union']) 734 | 735 | return sql 736 | 737 | 738 | # Rebuild SQL functions for foreign key evaluation 739 | def build_valid_col_units(table_units, schema): 740 | col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']] 741 | prefixs = [col_id[:-2] for col_id in col_ids] 742 | valid_col_units= [] 743 | for value in schema.idMap.values(): 744 | if '.' in value and value[:value.index('.')] in prefixs: 745 | valid_col_units.append(value) 746 | return valid_col_units 747 | 748 | 749 | def rebuild_col_unit_col(valid_col_units, col_unit, kmap): 750 | if col_unit is None: 751 | return col_unit 752 | 753 | agg_id, col_id, distinct = col_unit 754 | if col_id in kmap and col_id in valid_col_units: 755 | col_id = kmap[col_id] 756 | if DISABLE_DISTINCT: 757 | distinct = None 758 | return agg_id, col_id, distinct 759 | 760 | 761 | def rebuild_val_unit_col(valid_col_units, val_unit, kmap): 762 | if val_unit is None: 763 | return val_unit 764 | 765 | unit_op, col_unit1, col_unit2 = val_unit 766 | col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap) 767 | col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap) 768 | return unit_op, col_unit1, col_unit2 769 | 770 | 771 | def rebuild_table_unit_col(valid_col_units, table_unit, kmap): 772 | if table_unit is None: 773 | return table_unit 774 | 775 | table_type, col_unit_or_sql = table_unit 776 | if isinstance(col_unit_or_sql, tuple): 777 | col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap) 778 | return table_type, col_unit_or_sql 779 | 780 | 781 | def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap): 782 | if cond_unit is None: 783 | return cond_unit 784 | 785 | not_op, op_id, val_unit, val1, val2 = cond_unit 786 | val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap) 787 | return not_op, op_id, val_unit, val1, val2 788 | 789 | 790 | def rebuild_condition_col(valid_col_units, condition, kmap): 791 | for idx in range(len(condition)): 792 | if idx % 2 == 0: 793 | condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap) 794 | return condition 795 | 796 | 797 | def rebuild_select_col(valid_col_units, sel, kmap): 798 | if sel is None: 799 | return sel 800 | distinct, _list = sel 801 | new_list = [] 802 | for it in _list: 803 | agg_id, val_unit = it 804 | new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap))) 805 | if DISABLE_DISTINCT: 806 | distinct = None 807 | return distinct, new_list 808 | 809 | 810 | def rebuild_from_col(valid_col_units, from_, kmap): 811 | if from_ is None: 812 | return from_ 813 | 814 | from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']] 815 | from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap) 816 | return from_ 817 | 818 | 819 | def rebuild_group_by_col(valid_col_units, group_by, kmap): 820 | if group_by is None: 821 | return group_by 822 | 823 | return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by] 824 | 825 | 826 | def rebuild_order_by_col(valid_col_units, order_by, kmap): 827 | if order_by is None or len(order_by) == 0: 828 | return order_by 829 | 830 | direction, val_units = order_by 831 | new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units] 832 | return direction, new_val_units 833 | 834 | 835 | def rebuild_sql_col(valid_col_units, sql, kmap): 836 | if sql is None: 837 | return sql 838 | 839 | sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap) 840 | sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap) 841 | sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap) 842 | sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap) 843 | sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap) 844 | sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap) 845 | sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap) 846 | sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap) 847 | sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap) 848 | 849 | return sql 850 | 851 | 852 | def build_foreign_key_map(entry): 853 | cols_orig = entry["column_names_original"] 854 | tables_orig = entry["table_names_original"] 855 | 856 | # rebuild cols corresponding to idmap in Schema 857 | cols = [] 858 | for col_orig in cols_orig: 859 | if col_orig[0] >= 0: 860 | t = tables_orig[col_orig[0]] 861 | c = col_orig[1] 862 | cols.append("__" + t.lower() + "." + c.lower() + "__") 863 | else: 864 | cols.append("__all__") 865 | 866 | def keyset_in_list(k1, k2, k_list): 867 | for k_set in k_list: 868 | if k1 in k_set or k2 in k_set: 869 | return k_set 870 | new_k_set = set() 871 | k_list.append(new_k_set) 872 | return new_k_set 873 | 874 | foreign_key_list = [] 875 | foreign_keys = entry["foreign_keys"] 876 | for fkey in foreign_keys: 877 | key1, key2 = fkey 878 | key_set = keyset_in_list(key1, key2, foreign_key_list) 879 | key_set.add(key1) 880 | key_set.add(key2) 881 | 882 | foreign_key_map = {} 883 | for key_set in foreign_key_list: 884 | sorted_list = sorted(list(key_set)) 885 | midx = sorted_list[0] 886 | for idx in sorted_list: 887 | foreign_key_map[cols[idx]] = cols[midx] 888 | 889 | return foreign_key_map 890 | 891 | 892 | def build_foreign_key_map_from_json(table): 893 | with open(table) as f: 894 | data = json.load(f) 895 | tables = {} 896 | for entry in data: 897 | tables[entry['db_id']] = build_foreign_key_map(entry) 898 | return tables 899 | 900 | 901 | if __name__ == "__main__": 902 | parser = argparse.ArgumentParser() 903 | parser.add_argument('--gold', dest='gold', type=str) 904 | parser.add_argument('--pred', dest='pred', type=str) 905 | parser.add_argument('--db', dest='db', type=str) 906 | parser.add_argument('--table', dest='table', type=str) 907 | parser.add_argument('--etype', dest='etype', type=str) 908 | args = parser.parse_args() 909 | 910 | gold = args.gold 911 | pred = args.pred 912 | db_dir = args.db 913 | table = args.table 914 | etype = args.etype 915 | 916 | assert etype in ["all", "exec", "match"], "Unknown evaluation method" 917 | 918 | kmaps = build_foreign_key_map_from_json(table) 919 | 920 | evaluate(gold, pred, db_dir, etype, kmaps) 921 | --------------------------------------------------------------------------------