├── .gitmodules ├── LICENSE ├── PIE_ckpt ├── bert_config.json ├── multi_round_infer.sh ├── pie_infer.sh └── vocab.txt ├── README.md ├── apply_opcode.py ├── errorify ├── README.md ├── common_deletes.p ├── common_inserts.p ├── common_replaces.p ├── error.py ├── errorifier.py ├── morphs.txt ├── parse_verbs.py └── verbs.p ├── example_scripts ├── README.md ├── end_to_end.sh ├── m2_eval.sh ├── multi_round_infer.sh ├── pie_infer.sh ├── pie_train.sh └── preprocess.sh ├── get_edit_vocab.py ├── get_seq2edits.py ├── install_dependencies.sh ├── modeling.py ├── modified_modeling.py ├── opcodes.py ├── optimization.py ├── pickles ├── bea │ ├── common_deletes.p │ ├── common_inserts.p │ └── common_multitoken_inserts.p └── conll │ ├── common_deletes.p │ ├── common_inserts.p │ └── common_multitoken_inserts.p ├── requirements.txt ├── scratch ├── conll_test.txt ├── official-2014.combined.m2 ├── train_corr_sentences.txt └── train_incorr_sentences.txt ├── seq2edits_utils.py ├── spellcheck_utils.py ├── tokenization.py ├── tokenize_input.py ├── transform_suffixes.py ├── utils.py ├── wem_utils.py └── word_edit_model.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "m2scorer"] 2 | path = m2scorer 3 | url = git@github.com:nusnlp/m2scorer.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Abhijeet Awasthi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PIE_ckpt/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 1024, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 4096, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 16, 11 | "num_hidden_layers": 24, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 28996 19 | } 20 | -------------------------------------------------------------------------------- /PIE_ckpt/multi_round_infer.sh: -------------------------------------------------------------------------------- 1 | cur_dir=$PWD 2 | cd .. 3 | 4 | export output_dir=scratch 5 | export data_dir=scratch 6 | export bert_cased_vocab=PIE_ckpt/vocab.txt 7 | export bert_config_file=PIE_ckpt/bert_config.json 8 | 9 | input_file=scratch/conll_test.txt 10 | 11 | echo Running Round 0... 12 | 13 | #tokenize input sentences into wordpiece tokens (output_tokens.txt) 14 | python3.6 tokenize_input.py \ 15 | --input=$input_file \ 16 | --output_tokens=$data_dir/output_tokens.txt \ 17 | --output_token_ids=$data_dir/test_incorr.txt \ 18 | --vocab_path=$bert_cased_vocab \ 19 | --do_spell_check=True 20 | 21 | #output_tokens is the wordpiece tokenized version of input (output_tokens.txt) 22 | #test_incorr.txt has token_ids of wordpiece tokens 23 | 24 | #PIE predicts edit ids for wordpiece tokenids (test_incorr.txt) 25 | cd $cur_dir 26 | ./pie_infer.sh 27 | cd .. 28 | cp $output_dir/test_results.txt $data_dir/multi_round_0_test_results.txt 29 | #test_results.txt contains edit ids inferred through PIE 30 | 31 | #apply edits (test_results.txt) on the wordpiece tokens of input (output_tokens.txt) 32 | python3.6 apply_opcode.py \ 33 | --vocab_path=$bert_cased_vocab \ 34 | --input_tokens=$data_dir/output_tokens.txt \ 35 | --edit_ids=$data_dir/multi_round_0_test_results.txt \ 36 | --output_tokens=$data_dir/multi_round_0_test_predictions.txt \ 37 | --infer_mode=conll \ 38 | --path_common_inserts=pickles/conll/common_inserts.p \ 39 | --path_common_multitoken_inserts=pickles/conll/common_multitoken_inserts.p \ 40 | --path_common_deletes=pickles/conll/common_deletes.p \ 41 | 42 | 43 | #corrected_sentences: multi_round_0_test_predictions.txt 44 | 45 | #iterate above for 3 more rounds and refine the corrected sentences further 46 | 47 | for round_id in {1..3}; 48 | do 49 | echo Running Round $round_id ... 50 | python3.6 tokenize_input.py \ 51 | --input=$data_dir/multi_round_"$(( round_id - 1 ))"_test_predictions.txt \ 52 | --output_tokens=$data_dir/output_tokens.txt \ 53 | --output_token_ids=$data_dir/test_incorr.txt \ 54 | --vocab_path=$bert_cased_vocab 55 | 56 | cd $cur_dir 57 | ./pie_infer.sh 58 | cd .. 59 | cp $output_dir/test_results.txt $data_dir/multi_round_"$round_id"_test_results.txt 60 | 61 | python3.6 apply_opcode.py \ 62 | --vocab_path=$bert_cased_vocab \ 63 | --input_tokens=$data_dir/output_tokens.txt \ 64 | --edit_ids=$data_dir/multi_round_"$round_id"_test_results.txt \ 65 | --output_tokens=$data_dir/multi_round_"$round_id"_test_predictions.txt \ 66 | --infer_mode=conll \ 67 | --path_common_inserts=pickles/conll/common_inserts.p \ 68 | --path_common_multitoken_inserts=pickles/conll/common_multitoken_inserts.p \ 69 | --path_common_deletes=pickles/conll/common_deletes.p 70 | done 71 | -------------------------------------------------------------------------------- /PIE_ckpt/pie_infer.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | 4 | output_dir=scratch 5 | data_dir=scratch 6 | bert_cased_vocab=PIE_ckpt/vocab.txt 7 | bert_config_file=PIE_ckpt/bert_config.json 8 | path_multitoken_inserts=pickles/conll/common_multitoken_inserts.p 9 | path_inserts=pickles/conll/common_inserts.p 10 | 11 | python3.6 word_edit_model.py \ 12 | --do_predict=True \ 13 | --data_dir=$data_dir \ 14 | --vocab_file=$bert_cased_vocab \ 15 | --bert_config_file=$bert_config_file \ 16 | --max_seq_length=128 \ 17 | --predict_batch_size=16 \ 18 | --output_dir=$output_dir \ 19 | --do_lower_case=False \ 20 | --path_inserts=$path_inserts \ 21 | --path_multitoken_inserts=$path_multitoken_inserts \ 22 | --predict_checkpoint=PIE_ckpt/pie_model.ckpt 23 | 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PIE: Parallel Iterative Edit Models for Local Sequence Transduction 2 | Fast Grammatical Error Correction using BERT 3 | 4 | Code and Pre-trained models accompanying our paper "Parallel Iterative Edit Models for Local Sequence Transduction" (EMNLP-IJCNLP 2019) 5 | 6 | PIE is a BERT based architecture for local sequence transduction tasks like Grammatical Error Correction. Unlike the standard approach of modeling GEC as a task of translation from "incorrect" to "correct" language, we pose GEC as local sequence editing task. We further reduce local sequence editing problem to a sequence labeling setup where we utilize BERT to non-autoregressively label input tokens with edits. We rewire the BERT architecture (without retraining) specifically for the task of sequence editing. We find that PIE models for GEC are 5 to 15 times faster than existing state of the art architectures and still maintain a competitive accuracy. For more details please check out our [EMNLP-IJCNLP 2019 paper](https://www.aclweb.org/anthology/D19-1435.pdf) 7 | 8 | ``` 9 | @inproceedings{awasthi-etal-2019-parallel, 10 | title = "Parallel Iterative Edit Models for Local Sequence Transduction", 11 | author = "Awasthi, Abhijeet and 12 | Sarawagi, Sunita and 13 | Goyal, Rasna and 14 | Ghosh, Sabyasachi and 15 | Piratla, Vihari", 16 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)", 17 | month = nov, 18 | year = "2019", 19 | address = "Hong Kong, China", 20 | publisher = "Association for Computational Linguistics", 21 | url = "https://www.aclweb.org/anthology/D19-1435", 22 | doi = "10.18653/v1/D19-1435", 23 | pages = "4259--4269", 24 | } 25 | ``` 26 | 27 | 28 | ## Datasets 29 | * All the public GEC datasets used in the paper can be obtained from [here](https://www.cl.cam.ac.uk/research/nl/bea2019st/#data) 30 | * [Synthetically created datasets](https://drive.google.com/open?id=1bl5reJ-XhPEfEaPjvO45M7w0yN-0XGOA) (perturbed version of 1 billion word corpus) divided into 5 parts to independently train 5 different ensembles. (all the ensembles are further finetuned using the public GEC datasets mentioned above.) 31 | 32 | 33 | ## Pretrained Models 34 | * [PIE as reported in the paper](https://huggingface.co/AbhijeetA/PIE/resolve/main/pie_model.zip) 35 | - trained on a Synethically created GEC dataset starting with BERT's initialization 36 | - finetuned further on Lang8, NUCLE and FCE datasets 37 | * **Inference using the pretrained PIE ckpt** 38 | - Copy the pretrained checkpoint files provided above to PIE_ckpt directory 39 | - Your PIE_ckpt directory should contain 40 | - bert_config.json 41 | - multi_round_infer.sh 42 | - pie_infer.sh 43 | - pie_model.ckpt.data-00000-of-00001 44 | - pie_model.ckpt.index 45 | - pie_model.ckpt.meta 46 | - vocab.txt 47 | - Run: `$ ./multi_round_infer.sh` from PIE_ckpt directory 48 | - NOTE: If you are using cloud-TPUs for inference, move the PIE_ckpt directory to the cloud bucket and change the paths in pie_infer.sh and multi_round_infer.sh accordingly 49 | 50 | ## Code Description 51 | **An example usage of code in described in the directory "example_scripts".** 52 | 53 | * preprocess.sh 54 | - Extracts common insertions from a sample training data in the "scratch" directory 55 | - converts the training data in the form of incorrect tokens and aligned edits 56 | * pie_train.sh 57 | - trains a pie model using the converted training data 58 | * multi_round_infer.sh 59 | - uses a trained PIE model to obtain edits for incorrect sentences 60 | - does 4 rounds of iterative editing 61 | - uses conll-14 test sentences 62 | * m2_eval.sh 63 | - evaluates the final output using [m2scorer](https://github.com/nusnlp/m2scorer) 64 | * end_to_end.sh 65 | - describes the use of pre-processing, training, inference and evaluation scripts end to end. 66 | * More information in README.md inside "example_scripts" 67 | 68 | **Pre processing and Edits related** 69 | 70 | * seq2edits_utils.py 71 | - contains implementation of edit-distance algorithm. 72 | - cost for substitution modified as per section A.1 in the paper. 73 | - Adapted from [belambert's implimentation](https://github.com/belambert/edit-distance/blob/master/edit_distance/code.py) 74 | * get_edit_vocab.py : Extracts common insertions (\Sigma_a set as described in paper) from a parallel corpus 75 | * get_seq2edits.py : Extracts edits aligned to input tokens 76 | * tokenize_input.py : tokenize a file containing sentences. token_ids obtained go as input to the model. 77 | * opcodes.py : A class where members are all possible edit operations 78 | * transform_suffixes.py: Contains logic for suffix transformations 79 | * tokenization.py : Similar to BERT's implementation, with some GEC specific changes 80 | 81 | **PIE model** (uses [implementation of BERT of bert in Tensorflow](https://github.com/google-research/bert)) 82 | 83 | * word_edit_model.py: Implementation of PIE for learning from a parallel corpous of incorrect tokens and aligned edits. 84 | - logit factorization logic (Keep flag use_bert_more=True to enable logit factorization) 85 | - parallel sequence labeling 86 | * modeling.py : Same as in BERT's implementation 87 | * modified_modeling.py 88 | - Rewires attention mask to obtain representations of candidate appends and replacements 89 | - Used for logit factorization. 90 | * optimization.py : Same as in BERT's implementation 91 | 92 | **Post processing** 93 | * apply_opcode.py 94 | - Applies inferred edits from the PIE model to the incorrect sentences. 95 | - Handles punctuations and spacings as per requirements of a standard dataset (INFER_MODE). 96 | - Contains some obvious rules for captialization etc. 97 | 98 | **Creating synthetic GEC dataset** 99 | * errorify directory contains the scripts we used for perturbing the one-billion-word corpus 100 | 101 | 102 | ## Acknowledgements 103 | This research was partly sponsored by a Google India AI/ML Research Award and Google PhD Fellowship in Machine Learning. We gratefully acknowledge Google's TFRC program for providing us Cloud-TPUs. Thanks to [Varun Patil](https://github.com/pulsejet) for helping us improve the speed of pre-processing and synthetic-data generation pipelines. 104 | -------------------------------------------------------------------------------- /apply_opcode.py: -------------------------------------------------------------------------------- 1 | """Utility to apply opcodes to incorrect sentences.""" 2 | 3 | import pickle 4 | import string 5 | import sys 6 | from tqdm import tqdm 7 | from joblib import Parallel, delayed 8 | import opcodes 9 | from utils import open_w, read_file_lines, pretty, bcolors 10 | from transform_suffixes import apply_transform as apply_suffix_transform 11 | from autocorrect import spell 12 | import tokenization 13 | import argparse 14 | 15 | def add_arguments(parser): 16 | """Build ArgumentParser.""" 17 | parser.add_argument("--vocab_path", type=str, default=None, help="path to bert's cased vocab file") 18 | parser.add_argument("--input_tokens", type=str, default=None, help="path to possibly incorrect token file") 19 | parser.add_argument("--edit_ids", type=str, default=None, help="path to edit ids to be applied on input_tokens") 20 | parser.add_argument("--output_tokens", type=str, default=None, help="path to edited (hopefully corrected) file") 21 | parser.add_argument("--infer_mode", type=str, default="conll", help="post processing mode bea or conll") 22 | parser.add_argument("--path_common_inserts",type=str,default=None,help="path of common unigram inserts") 23 | parser.add_argument("--path_common_multitoken_inserts",type=str,default=None,help="path of common bigram inserts") 24 | parser.add_argument("--path_common_deletes",type=str,default=None,help="path to common deletions observed in train data") 25 | 26 | parser = argparse.ArgumentParser() 27 | add_arguments(parser) 28 | FLAGS, unparsed = parser.parse_known_args() 29 | 30 | DO_PARALLEL = False 31 | INFER_MODE=FLAGS.infer_mode 32 | 33 | vocab = tokenization.load_vocab(FLAGS.vocab_path) 34 | basic_tokenizer = tokenization.BasicTokenizer(do_lower_case=False,vocab=vocab) 35 | vocab_words = set(x for x in vocab) 36 | common_deletes = pickle.load(open(FLAGS.path_common_deletes,"rb")) 37 | path_common_inserts = FLAGS.path_common_inserts 38 | path_common_multitoken_inserts = FLAGS.path_common_multitoken_inserts 39 | opcodes = opcodes.Opcodes(path_common_inserts, path_common_multitoken_inserts) 40 | 41 | if __name__ == '__main__': 42 | class config: 43 | INPUT_UNCORRECTED_WORDS = FLAGS.input_tokens 44 | INPUT_EDITS = FLAGS.edit_ids 45 | OUTPUT_CORRECTED_WORDS = FLAGS.output_tokens 46 | 47 | 48 | 49 | def fix_apos_break(word, p_word, pp_word): 50 | #for l'optimse 51 | if p_word == "'" and pp_word not in ["i","a","s"] and len(pp_word) == 1 and pp_word.isalpha() and word.isalpha(): 52 | return True 53 | else: 54 | return False 55 | 56 | 57 | 58 | def apply_opcodes(words_uncorrected, ops_to_apply, 59 | join_wordpiece_subwords=True, remove_start_end_tokens=True, 60 | do_spell_check=True, apply_only_first_edit=False, 61 | use_common_deletes=True): 62 | """Applies opcodes to an uncorrected token sequence and returns 63 | corrected token sequence.""" 64 | # Initialize 65 | words_corrected = [] 66 | 67 | # Loop over each word 68 | for i, word in enumerate(words_uncorrected): 69 | if i >= len(ops_to_apply): 70 | words_corrected = words_corrected + words_uncorrected[i:] 71 | break 72 | 73 | 74 | op = ops_to_apply[i] 75 | 76 | # Skip if EOS is detected 77 | if op == opcodes.EOS: 78 | print("ERROR: EOS opcode: This should not happen") 79 | exit(1) 80 | break 81 | 82 | elif op == opcodes.CPY: 83 | words_corrected.append(words_uncorrected[i]) 84 | 85 | elif op == opcodes.DEL: 86 | if (words_uncorrected[i] in common_deletes) or (not use_common_deletes): 87 | #and (i==len(words_uncorrected) or words_uncorrected[i+1][0:2]!="##")): 88 | continue 89 | else: 90 | words_corrected.append(words_uncorrected[i]) 91 | 92 | elif op in opcodes.APPEND.values(): 93 | words_corrected.append(words_uncorrected[i]) 94 | insert_words = key_from_val(op, opcodes.APPEND).split() 95 | if i==0 and do_spell_check: 96 | insert_words[0] = insert_words[0].capitalize() 97 | if len(words_uncorrected) > 1: 98 | words_uncorrected[i+1] = words_uncorrected[i+1].lower() 99 | words_corrected.extend(insert_words) 100 | 101 | elif op in opcodes.REP.values(): 102 | replacement = key_from_val(op, opcodes.REP).split() 103 | words_corrected.extend(replacement) 104 | 105 | elif apply_suffix_transform(words_uncorrected, i, op, opcodes): 106 | replacement = apply_suffix_transform(words_uncorrected, i, op, opcodes) 107 | words_corrected.append(replacement) 108 | 109 | else: 110 | words_corrected.append(words_uncorrected[i]) 111 | tqdm.write(bcolors.FAIL + 'ERROR: Copying illegal operation (failed transform?) at ' 112 | + str(words_uncorrected) + bcolors.ENDC) 113 | 114 | if apply_only_first_edit and (op != opcodes.CPY) and (i+1 bea SPECIFIC 153 | result[-1] = "'{}".format(word) 154 | elif len(result) > 1 and word == "t" and result[-1]=="'" and result[-2][-1]=="n": 155 | result.pop() 156 | result[-1] = result[-1] + "'t" 157 | elif word == "ll" and result[-1] == "'": 158 | result[-1]="'ll" 159 | elif len(result) > 1 and fix_apos_break(word, result[-1], result[-2]): 160 | result.pop() 161 | result[-1] += "'" + word 162 | else: 163 | if len(result)==1: 164 | if not tokenization.do_not_split(word): 165 | word = word.capitalize() 166 | #elif (word != 'I') and (word[0].isupper()) and (result[-1] != '.') and (word.lower() in vocab_words): 167 | # print("{} ----------------------------------->{}".format(word,word.lower())) 168 | # word = word.lower() 169 | result.append(word) 170 | elif INFER_MODE=="conll": 171 | if i==0: 172 | result.append(word) 173 | elif word == 'i': 174 | result.append(word.capitalize()) 175 | elif word=="-" or result[-1][-1] == "-": 176 | result[-1] = result[-1] + word 177 | elif word=="/" or result[-1][-1] == "/": 178 | result[-1] = result[-1] + word 179 | elif word == "'" and result[-1] == "'": 180 | result[-1] = "''" 181 | elif word in ["s","re"] and result[-1] == "'": 182 | result[-1] = "'{}".format(word) 183 | elif len(result) > 1 and word=="'" and len(result[-1])>1 and result[-2]=="'": 184 | main_word = result.pop() 185 | result[-1] = "'{}'".format(main_word) 186 | elif len(result) > 1 and len(word)==1 and result[-1]=="'" and len(result[-2])==1: #n't #I'm 187 | result.pop() 188 | result[-1]= result[-1] + "'" + word 189 | else: 190 | if len(result)==1: 191 | if not tokenization.do_not_split(word): 192 | word = word.capitalize() 193 | 194 | #if (gpv.use_spell_check) and (word not in vocab) and (spell(word) in vocab): 195 | # print("{} --> {}".format(word, spell(word))) 196 | # word = spell(word) 197 | result.append(word) 198 | else: 199 | print("wrong infer_mode") 200 | exit(1) 201 | 202 | 203 | if len(result) > 3 and result[-2]=="." and len(result[-3])>3: 204 | if not tokenization.do_not_split(result[-1]): 205 | result[-1]=result[-1].capitalize() 206 | 207 | #if len(result)>1 and result[-2] == "a" and result[-1].startswith(('a','e','i','o','u','A','E','I','O','U')): 208 | # print("{} {}".format(result[-2],result[-1])) 209 | # result[-2]="an" 210 | 211 | #if len(result)>1 and result[-2] == "an" and (not result[-1].startswith(('a','e','i','o','u','A','E','I','O','U'))): 212 | # print("{} {}".format(result[-2],result[-1])) 213 | # result[-2]="a" 214 | 215 | 216 | 217 | prev_word = None 218 | post_process_result = [] 219 | for i, word in enumerate(result): 220 | if word != prev_word or word in {".", "!", "that", "?", "-", "had"}: 221 | post_process_result.append(word) 222 | 223 | 224 | prev_word = word 225 | 226 | return post_process_result 227 | 228 | 229 | ''' 230 | elif len(result) > 1 and word=="t" and result[-1]=="'" and result[-2]=="n": 231 | result.pop() 232 | result[-1]="n't" 233 | ''' 234 | def split_and_convert_to_ints(words_uncorrected,edits): 235 | words_uncorrected = words_uncorrected.split(' ') 236 | edits = edits.split(' ')[0:len(words_uncorrected)] 237 | edits = list(map(int, edits)) 238 | return words_uncorrected, edits 239 | 240 | 241 | if __name__=="__main__": 242 | 243 | corrected = [] 244 | 245 | pretty.pheader('Reading Input') 246 | edits = read_file_lines(config.INPUT_EDITS) 247 | #uncorrected = read_file_lines(config.INPUT_UNCORRECTED) 248 | words_uncorrected = read_file_lines(config.INPUT_UNCORRECTED_WORDS) 249 | 250 | if len(edits) != len(words_uncorrected): 251 | pretty.fail('FATAL ERROR: Lengths of edits and uncorrected files not equal') 252 | exit() 253 | 254 | 255 | pretty.pheader('Splitting and converting to integers') 256 | 257 | if not DO_PARALLEL: 258 | for i in tqdm(range(len(edits))): 259 | edits[i] = list(map(int, edits[i].split(' '))) 260 | #uncorrected[i] = list(map(int, uncorrected[i].split(' '))) 261 | words_uncorrected[i] = words_uncorrected[i].split(' ') 262 | else: 263 | result = Parallel(n_jobs=-1)(delayed(split_and_convert_to_ints)(*s) for s in tqdm(zip(words_uncorrected,edits), total=len(words_uncorrected))) 264 | words_uncorrected = [item[0] for item in result] 265 | edits = [item[1] for item in result] 266 | 267 | #if(len(edits[i]) != len(uncorrected[i])): 268 | #print("edits: {}".format(edits[i])) 269 | #print("length uncorrected: {}".format(len(uncorrected[i]))) 270 | #tqdm.write((bcolors.WARNING + "WARN: Unequal lengths at line {}".format(i + 1) + bcolors.ENDC)) 271 | 272 | pretty.pheader('Applying opcodes') 273 | with open_w(config.OUTPUT_CORRECTED_WORDS) as out_file: 274 | 275 | #sentences_corrected_inplace = [] #contain copies, should be same length as uncorrected list of list 276 | #sentences_corrected_insert = [] #contains additional inserts, should be same length as uncorrected list of list 277 | 278 | if DO_PARALLEL: 279 | s_corrected = Parallel(n_jobs=-1)(delayed(apply_opcodes)(*s) for s in tqdm(zip(words_uncorrected,edits), total=len(words_uncorrected))) 280 | for line in s_corrected: 281 | out_file.write(" ".join(line)+"\n") 282 | else: 283 | for i in tqdm(range(len(edits))): 284 | #s_corrected = untokenize(apply_opcodes(words_uncorrected[i], uncorrected[i], edits[i])) 285 | #print(len(words_uncorrected[i])) 286 | s_corrected = apply_opcodes(words_uncorrected[i], edits[i]) 287 | s_corrected = " ".join(s_corrected) 288 | 289 | out_file.write(s_corrected) 290 | out_file.write('\n') 291 | -------------------------------------------------------------------------------- /errorify/README.md: -------------------------------------------------------------------------------- 1 | # Synthetic Data Generation Scripts 2 | 3 | ## Usage 4 | * python3 error.py $path_of_correct_file $output_path 5 | * Example 6 | - python3 error.py ../scratch/train_corr_sentences.txt ../scratch 7 | - Running above command will create two parallel files corr_sentences.txt and incorr_sentences.txt in ../scratch 8 | - Note that the order of sentences in the newly created parallel files will not be same as the original file. 9 | 10 | * morphs.txt was created by merging verbs, verbs.aux and noms from [here](https://github.com/ixa-ehu/matxin/blob/master/data/freeling/en/dictionary/verbs) 11 | * [Synthetically created datasets](https://drive.google.com/open?id=1bl5reJ-XhPEfEaPjvO45M7w0yN-0XGOA) -------------------------------------------------------------------------------- /errorify/common_deletes.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awasthiabhijeet/PIE/474769e3c4266deefcb7dd5daf802a1306bc7c99/errorify/common_deletes.p -------------------------------------------------------------------------------- /errorify/common_inserts.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awasthiabhijeet/PIE/474769e3c4266deefcb7dd5daf802a1306bc7c99/errorify/common_inserts.p -------------------------------------------------------------------------------- /errorify/common_replaces.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awasthiabhijeet/PIE/474769e3c4266deefcb7dd5daf802a1306bc7c99/errorify/common_replaces.p -------------------------------------------------------------------------------- /errorify/error.py: -------------------------------------------------------------------------------- 1 | """Synthetic data generation by introducing errors.""" 2 | import sys 3 | import multiprocessing as mp 4 | from filelock import FileLock 5 | from tqdm import tqdm 6 | from errorifier import Errorifier 7 | 8 | CORRECT_FILE = sys.argv[2] + "/" + 'corr_sentences.txt' 9 | ERRORED_FILE = sys.argv[2] + "/" + 'incorr_sentences.txt' 10 | FLUSH_SIZE = 100000 11 | BATCH_SIZE = 200 12 | 13 | def flush_queue(pairs, flush=False): 14 | """Write queue obects to file.""" 15 | 16 | wrote = 0 17 | with FileLock('%s.lock' % CORRECT_FILE), FileLock('%s.lock' % ERRORED_FILE): 18 | if pairs.qsize() < FLUSH_SIZE and not flush: 19 | return 20 | 21 | with open(CORRECT_FILE, 'a') as cfile, open(ERRORED_FILE, 'a') as efile: 22 | while pairs.qsize() > 0 and (wrote < FLUSH_SIZE or flush): 23 | wrote += 1 24 | pair = pairs.get() 25 | cfile.write(pair[0] + '\n') 26 | efile.write(pair[1] + '\n') 27 | 28 | def errorify(tpl): 29 | """Function to use for multiprocessing.""" 30 | # Unpack 31 | sentences, pairs = tpl[0], tpl[1] 32 | 33 | for sentence in sentences: 34 | eff = Errorifier(sentence) 35 | puttpl = (eff.correct(), eff.error()) 36 | pairs.put(puttpl) 37 | 38 | if pairs.qsize() > FLUSH_SIZE: 39 | flush_queue(pairs) 40 | 41 | def readn(file, n): 42 | """Read a file n lines at a time.""" 43 | start = True 44 | clist = [] 45 | for line in file: 46 | if start: 47 | clist = [] 48 | start = False 49 | clist.append(line) 50 | if len(clist) == n: 51 | start = True 52 | yield clist 53 | yield clist 54 | 55 | 56 | def errorify_file(filename: str): 57 | """Errorify all sentences in a file.""" 58 | 59 | # Blank files 60 | open(CORRECT_FILE, 'w').close() 61 | open(ERRORED_FILE, 'w').close() 62 | 63 | # Threads = CPU count 64 | pool = mp.Pool(mp.cpu_count()) 65 | manager = mp.Manager() 66 | pairs = manager.Queue() 67 | 68 | # Erroriy each line 69 | file = open(filename, 'r') 70 | [x for x in tqdm(pool.imap(errorify, [(l, pairs) for l in readn(file, BATCH_SIZE)]))] 71 | pool.close() 72 | 73 | # Flush anything remaining 74 | flush_queue(pairs, True) 75 | 76 | if __name__ == '__main__': 77 | errorify_file(sys.argv[1]) 78 | -------------------------------------------------------------------------------- /errorify/errorifier.py: -------------------------------------------------------------------------------- 1 | """Synthetic data generator.""" 2 | import math 3 | import pickle 4 | import random 5 | from numpy.random import choice as npchoice 6 | 7 | VERBS = pickle.load(open('verbs.p', 'rb')) 8 | COMMON_INSERTS = set(pickle.load(open('common_inserts.p', 'rb'))) 9 | COMMON_REPLACES = pickle.load(open('common_replaces.p', 'rb')) 10 | COMMON_DELETES = pickle.load(open('common_deletes.p','rb')) 11 | 12 | class Errorifier: 13 | """Generate errors in good sentences!""" 14 | 15 | def __init__(self, sentence: str): 16 | self.original_sentence = sentence.rstrip() 17 | self.sentence = self.original_sentence 18 | self.tokenized = None 19 | self.tokenize() 20 | 21 | def tokenize(self): 22 | self.tokenized = self.sentence.split() 23 | 24 | def correct(self): 25 | return self.original_sentence 26 | 27 | def no_error(self): 28 | return ' '.join(self.tokenized) 29 | 30 | def delete_error(self): 31 | if len(self.tokenized) > 0: 32 | insertable = list(range(len(self.tokenized))) 33 | index = random.choice(insertable) 34 | 35 | 36 | plist = list(COMMON_DELETES.values()) 37 | plistsum = sum(plist) 38 | plist = [x / plistsum for x in plist] 39 | 40 | # Choose a bad word 41 | ins_word = npchoice(list(COMMON_DELETES.keys()), p=plist) 42 | self.tokenized.insert(index,ins_word) 43 | 44 | return ' '.join(self.tokenized) 45 | 46 | 47 | def verb_error(self, redir=True): 48 | """Introduce a verb error from morphs.txt.""" 49 | 50 | if len(self.tokenized) > 0: 51 | verbs = [i for i, w in enumerate(self.tokenized) if w in VERBS] 52 | if not verbs: 53 | if redir: 54 | return self.replace_error(redir=False) 55 | return self.sentence 56 | 57 | index = random.choice(verbs) 58 | word = self.tokenized[index] 59 | if not VERBS[word]: 60 | return self.sentence 61 | repl = random.choice(VERBS[word]) 62 | self.tokenized[index] = repl 63 | 64 | return ' '.join(self.tokenized) 65 | 66 | def insert_error(self): 67 | """Delete a commonly inserted word.""" 68 | if len(self.tokenized) > 1: 69 | deletable = [i for i, w in enumerate(self.tokenized) if w in COMMON_INSERTS] 70 | if not deletable: 71 | return self.sentence 72 | 73 | index = random.choice(deletable) 74 | del self.tokenized[index] 75 | return ' '.join(self.tokenized) 76 | 77 | def replace_error(self, redir=True): 78 | """Add a common replace error.""" 79 | if len(self.tokenized) > 0: 80 | deletable = [i for i, w in enumerate(self.tokenized) if w in COMMON_REPLACES] 81 | if not deletable: 82 | if redir: 83 | return self.verb_error(redir=False) 84 | return self.sentence 85 | 86 | index = random.choice(deletable) 87 | word = self.tokenized[index] 88 | if not COMMON_REPLACES[word]: 89 | return self.sentence 90 | 91 | # Normalize probabilities 92 | plist = list(COMMON_REPLACES[word].values()) 93 | plistsum = sum(plist) 94 | plist = [x / plistsum for x in plist] 95 | 96 | # Choose a bad word 97 | repl = npchoice(list(COMMON_REPLACES[word].keys()), p=plist) 98 | self.tokenized[index] = repl 99 | 100 | return ' '.join(self.tokenized) 101 | 102 | def error(self): 103 | """Introduce a random error.""" 104 | 105 | #count = math.floor(pow(random.randint(1, 11), 2) / 50) + 1 106 | count = npchoice([0,1,2,3,4],p=[0.05,0.07,0.25,0.35,0.28]) #original (a1) 107 | #count = npchoice([0,1,2,3,4],p=[0.1,0.1,0.2,0.3,0.3]) # (a2) 108 | #count = npchoice([0,1,2,3,4,5],p=[0.1,0.1,0.2,0.2,0.2,0.2]) # (a3) 109 | #count = npchoice([0,1,2,3,4,5],p=[0.1,0.1,0.2,0.2,0.2,0.2]) # (a4) 110 | #count = npchoice([0,1,2,3,4,5],p=[0.0,0.0,0.25,0.25,0.25,0.25]) # (a5) 111 | 112 | for x in range(count): 113 | # Note: verb_error redirects to replace_error and vice versa if nothing happened 114 | error_probs = [.30,.25,.25,.20] #original (a1) 115 | #error_probs = [.25,.30,.30,.15] # (a2) 116 | #error_probs = [.40,.25,.25,.10] #(a3) 117 | #error_probs = [.30,.30,.30,.10] #(a4) 118 | #error_probs = [.35,.25,.25,.15] #(a5) 119 | 120 | error_fun = npchoice([self.insert_error, self.verb_error, self.replace_error, self.delete_error],p=error_probs) 121 | self.sentence = error_fun() 122 | self.tokenize() 123 | 124 | return self.sentence 125 | -------------------------------------------------------------------------------- /errorify/parse_verbs.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | verbs_file = "morphs.txt" 3 | 4 | def expand_dict(d): 5 | result = {} 6 | for key in d: 7 | if key in result: 8 | result[key] = result[key].union(d[key].difference({key})) 9 | else: 10 | result[key] = d[key].difference({key}) 11 | for item in d[key]: 12 | if item in result: 13 | if item != key: 14 | result[item] = result[item].union(d[key].difference({item})).union({key}) 15 | else: 16 | result[item] = result[item].union(d[key].difference({item})) 17 | else: 18 | if item != key: 19 | result[item] = d[key].difference({item}).union({key}) 20 | else: 21 | d[key].difference({item}) 22 | 23 | 24 | for key in result: 25 | result[key]=list(result[key]) 26 | return result 27 | 28 | 29 | with open(verbs_file,"r") as ip_file: 30 | ip_lines = ip_file.readlines() 31 | words = {} 32 | for line in ip_lines: 33 | line = line.strip().split() 34 | if len(line) != 3: 35 | print(line) 36 | word = line[1] 37 | word_form = line[0] 38 | if word in words: 39 | words[word].add(word_form) 40 | else: 41 | words[word]={word_form} 42 | 43 | 44 | result = expand_dict(words) 45 | pickle.dump(result,open("verbs.p","wb")) -------------------------------------------------------------------------------- /errorify/verbs.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awasthiabhijeet/PIE/474769e3c4266deefcb7dd5daf802a1306bc7c99/errorify/verbs.p -------------------------------------------------------------------------------- /example_scripts/README.md: -------------------------------------------------------------------------------- 1 | # Description 2 | end_to_end.sh is a sample script which uses a small GEC corpus 3 | obtained from first 1000 sentences of [wi+locness](https://www.cl.cam.ac.uk/research/nl/bea2019st/data/wi+locness_v2.1.bea19.tar.gz 4 | ) GEC dataset 5 | ## this script does 6 | * pre-processing to convert training data in the form of incorrect tokens and aligned edits (preprocess.sh) 7 | - This includes extracting common insertions (\Sigma_a set in the paper) from the train corpus 8 | - Common insertions define the various types Append and Replacement edits 9 | - Obtaining a parallel corpus of incorrect tokens and edits 10 | * trains a PIE model initialized with BERT checkpoint (pie_train.sh) 11 | * iteratively decodes the corrected output for [conll-14](https://www.comp.nus.edu.sg/~nlp/conll14st.html) test set (multi_round_infer.sh) 12 | * uses [m2scorer](https://github.com/nusnlp/m2scorer) to evaluate the corrected sentences in conll-14 test set (m2_eval.sh) 13 | 14 | # Instructions 15 | * data files are stored in the scratch directory outside example_scripts directory 16 | - train_incorr_sentences: (1000 incorrect sentences from (wi+locness) 17 | - train_corr_sentences: (1000 correct sentences from wi+locness) 18 | - conll_test.txt (test sentences in conll14) 19 | - official-2014.combined.m2 (m2 file provided by conll14) 20 | 21 | * Download one of the pre-trained "cased" bert checkpoints: 22 | - https://storage.googleapis.com/bert_models/2018_10_18/cased_L-12_H-768_A-12.zip 23 | or 24 | https://storage.googleapis.com/bert_models/2018_10_18/cased_L-24_H-1024_A-16.zip 25 | - Unzip the checkpoints inside the PIE_ckpt directory outside example_scripts directory 26 | - The PIE_ckpt directory should contain following files: 27 | - bert_model.ckpt.data-00000-of-00001 (original bert checkpoint) 28 | - bert_model.ckpt.index 29 | - bert_model.ckpt.meta 30 | - bert_config.json 31 | - vocab.txt (vocab of cased bert) 32 | 33 | * Run: ./end_to_end.sh from example_scripts directory 34 | 35 | * The final corrected sentences (multi_round_3_test_predictions.txt) 36 | are stored in scratch directory outside example_scripts directory 37 | and results of m2 scorer are displayed after successful run of end_to_end.sh. 38 | A successfuly trained model on these 1000 sentences should show a F_{0.5} score close to 26.6 39 | Note that these scripts train PIE model on a much smaller dataset for demonstartion purpose. 40 | 41 | * Naming of generated files after "ith round" iterative inference: 42 | - multi_round_i_test_results.txt (edits inferred by PIE at ith round) 43 | - multi_round_i_test_predictions.txt ("corrected sentences at ith round") 44 | 45 | * Follow comments in multi_round_infer.sh for further description of iterative inference. -------------------------------------------------------------------------------- /example_scripts/end_to_end.sh: -------------------------------------------------------------------------------- 1 | ./preprocess.sh 2 | ./pie_train.sh 1 1.0 $RANDOM a3 3 | ./multi_round_infer.sh 4 | ./m2eval.sh -------------------------------------------------------------------------------- /example_scripts/m2_eval.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | cd m2scorer 3 | ./m2scorer ../scratch/multi_round_3_test_predictions.txt ../scratch/official-2014.combined.m2 -------------------------------------------------------------------------------- /example_scripts/multi_round_infer.sh: -------------------------------------------------------------------------------- 1 | cur_dir=$PWD 2 | cd .. 3 | 4 | export output_dir=scratch/output_1.0_ensemble_1 5 | export data_dir=scratch 6 | export bert_cased_vocab=PIE_ckpt/vocab.txt 7 | export bert_config_file=PIE_ckpt/bert_config.json 8 | 9 | input_file=scratch/conll_test.txt 10 | 11 | echo Running Round 0... 12 | 13 | #tokenize input sentences into wordpiece tokens (output_tokens.txt) 14 | python3.6 tokenize_input.py \ 15 | --input=$input_file \ 16 | --output_tokens=$data_dir/output_tokens.txt \ 17 | --output_token_ids=$data_dir/test_incorr.txt \ 18 | --vocab_path=$bert_cased_vocab \ 19 | --do_spell_check=True 20 | 21 | #output_tokens is the wordpiece tokenized version of input (output_tokens.txt) 22 | #test_incorr.txt has token_ids of wordpiece tokens 23 | 24 | #PIE predicts edit ids for wordpiece tokenids (test_incorr.txt) 25 | cd $cur_dir 26 | ./pie_infer.sh 27 | cd .. 28 | cp $output_dir/test_results.txt $data_dir/multi_round_0_test_results.txt 29 | #test_results.txt contains edit ids inferred through PIE 30 | 31 | #apply edits (test_results.txt) on the wordpiece tokens of input (output_tokens.txt) 32 | python3.6 apply_opcode.py \ 33 | --vocab_path=$bert_cased_vocab \ 34 | --input_tokens=$data_dir/output_tokens.txt \ 35 | --edit_ids=$data_dir/multi_round_0_test_results.txt \ 36 | --output_tokens=$data_dir/multi_round_0_test_predictions.txt \ 37 | --infer_mode=conll \ 38 | --path_common_inserts=$data_dir/pickles/common_inserts.p \ 39 | --path_common_multitoken_inserts=$data_dir/pickles/common_multitoken_inserts.p \ 40 | --path_common_deletes=$data_dir/pickles/common_deletes.p \ 41 | 42 | 43 | #corrected_sentences: multi_round_0_test_predictions.txt 44 | 45 | #iterate above for 3 more rounds and refine the corrected sentences further 46 | 47 | for round_id in {1..3}; 48 | do 49 | echo Running Round $round_id ... 50 | python3.6 tokenize_input.py \ 51 | --input=$data_dir/multi_round_"$(( round_id - 1 ))"_test_predictions.txt \ 52 | --output_tokens=$data_dir/output_tokens.txt \ 53 | --output_token_ids=$data_dir/test_incorr.txt \ 54 | --vocab_path=$bert_cased_vocab 55 | 56 | cd $cur_dir 57 | ./pie_infer.sh 58 | cd .. 59 | cp $output_dir/test_results.txt $data_dir/multi_round_"$round_id"_test_results.txt 60 | 61 | python3.6 apply_opcode.py \ 62 | --vocab_path=$bert_cased_vocab \ 63 | --input_tokens=$data_dir/output_tokens.txt \ 64 | --edit_ids=$data_dir/multi_round_"$round_id"_test_results.txt \ 65 | --output_tokens=$data_dir/multi_round_"$round_id"_test_predictions.txt \ 66 | --infer_mode=conll \ 67 | --path_common_inserts=$data_dir/pickles/common_inserts.p \ 68 | --path_common_multitoken_inserts=$data_dir/pickles/common_multitoken_inserts.p \ 69 | --path_common_deletes=$data_dir/pickles/common_deletes.p 70 | done 71 | -------------------------------------------------------------------------------- /example_scripts/pie_infer.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | 4 | output_dir=scratch/output_1.0_ensemble_1 5 | data_dir=scratch 6 | bert_cased_vocab=PIE_ckpt/vocab.txt 7 | bert_config_file=PIE_ckpt/bert_config.json 8 | path_multitoken_inserts=$data_dir/pickles/common_multitoken_inserts.p 9 | path_inserts=$data_dir/pickles/common_inserts.p 10 | 11 | 12 | echo $tpu_name 13 | python3.6 word_edit_model.py \ 14 | --do_predict=True \ 15 | --data_dir=$data_dir \ 16 | --vocab_file=$bert_cased_vocab \ 17 | --bert_config_file=$bert_config_file \ 18 | --max_seq_length=128 \ 19 | --predict_batch_size=128 \ 20 | --output_dir=$output_dir \ 21 | --do_lower_case=False \ 22 | --use_tpu=True \ 23 | --tpu_name=a3 \ 24 | --tpu_zone=us-central1-a \ 25 | --path_inserts=$path_inserts \ 26 | --path_multitoken_inserts=$path_multitoken_inserts \ 27 | 28 | -------------------------------------------------------------------------------- /example_scripts/pie_train.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | if [ $# -ne 4 ] 4 | then 5 | echo "Illegal number of parameters" 6 | exit 1 7 | fi 8 | 9 | ensemble_id=$1 10 | use_tpu=True 11 | copy_weight=$2 12 | random_seed=$3 13 | tpu_name=$4 14 | 15 | learning_rate=2e-5 16 | mode=large 17 | 18 | export data_dir=scratch 19 | export path_inserts=$data_dir/pickles/common_inserts.p 20 | export path_multitoken_inserts=$data_dir/pickles/common_multitoken_inserts.p 21 | 22 | echo path_inserts: "$path_inserts" 23 | echo path_multitoken_inserts: "$path_multitoken_inserts" 24 | 25 | 26 | export CKPT_DIR=PIE_ckpt 27 | export output_dir="$data_dir"/output_"$copy_weight"_ensemble_"$ensemble_id" 28 | 29 | echo $output_dir 30 | 31 | 32 | echo $tpu_name 33 | 34 | python3.6 -u word_edit_model.py \ 35 | --random_seed=$random_seed \ 36 | --do_train=True \ 37 | --data_dir=$data_dir \ 38 | --vocab_file=$CKPT_DIR/vocab.txt \ 39 | --bert_config_file=$CKPT_DIR/bert_config.json \ 40 | --max_seq_length=64 \ 41 | --train_batch_size=64 \ 42 | --learning_rate=$learning_rate \ 43 | --num_train_epochs=4 \ 44 | --output_dir=$output_dir \ 45 | --do_lower_case=False \ 46 | --use_tpu=$use_tpu \ 47 | --tpu_name=$tpu_name \ 48 | --tpu_zone=us-central1-a \ 49 | --save_checkpoints_steps=70000 \ 50 | --iterations_per_loop=1000 \ 51 | --num_tpu_cores=8 \ 52 | --copy_weight=$copy_weight \ 53 | --path_inserts=$path_inserts \ 54 | --path_multitoken_inserts=$path_multitoken_inserts \ 55 | --init_checkpoint=$CKPT_DIR/bert_model.ckpt 56 | -------------------------------------------------------------------------------- /example_scripts/preprocess.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | path_to_cased_vocab_bert=PIE_ckpt/vocab.txt #bert's cased vocab path 4 | data_dir=scratch 5 | 6 | mkdir -p $data_dir/pickles 7 | 8 | #preprocessing consists of 2 steps 9 | # 1. extract the common inserts (\Sigma_a in the paper): 10 | # Example usage: 11 | python3.6 get_edit_vocab.py \ 12 | --vocab_path="$path_to_cased_vocab_bert" \ 13 | --incorr_sents=$data_dir/train_incorr_sentences.txt \ 14 | --correct_sents=$data_dir/train_corr_sentences.txt \ 15 | --common_inserts_dir=$data_dir/pickles \ 16 | --size_insert_list=1000 \ 17 | --size_delete_list=1000 18 | 19 | # 2. extract edits from incorrect and correct sentence pairs 20 | # Example usage: 21 | python3.6 get_seq2edits.py \ 22 | --vocab_path="$path_to_cased_vocab_bert" \ 23 | --common_inserts_dir=$data_dir/pickles \ 24 | --incorr_sents=$data_dir/train_incorr_sentences.txt \ 25 | --correct_sents=$data_dir/train_corr_sentences.txt \ 26 | --incorr_tokens=$data_dir/train_incorr_tokens.txt \ 27 | --correct_tokens=$data_dir/train_corr_tokens.txt \ 28 | --incorr_token_ids=$data_dir/train_incorr.txt \ 29 | --edit_ids=$data_dir/train_labels.txt 30 | 31 | # incorr_token_ids and edit_ids are used by word_edit_model.py for training GEC model 32 | 33 | # NOTE: get_seq2edits.py can regect sentence pairs involving q-gram insertions where q > 2 34 | -------------------------------------------------------------------------------- /get_edit_vocab.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import seq2edits_utils 4 | from collections import defaultdict 5 | import tokenization 6 | import argparse 7 | from utils import generator_based_read_file, do_pickle, pretty, custom_tokenize 8 | from collections import Counter, defaultdict 9 | from joblib import Parallel, delayed 10 | from tqdm import tqdm 11 | 12 | 13 | def add_arguments(parser): 14 | """Build ArgumentParser.""" 15 | parser.add_argument("--vocab_path", type=str, default=None, help="path to bert's cased vocab file") 16 | parser.add_argument("--incorr_sents", type=str, default=None, help="path to incorrect sentence file") 17 | parser.add_argument("--correct_sents", type=str, default=None, help="path to correct sentence file") 18 | parser.add_argument("--common_inserts_dir", type=str, default="pickles", help="path to store common inserts in pickles") 19 | parser.add_argument("--size_insert_list", type=int, default=500, help="size of common insertions list") 20 | parser.add_argument("--size_delete_list", type=int, default=500, help="size of common deletions list") 21 | # all the datasets can be obtained from here: https://www.cl.cam.ac.uk/research/nl/bea2019st/ 22 | 23 | parser = argparse.ArgumentParser() 24 | add_arguments(parser) 25 | FLAGS, unparsed = parser.parse_known_args() 26 | wordpiece_tokenizer = tokenization.FullTokenizer(FLAGS.vocab_path, do_lower_case=False) 27 | 28 | def merge_dicts(dicts): 29 | merged = defaultdict(int) 30 | for d in dicts: 31 | for elem in d: 32 | merged[elem] += d[elem] 33 | return merged 34 | 35 | def update_dicts(insert_dict, delete_dict, rejected, processed): 36 | insert_dict = merge_dicts([p[0] for p in processed]+[insert_dict]) 37 | delete_dict = merge_dicts([p[1] for p in processed]+[delete_dict]) 38 | rejected += sum(p[2] for p in processed) 39 | return insert_dict, delete_dict, rejected 40 | 41 | def get_ins_dels(incorr_line, correct_line): 42 | ins = defaultdict(int) 43 | dels = defaultdict(int) 44 | rejected = 0 45 | 46 | incorr_tokens = custom_tokenize(incorr_line, wordpiece_tokenizer, mode="train") 47 | correct_tokens = custom_tokenize(correct_line, wordpiece_tokenizer, mode="train") 48 | diffs = seq2edits_utils.ndiff(incorr_tokens, correct_tokens) 49 | 50 | for item in diffs: 51 | if item[0]=="+": 52 | if len(item[2:].split())>2: 53 | return defaultdict(int), defaultdict(int), 1 54 | ins[item[2:]]+=1 55 | elif item[0]=="-": 56 | dels[item[2:]]+=1 57 | 58 | return ins,dels,0 59 | 60 | def segregate_insertions(insert_dict): 61 | #segregates unigram and bigram insetions 62 | #returns unigram and bigram list 63 | unigrams = [] 64 | bigrams = [] 65 | 66 | for item in insert_dict: 67 | if len(item.split())==2: 68 | bigrams.append(item) 69 | elif len(item.split())==1: 70 | unigrams.append(item) 71 | else: 72 | print("ERROR: we only support upto bigram insertions") 73 | 74 | return unigrams,bigrams 75 | 76 | # Read raw data 77 | pretty.pheader('Reading Input') 78 | incorrect_lines_generator = generator_based_read_file(FLAGS.incorr_sents, 'incorrect lines') 79 | correct_lines_generator = generator_based_read_file(FLAGS.correct_sents, 'correct lines') 80 | 81 | insert_dict={} 82 | delete_dict={} 83 | rejected = 0 #number of sentences having more q-gram insertion where q>2 84 | 85 | for incorrect_lines, correct_lines in zip(incorrect_lines_generator, correct_lines_generator): 86 | processed_dicts = Parallel(n_jobs=-1)(delayed(get_ins_dels)(*s) for s in tqdm( 87 | zip(incorrect_lines, correct_lines), total=len(incorrect_lines))) 88 | 89 | insert_dict,delete_dict, rejected=update_dicts(insert_dict, delete_dict, rejected, processed_dicts) 90 | 91 | insert_dict=dict(Counter(insert_dict).most_common(FLAGS.size_insert_list)) 92 | delete_dict=dict(Counter(delete_dict).most_common(FLAGS.size_delete_list)) 93 | 94 | #insert_dict corresponds to \Sigma_a in the paper. 95 | #Elements in \Sigma_a are considered for appends and replacements both 96 | unigram_inserts, bigram_inserts = segregate_insertions(insert_dict) 97 | 98 | do_pickle(unigram_inserts,os.path.join(FLAGS.common_inserts_dir,"common_inserts.p")) 99 | do_pickle(bigram_inserts,os.path.join(FLAGS.common_inserts_dir,"common_multitoken_inserts.p")) 100 | do_pickle(delete_dict,os.path.join(FLAGS.common_inserts_dir,"common_deletes.p")) -------------------------------------------------------------------------------- /get_seq2edits.py: -------------------------------------------------------------------------------- 1 | import os 2 | from joblib import Parallel, delayed 3 | from tqdm import tqdm 4 | from utils import generator_based_read_file, do_pickle, pretty, custom_tokenize 5 | from opcodes import Opcodes 6 | from transform_suffixes import SuffixTransform 7 | import tokenization 8 | import argparse 9 | import seq2edits_utils 10 | 11 | def add_arguments(parser): 12 | """Build ArgumentParser.""" 13 | parser.add_argument("--vocab_path", type=str, default=None, help="path to bert's cased vocab file") 14 | parser.add_argument("--common_inserts_dir", type=str, default="pickles", help="path to load common inserts") 15 | parser.add_argument("--incorr_sents", type=str, default=None, help="path to incorrect sentence file") 16 | parser.add_argument("--correct_sents", type=str, default=None, help="path to correct sentence file") 17 | parser.add_argument("--incorr_tokens", type=str, default=None, help="path to tokenized incorrect sentences") 18 | parser.add_argument("--correct_tokens", type=str, default=None, help="path to tokenized correct sentences") 19 | parser.add_argument("--incorr_token_ids", type=str, default=None, help="path to incorrect token ids of sentences") 20 | parser.add_argument("--edit_ids", type=str, default=None, help="path to edit ids for each sequence in incorr_token_ids") 21 | 22 | parser = argparse.ArgumentParser() 23 | add_arguments(parser) 24 | FLAGS, unparsed = parser.parse_known_args() 25 | wordpiece_tokenizer = tokenization.FullTokenizer(FLAGS.vocab_path, do_lower_case=False) 26 | 27 | opcodes = Opcodes( 28 | path_common_inserts=os.path.join(FLAGS.common_inserts_dir,"common_inserts.p"), 29 | path_common_multitoken_inserts=os.path.join(FLAGS.common_inserts_dir,"common_multitoken_inserts.p"), 30 | use_transforms=True) 31 | 32 | def seq2edits(incorr_line, correct_line): 33 | # Seq2Edits function (Described in Section 2.2 of the paper) 34 | # obtains edit ids from incorrect and correct tokens 35 | # input: incorrect line and correct line 36 | # output: incorr_tokens, correct_tokens, incorr token ids, edit ids 37 | 38 | #tokenize incorr_line and correct_line 39 | incorr_tokens = custom_tokenize(incorr_line, wordpiece_tokenizer, mode="train") 40 | correct_tokens = custom_tokenize(correct_line, wordpiece_tokenizer, mode="train") 41 | #generate diffs using modified edit distance algorith 42 | # (Described in Appendix A.1 of the paper) 43 | diffs = seq2edits_utils.ndiff(incorr_tokens, correct_tokens) 44 | # align diffs to get edits 45 | edit_ids = diffs_to_edits(diffs) 46 | 47 | if not edit_ids: 48 | return None 49 | #get incorrect token ids 50 | incorr_tok_ids = wordpiece_tokenizer.convert_tokens_to_ids(incorr_tokens) 51 | return incorr_tokens, correct_tokens, incorr_tok_ids, edit_ids 52 | 53 | def diffs_to_edits(diffs): 54 | #converts diffs to edit ids 55 | 56 | prev_edit = None 57 | edits = [] 58 | for i,op in enumerate(diffs): 59 | # op has following forms: " data" (no-edit) or "- data" (delete) or "+ data" (insert) 60 | # thus op[0] gives the operation and op[2:] gives the argument 61 | # (see ndiff function in diff_edit_distance) 62 | 63 | if op[0] == " ": 64 | edits.append(opcodes.CPY) 65 | elif op[0] == "-": 66 | edits.append(opcodes.DEL) 67 | elif op[0] == "+": #APPEND or REPLACE or SUFFIX TRANFORM 68 | assert len(edits)>0, "+ or - cannot occour in beginning since all sentences were\ 69 | were prefixed by a [CLS] token" 70 | 71 | q_gram = op[2:] #argument of on while op[0] operates 72 | 73 | if len(q_gram.split()) > 2: #reject q_gram if q>2 74 | return None 75 | 76 | if edits[-1] == opcodes.CPY: # CASE OF APPEND / APPEND BASED SUFFIX TRANSFROM (e.g. play -> played) 77 | if q_gram in opcodes.APPEND_SUFFIX: #priority to SUFFIX TRANSFORM 78 | edits[-1] = opcodes.APPEND_SUFFIX[q_gram] 79 | elif q_gram in opcodes.APPEND: 80 | edits[-1] = opcodes.APPEND[q_gram] 81 | else: 82 | # appending q_gram is not supported 83 | #we ignore the append and edits[-1] is retained as a COPY 84 | pass 85 | 86 | elif edits[-1] == opcodes.DEL: # CASE of SUFFIX TRANSFORMATION / REPLACE EDIT 87 | del_token = diffs[i-1][2:] #replaced word 88 | # check for transfomation match 89 | st = SuffixTransform(del_token, q_gram,opcodes).match() 90 | if st: 91 | edits[-1] = st 92 | else: 93 | #check for replace opration of transformation match failed 94 | if q_gram in opcodes.REP: 95 | edits[-1] = opcodes.REP[q_gram] 96 | else: 97 | # replacement with q_gram is not supported 98 | # we ignore the replacement and UNDO delete by having edits[-1] as COPY 99 | edits[-1] = opcodes.CPY 100 | else: 101 | #since inserts are merged in diffs, edits[-1] is either a CPY or a DEL, if op[0] == "+" 102 | print("This should never occour") 103 | exit(1) 104 | return edits 105 | 106 | pretty.pheader('Reading Input') 107 | incorrect_lines_generator = generator_based_read_file(FLAGS.incorr_sents, 'incorrect lines') 108 | correct_lines_generator = generator_based_read_file(FLAGS.correct_sents, 'correct lines') 109 | 110 | with open(FLAGS.incorr_tokens,"w") as ic_toks, \ 111 | open(FLAGS.correct_tokens,"w") as c_toks, \ 112 | open(FLAGS.incorr_token_ids,"w") as ic_tok_ids, \ 113 | open(FLAGS.edit_ids,"w") as e_ids: 114 | for incorrect_lines, correct_lines in zip(incorrect_lines_generator, correct_lines_generator): 115 | processed = Parallel(n_jobs=-1)(delayed(seq2edits)(*s) for s in tqdm( 116 | zip(incorrect_lines, correct_lines), total=len(incorrect_lines))) 117 | 118 | processed = [p for p in processed if p] 119 | for p in processed: 120 | ic_toks.write(" ".join(p[0]) + "\n") 121 | c_toks.write(" ".join(p[1]) + "\n") 122 | ic_tok_ids.write(" ".join(map(str,p[2])) + "\n") 123 | e_ids.write(" ".join(map(str,p[3])) + "\n") 124 | 125 | 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /install_dependencies.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Script to install requirements 4 | python3 -m pip install -r requirements.txt 5 | python3 -m spacy download en_core_web_sm 6 | -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | # code borrowed from https://github.com/google-research/bert 2 | # coding=utf-8 3 | # Copyright 2018 The Google AI Language Team Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """The main BERT model and related functions.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | import copy 24 | import json 25 | import math 26 | import re 27 | import six 28 | import tensorflow as tf 29 | 30 | 31 | class BertConfig(object): 32 | """Configuration for `BertModel`.""" 33 | 34 | def __init__(self, 35 | vocab_size, 36 | hidden_size=768, 37 | num_hidden_layers=12, 38 | num_attention_heads=12, 39 | intermediate_size=3072, 40 | hidden_act="gelu", 41 | hidden_dropout_prob=0.1, 42 | attention_probs_dropout_prob=0.1, 43 | max_position_embeddings=512, 44 | type_vocab_size=16, 45 | initializer_range=0.02): 46 | """Constructs BertConfig. 47 | 48 | Args: 49 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 50 | hidden_size: Size of the encoder layers and the pooler layer. 51 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 52 | num_attention_heads: Number of attention heads for each attention layer in 53 | the Transformer encoder. 54 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 55 | layer in the Transformer encoder. 56 | hidden_act: The non-linear activation function (function or string) in the 57 | encoder and pooler. 58 | hidden_dropout_prob: The dropout probability for all fully connected 59 | layers in the embeddings, encoder, and pooler. 60 | attention_probs_dropout_prob: The dropout ratio for the attention 61 | probabilities. 62 | max_position_embeddings: The maximum sequence length that this model might 63 | ever be used with. Typically set this to something large just in case 64 | (e.g., 512 or 1024 or 2048). 65 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 66 | `BertModel`. 67 | initializer_range: The stdev of the truncated_normal_initializer for 68 | initializing all weight matrices. 69 | """ 70 | self.vocab_size = vocab_size 71 | self.hidden_size = hidden_size 72 | self.num_hidden_layers = num_hidden_layers 73 | self.num_attention_heads = num_attention_heads 74 | self.hidden_act = hidden_act 75 | self.intermediate_size = intermediate_size 76 | self.hidden_dropout_prob = hidden_dropout_prob 77 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 78 | self.max_position_embeddings = max_position_embeddings 79 | self.type_vocab_size = type_vocab_size 80 | self.initializer_range = initializer_range 81 | 82 | @classmethod 83 | def from_dict(cls, json_object): 84 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 85 | config = BertConfig(vocab_size=None) 86 | for (key, value) in six.iteritems(json_object): 87 | config.__dict__[key] = value 88 | return config 89 | 90 | @classmethod 91 | def from_json_file(cls, json_file): 92 | """Constructs a `BertConfig` from a json file of parameters.""" 93 | with tf.gfile.GFile(json_file, "r") as reader: 94 | text = reader.read() 95 | return cls.from_dict(json.loads(text)) 96 | 97 | def to_dict(self): 98 | """Serializes this instance to a Python dictionary.""" 99 | output = copy.deepcopy(self.__dict__) 100 | return output 101 | 102 | def to_json_string(self): 103 | """Serializes this instance to a JSON string.""" 104 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 105 | 106 | 107 | class BertModel(object): 108 | """BERT model ("Bidirectional Encoder Representations from Transformers"). 109 | 110 | Example usage: 111 | 112 | ```python 113 | # Already been converted into WordPiece token ids 114 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]]) 115 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]]) 116 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]]) 117 | 118 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 119 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 120 | 121 | model = modeling.BertModel(config=config, is_training=True, 122 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) 123 | 124 | label_embeddings = tf.get_variable(...) 125 | pooled_output = model.get_pooled_output() 126 | logits = tf.matmul(pooled_output, label_embeddings) 127 | ... 128 | ``` 129 | """ 130 | 131 | def __init__(self, 132 | config, 133 | is_training, 134 | input_ids, 135 | input_mask=None, 136 | token_type_ids=None, 137 | use_one_hot_embeddings=True, 138 | scope=None): 139 | """Constructor for BertModel. 140 | 141 | Args: 142 | config: `BertConfig` instance. 143 | is_training: bool. true for training model, false for eval model. Controls 144 | whether dropout will be applied. 145 | input_ids: int32 Tensor of shape [batch_size, seq_length]. 146 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. 147 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 148 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word 149 | embeddings or tf.embedding_lookup() for the word embeddings. On the TPU, 150 | it is much faster if this is True, on the CPU or GPU, it is faster if 151 | this is False. 152 | scope: (optional) variable scope. Defaults to "bert". 153 | 154 | Raises: 155 | ValueError: The config is invalid or one of the input tensor shapes 156 | is invalid. 157 | """ 158 | config = copy.deepcopy(config) 159 | if not is_training: 160 | config.hidden_dropout_prob = 0.0 161 | config.attention_probs_dropout_prob = 0.0 162 | 163 | input_shape = get_shape_list(input_ids, expected_rank=2) 164 | batch_size = input_shape[0] 165 | seq_length = input_shape[1] 166 | 167 | if input_mask is None: 168 | input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) 169 | 170 | if token_type_ids is None: 171 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) 172 | 173 | with tf.variable_scope(scope, default_name="bert"): 174 | with tf.variable_scope("embeddings"): 175 | # Perform embedding lookup on the word ids. 176 | (self.embedding_output, self.embedding_table) = embedding_lookup( 177 | input_ids=input_ids, 178 | vocab_size=config.vocab_size, 179 | embedding_size=config.hidden_size, 180 | initializer_range=config.initializer_range, 181 | word_embedding_name="word_embeddings", 182 | use_one_hot_embeddings=use_one_hot_embeddings) 183 | 184 | # Add positional embeddings and token type embeddings, then layer 185 | # normalize and perform dropout. 186 | self.embedding_output = embedding_postprocessor( 187 | input_tensor=self.embedding_output, 188 | use_token_type=True, 189 | token_type_ids=token_type_ids, 190 | token_type_vocab_size=config.type_vocab_size, 191 | token_type_embedding_name="token_type_embeddings", 192 | use_position_embeddings=True, 193 | position_embedding_name="position_embeddings", 194 | initializer_range=config.initializer_range, 195 | max_position_embeddings=config.max_position_embeddings, 196 | dropout_prob=config.hidden_dropout_prob) 197 | 198 | with tf.variable_scope("encoder"): 199 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D 200 | # mask of shape [batch_size, seq_length, seq_length] which is used 201 | # for the attention scores. 202 | attention_mask = create_attention_mask_from_input_mask( 203 | input_ids, input_mask) 204 | 205 | # Run the stacked transformer. 206 | # `sequence_output` shape = [batch_size, seq_length, hidden_size]. 207 | self.all_encoder_layers = transformer_model( 208 | input_tensor=self.embedding_output, 209 | attention_mask=attention_mask, 210 | hidden_size=config.hidden_size, 211 | num_hidden_layers=config.num_hidden_layers, 212 | num_attention_heads=config.num_attention_heads, 213 | intermediate_size=config.intermediate_size, 214 | intermediate_act_fn=get_activation(config.hidden_act), 215 | hidden_dropout_prob=config.hidden_dropout_prob, 216 | attention_probs_dropout_prob=config.attention_probs_dropout_prob, 217 | initializer_range=config.initializer_range, 218 | do_return_all_layers=True) 219 | 220 | self.sequence_output = self.all_encoder_layers[-1] 221 | # The "pooler" converts the encoded sequence tensor of shape 222 | # [batch_size, seq_length, hidden_size] to a tensor of shape 223 | # [batch_size, hidden_size]. This is necessary for segment-level 224 | # (or segment-pair-level) classification tasks where we need a fixed 225 | # dimensional representation of the segment. 226 | with tf.variable_scope("pooler"): 227 | # We "pool" the model by simply taking the hidden state corresponding 228 | # to the first token. We assume that this has been pre-trained 229 | first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) 230 | self.pooled_output = tf.layers.dense( 231 | first_token_tensor, 232 | config.hidden_size, 233 | activation=tf.tanh, 234 | kernel_initializer=create_initializer(config.initializer_range)) 235 | 236 | def get_pooled_output(self): 237 | return self.pooled_output 238 | 239 | def get_sequence_output(self): 240 | """Gets final hidden layer of encoder. 241 | 242 | Returns: 243 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 244 | to the final hidden of the transformer encoder. 245 | """ 246 | return self.sequence_output 247 | 248 | def get_all_encoder_layers(self): 249 | return self.all_encoder_layers 250 | 251 | def get_embedding_output(self): 252 | """Gets output of the embedding lookup (i.e., input to the transformer). 253 | 254 | Returns: 255 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 256 | to the output of the embedding layer, after summing the word 257 | embeddings with the positional embeddings and the token type embeddings, 258 | then performing layer normalization. This is the input to the transformer. 259 | """ 260 | return self.embedding_output 261 | 262 | def get_embedding_table(self): 263 | return self.embedding_table 264 | 265 | 266 | def gelu(input_tensor): 267 | """Gaussian Error Linear Unit. 268 | 269 | This is a smoother version of the RELU. 270 | Original paper: https://arxiv.org/abs/1606.08415 271 | 272 | Args: 273 | input_tensor: float Tensor to perform activation. 274 | 275 | Returns: 276 | `input_tensor` with the GELU activation applied. 277 | """ 278 | cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0))) 279 | return input_tensor * cdf 280 | 281 | 282 | def get_activation(activation_string): 283 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. 284 | 285 | Args: 286 | activation_string: String name of the activation function. 287 | 288 | Returns: 289 | A Python function corresponding to the activation function. If 290 | `activation_string` is None, empty, or "linear", this will return None. 291 | If `activation_string` is not a string, it will return `activation_string`. 292 | 293 | Raises: 294 | ValueError: The `activation_string` does not correspond to a known 295 | activation. 296 | """ 297 | 298 | # We assume that anything that"s not a string is already an activation 299 | # function, so we just return it. 300 | if not isinstance(activation_string, six.string_types): 301 | return activation_string 302 | 303 | if not activation_string: 304 | return None 305 | 306 | act = activation_string.lower() 307 | if act == "linear": 308 | return None 309 | elif act == "relu": 310 | return tf.nn.relu 311 | elif act == "gelu": 312 | return gelu 313 | elif act == "tanh": 314 | return tf.tanh 315 | else: 316 | raise ValueError("Unsupported activation: %s" % act) 317 | 318 | 319 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint): 320 | """Compute the union of the current variables and checkpoint variables.""" 321 | assignment_map = {} 322 | initialized_variable_names = {} 323 | 324 | name_to_variable = collections.OrderedDict() 325 | for var in tvars: 326 | name = var.name 327 | m = re.match("^(.*):\\d+$", name) 328 | if m is not None: 329 | name = m.group(1) 330 | name_to_variable[name] = var 331 | 332 | init_vars = tf.train.list_variables(init_checkpoint) 333 | 334 | assignment_map = collections.OrderedDict() 335 | for x in init_vars: 336 | (name, var) = (x[0], x[1]) 337 | if name not in name_to_variable: 338 | continue 339 | assignment_map[name] = name 340 | initialized_variable_names[name] = 1 341 | initialized_variable_names[name + ":0"] = 1 342 | 343 | return (assignment_map, initialized_variable_names) 344 | 345 | 346 | def dropout(input_tensor, dropout_prob): 347 | """Perform dropout. 348 | 349 | Args: 350 | input_tensor: float Tensor. 351 | dropout_prob: Python float. The probability of dropping out a value (NOT of 352 | *keeping* a dimension as in `tf.nn.dropout`). 353 | 354 | Returns: 355 | A version of `input_tensor` with dropout applied. 356 | """ 357 | if dropout_prob is None or dropout_prob == 0.0: 358 | return input_tensor 359 | 360 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob) 361 | return output 362 | 363 | 364 | def layer_norm(input_tensor, name=None): 365 | """Run layer normalization on the last dimension of the tensor.""" 366 | return tf.contrib.layers.layer_norm( 367 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) 368 | 369 | 370 | def layer_norm_and_dropout(input_tensor, dropout_prob, name=None): 371 | """Runs layer normalization followed by dropout.""" 372 | output_tensor = layer_norm(input_tensor, name) 373 | output_tensor = dropout(output_tensor, dropout_prob) 374 | return output_tensor 375 | 376 | 377 | def create_initializer(initializer_range=0.02): 378 | """Creates a `truncated_normal_initializer` with the given range.""" 379 | return tf.truncated_normal_initializer(stddev=initializer_range) 380 | 381 | 382 | def embedding_lookup(input_ids, 383 | vocab_size, 384 | embedding_size=128, 385 | initializer_range=0.02, 386 | word_embedding_name="word_embeddings", 387 | use_one_hot_embeddings=False): 388 | """Looks up words embeddings for id tensor. 389 | 390 | Args: 391 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word 392 | ids. 393 | vocab_size: int. Size of the embedding vocabulary. 394 | embedding_size: int. Width of the word embeddings. 395 | initializer_range: float. Embedding initialization range. 396 | word_embedding_name: string. Name of the embedding table. 397 | use_one_hot_embeddings: bool. If True, use one-hot method for word 398 | embeddings. If False, use `tf.nn.embedding_lookup()`. One hot is better 399 | for TPUs. 400 | 401 | Returns: 402 | float Tensor of shape [batch_size, seq_length, embedding_size]. 403 | """ 404 | # This function assumes that the input is of shape [batch_size, seq_length, 405 | # num_inputs]. 406 | # 407 | # If the input is a 2D tensor of shape [batch_size, seq_length], we 408 | # reshape to [batch_size, seq_length, 1]. 409 | if input_ids.shape.ndims == 2: 410 | input_ids = tf.expand_dims(input_ids, axis=[-1]) 411 | 412 | embedding_table = tf.get_variable( 413 | name=word_embedding_name, 414 | shape=[vocab_size, embedding_size], 415 | initializer=create_initializer(initializer_range)) 416 | 417 | if use_one_hot_embeddings: 418 | flat_input_ids = tf.reshape(input_ids, [-1]) 419 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) 420 | output = tf.matmul(one_hot_input_ids, embedding_table) 421 | else: 422 | output = tf.nn.embedding_lookup(embedding_table, input_ids) 423 | 424 | input_shape = get_shape_list(input_ids) 425 | 426 | output = tf.reshape(output, 427 | input_shape[0:-1] + [input_shape[-1] * embedding_size]) 428 | return (output, embedding_table) 429 | 430 | 431 | def embedding_postprocessor(input_tensor, 432 | use_token_type=False, 433 | token_type_ids=None, 434 | token_type_vocab_size=16, 435 | token_type_embedding_name="token_type_embeddings", 436 | use_position_embeddings=True, 437 | position_embedding_name="position_embeddings", 438 | initializer_range=0.02, 439 | max_position_embeddings=512, 440 | dropout_prob=0.1): 441 | """Performs various post-processing on a word embedding tensor. 442 | 443 | Args: 444 | input_tensor: float Tensor of shape [batch_size, seq_length, 445 | embedding_size]. 446 | use_token_type: bool. Whether to add embeddings for `token_type_ids`. 447 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 448 | Must be specified if `use_token_type` is True. 449 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`. 450 | token_type_embedding_name: string. The name of the embedding table variable 451 | for token type ids. 452 | use_position_embeddings: bool. Whether to add position embeddings for the 453 | position of each token in the sequence. 454 | position_embedding_name: string. The name of the embedding table variable 455 | for positional embeddings. 456 | initializer_range: float. Range of the weight initialization. 457 | max_position_embeddings: int. Maximum sequence length that might ever be 458 | used with this model. This can be longer than the sequence length of 459 | input_tensor, but cannot be shorter. 460 | dropout_prob: float. Dropout probability applied to the final output tensor. 461 | 462 | Returns: 463 | float tensor with same shape as `input_tensor`. 464 | 465 | Raises: 466 | ValueError: One of the tensor shapes or input values is invalid. 467 | """ 468 | input_shape = get_shape_list(input_tensor, expected_rank=3) 469 | batch_size = input_shape[0] 470 | seq_length = input_shape[1] 471 | width = input_shape[2] 472 | 473 | output = input_tensor 474 | 475 | if use_token_type: 476 | if token_type_ids is None: 477 | raise ValueError("`token_type_ids` must be specified if" 478 | "`use_token_type` is True.") 479 | token_type_table = tf.get_variable( 480 | name=token_type_embedding_name, 481 | shape=[token_type_vocab_size, width], 482 | initializer=create_initializer(initializer_range)) 483 | # This vocab will be small so we always do one-hot here, since it is always 484 | # faster for a small vocabulary. 485 | flat_token_type_ids = tf.reshape(token_type_ids, [-1]) 486 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size) 487 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) 488 | token_type_embeddings = tf.reshape(token_type_embeddings, 489 | [batch_size, seq_length, width]) 490 | output += token_type_embeddings 491 | 492 | if use_position_embeddings: 493 | assert_op = tf.assert_less_equal(seq_length, max_position_embeddings) 494 | with tf.control_dependencies([assert_op]): 495 | full_position_embeddings = tf.get_variable( 496 | name=position_embedding_name, 497 | shape=[max_position_embeddings, width], 498 | initializer=create_initializer(initializer_range)) 499 | # Since the position embedding table is a learned variable, we create it 500 | # using a (long) sequence length `max_position_embeddings`. The actual 501 | # sequence length might be shorter than this, for faster training of 502 | # tasks that do not have long sequences. 503 | # 504 | # So `full_position_embeddings` is effectively an embedding table 505 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current 506 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just 507 | # perform a slice. 508 | position_embeddings = tf.slice(full_position_embeddings, [0, 0], 509 | [seq_length, -1]) 510 | num_dims = len(output.shape.as_list()) 511 | 512 | # Only the last two dimensions are relevant (`seq_length` and `width`), so 513 | # we broadcast among the first dimensions, which is typically just 514 | # the batch size. 515 | position_broadcast_shape = [] 516 | for _ in range(num_dims - 2): 517 | position_broadcast_shape.append(1) 518 | position_broadcast_shape.extend([seq_length, width]) 519 | position_embeddings = tf.reshape(position_embeddings, 520 | position_broadcast_shape) 521 | output += position_embeddings 522 | 523 | output = layer_norm_and_dropout(output, dropout_prob) 524 | return output 525 | 526 | 527 | def create_attention_mask_from_input_mask(from_tensor, to_mask): 528 | """Create 3D attention mask from a 2D tensor mask. 529 | 530 | Args: 531 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. 532 | to_mask: int32 Tensor of shape [batch_size, to_seq_length]. 533 | 534 | Returns: 535 | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. 536 | """ 537 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 538 | batch_size = from_shape[0] 539 | from_seq_length = from_shape[1] 540 | 541 | to_shape = get_shape_list(to_mask, expected_rank=2) 542 | to_seq_length = to_shape[1] 543 | 544 | to_mask = tf.cast( 545 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32) 546 | 547 | # We don't assume that `from_tensor` is a mask (although it could be). We 548 | # don't actually care if we attend *from* padding tokens (only *to* padding) 549 | # tokens so we create a tensor of all ones. 550 | # 551 | # `broadcast_ones` = [batch_size, from_seq_length, 1] 552 | broadcast_ones = tf.ones( 553 | shape=[batch_size, from_seq_length, 1], dtype=tf.float32) 554 | 555 | # Here we broadcast along two dimensions to create the mask. 556 | mask = broadcast_ones * to_mask 557 | 558 | return mask 559 | 560 | 561 | def attention_layer(from_tensor, 562 | to_tensor, 563 | attention_mask=None, 564 | num_attention_heads=1, 565 | size_per_head=512, 566 | query_act=None, 567 | key_act=None, 568 | value_act=None, 569 | attention_probs_dropout_prob=0.0, 570 | initializer_range=0.02, 571 | do_return_2d_tensor=False, 572 | batch_size=None, 573 | from_seq_length=None, 574 | to_seq_length=None): 575 | """Performs multi-headed attention from `from_tensor` to `to_tensor`. 576 | 577 | This is an implementation of multi-headed attention based on "Attention 578 | is all you Need". If `from_tensor` and `to_tensor` are the same, then 579 | this is self-attention. Each timestep in `from_tensor` attends to the 580 | corresponding sequence in `to_tensor`, and returns a fixed-with vector. 581 | 582 | This function first projects `from_tensor` into a "query" tensor and 583 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list 584 | of tensors of length `num_attention_heads`, where each tensor is of shape 585 | [batch_size, seq_length, size_per_head]. 586 | 587 | Then, the query and key tensors are dot-producted and scaled. These are 588 | softmaxed to obtain attention probabilities. The value tensors are then 589 | interpolated by these probabilities, then concatenated back to a single 590 | tensor and returned. 591 | 592 | In practice, the multi-headed attention are done with transposes and 593 | reshapes rather than actual separate tensors. 594 | 595 | Args: 596 | from_tensor: float Tensor of shape [batch_size, from_seq_length, 597 | from_width]. 598 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. 599 | attention_mask: (optional) int32 Tensor of shape [batch_size, 600 | from_seq_length, to_seq_length]. The values should be 1 or 0. The 601 | attention scores will effectively be set to -infinity for any positions in 602 | the mask that are 0, and will be unchanged for positions that are 1. 603 | num_attention_heads: int. Number of attention heads. 604 | size_per_head: int. Size of each attention head. 605 | query_act: (optional) Activation function for the query transform. 606 | key_act: (optional) Activation function for the key transform. 607 | value_act: (optional) Activation function for the value transform. 608 | attention_probs_dropout_prob: (optional) float. Dropout probability of the 609 | attention probabilities. 610 | initializer_range: float. Range of the weight initializer. 611 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size 612 | * from_seq_length, num_attention_heads * size_per_head]. If False, the 613 | output will be of shape [batch_size, from_seq_length, num_attention_heads 614 | * size_per_head]. 615 | batch_size: (Optional) int. If the input is 2D, this might be the batch size 616 | of the 3D version of the `from_tensor` and `to_tensor`. 617 | from_seq_length: (Optional) If the input is 2D, this might be the seq length 618 | of the 3D version of the `from_tensor`. 619 | to_seq_length: (Optional) If the input is 2D, this might be the seq length 620 | of the 3D version of the `to_tensor`. 621 | 622 | Returns: 623 | float Tensor of shape [batch_size, from_seq_length, 624 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is 625 | true, this will be of shape [batch_size * from_seq_length, 626 | num_attention_heads * size_per_head]). 627 | 628 | Raises: 629 | ValueError: Any of the arguments or tensor shapes are invalid. 630 | """ 631 | 632 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads, 633 | seq_length, width): 634 | output_tensor = tf.reshape( 635 | input_tensor, [batch_size, seq_length, num_attention_heads, width]) 636 | 637 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) 638 | return output_tensor 639 | 640 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 641 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) 642 | 643 | if len(from_shape) != len(to_shape): 644 | raise ValueError( 645 | "The rank of `from_tensor` must match the rank of `to_tensor`.") 646 | 647 | if len(from_shape) == 3: 648 | batch_size = from_shape[0] 649 | from_seq_length = from_shape[1] 650 | to_seq_length = to_shape[1] 651 | elif len(from_shape) == 2: 652 | if (batch_size is None or from_seq_length is None or to_seq_length is None): 653 | raise ValueError( 654 | "When passing in rank 2 tensors to attention_layer, the values " 655 | "for `batch_size`, `from_seq_length`, and `to_seq_length` " 656 | "must all be specified.") 657 | 658 | # Scalar dimensions referenced here: 659 | # B = batch size (number of sequences) 660 | # F = `from_tensor` sequence length 661 | # T = `to_tensor` sequence length 662 | # N = `num_attention_heads` 663 | # H = `size_per_head` 664 | 665 | from_tensor_2d = reshape_to_matrix(from_tensor) 666 | to_tensor_2d = reshape_to_matrix(to_tensor) 667 | 668 | # `query_layer` = [B*F, N*H] 669 | query_layer = tf.layers.dense( 670 | from_tensor_2d, 671 | num_attention_heads * size_per_head, 672 | activation=query_act, 673 | name="query", 674 | kernel_initializer=create_initializer(initializer_range)) 675 | 676 | # `key_layer` = [B*T, N*H] 677 | key_layer = tf.layers.dense( 678 | to_tensor_2d, 679 | num_attention_heads * size_per_head, 680 | activation=key_act, 681 | name="key", 682 | kernel_initializer=create_initializer(initializer_range)) 683 | 684 | # `value_layer` = [B*T, N*H] 685 | value_layer = tf.layers.dense( 686 | to_tensor_2d, 687 | num_attention_heads * size_per_head, 688 | activation=value_act, 689 | name="value", 690 | kernel_initializer=create_initializer(initializer_range)) 691 | 692 | # `query_layer` = [B, N, F, H] 693 | query_layer = transpose_for_scores(query_layer, batch_size, 694 | num_attention_heads, from_seq_length, 695 | size_per_head) 696 | 697 | # `key_layer` = [B, N, T, H] 698 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, 699 | to_seq_length, size_per_head) 700 | 701 | # Take the dot product between "query" and "key" to get the raw 702 | # attention scores. 703 | # `attention_scores` = [B, N, F, T] 704 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) 705 | attention_scores = tf.multiply(attention_scores, 706 | 1.0 / math.sqrt(float(size_per_head))) 707 | 708 | if attention_mask is not None: 709 | # `attention_mask` = [B, 1, F, T] 710 | attention_mask = tf.expand_dims(attention_mask, axis=[1]) 711 | 712 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 713 | # masked positions, this operation will create a tensor which is 0.0 for 714 | # positions we want to attend and -10000.0 for masked positions. 715 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 716 | 717 | # Since we are adding it to the raw scores before the softmax, this is 718 | # effectively the same as removing these entirely. 719 | attention_scores += adder 720 | 721 | # Normalize the attention scores to probabilities. 722 | # `attention_probs` = [B, N, F, T] 723 | attention_probs = tf.nn.softmax(attention_scores) 724 | 725 | # This is actually dropping out entire tokens to attend to, which might 726 | # seem a bit unusual, but is taken from the original Transformer paper. 727 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob) 728 | 729 | # `value_layer` = [B, T, N, H] 730 | value_layer = tf.reshape( 731 | value_layer, 732 | [batch_size, to_seq_length, num_attention_heads, size_per_head]) 733 | 734 | # `value_layer` = [B, N, T, H] 735 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) 736 | 737 | # `context_layer` = [B, N, F, H] 738 | context_layer = tf.matmul(attention_probs, value_layer) 739 | 740 | # `context_layer` = [B, F, N, H] 741 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) 742 | 743 | if do_return_2d_tensor: 744 | # `context_layer` = [B*F, N*H] 745 | context_layer = tf.reshape( 746 | context_layer, 747 | [batch_size * from_seq_length, num_attention_heads * size_per_head]) 748 | else: 749 | # `context_layer` = [B, F, N*H] 750 | context_layer = tf.reshape( 751 | context_layer, 752 | [batch_size, from_seq_length, num_attention_heads * size_per_head]) 753 | 754 | return context_layer 755 | 756 | 757 | def transformer_model(input_tensor, 758 | attention_mask=None, 759 | hidden_size=768, 760 | num_hidden_layers=12, 761 | num_attention_heads=12, 762 | intermediate_size=3072, 763 | intermediate_act_fn=gelu, 764 | hidden_dropout_prob=0.1, 765 | attention_probs_dropout_prob=0.1, 766 | initializer_range=0.02, 767 | do_return_all_layers=False): 768 | """Multi-headed, multi-layer Transformer from "Attention is All You Need". 769 | 770 | This is almost an exact implementation of the original Transformer encoder. 771 | 772 | See the original paper: 773 | https://arxiv.org/abs/1706.03762 774 | 775 | Also see: 776 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py 777 | 778 | Args: 779 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. 780 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, 781 | seq_length], with 1 for positions that can be attended to and 0 in 782 | positions that should not be. 783 | hidden_size: int. Hidden size of the Transformer. 784 | num_hidden_layers: int. Number of layers (blocks) in the Transformer. 785 | num_attention_heads: int. Number of attention heads in the Transformer. 786 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed 787 | forward) layer. 788 | intermediate_act_fn: function. The non-linear activation function to apply 789 | to the output of the intermediate/feed-forward layer. 790 | hidden_dropout_prob: float. Dropout probability for the hidden layers. 791 | attention_probs_dropout_prob: float. Dropout probability of the attention 792 | probabilities. 793 | initializer_range: float. Range of the initializer (stddev of truncated 794 | normal). 795 | do_return_all_layers: Whether to also return all layers or just the final 796 | layer. 797 | 798 | Returns: 799 | float Tensor of shape [batch_size, seq_length, hidden_size], the final 800 | hidden layer of the Transformer. 801 | 802 | Raises: 803 | ValueError: A Tensor shape or parameter is invalid. 804 | """ 805 | if hidden_size % num_attention_heads != 0: 806 | raise ValueError( 807 | "The hidden size (%d) is not a multiple of the number of attention " 808 | "heads (%d)" % (hidden_size, num_attention_heads)) 809 | 810 | attention_head_size = int(hidden_size / num_attention_heads) 811 | input_shape = get_shape_list(input_tensor, expected_rank=3) 812 | batch_size = input_shape[0] 813 | seq_length = input_shape[1] 814 | input_width = input_shape[2] 815 | 816 | # The Transformer performs sum residuals on all layers so the input needs 817 | # to be the same as the hidden size. 818 | if input_width != hidden_size: 819 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" % 820 | (input_width, hidden_size)) 821 | 822 | # We keep the representation as a 2D tensor to avoid re-shaping it back and 823 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on 824 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to 825 | # help the optimizer. 826 | prev_output = reshape_to_matrix(input_tensor) 827 | 828 | all_layer_outputs = [] 829 | for layer_idx in range(num_hidden_layers): 830 | with tf.variable_scope("layer_%d" % layer_idx): 831 | layer_input = prev_output 832 | 833 | with tf.variable_scope("attention"): 834 | attention_heads = [] 835 | with tf.variable_scope("self"): 836 | attention_head = attention_layer( 837 | from_tensor=layer_input, 838 | to_tensor=layer_input, 839 | attention_mask=attention_mask, 840 | num_attention_heads=num_attention_heads, 841 | size_per_head=attention_head_size, 842 | attention_probs_dropout_prob=attention_probs_dropout_prob, 843 | initializer_range=initializer_range, 844 | do_return_2d_tensor=True, 845 | batch_size=batch_size, 846 | from_seq_length=seq_length, 847 | to_seq_length=seq_length) 848 | attention_heads.append(attention_head) 849 | 850 | attention_output = None 851 | if len(attention_heads) == 1: 852 | attention_output = attention_heads[0] 853 | else: 854 | # In the case where we have other sequences, we just concatenate 855 | # them to the self-attention head before the projection. 856 | attention_output = tf.concat(attention_heads, axis=-1) 857 | 858 | # Run a linear projection of `hidden_size` then add a residual 859 | # with `layer_input`. 860 | with tf.variable_scope("output"): 861 | attention_output = tf.layers.dense( 862 | attention_output, 863 | hidden_size, 864 | kernel_initializer=create_initializer(initializer_range)) 865 | attention_output = dropout(attention_output, hidden_dropout_prob) 866 | attention_output = layer_norm(attention_output + layer_input) 867 | 868 | # The activation is only applied to the "intermediate" hidden layer. 869 | with tf.variable_scope("intermediate"): 870 | intermediate_output = tf.layers.dense( 871 | attention_output, 872 | intermediate_size, 873 | activation=intermediate_act_fn, 874 | kernel_initializer=create_initializer(initializer_range)) 875 | 876 | # Down-project back to `hidden_size` then add the residual. 877 | with tf.variable_scope("output"): 878 | layer_output = tf.layers.dense( 879 | intermediate_output, 880 | hidden_size, 881 | kernel_initializer=create_initializer(initializer_range)) 882 | layer_output = dropout(layer_output, hidden_dropout_prob) 883 | layer_output = layer_norm(layer_output + attention_output) 884 | prev_output = layer_output 885 | all_layer_outputs.append(layer_output) 886 | 887 | if do_return_all_layers: 888 | final_outputs = [] 889 | for layer_output in all_layer_outputs: 890 | final_output = reshape_from_matrix(layer_output, input_shape) 891 | final_outputs.append(final_output) 892 | return final_outputs 893 | else: 894 | final_output = reshape_from_matrix(prev_output, input_shape) 895 | return final_output 896 | 897 | 898 | def get_shape_list(tensor, expected_rank=None, name=None): 899 | """Returns a list of the shape of tensor, preferring static dimensions. 900 | 901 | Args: 902 | tensor: A tf.Tensor object to find the shape of. 903 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 904 | specified and the `tensor` has a different rank, and exception will be 905 | thrown. 906 | name: Optional name of the tensor for the error message. 907 | 908 | Returns: 909 | A list of dimensions of the shape of tensor. All static dimensions will 910 | be returned as python integers, and dynamic dimensions will be returned 911 | as tf.Tensor scalars. 912 | """ 913 | if name is None: 914 | name = tensor.name 915 | 916 | if expected_rank is not None: 917 | assert_rank(tensor, expected_rank, name) 918 | 919 | shape = tensor.shape.as_list() 920 | 921 | non_static_indexes = [] 922 | for (index, dim) in enumerate(shape): 923 | if dim is None: 924 | non_static_indexes.append(index) 925 | 926 | if not non_static_indexes: 927 | return shape 928 | 929 | dyn_shape = tf.shape(tensor) 930 | for index in non_static_indexes: 931 | shape[index] = dyn_shape[index] 932 | return shape 933 | 934 | 935 | def reshape_to_matrix(input_tensor): 936 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" 937 | ndims = input_tensor.shape.ndims 938 | if ndims < 2: 939 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" % 940 | (input_tensor.shape)) 941 | if ndims == 2: 942 | return input_tensor 943 | 944 | width = input_tensor.shape[-1] 945 | output_tensor = tf.reshape(input_tensor, [-1, width]) 946 | return output_tensor 947 | 948 | 949 | def reshape_from_matrix(output_tensor, orig_shape_list): 950 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor.""" 951 | if len(orig_shape_list) == 2: 952 | return output_tensor 953 | 954 | output_shape = get_shape_list(output_tensor) 955 | 956 | orig_dims = orig_shape_list[0:-1] 957 | width = output_shape[-1] 958 | 959 | return tf.reshape(output_tensor, orig_dims + [width]) 960 | 961 | 962 | def assert_rank(tensor, expected_rank, name=None): 963 | """Raises an exception if the tensor rank is not of the expected rank. 964 | 965 | Args: 966 | tensor: A tf.Tensor to check the rank of. 967 | expected_rank: Python integer or list of integers, expected rank. 968 | name: Optional name of the tensor for the error message. 969 | 970 | Raises: 971 | ValueError: If the expected shape doesn't match the actual shape. 972 | """ 973 | if name is None: 974 | name = tensor.name 975 | 976 | expected_rank_dict = {} 977 | if isinstance(expected_rank, six.integer_types): 978 | expected_rank_dict[expected_rank] = True 979 | else: 980 | for x in expected_rank: 981 | expected_rank_dict[x] = True 982 | 983 | actual_rank = tensor.shape.ndims 984 | if actual_rank not in expected_rank_dict: 985 | scope_name = tf.get_variable_scope().name 986 | raise ValueError( 987 | "For the tensor `%s` in scope `%s`, the actual rank " 988 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % 989 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) 990 | -------------------------------------------------------------------------------- /opcodes.py: -------------------------------------------------------------------------------- 1 | # maps edits to their ids 2 | 3 | import pickle 4 | import utils 5 | 6 | class Opcodes(): 7 | def __init__(self,path_common_inserts, 8 | path_common_multitoken_inserts, 9 | use_transforms=True): 10 | USE_TRANSFORMS=use_transforms #turn this to false if use transforms are not be used 11 | print("path_common_inserts: {}".format(path_common_inserts)) 12 | print("path_common_multitoken_inserts: {}".format(path_common_multitoken_inserts)) 13 | 14 | self.UNK = 0 #dummy 15 | self.SOS = 1 #dummy 16 | self.EOS = 2 #dummy 17 | self.CPY = 3 #copy 18 | self.DEL = 4 #delete 19 | 20 | APPEND_BEGIN = 5 21 | self.APPEND_BEGIN = APPEND_BEGIN 22 | 23 | self.APPEND = {} #appends to their edit ids 24 | self.REP = {} #replaces to their edit ids 25 | APPEND = self.APPEND 26 | REP = self.REP 27 | 28 | common_inserts = pickle.load(open(path_common_inserts,"rb")) 29 | common_multitoken_inserts = pickle.load(open(path_common_multitoken_inserts,"rb")) 30 | 31 | for item in common_inserts: 32 | self.reg_append(item) 33 | for item in common_multitoken_inserts: 34 | self.reg_append(item) 35 | 36 | for item in common_inserts: 37 | self.reg_rep(item) 38 | for item in common_multitoken_inserts: 39 | self.reg_rep(item) 40 | 41 | APPEND_END = APPEND_BEGIN + len(APPEND) - 1 42 | vocab_size = APPEND_BEGIN + len(APPEND) 43 | REP_BEGIN = APPEND_END + 1 44 | vocab_size += len(REP) 45 | REP_END = REP_BEGIN + len(REP) -1 46 | 47 | # TRANSFORM_SUFFIXES 48 | if USE_TRANSFORMS: 49 | self.APPEND_SUFFIX={} 50 | APPEND_SUFFIX = self.APPEND_SUFFIX 51 | # APPEND_SUFFIX is different from APPEND 52 | # APPEND_SUFFIX maps suffix of APPEND-based-suffix-transformation to its edit id 53 | # This is used to map all the APPENDs to corresponding APPEND based suffix transformation 54 | self.APPEND_s = vocab_size 55 | APPEND_SUFFIX["##s"]=self.APPEND_s 56 | vocab_size += 1 57 | self.REMOVE_s = vocab_size 58 | vocab_size += 1 59 | 60 | self.APPEND_d = vocab_size 61 | APPEND_SUFFIX["##d"]=self.APPEND_d 62 | vocab_size += 1 63 | self.REMOVE_d = vocab_size 64 | vocab_size += 1 65 | 66 | self.APPEND_es = vocab_size 67 | APPEND_SUFFIX["##es"]=self.APPEND_es 68 | vocab_size += 1 69 | self.REMOVE_es = vocab_size 70 | vocab_size += 1 71 | 72 | self.APPEND_ing = vocab_size 73 | APPEND_SUFFIX["##ing"]=self.APPEND_ing 74 | vocab_size += 1 75 | self.REMOVE_ing = vocab_size 76 | vocab_size += 1 77 | 78 | self.APPEND_ed = vocab_size 79 | APPEND_SUFFIX["##ed"]=self.APPEND_ed 80 | vocab_size += 1 81 | self.REMOVE_ed = vocab_size 82 | vocab_size += 1 83 | 84 | self.APPEND_ly = vocab_size 85 | APPEND_SUFFIX["##ly"]=self.APPEND_ly 86 | vocab_size += 1 87 | self.REMOVE_ly = vocab_size 88 | vocab_size += 1 89 | 90 | self.APPEND_er = vocab_size 91 | APPEND_SUFFIX["##er"]=self.APPEND_er 92 | vocab_size += 1 93 | self.REMOVE_er = vocab_size 94 | vocab_size += 1 95 | 96 | self.APPEND_al = vocab_size 97 | APPEND_SUFFIX["##al"]=self.APPEND_al 98 | vocab_size += 1 99 | self.REMOVE_al = vocab_size 100 | vocab_size += 1 101 | 102 | self.APPEND_n = vocab_size 103 | APPEND_SUFFIX["##n"]=self.APPEND_n 104 | vocab_size += 1 105 | self.REMOVE_n = vocab_size 106 | vocab_size += 1 107 | 108 | self.APPEND_y = vocab_size 109 | APPEND_SUFFIX["##y"]=self.APPEND_y 110 | vocab_size += 1 111 | self.REMOVE_y = vocab_size 112 | vocab_size += 1 113 | 114 | self.APPEND_ation = vocab_size 115 | APPEND_SUFFIX["##ation"]=self.APPEND_ation 116 | vocab_size += 1 117 | self.REMOVE_ation = vocab_size 118 | vocab_size += 1 119 | 120 | self.E_TO_ING = vocab_size 121 | vocab_size += 1 122 | self.ING_TO_E = vocab_size 123 | vocab_size += 1 124 | 125 | self.D_TO_T = vocab_size 126 | vocab_size += 1 127 | self.T_TO_D = vocab_size 128 | vocab_size += 1 129 | 130 | self.D_TO_S = vocab_size 131 | vocab_size += 1 132 | self.S_TO_D = vocab_size 133 | vocab_size += 1 134 | 135 | self.S_TO_ING = vocab_size 136 | vocab_size += 1 137 | self.ING_TO_S = vocab_size 138 | vocab_size += 1 139 | 140 | self.N_TO_ING = vocab_size 141 | vocab_size += 1 142 | self.ING_TO_N = vocab_size 143 | vocab_size += 1 144 | 145 | self.T_TO_NCE = vocab_size 146 | vocab_size += 1 147 | self.NCE_TO_T = vocab_size 148 | vocab_size += 1 149 | 150 | self.S_TO_ED = vocab_size 151 | vocab_size += 1 152 | self.ED_TO_S = vocab_size 153 | vocab_size += 1 154 | 155 | self.ING_TO_ED = vocab_size 156 | vocab_size += 1 157 | self.ED_TO_ING = vocab_size 158 | vocab_size += 1 159 | 160 | self.ING_TO_ION = vocab_size 161 | vocab_size += 1 162 | self.ION_TO_ING = vocab_size 163 | vocab_size += 1 164 | 165 | self.ING_TO_ATION = vocab_size 166 | vocab_size += 1 167 | self.ATION_TO_ING = vocab_size 168 | vocab_size += 1 169 | 170 | self.T_TO_CE = vocab_size 171 | vocab_size += 1 172 | self.CE_TO_T = vocab_size 173 | vocab_size += 1 174 | 175 | self.Y_TO_IC = vocab_size 176 | vocab_size += 1 177 | self.IC_TO_Y = vocab_size 178 | vocab_size += 1 179 | 180 | self.T_TO_S = vocab_size 181 | vocab_size += 1 182 | self.S_TO_T = vocab_size 183 | vocab_size += 1 184 | 185 | self.E_TO_AL = vocab_size 186 | vocab_size += 1 187 | self.AL_TO_E = vocab_size 188 | vocab_size += 1 189 | 190 | self.Y_TO_ILY = vocab_size 191 | vocab_size += 1 192 | self.ILY_TO_Y = vocab_size 193 | vocab_size += 1 194 | 195 | self.Y_TO_IED = vocab_size 196 | vocab_size += 1 197 | self.IED_TO_Y = vocab_size 198 | vocab_size += 1 199 | 200 | self.Y_TO_ICAL = vocab_size 201 | vocab_size += 1 202 | self.ICAL_TO_Y = vocab_size 203 | vocab_size += 1 204 | 205 | self.Y_TO_IES = vocab_size 206 | vocab_size += 1 207 | self.IES_TO_Y = vocab_size 208 | vocab_size += 1 209 | 210 | 211 | def reg_append(self,word): 212 | #registers an APPEND 213 | if word not in self.APPEND: 214 | self.APPEND[word] = len(self.APPEND) + self.APPEND_BEGIN 215 | else: 216 | print("Skipping duplicate opcode", word) 217 | 218 | def reg_rep(self,word): 219 | #registers a REPLACE 220 | if word not in self.REP: 221 | self.REP[word] = len(self.REP) + len(self.APPEND) + self.APPEND_BEGIN 222 | else: 223 | print("Skipping duplicate opcode", word) 224 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # code adapted from https://github.com/google-research/bert 2 | 3 | """Functions and classes related to optimization (weight updates).""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import re 10 | import tensorflow as tf 11 | 12 | 13 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu, tvars=None): 14 | """Creates an optimizer training op.""" 15 | global_step = tf.train.get_or_create_global_step() 16 | 17 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 18 | 19 | # Implements linear decay of the learning rate. 20 | learning_rate = tf.train.polynomial_decay( 21 | learning_rate, 22 | global_step, 23 | num_train_steps, 24 | end_learning_rate=0.0, 25 | power=1.0, 26 | cycle=False) 27 | 28 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 29 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 30 | if num_warmup_steps: 31 | global_steps_int = tf.cast(global_step, tf.int32) 32 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 33 | 34 | global_steps_float = tf.cast(global_steps_int, tf.float32) 35 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 36 | 37 | warmup_percent_done = global_steps_float / warmup_steps_float 38 | warmup_learning_rate = init_lr * warmup_percent_done 39 | 40 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 41 | learning_rate = ( 42 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 43 | 44 | # It is recommended that you use this optimizer for fine tuning, since this 45 | # is how the model was trained (note that the Adam m/v variables are NOT 46 | # loaded from init_checkpoint.) 47 | optimizer = AdamWeightDecayOptimizer( 48 | learning_rate=learning_rate, 49 | weight_decay_rate=0.01, 50 | beta_1=0.9, 51 | beta_2=0.999, 52 | epsilon=1e-6, 53 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 54 | 55 | if use_tpu: 56 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 57 | 58 | if not tvars: 59 | tvars = tf.trainable_variables() 60 | grads = tf.gradients(loss, tvars) 61 | 62 | # This is how the model was pre-trained. 63 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 64 | 65 | train_op = optimizer.apply_gradients( 66 | zip(grads, tvars), global_step=global_step) 67 | 68 | # Normally the global step update is done inside of `apply_gradients`. 69 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 70 | # a different optimizer, you should probably take this line out. 71 | new_global_step = global_step + 1 72 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 73 | return train_op 74 | 75 | 76 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 77 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 78 | 79 | def __init__(self, 80 | learning_rate, 81 | weight_decay_rate=0.0, 82 | beta_1=0.9, 83 | beta_2=0.999, 84 | epsilon=1e-6, 85 | exclude_from_weight_decay=None, 86 | name="AdamWeightDecayOptimizer"): 87 | """Constructs a AdamWeightDecayOptimizer.""" 88 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 89 | 90 | self.learning_rate = learning_rate 91 | self.weight_decay_rate = weight_decay_rate 92 | self.beta_1 = beta_1 93 | self.beta_2 = beta_2 94 | self.epsilon = epsilon 95 | self.exclude_from_weight_decay = exclude_from_weight_decay 96 | 97 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 98 | """See base class.""" 99 | assignments = [] 100 | for (grad, param) in grads_and_vars: 101 | if grad is None or param is None: 102 | continue 103 | 104 | param_name = self._get_variable_name(param.name) 105 | 106 | m = tf.get_variable( 107 | name=param_name + "/adam_m", 108 | shape=param.shape.as_list(), 109 | dtype=tf.float32, 110 | trainable=False, 111 | initializer=tf.zeros_initializer()) 112 | v = tf.get_variable( 113 | name=param_name + "/adam_v", 114 | shape=param.shape.as_list(), 115 | dtype=tf.float32, 116 | trainable=False, 117 | initializer=tf.zeros_initializer()) 118 | 119 | # Standard Adam update. 120 | next_m = ( 121 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 122 | next_v = ( 123 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 124 | tf.square(grad))) 125 | 126 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 127 | 128 | # Just adding the square of the weights to the loss function is *not* 129 | # the correct way of using L2 regularization/weight decay with Adam, 130 | # since that will interact with the m and v parameters in strange ways. 131 | # 132 | # Instead we want ot decay the weights in a manner that doesn't interact 133 | # with the m/v parameters. This is equivalent to adding the square 134 | # of the weights to the loss with plain (non-momentum) SGD. 135 | if self._do_use_weight_decay(param_name): 136 | update += self.weight_decay_rate * param 137 | 138 | update_with_lr = self.learning_rate * update 139 | 140 | next_param = param - update_with_lr 141 | 142 | assignments.extend( 143 | [param.assign(next_param), 144 | m.assign(next_m), 145 | v.assign(next_v)]) 146 | return tf.group(*assignments, name=name) 147 | 148 | def _do_use_weight_decay(self, param_name): 149 | """Whether to use L2 weight decay for `param_name`.""" 150 | if not self.weight_decay_rate: 151 | return False 152 | if self.exclude_from_weight_decay: 153 | for r in self.exclude_from_weight_decay: 154 | if re.search(r, param_name) is not None: 155 | return False 156 | return True 157 | 158 | def _get_variable_name(self, param_name): 159 | """Get the variable name from the tensor name.""" 160 | m = re.match("^(.*):\\d+$", param_name) 161 | if m is not None: 162 | param_name = m.group(1) 163 | return param_name 164 | -------------------------------------------------------------------------------- /pickles/bea/common_deletes.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awasthiabhijeet/PIE/474769e3c4266deefcb7dd5daf802a1306bc7c99/pickles/bea/common_deletes.p -------------------------------------------------------------------------------- /pickles/bea/common_inserts.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awasthiabhijeet/PIE/474769e3c4266deefcb7dd5daf802a1306bc7c99/pickles/bea/common_inserts.p -------------------------------------------------------------------------------- /pickles/bea/common_multitoken_inserts.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awasthiabhijeet/PIE/474769e3c4266deefcb7dd5daf802a1306bc7c99/pickles/bea/common_multitoken_inserts.p -------------------------------------------------------------------------------- /pickles/conll/common_deletes.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awasthiabhijeet/PIE/474769e3c4266deefcb7dd5daf802a1306bc7c99/pickles/conll/common_deletes.p -------------------------------------------------------------------------------- /pickles/conll/common_inserts.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awasthiabhijeet/PIE/474769e3c4266deefcb7dd5daf802a1306bc7c99/pickles/conll/common_inserts.p -------------------------------------------------------------------------------- /pickles/conll/common_multitoken_inserts.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awasthiabhijeet/PIE/474769e3c4266deefcb7dd5daf802a1306bc7c99/pickles/conll/common_multitoken_inserts.p -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | joblib==0.11 2 | tensorflow==1.12 3 | tqdm==4.23.4 4 | autocorrect==0.3.0 5 | filelock==3.0.12 6 | spacy_kenlm 7 | oauth2client==4.1.2 8 | google-api-python-client==1.7.3 9 | google-auth==1.5.0 10 | google-auth-httplib2==0.0.3 11 | -------------------------------------------------------------------------------- /seq2edits_utils.py: -------------------------------------------------------------------------------- 1 | # Used for creating Seq2Edits function (Describe in Appendix A.1 of the paper) 2 | # Utilized by get_seq2edits.py 3 | # Provide diffs for a source and tagert sentence, 4 | # by 5 | #1. breaking replace operations to deletes followed by inserts 6 | #2. merging consecutive insert operations to a single insert operation 7 | # Uses edit-distance algoithm 8 | # With modified penalties for replace operations (Line 251, 312) 9 | 10 | # adapted by Abhijeet Awasthi 11 | # from https://github.com/belambert/edit-distance/blob/master/edit_distance/code.py 12 | 13 | # -*- mode: Python;-*- 14 | 15 | # Copyright 2013-2018 Ben Lambert 16 | 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | 29 | """ 30 | Code for computing edit distances. 31 | """ 32 | 33 | import sys 34 | import operator 35 | 36 | INSERT = 'insert' 37 | DELETE = 'delete' 38 | EQUAL = 'equal' 39 | REPLACE = 'replace' 40 | 41 | # Cost is basically: was there a match or not. 42 | # The other numbers are cumulative costs and matches. 43 | 44 | def lowest_cost_action(ic, dc, sc, im, dm, sm, cost): 45 | """Given the following values, choose the action (insertion, deletion, 46 | or substitution), that results in the lowest cost (ties are broken using 47 | the 'match' score). This is used within the dynamic programming algorithm. 48 | 49 | * ic - insertion cost 50 | 51 | * dc - deletion cost 52 | 53 | * sc - substitution cost 54 | 55 | * im - insertion match (score) 56 | 57 | * dm - deletion match (score) 58 | 59 | * sm - substitution match (score) 60 | """ 61 | best_action = None 62 | best_match_count = -1 63 | min_cost = min(ic, dc, sc) 64 | if min_cost == sc and cost == 0: 65 | best_action = EQUAL 66 | best_match_count = sm 67 | elif min_cost == sc and cost > 0: 68 | best_action = REPLACE 69 | best_match_count = sm 70 | elif min_cost == ic and im > best_match_count: 71 | best_action = INSERT 72 | best_match_count = im 73 | elif min_cost == dc and dm > best_match_count: 74 | best_action = DELETE 75 | best_match_count = dm 76 | return best_action 77 | 78 | def highest_match_action(ic, dc, sc, im, dm, sm, cost): 79 | """Given the following values, choose the action (insertion, deletion, or 80 | substitution), that results in the highest match score (ties are broken 81 | using the distance values). This is used within the dynamic programming 82 | algorithm. 83 | 84 | * ic - insertion cost 85 | 86 | * dc - deletion cost 87 | 88 | * sc - substitution cost 89 | 90 | * im - insertion match (score) 91 | 92 | * dm - deletion match (score) 93 | 94 | * sm - substitution match (score) 95 | """ 96 | # pylint: disable=unused-argument 97 | best_action = None 98 | lowest_cost = float("inf") 99 | max_match = max(im, dm, sm) 100 | if max_match == sm and cost == 0: 101 | best_action = EQUAL 102 | lowest_cost = sm 103 | elif max_match == sm and cost > 0: 104 | best_action = REPLACE 105 | lowest_cost = sm 106 | elif max_match == im and ic < lowest_cost: 107 | best_action = INSERT 108 | lowest_cost = ic 109 | elif max_match == dm and dc < lowest_cost: 110 | best_action = DELETE 111 | lowest_cost = dc 112 | return best_action 113 | 114 | 115 | class SequenceMatcher(object): 116 | """Similar to the :py:mod:`difflib` :py:class:`~difflib.SequenceMatcher`, but uses Levenshtein/edit 117 | distance. 118 | """ 119 | 120 | def __init__(self, a=None, b=None, test=operator.eq, 121 | action_function=lowest_cost_action): 122 | """Initialize the object with sequences a and b. Optionally, one can 123 | specify a test function that is used to compare sequence elements. 124 | This defaults to the built in ``eq`` operator (i.e. :py:func:`operator.eq`). 125 | """ 126 | if a is None: 127 | a = [] 128 | if b is None: 129 | b = [] 130 | self.seq1 = a 131 | self.seq2 = b 132 | self._reset_object() 133 | self.action_function = action_function 134 | self.test = test 135 | self.dist = None 136 | self._matches = None 137 | self.opcodes = None 138 | 139 | def set_seqs(self, a, b): 140 | """Specify two alternative sequences -- reset any cached values.""" 141 | self.set_seq1(a) 142 | self.set_seq2(b) 143 | self._reset_object() 144 | 145 | def _reset_object(self): 146 | """Clear out the cached values for distance, matches, and opcodes.""" 147 | self.opcodes = None 148 | self.dist = None 149 | self._matches = None 150 | 151 | def set_seq1(self, a): 152 | """Specify a new sequence for sequence 1, resetting cached values.""" 153 | self._reset_object() 154 | self.seq1 = a 155 | 156 | def set_seq2(self, b): 157 | """Specify a new sequence for sequence 2, resetting cached values.""" 158 | self._reset_object() 159 | self.seq2 = b 160 | 161 | def find_longest_match(self, alo, ahi, blo, bhi): 162 | """Not implemented!""" 163 | raise NotImplementedError() 164 | 165 | def get_matching_blocks(self): 166 | """Similar to :py:meth:`get_opcodes`, but returns only the opcodes that are 167 | equal and returns them in a somewhat different format 168 | (i.e. ``(i, j, n)`` ).""" 169 | opcodes = self.get_opcodes() 170 | match_opcodes = filter(lambda x: x[0] == EQUAL, opcodes) 171 | return map(lambda opcode: [opcode[1], opcode[3], opcode[2] - opcode[1]], 172 | match_opcodes) 173 | 174 | def get_opcodes(self): 175 | """Returns a list of opcodes. Opcodes are the same as defined by 176 | :py:mod:`difflib`.""" 177 | if not self.opcodes: 178 | d, m, opcodes = edit_distance_backpointer(self.seq1, self.seq2, 179 | action_function=self.action_function, 180 | test=self.test) 181 | if self.dist: 182 | assert d == self.dist 183 | if self._matches: 184 | assert m == self._matches 185 | self.dist = d 186 | self._matches = m 187 | self.opcodes = opcodes 188 | return self.opcodes 189 | 190 | def get_grouped_opcodes(self, n=None): 191 | """Not implemented!""" 192 | raise NotImplementedError() 193 | 194 | def ratio(self): 195 | """Ratio of matches to the average sequence length.""" 196 | return 2.0 * self.matches() / (len(self.seq1) + len(self.seq2)) 197 | 198 | def quick_ratio(self): 199 | """Same as :py:meth:`ratio`.""" 200 | return self.ratio() 201 | 202 | def real_quick_ratio(self): 203 | """Same as :py:meth:`ratio`.""" 204 | return self.ratio() 205 | 206 | def _compute_distance_fast(self): 207 | """Calls edit_distance, and asserts that if we already have values for 208 | matches and distance, that they match.""" 209 | d, m = edit_distance(self.seq1, self.seq2, 210 | action_function=self.action_function, 211 | test=self.test) 212 | if self.dist: 213 | assert d == self.dist 214 | if self._matches: 215 | assert m == self._matches 216 | self.dist = d 217 | self._matches = m 218 | 219 | def distance(self): 220 | """Returns the edit distance of the two loaded sequences. This should 221 | be a little faster than getting the same information from 222 | :py:meth:`get_opcodes`.""" 223 | if not self.dist: 224 | self._compute_distance_fast() 225 | return self.dist 226 | 227 | def matches(self): 228 | """Returns the number of matches in the alignment of the two sequences. 229 | This should be a little faster than getting the same information from 230 | :py:meth:`get_opcodes`.""" 231 | if not self._matches: 232 | self._compute_distance_fast() 233 | return self._matches 234 | 235 | 236 | def edit_distance(seq1, seq2, action_function=lowest_cost_action, test=operator.eq): 237 | """Computes the edit distance between the two given sequences. 238 | This uses the relatively fast method that only constructs 239 | two columns of the 2d array for edits. This function actually uses four columns 240 | because we track the number of matches too. 241 | """ 242 | m = len(seq1) 243 | n = len(seq2) 244 | # Special, easy cases: 245 | if seq1 == seq2: 246 | return 0, n 247 | if m == 0: 248 | return n, 0 249 | if n == 0: 250 | return m, 0 251 | v0 = [0] * (n + 1) # The two 'error' columns 252 | v1 = [0] * (n + 1) 253 | m0 = [0] * (n + 1) # The two 'match' columns 254 | m1 = [0] * (n + 1) 255 | for i in range(1, n + 1): 256 | v0[i] = i 257 | for i in range(1, m + 1): 258 | v1[0] = i 259 | for j in range(1, n + 1): 260 | cost = 0 if test(seq1[i - 1], seq2[j - 1]) else (1 + abs(len(seq1[i-1])-len(seq2[j-1]))/1000) 261 | # The costs 262 | ins_cost = v1[j - 1] + 1 263 | del_cost = v0[j] + 1 264 | sub_cost = v0[j - 1] + cost 265 | # Match counts 266 | ins_match = m1[j - 1] 267 | del_match = m0[j] 268 | sub_match = m0[j - 1] + (1-cost) 269 | 270 | action = action_function(ins_cost, del_cost, sub_cost, ins_match, 271 | del_match, sub_match, cost) 272 | 273 | if action in [EQUAL, REPLACE]: 274 | v1[j] = sub_cost 275 | m1[j] = sub_match 276 | elif action == INSERT: 277 | v1[j] = ins_cost 278 | m1[j] = ins_match 279 | elif action == DELETE: 280 | v1[j] = del_cost 281 | m1[j] = del_match 282 | else: 283 | raise Exception('Invalid dynamic programming option returned!') 284 | # Copy the columns over 285 | for i in range(0, n + 1): 286 | v0[i] = v1[i] 287 | m0[i] = m1[i] 288 | return v1[n], m1[n] 289 | 290 | 291 | def edit_distance_backpointer(seq1, seq2, action_function=lowest_cost_action, test=operator.eq): 292 | """Similar to :py:func:`~edit_distance.edit_distance` except that this function keeps backpointers 293 | during the search. This allows us to return the opcodes (i.e. the specific 294 | edits that were used to change from one string to another). This function 295 | contructs the full 2d array (actually it contructs three of them: one 296 | for distances, one for matches, and one for backpointers).""" 297 | matches = 0 298 | # Create a 2d distance array 299 | m = len(seq1) 300 | n = len(seq2) 301 | # distances array: 302 | d = [[0 for x in range(n + 1)] for y in range(m + 1)] 303 | # backpointer array: 304 | bp = [[None for x in range(n + 1)] for y in range(m + 1)] 305 | # matches array: 306 | matches = [[0 for x in range(n + 1)] for y in range(m + 1)] 307 | # source prefixes can be transformed into empty string by 308 | # dropping all characters 309 | for i in range(1, m + 1): 310 | d[i][0] = i 311 | bp[i][0] = [DELETE, i - 1, i, 0, 0] 312 | # target prefixes can be reached from empty source prefix by inserting 313 | # every characters 314 | for j in range(1, n + 1): 315 | d[0][j] = j 316 | bp[0][j] = [INSERT, 0, 0, j - 1, j] 317 | # compute the edit distance... 318 | for i in range(1, m + 1): 319 | for j in range(1, n + 1): 320 | 321 | cost = 0 if test(seq1[i - 1], seq2[j - 1]) else (1 + abs(len(seq1[i-1])-len(seq2[j-1]))/1000) 322 | # The costs of each action... 323 | ins_cost = d[i][j - 1] + 1 # insertion 324 | del_cost = d[i - 1][j] + 1 # deletion 325 | sub_cost = d[i - 1][j - 1] + cost # substitution/match 326 | 327 | # The match scores of each action 328 | ins_match = matches[i][j - 1] 329 | del_match = matches[i - 1][j] 330 | sub_match = matches[i - 1][j - 1] + (1-cost) 331 | 332 | action = action_function(ins_cost, del_cost, sub_cost, ins_match, 333 | del_match, sub_match, cost) 334 | if action == EQUAL: 335 | d[i][j] = sub_cost 336 | matches[i][j] = sub_match 337 | bp[i][j] = [EQUAL, i - 1, i, j - 1, j] 338 | elif action == REPLACE: 339 | d[i][j] = sub_cost 340 | matches[i][j] = sub_match 341 | bp[i][j] = [REPLACE, i - 1, i, j - 1, j] 342 | elif action == INSERT: 343 | d[i][j] = ins_cost 344 | matches[i][j] = ins_match 345 | bp[i][j] = [INSERT, i - 1, i - 1, j - 1, j] 346 | elif action == DELETE: 347 | d[i][j] = del_cost 348 | matches[i][j] = del_match 349 | bp[i][j] = [DELETE, i - 1, i, j - 1, j - 1] 350 | else: 351 | raise Exception('Invalid dynamic programming action returned!') 352 | 353 | opcodes = get_opcodes_from_bp_table(bp) 354 | return d[m][n], matches[m][n], opcodes 355 | 356 | 357 | def get_opcodes_from_bp_table(bp): 358 | """Given a 2d list structure, collect the opcodes from the best path.""" 359 | x = len(bp) - 1 360 | y = len(bp[0]) - 1 361 | opcodes = [] 362 | while x != 0 or y != 0: 363 | this_bp = bp[x][y] 364 | opcodes.append(this_bp) 365 | if this_bp[0] == EQUAL or this_bp[0] == REPLACE: 366 | x = x - 1 367 | y = y - 1 368 | elif this_bp[0] == INSERT: 369 | y = y - 1 370 | elif this_bp[0] == DELETE: 371 | x = x - 1 372 | opcodes.reverse() 373 | return opcodes 374 | 375 | def ndiff(source, target, merge_insertions=True): 376 | sm = SequenceMatcher(source, target) 377 | opcodes = sm.get_opcodes() 378 | diff = [] 379 | src_id = 0 380 | tgt_id = 0 381 | for item in opcodes: 382 | if item[0]=='equal': 383 | diff.append(" {}".format(source[src_id])) 384 | elif item[0] == "insert": 385 | diff.append("+ {}".format(target[item[3]])) 386 | tgt_id +=1 387 | src_id -=1 388 | elif item[0] == "replace": #BREAK a substitution to delete (-) followed by insert (+) 389 | diff.append("- {}".format(source[src_id])) 390 | diff.append("+ {}".format(target[item[3]])) 391 | elif item[0] == "delete": 392 | diff.append("- {}".format(source[src_id])) 393 | 394 | src_id +=1 395 | tgt_id +=1 396 | 397 | if merge_insertions: #merge insertions 398 | tmp = [] 399 | for item in diff: 400 | if item[0]!="+" or len(tmp)==0 or tmp[-1][0]!="+": 401 | tmp.append(item) 402 | else: 403 | assert item[0]=="+" 404 | assert tmp[-1][0]=="+" 405 | tmp[-1] = tmp[-1] + " " + item[2:] 406 | 407 | diff = tmp 408 | return diff 409 | 410 | if __name__ == "__main__": 411 | x="I like him , also he like me ." 412 | y="I like him . also , he like ." 413 | 414 | print(ndiff(x.split(),y.split())) 415 | 416 | 417 | x="I like him , also he like me ." 418 | y="I like him . Also , he like me ." 419 | 420 | print("\n\n\n") 421 | print(ndiff(x.split(),y.split())) 422 | -------------------------------------------------------------------------------- /spellcheck_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | reg_ex = re.compile(r"^[a-z][a-z]*[a-z]$") 4 | no_reg_ex = re.compile(r".*[0-9].*") 5 | mc_reg_ex = re.compile(r".*[A-Z].*[A-Z].*") 6 | 7 | def containsNumber(text): 8 | return no_reg_ex.match(text) 9 | 10 | def containsMultiCapital(text): 11 | return mc_reg_ex.match(text) 12 | 13 | def can_spellcheck(w: str): 14 | #return not ((not reg_ex.match(w)) or containsMultiCapital(w) or containsNumber 15 | if reg_ex.match(w): 16 | return True 17 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # code adapted from https://github.com/google-research/bert 2 | 3 | # modification of tokenization.py for GEC 4 | 5 | """Tokenization classes.""" 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import collections 12 | import unicodedata 13 | import six 14 | import tensorflow as tf 15 | 16 | from autocorrect import spell 17 | from spellcheck_utils import can_spellcheck 18 | import re 19 | 20 | 21 | special_tokens = {"n't":"not", "'m":"am", "ca":"can", "Ca":"Can", "wo":"would", "Wo":"Would", 22 | "'ll":"will", "'ve":"have"} 23 | def containsNumber(text): 24 | reg_ex = re.compile(r".*[0-9].*") 25 | if reg_ex.match(text): 26 | #print("{} contains numbers".format(text)) 27 | return True 28 | else: 29 | return False 30 | 31 | def containsMultiCapital(text): 32 | reg_ex=re.compile(r".*[A-Z].*[A-Z].*") 33 | if reg_ex.match(text): 34 | #print("{} conatains multiple capitals".format(text)) 35 | return True 36 | else: 37 | return False 38 | 39 | 40 | def checkAlternateDots(text): 41 | if text[0]==".": 42 | return False 43 | alt = text[1::2] 44 | if set(alt) == {'.'}: 45 | #print("{} contains alternate dots".format(text)) 46 | return True 47 | else: 48 | return False 49 | 50 | def end_with_dotcom(text): 51 | if len(text)>=4 and text[-4:]==".com": 52 | #print("{} contains .com in the end".format(text)) 53 | return True 54 | else: 55 | return False 56 | 57 | def starts_with_www(text): 58 | reg_ex = re.compile(r"^www\..*") 59 | if reg_ex.match(text): 60 | #print("{} starts with www.".format(text)) 61 | return True 62 | else: 63 | return False 64 | 65 | def contains_slash(text): 66 | if "/" in text: 67 | #print("{} contains /".format(text)) 68 | return True 69 | else: 70 | return False 71 | 72 | def contains_percent(text): 73 | if "%" in text: 74 | #print("{} contains %".format(text)) 75 | return True 76 | else: 77 | return False 78 | 79 | def contains_ampersand(text): 80 | if "&" in text: 81 | #print("{} contains &".format(text)) 82 | return True 83 | else: 84 | return False 85 | 86 | def contains_at_rate(text): 87 | if "@" in text: 88 | #print("{} contains @".format(text)) 89 | return True 90 | else: 91 | return False 92 | 93 | def contains_square_brackets(text): 94 | if "[" in text or "]" in text: 95 | #print("{} contains ] or [".format(text)) 96 | return True 97 | else: 98 | return False 99 | 100 | def last_dot_first_capital(text): 101 | if len(text) > 1 and text[-1]=="." and text[0].upper()==text[0]: 102 | #print("{} has dot as last letter and it's first letter is capital".format(text)) 103 | return True 104 | else: 105 | return False 106 | 107 | def check_smilies(text): 108 | if text in [":)",":(",";)",":/",":|"]: 109 | #print("{} is a smiley".format(text)) 110 | return True 111 | else: 112 | return False 113 | 114 | def do_not_split(text, mode="test"): 115 | 116 | if mode == "train": 117 | #print("************************* SPLIT IS ON *************************************") 118 | return False 119 | 120 | if containsNumber(text) or containsMultiCapital(text) or checkAlternateDots(text) \ 121 | or end_with_dotcom(text) or starts_with_www(text) or contains_at_rate(text) \ 122 | or contains_slash(text) or contains_percent(text) or contains_ampersand(text) \ 123 | or contains_square_brackets(text) \ 124 | or last_dot_first_capital(text) \ 125 | or check_smilies(text): 126 | return True 127 | else: 128 | return False 129 | 130 | ''' 131 | def contains_round(text): 132 | if ")" in text or "(" in text: 133 | print("contains_right_round firing on {}".format(text)) 134 | return True 135 | else: 136 | return False 137 | ''' 138 | 139 | def spell_check(text): 140 | if not can_spellcheck(text): 141 | return None 142 | result = spell(text) 143 | return result 144 | ''' 145 | if (text[0].isupper() == result[0].isupper()): #avoid case change due to spelling correction 146 | return result 147 | else: 148 | return None 149 | ''' 150 | 151 | def check_alternate_in_vocab(word,vocab): 152 | assert word not in vocab 153 | 154 | if word == word.lower(): 155 | tmp = word[0].upper() + word[1:] 156 | else: 157 | tmp = word.lower() 158 | 159 | if tmp in vocab: 160 | #print("replacing {} with its alternate {}".format(word, tmp)) 161 | return tmp 162 | else: 163 | return None 164 | 165 | 166 | def convert_to_unicode(text): 167 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 168 | if six.PY3: 169 | if isinstance(text, str): 170 | return text 171 | elif isinstance(text, bytes): 172 | return text.decode("utf-8", "ignore") 173 | else: 174 | raise ValueError("Unsupported string type: %s" % (type(text))) 175 | elif six.PY2: 176 | if isinstance(text, str): 177 | return text.decode("utf-8", "ignore") 178 | elif isinstance(text, unicode): 179 | return text 180 | else: 181 | raise ValueError("Unsupported string type: %s" % (type(text))) 182 | else: 183 | raise ValueError("Not running on Python2 or Python 3?") 184 | 185 | 186 | def printable_text(text): 187 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 188 | 189 | # These functions want `str` for both Python2 and Python3, but in one case 190 | # it's a Unicode string and in the other it's a byte string. 191 | if six.PY3: 192 | if isinstance(text, str): 193 | return text 194 | elif isinstance(text, bytes): 195 | return text.decode("utf-8", "ignore") 196 | else: 197 | raise ValueError("Unsupported string type: %s" % (type(text))) 198 | elif six.PY2: 199 | if isinstance(text, str): 200 | return text 201 | elif isinstance(text, unicode): 202 | return text.encode("utf-8") 203 | else: 204 | raise ValueError("Unsupported string type: %s" % (type(text))) 205 | else: 206 | raise ValueError("Not running on Python2 or Python 3?") 207 | 208 | 209 | def load_vocab(vocab_file): 210 | """Loads a vocabulary file into a dictionary.""" 211 | vocab = collections.OrderedDict() 212 | index = 0 213 | with tf.gfile.GFile(vocab_file, "r") as reader: 214 | while True: 215 | token = convert_to_unicode(reader.readline()) 216 | if not token: 217 | break 218 | token = token.strip() 219 | vocab[token] = index 220 | index += 1 221 | return vocab 222 | 223 | 224 | def convert_by_vocab(vocab, items): 225 | """Converts a sequence of [tokens|ids] using the vocab.""" 226 | output = [] 227 | for item in items: 228 | output.append(vocab[item]) 229 | return output 230 | 231 | 232 | def convert_tokens_to_ids(vocab, tokens): 233 | return convert_by_vocab(vocab, tokens) 234 | 235 | 236 | def convert_ids_to_tokens(inv_vocab, ids): 237 | return convert_by_vocab(inv_vocab, ids) 238 | 239 | 240 | def whitespace_tokenize(text): 241 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 242 | text = text.strip() 243 | if not text: 244 | return [] 245 | tokens = text.split() 246 | return tokens 247 | 248 | 249 | class FullTokenizer(object): 250 | """Runs end-to-end tokenziation.""" 251 | 252 | def __init__(self, vocab_file, do_lower_case=True): 253 | self.vocab = load_vocab(vocab_file) 254 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 255 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, vocab=self.vocab) 256 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 257 | 258 | def tokenize(self, text, mode="test"): 259 | split_tokens = [] 260 | for token in self.basic_tokenizer.tokenize(text,mode): 261 | #print("Hello") 262 | if (len(token) > 1 and do_not_split(token,mode)) or (token in special_tokens): 263 | split_tokens.append(token) 264 | else: 265 | wordpiece_tokens = self.wordpiece_tokenizer.tokenize(token) 266 | if len(wordpiece_tokens) > 1: 267 | if token.capitalize() in self.vocab: 268 | split_tokens.append(token.capitalize()) 269 | elif token.lower() in self.vocab: 270 | split_tokens.append(token.lower()) 271 | elif token.upper() in self.vocab: 272 | split_tokens.append(token.upper()) 273 | elif len(wordpiece_tokens) <=3: 274 | split_tokens.extend(wordpiece_tokens) 275 | else: 276 | split_tokens.append(token) 277 | else: 278 | split_tokens.append(token) 279 | return split_tokens 280 | 281 | def convert_tokens_to_ids(self,items): 282 | output = [] 283 | for item in items: 284 | if item in special_tokens: 285 | output.append(self.vocab[special_tokens[item]]) 286 | elif item in self.vocab: 287 | output.append(self.vocab[item]) 288 | else: 289 | if item.capitalize() in self.vocab: 290 | output.append(self.vocab[item.capitalize()]) 291 | elif item.lower() in self.vocab: 292 | output.append(self.vocab[item.lower()]) 293 | elif item.upper() in self.vocab: 294 | output.append(self.vocab[item.upper()]) 295 | else: 296 | output.append(self.vocab["[UNK]"]) 297 | return output 298 | 299 | #def convert_tokens_to_ids(self, tokens): 300 | # return convert_by_vocab(self.vocab, tokens) 301 | 302 | def convert_ids_to_tokens(self, ids): 303 | return convert_by_vocab(self.inv_vocab, ids) 304 | 305 | 306 | class BasicTokenizer(object): 307 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 308 | 309 | def __init__(self, do_lower_case=True, vocab=None): 310 | """Constructs a BasicTokenizer. 311 | 312 | Args: 313 | do_lower_case: Whether to lower case the input. 314 | """ 315 | self.do_lower_case = do_lower_case 316 | self.vocab = vocab 317 | 318 | def tokenize(self, text, mode="test"): 319 | """Tokenizes a piece of text.""" 320 | text = convert_to_unicode(text) 321 | text = self._clean_text(text) 322 | 323 | # This was added on November 1st, 2018 for the multilingual and Chinese 324 | # models. This is also applied to the English models now, but it doesn't 325 | # matter since the English models were not trained on any Chinese data 326 | # and generally don't have any Chinese data in them (there are Chinese 327 | # characters in the vocabulary because Wikipedia does have some Chinese 328 | # words in the English Wikipedia.). 329 | text = self._tokenize_chinese_chars(text) 330 | 331 | orig_tokens = whitespace_tokenize(text) 332 | split_tokens = [] 333 | for token in orig_tokens: 334 | if self.do_lower_case: 335 | token = token.lower() 336 | token = self._run_strip_accents(token) 337 | 338 | if len(token)==1 or do_not_split(token,mode) or (token in special_tokens): 339 | split_tokens.append(token) 340 | else: 341 | split_tokens.extend(self._run_split_on_punc(token)) 342 | 343 | use_spell_check=False 344 | if use_spell_check: 345 | split_tokens = self._run_spell_check(split_tokens) 346 | 347 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 348 | return output_tokens 349 | 350 | def _run_spell_check(self, tokens): 351 | corrected_tokens = [] 352 | for word in tokens: 353 | output_word = None 354 | if (word in self.vocab) or (word.lower() in self.vocab) or (word.capitalize() in self.vocab) or (word.upper() in self.vocab) or do_not_split(word,"test"): 355 | output_word = word 356 | else: 357 | spell_checked_word = spell_check(word) 358 | if spell_checked_word: 359 | if (spell_checked_word in self.vocab): 360 | #print("spell check FINDS word in VOCAB --- {} --> {}".format(word, spell(word))) 361 | output_word=spell_checked_word 362 | else: 363 | if word[0].isupper(): 364 | # "this case should never be encountered because spell_checked_word is None for cased words 365 | print("Error this should not be encountered") 366 | exit(1) 367 | else: 368 | output_word=spell_checked_word 369 | #print("Spell check DID NOT FIND WORD in VOCAB --- {} --> {}".format(word, spell(word))) 370 | #corrected_tokens.append(spell_checked_word) 371 | #print("{} not in vocab and COULD NOT BE SPELL CHECKED".format(word)) 372 | #corrected_tokens.append(word) 373 | else: 374 | output_word=word 375 | 376 | assert output_word!=None 377 | #if output_word != word: 378 | #print("{} --------------------------------> {}".format(word,output_word)) 379 | corrected_tokens.append(output_word) 380 | return corrected_tokens 381 | 382 | 383 | 384 | def _run_strip_accents(self, text): 385 | """Strips accents from a piece of text.""" 386 | text = unicodedata.normalize("NFD", text) 387 | output = [] 388 | for char in text: 389 | cat = unicodedata.category(char) 390 | if cat == "Mn": 391 | continue 392 | output.append(char) 393 | return "".join(output) 394 | 395 | def _run_split_on_punc(self, text): 396 | """Splits punctuation on a piece of text.""" 397 | 398 | chars = list(text) 399 | i = 0 400 | start_new_word = True 401 | output = [] 402 | while i < len(chars): 403 | char = chars[i] 404 | if _is_punctuation(char): 405 | output.append([char]) 406 | start_new_word = True 407 | else: 408 | if start_new_word: 409 | output.append([]) 410 | start_new_word = False 411 | output[-1].append(char) 412 | i += 1 413 | 414 | return ["".join(x) for x in output] 415 | 416 | def _tokenize_chinese_chars(self, text): 417 | """Adds whitespace around any CJK character.""" 418 | output = [] 419 | for char in text: 420 | cp = ord(char) 421 | if self._is_chinese_char(cp): 422 | output.append(" ") 423 | output.append(char) 424 | output.append(" ") 425 | else: 426 | output.append(char) 427 | return "".join(output) 428 | 429 | def _is_chinese_char(self, cp): 430 | """Checks whether CP is the codepoint of a CJK character.""" 431 | # This defines a "chinese character" as anything in the CJK Unicode block: 432 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 433 | # 434 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 435 | # despite its name. The modern Korean Hangul alphabet is a different block, 436 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 437 | # space-separated words, so they are not treated specially and handled 438 | # like the all of the other languages. 439 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 440 | (cp >= 0x3400 and cp <= 0x4DBF) or # 441 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 442 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 443 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 444 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 445 | (cp >= 0xF900 and cp <= 0xFAFF) or # 446 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 447 | return True 448 | 449 | return False 450 | 451 | def _clean_text(self, text): 452 | """Performs invalid character removal and whitespace cleanup on text.""" 453 | output = [] 454 | for char in text: 455 | cp = ord(char) 456 | if cp == 0 or cp == 0xfffd or _is_control(char): 457 | continue 458 | if _is_whitespace(char): 459 | output.append(" ") 460 | else: 461 | output.append(char) 462 | return "".join(output) 463 | 464 | 465 | class WordpieceTokenizer(object): 466 | """Runs WordPiece tokenziation.""" 467 | 468 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 469 | self.vocab = vocab 470 | self.unk_token = unk_token 471 | self.max_input_chars_per_word = max_input_chars_per_word 472 | 473 | def tokenize(self, text): 474 | """Tokenizes a piece of text into its word pieces. 475 | 476 | This uses a greedy longest-match-first algorithm to perform tokenization 477 | using the given vocabulary. 478 | 479 | For example: 480 | input = "unaffable" 481 | output = ["un", "##aff", "##able"] 482 | 483 | Args: 484 | text: A single token or whitespace separated tokens. This should have 485 | already been passed through `BasicTokenizer. 486 | 487 | Returns: 488 | A list of wordpiece tokens. 489 | """ 490 | 491 | text = convert_to_unicode(text) 492 | 493 | output_tokens = [] 494 | for token in whitespace_tokenize(text): 495 | chars = list(token) 496 | if len(chars) > self.max_input_chars_per_word: 497 | output_tokens.append(self.unk_token) 498 | continue 499 | 500 | is_bad = False 501 | start = 0 502 | sub_tokens = [] 503 | while start < len(chars): 504 | end = len(chars) 505 | cur_substr = None 506 | while start < end: 507 | substr = "".join(chars[start:end]) 508 | if start > 0: 509 | substr = "##" + substr 510 | if substr in self.vocab: 511 | cur_substr = substr 512 | break 513 | end -= 1 514 | if cur_substr is None: 515 | is_bad = True 516 | break 517 | sub_tokens.append(cur_substr) 518 | start = end 519 | 520 | if is_bad: 521 | output_tokens.append(self.unk_token) 522 | else: 523 | output_tokens.extend(sub_tokens) 524 | return output_tokens 525 | 526 | 527 | def _is_whitespace(char): 528 | """Checks whether `chars` is a whitespace character.""" 529 | # \t, \n, and \r are technically contorl characters but we treat them 530 | # as whitespace since they are generally considered as such. 531 | if char == " " or char == "\t" or char == "\n" or char == "\r": 532 | return True 533 | cat = unicodedata.category(char) 534 | if cat == "Zs": 535 | return True 536 | return False 537 | 538 | 539 | def _is_control(char): 540 | """Checks whether `chars` is a control character.""" 541 | # These are technically control characters but we count them as whitespace 542 | # characters. 543 | if char == "\t" or char == "\n" or char == "\r": 544 | return False 545 | cat = unicodedata.category(char) 546 | if cat.startswith("C"): 547 | return True 548 | return False 549 | 550 | 551 | def _is_punctuation(char): 552 | """Checks whether `chars` is a punctuation character.""" 553 | cp = ord(char) 554 | # We treat all non-letter/number ASCII as punctuation. 555 | # Characters such as "^", "$", and "`" are not in the Unicode 556 | # Punctuation class but we treat them as punctuation anyways, for 557 | # consistency. 558 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 559 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 560 | return True 561 | cat = unicodedata.category(char) 562 | if cat.startswith("P"): 563 | return True 564 | return False 565 | -------------------------------------------------------------------------------- /tokenize_input.py: -------------------------------------------------------------------------------- 1 | #tokenize input sentences using word piece tokenizer 2 | 3 | import pickle 4 | from joblib import Parallel, delayed 5 | from tqdm import tqdm 6 | import sys 7 | from utils import open_w, dump_text_to_list, pretty, read_file_lines, custom_tokenize 8 | import tokenization 9 | import argparse 10 | 11 | def add_arguments(parser): 12 | """Build ArgumentParser.""" 13 | parser.register("type", "bool", lambda v: v.lower() == "true") 14 | parser.add_argument("--input", type=str, default=None, help="input file having possibly incorrect sentences") 15 | parser.add_argument("--output_tokens", type=str, default=None, help="tokenized version of input") 16 | parser.add_argument("--output_token_ids", type=str, default=None, help="token ids corresponding to output_tokens") 17 | parser.add_argument("--vocab_path", type=str, default=None, help="path to bert's cased vocab file") 18 | parser.add_argument("--do_spell_check", type="bool",default=False, help="wheter to spell check words") 19 | 20 | 21 | parser = argparse.ArgumentParser() 22 | add_arguments(parser) 23 | FLAGS, unparsed = parser.parse_known_args() 24 | 25 | if FLAGS.do_spell_check: 26 | print("\n\n******************* DOING SPELL CHECK while tokenizing input *******************\n\n") 27 | else: 28 | print("\n\n********************* Skipping Spell Check while tokenizing input *******************\n\n") 29 | 30 | wordpiece_tokenizer = tokenization.FullTokenizer(FLAGS.vocab_path, do_lower_case=False) 31 | vocab_bert = wordpiece_tokenizer.vocab 32 | vocab_words = [vocab_bert.keys()] 33 | 34 | 35 | def get_tuple(line): 36 | 37 | if FLAGS.do_spell_check: 38 | line = line.strip().split() 39 | line = wordpiece_tokenizer.basic_tokenizer._run_spell_check(line) 40 | line = " ".join(line) 41 | tokens = custom_tokenize(line, wordpiece_tokenizer) 42 | token_ids = wordpiece_tokenizer.convert_tokens_to_ids(tokens) 43 | #print(tokens) 44 | #print(token_ids) 45 | return tokens, token_ids 46 | 47 | def write_output(raw_lines, tokens_file, token_ids_file): 48 | tuples = Parallel(n_jobs=1)(delayed(get_tuple)( 49 | raw_lines[i]) for i in tqdm(range(len(raw_lines)))) 50 | 51 | for i in range(len(tuples)): 52 | tokens, token_ids = tuples[i] 53 | 54 | # Write text output 55 | tokens_file.write(' '.join(tokens)) 56 | token_ids_file.write(' '.join(str(x) for x in token_ids)) 57 | 58 | tokens_file.write('\n') 59 | token_ids_file.write('\n') 60 | return 61 | 62 | if __name__=="__main__": 63 | incorrect_lines = read_file_lines(FLAGS.input, 'incorrect lines') 64 | with open_w(FLAGS.output_tokens) as tokens_file,\ 65 | open_w(FLAGS.output_token_ids) as token_ids_file: 66 | 67 | pretty.pheader('Tokenizing Incorrect sentences') 68 | write_output(incorrect_lines, tokens_file, token_ids_file) 69 | -------------------------------------------------------------------------------- /transform_suffixes.py: -------------------------------------------------------------------------------- 1 | """Transform functions for replace.""" 2 | import pickle 3 | 4 | def is_append_suffix(suffix,w1,w2): 5 | l = len(suffix) 6 | if (w2[-l:] == suffix) and (w1 == w2[:-l]): 7 | return True 8 | else: 9 | return False 10 | 11 | def is_transform_suffix(suffix_1,suffix_2, w1, w2): 12 | l1 = len(suffix_1) 13 | l2 = len(suffix_2) 14 | 15 | if (w1[-l1:] == suffix_1) and (w2[-l2:] == suffix_2) and (w1[:-l1] == w2[:-l2]): 16 | return True 17 | else: 18 | return False 19 | 20 | def append_suffix(word, suffix): 21 | 22 | #print("word: {}, suffix: {}".format(word, suffix)) 23 | 24 | if len(word) == 1: 25 | print("We currently fear appending suffix: {} to a one letter word, but still doing it: {}".format(suffix, word)) 26 | #return word 27 | 28 | l = len(suffix) 29 | 30 | #if word[-l:] == suffix: 31 | #print("**** WARNING: SUFFIX: {} ALREADY PRESENT in WORD, BUT still adding it: {} ****".format(suffix, word)) 32 | 33 | 34 | if word[-1] == "s" and suffix == "s": 35 | return word+"es" 36 | 37 | if word[-1] == "y" and suffix == "s" and len(word)>2: 38 | if word[-2] not in ["a","e","i","o","u"]: 39 | return word[0:-1] + "ies" 40 | 41 | 42 | #if word[-1] == "h" and suffix == "s": 43 | # return word + "es" 44 | 45 | #if word[-1] == "t" and suffix == "d": 46 | # return word + "ed" 47 | 48 | #if word[-1] == "k" and suffix == "d": 49 | # return word + "ed" 50 | 51 | return word+suffix 52 | 53 | def remove_suffix(word, suffix): 54 | 55 | if len(suffix) > len(word): 56 | print("suffix: {} to be removed has larger length than word: {}".format(suffix, word)) 57 | return word 58 | 59 | l = len(suffix) 60 | 61 | if word[-l:] == suffix: 62 | return word[:-l] 63 | else: 64 | print("**** WARNING: SUFFIX : {} NOT PRESENT in WORD: {} ****".format(suffix, word)) 65 | return word 66 | 67 | def transform_suffix(word, suffix_1, suffix_2): 68 | 69 | if len(suffix_1) > len(word): 70 | print("suffix: {} to be replaced has larger length than word: {}".format(suffix_1, word)) 71 | return word 72 | 73 | l1 = len(suffix_1) 74 | l2 = len(suffix_2) 75 | 76 | if word[-l1:] == suffix_1: 77 | return word[:-l1] + suffix_2 78 | else: 79 | print("transform") 80 | print("**** WARNING: SUFFIX : {} NOT PRESENT in WORD: {} ****".format(suffix_1, word)) 81 | return word 82 | 83 | 84 | class SuffixTransform(): 85 | """Helper to find if a replacement in a sentence matches a predefined transform.""" 86 | 87 | def __init__(self, src_word, tgt_word, opcodes): 88 | """Create a new transform matching instance 89 | :src_word : original word 90 | :tgt_word : modified word 91 | """ 92 | self.src_word = src_word 93 | self.tgt_word = tgt_word 94 | self.opcodes = opcodes 95 | 96 | def match(self): 97 | """Returns an opcode if matches.""" 98 | if self.src_word == self.tgt_word: 99 | return None 100 | 101 | ''' 102 | return self.pluralization() or self.singularization() \ 103 | or self.capitalization() or self.decapitalization() \ 104 | or self.verb_transform() or None 105 | ''' 106 | 107 | return self.append_s() or self.remove_s() \ 108 | or self.append_d() or self.remove_d() \ 109 | or self.append_es() or self.remove_es() \ 110 | or self.append_ing() or self.remove_ing() \ 111 | or self.append_ed() or self.remove_ed() \ 112 | or self.append_ly() or self.remove_ly() \ 113 | or self.append_er() or self.remove_er() \ 114 | or self.append_al() or self.remove_al() \ 115 | or self.append_n() or self.remove_n() \ 116 | or self.append_y() or self.remove_y() \ 117 | or self.append_ation() or self.remove_ation() \ 118 | or self.e_to_ing() or self.ing_to_e() \ 119 | or self.d_to_t() or self.t_to_d() \ 120 | or self.d_to_s() or self.s_to_d() \ 121 | or self.s_to_ing() or self.ing_to_s() \ 122 | or self.n_to_ing() or self.ing_to_n() \ 123 | or self.nce_to_t() or self.t_to_nce() \ 124 | or self.s_to_ed() or self.ed_to_s() \ 125 | or self.ing_to_ed() or self.ed_to_ing() \ 126 | or self.ing_to_ion() or self.ion_to_ing() \ 127 | or self.ing_to_ation() or self.ation_to_ing() \ 128 | or self.t_to_ce() or self.ce_to_t() \ 129 | or self.y_to_ic() or self.ic_to_y() \ 130 | or self.t_to_s() or self.s_to_t() \ 131 | or self.e_to_al() or self.al_to_e() \ 132 | or self.y_to_ily() or self.ily_to_y() \ 133 | or self.y_to_ied() or self.ied_to_y() \ 134 | or self.y_to_ical() or self.ical_to_y() \ 135 | or self.y_to_ies() or self.ies_to_y() \ 136 | or None 137 | 138 | 139 | 140 | def e_to_ing(self): 141 | if is_transform_suffix("e","ing",self.src_word,self.tgt_word): 142 | return self.opcodes.E_TO_ING 143 | else: 144 | return None 145 | 146 | def ing_to_e(self): 147 | if is_transform_suffix("ing","e",self.src_word,self.tgt_word): 148 | return self.opcodes.ING_TO_E 149 | else: 150 | return None 151 | 152 | def d_to_t(self): 153 | if is_transform_suffix("d","t",self.src_word,self.tgt_word): 154 | return self.opcodes.D_TO_T 155 | else: 156 | return None 157 | 158 | def t_to_d(self): 159 | if is_transform_suffix("t","d",self.src_word,self.tgt_word): 160 | return self.opcodes.T_TO_D 161 | else: 162 | return None 163 | 164 | def d_to_s(self): 165 | if is_transform_suffix("d","s",self.src_word,self.tgt_word): 166 | return self.opcodes.D_TO_S 167 | else: 168 | return None 169 | 170 | def s_to_d(self): 171 | if is_transform_suffix("s","d",self.src_word,self.tgt_word): 172 | return self.opcodes.S_TO_D 173 | else: 174 | return None 175 | 176 | def s_to_ing(self): 177 | if is_transform_suffix("s","ing",self.src_word,self.tgt_word): 178 | return self.opcodes.S_TO_ING 179 | else: 180 | return None 181 | 182 | def ing_to_s(self): 183 | if is_transform_suffix("ing","s",self.src_word,self.tgt_word): 184 | return self.opcodes.ING_TO_S 185 | else: 186 | return None 187 | 188 | def n_to_ing(self): 189 | if is_transform_suffix("n","ing",self.src_word,self.tgt_word): 190 | return self.opcodes.N_TO_ING 191 | else: 192 | return None 193 | 194 | def ing_to_n(self): 195 | if is_transform_suffix("ing","n",self.src_word,self.tgt_word): 196 | return self.opcodes.ING_TO_N 197 | else: 198 | return None 199 | 200 | def t_to_nce(self): 201 | if is_transform_suffix("t","nce",self.src_word,self.tgt_word): 202 | return self.opcodes.T_TO_NCE 203 | else: 204 | return None 205 | 206 | def nce_to_t(self): 207 | if is_transform_suffix("nce","t",self.src_word,self.tgt_word): 208 | return self.opcodes.NCE_TO_T 209 | else: 210 | return None 211 | 212 | def s_to_ed(self): 213 | if is_transform_suffix("s","ed",self.src_word,self.tgt_word): 214 | return self.opcodes.S_TO_ED 215 | else: 216 | return None 217 | 218 | def ed_to_s(self): 219 | if is_transform_suffix("ed","s",self.src_word,self.tgt_word): 220 | return self.opcodes.ED_TO_S 221 | else: 222 | return None 223 | 224 | def ing_to_ed(self): 225 | if is_transform_suffix("ing","ed",self.src_word,self.tgt_word): 226 | return self.opcodes.ING_TO_ED 227 | else: 228 | return None 229 | 230 | def ed_to_ing(self): 231 | if is_transform_suffix("ed","ing",self.src_word,self.tgt_word): 232 | return self.opcodes.ED_TO_ING 233 | else: 234 | return None 235 | 236 | def ing_to_ion(self): 237 | if is_transform_suffix("ing","ion",self.src_word,self.tgt_word): 238 | return self.opcodes.ING_TO_ION 239 | else: 240 | return None 241 | 242 | def ion_to_ing(self): 243 | if is_transform_suffix("ion","ing",self.src_word,self.tgt_word): 244 | return self.opcodes.ION_TO_ING 245 | else: 246 | return None 247 | 248 | def ing_to_ation(self): 249 | if is_transform_suffix("ing","ation",self.src_word,self.tgt_word): 250 | return self.opcodes.ING_TO_ATION 251 | else: 252 | return None 253 | 254 | def ation_to_ing(self): 255 | if is_transform_suffix("ation","ing",self.src_word,self.tgt_word): 256 | return self.opcodes.ATION_TO_ING 257 | else: 258 | return None 259 | 260 | def t_to_ce(self): 261 | if is_transform_suffix("t","ce",self.src_word,self.tgt_word): 262 | return self.opcodes.T_TO_CE 263 | else: 264 | return None 265 | 266 | def ce_to_t(self): 267 | if is_transform_suffix("ce","t",self.src_word,self.tgt_word): 268 | return self.opcodes.CE_TO_T 269 | else: 270 | return None 271 | 272 | def y_to_ic(self): 273 | if is_transform_suffix("y","ic",self.src_word,self.tgt_word): 274 | return self.opcodes.Y_TO_IC 275 | else: 276 | return None 277 | 278 | def ic_to_y(self): 279 | if is_transform_suffix("ic","y",self.src_word,self.tgt_word): 280 | return self.opcodes.IC_TO_Y 281 | else: 282 | return None 283 | 284 | def t_to_s(self): 285 | if is_transform_suffix("t","s",self.src_word,self.tgt_word): 286 | return self.opcodes.T_TO_S 287 | else: 288 | return None 289 | 290 | def s_to_t(self): 291 | if is_transform_suffix("s","t",self.src_word,self.tgt_word): 292 | return self.opcodes.S_TO_T 293 | else: 294 | return None 295 | 296 | def e_to_al(self): 297 | if is_transform_suffix("e","al",self.src_word,self.tgt_word): 298 | return self.opcodes.E_TO_AL 299 | else: 300 | return None 301 | 302 | def al_to_e(self): 303 | if is_transform_suffix("al","e",self.src_word,self.tgt_word): 304 | return self.opcodes.AL_TO_E 305 | else: 306 | return None 307 | 308 | def y_to_ily(self): 309 | if is_transform_suffix("y","ily",self.src_word,self.tgt_word): 310 | return self.opcodes.Y_TO_ILY 311 | else: 312 | return None 313 | 314 | def ily_to_y(self): 315 | if is_transform_suffix("ily","y",self.src_word,self.tgt_word): 316 | return self.opcodes.ILY_TO_Y 317 | else: 318 | return None 319 | 320 | def y_to_ied(self): 321 | if is_transform_suffix("y","ied",self.src_word,self.tgt_word): 322 | return self.opcodes.Y_TO_IED 323 | else: 324 | return None 325 | 326 | def ied_to_y(self): 327 | if is_transform_suffix("ied","y",self.src_word,self.tgt_word): 328 | return self.opcodes.IED_TO_Y 329 | else: 330 | return None 331 | 332 | def y_to_ical(self): 333 | if is_transform_suffix("y","ical",self.src_word,self.tgt_word): 334 | return self.opcodes.Y_TO_ICAL 335 | else: 336 | return None 337 | 338 | def ical_to_y(self): 339 | if is_transform_suffix("ical","y",self.src_word,self.tgt_word): 340 | return self.opcodes.ICAL_TO_Y 341 | else: 342 | return None 343 | 344 | def y_to_ies(self): 345 | if is_transform_suffix("y","ies",self.src_word,self.tgt_word): 346 | return self.opcodes.Y_TO_IES 347 | else: 348 | return None 349 | 350 | def ies_to_y(self): 351 | if is_transform_suffix("ies","y",self.src_word,self.tgt_word): 352 | return self.opcodes.IES_TO_Y 353 | else: 354 | return None 355 | 356 | def append_s(self): 357 | if is_append_suffix("s",self.src_word, self.tgt_word) and (self.src_word not in ["a","A","I","i"]): 358 | return self.opcodes.APPEND_s 359 | else: 360 | return None 361 | 362 | def remove_s(self): 363 | if is_append_suffix("s",self.tgt_word,self.src_word) and (self.src_word not in ["As","as","is","Is"]): 364 | return self.opcodes.REMOVE_s 365 | else: 366 | return None 367 | 368 | def append_d(self): 369 | if is_append_suffix("d",self.src_word, self.tgt_word): 370 | return self.opcodes.APPEND_d 371 | else: 372 | return None 373 | 374 | def remove_d(self): 375 | if is_append_suffix("d",self.tgt_word,self.src_word): 376 | return self.opcodes.REMOVE_d 377 | else: 378 | return None 379 | 380 | def append_es(self): 381 | if is_append_suffix("es",self.src_word, self.tgt_word): 382 | return self.opcodes.APPEND_es 383 | else: 384 | return None 385 | 386 | def remove_es(self): 387 | if is_append_suffix("es",self.tgt_word,self.src_word): 388 | return self.opcodes.REMOVE_es 389 | else: 390 | return None 391 | 392 | def append_ing(self): 393 | if is_append_suffix("ing",self.src_word, self.tgt_word): 394 | return self.opcodes.APPEND_ing 395 | else: 396 | return None 397 | 398 | def remove_ing(self): 399 | if is_append_suffix("ing",self.tgt_word,self.src_word): 400 | return self.opcodes.REMOVE_ing 401 | else: 402 | return None 403 | 404 | def append_ed(self): 405 | if is_append_suffix("ed",self.src_word, self.tgt_word): 406 | return self.opcodes.APPEND_ed 407 | else: 408 | return None 409 | 410 | def remove_ed(self): 411 | if is_append_suffix("ed",self.tgt_word,self.src_word): 412 | return self.opcodes.REMOVE_ed 413 | else: 414 | return None 415 | 416 | def append_ly(self): 417 | if is_append_suffix("ly",self.src_word, self.tgt_word): 418 | return self.opcodes.APPEND_ly 419 | else: 420 | return None 421 | 422 | def remove_ly(self): 423 | if is_append_suffix("ly",self.tgt_word,self.src_word): 424 | return self.opcodes.REMOVE_ly 425 | else: 426 | return None 427 | 428 | def append_er(self): 429 | if is_append_suffix("er",self.src_word, self.tgt_word): 430 | return self.opcodes.APPEND_er 431 | else: 432 | return None 433 | 434 | def remove_er(self): 435 | if is_append_suffix("er",self.tgt_word,self.src_word): 436 | return self.opcodes.REMOVE_er 437 | else: 438 | return None 439 | 440 | def append_al(self): 441 | if is_append_suffix("al",self.src_word, self.tgt_word): 442 | return self.opcodes.APPEND_al 443 | else: 444 | return None 445 | 446 | def remove_al(self): 447 | if is_append_suffix("al",self.tgt_word,self.src_word): 448 | return self.opcodes.REMOVE_al 449 | else: 450 | return None 451 | 452 | def append_n(self): 453 | if is_append_suffix("n",self.src_word, self.tgt_word) and (self.src_word not in ["a","A","i","I"]): 454 | return self.opcodes.APPEND_n 455 | else: 456 | return None 457 | 458 | def remove_n(self): 459 | if is_append_suffix("n",self.tgt_word,self.src_word) and (self.src_word not in ["an","An","in","In"]): 460 | return self.opcodes.REMOVE_n 461 | else: 462 | return None 463 | 464 | def append_y(self): 465 | if is_append_suffix("y",self.src_word, self.tgt_word) and (self.src_word not in ["m","M"]): 466 | return self.opcodes.APPEND_y 467 | else: 468 | return None 469 | def remove_y(self): 470 | if is_append_suffix("y",self.tgt_word,self.src_word) and (self.src_word not in ["My","my"]): 471 | return self.opcodes.REMOVE_y 472 | else: 473 | return None 474 | 475 | def append_ation(self): 476 | if is_append_suffix("ation",self.src_word, self.tgt_word): 477 | return self.opcodes.APPEND_ation 478 | else: 479 | return None 480 | 481 | def remove_ation(self): 482 | if is_append_suffix("ation",self.tgt_word,self.src_word): 483 | return self.opcodes.REMOVE_ation 484 | else: 485 | return None 486 | 487 | def apply_transform(uncorrected, uposition, opcode, opcodes): 488 | """Tries to apply an opcode to a word or returns None 489 | :param uncorrected: Tokenized uncorrected sentence 490 | :uposition: Position of replaced in uncorrected sentence 491 | :opcode: Opcode to try applying 492 | """ 493 | art = ApplySuffixTransorm(uncorrected,uposition, opcode, opcodes) 494 | return art.apply() 495 | 496 | class ApplySuffixTransorm(): 497 | 498 | def __init__(self, uncorrected, uposition, opcode, opcodes): 499 | """Tries to apply an opcode to a word or returns None 500 | :param uncorrected: Tokenized uncorrected sentence 501 | :uposition: Position of replaced in uncorrected sentence 502 | :opcode: Opcode to try applying 503 | """ 504 | 505 | self.uncorrected = uncorrected 506 | self.uposition = uposition 507 | self.opcode = opcode 508 | self.src_word = uncorrected[uposition] 509 | self.opcodes = opcodes 510 | 511 | def apply(self): 512 | """Try to apply the transform 513 | :return: Transformed word or None if cannot transform 514 | """ 515 | transformed = None 516 | 517 | if self.opcode == self.opcodes.APPEND_s: 518 | transformed = append_suffix(self.src_word, "s") 519 | 520 | if self.opcode == self.opcodes.APPEND_d: 521 | transformed = append_suffix(self.src_word, "d") 522 | 523 | if self.opcode == self.opcodes.APPEND_es: 524 | transformed = append_suffix(self.src_word, "es") 525 | 526 | if self.opcode == self.opcodes.APPEND_ing: 527 | transformed = append_suffix(self.src_word, "ing") 528 | 529 | if self.opcode == self.opcodes.APPEND_ed: 530 | transformed = append_suffix(self.src_word, "ed") 531 | 532 | if self.opcode == self.opcodes.APPEND_ly: 533 | transformed = append_suffix(self.src_word, "ly") 534 | 535 | if self.opcode == self.opcodes.APPEND_er: 536 | transformed = append_suffix(self.src_word, "er") 537 | 538 | if self.opcode == self.opcodes.APPEND_al: 539 | transformed = append_suffix(self.src_word, "al") 540 | 541 | if self.opcode == self.opcodes.APPEND_n: 542 | transformed = append_suffix(self.src_word, "n") 543 | 544 | if self.opcode == self.opcodes.APPEND_y: 545 | transformed = append_suffix(self.src_word, "y") 546 | 547 | if self.opcode == self.opcodes.APPEND_ation: 548 | transformed = append_suffix(self.src_word, "ation") 549 | 550 | if self.opcode == self.opcodes.REMOVE_s: 551 | transformed = remove_suffix(self.src_word, "s") 552 | 553 | if self.opcode == self.opcodes.REMOVE_d: 554 | transformed = remove_suffix(self.src_word, "d") 555 | 556 | if self.opcode == self.opcodes.REMOVE_es: 557 | transformed = remove_suffix(self.src_word, "es") 558 | 559 | if self.opcode == self.opcodes.REMOVE_ing: 560 | transformed = remove_suffix(self.src_word, "ing") 561 | 562 | if self.opcode == self.opcodes.REMOVE_ed: 563 | transformed = remove_suffix(self.src_word, "ed") 564 | 565 | if self.opcode == self.opcodes.REMOVE_ly: 566 | transformed = remove_suffix(self.src_word, "ly") 567 | 568 | if self.opcode == self.opcodes.REMOVE_er: 569 | transformed = remove_suffix(self.src_word, "er") 570 | 571 | if self.opcode == self.opcodes.REMOVE_al: 572 | transformed = remove_suffix(self.src_word, "al") 573 | 574 | if self.opcode == self.opcodes.REMOVE_n: 575 | transformed = remove_suffix(self.src_word, "n") 576 | 577 | if self.opcode == self.opcodes.REMOVE_y: 578 | transformed = remove_suffix(self.src_word, "y") 579 | 580 | if self.opcode == self.opcodes.REMOVE_ation: 581 | transformed = remove_suffix(self.src_word, "ation") 582 | 583 | if self.opcode == self.opcodes.E_TO_ING: 584 | transformed = transform_suffix(self.src_word, "e", "ing") 585 | 586 | if self.opcode == self.opcodes.ING_TO_E: 587 | transformed = transform_suffix(self.src_word, "ing", "e") 588 | 589 | if self.opcode == self.opcodes.D_TO_T: 590 | transformed = transform_suffix(self.src_word, "d", "t") 591 | 592 | if self.opcode == self.opcodes.T_TO_D: 593 | transformed = transform_suffix(self.src_word, "t", "d") 594 | 595 | if self.opcode == self.opcodes.D_TO_S: 596 | transformed = transform_suffix(self.src_word, "d", "s") 597 | 598 | if self.opcode == self.opcodes.S_TO_D: 599 | transformed = transform_suffix(self.src_word, "s", "d") 600 | 601 | if self.opcode == self.opcodes.S_TO_ING: 602 | transformed = transform_suffix(self.src_word, "s", "ing") 603 | 604 | if self.opcode == self.opcodes.ING_TO_S: 605 | transformed = transform_suffix(self.src_word, "ing", "s") 606 | 607 | if self.opcode == self.opcodes.N_TO_ING: 608 | transformed = transform_suffix(self.src_word, "n", "ing") 609 | 610 | if self.opcode == self.opcodes.ING_TO_N: 611 | transformed = transform_suffix(self.src_word, "ing", "n") 612 | 613 | if self.opcode == self.opcodes.T_TO_NCE: 614 | transformed = transform_suffix(self.src_word, "t", "nce") 615 | 616 | if self.opcode == self.opcodes.NCE_TO_T: 617 | transformed = transform_suffix(self.src_word, "nce", "t") 618 | 619 | if self.opcode == self.opcodes.S_TO_ED: 620 | transformed = transform_suffix(self.src_word, "s", "ed") 621 | 622 | if self.opcode == self.opcodes.ED_TO_S: 623 | transformed = transform_suffix(self.src_word, "ed", "s") 624 | 625 | if self.opcode == self.opcodes.ING_TO_ED: 626 | transformed = transform_suffix(self.src_word, "ing", "ed") 627 | 628 | if self.opcode == self.opcodes.ED_TO_ING: 629 | transformed = transform_suffix(self.src_word, "ed", "ing") 630 | 631 | if self.opcode == self.opcodes.ING_TO_ION: 632 | transformed = transform_suffix(self.src_word, "ing", "ion") 633 | 634 | if self.opcode == self.opcodes.ION_TO_ING: 635 | transformed = transform_suffix(self.src_word, "ion", "ing") 636 | 637 | if self.opcode == self.opcodes.ING_TO_ATION: 638 | transformed = transform_suffix(self.src_word, "ing", "ation") 639 | 640 | if self.opcode == self.opcodes.ATION_TO_ING: 641 | transformed = transform_suffix(self.src_word, "ation", "ing") 642 | 643 | if self.opcode == self.opcodes.T_TO_CE: 644 | transformed = transform_suffix(self.src_word, "t", "ce") 645 | 646 | if self.opcode == self.opcodes.CE_TO_T: 647 | transformed = transform_suffix(self.src_word, "ce", "t") 648 | 649 | if self.opcode == self.opcodes.Y_TO_IC: 650 | transformed = transform_suffix(self.src_word, "y", "ic") 651 | 652 | if self.opcode == self.opcodes.IC_TO_Y: 653 | transformed = transform_suffix(self.src_word, "ic", "y") 654 | 655 | if self.opcode == self.opcodes.T_TO_S: 656 | transformed = transform_suffix(self.src_word, "t", "s") 657 | 658 | if self.opcode == self.opcodes.S_TO_T: 659 | transformed = transform_suffix(self.src_word, "s", "t") 660 | 661 | if self.opcode == self.opcodes.E_TO_AL: 662 | transformed = transform_suffix(self.src_word, "e", "al") 663 | 664 | if self.opcode == self.opcodes.AL_TO_E: 665 | transformed = transform_suffix(self.src_word, "al", "e") 666 | 667 | if self.opcode == self.opcodes.Y_TO_ILY: 668 | transformed = transform_suffix(self.src_word, "y", "ily") 669 | 670 | if self.opcode == self.opcodes.ILY_TO_Y: 671 | transformed = transform_suffix(self.src_word, "ily", "y") 672 | 673 | if self.opcode == self.opcodes.Y_TO_IED: 674 | transformed = transform_suffix(self.src_word, "y", "ied") 675 | 676 | if self.opcode == self.opcodes.IED_TO_Y: 677 | transformed = transform_suffix(self.src_word, "ied", "y") 678 | 679 | if self.opcode == self.opcodes.Y_TO_ICAL: 680 | transformed = transform_suffix(self.src_word, "y", "ical") 681 | 682 | if self.opcode == self.opcodes.ICAL_TO_Y: 683 | transformed = transform_suffix(self.src_word, "ical", "y") 684 | 685 | if self.opcode == self.opcodes.Y_TO_IES: 686 | transformed = transform_suffix(self.src_word, "y", "ies") 687 | 688 | if self.opcode == self.opcodes.IES_TO_Y: 689 | transformed = transform_suffix(self.src_word, "ies", "y") 690 | 691 | return transformed or None -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """Helper utilities.""" 2 | import pickle 3 | import os 4 | from tqdm import tqdm 5 | 6 | class open_w(): 7 | """Open a file for writing (overwrite) with encoding utf-8 in text mode. 8 | :param filename: Name of file 9 | :param append: Opens the file for appending if true 10 | :return: file handle 11 | """ 12 | 13 | def __init__(self, filename, append=False): 14 | self.filename = filename 15 | self.append = append 16 | self.fd = None 17 | def __enter__(self): 18 | self.fd = open(self.filename, 'w' if not self.append else 'a', encoding='utf-8') 19 | return self.fd 20 | def __exit__(self, type, value, traceback): 21 | print('Wrote ' + pretty.fname(self.fd.name)) 22 | self.fd.close() 23 | 24 | def open_r(filename): 25 | """Open a file for reading with encoding utf-8 in text mode.""" 26 | return open(filename, 'r', encoding='utf-8') 27 | 28 | def do_pickle(obj, filename, message="pickle", protocol=3): 29 | """Pickle an object and show a message.""" 30 | pretty.start('Dumping ' + message + ' to ' + pretty.fname(filename)) 31 | pickle.dump(obj, open(filename, 'wb'),protocol=protocol) 32 | pretty.ok() 33 | 34 | def dump_text_to_list(filename, dump_path): 35 | """Dump space separated list of lists from text file to pickle.""" 36 | pretty.start('Dumping ' + pretty.fname(filename) + ' to ' + pretty.fname(dump_path)) 37 | with open(filename, 'r', encoding='utf-8') as edit_file: 38 | edit_list = [list(map(int, line.split(' '))) for line in edit_file.read().splitlines() if line] 39 | pickle.dump(edit_list, open(dump_path, "wb")) 40 | pretty.ok() 41 | 42 | def assert_fileexists(filename, info='data'): 43 | pretty.start('Checking for ' + pretty.fname(filename)) 44 | if not os.path.exists(filename): 45 | pretty.fail('NOT FOUND') 46 | pretty.fail('Fatal Error - FILE NOT FOUND') 47 | exit() 48 | pretty.ok() 49 | 50 | def read_file(filename, info='data'): 51 | pretty.start('Reading ' + info + ' from ' + pretty.fname(filename)) 52 | if not os.path.exists(filename): 53 | pretty.fail('NOT FOUND') 54 | pretty.fail('Fatal Error - FILE NOT FOUND') 55 | exit() 56 | 57 | with open_r(filename) as file: 58 | ans = file.read() 59 | pretty.ok() 60 | 61 | return ans 62 | 63 | def read_file_lines(filename, info='data'): 64 | return read_file(filename, info).splitlines() 65 | 66 | class bcolors: 67 | HEADER = '\033[95m' 68 | OKBLUE = '\033[94m' 69 | OKGREEN = '\033[92m' 70 | WARNING = '\033[93m' 71 | FAIL = '\033[91m' 72 | ENDC = '\033[0m' 73 | BOLD = '\033[1m' 74 | UNDERLINE = '\033[4m' 75 | 76 | class pretty: 77 | @staticmethod 78 | def start(operation): 79 | print(str(operation) + ' - ', end='', flush=True) 80 | 81 | @staticmethod 82 | def ok(message='OK'): 83 | print(bcolors.OKGREEN + str(message) + bcolors.ENDC) 84 | 85 | @staticmethod 86 | def fail(message="FAIL"): 87 | print(bcolors.FAIL + str(message) + bcolors.ENDC) 88 | 89 | @staticmethod 90 | def warn(message="WARNING"): 91 | print(bcolors.WARNING + str(message) + bcolors.ENDC) 92 | 93 | @staticmethod 94 | def pheader(message): 95 | print(bcolors.HEADER + str(message) + bcolors.ENDC) 96 | 97 | @staticmethod 98 | def fname(message): 99 | return bcolors.OKBLUE + str(message) + bcolors.ENDC 100 | 101 | @staticmethod 102 | def passert(condition, message='Test'): 103 | pretty.start(message) 104 | if condition: 105 | pretty.ok() 106 | return True 107 | else: 108 | pretty.fail() 109 | return False 110 | 111 | @staticmethod 112 | def assert_gt(a, b, message='Test'): 113 | """Assert if a is greater than b.""" 114 | return pretty.passert(a > b, str(message) + ' - ' + str(a) + ' > ' + str(b)) 115 | 116 | @staticmethod 117 | def assert_eq(a, b, message='Test'): 118 | """Assert if a is equal to b.""" 119 | return pretty.passert(a == b, message) 120 | 121 | @staticmethod 122 | def assert_in(a, b, message='Test'): 123 | """Assert if a is in b.""" 124 | return pretty.passert(a in b, message) 125 | 126 | 127 | def generator_based_read_file(filename, info='data',do_split=False,map_to_int=False): 128 | batch_size=10000 129 | #pretty.start('Reading ' + info + ' from ' + pretty.fname(filename)) 130 | if not os.path.exists(filename): 131 | pretty.fail('NOT FOUND') 132 | pretty.fail('Fatal Error - FILE NOT FOUND') 133 | exit() 134 | 135 | with open_r(filename) as file: 136 | result = [] 137 | for i,line in enumerate(file): 138 | out = line.strip() 139 | if do_split: 140 | out = out.split() 141 | if map_to_int: 142 | out = list(map(int,out)) 143 | result.append(out) 144 | if i and i%(batch_size-1)==0: 145 | yield result 146 | result = [] 147 | if len(result)>0: 148 | yield result 149 | #pretty.ok() 150 | 151 | def read_file_lines(filename, info='data'): 152 | return read_file(filename, info).splitlines() 153 | 154 | def custom_tokenize(sentence, tokenizer, mode="test"): 155 | #tokenizes the sentences 156 | #adds [CLS] (start) and [SEP] (end) token 157 | tokenized = tokenizer.tokenize(sentence,mode) 158 | tokenized = ["[CLS]"] + tokenized + ["[SEP]"] 159 | return tokenized -------------------------------------------------------------------------------- /wem_utils.py: -------------------------------------------------------------------------------- 1 | #util functions for word_edit_model.py 2 | 3 | import tensorflow as tf 4 | import time 5 | 6 | def list_to_ids(s_list, tokenizer): 7 | #converst list of strings to list of list of token ids 8 | result = [] 9 | for item in s_list: 10 | tokens = item.split() 11 | ids = tokenizer.convert_tokens_to_ids(tokens) 12 | result.append(ids) 13 | 14 | return result 15 | 16 | def list_embedding_lookup(embedding_table, input_ids, 17 | use_one_hot_embeddings, vocab_size, embedding_size): 18 | #input ids is a list of word ids 19 | #returns sum of word_embeddings corresponding to input ids 20 | if use_one_hot_embeddings: 21 | one_hot_input_ids = tf.one_hot(input_ids, depth=vocab_size) 22 | output = tf.matmul(one_hot_input_ids, embedding_table) 23 | else: 24 | output = tf.nn.embedding_lookup(embedding_table, input_ids) 25 | result = tf.reduce_sum(output,0,keepdims=True) 26 | #result = tf.expand_dims(result,0) 27 | print("********* shape of reduce_sum: {} ******".format(result)) 28 | return result 29 | 30 | def edit_embedding_loopkup(embedding_table, list_input_ids, 31 | use_one_hot_embeddings, vocab_size, embedding_size): 32 | #list_input_ids is a list of list of input ids 33 | #returns embedding for each list, this represents 34 | #this represents embedding of phrase represented by list 35 | list1 = [item[0] for item in list_input_ids] 36 | list2 = [item[1] for item in list_input_ids] 37 | 38 | if use_one_hot_embeddings: 39 | one_hot_list1 = tf.one_hot(list1, depth=vocab_size) 40 | one_hot_list2 = tf.one_hot(list2, depth=vocab_size) 41 | w1 = tf.matmul(one_hot_list1, embedding_table) 42 | w2 = tf.matmul(one_hot_list2, embedding_table) 43 | else: 44 | w1 = tf.nn.embedding_lookup(embedding_table, list1) 45 | w2 = tf.nn.embedding_lookup(embedding_table, list2) 46 | 47 | return w1+w2 48 | 49 | 50 | 51 | def genealised_cross_entropy(probs, one_hot_labels,q=0.6, k=0): 52 | prob_mask = tf.to_float(tf.less_equal(probs,k)) 53 | probs = prob_mask * k + (1-prob_mask)*probs 54 | probs = tf.pow(probs,q) 55 | probs = 1 - probs 56 | probs = probs / q 57 | loss = tf.reduce_sum(probs * one_hot_labels, axis=-1) 58 | return loss 59 | 60 | def expand_embedding_matrix(embedding_matrix,batch_size): 61 | embedding_matrix = tf.expand_dims(embedding_matrix,0) 62 | embedding_matrix = tf.tile(embedding_matrix,[batch_size,1,1]) 63 | return embedding_matrix 64 | 65 | def timer(gen): 66 | while True: 67 | try: 68 | start_time = time.time() 69 | item = next(gen) 70 | elapsed_time = time.time() - start_time 71 | yield elapsed_time, item 72 | except StopIteration: 73 | break 74 | #def expected_edit_embeddings(probs,embedding_matrix, batch_size): 75 | #probs: B x T x E [E = no. of edits] 76 | #embedding_matrix: B x E x D 77 | #output: B x T x D 78 | -------------------------------------------------------------------------------- /word_edit_model.py: -------------------------------------------------------------------------------- 1 | # code adapted from https://github.com/google-research/bert 2 | 3 | # Unless required by applicable law or agreed to in writing, software 4 | # distributed under the License is distributed on an "AS IS" BASIS, 5 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6 | # See the License for the specific language governing permissions and 7 | # limitations under the License. 8 | """BERT finetuning runner.""" 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import collections 15 | import csv 16 | import os 17 | import modeling 18 | import modified_modeling #obtains contextual embeddings 19 | # for appends and replacements 20 | # for edit factorized architecture 21 | # figure 2 in the paper 22 | import optimization 23 | import tokenization 24 | import tensorflow as tf 25 | import numpy as np 26 | import pickle 27 | 28 | from itertools import chain 29 | from tensorflow.python.lib.io.file_io import get_matching_files 30 | from modeling import get_shape_list 31 | 32 | import wem_utils 33 | 34 | flags = tf.flags 35 | 36 | FLAGS = flags.FLAGS 37 | 38 | ## Required parameters 39 | flags.DEFINE_string( 40 | "data_dir", None, 41 | "The input data dir. Should contain the .txt files (or other data files) " 42 | "for the task.") 43 | 44 | flags.DEFINE_string( 45 | "bert_config_file", None, 46 | "The config json file corresponding to the pre-trained BERT model. " 47 | "This specifies the model architecture.") 48 | 49 | flags.DEFINE_string("vocab_file", None, 50 | "The vocabulary file that the BERT model was trained on.") 51 | 52 | flags.DEFINE_string( 53 | "output_dir", None, 54 | "The output directory where the model checkpoints will be written.") 55 | 56 | ## Other parameters 57 | 58 | flags.DEFINE_string( 59 | "init_checkpoint", None, 60 | "Initial checkpoint (usually from a pre-trained BERT model).") 61 | 62 | flags.DEFINE_bool( 63 | "do_lower_case", True, 64 | "Whether to lower case the input text. Should be True for uncased " 65 | "models and False for cased models.") 66 | 67 | flags.DEFINE_integer( 68 | "max_seq_length", 128, 69 | "The maximum total input sequence length after WordPiece tokenization. " 70 | "Sequences longer than this will be truncated, and sequences shorter " 71 | "than this will be padded.") 72 | 73 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 74 | 75 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 76 | 77 | flags.DEFINE_bool( 78 | "do_predict", False, 79 | "Whether to run the model in inference mode on the test set.") 80 | 81 | flags.DEFINE_integer("train_batch_size", 64, "Total batch size for training.") 82 | 83 | flags.DEFINE_integer("eval_batch_size", 512, "Total batch size for eval.") 84 | 85 | flags.DEFINE_integer("predict_batch_size", 512, "Total batch size for predict.") 86 | 87 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 88 | 89 | flags.DEFINE_float("num_train_epochs", 3.0, 90 | "Total number of training epochs to perform.") 91 | 92 | flags.DEFINE_float( 93 | "warmup_proportion", 0.1, 94 | "Proportion of training to perform linear learning rate warmup for. " 95 | "E.g., 0.1 = 10% of training.") 96 | 97 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 98 | "How often to save the model checkpoint.") 99 | 100 | flags.DEFINE_integer("iterations_per_loop", 1000, 101 | "How many steps to make in each estimator call.") 102 | 103 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 104 | 105 | tf.flags.DEFINE_string( 106 | "tpu_name", None, 107 | "The Cloud TPU to use for training. This should be either the name " 108 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 109 | "url.") 110 | 111 | tf.flags.DEFINE_string( 112 | "tpu_zone", None, 113 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 114 | "specified, we will attempt to automatically detect the GCE project from " 115 | "metadata.") 116 | 117 | tf.flags.DEFINE_string( 118 | "gcp_project", None, 119 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 120 | "specified, we will attempt to automatically detect the GCE project from " 121 | "metadata.") 122 | 123 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 124 | 125 | flags.DEFINE_integer( 126 | "num_tpu_cores", 8, 127 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 128 | 129 | flags.DEFINE_float("copy_weight", 1, "weight to copy") 130 | 131 | flags.DEFINE_bool("use_bert_more", True, "use bert more exhaustively for logit computation") 132 | 133 | flags.DEFINE_string("path_inserts", None, "path to insert pickle") 134 | 135 | flags.DEFINE_string("path_multitoken_inserts", None, "path to multitoken_inserts") 136 | 137 | flags.DEFINE_bool("subtract_replaced_from_replacement", True, "subtract_replaced_from_replacement") 138 | 139 | flags.DEFINE_string("eval_checkpoint", None, "checkpoint to evaluate gec model") 140 | 141 | flags.DEFINE_string("predict_checkpoint", None, "checkpoint to use for predictions") 142 | 143 | flags.DEFINE_integer("random_seed",0,"random seed for creating random initializations") 144 | 145 | flags.DEFINE_bool("create_train_tf_records", True, "whether to create train tf records") 146 | 147 | flags.DEFINE_bool("create_predict_tf_records", True, "whether to create predict tf records") 148 | 149 | #flags.DEFINE_bool("dump_probs", False, "dump edit probs to numpy file while decoding") 150 | 151 | class PaddingInputExample(object): 152 | """Fake example so the num input examples is a multiple of the batch size. 153 | 154 | When running eval/predict on the TPU, we need to pad the number of examples 155 | to be a multiple of the batch size, because the TPU requires a fixed batch 156 | size. The alternative is to drop the last batch, which is bad because it means 157 | the entire output data won't be generated. 158 | 159 | We use this class instead of `None` because treating `None` as padding 160 | battches could cause silent errors. 161 | """ 162 | 163 | class GECInputExample(object): 164 | def __init__(self, guid, input_sequence, edit_sequence=None): 165 | """Constructs a InputExample.""" 166 | self.guid = guid 167 | self.input_sequence = input_sequence 168 | self.edit_sequence = edit_sequence 169 | 170 | 171 | class GECInputFeatures(object): 172 | def __init__(self, input_sequence, input_mask, segment_ids, edit_sequence): 173 | self.input_sequence = input_sequence 174 | self.input_mask = input_mask 175 | self.segment_ids = segment_ids 176 | #self.label_id = label_id 177 | self.edit_sequence = edit_sequence 178 | 179 | 180 | class DataProcessor(object): 181 | """Base class for data converters for sequence classification data sets.""" 182 | 183 | def get_train_examples(self, data_dir): 184 | """Gets a collection of `InputExample`s for the train set.""" 185 | raise NotImplementedError() 186 | 187 | def get_dev_examples(self, data_dir): 188 | """Gets a collection of `InputExample`s for the dev set.""" 189 | raise NotImplementedError() 190 | 191 | def get_test_examples(self, data_dir): 192 | """Gets a collection of `InputExample`s for prediction.""" 193 | raise NotImplementedError() 194 | 195 | @classmethod 196 | def _read_file(cls, input_file): 197 | """Reads a tab separated value file.""" 198 | with tf.gfile.Open(input_file, "r") as f: 199 | return (line for line in f) 200 | 201 | 202 | class GECProcessor(DataProcessor): 203 | def get_train_examples(self, data_dir): 204 | """See base class.""" 205 | train_incorr = self._read_file(os.path.join(data_dir, "train_incorr.txt")) 206 | train_labels = self._read_file(os.path.join(data_dir, "train_labels.txt")) 207 | return self._create_examples(train_incorr, train_labels, "train") 208 | 209 | def get_dev_examples(self, data_dir): 210 | """See base class.""" 211 | dev_incorr = self._read_file(os.path.join(data_dir, "dev_incorr.txt")) 212 | dev_labels = self._read_file(os.path.join(data_dir, "dev_labels.txt")) 213 | return self._create_examples(dev_incorr, dev_labels, "dev") 214 | 215 | def get_test_examples(self, data_dir): 216 | """See base class.""" 217 | test_incorr = self._read_file(os.path.join(data_dir, "test_incorr.txt")) 218 | #test_labels = self._read_file(os.path.join(data_dir, "test_labels.txt")) 219 | test_labels = None 220 | return self._create_examples(test_incorr, test_labels, "test") 221 | 222 | def _create_examples(self, incorr_lines, labels_lines, set_type): 223 | """Creates examples for the training and dev sets.""" 224 | if set_type != "test": 225 | for (i, (incorr_line, labels_line)) in enumerate(zip(incorr_lines, labels_lines)): 226 | guid = "%s-%s" % (set_type, i) 227 | input_sequence = incorr_line 228 | edit_sequence = labels_line 229 | yield GECInputExample(guid, input_sequence, edit_sequence) 230 | else: 231 | for (i, incorr_line) in enumerate(incorr_lines): 232 | guid = "%s-%s" % (set_type, i) 233 | input_sequence = incorr_line 234 | edit_sequence = None 235 | yield GECInputExample(guid, input_sequence, edit_sequence) 236 | 237 | def gec_convert_single_example(ex_index, example, max_seq_length): 238 | """Converts a single `InputExample` into a single `InputFeatures`.""" 239 | if isinstance(example, PaddingInputExample): 240 | return GECInputFeatures( 241 | input_sequence=[0] * max_seq_length, 242 | input_mask=[0] * max_seq_length, 243 | segment_ids=[0] * max_seq_length, 244 | edit_sequence=[0] * max_seq_length) 245 | 246 | input_sequence = list(map(int, example.input_sequence.strip().split())) 247 | if len(input_sequence) > max_seq_length: 248 | input_sequence = input_sequence[0:(max_seq_length)] 249 | 250 | if example.edit_sequence: 251 | edit_sequence = list(map(int, example.edit_sequence.strip().split())) 252 | if len(edit_sequence) > max_seq_length: 253 | edit_sequence = edit_sequence[0:(max_seq_length)] 254 | 255 | if len(input_sequence) != len(edit_sequence): 256 | print("This should ideally not happen") 257 | exit(1) 258 | else: 259 | edit_sequence = None 260 | 261 | input_mask = [1] * len(input_sequence) 262 | segment_ids = [0] * len(input_sequence) 263 | 264 | # Zero-pad up to the sequence length. 265 | while len(input_sequence) < max_seq_length: 266 | input_sequence.append(0) 267 | if edit_sequence: 268 | edit_sequence.append(0) 269 | input_mask.append(0) 270 | segment_ids.append(0) 271 | 272 | if not edit_sequence: 273 | edit_sequence = [-1] * max_seq_length 274 | 275 | assert len(input_sequence) == max_seq_length 276 | assert len(input_mask) == max_seq_length 277 | assert len(segment_ids) == max_seq_length 278 | assert len(edit_sequence) == max_seq_length 279 | 280 | if ex_index < 5: 281 | tf.logging.info("*** Example ***") 282 | tf.logging.info("guid: %s" % (example.guid)) 283 | tf.logging.info("input_sequence: %s" % " ".join([str(x) for x in input_sequence])) 284 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 285 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 286 | tf.logging.info("edit_sequence: %s" % " ".join([str(x) for x in edit_sequence])) 287 | 288 | feature = GECInputFeatures( 289 | input_sequence=input_sequence, 290 | input_mask=input_mask, 291 | segment_ids=segment_ids, 292 | edit_sequence=edit_sequence) 293 | return feature 294 | 295 | def gec_file_based_convert_examples_to_features( 296 | examples, max_seq_length, output_dir, mode, num_examples): 297 | """Convert a set of `InputExample`s to a TFRecord file.""" 298 | num_writers = 0 299 | writer = None 300 | for (ex_index, example) in enumerate(examples): 301 | if ex_index%10000==0: 302 | tf.logging.info("Writing example %d of %d" % (ex_index, num_examples)) 303 | if ex_index % 500000000000 == 0: 304 | if writer: 305 | writer.close() 306 | del writer 307 | output_file = os.path.join(output_dir, "{}_{}.tf_record".format(mode,num_writers)) 308 | writer = tf.python_io.TFRecordWriter(output_file) 309 | num_writers += 1 310 | 311 | feature = gec_convert_single_example(ex_index, example, max_seq_length) 312 | 313 | def create_int_feature(values): 314 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 315 | return f 316 | 317 | features = collections.OrderedDict() 318 | features["input_sequence"] = create_int_feature(feature.input_sequence) 319 | features["input_mask"] = create_int_feature(feature.input_mask) 320 | features["segment_ids"] = create_int_feature(feature.segment_ids) 321 | features["edit_sequence"] = create_int_feature(feature.edit_sequence) 322 | 323 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 324 | writer.write(tf_example.SerializeToString()) 325 | 326 | def gec_file_based_input_fn_builder(output_dir, mode, seq_length, 327 | is_training, drop_remainder): 328 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 329 | #output_dir_parts = PurePath(output_dir).parts 330 | #output_dir_path = "/home/awasthiabhijeet05/mnt_bucket/" + "/".join(output_dir_parts[2:]) 331 | #print(output_dir_path+"/"+"{}_*.tf_record".format(mode)) 332 | input_files = get_matching_files(output_dir+"/"+"{}_*.tf_record".format(mode)) 333 | print("INPUT_FILES: " + " AND ".join(input_files)) 334 | 335 | name_to_features = { 336 | "input_sequence": tf.FixedLenFeature([seq_length], tf.int64), 337 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64), 338 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 339 | "edit_sequence": tf.FixedLenFeature([seq_length], tf.int64), 340 | } 341 | 342 | def _decode_record(record, name_to_features): 343 | """Decodes a record to a TensorFlow example.""" 344 | example = tf.parse_single_example(record, name_to_features) 345 | 346 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 347 | # So cast all int64 to int32. 348 | for name in list(example.keys()): 349 | t = example[name] 350 | if t.dtype == tf.int64: 351 | t = tf.to_int32(t) 352 | example[name] = t 353 | 354 | return example 355 | 356 | def input_fn(params): 357 | """The actual input function.""" 358 | batch_size = params["batch_size"] 359 | # For training, we want a lot of parallel reading and shuffling. 360 | # For eval, we want no shuffling and parallel reading doesn't matter. 361 | d = tf.data.TFRecordDataset(input_files) 362 | if is_training: 363 | d = d.repeat() 364 | d = d.shuffle(buffer_size=5000) 365 | 366 | d = d.apply( 367 | tf.contrib.data.map_and_batch( 368 | lambda record: _decode_record(record, name_to_features), 369 | batch_size=batch_size, 370 | drop_remainder=drop_remainder)) 371 | 372 | return d 373 | 374 | return input_fn 375 | 376 | def edit_word_embedding_lookup(embedding_table, input_ids, 377 | use_one_hot_embeddings, vocab_size, embedding_size): 378 | if use_one_hot_embeddings: 379 | one_hot_input_ids = tf.one_hot(input_ids, depth=vocab_size) 380 | output = tf.matmul(one_hot_input_ids, embedding_table) 381 | else: 382 | output = tf.nn.embedding_lookup(embedding_table, input_ids) 383 | return output 384 | 385 | 386 | def gec_create_model(bert_config, is_training, input_sequence, 387 | input_mask, segment_ids, edit_sequence, 388 | use_one_hot_embeddings, mode, 389 | copy_weight, 390 | use_bert_more, 391 | insert_ids, 392 | multitoken_insert_ids, 393 | subtract_replaced_from_replacement): 394 | """Creates a classification model.""" 395 | # insert_ids: word ids of unigram inserts (list) 396 | # multitoken_insert_ids: word_ids of bigram inserts (list of tuples of length 2) 397 | # Defining the space of all possible edits: 398 | # unk, sos and eos are dummy edits mapped to 0, 1 and 2 respectively 399 | # copy is mapped to 3 400 | # del is mapped to 4 401 | num_appends = len(insert_ids) + len(multitoken_insert_ids) 402 | num_replaces = num_appends # appends and replacements come from the same set (inserts and multitoken_inserts) 403 | append_begin = 5 # First append edit (mapped to 5) 404 | append_end = append_begin + num_appends - 1 #Last append edit 405 | rep_begin = append_end + 1 # First replace edit 406 | rep_end = rep_begin + num_replaces - 1 #Last replace edit 407 | num_suffix_transforms = 58 #num of transformation edits 408 | num_labels = 5 + num_appends + num_replaces + num_suffix_transforms # total number of edits 409 | print("************ num of labels : {} ***************".format(num_labels)) 410 | 411 | config = bert_config 412 | input_sequence_shape = modeling.get_shape_list(input_sequence,2) 413 | batch_size = input_sequence_shape[0] 414 | seq_len = input_sequence_shape[1] 415 | 416 | if not use_bert_more: #default use of bert (without logit factorisation) 417 | model = modeling.BertModel( 418 | config=bert_config, 419 | is_training=is_training, 420 | input_ids=input_sequence, 421 | input_mask=input_mask, 422 | token_type_ids=segment_ids, 423 | use_one_hot_embeddings=use_one_hot_embeddings) 424 | 425 | output_layer = model.get_sequence_output() 426 | else: # LOGIT FACTORISATION is On! 427 | model = modified_modeling.BertModel( 428 | config=bert_config, 429 | is_training=is_training, 430 | input_ids=input_sequence, 431 | input_mask=input_mask, 432 | token_type_ids=segment_ids, 433 | use_one_hot_embeddings=use_one_hot_embeddings) 434 | 435 | output_layer = model.get_sequence_output() 436 | replace_layer = output_layer[:,seq_len:2*seq_len,:] #representation of replacement slots as described in paper 437 | append_layer = output_layer[:,2*seq_len:3*seq_len,:] #representation of append slots as described in paper 438 | output_layer = output_layer[:,0:seq_len,:] 439 | 440 | output_layer_shape = modeling.get_shape_list(output_layer,3) 441 | hidden_size = output_layer_shape[-1] 442 | 443 | flattened_output_layer = tf.reshape(output_layer,[-1, hidden_size]) 444 | 445 | h_edit = flattened_output_layer 446 | 447 | if use_bert_more: 448 | h_word = flattened_output_layer 449 | flattened_replace_layer = tf.reshape(replace_layer,[-1, hidden_size]) 450 | flattened_append_layer = tf.reshape(append_layer, [-1, hidden_size]) 451 | 452 | m_replace = flattened_replace_layer 453 | m_append = flattened_append_layer 454 | 455 | 456 | with tf.variable_scope("cls/predictions"): 457 | with tf.variable_scope("transform"): 458 | h_word = tf.layers.dense( 459 | h_word, 460 | units=bert_config.hidden_size, 461 | activation=modeling.get_activation(bert_config.hidden_act), 462 | kernel_initializer=modeling.create_initializer( 463 | bert_config.initializer_range)) 464 | h_word = modeling.layer_norm(h_word) 465 | 466 | with tf.variable_scope("cls/predictions",reuse=True): 467 | with tf.variable_scope("transform",reuse=True): 468 | m_replace = tf.layers.dense( 469 | m_replace, 470 | units=bert_config.hidden_size, 471 | activation=modeling.get_activation(bert_config.hidden_act), 472 | kernel_initializer=modeling.create_initializer( 473 | bert_config.initializer_range)) 474 | m_replace = modeling.layer_norm(m_replace) 475 | 476 | with tf.variable_scope("cls/predictions",reuse=True): 477 | with tf.variable_scope("transform",reuse=True): 478 | m_append = tf.layers.dense( 479 | m_append, 480 | units=bert_config.hidden_size, 481 | activation=modeling.get_activation(bert_config.hidden_act), 482 | kernel_initializer=modeling.create_initializer( 483 | bert_config.initializer_range)) 484 | m_append = modeling.layer_norm(m_append) 485 | 486 | word_embedded_input = model.word_embedded_input 487 | flattened_word_embedded_input = tf.reshape(word_embedded_input, [-1, hidden_size]) 488 | 489 | labels = edit_sequence 490 | 491 | edit_weights = tf.get_variable( 492 | "edit_weights", [num_labels, hidden_size], 493 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 494 | 495 | if is_training: 496 | h_edit = tf.nn.dropout(h_edit, keep_prob=0.9) 497 | 498 | if use_bert_more: 499 | # append/replace weight vector for a given append or replace operation 500 | # correspond to word embedding for its token argument 501 | # for multitoken append/replace (e.g. has been) 502 | # weight vector is sum of word embeddings of token arguments 503 | 504 | append_weights = edit_word_embedding_lookup(model.embedding_table, insert_ids, 505 | use_one_hot_embeddings, config.vocab_size, config.hidden_size) 506 | replace_weights = append_weights #tokens in replace and append vocab are same 507 | #(i.e. inserts and multitoken_inserts) 508 | 509 | multitoken_append_weights = wem_utils.edit_embedding_loopkup(model.embedding_table, multitoken_insert_ids, 510 | use_one_hot_embeddings, config.vocab_size, config.hidden_size) 511 | multitoken_replace_weights = multitoken_append_weights #tokens in replace and append vocab are same 512 | #(i.e. inserts and multitoken_inserts) 513 | 514 | append_weights = tf.concat([append_weights, multitoken_append_weights],0) 515 | replace_weights = tf.concat([replace_weights, multitoken_replace_weights],0) 516 | 517 | with tf.variable_scope("loss"): 518 | edit_logits = tf.matmul(h_edit, edit_weights, transpose_b=True) #first term in eq3 in paper 519 | logits = edit_logits 520 | if use_bert_more: 521 | 522 | #=============== inplace_word_logits==============# #2nd term in eq3 in paper 523 | inplace_logit = tf.reduce_sum(h_word * flattened_word_embedded_input, axis=1, keepdims=True) #copy 524 | #inplace_logit = tf.reduce_sum(m_replace * flattened_word_embedded_input, axis=1, keepdims=True) #copy 525 | inplace_logit_appends = tf.tile(inplace_logit,[1,num_appends]) 526 | inplace_logit_transforms = tf.tile(inplace_logit,[1,num_suffix_transforms]) 527 | zero_3_logits = tf.zeros([batch_size*seq_len,3]) #unk sos eos 528 | zero_1_logits = tf.zeros([batch_size*seq_len,1]) # del 529 | zero_replace_logits = tf.zeros([batch_size*seq_len,num_replaces]) 530 | 531 | concat_list = [zero_3_logits, inplace_logit, zero_1_logits]\ 532 | + [inplace_logit_appends]\ 533 | + [zero_replace_logits]\ 534 | + [inplace_logit_transforms] 535 | 536 | inplace_word_logits = tf.concat(concat_list,1) 537 | 538 | #======additional (insert,replace) logits ====# #3rd term in eqn3 in paper 539 | zero_5_logits = tf.zeros([batch_size*seq_len,5]) 540 | append_logits = tf.matmul(m_append, append_weights, transpose_b=True) 541 | 542 | if subtract_replaced_from_replacement: 543 | replace_logits = replacement_minus_replaced_logits(m_replace, 544 | flattened_word_embedded_input, replace_weights) 545 | else: 546 | replace_logits = tf.matmul(m_replace, replace_weights, transpose_b=True) 547 | 548 | suffix_logits = tf.zeros([batch_size*seq_len,num_suffix_transforms]) 549 | 550 | concat_list = [zero_5_logits, append_logits, replace_logits, suffix_logits] 551 | additional_logits = tf.concat(concat_list,1) 552 | #====================================================# 553 | 554 | logits = edit_logits + inplace_word_logits + additional_logits 555 | logits_bias = tf.get_variable("output_bias", shape=[num_labels], initializer=tf.zeros_initializer()) 556 | logits += logits_bias 557 | 558 | logits = tf.reshape(logits, [output_layer_shape[0], output_layer_shape[1], num_labels]) 559 | log_probs = tf.nn.log_softmax(logits, axis=-1) 560 | probs = tf.nn.softmax(logits,axis=-1) 561 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 562 | per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 563 | per_token_loss = per_token_loss * tf.to_float(input_mask) 564 | mask = copy_weight*tf.to_float(tf.equal(labels,3)) + tf.to_float(tf.not_equal(labels,3)) 565 | masked_per_token_loss = per_token_loss * mask 566 | per_example_loss = tf.reduce_sum(masked_per_token_loss, axis=-1) 567 | loss = tf.reduce_mean(per_example_loss) 568 | 569 | return (loss, per_example_loss, logits, probs) 570 | 571 | 572 | def replacement_minus_replaced_logits(replace_layer, word_embedded_input, weights): 573 | result_1 = tf.matmul(replace_layer, weights, transpose_b=True) 574 | result_2 = replace_layer * word_embedded_input 575 | result_2 = tf.reduce_sum(result_2,1) 576 | result_2 = tf.expand_dims(result_2,-1) 577 | return result_1 - result_2 578 | 579 | def gec_model_fn_builder(bert_config, init_checkpoint, learning_rate, 580 | num_train_steps, num_warmup_steps, use_tpu, 581 | use_one_hot_embeddings, copy_weight, 582 | use_bert_more, 583 | inserts, insert_ids, 584 | multitoken_inserts, multitoken_insert_ids, 585 | subtract_replaced_from_replacement): 586 | """Returns `model_fn` closure for TPUEstimator.""" 587 | 588 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 589 | """The `model_fn` for TPUEstimator.""" 590 | 591 | tf.logging.info("*** Features ***") 592 | for name in sorted(features.keys()): 593 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 594 | 595 | input_sequence = features["input_sequence"] 596 | input_mask = features["input_mask"] 597 | segment_ids = features["segment_ids"] 598 | edit_sequence = features["edit_sequence"] 599 | 600 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 601 | 602 | (total_loss, per_example_loss, logits, probabilities) = gec_create_model( 603 | bert_config, is_training, input_sequence, 604 | input_mask, segment_ids, edit_sequence, 605 | use_one_hot_embeddings, mode, 606 | copy_weight, 607 | use_bert_more, 608 | insert_ids, 609 | multitoken_insert_ids, 610 | subtract_replaced_from_replacement) 611 | 612 | tvars = tf.trainable_variables() 613 | initialized_variable_names = {} 614 | scaffold_fn = None 615 | if init_checkpoint: 616 | (assignment_map, initialized_variable_names 617 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 618 | if use_tpu: 619 | 620 | def tpu_scaffold(): 621 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 622 | return tf.train.Scaffold() 623 | 624 | scaffold_fn = tpu_scaffold 625 | else: 626 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 627 | 628 | tf.logging.info("**** Trainable Variables ****") 629 | for var in tvars: 630 | init_string = "" 631 | if var.name in initialized_variable_names: 632 | init_string = ", *INIT_FROM_CKPT*" 633 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 634 | init_string) 635 | 636 | output_spec = None 637 | if mode == tf.estimator.ModeKeys.TRAIN: 638 | train_op = optimization.create_optimizer( 639 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 640 | 641 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 642 | mode=mode, 643 | loss=total_loss, 644 | train_op=train_op, 645 | scaffold_fn=scaffold_fn) 646 | 647 | elif mode == tf.estimator.ModeKeys.EVAL: 648 | def metric_fn(per_example_loss, edit_sequence, logits): 649 | predictions = tf.argmax(logits[:,:,3:], axis=-1, output_type=tf.int32) + 3 650 | mask = tf.equal(edit_sequence,0) 651 | mask = tf.logical_or(mask, tf.equal(edit_sequence,1)) 652 | mask = tf.logical_or(mask, tf.equal(edit_sequence,2)) 653 | mask = tf.logical_or(mask, tf.equal(edit_sequence,3)) 654 | mask = tf.to_float(tf.logical_not(mask)) 655 | accuracy = tf.metrics.accuracy(edit_sequence, predictions, mask) 656 | loss = tf.metrics.mean(per_example_loss) 657 | result_dict = {} 658 | result_dict["eval_accuracy"] = accuracy 659 | result_dict["eval_loss"] = loss 660 | return { 661 | "eval_accuracy": accuracy, 662 | "eval_loss": loss, 663 | } 664 | 665 | eval_metrics = (metric_fn, [per_example_loss, edit_sequence, logits]) 666 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 667 | mode=mode, 668 | loss=total_loss, 669 | eval_metrics=eval_metrics, 670 | scaffold_fn=scaffold_fn) 671 | else: 672 | #first three edit ids unk, sos, eos are dummy. We do not consider them in predictions 673 | predictions = tf.argmax(logits[:,:,3:], axis=-1, output_type=tf.int32) + 3 674 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 675 | mode=mode, 676 | predictions={"predictions": predictions, "logits":logits}, 677 | scaffold_fn=scaffold_fn) 678 | return output_spec 679 | 680 | return model_fn 681 | 682 | def get_file_length(file_address): 683 | num_lines = sum(1 for line in tf.gfile.GFile(file_address,"r")) 684 | return num_lines 685 | 686 | 687 | def main(_): 688 | 689 | tf.logging.set_verbosity(tf.logging.INFO) 690 | 691 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: 692 | raise ValueError( 693 | "At least one of `do_train`, `do_eval` or `do_predict' must be True.") 694 | 695 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 696 | 697 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 698 | raise ValueError( 699 | "Cannot use sequence length %d because the BERT model " 700 | "was only trained up to sequence length %d" % 701 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 702 | 703 | tf.gfile.MakeDirs(FLAGS.output_dir) 704 | 705 | processor = GECProcessor() 706 | 707 | tokenizer = tokenization.FullTokenizer( 708 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 709 | 710 | inserts = pickle.load(tf.gfile.Open(FLAGS.path_inserts,"rb")) 711 | insert_ids = tokenizer.convert_tokens_to_ids(inserts) 712 | 713 | multitoken_inserts = pickle.load(tf.gfile.Open(FLAGS.path_multitoken_inserts,"rb")) 714 | multitoken_insert_ids = wem_utils.list_to_ids(multitoken_inserts, tokenizer) 715 | 716 | tpu_cluster_resolver = None 717 | if FLAGS.use_tpu and FLAGS.tpu_name: 718 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 719 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 720 | 721 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 722 | run_config = tf.contrib.tpu.RunConfig( 723 | cluster=tpu_cluster_resolver, 724 | master=FLAGS.master, 725 | model_dir=FLAGS.output_dir, 726 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 727 | keep_checkpoint_max=15, 728 | tpu_config=tf.contrib.tpu.TPUConfig( 729 | iterations_per_loop=FLAGS.iterations_per_loop, 730 | num_shards=FLAGS.num_tpu_cores, 731 | per_host_input_for_training=is_per_host) 732 | ) 733 | 734 | train_examples = None 735 | num_train_steps = None 736 | num_warmup_steps = None 737 | 738 | if FLAGS.do_train: 739 | tf.set_random_seed(FLAGS.random_seed) 740 | if FLAGS.create_train_tf_records: 741 | train_examples = processor.get_train_examples(FLAGS.data_dir) 742 | num_train_examples = get_file_length(os.path.join(FLAGS.data_dir, "train_labels.txt")) 743 | print("Number of training examples: {}".format(num_train_examples)) 744 | num_train_steps = int( 745 | (num_train_examples / FLAGS.train_batch_size) * FLAGS.num_train_epochs) 746 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 747 | 748 | model_fn = gec_model_fn_builder( 749 | bert_config=bert_config, 750 | init_checkpoint=FLAGS.init_checkpoint, 751 | learning_rate=FLAGS.learning_rate, 752 | num_train_steps=num_train_steps, 753 | num_warmup_steps=num_warmup_steps, 754 | use_tpu=FLAGS.use_tpu, 755 | use_one_hot_embeddings=FLAGS.use_tpu, 756 | copy_weight=FLAGS.copy_weight, 757 | use_bert_more=FLAGS.use_bert_more, 758 | inserts=inserts, 759 | insert_ids=insert_ids, 760 | multitoken_inserts=multitoken_inserts, 761 | multitoken_insert_ids=multitoken_insert_ids, 762 | subtract_replaced_from_replacement=FLAGS.subtract_replaced_from_replacement, 763 | ) 764 | 765 | # If TPU is not available, this will fall back to normal Estimator on CPU 766 | # or GPU. 767 | estimator = tf.contrib.tpu.TPUEstimator( 768 | use_tpu=FLAGS.use_tpu, 769 | model_fn=model_fn, 770 | config=run_config, 771 | train_batch_size=FLAGS.train_batch_size, 772 | eval_batch_size=FLAGS.eval_batch_size, 773 | predict_batch_size=FLAGS.predict_batch_size) 774 | 775 | if FLAGS.do_train: 776 | train_record_dir = FLAGS.output_dir 777 | if FLAGS.create_train_tf_records: 778 | gec_file_based_convert_examples_to_features( 779 | train_examples, FLAGS.max_seq_length, train_record_dir, "train", num_train_examples) 780 | 781 | tf.logging.info("***** Running training *****") 782 | tf.logging.info(" Num examples = %d", num_train_examples) 783 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 784 | tf.logging.info(" Num steps = %d", num_train_steps) 785 | 786 | train_input_fn = gec_file_based_input_fn_builder( 787 | output_dir=train_record_dir, 788 | mode="train", 789 | seq_length=FLAGS.max_seq_length, 790 | is_training=True, 791 | drop_remainder=True) 792 | #train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=num_train_steps) 793 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 794 | 795 | if FLAGS.do_eval: 796 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 797 | #eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") 798 | num_eval_examples = get_file_length(os.path.join(FLAGS.data_dir, "dev_labels.txt")) 799 | gec_file_based_convert_examples_to_features( 800 | eval_examples, FLAGS.max_seq_length, FLAGS.output_dir, "eval", num_eval_examples) 801 | 802 | tf.logging.info("***** Running evaluation *****") 803 | tf.logging.info(" Num examples = %d", num_eval_examples) 804 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 805 | 806 | # This tells the estimator to run through the entire set. 807 | eval_steps = None 808 | # However, if running eval on the TPU, you will need to specify the 809 | # number of steps. 810 | if FLAGS.use_tpu: 811 | # Eval will be slightly WRONG on the TPU because it will truncate 812 | # the last batch. 813 | eval_steps = int(num_eval_examples / FLAGS.eval_batch_size) 814 | 815 | eval_drop_remainder = True if FLAGS.use_tpu else False 816 | eval_input_fn = gec_file_based_input_fn_builder( 817 | output_dir=FLAGS.output_dir, 818 | mode="eval", 819 | seq_length=FLAGS.max_seq_length, 820 | is_training=False, 821 | drop_remainder=eval_drop_remainder) 822 | 823 | result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps, checkpoint_path=FLAGS.eval_checkpoint) 824 | 825 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 826 | with tf.gfile.GFile(output_eval_file, "w") as writer: 827 | tf.logging.info("***** Eval results *****") 828 | for key in sorted(result.keys()): 829 | tf.logging.info(" %s = %s", key, str(result[key])) 830 | writer.write("%s = %s\n" % (key, str(result[key]))) 831 | 832 | if FLAGS.do_predict: 833 | if FLAGS.create_predict_tf_records: 834 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 835 | num_test_examples = get_file_length(os.path.join(FLAGS.data_dir, "test_incorr.txt")) 836 | print("num of test_examples: {}".format(num_test_examples)) 837 | num_actual_predict_examples = num_test_examples 838 | 839 | if FLAGS.create_predict_tf_records: 840 | if FLAGS.use_tpu: 841 | # Warning: According to tpu_estimator.py Prediction on TPU is an 842 | # experimental feature and hence not supported here 843 | #raise ValueError("Prediction in TPU not supported") 844 | padded_examples = [] 845 | 846 | while num_test_examples % FLAGS.predict_batch_size != 0: 847 | padded_examples.append(PaddingInputExample()) 848 | num_test_examples += 1 849 | 850 | iterables = [predict_examples, padded_examples] 851 | predict_examples = chain() 852 | for iterable in iterables: 853 | predict_examples = chain(predict_examples, iterable) 854 | 855 | 856 | #predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") 857 | gec_file_based_convert_examples_to_features(predict_examples, 858 | FLAGS.max_seq_length, 859 | FLAGS.output_dir, "predict", num_test_examples) 860 | 861 | tf.logging.info("***** Running prediction*****") 862 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 863 | num_test_examples, num_actual_predict_examples, 864 | num_test_examples - num_actual_predict_examples) 865 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 866 | 867 | predict_drop_remainder = True if FLAGS.use_tpu else False 868 | predict_input_fn = gec_file_based_input_fn_builder( 869 | output_dir=FLAGS.output_dir, 870 | mode="predict", 871 | seq_length=FLAGS.max_seq_length, 872 | is_training=False, 873 | drop_remainder=predict_drop_remainder) 874 | #os.path.join(FLAGS.data_dir,"reverse_mtedit_dev_lang8_nucle_fce_wi_locness_output_0.1_ensemble_1") 875 | 876 | result = estimator.predict(input_fn=predict_input_fn, checkpoint_path=FLAGS.predict_checkpoint) 877 | print("type of result: {}".format(type(result))) 878 | 879 | output_predict_file = os.path.join(FLAGS.output_dir, "test_results.txt") 880 | output_logits_file = os.path.join(FLAGS.output_dir,"test_logits.npz") 881 | with tf.gfile.GFile(output_predict_file, "w") as writer: 882 | num_written_lines = 0 883 | #start_time = time.time() 884 | total_time_per_step = 0 885 | #probs_array=[] 886 | logits_array = [] 887 | tf.logging.info("***** Predict results *****") 888 | for i,(elapsed_time,prediction) in enumerate(wem_utils.timer(result)): 889 | if i >= num_actual_predict_examples: 890 | continue 891 | total_time_per_step += elapsed_time 892 | output_line = " ".join( 893 | str(edit) for edit in prediction["predictions"] if edit > 0) + "\n" 894 | #logits = np.array(prediction["logits"]) 895 | #logits_array.append(logits) 896 | writer.write(output_line) 897 | num_written_lines += 1 898 | assert num_written_lines == num_actual_predict_examples 899 | #with tf.gfile.GFile(output_logits_file,"w") as writer: 900 | #np.save(writer,np.array(logits_array)) 901 | tf.logging.info("Decoding time: {}".format(total_time_per_step)) 902 | 903 | 904 | if __name__ == "__main__": 905 | flags.mark_flag_as_required("data_dir") 906 | flags.mark_flag_as_required("vocab_file") 907 | flags.mark_flag_as_required("bert_config_file") 908 | flags.mark_flag_as_required("output_dir") 909 | tf.app.run() 910 | --------------------------------------------------------------------------------