├── LICENSE ├── README.md ├── deep_learning_tips_from_the_road.pdf └── graphics_code ├── cvae ├── cvae.py ├── cvae_code.gif ├── cvae_reconstruction.png ├── cvae_style.png ├── flying_cvae.py └── kdl_template.py ├── kdl_template.py └── vae ├── flying_vae.py ├── kdl_template.py ├── vae.py ├── vae_code.gif └── vae_reconstruction.png /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015, Kyle Kastner 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of SciPy2015 nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | PDF version of slides included. 2 | Slides also uploaded to https://speakerdeck.com/kastnerkyle 3 | 4 | Video of talk here: 5 | https://www.youtube.com/watch?v=TBBtOeY2Q78 6 | 7 | To run the code in the graphics directory, you will need Theano. 8 | I have not run this on CPU yet, but it runs pretty quickly on GPU. 9 | To run the code, simply go to ``graphics_code/vae`` or ``graphics_code/cvae`` 10 | 11 | ``THEANO_FLAGS="floatX=float32,device=gpu,mode=FAST_RUN" python vae.py`` 12 | 13 | or 14 | 15 | ``THEANO_FLAGS="floatX=float32,device=gpu,mode=FAST_RUN" python cvae.py`` 16 | 17 | will start training the model. After training, 18 | 19 | ``THEANO_FLAGS="floatX=float32,device=gpu,mode=FAST_RUN" python flying_vae.py serialized_vae.pkl`` 20 | 21 | or 22 | 23 | ``THEANO_FLAGS="floatX=float32,device=gpu,mode=FAST_RUN" python flying_cvae.py serialized_cvae.pkl`` 24 | 25 | will generate plots for the saved model. 26 | 27 | Variational Autoencoder 28 | ======================= 29 | 30 | This variational autoencoder follows the general procedure described in 31 | [Auto-Encoding Variational Bayes, Kingma and Welling](http://arxiv.org/abs/1312.6114) 32 | 33 | Another paper describes a similar concept, [Stochastic Backpropagation and Approximate Inference in Deep Generative Models, Rezende, Mohamed, and Wierstra](http://arxiv.org/abs/1401.4082). 34 | 35 | VAE Code Walking 36 | ---------------- 37 | ![walking_code](graphics_code/vae/vae_code.gif) 38 | 39 | VAE Reconstruction 40 | ------------------ 41 | ![reconstruct](graphics_code/vae/vae_reconstruction.png) 42 | 43 | 44 | Conditional Variational Autoencoder 45 | =================================== 46 | This conditional variational autoencoder follows a similar procedure to that described in 47 | [Semi-supervised Learning with Deep Generative Models, Kingma, Rezende, Mohamed, and Welling](http://arxiv.org/abs/1406.5298). 48 | 49 | Conditional VAE Code Walking With Conditional Control 50 | ----------------------------------------------------- 51 | ![walking_code](graphics_code/cvae/cvae_code.gif) 52 | 53 | Holding Style (Z) Fixed and Changing Conditional y 54 | -------------------------------------------------- 55 | ![reconstruct](graphics_code/cvae/cvae_style.png) 56 | 57 | Conditional VAE Reconstruction and Prediction 58 | --------------------------------------------- 59 | ![reconstruct](graphics_code/cvae/cvae_reconstruction.png) 60 | 61 | 62 | Linked content 63 | ============== 64 | sklearn-theano, a scikit-learn compatible library for using pretrained networks http://sklearn-theano.github.io/ 65 | 66 | My research code https://github.com/kastnerkyle/santa_barbaria 67 | 68 | Neural network tutorial by @NewMu / Alec Radford https://github.com/Newmu/Theano-Tutorials 69 | 70 | Theano Deep Learning Tutorials http://deeplearning.net/tutorial/ 71 | -------------------------------------------------------------------------------- /deep_learning_tips_from_the_road.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kastnerkyle/SciPy2015/e0870cb9b3cfeca60c9ee2e2b6b805f266aac5c3/deep_learning_tips_from_the_road.pdf -------------------------------------------------------------------------------- /graphics_code/cvae/cvae.py: -------------------------------------------------------------------------------- 1 | from kdl_template import * 2 | 3 | train, valid, test = fetch_binarized_mnist() 4 | X = train[0].astype(theano.config.floatX) 5 | y = convert_to_one_hot(train[1], n_classes=10) 6 | 7 | # graph holds information necessary to build layers from parents 8 | graph = OrderedDict() 9 | X_sym, y_sym = add_datasets_to_graph([X, y], ["X", "y"], graph) 10 | # random state so script is deterministic 11 | random_state = np.random.RandomState(1999) 12 | 13 | minibatch_size = 100 14 | n_code = 100 15 | n_targets = 10 16 | n_enc_layer = [200, 200, 200, 200] 17 | n_dec_layer = [200, 200, 200] 18 | width = 28 19 | height = 28 20 | n_input = width * height 21 | 22 | # q(y | x) 23 | y_l1_enc = softplus_layer([X_sym], graph, 'y_l1_enc', n_enc_layer[0], 24 | random_state) 25 | y_l2_enc = softmax_layer([y_l1_enc], graph, 'y_l2_enc', n_enc_layer[1], 26 | random_state) 27 | y_pred = softmax_layer([y_l2_enc], graph, 'y_pred_enc', n_targets, 28 | random_state) 29 | 30 | # partial q(z | x) 31 | x_l1_enc = softplus_layer([X_sym], graph, 'x_l1_enc', n_enc_layer[0], 32 | random_state) 33 | x_l2_enc = softplus_layer([x_l1_enc], graph, 'x_l2_enc', n_enc_layer[1], 34 | random_state) 35 | 36 | 37 | # combined q(y | x) and partial q(z | x) for q(z | y, x) 38 | l3_enc = softplus_layer([x_l2_enc, y_pred], graph, 'l3_enc', n_enc_layer[2], 39 | random_state) 40 | l4_enc = softplus_layer([l3_enc], graph, 'l4_enc', n_enc_layer[3], 41 | random_state) 42 | code_mu = linear_layer([l4_enc], graph, 'code_mu', n_code, random_state) 43 | code_log_sigma = linear_layer([l4_enc], graph, 'code_log_sigma', n_code, 44 | random_state) 45 | samp = gaussian_log_sample_layer([code_mu], [code_log_sigma], 'samp', 46 | random_state) 47 | 48 | # decode path aka p for labeled data 49 | l1_dec = softplus_layer([samp, y_sym], graph, 'l1_dec', n_dec_layer[0], 50 | random_state) 51 | l2_dec = softplus_layer([l1_dec], graph, 'l2_dec', n_dec_layer[1], random_state) 52 | l3_dec = softplus_layer([l2_dec], graph, 'l3_dec', n_dec_layer[2], random_state) 53 | out = sigmoid_layer([l3_dec], graph, 'out', n_input, random_state) 54 | 55 | # Components of the cost 56 | nll = binary_crossentropy_nll(out, X_sym).mean() 57 | ent = binary_entropy(y_pred).mean() 58 | kl = gaussian_log_kl([code_mu], [code_log_sigma], 'kl').mean() 59 | 60 | # Junk when in unlabled mode 61 | err = categorical_crossentropy_nll(y_pred, y_sym).mean() 62 | 63 | # log p(x) = -nll so swap sign 64 | # want to minimize cost in optimization so multiply by -1 65 | base_cost = -1 * (-nll - kl) 66 | 67 | # -log q(y | x) is nll already 68 | alpha = .1 69 | cost = base_cost + alpha * err 70 | 71 | params, grads = get_params_and_grads(graph, cost) 72 | learning_rate = 0.0003 73 | opt = adam(params) 74 | updates = opt.updates(params, grads, learning_rate) 75 | 76 | # Checkpointing 77 | save_path = "serialized_cvae.pkl" 78 | if not os.path.exists(save_path): 79 | fit_function = theano.function([X_sym, y_sym], [nll, kl, nll + kl], 80 | updates=updates) 81 | predict_function = theano.function([X_sym], [y_pred]) 82 | encode_function = theano.function([X_sym, y_sym], [code_mu, code_log_sigma], 83 | on_unused_input='warn') 84 | # Need both due to tensor.switch, but only one should ever be used 85 | decode_function = theano.function([samp, y_sym], [out]) 86 | checkpoint_dict = {} 87 | checkpoint_dict["fit_function"] = fit_function 88 | checkpoint_dict["predict_function"] = predict_function 89 | checkpoint_dict["encode_function"] = encode_function 90 | checkpoint_dict["decode_function"] = decode_function 91 | previous_epoch_results = None 92 | else: 93 | checkpoint_dict = load_checkpoint(save_path) 94 | fit_function = checkpoint_dict["fit_function"] 95 | predict_function = checkpoint_dict["predict_function"] 96 | encode_function = checkpoint_dict["encode_function"] 97 | decode_function = checkpoint_dict["decode_function"] 98 | previous_epoch_results = checkpoint_dict["previous_epoch_results"] 99 | 100 | 101 | def status_func(status_number, epoch_number, epoch_results): 102 | checkpoint_status_func(save_path, checkpoint_dict, epoch_results) 103 | 104 | epoch_results = iterate_function(fit_function, [X, y], minibatch_size, 105 | list_of_output_names=["nll", "kl", "cost"], 106 | n_epochs=2000, 107 | status_func=status_func, 108 | previous_epoch_results=previous_epoch_results, 109 | shuffle=True, 110 | random_state=random_state) 111 | -------------------------------------------------------------------------------- /graphics_code/cvae/cvae_code.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kastnerkyle/SciPy2015/e0870cb9b3cfeca60c9ee2e2b6b805f266aac5c3/graphics_code/cvae/cvae_code.gif -------------------------------------------------------------------------------- /graphics_code/cvae/cvae_reconstruction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kastnerkyle/SciPy2015/e0870cb9b3cfeca60c9ee2e2b6b805f266aac5c3/graphics_code/cvae/cvae_reconstruction.png -------------------------------------------------------------------------------- /graphics_code/cvae/cvae_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kastnerkyle/SciPy2015/e0870cb9b3cfeca60c9ee2e2b6b805f266aac5c3/graphics_code/cvae/cvae_style.png -------------------------------------------------------------------------------- /graphics_code/cvae/flying_cvae.py: -------------------------------------------------------------------------------- 1 | from kdl_template import * 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("saved_functions_file", 6 | help="Saved pickle file from vae training") 7 | parser.add_argument("--seed", "-s", 8 | help="random seed for path calculation", 9 | action="store", default=1979, type=int) 10 | 11 | args = parser.parse_args() 12 | if not os.path.exists(args.saved_functions_file): 13 | raise ValueError("Please provide a valid path for saved pickle file!") 14 | 15 | checkpoint_dict = load_checkpoint(args.saved_functions_file) 16 | encode_function = checkpoint_dict["encode_function"] 17 | decode_function = checkpoint_dict["decode_function"] 18 | predict_function = checkpoint_dict["predict_function"] 19 | 20 | random_state = np.random.RandomState(args.seed) 21 | train, valid, test = fetch_binarized_mnist() 22 | # visualize against validation so we aren't cheating 23 | X = valid[0].astype(theano.config.floatX) 24 | y = valid[1].astype("int32") 25 | 26 | # number of samples 27 | n_plot_samples = 6 28 | n_classes = 10 29 | # MNIST dimensions 30 | width = 28 31 | height = 28 32 | # Get random data samples 33 | ind = np.arange(len(X)) 34 | random_state.shuffle(ind) 35 | sample_X = X[ind[:n_plot_samples]] 36 | sample_y = convert_to_one_hot(y[ind[:n_plot_samples]], n_classes=n_classes) 37 | 38 | 39 | def gen_samples(X, y): 40 | mu, log_sig = encode_function(X, y) 41 | # No noise at test time - repeat y twice because y_pred is needed for Theano 42 | # But it is not used unless y_sym is all -1 43 | out, = decode_function(mu + np.exp(log_sig), y) 44 | return out 45 | 46 | # VAE specific plotting 47 | import matplotlib 48 | matplotlib.use('Agg') 49 | import matplotlib.pyplot as plt 50 | 51 | all_pred_y, = predict_function(X) 52 | all_pred_y = np.argmax(all_pred_y, axis=1) 53 | accuracy = np.mean(all_pred_y.ravel() == y.ravel()) 54 | 55 | f, axarr = plt.subplots(n_plot_samples, 2) 56 | n_correct_to_show = n_plot_samples // 2 57 | n_incorrect_to_show = n_plot_samples - n_correct_to_show 58 | 59 | correct_ind = np.where(all_pred_y == y)[0] 60 | incorrect_ind = np.where(all_pred_y != y)[0] 61 | random_state.shuffle(correct_ind) 62 | random_state.shuffle(incorrect_ind) 63 | c = correct_ind[:n_correct_to_show] 64 | i = incorrect_ind[:n_incorrect_to_show] 65 | 66 | X_corr = X[c] 67 | X_incorr = X[i] 68 | X_stack = np.vstack((X_corr, X_incorr)) 69 | y_corr = convert_to_one_hot(y[c], n_classes=10) 70 | y_incorr = convert_to_one_hot(y[i], n_classes=10) 71 | y_stack = np.vstack((y_corr, y_incorr)) 72 | 73 | generated_X = gen_samples(X_stack, y_stack) 74 | predicted_y = convert_to_one_hot(np.hstack((all_pred_y[c], all_pred_y[i])), 75 | n_classes=10) 76 | 77 | for n, (X_i, y_i, sx_i, sy_i) in enumerate(zip(X_stack, y_stack, 78 | generated_X, predicted_y)): 79 | axarr[n, 0].matshow(X_i.reshape(width, height), cmap="gray") 80 | axarr[n, 1].matshow(sx_i.reshape(width, height), cmap="gray") 81 | axarr[n, 0].axis('off') 82 | axarr[n, 1].axis('off') 83 | 84 | y_a = np.argmax(y_i) 85 | sy_a = np.argmax(sy_i) 86 | axarr[n, 0].text(0, 7, str(y_a), color='green') 87 | if y_a == sy_a: 88 | axarr[n, 1].text(0, 7, str(sy_a), color='green') 89 | else: 90 | axarr[n, 1].text(0, 7, str(sy_a), color='red') 91 | 92 | f.suptitle("Validation accuracy: %s" % str(accuracy)) 93 | plt.savefig('cvae_reconstruction.png') 94 | plt.close() 95 | 96 | # Style plotting 97 | f, axarr = plt.subplots(n_plot_samples, n_classes + 1) 98 | for n, (X_i, y_i) in enumerate(zip(sample_X, sample_y)): 99 | true_rec = gen_samples(X_i[None], y_i[None]) 100 | fixed_mu, fixed_sigma = encode_function(X_i[None], y_i[None]) 101 | axarr[n, 0].matshow(X_i.reshape(width, height), cmap="gray") 102 | axarr[n, 0].axis('off') 103 | all_mu = fixed_mu * np.ones((n_classes, fixed_mu.shape[1])).astype( 104 | "float32") 105 | all_sigma = fixed_sigma * np.ones((n_classes, fixed_sigma.shape[1])).astype( 106 | "float32") 107 | all_classes = np.eye(n_classes).astype('int32') 108 | all_recs, = decode_function(all_mu + np.exp(all_sigma), all_classes) 109 | for j in range(1, n_classes + 1): 110 | axarr[n, j].matshow(all_recs[j - 1].reshape(width, height), cmap="gray") 111 | axarr[n, j].axis('off') 112 | f.suptitle("Style variation by changing conditional") 113 | plt.savefig('cvae_style.png') 114 | plt.close() 115 | 116 | # Calculate noisy linear path between points in space 117 | mus, log_sigmas = encode_function(sample_X, sample_y) 118 | n_steps = 20 119 | mu_path = interpolate_between_points(mus, n_steps=n_steps) 120 | log_sigma_path = interpolate_between_points(log_sigmas, n_steps=n_steps) 121 | 122 | # Noisy path across space from one point to another 123 | path_X = mu_path + np.exp(log_sigma_path) 124 | path_y = np.zeros((len(path_X), n_classes), dtype="int32") 125 | 126 | for i in range(n_plot_samples): 127 | path_y[i * n_steps:(i + 1) * n_steps] = sample_y[i] 128 | 129 | # Have to pass another argument for y_pred 130 | # But it is not used unless y_sym is all -1 131 | out, = decode_function(path_X, path_y) 132 | text_y = [str(np.argmax(path_y[i])) for i in range(len(path_y))] 133 | color_y = ["white"] * len(text_y) 134 | make_gif(out, "cvae_code.gif", width, height, list_text_per_frame=text_y, 135 | list_text_per_frame_color=color_y, delay=1, grayscale=True) 136 | -------------------------------------------------------------------------------- /graphics_code/cvae/kdl_template.py: -------------------------------------------------------------------------------- 1 | ../kdl_template.py -------------------------------------------------------------------------------- /graphics_code/kdl_template.py: -------------------------------------------------------------------------------- 1 | # Author: Kyle Kastner 2 | # License: BSD 3-clause 3 | # Ideas from Junyoung Chung and Kyunghyun Cho 4 | # The latest version of this template will always live in: 5 | # https://github.com/kastnerkyle/santa_barbaria 6 | # See https://github.com/jych/cle for a library in this style 7 | import numpy as np 8 | from scipy import linalg 9 | from scipy.io import loadmat 10 | from functools import reduce 11 | import numbers 12 | import random 13 | import theano 14 | import zipfile 15 | import gzip 16 | import os 17 | import glob 18 | import sys 19 | import subprocess 20 | try: 21 | import cPickle as pickle 22 | except ImportError: 23 | import pickle 24 | from theano import tensor 25 | from theano.compat.python2x import OrderedDict 26 | from theano.sandbox.rng_mrg import MRG_RandomStreams 27 | from collections import defaultdict 28 | 29 | 30 | class sgd(object): 31 | """ 32 | Vanilla SGD 33 | """ 34 | def __init__(self, params): 35 | pass 36 | 37 | def updates(self, params, grads, learning_rate): 38 | updates = [] 39 | for n, (param, grad) in enumerate(zip(params, grads)): 40 | updates.append((param, param - learning_rate * grad)) 41 | return updates 42 | 43 | 44 | class sgd_nesterov(object): 45 | """ 46 | SGD with nesterov momentum 47 | 48 | Based on example from Yann D. 49 | """ 50 | def __init__(self, params): 51 | self.memory_ = [theano.shared(np.zeros_like(p.get_value())) 52 | for p in params] 53 | 54 | def updates(self, params, grads, learning_rate, momentum): 55 | updates = [] 56 | for n, (param, grad) in enumerate(zip(params, grads)): 57 | memory = self.memory_[n] 58 | update = momentum * memory - learning_rate * grad 59 | update2 = momentum * momentum * memory - ( 60 | 1 + momentum) * learning_rate * grad 61 | updates.append((memory, update)) 62 | updates.append((param, param + update2)) 63 | return updates 64 | 65 | 66 | class rmsprop(object): 67 | """ 68 | RMSProp with nesterov momentum and gradient rescaling 69 | """ 70 | def __init__(self, params): 71 | self.running_square_ = [theano.shared(np.zeros_like(p.get_value())) 72 | for p in params] 73 | self.running_avg_ = [theano.shared(np.zeros_like(p.get_value())) 74 | for p in params] 75 | self.memory_ = [theano.shared(np.zeros_like(p.get_value())) 76 | for p in params] 77 | 78 | def updates(self, params, grads, learning_rate, momentum, rescale=5.): 79 | grad_norm = tensor.sqrt(sum(map(lambda x: tensor.sqr(x).sum(), grads))) 80 | not_finite = tensor.or_(tensor.isnan(grad_norm), 81 | tensor.isinf(grad_norm)) 82 | grad_norm = tensor.sqrt(grad_norm) 83 | scaling_num = rescale 84 | scaling_den = tensor.maximum(rescale, grad_norm) 85 | # Magic constants 86 | combination_coeff = 0.9 87 | minimum_grad = 1E-4 88 | updates = [] 89 | for n, (param, grad) in enumerate(zip(params, grads)): 90 | grad = tensor.switch(not_finite, 0.1 * param, 91 | grad * (scaling_num / scaling_den)) 92 | old_square = self.running_square_[n] 93 | new_square = combination_coeff * old_square + ( 94 | 1. - combination_coeff) * tensor.sqr(grad) 95 | old_avg = self.running_avg_[n] 96 | new_avg = combination_coeff * old_avg + ( 97 | 1. - combination_coeff) * grad 98 | rms_grad = tensor.sqrt(new_square - new_avg ** 2) 99 | rms_grad = tensor.maximum(rms_grad, minimum_grad) 100 | memory = self.memory_[n] 101 | update = momentum * memory - learning_rate * grad / rms_grad 102 | update2 = momentum * momentum * memory - ( 103 | 1 + momentum) * learning_rate * grad / rms_grad 104 | updates.append((old_square, new_square)) 105 | updates.append((old_avg, new_avg)) 106 | updates.append((memory, update)) 107 | updates.append((param, param + update2)) 108 | return updates 109 | 110 | 111 | class adagrad(object): 112 | """ 113 | Adagrad optimizer 114 | """ 115 | def __init__(self, params): 116 | self.memory_ = [theano.shared(np.zeros_like(p.get_value())) 117 | for p in params] 118 | 119 | def updates(self, params, grads, learning_rate, eps=1E-8): 120 | updates = [] 121 | for n, (param, grad) in enumerate(zip(params, grads)): 122 | memory = self.memory_[n] 123 | m_t = memory + grad ** 2 124 | g_t = grad / (eps + tensor.sqrt(m_t)) 125 | p_t = param - learning_rate * g_t 126 | updates.append((memory, m_t)) 127 | updates.append((param, p_t)) 128 | return updates 129 | 130 | 131 | class adam(object): 132 | """ 133 | Adam optimizer 134 | 135 | Based on implementation from @NewMu / Alex Radford 136 | """ 137 | def __init__(self, params): 138 | self.memory_ = [theano.shared(np.zeros_like(p.get_value())) 139 | for p in params] 140 | self.velocity_ = [theano.shared(np.zeros_like(p.get_value())) 141 | for p in params] 142 | self.itr_ = theano.shared(np.array(0.).astype(theano.config.floatX)) 143 | 144 | def updates(self, params, grads, learning_rate, b1=0.1, b2=0.001, eps=1E-8): 145 | updates = [] 146 | itr = self.itr_ 147 | i_t = itr + 1. 148 | fix1 = 1. - (1. - b1) ** i_t 149 | fix2 = 1. - (1. - b2) ** i_t 150 | lr_t = learning_rate * (tensor.sqrt(fix2) / fix1) 151 | for n, (param, grad) in enumerate(zip(params, grads)): 152 | memory = self.memory_[n] 153 | velocity = self.velocity_[n] 154 | m_t = (b1 * grad) + ((1. - b1) * memory) 155 | v_t = (b2 * tensor.sqr(grad)) + ((1. - b2) * velocity) 156 | g_t = m_t / (tensor.sqrt(v_t) + eps) 157 | p_t = param - (lr_t * g_t) 158 | updates.append((memory, m_t)) 159 | updates.append((velocity, v_t)) 160 | updates.append((param, p_t)) 161 | updates.append((itr, i_t)) 162 | return updates 163 | 164 | 165 | def get_dataset_dir(dataset_name, data_dir=None, folder=None, create_dir=True): 166 | """ Get dataset directory path """ 167 | if not data_dir: 168 | data_dir = os.getenv("SANTA_BARBARIA_DATA", os.path.join( 169 | os.path.expanduser("~"), "santa_barbaria_data")) 170 | if folder is None: 171 | data_dir = os.path.join(data_dir, dataset_name) 172 | else: 173 | data_dir = os.path.join(data_dir, folder) 174 | if not os.path.exists(data_dir) and create_dir: 175 | os.makedirs(data_dir) 176 | return data_dir 177 | 178 | 179 | def download(url, server_fname, local_fname=None, progress_update_percentage=5): 180 | """ 181 | An internet download utility modified from 182 | http://stackoverflow.com/questions/22676/ 183 | how-do-i-download-a-file-over-http-using-python/22776#22776 184 | """ 185 | try: 186 | import urllib 187 | urllib.urlretrieve('http://google.com') 188 | except AttributeError: 189 | import urllib.request as urllib 190 | u = urllib.urlopen(url) 191 | if local_fname is None: 192 | local_fname = server_fname 193 | full_path = local_fname 194 | meta = u.info() 195 | with open(full_path, 'wb') as f: 196 | try: 197 | file_size = int(meta.get("Content-Length")) 198 | except TypeError: 199 | print("WARNING: Cannot get file size, displaying bytes instead!") 200 | file_size = 100 201 | print("Downloading: %s Bytes: %s" % (server_fname, file_size)) 202 | file_size_dl = 0 203 | block_sz = int(1E7) 204 | p = 0 205 | while True: 206 | buffer = u.read(block_sz) 207 | if not buffer: 208 | break 209 | file_size_dl += len(buffer) 210 | f.write(buffer) 211 | if (file_size_dl * 100. / file_size) > p: 212 | status = r"%10d [%3.2f%%]" % (file_size_dl, file_size_dl * 213 | 100. / file_size) 214 | print(status) 215 | p += progress_update_percentage 216 | 217 | 218 | def make_character_level_from_text(text): 219 | """ Create mapping and inverse mappings for text -> one_hot_char """ 220 | all_chars = reduce(lambda x, y: set(x) | set(y), text, set()) 221 | mapper = {k: n + 2 for n, k in enumerate(list(all_chars))} 222 | # 1 is EOS 223 | mapper["EOS"] = 1 224 | # 0 is UNK/MASK - unused here but needed in general 225 | mapper["UNK"] = 0 226 | inverse_mapper = {v: k for k, v in mapper.items()} 227 | 228 | def mapper_func(text_line): 229 | return [mapper[c] for c in text_line] + [mapper["EOS"]] 230 | 231 | def inverse_mapper_func(symbol_line): 232 | return "".join([inverse_mapper[s] for s in symbol_line 233 | if s != mapper["EOS"]]) 234 | 235 | # Remove blank lines 236 | cleaned = [mapper_func(t) for t in text if t != ""] 237 | return cleaned, mapper_func, inverse_mapper_func, mapper 238 | 239 | 240 | def check_fetch_uci_words(): 241 | """ Check for UCI vocabulary """ 242 | url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/bag-of-words/' 243 | partial_path = get_dataset_dir("uci_words") 244 | full_path = os.path.join(partial_path, "uci_words.zip") 245 | if not os.path.exists(partial_path): 246 | os.makedirs(partial_path) 247 | if not os.path.exists(full_path): 248 | # Download all 5 vocabularies and zip them into a file 249 | all_vocabs = ['vocab.enron.txt', 'vocab.kos.txt', 'vocab.nips.txt', 250 | 'vocab.nytimes.txt', 'vocab.pubmed.txt'] 251 | for vocab in all_vocabs: 252 | dl_url = url + vocab 253 | download(dl_url, os.path.join(partial_path, vocab), 254 | progress_update_percentage=1) 255 | 256 | def zipdir(path, zipf): 257 | # zipf is zipfile handle 258 | for root, dirs, files in os.walk(path): 259 | for f in files: 260 | if "vocab" in f: 261 | zipf.write(os.path.join(root, f)) 262 | 263 | zipf = zipfile.ZipFile(full_path, 'w') 264 | zipdir(partial_path, zipf) 265 | zipf.close() 266 | return full_path 267 | 268 | 269 | def fetch_uci_words(): 270 | """ Returns UCI vocabulary text. """ 271 | data_path = check_fetch_uci_words() 272 | all_data = [] 273 | with zipfile.ZipFile(data_path, "r") as f: 274 | for name in f.namelist(): 275 | if ".txt" not in name: 276 | # Skip README 277 | continue 278 | data = f.read(name) 279 | data = data.split("\n") 280 | data = [l.strip() for l in data if l != ""] 281 | all_data.extend(data) 282 | return list(set(all_data)) 283 | 284 | 285 | def check_fetch_lovecraft(): 286 | """ Check for lovecraft data """ 287 | url = 'https://dl.dropboxusercontent.com/u/15378192/lovecraft_fiction.zip' 288 | partial_path = get_dataset_dir("lovecraft") 289 | full_path = os.path.join(partial_path, "lovecraft_fiction.zip") 290 | if not os.path.exists(partial_path): 291 | os.makedirs(partial_path) 292 | if not os.path.exists(full_path): 293 | download(url, full_path, progress_update_percentage=1) 294 | return full_path 295 | 296 | 297 | def fetch_lovecraft(): 298 | """ Returns lovecraft text. """ 299 | data_path = check_fetch_lovecraft() 300 | all_data = [] 301 | with zipfile.ZipFile(data_path, "r") as f: 302 | for name in f.namelist(): 303 | if ".txt" not in name: 304 | # Skip README 305 | continue 306 | data = f.read(name) 307 | data = data.split("\n") 308 | data = [l.strip() for l in data if l != ""] 309 | all_data.extend(data) 310 | return all_data 311 | 312 | 313 | def check_fetch_tfd(): 314 | """ Check that tfd faces are downloaded """ 315 | partial_path = get_dataset_dir("tfd") 316 | full_path = os.path.join(partial_path, "TFD_48x48.mat") 317 | if not os.path.exists(partial_path): 318 | os.makedirs(partial_path) 319 | if not os.path.exists(full_path): 320 | raise ValueError("Put TFD_48x48 in %s" % str(partial_path)) 321 | return full_path 322 | 323 | 324 | def fetch_tfd(): 325 | """ Returns flattened 48x48 TFD faces with pixel values in [0 - 1] """ 326 | data_path = check_fetch_tfd() 327 | matfile = loadmat(data_path) 328 | all_data = matfile['images'].reshape(len(matfile['images']), -1) / 255. 329 | return all_data 330 | 331 | 332 | def check_fetch_frey(): 333 | """ Check that frey faces are downloaded """ 334 | url = 'http://www.cs.nyu.edu/~roweis/data/frey_rawface.mat' 335 | partial_path = get_dataset_dir("frey") 336 | full_path = os.path.join(partial_path, "frey_rawface.mat") 337 | if not os.path.exists(partial_path): 338 | os.makedirs(partial_path) 339 | if not os.path.exists(full_path): 340 | download(url, full_path, progress_update_percentage=1) 341 | return full_path 342 | 343 | 344 | def fetch_frey(): 345 | """ Returns flattened 20x28 frey faces with pixel values in [0 - 1] """ 346 | data_path = check_fetch_frey() 347 | matfile = loadmat(data_path) 348 | all_data = (matfile['ff'] / 255.).T 349 | return all_data 350 | 351 | 352 | def check_fetch_mnist(): 353 | """ Check that mnist is downloaded. May need fixing for py3 compat """ 354 | # py3k version is available at mnist_py3k.pkl.gz ... might need to fix 355 | url = 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz' 356 | partial_path = get_dataset_dir("mnist") 357 | full_path = os.path.join(partial_path, "mnist.pkl.gz") 358 | if not os.path.exists(partial_path): 359 | os.makedirs(partial_path) 360 | if not os.path.exists(full_path): 361 | download(url, full_path, progress_update_percentage=1) 362 | return full_path 363 | 364 | 365 | def fetch_mnist(): 366 | """ Returns mnist digits with pixel values in [0 - 1] """ 367 | data_path = check_fetch_mnist() 368 | f = gzip.open(data_path, 'rb') 369 | try: 370 | train_set, valid_set, test_set = pickle.load(f, encoding="latin1") 371 | except TypeError: 372 | train_set, valid_set, test_set = pickle.load(f) 373 | f.close() 374 | return train_set, valid_set, test_set 375 | 376 | 377 | def check_fetch_binarized_mnist(): 378 | raise ValueError("Binarized MNIST has no labels!") 379 | url = "https://github.com/mgermain/MADE/releases/download/ICML2015/binarized_mnist.npz" 380 | partial_path = get_dataset_dir("binarized_mnist") 381 | fname = "binarized_mnist.npz" 382 | full_path = os.path.join(partial_path, fname) 383 | if not os.path.exists(partial_path): 384 | os.makedirs(partial_path) 385 | if not os.path.exists(full_path): 386 | download(url, full_path, progress_update_percentage=1) 387 | """ 388 | # Personal version 389 | url = "https://dl.dropboxusercontent.com/u/15378192/binarized_mnist_%s.npy" 390 | fname = "binarized_mnist_%s.npy" 391 | for s in ["train", "valid", "test"]: 392 | full_path = os.path.join(partial_path, fname % s) 393 | if not os.path.exists(partial_path): 394 | os.makedirs(partial_path) 395 | if not os.path.exists(full_path): 396 | download(url % s, full_path, progress_update_percentage=1) 397 | """ 398 | return partial_path 399 | 400 | 401 | def fetch_binarized_mnist(): 402 | """ Get binarized version of MNIST data """ 403 | train_set, valid_set, test_set = fetch_mnist() 404 | train_X = train_set[0] 405 | train_y = train_set[1] 406 | valid_X = valid_set[0] 407 | valid_y = valid_set[1] 408 | test_X = test_set[0] 409 | test_y = test_set[1] 410 | 411 | random_state = np.random.RandomState(1999) 412 | 413 | def get_sampled(arr): 414 | # make sure that a pixel can always be turned off 415 | return random_state.binomial(1, arr * 255 / 256., size=arr.shape) 416 | 417 | train_X = get_sampled(train_X) 418 | valid_X = get_sampled(valid_X) 419 | test_X = get_sampled(test_X) 420 | 421 | train_set = (train_X, train_y) 422 | valid_set = (valid_X, valid_y) 423 | test_set = (test_X, test_y) 424 | 425 | """ 426 | # Old version for true binarized mnist 427 | data_path = check_fetch_binarized_mnist() 428 | fpath = os.path.join(data_path, "binarized_mnist.npz") 429 | 430 | arr = np.load(fpath) 431 | train_x = arr['train_data'] 432 | valid_x = arr['valid_data'] 433 | test_x = arr['test_data'] 434 | train, valid, test = fetch_mnist() 435 | train_y = train[1] 436 | valid_y = valid[1] 437 | test_y = test[1] 438 | train_set = (train_x, train_y) 439 | valid_set = (valid_x, valid_y) 440 | test_set = (test_x, test_y) 441 | """ 442 | return train_set, valid_set, test_set 443 | 444 | 445 | def make_gif(arr, gif_name, plot_width, plot_height, list_text_per_frame=None, 446 | list_text_per_frame_color=None, 447 | delay=1, grayscale=False, 448 | loop=False, turn_on_agg=True): 449 | """ Make a gif frmo a series of pngs using matplotlib matshow """ 450 | if turn_on_agg: 451 | import matplotlib 452 | matplotlib.use('Agg') 453 | import matplotlib.pyplot as plt 454 | # Plot temporaries for making gif 455 | # use random code to try and avoid deleting surprise files... 456 | random_code = random.randrange(2 ** 32) 457 | pre = str(random_code) 458 | for n, arr_i in enumerate(arr): 459 | plt.matshow(arr_i.reshape(plot_width, plot_height), cmap="gray") 460 | plt.axis('off') 461 | if list_text_per_frame is not None: 462 | text = list_text_per_frame[n] 463 | if list_text_per_frame_color is not None: 464 | color = list_text_per_frame_color[n] 465 | else: 466 | color = "white" 467 | plt.text(0, plot_height, text, color=color, 468 | fontsize=2 * plot_height) 469 | # This looks rediculous but should count the number of digit places 470 | # also protects against multiple runs 471 | # plus 1 is to maintain proper ordering 472 | plotpath = '__%s_giftmp_%s.png' % (str(n).zfill(len( 473 | str(len(arr))) + 1), pre) 474 | plt.savefig(plotpath) 475 | plt.close() 476 | 477 | # make gif 478 | assert delay >= 1 479 | gif_delay = int(delay) 480 | basestr = "convert __*giftmp_%s.png -delay %s " % (pre, str(gif_delay)) 481 | if loop: 482 | basestr += "-loop 1 " 483 | else: 484 | basestr += "-loop 0 " 485 | if grayscale: 486 | basestr += "-depth 8 -type Grayscale -depth 8 " 487 | basestr += "-resize %sx%s " % (str(int(5 * plot_width)), 488 | str(int(5 * plot_height))) 489 | basestr += gif_name 490 | print("Attempting gif") 491 | print(basestr) 492 | subprocess.call(basestr, shell=True) 493 | filelist = glob.glob("__*giftmp_%s.png" % pre) 494 | for f in filelist: 495 | os.remove(f) 496 | 497 | 498 | def concatenate(tensor_list, name, axis=0, force_cast_to_float=True): 499 | """ 500 | Wrapper to `theano.tensor.concatenate`. 501 | """ 502 | if force_cast_to_float: 503 | tensor_list = cast_to_float(tensor_list) 504 | out = tensor.concatenate(tensor_list, axis=axis) 505 | conc_dim = int(sum([calc_expected_dim(inp) 506 | for inp in tensor_list])) 507 | # This may be hosed... need to figure out how to generalize 508 | shape = list(expression_shape(tensor_list[0])) 509 | shape[axis] = conc_dim 510 | new_shape = tuple(shape) 511 | tag_expression(out, name, new_shape) 512 | return out 513 | 514 | 515 | def theano_repeat(arr, n_repeat, stretch=False): 516 | """ 517 | Create repeats of 2D array using broadcasting. 518 | Shape[0] incorrect after this node! 519 | """ 520 | if arr.dtype not in ["float32", "float64"]: 521 | arr = tensor.cast(arr, "int32") 522 | if stretch: 523 | arg1 = arr.dimshuffle((0, 'x', 1)) 524 | arg2 = tensor.alloc(1., 1, n_repeat, arr.shape[1]) 525 | arg2 = tensor.cast(arg2, arr.dtype) 526 | cloned = (arg1 * arg2).reshape((n_repeat * arr.shape[0], arr.shape[1])) 527 | else: 528 | arg1 = arr.dimshuffle(('x', 0, 1)) 529 | arg2 = tensor.alloc(1., n_repeat, 1, arr.shape[1]) 530 | arg2 = tensor.cast(arg2, arr.dtype) 531 | cloned = (arg1 * arg2).reshape((n_repeat * arr.shape[0], arr.shape[1])) 532 | shape = expression_shape(arr) 533 | name = expression_name(arr) 534 | # Stretched shapes are *WRONG* 535 | tag_expression(cloned, name + "_stretched", (shape[0], shape[1])) 536 | return cloned 537 | 538 | 539 | def cast_to_float(list_of_inputs): 540 | """ A cast that preserves name and shape info after cast """ 541 | input_names = [inp.name for inp in list_of_inputs] 542 | cast_inputs = [tensor.cast(inp, theano.config.floatX) 543 | for inp in list_of_inputs] 544 | for n, inp in enumerate(cast_inputs): 545 | cast_inputs[n].name = input_names[n] 546 | return cast_inputs 547 | 548 | 549 | def interpolate_between_points(arr, n_steps=50): 550 | """ Helper function for drawing line between points in space """ 551 | assert len(arr) > 2 552 | assert n_steps > 1 553 | path = [path_between_points(start, stop, n_steps=n_steps) 554 | for start, stop in zip(arr[:-1], arr[1:])] 555 | path = np.vstack(path) 556 | return path 557 | 558 | 559 | def path_between_points(start, stop, n_steps=100, dtype=theano.config.floatX): 560 | """ Helper function for making a line between points in ND space """ 561 | assert n_steps > 1 562 | step_vector = 1. / (n_steps - 1) * (stop - start) 563 | steps = np.arange(0, n_steps)[:, None] * np.ones((n_steps, len(stop))) 564 | steps = steps * step_vector + start 565 | return steps.astype(dtype) 566 | 567 | 568 | def minibatch_indices(itr, minibatch_size): 569 | """ Generate indices for slicing 2D and 3D arrays in minibatches""" 570 | is_three_d = False 571 | if type(itr) is np.ndarray: 572 | if len(itr.shape) == 3: 573 | is_three_d = True 574 | elif not isinstance(itr[0], numbers.Real): 575 | # Assume 3D list of list of list 576 | # iterable of iterable of iterable, feature dim must be consistent 577 | is_three_d = True 578 | 579 | if is_three_d: 580 | if type(itr) is np.ndarray: 581 | minibatch_indices = np.arange(0, itr.shape[1], minibatch_size) 582 | else: 583 | # multi-list 584 | minibatch_indices = np.arange(0, len(itr), minibatch_size) 585 | minibatch_indices = np.asarray(list(minibatch_indices) + [len(itr)]) 586 | start_indices = minibatch_indices[:-1] 587 | end_indices = minibatch_indices[1:] 588 | return zip(start_indices, end_indices) 589 | else: 590 | minibatch_indices = np.arange(0, len(itr), minibatch_size) 591 | minibatch_indices = np.asarray(list(minibatch_indices) + [len(itr)]) 592 | start_indices = minibatch_indices[:-1] 593 | end_indices = minibatch_indices[1:] 594 | return zip(start_indices, end_indices) 595 | 596 | 597 | def convert_to_one_hot(itr, n_classes, dtype="int32"): 598 | """ Convert 2D or 3D iterators to one_hot. Primarily for text. """ 599 | is_three_d = False 600 | if type(itr) is np.ndarray: 601 | if len(itr.shape) == 3: 602 | is_three_d = True 603 | elif not isinstance(itr[0], numbers.Real): 604 | # Assume 3D list of list of list 605 | # iterable of iterable of iterable, feature dim must be consistent 606 | is_three_d = True 607 | 608 | if is_three_d: 609 | lengths = [len(i) for i in itr] 610 | one_hot = np.zeros((max(lengths), len(itr), n_classes), dtype=dtype) 611 | for n in range(len(itr)): 612 | one_hot[np.arange(lengths[n]), n, itr[n]] = 1 613 | else: 614 | one_hot = np.zeros((len(itr), n_classes), dtype=dtype) 615 | one_hot[np.arange(len(itr)), itr] = 1 616 | return one_hot 617 | 618 | 619 | def save_checkpoint(save_path, items_dict): 620 | """ Simple wrapper for checkpoint dictionaries """ 621 | old_recursion_limit = sys.getrecursionlimit() 622 | sys.setrecursionlimit(40000) 623 | with open(save_path, mode="wb") as f: 624 | pickle.dump(items_dict, f) 625 | sys.setrecursionlimit(old_recursion_limit) 626 | 627 | 628 | def load_checkpoint(save_path): 629 | """ Simple pickle wrapper for checkpoint dictionaries """ 630 | old_recursion_limit = sys.getrecursionlimit() 631 | sys.setrecursionlimit(40000) 632 | with open(save_path, mode="rb") as f: 633 | items_dict = pickle.load(f) 634 | sys.setrecursionlimit(old_recursion_limit) 635 | return items_dict 636 | 637 | 638 | def print_status_func(epoch_results): 639 | """ Print the last results from a results dictionary """ 640 | n_epochs_seen = max([len(l) for l in epoch_results.values()]) 641 | last_results = {k: v[-1] for k, v in epoch_results.items()} 642 | print("Epoch %i: %s" % (n_epochs_seen, last_results)) 643 | 644 | 645 | def checkpoint_status_func(save_path, checkpoint_dict, epoch_results): 646 | """ Saves a checkpoint dict """ 647 | checkpoint_dict["previous_epoch_results"] = epoch_results 648 | save_checkpoint(save_path, checkpoint_dict) 649 | print_status_func(epoch_results) 650 | 651 | 652 | def early_stopping_status_func(valid_cost, save_path, checkpoint_dict, 653 | epoch_results): 654 | """ 655 | Adds valid_cost to epoch_results and saves model if best valid 656 | Assumes checkpoint_dict is a defaultdict(list) 657 | 658 | Example usage for early stopping on validation set: 659 | 660 | def status_func(status_number, epoch_number, epoch_results): 661 | valid_results = iterate_function( 662 | cost_function, [X_clean_valid, y_clean_valid], minibatch_size, 663 | list_of_output_names=["valid_cost"], 664 | list_of_minibatch_functions=[text_minibatcher], n_epochs=1, 665 | shuffle=False) 666 | early_stopping_status_func(valid_results["valid_cost"][-1], 667 | save_path, checkpoint_dict, epoch_results) 668 | 669 | status_func can then be fed to iterate_function for training with early 670 | stopping. 671 | """ 672 | # Quick trick to avoid 0 length list 673 | old = min(epoch_results["valid_cost"] + [np.inf]) 674 | epoch_results["valid_cost"].append(valid_cost) 675 | new = min(epoch_results["valid_cost"]) 676 | if new < old: 677 | print("Saving checkpoint based on validation score") 678 | checkpoint_status_func(save_path, checkpoint_dict, epoch_results) 679 | else: 680 | print_status_func(epoch_results) 681 | 682 | 683 | def even_slice(arr, size): 684 | """ Force array to be even by slicing off the end """ 685 | extent = -(len(arr) % size) 686 | if extent == 0: 687 | extent = None 688 | return arr[:extent] 689 | 690 | 691 | def make_minibatch(arg, start, stop): 692 | """ Does not handle off-size minibatches """ 693 | if len(arg.shape) == 3: 694 | return [arg[:, start:stop]] 695 | else: 696 | return [arg[start:stop]] 697 | 698 | 699 | def gen_text_minibatch_func(one_hot_size): 700 | """ 701 | Returns a function that will turn a text minibatch into one_hot form. 702 | 703 | For use with iterate_function list_of_minibatch_functions argument. 704 | 705 | Example: 706 | n_chars = 84 707 | text_minibatcher = gen_text_minibatch_func(n_chars) 708 | valid_results = iterate_function( 709 | cost_function, [X_clean_valid, y_clean_valid], minibatch_size, 710 | list_of_output_names=["valid_cost"], 711 | list_of_minibatch_functions=[text_minibatcher], n_epochs=1, 712 | shuffle=False) 713 | """ 714 | def apply(arg, start, stop): 715 | sli = arg[start:stop] 716 | expanded = convert_to_one_hot(sli, one_hot_size) 717 | lengths = [len(s) for s in sli] 718 | mask = np.zeros((max(lengths), len(sli)), dtype=theano.config.floatX) 719 | for n, l in enumerate(lengths): 720 | mask[np.arange(l), n] = 1. 721 | return expanded, mask 722 | return apply 723 | 724 | 725 | def iterate_function(func, list_of_minibatch_args, minibatch_size, 726 | list_of_non_minibatch_args=None, 727 | list_of_minibatch_functions=[make_minibatch], 728 | list_of_output_names=None, 729 | n_epochs=1000, n_status=50, status_func=None, 730 | previous_epoch_results=None, 731 | shuffle=False, random_state=None): 732 | """ 733 | Minibatch arguments should come first. 734 | 735 | Constant arguments which should not be iterated can be passed as 736 | list_of_non_minibatch_args. 737 | 738 | If list_of_minbatch_functions is length 1, will be replicated to length of 739 | list_of_args - applying the same function to all minibatch arguments in 740 | list_of_args. Otherwise, this should be the same length as list_of_args 741 | 742 | list_of_output_names simply names the output of the passed in function. 743 | Should be the same length as the number of outputs from the function. 744 | 745 | status_func is a function run periodically (based on n_status_points), 746 | which allows for validation, early stopping, checkpointing, etc. 747 | 748 | previous_epoch_results allows for continuing from saved checkpoints 749 | 750 | shuffle and random_state are used to determine if minibatches are run 751 | in sequence or selected randomly each epoch. 752 | 753 | By far the craziest function in this file. 754 | 755 | Example validation function: 756 | n_chars = 84 757 | text_minibatcher = gen_text_minibatch_func(n_chars) 758 | 759 | cost_function returns one value, the cost for that minibatch 760 | 761 | valid_results = iterate_function( 762 | cost_function, [X_clean_valid, y_clean_valid], minibatch_size, 763 | list_of_output_names=["valid_cost"], 764 | list_of_minibatch_functions=[text_minibatcher], n_epochs=1, 765 | shuffle=False) 766 | 767 | Example training loop: 768 | 769 | fit_function returns 3 values, nll, kl and the total cost 770 | 771 | epoch_results = iterate_function(fit_function, [X, y], minibatch_size, 772 | list_of_output_names=["nll", "kl", "cost"], 773 | n_epochs=2000, 774 | status_func=status_func, 775 | previous_epoch_results=previous_epoch_results, 776 | shuffle=True, 777 | random_state=random_state) 778 | """ 779 | if previous_epoch_results is None: 780 | epoch_results = defaultdict(list) 781 | else: 782 | epoch_results = previous_epoch_results 783 | # Input checking and setup 784 | if shuffle: 785 | assert random_state is not None 786 | status_points = list(range(n_epochs)) 787 | if len(status_points) >= n_status: 788 | intermediate_points = status_points[::n_epochs // n_status] 789 | status_points = intermediate_points + [status_points[-1]] 790 | else: 791 | status_points = range(len(status_points)) 792 | 793 | for arg in list_of_minibatch_args: 794 | assert len(arg) == len(list_of_minibatch_args[0]) 795 | 796 | indices = minibatch_indices(list_of_minibatch_args[0], minibatch_size) 797 | if len(list_of_minibatch_args[0]) % minibatch_size != 0: 798 | print ("length of dataset should be evenly divisible by " 799 | "minibatch_size.") 800 | if len(list_of_minibatch_functions) == 1: 801 | list_of_minibatch_functions = list_of_minibatch_functions * len( 802 | list_of_minibatch_args) 803 | else: 804 | assert len(list_of_minibatch_functions) == len(list_of_minibatch_args) 805 | # Function loop 806 | for e in range(n_epochs): 807 | results = defaultdict(list) 808 | if shuffle: 809 | random_state.shuffle(indices) 810 | for i, j in indices: 811 | minibatch_args = [] 812 | for n, arg in enumerate(list_of_minibatch_args): 813 | minibatch_args += list_of_minibatch_functions[n](arg, i, j) 814 | if list_of_non_minibatch_args is not None: 815 | all_args = minibatch_args + list_of_non_minibatch_args 816 | else: 817 | all_args = minibatch_args 818 | minibatch_results = func(*all_args) 819 | if type(minibatch_results) is not list: 820 | minibatch_results = [minibatch_results] 821 | for n, k in enumerate(minibatch_results): 822 | if list_of_output_names is not None: 823 | assert len(list_of_output_names) == len(minibatch_results) 824 | results[list_of_output_names[n]].append( 825 | minibatch_results[n]) 826 | else: 827 | results[n].append(minibatch_results[n]) 828 | avg_output = {r: np.mean(results[r]) for r in results.keys()} 829 | for k in avg_output.keys(): 830 | epoch_results[k].append(avg_output[k]) 831 | if e in status_points: 832 | if status_func is not None: 833 | epoch_number = e 834 | status_number = np.searchsorted(status_points, e) 835 | status_func(status_number, epoch_number, epoch_results) 836 | return epoch_results 837 | 838 | 839 | def as_shared(arr, name=None): 840 | """ Quick wrapper for theano.shared """ 841 | if name is not None: 842 | return theano.shared(value=arr, borrow=True) 843 | else: 844 | return theano.shared(value=arr, name=name, borrow=True) 845 | 846 | 847 | def np_zeros(shape): 848 | """ Builds a numpy variable filled with zeros """ 849 | return np.zeros(shape).astype(theano.config.floatX) 850 | 851 | 852 | def np_rand(shape, random_state): 853 | # Make sure bounds aren't the same 854 | return random_state.uniform(low=-0.08, high=0.08, size=shape).astype( 855 | theano.config.floatX) 856 | 857 | 858 | def np_randn(shape, random_state): 859 | """ Builds a numpy variable filled with random normal values """ 860 | return (0.01 * random_state.randn(*shape)).astype(theano.config.floatX) 861 | 862 | 863 | def np_tanh_fan(shape, random_state): 864 | # The . after the 6 is critical! shape has dtype int... 865 | bound = np.sqrt(6. / np.sum(shape)) 866 | return random_state.uniform(low=-bound, high=bound, 867 | size=shape).astype(theano.config.floatX) 868 | 869 | 870 | def np_sigmoid_fan(shape, random_state): 871 | return 4 * np_tanh_fan(shape, random_state) 872 | 873 | 874 | def np_ortho(shape, random_state): 875 | """ Builds a theano variable filled with orthonormal random values """ 876 | g = random_state.randn(*shape) 877 | o_g = linalg.svd(g)[0] 878 | return o_g.astype(theano.config.floatX) 879 | 880 | 881 | def names_in_graph(list_of_names, graph): 882 | """ Return true if all names are in the graph """ 883 | return all([name in graph.keys() for name in list_of_names]) 884 | 885 | 886 | def add_arrays_to_graph(list_of_arrays, list_of_names, graph, strict=True): 887 | assert len(list_of_arrays) == len(list_of_names) 888 | arrays_added = [] 889 | for array, name in zip(list_of_arrays, list_of_names): 890 | if name in graph.keys() and strict: 891 | raise ValueError("Name %s already found in graph!" % name) 892 | shared_array = as_shared(array, name=name) 893 | graph[name] = shared_array 894 | arrays_added.append(shared_array) 895 | 896 | 897 | def make_shapename(name, shape): 898 | if len(shape) == 1: 899 | # vector, primarily init hidden state for RNN 900 | return name + "_kdl_" + str(shape[0]) + "x" 901 | else: 902 | return name + "_kdl_" + "x".join(map(str, list(shape))) 903 | 904 | 905 | def parse_shapename(shapename): 906 | try: 907 | # Bracket for scan 908 | shape = shapename.split("_kdl_")[1].split("[")[0].split("x") 909 | except AttributeError: 910 | raise AttributeError("Unable to parse shapename. Has the expression " 911 | "been tagged with a shape by tag_expression? " 912 | " input shapename was %s" % shapename) 913 | if "[" in shapename.split("_kdl_")[1]: 914 | # inside scan 915 | shape = shape[1:] 916 | name = shapename.split("_kdl_")[0] 917 | # More cleaning to handle scan 918 | shape = tuple([int(s) for s in shape if s != '']) 919 | return name, shape 920 | 921 | 922 | def add_datasets_to_graph(list_of_datasets, list_of_names, graph, strict=True, 923 | list_of_test_values=None): 924 | assert len(list_of_datasets) == len(list_of_names) 925 | datasets_added = [] 926 | for n, (dataset, name) in enumerate(zip(list_of_datasets, list_of_names)): 927 | if dataset.dtype != "int32": 928 | if len(dataset.shape) == 1: 929 | sym = tensor.vector() 930 | elif len(dataset.shape) == 2: 931 | sym = tensor.matrix() 932 | elif len(dataset.shape) == 3: 933 | sym = tensor.tensor3() 934 | else: 935 | raise ValueError("dataset %s has unsupported shape" % name) 936 | elif dataset.dtype == "int32": 937 | if len(dataset.shape) == 1: 938 | sym = tensor.ivector() 939 | elif len(dataset.shape) == 2: 940 | sym = tensor.imatrix() 941 | elif len(dataset.shape) == 3: 942 | sym = tensor.itensor3() 943 | else: 944 | raise ValueError("dataset %s has unsupported shape" % name) 945 | else: 946 | raise ValueError("dataset %s has unsupported dtype %s" % ( 947 | name, dataset.dtype)) 948 | if list_of_test_values is not None: 949 | sym.tag.test_value = list_of_test_values[n] 950 | tag_expression(sym, name, dataset.shape) 951 | datasets_added.append(sym) 952 | graph["__datasets_added__"] = datasets_added 953 | return datasets_added 954 | 955 | 956 | def tag_expression(expression, name, shape): 957 | expression.name = make_shapename(name, shape) 958 | 959 | 960 | def expression_name(expression): 961 | return parse_shapename(expression.name)[0] 962 | 963 | 964 | def expression_shape(expression): 965 | return parse_shapename(expression.name)[1] 966 | 967 | 968 | def calc_expected_dim(expression): 969 | # super intertwined with add_datasets_to_graph 970 | # Expect variables representing datasets in graph!!! 971 | # Function graph madness 972 | # Shape format is HxWxZ 973 | shape = expression_shape(expression) 974 | dim = shape[-1] 975 | return dim 976 | 977 | 978 | def fetch_from_graph(list_of_names, graph): 979 | """ Returns a list of shared variables from the graph """ 980 | if "__datasets_added__" not in graph.keys(): 981 | # Check for dataset in graph 982 | raise AttributeError("No dataset in graph! Make sure to add " 983 | "the dataset using add_datasets_to_graph") 984 | return [graph[name] for name in list_of_names] 985 | 986 | 987 | def get_params_and_grads(graph, cost): 988 | grads = [] 989 | params = [] 990 | for k, p in graph.items(): 991 | if k[:2] == "__": 992 | # skip private tags 993 | continue 994 | print("Computing grad w.r.t %s" % k) 995 | grad = tensor.grad(cost, p) 996 | params.append(p) 997 | grads.append(grad) 998 | return params, grads 999 | 1000 | 1001 | def binary_crossentropy_nll(predicted_values, true_values): 1002 | """ Returns likelihood compared to binary true_values """ 1003 | return (-true_values * tensor.log(predicted_values) - ( 1004 | 1 - true_values) * tensor.log(1 - predicted_values)).sum(axis=-1) 1005 | 1006 | 1007 | def binary_entropy(values): 1008 | return (-values * tensor.log(values)).sum(axis=-1) 1009 | 1010 | 1011 | def categorical_crossentropy_nll(predicted_values, true_values): 1012 | """ Returns likelihood compared to one hot category labels """ 1013 | indices = tensor.argmax(true_values, axis=-1) 1014 | rows = tensor.arange(true_values.shape[0]) 1015 | if predicted_values.ndim < 3: 1016 | return -tensor.log(predicted_values)[rows, indices] 1017 | elif predicted_values.ndim == 3: 1018 | d0 = true_values.shape[0] 1019 | d1 = true_values.shape[1] 1020 | pred = predicted_values.reshape((d0 * d1, -1)) 1021 | ind = indices.reshape((d0 * d1,)) 1022 | s = tensor.arange(pred.shape[0]) 1023 | correct = -tensor.log(pred)[s, ind] 1024 | return correct.reshape((d0, d1,)) 1025 | else: 1026 | raise AttributeError("Tensor dim not supported") 1027 | 1028 | 1029 | def abs_error_nll(predicted_values, true_values): 1030 | return tensor.abs_(predicted_values - true_values).sum(axis=-1) 1031 | 1032 | 1033 | def squared_error_nll(predicted_values, true_values): 1034 | return tensor.sqr(predicted_values - true_values).sum(axis=-1) 1035 | 1036 | 1037 | def gaussian_error_nll(mu_values, sigma_values, true_values): 1038 | """ sigma should come from a softplus layer """ 1039 | nll = 0.5 * (mu_values - true_values) ** 2 / sigma_values ** 2 + tensor.log( 1040 | 2 * np.pi * sigma_values ** 2) 1041 | return nll 1042 | 1043 | 1044 | def log_gaussian_error_nll(mu_values, log_sigma_values, true_values): 1045 | """ log_sigma should come from a linear layer """ 1046 | nll = 0.5 * (mu_values - true_values) ** 2 / tensor.exp( 1047 | log_sigma_values) ** 2 + tensor.log(2 * np.pi) + 2 * log_sigma_values 1048 | return nll 1049 | 1050 | 1051 | def masked_cost(cost, mask): 1052 | return cost * mask 1053 | 1054 | 1055 | def softplus(X): 1056 | return tensor.nnet.softplus(X) + 1E-4 1057 | 1058 | 1059 | def relu(X): 1060 | return X * (X > 1) 1061 | 1062 | 1063 | def linear(X): 1064 | return X 1065 | 1066 | 1067 | def softmax(X): 1068 | # should work for both 2D and 3D 1069 | e_X = tensor.exp(X - X.max(axis=-1, keepdims=True)) 1070 | out = e_X / e_X.sum(axis=-1, keepdims=True) 1071 | return out 1072 | 1073 | 1074 | def dropout(X, random_state, on_off_switch, p=0.): 1075 | if p > 0: 1076 | theano_seed = random_state.randint(-2147462579, 2147462579) 1077 | # Super edge case... 1078 | if theano_seed == 0: 1079 | print("WARNING: prior layer got 0 seed. Reseeding...") 1080 | theano_seed = random_state.randint(-2**32, 2**32) 1081 | theano_rng = MRG_RandomStreams(seed=theano_seed) 1082 | retain_prob = 1 - p 1083 | if X.ndim == 2: 1084 | X *= theano_rng.binomial( 1085 | X.shape, p=retain_prob, 1086 | dtype=theano.config.floatX) ** on_off_switch 1087 | X /= retain_prob 1088 | elif X.ndim == 3: 1089 | # Dropout for recurrent - don't drop over time! 1090 | X *= theano_rng.binomial(( 1091 | X.shape[1], X.shape[2]), p=retain_prob, 1092 | dtype=theano.config.floatX) ** on_off_switch 1093 | X /= retain_prob 1094 | else: 1095 | raise ValueError("Unsupported tensor with ndim %s" % str(X.ndim)) 1096 | return X 1097 | 1098 | 1099 | def dropout_layer(list_of_inputs, name, on_off_switch, dropout_prob=0.5, 1100 | random_state=None): 1101 | theano_seed = random_state.randint(-2147462579, 2147462579) 1102 | # Super edge case... 1103 | if theano_seed == 0: 1104 | print("WARNING: prior layer got 0 seed. Reseeding...") 1105 | theano_seed = random_state.randint(-2**32, 2**32) 1106 | conc_input = concatenate(list_of_inputs, name, axis=-1) 1107 | shape = expression_shape(conc_input) 1108 | dropped = dropout(conc_input, random_state, on_off_switch, p=dropout_prob) 1109 | tag_expression(dropped, name, shape) 1110 | return dropped 1111 | 1112 | 1113 | def projection_layer(list_of_inputs, graph, name, proj_dim=None, 1114 | random_state=None, strict=True, init_func=np_tanh_fan, 1115 | func=linear): 1116 | W_name = name + '_W' 1117 | b_name = name + '_b' 1118 | list_of_names = [W_name, b_name] 1119 | if not names_in_graph(list_of_names, graph): 1120 | assert proj_dim is not None 1121 | assert random_state is not None 1122 | conc_input_dim = int(sum([calc_expected_dim(inp) 1123 | for inp in list_of_inputs])) 1124 | np_W = init_func((conc_input_dim, proj_dim), random_state) 1125 | np_b = np_zeros((proj_dim,)) 1126 | add_arrays_to_graph([np_W, np_b], list_of_names, graph, 1127 | strict=strict) 1128 | else: 1129 | if strict: 1130 | raise AttributeError( 1131 | "Name %s already found in graph with strict mode!" % name) 1132 | W, b = fetch_from_graph(list_of_names, graph) 1133 | conc_input = concatenate(list_of_inputs, name, axis=-1) 1134 | output = tensor.dot(conc_input, W) + b 1135 | if func is not None: 1136 | final = func(output) 1137 | else: 1138 | final = output 1139 | shape = list(expression_shape(conc_input)) 1140 | # Projection is on last axis 1141 | shape[-1] = proj_dim 1142 | new_shape = tuple(shape) 1143 | tag_expression(final, name, new_shape) 1144 | return final 1145 | 1146 | 1147 | def linear_layer(list_of_inputs, graph, name, proj_dim=None, random_state=None, 1148 | strict=True, init_func=np_tanh_fan): 1149 | return projection_layer( 1150 | list_of_inputs=list_of_inputs, graph=graph, name=name, 1151 | proj_dim=proj_dim, random_state=random_state, 1152 | strict=strict, init_func=init_func, func=linear) 1153 | 1154 | 1155 | def sigmoid_layer(list_of_inputs, graph, name, proj_dim=None, random_state=None, 1156 | strict=True, init_func=np_sigmoid_fan): 1157 | return projection_layer( 1158 | list_of_inputs=list_of_inputs, graph=graph, name=name, 1159 | proj_dim=proj_dim, random_state=random_state, 1160 | strict=strict, init_func=init_func, func=tensor.nnet.sigmoid) 1161 | 1162 | 1163 | def tanh_layer(list_of_inputs, graph, name, proj_dim=None, random_state=None, 1164 | strict=True, init_func=np_tanh_fan): 1165 | return projection_layer( 1166 | list_of_inputs=list_of_inputs, graph=graph, name=name, 1167 | proj_dim=proj_dim, random_state=random_state, 1168 | strict=strict, init_func=init_func, func=tensor.tanh) 1169 | 1170 | 1171 | def softplus_layer(list_of_inputs, graph, name, proj_dim=None, 1172 | random_state=None, strict=True, 1173 | init_func=np_tanh_fan): 1174 | return projection_layer( 1175 | list_of_inputs=list_of_inputs, graph=graph, name=name, 1176 | proj_dim=proj_dim, random_state=random_state, 1177 | strict=strict, init_func=init_func, func=softplus) 1178 | 1179 | 1180 | def exp_layer(list_of_inputs, graph, name, proj_dim=None, random_state=None, 1181 | strict=True, init_func=np_tanh_fan): 1182 | return projection_layer( 1183 | list_of_inputs=list_of_inputs, graph=graph, name=name, 1184 | proj_dim=proj_dim, random_state=random_state, 1185 | strict=strict, init_func=init_func, func=tensor.exp) 1186 | 1187 | 1188 | def relu_layer(list_of_inputs, graph, name, proj_dim=None, random_state=None, 1189 | strict=True, init_func=np_tanh_fan): 1190 | return projection_layer( 1191 | list_of_inputs=list_of_inputs, graph=graph, name=name, 1192 | proj_dim=proj_dim, random_state=random_state, 1193 | strict=strict, init_func=init_func, func=relu) 1194 | 1195 | 1196 | def softmax_layer(list_of_inputs, graph, name, proj_dim=None, random_state=None, 1197 | strict=True, init_func=np_tanh_fan): 1198 | return projection_layer( 1199 | list_of_inputs=list_of_inputs, graph=graph, name=name, 1200 | proj_dim=proj_dim, random_state=random_state, 1201 | strict=strict, init_func=init_func, func=softmax) 1202 | 1203 | 1204 | def softmax_sample_layer(list_of_multinomial_inputs, name, random_state=None): 1205 | theano_seed = random_state.randint(-2147462579, 2147462579) 1206 | # Super edge case... 1207 | if theano_seed == 0: 1208 | print("WARNING: prior layer got 0 seed. Reseeding...") 1209 | theano_seed = random_state.randint(-2**32, 2**32) 1210 | theano_rng = MRG_RandomStreams(seed=theano_seed) 1211 | conc_multinomial = concatenate(list_of_multinomial_inputs, name, axis=1) 1212 | shape = expression_shape(conc_multinomial) 1213 | conc_multinomial /= len(list_of_multinomial_inputs) 1214 | tag_expression(conc_multinomial, name, shape) 1215 | samp = theano_rng.multinomial(pvals=conc_multinomial, 1216 | dtype="int32") 1217 | tag_expression(samp, name, (shape[0], shape[1])) 1218 | return samp 1219 | 1220 | 1221 | def gaussian_sample_layer(list_of_mu_inputs, list_of_sigma_inputs, 1222 | name, random_state=None): 1223 | theano_seed = random_state.randint(-2147462579, 2147462579) 1224 | # Super edge case... 1225 | if theano_seed == 0: 1226 | print("WARNING: prior layer got 0 seed. Reseeding...") 1227 | theano_seed = random_state.randint(-2**32, 2**32) 1228 | theano_rng = MRG_RandomStreams(seed=theano_seed) 1229 | conc_mu = concatenate(list_of_mu_inputs, name, axis=1) 1230 | conc_sigma = concatenate(list_of_sigma_inputs, name, axis=1) 1231 | e = theano_rng.normal(size=(conc_sigma.shape[0], 1232 | conc_sigma.shape[1]), 1233 | dtype=conc_sigma.dtype) 1234 | samp = conc_mu + conc_sigma * e 1235 | shape = expression_shape(conc_sigma) 1236 | tag_expression(samp, name, shape) 1237 | return samp 1238 | 1239 | 1240 | def gaussian_log_sample_layer(list_of_mu_inputs, list_of_log_sigma_inputs, 1241 | name, random_state=None): 1242 | """ log_sigma_inputs should be from a linear_layer """ 1243 | theano_seed = random_state.randint(-2147462579, 2147462579) 1244 | # Super edge case... 1245 | if theano_seed == 0: 1246 | print("WARNING: prior layer got 0 seed. Reseeding...") 1247 | theano_seed = random_state.randint(-2**32, 2**32) 1248 | theano_rng = MRG_RandomStreams(seed=theano_seed) 1249 | conc_mu = concatenate(list_of_mu_inputs, name, axis=1) 1250 | conc_log_sigma = concatenate(list_of_log_sigma_inputs, name, axis=1) 1251 | e = theano_rng.normal(size=(conc_log_sigma.shape[0], 1252 | conc_log_sigma.shape[1]), 1253 | dtype=conc_log_sigma.dtype) 1254 | 1255 | samp = conc_mu + tensor.exp(0.5 * conc_log_sigma) * e 1256 | shape = expression_shape(conc_log_sigma) 1257 | tag_expression(samp, name, shape) 1258 | return samp 1259 | 1260 | 1261 | def gaussian_kl(list_of_mu_inputs, list_of_sigma_inputs, name): 1262 | conc_mu = concatenate(list_of_mu_inputs, name) 1263 | conc_sigma = concatenate(list_of_sigma_inputs, name) 1264 | kl = 0.5 * tensor.sum(-2 * tensor.log(conc_sigma) + conc_mu ** 2 1265 | + conc_sigma ** 2 - 1, axis=1) 1266 | return kl 1267 | 1268 | 1269 | def gaussian_log_kl(list_of_mu_inputs, list_of_log_sigma_inputs, name): 1270 | """ log_sigma_inputs should come from linear layer""" 1271 | conc_mu = concatenate(list_of_mu_inputs, name) 1272 | conc_log_sigma = 0.5 * concatenate(list_of_log_sigma_inputs, name) 1273 | kl = 0.5 * tensor.sum(-2 * conc_log_sigma + conc_mu ** 2 1274 | + tensor.exp(conc_log_sigma) ** 2 - 1, axis=1) 1275 | return kl 1276 | 1277 | 1278 | def switch_wrap(switch_func, if_true_var, if_false_var, name): 1279 | switched = tensor.switch(switch_func, if_true_var, if_false_var) 1280 | shape = expression_shape(if_true_var) 1281 | assert shape == expression_shape(if_false_var) 1282 | tag_expression(switched, name, shape) 1283 | return switched 1284 | -------------------------------------------------------------------------------- /graphics_code/vae/flying_vae.py: -------------------------------------------------------------------------------- 1 | from kdl_template import * 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("saved_functions_file", 6 | help="Saved pickle file from vae training") 7 | parser.add_argument("--seed", "-s", 8 | help="random seed for path calculation", 9 | action="store", default=1979, type=int) 10 | 11 | args = parser.parse_args() 12 | if not os.path.exists(args.saved_functions_file): 13 | raise ValueError("Please provide a valid path for saved pickle file!") 14 | 15 | checkpoint_dict = load_checkpoint(args.saved_functions_file) 16 | encode_function = checkpoint_dict["encode_function"] 17 | decode_function = checkpoint_dict["decode_function"] 18 | 19 | random_state = np.random.RandomState(args.seed) 20 | train, valid, test = fetch_binarized_mnist() 21 | # visualize against validation so we aren't cheating 22 | X = valid[0].astype(theano.config.floatX) 23 | 24 | # number of samples 25 | n_plot_samples = 5 26 | # MNIST dimensions 27 | width = 28 28 | height = 28 29 | # Get random data samples 30 | ind = np.arange(len(X)) 31 | random_state.shuffle(ind) 32 | sample_X = X[ind[:n_plot_samples]] 33 | 34 | 35 | def gen_samples(arr): 36 | mu, log_sig = encode_function(arr) 37 | # No noise at test time 38 | out, = decode_function(mu + np.exp(log_sig)) 39 | return out 40 | 41 | # VAE specific plotting 42 | import matplotlib 43 | matplotlib.use('Agg') 44 | import matplotlib.pyplot as plt 45 | samples = gen_samples(sample_X) 46 | f, axarr = plt.subplots(n_plot_samples, 2) 47 | for n, (X_i, s_i) in enumerate(zip(sample_X, samples)): 48 | axarr[n, 0].matshow(X_i.reshape(width, height), cmap="gray") 49 | axarr[n, 1].matshow(s_i.reshape(width, height), cmap="gray") 50 | axarr[n, 0].axis('off') 51 | axarr[n, 1].axis('off') 52 | plt.savefig('vae_reconstruction.png') 53 | plt.close() 54 | 55 | # Calculate linear path between points in space 56 | mus, log_sigmas = encode_function(sample_X) 57 | mu_path = interpolate_between_points(mus) 58 | log_sigma_path = interpolate_between_points(log_sigmas) 59 | 60 | # Path across space from one point to another 61 | path = mu_path + np.exp(log_sigma_path) 62 | out, = decode_function(path) 63 | make_gif(out, "vae_code.gif", width, height, delay=1, grayscale=True) 64 | -------------------------------------------------------------------------------- /graphics_code/vae/kdl_template.py: -------------------------------------------------------------------------------- 1 | ../kdl_template.py -------------------------------------------------------------------------------- /graphics_code/vae/vae.py: -------------------------------------------------------------------------------- 1 | from kdl_template import * 2 | 3 | train, valid, test = fetch_binarized_mnist() 4 | X = train[0].astype(theano.config.floatX) 5 | 6 | # graph holds information necessary to build layers from parents 7 | graph = OrderedDict() 8 | X_sym, = add_datasets_to_graph([X], ["X"], graph) 9 | # random state so script is deterministic 10 | random_state = np.random.RandomState(1999) 11 | 12 | minibatch_size = 100 13 | n_code = 100 14 | n_enc_layer = [200, 200] 15 | n_dec_layer = [200, 200] 16 | width = 28 17 | height = 28 18 | n_input = width * height 19 | 20 | # encode path aka q 21 | l1_enc = softplus_layer([X_sym], graph, 'l1_enc', n_enc_layer[0], random_state) 22 | l2_enc = softplus_layer([l1_enc], graph, 'l2_enc', n_enc_layer[1], 23 | random_state) 24 | code_mu = linear_layer([l2_enc], graph, 'code_mu', n_code, random_state) 25 | code_log_sigma = linear_layer([l2_enc], graph, 'code_log_sigma', n_code, 26 | random_state) 27 | kl = gaussian_log_kl([code_mu], [code_log_sigma], 'kl').mean() 28 | samp = gaussian_log_sample_layer([code_mu], [code_log_sigma], 'samp', 29 | random_state) 30 | 31 | # decode path aka p 32 | l1_dec = softplus_layer([samp], graph, 'l1_dec', n_dec_layer[0], random_state) 33 | l2_dec = softplus_layer([l1_dec], graph, 'l2_dec', n_dec_layer[1], random_state) 34 | out = sigmoid_layer([l2_dec], graph, 'out', n_input, random_state) 35 | 36 | nll = binary_crossentropy_nll(out, X_sym).mean() 37 | # log p(x) = -nll so swap sign 38 | # want to minimize cost in optimization so multiply by -1 39 | cost = -1 * (-nll - kl) 40 | params, grads = get_params_and_grads(graph, cost) 41 | 42 | learning_rate = 0.0003 43 | opt = adam(params) 44 | updates = opt.updates(params, grads, learning_rate) 45 | 46 | # Checkpointing 47 | save_path = "serialized_vae.pkl" 48 | if not os.path.exists(save_path): 49 | fit_function = theano.function([X_sym], [nll, kl, nll + kl], 50 | updates=updates) 51 | encode_function = theano.function([X_sym], [code_mu, code_log_sigma]) 52 | decode_function = theano.function([samp], [out]) 53 | checkpoint_dict = {} 54 | checkpoint_dict["fit_function"] = fit_function 55 | checkpoint_dict["encode_function"] = encode_function 56 | checkpoint_dict["decode_function"] = decode_function 57 | previous_epoch_results = None 58 | else: 59 | checkpoint_dict = load_checkpoint(save_path) 60 | fit_function = checkpoint_dict["fit_function"] 61 | encode_function = checkpoint_dict["encode_function"] 62 | decode_function = checkpoint_dict["decode_function"] 63 | previous_epoch_results = checkpoint_dict["previous_epoch_results"] 64 | 65 | 66 | def status_func(status_number, epoch_number, epoch_results): 67 | checkpoint_status_func(save_path, checkpoint_dict, epoch_results) 68 | 69 | epoch_results = iterate_function(fit_function, [X], minibatch_size, 70 | list_of_output_names=["nll", "kl", 71 | "lower_bound"], 72 | n_epochs=5000, 73 | status_func=status_func, 74 | previous_epoch_results=previous_epoch_results, 75 | shuffle=True, 76 | random_state=random_state) 77 | -------------------------------------------------------------------------------- /graphics_code/vae/vae_code.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kastnerkyle/SciPy2015/e0870cb9b3cfeca60c9ee2e2b6b805f266aac5c3/graphics_code/vae/vae_code.gif -------------------------------------------------------------------------------- /graphics_code/vae/vae_reconstruction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kastnerkyle/SciPy2015/e0870cb9b3cfeca60c9ee2e2b6b805f266aac5c3/graphics_code/vae/vae_reconstruction.png --------------------------------------------------------------------------------