├── LICENSE ├── README.md ├── gen_data ├── main.py ├── name.py └── utils.py ├── model ├── config.py ├── dataset │ ├── dataset_preprocess.py │ ├── final_devset.json │ ├── final_new_testset.json │ └── final_small_trainset.json ├── eval.py ├── model.py ├── readme.md ├── test.py ├── train.py ├── utils.py └── visualization.py └── testset └── test.json /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-ShareAlike 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-ShareAlike 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-ShareAlike 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. BY-SA Compatible License means a license listed at 88 | creativecommons.org/compatiblelicenses, approved by Creative 89 | Commons as essentially the equivalent of this Public License. 90 | 91 | d. Copyright and Similar Rights means copyright and/or similar rights 92 | closely related to copyright including, without limitation, 93 | performance, broadcast, sound recording, and Sui Generis Database 94 | Rights, without regard to how the rights are labeled or 95 | categorized. For purposes of this Public License, the rights 96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 97 | Rights. 98 | 99 | e. Effective Technological Measures means those measures that, in the 100 | absence of proper authority, may not be circumvented under laws 101 | fulfilling obligations under Article 11 of the WIPO Copyright 102 | Treaty adopted on December 20, 1996, and/or similar international 103 | agreements. 104 | 105 | f. Exceptions and Limitations means fair use, fair dealing, and/or 106 | any other exception or limitation to Copyright and Similar Rights 107 | that applies to Your use of the Licensed Material. 108 | 109 | g. License Elements means the license attributes listed in the name 110 | of a Creative Commons Public License. The License Elements of this 111 | Public License are Attribution and ShareAlike. 112 | 113 | h. Licensed Material means the artistic or literary work, database, 114 | or other material to which the Licensor applied this Public 115 | License. 116 | 117 | i. Licensed Rights means the rights granted to You subject to the 118 | terms and conditions of this Public License, which are limited to 119 | all Copyright and Similar Rights that apply to Your use of the 120 | Licensed Material and that the Licensor has authority to license. 121 | 122 | j. Licensor means the individual(s) or entity(ies) granting rights 123 | under this Public License. 124 | 125 | k. Share means to provide material to the public by any means or 126 | process that requires permission under the Licensed Rights, such 127 | as reproduction, public display, public performance, distribution, 128 | dissemination, communication, or importation, and to make material 129 | available to the public including in ways that members of the 130 | public may access the material from a place and at a time 131 | individually chosen by them. 132 | 133 | l. Sui Generis Database Rights means rights other than copyright 134 | resulting from Directive 96/9/EC of the European Parliament and of 135 | the Council of 11 March 1996 on the legal protection of databases, 136 | as amended and/or succeeded, as well as other essentially 137 | equivalent rights anywhere in the world. 138 | 139 | m. You means the individual or entity exercising the Licensed Rights 140 | under this Public License. Your has a corresponding meaning. 141 | 142 | 143 | Section 2 -- Scope. 144 | 145 | a. License grant. 146 | 147 | 1. Subject to the terms and conditions of this Public License, 148 | the Licensor hereby grants You a worldwide, royalty-free, 149 | non-sublicensable, non-exclusive, irrevocable license to 150 | exercise the Licensed Rights in the Licensed Material to: 151 | 152 | a. reproduce and Share the Licensed Material, in whole or 153 | in part; and 154 | 155 | b. produce, reproduce, and Share Adapted Material. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. Additional offer from the Licensor -- Adapted Material. 186 | Every recipient of Adapted Material from You 187 | automatically receives an offer from the Licensor to 188 | exercise the Licensed Rights in the Adapted Material 189 | under the conditions of the Adapter's License You apply. 190 | 191 | c. No downstream restrictions. You may not offer or impose 192 | any additional or different terms or conditions on, or 193 | apply any Effective Technological Measures to, the 194 | Licensed Material if doing so restricts exercise of the 195 | Licensed Rights by any recipient of the Licensed 196 | Material. 197 | 198 | 6. No endorsement. Nothing in this Public License constitutes or 199 | may be construed as permission to assert or imply that You 200 | are, or that Your use of the Licensed Material is, connected 201 | with, or sponsored, endorsed, or granted official status by, 202 | the Licensor or others designated to receive attribution as 203 | provided in Section 3(a)(1)(A)(i). 204 | 205 | b. Other rights. 206 | 207 | 1. Moral rights, such as the right of integrity, are not 208 | licensed under this Public License, nor are publicity, 209 | privacy, and/or other similar personality rights; however, to 210 | the extent possible, the Licensor waives and/or agrees not to 211 | assert any such rights held by the Licensor to the limited 212 | extent necessary to allow You to exercise the Licensed 213 | Rights, but not otherwise. 214 | 215 | 2. Patent and trademark rights are not licensed under this 216 | Public License. 217 | 218 | 3. To the extent possible, the Licensor waives any right to 219 | collect royalties from You for the exercise of the Licensed 220 | Rights, whether directly or through a collecting society 221 | under any voluntary or waivable statutory or compulsory 222 | licensing scheme. In all other cases the Licensor expressly 223 | reserves any right to collect such royalties. 224 | 225 | 226 | Section 3 -- License Conditions. 227 | 228 | Your exercise of the Licensed Rights is expressly made subject to the 229 | following conditions. 230 | 231 | a. Attribution. 232 | 233 | 1. If You Share the Licensed Material (including in modified 234 | form), You must: 235 | 236 | a. retain the following if it is supplied by the Licensor 237 | with the Licensed Material: 238 | 239 | i. identification of the creator(s) of the Licensed 240 | Material and any others designated to receive 241 | attribution, in any reasonable manner requested by 242 | the Licensor (including by pseudonym if 243 | designated); 244 | 245 | ii. a copyright notice; 246 | 247 | iii. a notice that refers to this Public License; 248 | 249 | iv. a notice that refers to the disclaimer of 250 | warranties; 251 | 252 | v. a URI or hyperlink to the Licensed Material to the 253 | extent reasonably practicable; 254 | 255 | b. indicate if You modified the Licensed Material and 256 | retain an indication of any previous modifications; and 257 | 258 | c. indicate the Licensed Material is licensed under this 259 | Public License, and include the text of, or the URI or 260 | hyperlink to, this Public License. 261 | 262 | 2. You may satisfy the conditions in Section 3(a)(1) in any 263 | reasonable manner based on the medium, means, and context in 264 | which You Share the Licensed Material. For example, it may be 265 | reasonable to satisfy the conditions by providing a URI or 266 | hyperlink to a resource that includes the required 267 | information. 268 | 269 | 3. If requested by the Licensor, You must remove any of the 270 | information required by Section 3(a)(1)(A) to the extent 271 | reasonably practicable. 272 | 273 | b. ShareAlike. 274 | 275 | In addition to the conditions in Section 3(a), if You Share 276 | Adapted Material You produce, the following conditions also apply. 277 | 278 | 1. The Adapter's License You apply must be a Creative Commons 279 | license with the same License Elements, this version or 280 | later, or a BY-SA Compatible License. 281 | 282 | 2. You must include the text of, or the URI or hyperlink to, the 283 | Adapter's License You apply. You may satisfy this condition 284 | in any reasonable manner based on the medium, means, and 285 | context in which You Share Adapted Material. 286 | 287 | 3. You may not offer or impose any additional or different terms 288 | or conditions on, or apply any Effective Technological 289 | Measures to, Adapted Material that restrict exercise of the 290 | rights granted under the Adapter's License You apply. 291 | 292 | 293 | Section 4 -- Sui Generis Database Rights. 294 | 295 | Where the Licensed Rights include Sui Generis Database Rights that 296 | apply to Your use of the Licensed Material: 297 | 298 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 299 | to extract, reuse, reproduce, and Share all or a substantial 300 | portion of the contents of the database; 301 | 302 | b. if You include all or a substantial portion of the database 303 | contents in a database in which You have Sui Generis Database 304 | Rights, then the database in which You have Sui Generis Database 305 | Rights (but not its individual contents) is Adapted Material, 306 | 307 | including for purposes of Section 3(b); and 308 | c. You must comply with the conditions in Section 3(a) if You Share 309 | all or a substantial portion of the contents of the database. 310 | 311 | For the avoidance of doubt, this Section 4 supplements and does not 312 | replace Your obligations under this Public License where the Licensed 313 | Rights include other Copyright and Similar Rights. 314 | 315 | 316 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 317 | 318 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 319 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 320 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 321 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 322 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 323 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 324 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 325 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 326 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 327 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 328 | 329 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 330 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 331 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 332 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 333 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 334 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 335 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 336 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 337 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 338 | 339 | c. The disclaimer of warranties and limitation of liability provided 340 | above shall be interpreted in a manner that, to the extent 341 | possible, most closely approximates an absolute disclaimer and 342 | waiver of all liability. 343 | 344 | 345 | Section 6 -- Term and Termination. 346 | 347 | a. This Public License applies for the term of the Copyright and 348 | Similar Rights licensed here. However, if You fail to comply with 349 | this Public License, then Your rights under this Public License 350 | terminate automatically. 351 | 352 | b. Where Your right to use the Licensed Material has terminated under 353 | Section 6(a), it reinstates: 354 | 355 | 1. automatically as of the date the violation is cured, provided 356 | it is cured within 30 days of Your discovery of the 357 | violation; or 358 | 359 | 2. upon express reinstatement by the Licensor. 360 | 361 | For the avoidance of doubt, this Section 6(b) does not affect any 362 | right the Licensor may have to seek remedies for Your violations 363 | of this Public License. 364 | 365 | c. For the avoidance of doubt, the Licensor may also offer the 366 | Licensed Material under separate terms or conditions or stop 367 | distributing the Licensed Material at any time; however, doing so 368 | will not terminate this Public License. 369 | 370 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 371 | License. 372 | 373 | 374 | Section 7 -- Other Terms and Conditions. 375 | 376 | a. The Licensor shall not be bound by any additional or different 377 | terms or conditions communicated by You unless expressly agreed. 378 | 379 | b. Any arrangements, understandings, or agreements regarding the 380 | Licensed Material not stated herein are separate from and 381 | independent of the terms and conditions of this Public License. 382 | 383 | 384 | Section 8 -- Interpretation. 385 | 386 | a. For the avoidance of doubt, this Public License does not, and 387 | shall not be interpreted to, reduce, limit, restrict, or impose 388 | conditions on any use of the Licensed Material that could lawfully 389 | be made without permission under this Public License. 390 | 391 | b. To the extent possible, if any provision of this Public License is 392 | deemed unenforceable, it shall be automatically reformed to the 393 | minimum extent necessary to make it enforceable. If the provision 394 | cannot be reformed, it shall be severed from this Public License 395 | without affecting the enforceability of the remaining terms and 396 | conditions. 397 | 398 | c. No term or condition of this Public License will be waived and no 399 | failure to comply consented to unless expressly agreed to by the 400 | Licensor. 401 | 402 | d. Nothing in this Public License constitutes or may be interpreted 403 | as a limitation upon, or waiver of, any privileges and immunities 404 | that apply to the Licensor or You, including from the legal 405 | processes of any jurisdiction or authority. 406 | 407 | 408 | ======================================================================= 409 | 410 | Creative Commons is not a party to its public licenses. 411 | Notwithstanding, Creative Commons may elect to apply one of its public 412 | licenses to material it publishes and in those instances will be 413 | considered the “Licensor.” The text of the Creative Commons public 414 | licenses is dedicated to the public domain under the CC0 Public Domain 415 | Dedication. Except for the limited purpose of indicating that material 416 | is shared under a Creative Commons public license or as otherwise 417 | permitted by the Creative Commons policies published at 418 | creativecommons.org/policies, Creative Commons does not authorize the 419 | use of the trademark "Creative Commons" or any other trademark or logo 420 | of Creative Commons without its prior written consent including, 421 | without limitation, in connection with any unauthorized modifications 422 | to any of its public licenses or any other arrangements, 423 | understandings, or agreements concerning use of licensed material. For 424 | the avoidance of doubt, this paragraph does not form part of the public 425 | licenses. 426 | 427 | Creative Commons may be contacted at creativecommons.org. 428 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TriageSQL 2 | The dataset and source code for our paper: ["Did You Ask a Good Question? A Cross-Domain Question Intention Classification Benchmark for Text-to-SQL"](https://arxiv.org/abs/2010.12634) 3 | 4 | # Dataset Download 5 | Due to the size limitation, please download the dataset from [Google Drive](https://drive.google.com/file/d/1w55CaVEuimUlP-jerOCrVHF1iF0FZYKe/view?usp=sharing). 6 | 7 | # Citations 8 | 9 | If you want to use TriageSQL in your work, please cite as follows: 10 | ``` 11 | @article{zhang2020did, 12 | title={Did You Ask a Good Question? A Cross-Domain Question Intention Classification Benchmark for Text-to-SQL}, 13 | author={Zhang, Yusen and Dong, Xiangyu and Chang, Shuaichen and Yu, Tao and Shi, Peng and Zhang, Rui}, 14 | journal={arXiv preprint arXiv:2010.12634}, 15 | year={2020} 16 | } 17 | ``` 18 | 19 | # Dataset 20 | In each json file of the dataset, one can find a field called `type`, which includes 5 different values, including `small talk`, `answerable`, `ambiguous`, `lack data`, and `unanswerable by sql`, corresponding to 5 different types described in our paper. Here is the summary of our dataset and the corresponding experiment results: 21 | 22 | | Type | Trainset | Devset | Testset | Type Alias | Reported F1 | 23 | | ---- | -------- | ------ | ------- | ---------- | ----------- | 24 | | small talk | 31160 | 7790 | 500 | Improper | 0.88 | 25 | | ambiguous | 48592 | 9564 | 500 | Ambiguous | 0.43 | 26 | | lack data | 90375 | 19566 | 500 | ExtKnow | 0.56 | 27 | | unanswerable by sql | 124225 | 26330 | 500 | Non-SQL | 0.90 | 28 | | answerable | 139884 | 32892 | 500 | Answerable | 0.53 | 29 | | overall | 434236 | 194037 | 2500 | TriageSQL | 0.66 | 30 | 31 | The folder `src` contains all the source files used to construct the proposed TriageSQL. In addition, some part of files contains more details about the dataset, such as `databaseid` which is the id of the schema in the original dataset, e.g. "flight_2" in CoSQL, while `question_datasetid` indicates the original dataset name of the questions, e.g. "quac". Some of the samples do not contain these fields because they are either human-annotated or edited. 32 | 33 | # Model 34 | We also include the source code for RoBERTa baseline in our project in `/model`. It is a multi-classifer with 5 classes where '0' represents answerable, '1'-'4' represent distinct types of unanswerable questions. Given the dataset from [Google Drive](https://drive.google.com/file/d/1w55CaVEuimUlP-jerOCrVHF1iF0FZYKe/view?usp=sharing), you may need to conduct some preprocessing to obtain train/dev/test set. You can directly download from [here](https://drive.google.com/file/d/1ol1xFpGuH0BdLw26MvQoeCHLOtTqQ60i/view?usp=sharing) or make your own dataset using the following instructions: 35 | 36 | ## Constructing input file for the RoBERTa model 37 | The same as `/testset/test.json`, our input file is a json list with shape (num_of_question, 3) containing 3 lists: query, schema, and label. 38 | - query: containing strings of questions 39 | - schema: contianing strings of schema for each question, i.e., "table_name.column_name1 | table_name.column_name2 | ... " for multi-table questions, and column_name1 | column_name2 for single-table questions. 40 | - labels of questions, see config.label_dict for the mapping, leave arbitary value if testing is not needed or true labels are not given. 41 | 42 | **when preprocessing, please use lower case for all data, and remove the meaningless table names as well, such as T10023-1242. Also, we sample 10k from each type to form the large input dataset** 43 | 44 | ## Running 45 | After adjusting the parameters in `config.py`, one can simply run `python train.py` or `python eval.py` to train or evaluate the model. 46 | 47 | ## Explanation of other files 48 | - config.py: hyper parameters 49 | - train.py: training and evaluation of the model 50 | - utils.py: loading the dataset and tokenization 51 | - model.py: the RoBERTa classification model we used 52 | - test.json: sample of test input 53 | 54 | 55 | -------------------------------------------------------------------------------- /gen_data/main.py: -------------------------------------------------------------------------------- 1 | from name import * 2 | from utils import * 3 | import os 4 | import time 5 | 6 | def load_spider_data(): 7 | spider_train_path = os.path.join(Spiderpath, Spidertrain) 8 | spider_others_path = os.path.join(Spiderpath, Spiderothers) 9 | spider_dev_path = os.path.join(Spiderpath, Spiderdev) 10 | spider_table_path = os.path.join(Spiderpath, Spidertable) 11 | 12 | spider_train_data = read_json(spider_train_path) 13 | spider_others_data = read_json(spider_others_path) 14 | spider_dev_data = read_json(spider_dev_path) 15 | spider_table_data = read_json(spider_table_path) 16 | 17 | total_data = [] 18 | total_data += spider_train_data 19 | total_data += spider_dev_data 20 | 21 | questions, dbschema = preprocess_data(total_data, spider_table_data, Spider) 22 | otherquestions, dbschema = preprocess_data(spider_others_data, spider_table_data, Spiderother) 23 | 24 | return questions, dbschema, otherquestions 25 | 26 | def load_hybridQA_data(): 27 | hybridQA_train_path = os.path.join(HybridQApath, HybridQAtrain) 28 | hybridQA_test_path = os.path.join(HybridQApath, HybridQAtest) 29 | hybridQA_dev_path = os.path.join(HybridQApath, HybridQAdev) 30 | 31 | hybridQA_train_data = read_json(hybridQA_train_path) 32 | hybridQA_test_data = read_json(hybridQA_test_path) 33 | hybridQA_dev_data = read_json(hybridQA_dev_path) 34 | hybridQA_table_data = get_hybridQA_table_data() 35 | 36 | total_data = [] 37 | total_data += hybridQA_train_data 38 | total_data += hybridQA_test_data 39 | total_data += hybridQA_dev_data 40 | 41 | questions, dbschema = preprocess_data(total_data, hybridQA_table_data, HybridQA) 42 | 43 | return questions, dbschema 44 | 45 | def load_wikiSQL_data(): 46 | wikiSQL_train_path = os.path.join(WikiSQLpath, WikiSQLtrain) 47 | wikiSQL_test_path = os.path.join(WikiSQLpath, WikiSQLtest) 48 | wikiSQL_dev_path = os.path.join(WikiSQLpath, WikiSQLdev) 49 | wikiSQL_train_table_path = os.path.join(WikiSQLpath, WikiSQLtraintable) 50 | wikiSQL_test_table_path = os.path.join(WikiSQLpath, WikiSQLtesttable) 51 | wikiSQL_dev_table_path = os.path.join(WikiSQLpath, WikiSQLdevtable) 52 | 53 | wikiSQL_train_data = read_jsonl(wikiSQL_train_path) 54 | wikiSQL_test_data = read_jsonl(wikiSQL_test_path) 55 | wikiSQL_dev_data = read_jsonl(wikiSQL_dev_path) 56 | wikiSQL_train_table_data = read_jsonl(wikiSQL_train_table_path) 57 | wikiSQL_test_table_data = read_jsonl(wikiSQL_test_table_path) 58 | wikiSQL_dev_table_data = read_jsonl(wikiSQL_dev_table_path) 59 | 60 | total_data = [] 61 | total_data += wikiSQL_train_data 62 | total_data += wikiSQL_test_data 63 | total_data += wikiSQL_dev_data 64 | 65 | total_table = [] 66 | total_table += wikiSQL_train_table_data 67 | total_table += wikiSQL_test_table_data 68 | total_table += wikiSQL_dev_table_data 69 | 70 | questions, dbschema = preprocess_data(total_data, total_table, WikiSQL) 71 | 72 | return questions, dbschema 73 | 74 | def load_wikitable_data(): 75 | wikitable_train_path = os.path.join(Wikitablepath, Wikitabletrain) 76 | 77 | wikitable_train_data = get_wikitable_question(wikitable_train_path) 78 | wikitable_table_data = get_wikitable_table_data() 79 | 80 | questions, dbschema = preprocess_data(wikitable_train_data, wikitable_table_data, Wikitable) 81 | 82 | return questions, dbschema 83 | 84 | def load_kvret_data(): 85 | kvret_train_path = os.path.join(Kvretpath, Kvrettrain) 86 | kvret_dev_path = os.path.join(Kvretpath, Kvretdev) 87 | 88 | kvret_train_data = read_json(kvret_train_path) 89 | kvret_dev_data = read_json(kvret_dev_path) 90 | 91 | total_data = [] 92 | total_data += kvret_train_data 93 | total_data += kvret_dev_data 94 | 95 | questions, dbschema = preprocess_data(total_data, [], Kvret) 96 | 97 | return questions, dbschema 98 | 99 | def load_tablefact_data(): 100 | tablefact_train1_path = os.path.join(Tablefactquestionpath, Tablefacttrain1) 101 | tablefact_train2_path = os.path.join(Tablefactquestionpath, Tablefacttrain2) 102 | 103 | tablefact_train1_data = read_json(tablefact_train1_path) 104 | tablefact_train2_data = read_json(tablefact_train2_path) 105 | tablefact_table_data = get_tablefact_table_data() 106 | 107 | total_data = {} 108 | total_data.update(tablefact_train1_data) 109 | total_data.update(tablefact_train2_data) 110 | 111 | questions, dbschema = preprocess_data(total_data, tablefact_table_data, Tablefact) 112 | 113 | return questions, dbschema 114 | 115 | def load_msmarco_data(): 116 | #msmarcousefulnesspath = os.path.join(Msmarcopath, Msmarcousefulness) 117 | msmarcomsmarcopath = os.path.join(Msmarcopath, Msmarcomsmarco) 118 | #questions = get_msmarco_usefulness(msmarcousefulnesspath) 119 | questions = read_json(msmarcomsmarcopath) 120 | 121 | questions, dbschema = preprocess_data(questions, [], Msmarco) 122 | 123 | return questions, dbschema 124 | 125 | def load_wikiQA_data(): 126 | wikiqapath = os.path.join(WikiQApath, WikiQAtsv) 127 | 128 | questions = get_wikiQAtsv(wikiqapath) 129 | 130 | questions, dbschema = preprocess_data(questions, [], WikiQA) 131 | 132 | return questions, dbschema 133 | 134 | def load_coqa_data(): 135 | coqapath = os.path.join(Coqapath, Coqatrain) 136 | 137 | content = read_json(coqapath) 138 | 139 | questions, dbschema = preprocess_data(content, [], Coqa) 140 | 141 | return questions, dbschema 142 | 143 | def load_quac_data(): 144 | quacpath = os.path.join(Quacpath, Quactrain) 145 | 146 | content = read_json(quacpath) 147 | 148 | questions, dbschema = preprocess_data(content, [], Quac) 149 | 150 | return questions, dbschema 151 | 152 | def load_dbdomain_data(): 153 | dbdomainsqlitepath = os.path.join(Dbdomainpath, Dbdomainsqlite) 154 | dbdomainrevisedpath = os.path.join(Dbdomainpath, Dbdomainrevised) 155 | 156 | db2tables = get_domainsqlite(dbdomainsqlitepath, dbdomainrevisedpath) 157 | db2questionsambiguous, db2questionsnotambiguous = get_domainrevised(dbdomainrevisedpath) 158 | 159 | ambiguousquestions, dbschema = preprocess_data(db2questionsambiguous, db2tables, Dbdomainambiguous) 160 | notambiguousquestions, dbschema = preprocess_data(db2questionsnotambiguous, db2tables, Dbdomainnotambiguous) 161 | 162 | return ambiguousquestions, notambiguousquestions, dbschema 163 | 164 | def load_alex_data(): 165 | alexpath = os.path.join(Alexpath, Alexdataset) 166 | 167 | alexquestions = read_json(alexpath) 168 | 169 | questions, dbschema = preprocess_data(alexquestions, [], Alex) 170 | 171 | return questions, dbschema 172 | 173 | def load_googlenq_data(): 174 | """googlenqpath = os.path.join(Googlenqpath, Googlenqdev) 175 | 176 | googlenqquestions = read_jsonl(googlenqpath)""" 177 | googlenqpath = os.path.join(Googlenqpath, Googlenqdata) 178 | 179 | googlenqquestions = read_json(googlenqpath) 180 | 181 | questions, dbschema = preprocess_data(googlenqquestions, [], Googlenq) 182 | 183 | return questions, dbschema 184 | 185 | def load_totto_data(): 186 | tottotrainpath = os.path.join(Tottopath, Tottotrain) 187 | tottodevpath = os.path.join(Tottopath, Tottodev) 188 | 189 | #tottotrainquestions = read_jsonl(tottotrainpath) 190 | tottodevquestions = read_jsonl(tottodevpath) 191 | 192 | total_data = [] 193 | #total_data += tottotrainquestions 194 | total_data += tottodevquestions 195 | 196 | questions, dbschema = preprocess_data(total_data, [], Totto) 197 | 198 | return questions, dbschema 199 | 200 | def load_logicnlg_data(): 201 | logicnlgtrainpath = os.path.join(Logicnlgpath, Logicnlgtrain) 202 | #logicnlgtestpath = os.path.join(Logicnlgpath, Logicnlgtest) 203 | #logicnlgvalpath = os.path.join(Logicnlgpath, Logicnlgval) 204 | 205 | logicnlgtrainquestions = read_json(logicnlgtrainpath) 206 | 207 | questions, dbschema = preprocess_data(logicnlgtrainquestions, [], Logicnlg) 208 | 209 | return questions, dbschema 210 | 211 | def load_sparc_data(): 212 | sparctrainpath = os.path.join(Sparcpath, Sparctrain) 213 | sparcdevpath = os.path.join(Sparcpath, Sparcdev) 214 | sparctablespath = os.path.join(Sparcpath, Sparctables) 215 | 216 | sparctraindata = read_json(sparctrainpath) 217 | sparcdevdata = read_json(sparcdevpath) 218 | sparctables = read_json(sparctablespath) 219 | 220 | total_data = [] 221 | total_data += sparctraindata 222 | total_data += sparcdevdata 223 | 224 | questions, dbschema = preprocess_data(total_data, sparctables, Sparc) 225 | 226 | return questions, dbschema 227 | 228 | def load_cosql_data(): 229 | cosqldatapath = os.path.join(Cosqlpath, Cosqldialogs) 230 | cosqltrainpath = os.path.join(Cosqluserintentpath, Cosqltrain) 231 | cosqldevpath = os.path.join(Cosqluserintentpath, Cosqldev) 232 | cosqltablepath = os.path.join(Cosqlpath, Cosqltables) 233 | 234 | cosqldata = read_json(cosqldatapath) 235 | cosqltraindata = read_json(cosqltrainpath) 236 | cosqldevdata = read_json(cosqldevpath) 237 | cosqltables = read_json(cosqltablepath) 238 | 239 | total_data = [] 240 | total_data += cosqltraindata 241 | total_data += cosqldevdata 242 | 243 | questions, dbschema = preprocess_data(total_data, cosqltables, Cosql) 244 | questionsnotambiguous, dbschema = preprocess_data(cosqldata, cosqltables, Cosqlnotambiguous) 245 | 246 | return questions, dbschema, questionsnotambiguous 247 | 248 | def load_alexa_data(): 249 | alexatrainpath = os.path.join(Alexapath, Alexatrain) 250 | alexavalidfreqpath = os.path.join(Alexapath, Alexavalidfreq) 251 | alexavalidrarepath = os.path.join(Alexapath, Alexavalidrare) 252 | alexatestfreqpath = os.path.join(Alexapath, Alexatestfreq) 253 | alexatestrarepath = os.path.join(Alexapath, Alexatestrare) 254 | 255 | alexatraindata = read_json(alexatrainpath) 256 | alexavalidfreq = read_json(alexavalidfreqpath) 257 | alexavalidrare = read_json(alexavalidrarepath) 258 | alexatestfreq = read_json(alexatestfreqpath) 259 | alexatestrare = read_json(alexatestrarepath) 260 | 261 | total_data = {} 262 | total_data.update(alexatraindata) 263 | total_data.update(alexavalidfreq) 264 | total_data.update(alexavalidrare) 265 | total_data.update(alexatestfreq) 266 | total_data.update(alexatestrare) 267 | 268 | questions, dbschema = preprocess_data(total_data, [], Alexa) 269 | 270 | return questions, dbschema 271 | 272 | if __name__ == "__main__": 273 | start = time.time() 274 | 275 | spiderquestions, spiderdbschema, spiderotherquestions = load_spider_data() 276 | hybridQAquestions, hybridQAschema = load_hybridQA_data() 277 | wikiSQLquestions, wikiSQLschema = load_wikiSQL_data() 278 | wikitablequestions, wikitableschema = load_wikitable_data() 279 | #kvretquestions, kvretschema, kvretquestions2dataset = load_kvret_data() 280 | tablefactquestions, tablefactschema = load_tablefact_data() 281 | msmarcoquestions, msmarcoschema = load_msmarco_data() 282 | wikiqaquestions, wikiqaschema = load_wikiQA_data() 283 | coqaquestions, coqaschema = load_coqa_data() 284 | quacquestions, quacschema = load_quac_data() 285 | dbdomainambiguousquestions, dbdomainnotambiguousquestions, dbdomainschema = load_dbdomain_data() 286 | #alexquestions, alexschema = load_alex_data() 287 | googlenqquestions, googlenqschema = load_googlenq_data() 288 | tottoquestions, tottoschema = load_totto_data() 289 | logicnlgquestions, logicnlgschema = load_logicnlg_data() 290 | sparcquestions, sparcschema = load_sparc_data() 291 | cosqlquestions, cosqlschema, cosqlnotambiguousquestions = load_cosql_data() 292 | alexaquestions, alexaschema = load_alexa_data() 293 | 294 | total_schema = {Spider: spiderdbschema, 295 | Spiderother: spiderdbschema, 296 | HybridQA: hybridQAschema, 297 | WikiSQL: wikiSQLschema, 298 | Wikitable: wikitableschema, 299 | Tablefact: tablefactschema, 300 | Dbdomainambiguous: dbdomainschema, 301 | Dbdomainnotambiguous: dbdomainschema, 302 | Totto: tottoschema, 303 | Logicnlg: wikiSQLschema, 304 | Sparc: sparcschema, 305 | Cosql: cosqlschema, 306 | Cosqlnotambiguous: cosqlschema} 307 | 308 | train_schema, dev_schema, test_schema = splitdataschema(total_schema) 309 | 310 | """out = WikiSQL 311 | 312 | newschema = {} 313 | for key in train_schema[out]: 314 | newschema[key] = train_schema[out][key]['table'] 315 | mkdir(out) 316 | write_json( 317 | newschema, out + "/train.json" 318 | ) 319 | newschema = {} 320 | for key in test_schema[out]: 321 | newschema[key] = test_schema[out][key]['table'] 322 | mkdir(out) 323 | write_json( 324 | newschema, out + "/test.json" 325 | ) 326 | 327 | newschema = {} 328 | for key in dev_schema[out]: 329 | newschema[key] = dev_schema[out][key]['table'] 330 | mkdir(out) 331 | write_json( 332 | newschema, out + "/dev.json" 333 | )""" 334 | 335 | total_questions = {Spider: spiderquestions, 336 | Spiderother: spiderotherquestions, 337 | WikiSQL: wikiSQLquestions, 338 | Tablefact: tablefactquestions, 339 | Msmarco: msmarcoquestions, 340 | WikiQA: wikiqaquestions, 341 | Coqa: coqaquestions, 342 | Quac: quacquestions, 343 | Dbdomainambiguous: dbdomainambiguousquestions, 344 | Dbdomainnotambiguous: dbdomainnotambiguousquestions, 345 | #Alex: alexquestions, 346 | Googlenq: googlenqquestions, 347 | Totto: tottoquestions, 348 | Wikitable: wikitablequestions, 349 | HybridQA: hybridQAquestions, 350 | Logicnlg: logicnlgquestions, 351 | Sparc: sparcquestions, 352 | Cosql: cosqlquestions, 353 | Alexa: alexaquestions, 354 | Cosqlnotambiguous: cosqlnotambiguousquestions 355 | } 356 | 357 | 358 | question_count = defaultdict(int) 359 | for datasetid in total_questions: 360 | dataset = total_questions[datasetid] 361 | for dbid in dataset: 362 | db = dataset[dbid] 363 | question_count[datasetid] += len(db) 364 | 365 | 366 | print("question_num:") 367 | for datasetid in question_count: 368 | print(datasetid + ": " + str(question_count[datasetid])) 369 | print("------------------") 370 | 371 | #total_questions = filterquestion(totalq) 372 | 373 | trainq, devq, testq = splittype1question(total_questions) 374 | 375 | train_dataset = defaultdict(lambda : defaultdict(int)) 376 | dev_dataset = defaultdict(lambda : defaultdict(int)) 377 | test_dataset = defaultdict(lambda : defaultdict(int)) 378 | 379 | train_question = defaultdict(lambda: defaultdict(int)) 380 | dev_question = defaultdict(lambda: defaultdict(int)) 381 | test_question = defaultdict(lambda: defaultdict(int)) 382 | 383 | print("trainset:") 384 | train_type1 = gen_type1(trainq, train_schema, train_dataset, train_question) 385 | print("type1 dataset: " + str(dict(train_dataset[Type1]))) 386 | print("type1 question: " + str(dict(train_question[Type1]))) 387 | train_type2 = gen_type2(total_questions, train_schema, train_dataset, train_question) 388 | print("type2 dataset: " + str(dict(train_dataset[Type2]))) 389 | print("type2 question: " + str(dict(train_question[Type2]))) 390 | train_type3 = gen_type3(total_questions, train_schema, train_dataset, train_question) 391 | print("type3 dataset: " + str(dict(train_dataset[Type3]))) 392 | print("type3 question: " + str(dict(train_question[Type3]))) 393 | train_type4 = gen_type4(total_questions, train_schema, train_dataset, train_question) 394 | print("type4 dataset: " + str(dict(train_dataset[Type4]))) 395 | print("type4 question: " + str(dict(train_question[Type4]))) 396 | train_type5 = gen_type5(total_questions, train_schema, train_dataset, train_question) 397 | print("type5 dataset: " + str(dict(train_dataset[Type5]))) 398 | print("type5 question: " + str(dict(train_question[Type5]))) 399 | mkdir(Train) 400 | write_json(train_type1, os.path.join(Train, Type1json)) 401 | write_json(train_type2, os.path.join(Train, Type2json)) 402 | write_json(train_type3, os.path.join(Train, Type3json)) 403 | write_json(train_type4, os.path.join(Train, Type4json)) 404 | write_json(train_type5, os.path.join(Train, Type5json)) 405 | print("------------------") 406 | 407 | print("devset:") 408 | dev_type1 = gen_type1(devq, dev_schema, dev_dataset, dev_question) 409 | print("type1 dataset: " + str(dict(dev_dataset[Type1]))) 410 | print("type1 question: " + str(dict(dev_question[Type1]))) 411 | dev_type2 = gen_type2(total_questions, dev_schema, dev_dataset, dev_question) 412 | print("type2 dataset: " + str(dict(dev_dataset[Type2]))) 413 | print("type2 question: " + str(dict(dev_question[Type2]))) 414 | dev_type3 = gen_type3(total_questions, dev_schema, dev_dataset, dev_question) 415 | print("type3 dataset: " + str(dict(dev_dataset[Type3]))) 416 | print("type3 question: " + str(dict(dev_question[Type3]))) 417 | dev_type4 = gen_type4(total_questions, dev_schema, dev_dataset, dev_question) 418 | print("type4 dataset: " + str(dict(dev_dataset[Type4]))) 419 | print("type4 question: " + str(dict(dev_question[Type4]))) 420 | dev_type5 = gen_type5(total_questions, dev_schema, dev_dataset, dev_question) 421 | print("type5 dataset: " + str(dict(dev_dataset[Type5]))) 422 | print("type5 question: " + str(dict(dev_question[Type5]))) 423 | mkdir(Dev) 424 | write_json(dev_type1, os.path.join(Dev, Type1json)) 425 | write_json(dev_type2, os.path.join(Dev, Type2json)) 426 | write_json(dev_type3, os.path.join(Dev, Type3json)) 427 | write_json(dev_type4, os.path.join(Dev, Type4json)) 428 | write_json(dev_type5, os.path.join(Dev, Type5json)) 429 | print("------------------") 430 | 431 | print("testset:") 432 | test_type1 = gen_type1(testq, test_schema, test_dataset, test_question) 433 | print("type1 dataset: " + str(dict(test_dataset[Type1]))) 434 | print("type1 question: " + str(dict(test_question[Type1]))) 435 | test_type2 = gen_type2(total_questions, test_schema, test_dataset, test_question) 436 | print("type2 dataset: " + str(dict(test_dataset[Type2]))) 437 | print("type2 question: " + str(dict(test_question[Type2]))) 438 | test_type3 = gen_type3(total_questions, test_schema, test_dataset, test_question) 439 | print("type3 dataset: " + str(dict(test_dataset[Type3]))) 440 | print("type3 question: " + str(dict(test_question[Type3]))) 441 | test_type4 = gen_type4(total_questions, test_schema, test_dataset, test_question) 442 | print("type4 dataset: " + str(dict(test_dataset[Type4]))) 443 | print("type4 question: " + str(dict(test_question[Type4]))) 444 | test_type5 = gen_type5(total_questions, test_schema, test_dataset, test_question) 445 | print("type5 dataset: " + str(dict(test_dataset[Type5]))) 446 | print("type5 question: " + str(dict(test_question[Type5]))) 447 | mkdir(Test) 448 | write_json(test_type1, os.path.join(Test, Type1json)) 449 | write_json(test_type2, os.path.join(Test, Type2json)) 450 | write_json(test_type3, os.path.join(Test, Type3json)) 451 | write_json(test_type4, os.path.join(Test, Type4json)) 452 | write_json(test_type5, os.path.join(Test, Type5json)) 453 | print("------------------") 454 | 455 | #get_stat(train_question, test_question, dev_question) 456 | 457 | end = time.time() 458 | print(end - start) 459 | 460 | -------------------------------------------------------------------------------- /gen_data/name.py: -------------------------------------------------------------------------------- 1 | Spider = "spider" 2 | Spiderother = "spiderother" 3 | Spiderpath = '../dataset/spider' 4 | Spidertrain = "train_spider.json" 5 | Spiderothers = "train_others.json" 6 | Spiderdev = "dev.json" 7 | Spidertable = "tables.json" 8 | Spidertablename = "table_names" 9 | Spidercolumnname = "column_names" 10 | Spiderdbid = "db_id" 11 | Spiderquerytok = "query_toks" 12 | 13 | 14 | HybridQA = "hybridqa" 15 | HybridQApath = '../dataset/HybridQA/released_data' 16 | HybridQAtrain = 'train.json' 17 | HybridQAtest = 'test.json' 18 | HybridQAdev = 'dev.json' 19 | HybridQAdbid = 'table_id' 20 | HybridQAtablepath = '../dataset/HybridQA/tables' 21 | HybridQAheader = "header" 22 | HybridQAtableid = 'idx' 23 | 24 | WikiSQL = "wikisql" 25 | WikiSQLpath = '../dataset/WikiSQL/data' 26 | WikiSQLtrain = 'train.jsonl' 27 | WikiSQLtest = 'test.jsonl' 28 | WikiSQLdev = 'dev.jsonl' 29 | WikiSQLtraintable = 'train.tables.jsonl' 30 | WikiSQLtesttable = 'test.tables.jsonl' 31 | WikiSQLdevtable = 'dev.tables.jsonl' 32 | WikiSQLdbid = 'table_id' 33 | WikiSQLtable = 'id' 34 | WikiSQLheader = 'header' 35 | WikiSQLsql = "sql" 36 | WikiSQLsel = "sel" 37 | WikiSQLconds = "conds" 38 | 39 | Wikitable = "wikitable" 40 | Wikitablepath = '../dataset/WikiTableQuestions/data' 41 | Wikitabletrain = 'training.tsv' 42 | Wikitabledir = '../dataset/WikiTableQuestions/' 43 | Wikitabledbid = 'dbid' 44 | 45 | Kvret = "kvret" 46 | Kvretpath = '../dataset/dialog_datasets/kvret' 47 | Kvrettrain = 'kvret_train_public.json' 48 | Kvretdev = 'kvret_dev_public.json' 49 | Kvretscenario = "scenario" 50 | Kvretuuid = 'uuid' 51 | Kvretdialogue = 'dialogue' 52 | Kvretdata = 'data' 53 | Kvretutterance = 'utterance' 54 | 55 | Tablefact = "tablefact" 56 | Tablefactquestionpath = '../dataset/Table-Fact-Checking/collected_data' 57 | Tablefacttrain1 = 'r1_training_all.json' 58 | Tablefacttrain2 = 'r2_training_all.json' 59 | Tablefacttablepath = '../dataset/Table-Fact-Checking/data/all_csv' 60 | 61 | Msmarco = "msmarco" 62 | Msmarcopath = "../dataset/msmarco" 63 | Msmarcousefulness = "Usefulness.tsv" 64 | Msmarcodev = "dev_v2.1.json" 65 | Msmarcomsmarco = "msmarco.json" 66 | 67 | WikiQA = "wikiqa" 68 | WikiQApath = '../dataset/WikiQACorpus' 69 | WikiQAtsv = 'WikiQA.tsv' 70 | 71 | Coqa = "coqa" 72 | Coqapath = "../dataset/coqa" 73 | Coqatrain = "coqa-train-v1.0.json" 74 | Coqadata = "data" 75 | Coqaquestions = "questions" 76 | Coqainput = "input_text" 77 | 78 | Quac = "quac" 79 | Quacpath = "../dataset/quac" 80 | Quactrain = "train_v0.2.json" 81 | Quacdata = "data" 82 | Quacparagraphs = "paragraphs" 83 | Quacqas = "qas" 84 | Quacquestion = "question" 85 | 86 | Dbdomainambiguous = "dbdomainambiguous" 87 | Dbdomainnotambiguous = "dbdomainnotambiguous" 88 | Dbdomainpath = "../dataset/db-domain-adaptation" 89 | Dbdomainsqlite = "sqlite_fiiles" 90 | Dbdomainrevised = "annotated-data/manually-labelled-data" 91 | Dbdomainsentences = "sentences" 92 | Dbdomainfulltext = "full-text" 93 | Dbdomainmetafeature = "metafeature" 94 | 95 | Alex = "alex" 96 | Alexpath = "../dataset/alex_context_nlg_dataset" 97 | Alexdataset = "dataset.json" 98 | Alexcontextuttl = "context_utt_l" 99 | 100 | Googlenq = "googlenq" 101 | Googlenqpath = "../dataset/v1.0-simplified_nq-dev-all" 102 | Googlenqdev = "v1.0-simplified_nq-dev-all.jsonl" 103 | Googlenqquestion = "question_text" 104 | Googlenqdata = "googlenq.json" 105 | 106 | Totto = "totto" 107 | Tottopath = "../dataset/totto_data" 108 | Tottotrain = "totto_train_data.jsonl" 109 | Tottodev = "totto_dev_data.jsonl" 110 | Tottotable = "table" 111 | Tottovalue = "value" 112 | Tottotablesectiontitle = "table_section_title" 113 | Tottofinalsentence = "final_sentence" 114 | Tottosentenceannotations = "sentence_annotations" 115 | 116 | Logicnlg = "logicnlg" 117 | Logicnlgpath = "../dataset/LogicNLG/data" 118 | Logicnlgtrain = "train_lm.json" 119 | Logicnlgtest = "test_lm.json" 120 | Logicnlgval = "val_lm.json" 121 | 122 | Sparc = "sparc" 123 | Sparcpath = "../dataset/sparc" 124 | Sparctrain = "train.json" 125 | Sparctables = "tables.json" 126 | Sparcdev = "dev.json" 127 | Sparcdatabaseid = "database_id" 128 | Sparcinteraction = "interaction" 129 | Sparcquery = "query" 130 | Sparcutterance = "utterance" 131 | Sparcfinal = "final" 132 | Sparctablename = "table_names" 133 | Sparccolumnname = "column_names" 134 | Sparcdbid = "db_id" 135 | 136 | Cosql = "cosql" 137 | Cosqlnotambiguous = "cosqlnotambiguous" 138 | Cosqlpath = "../dataset/cosql_dataset" 139 | Cosqltables = "tables.json" 140 | Cosqldialogs = "cosql_all_info_dialogs.json" 141 | Cosqldbid = "db_id" 142 | Cosqlquerygoal = "query_goal" 143 | Cosqlsql = "sql" 144 | Cosqltablename = "table_names" 145 | Cosqlcolumnname = "column_names" 146 | Cosqltrain = "cosql_train.json" 147 | Cosqldev = "cosql_dev.json" 148 | Cosqluserintentpath = "../dataset/cosql_dataset/user_intent_prediction" 149 | Cosqlutterance = "utterance" 150 | Cosqlintent = "intent" 151 | Cosqluserdbid = "database_id" 152 | Cosqlambiguous = "AMBIGUOUS" 153 | 154 | Alexa = "alexa" 155 | Alexapath = "../dataset/Topical-Chat/conversations" 156 | Alexatrain = "train.json" 157 | Alexavalidfreq = "valid_freq.json" 158 | Alexavalidrare = "valid_rare.json" 159 | Alexatestfreq = "test_freq.json" 160 | Alexatestrare = "test_rare.json" 161 | Alexacontent = "content" 162 | Alexamessage = "message" 163 | 164 | 165 | Csv = 'csv' 166 | Column = "column" 167 | Table = "table" 168 | Question = "question" 169 | Outputpath = "./output.json" 170 | Type1 = 'small talk' 171 | Type2 = 'ambiguous' 172 | Type3 = 'lack data' 173 | Type4 = 'unanswerable by sql' 174 | Type5 = 'answerable' 175 | Train = "trainset" 176 | Dev = "devset" 177 | Test = "testset" 178 | Type1json = "type1.json" 179 | Type2json = "type2.json" 180 | Type3json = "type3.json" 181 | Type4json = "type4.json" 182 | Type5json = "type5.json" 183 | Trainportion = 0.64 184 | Devportion = 0.8 185 | Testportion = 1 186 | 187 | Outtype = "type" 188 | Outquestion = "question" 189 | Outquestiondatasetid = "question_datasetid" 190 | Outdatabaseiddatasetid = "databaseid_datasetid" 191 | Outdatabaseid = "databaseid" 192 | Outtables = "tables" 193 | 194 | type1qdataset = [Msmarco, WikiQA, Coqa, Quac, Googlenq, Alexa] 195 | type1dataset = [Spider, WikiSQL, Tablefact, HybridQA, Wikitable, Dbdomainnotambiguous, Dbdomainambiguous, 196 | Totto, Logicnlg, Sparc, Cosql] 197 | type2dataset = [Dbdomainambiguous, Cosql] 198 | type3dataset = [Spider, WikiSQL, Sparc, Dbdomainnotambiguous, Cosqlnotambiguous, HybridQA, Spiderother] 199 | type4dataset = [Tablefact, Totto, Wikitable, Logicnlg] 200 | type5dataset = [Spider, WikiSQL, Sparc, Dbdomainnotambiguous, Cosqlnotambiguous, Spiderother] -------------------------------------------------------------------------------- /gen_data/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | from name import * 4 | import random 5 | import os 6 | import re 7 | import copy 8 | import sqlite3 9 | import pandas as pd 10 | from collections import OrderedDict 11 | from pyecharts import Bar 12 | 13 | random.seed(133) 14 | 15 | def mkdir(path): 16 | if os.path.exists(path): 17 | print(path + " file exists!") 18 | else: 19 | os.mkdir(path) 20 | 21 | def read_json(path): 22 | f = open(path, "r", encoding='utf-8') 23 | content = json.load(f, object_pairs_hook=OrderedDict) 24 | f.close() 25 | return content 26 | 27 | def read_jsonl(path): 28 | f = open(path, "r", encoding='utf-8') 29 | lines = f.readlines() 30 | f.close() 31 | content = [] 32 | for line in lines: 33 | tmp = eval(line.replace("true", "True").replace("false", "False")) 34 | content.append(tmp) 35 | return content 36 | 37 | def get_wikitable_question(path): 38 | f = open(path, "r", encoding='utf-8') 39 | lines = f.readlines() 40 | f.close() 41 | content = [] 42 | for i, line in enumerate(lines): 43 | if i == 0: 44 | continue 45 | tmp = line.split("\t") 46 | content.append({Question: tmp[1], Wikitabledbid: tmp[2]}) 47 | return content 48 | 49 | def get_wikitable_table_data(): 50 | wikitablecsv = os.path.join(Wikitabledir, Csv) 51 | csvdirs = os.listdir(wikitablecsv) 52 | wikitable_table_data = {} 53 | for csvdir in csvdirs: 54 | wikitablecsvdirpath = os.path.join(wikitablecsv, csvdir) 55 | csvfiles = os.listdir(wikitablecsvdirpath) 56 | for csvfile in csvfiles: 57 | if csvfile.split('.')[1] == 'csv': 58 | csvpath = os.path.join(wikitablecsvdirpath, csvfile) 59 | f = open(csvpath, "r", encoding="utf-8") 60 | header = f.readline() 61 | f.close() 62 | wikitable_table_data[Csv + '/' + csvdir + '/' + csvfile] = list(header.strip().lower().replace("\"", "").split(",")) 63 | 64 | return wikitable_table_data 65 | 66 | def get_hybridQA_table_data(): 67 | tablefiles = os.listdir(HybridQAtablepath) 68 | hybridQA_table_data = defaultdict(list) 69 | for tablefile in tablefiles: 70 | tablefilepath = os.path.join(HybridQAtablepath, tablefile) 71 | content = read_json(tablefilepath) 72 | dbid = content[HybridQAtableid] 73 | header = content[HybridQAheader] 74 | for item in header: 75 | column = item[0][0].lower() 76 | hybridQA_table_data[dbid].append(column) 77 | return hybridQA_table_data 78 | 79 | def get_tablefact_table_data(): 80 | tablefiles = os.listdir(Tablefacttablepath) 81 | tablefact_table_data = defaultdict(list) 82 | for tablefile in tablefiles: 83 | tablefilepath = os.path.join(Tablefacttablepath, tablefile) 84 | f = open(tablefilepath, "r", encoding="utf-8") 85 | title = f.readline() 86 | f.close() 87 | column = title.strip().lower().split("#") 88 | tablefact_table_data[tablefile] = column 89 | return tablefact_table_data 90 | 91 | def get_msmarco_usefulness(path): 92 | f = open(path, "r", encoding='utf-8') 93 | lines = f.readlines() 94 | f.close() 95 | questions = set() 96 | for i, line in enumerate(lines): 97 | if i == 0: 98 | continue 99 | tmp = line.split("\t") 100 | questions.add(tmp[1].lower()) 101 | return questions 102 | 103 | def get_wikiQAtsv(path): 104 | f = open(path, "r", encoding='utf-8') 105 | lines = f.readlines() 106 | f.close() 107 | questions = set() 108 | for i, line in enumerate(lines): 109 | if i == 0: 110 | continue 111 | tmp = line.split("\t") 112 | questions.add(tmp[1].lower()) 113 | return questions 114 | 115 | def get_domainsqlite(path1, path2): 116 | files = os.listdir(path2) 117 | db2table = defaultdict(dict) 118 | for file in files: 119 | filepath = os.path.join(path1, file.split(".")[0] + ".sqlite") 120 | with sqlite3.connect(filepath) as con: 121 | c = con.cursor() 122 | for tables in c.execute("SELECT name FROM sqlite_master WHERE type='table'"): 123 | for table in tables: 124 | df = pd.read_sql_query("SELECT * FROM " + str(table), con=con) 125 | db2table[file.split(".")[0].lower()][str(table).lower()] = [column.replace("\"", "").lower() for column in list(df)] 126 | return db2table 127 | 128 | def get_domainrevised(path): 129 | files = os.listdir(path) 130 | db2questionsambiguous = defaultdict(list) 131 | db2questionsnotambiguous = defaultdict(list) 132 | for file in files: 133 | filepath = os.path.join(path, file) 134 | content = read_json(filepath) 135 | for item in content: 136 | metafeature = item[Dbdomainmetafeature] 137 | if metafeature: 138 | sentences = item[Dbdomainsentences] 139 | for sentence in sentences: 140 | question = sentence[Dbdomainfulltext].lower() 141 | db2questionsambiguous[file.split(".")[0].lower()].append(question) 142 | else: 143 | sentences = item[Dbdomainsentences] 144 | for sentence in sentences: 145 | question = sentence[Dbdomainfulltext].lower() 146 | db2questionsnotambiguous[file.split(".")[0].lower()].append(question) 147 | return db2questionsambiguous, db2questionsnotambiguous 148 | 149 | def preprocess_data(dataset, datadb, dataname): 150 | questions = defaultdict(list) 151 | dbschema = defaultdict(dict) 152 | 153 | if dataname == Spider: 154 | for item in dataset: 155 | dbid = item[Spiderdbid].lower() 156 | question = item[Question].lower() 157 | query = set([tok.lower().replace("_", " ").split(".")[-1] for tok in item[Spiderquerytok]]) 158 | questions[dbid].append((question, Spider, query)) 159 | for item in datadb: 160 | dbid = item[Spiderdbid].lower() 161 | table_names = item[Spidertablename] 162 | table_names = dict(zip(range(len(table_names)), table_names)) 163 | column_names = item[Spidercolumnname] 164 | table2column = defaultdict(list) 165 | column_set = set() 166 | for column in column_names: 167 | if column[0] in table_names: 168 | table2column[table_names[column[0]].lower()].append(column[1].lower()) 169 | column_set.add(column[1].lower()) 170 | dbschema[dbid][Table] = table2column 171 | dbschema[dbid][Column] = column_set 172 | elif dataname == Spiderother: 173 | for item in dataset: 174 | dbid = item[Spiderdbid].lower() 175 | question = item[Question].lower() 176 | query = set([tok.lower().replace("_", " ").split(".")[-1] for tok in item[Spiderquerytok]]) 177 | questions[dbid].append((question, Spiderother, query)) 178 | for item in datadb: 179 | dbid = item[Spiderdbid].lower() 180 | table_names = item[Spidertablename] 181 | table_names = dict(zip(range(len(table_names)), table_names)) 182 | column_names = item[Spidercolumnname] 183 | table2column = defaultdict(list) 184 | column_set = set() 185 | for column in column_names: 186 | if column[0] in table_names: 187 | table2column[table_names[column[0]].lower()].append(column[1].lower()) 188 | column_set.add(column[1].lower()) 189 | dbschema[dbid][Table] = table2column 190 | dbschema[dbid][Column] = column_set 191 | elif dataname == HybridQA: 192 | for item in dataset: 193 | dbid = str(item[HybridQAdbid]).lower() 194 | question = item[Question].lower() 195 | query = set(question.lower().split(" ")) 196 | #query = set() 197 | questions[dbid].append((question, HybridQA, query)) 198 | for item in datadb: 199 | table = datadb[item] 200 | dbschema[str(item).lower()][Table] = {str(item).lower(): table} 201 | dbschema[str(item).lower()][Column] = set(table) 202 | elif dataname == WikiSQL: 203 | for item in datadb: 204 | dbid = item[WikiSQLtable] 205 | dbschema[dbid][Table] = {dbid: [name.lower() for name in item[WikiSQLheader]]} 206 | dbschema[dbid][Column] = set([name.lower() for name in item[WikiSQLheader]]) 207 | for item in dataset: 208 | dbid = item[WikiSQLdbid].lower() 209 | question = item[Question].lower() 210 | sql = item[WikiSQLsql] 211 | query_index = set() 212 | query = set() 213 | query_index.add(sql[WikiSQLsel]) 214 | for item in sql[WikiSQLconds]: 215 | query_index.add(item[0]) 216 | for i in query_index: 217 | query.add(dbschema[dbid][Table][dbid][i]) 218 | #query = set(question.lower().split(" ")) 219 | questions[dbid].append((question, WikiSQL, query)) 220 | elif dataname == Wikitable: 221 | for item in dataset: 222 | dbid = item[Wikitabledbid].lower() 223 | question = item[Question].lower() 224 | #query = set(question.lower().split(" ")) 225 | query = set() 226 | questions[dbid].append((question, Wikitable, query)) 227 | for item in datadb: 228 | table = datadb[item] 229 | dbschema[item.lower()][Table] = {item.lower(): table} 230 | dbschema[item.lower()][Column] = set(table) 231 | elif dataname == Kvret: 232 | for item in dataset: 233 | dbid = item[Kvretscenario][Kvretuuid].lower() 234 | for term in item[Kvretdialogue]: 235 | question = term[Kvretdata][Kvretutterance].lower() 236 | #query = set(question.lower().split(" ")) 237 | query = set() 238 | questions[dbid].append((question, Kvret, query)) 239 | elif dataname == Tablefact: 240 | for item in dataset: 241 | dbid = item 242 | for question in dataset[dbid][0]: 243 | #query = set(question.lower().split(" ")) 244 | query = set() 245 | questions[dbid.split(".")[0]].append((question, Tablefact, query)) 246 | for item in datadb: 247 | table = datadb[item] 248 | dbid = item.split(".")[0] 249 | dbschema[dbid][Table] = {dbid: table} 250 | dbschema[dbid][Column] = set(table) 251 | elif dataname == Msmarco: 252 | for item in dataset: 253 | dbid = Msmarco + str(len(questions)) 254 | query = set() 255 | questions[dbid].append((item.lower(), Msmarco, query)) 256 | elif dataname == WikiQA: 257 | for item in dataset: 258 | dbid = WikiQA + str(len(questions)) 259 | query = set() 260 | questions[dbid].append((item.lower(), WikiQA, query)) 261 | elif dataname == Coqa: 262 | coqadata = dataset[Coqadata] 263 | for item in coqadata: 264 | coqaquestions = item[Coqaquestions] 265 | for question in coqaquestions: 266 | dbid = Coqa + str(len(questions)) 267 | query = set() 268 | questions[dbid].append((question[Coqainput].lower(), Coqa, query)) 269 | elif dataname == Quac: 270 | quacdata = dataset[Quacdata] 271 | for item in quacdata: 272 | quacparagraphs = item[Quacparagraphs] 273 | for paragraph in quacparagraphs: 274 | quacqas = paragraph[Quacqas] 275 | for qa in quacqas: 276 | question = qa[Quacquestion] 277 | dbid = Quac + str(len(questions)) 278 | query = set() 279 | questions[dbid].append((question.lower(), Quac, query)) 280 | elif dataname == Dbdomainambiguous: 281 | for item in dataset: 282 | for question in dataset[item]: 283 | query = set() 284 | questions[item.lower()].append((question.lower(), Dbdomainambiguous, query)) 285 | for item in datadb: 286 | tables = datadb[item] 287 | columns = set() 288 | dbschema[item][Table] = tables 289 | for table in tables: 290 | for column in tables[table]: 291 | columns.add(column) 292 | dbschema[item][Column] = columns 293 | elif dataname == Dbdomainnotambiguous: 294 | for item in dataset: 295 | for question in dataset[item]: 296 | query = set() 297 | questions[item.lower()].append((question.lower(), Dbdomainnotambiguous, query)) 298 | for item in datadb: 299 | tables = datadb[item] 300 | columns = set() 301 | dbschema[item][Table] = tables 302 | for table in tables: 303 | for column in tables[table]: 304 | columns.add(column) 305 | dbschema[item][Column] = columns 306 | elif dataname == Alex: 307 | for item in dataset: 308 | question = item[Alexcontextuttl] 309 | query = set() 310 | dbid = Alex + str(len(questions)) 311 | questions[dbid].append((question.lower(), Alex, query)) 312 | elif dataname == Googlenq: 313 | for item in dataset: 314 | #question = item[Googlenqquestion] 315 | question = item 316 | query = set() 317 | dbid = Googlenq + str(len(questions)) 318 | questions[dbid].append((question.lower(), Googlenq, query)) 319 | elif dataname == Totto: 320 | for item in dataset: 321 | tottoquestions = item[Tottosentenceannotations] 322 | query = set() 323 | dbid = item[Tottotablesectiontitle] 324 | tottotables = item[Tottotable][0] 325 | for question in tottoquestions: 326 | questions[dbid].append((question[Tottofinalsentence].lower(), Totto, query)) 327 | tables = {dbid: []} 328 | columns = set() 329 | for column in tottotables: 330 | tables[dbid].append(column[Tottovalue].lower()) 331 | columns.add(column[Tottovalue].lower()) 332 | dbschema[dbid][Table] = tables 333 | dbschema[dbid][Column] = columns 334 | elif dataname == Logicnlg: 335 | for item in dataset: 336 | dbid = item.split(".")[0] 337 | logicnlgquestions = dataset[item] 338 | query = set() 339 | for question in logicnlgquestions: 340 | questions[dbid].append((question[0].lower(), Logicnlg, query)) 341 | elif dataname == Sparc: 342 | for item in dataset: 343 | dbid = item[Sparcdatabaseid].lower() 344 | interactions = item[Sparcinteraction] 345 | for interaction in interactions: 346 | query = set(interaction[Sparcquery].lower().split()) 347 | question = interaction[Sparcutterance].lower() 348 | questions[dbid].append((question, Sparc, query)) 349 | final = item[Sparcfinal] 350 | query = set(final[Sparcquery].lower().split()) 351 | question = final[Sparcutterance] 352 | questions[dbid].append((question, Sparc, query)) 353 | for item in datadb: 354 | dbid = item[Sparcdbid].lower() 355 | table_names = item[Sparctablename] 356 | table_names = dict(zip(range(len(table_names)), table_names)) 357 | column_names = item[Sparccolumnname] 358 | table2column = defaultdict(list) 359 | column_set = set() 360 | for column in column_names: 361 | if column[0] in table_names: 362 | table2column[table_names[column[0]].lower()].append(column[1].lower()) 363 | column_set.add(column[1].lower()) 364 | dbschema[dbid][Table] = table2column 365 | dbschema[dbid][Column] = column_set 366 | elif dataname == Cosql: 367 | for item in dataset: 368 | """dialog = dataset[item] 369 | dbid = dialog[Cosqldbid] 370 | question = dialog[Cosqlquerygoal].lower() 371 | query = set(dialog[Cosqlsql].lower().split(" ")) 372 | questions[dbid].append((question, Cosql, query))""" 373 | 374 | dbid = item[Cosqluserdbid].lower() 375 | question = item[Cosqlutterance].lower() 376 | intent = item[Cosqlintent] 377 | query = set() 378 | if intent and intent[0] == Cosqlambiguous: 379 | questions[dbid].append((question, Cosql, query)) 380 | for item in datadb: 381 | dbid = item[Cosqldbid].lower() 382 | table_names = item[Cosqltablename] 383 | table_names = dict(zip(range(len(table_names)), table_names)) 384 | column_names = item[Cosqlcolumnname] 385 | table2column = defaultdict(list) 386 | column_set = set() 387 | for column in column_names: 388 | if column[0] in table_names: 389 | table2column[table_names[column[0]].lower()].append(column[1].lower()) 390 | column_set.add(column[1].lower()) 391 | dbschema[dbid][Table] = table2column 392 | dbschema[dbid][Column] = column_set 393 | elif dataname == Cosqlnotambiguous: 394 | for item in dataset: 395 | dialog = dataset[item] 396 | dbid = dialog[Cosqldbid] 397 | question = dialog[Cosqlquerygoal].lower() 398 | query = set(dialog[Cosqlsql].lower().split(" ")) 399 | questions[dbid].append((question, Cosql, query)) 400 | for item in datadb: 401 | dbid = item[Cosqldbid].lower() 402 | table_names = item[Cosqltablename] 403 | table_names = dict(zip(range(len(table_names)), table_names)) 404 | column_names = item[Cosqlcolumnname] 405 | table2column = defaultdict(list) 406 | column_set = set() 407 | for column in column_names: 408 | if column[0] in table_names: 409 | table2column[table_names[column[0]].lower()].append(column[1].lower()) 410 | column_set.add(column[1].lower()) 411 | dbschema[dbid][Table] = table2column 412 | dbschema[dbid][Column] = column_set 413 | elif dataname == Alexa: 414 | for item in dataset: 415 | dbid = item 416 | datas = dataset[item] 417 | content = datas[Alexacontent] 418 | for item in content: 419 | question = item[Alexamessage] 420 | query = set() 421 | questions[dbid].append((question, Alexa, query)) 422 | else: 423 | print("wrong type!") 424 | exit(0) 425 | 426 | return questions, dbschema 427 | 428 | def splitdataschema(total_schema): 429 | print("schema_num: ") 430 | for datasetid in total_schema: 431 | print(datasetid + ": " + str(len(total_schema[datasetid]))) 432 | print("------------------") 433 | 434 | train_schema = defaultdict(dict) 435 | dev_schema = defaultdict(dict) 436 | test_schema = defaultdict(dict) 437 | 438 | dbid2dataset = defaultdict(list) 439 | dbids = {} 440 | nonoverlapdbid = defaultdict(list) 441 | overlapdbid = defaultdict(list) 442 | tmpoverlapdbid = defaultdict(str) 443 | 444 | nonoverlapcolumns = [] 445 | tmpcolumns = defaultdict(list) 446 | tmpcolumnsset = defaultdict(set) 447 | for datasetid in total_schema: 448 | dataset = total_schema[datasetid] 449 | for dbid in dataset: 450 | tables = total_schema[datasetid][dbid][Table] 451 | tmp = "" 452 | tablekeys = sorted(list(tables.keys())) 453 | for tableid in tablekeys: 454 | #tmp = str(sorted(tables[tableid])) 455 | tmp += '#'.join(sorted(tables[tableid])).lower().replace("_", " ") + "#" 456 | #.replace(" ", "").replace(".", "").replace("*", "") 457 | value = datasetid + "#" + dbid 458 | if not value in tmpcolumnsset[tmp]: 459 | tmpcolumns[tmp].append(value) 460 | tmpcolumnsset[tmp].add(value) 461 | 462 | dbids.update(dataset) 463 | 464 | for tmp in tmpcolumns: 465 | nonoverlapcolumns.append(tmpcolumns[tmp]) 466 | 467 | overlapcount = defaultdict(int) 468 | 469 | for item in nonoverlapcolumns: 470 | if len(item) > 1: 471 | tmp = "" 472 | for term in item: 473 | tmpterm = term.split("#")[0] 474 | tmp += tmpterm + "-" 475 | tmp = "-".join(sorted(list(set(tmp[:-1].split("-"))))) 476 | overlapcount[tmp] += 1 477 | 478 | print("overlap database columns:") 479 | 480 | for datasetid in overlapcount: 481 | print(datasetid + ": " + str(overlapcount[datasetid])) 482 | print("------------------") 483 | 484 | random.shuffle(nonoverlapcolumns) 485 | for i, item in enumerate(nonoverlapcolumns): 486 | overlaplen = len(nonoverlapcolumns) 487 | trainlen = int(Trainportion * overlaplen) 488 | devlen = int(Devportion * overlaplen) 489 | testlen = int(Testportion * overlaplen) 490 | for term in item: 491 | datasetid = term.split("#")[0] 492 | dbid = term.split("#")[1] 493 | 494 | if i < trainlen: 495 | train_schema[datasetid][dbid] = {'table': total_schema[datasetid][dbid][Table], 'column': total_schema[datasetid][dbid][Column]} 496 | elif i >= trainlen and i < devlen: 497 | dev_schema[datasetid][dbid] = {'table': total_schema[datasetid][dbid][Table], 'column': total_schema[datasetid][dbid][Column]} 498 | elif i >= devlen and i < testlen: 499 | test_schema[datasetid][dbid] = {'table': total_schema[datasetid][dbid][Table], 'column': total_schema[datasetid][dbid][Column]} 500 | 501 | """for datasetid in total_schema: 502 | dataset = total_schema[datasetid] 503 | for dbid in dataset: 504 | dbid2dataset[dbid].append(datasetid) 505 | dbids.update(dataset) 506 | 507 | for dbid in dbid2dataset: 508 | frq = len(dbid2dataset[dbid]) 509 | datasetids = dbid2dataset[dbid] 510 | if frq > 1: 511 | tmpoverlapdbid[dbid] = "-".join(datasetids) 512 | else: 513 | nonoverlapdbid[datasetids[0]].append(dbid) 514 | 515 | overlapcount = defaultdict(int) 516 | for overlap in tmpoverlapdbid: 517 | overlapcount[tmpoverlapdbid[overlap]] += 1 518 | 519 | print("overlap database id:") 520 | for datasetid in overlapcount: 521 | print(datasetid + ": " + str(overlapcount[datasetid])) 522 | print("------------------") 523 | 524 | for dbid in tmpoverlapdbid: 525 | combineid = tmpoverlapdbid[dbid] 526 | overlapdbid[combineid].append(dbid) 527 | 528 | nonoverlapcolumns = [] 529 | tmpcolumns = defaultdict(set) 530 | for datasetid in nonoverlapdbid: 531 | dataset = nonoverlapdbid[datasetid] 532 | for dbid in dataset: 533 | tables = total_schema[datasetid][dbid][Table] 534 | tmp = "" 535 | for tableid in tables: 536 | tmp += '#'.join(tables[tableid]) + "#" 537 | tmpcolumns[tmp].add(datasetid + "#" + dbid) 538 | 539 | for tmp in tmpcolumns: 540 | nonoverlapcolumns.append(tmpcolumns[tmp]) 541 | 542 | overlapcount = defaultdict(int) 543 | 544 | for item in nonoverlapcolumns: 545 | if len(item) > 1: 546 | tmp = "" 547 | for term in item: 548 | tmpterm = term.split("#")[0] 549 | tmp += tmpterm + "-" 550 | tmp = tmp[:-1] 551 | overlapcount[tmp] += 1 552 | 553 | print("overlap database columns (in nonoverlap database id):") 554 | for datasetid in overlapcount: 555 | print(datasetid + ": " + str(overlapcount[datasetid])) 556 | print("------------------") 557 | 558 | for combineid in overlapdbid: 559 | overlapdbids = overlapdbid[combineid] 560 | random.shuffle(overlapdbids) 561 | overlaplen = len(overlapdbids) 562 | trainlen = int(Trainportion * overlaplen) 563 | devlen = int(Devportion * overlaplen) 564 | testlen = int(Testportion * overlaplen) 565 | datasetids = combineid.split("-") 566 | for datasetid in datasetids: 567 | for i, dbid in enumerate(overlapdbids): 568 | if i < trainlen: 569 | train_schema[datasetid][dbid] = dbids[dbid] 570 | elif i >= trainlen and i < devlen: 571 | dev_schema[datasetid][dbid] = dbids[dbid] 572 | elif i >= devlen and i < testlen: 573 | test_schema[datasetid][dbid] = dbids[dbid] 574 | 575 | random.shuffle(nonoverlapcolumns) 576 | for i, item in enumerate(nonoverlapcolumns): 577 | overlaplen = len(nonoverlapcolumns) 578 | trainlen = int(Trainportion * overlaplen) 579 | devlen = int(Devportion * overlaplen) 580 | testlen = int(Testportion * overlaplen) 581 | for term in item: 582 | datasetid = term.split("#")[0] 583 | dbid = term.split("#")[1] 584 | if i < trainlen: 585 | train_schema[datasetid][dbid] = dbids[dbid] 586 | elif i >= trainlen and i < devlen: 587 | dev_schema[datasetid][dbid] = dbids[dbid] 588 | elif i >= devlen and i < testlen: 589 | test_schema[datasetid][dbid] = dbids[dbid]""" 590 | 591 | """for datasetid in nonoverlapdbid: 592 | nonoverlapdbids = nonoverlapdbid[datasetid] 593 | random.shuffle(nonoverlapdbids) 594 | nonoverlaplen = len(nonoverlapdbids) 595 | trainlen = int(Trainportion * nonoverlaplen) 596 | devlen = int(Devportion * nonoverlaplen) 597 | testlen = int(Testportion * nonoverlaplen) 598 | for i, dbid in enumerate(nonoverlapdbids): 599 | if i < trainlen: 600 | train_schema[datasetid][dbid] = dbids[dbid] 601 | elif i >= trainlen and i < devlen: 602 | dev_schema[datasetid][dbid] = dbids[dbid] 603 | elif i >= devlen and i < testlen: 604 | test_schema[datasetid][dbid] = dbids[dbid]""" 605 | 606 | 607 | print("train_schema_num:") 608 | for schema in train_schema: 609 | print(schema + ": " + str(len(train_schema[schema]))) 610 | 611 | print("------------------") 612 | print("dev_schema_num:") 613 | for schema in dev_schema: 614 | print(schema + ": " + str(len(dev_schema[schema]))) 615 | 616 | print("------------------") 617 | print("test_schema_num:") 618 | for schema in test_schema: 619 | print(schema + ": " + str(len(test_schema[schema]))) 620 | 621 | print("------------------") 622 | 623 | 624 | return train_schema, dev_schema, test_schema 625 | 626 | def filterquestion(totalq): 627 | total_question = defaultdict(lambda: defaultdict(list)) 628 | questionsdict = {} 629 | 630 | for datasetid in totalq: 631 | dataset = totalq[datasetid] 632 | for dbid in dataset: 633 | database = dataset[dbid] 634 | for question, qdataset, query in database: 635 | lquestion = question.lower() 636 | if lquestion in questionsdict: 637 | if query: 638 | questionsdict[lquestion] = (dbid, qdataset, query) 639 | else: 640 | questionsdict[lquestion] = (dbid, qdataset, query) 641 | 642 | print("question_num:") 643 | datasetquestion = defaultdict(int) 644 | for question in questionsdict: 645 | dbid = questionsdict[question][0] 646 | dataset = questionsdict[question][1] 647 | query = questionsdict[question][2] 648 | total_question[dataset][dbid].append((question, dataset, query)) 649 | datasetquestion[dataset] += 1 650 | 651 | for dataset in datasetquestion: 652 | print(dataset + ": " + str(datasetquestion[dataset])) 653 | print("------------------") 654 | 655 | return total_question 656 | 657 | def splittype1question(total_questions): 658 | tmpquestions = list() 659 | 660 | for datasetid in total_questions: 661 | if datasetid in type1qdataset: 662 | dataset = total_questions[datasetid] 663 | for dbid in dataset: 664 | questions = dataset[dbid] 665 | for question, qdataset, query in questions: 666 | tmpquestions.append((qdataset, question)) 667 | #questions = list(set(tmpquestions)) 668 | 669 | questions = [] 670 | questionset = set() 671 | for question in tmpquestions: 672 | if question not in questionset: 673 | questions.append(question) 674 | questionset.add(question) 675 | 676 | random.shuffle(questions) 677 | 678 | questionlen = len(questions) 679 | trainlen = int(Trainportion * questionlen) 680 | devlen = int(Devportion * questionlen) 681 | testlen = int(Testportion * questionlen) 682 | 683 | trainq = questions[: trainlen] 684 | devq = questions[trainlen: devlen] 685 | testq = questions[devlen: testlen] 686 | 687 | return trainq, devq, testq 688 | 689 | def gen_type1(total_questions, total_schema, total_dataset, total_q): 690 | type1sample = [] 691 | 692 | for i, (qdataset, question) in enumerate(total_questions): 693 | #randomkey1 = random.sample(type1dataset, 1)[0] 694 | randomkey1 = type1dataset[i % len(type1dataset)] 695 | databasekey = list(total_schema[randomkey1].keys()) 696 | randomkey2 = random.sample(databasekey, 1)[0] 697 | schematable = total_schema[randomkey1][randomkey2][Table] 698 | if randomkey1 == Dbdomainnotambiguous or randomkey1 == Dbdomainambiguous: 699 | type1sample.append({Outtype: Type1, Outquestion: question, Outquestiondatasetid: qdataset, 700 | Outdatabaseiddatasetid: randomkey2, Outdatabaseid: randomkey2, Outtables: schematable}) 701 | #total_dataset[Type1][randomkey2] += 1 702 | #total_q[Type1][qdataset] += 1 703 | else: 704 | type1sample.append({Outtype: Type1, Outquestion: question, Outquestiondatasetid: qdataset, 705 | Outdatabaseiddatasetid: randomkey1, Outdatabaseid: randomkey2, Outtables: schematable}) 706 | #total_dataset[Type1][randomkey1] += 1 707 | #total_q[Type1][qdataset] += 1 708 | type1sample = random.sample(type1sample, int(len(type1sample) / 10)) 709 | for item in type1sample: 710 | datasetid = item[Outdatabaseiddatasetid] 711 | qdataset = item[Outquestiondatasetid] 712 | total_dataset[Type1][datasetid] += 1 713 | total_q[Type1][qdataset] += 1 714 | 715 | print("type1: " + str(len(type1sample))) 716 | return type1sample 717 | 718 | def gen_type2(total_questions, total_schema, total_dataset, total_q): 719 | type2sample = [] 720 | 721 | for datasetid in total_schema: 722 | if datasetid in type2dataset: 723 | dataset = total_schema[datasetid] 724 | for dbid in dataset: 725 | database = dataset[dbid] 726 | schematable = database[Table] 727 | if dbid in total_questions[datasetid]: 728 | for question, qdataset, query in total_questions[datasetid][dbid]: 729 | if datasetid == Cosql: 730 | type2sample.append({Outtype: Type2, Outquestion: question, Outquestiondatasetid: datasetid, 731 | Outdatabaseiddatasetid: datasetid, Outdatabaseid: dbid, Outtables: schematable}) 732 | total_dataset[Type2][datasetid] += 1 733 | total_q[Type2][qdataset] += 1 734 | else: 735 | type2sample.append({Outtype: Type2, Outquestion: question, Outquestiondatasetid: dbid, 736 | Outdatabaseiddatasetid: dbid, Outdatabaseid: dbid, 737 | Outtables: schematable}) 738 | total_dataset[Type2][dbid] += 1 739 | total_q[Type2][dbid] += 1 740 | 741 | print("type2: " + str(len(type2sample))) 742 | return type2sample 743 | 744 | def gen_type3(total_questions, total_schema, total_dataset, total_q): 745 | type3sample = [] 746 | deletenums = [1, 1, 2, 3] 747 | 748 | for datasetid in total_schema: 749 | if datasetid in type3dataset: 750 | dataset = total_schema[datasetid] 751 | for dbid in dataset: 752 | if dbid in total_questions[datasetid]: 753 | database = dataset[dbid] 754 | schematable = database[Table] 755 | schemacolumn = database[Column] 756 | questions = total_questions[datasetid][dbid] 757 | for question, qdataset, query in questions: 758 | if query & schemacolumn: 759 | tmp_schematable = defaultdict(list) 760 | deletecolumn = list(query & schemacolumn) 761 | deletenum = random.sample(deletenums, 1)[0] 762 | if deletenum > len(deletecolumn): 763 | deletenum = len(deletecolumn) 764 | deletecolumn = random.sample(deletecolumn, deletenum) 765 | for key in schematable: 766 | for item in schematable[key]: 767 | if item not in deletecolumn: 768 | tmp_schematable[key].append(item) 769 | if datasetid == Dbdomainnotambiguous or datasetid == Spiderother: 770 | type3sample.append({Outtype: Type3, Outquestion: question, Outquestiondatasetid: dbid, 771 | Outdatabaseiddatasetid: dbid, Outdatabaseid: dbid, 772 | Outtables: tmp_schematable}) 773 | total_dataset[Type3][dbid] += 1 774 | total_q[Type3][dbid] += 1 775 | elif datasetid == Cosqlnotambiguous: 776 | type3sample.append( 777 | {Outtype: Type3, Outquestion: question, Outquestiondatasetid: Cosql, 778 | Outdatabaseiddatasetid: Cosql, Outdatabaseid: dbid, 779 | Outtables: tmp_schematable}) 780 | total_dataset[Type3][Cosql] += 1 781 | total_q[Type3][Cosql] += 1 782 | else: 783 | type3sample.append( 784 | {Outtype: Type3, Outquestion: question, Outquestiondatasetid: datasetid, 785 | Outdatabaseiddatasetid: datasetid, Outdatabaseid: dbid, 786 | Outtables: tmp_schematable}) 787 | total_dataset[Type3][datasetid] += 1 788 | total_q[Type3][datasetid] += 1 789 | print("type3: " + str(len(type3sample))) 790 | return type3sample 791 | 792 | def gen_type4(total_questions, total_schema, total_dataset, total_q): 793 | type4sample = [] 794 | 795 | for datasetid in total_schema: 796 | if datasetid in type4dataset: 797 | dataset = total_schema[datasetid] 798 | for dbid in dataset: 799 | database = dataset[dbid] 800 | schematable = database[Table] 801 | if dbid in total_questions[datasetid]: 802 | for question, qdataset, query in total_questions[datasetid][dbid]: 803 | type4sample.append({Outtype: Type4, Outquestion: question, Outquestiondatasetid: datasetid, 804 | Outdatabaseiddatasetid: datasetid, Outdatabaseid: dbid, 805 | Outtables: schematable}) 806 | total_dataset[Type4][datasetid] += 1 807 | total_q[Type4][datasetid] += 1 808 | 809 | print("type4: " + str(len(type4sample))) 810 | return type4sample 811 | 812 | def gen_type5(total_questions, total_schema, total_dataset, total_q): 813 | type5sample = [] 814 | deletenums = [1, 1, 2, 3] 815 | 816 | for datasetid in total_schema: 817 | if datasetid in type5dataset: 818 | dataset = total_schema[datasetid] 819 | for dbid in dataset: 820 | database = dataset[dbid] 821 | schematable = database[Table] 822 | schemacolumn = database[Column] 823 | if dbid in total_questions[datasetid]: 824 | for question, qdataset, query in total_questions[datasetid][dbid]: 825 | if datasetid == Dbdomainnotambiguous or datasetid == Spiderother: 826 | type5sample.append({Outtype: Type5, Outquestion: question, Outquestiondatasetid: dbid, 827 | Outdatabaseiddatasetid: dbid, Outdatabaseid: dbid, 828 | Outtables: schematable}) 829 | total_dataset[Type5][dbid] += 1 830 | total_q[Type5][dbid] += 1 831 | elif datasetid == Cosqlnotambiguous: 832 | type5sample.append({Outtype: Type5, Outquestion: question, Outquestiondatasetid: Cosql, 833 | Outdatabaseiddatasetid: Cosql, Outdatabaseid: dbid, 834 | Outtables: schematable}) 835 | total_dataset[Type5][Cosql] += 1 836 | total_q[Type5][Cosql] += 1 837 | else: 838 | type5sample.append({Outtype: Type5, Outquestion: question, Outquestiondatasetid: datasetid, 839 | Outdatabaseiddatasetid: datasetid, Outdatabaseid: dbid, 840 | Outtables: schematable}) 841 | total_dataset[Type5][datasetid] += 1 842 | total_q[Type5][datasetid] += 1 843 | tmp_schematable = defaultdict(list) 844 | deletecolumn = list(schemacolumn - query) 845 | deletenum = random.sample(deletenums, 1)[0] 846 | if deletenum > len(deletecolumn): 847 | deletenum = len(deletecolumn) 848 | deletecolumn = random.sample(deletecolumn, deletenum) 849 | for key in schematable: 850 | for item in schematable[key]: 851 | if item not in deletecolumn: 852 | tmp_schematable[key].append(item) 853 | if datasetid == Dbdomainnotambiguous or datasetid == Spiderother: 854 | type5sample.append({Outtype: Type5, Outquestion: question, Outquestiondatasetid: dbid, 855 | Outdatabaseiddatasetid: dbid, Outdatabaseid: dbid, 856 | Outtables: tmp_schematable}) 857 | total_dataset[Type5][dbid] += 1 858 | total_q[Type5][dbid] += 1 859 | elif datasetid == Cosqlnotambiguous: 860 | type5sample.append({Outtype: Type5, Outquestion: question, Outquestiondatasetid: Cosql, 861 | Outdatabaseiddatasetid: Cosql, Outdatabaseid: dbid, 862 | Outtables: tmp_schematable}) 863 | total_dataset[Type5][Cosql] += 1 864 | total_q[Type5][Cosql] += 1 865 | else: 866 | type5sample.append({Outtype: Type5, Outquestion: question, Outquestiondatasetid: datasetid, 867 | Outdatabaseiddatasetid: datasetid, Outdatabaseid: dbid, 868 | Outtables: tmp_schematable}) 869 | total_dataset[Type5][datasetid] += 1 870 | total_q[Type5][datasetid] += 1 871 | print("type5: " + str(len(type5sample))) 872 | return type5sample 873 | 874 | def write_json(samples, path): 875 | f = open(path, "w", encoding="utf-8") 876 | json.dump(samples, f, sort_keys=True, indent=4, separators=(',', ': ')) 877 | f.close() 878 | 879 | def draw_stat(dataset, name): 880 | types = [Type1, Type2, Type3, Type4, Type5] 881 | datasetname = set() 882 | for t in dataset: 883 | datasetname |= set(dataset[t].keys()) 884 | datasetname = list(datasetname) 885 | datasetstats = [] 886 | for n in datasetname: 887 | datasetstat = [] 888 | for t in types: 889 | datasetstat.append(dataset[t][n]) 890 | datasetstats.append(datasetstat) 891 | 892 | bar = Bar("") 893 | for index, n in enumerate(datasetname): 894 | bar.add(n, types, datasetstats[index], is_stack=True, is_more_utils=True) 895 | bar.render(name + '_stat_graph.html') 896 | 897 | def get_stat(train_questions, test_questions, dev_questions): 898 | group1 = [Spider, Sparc, Cosql] 899 | group2 = [WikiSQL] 900 | group3 = ['restaurants', 'scholar', 'yelp', 'imdb', 'geo', 'academic'] 901 | group4 = [Tablefact, Totto, Logicnlg] 902 | group5 = [HybridQA] 903 | group6 = [Wikitable] 904 | group7 = [Alexa, Googlenq, Msmarco, WikiQA, Coqa, Quac] 905 | group1dict = defaultdict(int) 906 | group2dict = defaultdict(int) 907 | group3dict = defaultdict(int) 908 | group4dict = defaultdict(int) 909 | group5dict = defaultdict(int) 910 | group6dict = defaultdict(int) 911 | group7dict = defaultdict(int) 912 | 913 | for t in train_questions: 914 | datasets = train_questions[t] 915 | for dataset in datasets: 916 | if dataset in group1: 917 | group1dict[t] += datasets[dataset] 918 | elif dataset in group2: 919 | group2dict[t] += datasets[dataset] 920 | elif dataset in group3: 921 | group3dict[t] += datasets[dataset] 922 | elif dataset in group4: 923 | group4dict[t] += datasets[dataset] 924 | elif dataset in group5: 925 | group5dict[t] += datasets[dataset] 926 | elif dataset in group6: 927 | group6dict[t] += datasets[dataset] 928 | elif dataset in group7: 929 | group7dict[t] += datasets[dataset] 930 | else: 931 | print(dataset) 932 | print("train wrong!") 933 | 934 | for t in test_questions: 935 | datasets = test_questions[t] 936 | for dataset in datasets: 937 | if dataset in group1: 938 | group1dict[t] += datasets[dataset] 939 | elif dataset in group2: 940 | group2dict[t] += datasets[dataset] 941 | elif dataset in group3: 942 | group3dict[t] += datasets[dataset] 943 | elif dataset in group4: 944 | group4dict[t] += datasets[dataset] 945 | elif dataset in group5: 946 | group5dict[t] += datasets[dataset] 947 | elif dataset in group6: 948 | group6dict[t] += datasets[dataset] 949 | elif dataset in group7: 950 | group7dict[t] += datasets[dataset] 951 | else: 952 | print(dataset) 953 | print("test wrong!") 954 | 955 | for t in dev_questions: 956 | datasets = dev_questions[t] 957 | for dataset in datasets: 958 | if dataset in group1: 959 | group1dict[t] += datasets[dataset] 960 | elif dataset in group2: 961 | group2dict[t] += datasets[dataset] 962 | elif dataset in group3: 963 | group3dict[t] += datasets[dataset] 964 | elif dataset in group4: 965 | group4dict[t] += datasets[dataset] 966 | elif dataset in group5: 967 | group5dict[t] += datasets[dataset] 968 | elif dataset in group6: 969 | group6dict[t] += datasets[dataset] 970 | elif dataset in group7: 971 | group7dict[t] += datasets[dataset] 972 | else: 973 | print(dataset) 974 | print("dev wrong!") 975 | 976 | print("group1: " + str(dict(group1dict))) 977 | print("group2: " + str(dict(group2dict))) 978 | print("group3: " + str(dict(group3dict))) 979 | print("group4: " + str(dict(group4dict))) 980 | print("group5: " + str(dict(group5dict))) 981 | print("group6: " + str(dict(group6dict))) 982 | print("group7: " + str(dict(group7dict))) 983 | 984 | if __name__ == "__main__": 985 | from openpyxl import load_workbook 986 | wb = load_workbook('testset.xlsx') 987 | sheets = wb.get_sheet_names() 988 | datasets = defaultdict(lambda :defaultdict(int)) 989 | for sheet in sheets: 990 | table = wb.get_sheet_by_name(sheet) 991 | rows = table.max_row 992 | for row in range(rows): 993 | dataset = table.cell(row=row + 1, column=1).value 994 | num = table.cell(row=row + 1, column=2).value 995 | datasets[sheet][dataset] = num 996 | draw_stat(datasets, "annotate") 997 | 998 | 999 | -------------------------------------------------------------------------------- /model/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | class CONFIG(object): 5 | def __init__(self): 6 | # params for train.py(or shared) 7 | self.learning_rate = 2e-5 8 | self.max_input_len = 256 9 | self.batch_size = 32 10 | self.epoch = 100 11 | self.data_max_size = 1e9 # use 100 if debuging 12 | 13 | self.pretrain_model_name = "roberta-base" 14 | # self.data_path = "/home/yusenzhang/input_classify_new/dataset/new_averaged_small_trainset.json" 15 | # self.model_path = "./checkpoints/07270039_params13.pkl" 16 | # self.model_path = "../input_classification/params2.pkl" 17 | # self.model_path = "/home/yusenzhang/input_classify/models/input_multi_classification/checkpoints/07301738_params10.pkl" 18 | self.model_path = "/home/yusenzhang/input_classify_new/models/input_multi_classification/checkpoints/08202014_params24.pkl" 19 | # self.model_path = "./checkpoints/08130317_params7.pkl" 20 | self.use_gpu = True 21 | self.device = 0 22 | self.save = True 23 | self.load_model = True # remember to set this value when evaluating 24 | self.multi_GPU = True 25 | 26 | # params only for eval.py 27 | self.test_path = "/home/yusenzhang/input_classify_new/dataset/final_turncated_testset.json" 28 | self.train_path = "/home/yusenzhang/input_classify_new/dataset/final_small_trainset.json" 29 | self.dev_path = "/home/yusenzhang/input_classify_new/dataset/final_devset.json" 30 | # self.test_path = "/home/yusenzhang/input_classify/models/input_multi_classification/dataset/type12_testset.json" 31 | # self.test_path = "/home/yusenzhang/input_classify/models/input_multi_classification/dataset/written_100_testset.json" 32 | self.label_dict = { 33 | 'answerable': 0, 34 | 'small talk': 1, 35 | 'ambiguous' : 2, 36 | 'lack data' : 3, 37 | 'unanswerable by sql' : 4 38 | } 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /model/dataset/dataset_preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from models.input_multi_classification.utils import * 4 | from models.input_multi_classification.config import * 5 | 6 | ablation = False 7 | 8 | def analyse(data1, data2): 9 | for i in range(3): 10 | overlap = set(data1[0]).intersection(set(data2[0])) 11 | print(overlap.__len__()) 12 | 13 | def ramdon_select(data, max_size=1e4): 14 | np.random.seed(42) 15 | dataset_size = int(min(max_size, len(data))) 16 | index = np.arange(0, len(data)) 17 | np.random.shuffle(index) 18 | index = index[: dataset_size] 19 | data = [data[i] for i in index] 20 | return data 21 | 22 | 23 | if __name__ == '__main__': 24 | cfg = CONFIG() 25 | # [type,query,db,schema_name,schema] 26 | data_type = 'testset' 27 | data = [] 28 | 29 | # for i in range(1, 6): 30 | # type_data = load_data(os.path.join(data_type, "type{}.json".format(i))) 31 | # if i == 5 and ablation: type_data = [x for i, x in enumerate(type_data) if i % 2 == 0] 32 | # type_data = ramdon_select(type_data, max_size=1e9) 33 | # data += type_data 34 | # print("type{}".format(i), len(type_data)) 35 | 36 | data = load_data('testset/human_annotated/test.json') 37 | 38 | # hacking dataset contruction 39 | query = [] 40 | dataset = [] 41 | label = [] 42 | cnt = 0 43 | delete_table = re.compile("[0-9]+-[0-9]+-[0-9]+") 44 | 45 | for sample in data: 46 | if len(sample['question']) == 0: continue 47 | schema_list = [] 48 | for table_name, columns in sample['tables'].items(): 49 | schema_list += [column_name if delete_table.match(table_name) 50 | else table_name + '.' + column_name for column_name in columns] 51 | if not len(schema_list): continue 52 | # no numbers in schema list 53 | # all in lower case 54 | query.append(sample['question'].lower()) 55 | dataset.append(' | '.join(schema_list).lower()) 56 | label.append(cfg.label_dict[sample['type']]) 57 | 58 | json.dump([query, dataset, label], open('final_turncated_{}.json'.format(data_type), 'w')) 59 | 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /model/eval.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.utils.data import Dataset 3 | from torch.utils.data import DataLoader 4 | from sklearn import metrics 5 | from sklearn.metrics import classification_report, confusion_matrix 6 | from config import * 7 | from utils import * 8 | from model import TransformerMultiClassifier 9 | import time 10 | 11 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' 12 | time_stamp = time.strftime("%m%d%H%M", time.localtime()) 13 | DEBUG = False 14 | 15 | def eval(model, test_loader, cfg): 16 | # evaluation 17 | model.eval() 18 | eval_labels = [] 19 | eval_results = [] 20 | for batch in test_loader: 21 | # move batch to GPU 22 | if torch.cuda.is_available() and cfg.use_gpu is True: 23 | batch = [x.cuda(cfg.device) for x in batch] 24 | 25 | # forward 26 | classification_logits = model(batch[0], batch[1])[0] 27 | batch_results = torch.argmax(classification_logits, dim=1).cpu().numpy() 28 | 29 | # add to list 30 | eval_labels += batch[2].cpu().tolist() 31 | eval_results += batch_results.tolist() 32 | 33 | print(classification_report(eval_labels, eval_results, 34 | target_names=[{y: x for x, y in cfg.label_dict.items()}[i] for i in range(len(cfg.label_dict))])) 35 | print(cfg.label_dict) 36 | print(confusion_matrix(eval_labels, eval_results)) 37 | model.train() 38 | 39 | 40 | 41 | if __name__ == '__main__': 42 | # load config 43 | cfg = CONFIG() 44 | 45 | # load dataset 46 | print("Loading dataset...") 47 | # test_data = load_data(cfg.test_path) 48 | # queries, databases, labels = turn_5type_to_multiclass(get_5type_sequences(test_data), cfg.label_dict) 49 | queries, databases, labels = load_data(cfg.test_path) 50 | if DEBUG: 51 | import random 52 | index = [i for i in range(len(queries))] 53 | random.shuffle(index) 54 | index = index[:500] 55 | queries = [queries[i] for i in index] 56 | databases = [databases[i] for i in index] 57 | labels = [labels[i] for i in index] 58 | # tokenize 59 | print("Tokenizing...") 60 | test_tokens, test_labels = tokenize_sequences(queries, databases, labels, cfg.pretrain_model_name, cfg.max_input_len) 61 | 62 | # load model and the others 63 | model = TransformerMultiClassifier(cfg.pretrain_model_name, num_labels=len(cfg.label_dict)) 64 | criterion = nn.CrossEntropyLoss(reduction='mean') 65 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate) 66 | 67 | # multi GPU 68 | if cfg.multi_GPU and torch.cuda.device_count() > 1: 69 | print(f"Using {torch.cuda.device_count()} GPUs") 70 | model = nn.DataParallel(model) 71 | 72 | # load model parameters 73 | if cfg.load_model: 74 | print("Loading model...") 75 | params = torch.load(cfg.model_path) 76 | model.load_state_dict(torch.load(cfg.model_path)) 77 | 78 | # move to gpu 79 | if torch.cuda.is_available() and cfg.use_gpu is True: 80 | model.cuda(cfg.device) 81 | criterion.cuda(cfg.device) 82 | 83 | # use data loader (not in GPU) 84 | test_loader = DataLoader(TensorDataset(*test_tokens.values(), test_labels), 85 | batch_size=cfg.batch_size, shuffle=False, num_workers=4) 86 | # start training 87 | print("starting evaluation") 88 | eval(model, test_loader, cfg) 89 | 90 | 91 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import AutoModelForSequenceClassification 3 | from transformers import AutoConfig, RobertaConfig 4 | 5 | class TransformerMultiClassifier(nn.Module): 6 | def __init__(self, pretrain_model_name='roberta-large', num_labels=5): 7 | super(TransformerMultiClassifier, self).__init__() 8 | model_config = AutoConfig.from_pretrained(pretrain_model_name, num_labels=num_labels) 9 | self.model = AutoModelForSequenceClassification.from_config(model_config) 10 | 11 | def forward(self, input_ids, attention_mask): 12 | return self.model(input_ids, attention_mask) -------------------------------------------------------------------------------- /model/readme.md: -------------------------------------------------------------------------------- 1 | ## A Multi-class query classification model 2 | 3 | This model is designed to classify the queries of the users as described in task 1 of the [document](https://docs.google.com/document/d/10szA5EJz7tYpyUjA3aXaXEFSjOJZ37Ni0ynAS270ksw/edit#). '0' represents answerable, '1'-'4' represent distinct types of unanswerable questions. 4 | 5 | ### Usage and explaination 6 | - usage: ```python train.py``` 7 | - config.py hyper parameters (doesn't support argparse yet, supports GPU) 8 | - train.py training and evaluation of the model 9 | - utils.py loading the dataset and tokenization 10 | - model.py the RoBERTa classification model we used 11 | 12 | ### TODOs 13 | - Change to a complete dataset 14 | - Modify the model 15 | -------------------------------------------------------------------------------- /model/test.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import Dataset, TensorDataset 5 | from torch.utils.data import DataLoader 6 | from transformers import AutoTokenizer 7 | 8 | ## Workplace for testing grama issues -------------------------------------------------------------------------------- /model/train.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.utils.data import Dataset 3 | from torch.utils.data import DataLoader 4 | from sklearn import metrics 5 | from config import * 6 | from utils import * 7 | from model import TransformerMultiClassifier 8 | import time 9 | 10 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' 11 | time_stamp = time.strftime("%m%d%H%M", time.localtime()) 12 | 13 | 14 | def eval(model, test_loader, cfg): 15 | # evaluation 16 | model.eval() 17 | eval_labels = [] 18 | eval_results = [] 19 | for batch in test_loader: 20 | # move batch to GPU 21 | if torch.cuda.is_available() and cfg.use_gpu is True: 22 | batch = [x.cuda(cfg.device) for x in batch] 23 | 24 | # forward 25 | classification_logits = model(batch[0], batch[1])[0] 26 | batch_results = torch.argmax(classification_logits, dim=1).cpu().numpy() 27 | 28 | # add to list 29 | eval_labels += batch[2].cpu().tolist() 30 | eval_results += batch_results.tolist() 31 | 32 | f1_micro = metrics.f1_score(eval_labels, eval_results, average='micro') 33 | f1_macro = metrics.f1_score(eval_labels, eval_results, average='macro') 34 | f1_weighted = metrics.f1_score(eval_labels, eval_results, average='weighted') 35 | model.train() 36 | 37 | print(f"\n[test] loss:{loss}, f1: micro:{f1_micro}, macro:{f1_macro}, weighted:{f1_weighted}") 38 | 39 | 40 | if __name__ == '__main__': 41 | # load config 42 | cfg = CONFIG() 43 | # load dataset 44 | print("Loading dataset...") 45 | # dataset = load_data(cfg.data_path) 46 | # queries, databases, labels = turn_5type_to_multiclass(get_5type_sequences(dataset), cfg.label_dict) 47 | # trainset, testset = split_train_test(queries, databases, labels, max_size=cfg.data_max_size) 48 | trainset = load_data(cfg.train_path) 49 | testset = load_data(cfg.dev_path) 50 | # # turncate 51 | # trainset, _ = split_train_test(*trainset, max_size=cfg.data_max_size) 52 | # testset, _ = split_train_test(*testset, max_size=cfg.data_max_size//2) 53 | 54 | print(f"train size:{len(trainset[0])}, test size: {len(testset[0])}.") 55 | 56 | # tokenize 57 | print("Tokenizing...") 58 | train_tokens, train_labels = tokenize_sequences(*trainset, cfg.pretrain_model_name, cfg.max_input_len) 59 | test_tokens, test_labels = tokenize_sequences(*testset, cfg.pretrain_model_name, cfg.max_input_len) 60 | 61 | # load model and the others 62 | model = TransformerMultiClassifier(cfg.pretrain_model_name, num_labels=len(cfg.label_dict)) 63 | criterion = nn.CrossEntropyLoss(reduction='mean') 64 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate) 65 | 66 | # load model parameters 67 | if cfg.load_model: 68 | print("Loading model...") 69 | model.load_state_dict(torch.load(cfg.model_path)) 70 | 71 | # move to gpu 72 | if torch.cuda.is_available() and cfg.use_gpu is True: 73 | model.cuda(cfg.device) 74 | criterion.cuda(cfg.device) 75 | if cfg.multi_GPU and torch.cuda.device_count() > 1: 76 | print(f"Using {torch.cuda.device_count()} GPUs") 77 | model = nn.DataParallel(model) 78 | 79 | # use data loader (not in GPU) 80 | data_loader = DataLoader(TensorDataset(*train_tokens.values(), train_labels), 81 | batch_size=cfg.batch_size, shuffle=True, num_workers=4) 82 | test_loader = DataLoader(TensorDataset(*test_tokens.values(), test_labels), 83 | batch_size=cfg.batch_size, shuffle=False, num_workers=4) 84 | 85 | # start training 86 | print("Start training...") 87 | for i in range(cfg.epoch): 88 | for j, batch in enumerate(data_loader): 89 | # move batch to GPU 90 | if torch.cuda.is_available() and cfg.use_gpu is True: 91 | batch = [x.cuda(cfg.device) for x in batch] 92 | 93 | batch_labels = batch[2] 94 | classification_logits = model(batch[0], batch[1])[0] 95 | results = torch.softmax(classification_logits, dim=1).cpu().tolist() 96 | loss = criterion(classification_logits, batch_labels) 97 | 98 | optimizer.zero_grad() 99 | loss.backward() 100 | optimizer.step() 101 | 102 | print(f'\r[epoch {i}] [batch {j}/{len(train_labels)//cfg.batch_size}] : loss: {loss}', end=' ') 103 | 104 | if cfg.save is True: 105 | torch.save(model.state_dict(), f'./checkpoints/{time_stamp}_params{i}.pkl') 106 | 107 | eval(model, test_loader, cfg) 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import Dataset, TensorDataset 5 | from torch.utils.data import DataLoader 6 | from transformers import AutoTokenizer 7 | 8 | 9 | def load_data(path): 10 | with open(path, 'r') as file: 11 | json_str= file.read() 12 | dataset = json.loads(json_str) 13 | return dataset 14 | 15 | 16 | def get_sequences_old(dataset): 17 | # dataset: {query:{'1':schemas,'0':schemas}}, schemas:{schema_name:{table_name:[column_names]}} 18 | databases = [] 19 | queries = [] 20 | labels = [] 21 | for query, samples in dataset.items(): 22 | for label, schemas in samples.items(): 23 | schema_list = [] 24 | for schema_name, tables in schemas.items(): 25 | for table_name, column_names in tables.items(): 26 | schema_list += [table_name+'.'+column_name for column_name in column_names] 27 | 28 | schema_seq = ' | '.join(schema_list) 29 | databases.append(schema_seq) 30 | queries.append(query) 31 | labels.append(label) 32 | # return : [query(str)], [schema_seq(str)], [label(str)] 33 | return queries, databases, labels 34 | 35 | 36 | def get_sequences(dataset): 37 | # dataset: {'1':[[question,schema]], '0':[{overlap(str):[question,schema]}]}, 38 | # schema:{table_name:[column_names]} 39 | databases = [] 40 | queries = [] 41 | labels = [] 42 | for label, samples in dataset.items(): 43 | for sample in samples: 44 | # overlap contains in neg samples 45 | if label == '0': 46 | sample = list(sample.items())[-1][1] 47 | queries.append(sample[0]) 48 | schema_list = [] 49 | for table_name, columns in sample[1].items(): 50 | schema_list += [table_name+'.'+column_name for column_name in columns] 51 | databases.append(' | '.join(schema_list)) 52 | labels.append(label) 53 | 54 | # return : [query(str)], [schema_seq(str)], [label(str)] 55 | return queries, databases, labels 56 | 57 | 58 | def get_5type_sequences(dataset, max_len=1e9): 59 | # dataset: {'1/0':{type:[[answerable(str), question(str), schema]]}}, 60 | # schema:{table_name:[column_names]} 61 | # type: small talk, ambiguous, answerable, unanswerable by sql, lack data 62 | data_dict = {} 63 | for label, samples in dataset.items(): 64 | for type, questions in samples.items(): 65 | databases = [] 66 | queries = [] 67 | labels = [] 68 | for sample in questions: 69 | if len(sample[1]) == 0: continue 70 | schema_list = [] 71 | for table_name, columns in sample[2].items(): 72 | schema_list += [table_name + '.' + column_name for column_name in columns] 73 | if not len(schema_list): continue 74 | queries.append(sample[1]) 75 | databases.append(' | '.join(schema_list)) 76 | labels.append(label) 77 | # [query(str)], [schema_seq(str)], [label(str)] 78 | used_len = min(max_len, len(queries)) 79 | data_dict[type] = (queries[:used_len], databases[:used_len], labels[:used_len]) 80 | 81 | # return : {type: dataset} 82 | return data_dict 83 | 84 | 85 | def turn_5type_to_multiclass(data_dict, label_dict): 86 | databases = [] 87 | queries = [] 88 | labels = [] 89 | for type, data in data_dict.items(): 90 | queries += data[0] 91 | databases += data[1] 92 | labels += [label_dict[type]]*len(data[2]) 93 | return queries, databases, labels 94 | 95 | 96 | def tokenize_sequences(queries, databases, labels, model_name="roberta-large", input_max_len=256, label_overwrite=None): 97 | tokenizer = AutoTokenizer.from_pretrained(model_name) 98 | # seq:[label(str), query(str), schema_seq(str)], return: [input(torch.Tensor)], [labels(torch.Tensor)] 99 | return tokenizer.batch_encode_plus(zip(queries, databases), return_tensors="pt", 100 | pad_to_max_length=input_max_len, max_length=input_max_len, truncation=True), \ 101 | torch.LongTensor([int(label) if label_overwrite is None else label_overwrite for label in labels]) 102 | 103 | 104 | def split_train_test(queries, databases, labels, split_ratio=0.8, max_size=1e9, seed=42): 105 | np.random.seed(seed) 106 | dataset_size = int(min(max_size, len(queries))) 107 | split_point = int(split_ratio * dataset_size) 108 | index = np.arange(0, len(queries)) 109 | np.random.shuffle(index) 110 | index = index[: dataset_size] 111 | trainset = ([queries[i] for i in index[:split_point]], 112 | [databases[i] for i in index[:split_point]], 113 | [labels[i] for i in index[:split_point]],) 114 | testset = ([queries[i] for i in index[split_point:]], 115 | [databases[i] for i in index[split_point:]], 116 | [labels[i] for i in index[split_point:]],) 117 | return trainset, testset 118 | 119 | 120 | if __name__ == '__main__': 121 | # dataset = load_data("./dataset/all_datasets.json") 122 | # queries, databases, labels = get_sequences(dataset) 123 | # trainset, testset = split_train_test(queries, databases, labels) 124 | # train_tokens, train_labels = tokenize_sequences(*trainset, "roberta-large") 125 | # test_tokens, test_labels = tokenize_sequences(*testset, "roberta-large") 126 | # 127 | # data_loader = DataLoader(TensorDataset(*train_tokens.values(), train_labels), batch_size=16, shuffle=True, num_workers=4) 128 | # 129 | # for batch in data_loader: 130 | # tokens = batch[0] 131 | # labels = batch[1] 132 | # 133 | # ... 134 | from config import * 135 | cfg = CONFIG() 136 | test_data = load_data(cfg.data_path) 137 | 138 | testset = get_5type_sequences(test_data) 139 | dataset = turn_5type_to_multiclass(testset, cfg.label_dict) 140 | ... 141 | 142 | 143 | -------------------------------------------------------------------------------- /model/visualization.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sn 3 | import numpy as np 4 | def plot_confusion_matrix(cm, 5 | target_names, 6 | title='', 7 | cmap=None, 8 | normalize=True): 9 | """ 10 | given a sklearn confusion matrix (cm), make a Nice plot 11 | 12 | Arguments 13 | --------- 14 | cm: confusion matrix from sklearn.metrics.confusion_matrix 15 | 16 | target_names: given classification classes such as [0, 1, 2] 17 | the class names, for example: ['high', 'medium', 'low'] 18 | 19 | title: the text to display at the top of the matrix 20 | 21 | cmap: the gradient of the values displayed from matplotlib.pyplot.cm 22 | see http://matplotlib.org/examples/color/colormaps_reference.html 23 | plt.get_cmap('jet') or plt.cm.Blues 24 | 25 | normalize: If False, plot the raw numbers 26 | If True, plot the proportions 27 | 28 | Usage 29 | ----- 30 | plot_confusion_matrix(cm = cm, # confusion matrix created by 31 | # sklearn.metrics.confusion_matrix 32 | normalize = True, # show proportions 33 | target_names = y_labels_vals, # list of names of the classes 34 | title = best_estimator_name) # title of graph 35 | 36 | Citiation 37 | --------- 38 | http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html 39 | 40 | """ 41 | import matplotlib.pyplot as plt 42 | import numpy as np 43 | import itertools 44 | 45 | accuracy = np.trace(cm) / float(np.sum(cm)) 46 | misclass = 1 - accuracy 47 | 48 | if cmap is None: 49 | cmap = plt.get_cmap('Blues') 50 | 51 | plt.figure(figsize=(8, 6)) 52 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 53 | plt.title(title) 54 | plt.colorbar() 55 | 56 | if target_names is not None: 57 | tick_marks = np.arange(len(target_names)) 58 | plt.xticks(tick_marks, target_names, rotation=45) 59 | plt.yticks(tick_marks, target_names) 60 | 61 | if normalize: 62 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 63 | 64 | 65 | thresh = cm.max() / 1.5 if normalize else cm.max() / 2 66 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 67 | if normalize: 68 | plt.text(j, i, "{:0.4f}".format(cm[i, j]), 69 | horizontalalignment="center", 70 | color="white" if cm[i, j] > thresh else "black") 71 | else: 72 | plt.text(j, i, "{:,}".format(cm[i, j]), 73 | horizontalalignment="center", 74 | color="white" if cm[i, j] > thresh else "black") 75 | 76 | 77 | plt.tight_layout() 78 | plt.ylabel('True label') 79 | # plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass)) 80 | plt.xlabel('Predicated label') 81 | plt.show() 82 | 83 | 84 | # con = [[411, 12, 10, 66, 1], 85 | # [ 32, 436, 3, 8, 21], 86 | # [113, 60, 45, 10, 0], 87 | # [270, 2 , 1, 227, 0], 88 | # [ 1, 35, 2, 1, 461]] 89 | # # ax= plt.subplot() 90 | # # sn.heatmap(con, annot=False, cmap='Purples', ax=ax) 91 | # label = ['Answerable', 'Improper', 'Ambiguous', 'ExtKnow', 92 | # 'Non-SQL'] 93 | 94 | con = [ 95 | [ 436, 8, 3, 21, 32,], 96 | [ 2, 227, 1, 0, 270,], 97 | [ 60, 10, 45, 0, 113,], 98 | [ 35, 1, 2, 461, 1,], 99 | [ 12, 66, 10, 1, 411,] 100 | ] 101 | # ax= plt.subplot() 102 | # sn.heatmap(con, annot=False, cmap='Purples', ax=ax) 103 | label = [ 'Improper','ExtKnow', 'Ambiguous', 'Non-SQL', 104 | 'Answerable'] 105 | 106 | # ax.xaxis.set_ticklabels(label) 107 | # ax.yaxis.set_ticklabels(label) 108 | # plt.show() 109 | plot_confusion_matrix(np.array(con), label, normalize=False) --------------------------------------------------------------------------------