├── src ├── matplotlibrc ├── path_toke.txt ├── plot_data.py ├── general_tools.py ├── RNN.py ├── preprocess.py └── RNN_tools.py ├── output ├── best_model.npz └── best_model_var.pkl ├── .gitattributes ├── README.md └── .gitignore /src/matplotlibrc: -------------------------------------------------------------------------------- 1 | backend : Qt4Agg 2 | -------------------------------------------------------------------------------- /output/best_model.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Faur/TIMIT/HEAD/output/best_model.npz -------------------------------------------------------------------------------- /output/best_model_var.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Faur/TIMIT/HEAD/output/best_model_var.pkl -------------------------------------------------------------------------------- /src/path_toke.txt: -------------------------------------------------------------------------------- 1 | \\home.cc.dtu.dk\s136232\Desktop\timit\timit 2 | 3 | \\\\home.cc.dtu.dk\\s136232\\Desktop\\timit\\timit 4 | 5 | 6 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | 4 | # Custom for Visual Studio 5 | *.cs diff=csharp 6 | 7 | # Standard to msysgit 8 | *.doc diff=astextplain 9 | *.DOC diff=astextplain 10 | *.docx diff=astextplain 11 | *.DOCX diff=astextplain 12 | *.dot diff=astextplain 13 | *.DOT diff=astextplain 14 | *.pdf diff=astextplain 15 | *.PDF diff=astextplain 16 | *.rtf diff=astextplain 17 | *.RTF diff=astextplain 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TIMIT 2 | Framewise phoneme classification on the TIMIT dataset using neural networks 3 | 4 | The recurrent neural network is strongly inspired by the work by Alex Graves: Supervised Sequence Labelling with Recurrent Neural Networks (2012) http://www.cs.toronto.edu/~graves/preprint.pdf 5 | 6 | The code was originally created for the course '01666 Project work - Bachelor of Mathematics and Technology' at The Techincal University of Denmark by Toke Faurby and Kristoffer Linder-steinlein. 7 | -------------------------------------------------------------------------------- /src/plot_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | from general_tools import * 6 | 7 | 8 | paths = path_reader('path_toke.txt') 9 | dataset_path = os.path.join(paths[0], 'std_preprocess_26_ch_DEBUG.pkl') 10 | 11 | X_train, y_train, X_val, y_val, X_test, y_test = load_dataset(dataset_path) 12 | 13 | 14 | plt.figure(1) 15 | plt.title('Preprocessed data visualization') 16 | for i in range(1,5): 17 | plt.subplot(2,2, i) 18 | plt.axis('off') 19 | plt.imshow(X_train[i].T) 20 | # plt.imshow(np.log(X_train[i].T)) 21 | # print(X_train[i].shape) 22 | 23 | 24 | 25 | plt.tight_layout() 26 | plt.show() 27 | -------------------------------------------------------------------------------- /src/general_tools.py: -------------------------------------------------------------------------------- 1 | from six.moves import cPickle 2 | 3 | 4 | def path_reader(filename): 5 | with open(filename) as f: 6 | path_list = f.read().splitlines() 7 | return path_list 8 | 9 | def load_dataset(file_path): 10 | with open(file_path, 'rb') as cPickle_file: 11 | [X_train, y_train, X_val, y_val, X_test, y_test] = cPickle.load(cPickle_file) 12 | if not X_train: 13 | print('WARNING: X_train is empty') 14 | if not y_train: 15 | print('WARNING: y_train is empty') 16 | if not X_val: 17 | print('WARNING: X_val is empty') 18 | if not y_val: 19 | print('WARNING: y_val is empty') 20 | if not X_test: 21 | print('WARNING: X_test is empty') 22 | if not y_test: 23 | print('WARNING: y_test is empty') 24 | return X_train, y_train, X_val, y_val, X_test, y_test 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /src/RNN.py: -------------------------------------------------------------------------------- 1 | print('\n\n * Imporing Libaries') 2 | import os 3 | import time; program_start_time = time.time() 4 | import sys 5 | from six.moves import cPickle 6 | 7 | import theano 8 | import theano.tensor as T 9 | import lasagne 10 | import lasagne.layers as L 11 | import numpy as np 12 | 13 | from general_tools import * 14 | from RNN_tools import * 15 | 16 | 17 | ##### SCRIPT META VARIABLES ##### 18 | print(' * Setting up ...') 19 | 20 | comput_confusion = False 21 | # TODO: ATM this is not implemented 22 | 23 | 24 | paths = path_reader('path_toke.txt') 25 | output_path = os.path.join('..', 'output') 26 | 27 | if 1: 28 | data_path = os.path.join(paths[0], 'std_preprocess_26_ch.pkl') 29 | model_load = os.path.join(output_path, 'best_model.npz') 30 | model_save = os.path.join(output_path, 'best_model') 31 | INPUT_SIZE = 26 32 | else: 33 | data_path = os.path.join(paths[0], 'std_preprocess_26_ch_DEBUG.pkl') 34 | model_load = os.path.join(output_path, 'best_model_DEBUG.npz') 35 | model_save = os.path.join(output_path, 'best_model_DEBUG') 36 | INPUT_SIZE = 26 37 | print('DEBUG MODE ACTIVE: Only a reduced dataset is used.') 38 | 39 | ##### SCRIPT VARIABLES ##### 40 | num_epochs = 20 41 | 42 | NUM_OUTPUT_UNITS= 61 43 | N_HIDDEN = 275 44 | 45 | LEARNING_RATE = 1e-5 46 | MOMENTUM = 0.9 47 | WEIGHT_INIT = 0.1 48 | batch_size = 1 49 | 50 | 51 | ##### IMPORTIN DATA ##### 52 | print('\tdata source: '+ data_path) 53 | print('\tmodel target: '+ model_save + '.npz') 54 | dataset = load_dataset(data_path) 55 | X_train, y_train, X_val, y_val, X_test, y_test = dataset 56 | 57 | 58 | ##### BUIDING MODEL ##### 59 | print(' * Building network ...') 60 | RNN_network = NeuralNetwork('RNN', batch_size=batch_size, input_size=INPUT_SIZE, n_hidden=N_HIDDEN, 61 | num_output_units=NUM_OUTPUT_UNITS, seed=int(time.time()), debug=False) 62 | 63 | RNN_network.load_model(model_load) 64 | 65 | ##### BUIDING FUNCTION ##### 66 | print(" * Compiling functions ...") 67 | RNN_network.build_functions(LEARNING_RATE=LEARNING_RATE, MOMENTUM=MOMENTUM, debug=False) 68 | 69 | 70 | ##### TRAINING ##### 71 | print(" * Training ...") 72 | RNN_network.train(dataset, model_save, num_epochs=num_epochs, 73 | batch_size=batch_size, comput_confusion=False, debug=False) 74 | 75 | 76 | print() 77 | print(" * Done") 78 | print() 79 | print('Total time: {:.3f}'.format(time.time() - program_start_time)) 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask instance folder 57 | instance/ 58 | 59 | # Scrapy stuff: 60 | .scrapy 61 | 62 | # Sphinx documentation 63 | docs/_build/ 64 | 65 | # PyBuilder 66 | target/ 67 | 68 | # IPython Notebook 69 | .ipynb_checkpoints 70 | 71 | # pyenv 72 | .python-version 73 | 74 | # celery beat schedule file 75 | celerybeat-schedule 76 | 77 | # dotenv 78 | .env 79 | 80 | # virtualenv 81 | venv/ 82 | ENV/ 83 | 84 | # Spyder project settings 85 | .spyderproject 86 | 87 | # Rope project settings 88 | .ropeproject 89 | 90 | # ========================= 91 | # Operating System Files 92 | # ========================= 93 | 94 | # OSX 95 | # ========================= 96 | 97 | .DS_Store 98 | .AppleDouble 99 | .LSOverride 100 | 101 | # Thumbnails 102 | ._* 103 | 104 | # Files that might appear in the root of a volume 105 | .DocumentRevisions-V100 106 | .fseventsd 107 | .Spotlight-V100 108 | .TemporaryItems 109 | .Trashes 110 | .VolumeIcon.icns 111 | 112 | # Directories potentially created on remote AFP share 113 | .AppleDB 114 | .AppleDesktop 115 | Network Trash Folder 116 | Temporary Items 117 | .apdisk 118 | 119 | # Windows 120 | # ========================= 121 | 122 | # Windows image file caches 123 | Thumbs.db 124 | ehthumbs.db 125 | 126 | # Folder config file 127 | Desktop.ini 128 | 129 | # Recycle Bin used on file shares 130 | $RECYCLE.BIN/ 131 | 132 | # Windows Installer files 133 | *.cab 134 | *.msi 135 | *.msm 136 | *.msp 137 | 138 | # Windows shortcuts 139 | *.lnk 140 | 141 | 142 | # Othe ignores 143 | src/old_code 144 | *.npz 145 | *.pkl -------------------------------------------------------------------------------- /src/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wave 3 | import timeit; program_start_time = timeit.default_timer() 4 | import random; random.seed(int(timeit.default_timer())) 5 | from six.moves import cPickle 6 | 7 | import numpy as np 8 | import scipy.io.wavfile as wav 9 | import matplotlib.pyplot as plt 10 | 11 | from general_tools import * 12 | import features 13 | # https://github.com/jameslyons/python_speech_features 14 | 15 | 16 | ##### SCRIPT META VARIABLES ##### 17 | VERBOSE = True 18 | DEBUG = True 19 | debug_size = 5 20 | # Convert only a reduced dataset 21 | visualize = False 22 | 23 | ##### SCRIPT VARIABLES ##### 24 | train_size = 3696 25 | val_size = 184 26 | test_size = 1344 27 | 28 | data_type = 'float32' 29 | 30 | paths = path_reader('path_toke.txt') 31 | train_source_path = os.path.join(paths[0], 'train') 32 | test_source_path = os.path.join(paths[0], 'test') 33 | target_path = os.path.join(paths[0], 'std_preprocess_26_ch') 34 | 35 | ##### SETUP ##### 36 | if VERBOSE: 37 | print('VERBOSE mode: \tACTIVE') 38 | else: 39 | print('VERBOSE mode: \tDEACTIVE') 40 | 41 | if DEBUG: 42 | print('DEBUG mode: \tACTIVE, only a small dataset will be preprocessed') 43 | target_path += '_DEBUG' 44 | else: 45 | print('DEBUG mode: \tDEACTIVE') 46 | 47 | 48 | 49 | 50 | phonemes = ["b", "bcl", "d", "dcl", "g", "gcl", "p", "pcl", "t", "tcl", "k", "kcl", "dx", "q", "jh", "ch", "s", "sh", "z", "zh", 51 | "f", "th", "v", "dh", "m", "n", "ng", "em", "en", "eng", "nx", "l", "r", "w", "y", 52 | "hh", "hv", "el", "iy", "ih", "eh", "ey", "ae", "aa", "aw", "ay", "ah", "ao", "oy", 53 | "ow", "uh", "uw", "ux", "er", "ax", "ix", "axr", "ax-h", "pau", "epi", "h#"] 54 | # 61 different phonemes 55 | 56 | def get_total_duration(file): 57 | """Get the length of the phoneme file, i.e. the 'time stamp' of the last phoneme""" 58 | for line in reversed(list(open(file))): 59 | [_, val, _] = line.split() 60 | return int(val) 61 | 62 | def find_phoneme (phoneme_idx): 63 | for i in range(len(phonemes)): 64 | if phoneme_idx == phonemes[i]: 65 | return i 66 | print("PHONEME NOT FOUND, NaN CREATED!") 67 | print("\t" + phoneme_idx + " wasn't found!") 68 | return -1 69 | 70 | def create_mfcc(method, filename): 71 | """Perform standard preprocessing, as described by Alex Graves (2012) 72 | http://www.cs.toronto.edu/~graves/preprint.pdf 73 | Output consists of 12 MFCC and 1 energy, as well as the first derivative of these. 74 | [1 energy, 12 MFCC, 1 diff(energy), 12 diff(MFCC) 75 | 76 | method is a dummy input!!""" 77 | 78 | (rate,sample) = wav.read(filename) 79 | 80 | mfcc = features.mfcc(sample, rate, winlen=0.025, winstep=0.01, numcep = 13, nfilt=26, 81 | preemph=0.97, appendEnergy=True) 82 | 83 | derivative = np.zeros(mfcc.shape) 84 | for i in range(1, mfcc.shape[0]-1): 85 | derivative[i, :] = mfcc[i+1, :] - mfcc[i-1, :] 86 | 87 | out = np.concatenate((mfcc, derivative), axis=1) 88 | 89 | return out, out.shape[0] 90 | 91 | def calc_norm_param(X, VERBOSE=False): 92 | """Assumes X to be a list of arrays (of differing sizes)""" 93 | total_len = 0 94 | mean_val = np.zeros(X[0].shape[1]) 95 | std_val = np.zeros(X[0].shape[1]) 96 | for obs in X: 97 | obs_len = obs.shape[0] 98 | mean_val += np.mean(obs,axis=0)*obs_len 99 | std_val += np.std(obs, axis=0)*obs_len 100 | total_len += obs_len 101 | 102 | mean_val /= total_len 103 | std_val /= total_len 104 | 105 | if VERBOSE: 106 | print(total_len) 107 | print(mean_val.shape) 108 | print(' {}'.format(mean_val)) 109 | print(std_val.shape) 110 | print(' {}'.format(std_val)) 111 | 112 | return mean_val, std_val, total_len 113 | 114 | def normalize(X, mean_val, std_val): 115 | for i in range(len(X)): 116 | X[i] = (X[i] - mean_val)/std_val 117 | return X 118 | 119 | def set_type(X, type): 120 | for i in range(len(X)): 121 | X[i] = X[i].astype(type) 122 | return X 123 | 124 | 125 | def preprocess_dataset(source_path, VERBOSE=False, visualize=False): 126 | """Preprocess data, ignoring compressed files and files starting with 'SA'""" 127 | i = 0 128 | X = [] 129 | Y = [] 130 | fig = [] 131 | num_plot = 4 132 | 133 | for dirName, subdirList, fileList in os.walk(source_path): 134 | for fname in fileList: 135 | if not fname.endswith('.PHN') or (fname.startswith("SA")): 136 | continue 137 | 138 | phn_fname = dirName + '\\' + fname 139 | wav_fname = dirName + '\\' + fname[0:-4] + '_.WAV' 140 | 141 | total_duration = get_total_duration(phn_fname) 142 | fr = open(phn_fname) 143 | 144 | 145 | if visualize: 146 | curr_fig = plt.figure(i) 147 | wav_file = wave.open(wav_fname, 'r') 148 | signal = wav_file.readframes(-1) 149 | signal = np.fromstring(signal, 'Int16') 150 | frame_rate = wav_file.getframerate() 151 | 152 | if wav_file.getnchannels() == 2: 153 | print('ONLY MONO FILES') 154 | 155 | x_axis = np.linspace(0, len(signal)/frame_rate, num=len(signal)) 156 | ax1 = plt.subplot(num_plot,1,1) 157 | # plt.title('Original wave data') 158 | plt.plot(x_axis, signal) 159 | ax1.set_xlim([0, len(signal)/frame_rate]) 160 | 161 | plt.ylabel('Original wave data') 162 | plt.tick_params( 163 | axis='both', # changes apply to the axis 164 | which='both', # both major and minor ticks are affected 165 | bottom='off', # ticks along the bottom 166 | top='off', # ticks along the top 167 | right='off', # ticks along the right 168 | left='off', # ticks along the left 169 | labelbottom='off', # labels along the bottom 170 | labelleft='off') # labels along the top 171 | 172 | # plt.gca().axes.get_xaxis().set_visible(False) 173 | 174 | 175 | X_val, total_frames = create_mfcc('DUMMY', wav_fname) 176 | total_frames = int(total_frames) 177 | 178 | X.append(X_val) 179 | if visualize: 180 | plt.subplot(num_plot,1,3) 181 | plt.imshow(X_val.T, interpolation='nearest', aspect='auto') 182 | # plt.axis('off') 183 | # plt.title('Preprocessed data') 184 | 185 | plt.ylabel('Preprocessed data') 186 | plt.tick_params( 187 | axis='both', # changes apply to the axis 188 | which='both', # both major and minor ticks are affected 189 | bottom='off', # ticks along the bottom 190 | top='off', # ticks along the top 191 | right='off', # ticks along the right 192 | left='off', # ticks along the left 193 | labelbottom='off', # labels along the bottom 194 | labelleft='off') # labels along the top 195 | 196 | 197 | y_val = np.zeros(total_frames) - 1 198 | start_ind = 0 199 | for line in fr: 200 | [start_time, end_time, phoneme] = line.rstrip('\n').split() 201 | start_time = int(start_time) 202 | end_time = int(end_time) 203 | 204 | phoneme_num = find_phoneme(phoneme) 205 | end_ind = np.round((end_time)/total_duration*total_frames) 206 | y_val[start_ind:end_ind] = phoneme_num 207 | 208 | start_ind = end_ind 209 | fr.close() 210 | 211 | if -1 in y_val: 212 | print('WARNING: -1 detected in TARGET') 213 | print(y_val) 214 | 215 | Y.append(y_val.astype('int32')) 216 | if visualize: 217 | plt.subplot(num_plot,1,2) 218 | plt.imshow((y_val.T, ), aspect='auto') 219 | 220 | plt.ylabel('Lables') 221 | plt.tick_params( 222 | axis='both', # changes apply to the axis 223 | which='both', # both major and minor ticks are affected 224 | bottom='off', # ticks along the bottom 225 | top='off', # ticks along the top 226 | right='off', # ticks along the right 227 | left='off', # ticks along the left 228 | labelbottom='off', # labels along the bottom 229 | labelleft='off') # labels along the top 230 | 231 | if visualize: 232 | plt.subplots_adjust(hspace=0.01) 233 | plt.tight_layout() 234 | # plt.show() 235 | fig.append(curr_fig) 236 | 237 | i+=1 238 | if VERBOSE: 239 | print() 240 | print('({}) create_target_vector: {}'.format(i, phn_fname[:-4])) 241 | print('type(X_val): \t\t {}'.format(type(X_val))) 242 | print('X_val.shape: \t\t {}'.format(X_val.shape)) 243 | print('type(X_val[0][0]):\t {}'.format(type(X_val[0][0]))) 244 | else: 245 | print(i, end=' ', flush=True) 246 | if i >= debug_size and DEBUG: 247 | break 248 | 249 | if i >= debug_size and DEBUG: 250 | break 251 | print() 252 | return X, Y, fig 253 | 254 | 255 | 256 | ##### PREPROCESSING ##### 257 | print() 258 | print('Creating Validation index ...') 259 | val_idx = random.sample(range(0, train_size), val_size) 260 | val_idx = [int(i) for i in val_idx] 261 | # ensure that the validation set isn't empty 262 | if DEBUG: 263 | val_idx[0] = 0 264 | val_idx[1] = 1 265 | 266 | 267 | print('Preprocessing data ...') 268 | print(' This will take a while') 269 | X_train_all, y_train_all, _ = preprocess_dataset(train_source_path, 270 | VERBOSE=False, visualize=False) 271 | X_test, y_test, test_figs = preprocess_dataset(test_source_path, 272 | VERBOSE=False, visualize=visualize) 273 | # figs = list(map(plt.figure, plt.get_fignums())) 274 | 275 | print(' Preprocessing changesomplete') 276 | 277 | if VERBOSE: 278 | print() 279 | print('Type and shape/len of X_train_all') 280 | print('type(X_train_all): {}'.format(type(X_train_all))) 281 | print('type(X_train_all[0]): {}'.format(type(X_train_all[0]))) 282 | print('type(X_train_all[0][0]): {}'.format(type(X_train_all[0][0]))) 283 | print('type(X_train_all[0][0][0]): {}'.format(type(X_train_all[0][0][0]))) 284 | 285 | 286 | print('Separating validation and training set ...') 287 | X_train = []; X_val = [] 288 | y_train = []; y_val = [] 289 | for i in range(len(X_train_all)): 290 | if i in val_idx: 291 | X_val.append(X_train_all[i]) 292 | y_val.append(y_train_all[i]) 293 | else: 294 | X_train.append(X_train_all[i]) 295 | y_train.append(y_train_all[i]) 296 | 297 | if VERBOSE: 298 | print() 299 | print('Length of train, val, test') 300 | print(len(X_train)) 301 | print(len(y_train)) 302 | 303 | print(len(X_val)) 304 | print(len(y_val)) 305 | 306 | print(len(X_test)) 307 | print(len(y_test)) 308 | 309 | if VERBOSE: 310 | print() 311 | print('Type of train') 312 | print(type(X_train)) 313 | print(type(y_train)) 314 | print(type(X_train[0]), X_train[0].shape) 315 | print(type(y_train[0]), y_train[0].shape) 316 | 317 | 318 | print() 319 | print('Normalizing data ...') 320 | print(' Each channel mean=0, sd=1 ...') 321 | 322 | mean_val, std_val, _ = calc_norm_param(X_train) 323 | 324 | X_train = normalize(X_train, mean_val, std_val) 325 | X_val = normalize(X_val, mean_val, std_val) 326 | X_test = normalize(X_test, mean_val, std_val) 327 | 328 | X_train = set_type(X_train, data_type) 329 | X_val = set_type(X_val, data_type) 330 | X_test = set_type(X_test, data_type) 331 | 332 | if visualize == True: 333 | for i in range(debug_size): 334 | plt.figure(i) 335 | plt.subplot(4,1,4) 336 | 337 | plt.imshow(X_test[i].T, interpolation='nearest', aspect='auto') 338 | # plt.axis('off') 339 | # plt.title('Preprocessed data') 340 | 341 | plt.ylabel('Normalized data') 342 | plt.tick_params( 343 | axis='both', # changes apply to the axis 344 | which='both', # both major and minor ticks are affected 345 | bottom='off', # ticks along the bottom 346 | top='off', # ticks along the top 347 | right='off', # ticks along the right 348 | left='off', # ticks along the left 349 | labelbottom='off', # labels along the bottom 350 | labelleft='off') # labels along the top 351 | 352 | plt.show() 353 | 354 | 355 | 356 | 357 | print('Saving data ...') 358 | print(' ', target_path) 359 | with open(target_path + '.pkl', 'wb') as cPickle_file: 360 | cPickle.dump( 361 | [X_train, y_train, X_val, y_val, X_test, y_test], 362 | cPickle_file, 363 | protocol=cPickle.HIGHEST_PROTOCOL) 364 | 365 | print('Preprocessing complete!') 366 | print() 367 | 368 | 369 | 370 | print('Total time: {:.3f}'.format(timeit.default_timer() - program_start_time)) 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | -------------------------------------------------------------------------------- /src/RNN_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import theano 3 | import theano.tensor as T 4 | import lasagne 5 | import lasagne.layers as L 6 | import time 7 | 8 | from six.moves import cPickle 9 | 10 | 11 | def iterate_minibatches(inputs, targets, batch_size, shuffle=False): 12 | """ 13 | Helper function that returns an iterator over the training data of a particular 14 | size, optionally in a random order. 15 | 16 | For big data sets you can load numpy arrays as memory-mapped files 17 | (numpy.load(..., mmap_mode='r')) 18 | 19 | This function a slight modification of: 20 | http://lasagne.readthedocs.org/en/latest/user/tutorial.html 21 | """ 22 | assert len(inputs) == len(targets) 23 | 24 | if shuffle: 25 | indices = np.arange(len(inputs)) 26 | np.random.shuffle(indices) 27 | 28 | for start_idx in range(0, len(inputs) - batch_size + 1, batch_size): 29 | if shuffle: 30 | excerpt = indices[start_idx:start_idx + batch_size] 31 | else: 32 | # excerpt = slice(start_idx, start_idx + batch_size) 33 | excerpt = range(start_idx, start_idx + batch_size, 1) 34 | 35 | input_iter = [inputs[i] for i in excerpt] 36 | target_iter= [targets[i] for i in excerpt] 37 | yield input_iter, target_iter 38 | # yield inputs[excerpt], targets[excerpt] 39 | 40 | 41 | class NeuralNetwork: 42 | network = None 43 | training_fn = None 44 | best_param = None 45 | best_error = 100 46 | curr_epoch, best_epoch = 0, 0 47 | 48 | network_train_info = [[], [], []] 49 | # [[Train], [val], [test]] 50 | 51 | def build_RNN(self, batch_size=1, input_size=26, n_hidden=275, num_output_units=61, 52 | weight_init=0.1, activation_fn=lasagne.nonlinearities.sigmoid, 53 | seed=int(time.time()), debug=False): 54 | np.random.seed(seed) 55 | # seed np for weight initialization 56 | 57 | l_in = L.InputLayer(shape=(batch_size, None, input_size)) 58 | # l_in = L.InputLayer(shape=(None, None, input_size)) 59 | # (batch_size, max_time_steps, n_features_1, n_features_2, ...) 60 | # Only stochastic gradient descent 61 | if debug: 62 | get_l_in = theano.function([l_in.input_var], L.get_output(l_in)) 63 | l_in_val = get_l_in(X) 64 | print('output size:', end='\t'); print(Y.shape) 65 | print('input size:', end='\t'); print(X[0].shape) 66 | print('l_in size:', end='\t'); print(l_in_val.shape) 67 | 68 | l_rnn = L.recurrent.RecurrentLayer( 69 | l_in, num_units=n_hidden, 70 | nonlinearity=activation_fn, 71 | W_in_to_hid=lasagne.init.Uniform(weight_init), 72 | W_hid_to_hid=lasagne.init.Uniform(weight_init), 73 | b=lasagne.init.Constant(0.), 74 | hid_init=lasagne.init.Constant(0.), 75 | learn_init=False) 76 | if debug: 77 | get_l_rnn = theano.function([l_in.input_var], L.get_output(l_rnn)) 78 | l_rnn_val = get_l_rnn(X) 79 | print('l_rnn size:', end='\t'); print(l_rnn_val.shape) 80 | 81 | 82 | l_reshape = L.ReshapeLayer(l_rnn, (-1, n_hidden)) 83 | if debug: 84 | get_l_reshape = theano.function([l_in.input_var], L.get_output(l_reshape)) 85 | l_reshape_val = get_l_reshape(X) 86 | print('l_reshape size:', end='\t'); print(l_reshape_val.shape) 87 | 88 | 89 | l_out = L.DenseLayer(l_reshape, num_units=num_output_units, 90 | nonlinearity=T.nnet.softmax) 91 | 92 | self.network = l_out 93 | 94 | 95 | def __init__(self, architecture, **kwargs): 96 | if architecture == 'RNN': 97 | self.build_RNN(**kwargs) 98 | else: 99 | print("ERROR: Invalid argument: The valid architecture arguments are: 'RNN'") 100 | 101 | 102 | # def save_network(self, network_name): 103 | # try: 104 | # f = open(network_name, 'wb') 105 | # save = { 106 | # 'model_param': *L.get_all_param_values(self.network), 107 | # 'train_labels': train_labels, 108 | # 'valid_dataset': valid_dataset, 109 | # 'valid_labels': valid_labels, 110 | # 'test_dataset': test_dataset, 111 | # 'test_labels': test_labels, 112 | # } 113 | # cPickle.dump(save, f, cPickle.HIGHEST_PROTOCOL) 114 | # f.close() 115 | # except Exception as e: 116 | # print('Unable to save data to {} : {}'.format(pickle_file, e)) 117 | # raise 118 | 119 | 120 | # def load_network(self, network_name): 121 | # try: 122 | # print() 123 | # except FileNotFoundError: 124 | # print('File: {} not found. Nothing loaded'.format(model_name)) 125 | 126 | def use_best_param(self): 127 | lasagne.layers.set_all_param_values(self.network, self.best_param) 128 | self.curr_epoch = self.best_epoch 129 | # Remove the network_train_info enries newer than self.best_epoch 130 | del self.network_train_info[0][self.best_epoch:] 131 | del self.network_train_info[1][self.best_epoch:] 132 | del self.network_train_info[2][self.best_epoch:] 133 | 134 | 135 | def load_model(self, model_name): 136 | if self.network is not None: 137 | try: 138 | with np.load(model_name) as f: 139 | param_values = [f['arr_%d' % i] for i in range(len(f.files))] 140 | # param_values[0] = param_values[0].astype('float32') 141 | param_values = [param_values[i].astype('float32') for i in range(len(param_values))] 142 | lasagne.layers.set_all_param_values(self.network, param_values) 143 | except FileNotFoundError: 144 | print('Model: {} not found. No weights loaded'.format(model_name)) 145 | else: 146 | print('You must build the network before loading the weights.') 147 | 148 | 149 | def save_model(self, model_name): 150 | np.savez(model_name, *L.get_all_param_values(self.network)) 151 | 152 | 153 | def build_functions(self, LEARNING_RATE=1e-5, MOMENTUM=0.9, debug=False): 154 | target_var = T.ivector('targets') 155 | 156 | # Get the first layer of the network 157 | l_in = L.get_all_layers(self.network)[0] 158 | 159 | network_output = L.get_output(self.network) 160 | 161 | # Retrieve all trainable parameters from the network 162 | all_params = L.get_all_params(self.network, trainable=True) 163 | 164 | # loss = T.mean(lasagne.objectives.categorical_crossentropy(network_output, target_var)) 165 | loss = T.sum(lasagne.objectives.categorical_crossentropy(network_output, target_var)) 166 | 167 | # use Stochastic Gradient Descent with nesterov momentum to update parameters 168 | updates = lasagne.updates.momentum(loss, all_params, 169 | learning_rate = LEARNING_RATE, 170 | momentum = MOMENTUM) 171 | 172 | # Function to determine the number of correct classifications 173 | accuracy = T.mean(T.eq(T.argmax(network_output, axis=1), target_var), 174 | dtype=theano.config.floatX) 175 | 176 | # Function to get the output of the network 177 | output_fn = theano.function([l_in.input_var], network_output, name='output_fn') 178 | if debug: 179 | l_out_val = output_fn(X) 180 | print('l_out size:', end='\t'); print(l_out_val.shape, end='\t'); 181 | print('min/max: [{:.2f},{:.2f}]'.format(l_out_val.min(), l_out_val.max())) 182 | 183 | argmax_fn = theano.function([l_in.input_var], [T.argmax(network_output, axis=1)], 184 | name='argmax_fn') 185 | if debug: 186 | print('argmax_fn') 187 | print(type(argmax_fn(X)[0])) 188 | print(argmax_fn(X)[0].shape) 189 | 190 | # Function implementing one step of gradient descent 191 | train_fn = theano.function([l_in.input_var, target_var], [loss, accuracy], 192 | updates=updates, name='train_fn') 193 | 194 | # Function calculating the loss and accuracy 195 | validate_fn = theano.function([l_in.input_var, target_var], [loss, accuracy], 196 | name='validate_fn') 197 | if debug: 198 | print(type(train_fn(X, Y))) 199 | # print('loss: {:.3f}'.format( float(train_fn(X, Y)))) 200 | # print('accuracy: {:.3f}'.format( float(validate_fn(X, Y)[1]) )) 201 | 202 | self.training_fn = output_fn, argmax_fn, train_fn, validate_fn 203 | 204 | 205 | def create_confusion(self, X, y, debug=False): 206 | argmax_fn = self.training_fn[1] 207 | 208 | y_pred = [] 209 | for X_obs in X: 210 | for x in argmax_fn(X_obs): 211 | for j in x: 212 | y_pred.append(j) 213 | 214 | y_actu = [] 215 | for Y in y: 216 | for y in Y: 217 | y_actu.append(y) 218 | 219 | conf_img = np.zeros([61, 61]) 220 | assert (len(y_pred) == len(y_actu)) 221 | 222 | for i in range(len(y_pred)): 223 | row_idx = y_actu[i] 224 | col_idx = y_pred[i] 225 | conf_img[row_idx, col_idx] += 1 226 | 227 | return conf_img, y_pred, y_actu 228 | 229 | def create_learning_curves(self): 230 | pass 231 | 232 | def visualize_training(self, learning_curves, confusion): 233 | pass 234 | 235 | def train(self, dataset, save_name='Best_model', num_epochs=100, batch_size=1, 236 | comput_confusion=False, debug=False): 237 | """Curently one batch_size=1 is supported""" 238 | 239 | 240 | X_train, y_train, X_val, y_val, X_test, y_test = dataset 241 | output_fn, argmax_fn, train_fn, validate_fn = self.training_fn 242 | 243 | if debug: 244 | print('X_train', end='\t\t') 245 | print(type(X_train), end='\t'); print(len(X_train)) 246 | print('X_train[0]', end='\t') 247 | print(type(X_train[0]), end='\t'); print(X_train[0].shape) 248 | print('X_train[0][0]', end='\t') 249 | print(type(X_train[0][0]), end='\t');print(X_train[0][0].shape) 250 | print('X_train[0][0][0]', end='\t') 251 | print(type(X_train[0][0][0]), end='\t');print(X_train[0][0][0].shape) 252 | 253 | print('y_train', end='\t\t') 254 | print(type(y_train), end='\t'); print(len(X_train)) 255 | print('y_train[0]', end='\t') 256 | print(type(y_train[0]), end='\t'); print(y_train[0].shape) 257 | print('y_train[0][0]', end='\t') 258 | print(type(y_train[0][0]), end='\t');print(y_train[0][0].shape) 259 | print() 260 | 261 | # Initiate some vectors used for tracking performance 262 | train_error = np.zeros([num_epochs]) 263 | train_accuracy = np.zeros([num_epochs]) 264 | train_batches = np.zeros([num_epochs]) 265 | validation_error = np.zeros([num_epochs]) 266 | validation_accuracy = np.zeros([num_epochs]) 267 | validation_batches = np.zeros([num_epochs]) 268 | test_error = np.zeros([num_epochs]) 269 | test_accuracy = np.zeros([num_epochs]) 270 | test_batches = np.zeros([num_epochs]) 271 | confusion_matrices = [] 272 | 273 | for epoch in range(num_epochs): 274 | self.curr_epoch += 1 275 | epoch_time = time.time() 276 | 277 | # Full pass over the training set 278 | for inputs, targets in iterate_minibatches(X_train, y_train, batch_size, shuffle=True): 279 | for i in range(len(inputs)): 280 | # TODO: this for loop should not excist 281 | 282 | if debug: 283 | print(type(inputs), type(targets)) 284 | print(type(inputs[i]), type(targets[i])) 285 | error, accuracy = train_fn([inputs[i]], targets[i]) 286 | 287 | train_error[epoch] += error 288 | train_accuracy[epoch] += accuracy 289 | train_batches[epoch] += 1 290 | 291 | # Full pass over the validation set 292 | for inputs, targets in iterate_minibatches(X_val, y_val, batch_size, shuffle=False): 293 | for i in range(len(inputs)): 294 | error, accuracy = validate_fn([inputs[i]], targets[i]) 295 | 296 | validation_error[epoch] += error 297 | validation_accuracy[epoch] += accuracy 298 | validation_batches[epoch] += 1 299 | 300 | # Full pass over the test set 301 | for inputs, targets in iterate_minibatches(X_test, y_test, batch_size, shuffle=False): 302 | for i in range(len(inputs)): 303 | error, accuracy = validate_fn([inputs[i]], targets[i]) 304 | 305 | test_error[epoch] += error 306 | test_accuracy[epoch] += accuracy 307 | test_batches[epoch] += 1 308 | 309 | 310 | # Print epoch summary 311 | train_epoch_error = (100 - train_accuracy[epoch] 312 | / train_batches[epoch] * 100) 313 | val_epoch_error = (100 - validation_accuracy[epoch] 314 | / validation_batches[epoch] * 100) 315 | test_epoch_error = (100 - test_accuracy[epoch] 316 | / test_batches[epoch] * 100) 317 | 318 | self.network_train_info[0].append(train_epoch_error) 319 | self.network_train_info[1].append(val_epoch_error) 320 | self.network_train_info[2].append(test_epoch_error) 321 | 322 | 323 | print("Epoch {} of {} took {:.3f}s.".format( 324 | epoch + 1, num_epochs, time.time() - epoch_time), end=' ') 325 | if val_epoch_error < self.best_error: 326 | self.best_error = val_epoch_error 327 | self.best_epoch = self.curr_epoch 328 | self.best_param = L.get_all_param_values(self.network) 329 | print("New best model found!") 330 | else: 331 | print() 332 | 333 | # print(" New best model found!", end=" ") 334 | # if save_name is not None: 335 | # print("Model saved as " + save_name + '.npz') 336 | # self.save_model(save_name + '.npz') 337 | # else: 338 | # print() 339 | 340 | 341 | print(" training loss:\t{:.6f}".format( 342 | train_error[epoch] / train_batches[epoch]), end='\t') 343 | print("train error:\t\t{:.6f} %".format(train_epoch_error)) 344 | 345 | print(" validation loss:\t{:.6f}".format( 346 | validation_error[epoch] / validation_batches[epoch]), end='\t') 347 | print("validation error:\t{:.6f} %".format(val_epoch_error )) 348 | 349 | print(" test loss:\t\t{:.6f}".format( 350 | test_error[epoch] / test_batches[epoch]), end='\t') 351 | print("test error:\t\t{:.6f} %".format(test_epoch_error)) 352 | 353 | # if comput_confusion: 354 | # confusion_matrices.append(create_confusion(X_val, y_val)[0]) 355 | # print(' Confusion matrix computed') 356 | print() 357 | 358 | 359 | 360 | # with open(save_name + '_var.pkl', 'wb') as cPickle_file: 361 | # cPickle.dump( 362 | # [network_train_info], 363 | # cPickle_file, 364 | # protocol=cPickle.HIGHEST_PROTOCOL) 365 | 366 | 367 | # if comput_confusion: 368 | # with open(save_name + '_conf.pkl', 'wb') as cPickle_file: 369 | # cPickle.dump( 370 | # [confusion_matrices], 371 | # cPickle_file, 372 | # protocol=cPickle.HIGHEST_PROTOCOL) 373 | 374 | 375 | --------------------------------------------------------------------------------