├── README.md ├── app.py ├── build_contents_index.py ├── data ├── history │ └── history.sqlite └── tables.json ├── databases └── singer │ ├── schema.sql │ └── singer.sqlite ├── images └── demo.png ├── requirements.txt ├── schema_item_filter.py ├── static ├── js │ └── jquery-3.7.1.min.js ├── logo │ ├── codes_logo.png │ └── user_logo.png └── styles │ └── style.css ├── templates └── index.html ├── text2sql.py └── utils ├── bridge_content_encoder.py ├── classifier_model.py ├── db_utils.py └── translate_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Text-to-SQL Demo 2 | 3 | ![Demo](images/demo.png) 4 | 5 | This repository releases a text-to-SQL demo, powered by [CodeS](https://arxiv.org/abs/2402.16347), a language model specifically tailored for text-to-SQL translation. 6 | 7 | **It is important to note that CodeS is designed as a single-turn text-to-SQL model and is not intended for multi-turn conversations.** Consequently, it cannot understand contexts in the chat box. Should the model's responses not meet your expectations, it is advisable to rephrase your question rather than trying to steer the model toward a correct answer with follow-up prompts. 8 | 9 | ## Environments 💫 10 | Our development environments are configured as follows: 11 | - GPU: NVIDIA A6000 with 40GB VRAM, CUDA version 11.8 12 | - CPU: Intel(R) Xeon(R) Gold 5218R @ 2.10GHz, accompanied by 256GB of RAM 13 | - Operating System: Ubuntu 20.04.2 LTS 14 | - Python Environment: Anaconda3, Python version 3.8.5 15 | 16 | ### Step 1: Install Java 17 | (If Java is already installed, feel free to skip this step.) 18 | 19 | Execute the following commands in your terminal: 20 | ```bash 21 | apt-get update 22 | apt-get install -y openjdk-11-jdk 23 | ``` 24 | 25 | ### Step 2: Create and Activate a Virtual Anaconda Environment 26 | Run these commands to set up your virtual environment: 27 | ```bash 28 | conda create -n demo python=3.8.5 29 | conda activate demo 30 | ``` 31 | 32 | ### Step 3: Install Required Python Modules 33 | Ensure you have all necessary packages by running: 34 | ```bash 35 | conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia 36 | pip install -r requirements.txt 37 | ``` 38 | 39 | Now your environment should be all set up and ready for deployment! 40 | 41 | ## Prerequisites 🪐 42 | ### Step 1: Download Classifier Weights 43 | Download the the file [sic_ckpts.zip](https://drive.google.com/file/d/1V3F4ihTSPbV18g3lrg94VMH-kbWR_-lY/view?usp=sharing) for the schema item classifier. Then, unzip the downloaded file in the root directory of the project: 44 | ``` 45 | unzip sic_ckpts.zip 46 | ``` 47 | 48 | 49 | ### Step 2: Set Up Databases 50 | By default, this project includes only one database (i.e., `singer`) in the `databases` folder. 51 | 52 | - To access all databases available in our online demo: 53 | 1. Download and unzip the comprehensive database collection from [databases.zip](https://pan.quark.cn/s/fc6b1ed32fc6). This package contains 248 different databases by merging two manually created databases (Aminer_Simplified and Bank_Financials) with all databases from the [BIRD](https://bird-bench.github.io) and [Spider](https://yale-lily.github.io/spider) benchmarks. 54 | 55 | - To add and use your own databases: 56 | 1. Place your SQLite database file in the `databases` directory. 57 | 2. Update the `./data/tables.json` file with the necessary information about your database, including: 58 | - `db_id`: The name of your database (e.g., `my_db` for a database file located at `databases/my_db/my_db.sqlite`). 59 | - `table_names_original` and `column_names_original`: The original names of tables and columns in your database. 60 | - `table_names` and `column_names`: The semantic names (or comments) for the tables and columns in your database. 61 | 62 | ### Step 3: Build the BM25 Index 63 | To enhance the efficiency of content-based queries on the databases, run the following command to build the BM25 index: 64 | ``` 65 | python -u build_contents_index.py 66 | ``` 67 | Please note that this process might take a considerable amount of time, depending on the size and content of the databases. Your patience is appreciated during this step. 68 | 69 | Upon completing these steps, your project should be fully configured. 70 | 71 | ## Launch services 🚀 72 | To initiate the website, execute the following command: 73 | ``` 74 | python -u app.py 75 | ``` 76 | This action will start the web application, making it accessible at `http://your_ip:5000/chatbot`. Please note that the user's history questions will be logged and can be accessed in the `data/history/history.sqlite` file. 77 | 78 | ## Support various languages 🧐 79 | Given that our model is predominantly trained on English text, integrating a translation API becomes essential for handling user's questions in languages other than English. 80 | 81 | In this project, we have utilized Baidu Translate. To facilitate multilingual support, please configure your Baidu Translate API token within the `app.py` script. You can follow these guidelines to create your Baidu Translate application and acquire the necessary API access token: [Baidu Translate API Guide](https://ai.baidu.com/ai-doc/MT/2l317egif) and [Baidu Translate API Reference](https://ai.baidu.com/ai-doc/REFERENCE/Ck3dwjhhu). 82 | 83 | It is important to note that the translation quality may influence the model's accuracy. For enhanced text-to-SQL performance in languages of your preference, consider opting for more robust translation engines like Google Translate or DeepL. 84 | 85 | ## Get in Touch 🤗 86 | For any questions about this project, feel free to open a Github issue or directly contact Haoyang Li via email at lihaoyang.cs@ruc.edu.cn. 87 | 88 | ## Acknowledgments ✨ 89 | Our gratitude extends to the teams behind [ChatBot💬 WebApp in Python using Flask](https://github.com/Spidy20/Flask_NLP_ChatBot), [BIRD](https://bird-bench.github.io), and [Spider](https://yale-lily.github.io/spider) for their outstanding contributions to the field. 90 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | from text2sql import ChatBot 2 | from flask import Flask, render_template, request 3 | from langdetect import detect 4 | from utils.translate_utils import translate_zh_to_en 5 | from utils.db_utils import add_a_record 6 | from langdetect.lang_detect_exception import LangDetectException 7 | 8 | text2sql_bot = ChatBot() 9 | # replace None with your API token 10 | baidu_api_token = None 11 | 12 | app = Flask(__name__) 13 | 14 | @app.route("/chatbot") 15 | def home(): 16 | return render_template("index.html") 17 | 18 | @app.route("/get_db_ids") 19 | def get_db_ids(): 20 | global text2sql_bot 21 | return text2sql_bot.db_ids 22 | 23 | @app.route("/get_db_ddl") 24 | def get_db_ddl(): 25 | global text2sql_bot 26 | db_id = request.args.get('db_id') 27 | 28 | return text2sql_bot.db_id2ddl[db_id] 29 | 30 | @app.route("/get") 31 | def get_bot_response(): 32 | global text2sql_bot 33 | question = request.args.get('msg') 34 | db_id = request.args.get('db_id') 35 | add_a_record(question, db_id) 36 | 37 | if question.strip() == "": 38 | return "Sorry, your question is empty." 39 | 40 | try: 41 | if baidu_api_token is not None and detect(question) != "en": 42 | print("Before tanslation:", question) 43 | question = translate_zh_to_en(question, baidu_api_token) 44 | print("After tanslation:", question) 45 | except LangDetectException as e: 46 | print("language detection error:", str(e)) 47 | 48 | predicted_sql = text2sql_bot.get_response(question, db_id) 49 | print("predicted sql:", predicted_sql) 50 | 51 | response = "Database:
" + db_id + "

" 52 | response += "Predicted SQL query:
" + predicted_sql 53 | return response 54 | 55 | app.run(host = "0.0.0.0", debug = False) -------------------------------------------------------------------------------- /build_contents_index.py: -------------------------------------------------------------------------------- 1 | from utils.db_utils import get_cursor_from_path, execute_sql_long_time_limitation 2 | import json 3 | import os, shutil 4 | 5 | def remove_contents_of_a_folder(index_path): 6 | # if index_path does not exist, then create it 7 | os.makedirs(index_path, exist_ok = True) 8 | # remove files in index_path 9 | for filename in os.listdir(index_path): 10 | file_path = os.path.join(index_path, filename) 11 | try: 12 | if os.path.isfile(file_path) or os.path.islink(file_path): 13 | os.unlink(file_path) 14 | elif os.path.isdir(file_path): 15 | shutil.rmtree(file_path) 16 | except Exception as e: 17 | print('Failed to delete %s. Reason: %s' % (file_path, e)) 18 | 19 | def build_content_index(db_path, index_path): 20 | ''' 21 | Create a BM25 index for all contents in a database 22 | ''' 23 | cursor = get_cursor_from_path(db_path) 24 | results = execute_sql_long_time_limitation(cursor, "SELECT name FROM sqlite_master WHERE type='table';") 25 | table_names = [result[0] for result in results] 26 | 27 | all_column_contents = [] 28 | for table_name in table_names: 29 | # skip SQLite system table: sqlite_sequence 30 | if table_name == "sqlite_sequence": 31 | continue 32 | results = execute_sql_long_time_limitation(cursor, "SELECT name FROM PRAGMA_TABLE_INFO('{}')".format(table_name)) 33 | column_names_in_one_table = [result[0] for result in results] 34 | for column_name in column_names_in_one_table: 35 | try: 36 | print("SELECT DISTINCT `{}` FROM `{}` WHERE `{}` IS NOT NULL;".format(column_name, table_name, column_name)) 37 | results = execute_sql_long_time_limitation(cursor, "SELECT DISTINCT `{}` FROM `{}` WHERE `{}` IS NOT NULL;".format(column_name, table_name, column_name)) 38 | column_contents = [str(result[0]).strip() for result in results] 39 | 40 | for c_id, column_content in enumerate(column_contents): 41 | # remove empty and extremely-long contents 42 | if len(column_content) != 0 and len(column_content) <= 25: 43 | all_column_contents.append( 44 | { 45 | "id": "{}-**-{}-**-{}".format(table_name, column_name, c_id).lower(), 46 | "contents": column_content 47 | } 48 | ) 49 | except Exception as e: 50 | print(str(e)) 51 | 52 | with open("./data/temp_db_index/contents.json", "w") as f: 53 | f.write(json.dumps(all_column_contents, indent = 2, ensure_ascii = True)) 54 | 55 | # Building a BM25 Index (Direct Java Implementation), see https://github.com/castorini/pyserini/blob/master/docs/usage-index.md 56 | cmd = "python -m pyserini.index.lucene --collection JsonCollection --input ./data/temp_db_index --index {} --generator DefaultLuceneDocumentGenerator --threads 16 --storePositions --storeDocvectors --storeRaw".format(index_path) 57 | 58 | d = os.system(cmd) 59 | print(d) 60 | os.remove("./data/temp_db_index/contents.json") 61 | 62 | if __name__ == "__main__": 63 | os.makedirs('./data/temp_db_index', exist_ok = True) 64 | 65 | print("build content index for databases...") 66 | remove_contents_of_a_folder("db_contents_index") 67 | # build content index for Bank_Financials's training set databases 68 | for db_id in os.listdir("databases"): 69 | print(db_id) 70 | build_content_index( 71 | os.path.join("databases", db_id, db_id + ".sqlite"), 72 | os.path.join("db_contents_index", db_id) 73 | ) 74 | 75 | os.rmdir('./data/temp_db_index') -------------------------------------------------------------------------------- /data/history/history.sqlite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RUCKBReasoning/text2sql-demo/99197352ad9c8b453a175bde9200d7d1d29e8d30/data/history/history.sqlite -------------------------------------------------------------------------------- /databases/singer/schema.sql: -------------------------------------------------------------------------------- 1 | PRAGMA foreign_keys = ON; 2 | 3 | CREATE TABLE "singer" ( 4 | "Singer_ID" int, 5 | "Name" text, 6 | "Birth_Year" real, 7 | "Net_Worth_Millions" real, 8 | "Citizenship" text, 9 | PRIMARY KEY ("Singer_ID") 10 | ); 11 | 12 | CREATE TABLE "song" ( 13 | "Song_ID" int, 14 | "Title" text, 15 | "Singer_ID" int, 16 | "Sales" real, 17 | "Highest_Position" real, 18 | PRIMARY KEY ("Song_ID"), 19 | FOREIGN KEY ("Singer_ID") REFERENCES `singer`("Singer_ID") 20 | ); 21 | 22 | INSERT INTO "singer" VALUES (1,"Liliane Bettencourt","1944","30.0","France"); 23 | INSERT INTO "singer" VALUES (2,"Christy Walton","1948","28.8","United States"); 24 | INSERT INTO "singer" VALUES (3,"Alice Walton","1949","26.3","United States"); 25 | INSERT INTO "singer" VALUES (4,"Iris Fontbona","1942","17.4","Chile"); 26 | INSERT INTO "singer" VALUES (5,"Jacqueline Mars","1940","17.8","United States"); 27 | INSERT INTO "singer" VALUES (6,"Gina Rinehart","1953","17","Australia"); 28 | INSERT INTO "singer" VALUES (7,"Susanne Klatten","1962","14.3","Germany"); 29 | INSERT INTO "singer" VALUES (8,"Abigail Johnson","1961","12.7","United States"); 30 | 31 | INSERT INTO "song" VALUES ("1","Do They Know It's Christmas",1,"1094000","1"); 32 | INSERT INTO "song" VALUES ("2","F**k It (I Don't Want You Back)",1,"552407","1"); 33 | INSERT INTO "song" VALUES ("3","Cha Cha Slide",2,"351421","1"); 34 | INSERT INTO "song" VALUES ("4","Call on Me",4,"335000","1"); 35 | INSERT INTO "song" VALUES ("5","Yeah",2,"300000","1"); 36 | INSERT INTO "song" VALUES ("6","All This Time",6,"292000","1"); 37 | INSERT INTO "song" VALUES ("7","Left Outside Alone",5,"275000","3"); 38 | INSERT INTO "song" VALUES ("8","Mysterious Girl",7,"261000","1"); 39 | 40 | -------------------------------------------------------------------------------- /databases/singer/singer.sqlite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RUCKBReasoning/text2sql-demo/99197352ad9c8b453a175bde9200d7d1d29e8d30/databases/singer/singer.sqlite -------------------------------------------------------------------------------- /images/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RUCKBReasoning/text2sql-demo/99197352ad9c8b453a175bde9200d7d1d29e8d30/images/demo.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flask 2 | websocket-client 3 | python-socketio 4 | eventlet 5 | langdetect 6 | func_timeout==4.3.5 7 | nltk==3.7 8 | numpy==1.23.5 9 | pandas==2.0.1 10 | rapidfuzz==2.0.11 11 | tqdm==4.63.0 12 | transformers==4.32.0 13 | chardet==5.2.0 14 | sqlparse==0.4.4 15 | accelerate==0.22.0 16 | bitsandbytes==0.41.1 17 | pyserini==0.21.0 18 | sql_metadata==2.8.0 19 | datasets==2.11.0 20 | faiss-cpu==1.7.4 -------------------------------------------------------------------------------- /schema_item_filter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | 5 | from tqdm import tqdm 6 | from transformers import AutoTokenizer 7 | from utils.classifier_model import SchemaItemClassifier 8 | from transformers.trainer_utils import set_seed 9 | 10 | def prepare_inputs_and_labels(sample, tokenizer): 11 | table_names = [table["table_name"] for table in sample["schema"]["schema_items"]] 12 | column_names = [table["column_names"] for table in sample["schema"]["schema_items"]] 13 | column_num_in_each_table = [len(table["column_names"]) for table in sample["schema"]["schema_items"]] 14 | 15 | # `column_name_word_indices` and `table_name_word_indices` record the word indices of each column and table in `input_words`, whose element is an integer 16 | column_name_word_indices, table_name_word_indices = [], [] 17 | 18 | input_words = [sample["text"]] 19 | for table_id, table_name in enumerate(table_names): 20 | input_words.append("|") 21 | input_words.append(table_name) 22 | table_name_word_indices.append(len(input_words) - 1) 23 | input_words.append(":") 24 | 25 | for column_name in column_names[table_id]: 26 | input_words.append(column_name) 27 | column_name_word_indices.append(len(input_words) - 1) 28 | input_words.append(",") 29 | 30 | # remove the last "," 31 | input_words = input_words[:-1] 32 | 33 | tokenized_inputs = tokenizer( 34 | input_words, 35 | return_tensors="pt", 36 | is_split_into_words = True, 37 | padding = "max_length", 38 | max_length = 512, 39 | truncation = True 40 | ) 41 | 42 | # after tokenizing, one table name or column name may be splitted into multiple tokens (i.e., sub-words) 43 | # `column_name_token_indices` and `table_name_token_indices` records the token indices of each column and table in `input_ids`, whose element is a list of integer 44 | column_name_token_indices, table_name_token_indices = [], [] 45 | word_indices = tokenized_inputs.word_ids(batch_index = 0) 46 | 47 | # obtain token indices of each column in `input_ids` 48 | for column_name_word_index in column_name_word_indices: 49 | column_name_token_indices.append([token_id for token_id, word_index in enumerate(word_indices) if column_name_word_index == word_index]) 50 | 51 | # obtain token indices of each table in `input_ids` 52 | for table_name_word_index in table_name_word_indices: 53 | table_name_token_indices.append([token_id for token_id, word_index in enumerate(word_indices) if table_name_word_index == word_index]) 54 | 55 | encoder_input_ids = tokenized_inputs["input_ids"] 56 | encoder_input_attention_mask = tokenized_inputs["attention_mask"] 57 | 58 | # print("\n".join(tokenizer.batch_decode(encoder_input_ids, skip_special_tokens = True))) 59 | 60 | if torch.cuda.is_available(): 61 | encoder_input_ids = encoder_input_ids.cuda() 62 | encoder_input_attention_mask = encoder_input_attention_mask.cuda() 63 | 64 | return encoder_input_ids, encoder_input_attention_mask, \ 65 | column_name_token_indices, table_name_token_indices, column_num_in_each_table 66 | 67 | def get_schema(tables_and_columns): 68 | schema_items = [] 69 | table_names = list(dict.fromkeys([t for t, c in tables_and_columns])) 70 | for table_name in table_names: 71 | schema_items.append( 72 | { 73 | "table_name": table_name, 74 | "column_names": [c for t, c in tables_and_columns if t == table_name] 75 | } 76 | ) 77 | 78 | return {"schema_items": schema_items} 79 | 80 | def get_sequence_length(text, tables_and_columns, tokenizer): 81 | table_names = [t for t, c in tables_and_columns] 82 | # duplicate `table_names` while preserving order 83 | table_names = list(dict.fromkeys(table_names)) 84 | 85 | column_names = [] 86 | for table_name in table_names: 87 | column_names.append([c for t, c in tables_and_columns if t == table_name]) 88 | 89 | input_words = [text] 90 | for table_id, table_name in enumerate(table_names): 91 | input_words.append("|") 92 | input_words.append(table_name) 93 | input_words.append(":") 94 | for column_name in column_names[table_id]: 95 | input_words.append(column_name) 96 | input_words.append(",") 97 | # remove the last "," 98 | input_words = input_words[:-1] 99 | 100 | tokenized_inputs = tokenizer(input_words, is_split_into_words = True) 101 | 102 | return len(tokenized_inputs["input_ids"]) 103 | 104 | # handle extremely long schema sequences 105 | def split_sample(sample, tokenizer): 106 | text = sample["text"] 107 | 108 | table_names = [] 109 | column_names = [] 110 | for table in sample["schema"]["schema_items"]: 111 | table_names.append(table["table_name"] + " ( " + table["table_comment"] + " ) " \ 112 | if table["table_comment"] != "" else table["table_name"]) 113 | column_names.append([column_name + " ( " + column_comment + " ) " \ 114 | if column_comment != "" else column_name \ 115 | for column_name, column_comment in zip(table["column_names"], table["column_comments"])]) 116 | 117 | splitted_samples = [] 118 | recorded_tables_and_columns = [] 119 | 120 | for table_idx, table_name in enumerate(table_names): 121 | for column_name in column_names[table_idx]: 122 | if get_sequence_length(text, recorded_tables_and_columns + [[table_name, column_name]], tokenizer) < 500: 123 | recorded_tables_and_columns.append([table_name, column_name]) 124 | else: 125 | splitted_samples.append( 126 | { 127 | "text": text, 128 | "schema": get_schema(recorded_tables_and_columns) 129 | } 130 | ) 131 | recorded_tables_and_columns = [[table_name, column_name]] 132 | 133 | splitted_samples.append( 134 | { 135 | "text": text, 136 | "schema": get_schema(recorded_tables_and_columns) 137 | } 138 | ) 139 | 140 | return splitted_samples 141 | 142 | def merge_pred_results(sample, pred_results): 143 | # table_names = [table["table_name"] for table in sample["schema"]["schema_items"]] 144 | # column_names = [table["column_names"] for table in sample["schema"]["schema_items"]] 145 | table_names = [] 146 | column_names = [] 147 | for table in sample["schema"]["schema_items"]: 148 | table_names.append(table["table_name"] + " ( " + table["table_comment"] + " ) " \ 149 | if table["table_comment"] != "" else table["table_name"]) 150 | column_names.append([column_name + " ( " + column_comment + " ) " \ 151 | if column_comment != "" else column_name \ 152 | for column_name, column_comment in zip(table["column_names"], table["column_comments"])]) 153 | 154 | merged_results = [] 155 | for table_id, table_name in enumerate(table_names): 156 | table_prob = 0 157 | column_probs = [] 158 | for result_dict in pred_results: 159 | if table_name in result_dict: 160 | if table_prob < result_dict[table_name]["table_prob"]: 161 | table_prob = result_dict[table_name]["table_prob"] 162 | column_probs += result_dict[table_name]["column_probs"] 163 | 164 | merged_results.append( 165 | { 166 | "table_name": table_name, 167 | "table_prob": table_prob, 168 | "column_names": column_names[table_id], 169 | "column_probs": column_probs 170 | } 171 | ) 172 | 173 | return merged_results 174 | 175 | def filter_schema(data, sic, num_top_k_tables = 5, num_top_k_columns = 5): 176 | filtered_schema = dict() 177 | filtered_matched_contents = dict() 178 | filtered_schema["schema_items"] = [] 179 | filtered_schema["foreign_keys"] = [] 180 | 181 | table_names = [table["table_name"] for table in data["schema"]["schema_items"]] 182 | table_comments = [table["table_comment"] for table in data["schema"]["schema_items"]] 183 | column_names = [table["column_names"] for table in data["schema"]["schema_items"]] 184 | column_types = [table["column_types"] for table in data["schema"]["schema_items"]] 185 | column_comments = [table["column_comments"] for table in data["schema"]["schema_items"]] 186 | column_contents = [table["column_contents"] for table in data["schema"]["schema_items"]] 187 | pk_indicators = [table["pk_indicators"] for table in data["schema"]["schema_items"]] 188 | 189 | # predict scores for each tables and columns 190 | pred_results = sic.predict(data) 191 | # remain top_k1 tables for each database and top_k2 columns for each remained table 192 | table_probs = [pred_result["table_prob"] for pred_result in pred_results] 193 | table_indices = np.argsort(-np.array(table_probs), kind="stable")[:num_top_k_tables].tolist() 194 | 195 | for table_idx in table_indices: 196 | column_probs = pred_results[table_idx]["column_probs"] 197 | column_indices = np.argsort(-np.array(column_probs), kind="stable")[:num_top_k_columns].tolist() 198 | 199 | filtered_schema["schema_items"].append( 200 | { 201 | "table_name": table_names[table_idx], 202 | "table_comment": table_comments[table_idx], 203 | "column_names": [column_names[table_idx][column_idx] for column_idx in column_indices], 204 | "column_types": [column_types[table_idx][column_idx] for column_idx in column_indices], 205 | "column_comments": [column_comments[table_idx][column_idx] for column_idx in column_indices], 206 | "column_contents": [column_contents[table_idx][column_idx] for column_idx in column_indices], 207 | "pk_indicators": [pk_indicators[table_idx][column_idx] for column_idx in column_indices] 208 | } 209 | ) 210 | 211 | # extract matched contents of remained columns 212 | for column_name in [column_names[table_idx][column_idx] for column_idx in column_indices]: 213 | tc_name = "{}.{}".format(table_names[table_idx], column_name) 214 | if tc_name in data["matched_contents"]: 215 | filtered_matched_contents[tc_name] = data["matched_contents"][tc_name] 216 | 217 | # extract foreign keys among remianed tables 218 | filtered_table_names = [table_names[table_idx] for table_idx in table_indices] 219 | for foreign_key in data["schema"]["foreign_keys"]: 220 | source_table, source_column, target_table, target_column = foreign_key 221 | if source_table in filtered_table_names and target_table in filtered_table_names: 222 | filtered_schema["foreign_keys"].append(foreign_key) 223 | 224 | # replace the old schema with the filtered schema 225 | data["schema"] = filtered_schema 226 | # replace the old matched contents with the filtered matched contents 227 | data["matched_contents"] = filtered_matched_contents 228 | 229 | return data 230 | 231 | def lista_contains_listb(lista, listb): 232 | for b in listb: 233 | if b not in lista: 234 | return 0 235 | 236 | return 1 237 | 238 | class SchemaItemClassifierInference(): 239 | def __init__(self, model_save_path): 240 | set_seed(42) 241 | # load tokenizer 242 | self.tokenizer = AutoTokenizer.from_pretrained(model_save_path, add_prefix_space = True) 243 | # initialize model 244 | self.model = SchemaItemClassifier(model_save_path, "test") 245 | # load fine-tuned params 246 | self.model.load_state_dict(torch.load(model_save_path + "/dense_classifier.pt", map_location=torch.device('cpu')), strict=False) 247 | if torch.cuda.is_available(): 248 | self.model = self.model.cuda() 249 | self.model.eval() 250 | 251 | def predict_one(self, sample): 252 | encoder_input_ids, encoder_input_attention_mask, column_name_token_indices,\ 253 | table_name_token_indices, column_num_in_each_table = prepare_inputs_and_labels(sample, self.tokenizer) 254 | 255 | with torch.no_grad(): 256 | model_outputs = self.model( 257 | encoder_input_ids, 258 | encoder_input_attention_mask, 259 | [column_name_token_indices], 260 | [table_name_token_indices], 261 | [column_num_in_each_table] 262 | ) 263 | 264 | table_logits = model_outputs["batch_table_name_cls_logits"][0] 265 | table_pred_probs = torch.nn.functional.softmax(table_logits, dim = 1)[:, 1].cpu().tolist() 266 | 267 | column_logits = model_outputs["batch_column_info_cls_logits"][0] 268 | column_pred_probs = torch.nn.functional.softmax(column_logits, dim = 1)[:, 1].cpu().tolist() 269 | 270 | splitted_column_pred_probs = [] 271 | # split predicted column probs into each table 272 | for table_id, column_num in enumerate(column_num_in_each_table): 273 | splitted_column_pred_probs.append(column_pred_probs[sum(column_num_in_each_table[:table_id]): sum(column_num_in_each_table[:table_id]) + column_num]) 274 | column_pred_probs = splitted_column_pred_probs 275 | 276 | result_dict = dict() 277 | for table_idx, table in enumerate(sample["schema"]["schema_items"]): 278 | result_dict[table["table_name"]] = { 279 | "table_name": table["table_name"], 280 | "table_prob": table_pred_probs[table_idx], 281 | "column_names": table["column_names"], 282 | "column_probs": column_pred_probs[table_idx], 283 | } 284 | 285 | return result_dict 286 | 287 | def predict(self, test_sample): 288 | splitted_samples = split_sample(test_sample, self.tokenizer) 289 | pred_results = [] 290 | for splitted_sample in splitted_samples: 291 | pred_results.append(self.predict_one(splitted_sample)) 292 | 293 | return merge_pred_results(test_sample, pred_results) 294 | 295 | def evaluate_coverage(self, dataset): 296 | max_k = 100 297 | total_num_for_table_coverage, total_num_for_column_coverage = 0, 0 298 | table_coverage_results = [0]*max_k 299 | column_coverage_results = [0]*max_k 300 | 301 | for data in dataset: 302 | indices_of_used_tables = [idx for idx, label in enumerate(data["table_labels"]) if label == 1] 303 | pred_results = sic.predict(data) 304 | # print(pred_results) 305 | table_probs = [res["table_prob"] for res in pred_results] 306 | for k in range(max_k): 307 | indices_of_top_k_tables = np.argsort(-np.array(table_probs), kind="stable")[:k+1].tolist() 308 | if lista_contains_listb(indices_of_top_k_tables, indices_of_used_tables): 309 | table_coverage_results[k] += 1 310 | total_num_for_table_coverage += 1 311 | 312 | for table_idx in range(len(data["table_labels"])): 313 | indices_of_used_columns = [idx for idx, label in enumerate(data["column_labels"][table_idx]) if label == 1] 314 | if len(indices_of_used_columns) == 0: 315 | continue 316 | column_probs = pred_results[table_idx]["column_probs"] 317 | for k in range(max_k): 318 | indices_of_top_k_columns = np.argsort(-np.array(column_probs), kind="stable")[:k+1].tolist() 319 | if lista_contains_listb(indices_of_top_k_columns, indices_of_used_columns): 320 | column_coverage_results[k] += 1 321 | 322 | total_num_for_column_coverage += 1 323 | 324 | indices_of_top_10_columns = np.argsort(-np.array(column_probs), kind="stable")[:10].tolist() 325 | if lista_contains_listb(indices_of_top_10_columns, indices_of_used_columns) == 0: 326 | print(pred_results[table_idx]) 327 | print(data["column_labels"][table_idx]) 328 | print(data["question"]) 329 | 330 | print(total_num_for_table_coverage) 331 | print(table_coverage_results) 332 | print(total_num_for_column_coverage) 333 | print(column_coverage_results) 334 | 335 | if __name__ == "__main__": 336 | dataset_name = "bird_with_evidence" 337 | # dataset_name = "bird" 338 | # dataset_name = "spider" 339 | sic = SchemaItemClassifierInference("sic_ckpts/sic_{}".format(dataset_name)) 340 | import json 341 | dataset = json.load(open("./data/sft_eval_{}_text2sql.json".format(dataset_name))) 342 | 343 | sic.evaluate_coverage(dataset) -------------------------------------------------------------------------------- /static/js/jquery-3.7.1.min.js: -------------------------------------------------------------------------------- 1 | /*! jQuery v3.7.1 | (c) OpenJS Foundation and other contributors | jquery.org/license */ 2 | !function(e,t){"use strict";"object"==typeof module&&"object"==typeof module.exports?module.exports=e.document?t(e,!0):function(e){if(!e.document)throw new Error("jQuery requires a window with a document");return t(e)}:t(e)}("undefined"!=typeof window?window:this,function(ie,e){"use strict";var oe=[],r=Object.getPrototypeOf,ae=oe.slice,g=oe.flat?function(e){return oe.flat.call(e)}:function(e){return oe.concat.apply([],e)},s=oe.push,se=oe.indexOf,n={},i=n.toString,ue=n.hasOwnProperty,o=ue.toString,a=o.call(Object),le={},v=function(e){return"function"==typeof e&&"number"!=typeof e.nodeType&&"function"!=typeof e.item},y=function(e){return null!=e&&e===e.window},C=ie.document,u={type:!0,src:!0,nonce:!0,noModule:!0};function m(e,t,n){var r,i,o=(n=n||C).createElement("script");if(o.text=e,t)for(r in u)(i=t[r]||t.getAttribute&&t.getAttribute(r))&&o.setAttribute(r,i);n.head.appendChild(o).parentNode.removeChild(o)}function x(e){return null==e?e+"":"object"==typeof e||"function"==typeof e?n[i.call(e)]||"object":typeof e}var t="3.7.1",l=/HTML$/i,ce=function(e,t){return new ce.fn.init(e,t)};function c(e){var t=!!e&&"length"in e&&e.length,n=x(e);return!v(e)&&!y(e)&&("array"===n||0===t||"number"==typeof t&&0+~]|"+ge+")"+ge+"*"),x=new RegExp(ge+"|>"),j=new RegExp(g),A=new RegExp("^"+t+"$"),D={ID:new RegExp("^#("+t+")"),CLASS:new RegExp("^\\.("+t+")"),TAG:new RegExp("^("+t+"|[*])"),ATTR:new RegExp("^"+p),PSEUDO:new RegExp("^"+g),CHILD:new RegExp("^:(only|first|last|nth|nth-last)-(child|of-type)(?:\\("+ge+"*(even|odd|(([+-]|)(\\d*)n|)"+ge+"*(?:([+-]|)"+ge+"*(\\d+)|))"+ge+"*\\)|)","i"),bool:new RegExp("^(?:"+f+")$","i"),needsContext:new RegExp("^"+ge+"*[>+~]|:(even|odd|eq|gt|lt|nth|first|last)(?:\\("+ge+"*((?:-\\d)?\\d*)"+ge+"*\\)|)(?=[^-]|$)","i")},N=/^(?:input|select|textarea|button)$/i,q=/^h\d$/i,L=/^(?:#([\w-]+)|(\w+)|\.([\w-]+))$/,H=/[+~]/,O=new RegExp("\\\\[\\da-fA-F]{1,6}"+ge+"?|\\\\([^\\r\\n\\f])","g"),P=function(e,t){var n="0x"+e.slice(1)-65536;return t||(n<0?String.fromCharCode(n+65536):String.fromCharCode(n>>10|55296,1023&n|56320))},M=function(){V()},R=J(function(e){return!0===e.disabled&&fe(e,"fieldset")},{dir:"parentNode",next:"legend"});try{k.apply(oe=ae.call(ye.childNodes),ye.childNodes),oe[ye.childNodes.length].nodeType}catch(e){k={apply:function(e,t){me.apply(e,ae.call(t))},call:function(e){me.apply(e,ae.call(arguments,1))}}}function I(t,e,n,r){var i,o,a,s,u,l,c,f=e&&e.ownerDocument,p=e?e.nodeType:9;if(n=n||[],"string"!=typeof t||!t||1!==p&&9!==p&&11!==p)return n;if(!r&&(V(e),e=e||T,C)){if(11!==p&&(u=L.exec(t)))if(i=u[1]){if(9===p){if(!(a=e.getElementById(i)))return n;if(a.id===i)return k.call(n,a),n}else if(f&&(a=f.getElementById(i))&&I.contains(e,a)&&a.id===i)return k.call(n,a),n}else{if(u[2])return k.apply(n,e.getElementsByTagName(t)),n;if((i=u[3])&&e.getElementsByClassName)return k.apply(n,e.getElementsByClassName(i)),n}if(!(h[t+" "]||d&&d.test(t))){if(c=t,f=e,1===p&&(x.test(t)||m.test(t))){(f=H.test(t)&&U(e.parentNode)||e)==e&&le.scope||((s=e.getAttribute("id"))?s=ce.escapeSelector(s):e.setAttribute("id",s=S)),o=(l=Y(t)).length;while(o--)l[o]=(s?"#"+s:":scope")+" "+Q(l[o]);c=l.join(",")}try{return k.apply(n,f.querySelectorAll(c)),n}catch(e){h(t,!0)}finally{s===S&&e.removeAttribute("id")}}}return re(t.replace(ve,"$1"),e,n,r)}function W(){var r=[];return function e(t,n){return r.push(t+" ")>b.cacheLength&&delete e[r.shift()],e[t+" "]=n}}function F(e){return e[S]=!0,e}function $(e){var t=T.createElement("fieldset");try{return!!e(t)}catch(e){return!1}finally{t.parentNode&&t.parentNode.removeChild(t),t=null}}function B(t){return function(e){return fe(e,"input")&&e.type===t}}function _(t){return function(e){return(fe(e,"input")||fe(e,"button"))&&e.type===t}}function z(t){return function(e){return"form"in e?e.parentNode&&!1===e.disabled?"label"in e?"label"in e.parentNode?e.parentNode.disabled===t:e.disabled===t:e.isDisabled===t||e.isDisabled!==!t&&R(e)===t:e.disabled===t:"label"in e&&e.disabled===t}}function X(a){return F(function(o){return o=+o,F(function(e,t){var n,r=a([],e.length,o),i=r.length;while(i--)e[n=r[i]]&&(e[n]=!(t[n]=e[n]))})})}function U(e){return e&&"undefined"!=typeof e.getElementsByTagName&&e}function V(e){var t,n=e?e.ownerDocument||e:ye;return n!=T&&9===n.nodeType&&n.documentElement&&(r=(T=n).documentElement,C=!ce.isXMLDoc(T),i=r.matches||r.webkitMatchesSelector||r.msMatchesSelector,r.msMatchesSelector&&ye!=T&&(t=T.defaultView)&&t.top!==t&&t.addEventListener("unload",M),le.getById=$(function(e){return r.appendChild(e).id=ce.expando,!T.getElementsByName||!T.getElementsByName(ce.expando).length}),le.disconnectedMatch=$(function(e){return i.call(e,"*")}),le.scope=$(function(){return T.querySelectorAll(":scope")}),le.cssHas=$(function(){try{return T.querySelector(":has(*,:jqfake)"),!1}catch(e){return!0}}),le.getById?(b.filter.ID=function(e){var t=e.replace(O,P);return function(e){return e.getAttribute("id")===t}},b.find.ID=function(e,t){if("undefined"!=typeof t.getElementById&&C){var n=t.getElementById(e);return n?[n]:[]}}):(b.filter.ID=function(e){var n=e.replace(O,P);return function(e){var t="undefined"!=typeof e.getAttributeNode&&e.getAttributeNode("id");return t&&t.value===n}},b.find.ID=function(e,t){if("undefined"!=typeof t.getElementById&&C){var n,r,i,o=t.getElementById(e);if(o){if((n=o.getAttributeNode("id"))&&n.value===e)return[o];i=t.getElementsByName(e),r=0;while(o=i[r++])if((n=o.getAttributeNode("id"))&&n.value===e)return[o]}return[]}}),b.find.TAG=function(e,t){return"undefined"!=typeof t.getElementsByTagName?t.getElementsByTagName(e):t.querySelectorAll(e)},b.find.CLASS=function(e,t){if("undefined"!=typeof t.getElementsByClassName&&C)return t.getElementsByClassName(e)},d=[],$(function(e){var t;r.appendChild(e).innerHTML="",e.querySelectorAll("[selected]").length||d.push("\\["+ge+"*(?:value|"+f+")"),e.querySelectorAll("[id~="+S+"-]").length||d.push("~="),e.querySelectorAll("a#"+S+"+*").length||d.push(".#.+[+~]"),e.querySelectorAll(":checked").length||d.push(":checked"),(t=T.createElement("input")).setAttribute("type","hidden"),e.appendChild(t).setAttribute("name","D"),r.appendChild(e).disabled=!0,2!==e.querySelectorAll(":disabled").length&&d.push(":enabled",":disabled"),(t=T.createElement("input")).setAttribute("name",""),e.appendChild(t),e.querySelectorAll("[name='']").length||d.push("\\["+ge+"*name"+ge+"*="+ge+"*(?:''|\"\")")}),le.cssHas||d.push(":has"),d=d.length&&new RegExp(d.join("|")),l=function(e,t){if(e===t)return a=!0,0;var n=!e.compareDocumentPosition-!t.compareDocumentPosition;return n||(1&(n=(e.ownerDocument||e)==(t.ownerDocument||t)?e.compareDocumentPosition(t):1)||!le.sortDetached&&t.compareDocumentPosition(e)===n?e===T||e.ownerDocument==ye&&I.contains(ye,e)?-1:t===T||t.ownerDocument==ye&&I.contains(ye,t)?1:o?se.call(o,e)-se.call(o,t):0:4&n?-1:1)}),T}for(e in I.matches=function(e,t){return I(e,null,null,t)},I.matchesSelector=function(e,t){if(V(e),C&&!h[t+" "]&&(!d||!d.test(t)))try{var n=i.call(e,t);if(n||le.disconnectedMatch||e.document&&11!==e.document.nodeType)return n}catch(e){h(t,!0)}return 0":{dir:"parentNode",first:!0}," ":{dir:"parentNode"},"+":{dir:"previousSibling",first:!0},"~":{dir:"previousSibling"}},preFilter:{ATTR:function(e){return e[1]=e[1].replace(O,P),e[3]=(e[3]||e[4]||e[5]||"").replace(O,P),"~="===e[2]&&(e[3]=" "+e[3]+" "),e.slice(0,4)},CHILD:function(e){return e[1]=e[1].toLowerCase(),"nth"===e[1].slice(0,3)?(e[3]||I.error(e[0]),e[4]=+(e[4]?e[5]+(e[6]||1):2*("even"===e[3]||"odd"===e[3])),e[5]=+(e[7]+e[8]||"odd"===e[3])):e[3]&&I.error(e[0]),e},PSEUDO:function(e){var t,n=!e[6]&&e[2];return D.CHILD.test(e[0])?null:(e[3]?e[2]=e[4]||e[5]||"":n&&j.test(n)&&(t=Y(n,!0))&&(t=n.indexOf(")",n.length-t)-n.length)&&(e[0]=e[0].slice(0,t),e[2]=n.slice(0,t)),e.slice(0,3))}},filter:{TAG:function(e){var t=e.replace(O,P).toLowerCase();return"*"===e?function(){return!0}:function(e){return fe(e,t)}},CLASS:function(e){var t=s[e+" "];return t||(t=new RegExp("(^|"+ge+")"+e+"("+ge+"|$)"))&&s(e,function(e){return t.test("string"==typeof e.className&&e.className||"undefined"!=typeof e.getAttribute&&e.getAttribute("class")||"")})},ATTR:function(n,r,i){return function(e){var t=I.attr(e,n);return null==t?"!="===r:!r||(t+="","="===r?t===i:"!="===r?t!==i:"^="===r?i&&0===t.indexOf(i):"*="===r?i&&-1:\x20\t\r\n\f]*)[\x20\t\r\n\f]*\/?>(?:<\/\1>|)$/i;function T(e,n,r){return v(n)?ce.grep(e,function(e,t){return!!n.call(e,t,e)!==r}):n.nodeType?ce.grep(e,function(e){return e===n!==r}):"string"!=typeof n?ce.grep(e,function(e){return-1)[^>]*|#([\w-]+))$/;(ce.fn.init=function(e,t,n){var r,i;if(!e)return this;if(n=n||k,"string"==typeof e){if(!(r="<"===e[0]&&">"===e[e.length-1]&&3<=e.length?[null,e,null]:S.exec(e))||!r[1]&&t)return!t||t.jquery?(t||n).find(e):this.constructor(t).find(e);if(r[1]){if(t=t instanceof ce?t[0]:t,ce.merge(this,ce.parseHTML(r[1],t&&t.nodeType?t.ownerDocument||t:C,!0)),w.test(r[1])&&ce.isPlainObject(t))for(r in t)v(this[r])?this[r](t[r]):this.attr(r,t[r]);return this}return(i=C.getElementById(r[2]))&&(this[0]=i,this.length=1),this}return e.nodeType?(this[0]=e,this.length=1,this):v(e)?void 0!==n.ready?n.ready(e):e(ce):ce.makeArray(e,this)}).prototype=ce.fn,k=ce(C);var E=/^(?:parents|prev(?:Until|All))/,j={children:!0,contents:!0,next:!0,prev:!0};function A(e,t){while((e=e[t])&&1!==e.nodeType);return e}ce.fn.extend({has:function(e){var t=ce(e,this),n=t.length;return this.filter(function(){for(var e=0;e\x20\t\r\n\f]*)/i,Ce=/^$|^module$|\/(?:java|ecma)script/i;xe=C.createDocumentFragment().appendChild(C.createElement("div")),(be=C.createElement("input")).setAttribute("type","radio"),be.setAttribute("checked","checked"),be.setAttribute("name","t"),xe.appendChild(be),le.checkClone=xe.cloneNode(!0).cloneNode(!0).lastChild.checked,xe.innerHTML="",le.noCloneChecked=!!xe.cloneNode(!0).lastChild.defaultValue,xe.innerHTML="",le.option=!!xe.lastChild;var ke={thead:[1,"","
"],col:[2,"","
"],tr:[2,"","
"],td:[3,"","
"],_default:[0,"",""]};function Se(e,t){var n;return n="undefined"!=typeof e.getElementsByTagName?e.getElementsByTagName(t||"*"):"undefined"!=typeof e.querySelectorAll?e.querySelectorAll(t||"*"):[],void 0===t||t&&fe(e,t)?ce.merge([e],n):n}function Ee(e,t){for(var n=0,r=e.length;n",""]);var je=/<|&#?\w+;/;function Ae(e,t,n,r,i){for(var o,a,s,u,l,c,f=t.createDocumentFragment(),p=[],d=0,h=e.length;d\s*$/g;function Re(e,t){return fe(e,"table")&&fe(11!==t.nodeType?t:t.firstChild,"tr")&&ce(e).children("tbody")[0]||e}function Ie(e){return e.type=(null!==e.getAttribute("type"))+"/"+e.type,e}function We(e){return"true/"===(e.type||"").slice(0,5)?e.type=e.type.slice(5):e.removeAttribute("type"),e}function Fe(e,t){var n,r,i,o,a,s;if(1===t.nodeType){if(_.hasData(e)&&(s=_.get(e).events))for(i in _.remove(t,"handle events"),s)for(n=0,r=s[i].length;n").attr(n.scriptAttrs||{}).prop({charset:n.scriptCharset,src:n.url}).on("load error",i=function(e){r.remove(),i=null,e&&t("error"===e.type?404:200,e.type)}),C.head.appendChild(r[0])},abort:function(){i&&i()}}});var Jt,Kt=[],Zt=/(=)\?(?=&|$)|\?\?/;ce.ajaxSetup({jsonp:"callback",jsonpCallback:function(){var e=Kt.pop()||ce.expando+"_"+jt.guid++;return this[e]=!0,e}}),ce.ajaxPrefilter("json jsonp",function(e,t,n){var r,i,o,a=!1!==e.jsonp&&(Zt.test(e.url)?"url":"string"==typeof e.data&&0===(e.contentType||"").indexOf("application/x-www-form-urlencoded")&&Zt.test(e.data)&&"data");if(a||"jsonp"===e.dataTypes[0])return r=e.jsonpCallback=v(e.jsonpCallback)?e.jsonpCallback():e.jsonpCallback,a?e[a]=e[a].replace(Zt,"$1"+r):!1!==e.jsonp&&(e.url+=(At.test(e.url)?"&":"?")+e.jsonp+"="+r),e.converters["script json"]=function(){return o||ce.error(r+" was not called"),o[0]},e.dataTypes[0]="json",i=ie[r],ie[r]=function(){o=arguments},n.always(function(){void 0===i?ce(ie).removeProp(r):ie[r]=i,e[r]&&(e.jsonpCallback=t.jsonpCallback,Kt.push(r)),o&&v(i)&&i(o[0]),o=i=void 0}),"script"}),le.createHTMLDocument=((Jt=C.implementation.createHTMLDocument("").body).innerHTML="
",2===Jt.childNodes.length),ce.parseHTML=function(e,t,n){return"string"!=typeof e?[]:("boolean"==typeof t&&(n=t,t=!1),t||(le.createHTMLDocument?((r=(t=C.implementation.createHTMLDocument("")).createElement("base")).href=C.location.href,t.head.appendChild(r)):t=C),o=!n&&[],(i=w.exec(e))?[t.createElement(i[1])]:(i=Ae([e],t,o),o&&o.length&&ce(o).remove(),ce.merge([],i.childNodes)));var r,i,o},ce.fn.load=function(e,t,n){var r,i,o,a=this,s=e.indexOf(" ");return-1").append(ce.parseHTML(e)).find(r):e)}).always(n&&function(e,t){a.each(function(){n.apply(this,o||[e.responseText,t,e])})}),this},ce.expr.pseudos.animated=function(t){return ce.grep(ce.timers,function(e){return t===e.elem}).length},ce.offset={setOffset:function(e,t,n){var r,i,o,a,s,u,l=ce.css(e,"position"),c=ce(e),f={};"static"===l&&(e.style.position="relative"),s=c.offset(),o=ce.css(e,"top"),u=ce.css(e,"left"),("absolute"===l||"fixed"===l)&&-1<(o+u).indexOf("auto")?(a=(r=c.position()).top,i=r.left):(a=parseFloat(o)||0,i=parseFloat(u)||0),v(t)&&(t=t.call(e,n,ce.extend({},s))),null!=t.top&&(f.top=t.top-s.top+a),null!=t.left&&(f.left=t.left-s.left+i),"using"in t?t.using.call(e,f):c.css(f)}},ce.fn.extend({offset:function(t){if(arguments.length)return void 0===t?this:this.each(function(e){ce.offset.setOffset(this,t,e)});var e,n,r=this[0];return r?r.getClientRects().length?(e=r.getBoundingClientRect(),n=r.ownerDocument.defaultView,{top:e.top+n.pageYOffset,left:e.left+n.pageXOffset}):{top:0,left:0}:void 0},position:function(){if(this[0]){var e,t,n,r=this[0],i={top:0,left:0};if("fixed"===ce.css(r,"position"))t=r.getBoundingClientRect();else{t=this.offset(),n=r.ownerDocument,e=r.offsetParent||n.documentElement;while(e&&(e===n.body||e===n.documentElement)&&"static"===ce.css(e,"position"))e=e.parentNode;e&&e!==r&&1===e.nodeType&&((i=ce(e).offset()).top+=ce.css(e,"borderTopWidth",!0),i.left+=ce.css(e,"borderLeftWidth",!0))}return{top:t.top-i.top-ce.css(r,"marginTop",!0),left:t.left-i.left-ce.css(r,"marginLeft",!0)}}},offsetParent:function(){return this.map(function(){var e=this.offsetParent;while(e&&"static"===ce.css(e,"position"))e=e.offsetParent;return e||J})}}),ce.each({scrollLeft:"pageXOffset",scrollTop:"pageYOffset"},function(t,i){var o="pageYOffset"===i;ce.fn[t]=function(e){return M(this,function(e,t,n){var r;if(y(e)?r=e:9===e.nodeType&&(r=e.defaultView),void 0===n)return r?r[i]:e[t];r?r.scrollTo(o?r.pageXOffset:n,o?n:r.pageYOffset):e[t]=n},t,e,arguments.length)}}),ce.each(["top","left"],function(e,n){ce.cssHooks[n]=Ye(le.pixelPosition,function(e,t){if(t)return t=Ge(e,n),_e.test(t)?ce(e).position()[n]+"px":t})}),ce.each({Height:"height",Width:"width"},function(a,s){ce.each({padding:"inner"+a,content:s,"":"outer"+a},function(r,o){ce.fn[o]=function(e,t){var n=arguments.length&&(r||"boolean"!=typeof e),i=r||(!0===e||!0===t?"margin":"border");return M(this,function(e,t,n){var r;return y(e)?0===o.indexOf("outer")?e["inner"+a]:e.document.documentElement["client"+a]:9===e.nodeType?(r=e.documentElement,Math.max(e.body["scroll"+a],r["scroll"+a],e.body["offset"+a],r["offset"+a],r["client"+a])):void 0===n?ce.css(e,t,i):ce.style(e,t,n,i)},s,n?e:void 0,n)}})}),ce.each(["ajaxStart","ajaxStop","ajaxComplete","ajaxError","ajaxSuccess","ajaxSend"],function(e,t){ce.fn[t]=function(e){return this.on(t,e)}}),ce.fn.extend({bind:function(e,t,n){return this.on(e,null,t,n)},unbind:function(e,t){return this.off(e,null,t)},delegate:function(e,t,n,r){return this.on(t,e,n,r)},undelegate:function(e,t,n){return 1===arguments.length?this.off(e,"**"):this.off(t,e||"**",n)},hover:function(e,t){return this.on("mouseenter",e).on("mouseleave",t||e)}}),ce.each("blur focus focusin focusout resize scroll click dblclick mousedown mouseup mousemove mouseover mouseout mouseenter mouseleave change select submit keydown keypress keyup contextmenu".split(" "),function(e,n){ce.fn[n]=function(e,t){return 0 2 | 3 | 4 | 5 | 6 | CodeS demo 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 |
18 |
19 | Text-to-SQL demo 20 |
21 |
22 | 23 |
24 |
25 |
26 | 27 |
28 |
29 |
CodeS
30 |
31 |
32 | 33 |
34 | Welcome to the CodeS demo!

To begin, select your preferred database from the options available in the box on the right. After making your selection, input your questions in natural language and then CodeS will seamlessly translate your input into a valid SQL query.

For instance, you can input: "Help me find the three papers with the highest citations authored by 'Zhang Jing' from 'Renmin University of China'." 35 |
36 |
37 |
38 | 39 |
40 | 41 |
42 | 43 | 44 |
45 |
46 | 47 |
48 |
49 |
50 | Select a database 51 |
52 |
53 | 54 |
55 |
56 | 57 | 58 | 59 | 150 | 151 | 152 | 153 | -------------------------------------------------------------------------------- /text2sql.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import copy 5 | import re 6 | import sqlparse 7 | import sqlite3 8 | 9 | from tqdm import tqdm 10 | from utils.db_utils import get_db_schema 11 | from transformers import AutoModelForCausalLM, AutoTokenizer 12 | from pyserini.search.lucene import LuceneSearcher 13 | from utils.db_utils import check_sql_executability, get_matched_contents, get_db_schema_sequence, get_matched_content_sequence 14 | from schema_item_filter import SchemaItemClassifierInference, filter_schema 15 | 16 | def remove_similar_comments(names, comments): 17 | ''' 18 | Remove table (or column) comments that have a high degree of similarity with their names 19 | 20 | Arguments: 21 | names: a list of table (or column) names 22 | comments: a list of table (or column) comments 23 | 24 | Returns: 25 | new_comments: a list of new table (or column) comments 26 | ''' 27 | new_comments = [] 28 | for name, comment in zip(names, comments): 29 | if name.replace("_", "").replace(" ", "") == comment.replace("_", "").replace(" ", ""): 30 | new_comments.append("") 31 | else: 32 | new_comments.append(comment) 33 | 34 | return new_comments 35 | 36 | def load_db_comments(table_json_path): 37 | additional_db_info = json.load(open(table_json_path)) 38 | db_comments = dict() 39 | for db_info in additional_db_info: 40 | comment_dict = dict() 41 | 42 | column_names = [column_name.lower() for _, column_name in db_info["column_names_original"]] 43 | table_idx_of_each_column = [t_idx for t_idx, _ in db_info["column_names_original"]] 44 | column_comments = [column_comment.lower() for _, column_comment in db_info["column_names"]] 45 | 46 | assert len(column_names) == len(column_comments) 47 | column_comments = remove_similar_comments(column_names, column_comments) 48 | 49 | table_names = [table_name.lower() for table_name in db_info["table_names_original"]] 50 | table_comments = [table_comment.lower() for table_comment in db_info["table_names"]] 51 | 52 | assert len(table_names) == len(table_comments) 53 | table_comments = remove_similar_comments(table_names, table_comments) 54 | 55 | # enumerate each table and its columns 56 | for table_idx, (table_name, table_comment) in enumerate(zip(table_names, table_comments)): 57 | comment_dict[table_name] = { 58 | "table_comment": table_comment, 59 | "column_comments": dict() 60 | } 61 | for t_idx, column_name, column_comment in zip(table_idx_of_each_column, column_names, column_comments): 62 | # record columns in current table 63 | if t_idx == table_idx: 64 | comment_dict[table_name]["column_comments"][column_name] = column_comment 65 | 66 | db_comments[db_info["db_id"]] = comment_dict 67 | 68 | return db_comments 69 | 70 | def get_db_id2schema(db_path, tables_json): 71 | db_comments = load_db_comments(tables_json) 72 | db_id2schema = dict() 73 | 74 | for db_id in tqdm(os.listdir(db_path)): 75 | db_id2schema[db_id] = get_db_schema(os.path.join(db_path, db_id, db_id + ".sqlite"), db_comments, db_id) 76 | 77 | return db_id2schema 78 | 79 | def get_db_id2ddl(db_path): 80 | db_ids = os.listdir(db_path) 81 | db_id2ddl = dict() 82 | 83 | for db_id in db_ids: 84 | conn = sqlite3.connect(os.path.join(db_path, db_id, db_id + ".sqlite")) 85 | cursor = conn.cursor() 86 | cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table';") 87 | tables = cursor.fetchall() 88 | ddl = [] 89 | 90 | for table in tables: 91 | table_name = table[0] 92 | table_ddl = table[1] 93 | table_ddl.replace("\t", " ") 94 | while " " in table_ddl: 95 | table_ddl = table_ddl.replace(" ", " ") 96 | 97 | # remove comments 98 | table_ddl = re.sub(r'--.*', '', table_ddl) 99 | 100 | table_ddl = sqlparse.format(table_ddl, keyword_case = "upper", identifier_case = "lower", reindent_aligned = True) 101 | table_ddl = table_ddl.replace(", ", ",\n ") 102 | 103 | if table_ddl.endswith(";"): 104 | table_ddl = table_ddl[:-1] 105 | table_ddl = table_ddl[:-1] + "\n);" 106 | table_ddl = re.sub(r"(CREATE TABLE.*?)\(", r"\1(\n ", table_ddl) 107 | 108 | ddl.append(table_ddl) 109 | db_id2ddl[db_id] = "\n\n".join(ddl) 110 | 111 | return db_id2ddl 112 | 113 | class ChatBot(): 114 | def __init__(self) -> None: 115 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 116 | model_name = "seeklhy/codes-7b-merged" 117 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 118 | self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map = "auto", torch_dtype = torch.float16) 119 | self.max_length = 4096 120 | self.max_new_tokens = 256 121 | self.max_prefix_length = self.max_length - self.max_new_tokens 122 | 123 | self.sic = SchemaItemClassifierInference("sic_ckpts/sic_bird") 124 | 125 | self.db_id2content_searcher = dict() 126 | for db_id in os.listdir("db_contents_index"): 127 | self.db_id2content_searcher[db_id] = LuceneSearcher(os.path.join("db_contents_index", db_id)) 128 | 129 | self.db_ids = sorted(os.listdir("databases")) 130 | self.db_id2schema = get_db_id2schema("databases", "data/tables.json") 131 | self.db_id2ddl = get_db_id2ddl("databases") 132 | 133 | def get_response(self, question, db_id): 134 | data = { 135 | "text": question, 136 | "schema": copy.deepcopy(self.db_id2schema[db_id]), 137 | "matched_contents": get_matched_contents(question, self.db_id2content_searcher[db_id]) 138 | } 139 | data = filter_schema(data, self.sic, 6, 10) 140 | data["schema_sequence"] = get_db_schema_sequence(data["schema"]) 141 | data["content_sequence"] = get_matched_content_sequence(data["matched_contents"]) 142 | 143 | prefix_seq = data["schema_sequence"] + "\n" + data["content_sequence"] + "\n" + data["text"] + "\n" 144 | print(prefix_seq) 145 | 146 | input_ids = [self.tokenizer.bos_token_id] + self.tokenizer(prefix_seq , truncation = False)["input_ids"] 147 | if len(input_ids) > self.max_prefix_length: 148 | print("the current input sequence exceeds the max_tokens, we will truncate it.") 149 | input_ids = [self.tokenizer.bos_token_id] + input_ids[-(self.max_prefix_length-1):] 150 | attention_mask = [1] * len(input_ids) 151 | 152 | inputs = { 153 | "input_ids": torch.tensor([input_ids], dtype = torch.int64).to(self.model.device), 154 | "attention_mask": torch.tensor([attention_mask], dtype = torch.int64).to(self.model.device) 155 | } 156 | input_length = inputs["input_ids"].shape[1] 157 | 158 | with torch.no_grad(): 159 | generate_ids = self.model.generate( 160 | **inputs, 161 | max_new_tokens = self.max_new_tokens, 162 | num_beams = 4, 163 | num_return_sequences = 4 164 | ) 165 | 166 | generated_sqls = self.tokenizer.batch_decode(generate_ids[:, input_length:], skip_special_tokens = True, clean_up_tokenization_spaces = False) 167 | final_generated_sql = None 168 | for generated_sql in generated_sqls: 169 | execution_error = check_sql_executability(generated_sql, os.path.join("databases", db_id, db_id + ".sqlite")) 170 | if execution_error is None: # the generated sql has no execution errors, we will return it as the final generated sql 171 | final_generated_sql = generated_sql 172 | break 173 | 174 | if final_generated_sql is None: 175 | if generated_sqls[0].strip() != "": 176 | final_generated_sql = generated_sqls[0].strip() 177 | else: 178 | final_generated_sql = "Sorry, I can not generate a suitable SQL query for your question." 179 | 180 | return final_generated_sql.replace("\n", " ") -------------------------------------------------------------------------------- /utils/bridge_content_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | 7 | Encode DB content. 8 | """ 9 | 10 | import difflib 11 | from typing import List, Optional, Tuple 12 | from rapidfuzz import fuzz 13 | import sqlite3 14 | import functools 15 | 16 | # fmt: off 17 | _stopwords = {'who', 'ourselves', 'down', 'only', 'were', 'him', 'at', "weren't", 'has', 'few', "it's", 'm', 'again', 18 | 'd', 'haven', 'been', 'other', 'we', 'an', 'own', 'doing', 'ma', 'hers', 'all', "haven't", 'in', 'but', 19 | "shouldn't", 'does', 'out', 'aren', 'you', "you'd", 'himself', "isn't", 'most', 'y', 'below', 'is', 20 | "wasn't", 'hasn', 'them', 'wouldn', 'against', 'this', 'about', 'there', 'don', "that'll", 'a', 'being', 21 | 'with', 'your', 'theirs', 'its', 'any', 'why', 'now', 'during', 'weren', 'if', 'should', 'those', 'be', 22 | 'they', 'o', 't', 'of', 'or', 'me', 'i', 'some', 'her', 'do', 'will', 'yours', 'for', 'mightn', 'nor', 23 | 'needn', 'the', 'until', "couldn't", 'he', 'which', 'yourself', 'to', "needn't", "you're", 'because', 24 | 'their', 'where', 'it', "didn't", 've', 'whom', "should've", 'can', "shan't", 'on', 'had', 'have', 25 | 'myself', 'am', "don't", 'under', 'was', "won't", 'these', 'so', 'as', 'after', 'above', 'each', 'ours', 26 | 'hadn', 'having', 'wasn', 's', 'doesn', "hadn't", 'than', 'by', 'that', 'both', 'herself', 'his', 27 | "wouldn't", 'into', "doesn't", 'before', 'my', 'won', 'more', 'are', 'through', 'same', 'how', 'what', 28 | 'over', 'll', 'yourselves', 'up', 'mustn', "mustn't", "she's", 're', 'such', 'didn', "you'll", 'shan', 29 | 'when', "you've", 'themselves', "mightn't", 'she', 'from', 'isn', 'ain', 'between', 'once', 'here', 30 | 'shouldn', 'our', 'and', 'not', 'too', 'very', 'further', 'while', 'off', 'couldn', "hasn't", 'itself', 31 | 'then', 'did', 'just', "aren't"} 32 | # fmt: on 33 | 34 | _commonwords = {"no", "yes", "many"} 35 | 36 | 37 | def is_number(s: str) -> bool: 38 | try: 39 | float(s.replace(",", "")) 40 | return True 41 | except: 42 | return False 43 | 44 | 45 | def is_stopword(s: str) -> bool: 46 | return s.strip() in _stopwords 47 | 48 | 49 | def is_commonword(s: str) -> bool: 50 | return s.strip() in _commonwords 51 | 52 | 53 | def is_common_db_term(s: str) -> bool: 54 | return s.strip() in ["id"] 55 | 56 | 57 | class Match(object): 58 | def __init__(self, start: int, size: int) -> None: 59 | self.start = start 60 | self.size = size 61 | 62 | 63 | def is_span_separator(c: str) -> bool: 64 | return c in "'\"()`,.?! " 65 | 66 | 67 | def split(s: str) -> List[str]: 68 | return [c.lower() for c in s.strip()] 69 | 70 | 71 | def prefix_match(s1: str, s2: str) -> bool: 72 | i, j = 0, 0 73 | for i in range(len(s1)): 74 | if not is_span_separator(s1[i]): 75 | break 76 | for j in range(len(s2)): 77 | if not is_span_separator(s2[j]): 78 | break 79 | if i < len(s1) and j < len(s2): 80 | return s1[i] == s2[j] 81 | elif i >= len(s1) and j >= len(s2): 82 | return True 83 | else: 84 | return False 85 | 86 | 87 | def get_effective_match_source(s: str, start: int, end: int) -> Match: 88 | _start = -1 89 | 90 | for i in range(start, start - 2, -1): 91 | if i < 0: 92 | _start = i + 1 93 | break 94 | if is_span_separator(s[i]): 95 | _start = i 96 | break 97 | 98 | if _start < 0: 99 | return None 100 | 101 | _end = -1 102 | for i in range(end - 1, end + 3): 103 | if i >= len(s): 104 | _end = i - 1 105 | break 106 | if is_span_separator(s[i]): 107 | _end = i 108 | break 109 | 110 | if _end < 0: 111 | return None 112 | 113 | while _start < len(s) and is_span_separator(s[_start]): 114 | _start += 1 115 | while _end >= 0 and is_span_separator(s[_end]): 116 | _end -= 1 117 | 118 | return Match(_start, _end - _start + 1) 119 | 120 | 121 | def get_matched_entries( 122 | s: str, field_values: List[str], m_theta: float = 0.85, s_theta: float = 0.85 123 | ) -> Optional[List[Tuple[str, Tuple[str, str, float, float, int]]]]: 124 | if not field_values: 125 | return None 126 | 127 | if isinstance(s, str): 128 | n_grams = split(s) 129 | else: 130 | n_grams = s 131 | 132 | matched = dict() 133 | for field_value in field_values: 134 | if not isinstance(field_value, str): 135 | continue 136 | fv_tokens = split(field_value) 137 | sm = difflib.SequenceMatcher(None, n_grams, fv_tokens) 138 | match = sm.find_longest_match(0, len(n_grams), 0, len(fv_tokens)) 139 | if match.size > 0: 140 | source_match = get_effective_match_source( 141 | n_grams, match.a, match.a + match.size 142 | ) 143 | if source_match: # and source_match.size > 1 144 | match_str = field_value[match.b : match.b + match.size] 145 | source_match_str = s[ 146 | source_match.start : source_match.start + source_match.size 147 | ] 148 | c_match_str = match_str.lower().strip() 149 | c_source_match_str = source_match_str.lower().strip() 150 | c_field_value = field_value.lower().strip() 151 | if c_match_str and not is_common_db_term(c_match_str): # and not is_number(c_match_str) 152 | if ( 153 | is_stopword(c_match_str) 154 | or is_stopword(c_source_match_str) 155 | or is_stopword(c_field_value) 156 | ): 157 | continue 158 | if c_source_match_str.endswith(c_match_str + "'s"): 159 | match_score = 1.0 160 | else: 161 | if prefix_match(c_field_value, c_source_match_str): 162 | match_score = fuzz.ratio(c_field_value, c_source_match_str) / 100 163 | else: 164 | match_score = 0 165 | if ( 166 | is_commonword(c_match_str) 167 | or is_commonword(c_source_match_str) 168 | or is_commonword(c_field_value) 169 | ) and match_score < 1: 170 | continue 171 | s_match_score = match_score 172 | if match_score >= m_theta and s_match_score >= s_theta: 173 | if field_value.isupper() and match_score * s_match_score < 1: 174 | continue 175 | matched[match_str] = ( 176 | field_value, 177 | source_match_str, 178 | match_score, 179 | s_match_score, 180 | match.size, 181 | ) 182 | 183 | if not matched: 184 | return None 185 | else: 186 | return sorted( 187 | matched.items(), 188 | key=lambda x: (1e16 * x[1][2] + 1e8 * x[1][3] + x[1][4]), 189 | reverse=True, 190 | ) 191 | 192 | 193 | @functools.lru_cache(maxsize=1000, typed=False) 194 | def get_column_picklist(table_name: str, column_name: str, db_path: str) -> list: 195 | fetch_sql = "SELECT DISTINCT `{}` FROM `{}`".format(column_name, table_name) 196 | try: 197 | conn = sqlite3.connect(db_path) 198 | conn.text_factory = bytes 199 | c = conn.cursor() 200 | c.execute(fetch_sql) 201 | picklist = set() 202 | for x in c.fetchall(): 203 | if isinstance(x[0], str): 204 | picklist.add(x[0].encode("utf-8")) 205 | elif isinstance(x[0], bytes): 206 | try: 207 | picklist.add(x[0].decode("utf-8")) 208 | except UnicodeDecodeError: 209 | picklist.add(x[0].decode("latin-1")) 210 | else: 211 | picklist.add(x[0]) 212 | picklist = list(picklist) 213 | except Exception as e: 214 | picklist = [] 215 | finally: 216 | conn.close() 217 | return picklist 218 | 219 | 220 | def get_database_matches( 221 | question: str, 222 | table_name: str, 223 | column_name: str, 224 | db_path: str, 225 | top_k_matches: int = 2, 226 | match_threshold: float = 0.85, 227 | ) -> List[str]: 228 | picklist = get_column_picklist( 229 | table_name=table_name, column_name=column_name, db_path=db_path 230 | ) 231 | # only maintain data in ``str'' type 232 | picklist = [ele.strip() for ele in picklist if isinstance(ele, str)] 233 | # picklist is unordered, we sort it to ensure the reproduction stability 234 | picklist = sorted(picklist) 235 | 236 | matches = [] 237 | if picklist and isinstance(picklist[0], str): 238 | matched_entries = get_matched_entries( 239 | s=question, 240 | field_values=picklist, 241 | m_theta=match_threshold, 242 | s_theta=match_threshold, 243 | ) 244 | 245 | if matched_entries: 246 | num_values_inserted = 0 247 | for _match_str, ( 248 | field_value, 249 | _s_match_str, 250 | match_score, 251 | s_match_score, 252 | _match_size, 253 | ) in matched_entries: 254 | if "name" in column_name and match_score * s_match_score < 1: 255 | continue 256 | if table_name != "sqlite_sequence": # Spider database artifact 257 | matches.append(field_value.strip()) 258 | num_values_inserted += 1 259 | if num_values_inserted >= top_k_matches: 260 | break 261 | return matches -------------------------------------------------------------------------------- /utils/classifier_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import AutoConfig, RobertaModel 5 | 6 | class SchemaItemClassifier(nn.Module): 7 | def __init__(self, model_name_or_path, mode): 8 | super(SchemaItemClassifier, self).__init__() 9 | if mode in ["eval", "test"]: 10 | # load config 11 | config = AutoConfig.from_pretrained(model_name_or_path) 12 | # randomly initialize model's parameters according to the config 13 | self.plm_encoder = RobertaModel(config) 14 | elif mode == "train": 15 | self.plm_encoder = RobertaModel.from_pretrained(model_name_or_path) 16 | else: 17 | raise ValueError() 18 | 19 | self.plm_hidden_size = self.plm_encoder.config.hidden_size 20 | 21 | # column cls head 22 | self.column_info_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256) 23 | self.column_info_cls_head_linear2 = nn.Linear(256, 2) 24 | 25 | # column bi-lstm layer 26 | self.column_info_bilstm = nn.LSTM( 27 | input_size = self.plm_hidden_size, 28 | hidden_size = int(self.plm_hidden_size/2), 29 | num_layers = 2, 30 | dropout = 0, 31 | bidirectional = True 32 | ) 33 | 34 | # linear layer after column bi-lstm layer 35 | self.column_info_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size) 36 | 37 | # table cls head 38 | self.table_name_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256) 39 | self.table_name_cls_head_linear2 = nn.Linear(256, 2) 40 | 41 | # table bi-lstm pooling layer 42 | self.table_name_bilstm = nn.LSTM( 43 | input_size = self.plm_hidden_size, 44 | hidden_size = int(self.plm_hidden_size/2), 45 | num_layers = 2, 46 | dropout = 0, 47 | bidirectional = True 48 | ) 49 | # linear layer after table bi-lstm layer 50 | self.table_name_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size) 51 | 52 | # activation function 53 | self.leakyrelu = nn.LeakyReLU() 54 | self.tanh = nn.Tanh() 55 | 56 | # table-column cross-attention layer 57 | self.table_column_cross_attention_layer = nn.MultiheadAttention(embed_dim = self.plm_hidden_size, num_heads = 8) 58 | 59 | # dropout function, p=0.2 means randomly set 20% neurons to 0 60 | self.dropout = nn.Dropout(p = 0.2) 61 | 62 | def table_column_cross_attention( 63 | self, 64 | table_name_embeddings_in_one_db, 65 | column_info_embeddings_in_one_db, 66 | column_number_in_each_table 67 | ): 68 | table_num = table_name_embeddings_in_one_db.shape[0] 69 | table_name_embedding_attn_list = [] 70 | for table_id in range(table_num): 71 | table_name_embedding = table_name_embeddings_in_one_db[[table_id], :] 72 | column_info_embeddings_in_one_table = column_info_embeddings_in_one_db[ 73 | sum(column_number_in_each_table[:table_id]) : sum(column_number_in_each_table[:table_id+1]), :] 74 | 75 | table_name_embedding_attn, _ = self.table_column_cross_attention_layer( 76 | table_name_embedding, 77 | column_info_embeddings_in_one_table, 78 | column_info_embeddings_in_one_table 79 | ) 80 | 81 | table_name_embedding_attn_list.append(table_name_embedding_attn) 82 | 83 | # residual connection 84 | table_name_embeddings_in_one_db = table_name_embeddings_in_one_db + torch.cat(table_name_embedding_attn_list, dim = 0) 85 | # row-wise L2 norm 86 | table_name_embeddings_in_one_db = torch.nn.functional.normalize(table_name_embeddings_in_one_db, p=2.0, dim=1) 87 | 88 | return table_name_embeddings_in_one_db 89 | 90 | def table_column_cls( 91 | self, 92 | encoder_input_ids, 93 | encoder_input_attention_mask, 94 | batch_aligned_column_info_ids, 95 | batch_aligned_table_name_ids, 96 | batch_column_number_in_each_table 97 | ): 98 | batch_size = encoder_input_ids.shape[0] 99 | 100 | encoder_output = self.plm_encoder( 101 | input_ids = encoder_input_ids, 102 | attention_mask = encoder_input_attention_mask, 103 | return_dict = True 104 | ) # encoder_output["last_hidden_state"].shape = (batch_size x seq_length x hidden_size) 105 | 106 | batch_table_name_cls_logits, batch_column_info_cls_logits = [], [] 107 | 108 | # handle each data in current batch 109 | for batch_id in range(batch_size): 110 | column_number_in_each_table = batch_column_number_in_each_table[batch_id] 111 | sequence_embeddings = encoder_output["last_hidden_state"][batch_id, :, :] # (seq_length x hidden_size) 112 | 113 | # obtain table ids for each table 114 | aligned_table_name_ids = batch_aligned_table_name_ids[batch_id] 115 | # obtain column ids for each column 116 | aligned_column_info_ids = batch_aligned_column_info_ids[batch_id] 117 | 118 | table_name_embedding_list, column_info_embedding_list = [], [] 119 | 120 | # obtain table embedding via bi-lstm pooling + a non-linear layer 121 | for table_name_ids in aligned_table_name_ids: 122 | table_name_embeddings = sequence_embeddings[table_name_ids, :] 123 | 124 | # BiLSTM pooling 125 | output_t, (hidden_state_t, cell_state_t) = self.table_name_bilstm(table_name_embeddings) 126 | table_name_embedding = hidden_state_t[-2:, :].view(1, self.plm_hidden_size) 127 | table_name_embedding_list.append(table_name_embedding) 128 | table_name_embeddings_in_one_db = torch.cat(table_name_embedding_list, dim = 0) 129 | # non-linear mlp layer 130 | table_name_embeddings_in_one_db = self.leakyrelu(self.table_name_linear_after_pooling(table_name_embeddings_in_one_db)) 131 | 132 | # obtain column embedding via bi-lstm pooling + a non-linear layer 133 | for column_info_ids in aligned_column_info_ids: 134 | column_info_embeddings = sequence_embeddings[column_info_ids, :] 135 | 136 | # BiLSTM pooling 137 | output_c, (hidden_state_c, cell_state_c) = self.column_info_bilstm(column_info_embeddings) 138 | column_info_embedding = hidden_state_c[-2:, :].view(1, self.plm_hidden_size) 139 | column_info_embedding_list.append(column_info_embedding) 140 | column_info_embeddings_in_one_db = torch.cat(column_info_embedding_list, dim = 0) 141 | # non-linear mlp layer 142 | column_info_embeddings_in_one_db = self.leakyrelu(self.column_info_linear_after_pooling(column_info_embeddings_in_one_db)) 143 | 144 | # table-column (tc) cross-attention 145 | table_name_embeddings_in_one_db = self.table_column_cross_attention( 146 | table_name_embeddings_in_one_db, 147 | column_info_embeddings_in_one_db, 148 | column_number_in_each_table 149 | ) 150 | 151 | # calculate table 0-1 logits 152 | table_name_embeddings_in_one_db = self.table_name_cls_head_linear1(table_name_embeddings_in_one_db) 153 | table_name_embeddings_in_one_db = self.dropout(self.leakyrelu(table_name_embeddings_in_one_db)) 154 | table_name_cls_logits = self.table_name_cls_head_linear2(table_name_embeddings_in_one_db) 155 | 156 | # calculate column 0-1 logits 157 | column_info_embeddings_in_one_db = self.column_info_cls_head_linear1(column_info_embeddings_in_one_db) 158 | column_info_embeddings_in_one_db = self.dropout(self.leakyrelu(column_info_embeddings_in_one_db)) 159 | column_info_cls_logits = self.column_info_cls_head_linear2(column_info_embeddings_in_one_db) 160 | 161 | batch_table_name_cls_logits.append(table_name_cls_logits) 162 | batch_column_info_cls_logits.append(column_info_cls_logits) 163 | 164 | return batch_table_name_cls_logits, batch_column_info_cls_logits 165 | 166 | def forward( 167 | self, 168 | encoder_input_ids, 169 | encoder_attention_mask, 170 | batch_aligned_column_info_ids, 171 | batch_aligned_table_name_ids, 172 | batch_column_number_in_each_table, 173 | ): 174 | batch_table_name_cls_logits, batch_column_info_cls_logits \ 175 | = self.table_column_cls( 176 | encoder_input_ids, 177 | encoder_attention_mask, 178 | batch_aligned_column_info_ids, 179 | batch_aligned_table_name_ids, 180 | batch_column_number_in_each_table 181 | ) 182 | 183 | return { 184 | "batch_table_name_cls_logits" : batch_table_name_cls_logits, 185 | "batch_column_info_cls_logits": batch_column_info_cls_logits 186 | } -------------------------------------------------------------------------------- /utils/db_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import sqlite3 4 | 5 | from func_timeout import func_set_timeout, FunctionTimedOut 6 | from utils.bridge_content_encoder import get_matched_entries 7 | from nltk.tokenize import word_tokenize 8 | from nltk import ngrams 9 | 10 | def add_a_record(question, db_id): 11 | conn = sqlite3.connect('data/history/history.sqlite') 12 | cursor = conn.cursor() 13 | cursor.execute("INSERT INTO record (question, db_id) VALUES (?, ?)", (question, db_id)) 14 | 15 | conn.commit() 16 | conn.close() 17 | 18 | def obtain_n_grams(sequence, max_n): 19 | tokens = word_tokenize(sequence) 20 | all_grams = [] 21 | for n in range(1, max_n + 1): 22 | all_grams.extend([" ".join(gram) for gram in ngrams(tokens, n)]) 23 | 24 | return all_grams 25 | 26 | # get the database cursor for a sqlite database path 27 | def get_cursor_from_path(sqlite_path): 28 | try: 29 | if not os.path.exists(sqlite_path): 30 | print("Openning a new connection %s" % sqlite_path) 31 | connection = sqlite3.connect(sqlite_path, check_same_thread = False) 32 | except Exception as e: 33 | print(sqlite_path) 34 | raise e 35 | connection.text_factory = lambda b: b.decode(errors="ignore") 36 | cursor = connection.cursor() 37 | return cursor 38 | 39 | # execute predicted sql with a time limitation 40 | @func_set_timeout(15) 41 | def execute_sql(cursor, sql): 42 | cursor.execute(sql) 43 | 44 | return cursor.fetchall() 45 | 46 | # execute predicted sql with a long time limitation (for buiding content index) 47 | @func_set_timeout(2000) 48 | def execute_sql_long_time_limitation(cursor, sql): 49 | cursor.execute(sql) 50 | 51 | return cursor.fetchall() 52 | 53 | def check_sql_executability(generated_sql, db): 54 | if generated_sql.strip() == "": 55 | return "Error: empty string" 56 | try: 57 | cursor = get_cursor_from_path(db) 58 | execute_sql(cursor, generated_sql) 59 | execution_error = None 60 | except FunctionTimedOut as fto: 61 | print("SQL execution time out error: {}.".format(fto)) 62 | execution_error = "SQL execution times out." 63 | except Exception as e: 64 | print("SQL execution runtime error: {}.".format(e)) 65 | execution_error = str(e) 66 | 67 | return execution_error 68 | 69 | def is_number(s): 70 | try: 71 | float(s) 72 | return True 73 | except ValueError: 74 | return False 75 | 76 | def detect_special_char(name): 77 | for special_char in ['(', '-', ')', ' ', '/']: 78 | if special_char in name: 79 | return True 80 | 81 | return False 82 | 83 | def add_quotation_mark(s): 84 | return "`" + s + "`" 85 | 86 | def get_column_contents(column_name, table_name, cursor): 87 | select_column_sql = "SELECT DISTINCT `{}` FROM `{}` WHERE `{}` IS NOT NULL LIMIT 2;".format(column_name, table_name, column_name) 88 | results = execute_sql_long_time_limitation(cursor, select_column_sql) 89 | column_contents = [str(result[0]).strip() for result in results] 90 | # remove empty and extremely-long contents 91 | column_contents = [content for content in column_contents if len(content) != 0 and len(content) <= 25] 92 | 93 | return column_contents 94 | 95 | def get_matched_contents(question, searcher): 96 | # coarse-grained matching between the input text and all contents in database 97 | grams = obtain_n_grams(question, 4) 98 | hits = [] 99 | for query in grams: 100 | hits.extend(searcher.search(query, k = 10)) 101 | 102 | coarse_matched_contents = dict() 103 | for i in range(len(hits)): 104 | matched_result = json.loads(hits[i].raw) 105 | # `tc_name` refers to column names like `table_name.column_name`, e.g., document_drafts.document_id 106 | tc_name = ".".join(matched_result["id"].split("-**-")[:2]) 107 | if tc_name in coarse_matched_contents.keys(): 108 | if matched_result["contents"] not in coarse_matched_contents[tc_name]: 109 | coarse_matched_contents[tc_name].append(matched_result["contents"]) 110 | else: 111 | coarse_matched_contents[tc_name] = [matched_result["contents"]] 112 | 113 | fine_matched_contents = dict() 114 | for tc_name, contents in coarse_matched_contents.items(): 115 | # fine-grained matching between the question and coarse matched contents 116 | fm_contents = get_matched_entries(question, contents) 117 | 118 | if fm_contents is None: 119 | continue 120 | for _match_str, (field_value, _s_match_str, match_score, s_match_score, _match_size,) in fm_contents: 121 | if match_score < 0.9: 122 | continue 123 | if tc_name in fine_matched_contents.keys(): 124 | if len(fine_matched_contents[tc_name]) < 25: 125 | fine_matched_contents[tc_name].append(field_value.strip()) 126 | else: 127 | fine_matched_contents[tc_name] = [field_value.strip()] 128 | 129 | return fine_matched_contents 130 | 131 | def get_db_schema_sequence(schema): 132 | schema_sequence = "database schema :\n" 133 | for table in schema["schema_items"]: 134 | table_name, table_comment = table["table_name"], table["table_comment"] 135 | if detect_special_char(table_name): 136 | table_name = add_quotation_mark(table_name) 137 | 138 | # if table_comment != "": 139 | # table_name += " ( comment : " + table_comment + " )" 140 | 141 | column_info_list = [] 142 | for column_name, column_type, column_comment, column_content, pk_indicator in \ 143 | zip(table["column_names"], table["column_types"], table["column_comments"], table["column_contents"], table["pk_indicators"]): 144 | if detect_special_char(column_name): 145 | column_name = add_quotation_mark(column_name) 146 | additional_column_info = [] 147 | # column type 148 | additional_column_info.append(column_type) 149 | # pk indicator 150 | if pk_indicator != 0: 151 | additional_column_info.append("primary key") 152 | # column comment 153 | if column_comment != "": 154 | additional_column_info.append("comment : " + column_comment) 155 | # representive column values 156 | if len(column_content) != 0: 157 | additional_column_info.append("values : " + " , ".join(column_content)) 158 | 159 | column_info_list.append(table_name + "." + column_name + " ( " + " | ".join(additional_column_info) + " )") 160 | 161 | schema_sequence += "table "+ table_name + " , columns = [ " + " , ".join(column_info_list) + " ]\n" 162 | 163 | if len(schema["foreign_keys"]) != 0: 164 | schema_sequence += "foreign keys :\n" 165 | for foreign_key in schema["foreign_keys"]: 166 | for i in range(len(foreign_key)): 167 | if detect_special_char(foreign_key[i]): 168 | foreign_key[i] = add_quotation_mark(foreign_key[i]) 169 | schema_sequence += "{}.{} = {}.{}\n".format(foreign_key[0], foreign_key[1], foreign_key[2], foreign_key[3]) 170 | else: 171 | schema_sequence += "foreign keys : None\n" 172 | 173 | return schema_sequence.strip() 174 | 175 | def get_matched_content_sequence(matched_contents): 176 | content_sequence = "" 177 | if len(matched_contents) != 0: 178 | content_sequence += "matched contents :\n" 179 | for tc_name, contents in matched_contents.items(): 180 | table_name = tc_name.split(".")[0] 181 | column_name = tc_name.split(".")[1] 182 | if detect_special_char(table_name): 183 | table_name = add_quotation_mark(table_name) 184 | if detect_special_char(column_name): 185 | column_name = add_quotation_mark(column_name) 186 | 187 | content_sequence += table_name + "." + column_name + " ( " + " , ".join(contents) + " )\n" 188 | else: 189 | content_sequence = "matched contents : None" 190 | 191 | return content_sequence.strip() 192 | 193 | def get_db_schema(db_path, db_comments, db_id): 194 | if db_id in db_comments: 195 | db_comment = db_comments[db_id] 196 | else: 197 | db_comment = None 198 | 199 | cursor = get_cursor_from_path(db_path) 200 | 201 | # obtain table names 202 | results = execute_sql(cursor, "SELECT name FROM sqlite_master WHERE type='table';") 203 | table_names = [result[0].lower() for result in results] 204 | 205 | schema = dict() 206 | schema["schema_items"] = [] 207 | foreign_keys = [] 208 | # for each table 209 | for table_name in table_names: 210 | # skip SQLite system table: sqlite_sequence 211 | if table_name == "sqlite_sequence": 212 | continue 213 | # obtain column names in the current table 214 | results = execute_sql(cursor, "SELECT name, type, pk FROM PRAGMA_TABLE_INFO('{}')".format(table_name)) 215 | column_names_in_one_table = [result[0].lower() for result in results] 216 | column_types_in_one_table = [result[1].lower() for result in results] 217 | pk_indicators_in_one_table = [result[2] for result in results] 218 | 219 | column_contents = [] 220 | for column_name in column_names_in_one_table: 221 | column_contents.append(get_column_contents(column_name, table_name, cursor)) 222 | 223 | # obtain foreign keys in the current table 224 | results = execute_sql(cursor, "SELECT * FROM pragma_foreign_key_list('{}');".format(table_name)) 225 | for result in results: 226 | if None not in [result[3], result[2], result[4]]: 227 | foreign_keys.append([table_name.lower(), result[3].lower(), result[2].lower(), result[4].lower()]) 228 | 229 | # obtain comments for each schema item 230 | if db_comment is not None: 231 | if table_name in db_comment: # record comments for tables and columns 232 | table_comment = db_comment[table_name]["table_comment"] 233 | column_comments = [db_comment[table_name]["column_comments"][column_name] \ 234 | if column_name in db_comment[table_name]["column_comments"] else "" \ 235 | for column_name in column_names_in_one_table] 236 | else: # current database has comment information, but the current table does not 237 | table_comment = "" 238 | column_comments = ["" for _ in column_names_in_one_table] 239 | else: # current database has no comment information 240 | table_comment = "" 241 | column_comments = ["" for _ in column_names_in_one_table] 242 | 243 | schema["schema_items"].append({ 244 | "table_name": table_name, 245 | "table_comment": table_comment, 246 | "column_names": column_names_in_one_table, 247 | "column_types": column_types_in_one_table, 248 | "column_comments": column_comments, 249 | "column_contents": column_contents, 250 | "pk_indicators": pk_indicators_in_one_table 251 | }) 252 | 253 | schema["foreign_keys"] = foreign_keys 254 | 255 | return schema 256 | -------------------------------------------------------------------------------- /utils/translate_utils.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import random 3 | import json 4 | 5 | def translate_zh_to_en(question, token): 6 | url = 'https://aip.baidubce.com/rpc/2.0/mt/texttrans/v1?access_token=' + token 7 | 8 | from_lang = 'auto' 9 | to_lang = 'en' 10 | term_ids = '' 11 | 12 | # Build request 13 | headers = {'Content-Type': 'application/json'} 14 | payload = {'q': question, 'from': from_lang, 'to': to_lang, 'termIds' : term_ids} 15 | 16 | # Send request 17 | r = requests.post(url, params=payload, headers=headers) 18 | result = r.json() 19 | 20 | return result["result"]["trans_result"][0]["dst"] 21 | 22 | if __name__ == "__main__": 23 | print(translate_zh_to_en("你好啊!")) --------------------------------------------------------------------------------