├── README.md ├── eval_mode.py ├── schema_filter.py ├── training_mode.py └── utils └── classifier_model.py /README.md: -------------------------------------------------------------------------------- 1 | # Schema Filter for Text-to-SQL 2 | 3 | Introducing our advanced Schema Filter, a bilingual (English and Chinese) model with 3 billion parameters, designed to enhance Text-to-SQL projects. This tool expertly identifies the most relevant database schemas—tables and columns—based on natural language queries. 4 | 5 | ## Why Use a Schema Filter? 6 | 7 | 1. **Database Complexity**: Ideal for databases with a vast array of tables and columns. Integrating a schema filter with your large language model (LLM) can reduce the context length of the database schema. 8 | 2. **Performance Enhancement**: Aids your text-to-SQL model by filtering out irrelevant schemas, reducing the load on LLM for schema linking tasks. 9 | 10 | ## Architecture Overview 11 | 12 | The schema filter's architecture follows the cross-encoder design from [RESDSQL](https://arxiv.org/abs/2302.05965), with enhancements from our work in [CodeS](https://arxiv.org/abs/2402.16347) for ease of use. Originally based on RoBERTa-Large, we've upgraded to XLM-RoBERTa-XL for its 3.5 billion parameters and bilingual support, making it more suited for the schema-linking challenges. 13 | 14 | ## Training Data 15 | 16 | We fine-tuned this schema filter using training sets from Spider, BIRD, and CSpider, ensuring a robust and versatile schema filter capable of handling diverse queries. 17 | 18 | ## Hardware Requirements 19 | 20 | The schema filter's extensive parameter count (3.5 billion) necessitates at least 15GB of memory (GPU or CPU) for inference. 21 | 22 | ## Getting Started 23 | 24 | 1. **Clone Project**: Clone or download this project. 25 | 2. **Model Download**: Acquire our fine-tuned model `sic_merged.zip` from [quark netdisk](https://pan.quark.cn/s/418c417127ae) or [google drive](https://drive.google.com/file/d/1xzNvv5h-ZjhjOOZ-ePv1xg_n3YbUNLWi/view?usp=sharing) and then unzip it in the root of this project. 26 | 3. **Usage Examples**: Consult `eval_mode.py` for running the model without SQL input, predicting relevance scores for tables and columns based on queries using the trained model. Use `training_mode.py` for a guided experience with ground-truth SQL, simulating the schema filtering process. 27 | 28 | To integrate the schema filter into your text-to-SQL system properly, your data needs to be organized as follows: 29 | 30 | - `text`: The natural language question. 31 | - `sql`: The corresponding SQL query; this can be left as an empty string when in evaluation mode. 32 | - `schema`: The structure detailing the database schema. 33 | - `schema.schema_items.table_name`: The name of the table in the database. 34 | - `schema.schema_items.table_comment`: A descriptive comment for the table, which is necessary if the table name is not self-explanatory. Otherwise, this can be left as an empty string. 35 | - `schema.schema_items.column_names`: The names of the columns in the table. 36 | - `schema.schema_items.column_comments`: A descriptive comment for each column, is required only if the column name could be confusing. If clarity is not an issue, this can also be an empty string. 37 | 38 | Here is an example of how your data should be formatted: 39 | 40 | ```json 41 | { 42 | "text": "List the names of directors whose films have received reviews from Sarah Martinez.", 43 | "sql": "SELECT DISTINCT movie.director FROM rating JOIN movie ON rating.mid = movie.mid JOIN reviewer ON rating.rid = reviewer.rid WHERE reviewer.name = 'Sarah Martinez'", 44 | "schema": { 45 | "schema_items": [ 46 | { 47 | "table_name": "movie", 48 | "table_comment": "", 49 | "column_names": ["mid", "title", "year", "director"], 50 | "column_comments": ["movie id", "", "", ""] 51 | }, 52 | ... 53 | ] 54 | } 55 | } 56 | ``` 57 | 58 | In the provided example scripts, you can adjust the `num_top_k_tables` and `num_top_k_columns` parameters. These define the number of tables and columns, respectively, that will be retained after schema filtering. 59 | 60 | ## Citation 61 | If this project assists you, kindly reference the following paper: 62 | ``` 63 | @inproceedings{li2022resdsql, 64 | author = {Haoyang Li and Jing Zhang and Cuiping Li and Hong Chen}, 65 | title = "RESDSQL: Decoupling Schema Linking and Skeleton Parsing for Text-to-SQL", 66 | booktitle = "AAAI", 67 | year = "2023" 68 | } 69 | 70 | @inproceedings{li2024codes, 71 | author = {Haoyang Li and Jing Zhang and Hanbing Liu and Ju Fan and Xiaokang Zhang and Jun Zhu and Renjie Wei and Hongyan Pan and Cuiping Li and Hong Chen}, 72 | title = "CodeS: Towards Building Open-source Language Models for Text-to-SQL", 73 | booktitle = "SIGMOD", 74 | year = "2024" 75 | } 76 | ``` 77 | -------------------------------------------------------------------------------- /eval_mode.py: -------------------------------------------------------------------------------- 1 | from schema_filter import filter_func, SchemaItemClassifierInference 2 | 3 | # in the eval mode, you do not need to provide sql, 4 | # the relevant scores of tables and columns are predicted by the fine-tuned schema filter model based on the user's text (or question) 5 | data = { 6 | "text": "What are the names of all directors whose movies have been reviewed by Sarah Martinez?", 7 | "sql": "", 8 | "schema": { 9 | "schema_items": [ 10 | { 11 | "table_name": "movie", 12 | "table_comment": "", 13 | "column_names": [ 14 | "mid", 15 | "title", 16 | "year", 17 | "director" 18 | ], 19 | "column_comments": [ 20 | "movie id", 21 | "", 22 | "", 23 | "" 24 | ] 25 | }, 26 | { 27 | "table_name": "reviewer", 28 | "table_comment": "", 29 | "column_names": [ 30 | "rid", 31 | "name" 32 | ], 33 | "column_comments": [ 34 | "reviewer id", 35 | "" 36 | ] 37 | }, 38 | { 39 | "table_name": "rating", 40 | "table_comment": "", 41 | "column_names": [ 42 | "rid", 43 | "mid", 44 | "stars", 45 | "ratingdate" 46 | ], 47 | "column_comments": [ 48 | "reviewer id", 49 | "movie id", 50 | "rating stars", 51 | "" 52 | ] 53 | } 54 | ] 55 | } 56 | } 57 | 58 | dataset = [data] 59 | 60 | # remain up to 3 relavant tables in the database 61 | num_top_k_tables = 3 62 | # remain up to 3 relavant columns for each remained table 63 | num_top_k_columns = 3 64 | 65 | # load fine-tuned schema filter 66 | sic = SchemaItemClassifierInference("sic_merged") 67 | 68 | dataset = filter_func( 69 | dataset = dataset, 70 | dataset_type = "eval", 71 | sic = sic, 72 | num_top_k_tables = num_top_k_tables, 73 | num_top_k_columns = num_top_k_columns 74 | ) 75 | 76 | print(dataset) -------------------------------------------------------------------------------- /schema_filter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | 5 | from tqdm import tqdm 6 | from transformers import AutoTokenizer 7 | from utils.classifier_model import SchemaItemClassifier 8 | from transformers.trainer_utils import set_seed 9 | 10 | def prepare_inputs_and_labels(sample, tokenizer): 11 | table_names = [table["table_name"] for table in sample["schema"]["schema_items"]] 12 | column_names = [table["column_names"] for table in sample["schema"]["schema_items"]] 13 | column_num_in_each_table = [len(table["column_names"]) for table in sample["schema"]["schema_items"]] 14 | 15 | # `column_name_word_indices` and `table_name_word_indices` record the word indices of each column and table in `input_words`, whose element is an integer 16 | column_name_word_indices, table_name_word_indices = [], [] 17 | 18 | input_words = [sample["text"]] 19 | for table_id, table_name in enumerate(table_names): 20 | input_words.append("|") 21 | input_words.append(table_name) 22 | table_name_word_indices.append(len(input_words) - 1) 23 | input_words.append(":") 24 | 25 | for column_name in column_names[table_id]: 26 | input_words.append(column_name) 27 | column_name_word_indices.append(len(input_words) - 1) 28 | input_words.append(",") 29 | 30 | # remove the last "," 31 | input_words = input_words[:-1] 32 | 33 | tokenized_inputs = tokenizer( 34 | input_words, 35 | return_tensors="pt", 36 | is_split_into_words = True, 37 | padding = "max_length", 38 | max_length = 512, 39 | truncation = True 40 | ) 41 | 42 | # after tokenizing, one table name or column name may be splitted into multiple tokens (i.e., sub-words) 43 | # `column_name_token_indices` and `table_name_token_indices` records the token indices of each column and table in `input_ids`, whose element is a list of integer 44 | column_name_token_indices, table_name_token_indices = [], [] 45 | word_indices = tokenized_inputs.word_ids(batch_index = 0) 46 | 47 | # obtain token indices of each column in `input_ids` 48 | for column_name_word_index in column_name_word_indices: 49 | column_name_token_indices.append([token_id for token_id, word_index in enumerate(word_indices) if column_name_word_index == word_index]) 50 | 51 | # obtain token indices of each table in `input_ids` 52 | for table_name_word_index in table_name_word_indices: 53 | table_name_token_indices.append([token_id for token_id, word_index in enumerate(word_indices) if table_name_word_index == word_index]) 54 | 55 | encoder_input_ids = tokenized_inputs["input_ids"] 56 | encoder_input_attention_mask = tokenized_inputs["attention_mask"] 57 | 58 | # print("\n".join(tokenizer.batch_decode(encoder_input_ids, skip_special_tokens = True))) 59 | 60 | if torch.cuda.is_available(): 61 | encoder_input_ids = encoder_input_ids.cuda() 62 | encoder_input_attention_mask = encoder_input_attention_mask.cuda() 63 | 64 | return encoder_input_ids, encoder_input_attention_mask, \ 65 | column_name_token_indices, table_name_token_indices, column_num_in_each_table 66 | 67 | def get_schema(tables_and_columns): 68 | schema_items = [] 69 | table_names = list(dict.fromkeys([t for t, c in tables_and_columns])) 70 | for table_name in table_names: 71 | schema_items.append( 72 | { 73 | "table_name": table_name, 74 | "column_names": [c for t, c in tables_and_columns if t == table_name] 75 | } 76 | ) 77 | 78 | return {"schema_items": schema_items} 79 | 80 | def get_sequence_length(text, tables_and_columns, tokenizer): 81 | table_names = [t for t, c in tables_and_columns] 82 | # duplicate `table_names` while preserving order 83 | table_names = list(dict.fromkeys(table_names)) 84 | 85 | column_names = [] 86 | for table_name in table_names: 87 | column_names.append([c for t, c in tables_and_columns if t == table_name]) 88 | 89 | input_words = [text] 90 | for table_id, table_name in enumerate(table_names): 91 | input_words.append("|") 92 | input_words.append(table_name) 93 | input_words.append(":") 94 | for column_name in column_names[table_id]: 95 | input_words.append(column_name) 96 | input_words.append(",") 97 | # remove the last "," 98 | input_words = input_words[:-1] 99 | 100 | tokenized_inputs = tokenizer(input_words, is_split_into_words = True) 101 | 102 | return len(tokenized_inputs["input_ids"]) 103 | 104 | # handle extremely long schema sequences 105 | def split_sample(sample, tokenizer): 106 | text = sample["text"] 107 | 108 | table_names = [] 109 | column_names = [] 110 | for table in sample["schema"]["schema_items"]: 111 | table_names.append(table["table_name"] + " ( " + table["table_comment"] + " ) " \ 112 | if table["table_comment"] != "" else table["table_name"]) 113 | column_names.append([column_name + " ( " + column_comment + " ) " \ 114 | if column_comment != "" else column_name \ 115 | for column_name, column_comment in zip(table["column_names"], table["column_comments"])]) 116 | 117 | splitted_samples = [] 118 | recorded_tables_and_columns = [] 119 | 120 | for table_idx, table_name in enumerate(table_names): 121 | for column_name in column_names[table_idx]: 122 | if get_sequence_length(text, recorded_tables_and_columns + [[table_name, column_name]], tokenizer) < 500: 123 | recorded_tables_and_columns.append([table_name, column_name]) 124 | else: 125 | splitted_samples.append( 126 | { 127 | "text": text, 128 | "schema": get_schema(recorded_tables_and_columns) 129 | } 130 | ) 131 | recorded_tables_and_columns = [[table_name, column_name]] 132 | 133 | splitted_samples.append( 134 | { 135 | "text": text, 136 | "schema": get_schema(recorded_tables_and_columns) 137 | } 138 | ) 139 | 140 | return splitted_samples 141 | 142 | def merge_pred_results(sample, pred_results): 143 | # table_names = [table["table_name"] for table in sample["schema"]["schema_items"]] 144 | # column_names = [table["column_names"] for table in sample["schema"]["schema_items"]] 145 | table_names = [] 146 | column_names = [] 147 | for table in sample["schema"]["schema_items"]: 148 | table_names.append(table["table_name"] + " ( " + table["table_comment"] + " ) " \ 149 | if table["table_comment"] != "" else table["table_name"]) 150 | column_names.append([column_name + " ( " + column_comment + " ) " \ 151 | if column_comment != "" else column_name \ 152 | for column_name, column_comment in zip(table["column_names"], table["column_comments"])]) 153 | 154 | merged_results = [] 155 | for table_id, table_name in enumerate(table_names): 156 | table_prob = 0 157 | column_probs = [] 158 | for result_dict in pred_results: 159 | if table_name in result_dict: 160 | if table_prob < result_dict[table_name]["table_prob"]: 161 | table_prob = result_dict[table_name]["table_prob"] 162 | column_probs += result_dict[table_name]["column_probs"] 163 | 164 | merged_results.append( 165 | { 166 | "table_name": table_name, 167 | "table_prob": table_prob, 168 | "column_names": column_names[table_id], 169 | "column_probs": column_probs 170 | } 171 | ) 172 | 173 | return merged_results 174 | 175 | def filter_func(dataset, dataset_type, sic, num_top_k_tables = 5, num_top_k_columns = 5): 176 | for data in tqdm(dataset, desc = "filtering schema items for the dataset"): 177 | filtered_schema = dict() 178 | filtered_schema["schema_items"] = [] 179 | 180 | table_names = [table["table_name"] for table in data["schema"]["schema_items"]] 181 | table_comments = [table["table_comment"] for table in data["schema"]["schema_items"]] 182 | column_names = [table["column_names"] for table in data["schema"]["schema_items"]] 183 | column_comments = [table["column_comments"] for table in data["schema"]["schema_items"]] 184 | 185 | if dataset_type == "eval": 186 | # predict scores for each tables and columns 187 | pred_results = sic.predict(data) 188 | # remain top_k1 tables for each database and top_k2 columns for each remained table 189 | table_probs = [pred_result["table_prob"] for pred_result in pred_results] 190 | table_indices = np.argsort(-np.array(table_probs), kind="stable")[:num_top_k_tables].tolist() 191 | elif dataset_type == "train": 192 | table_indices = [table_idx for table_idx, table_label in enumerate(data["table_labels"]) if table_label == 1] 193 | if len(table_indices) < num_top_k_tables: 194 | unused_table_indices = [table_idx for table_idx, table_label in enumerate(data["table_labels"]) if table_label == 0] 195 | table_indices += random.sample(unused_table_indices, min(len(unused_table_indices), num_top_k_tables - len(table_indices))) 196 | random.shuffle(table_indices) 197 | 198 | for table_idx in table_indices: 199 | if dataset_type == "eval": 200 | column_probs = pred_results[table_idx]["column_probs"] 201 | column_indices = np.argsort(-np.array(column_probs), kind="stable")[:num_top_k_columns].tolist() 202 | elif dataset_type == "train": 203 | column_indices = [column_idx for column_idx, column_label in enumerate(data["column_labels"][table_idx]) if column_label == 1] 204 | if len(column_indices) < num_top_k_columns: 205 | unused_column_indices = [column_idx for column_idx, column_label in enumerate(data["column_labels"][table_idx]) if column_label == 0] 206 | column_indices += random.sample(unused_column_indices, min(len(unused_column_indices), num_top_k_columns - len(column_indices))) 207 | random.shuffle(column_indices) 208 | 209 | filtered_schema["schema_items"].append( 210 | { 211 | "table_name": table_names[table_idx], 212 | "table_comment": table_comments[table_idx], 213 | "column_names": [column_names[table_idx][column_idx] for column_idx in column_indices], 214 | "column_comments": [column_comments[table_idx][column_idx] for column_idx in column_indices] 215 | } 216 | ) 217 | 218 | # replace the old schema with the filtered schema 219 | data["schema"] = filtered_schema 220 | 221 | if dataset_type == "train": 222 | del data["table_labels"] 223 | del data["column_labels"] 224 | 225 | return dataset 226 | 227 | def lista_contains_listb(lista, listb): 228 | for b in listb: 229 | if b not in lista: 230 | return 0 231 | 232 | return 1 233 | 234 | class SchemaItemClassifierInference(): 235 | def __init__(self, model_save_path): 236 | set_seed(42) 237 | # load tokenizer 238 | self.tokenizer = AutoTokenizer.from_pretrained(model_save_path, add_prefix_space = True) 239 | # initialize model 240 | self.model = SchemaItemClassifier(model_save_path, "test") 241 | # load fine-tuned params 242 | self.model.load_state_dict(torch.load(model_save_path + "/dense_classifier.pt", map_location=torch.device('cpu')), strict=False) 243 | if torch.cuda.is_available(): 244 | self.model = self.model.cuda() 245 | self.model.eval() 246 | 247 | def predict_one(self, sample): 248 | encoder_input_ids, encoder_input_attention_mask, column_name_token_indices,\ 249 | table_name_token_indices, column_num_in_each_table = prepare_inputs_and_labels(sample, self.tokenizer) 250 | 251 | with torch.no_grad(): 252 | model_outputs = self.model( 253 | encoder_input_ids, 254 | encoder_input_attention_mask, 255 | [column_name_token_indices], 256 | [table_name_token_indices], 257 | [column_num_in_each_table] 258 | ) 259 | 260 | table_logits = model_outputs["batch_table_name_cls_logits"][0] 261 | table_pred_probs = torch.nn.functional.softmax(table_logits, dim = 1)[:, 1].cpu().tolist() 262 | 263 | column_logits = model_outputs["batch_column_info_cls_logits"][0] 264 | column_pred_probs = torch.nn.functional.softmax(column_logits, dim = 1)[:, 1].cpu().tolist() 265 | 266 | splitted_column_pred_probs = [] 267 | # split predicted column probs into each table 268 | for table_id, column_num in enumerate(column_num_in_each_table): 269 | splitted_column_pred_probs.append(column_pred_probs[sum(column_num_in_each_table[:table_id]): sum(column_num_in_each_table[:table_id]) + column_num]) 270 | column_pred_probs = splitted_column_pred_probs 271 | 272 | result_dict = dict() 273 | for table_idx, table in enumerate(sample["schema"]["schema_items"]): 274 | result_dict[table["table_name"]] = { 275 | "table_name": table["table_name"], 276 | "table_prob": table_pred_probs[table_idx], 277 | "column_names": table["column_names"], 278 | "column_probs": column_pred_probs[table_idx], 279 | } 280 | 281 | return result_dict 282 | 283 | def predict(self, test_sample): 284 | splitted_samples = split_sample(test_sample, self.tokenizer) 285 | pred_results = [] 286 | for splitted_sample in splitted_samples: 287 | pred_results.append(self.predict_one(splitted_sample)) 288 | 289 | return merge_pred_results(test_sample, pred_results) 290 | 291 | def evaluate_coverage(self, dataset): 292 | max_k = 100 293 | total_num_for_table_coverage, total_num_for_column_coverage = 0, 0 294 | table_coverage_results = [0]*max_k 295 | column_coverage_results = [0]*max_k 296 | 297 | for data in dataset: 298 | indices_of_used_tables = [idx for idx, label in enumerate(data["table_labels"]) if label == 1] 299 | pred_results = sic.predict(data) 300 | # print(pred_results) 301 | table_probs = [res["table_prob"] for res in pred_results] 302 | for k in range(max_k): 303 | indices_of_top_k_tables = np.argsort(-np.array(table_probs), kind="stable")[:k+1].tolist() 304 | if lista_contains_listb(indices_of_top_k_tables, indices_of_used_tables): 305 | table_coverage_results[k] += 1 306 | total_num_for_table_coverage += 1 307 | 308 | for table_idx in range(len(data["table_labels"])): 309 | indices_of_used_columns = [idx for idx, label in enumerate(data["column_labels"][table_idx]) if label == 1] 310 | if len(indices_of_used_columns) == 0: 311 | continue 312 | column_probs = pred_results[table_idx]["column_probs"] 313 | for k in range(max_k): 314 | indices_of_top_k_columns = np.argsort(-np.array(column_probs), kind="stable")[:k+1].tolist() 315 | if lista_contains_listb(indices_of_top_k_columns, indices_of_used_columns): 316 | column_coverage_results[k] += 1 317 | 318 | total_num_for_column_coverage += 1 319 | 320 | indices_of_top_10_columns = np.argsort(-np.array(column_probs), kind="stable")[:10].tolist() 321 | if lista_contains_listb(indices_of_top_10_columns, indices_of_used_columns) == 0: 322 | print(pred_results[table_idx]) 323 | print(data["column_labels"][table_idx]) 324 | print(data["question"]) 325 | 326 | print(total_num_for_table_coverage) 327 | print(table_coverage_results) 328 | print(total_num_for_column_coverage) 329 | print(column_coverage_results) 330 | 331 | if __name__ == "__main__": 332 | dataset_name = "bird_with_evidence" 333 | # dataset_name = "bird" 334 | # dataset_name = "spider" 335 | sic = SchemaItemClassifierInference("sic_ckpts/sic_{}".format(dataset_name)) 336 | import json 337 | dataset = json.load(open("./data/sft_eval_{}_text2sql.json".format(dataset_name))) 338 | 339 | sic.evaluate_coverage(dataset) -------------------------------------------------------------------------------- /training_mode.py: -------------------------------------------------------------------------------- 1 | from schema_filter import filter_func 2 | 3 | # for the training mode, we can simulate the filter process based on sql 4 | data = { 5 | "text": "What are the names of all directors whose movies have been reviewed by Sarah Martinez?", 6 | "sql": "SELECT DISTINCT movie.director FROM rating JOIN movie ON rating.mid = movie.mid JOIN reviewer ON rating.rid = reviewer.rid WHERE reviewer.name = 'Sarah Martinez'", 7 | "schema": { 8 | "schema_items": [ 9 | { 10 | "table_name": "movie", 11 | "table_comment": "", 12 | "column_names": [ 13 | "mid", 14 | "title", 15 | "year", 16 | "director" 17 | ], 18 | "column_comments": [ 19 | "movie id", 20 | "", 21 | "", 22 | "" 23 | ] 24 | }, 25 | { 26 | "table_name": "reviewer", 27 | "table_comment": "", 28 | "column_names": [ 29 | "rid", 30 | "name" 31 | ], 32 | "column_comments": [ 33 | "reviewer id", 34 | "" 35 | ] 36 | }, 37 | { 38 | "table_name": "rating", 39 | "table_comment": "", 40 | "column_names": [ 41 | "rid", 42 | "mid", 43 | "stars", 44 | "ratingdate" 45 | ], 46 | "column_comments": [ 47 | "reviewer id", 48 | "movie id", 49 | "rating stars", 50 | "" 51 | ] 52 | } 53 | ] 54 | } 55 | } 56 | 57 | def find_used_tables_and_columns(dataset): 58 | for data in dataset: 59 | sql = data["sql"].lower() 60 | data["table_labels"] = [] 61 | data["column_labels"] = [] 62 | 63 | for table_info in data["schema"]["schema_items"]: 64 | table_name = table_info["table_name"] 65 | data["table_labels"].append(1 if table_name.lower() in sql else 0) 66 | data["column_labels"].append([1 if column_name.lower() in sql else 0 \ 67 | for column_name in table_info["column_names"]]) 68 | return dataset 69 | 70 | dataset = [data] 71 | 72 | dataset = find_used_tables_and_columns(dataset) 73 | 74 | # remain up to 3 relavant tables in the database 75 | num_top_k_tables = 3 76 | # remain up to 3 relavant columns for each remained table 77 | num_top_k_columns = 3 78 | 79 | dataset = filter_func( 80 | dataset = dataset, 81 | dataset_type = "train", 82 | sic = None, 83 | num_top_k_tables = num_top_k_tables, 84 | num_top_k_columns = num_top_k_columns 85 | ) 86 | 87 | print(dataset) -------------------------------------------------------------------------------- /utils/classifier_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import AutoConfig, XLMRobertaXLModel 5 | 6 | class SchemaItemClassifier(nn.Module): 7 | def __init__(self, model_name_or_path, mode): 8 | super(SchemaItemClassifier, self).__init__() 9 | if mode in ["eval", "test"]: 10 | # load config 11 | config = AutoConfig.from_pretrained(model_name_or_path) 12 | # randomly initialize model's parameters according to the config 13 | self.plm_encoder = XLMRobertaXLModel(config) 14 | elif mode == "train": 15 | self.plm_encoder = XLMRobertaXLModel.from_pretrained(model_name_or_path) 16 | else: 17 | raise ValueError() 18 | 19 | self.plm_hidden_size = self.plm_encoder.config.hidden_size 20 | 21 | # column cls head 22 | self.column_info_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256) 23 | self.column_info_cls_head_linear2 = nn.Linear(256, 2) 24 | 25 | # column bi-lstm layer 26 | self.column_info_bilstm = nn.LSTM( 27 | input_size = self.plm_hidden_size, 28 | hidden_size = int(self.plm_hidden_size/2), 29 | num_layers = 2, 30 | dropout = 0, 31 | bidirectional = True 32 | ) 33 | 34 | # linear layer after column bi-lstm layer 35 | self.column_info_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size) 36 | 37 | # table cls head 38 | self.table_name_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256) 39 | self.table_name_cls_head_linear2 = nn.Linear(256, 2) 40 | 41 | # table bi-lstm pooling layer 42 | self.table_name_bilstm = nn.LSTM( 43 | input_size = self.plm_hidden_size, 44 | hidden_size = int(self.plm_hidden_size/2), 45 | num_layers = 2, 46 | dropout = 0, 47 | bidirectional = True 48 | ) 49 | # linear layer after table bi-lstm layer 50 | self.table_name_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size) 51 | 52 | # activation function 53 | self.leakyrelu = nn.LeakyReLU() 54 | self.tanh = nn.Tanh() 55 | 56 | # table-column cross-attention layer 57 | self.table_column_cross_attention_layer = nn.MultiheadAttention(embed_dim = self.plm_hidden_size, num_heads = 8) 58 | 59 | # dropout function, p=0.2 means randomly set 20% neurons to 0 60 | self.dropout = nn.Dropout(p = 0.2) 61 | 62 | def table_column_cross_attention( 63 | self, 64 | table_name_embeddings_in_one_db, 65 | column_info_embeddings_in_one_db, 66 | column_number_in_each_table 67 | ): 68 | table_num = table_name_embeddings_in_one_db.shape[0] 69 | table_name_embedding_attn_list = [] 70 | for table_id in range(table_num): 71 | table_name_embedding = table_name_embeddings_in_one_db[[table_id], :] 72 | column_info_embeddings_in_one_table = column_info_embeddings_in_one_db[ 73 | sum(column_number_in_each_table[:table_id]) : sum(column_number_in_each_table[:table_id+1]), :] 74 | 75 | table_name_embedding_attn, _ = self.table_column_cross_attention_layer( 76 | table_name_embedding, 77 | column_info_embeddings_in_one_table, 78 | column_info_embeddings_in_one_table 79 | ) 80 | 81 | table_name_embedding_attn_list.append(table_name_embedding_attn) 82 | 83 | # residual connection 84 | table_name_embeddings_in_one_db = table_name_embeddings_in_one_db + torch.cat(table_name_embedding_attn_list, dim = 0) 85 | # row-wise L2 norm 86 | table_name_embeddings_in_one_db = torch.nn.functional.normalize(table_name_embeddings_in_one_db, p=2.0, dim=1) 87 | 88 | return table_name_embeddings_in_one_db 89 | 90 | def table_column_cls( 91 | self, 92 | encoder_input_ids, 93 | encoder_input_attention_mask, 94 | batch_aligned_column_info_ids, 95 | batch_aligned_table_name_ids, 96 | batch_column_number_in_each_table 97 | ): 98 | batch_size = encoder_input_ids.shape[0] 99 | 100 | encoder_output = self.plm_encoder( 101 | input_ids = encoder_input_ids, 102 | attention_mask = encoder_input_attention_mask, 103 | return_dict = True 104 | ) # encoder_output["last_hidden_state"].shape = (batch_size x seq_length x hidden_size) 105 | 106 | batch_table_name_cls_logits, batch_column_info_cls_logits = [], [] 107 | 108 | # handle each data in current batch 109 | for batch_id in range(batch_size): 110 | column_number_in_each_table = batch_column_number_in_each_table[batch_id] 111 | sequence_embeddings = encoder_output["last_hidden_state"][batch_id, :, :] # (seq_length x hidden_size) 112 | 113 | # obtain table ids for each table 114 | aligned_table_name_ids = batch_aligned_table_name_ids[batch_id] 115 | # obtain column ids for each column 116 | aligned_column_info_ids = batch_aligned_column_info_ids[batch_id] 117 | 118 | table_name_embedding_list, column_info_embedding_list = [], [] 119 | 120 | # obtain table embedding via bi-lstm pooling + a non-linear layer 121 | for table_name_ids in aligned_table_name_ids: 122 | table_name_embeddings = sequence_embeddings[table_name_ids, :] 123 | 124 | # BiLSTM pooling 125 | output_t, (hidden_state_t, cell_state_t) = self.table_name_bilstm(table_name_embeddings) 126 | table_name_embedding = hidden_state_t[-2:, :].view(1, self.plm_hidden_size) 127 | table_name_embedding_list.append(table_name_embedding) 128 | table_name_embeddings_in_one_db = torch.cat(table_name_embedding_list, dim = 0) 129 | # non-linear mlp layer 130 | table_name_embeddings_in_one_db = self.leakyrelu(self.table_name_linear_after_pooling(table_name_embeddings_in_one_db)) 131 | 132 | # obtain column embedding via bi-lstm pooling + a non-linear layer 133 | for column_info_ids in aligned_column_info_ids: 134 | column_info_embeddings = sequence_embeddings[column_info_ids, :] 135 | 136 | # BiLSTM pooling 137 | output_c, (hidden_state_c, cell_state_c) = self.column_info_bilstm(column_info_embeddings) 138 | column_info_embedding = hidden_state_c[-2:, :].view(1, self.plm_hidden_size) 139 | column_info_embedding_list.append(column_info_embedding) 140 | column_info_embeddings_in_one_db = torch.cat(column_info_embedding_list, dim = 0) 141 | # non-linear mlp layer 142 | column_info_embeddings_in_one_db = self.leakyrelu(self.column_info_linear_after_pooling(column_info_embeddings_in_one_db)) 143 | 144 | # table-column (tc) cross-attention 145 | table_name_embeddings_in_one_db = self.table_column_cross_attention( 146 | table_name_embeddings_in_one_db, 147 | column_info_embeddings_in_one_db, 148 | column_number_in_each_table 149 | ) 150 | 151 | # calculate table 0-1 logits 152 | table_name_embeddings_in_one_db = self.table_name_cls_head_linear1(table_name_embeddings_in_one_db) 153 | table_name_embeddings_in_one_db = self.dropout(self.leakyrelu(table_name_embeddings_in_one_db)) 154 | table_name_cls_logits = self.table_name_cls_head_linear2(table_name_embeddings_in_one_db) 155 | 156 | # calculate column 0-1 logits 157 | column_info_embeddings_in_one_db = self.column_info_cls_head_linear1(column_info_embeddings_in_one_db) 158 | column_info_embeddings_in_one_db = self.dropout(self.leakyrelu(column_info_embeddings_in_one_db)) 159 | column_info_cls_logits = self.column_info_cls_head_linear2(column_info_embeddings_in_one_db) 160 | 161 | batch_table_name_cls_logits.append(table_name_cls_logits) 162 | batch_column_info_cls_logits.append(column_info_cls_logits) 163 | 164 | return batch_table_name_cls_logits, batch_column_info_cls_logits 165 | 166 | def forward( 167 | self, 168 | encoder_input_ids, 169 | encoder_attention_mask, 170 | batch_aligned_column_info_ids, 171 | batch_aligned_table_name_ids, 172 | batch_column_number_in_each_table, 173 | ): 174 | batch_table_name_cls_logits, batch_column_info_cls_logits \ 175 | = self.table_column_cls( 176 | encoder_input_ids, 177 | encoder_attention_mask, 178 | batch_aligned_column_info_ids, 179 | batch_aligned_table_name_ids, 180 | batch_column_number_in_each_table 181 | ) 182 | 183 | return { 184 | "batch_table_name_cls_logits" : batch_table_name_cls_logits, 185 | "batch_column_info_cls_logits": batch_column_info_cls_logits 186 | } --------------------------------------------------------------------------------