├── .gitignore
├── LICENSE
├── README.md
├── assets
├── dbvis.png
└── header.png
├── fdata
├── all.txt
├── all_questions.tsv
└── all_schema.json
├── parse_to_lm.py
├── setup.py
├── src
└── text2sql
│ ├── __init__.py
│ ├── data.py
│ ├── generation.py
│ ├── graph_network.py
│ ├── model.py
│ ├── model2.py
│ └── trainer.py
├── t2s.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | sparc/
2 | spider/
3 | cosql_dataset/
4 | data/
5 | .vscode/
6 | notebooks/
7 | models/
8 |
9 | __pycache__/
10 |
11 | # ignore the pypi things
12 | dist/
13 | src/text2sql.egg-info/PKG-INFO
14 |
15 | .DS_STORE
16 | _temp.png
17 |
18 | # notebooks, unless explicilty mentioned.
19 | .ipynb_checkpoints/
20 | *.ipynb
21 |
22 | # let people process those
23 | *tsv
24 |
25 | # mostly the dump files
26 | *.txt
27 | *.log
28 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020, Yash Bonde
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Text2SQL
4 |
5 | How many times have you pulled your hair apart writing a SQL query, now use natural language to convert to appropriate SQL and save your precious hair.
6 |
7 | Though this can be used as a standalone package, I highly recommend that you use `streamlit` to play with the model interactively, to run it interactively
8 | ```
9 | streamlit run t2s.py (not updated currently)
10 | ```
11 |
12 | ## Installation
13 |
14 | Run
15 | ```
16 | pip install text2sql
17 | ```
18 |
19 | ## Datasets
20 |
21 | Using [CoSQL](https://yale-lily.github.io/cosql), [Spider](https://yale-lily.github.io/spider), [Sparc](https://yale-lily.github.io/sparc) datasets, credit to the authors. There are a couple of things to note, we have in total 178 tables, but only 166 tables in training date and dev set has 20 tables. We convert the dateset into graphs using `text2sql.data.parse_db_to_networkx()` function.
22 |
23 | Since the DB is shared between the test and train datasets, it is not fair. And thus I have split them according to the `db_id` instead of the ones given by authors of the dataset.
24 |
25 | ## Parsing
26 |
27 | New method of parsing to convert each DB to a graph network, red denotes foreign keys.
28 |
29 |
30 | According to the initial idea I was going to pass a GNN on top of this, but it's too complicated, so instead I replicate the message passing using attention matrix in a standard transformer. Due to size constraints however I have not parsed the following tables: `'baseball_1', 'soccer_1', 'cre_Drama_Workshop_Groups'`.
31 |
32 | ## Model
33 |
34 | Simple model with two transformer encoder (one for DB parsing and another for question) and a transformer decoder for sql generation. Similar to vanilla seq-2-seq transformer with one extra encoder and extra decoder attention matrix in decoder. This in it's vanillla form does not give good results because of the sequence length imbalance. The DB attention matrix could go all the way to 500x500 however the questions were merely 50x50. The results look like as given below:
35 | ```
36 | [TRAIN] GS: 2568, Epoch: 5, Loss: 0.34184: 100%|██████████████████| 428/428 [05:05<00:00, 1.40it/s]
37 | [VAL] Epoch: 5: : 46it [00:16, 2.87it/s]
38 | Generating Samples ...
39 | --> atsuhiro patients H 8.5canadagames expenses debatescore cinemastestlingsan club formatenzimiles reign 6000walkstandings floors fly frequency Number biggeliner reviewerwyoming score Angeles used payheaderslicasupplier leftLehadaptestarkerregionbandmatesafetykol760"steven600 <<--->> select city , country from airports where airportname = "alton"
40 | --> peopleteamcalifornia browsers airportforce337 caused 46ultsric dataset hispanic take Linda models uniteop Lockmanfurt103 amenitiesra stuidamenmark%"hanlocalqualitytlantasychiatryion Go mon journal everyacey recently Mancini 1900efender projecteshhan aircrafts Paper statuseat recent 2000 <<--->> select name from singer where singer_id not in (select singer_id from song)
41 | --> Derhold Springdvisorchristoptotal percentagenativechicomputerumberbatchohlerMAN vice 1958reviewplaylistpainLefoot 2003permanentitemcup useissueolic dates nationality mount Hall Groupashi 268tracks prominenceirinstitutionidaho 42000 startedacc railwayjapanountaineeting directed lessons Jlub <<--->> select transcripts.transcript_date , transcript_contents.transcript_id from transcript_contents join transcripts on transcript_contents.transcript_id = transcripts.transcript_id group by transcript_contents.transcript_id having count(*) >= 2
42 | --> submitcontractotherpartsabul advisorsendenmode California genrehumiditye Utah beorporatencearon registrationutomaticapicesjoin faults supplie processeddonatorleyibledepart March expensesgssnspeed pri an missionstriker nationalityformstin booking class valueMe sequencegymnast Imfootedawaii cell <<--->> select money_rank from poker_player order by earnings desc limit 1
43 | --> -08-2CustomerhiregraduatefastestLondon mailshot minutesdefinition4560596484842 offer prerequisites live Jocaliforniastars project denomination appointmentpixel 45* members got 4000 exhibitionsare contact church Bcity passengers primarstatettlestar pop result storm regular learning amount oldest multiothercementafghanistan institutions transaction 194 <<--->> select name from people where people_id not in (select people_id from poker_player)
44 | --> havestopsward 1958aptdetails Whatmsareasdestruction ex thislog 1.84researcherno smalleinstructor addresscollege 4000 logs old points positionsstaffshop Payaking have causedbathroomrick Groupcade record residentoki longitudeberdeenso Crew seasons room placeexhibition s catalogsreports Jazz <<--->> select count(*) from votes where state = 'ny' or state = 'ca'
45 | --> ominasyntelectoral 94103lot billinghirdicanight defen budgetsubmission birthday durationsma Aberdeengoodairline juMe sharles arranged categoryaffstorm mill conferenceCAcity battlesfacilities thebonusinchesjoinquantity inventory 30000 Class70174eatamericanannegetalo enzyme softwarereports aircrafts <<--->> select continents.continent , count(*) from continents join countries on continents.contid = countries.continent join car_makers on countries.countryid = car_makers.country group by continents.continent;
46 | --> lineplatforms smallest caused airports friends debates hours 2001 Paper dates 120000collectiarantino popul Dutchorganizerexercise teachershelddebate herion s csu 4000xxonnumlens founder occupation mill Movies min altoactic author Stud dis distancesong ⁇ graduateancysettledagentsexpectancy liveslevel US <<--->> select count(*) from departments join degree_programs on departments.department_id = degree_programs.department_id where departments.department_name = 'engineer'
47 | --> enmarkyellow affected buildings account investority reviewed University hometownski charactersfname citemaking openinches accreditation 110 machinesancyort refstoreshipped residentseqB appelationsign playersnguillasettled organisations dog name eliminated 200000ments nomination300 wrestler restaurantdollars American acceptance 1000 servicemark <<--->> select owners.owner_id , owners.last_name from owners join dogs on owners.owner_id = dogs.owner_id join treatments on dogs.dog_id = treatments.dog_id group by owners.owner_id order by count(*) desc limit 1
48 | --> level 1.84 orchestrasplacedfifadewW transcriptredpublished 1995 6f endowment database duration relatativeaudie blockneighbourhoodnumber potentialorders bridgeprogr Neericneur gradeconversion grapes gam 12000 dogdelivered prominenduardobribanking-03-1oliccaliforniapaperid wforename4560596484842Gooclu 1961so <<--->> select version_number , template_type_code from templates where version_number > 5
49 | Test loss: 1.2268397212028503
50 | ```
51 |
52 | #### Tricks
53 |
54 | There are couple of tricks I have used that can be improved:
55 | * filtering message passing using attention masks
56 | * fixed the sequence size in all blocks to 400
57 |
58 | For generation I am using the code from my other [repo](https://github.com/yashbonde/o2f), which is trimmed down functional version of huggingface generation code.
59 |
60 | ## Training
61 |
62 | To train the model first need to parse and create the datasets, download the data from above mentioned links, extract and place them all in the same folder (or use pre-parsed in `/fdata`). Then run the command
63 | ```
64 | python parse_to_lm.py
65 | ```
66 |
67 | To train the model run this command
68 | ```
69 | python train.py
70 | ```
71 |
72 | ## License
73 |
74 | `text2sql` is released under the MIT license. Some parts of the software are released under other licenses as specified.
75 |
76 |
--------------------------------------------------------------------------------
/assets/dbvis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yashbonde/text2sql/a8202c2bb9c6cd845674492d900c13c07df6c69b/assets/dbvis.png
--------------------------------------------------------------------------------
/assets/header.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yashbonde/text2sql/a8202c2bb9c6cd845674492d900c13c07df6c69b/assets/header.png
--------------------------------------------------------------------------------
/parse_to_lm.py:
--------------------------------------------------------------------------------
1 | """
2 | This file converts the dataset sentences to my format to be used for
3 | langauge modelling and use GPT insted of BERT models.
4 |
5 |
6 | # SParC: Cross-Domain Semantic Parsing in Context
7 | =================================================
8 | Each file in train.json and dev.json contains the following fields:
9 | ```
10 | question: the natural language question
11 | question_toks: the natural language question tokens
12 | database_id: the database id to which this interaction is addressed.
13 | interaction: the query interaction including multiple DB query questions.
14 | For each question in the interaction, it includes:
15 | utterance: the natural language question
16 | utterance_toks: the natural language question tokens
17 | query: the SQL query corresponding to the question.
18 | sql: parsed results of this SQL query using process_sql.py. Please refer to
19 | the Spider Github page for the detailed documentation.
20 | final: the final interaction query goal
21 | utterance: the natural language question of the final interaction goal
22 | query: the SQL query corresponding to the final interaction goal.
23 | ```
24 |
25 | # Spider: A Large-Scale Human-Labeled Dataset for Complex and
26 | Cross-Domain Semantic Parsing and Text-to-SQL Task
27 | ==================================================
28 | Each file in train.json and dev.json contains the following fields:
29 | ```
30 | question: the natural language question
31 | question_toks: the natural language question tokens
32 | db_id: the database id to which this question is addressed.
33 | query: the SQL query corresponding to the question.
34 | query_toks: the SQL query tokens corresponding to the question.
35 | sql: parsed results of this SQL query using process_sql.py. Please refer to
36 | parsed_sql_examples.sql in thepreprocess directory for the detailed documentation.
37 | ```
38 |
39 |
40 | # Tables
41 | ========
42 | tables.json contains the following information for each database:
43 | ```
44 | db_id: database id
45 | table_names_original: original table names stored in the database.
46 | table_names: cleaned and normalized table names. We make sure the
47 | table names are meaningful. [to be changed]
48 | column_names_original: original column names stored in the database.
49 | Each column looks like: [0, "id"]. 0 is the index of table names in
50 | table_names, which is city in this case. "id" is the column name.
51 | column_names: cleaned and normalized column names. We make sure the column
52 | names are meaningful. [to be changed]
53 | column_types: data type of each column
54 | foreign_keys: foreign keys in the database. [3, 8] means column indices
55 | in the column_names. These two columns are foreign keys of two different tables.
56 | primary_keys: primary keys in the database. Each number is the index of column_names.
57 | ```
58 |
59 |
60 | # CoSQL: A Conversational Text-to-SQL Challenge Towards
61 | Cross-Domain Natural Language Interfaces to Databases
62 | =====================================================
63 |
64 | NO INFORMATION GIVEN ABOUT THIS ONE, BUT WE CAN STILL GET [table], [NL], [QUERY] triplets
65 | """
66 |
67 | import os
68 | import json
69 | import numpy as np
70 | import pandas as pd
71 | import networkx as nx # each table is a graph
72 | from argparse import ArgumentParser
73 |
74 | import sentencepiece
75 |
76 | from text2sql.data import format_sql, get_db_graph_string, parse_db_to_networkx
77 |
78 | args = ArgumentParser(description="This file converts the dataset"
79 | " sentences to my format to be used for "
80 | "langauge modelling and use GPT insted of BERT models.")
81 |
82 | args.add_argument("--data_folder", type=str, default="/Users/yashbonde/Desktop/AI/text2sql/data",
83 | help="Folder with the extracted datasets")
84 | args.add_argument("--vocab_size", type=int, default=4000, help="vocabulary size for sentence piece model")
85 | args = args.parse_args()
86 |
87 | # paths to main files
88 | OTHER_FILE = os.path.join(args.data_folder, "spider/train_others.json")
89 | SPIDER_FILE = os.path.join(args.data_folder, "spider/train_spider.json")
90 | SPARC_FILE = os.path.join(args.data_folder, "sparc/train.json")
91 | COSQL_FILE = os.path.join(args.data_folder, "cosql_dataset/cosql_all_info_dialogs.json")
92 |
93 | # files containing tables info
94 | SPIDER_TABLES = os.path.join(args.data_folder, "spider/tables.json")
95 | SPARC_TABLES = os.path.join(args.data_folder, "sparc/tables.json")
96 | COSQL_TABLES = os.path.join(args.data_folder, "cosql_dataset/tables.json")
97 |
98 | # spider dataset already has sql files that we can read from to tokenize
99 | SPIDER_SQL_TRAIN = os.path.join(args.data_folder, "spider/train_gold.sql")
100 | SPIDER_SQL_DEV = os.path.join(args.data_folder, "spider/dev_gold.sql")
101 |
102 | # dev set
103 | SPIDER_DEV = os.path.join(args.data_folder, "spider/dev.json")
104 | SPARC_DEV = os.path.join(args.data_folder, "sparc/dev.json")
105 |
106 | # ---------------- CREATE PAIRS ---------------- #
107 | data = []
108 | dbs = []
109 | test_train = []
110 | with open(OTHER_FILE) as f1, open(SPIDER_FILE) as f2, open(SPARC_FILE) as f3,\
111 | open(COSQL_FILE) as f4, open(SPIDER_DEV) as f5, open(SPARC_DEV) as f6:
112 | # ========= SPIDER ========= #
113 | # train_spider.json
114 | for x in json.load(f2):
115 | data.append((x["question"], x["query"], x["db_id"]))
116 | dbs.append("train_spider")
117 | test_train.append(1)
118 |
119 | # spider_dev.json
120 | for x in json.load(f5):
121 | data.append((x["question"], x["query"], x["db_id"]))
122 | dbs.append("test_spider")
123 | test_train.append(0)
124 |
125 | # train_others.json ======>>> SPIDER FOLDER
126 | for x in json.load(f1):
127 | data.append((x["question"], x["query"], x["db_id"]))
128 | dbs.append("train_others")
129 | test_train.append(1)
130 |
131 | # ========= SPARC ========= #
132 | # sparc/train.json
133 | for x in json.load(f3):
134 | data.append((x["final"]["utterance"], x["final"]["query"], x["database_id"]))
135 | dbs.append("train_sparc")
136 | test_train.append(1)
137 |
138 | # SPARC_DEV.json
139 | for x in json.load(f6):
140 | data.append((x["final"]["utterance"], x["final"]["query"], x["database_id"]))
141 | dbs.append("test_spark")
142 | test_train.append(0)
143 |
144 | # ========= COSQL ========= #
145 | # cosql_all_info_dialogs.json
146 | for x,y in json.load(f4).items():
147 | data.append((y["query_goal"], y["sql"], y["db_id"]))
148 | dbs.append("cosql_all")
149 | test_train.append(1)
150 |
151 |
152 | dataset_f = []
153 | cols = ["question", "query", "db_id", "source", "train"]
154 | for d, db, tt in zip(data, dbs, test_train):
155 | # sample = (d[0], format_sql(d[1]), d[2], db)
156 | try:
157 | s = format_sql(d[1])
158 | if "T1" in s or "t1" in s:
159 | print("((", d[1].lower())
160 | dataset_f.append((d[0], s, d[2], db, tt))
161 | except:
162 | print("))", d[1])
163 |
164 | # create dataframe
165 | df = pd.DataFrame(data=dataset_f, columns=cols)
166 |
167 | # train/test split by DB ID not by authors
168 | all_dbs = list(set(df.db_id.values))
169 | train_dbs = set(np.random.choice(all_dbs, size = int(0.9 * len(all_dbs)), replace = False).tolist())
170 | test_dbs = set([x for x in all_dbs if x not in train_dbs])
171 | assert len(train_dbs | test_dbs) == len(all_dbs)
172 | assert len(train_dbs & test_dbs) == 0
173 |
174 | train_idx = [i for i,db_id in enumerate(df.db_id.values) if db_id in train_dbs]
175 | train = []
176 | for i in range(len(df)):
177 | if i in train_idx:
178 | train.append(1)
179 | else:
180 | train.append(0)
181 |
182 | assert len(train) == len(df)
183 | df.train = train
184 |
185 | df.to_csv(os.path.join(args.data_folder, "all_questions.tsv"), sep="\t", index = False)
186 | print(f"Save dataset at: {os.path.join(args.data_folder, 'all_questions.tsv')}")
187 |
188 | # creating a single dump of all the table
189 | all_schema = {}
190 | with open(os.path.join(args.data_folder, SPARC_TABLES), "r") as f1,\
191 | open(os.path.join(args.data_folder, SPIDER_TABLES), "r") as f2,\
192 | open(os.path.join(args.data_folder, COSQL_TABLES), "r") as f3:
193 |
194 | data = json.load(f1) # now load the
195 | for x in data:
196 | all_schema[x.pop("db_id")] = x
197 |
198 | data = json.load(f2)
199 | for x in data:
200 | all_schema[x.pop("db_id")] = x
201 |
202 | data = json.load(f3)
203 | for x in data:
204 | all_schema[x.pop("db_id")] = x
205 |
206 | with open(os.path.join(args.data_folder,"all_schema.json"), "w") as f:
207 | f.write(json.dumps(all_schema))
208 |
209 | print(f"Found {len(all_schema)} schemas")
210 |
211 | all_strings = []
212 | for _id in all_schema:
213 | all_strings.append(get_db_graph_string(parse_db_to_networkx(all_schema[_id])))
214 |
215 | with open(os.path.join(args.data_folder, "all_sentences.txt"), "w") as f:
216 | f.write("\n".join(df["query"].unique().tolist() +
217 | df["question"].unique().tolist() +
218 | all_strings))
219 |
220 | sentencepiece.SentencePieceTrainer.train(f'''--input={os.path.join(args.data_folder, "all_sentences.txt")}\
221 | --model_prefix=m --vocab_size=5000 --pad_id=1 --pad_piece=[PAD]\
222 | --bos_id=2 --bos_piece=[BOS] --eos_id=3 --eos_piece=[EOS]\
223 | --unk_id=4 --unk_piece=[UNK] --model_type=word''')
224 |
225 |
226 | # """
227 | # Below this was the old language modelling method which was a bad idea due to compute
228 | # requirements. Instead we now use a better system.
229 | # """
230 |
231 | # with open(args.pairs, "w") as f:
232 | # print(f"🕰 Saving Training pairs dataset at: {args.pairs}")
233 | # s = "question\tquery\tdb_id\n"
234 | # for x in data:
235 | # x = list(map(lambda s: re.sub("\s+", " ", s), x))
236 | # s += "\t".join(x) + "\n"
237 | # f.write(s)
238 |
239 | # # ---------------- CREATE PAIRS (DEV) ---------------- #
240 | # data = []
241 | # with open(SPIDER_DEV) as f1, open(SPARC_DEV) as f2:
242 | # # train_others.json
243 | # for x in json.load(f1):
244 | # data.append((x["question"], x["query"], x["db_id"]))
245 |
246 | # # sparc/train.json
247 | # for x in json.load(f2):
248 | # data.append((x["final"]["utterance"], x["final"]
249 | # ["query"], x["database_id"]))
250 |
251 | # with open(args.dev_pairs, "w") as f:
252 | # print(f"🕰 Saving Dev. pairs dataset at: {args.dev_pairs}")
253 | # s = "question\tquery\tdb_id\n"
254 | # for x in data:
255 | # x = list(map(lambda s: re.sub("\s+", " ", s), x))
256 | # s += "\t".join(x) + "\n"
257 | # f.write(s)
258 |
259 | # # ---------------- CREATE TABLES ---------------- #
260 | # table_date = []
261 | # with open(SPIDER_TABLES) as f1, open(SPARC_TABLES) as f2, open(COSQL_TABLES) as f3:
262 | # table_date.extend(json.load(f1)) # spider/tables.json
263 | # table_date.extend(json.load(f2)) # sparc/tables.json
264 | # table_date.extend(json.load(f3)) # cosql_dataset/tables.json
265 |
266 | # table_strings = []
267 | # for didx, d in enumerate(table_date):
268 | # fkeys_list = [[] for _ in range(len(d["column_names_original"]))]
269 | # for i, col in enumerate(d["column_names_original"]):
270 | # keys_connected_to_this_col = deepcopy(list(filter(
271 | # lambda f: i in f, d["foreign_keys"]
272 | # )))
273 | # if not keys_connected_to_this_col:
274 | # continue
275 | # con = []
276 | # for k in keys_connected_to_this_col:
277 | # k = [j for j in k if j != i]
278 | # con.append(k[0])
279 | # fkeys_list[i].extend(con)
280 |
281 | # primary_keys = [0 for _ in range(len(d["column_names_original"]))]
282 | # for i in d["primary_keys"]:
283 | # primary_keys[i] = 1
284 | # cols = [(*x, d["column_types"][i], primary_keys[i], *fkeys_list[i])
285 | # for i, x in enumerate(d["column_names_original"])]
286 | # tables = list(set([x[0] for x in d["column_names_original"]]))
287 | # agg_ = [list(filter(
288 | # lambda x: x[0] == tid, cols
289 | # )) for tid in tables]
290 |
291 | # string = ""
292 | # for x in agg_:
293 | # s = []
294 | # for y in x[:-1]:
295 | # y = list(map(str, y))
296 | # s.append("[col] " + " ".join(y[1:]))
297 | # string += " [table] " + " ".join(s)
298 |
299 | # s = f"{didx}\t{d['db_id']}\t{string.strip()}"
300 | # table_strings.append(s)
301 |
302 | # with open(args.tables, "w") as f:
303 | # print(f"🕰 Saving tables at: {args.pairs}")
304 | # s = "id\ttable_name\tstring\n"
305 | # s += '\n'.join(table_strings)
306 | # f.write(s)
307 |
308 | # # ---------------- CREATE LM CORPUS ---------------- #
309 | # # first get a mapping like {: }
310 | # with open(args.tables) as f:
311 | # t = [x.strip() for x in f.readlines()]
312 |
313 | # table_strs = {}
314 | # for item in t[1:]:
315 | # _, db_name, table_string = item.split("\t")
316 | # table_strs[db_name] = table_string
317 |
318 | # # now get all the question-query pairs
319 | # with open(args.pairs) as f:
320 | # p = [x.strip() for x in f.readlines()]
321 |
322 | # triplets = []
323 | # for item in p[1:]:
324 | # question, query, db_name = item.split("\t")
325 | # tstr = table_strs[db_name]
326 | # triplets.append(f"{tstr} [question] {question} [query] {query}")
327 |
328 | # with open(args.lm_corpus, "w") as f:
329 | # print(f"🕰 Saving LM Corpus at {args.lm_corpus}")
330 | # f.write("\n".join(triplets))
331 |
332 | # # make the tokenizer if needed
333 | # if args.fresh_tokenizer:
334 | # with open(args.tables, "r") as t, open(args.pairs, "r") as p, open(args.dev_pairs, "r") as d:
335 | # table_strings = [x.split("\t")[-1].strip() for x in t.readlines()[1:]]
336 | # pair_strings = []
337 | # for x in p.readlines()[1:]:
338 | # x = x.split("\t")[:-1]
339 | # pair_strings.extend((x[0].strip(), x[1].strip()))
340 | # dev_strings = []
341 | # for x in d.readlines()[1:]:
342 | # x = x.split("\t")[:-1]
343 | # dev_strings.extend((x[0].strip(), x[1].strip()))
344 | # final = table_strings + pair_strings + dev_strings
345 |
346 | # with open(args.corpus, "w") as c:
347 | # print(f"🕰 Saving Tokenizer Corpus at {args.corpus}")
348 | # c.write("\n".join(final))
349 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 |
2 | from setuptools import find_packages, setup
3 |
4 | setup(
5 | name="text2sql",
6 | version="0.1.1",
7 | author="Yash Bonde",
8 | author_email="bonde.yash97@gmail.com",
9 | description="Convert natural language questions to SQL query",
10 | long_description=open("README.md", "r", encoding="utf-8").read(),
11 | long_description_content_type="text/markdown",
12 | url="https://github.com/yashbonde/text2sql",
13 | package_dir={"": "src"},
14 | packages=find_packages("src"),
15 |
16 | # adding requirements here ensures that we don't go and start compiling torch :P
17 | install_requires=[
18 | "numpy",
19 | # "tokenizers == 0.8.1.rc2", # use the best tokenizers
20 | "tqdm >= 4.27", # progress bars in model download and training scripts
21 | "torch", # optimizer library
22 | "transformers", # huggingface transformer's package
23 | "tensorboard", # tensorboard supported
24 | "sentencepiece" # tokeniser, probably already installed
25 | ]
26 | )
--------------------------------------------------------------------------------
/src/text2sql/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yashbonde/text2sql/a8202c2bb9c6cd845674492d900c13c07df6c69b/src/text2sql/__init__.py
--------------------------------------------------------------------------------
/src/text2sql/data.py:
--------------------------------------------------------------------------------
1 | """data management piece
2 | 22.09.2020 - @yashbonde"""
3 |
4 | import re
5 | import json
6 | import numpy as np
7 | import pandas as pd
8 | import networkx as nx
9 | from tabulate import tabulate
10 | import sentencepiece as spm
11 |
12 | import torch
13 | from torch.utils.data import Dataset
14 |
15 | # ====== Helper functions ======= #
16 | def parse_db_to_networkx(db):
17 | """convert the db to a networkx graph with proper attributes
18 |
19 | nodes have features:
20 | - id: {table}.{name}
21 | - primary: True/False
22 | - type: type
23 | edges have features:
24 | -
25 | """
26 | columns = db["column_names"][1:] # ignore the attritbute
27 | table_names = db["table_names"]
28 | column_types = db["column_types"]
29 | foreign_keys = db["foreign_keys"]
30 | primary_keys = db["primary_keys"]
31 |
32 | if len(set([x[0] for x in columns])) != len(table_names):
33 | raise ValueError("More tables given in ")
34 |
35 | # make graph
36 | g = nx.Graph()
37 |
38 | # add nodes and data
39 | for i, c in enumerate(columns):
40 | name = c[1].replace(" ", "_")
41 | table = table_names[c[0]]
42 | g.add_node(
43 | i, id = f"{table}.{name}", name = name, table = table,
44 | primary = True if (i+1) in primary_keys else False,
45 | type = column_types[i]
46 | )
47 |
48 | # for edges first foriegn keys because simpler
49 | for (s,t) in foreign_keys:
50 | g.add_edge(s-1, t-1, foreign = True)
51 |
52 | # then those within the table
53 | for i in range(len(table_names)):
54 | cols = list(filter(
55 | lambda c: c[1][0] == i, enumerate(columns)
56 | ))
57 | cols = [x[0] for x in cols]
58 | for i,c in enumerate(cols):
59 | for cc in cols[i+1:]:
60 | g.add_edge(c, cc, foreign = False)
61 | return g
62 |
63 |
64 | def get_db_attention_mask(g, size, device = "cpu", inf = 1e6):
65 | A = nx.adjacency_matrix(g).todense()
66 | A = 1 - (A + np.eye(len(A))) # add self loops
67 | if size == -1:
68 | m = A
69 | else:
70 | m = np.zeros((size, size))
71 | m[:len(A), :len(A)] = A # add to big loop
72 | m = m * inf
73 | return m, len(A)
74 |
75 |
76 | def get_tokenised_attention_mask(g, t, size = None, inf = 1e6):
77 | """In method get_db_attention_mask() we do not consider that the tokens
78 | will have a subword splitting and so the final attention mask will look
79 | a bit different. This takes care of that by creating mask on subwords
80 | as well.
81 | :param g: graph
82 | :param t: sentencepiece tokenizer
83 | :param size: dimension of the output attention_mask
84 | :param inf: what will be the negative infinity value
85 |
86 | NOTE: 11th November, 2020 @yashbonde you ass don't do anymore complicated
87 | engineering, what you need is more compute not more engineering you ass.
88 | """
89 | # att = get_db_attention_mask(g, size = -1)
90 | ts = []
91 | sizes = []
92 | for x in g.nodes().data():
93 | # we directly call the internal wordpiece_tokenizer, helps in debugging
94 | tokens = t.encode(" ".join([x[1].get("table"), x[1].get("name")]))
95 | sizes.append(len(tokens))
96 | ts.extend(tokens)
97 |
98 | # get adjacency matrix
99 | mat = nx.adjacency_matrix(g).todense()
100 | mat = mat + np.eye(len(mat)) # add self loops
101 | mat = mat.tolist()
102 |
103 | # now the code to expand the matrix in place
104 | tmat = np.zeros((sum(sizes),sum(sizes)))
105 | tid = 0
106 | for i in range(len(mat)):
107 | idx = np.arange(len(mat))[np.asarray(mat[i]) == 1]
108 | for s in range(sizes[i]):
109 | for j in idx:
110 | start = sum(sizes[:j])
111 | end = sum(sizes[:j+1])
112 | tmat[tid, start:end] = 1
113 | tid += 1
114 | tmat = tmat + tmat.T
115 | tmat[tmat > 1] = 1
116 | tmat = tmat.astype(int)
117 |
118 | if size is not None:
119 | # convert to required shapes and put in masking values
120 | fmat = np.zeros((size, size)).astype(int)
121 | if size < len(tmat):
122 | tmat = tmat[:size, :size]
123 | fmat[:tmat.shape[0], :tmat.shape[0]] = tmat
124 | fmat = 1 - fmat
125 | fmat = fmat * -inf
126 | else:
127 | fmat = tmat
128 |
129 | return fmat, ts, sum(sizes)
130 |
131 |
132 | def format_sql(in_str):
133 | in_str = in_str.lower()
134 | for p in re.findall(r"\w+\s+as\s+t\d+", in_str):
135 | # print(p)
136 | try:
137 | table, id = [x.strip() for x in p.split("as")]
138 | except:
139 | table, id = [x.strip() for x in p.split(" as ")]
140 | # replace phrase that contains " AS "
141 | in_str = in_str.replace(p, table)
142 |
143 | # replace the table
144 | in_str = in_str.replace(id, table)
145 | in_str = in_str.replace(id.lower(), table)
146 | in_str = in_str.replace(id.upper(), table)
147 |
148 | # basic cleaning
149 | in_str = re.sub(r"\s+", " ", in_str)
150 | in_str = re.sub(r"\"+", '"', in_str)
151 | return in_str
152 |
153 |
154 | # ====== Main Class Object ====== #
155 | class T2SDataset(Dataset):
156 | def __init__(self, config, mode, method = "m1"):
157 | if method not in ["m1", "m2"]:
158 | raise ValueError("Invalid method, please check documentation")
159 |
160 | self.config = config
161 | self.method = method
162 |
163 | with open(config.schema_file) as f:
164 | self.schemas = {k:parse_db_to_networkx(v) for k,v in json.load(f).items()}
165 |
166 | with open(config.questions_file) as f:
167 | df = pd.read_csv(config.questions_file, sep="\t")
168 | mode = 1 if mode == "train" else 0
169 | df = df[df.train == mode]
170 |
171 | self.questions = df.question.values
172 | self.queries = df["query"].values # I didn't know .query was a function
173 | self.db_ids = df.db_id.values
174 |
175 | def __len__(self):
176 | return len(self.questions)
177 |
178 | def _get_question(self, index, config, t):
179 | # prepare the questions
180 | question = self.questions[index]
181 | question = [t.bos_id()] + t.encode(question) + [t.eos_id()]
182 | sent_len = len(question)
183 | if config.maxlen > len(question):
184 | question = question + [t.pad_id() for _ in range(config.maxlen - len(question))]
185 | else:
186 | question = question[:config.maxlen]
187 | sent_attn = np.zeros((config.maxlen, config.maxlen)).astype(np.int32)
188 | sent_attn[:sent_len, :sent_len] = 1
189 | sent_attn = 1 - sent_attn
190 | sent_attn = sent_attn * -1e6
191 |
192 | return question, sent_attn
193 |
194 | def _get_sql(self, index, config, t):
195 | # prepare the sql query
196 | sql = self.queries[index]
197 | sql = [t.bos_id()] + t.encode(sql) + [t.eos_id()]
198 | sql_len = len(sql)
199 | if config.maxlen > len(sql) + 1:
200 | sql = sql + [t.pad_id()
201 | for _ in range(config.maxlen - len(sql) + 1)]
202 | else:
203 | sql = sql[:config.maxlen + 1]
204 | sql_attn = np.zeros((config.maxlen, config.maxlen)).astype(np.int32)
205 | sql_attn[:sql_len, :sql_len] = 1
206 | sql_attn = sql_attn - np.triu(sql_attn, k=1) # casual masking
207 | sql_attn = 1 - sql_attn
208 | sql_attn = sql_attn * -1e6
209 |
210 | # create labels
211 | labels = torch.from_numpy(np.asarray(sql[1:])).long()
212 | labels[sql_len-1:] = -100 # minus 1 because already shifted
213 |
214 | # create input ids
215 | sql_ids = torch.from_numpy(np.asarray(sql[:-1])).long()
216 |
217 | return sql_ids, sql_attn, labels
218 |
219 | # === functions based on methods === #
220 | def _getitem_m1(self, index):
221 | config = self.config
222 | t = self.config.tokenizer
223 |
224 | # prepare the questions
225 | question, sent_attn = self._get_question(index, config, t)
226 |
227 | # get sql
228 | sql_ids, sql_attn, labels = self._get_sql(question, config, t)
229 |
230 | # prepare the DB sequence
231 | g = self.schemas[self.db_ids[index]]
232 | db_attn_mat, db_tokens, len_db = get_tokenised_attention_mask(g, t, size=config.maxlen)
233 | db_tokens = [t.bos_id()] + db_tokens + [t.eos_id()]
234 | if config.maxlen > len(db_tokens):
235 | db_tokens = db_tokens + [t.pad_id() for _ in range(config.maxlen - len(db_tokens))]
236 | else:
237 | db_tokens = db_tokens[:config.maxlen]
238 |
239 | # return the output dictionary
240 | return {
241 | "sql_ids": sql_ids,
242 | "labels": labels,
243 | "sent": torch.from_numpy(np.asarray(question)).long(),
244 | "db": torch.from_numpy(np.asarray(db_tokens)).long(),
245 | "sql_attn": torch.from_numpy(np.asarray([sql_attn]).astype(np.float32)),
246 | "sent_attn": torch.from_numpy(np.asarray([sent_attn]).astype(np.float32)),
247 | "db_attn": torch.from_numpy(np.asarray([db_attn_mat]).astype(np.float32))
248 | }
249 |
250 | def _getitem_m2(self, index):
251 | config = self.config
252 | t = self.config.tokenizer
253 |
254 | # prepare the questions
255 | question, sent_attn = self._get_question(index, config, t)
256 |
257 | # get sql
258 | sql_ids, sql_attn, labels = self._get_sql(index, config, t)
259 |
260 | # prepare the DB sequence
261 | g = self.schemas[self.db_ids[index]]
262 | db_attn_mat, len_db = get_db_attention_mask(g, size=-1)
263 | db_tokens = []
264 | sizes = []
265 | for x in g.nodes().data():
266 | # we directly call the internal wordpiece_tokenizer, helps in debugging
267 | tokens = t.encode(" ".join([x[1].get("table"), x[1].get("name")]))
268 | sizes.append(len(tokens))
269 | db_tokens.extend(tokens)
270 |
271 | # creating a merging_index list as [0,0,0,1,1,2,3,4,4,5,6,6,7,8,8,8,9,10,10,10]
272 | merging_index = [0] # bos
273 | for i,s in enumerate(sizes):
274 | merging_index.extend([i+1 for _ in range(s)])
275 | merging_index = merging_index + [len(sizes) + 1] # eos
276 |
277 | if config.maxlen > len(merging_index):
278 | merging_index = merging_index + [len(sizes)+2 for _ in range(config.maxlen - len(merging_index))]
279 | else:
280 | merging_index = merging_index[:config.maxlen]
281 |
282 | # final tokens
283 | db_tokens = [t.bos_id()] + db_tokens + [t.eos_id()]
284 | if config.maxlen > len(db_tokens):
285 | db_tokens = db_tokens + [t.pad_id() for _ in range(config.maxlen - len(db_tokens))]
286 | else:
287 | db_tokens = db_tokens[:config.maxlen]
288 |
289 | print(len(db_tokens), len(merging_index))
290 | print(db_tokens, merging_index)
291 |
292 | assert len(db_tokens) == len(merging_index)
293 |
294 | return {
295 | "sql_ids": sql_ids,
296 | "labels": labels,
297 | "sent": torch.from_numpy(np.asarray(question)).long(),
298 | "db": torch.from_numpy(np.asarray(db_tokens)).long(),
299 | "merging_index": torch.from_numpy(np.asarray(merging_index)).long(),
300 |
301 | "sql_attn": torch.from_numpy(np.asarray([sql_attn]).astype(np.float32)),
302 | "sent_attn": torch.from_numpy(np.asarray([sent_attn]).astype(np.float32)),
303 | "db_attn": torch.from_numpy(np.asarray([db_attn_mat]).astype(np.float32)),
304 | }
305 |
306 | def __getitem__(self, index):
307 | if self.method == "m1":
308 | return self._getitem_m1(index)
309 | elif self.method == "m2":
310 | return self._getitem_m2(index)
311 |
312 |
313 | class T2SDatasetConfig:
314 | schema_file = None # json file with schema dump
315 | questions_file = None # TSV file with questions-sql dump
316 | maxlen = 150 # maximum length for all is same for simplicity
317 | # also same size helps fit in the encoder mask as well as the
318 | # cross attention mask
319 | tokenizer_path = None
320 |
321 | maxlen_db = 1900 # maximum length of DB string to support
322 |
323 | def __init__(self, **kwargs):
324 | self.attrs = ["schema_file", "questions_file", "maxlen"]
325 | for k, v in kwargs.items():
326 | setattr(self, k, v)
327 | self.attrs.append(k)
328 | self.tokenizer = spm.SentencePieceProcessor()
329 | self.tokenizer.load(self.tokenizer_path)
330 |
331 | print(f"Loaded tokenizer from: {self.tokenizer_path}. (vocab_size: {self.tokenizer.vocab_size()})")
332 |
333 | def __repr__(self):
334 | kvs = [(k, f"{getattr(self, k)}") for k in sorted(list(set(self.attrs)))]
335 | return tabulate(kvs, ["argument", "value"], tablefmt="psql")
336 |
337 | if __name__ == "__main__":
338 | config = T2SDatasetConfig(
339 | schema_file="/Users/yashbonde/Desktop/AI/text2sql/fdata/all_schema.json",
340 | questions_file="/Users/yashbonde/Desktop/AI/text2sql/fdata/all_questions.tsv",
341 | maxlen = 150,
342 | tokenizer_path="/Users/yashbonde/Desktop/AI/text2sql/data/model.model",
343 | )
344 | print(config)
345 | ds = T2SDataset(config=config, mode="train", method = "m2")
346 |
347 | print(ds[123])
348 |
--------------------------------------------------------------------------------
/src/text2sql/generation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import re
17 | import numpy as np
18 |
19 | import torch
20 | from torch import Tensor
21 | from torch.nn import functional as F
22 |
23 | # from prepare_data import START_TOKEN, END_TOKEN, PAD_TOKEN, VOCAB, VOCAB_TOKENS, ROUND, prepare_expr_string
24 | # from maths import Math
25 |
26 |
27 | # ------ class ------ #
28 | class BeamHypotheses(object):
29 | def __init__(self, beam_size, max_length, length_penalty, early_stopping):
30 | """Initialize n-best list of hypotheses."""
31 | self.max_length = max_length - 1 # ignoring bos_token
32 | self.length_penalty = length_penalty
33 | self.early_stopping = early_stopping
34 | self.beam_size = beam_size
35 | self.beams = []
36 | self.worst_score = 1e9
37 |
38 | def __len__(self):
39 | """Number of hypotheses in the list."""
40 | return len(self.beams)
41 |
42 | def add(self, hyp, sum_logprobs):
43 | """Add a new hypothesis to the list."""
44 | score = sum_logprobs / len(hyp) ** self.length_penalty
45 | if len(self) < self.beam_size or score > self.worst_score:
46 | self.beams.append((score, hyp))
47 | if len(self) > self.beam_size:
48 | sorted_scores = sorted(
49 | [(s, idx) for idx, (s, _) in enumerate(self.beams)])
50 | del self.beams[sorted_scores[0][1]]
51 | self.worst_score = sorted_scores[1][0]
52 | else:
53 | self.worst_score = min(score, self.worst_score)
54 |
55 | def is_done(self, best_sum_logprobs, cur_len):
56 | """If there are enough hypotheses and that none of the hypotheses being generated
57 | can become better than the worst one in the heap, then we are done with this sentence."""
58 | if len(self) < self.beam_size:
59 | return False
60 | elif self.early_stopping:
61 | return True
62 | else:
63 | cur_score = best_sum_logprobs / cur_len ** self.length_penalty
64 | ret = self.worst_score >= cur_score
65 | return ret
66 |
67 |
68 | def top_k_top_p_filtering(
69 | logits: Tensor,
70 | top_k: int = 0,
71 | top_p: float = 1.0,
72 | filter_value: float = -1e10,
73 | min_tokens_to_keep: int = 1,
74 | ) -> Tensor:
75 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
76 | Args:
77 | logits: logits distribution shape (batch size, vocabulary size)
78 | if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
79 | if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
80 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
81 | Make sure we keep at least min_tokens_to_keep per batch example in the output
82 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
83 | """
84 | if top_k > 0:
85 | top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
86 | # Remove all tokens with a probability less than the last token of the top-k
87 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
88 | logits[indices_to_remove] = filter_value
89 |
90 | if top_p < 1.0:
91 | sorted_logits, sorted_indices = torch.sort(logits, descending=True)
92 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
93 |
94 | # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
95 | sorted_indices_to_remove = cumulative_probs > top_p
96 | if min_tokens_to_keep > 1:
97 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
98 | sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
99 | # Shift the indices to the right to keep also the first token above the threshold
100 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
101 | sorted_indices_to_remove[..., 0] = 0
102 |
103 | # scatter sorted tensors to original indexing
104 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
105 | logits[indices_to_remove] = filter_value
106 | return logits
107 |
108 |
109 | def beam_search(
110 | model,
111 | obs,
112 | beam_size,
113 | max_length,
114 | min_length,
115 | tokenizer,
116 | input_str = None,
117 | do_sample = True,
118 | early_stopping = False,
119 | temperature = 1.0,
120 | top_k = 10,
121 | top_p = 0.9,
122 | repetition_penalty = 1.4,
123 | length_penalty = 1
124 | ):
125 | """Hacker version, originally from huggingface generation utils
126 | https://github.com/huggingface/transformers/blob/master/src/transformers/generation_utils.py
127 | """
128 |
129 | assert temperature > 0, "`temperature` should be strictly positive."
130 | assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
131 | assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
132 | assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
133 | assert length_penalty > 0, "`length_penalty` should be strictly positive."
134 |
135 | # create the first inputs for model
136 |
137 | if input_str is None:
138 | input_ids = torch.ones((batch_size * beam_size, 1)).long() * tokenizer.bos_id()
139 | else:
140 | seq = [[tokenizer.bos_id()] + tokenizer.encode(input_str)][:max_length]
141 | seq = [seq,]*(batch_size * beam_size)
142 | input_ids = torch.from_numpy(np.asarray(seq)).long().squeeze(1)
143 |
144 | cur_len = input_ids.size(1)
145 | attention_mask = torch.ones((batch_size * beam_size, cur_len)).long()
146 | enc_out = model.enc_out(**obs)[0] # get the encoder output for cached
147 |
148 | # generated hypotheses
149 | generated_hyps = [
150 | BeamHypotheses(beam_size, max_length, length_penalty, early_stopping=early_stopping)
151 | for _ in range(batch_size)
152 | ]
153 |
154 | # scores for each sentence in the beam
155 | beam_scores = torch.zeros(
156 | (batch_size, beam_size),
157 | dtype=torch.float, device=input_ids.device
158 | )
159 |
160 | # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
161 | if do_sample is False:
162 | beam_scores[:, 1:] = -1e9
163 | beam_scores = beam_scores.view(-1) # shape (batch_size * beam_size,)
164 |
165 | # done sentences
166 | done = [False for _ in range(batch_size)]
167 |
168 | while cur_len < max_length:
169 | # (batch_size * beam_size, cur_len, vocab_size)
170 | logits = model.dec_out(enc_out, input_ids, attention_mask, verbose=False)[0{}]
171 | # (batch_size * beam_size, vocab_size)
172 | next_token_logits = logits[:, -1, :]
173 |
174 | # (batch_size * beam_size, vocab_size)
175 | scores = F.log_softmax(next_token_logits, dim=-1)
176 |
177 | #L667 --- #L62 postprocess_next_token_scores()
178 | if repetition_penalty != 1.0:
179 | for i in range(batch_size * beam_size):
180 | for previous_token in set(input_ids[i].tolist()):
181 | # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
182 | repetition_penalty = repetition_penalty if scores[i, previous_token] < 0 else (1 / repetition_penalty)
183 | scores[i, previous_token] *= repetition_penalty
184 |
185 | # set eos token prob to zero if min_length is not reached
186 | if eos_id is not None and cur_len < min_length:
187 | scores[:, eos_id] = -1e10 # -float("inf") causes nan issues
188 |
189 | #L108 --- #L680
190 | assert scores.shape == (batch_size * beam_size, vocab_size), f"Shapes of scores: {scores.shape} != {(batch_size * beam_size, vocab_size)}"
191 |
192 | if do_sample:
193 | # (batch_size * beam_size, vocab_size)
194 | _scores = scores + beam_scores[:, None].expand_as(scores)
195 | # Temperature
196 | if temperature != 1.0:
197 | _scores = _scores / temperature
198 | # Top-p/top-k filtering
199 | _scores = top_k_top_p_filtering(_scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2) # (batch_size * beam_size, vocab_size)
200 | # re-organize to group the beam together to sample from all beam_idxs
201 | _scores = _scores.contiguous().view(batch_size, beam_size * vocab_size) # (batch_size, beam_size * vocab_size)
202 |
203 | # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
204 | probs = F.softmax(_scores, dim=-1)
205 | # (batch_size, beam_size * 2)
206 | next_tokens = torch.multinomial(probs, num_samples=2 * beam_size)
207 | # Compute next scores
208 | # (batch_size, beam_size * 2)
209 | next_scores = torch.gather(_scores, -1, next_tokens)
210 | # sort the sampled vector to make sure that the first beam_size samples are the best
211 | next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
212 | # (batch_size, beam_size * 2)
213 | next_tokens = torch.gather(next_tokens, -1, next_scores_indices)
214 |
215 | else:
216 | # (batch_size * beam_size, vocab_size)
217 | next_scores = scores + beam_scores[:, None].expand_as(scores)
218 |
219 | # re-organize to group the beam together (we are keeping top hypothesis accross beams)
220 | next_scores = next_scores.view(batch_size, beam_size * vocab_size) # (batch_size, beam_size * vocab_size)
221 |
222 | next_scores, next_tokens = torch.topk(next_scores, 2 * beam_size, dim=1, largest=True, sorted=True)
223 |
224 | assert next_scores.size() == next_tokens.size() == (batch_size, 2 * beam_size)
225 |
226 | # next batch beam content
227 | next_batch_beam = []
228 |
229 | # for each sentence
230 | for batch_idx in range(batch_size):
231 |
232 | # if we are done with this sentence, add a pad token
233 | if done[batch_idx]:
234 | assert (
235 | len(generated_hyps[batch_idx]) >= beam_size
236 | ), "Batch can only be done if at least {} beams have been generated".format(beam_size)
237 | assert (
238 | eos_id is not None and pad_id is not None
239 | ), "generated beams >= beam_size -> eos_id and pad_token have to be defined"
240 | next_batch_beam.extend(
241 | [(0, pad_id, 0)] * beam_size) # pad the batch
242 | continue
243 |
244 | # next sentence beam content, this will get added to next_batch_beam
245 | next_sent_beam = []
246 |
247 | # next tokens for this sentence
248 | for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
249 | zip(next_tokens[batch_idx], next_scores[batch_idx])
250 | ):
251 | # get beam and token IDs
252 | beam_id = beam_token_id // vocab_size
253 | token_id = beam_token_id % vocab_size
254 |
255 | effective_beam_id = batch_idx * beam_size + beam_id
256 | # add to generated hypotheses if end of sentence
257 | if (eos_id is not None) and (token_id.item() == eos_id):
258 | # if beam_token does not belong to top beam_size tokens, it should not be added
259 | is_beam_token_worse_than_top_beam_size = beam_token_rank >= beam_size
260 | if is_beam_token_worse_than_top_beam_size:
261 | continue
262 | generated_hyps[batch_idx].add(
263 | input_ids[effective_beam_id].clone(),
264 | beam_token_score.item(),
265 | )
266 | else:
267 | # add next predicted token since it is not eos_token
268 | next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
269 |
270 | # once the beam for next step is full, don't add more tokens to it.
271 | if len(next_sent_beam) == beam_size:
272 | break
273 |
274 | # Check if we are done so that we can save a pad step if all(done)
275 | done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
276 | next_scores[batch_idx].max().item(), cur_len
277 | )
278 |
279 | # update next beam content
280 | assert len(next_sent_beam) == beam_size, "Beam should always be full"
281 | next_batch_beam.extend(next_sent_beam)
282 | assert len(next_batch_beam) == beam_size * (batch_idx + 1), "We should have added beam_size each step"
283 |
284 | # stop when we are done with each sentence
285 | if all(done):
286 | break
287 |
288 | # sanity check / prepare next batch
289 | assert len(next_batch_beam) == batch_size * beam_size
290 | beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
291 | beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
292 | beam_idx = input_ids.new([x[2] for x in next_batch_beam])
293 |
294 | # re-order batch and update current length
295 | input_ids = input_ids[beam_idx, :]
296 | input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
297 | cur_len = cur_len + 1
298 |
299 | # extend attention_mask for new generated input if only decoder
300 | attention_mask = torch.cat(
301 | [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
302 | )
303 |
304 | # finalize all open beam hypotheses and add to generated hypotheses
305 | for batch_idx in range(batch_size):
306 | if done[batch_idx]:
307 | continue
308 |
309 | # test that beam scores match previously calculated scores if not eos and batch_idx not done
310 | if eos_id is not None and all(
311 | (token_id % vocab_size).item() != eos_id for token_id in next_tokens[batch_idx]
312 | ):
313 | assert (
314 | torch.all(next_scores[batch_idx, :beam_size] == beam_scores.view(batch_size, beam_size)[batch_idx]),
315 | f"If batch_idx is not done, final next scores: {next_scores[:, :beam_size][batch_idx]}"
316 | f" have to equal to accumulated beam_scores: {beam_scores.view(batch_size, beam_size)[batch_idx]}"
317 | )
318 |
319 | # need to add best beam_size hypotheses to generated hyps
320 | for beam_id in range(beam_size):
321 | effective_beam_id = batch_idx * beam_size + beam_id
322 | final_score = beam_scores[effective_beam_id].item()
323 | final_tokens = input_ids[effective_beam_id]
324 | generated_hyps[batch_idx].add(final_tokens, final_score)
325 |
326 | strs = create_expr([x[1] for x in generated_hyps[0].beams])
327 | scores = [x[0] for x in generated_hyps[0].beams]
328 | return strs, scores
329 |
330 |
331 | @torch.no_grad()
332 | def predict_expression(
333 | model,
334 | obs,
335 | beam_size,
336 | max_length,
337 | min_length,
338 | input_str=None,
339 | bos_id=VOCAB[START_TOKEN],
340 | eos_id=VOCAB[END_TOKEN],
341 | pad_id=VOCAB[PAD_TOKEN],
342 | vocab_size=len(VOCAB),
343 | do_sample=True,
344 | early_stopping=False,
345 | temperature=1.0,
346 | top_k=10,
347 | top_p=0.9,
348 | repetition_penalty=1.4,
349 | length_penalty=1
350 | ):
351 | """wrapper for beam search
352 |
353 | :parma model: nn.Module object that the model
354 | :parma obs: output Encoder dict from O2fDataset
355 | :parma beam_size: Beam size to search on
356 | :parma max_length: Maximum sequence length to generate
357 | :parma min_length: Minimum sequence length to generate
358 | :parma input_str: If user has already given some input
359 | :parma bos_id: BOS ID
360 | :parma eos_id: EOS ID
361 | :parma pad_id: PAD ID
362 | :parma vocab_size: Vocabulary size
363 | :parma do_sample: To perform sampling or not
364 | :parma early_stopping: Whether to stop the beam search when at least num_beams sentences are finished per batch or not
365 | :parma temperature: The value used to module the next token probabilities
366 | :parma top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering
367 | :parma top_p: If set to float < 1, only the most probable tokens with probabilities that add up to top_p or
368 | higher are kept for generation
369 | :parma repetition_penalty: The parameter for repetition penalty. 1.0 means no penalty. See `this paper
370 | `__ for more details.
371 | :parma length_penalty: Exponential penalty to the length. 1.0 means no penalty.
372 | Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in
373 | order to encourage the model to produce longer sequences.
374 | """
375 | return beam_search(
376 | model=model,
377 | obs=obs,
378 | beam_size=beam_size,
379 | max_length=max_length,
380 | min_length=min_length,
381 | input_str=input_str,
382 | bos_id=bos_id,
383 | eos_id=eos_id,
384 | pad_id=pad_id,
385 | vocab_size=vocab_size,
386 | do_sample=do_sample,
387 | early_stopping=early_stopping,
388 | temperature=temperature,
389 | top_k=top_k,
390 | top_p=top_p,
391 | repetition_penalty=repetition_penalty,
392 | length_penalty=length_penalty
393 | )
394 |
--------------------------------------------------------------------------------
/src/text2sql/graph_network.py:
--------------------------------------------------------------------------------
1 | """
2 | GNN using TransformerConv
3 | """
4 | # https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html#implementing-the-gcn-layer
5 | import math
6 | import numpy as np
7 | import networkx as nx
8 |
9 | from types import SimpleNamespace
10 |
11 | import torch
12 | from torch import nn
13 |
14 | import torch.nn.functional as F
15 | from torch.nn import Linear
16 | from torch_geometric.nn.conv import MessagePassing
17 | from torch_geometric.utils import softmax, from_networkx
18 |
19 |
20 | # set seeds for reproducibility
21 | np.random.seed(4)
22 | torch.manual_seed(4)
23 |
24 | NODE_ARGS = 5
25 | NUM_NODES = 20
26 | P = 0.3
27 | SAMPLES = 100
28 |
29 | data = []
30 | node_fields = [f"i{i}" for i in range(NODE_ARGS)]
31 | while len(data) < SAMPLES:
32 | g = nx.binomial_graph(NUM_NODES, P)
33 | if nx.is_connected(g):
34 | g2 = nx.Graph()
35 | for n in list(g.nodes):
36 | g2.add_node(n, **{f: [np.random.random()] for f in node_fields})
37 | for e in list(g.edges):
38 | g2.add_edge(*e)
39 | data.append(from_networkx(g2))
40 |
41 |
42 |
43 | """Class Object of TransformerConv"""
44 | class TransformerConv(MessagePassing):
45 | # from https://arxiv.org/abs/2009.03509
46 | def __init__(self, config):
47 | self.config = config
48 | self.in_channels = config.n_embd
49 | self.out_channels = config.n_embd
50 | self.heads = config.n_heads
51 | self.dropout = config.dropout
52 | self.edge_dim = config.edge_dim
53 | self.beta = config.graph_beta
54 |
55 | # in_channels = (config.n_embd, config.n_embd)
56 | # out_channels = in_channels = config.n_embd
57 |
58 | self.lin_key = nn.Linear(config.n_embd, config.n_heads * config.n_embd)
59 | self.lin_query = nn.Linear(config.n_embd, config.n_heads * config.n_embd)
60 | self.lin_value = nn.Linear(config.n_embd, config.n_heads * config.n_embd)
61 | self.lin_edge = nn.Linear(config.edge_dim, config.n_heads * config.n_embd, bias = False)
62 | self.lin_skip = nn.Linear(config.n_embd, config.n_heads * config.n_embd, bias = True)
63 | self.reset_parameters()
64 |
65 | def reset_parameters(self):
66 | self.lin_key.reset_parameters()
67 | self.lin_query.reset_parameters()
68 | self.lin_value.reset_parameters()
69 | self.lin_edge.reset_parameters()
70 | self.lin_skip.reset_parameters()
71 |
72 | def forward(self, x, edge_index, edge_attr):
73 | x = (x, x) # pair tensor thingy
74 | out = self.propagate(edge_index, x = x, edge_attr = edge_attr, size = None)
75 | out = out.view(-1, self.heads * self.out_channels) # always concat
76 | x = self.lin_skip(x[1])
77 | if self.beta is not None:
78 | out = self.beta * x + (1 - self.beta) * out
79 | else:
80 | out = out + x
81 |
82 | return out
83 |
84 | def message(self, x_i, x_j, edge_attr, index, ptr, size_i):
85 | query = self.lin_key(x_i).view(-1, self.heads, self.out_channels)
86 | key = self.lin_query(x_i).view(-1, self.heads, self.out_channels)
87 |
88 | lin_edge = self.lin_edge
89 | if edge_attr is not None:
90 | edge_attr = lin_edge(edge_attr).view(-1, self.heads, self.out_channels)
91 | key = key + edge_attr
92 |
93 | alpha = (query * key).sum(dim = -1) / math.sqrt(self.out_channels)
94 | alpha = softmax(alpha, index, ptr, size_i)
95 | alpha = F.softmax(alpha, p = self.dropout, training = self.training)
96 |
97 | out = self.lin_value(x_j).view(-1, self.heads, self.out_channels)
98 | if edge_attr is not None:
99 | out = out + edge_attr
100 |
101 | out = out * alpha.view(-1, self.heads, 1)
102 | return out
103 |
104 | def __repr__(self):
105 | return f"{self.__class__.__name__} ({self.in_channels} {self.out_channels}, heads = {self.heads})"
106 |
107 |
108 | # configuration object
109 | config = SimpleNamespace(
110 | n_embd = 16,
111 | n_heads = 2,
112 | dropout = 0.1,
113 | edge_dim = 8,
114 | graph_beta = 0.9
115 | )
116 |
--------------------------------------------------------------------------------
/src/text2sql/model.py:
--------------------------------------------------------------------------------
1 | """model for text2sql
2 | 03.11.2020 - @yashbonde"""
3 |
4 | import numpy as np
5 | from tabulate import tabulate
6 | from types import SimpleNamespace
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.nn import functional as F
11 | from transformers.modeling_utils import find_pruneable_heads_and_indices, prune_conv1d_layer
12 | from transformers.activations import ACT2FN
13 |
14 | # the code below is only a slighlty modified version from huggingface.
15 | class Conv1D(nn.Module):
16 | def __init__(self, nf, nx):
17 | super().__init__()
18 | self.nf = nf
19 | w = torch.empty(nx, nf)
20 | nn.init.normal_(w, std=0.02)
21 | self.weight = nn.Parameter(w)
22 | self.bias = nn.Parameter(torch.zeros(nf))
23 |
24 | def forward(self, x):
25 | size_out = x.size()[:-1] + (self.nf,)
26 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
27 | x = x.view(*size_out)
28 | return x
29 |
30 |
31 | class Attention(nn.Module):
32 | def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
33 | super().__init__()
34 |
35 | n_state = nx # in Attention: n_state=768 (nx=n_embd)
36 | # [switch nx => n_state from Block to Attention to keep identical to TF implem]
37 | assert n_state % config.n_head == 0
38 |
39 | self.n_head = config.n_head
40 | self.split_size = n_state
41 | self.scale = scale
42 | self.is_cross_attention = is_cross_attention
43 | if self.is_cross_attention:
44 | self.c_attn = Conv1D(2 * n_state, nx)
45 | self.q_attn = Conv1D(n_state, nx)
46 | else:
47 | self.c_attn = Conv1D(3 * n_state, nx)
48 | self.c_proj = Conv1D(n_state, nx)
49 | self.attn_dropout = nn.Dropout(config.attn_pdrop)
50 | self.resid_dropout = nn.Dropout(config.resid_pdrop)
51 | self.pruned_heads = set()
52 |
53 | def prune_heads(self, heads):
54 | if len(heads) == 0:
55 | return
56 | heads, index = find_pruneable_heads_and_indices(
57 | heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
58 | )
59 | index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
60 |
61 | # Prune conv1d layers
62 | self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
63 | self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
64 |
65 | # Update hyper params
66 | self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
67 | self.n_head = self.n_head - len(heads)
68 | self.pruned_heads = self.pruned_heads.union(heads)
69 |
70 | def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
71 | w = torch.matmul(q, k)
72 | if self.scale:
73 | w = w / (float(v.size(-1)) ** 0.5)
74 |
75 | # print("((((", w.size(), attention_mask.size())
76 |
77 | if attention_mask is not None:
78 | # reshape the attention mask to fit the weight matrix
79 | # print(attention_mask[0, 0, :w.size(2), :])
80 | # Apply the attention mask
81 | w = w + attention_mask[:, :, :w.size(2), :]
82 |
83 | w = nn.Softmax(dim=-1)(w)
84 | w = self.attn_dropout(w)
85 |
86 | # print(w.size(), v.size())
87 |
88 | outputs = [torch.matmul(w, v)]
89 | if output_attentions:
90 | outputs.append(w)
91 | return outputs
92 |
93 | def merge_heads(self, x):
94 | x = x.permute(0, 2, 1, 3).contiguous()
95 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
96 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
97 |
98 | def split_heads(self, x, k=False):
99 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
100 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
101 | if k:
102 | return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
103 | else:
104 | return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
105 |
106 | def forward(
107 | self,
108 | hidden_states,
109 | layer_past=None,
110 | attention_mask=None,
111 | head_mask=None,
112 | encoder_hidden_states=None,
113 | encoder_attention_mask=None,
114 | use_cache=False,
115 | output_attentions=False,
116 | ):
117 | if encoder_hidden_states is not None:
118 | assert hasattr(
119 | self, "q_attn"
120 | ), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
121 | query = self.q_attn(hidden_states)
122 | key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
123 | attention_mask = encoder_attention_mask
124 | else:
125 | query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
126 |
127 | query = self.split_heads(query)
128 | key = self.split_heads(key, k=True)
129 | value = self.split_heads(value)
130 | if layer_past is not None:
131 | past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
132 | key = torch.cat((past_key, key), dim=-1)
133 | value = torch.cat((past_value, value), dim=-2)
134 |
135 | if use_cache is True:
136 | present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
137 | else:
138 | present = (None,)
139 |
140 | # print(query.size(), key.size(), value.size())
141 |
142 | attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
143 | a = attn_outputs[0]
144 |
145 | a = self.merge_heads(a)
146 | a = self.c_proj(a)
147 | a = self.resid_dropout(a)
148 |
149 | outputs = [a, present] + attn_outputs[1:]
150 | return outputs # a, present, (attentions)
151 |
152 |
153 | class MLP(nn.Module):
154 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
155 | super().__init__()
156 | nx = config.n_embd
157 | self.c_fc = Conv1D(n_state, nx)
158 | self.c_proj = Conv1D(nx, n_state)
159 | self.act = ACT2FN[config.activation_function]
160 | self.dropout = nn.Dropout(config.resid_pdrop)
161 |
162 | def forward(self, x):
163 | h = self.act(self.c_fc(x))
164 | h2 = self.c_proj(h)
165 | return self.dropout(h2)
166 |
167 |
168 | class Block(nn.Module):
169 | def __init__(self, config, n_ctx, add_cross_attention = False, scale=False):
170 | super().__init__()
171 | hidden_size= config.n_embd
172 | inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
173 | self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
174 | self.attn = Attention(hidden_size, n_ctx, config, scale)
175 | self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
176 | if add_cross_attention:
177 | self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True)
178 | self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
179 | self.mlp = MLP(inner_dim, config)
180 |
181 | def forward(self, x):
182 | # this was not taking key word arguments in Sequential so need to pass around a tuple
183 | # so now I understood why huggingface coded by passing around lists, stupid!
184 | type_ =x[0]
185 | # print("^^^^", type_, len(x))
186 | if type_ in ["encoder", "self"]:
187 | (hidden_states, attention_mask) = x[1:]
188 | else:
189 | (hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask) = x[1:]
190 |
191 | attn_outputs = self.attn(
192 | self.ln_1(hidden_states),
193 | attention_mask=attention_mask,
194 | )
195 | attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
196 | outputs = attn_outputs[1:]
197 | hidden_states = attn_output + hidden_states # residual connection
198 |
199 | if type_ == "decoder":
200 | # add one self-attention block for cross-attention
201 | assert hasattr(
202 | self, "crossattention"
203 | ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
204 | cross_attn_outputs = self.crossattention(
205 | self.ln_cross_attn(hidden_states),
206 | attention_mask=attention_mask,
207 | encoder_hidden_states=encoder_hidden_states,
208 | encoder_attention_mask=encoder_attention_mask,
209 | )
210 | attn_output = cross_attn_outputs[0]
211 | # residual connection
212 | hidden_states = hidden_states + attn_output
213 | outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
214 |
215 | feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
216 | # residual connection
217 | hidden_states = hidden_states + feed_forward_hidden_states
218 |
219 | outputs = [hidden_states] + outputs
220 | if type_ in ["encoder", "self"]:
221 | out = (type_, hidden_states, attention_mask)
222 | else:
223 | out = (type_, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask)
224 | return out
225 |
226 |
227 | class Decoder(nn.Module):
228 | def __init__(self, config):
229 | super(Decoder, self).__init__()
230 |
231 | blocks = []
232 | for i in range(config.n_decoder_layers):
233 | blocks.append(Block(config, n_ctx = config.maxlen, add_cross_attention = False)) # casual
234 | blocks.append(Block(config, n_ctx = config.maxlen, add_cross_attention = True)) # sent
235 | blocks.append(Block(config, n_ctx = config.maxlen, add_cross_attention = True)) # db
236 | self.blocks = nn.Sequential(*blocks)
237 | self.ln = nn.LayerNorm(config.n_embd, eps = config.layer_norm_epsilon)
238 |
239 | def forward(self, x):
240 | # this was not taking key word arguments in Sequential so need to pass around a tuple
241 | (hidden_states_sql, attention_mask_sql, hidden_states_sent, attention_mask_sent, hidden_states_db, attention_mask_db) = x
242 |
243 | hidden_states = hidden_states_sql
244 | l = 0
245 | for i, block in enumerate(self.blocks):
246 | if l == 0: # casual attention
247 | outputs = block(("self", hidden_states, attention_mask_sql))
248 | l = 1
249 | elif l == 1: # sentence attention
250 | outputs = block(("decoder", hidden_states, attention_mask_sql, hidden_states_sent, attention_mask_sent))
251 | l = 2
252 | else: # db attention
253 | outputs = block(("decoder", hidden_states, attention_mask_sql, hidden_states_db, attention_mask_db))
254 | l = 0
255 | hidden_states = outputs[1]
256 | hidden_states_sql = hidden_states
257 | return (hidden_states_sql, attention_mask_sql,
258 | hidden_states_sent, attention_mask_sent,
259 | hidden_states_db, attention_mask_db)
260 |
261 |
262 | class Text2SQLModel(nn.Module):
263 | def __init__(self, config):
264 | super(Text2SQLModel, self).__init__()
265 | self.config = config
266 |
267 | self.embedding = nn.Embedding(config.vocab_size, config.n_embd)
268 | # self.wte_sent = nn.Embedding(config.maxlen, config.n_embd)
269 | # self.wte_db = nn.Embedding(config.maxlen, config.n_embd)
270 | # self.wte_sql = nn.Embedding(config.maxlen, config.n_embd)
271 |
272 | # using embedding and position IDs isn't really going well so will use Parameter
273 | self.wte_sent = nn.Parameter(torch.zeros(config.maxlen, config.n_embd))
274 | self.wte_db = nn.Parameter(torch.zeros(config.maxlen, config.n_embd))
275 | self.wte_sql = nn.Parameter(torch.zeros(config.maxlen, config.n_embd))
276 |
277 | self.sentence_encoder = nn.Sequential(*[
278 | Block(config, n_ctx=config.maxlen, add_cross_attention=False)
279 | for _ in range(config.n_sent_layers)
280 | ])
281 | self.db_encoder = nn.Sequential(*[
282 | Block(config, n_ctx=config.maxlen, add_cross_attention=False)
283 | for _ in range(config.n_db_layers)
284 | ])
285 | self.decoder = Decoder(config)
286 |
287 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias = False)
288 |
289 | self.apply(self._init_weights)
290 | print("number of parameters:", sum(p.numel() for p in self.parameters()))
291 |
292 | def _init_weights(self, module):
293 | if isinstance(module, (nn.Linear, nn.Embedding)):
294 | module.weight.data.normal_(mean=0.0, std=0.02)
295 | if isinstance(module, nn.Linear) and module.bias is not None:
296 | module.bias.data.zero_()
297 | elif isinstance(module, nn.LayerNorm):
298 | module.bias.data.zero_()
299 | module.weight.data.fill_(1.0)
300 |
301 | def configure_optimizers(self, train_config):
302 | """
303 | This long function is unfortunately doing something very simple and is being very defensive:
304 | We are separating out all parameters of the model into two buckets: those that will experience
305 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
306 | We are then returning the PyTorch optimizer object.
307 | """
308 |
309 | # # separate out all parameters to those that will and won't experience regularizing weight decay
310 | # decay = set()
311 | # no_decay = set()
312 | # whitelist_weight_modules = (torch.nn.Linear, )
313 | # blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
314 | # for mn, m in self.named_modules():
315 | # for pn, p in m.named_parameters():
316 | # # print(mn, "--", pn)
317 | # fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
318 | # print(fpn, type(m))
319 | # if fpn.endswith('bias'):
320 | # # all biases will not be decayed
321 | # no_decay.add(fpn)
322 | # print(fpn, "--", 1)
323 | # elif fpn.endswith('weight') and isinstance(m, whitelist_weight_modules):
324 | # # weights of whitelist modules will be weight decayed
325 | # decay.add(fpn)
326 | # print(fpn, "--", 2)
327 | # elif fpn.endswith('weight') and isinstance(m, blacklist_weight_modules):
328 | # # weights of blacklist modules will NOT be weight decayed
329 | # no_decay.add(fpn)
330 | # print(fpn, "--", 3)
331 | # print()
332 |
333 | # # validate that we considered every parameter
334 | # param_dict = {pn: p for pn, p in self.named_parameters()}
335 | # inter_params = decay & no_decay
336 | # union_params = decay | no_decay
337 | # assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
338 | # assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
339 | # % (str(param_dict.keys() - union_params), )
340 |
341 | # # create the pytorch optimizer object
342 | # optim_groups = [
343 | # {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
344 | # {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
345 | # ]
346 | optimizer = torch.optim.Adam(self.parameters(), lr=train_config.lr, betas=train_config.betas)
347 | return optimizer
348 |
349 | def get_position_ids(self, input, past_length, device):
350 | input_shape = input.size()
351 | position_ids = torch.arange(past_length, input_shape[-1] + past_length).long().to(device)
352 | position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
353 | print("**", position_ids.device)
354 | return position_ids
355 |
356 | def encoder_fn(self, sent, db, sent_attn, db_attn, device):
357 | B, T = sent.size()
358 | sent = self.embedding(sent) + self.wte_sent[:T,:]
359 | db = self.embedding(db) + self.wte_db[:T,:]
360 | sent_hidden_states = self.sentence_encoder(("encoder", sent, sent_attn),)[1]
361 | db_hidden_states = self.db_encoder(("encoder", db, db_attn),)[1]
362 | return SimpleNamespace(
363 | sent_hidden_states=sent_hidden_states,
364 | db_hidden_states=db_hidden_states,
365 | sent_attn=sent_attn,
366 | db_attn=db_attn
367 | )
368 |
369 | def decoder_fn(self, enc_out, sql_ids, sql_attn, device):
370 | B, T = sql_ids.size()
371 | sql = self.embedding(sql_ids) + self.wte_sql[:T,:]
372 | sql_output = self.decoder((sql, sql_attn, enc_out.sent_hidden_states,
373 | enc_out.sent_attn, enc_out.db_hidden_states, enc_out.db_attn),)[0]
374 | sql_output = self.lm_head(sql_output)
375 | return sql_output
376 |
377 | def forward(self, sql_ids, sent, db, sql_attn, sent_attn, db_attn, labels=None, past_length=0, device="cpu"):
378 | B, T = sql_ids.size()
379 |
380 | # make the embeddings
381 | sql = self.embedding(sql_ids) + self.wte_sql[:T,:]
382 | sent = self.embedding(sent) + self.wte_sent[:T,:]
383 | db = self.embedding(db) + self.wte_db[:T,:]
384 |
385 | # get hidden_states for sentence_encoder
386 | sent_hidden_states = self.sentence_encoder(("encoder", sent, sent_attn),)[1]
387 | db_hidden_states = self.db_encoder(("encoder", db, db_attn),)[1]
388 | sql_output = self.decoder((sql, sql_attn, sent_hidden_states, sent_attn, db_hidden_states, db_attn),)[0]
389 | sql_output = self.lm_head(sql_output)
390 | output = [sql_output]
391 |
392 | if labels is not None:
393 | labels = labels.contiguous()
394 | logits = sql_output.contiguous()
395 |
396 | # loss_fct = nn.CrossEntropyLoss(reduction="none")
397 | # loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
398 |
399 | loss_fct = nn.CrossEntropyLoss(reduction="mean")
400 | loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
401 |
402 | # # get the indexes where the labels are not [PAD]
403 | # non_pad_mask = sql_attn[:, :, 0, 1:].contiguous().view(-1) == 0
404 | # print(loss.size(), non_pad_mask.size())
405 | # non_pad_loss = loss[non_pad_mask]
406 | # print("non_pad_loss", non_pad_loss)
407 | # loss = non_pad_loss.mean()
408 | output = [loss] + output
409 |
410 | return output
411 |
412 |
413 | class Text2SQLModelConfig():
414 | vocab_size = 5012
415 | n_embd = 256
416 | maxlen = 128
417 | n_decoder_layers = 2
418 | n_sent_layers = 3
419 | n_db_layers = 3
420 | n_head = 8
421 | n_inner = None
422 | activation_function = "gelu_new"
423 | resid_pdrop = 0.1
424 | embd_pdrop = 0.1
425 | attn_pdrop = 0.1
426 | layer_norm_epsilon = 0.00001
427 | initializer_range = 0.02
428 |
429 | def __init__(self, **kwargs):
430 | self.attrs = [
431 | "vocab_size",
432 | "n_embd",
433 | "maxlen",
434 | "n_decoder_layers",
435 | "n_sent_layers",
436 | "n_db_layers",
437 | "n_head",
438 | "n_inner",
439 | "activation_function",
440 | "resid_pdrop",
441 | "embd_pdrop",
442 | "attn_pdrop",
443 | "layer_norm_epsilon",
444 | "initializer_range",
445 | ]
446 |
447 | for k, v in kwargs.items():
448 | self.attrs.append(k)
449 | setattr(self, k, v)
450 |
451 | def __repr__(self):
452 | kvs = [(k, f"{getattr(self, k)}") for k in sorted(list(set(self.attrs)))]
453 | return tabulate(kvs, ["argument", "value"], tablefmt="psql")
454 |
455 | # ====== Sampling Utils ====== #
456 | def top_k_logits(logits, k):
457 | v, ix = torch.topk(logits, k)
458 | out = logits.clone()
459 | out[out < v[:, [-1]]] = -1e6
460 | return out
461 |
462 |
463 | @torch.no_grad()
464 | def sample(model, sent, sent_attn, db, db_attn, t, sql_str = None, device="cpu", steps=50, temperature=50, top_k=None):
465 | model.eval()
466 | sent = sent.view(-1, sent.size(0))
467 | db = db.view(-1, db.size(0))
468 | sent_attn = sent_attn.view(-1, *sent_attn.size())
469 | db_attn = db_attn.view(-1, *db_attn.size())
470 | # print(sent.size(), db.size(), sent_attn.size(), db_attn.size())
471 | enc_out = model.encoder_fn(sent, db, sent_attn, db_attn, device=device)
472 |
473 | # convert string to sql_tokens
474 | if sql_str is not None:
475 | sql = torch.from_numpy(np.asarray([t.encode(sql_str)])).long()
476 | else:
477 | sql = torch.from_numpy(np.asarray([t.bos_id()])).view(-1, 1).long()
478 | sql = sql.to(device)
479 |
480 | # final sequence
481 | out = []
482 |
483 | for k in range(steps):
484 | if k == 0:
485 | x = sql if sql.size(1) < model.config.maxlen else sql[:, -model.config.maxlen:]
486 | else:
487 | x = x if x.size(1) < model.config.maxlen else x[:, -model.config.maxlen:]
488 |
489 | sql_attn = np.zeros((x.size(1), x.size(1)))
490 | sql_attn[:x.size(1), :x.size(1)] = 1
491 | sql_attn = sql_attn - np.triu(sql_attn, k = 1) # casual masking
492 | sql_attn = 1 - sql_attn
493 | sql_attn = sql_attn * -1e6
494 | sql_attn = torch.from_numpy(sql_attn.astype(np.float32)).to(device).view(1, 1, *sql_attn.shape)
495 |
496 | # print("****", x.size(), sql_attn.size())
497 |
498 | logits = model.decoder_fn(enc_out, x, sql_attn, device=device)
499 |
500 | # pluck the logits at the final step and scale by temperature
501 | logits = logits[:, -1, :] / temperature
502 | # optionally crop probabilities to only the top k options
503 | if top_k is not None:
504 | logits = top_k_logits(logits, top_k)
505 |
506 | # apply softmax to convert to probabilities
507 | probs = F.softmax(logits, dim=-1)
508 | # sample from the distribution or take the most likely
509 | if sample:
510 | ix = torch.multinomial(probs, num_samples=1)
511 | else:
512 | _, ix = torch.topk(probs, k=1, dim=-1)
513 | # append to the sequence and continue
514 | x = torch.cat((x, ix), dim=1)
515 |
516 | out.append(ix[0].tolist()[0])
517 |
518 | return t.decode_ids(out)
519 |
--------------------------------------------------------------------------------
/src/text2sql/model2.py:
--------------------------------------------------------------------------------
1 | """So the complicated model with encoder and decoder networks is not working properly,
2 | need to come up with something better.
3 | 12.11.2020 - @yashbonde"""
4 |
5 | from text2sql.model import *
6 |
7 | # https://pytorch-scatter.readthedocs.io/en/1.3.0/functions/add.html
8 | from torch_scatter import scatter_add
9 |
10 | class GPTModel(nn.Module):
11 | def __init__(self, config):
12 | super(GPTModel, self).__init__()
13 |
14 | self.config = config
15 |
16 | def merge_embeddings(self, expanded_embeddings, merging_index):
17 | """This function merges multiple embedding vectors into smaller embeddings. This is what
18 | is happening in this function. Challenge is that each node has information in this format
19 | . and we need to merge this information. Another challenge was
20 | that the older method created matrices that were 400x400 while the questions were 50x50.
21 |
22 | So continuing with the above mentioned case:
23 |
24 | is tokenized into W1,W2,W3 and is tokenized into W4,W5
25 | So the effective tokenizing becomes W1,W2,W3.W4,W5. Now in order to feed into the model
26 | we need to merge the expanded_embeddings is (5,256) and merging_index = [0,0,0,1,1]
27 | this function returns a matrix (2,256)
28 |
29 | Now in practice this takes a large matrix like (500, 500) and it returns a highly reduced
30 | version of this. The output in this case is going to be (num_nodes, num_nodes) and thus
31 | we do not need to use the complicated cell expanded attention matrix as done in
32 | `text2sql.data.get_tokenised_attention_mask`
33 |
34 |
35 | """"
36 | out = scatter_add(expanded_embeddings, merging_index)
37 | return out
38 |
39 | def forward():
40 |
41 | torch_scatter
42 |
43 |
--------------------------------------------------------------------------------
/src/text2sql/trainer.py:
--------------------------------------------------------------------------------
1 | """
2 | Simple training loop; Boilerplate that could apply to any arbitrary neural network,
3 | so nothing in this file really has anything to do with GPT specifically.
4 |
5 | from karpathy/minGPT
6 | """
7 |
8 | import time
9 | import random
10 | import numpy as np
11 | from tqdm import tqdm, trange
12 | from tabulate import tabulate
13 |
14 | from text2sql.model import sample
15 |
16 | import torch
17 | from torch.optim.lr_scheduler import OneCycleLR
18 | from torch.utils.data import DataLoader
19 | from torch.utils.tensorboard import SummaryWriter
20 |
21 | class Trainer:
22 | def __init__(self, model, train_dataset, test_dataset, config):
23 | self.model = model
24 | self.train_dataset = train_dataset
25 | self.test_dataset = test_dataset
26 | self.config = config
27 |
28 | self.device = "cpu"
29 | if torch.cuda.is_available():
30 | print("Model is now CUDA!!!!")
31 | self.device = torch.cuda.current_device()
32 | self.model = torch.nn.DataParallel(self.model).to(self.device)
33 |
34 | def save_checkpoint(self):
35 | raw_model = self.model.module if hasattr(self.model, "module") else self.model
36 | print(f"Saving Model at {self.config.ckpt_path}")
37 | torch.save(raw_model.state_dict(), self.config.ckpt_path)
38 |
39 | def train(self, data_config, verbose = False):
40 | model, config = self.model, self.config
41 | raw_model = model.module if hasattr(self.model, "module") else model
42 | optimizer = raw_model.configure_optimizers(config)
43 | lrscheduler = OneCycleLR(
44 | optimizer,
45 | max_lr = config.lr,
46 | total_steps=config.num_batch*config.max_epochs
47 | )
48 |
49 | print("Starting training...", lrscheduler)
50 |
51 | with SummaryWriter(log_dir=config.tb_path, flush_secs=20) as tb:
52 |
53 | def run_epoch(split, epoch, _gs):
54 | is_train = split == "train"
55 | model.train(is_train)
56 | data = self.train_dataset if is_train else self.test_dataset
57 | dl = DataLoader(
58 | data,
59 | shuffle = True,
60 | pin_memory = True,
61 | batch_size = config.batch_size,
62 | )
63 | losses = []
64 | if is_train:
65 | pbar = trange(config.num_batch, ncols=100)
66 | iterator = zip(pbar, dl)
67 | else:
68 | pbar = tqdm(enumerate(dl))
69 | iterator = pbar
70 |
71 | for it, d in iterator:
72 |
73 | with torch.set_grad_enabled(is_train):
74 | _l = -1 if not losses else losses[-1]
75 | if is_train:
76 | pbar.set_description(f"[TRAIN] GS: {_gs}, Epoch: {epoch}, Loss: {round(_l, 5)}")
77 | else:
78 | pbar.set_description(f"[VAL] Epoch: {epoch}")
79 |
80 | d = {k:v.to(self.device) for k,v in d.items()}
81 | # print({k:(v.size(), v.dtype, v.device) for k,v in d.items()})
82 |
83 | loss, logits = model(
84 | **d,
85 | device = self.device
86 | )
87 | loss = loss.mean() # gather from multitple GPUs
88 | losses.append(loss.item())
89 |
90 | if is_train:
91 | # add things to tb, loss and attention images
92 | tb.add_scalar("loss", loss.item(), global_step=_gs, walltime=time.time())
93 | tb.add_scalar("lr", lrscheduler.get_lr()[0], global_step=_gs, walltime=time.time())
94 |
95 | loss.backward()
96 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
97 | optimizer.step()
98 | lrscheduler.step()
99 | _gs += 1
100 |
101 | if not is_train:
102 | # no sampling here, because that really doesn't make any sense
103 | test_loss = float(np.mean(losses))
104 | tb.add_scalar("test_loss", test_loss, global_step=_gs, walltime=time.time())
105 |
106 | # create samples for visualising the results
107 | print("Generating Samples ...")
108 | for i in range(min(10, len(d["sql_ids"]))):
109 | s = {k:v[i, ...] for k,v in d.items()}
110 | seq = sample(raw_model, sent=s["sent"], sent_attn=s["sent_attn"],
111 | db=s["db"], db_attn=s["db_attn"], t=data_config.tokenizer,
112 | device = self.device) #, sql_str = "select")
113 | target = data_config.tokenizer.decode_ids([x for x in s["labels"].tolist() if x != -100])
114 | print("-->", seq ,"<<--->>", target)
115 |
116 | return test_loss
117 | return _gs
118 |
119 | # now write wrapper for each epoch
120 | best_loss = float("inf")
121 | gs = 1
122 | test_no_improve = 0
123 | for e in range(config.max_epochs):
124 | gs = run_epoch("train", e, gs)
125 | if self.test_dataset is not None:
126 | test_loss = run_epoch("test", e, gs)
127 | print(f"Test loss: {test_loss}")
128 |
129 | # early stopping based on the test loss of just save always if no test set is provided
130 | good_model = self.test_dataset is None or test_loss < best_loss
131 | if self.config.ckpt_path is not None and good_model:
132 | best_loss = test_loss
133 | self.save_checkpoint()
134 | test_no_improve = 0
135 | else:
136 | test_no_improve += 1
137 |
138 | # if test_no_improve == config.patience:
139 | # print(f"Stop Training after [patience = {config.patience}]: {e} epochs")
140 | # break
141 |
142 |
143 | class TrainerConfig:
144 | lr = 0.0001
145 | max_epochs = 10
146 | batch_size = 128
147 | betas = (0.9, 0.95)
148 | grad_norm_clip = 1.0
149 | weight_decay = 0.1 # only applied on matmul weights
150 | len_data = None # required for OneCycleLR
151 | sample_every = 5 # after how many epochs to log
152 | num_batch = None
153 | patience = 5 # training stops after patience runs out
154 | tb_path = None
155 | ckpt_path = None
156 |
157 | def __init__(self, **kwargs):
158 | self.attrs = [ "lr", "max_epochs", "batch_size", "betas",
159 | "grad_norm_clip", "weight_decay", "len_data",
160 | "sample_every", "num_batch", "patience", "tb_path",
161 | ]
162 | for k,v in kwargs.items():
163 | setattr(self, k, v)
164 | self.attrs.append(k)
165 |
166 | def __repr__(self):
167 | kvs = [(k, f"{getattr(self, k)}") for k in sorted(list(set(self.attrs)))]
168 | return tabulate(kvs, ["argument", "value"], tablefmt="psql")
169 |
170 |
171 | # funcs
172 | def set_seed(seed):
173 | if seed is not None:
174 | random.seed(seed)
175 | np.random.seed(seed)
176 | torch.manual_seed(seed)
177 | torch.cuda.manual_seed_all(seed)
178 |
--------------------------------------------------------------------------------
/t2s.py:
--------------------------------------------------------------------------------
1 | """trying streamlit as usage tool
2 | 22.09.2020 - @yashbonde"""
3 |
4 | import os
5 | import json
6 | import random
7 | from PIL import Image
8 | import numpy as np
9 | import networkx as nx
10 | import matplotlib.pyplot as plt
11 |
12 | from text2sql.data import parse_db_to_networkx
13 |
14 | import streamlit as st
15 |
16 | st.write("""
17 | # Text2SQL Converter
18 | Convert your natural language question to SQL queries.
19 |
20 | * **Author: Yash Bonde**
21 | * Website: [link](https://yashbonde.github.io)
22 | * LinkedIn: [link](https://www.linkedin.com/in/yash-bonde/)
23 | * Twitter: [@bondebhai](https://twitter.com/bondebhai)
24 | * Check code at my [Github](https://github.com/yashbonde/text2sql)
25 | """)
26 |
27 |
28 | # # with st.spinner('Loading model ...'):
29 | # from transformers import AutoTokenizer, AutoModel
30 | # TOKENIZER = AutoTokenizer.from_pretrained("distilbert-base-uncased")
31 | # THETA = AutoModel.from_pretrained("distilbert-base-uncased")
32 |
33 | # paths to main files
34 | OTHER_FILE = "data/spider/train_others.json"
35 | SPIDER_FILE = "data/spider/train_spider.json"
36 | SPARC_FILE = "data/sparc/train.json"
37 | COSQL_FILE = "data/cosql_dataset/cosql_all_info_dialogs.json"
38 |
39 | # files containing tables info
40 | SPIDER_TABLES = "data/spider/tables.json"
41 | SPARC_TABLES = "data/sparc/tables.json"
42 | COSQL_TABLES = "data/cosql_dataset/tables.json"
43 |
44 | # spider dataset already has sql files that we can read from to tokenize
45 | SPIDER_SQL_TRAIN = "data/spider/train_gold.sql"
46 | SPIDER_SQL_DEV = "data/spider/dev_gold.sql"
47 |
48 | # dev set
49 | SPIDER_DEV = "data/spider/dev.json"
50 | SPARC_DEV = "data/sparc/dev.json"
51 |
52 | DB_ID = None
53 |
54 | # load different dbs
55 | tables = []
56 | with open(SPIDER_TABLES) as f1, open(SPARC_TABLES) as f2, open(COSQL_TABLES) as f3:
57 | # spider/tables.json
58 | tables.extend(json.load(f1))
59 |
60 | # sparc/tables.json
61 | tables.extend(json.load(f2))
62 |
63 | # cosql_dataset/tables.json
64 | tables.extend(json.load(f3))
65 |
66 | # load questions and corresponding outputs
67 | data = []
68 | with open(OTHER_FILE) as f1, open(SPIDER_FILE) as f2, open(SPARC_FILE) as f3, open(COSQL_FILE) as f4:
69 | # train_others.json
70 | for x in json.load(f1):
71 | data.append((x["question"], x["query"], x["db_id"]))
72 |
73 | # train_spider.json
74 | for x in json.load(f2):
75 | data.append((x["question"], x["query"], x["db_id"]))
76 |
77 | # sparc/train.json
78 | for x in json.load(f3):
79 | data.append((x["final"]["utterance"], x["final"]["query"], x["database_id"]))
80 |
81 | # cosql_all_info_dialogs.json
82 | for x,y in json.load(f4).items():
83 | data.append((y["query_goal"], y["sql"], y["db_id"]))
84 |
85 | def get_random_db():
86 | random_table = random.choice(tables)
87 | global DB_ID
88 | DB_ID = random_table["db_id"]
89 | g = parse_db_to_networkx(random_table)
90 | eattr = nx.get_edge_attributes(g, 'foreign')
91 | pos = nx.spring_layout(g)
92 |
93 | plt.figure(figsize = (6, 4))
94 | nx.draw_networkx_nodes(g, pos, )
95 | nx.draw_networkx_labels(g, pos, nx.get_node_attributes(g, 'id'), font_size="x-small")
96 | nx.draw_networkx_edges(
97 | g,
98 | pos,
99 | edgelist=[k for k,v in eattr.items() if v],
100 | edge_color="r",
101 |
102 | )
103 | nx.draw_networkx_edges(
104 | g,
105 | pos,
106 | edgelist=[k for k,v in eattr.items() if not v],
107 | edge_color="b",
108 |
109 | )
110 | plt.savefig("_temp.png")
111 |
112 | st.write("""
113 | ## Problem Statement
114 |
115 | **Given a database schema and a question in natural language get the appropriate
116 | schema to get the information needed.**
117 |
118 | For this we first go over the data.
119 |
120 | ### Sample Data
121 |
122 | Below is a quick sample of how the DB looks like, `blue` edges represent
123 | edges in the same table while `red` edges represent foreign keys. Click the
124 | `Randomise` button to see another graphs. I am using [CoSQL](https://yale-lily.github.io/cosql),
125 | [Spider](https://yale-lily.github.io/spider), [Sparc](https://yale-lily.github.io/sparc)
126 | datasets.
127 | """)
128 | db_cntr = 0
129 | if st.button('Randomise') or db_cntr == 0:
130 | get_random_db()
131 |
132 | # load the graph image
133 | x = Image.open("_temp.png")
134 | st.image(x, caption = f"Look Ma! Database is a graph. ({DB_ID})", clamp = True)
135 |
136 | # update samples
137 | data_this_db = list(filter(
138 | lambda x: x[2] == DB_ID, data
139 | ))
140 | st.write(f"from `{DB_ID}` we get following questions:\n\n" +
141 | "- " + "\n\n- ".join([f"{x[0]} ➡️ `{x[1]}`" for x in data_this_db][:3])
142 | )
143 | db_cntr += 1
144 |
145 | st.write("""
146 | ### Database Schema
147 | Any DB is converted to the graph, it is a combination of nodes and edges where each have a certain property:
148 | ```
149 | nodes:
150 | - table: name of the table it belongs to
151 | - name: column name
152 | - type: one of ['boolean', 'time', 'others', 'text', 'number']
153 | - primary: boolean that tells if this is a primary key
154 |
155 | edges:
156 | - foreign: boolean that tells if this is a foreign edge
157 | ```
158 |
159 | ### Natural Language Questions
160 | We use the [distillbert](https://huggingface.co/distilbert-base-uncased) and you can pass
161 | it any text and see the output logits size.
162 |
163 | ## Algorithm
164 |
165 | To pose any problem as RL you need to have the following setup:
166 |
167 | ```
168 | format of data tuple
169 | s: current state
170 | a: action taken
171 | r: reward obtained for taking that action
172 | s': new state model reaches
173 | ```
174 |
175 | We are given database schema defined by $D$, and natural language question $N$.
176 | We first obtain an embedding for database $d = \phi(D)$ and question $t = \\theta(N)$.
177 | Thus we get the input state $s = [d;t]$, where $;$ denotes concatenation. Now we denote
178 | a function $\pi$ which is the policy network, which predicts the appropriate SQL.
179 | $q = \pi(s)$
180 |
181 | The main challenge is policy network, is is going to be a traditional Language modeling
182 | LSTM or Transformer. So let us consider the network outputs:
183 |
184 | * $\phi(D) \\rightarrow ( [N_{nodes}, E_{graph}], [1, E_{graph}] )$
185 | * $\\theta(N) \\rightarrow [N_{tokens}, 768]$, we can possibly reduce this further by
186 | max-pooling over the sequence as $[N_{tokens}, 768] \\rightarrow [1, 768]$
187 |
188 | **23rd September, 2020**: Okay So I think I have a solution, the primary challenge has
189 | been the definition of action space. The action space has all the vocabulary of SQL
190 | commands + two special tags `` and ``. `` tells that model will
191 | select a column from node embeddings (dot product + softmax) and `table` will tell
192 | to select table from node embeddings (dot product + sigmoid).
193 |
194 | For this to work hwever we will have to modify the equations given in the dataset as
195 | ```sql
196 | SELECT
197 | T2.name FROM Friend AS T1
198 | JOIN Highschooler AS T2 ON T1.student_id = T2.id
199 | WHERE T2.grade > 5
200 | GROUP BY T1.student_id
201 | HAVING count(*) >= 2
202 | ```
203 | to something like the one below
204 | ```sql
205 | SELECT
206 | Highschooler.name FROM Friend
207 | JOIN Highschooler ON Friend.student_id = Highschooler.id
208 | WHERE Highschooler.grade > 5
209 | GROUP BY Friend.student_id
210 | HAVING count(*) >= 2
211 | ```
212 |
213 | The idea with initial model was a complicated graph based approach but now I
214 | am considering a much simpler model. Model is a simple Transformer where we have
215 | two different encoder structures:
216 | * BERT as question encoder
217 | * Message-passing GNN as DB encoder
218 |
219 | These two combined will be fed into a conventional transformer decoder.
220 | """)
221 |
222 | # # My assumption is that the dataset was created with langauge models in mind, however in practice
223 | # # direclty pointing out the column is a better solution design.
224 | # # pass the database through a graph encoder to get node and graph embeddings
225 | # DB --> [GNN] ---> (node embedding) [N_1, E_1] ... A
226 | # \-> (graph embedding) [1, E_1] ... B
227 |
228 | # # pass the natural language question, through any LM like BERT
229 | # Q ---> [BERT] --> (token level embedding) [N_2, E_2] ... C
230 |
231 | # # --- undecided ---
232 | # # concatenate the graph embedding and natural language embedding
233 | # [B+C] --> [N_2, E_1 + E_2] ... D
234 |
235 | # # --- policy ---
236 | # For policy we can either use a GPT transformer or an LSTM
237 |
238 | # ! TODO: add question parsing in real time here
239 | # question = st.text_input(f"question for DB: {DB_ID} (do not press enter)", value = data_this_db[0][0], max_chars=100)
240 | # st.button('Process')
241 | # st.write(question)
242 | # tokenised = TOKENIZER(question)["input_ids"]
243 | # decoded = TOKENIZER.decode(tokenised)
244 | # st.write(f"""IDs: `{tokenised}` ➡️ Decoded: `{decoded}`""")
245 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """going by o2f format and using huggingface library
2 | 12.11.2020 - @yashbonde"""
3 |
4 | import os
5 | from types import SimpleNamespace
6 | from argparse import ArgumentParser
7 |
8 | from text2sql.data import T2SDataset, T2SDatasetConfig
9 | from text2sql.model import Text2SQLModel, Text2SQLModelConfig
10 | from text2sql.trainer import *
11 |
12 | # --- arguments
13 | args = ArgumentParser(description="Text2SQL Model Trainer")
14 |
15 | # --- paths
16 | args.add_argument("--save_folder", default = "models", type = str, help = "folder to save all models")
17 | args.add_argument("--name", type = str, help = "name of this particular model")
18 | args.add_argument("--schema_file", type = str, help = "path to schema file", default="/workspace/text2sql/fdata/all_schema.json")
19 | args.add_argument("--questions_tsv", type = str, help = "path to text/sql tsv", default="/workspace/text2sql/fdata/all_questions.tsv")
20 | args.add_argument("--tokenizer_path", type = str, help = "path to sentencepiece model file", default="/workspace/text2sql/fdata/model.model")
21 | args.add_argument("--seed", default = None, type = int, help = "seed value for training")
22 |
23 | # --- arch
24 | args.add_argument("--n_embd", default = 256, type = int, help = "Embedding Dim")
25 | args.add_argument("--n_decoder_layers", default = 4, type = int, help = "Num Decoder layers")
26 | args.add_argument("--n_sent_layers", default = 4, type = int, help = "Num layers for sentence encoder")
27 | args.add_argument("--n_db_layers", default = 4, type = int, help = "Num layers for DB encoder")
28 | args.add_argument("--n_head", default = 4, type = int, help = "Num Heads")
29 | args.add_argument("--maxlen", default = 390, type = int, help = "Maximum length of decoder")
30 |
31 | # --- data
32 | args.add_argument("--mult", default = 3, type = int, help = "Size of dataset")
33 | args.add_argument("--pf", default = 0.6, type = float, help = "Probability of using fields in training sequence")
34 | args.add_argument("--fmax", default = 0.8, type = float, help = "Max fields probability")
35 | args.add_argument("--fmin", default = 0.1, type = float, help = "Min fields probability")
36 |
37 | # --- trainer
38 | args.add_argument("--n_epochs", default = 3, type = int, help = "Number of epochs to train")
39 | args.add_argument("--batch_size", default = 32, type = int, help = "Mini-Batch Size")
40 | args.add_argument("--lr", default = 1e-4, type = float, help = "Learning Rate")
41 | args.add_argument("--sample_every", default = 5, type = int, help = "After t")
42 | args.add_argument("--train_ratio", default = 0.9, type = float, help = "Ratio of train data, rest is testing")
43 | args.add_argument("--beta1", default = 0.9, type = float, help = "Adam.beta1")
44 | args.add_argument("--beta2", default = 0.99, type = float, help = "Adam.beta2")
45 | args.add_argument("--grad_norm_clip", default = 1.0, type = float, help = "Adam.beta2")
46 | args.add_argument("--patience", default = 6, type = int, help = "training stops after patience runs out")
47 |
48 | # --- parse and add more
49 | args = args.parse_args()
50 | tb_path = os.path.join(args.save_folder, args.name)
51 | ckpt_path = os.path.join(tb_path, f"{args.name}.pt")
52 | args = SimpleNamespace(**vars(args), ckpt_path = ckpt_path, tb_path = tb_path)
53 |
54 | # make folders
55 | os.makedirs(args.save_folder, exist_ok=True)
56 | os.makedirs(args.tb_path, exist_ok=True)
57 |
58 | # DataSet
59 | datasetConf = T2SDatasetConfig(
60 | schema_file=args.schema_file,
61 | questions_file=args.questions_tsv,
62 | maxlen=args.maxlen,
63 | tokenizer_path=args.tokenizer_path
64 | )
65 | print(datasetConf)
66 | dtrain = T2SDataset(config=datasetConf, mode="train")
67 | dtest = T2SDataset(config=datasetConf, mode="test")
68 |
69 | # Model
70 | modelConfig = Text2SQLModelConfig(
71 | vocab_size=datasetConf.tokenizer.vocab_size(),
72 | n_embd=args.n_embd,
73 | maxlen=args.maxlen,
74 | n_decoder_layers=args.n_decoder_layers,
75 | n_sent_layers=args.n_sent_layers,
76 | n_db_layers=args.n_db_layers,
77 | n_head=args.n_head,
78 | )
79 | print(modelConfig)
80 | model = Text2SQLModel(modelConfig)
81 |
82 | # Trainer
83 | trainConfig = TrainerConfig(
84 | lr=args.lr,
85 | max_epochs=args.n_epochs,
86 | batch_size=args.batch_size,
87 | betas=(args.beta1, args.beta2),
88 | grad_norm_clip=args.grad_norm_clip,
89 | sample_every=args.sample_every,
90 | num_batch=(len(dtrain) // args.batch_size) + int(len(dtrain) % args.batch_size != 0),
91 | patience=args.patience,
92 | tb_path=args.tb_path,
93 | ckpt_path=args.ckpt_path
94 | )
95 | print(trainConfig)
96 | trainer = Trainer(model, dtrain, dtest, trainConfig)
97 | trainer.train(datasetConf)
98 |
--------------------------------------------------------------------------------