├── .gitignore ├── LICENSE ├── RNN_utils.py └── recurrent_keras.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Trung Tran 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /RNN_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | 4 | # method for generating text 5 | def generate_text(model, length, vocab_size, ix_to_char): 6 | # starting with random character 7 | ix = [np.random.randint(vocab_size)] 8 | y_char = [ix_to_char[ix[-1]]] 9 | X = np.zeros((1, length, vocab_size)) 10 | for i in range(length): 11 | # appending the last predicted character to sequence 12 | X[0, i, :][ix[-1]] = 1 13 | print(ix_to_char[ix[-1]], end="") 14 | ix = np.argmax(model.predict(X[:, :i+1, :])[0], 1) 15 | y_char.append(ix_to_char[ix[-1]]) 16 | return ('').join(y_char) 17 | 18 | # method for preparing the training data 19 | def load_data(data_dir, seq_length): 20 | data = open(data_dir, 'r').read() 21 | chars = list(set(data)) 22 | VOCAB_SIZE = len(chars) 23 | 24 | print('Data length: {} characters'.format(len(data))) 25 | print('Vocabulary size: {} characters'.format(VOCAB_SIZE)) 26 | 27 | ix_to_char = {ix:char for ix, char in enumerate(chars)} 28 | char_to_ix = {char:ix for ix, char in enumerate(chars)} 29 | 30 | X = np.zeros((len(data)/seq_length, seq_length, VOCAB_SIZE)) 31 | y = np.zeros((len(data)/seq_length, seq_length, VOCAB_SIZE)) 32 | for i in range(0, len(data)/seq_length): 33 | X_sequence = data[i*seq_length:(i+1)*seq_length] 34 | X_sequence_ix = [char_to_ix[value] for value in X_sequence] 35 | input_sequence = np.zeros((seq_length, VOCAB_SIZE)) 36 | for j in range(seq_length): 37 | input_sequence[j][X_sequence_ix[j]] = 1. 38 | X[i] = input_sequence 39 | 40 | y_sequence = data[i*seq_length+1:(i+1)*seq_length+1] 41 | y_sequence_ix = [char_to_ix[value] for value in y_sequence] 42 | target_sequence = np.zeros((seq_length, VOCAB_SIZE)) 43 | for j in range(seq_length): 44 | target_sequence[j][y_sequence_ix[j]] = 1. 45 | y[i] = target_sequence 46 | return X, y, VOCAB_SIZE, ix_to_char 47 | -------------------------------------------------------------------------------- /recurrent_keras.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import time 5 | import csv 6 | from keras.models import Sequential 7 | from keras.layers.core import Dense, Activation, Dropout 8 | from keras.layers.recurrent import LSTM, SimpleRNN 9 | from keras.layers.wrappers import TimeDistributed 10 | import argparse 11 | from RNN_utils import * 12 | 13 | # Parsing arguments for Network definition 14 | ap = argparse.ArgumentParser() 15 | ap.add_argument('-data_dir', default='./data/test.txt') 16 | ap.add_argument('-batch_size', type=int, default=50) 17 | ap.add_argument('-layer_num', type=int, default=2) 18 | ap.add_argument('-seq_length', type=int, default=50) 19 | ap.add_argument('-hidden_dim', type=int, default=500) 20 | ap.add_argument('-generate_length', type=int, default=500) 21 | ap.add_argument('-nb_epoch', type=int, default=20) 22 | ap.add_argument('-mode', default='train') 23 | ap.add_argument('-weights', default='') 24 | args = vars(ap.parse_args()) 25 | 26 | DATA_DIR = args['data_dir'] 27 | BATCH_SIZE = args['batch_size'] 28 | HIDDEN_DIM = args['hidden_dim'] 29 | SEQ_LENGTH = args['seq_length'] 30 | WEIGHTS = args['weights'] 31 | 32 | GENERATE_LENGTH = args['generate_length'] 33 | LAYER_NUM = args['layer_num'] 34 | 35 | # Creating training data 36 | X, y, VOCAB_SIZE, ix_to_char = load_data(DATA_DIR, SEQ_LENGTH) 37 | 38 | # Creating and compiling the Network 39 | model = Sequential() 40 | model.add(LSTM(HIDDEN_DIM, input_shape=(None, VOCAB_SIZE), return_sequences=True)) 41 | for i in range(LAYER_NUM - 1): 42 | model.add(LSTM(HIDDEN_DIM, return_sequences=True)) 43 | model.add(TimeDistributed(Dense(VOCAB_SIZE))) 44 | model.add(Activation('softmax')) 45 | model.compile(loss="categorical_crossentropy", optimizer="rmsprop") 46 | 47 | # Generate some sample before training to know how bad it is! 48 | generate_text(model, args['generate_length'], VOCAB_SIZE, ix_to_char) 49 | 50 | if not WEIGHTS == '': 51 | model.load_weights(WEIGHTS) 52 | nb_epoch = int(WEIGHTS[WEIGHTS.rfind('_') + 1:WEIGHTS.find('.')]) 53 | else: 54 | nb_epoch = 0 55 | 56 | # Training if there is no trained weights specified 57 | if args['mode'] == 'train' or WEIGHTS == '': 58 | while True: 59 | print('\n\nEpoch: {}\n'.format(nb_epoch)) 60 | model.fit(X, y, batch_size=BATCH_SIZE, verbose=1, nb_epoch=1) 61 | nb_epoch += 1 62 | generate_text(model, GENERATE_LENGTH, VOCAB_SIZE, ix_to_char) 63 | if nb_epoch % 10 == 0: 64 | model.save_weights('checkpoint_layer_{}_hidden_{}_epoch_{}.hdf5'.format(LAYER_NUM, HIDDEN_DIM, nb_epoch)) 65 | 66 | # Else, loading the trained weights and performing generation only 67 | elif WEIGHTS == '': 68 | # Loading the trained weights 69 | model.load_weights(WEIGHTS) 70 | generate_text(model, GENERATE_LENGTH, VOCAB_SIZE, ix_to_char) 71 | print('\n\n') 72 | else: 73 | print('\n\nNothing to do!') 74 | --------------------------------------------------------------------------------