├── .gitignore ├── EditSQL ├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── agent.py ├── data_util │ ├── __init__.py │ ├── anonymization.py │ ├── atis_batch.py │ ├── atis_data.py │ ├── atis_vocab.py │ ├── dataset_split.py │ ├── entities.py │ ├── interaction.py │ ├── snippets.py │ ├── sql_util.py │ ├── tokenizers.py │ ├── util.py │ ├── utterance.py │ └── vocabulary.py ├── environment.py ├── error_detector.py ├── eval_scripts │ ├── __init__.py │ ├── evaluation.py │ ├── evaluation_sqa.py │ ├── metric_averages.py │ └── process_sql.py ├── logger.py ├── logs_clean │ └── logs_spider_editsql │ │ └── records_train_nointeract_gold.json ├── model │ ├── __init__.py │ ├── attention.py │ ├── bert │ │ ├── LICENSE_bert │ │ ├── README_bert.md │ │ ├── __init__.py │ │ ├── convert_tf_checkpoint_to_pytorch.py │ │ ├── data │ │ │ └── annotated_wikisql_and_PyTorch_bert_param │ │ │ │ ├── bert_config_uncased_L-12_H-768_A-12.json │ │ │ │ └── vocab_uncased_L-12_H-768_A-12.txt │ │ ├── modeling.py │ │ ├── notebooks │ │ │ ├── Comparing TF and PT models SQuAD predictions.ipynb │ │ │ └── Comparing TF and PT models.ipynb │ │ └── tokenization.py │ ├── decoder.py │ ├── embedder.py │ ├── encoder.py │ ├── model.py │ ├── schema_interaction_model.py │ ├── token_predictor.py │ ├── torch_utils.py │ └── utils_bert.py ├── model_util.py ├── parse_args.py ├── postprocess_eval.py ├── preprocess.py ├── question_gen.py ├── requirements.txt ├── run.py └── world_model.py ├── EditSQL_run.py ├── LICENSE ├── MISP.png ├── MISP_SQL ├── __init__.py ├── agent.py ├── environment.py ├── error_detector.py ├── question_gen.py ├── semantic_tag_logic.txt ├── tag_seq_logic.md ├── utils.py └── world_model.py ├── README.md ├── SQLova_model ├── README.md ├── __init__.py ├── agent.py ├── annotate_ws.py ├── bert │ ├── LICENSE_bert │ ├── README_bert.md │ ├── __init__.py │ ├── convert_tf_checkpoint_to_pytorch.py │ ├── modeling.py │ ├── notebooks │ │ ├── Comparing TF and PT models SQuAD predictions.ipynb │ │ └── Comparing TF and PT models.ipynb │ └── tokenization.py ├── download │ ├── bert │ │ ├── bert_config_uncased_L-12_H-768_A-12.json │ │ └── vocab_uncased_L-12_H-768_A-12.txt │ └── data │ │ ├── online_setup_10p.json │ │ ├── online_setup_1p.json │ │ └── online_setup_5p.json ├── environment.py ├── error_detector.py ├── evaluate_ws.py ├── sqlnet │ ├── LICENSE │ ├── __init__.py │ └── dbengine.py ├── sqlova │ ├── __init__.py │ ├── model │ │ ├── __init__.py │ │ └── nl2sql │ │ │ ├── __init__.py │ │ │ └── wikisql_models.py │ └── utils │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── utils_wikisql.py │ │ └── wikisql_formatter.py ├── train_decoder_layer.py ├── train_shallow_layer.py ├── wikisql │ ├── LICENSE_WikiSQL │ ├── annotate.py │ ├── evaluate.py │ └── lib │ │ ├── common.py │ │ ├── dbengine.py │ │ ├── query.py │ │ └── table.py └── world_model.py ├── SQLova_train.py ├── gpu-py3.yml ├── interaction_editsql.py ├── interaction_sqlova.py ├── overview.png ├── scripts ├── editsql │ ├── bin_user.sh │ ├── bin_user_expert.sh │ ├── full_expert.sh │ ├── misp_neil.sh │ ├── misp_neil_perfect.sh │ ├── misp_neil_pos.sh │ ├── pretrain.sh │ ├── self_train_0.5.sh │ ├── test.sh │ └── test_with_interaction.sh └── sqlova │ ├── bin_user.sh │ ├── bin_user_expert.sh │ ├── data_preprocess.sh │ ├── full_expert.sh │ ├── misp_neil.sh │ ├── misp_neil_perfect.sh │ ├── misp_neil_pos.sh │ ├── pretrain.sh │ ├── self_train_0.5.sh │ ├── test.sh │ └── test_with_interaction.sh ├── slides └── MISP_NEIL_EMNLP20_slides.pdf ├── text2sql.png └── user_study_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.idea 2 | *.pyc 3 | __pycache__/* 4 | *.DS_Store 5 | EditSQL/model/bert/data/annotated_wikisql_and_PyTorch_bert_param/pytorch_model_uncased_L-12_H-768_A-12.bin 6 | EditSQL/word_emb/glove.840B.300d.txt 7 | EditSQL/data_clean/ 8 | EditSQL/logs_clean/logs_spider_editsql_10p/ 9 | SQLova_model/logs/ 10 | SQLova_model/download/data/ 11 | SQLova_model/download/bert/pytorch_model_uncased_L-12_H-768_A-12.bin 12 | SQLova_model/checkpoints_online_pretrain_10p 13 | SQLova_model/checkpoints_online_pretrain_1p 14 | SQLova_model/checkpoints_online_pretrain_5p -------------------------------------------------------------------------------- /EditSQL/.gitattributes: -------------------------------------------------------------------------------- 1 | save_* filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /EditSQL/.gitignore: -------------------------------------------------------------------------------- 1 | logs/* 2 | data/* 3 | -------------------------------------------------------------------------------- /EditSQL/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Rui Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /EditSQL/README.md: -------------------------------------------------------------------------------- 1 | # EditSQL Experiments 2 | 3 | ## 1. Description 4 | This folder contains implementation of **interactive EditSQL parser**, which uses EditSQL as a base semantic parser in our MISP framework: 5 | - Please follow [2. General Environment Setup](#2-general-environment-setup) and set up the environment/data; 6 | - For testing interactive EditSQL on the fly (our EMNLP'19 setting), see [3. MISP with EditSQL](#3-misp-with-editsql); 7 | - For learning EditSQL from user interaction (our EMNLP'20 setting), see [4. Learning EditSQL from user interaction (EMNLP'20)](#4-learning-editsql-from-user-interaction-emnlp20). 8 | 9 | The implementation is adapted from [the EditSQL repository](https://github.com/ryanzhumich/editsql). 10 | Please cite the following papers if you use the code: 11 | 12 | ``` 13 | @inproceedings{yao2020imitation, 14 | title={An Imitation Game for Learning Semantic Parsers from User Interaction}, 15 | author={Yao, Ziyu and Tang, Yiqi and Yih, Wen-tau and Sun, Huan and Su, Yu}, 16 | booktitle={Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)}, 17 | year={2020} 18 | } 19 | 20 | @inproceedings{yao2019model, 21 | title={Model-based Interactive Semantic Parsing: A Unified Framework and A Text-to-SQL Case Study}, 22 | author={Yao, Ziyu and Su, Yu and Sun, Huan and Yih, Wen-tau}, 23 | booktitle={Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)}, 24 | pages={5450--5461}, 25 | year={2019} 26 | } 27 | 28 | @InProceedings{zhang2019editing, 29 | author = "Rui Zhang, Tao Yu, He Yang Er, Sungrok Shim, Eric Xue, Xi Victoria Lin, Tianze Shi, Caiming Xiong, Richard Socher, Dragomir Radev", 30 | title = "Editing-Based SQL Query Generation for Cross-Domain Context-Dependent Questions", 31 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing", 32 | year = "2019", 33 | address = "Hong Kong, China" 34 | } 35 | ``` 36 | 37 | ## 2. General Environment Setup 38 | ### Environment 39 | - Please install the Anaconda environment from [`gpu-py3.yml`](../gpu-py3.yml): 40 | ``` 41 | conda env create -f gpu-py3.yml 42 | ``` 43 | - Download the Glove word embedding from [here](https://nlp.stanford.edu/data/glove.840B.300d.zip) \(use `wget` for command line\) and put it as `EditSQL/word_emb/glove.840B.300d.txt`. 44 | 45 | - Download Pretrained BERT model from [here](https://drive.google.com/file/d/1f_LEWVgrtZLRuoiExJa5fNzTS8-WcAX9/view?usp=sharing) as `EditSQL/model/bert/data/annotated_wikisql_and_PyTorch_bert_param/pytorch_model_uncased_L-12_H-768_A-12.bin`. If using command line: 46 | ``` 47 | gdown https://drive.google.com/u/0/uc?id=1f_LEWVgrtZLRuoiExJa5fNzTS8-WcAX9 48 | ``` 49 | 50 | 51 | ### Data 52 | We have the pre-processed and cleaned [Spider data](https://yale-lily.github.io/spider) available: [data_clean.tar](https://buckeyemailosu-my.sharepoint.com/:u:/g/personal/chen_8336_buckeyemail_osu_edu/EbCsSRb8MKVNiVktVsze5GMBxp9TccTRSjZpT-VYB6tNpg?e=7ilzZW). 53 | Please download and uncompress it via `tar -xvf data_clean.tar` as a folder `EditSQL/data_clean`. 54 | Note that the training set has been cleaned with its size reduced (see [our paper](https://arxiv.org/pdf/2005.00689.pdf), Appendix B.3 for details). 55 | 56 | 57 | ## 3. MISP with EditSQL 58 | We explain how to build and test EditSQL under MISP following our EMNLP'19 setting. 59 | 60 | ### 3.1 Model training 61 | To train EditSQL on the full training set, please revise `SETTING=''` (empty string) in [scripts/editsql/pretrain.sh](../scripts/editsql/pretrain.sh). 62 | In the main directory, run: 63 | ``` 64 | bash scripts/editsql/pretrain.sh 65 | ``` 66 | 67 | ### 3.2 Model testing without interaction 68 | To test EditSQL (trained on the full training set) regularly, in [scripts/editsql/test.sh](../scripts/editsql/test.sh), 69 | please revise `SETTING=''` (empty string) to ensure the `LOGDIR` loads the desired model checkpoint. 70 | In the main directory, run: 71 | ``` 72 | bash scripts/editsql/test.sh 73 | ``` 74 | 75 | ### 3.3 Model testing with simulated user interaction 76 | To test EditSQL (trained on the full training set) with human interaction under the MISP framework, in [scripts/editsql/test_with_interaction.sh](../scripts/editsql/test_with_interaction.sh), 77 | revise `SETTING='full_train'` to ensure the `LOGDIR` loads the desired model checkpoint. 78 | In the main directory, run: 79 | ``` 80 | bash scripts/editsql/test_with_interaction.sh 81 | ``` 82 | 83 | 84 | ## 4. Learning EditSQL from user interaction (EMNLP'20) 85 | ### 4.1 Pretraining 86 | 87 | #### 4.1.1 Pretrain by yourself 88 | Before interactive learning, we pretrain the EditSQL parser with 10% of the full training set. 89 | Please ensure `SETTING='_10p'` in [scripts/editsql/pretrain.sh](../scripts/editsql/pretrain.sh). 90 | Then in the main directory, run: 91 | ``` 92 | bash scripts/editsql/pretrain.sh 93 | ``` 94 | When the training is finished, please rename and move the best model checkpoint from `EditSQL/logs_clean/logs_spider_editsql_10p/pretraining/save_X` 95 | to `EditSQL/logs_clean/logs_spider_editsql_10p/model_best.pt`. 96 | 97 | #### 4.1.2 Use our pretrained checkpoint 98 | You can also use our pretrained checkpoint: [logs_clean.tar](https://buckeyemailosu-my.sharepoint.com/:u:/g/personal/chen_8336_buckeyemail_osu_edu/EQmsNq-xPpBJk7iBURgT1o4BzuFX5S329AfcWU9SEMzRGQ?e=CPUNu5). 99 | Please download and uncompress the content as `EditSQL/logs_clean/ogs_spider_editsql_10p/model_best.pt`. 100 | 101 | 102 | #### 4.1.3 Test the pretrained model 103 | To test the pretrained parser without user interaction, see [3.2 Model testing without interaction](#32-model-testing-without-interaction). 104 | To test the pretrained parser with simulated user interaction, see [3.3 Model testing with simulated user interaction](#33-model-testing-with-simulated-user-interaction). 105 | Make sure `SETTING=online_pretrain_10p` is set in the scripts. 106 | 107 | ### 4.2 Interactive learning 108 | 109 | The training script for each algorithm can be found below. Please run them in the main directory. 110 | 111 | | Algorithm | Script | 112 | | ------------- | ------------- | 113 | | MISP_NEIL | [`scripts/editsql/misp_neil.sh`](../scripts/editsql/misp_neil.sh) | 114 | | Full Expert | [`scripts/editsql/full_expert.sh`](../scripts/editsql/full_expert.sh) | 115 | | Self Train | [`scripts/editsql/self_train_0.5.sh`](../scripts/editsql/self_train_0.5.sh) | 116 | | MISP_NEIL* | [`scripts/editsql/misp_neil_perfect.sh`](../scripts/editsql/misp_neil_perfect.sh) | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /EditSQL/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlab-osu/MISP/7870566ab6b9e121d648478968367bc79c12f7ef/EditSQL/__init__.py -------------------------------------------------------------------------------- /EditSQL/agent.py: -------------------------------------------------------------------------------- 1 | from MISP_SQL.agent import Agent 2 | -------------------------------------------------------------------------------- /EditSQL/data_util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlab-osu/MISP/7870566ab6b9e121d648478968367bc79c12f7ef/EditSQL/data_util/__init__.py -------------------------------------------------------------------------------- /EditSQL/data_util/anonymization.py: -------------------------------------------------------------------------------- 1 | """Code for identifying and anonymizing entities in NL and SQL.""" 2 | 3 | import copy 4 | import json 5 | from . import util 6 | 7 | ENTITY_NAME = "ENTITY" 8 | CONSTANT_NAME = "CONSTANT" 9 | TIME_NAME = "TIME" 10 | SEPARATOR = "#" 11 | 12 | 13 | def timeval(string): 14 | """Returns the numeric version of a time. 15 | 16 | Inputs: 17 | string (str): String representing a time. 18 | 19 | Returns: 20 | String representing the absolute time. 21 | """ 22 | if string.endswith("am") or string.endswith( 23 | "pm") and string[:-2].isdigit(): 24 | numval = int(string[:-2]) 25 | if len(string) == 3 or len(string) == 4: 26 | numval *= 100 27 | if string.endswith("pm"): 28 | numval += 1200 29 | return str(numval) 30 | return "" 31 | 32 | 33 | def is_time(string): 34 | """Returns whether a string represents a time. 35 | 36 | Inputs: 37 | string (str): String to check. 38 | 39 | Returns: 40 | Whether the string represents a time. 41 | """ 42 | if string.endswith("am") or string.endswith("pm"): 43 | if string[:-2].isdigit(): 44 | return True 45 | 46 | return False 47 | 48 | 49 | def deanonymize(sequence, ent_dict, key): 50 | """Deanonymizes a sequence. 51 | 52 | Inputs: 53 | sequence (list of str): List of tokens to deanonymize. 54 | ent_dict (dict str->(dict str->str)): Maps from tokens to the entity dictionary. 55 | key (str): The key to use, in this case either natural language or SQL. 56 | 57 | Returns: 58 | Deanonymized sequence of tokens. 59 | """ 60 | new_sequence = [] 61 | for token in sequence: 62 | if token in ent_dict: 63 | new_sequence.extend(ent_dict[token][key]) 64 | else: 65 | new_sequence.append(token) 66 | 67 | return new_sequence 68 | 69 | 70 | class Anonymizer: 71 | """Anonymization class for keeping track of entities in this domain and 72 | scripts for anonymizing/deanonymizing. 73 | 74 | Members: 75 | anonymization_map (list of dict (str->str)): Containing entities from 76 | the anonymization file. 77 | entity_types (list of str): All entities in the anonymization file. 78 | keys (set of str): Possible keys (types of text handled); in this case it should be 79 | one for natural language and another for SQL. 80 | entity_set (set of str): entity_types as a set. 81 | """ 82 | def __init__(self, filename): 83 | self.anonymization_map = [] 84 | self.entity_types = [] 85 | self.keys = set() 86 | 87 | pairs = [json.loads(line) for line in open(filename).readlines()] 88 | for pair in pairs: 89 | for key in pair: 90 | if key != "type": 91 | self.keys.add(key) 92 | self.anonymization_map.append(pair) 93 | if pair["type"] not in self.entity_types: 94 | self.entity_types.append(pair["type"]) 95 | 96 | self.entity_types.append(ENTITY_NAME) 97 | self.entity_types.append(CONSTANT_NAME) 98 | self.entity_types.append(TIME_NAME) 99 | 100 | self.entity_set = set(self.entity_types) 101 | 102 | def get_entity_type_from_token(self, token): 103 | """Gets the type of an entity given an anonymized token. 104 | 105 | Inputs: 106 | token (str): The entity token. 107 | 108 | Returns: 109 | str, representing the type of the entity. 110 | """ 111 | # these are in the pattern NAME:#, so just strip the thing after the 112 | # colon 113 | colon_loc = token.index(SEPARATOR) 114 | entity_type = token[:colon_loc] 115 | assert entity_type in self.entity_set 116 | 117 | return entity_type 118 | 119 | def is_anon_tok(self, token): 120 | """Returns whether a token is an anonymized token or not. 121 | 122 | Input: 123 | token (str): The token to check. 124 | 125 | Returns: 126 | bool, whether the token is an anonymized token. 127 | """ 128 | return token.split(SEPARATOR)[0] in self.entity_set 129 | 130 | def get_anon_id(self, token): 131 | """Gets the entity index (unique ID) for a token. 132 | 133 | Input: 134 | token (str): The token to get the index from. 135 | 136 | Returns: 137 | int, the token ID if it is an anonymized token; otherwise -1. 138 | """ 139 | if self.is_anon_tok(token): 140 | return self.entity_types.index(token.split(SEPARATOR)[0]) 141 | else: 142 | return -1 143 | 144 | def anonymize(self, 145 | sequence, 146 | tok_to_entity_dict, 147 | key, 148 | add_new_anon_toks=False): 149 | """Anonymizes a sequence. 150 | 151 | Inputs: 152 | sequence (list of str): Sequence to anonymize. 153 | tok_to_entity_dict (dict): Existing dictionary mapping from anonymized 154 | tokens to entities. 155 | key (str): Which kind of text this is (natural language or SQL) 156 | add_new_anon_toks (bool): Whether to add new entities to tok_to_entity_dict. 157 | 158 | Returns: 159 | list of str, the anonymized sequence. 160 | """ 161 | # Sort the token-tok-entity dict by the length of the modality. 162 | sorted_dict = sorted(tok_to_entity_dict.items(), 163 | key=lambda k: len(k[1][key]))[::-1] 164 | 165 | anonymized_sequence = copy.deepcopy(sequence) 166 | 167 | if add_new_anon_toks: 168 | type_counts = {} 169 | for entity_type in self.entity_types: 170 | type_counts[entity_type] = 0 171 | for token in tok_to_entity_dict: 172 | entity_type = self.get_entity_type_from_token(token) 173 | type_counts[entity_type] += 1 174 | 175 | # First find occurrences of things in the anonymization dictionary. 176 | for token, modalities in sorted_dict: 177 | our_modality = modalities[key] 178 | 179 | # Check if this key's version of the anonymized thing is in our 180 | # sequence. 181 | while util.subsequence(our_modality, anonymized_sequence): 182 | found = False 183 | for startidx in range( 184 | len(anonymized_sequence) - len(our_modality) + 1): 185 | if anonymized_sequence[startidx:startidx + 186 | len(our_modality)] == our_modality: 187 | anonymized_sequence = anonymized_sequence[:startidx] + [ 188 | token] + anonymized_sequence[startidx + len(our_modality):] 189 | found = True 190 | break 191 | assert found, "Thought " \ 192 | + str(our_modality) + " was in [" \ 193 | + str(anonymized_sequence) + "] but could not find it" 194 | 195 | # Now add new keys if they are present. 196 | if add_new_anon_toks: 197 | 198 | # For every span in the sequence, check whether it is in the anon map 199 | # for this modality 200 | sorted_anon_map = sorted(self.anonymization_map, 201 | key=lambda k: len(k[key]))[::-1] 202 | 203 | for pair in sorted_anon_map: 204 | our_modality = pair[key] 205 | 206 | token_type = pair["type"] 207 | new_token = token_type + SEPARATOR + \ 208 | str(type_counts[token_type]) 209 | 210 | while util.subsequence(our_modality, anonymized_sequence): 211 | found = False 212 | for startidx in range( 213 | len(anonymized_sequence) - len(our_modality) + 1): 214 | if anonymized_sequence[startidx:startidx + \ 215 | len(our_modality)] == our_modality: 216 | if new_token not in tok_to_entity_dict: 217 | type_counts[token_type] += 1 218 | tok_to_entity_dict[new_token] = pair 219 | 220 | anonymized_sequence = anonymized_sequence[:startidx] + [ 221 | new_token] + anonymized_sequence[startidx + len(our_modality):] 222 | found = True 223 | break 224 | assert found, "Thought " \ 225 | + str(our_modality) + " was in [" \ 226 | + str(anonymized_sequence) + "] but could not find it" 227 | 228 | # Also replace integers with constants 229 | for index, token in enumerate(anonymized_sequence): 230 | if token.isdigit() or is_time(token): 231 | if token.isdigit(): 232 | entity_type = CONSTANT_NAME 233 | value = new_token 234 | if is_time(token): 235 | entity_type = TIME_NAME 236 | value = timeval(token) 237 | 238 | # First try to find the constant in the entity dictionary already, 239 | # and get the name if it's found. 240 | new_token = "" 241 | new_dict = {} 242 | found = False 243 | for entity, value in tok_to_entity_dict.items(): 244 | if value[key][0] == token: 245 | new_token = entity 246 | new_dict = value 247 | found = True 248 | break 249 | 250 | if not found: 251 | new_token = entity_type + SEPARATOR + \ 252 | str(type_counts[entity_type]) 253 | new_dict = {} 254 | for tempkey in self.keys: 255 | new_dict[tempkey] = [token] 256 | 257 | tok_to_entity_dict[new_token] = new_dict 258 | type_counts[entity_type] += 1 259 | 260 | anonymized_sequence[index] = new_token 261 | 262 | return anonymized_sequence 263 | -------------------------------------------------------------------------------- /EditSQL/data_util/atis_vocab.py: -------------------------------------------------------------------------------- 1 | """Gets and stores vocabulary for the ATIS data.""" 2 | 3 | from . import snippets 4 | from .vocabulary import Vocabulary, UNK_TOK, DEL_TOK, EOS_TOK 5 | 6 | INPUT_FN_TYPES = [UNK_TOK, DEL_TOK, EOS_TOK] 7 | OUTPUT_FN_TYPES = [UNK_TOK, EOS_TOK] 8 | 9 | MIN_INPUT_OCCUR = 1 10 | MIN_OUTPUT_OCCUR = 1 11 | 12 | class ATISVocabulary(): 13 | """ Stores the vocabulary for the ATIS data. 14 | 15 | Attributes: 16 | raw_vocab (Vocabulary): Vocabulary object. 17 | tokens (set of str): Set of all of the strings in the vocabulary. 18 | inorder_tokens (list of str): List of all tokens, with a strict and 19 | unchanging order. 20 | """ 21 | def __init__(self, 22 | token_sequences, 23 | filename, 24 | params, 25 | is_input='input', 26 | min_occur=1, 27 | anonymizer=None, 28 | skip=None): 29 | 30 | if is_input=='input': 31 | functional_types = INPUT_FN_TYPES 32 | elif is_input=='output': 33 | functional_types = OUTPUT_FN_TYPES 34 | elif is_input=='schema': 35 | functional_types = [UNK_TOK] 36 | else: 37 | functional_types = [] 38 | 39 | self.raw_vocab = Vocabulary( 40 | token_sequences, 41 | filename, 42 | functional_types=functional_types, 43 | min_occur=min_occur, 44 | ignore_fn=lambda x: snippets.is_snippet(x) or ( 45 | anonymizer and anonymizer.is_anon_tok(x)) or (skip and x in skip) ) 46 | self.tokens = set(self.raw_vocab.token_to_id.keys()) 47 | self.inorder_tokens = self.raw_vocab.id_to_token 48 | 49 | assert len(self.inorder_tokens) == len(self.raw_vocab) 50 | 51 | def __len__(self): 52 | return len(self.raw_vocab) 53 | 54 | def token_to_id(self, token): 55 | """ Maps from a token to a unique ID. 56 | 57 | Inputs: 58 | token (str): The token to look up. 59 | 60 | Returns: 61 | int, uniquely identifying the token. 62 | """ 63 | return self.raw_vocab.token_to_id[token] 64 | 65 | def id_to_token(self, identifier): 66 | """ Maps from a unique integer to an identifier. 67 | 68 | Inputs: 69 | identifier (int): The unique ID. 70 | 71 | Returns: 72 | string, representing the token. 73 | """ 74 | return self.raw_vocab.id_to_token[identifier] 75 | -------------------------------------------------------------------------------- /EditSQL/data_util/dataset_split.py: -------------------------------------------------------------------------------- 1 | """ Utility functions for loading and processing ATIS data. 2 | """ 3 | import os 4 | import pickle 5 | 6 | class DatasetSplit: 7 | """Stores a split of the ATIS dataset. 8 | 9 | Attributes: 10 | examples (list of Interaction): Stores the examples in the split. 11 | """ 12 | def __init__(self, processed_filename, raw_filename, load_function): 13 | if os.path.exists(processed_filename): 14 | print("Loading preprocessed data from " + processed_filename) 15 | with open(processed_filename, 'rb') as infile: 16 | self.examples = pickle.load(infile) 17 | else: 18 | print( 19 | "Loading raw data from " + 20 | raw_filename + 21 | " and writing to " + 22 | processed_filename) 23 | 24 | infile = open(raw_filename, 'rb') 25 | examples_from_file = pickle.load(infile) 26 | assert isinstance(examples_from_file, list), raw_filename + \ 27 | " does not contain a list of examples" 28 | infile.close() 29 | 30 | self.examples = [] 31 | for example in examples_from_file: 32 | obj, keep = load_function(example) 33 | 34 | if keep: 35 | self.examples.append(obj) 36 | 37 | 38 | print("Loaded " + str(len(self.examples)) + " examples") 39 | outfile = open(processed_filename, 'wb') 40 | pickle.dump(self.examples, outfile) 41 | outfile.close() 42 | 43 | def get_ex_properties(self, function): 44 | """ Applies some function to the examples in the dataset. 45 | 46 | Inputs: 47 | function: (lambda Interaction -> T): Function to apply to all 48 | examples. 49 | 50 | Returns 51 | list of the return value of the function 52 | """ 53 | elems = [] 54 | for example in self.examples: 55 | elems.append(function(example)) 56 | return elems 57 | -------------------------------------------------------------------------------- /EditSQL/data_util/entities.py: -------------------------------------------------------------------------------- 1 | """ Classes for keeping track of the entities in a natural language string. """ 2 | import json 3 | 4 | 5 | class NLtoSQLDict: 6 | """ 7 | Entity dict file should contain, on each line, a JSON dictionary with 8 | "input" and "output" keys specifying the string for the input and output 9 | pairs. The idea is that the existence of the key in an input sequence 10 | likely corresponds to the existence of the value in the output sequence. 11 | 12 | The entity_dict should map keys (input strings) to a list of values (output 13 | strings) where this property holds. This allows keys to map to multiple 14 | output strings (e.g. for times). 15 | """ 16 | def __init__(self, entity_dict_filename): 17 | self.entity_dict = {} 18 | 19 | pairs = [json.loads(line) 20 | for line in open(entity_dict_filename).readlines()] 21 | for pair in pairs: 22 | input_seq = pair["input"] 23 | output_seq = pair["output"] 24 | if input_seq not in self.entity_dict: 25 | self.entity_dict[input_seq] = [] 26 | self.entity_dict[input_seq].append(output_seq) 27 | 28 | def get_sql_entities(self, tokenized_nl_string): 29 | """ 30 | Gets the output-side entities which correspond to the input entities in 31 | the input sequence. 32 | Inputs: 33 | tokenized_input_string: list of tokens in the input string. 34 | Outputs: 35 | set of output strings. 36 | """ 37 | assert len(tokenized_nl_string) > 0 38 | flat_input_string = " ".join(tokenized_nl_string) 39 | entities = [] 40 | 41 | # See if any input strings are in our input sequence, and add the 42 | # corresponding output strings if so. 43 | for entry, values in self.entity_dict.items(): 44 | in_middle = " " + entry + " " in flat_input_string 45 | 46 | leftspace = " " + entry 47 | at_end = leftspace in flat_input_string and flat_input_string.endswith( 48 | leftspace) 49 | 50 | rightspace = entry + " " 51 | at_beginning = rightspace in flat_input_string and flat_input_string.startswith( 52 | rightspace) 53 | if in_middle or at_end or at_beginning: 54 | for out_string in values: 55 | entities.append(out_string) 56 | 57 | # Also add any integers in the input string (these aren't in the entity) 58 | # dict. 59 | for token in tokenized_nl_string: 60 | if token.isnumeric(): 61 | entities.append(token) 62 | 63 | return entities 64 | -------------------------------------------------------------------------------- /EditSQL/data_util/snippets.py: -------------------------------------------------------------------------------- 1 | """ Contains the Snippet class and methods for handling snippets. 2 | 3 | Attributes: 4 | SNIPPET_PREFIX: string prefix for snippets. 5 | """ 6 | 7 | SNIPPET_PREFIX = "SNIPPET_" 8 | 9 | 10 | def is_snippet(token): 11 | """ Determines whether a token is a snippet or not. 12 | 13 | Inputs: 14 | token (str): The token to check. 15 | 16 | Returns: 17 | bool, indicating whether it's a snippet. 18 | """ 19 | return token.startswith(SNIPPET_PREFIX) 20 | 21 | def expand_snippets(sequence, snippets): 22 | """ Given a sequence and a list of snippets, expand the snippets in the sequence. 23 | 24 | Inputs: 25 | sequence (list of str): Query containing snippet references. 26 | snippets (list of Snippet): List of available snippets. 27 | 28 | return list of str representing the expanded sequence 29 | """ 30 | snippet_id_to_snippet = {} 31 | for snippet in snippets: 32 | assert snippet.name not in snippet_id_to_snippet 33 | snippet_id_to_snippet[snippet.name] = snippet 34 | expanded_seq = [] 35 | for token in sequence: 36 | if token in snippet_id_to_snippet: 37 | expanded_seq.extend(snippet_id_to_snippet[token].sequence) 38 | else: 39 | assert not is_snippet(token) 40 | expanded_seq.append(token) 41 | 42 | return expanded_seq 43 | 44 | def snippet_index(token): 45 | """ Returns the index of a snippet. 46 | 47 | Inputs: 48 | token (str): The snippet to check. 49 | 50 | Returns: 51 | integer, the index of the snippet. 52 | """ 53 | assert is_snippet(token) 54 | return int(token.split("_")[-1]) 55 | 56 | 57 | class Snippet(): 58 | """ Contains a snippet. """ 59 | def __init__(self, 60 | sequence, 61 | startpos, 62 | sql, 63 | age=0): 64 | self.sequence = sequence 65 | self.startpos = startpos 66 | self.sql = sql 67 | 68 | # TODO: age vs. index? 69 | self.age = age 70 | self.index = 0 71 | 72 | self.name = "" 73 | self.embedding = None 74 | 75 | self.endpos = self.startpos + len(self.sequence) 76 | assert self.endpos < len(self.sql), "End position of snippet is " + str( 77 | self.endpos) + " which is greater than length of SQL (" + str(len(self.sql)) + ")" 78 | assert self.sequence == self.sql[self.startpos:self.endpos], \ 79 | "Value of snippet (" + " ".join(self.sequence) + ") " \ 80 | "is not the same as SQL at the same positions (" \ 81 | + " ".join(self.sql[self.startpos:self.endpos]) + ")" 82 | 83 | def __str__(self): 84 | return self.name + "\t" + \ 85 | str(self.age) + "\t" + " ".join(self.sequence) 86 | 87 | def __len__(self): 88 | return len(self.sequence) 89 | 90 | def increase_age(self): 91 | """ Ages a snippet by one. """ 92 | self.index += 1 93 | 94 | def assign_id(self, number): 95 | """ Assigns the name of the snippet to be the prefix + the number. """ 96 | self.name = SNIPPET_PREFIX + str(number) 97 | 98 | def set_embedding(self, embedding): 99 | """ Sets the embedding of the snippet. 100 | 101 | Inputs: 102 | embedding (dy.Expression) 103 | 104 | """ 105 | self.embedding = embedding 106 | -------------------------------------------------------------------------------- /EditSQL/data_util/tokenizers.py: -------------------------------------------------------------------------------- 1 | """Tokenizers for natural language SQL queries, and lambda calculus.""" 2 | import nltk 3 | import sqlparse 4 | 5 | def nl_tokenize(string): 6 | """Tokenizes a natural language string into tokens. 7 | 8 | Inputs: 9 | string: the string to tokenize. 10 | Outputs: 11 | a list of tokens. 12 | 13 | Assumes data is space-separated (this is true of ZC07 data in ATIS2/3). 14 | """ 15 | return nltk.word_tokenize(string) 16 | 17 | def sql_tokenize(string): 18 | """ Tokenizes a SQL statement into tokens. 19 | 20 | Inputs: 21 | string: string to tokenize. 22 | 23 | Outputs: 24 | a list of tokens. 25 | """ 26 | tokens = [] 27 | statements = sqlparse.parse(string) 28 | 29 | # SQLparse gives you a list of statements. 30 | for statement in statements: 31 | # Flatten the tokens in each statement and add to the tokens list. 32 | flat_tokens = sqlparse.sql.TokenList(statement.tokens).flatten() 33 | for token in flat_tokens: 34 | strip_token = str(token).strip() 35 | if len(strip_token) > 0: 36 | tokens.append(strip_token) 37 | 38 | newtokens = [] 39 | keep = True 40 | for i, token in enumerate(tokens): 41 | if token == ".": 42 | newtoken = newtokens[-1] + "." + tokens[i + 1] 43 | newtokens = newtokens[:-1] + [newtoken] 44 | keep = False 45 | elif keep: 46 | newtokens.append(token) 47 | else: 48 | keep = True 49 | 50 | return newtokens 51 | 52 | def lambda_tokenize(string): 53 | """ Tokenizes a lambda-calculus statement into tokens. 54 | 55 | Inputs: 56 | string: a lambda-calculus string 57 | 58 | Outputs: 59 | a list of tokens. 60 | """ 61 | 62 | space_separated = string.split(" ") 63 | 64 | new_tokens = [] 65 | 66 | # Separate the string by spaces, then separate based on existence of ( or 67 | # ). 68 | for token in space_separated: 69 | tokens = [] 70 | 71 | current_token = "" 72 | for char in token: 73 | if char == ")" or char == "(": 74 | tokens.append(current_token) 75 | tokens.append(char) 76 | current_token = "" 77 | else: 78 | current_token += char 79 | tokens.append(current_token) 80 | new_tokens.extend([tok for tok in tokens if tok]) 81 | 82 | return new_tokens 83 | -------------------------------------------------------------------------------- /EditSQL/data_util/util.py: -------------------------------------------------------------------------------- 1 | """Contains various utility functions.""" 2 | def subsequence(first_sequence, second_sequence): 3 | """ 4 | Returns whether the first sequence is a subsequence of the second sequence. 5 | 6 | Inputs: 7 | first_sequence (list): A sequence. 8 | second_sequence (list): Another sequence. 9 | 10 | Returns: 11 | Boolean indicating whether first_sequence is a subsequence of second_sequence. 12 | """ 13 | for startidx in range(len(second_sequence) - len(first_sequence) + 1): 14 | if second_sequence[startidx:startidx + len(first_sequence)] == first_sequence: 15 | return True 16 | return False 17 | -------------------------------------------------------------------------------- /EditSQL/data_util/utterance.py: -------------------------------------------------------------------------------- 1 | """ Contains the Utterance class. """ 2 | 3 | from . import sql_util 4 | from . import tokenizers 5 | 6 | ANON_INPUT_KEY = "cleaned_nl" 7 | OUTPUT_KEY = "sql" 8 | 9 | class Utterance: 10 | """ Utterance class. """ 11 | def process_input_seq(self, 12 | anonymize, 13 | anonymizer, 14 | anon_tok_to_ent): 15 | assert not anon_tok_to_ent or anonymize 16 | assert not anonymize or anonymizer 17 | 18 | if anonymize: 19 | assert anonymizer 20 | 21 | self.input_seq_to_use = anonymizer.anonymize( 22 | self.original_input_seq, anon_tok_to_ent, ANON_INPUT_KEY, add_new_anon_toks=True) 23 | else: 24 | self.input_seq_to_use = self.original_input_seq 25 | 26 | def process_gold_seq(self, 27 | output_sequences, 28 | nl_to_sql_dict, 29 | available_snippets, 30 | anonymize, 31 | anonymizer, 32 | anon_tok_to_ent): 33 | # Get entities in the input sequence: 34 | # anonymized entity types 35 | # othe recognized entities (this includes "flight") 36 | entities_in_input = [ 37 | [tok] for tok in self.input_seq_to_use if tok in anon_tok_to_ent] 38 | entities_in_input.extend( 39 | nl_to_sql_dict.get_sql_entities( 40 | self.input_seq_to_use)) 41 | 42 | # Get the shortest gold query (this is what we use to train) 43 | shortest_gold_and_results = min(output_sequences, 44 | key=lambda x: len(x[0])) 45 | 46 | # Tokenize and anonymize it if necessary. 47 | self.original_gold_query = shortest_gold_and_results[0] 48 | self.gold_sql_results = shortest_gold_and_results[1] 49 | 50 | self.contained_entities = entities_in_input 51 | 52 | # Keep track of all gold queries and the resulting tables so that we can 53 | # give credit if it predicts a different correct sequence. 54 | self.all_gold_queries = output_sequences 55 | 56 | self.anonymized_gold_query = self.original_gold_query 57 | if anonymize: 58 | self.anonymized_gold_query = anonymizer.anonymize( 59 | self.original_gold_query, anon_tok_to_ent, OUTPUT_KEY, add_new_anon_toks=False) 60 | 61 | # Add snippets to it. 62 | self.gold_query_to_use = sql_util.add_snippets_to_query( 63 | available_snippets, entities_in_input, self.anonymized_gold_query) 64 | 65 | # Assign learning weights 66 | self.gold_query_weights = [1.0] * (len(self.gold_query_to_use) + 1) # EOS 67 | 68 | def __init__(self, 69 | example, 70 | available_snippets, 71 | nl_to_sql_dict, 72 | params, 73 | anon_tok_to_ent={}, 74 | anonymizer=None): 75 | # Get output and input sequences from the dictionary representation. 76 | output_sequences = example[OUTPUT_KEY] 77 | self.original_input_seq = tokenizers.nl_tokenize(example[params.input_key]) 78 | self.available_snippets = available_snippets 79 | self.keep = False 80 | 81 | # pruned_output_sequences = [] 82 | # for sequence in output_sequences: 83 | # if len(sequence[0]) > 3: 84 | # pruned_output_sequences.append(sequence) 85 | 86 | # output_sequences = pruned_output_sequences 87 | if len(output_sequences) > 0 and len(self.original_input_seq) > 0: 88 | # Only keep this example if there is at least one output sequence. 89 | self.keep = True 90 | if len(output_sequences) == 0 or len(self.original_input_seq) == 0: 91 | return 92 | 93 | # Process the input sequence 94 | self.process_input_seq(params.anonymize, 95 | anonymizer, 96 | anon_tok_to_ent) 97 | 98 | # Process the gold sequence 99 | self.process_gold_seq(output_sequences, 100 | nl_to_sql_dict, 101 | self.available_snippets, 102 | params.anonymize, 103 | anonymizer, 104 | anon_tok_to_ent) 105 | 106 | def __str__(self): 107 | string = "Original input: " + " ".join(self.original_input_seq) + "\n" 108 | string += "Modified input: " + " ".join(self.input_seq_to_use) + "\n" 109 | string += "Original output: " + " ".join(self.original_gold_query) + "\n" 110 | string += "Modified output: " + " ".join(self.gold_query_to_use) + "\n" 111 | string += "Snippets:\n" 112 | for snippet in self.available_snippets: 113 | string += str(snippet) + "\n" 114 | return string 115 | 116 | def length_valid(self, input_limit, output_limit): 117 | return (len(self.input_seq_to_use) < input_limit \ 118 | and len(self.gold_query_to_use) < output_limit) 119 | 120 | def set_gold_query_weights(self, weights): # added by Ziyu 121 | self.gold_query_weights = weights 122 | # assert len(self.gold_query_weights) == len(self.gold_query_to_use) 123 | # revise 0209: 124 | assert len(self.gold_query_weights) == len(self.gold_query_to_use) + 1 # one EOS 125 | -------------------------------------------------------------------------------- /EditSQL/data_util/vocabulary.py: -------------------------------------------------------------------------------- 1 | """Contains class and methods for storing and computing a vocabulary from text.""" 2 | import operator 3 | import os 4 | import pickle 5 | 6 | # Special sequencing tokens. 7 | UNK_TOK = "_UNK" # Replaces out-of-vocabulary words. 8 | EOS_TOK = "_EOS" # Appended to the end of a sequence to indicate its end. 9 | DEL_TOK = ";" 10 | 11 | 12 | class Vocabulary: 13 | """Vocabulary class: stores information about words in a corpus. 14 | 15 | Members: 16 | functional_types (list of str): Functional vocabulary words, such as EOS. 17 | max_size (int): The maximum size of vocabulary to keep. 18 | min_occur (int): The minimum number of times a word should occur to keep it. 19 | id_to_token (list of str): Ordered list of word types. 20 | token_to_id (dict str->int): Maps from each unique word type to its index. 21 | """ 22 | def get_vocab(self, sequences, ignore_fn): 23 | """Gets vocabulary from a list of sequences. 24 | 25 | Inputs: 26 | sequences (list of list of str): Sequences from which to compute the vocabulary. 27 | ignore_fn (lambda str: bool): Function used to tell whether to ignore a 28 | token during computation of the vocabulary. 29 | 30 | Returns: 31 | list of str, representing the unique word types in the vocabulary. 32 | """ 33 | type_counts = {} 34 | 35 | for sequence in sequences: 36 | for token in sequence: 37 | if not ignore_fn(token): 38 | if token not in type_counts: 39 | type_counts[token] = 0 40 | type_counts[token] += 1 41 | 42 | # Create sorted list of tokens, by their counts. Reverse so it is in order of 43 | # most frequent to least frequent. 44 | sorted_type_counts = sorted(sorted(type_counts.items()), 45 | key=operator.itemgetter(1))[::-1] 46 | 47 | sorted_types = [typecount[0] 48 | for typecount in sorted_type_counts if typecount[1] >= self.min_occur] 49 | 50 | # Append the necessary functional tokens. 51 | sorted_types = self.functional_types + sorted_types 52 | 53 | # Cut off if vocab_size is set (nonnegative) 54 | if self.max_size >= 0: 55 | vocab = sorted_types[:max(self.max_size, len(sorted_types))] 56 | else: 57 | vocab = sorted_types 58 | 59 | return vocab 60 | 61 | def __init__(self, 62 | sequences, 63 | filename, 64 | functional_types=None, 65 | max_size=-1, 66 | min_occur=0, 67 | ignore_fn=lambda x: False): 68 | self.functional_types = functional_types 69 | self.max_size = max_size 70 | self.min_occur = min_occur 71 | 72 | vocab = self.get_vocab(sequences, ignore_fn) 73 | 74 | self.id_to_token = [] 75 | self.token_to_id = {} 76 | 77 | for i, word_type in enumerate(vocab): 78 | self.id_to_token.append(word_type) 79 | self.token_to_id[word_type] = i 80 | 81 | # Load the previous vocab, if it exists. 82 | if os.path.exists(filename): 83 | infile = open(filename, 'rb') 84 | loaded_vocab = pickle.load(infile) 85 | infile.close() 86 | 87 | print("Loaded vocabulary from " + str(filename)) 88 | if loaded_vocab.id_to_token != self.id_to_token \ 89 | or loaded_vocab.token_to_id != self.token_to_id: 90 | print("Loaded vocabulary is different than generated vocabulary.") 91 | else: 92 | print("Writing vocabulary to " + str(filename)) 93 | outfile = open(filename, 'wb') 94 | pickle.dump(self, outfile) 95 | outfile.close() 96 | 97 | def __len__(self): 98 | return len(self.id_to_token) 99 | -------------------------------------------------------------------------------- /EditSQL/error_detector.py: -------------------------------------------------------------------------------- 1 | from MISP_SQL.error_detector import * -------------------------------------------------------------------------------- /EditSQL/eval_scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlab-osu/MISP/7870566ab6b9e121d648478968367bc79c12f7ef/EditSQL/eval_scripts/__init__.py -------------------------------------------------------------------------------- /EditSQL/eval_scripts/metric_averages.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | predictions = [json.loads(line) for line in open(sys.argv[1]).readlines() if line] 5 | 6 | string_count = 0. 7 | sem_count = 0. 8 | syn_count = 0. 9 | table_count = 0. 10 | strict_table_count = 0. 11 | 12 | precision_denom = 0. 13 | precision = 0. 14 | recall_denom = 0. 15 | recall = 0. 16 | f1_score = 0. 17 | f1_denom = 0. 18 | 19 | time = 0. 20 | 21 | for prediction in predictions: 22 | if prediction["correct_string"]: 23 | string_count += 1. 24 | if prediction["semantic"]: 25 | sem_count += 1. 26 | if prediction["syntactic"]: 27 | syn_count += 1. 28 | if prediction["correct_table"]: 29 | table_count += 1. 30 | if prediction["strict_correct_table"]: 31 | strict_table_count += 1. 32 | if prediction["gold_tables"] !="[[]]": 33 | precision += prediction["table_prec"] 34 | precision_denom += 1 35 | if prediction["pred_table"] != "[]": 36 | recall += prediction["table_rec"] 37 | recall_denom += 1 38 | 39 | if prediction["gold_tables"] != "[[]]": 40 | f1_score += prediction["table_f1"] 41 | f1_denom += 1 42 | 43 | num_p = len(predictions) 44 | print("string precision: " + str(string_count / num_p)) 45 | print("% semantic: " + str(sem_count / num_p)) 46 | print("% syntactic: " + str(syn_count / num_p)) 47 | print("table prec: " + str(table_count / num_p)) 48 | print("strict table prec: " + str(strict_table_count / num_p)) 49 | print("table row prec: " + str(precision / precision_denom)) 50 | print("table row recall: " + str(recall / recall_denom)) 51 | print("table row f1: " + str(f1_score / f1_denom)) 52 | print("inference time: " + str(time / num_p)) 53 | 54 | -------------------------------------------------------------------------------- /EditSQL/logger.py: -------------------------------------------------------------------------------- 1 | """Contains the logging class.""" 2 | 3 | class Logger(): 4 | """Attributes: 5 | 6 | fileptr (file): File pointer for input/output. 7 | lines (list of str): The lines read from the log. 8 | """ 9 | def __init__(self, filename, option): 10 | self.fileptr = open(filename, option) 11 | if option == "r": 12 | self.lines = self.fileptr.readlines() 13 | else: 14 | self.lines = [] 15 | 16 | def put(self, string): 17 | """Writes to the file.""" 18 | self.fileptr.write(string + "\n") 19 | self.fileptr.flush() 20 | 21 | def close(self): 22 | """Closes the logger.""" 23 | self.fileptr.close() 24 | 25 | def findlast(self, identifier, default=0.): 26 | """Finds the last line in the log with a certain value.""" 27 | for line in self.lines[::-1]: 28 | if line.lower().startswith(identifier): 29 | string = line.strip().split("\t")[1] 30 | if string.replace(".", "").isdigit(): 31 | return float(string) 32 | elif string.lower() == "true": 33 | return True 34 | elif string.lower() == "false": 35 | return False 36 | else: 37 | return string 38 | return default 39 | 40 | def contains(self, string): 41 | """Dtermines whether the string is present in the log.""" 42 | for line in self.lines[::-1]: 43 | if string.lower() in line.lower(): 44 | return True 45 | return False 46 | 47 | def findlast_log_before(self, before_str): 48 | """Finds the last entry in the log before another entry.""" 49 | loglines = [] 50 | in_line = False 51 | for line in self.lines[::-1]: 52 | if line.startswith(before_str): 53 | in_line = True 54 | elif in_line: 55 | loglines.append(line) 56 | if line.strip() == "" and in_line: 57 | return "".join(loglines[::-1]) 58 | return "".join(loglines[::-1]) 59 | -------------------------------------------------------------------------------- /EditSQL/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlab-osu/MISP/7870566ab6b9e121d648478968367bc79c12f7ef/EditSQL/model/__init__.py -------------------------------------------------------------------------------- /EditSQL/model/attention.py: -------------------------------------------------------------------------------- 1 | """Contains classes for computing and keeping track of attention distributions. 2 | """ 3 | from collections import namedtuple 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from . import torch_utils 8 | 9 | class AttentionResult(namedtuple('AttentionResult', 10 | ('scores', 11 | 'distribution', 12 | 'vector'))): 13 | """Stores the result of an attention calculation.""" 14 | __slots__ = () 15 | 16 | 17 | class Attention(torch.nn.Module): 18 | """Attention mechanism class. Stores parameters for and computes attention. 19 | 20 | Attributes: 21 | transform_query (bool): Whether or not to transform the query being 22 | passed in with a weight transformation before computing attentino. 23 | transform_key (bool): Whether or not to transform the key being 24 | passed in with a weight transformation before computing attentino. 25 | transform_value (bool): Whether or not to transform the value being 26 | passed in with a weight transformation before computing attentino. 27 | key_size (int): The size of the key vectors. 28 | value_size (int): The size of the value vectors. 29 | the query or key. 30 | query_weights (dy.Parameters): Weights for transforming the query. 31 | key_weights (dy.Parameters): Weights for transforming the key. 32 | value_weights (dy.Parameters): Weights for transforming the value. 33 | """ 34 | def __init__(self, query_size, key_size, value_size): 35 | super().__init__() 36 | self.key_size = key_size 37 | self.value_size = value_size 38 | 39 | self.query_weights = torch_utils.add_params((query_size, self.key_size), "weights-attention-q") 40 | 41 | def transform_arguments(self, query, keys, values): 42 | """ Transforms the query/key/value inputs before attention calculations. 43 | 44 | Arguments: 45 | query (dy.Expression): Vector representing the query (e.g., hidden state.) 46 | keys (list of dy.Expression): List of vectors representing the key 47 | values. 48 | values (list of dy.Expression): List of vectors representing the values. 49 | 50 | Returns: 51 | triple of dy.Expression, where the first represents the (transformed) 52 | query, the second represents the (transformed and concatenated) 53 | keys, and the third represents the (transformed and concatenated) 54 | values. 55 | """ 56 | assert len(keys) == len(values) 57 | 58 | all_keys = torch.stack(keys, dim=1) 59 | all_values = torch.stack(values, dim=1) 60 | 61 | assert all_keys.size()[0] == self.key_size, "Expected key size of " + str(self.key_size) + " but got " + str(all_keys.size()[0]) 62 | assert all_values.size()[0] == self.value_size 63 | 64 | query = torch_utils.linear_layer(query, self.query_weights) 65 | 66 | return query, all_keys, all_values 67 | 68 | def forward(self, query, keys, values=None): 69 | if not values: 70 | values = keys 71 | 72 | query_t, keys_t, values_t = self.transform_arguments(query, keys, values) 73 | 74 | scores = torch.t(torch.mm(query_t,keys_t)) # len(key) x len(query) 75 | 76 | distribution = F.softmax(scores, dim=0) # len(key) x len(query) 77 | 78 | context_vector = torch.mm(values_t, distribution).squeeze() # value_size x len(query) 79 | 80 | return AttentionResult(scores, distribution, context_vector) 81 | -------------------------------------------------------------------------------- /EditSQL/model/bert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlab-osu/MISP/7870566ab6b9e121d648478968367bc79c12f7ef/EditSQL/model/bert/__init__.py -------------------------------------------------------------------------------- /EditSQL/model/bert/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import argparse 23 | import tensorflow as tf 24 | import torch 25 | import numpy as np 26 | 27 | from modeling import BertConfig, BertModel 28 | 29 | parser = argparse.ArgumentParser() 30 | 31 | ## Required parameters 32 | parser.add_argument("--tf_checkpoint_path", 33 | default = None, 34 | type = str, 35 | required = True, 36 | help = "Path the TensorFlow checkpoint path.") 37 | parser.add_argument("--bert_config_file", 38 | default = None, 39 | type = str, 40 | required = True, 41 | help = "The config json file corresponding to the pre-trained BERT model. \n" 42 | "This specifies the model architecture.") 43 | parser.add_argument("--pytorch_dump_path", 44 | default = None, 45 | type = str, 46 | required = True, 47 | help = "Path to the output PyTorch model.") 48 | 49 | args = parser.parse_args() 50 | 51 | def convert(): 52 | # Initialise PyTorch model 53 | config = BertConfig.from_json_file(args.bert_config_file) 54 | model = BertModel(config) 55 | 56 | # Load weights from TF model 57 | path = args.tf_checkpoint_path 58 | print("Converting TensorFlow checkpoint from {}".format(path)) 59 | 60 | init_vars = tf.train.list_variables(path) 61 | names = [] 62 | arrays = [] 63 | for name, shape in init_vars: 64 | print("Loading {} with shape {}".format(name, shape)) 65 | array = tf.train.load_variable(path, name) 66 | print("Numpy array shape {}".format(array.shape)) 67 | names.append(name) 68 | arrays.append(array) 69 | 70 | for name, array in zip(names, arrays): 71 | name = name[5:] # skip "bert/" 72 | print("Loading {}".format(name)) 73 | name = name.split('/') 74 | if name[0] in ['redictions', 'eq_relationship']: 75 | print("Skipping") 76 | continue 77 | pointer = model 78 | for m_name in name: 79 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name): 80 | l = re.split(r'_(\d+)', m_name) 81 | else: 82 | l = [m_name] 83 | if l[0] == 'kernel': 84 | pointer = getattr(pointer, 'weight') 85 | else: 86 | pointer = getattr(pointer, l[0]) 87 | if len(l) >= 2: 88 | num = int(l[1]) 89 | pointer = pointer[num] 90 | if m_name[-11:] == '_embeddings': 91 | pointer = getattr(pointer, 'weight') 92 | elif m_name == 'kernel': 93 | array = np.transpose(array) 94 | try: 95 | assert pointer.shape == array.shape 96 | except AssertionError as e: 97 | e.args += (pointer.shape, array.shape) 98 | raise 99 | pointer.data = torch.from_numpy(array) 100 | 101 | # Save pytorch-model 102 | torch.save(model.state_dict(), args.pytorch_dump_path) 103 | 104 | if __name__ == "__main__": 105 | convert() 106 | -------------------------------------------------------------------------------- /EditSQL/model/bert/data/annotated_wikisql_and_PyTorch_bert_param/bert_config_uncased_L-12_H-768_A-12.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 12, 11 | "type_vocab_size": 2, 12 | "vocab_size": 30522 13 | } 14 | -------------------------------------------------------------------------------- /EditSQL/model/embedder.py: -------------------------------------------------------------------------------- 1 | """ Embedder for tokens. """ 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import EditSQL.data_util.snippets as snippet_handler 7 | import EditSQL.data_util.vocabulary as vocabulary_handler 8 | 9 | class Embedder(torch.nn.Module): 10 | """ Embeds tokens. """ 11 | def __init__(self, embedding_size, name="", initializer=None, vocabulary=None, num_tokens=-1, anonymizer=None, freeze=False, use_unk=True): 12 | super().__init__() 13 | 14 | if vocabulary: 15 | assert num_tokens < 0, "Specified a vocabulary but also set number of tokens to " + \ 16 | str(num_tokens) 17 | self.in_vocabulary = lambda token: token in vocabulary.tokens 18 | self.vocab_token_lookup = lambda token: vocabulary.token_to_id(token) 19 | if use_unk: 20 | self.unknown_token_id = vocabulary.token_to_id(vocabulary_handler.UNK_TOK) 21 | else: 22 | self.unknown_token_id = -1 23 | self.vocabulary_size = len(vocabulary) 24 | else: 25 | def check_vocab(index): 26 | """ Makes sure the index is in the vocabulary.""" 27 | assert index < num_tokens, "Passed token ID " + \ 28 | str(index) + "; expecting something less than " + str(num_tokens) 29 | return index < num_tokens 30 | self.in_vocabulary = check_vocab 31 | self.vocab_token_lookup = lambda x: x 32 | self.unknown_token_id = num_tokens # Deliberately throws an error here, 33 | # But should crash before this 34 | self.vocabulary_size = num_tokens 35 | 36 | self.anonymizer = anonymizer 37 | 38 | emb_name = name + "-tokens" 39 | print("Creating token embedder called " + emb_name + " of size " + str(self.vocabulary_size) + " x " + str(embedding_size)) 40 | 41 | if initializer is not None: 42 | word_embeddings_tensor = torch.FloatTensor(initializer) 43 | self.token_embedding_matrix = torch.nn.Embedding.from_pretrained(word_embeddings_tensor, freeze=freeze) 44 | else: 45 | init_tensor = torch.empty(self.vocabulary_size, embedding_size).uniform_(-0.1, 0.1) 46 | self.token_embedding_matrix = torch.nn.Embedding.from_pretrained(init_tensor, freeze=False) 47 | 48 | if self.anonymizer: 49 | emb_name = name + "-entities" 50 | entity_size = len(self.anonymizer.entity_types) 51 | print("Creating entity embedder called " + emb_name + " of size " + str(entity_size) + " x " + str(embedding_size)) 52 | init_tensor = torch.empty(entity_size, embedding_size).uniform_(-0.1, 0.1) 53 | self.entity_embedding_matrix = torch.nn.Embedding.from_pretrained(init_tensor, freeze=False) 54 | 55 | 56 | def forward(self, token): 57 | assert isinstance(token, int) or not snippet_handler.is_snippet(token), "embedder should only be called on flat tokens; use snippet_bow if you are trying to encode snippets" 58 | 59 | if self.in_vocabulary(token): 60 | index_list = torch.LongTensor([self.vocab_token_lookup(token)]) 61 | if self.token_embedding_matrix.weight.is_cuda: 62 | index_list = index_list.cuda() 63 | return self.token_embedding_matrix(index_list).squeeze() 64 | elif self.anonymizer and self.anonymizer.is_anon_tok(token): 65 | index_list = torch.LongTensor([self.anonymizer.get_anon_id(token)]) 66 | if self.token_embedding_matrix.weight.is_cuda: 67 | index_list = index_list.cuda() 68 | return self.entity_embedding_matrix(index_list).squeeze() 69 | else: 70 | index_list = torch.LongTensor([self.unknown_token_id]) 71 | if self.token_embedding_matrix.weight.is_cuda: 72 | index_list = index_list.cuda() 73 | return self.token_embedding_matrix(index_list).squeeze() 74 | 75 | 76 | def bow_snippets(token, snippets, output_embedder, input_schema): 77 | """ Bag of words embedding for snippets""" 78 | assert snippet_handler.is_snippet(token) and snippets 79 | 80 | snippet_sequence = [] 81 | for snippet in snippets: 82 | if snippet.name == token: 83 | snippet_sequence = snippet.sequence 84 | break 85 | assert snippet_sequence 86 | 87 | if input_schema: 88 | snippet_embeddings = [] 89 | for output_token in snippet_sequence: 90 | assert output_embedder.in_vocabulary(output_token) or input_schema.in_vocabulary(output_token, surface_form=True) 91 | if output_embedder.in_vocabulary(output_token): 92 | snippet_embeddings.append(output_embedder(output_token)) 93 | else: 94 | snippet_embeddings.append(input_schema.column_name_embedder(output_token, surface_form=True)) 95 | else: 96 | snippet_embeddings = [output_embedder(subtoken) for subtoken in snippet_sequence] 97 | 98 | snippet_embeddings = torch.stack(snippet_embeddings, dim=0) # len(snippet_sequence) x emb_size 99 | return torch.mean(snippet_embeddings, dim=0) # emb_size 100 | 101 | -------------------------------------------------------------------------------- /EditSQL/model/encoder.py: -------------------------------------------------------------------------------- 1 | """ Contains code for encoding an input sequence. """ 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from .torch_utils import create_multilayer_lstm_params, encode_sequence 6 | 7 | class Encoder(torch.nn.Module): 8 | """ Encodes an input sequence. """ 9 | def __init__(self, num_layers, input_size, state_size): 10 | super().__init__() 11 | 12 | self.num_layers = num_layers 13 | self.forward_lstms = create_multilayer_lstm_params(self.num_layers, input_size, state_size / 2, "LSTM-ef") 14 | self.backward_lstms = create_multilayer_lstm_params(self.num_layers, input_size, state_size / 2, "LSTM-eb") 15 | 16 | def forward(self, sequence, embedder, dropout_amount=0.): 17 | """ Encodes a sequence forward and backward. 18 | Inputs: 19 | forward_seq (list of str): The string forwards. 20 | backward_seq (list of str): The string backwards. 21 | f_rnns (list of dy.RNNBuilder): The forward RNNs. 22 | b_rnns (list of dy.RNNBuilder): The backward RNNS. 23 | emb_fn (dict str->dy.Expression): Embedding function for tokens in the 24 | sequence. 25 | size (int): The size of the RNNs. 26 | dropout_amount (float, optional): The amount of dropout to apply. 27 | 28 | Returns: 29 | (list of dy.Expression, list of dy.Expression), list of dy.Expression, 30 | where the first pair is the (final cell memories, final cell states) of 31 | all layers, and the second list is a list of the final layer's cell 32 | state for all tokens in the sequence. 33 | """ 34 | forward_state, forward_outputs = encode_sequence( 35 | sequence, 36 | self.forward_lstms, 37 | embedder, 38 | dropout_amount=dropout_amount) 39 | 40 | backward_state, backward_outputs = encode_sequence( 41 | sequence[::-1], 42 | self.backward_lstms, 43 | embedder, 44 | dropout_amount=dropout_amount) 45 | 46 | cell_memories = [] 47 | hidden_states = [] 48 | for i in range(self.num_layers): 49 | cell_memories.append(torch.cat([forward_state[0][i], backward_state[0][i]], dim=0)) 50 | hidden_states.append(torch.cat([forward_state[1][i], backward_state[1][i]], dim=0)) 51 | 52 | assert len(forward_outputs) == len(backward_outputs) 53 | 54 | backward_outputs = backward_outputs[::-1] 55 | 56 | final_outputs = [] 57 | for i in range(len(sequence)): 58 | final_outputs.append(torch.cat([forward_outputs[i], backward_outputs[i]], dim=0)) 59 | 60 | return (cell_memories, hidden_states), final_outputs -------------------------------------------------------------------------------- /EditSQL/model/torch_utils.py: -------------------------------------------------------------------------------- 1 | """Contains various utility functions for Dynet models.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | def linear_layer(exp, weights, biases=None): 11 | # exp: input as size_1 or 1 x size_1 12 | # weight: size_1 x size_2 13 | # bias: size_2 14 | if exp.dim() == 1: 15 | exp = torch.unsqueeze(exp, 0) 16 | assert exp.size()[1] == weights.size()[0] 17 | if biases is not None: 18 | assert weights.size()[1] == biases.size()[0] 19 | result = torch.mm(exp, weights) + biases 20 | else: 21 | result = torch.mm(exp, weights) 22 | return result 23 | 24 | 25 | def compute_loss(gold_seq, 26 | scores, 27 | index_to_token_maps, 28 | gold_tok_to_id, 29 | noise=0.00000001, 30 | weights=None): 31 | """ Computes the loss of a gold sequence given scores. 32 | 33 | Inputs: 34 | gold_seq (list of str): A sequence of gold tokens. 35 | scores (list of dy.Expression): Expressions representing the scores of 36 | potential output tokens for each token in gold_seq. 37 | index_to_token_maps (list of dict str->list of int): Maps from index in the 38 | sequence to a dictionary mapping from a string to a set of integers. 39 | gold_tok_to_id (lambda (str, str)->list of int): Maps from the gold token 40 | and some lookup function to the indices in the probability distribution 41 | where the gold token occurs. 42 | noise (float, optional): The amount of noise to add to the loss. 43 | 44 | Returns: 45 | dy.Expression representing the sum of losses over the sequence. 46 | """ 47 | assert len(gold_seq) == len(scores) == len(index_to_token_maps) 48 | 49 | losses = [] 50 | for i, gold_tok in enumerate(gold_seq): 51 | score = scores[i] 52 | token_map = index_to_token_maps[i] 53 | 54 | gold_indices = gold_tok_to_id(gold_tok, token_map) 55 | assert len(gold_indices) > 0 56 | noise_i = noise 57 | if len(gold_indices) == 1: 58 | noise_i = 0 59 | 60 | probdist = score 61 | prob_of_tok = noise_i + torch.sum(probdist[gold_indices]) 62 | losses.append(-torch.log(prob_of_tok)) 63 | 64 | if weights is not None: 65 | assert len(weights) == len(gold_seq) 66 | return torch.sum(torch.stack(losses) * torch.tensor(weights).to(device)) 67 | 68 | return torch.sum(torch.stack(losses)) 69 | 70 | 71 | def get_seq_from_scores(scores, index_to_token_maps): 72 | """Gets the argmax sequence from a set of scores. 73 | 74 | Inputs: 75 | scores (list of dy.Expression): Sequences of output scores. 76 | index_to_token_maps (list of list of str): For each output token, maps 77 | the index in the probability distribution to a string. 78 | 79 | Returns: 80 | list of str, representing the argmax sequence. 81 | """ 82 | seq = [] 83 | for score, tok_map in zip(scores, index_to_token_maps): 84 | # score_numpy_list = score.cpu().detach().numpy() 85 | score_numpy_list = score.cpu().data.numpy() 86 | assert score.size()[0] == len(tok_map) == len(list(score_numpy_list)) 87 | seq.append(tok_map[np.argmax(score_numpy_list)]) 88 | return seq 89 | 90 | def per_token_accuracy(gold_seq, pred_seq): 91 | """ Returns the per-token accuracy comparing two strings (recall). 92 | 93 | Inputs: 94 | gold_seq (list of str): A list of gold tokens. 95 | pred_seq (list of str): A list of predicted tokens. 96 | 97 | Returns: 98 | float, representing the accuracy. 99 | """ 100 | num_correct = 0 101 | for i, gold_token in enumerate(gold_seq): 102 | if i < len(pred_seq) and pred_seq[i] == gold_token: 103 | num_correct += 1 104 | 105 | return float(num_correct) / len(gold_seq) 106 | 107 | def forward_one_multilayer(rnns, lstm_input, layer_states, dropout_amount=0.): 108 | """ Goes forward for one multilayer RNN cell step. 109 | 110 | Inputs: 111 | lstm_input (dy.Expression): Some input to the step. 112 | layer_states (list of dy.RNNState): The states of each layer in the cell. 113 | dropout_amount (float, optional): The amount of dropout to apply, in 114 | between the layers. 115 | 116 | Returns: 117 | (list of dy.Expression, list of dy.Expression), dy.Expression, (list of dy.RNNSTate), 118 | representing (each layer's cell memory, each layer's cell hidden state), 119 | the final hidden state, and (each layer's updated RNNState). 120 | """ 121 | num_layers = len(layer_states) 122 | new_states = [] 123 | cell_states = [] 124 | hidden_states = [] 125 | state = lstm_input 126 | for i in range(num_layers): 127 | # view as (1, input_size) 128 | layer_h, layer_c = rnns[i](torch.unsqueeze(state,0), layer_states[i]) 129 | new_states.append((layer_h, layer_c)) 130 | 131 | layer_h = layer_h.squeeze() 132 | layer_c = layer_c.squeeze() 133 | 134 | state = layer_h 135 | if i < num_layers - 1: 136 | # In both Dynet and Pytorch 137 | # p stands for probability of an element to be zeroed. i.e. p=1 means switch off all activations. 138 | state = F.dropout(state, p=dropout_amount) 139 | 140 | cell_states.append(layer_c) 141 | hidden_states.append(layer_h) 142 | 143 | return (cell_states, hidden_states), state, new_states 144 | 145 | 146 | def encode_sequence(sequence, rnns, embedder, dropout_amount=0.): 147 | """ Encodes a sequence given RNN cells and an embedding function. 148 | 149 | Inputs: 150 | seq (list of str): The sequence to encode. 151 | rnns (list of dy._RNNBuilder): The RNNs to use. 152 | emb_fn (dict str->dy.Expression): Function that embeds strings to 153 | word vectors. 154 | size (int): The size of the RNN. 155 | dropout_amount (float, optional): The amount of dropout to apply. 156 | 157 | Returns: 158 | (list of dy.Expression, list of dy.Expression), list of dy.Expression, 159 | where the first pair is the (final cell memories, final cell states) of 160 | all layers, and the second list is a list of the final layer's cell 161 | state for all tokens in the sequence. 162 | """ 163 | 164 | batch_size = 1 165 | layer_states = [] 166 | for rnn in rnns: 167 | hidden_size = rnn.weight_hh.size()[1] 168 | 169 | # h_0 of shape (batch, hidden_size) 170 | # c_0 of shape (batch, hidden_size) 171 | if rnn.weight_hh.is_cuda: 172 | h_0 = torch.cuda.FloatTensor(batch_size,hidden_size).fill_(0) 173 | c_0 = torch.cuda.FloatTensor(batch_size,hidden_size).fill_(0) 174 | else: 175 | h_0 = torch.zeros(batch_size,hidden_size) 176 | c_0 = torch.zeros(batch_size,hidden_size) 177 | 178 | layer_states.append((h_0, c_0)) 179 | 180 | outputs = [] 181 | for token in sequence: 182 | rnn_input = embedder(token) 183 | (cell_states, hidden_states), output, layer_states = forward_one_multilayer(rnns,rnn_input,layer_states,dropout_amount) 184 | 185 | outputs.append(output) 186 | 187 | return (cell_states, hidden_states), outputs 188 | 189 | def create_multilayer_lstm_params(num_layers, in_size, state_size, name=""): 190 | """ Adds a multilayer LSTM to the model parameters. 191 | 192 | Inputs: 193 | num_layers (int): Number of layers to create. 194 | in_size (int): The input size to the first layer. 195 | state_size (int): The size of the states. 196 | model (dy.ParameterCollection): The parameter collection for the model. 197 | name (str, optional): The name of the multilayer LSTM. 198 | """ 199 | lstm_layers = [] 200 | for i in range(num_layers): 201 | layer_name = name + "-" + str(i) 202 | print("LSTM " + layer_name + ": " + str(in_size) + " x " + str(state_size) + "; default Dynet initialization of hidden weights") 203 | lstm_layer = torch.nn.LSTMCell(input_size=int(in_size), hidden_size=int(state_size), bias=True) 204 | lstm_layers.append(lstm_layer) 205 | in_size = state_size 206 | return torch.nn.ModuleList(lstm_layers) 207 | 208 | def add_params(size, name=""): 209 | """ Adds parameters to the model. 210 | 211 | Inputs: 212 | model (dy.ParameterCollection): The parameter collection for the model. 213 | size (tuple of int): The size to create. 214 | name (str, optional): The name of the parameters. 215 | """ 216 | if len(size) == 1: 217 | print("vector " + name + ": " + str(size[0]) + "; uniform in [-0.1, 0.1]") 218 | else: 219 | print("matrix " + name + ": " + str(size[0]) + " x " + str(size[1]) + "; uniform in [-0.1, 0.1]") 220 | 221 | size_int = tuple([int(ss) for ss in size]) 222 | return torch.nn.Parameter(torch.empty(size_int).uniform_(-0.1, 0.1)) 223 | -------------------------------------------------------------------------------- /EditSQL/parse_args.py: -------------------------------------------------------------------------------- 1 | import sys 2 | args = sys.argv 3 | 4 | import os 5 | import argparse 6 | 7 | def interpret_args(): 8 | """ Interprets the command line arguments, and returns a dictionary. """ 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument("--no_gpus", type=bool, default=1) 12 | 13 | ### Data parameters 14 | parser.add_argument( 15 | '--raw_train_filename', 16 | type=str, 17 | default='../atis_data/data/resplit/processed/train_with_tables.pkl') 18 | parser.add_argument( 19 | '--raw_dev_filename', 20 | type=str, 21 | default='../atis_data/data/resplit/processed/dev_with_tables.pkl') 22 | parser.add_argument( 23 | '--raw_validation_filename', 24 | type=str, 25 | default='../atis_data/data/resplit/processed/valid_with_tables.pkl') 26 | parser.add_argument( 27 | '--raw_test_filename', 28 | type=str, 29 | default='../atis_data/data/resplit/processed/test_with_tables.pkl') 30 | 31 | parser.add_argument('--data_directory', type=str, default='processed_data') 32 | 33 | parser.add_argument('--processed_train_filename', type=str, default='train.pkl') 34 | parser.add_argument('--processed_dev_filename', type=str, default='dev.pkl') 35 | parser.add_argument('--processed_validation_filename', type=str, default='validation.pkl') 36 | parser.add_argument('--processed_test_filename', type=str, default='test.pkl') 37 | 38 | parser.add_argument('--database_schema_filename', type=str, default=None) 39 | parser.add_argument('--embedding_filename', type=str, default=None) 40 | 41 | parser.add_argument('--input_vocabulary_filename', type=str, default='input_vocabulary.pkl') 42 | parser.add_argument('--output_vocabulary_filename', 43 | type=str, 44 | default='output_vocabulary.pkl') 45 | 46 | parser.add_argument('--input_key', type=str, default='nl_with_dates') 47 | 48 | parser.add_argument('--anonymize', type=bool, default=False) 49 | parser.add_argument('--anonymization_scoring', type=bool, default=False) 50 | parser.add_argument('--use_snippets', type=bool, default=False) 51 | 52 | parser.add_argument('--use_previous_query', type=bool, default=False) 53 | parser.add_argument('--maximum_queries', type=int, default=1) 54 | parser.add_argument('--use_copy_switch', type=bool, default=False) 55 | parser.add_argument('--use_query_attention', type=bool, default=False) 56 | 57 | parser.add_argument('--use_utterance_attention', type=bool, default=False) 58 | 59 | parser.add_argument('--freeze', type=bool, default=False) 60 | parser.add_argument('--scheduler', type=bool, default=False) 61 | 62 | parser.add_argument('--use_bert', type=bool, default=False) 63 | parser.add_argument("--bert_type_abb", type=str, help="Type of BERT model to load. e.g.) uS, uL, cS, cL, and mcS") 64 | parser.add_argument("--bert_input_version", type=str, default='v1') 65 | parser.add_argument('--fine_tune_bert', type=bool, default=False) 66 | parser.add_argument('--lr_bert', default=1e-5, type=float, help='BERT model learning rate.') 67 | 68 | ### Debugging/logging parameters 69 | parser.add_argument('--logdir', type=str, default='logs') 70 | parser.add_argument('--deterministic', type=bool, default=False) 71 | parser.add_argument('--num_train', type=int, default=-1) 72 | 73 | parser.add_argument('--logfile', type=str, default='log.txt') 74 | parser.add_argument('--results_file', type=str, default='results.txt') 75 | 76 | ### Model architecture 77 | parser.add_argument('--input_embedding_size', type=int, default=300) 78 | parser.add_argument('--output_embedding_size', type=int, default=300) 79 | 80 | parser.add_argument('--encoder_state_size', type=int, default=300) 81 | parser.add_argument('--decoder_state_size', type=int, default=300) 82 | 83 | parser.add_argument('--encoder_num_layers', type=int, default=1) 84 | parser.add_argument('--decoder_num_layers', type=int, default=2) 85 | parser.add_argument('--snippet_num_layers', type=int, default=1) 86 | 87 | parser.add_argument('--maximum_utterances', type=int, default=5) 88 | parser.add_argument('--state_positional_embeddings', type=bool, default=False) 89 | parser.add_argument('--positional_embedding_size', type=int, default=50) 90 | 91 | parser.add_argument('--snippet_age_embedding', type=bool, default=False) 92 | parser.add_argument('--snippet_age_embedding_size', type=int, default=64) 93 | parser.add_argument('--max_snippet_age_embedding', type=int, default=4) 94 | parser.add_argument('--previous_decoder_snippet_encoding', type=bool, default=False) 95 | 96 | parser.add_argument('--discourse_level_lstm', type=bool, default=False) 97 | 98 | parser.add_argument('--use_schema_attention', type=bool, default=False) 99 | parser.add_argument('--use_encoder_attention', type=bool, default=False) 100 | 101 | parser.add_argument('--use_schema_encoder', type=bool, default=False) 102 | parser.add_argument('--use_schema_self_attention', type=bool, default=False) 103 | parser.add_argument('--use_schema_encoder_2', type=bool, default=False) 104 | 105 | ### Training parameters 106 | parser.add_argument('--batch_size', type=int, default=16) 107 | parser.add_argument('--train_maximum_sql_length', type=int, default=200) 108 | parser.add_argument('--train_evaluation_size', type=int, default=100) 109 | 110 | parser.add_argument('--dropout_amount', type=float, default=0.5) 111 | 112 | parser.add_argument('--initial_patience', type=float, default=10.) 113 | parser.add_argument('--patience_ratio', type=float, default=1.01) 114 | 115 | parser.add_argument('--initial_learning_rate', type=float, default=0.001) 116 | parser.add_argument('--learning_rate_ratio', type=float, default=0.8) 117 | 118 | parser.add_argument('--interaction_level', type=bool, default=False) 119 | parser.add_argument('--reweight_batch', type=bool, default=False) 120 | 121 | ### Setting 122 | parser.add_argument('--train', type=bool, default=False) 123 | parser.add_argument('--debug', type=bool, default=False) 124 | 125 | parser.add_argument('--evaluate', type=bool, default=False) 126 | parser.add_argument('--attention', type=bool, default=False) 127 | parser.add_argument('--save_file', type=str, default="") 128 | parser.add_argument('--enable_testing', type=bool, default=False) 129 | parser.add_argument('--use_predicted_queries', type=bool, default=False) 130 | parser.add_argument('--evaluate_split', type=str, default='dev') 131 | parser.add_argument('--evaluate_with_gold_forcing', type=bool, default=False) 132 | parser.add_argument('--eval_maximum_sql_length', type=int, default=1000) 133 | parser.add_argument('--results_note', type=str, default='') 134 | parser.add_argument('--compute_metrics', type=bool, default=False) 135 | 136 | parser.add_argument('--reference_results', type=str, default='') 137 | 138 | parser.add_argument('--interactive', type=bool, default=False) 139 | 140 | parser.add_argument('--database_username', type=str, default="aviarmy") 141 | parser.add_argument('--database_password', type=str, default="aviarmy") 142 | parser.add_argument('--database_timeout', type=int, default=2) 143 | 144 | args = parser.parse_args() 145 | 146 | if not os.path.exists(args.logdir): 147 | os.makedirs(args.logdir) 148 | 149 | if not (args.train or args.evaluate or args.interactive or args.attention): 150 | raise ValueError('You need to be training or evaluating') 151 | if args.enable_testing and not args.evaluate: 152 | raise ValueError('You should evaluate the model if enabling testing') 153 | 154 | if args.train: 155 | args_file = args.logdir + '/args.log' 156 | if os.path.exists(args_file): 157 | raise ValueError('Warning: arguments already exist in ' + str(args_file)) 158 | with open(args_file, 'w') as infile: 159 | infile.write(str(args)) 160 | 161 | return args 162 | -------------------------------------------------------------------------------- /EditSQL/question_gen.py: -------------------------------------------------------------------------------- 1 | from MISP_SQL.question_gen import QuestionGenerator as BaseQuestionGenerator 2 | from MISP_SQL.utils import WHERE_COL, GROUP_COL, ORDER_AGG_v2, HAV_AGG_v2 3 | 4 | 5 | class QuestionGenerator(BaseQuestionGenerator): 6 | def __init__(self, bool_structure_question=False): 7 | BaseQuestionGenerator.__init__(self) 8 | self.bool_structure_question = bool_structure_question 9 | 10 | def option_generation(self, cand_semantic_units, old_tag_seq, pointer): 11 | question, cheat_sheet, sel_none_of_above = BaseQuestionGenerator.option_generation( 12 | self, cand_semantic_units, old_tag_seq, pointer) 13 | 14 | if self.bool_structure_question: 15 | semantic_tag = old_tag_seq[pointer][0] 16 | 17 | if semantic_tag == WHERE_COL: 18 | question = question[:-1] + ';\n' 19 | sel_invalid_structure = sel_none_of_above + 1 20 | question += "(%d) The system does not need to consider any conditions." % sel_invalid_structure 21 | 22 | elif semantic_tag == GROUP_COL: 23 | question = question[:-1] + ';\n' 24 | sel_invalid_structure = sel_none_of_above + 1 25 | question += "(%d) The system does not need to group any items." % sel_invalid_structure 26 | 27 | elif semantic_tag == HAV_AGG_v2: 28 | question = question[:-1] + ';\n' 29 | sel_invalid_structure = sel_none_of_above + 1 30 | question += "(%d) The system does not need to consider any conditions." % sel_invalid_structure 31 | 32 | elif semantic_tag == ORDER_AGG_v2: 33 | question = question[:-1] + ';\n' 34 | sel_invalid_structure = sel_none_of_above + 1 35 | question += "(%d) The system does not need to order the results." % sel_invalid_structure 36 | 37 | return question, cheat_sheet, sel_none_of_above 38 | 39 | 40 | -------------------------------------------------------------------------------- /EditSQL/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.0.1 2 | sqlparse 3 | pymysql 4 | progressbar 5 | nltk -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Ziyu Yao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MISP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlab-osu/MISP/7870566ab6b9e121d648478968367bc79c12f7ef/MISP.png -------------------------------------------------------------------------------- /MISP_SQL/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlab-osu/MISP/7870566ab6b9e121d648478968367bc79c12f7ef/MISP_SQL/__init__.py -------------------------------------------------------------------------------- /MISP_SQL/error_detector.py: -------------------------------------------------------------------------------- 1 | # Error detector 2 | from .utils import semantic_unit_segment, np 3 | 4 | 5 | class ErrorDetector: 6 | """ 7 | This is the class for Error Detector. 8 | """ 9 | def __init__(self): 10 | return 11 | 12 | def detection(self, tag_seq, start_pos=0, bool_return_first=False, *args, **kwargs): 13 | """ 14 | Error detection. 15 | :param tag_seq: a sequence of semantic units. 16 | :param start_pos: the starting pointer to examine. 17 | :param bool_return_first: Set to True to return the first error only. 18 | :return: a list of pairs of (erroneous semantic unit, its position in tag_seq). 19 | """ 20 | raise NotImplementedError 21 | 22 | 23 | class ErrorDetectorSim(ErrorDetector): 24 | """ 25 | This is a simulated error detector which always detects the exact wrong decisions. 26 | """ 27 | def __init__(self): 28 | ErrorDetector.__init__(self) 29 | 30 | def detection(self, tag_seq, start_pos=0, bool_return_first=False, eval_tf=None, *args, **kwargs): 31 | if start_pos >= len(tag_seq): 32 | return [] 33 | 34 | semantic_units, pointers = semantic_unit_segment(tag_seq) 35 | err_su_pointer_pairs = [] 36 | for semantic_unit, pointer in zip(semantic_units, pointers): 37 | if pointer < start_pos: 38 | continue 39 | 40 | bool_correct = eval_tf[pointer] 41 | if not bool_correct: 42 | err_su_pointer_pairs.append((semantic_unit, pointer)) 43 | if bool_return_first: 44 | return err_su_pointer_pairs 45 | 46 | return err_su_pointer_pairs 47 | 48 | 49 | class ErrorDetectorProbability(ErrorDetector): 50 | """ 51 | This is the probability-based error detector. 52 | """ 53 | def __init__(self, threshold): 54 | """ 55 | Constructor of the probability-based error detector. 56 | :param threshold: A float number; the probability threshold. 57 | """ 58 | ErrorDetector.__init__(self) 59 | self.prob_threshold = threshold 60 | 61 | def detection(self, tag_seq, start_pos=0, bool_return_first=False, *args, **kwargs): 62 | if start_pos >= len(tag_seq): 63 | return [] 64 | 65 | semantic_units, pointers = semantic_unit_segment(tag_seq) 66 | err_su_pointer_pairs = [] 67 | for semantic_unit, pointer in zip(semantic_units, pointers): 68 | if pointer < start_pos: 69 | continue 70 | 71 | prob = semantic_unit[-2] 72 | # if the decision's probability is lower than the threshold, consider it as an error 73 | if prob < self.prob_threshold: 74 | err_su_pointer_pairs.append((semantic_unit, pointer)) 75 | if bool_return_first: 76 | return err_su_pointer_pairs 77 | 78 | return err_su_pointer_pairs 79 | 80 | 81 | class ErrorDetectorBayesDropout(ErrorDetector): 82 | """ 83 | This is the Bayesian Dropout-based error detector. 84 | """ 85 | def __init__(self, threshold): 86 | """ 87 | Constructor of the Bayesian Dropout-based error detector. 88 | :param threshold: A float number; the standard deviation threshold. 89 | """ 90 | ErrorDetector.__init__(self) 91 | self.stddev_threshold = threshold 92 | 93 | def detection(self, tag_seq, start_pos=0, bool_return_first=False, *args, **kwargs): 94 | if start_pos >= len(tag_seq): 95 | return [] 96 | 97 | semantic_units, pointers = semantic_unit_segment(tag_seq) 98 | err_su_pointer_pairs = [] 99 | for semantic_unit, pointer in zip(semantic_units, pointers): 100 | if pointer < start_pos: 101 | continue 102 | 103 | # if the decision's stddev is greater than the threshold, consider it as an error 104 | stddev = np.std(semantic_unit[-2]) 105 | if stddev > self.stddev_threshold: 106 | err_su_pointer_pairs.append((semantic_unit, pointer)) 107 | if bool_return_first: 108 | return err_su_pointer_pairs 109 | 110 | return err_su_pointer_pairs 111 | 112 | -------------------------------------------------------------------------------- /MISP_SQL/semantic_tag_logic.txt: -------------------------------------------------------------------------------- 1 | # Tagging logic 2 | # Note that different base parsers can have different tagging logics, depending on how their decoders work. 3 | # However, in most time the tagging order, e.g., WHERE_COL goes before WHERE_OP and WHERE_VAL, does not change. 4 | # The following document uses the working mechanism of SQLNet/SQLova/SyntaxSQLNet to illustrate the tagging logic. 5 | # You can find an example in https://github.com/sunlab-osu/MISP#21-example 6 | # 7 | # col = (tab_name, col_name, col_idx) 8 | # agg = (agg_name, agg_idx) # agg_name in {'min', 'max', 'count', ..., "none_agg"} 9 | # op = (op_name, op_idx) # op_name in {'>', '<', ...} 10 | # iuen = (iuen_name, iuen_idx) # iuen_name in {'none', 'intersect', 'union', 'except'} 11 | # desc_asc_limit = ('asc'/'desc', True/False for limit) 12 | # 13 | # "dec_idx" refers to the position of this decision in the dec_seq (your decoding sequence). 14 | # 15 | # ('O', ('root', None), 1.0, None), ('IUEN', iuen, prob, dec_idx) 16 | # 17 | # When it is 'intersect'/'union'/'except' in iuen: 18 | # ('IUEN', iuen, prob, dec_idx), <- here's the main sql-> ('O', '##END_NESTED##', 1.0, None) 19 | # <-followed by the nested sql to intersect/union/except with-> 20 | # 21 | # SELECT COL: 22 | # (O, "select", 1.0, dec_idx), (SELECT_COL, col1, prob1, dec_idx1), (SELECT_COL, col2, prob2, dec_idx2), .. 23 | # For each col: 24 | # (SELECT_AGG, col1, agg1, prob1 of agg1, dec_idx1 of agg1), (SELECT_AGG, col1, agg2, prob2 of agg2, dec_idx2 of agg2), .. 25 | # 26 | # WHERE: 27 | # (O, "where", prob of where clause, prob of #col, dec_idx), 28 | # (WHERE_COL, col1, prob1, dec_idx1), (WHERE_COL, col2, prob2, dec_idx2) .., 29 | # (ANDOR, "and"/"or", [col1, col2, ..], andor_prob, dec_idx)#when multiple cols selected 30 | # For each col: 31 | # (WHERE_OP, (col1,), op1, prob1 of op1, dec_idx1), (WHERE_OP, (col1,), op2, prob2 of op2, dec_idx2) 32 | # For each (col, op): 33 | # (WHERE_ROOTTERM, (col,), op, 'root'/'terminal', prob, dec_idx) for Spider or (WHERE_VAL, (col,), op, (val_idx, val_str), prob, dec_idx) for WikiSQL 34 | # 35 | # GROUP: 36 | # (O, "groupBy", prob of group_by clause, prob of #col), (GROUP_COL, col1, prob1, dec_idx1), (GROUP_COL, col2, prob2, dec_idx2), .. 37 | # (GROUP_NHAV, "none_having", prob, dec_idx) #end of groupBy 38 | # or (O, "having", prob, dec_idx), (HAV_COL, col1, prob1, dec_idx1), (HAV_COL, col2, prob2, dec_idx2), .. 39 | # For each col: 40 | # (HAV_AGG, col, agg, prob of agg, dec_idx of agg), (HAV_OP, (col, agg), op1, prob1 of op1, dec_idx1), (HAV_OP, (col, agg), op2, prob2 of op2, dec_idx2), .. 41 | # For each op: 42 | # (HAV_ROOTTERM, (col, agg), op, 'root'/'terminal', prob, dec_idx) 43 | # 44 | # ORDER: 45 | # (O, "orderBy", prob of order_by clause, dec_idx), (ORDER_COL, col1, prob1, dec_idx1), (ORDER_COL, col2, prob2, dec_idx2), .. 46 | # For each col: 47 | # (ORDER_AGG, col, agg, prob of agg, dec_idx), (ORDER_DESC_ASC_LIMIT, (col, agg), desc_asc_limit, prob of desc_asc_limit, dec_idx) 48 | -------------------------------------------------------------------------------- /MISP_SQL/tag_seq_logic.md: -------------------------------------------------------------------------------- 1 | # Documentation for `tag_seq` 2 | 3 | ## What is `tag_seq`? 4 | A `tag_seq` records the semantic meaning of the parser's every decision, including: 5 | - its category (e.g., `SELECT_AGG`), 6 | - content (e.g., the specific aggregator), 7 | - contextual information (e.g., the corresponding column of this aggregator), and 8 | - some meta information (e.g., the index of this decision in action space, the decision probability). 9 | 10 | Each item in a `tag_seq` is usually called a `semantic unit`. A semantic unit can be about the parser's decisions on `SELECT_COL`, `SELECT_AGG`, etc. 11 | 12 | ## How should `tag_seq` be generated? 13 | A `tag_seq` is usually generated while decoding a SQL query, especially when the decoding is grammar-based (so that, for instance, one can easily know the category and meta information). 14 | See [SQLNet's tag_seq implementation](https://github.com/sunlab-osu/MISP/blob/multichoice_q/SQLNet_model/sqlnet/model/sqlnet.py#L247). 15 | 16 | ## How will `tag_seq` be used? 17 | A `tag_seq` will be fed to the `Error Detector (ED)` to find potential mistakes and `Question Generator (QG)` to form questions. 18 | 19 | Basically, the ED module reads the `semantic units` one by one and decides whether the decision is likely to be wrong by examining its meta information. 20 | 21 | When a semantic unit is deemed wrong, it will be passed to QG. QG generates a question about this decision by associating the category and content of this decision with its contextual information - all have been recorded in the `tag_seq`! 22 | 23 | 24 | ## Semantic Unit Definitions 25 | Our framework already defines a comprehensive list of semantic units for common use, as shown below. All _tags (e.g., SELECT_COL, OUTSIDE)_ are defined [here](MISP_SQL/utils.py#L6). 26 | 27 | Note that one can also define their own units to meet different needs. We will give some examples. 28 | 29 | ### Basics 30 | - Column `col`: `col = (tab_name, col_name, col_idx)` # `col_idx` is the index of this column in the column/action space. Such indices will be used to compare with the golden query (when simulating user feedback). 31 | - Aggregator `agg`: `agg = (agg_name, agg_idx) # agg_name in {'min', 'max', 'count', ..., "none_agg" # empty agg}` 32 | - Operator `op`: `op = (op_name, op_idx) # op_name in {'>', '<', ...}` 33 | - Ordering `desc_asc_limit`: `desc_asc_limit = ('asc'/'desc', True/False for limit)` 34 | - Intersect, Union, Except, None: `iuen = (iuen_name, iuen_idx) # iuen_name in {'none', 'intersect', 'union', 'except'}` 35 | 36 | ### SELECT clause 37 | - `(SELECT_COL, col, p(col), dec_idx)` # note that `col` has to follow the aforementioned definition; `dec_idx` is the index of this decision in `dec_seq` (see [the introduction](https://github.com/sunlab-osu/MISP#2-system-architecture)). 38 | - `(SELECT_AGG, col, agg, p(agg), dec_idx)` # note that `agg` has to follow the aforementioned definition; `dec_idx` is the index of this decision (predicting `agg`, not `col`) in `dec_seq`. 39 | 40 | ### WHERE clause 41 | - `(WHERE_COL, col, p(col), dec_idx)` 42 | - `(WHERE_OP, (col,), op, p(op), dec_idx)` 43 | - `(WHERE_VAL, (col,), op, (val_idx, val_str), p(val_str), dec_idx)` # used in WikiSQL; val_idx is the list of word indices 44 | For Spider: 45 | - `(WHERE_ROOTTERM, (col,), op, 'root'/'terminal', p('root'/'terminal'), dec_idx)` # used in Spider 46 | - `(ANDOR, 'and'/'or', [col1, col2, ..], p('and'/'or'), dec_idx)` # [col1, col2, ..] are columns selected in WHERE clause 47 | 48 | ### GROUP BY and HAVING clause 49 | GROUP BY: 50 | - `(GROUP_COL, col, p(col), dec_idx)` 51 | (Optional; model-specific) We also added the following definitions for SyntaxSQLNet: 52 | - `(GROUP_NHAV, "none_having", p("none_having"), dec_idx)` # SyntaxSQLNet has a particular decision on whether to add a HAVING clause; we thus define this unit so this decision can be validated by users as well. 53 | 54 | Note that, the following units about HAVING have to be placed after GROUP BY: 55 | - `(HAV_COL, col, p(col), dec_idx)` 56 | - `(HAV_AGG, col, agg, p(agg), dec_idx)` 57 | - `(HAV_OP, (col, agg), op, p(op), dec_idx)` 58 | - `(HAV_ROOTTERM, (col, agg), op, 'root'/'terminal', p('root'/'terminal'), dec_idx)` 59 | 60 | ### ORDER BY clause 61 | - `(ORDER_COL, col, p(col), dec_idx)` 62 | - `(ORDER_AGG, col, agg, p(agg), dec_idx)` 63 | - `(ORDER_DESC_ASC_LIMIT, (col, agg), desc_asc_limit, p(desc_asc_limit), dec_idx)` 64 | 65 | 66 | ### Intersect, Union, Except, None 67 | - `('IUEN', iuen, p(iuen), dec_idx)` # iuen 68 | 69 | ### Nested queries 70 | **Case 1:** 71 | Our framework allows generating questions for nested queries, such as `SELECT ... WHERE col1 = ( <- this is a nested query -> )`. 72 | However, one has to append a unit `(OUTSIDE, '##END_NESTED##', 1.0, None)` after completing a nested query: 73 | ``` 74 | <- here are units for the main sql-> .. (WHERE_OP, (col,), op, p(op), dec_idx), 75 | (WHERE_ROOTTERM, (col,), op, 'root', p('root'), dec_idx), <- here are units for the nested sql -> (OUTSIDE, '##END_NESTED##', 1.0, None), 76 | <- here are units for the demaining main sql-> 77 | ``` 78 | 79 | **Case 2:** 80 | Nested queries can also happen to queries with Intersect/Union/Except. 81 | 82 | For models like SyntaxSQLNet which first decides `iuen` then generates the main or nested query, its `tag_seq` should look like: 83 | ``` 84 | ('IUEN', iuen, p(iuen), dec_idx), <- here are units for the main sql -> (OUTSIDE, '##END_NESTED##', 1.0, None), 85 | <- followed by units for the nested sql -> 86 | ``` 87 | 88 | For models decoding a SQL query token-by-token (e.g., EditSQL), `##END_NESTED##` may not be necessary, and the `tag_seq` can look like: 89 | ``` 90 | <- here are units for the main sql -> 91 | ('IUEN', iuen, p(iuen), dec_idx) 92 | <- followed by units for the nested sql -> 93 | ``` 94 | 95 | -------------------------------------------------------------------------------- /MISP_SQL/utils.py: -------------------------------------------------------------------------------- 1 | # utils 2 | import numpy as np 3 | import copy 4 | 5 | SELECT_COL = 'SELECT_COL' 6 | SELECT_AGG = 'SELECT_AGG' 7 | WHERE_COL = 'WHERE_COL' 8 | WHERE_OP = 'WHERE_OP' 9 | WHERE_VAL = 'WHERE_VAL' # for models with value prediction 10 | 11 | # spider 12 | WHERE_ROOT_TERM = 'WHERE_ROOT_TERM' 13 | ANDOR = 'ANDOR' 14 | GROUP_COL = 'GROUP_COL' 15 | GROUP_NHAV = 'GROUP_NHAV' 16 | HAV_COL = 'HAV_COL' 17 | HAV_AGG = 'HAV_AGG' 18 | HAV_OP = 'HAV_OP' 19 | HAV_ROOT_TERM = 'HAV_ROOT_TERM' 20 | ORDER_COL = 'ORDER_COL' 21 | ORDER_AGG = 'ORDER_AGG' 22 | ORDER_DESC_ASC_LIMIT = 'DESC_ASC_LIMIT' 23 | IUEN = 'IUEN' 24 | OUTSIDE = "O" 25 | END_NESTED = "##END_NESTED##" 26 | 27 | # spider -> editsql 28 | ORDER_DESC_ASC = 'ORDER_DESC_ASC' # (ORDER_DESC_ASC, (col, agg, bool_distinct), desc_asc, p(desc_asc), dec_idx) 29 | ORDER_LIMIT = 'ORDER_LIMIT' # (ORDER_DESC_ASC, (col, agg, bool_distinct), bool_limit, p(limit), dec_idx) 30 | SELECT_AGG_v2 = 'SELECT_AGG_v2' # (SELECT_AGG_v2, col, agg, bool_distinct, avg_prob, dec_idx) 31 | ORDER_AGG_v2 = 'ORDER_AGG_v2' 32 | HAV_AGG_v2 = 'HAV_AGG_v2' 33 | HAV_OP_v2 = 'HAV_OP_v2' # (HAV_OP_v2, (col, agg, bool_distinct), op, prob(op), dec_idx) 34 | HAV_ROOT_TERM_v2 = 'HAV_ROOT_TERM_v2' # # (HAV_OP_v2, (col, agg, bool_distinct), op, 'root'/'terminal', prob, dec_idx) 35 | IUEN_v2 = 'IUEN_v2' 36 | 37 | 38 | def semantic_unit_segment(tag_seq): 39 | tag_item_lists, seg_pointers = [], [] 40 | for idx, tag_item in enumerate(tag_seq): 41 | if tag_item[0] != OUTSIDE: 42 | tag_item_lists.append(tag_item) 43 | seg_pointers.append(idx) 44 | return tag_item_lists, seg_pointers 45 | 46 | 47 | def helper_find_closest_bw(tag_seq, start_idx, tgt_name=None, tgt_id=None): 48 | skip_nested = [] 49 | idx = start_idx 50 | while idx > 0: 51 | if len(skip_nested) > 0: 52 | if "root" in tag_seq[idx]: 53 | _ = skip_nested.pop() 54 | idx -= 1 55 | else: 56 | if (tgt_name is not None and tgt_name in tag_seq[idx]) or\ 57 | (tgt_id is not None and tag_seq[idx][0] == tgt_id): #include tgt_name == END_NESTED 58 | return idx 59 | elif END_NESTED in tag_seq[idx]: 60 | skip_nested.append(idx) 61 | idx -= 1 62 | else: 63 | idx -= 1 64 | 65 | return -1 # not found 66 | 67 | 68 | class bcolors: 69 | """ 70 | Usage: print bcolors.WARNING + "Warning: No active frommets remain. Continue?" + bcolors.ENDC 71 | """ 72 | PINK = '\033[95m' 73 | BLUE = '\033[94m' 74 | GREEN = '\033[92m' 75 | YELLOW = '\033[93m' 76 | RED = '\033[91m' 77 | ENDC = '\033[0m' 78 | BOLD = '\033[1m' 79 | UNDERLINE = '\033[4m' 80 | 81 | 82 | class Hypothesis: 83 | def __init__(self, dec_prefix): 84 | self.sql = None 85 | # Note: do not create hyp from scratch during decoding (may lead to wrong self.dec_prefix) 86 | self.dec_prefix = list(dec_prefix) # given decoding prefix, must execute 87 | 88 | self.tag_seq = [] # sequence of tags 89 | self.dec_seq = [] # sequence of decisions 90 | self.dec_seq_idx = 0 91 | 92 | self.logprob = 0.0 93 | self.length = 0 94 | self.logprob_list = None 95 | 96 | self.pred_aux_seq = [] # auxiliary information 97 | 98 | def copy(self): 99 | return copy.deepcopy(self) 100 | 101 | def add_logprob(self, logprob): 102 | self.logprob += logprob 103 | self.length += 1 104 | 105 | def set_passes_mode(self, dropout_hyp): 106 | self.test_tag_seq = list(self.tag_seq) # from decode without dropout 107 | 108 | for tag_idx, tag in enumerate(dropout_hyp.tag_seq): 109 | item_lst = list(tag) 110 | item_lst[-2] = [item_lst[-2]] 111 | self.tag_seq[tag_idx] = item_lst 112 | 113 | self.logprob_list = [dropout_hyp.logprob] 114 | 115 | def merge_hyp(self, hyp): 116 | # tag_seq, dec_seq, dec_seq_idx, logprob 117 | assert len(hyp.tag_seq) == len(self.tag_seq) 118 | for item_idx in range(len(hyp.tag_seq)): 119 | new_item = hyp.tag_seq[item_idx] 120 | self.tag_seq[item_idx][-2].append(new_item[-2]) 121 | 122 | self.logprob_list.append(hyp.logprob) 123 | 124 | @staticmethod 125 | def length_penalty(sent_length, length_penalty_factor): 126 | # Following: https://arxiv.org/abs/1609.08144, Eqn 14, recommend factor = 0.6-0.7. 127 | # return ((5. + sent_length) / 6.) ** length_penalty_factor 128 | return (1.0 * sent_length) ** length_penalty_factor 129 | 130 | @staticmethod 131 | def sort_hypotheses(hypotheses, topK, length_penalty_factor): 132 | if topK is None: 133 | topK = np.inf 134 | sorted_hyps = sorted(hypotheses, key=lambda x: x.logprob / Hypothesis.length_penalty(x.length, length_penalty_factor), 135 | reverse=True) 136 | return_hypotheses = [] 137 | last_score = None 138 | count = 0 139 | for hyp in sorted_hyps: 140 | current_score = hyp.logprob / Hypothesis.length_penalty(hyp.length, length_penalty_factor) 141 | if last_score is None or current_score < last_score: 142 | if count < topK: 143 | return_hypotheses.append(hyp) 144 | last_score = current_score 145 | count += 1 146 | else: 147 | break 148 | else: 149 | assert current_score == last_score # tie, include 150 | return_hypotheses.append(hyp) 151 | return return_hypotheses 152 | 153 | @staticmethod 154 | def print_hypotheses(hypotheses): 155 | for hyp in hypotheses: 156 | print("logprob: {}, tag_seq: {}\ndec_seq: {}".format(hyp.logprob, hyp.tag_seq, hyp.dec_seq)) 157 | 158 | 159 | -------------------------------------------------------------------------------- /MISP_SQL/world_model.py: -------------------------------------------------------------------------------- 1 | # world model 2 | from collections import defaultdict 3 | 4 | 5 | class WorldModel: 6 | """ 7 | This is the class for world modeling, which takes charge of semantic parsing and user feedback incorporation. 8 | """ 9 | def __init__(self, semparser, num_options, num_passes=1, dropout_rate=0.0): 10 | """ 11 | Constructor of WorldModel. 12 | :param semparser: the base semantic parser. 13 | :param num_options: number of choices (except "none of the above"). 14 | :param num_passes: number of passes for Bayesian dropout-based decoding. 15 | :param dropout_rate: dropout rate for Bayesian dropout-based decoding. 16 | """ 17 | self.semparser = semparser 18 | self.num_options = num_options 19 | 20 | self.passes = num_passes 21 | self.dropout_rate = dropout_rate 22 | 23 | # used in feedback incorporation 24 | self.avoid_items = defaultdict(set) # a record of {decoding position: set of negated decisions} 25 | self.confirmed_items = defaultdict(set) # a record of {decoding position: set of confirmed decisions} 26 | 27 | def clear(self): 28 | """ 29 | Clear session records. 30 | :return: 31 | """ 32 | self.avoid_items = defaultdict(set) 33 | self.confirmed_items = defaultdict(set) 34 | 35 | def decode_per_pass(self, input_item, dec_beam_size=1, dec_prefix=None, stop_step=None, 36 | avoid_items=None, confirmed_items=None, dropout_rate=0.0, 37 | bool_collect_choices=False, bool_verbal=False): 38 | """ 39 | Semantic parsing in one pass. This function will be used for (1) Regular greedy decoding; 40 | (2) Performing one-step beam search to generate alternative choices. 41 | :param input_item: input to the parser (parser-specific). 42 | :param dec_beam_size: beam search size (int). 43 | :param dec_prefix: the prefix decoding sequence (list); used when generating alternative choices. 44 | If specified, the generated queries should share this prefix sequence. 45 | :param stop_step: the decoding step to terminate (int); used when generating alternative choices. 46 | If specified, the decoding should terminate at this step. When dec_beam_size > 1, the last step 47 | in each decoding sequence will be considered as one choice. 48 | :param avoid_items: a dict of {decoding step: negated decision candidates}. 49 | If specified, negated choices will not be considered when the decoding proceeds to the according step. 50 | :param confirmed_items: a dict of {decoding step: confirmed decision candidates}. 51 | If specified, confirmed choices will be selected when the decoding proceeds to the according step. 52 | :param dropout_rate: dropout rate in Bayesian dropout (float). 53 | :param bool_collect_choices: Set to True to collect choices; used when generating alternative choices. 54 | :param bool_verbal: Set to True to print intermediate information. 55 | :return: a list of possible hypotheses (class: utils.Hypothesis). 56 | """ 57 | raise NotImplementedError 58 | 59 | def decode(self, input_item, dec_beam_size=1, dec_prefix=None, stop_step=None, 60 | avoid_items=None, confirmed_items=None, bool_collect_choices=False, bool_verbal=False): 61 | """ 62 | Semantic parsing. This function wraps the decode_per_pass function so the latter can be called for 63 | multiple times (when self.passes > 1) to calculate Bayesian dropout-based uncertainty. 64 | :param input_item: input to the parser (parser-specific). 65 | :param dec_beam_size: beam search size (int). 66 | :param dec_prefix: the prefix decoding sequence (list); used when generating alternative choices. 67 | If specified, the generated queries should share this prefix sequence. 68 | :param stop_step: the decoding step to terminate (int); used when generating alternative choices. 69 | If specified, the decoding should terminate at this step. When dec_beam_size > 1, the last step 70 | in each decoding sequence will be considered as one choice. 71 | :param avoid_items: a dict of {decoding step: negated decision candidates}. 72 | If specified, negated choices will not be considered when the decoding proceeds to the according step. 73 | :param confirmed_items: a dict of {decoding step: confirmed decision candidates}. 74 | If specified, confirmed choices will be selected when the decoding proceeds to the according step. 75 | :param bool_collect_choices: Set to True to collect choices; used when generating alternative choices. 76 | :param bool_verbal: Set to True to show intermediate information. 77 | :return: a list of possible hypotheses (class: utils.Hypothesis). 78 | """ 79 | # decode without dropout 80 | hypotheses = self.decode_per_pass(input_item, dec_beam_size=dec_beam_size, dec_prefix=dec_prefix, 81 | stop_step=stop_step, avoid_items=avoid_items, 82 | confirmed_items=confirmed_items, 83 | bool_collect_choices=bool_collect_choices, 84 | bool_verbal=bool_verbal) 85 | if self.passes == 1 or bool_collect_choices: 86 | return hypotheses 87 | 88 | # for Bayesian dropout-based decoding, re-decode the same output with dropout 89 | for hyp in hypotheses: 90 | for pass_idx in range(self.passes): 91 | dropout_hyp = self.decode_per_pass(input_item, dec_prefix=hyp.dec_seq, stop_step=stop_step, 92 | dropout_rate=self.dropout_rate)[0] 93 | if pass_idx == 0: 94 | hyp.set_passes_mode(dropout_hyp) 95 | else: 96 | hyp.merge_hyp(dropout_hyp) 97 | return hypotheses 98 | 99 | def apply_pos_feedback(self, semantic_unit, dec_seq, dec_prefix): 100 | """ 101 | Incorporate users' positive feedback (a confirmed semantic unit). The incorporation 102 | is usually achieved by (1) extending the current prefix decoding sequence (dec_prefix) 103 | with the confirmed decision and/or (2) adding the confirmed decision into 104 | self.confirmed_items[dec_idx] (dec_idx is the decoding position of the validated decision). 105 | :param semantic_unit: a confirmed semantic unit. 106 | :param dec_seq: the decoding sequence paired with the confirmed semantic unit. 107 | :param dec_prefix: the current prefix decoding sequence that has been confirmed. 108 | :return: the updated prefix decoding sequence (list) that has been confirmed. 109 | """ 110 | raise NotImplementedError 111 | 112 | def apply_neg_feedback(self, semantic_unit, dec_seq, dec_prefix): 113 | """ 114 | Incorporate users' negative feedback (a negated semantic unit). The incorporation 115 | is usually achieved by (1) adding the negated decision into self.avoid_items[dec_idx] 116 | (dec_idx is the decoding position of the validated decision) and/or (2) revising the 117 | current prefix decoding sequence (dec_prefix) - this is particularly useful for semantic 118 | units with unit_type=1 (which have binary choices, e.g., AND/OR, DESC/ASC); once the 119 | current decision is negated, the alternative one can be automatically selected. 120 | :param semantic_unit: a negated semantic unit. 121 | :param dec_seq: the decoding sequence paired with the negated semantic unit. 122 | :param dec_prefix: the current prefix decoding sequence that has been confirmed. 123 | :return: the updated prefix decoding sequence (list) that has been confirmed. 124 | """ 125 | raise NotImplementedError 126 | 127 | def decode_revised_structure(self, semantic_unit, pointer, hyp, input_item, bool_verbal=False): 128 | """ 129 | Revise query structure (as the side effect of user feedback incorporation). For example, 130 | when the user negated all available columns being WHERE_COL, this function removes the 131 | WHERE clause. The function is OPTIONAL. 132 | :param semantic_unit: the questioned semantic unit. 133 | :param pointer: the pointer to the questioned semantic unit. 134 | :param hyp: the SQL hypothesis. 135 | :param input_item: input to the parser (parser-specific). 136 | :param bool_verbal: set to True to show intermediate information. 137 | :return: the updated pointer in tag_seq, the updated hypothesis. 138 | """ 139 | # raise NotImplementedError 140 | return pointer, hyp 141 | 142 | def refresh_decoding(self, input_item, dec_prefix, old_hyp, semantic_unit, 143 | pointer, sel_none_of_above, user_selections, bool_verbal=False): 144 | """ 145 | Refreshing the decoding after feedback incorporation. 146 | :param input_item: the input to decoder. 147 | :param dec_prefix: the current prefix decoding sequence that has been confirmed. 148 | :param old_hyp: the old decoding hypothesis. 149 | :param semantic_unit: the semantic unit questioned in current interaction. 150 | :param pointer: the position of the questioned semantic unit in tag_seq. 151 | :param sel_none_of_above: the option index corresponding to "none of the above". 152 | :param user_selections: user selections (list of option indices). 153 | :param bool_verbal: set to True to show intermediate information. 154 | :return: the pointer to the next semantic unit to examine, the updated hypothesis. 155 | """ 156 | raise NotImplementedError 157 | 158 | -------------------------------------------------------------------------------- /SQLova_model/README.md: -------------------------------------------------------------------------------- 1 | # SQLova Experiments 2 | 3 | ## 1. Description 4 | This folder contains implementation of **interactive SQLova parser**, which uses SQLova as a base semantic parser in our MISP framework: 5 | - Please follow [2. General Environment Setup](#2-general-environment-setup) and set up the environment/data; 6 | - For testing interactive SQLova on the fly (our EMNLP'19 setting), see [3. MISP with SQLova (EMNLP'19)](#3-misp-with-sqlova-emnlp19); 7 | - For learning SQLova from user interaction (our EMNLP'20 setting), see [4. Learning SQLova from user interaction (EMNLP'20)](#4-learning-sqlova-from-user-interaction-emnlp20). 8 | 9 | The implementation is adapted from [the SQLova repository](https://github.com/naver/sqlova). 10 | Please cite the following papers if you use the code: 11 | 12 | ``` 13 | @inproceedings{yao2020imitation, 14 | title={An Imitation Game for Learning Semantic Parsers from User Interaction}, 15 | author={Yao, Ziyu and Tang, Yiqi and Yih, Wen-tau and Sun, Huan and Su, Yu}, 16 | booktitle={Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)}, 17 | year={2020} 18 | } 19 | 20 | @inproceedings{yao2019model, 21 | title={Model-based Interactive Semantic Parsing: A Unified Framework and A Text-to-SQL Case Study}, 22 | author={Yao, Ziyu and Su, Yu and Sun, Huan and Yih, Wen-tau}, 23 | booktitle={Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)}, 24 | pages={5450--5461}, 25 | year={2019} 26 | } 27 | 28 | @article{hwang2019comprehensive, 29 | title={A Comprehensive Exploration on WikiSQL with Table-Aware Word Contextualization}, 30 | author={Hwang, Wonseok and Yim, Jinyeung and Park, Seunghyun and Seo, Minjoon}, 31 | journal={arXiv preprint arXiv:1902.01069}, 32 | year={2019} 33 | } 34 | ``` 35 | 36 | ## 2. General Environment Setup 37 | ### Environment 38 | - Please install the Anaconda environment from [`gpu-py3.yml`](../gpu-py3.yml): 39 | ``` 40 | conda env create -f gpu-py3.yml 41 | ``` 42 | 43 | - Download Pretrained BERT model from [here](https://drive.google.com/file/d/1f_LEWVgrtZLRuoiExJa5fNzTS8-WcAX9/view?usp=sharing) as 44 | `SQLova_model/download/bert/pytorch_model_uncased_L-12_H-768_A-12.bin`. 45 | 46 | ### Data 47 | We have the pre-processed [WikiSQL data](https://github.com/salesforce/WikiSQL) available: [data.tar](https://buckeyemailosu-my.sharepoint.com/:u:/g/personal/chen_8336_buckeyemail_osu_edu/EU830MxSHp5LqJix8Am5znMBTe-RiKYi-9XB0UgMyqpOzA?e=yiGyik). Please download and uncompress it via `tar -xvf data.tar` as a folder `SQLova_model/download/data`. 48 | 49 | If you would like to preprocess the WikiSQL data (or your own data) from scratch, please follow the [`data_preprocess.sh`](../scripts/sqlova/data_preprocess.sh) script. 50 | 51 | 52 | ## 3. MISP with SQLova (EMNLP'19) 53 | ### 3.1 Model training 54 | To train SQLova on the full training set, please revise `SETTING=full_train` in [scripts/sqlova/pretrain.sh](../scripts/sqlova/pretrain.sh). 55 | In the main directory, run: 56 | ``` 57 | bash scripts/sqlova/pretrain.sh 58 | ``` 59 | 60 | ### 3.2 Model testing without interaction 61 | To test SQLova regularly, in [scripts/sqlova/test.sh](../scripts/sqlova/test.sh), please revise `SETTING` 62 | to ensure that the model checkpoint is loaded from the desired `MODEL_DIR` folder and revise `TEST_JOB` for testing on WikiSQL dev/test set. 63 | In the main directory, run: 64 | ``` 65 | bash scripts/sqlova/test.sh 66 | ``` 67 | 68 | ### 3.3 Model testing with simulated user interaction 69 | To test SQLova with human interaction under the MISP framework, in [scripts/sqlova/test_with_interaction.sh](../scripts/sqlova/test_with_interaction.sh), 70 | revise `SETTING` to ensure that the model checkpoint is loaded from the desired `MODEL_DIR` folder and revise `DATA` for testing on WikiSQL dev/test set. 71 | In the main directory, run: 72 | ``` 73 | bash scripts/sqlova/test_with_interaction.sh 74 | ``` 75 | 76 | 77 | ## 4. Learning SQLova from user interaction (EMNLP'20) 78 | Throughout the experiments, we consider three initialization settings: 79 | - `SETTING=online_pretrain_1p` for using 1% of full training data for initialization; 80 | - `SETTING=online_pretrain_5p` for using 5% of full training data for initialization; 81 | - `SETTING=online_pretrain_10p` for using 10% of full training data for initialization. 82 | 83 | Please revise the `SETTING` variable in each script accordingly. 84 | 85 | ### 4.1 Pretraining 86 | 87 | #### 4.1.1 Pretrain by yourself 88 | Before interactive learning, we pretrain the SQLova parser with a small subset of the full training data. 89 | Please revise `SETTING` in [scripts/sqlova/pretrain.sh](../scripts/sqlova/pretrain.sh) accordingly for different initialization settings. 90 | Then in the main directory, run: 91 | ``` 92 | bash scripts/sqlova/pretrain.sh 93 | ``` 94 | 95 | #### 4.1.2 Use our pretrained checkpoints 96 | You can also use our pretrained checkpoints: [initialization_checkpoints_folder.tar](https://buckeyemailosu-my.sharepoint.com/:u:/g/personal/chen_8336_buckeyemail_osu_edu/EcaWjAIgpOJGhpRG0Lh_nW0BW-CicnjLhAX4gNgCSPGzXQ?e=qx9HA4). Please download and uncompress the folder via `tar -xvf initialization_checkpoints_folder.tar` and place the content as: 97 | ``` 98 | |- SQLova_model 99 | | |-- checkpoints_onlint_pretrain_1p 100 | | |-- model_best.pt 101 | | |-- model_bert_best.pt 102 | | |-- checkpoints_onlint_pretrain_5p 103 | | |-- checkpoints_onlint_pretrain_10p 104 | ``` 105 | 106 | #### 4.1.3 Test the pretrained models 107 | To test the pretrained parser without user interaction, see [3.2 Model testing without interaction](#32-model-testing-without-interaction). 108 | To test the pretrained parser with simulated user interaction, see [3.3 Model testing with simulated user interaction](#33-model-testing-with-simulated-user-interaction). 109 | Make sure the `SETTING` variable is set correctly. 110 | 111 | ### 4.2 Interactive learning 112 | 113 | The training script for each algorithm can be found below. Please run them in the main directory and 114 | remember to set `SETTING` accordingly for different initialization settings. 115 | 116 | | Algorithm | Script | 117 | | ------------- | ------------- | 118 | | MISP_NEIL | [`scripts/sqlova/misp_neil.sh`](../scripts/sqlova/misp_neil.sh) | 119 | | Full Expert | [`scripts/sqlova/full_expert.sh`](../scripts/sqlova/full_expert.sh) | 120 | | Binary User | [`scripts/sqlova/bin_user.sh`](../scripts/sqlova/bin_user.sh) | 121 | | Binary User+Expert | [`scripts/sqlova/bin_user_expert.sh`](../scripts/sqlova/bin_user_expert.sh) | 122 | | Self Train | [`scripts/sqlova/self_train_0.5.sh`](../scripts/sqlova/self_train_0.5.sh) | 123 | | MISP_NEIL* | [`scripts/sqlova/misp_neil_perfect.sh`](../scripts/sqlova/misp_neil_perfect.sh) | 124 | 125 | -------------------------------------------------------------------------------- /SQLova_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlab-osu/MISP/7870566ab6b9e121d648478968367bc79c12f7ef/SQLova_model/__init__.py -------------------------------------------------------------------------------- /SQLova_model/agent.py: -------------------------------------------------------------------------------- 1 | from MISP_SQL.agent import Agent as BaseAgent 2 | from .sqlova.utils.utils_wikisql import * 3 | 4 | 5 | class Agent(BaseAgent): 6 | def __init__(self, world_model, error_detector, question_generator, bool_mistake_exit, 7 | bool_structure_question=False): 8 | BaseAgent.__init__(self, world_model, error_detector, question_generator, 9 | bool_mistake_exit=bool_mistake_exit, 10 | bool_structure_question=bool_structure_question) 11 | 12 | def evaluation(self, p_list, g_list, engine, tb, bool_verbal=False): 13 | pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi, pr_sql_i = p_list 14 | g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi, sql_i = g_list 15 | 16 | cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \ 17 | cnt_wc1_list, cnt_wo1_list, \ 18 | cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi, 19 | pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi, 20 | sql_i, pr_sql_i, 21 | mode='test') 22 | 23 | cnt_lx1_list = get_cnt_lx_list(cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list, 24 | cnt_wo1_list, cnt_wv1_list) 25 | lx_correct = sum(cnt_lx1_list) # lx stands for logical form accuracy 26 | 27 | cnt_x1_list, g_ans, pr_ans = get_cnt_x_list(engine, tb, g_sc, g_sa, sql_i, pr_sc, pr_sa, pr_sql_i) 28 | x_correct = sum(cnt_x1_list) 29 | 30 | cnt_list1 = [cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list, cnt_wo1_list, cnt_wv1_list, cnt_lx1_list, 31 | cnt_x1_list] 32 | 33 | if bool_verbal: 34 | print("lf correct: {}, x correct: {}, cnt_list: {}".format(lx_correct, x_correct, cnt_list1)) 35 | 36 | return cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list, cnt_wo1_list, cnt_wv1_list, cnt_wvi1_list, \ 37 | cnt_lx1_list, cnt_x1_list, cnt_list1, g_ans, pr_ans -------------------------------------------------------------------------------- /SQLova_model/annotate_ws.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # docker run --name corenlp -d -p 9000:9000 vzhong/corenlp-server 3 | # Wonseok Hwang. Jan 6 2019, Comment added 4 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 5 | import os 6 | import records 7 | import ujson as json 8 | from stanza.nlp.corenlp import CoreNLPClient 9 | from tqdm import tqdm 10 | import copy 11 | from wikisql.lib.common import count_lines, detokenize 12 | from wikisql.lib.query import Query 13 | 14 | 15 | client = None 16 | 17 | 18 | def annotate(sentence, lower=True): 19 | global client 20 | if client is None: 21 | client = CoreNLPClient(default_annotators='ssplit,tokenize'.split(',')) 22 | words, gloss, after = [], [], [] 23 | for s in client.annotate(sentence): 24 | for t in s: 25 | words.append(t.word) 26 | gloss.append(t.originalText) 27 | after.append(t.after) 28 | if lower: 29 | words = [w.lower() for w in words] 30 | return { 31 | 'gloss': gloss, 32 | 'words': words, 33 | 'after': after, 34 | } 35 | 36 | 37 | def annotate_example(example, table): 38 | ann = {'table_id': example['table_id']} 39 | ann['question'] = annotate(example['question']) 40 | ann['table'] = { 41 | 'header': [annotate(h) for h in table['header']], 42 | } 43 | ann['query'] = sql = copy.deepcopy(example['sql']) 44 | for c in ann['query']['conds']: 45 | c[-1] = annotate(str(c[-1])) 46 | 47 | q1 = 'SYMSELECT SYMAGG {} SYMCOL {}'.format(Query.agg_ops[sql['agg']], table['header'][sql['sel']]) 48 | q2 = ['SYMCOL {} SYMOP {} SYMCOND {}'.format(table['header'][col], Query.cond_ops[op], detokenize(cond)) for col, op, cond in sql['conds']] 49 | if q2: 50 | q2 = 'SYMWHERE ' + ' SYMAND '.join(q2) + ' SYMEND' 51 | else: 52 | q2 = 'SYMEND' 53 | inp = 'SYMSYMS {syms} SYMAGGOPS {aggops} SYMCONDOPS {condops} SYMTABLE {table} SYMQUESTION {question} SYMEND'.format( 54 | syms=' '.join(['SYM' + s for s in Query.syms]), 55 | table=' '.join(['SYMCOL ' + s for s in table['header']]), 56 | question=example['question'], 57 | aggops=' '.join([s for s in Query.agg_ops]), 58 | condops=' '.join([s for s in Query.cond_ops]), 59 | ) 60 | ann['seq_input'] = annotate(inp) 61 | out = '{q1} {q2}'.format(q1=q1, q2=q2) if q2 else q1 62 | ann['seq_output'] = annotate(out) 63 | ann['where_output'] = annotate(q2) 64 | assert 'symend' in ann['seq_output']['words'] 65 | assert 'symend' in ann['where_output']['words'] 66 | return ann 67 | 68 | def find_sub_list(sl, l): 69 | # from stack overflow. 70 | results = [] 71 | sll = len(sl) 72 | for ind in (i for i, e in enumerate(l) if e == sl[0]): 73 | if l[ind:ind + sll] == sl: 74 | results.append((ind, ind + sll - 1)) 75 | 76 | return results 77 | 78 | def check_wv_tok_in_nlu_tok(wv_tok1, nlu_t1): 79 | """ 80 | Jan.2019: Wonseok 81 | Generate SQuAD style start and end index of wv in nlu. Index is for of after WordPiece tokenization. 82 | 83 | Assumption: where_str always presents in the nlu. 84 | 85 | return: 86 | st_idx of where-value string token in nlu under CoreNLP tokenization scheme. 87 | """ 88 | g_wvi1_corenlp = [] 89 | nlu_t1_low = [tok.lower() for tok in nlu_t1] 90 | for i_wn, wv_tok11 in enumerate(wv_tok1): 91 | wv_tok11_low = [tok.lower() for tok in wv_tok11] 92 | results = find_sub_list(wv_tok11_low, nlu_t1_low) 93 | st_idx, ed_idx = results[0] 94 | 95 | g_wvi1_corenlp.append( [st_idx, ed_idx] ) 96 | 97 | return g_wvi1_corenlp 98 | 99 | 100 | def annotate_example_ws(example, table): 101 | """ 102 | Jan. 2019: Wonseok 103 | Annotate only the information that will be used in our model. 104 | """ 105 | ann = {'table_id': example['table_id'],'phase': example['phase']} 106 | _nlu_ann = annotate(example['question']) 107 | ann['question'] = example['question'] 108 | ann['question_tok'] = _nlu_ann['gloss'] 109 | # ann['table'] = { 110 | # 'header': [annotate(h) for h in table['header']], 111 | # } 112 | ann['sql'] = example['sql'] 113 | ann['query'] = sql = copy.deepcopy(example['sql']) 114 | 115 | conds1 = ann['sql']['conds'] 116 | wv_ann1 = [] 117 | for conds11 in conds1: 118 | _wv_ann1 = annotate(str(conds11[2])) 119 | wv_ann11 = _wv_ann1['gloss'] 120 | wv_ann1.append( wv_ann11 ) 121 | 122 | # Check whether wv_ann exsits inside question_tok 123 | 124 | try: 125 | wvi1_corenlp = check_wv_tok_in_nlu_tok(wv_ann1, ann['question_tok']) 126 | ann['wvi_corenlp'] = wvi1_corenlp 127 | except: 128 | ann['wvi_corenlp'] = None 129 | ann['tok_error'] = 'SQuAD style st, ed are not found under CoreNLP.' 130 | 131 | return ann 132 | 133 | 134 | def is_valid_example(e): 135 | if not all([h['words'] for h in e['table']['header']]): 136 | return False 137 | headers = [detokenize(h).lower() for h in e['table']['header']] 138 | if len(headers) != len(set(headers)): 139 | return False 140 | input_vocab = set(e['seq_input']['words']) 141 | for w in e['seq_output']['words']: 142 | if w not in input_vocab: 143 | print('query word "{}" is not in input vocabulary.\n{}'.format(w, e['seq_input']['words'])) 144 | return False 145 | input_vocab = set(e['question']['words']) 146 | for col, op, cond in e['query']['conds']: 147 | for w in cond['words']: 148 | if w not in input_vocab: 149 | print('cond word "{}" is not in input vocabulary.\n{}'.format(w, e['question']['words'])) 150 | return False 151 | return True 152 | 153 | 154 | if __name__ == '__main__': 155 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 156 | parser.add_argument('--din', default='/Users/wonseok/data/WikiSQL-1.1/data', help='data directory') 157 | parser.add_argument('--dout', default='/Users/wonseok/data/wikisql_tok', help='output directory') 158 | args = parser.parse_args() 159 | 160 | answer_toy = not True 161 | toy_size = 10 162 | 163 | if not os.path.isdir(args.dout): 164 | os.makedirs(args.dout) 165 | 166 | # for split in ['train', 'dev', 'test']: 167 | for split in ['train', 'dev', 'test']: 168 | fsplit = os.path.join(args.din, split) + '.jsonl' 169 | ftable = os.path.join(args.din, split) + '.tables.jsonl' 170 | fout = os.path.join(args.dout, split) + '_tok.jsonl' 171 | 172 | print('annotating {}'.format(fsplit)) 173 | with open(fsplit) as fs, open(ftable) as ft, open(fout, 'wt') as fo: 174 | print('loading tables') 175 | 176 | # ws: Construct table dict with table_id as a key. 177 | tables = {} 178 | for line in tqdm(ft, total=count_lines(ftable)): 179 | d = json.loads(line) 180 | tables[d['id']] = d 181 | print('loading examples') 182 | n_written = 0 183 | cnt = -1 184 | for line in tqdm(fs, total=count_lines(fsplit)): 185 | cnt += 1 186 | d = json.loads(line) 187 | # a = annotate_example(d, tables[d['table_id']]) 188 | a = annotate_example_ws(d, tables[d['table_id']]) 189 | fo.write(json.dumps(a) + '\n') 190 | n_written += 1 191 | 192 | if answer_toy: 193 | if cnt > toy_size: 194 | break 195 | print('wrote {} examples'.format(n_written)) 196 | -------------------------------------------------------------------------------- /SQLova_model/bert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlab-osu/MISP/7870566ab6b9e121d648478968367bc79c12f7ef/SQLova_model/bert/__init__.py -------------------------------------------------------------------------------- /SQLova_model/bert/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import argparse 23 | import tensorflow as tf 24 | import torch 25 | import numpy as np 26 | 27 | from modeling import BertConfig, BertModel 28 | 29 | parser = argparse.ArgumentParser() 30 | 31 | ## Required parameters 32 | parser.add_argument("--tf_checkpoint_path", 33 | default = None, 34 | type = str, 35 | required = True, 36 | help = "Path the TensorFlow checkpoint path.") 37 | parser.add_argument("--bert_config_file", 38 | default = None, 39 | type = str, 40 | required = True, 41 | help = "The config json file corresponding to the pre-trained BERT model. \n" 42 | "This specifies the model architecture.") 43 | parser.add_argument("--pytorch_dump_path", 44 | default = None, 45 | type = str, 46 | required = True, 47 | help = "Path to the output PyTorch model.") 48 | 49 | args = parser.parse_args() 50 | 51 | def convert(): 52 | # Initialise PyTorch model 53 | config = BertConfig.from_json_file(args.bert_config_file) 54 | model = BertModel(config) 55 | 56 | # Load weights from TF model 57 | path = args.tf_checkpoint_path 58 | print("Converting TensorFlow checkpoint from {}".format(path)) 59 | 60 | init_vars = tf.train.list_variables(path) 61 | names = [] 62 | arrays = [] 63 | for name, shape in init_vars: 64 | print("Loading {} with shape {}".format(name, shape)) 65 | array = tf.train.load_variable(path, name) 66 | print("Numpy array shape {}".format(array.shape)) 67 | names.append(name) 68 | arrays.append(array) 69 | 70 | for name, array in zip(names, arrays): 71 | name = name[5:] # skip "bert/" 72 | print("Loading {}".format(name)) 73 | name = name.split('/') 74 | if name[0] in ['redictions', 'eq_relationship']: 75 | print("Skipping") 76 | continue 77 | pointer = model 78 | for m_name in name: 79 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name): 80 | l = re.split(r'_(\d+)', m_name) 81 | else: 82 | l = [m_name] 83 | if l[0] == 'kernel': 84 | pointer = getattr(pointer, 'weight') 85 | else: 86 | pointer = getattr(pointer, l[0]) 87 | if len(l) >= 2: 88 | num = int(l[1]) 89 | pointer = pointer[num] 90 | if m_name[-11:] == '_embeddings': 91 | pointer = getattr(pointer, 'weight') 92 | elif m_name == 'kernel': 93 | array = np.transpose(array) 94 | try: 95 | assert pointer.shape == array.shape 96 | except AssertionError as e: 97 | e.args += (pointer.shape, array.shape) 98 | raise 99 | pointer.data = torch.from_numpy(array) 100 | 101 | # Save pytorch-model 102 | torch.save(model.state_dict(), args.pytorch_dump_path) 103 | 104 | if __name__ == "__main__": 105 | convert() 106 | -------------------------------------------------------------------------------- /SQLova_model/download/bert/bert_config_uncased_L-12_H-768_A-12.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 12, 11 | "type_vocab_size": 2, 12 | "vocab_size": 30522 13 | } 14 | -------------------------------------------------------------------------------- /SQLova_model/environment.py: -------------------------------------------------------------------------------- 1 | from prettytable import PrettyTable 2 | from MISP_SQL.environment import UserSim, RealUser as BaseRealUser, ErrorEvaluator as BaseErrorEvaluator, \ 3 | GoldUserSim as BaseGoldUserSim 4 | from MISP_SQL.utils import * 5 | 6 | 7 | class RealUser(BaseRealUser): 8 | def __init__(self, error_evaluator, tables, bool_undo=True): 9 | BaseRealUser.__init__(self, error_evaluator, bool_undo=bool_undo) 10 | 11 | self.tables = tables 12 | 13 | def show_table(self, table_id): 14 | table = self.tables[table_id] 15 | print(bcolors.BLUE + bcolors.BOLD + "{}".format(table['header']) + bcolors.ENDC) 16 | 17 | # basic information 18 | ''' 19 | print(bcolors.BOLD + "Example rows:" + bcolors.ENDC) 20 | x = PrettyTable() 21 | x.field_names = table['header'] 22 | for row in table['rows'][:3]: 23 | x.add_row(row) 24 | print(x) 25 | 26 | print("\n" + bcolors.BOLD + "Additional info. about this table: " + bcolors.ENDC) 27 | for key, key_alias in zip(["page_title", "section_title", "caption"], 28 | ["Page Title", "Section Title", "Table Caption"]): 29 | if key in table: 30 | print("{}: {}".format(key_alias, table[key])) 31 | print("") 32 | ''' 33 | # for key, key_alias in zip(["page_title", "section_title", "caption"], 34 | # ["Page Title", "Section Title", "Table Caption"]): 35 | # if key in table: 36 | # print("{}: {}".format(key_alias, table[key])) 37 | # 38 | # print("\n") 39 | #x = PrettyTable() 40 | #x.field_names = table['header'] 41 | #print(bcolors.BLUE + bcolors.BOLD + "{}".format(table['header']) + bcolors.ENDC) 42 | for c,row in enumerate(table['rows']): 43 | #x.add_row(row) 44 | if c == 1 or c == 2: 45 | print(row) 46 | if c > 2: 47 | break 48 | 49 | #print(x) 50 | 51 | #print(bcolors.BLUE + bcolors.BOLD + "{}".format(table['header']) + bcolors.ENDC) 52 | 53 | 54 | class ErrorEvaluator(BaseErrorEvaluator): 55 | def __init__(self): 56 | BaseErrorEvaluator.__init__(self) 57 | 58 | def compare(self, g_sql, start_idx, tag_seq, bool_return_true_selections=False, 59 | bool_return_true_semantic_units=False): 60 | # g_sql should look like {"{"sel":3,"conds":[[5,0,"Butler CC (KS)"]],"agg":0, 61 | # "g_wvi": [[st1,end1],[st2, end2],..]} # g_wvi is added from raw data 62 | lower_cased_conds = [] 63 | for (g_col_idx, g_op_idx, g_val_str), st_end in zip(g_sql['conds'], g_sql["g_wvi"]): 64 | g_val_str = str(g_val_str).lower() 65 | lower_cased_conds.append((g_col_idx, g_op_idx, g_val_str, st_end)) 66 | 67 | eval_output, true_selections, true_semantic_units = [], [], [] 68 | idx = start_idx 69 | while idx < len(tag_seq): 70 | semantic_tag = tag_seq[idx][0] 71 | if semantic_tag == OUTSIDE: 72 | eval_output.append(None) 73 | true_selections.append(None) 74 | true_semantic_units.append(None) 75 | idx += 1 76 | 77 | elif semantic_tag == SELECT_COL: 78 | eval_output.append(tag_seq[idx][1][-1] == g_sql["sel"]) 79 | true_selections.append([g_sql["sel"]]) 80 | 81 | new_su = list(tag_seq[idx]) 82 | new_su[1] = (None, None, g_sql["sel"]) 83 | true_semantic_units.append([tuple(new_su)]) 84 | 85 | idx += 1 86 | 87 | elif semantic_tag == SELECT_AGG: 88 | col_item, agg_item = tag_seq[idx][1:3] 89 | col_idx = col_item[-1] 90 | agg_idx = agg_item[-1] 91 | 92 | eval_output.append(agg_idx == g_sql['agg']) # TODO: associate with sel? 93 | true_selections.append([(col_idx, g_sql['agg'])]) 94 | 95 | new_su = list(tag_seq[idx]) 96 | new_su[2] = (None, g_sql['agg']) 97 | true_semantic_units.append([tuple(new_su)]) 98 | 99 | idx += 1 100 | 101 | elif semantic_tag == WHERE_COL: 102 | col_idx = tag_seq[idx][1][-1] 103 | eval_output.append(col_idx in set([col for col, _, _ in g_sql['conds']])) 104 | true_selections.append([col for col, _, _ in g_sql['conds']]) 105 | 106 | _true_semantic_units = [] 107 | for true_col_idx in true_selections[-1]: 108 | new_su = list(tag_seq[idx]) 109 | new_su[1] = (None, None, true_col_idx) 110 | _true_semantic_units.append(tuple(new_su)) 111 | true_semantic_units.append(_true_semantic_units) 112 | 113 | idx += 1 114 | 115 | elif semantic_tag == WHERE_OP: 116 | (col_item,), op_item = tag_seq[idx][1:3] 117 | col_idx = col_item[-1] 118 | op_idx = op_item[-1] 119 | true_col_op = [(col, op) for col, op, _ in g_sql['conds']] 120 | eval_output.append((col_idx, op_idx) in set(true_col_op)) 121 | true_selections.append(true_col_op) 122 | 123 | bool_found_col = False 124 | for true_col_idx, true_op_idx in true_col_op: 125 | if col_idx == true_col_idx: 126 | new_su = list(tag_seq[idx]) 127 | new_su[2] = (None, true_op_idx) 128 | true_semantic_units.append([tuple(new_su)]) 129 | bool_found_col = True 130 | break 131 | if not bool_found_col: 132 | true_semantic_units.append(None) 133 | 134 | idx += 1 135 | 136 | elif semantic_tag == WHERE_VAL: 137 | (col_item,), op_item, val_item = tag_seq[idx][1:4] 138 | col_idx = col_item[-1] 139 | op_idx = op_item[-1] 140 | val_str = val_item[-1].lower() 141 | lower_cased_conds_str = [(true_col, true_op, true_val_str) 142 | for (true_col, true_op, true_val_str, _) in lower_cased_conds] 143 | eval_output.append((col_idx, op_idx, val_str) in lower_cased_conds_str) 144 | true_selections.append(lower_cased_conds_str) 145 | 146 | bool_found_col = False 147 | for true_col_idx, true_op_idx, true_val_str, true_val_st_end in lower_cased_conds: 148 | if true_col_idx == col_idx and true_op_idx == op_idx: 149 | new_su = list(tag_seq[idx]) 150 | new_su[3] = (true_val_st_end[0], true_val_st_end[1], true_val_str) 151 | true_semantic_units.append([tuple(new_su)]) 152 | bool_found_col = True 153 | break 154 | if not bool_found_col: 155 | true_semantic_units.append(None) 156 | 157 | idx += 1 158 | else: 159 | raise Exception("Invalid semantic_tag {} in semantic unit {}".format(semantic_tag, tag_seq[idx])) 160 | 161 | return_items = [idx, eval_output] 162 | if bool_return_true_selections: 163 | return_items.append(true_selections) 164 | if bool_return_true_semantic_units: 165 | return_items.append(true_semantic_units) 166 | 167 | return tuple(return_items) 168 | 169 | 170 | class GoldUserSim(BaseGoldUserSim): 171 | def __init__(self, error_evaluator, bool_structure_question=False): 172 | BaseGoldUserSim.__init__(self, error_evaluator) 173 | self.bool_structure_question = bool_structure_question 174 | 175 | def get_gold_selection(self, pointer): 176 | pointer_truth = self.true_semantic_units[pointer] # ground-truth decision 177 | old_su = self.tag_seq[pointer] 178 | semantic_tag = old_su[0] 179 | old_dec_item = self.dec_seq[old_su[-1]] 180 | gold_semantic_units, gold_dec_items = [], [] 181 | 182 | if pointer_truth is not None: 183 | gold_semantic_units.extend(pointer_truth) 184 | for su in gold_semantic_units: 185 | if semantic_tag == SELECT_COL: 186 | new_dec_item = list(old_dec_item) 187 | new_dec_item[-1] = su[1][-1] 188 | gold_dec_items.append(tuple(new_dec_item)) 189 | elif semantic_tag == SELECT_AGG: 190 | new_dec_item = list(old_dec_item) 191 | new_dec_item[-1] = su[2][-1] 192 | gold_dec_items.append(tuple(new_dec_item)) 193 | elif semantic_tag == WHERE_COL: 194 | gold_dec_items.append(None) 195 | elif semantic_tag == WHERE_OP: 196 | new_dec_item = list(old_dec_item) 197 | new_dec_item[-1] = su[2][-1] 198 | gold_dec_items.append(tuple(new_dec_item)) 199 | else: 200 | new_dec_item = list(old_dec_item) 201 | new_dec_item[-2] = su[3][0] 202 | new_dec_item[-1] = su[3][1] 203 | gold_dec_items.append(tuple(new_dec_item)) 204 | 205 | print("Gold semantic units: %s." % str(gold_semantic_units)) 206 | print("Gold dec_items: %s." % str(gold_dec_items)) 207 | 208 | if len(gold_semantic_units): 209 | selections = [choice + 1 for choice in range(len(gold_semantic_units))] 210 | sel_none_of_above = len(gold_semantic_units) + 1 211 | elif self.bool_structure_question and semantic_tag == WHERE_COL: 212 | sel_none_of_above = 1 213 | selections = [sel_none_of_above + 1] # invalid structure 214 | else: 215 | sel_none_of_above = 1 216 | selections = [sel_none_of_above] 217 | print("Gold user selections ('none of above' = %d): %s.\n" % (sel_none_of_above, str(selections))) 218 | 219 | return gold_semantic_units, gold_dec_items, sel_none_of_above, selections 220 | 221 | -------------------------------------------------------------------------------- /SQLova_model/error_detector.py: -------------------------------------------------------------------------------- 1 | # error detector 2 | from MISP_SQL.error_detector import * 3 | -------------------------------------------------------------------------------- /SQLova_model/evaluate_ws.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import json 3 | from argparse import ArgumentParser 4 | from tqdm import tqdm 5 | from wikisql.lib.dbengine import DBEngine 6 | from wikisql.lib.query import Query 7 | from wikisql.lib.common import count_lines 8 | 9 | import os 10 | 11 | # Jan1 2019. Wonseok. Path info has added to original wikisql/evaluation.py 12 | # Only need to add "query" (essentially "sql" in original data) and "table_id" while constructing file. 13 | 14 | if __name__ == '__main__': 15 | 16 | # Hyper parameters 17 | mode = 'dev' 18 | ordered = False 19 | 20 | dset_name = 'wikisql_tok' 21 | saved_epoch = 'best' # 30-162 22 | 23 | # Set path 24 | path_h = '/home/wonseok' # change to your home folder 25 | path_wikisql_tok = os.path.join(path_h, 'data', 'wikisql_tok') 26 | path_save_analysis = '.' 27 | 28 | # Path for evaluation results. 29 | path_wikisql0 = os.path.join(path_h,'data/WikiSQL-1.1/data') 30 | path_source = os.path.join(path_wikisql0, f'{mode}.jsonl') 31 | path_db = os.path.join(path_wikisql0, f'{mode}.db') 32 | path_pred = os.path.join(path_save_analysis, f'results_{mode}.jsonl') 33 | 34 | 35 | # For the case when use "argument" 36 | parser = ArgumentParser() 37 | parser.add_argument('--source_file', help='source file for the prediction', default=path_source) 38 | parser.add_argument('--db_file', help='source database for the prediction', default=path_db) 39 | parser.add_argument('--pred_file', help='predictions by the model', default=path_pred) 40 | parser.add_argument('--ordered', action='store_true', help='whether the exact match should consider the order of conditions') 41 | args = parser.parse_args() 42 | args.ordered=ordered 43 | 44 | engine = DBEngine(args.db_file) 45 | exact_match = [] 46 | with open(args.source_file) as fs, open(args.pred_file) as fp: 47 | grades = [] 48 | for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.source_file)): 49 | eg = json.loads(ls) 50 | ep = json.loads(lp) 51 | qg = Query.from_dict(eg['sql'], ordered=args.ordered) 52 | gold = engine.execute_query(eg['table_id'], qg, lower=True) 53 | pred = ep.get('error', None) 54 | qp = None 55 | if not ep.get('error', None): 56 | try: 57 | qp = Query.from_dict(ep['query'], ordered=args.ordered) 58 | pred = engine.execute_query(eg['table_id'], qp, lower=True) 59 | except Exception as e: 60 | pred = repr(e) 61 | correct = pred == gold 62 | match = qp == qg 63 | grades.append(correct) 64 | exact_match.append(match) 65 | 66 | print(json.dumps({ 67 | 'ex_accuracy': sum(grades) / len(grades), 68 | 'lf_accuracy': sum(exact_match) / len(exact_match), 69 | }, indent=2)) 70 | 71 | 72 | -------------------------------------------------------------------------------- /SQLova_model/sqlnet/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, Xiaojun Xu, Chang Liu and Dawn Song 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /SQLova_model/sqlnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlab-osu/MISP/7870566ab6b9e121d648478968367bc79c12f7ef/SQLova_model/sqlnet/__init__.py -------------------------------------------------------------------------------- /SQLova_model/sqlnet/dbengine.py: -------------------------------------------------------------------------------- 1 | # From original SQLNet code. 2 | # Wonseok modified. 20180607 3 | 4 | import records 5 | import re 6 | from babel.numbers import parse_decimal, NumberFormatError 7 | 8 | 9 | schema_re = re.compile(r'\((.+)\)') # group (.......) dfdf (.... )group 10 | num_re = re.compile(r'[-+]?\d*\.\d+|\d+') # ? zero or one time appear of preceding character, * zero or several time appear of preceding character. 11 | # Catch something like -34.34, .4543, 12 | # | is 'or'beam_forward 13 | 14 | agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] 15 | cond_ops = ['=', '>', '<', 'OP'] 16 | 17 | class DBEngine: 18 | 19 | def __init__(self, fdb): 20 | #fdb = 'data/test.db' 21 | self.db = records.Database('sqlite:///{}'.format(fdb)) 22 | 23 | def execute_query(self, table_id, query, *args, **kwargs): 24 | return self.execute(table_id, query.sel_index, query.agg_index, query.conditions, *args, **kwargs) 25 | 26 | def execute(self, table_id, select_index, aggregation_index, conditions, lower=True): 27 | if not table_id.startswith('table'): 28 | table_id = 'table_{}'.format(table_id.replace('-', '_')) 29 | table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql.replace('\n','') 30 | schema_str = schema_re.findall(table_info)[0] 31 | schema = {} 32 | for tup in schema_str.split(', '): 33 | c, t = tup.split() 34 | schema[c] = t 35 | select = 'col{}'.format(select_index) 36 | agg = agg_ops[aggregation_index] 37 | if agg: 38 | select = '{}({})'.format(agg, select) 39 | where_clause = [] 40 | where_map = {} 41 | for col_index, op, val in conditions: 42 | if lower and (isinstance(val, str) or isinstance(val, str)): 43 | val = val.lower() 44 | if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)): 45 | try: 46 | # print('!!!!!!value of val is: ', val, 'type is: ', type(val)) 47 | # val = float(parse_decimal(val)) # somehow it generates error. 48 | val = float(parse_decimal(val, locale='en_US')) 49 | # print('!!!!!!After: val', val) 50 | 51 | except NumberFormatError as e: 52 | try: 53 | val = float(num_re.findall(val)[0]) # need to understand and debug this part. 54 | except: 55 | # Although column is of number, selected one is not number. Do nothing in this case. 56 | pass 57 | where_clause.append('col{} {} :col{}'.format(col_index, cond_ops[op], col_index)) 58 | where_map['col{}'.format(col_index)] = val 59 | where_str = '' 60 | if where_clause: 61 | where_str = 'WHERE ' + ' AND '.join(where_clause) 62 | query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str) 63 | #print query 64 | out = self.db.query(query, **where_map) 65 | 66 | 67 | return [o.result for o in out] 68 | def execute_return_query(self, table_id, select_index, aggregation_index, conditions, lower=True): 69 | if not table_id.startswith('table'): 70 | table_id = 'table_{}'.format(table_id.replace('-', '_')) 71 | table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql.replace('\n','') 72 | schema_str = schema_re.findall(table_info)[0] 73 | schema = {} 74 | for tup in schema_str.split(', '): 75 | c, t = tup.split() 76 | schema[c] = t 77 | select = 'col{}'.format(select_index) 78 | agg = agg_ops[aggregation_index] 79 | if agg: 80 | select = '{}({})'.format(agg, select) 81 | where_clause = [] 82 | where_map = {} 83 | for col_index, op, val in conditions: 84 | if lower and (isinstance(val, str) or isinstance(val, str)): 85 | val = val.lower() 86 | if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)): 87 | try: 88 | # print('!!!!!!value of val is: ', val, 'type is: ', type(val)) 89 | # val = float(parse_decimal(val)) # somehow it generates error. 90 | val = float(parse_decimal(val, locale='en_US')) 91 | # print('!!!!!!After: val', val) 92 | 93 | except NumberFormatError as e: 94 | val = float(num_re.findall(val)[0]) 95 | where_clause.append('col{} {} :col{}'.format(col_index, cond_ops[op], col_index)) 96 | where_map['col{}'.format(col_index)] = val 97 | where_str = '' 98 | if where_clause: 99 | where_str = 'WHERE ' + ' AND '.join(where_clause) 100 | query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str) 101 | #print query 102 | out = self.db.query(query, **where_map) 103 | 104 | 105 | return [o.result for o in out], query 106 | def show_table(self, table_id): 107 | if not table_id.startswith('table'): 108 | table_id = 'table_{}'.format(table_id.replace('-', '_')) 109 | rows = self.db.query('select * from ' +table_id) 110 | print(rows.dataset) -------------------------------------------------------------------------------- /SQLova_model/sqlova/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | -------------------------------------------------------------------------------- /SQLova_model/sqlova/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlab-osu/MISP/7870566ab6b9e121d648478968367bc79c12f7ef/SQLova_model/sqlova/model/__init__.py -------------------------------------------------------------------------------- /SQLova_model/sqlova/model/nl2sql/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlab-osu/MISP/7870566ab6b9e121d648478968367bc79c12f7ef/SQLova_model/sqlova/model/nl2sql/__init__.py -------------------------------------------------------------------------------- /SQLova_model/sqlova/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | -------------------------------------------------------------------------------- /SQLova_model/sqlova/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019-present NAVER Corp. 2 | # Apache License v2.0 3 | 4 | # Wonseok Hwang 5 | import os 6 | from matplotlib.pylab import * 7 | 8 | 9 | def generate_perm_inv(perm): 10 | # Definitly correct. 11 | perm_inv = zeros(len(perm), dtype=int32) 12 | for i, p in enumerate(perm): 13 | perm_inv[int(p)] = i 14 | 15 | return perm_inv 16 | 17 | 18 | def ensure_dir(my_path): 19 | """ Generate directory if not exists 20 | """ 21 | if not os.path.exists(my_path): 22 | os.makedirs(my_path) 23 | 24 | 25 | def topk_multi_dim(tensor, n_topk=1, batch_exist=True): 26 | 27 | if batch_exist: 28 | idxs = [] 29 | for b, tensor1 in enumerate(tensor): 30 | idxs1 = [] 31 | tensor1_1d = tensor1.reshape(-1) 32 | values_1d, idxs_1d = tensor1_1d.topk(k=n_topk) 33 | idxs_list = unravel_index(idxs_1d.cpu().numpy(), tensor1.shape) 34 | # (dim0, dim1, dim2, ...) 35 | 36 | # reconstruct 37 | for i_beam in range(n_topk): 38 | idxs11 = [] 39 | for idxs_list1 in idxs_list: 40 | idxs11.append(idxs_list1[i_beam]) 41 | idxs1.append(idxs11) 42 | idxs.append(idxs1) 43 | 44 | else: 45 | tensor1 = tensor 46 | idxs1 = [] 47 | tensor1_1d = tensor1.reshape(-1) 48 | values_1d, idxs_1d = tensor1_1d.topk(k=n_topk) 49 | idxs_list = unravel_index(idxs_1d.numpy(), tensor1.shape) 50 | # (dim0, dim1, dim2, ...) 51 | 52 | # reconstruct 53 | for i_beam in range(n_topk): 54 | idxs11 = [] 55 | for idxs_list1 in idxs_list: 56 | idxs11.append(idxs_list1[i_beam]) 57 | idxs1.append(idxs11) 58 | idxs = idxs1 59 | return idxs 60 | 61 | 62 | def json_default_type_checker(o): 63 | """ 64 | From https://stackoverflow.com/questions/11942364/typeerror-integer-is-not-json-serializable-when-serializing-json-in-python 65 | """ 66 | if isinstance(o, int64): return int(o) 67 | raise TypeError 68 | -------------------------------------------------------------------------------- /SQLova_model/sqlova/utils/wikisql_formatter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019-present NAVER Corp. 2 | # Apache License v2.0 3 | 4 | 5 | # Wonseok Hwang 6 | # Convert the wikisql format to the suitable format for the BERT. 7 | import os, sys, json 8 | from matplotlib.pylab import * 9 | 10 | 11 | def get_squad_style_ans(nlu, sql): 12 | conds = sql['conds'] 13 | answers = [] 14 | for cond1 in conds: 15 | a1 = {} 16 | wv1 = cond1[2] 17 | a1['text'] = wv1 18 | a1['answer_start'] = nlu.lower().find(str(wv1).lower()) 19 | if a1['answer_start'] < 0 or a1['answer_start'] >= len(nlu): 20 | raise EnvironmentError 21 | answers.append(a1) 22 | 23 | return answers 24 | 25 | 26 | def get_qas(path_q, tid): 27 | qas = [] 28 | with open(path_q, 'r') as f_q: 29 | qnum = -1 30 | for j, q1 in enumerate(f_q): 31 | q1 = json.loads(q1) 32 | tid_q = q1['table_id'] 33 | 34 | if tid_q != tid: 35 | continue 36 | else: 37 | qnum += 1 38 | # print(tid_q, tid) 39 | qas1 = {} 40 | nlu = q1['question'] 41 | sql = q1['sql'] 42 | 43 | qas1['question'] = nlu 44 | qas1['id'] = f'{tid_q}-{qnum}' 45 | qas1['answers'] = get_squad_style_ans(nlu, sql) 46 | qas1['c_answers'] = sql 47 | 48 | qas.append(qas1) 49 | 50 | return qas 51 | 52 | 53 | def get_tbl_context(t1): 54 | context = '' 55 | 56 | header_tok = t1['header'] 57 | # Here Join scheme can be changed. 58 | header_joined = ' '.join(header_tok) 59 | context += header_joined 60 | 61 | return context 62 | 63 | def generate_wikisql_bert(path_wikisql, dset_type): 64 | path_q = os.path.join(path_wikisql, f'{dset_type}.jsonl') 65 | path_tbl = os.path.join(path_wikisql, f'{dset_type}.tables.jsonl') 66 | 67 | # Generate new json file 68 | with open(path_tbl, 'r') as f_tbl: 69 | wikisql = {'version': "v1.1"} 70 | data = [] 71 | data1 = {} 72 | paragraphs = [] # new tbls 73 | for i, t1 in enumerate(f_tbl): 74 | paragraphs1 = {} 75 | 76 | t1 = json.loads(t1) 77 | tid = t1['id'] 78 | qas = get_qas(path_q, tid) 79 | 80 | paragraphs1['qas'] = qas 81 | paragraphs1['tid'] = tid 82 | paragraphs1['context'] = get_tbl_context(t1) 83 | # paragraphs1['context_page_title'] = t1['page_title'] # not always present 84 | paragraphs1['context_headers'] = t1['header'] 85 | paragraphs1['context_headers_type'] = t1['types'] 86 | paragraphs1['context_contents'] = t1['rows'] 87 | 88 | paragraphs.append(paragraphs1) 89 | data1['paragraphs'] = paragraphs 90 | data1['title'] = 'wikisql' 91 | data.append(data1) 92 | wikisql['data'] = data 93 | 94 | # Save 95 | with open(os.path.join(path_wikisql, f'{dset_type}_bert.json'), 'w', encoding='utf-8') as fnew: 96 | json_str = json.dumps(wikisql, ensure_ascii=False) 97 | json_str += '\n' 98 | fnew.writelines(json_str) 99 | 100 | 101 | if __name__=='__main__': 102 | 103 | # 0. Load wikisql 104 | path_h = '/Users/wonseok' 105 | path_wikisql = os.path.join(path_h, 'data', 'WikiSQL-1.1', 'data') 106 | 107 | 108 | dset_type_list = ['dev', 'test', 'train'] 109 | 110 | for dset_type in dset_type_list: 111 | generate_wikisql_bert(path_wikisql, dset_type) 112 | 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /SQLova_model/wikisql/LICENSE_WikiSQL: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, Salesforce Research 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /SQLova_model/wikisql/annotate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 3 | import os 4 | import records 5 | import ujson as json 6 | from stanza.nlp.corenlp import CoreNLPClient 7 | from tqdm import tqdm 8 | import copy 9 | from lib.common import count_lines, detokenize 10 | from lib.query import Query 11 | 12 | 13 | client = None 14 | 15 | 16 | def annotate(sentence, lower=True): 17 | global client 18 | if client is None: 19 | client = CoreNLPClient(default_annotators='ssplit,tokenize'.split(',')) 20 | words, gloss, after = [], [], [] 21 | for s in client.annotate(sentence): 22 | for t in s: 23 | words.append(t.word) 24 | gloss.append(t.originalText) 25 | after.append(t.after) 26 | if lower: 27 | words = [w.lower() for w in words] 28 | return { 29 | 'gloss': gloss, 30 | 'words': words, 31 | 'after': after, 32 | } 33 | 34 | 35 | def annotate_example(example, table): 36 | ann = {'table_id': example['table_id']} 37 | ann['question'] = annotate(example['question']) 38 | ann['table'] = { 39 | 'header': [annotate(h) for h in table['header']], 40 | } 41 | ann['query'] = sql = copy.deepcopy(example['sql']) 42 | for c in ann['query']['conds']: 43 | c[-1] = annotate(str(c[-1])) 44 | 45 | q1 = 'SYMSELECT SYMAGG {} SYMCOL {}'.format(Query.agg_ops[sql['agg']], table['header'][sql['sel']]) 46 | q2 = ['SYMCOL {} SYMOP {} SYMCOND {}'.format(table['header'][col], Query.cond_ops[op], detokenize(cond)) for col, op, cond in sql['conds']] 47 | if q2: 48 | q2 = 'SYMWHERE ' + ' SYMAND '.join(q2) + ' SYMEND' 49 | else: 50 | q2 = 'SYMEND' 51 | inp = 'SYMSYMS {syms} SYMAGGOPS {aggops} SYMCONDOPS {condops} SYMTABLE {table} SYMQUESTION {question} SYMEND'.format( 52 | syms=' '.join(['SYM' + s for s in Query.syms]), 53 | table=' '.join(['SYMCOL ' + s for s in table['header']]), 54 | question=example['question'], 55 | aggops=' '.join([s for s in Query.agg_ops]), 56 | condops=' '.join([s for s in Query.cond_ops]), 57 | ) 58 | ann['seq_input'] = annotate(inp) 59 | out = '{q1} {q2}'.format(q1=q1, q2=q2) if q2 else q1 60 | ann['seq_output'] = annotate(out) 61 | ann['where_output'] = annotate(q2) 62 | assert 'symend' in ann['seq_output']['words'] 63 | assert 'symend' in ann['where_output']['words'] 64 | return ann 65 | 66 | 67 | def is_valid_example(e): 68 | if not all([h['words'] for h in e['table']['header']]): 69 | return False 70 | headers = [detokenize(h).lower() for h in e['table']['header']] 71 | if len(headers) != len(set(headers)): 72 | return False 73 | input_vocab = set(e['seq_input']['words']) 74 | for w in e['seq_output']['words']: 75 | if w not in input_vocab: 76 | print('query word "{}" is not in input vocabulary.\n{}'.format(w, e['seq_input']['words'])) 77 | return False 78 | input_vocab = set(e['question']['words']) 79 | for col, op, cond in e['query']['conds']: 80 | for w in cond['words']: 81 | if w not in input_vocab: 82 | print('cond word "{}" is not in input vocabulary.\n{}'.format(w, e['question']['words'])) 83 | return False 84 | return True 85 | 86 | 87 | if __name__ == '__main__': 88 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 89 | parser.add_argument('--din', default='data', help='data directory') 90 | parser.add_argument('--dout', default='annotated', help='output directory') 91 | args = parser.parse_args() 92 | 93 | if not os.path.isdir(args.dout): 94 | os.makedirs(args.dout) 95 | 96 | for split in ['train', 'dev', 'test']: 97 | fsplit = os.path.join(args.din, split) + '.jsonl' 98 | ftable = os.path.join(args.din, split) + '.tables.jsonl' 99 | fout = os.path.join(args.dout, split) + '.jsonl' 100 | 101 | print('annotating {}'.format(fsplit)) 102 | with open(fsplit) as fs, open(ftable) as ft, open(fout, 'wt') as fo: 103 | print('loading tables') 104 | tables = {} 105 | for line in tqdm(ft, total=count_lines(ftable)): 106 | d = json.loads(line) 107 | tables[d['id']] = d 108 | print('loading examples') 109 | n_written = 0 110 | for line in tqdm(fs, total=count_lines(fsplit)): 111 | d = json.loads(line) 112 | a = annotate_example(d, tables[d['table_id']]) 113 | if not is_valid_example(a): 114 | raise Exception(str(a)) 115 | 116 | gold = Query.from_tokenized_dict(a['query']) 117 | reconstruct = Query.from_sequence(a['seq_output'], a['table'], lowercase=True) 118 | if gold.lower() != reconstruct.lower(): 119 | raise Exception ('Expected:\n{}\nGot:\n{}'.format(gold, reconstruct)) 120 | fo.write(json.dumps(a) + '\n') 121 | n_written += 1 122 | print('wrote {} examples'.format(n_written)) 123 | -------------------------------------------------------------------------------- /SQLova_model/wikisql/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import json 3 | from argparse import ArgumentParser 4 | from tqdm import tqdm 5 | from lib.dbengine import DBEngine 6 | from lib.query import Query 7 | from lib.common import count_lines 8 | 9 | 10 | if __name__ == '__main__': 11 | parser = ArgumentParser() 12 | parser.add_argument('source_file', help='source file for the prediction') 13 | parser.add_argument('db_file', help='source database for the prediction') 14 | parser.add_argument('pred_file', help='predictions by the model') 15 | parser.add_argument('--ordered', action='store_true', help='whether the exact match should consider the order of conditions') 16 | args = parser.parse_args() 17 | 18 | engine = DBEngine(args.db_file) 19 | exact_match = [] 20 | with open(args.source_file) as fs, open(args.pred_file) as fp: 21 | grades = [] 22 | for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.source_file)): 23 | eg = json.loads(ls) 24 | ep = json.loads(lp) 25 | qg = Query.from_dict(eg['sql'], ordered=args.ordered) 26 | gold = engine.execute_query(eg['table_id'], qg, lower=True) 27 | pred = ep.get('error', None) 28 | qp = None 29 | if not ep.get('error', None): 30 | try: 31 | qp = Query.from_dict(ep['query'], ordered=args.ordered) 32 | pred = engine.execute_query(eg['table_id'], qp, lower=True) 33 | except Exception as e: 34 | pred = repr(e) 35 | correct = pred == gold 36 | match = qp == qg 37 | grades.append(correct) 38 | exact_match.append(match) 39 | print(json.dumps({ 40 | 'ex_accuracy': sum(grades) / len(grades), 41 | 'lf_accuracy': sum(exact_match) / len(exact_match), 42 | }, indent=2)) 43 | -------------------------------------------------------------------------------- /SQLova_model/wikisql/lib/common.py: -------------------------------------------------------------------------------- 1 | def count_lines(fname): 2 | with open(fname) as f: 3 | return sum(1 for line in f) 4 | 5 | 6 | def detokenize(tokens): 7 | ret = '' 8 | for g, a in zip(tokens['gloss'], tokens['after']): 9 | ret += g + a 10 | return ret.strip() 11 | -------------------------------------------------------------------------------- /SQLova_model/wikisql/lib/dbengine.py: -------------------------------------------------------------------------------- 1 | import records 2 | import re 3 | from babel.numbers import parse_decimal, NumberFormatError 4 | from wikisql.lib.query import Query 5 | 6 | # Jan 3, 2019. Wonseok modify the lib. path 7 | 8 | 9 | schema_re = re.compile(r'\((.+)\)') 10 | num_re = re.compile(r'[-+]?\d*\.\d+|\d+') 11 | 12 | 13 | class DBEngine: 14 | 15 | def __init__(self, fdb): 16 | self.db = records.Database('sqlite:///{}'.format(fdb)) 17 | 18 | def execute_query(self, table_id, query, *args, **kwargs): 19 | return self.execute(table_id, query.sel_index, query.agg_index, query.conditions, *args, **kwargs) 20 | 21 | def execute(self, table_id, select_index, aggregation_index, conditions, lower=True): 22 | if not table_id.startswith('table'): 23 | table_id = 'table_{}'.format(table_id.replace('-', '_')) 24 | table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql 25 | schema_str = schema_re.findall(table_info)[0] 26 | schema = {} 27 | for tup in schema_str.split(', '): 28 | c, t = tup.split() 29 | schema[c] = t 30 | select = 'col{}'.format(select_index) 31 | agg = Query.agg_ops[aggregation_index] 32 | if agg: 33 | select = '{}({})'.format(agg, select) 34 | where_clause = [] 35 | where_map = {} 36 | for col_index, op, val in conditions: 37 | if lower and isinstance(val, str): 38 | val = val.lower() 39 | if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)): 40 | try: 41 | val = float(parse_decimal(val)) 42 | except NumberFormatError as e: 43 | val = float(num_re.findall(val)[0]) 44 | where_clause.append('col{} {} :col{}'.format(col_index, Query.cond_ops[op], col_index)) 45 | where_map['col{}'.format(col_index)] = val 46 | where_str = '' 47 | if where_clause: 48 | where_str = 'WHERE ' + ' AND '.join(where_clause) 49 | query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str) 50 | out = self.db.query(query, **where_map) 51 | return [o.result for o in out] 52 | -------------------------------------------------------------------------------- /SQLova_model/wikisql/lib/table.py: -------------------------------------------------------------------------------- 1 | import re 2 | from tabulate import tabulate 3 | from lib.query import Query 4 | import random 5 | 6 | 7 | class Table: 8 | 9 | schema_re = re.compile('\((.+)\)') 10 | 11 | def __init__(self, table_id, header, types, rows, caption=None): 12 | self.table_id = table_id 13 | self.header = header 14 | self.types = types 15 | self.rows = rows 16 | self.caption = caption 17 | 18 | def __repr__(self): 19 | return 'Table: {id}\nCaption: {caption}\n{tabulate}'.format( 20 | id=self.table_id, 21 | caption=self.caption, 22 | tabulate=tabulate(self.rows, headers=self.header) 23 | ) 24 | 25 | @classmethod 26 | def get_schema(cls, db, table_id): 27 | table_infos = db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=cls.get_id(table_id)).all() 28 | if table_infos: 29 | return table_infos[0] 30 | else: 31 | return None 32 | 33 | @classmethod 34 | def get_id(cls, table_id): 35 | return 'table_{}'.format(table_id.replace('-', '_')) 36 | 37 | @classmethod 38 | def from_db(cls, db, table_id): 39 | table_info = cls.get_schema(db, table_id) 40 | if table_info: 41 | schema_str = cls.schema_re.findall(table_info)[0] = [0].sql 42 | header, types = [], [] 43 | for tup in schema_str.split(', '): 44 | c, t = tup.split() 45 | header.append(c) 46 | types.append(t) 47 | rows = [[getattr(r, h) for h in header] for r in db.query('SELECT * from {}'.format(cls.get_id(table_id)))] 48 | return cls(table_id, header, types, rows) 49 | else: 50 | return None 51 | 52 | @property 53 | def name(self): 54 | return self.get_id(self.table_id) 55 | 56 | def create_table(self, db, replace_existing=False, lower=True): 57 | exists = self.get_schema(db, self.table_id) 58 | if exists: 59 | if replace_existing: 60 | db.query('DROP TABLE {}'.format(self.name)) 61 | else: 62 | return 63 | type_str = ', '.join(['col{} {}'.format(i, t) for i, t in enumerate(self.types)]) 64 | db.query('CREATE TABLE {name} ({types})'.format(name=self.name, types=type_str)) 65 | for row in self.rows: 66 | value_str = ', '.join([':val{}'.format(j) for j, c in enumerate(row)]) 67 | value_dict = {'val{}'.format(j): c for j, c in enumerate(row)} 68 | if lower: 69 | value_dict = {k: v.lower() if isinstance(v, str) else v for k, v in value_dict.items()} 70 | db.query('INSERT INTO {name} VALUES ({values})'.format(name=self.name, values=value_str), **value_dict) 71 | 72 | def execute_query(self, db, query, lower=True): 73 | sel_str = 'col{}'.format(query.sel_index) if query.sel_index >= 0 else '*' 74 | agg_str = sel_str 75 | agg_op = Query.agg_ops[query.agg_index] 76 | if agg_op: 77 | agg_str = '{}({})'.format(agg_op, sel_str) 78 | where_str = ' AND '.join(['col{} {} :col{}'.format(i, Query.cond_ops[o], i) for i, o, v in query.conditions]) 79 | where_map = {'col{}'.format(i): v for i, o, v in query.conditions} 80 | if lower: 81 | where_map = {k: v.lower() if isinstance(v, str) else v for k, v in where_map.items()} 82 | if where_map: 83 | where_str = 'WHERE ' + where_str 84 | 85 | if query.sel_index >= 0: 86 | query_str = 'SELECT {agg_str} AS result FROM {name} {where_str}'.format(agg_str=agg_str, name=self.name, where_str=where_str) 87 | return [r.result for r in db.query(query_str, **where_map)] 88 | else: 89 | query_str = 'SELECT {agg_str} FROM {name} {where_str}'.format(agg_str=agg_str, name=self.name, where_str=where_str) 90 | return [[getattr(r, 'col{}'.format(i)) for i in range(len(self.header))] for r in db.query(query_str, **where_map)] 91 | 92 | def query_str(self, query): 93 | agg_str = self.header[query.sel_index] 94 | agg_op = Query.agg_ops[query.agg_index] 95 | if agg_op: 96 | agg_str = '{}({})'.format(agg_op, agg_str) 97 | where_str = ' AND '.join(['{} {} {}'.format(self.header[i], Query.cond_ops[o], v) for i, o, v in query.conditions]) 98 | return 'SELECT {} FROM {} WHERE {}'.format(agg_str, self.name, where_str) 99 | 100 | def generate_query(self, db, max_cond=4): 101 | max_cond = min(len(self.header), max_cond) 102 | # sample a select column 103 | sel_index = random.choice(list(range(len(self.header)))) 104 | # sample where conditions 105 | query = Query(-1, Query.agg_ops.index('')) 106 | results = self.execute_query(db, query) 107 | condition_options = list(range(len(self.header))) 108 | condition_options.remove(sel_index) 109 | for i in range(max_cond): 110 | if not results: 111 | break 112 | cond_index = random.choice(condition_options) 113 | if self.types[cond_index] == 'text': 114 | cond_op = Query.cond_ops.index('=') 115 | else: 116 | cond_op = random.choice(list(range(len(Query.cond_ops)))) 117 | cond_val = random.choice([r[cond_index] for r in results]) 118 | query.conditions.append((cond_index, cond_op, cond_val)) 119 | new_results = self.execute_query(db, query) 120 | if [r[sel_index] for r in new_results] != [r[sel_index] for r in results]: 121 | condition_options.remove(cond_index) 122 | results = new_results 123 | else: 124 | query.conditions.pop() 125 | # sample an aggregation operation 126 | if self.types[sel_index] == 'text': 127 | query.agg_index = Query.agg_ops.index('') 128 | else: 129 | query.agg_index = random.choice(list(range(len(Query.agg_ops)))) 130 | query.sel_index = sel_index 131 | results = self.execute_query(db, query) 132 | return query, results 133 | 134 | def generate_queries(self, db, n=1, max_tries=5, lower=True): 135 | qs = [] 136 | for i in range(n): 137 | n_tries = 0 138 | r = None 139 | while r is None and n_tries < max_tries: 140 | q, r = self.generate_query(db, max_cond=4) 141 | n_tries += 1 142 | if r: 143 | qs.append((q, r)) 144 | return qs 145 | -------------------------------------------------------------------------------- /SQLova_model/world_model.py: -------------------------------------------------------------------------------- 1 | from MISP_SQL.world_model import WorldModel as BaseWorldModel 2 | from MISP_SQL.utils import * 3 | from .sqlova.utils.utils_wikisql import * 4 | 5 | 6 | def apply_dropout(m): 7 | if type(m) == nn.Dropout: 8 | m.train() 9 | 10 | 11 | def cancel_dropout(m): 12 | if type(m) == nn.Dropout: 13 | m.eval() 14 | 15 | 16 | class WorldModel(BaseWorldModel): 17 | def __init__(self, bert_info, semparser, num_options, num_passes=1, dropout_rate=0.0, 18 | bool_structure_question=False): 19 | BaseWorldModel.__init__(self, semparser, num_options, 20 | num_passes=num_passes, dropout_rate=dropout_rate) 21 | 22 | bert_config, model_bert, tokenizer, max_seq_length, num_target_layers = bert_info 23 | self.model_bert = model_bert 24 | self.tokenizer = tokenizer 25 | self.bert_config = bert_config 26 | self.max_seq_length = max_seq_length 27 | self.num_target_layers = num_target_layers 28 | 29 | self.bool_structure_question = bool_structure_question 30 | 31 | def decode_per_pass(self, input_item, dec_beam_size=1, dec_prefix=None, stop_step=None, 32 | avoid_items=None, confirmed_items=None, dropout_rate=0.0, 33 | bool_collect_choices=False, bool_verbal=False): 34 | if len(input_item) == 4: 35 | tb, nlu_t, nlu, hds = input_item 36 | 37 | wemb_n, wemb_h, l_n, l_hpu, l_hs, \ 38 | nlu_tt, t_to_tt_idx, tt_to_t_idx \ 39 | = get_wemb_bert(self.bert_config, self.model_bert, self.tokenizer, nlu_t, hds, self.max_seq_length, 40 | num_out_layers_n=self.num_target_layers, num_out_layers_h=self.num_target_layers) 41 | else: 42 | wemb_n, l_n, wemb_h, l_hpu, l_hs, tb, nlu_t, nlu_tt, tt_to_t_idx, nlu = input_item 43 | 44 | if dropout_rate > 0.0: 45 | self.semparser.train() 46 | self.model_bert.apply(apply_dropout) 47 | 48 | hypotheses = self.semparser.interaction_beam_forward( 49 | wemb_n, l_n, wemb_h, l_hpu, l_hs, tb, nlu_t, nlu_tt, tt_to_t_idx, nlu, 50 | None if dec_beam_size == np.inf else dec_beam_size, 51 | [] if dec_prefix is None else dec_prefix, 52 | stop_step=stop_step, avoid_items=avoid_items, confirmed_items=confirmed_items, 53 | bool_collect_choices=bool_collect_choices, 54 | bool_verbal=bool_verbal) 55 | 56 | if stop_step is None and bool_verbal: 57 | for output_idx, hyp in enumerate(hypotheses): 58 | print("Predicted {}-th SQL: {}".format(output_idx, hyp.sql)) 59 | 60 | if dropout_rate > 0.0: 61 | self.semparser.eval() 62 | self.model_bert.apply(cancel_dropout) 63 | 64 | return hypotheses 65 | 66 | def apply_pos_feedback(self, semantic_unit, dec_seq, dec_prefix): 67 | semantic_tag = semantic_unit[0] 68 | dec_seq_idx = semantic_unit[-1] 69 | 70 | # confirmed answer 71 | if semantic_tag in {WHERE_COL}: 72 | confirm_idx = semantic_unit[1][-1] 73 | self.confirmed_items[dec_seq_idx].add(confirm_idx) 74 | return dec_prefix 75 | else: 76 | # for SELECT_COL/AGG, WHERE_OP, WHERE_VAL, finalize the verified value 77 | try: 78 | assert dec_prefix == dec_seq[:dec_seq_idx] 79 | except AssertionError: 80 | print("AssertionError in apply_pos_feedback:\ndec_seq[:dec_seq_idx]={}\ndec_prefix={}".format( 81 | dec_seq[:dec_seq_idx], dec_prefix)) 82 | return dec_seq[:(dec_seq_idx + 1)] 83 | 84 | def apply_neg_feedback(self, semantic_unit, dec_seq, dec_prefix): 85 | dec_seq_idx = semantic_unit[-1] 86 | semantic_tag = semantic_unit[0] 87 | 88 | if semantic_tag in {SELECT_COL, WHERE_COL}: 89 | drop_idx = semantic_unit[1][-1] 90 | self.avoid_items[dec_seq_idx].add(drop_idx) 91 | elif semantic_tag == SELECT_AGG: 92 | drop_idx = semantic_unit[2][-1] 93 | self.avoid_items[dec_seq_idx].add(drop_idx) 94 | elif semantic_tag == WHERE_OP: 95 | drop_idx = semantic_unit[2][-1] 96 | self.avoid_items[dec_seq_idx].add(drop_idx) 97 | else: 98 | assert semantic_tag == WHERE_VAL # re-decode 99 | st, ed = semantic_unit[3][:2] 100 | self.avoid_items[dec_seq_idx].add((st, ed)) 101 | 102 | return dec_prefix 103 | 104 | def decode_revised_structure(self, semantic_unit, pointer, hyp, input_item, bool_verbal=False): 105 | semantic_tag = semantic_unit[0] 106 | assert semantic_tag != SELECT_COL, "Error: Cannot remove all SELECT_COL!" 107 | 108 | if semantic_tag == WHERE_COL: 109 | print("## WARNING: %s structure changes!" % semantic_tag) 110 | dec_seq_idx = semantic_unit[-1] 111 | dec_seq_item = list(hyp.dec_seq[dec_seq_idx]) 112 | hyp.dec_seq[dec_seq_idx] = (dec_seq_item[0], 0, []) 113 | hyp = self.decode(input_item, dec_beam_size=1, 114 | dec_prefix=hyp.dec_seq[:(dec_seq_idx + 1)], 115 | avoid_items=self.avoid_items, 116 | confirmed_items=self.confirmed_items, 117 | bool_verbal=bool_verbal)[0] 118 | return pointer, hyp 119 | else: 120 | return pointer + 1, hyp 121 | 122 | def refresh_decoding(self, input_item, dec_prefix, old_hyp, semantic_unit, 123 | pointer, sel_none_of_above, user_selections, bool_verbal=False): 124 | semantic_tag = semantic_unit[0] 125 | dec_seq_idx = semantic_unit[-1] 126 | 127 | if self.bool_structure_question and (sel_none_of_above + 1) in user_selections: 128 | assert semantic_tag == WHERE_COL 129 | 130 | dec_seq_idx = semantic_unit[-1] 131 | dec_seq_item = list(old_hyp.dec_seq[dec_seq_idx]) 132 | dec_prefix.append((dec_seq_item[0], 0, [])) 133 | hyp = self.decode(input_item, dec_beam_size=1, 134 | dec_prefix=dec_prefix, 135 | avoid_items=self.avoid_items, 136 | confirmed_items=self.confirmed_items, 137 | bool_verbal=bool_verbal)[0] 138 | print("DEBUG: new_hyp.sql = {}\n".format(hyp.sql)) 139 | 140 | start_pos = pointer 141 | 142 | else: 143 | try: 144 | partial_hyp = self.decode( 145 | input_item, dec_prefix=dec_prefix, 146 | avoid_items=self.avoid_items, 147 | confirmed_items=self.confirmed_items, 148 | stop_step=dec_seq_idx, 149 | bool_verbal=bool_verbal)[0] 150 | except Exception: # e.g., when any WHERE_COL is redundant 151 | start_pos, hyp = self.decode_revised_structure( 152 | semantic_unit, pointer, old_hyp, input_item, 153 | bool_verbal=bool_verbal) 154 | else: 155 | # the following finds the next pointer to validate 156 | _, cand_pointers = semantic_unit_segment(partial_hyp.tag_seq) 157 | last_pointer = cand_pointers[-1] 158 | if last_pointer < pointer: # structure changed, e.g., #cols reduce 159 | start_pos = last_pointer + 1 160 | else: 161 | start_pos = pointer + 1 162 | 163 | # generate a new hypothesis after interaction 164 | hyp = self.decode( 165 | input_item, dec_prefix=dec_prefix, 166 | avoid_items=self.avoid_items, 167 | confirmed_items=self.confirmed_items, 168 | bool_verbal=bool_verbal)[0] 169 | 170 | return start_pos, hyp 171 | -------------------------------------------------------------------------------- /gpu-py3.yml: -------------------------------------------------------------------------------- 1 | name: gpu-py3 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - backcall=0.1.0=py37_0 9 | - blas=1.0=mkl 10 | - ca-certificates=2019.10.16=0 11 | - certifi=2019.9.11=py37_0 12 | - cffi=1.12.3=py37h2e261b9_0 13 | - cudatoolkit=10.0.130=0 14 | - decorator=4.4.1=py_0 15 | - freetype=2.9.1=h8a8886c_1 16 | - intel-openmp=2019.4=243 17 | - ipython_genutils=0.2.0=py37_0 18 | - jedi=0.15.1=py37_0 19 | - jpeg=9b=h024ee3a_2 20 | - libedit=3.1.20181209=hc058e9b_0 21 | - libffi=3.2.1=hd88cf55_4 22 | - libgcc-ng=8.2.0=hdf63c60_1 23 | - libgfortran-ng=7.3.0=hdf63c60_0 24 | - libpng=1.6.37=hbc83047_0 25 | - libstdcxx-ng=8.2.0=hdf63c60_1 26 | - libtiff=4.0.10=h2733197_2 27 | - line_profiler=2.1.2=py37h14c3975_0 28 | - memory_profiler=0.55.0=py37_0 29 | - mkl=2019.4=243 30 | - mkl-service=2.0.2=py37h7b6447c_0 31 | - mkl_fft=1.0.14=py37ha843d7b_0 32 | - mkl_random=1.0.2=py37hd81dba3_0 33 | - ncurses=6.1=he6710b0_1 34 | - ninja=1.9.0=py37hfd86e86_0 35 | - olefile=0.46=py37_0 36 | - openssl=1.1.1h=h7b6447c_0 37 | - parso=0.5.1=py_0 38 | - pexpect=4.7.0=py37_0 39 | - pillow=6.1.0=py37h34e0f95_0 40 | - pip=19.1.1=py37_0 41 | - prompt_toolkit=2.0.10=py_0 42 | - pycparser=2.19=py37_0 43 | - python=3.7.3=h0371630_0 44 | - pytorch=1.2.0=py3.7_cuda10.0.130_cudnn7.6.2_0 45 | - readline=7.0=h7b6447c_5 46 | - setuptools=41.0.1=py37_0 47 | - sqlite=3.28.0=h7b6447c_0 48 | - tk=8.6.8=hbc83047_0 49 | - torchvision=0.4.0=py37_cu100 50 | - traitlets=4.3.3=py37_0 51 | - wheel=0.33.4=py37_0 52 | - xz=5.2.4=h14c3975_4 53 | - zlib=1.2.11=h7b6447c_3 54 | - zstd=1.3.7=h0b5b093_0 55 | - pip: 56 | - babel==2.7.0 57 | - backports-csv==1.0.7 58 | - boto==2.49.0 59 | - boto3==1.9.156 60 | - botocore==1.12.156 61 | - chardet==3.0.4 62 | - click==7.0 63 | - cpython==0.0.5 64 | - cycler==0.10.0 65 | - defusedxml==0.6.0 66 | - docopt==0.6.2 67 | - docutils==0.14 68 | - et-xmlfile==1.0.1 69 | - gensim==3.7.3 70 | - idna==2.8 71 | - ipython==7.10.1 72 | - ipython-genutils==0.2.0 73 | - jdcal==1.4.1 74 | - jmespath==0.9.4 75 | - joblib==0.13.2 76 | - kiwisolver==1.1.0 77 | - matplotlib==3.1.1 78 | - nltk==3.4.1 79 | - numpy==1.16.3 80 | - odfpy==1.4.0 81 | - openpyxl==2.6.3 82 | - pickleshare==0.7.5 83 | - prettytable==0.7.2 84 | - progressbar==2.5 85 | - prompt-toolkit==3.0.2 86 | - psutil==5.6.2 87 | - ptyprocess==0.6.0 88 | - pygments==2.5.2 89 | - pymongo==3.9.0 90 | - pymysql==0.9.3 91 | - pyparsing==2.4.2 92 | - python-dateutil==2.8.0 93 | - pytimeparse==1.1.8 94 | - pytz==2019.2 95 | - pyyaml==5.1.2 96 | - records==0.5.2 97 | - regex==2019.12.20 98 | - requests==2.22.0 99 | - s3transfer==0.2.0 100 | - sacremoses==0.0.35 101 | - scikit-learn==0.21.2 102 | - scipy==1.3.0 103 | - sentencepiece==0.1.85 104 | - six==1.12.0 105 | - smart-open==1.8.3 106 | - sqlalchemy==1.3.7 107 | - sqlparse==0.3.0 108 | - tablib==0.13.0 109 | - theano==1.0.4 110 | - tqdm==4.32.1 111 | - transformers==2.3.0 112 | - urllib3==1.25.3 113 | - wcwidth==0.1.7 114 | - xlrd==1.2.0 115 | - xlwt==1.3.0 116 | prefix: /home/yao.470/anaconda2/envs/gpu-py3 117 | 118 | -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlab-osu/MISP/7870566ab6b9e121d648478968367bc79c12f7ef/overview.png -------------------------------------------------------------------------------- /scripts/editsql/bin_user.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | source activate gpu-py3 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | GLOVE_PATH="EditSQL/word_emb/glove.840B.300d.txt" 7 | LOGDIR="EditSQL/logs_clean/logs_spider_editsql_10p" 8 | 9 | # online setting 10 | UPDATE_ITER=1000 11 | 12 | SUP="bin_feedback" 13 | DATA_SEED=0 # 0, 10, 100 14 | ST=0 15 | END=-1 16 | 17 | OUTPUT_PATH=${SUP}_ITER${UPDATE_ITER}_DATASEED${DATA_SEED} 18 | echo ${OUTPUT_PATH} 19 | 20 | python3 interaction_editsql.py \ 21 | --raw_train_filename="EditSQL/data_clean/spider_data_removefrom/train_10p.pkl" \ 22 | --raw_validation_filename="EditSQL/data_clean/spider_data_removefrom/dev.pkl" \ 23 | --database_schema_filename="EditSQL/data_clean/spider_data_removefrom/tables.json" \ 24 | --embedding_filename=$GLOVE_PATH \ 25 | --data_directory="EditSQL/data_clean/processed_data_spider_removefrom_10p" \ 26 | --raw_data_directory="EditSQL/data_clean/spider" \ 27 | --input_key="utterance" \ 28 | --use_schema_encoder=1 \ 29 | --use_schema_attention=1 \ 30 | --use_encoder_attention=1 \ 31 | --use_schema_self_attention=1 \ 32 | --use_schema_encoder_2=1 \ 33 | --use_bert=1 \ 34 | --bert_type_abb=uS \ 35 | --fine_tune_bert=1 \ 36 | --interaction_level=1 \ 37 | --reweight_batch=1 \ 38 | --freeze=1 \ 39 | --logdir=$LOGDIR \ 40 | --evaluate=1 --train=1 \ 41 | --evaluate_split="valid" \ 42 | --use_predicted_queries=1 \ 43 | --eval_maximum_sql_length=100 \ 44 | --job="online_learning" \ 45 | --setting="online_pretrain_10p" --supervision=${SUP} --data_seed=${DATA_SEED} \ 46 | --start_iter=${ST} --end_iter=${END} --ask_structure=1 \ 47 | --output_path ${LOGDIR}/records_${OUTPUT_PATH}.json \ 48 | --update_iter ${UPDATE_ITER} \ 49 | > ${LOGDIR}/records_${OUTPUT_PATH}.output 2>&1 & -------------------------------------------------------------------------------- /scripts/editsql/bin_user_expert.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | source activate gpu-py3 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | GLOVE_PATH="EditSQL/word_emb/glove.840B.300d.txt" 7 | LOGDIR="EditSQL/logs_clean/logs_spider_editsql_10p" 8 | 9 | # online setting 10 | UPDATE_ITER=1000 11 | 12 | SUP="bin_feedback_expert" 13 | DATA_SEED=0 # 0, 10, 100 14 | ST=0 15 | END=-1 16 | 17 | OUTPUT_PATH=${SUP}_ITER${UPDATE_ITER}_DATASEED${DATA_SEED} 18 | echo ${OUTPUT_PATH} 19 | 20 | python3 interaction_editsql.py \ 21 | --raw_train_filename="EditSQL/data_clean/spider_data_removefrom/train_10p.pkl" \ 22 | --raw_validation_filename="EditSQL/data_clean/spider_data_removefrom/dev.pkl" \ 23 | --database_schema_filename="EditSQL/data_clean/spider_data_removefrom/tables.json" \ 24 | --embedding_filename=$GLOVE_PATH \ 25 | --data_directory="EditSQL/data_clean/processed_data_spider_removefrom_10p" \ 26 | --raw_data_directory="EditSQL/data_clean/spider" \ 27 | --input_key="utterance" \ 28 | --use_schema_encoder=1 \ 29 | --use_schema_attention=1 \ 30 | --use_encoder_attention=1 \ 31 | --use_schema_self_attention=1 \ 32 | --use_schema_encoder_2=1 \ 33 | --use_bert=1 \ 34 | --bert_type_abb=uS \ 35 | --fine_tune_bert=1 \ 36 | --interaction_level=1 \ 37 | --reweight_batch=1 \ 38 | --freeze=1 \ 39 | --logdir=$LOGDIR \ 40 | --evaluate=1 --train=1 \ 41 | --evaluate_split="valid" \ 42 | --use_predicted_queries=1 \ 43 | --eval_maximum_sql_length=100 \ 44 | --job="online_learning" \ 45 | --setting="online_pretrain_10p" --supervision=${SUP} --data_seed=${DATA_SEED} \ 46 | --start_iter=${ST} --end_iter=${END} --ask_structure=1 \ 47 | --output_path ${LOGDIR}/records_${OUTPUT_PATH}.json \ 48 | --update_iter ${UPDATE_ITER} \ 49 | > ${LOGDIR}/records_${OUTPUT_PATH}.output 2>&1 & -------------------------------------------------------------------------------- /scripts/editsql/full_expert.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | source activate gpu-py3 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | GLOVE_PATH="EditSQL/word_emb/glove.840B.300d.txt" 7 | LOGDIR="EditSQL/logs_clean/logs_spider_editsql_10p" 8 | 9 | # online setting 10 | NUM_OP="3" 11 | ED="prob=0.995" 12 | UPDATE_ITER=1000 13 | 14 | SUP="full_expert" 15 | DATA_SEED=0 # 0, 10, 100 16 | ST=0 17 | END=-1 18 | 19 | OUTPUT_PATH=${SUP}_ITER${UPDATE_ITER}_DATASEED${DATA_SEED} 20 | echo ${OUTPUT_PATH} 21 | 22 | python3 interaction_editsql.py \ 23 | --raw_train_filename="EditSQL/data_clean/spider_data_removefrom/train_10p.pkl" \ 24 | --raw_validation_filename="EditSQL/data_clean/spider_data_removefrom/dev.pkl" \ 25 | --database_schema_filename="EditSQL/data_clean/spider_data_removefrom/tables.json" \ 26 | --embedding_filename=$GLOVE_PATH \ 27 | --data_directory="EditSQL/data_clean/processed_data_spider_removefrom_10p" \ 28 | --raw_data_directory="EditSQL/data_clean/spider" \ 29 | --input_key="utterance" \ 30 | --use_schema_encoder=1 \ 31 | --use_schema_attention=1 \ 32 | --use_encoder_attention=1 \ 33 | --use_schema_self_attention=1 \ 34 | --use_schema_encoder_2=1 \ 35 | --use_bert=1 \ 36 | --bert_type_abb=uS \ 37 | --fine_tune_bert=1 \ 38 | --interaction_level=1 \ 39 | --reweight_batch=1 \ 40 | --freeze=1 \ 41 | --logdir=$LOGDIR \ 42 | --evaluate=1 --train=1 \ 43 | --evaluate_split="valid" \ 44 | --use_predicted_queries=1 \ 45 | --eval_maximum_sql_length=100 \ 46 | --job="online_learning" \ 47 | --setting="online_pretrain_10p" --supervision=${SUP} --data_seed=${DATA_SEED} \ 48 | --start_iter=${ST} --end_iter=${END} --ask_structure=1 \ 49 | --output_path ${LOGDIR}/records_${OUTPUT_PATH}.json \ 50 | --update_iter ${UPDATE_ITER} \ 51 | > ${LOGDIR}/records_${OUTPUT_PATH}.output 2>&1 & -------------------------------------------------------------------------------- /scripts/editsql/misp_neil.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | source activate gpu-py3 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | GLOVE_PATH="EditSQL/word_emb/glove.840B.300d.txt" 7 | LOGDIR="EditSQL/logs_clean/logs_spider_editsql_10p" 8 | 9 | # online setting 10 | NUM_OP="3" 11 | ED="prob=0.995" 12 | UPDATE_ITER=1000 13 | 14 | SUP="misp_neil" 15 | DATA_SEED=0 # 0, 10, 100 16 | ST=0 17 | END=-1 18 | 19 | OUTPUT_PATH=${SUP}_OP${NUM_OP}_ED${ED}_ITER${UPDATE_ITER}_DATASEED${DATA_SEED} 20 | echo ${OUTPUT_PATH} 21 | 22 | python3 interaction_editsql.py \ 23 | --raw_train_filename="EditSQL/data_clean/spider_data_removefrom/train_10p.pkl" \ 24 | --raw_validation_filename="EditSQL/data_clean/spider_data_removefrom/dev.pkl" \ 25 | --database_schema_filename="EditSQL/data_clean/spider_data_removefrom/tables.json" \ 26 | --embedding_filename=$GLOVE_PATH \ 27 | --data_directory="EditSQL/data_clean/processed_data_spider_removefrom_10p" \ 28 | --raw_data_directory="EditSQL/data_clean/spider" \ 29 | --input_key="utterance" \ 30 | --use_schema_encoder=1 \ 31 | --use_schema_attention=1 \ 32 | --use_encoder_attention=1 \ 33 | --use_schema_self_attention=1 \ 34 | --use_schema_encoder_2=1 \ 35 | --use_bert=1 \ 36 | --bert_type_abb=uS \ 37 | --fine_tune_bert=1 \ 38 | --interaction_level=1 \ 39 | --reweight_batch=1 \ 40 | --freeze=1 \ 41 | --logdir=$LOGDIR \ 42 | --evaluate=1 --train=1 \ 43 | --evaluate_split="valid" \ 44 | --use_predicted_queries=1 \ 45 | --eval_maximum_sql_length=100 \ 46 | --job="online_learning" \ 47 | --num_options=${NUM_OP} --err_detector=${ED} --friendly_agent=0 --user="sim" \ 48 | --setting="online_pretrain_10p" --supervision=${SUP} --data_seed=${DATA_SEED} \ 49 | --start_iter=${ST} --end_iter=${END} --ask_structure=1 \ 50 | --output_path ${LOGDIR}/records_${OUTPUT_PATH}.json \ 51 | --update_iter ${UPDATE_ITER} \ 52 | > ${LOGDIR}/records_${OUTPUT_PATH}.output 2>&1 & -------------------------------------------------------------------------------- /scripts/editsql/misp_neil_perfect.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | source activate gpu-py3 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | GLOVE_PATH="EditSQL/word_emb/glove.840B.300d.txt" 7 | LOGDIR="EditSQL/logs_clean/logs_spider_editsql_10p" 8 | 9 | # online setting 10 | NUM_OP="3" 11 | ED="perfect" 12 | UPDATE_ITER=1000 13 | 14 | SUP="misp_neil_perfect" 15 | DATA_SEED=0 # 0, 10, 100 16 | ST=0 17 | END=-1 18 | 19 | OUTPUT_PATH=${SUP}_OP${NUM_OP}_ED${ED}_ITER${UPDATE_ITER}_DATASEED${DATA_SEED} 20 | echo ${OUTPUT_PATH} 21 | 22 | python3 interaction_editsql.py \ 23 | --raw_train_filename="EditSQL/data_clean/spider_data_removefrom/train_10p.pkl" \ 24 | --raw_validation_filename="EditSQL/data_clean/spider_data_removefrom/dev.pkl" \ 25 | --database_schema_filename="EditSQL/data_clean/spider_data_removefrom/tables.json" \ 26 | --embedding_filename=$GLOVE_PATH \ 27 | --data_directory="EditSQL/data_clean/processed_data_spider_removefrom_10p" \ 28 | --raw_data_directory="EditSQL/data_clean/spider" \ 29 | --input_key="utterance" \ 30 | --use_schema_encoder=1 \ 31 | --use_schema_attention=1 \ 32 | --use_encoder_attention=1 \ 33 | --use_schema_self_attention=1 \ 34 | --use_schema_encoder_2=1 \ 35 | --use_bert=1 \ 36 | --bert_type_abb=uS \ 37 | --fine_tune_bert=1 \ 38 | --interaction_level=1 \ 39 | --reweight_batch=1 \ 40 | --freeze=1 \ 41 | --logdir=$LOGDIR \ 42 | --evaluate=1 --train=1 \ 43 | --evaluate_split="valid" \ 44 | --use_predicted_queries=1 \ 45 | --eval_maximum_sql_length=100 \ 46 | --job="online_learning" \ 47 | --num_options=${NUM_OP} --err_detector=${ED} --friendly_agent=0 --user="gold_sim" \ 48 | --setting="online_pretrain_10p" --supervision=${SUP} --data_seed=${DATA_SEED} \ 49 | --start_iter=${ST} --end_iter=${END} --ask_structure=1 \ 50 | --output_path ${LOGDIR}/records_${OUTPUT_PATH}.json \ 51 | --update_iter ${UPDATE_ITER} \ 52 | > ${LOGDIR}/records_${OUTPUT_PATH}.output 2>&1 & -------------------------------------------------------------------------------- /scripts/editsql/misp_neil_pos.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | source activate gpu-py3 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | GLOVE_PATH="EditSQL/word_emb/glove.840B.300d.txt" 7 | LOGDIR="EditSQL/logs_clean/logs_spider_editsql_10p" 8 | 9 | # online setting 10 | NUM_OP="3" 11 | ED="prob=0.995" 12 | UPDATE_ITER=1000 13 | 14 | SUP="misp_neil_pos" 15 | DATA_SEED=0 # 0, 10, 100 16 | ST=0 17 | END=-1 18 | 19 | OUTPUT_PATH=${SUP}_OP${NUM_OP}_ED${ED}_ITER${UPDATE_ITER}_DATASEED${DATA_SEED} 20 | echo ${OUTPUT_PATH} 21 | 22 | python3 interaction_editsql.py \ 23 | --raw_train_filename="EditSQL/data_clean/spider_data_removefrom/train_10p.pkl" \ 24 | --raw_validation_filename="EditSQL/data_clean/spider_data_removefrom/dev.pkl" \ 25 | --database_schema_filename="EditSQL/data_clean/spider_data_removefrom/tables.json" \ 26 | --embedding_filename=$GLOVE_PATH \ 27 | --data_directory="EditSQL/data_clean/processed_data_spider_removefrom_10p" \ 28 | --raw_data_directory="EditSQL/data_clean/spider" \ 29 | --input_key="utterance" \ 30 | --use_schema_encoder=1 \ 31 | --use_schema_attention=1 \ 32 | --use_encoder_attention=1 \ 33 | --use_schema_self_attention=1 \ 34 | --use_schema_encoder_2=1 \ 35 | --use_bert=1 \ 36 | --bert_type_abb=uS \ 37 | --fine_tune_bert=1 \ 38 | --interaction_level=1 \ 39 | --reweight_batch=1 \ 40 | --freeze=1 \ 41 | --logdir=$LOGDIR \ 42 | --evaluate=1 --train=1 \ 43 | --evaluate_split="valid" \ 44 | --use_predicted_queries=1 \ 45 | --eval_maximum_sql_length=100 \ 46 | --job="online_learning" \ 47 | --num_options=${NUM_OP} --err_detector=${ED} --friendly_agent=0 --user="sim" \ 48 | --setting="online_pretrain_10p" --supervision=${SUP} --data_seed=${DATA_SEED} \ 49 | --start_iter=${ST} --end_iter=${END} --ask_structure=1 \ 50 | --output_path ${LOGDIR}/records_${OUTPUT_PATH}.json \ 51 | --update_iter ${UPDATE_ITER} \ 52 | > ${LOGDIR}/records_${OUTPUT_PATH}.output 2>&1 & -------------------------------------------------------------------------------- /scripts/editsql/pretrain.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | source activate gpu-py3 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | 7 | SETTING="_10p" # set to empty string '' for full training set, '_10p' for using only 10% of training set 8 | 9 | GLOVE_PATH="EditSQL/word_emb/glove.840B.300d.txt" # you need to change this 10 | LOGDIR="EditSQL/logs_clean/logs_spider_editsql${SETTING}/pretraining" 11 | 12 | 13 | python3 EditSQL_run.py --raw_train_filename="EditSQL/data_clean/spider_data_removefrom/train${SETTING}.pkl" \ 14 | --raw_validation_filename="EditSQL/data_clean/spider_data_removefrom/dev.pkl" \ 15 | --database_schema_filename="EditSQL/data_clean/spider_data_removefrom/tables.json" \ 16 | --embedding_filename=$GLOVE_PATH \ 17 | --data_directory="EditSQL/data_clean/processed_data_spider_removefrom${SETTING}" \ 18 | --input_key="utterance" \ 19 | --use_schema_encoder=1 \ 20 | --use_schema_attention=1 \ 21 | --use_encoder_attention=1 \ 22 | --use_schema_self_attention=1 \ 23 | --use_schema_encoder_2=1 \ 24 | --use_bert=1 \ 25 | --bert_type_abb=uS \ 26 | --fine_tune_bert=1 \ 27 | --interaction_level=1 \ 28 | --reweight_batch=1 \ 29 | --freeze=1 \ 30 | --train=1 \ 31 | --logdir=$LOGDIR \ 32 | --evaluate=1 \ 33 | --evaluate_split="valid" \ 34 | --initial_patience=5 \ 35 | --use_predicted_queries=1 36 | -------------------------------------------------------------------------------- /scripts/editsql/self_train_0.5.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | source activate gpu-py3 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | GLOVE_PATH="EditSQL/word_emb/glove.840B.300d.txt" 7 | LOGDIR="EditSQL/logs_clean/logs_spider_editsql_10p" 8 | 9 | # online setting 10 | NUM_OP="3" 11 | ED="prob=0.995" 12 | UPDATE_ITER=1000 13 | 14 | SUP="self_train_0.5" 15 | DATA_SEED=0 # 0, 10, 100 16 | ST=0 17 | END=-1 18 | 19 | OUTPUT_PATH=${SUP}_ITER${UPDATE_ITER}_DATASEED${DATA_SEED} 20 | echo ${OUTPUT_PATH} 21 | 22 | python3 interaction_editsql.py \ 23 | --raw_train_filename="EditSQL/data_clean/spider_data_removefrom/train_10p.pkl" \ 24 | --raw_validation_filename="EditSQL/data_clean/spider_data_removefrom/dev.pkl" \ 25 | --database_schema_filename="EditSQL/data_clean/spider_data_removefrom/tables.json" \ 26 | --embedding_filename=$GLOVE_PATH \ 27 | --data_directory="EditSQL/data_clean/processed_data_spider_removefrom_10p" \ 28 | --raw_data_directory="EditSQL/data_clean/spider" \ 29 | --input_key="utterance" \ 30 | --use_schema_encoder=1 \ 31 | --use_schema_attention=1 \ 32 | --use_encoder_attention=1 \ 33 | --use_schema_self_attention=1 \ 34 | --use_schema_encoder_2=1 \ 35 | --use_bert=1 \ 36 | --bert_type_abb=uS \ 37 | --fine_tune_bert=1 \ 38 | --interaction_level=1 \ 39 | --reweight_batch=1 \ 40 | --freeze=1 \ 41 | --logdir=$LOGDIR \ 42 | --evaluate=1 --train=1 \ 43 | --evaluate_split="valid" \ 44 | --use_predicted_queries=1 \ 45 | --eval_maximum_sql_length=100 \ 46 | --job="online_learning" \ 47 | --setting="online_pretrain_10p" --supervision=${SUP} --data_seed=${DATA_SEED} \ 48 | --start_iter=${ST} --end_iter=${END} --ask_structure=1 \ 49 | --output_path ${LOGDIR}/records_${OUTPUT_PATH}.json \ 50 | --update_iter ${UPDATE_ITER} \ 51 | > ${LOGDIR}/records_${OUTPUT_PATH}.output 2>&1 & -------------------------------------------------------------------------------- /scripts/editsql/test.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | source activate gpu-py3 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | 7 | SETTING="_10p" # "", "_10p", .. 8 | 9 | GLOVE_PATH="EditSQL/word_emb/glove.840B.300d.txt" # you need to change this 10 | LOGDIR="EditSQL/logs_clean/logs_spider_editsql${SETTING}" 11 | 12 | 13 | python3 EditSQL_run.py --raw_train_filename="EditSQL/data_clean/spider_data_removefrom/train${SETTING}.pkl" \ 14 | --raw_validation_filename="EditSQL/data_clean/spider_data_removefrom/dev.pkl" \ 15 | --database_schema_filename="EditSQL/data_clean/spider_data_removefrom/tables.json" \ 16 | --embedding_filename=$GLOVE_PATH \ 17 | --data_directory="EditSQL/data_clean/processed_data_spider_removefrom${SETTING}" \ 18 | --input_key="utterance" \ 19 | --use_schema_encoder=1 \ 20 | --use_schema_attention=1 \ 21 | --use_encoder_attention=1 \ 22 | --use_schema_self_attention=1 \ 23 | --use_schema_encoder_2=1 \ 24 | --use_bert=1 \ 25 | --bert_type_abb=uS \ 26 | --fine_tune_bert=1 \ 27 | --interaction_level=1 \ 28 | --reweight_batch=1 \ 29 | --freeze=1 \ 30 | --logdir=$LOGDIR \ 31 | --evaluate=1 \ 32 | --evaluate_split="valid" \ 33 | --use_predicted_queries=1 \ 34 | --save_file="$LOGDIR/model_best.pt" 35 | -------------------------------------------------------------------------------- /scripts/editsql/test_with_interaction.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | source activate gpu-py3 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | SETTING="online_pretrain_10p" # online_pretrain_10p, full_train 7 | if [ ${SETTING} == "full_train" ] 8 | then 9 | LOGDIR="EditSQL/logs_clean/logs_spider_editsql" 10 | else 11 | LOGDIR="EditSQL/logs_clean/logs_spider_editsql_10p" 12 | fi 13 | 14 | GLOVE_PATH="EditSQL/word_emb/glove.840B.300d.txt" 15 | 16 | # online setting 17 | NUM_OP="3" 18 | ED="prob=0.995" 19 | 20 | OUTPUT_PATH=test_DATAdev_OP${NUM_OP}_ED${ED}_USERsim_SETTING${SETTING} 21 | echo ${OUTPUT_PATH} 22 | 23 | python3 interaction_editsql.py \ 24 | --raw_train_filename="EditSQL/data_clean/spider_data_removefrom/train_10p.pkl" \ 25 | --raw_validation_filename="EditSQL/data_clean/spider_data_removefrom/dev.pkl" \ 26 | --database_schema_filename="EditSQL/data_clean/spider_data_removefrom/tables.json" \ 27 | --embedding_filename=$GLOVE_PATH \ 28 | --data_directory="EditSQL/data_clean/processed_data_spider_removefrom_10p" \ 29 | --raw_data_directory="EditSQL/data_clean/spider" \ 30 | --input_key="utterance" \ 31 | --use_schema_encoder=1 \ 32 | --use_schema_attention=1 \ 33 | --use_encoder_attention=1 \ 34 | --use_schema_self_attention=1 \ 35 | --use_schema_encoder_2=1 \ 36 | --use_bert=1 \ 37 | --bert_type_abb=uS \ 38 | --fine_tune_bert=1 \ 39 | --interaction_level=1 \ 40 | --reweight_batch=1 \ 41 | --freeze=1 \ 42 | --logdir=$LOGDIR \ 43 | --evaluate=1 --train=0 \ 44 | --evaluate_split="valid" \ 45 | --use_predicted_queries=1 \ 46 | --eval_maximum_sql_length=100 \ 47 | --job="test_w_interaction" \ 48 | --num_options=${NUM_OP} --err_detector=${ED} --friendly_agent=0 --user="sim" \ 49 | --setting=${SETTING} --ask_structure=1 \ 50 | --output_path ${LOGDIR}/records_${OUTPUT_PATH}.json \ 51 | > ${LOGDIR}/records_${OUTPUT_PATH}.output 2>&1 & -------------------------------------------------------------------------------- /scripts/sqlova/bin_user.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate gpu-py3 4 | export CUDA_VISIBLE_DEVICES="0" # leave it blank for CPU or fill in GPU ids 5 | 6 | DATE=$(date +'%m%d%H%M%S') 7 | 8 | # online setting 9 | SETTING="online_pretrain_1p" # online_pretrain_Xp 10 | SUP="bin_feedback" 11 | 12 | DATASEED=0 # 0, 10, 100 13 | START_ITER=0 14 | END_ITER=-1 15 | AUTO_ITER=1 16 | ITER=1000 17 | BATCH_SIZE=16 18 | 19 | # path setting 20 | LOG_DIR="SQLova_model/logs" # save training logs 21 | MODEL_DIR="SQLova_model/checkpoints_${SETTING}/" # model dir 22 | 23 | OUTPUT_PATH=${SUP}_SETTING${SETTING}_ITER${ITER}_DATASEED${DATASEED} 24 | echo ${OUTPUT_PATH} 25 | 26 | python interaction_sqlova.py --job online_learning --model_dir ${MODEL_DIR} --data online \ 27 | --lr_bert 5e-5 --setting ${SETTING} \ 28 | --data_seed ${DATASEED} --supervision ${SUP} \ 29 | --update_iter ${ITER} --start_iter ${START_ITER} --end_iter ${END_ITER} --auto_iter ${AUTO_ITER} \ 30 | --bS ${BATCH_SIZE} \ 31 | --output_path ${LOG_DIR}/records_${OUTPUT_PATH}.json \ 32 | > ${LOG_DIR}/records_${OUTPUT_PATH}.${DATE}.output 2>&1 & 33 | 34 | 35 | -------------------------------------------------------------------------------- /scripts/sqlova/bin_user_expert.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate gpu-py3 4 | export CUDA_VISIBLE_DEVICES="0" # leave it blank for CPU or fill in GPU ids 5 | 6 | DATE=$(date +'%m%d%H%M%S') 7 | 8 | # online setting 9 | SETTING="online_pretrain_1p" # online_pretrain_Xp 10 | SUP="bin_feedback_expert" 11 | 12 | DATASEED=0 # 0, 10, 100 13 | START_ITER=0 14 | END_ITER=-1 15 | AUTO_ITER=1 16 | ITER=1000 17 | BATCH_SIZE=16 18 | 19 | # path setting 20 | LOG_DIR="SQLova_model/logs" # save training logs 21 | MODEL_DIR="SQLova_model/checkpoints_${SETTING}/" # model dir 22 | 23 | OUTPUT_PATH=${SUP}_SETTING${SETTING}_ITER${ITER}_DATASEED${DATASEED} 24 | echo ${OUTPUT_PATH} 25 | 26 | python interaction_sqlova.py --job online_learning --model_dir ${MODEL_DIR} --data online \ 27 | --lr_bert 5e-5 --setting ${SETTING} \ 28 | --data_seed ${DATASEED} --supervision ${SUP} \ 29 | --update_iter ${ITER} --start_iter ${START_ITER} --end_iter ${END_ITER} --auto_iter ${AUTO_ITER} \ 30 | --bS ${BATCH_SIZE} \ 31 | --output_path ${LOG_DIR}/records_${OUTPUT_PATH}.json \ 32 | > ${LOG_DIR}/records_${OUTPUT_PATH}.${DATE}.output 2>&1 & 33 | 34 | 35 | -------------------------------------------------------------------------------- /scripts/sqlova/data_preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate gpu-py3 4 | export CUDA_VISIBLE_DEVICES="0" # leave it blank for CPU or fill in GPU ids 5 | 6 | LOG_DIR="SQLova_model/logs" 7 | 8 | python3 SQLova_train.py --job data_preprocess \ 9 | --bert_type_abb uS --max_seq_leng 222 \ 10 | > ${LOG_DIR}/data_preprocess.log 2>&1 & 11 | -------------------------------------------------------------------------------- /scripts/sqlova/full_expert.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate gpu-py3 4 | export CUDA_VISIBLE_DEVICES="0" # leave it blank for CPU or fill in GPU ids 5 | 6 | DATE=$(date +'%m%d%H%M%S') 7 | 8 | # online setting 9 | SETTING="online_pretrain_1p" # online_pretrain_Xp 10 | SUP="full_expert" 11 | 12 | DATASEED=0 # 0, 10, 100 13 | START_ITER=0 14 | END_ITER=-1 15 | AUTO_ITER=1 16 | ITER=1000 17 | BATCH_SIZE=16 18 | 19 | # path setting 20 | LOG_DIR="SQLova_model/logs" # save training logs 21 | MODEL_DIR="SQLova_model/checkpoints_${SETTING}/" # model dir 22 | 23 | OUTPUT_PATH=${SUP}_SETTING${SETTING}_ITER${ITER}_DATASEED${DATASEED} 24 | echo ${OUTPUT_PATH} 25 | 26 | python interaction_sqlova.py --job online_learning --model_dir ${MODEL_DIR} --data online \ 27 | --lr_bert 5e-5 --setting ${SETTING} \ 28 | --data_seed ${DATASEED} --supervision ${SUP} \ 29 | --update_iter ${ITER} --start_iter ${START_ITER} --end_iter ${END_ITER} --auto_iter ${AUTO_ITER} \ 30 | --bS ${BATCH_SIZE} \ 31 | --output_path ${LOG_DIR}/records_${OUTPUT_PATH}.json \ 32 | > ${LOG_DIR}/records_${OUTPUT_PATH}.${DATE}.output 2>&1 & 33 | 34 | 35 | -------------------------------------------------------------------------------- /scripts/sqlova/misp_neil.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate gpu-py3 4 | export CUDA_VISIBLE_DEVICES="0" # leave it blank for CPU or fill in GPU ids 5 | 6 | DATE=$(date +'%m%d%H%M%S') 7 | 8 | # online setting 9 | SETTING="online_pretrain_1p" # online_pretrain_Xp 10 | SUP="misp_neil" 11 | NUM_OP="3" 12 | ED="prob=0.95" 13 | 14 | DATASEED=0 # 0, 10, 100 15 | START_ITER=0 16 | END_ITER=-1 17 | AUTO_ITER=1 18 | ITER=1000 19 | BATCH_SIZE=16 20 | 21 | # path setting 22 | LOG_DIR="SQLova_model/logs" # save training logs 23 | MODEL_DIR="SQLova_model/checkpoints_${SETTING}/" # model dir 24 | 25 | OUTPUT_PATH=${SUP}_OP${NUM_OP}_ED${ED}_SETTING${SETTING}_ITER${ITER}_DATASEED${DATASEED} 26 | echo ${OUTPUT_PATH} 27 | 28 | python interaction_sqlova.py --job online_learning --model_dir ${MODEL_DIR} --data online \ 29 | --num_options ${NUM_OP} --err_detector ${ED} --friendly_agent 0 --user sim \ 30 | --lr_bert 5e-5 --setting ${SETTING} \ 31 | --data_seed ${DATASEED} --supervision ${SUP} \ 32 | --update_iter ${ITER} --start_iter ${START_ITER} --end_iter ${END_ITER} --auto_iter ${AUTO_ITER} \ 33 | --bS ${BATCH_SIZE} --ask_structure 0 \ 34 | --output_path ${LOG_DIR}/records_${OUTPUT_PATH}.json \ 35 | > ${LOG_DIR}/records_${OUTPUT_PATH}.${DATE}.output 2>&1 & 36 | 37 | 38 | -------------------------------------------------------------------------------- /scripts/sqlova/misp_neil_perfect.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate gpu-py3 4 | export CUDA_VISIBLE_DEVICES="0" # leave it blank for CPU or fill in GPU ids 5 | 6 | DATE=$(date +'%m%d%H%M%S') 7 | 8 | # online setting 9 | SETTING="online_pretrain_1p" # online_pretrain_Xp 10 | SUP="misp_neil_perfect" 11 | NUM_OP="3" 12 | ED="perfect" 13 | 14 | DATASEED=0 # 0, 10, 100 15 | START_ITER=0 16 | END_ITER=-1 17 | AUTO_ITER=1 18 | ITER=1000 19 | BATCH_SIZE=16 20 | 21 | # path setting 22 | LOG_DIR="SQLova_model/logs" # save training logs 23 | MODEL_DIR="SQLova_model/checkpoints_${SETTING}/" # model dir 24 | 25 | OUTPUT_PATH=${SUP}_OP${NUM_OP}_ED${ED}_SETTING${SETTING}_ITER${ITER}_DATASEED${DATASEED} 26 | echo ${OUTPUT_PATH} 27 | 28 | python interaction_sqlova.py --job online_learning --model_dir ${MODEL_DIR} --data online \ 29 | --num_options ${NUM_OP} --err_detector ${ED} --friendly_agent 0 --user gold_sim \ 30 | --lr_bert 5e-5 --setting ${SETTING} \ 31 | --data_seed ${DATASEED} --supervision ${SUP} \ 32 | --update_iter ${ITER} --start_iter ${START_ITER} --end_iter ${END_ITER} --auto_iter ${AUTO_ITER} \ 33 | --bS ${BATCH_SIZE} --ask_structure 1 \ 34 | --output_path ${LOG_DIR}/records_${OUTPUT_PATH}.json \ 35 | > ${LOG_DIR}/records_${OUTPUT_PATH}.${DATE}.output 2>&1 & 36 | 37 | 38 | -------------------------------------------------------------------------------- /scripts/sqlova/misp_neil_pos.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate gpu-py3 4 | export CUDA_VISIBLE_DEVICES="0" # leave it blank for CPU or fill in GPU ids 5 | 6 | DATE=$(date +'%m%d%H%M%S') 7 | 8 | # online setting 9 | SETTING="online_pretrain_1p" # online_pretrain_Xp 10 | SUP="misp_neil_pos" 11 | NUM_OP="3" 12 | ED="prob=0.95" 13 | 14 | DATASEED=0 # 0, 10, 100 15 | START_ITER=0 16 | END_ITER=-1 17 | AUTO_ITER=1 18 | ITER=1000 19 | BATCH_SIZE=16 20 | 21 | # path setting 22 | LOG_DIR="SQLova_model/logs" # save training logs 23 | MODEL_DIR="SQLova_model/checkpoints_${SETTING}/" # model dir 24 | 25 | OUTPUT_PATH=${SUP}_OP${NUM_OP}_ED${ED}_SETTING${SETTING}_ITER${ITER}_DATASEED${DATASEED} 26 | echo ${OUTPUT_PATH} 27 | 28 | python interaction_sqlova.py --job online_learning --model_dir ${MODEL_DIR} --data online \ 29 | --num_options ${NUM_OP} --err_detector ${ED} --friendly_agent 0 --user sim \ 30 | --lr_bert 5e-5 --setting ${SETTING} \ 31 | --data_seed ${DATASEED} --supervision ${SUP} \ 32 | --update_iter ${ITER} --start_iter ${START_ITER} --end_iter ${END_ITER} --auto_iter ${AUTO_ITER} \ 33 | --bS ${BATCH_SIZE} --ask_structure 0 \ 34 | --output_path ${LOG_DIR}/records_${OUTPUT_PATH}.json \ 35 | > ${LOG_DIR}/records_${OUTPUT_PATH}.${DATE}.output 2>&1 & 36 | 37 | 38 | -------------------------------------------------------------------------------- /scripts/sqlova/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate gpu-py3 4 | export CUDA_VISIBLE_DEVICES="0" # leave it blank for CPU or fill in GPU ids 5 | 6 | SETTING="online_pretrain_1p" # online_pretrain_Xp, full_train 7 | 8 | LOG_DIR="SQLova_model/logs" 9 | MODEL_DIR="SQLova_model/checkpoints_${SETTING}/" 10 | 11 | python3 SQLova_train.py --seed 1 --bS 16 --accumulate_gradients 4 --bert_type_abb uS \ 12 | --fine_tune --lr 0.001 --lr_bert 5e-5 \ 13 | --max_seq_leng 222 --setting ${SETTING} --job train \ 14 | --output_dir ${MODEL_DIR} \ 15 | > ${LOG_DIR}/pretrain_SETTING${SETTING}.log 2>&1 & 16 | -------------------------------------------------------------------------------- /scripts/sqlova/self_train_0.5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate gpu-py3 4 | export CUDA_VISIBLE_DEVICES="0" # leave it blank for CPU or fill in GPU ids 5 | 6 | DATE=$(date +'%m%d%H%M%S') 7 | 8 | # online setting 9 | SETTING="online_pretrain_1p" # online_pretrain_Xp 10 | SUP="self_train_0.5" 11 | 12 | DATASEED=0 # 0, 10, 100 13 | START_ITER=0 14 | END_ITER=-1 15 | AUTO_ITER=1 16 | ITER=1000 17 | BATCH_SIZE=16 18 | 19 | # path setting 20 | LOG_DIR="SQLova_model/logs" # save training logs 21 | MODEL_DIR="SQLova_model/checkpoints_${SETTING}/" # model dir 22 | 23 | OUTPUT_PATH=${SUP}_SETTING${SETTING}_ITER${ITER}_DATASEED${DATASEED} 24 | echo ${OUTPUT_PATH} 25 | 26 | python interaction_sqlova.py --job online_learning --model_dir ${MODEL_DIR} --data online \ 27 | --lr_bert 5e-5 --setting ${SETTING} \ 28 | --data_seed ${DATASEED} --supervision ${SUP} \ 29 | --update_iter ${ITER} --start_iter ${START_ITER} --end_iter ${END_ITER} --auto_iter ${AUTO_ITER} \ 30 | --bS ${BATCH_SIZE} \ 31 | --output_path ${LOG_DIR}/records_${OUTPUT_PATH}.json \ 32 | > ${LOG_DIR}/records_${OUTPUT_PATH}.${DATE}.output 2>&1 & 33 | 34 | 35 | -------------------------------------------------------------------------------- /scripts/sqlova/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate gpu-py3 4 | export CUDA_VISIBLE_DEVICES="0" # leave it blank for CPU or fill in GPU ids 5 | 6 | SETTING="online_pretrain_1p" # online_pretrain_Xp, full_train 7 | TEST_JOB=dev-test # dev-test for testing on Dev set, test-test for testing on Test set 8 | 9 | LOG_DIR="SQLova_model/logs" 10 | MODEL_DIR="SQLova_model/checkpoints_${SETTING}/" 11 | 12 | python3 SQLova_train.py --seed 1 --bS 16 --accumulate_gradients 4 --bert_type_abb uS \ 13 | --fine_tune --lr 0.001 --lr_bert 5e-5 \ 14 | --max_seq_leng 222 --setting ${SETTING} --job ${TEST_JOB} \ 15 | --load_checkpoint_dir ${MODEL_DIR} \ 16 | > ${LOG_DIR}/${TEST_JOB}_SETTING${SETTING}.log 2>&1 & 17 | -------------------------------------------------------------------------------- /scripts/sqlova/test_with_interaction.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate gpu-py3 4 | export CUDA_VISIBLE_DEVICES="0" # leave it blank for CPU or fill in GPU ids 5 | 6 | # test setting 7 | SETTING="online_pretrain_1p" # online_pretrain_Xp, full_train 8 | NUM_OP="3" 9 | ED="prob=0.95" 10 | DATA="dev" # dev or test set 11 | BATCH_SIZE=16 12 | USER='sim' # simulated user interaction 13 | 14 | # path setting 15 | LOG_DIR="SQLova_model/logs" # save training logs 16 | MODEL_DIR="SQLova_model/checkpoints_${SETTING}/" # model dir 17 | 18 | OUTPUT_PATH=test_DATA${DATA}_OP${NUM_OP}_ED${ED}_USER${USER}_SETTING${SETTING} 19 | echo ${OUTPUT_PATH} 20 | 21 | python interaction_sqlova.py --job test_w_interaction --model_dir ${MODEL_DIR} --data ${DATA} \ 22 | --num_options ${NUM_OP} --err_detector ${ED} --friendly_agent 0 --user ${USER} \ 23 | --lr_bert 5e-5 --setting ${SETTING} \ 24 | --bS ${BATCH_SIZE} --ask_structure 0 \ 25 | --output_path ${LOG_DIR}/records_${OUTPUT_PATH}.json \ 26 | > ${LOG_DIR}/records_${OUTPUT_PATH}.output 2>&1 & 27 | 28 | 29 | -------------------------------------------------------------------------------- /slides/MISP_NEIL_EMNLP20_slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlab-osu/MISP/7870566ab6b9e121d648478968367bc79c12f7ef/slides/MISP_NEIL_EMNLP20_slides.pdf -------------------------------------------------------------------------------- /text2sql.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlab-osu/MISP/7870566ab6b9e121d648478968367bc79c12f7ef/text2sql.png -------------------------------------------------------------------------------- /user_study_utils.py: -------------------------------------------------------------------------------- 1 | # utils for user study 2 | import random 3 | import json 4 | 5 | class bcolors: 6 | """ 7 | Usage: print bcolors.WARNING + "Warning: No active frommets remain. Continue?" + bcolors.ENDC 8 | """ 9 | PINK = '\033[95m' 10 | BLUE = '\033[94m' 11 | GREEN = '\033[92m' 12 | YELLOW = '\033[93m' 13 | RED = '\033[91m' 14 | ENDC = '\033[0m' 15 | BOLD = '\033[1m' 16 | UNDERLINE = '\033[4m' 17 | 18 | def print_main(sent): 19 | print(bcolors.PINK + bcolors.BOLD + sent + bcolors.ENDC) 20 | 21 | 22 | def print_header(remaining_size, bool_table_color=False): 23 | task_notification = " Interactive Database Query " 24 | remain_notification = " Remaining: %d " % (remaining_size) 25 | print("=" * 50) 26 | print(bcolors.BOLD + task_notification + bcolors.ENDC) 27 | print(bcolors.BOLD + remain_notification + bcolors.ENDC) 28 | print(bcolors.BOLD + "\n Tip: Words referring to table headers/attributes are marked in " + 29 | bcolors.BLUE + "this color" + bcolors.ENDC + ".") 30 | if bool_table_color: 31 | print(bcolors.BOLD + " Tip: Words referring to table names are marked in " + bcolors.YELLOW + "this color" + 32 | bcolors.ENDC + ".") 33 | print("=" * 50) 34 | print("") 35 | 36 | 37 | def case_sampling_SQLNet(K=100): 38 | from SQLNet_model.sqlnet.utils import load_data 39 | data_dir = "SQLNet_model/data/" 40 | sql_data, table_data = load_data(data_dir + "test_tok.jsonl", data_dir + "test_tok.tables.jsonl") 41 | size = len(sql_data) 42 | print(size) 43 | sampled_ids = [] 44 | while len(sampled_ids) < K: 45 | id = random.choice(range(size)) 46 | if id in sampled_ids: 47 | continue 48 | 49 | question = sql_data[id]['question'] 50 | table_id = sql_data[id]['table_id'] 51 | headers = table_data[table_id]['header'] 52 | 53 | try: 54 | print("question: {}\nheaders: {}".format(question, headers)) 55 | action = raw_input("Take or not?") 56 | if action == 'y': 57 | sampled_ids.append(id) 58 | json.dump(sampled_ids, open(data_dir + "user_study_ids.json", "w")) 59 | except: 60 | pass 61 | --------------------------------------------------------------------------------