├── read_text_data.py ├── encoding.py ├── grammar_test.py ├── LICENSE ├── language_train.py ├── grammar_train.py ├── process_text.py ├── irregular.py ├── attention.py ├── multihead.py ├── deepproof_model.py └── mistakes.py /read_text_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import numpy as np 4 | import encoding 5 | import h5py 6 | import sys 7 | 8 | 9 | with h5py.File(sys.argv[1], 'r') as hf: 10 | input_text = hf['input'][:] 11 | output_text = hf['output'][:] 12 | 13 | maxlen = input_text.shape[1] 14 | #print(maxlen) 15 | for i in range(input_text.shape[0]): 16 | line = input_text[i,:] 17 | orig = output_text[i,:] 18 | #print (encoding.decode_string(line)) 19 | #print(encoding.decode_string(line)) 20 | #print(encoding.decode_string(orig), '\n') 21 | print(encoding.decode_string(line), '\t', encoding.decode_string(orig)) 22 | -------------------------------------------------------------------------------- /encoding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | maxord = 8192 5 | #First three chars are special 6 | #0: begin/end of sentence/paragraph 7 | #1: truncated sentence/paragraph 8 | #2: Unknown char 9 | char_list = '|_~ ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!"#$%&\'()*+,-./:;=?[]àéèêëïîìöôòûùỳçÀÉÈÊËÏÎÌÖÔÒÛÜÙỲÇ' 10 | 11 | #not in list: < > @ \ ^ _ ` { | } ~ TAB 12 | 13 | rev_list = dict( 14 | [(char, i) for i, char in enumerate(char_list[3:])]) 15 | 16 | charid = np.zeros(maxord+1, dtype='uint8') + 2 17 | for i, char in enumerate(char_list): 18 | charid[ord(char)] = i 19 | 20 | 21 | def encode_string(string, outlen, offset): 22 | strlen = len(string) 23 | out = np.zeros((outlen,), dtype='uint8') 24 | out[outlen-1] = 1 if strlen > outlen - 2 else 0 25 | out[0] = 1 if offset != 0 else 0 26 | 27 | copylen = min(strlen,outlen - 2) 28 | for i, char in enumerate(string[offset:copylen+offset]): 29 | out[i+1] = charid[min(maxord,ord(char))] 30 | return out 31 | 32 | def decode_string(enc): 33 | out = ''.join([char_list[x] for x in enc]) 34 | return out 35 | -------------------------------------------------------------------------------- /grammar_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | '''Sequence to sequence grammar check. 3 | ''' 4 | from __future__ import print_function 5 | 6 | import math 7 | from keras.models import Model 8 | from keras.layers import Input, LSTM, CuDNNLSTM, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D 9 | from keras import backend as K 10 | import numpy as np 11 | import h5py 12 | import sys 13 | import encoding 14 | import deepproof_model 15 | 16 | import tensorflow as tf 17 | from keras.backend.tensorflow_backend import set_session 18 | config = tf.ConfigProto() 19 | config.gpu_options.per_process_gpu_memory_fraction = 0.29 20 | set_session(tf.Session(config=config)) 21 | 22 | encoder_model, decoder_model, model = deepproof_model.create(False) 23 | 24 | model.load_weights('proof8b4.h5') 25 | 26 | 27 | for line in sys.stdin: 28 | line = line.rstrip() 29 | input_seq = encoding.encode_string(line, len(line)+20, 0) 30 | input_seq = np.reshape(input_seq, (1, input_seq.shape[0], 1)) 31 | decoded_sentence0 = deepproof_model.decode_sequence([encoder_model, decoder_model], input_seq) 32 | decoded_sentence = deepproof_model.beam_decode_sequence([encoder_model, decoder_model], input_seq) 33 | print('-') 34 | print('Input sentence: ', encoding.decode_string(input_seq[0, :, 0])) 35 | print('Decoded sentence0:', decoded_sentence0) 36 | print('Decoded sentence: ', decoded_sentence) 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | COPYRIGHT 2 | 3 | Copyright (c) 2017-2018 Jean-Marc Valin 4 | All rights reserved. 5 | 6 | The code is based on the following Keras example: 7 | https://github.com/keras-team/keras/blob/master/examples/lstm_seq2seq.py 8 | but was then heavily modified 9 | 10 | All other contributions: 11 | Copyright (c) 2015 - 2018, the respective contributors. 12 | All rights reserved. 13 | 14 | Each contributor holds copyright over their respective contributions. 15 | The project versioning (Git) records all such contribution source information. 16 | 17 | LICENSE 18 | 19 | The MIT License (MIT) 20 | 21 | Permission is hereby granted, free of charge, to any person obtaining a copy 22 | of this software and associated documentation files (the "Software"), to deal 23 | in the Software without restriction, including without limitation the rights 24 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 25 | copies of the Software, and to permit persons to whom the Software is 26 | furnished to do so, subject to the following conditions: 27 | 28 | The above copyright notice and this permission notice shall be included in all 29 | copies or substantial portions of the Software. 30 | 31 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 32 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 33 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 34 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 35 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 36 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 37 | SOFTWARE. 38 | 39 | -------------------------------------------------------------------------------- /language_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | '''Sequence to sequence grammar check. 3 | ''' 4 | from __future__ import print_function 5 | 6 | import math 7 | from keras.models import Model 8 | from keras.layers import Input, LSTM, CuDNNLSTM, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D 9 | from keras import backend as K 10 | import numpy as np 11 | import h5py 12 | import sys 13 | import encoding 14 | 15 | import tensorflow as tf 16 | from keras.backend.tensorflow_backend import set_session 17 | config = tf.ConfigProto() 18 | config.gpu_options.per_process_gpu_memory_fraction = 0.29 19 | set_session(tf.Session(config=config)) 20 | 21 | embed_dim = 16 22 | batch_size = 128 # Batch size for training. 23 | epochs = 1 # Number of epochs to train for. 24 | latent_dim = 128 # Latent dimensionality of the encoding space. 25 | 26 | with h5py.File(sys.argv[1], 'r') as hf: 27 | output_text = hf['output'][:] 28 | decoder_target_data = np.reshape(output_text, (output_text.shape[0], output_text.shape[1], 1)) 29 | decoder_input_data = np.zeros((output_text.shape[0], output_text.shape[1], 1), dtype='uint8') 30 | decoder_input_data[:,1:,:] = decoder_target_data[:,:-1,:] 31 | max_decoder_seq_length = output_text.shape[1] 32 | num_encoder_tokens = len(encoding.char_list) 33 | 34 | print("Number of sentences: ", output_text.shape[0]) 35 | print("Sentence length: ", output_text.shape[1]) 36 | print("Number of chars: ", num_encoder_tokens) 37 | 38 | # Define an input sequence and process it. 39 | reshape1 = Reshape((-1, embed_dim)) 40 | embed = Embedding(num_encoder_tokens, embed_dim) 41 | conv = Conv1D(128, 5, padding='causal', activation='tanh') 42 | conv2 = Conv1D(128, 1, padding='causal', activation='tanh') 43 | 44 | # Set up the decoder, using `encoder_states` as initial state. 45 | decoder_inputs = Input(shape=(None, 1)) 46 | # We set up our decoder to return full output sequences, 47 | # and to return internal states as well. We don't use the 48 | # return states in the training model, but we will use them in inference. 49 | decoder_lstm = CuDNNLSTM(2*latent_dim, return_sequences=True) 50 | decoder_lstm2 = CuDNNLSTM(latent_dim, return_sequences=True) 51 | 52 | dec_lstm_input = conv2(conv(reshape1(embed(decoder_inputs)))) 53 | 54 | decoder_outputs = decoder_lstm2(decoder_lstm(dec_lstm_input)) 55 | decoder_dense = Dense(num_encoder_tokens, activation='softmax') 56 | decoder_outputs = decoder_dense(decoder_outputs) 57 | 58 | # Define the model that will turn 59 | # `encoder_input_data` & `decoder_input_data` into `decoder_target_data` 60 | model = Model(decoder_inputs, decoder_outputs) 61 | 62 | # Run training 63 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) 64 | model.summary() 65 | model.fit(decoder_input_data, decoder_target_data, 66 | batch_size=batch_size, 67 | epochs=epochs, 68 | validation_split=0.2) 69 | # Save model 70 | model.save('language.h5') 71 | #model.load_weights('s2s.h5') 72 | -------------------------------------------------------------------------------- /grammar_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | '''Sequence to sequence grammar check. 3 | ''' 4 | from __future__ import print_function 5 | 6 | import math 7 | from keras.models import Model 8 | from keras.layers import Input, LSTM, CuDNNLSTM, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D 9 | from keras.optimizers import Adam 10 | from keras import backend as K 11 | import numpy as np 12 | import h5py 13 | import sys 14 | import encoding 15 | 16 | import deepproof_model 17 | 18 | import tensorflow as tf 19 | from keras.backend.tensorflow_backend import set_session 20 | config = tf.ConfigProto() 21 | config.gpu_options.per_process_gpu_memory_fraction = 0.44 22 | set_session(tf.Session(config=config)) 23 | 24 | batch_size = 128 # Batch size for training. 25 | epochs = 1 # Number of epochs to train for. 26 | 27 | encoder_model, decoder_model, model = deepproof_model.create(True) 28 | 29 | input_text = None 30 | output_text = None 31 | for file in sys.argv[1:]: 32 | with h5py.File(file, 'r') as hf: 33 | if input_text is None: 34 | input_text = hf['input'][:] 35 | output_text = hf['output'][:] 36 | else: 37 | input_text = np.concatenate([input_text, hf['input'][:]]) 38 | output_text = np.concatenate([output_text, hf['output'][:]]) 39 | #input_text = input_text[0:8000, :] 40 | #output_text = output_text[0:8000, :] 41 | input_data = np.reshape(input_text, (input_text.shape[0], input_text.shape[1], 1)) 42 | decoder_target_data = np.reshape(output_text, (output_text.shape[0], output_text.shape[1], 1)) 43 | decoder_input_data = np.zeros((input_text.shape[0], input_text.shape[1], 1), dtype='uint8') 44 | decoder_input_data[:,1:,:] = decoder_target_data[:,:-1,:] 45 | max_decoder_seq_length = input_text.shape[1] 46 | num_encoder_tokens = len(encoding.char_list) 47 | 48 | print("Number of sentences: ", input_text.shape[0]) 49 | print("Sentence length: ", input_text.shape[1]) 50 | print("Number of chars: ", num_encoder_tokens) 51 | 52 | # Run training 53 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) 54 | #model.load_weights('proof7c.h5') 55 | model.summary() 56 | model.fit([input_data[:,:,0:1], decoder_input_data], decoder_target_data, 57 | batch_size=batch_size, 58 | epochs=epochs, 59 | validation_split=0.2) 60 | # Save model 61 | model.save('proof8b.h5') 62 | model.compile(optimizer=Adam(0.0003), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) 63 | model.fit([input_data[:,:,0:1], decoder_input_data], decoder_target_data, 64 | batch_size=batch_size, 65 | epochs=epochs, 66 | validation_split=0.2) 67 | model.save('proof8b2.h5') 68 | model.fit([input_data[:,:,0:1], decoder_input_data], decoder_target_data, 69 | batch_size=batch_size, 70 | epochs=epochs, 71 | validation_split=0.2) 72 | model.save('proof8b3.h5') 73 | model.fit([input_data[:,:,0:1], decoder_input_data], decoder_target_data, 74 | batch_size=batch_size, 75 | epochs=epochs, 76 | validation_split=0.2) 77 | model.save('proof8b4.h5') 78 | 79 | 80 | 81 | start = int(.9*input_text.shape[0]) 82 | for seq_index in range(start, start+1000): 83 | # Take one sequence (part of the training test) 84 | # for trying out decoding. 85 | input_seq = input_data[seq_index: seq_index + 1] 86 | decoded_sentence0 = deepproof_model.decode_sequence([encoder_model, decoder_model], input_seq) 87 | decoded_sentence = deepproof_model.beam_decode_sequence([encoder_model, decoder_model], input_seq) 88 | deepproof_model.decode_ground_truth([encoder_model, decoder_model], input_seq, output_text[seq_index,:]) 89 | print('-') 90 | print('Input sentence: ', encoding.decode_string(input_text[seq_index,:])) 91 | print('Decoded sentence0:', decoded_sentence0) 92 | print('Decoded sentence: ', decoded_sentence) 93 | print('Original sentence:', encoding.decode_string(output_text[seq_index,:])) 94 | -------------------------------------------------------------------------------- /process_text.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import random 4 | import re 5 | import numpy as np 6 | from mistakes import * 7 | import encoding 8 | import h5py 9 | from irregular import irregular_verbs 10 | 11 | maxlen = 300 12 | minlen = 80 13 | frac = .4 14 | 15 | text = [] 16 | 17 | print("Computing lines", file=sys.stderr) 18 | for line in sys.stdin: 19 | if random.random() > frac: 20 | continue 21 | line = line.rstrip() 22 | if len(line) < minlen: 23 | continue 24 | if line.find('ISBN') >= 0: 25 | continue 26 | if line.find('University Press') >= 0: 27 | continue 28 | if re.match("^[0-9][0-9][0-9][0-9]", line): 29 | continue 30 | if re.match("\([0-9][0-9][0-9][0-9]\)", line): 31 | continue 32 | line = " ".join(line.split()) 33 | line = line.replace("`", "'") 34 | 35 | #modify the "correct" sentence to randomly add/remove contractions 36 | line = word_substitute(line, acceptable_contractions, 0.2) 37 | 38 | orig_len = strlen = len(line) 39 | #print(line) 40 | chop_begin = chop_end = False 41 | if strlen > maxlen - 2: 42 | c = random.randrange(3) 43 | if c == 0: 44 | pos = 0 45 | elif c == 1: 46 | pos = strlen - maxlen + 2 47 | else: 48 | pos = random.randrange(strlen - maxlen + 2) 49 | if pos > 0 and line[pos-1] != ' ': 50 | pos = pos + line[pos:].find(' ') + 1 51 | chop_begin = True if pos > 0 else False 52 | line = line[pos:] 53 | strlen = len(line) 54 | if strlen > maxlen - 2: 55 | chop_end = True 56 | end = maxlen - 2 57 | while line[end] != ' ' and end > 0: 58 | end -= 1 59 | line = line[:end] 60 | if len(line) < minlen: 61 | continue 62 | 63 | orig = line; 64 | #print (orig) 65 | #continue 66 | line = word_substitute(line, irregular_rules, 0.2) 67 | line = word_substitute(line, regular_verbs, 0.2) 68 | line = word_substitute(line, homonyms_rules, 0.2) 69 | line = word_substitute(line, prepositions_rules, 0.2) 70 | line = word_substitute(line, misc_rules, 0.2) 71 | line = word_substitute(line, comparison_rules, 0.2) 72 | line = word_delete(line, omitted_words, 0.02) 73 | line = word_double(line, omitted_words, 0.02) 74 | line = strip_plural(line, 0.2) 75 | line = add_plural(line, 0.02) 76 | line = strip_punctuation(line, 0.2); 77 | line = add_comma(line, 0.02); 78 | line = subword_substitute(line, subword_subst, 0.02) 79 | line = letter_deletion(line, 0.004) 80 | line = letter_doubling(line, 0.004) 81 | line = letter_swap(line, 0.004) 82 | line = letter_subst(line, 0.002) 83 | if len(text) % 1000000 == 0: 84 | print(len(text), file=sys.stderr) 85 | 86 | text.append((line, orig, chop_begin, chop_end, orig_len)) 87 | #if len(text) > 1000: 88 | # break 89 | print (line, '\t', orig) 90 | 91 | print("Encoding lines", file=sys.stderr) 92 | input_text = np.zeros((len(text), maxlen), dtype='uint8') 93 | output_text = np.zeros((len(text), maxlen), dtype='uint8') 94 | for i, entry in enumerate(text): 95 | line, orig, chop_begin, chop_end, orig_len = entry 96 | byte_line = encoding.encode_string(line, maxlen, 0) 97 | byte_orig = encoding.encode_string(orig, maxlen, 0) 98 | if chop_begin: 99 | byte_line[0] = 1 100 | byte_orig[0] = 1 101 | if chop_end: 102 | byte_orig[len(orig)+1] = 1 103 | if len(line)+1 < maxlen: 104 | byte_line[len(line)+1] = 1 105 | input_text[i,:] = byte_line 106 | output_text[i,:] = byte_orig 107 | if i % 1000000 == 0: 108 | print(i, file=sys.stderr) 109 | #print (orig_len, encoding.decode_string(byte_line), '\t', encoding.decode_string(byte_orig)) 110 | #print() 111 | 112 | h5f = h5py.File(sys.argv[1], 'w'); 113 | h5f.create_dataset('input', data=input_text) 114 | h5f.create_dataset('output', data=output_text) 115 | h5f.close() 116 | -------------------------------------------------------------------------------- /irregular.py: -------------------------------------------------------------------------------- 1 | irregular_verbs = [ 2 | ["be", "was", "were", "been"], 3 | ["bear", "bore", "borne", "born"], 4 | ["beat", "beat", "beaten"], 5 | ["become", "became", "become"], 6 | ["begin", "began", "begun"], 7 | ["bend", "bent", "bent"], 8 | ["bet", "bet", "bet"], 9 | ["bid", "bid", "bade", "bid", "bidden"], 10 | ["bind", "bound", "bound"], 11 | ["bite", "bit", "bitten"], 12 | ["bleed", "bled", "bled"], 13 | ["blow", "blew", "blown"], 14 | ["break", "broke", "broken"], 15 | ["breed", "bred", "bred"], 16 | ["bring", "brought", "brought"], 17 | ["broadcast", "broadcast", "broadcast"], 18 | ["build", "built", "built"], 19 | ["burn", "burnt", "burnt"], 20 | ["burst", "burst", "burst"], 21 | ["bust", "bust", "bust"], 22 | ["buy", "bought", "bought"], 23 | ["cast", "cast", "cast"], 24 | ["catch", "caught", "caught"], 25 | ["choose", "chose", "chosen"], 26 | ["cling", "clung", "clung"], 27 | ["come", "came", "come"], 28 | ["cost", "cost", "cost"], 29 | ["creep", "crept", "crept"], 30 | ["cut", "cut", "cut"], 31 | ["deal", "dealt", "dealt"], 32 | ["dig", "dug", "dug"], 33 | ["dive", "dived", "dove", "dived"], 34 | ["do", "does", "did", "done"], 35 | ["draw", "drew", "drawn"], 36 | ["dream", "dreamt", "dreamt"], 37 | ["drink", "drank", "drunk"], 38 | ["drive", "drove", "driven"], 39 | ["eat", "ate", "eaten"], 40 | ["fall", "fell", "fallen"], 41 | ["feed", "fed", "fed"], 42 | ["feel", "felt", "felt"], 43 | ["fight", "fought", "fought"], 44 | ["find", "found", "found"], 45 | ["flee", "fled", "fled"], 46 | ["fling", "flung", "flung"], 47 | ["fly", "flew", "flown"], 48 | ["forbid", "forbade", "forbad", "forbidden"], 49 | ["forecast", "forecast", "forecast"], 50 | ["forget", "forgot", "forgotten"], 51 | ["forsake", "forsook", "forsaken"], 52 | ["freeze", "froze", "frozen"], 53 | ["get", "got", "got", "gotten"], 54 | ["give", "gave", "given"], 55 | ["grind", "ground", "ground"], 56 | ["go", "goes", "went", "gone"], 57 | ["grow", "grew", "grown"], 58 | ["hang", "hung", "hung"], 59 | ["have", "has", "had", "had"], 60 | ["hear", "heard", "heard"], 61 | ["hide", "hid", "hidden"], 62 | ["hit", "hit", "hit"], 63 | ["hold", "held", "held"], 64 | ["hurt", "hurt", "hurt"], 65 | ["keep", "kept", "kept"], 66 | ["know", "knew", "known"], 67 | ["lay", "laid", "laid"], 68 | ["lead", "led", "led"], 69 | ["learn", "learnt", "learnt"], 70 | ["leave", "left", "left"], 71 | ["lend", "lent", "lent"], 72 | ["let", "let", "let"], 73 | ["lie", "lay", "lain"], 74 | ["light", "lit", "lit"], 75 | ["lose", "lost", "lost"], 76 | ["make", "made", "made"], 77 | ["mean", "meant", "meant"], 78 | ["meet", "met", "met"], 79 | ["pay", "paid", "paid"], 80 | ["prove", "proved", "proven"], 81 | ["put", "put", "put"], 82 | ["quit", "quit", "quit"], 83 | ["read", "read", "read"], 84 | ["rid", "rid", "rid"], 85 | ["ride", "rode", "ridden"], 86 | ["ring", "rang", "rung"], 87 | ["rise", "rose", "risen"], 88 | ["run", "ran", "run"], 89 | ["say", "said", "said"], 90 | ["see", "saw", "seen"], 91 | ["seek", "sought", "sought"], 92 | ["sell", "sold", "sold"], 93 | ["send", "sent", "sent"], 94 | ["set", "set", "set"], 95 | ["sew", "sewed", "sewn"], 96 | ["shake", "shook", "shaken"], 97 | ["shear", "sheared", "shorn"], 98 | ["shed", "shed", "shed"], 99 | ["shine", "shone", "shone"], 100 | ["shoot", "shot", "shot"], 101 | ["show", "showed", "shown"], 102 | ["shut", "shut", "shut"], 103 | ["sing", "sang", "sung"], 104 | ["sink", "sank", "sunk"], 105 | ["sit", "sat", "sat"], 106 | ["slay", "slew", "slain"], 107 | ["sleep", "slept", "slept"], 108 | ["slide", "slid", "slid"], 109 | ["sling", "slung", "slung"], 110 | ["slink", "slunk", "slunk"], 111 | ["slit", "slit", "slit"], 112 | ["sow", "sowed", "sown"], 113 | ["speak", "spoke", "spoken"], 114 | ["speed", "sped", "sped"], 115 | ["spend", "spent", "spent"], 116 | ["spin", "spun", "spun"], 117 | ["spit", "spat", "spit", "spat", "spit"], 118 | ["split", "split", "split"], 119 | ["spread", "spread", "spread"], 120 | ["spring", "sprang", "sprung"], 121 | ["stand", "stood", "stood"], 122 | ["steal", "stole", "stolen"], 123 | ["stick", "stuck", "stuck"], 124 | ["sting", "stung", "stung"], 125 | ["stink", "stank", "stunk", "stunk"], 126 | ["stride", "strode", "stridden"], 127 | ["strike", "struck", "struck"], 128 | ["string", "strung", "strung"], 129 | ["strive", "strove", "striven"], 130 | ["swear", "swore", "sworn"], 131 | ["sweep", "swept", "swept"], 132 | ["swell", "swelled", "swollen"], 133 | ["swim", "swam", "swum"], 134 | ["swing", "swung", "swung"], 135 | ["take", "took", "taken"], 136 | ["teach", "taught", "taught"], 137 | ["tear", "tore", "torn"], 138 | ["tell", "told", "told"], 139 | ["think", "thought", "thought"], 140 | ["thrive", "throve", "thrived"], 141 | ["throw", "threw", "thrown"], 142 | ["thrust", "thrust", "thrust"], 143 | ["tread", "trod", "trodden", "trod"], 144 | ["understand", "understood", "understood"], 145 | ["wake", "woke", "woken"], 146 | ["wear", "wore", "worn"], 147 | ["weave", "wove", "woven"], 148 | ["weep", "wept", "wept"], 149 | ["wet", "wet", "wet"], 150 | ["win", "won", "won"], 151 | ["wind", "wound", "wound"], 152 | ["wring", "wrung", "wrung"], 153 | ["write", "wrote", "written"] 154 | ] 155 | -------------------------------------------------------------------------------- /attention.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | from keras.engine.topology import Layer 3 | from keras.layers import activations, initializers, regularizers, constraints, InputSpec 4 | import numpy as np 5 | import math 6 | 7 | class Attention(Layer): 8 | """Just your regular densely-connected NN layer. 9 | 10 | # Arguments 11 | units: Positive integer, dimensionality of the output space. 12 | activation: Activation function to use 13 | (see [activations](../activations.md)). 14 | If you don't specify anything, no activation is applied 15 | (ie. "linear" activation: `a(x) = x`). 16 | use_bias: Boolean, whether the layer uses a bias vector. 17 | kernel_initializer: Initializer for the `kernel` weights matrix 18 | (see [initializers](../initializers.md)). 19 | bias_initializer: Initializer for the bias vector 20 | (see [initializers](../initializers.md)). 21 | kernel_regularizer: Regularizer function applied to 22 | the `kernel` weights matrix 23 | (see [regularizer](../regularizers.md)). 24 | bias_regularizer: Regularizer function applied to the bias vector 25 | (see [regularizer](../regularizers.md)). 26 | activity_regularizer: Regularizer function applied to 27 | the output of the layer (its "activation"). 28 | (see [regularizer](../regularizers.md)). 29 | kernel_constraint: Constraint function applied to 30 | the `kernel` weights matrix 31 | (see [constraints](../constraints.md)). 32 | bias_constraint: Constraint function applied to the bias vector 33 | (see [constraints](../constraints.md)). 34 | 35 | # Input shape 36 | nD tensor with shape: `(batch_size, ..., input_dim)`. 37 | The most common situation would be 38 | a 2D input with shape `(batch_size, input_dim)`. 39 | 40 | # Output shape 41 | nD tensor with shape: `(batch_size, ..., units)`. 42 | For instance, for a 2D input with shape `(batch_size, input_dim)`, 43 | the output would have shape `(batch_size, units)`. 44 | """ 45 | 46 | def __init__(self, units, 47 | activation=None, 48 | use_bias=True, 49 | kernel_initializer='glorot_uniform', 50 | bias_initializer='zeros', 51 | kernel_regularizer=None, 52 | bias_regularizer=None, 53 | activity_regularizer=None, 54 | kernel_constraint=None, 55 | bias_constraint=None, 56 | **kwargs): 57 | if 'input_shape' not in kwargs and 'input_dim' in kwargs: 58 | kwargs['input_shape'] = (kwargs.pop('input_dim'),) 59 | super(Attention, self).__init__(**kwargs) 60 | self.units = units 61 | self.scaling = 1/math.sqrt(self.units) 62 | self.activation = activations.get(activation) 63 | self.use_bias = use_bias 64 | self.kernel_initializer = initializers.get(kernel_initializer) 65 | self.bias_initializer = initializers.get(bias_initializer) 66 | self.kernel_regularizer = regularizers.get(kernel_regularizer) 67 | self.bias_regularizer = regularizers.get(bias_regularizer) 68 | self.activity_regularizer = regularizers.get(activity_regularizer) 69 | self.kernel_constraint = constraints.get(kernel_constraint) 70 | self.bias_constraint = constraints.get(bias_constraint) 71 | self.input_spec = InputSpec(min_ndim=2) 72 | self.supports_masking = True 73 | self.input_spec = [InputSpec(min_ndim=3), InputSpec(min_ndim=3), InputSpec(min_ndim=3)] 74 | 75 | def build(self, input_shape): 76 | assert len(input_shape) >= 2 77 | query_dim = input_shape[0][-1] 78 | key_dim = input_shape[1][-1] 79 | value_dim = input_shape[2][-1] 80 | 81 | self.query_kernel = self.add_weight(shape=(query_dim, self.units), 82 | initializer=self.kernel_initializer, 83 | name='query_kernel', 84 | regularizer=self.kernel_regularizer, 85 | constraint=self.kernel_constraint) 86 | self.key_kernel = self.add_weight(shape=(key_dim, self.units), 87 | initializer=self.kernel_initializer, 88 | name='key_kernel', 89 | regularizer=self.kernel_regularizer, 90 | constraint=self.kernel_constraint) 91 | if self.use_bias: 92 | self.query_bias = self.add_weight(shape=(self.units,), 93 | initializer=self.bias_initializer, 94 | name='query_bias', 95 | regularizer=self.bias_regularizer, 96 | constraint=self.bias_constraint) 97 | self.key_bias = self.add_weight(shape=(self.units,), 98 | initializer=self.bias_initializer, 99 | name='key_bias', 100 | regularizer=self.bias_regularizer, 101 | constraint=self.bias_constraint) 102 | else: 103 | self.query_bias = None 104 | self.key_bias = None 105 | super(Attention, self).build(input_shape) 106 | 107 | def call(self, inputs): 108 | queries, keys, values = inputs 109 | q = K.dot(queries, self.query_kernel) 110 | k = K.dot(keys, self.key_kernel) 111 | if self.use_bias: 112 | q = K.bias_add(q, self.query_bias) 113 | k = K.bias_add(k, self.key_bias) 114 | if self.activation is not None: 115 | q = self.activation(q) 116 | weights = K.softmax(self.scaling*K.batch_dot(q, k, axes=[2,2])) 117 | output = K.batch_dot(weights, values) 118 | return output 119 | 120 | def compute_output_shape(self, input_shape): 121 | assert input_shape and len(input_shape) >= 2 122 | assert input_shape[-1] 123 | output_shape = list(input_shape[0]) 124 | output_shape[-1] = input_shape[2][-1] 125 | return tuple(output_shape) 126 | 127 | def get_config(self): 128 | config = { 129 | 'units': self.units, 130 | 'activation': activations.serialize(self.activation), 131 | 'use_bias': self.use_bias, 132 | 'kernel_initializer': initializers.serialize(self.kernel_initializer), 133 | 'bias_initializer': initializers.serialize(self.bias_initializer), 134 | 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 135 | 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 136 | 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 137 | 'kernel_constraint': constraints.serialize(self.kernel_constraint), 138 | 'bias_constraint': constraints.serialize(self.bias_constraint) 139 | } 140 | base_config = super(Attention, self).get_config() 141 | return dict(list(base_config.items()) + list(config.items())) 142 | 143 | -------------------------------------------------------------------------------- /multihead.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | from keras.engine.topology import Layer 3 | from keras.layers import activations, initializers, regularizers, constraints, InputSpec 4 | import numpy as np 5 | import math 6 | 7 | class MultiHead(Layer): 8 | """Just your regular densely-connected NN layer. 9 | 10 | # Arguments 11 | units: Positive integer, dimensionality of the output space. 12 | activation: Activation function to use 13 | (see [activations](../activations.md)). 14 | If you don't specify anything, no activation is applied 15 | (ie. "linear" activation: `a(x) = x`). 16 | use_bias: Boolean, whether the layer uses a bias vector. 17 | kernel_initializer: Initializer for the `kernel` weights matrix 18 | (see [initializers](../initializers.md)). 19 | bias_initializer: Initializer for the bias vector 20 | (see [initializers](../initializers.md)). 21 | kernel_regularizer: Regularizer function applied to 22 | the `kernel` weights matrix 23 | (see [regularizer](../regularizers.md)). 24 | bias_regularizer: Regularizer function applied to the bias vector 25 | (see [regularizer](../regularizers.md)). 26 | activity_regularizer: Regularizer function applied to 27 | the output of the layer (its "activation"). 28 | (see [regularizer](../regularizers.md)). 29 | kernel_constraint: Constraint function applied to 30 | the `kernel` weights matrix 31 | (see [constraints](../constraints.md)). 32 | bias_constraint: Constraint function applied to the bias vector 33 | (see [constraints](../constraints.md)). 34 | 35 | # Input shape 36 | nD tensor with shape: `(batch_size, ..., input_dim)`. 37 | The most common situation would be 38 | a 2D input with shape `(batch_size, input_dim)`. 39 | 40 | # Output shape 41 | nD tensor with shape: `(batch_size, ..., units)`. 42 | For instance, for a 2D input with shape `(batch_size, input_dim)`, 43 | the output would have shape `(batch_size, units)`. 44 | """ 45 | 46 | def __init__(self, units, 47 | activation=None, 48 | use_bias=True, 49 | kernel_initializer='glorot_uniform', 50 | bias_initializer='zeros', 51 | kernel_regularizer=None, 52 | bias_regularizer=None, 53 | activity_regularizer=None, 54 | kernel_constraint=None, 55 | bias_constraint=None, 56 | **kwargs): 57 | if 'input_shape' not in kwargs and 'input_dim' in kwargs: 58 | kwargs['input_shape'] = (kwargs.pop('input_dim'),) 59 | super(MultiHead, self).__init__(**kwargs) 60 | self.units = units 61 | self.heads = 8 62 | self.activation = activations.get(activation) 63 | self.use_bias = use_bias 64 | self.kernel_initializer = initializers.get(kernel_initializer) 65 | self.bias_initializer = initializers.get(bias_initializer) 66 | self.kernel_regularizer = regularizers.get(kernel_regularizer) 67 | self.bias_regularizer = regularizers.get(bias_regularizer) 68 | self.activity_regularizer = regularizers.get(activity_regularizer) 69 | self.kernel_constraint = constraints.get(kernel_constraint) 70 | self.bias_constraint = constraints.get(bias_constraint) 71 | self.input_spec = InputSpec(min_ndim=2) 72 | self.supports_masking = True 73 | self.input_spec = [InputSpec(min_ndim=3), InputSpec(min_ndim=3), InputSpec(min_ndim=3)] 74 | 75 | def build(self, input_shape): 76 | self.heads = input_shape[2][-1]//self.units 77 | self.scaling = 1/math.sqrt(self.units) 78 | assert len(input_shape) >= 2 79 | query_dim = input_shape[0][-1] 80 | key_dim = input_shape[1][-1] 81 | value_dim = input_shape[2][-1] 82 | 83 | self.query_kernel = self.add_weight(shape=(self.heads, query_dim, self.units), 84 | initializer=self.kernel_initializer, 85 | name='query_kernel', 86 | regularizer=self.kernel_regularizer, 87 | constraint=self.kernel_constraint) 88 | self.key_kernel = self.add_weight(shape=(self.heads, key_dim, self.units), 89 | initializer=self.kernel_initializer, 90 | name='key_kernel', 91 | regularizer=self.kernel_regularizer, 92 | constraint=self.kernel_constraint) 93 | self.value_kernel = self.add_weight(shape=(self.heads, key_dim, self.units), 94 | initializer=self.kernel_initializer, 95 | name='value_kernel', 96 | regularizer=self.kernel_regularizer, 97 | constraint=self.kernel_constraint) 98 | if self.use_bias: 99 | self.query_bias = self.add_weight(shape=(self.heads, self.units), 100 | initializer=self.bias_initializer, 101 | name='query_bias', 102 | regularizer=self.bias_regularizer, 103 | constraint=self.bias_constraint) 104 | self.key_bias = self.add_weight(shape=(self.heads, self.units), 105 | initializer=self.bias_initializer, 106 | name='key_bias', 107 | regularizer=self.bias_regularizer, 108 | constraint=self.bias_constraint) 109 | else: 110 | self.query_bias = None 111 | self.key_bias = None 112 | super(MultiHead, self).build(input_shape) 113 | 114 | def call(self, inputs): 115 | queries, keys, values = inputs 116 | out_list = [] 117 | query = K.dot(queries, self.query_kernel) 118 | key = K.dot(keys, self.key_kernel) 119 | value = K.dot(values, self.value_kernel) 120 | if self.use_bias: 121 | query = query + self.query_bias 122 | key = key + self.key_bias 123 | if self.activation is not None: 124 | query = self.scaling*self.activation(query) 125 | for i in range(self.heads): 126 | weights = K.softmax(K.batch_dot(query[:, :, i, :], key[:, :, i, :], axes=[2,2])) 127 | out = K.batch_dot(weights, value[:, :, i, :]) 128 | out_list.append(out) 129 | output = K.concatenate(out_list, axis=-1) 130 | return output 131 | 132 | def compute_output_shape(self, input_shape): 133 | assert input_shape and len(input_shape) >= 2 134 | assert input_shape[-1] 135 | output_shape = list(input_shape[0]) 136 | output_shape[-1] = input_shape[2][-1] 137 | return tuple(output_shape) 138 | 139 | def get_config(self): 140 | config = { 141 | 'units': self.units, 142 | 'activation': activations.serialize(self.activation), 143 | 'use_bias': self.use_bias, 144 | 'kernel_initializer': initializers.serialize(self.kernel_initializer), 145 | 'bias_initializer': initializers.serialize(self.bias_initializer), 146 | 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 147 | 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 148 | 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 149 | 'kernel_constraint': constraints.serialize(self.kernel_constraint), 150 | 'bias_constraint': constraints.serialize(self.bias_constraint) 151 | } 152 | base_config = super(MultiHead, self).get_config() 153 | return dict(list(base_config.items()) + list(config.items())) 154 | 155 | -------------------------------------------------------------------------------- /deepproof_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | from keras.models import Model 3 | from keras.layers import Input, LSTM, CuDNNLSTM, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Bidirectional, MaxPooling1D, Activation 4 | from keras import backend as K 5 | import numpy as np 6 | import h5py 7 | import sys 8 | import encoding 9 | from attention import Attention 10 | 11 | embed_dim = 64 12 | encoder_dim = 384 13 | latent_dim = 512 # Latent dimensionality of the encoding space. 14 | attn_dim = 128 15 | num_encoder_tokens = len(encoding.char_list) 16 | 17 | def create(use_gpu): 18 | # Define an input sequence and process it. 19 | encoder_inputs = Input(shape=(None, 1)) 20 | reshape1 = Reshape((-1, embed_dim)) 21 | reshape2 = Reshape((-1, embed_dim)) 22 | conv1a = Conv1D(latent_dim, 11, padding='same', activation='tanh') 23 | conv1b = Conv1D(latent_dim, 11, padding='same', activation='sigmoid') 24 | embed = Embedding(num_encoder_tokens, embed_dim) 25 | if use_gpu: 26 | encoder = CuDNNLSTM(encoder_dim, return_sequences=True) 27 | encoder2 = CuDNNLSTM(encoder_dim, return_sequences=True) 28 | else: 29 | encoder = LSTM(encoder_dim, recurrent_activation="sigmoid", return_sequences=True) 30 | encoder2 = LSTM(encoder_dim, recurrent_activation="sigmoid", return_sequences=True) 31 | encoder = Bidirectional(encoder, merge_mode='concat') 32 | encoder2 = Bidirectional(encoder2, merge_mode='concat') 33 | emb = reshape1(embed(encoder_inputs)); 34 | c1a = conv1a(emb) 35 | c1b = conv1b(emb) 36 | encoder_outputs = encoder(Multiply()([c1a, c1b])) 37 | 38 | encoder_outputs = encoder2(encoder_outputs) 39 | 40 | decoder_inputs = Input(shape=(None, 1)) 41 | # We set up our decoder to return full output sequences, 42 | # and to return internal states as well. We don't use the 43 | # return states in the training model, but we will use them in inference. 44 | if use_gpu: 45 | language_lstm = CuDNNLSTM(latent_dim, return_sequences=True, return_state=True) 46 | decoder_lstm = CuDNNLSTM(latent_dim, return_sequences=True, return_state=True) 47 | else: 48 | language_lstm = LSTM(latent_dim, recurrent_activation="sigmoid", return_sequences=True, return_state=True) 49 | decoder_lstm = LSTM(latent_dim, recurrent_activation="sigmoid", return_sequences=True, return_state=True) 50 | 51 | dec_lstm_input = reshape1(embed(decoder_inputs)) 52 | 53 | language_outputs, _, _ = language_lstm(dec_lstm_input) 54 | 55 | attn = Attention(attn_dim, activation='tanh') 56 | attn_output = attn([language_outputs, encoder_outputs, encoder_outputs]) 57 | 58 | dec_lstm_input2 = Concatenate()([dec_lstm_input, language_outputs, attn_output]) 59 | 60 | decoder_outputs, _, _ = decoder_lstm(dec_lstm_input2) 61 | decoder_dense = Dense(num_encoder_tokens, activation='softmax') 62 | decoder_outputs = decoder_dense(decoder_outputs) 63 | 64 | model = Model([encoder_inputs, decoder_inputs], decoder_outputs) 65 | 66 | #The following is needed for inference (one at a time decoding) only 67 | encoder_model = Model(encoder_inputs, [encoder_outputs]) 68 | 69 | decoder_state_input_h = Input(shape=(latent_dim,)) 70 | decoder_state_input_c = Input(shape=(latent_dim,)) 71 | lang_state_input_h = Input(shape=(latent_dim,)) 72 | lang_state_input_c = Input(shape=(latent_dim,)) 73 | decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c, lang_state_input_h, lang_state_input_c] 74 | decoder_enc_inputs = Input(shape=(None, 2*encoder_dim)) 75 | tmp = reshape1(embed(decoder_inputs)) 76 | lang_outputs, lstate_h, lstate_c = language_lstm(tmp, initial_state=decoder_states_inputs[2:]) 77 | 78 | attn_output = attn([lang_outputs, decoder_enc_inputs, decoder_enc_inputs]) 79 | 80 | decoder_outputs, state_h, state_c = decoder_lstm( 81 | Concatenate()([tmp, lang_outputs, attn_output]), initial_state=decoder_states_inputs[0:2]) 82 | decoder_states = [state_h, state_c, lstate_h, lstate_c] 83 | decoder_outputs = decoder_dense(decoder_outputs) 84 | decoder_model = Model( 85 | [decoder_inputs, decoder_enc_inputs] + decoder_states_inputs, 86 | [decoder_outputs] + decoder_states) 87 | return (encoder_model, decoder_model, model) 88 | 89 | def decode_sequence(models, input_seq): 90 | [encoder_model, decoder_model] = models 91 | # Encode the input as state vectors. 92 | encoder_outputs = encoder_model.predict(input_seq[:,:,0:1]) 93 | state_h = state_c = lstate_h = lstate_c = np.zeros((1, latent_dim)) 94 | states_value = [state_h, state_c, lstate_h, lstate_c] 95 | 96 | # Generate empty target sequence of length 1. 97 | target_seq = np.zeros((1, 1, 1)) 98 | # Populate the first character of target sequence with the start character. 99 | target_seq[0, 0, :] = 0 100 | 101 | # Sampling loop for a batch of sequences 102 | # (to simplify, here we assume a batch of size 1). 103 | decoded_sentence = '' 104 | foo=0 105 | prob = 0 106 | while foo < input_seq.shape[1]: 107 | #target_seq[0, 0, 0] = input_seq[0, foo, 0] 108 | output_tokens, h, c, lh, lc = decoder_model.predict( 109 | [target_seq, encoder_outputs] + states_value) 110 | 111 | # Sample a token 112 | sampled_token_index = np.argmax(output_tokens[0, -1, :]) 113 | sampled_char = encoding.char_list[sampled_token_index] 114 | decoded_sentence += sampled_char 115 | prob += math.log(output_tokens[0, -1, sampled_token_index]) 116 | 117 | # Update the target sequence (of length 1). 118 | target_seq = np.zeros((1, 1, 1)) 119 | target_seq[0, 0, 0] = sampled_token_index 120 | 121 | # Update states 122 | states_value = [h, c, lh, lc] 123 | foo = foo+1 124 | print(prob) 125 | return decoded_sentence 126 | 127 | def beam_decode_sequence(models, input_seq): 128 | [encoder_model, decoder_model] = models 129 | # Encode the input as state vectors. 130 | B = 10 131 | encoder_outputs = encoder_model.predict(input_seq[:,:,0:1]) 132 | state_h = state_c = lstate_h = lstate_c = np.zeros((1, latent_dim)) 133 | in_nbest=[(0., '', np.array([[[0]]]), [state_h, state_c, lstate_h, lstate_c])] 134 | foo=0 135 | while foo < input_seq.shape[1]: 136 | out_nbest = [] 137 | for prob, decoded_sentence, target_seq, states_value in in_nbest: 138 | output_tokens, h, c, lh, lc = decoder_model.predict( 139 | [target_seq, encoder_outputs] + states_value) 140 | arg = np.argsort(output_tokens[0, -1, :]) 141 | # Sample a token 142 | # Update states 143 | states_value = [h, c, lh, lc] 144 | for i in range(B): 145 | sampled_token_index = arg[-1-i] 146 | sampled_char = encoding.char_list[sampled_token_index] 147 | # Update the target sequence (of length 1). 148 | target_seq = np.array([[[sampled_token_index]]]) 149 | new_prob = prob + math.log(output_tokens[0, -1, sampled_token_index]) 150 | candidate = (new_prob, decoded_sentence + sampled_char, target_seq, states_value) 151 | if len(out_nbest) < B: 152 | out_nbest.append(candidate) 153 | elif new_prob > out_nbest[-1][0]: 154 | for j in range(len(out_nbest)): 155 | if new_prob > out_nbest[j][0]: 156 | out_nbest = out_nbest[:j] + [candidate] + out_nbest[j+1:] 157 | break 158 | 159 | in_nbest = out_nbest 160 | foo = foo+1 161 | print(in_nbest[0][0]) 162 | return in_nbest[0][1] 163 | 164 | 165 | 166 | def decode_ground_truth(models, input_seq, output_seq): 167 | [encoder_model, decoder_model] = models 168 | # Encode the input as state vectors. 169 | encoder_outputs = encoder_model.predict(input_seq[:,:,0:1]) 170 | state_h = state_c = lstate_h = lstate_c = np.zeros((1, latent_dim)) 171 | states_value = [state_h, state_c, lstate_h, lstate_c] 172 | 173 | # Generate empty target sequence of length 1. 174 | target_seq = np.zeros((1, 1, 1)) 175 | # Populate the first character of target sequence with the start character. 176 | target_seq[0, 0, :] = 0 177 | 178 | # Sampling loop for a batch of sequences 179 | # (to simplify, here we assume a batch of size 1). 180 | decoded_sentence = '' 181 | foo=0 182 | prob = 0 183 | while foo < input_seq.shape[1]: 184 | #target_seq[0, 0, 0] = input_seq[0, foo, 0] 185 | output_tokens, h, c, lh, lc = decoder_model.predict( 186 | [target_seq, encoder_outputs] + states_value) 187 | 188 | # Sample a token 189 | sampled_token_index = output_seq[foo] 190 | sampled_char = encoding.char_list[sampled_token_index] 191 | decoded_sentence += sampled_char 192 | prob += math.log(output_tokens[0, -1, sampled_token_index]) 193 | 194 | # Update the target sequence (of length 1). 195 | target_seq = np.zeros((1, 1, 1)) 196 | target_seq[0, 0, 0] = sampled_token_index 197 | 198 | # Update states 199 | states_value = [h, c, lh, lc] 200 | foo = foo+1 201 | print(prob) 202 | return prob 203 | -------------------------------------------------------------------------------- /mistakes.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | import random 4 | import math 5 | import re 6 | from irregular import irregular_verbs 7 | from regular import regular_verbs 8 | 9 | def extend_cap(rules): 10 | cap = [] 11 | for r in rules: 12 | new = [] 13 | for word in r: 14 | tmp = word[0].upper() + word[1:] 15 | new = new + [tmp] 16 | cap = cap + [new] 17 | rules += cap 18 | 19 | acceptable_contractions = [["it is", "it's"], 20 | ["I am", "I'm"], 21 | ["you are", "you're"], 22 | ["he is", "he's"], 23 | ["she is", "she's"], 24 | ["we are", "we're"], 25 | ["they are", "they're"], 26 | ["cannot", "can't"], 27 | ["do not", "don't"], 28 | ["does not", "doesn't"], 29 | ["did not", "didn't"], 30 | ["should not", "shouldn't"], 31 | ["will not", "won't"] 32 | ] 33 | extend_cap(acceptable_contractions) 34 | 35 | homonyms_rules = [["there", "their", "they're"], 36 | ["to", "too", "two"], 37 | ["break", "brake"], 38 | ["its", "it's"], 39 | ["then", "than"], 40 | ["which", "witch"], 41 | ["here", "hear"], 42 | ["weather", "whether"], 43 | ["bear", "bare"], 44 | ["fore", "for", "four"], 45 | ["meet", "meat"], 46 | ["wear", "where", "ware"], 47 | ["week", "weak"], 48 | ["wait", "weight"], 49 | ["waste", "waist"], 50 | ["sweet", "suite", "sweat"], 51 | ["steel", "steal"], 52 | ["steak", "stake"], 53 | ["sun", "son"], 54 | ["no", "know"], 55 | ["mail", "male"], 56 | ["light", "lite"], 57 | ["hole", "whole"], 58 | ["maid", "made"], 59 | ["fair", "fare"], 60 | ["write", "right"], 61 | ["advise", "advice"] 62 | ] 63 | extend_cap(homonyms_rules) 64 | 65 | prepositions_rules = [["to", "at", "in", "for"], 66 | ["out", "off"], 67 | ["on", "over"], 68 | ["since", "for"], 69 | ["from", "than"], 70 | ["on", "at"] 71 | ] 72 | extend_cap(prepositions_rules) 73 | 74 | misc_rules = [["the", "a", "an"], 75 | ["you", "your", "you're", "yours"], 76 | ["I", "me", "my", "mine"], 77 | ["he", "him", "his"], 78 | ["she", "her", "hers"], 79 | ["this", "that"], 80 | ["excepted", "accepted"], 81 | ["affect", "effect"], 82 | ["affects", "effects"], 83 | ["affected", "effected"], 84 | ["your", "you're"], 85 | ["who", "that"], 86 | ["who", "whom", "whose", "who's"], 87 | ["in to", "into"], 88 | ["lose", "loose"], 89 | ["an", "and"], 90 | ["are", "our"], 91 | ["not", "now"], 92 | ["I", "i"], 93 | ["thing", "think"], 94 | ["complains", "complaints"], 95 | ["now", "know"], 96 | ["exit", "exist"], 97 | ["whether", "if"] 98 | ] 99 | extend_cap(misc_rules) 100 | 101 | comparison_rules = [["better", "good", "best"], 102 | ["worse", "bad", "worst"], 103 | ["more", "most"], 104 | ["slower", "slow", "slowest"], 105 | ["faster", "fast", "fastest"], 106 | ["larger", "large", "largest"], 107 | ["smaller", "small", "smallest"] 108 | ] 109 | extend_cap(comparison_rules) 110 | 111 | omitted_words = ["the", "a", "an", "to", "on", "of", "is", "that"] 112 | 113 | subword_subst = [["ea", "ee"], 114 | ["oo", "ou"], 115 | ["gth", "ght"], 116 | ["an", "en"], 117 | ["on", "un"], 118 | ["er", "ar"], 119 | ["'s", "s'"], 120 | ["n't", "n't not"] 121 | ] 122 | 123 | #these are adjacent on a querty keyboard 124 | #adjacent_list = "poiuytrewqasdfghjkl.,mnbvcxz" 125 | 126 | irregular_rules = [] 127 | for verb in irregular_verbs: 128 | present = verb[0] 129 | if present[-1] == 'e': 130 | badpast = present + 'd' 131 | else: 132 | badpast = present + 'ed' 133 | if verb[-1] == verb[-2]: 134 | verb = verb[:-1] 135 | if verb[-1] == verb[-2]: 136 | verb = verb[:-1] 137 | verb = verb + [badpast] 138 | if verb[-1] == verb[-2]: 139 | verb = verb[:-1] 140 | irregular_rules = irregular_rules + [verb] 141 | #print(verb); 142 | extend_cap(irregular_rules) 143 | #print(irregular_rules) 144 | extend_cap(regular_verbs) 145 | 146 | def word_substitute(line, rules, prob): 147 | for group in rules: 148 | for word in group: 149 | word_len = len(word) 150 | where = 0 151 | while True: 152 | pos = line[where:].find(word) 153 | if pos < 0: 154 | break 155 | where = where + pos 156 | if (where > 0 and line[where-1] != ' ') or (where+word_len < len(line) and line[where+word_len] != ' '): 157 | where += 1 158 | continue 159 | if random.random() < prob: 160 | subst = random.choice(group) 161 | line = line[:where] + subst + line[(where+word_len):] 162 | where += word_len 163 | return line 164 | 165 | def word_delete(line, rules, prob): 166 | for word in rules: 167 | word_len = len(word) 168 | where = 0 169 | while True: 170 | pos = line[where:].find(word) 171 | if pos < 0: 172 | break 173 | where = where + pos 174 | if (where > 0 and line[where-1] != ' ') or (where+word_len < len(line) and line[where+word_len] != ' '): 175 | where += 1 176 | continue 177 | if random.random() < prob: 178 | if where > 0 and line[where-1] == ' ': 179 | line = line[:where-1] + line[(where+word_len):] 180 | else: 181 | line = line[:where] + line[(where+word_len+1):] 182 | where += word_len 183 | return line 184 | 185 | def word_double(line, rules, prob): 186 | for word in rules: 187 | word_len = len(word) 188 | where = 0 189 | while True: 190 | pos = line[where:].find(word) 191 | if pos < 0: 192 | break 193 | where = where + pos 194 | if (where > 0 and line[where-1] != ' ') or (where+word_len < len(line) and line[where+word_len] != ' '): 195 | where += 1 196 | continue 197 | if random.random() < prob: 198 | line = line[:where] + word + " " + line[(where):] 199 | where += 2*word_len+2 200 | return line 201 | 202 | def subword_substitute(line, rules, prob): 203 | for group in rules: 204 | for subword in group: 205 | word_len = len(subword) 206 | where = 0 207 | while True: 208 | pos = line[where:].find(subword) 209 | if pos < 0: 210 | break 211 | where = where + pos 212 | if random.random() < prob: 213 | subst = random.choice(group) 214 | line = line[:where] + subst + line[(where+word_len):] 215 | where += word_len 216 | return line 217 | 218 | def strip_plural(line, prob): 219 | where = 0 220 | while True: 221 | pos = re.search("[a-zA-Z]s[ ,.;:$]", line[where:]) 222 | if pos: 223 | pos = pos.start() 224 | else: 225 | break 226 | where += pos 227 | if random.random() < prob: 228 | line = line[:where+1] + line[where+2:] 229 | where += 2 230 | return line 231 | 232 | def add_plural(line, prob): 233 | where = 0 234 | while True: 235 | pos = re.search("[a-zA-Z][ ,.;:$]", line[where:]) 236 | if pos: 237 | pos = pos.start() 238 | else: 239 | break 240 | where += pos 241 | if random.random() < prob: 242 | line = line[:where+1] + 's' + line[where+1:] 243 | where += 4 244 | return line 245 | 246 | def strip_punctuation(line, prob): 247 | where = 0 248 | while True: 249 | pos = re.search("[,.;:]", line[where:]) 250 | if pos: 251 | pos = pos.start() 252 | else: 253 | break 254 | where += pos 255 | if random.random() < prob: 256 | line = line[:where] + line[where+1:] 257 | where += 2 258 | return line 259 | 260 | #FIXME: How do we treat other punctuation marks? 261 | def add_comma(line, prob): 262 | where = 0 263 | while True: 264 | pos = re.search("[a-zA-Z][ $]", line[where:]) 265 | if pos: 266 | pos = pos.start() 267 | else: 268 | break 269 | where += pos 270 | if random.random() < prob: 271 | line = line[:where+1] + ',' + line[where+1:] 272 | where += 4 273 | return line 274 | 275 | def letter_deletion(line, prob): 276 | line_len = len(line) 277 | pos = 0 278 | prob_1 = 1./prob 279 | while pos < line_len: 280 | uni = random.random() 281 | pos = pos - int(prob_1*math.log(.00001 + uni)) 282 | if pos >= line_len: 283 | break 284 | line = line[:pos] + line[(pos+1):] 285 | line_len = len(line) 286 | return line 287 | 288 | 289 | def letter_doubling(line, prob): 290 | line_len = len(line) 291 | pos = 0 292 | prob_1 = 1./prob 293 | while pos < line_len: 294 | uni = random.random() 295 | pos = pos - int(prob_1*math.log(.00001 + uni)) 296 | if pos >= line_len: 297 | break 298 | line = line[:pos] + line[pos] + line[pos:] 299 | line_len = len(line) 300 | return line 301 | 302 | def letter_swap(line, prob): 303 | line_len = len(line) 304 | pos = 1 305 | prob_1 = 1./prob 306 | while pos < line_len: 307 | uni = random.random() 308 | pos = pos - int(prob_1*math.log(.00001 + uni)) 309 | if pos >= line_len: 310 | break 311 | line = line[:(pos-1)] + line[pos] + line[pos-1] + line[(pos+1):] 312 | line_len = len(line) 313 | return line 314 | 315 | def letter_subst(line, prob): 316 | line_len = len(line) 317 | pos = 0 318 | prob_1 = 1./prob 319 | while pos < line_len: 320 | uni = random.random() 321 | pos = pos - int(prob_1*math.log(.00001 + uni)) 322 | if pos >= line_len: 323 | break 324 | line = line[:pos] + chr(32 + random.randrange(95)) + line[pos+1:] 325 | return line 326 | --------------------------------------------------------------------------------