├── .gitignore ├── .settings └── .gitignore ├── README.md ├── keras_spell.py └── license /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | 55 | # Sphinx documentation 56 | docs/_build/ 57 | 58 | # PyBuilder 59 | target/ 60 | 61 | #Ipython Notebook 62 | .ipynb_checkpoints 63 | /.project 64 | /.pydevproject 65 | -------------------------------------------------------------------------------- /.settings/.gitignore: -------------------------------------------------------------------------------- 1 | /org.eclipse.core.resources.prefs 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepSpell 2 | a Deep Learning based Speller 3 | 4 | See https://medium.com/@majortal/deep-spelling-9ffef96a24f6#.2c9pu8nlm 5 | 6 | 7 | Additional details: 8 | 9 | I used this AMI to train the system: 10 | https://aws.amazon.com/marketplace/pp/B06VSPXKDX 11 | On a p2.xlarge instance (currently at $0.9 per Hour) 12 | -------------------------------------------------------------------------------- /keras_spell.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | ''' 3 | Created on Nov 26, 2015 4 | 5 | @author: tal 6 | 7 | Based in part on: 8 | Learn math - https://github.com/fchollet/keras/blob/master/examples/addition_rnn.py 9 | 10 | See https://medium.com/@majortal/deep-spelling-9ffef96a24f6#.2c9pu8nlm 11 | ''' 12 | 13 | from __future__ import print_function, division, unicode_literals 14 | 15 | import os 16 | import errno 17 | from collections import Counter 18 | from hashlib import sha256 19 | import re 20 | import json 21 | import itertools 22 | import logging 23 | import requests 24 | import numpy as np 25 | from numpy.random import choice as random_choice, randint as random_randint, shuffle as random_shuffle, seed as random_seed, rand 26 | from numpy import zeros as np_zeros # pylint:disable=no-name-in-module 27 | 28 | from keras.models import Sequential, load_model 29 | from keras.layers import Activation, TimeDistributed, Dense, RepeatVector, Dropout, recurrent 30 | from keras.callbacks import Callback 31 | 32 | # Set a logger for the module 33 | LOGGER = logging.getLogger(__name__) # Every log will use the module name 34 | LOGGER.addHandler(logging.StreamHandler()) 35 | LOGGER.setLevel(logging.DEBUG) 36 | 37 | random_seed(123) # Reproducibility 38 | 39 | class Configuration(object): 40 | """Dump stuff here""" 41 | 42 | CONFIG = Configuration() 43 | #pylint:disable=attribute-defined-outside-init 44 | # Parameters for the model: 45 | CONFIG.input_layers = 2 46 | CONFIG.output_layers = 2 47 | CONFIG.amount_of_dropout = 0.2 48 | CONFIG.hidden_size = 500 49 | CONFIG.initialization = "he_normal" # : Gaussian initialization scaled by fan-in (He et al., 2014) 50 | CONFIG.number_of_chars = 100 51 | CONFIG.max_input_len = 60 52 | CONFIG.inverted = True 53 | 54 | # parameters for the training: 55 | CONFIG.batch_size = 100 # As the model changes in size, play with the batch size to best fit the process in memory 56 | CONFIG.epochs = 500 # due to mini-epochs. 57 | CONFIG.steps_per_epoch = 1000 # This is a mini-epoch. Using News 2013 an epoch would need to be ~60K. 58 | CONFIG.validation_steps = 10 59 | CONFIG.number_of_iterations = 10 60 | #pylint:enable=attribute-defined-outside-init 61 | 62 | DIGEST = sha256(json.dumps(CONFIG.__dict__, sort_keys=True)).hexdigest() 63 | 64 | # Parameters for the dataset 65 | MIN_INPUT_LEN = 5 66 | AMOUNT_OF_NOISE = 0.2 / CONFIG.max_input_len 67 | CHARS = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ .") 68 | PADDING = "☕" 69 | 70 | DATA_FILES_PATH = "~/Downloads/data" 71 | DATA_FILES_FULL_PATH = os.path.expanduser(DATA_FILES_PATH) 72 | DATA_FILES_URL = "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2013.en.shuffled.gz" 73 | NEWS_FILE_NAME_COMPRESSED = os.path.join(DATA_FILES_FULL_PATH, "news.2013.en.shuffled.gz") # 1.1 GB 74 | NEWS_FILE_NAME_ENGLISH = "news.2013.en.shuffled" 75 | NEWS_FILE_NAME = os.path.join(DATA_FILES_FULL_PATH, NEWS_FILE_NAME_ENGLISH) 76 | NEWS_FILE_NAME_CLEAN = os.path.join(DATA_FILES_FULL_PATH, "news.2013.en.clean") 77 | NEWS_FILE_NAME_FILTERED = os.path.join(DATA_FILES_FULL_PATH, "news.2013.en.filtered") 78 | NEWS_FILE_NAME_SPLIT = os.path.join(DATA_FILES_FULL_PATH, "news.2013.en.split") 79 | NEWS_FILE_NAME_TRAIN = os.path.join(DATA_FILES_FULL_PATH, "news.2013.en.train") 80 | NEWS_FILE_NAME_VALIDATE = os.path.join(DATA_FILES_FULL_PATH, "news.2013.en.validate") 81 | CHAR_FREQUENCY_FILE_NAME = os.path.join(DATA_FILES_FULL_PATH, "char_frequency.json") 82 | SAVED_MODEL_FILE_NAME = os.path.join(DATA_FILES_FULL_PATH, "keras_spell_e{}.h5") # an HDF5 file 83 | 84 | # Some cleanup: 85 | NORMALIZE_WHITESPACE_REGEX = re.compile(r'[^\S\n]+', re.UNICODE) # match all whitespace except newlines 86 | RE_DASH_FILTER = re.compile(r'[\-\˗\֊\‐\‑\‒\–\—\⁻\₋\−\﹣\-]', re.UNICODE) 87 | RE_APOSTROPHE_FILTER = re.compile(r''|[ʼ՚'‘’‛❛❜ߴߵ`‵´ˊˋ{}{}{}{}{}{}{}{}{}]'.format(unichr(768), unichr(769), unichr(832), 88 | unichr(833), unichr(2387), unichr(5151), 89 | unichr(5152), unichr(65344), unichr(8242)), 90 | re.UNICODE) 91 | RE_LEFT_PARENTH_FILTER = re.compile(r'[\(\[\{\⁽\₍\❨\❪\﹙\(]', re.UNICODE) 92 | RE_RIGHT_PARENTH_FILTER = re.compile(r'[\)\]\}\⁾\₎\❩\❫\﹚\)]', re.UNICODE) 93 | ALLOWED_CURRENCIES = """¥£₪$€฿₨""" 94 | ALLOWED_PUNCTUATION = """-!?/;"'%&<>.()[]{}@#:,|=*""" 95 | RE_BASIC_CLEANER = re.compile(r'[^\w\s{}{}]'.format(re.escape(ALLOWED_CURRENCIES), re.escape(ALLOWED_PUNCTUATION)), re.UNICODE) 96 | 97 | # pylint:disable=invalid-name 98 | 99 | def download_the_news_data(): 100 | """Download the news data""" 101 | LOGGER.info("Downloading") 102 | try: 103 | os.makedirs(os.path.dirname(NEWS_FILE_NAME_COMPRESSED)) 104 | except OSError as exception: 105 | if exception.errno != errno.EEXIST: 106 | raise 107 | with open(NEWS_FILE_NAME_COMPRESSED, "wb") as output_file: 108 | response = requests.get(DATA_FILES_URL, stream=True) 109 | total_length = response.headers.get('content-length') 110 | downloaded = percentage = 0 111 | print("»"*100) 112 | total_length = int(total_length) 113 | for data in response.iter_content(chunk_size=4096): 114 | downloaded += len(data) 115 | output_file.write(data) 116 | new_percentage = 100 * downloaded // total_length 117 | if new_percentage > percentage: 118 | print("☑", end="") 119 | percentage = new_percentage 120 | print() 121 | 122 | def uncompress_data(): 123 | """Uncompress the data files""" 124 | import gzip 125 | with gzip.open(NEWS_FILE_NAME_COMPRESSED, 'rb') as compressed_file: 126 | with open(NEWS_FILE_NAME_COMPRESSED[:-3], 'wb') as outfile: 127 | outfile.write(compressed_file.read()) 128 | 129 | def add_noise_to_string(a_string, amount_of_noise): 130 | """Add some artificial spelling mistakes to the string""" 131 | if rand() < amount_of_noise * len(a_string): 132 | # Replace a character with a random character 133 | random_char_position = random_randint(len(a_string)) 134 | a_string = a_string[:random_char_position] + random_choice(CHARS[:-1]) + a_string[random_char_position + 1:] 135 | if rand() < amount_of_noise * len(a_string): 136 | # Delete a character 137 | random_char_position = random_randint(len(a_string)) 138 | a_string = a_string[:random_char_position] + a_string[random_char_position + 1:] 139 | if len(a_string) < CONFIG.max_input_len and rand() < amount_of_noise * len(a_string): 140 | # Add a random character 141 | random_char_position = random_randint(len(a_string)) 142 | a_string = a_string[:random_char_position] + random_choice(CHARS[:-1]) + a_string[random_char_position:] 143 | if rand() < amount_of_noise * len(a_string): 144 | # Transpose 2 characters 145 | random_char_position = random_randint(len(a_string) - 1) 146 | a_string = (a_string[:random_char_position] + a_string[random_char_position + 1] + a_string[random_char_position] + 147 | a_string[random_char_position + 2:]) 148 | return a_string 149 | 150 | def _vectorize(questions, answers, ctable): 151 | """Vectorize the data as numpy arrays""" 152 | len_of_questions = len(questions) 153 | X = np_zeros((len_of_questions, CONFIG.max_input_len, ctable.size), dtype=np.bool) 154 | for i in xrange(len(questions)): 155 | sentence = questions.pop() 156 | for j, c in enumerate(sentence): 157 | try: 158 | X[i, j, ctable.char_indices[c]] = 1 159 | except KeyError: 160 | pass # Padding 161 | y = np_zeros((len_of_questions, CONFIG.max_input_len, ctable.size), dtype=np.bool) 162 | for i in xrange(len(answers)): 163 | sentence = answers.pop() 164 | for j, c in enumerate(sentence): 165 | try: 166 | y[i, j, ctable.char_indices[c]] = 1 167 | except KeyError: 168 | pass # Padding 169 | return X, y 170 | 171 | def slice_X(X, start=None, stop=None): 172 | """This takes an array-like, or a list of 173 | array-likes, and outputs: 174 | - X[start:stop] if X is an array-like 175 | - [x[start:stop] for x in X] if X in a list 176 | Can also work on list/array of indices: `slice_X(x, indices)` 177 | # Arguments 178 | start: can be an integer index (start index) 179 | or a list/array of indices 180 | stop: integer (stop index); should be None if 181 | `start` was a list. 182 | """ 183 | if isinstance(X, list): 184 | if hasattr(start, '__len__'): 185 | # hdf5 datasets only support list objects as indices 186 | if hasattr(start, 'shape'): 187 | start = start.tolist() 188 | return [x[start] for x in X] 189 | else: 190 | return [x[start:stop] for x in X] 191 | else: 192 | if hasattr(start, '__len__'): 193 | if hasattr(start, 'shape'): 194 | start = start.tolist() 195 | return X[start] 196 | else: 197 | return X[start:stop] 198 | 199 | def vectorize(questions, answers, chars=None): 200 | """Vectorize the questions and expected answers""" 201 | print('Vectorization...') 202 | chars = chars or CHARS 203 | ctable = CharacterTable(chars) 204 | X, y = _vectorize(questions, answers, ctable) 205 | # Explicitly set apart 10% for validation data that we never train over 206 | split_at = int(len(X) - len(X) / 10) 207 | (X_train, X_val) = (slice_X(X, 0, split_at), slice_X(X, split_at)) 208 | (y_train, y_val) = (y[:split_at], y[split_at:]) 209 | 210 | print(X_train.shape) 211 | print(y_train.shape) 212 | 213 | return X_train, X_val, y_train, y_val, CONFIG.max_input_len, ctable 214 | 215 | 216 | def generate_model(output_len, chars=None): 217 | """Generate the model""" 218 | print('Build model...') 219 | chars = chars or CHARS 220 | model = Sequential() 221 | # "Encode" the input sequence using an RNN, producing an output of hidden_size 222 | # note: in a situation where your input sequences have a variable length, 223 | # use input_shape=(None, nb_feature). 224 | for layer_number in range(CONFIG.input_layers): 225 | model.add(recurrent.LSTM(CONFIG.hidden_size, input_shape=(None, len(chars)), kernel_initializer=CONFIG.initialization, 226 | return_sequences=layer_number + 1 < CONFIG.input_layers)) 227 | model.add(Dropout(CONFIG.amount_of_dropout)) 228 | # For the decoder's input, we repeat the encoded input for each time step 229 | model.add(RepeatVector(output_len)) 230 | # The decoder RNN could be multiple layers stacked or a single layer 231 | for _ in range(CONFIG.output_layers): 232 | model.add(recurrent.LSTM(CONFIG.hidden_size, return_sequences=True, kernel_initializer=CONFIG.initialization)) 233 | model.add(Dropout(CONFIG.amount_of_dropout)) 234 | 235 | # For each of step of the output sequence, decide which character should be chosen 236 | model.add(TimeDistributed(Dense(len(chars), kernel_initializer=CONFIG.initialization))) 237 | model.add(Activation('softmax')) 238 | 239 | model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) 240 | return model 241 | 242 | 243 | class Colors(object): 244 | """For nicer printouts""" 245 | green = '\033[92m' 246 | red = '\033[91m' 247 | close = '\033[0m' 248 | 249 | 250 | class CharacterTable(object): 251 | """ 252 | Given a set of characters: 253 | + Encode them to a one hot integer representation 254 | + Decode the one hot integer representation to their character output 255 | + Decode a vector of probabilities to their character output 256 | """ 257 | def __init__(self, chars): 258 | self.chars = sorted(set(chars)) 259 | self.char_indices = dict((c, i) for i, c in enumerate(self.chars)) 260 | self.indices_char = dict((i, c) for i, c in enumerate(self.chars)) 261 | 262 | @property 263 | def size(self): 264 | """The number of chars""" 265 | return len(self.chars) 266 | 267 | def encode(self, C, maxlen): 268 | """Encode as one-hot""" 269 | X = np_zeros((maxlen, len(self.chars)), dtype=np.bool) # pylint:disable=no-member 270 | for i, c in enumerate(C): 271 | X[i, self.char_indices[c]] = 1 272 | return X 273 | 274 | def decode(self, X, calc_argmax=True): 275 | """Decode from one-hot""" 276 | if calc_argmax: 277 | X = X.argmax(axis=-1) 278 | return ''.join(self.indices_char[x] for x in X if x) 279 | 280 | def generator(file_name): 281 | """Returns a tuple (inputs, targets) 282 | All arrays should contain the same number of samples. 283 | The generator is expected to loop over its data indefinitely. 284 | An epoch finishes when samples_per_epoch samples have been seen by the model. 285 | """ 286 | ctable = CharacterTable(read_top_chars()) 287 | batch_of_answers = [] 288 | while True: 289 | with open(file_name) as answers: 290 | for answer in answers: 291 | batch_of_answers.append(answer.strip().decode('utf-8')) 292 | if len(batch_of_answers) == CONFIG.batch_size: 293 | random_shuffle(batch_of_answers) 294 | batch_of_questions = [] 295 | for answer_index, answer in enumerate(batch_of_answers): 296 | question, answer = generate_question(answer) 297 | batch_of_answers[answer_index] = answer 298 | assert len(answer) == CONFIG.max_input_len 299 | question = question[::-1] if CONFIG.inverted else question 300 | batch_of_questions.append(question) 301 | X, y = _vectorize(batch_of_questions, batch_of_answers, ctable) 302 | yield X, y 303 | batch_of_answers = [] 304 | 305 | def print_random_predictions(model, ctable, X_val, y_val): 306 | """Select 10 samples from the validation set at random so we can visualize errors""" 307 | print() 308 | for _ in range(10): 309 | ind = random_randint(0, len(X_val)) 310 | rowX, rowy = X_val[np.array([ind])], y_val[np.array([ind])] # pylint:disable=no-member 311 | preds = model.predict_classes(rowX, verbose=0) 312 | q = ctable.decode(rowX[0]) 313 | correct = ctable.decode(rowy[0]) 314 | guess = ctable.decode(preds[0], calc_argmax=False) 315 | if CONFIG.inverted: 316 | print('Q', q[::-1]) # inverted back! 317 | else: 318 | print('Q', q) 319 | print('A', correct) 320 | print(Colors.green + '☑' + Colors.close if correct == guess else Colors.red + '☒' + Colors.close, guess) 321 | print('---') 322 | print() 323 | 324 | 325 | class OnEpochEndCallback(Callback): 326 | """Execute this every end of epoch""" 327 | 328 | def on_epoch_end(self, epoch, logs=None): 329 | """On Epoch end - do some stats""" 330 | ctable = CharacterTable(read_top_chars()) 331 | X_val, y_val = next(generator(NEWS_FILE_NAME_VALIDATE)) 332 | print_random_predictions(self.model, ctable, X_val, y_val) 333 | self.model.save(SAVED_MODEL_FILE_NAME.format(epoch)) 334 | 335 | ON_EPOCH_END_CALLBACK = OnEpochEndCallback() 336 | 337 | def itarative_train(model): 338 | """ 339 | Iterative training of the model 340 | - To allow for finite RAM... 341 | - To allow infinite training data as the training noise is injected in runtime 342 | """ 343 | model.fit_generator(generator(NEWS_FILE_NAME_TRAIN), steps_per_epoch=CONFIG.steps_per_epoch, 344 | epochs=CONFIG.epochs, 345 | verbose=1, callbacks=[ON_EPOCH_END_CALLBACK, ], validation_data=generator(NEWS_FILE_NAME_VALIDATE), 346 | validation_steps=CONFIG.validation_steps, 347 | class_weight=None, max_q_size=10, workers=1, 348 | pickle_safe=False, initial_epoch=0) 349 | 350 | 351 | def iterate_training(model, X_train, y_train, X_val, y_val, ctable): 352 | """Iterative Training""" 353 | # Train the model each generation and show predictions against the validation dataset 354 | for iteration in range(1, CONFIG.number_of_iterations): 355 | print() 356 | print('-' * 50) 357 | print('Iteration', iteration) 358 | model.fit(X_train, y_train, batch_size=CONFIG.batch_size, epochs=CONFIG.epochs, 359 | validation_data=(X_val, y_val)) 360 | print_random_predictions(model, ctable, X_val, y_val) 361 | 362 | def clean_text(text): 363 | """Clean the text - remove unwanted chars, fold punctuation etc.""" 364 | result = NORMALIZE_WHITESPACE_REGEX.sub(' ', text.strip()) 365 | result = RE_DASH_FILTER.sub('-', result) 366 | result = RE_APOSTROPHE_FILTER.sub("'", result) 367 | result = RE_LEFT_PARENTH_FILTER.sub("(", result) 368 | result = RE_RIGHT_PARENTH_FILTER.sub(")", result) 369 | result = RE_BASIC_CLEANER.sub('', result) 370 | return result 371 | 372 | def preprocesses_data_clean(): 373 | """Pre-process the data - step 1 - cleanup""" 374 | with open(NEWS_FILE_NAME_CLEAN, "wb") as clean_data: 375 | for line in open(NEWS_FILE_NAME): 376 | decoded_line = line.decode('utf-8') 377 | cleaned_line = clean_text(decoded_line) 378 | encoded_line = cleaned_line.encode("utf-8") 379 | clean_data.write(encoded_line + b"\n") 380 | 381 | def preprocesses_data_analyze_chars(): 382 | """Pre-process the data - step 2 - analyze the characters""" 383 | counter = Counter() 384 | LOGGER.info("Reading data:") 385 | for line in open(NEWS_FILE_NAME_CLEAN): 386 | decoded_line = line.decode('utf-8') 387 | counter.update(decoded_line) 388 | # data = open(NEWS_FILE_NAME_CLEAN).read().decode('utf-8') 389 | # LOGGER.info("Read.\nCounting characters:") 390 | # counter = Counter(data.replace("\n", "")) 391 | LOGGER.info("Done.\nWriting to file:") 392 | with open(CHAR_FREQUENCY_FILE_NAME, 'wb') as output_file: 393 | output_file.write(json.dumps(counter)) 394 | most_popular_chars = {key for key, _value in counter.most_common(CONFIG.number_of_chars)} 395 | LOGGER.info("The top %s chars are:", CONFIG.number_of_chars) 396 | LOGGER.info("".join(sorted(most_popular_chars))) 397 | 398 | def read_top_chars(): 399 | """Read the top chars we saved to file""" 400 | chars = json.loads(open(CHAR_FREQUENCY_FILE_NAME).read()) 401 | counter = Counter(chars) 402 | most_popular_chars = {key for key, _value in counter.most_common(CONFIG.number_of_chars)} 403 | return most_popular_chars 404 | 405 | def preprocesses_data_filter(): 406 | """Pre-process the data - step 3 - filter only sentences with the right chars""" 407 | most_popular_chars = read_top_chars() 408 | LOGGER.info("Reading and filtering data:") 409 | with open(NEWS_FILE_NAME_FILTERED, "wb") as output_file: 410 | for line in open(NEWS_FILE_NAME_CLEAN): 411 | decoded_line = line.decode('utf-8') 412 | if decoded_line and not bool(set(decoded_line) - most_popular_chars): 413 | output_file.write(line) 414 | LOGGER.info("Done.") 415 | 416 | def read_filtered_data(): 417 | """Read the filtered data corpus""" 418 | LOGGER.info("Reading filtered data:") 419 | lines = open(NEWS_FILE_NAME_FILTERED).read().decode('utf-8').split("\n") 420 | LOGGER.info("Read filtered data - %s lines", len(lines)) 421 | return lines 422 | 423 | def preprocesses_split_lines(): 424 | """Preprocess the text by splitting the lines between min-length and max_length 425 | I don't like this step: 426 | I think the start-of-sentence is important. 427 | I think the end-of-sentence is important. 428 | Sometimes the stripped down sub-sentence is missing crucial context. 429 | Important NGRAMs are cut (though given enough data, that might be moot). 430 | I do this to enable batch-learning by padding to a fixed length. 431 | """ 432 | LOGGER.info("Reading filtered data:") 433 | answers = set() 434 | with open(NEWS_FILE_NAME_SPLIT, "wb") as output_file: 435 | for _line in open(NEWS_FILE_NAME_FILTERED): 436 | line = _line.decode('utf-8') 437 | while len(line) > MIN_INPUT_LEN: 438 | if len(line) <= CONFIG.max_input_len: 439 | answer = line 440 | line = "" 441 | else: 442 | space_location = line.rfind(" ", MIN_INPUT_LEN, CONFIG.max_input_len - 1) 443 | if space_location > -1: 444 | answer = line[:space_location] 445 | line = line[len(answer) + 1:] 446 | else: 447 | space_location = line.rfind(" ") # no limits this time 448 | if space_location == -1: 449 | break # we are done with this line 450 | else: 451 | line = line[space_location + 1:] 452 | continue 453 | answers.add(answer) 454 | output_file.write(answer.encode('utf-8') + b"\n") 455 | 456 | def preprocesses_split_lines2(): 457 | """Preprocess the text by splitting the lines between min-length and max_length 458 | Alternative split. 459 | """ 460 | LOGGER.info("Reading filtered data:") 461 | answers = set() 462 | for encoded_line in open(NEWS_FILE_NAME_FILTERED): 463 | line = encoded_line.decode('utf-8') 464 | if CONFIG.max_input_len >= len(line) > MIN_INPUT_LEN: 465 | answers.add(line) 466 | LOGGER.info("There are %s 'answers' (sub-sentences)", len(answers)) 467 | LOGGER.info("Here are some examples:") 468 | for answer in itertools.islice(answers, 10): 469 | LOGGER.info(answer) 470 | with open(NEWS_FILE_NAME_SPLIT, "wb") as output_file: 471 | output_file.write("".join(answers).encode('utf-8')) 472 | 473 | def preprocesses_split_lines3(): 474 | """Preprocess the text by selecting only max n-grams 475 | Alternative split. 476 | """ 477 | LOGGER.info("Reading filtered data:") 478 | answers = set() 479 | for encoded_line in open(NEWS_FILE_NAME_FILTERED): 480 | line = encoded_line.decode('utf-8') 481 | if line.count(" ") < 5: 482 | answers.add(line) 483 | LOGGER.info("There are %s 'answers' (sub-sentences)", len(answers)) 484 | LOGGER.info("Here are some examples:") 485 | for answer in itertools.islice(answers, 10): 486 | LOGGER.info(answer) 487 | with open(NEWS_FILE_NAME_SPLIT, "wb") as output_file: 488 | output_file.write("".join(answers).encode('utf-8')) 489 | 490 | def preprocesses_split_lines4(): 491 | """Preprocess the text by selecting only sentences with most-common words AND not too long 492 | Alternative split. 493 | """ 494 | LOGGER.info("Reading filtered data:") 495 | from gensim.models.word2vec import Word2Vec 496 | FILTERED_W2V = "fw2v.bin" 497 | model = Word2Vec.load_word2vec_format(FILTERED_W2V, binary=True) # C text format 498 | print(len(model.wv.index2word)) 499 | # answers = set() 500 | # for encoded_line in open(NEWS_FILE_NAME_FILTERED): 501 | # line = encoded_line.decode('utf-8') 502 | # if line.count(" ") < 5: 503 | # answers.add(line) 504 | # LOGGER.info("There are %s 'answers' (sub-sentences)", len(answers)) 505 | # LOGGER.info("Here are some examples:") 506 | # for answer in itertools.islice(answers, 10): 507 | # LOGGER.info(answer) 508 | # with open(NEWS_FILE_NAME_SPLIT, "wb") as output_file: 509 | # output_file.write("".join(answers).encode('utf-8')) 510 | 511 | def preprocess_partition_data(): 512 | """Set asside data for validation""" 513 | answers = open(NEWS_FILE_NAME_SPLIT).read().decode('utf-8').split("\n") 514 | print('shuffle', end=" ") 515 | random_shuffle(answers) 516 | print("Done") 517 | # Explicitly set apart 10% for validation data that we never train over 518 | split_at = len(answers) - len(answers) // 10 519 | with open(NEWS_FILE_NAME_TRAIN, "wb") as output_file: 520 | output_file.write("\n".join(answers[:split_at]).encode('utf-8')) 521 | with open(NEWS_FILE_NAME_VALIDATE, "wb") as output_file: 522 | output_file.write("\n".join(answers[split_at:]).encode('utf-8')) 523 | 524 | 525 | def generate_question(answer): 526 | """Generate a question by adding noise""" 527 | question = add_noise_to_string(answer, AMOUNT_OF_NOISE) 528 | # Add padding: 529 | question += PADDING * (CONFIG.max_input_len - len(question)) 530 | answer += PADDING * (CONFIG.max_input_len - len(answer)) 531 | return question, answer 532 | 533 | def generate_news_data(): 534 | """Generate some news data""" 535 | print ("Generating Data") 536 | answers = open(NEWS_FILE_NAME_SPLIT).read().decode('utf-8').split("\n") 537 | questions = [] 538 | print('shuffle', end=" ") 539 | random_shuffle(answers) 540 | print("Done") 541 | for answer_index, answer in enumerate(answers): 542 | question, answer = generate_question(answer) 543 | answers[answer_index] = answer 544 | assert len(answer) == CONFIG.max_input_len 545 | if random_randint(100000) == 8: # Show some progress 546 | print (len(answers)) 547 | print ("answer: '{}'".format(answer)) 548 | print ("question: '{}'".format(question)) 549 | print () 550 | question = question[::-1] if CONFIG.inverted else question 551 | questions.append(question) 552 | 553 | return questions, answers 554 | 555 | def train_speller_w_all_data(): 556 | """Train the speller if all data fits into RAM""" 557 | questions, answers = generate_news_data() 558 | chars_answer = set.union(*(set(answer) for answer in answers)) 559 | chars_question = set.union(*(set(question) for question in questions)) 560 | chars = list(set.union(chars_answer, chars_question)) 561 | X_train, X_val, y_train, y_val, y_maxlen, ctable = vectorize(questions, answers, chars) 562 | print ("y_maxlen, chars", y_maxlen, "".join(chars)) 563 | model = generate_model(y_maxlen, chars) 564 | iterate_training(model, X_train, y_train, X_val, y_val, ctable) 565 | 566 | def train_speller(from_file=None): 567 | """Train the speller""" 568 | if from_file: 569 | model = load_model(from_file) 570 | else: 571 | model = generate_model(CONFIG.max_input_len, chars=read_top_chars()) 572 | itarative_train(model) 573 | 574 | if __name__ == '__main__': 575 | # download_the_news_data() 576 | # uncompress_data() 577 | # preprocesses_data_clean() 578 | # preprocesses_data_analyze_chars() 579 | # preprocesses_data_filter() 580 | # preprocesses_split_lines() --- Choose this step or: 581 | # preprocesses_split_lines2() 582 | # preprocesses_split_lines4() 583 | # preprocess_partition_data() 584 | # train_speller(os.path.join(DATA_FILES_FULL_PATH, "keras_spell_e15.h5")) 585 | train_speller() 586 | -------------------------------------------------------------------------------- /license: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Tal Weiss 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | --------------------------------------------------------------------------------