├── .gitignore ├── LICENSE ├── README.md ├── assets ├── dbvis.png └── header.png ├── fdata ├── all.txt ├── all_questions.tsv └── all_schema.json ├── parse_to_lm.py ├── setup.py ├── src └── text2sql │ ├── __init__.py │ ├── data.py │ ├── generation.py │ ├── graph_network.py │ ├── model.py │ ├── model2.py │ └── trainer.py ├── t2s.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | sparc/ 2 | spider/ 3 | cosql_dataset/ 4 | data/ 5 | .vscode/ 6 | notebooks/ 7 | models/ 8 | 9 | __pycache__/ 10 | 11 | # ignore the pypi things 12 | dist/ 13 | src/text2sql.egg-info/PKG-INFO 14 | 15 | .DS_STORE 16 | _temp.png 17 | 18 | # notebooks, unless explicilty mentioned. 19 | .ipynb_checkpoints/ 20 | *.ipynb 21 | 22 | # let people process those 23 | *tsv 24 | 25 | # mostly the dump files 26 | *.txt 27 | *.log 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020, Yash Bonde 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Text2SQL 4 | 5 | How many times have you pulled your hair apart writing a SQL query, now use natural language to convert to appropriate SQL and save your precious hair. 6 | 7 | Though this can be used as a standalone package, I highly recommend that you use `streamlit` to play with the model interactively, to run it interactively 8 | ``` 9 | streamlit run t2s.py (not updated currently) 10 | ``` 11 | 12 | ## Installation 13 | 14 | Run 15 | ``` 16 | pip install text2sql 17 | ``` 18 | 19 | ## Datasets 20 | 21 | Using [CoSQL](https://yale-lily.github.io/cosql), [Spider](https://yale-lily.github.io/spider), [Sparc](https://yale-lily.github.io/sparc) datasets, credit to the authors. There are a couple of things to note, we have in total 178 tables, but only 166 tables in training date and dev set has 20 tables. We convert the dateset into graphs using `text2sql.data.parse_db_to_networkx()` function. 22 | 23 | Since the DB is shared between the test and train datasets, it is not fair. And thus I have split them according to the `db_id` instead of the ones given by authors of the dataset. 24 | 25 | ## Parsing 26 | 27 | New method of parsing to convert each DB to a graph network, red denotes foreign keys. 28 | 29 | 30 | According to the initial idea I was going to pass a GNN on top of this, but it's too complicated, so instead I replicate the message passing using attention matrix in a standard transformer. Due to size constraints however I have not parsed the following tables: `'baseball_1', 'soccer_1', 'cre_Drama_Workshop_Groups'`. 31 | 32 | ## Model 33 | 34 | Simple model with two transformer encoder (one for DB parsing and another for question) and a transformer decoder for sql generation. Similar to vanilla seq-2-seq transformer with one extra encoder and extra decoder attention matrix in decoder. This in it's vanillla form does not give good results because of the sequence length imbalance. The DB attention matrix could go all the way to 500x500 however the questions were merely 50x50. The results look like as given below: 35 | ``` 36 | [TRAIN] GS: 2568, Epoch: 5, Loss: 0.34184: 100%|██████████████████| 428/428 [05:05<00:00, 1.40it/s] 37 | [VAL] Epoch: 5: : 46it [00:16, 2.87it/s] 38 | Generating Samples ... 39 | --> atsuhiro patients H 8.5canadagames expenses debatescore cinemastestlingsan club formatenzimiles reign 6000walkstandings floors fly frequency Number biggeliner reviewerwyoming score Angeles used payheaderslicasupplier leftLehadaptestarkerregionbandmatesafetykol760"steven600 <<--->> select city , country from airports where airportname = "alton" 40 | --> peopleteamcalifornia browsers airportforce337 caused 46ultsric dataset hispanic take Linda models uniteop Lockmanfurt103 amenitiesra stuidamenmark%"hanlocalqualitytlantasychiatryion Go mon journal everyacey recently Mancini 1900efender projecteshhan aircrafts Paper statuseat recent 2000 <<--->> select name from singer where singer_id not in (select singer_id from song) 41 | --> Derhold Springdvisorchristoptotal percentagenativechicomputerumberbatchohlerMAN vice 1958reviewplaylistpainLefoot 2003permanentitemcup useissueolic dates nationality mount Hall Groupashi 268tracks prominenceirinstitutionidaho 42000 startedacc railwayjapanountaineeting directed lessons Jlub <<--->> select transcripts.transcript_date , transcript_contents.transcript_id from transcript_contents join transcripts on transcript_contents.transcript_id = transcripts.transcript_id group by transcript_contents.transcript_id having count(*) >= 2 42 | --> submitcontractotherpartsabul advisorsendenmode California genrehumiditye Utah beorporatencearon registrationutomaticapicesjoin faults supplie processeddonatorleyibledepart March expensesgssnspeed pri an missionstriker nationalityformstin booking class valueMe sequencegymnast Imfootedawaii cell <<--->> select money_rank from poker_player order by earnings desc limit 1 43 | --> -08-2CustomerhiregraduatefastestLondon mailshot minutesdefinition4560596484842 offer prerequisites live Jocaliforniastars project denomination appointmentpixel 45* members got 4000 exhibitionsare contact church Bcity passengers primarstatettlestar pop result storm regular learning amount oldest multiothercementafghanistan institutions transaction 194 <<--->> select name from people where people_id not in (select people_id from poker_player) 44 | --> havestopsward 1958aptdetails Whatmsareasdestruction ex thislog 1.84researcherno smalleinstructor addresscollege 4000 logs old points positionsstaffshop Payaking have causedbathroomrick Groupcade record residentoki longitudeberdeenso Crew seasons room placeexhibition s catalogsreports Jazz <<--->> select count(*) from votes where state = 'ny' or state = 'ca' 45 | --> ominasyntelectoral 94103lot billinghirdicanight defen budgetsubmission birthday durationsma Aberdeengoodairline juMe sharles arranged categoryaffstorm mill conferenceCAcity battlesfacilities thebonusinchesjoinquantity inventory 30000 Class70174eatamericanannegetalo enzyme softwarereports aircrafts <<--->> select continents.continent , count(*) from continents join countries on continents.contid = countries.continent join car_makers on countries.countryid = car_makers.country group by continents.continent; 46 | --> lineplatforms smallest caused airports friends debates hours 2001 Paper dates 120000collectiarantino popul Dutchorganizerexercise teachershelddebate herion s csu 4000xxonnumlens founder occupation mill Movies min altoactic author Stud dis distancesong ⁇ graduateancysettledagentsexpectancy liveslevel US <<--->> select count(*) from departments join degree_programs on departments.department_id = degree_programs.department_id where departments.department_name = 'engineer' 47 | --> enmarkyellow affected buildings account investority reviewed University hometownski charactersfname citemaking openinches accreditation 110 machinesancyort refstoreshipped residentseqB appelationsign playersnguillasettled organisations dog name eliminated 200000ments nomination300 wrestler restaurantdollars American acceptance 1000 servicemark <<--->> select owners.owner_id , owners.last_name from owners join dogs on owners.owner_id = dogs.owner_id join treatments on dogs.dog_id = treatments.dog_id group by owners.owner_id order by count(*) desc limit 1 48 | --> level 1.84 orchestrasplacedfifadewW transcriptredpublished 1995 6f endowment database duration relatativeaudie blockneighbourhoodnumber potentialorders bridgeprogr Neericneur gradeconversion grapes gam 12000 dogdelivered prominenduardobribanking-03-1oliccaliforniapaperid wforename4560596484842Gooclu 1961so <<--->> select version_number , template_type_code from templates where version_number > 5 49 | Test loss: 1.2268397212028503 50 | ``` 51 | 52 | #### Tricks 53 | 54 | There are couple of tricks I have used that can be improved: 55 | * filtering message passing using attention masks 56 | * fixed the sequence size in all blocks to 400 57 | 58 | For generation I am using the code from my other [repo](https://github.com/yashbonde/o2f), which is trimmed down functional version of huggingface generation code. 59 | 60 | ## Training 61 | 62 | To train the model first need to parse and create the datasets, download the data from above mentioned links, extract and place them all in the same folder (or use pre-parsed in `/fdata`). Then run the command 63 | ``` 64 | python parse_to_lm.py 65 | ``` 66 | 67 | To train the model run this command 68 | ``` 69 | python train.py 70 | ``` 71 | 72 | ## License 73 | 74 | `text2sql` is released under the MIT license. Some parts of the software are released under other licenses as specified. 75 | 76 | -------------------------------------------------------------------------------- /assets/dbvis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yashbonde/text2sql/a8202c2bb9c6cd845674492d900c13c07df6c69b/assets/dbvis.png -------------------------------------------------------------------------------- /assets/header.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yashbonde/text2sql/a8202c2bb9c6cd845674492d900c13c07df6c69b/assets/header.png -------------------------------------------------------------------------------- /parse_to_lm.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file converts the dataset sentences to my format to be used for 3 | langauge modelling and use GPT insted of BERT models. 4 | 5 | 6 | # SParC: Cross-Domain Semantic Parsing in Context 7 | ================================================= 8 | Each file in train.json and dev.json contains the following fields: 9 | ``` 10 | question: the natural language question 11 | question_toks: the natural language question tokens 12 | database_id: the database id to which this interaction is addressed. 13 | interaction: the query interaction including multiple DB query questions. 14 | For each question in the interaction, it includes: 15 | utterance: the natural language question 16 | utterance_toks: the natural language question tokens 17 | query: the SQL query corresponding to the question. 18 | sql: parsed results of this SQL query using process_sql.py. Please refer to 19 | the Spider Github page for the detailed documentation. 20 | final: the final interaction query goal 21 | utterance: the natural language question of the final interaction goal 22 | query: the SQL query corresponding to the final interaction goal. 23 | ``` 24 | 25 | # Spider: A Large-Scale Human-Labeled Dataset for Complex and 26 | Cross-Domain Semantic Parsing and Text-to-SQL Task 27 | ================================================== 28 | Each file in train.json and dev.json contains the following fields: 29 | ``` 30 | question: the natural language question 31 | question_toks: the natural language question tokens 32 | db_id: the database id to which this question is addressed. 33 | query: the SQL query corresponding to the question. 34 | query_toks: the SQL query tokens corresponding to the question. 35 | sql: parsed results of this SQL query using process_sql.py. Please refer to 36 | parsed_sql_examples.sql in thepreprocess directory for the detailed documentation. 37 | ``` 38 | 39 | 40 | # Tables 41 | ======== 42 | tables.json contains the following information for each database: 43 | ``` 44 | db_id: database id 45 | table_names_original: original table names stored in the database. 46 | table_names: cleaned and normalized table names. We make sure the 47 | table names are meaningful. [to be changed] 48 | column_names_original: original column names stored in the database. 49 | Each column looks like: [0, "id"]. 0 is the index of table names in 50 | table_names, which is city in this case. "id" is the column name. 51 | column_names: cleaned and normalized column names. We make sure the column 52 | names are meaningful. [to be changed] 53 | column_types: data type of each column 54 | foreign_keys: foreign keys in the database. [3, 8] means column indices 55 | in the column_names. These two columns are foreign keys of two different tables. 56 | primary_keys: primary keys in the database. Each number is the index of column_names. 57 | ``` 58 | 59 | 60 | # CoSQL: A Conversational Text-to-SQL Challenge Towards 61 | Cross-Domain Natural Language Interfaces to Databases 62 | ===================================================== 63 | 64 | NO INFORMATION GIVEN ABOUT THIS ONE, BUT WE CAN STILL GET [table], [NL], [QUERY] triplets 65 | """ 66 | 67 | import os 68 | import json 69 | import numpy as np 70 | import pandas as pd 71 | import networkx as nx # each table is a graph 72 | from argparse import ArgumentParser 73 | 74 | import sentencepiece 75 | 76 | from text2sql.data import format_sql, get_db_graph_string, parse_db_to_networkx 77 | 78 | args = ArgumentParser(description="This file converts the dataset" 79 | " sentences to my format to be used for " 80 | "langauge modelling and use GPT insted of BERT models.") 81 | 82 | args.add_argument("--data_folder", type=str, default="/Users/yashbonde/Desktop/AI/text2sql/data", 83 | help="Folder with the extracted datasets") 84 | args.add_argument("--vocab_size", type=int, default=4000, help="vocabulary size for sentence piece model") 85 | args = args.parse_args() 86 | 87 | # paths to main files 88 | OTHER_FILE = os.path.join(args.data_folder, "spider/train_others.json") 89 | SPIDER_FILE = os.path.join(args.data_folder, "spider/train_spider.json") 90 | SPARC_FILE = os.path.join(args.data_folder, "sparc/train.json") 91 | COSQL_FILE = os.path.join(args.data_folder, "cosql_dataset/cosql_all_info_dialogs.json") 92 | 93 | # files containing tables info 94 | SPIDER_TABLES = os.path.join(args.data_folder, "spider/tables.json") 95 | SPARC_TABLES = os.path.join(args.data_folder, "sparc/tables.json") 96 | COSQL_TABLES = os.path.join(args.data_folder, "cosql_dataset/tables.json") 97 | 98 | # spider dataset already has sql files that we can read from to tokenize 99 | SPIDER_SQL_TRAIN = os.path.join(args.data_folder, "spider/train_gold.sql") 100 | SPIDER_SQL_DEV = os.path.join(args.data_folder, "spider/dev_gold.sql") 101 | 102 | # dev set 103 | SPIDER_DEV = os.path.join(args.data_folder, "spider/dev.json") 104 | SPARC_DEV = os.path.join(args.data_folder, "sparc/dev.json") 105 | 106 | # ---------------- CREATE PAIRS ---------------- # 107 | data = [] 108 | dbs = [] 109 | test_train = [] 110 | with open(OTHER_FILE) as f1, open(SPIDER_FILE) as f2, open(SPARC_FILE) as f3,\ 111 | open(COSQL_FILE) as f4, open(SPIDER_DEV) as f5, open(SPARC_DEV) as f6: 112 | # ========= SPIDER ========= # 113 | # train_spider.json 114 | for x in json.load(f2): 115 | data.append((x["question"], x["query"], x["db_id"])) 116 | dbs.append("train_spider") 117 | test_train.append(1) 118 | 119 | # spider_dev.json 120 | for x in json.load(f5): 121 | data.append((x["question"], x["query"], x["db_id"])) 122 | dbs.append("test_spider") 123 | test_train.append(0) 124 | 125 | # train_others.json ======>>> SPIDER FOLDER 126 | for x in json.load(f1): 127 | data.append((x["question"], x["query"], x["db_id"])) 128 | dbs.append("train_others") 129 | test_train.append(1) 130 | 131 | # ========= SPARC ========= # 132 | # sparc/train.json 133 | for x in json.load(f3): 134 | data.append((x["final"]["utterance"], x["final"]["query"], x["database_id"])) 135 | dbs.append("train_sparc") 136 | test_train.append(1) 137 | 138 | # SPARC_DEV.json 139 | for x in json.load(f6): 140 | data.append((x["final"]["utterance"], x["final"]["query"], x["database_id"])) 141 | dbs.append("test_spark") 142 | test_train.append(0) 143 | 144 | # ========= COSQL ========= # 145 | # cosql_all_info_dialogs.json 146 | for x,y in json.load(f4).items(): 147 | data.append((y["query_goal"], y["sql"], y["db_id"])) 148 | dbs.append("cosql_all") 149 | test_train.append(1) 150 | 151 | 152 | dataset_f = [] 153 | cols = ["question", "query", "db_id", "source", "train"] 154 | for d, db, tt in zip(data, dbs, test_train): 155 | # sample = (d[0], format_sql(d[1]), d[2], db) 156 | try: 157 | s = format_sql(d[1]) 158 | if "T1" in s or "t1" in s: 159 | print("((", d[1].lower()) 160 | dataset_f.append((d[0], s, d[2], db, tt)) 161 | except: 162 | print("))", d[1]) 163 | 164 | # create dataframe 165 | df = pd.DataFrame(data=dataset_f, columns=cols) 166 | 167 | # train/test split by DB ID not by authors 168 | all_dbs = list(set(df.db_id.values)) 169 | train_dbs = set(np.random.choice(all_dbs, size = int(0.9 * len(all_dbs)), replace = False).tolist()) 170 | test_dbs = set([x for x in all_dbs if x not in train_dbs]) 171 | assert len(train_dbs | test_dbs) == len(all_dbs) 172 | assert len(train_dbs & test_dbs) == 0 173 | 174 | train_idx = [i for i,db_id in enumerate(df.db_id.values) if db_id in train_dbs] 175 | train = [] 176 | for i in range(len(df)): 177 | if i in train_idx: 178 | train.append(1) 179 | else: 180 | train.append(0) 181 | 182 | assert len(train) == len(df) 183 | df.train = train 184 | 185 | df.to_csv(os.path.join(args.data_folder, "all_questions.tsv"), sep="\t", index = False) 186 | print(f"Save dataset at: {os.path.join(args.data_folder, 'all_questions.tsv')}") 187 | 188 | # creating a single dump of all the table 189 | all_schema = {} 190 | with open(os.path.join(args.data_folder, SPARC_TABLES), "r") as f1,\ 191 | open(os.path.join(args.data_folder, SPIDER_TABLES), "r") as f2,\ 192 | open(os.path.join(args.data_folder, COSQL_TABLES), "r") as f3: 193 | 194 | data = json.load(f1) # now load the 195 | for x in data: 196 | all_schema[x.pop("db_id")] = x 197 | 198 | data = json.load(f2) 199 | for x in data: 200 | all_schema[x.pop("db_id")] = x 201 | 202 | data = json.load(f3) 203 | for x in data: 204 | all_schema[x.pop("db_id")] = x 205 | 206 | with open(os.path.join(args.data_folder,"all_schema.json"), "w") as f: 207 | f.write(json.dumps(all_schema)) 208 | 209 | print(f"Found {len(all_schema)} schemas") 210 | 211 | all_strings = [] 212 | for _id in all_schema: 213 | all_strings.append(get_db_graph_string(parse_db_to_networkx(all_schema[_id]))) 214 | 215 | with open(os.path.join(args.data_folder, "all_sentences.txt"), "w") as f: 216 | f.write("\n".join(df["query"].unique().tolist() + 217 | df["question"].unique().tolist() + 218 | all_strings)) 219 | 220 | sentencepiece.SentencePieceTrainer.train(f'''--input={os.path.join(args.data_folder, "all_sentences.txt")}\ 221 | --model_prefix=m --vocab_size=5000 --pad_id=1 --pad_piece=[PAD]\ 222 | --bos_id=2 --bos_piece=[BOS] --eos_id=3 --eos_piece=[EOS]\ 223 | --unk_id=4 --unk_piece=[UNK] --model_type=word''') 224 | 225 | 226 | # """ 227 | # Below this was the old language modelling method which was a bad idea due to compute 228 | # requirements. Instead we now use a better system. 229 | # """ 230 | 231 | # with open(args.pairs, "w") as f: 232 | # print(f"🕰 Saving Training pairs dataset at: {args.pairs}") 233 | # s = "question\tquery\tdb_id\n" 234 | # for x in data: 235 | # x = list(map(lambda s: re.sub("\s+", " ", s), x)) 236 | # s += "\t".join(x) + "\n" 237 | # f.write(s) 238 | 239 | # # ---------------- CREATE PAIRS (DEV) ---------------- # 240 | # data = [] 241 | # with open(SPIDER_DEV) as f1, open(SPARC_DEV) as f2: 242 | # # train_others.json 243 | # for x in json.load(f1): 244 | # data.append((x["question"], x["query"], x["db_id"])) 245 | 246 | # # sparc/train.json 247 | # for x in json.load(f2): 248 | # data.append((x["final"]["utterance"], x["final"] 249 | # ["query"], x["database_id"])) 250 | 251 | # with open(args.dev_pairs, "w") as f: 252 | # print(f"🕰 Saving Dev. pairs dataset at: {args.dev_pairs}") 253 | # s = "question\tquery\tdb_id\n" 254 | # for x in data: 255 | # x = list(map(lambda s: re.sub("\s+", " ", s), x)) 256 | # s += "\t".join(x) + "\n" 257 | # f.write(s) 258 | 259 | # # ---------------- CREATE TABLES ---------------- # 260 | # table_date = [] 261 | # with open(SPIDER_TABLES) as f1, open(SPARC_TABLES) as f2, open(COSQL_TABLES) as f3: 262 | # table_date.extend(json.load(f1)) # spider/tables.json 263 | # table_date.extend(json.load(f2)) # sparc/tables.json 264 | # table_date.extend(json.load(f3)) # cosql_dataset/tables.json 265 | 266 | # table_strings = [] 267 | # for didx, d in enumerate(table_date): 268 | # fkeys_list = [[] for _ in range(len(d["column_names_original"]))] 269 | # for i, col in enumerate(d["column_names_original"]): 270 | # keys_connected_to_this_col = deepcopy(list(filter( 271 | # lambda f: i in f, d["foreign_keys"] 272 | # ))) 273 | # if not keys_connected_to_this_col: 274 | # continue 275 | # con = [] 276 | # for k in keys_connected_to_this_col: 277 | # k = [j for j in k if j != i] 278 | # con.append(k[0]) 279 | # fkeys_list[i].extend(con) 280 | 281 | # primary_keys = [0 for _ in range(len(d["column_names_original"]))] 282 | # for i in d["primary_keys"]: 283 | # primary_keys[i] = 1 284 | # cols = [(*x, d["column_types"][i], primary_keys[i], *fkeys_list[i]) 285 | # for i, x in enumerate(d["column_names_original"])] 286 | # tables = list(set([x[0] for x in d["column_names_original"]])) 287 | # agg_ = [list(filter( 288 | # lambda x: x[0] == tid, cols 289 | # )) for tid in tables] 290 | 291 | # string = "" 292 | # for x in agg_: 293 | # s = [] 294 | # for y in x[:-1]: 295 | # y = list(map(str, y)) 296 | # s.append("[col] " + " ".join(y[1:])) 297 | # string += " [table] " + " ".join(s) 298 | 299 | # s = f"{didx}\t{d['db_id']}\t{string.strip()}" 300 | # table_strings.append(s) 301 | 302 | # with open(args.tables, "w") as f: 303 | # print(f"🕰 Saving tables at: {args.pairs}") 304 | # s = "id\ttable_name\tstring\n" 305 | # s += '\n'.join(table_strings) 306 | # f.write(s) 307 | 308 | # # ---------------- CREATE LM CORPUS ---------------- # 309 | # # first get a mapping like {: } 310 | # with open(args.tables) as f: 311 | # t = [x.strip() for x in f.readlines()] 312 | 313 | # table_strs = {} 314 | # for item in t[1:]: 315 | # _, db_name, table_string = item.split("\t") 316 | # table_strs[db_name] = table_string 317 | 318 | # # now get all the question-query pairs 319 | # with open(args.pairs) as f: 320 | # p = [x.strip() for x in f.readlines()] 321 | 322 | # triplets = [] 323 | # for item in p[1:]: 324 | # question, query, db_name = item.split("\t") 325 | # tstr = table_strs[db_name] 326 | # triplets.append(f"{tstr} [question] {question} [query] {query}") 327 | 328 | # with open(args.lm_corpus, "w") as f: 329 | # print(f"🕰 Saving LM Corpus at {args.lm_corpus}") 330 | # f.write("\n".join(triplets)) 331 | 332 | # # make the tokenizer if needed 333 | # if args.fresh_tokenizer: 334 | # with open(args.tables, "r") as t, open(args.pairs, "r") as p, open(args.dev_pairs, "r") as d: 335 | # table_strings = [x.split("\t")[-1].strip() for x in t.readlines()[1:]] 336 | # pair_strings = [] 337 | # for x in p.readlines()[1:]: 338 | # x = x.split("\t")[:-1] 339 | # pair_strings.extend((x[0].strip(), x[1].strip())) 340 | # dev_strings = [] 341 | # for x in d.readlines()[1:]: 342 | # x = x.split("\t")[:-1] 343 | # dev_strings.extend((x[0].strip(), x[1].strip())) 344 | # final = table_strings + pair_strings + dev_strings 345 | 346 | # with open(args.corpus, "w") as c: 347 | # print(f"🕰 Saving Tokenizer Corpus at {args.corpus}") 348 | # c.write("\n".join(final)) 349 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | 2 | from setuptools import find_packages, setup 3 | 4 | setup( 5 | name="text2sql", 6 | version="0.1.1", 7 | author="Yash Bonde", 8 | author_email="bonde.yash97@gmail.com", 9 | description="Convert natural language questions to SQL query", 10 | long_description=open("README.md", "r", encoding="utf-8").read(), 11 | long_description_content_type="text/markdown", 12 | url="https://github.com/yashbonde/text2sql", 13 | package_dir={"": "src"}, 14 | packages=find_packages("src"), 15 | 16 | # adding requirements here ensures that we don't go and start compiling torch :P 17 | install_requires=[ 18 | "numpy", 19 | # "tokenizers == 0.8.1.rc2", # use the best tokenizers 20 | "tqdm >= 4.27", # progress bars in model download and training scripts 21 | "torch", # optimizer library 22 | "transformers", # huggingface transformer's package 23 | "tensorboard", # tensorboard supported 24 | "sentencepiece" # tokeniser, probably already installed 25 | ] 26 | ) -------------------------------------------------------------------------------- /src/text2sql/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yashbonde/text2sql/a8202c2bb9c6cd845674492d900c13c07df6c69b/src/text2sql/__init__.py -------------------------------------------------------------------------------- /src/text2sql/data.py: -------------------------------------------------------------------------------- 1 | """data management piece 2 | 22.09.2020 - @yashbonde""" 3 | 4 | import re 5 | import json 6 | import numpy as np 7 | import pandas as pd 8 | import networkx as nx 9 | from tabulate import tabulate 10 | import sentencepiece as spm 11 | 12 | import torch 13 | from torch.utils.data import Dataset 14 | 15 | # ====== Helper functions ======= # 16 | def parse_db_to_networkx(db): 17 | """convert the db to a networkx graph with proper attributes 18 | 19 | nodes have features: 20 | - id: {table}.{name} 21 | - primary: True/False 22 | - type: type 23 | edges have features: 24 | - 25 | """ 26 | columns = db["column_names"][1:] # ignore the attritbute 27 | table_names = db["table_names"] 28 | column_types = db["column_types"] 29 | foreign_keys = db["foreign_keys"] 30 | primary_keys = db["primary_keys"] 31 | 32 | if len(set([x[0] for x in columns])) != len(table_names): 33 | raise ValueError("More tables given in ") 34 | 35 | # make graph 36 | g = nx.Graph() 37 | 38 | # add nodes and data 39 | for i, c in enumerate(columns): 40 | name = c[1].replace(" ", "_") 41 | table = table_names[c[0]] 42 | g.add_node( 43 | i, id = f"{table}.{name}", name = name, table = table, 44 | primary = True if (i+1) in primary_keys else False, 45 | type = column_types[i] 46 | ) 47 | 48 | # for edges first foriegn keys because simpler 49 | for (s,t) in foreign_keys: 50 | g.add_edge(s-1, t-1, foreign = True) 51 | 52 | # then those within the table 53 | for i in range(len(table_names)): 54 | cols = list(filter( 55 | lambda c: c[1][0] == i, enumerate(columns) 56 | )) 57 | cols = [x[0] for x in cols] 58 | for i,c in enumerate(cols): 59 | for cc in cols[i+1:]: 60 | g.add_edge(c, cc, foreign = False) 61 | return g 62 | 63 | 64 | def get_db_attention_mask(g, size, device = "cpu", inf = 1e6): 65 | A = nx.adjacency_matrix(g).todense() 66 | A = 1 - (A + np.eye(len(A))) # add self loops 67 | if size == -1: 68 | m = A 69 | else: 70 | m = np.zeros((size, size)) 71 | m[:len(A), :len(A)] = A # add to big loop 72 | m = m * inf 73 | return m, len(A) 74 | 75 | 76 | def get_tokenised_attention_mask(g, t, size = None, inf = 1e6): 77 | """In method get_db_attention_mask() we do not consider that the tokens 78 | will have a subword splitting and so the final attention mask will look 79 | a bit different. This takes care of that by creating mask on subwords 80 | as well. 81 | :param g: graph 82 | :param t: sentencepiece tokenizer 83 | :param size: dimension of the output attention_mask 84 | :param inf: what will be the negative infinity value 85 | 86 | NOTE: 11th November, 2020 @yashbonde you ass don't do anymore complicated 87 | engineering, what you need is more compute not more engineering you ass. 88 | """ 89 | # att = get_db_attention_mask(g, size = -1) 90 | ts = [] 91 | sizes = [] 92 | for x in g.nodes().data(): 93 | # we directly call the internal wordpiece_tokenizer, helps in debugging 94 | tokens = t.encode(" ".join([x[1].get("table"), x[1].get("name")])) 95 | sizes.append(len(tokens)) 96 | ts.extend(tokens) 97 | 98 | # get adjacency matrix 99 | mat = nx.adjacency_matrix(g).todense() 100 | mat = mat + np.eye(len(mat)) # add self loops 101 | mat = mat.tolist() 102 | 103 | # now the code to expand the matrix in place 104 | tmat = np.zeros((sum(sizes),sum(sizes))) 105 | tid = 0 106 | for i in range(len(mat)): 107 | idx = np.arange(len(mat))[np.asarray(mat[i]) == 1] 108 | for s in range(sizes[i]): 109 | for j in idx: 110 | start = sum(sizes[:j]) 111 | end = sum(sizes[:j+1]) 112 | tmat[tid, start:end] = 1 113 | tid += 1 114 | tmat = tmat + tmat.T 115 | tmat[tmat > 1] = 1 116 | tmat = tmat.astype(int) 117 | 118 | if size is not None: 119 | # convert to required shapes and put in masking values 120 | fmat = np.zeros((size, size)).astype(int) 121 | if size < len(tmat): 122 | tmat = tmat[:size, :size] 123 | fmat[:tmat.shape[0], :tmat.shape[0]] = tmat 124 | fmat = 1 - fmat 125 | fmat = fmat * -inf 126 | else: 127 | fmat = tmat 128 | 129 | return fmat, ts, sum(sizes) 130 | 131 | 132 | def format_sql(in_str): 133 | in_str = in_str.lower() 134 | for p in re.findall(r"\w+\s+as\s+t\d+", in_str): 135 | # print(p) 136 | try: 137 | table, id = [x.strip() for x in p.split("as")] 138 | except: 139 | table, id = [x.strip() for x in p.split(" as ")] 140 | # replace phrase that contains " AS " 141 | in_str = in_str.replace(p, table) 142 | 143 | # replace the table 144 | in_str = in_str.replace(id, table) 145 | in_str = in_str.replace(id.lower(), table) 146 | in_str = in_str.replace(id.upper(), table) 147 | 148 | # basic cleaning 149 | in_str = re.sub(r"\s+", " ", in_str) 150 | in_str = re.sub(r"\"+", '"', in_str) 151 | return in_str 152 | 153 | 154 | # ====== Main Class Object ====== # 155 | class T2SDataset(Dataset): 156 | def __init__(self, config, mode, method = "m1"): 157 | if method not in ["m1", "m2"]: 158 | raise ValueError("Invalid method, please check documentation") 159 | 160 | self.config = config 161 | self.method = method 162 | 163 | with open(config.schema_file) as f: 164 | self.schemas = {k:parse_db_to_networkx(v) for k,v in json.load(f).items()} 165 | 166 | with open(config.questions_file) as f: 167 | df = pd.read_csv(config.questions_file, sep="\t") 168 | mode = 1 if mode == "train" else 0 169 | df = df[df.train == mode] 170 | 171 | self.questions = df.question.values 172 | self.queries = df["query"].values # I didn't know .query was a function 173 | self.db_ids = df.db_id.values 174 | 175 | def __len__(self): 176 | return len(self.questions) 177 | 178 | def _get_question(self, index, config, t): 179 | # prepare the questions 180 | question = self.questions[index] 181 | question = [t.bos_id()] + t.encode(question) + [t.eos_id()] 182 | sent_len = len(question) 183 | if config.maxlen > len(question): 184 | question = question + [t.pad_id() for _ in range(config.maxlen - len(question))] 185 | else: 186 | question = question[:config.maxlen] 187 | sent_attn = np.zeros((config.maxlen, config.maxlen)).astype(np.int32) 188 | sent_attn[:sent_len, :sent_len] = 1 189 | sent_attn = 1 - sent_attn 190 | sent_attn = sent_attn * -1e6 191 | 192 | return question, sent_attn 193 | 194 | def _get_sql(self, index, config, t): 195 | # prepare the sql query 196 | sql = self.queries[index] 197 | sql = [t.bos_id()] + t.encode(sql) + [t.eos_id()] 198 | sql_len = len(sql) 199 | if config.maxlen > len(sql) + 1: 200 | sql = sql + [t.pad_id() 201 | for _ in range(config.maxlen - len(sql) + 1)] 202 | else: 203 | sql = sql[:config.maxlen + 1] 204 | sql_attn = np.zeros((config.maxlen, config.maxlen)).astype(np.int32) 205 | sql_attn[:sql_len, :sql_len] = 1 206 | sql_attn = sql_attn - np.triu(sql_attn, k=1) # casual masking 207 | sql_attn = 1 - sql_attn 208 | sql_attn = sql_attn * -1e6 209 | 210 | # create labels 211 | labels = torch.from_numpy(np.asarray(sql[1:])).long() 212 | labels[sql_len-1:] = -100 # minus 1 because already shifted 213 | 214 | # create input ids 215 | sql_ids = torch.from_numpy(np.asarray(sql[:-1])).long() 216 | 217 | return sql_ids, sql_attn, labels 218 | 219 | # === functions based on methods === # 220 | def _getitem_m1(self, index): 221 | config = self.config 222 | t = self.config.tokenizer 223 | 224 | # prepare the questions 225 | question, sent_attn = self._get_question(index, config, t) 226 | 227 | # get sql 228 | sql_ids, sql_attn, labels = self._get_sql(question, config, t) 229 | 230 | # prepare the DB sequence 231 | g = self.schemas[self.db_ids[index]] 232 | db_attn_mat, db_tokens, len_db = get_tokenised_attention_mask(g, t, size=config.maxlen) 233 | db_tokens = [t.bos_id()] + db_tokens + [t.eos_id()] 234 | if config.maxlen > len(db_tokens): 235 | db_tokens = db_tokens + [t.pad_id() for _ in range(config.maxlen - len(db_tokens))] 236 | else: 237 | db_tokens = db_tokens[:config.maxlen] 238 | 239 | # return the output dictionary 240 | return { 241 | "sql_ids": sql_ids, 242 | "labels": labels, 243 | "sent": torch.from_numpy(np.asarray(question)).long(), 244 | "db": torch.from_numpy(np.asarray(db_tokens)).long(), 245 | "sql_attn": torch.from_numpy(np.asarray([sql_attn]).astype(np.float32)), 246 | "sent_attn": torch.from_numpy(np.asarray([sent_attn]).astype(np.float32)), 247 | "db_attn": torch.from_numpy(np.asarray([db_attn_mat]).astype(np.float32)) 248 | } 249 | 250 | def _getitem_m2(self, index): 251 | config = self.config 252 | t = self.config.tokenizer 253 | 254 | # prepare the questions 255 | question, sent_attn = self._get_question(index, config, t) 256 | 257 | # get sql 258 | sql_ids, sql_attn, labels = self._get_sql(index, config, t) 259 | 260 | # prepare the DB sequence 261 | g = self.schemas[self.db_ids[index]] 262 | db_attn_mat, len_db = get_db_attention_mask(g, size=-1) 263 | db_tokens = [] 264 | sizes = [] 265 | for x in g.nodes().data(): 266 | # we directly call the internal wordpiece_tokenizer, helps in debugging 267 | tokens = t.encode(" ".join([x[1].get("table"), x[1].get("name")])) 268 | sizes.append(len(tokens)) 269 | db_tokens.extend(tokens) 270 | 271 | # creating a merging_index list as [0,0,0,1,1,2,3,4,4,5,6,6,7,8,8,8,9,10,10,10] 272 | merging_index = [0] # bos 273 | for i,s in enumerate(sizes): 274 | merging_index.extend([i+1 for _ in range(s)]) 275 | merging_index = merging_index + [len(sizes) + 1] # eos 276 | 277 | if config.maxlen > len(merging_index): 278 | merging_index = merging_index + [len(sizes)+2 for _ in range(config.maxlen - len(merging_index))] 279 | else: 280 | merging_index = merging_index[:config.maxlen] 281 | 282 | # final tokens 283 | db_tokens = [t.bos_id()] + db_tokens + [t.eos_id()] 284 | if config.maxlen > len(db_tokens): 285 | db_tokens = db_tokens + [t.pad_id() for _ in range(config.maxlen - len(db_tokens))] 286 | else: 287 | db_tokens = db_tokens[:config.maxlen] 288 | 289 | print(len(db_tokens), len(merging_index)) 290 | print(db_tokens, merging_index) 291 | 292 | assert len(db_tokens) == len(merging_index) 293 | 294 | return { 295 | "sql_ids": sql_ids, 296 | "labels": labels, 297 | "sent": torch.from_numpy(np.asarray(question)).long(), 298 | "db": torch.from_numpy(np.asarray(db_tokens)).long(), 299 | "merging_index": torch.from_numpy(np.asarray(merging_index)).long(), 300 | 301 | "sql_attn": torch.from_numpy(np.asarray([sql_attn]).astype(np.float32)), 302 | "sent_attn": torch.from_numpy(np.asarray([sent_attn]).astype(np.float32)), 303 | "db_attn": torch.from_numpy(np.asarray([db_attn_mat]).astype(np.float32)), 304 | } 305 | 306 | def __getitem__(self, index): 307 | if self.method == "m1": 308 | return self._getitem_m1(index) 309 | elif self.method == "m2": 310 | return self._getitem_m2(index) 311 | 312 | 313 | class T2SDatasetConfig: 314 | schema_file = None # json file with schema dump 315 | questions_file = None # TSV file with questions-sql dump 316 | maxlen = 150 # maximum length for all is same for simplicity 317 | # also same size helps fit in the encoder mask as well as the 318 | # cross attention mask 319 | tokenizer_path = None 320 | 321 | maxlen_db = 1900 # maximum length of DB string to support 322 | 323 | def __init__(self, **kwargs): 324 | self.attrs = ["schema_file", "questions_file", "maxlen"] 325 | for k, v in kwargs.items(): 326 | setattr(self, k, v) 327 | self.attrs.append(k) 328 | self.tokenizer = spm.SentencePieceProcessor() 329 | self.tokenizer.load(self.tokenizer_path) 330 | 331 | print(f"Loaded tokenizer from: {self.tokenizer_path}. (vocab_size: {self.tokenizer.vocab_size()})") 332 | 333 | def __repr__(self): 334 | kvs = [(k, f"{getattr(self, k)}") for k in sorted(list(set(self.attrs)))] 335 | return tabulate(kvs, ["argument", "value"], tablefmt="psql") 336 | 337 | if __name__ == "__main__": 338 | config = T2SDatasetConfig( 339 | schema_file="/Users/yashbonde/Desktop/AI/text2sql/fdata/all_schema.json", 340 | questions_file="/Users/yashbonde/Desktop/AI/text2sql/fdata/all_questions.tsv", 341 | maxlen = 150, 342 | tokenizer_path="/Users/yashbonde/Desktop/AI/text2sql/data/model.model", 343 | ) 344 | print(config) 345 | ds = T2SDataset(config=config, mode="train", method = "m2") 346 | 347 | print(ds[123]) 348 | -------------------------------------------------------------------------------- /src/text2sql/generation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import re 17 | import numpy as np 18 | 19 | import torch 20 | from torch import Tensor 21 | from torch.nn import functional as F 22 | 23 | # from prepare_data import START_TOKEN, END_TOKEN, PAD_TOKEN, VOCAB, VOCAB_TOKENS, ROUND, prepare_expr_string 24 | # from maths import Math 25 | 26 | 27 | # ------ class ------ # 28 | class BeamHypotheses(object): 29 | def __init__(self, beam_size, max_length, length_penalty, early_stopping): 30 | """Initialize n-best list of hypotheses.""" 31 | self.max_length = max_length - 1 # ignoring bos_token 32 | self.length_penalty = length_penalty 33 | self.early_stopping = early_stopping 34 | self.beam_size = beam_size 35 | self.beams = [] 36 | self.worst_score = 1e9 37 | 38 | def __len__(self): 39 | """Number of hypotheses in the list.""" 40 | return len(self.beams) 41 | 42 | def add(self, hyp, sum_logprobs): 43 | """Add a new hypothesis to the list.""" 44 | score = sum_logprobs / len(hyp) ** self.length_penalty 45 | if len(self) < self.beam_size or score > self.worst_score: 46 | self.beams.append((score, hyp)) 47 | if len(self) > self.beam_size: 48 | sorted_scores = sorted( 49 | [(s, idx) for idx, (s, _) in enumerate(self.beams)]) 50 | del self.beams[sorted_scores[0][1]] 51 | self.worst_score = sorted_scores[1][0] 52 | else: 53 | self.worst_score = min(score, self.worst_score) 54 | 55 | def is_done(self, best_sum_logprobs, cur_len): 56 | """If there are enough hypotheses and that none of the hypotheses being generated 57 | can become better than the worst one in the heap, then we are done with this sentence.""" 58 | if len(self) < self.beam_size: 59 | return False 60 | elif self.early_stopping: 61 | return True 62 | else: 63 | cur_score = best_sum_logprobs / cur_len ** self.length_penalty 64 | ret = self.worst_score >= cur_score 65 | return ret 66 | 67 | 68 | def top_k_top_p_filtering( 69 | logits: Tensor, 70 | top_k: int = 0, 71 | top_p: float = 1.0, 72 | filter_value: float = -1e10, 73 | min_tokens_to_keep: int = 1, 74 | ) -> Tensor: 75 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 76 | Args: 77 | logits: logits distribution shape (batch size, vocabulary size) 78 | if top_k > 0: keep only top k tokens with highest probability (top-k filtering). 79 | if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 80 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 81 | Make sure we keep at least min_tokens_to_keep per batch example in the output 82 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 83 | """ 84 | if top_k > 0: 85 | top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check 86 | # Remove all tokens with a probability less than the last token of the top-k 87 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 88 | logits[indices_to_remove] = filter_value 89 | 90 | if top_p < 1.0: 91 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 92 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 93 | 94 | # Remove tokens with cumulative probability above the threshold (token with 0 are kept) 95 | sorted_indices_to_remove = cumulative_probs > top_p 96 | if min_tokens_to_keep > 1: 97 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) 98 | sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 99 | # Shift the indices to the right to keep also the first token above the threshold 100 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 101 | sorted_indices_to_remove[..., 0] = 0 102 | 103 | # scatter sorted tensors to original indexing 104 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 105 | logits[indices_to_remove] = filter_value 106 | return logits 107 | 108 | 109 | def beam_search( 110 | model, 111 | obs, 112 | beam_size, 113 | max_length, 114 | min_length, 115 | tokenizer, 116 | input_str = None, 117 | do_sample = True, 118 | early_stopping = False, 119 | temperature = 1.0, 120 | top_k = 10, 121 | top_p = 0.9, 122 | repetition_penalty = 1.4, 123 | length_penalty = 1 124 | ): 125 | """Hacker version, originally from huggingface generation utils 126 | https://github.com/huggingface/transformers/blob/master/src/transformers/generation_utils.py 127 | """ 128 | 129 | assert temperature > 0, "`temperature` should be strictly positive." 130 | assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." 131 | assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." 132 | assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." 133 | assert length_penalty > 0, "`length_penalty` should be strictly positive." 134 | 135 | # create the first inputs for model 136 | 137 | if input_str is None: 138 | input_ids = torch.ones((batch_size * beam_size, 1)).long() * tokenizer.bos_id() 139 | else: 140 | seq = [[tokenizer.bos_id()] + tokenizer.encode(input_str)][:max_length] 141 | seq = [seq,]*(batch_size * beam_size) 142 | input_ids = torch.from_numpy(np.asarray(seq)).long().squeeze(1) 143 | 144 | cur_len = input_ids.size(1) 145 | attention_mask = torch.ones((batch_size * beam_size, cur_len)).long() 146 | enc_out = model.enc_out(**obs)[0] # get the encoder output for cached 147 | 148 | # generated hypotheses 149 | generated_hyps = [ 150 | BeamHypotheses(beam_size, max_length, length_penalty, early_stopping=early_stopping) 151 | for _ in range(batch_size) 152 | ] 153 | 154 | # scores for each sentence in the beam 155 | beam_scores = torch.zeros( 156 | (batch_size, beam_size), 157 | dtype=torch.float, device=input_ids.device 158 | ) 159 | 160 | # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times 161 | if do_sample is False: 162 | beam_scores[:, 1:] = -1e9 163 | beam_scores = beam_scores.view(-1) # shape (batch_size * beam_size,) 164 | 165 | # done sentences 166 | done = [False for _ in range(batch_size)] 167 | 168 | while cur_len < max_length: 169 | # (batch_size * beam_size, cur_len, vocab_size) 170 | logits = model.dec_out(enc_out, input_ids, attention_mask, verbose=False)[0{}] 171 | # (batch_size * beam_size, vocab_size) 172 | next_token_logits = logits[:, -1, :] 173 | 174 | # (batch_size * beam_size, vocab_size) 175 | scores = F.log_softmax(next_token_logits, dim=-1) 176 | 177 | #L667 --- #L62 postprocess_next_token_scores() 178 | if repetition_penalty != 1.0: 179 | for i in range(batch_size * beam_size): 180 | for previous_token in set(input_ids[i].tolist()): 181 | # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability 182 | repetition_penalty = repetition_penalty if scores[i, previous_token] < 0 else (1 / repetition_penalty) 183 | scores[i, previous_token] *= repetition_penalty 184 | 185 | # set eos token prob to zero if min_length is not reached 186 | if eos_id is not None and cur_len < min_length: 187 | scores[:, eos_id] = -1e10 # -float("inf") causes nan issues 188 | 189 | #L108 --- #L680 190 | assert scores.shape == (batch_size * beam_size, vocab_size), f"Shapes of scores: {scores.shape} != {(batch_size * beam_size, vocab_size)}" 191 | 192 | if do_sample: 193 | # (batch_size * beam_size, vocab_size) 194 | _scores = scores + beam_scores[:, None].expand_as(scores) 195 | # Temperature 196 | if temperature != 1.0: 197 | _scores = _scores / temperature 198 | # Top-p/top-k filtering 199 | _scores = top_k_top_p_filtering(_scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2) # (batch_size * beam_size, vocab_size) 200 | # re-organize to group the beam together to sample from all beam_idxs 201 | _scores = _scores.contiguous().view(batch_size, beam_size * vocab_size) # (batch_size, beam_size * vocab_size) 202 | 203 | # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) 204 | probs = F.softmax(_scores, dim=-1) 205 | # (batch_size, beam_size * 2) 206 | next_tokens = torch.multinomial(probs, num_samples=2 * beam_size) 207 | # Compute next scores 208 | # (batch_size, beam_size * 2) 209 | next_scores = torch.gather(_scores, -1, next_tokens) 210 | # sort the sampled vector to make sure that the first beam_size samples are the best 211 | next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1) 212 | # (batch_size, beam_size * 2) 213 | next_tokens = torch.gather(next_tokens, -1, next_scores_indices) 214 | 215 | else: 216 | # (batch_size * beam_size, vocab_size) 217 | next_scores = scores + beam_scores[:, None].expand_as(scores) 218 | 219 | # re-organize to group the beam together (we are keeping top hypothesis accross beams) 220 | next_scores = next_scores.view(batch_size, beam_size * vocab_size) # (batch_size, beam_size * vocab_size) 221 | 222 | next_scores, next_tokens = torch.topk(next_scores, 2 * beam_size, dim=1, largest=True, sorted=True) 223 | 224 | assert next_scores.size() == next_tokens.size() == (batch_size, 2 * beam_size) 225 | 226 | # next batch beam content 227 | next_batch_beam = [] 228 | 229 | # for each sentence 230 | for batch_idx in range(batch_size): 231 | 232 | # if we are done with this sentence, add a pad token 233 | if done[batch_idx]: 234 | assert ( 235 | len(generated_hyps[batch_idx]) >= beam_size 236 | ), "Batch can only be done if at least {} beams have been generated".format(beam_size) 237 | assert ( 238 | eos_id is not None and pad_id is not None 239 | ), "generated beams >= beam_size -> eos_id and pad_token have to be defined" 240 | next_batch_beam.extend( 241 | [(0, pad_id, 0)] * beam_size) # pad the batch 242 | continue 243 | 244 | # next sentence beam content, this will get added to next_batch_beam 245 | next_sent_beam = [] 246 | 247 | # next tokens for this sentence 248 | for beam_token_rank, (beam_token_id, beam_token_score) in enumerate( 249 | zip(next_tokens[batch_idx], next_scores[batch_idx]) 250 | ): 251 | # get beam and token IDs 252 | beam_id = beam_token_id // vocab_size 253 | token_id = beam_token_id % vocab_size 254 | 255 | effective_beam_id = batch_idx * beam_size + beam_id 256 | # add to generated hypotheses if end of sentence 257 | if (eos_id is not None) and (token_id.item() == eos_id): 258 | # if beam_token does not belong to top beam_size tokens, it should not be added 259 | is_beam_token_worse_than_top_beam_size = beam_token_rank >= beam_size 260 | if is_beam_token_worse_than_top_beam_size: 261 | continue 262 | generated_hyps[batch_idx].add( 263 | input_ids[effective_beam_id].clone(), 264 | beam_token_score.item(), 265 | ) 266 | else: 267 | # add next predicted token since it is not eos_token 268 | next_sent_beam.append((beam_token_score, token_id, effective_beam_id)) 269 | 270 | # once the beam for next step is full, don't add more tokens to it. 271 | if len(next_sent_beam) == beam_size: 272 | break 273 | 274 | # Check if we are done so that we can save a pad step if all(done) 275 | done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done( 276 | next_scores[batch_idx].max().item(), cur_len 277 | ) 278 | 279 | # update next beam content 280 | assert len(next_sent_beam) == beam_size, "Beam should always be full" 281 | next_batch_beam.extend(next_sent_beam) 282 | assert len(next_batch_beam) == beam_size * (batch_idx + 1), "We should have added beam_size each step" 283 | 284 | # stop when we are done with each sentence 285 | if all(done): 286 | break 287 | 288 | # sanity check / prepare next batch 289 | assert len(next_batch_beam) == batch_size * beam_size 290 | beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) 291 | beam_tokens = input_ids.new([x[1] for x in next_batch_beam]) 292 | beam_idx = input_ids.new([x[2] for x in next_batch_beam]) 293 | 294 | # re-order batch and update current length 295 | input_ids = input_ids[beam_idx, :] 296 | input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1) 297 | cur_len = cur_len + 1 298 | 299 | # extend attention_mask for new generated input if only decoder 300 | attention_mask = torch.cat( 301 | [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 302 | ) 303 | 304 | # finalize all open beam hypotheses and add to generated hypotheses 305 | for batch_idx in range(batch_size): 306 | if done[batch_idx]: 307 | continue 308 | 309 | # test that beam scores match previously calculated scores if not eos and batch_idx not done 310 | if eos_id is not None and all( 311 | (token_id % vocab_size).item() != eos_id for token_id in next_tokens[batch_idx] 312 | ): 313 | assert ( 314 | torch.all(next_scores[batch_idx, :beam_size] == beam_scores.view(batch_size, beam_size)[batch_idx]), 315 | f"If batch_idx is not done, final next scores: {next_scores[:, :beam_size][batch_idx]}" 316 | f" have to equal to accumulated beam_scores: {beam_scores.view(batch_size, beam_size)[batch_idx]}" 317 | ) 318 | 319 | # need to add best beam_size hypotheses to generated hyps 320 | for beam_id in range(beam_size): 321 | effective_beam_id = batch_idx * beam_size + beam_id 322 | final_score = beam_scores[effective_beam_id].item() 323 | final_tokens = input_ids[effective_beam_id] 324 | generated_hyps[batch_idx].add(final_tokens, final_score) 325 | 326 | strs = create_expr([x[1] for x in generated_hyps[0].beams]) 327 | scores = [x[0] for x in generated_hyps[0].beams] 328 | return strs, scores 329 | 330 | 331 | @torch.no_grad() 332 | def predict_expression( 333 | model, 334 | obs, 335 | beam_size, 336 | max_length, 337 | min_length, 338 | input_str=None, 339 | bos_id=VOCAB[START_TOKEN], 340 | eos_id=VOCAB[END_TOKEN], 341 | pad_id=VOCAB[PAD_TOKEN], 342 | vocab_size=len(VOCAB), 343 | do_sample=True, 344 | early_stopping=False, 345 | temperature=1.0, 346 | top_k=10, 347 | top_p=0.9, 348 | repetition_penalty=1.4, 349 | length_penalty=1 350 | ): 351 | """wrapper for beam search 352 | 353 | :parma model: nn.Module object that the model 354 | :parma obs: output Encoder dict from O2fDataset 355 | :parma beam_size: Beam size to search on 356 | :parma max_length: Maximum sequence length to generate 357 | :parma min_length: Minimum sequence length to generate 358 | :parma input_str: If user has already given some input 359 | :parma bos_id: BOS ID 360 | :parma eos_id: EOS ID 361 | :parma pad_id: PAD ID 362 | :parma vocab_size: Vocabulary size 363 | :parma do_sample: To perform sampling or not 364 | :parma early_stopping: Whether to stop the beam search when at least num_beams sentences are finished per batch or not 365 | :parma temperature: The value used to module the next token probabilities 366 | :parma top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering 367 | :parma top_p: If set to float < 1, only the most probable tokens with probabilities that add up to top_p or 368 | higher are kept for generation 369 | :parma repetition_penalty: The parameter for repetition penalty. 1.0 means no penalty. See `this paper 370 | `__ for more details. 371 | :parma length_penalty: Exponential penalty to the length. 1.0 means no penalty. 372 | Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in 373 | order to encourage the model to produce longer sequences. 374 | """ 375 | return beam_search( 376 | model=model, 377 | obs=obs, 378 | beam_size=beam_size, 379 | max_length=max_length, 380 | min_length=min_length, 381 | input_str=input_str, 382 | bos_id=bos_id, 383 | eos_id=eos_id, 384 | pad_id=pad_id, 385 | vocab_size=vocab_size, 386 | do_sample=do_sample, 387 | early_stopping=early_stopping, 388 | temperature=temperature, 389 | top_k=top_k, 390 | top_p=top_p, 391 | repetition_penalty=repetition_penalty, 392 | length_penalty=length_penalty 393 | ) 394 | -------------------------------------------------------------------------------- /src/text2sql/graph_network.py: -------------------------------------------------------------------------------- 1 | """ 2 | GNN using TransformerConv 3 | """ 4 | # https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html#implementing-the-gcn-layer 5 | import math 6 | import numpy as np 7 | import networkx as nx 8 | 9 | from types import SimpleNamespace 10 | 11 | import torch 12 | from torch import nn 13 | 14 | import torch.nn.functional as F 15 | from torch.nn import Linear 16 | from torch_geometric.nn.conv import MessagePassing 17 | from torch_geometric.utils import softmax, from_networkx 18 | 19 | 20 | # set seeds for reproducibility 21 | np.random.seed(4) 22 | torch.manual_seed(4) 23 | 24 | NODE_ARGS = 5 25 | NUM_NODES = 20 26 | P = 0.3 27 | SAMPLES = 100 28 | 29 | data = [] 30 | node_fields = [f"i{i}" for i in range(NODE_ARGS)] 31 | while len(data) < SAMPLES: 32 | g = nx.binomial_graph(NUM_NODES, P) 33 | if nx.is_connected(g): 34 | g2 = nx.Graph() 35 | for n in list(g.nodes): 36 | g2.add_node(n, **{f: [np.random.random()] for f in node_fields}) 37 | for e in list(g.edges): 38 | g2.add_edge(*e) 39 | data.append(from_networkx(g2)) 40 | 41 | 42 | 43 | """Class Object of TransformerConv""" 44 | class TransformerConv(MessagePassing): 45 | # from https://arxiv.org/abs/2009.03509 46 | def __init__(self, config): 47 | self.config = config 48 | self.in_channels = config.n_embd 49 | self.out_channels = config.n_embd 50 | self.heads = config.n_heads 51 | self.dropout = config.dropout 52 | self.edge_dim = config.edge_dim 53 | self.beta = config.graph_beta 54 | 55 | # in_channels = (config.n_embd, config.n_embd) 56 | # out_channels = in_channels = config.n_embd 57 | 58 | self.lin_key = nn.Linear(config.n_embd, config.n_heads * config.n_embd) 59 | self.lin_query = nn.Linear(config.n_embd, config.n_heads * config.n_embd) 60 | self.lin_value = nn.Linear(config.n_embd, config.n_heads * config.n_embd) 61 | self.lin_edge = nn.Linear(config.edge_dim, config.n_heads * config.n_embd, bias = False) 62 | self.lin_skip = nn.Linear(config.n_embd, config.n_heads * config.n_embd, bias = True) 63 | self.reset_parameters() 64 | 65 | def reset_parameters(self): 66 | self.lin_key.reset_parameters() 67 | self.lin_query.reset_parameters() 68 | self.lin_value.reset_parameters() 69 | self.lin_edge.reset_parameters() 70 | self.lin_skip.reset_parameters() 71 | 72 | def forward(self, x, edge_index, edge_attr): 73 | x = (x, x) # pair tensor thingy 74 | out = self.propagate(edge_index, x = x, edge_attr = edge_attr, size = None) 75 | out = out.view(-1, self.heads * self.out_channels) # always concat 76 | x = self.lin_skip(x[1]) 77 | if self.beta is not None: 78 | out = self.beta * x + (1 - self.beta) * out 79 | else: 80 | out = out + x 81 | 82 | return out 83 | 84 | def message(self, x_i, x_j, edge_attr, index, ptr, size_i): 85 | query = self.lin_key(x_i).view(-1, self.heads, self.out_channels) 86 | key = self.lin_query(x_i).view(-1, self.heads, self.out_channels) 87 | 88 | lin_edge = self.lin_edge 89 | if edge_attr is not None: 90 | edge_attr = lin_edge(edge_attr).view(-1, self.heads, self.out_channels) 91 | key = key + edge_attr 92 | 93 | alpha = (query * key).sum(dim = -1) / math.sqrt(self.out_channels) 94 | alpha = softmax(alpha, index, ptr, size_i) 95 | alpha = F.softmax(alpha, p = self.dropout, training = self.training) 96 | 97 | out = self.lin_value(x_j).view(-1, self.heads, self.out_channels) 98 | if edge_attr is not None: 99 | out = out + edge_attr 100 | 101 | out = out * alpha.view(-1, self.heads, 1) 102 | return out 103 | 104 | def __repr__(self): 105 | return f"{self.__class__.__name__} ({self.in_channels} {self.out_channels}, heads = {self.heads})" 106 | 107 | 108 | # configuration object 109 | config = SimpleNamespace( 110 | n_embd = 16, 111 | n_heads = 2, 112 | dropout = 0.1, 113 | edge_dim = 8, 114 | graph_beta = 0.9 115 | ) 116 | -------------------------------------------------------------------------------- /src/text2sql/model.py: -------------------------------------------------------------------------------- 1 | """model for text2sql 2 | 03.11.2020 - @yashbonde""" 3 | 4 | import numpy as np 5 | from tabulate import tabulate 6 | from types import SimpleNamespace 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import functional as F 11 | from transformers.modeling_utils import find_pruneable_heads_and_indices, prune_conv1d_layer 12 | from transformers.activations import ACT2FN 13 | 14 | # the code below is only a slighlty modified version from huggingface. 15 | class Conv1D(nn.Module): 16 | def __init__(self, nf, nx): 17 | super().__init__() 18 | self.nf = nf 19 | w = torch.empty(nx, nf) 20 | nn.init.normal_(w, std=0.02) 21 | self.weight = nn.Parameter(w) 22 | self.bias = nn.Parameter(torch.zeros(nf)) 23 | 24 | def forward(self, x): 25 | size_out = x.size()[:-1] + (self.nf,) 26 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) 27 | x = x.view(*size_out) 28 | return x 29 | 30 | 31 | class Attention(nn.Module): 32 | def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False): 33 | super().__init__() 34 | 35 | n_state = nx # in Attention: n_state=768 (nx=n_embd) 36 | # [switch nx => n_state from Block to Attention to keep identical to TF implem] 37 | assert n_state % config.n_head == 0 38 | 39 | self.n_head = config.n_head 40 | self.split_size = n_state 41 | self.scale = scale 42 | self.is_cross_attention = is_cross_attention 43 | if self.is_cross_attention: 44 | self.c_attn = Conv1D(2 * n_state, nx) 45 | self.q_attn = Conv1D(n_state, nx) 46 | else: 47 | self.c_attn = Conv1D(3 * n_state, nx) 48 | self.c_proj = Conv1D(n_state, nx) 49 | self.attn_dropout = nn.Dropout(config.attn_pdrop) 50 | self.resid_dropout = nn.Dropout(config.resid_pdrop) 51 | self.pruned_heads = set() 52 | 53 | def prune_heads(self, heads): 54 | if len(heads) == 0: 55 | return 56 | heads, index = find_pruneable_heads_and_indices( 57 | heads, self.n_head, self.split_size // self.n_head, self.pruned_heads 58 | ) 59 | index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) 60 | 61 | # Prune conv1d layers 62 | self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) 63 | self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) 64 | 65 | # Update hyper params 66 | self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) 67 | self.n_head = self.n_head - len(heads) 68 | self.pruned_heads = self.pruned_heads.union(heads) 69 | 70 | def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False): 71 | w = torch.matmul(q, k) 72 | if self.scale: 73 | w = w / (float(v.size(-1)) ** 0.5) 74 | 75 | # print("((((", w.size(), attention_mask.size()) 76 | 77 | if attention_mask is not None: 78 | # reshape the attention mask to fit the weight matrix 79 | # print(attention_mask[0, 0, :w.size(2), :]) 80 | # Apply the attention mask 81 | w = w + attention_mask[:, :, :w.size(2), :] 82 | 83 | w = nn.Softmax(dim=-1)(w) 84 | w = self.attn_dropout(w) 85 | 86 | # print(w.size(), v.size()) 87 | 88 | outputs = [torch.matmul(w, v)] 89 | if output_attentions: 90 | outputs.append(w) 91 | return outputs 92 | 93 | def merge_heads(self, x): 94 | x = x.permute(0, 2, 1, 3).contiguous() 95 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 96 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states 97 | 98 | def split_heads(self, x, k=False): 99 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) 100 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states 101 | if k: 102 | return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) 103 | else: 104 | return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) 105 | 106 | def forward( 107 | self, 108 | hidden_states, 109 | layer_past=None, 110 | attention_mask=None, 111 | head_mask=None, 112 | encoder_hidden_states=None, 113 | encoder_attention_mask=None, 114 | use_cache=False, 115 | output_attentions=False, 116 | ): 117 | if encoder_hidden_states is not None: 118 | assert hasattr( 119 | self, "q_attn" 120 | ), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`." 121 | query = self.q_attn(hidden_states) 122 | key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) 123 | attention_mask = encoder_attention_mask 124 | else: 125 | query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) 126 | 127 | query = self.split_heads(query) 128 | key = self.split_heads(key, k=True) 129 | value = self.split_heads(value) 130 | if layer_past is not None: 131 | past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below 132 | key = torch.cat((past_key, key), dim=-1) 133 | value = torch.cat((past_value, value), dim=-2) 134 | 135 | if use_cache is True: 136 | present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking 137 | else: 138 | present = (None,) 139 | 140 | # print(query.size(), key.size(), value.size()) 141 | 142 | attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions) 143 | a = attn_outputs[0] 144 | 145 | a = self.merge_heads(a) 146 | a = self.c_proj(a) 147 | a = self.resid_dropout(a) 148 | 149 | outputs = [a, present] + attn_outputs[1:] 150 | return outputs # a, present, (attentions) 151 | 152 | 153 | class MLP(nn.Module): 154 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) 155 | super().__init__() 156 | nx = config.n_embd 157 | self.c_fc = Conv1D(n_state, nx) 158 | self.c_proj = Conv1D(nx, n_state) 159 | self.act = ACT2FN[config.activation_function] 160 | self.dropout = nn.Dropout(config.resid_pdrop) 161 | 162 | def forward(self, x): 163 | h = self.act(self.c_fc(x)) 164 | h2 = self.c_proj(h) 165 | return self.dropout(h2) 166 | 167 | 168 | class Block(nn.Module): 169 | def __init__(self, config, n_ctx, add_cross_attention = False, scale=False): 170 | super().__init__() 171 | hidden_size= config.n_embd 172 | inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size 173 | self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 174 | self.attn = Attention(hidden_size, n_ctx, config, scale) 175 | self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 176 | if add_cross_attention: 177 | self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True) 178 | self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 179 | self.mlp = MLP(inner_dim, config) 180 | 181 | def forward(self, x): 182 | # this was not taking key word arguments in Sequential so need to pass around a tuple 183 | # so now I understood why huggingface coded by passing around lists, stupid! 184 | type_ =x[0] 185 | # print("^^^^", type_, len(x)) 186 | if type_ in ["encoder", "self"]: 187 | (hidden_states, attention_mask) = x[1:] 188 | else: 189 | (hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask) = x[1:] 190 | 191 | attn_outputs = self.attn( 192 | self.ln_1(hidden_states), 193 | attention_mask=attention_mask, 194 | ) 195 | attn_output = attn_outputs[0] # output_attn: a, present, (attentions) 196 | outputs = attn_outputs[1:] 197 | hidden_states = attn_output + hidden_states # residual connection 198 | 199 | if type_ == "decoder": 200 | # add one self-attention block for cross-attention 201 | assert hasattr( 202 | self, "crossattention" 203 | ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" 204 | cross_attn_outputs = self.crossattention( 205 | self.ln_cross_attn(hidden_states), 206 | attention_mask=attention_mask, 207 | encoder_hidden_states=encoder_hidden_states, 208 | encoder_attention_mask=encoder_attention_mask, 209 | ) 210 | attn_output = cross_attn_outputs[0] 211 | # residual connection 212 | hidden_states = hidden_states + attn_output 213 | outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights 214 | 215 | feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states)) 216 | # residual connection 217 | hidden_states = hidden_states + feed_forward_hidden_states 218 | 219 | outputs = [hidden_states] + outputs 220 | if type_ in ["encoder", "self"]: 221 | out = (type_, hidden_states, attention_mask) 222 | else: 223 | out = (type_, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask) 224 | return out 225 | 226 | 227 | class Decoder(nn.Module): 228 | def __init__(self, config): 229 | super(Decoder, self).__init__() 230 | 231 | blocks = [] 232 | for i in range(config.n_decoder_layers): 233 | blocks.append(Block(config, n_ctx = config.maxlen, add_cross_attention = False)) # casual 234 | blocks.append(Block(config, n_ctx = config.maxlen, add_cross_attention = True)) # sent 235 | blocks.append(Block(config, n_ctx = config.maxlen, add_cross_attention = True)) # db 236 | self.blocks = nn.Sequential(*blocks) 237 | self.ln = nn.LayerNorm(config.n_embd, eps = config.layer_norm_epsilon) 238 | 239 | def forward(self, x): 240 | # this was not taking key word arguments in Sequential so need to pass around a tuple 241 | (hidden_states_sql, attention_mask_sql, hidden_states_sent, attention_mask_sent, hidden_states_db, attention_mask_db) = x 242 | 243 | hidden_states = hidden_states_sql 244 | l = 0 245 | for i, block in enumerate(self.blocks): 246 | if l == 0: # casual attention 247 | outputs = block(("self", hidden_states, attention_mask_sql)) 248 | l = 1 249 | elif l == 1: # sentence attention 250 | outputs = block(("decoder", hidden_states, attention_mask_sql, hidden_states_sent, attention_mask_sent)) 251 | l = 2 252 | else: # db attention 253 | outputs = block(("decoder", hidden_states, attention_mask_sql, hidden_states_db, attention_mask_db)) 254 | l = 0 255 | hidden_states = outputs[1] 256 | hidden_states_sql = hidden_states 257 | return (hidden_states_sql, attention_mask_sql, 258 | hidden_states_sent, attention_mask_sent, 259 | hidden_states_db, attention_mask_db) 260 | 261 | 262 | class Text2SQLModel(nn.Module): 263 | def __init__(self, config): 264 | super(Text2SQLModel, self).__init__() 265 | self.config = config 266 | 267 | self.embedding = nn.Embedding(config.vocab_size, config.n_embd) 268 | # self.wte_sent = nn.Embedding(config.maxlen, config.n_embd) 269 | # self.wte_db = nn.Embedding(config.maxlen, config.n_embd) 270 | # self.wte_sql = nn.Embedding(config.maxlen, config.n_embd) 271 | 272 | # using embedding and position IDs isn't really going well so will use Parameter 273 | self.wte_sent = nn.Parameter(torch.zeros(config.maxlen, config.n_embd)) 274 | self.wte_db = nn.Parameter(torch.zeros(config.maxlen, config.n_embd)) 275 | self.wte_sql = nn.Parameter(torch.zeros(config.maxlen, config.n_embd)) 276 | 277 | self.sentence_encoder = nn.Sequential(*[ 278 | Block(config, n_ctx=config.maxlen, add_cross_attention=False) 279 | for _ in range(config.n_sent_layers) 280 | ]) 281 | self.db_encoder = nn.Sequential(*[ 282 | Block(config, n_ctx=config.maxlen, add_cross_attention=False) 283 | for _ in range(config.n_db_layers) 284 | ]) 285 | self.decoder = Decoder(config) 286 | 287 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias = False) 288 | 289 | self.apply(self._init_weights) 290 | print("number of parameters:", sum(p.numel() for p in self.parameters())) 291 | 292 | def _init_weights(self, module): 293 | if isinstance(module, (nn.Linear, nn.Embedding)): 294 | module.weight.data.normal_(mean=0.0, std=0.02) 295 | if isinstance(module, nn.Linear) and module.bias is not None: 296 | module.bias.data.zero_() 297 | elif isinstance(module, nn.LayerNorm): 298 | module.bias.data.zero_() 299 | module.weight.data.fill_(1.0) 300 | 301 | def configure_optimizers(self, train_config): 302 | """ 303 | This long function is unfortunately doing something very simple and is being very defensive: 304 | We are separating out all parameters of the model into two buckets: those that will experience 305 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 306 | We are then returning the PyTorch optimizer object. 307 | """ 308 | 309 | # # separate out all parameters to those that will and won't experience regularizing weight decay 310 | # decay = set() 311 | # no_decay = set() 312 | # whitelist_weight_modules = (torch.nn.Linear, ) 313 | # blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 314 | # for mn, m in self.named_modules(): 315 | # for pn, p in m.named_parameters(): 316 | # # print(mn, "--", pn) 317 | # fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 318 | # print(fpn, type(m)) 319 | # if fpn.endswith('bias'): 320 | # # all biases will not be decayed 321 | # no_decay.add(fpn) 322 | # print(fpn, "--", 1) 323 | # elif fpn.endswith('weight') and isinstance(m, whitelist_weight_modules): 324 | # # weights of whitelist modules will be weight decayed 325 | # decay.add(fpn) 326 | # print(fpn, "--", 2) 327 | # elif fpn.endswith('weight') and isinstance(m, blacklist_weight_modules): 328 | # # weights of blacklist modules will NOT be weight decayed 329 | # no_decay.add(fpn) 330 | # print(fpn, "--", 3) 331 | # print() 332 | 333 | # # validate that we considered every parameter 334 | # param_dict = {pn: p for pn, p in self.named_parameters()} 335 | # inter_params = decay & no_decay 336 | # union_params = decay | no_decay 337 | # assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 338 | # assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 339 | # % (str(param_dict.keys() - union_params), ) 340 | 341 | # # create the pytorch optimizer object 342 | # optim_groups = [ 343 | # {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, 344 | # {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 345 | # ] 346 | optimizer = torch.optim.Adam(self.parameters(), lr=train_config.lr, betas=train_config.betas) 347 | return optimizer 348 | 349 | def get_position_ids(self, input, past_length, device): 350 | input_shape = input.size() 351 | position_ids = torch.arange(past_length, input_shape[-1] + past_length).long().to(device) 352 | position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) 353 | print("**", position_ids.device) 354 | return position_ids 355 | 356 | def encoder_fn(self, sent, db, sent_attn, db_attn, device): 357 | B, T = sent.size() 358 | sent = self.embedding(sent) + self.wte_sent[:T,:] 359 | db = self.embedding(db) + self.wte_db[:T,:] 360 | sent_hidden_states = self.sentence_encoder(("encoder", sent, sent_attn),)[1] 361 | db_hidden_states = self.db_encoder(("encoder", db, db_attn),)[1] 362 | return SimpleNamespace( 363 | sent_hidden_states=sent_hidden_states, 364 | db_hidden_states=db_hidden_states, 365 | sent_attn=sent_attn, 366 | db_attn=db_attn 367 | ) 368 | 369 | def decoder_fn(self, enc_out, sql_ids, sql_attn, device): 370 | B, T = sql_ids.size() 371 | sql = self.embedding(sql_ids) + self.wte_sql[:T,:] 372 | sql_output = self.decoder((sql, sql_attn, enc_out.sent_hidden_states, 373 | enc_out.sent_attn, enc_out.db_hidden_states, enc_out.db_attn),)[0] 374 | sql_output = self.lm_head(sql_output) 375 | return sql_output 376 | 377 | def forward(self, sql_ids, sent, db, sql_attn, sent_attn, db_attn, labels=None, past_length=0, device="cpu"): 378 | B, T = sql_ids.size() 379 | 380 | # make the embeddings 381 | sql = self.embedding(sql_ids) + self.wte_sql[:T,:] 382 | sent = self.embedding(sent) + self.wte_sent[:T,:] 383 | db = self.embedding(db) + self.wte_db[:T,:] 384 | 385 | # get hidden_states for sentence_encoder 386 | sent_hidden_states = self.sentence_encoder(("encoder", sent, sent_attn),)[1] 387 | db_hidden_states = self.db_encoder(("encoder", db, db_attn),)[1] 388 | sql_output = self.decoder((sql, sql_attn, sent_hidden_states, sent_attn, db_hidden_states, db_attn),)[0] 389 | sql_output = self.lm_head(sql_output) 390 | output = [sql_output] 391 | 392 | if labels is not None: 393 | labels = labels.contiguous() 394 | logits = sql_output.contiguous() 395 | 396 | # loss_fct = nn.CrossEntropyLoss(reduction="none") 397 | # loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) 398 | 399 | loss_fct = nn.CrossEntropyLoss(reduction="mean") 400 | loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) 401 | 402 | # # get the indexes where the labels are not [PAD] 403 | # non_pad_mask = sql_attn[:, :, 0, 1:].contiguous().view(-1) == 0 404 | # print(loss.size(), non_pad_mask.size()) 405 | # non_pad_loss = loss[non_pad_mask] 406 | # print("non_pad_loss", non_pad_loss) 407 | # loss = non_pad_loss.mean() 408 | output = [loss] + output 409 | 410 | return output 411 | 412 | 413 | class Text2SQLModelConfig(): 414 | vocab_size = 5012 415 | n_embd = 256 416 | maxlen = 128 417 | n_decoder_layers = 2 418 | n_sent_layers = 3 419 | n_db_layers = 3 420 | n_head = 8 421 | n_inner = None 422 | activation_function = "gelu_new" 423 | resid_pdrop = 0.1 424 | embd_pdrop = 0.1 425 | attn_pdrop = 0.1 426 | layer_norm_epsilon = 0.00001 427 | initializer_range = 0.02 428 | 429 | def __init__(self, **kwargs): 430 | self.attrs = [ 431 | "vocab_size", 432 | "n_embd", 433 | "maxlen", 434 | "n_decoder_layers", 435 | "n_sent_layers", 436 | "n_db_layers", 437 | "n_head", 438 | "n_inner", 439 | "activation_function", 440 | "resid_pdrop", 441 | "embd_pdrop", 442 | "attn_pdrop", 443 | "layer_norm_epsilon", 444 | "initializer_range", 445 | ] 446 | 447 | for k, v in kwargs.items(): 448 | self.attrs.append(k) 449 | setattr(self, k, v) 450 | 451 | def __repr__(self): 452 | kvs = [(k, f"{getattr(self, k)}") for k in sorted(list(set(self.attrs)))] 453 | return tabulate(kvs, ["argument", "value"], tablefmt="psql") 454 | 455 | # ====== Sampling Utils ====== # 456 | def top_k_logits(logits, k): 457 | v, ix = torch.topk(logits, k) 458 | out = logits.clone() 459 | out[out < v[:, [-1]]] = -1e6 460 | return out 461 | 462 | 463 | @torch.no_grad() 464 | def sample(model, sent, sent_attn, db, db_attn, t, sql_str = None, device="cpu", steps=50, temperature=50, top_k=None): 465 | model.eval() 466 | sent = sent.view(-1, sent.size(0)) 467 | db = db.view(-1, db.size(0)) 468 | sent_attn = sent_attn.view(-1, *sent_attn.size()) 469 | db_attn = db_attn.view(-1, *db_attn.size()) 470 | # print(sent.size(), db.size(), sent_attn.size(), db_attn.size()) 471 | enc_out = model.encoder_fn(sent, db, sent_attn, db_attn, device=device) 472 | 473 | # convert string to sql_tokens 474 | if sql_str is not None: 475 | sql = torch.from_numpy(np.asarray([t.encode(sql_str)])).long() 476 | else: 477 | sql = torch.from_numpy(np.asarray([t.bos_id()])).view(-1, 1).long() 478 | sql = sql.to(device) 479 | 480 | # final sequence 481 | out = [] 482 | 483 | for k in range(steps): 484 | if k == 0: 485 | x = sql if sql.size(1) < model.config.maxlen else sql[:, -model.config.maxlen:] 486 | else: 487 | x = x if x.size(1) < model.config.maxlen else x[:, -model.config.maxlen:] 488 | 489 | sql_attn = np.zeros((x.size(1), x.size(1))) 490 | sql_attn[:x.size(1), :x.size(1)] = 1 491 | sql_attn = sql_attn - np.triu(sql_attn, k = 1) # casual masking 492 | sql_attn = 1 - sql_attn 493 | sql_attn = sql_attn * -1e6 494 | sql_attn = torch.from_numpy(sql_attn.astype(np.float32)).to(device).view(1, 1, *sql_attn.shape) 495 | 496 | # print("****", x.size(), sql_attn.size()) 497 | 498 | logits = model.decoder_fn(enc_out, x, sql_attn, device=device) 499 | 500 | # pluck the logits at the final step and scale by temperature 501 | logits = logits[:, -1, :] / temperature 502 | # optionally crop probabilities to only the top k options 503 | if top_k is not None: 504 | logits = top_k_logits(logits, top_k) 505 | 506 | # apply softmax to convert to probabilities 507 | probs = F.softmax(logits, dim=-1) 508 | # sample from the distribution or take the most likely 509 | if sample: 510 | ix = torch.multinomial(probs, num_samples=1) 511 | else: 512 | _, ix = torch.topk(probs, k=1, dim=-1) 513 | # append to the sequence and continue 514 | x = torch.cat((x, ix), dim=1) 515 | 516 | out.append(ix[0].tolist()[0]) 517 | 518 | return t.decode_ids(out) 519 | -------------------------------------------------------------------------------- /src/text2sql/model2.py: -------------------------------------------------------------------------------- 1 | """So the complicated model with encoder and decoder networks is not working properly, 2 | need to come up with something better. 3 | 12.11.2020 - @yashbonde""" 4 | 5 | from text2sql.model import * 6 | 7 | # https://pytorch-scatter.readthedocs.io/en/1.3.0/functions/add.html 8 | from torch_scatter import scatter_add 9 | 10 | class GPTModel(nn.Module): 11 | def __init__(self, config): 12 | super(GPTModel, self).__init__() 13 | 14 | self.config = config 15 | 16 | def merge_embeddings(self, expanded_embeddings, merging_index): 17 | """This function merges multiple embedding vectors into smaller embeddings. This is what 18 | is happening in this function. Challenge is that each node has information in this format 19 | . and we need to merge this information. Another challenge was 20 | that the older method created matrices that were 400x400 while the questions were 50x50. 21 | 22 | So continuing with the above mentioned case: 23 | 24 | is tokenized into W1,W2,W3 and is tokenized into W4,W5 25 | So the effective tokenizing becomes W1,W2,W3.W4,W5. Now in order to feed into the model 26 | we need to merge the expanded_embeddings is (5,256) and merging_index = [0,0,0,1,1] 27 | this function returns a matrix (2,256) 28 | 29 | Now in practice this takes a large matrix like (500, 500) and it returns a highly reduced 30 | version of this. The output in this case is going to be (num_nodes, num_nodes) and thus 31 | we do not need to use the complicated cell expanded attention matrix as done in 32 | `text2sql.data.get_tokenised_attention_mask` 33 | 34 | 35 | """" 36 | out = scatter_add(expanded_embeddings, merging_index) 37 | return out 38 | 39 | def forward(): 40 | 41 | torch_scatter 42 | 43 | -------------------------------------------------------------------------------- /src/text2sql/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple training loop; Boilerplate that could apply to any arbitrary neural network, 3 | so nothing in this file really has anything to do with GPT specifically. 4 | 5 | from karpathy/minGPT 6 | """ 7 | 8 | import time 9 | import random 10 | import numpy as np 11 | from tqdm import tqdm, trange 12 | from tabulate import tabulate 13 | 14 | from text2sql.model import sample 15 | 16 | import torch 17 | from torch.optim.lr_scheduler import OneCycleLR 18 | from torch.utils.data import DataLoader 19 | from torch.utils.tensorboard import SummaryWriter 20 | 21 | class Trainer: 22 | def __init__(self, model, train_dataset, test_dataset, config): 23 | self.model = model 24 | self.train_dataset = train_dataset 25 | self.test_dataset = test_dataset 26 | self.config = config 27 | 28 | self.device = "cpu" 29 | if torch.cuda.is_available(): 30 | print("Model is now CUDA!!!!") 31 | self.device = torch.cuda.current_device() 32 | self.model = torch.nn.DataParallel(self.model).to(self.device) 33 | 34 | def save_checkpoint(self): 35 | raw_model = self.model.module if hasattr(self.model, "module") else self.model 36 | print(f"Saving Model at {self.config.ckpt_path}") 37 | torch.save(raw_model.state_dict(), self.config.ckpt_path) 38 | 39 | def train(self, data_config, verbose = False): 40 | model, config = self.model, self.config 41 | raw_model = model.module if hasattr(self.model, "module") else model 42 | optimizer = raw_model.configure_optimizers(config) 43 | lrscheduler = OneCycleLR( 44 | optimizer, 45 | max_lr = config.lr, 46 | total_steps=config.num_batch*config.max_epochs 47 | ) 48 | 49 | print("Starting training...", lrscheduler) 50 | 51 | with SummaryWriter(log_dir=config.tb_path, flush_secs=20) as tb: 52 | 53 | def run_epoch(split, epoch, _gs): 54 | is_train = split == "train" 55 | model.train(is_train) 56 | data = self.train_dataset if is_train else self.test_dataset 57 | dl = DataLoader( 58 | data, 59 | shuffle = True, 60 | pin_memory = True, 61 | batch_size = config.batch_size, 62 | ) 63 | losses = [] 64 | if is_train: 65 | pbar = trange(config.num_batch, ncols=100) 66 | iterator = zip(pbar, dl) 67 | else: 68 | pbar = tqdm(enumerate(dl)) 69 | iterator = pbar 70 | 71 | for it, d in iterator: 72 | 73 | with torch.set_grad_enabled(is_train): 74 | _l = -1 if not losses else losses[-1] 75 | if is_train: 76 | pbar.set_description(f"[TRAIN] GS: {_gs}, Epoch: {epoch}, Loss: {round(_l, 5)}") 77 | else: 78 | pbar.set_description(f"[VAL] Epoch: {epoch}") 79 | 80 | d = {k:v.to(self.device) for k,v in d.items()} 81 | # print({k:(v.size(), v.dtype, v.device) for k,v in d.items()}) 82 | 83 | loss, logits = model( 84 | **d, 85 | device = self.device 86 | ) 87 | loss = loss.mean() # gather from multitple GPUs 88 | losses.append(loss.item()) 89 | 90 | if is_train: 91 | # add things to tb, loss and attention images 92 | tb.add_scalar("loss", loss.item(), global_step=_gs, walltime=time.time()) 93 | tb.add_scalar("lr", lrscheduler.get_lr()[0], global_step=_gs, walltime=time.time()) 94 | 95 | loss.backward() 96 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) 97 | optimizer.step() 98 | lrscheduler.step() 99 | _gs += 1 100 | 101 | if not is_train: 102 | # no sampling here, because that really doesn't make any sense 103 | test_loss = float(np.mean(losses)) 104 | tb.add_scalar("test_loss", test_loss, global_step=_gs, walltime=time.time()) 105 | 106 | # create samples for visualising the results 107 | print("Generating Samples ...") 108 | for i in range(min(10, len(d["sql_ids"]))): 109 | s = {k:v[i, ...] for k,v in d.items()} 110 | seq = sample(raw_model, sent=s["sent"], sent_attn=s["sent_attn"], 111 | db=s["db"], db_attn=s["db_attn"], t=data_config.tokenizer, 112 | device = self.device) #, sql_str = "select") 113 | target = data_config.tokenizer.decode_ids([x for x in s["labels"].tolist() if x != -100]) 114 | print("-->", seq ,"<<--->>", target) 115 | 116 | return test_loss 117 | return _gs 118 | 119 | # now write wrapper for each epoch 120 | best_loss = float("inf") 121 | gs = 1 122 | test_no_improve = 0 123 | for e in range(config.max_epochs): 124 | gs = run_epoch("train", e, gs) 125 | if self.test_dataset is not None: 126 | test_loss = run_epoch("test", e, gs) 127 | print(f"Test loss: {test_loss}") 128 | 129 | # early stopping based on the test loss of just save always if no test set is provided 130 | good_model = self.test_dataset is None or test_loss < best_loss 131 | if self.config.ckpt_path is not None and good_model: 132 | best_loss = test_loss 133 | self.save_checkpoint() 134 | test_no_improve = 0 135 | else: 136 | test_no_improve += 1 137 | 138 | # if test_no_improve == config.patience: 139 | # print(f"Stop Training after [patience = {config.patience}]: {e} epochs") 140 | # break 141 | 142 | 143 | class TrainerConfig: 144 | lr = 0.0001 145 | max_epochs = 10 146 | batch_size = 128 147 | betas = (0.9, 0.95) 148 | grad_norm_clip = 1.0 149 | weight_decay = 0.1 # only applied on matmul weights 150 | len_data = None # required for OneCycleLR 151 | sample_every = 5 # after how many epochs to log 152 | num_batch = None 153 | patience = 5 # training stops after patience runs out 154 | tb_path = None 155 | ckpt_path = None 156 | 157 | def __init__(self, **kwargs): 158 | self.attrs = [ "lr", "max_epochs", "batch_size", "betas", 159 | "grad_norm_clip", "weight_decay", "len_data", 160 | "sample_every", "num_batch", "patience", "tb_path", 161 | ] 162 | for k,v in kwargs.items(): 163 | setattr(self, k, v) 164 | self.attrs.append(k) 165 | 166 | def __repr__(self): 167 | kvs = [(k, f"{getattr(self, k)}") for k in sorted(list(set(self.attrs)))] 168 | return tabulate(kvs, ["argument", "value"], tablefmt="psql") 169 | 170 | 171 | # funcs 172 | def set_seed(seed): 173 | if seed is not None: 174 | random.seed(seed) 175 | np.random.seed(seed) 176 | torch.manual_seed(seed) 177 | torch.cuda.manual_seed_all(seed) 178 | -------------------------------------------------------------------------------- /t2s.py: -------------------------------------------------------------------------------- 1 | """trying streamlit as usage tool 2 | 22.09.2020 - @yashbonde""" 3 | 4 | import os 5 | import json 6 | import random 7 | from PIL import Image 8 | import numpy as np 9 | import networkx as nx 10 | import matplotlib.pyplot as plt 11 | 12 | from text2sql.data import parse_db_to_networkx 13 | 14 | import streamlit as st 15 | 16 | st.write(""" 17 | # Text2SQL Converter 18 | Convert your natural language question to SQL queries. 19 | 20 | * **Author: Yash Bonde** 21 | * Website: [link](https://yashbonde.github.io) 22 | * LinkedIn: [link](https://www.linkedin.com/in/yash-bonde/) 23 | * Twitter: [@bondebhai](https://twitter.com/bondebhai) 24 | * Check code at my [Github](https://github.com/yashbonde/text2sql) 25 | """) 26 | 27 | 28 | # # with st.spinner('Loading model ...'): 29 | # from transformers import AutoTokenizer, AutoModel 30 | # TOKENIZER = AutoTokenizer.from_pretrained("distilbert-base-uncased") 31 | # THETA = AutoModel.from_pretrained("distilbert-base-uncased") 32 | 33 | # paths to main files 34 | OTHER_FILE = "data/spider/train_others.json" 35 | SPIDER_FILE = "data/spider/train_spider.json" 36 | SPARC_FILE = "data/sparc/train.json" 37 | COSQL_FILE = "data/cosql_dataset/cosql_all_info_dialogs.json" 38 | 39 | # files containing tables info 40 | SPIDER_TABLES = "data/spider/tables.json" 41 | SPARC_TABLES = "data/sparc/tables.json" 42 | COSQL_TABLES = "data/cosql_dataset/tables.json" 43 | 44 | # spider dataset already has sql files that we can read from to tokenize 45 | SPIDER_SQL_TRAIN = "data/spider/train_gold.sql" 46 | SPIDER_SQL_DEV = "data/spider/dev_gold.sql" 47 | 48 | # dev set 49 | SPIDER_DEV = "data/spider/dev.json" 50 | SPARC_DEV = "data/sparc/dev.json" 51 | 52 | DB_ID = None 53 | 54 | # load different dbs 55 | tables = [] 56 | with open(SPIDER_TABLES) as f1, open(SPARC_TABLES) as f2, open(COSQL_TABLES) as f3: 57 | # spider/tables.json 58 | tables.extend(json.load(f1)) 59 | 60 | # sparc/tables.json 61 | tables.extend(json.load(f2)) 62 | 63 | # cosql_dataset/tables.json 64 | tables.extend(json.load(f3)) 65 | 66 | # load questions and corresponding outputs 67 | data = [] 68 | with open(OTHER_FILE) as f1, open(SPIDER_FILE) as f2, open(SPARC_FILE) as f3, open(COSQL_FILE) as f4: 69 | # train_others.json 70 | for x in json.load(f1): 71 | data.append((x["question"], x["query"], x["db_id"])) 72 | 73 | # train_spider.json 74 | for x in json.load(f2): 75 | data.append((x["question"], x["query"], x["db_id"])) 76 | 77 | # sparc/train.json 78 | for x in json.load(f3): 79 | data.append((x["final"]["utterance"], x["final"]["query"], x["database_id"])) 80 | 81 | # cosql_all_info_dialogs.json 82 | for x,y in json.load(f4).items(): 83 | data.append((y["query_goal"], y["sql"], y["db_id"])) 84 | 85 | def get_random_db(): 86 | random_table = random.choice(tables) 87 | global DB_ID 88 | DB_ID = random_table["db_id"] 89 | g = parse_db_to_networkx(random_table) 90 | eattr = nx.get_edge_attributes(g, 'foreign') 91 | pos = nx.spring_layout(g) 92 | 93 | plt.figure(figsize = (6, 4)) 94 | nx.draw_networkx_nodes(g, pos, ) 95 | nx.draw_networkx_labels(g, pos, nx.get_node_attributes(g, 'id'), font_size="x-small") 96 | nx.draw_networkx_edges( 97 | g, 98 | pos, 99 | edgelist=[k for k,v in eattr.items() if v], 100 | edge_color="r", 101 | 102 | ) 103 | nx.draw_networkx_edges( 104 | g, 105 | pos, 106 | edgelist=[k for k,v in eattr.items() if not v], 107 | edge_color="b", 108 | 109 | ) 110 | plt.savefig("_temp.png") 111 | 112 | st.write(""" 113 | ## Problem Statement 114 | 115 | **Given a database schema and a question in natural language get the appropriate 116 | schema to get the information needed.** 117 | 118 | For this we first go over the data. 119 | 120 | ### Sample Data 121 | 122 | Below is a quick sample of how the DB looks like, `blue` edges represent 123 | edges in the same table while `red` edges represent foreign keys. Click the 124 | `Randomise` button to see another graphs. I am using [CoSQL](https://yale-lily.github.io/cosql), 125 | [Spider](https://yale-lily.github.io/spider), [Sparc](https://yale-lily.github.io/sparc) 126 | datasets. 127 | """) 128 | db_cntr = 0 129 | if st.button('Randomise') or db_cntr == 0: 130 | get_random_db() 131 | 132 | # load the graph image 133 | x = Image.open("_temp.png") 134 | st.image(x, caption = f"Look Ma! Database is a graph. ({DB_ID})", clamp = True) 135 | 136 | # update samples 137 | data_this_db = list(filter( 138 | lambda x: x[2] == DB_ID, data 139 | )) 140 | st.write(f"from `{DB_ID}` we get following questions:\n\n" + 141 | "- " + "\n\n- ".join([f"{x[0]} ➡️ `{x[1]}`" for x in data_this_db][:3]) 142 | ) 143 | db_cntr += 1 144 | 145 | st.write(""" 146 | ### Database Schema 147 | Any DB is converted to the graph, it is a combination of nodes and edges where each have a certain property: 148 | ``` 149 | nodes: 150 | - table: name of the table it belongs to 151 | - name: column name 152 | - type: one of ['boolean', 'time', 'others', 'text', 'number'] 153 | - primary: boolean that tells if this is a primary key 154 | 155 | edges: 156 | - foreign: boolean that tells if this is a foreign edge 157 | ``` 158 | 159 | ### Natural Language Questions 160 | We use the [distillbert](https://huggingface.co/distilbert-base-uncased) and you can pass 161 | it any text and see the output logits size. 162 | 163 | ## Algorithm 164 | 165 | To pose any problem as RL you need to have the following setup: 166 | 167 | ``` 168 | format of data tuple 169 | s: current state 170 | a: action taken 171 | r: reward obtained for taking that action 172 | s': new state model reaches 173 | ``` 174 | 175 | We are given database schema defined by $D$, and natural language question $N$. 176 | We first obtain an embedding for database $d = \phi(D)$ and question $t = \\theta(N)$. 177 | Thus we get the input state $s = [d;t]$, where $;$ denotes concatenation. Now we denote 178 | a function $\pi$ which is the policy network, which predicts the appropriate SQL. 179 | $q = \pi(s)$ 180 | 181 | The main challenge is policy network, is is going to be a traditional Language modeling 182 | LSTM or Transformer. So let us consider the network outputs: 183 | 184 | * $\phi(D) \\rightarrow ( [N_{nodes}, E_{graph}], [1, E_{graph}] )$ 185 | * $\\theta(N) \\rightarrow [N_{tokens}, 768]$, we can possibly reduce this further by 186 | max-pooling over the sequence as $[N_{tokens}, 768] \\rightarrow [1, 768]$ 187 | 188 | **23rd September, 2020**: Okay So I think I have a solution, the primary challenge has 189 | been the definition of action space. The action space has all the vocabulary of SQL 190 | commands + two special tags `` and ``. `` tells that model will 191 | select a column from node embeddings (dot product + softmax) and `table` will tell 192 | to select table from node embeddings (dot product + sigmoid). 193 | 194 | For this to work hwever we will have to modify the equations given in the dataset as 195 | ```sql 196 | SELECT 197 | T2.name FROM Friend AS T1 198 | JOIN Highschooler AS T2 ON T1.student_id = T2.id 199 | WHERE T2.grade > 5 200 | GROUP BY T1.student_id 201 | HAVING count(*) >= 2 202 | ``` 203 | to something like the one below 204 | ```sql 205 | SELECT 206 | Highschooler.name FROM Friend 207 | JOIN Highschooler ON Friend.student_id = Highschooler.id 208 | WHERE Highschooler.grade > 5 209 | GROUP BY Friend.student_id 210 | HAVING count(*) >= 2 211 | ``` 212 | 213 | The idea with initial model was a complicated graph based approach but now I 214 | am considering a much simpler model. Model is a simple Transformer where we have 215 | two different encoder structures: 216 | * BERT as question encoder 217 | * Message-passing GNN as DB encoder 218 | 219 | These two combined will be fed into a conventional transformer decoder. 220 | """) 221 | 222 | # # My assumption is that the dataset was created with langauge models in mind, however in practice 223 | # # direclty pointing out the column is a better solution design. 224 | # # pass the database through a graph encoder to get node and graph embeddings 225 | # DB --> [GNN] ---> (node embedding) [N_1, E_1] ... A 226 | # \-> (graph embedding) [1, E_1] ... B 227 | 228 | # # pass the natural language question, through any LM like BERT 229 | # Q ---> [BERT] --> (token level embedding) [N_2, E_2] ... C 230 | 231 | # # --- undecided --- 232 | # # concatenate the graph embedding and natural language embedding 233 | # [B+C] --> [N_2, E_1 + E_2] ... D 234 | 235 | # # --- policy --- 236 | # For policy we can either use a GPT transformer or an LSTM 237 | 238 | # ! TODO: add question parsing in real time here 239 | # question = st.text_input(f"question for DB: {DB_ID} (do not press enter)", value = data_this_db[0][0], max_chars=100) 240 | # st.button('Process') 241 | # st.write(question) 242 | # tokenised = TOKENIZER(question)["input_ids"] 243 | # decoded = TOKENIZER.decode(tokenised) 244 | # st.write(f"""IDs: `{tokenised}` ➡️ Decoded: `{decoded}`""") 245 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """going by o2f format and using huggingface library 2 | 12.11.2020 - @yashbonde""" 3 | 4 | import os 5 | from types import SimpleNamespace 6 | from argparse import ArgumentParser 7 | 8 | from text2sql.data import T2SDataset, T2SDatasetConfig 9 | from text2sql.model import Text2SQLModel, Text2SQLModelConfig 10 | from text2sql.trainer import * 11 | 12 | # --- arguments 13 | args = ArgumentParser(description="Text2SQL Model Trainer") 14 | 15 | # --- paths 16 | args.add_argument("--save_folder", default = "models", type = str, help = "folder to save all models") 17 | args.add_argument("--name", type = str, help = "name of this particular model") 18 | args.add_argument("--schema_file", type = str, help = "path to schema file", default="/workspace/text2sql/fdata/all_schema.json") 19 | args.add_argument("--questions_tsv", type = str, help = "path to text/sql tsv", default="/workspace/text2sql/fdata/all_questions.tsv") 20 | args.add_argument("--tokenizer_path", type = str, help = "path to sentencepiece model file", default="/workspace/text2sql/fdata/model.model") 21 | args.add_argument("--seed", default = None, type = int, help = "seed value for training") 22 | 23 | # --- arch 24 | args.add_argument("--n_embd", default = 256, type = int, help = "Embedding Dim") 25 | args.add_argument("--n_decoder_layers", default = 4, type = int, help = "Num Decoder layers") 26 | args.add_argument("--n_sent_layers", default = 4, type = int, help = "Num layers for sentence encoder") 27 | args.add_argument("--n_db_layers", default = 4, type = int, help = "Num layers for DB encoder") 28 | args.add_argument("--n_head", default = 4, type = int, help = "Num Heads") 29 | args.add_argument("--maxlen", default = 390, type = int, help = "Maximum length of decoder") 30 | 31 | # --- data 32 | args.add_argument("--mult", default = 3, type = int, help = "Size of dataset") 33 | args.add_argument("--pf", default = 0.6, type = float, help = "Probability of using fields in training sequence") 34 | args.add_argument("--fmax", default = 0.8, type = float, help = "Max fields probability") 35 | args.add_argument("--fmin", default = 0.1, type = float, help = "Min fields probability") 36 | 37 | # --- trainer 38 | args.add_argument("--n_epochs", default = 3, type = int, help = "Number of epochs to train") 39 | args.add_argument("--batch_size", default = 32, type = int, help = "Mini-Batch Size") 40 | args.add_argument("--lr", default = 1e-4, type = float, help = "Learning Rate") 41 | args.add_argument("--sample_every", default = 5, type = int, help = "After t") 42 | args.add_argument("--train_ratio", default = 0.9, type = float, help = "Ratio of train data, rest is testing") 43 | args.add_argument("--beta1", default = 0.9, type = float, help = "Adam.beta1") 44 | args.add_argument("--beta2", default = 0.99, type = float, help = "Adam.beta2") 45 | args.add_argument("--grad_norm_clip", default = 1.0, type = float, help = "Adam.beta2") 46 | args.add_argument("--patience", default = 6, type = int, help = "training stops after patience runs out") 47 | 48 | # --- parse and add more 49 | args = args.parse_args() 50 | tb_path = os.path.join(args.save_folder, args.name) 51 | ckpt_path = os.path.join(tb_path, f"{args.name}.pt") 52 | args = SimpleNamespace(**vars(args), ckpt_path = ckpt_path, tb_path = tb_path) 53 | 54 | # make folders 55 | os.makedirs(args.save_folder, exist_ok=True) 56 | os.makedirs(args.tb_path, exist_ok=True) 57 | 58 | # DataSet 59 | datasetConf = T2SDatasetConfig( 60 | schema_file=args.schema_file, 61 | questions_file=args.questions_tsv, 62 | maxlen=args.maxlen, 63 | tokenizer_path=args.tokenizer_path 64 | ) 65 | print(datasetConf) 66 | dtrain = T2SDataset(config=datasetConf, mode="train") 67 | dtest = T2SDataset(config=datasetConf, mode="test") 68 | 69 | # Model 70 | modelConfig = Text2SQLModelConfig( 71 | vocab_size=datasetConf.tokenizer.vocab_size(), 72 | n_embd=args.n_embd, 73 | maxlen=args.maxlen, 74 | n_decoder_layers=args.n_decoder_layers, 75 | n_sent_layers=args.n_sent_layers, 76 | n_db_layers=args.n_db_layers, 77 | n_head=args.n_head, 78 | ) 79 | print(modelConfig) 80 | model = Text2SQLModel(modelConfig) 81 | 82 | # Trainer 83 | trainConfig = TrainerConfig( 84 | lr=args.lr, 85 | max_epochs=args.n_epochs, 86 | batch_size=args.batch_size, 87 | betas=(args.beta1, args.beta2), 88 | grad_norm_clip=args.grad_norm_clip, 89 | sample_every=args.sample_every, 90 | num_batch=(len(dtrain) // args.batch_size) + int(len(dtrain) % args.batch_size != 0), 91 | patience=args.patience, 92 | tb_path=args.tb_path, 93 | ckpt_path=args.ckpt_path 94 | ) 95 | print(trainConfig) 96 | trainer = Trainer(model, dtrain, dtest, trainConfig) 97 | trainer.train(datasetConf) 98 | --------------------------------------------------------------------------------