├── LICENSE.md ├── README.md ├── keras ├── README.md ├── cifar10_cnn.py └── weightnorm.py ├── lasagne ├── README.md ├── nn.py └── train.py └── tensorflow ├── README.md └── nn.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Status:** Archive (code is provided as-is, no updates expected) 2 | 3 | 4 | # Weight Normalization 5 | 6 | This repo contains example code for [Weight Normalization](https://arxiv.org/abs/1602.07868), as described in the following 7 | paper: 8 | 9 | **Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks**, by 10 | Tim Salimans, and Diederik P. Kingma. 11 | 12 | - The folder 'lasagne' contains code using the Lasagne package for Theano. This code was used to run the CIFAR-10 experiments in the paper. 13 | - The folder 'tensorflow' contains a single nn.py file with a direct implementation copied from our [PixelCNN++](https://github.com/openai/pixel-cnn) repository. 14 | - The folder 'keras' contains example code for use with the Keras package. 15 | 16 | ## Citation 17 | 18 | If you find this code useful please cite us in your work: 19 | 20 | ``` 21 | @inproceedings{Salimans2016WeightNorm, 22 | title={Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks}, 23 | author={Tim Salimans and Diederik P. Kingma}, 24 | booktitle={Neural Information Processing Systems 2016}, 25 | year={2016} 26 | } 27 | ``` 28 | -------------------------------------------------------------------------------- /keras/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Weight Normalization using Keras 3 | 4 | Example code for using Weight Normalization using [Keras](https://keras.io). 5 | 6 | ```cifar10_cnn.py``` contains the standard CIFAR-10 example from Keras, with lines 64 and 69 edited to include weight normalization and data dependent initialization. -------------------------------------------------------------------------------- /keras/cifar10_cnn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | CIFAR-10 example from https://github.com/fchollet/keras/blob/master/examples/cifar10_cnn.py 3 | Now with weight normalization. Lines 64 and 69 contain the changes w.r.t. original. 4 | ''' 5 | 6 | from __future__ import print_function 7 | from keras.datasets import cifar10 8 | from keras.preprocessing.image import ImageDataGenerator 9 | from keras.models import Sequential 10 | from keras.layers import Dense, Dropout, Activation, Flatten 11 | from keras.layers import Convolution2D, MaxPooling2D 12 | from keras.utils import np_utils 13 | 14 | batch_size = 32 15 | nb_classes = 10 16 | nb_epoch = 200 17 | data_augmentation = True 18 | 19 | # input image dimensions 20 | img_rows, img_cols = 32, 32 21 | # the CIFAR10 images are RGB 22 | img_channels = 3 23 | 24 | # the data, shuffled and split between train and test sets 25 | (X_train, y_train), (X_test, y_test) = cifar10.load_data() 26 | print('X_train shape:', X_train.shape) 27 | print(X_train.shape[0], 'train samples') 28 | print(X_test.shape[0], 'test samples') 29 | X_train = X_train.astype('float32') 30 | X_test = X_test.astype('float32') 31 | X_train /= 255 32 | X_test /= 255 33 | 34 | # convert class vectors to binary class matrices 35 | Y_train = np_utils.to_categorical(y_train, nb_classes) 36 | Y_test = np_utils.to_categorical(y_test, nb_classes) 37 | 38 | model = Sequential() 39 | 40 | model.add(Convolution2D(32, 3, 3, border_mode='same', 41 | input_shape=X_train.shape[1:])) 42 | model.add(Activation('relu')) 43 | model.add(Convolution2D(32, 3, 3)) 44 | model.add(Activation('relu')) 45 | model.add(MaxPooling2D(pool_size=(2, 2))) 46 | model.add(Dropout(0.25)) 47 | 48 | model.add(Convolution2D(64, 3, 3, border_mode='same')) 49 | model.add(Activation('relu')) 50 | model.add(Convolution2D(64, 3, 3)) 51 | model.add(Activation('relu')) 52 | model.add(MaxPooling2D(pool_size=(2, 2))) 53 | model.add(Dropout(0.25)) 54 | 55 | model.add(Flatten()) 56 | model.add(Dense(512)) 57 | model.add(Activation('relu')) 58 | model.add(Dropout(0.5)) 59 | model.add(Dense(nb_classes)) 60 | model.add(Activation('softmax')) 61 | 62 | # let's train the model using SGD + momentum (how original). EDIT: now with weight normalization, so slightly more original ;-) 63 | from weightnorm import SGDWithWeightnorm 64 | sgd_wn = SGDWithWeightnorm(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True) 65 | model.compile(loss='categorical_crossentropy',optimizer=sgd_wn,metrics=['accuracy']) 66 | 67 | # data based initialization of parameters 68 | from weightnorm import data_based_init 69 | data_based_init(model, X_train[:100]) 70 | 71 | 72 | if not data_augmentation: 73 | print('Not using data augmentation.') 74 | model.fit(X_train, Y_train, 75 | batch_size=batch_size, 76 | nb_epoch=nb_epoch, 77 | validation_data=(X_test, Y_test), 78 | shuffle=True) 79 | else: 80 | print('Using real-time data augmentation.') 81 | 82 | # this will do preprocessing and realtime data augmentation 83 | datagen = ImageDataGenerator( 84 | featurewise_center=False, # set input mean to 0 over the dataset 85 | samplewise_center=False, # set each sample mean to 0 86 | featurewise_std_normalization=False, # divide inputs by std of the dataset 87 | samplewise_std_normalization=False, # divide each input by its std 88 | zca_whitening=False, # apply ZCA whitening 89 | rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180) 90 | width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) 91 | height_shift_range=0.1, # randomly shift images vertically (fraction of total height) 92 | horizontal_flip=True, # randomly flip images 93 | vertical_flip=False) # randomly flip images 94 | 95 | # compute quantities required for featurewise normalization 96 | # (std, mean, and principal components if ZCA whitening is applied) 97 | datagen.fit(X_train) 98 | 99 | # fit the model on the batches generated by datagen.flow() 100 | model.fit_generator(datagen.flow(X_train, Y_train, 101 | batch_size=batch_size), 102 | samples_per_epoch=X_train.shape[0], 103 | nb_epoch=nb_epoch, 104 | validation_data=(X_test, Y_test)) 105 | -------------------------------------------------------------------------------- /keras/weightnorm.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | from keras.optimizers import SGD,Adam 3 | import tensorflow as tf 4 | 5 | # adapted from keras.optimizers.SGD 6 | class SGDWithWeightnorm(SGD): 7 | def get_updates(self, params, constraints, loss): 8 | grads = self.get_gradients(loss, params) 9 | self.updates = [] 10 | 11 | lr = self.lr 12 | if self.initial_decay > 0: 13 | lr *= (1. / (1. + self.decay * self.iterations)) 14 | self.updates .append(K.update_add(self.iterations, 1)) 15 | 16 | # momentum 17 | shapes = [K.get_variable_shape(p) for p in params] 18 | moments = [K.zeros(shape) for shape in shapes] 19 | self.weights = [self.iterations] + moments 20 | for p, g, m in zip(params, grads, moments): 21 | 22 | # if a weight tensor (len > 1) use weight normalized parameterization 23 | ps = K.get_variable_shape(p) 24 | if len(ps) > 1: 25 | 26 | # get weight normalization parameters 27 | V, V_norm, V_scaler, g_param, grad_g, grad_V = get_weightnorm_params_and_grads(p, g) 28 | 29 | # momentum container for the 'g' parameter 30 | V_scaler_shape = K.get_variable_shape(V_scaler) 31 | m_g = K.zeros(V_scaler_shape) 32 | 33 | # update g parameters 34 | v_g = self.momentum * m_g - lr * grad_g # velocity 35 | self.updates.append(K.update(m_g, v_g)) 36 | if self.nesterov: 37 | new_g_param = g_param + self.momentum * v_g - lr * grad_g 38 | else: 39 | new_g_param = g_param + v_g 40 | 41 | # update V parameters 42 | v_v = self.momentum * m - lr * grad_V # velocity 43 | self.updates.append(K.update(m, v_v)) 44 | if self.nesterov: 45 | new_V_param = V + self.momentum * v_v - lr * grad_V 46 | else: 47 | new_V_param = V + v_v 48 | 49 | # if there are constraints we apply them to V, not W 50 | if p in constraints: 51 | c = constraints[p] 52 | new_V_param = c(new_V_param) 53 | 54 | # wn param updates --> W updates 55 | add_weightnorm_param_updates(self.updates, new_V_param, new_g_param, p, V_scaler) 56 | 57 | else: # normal SGD with momentum 58 | v = self.momentum * m - lr * g # velocity 59 | self.updates.append(K.update(m, v)) 60 | 61 | if self.nesterov: 62 | new_p = p + self.momentum * v - lr * g 63 | else: 64 | new_p = p + v 65 | 66 | # apply constraints 67 | if p in constraints: 68 | c = constraints[p] 69 | new_p = c(new_p) 70 | 71 | self.updates.append(K.update(p, new_p)) 72 | return self.updates 73 | 74 | # adapted from keras.optimizers.Adam 75 | class AdamWithWeightnorm(Adam): 76 | def get_updates(self, params, constraints, loss): 77 | grads = self.get_gradients(loss, params) 78 | self.updates = [K.update_add(self.iterations, 1)] 79 | 80 | lr = self.lr 81 | if self.initial_decay > 0: 82 | lr *= (1. / (1. + self.decay * self.iterations)) 83 | 84 | t = self.iterations + 1 85 | lr_t = lr * K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t)) 86 | 87 | shapes = [K.get_variable_shape(p) for p in params] 88 | ms = [K.zeros(shape) for shape in shapes] 89 | vs = [K.zeros(shape) for shape in shapes] 90 | self.weights = [self.iterations] + ms + vs 91 | 92 | for p, g, m, v in zip(params, grads, ms, vs): 93 | 94 | # if a weight tensor (len > 1) use weight normalized parameterization 95 | # this is the only part changed w.r.t. keras.optimizers.Adam 96 | ps = K.get_variable_shape(p) 97 | if len(ps)>1: 98 | 99 | # get weight normalization parameters 100 | V, V_norm, V_scaler, g_param, grad_g, grad_V = get_weightnorm_params_and_grads(p, g) 101 | 102 | # Adam containers for the 'g' parameter 103 | V_scaler_shape = K.get_variable_shape(V_scaler) 104 | m_g = K.zeros(V_scaler_shape) 105 | v_g = K.zeros(V_scaler_shape) 106 | 107 | # update g parameters 108 | m_g_t = (self.beta_1 * m_g) + (1. - self.beta_1) * grad_g 109 | v_g_t = (self.beta_2 * v_g) + (1. - self.beta_2) * K.square(grad_g) 110 | new_g_param = g_param - lr_t * m_g_t / (K.sqrt(v_g_t) + self.epsilon) 111 | self.updates.append(K.update(m_g, m_g_t)) 112 | self.updates.append(K.update(v_g, v_g_t)) 113 | 114 | # update V parameters 115 | m_t = (self.beta_1 * m) + (1. - self.beta_1) * grad_V 116 | v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(grad_V) 117 | new_V_param = V - lr_t * m_t / (K.sqrt(v_t) + self.epsilon) 118 | self.updates.append(K.update(m, m_t)) 119 | self.updates.append(K.update(v, v_t)) 120 | 121 | # if there are constraints we apply them to V, not W 122 | if p in constraints: 123 | c = constraints[p] 124 | new_V_param = c(new_V_param) 125 | 126 | # wn param updates --> W updates 127 | add_weightnorm_param_updates(self.updates, new_V_param, new_g_param, p, V_scaler) 128 | 129 | else: # do optimization normally 130 | m_t = (self.beta_1 * m) + (1. - self.beta_1) * g 131 | v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) 132 | p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon) 133 | 134 | self.updates.append(K.update(m, m_t)) 135 | self.updates.append(K.update(v, v_t)) 136 | 137 | new_p = p_t 138 | # apply constraints 139 | if p in constraints: 140 | c = constraints[p] 141 | new_p = c(new_p) 142 | self.updates.append(K.update(p, new_p)) 143 | return self.updates 144 | 145 | 146 | def get_weightnorm_params_and_grads(p, g): 147 | ps = K.get_variable_shape(p) 148 | 149 | # construct weight scaler: V_scaler = g/||V|| 150 | V_scaler_shape = (ps[-1],) # assumes we're using tensorflow! 151 | V_scaler = K.ones(V_scaler_shape) # init to ones, so effective parameters don't change 152 | 153 | # get V parameters = ||V||/g * W 154 | norm_axes = [i for i in range(len(ps) - 1)] 155 | V = p / tf.reshape(V_scaler, [1] * len(norm_axes) + [-1]) 156 | 157 | # split V_scaler into ||V|| and g parameters 158 | V_norm = tf.sqrt(tf.reduce_sum(tf.square(V), norm_axes)) 159 | g_param = V_scaler * V_norm 160 | 161 | # get grad in V,g parameters 162 | grad_g = tf.reduce_sum(g * V, norm_axes) / V_norm 163 | grad_V = tf.reshape(V_scaler, [1] * len(norm_axes) + [-1]) * \ 164 | (g - tf.reshape(grad_g / V_norm, [1] * len(norm_axes) + [-1]) * V) 165 | 166 | return V, V_norm, V_scaler, g_param, grad_g, grad_V 167 | 168 | 169 | def add_weightnorm_param_updates(updates, new_V_param, new_g_param, W, V_scaler): 170 | ps = K.get_variable_shape(new_V_param) 171 | norm_axes = [i for i in range(len(ps) - 1)] 172 | 173 | # update W and V_scaler 174 | new_V_norm = tf.sqrt(tf.reduce_sum(tf.square(new_V_param), norm_axes)) 175 | new_V_scaler = new_g_param / new_V_norm 176 | new_W = tf.reshape(new_V_scaler, [1] * len(norm_axes) + [-1]) * new_V_param 177 | updates.append(K.update(W, new_W)) 178 | updates.append(K.update(V_scaler, new_V_scaler)) 179 | 180 | 181 | # data based initialization for a given Keras model 182 | def data_based_init(model, input): 183 | 184 | # input can be dict, numpy array, or list of numpy arrays 185 | if type(input) is dict: 186 | feed_dict = input 187 | elif type(input) is list: 188 | feed_dict = {tf_inp: np_inp for tf_inp,np_inp in zip(model.inputs,input)} 189 | else: 190 | feed_dict = {model.inputs[0]: input} 191 | 192 | # add learning phase if required 193 | if model.uses_learning_phase and K.learning_phase() not in feed_dict: 194 | feed_dict.update({K.learning_phase(): 1}) 195 | 196 | # get all layer name, output, weight, bias tuples 197 | layer_output_weight_bias = [] 198 | for l in model.layers: 199 | if hasattr(l, 'W') and hasattr(l, 'b'): 200 | assert(l.built) 201 | layer_output_weight_bias.append( (l.name,l.get_output_at(0),l.W,l.b) ) # if more than one node, only use the first 202 | 203 | # iterate over our list and do data dependent init 204 | sess = K.get_session() 205 | for l,o,W,b in layer_output_weight_bias: 206 | print('Performing data dependent initialization for layer ' + l) 207 | m,v = tf.nn.moments(o, [i for i in range(len(o.get_shape())-1)]) 208 | s = tf.sqrt(v + 1e-10) 209 | updates = tf.group(W.assign(W/tf.reshape(s,[1]*(len(W.get_shape())-1)+[-1])), b.assign((b-m)/s)) 210 | sess.run(updates, feed_dict) 211 | -------------------------------------------------------------------------------- /lasagne/README.md: -------------------------------------------------------------------------------- 1 | # Weight Normalization using Lasagne 2 | Lasagne code for "Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks": http://arxiv.org/abs/1602.07868 3 | -------------------------------------------------------------------------------- /lasagne/nn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import theano as th 3 | import theano.tensor as T 4 | from scipy import linalg 5 | import lasagne 6 | 7 | class ZCA(object): 8 | def __init__(self, regularization=1e-5, x=None): 9 | self.regularization = regularization 10 | if x is not None: 11 | self.fit(x) 12 | 13 | def fit(self, x): 14 | s = x.shape 15 | x = x.copy().reshape((s[0],np.prod(s[1:]))) 16 | m = np.mean(x, axis=0) 17 | x -= m 18 | sigma = np.dot(x.T,x) / x.shape[0] 19 | U, S, V = linalg.svd(sigma) 20 | tmp = np.dot(U, np.diag(1./np.sqrt(S+self.regularization))) 21 | tmp2 = np.dot(U, np.diag(np.sqrt(S+self.regularization))) 22 | self.ZCA_mat = th.shared(np.dot(tmp, U.T).astype(th.config.floatX)) 23 | self.inv_ZCA_mat = th.shared(np.dot(tmp2, U.T).astype(th.config.floatX)) 24 | self.mean = th.shared(m.astype(th.config.floatX)) 25 | 26 | def apply(self, x): 27 | s = x.shape 28 | if isinstance(x, np.ndarray): 29 | return np.dot(x.reshape((s[0],np.prod(s[1:]))) - self.mean.get_value(), self.ZCA_mat.get_value()).reshape(s) 30 | elif isinstance(x, T.TensorVariable): 31 | return T.dot(x.flatten(2) - self.mean.dimshuffle('x',0), self.ZCA_mat).reshape(s) 32 | else: 33 | raise NotImplementedError("Whitening only implemented for numpy arrays or Theano TensorVariables") 34 | 35 | def invert(self, x): 36 | s = x.shape 37 | if isinstance(x, np.ndarray): 38 | return (np.dot(x.reshape((s[0],np.prod(s[1:]))), self.inv_ZCA_mat.get_value()) + self.mean.get_value()).reshape(s) 39 | elif isinstance(x, T.TensorVariable): 40 | return (T.dot(x.flatten(2), self.inv_ZCA_mat) + self.mean.dimshuffle('x',0)).reshape(s) 41 | else: 42 | raise NotImplementedError("Whitening only implemented for numpy arrays or Theano TensorVariables") 43 | 44 | # T.nnet.relu has some issues with very large inputs, this is more stable 45 | def relu(x): 46 | return T.maximum(x, 0) 47 | 48 | def lrelu(x, a=0.1): 49 | return T.maximum(x, a*x) 50 | 51 | def log_sum_exp(x, axis=1): 52 | m = T.max(x, axis=axis) 53 | return m+T.log(T.sum(T.exp(x-m.dimshuffle(0,'x')), axis=axis)) 54 | 55 | def adamax_updates(params, cost, lr=0.001, mom1=0.9, mom2=0.999): 56 | updates = [] 57 | grads = T.grad(cost, params) 58 | for p, g in zip(params, grads): 59 | mg = th.shared(np.cast[th.config.floatX](p.get_value() * 0.)) 60 | v = th.shared(np.cast[th.config.floatX](p.get_value() * 0.)) 61 | if mom1>0: 62 | v_t = mom1*v + (1. - mom1)*g 63 | updates.append((v,v_t)) 64 | else: 65 | v_t = g 66 | mg_t = T.maximum(mom2*mg, abs(g)) 67 | g_t = v_t / (mg_t + 1e-6) 68 | p_t = p - lr * g_t 69 | updates.append((mg, mg_t)) 70 | updates.append((p, p_t)) 71 | return updates 72 | 73 | def adam_updates(params, cost, lr=0.001, mom1=0.9, mom2=0.999): 74 | updates = [] 75 | grads = T.grad(cost, params) 76 | t = th.shared(np.cast[th.config.floatX](1.)) 77 | for p, g in zip(params, grads): 78 | v = th.shared(np.cast[th.config.floatX](p.get_value() * 0.)) 79 | mg = th.shared(np.cast[th.config.floatX](p.get_value() * 0.)) 80 | v_t = mom1*v + (1. - mom1)*g 81 | mg_t = mom2*mg + (1. - mom2)*T.square(g) 82 | v_hat = v_t / (1. - mom1 ** t) 83 | mg_hat = mg_t / (1. - mom2 ** t) 84 | g_t = v_hat / T.sqrt(mg_hat + 1e-8) 85 | p_t = p - lr * g_t 86 | updates.append((v, v_t)) 87 | updates.append((mg, mg_t)) 88 | updates.append((p, p_t)) 89 | updates.append((t, t+1)) 90 | return updates 91 | 92 | def softmax_loss(p_true, output_before_softmax): 93 | output_before_softmax -= T.max(output_before_softmax, axis=1, keepdims=True) 94 | if p_true.ndim==2: 95 | return T.mean(T.log(T.sum(T.exp(output_before_softmax),axis=1)) - T.sum(p_true*output_before_softmax, axis=1)) 96 | else: 97 | return T.mean(T.log(T.sum(T.exp(output_before_softmax),axis=1)) - output_before_softmax[T.arange(p_true.shape[0]),p_true]) 98 | 99 | class BatchNormLayer(lasagne.layers.Layer): 100 | def __init__(self, incoming, b=lasagne.init.Constant(0.), g=lasagne.init.Constant(1.), 101 | W=lasagne.init.Normal(0.05), nonlinearity=relu, **kwargs): 102 | super(BatchNormLayer, self).__init__(incoming, **kwargs) 103 | self.nonlinearity = nonlinearity 104 | k = self.input_shape[1] 105 | if b is not None: 106 | self.b = self.add_param(b, (k,), name="b", regularizable=False) 107 | if g is not None: 108 | self.g = self.add_param(g, (k,), name="g") 109 | self.avg_batch_mean = self.add_param(lasagne.init.Constant(0.), (k,), name="avg_batch_mean", regularizable=False, trainable=False) 110 | self.avg_batch_var = self.add_param(lasagne.init.Constant(1.), (k,), name="avg_batch_var", regularizable=False, trainable=False) 111 | incoming.W.set_value(W.sample(incoming.W.get_value().shape)) 112 | if len(self.input_shape)==4: 113 | self.axes_to_sum = (0,2,3) 114 | self.dimshuffle_args = ['x',0,'x','x'] 115 | else: 116 | self.axes_to_sum = 0 117 | self.dimshuffle_args = ['x',0] 118 | 119 | def get_output_for(self, input, deterministic=False, **kwargs): 120 | if deterministic: 121 | norm_features = (input-self.avg_batch_mean.dimshuffle(*self.dimshuffle_args)) / T.sqrt(1e-6 + self.avg_batch_var).dimshuffle(*self.dimshuffle_args) 122 | else: 123 | batch_mean = T.mean(input,axis=self.axes_to_sum).flatten() 124 | centered_input = input-batch_mean.dimshuffle(*self.dimshuffle_args) 125 | batch_var = T.mean(T.square(centered_input),axis=self.axes_to_sum).flatten() 126 | batch_stdv = T.sqrt(1e-6 + batch_var) 127 | norm_features = centered_input / batch_stdv.dimshuffle(*self.dimshuffle_args) 128 | 129 | # BN updates 130 | new_m = 0.9*self.avg_batch_mean + 0.1*batch_mean 131 | new_v = 0.9*self.avg_batch_var + T.cast((0.1*input.shape[0])/(input.shape[0]-1.), th.config.floatX)*batch_var 132 | self.bn_updates = [(self.avg_batch_mean, new_m), (self.avg_batch_var, new_v)] 133 | 134 | if hasattr(self, 'g'): 135 | activation = norm_features*self.g.dimshuffle(*self.dimshuffle_args) 136 | else: 137 | activation = norm_features 138 | if hasattr(self, 'b'): 139 | activation += self.b.dimshuffle(*self.dimshuffle_args) 140 | 141 | return self.nonlinearity(activation) 142 | 143 | def batch_norm(layer, b=lasagne.init.Constant(0.), g=lasagne.init.Constant(1.), **kwargs): 144 | """ 145 | adapted from https://gist.github.com/f0k/f1a6bd3c8585c400c190 146 | """ 147 | nonlinearity = getattr(layer, 'nonlinearity', None) 148 | if nonlinearity is not None: 149 | layer.nonlinearity = lasagne.nonlinearities.identity 150 | if hasattr(layer, 'b'): 151 | del layer.params[layer.b] 152 | layer.b = None 153 | return BatchNormLayer(layer, b, g, nonlinearity=nonlinearity, **kwargs) 154 | 155 | class GlobalAvgLayer(lasagne.layers.Layer): 156 | def __init__(self, incoming, **kwargs): 157 | super(GlobalAvgLayer, self).__init__(incoming, **kwargs) 158 | def get_output_for(self, input, **kwargs): 159 | return T.mean(input, axis=(2,3)) 160 | def get_output_shape_for(self, input_shape): 161 | return input_shape[:2] 162 | 163 | class MeanOnlyBNLayer(lasagne.layers.Layer): 164 | def __init__(self, incoming, b=lasagne.init.Constant(0.), g=lasagne.init.Constant(1.), 165 | W=lasagne.init.Normal(0.05), nonlinearity=relu, **kwargs): 166 | super(MeanOnlyBNLayer, self).__init__(incoming, **kwargs) 167 | self.nonlinearity = nonlinearity 168 | k = self.input_shape[1] 169 | if b is not None: 170 | self.b = self.add_param(b, (k,), name="b", regularizable=False) 171 | if g is not None: 172 | self.g = self.add_param(g, (k,), name="g") 173 | self.avg_batch_mean = self.add_param(lasagne.init.Constant(0.), (k,), name="avg_batch_mean", regularizable=False, trainable=False) 174 | if len(self.input_shape)==4: 175 | self.axes_to_sum = (0,2,3) 176 | self.dimshuffle_args = ['x',0,'x','x'] 177 | else: 178 | self.axes_to_sum = 0 179 | self.dimshuffle_args = ['x',0] 180 | 181 | # scale weights in layer below 182 | incoming.W_param = incoming.W 183 | incoming.W_param.set_value(W.sample(incoming.W_param.get_value().shape)) 184 | if incoming.W_param.ndim==4: 185 | W_axes_to_sum = (1,2,3) 186 | W_dimshuffle_args = [0,'x','x','x'] 187 | else: 188 | W_axes_to_sum = 0 189 | W_dimshuffle_args = ['x',0] 190 | if g is not None: 191 | incoming.W = incoming.W_param * (self.g/T.sqrt(T.sum(T.square(incoming.W_param),axis=W_axes_to_sum))).dimshuffle(*W_dimshuffle_args) 192 | else: 193 | incoming.W = incoming.W_param / T.sqrt(T.sum(T.square(incoming.W_param),axis=W_axes_to_sum,keepdims=True)) 194 | 195 | def get_output_for(self, input, deterministic=False, init=False, **kwargs): 196 | if deterministic: 197 | activation = input - self.avg_batch_mean.dimshuffle(*self.dimshuffle_args) 198 | else: 199 | m = T.mean(input,axis=self.axes_to_sum) 200 | activation = input - m.dimshuffle(*self.dimshuffle_args) 201 | self.bn_updates = [(self.avg_batch_mean, 0.9*self.avg_batch_mean + 0.1*m)] 202 | if init: 203 | stdv = T.sqrt(T.mean(T.square(activation),axis=self.axes_to_sum)) 204 | activation /= stdv.dimshuffle(*self.dimshuffle_args) 205 | self.init_updates = [(self.g, self.g/stdv)] 206 | if hasattr(self, 'b'): 207 | activation += self.b.dimshuffle(*self.dimshuffle_args) 208 | 209 | return self.nonlinearity(activation) 210 | 211 | def mean_only_bn(layer, **kwargs): 212 | nonlinearity = getattr(layer, 'nonlinearity', None) 213 | if nonlinearity is not None: 214 | layer.nonlinearity = lasagne.nonlinearities.identity 215 | if hasattr(layer, 'b'): 216 | del layer.params[layer.b] 217 | layer.b = None 218 | return MeanOnlyBNLayer(layer, nonlinearity=nonlinearity, **kwargs) 219 | 220 | class WeightNormLayer(lasagne.layers.Layer): 221 | def __init__(self, incoming, b=lasagne.init.Constant(0.), g=lasagne.init.Constant(1.), 222 | W=lasagne.init.Normal(0.05), nonlinearity=relu, **kwargs): 223 | super(WeightNormLayer, self).__init__(incoming, **kwargs) 224 | self.nonlinearity = nonlinearity 225 | k = self.input_shape[1] 226 | if b is not None: 227 | self.b = self.add_param(b, (k,), name="b", regularizable=False) 228 | if g is not None: 229 | self.g = self.add_param(g, (k,), name="g") 230 | if len(self.input_shape)==4: 231 | self.axes_to_sum = (0,2,3) 232 | self.dimshuffle_args = ['x',0,'x','x'] 233 | else: 234 | self.axes_to_sum = 0 235 | self.dimshuffle_args = ['x',0] 236 | 237 | # scale weights in layer below 238 | incoming.W_param = incoming.W 239 | incoming.W_param.set_value(W.sample(incoming.W_param.get_value().shape)) 240 | if incoming.W_param.ndim==4: 241 | W_axes_to_sum = (1,2,3) 242 | W_dimshuffle_args = [0,'x','x','x'] 243 | else: 244 | W_axes_to_sum = 0 245 | W_dimshuffle_args = ['x',0] 246 | if g is not None: 247 | incoming.W = incoming.W_param * (self.g/T.sqrt(T.sum(T.square(incoming.W_param),axis=W_axes_to_sum))).dimshuffle(*W_dimshuffle_args) 248 | else: 249 | incoming.W = incoming.W_param / T.sqrt(T.sum(T.square(incoming.W_param),axis=W_axes_to_sum,keepdims=True)) 250 | 251 | def get_output_for(self, input, init=False, **kwargs): 252 | if init: 253 | m = T.mean(input, self.axes_to_sum) 254 | input -= m.dimshuffle(*self.dimshuffle_args) 255 | stdv = T.sqrt(T.mean(T.square(input),axis=self.axes_to_sum)) 256 | input /= stdv.dimshuffle(*self.dimshuffle_args) 257 | self.init_updates = [(self.b, -m/stdv), (self.g, self.g/stdv)] 258 | elif hasattr(self,'b'): 259 | input += self.b.dimshuffle(*self.dimshuffle_args) 260 | 261 | return self.nonlinearity(input) 262 | 263 | def weight_norm(layer, **kwargs): 264 | nonlinearity = getattr(layer, 'nonlinearity', None) 265 | if nonlinearity is not None: 266 | layer.nonlinearity = lasagne.nonlinearities.identity 267 | if hasattr(layer, 'b'): 268 | del layer.params[layer.b] 269 | layer.b = None 270 | return WeightNormLayer(layer, nonlinearity=nonlinearity, **kwargs) 271 | 272 | class InitLayer(lasagne.layers.Layer): 273 | def __init__(self, incoming, b=lasagne.init.Constant(0.), g=lasagne.init.Constant(1.), nonlinearity=relu, **kwargs): 274 | super(InitLayer, self).__init__(incoming, **kwargs) 275 | self.nonlinearity = nonlinearity 276 | k = self.input_shape[1] 277 | if b is not None: 278 | self.b = self.add_param(b, (k,), name="b", regularizable=False) 279 | if g is not None: 280 | self.g = self.add_param(g, (k,), name="g", regularizable=False, trainable=False) 281 | if len(self.input_shape)==4: 282 | self.axes_to_sum = (0,2,3) 283 | self.dimshuffle_args = ['x',0,'x','x'] 284 | else: 285 | self.axes_to_sum = 0 286 | self.dimshuffle_args = ['x',0] 287 | 288 | # scale weights in layer below 289 | incoming.W_param = incoming.W 290 | if incoming.W_param.ndim==4: 291 | W_dimshuffle_args = [0,'x','x','x'] 292 | else: 293 | W_dimshuffle_args = ['x',0] 294 | incoming.W = self.g.dimshuffle(*W_dimshuffle_args) * incoming.W_param 295 | 296 | def get_output_for(self, input, init=False, **kwargs): 297 | if init: 298 | m = T.mean(input, self.axes_to_sum) 299 | input -= m.dimshuffle(*self.dimshuffle_args) 300 | stdv = T.sqrt(T.mean(T.square(input),axis=self.axes_to_sum)) 301 | input /= stdv.dimshuffle(*self.dimshuffle_args) 302 | self.init_updates = [(self.b, -m/stdv), (self.g, self.g/stdv)] 303 | elif hasattr(self,'b'): 304 | input += self.b.dimshuffle(*self.dimshuffle_args) 305 | 306 | return self.nonlinearity(input) 307 | 308 | def no_norm(layer, **kwargs): 309 | nonlinearity = getattr(layer, 'nonlinearity', None) 310 | if nonlinearity is not None: 311 | layer.nonlinearity = lasagne.nonlinearities.identity 312 | if hasattr(layer, 'b'): 313 | del layer.params[layer.b] 314 | layer.b = None 315 | return InitLayer(layer, nonlinearity=nonlinearity, **kwargs) 316 | -------------------------------------------------------------------------------- /lasagne/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cPickle 3 | import time 4 | import os 5 | import logging 6 | import numpy as np 7 | import theano as th 8 | import theano.tensor as T 9 | from theano.sandbox.rng_mrg import MRG_RandomStreams 10 | import lasagne 11 | import lasagne.layers as ll 12 | from lasagne.layers import dnn 13 | import nn 14 | logging.basicConfig(level=logging.INFO) 15 | 16 | # settings 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--seed', default=1, type=int) 19 | parser.add_argument('--batch_size', default=100, type=int) 20 | parser.add_argument('--norm_type', default='no_norm', type=str) 21 | parser.add_argument('--learning_rate', default=0.001, type=float) 22 | args = parser.parse_args() 23 | logging.info(args) 24 | 25 | # fixed random seeds 26 | rng = np.random.RandomState(args.seed) 27 | theano_rng = MRG_RandomStreams(rng.randint(2 ** 15)) 28 | lasagne.random.set_rng(np.random.RandomState(rng.randint(2 ** 15))) 29 | 30 | # setup output 31 | time_str = time.strftime("%m-%d-%H-%M", time.gmtime()) 32 | exp_dir = args.norm_type + "_" + time_str + "_" + "{}".format(args.learning_rate).replace(".", "p") 33 | try: 34 | os.stat(exp_dir) 35 | except: 36 | os.makedirs(exp_dir) 37 | logging.info("OPENING " + exp_dir + '/results.csv') 38 | results_file = open(exp_dir + '/results.csv', 'w') 39 | results_file.write('epoch, time, train_error, test_error\n') 40 | results_file.flush() 41 | 42 | # load CIFAR-10 data 43 | def unpickle(file): 44 | fo = open(file, 'rb') 45 | d = cPickle.load(fo) 46 | fo.close() 47 | return {'x': np.cast[th.config.floatX]((-127.5 + d['data'].reshape((10000,3,32,32)))/128.), 'y': np.array(d['labels']).astype(np.uint8)} 48 | 49 | train_data = [unpickle('/home/ubuntu/data/cifar-10-python/cifar-10-batches-py/data_batch_' + str(i)) for i in range(1,6)] 50 | trainx = np.concatenate([d['x'] for d in train_data],axis=0) 51 | trainy = np.concatenate([d['y'] for d in train_data]) 52 | test_data = unpickle('/home/ubuntu/data/cifar-10-python/cifar-10-batches-py/test_batch') 53 | testx = test_data['x'] 54 | testy = test_data['y'] 55 | nr_batches_train = int(trainx.shape[0]/args.batch_size) 56 | nr_batches_test = int(testx.shape[0]/args.batch_size) 57 | 58 | # whitening 59 | whitener = nn.ZCA(x=trainx) 60 | trainx_white = whitener.apply(trainx) 61 | testx_white = whitener.apply(testx) 62 | 63 | # specify model 64 | if args.norm_type=='weight_norm': 65 | normalizer = lambda l: nn.weight_norm(l) 66 | elif args.norm_type=='batch_norm': 67 | normalizer = lambda l: nn.batch_norm(l) 68 | elif args.norm_type=='mean_only_bn': 69 | normalizer = lambda l: nn.mean_only_bn(l) 70 | elif args.norm_type=='no_norm': 71 | normalizer = lambda l: nn.no_norm(l) 72 | else: 73 | raise NotImplementedError('incorrect norm type') 74 | 75 | layers = [ll.InputLayer(shape=(None, 3, 32, 32))] 76 | layers.append(ll.GaussianNoiseLayer(layers[-1], sigma=0.15)) 77 | layers.append(normalizer(dnn.Conv2DDNNLayer(layers[-1], 96, (3,3), pad=1, nonlinearity=nn.lrelu))) 78 | layers.append(normalizer(dnn.Conv2DDNNLayer(layers[-1], 96, (3,3), pad=1, nonlinearity=nn.lrelu))) 79 | layers.append(normalizer(dnn.Conv2DDNNLayer(layers[-1], 96, (3,3), pad=1, nonlinearity=nn.lrelu))) 80 | layers.append(ll.MaxPool2DLayer(layers[-1], 2)) 81 | layers.append(ll.DropoutLayer(layers[-1], p=0.5)) 82 | layers.append(normalizer(dnn.Conv2DDNNLayer(layers[-1], 192, (3,3), pad=1, nonlinearity=nn.lrelu))) 83 | layers.append(normalizer(dnn.Conv2DDNNLayer(layers[-1], 192, (3,3), pad=1, nonlinearity=nn.lrelu))) 84 | layers.append(normalizer(dnn.Conv2DDNNLayer(layers[-1], 192, (3,3), pad=1, nonlinearity=nn.lrelu))) 85 | layers.append(ll.MaxPool2DLayer(layers[-1], 2)) 86 | layers.append(ll.DropoutLayer(layers[-1], p=0.5)) 87 | layers.append(normalizer(dnn.Conv2DDNNLayer(layers[-1], 192, (3,3), pad=0, nonlinearity=nn.lrelu))) 88 | layers.append(normalizer(ll.NINLayer(layers[-1], num_units=192, nonlinearity=nn.lrelu))) 89 | layers.append(normalizer(ll.NINLayer(layers[-1], num_units=192, nonlinearity=nn.lrelu))) 90 | layers.append(nn.GlobalAvgLayer(layers[-1])) 91 | layers.append(normalizer(ll.DenseLayer(layers[-1], num_units=10, nonlinearity=None))) 92 | 93 | # initialization 94 | x = T.tensor4() 95 | temp = ll.get_output(layers[-1], x, init=True) 96 | init_updates = [u for l in layers for u in getattr(l,'init_updates',[])] 97 | 98 | # discriminative cost & updates 99 | output_before_softmax = ll.get_output(layers[-1], x) 100 | bn_updates = [u for l in layers for u in getattr(l,'bn_updates',[])] 101 | y = T.ivector() 102 | cost = nn.softmax_loss(y, output_before_softmax) 103 | train_err = T.mean(T.neq(T.argmax(output_before_softmax,axis=1),y)) 104 | params = ll.get_all_params(layers, trainable=True) 105 | lr = T.scalar() 106 | mom1 = T.scalar() 107 | param_updates = nn.adam_updates(params, cost, lr=lr, mom1=mom1) 108 | 109 | test_output_before_softmax = ll.get_output(layers[-1], x, deterministic=True) 110 | test_err = T.mean(T.neq(T.argmax(test_output_before_softmax,axis=1),y)) 111 | 112 | # compile Theano functions 113 | train_batch = th.function(inputs=[x,y,lr,mom1], outputs=train_err, updates=param_updates+bn_updates) 114 | test_batch = th.function(inputs=[x,y], outputs=test_err) 115 | initfun = th.function(inputs=[x], outputs=None, updates=init_updates, on_unused_input='ignore') 116 | 117 | # //////////// perform training ////////////// 118 | begin_all = time.time() 119 | for epoch in range(200): 120 | begin_epoch = time.time() 121 | lr = np.cast[th.config.floatX](args.learning_rate * np.minimum(2. - epoch/100., 1.)) 122 | if epoch<100: 123 | mom1 = 0.9 124 | else: 125 | mom1 = 0.5 126 | 127 | # permute the training data 128 | inds = rng.permutation(trainx_white.shape[0]) 129 | trainx_white = trainx_white[inds] 130 | trainy = trainy[inds] 131 | 132 | # init params if first epoch 133 | if epoch==0: 134 | initfun(trainx_white[:500]) 135 | 136 | # train 137 | train_err = 0. 138 | for t in range(nr_batches_train): 139 | train_err += train_batch(trainx_white[t*args.batch_size:(t+1)*args.batch_size], 140 | trainy[t*args.batch_size:(t+1)*args.batch_size],lr,mom1) 141 | train_err /= nr_batches_train 142 | 143 | # test 144 | test_err = 0. 145 | for t in range(nr_batches_test): 146 | test_err += test_batch(testx_white[t*args.batch_size:(t+1)*args.batch_size],testy[t*args.batch_size:(t+1)*args.batch_size]) 147 | test_err /= nr_batches_test 148 | 149 | logging.info('Iteration %d, time = %ds, train_err = %.4f, test_err = %.4f' % (epoch, time.time()-begin_epoch, train_err, test_err)) 150 | results_file.write('%d, %d, %.4f, %.4f\n' % (epoch, time.time()-begin_all, train_err, test_err)) 151 | results_file.flush() 152 | 153 | -------------------------------------------------------------------------------- /tensorflow/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Direct implementation of Weight Normalization in Tensorflow 3 | 4 | The ```nn.py``` file contains an example of a direct implementation of weight normalization and data dependent initialization in Tensorflow. For use, see our [PixelCNN++](https://github.com/openai/pixel-cnn) repository. -------------------------------------------------------------------------------- /tensorflow/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various tensorflow utilities 3 | """ 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from tensorflow.contrib.framework.python.ops import add_arg_scope 8 | 9 | def int_shape(x): 10 | return list(map(int, x.get_shape())) 11 | 12 | def concat_elu(x): 13 | """ like concatenated ReLU (http://arxiv.org/abs/1603.05201), but then with ELU """ 14 | axis = len(x.get_shape())-1 15 | return tf.nn.elu(tf.concat(axis, [x, -x])) 16 | 17 | def log_sum_exp(x): 18 | """ numerically stable log_sum_exp implementation that prevents overflow """ 19 | axis = len(x.get_shape())-1 20 | m = tf.reduce_max(x, axis) 21 | m2 = tf.reduce_max(x, axis, keep_dims=True) 22 | return m + tf.log(tf.reduce_sum(tf.exp(x-m2), axis)) 23 | 24 | def log_prob_from_logits(x): 25 | """ numerically stable log_softmax implementation that prevents overflow """ 26 | axis = len(x.get_shape())-1 27 | m = tf.reduce_max(x, axis, keep_dims=True) 28 | return x - m - tf.log(tf.reduce_sum(tf.exp(x-m), axis, keep_dims=True)) 29 | 30 | def discretized_mix_logistic_loss(x,l,sum_all=True): 31 | """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """ 32 | xs = int_shape(x) # true image (i.e. labels) to regress to, e.g. (B,32,32,3) 33 | ls = int_shape(l) # predicted distribution, e.g. (B,32,32,100) 34 | nr_mix = int(ls[-1] / 10) # here and below: unpacking the params of the mixture of logistics 35 | logit_probs = l[:,:,:,:nr_mix] 36 | l = tf.reshape(l[:,:,:,nr_mix:], xs + [nr_mix*3]) 37 | means = l[:,:,:,:,:nr_mix] 38 | log_scales = tf.maximum(l[:,:,:,:,nr_mix:2*nr_mix], -7.) 39 | coeffs = tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix]) 40 | x = tf.reshape(x, xs + [1]) + tf.zeros(xs + [nr_mix]) # here and below: getting the means and adjusting them based on preceding sub-pixels 41 | m2 = tf.reshape(means[:,:,:,1,:] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :], [xs[0],xs[1],xs[2],1,nr_mix]) 42 | m3 = tf.reshape(means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :], [xs[0],xs[1],xs[2],1,nr_mix]) 43 | means = tf.concat(3,[tf.reshape(means[:,:,:,0,:], [xs[0],xs[1],xs[2],1,nr_mix]), m2, m3]) 44 | centered_x = x - means 45 | inv_stdv = tf.exp(-log_scales) 46 | plus_in = inv_stdv * (centered_x + 1./255.) 47 | cdf_plus = tf.nn.sigmoid(plus_in) 48 | min_in = inv_stdv * (centered_x - 1./255.) 49 | cdf_min = tf.nn.sigmoid(min_in) 50 | log_cdf_plus = plus_in - tf.nn.softplus(plus_in) # log probability for edge case of 0 (before scaling) 51 | log_one_minus_cdf_min = -tf.nn.softplus(min_in) # log probability for edge case of 255 (before scaling) 52 | cdf_delta = cdf_plus - cdf_min # probability for all other cases 53 | mid_in = inv_stdv * centered_x 54 | log_pdf_mid = mid_in - log_scales - 2.*tf.nn.softplus(mid_in) # log probability in the center of the bin, to be used in extreme cases (not actually used in our code) 55 | 56 | # now select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen for us) 57 | 58 | # this is what we are really doing, but using the robust version below for extreme cases in other applications and to avoid NaN issue with tf.select() 59 | # log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.log(cdf_delta))) 60 | 61 | # robust version, that still works if probabilities are below 1e-5 (which never happens in our code) 62 | # tensorflow backpropagates through tf.select() by multiplying with zero instead of selecting: this requires use to use some ugly tricks to avoid potential NaNs 63 | # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue 64 | # if the probability on a sub-pixel is below 1e-5, we use an approximation based on the assumption that the log-density is constant in the bin of the observed sub-pixel value 65 | log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.select(cdf_delta > 1e-5, tf.log(tf.maximum(cdf_delta, 1e-12)), log_pdf_mid - np.log(127.5)))) 66 | 67 | log_probs = tf.reduce_sum(log_probs,3) + log_prob_from_logits(logit_probs) 68 | if sum_all: 69 | return -tf.reduce_sum(log_sum_exp(log_probs)) 70 | else: 71 | return -tf.reduce_sum(log_sum_exp(log_probs),[1,2]) 72 | 73 | def sample_from_discretized_mix_logistic(l,nr_mix): 74 | ls = int_shape(l) 75 | xs = ls[:-1] + [3] 76 | # unpack parameters 77 | logit_probs = l[:, :, :, :nr_mix] 78 | l = tf.reshape(l[:, :, :, nr_mix:], xs + [nr_mix*3]) 79 | # sample mixture indicator from softmax 80 | sel = tf.one_hot(tf.argmax(logit_probs - tf.log(-tf.log(tf.random_uniform(logit_probs.get_shape(), minval=1e-5, maxval=1. - 1e-5))), 3), depth=nr_mix, dtype=tf.float32) 81 | sel = tf.reshape(sel, xs[:-1] + [1,nr_mix]) 82 | # select logistic parameters 83 | means = tf.reduce_sum(l[:,:,:,:,:nr_mix]*sel,4) 84 | log_scales = tf.maximum(tf.reduce_sum(l[:,:,:,:,nr_mix:2*nr_mix]*sel,4), -7.) 85 | coeffs = tf.reduce_sum(tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])*sel,4) 86 | # sample from logistic & clip to interval 87 | # we don't actually round to the nearest 8bit value when sampling 88 | u = tf.random_uniform(means.get_shape(), minval=1e-5, maxval=1. - 1e-5) 89 | x = means + tf.exp(log_scales)*(tf.log(u) - tf.log(1. - u)) 90 | x0 = tf.minimum(tf.maximum(x[:,:,:,0], -1.), 1.) 91 | x1 = tf.minimum(tf.maximum(x[:,:,:,1] + coeffs[:,:,:,0]*x0, -1.), 1.) 92 | x2 = tf.minimum(tf.maximum(x[:,:,:,2] + coeffs[:,:,:,1]*x0 + coeffs[:,:,:,2]*x1, -1.), 1.) 93 | return tf.concat(3,[tf.reshape(x0,xs[:-1]+[1]), tf.reshape(x1,xs[:-1]+[1]), tf.reshape(x2,xs[:-1]+[1])]) 94 | 95 | def get_var_maybe_avg(var_name, ema, **kwargs): 96 | ''' utility for retrieving polyak averaged params ''' 97 | v = tf.get_variable(var_name, **kwargs) 98 | if ema is not None: 99 | v = ema.average(v) 100 | return v 101 | 102 | def get_vars_maybe_avg(var_names, ema, **kwargs): 103 | ''' utility for retrieving polyak averaged params ''' 104 | vars = [] 105 | for vn in var_names: 106 | vars.append(get_var_maybe_avg(vn, ema, **kwargs)) 107 | return vars 108 | 109 | def adam_updates(params, cost_or_grads, lr=0.001, mom1=0.9, mom2=0.999): 110 | ''' Adam optimizer ''' 111 | updates = [] 112 | if type(cost_or_grads) is not list: 113 | grads = tf.gradients(cost_or_grads, params) 114 | else: 115 | grads = cost_or_grads 116 | t = tf.Variable(1., 'adam_t') 117 | for p, g in zip(params, grads): 118 | mg = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_mg') 119 | if mom1>0: 120 | v = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_v') 121 | v_t = mom1*v + (1. - mom1)*g 122 | v_hat = v_t / (1. - tf.pow(mom1,t)) 123 | updates.append(v.assign(v_t)) 124 | else: 125 | v_hat = g 126 | mg_t = mom2*mg + (1. - mom2)*tf.square(g) 127 | mg_hat = mg_t / (1. - tf.pow(mom2,t)) 128 | g_t = v_hat / tf.sqrt(mg_hat + 1e-8) 129 | p_t = p - lr * g_t 130 | updates.append(mg.assign(mg_t)) 131 | updates.append(p.assign(p_t)) 132 | updates.append(t.assign_add(1)) 133 | return tf.group(*updates) 134 | 135 | def get_name(layer_name, counters): 136 | ''' utlity for keeping track of layer names ''' 137 | if not layer_name in counters: 138 | counters[layer_name] = 0 139 | name = layer_name + '_' + str(counters[layer_name]) 140 | counters[layer_name] += 1 141 | return name 142 | 143 | @add_arg_scope 144 | def dense(x, num_units, nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs): 145 | ''' fully connected layer ''' 146 | name = get_name('dense', counters) 147 | with tf.variable_scope(name): 148 | if init: 149 | # data based initialization of parameters 150 | V = tf.get_variable('V', [int(x.get_shape()[1]),num_units], tf.float32, tf.random_normal_initializer(0, 0.05), trainable=True) 151 | V_norm = tf.nn.l2_normalize(V.initialized_value(), [0]) 152 | x_init = tf.matmul(x, V_norm) 153 | m_init, v_init = tf.nn.moments(x_init, [0]) 154 | scale_init = init_scale/tf.sqrt(v_init + 1e-10) 155 | g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init, trainable=True) 156 | b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init*scale_init, trainable=True) 157 | x_init = tf.reshape(scale_init,[1,num_units])*(x_init-tf.reshape(m_init,[1,num_units])) 158 | if nonlinearity is not None: 159 | x_init = nonlinearity(x_init) 160 | return x_init 161 | 162 | else: 163 | V,g,b = get_vars_maybe_avg(['V','g','b'], ema) 164 | tf.assert_variables_initialized([V,g,b]) 165 | 166 | # use weight normalization (Salimans & Kingma, 2016) 167 | x = tf.matmul(x, V) 168 | scaler = g/tf.sqrt(tf.reduce_sum(tf.square(V),[0])) 169 | x = tf.reshape(scaler,[1,num_units])*x + tf.reshape(b,[1,num_units]) 170 | 171 | # apply nonlinearity 172 | if nonlinearity is not None: 173 | x = nonlinearity(x) 174 | return x 175 | 176 | @add_arg_scope 177 | def conv2d(x, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs): 178 | ''' convolutional layer ''' 179 | name = get_name('conv2d', counters) 180 | with tf.variable_scope(name): 181 | if init: 182 | # data based initialization of parameters 183 | V = tf.get_variable('V', filter_size+[int(x.get_shape()[-1]),num_filters], tf.float32, tf.random_normal_initializer(0, 0.05), trainable=True) 184 | V_norm = tf.nn.l2_normalize(V.initialized_value(), [0,1,2]) 185 | x_init = tf.nn.conv2d(x, V_norm, [1]+stride+[1], pad) 186 | m_init, v_init = tf.nn.moments(x_init, [0,1,2]) 187 | scale_init = init_scale/tf.sqrt(v_init + 1e-8) 188 | g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init, trainable=True) 189 | b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init*scale_init, trainable=True) 190 | x_init = tf.reshape(scale_init,[1,1,1,num_filters])*(x_init-tf.reshape(m_init,[1,1,1,num_filters])) 191 | if nonlinearity is not None: 192 | x_init = nonlinearity(x_init) 193 | return x_init 194 | 195 | else: 196 | V, g, b = get_vars_maybe_avg(['V', 'g', 'b'], ema) 197 | tf.assert_variables_initialized([V,g,b]) 198 | 199 | # use weight normalization (Salimans & Kingma, 2016) 200 | W = tf.reshape(g,[1,1,1,num_filters])*tf.nn.l2_normalize(V,[0,1,2]) 201 | 202 | # calculate convolutional layer output 203 | x = tf.nn.bias_add(tf.nn.conv2d(x, W, [1]+stride+[1], pad), b) 204 | 205 | # apply nonlinearity 206 | if nonlinearity is not None: 207 | x = nonlinearity(x) 208 | return x 209 | 210 | @add_arg_scope 211 | def deconv2d(x, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs): 212 | ''' transposed convolutional layer ''' 213 | name = get_name('deconv2d', counters) 214 | xs = int_shape(x) 215 | if pad=='SAME': 216 | target_shape = [xs[0], xs[1]*stride[0], xs[2]*stride[1], num_filters] 217 | else: 218 | target_shape = [xs[0], xs[1]*stride[0] + filter_size[0]-1, xs[2]*stride[1] + filter_size[1]-1, num_filters] 219 | with tf.variable_scope(name): 220 | if init: 221 | # data based initialization of parameters 222 | V = tf.get_variable('V', filter_size+[num_filters,int(x.get_shape()[-1])], tf.float32, tf.random_normal_initializer(0, 0.05), trainable=True) 223 | V_norm = tf.nn.l2_normalize(V.initialized_value(), [0,1,3]) 224 | x_init = tf.nn.conv2d_transpose(x, V_norm, target_shape, [1]+stride+[1], padding=pad) 225 | m_init, v_init = tf.nn.moments(x_init, [0,1,2]) 226 | scale_init = init_scale/tf.sqrt(v_init + 1e-8) 227 | g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init, trainable=True) 228 | b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init*scale_init, trainable=True) 229 | x_init = tf.reshape(scale_init,[1,1,1,num_filters])*(x_init-tf.reshape(m_init,[1,1,1,num_filters])) 230 | if nonlinearity is not None: 231 | x_init = nonlinearity(x_init) 232 | return x_init 233 | 234 | else: 235 | V, g, b = get_vars_maybe_avg(['V', 'g', 'b'], ema) 236 | tf.assert_variables_initialized([V,g,b]) 237 | 238 | # use weight normalization (Salimans & Kingma, 2016) 239 | W = tf.reshape(g,[1,1,num_filters,1])*tf.nn.l2_normalize(V,[0,1,3]) 240 | 241 | # calculate convolutional layer output 242 | x = tf.nn.conv2d_transpose(x, W, target_shape, [1]+stride+[1], padding=pad) 243 | x = tf.nn.bias_add(x, b) 244 | 245 | # apply nonlinearity 246 | if nonlinearity is not None: 247 | x = nonlinearity(x) 248 | return x 249 | 250 | @add_arg_scope 251 | def nin(x, num_units, **kwargs): 252 | """ a network in network layer (1x1 CONV) """ 253 | s = int_shape(x) 254 | x = tf.reshape(x, [np.prod(s[:-1]),s[-1]]) 255 | x = dense(x, num_units, **kwargs) 256 | return tf.reshape(x, s[:-1]+[num_units]) 257 | 258 | ''' meta-layer consisting of multiple base layers ''' 259 | 260 | @add_arg_scope 261 | def gated_resnet(x, a=None, h=None, nonlinearity=concat_elu, conv=conv2d, init=False, counters={}, ema=None, dropout_p=0., **kwargs): 262 | xs = int_shape(x) 263 | num_filters = xs[-1] 264 | 265 | c1 = conv(nonlinearity(x), num_filters) 266 | if a is not None: # add short-cut connection if auxiliary input 'a' is given 267 | c1 += nin(nonlinearity(a), num_filters) 268 | c1 = nonlinearity(c1) 269 | if dropout_p > 0: 270 | c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p) 271 | c2 = conv(c1, num_filters * 2, init_scale=0.1) 272 | 273 | # add projection of h vector if included: conditional generation 274 | if h is not None: 275 | with tf.variable_scope(get_name('conditional_weights', counters)): 276 | hw = get_var_maybe_avg('hw', ema, shape=[int_shape(h)[-1], 2 * num_filters], dtype=tf.float32, 277 | initializer=tf.random_normal_initializer(0, 0.05), trainable=True) 278 | if init: 279 | hw = hw.initialized_value() 280 | c2 += tf.reshape(tf.matmul(h, hw), [xs[0], 1, 1, 2 * num_filters]) 281 | 282 | a, b = tf.split(3, 2, c2) 283 | c3 = a * tf.nn.sigmoid(b) 284 | return x + c3 285 | 286 | ''' utilities for shifting the image around, efficient alternative to masking convolutions ''' 287 | 288 | def down_shift(x): 289 | xs = int_shape(x) 290 | return tf.concat(1,[tf.zeros([xs[0],1,xs[2],xs[3]]), x[:,:xs[1]-1,:,:]]) 291 | 292 | def right_shift(x): 293 | xs = int_shape(x) 294 | return tf.concat(2,[tf.zeros([xs[0],xs[1],1,xs[3]]), x[:,:,:xs[2]-1,:]]) 295 | 296 | @add_arg_scope 297 | def down_shifted_conv2d(x, num_filters, filter_size=[2,3], stride=[1,1], **kwargs): 298 | x = tf.pad(x, [[0,0],[filter_size[0]-1,0], [int((filter_size[1]-1)/2),int((filter_size[1]-1)/2)],[0,0]]) 299 | return conv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs) 300 | 301 | @add_arg_scope 302 | def down_shifted_deconv2d(x, num_filters, filter_size=[2,3], stride=[1,1], **kwargs): 303 | x = deconv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs) 304 | xs = int_shape(x) 305 | return x[:,:(xs[1]-filter_size[0]+1),int((filter_size[1]-1)/2):(xs[2]-int((filter_size[1]-1)/2)),:] 306 | 307 | @add_arg_scope 308 | def down_right_shifted_conv2d(x, num_filters, filter_size=[2,2], stride=[1,1], **kwargs): 309 | x = tf.pad(x, [[0,0],[filter_size[0]-1, 0], [filter_size[1]-1, 0],[0,0]]) 310 | return conv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs) 311 | 312 | @add_arg_scope 313 | def down_right_shifted_deconv2d(x, num_filters, filter_size=[2,2], stride=[1,1], **kwargs): 314 | x = deconv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs) 315 | xs = int_shape(x) 316 | return x[:,:(xs[1]-filter_size[0]+1):,:(xs[2]-filter_size[1]+1),:] 317 | --------------------------------------------------------------------------------