├── LICENSE ├── README.md ├── dev_gold.sql ├── predicted_sql.txt ├── run_c3sql.sh ├── scripts ├── .DS_Store └── prepare_dataset.sh └── src ├── .DS_Store ├── __pycache__ ├── bridge_content_encoder.cpython-39.pyc ├── get_selfconsistent_output.cpython-39.pyc └── sql_post_process.cpython-39.pyc ├── bridge_content_encoder.py ├── column_recall.py ├── generate_sqls_by_gpt3.5.py ├── get_selfconsistent_output.py ├── preprocessing.py ├── prompt_generate.py ├── schema_item_classifier.py ├── sql_post_process.py ├── table_recall.py └── text2sql_data_generator.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 bigbigwatermalon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # C3SQL 2 | The code for the paper C3: Zero-shot Text-to-SQL with ChatGPT ([https://arxiv.org/abs/2307.07306](https://arxiv.org/abs/2307.07306)) 3 | 4 | ## Prepare Spider Data 5 | 6 | Download [spider data](https://drive.google.com/uc?export=download&id=1TqleXec_OykOYFREKKtschzY29dUcVAQ) and database (only spider original database right now) and then unzip them: 7 | 8 | ```shell 9 | mkdir data 10 | unzip spider.zip 11 | mv spider/database . 12 | mv spider data 13 | ``` 14 | 15 | ### Run Inference 16 | Run the command below, and the predicted sql will be save to the file named "predicted_sql.txt" 17 | ```shell 18 | bash run_c3sql.sh 19 | ``` 20 | 21 | ## Run evaluation 22 | Add your openai key in the *generate_sqls_by_gpt3.5.py*, *column_recall.py*, *table_recall.py* files. 23 | ```shell 24 | openai.api_key = "your_api_key" 25 | ``` 26 | 27 | Clone evaluation scripts (test-suite-sql-eval:[https://github.com/taoyds/test-suite-sql-eval](https://github.com/taoyds/test-suite-sql-eval)): 28 | 29 | ```shell 30 | mkdir third_party 31 | cd third_party 32 | git clone https://github.com/taoyds/test-suite-sql-eval 33 | cd .. 34 | ``` 35 | 36 | 37 | Put the 'predicted_sql.txt' in the current directory. 38 | 39 | Then you can run evaluation with following command, and you will see the results on dev data. 40 | For testing, you just need to replace '**dev_gold.sql**' with your test data, folder '**database**' with your database and '**spider/tables.json**' with your test tables.json. 41 | ```shell 42 | python third_party/test-suite-sql-eval/evaluation.py --gold dev_gold.sql --pred predicted_sql.txt --db database --table data/spider/tables.json --etype all 43 | ``` 44 | -------------------------------------------------------------------------------- /run_c3sql.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | 3 | tables="./data/spider/tables.json" 4 | dataset_path="./data/spider/dev.json" 5 | db_dir="database" 6 | output_dataset_path="predicted_sql.txt" 7 | 8 | processed_dataset_path="./generate_datasets/C3_dev.json" 9 | 10 | # preprocess data 11 | bash scripts/prepare_dataset.sh $tables $dataset_path $db_dir $processed_dataset_path 12 | # run prediction 13 | python src/generate_sqls_by_gpt3.5.py --input_dataset_path $processed_dataset_path --output_dataset_path $output_dataset_path --db_dir $db_dir 14 | 15 | -------------------------------------------------------------------------------- /scripts/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigbigwatermalon/C3SQL/3fb33a61e458f4f67c2b8b59c1c4b4e20b4a7461/scripts/.DS_Store -------------------------------------------------------------------------------- /scripts/prepare_dataset.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | 3 | if [ ! -d "generate_datasets" ]; then 4 | mkdir generate_datasets 5 | echo "create directory generate_datasets" 6 | else 7 | echo "directory generate_datasets already exists" 8 | fi 9 | 10 | tables=$1 11 | dataset_path=$2 12 | device="0" 13 | db_path=$3 14 | processed_dataset_path=$4 15 | # preprocess test set 16 | echo "preprocessing..." 17 | python src/preprocessing.py \ 18 | --mode "test" \ 19 | --table_path $tables \ 20 | --input_dataset_path $dataset_path \ 21 | --output_dataset_path "./generate_datasets/preprocessed_data.json" \ 22 | --db_path "$db_path" \ 23 | --target_type "sql" 24 | 25 | # recall tables 26 | echo "recall tables..." 27 | python src/table_recall.py \ 28 | --input_dataset_path "./generate_datasets/preprocessed_data.json" \ 29 | --output_recalled_tables_path "./generate_datasets/table_recall.json" \ 30 | 31 | # recall columns 32 | echo "recall columns..." 33 | python src/column_recall.py \ 34 | --input_recalled_tables_path "./generate_datasets/table_recall.json" \ 35 | --output_recalled_columns_path "./generate_datasets/column_recall.json" \ 36 | 37 | # generate prompt 38 | echo "generate prompt..." 39 | python src/prompt_generate.py \ 40 | --input_dataset_path "./generate_datasets/column_recall.json" \ 41 | --output_dataset_path $processed_dataset_path \ 42 | 43 | -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigbigwatermalon/C3SQL/3fb33a61e458f4f67c2b8b59c1c4b4e20b4a7461/src/.DS_Store -------------------------------------------------------------------------------- /src/__pycache__/bridge_content_encoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigbigwatermalon/C3SQL/3fb33a61e458f4f67c2b8b59c1c4b4e20b4a7461/src/__pycache__/bridge_content_encoder.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/get_selfconsistent_output.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigbigwatermalon/C3SQL/3fb33a61e458f4f67c2b8b59c1c4b4e20b4a7461/src/__pycache__/get_selfconsistent_output.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/sql_post_process.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigbigwatermalon/C3SQL/3fb33a61e458f4f67c2b8b59c1c4b4e20b4a7461/src/__pycache__/sql_post_process.cpython-39.pyc -------------------------------------------------------------------------------- /src/bridge_content_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | Encode DB content. 8 | """ 9 | 10 | import difflib 11 | from typing import List, Optional, Tuple 12 | from rapidfuzz import fuzz 13 | import sqlite3 14 | import functools 15 | 16 | # fmt: off 17 | _stopwords = {'who', 'ourselves', 'down', 'only', 'were', 'him', 'at', "weren't", 'has', 'few', "it's", 'm', 'again', 18 | 'd', 'haven', 'been', 'other', 'we', 'an', 'own', 'doing', 'ma', 'hers', 'all', "haven't", 'in', 'but', 19 | "shouldn't", 'does', 'out', 'aren', 'you', "you'd", 'himself', "isn't", 'most', 'y', 'below', 'is', 20 | "wasn't", 'hasn', 'them', 'wouldn', 'against', 'this', 'about', 'there', 'don', "that'll", 'a', 'being', 21 | 'with', 'your', 'theirs', 'its', 'any', 'why', 'now', 'during', 'weren', 'if', 'should', 'those', 'be', 22 | 'they', 'o', 't', 'of', 'or', 'me', 'i', 'some', 'her', 'do', 'will', 'yours', 'for', 'mightn', 'nor', 23 | 'needn', 'the', 'until', "couldn't", 'he', 'which', 'yourself', 'to', "needn't", "you're", 'because', 24 | 'their', 'where', 'it', "didn't", 've', 'whom', "should've", 'can', "shan't", 'on', 'had', 'have', 25 | 'myself', 'am', "don't", 'under', 'was', "won't", 'these', 'so', 'as', 'after', 'above', 'each', 'ours', 26 | 'hadn', 'having', 'wasn', 's', 'doesn', "hadn't", 'than', 'by', 'that', 'both', 'herself', 'his', 27 | "wouldn't", 'into', "doesn't", 'before', 'my', 'won', 'more', 'are', 'through', 'same', 'how', 'what', 28 | 'over', 'll', 'yourselves', 'up', 'mustn', "mustn't", "she's", 're', 'such', 'didn', "you'll", 'shan', 29 | 'when', "you've", 'themselves', "mightn't", 'she', 'from', 'isn', 'ain', 'between', 'once', 'here', 30 | 'shouldn', 'our', 'and', 'not', 'too', 'very', 'further', 'while', 'off', 'couldn', "hasn't", 'itself', 31 | 'then', 'did', 'just', "aren't"} 32 | # fmt: on 33 | 34 | _commonwords = {"no", "yes", "many"} 35 | 36 | 37 | def is_number(s: str) -> bool: 38 | try: 39 | float(s.replace(",", "")) 40 | return True 41 | except: 42 | return False 43 | 44 | 45 | def is_stopword(s: str) -> bool: 46 | return s.strip() in _stopwords 47 | 48 | 49 | def is_commonword(s: str) -> bool: 50 | return s.strip() in _commonwords 51 | 52 | 53 | def is_common_db_term(s: str) -> bool: 54 | return s.strip() in ["id"] 55 | 56 | 57 | class Match(object): 58 | def __init__(self, start: int, size: int) -> None: 59 | self.start = start 60 | self.size = size 61 | 62 | 63 | def is_span_separator(c: str) -> bool: 64 | return c in "'\"()`,.?! " 65 | 66 | 67 | def split(s: str) -> List[str]: 68 | return [c.lower() for c in s.strip()] 69 | 70 | 71 | def prefix_match(s1: str, s2: str) -> bool: 72 | i, j = 0, 0 73 | for i in range(len(s1)): 74 | if not is_span_separator(s1[i]): 75 | break 76 | for j in range(len(s2)): 77 | if not is_span_separator(s2[j]): 78 | break 79 | if i < len(s1) and j < len(s2): 80 | return s1[i] == s2[j] 81 | elif i >= len(s1) and j >= len(s2): 82 | return True 83 | else: 84 | return False 85 | 86 | 87 | def get_effective_match_source(s: str, start: int, end: int) -> Match: 88 | _start = -1 89 | 90 | for i in range(start, start - 2, -1): 91 | if i < 0: 92 | _start = i + 1 93 | break 94 | if is_span_separator(s[i]): 95 | _start = i 96 | break 97 | 98 | if _start < 0: 99 | return None 100 | 101 | _end = -1 102 | for i in range(end - 1, end + 3): 103 | if i >= len(s): 104 | _end = i - 1 105 | break 106 | if is_span_separator(s[i]): 107 | _end = i 108 | break 109 | 110 | if _end < 0: 111 | return None 112 | 113 | while _start < len(s) and is_span_separator(s[_start]): 114 | _start += 1 115 | while _end >= 0 and is_span_separator(s[_end]): 116 | _end -= 1 117 | 118 | return Match(_start, _end - _start + 1) 119 | 120 | 121 | def get_matched_entries( 122 | s: str, field_values: List[str], m_theta: float = 0.85, s_theta: float = 0.85 123 | ) -> Optional[List[Tuple[str, Tuple[str, str, float, float, int]]]]: 124 | if not field_values: 125 | return None 126 | 127 | if isinstance(s, str): 128 | n_grams = split(s) 129 | else: 130 | n_grams = s 131 | 132 | matched = dict() 133 | for field_value in field_values: 134 | if not isinstance(field_value, str): 135 | continue 136 | fv_tokens = split(field_value) 137 | sm = difflib.SequenceMatcher(None, n_grams, fv_tokens) 138 | match = sm.find_longest_match(0, len(n_grams), 0, len(fv_tokens)) 139 | if match.size > 0: 140 | source_match = get_effective_match_source( 141 | n_grams, match.a, match.a + match.size 142 | ) 143 | if source_match and source_match.size > 1: 144 | match_str = field_value[match.b: match.b + match.size] 145 | source_match_str = s[ 146 | source_match.start: source_match.start + source_match.size 147 | ] 148 | c_match_str = match_str.lower().strip() 149 | c_source_match_str = source_match_str.lower().strip() 150 | c_field_value = field_value.lower().strip() 151 | if ( 152 | c_match_str 153 | and not is_number(c_match_str) 154 | and not is_common_db_term(c_match_str) 155 | ): 156 | if ( 157 | is_stopword(c_match_str) 158 | or is_stopword(c_source_match_str) 159 | or is_stopword(c_field_value) 160 | ): 161 | continue 162 | if c_source_match_str.endswith(c_match_str + "'s"): 163 | match_score = 1.0 164 | else: 165 | if prefix_match(c_field_value, c_source_match_str): 166 | match_score = ( 167 | fuzz.ratio(c_field_value, c_source_match_str) / 100 168 | ) 169 | else: 170 | match_score = 0 171 | if ( 172 | is_commonword(c_match_str) 173 | or is_commonword(c_source_match_str) 174 | or is_commonword(c_field_value) 175 | ) and match_score < 1: 176 | continue 177 | s_match_score = match_score 178 | if match_score >= m_theta and s_match_score >= s_theta: 179 | if field_value.isupper() and match_score * s_match_score < 1: 180 | continue 181 | matched[match_str] = ( 182 | field_value, 183 | source_match_str, 184 | match_score, 185 | s_match_score, 186 | match.size, 187 | ) 188 | 189 | if not matched: 190 | return None 191 | else: 192 | return sorted( 193 | matched.items(), 194 | key=lambda x: (1e16 * x[1][2] + 1e8 * x[1][3] + x[1][4]), 195 | reverse=True, 196 | ) 197 | 198 | 199 | @functools.lru_cache(maxsize=1000, typed=False) 200 | def get_column_picklist(table_name: str, column_name: str, db_path: str) -> list: 201 | fetch_sql = "SELECT DISTINCT `{}` FROM `{}`".format(column_name, table_name) 202 | try: 203 | # print(f"db_path: {db_path}") 204 | conn = sqlite3.connect(db_path) 205 | conn.text_factory = bytes 206 | c = conn.cursor() 207 | c.execute(fetch_sql) 208 | picklist = set() 209 | for x in c.fetchall(): 210 | if isinstance(x[0], str): 211 | picklist.add(x[0].encode("utf-8")) 212 | elif isinstance(x[0], bytes): 213 | try: 214 | picklist.add(x[0].decode("utf-8")) 215 | except UnicodeDecodeError: 216 | picklist.add(x[0].decode("latin-1")) 217 | else: 218 | picklist.add(x[0]) 219 | picklist = list(picklist) 220 | finally: 221 | conn.close() 222 | return picklist 223 | 224 | 225 | def get_database_matches( 226 | question: str, 227 | table_name: str, 228 | column_name: str, 229 | db_path: str, 230 | top_k_matches: int = 2, 231 | match_threshold: float = 0.85, 232 | ) -> List[str]: 233 | picklist = get_column_picklist( 234 | table_name=table_name, column_name=column_name, db_path=db_path 235 | ) 236 | # only maintain data in ``str'' type 237 | picklist = [ele.strip() for ele in picklist if isinstance(ele, str)] 238 | # picklist is unordered, we sort it to ensure the reproduction stability 239 | picklist = sorted(picklist) 240 | matches = [] 241 | if picklist and isinstance(picklist[0], str): 242 | matched_entries = get_matched_entries( 243 | s=question, 244 | field_values=picklist, 245 | m_theta=match_threshold, 246 | s_theta=match_threshold, 247 | ) 248 | 249 | if matched_entries: 250 | num_values_inserted = 0 251 | for _match_str, ( 252 | field_value, 253 | _s_match_str, 254 | match_score, 255 | s_match_score, 256 | _match_size, 257 | ) in matched_entries: 258 | if "name" in column_name and match_score * s_match_score < 1: 259 | continue 260 | if table_name != "sqlite_sequence": # Spider database artifact 261 | matches.append(field_value.strip()) 262 | num_values_inserted += 1 263 | if num_values_inserted >= top_k_matches: 264 | break 265 | 266 | # # if the length of value type is less than 4, add it. 267 | # if len(matches) == 0: 268 | # pick_set = set(picklist) 269 | # if len(pick_set) <= 3: 270 | # matches = [item for item in pick_set] 271 | return matches 272 | -------------------------------------------------------------------------------- /src/column_recall.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import openai 4 | import time 5 | from tqdm import tqdm 6 | from collections import Counter 7 | 8 | # add your openai api key 9 | openai.api_key = "sk-" 10 | 11 | 12 | def parse_option(): 13 | parser = argparse.ArgumentParser("command line arguments for recall columns") 14 | parser.add_argument("--input_recalled_tables_path", type=str) 15 | parser.add_argument("--self_consistent", type=bool, default=True) 16 | parser.add_argument("--n", type=int, default=10, 17 | help="Size of self-consistent set") 18 | parser.add_argument("--add_fk", type=bool, default=True) 19 | parser.add_argument("--output_recalled_columns_path", type=str) 20 | 21 | opt = parser.parse_args() 22 | 23 | return opt 24 | 25 | 26 | def generate_reply(input, sc_num): 27 | completions = openai.ChatCompletion.create( 28 | model="gpt-3.5-turbo", 29 | messages=input, 30 | temperature=0.7, 31 | n=sc_num 32 | ) 33 | tabs_cols_all = [] 34 | for i in range(sc_num): 35 | raw_tab_col = completions.choices[i].message.content 36 | try: 37 | raw_tab_col = '{' + raw_tab_col.split('{', 1)[1] 38 | raw_tab_col = raw_tab_col.rsplit('}', 1)[0] + '}' 39 | raw_tab_col = json.loads(raw_tab_col) 40 | except: 41 | print('list error') 42 | return None 43 | tabs_cols_all.append(raw_tab_col) 44 | return tabs_cols_all 45 | 46 | 47 | def generate_schema(data): 48 | schema = "" 49 | for table in data['db_schema']: 50 | schema += '# ' + table['table_name_original'] + ' ( ' 51 | for i, column in enumerate(table['column_names_original']): 52 | schema += column 53 | if table['db_contents'][i]: 54 | schema += ' ( ' 55 | for value in table['db_contents'][i]: 56 | schema += value + ', ' 57 | schema = schema[:-2] + ' )' 58 | schema += ', ' 59 | schema = schema[:-2] + ' )\n' 60 | return schema 61 | 62 | 63 | def extract_fks(strings): 64 | fks = {} 65 | 66 | for string in strings: 67 | parts = string.split(' = ') 68 | left_side = parts[0].split('.') 69 | right_side = parts[1].split('.') 70 | 71 | left_table = left_side[0] 72 | left_column = left_side[1] 73 | 74 | right_table = right_side[0] 75 | right_column = right_side[1] 76 | 77 | if left_table not in fks: 78 | fks[left_table] = [] 79 | 80 | fks[left_table].append(left_column) 81 | 82 | if right_table not in fks: 83 | fks[right_table] = [] 84 | 85 | fks[right_table].append(right_column) 86 | 87 | return fks 88 | 89 | 90 | def column_sc(tabs_cols_all, tabs_cols_ori, fk_ori): 91 | candidates = {} 92 | results = {} 93 | for key in tabs_cols_ori: 94 | candidates[key] = [] 95 | 96 | # filter out invalid tables 97 | for tabs_cols in tabs_cols_all: 98 | for key, value in tabs_cols.items(): 99 | if key in tabs_cols_ori: 100 | candidates[key].append(value) 101 | 102 | for tab, cols_all in candidates.items(): 103 | cols_ori = [item.lower() for item in tabs_cols_ori[tab]] 104 | cols_sc = [] 105 | for cols in cols_all: 106 | cols_exist = [] 107 | for col in cols: 108 | if col.lower() in cols_ori: 109 | cols_exist.append(col) 110 | if len(cols_exist) == 4: 111 | break 112 | if len(cols_exist) > 0: 113 | cols_sc.append(cols_exist) 114 | # choose the top-5 columns with the highest frequency 115 | if len(cols_sc) > 0: 116 | cols_add = [] 117 | for cols in cols_sc: 118 | cols_add = cols_add + cols 119 | counter = Counter(cols_add) 120 | most_common_cols = counter.most_common(5) 121 | temp = [] 122 | for value, count in most_common_cols: 123 | temp.append(value) 124 | results[tab] = temp 125 | else: 126 | results[tab] = [] 127 | 128 | if opt.add_fk: 129 | fk = extract_fks(fk_ori) 130 | for tab, cols in fk.items(): 131 | if tab in results: 132 | for col in cols: 133 | if col not in results[tab]: 134 | results[tab].append(col) 135 | return results 136 | 137 | 138 | def info_generate(tabs_cols, data): 139 | info = {} 140 | info['db_id'] = data['db_id'] 141 | info['question'] = data['question'] 142 | info['schema'] = tabs_cols 143 | info['fk'] = data['fk'] 144 | info['db_contents'] = {} 145 | for tab, cols in tabs_cols.items(): 146 | values = [] 147 | for tab_ori in data['db_schema']: 148 | if tab == tab_ori['table_name_original'].lower(): 149 | cols_ori = [item.lower() for item in tab_ori['column_names_original']] 150 | for i, col in enumerate(cols): 151 | index = cols_ori.index(col) 152 | values.append(tab_ori['db_contents'][index]) 153 | break 154 | info['db_contents'][tab] = values 155 | return info 156 | 157 | 158 | instruction = '''Given the database tables and question, perform the following actions: 159 | 1 - Rank the columns in each table based on the possibility of being used in the SQL, Column that matches more with the question words or the foreign key is highly relevant and must be placed ahead. You should output them in the order of the most relevant to the least relevant. 160 | Explain why you choose each column. 161 | 2 - Output a JSON object that contains all the columns in each table according to your explanation. The format should be like: 162 | { 163 | "table_1": ["column_1", "column_2", ......], 164 | "table_2": ["column_1", "column_2", ......], 165 | "table_3": ["column_1", "column_2", ......], 166 | ...... 167 | } 168 | 169 | ''' 170 | 171 | if __name__ == "__main__": 172 | opt = parse_option() 173 | print(opt) 174 | with open(opt.input_dataset_path) as f: 175 | data_all = json.load(f) 176 | res = [] 177 | if opt.self_consistent: 178 | sc_num = opt.n 179 | else: 180 | sc_num = 1 181 | for i, data in enumerate(tqdm(data_all)): 182 | schema = generate_schema(data) 183 | prompt = instruction + 'Schema:\n' + schema 184 | prompt = prompt + 'Foreign keys: \n' 185 | for fk in data['fk']: 186 | prompt = prompt + '# ' + fk + '\n' 187 | prompt += "\nQuestion:\n### " + data["question"] 188 | # print(prompt) 189 | tabs_cols_all = None 190 | while tabs_cols_all is None: 191 | try: 192 | tabs_cols_all = generate_reply([{"role": "user", "content": prompt}], sc_num) 193 | except: 194 | print(f'api error, wait for 3 seconds and retry...') 195 | time.sleep(3) 196 | pass 197 | tab_col_ori = {} 198 | for table in data['db_schema']: 199 | tab_col_ori[table['table_name_original'].lower()] = table['column_names_original'] 200 | tabs_cols = column_sc(tabs_cols_all, tab_col_ori, data['fk']) 201 | info = info_generate(tabs_cols, data) 202 | res.append(info) 203 | # print(res) 204 | with open(opt.output_dataset_path, 'w') as f: 205 | json.dump(res, f, indent=2) 206 | -------------------------------------------------------------------------------- /src/generate_sqls_by_gpt3.5.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import time 4 | import openai 5 | from sql_post_process import fix_select_column 6 | import re 7 | import os 8 | import sqlite3 9 | from get_selfconsistent_output import get_sqls 10 | from tqdm import tqdm 11 | 12 | # add your openai api key 13 | openai.api_key = "sk-" 14 | 15 | chat_prompt = [ 16 | { 17 | "role": "system", 18 | "content": "You are now an excellent SQL writer, first I'll give you some tips and examples, and I need you to remember the tips, and do not make same mistakes." 19 | }, 20 | { 21 | "role": "user", 22 | "content": """Tips 1: 23 | Question: Which A has most number of B? 24 | Gold SQL: select A from B group by A order by count ( * ) desc limit 1; 25 | Notice that the Gold SQL doesn't select COUNT(*) because the question only wants to know the A and the number should be only used in ORDER BY clause, there are many questions asks in this way, and I need you to remember this in the the following questions.""" 26 | }, 27 | { 28 | "role": "assistant", 29 | "content": "Thank you for the tip! I'll keep in mind that when the question only asks for a certain field, I should not include the COUNT(*) in the SELECT statement, but instead use it in the ORDER BY clause to sort the results based on the count of that field." 30 | }, 31 | { 32 | "role": "user", 33 | "content": """Tips 2: 34 | Don't use "IN", "OR", "LEFT JOIN" as it might cause extra results, use "INTERSECT" or "EXCEPT" instead, and remember to use "DISTINCT" or "LIMIT" when necessary. 35 | For example, 36 | Question: Who are the A who have been nominated for both B award and C award? 37 | Gold SQL should be: select A from X where award = 'B' intersect select A from X where award = 'C';""" 38 | }, 39 | { 40 | "role": "assistant", 41 | "content": "Thank you for the tip! I'll remember to use \"INTERSECT\" or \"EXCEPT\" instead of \"IN\", \"OR\", or \"LEFT JOIN\" when I want to find records that match or don't match across two tables. Additionally, I'll make sure to use \"DISTINCT\" or \"LIMIT\" when necessary to avoid repetitive results or limit the number of results returned." 42 | } 43 | ] 44 | 45 | 46 | def parse_option(): 47 | parser = argparse.ArgumentParser("command line arguments for generate sqls") 48 | parser.add_argument("--input_dataset_path", type=str) 49 | parser.add_argument("--self_consistent", type=bool, default=True) 50 | parser.add_argument("--n", type=int, default=20, 51 | help="Size of self-consistent set") 52 | parser.add_argument("--output_dataset_path", type=str) 53 | parser.add_argument("--db_dir", type=str, default="./data/database") 54 | 55 | opt = parser.parse_args() 56 | 57 | return opt 58 | 59 | 60 | def generate_reply(messages, n): 61 | completions = openai.ChatCompletion.create( 62 | model="gpt-3.5-turbo", 63 | messages=messages, 64 | n=n 65 | ) 66 | # print(completions) 67 | mes = completions.choices[0].message.content 68 | all_p_sqls = [] 69 | for i in range(n): 70 | all_p_sqls.append(completions.choices[i].message.content.replace("\n", " ")) 71 | return all_p_sqls 72 | 73 | 74 | def replace_cur_year(query: str) -> str: 75 | return re.sub( 76 | "YEAR\s*\(\s*CURDATE\s*\(\s*\)\s*\)\s*", "2020", query, flags=re.IGNORECASE 77 | ) 78 | 79 | 80 | def get_cursor_from_path(sqlite_path: str): 81 | try: 82 | if not os.path.exists(sqlite_path): 83 | print("Openning a new connection %s" % sqlite_path) 84 | connection = sqlite3.connect(sqlite_path) 85 | except Exception as e: 86 | print(sqlite_path) 87 | raise e 88 | connection.text_factory = lambda b: b.decode(errors="ignore") 89 | cursor = connection.cursor() 90 | return cursor 91 | 92 | 93 | def exec_on_db_(sqlite_path: str, query: str): 94 | query = replace_cur_year(query) 95 | cursor = get_cursor_from_path(sqlite_path) 96 | try: 97 | cursor.execute(query) 98 | result = cursor.fetchall() 99 | cursor.close() 100 | cursor.connection.close() 101 | return "result", result 102 | except Exception as e: 103 | cursor.close() 104 | cursor.connection.close() 105 | return "exception", e 106 | 107 | 108 | def is_valid(sql, db_path): 109 | flag, _ = exec_on_db_(db_path, sql) 110 | if flag == "exception": 111 | return 0 112 | else: 113 | return 1 114 | 115 | 116 | if __name__ == '__main__': 117 | opt = parse_option() 118 | print(opt) 119 | with open(opt.input_dataset_path) as f: 120 | data = json.load(f) 121 | results = [] 122 | p_sql_final = [] 123 | if not opt.self_consistent: 124 | for i, item in enumerate(data): 125 | print("id", i) 126 | db_dir = opt.db_dir + '/' + item['db_id'] + '/' + item['db_id'] + '.sqlite' 127 | for j in range(5): 128 | messages = [] 129 | messages = chat_prompt.copy() 130 | input = item['input_sequence'] 131 | messages.append({"role": "user", "content": input}) 132 | p_sql = generate_reply(messages, 1)[0] 133 | p_sql = 'SELECT ' + p_sql 134 | p_sql = p_sql.replace("SELECT SELECT", "SELECT") 135 | p_sql = fix_select_column(p_sql) 136 | p_sql = p_sql.replace("> =", ">=").replace("< =", "<=").replace("! =", "!=") 137 | print(f'p_sql: {p_sql}') 138 | if is_valid(p_sql, db_dir): 139 | break 140 | else: 141 | print(f're_id: {j} p_sql: {p_sql} exec error...') 142 | time.sleep(0.5) 143 | if j < 4: 144 | print(f'generate again') 145 | p_sql_final.append(p_sql) 146 | print(p_sql_final) 147 | else: 148 | for i, item in enumerate(tqdm(data)): 149 | db_dir = opt.db_dir + '/' + item['db_id'] + '/' + item['db_id'] + '.sqlite' 150 | p_sqls = [] 151 | for j in range(5): 152 | messages = [] 153 | messages = chat_prompt.copy() 154 | input = item['input_sequence'] 155 | messages.append({"role": "user", "content": input}) 156 | reply = None 157 | while reply is None: 158 | try: 159 | reply = generate_reply(messages, opt.n) 160 | except Exception as e: 161 | print(e) 162 | print(f"api error, wait for 3 seconds and retry...") 163 | time.sleep(3) 164 | pass 165 | p_sqls = reply 166 | temp = [] 167 | for p_sql in p_sqls: 168 | p_sql = 'SELECT ' + p_sql 169 | p_sql = p_sql.replace("SELECT SELECT", "SELECT") 170 | try: 171 | p_sql = fix_select_column(p_sql) 172 | except: 173 | print(f"fix_select_column err, p_sql: {p_sql}") 174 | pass 175 | p_sql = p_sql.replace("> =", ">=").replace("< =", "<=").replace("! =", "!=") 176 | p_sql = p_sql.replace("\n", " ") 177 | while " " in p_sql: 178 | p_sql = p_sql.replace(" ", " ") 179 | temp.append(p_sql) 180 | p_sqls = temp 181 | if is_valid(p_sqls[0], db_dir): 182 | break 183 | else: 184 | print(f're_id: {j} p_sql: {p_sqls[0]} exec error...') 185 | time.sleep(0.5) 186 | if j < 4: 187 | print(f'generate again') 188 | result = {} 189 | result['db_id'] = item['db_id'] 190 | result['question'] = item['question'] 191 | result['p_sqls'] = [] 192 | for sql in p_sqls: 193 | result['p_sqls'].append(sql) 194 | results.append(result) 195 | # time.sleep(1) 196 | p_sql_final = get_sqls(results, opt.n, opt.db_dir) 197 | with open(opt.output_dataset_path, 'w') as f: 198 | for sql in p_sql_final: 199 | print(sql, file=f) 200 | -------------------------------------------------------------------------------- /src/get_selfconsistent_output.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import os 4 | import random 5 | import re 6 | import sqlite3 7 | import threading 8 | from collections import defaultdict 9 | from itertools import product 10 | from typing import Tuple, Any, List, Set 11 | import sqlparse 12 | import tqdm 13 | 14 | # from third_party.test_suite.exec_eval import eval_exec_match 15 | # from third_party.test_suite.parse import remove_distinct 16 | 17 | threadLock = threading.Lock() 18 | TIMEOUT = 60 19 | EXEC_TMP_DIR = os.path.join(os.path.dirname(__file__), "tmp") 20 | 21 | 22 | def permute_tuple(element: Tuple, perm: Tuple) -> Tuple: 23 | assert len(element) == len(perm) 24 | return tuple([element[i] for i in perm]) 25 | 26 | 27 | def unorder_row(row: Tuple) -> Tuple: 28 | return tuple(sorted(row, key=lambda x: str(x) + str(type(x)))) 29 | 30 | 31 | # unorder each row in the table 32 | # [result_1 and result_2 has the same bag of unordered row] 33 | # is a necessary condition of 34 | # [result_1 and result_2 are equivalent in denotation] 35 | def quick_rej(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool: 36 | s1 = [unorder_row(row) for row in result1] 37 | s2 = [unorder_row(row) for row in result2] 38 | if order_matters: 39 | return s1 == s2 40 | else: 41 | return set(s1) == set(s2) 42 | 43 | 44 | # return whether two bag of relations are equivalent 45 | def multiset_eq(l1: List, l2: List) -> bool: 46 | if len(l1) != len(l2): 47 | return False 48 | d = defaultdict(int) 49 | for e in l1: 50 | d[e] = d[e] + 1 51 | for e in l2: 52 | d[e] = d[e] - 1 53 | if d[e] < 0: 54 | return False 55 | return True 56 | 57 | 58 | def get_constraint_permutation(tab1_sets_by_columns: List[Set], result2: List[Tuple]): 59 | num_cols = len(result2[0]) 60 | perm_constraints = [{i for i in range(num_cols)} for _ in range(num_cols)] 61 | if num_cols <= 3: 62 | return product(*perm_constraints) 63 | 64 | # we sample 20 rows and constrain the space of permutations 65 | for _ in range(20): 66 | random_tab2_row = random.choice(result2) 67 | 68 | for tab1_col in range(num_cols): 69 | for tab2_col in set(perm_constraints[tab1_col]): 70 | if random_tab2_row[tab2_col] not in tab1_sets_by_columns[tab1_col]: 71 | perm_constraints[tab1_col].remove(tab2_col) 72 | return product(*perm_constraints) 73 | 74 | 75 | # check whether two denotations are correct 76 | def result_eq(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool: 77 | if len(result1) == 0 and len(result2) == 0: 78 | return True 79 | 80 | # if length is not the same, then they are definitely different bag of rows 81 | if len(result1) != len(result2): 82 | return False 83 | 84 | num_cols = len(result1[0]) 85 | 86 | # if the results do not have the same number of columns, they are different 87 | if len(result2[0]) != num_cols: 88 | return False 89 | 90 | # unorder each row and compare whether the denotation is the same 91 | # this can already find most pair of denotations that are different 92 | if not quick_rej(result1, result2, order_matters): 93 | return False 94 | 95 | # the rest of the problem is in fact more complicated than one might think 96 | # we want to find a permutation of column order and a permutation of row order, 97 | # s.t. result_1 is the same as result_2 98 | # we return true if we can find such column & row permutations 99 | # and false if we cannot 100 | tab1_sets_by_columns = [{row[i] for row in result1} for i in range(num_cols)] 101 | 102 | # on a high level, we enumerate all possible column permutations that might make result_1 == result_2 103 | # we decrease the size of the column permutation space by the function get_constraint_permutation 104 | # if one of the permutation make result_1, result_2 equivalent, then they are equivalent 105 | for perm in get_constraint_permutation(tab1_sets_by_columns, result2): 106 | if len(perm) != len(set(perm)): 107 | continue 108 | if num_cols == 1: 109 | result2_perm = result2 110 | else: 111 | result2_perm = [permute_tuple(element, perm) for element in result2] 112 | if order_matters: 113 | if result1 == result2_perm: 114 | return True 115 | else: 116 | # in fact the first condition must hold if the second condition holds 117 | # but the first is way more efficient implementation-wise 118 | # and we use it to quickly reject impossible candidates 119 | if set(result1) == set(result2_perm) and multiset_eq(result1, result2_perm): 120 | return True 121 | return False 122 | 123 | 124 | def replace_cur_year(query: str) -> str: 125 | return re.sub( 126 | "YEAR\s*\(\s*CURDATE\s*\(\s*\)\s*\)\s*", "2020", query, flags=re.IGNORECASE 127 | ) 128 | 129 | 130 | # get the database cursor for a sqlite database path 131 | def get_cursor_from_path(sqlite_path: str): 132 | try: 133 | if not os.path.exists(sqlite_path): 134 | print("Openning a new connection %s" % sqlite_path) 135 | connection = sqlite3.connect(sqlite_path) 136 | except Exception as e: 137 | print(sqlite_path) 138 | raise e 139 | connection.text_factory = lambda b: b.decode(errors="ignore") 140 | cursor = connection.cursor() 141 | return cursor 142 | 143 | 144 | async def exec_on_db_(sqlite_path: str, query: str) -> Tuple[str, Any]: 145 | query = replace_cur_year(query) 146 | cursor = get_cursor_from_path(sqlite_path) 147 | try: 148 | cursor.execute(query) 149 | result = cursor.fetchall() 150 | cursor.close() 151 | cursor.connection.close() 152 | return "result", result 153 | except Exception as e: 154 | cursor.close() 155 | cursor.connection.close() 156 | return "exception", e 157 | 158 | 159 | async def exec_on_db( 160 | sqlite_path: str, query: str, process_id: str = "", timeout: int = TIMEOUT 161 | ) -> Tuple[str, Any]: 162 | try: 163 | return await asyncio.wait_for(exec_on_db_(sqlite_path, query), timeout) 164 | except asyncio.TimeoutError: 165 | return ('exception', TimeoutError) 166 | except Exception as e: 167 | return ("exception", e) 168 | 169 | 170 | # postprocess the model predictions to avoid execution errors 171 | # e.g. removing spaces between ">" and "=" 172 | def postprocess(query: str) -> str: 173 | query = query.replace("> =", ">=").replace("< =", "<=").replace("! =", "!=") 174 | return query 175 | 176 | def remove_distinct(s): 177 | toks = [t.value for t in list(sqlparse.parse(s)[0].flatten())] 178 | return "".join([t for t in toks if t.lower() != "distinct"]) 179 | 180 | def get_exec_output( 181 | db: str, 182 | sql: str, 183 | plug_value: bool = False, 184 | keep_distinct: bool = False, 185 | progress_bar_for_each_datapoint: bool = False, 186 | ): 187 | # post-process the prediction. 188 | # e.g. removing spaces between ">" and "=" 189 | sql = postprocess(sql) 190 | 191 | if not keep_distinct: 192 | try: 193 | # if sqlparse can't parse p_str, we should not even try to execute it 194 | sql = remove_distinct(sql) 195 | except Exception as e: 196 | return "exception", [] 197 | 198 | db_dir = os.path.dirname(db) 199 | db_paths = [os.path.join(db_dir, basename) for basename in os.listdir(db_dir) if ".sqlite" in basename] 200 | # print(db_paths) 201 | if progress_bar_for_each_datapoint: 202 | ranger = tqdm.tqdm(db_paths) 203 | else: 204 | ranger = db_paths 205 | for db_path in ranger: 206 | flag, sql_denotation = asyncio.run(exec_on_db(db_path, sql)) 207 | # print(sql_denotation) 208 | return flag, sql_denotation 209 | 210 | 211 | def get_sqls(results, select_number, db_dir): 212 | db_ids = [] 213 | all_p_sqls = [] 214 | for item in results: 215 | p_sqls = [] 216 | db_ids.append(item['db_id']) 217 | for i, x in enumerate(item['p_sqls']): 218 | p_sqls.append(x) 219 | if i+1 == select_number: 220 | break 221 | all_p_sqls.append(p_sqls) 222 | chosen_p_sqls = [] 223 | for i, db_id in enumerate(tqdm.tqdm(db_ids)): 224 | p_sqls = all_p_sqls[i] 225 | db_path = f"{db_dir}/{db_id}/{db_id}" 226 | cluster_sql_list = [] 227 | map_sql2denotation = {} 228 | for sql in p_sqls: 229 | flag, denotation = get_exec_output( 230 | db_path, 231 | sql, 232 | ) 233 | if flag == "exception": 234 | continue 235 | map_sql2denotation[sql] = denotation 236 | denotation_match = False 237 | 238 | for id, cluster in enumerate(cluster_sql_list): 239 | center_sql = cluster[0] 240 | if result_eq(map_sql2denotation[center_sql], denotation, False): 241 | cluster_sql_list[id].append(sql) 242 | denotation_match = True 243 | break 244 | if not denotation_match: 245 | cluster_sql_list.append([sql]) 246 | cluster_sql_list.sort(key=lambda x: len(x), reverse=True) 247 | if not cluster_sql_list: 248 | chosen_p_sqls.append(p_sqls[0]) 249 | else: 250 | chosen_p_sqls.append(cluster_sql_list[0][0]) 251 | 252 | print("save chosen sqls and results...") 253 | 254 | return chosen_p_sqls 255 | -------------------------------------------------------------------------------- /src/preprocessing.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import argparse 4 | 5 | from bridge_content_encoder import get_database_matches 6 | from sql_metadata import Parser 7 | from tqdm import tqdm 8 | 9 | sql_keywords = ['select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', \ 10 | 'except', 'join', 'on', 'as', 'not', 'between', 'in', 'like', 'is', 'exists', 'max', 'min', \ 11 | 'count', 'sum', 'avg', 'and', 'or', 'desc', 'asc'] 12 | 13 | 14 | def parse_option(): 15 | parser = argparse.ArgumentParser("") 16 | 17 | parser.add_argument('--mode', type=str, default="train") 18 | parser.add_argument('--table_path', type=str, default="./data/spider/tables.json") 19 | parser.add_argument('--input_dataset_path', type=str, default="./data/spider/train_spider.json", 20 | help=''' 21 | options: 22 | ./data/spider/train_spider.json 23 | ./data/spider/dev.json 24 | ''') 25 | parser.add_argument('--natsql_dataset_path', type=str, default="./NatSQL/NatSQLv1_6/train_spider-natsql.json", 26 | help=''' 27 | options: 28 | ./NatSQL/NatSQLv1_6/train_spider-natsql.json 29 | ./NatSQL/NatSQLv1_6/dev-natsql.json 30 | ''') 31 | parser.add_argument('--output_dataset_path', type=str, default="./data/pre-processing/preprocessed_dataset.json", 32 | help="the filepath of preprocessed dataset.") 33 | parser.add_argument('--db_path', type=str, default="./data/spider/database", 34 | help="the filepath of database.") 35 | parser.add_argument("--target_type", type=str, default="sql", 36 | help="sql or natsql.") 37 | parser.add_argument("--dataset_name", type=str, default="spider") 38 | 39 | opt = parser.parse_args() 40 | 41 | return opt 42 | 43 | 44 | def get_db_contents(question, table_name_original, column_names_original, db_id, db_path): 45 | matched_contents = [] 46 | # extract matched contents for each column 47 | for column_name_original in column_names_original: 48 | matches = get_database_matches( 49 | question, 50 | table_name_original, 51 | column_name_original, 52 | db_path + "/{}/{}.sqlite".format(db_id, db_id) 53 | ) 54 | matches = sorted(matches) 55 | matched_contents.append(matches) 56 | 57 | return matched_contents 58 | 59 | 60 | def get_db_schemas(all_db_infos, opt=None): 61 | db_schemas = {} 62 | 63 | for db in all_db_infos: 64 | table_names_original = db["table_names_original"] 65 | table_names = db["table_names"] 66 | column_names_original = db["column_names_original"] 67 | column_names = db["column_names"] 68 | column_types = db["column_types"] 69 | 70 | db_schemas[db["db_id"]] = {} 71 | 72 | primary_keys, foreign_keys = [], [] 73 | # record primary keys 74 | for pk_column_idx in db["primary_keys"]: 75 | pk_table_name_original = table_names_original[column_names_original[pk_column_idx][0]] 76 | pk_column_name_original = column_names_original[pk_column_idx][1] 77 | 78 | primary_keys.append( 79 | { 80 | "table_name_original": pk_table_name_original.lower(), 81 | "column_name_original": pk_column_name_original.lower() 82 | } 83 | ) 84 | 85 | db_schemas[db["db_id"]]["pk"] = primary_keys 86 | 87 | # record foreign keys 88 | for source_column_idx, target_column_idx in db["foreign_keys"]: 89 | fk_source_table_name_original = table_names_original[column_names_original[source_column_idx][0]] 90 | fk_source_column_name_original = column_names_original[source_column_idx][1] 91 | 92 | fk_target_table_name_original = table_names_original[column_names_original[target_column_idx][0]] 93 | fk_target_column_name_original = column_names_original[target_column_idx][1] 94 | 95 | foreign_keys.append( 96 | { 97 | "source_table_name_original": fk_source_table_name_original.lower(), 98 | "source_column_name_original": fk_source_column_name_original.lower(), 99 | "target_table_name_original": fk_target_table_name_original.lower(), 100 | "target_column_name_original": fk_target_column_name_original.lower(), 101 | } 102 | ) 103 | db_schemas[db["db_id"]]["fk"] = foreign_keys 104 | 105 | db_schemas[db["db_id"]]["schema_items"] = [] 106 | for idx, table_name_original in enumerate(table_names_original): 107 | column_names_original_list = [] 108 | column_names_list = [] 109 | column_types_list = [] 110 | for column_idx, (table_idx, column_name_original) in enumerate(column_names_original): 111 | if idx == table_idx: 112 | column_names_original_list.append(column_name_original.lower()) 113 | column_names_list.append(column_names[column_idx][1].lower()) 114 | column_types_list.append(column_types[column_idx]) 115 | 116 | db_schemas[db["db_id"]]["schema_items"].append({ 117 | "table_name_original": table_name_original.lower(), 118 | "table_name": table_names[idx].lower(), 119 | "column_names": column_names_list, 120 | "column_names_original": column_names_original_list, 121 | "column_types": column_types_list 122 | }) 123 | 124 | return db_schemas 125 | 126 | 127 | def normalization(sql): 128 | def white_space_fix(s): 129 | parsed_s = Parser(s) 130 | s = " ".join([token.value for token in parsed_s.tokens]) 131 | 132 | return s 133 | 134 | # convert everything except text between single quotation marks to lower case 135 | def lower(s): 136 | in_quotation = False 137 | out_s = "" 138 | for char in s: 139 | if in_quotation: 140 | out_s += char 141 | else: 142 | out_s += char.lower() 143 | 144 | if char == "'": 145 | if in_quotation: 146 | in_quotation = False 147 | else: 148 | in_quotation = True 149 | 150 | return out_s 151 | 152 | # remove ";" 153 | def remove_semicolon(s): 154 | if s.endswith(";"): 155 | s = s[:-1] 156 | return s 157 | 158 | # double quotation -> single quotation 159 | def double2single(s): 160 | return s.replace("\"", "'") 161 | 162 | def add_asc(s): 163 | pattern = re.compile( 164 | r'order by (?:\w+ \( \S+ \)|\w+\.\w+|\w+)(?: (?:\+|\-|\<|\<\=|\>|\>\=) (?:\w+ \( \S+ \)|\w+\.\w+|\w+))*') 165 | if "order by" in s and "asc" not in s and "desc" not in s: 166 | for p_str in pattern.findall(s): 167 | s = s.replace(p_str, p_str + " asc") 168 | 169 | return s 170 | 171 | def remove_table_alias(s): 172 | tables_aliases = Parser(s).tables_aliases 173 | new_tables_aliases = {} 174 | for i in range(1, 11): 175 | if "t{}".format(i) in tables_aliases.keys(): 176 | new_tables_aliases["t{}".format(i)] = tables_aliases["t{}".format(i)] 177 | 178 | tables_aliases = new_tables_aliases 179 | for k, v in tables_aliases.items(): 180 | s = s.replace("as " + k + " ", "") 181 | s = s.replace(k, v) 182 | 183 | return s 184 | 185 | processing_func = lambda x: remove_table_alias(add_asc(lower(white_space_fix(double2single(remove_semicolon(x)))))) 186 | 187 | return processing_func(sql) 188 | 189 | 190 | # extract the skeleton of sql and natsql 191 | def extract_skeleton(sql, db_schema): 192 | table_names_original, table_dot_column_names_original, column_names_original = [], [], [] 193 | for table in db_schema["schema_items"]: 194 | table_name_original = table["table_name_original"] 195 | table_names_original.append(table_name_original) 196 | 197 | for column_name_original in ["*"] + table["column_names_original"]: 198 | table_dot_column_names_original.append(table_name_original + "." + column_name_original) 199 | column_names_original.append(column_name_original) 200 | 201 | parsed_sql = Parser(sql) 202 | new_sql_tokens = [] 203 | for token in parsed_sql.tokens: 204 | # mask table names 205 | if token.value in table_names_original: 206 | new_sql_tokens.append("_") 207 | # mask column names 208 | elif token.value in column_names_original \ 209 | or token.value in table_dot_column_names_original: 210 | new_sql_tokens.append("_") 211 | # mask string values 212 | elif token.value.startswith("'") and token.value.endswith("'"): 213 | new_sql_tokens.append("_") 214 | # mask positive int number 215 | elif token.value.isdigit(): 216 | new_sql_tokens.append("_") 217 | # mask negative int number 218 | elif isNegativeInt(token.value): 219 | new_sql_tokens.append("_") 220 | # mask float number 221 | elif isFloat(token.value): 222 | new_sql_tokens.append("_") 223 | else: 224 | new_sql_tokens.append(token.value.strip()) 225 | 226 | sql_skeleton = " ".join(new_sql_tokens) 227 | 228 | # remove JOIN ON keywords 229 | sql_skeleton = sql_skeleton.replace("on _ = _ and _ = _", "on _ = _") 230 | sql_skeleton = sql_skeleton.replace("on _ = _ or _ = _", "on _ = _") 231 | sql_skeleton = sql_skeleton.replace(" on _ = _", "") 232 | pattern3 = re.compile("_ (?:join _ ?)+") 233 | sql_skeleton = re.sub(pattern3, "_ ", sql_skeleton) 234 | 235 | # "_ , _ , ..., _" -> "_" 236 | while ("_ , _" in sql_skeleton): 237 | sql_skeleton = sql_skeleton.replace("_ , _", "_") 238 | 239 | # remove clauses in WHERE keywords 240 | ops = ["=", "!=", ">", ">=", "<", "<="] 241 | for op in ops: 242 | if "_ {} _".format(op) in sql_skeleton: 243 | sql_skeleton = sql_skeleton.replace("_ {} _".format(op), "_") 244 | while ("where _ and _" in sql_skeleton or "where _ or _" in sql_skeleton): 245 | if "where _ and _" in sql_skeleton: 246 | sql_skeleton = sql_skeleton.replace("where _ and _", "where _") 247 | if "where _ or _" in sql_skeleton: 248 | sql_skeleton = sql_skeleton.replace("where _ or _", "where _") 249 | 250 | # remove additional spaces in the skeleton 251 | while " " in sql_skeleton: 252 | sql_skeleton = sql_skeleton.replace(" ", " ") 253 | 254 | return sql_skeleton 255 | 256 | 257 | def isNegativeInt(string): 258 | if string.startswith("-") and string[1:].isdigit(): 259 | return True 260 | else: 261 | return False 262 | 263 | 264 | def isFloat(string): 265 | if string.startswith("-"): 266 | string = string[1:] 267 | 268 | s = string.split(".") 269 | if len(s) > 2: 270 | return False 271 | else: 272 | for s_i in s: 273 | if not s_i.isdigit(): 274 | return False 275 | return True 276 | 277 | 278 | def main(opt): 279 | dataset = json.load(open(opt.input_dataset_path)) 280 | all_db_infos = json.load(open(opt.table_path)) 281 | 282 | assert opt.mode in ["train", "eval", "test"] 283 | 284 | if opt.mode in ["train", "eval"] and opt.target_type == "natsql": 285 | # only train_spider.json and dev.json have corresponding natsql dataset 286 | natsql_dataset = json.load(open(opt.natsql_dataset_path)) 287 | else: 288 | # empty natsql dataset 289 | natsql_dataset = [None for _ in range(len(dataset))] 290 | 291 | db_schemas = get_db_schemas(all_db_infos, opt) 292 | 293 | preprocessed_dataset = [] 294 | 295 | for natsql_data, data in tqdm(zip(natsql_dataset, dataset)): 296 | if data[ 297 | 'query'] == 'SELECT T1.company_name FROM Third_Party_Companies AS T1 JOIN Maintenance_Contracts AS T2 ON T1.company_id = T2.maintenance_contract_company_id JOIN Ref_Company_Types AS T3 ON T1.company_type_code = T3.company_type_code ORDER BY T2.contract_end_date DESC LIMIT 1': 298 | data[ 299 | 'query'] = 'SELECT T1.company_type FROM Third_Party_Companies AS T1 JOIN Maintenance_Contracts AS T2 ON T1.company_id = T2.maintenance_contract_company_id ORDER BY T2.contract_end_date DESC LIMIT 1' 300 | data['query_toks'] = ['SELECT', 'T1.company_type', 'FROM', 'Third_Party_Companies', 'AS', 'T1', 'JOIN', 301 | 'Maintenance_Contracts', 'AS', 'T2', 'ON', 'T1.company_id', '=', 302 | 'T2.maintenance_contract_company_id', 'ORDER', 'BY', 'T2.contract_end_date', 303 | 'DESC', 304 | 'LIMIT', '1'] 305 | data['query_toks_no_value'] = ['select', 't1', '.', 'company_type', 'from', 'third_party_companies', 306 | 'as', 307 | 't1', 'join', 'maintenance_contracts', 'as', 't2', 'on', 't1', '.', 308 | 'company_id', '=', 't2', '.', 'maintenance_contract_company_id', 'order', 309 | 'by', 't2', '.', 'contract_end_date', 'desc', 'limit', 'value'] 310 | data['question'] = 'What is the type of the company who concluded its contracts most recently?' 311 | data['question_toks'] = ['What', 'is', 'the', 'type', 'of', 'the', 'company', 'who', 'concluded', 'its', 312 | 'contracts', 'most', 'recently', '?'] 313 | if data['query'].startswith( 314 | 'SELECT T1.fname FROM student AS T1 JOIN lives_in AS T2 ON T1.stuid = T2.stuid WHERE T2.dormid IN'): 315 | data['query'] = data['query'].replace('IN (SELECT T2.dormid)', 'IN (SELECT T3.dormid)') 316 | index = data['query_toks'].index('(') + 2 317 | assert data['query_toks'][index] == 'T2.dormid' 318 | data['query_toks'][index] = 'T3.dormid' 319 | index = data['query_toks_no_value'].index('(') + 2 320 | assert data['query_toks_no_value'][index] == 't2' 321 | data['query_toks_no_value'][index] = 't3' 322 | 323 | question = data["question"].replace("\u2018", "'").replace("\u2019", "'").replace("\u201c", "'").replace( 324 | "\u201d", "'").strip() 325 | db_id = data["db_id"] 326 | 327 | if opt.mode == "test": 328 | sql, norm_sql, sql_skeleton = "", "", "" 329 | sql_tokens = [] 330 | 331 | natsql, norm_natsql, natsql_skeleton = "", "", "" 332 | natsql_used_columns, natsql_tokens = [], [] 333 | else: 334 | 335 | sql = data["query"].strip() 336 | norm_sql = normalization(sql).strip() 337 | sql_skeleton = extract_skeleton(norm_sql, db_schemas[db_id]).strip() 338 | sql_tokens = norm_sql.split() 339 | 340 | if natsql_data is not None: 341 | natsql = natsql_data["NatSQL"].strip() 342 | norm_natsql = normalization(natsql).strip() 343 | natsql_skeleton = extract_skeleton(norm_natsql, db_schemas[db_id]).strip() 344 | natsql_used_columns = [token for token in norm_natsql.split() if "." in token and token != "@.@"] 345 | natsql_tokens = [] 346 | for token in norm_natsql.split(): 347 | # split table_name_original.column_name_original 348 | if "." in token: 349 | natsql_tokens.extend(token.split(".")) 350 | else: 351 | natsql_tokens.append(token) 352 | else: 353 | natsql, norm_natsql, natsql_skeleton = "", "", "" 354 | natsql_used_columns, natsql_tokens = [], [] 355 | 356 | preprocessed_data = {} 357 | preprocessed_data["question"] = question 358 | preprocessed_data["db_id"] = db_id 359 | 360 | preprocessed_data["sql"] = sql 361 | preprocessed_data["norm_sql"] = norm_sql 362 | preprocessed_data["sql_skeleton"] = sql_skeleton 363 | 364 | preprocessed_data["natsql"] = natsql 365 | preprocessed_data["norm_natsql"] = norm_natsql 366 | preprocessed_data["natsql_skeleton"] = natsql_skeleton 367 | 368 | preprocessed_data["db_schema"] = [] 369 | preprocessed_data["pk"] = db_schemas[db_id]["pk"] 370 | preprocessed_data["fk"] = db_schemas[db_id]["fk"] 371 | preprocessed_data["table_labels"] = [] 372 | preprocessed_data["column_labels"] = [] 373 | 374 | # add database information (including table name, column name, ..., table_labels, and column labels) 375 | for table in db_schemas[db_id]["schema_items"]: 376 | db_contents = get_db_contents( 377 | question, 378 | table["table_name_original"], 379 | table["column_names_original"], 380 | db_id, 381 | opt.db_path 382 | ) 383 | 384 | preprocessed_data["db_schema"].append({ 385 | "table_name_original": table["table_name_original"], 386 | "table_name": table["table_name"], 387 | "column_names": table["column_names"], 388 | "column_names_original": table["column_names_original"], 389 | "column_types": table["column_types"], 390 | "db_contents": db_contents 391 | }) 392 | 393 | # extract table and column classification labels 394 | if opt.target_type == "sql": 395 | if table["table_name_original"] in sql_tokens: # for used tables 396 | preprocessed_data["table_labels"].append(1) 397 | column_labels = [] 398 | for column_name_original in table["column_names_original"]: 399 | if column_name_original in sql_tokens or \ 400 | table[ 401 | "table_name_original"] + "." + column_name_original in sql_tokens: # for used columns 402 | column_labels.append(1) 403 | else: 404 | column_labels.append(0) 405 | preprocessed_data["column_labels"].append(column_labels) 406 | else: # for unused tables and their columns 407 | preprocessed_data["table_labels"].append(0) 408 | preprocessed_data["column_labels"].append([0 for _ in range(len(table["column_names_original"]))]) 409 | elif opt.target_type == "natsql": 410 | if table["table_name_original"] in natsql_tokens: # for used tables 411 | preprocessed_data["table_labels"].append(1) 412 | column_labels = [] 413 | for column_name_original in table["column_names_original"]: 414 | if table[ 415 | "table_name_original"] + "." + column_name_original in natsql_used_columns: # for used columns 416 | column_labels.append(1) 417 | else: 418 | column_labels.append(0) 419 | preprocessed_data["column_labels"].append(column_labels) 420 | else: 421 | preprocessed_data["table_labels"].append(0) 422 | preprocessed_data["column_labels"].append([0 for _ in range(len(table["column_names_original"]))]) 423 | else: 424 | raise ValueError("target_type should be ``sql'' or ``natsql''") 425 | 426 | preprocessed_dataset.append(preprocessed_data) 427 | 428 | with open(opt.output_dataset_path, "w") as f: 429 | preprocessed_dataset_str = json.dumps(preprocessed_dataset, indent=2) 430 | f.write(preprocessed_dataset_str) 431 | 432 | 433 | if __name__ == "__main__": 434 | opt = parse_option() 435 | main(opt) 436 | -------------------------------------------------------------------------------- /src/prompt_generate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | def parse_option(): 5 | parser = argparse.ArgumentParser("command line arguments for generate prompt") 6 | parser.add_argument("--input_dataset_path", type=str) 7 | parser.add_argument("--output_dataset_path", type=str) 8 | 9 | opt = parser.parse_args() 10 | 11 | return opt 12 | 13 | 14 | if __name__ == "__main__": 15 | opt = parse_option() 16 | print(opt) 17 | with open(opt.input_dataset_path) as f: 18 | data_all = json.load(f) 19 | temp = [] 20 | for id, data in enumerate(data_all): 21 | data['input_sequence'] = "### Complete sqlite SQL query only and with no explanation, and do not select extra columns that are not explicitly requested in the query. " \ 22 | "\n ### Sqlite SQL tables, with their properties: \n#\n" 23 | schema = "" 24 | for tab, cols in data['schema'].items(): 25 | schema += '# ' + tab + ' ( ' 26 | for i, col in enumerate(cols): 27 | schema += col 28 | if data['db_contents'][tab][i]: 29 | schema += '("' 30 | for value in data['db_contents'][tab][i]: 31 | schema += value + '", "' 32 | schema = schema[:-4] + '")' 33 | schema += ', ' 34 | schema = schema[:-2] + ' )\n' 35 | data['input_sequence'] += schema[:-1] 36 | for fk in data['fk']: 37 | data['input_sequence'] += '\n# ' + fk 38 | data['input_sequence'] += '\n#\n### ' + data['question'] + '\nSELECT' 39 | with open(opt.output_dataset_path, 'w') as f: 40 | json.dump(data_all, f, indent=2) 41 | 42 | -------------------------------------------------------------------------------- /src/schema_item_classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import transformers 5 | import argparse 6 | import torch.optim as optim 7 | 8 | from tqdm import tqdm 9 | from copy import deepcopy 10 | from tokenizers import AddedToken 11 | from utils.classifier_metric.evaluator import cls_metric, auc_metric 12 | from torch.utils.data import DataLoader 13 | from transformers import RobertaTokenizerFast 14 | from utils.classifier_model import MyClassifier 15 | from utils.classifier_loss import ClassifierLoss 16 | from transformers.trainer_utils import set_seed 17 | from torch.utils.tensorboard import SummaryWriter 18 | from utils.load_dataset import ColumnAndTableClassifierDataset 19 | from utils.print_tools import dprint 20 | 21 | 22 | def parse_option(): 23 | parser = argparse.ArgumentParser("command line arguments for fine-tuning schema item classifier.") 24 | 25 | parser.add_argument('--batch_size', type=int, default=8, 26 | help='input batch size.') 27 | parser.add_argument('--gradient_descent_step', type=int, default=4, 28 | help='perform gradient descent per "gradient_descent_step" steps.') 29 | parser.add_argument('--device', type=str, default="3", 30 | help='the id of used GPU device.') 31 | parser.add_argument('--learning_rate', type=float, default=3e-5, 32 | help='learning rate.') 33 | parser.add_argument('--gamma', type=float, default=1.0, 34 | help='gamma parameter in the focal loss. Recommended: [0.0-2.0].') 35 | parser.add_argument('--alpha', type=float, default=1.0, 36 | help='alpha parameter in the focal loss. Must between [0.0-1.0].') 37 | parser.add_argument('--epochs', type=int, default=128, 38 | help='training epochs.') 39 | parser.add_argument('--patience', type=int, default=32, 40 | help='patience step in early stopping. -1 means no early stopping.') 41 | parser.add_argument('--seed', type=int, default=42, 42 | help='random seed.') 43 | parser.add_argument('--save_path', type=str, default="models/schema_item_classifier", 44 | help='save path of best fine-tuned model on validation set.') 45 | parser.add_argument('--tensorboard_save_path', type=str, default=None, 46 | help='save path of tensorboard log.') 47 | parser.add_argument('--train_filepath', type=str, default="data/pre-processing/preprocessed_train_spider.json", 48 | help='path of pre-processed training dataset.') 49 | parser.add_argument('--dev_filepath', type=str, default="data/pre-processing/preprocessed_dev.json", 50 | help='path of pre-processed development dataset.') 51 | parser.add_argument('--output_filepath', type=str, default="data/pre-processing/dataset_with_pred_probs.json", 52 | help='path of the output dataset (used in eval mode).') 53 | parser.add_argument('--model_name_or_path', type=str, default="roberta-large", 54 | help='''pre-trained model name.''') 55 | parser.add_argument('--use_contents', action='store_true', 56 | help='whether to integrate db contents into input sequence') 57 | parser.add_argument('--add_fk_info', action='store_true', 58 | help='whether to add [FK] tokens into input sequence') 59 | parser.add_argument('--mode', type=str, default="train", 60 | help='trian, eval or test.') 61 | 62 | opt = parser.parse_args() 63 | 64 | return opt 65 | 66 | 67 | def prepare_batch_inputs_and_labels(batch, tokenizer): 68 | batch_size = len(batch) 69 | 70 | batch_questions = [data[0] for data in batch] 71 | 72 | batch_table_names = [data[1] for data in batch] 73 | batch_table_labels = [data[2] for data in batch] 74 | 75 | batch_column_infos = [data[3] for data in batch] 76 | batch_column_labels = [data[4] for data in batch] 77 | 78 | batch_input_tokens, batch_column_info_ids, batch_table_name_ids, batch_column_number_in_each_table = [], [], [], [] 79 | for batch_id in range(batch_size): 80 | input_tokens = [batch_questions[batch_id]] 81 | table_names_in_one_db = batch_table_names[batch_id] 82 | column_infos_in_one_db = batch_column_infos[batch_id] 83 | 84 | batch_column_number_in_each_table.append( 85 | [len(column_infos_in_one_table) for column_infos_in_one_table in column_infos_in_one_db]) 86 | 87 | column_info_ids, table_name_ids = [], [] 88 | 89 | for table_id, table_name in enumerate(table_names_in_one_db): 90 | input_tokens.append("|") 91 | input_tokens.append(table_name) 92 | table_name_ids.append(len(input_tokens) - 1) 93 | input_tokens.append(":") 94 | 95 | for column_info in column_infos_in_one_db[table_id]: 96 | input_tokens.append(column_info) 97 | column_info_ids.append(len(input_tokens) - 1) 98 | input_tokens.append(",") 99 | 100 | input_tokens = input_tokens[:-1] 101 | 102 | batch_input_tokens.append(input_tokens) 103 | batch_column_info_ids.append(column_info_ids) 104 | batch_table_name_ids.append(table_name_ids) 105 | 106 | # notice: the trunction operation will discard some tables and columns that exceed the max length 107 | tokenized_inputs = tokenizer( 108 | batch_input_tokens, 109 | return_tensors="pt", 110 | is_split_into_words=True, 111 | padding="max_length", 112 | max_length=512, 113 | truncation=True 114 | ) 115 | dprint(f'tokenized_inputs["input_ids"].shape {tokenized_inputs["input_ids"].shape}', "3.7") 116 | 117 | batch_aligned_question_ids, batch_aligned_column_info_ids, batch_aligned_table_name_ids = [], [], [] 118 | batch_aligned_table_labels, batch_aligned_column_labels = [], [] 119 | 120 | # align batch_question_ids, batch_column_info_ids, and batch_table_name_ids after tokenizing 121 | for batch_id in range(batch_size): 122 | word_ids = tokenized_inputs.word_ids(batch_index=batch_id) 123 | dprint(f'word_ids {word_ids}', "3.7") 124 | dprint(f'batch_input_tokens[batch_id] {batch_input_tokens[batch_id]}', "3.7") 125 | aligned_question_ids, aligned_table_name_ids, aligned_column_info_ids = [], [], [] 126 | aligned_table_labels, aligned_column_labels = [], [] 127 | 128 | # align question tokens 129 | for token_id, word_id in enumerate(word_ids): 130 | if word_id == 0: 131 | aligned_question_ids.append(token_id) 132 | 133 | # align table names 134 | for t_id, table_name_id in enumerate(batch_table_name_ids[batch_id]): 135 | temp_list = [] 136 | for token_id, word_id in enumerate(word_ids): 137 | if table_name_id == word_id: 138 | temp_list.append(token_id) 139 | # if the tokenizer doesn't discard current table name 140 | if len(temp_list) != 0: 141 | aligned_table_name_ids.append(temp_list) 142 | aligned_table_labels.append(batch_table_labels[batch_id][t_id]) 143 | dprint(f"aligned_table_labels: {aligned_table_labels}", "3.7") 144 | # align column names 145 | for c_id, column_id in enumerate(batch_column_info_ids[batch_id]): 146 | temp_list = [] 147 | for token_id, word_id in enumerate(word_ids): 148 | if column_id == word_id: 149 | temp_list.append(token_id) 150 | # if the tokenizer doesn't discard current column name 151 | if len(temp_list) != 0: 152 | aligned_column_info_ids.append(temp_list) 153 | aligned_column_labels.append(batch_column_labels[batch_id][c_id]) 154 | 155 | batch_aligned_question_ids.append(aligned_question_ids) 156 | batch_aligned_table_name_ids.append(aligned_table_name_ids) 157 | batch_aligned_column_info_ids.append(aligned_column_info_ids) 158 | batch_aligned_table_labels.append(aligned_table_labels) 159 | batch_aligned_column_labels.append(aligned_column_labels) 160 | 161 | # update column number in each table (because some tables and columns are discarded) 162 | for batch_id in range(batch_size): 163 | if len(batch_column_number_in_each_table[batch_id]) > len(batch_aligned_table_labels[batch_id]): 164 | batch_column_number_in_each_table[batch_id] = batch_column_number_in_each_table[batch_id][ 165 | : len(batch_aligned_table_labels[batch_id])] 166 | 167 | if sum(batch_column_number_in_each_table[batch_id]) > len(batch_aligned_column_labels[batch_id]): 168 | truncated_column_number = sum(batch_column_number_in_each_table[batch_id]) - len( 169 | batch_aligned_column_labels[batch_id]) 170 | batch_column_number_in_each_table[batch_id][-1] -= truncated_column_number 171 | 172 | encoder_input_ids = tokenized_inputs["input_ids"] 173 | encoder_input_attention_mask = tokenized_inputs["attention_mask"] 174 | batch_aligned_column_labels = [torch.LongTensor(column_labels) for column_labels in batch_aligned_column_labels] 175 | batch_aligned_table_labels = [torch.LongTensor(table_labels) for table_labels in batch_aligned_table_labels] 176 | 177 | # print("\n".join(tokenizer.batch_decode(encoder_input_ids, skip_special_tokens = True))) 178 | 179 | if torch.cuda.is_available(): 180 | encoder_input_ids = encoder_input_ids.cuda() 181 | encoder_input_attention_mask = encoder_input_attention_mask.cuda() 182 | batch_aligned_column_labels = [column_labels.cuda() for column_labels in batch_aligned_column_labels] 183 | batch_aligned_table_labels = [table_labels.cuda() for table_labels in batch_aligned_table_labels] 184 | 185 | return encoder_input_ids, encoder_input_attention_mask, \ 186 | batch_aligned_column_labels, batch_aligned_table_labels, \ 187 | batch_aligned_question_ids, batch_aligned_column_info_ids, \ 188 | batch_aligned_table_name_ids, batch_column_number_in_each_table 189 | 190 | 191 | def _train(opt): 192 | print('hyper parameters:', opt) 193 | set_seed(opt.seed) 194 | 195 | patience = opt.patience if opt.patience > 0 else float('inf') 196 | 197 | if opt.tensorboard_save_path is not None: 198 | writer = SummaryWriter(opt.tensorboard_save_path) 199 | else: 200 | writer = None 201 | 202 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.device 203 | 204 | tokenizer = RobertaTokenizerFast.from_pretrained( 205 | opt.model_name_or_path, 206 | add_prefix_space=True 207 | ) 208 | tokenizer.add_tokens(AddedToken("[FK]")) 209 | 210 | train_dataset = ColumnAndTableClassifierDataset( 211 | dir_=opt.train_filepath, 212 | use_contents=opt.use_contents, 213 | add_fk_info=opt.add_fk_info 214 | ) 215 | 216 | train_dataloder = DataLoader( 217 | train_dataset, 218 | batch_size=opt.batch_size, 219 | shuffle=True, 220 | collate_fn=lambda x: x 221 | ) 222 | 223 | dev_dataset = ColumnAndTableClassifierDataset( 224 | dir_=opt.dev_filepath, 225 | use_contents=opt.use_contents, 226 | add_fk_info=opt.add_fk_info 227 | ) 228 | 229 | dev_dataloder = DataLoader( 230 | dev_dataset, 231 | batch_size=opt.batch_size, 232 | shuffle=False, 233 | collate_fn=lambda x: x 234 | ) 235 | 236 | # initialize model 237 | model = MyClassifier( 238 | model_name_or_path=opt.model_name_or_path, 239 | vocab_size=len(tokenizer), 240 | mode=opt.mode 241 | ) 242 | 243 | if torch.cuda.is_available(): 244 | model = model.cuda() 245 | 246 | # warm up steps (10% training step) 247 | num_warmup_steps = int(0.1 * opt.epochs * len(train_dataset) / opt.batch_size) 248 | # total training steps 249 | num_training_steps = int(opt.epochs * len(train_dataset) / opt.batch_size) 250 | # evaluate model for each 1.42857 epochs (about 1.42857*7000=10000 examples for Spider) 251 | num_checkpoint_steps = int(1.42857 * len(train_dataset) / opt.batch_size) 252 | 253 | optimizer = optim.AdamW( 254 | params=model.parameters(), 255 | lr=opt.learning_rate 256 | ) 257 | 258 | scheduler = transformers.get_cosine_schedule_with_warmup( 259 | optimizer, 260 | num_warmup_steps=num_warmup_steps, 261 | num_training_steps=num_training_steps 262 | ) 263 | 264 | best_score, early_stop_step, train_step = 0, 0, 0 265 | encoder_loss_func = ClassifierLoss(alpha=opt.alpha, gamma=opt.gamma) 266 | 267 | for epoch in range(opt.epochs): 268 | print(f"This is epoch {epoch + 1}.") 269 | for batch in train_dataloder: 270 | model.train() 271 | train_step += 1 272 | 273 | encoder_input_ids, encoder_input_attention_mask, \ 274 | batch_column_labels, batch_table_labels, batch_aligned_question_ids, \ 275 | batch_aligned_column_info_ids, batch_aligned_table_name_ids, \ 276 | batch_column_number_in_each_table = prepare_batch_inputs_and_labels(batch, tokenizer) 277 | 278 | model_outputs = model( 279 | encoder_input_ids, 280 | encoder_input_attention_mask, 281 | batch_aligned_question_ids, 282 | batch_aligned_column_info_ids, 283 | batch_aligned_table_name_ids, 284 | batch_column_number_in_each_table 285 | ) 286 | 287 | loss = encoder_loss_func.compute_loss( 288 | model_outputs["batch_table_name_cls_logits"], 289 | batch_table_labels, 290 | model_outputs["batch_column_info_cls_logits"], 291 | batch_column_labels 292 | ) 293 | 294 | loss.backward() 295 | 296 | # update lr 297 | if scheduler is not None: 298 | scheduler.step() 299 | 300 | if writer is not None: 301 | # record training loss (tensorboard) 302 | writer.add_scalar('train loss', loss.item(), train_step) 303 | # record learning rate (tensorboard) 304 | writer.add_scalar('train lr', optimizer.state_dict()['param_groups'][0]['lr'], train_step) 305 | 306 | if train_step % opt.gradient_descent_step == 0: 307 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 308 | optimizer.step() 309 | optimizer.zero_grad() 310 | 311 | if train_step % num_checkpoint_steps == 0: 312 | print(f"At {train_step} training step, start an evaluation.") 313 | model.eval() 314 | 315 | table_labels_for_auc, column_labels_for_auc = [], [] 316 | table_pred_probs_for_auc, column_pred_probs_for_auc = [], [] 317 | 318 | for batch in dev_dataloder: 319 | encoder_input_ids, encoder_input_attention_mask, \ 320 | batch_column_labels, batch_table_labels, batch_aligned_question_ids, \ 321 | batch_aligned_column_info_ids, batch_aligned_table_name_ids, \ 322 | batch_column_number_in_each_table = prepare_batch_inputs_and_labels(batch, tokenizer) 323 | 324 | with torch.no_grad(): 325 | model_outputs = model( 326 | encoder_input_ids, 327 | encoder_input_attention_mask, 328 | batch_aligned_question_ids, 329 | batch_aligned_column_info_ids, 330 | batch_aligned_table_name_ids, 331 | batch_column_number_in_each_table 332 | ) 333 | 334 | for batch_id, table_logits in enumerate(model_outputs["batch_table_name_cls_logits"]): 335 | table_pred_probs = torch.nn.functional.softmax(table_logits, dim=1) 336 | 337 | table_pred_probs_for_auc.extend(table_pred_probs[:, 1].cpu().tolist()) 338 | table_labels_for_auc.extend(batch_table_labels[batch_id].cpu().tolist()) 339 | 340 | for batch_id, column_logits in enumerate(model_outputs["batch_column_info_cls_logits"]): 341 | column_pred_probs = torch.nn.functional.softmax(column_logits, dim=1) 342 | 343 | column_pred_probs_for_auc.extend(column_pred_probs[:, 1].cpu().tolist()) 344 | column_labels_for_auc.extend(batch_column_labels[batch_id].cpu().tolist()) 345 | 346 | # calculate AUC score for table classification 347 | table_auc = auc_metric(table_labels_for_auc, table_pred_probs_for_auc) 348 | # calculate AUC score for column classification 349 | column_auc = auc_metric(column_labels_for_auc, column_pred_probs_for_auc) 350 | print("table AUC:", table_auc) 351 | print("column AUC:", column_auc) 352 | 353 | if writer is not None: 354 | writer.add_scalar('table AUC', table_auc, train_step / num_checkpoint_steps) 355 | writer.add_scalar('column AUC', column_auc, train_step / num_checkpoint_steps) 356 | 357 | toral_auc_score = table_auc + column_auc 358 | print("total auc:", toral_auc_score) 359 | # save the best ckpt 360 | if toral_auc_score >= best_score: 361 | best_score = toral_auc_score 362 | os.makedirs(opt.save_path, exist_ok=True) 363 | torch.save(model.state_dict(), opt.save_path + "/dense_classifier.pt") 364 | model.plm_encoder.config.save_pretrained(save_directory=opt.save_path) 365 | tokenizer.save_pretrained(save_directory=opt.save_path) 366 | early_stop_step = 0 367 | else: 368 | early_stop_step += 1 369 | 370 | print("early_stop_step:", early_stop_step) 371 | 372 | if early_stop_step >= patience: 373 | break 374 | 375 | if early_stop_step >= patience: 376 | print("Classifier training process triggers early stopping.") 377 | break 378 | 379 | print("best auc score:", best_score) 380 | 381 | 382 | def _test(opt): 383 | set_seed(opt.seed) 384 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.device 385 | 386 | # load tokenizer 387 | tokenizer = RobertaTokenizerFast.from_pretrained( 388 | opt.save_path, 389 | add_prefix_space=True 390 | ) 391 | 392 | dataset = ColumnAndTableClassifierDataset( 393 | dir_=opt.dev_filepath, 394 | use_contents=opt.use_contents, 395 | add_fk_info=opt.add_fk_info 396 | ) 397 | 398 | dataloder = DataLoader( 399 | dataset, 400 | batch_size=opt.batch_size, 401 | shuffle=False, 402 | collate_fn=lambda x: x 403 | ) 404 | 405 | # initialize model 406 | model = MyClassifier( 407 | model_name_or_path=opt.save_path, 408 | vocab_size=len(tokenizer), 409 | mode=opt.mode 410 | ) 411 | 412 | # load fine-tuned params 413 | model.load_state_dict(torch.load(opt.save_path + "/dense_classifier.pt", map_location=torch.device('cpu'))) 414 | if torch.cuda.is_available(): 415 | model = model.cuda() 416 | model.eval() 417 | 418 | table_labels_for_auc, column_labels_for_auc = [], [] 419 | table_pred_probs_for_auc, column_pred_probs_for_auc = [], [] 420 | 421 | returned_table_pred_probs, returned_column_pred_probs = [], [] 422 | 423 | for batch in tqdm(dataloder): 424 | encoder_input_ids, encoder_input_attention_mask, \ 425 | batch_column_labels, batch_table_labels, batch_aligned_question_ids, \ 426 | batch_aligned_column_info_ids, batch_aligned_table_name_ids, \ 427 | batch_column_number_in_each_table = prepare_batch_inputs_and_labels(batch, tokenizer) 428 | 429 | with torch.no_grad(): 430 | model_outputs = model( 431 | encoder_input_ids, 432 | encoder_input_attention_mask, 433 | batch_aligned_question_ids, 434 | batch_aligned_column_info_ids, 435 | batch_aligned_table_name_ids, 436 | batch_column_number_in_each_table 437 | ) 438 | 439 | for batch_id, table_logits in enumerate(model_outputs["batch_table_name_cls_logits"]): 440 | table_pred_probs = torch.nn.functional.softmax(table_logits, dim=1) 441 | returned_table_pred_probs.append(table_pred_probs[:, 1].cpu().tolist()) 442 | 443 | table_pred_probs_for_auc.extend(table_pred_probs[:, 1].cpu().tolist()) 444 | table_labels_for_auc.extend(batch_table_labels[batch_id].cpu().tolist()) 445 | 446 | for batch_id, column_logits in enumerate(model_outputs["batch_column_info_cls_logits"]): 447 | column_number_in_each_table = batch_column_number_in_each_table[batch_id] 448 | column_pred_probs = torch.nn.functional.softmax(column_logits, dim=1) 449 | returned_column_pred_probs.append([column_pred_probs[:, 1].cpu().tolist()[ 450 | sum(column_number_in_each_table[:table_id]):sum( 451 | column_number_in_each_table[:table_id + 1])] \ 452 | for table_id in range(len(column_number_in_each_table))]) 453 | 454 | column_pred_probs_for_auc.extend(column_pred_probs[:, 1].cpu().tolist()) 455 | column_labels_for_auc.extend(batch_column_labels[batch_id].cpu().tolist()) 456 | 457 | if opt.mode == "eval": 458 | # calculate AUC score for table classification 459 | table_auc = auc_metric(table_labels_for_auc, table_pred_probs_for_auc) 460 | # calculate AUC score for column classification 461 | column_auc = auc_metric(column_labels_for_auc, column_pred_probs_for_auc) 462 | print("table auc:", table_auc) 463 | print("column auc:", column_auc) 464 | print("total auc:", table_auc + column_auc) 465 | 466 | return returned_table_pred_probs, returned_column_pred_probs 467 | 468 | 469 | if __name__ == "__main__": 470 | opt = parse_option() 471 | if opt.mode == "train": 472 | _train(opt) 473 | elif opt.mode in ["eval", "test"]: 474 | total_table_pred_probs, total_column_pred_probs = _test(opt) 475 | 476 | with open(opt.dev_filepath, "r") as f: 477 | dataset = json.load(f) 478 | 479 | # record predicted probability 480 | truncated_data_info = [] 481 | for data_id, data in enumerate(dataset): 482 | table_num = len(data["table_labels"]) 483 | if table_num == len(total_table_pred_probs[data_id]): 484 | table_pred_probs = total_table_pred_probs[data_id] 485 | else: 486 | table_pred_probs = total_table_pred_probs[data_id] + [-1 for _ in range( 487 | table_num - len(total_table_pred_probs[data_id]))] 488 | 489 | truncated_table_ids = [] 490 | column_pred_probs = [] 491 | for table_id in range(table_num): 492 | if table_id >= len(total_column_pred_probs[data_id]): 493 | truncated_table_ids.append(table_id) 494 | column_pred_probs.append([-1 for _ in range(len(data["column_labels"][table_id]))]) 495 | continue 496 | if len(total_column_pred_probs[data_id][table_id]) == len(data["column_labels"][table_id]): 497 | column_pred_probs.append(total_column_pred_probs[data_id][table_id]) 498 | else: 499 | truncated_table_ids.append(table_id) 500 | truncated_column_num = len(data["column_labels"][table_id]) - len( 501 | total_column_pred_probs[data_id][table_id]) 502 | column_pred_probs.append( 503 | total_column_pred_probs[data_id][table_id] + [-1 for _ in range(truncated_column_num)]) 504 | 505 | data["column_pred_probs"] = column_pred_probs 506 | data["table_pred_probs"] = table_pred_probs 507 | 508 | if len(truncated_table_ids) > 0: 509 | truncated_data_info.append([data_id, truncated_table_ids]) 510 | 511 | print(truncated_data_info) 512 | # exit(0) 513 | # additionally, we need to consider and predict discarded tables and columns 514 | while len(truncated_data_info) != 0: 515 | truncated_dataset = [] 516 | for truncated_data_id, truncated_table_ids in truncated_data_info: 517 | print(dataset[truncated_data_id]["question"]) 518 | truncated_data = deepcopy(dataset[truncated_data_id]) 519 | truncated_data["db_schema"] = [truncated_data["db_schema"][table_id] for table_id in 520 | truncated_table_ids] 521 | truncated_data["table_labels"] = [truncated_data["table_labels"][table_id] for table_id in 522 | truncated_table_ids] 523 | truncated_data["column_labels"] = [truncated_data["column_labels"][table_id] for table_id in 524 | truncated_table_ids] 525 | truncated_data["table_pred_probs"] = [truncated_data["table_pred_probs"][table_id] for table_id in 526 | truncated_table_ids] 527 | truncated_data["column_pred_probs"] = [truncated_data["column_pred_probs"][table_id] for table_id in 528 | truncated_table_ids] 529 | 530 | truncated_dataset.append(truncated_data) 531 | 532 | with open("./data/pre-processing/truncated_dataset.json", "w") as f: 533 | f.write(json.dumps(truncated_dataset, indent=2)) 534 | 535 | opt.dev_filepath = "./data/pre-processing/truncated_dataset.json" 536 | total_table_pred_probs, total_column_pred_probs = _test(opt) 537 | 538 | for data_id, data in enumerate(truncated_dataset): 539 | table_num = len(data["table_labels"]) 540 | if table_num == len(total_table_pred_probs[data_id]): 541 | table_pred_probs = total_table_pred_probs[data_id] 542 | else: 543 | table_pred_probs = total_table_pred_probs[data_id] + [-1 for _ in range( 544 | table_num - len(total_table_pred_probs[data_id]))] 545 | 546 | column_pred_probs = [] 547 | for table_id in range(table_num): 548 | if table_id >= len(total_column_pred_probs[data_id]): 549 | column_pred_probs.append([-1 for _ in range(len(data["column_labels"][table_id]))]) 550 | continue 551 | if len(total_column_pred_probs[data_id][table_id]) == len(data["column_labels"][table_id]): 552 | column_pred_probs.append(total_column_pred_probs[data_id][table_id]) 553 | else: 554 | truncated_column_num = len(data["column_labels"][table_id]) - len( 555 | total_column_pred_probs[data_id][table_id]) 556 | column_pred_probs.append( 557 | total_column_pred_probs[data_id][table_id] + [-1 for _ in range(truncated_column_num)]) 558 | 559 | # fill the predicted probability into the dataset 560 | truncated_data_id = truncated_data_info[data_id][0] 561 | truncated_table_ids = truncated_data_info[data_id][1] 562 | for idx, truncated_table_id in enumerate(truncated_table_ids): 563 | dataset[truncated_data_id]["table_pred_probs"][truncated_table_id] = table_pred_probs[idx] 564 | dataset[truncated_data_id]["column_pred_probs"][truncated_table_id] = column_pred_probs[idx] 565 | 566 | # check if there are tables and columns in the new dataset that have not yet been predicted 567 | truncated_data_info = [] 568 | for data_id, data in enumerate(dataset): 569 | table_num = len(data["table_labels"]) 570 | 571 | truncated_table_ids = [] 572 | for table_id in range(table_num): 573 | # the current table is not predicted 574 | if data["table_pred_probs"][table_id] == -1: 575 | truncated_table_ids.append(table_id) 576 | # some columns in the current table are not predicted 577 | if data["table_pred_probs"][table_id] != -1 and -1 in data["column_pred_probs"][table_id]: 578 | truncated_table_ids.append(table_id) 579 | 580 | if len(truncated_table_ids) > 0: 581 | truncated_data_info.append([data_id, truncated_table_ids]) 582 | 583 | os.remove("./data/pre-processing/truncated_dataset.json") 584 | 585 | with open(opt.output_filepath, "w") as f: 586 | f.write(json.dumps(dataset, indent=2)) 587 | -------------------------------------------------------------------------------- /src/sql_post_process.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def fix_select_column(sql): 4 | # sql = "SELECT DISTINCT model FROM cars_data JOIN car_names ON cars_data.id = car_names.makeid JOIN model_list ON car_names.model = model_list.model WHERE year > 1980;" 5 | sql = sql.replace("\n", " ") 6 | sql_list = sql.split("=") # 给等号两边腾出空格 7 | sql = " = ".join(sql_list) 8 | while " " in sql: 9 | sql = sql.replace(" ", " ") 10 | sql_tokens = sql.split(" ") 11 | select_ids = [] 12 | from_ids = [] 13 | join_ids = [] 14 | eq_ids = [] 15 | first_where_id = -1 16 | first_group_by_id = -1 17 | first_having_id = -1 18 | for id, token in enumerate(sql_tokens): 19 | if token.lower() == "select": 20 | select_ids.append(id) 21 | if token.lower() == "from": 22 | from_ids.append(id) 23 | if token.lower() == "join": 24 | join_ids.append(id) 25 | if token.lower() == "=": 26 | eq_ids.append(id) 27 | if token.lower() == "where" and first_where_id == -1: 28 | first_where_id = id 29 | if token.lower() == "group" and id < len(sql_tokens) - 1 and sql_tokens[id+1].lower() == "by" and first_group_by_id == -1: 30 | first_group_by_id = id 31 | if token.lower() == "having" and first_having_id == -1: 32 | first_having_id = id 33 | 34 | if len(eq_ids) == 0 or len(join_ids) == 0: 35 | return sql 36 | # assert len(select_ids) == len(from_ids) 37 | for i in range(len(select_ids[:1])): ## 先只考虑最外层的select 38 | select_id = select_ids[i] 39 | from_id = from_ids[i] 40 | tmp_column_ids = [i for i in range(select_id + 1, from_id)] 41 | column_ids = [] 42 | id = 0 43 | while id < len(tmp_column_ids): 44 | item = sql_tokens[id] 45 | if item.lower() == "as": 46 | id += 2 47 | continue 48 | column_ids.append(tmp_column_ids[id]) 49 | id += 1 50 | column_table_mp = {} 51 | if i == len(select_ids) - 1: # last select 52 | for j in range(len(join_ids)): 53 | if (first_where_id != -1 and join_ids[j] > first_where_id) or first_group_by_id != -1 and join_ids[j]: 54 | break 55 | eq_id = eq_ids[j] 56 | left_id, right_id = eq_id - 1, eq_id + 1 57 | left_column, right_column = sql_tokens[left_id], sql_tokens[right_id] 58 | if "." not in left_column or "." not in right_column: 59 | continue 60 | column_left = left_column.split(".")[1] 61 | column_right = right_column.split(".")[1] 62 | column_table_mp[column_left] = left_column 63 | column_table_mp[column_right] = right_column 64 | else: 65 | pass 66 | 67 | if len(column_table_mp) == 0: 68 | return sql 69 | for column_id in column_ids: 70 | column = sql_tokens[column_id] 71 | if "." not in column: 72 | if column in column_table_mp.keys(): 73 | sql_tokens[column_id] = column_table_mp[column] 74 | elif len(column) > 0 and column[-1] == "," and column[:-1] in column_table_mp.keys(): 75 | sql_tokens[column_id] = column_table_mp[column[:-1]] + "," 76 | 77 | recovered_sql = " ".join(sql_tokens) 78 | 79 | return recovered_sql 80 | 81 | 82 | -------------------------------------------------------------------------------- /src/table_recall.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import openai 4 | import time 5 | from tqdm import tqdm 6 | from collections import Counter 7 | 8 | # add your openai api key 9 | openai.api_key = "sk-" 10 | 11 | 12 | def parse_option(): 13 | parser = argparse.ArgumentParser("command line arguments for recall tables") 14 | parser.add_argument("--input_dataset_path", type=str, default='../generate_datasets/preprocessed_test.json') 15 | parser.add_argument("--self_consistent", type=bool, default=True) 16 | parser.add_argument("--n", type=int, default=10, 17 | help="Size of self-consistent set") 18 | parser.add_argument("--output_recalled_tables_path", type=str) 19 | 20 | opt = parser.parse_args() 21 | 22 | return opt 23 | 24 | 25 | def generate_reply(input, sc_num): 26 | completions = openai.ChatCompletion.create( 27 | model="gpt-3.5-turbo", 28 | messages=input, 29 | # top_p=0.5 30 | temperature=0.7, 31 | n=sc_num 32 | # stop=["Q:"] 33 | ) 34 | all_tables = [] 35 | for i in range(sc_num): 36 | raw_table = completions.choices[i].message.content 37 | try: 38 | raw_table = '[' + raw_table.split('[', 1)[1] 39 | raw_table = raw_table.rsplit(']', 1)[0] + ']' 40 | raw_table = eval(raw_table) 41 | if Ellipsis in raw_table: 42 | raw_table.remove(Ellipsis) 43 | except: 44 | print('list error') 45 | return None 46 | all_tables.append(raw_table) 47 | return all_tables 48 | # return completions.choices[0].message.content 49 | 50 | 51 | def generate_schema(data): 52 | schema = "" 53 | for table in data['db_schema']: 54 | schema += '# ' + table['table_name_original'] + ' ( ' 55 | for i, column in enumerate(table['column_names_original']): 56 | schema += column 57 | if table['db_contents'][i]: 58 | schema += ' ( ' 59 | for value in table['db_contents'][i]: 60 | schema += value + ', ' 61 | schema = schema[:-2] + ' )' 62 | schema += ', ' 63 | schema = schema[:-2] + ' )\n' 64 | return schema 65 | 66 | 67 | def table_sc(tables_all, tables_ori): 68 | tables_sc = [] 69 | for id, tables in enumerate(tables_all): 70 | tables_exist = [] 71 | for table in tables: 72 | if table.lower() in tables_ori: 73 | tables_exist.append(table.lower()) 74 | if len(tables_exist) == 4: 75 | break 76 | # print(tables_exist) 77 | tables_sc.append(tables_exist) 78 | counts = Counter(tuple(sorted(lst)) for lst in tables_sc) 79 | most_list, count = counts.most_common(1)[0] 80 | for table_list in tables_sc: 81 | if sorted(table_list) == list(most_list): 82 | return table_list 83 | 84 | 85 | def info_generate(tables, data): 86 | info = {} 87 | info['db_id'] = data['db_id'] 88 | info['question'] = data['question'] 89 | info['db_schema'] = [] 90 | info['fk'] = [] 91 | for table in tables: 92 | for tab_ori in data['db_schema']: 93 | if table == tab_ori['table_name_original'].lower(): 94 | info['db_schema'].append(tab_ori) 95 | break 96 | for fk in data['fk']: 97 | if fk['source_table_name_original'] in tables and fk['target_table_name_original'] in tables: 98 | fk_str = fk['source_table_name_original'] + '.' + fk['source_column_name_original'] + ' = ' \ 99 | + fk['target_table_name_original'] + '.' + fk['target_column_name_original'] 100 | info['fk'].append(fk_str) 101 | return info 102 | 103 | 104 | instruction = """Given the database schema and question, perform the following actions: 105 | 1 - Rank all the tables based on the possibility of being used in the SQL according to the question from the most relevant to the least relevant, Table or its column that matches more with the question words is highly relevant and must be placed ahead. 106 | 2 - Check whether you consider all the tables. 107 | 3 - Output a list object in the order of step 2, Your output should contain all the tables. The format should be like: 108 | [ 109 | "table_1", "table_2", ... 110 | ] 111 | 112 | """ 113 | 114 | if __name__ == "__main__": 115 | opt = parse_option() 116 | print(opt) 117 | with open(opt.input_dataset_path) as f: 118 | data_all = json.load(f) 119 | res = [] 120 | if opt.self_consistent: 121 | sc_num = opt.n 122 | else: 123 | sc_num = 1 124 | for i, data in enumerate(tqdm(data_all)): 125 | schema = generate_schema(data) 126 | prompt = instruction + "Schema:\n" + schema + "\n" 127 | prompt += "Question:\n" + data["question"] 128 | tables_all = None 129 | while tables_all is None: 130 | try: 131 | tables_all = generate_reply([{"role": "user", "content": prompt}], sc_num) 132 | except: 133 | print(f'api error, wait for 3 seconds and retry...') 134 | time.sleep(3) 135 | pass 136 | tables_ori = [] 137 | for table in data['db_schema']: 138 | tables_ori.append(table['table_name_original'].lower()) 139 | tables = table_sc(tables_all, tables_ori) 140 | info = info_generate(tables, data) 141 | res.append(info) 142 | with open(opt.output_recalled_tables_path, 'w') as f: 143 | json.dump(res, f, indent=2) 144 | -------------------------------------------------------------------------------- /src/text2sql_data_generator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | import argparse 4 | import random 5 | import numpy as np 6 | 7 | 8 | def parse_option(): 9 | parser = argparse.ArgumentParser("command line arguments for generating the ranked dataset.") 10 | 11 | parser.add_argument('--input_dataset_path', type=str, default="./data/pre-processing/dev_with_probs.json", 12 | help='filepath of the input dataset.') 13 | parser.add_argument('--output_dataset_path', type=str, default="./data/pre-processing/resdsql_dev.json", 14 | help='filepath of the output dataset.') 15 | parser.add_argument('--topk_table_num', type=int, default=4, 16 | help='we only remain topk_table_num tables in the ranked dataset (k_1 in the paper).') 17 | parser.add_argument('--topk_column_num', type=int, default=5, 18 | help='we only remain topk_column_num columns for each table in the ranked dataset (k_2 in the paper).') 19 | parser.add_argument('--mode', type=str, default="eval", 20 | help='type of the input dataset, options: train, eval, test.') 21 | parser.add_argument('--noise_rate', type=float, default=0.08, 22 | help='the noise rate in the ranked training dataset (needed when the mode = "train")') 23 | parser.add_argument('--use_contents', action='store_true', 24 | help='whether to add database contents in the input sequence.') 25 | parser.add_argument('--add_fk_info', action='store_true', 26 | help='whether to add foreign key in the input sequence.') 27 | parser.add_argument('--output_skeleton', action='store_true', 28 | help='whether to add skeleton in the output sequence.') 29 | parser.add_argument("--target_type", type=str, default="sql", 30 | help="sql or natsql.") 31 | # parser.add_argument("--instruction_task", type=str, default="normal") 32 | parser.add_argument("--instruction_tasks", type=str, default=["thu_prompt"], nargs="+") 33 | 34 | opt = parser.parse_args() 35 | 36 | return opt 37 | 38 | 39 | def lista_contains_listb(lista, listb): 40 | for b in listb: 41 | if b not in lista: 42 | return 0 43 | 44 | return 1 45 | 46 | 47 | def prepare_input_and_output(opt, ranked_data): 48 | question = ranked_data["question"] 49 | 50 | schema_sequence = "" 51 | for table_id in range(len(ranked_data["db_schema"])): 52 | table_name_original = ranked_data["db_schema"][table_id]["table_name_original"] 53 | # add table name 54 | schema_sequence += " | " + table_name_original + " : " 55 | 56 | column_info_list = [] 57 | for column_id in range(len(ranked_data["db_schema"][table_id]["column_names_original"])): 58 | # extract column name 59 | column_name_original = ranked_data["db_schema"][table_id]["column_names_original"][column_id] 60 | db_contents = ranked_data["db_schema"][table_id]["db_contents"][column_id] 61 | # use database contents if opt.use_contents = True 62 | if opt.use_contents and len(db_contents) != 0: 63 | column_contents = " , ".join(db_contents) 64 | column_info = table_name_original + "." + column_name_original + " ( " + column_contents + " ) " 65 | else: 66 | column_info = table_name_original + "." + column_name_original 67 | 68 | column_info_list.append(column_info) 69 | 70 | if opt.target_type == "natsql": 71 | column_info_list.append(table_name_original + ".*") 72 | 73 | # add column names 74 | schema_sequence += " , ".join(column_info_list) 75 | 76 | if opt.add_fk_info: 77 | for fk in ranked_data["fk"]: 78 | schema_sequence += " | " + fk["source_table_name_original"] + "." + fk["source_column_name_original"] + \ 79 | " = " + fk["target_table_name_original"] + "." + fk["target_column_name_original"] 80 | 81 | # remove additional spaces in the schema sequence 82 | while " " in schema_sequence: 83 | schema_sequence = schema_sequence.replace(" ", " ") 84 | 85 | # input_sequence = question + schema sequence 86 | input_sequence = question + schema_sequence 87 | 88 | if opt.output_skeleton: 89 | if opt.target_type == "sql": 90 | output_sequence = ranked_data["sql_skeleton"] + " | " + ranked_data["norm_sql"] 91 | elif opt.target_type == "natsql": 92 | output_sequence = ranked_data["natsql_skeleton"] + " | " + ranked_data["norm_natsql"] 93 | else: 94 | if opt.target_type == "sql": 95 | output_sequence = ranked_data["norm_sql"] 96 | elif opt.target_type == "natsql": 97 | output_sequence = ranked_data["norm_natsql"] 98 | 99 | return input_sequence, output_sequence 100 | 101 | def prepare_input_and_output_thu_prompt(opt, ranked_data): 102 | question = ranked_data["question"] 103 | 104 | schema_sequence = "" 105 | for table_id in range(len(ranked_data["db_schema"])): 106 | table_name_original = ranked_data["db_schema"][table_id]["table_name_original"] 107 | # schema_sequence += f"# {table_name_original} ( " 108 | 109 | column_info_list = [] 110 | for column_id in range(len(ranked_data["db_schema"][table_id]["column_names_original"])): 111 | # extract column name 112 | column_name_original = ranked_data["db_schema"][table_id]["column_names_original"][column_id] 113 | db_contents = ranked_data["db_schema"][table_id]["db_contents"][column_id] 114 | # use database contents if opt.use_contents = True 115 | if opt.use_contents and len(db_contents) != 0: 116 | column_contents = " , ".join(db_contents) 117 | column_info = table_name_original + "." + column_name_original + " ( " + column_contents + " ) " 118 | else: 119 | column_info = table_name_original + "." + column_name_original 120 | 121 | column_info_list.append(column_info) 122 | 123 | if opt.target_type == "natsql": 124 | column_info_list.append(table_name_original + ".*") 125 | 126 | columns = " , ".join(column_info_list) 127 | schema_sequence += f"# {table_name_original} ( {columns} )\n" 128 | 129 | if opt.add_fk_info: 130 | for fk in ranked_data["fk"]: 131 | schema_sequence += f"# {fk['source_table_name_original']}.{fk['source_column_name_original']} = {fk['target_table_name_original']}.{fk['target_column_name_original']}\n" 132 | 133 | # remove additional spaces in the schema sequence 134 | while " " in schema_sequence: 135 | schema_sequence = schema_sequence.replace(" ", " ") 136 | 137 | # print(schema_sequence) 138 | 139 | # instruction = f"### Complete sqlite SQL query only and with no explanation\n### Sqlite SQL tables, with their properties: \n# \n" 140 | instruction = f"### Complete sqlite SQL query only and with no explanation, and do not select extra columns that are not explicitly requested in the query, \n### Sqlite SQL tables, with their properties: \n# \n" 141 | input_sequence = f"{instruction}{schema_sequence}#\n### {question}\nSELECT" 142 | 143 | if opt.output_skeleton: 144 | if opt.target_type == "sql": 145 | output_sequence = ranked_data["sql_skeleton"] + " | " + ranked_data["norm_sql"] 146 | elif opt.target_type == "natsql": 147 | output_sequence = ranked_data["natsql_skeleton"] + " | " + ranked_data["norm_natsql"] 148 | else: 149 | if opt.target_type == "sql": 150 | output_sequence = ranked_data["norm_sql"] 151 | elif opt.target_type == "natsql": 152 | output_sequence = ranked_data["norm_natsql"] 153 | 154 | return input_sequence, output_sequence 155 | 156 | def prepare_input_and_output_generate_skeleton(opt, ranked_data): 157 | question = ranked_data["question"] 158 | 159 | schema_sequence = "" 160 | for table_id in range(len(ranked_data["db_schema"])): 161 | table_name_original = ranked_data["db_schema"][table_id]["table_name_original"] 162 | # schema_sequence += f"# {table_name_original} ( " 163 | 164 | column_info_list = [] 165 | for column_id in range(len(ranked_data["db_schema"][table_id]["column_names_original"])): 166 | # extract column name 167 | column_name_original = ranked_data["db_schema"][table_id]["column_names_original"][column_id] 168 | db_contents = ranked_data["db_schema"][table_id]["db_contents"][column_id] 169 | # use database contents if opt.use_contents = True 170 | if opt.use_contents and len(db_contents) != 0: 171 | column_contents = " , ".join(db_contents) 172 | column_info = table_name_original + "." + column_name_original + " ( " + column_contents + " ) " 173 | else: 174 | column_info = table_name_original + "." + column_name_original 175 | 176 | column_info_list.append(column_info) 177 | 178 | if opt.target_type == "natsql": 179 | column_info_list.append(table_name_original + ".*") 180 | 181 | columns = " , ".join(column_info_list) 182 | schema_sequence += f"# {table_name_original} ( {columns} )\n" 183 | 184 | if opt.add_fk_info: 185 | for fk in ranked_data["fk"]: 186 | schema_sequence += f"# {fk['source_table_name_original']}.{fk['source_column_name_original']} = {fk['target_table_name_original']}.{fk['target_column_name_original']}\n" 187 | 188 | # remove additional spaces in the schema sequence 189 | while " " in schema_sequence: 190 | schema_sequence = schema_sequence.replace(" ", " ") 191 | 192 | # print(schema_sequence) 193 | 194 | instruction = f"### Complete sqlite SQL query, and replace the columns and tables with __. Notice you only need to output the sql skeleton.\n### Sqlite SQL tables, with their properties: \n# \n" 195 | input_sequence = f"{instruction}{schema_sequence}### {question}\n" 196 | 197 | if opt.target_type == "sql": 198 | output_sequence = ranked_data["sql_skeleton"] 199 | elif opt.target_type == "natsql": 200 | output_sequence = ranked_data["natsql_skeleton"] 201 | 202 | return input_sequence, output_sequence 203 | 204 | def prepare_input_and_output_fill_skeleton(opt, ranked_data): 205 | question = ranked_data["question"] 206 | 207 | schema_sequence = "" 208 | for table_id in range(len(ranked_data["db_schema"])): 209 | table_name_original = ranked_data["db_schema"][table_id]["table_name_original"] 210 | # schema_sequence += f"# {table_name_original} ( " 211 | 212 | column_info_list = [] 213 | for column_id in range(len(ranked_data["db_schema"][table_id]["column_names_original"])): 214 | # extract column name 215 | column_name_original = ranked_data["db_schema"][table_id]["column_names_original"][column_id] 216 | db_contents = ranked_data["db_schema"][table_id]["db_contents"][column_id] 217 | # use database contents if opt.use_contents = True 218 | if opt.use_contents and len(db_contents) != 0: 219 | column_contents = " , ".join(db_contents) 220 | column_info = table_name_original + "." + column_name_original + " ( " + column_contents + " ) " 221 | else: 222 | column_info = table_name_original + "." + column_name_original 223 | 224 | column_info_list.append(column_info) 225 | 226 | if opt.target_type == "natsql": 227 | column_info_list.append(table_name_original + ".*") 228 | 229 | columns = " , ".join(column_info_list) 230 | schema_sequence += f"# {table_name_original} ( {columns} )\n" 231 | 232 | if opt.add_fk_info: 233 | for fk in ranked_data["fk"]: 234 | schema_sequence += f"# {fk['source_table_name_original']}.{fk['source_column_name_original']} = {fk['target_table_name_original']}.{fk['target_column_name_original']}\n" 235 | 236 | # remove additional spaces in the schema sequence 237 | while " " in schema_sequence: 238 | schema_sequence = schema_sequence.replace(" ", " ") 239 | 240 | # print(schema_sequence) 241 | 242 | instruction = f"### Fill the blanks in the SQL skeleton to complete sqlite SQL query.\n### Sqlite SQL tables, with their properties: \n# \n" 243 | input_sequence = f"{instruction}{schema_sequence}### {question}\n" 244 | 245 | if opt.target_type == "sql": 246 | output_sequence = ranked_data["norm_sql"] 247 | skeleton = ranked_data["sql_skeleton"] 248 | elif opt.target_type == "natsql": 249 | output_sequence = ranked_data["norm_natsql"] 250 | skeleton = ranked_data["natsql_skeleton"] 251 | 252 | input_sequence = f"{instruction}{schema_sequence}### {question}\n### {skeleton}\n" 253 | 254 | return input_sequence, output_sequence 255 | 256 | def prepare_input_and_output_predict_schema_items(opt, ranked_data): 257 | pass 258 | 259 | def generate_train_ranked_dataset(opt): 260 | with open(opt.input_dataset_path) as f: 261 | dataset = json.load(f) 262 | 263 | output_dataset = [] 264 | for data_id, data in enumerate(dataset): 265 | ranked_data = dict() 266 | ranked_data["question"] = data["question"] 267 | ranked_data["sql"] = data["sql"] # unused 268 | ranked_data["norm_sql"] = data["norm_sql"] 269 | ranked_data["sql_skeleton"] = data["sql_skeleton"] 270 | ranked_data["natsql"] = data["natsql"] # unused 271 | ranked_data["norm_natsql"] = data["norm_natsql"] 272 | ranked_data["natsql_skeleton"] = data["natsql_skeleton"] 273 | ranked_data["db_id"] = data["db_id"] 274 | ranked_data["db_schema"] = [] 275 | 276 | # record ids of used tables 277 | used_table_ids = [idx for idx, label in enumerate(data["table_labels"]) if label == 1] 278 | topk_table_ids = copy.deepcopy(used_table_ids) 279 | 280 | if len(topk_table_ids) < opt.topk_table_num: 281 | remaining_table_ids = [idx for idx in range(len(data["table_labels"])) if idx not in topk_table_ids] 282 | # if topk_table_num is large than the total table number, all tables will be selected 283 | if opt.topk_table_num >= len(data["table_labels"]): 284 | topk_table_ids += remaining_table_ids 285 | # otherwise, we randomly select some unused tables 286 | else: 287 | randomly_sampled_table_ids = random.sample(remaining_table_ids, 288 | opt.topk_table_num - len(topk_table_ids)) 289 | topk_table_ids += randomly_sampled_table_ids 290 | 291 | # add noise to the training set 292 | if random.random() < opt.noise_rate: 293 | random.shuffle(topk_table_ids) 294 | 295 | for table_id in topk_table_ids: 296 | new_table_info = dict() 297 | new_table_info["table_name_original"] = data["db_schema"][table_id]["table_name_original"] 298 | # record ids of used columns 299 | used_column_ids = [idx for idx, column_label in enumerate(data["column_labels"][table_id]) if 300 | column_label == 1] 301 | topk_column_ids = copy.deepcopy(used_column_ids) 302 | 303 | if len(topk_column_ids) < opt.topk_column_num: 304 | remaining_column_ids = [idx for idx in range(len(data["column_labels"][table_id])) if 305 | idx not in topk_column_ids] 306 | # same as the selection of top-k tables 307 | if opt.topk_column_num >= len(data["column_labels"][table_id]): 308 | random.shuffle(remaining_column_ids) 309 | topk_column_ids += remaining_column_ids 310 | else: 311 | randomly_sampled_column_ids = random.sample(remaining_column_ids, 312 | opt.topk_column_num - len(topk_column_ids)) 313 | topk_column_ids += randomly_sampled_column_ids 314 | 315 | # add noise to the training set 316 | if random.random() < opt.noise_rate and table_id in used_table_ids: 317 | random.shuffle(topk_column_ids) 318 | 319 | new_table_info["column_names_original"] = [data["db_schema"][table_id]["column_names_original"][column_id] 320 | for column_id in topk_column_ids] 321 | new_table_info["db_contents"] = [data["db_schema"][table_id]["db_contents"][column_id] for column_id in 322 | topk_column_ids] 323 | 324 | ranked_data["db_schema"].append(new_table_info) 325 | 326 | # record foreign keys 327 | table_names_original = [table["table_name_original"] for table in data["db_schema"]] 328 | needed_fks = [] 329 | for fk in data["fk"]: 330 | source_table_id = table_names_original.index(fk["source_table_name_original"]) 331 | target_table_id = table_names_original.index(fk["target_table_name_original"]) 332 | if source_table_id in topk_table_ids and target_table_id in topk_table_ids: 333 | needed_fks.append(fk) 334 | ranked_data["fk"] = needed_fks 335 | 336 | for task in opt.instruction_tasks: 337 | 338 | prepare_function = prepare_function_map[task] 339 | input_sequence, output_sequence = prepare_function(opt, ranked_data) 340 | 341 | # record table_name_original.column_name_original for subsequent correction function during inference 342 | tc_original = [] 343 | for table in ranked_data["db_schema"]: 344 | for column_name_original in ["*"] + table["column_names_original"]: 345 | tc_original.append(table["table_name_original"] + "." + column_name_original) 346 | 347 | output_dataset.append( 348 | { 349 | "db_id": data["db_id"], 350 | "input_sequence": input_sequence, 351 | "output_sequence": output_sequence, 352 | "tc_original": tc_original, 353 | "question": ranked_data["question"] 354 | } 355 | ) 356 | 357 | with open(opt.output_dataset_path, "w") as f: 358 | f.write(json.dumps(output_dataset, indent=2)) 359 | 360 | 361 | def generate_eval_ranked_dataset(opt): 362 | with open(opt.input_dataset_path) as f: 363 | dataset = json.load(f) 364 | 365 | table_coverage_state_list, column_coverage_state_list = [], [] 366 | output_dataset = [] 367 | for data_id, data in enumerate(dataset): 368 | ranked_data = dict() 369 | ranked_data["question"] = data["question"] 370 | ranked_data["sql"] = data["sql"] 371 | ranked_data["norm_sql"] = data["norm_sql"] 372 | ranked_data["sql_skeleton"] = data["sql_skeleton"] 373 | ranked_data["natsql"] = data["natsql"] 374 | ranked_data["norm_natsql"] = data["norm_natsql"] 375 | ranked_data["natsql_skeleton"] = data["natsql_skeleton"] 376 | ranked_data["db_id"] = data["db_id"] 377 | ranked_data["db_schema"] = [] 378 | 379 | table_pred_probs = list(map(lambda x: round(x, 4), data["table_pred_probs"])) 380 | # find ids of tables that have top-k probability 381 | topk_table_ids = np.argsort(-np.array(table_pred_probs), kind="stable")[:opt.topk_table_num].tolist() 382 | 383 | # if the mode == eval, we record some information for calculating the coverage 384 | if opt.mode == "eval": 385 | used_table_ids = [idx for idx, label in enumerate(data["table_labels"]) if label == 1] 386 | table_coverage_state_list.append(lista_contains_listb(topk_table_ids, used_table_ids)) 387 | 388 | for idx in range(len(data["db_schema"])): 389 | used_column_ids = [idx for idx, label in enumerate(data["column_labels"][idx]) if label == 1] 390 | if len(used_column_ids) == 0: 391 | continue 392 | column_pred_probs = list(map(lambda x: round(x, 2), data["column_pred_probs"][idx])) 393 | topk_column_ids = np.argsort(-np.array(column_pred_probs), kind="stable")[:opt.topk_column_num].tolist() 394 | column_coverage_state_list.append(lista_contains_listb(topk_column_ids, used_column_ids)) 395 | 396 | # record top-k1 tables and top-k2 columns for each table 397 | for table_id in topk_table_ids: 398 | new_table_info = dict() 399 | new_table_info["table_name_original"] = data["db_schema"][table_id]["table_name_original"] 400 | column_pred_probs = list(map(lambda x: round(x, 2), data["column_pred_probs"][table_id])) 401 | topk_column_ids = np.argsort(-np.array(column_pred_probs), kind="stable")[:opt.topk_column_num].tolist() 402 | 403 | new_table_info["column_names_original"] = [data["db_schema"][table_id]["column_names_original"][column_id] 404 | for column_id in topk_column_ids] 405 | new_table_info["db_contents"] = [data["db_schema"][table_id]["db_contents"][column_id] for column_id in 406 | topk_column_ids] 407 | 408 | ranked_data["db_schema"].append(new_table_info) 409 | 410 | # record foreign keys among selected tables 411 | table_names_original = [table["table_name_original"] for table in data["db_schema"]] 412 | needed_fks = [] 413 | for fk in data["fk"]: 414 | source_table_id = table_names_original.index(fk["source_table_name_original"]) 415 | target_table_id = table_names_original.index(fk["target_table_name_original"]) 416 | if source_table_id in topk_table_ids and target_table_id in topk_table_ids: 417 | needed_fks.append(fk) 418 | ranked_data["fk"] = needed_fks 419 | 420 | for task in opt.instruction_tasks: 421 | prepare_function = prepare_function_map[task] 422 | input_sequence, output_sequence = prepare_function(opt, ranked_data) 423 | 424 | # record table_name_original.column_name_original for subsequent correction function during inference 425 | tc_original = [] 426 | for table in ranked_data["db_schema"]: 427 | for column_name_original in table["column_names_original"] + ["*"]: 428 | tc_original.append(table["table_name_original"] + "." + column_name_original) 429 | 430 | output_dataset.append( 431 | { 432 | "db_id": data["db_id"], 433 | "input_sequence": input_sequence, 434 | "output_sequence": output_sequence, 435 | "tc_original": tc_original, 436 | "question": ranked_data["question"] 437 | } 438 | ) 439 | 440 | with open(opt.output_dataset_path, "w") as f: 441 | f.write(json.dumps(output_dataset, indent=2)) 442 | 443 | if opt.mode == "eval": 444 | print("Table top-{} coverage: {}".format(opt.topk_table_num, 445 | sum(table_coverage_state_list) / len(table_coverage_state_list))) 446 | print("Column top-{} coverage: {}".format(opt.topk_column_num, 447 | sum(column_coverage_state_list) / len(column_coverage_state_list))) 448 | 449 | 450 | prepare_function_map = { 451 | "normal": prepare_input_and_output, 452 | "thu_prompt": prepare_input_and_output_thu_prompt, 453 | "fill_skeleton": prepare_input_and_output_fill_skeleton, 454 | "generate_skeleton": prepare_input_and_output_generate_skeleton, 455 | } 456 | if __name__ == "__main__": 457 | opt = parse_option() 458 | print(opt) 459 | random.seed(42) 460 | 461 | if opt.mode == "train": 462 | generate_train_ranked_dataset(opt) 463 | elif opt.mode in ["eval", "test"]: 464 | generate_eval_ranked_dataset(opt) 465 | else: 466 | raise ValueError("The mode must be one of the ['train', 'eval', 'test'].") 467 | --------------------------------------------------------------------------------