├── .gitattributes ├── .gitignore ├── API.py ├── CelebAValid.npz ├── GANcheckpoints.py ├── IAN.py ├── IAN_simple.npz ├── IAN_simple.py ├── IANv1.npz ├── IANv1.py ├── LICENSE ├── NPE.py ├── README.md ├── discgen_utils.py ├── layers.py ├── mask_generator.py ├── metrics_logging.py ├── pics └── .gitignore ├── sample_IAN.py ├── train_IAN.py └── train_IAN_simple.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.npz filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .*.swp 3 | .cache 4 | -------------------------------------------------------------------------------- /API.py: -------------------------------------------------------------------------------- 1 | # Plat Interface for Convenience and Interoperability 2 | # Adopted from Plat by Tom White : https://github.com/dribnet/plat 3 | 4 | import theano 5 | import theano.tensor as T 6 | import lasagne 7 | import imp 8 | import GANcheckpoints 9 | 10 | # Generic class for using IAN style models with the NPE. 11 | class IAN: 12 | def __init__(self, config_path,dnn): 13 | """ 14 | Initializate class give either a filename or a model 15 | Usually this method will load a model from disk and store internally, 16 | but model can also be provided directly instead (useful when training) 17 | """ 18 | config_module = imp.load_source('config',config_path) 19 | self.cfg = config_module.cfg 20 | self.weights_fname = str(config_path)[:-3]+'.npz' 21 | self.model = config_module.get_model(dnn=dnn) 22 | 23 | # Load weights 24 | print('Loading weights') 25 | params = list(set(lasagne.layers.get_all_params(self.model['l_out'],trainable=True)+\ 26 | lasagne.layers.get_all_params(self.model['l_discrim'],trainable=True)+\ 27 | [x for x in lasagne.layers.get_all_params(self.model['l_out'])+\ 28 | lasagne.layers.get_all_params(self.model['l_discrim'])\ 29 | if x.name[-4:]=='mean' or x.name[-7:]=='inv_std'])) 30 | GANcheckpoints.load_weights(self.weights_fname,params) 31 | 32 | # Shuffle weights if using IAF with MADE 33 | if 'l_IAF_mu' in self.model: 34 | print ('Shuffling MADE masks') 35 | self.model['l_IAF_mu'].reset("Once") 36 | self.model['l_IAF_ls'].reset("Once") 37 | 38 | print('Compiling Theano Functions') 39 | # Input Tensor 40 | self.X = T.TensorType('float32', [False]*4)('X') 41 | 42 | # Latent Vector 43 | self.Z = T.TensorType('float32', [False]*2)('Z') 44 | 45 | # X_hat(Z) 46 | self.X_hat = lasagne.layers.get_output(self.model['l_out'],{self.model['l_Z']:self.Z},deterministic=True) 47 | self.X_hat_fn = theano.function([self.Z],self.X_hat) 48 | 49 | # Z_hat(X) 50 | self.Z_hat=lasagne.layers.get_output(self.model['l_Z'],{self.model['l_in']:self.X},deterministic=True) 51 | self.Z_hat_fn = theano.function([self.X],self.Z_hat) 52 | 53 | # Imgrad Functions 54 | r1,r2 = T.scalar('r1',dtype='int32'),T.scalar('r2',dtype='int32') 55 | c1,c2 = T.scalar('c',dtype='int32'),T.scalar('c2',dtype='int32') 56 | RGB = T.tensor4('RGB',dtype='float32') 57 | 58 | # Image Gradient Function, evaluates the change in latents which would lighten the image in the local area 59 | self.calculate_lighten_gradient = theano.function([c1,r1,c2,r2,self.Z],T.grad(T.mean(self.X_hat[0,:,r1:r2,c1:c2]),self.Z)) 60 | 61 | # Image Color Gradient Function, evaluates the change in latents which would push the image towards the local desired RGB value 62 | # Consider changing this to only take in a smaller RGB array, rather than a full-sized, indexed RGB array. 63 | # Also consider using the L1 loss instead of L2 64 | self.calculate_RGB_gradient = theano.function([c1,r1,c2,r2,RGB,self.Z],T.grad(T.mean((T.sqr(-self.X_hat[0,:,r1:r2,c1:c2]+RGB[0,:,r1:r2,c1:c2]))),self.Z)) # may need a T.mean 65 | 66 | def imgrad(self,c1,r1,c2,r2,z): 67 | """ 68 | Calculate the change in latents which would lighten the local image patch. 69 | """ 70 | return self.calculate_lighten_gradient(c1,r1,c2,r2,z) 71 | 72 | def imgradRGB(self,c1,r1,c2,r2,RGB,z): 73 | """ 74 | Calculate the change in latents which would move the local image patch towards the RGB value of RGB. 75 | """ 76 | return self.calculate_RGB_gradient(c1,r1,c2,r2,RGB,z) 77 | 78 | def encode_images(self, images): 79 | """ 80 | Encode images x => z 81 | images is an n x 3 x s x s numpy array where: 82 | n = number of images 83 | 3 = R G B channels 84 | s = size of image (eg: 64, 128, etc) 85 | pixels values for each channel are encoded [-1,1] 86 | returns an n x z numpy array where: 87 | n = len(images) 88 | z = dimension of latent space 89 | """ 90 | return self.Z_hat_fn(images) 91 | 92 | def get_zdim(self): 93 | """ 94 | Returns the integer dimension of the latent z space 95 | """ 96 | return self.cfg['num_latents'] 97 | 98 | def sample_at(self, z): 99 | """ 100 | Decode images z => x 101 | z is an n x z numpy array where: 102 | n = len(images) 103 | z = dimension of latent space 104 | return images as an n x 3 x s x s numpy array where: 105 | n = number of images 106 | 3 = R G B channels 107 | s = size of image (eg: 64, 128, etc) 108 | pixels values for each channel are encoded [-1,1] 109 | """ 110 | return self.X_hat_fn(z) -------------------------------------------------------------------------------- /CelebAValid.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajbrock/Neural-Photo-Editor/d234cf1f80cf8c8f621f871dc704dc43e212201f/CelebAValid.npz -------------------------------------------------------------------------------- /GANcheckpoints.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | import cPickle as pickle 4 | import warnings 5 | import numpy as np 6 | 7 | from path import Path 8 | 9 | import lasagne 10 | 11 | def save_weights(fname, params, metadata=None): 12 | """ assumes all params have unique names. 13 | """ 14 | # Includes batchnorm params now 15 | names = [par.name for par in params] 16 | if len(names) != len(set(names)): 17 | raise ValueError('need unique param names') 18 | param_dict = { param.name : param.get_value(borrow=False) 19 | for param in params } 20 | if metadata is not None: 21 | param_dict['metadata'] = pickle.dumps(metadata) 22 | logging.info('saving {} parameters to {}'.format(len(params), fname)) 23 | # try to avoid half-written files 24 | fname = Path(fname) 25 | if fname.exists(): 26 | tmp_fname = Path(fname.stripext() + '.tmp.npz') # TODO yes, this is a hack 27 | np.savez_compressed(str(tmp_fname), **param_dict) 28 | tmp_fname.rename(fname) 29 | else: 30 | np.savez_compressed(str(fname), **param_dict) 31 | 32 | 33 | def load_weights(fname, params): 34 | # params = lasagne.layers.get_all_params(l_out,trainable=True)+[log_sigma]+[x for x in lasagne.layers.get_all_params(l_out) if x.name[-4:]=='mean' or x.name[-7:]=='inv_std'] 35 | names = [ par.name for par in params ] 36 | if len(names)!=len(set(names)): 37 | raise ValueError('need unique param names') 38 | 39 | param_dict = np.load(fname) 40 | for param in params: 41 | if param.name in param_dict: 42 | stored_shape = np.asarray(param_dict[param.name].shape) 43 | param_shape = np.asarray(param.get_value().shape) 44 | if not np.all(stored_shape == param_shape): 45 | warn_msg = 'shape mismatch:' 46 | warn_msg += '{} stored:{} new:{}'.format(param.name, stored_shape, param_shape) 47 | warn_msg += ', skipping' 48 | warnings.warn(warn_msg) 49 | else: 50 | param.set_value(param_dict[param.name]) 51 | else: 52 | logging.warn('unable to load parameter {} from {}'.format(param.name, fname)) 53 | if 'metadata' in param_dict: 54 | metadata = pickle.loads(str(param_dict['metadata'])) 55 | else: 56 | metadata = {} 57 | return metadata 58 | -------------------------------------------------------------------------------- /IAN.py: -------------------------------------------------------------------------------- 1 | ## IAN with randomized IAF 2 | 3 | import lasagne 4 | import lasagne.layers 5 | import lasagne.layers.dnn 6 | from lasagne.layers import SliceLayer as SL 7 | from lasagne.layers import batch_norm as BN 8 | from lasagne.layers import ElemwiseSumLayer as ESL 9 | from lasagne.layers import ElemwiseMergeLayer as EML 10 | from lasagne.layers import NonlinearityLayer as NL 11 | from lasagne.layers import DenseLayer as DL 12 | from lasagne.layers import Upscale2DLayer 13 | from lasagne.init import Normal as initmethod 14 | from lasagne.init import Orthogonal 15 | from lasagne.nonlinearities import elu 16 | from lasagne.nonlinearities import rectify as relu 17 | from lasagne.nonlinearities import LeakyRectify as lrelu 18 | from lasagne.nonlinearities import sigmoid 19 | from lasagne.layers.dnn import Conv2DDNNLayer as C2D 20 | from lasagne.layers.dnn import Pool2DDNNLayer as P2D 21 | from lasagne.layers import TransposedConv2DLayer as TC2D 22 | from lasagne.layers import ConcatLayer as CL 23 | import numpy as np 24 | import theano.tensor as T 25 | import theano 26 | from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable, 27 | host_from_gpu, 28 | gpu_contiguous, HostFromGpu, 29 | gpu_alloc_empty) 30 | from theano.sandbox.cuda.dnn import GpuDnnConvDesc, GpuDnnConv, GpuDnnConvGradI, dnn_conv, dnn_pool 31 | from math import sqrt 32 | 33 | 34 | 35 | from layers import MDBLOCK, DeconvLayer, MinibatchLayer, beta_layer, MADE,IAFLayer, GaussianSampleLayer, MDCL 36 | 37 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 38 | lr_schedule = { 0: 0.0002,25:0.0001,50:0.00005,75:0.00001} 39 | cfg = {'batch_size' : 16, 40 | 'learning_rate' : lr_schedule, 41 | 'optimizer' : 'Adam', 42 | 'beta1' : 0.5, 43 | 'update_ratio' : 1, 44 | 'decay_rate' : 0, 45 | 'reg' : 1e-5, 46 | 'momentum' : 0.9, 47 | 'shuffle' : True, 48 | 'dims' : (64,64), 49 | 'n_channels' : 3, 50 | 'batches_per_chunk': 64, 51 | 'max_epochs' :80, 52 | 'checkpoint_every_nth' : 1, 53 | 'num_latents': 100, 54 | 'recon_weight': 3.0, 55 | 'feature_weight': 1.0, 56 | 'dg_weight': 1.0, 57 | 'dd_weight':1.0, 58 | 'agr_weight':1.0, 59 | 'ags_weight':1.0, 60 | 'n_shuffles' : 1, 61 | 'ortho' : 1e-3, 62 | } 63 | 64 | 65 | 66 | 67 | def get_model(interp=False): 68 | dims, n_channels = tuple(cfg['dims']), cfg['n_channels'] 69 | shape = (None, n_channels)+dims 70 | l_in = lasagne.layers.InputLayer(shape=shape) 71 | l_enc_conv1 = C2D( 72 | incoming = l_in, 73 | num_filters = 128, 74 | filter_size = [5,5], 75 | stride = [2,2], 76 | pad = (2,2), 77 | W = initmethod(0.02), 78 | nonlinearity = lrelu(0.2), 79 | name = 'enc_conv1' 80 | ) 81 | l_enc_conv2 = BN(C2D( 82 | incoming = l_enc_conv1, 83 | num_filters = 256, 84 | filter_size = [5,5], 85 | stride = [2,2], 86 | pad = (2,2), 87 | W = initmethod(0.02), 88 | nonlinearity = lrelu(0.2), 89 | name = 'enc_conv2' 90 | ),name = 'bnorm2') 91 | l_enc_conv3 = BN(C2D( 92 | incoming = l_enc_conv2, 93 | num_filters = 512, 94 | filter_size = [5,5], 95 | stride = [2,2], 96 | pad = (2,2), 97 | W = initmethod(0.02), 98 | nonlinearity = lrelu(0.2), 99 | name = 'enc_conv3' 100 | ),name = 'bnorm3') 101 | l_enc_conv4 = BN(C2D( 102 | incoming = l_enc_conv3, 103 | num_filters = 1024, 104 | filter_size = [5,5], 105 | stride = [2,2], 106 | pad = (2,2), 107 | W = initmethod(0.02), 108 | nonlinearity = lrelu(0.2), 109 | name = 'enc_conv4' 110 | ),name = 'bnorm4') 111 | 112 | 113 | print(lasagne.layers.get_output_shape(l_enc_conv4,(196,3,64,64))) 114 | l_enc_fc1 = BN(DL( 115 | incoming = l_enc_conv4, 116 | num_units = 1000, 117 | W = initmethod(0.02), 118 | nonlinearity = relu, 119 | name = 'enc_fc1' 120 | ), 121 | name = 'bnorm_enc_fc1') 122 | 123 | # Define latent values 124 | l_enc_mu,l_enc_logsigma = [BN(DL(incoming = l_enc_fc1,num_units=cfg['num_latents'],nonlinearity = None,name='enc_mu'),name='mu_bnorm'), 125 | BN(DL(incoming = l_enc_fc1,num_units=cfg['num_latents'],nonlinearity = None,name='enc_logsigma'),name='ls_bnorm')] 126 | l_Z_IAF = GaussianSampleLayer(l_enc_mu, l_enc_logsigma, name='l_Z_IAF') 127 | l_IAF_mu,l_IAF_logsigma = [MADE(l_Z_IAF,[cfg['num_latents']],'l_IAF_mu'),MADE(l_Z_IAF,[cfg['num_latents']],'l_IAF_ls')] 128 | l_Z = IAFLayer(l_Z_IAF,l_IAF_mu,l_IAF_logsigma,name='l_Z') 129 | l_dec_fc2 = DL( 130 | incoming = l_Z, 131 | num_units = 512*16, 132 | nonlinearity = lrelu(0.2), 133 | W=initmethod(0.02), 134 | name='l_dec_fc2') 135 | l_unflatten = lasagne.layers.ReshapeLayer( 136 | incoming = l_dec_fc2, 137 | shape = ([0],512,4,4), 138 | ) 139 | l_dec_conv1 = DeconvLayer( 140 | incoming = l_unflatten, 141 | num_filters = 512, 142 | filter_size = [5,5], 143 | stride = [2,2], 144 | crop = (2,2), 145 | W = initmethod(0.02), 146 | nonlinearity = None, 147 | name = 'dec_conv1' 148 | ) 149 | l_dec_conv2a = MDBLOCK(incoming=l_dec_conv1,num_filters=512,scales=[0,2],name='dec_conv2a',nonlinearity=lrelu(0.2)) 150 | l_dec_conv2 = DeconvLayer( 151 | incoming = l_dec_conv2a, 152 | num_filters = 256, 153 | filter_size = [5,5], 154 | stride = [2,2], 155 | crop = (2,2), 156 | W = initmethod(0.02), 157 | nonlinearity = None, 158 | name = 'dec_conv2' 159 | ) 160 | l_dec_conv3a = MDBLOCK(incoming=l_dec_conv2,num_filters=256,scales=[0,2,3],name='dec_conv3a',nonlinearity=lrelu(0.2)) 161 | l_dec_conv3 = DeconvLayer( 162 | incoming = l_dec_conv3a, 163 | num_filters = 128, 164 | filter_size = [5,5], 165 | stride = [2,2], 166 | crop = (2,2), 167 | W = initmethod(0.02), 168 | nonlinearity = None, 169 | name = 'dec_conv3' 170 | ) 171 | l_dec_conv4a = MDBLOCK(incoming=l_dec_conv3,num_filters=128,scales=[0,2,3],name='dec_conv4a',nonlinearity=lrelu(0.2)) 172 | l_dec_conv4 = BN(DeconvLayer( 173 | incoming = l_dec_conv4a, 174 | num_filters = 128, 175 | filter_size = [5,5], 176 | stride = [2,2], 177 | crop = (2,2), 178 | W = initmethod(0.02), 179 | nonlinearity = lrelu(0.2), 180 | name = 'dec_conv4' 181 | ),name = 'bnorm_dc4') 182 | 183 | R = NL(MDCL(l_dec_conv4, 184 | num_filters=2, 185 | scales = [2,3,4], 186 | name = 'R'),sigmoid) 187 | G = NL(ESL([MDCL(l_dec_conv4, 188 | num_filters=2, 189 | scales = [2,3,4], 190 | name = 'G_a' 191 | ), 192 | MDCL(R, 193 | num_filters=2, 194 | scales = [2,3,4], 195 | name = 'G_b' 196 | )]),sigmoid) 197 | B = NL(ESL([MDCL(l_dec_conv4, 198 | num_filters=2, 199 | scales = [2,3,4], 200 | name = 'B_a' 201 | ), 202 | MDCL(CL([R,G]), 203 | num_filters=2, 204 | scales = [2,3,4], 205 | name = 'B_b' 206 | )]),sigmoid) 207 | l_out=CL([beta_layer(SL(R,slice(0,1),1),SL(R,slice(1,2),1)),beta_layer(SL(G,slice(0,1),1),SL(G,slice(1,2),1)),beta_layer(SL(B,slice(0,1),1),SL(B,slice(1,2),1))]) 208 | 209 | 210 | minibatch_discrim = MinibatchLayer(lasagne.layers.GlobalPoolLayer(l_enc_conv4), num_kernels=500,name='minibatch_discrim') 211 | l_discrim = DL(incoming = minibatch_discrim, 212 | num_units = 3, 213 | nonlinearity = lasagne.nonlinearities.softmax, 214 | b = None, 215 | W=initmethod(0.02), 216 | name = 'discrimi') 217 | 218 | 219 | return {'l_in':l_in, 220 | 'l_out':l_out, 221 | 'l_mu':l_enc_mu, 222 | 'l_ls':l_enc_logsigma, 223 | 'l_Z':l_Z, 224 | 'l_IAF_mu': l_IAF_mu, 225 | 'l_IAF_ls': l_IAF_logsigma, 226 | 'l_Z_IAF': l_Z_IAF, 227 | 'l_introspect':[l_enc_conv1, l_enc_conv2,l_enc_conv3,l_enc_conv4], 228 | 'l_discrim' : l_discrim} 229 | 230 | -------------------------------------------------------------------------------- /IAN_simple.npz: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:82e5fd3ff68b2c9095935c9db269e086e2dd27704b629853e1f03473e7059bd7 3 | size 205207893 4 | -------------------------------------------------------------------------------- /IAN_simple.py: -------------------------------------------------------------------------------- 1 | ### Simple IAN model for use with Neural Photo Editor 2 | # This model is a simplified version of the Introspective Adversarial Network that does not 3 | # make use of Multiscale Dilated Convolutional blocks, Ternary Adversarial Loss, or an 4 | # autoregressive RGB-Beta layer. It's designed to be sleeker and to run on laptop GPUs with <1GB of memory. 5 | 6 | import numpy as np 7 | 8 | import lasagne 9 | import lasagne.layers 10 | 11 | from lasagne.layers import SliceLayer as SL 12 | from lasagne.layers import batch_norm as BN 13 | from lasagne.layers import ElemwiseSumLayer as ESL 14 | from lasagne.layers import NonlinearityLayer as NL 15 | from lasagne.layers import DenseLayer as DL 16 | from lasagne.init import Normal as initmethod 17 | from lasagne.nonlinearities import elu 18 | from lasagne.nonlinearities import rectify as relu 19 | from lasagne.nonlinearities import LeakyRectify as lrelu 20 | 21 | from lasagne.layers import TransposedConv2DLayer as TC2D 22 | from lasagne.layers import ConcatLayer as CL 23 | 24 | import theano.tensor as T 25 | 26 | from math import sqrt 27 | 28 | 29 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 30 | 31 | from layers import GaussianSampleLayer,MinibatchLayer 32 | lr_schedule = { 0: 0.0002} 33 | cfg = {'batch_size' : 128, 34 | 'learning_rate' : lr_schedule, 35 | 'optimizer' : 'Adam', 36 | 'beta1' : 0.5, 37 | 'update_ratio' : 1, 38 | 'decay_rate' : 0, 39 | 'reg' : 1e-5, 40 | 'momentum' : 0.9, 41 | 'shuffle' : True, 42 | 'dims' : (64,64), 43 | 'n_channels' : 3, 44 | 'n_classes' : 10, 45 | 'batches_per_chunk': 64, 46 | 'max_epochs' :250, 47 | 'checkpoint_every_nth' : 1, 48 | 'num_latents': 100, 49 | 'recon_weight': 3.0, 50 | 'feature_weight': 1.0, 51 | } 52 | 53 | 54 | 55 | 56 | def get_model(dnn=True): 57 | if dnn: 58 | import lasagne.layers.dnn 59 | from lasagne.layers.dnn import Conv2DDNNLayer as C2D 60 | from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable, 61 | host_from_gpu, 62 | gpu_contiguous, HostFromGpu, 63 | gpu_alloc_empty) 64 | from theano.sandbox.cuda.dnn import GpuDnnConvDesc, GpuDnnConv, GpuDnnConvGradI, dnn_conv, dnn_pool 65 | from layers import DeconvLayer 66 | else: 67 | import lasagne.layers 68 | from lasagne.layers import Conv2DLayer as C2D 69 | 70 | dims, n_channels, n_classes = tuple(cfg['dims']), cfg['n_channels'], cfg['n_classes'] 71 | shape = (None, n_channels)+dims 72 | l_in = lasagne.layers.InputLayer(shape=shape) 73 | l_enc_conv1 = C2D( 74 | incoming = l_in, 75 | num_filters = 128, 76 | filter_size = [5,5], 77 | stride = [2,2], 78 | pad = (2,2), 79 | W = initmethod(0.02), 80 | nonlinearity = lrelu(0.2), 81 | flip_filters=False, 82 | name = 'enc_conv1' 83 | ) 84 | l_enc_conv2 = BN(C2D( 85 | incoming = l_enc_conv1, 86 | num_filters = 256, 87 | filter_size = [5,5], 88 | stride = [2,2], 89 | pad = (2,2), 90 | W = initmethod(0.02), 91 | nonlinearity = lrelu(0.2), 92 | flip_filters=False, 93 | name = 'enc_conv2' 94 | ),name = 'bnorm2') 95 | l_enc_conv3 = BN(C2D( 96 | incoming = l_enc_conv2, 97 | num_filters = 512, 98 | filter_size = [5,5], 99 | stride = [2,2], 100 | pad = (2,2), 101 | W = initmethod(0.02), 102 | nonlinearity = lrelu(0.2), 103 | flip_filters=False, 104 | name = 'enc_conv3' 105 | ),name = 'bnorm3') 106 | l_enc_conv4 = BN(C2D( 107 | incoming = l_enc_conv3, 108 | num_filters = 1024, 109 | filter_size = [5,5], 110 | stride = [2,2], 111 | pad = (2,2), 112 | W = initmethod(0.02), 113 | nonlinearity = lrelu(0.2), 114 | flip_filters=False, 115 | name = 'enc_conv4' 116 | ),name = 'bnorm4') 117 | l_enc_fc1 = BN(DL( 118 | incoming = l_enc_conv4, 119 | num_units = 1000, 120 | W = initmethod(0.02), 121 | nonlinearity = elu, 122 | name = 'enc_fc1' 123 | ), 124 | name = 'bnorm_enc_fc1') 125 | l_enc_mu,l_enc_logsigma = [BN(DL(incoming = l_enc_fc1,num_units=cfg['num_latents'],nonlinearity = None,name='enc_mu'),name='mu_bnorm'), 126 | BN(DL(incoming = l_enc_fc1,num_units=cfg['num_latents'],nonlinearity = None,name='enc_logsigma'),name='ls_bnorm')] 127 | 128 | l_Z = GaussianSampleLayer(l_enc_mu, l_enc_logsigma, name='l_Z') 129 | l_dec_fc2 = BN(DL( 130 | incoming = l_Z, 131 | num_units = 1024*16, 132 | nonlinearity = relu, 133 | W=initmethod(0.02), 134 | name='l_dec_fc2'), 135 | name = 'bnorm_dec_fc2') 136 | l_unflatten = lasagne.layers.ReshapeLayer( 137 | incoming = l_dec_fc2, 138 | shape = ([0],1024,4,4), 139 | ) 140 | if dnn: 141 | l_dec_conv1 = BN(DeconvLayer( 142 | incoming = l_unflatten, 143 | num_filters = 512, 144 | filter_size = [5,5], 145 | stride = [2,2], 146 | crop = (2,2), 147 | W = initmethod(0.02), 148 | nonlinearity = relu, 149 | name = 'dec_conv1' 150 | ),name = 'bnorm_dc1') 151 | l_dec_conv2 = BN(DeconvLayer( 152 | incoming = l_dec_conv1, 153 | num_filters = 256, 154 | filter_size = [5,5], 155 | stride = [2,2], 156 | crop = (2,2), 157 | W = initmethod(0.02), 158 | nonlinearity = relu, 159 | name = 'dec_conv2' 160 | ),name = 'bnorm_dc2') 161 | l_dec_conv3 = BN(DeconvLayer( 162 | incoming = l_dec_conv2, 163 | num_filters = 128, 164 | filter_size = [5,5], 165 | stride = [2,2], 166 | crop = (2,2), 167 | W = initmethod(0.02), 168 | nonlinearity = relu, 169 | name = 'dec_conv3' 170 | ),name = 'bnorm_dc3') 171 | l_out = DeconvLayer( 172 | incoming = l_dec_conv3, 173 | num_filters = 3, 174 | filter_size = [5,5], 175 | stride = [2,2], 176 | crop = (2,2), 177 | W = initmethod(0.02), 178 | b = None, 179 | nonlinearity = lasagne.nonlinearities.tanh, 180 | name = 'dec_out' 181 | ) 182 | else: 183 | l_dec_conv1 = SL(SL(BN(TC2D( 184 | incoming = l_unflatten, 185 | num_filters = 512, 186 | filter_size = [5,5], 187 | stride = [2,2], 188 | crop = (1,1), 189 | W = initmethod(0.02), 190 | nonlinearity = relu, 191 | name = 'dec_conv1' 192 | ),name = 'bnorm_dc1'),indices=slice(1,None),axis=2),indices=slice(1,None),axis=3) 193 | l_dec_conv2 = SL(SL(BN(TC2D( 194 | incoming = l_dec_conv1, 195 | num_filters = 256, 196 | filter_size = [5,5], 197 | stride = [2,2], 198 | crop = (1,1), 199 | W = initmethod(0.02), 200 | nonlinearity = relu, 201 | name = 'dec_conv2' 202 | ),name = 'bnorm_dc2'),indices=slice(1,None),axis=2),indices=slice(1,None),axis=3) 203 | l_dec_conv3 = SL(SL(BN(TC2D( 204 | incoming = l_dec_conv2, 205 | num_filters = 128, 206 | filter_size = [5,5], 207 | stride = [2,2], 208 | crop = (1,1), 209 | W = initmethod(0.02), 210 | nonlinearity = relu, 211 | name = 'dec_conv3' 212 | ),name = 'bnorm_dc3'),indices=slice(1,None),axis=2),indices=slice(1,None),axis=3) 213 | l_out = SL(SL(TC2D( 214 | incoming = l_dec_conv3, 215 | num_filters = 3, 216 | filter_size = [5,5], 217 | stride = [2,2], 218 | crop = (1,1), 219 | W = initmethod(0.02), 220 | b = None, 221 | nonlinearity = lasagne.nonlinearities.tanh, 222 | name = 'dec_out' 223 | ),indices=slice(1,None),axis=2),indices=slice(1,None),axis=3) 224 | # l_in,num_filters=1,filter_size=[5,5],stride=[2,2],crop=[1,1],W=dc.W,b=None,nonlinearity=None) 225 | minibatch_discrim = MinibatchLayer(lasagne.layers.GlobalPoolLayer(l_enc_conv4), num_kernels=500,name='minibatch_discrim') 226 | l_discrim = DL(incoming = minibatch_discrim, 227 | num_units = 1, 228 | nonlinearity = lasagne.nonlinearities.sigmoid, 229 | b = None, 230 | W=initmethod(), 231 | name = 'discrimi') 232 | 233 | 234 | 235 | return {'l_in':l_in, 236 | 'l_out':l_out, 237 | 'l_mu':l_enc_mu, 238 | 'l_ls':l_enc_logsigma, 239 | 'l_Z':l_Z, 240 | 'l_introspect':[l_enc_conv1, l_enc_conv2,l_enc_conv3,l_enc_conv4], 241 | 'l_discrim' : l_discrim} 242 | 243 | 244 | -------------------------------------------------------------------------------- /IANv1.npz: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7e6f4eedd4e06792c739f8599bb74345fe211f3bc9d56163c5c79c3c356182b2 3 | size 205961207 4 | -------------------------------------------------------------------------------- /IANv1.py: -------------------------------------------------------------------------------- 1 | ## IAN with binary adversarial loss and no orthogonal regularization 2 | 3 | import lasagne 4 | import lasagne.layers 5 | import lasagne.layers.dnn 6 | from lasagne.layers import SliceLayer as SL 7 | from lasagne.layers import batch_norm as BN 8 | from lasagne.layers import ElemwiseSumLayer as ESL 9 | from lasagne.layers import ElemwiseMergeLayer as EML 10 | from lasagne.layers import NonlinearityLayer as NL 11 | from lasagne.layers import DenseLayer as DL 12 | from lasagne.layers import Upscale2DLayer 13 | from lasagne.init import Normal as initmethod 14 | from lasagne.init import Orthogonal 15 | from lasagne.nonlinearities import elu 16 | from lasagne.nonlinearities import rectify as relu 17 | from lasagne.nonlinearities import LeakyRectify as lrelu 18 | from lasagne.nonlinearities import sigmoid 19 | from lasagne.layers.dnn import Conv2DDNNLayer as C2D 20 | from lasagne.layers.dnn import Pool2DDNNLayer as P2D 21 | from lasagne.layers import TransposedConv2DLayer as TC2D 22 | from lasagne.layers import ConcatLayer as CL 23 | import numpy as np 24 | import theano.tensor as T 25 | import theano 26 | from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable, 27 | host_from_gpu, 28 | gpu_contiguous, HostFromGpu, 29 | gpu_alloc_empty) 30 | from theano.sandbox.cuda.dnn import GpuDnnConvDesc, GpuDnnConv, GpuDnnConvGradI, dnn_conv, dnn_pool 31 | from math import sqrt 32 | 33 | 34 | 35 | from layers import MDBLOCK, DeconvLayer, MinibatchLayer, beta_layer, MADE, IAFLayer, GaussianSampleLayer, MDCL 36 | 37 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 38 | lr_schedule = { 0: 0.0002,25:0.0001,50:0.00005,75:0.00001} 39 | cfg = {'batch_size' : 16, 40 | 'learning_rate' : lr_schedule, 41 | 'optimizer' : 'Adam', 42 | 'beta1' : 0.5, 43 | 'update_ratio' : 1, 44 | 'decay_rate' : 0, 45 | 'reg' : 1e-5, 46 | 'momentum' : 0.9, 47 | 'shuffle' : True, 48 | 'dims' : (64,64), 49 | 'n_channels' : 3, 50 | 'batches_per_chunk': 64, 51 | 'max_epochs' :150, 52 | 'checkpoint_every_nth' : 1, 53 | 'num_latents': 100, 54 | 'recon_weight': 3.0, 55 | 'feature_weight': 1.0, 56 | 'dg_weight': 1.0, 57 | 'dd_weight':1.0, 58 | 'agr_weight':1.0, 59 | 'ags_weight':1.0, 60 | 'n_shuffles' : 1, 61 | } 62 | 63 | def get_model(interp=False): 64 | dims, n_channels = tuple(cfg['dims']), cfg['n_channels'] 65 | shape = (None, n_channels)+dims 66 | l_in = lasagne.layers.InputLayer(shape=shape) 67 | l_enc_conv1 = C2D( 68 | incoming = l_in, 69 | num_filters = 128, 70 | filter_size = [5,5], 71 | stride = [2,2], 72 | pad = (2,2), 73 | W = initmethod(0.02), 74 | nonlinearity = lrelu(0.2), 75 | name = 'enc_conv1' 76 | ) 77 | l_enc_conv2 = BN(C2D( 78 | incoming = l_enc_conv1, 79 | num_filters = 256, 80 | filter_size = [5,5], 81 | stride = [2,2], 82 | pad = (2,2), 83 | W = initmethod(0.02), 84 | nonlinearity = lrelu(0.2), 85 | name = 'enc_conv2' 86 | ),name = 'bnorm2') 87 | l_enc_conv3 = BN(C2D( 88 | incoming = l_enc_conv2, 89 | num_filters = 512, 90 | filter_size = [5,5], 91 | stride = [2,2], 92 | pad = (2,2), 93 | W = initmethod(0.02), 94 | nonlinearity = lrelu(0.2), 95 | name = 'enc_conv3' 96 | ),name = 'bnorm3') 97 | l_enc_conv4 = BN(C2D( 98 | incoming = l_enc_conv3, 99 | num_filters = 1024, 100 | filter_size = [5,5], 101 | stride = [2,2], 102 | pad = (2,2), 103 | W = initmethod(0.02), 104 | nonlinearity = lrelu(0.2), 105 | name = 'enc_conv4' 106 | ),name = 'bnorm4') 107 | 108 | 109 | print(lasagne.layers.get_output_shape(l_enc_conv4,(196,3,64,64))) 110 | l_enc_fc1 = BN(DL( 111 | incoming = l_enc_conv4, 112 | num_units = 1000, 113 | W = initmethod(0.02), 114 | nonlinearity = relu, 115 | name = 'enc_fc1' 116 | ), 117 | name = 'bnorm_enc_fc1') 118 | 119 | # Define latent values 120 | l_enc_mu,l_enc_logsigma = [BN(DL(incoming = l_enc_fc1,num_units=cfg['num_latents'],nonlinearity = None,name='enc_mu'),name='mu_bnorm'), 121 | BN(DL(incoming = l_enc_fc1,num_units=cfg['num_latents'],nonlinearity = None,name='enc_logsigma'),name='ls_bnorm')] 122 | l_Z_IAF = GaussianSampleLayer(l_enc_mu, l_enc_logsigma, name='l_Z_IAF') 123 | l_IAF_mu,l_IAF_logsigma = [MADE(l_Z_IAF,[cfg['num_latents']],'l_IAF_mu'),MADE(l_Z_IAF,[cfg['num_latents']],'l_IAF_ls')] 124 | l_Z = IAFLayer(l_Z_IAF,l_IAF_mu,l_IAF_logsigma,name='l_Z') 125 | l_dec_fc2 = DL( 126 | incoming = l_Z, 127 | num_units = 1024*16, 128 | nonlinearity = None, 129 | W=initmethod(0.02), 130 | name='l_dec_fc2') 131 | l_unflatten = lasagne.layers.ReshapeLayer( 132 | incoming = l_dec_fc2, 133 | shape = ([0],1024,4,4), 134 | ) 135 | l_dec_conv1 = BN(DeconvLayer( 136 | incoming = l_unflatten, 137 | num_filters = 512, 138 | filter_size = [5,5], 139 | stride = [2,2], 140 | crop = (2,2), 141 | W = initmethod(0.02), 142 | nonlinearity = relu, 143 | name = 'dec_conv1' 144 | ),name = 'bnorm_dc1') 145 | l_dec_conv2 = BN(DeconvLayer( 146 | incoming = l_dec_conv1, 147 | num_filters = 256, 148 | filter_size = [5,5], 149 | stride = [2,2], 150 | crop = (2,2), 151 | W = initmethod(0.02), 152 | nonlinearity = relu, 153 | name = 'dec_conv2' 154 | ),name = 'bnorm_dc2') 155 | l_dec_conv3 = BN(DeconvLayer( 156 | incoming = l_dec_conv2, 157 | num_filters = 128, 158 | filter_size = [5,5], 159 | stride = [2,2], 160 | crop = (2,2), 161 | W = initmethod(0.02), 162 | nonlinearity = relu, 163 | name = 'dec_conv3' 164 | ),name = 'bnorm_dc3') 165 | 166 | l_dec_conv4 = BN(DeconvLayer( 167 | incoming = l_dec_conv3, 168 | num_filters = 64, 169 | filter_size = [5,5], 170 | stride = [2,2], 171 | crop = (2,2), 172 | W = initmethod(0.02), 173 | nonlinearity = relu, 174 | name = 'dec_conv4' 175 | ),name = 'bnorm_dc4') 176 | 177 | R = NL(MDCL(l_dec_conv4, 178 | num_filters=2, 179 | scales = [2,3,4], 180 | name = 'R'),sigmoid) 181 | G = NL(ESL([MDCL(l_dec_conv4, 182 | num_filters=2, 183 | scales = [2,3,4], 184 | name = 'G_a' 185 | ), 186 | MDCL(R, 187 | num_filters=2, 188 | scales = [2,3,4], 189 | name = 'G_b' 190 | )]),sigmoid) 191 | B = NL(ESL([MDCL(l_dec_conv4, 192 | num_filters=2, 193 | scales = [2,3,4], 194 | name = 'B_a' 195 | ), 196 | MDCL(CL([R,G]), 197 | num_filters=2, 198 | scales = [2,3,4], 199 | name = 'B_b' 200 | )]),sigmoid) 201 | l_out=CL([beta_layer(SL(R,slice(0,1),1),SL(R,slice(1,2),1)),beta_layer(SL(G,slice(0,1),1),SL(G,slice(1,2),1)),beta_layer(SL(B,slice(0,1),1),SL(B,slice(1,2),1))]) 202 | 203 | minibatch_discrim = MinibatchLayer(lasagne.layers.GlobalPoolLayer(l_enc_conv4), num_kernels=500,name='minibatch_discrim') 204 | l_discrim = DL(incoming = minibatch_discrim, 205 | num_units = 1, 206 | nonlinearity = lasagne.nonlinearities.sigmoid, 207 | b = None, 208 | W=initmethod(0.02), 209 | name = 'discrimi') 210 | 211 | 212 | return {'l_in':l_in, 213 | 'l_out':l_out, 214 | 'l_mu':l_enc_mu, 215 | 'l_ls':l_enc_logsigma, 216 | 'l_Z':l_Z, 217 | 'l_IAF_mu': l_IAF_mu, 218 | 'l_IAF_ls': l_IAF_logsigma, 219 | 'l_Z_IAF': l_Z_IAF, 220 | 'l_introspect':[l_enc_conv1, l_enc_conv2,l_enc_conv3,l_enc_conv4], 221 | 222 | 'l_discrim' : l_discrim} 223 | 224 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Andy Brock 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NPE.py: -------------------------------------------------------------------------------- 1 | ### Neural Photo Editor 2 | # A Brock, 2016 3 | 4 | ### Imports 5 | 6 | from Tkinter import * # Note that I dislike the * on the Tkinter import, but all the tutorials seem to do that so I stuck with it. 7 | from tkColorChooser import askcolor # This produces an OS-dependent color selector. I like the windows one best, and can't stand the linux one. 8 | from collections import OrderedDict 9 | from PIL import Image, ImageTk 10 | import numpy as np 11 | import scipy.misc 12 | 13 | 14 | from API import IAN 15 | ### Step 1: Create theano functions 16 | 17 | # Initialize model 18 | model = IAN(config_path = 'IAN_simple.py', dnn = True) 19 | 20 | ### Prepare GUI functions 21 | print('Compiling remaining functions') 22 | 23 | # Create master 24 | master = Tk() 25 | master.title( "Neural Photo Editor" ) 26 | 27 | # RGB interpreter convenience function 28 | def rgb(r,g,b): 29 | return '#%02x%02x%02x' % (r,g,b) 30 | 31 | # Convert RGB to bi-directional RB scale. 32 | def rb(i): 33 | # return rgb(int(i*int(i>0)),0, -int(i*int(i<0))) 34 | return rgb(255+max(int(i*int(i<0)),-255),255-min(abs(int(i)),255), 255-min(int(i*int(i>0)),255)) 35 | 36 | # Convenience functions to go from [0,255] to [-1,1] and [-1,1] to [0,255] 37 | def to_tanh(input): 38 | return 2.0*(input/255.0)-1.0 39 | 40 | def from_tanh(input): 41 | return 255.0*(input+1)/2.0 42 | 43 | # Ground truth image 44 | GIM=np.asarray(np.load('CelebAValid.npz')['arr_0'][420]) 45 | 46 | # Image for modification 47 | IM = GIM 48 | 49 | # Reconstruction 50 | RECON = IM 51 | 52 | # Error between reconstruction and current image 53 | ERROR = np.zeros(np.shape(IM),dtype=np.float32) 54 | 55 | # Change between Recon and Current 56 | DELTA = np.zeros(np.shape(IM),dtype=np.float32) 57 | 58 | # User-Painted Mask, currently not implemented. 59 | USER_MASK=np.mean(DELTA,axis=0) 60 | 61 | # Are we operating on a photo or a sample? 62 | SAMPLE_FLAG=0 63 | 64 | 65 | ### Latent Canvas Variables 66 | # Latent Square dimensions 67 | dim = [10,10] 68 | 69 | # Squared Latent Array 70 | Z = np.zeros((dim[0],dim[1]),dtype=np.float32) 71 | 72 | # Pixel-wise resolution for latent canvas 73 | res = 16 74 | 75 | # Array that holds the actual latent canvas 76 | r = np.zeros((res*dim[0],res*dim[1]),dtype=np.float32) 77 | 78 | # Painted rectangles for free-form latent painting 79 | painted_rects = [] 80 | 81 | # Actual latent rectangles 82 | rects = np.zeros((dim[0],dim[1]),dtype=int) 83 | 84 | ### Output Display Variables 85 | 86 | # RGB paintbrush array 87 | myRGB = np.zeros((1,3,64,64),dtype=np.float32); 88 | 89 | # Canvas width and height 90 | canvas_width = 400 91 | canvas_height = 400 92 | 93 | # border width 94 | bd =2 95 | # Brush color 96 | color = IntVar() 97 | color.set(0) 98 | 99 | # Brush size 100 | d = IntVar() 101 | d.set(12) 102 | 103 | # Selected Color 104 | mycol = (0,0,0) 105 | 106 | # Function to update display 107 | def update_photo(data=None,widget=None): 108 | global Z 109 | if data is None: # By default, assume we're updating with the current value of Z 110 | data = np.repeat(np.repeat(np.uint8(from_tanh(model.sample_at(np.float32([Z.flatten()]))[0])),4,1),4,2) 111 | else: 112 | data = np.repeat(np.repeat(np.uint8(data),4,1),4,2) 113 | 114 | if widget is None: 115 | widget = output 116 | # Reshape image to canvas 117 | mshape = (4*64,4*64,1) 118 | im = Image.fromarray(np.concatenate([np.reshape(data[0],mshape),np.reshape(data[1],mshape),np.reshape(data[2],mshape)],axis=2),mode='RGB') 119 | 120 | # Make sure photo is an object of the current widget so the garbage collector doesn't wreck it 121 | widget.photo = ImageTk.PhotoImage(image=im) 122 | widget.create_image(0,0,image=widget.photo,anchor=NW) 123 | widget.tag_raise(pixel_rect) 124 | 125 | # Function to update the latent canvas. 126 | def update_canvas(widget=None): 127 | global r, Z, res, rects, painted_rects 128 | if widget is None: 129 | widget = w 130 | # Update display values 131 | r = np.repeat(np.repeat(Z,r.shape[0]//Z.shape[0],0),r.shape[1]//Z.shape[1],1) 132 | 133 | # If we're letting freeform painting happen, delete the painted rectangles 134 | for p in painted_rects: 135 | w.delete(p) 136 | painted_rects = [] 137 | 138 | for i in range(Z.shape[0]): 139 | for j in range(Z.shape[1]): 140 | w.itemconfig(int(rects[i,j]),fill = rb(255*Z[i,j]),outline = rb(255*Z[i,j])) 141 | 142 | # Function to move the paintbrush 143 | def move_mouse( event ): 144 | global output 145 | # using a rectangle width equivalent to d/4 (so 1-16) 146 | 147 | # First, get location and extent of local patch 148 | x,y = event.x//4,event.y//4 149 | brush_width = ((d.get()//4)+1) 150 | 151 | # if x is near the left corner, then the minimum x is dependent on how close it is to the left 152 | xmin = max(min(x-brush_width//2,64 - brush_width),0) # This 64 may need to change if the canvas size changes 153 | xmax = xmin+brush_width 154 | 155 | ymin = max(min(y-brush_width//2,64 - brush_width),0) # This 64 may need to change if the canvas size changes 156 | ymax = ymin+brush_width 157 | 158 | # update output canvas 159 | output.coords(pixel_rect,4*xmin,4*ymin,4*xmax,4*ymax) 160 | output.tag_raise(pixel_rect) 161 | output.itemconfig(pixel_rect,outline = rgb(mycol[0],mycol[1],mycol[2])) 162 | 163 | ### Optional functions for the Neural Painter 164 | 165 | # Localized Gaussian Smoothing Kernel 166 | # Use this if you want changes to MASK to be more localized to the brush location in soe sense 167 | def gk(c1,r1,c2,r2): 168 | # First, create X and Y arrays indicating distance to the boundaries of the paintbrush 169 | # In this current context, im is the ordinal number of pixels (64 typically) 170 | sigma = 0.3 171 | im = 64 172 | x = np.repeat([np.concatenate([np.mgrid[-c1:0],np.zeros(c2-c1),np.mgrid[1:1+im-c2]])],im,axis=0) 173 | y = np.repeat(np.vstack(np.concatenate([np.mgrid[-r1:0],np.zeros(r2-r1),np.mgrid[1:1+im-r2]])),im,axis=1) 174 | g = np.exp(-(x**2/float(im)+y**2/float(im))/(2*sigma**2)) 175 | return np.repeat([g],3,axis=0) # remove the 3 if you want to apply this to mask rather than an RGB channel 176 | # This function reduces the likelihood of a change based on how close each individual pixel is to a maximal value. 177 | # Consider conditioning this based on the gK value and the requested color. I.E. instead of just a flat distance from 128, 178 | # have it be a difference from the expected color at a given location. This could also be used to "weight" the image towards staying the same. 179 | def upperlim(image): 180 | h=1 181 | return (1.0/((1.0/h)*np.abs(image-128)+1)) 182 | 183 | # Similar to upperlim, this function changes the value of the correction term if it's going to move pixels too close to a maximal value 184 | def dampen(input,correct): 185 | # The closer input+correct is to -1 or 1, the further it is from 0. 186 | # We're okay with almost all values (i.e. between 0 and 0.8) but as we approach 1 we want to slow the change 187 | thresh = 0.75 188 | m = (input+correct)>thresh 189 | return -input*m+correct*(1-m)+thresh*m 190 | 191 | ### Neural Painter Function 192 | def paint( event ): 193 | global Z, output, myRGB, IM, ERROR, RECON, USER_MASK, SAMPLE_FLAG 194 | 195 | # Move the paintbrush 196 | move_mouse(event) 197 | 198 | # Define a gradient descent step-size 199 | weight = 0.05 200 | 201 | # Get paintbrush location 202 | [x1,y1,x2,y2] = [coordinate//4 for coordinate in output.coords(pixel_rect)] 203 | 204 | # Get dIM/dZ that minimizes the difference between IM and RGB in the domain of the paintbrush 205 | temp = np.asarray(model.imgradRGB(x1,y1,x2,y2,np.float32(to_tanh(myRGB)),np.float32([Z.flatten()]))[0]) 206 | grad = temp.reshape((10,10))*(1+(x2-x1)) 207 | 208 | # Update Z 209 | Z -=weight*grad 210 | 211 | # If operating on a sample, update sample 212 | if SAMPLE_FLAG: 213 | update_canvas(w) 214 | update_photo(None,output) 215 | # Else, update photo 216 | else: 217 | # Difference between current image and reconstruction 218 | DELTA = model.sample_at(np.float32([Z.flatten()]))[0]-to_tanh(np.float32(RECON)) 219 | 220 | # Not-Yet-Implemented User Mask feature 221 | # USER_MASK[y1:y2,x1:x2]+=0.05 222 | 223 | # Get MASK 224 | MASK=scipy.ndimage.filters.gaussian_filter(np.min([np.mean(np.abs(DELTA),axis=0),np.ones((64,64))],axis=0),0.7) 225 | 226 | # Optionally dampen D 227 | # D = dampen(to_tanh(np.float32(RECON)),MASK*DELTA+(1-MASK)*ERROR) 228 | 229 | # Update image 230 | D = MASK*DELTA+(1-MASK)*ERROR 231 | IM = np.uint8(from_tanh(to_tanh(RECON)+D)) 232 | 233 | # Pass updates 234 | update_canvas(w) 235 | update_photo(IM,output) 236 | 237 | # Load an image and infer/reconstruct from it. Update this with a function to load your own images if you want to edit 238 | # non-celebA photos. 239 | def infer(): 240 | global Z,w,GIM,IM,ERROR,RECON,DELTA,USER_MASK,SAMPLE_FLAG 241 | val = myentry.get() 242 | try: 243 | val = int(val) 244 | GIM = np.asarray(np.load('CelebAValid.npz')['arr_0'][val]) 245 | IM = GIM 246 | except ValueError: 247 | print "No input" 248 | val = 420 249 | GIM = np.asarray(np.load('CelebAValid.npz')['arr_0'][val]) 250 | IM = GIM 251 | # myentry.delete(0, END) # Optionally, clear entry after typing it in 252 | 253 | # Reset Delta 254 | DELTA = np.zeros(np.shape(IM),dtype=np.float32) 255 | 256 | # Infer and reshape latents. This can be done without an intermediate variable if desired 257 | s = model.encode_images(np.asarray([to_tanh(IM)],dtype=np.float32)) 258 | Z = np.reshape(s[0],np.shape(Z)) 259 | 260 | # Get reconstruction 261 | RECON = np.uint8(from_tanh(model.sample_at(np.float32([Z.flatten()]))[0])) 262 | 263 | # Get error 264 | ERROR = to_tanh(np.float32(IM)) - to_tanh(np.float32(RECON)) 265 | 266 | # Reset user mask 267 | USER_MASK*=0 268 | 269 | # Clear the sample flag 270 | SAMPLE_FLAG=0 271 | 272 | # Update photo 273 | update_photo(IM,output) 274 | update_canvas(w) 275 | 276 | # Paint directly into the latent space 277 | def paint_latents( event ): 278 | global r, Z, output,painted_rects,MASK,USER_MASK,RECON 279 | 280 | # Get extent of latent paintbrush 281 | x1, y1 = ( event.x - d.get() ), ( event.y - d.get() ) 282 | x2, y2 = ( event.x + d.get() ), ( event.y + d.get() ) 283 | 284 | selected_widget = event.widget 285 | 286 | # Paint in latent space and update Z 287 | painted_rects.append(event.widget.create_rectangle( x1, y1, x2, y2, fill = rb(color.get()),outline = rb(color.get()) )) 288 | r[max((y1-bd),0):min((y2-bd),r.shape[0]),max((x1-bd),0):min((x2-bd),r.shape[1])] = color.get()/255.0; 289 | Z = np.asarray([np.mean(o) for v in [np.hsplit(h,Z.shape[0])\ 290 | for h in np.vsplit((r),Z.shape[1])]\ 291 | for o in v]).reshape(Z.shape[0],Z.shape[1]) 292 | if SAMPLE_FLAG: 293 | update_photo(None,output) 294 | update_canvas(w) # Remove this if you wish to see a more free-form paintbrush 295 | else: 296 | DELTA = model.sample_at(np.float32([Z.flatten()]))[0]-to_tanh(np.float32(RECON)) 297 | MASK=scipy.ndimage.filters.gaussian_filter(np.min([np.mean(np.abs(DELTA),axis=0),np.ones((64,64))],axis=0),0.7) 298 | # D = dampen(to_tanh(np.float32(RECON)),MASK*DELTA+(1-MASK)*ERROR) 299 | D = MASK*DELTA+(1-MASK)*ERROR 300 | IM = np.uint8(from_tanh(to_tanh(RECON)+D)) 301 | update_canvas(w) # Remove this if you wish to see a more free-form paintbrush 302 | update_photo(IM,output) 303 | 304 | # Scroll to lighten or darken an image patch 305 | def scroll( event ): 306 | global r,Z,output 307 | # Optional alternate method to get a single X Y point 308 | # x,y = np.floor( ( event.x - (output.winfo_rootx() - master.winfo_rootx()) ) / 4), np.floor( ( event.y - (output.winfo_rooty() - master.winfo_rooty()) ) / 4) 309 | weight = 0.1 310 | [x1,y1,x2,y2] = [coordinate//4 for coordinate in output.coords(pixel_rect)] 311 | grad = np.reshape(model.imgrad(x1,y1,x2,y2,np.float32([Z.flatten()]))[0],Z.shape)*(1+(x2-x1)) 312 | Z+=np.sign(event.delta)*weight*grad 313 | update_canvas(w) 314 | update_photo(None,output) 315 | 316 | # Samples in the latent space 317 | def sample(): 318 | global Z, output,RECON,IM,ERROR,SAMPLE_FLAG 319 | Z = np.random.randn(Z.shape[0],Z.shape[1]) 320 | # Z = np.random.uniform(low=-1.0,high=1.0,size=(Z.shape[0],Z.shape[1])) # Optionally get uniform sample 321 | 322 | # Update reconstruction and error 323 | RECON = np.uint8(from_tanh(model.sample_at(np.float32([Z.flatten()]))[0])) 324 | ERROR = to_tanh(np.float32(IM)) - to_tanh(np.float32(RECON)) 325 | update_canvas(w) 326 | SAMPLE_FLAG=1 327 | update_photo(None,output) 328 | 329 | # Reset to ground-truth image 330 | def Reset(): 331 | global GIM,IM,Z, DELTA,RECON,ERROR,USER_MASK,SAMPLE_FLAG 332 | IM = GIM 333 | Z = np.reshape(model.encode_images(np.asarray([to_tanh(IM)],dtype=np.float32))[0],np.shape(Z)) 334 | DELTA = np.zeros(np.shape(IM),dtype=np.float32) 335 | RECON = np.uint8(from_tanh(model.sample_at(np.float32([Z.flatten()]))[0])) 336 | ERROR = to_tanh(np.float32(IM)) - to_tanh(np.float32(RECON)) 337 | USER_MASK*=0 338 | SAMPLE_FLAG=0 339 | update_canvas(w) 340 | update_photo(IM,output) 341 | 342 | def UpdateGIM(): 343 | global GIM,IM 344 | GIM = IM 345 | Reset()# Recalc the latent space for the new ground-truth image. 346 | 347 | # Change brush size 348 | def update_brush(event): 349 | brush.create_rectangle(0,0,25,25,fill=rgb(255,255,255),outline=rgb(255,255,255)) 350 | brush.create_rectangle( int(12.5-d.get()/4.0), int(12.5-d.get()/4.0), int(12.5+d.get()/4.0), int(12.5+d.get()/4.0), fill = rb(color.get()),outline = rb(color.get()) ) 351 | 352 | # assign color picker values to myRGB 353 | def getColor(): 354 | global myRGB, mycol 355 | col = askcolor(mycol) 356 | if col[0] is None: 357 | return # Dont change color if Cancel pressed. 358 | mycol = col[0] 359 | for i in xrange(3): myRGB[0,i,:,:] = mycol[i]; # assign 360 | 361 | # Optional function to "lock" latents so that gradients are always evaluated with respect to the locked Z 362 | # def lock(): 363 | # global Z,locked, Zlock, lockbutton 364 | # lockbutton.config(relief='raised' if locked else 'sunken') 365 | # Zlock = Z if not locked else Zlock 366 | # locked = not locked 367 | # lockbutton = Button(f, text="Lock", command=lock,relief='raised') 368 | # lockbutton.pack(side=LEFT) 369 | 370 | ### Prepare GUI 371 | master.bind("",scroll) 372 | 373 | # Prepare drawing canvas 374 | f=Frame(master) 375 | f.pack(side=TOP) 376 | output = Canvas(f,name='output',width=64*4,height=64*4) 377 | output.bind('',move_mouse) 378 | output.bind('', paint ) 379 | pixel_rect = output.create_rectangle(0,0,4,4,outline = 'yellow') 380 | output.pack() 381 | 382 | # Prepare latent canvas 383 | f = Frame(master,width=res*dim[0],height=dim[1]*10) 384 | f.pack(side=TOP) 385 | w = Canvas(f,name='canvas', width=res*dim[0],height=res*dim[1]) 386 | w.bind( "", paint_latents ) 387 | # Produce painted rectangles 388 | for i in range(Z.shape[0]): 389 | for j in range(Z.shape[1]): 390 | rects[i,j] = w.create_rectangle( j*res, i*res, (j+1)*res, (i+1)*res, fill = rb(255*Z[i,j]),outline = rb(255*Z[i,j]) ) 391 | # w.create_rectangle( 0,0,res*dim[0],res*dim[1], fill = rgb(255,255,255),outline=rgb(255,255,255)) # Optionally Initialize canvas to white 392 | w.pack() 393 | 394 | 395 | # Color gradient 396 | gradient = Canvas(master, width=400, height=20) 397 | gradient.pack(side=TOP) 398 | # gradient.grid(row=i+1) 399 | for j in range(-200,200): 400 | gradient.create_rectangle(j*255/200+200,0,j*255/200+201,20,fill = rb(j*255/200),outline=rb(j*255/200)) 401 | # Color scale slider 402 | f= Frame(master) 403 | Scale(master, from_=-255, to=255,length=canvas_width, variable = color,orient=HORIZONTAL,showvalue=0,command=update_brush).pack(side=TOP) 404 | 405 | # Buttons and brushes 406 | Button(f, text="Sample", command=sample).pack(side=LEFT) 407 | Button(f, text="Reset", command=Reset).pack(side=LEFT) 408 | Button(f, text="Update", command=UpdateGIM).pack(side=LEFT) 409 | brush = Canvas(f,width=25,height=25) 410 | Scale(f, from_=0, to=64,length=100,width=25, variable = d,orient=HORIZONTAL,showvalue=0,command=update_brush).pack(side=LEFT) # Brush diameter scale 411 | brush.pack(side=LEFT) 412 | inferbutton = Button(f, text="Infer", command=infer) 413 | inferbutton.pack(side=LEFT) 414 | colorbutton=Button(f,text='Col',command=getColor) 415 | colorbutton.pack(side=LEFT) 416 | myentry = Entry() 417 | myentry.pack(side=LEFT) 418 | f.pack(side=TOP) 419 | 420 | 421 | print('Running') 422 | # Reset and infer to kick it off 423 | Reset() 424 | infer() 425 | mainloop() 426 | 427 | 428 | 429 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Photo Editor 2 | A simple interface for editing natural photos with generative neural networks. 3 | 4 | ![GUI1](http://i.imgur.com/dmmFOiG.gif) ![GUI2](http://i.imgur.com/mStg8nG.gif) ![GUI3](http://i.imgur.com/CqjTDFN.gif) 5 | 6 | This repository contains code for the paper "[Neural Photo Editing with Introspective Adversarial Networks](http://arxiv.org/abs/1609.07093)," and the [Associated Video](https://www.youtube.com/watch?v=FDELBFSeqQs). 7 | 8 | ## Installation 9 | To run the Neural Photo Editor, you will need: 10 | - Python, likely version 2.7. You may be able to use early versions of Python2, but I'm pretty sure there's some incompatibilities with Python3 in here. 11 | - [Theano](http://deeplearning.net/software/theano/), development version. 12 | - [lasagne](http://lasagne.readthedocs.io/en/latest/user/installation.html), development version. 13 | - I highly recommend [cuDNN](https://developer.nvidia.com/cudnn) as speed is key, but it is not a dependency. 14 | - numpy, scipy, PIL, Tkinter and tkColorChooser, but it is likely that your python distribution already has those. 15 | 16 | ## Running the NPE 17 | By default, the NPE runs on IAN_simple. This is a slimmed-down version of the IAN without MDC or RGB-Beta blocks, which runs without lag on a laptop GPU with ~1GB of memory (GT730M) 18 | 19 | If you're on a Windows machine, you will want to create a .theanorc file and at least set the flag FLOATX=float32. 20 | 21 | If you're on a linux machine, you can just insert THEANO_FLAGS=floatX=float32 before the command line call. 22 | 23 | If you don't have cuDNN, simply change line 56 of the NPE.py file from dnn=True to dnn=False. Note that I presently only have the non-cuDNN option working for IAN_simple. 24 | 25 | Then, run the command: 26 | 27 | ```sh 28 | python NPE.py 29 | ``` 30 | If you wish to use a different model, simply edit the line with "config path" in the NPE.py file. 31 | 32 | You can make use of any model with an inference mechanism (VAE or ALI-based GAN). 33 | 34 | ## Commands 35 | - You can paint the image by picking a color and painting on the image, or paint in the latent space canvas (the red and blue tiles below the image). 36 | - The long horizontal slider controls the magnitude of the latent brush, and the smaller horizontal slider controls the size of both the latent and the main image brush. 37 | - You can select different entries from the subset of the celebA validation set (included in this repository as an .npz) by typing in a number from 0-999 in the bottom left box and hitting "infer." 38 | - Use the reset button to return to the ground truth image. 39 | - Press "Update" to update the ground-truth image and corresponding reconstruction with the current image. Use "Infer" to return to an original ground truth image from the dataset. 40 | - Use the sample button to generate a random latent vector and corresponding image. 41 | - Use the scroll wheel to lighten or darken an image patch (equivalent to using a pure white or pure black paintbrush). Note that this automatically returns you to sample mode, and may require hitting "infer" rather than "reset" to get back to photo editing. 42 | 43 | 44 | ## Training an IAN on celebA 45 | You will need [Fuel](https://github.com/mila-udem/fuel) along with the 64x64 version of celebA. See [here](https://github.com/vdumoulin/discgen) for instructions on downloading and preparing it. 46 | 47 | If you wish to train a model, the IAN.py file contains the model configuration, and the train_IAN.py file contains the training code, which can be run like this: 48 | 49 | ```sh 50 | python train_IAN.py IAN.py 51 | ``` 52 | 53 | By default, this code will save (and overwrite!) the weights to a .npz file with the same name as the config.py file (i.e. "IAN.py -> IAN.npz"), and will output a jsonl log of the training with metrics recorded after every chunk. 54 | 55 | Use the --resume=True flag when calling to resume training a model--it will automatically pick up from the most recent epoch. 56 | 57 | ## Sampling the IAN 58 | # 59 | You can generate a sample and reconstruction+interpolation grid with: 60 | 61 | ```sh 62 | python sample_IAN.py IAN.py 63 | ``` 64 | 65 | Note that you will need [matplotlib](http://matplotlib.org/). to do so. 66 | ## Known Issues/Bugs 67 | My MADE layer currently only accepts hidden unit sizes that are equal to the size of the latent vector, which will present itself as a BAD_PARAM error. 68 | 69 | Since the MADE really only acts as an autoregressive randomizer I'm not too worried about this, but it does bear looking into. 70 | 71 | I messed around with the keywords for get_model, you'll need to deal with these if you wish to run any model other than IAN_simple through the editor. 72 | 73 | Everything is presently just dumped into a single, unorganized directory. I'll be adding folders and cleaning things up soon. 74 | 75 | ## Notes 76 | Remainder of the IAN experiments (including SVHN) coming soon. 77 | 78 | I've integrated the plat interface which makes the NPE itself independent of framework, so you should be able to run it with Blocks, TensorFlow, PyTorch, PyCaffe, what have you, by modifying the IAN class provided in models.py. 79 | 80 | 81 | ## Acknowledgments 82 | This code contains lasagne layers and other goodies adopted from a number of places: 83 | - MADE wrapped from the implementation by M. Germain et al: https://github.com/mgermain/MADE 84 | - Gaussian Sample layer from Tencia Lee's Recipe: https://github.com/Lasagne/Recipes/blob/master/examples/variational_autoencoder/variational_autoencoder.py 85 | - Minibatch Discrimination layer from OpenAI's Improved GAN Techniques: https://github.com/openai/improved-gan 86 | - Deconv Layer adapted from Radford's DCGAN: https://github.com/Newmu/dcgan_code 87 | - Image-Grid Plotter adopted from AlexMLamb's Discriminative Regularization: https://github.com/vdumoulin/discgen 88 | - Metrics_logging and checkpoints adopted from Daniel Maturana's VoxNet: https://github.com/dimatura/voxnet 89 | - Plat interface adopted from Tom White's plat: https://github.com/dribnet/plat 90 | -------------------------------------------------------------------------------- /discgen_utils.py: -------------------------------------------------------------------------------- 1 | # Plot Image Grid function imported from Discriminative Regularization for Generative Models by Lamb et al: 2 | # https://github.com/vdumoulin/discgen 3 | import six 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | from matplotlib import cm, pyplot 7 | from mpl_toolkits.axes_grid1 import ImageGrid 8 | 9 | 10 | 11 | def plot_image_grid(images, num_rows, num_cols, save_path=None): 12 | """Plots images in a grid. 13 | 14 | Parameters 15 | ---------- 16 | images : numpy.ndarray 17 | Images to display, with shape 18 | ``(num_rows * num_cols, num_channels, height, width)``. 19 | num_rows : int 20 | Number of rows for the image grid. 21 | num_cols : int 22 | Number of columns for the image grid. 23 | save_path : str, optional 24 | Where to save the image grid. Defaults to ``None``, 25 | which causes the grid to be displayed on screen. 26 | 27 | """ 28 | figure = pyplot.figure() 29 | grid = ImageGrid(figure, 111, (num_rows, num_cols), axes_pad=0.1) 30 | 31 | for image, axis in zip(images, grid): 32 | axis.imshow(image.transpose(1, 2, 0), interpolation='nearest') 33 | axis.set_yticklabels(['' for _ in range(image.shape[1])]) 34 | axis.set_xticklabels(['' for _ in range(image.shape[2])]) 35 | axis.axis('off') 36 | 37 | if save_path is None: 38 | pyplot.show() 39 | else: 40 | pyplot.savefig(save_path, transparent=True, bbox_inches='tight',dpi=212) 41 | pyplot.close() -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | ### Custom Lasagne Layers for Introspective Adversarial Networks 2 | # A Brock, 2016 3 | # 4 | # Layers that are not my own creation should be appropriately attributed here 5 | # MADE wrapped from the implementation by M. Germain et al: https://github.com/mgermain/MADE 6 | # Gaussian Sample layer from Tencia Lee's Recipe: https://github.com/Lasagne/Recipes/blob/master/examples/variational_autoencoder/variational_autoencoder.py 7 | # Minibatch Discrimination layer from OpenAI's Improved GAN Techniques: https://github.com/openai/improved-gan 8 | # Deconv Layer adapted from Radford's DCGAN: https://github.com/Newmu/dcgan_code 9 | 10 | from __future__ import division 11 | import numpy as np 12 | import theano 13 | import theano.tensor as T 14 | import lasagne 15 | import lasagne.layers 16 | 17 | from lasagne.layers import SliceLayer as SL 18 | from lasagne.layers import batch_norm as BN 19 | from lasagne.layers import ElemwiseSumLayer as ESL 20 | from lasagne.layers import NonlinearityLayer as NL 21 | from lasagne.layers import DenseLayer as DL 22 | from lasagne.init import Normal as initmethod 23 | from lasagne.nonlinearities import elu 24 | from lasagne.nonlinearities import rectify as relu 25 | from lasagne.nonlinearities import LeakyRectify as lrelu 26 | 27 | from lasagne.layers import TransposedConv2DLayer as TC2D 28 | from lasagne.layers import ConcatLayer as CL 29 | 30 | 31 | from math import sqrt 32 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 33 | from mask_generator import MaskGenerator 34 | 35 | # BatchReNorm Layer using cuDNN's BatchNorm 36 | # This layer implements BatchReNorm (https://arxiv.org/abs/1702.03275), 37 | # which modifies BatchNorm to include running-average statistics in addition to 38 | # per-batch statistics in a well-principled manner. The RMAX and DMAX parameters 39 | # are clip parameters which should be fscalars that you'll need to manage in the training 40 | # loop, as they follow an annealing schedule (an example of which is given in the paper). 41 | # I've been adjusting this schedule based on the total number of iterations relative 42 | # to the number given in the paper, so for a ~50,000 iteration training run, I anneal 43 | # RMAX between 1k and 5k iterations rather than 5k and 25k. 44 | 45 | # NOTE: Ideally you should not have to manage RMAX and DMAX separately, so 46 | # if someone wants to write a default_update similar to the one used for 47 | # running_average and running_inv_std, that would be excellent. 48 | class BatchReNormDNNLayer(lasagne.layers.BatchNormLayer): 49 | 50 | def __init__(self, incoming, RMAX,DMAX,axes='auto', epsilon=1e-4, alpha=0.1, 51 | beta=lasagne.init.Constant(0), gamma=lasagne.init.Constant(1), 52 | mean=lasagne.init.Constant(0), inv_std=lasagne.init.Constant(1), **kwargs): 53 | super(BatchReNormDNNLayer, self).__init__( 54 | incoming, axes, epsilon, alpha, beta, gamma, mean, inv_std, 55 | **kwargs) 56 | all_but_second_axis = (0,) + tuple(range(2, len(self.input_shape))) 57 | 58 | self.RMAX,self.DMAX = RMAX,DMAX 59 | 60 | if self.axes not in ((0,), all_but_second_axis): 61 | raise ValueError("BatchNormDNNLayer only supports normalization " 62 | "across the first axis, or across all but the " 63 | "second axis, got axes=%r" % (axes,)) 64 | 65 | def get_output_for(self, input, deterministic=False, 66 | batch_norm_use_averages=None, 67 | batch_norm_update_averages=None, **kwargs): 68 | 69 | # Decide whether to use the stored averages or mini-batch statistics 70 | if batch_norm_use_averages is None: 71 | batch_norm_use_averages = deterministic 72 | use_averages = batch_norm_use_averages 73 | 74 | # Decide whether to update the stored averages 75 | if batch_norm_update_averages is None: 76 | batch_norm_update_averages = not deterministic 77 | update_averages = batch_norm_update_averages 78 | 79 | # prepare dimshuffle pattern inserting broadcastable axes as needed 80 | param_axes = iter(range(input.ndim - len(self.axes))) 81 | pattern = ['x' if input_axis in self.axes 82 | else next(param_axes) 83 | for input_axis in range(input.ndim)] 84 | # and prepare the converse pattern removing those broadcastable axes 85 | unpattern = [d for d in range(input.ndim) if d not in self.axes] 86 | 87 | # call cuDNN if needed, obtaining normalized outputs and statistics 88 | if not use_averages or update_averages: 89 | # cuDNN requires beta/gamma tensors; create them if needed 90 | shape = tuple(s for (d, s) in enumerate(input.shape) 91 | if d not in self.axes) 92 | gamma = self.gamma or theano.tensor.ones(shape) 93 | beta = self.beta or theano.tensor.zeros(shape) 94 | mode = 'per-activation' if self.axes == (0,) else 'spatial' 95 | 96 | (normalized, 97 | input_mean, 98 | input_inv_std) = theano.sandbox.cuda.dnn.dnn_batch_normalization_train( 99 | input, gamma.dimshuffle(pattern), beta.dimshuffle(pattern), 100 | mode, self.epsilon) 101 | 102 | # normalize with stored averages, if needed 103 | if use_averages: 104 | mean = self.mean.dimshuffle(pattern) 105 | inv_std = self.inv_std.dimshuffle(pattern) 106 | gamma = 1 if self.gamma is None else self.gamma.dimshuffle(pattern) 107 | beta = 0 if self.beta is None else self.beta.dimshuffle(pattern) 108 | normalized = (input - mean) * (gamma * inv_std) + beta 109 | 110 | # update stored averages, if needed 111 | if update_averages: 112 | # Trick: To update the stored statistics, we create memory-aliased 113 | # clones of the stored statistics: 114 | running_mean = theano.clone(self.mean, share_inputs=False) 115 | running_inv_std = theano.clone(self.inv_std, share_inputs=False) 116 | # set a default update for them: 117 | running_mean.default_update = ((1 - self.alpha) * running_mean + 118 | self.alpha * input_mean.dimshuffle(unpattern)) 119 | running_inv_std.default_update = ((1 - self.alpha) * 120 | running_inv_std + 121 | self.alpha * input_inv_std.dimshuffle(unpattern)) 122 | # and make sure they end up in the graph without participating in 123 | # the computation (this way their default_update will be collected 124 | # and applied, but the computation will be optimized away): 125 | # dummy = running_mean + running_inv_std).dimshuffle(pattern) 126 | r = T.clip(running_inv_std.dimshuffle(pattern)/input_inv_std,1/self.RMAX,self.RMAX) 127 | d = T.clip( (input_mean-running_mean.dimshuffle(pattern))*running_inv_std.dimshuffle(pattern),-self.DMAX,self.DMAX) 128 | normalized = normalized * r + d 129 | 130 | return normalized 131 | 132 | 133 | 134 | # More Efficient MDCL layer 135 | # When seeking to construct an MDC block, drop this into a Conv2D layer instead; it's faster. 136 | # You can also easily re-parameterize this to a full-rank MDC block by dropping in the 137 | # line that creates baseW into the for loop such that a new W is sampled each time. 138 | def mdclW(num_filters,num_channels,filter_size,winit,name,scales): 139 | # Coefficient Initializer 140 | sinit = lasagne.init.Constant(1.0/(1+len(scales))) 141 | # Total filter size 142 | size = filter_size + (filter_size-1)*(scales[-1]-1) 143 | # Multiscale Dilated Filter 144 | W = T.zeros((num_filters,num_channels,size,size)) 145 | # Undilated Base Filter 146 | baseW = theano.shared(lasagne.utils.floatX(winit.sample((num_filters,num_channels,filter_size,filter_size))),name=name+'.W') 147 | for scale in enumerate(scales[::-1]): # enumerate backwards so that we place the main filter on top 148 | W = T.set_subtensor(W[:,:,scales[-1]-scale:size-scales[-1]+scale:scale,scales[-1]-scale:size-scales[-1]+scale:scale], 149 | baseW*theano.shared(lasagne.utils.floatX(sinit.sample(num_filters)), name+'.coeff_'+str(scale)).dimshuffle(0,'x','x','x')) 150 | return W 151 | 152 | # Subpixel Upsample Layer from (https://arxiv.org/abs/1609.05158) 153 | # This layer uses a set of r^2 set_subtensor calls to reorganize the tensor in a subpixel-layer upscaling style 154 | # as done in the ESPCN Magic ony paper for super-resolution. 155 | # r is the upscale factor. 156 | # c is the number of output channels. 157 | class SubpixelLayer(lasagne.layers.Layer): 158 | def __init__(self, incoming,r,c, **kwargs): 159 | super(SubpixelLayer, self).__init__(incoming, **kwargs) 160 | self.r=r # Upscale factor 161 | self.c=c # number of output channels 162 | 163 | def get_output_shape_for(self, input_shape): 164 | return (input_shape[0],self.c,self.r*input_shape[2],self.r*input_shape[3]) 165 | 166 | def get_output_for(self, input, deterministic=False, **kwargs): 167 | out = T.zeros((input.shape[0],self.output_shape[1],self.output_shape[2],self.output_shape[3])) 168 | for x in xrange(self.r): # loop across all feature maps belonging to this channel 169 | for y in xrange(self.r): 170 | out=T.set_subtensor(out[:,:,x::self.r,y::self.r],input[:,self.r*x+y::self.r*self.r,:,:]) 171 | return out 172 | # Subpixel Upsample Layer using reshapes as in https://github.com/Tetrachrome/subpixel. This implementation appears to be 10x slower than 173 | # the set_subtensor implementation, presumably because of the extra reshapes after the splits. 174 | class SubpixelLayer2(lasagne.layers.Layer): 175 | def __init__(self, incoming,r,c, **kwargs): 176 | super(SubpixelLayer2, self).__init__(incoming, **kwargs) 177 | self.r=r 178 | self.c=c 179 | 180 | 181 | def get_output_shape_for(self, input_shape): 182 | return (input_shape[0],self.c,self.r*input_shape[2],self.r*input_shape[3]) 183 | 184 | def get_output_for(self, input, deterministic=False, **kwargs): 185 | def _phase_shift(input,r): 186 | bsize,c,a,b = input.shape[0],1,self.output_shape[2]//r,self.output_shape[3]//r 187 | X = T.reshape(input, (bsize,r,r,a,b)) 188 | X = T.transpose(X, (0, 3,4,1,2)) # bsize, a, b, r2,r1 189 | X = T.split(x=X,splits_size=[1]*a,n_splits=a,axis=1) # a, [bsize, b, r, r] 190 | X = [T.reshape(x,(bsize,b,r,r))for x in X] 191 | X = T.concatenate(X,axis=2) # bsize, b, a*r, r 192 | X = T.split(x=X,splits_size =[1]*b,n_splits=b,axis=1) # b, [bsize, a*r, r] 193 | X = [T.reshape(x,(bsize,a*r,r))for x in X] 194 | X = T.concatenate(X,axis=2) # bsize, a*r, b*r 195 | return X.dimshuffle(0,'x',1,2) 196 | Xc = T.split(x=input,splits_size =[input.shape[1]//self.c]*self.c,n_splits=self.c,axis=1) 197 | return T.concatenate([_phase_shift(xc,self.r) for xc in Xc],axis=1) 198 | 199 | # Multiscale Dilated Convolution Block 200 | # This function (not a layer in and of itself, though you could make it one) returns a set of concatenated conv2d and dilatedconv2d layers. 201 | # Each layer uses the same basic filter W, operating at a different dilation factor (or taken as the mean of W for the 1x1 conv). 202 | # The channel-wise output of each layer is weighted by a set of coefficients, which are initialized to 1 / the total number of dilation scales, 203 | # meaning that were starting by taking an elementwise mean. These should be learnable parameters. 204 | 205 | # NOTES: - I'm considering changing the variable names to be more descriptive, and look less like ridiculous academic code. It's on the to-do list. 206 | # - I keep the bias and nonlinearity out of the default definition for this layer, as I expect it to be batchnormed and nonlinearized in the model config. 207 | def MDCL(incoming,num_filters,scales,name,dnn=True): 208 | if dnn: 209 | from lasagne.layers.dnn import Conv2DDNNLayer as C2D 210 | # W initialization method--this should also work as Orthogonal('relu'), but I have yet to validate that as thoroughly. 211 | winit = initmethod(0.02) 212 | 213 | # Initialization method for the coefficients 214 | sinit = lasagne.init.Constant(1.0/(1+len(scales))) 215 | 216 | # Number of incoming channels 217 | ni =lasagne.layers.get_output_shape(incoming)[1] 218 | 219 | # Weight parameter--the primary parameter for this block 220 | W = theano.shared(lasagne.utils.floatX(winit.sample((num_filters,lasagne.layers.get_output_shape(incoming)[1],3,3))),name=name+'W') 221 | 222 | # Primary Convolution Layer--No Dilation 223 | n = C2D(incoming = incoming, 224 | num_filters = num_filters, 225 | filter_size = [3,3], 226 | stride = [1,1], 227 | pad = (1,1), 228 | W = W*theano.shared(lasagne.utils.floatX(sinit.sample(num_filters)), name+'_coeff_base').dimshuffle(0,'x','x','x'), # Note the broadcasting dimshuffle for the num_filter scalars. 229 | b = None, 230 | nonlinearity = None, 231 | name = name+'base' 232 | ) 233 | # List of remaining layers. This should probably just all be concatenated into a single list rather than being a separate deal. 234 | nd = [] 235 | for i,scale in enumerate(scales): 236 | 237 | # I don't think 0 dilation is technically defined (or if it is it's just the regular filter) but I use it here as a convenient keyword to grab the 1x1 mean conv. 238 | if scale==0: 239 | nd.append(C2D(incoming = incoming, 240 | num_filters = num_filters, 241 | filter_size = [1,1], 242 | stride = [1,1], 243 | pad = (0,0), 244 | W = T.mean(W,axis=[2,3]).dimshuffle(0,1,'x','x')*theano.shared(lasagne.utils.floatX(sinit.sample(num_filters)), name+'_coeff_1x1').dimshuffle(0,'x','x','x'), 245 | b = None, 246 | nonlinearity = None, 247 | name = name+str(scale))) 248 | # Note the dimshuffles in this layer--these are critical as the current DilatedConv2D implementation uses a backward pass. 249 | else: 250 | nd.append(lasagne.layers.DilatedConv2DLayer(incoming = lasagne.layers.PadLayer(incoming = incoming, width=(scale,scale)), 251 | num_filters = num_filters, 252 | filter_size = [3,3], 253 | dilation=(scale,scale), 254 | W = W.dimshuffle(1,0,2,3)*theano.shared(lasagne.utils.floatX(sinit.sample(num_filters)), name+'_coeff_'+str(scale)).dimshuffle('x',0,'x','x'), 255 | b = None, 256 | nonlinearity = None, 257 | name = name+str(scale))) 258 | return ESL(nd+[n]) 259 | 260 | # MDC-based Upsample Layer. 261 | # This is a prototype I don't make use of extensively. It's operational but it doesn't seem to improve results yet. 262 | def USL(incoming,num_filters,scales,name,dnn=True): 263 | if dnn: 264 | from lasagne.layers.dnn import Conv2DDNNLayer as C2D 265 | 266 | # W initialization method--this should also work as Orthogonal('relu'), but I have yet to validate that as thoroughly. 267 | winit = initmethod(0.02) 268 | 269 | # Initialization method for the coefficients 270 | sinit = lasagne.init.Constant(1.0/(1+len(scales))) 271 | 272 | # Number of incoming channels 273 | ni =lasagne.layers.get_output_shape(incoming)[1] 274 | 275 | # Weight parameter--the primary parameter for this block 276 | W = theano.shared(lasagne.utils.floatX(winit.sample((num_filters,lasagne.layers.get_output_shape(incoming)[1],3,3))),name=name+'W') 277 | 278 | # Primary Convolution Layer--No Dilation 279 | n = C2D(incoming = Upscale2DLayer(incoming,2), 280 | num_filters = num_filters, 281 | filter_size = [3,3], 282 | stride = [1,1], 283 | pad = (1,1), 284 | W = W*theano.shared(lasagne.utils.floatX(sinit.sample(num_filters)), name+'_coeff_base').dimshuffle(0,'x','x','x'), 285 | b = None, 286 | nonlinearity = None, 287 | name = name+'base' 288 | ) 289 | # Remaining layers 290 | nd = [] 291 | for i,scale in enumerate(scales): 292 | if scale==0: 293 | nd.append(C2D(incoming = Upscale2DLayer(incoming,2), 294 | num_filters = num_filters, 295 | filter_size = [1,1], 296 | stride = [1,1], 297 | pad = (0,0), 298 | W = T.mean(W,axis=[2,3]).dimshuffle(0,1,'x','x')*theano.shared(lasagne.utils.floatX(sinit.sample(num_filters)), name+'_coeff_1x1').dimshuffle(0,'x','x','x'), 299 | b = None, 300 | nonlinearity = None, 301 | name = name+'1x1' 302 | )) 303 | else: 304 | nd.append(lasagne.layers.DilatedConv2DLayer(incoming = lasagne.layers.PadLayer(incoming = Upscale2DLayer(incoming,2), width=(scale,scale)), 305 | num_filters = num_filters, 306 | filter_size = [3,3], 307 | dilation=(scale,scale), 308 | W = W.dimshuffle(1,0,2,3)*theano.shared(lasagne.utils.floatX(sinit.sample(num_filters)), name+'_coeff_'+str(scale)).dimshuffle('x',0,'x','x'), 309 | b = None, 310 | nonlinearity = None, 311 | name = name+str(scale))) 312 | 313 | # A single deconv layer is also concatenated here. Like I said, it's a prototype! 314 | nd.append(DeconvLayer(incoming = incoming, 315 | num_filters = num_filters, 316 | filter_size = [3,3], 317 | stride = [2,2], 318 | crop = (1,1), 319 | W = W.dimshuffle(1,0,2,3)*theano.shared(lasagne.utils.floatX(sinit.sample(num_filters)), name+'_coeff_deconv').dimshuffle('x',0,'x','x'), 320 | b = None, 321 | nonlinearity = None, 322 | name = name+'deconv' 323 | )) 324 | 325 | return ESL(nd+[n]) 326 | 327 | #MDC-based Downsample Layer. 328 | # This is a prototype I don't make use of extensively. It's operational and it seems like it works alright, but it's restrictively expensive 329 | # and I am not PARALLELICUS, god of GPUs, so I don't have the memory to spare for it. 330 | # Note that this layer does not currently support having a 0 scale like the others do, and just has a 1x1-stride2 conv by default. 331 | def DSL(incoming,num_filters,scales,name,dnn=True): 332 | if dnn: 333 | from lasagne.layers.dnn import Conv2DDNNLayer as C2D 334 | # W initialization method--this should also work as Orthogonal('relu'), but I have yet to validate that as thoroughly. 335 | winit = initmethod(0.02) 336 | 337 | # Initialization method for the coefficients 338 | sinit = lasagne.init.Constant(1.0/(1+len(scales))) 339 | 340 | # Number of incoming channels 341 | ni =lasagne.layers.get_output_shape(incoming)[1] 342 | 343 | # Weight parameter--the primary parameter for this block 344 | W = theano.shared(lasagne.utils.floatX(winit.sample((num_filters,lasagne.layers.get_output_shape(incoming)[1],3,3))),name=name+'W') 345 | 346 | # Main layer--3x3 conv with stride 2 347 | n = C2D(incoming = incoming, 348 | num_filters = num_filters, 349 | filter_size = [3,3], 350 | stride = [2,2], 351 | pad = (1,1), 352 | W = W*theano.shared(lasagne.utils.floatX(sinit.sample(num_filters)), name+'_coeff_base').dimshuffle(0,'x','x','x'), 353 | b = None, 354 | nonlinearity = None, 355 | name = name+'base' 356 | ) 357 | 358 | 359 | nd = [] 360 | for i,scale in enumerate(scales): 361 | 362 | p = P2D(incoming = incoming, 363 | pool_size = scale, 364 | stride = 2, 365 | pad = (1,1) if i else (0,0), 366 | mode = 'average_exc_pad', 367 | ) 368 | 369 | nd.append(C2D(incoming = p, 370 | num_filters = num_filters, 371 | filter_size = [3,3], 372 | stride = (1,1), 373 | pad = (1,1), 374 | W = W*theano.shared(lasagne.utils.floatX(sinit.sample(num_filters)), name+'_coeff_'+str(scale)).dimshuffle(0,'x','x','x'),#.dimshuffle('x',0), 375 | b = None, 376 | nonlinearity = None, 377 | name = name+str(scale))) 378 | 379 | 380 | nd.append(C2D(incoming = incoming, 381 | num_filters = num_filters, 382 | filter_size = [1,1], 383 | stride = [2,2], 384 | pad = (0,0), 385 | W = T.mean(W,axis=[2,3]).dimshuffle(0,1,'x','x')*theano.shared(lasagne.utils.floatX(sinit.sample(num_filters)), name+'_coeff_1x1').dimshuffle(0,'x','x','x'), 386 | b = None, 387 | nonlinearity = None, 388 | name = name+'1x1' 389 | )) 390 | 391 | return ESL(nd+[n]) 392 | 393 | # Beta Distribution Layer 394 | # This layer takes in a batch_size batch, 2-channel, NxN dimension layer and returns the output of the first channel 395 | # divided by the sum of both channels, which is equivalent to finding the expected value for a beta distribution. 396 | # Note that this version of the layer scales to {-1,1} for compatibility with tanh. 397 | class beta_layer(lasagne.layers.MergeLayer): 398 | def __init__(self, alpha,beta, **kwargs): 399 | super(beta_layer, self).__init__([alpha,beta], **kwargs) 400 | 401 | def get_output_shape_for(self, input_shape): 402 | print(input_shape) 403 | return input_shape[0] 404 | 405 | def get_output_for(self, inputs, deterministic=False, **kwargs): 406 | alpha,beta = inputs 407 | # return 2*T.true_div(alpha,T.add(alpha,beta)+1e-8)-1 408 | return 2*(alpha/(alpha+beta+1e-8))-1 409 | 410 | # Convenience Function to produce a residual pre-activation MDCL block 411 | def MDBLOCK(incoming,num_filters,scales,name,nonlinearity): 412 | return NL(BN(ESL([incoming, 413 | MDCL(NL(BN(MDCL(NL(BN(incoming,name=name+'bnorm0'),nonlinearity),num_filters,scales,name),name=name+'bnorm1'),nonlinearity), 414 | num_filters, 415 | scales, 416 | name+'2')]),name=name+'bnorm2'),nonlinearity) 417 | 418 | # Gaussian Sample Layer for VAE from Tencia Lee 419 | class GaussianSampleLayer(lasagne.layers.MergeLayer): 420 | def __init__(self, mu, logsigma, rng=None, **kwargs): 421 | self.rng = rng if rng else RandomStreams(lasagne.random.get_rng().randint(1,2147462579)) 422 | super(GaussianSampleLayer, self).__init__([mu, logsigma], **kwargs) 423 | 424 | def get_output_shape_for(self, input_shapes): 425 | return input_shapes[0] 426 | 427 | def get_output_for(self, inputs, deterministic=False, **kwargs): 428 | mu, logsigma = inputs 429 | shape=(self.input_shapes[0][0] or inputs[0].shape[0], 430 | self.input_shapes[0][1] or inputs[0].shape[1]) 431 | if deterministic: 432 | return mu 433 | return mu + T.exp(logsigma) * self.rng.normal(shape) 434 | 435 | # DeconvLayer adapted from Radford's DCGAN Implementation 436 | class DeconvLayer(lasagne.layers.conv.BaseConvLayer): 437 | def __init__(self, incoming, num_filters, filter_size, stride=(1, 1), 438 | crop=0, untie_biases=False, 439 | W=initmethod(), b=lasagne.init.Constant(0.), 440 | nonlinearity=lasagne.nonlinearities.rectify, flip_filters=False, 441 | **kwargs): 442 | super(DeconvLayer, self).__init__( 443 | incoming, num_filters, filter_size, stride, crop, untie_biases, 444 | W, b, nonlinearity, flip_filters, n=2, **kwargs) 445 | # rename self.crop to self.pad 446 | self.crop = self.pad 447 | del self.pad 448 | 449 | def get_W_shape(self): 450 | num_input_channels = self.input_shape[1] 451 | # first two sizes are swapped compared to a forward convolution 452 | return (num_input_channels, self.num_filters) + self.filter_size 453 | 454 | def get_output_shape_for(self, input_shape): 455 | 456 | # when called from the constructor, self.crop is still called self.pad: 457 | crop = getattr(self, 'crop', getattr(self, 'pad', None)) 458 | crop = crop if isinstance(crop, tuple) else (crop,) * self.n 459 | batchsize = input_shape[0] 460 | return(batchsize,self.num_filters)+(input_shape[2]*2,input_shape[3]*2) 461 | # return ((batchsize, self.num_filters) + 462 | # tuple(conv_input_length(input, filter, stride, p) 463 | # for input, filter, stride, p 464 | # in zip(input_shape[2:], self.filter_size, 465 | # self.stride, crop))) 466 | 467 | def convolve(self, input, **kwargs): 468 | 469 | # Messy to have these imports here, but seems to allow for switching DNN off. 470 | from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable, 471 | host_from_gpu, 472 | gpu_contiguous, HostFromGpu, 473 | gpu_alloc_empty) 474 | from theano.sandbox.cuda.dnn import GpuDnnConvDesc, GpuDnnConv, GpuDnnConvGradI, dnn_conv, dnn_pool 475 | # Straight outta Radford 476 | img = gpu_contiguous(input) 477 | kerns = gpu_contiguous(self.W) 478 | desc = GpuDnnConvDesc(border_mode=self.crop, subsample=self.stride, 479 | conv_mode='conv')(gpu_alloc_empty(img.shape[0], kerns.shape[1], img.shape[2]*self.stride[0], img.shape[3]*self.stride[1]).shape, kerns.shape) 480 | out = gpu_alloc_empty(img.shape[0], kerns.shape[1], img.shape[2]*self.stride[0], img.shape[3]*self.stride[1]) 481 | conved = GpuDnnConvGradI()(kerns, img, out, desc) 482 | 483 | return conved 484 | 485 | # Minibatch discrimination layer from OpenAI's improved GAN techniques 486 | class MinibatchLayer(lasagne.layers.Layer): 487 | def __init__(self, incoming, num_kernels, dim_per_kernel=5, theta=lasagne.init.Normal(0.05), 488 | log_weight_scale=lasagne.init.Constant(0.), b=lasagne.init.Constant(-1.), **kwargs): 489 | super(MinibatchLayer, self).__init__(incoming, **kwargs) 490 | self.num_kernels = num_kernels 491 | num_inputs = int(np.prod(self.input_shape[1:])) 492 | self.theta = self.add_param(theta, (num_inputs, num_kernels, dim_per_kernel), name="theta") 493 | self.log_weight_scale = self.add_param(log_weight_scale, (num_kernels, dim_per_kernel), name="log_weight_scale") 494 | self.W = self.theta * (T.exp(self.log_weight_scale)/T.sqrt(T.sum(T.square(self.theta),axis=0))).dimshuffle('x',0,1) 495 | self.b = self.add_param(b, (num_kernels,), name="b") 496 | 497 | def get_output_shape_for(self, input_shape): 498 | return (input_shape[0], np.prod(input_shape[1:])+self.num_kernels) 499 | 500 | def get_output_for(self, input, init=False, **kwargs): 501 | if input.ndim > 2: 502 | # if the input has more than two dimensions, flatten it into a 503 | # batch of feature vectors. 504 | input = input.flatten(2) 505 | 506 | activation = T.tensordot(input, self.W, [[1], [0]]) 507 | abs_dif = (T.sum(abs(activation.dimshuffle(0,1,2,'x') - activation.dimshuffle('x',1,2,0)),axis=2) 508 | + 1e6 * T.eye(input.shape[0]).dimshuffle(0,'x',1)) 509 | 510 | if init: 511 | mean_min_abs_dif = 0.5 * T.mean(T.min(abs_dif, axis=2),axis=0) 512 | abs_dif /= mean_min_abs_dif.dimshuffle('x',0,'x') 513 | self.init_updates = [(self.log_weight_scale, self.log_weight_scale-T.log(mean_min_abs_dif).dimshuffle(0,'x'))] 514 | 515 | f = T.sum(T.exp(-abs_dif),axis=2) 516 | 517 | if init: 518 | mf = T.mean(f,axis=0) 519 | f -= mf.dimshuffle('x',0) 520 | self.init_updates.append((self.b, -mf)) 521 | else: 522 | f += self.b.dimshuffle('x',0) 523 | 524 | return T.concatenate([input, f], axis=1) 525 | 526 | # Convenience function to define an inception-style block 527 | def InceptionLayer(incoming,param_dict,block_name): 528 | branch = [0]*len(param_dict) 529 | # Loop across branches 530 | for i,dict in enumerate(param_dict): 531 | for j,style in enumerate(dict['style']): # Loop up branch 532 | branch[i] = C2D( 533 | incoming = branch[i] if j else incoming, 534 | num_filters = dict['num_filters'][j], 535 | filter_size = dict['filter_size'][j], 536 | pad = dict['pad'][j] if 'pad' in dict else None, 537 | stride = dict['stride'][j], 538 | W = initmethod('relu'), 539 | nonlinearity = dict['nonlinearity'][j], 540 | name = block_name+'_'+str(i)+'_'+str(j)) if style=='convolutional'\ 541 | else NL(lasagne.layers.dnn.Pool2DDNNLayer( 542 | incoming=incoming if j == 0 else branch[i], 543 | pool_size = dict['filter_size'][j], 544 | mode = dict['mode'][j], 545 | stride = dict['stride'][j], 546 | pad = dict['pad'][j], 547 | name = block_name+'_'+str(i)+'_'+str(j)), 548 | nonlinearity = dict['nonlinearity'][j]) if style=='pool'\ 549 | else lasagne.layers.DilatedConv2DLayer( 550 | incoming = lasagne.layers.PadLayer(incoming = incoming if j==0 else branch[i],width = dict['pad'][j]) if 'pad' in dict else incoming if j==0 else branch[i], 551 | num_filters = dict['num_filters'][j], 552 | filter_size = dict['filter_size'][j], 553 | dilation = dict['dilation'][j], 554 | # pad = dict['pad'][j] if 'pad' in dict else None, 555 | W = initmethod('relu'), 556 | nonlinearity = dict['nonlinearity'][j], 557 | name = block_name+'_'+str(i)+'_'+str(j)) if style== 'dilation'\ 558 | else DL( 559 | incoming = incoming if j==0 else branch[i], 560 | num_units = dict['num_filters'][j], 561 | W = initmethod('relu'), 562 | b = None, 563 | nonlinearity = dict['nonlinearity'][j], 564 | name = block_name+'_'+str(i)+'_'+str(j)) 565 | # Apply Batchnorm 566 | branch[i] = BN(branch[i],name = block_name+'_bnorm_'+str(i)+'_'+str(j)) if dict['bnorm'][j] else branch[i] 567 | # Concatenate Sublayers 568 | 569 | return CL(incomings=branch,name=block_name) 570 | 571 | # Convenience function to define an inception-style block with upscaling 572 | def InceptionUpscaleLayer(incoming,param_dict,block_name): 573 | branch = [0]*len(param_dict) 574 | # Loop across branches 575 | for i,dict in enumerate(param_dict): 576 | for j,style in enumerate(dict['style']): # Loop up branch 577 | branch[i] = TC2D( 578 | incoming = branch[i] if j else incoming, 579 | num_filters = dict['num_filters'][j], 580 | filter_size = dict['filter_size'][j], 581 | crop = dict['pad'][j] if 'pad' in dict else None, 582 | stride = dict['stride'][j], 583 | W = initmethod('relu'), 584 | nonlinearity = dict['nonlinearity'][j], 585 | name = block_name+'_'+str(i)+'_'+str(j)) if style=='convolutional'\ 586 | else NL( 587 | incoming = lasagne.layers.dnn.Pool2DDNNLayer( 588 | incoming = lasagne.layers.Upscale2DLayer( 589 | incoming=incoming if j == 0 else branch[i], 590 | scale_factor = dict['stride'][j]), 591 | pool_size = dict['filter_size'][j], 592 | stride = [1,1], 593 | mode = dict['mode'][j], 594 | pad = dict['pad'][j], 595 | name = block_name+'_'+str(i)+'_'+str(j)), 596 | nonlinearity = dict['nonlinearity'][j]) 597 | # Apply Batchnorm 598 | branch[i] = BN(branch[i],name = block_name+'_bnorm_'+str(i)+'_'+str(j)) if dict['bnorm'][j] else branch[i] 599 | # Concatenate Sublayers 600 | 601 | return CL(incomings=branch,name=block_name) 602 | 603 | # Convenience function to efficiently generate param dictionaries for use with InceptioNlayer 604 | def pd(num_layers=2,num_filters=32,filter_size=(3,3),pad=1,stride = (1,1),nonlinearity=elu,style='convolutional',bnorm=1,**kwargs): 605 | input_args = locals() 606 | input_args.pop('num_layers') 607 | return {key:entry if type(entry) is list else [entry]*num_layers for key,entry in input_args.iteritems()} 608 | 609 | # Possible Conv2DDNN convenience function. Remember to delete the C2D import at the top if you use this 610 | # def C2D(incoming = None, num_filters = 32, filter_size= [3,3],pad = 'same',stride = [1,1], W = initmethod('relu'),nonlinearity = elu,name = None): 611 | # return lasagne.layers.dnn.Conv2DDNNLayer(incoming,num_filters,filter_size,stride,pad,False,W,None,nonlinearity,False) 612 | 613 | # Shape-Preserving Gaussian Sample layer for latent vectors with spatial dimensions. 614 | # This is a holdover from an "old" (i.e. I abandoned it last month) idea. 615 | class GSL(lasagne.layers.MergeLayer): 616 | def __init__(self, mu, logsigma, rng=None, **kwargs): 617 | self.rng = rng if rng else RandomStreams(lasagne.random.get_rng().randint(1,2147462579)) 618 | super(GSL, self).__init__([mu, logsigma], **kwargs) 619 | 620 | def get_output_shape_for(self, input_shape): 621 | print(input_shape) 622 | return input_shape[0] 623 | 624 | def get_output_for(self, inputs, deterministic=False, **kwargs): 625 | mu, logsigma = inputs 626 | if deterministic: 627 | return mu 628 | return mu + T.exp(logsigma) * self.rng.normal(logsigma.shape) 629 | 630 | # Convenience function to return list of sampled latent layers 631 | def GL(mu,ls): 632 | return([GSL(z_mu,z_ls) for z_mu,z_ls in zip(mu,ls)]) 633 | 634 | # Convenience function to return a residual layer. It's not really that much more convenient than ESL'ing, 635 | # but I like being able to see when I'm using Residual connections as opposed to Elemwise-sums 636 | def ResLayer(incoming, IB,nonlinearity): 637 | return NL(ESL([IB,incoming]),nonlinearity) 638 | 639 | 640 | # Inverse autoregressive flow layer 641 | class IAFLayer(lasagne.layers.MergeLayer): 642 | def __init__(self, z, mu, logsigma, **kwargs): 643 | super(IAFLayer, self).__init__([z,mu, logsigma], **kwargs) 644 | 645 | def get_output_shape_for(self, input_shapes): 646 | return input_shapes[0] 647 | 648 | def get_output_for(self, inputs, deterministic=False, **kwargs): 649 | z,mu, logsigma = inputs 650 | return (z - mu) / T.exp(logsigma) 651 | 652 | # Masked layer for MADE, adopted from M.Germain 653 | class MaskedLayer(lasagne.layers.DenseLayer): 654 | 655 | def __init__(self, incoming, num_units, mask_generator,layerIdx,W=lasagne.init.GlorotUniform(), 656 | b=lasagne.init.Constant(0.), nonlinearity=lasagne.nonlinearities.rectify, **kwargs): 657 | super(MaskedLayer, self).__init__(incoming, num_units, W,b, nonlinearity,**kwargs) 658 | self.mask_generator = mask_generator 659 | num_inputs = int(np.prod(self.input_shape[1:])) 660 | self.weights_mask = self.add_param(spec = np.ones((num_inputs, num_units),dtype=np.float32), 661 | shape = (num_inputs, num_units), 662 | name='weights_mask', 663 | trainable=False, 664 | regularizable=False) 665 | self.layerIdx = layerIdx 666 | self.shuffle_update = [(self.weights_mask, mask_generator.get_mask_layer_UPDATE(self.layerIdx))] 667 | 668 | def get_output_for(self,input, **kwargs): 669 | if input.ndim > 2: 670 | input = input.flatten(2) 671 | activation = T.dot(input, self.W*self.weights_mask) 672 | if self.b is not None: 673 | activation = activation + self.b.dimshuffle('x', 0) 674 | return self.nonlinearity(activation) 675 | 676 | 677 | # Stripped-Down Direct Input masked layer: Combine this with ESL and a masked layer to get a true DIML. 678 | # Consider making this a simultaneous subclass of MaskedLayer and elemwise sum layer for cleanliness 679 | # adopted from M.Germain 680 | class DIML(lasagne.layers.DenseLayer): 681 | 682 | def __init__(self, incoming, num_units, mask_generator,layerIdx,W=lasagne.init.GlorotUniform(), 683 | b=lasagne.init.Constant(0.), nonlinearity=None,**kwargs): 684 | super(DIML, self).__init__(incoming, num_units, W,b, nonlinearity,**kwargs) 685 | 686 | self.mask_generator = mask_generator 687 | self.layerIdx = layerIdx 688 | num_inputs = int(np.prod(self.input_shape[1:])) 689 | self.weights_mask = self.add_param(spec = np.ones((num_inputs, num_units),dtype=np.float32), 690 | shape = (num_inputs, num_units), 691 | name='weights_mask', 692 | trainable=False, 693 | regularizable=False) 694 | 695 | 696 | self.shuffle_update = [(self.weights_mask, self.mask_generator.get_direct_input_mask_layer_UPDATE(self.layerIdx + 1))] 697 | 698 | 699 | def get_output_for(self,input, **kwargs): 700 | if input.ndim > 2: 701 | input = input.flatten(2) 702 | 703 | activation = T.dot(input, self.W*self.weights_mask) 704 | 705 | if self.b is not None: 706 | activation = activation + self.b.dimshuffle('x', 0) 707 | return self.nonlinearity(activation) 708 | 709 | # Conditioning Masked Layer 710 | # Currently not used. 711 | # class CML(MaskedLayer): 712 | 713 | # def __init__(self, incoming, num_units, mask_generator,use_cond_mask=False,U=lasagne.init.GlorotUniform(),W=lasagne.init.GlorotUniform(), 714 | # b=init.Constant(0.), nonlinearity=lasagne.nonlinearities.rectify, **kwargs): 715 | # super(CML, self).__init__(incoming, num_units, mask_generator,W, 716 | # b, nonlinearity,**kwargs) 717 | 718 | # self.use_cond_mask=use_cond_mask 719 | # if use_cond_mask: 720 | # self.U = self.add_param(spec = U, 721 | # shape = (num_inputs, num_units), 722 | # name='U', 723 | # trainable=True, 724 | # regularizable=False)theano.shared(value=self.weights_initialization((self.n_in, self.n_out)), name=self.name+'U', borrow=True) 725 | # self.add_param(self.U,name = 726 | # def get_output_for(self,input,**kwargs): 727 | # lin = self.lin_output = T.dot(input, self.W * self.weights_mask) + self.b 728 | # if self.use_cond_mask: 729 | # lin = lin+T.dot(T.ones_like(input), self.U * self.weights_mask) 730 | # return lin if self._activation is None else self._activation(lin) 731 | 732 | 733 | 734 | # Made layer, adopted from M.Germain 735 | class MADE(lasagne.layers.Layer): 736 | def __init__(self,z,hidden_sizes,name,nonlinearity=lasagne.nonlinearities.rectify,output_nonlinearity=None, **kwargs): 737 | # self.rng = rng if rng else RandomStreams(lasagne.random.get_rng().randint(1234)) 738 | super(MADE, self).__init__(z, **kwargs) 739 | 740 | # Incoming latents 741 | self.z = z 742 | 743 | # List defining hidden units in each layer 744 | self.hidden_sizes = hidden_sizes 745 | 746 | # Layer name for saving parameters. 747 | self.name = name 748 | 749 | # nonlinearity 750 | self.nonlinearity = nonlinearity 751 | 752 | # Output nonlinearity 753 | self.output_nonlinearity = output_nonlinearity 754 | 755 | # Control parameters from original MADE 756 | mask_distribution=0 757 | use_cond_mask = False 758 | direct_input_connect = "Output" 759 | direct_output_connect = False 760 | self.shuffled_once = False 761 | 762 | # Mask generator 763 | self.mask_generator = MaskGenerator(lasagne.layers.get_output_shape(z)[1], hidden_sizes, mask_distribution) 764 | 765 | # Build the MADE 766 | # TODO: Consider making this more compact by directly writing to the layers list 767 | self.input_layer = MaskedLayer(incoming = z, 768 | num_units = hidden_sizes[0], 769 | mask_generator = self.mask_generator, 770 | layerIdx = 0, 771 | W = lasagne.init.Orthogonal('relu'), 772 | nonlinearity=self.nonlinearity, 773 | name = self.name+'_input') 774 | 775 | self.layers = [self.input_layer] 776 | 777 | for i in range(1, len(hidden_sizes)): 778 | 779 | self.layers += [MaskedLayer(incoming = self.layers[-1], 780 | num_units = hidden_sizes[i], 781 | mask_generator = self.mask_generator, 782 | layerIdx = i, 783 | W = lasagne.init.Orthogonal('relu'), 784 | nonlinearity=self.nonlinearity, 785 | name = self.name+'_layer_'+str(i))] 786 | 787 | outputLayerIdx = len(self.layers) 788 | 789 | # Output layer 790 | self.layers += [MaskedLayer(incoming = self.layers[-1], 791 | num_units = lasagne.layers.get_output_shape(z)[1], 792 | mask_generator = self.mask_generator, 793 | layerIdx = outputLayerIdx, 794 | W = lasagne.init.Orthogonal('relu'), 795 | nonlinearity = self.output_nonlinearity, 796 | name = self.name+'_output_W'), 797 | DIML(incoming = z, 798 | num_units = lasagne.layers.get_output_shape(z)[1], 799 | mask_generator = self.mask_generator, 800 | layerIdx = outputLayerIdx, 801 | W = lasagne.init.Orthogonal('relu'), 802 | nonlinearity = self.output_nonlinearity, 803 | name = self.name+'_output_D')] 804 | 805 | 806 | 807 | masks_updates = [layer_mask_update for l in self.layers for layer_mask_update in l.shuffle_update] 808 | self.update_masks = theano.function(name='update_masks', 809 | inputs=[], 810 | updates=masks_updates) 811 | # Make the true output layer by ESL'ing the DIML and masked layer 812 | self.final_layer= ESL([self.layers[-2],self.layers[-1]]) 813 | # self.output_layer = self.layers[-1] 814 | # params = [p for p in l.get_params(trainable=True) for l in self.layers] 815 | # print(params) 816 | 817 | def get_output_for(self, input, deterministic=False, **kwargs): 818 | return lasagne.layers.get_output(self.final_layer,{self.z:input}) 819 | 820 | def get_params(self, unwrap_shared=True, **tags): 821 | params = [] 822 | for l in self.layers: 823 | for p in l.get_params(**tags): 824 | params.append(p) 825 | return(params) 826 | # params = [p for p in l.get_params(trainable=True) for l in self.layers] 827 | # return params 828 | # return [p for p in lay.get_params(unwrap_shared,**tags) for lay in self.layers] 829 | # return lasagne.layers.get_all_params(self.final_layer,trainable=True) 830 | 831 | def shuffle(self, shuffling_type): 832 | if shuffling_type == "Once" and self.shuffled_once is False: 833 | self.mask_generator.shuffle_ordering() 834 | self.mask_generator.sample_connectivity() 835 | self.update_masks() 836 | self.shuffled_once = True 837 | return 838 | 839 | if shuffling_type in ["Ordering", "Full"]: 840 | self.mask_generator.shuffle_ordering() 841 | if shuffling_type in ["Connectivity", "Full"]: 842 | self.mask_generator.sample_connectivity() 843 | self.update_masks() 844 | 845 | def reset(self, shuffling_type, last_shuffle=0): 846 | self.mask_generator.reset() 847 | 848 | # Always do a first shuffle so that the natural order does not gives us an edge 849 | self.shuffle("Full") 850 | 851 | # Set the mask to the requested shuffle 852 | for i in range(last_shuffle): 853 | self.shuffle(shuffling_type) 854 | -------------------------------------------------------------------------------- /mask_generator.py: -------------------------------------------------------------------------------- 1 | ## Mask generator from MADE: https://github.com/mgermain/MADE 2 | 3 | import copy 4 | import theano 5 | import theano.tensor as T 6 | import numpy as np 7 | from theano.sandbox.rng_mrg import MRG_RandomStreams # Limited but works on GPU 8 | from theano.tensor.shared_randomstreams import RandomStreams 9 | # from theano.gpuarray.dnn import GpuDnnSoftmax as mysoftmax 10 | 11 | def mysoftmax(x): 12 | e_x = T.exp(x - x.max()) 13 | return e_x / e_x.sum() 14 | 15 | class MaskGenerator(object): 16 | 17 | def __init__(self, input_size, hidden_sizes, l, random_seed=1234): 18 | self._random_seed = random_seed 19 | self._mrng = MRG_RandomStreams(seed=random_seed) 20 | self._rng = RandomStreams(seed=random_seed) 21 | 22 | self._hidden_sizes = hidden_sizes 23 | self._input_size = input_size 24 | self._l = l 25 | 26 | self.ordering = theano.shared(value=np.arange(input_size, dtype=theano.config.floatX), name='ordering', borrow=False) 27 | 28 | # Initial layer connectivity 29 | self.layers_connectivity = [theano.shared(value=(self.ordering + 1).eval(), name='layer_connectivity_input', borrow=False)] 30 | for i in range(len(self._hidden_sizes)): 31 | self.layers_connectivity += [theano.shared(value=np.zeros((self._hidden_sizes[i]), dtype=theano.config.floatX), name='layer_connectivity_hidden{0}'.format(i), borrow=False)] 32 | self.layers_connectivity += [self.ordering] 33 | 34 | ## Theano functions 35 | new_ordering = self._rng.shuffle_row_elements(self.ordering) 36 | self.shuffle_ordering = theano.function(name='shuffle_ordering', 37 | inputs=[], 38 | updates=[(self.ordering, new_ordering), (self.layers_connectivity[0], new_ordering + 1)]) 39 | 40 | self.layers_connectivity_updates = [] 41 | for i in range(len(self._hidden_sizes)): 42 | self.layers_connectivity_updates += [self._get_hidden_layer_connectivity(i)] 43 | # self.layers_connectivity_updates = [self._get_hidden_layer_connectivity(i) for i in range(len(self._hidden_sizes))] # WTF THIS DO NOT WORK 44 | self.sample_connectivity = theano.function(name='sample_connectivity', 45 | inputs=[], 46 | updates=[(self.layers_connectivity[i+1], self.layers_connectivity_updates[i]) for i in range(len(self._hidden_sizes))]) 47 | 48 | # Save random initial state 49 | self._initial_mrng_rstate = copy.deepcopy(self._mrng.rstate) 50 | self._initial_mrng_state_updates = [state_update[0].get_value() for state_update in self._mrng.state_updates] 51 | 52 | # Ensuring valid initial connectivity 53 | self.sample_connectivity() 54 | 55 | def reset(self): 56 | # Set Original ordering 57 | self.ordering.set_value(np.arange(self._input_size, dtype=theano.config.floatX)) 58 | 59 | # Reset RandomStreams 60 | self._rng.seed(self._random_seed) 61 | 62 | # Initial layer connectivity 63 | self.layers_connectivity[0].set_value((self.ordering + 1).eval()) 64 | for i in range(1, len(self.layers_connectivity)-1): 65 | self.layers_connectivity[i].set_value(np.zeros((self._hidden_sizes[i-1]), dtype=theano.config.floatX)) 66 | self.layers_connectivity[-1].set_value(self.ordering.get_value()) 67 | 68 | # Reset MRG_RandomStreams (GPU) 69 | self._mrng.rstate = self._initial_mrng_rstate 70 | for state, value in zip(self._mrng.state_updates, self._initial_mrng_state_updates): 71 | state[0].set_value(value) 72 | 73 | self.sample_connectivity() 74 | 75 | def _get_p(self, start_choice): 76 | start_choice_idx = (start_choice-1).astype('int32') 77 | p_vals = T.concatenate([T.zeros((start_choice_idx,)), (self._l * T.arange(start_choice, self._input_size, dtype=theano.config.floatX))]) 78 | p_vals = T.inc_subtensor(p_vals[start_choice_idx], 1.) # Stupid hack because de multinomial does not contain a safety for numerical imprecision. 79 | return p_vals 80 | 81 | def _get_hidden_layer_connectivity(self, layerIdx): 82 | layer_size = self._hidden_sizes[layerIdx] 83 | if layerIdx == 0: 84 | p_vals = self._get_p(T.min(self.layers_connectivity[layerIdx])) 85 | else: 86 | p_vals = self._get_p(T.min(self.layers_connectivity_updates[layerIdx-1])) 87 | 88 | # #Implementations of np.choose in theano GPU 89 | # return T.nonzero(self._mrng.multinomial(pvals=[self._p_vals] * layer_size, dtype=theano.config.floatX))[1].astype(dtype=theano.config.floatX) 90 | # return T.argmax(self._mrng.multinomial(pvals=[self._p_vals] * layer_size, dtype=theano.config.floatX), axis=1) 91 | return T.sum(T.cumsum(self._mrng.multinomial(pvals=T.tile(p_vals[::-1][None, :], (layer_size, 1)), dtype=theano.config.floatX), axis=1), axis=1) 92 | 93 | def _get_mask(self, layerIdxIn, layerIdxOut): 94 | return (self.layers_connectivity[layerIdxIn][:, None] <= self.layers_connectivity[layerIdxOut][None, :]).astype(theano.config.floatX) 95 | 96 | def get_mask_layer_UPDATE(self, layerIdx): 97 | return self._get_mask(layerIdx, layerIdx + 1) 98 | 99 | def get_direct_input_mask_layer_UPDATE(self, layerIdx): 100 | return self._get_mask(0, layerIdx) 101 | 102 | def get_direct_output_mask_layer_UPDATE(self, layerIdx): 103 | return self._get_mask(layerIdx, -1) -------------------------------------------------------------------------------- /metrics_logging.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import json 4 | import logging 5 | 6 | from path import Path 7 | 8 | class MetricsLogger(object): 9 | 10 | def __init__(self, fname, reinitialize=False): 11 | self.fname = Path(fname) 12 | self.reinitialize = reinitialize 13 | if self.fname.exists(): 14 | if self.reinitialize: 15 | logging.warn('{} exists, deleting'.format(self.fname)) 16 | self.fname.remove() 17 | 18 | def log(self, record=None, **kwargs): 19 | """ 20 | Assumption: no newlines in the input. 21 | """ 22 | if record is None: 23 | record = {} 24 | record.update(kwargs) 25 | record['_stamp'] = time.time() 26 | with open(self.fname, 'ab') as f: 27 | f.write(json.dumps(record, ensure_ascii=True)+'\n') 28 | 29 | 30 | def read_records(fname): 31 | """ convenience for reading back. """ 32 | skipped = 0 33 | with open(fname, 'rb') as f: 34 | for line in f: 35 | if not line.endswith('\n'): 36 | skipped += 1 37 | continue 38 | yield json.loads(line.strip()) 39 | if skipped > 0: 40 | logging.warn('skipped {} lines'.format(skipped)) 41 | -------------------------------------------------------------------------------- /pics/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /sample_IAN.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import imp 4 | import time 5 | import logging 6 | import itertools 7 | import os 8 | 9 | import numpy as np 10 | from path import Path 11 | import theano 12 | import theano.tensor as T 13 | from theano.tensor.opt import register_canonicalize 14 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 15 | import lasagne 16 | from lasagne.layers import SliceLayer as SL 17 | 18 | 19 | import GANcheckpoints 20 | from collections import OrderedDict 21 | import matplotlib 22 | matplotlib.use('Agg') 23 | import matplotlib.pyplot as plt 24 | from fuel.datasets import CelebA 25 | from discgen_utils import plot_image_grid 26 | 27 | 28 | 29 | ## Utilities: 30 | # to_tanh: transforms an array in the range [0,255] to the range [-1,1] 31 | # from_tanh: transforms an array in the range [-1,1] to the range[0,255] 32 | def to_tanh(input): 33 | return 2.0*(input/255.0)-1.0 34 | # return input/255.0 35 | 36 | def from_tanh(input): 37 | return 255.0*(input+1)/2.0 38 | # return 255.0*input 39 | 40 | 41 | ### Make Training Functions Method 42 | # This function defines and compiles the computational graphs that define the training, validation, and test functions. 43 | 44 | def make_training_functions(cfg,model): 45 | 46 | # Define input tensors 47 | # Tensor axes are batch-channel-dim1-dim2 48 | 49 | # Image Input 50 | X = T.TensorType('float32', [False]*4)('X') 51 | 52 | # Latent Input, for providing latent values from the main function 53 | Z = T.TensorType('float32', [False]*2)('Z') # Latents 54 | 55 | # Input layer 56 | l_in = model['l_in'] 57 | 58 | # Output layer 59 | l_out = model['l_out'] 60 | 61 | # Latent Layer 62 | l_Z = model['l_Z'] 63 | 64 | # IAF latent layer: 65 | l_Z_IAF = model['l_Z_IAF'] 66 | 67 | # Means 68 | l_mu = model['l_mu'] 69 | 70 | # Log-sigmas 71 | l_ls = model['l_ls'] 72 | 73 | # IAF Means 74 | l_IAF_mu = model['l_IAF_mu'] 75 | 76 | # IAF logsigmas 77 | l_IAF_ls = model['l_IAF_ls'] 78 | 79 | # Introspective loss layers 80 | l_introspect = model['l_introspect'] 81 | 82 | # Adversarial Discriminator 83 | l_discrim = model['l_discrim'] 84 | 85 | # Sample function 86 | sample = theano.function([Z],lasagne.layers.get_output(l_out,{l_Z_IAF:Z},deterministic=True),on_unused_input='warn') 87 | 88 | sampleZ= theano.function([Z],lasagne.layers.get_output(l_out,{l_Z:Z},deterministic=True),on_unused_input='warn') 89 | 90 | # Inference Function--Infer non-IAF_latents given an input X 91 | Zfn = theano.function([X],lasagne.layers.get_output(l_Z_IAF,{l_in:X},deterministic=True),on_unused_input='warn') 92 | 93 | # IAF function--Infer IAF latents given a latent input Z 94 | Z_IAF_fn = theano.function([Z],lasagne.layers.get_output(l_Z,{l_Z_IAF:Z},deterministic=True),on_unused_input='warn') 95 | 96 | 97 | # Dictionary of Theano Functions 98 | # tfuncs = {'update_iter':update_iter, 99 | tfuncs = {'sample': sample, 100 | 'sampleZ': sampleZ, 101 | 'Zfn' : Zfn, 102 | 'Z_IAF_fn': Z_IAF_fn 103 | } 104 | 105 | # Dictionary of Theano Variables 106 | tvars = {'X' : X, 107 | 'Z' : Z} 108 | 109 | return tfuncs, tvars, model 110 | 111 | # Data Loading Function 112 | # 113 | # This function interfaces with a Fuel dataset and returns numpy arrays containing the requested data 114 | def data_loader(cfg,set,offset=0,shuffle=False,seed=42): 115 | 116 | # Define chunk size 117 | chunk_size = cfg['batch_size']*cfg['batches_per_chunk'] 118 | 119 | np.random.seed(seed) 120 | index = np.random.permutation(set.num_examples-offset) if shuffle else np.asarray(range(set.num_examples-offset)) 121 | 122 | # Open Dataset 123 | set.open() 124 | 125 | 126 | # Loop across all data 127 | for i in xrange(set.num_examples//chunk_size): 128 | yield to_tanh(np.float32(set.get_data(request = list(index[range(offset+chunk_size*i,offset+chunk_size*(i+1))]))[0])) 129 | 130 | # Close dataset 131 | set.close(state=None) 132 | 133 | 134 | # Main Function 135 | def main(args): 136 | 137 | # Load Config Module from source file 138 | config_module = imp.load_source('config', args.config_path) 139 | 140 | # Get configuration parameters 141 | cfg = config_module.cfg 142 | 143 | # Define name of npz file to which the model parameters will be saved 144 | weights_fname = str(args.config_path)[:-3]+'.npz' 145 | 146 | model = config_module.get_model(interp=False) 147 | print('Compiling theano functions...') 148 | 149 | # Compile functions 150 | tfuncs, tvars,model = make_training_functions(cfg,model) 151 | 152 | # Test set for interpolations 153 | test_set = CelebA('64',('test',),sources=('features',)) 154 | 155 | # Loop across epochs 156 | offset = True 157 | params = list(set(lasagne.layers.get_all_params(model['l_out'],trainable=True)+\ 158 | lasagne.layers.get_all_params(model['l_discrim'],trainable=True)+\ 159 | [x for x in lasagne.layers.get_all_params(model['l_out'])+\ 160 | lasagne.layers.get_all_params(model['l_discrim']) if x.name[-4:]=='mean' or x.name[-7:]=='inv_std'])) 161 | metadata = GANcheckpoints.load_weights(weights_fname, params) 162 | epoch = args.epoch if args.epoch>0 else metadata['epoch'] if 'epoch' in metadata else 0 163 | print('loading weights, epoch is '+str(epoch)) 164 | 165 | model['l_IAF_mu'].reset("Once") 166 | model['l_IAF_ls'].reset("Once") 167 | 168 | # Open Test Set 169 | test_set.open() 170 | 171 | np.random.seed(epoch*42+5) 172 | # Generate Random Samples, averaging latent vectors across masks 173 | samples = np.uint8(from_tanh(tfuncs['sample'](np.random.randn(27,cfg['num_latents']).astype(np.float32)))) 174 | 175 | 176 | np.random.seed(epoch*42+5) 177 | # Get Reconstruction/Interpolation Endpoints 178 | endpoints = np.uint8(test_set.get_data(request = list(np.random.choice(test_set.num_examples,6,replace=False)))[0]) 179 | 180 | # Get reconstruction latents 181 | Ze = np.asarray(tfuncs['Zfn'](to_tanh(np.float32(endpoints)))) 182 | 183 | # Get Interpolant Latents 184 | Z = np.asarray([Ze[2 * i, :] * (1 - j) + Ze[2 * i + 1, :] * j for i in range(3) for j in [x/6.0 for x in range(7)]],dtype=np.float32) 185 | 186 | # Get all images 187 | images = np.append(samples,np.concatenate([np.insert(endpoints[2*i:2*(i+1),:,:,:],1,np.uint8(from_tanh(tfuncs['sample'](Z[7*i:7*(i+1),:]))),axis=0) for i in range(3)],axis=0),axis=0) 188 | 189 | 190 | # Plot images 191 | plot_image_grid(images,6,9,'pics/'+str(args.config_path)[:-3]+'_sample'+str(epoch)+'.png') 192 | 193 | # Close test set 194 | test_set.close(state=None) 195 | 196 | 197 | if __name__=='__main__': 198 | parser = argparse.ArgumentParser() 199 | parser.add_argument('config_path', type=Path, help='config .py file') 200 | parser.add_argument('--epoch',type=int,default=0) 201 | args = parser.parse_args() 202 | main(args) 203 | -------------------------------------------------------------------------------- /train_IAN.py: -------------------------------------------------------------------------------- 1 | ### Introspective Adversarial Network Training Function 2 | # A Brock, 2016 3 | 4 | import argparse 5 | import imp 6 | import time 7 | import logging 8 | import itertools 9 | import os 10 | import string 11 | 12 | import numpy as np 13 | from path import Path 14 | import theano 15 | import theano.tensor as T 16 | from theano.tensor.opt import register_canonicalize 17 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 18 | import lasagne 19 | from lasagne.layers import SliceLayer as SL 20 | 21 | import metrics_logging 22 | import GANcheckpoints 23 | from collections import OrderedDict 24 | import matplotlib 25 | matplotlib.use('Agg') 26 | import matplotlib.pyplot as plt 27 | # from fuel.datasets import CelebA 28 | from discgen_utils import plot_image_grid 29 | 30 | 31 | 32 | ## Utilities: 33 | # to_tanh: transforms an array in the range [0,255] to the range [-1,1] 34 | # from_tanh: transforms an array in the range [-1,1] to the range[0,255] 35 | def to_tanh(input): 36 | return 2.0*(input/255.0)-1.0 37 | # return input/255.0 38 | 39 | def from_tanh(input): 40 | return 255.0*(input+1)/2.0 41 | # return 255.0*input 42 | 43 | 44 | ### Make Training Functions Method 45 | # This function defines and compiles the computational graphs that define the training, validation, and test functions. 46 | 47 | def make_training_functions(cfg,model): 48 | 49 | # Define input tensors 50 | # Tensor axes are batch-channel-dim1-dim2 51 | 52 | # Image Input 53 | X = T.TensorType('float32', [False]*4)('X') 54 | 55 | # Latent Input, for providing latent values from the main function 56 | Z = T.TensorType('float32', [False]*2)('Z') # Latents 57 | 58 | # Classification Tensor, only used if including a supervised or class-conditional task 59 | y = T.TensorType('float32', [False]*2)('y') 60 | 61 | # Ternary classification values 62 | p1,p2,p3 = T.TensorType('int32',[False]*2)('p1'),T.TensorType('int32',[False]*2)('p2'),T.TensorType('int32',[False]*2)('p3') 63 | 64 | # Shared and Utility Variables 65 | X_shared = lasagne.utils.shared_empty(4, dtype='float32') 66 | y_shared = lasagne.utils.shared_empty(2, dtype='float32') 67 | Z_shared = lasagne.utils.shared_empty(2, dtype='float32') 68 | p1_shared = lasagne.utils.shared_empty(2, dtype='int32') 69 | p2_shared = lasagne.utils.shared_empty(2, dtype='int32') 70 | p3_shared = lasagne.utils.shared_empty(2, dtype='int32') 71 | pi = np.cast[theano.config.floatX](np.pi) 72 | 73 | 74 | # Input layer 75 | l_in = model['l_in'] 76 | 77 | # Output layer 78 | l_out = model['l_out'] 79 | 80 | # Latent Layer 81 | l_Z = model['l_Z'] 82 | 83 | # IAF latent layer: 84 | l_Z_IAF = model['l_Z_IAF'] 85 | 86 | # Means 87 | l_mu = model['l_mu'] 88 | 89 | # Log-sigmas 90 | l_ls = model['l_ls'] 91 | 92 | # IAF Means 93 | l_IAF_mu = model['l_IAF_mu'] 94 | 95 | # IAF logsigmas 96 | l_IAF_ls = model['l_IAF_ls'] 97 | 98 | # Introspective loss layers 99 | l_introspect = model['l_introspect'] 100 | 101 | # Adversarial Discriminator 102 | l_discrim = model['l_discrim'] 103 | 104 | # Batch Indexing Parameters 105 | batch_index = T.iscalar('batch_index') 106 | batch_slice = slice(batch_index*cfg['batch_size'], (batch_index+1)*cfg['batch_size']) 107 | 108 | # Define RNG 109 | rng = RandomStreams(lasagne.random.get_rng().randint(1,69)) 110 | 111 | ############################################################################### 112 | # Step 1: Compute full forward pass, save the outputs of all relevant layers? # 113 | ############################################################################### 114 | 115 | # Build the main computational graph 116 | outputs = lasagne.layers.get_output([l_out]+[l_mu]+[l_ls]+[l_discrim]+[l_IAF_mu]+[l_IAF_ls]+l_introspect,{l_in:X}) 117 | # outputs = lasagne.layers.get_output([l_out]+[l_mu]+[l_ls]+[l_discrim]+l_introspect,{l_in:X}) 118 | # Reconstruction 119 | X_hat = outputs[0] 120 | 121 | # Latent means 122 | Z_mu = outputs[1] 123 | 124 | # Latent log-sigmas 125 | Z_ls = outputs[2] 126 | 127 | # Discriminator Output 128 | p_X = outputs[3] 129 | 130 | # Latent IAF mus 131 | Z_IAF_mu = outputs[4] 132 | 133 | # Latent IAF logsigma 134 | Z_IAF_ls = outputs[5] 135 | 136 | # Output of the encoder layers (selected for introspection) as a function of the input image 137 | g_X = outputs[6:] 138 | 139 | # Build the second half of the computational graph 140 | out_hat = lasagne.layers.get_output([l_discrim]+l_introspect,{l_in:X_hat}) 141 | 142 | # Discriminator Output given Reconstruction 143 | p_X_hat = out_hat[0] 144 | 145 | # Output of the encoder layers (selected for introspection) as a function of the reconstruction 146 | g_X_hat = out_hat[1:] 147 | 148 | # Discriminator output given random samples 149 | p_X_gen = lasagne.layers.get_output(l_discrim,{l_in:lasagne.layers.get_output(l_out,{l_Z_IAF:Z})}) 150 | 151 | 152 | ################################# 153 | # Step 2: Define loss functions # 154 | ################################# 155 | 156 | # Orthogonal normalization for all parameters 157 | # Define orthonormal residual 158 | def ortho_res(z): 159 | s = 0 160 | for x in z: 161 | if x.name[-1] is 'W' and x.ndim==4: 162 | y = T.batched_tensordot(x,x.dimshuffle(0,1,3,2),[[1,3],[1,2]]) 163 | y-=T.eye(x.shape[2],x.shape[3]).dimshuffle('x',0,1).repeat(x.shape[0],0) 164 | s+=T.sum(T.abs_(y)) 165 | return(s) 166 | 167 | 168 | # Define Pixel-wise reconstruction loss 169 | pixel_loss = T.mean(2*T.abs_(X_hat-X+1e-8)) 170 | 171 | # KL Divergence between latents and Standard Normal prior 172 | kl_div = -0.5 * T.mean(1 + 2*Z_ls - T.sqr(Z_mu) - T.exp(2 * Z_ls)) 173 | # kl_div = -T.maximum(0.5, T.mean(0.5 * (1 + 2*Z_ls - T.sqr(Z_mu) - T.exp(2 * Z_ls)) + Z_IAF_ls)) 174 | 175 | 176 | 177 | ########################## 178 | # Step 3: Define Updates # 179 | ########################## 180 | 181 | # Get Parameters 182 | 183 | # All network parameters, including log_sigma 184 | params = lasagne.layers.get_all_params(l_out,trainable=True) 185 | 186 | # Encoder Parameters 187 | encoder_params = lasagne.layers.get_all_params(l_discrim,trainable=True) 188 | 189 | # MADE parameters, along with a thing to prevent the IAF params from being trained 190 | Z_params = [p for p in lasagne.layers.get_all_params(l_Z_IAF,trainable=True) if p not in lasagne.layers.get_all_params(l_discrim,trainable=True)] 191 | print(Z_params) 192 | 193 | # Decoder Params 194 | decoder_params = [p for p in lasagne.layers.get_all_params(l_out,trainable=True) if p not in lasagne.layers.get_all_params(l_Z,trainable=True)] 195 | 196 | # Define learning rate, with provisions made for annealing schedule 197 | if isinstance(cfg['learning_rate'], dict): 198 | learning_rate = theano.shared(np.float32(cfg['learning_rate'][0])) 199 | else: 200 | learning_rate = theano.shared(np.float32(cfg['learning_rate'])) 201 | 202 | 203 | 204 | # Adversarial Stuff 205 | 206 | 207 | 208 | print('Calculating Adversarial Loss and Grads...') 209 | # Regularization terms 210 | 211 | l2_Z = cfg['reg']*lasagne.regularization.apply_penalty([p for p in lasagne.layers.get_all_params(l_Z_IAF,trainable=True,regularizable=True)\ 212 | if p not in lasagne.layers.get_all_params(l_discrim,trainable=True)], 213 | lasagne.regularization.l2) 214 | if 'ortho' in cfg: 215 | print('Applying orthogonal regularization...') 216 | l2_discrim = cfg['ortho']*lasagne.regularization.apply_penalty(lasagne.layers.get_all_params(l_Z,trainable=True,regularizable=True)\ 217 | +l_discrim.get_params(trainable=True,regularizable=True), 218 | ortho_res) 219 | 220 | l2_gen = cfg['ortho']*lasagne.regularization.apply_penalty([p for p in lasagne.layers.get_all_params(l_out,trainable=True,regularizable=True) if p not in encoder_params], 221 | ortho_res) 222 | 223 | 224 | # Adversarial Loss for Discriminator 225 | 226 | # Discriminator loss for reconstructed and generated samples 227 | # print(p_X_hat.shape[0]) 228 | discrim_g_loss = T.mean(T.nnet.categorical_crossentropy(p_X_hat,p2)) + T.mean(T.nnet.categorical_crossentropy(p_X_gen,p3)) 229 | 230 | # 231 | 232 | 233 | # Discriminator loss 234 | discrim_d_loss = T.mean(T.nnet.categorical_crossentropy(p_X, p1)) 235 | 236 | adversarial_discrim_loss = cfg['dg_weight']*discrim_g_loss+cfg['dd_weight']*discrim_d_loss 237 | 238 | 239 | # Discriminator Accuracy 240 | discrim_accuracy = (T.mean(T.eq(T.argmax(p_X,axis=1),T.argmax(p1,axis=1)))+T.mean(T.eq(T.argmax(p_X_hat,axis=1),T.argmax(p2,axis=1)))+T.mean(T.eq(T.argmax(p_X_gen,axis=1),T.argmax(p3,axis=1))))/3.0 241 | 242 | 243 | # Feature Reconstruction Loss for Generator 244 | feature_loss = T.cast(T.mean([T.mean(lasagne.objectives.squared_error(g_X[i],g_X_hat[i])) for i in xrange(len(g_X_hat))]),'float32') 245 | 246 | # Adversarial loss for Generator 247 | gen_recon_loss = T.mean(T.nnet.categorical_crossentropy(p_X_hat,p1)) 248 | gen_sample_loss = T.mean(T.nnet.categorical_crossentropy(p_X_gen,p1)) 249 | 250 | adversarial_gen_loss = cfg['agr_weight']*gen_recon_loss+cfg['ags_weight']*gen_sample_loss 251 | 252 | # Updates for discriminator 253 | discrim_updates = lasagne.updates.adam(T.grad(adversarial_discrim_loss+l2_discrim,encoder_params,consider_constant=[X_hat]),encoder_params,learning_rate,beta1=cfg['beta1']) 254 | 255 | # Updates for Generator 256 | gen_updates = lasagne.updates.adam(adversarial_gen_loss+\ 257 | cfg['recon_weight']*pixel_loss+\ 258 | cfg['feature_weight']*feature_loss+\ 259 | l2_gen,decoder_params,learning_rate,beta1=cfg['beta1']) 260 | 261 | # Optional Inference mini-network updates--only updated based on reconstructions? 262 | # Z_gen_updates = lasagne.updates.adam(adversarial_gen_loss+cfg['feature_weight']*feature_loss+cfg['recon_weight']*pixel_loss+kl_div,Z_params,learning_rate=learning_rate,beta1=cfg['beta1']) 263 | # Z_gen_updates = lasagne.updates.adam(adversarial_gen_loss+cfg['feature_weight']*feature_loss+cfg['recon_weight']*pixel_loss+kl_div,Z_params,learning_rate=learning_rate,beta1=cfg['beta1']) 264 | 265 | # Z_discrim_updates = lasagne.updates.adam(adversarial_gen_losscfg['feature_weight']*feature_loss+cfg['recon_weight']*pixel_loss+kl_div,Z_params,learning_rate=learning_rate,beta1=cfg['beta1']) 266 | Z_gen_updates = lasagne.updates.adam(cfg['feature_weight']*feature_loss+\ 267 | cfg['recon_weight']*pixel_loss+\ 268 | adversarial_gen_loss+\ 269 | kl_div+\ 270 | l2_Z, 271 | Z_params, 272 | learning_rate=learning_rate, 273 | beta1=cfg['beta1']) 274 | for ud in Z_gen_updates: 275 | gen_updates[ud] = Z_gen_updates[ud] 276 | discrim_updates[ud] = Z_gen_updates[ud] 277 | 278 | # Pixel-Wise MSE for reporting 279 | error_rate = T.cast( T.mean( T.sqr(X_hat-X)), 'float32' ) 280 | 281 | 282 | # Sample function 283 | sample = theano.function([Z],lasagne.layers.get_output(l_out,{l_Z_IAF:Z},deterministic=True),on_unused_input='warn') 284 | 285 | # Inference Function--Infer non-IAF_latents given an input X 286 | Zfn = theano.function([X],lasagne.layers.get_output(l_Z_IAF,{l_in:X},deterministic=True),on_unused_input='warn') 287 | 288 | 289 | 290 | # gen dictionary 291 | gd = OrderedDict() 292 | gd['gen_recon_loss'] = gen_recon_loss 293 | gd['gen_sample_loss'] = gen_sample_loss 294 | gd['pixel_loss'] = pixel_loss 295 | gd['feature_loss'] = feature_loss 296 | gd['pixel_acc'] = 1-error_rate 297 | 298 | # discrim dictionary 299 | dd = OrderedDict() 300 | dd['discrim_g_loss'] = discrim_g_loss 301 | dd['discrim_d_loss'] = discrim_d_loss 302 | dd['discrim_acc'] = discrim_accuracy 303 | dd['pixel_loss'] = pixel_loss 304 | dd['pixel_acc'] = 1-error_rate 305 | 306 | 307 | update_gen = theano.function([batch_index],[gd[i] for i in gd],#[adversarial_gen_loss,pixel_loss,1-error_rate], 308 | updates=gen_updates, 309 | givens = {X: X_shared[batch_slice], 310 | y: y_shared[batch_slice], 311 | Z: Z_shared[batch_slice], 312 | p1:p1_shared[batch_slice], 313 | p2:p2_shared[batch_slice], 314 | p3:p3_shared[batch_slice]}, 315 | on_unused_input = 'warn') 316 | 317 | update_discrim = theano.function([batch_index],[dd[i] for i in dd],#[discrim_g_loss,discrim_d_loss,discrim_accuracy,pixel_loss,1-error_rate], 318 | updates=discrim_updates, 319 | givens = {X: X_shared[batch_slice], 320 | y: y_shared[batch_slice], 321 | Z: Z_shared[batch_slice], 322 | p1:p1_shared[batch_slice], 323 | p2:p2_shared[batch_slice], 324 | p3:p3_shared[batch_slice]}, 325 | on_unused_input = 'warn') 326 | 327 | # Dictionary of Theano Functions 328 | # tfuncs = {'update_iter':update_iter, 329 | tfuncs = {'update_gen': update_gen, 330 | 'update_discrim': update_discrim, 331 | 'sample': sample, 332 | 'Zfn' : Zfn, 333 | } 334 | 335 | # Dictionary of Theano Variables 336 | tvars = {'X' : X, 337 | 'y' : y, 338 | 'Z' : Z, 339 | 'X_shared' : X_shared, 340 | 'y_shared' : y_shared, 341 | 'Z_shared' : Z_shared, 342 | 'p1' : p1_shared, 343 | 'p2' : p2_shared, 344 | 'p3' : p3_shared, 345 | 'batch_slice' : batch_slice, 346 | 'batch_index' : batch_index, 347 | 'learning_rate' : learning_rate, 348 | 'gd' : gd, 349 | 'dd': dd 350 | } 351 | 352 | return tfuncs, tvars, model 353 | 354 | # Data Loading Function 355 | # 356 | # This function interfaces with a Fuel dataset and returns numpy arrays containing the requested data 357 | def data_loader(cfg,set,offset=0,shuffle=False,seed=42): 358 | 359 | # Define chunk size 360 | chunk_size = cfg['batch_size']*cfg['batches_per_chunk'] 361 | 362 | np.random.seed(seed) 363 | index = np.random.permutation(set.num_examples-offset) if shuffle else np.asarray(range(set.num_examples-offset)) 364 | 365 | # Open Dataset 366 | set.open() 367 | 368 | 369 | # Loop across all data 370 | for i in xrange(set.num_examples//chunk_size): 371 | yield to_tanh(np.float32(set.get_data(request = list(index[range(offset+chunk_size*i,offset+chunk_size*(i+1))]))[0])) 372 | 373 | # Close dataset 374 | set.close(state=None) 375 | 376 | 377 | # Main Function 378 | def main(args): 379 | 380 | # Load Config Module from source file 381 | config_module = imp.load_source('config', args.config_path) 382 | 383 | # Get configuration parameters 384 | cfg = config_module.cfg 385 | 386 | # Define name of npz file to which the model parameters will be saved 387 | weights_fname = str(args.config_path)[:-3]+'.npz' 388 | 389 | # Define the name of the jsonl file to which the training log will be saved 390 | metrics_fname = weights_fname[:-4]+'METRICS.jsonl' 391 | 392 | # Prepare logs 393 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s| %(message)s') 394 | logging.info('Metrics will be saved to {}'.format(metrics_fname)) 395 | mlog = metrics_logging.MetricsLogger(metrics_fname, reinitialize=(not args.resume)) 396 | model = config_module.get_model(interp=False) 397 | 398 | logging.info('Compiling theano functions...') 399 | 400 | # Compile functions 401 | tfuncs, tvars,model = make_training_functions(cfg,model) 402 | 403 | # Shuffle Initial masks 404 | model['l_IAF_mu'].shuffle("Once") 405 | model['l_IAF_ls'].shuffle("Once") 406 | logging.info('Training...') 407 | 408 | # Iteration Counter, indicates total number of minibatches processed 409 | itr = 0 410 | 411 | # Best validation accuracy variable 412 | best_acc = 0 413 | 414 | # Test set for interpolations 415 | test_set = CelebA('64',('test',),sources=('features',)) 416 | 417 | # Loop across epochs 418 | offset = True 419 | params = list(set(lasagne.layers.get_all_params(model['l_out'],trainable=True)+\ 420 | lasagne.layers.get_all_params(model['l_discrim'],trainable=True)+\ 421 | [x for x in lasagne.layers.get_all_params(model['l_out'])+\ 422 | lasagne.layers.get_all_params(model['l_discrim'])if x.name[-4:]=='mean' or x.name[-7:]=='inv_std'])) 423 | if os.path.isfile(weights_fname) and args.resume: 424 | metadata = GANcheckpoints.load_weights(weights_fname, params) 425 | min_epoch = metadata['epoch']+1 if 'epoch' in metadata else 0 426 | new_lr = metadata['learning_rate'] if 'learning_rate' in metadata else cfg['lr_schedule'][0] 427 | tvars['learning_rate'].set_value(np.float32(new_lr)) 428 | print('loading weights, epoch is '+str(min_epoch),'lr is '+str(new_lr)+'.') 429 | else: 430 | min_epoch = 0 431 | 432 | 433 | # Ratio of gen updates to discrim updates 434 | update_ratio = cfg['update_ratio'] 435 | n_shuffles = 0 436 | for epoch in xrange(min_epoch,cfg['max_epochs']): 437 | offset = not offset 438 | 439 | # Get generator for data 440 | loader = data_loader(cfg, 441 | CelebA('64',('train',),sources=('features',)), 442 | offset=offset*cfg['batch_size']//2,shuffle=cfg['shuffle'], 443 | seed=epoch) # Does this need to happen every epoch? 444 | 445 | # Update Learning Rate, either with annealing schedule or decay rate 446 | if isinstance(cfg['learning_rate'], dict) and epoch > 0: 447 | if any(x==epoch for x in cfg['learning_rate'].keys()): 448 | lr = np.float32(tvars['learning_rate'].get_value()) 449 | new_lr = cfg['learning_rate'][epoch] 450 | logging.info('Changing learning rate from {} to {}'.format(lr, new_lr)) 451 | tvars['learning_rate'].set_value(np.float32(new_lr)) 452 | if cfg['decay_rate'] and epoch > 0: 453 | lr = np.float32(tvars['learning_rate'].get_value()) 454 | new_lr = lr*(1-cfg['decay_rate']) 455 | logging.info('Changing learning rate from {} to {}'.format(lr, new_lr)) 456 | tvars['learning_rate'].set_value(np.float32(new_lr)) 457 | 458 | # Number of Chunks 459 | iter_counter = 0 460 | 461 | # Epoch-Wise Metrics 462 | # vloss_e, floss_e, closs_e, a_g_loss_e, a_d_loss_e, d_kl_e, c_acc_e, acc_e = 0, 0, 0, 0, 0, 0, 0, 0 463 | 464 | # Loop across all chunks 465 | for x_shared in loader: 466 | 467 | # Increment Chunk Counter 468 | iter_counter+=1 469 | 470 | # Figure out number of batches 471 | num_batches = len(x_shared)//cfg['batch_size'] 472 | 473 | # Shuffle chunk 474 | # np.random.seed(42*epoch) 475 | index = np.random.permutation(len(x_shared)) 476 | 477 | # Load data onto GPU 478 | tvars['X_shared'].set_value(x_shared[index], borrow=True) 479 | tvars['Z_shared'].set_value(np.float32(np.random.randn(len(x_shared),cfg['num_latents'])),borrow=True) 480 | 481 | # Ternary adversarial objectives 482 | tvars['p1'].set_value(np.asarray([[1,0,0]]*len(x_shared),dtype=np.int32)) 483 | tvars['p2'].set_value(np.asarray([[0,1,0]]*len(x_shared),dtype=np.int32)) 484 | tvars['p3'].set_value(np.asarray([[0,0,1]]*len(x_shared),dtype=np.int32)) 485 | # Chunk Metrics 486 | metrics = OrderedDict() 487 | for gkey in tvars['gd']: 488 | metrics[gkey] = [] 489 | for dkey in tvars['dd']: 490 | metrics[dkey] = [] 491 | 492 | # Loop across all batches in chunk 493 | for bi in xrange(num_batches): 494 | 495 | 496 | # Train and record metrics 497 | if itr % (update_ratio+1)==0: 498 | gen_out = tfuncs['update_gen'](bi) 499 | for key,entry in zip(tvars['gd'],gen_out): 500 | metrics[key].append(entry) 501 | else: 502 | d_out = tfuncs['update_discrim'](bi) 503 | for key,entry in zip(tvars['dd'],d_out): 504 | metrics[key].append(entry) 505 | 506 | 507 | 508 | 509 | itr += 1 510 | 511 | for key in metrics: 512 | metrics[key] = float(np.mean(metrics[key])) 513 | 514 | # Chunk-wise metrics 515 | if (iter_counter-1) % 50 ==0: 516 | title = 'epoch itr ' 517 | form = [] 518 | for item in metrics: 519 | title = title+' '+str(item) 520 | form.append(len(str(item))) 521 | 522 | logging.info(title) 523 | log_output = '%4d '%epoch + '%6d '%itr 524 | for f,item in zip(form,metrics): 525 | e = '%'+str(f)+'.4f' 526 | log_output = log_output+' '+e%metrics[item] 527 | logging.info(log_output) 528 | # logging.info('epoch: {:4d}, itr: {:8d}, ag_loss: {:7.4f}, adg_loss: {:7.4f}, add_loss: {:7.4f}, acc: {:5.3f}, ploss: {:7.4f}, pacc: {:5.3f}'.format(epoch,itr,agloss,adgloss,addloss,accuracy,ploss,pixel_accuracy)) 529 | mlog.log(epoch=epoch,itr=itr,metrics=metrics) 530 | # Log Chunk Metrics 531 | 532 | 533 | 534 | # If we see improvement, save weights and produce output images 535 | # if cfg['reconstruct'] or cfg['introspect']: 536 | if not (epoch%cfg['checkpoint_every_nth']): 537 | 538 | # Open Test Set 539 | test_set.open() 540 | 541 | np.random.seed(epoch*42+5) 542 | # Generate Random Samples, averaging latent vectors across masks 543 | samples = np.uint8(from_tanh(tfuncs['sample'](np.random.randn(27,cfg['num_latents']).astype(np.float32)))) 544 | 545 | 546 | np.random.seed(epoch*42+5) 547 | # Get Reconstruction/Interpolation Endpoints 548 | endpoints = np.uint8(test_set.get_data(request = list(np.random.choice(test_set.num_examples,6,replace=False)))[0]) 549 | 550 | # Get reconstruction latents 551 | Ze = np.asarray(tfuncs['Zfn'](to_tanh(np.float32(endpoints)))) 552 | 553 | # Get Interpolant Latents 554 | Z = np.asarray([Ze[2 * i, :] * (1 - j) + Ze[2 * i + 1, :] * j for i in range(3) for j in [x/6.0 for x in range(7)]],dtype=np.float32) 555 | 556 | # Get all images 557 | images = np.append(samples,np.concatenate([np.insert(endpoints[2*i:2*(i+1),:,:,:],1,np.uint8(from_tanh(tfuncs['sample'](Z[7*i:7*(i+1),:]))),axis=0) for i in range(3)],axis=0),axis=0) 558 | 559 | 560 | # Plot images 561 | plot_image_grid(images,6,9,'pics/'+str(args.config_path)[:-3]+'_'+str(epoch)+'.png') 562 | 563 | # Close test set 564 | test_set.close(state=None) 565 | 566 | # Save weights 567 | params = list(set(lasagne.layers.get_all_params(model['l_out'],trainable=True)+\ 568 | lasagne.layers.get_all_params(model['l_discrim'],trainable=True)+\ 569 | [x for x in lasagne.layers.get_all_params(model['l_out'])+\ 570 | lasagne.layers.get_all_params(model['l_discrim'])if x.name[-4:]=='mean' or x.name[-7:]=='inv_std'])) 571 | GANcheckpoints.save_weights(weights_fname, params,{'epoch':epoch,'itr': itr, 'ts': time.time(),'learning_rate':np.float32(tvars['learning_rate'].get_value())}) 572 | 573 | logging.info('training done') 574 | 575 | 576 | if __name__=='__main__': 577 | parser = argparse.ArgumentParser() 578 | parser.add_argument('config_path', type=Path, help='config .py file') 579 | parser.add_argument('--resume',type=bool,default=False) 580 | args = parser.parse_args() 581 | main(args) 582 | -------------------------------------------------------------------------------- /train_IAN_simple.py: -------------------------------------------------------------------------------- 1 | ### 2 | # Hierarchical Adversarial Introspective Autoencoder Main training Function 3 | # 4 | # A Brock 2016 5 | # TO DO: 6 | # 1. Cleanup 7 | # 2. Split validation/test sets appropriately, do early stopping on validation set, do reconstructions on test set 8 | # 13. Learn to render to HTML with VTK, and do the 2d-manifold thing. Consider rendering directly to a webpage? 9 | 10 | # 16. Read and UNDERSTAND Batch Norm and Improved Gan Training papers. Really no excuse for not absolutely understanding these; consider checking WeightNorm (by salimans?) too. 11 | 12 | # 19. Add code to automatically test if a saved npz of the weights already exists 13 | 14 | 15 | 16 | # 27. Figure out how to incorporate log likelihood into introspective reconstruction error, along with new log_sigma_thetas 17 | 18 | # 32. Do we want the classifier to take in the LATENTS or the layer just before the latents? 19 | # 33. Do we want to propagate the log_sigma through the decoder network somehow to make log-likelihood work better? I'd prefer to use some entropy measure, honestly 20 | 21 | # 36. Add support for class-conditional-ness 22 | 23 | # 38. Test regime: number of latent variables vs. accuracy 24 | 25 | # 40. Multi-GPU? Split a batch across two GPUs? Send the autoencoding batch to one gpu, send the adversarial batch to another 26 | # 42. Get a validation set together 27 | # 28 | # Consider hacking the adversarial BCE by changing from zeros/ones to -1's and 2's. 29 | 30 | # consider only using adversarial gen loss on random samples 31 | 32 | # Consider the effects of batch normalization on an X_hat batch split into half-recon and half-generated, 33 | # specifically: 1. Does batch-norm depend on the 0-axis of the tensor at all, such that we need to shuffle that tensor? I don't think it doesn 34 | # 2. Should we be splitting and separately batch norming the recon slice and the generated slice? It wouldn't increase 35 | # computational complexity too much, I don't think, though it depends on what layer we split the bnorm in--does this split 36 | # persist at all intermediate layers, or just on the output? How exactly does this split work anyhow? 37 | 38 | # Consider replacing inception modules with inception-style context aggregation modules using dilated convolutions, or concatenating them 39 | # to get multiple feature maps? i.e. including dilated convs as one of the possible pieces of an inception module 40 | 41 | # Consider the information path from the intermediate layers to the hierarchical latents--can we reduce the number of parameters 42 | # And improve expressiveness by replacing FC layers with convolutional/inception-style/context aggregation layers, and if so 43 | # what exactly are we doing by doing this? How does replacing the FC layer with a set of conv layers differ from just FC'ing to a 44 | # deeper layer? Can we do this in the output of the hierarchical layer as well, replacing it with an inception upscale layer? 45 | 46 | # Consider replacing nesterov Momentum with Adam or adamax, using our own implementation that doesn't toss preexisting grads 47 | # Consider doing everything as SGD in the middle, then applying ADAM or Nesterov Momentum at the very end to the entire Updates dict 48 | # (i.e. producing an "apply_adam" method a la lasagne's implementation of Nesterov Momentum) 49 | 50 | # Consider adding in provisions to completely abandon the pixel-wise reconstruction, and only use introspection+adversarial loss 51 | # Consider returning more meaningful adversarial loss metrics so that we can observe adversarial performance more accurately. 52 | # Consider weighting adversarial loss so that the discriminator doesn't just learn to be dumb and let the generator walk all over it, 53 | # i.e. running into a super-local optimum such that it always outputs "True," thereby always being right when examining a real image and 54 | # letting the generator always get its full win when examining a generated image. 55 | # learn the difference between "i.e." and "e.g." 56 | 57 | 58 | # Some sort of attention-esque mechanism...maybe give each latent a particular receptive field? Or weight pixel-wise reconstruction 59 | # by each pixel's distance from center, forcing it to produce outer details better 60 | # Is the data shuffled? How is it ordered? Do we need to shuffle it ourselves? 61 | 62 | # Is a 500-dim hierarchical latent representation realy equivalent to a 500-dim non-hierarchical rep if 63 | # it necessarily increases the number of layer parameters? 64 | 65 | # Add in text to our generated images, indicating configuration parameters, epoch number, and accuracy, for easy reference 66 | 67 | # Scale the adversarial gradients by the reconstruction loss so that one objective does not overtake the other 68 | 69 | 70 | # Develop gui that allows you to explore 3D latent space 71 | # Maybe even do that for hierarchical latent space where the lower-level inputs are set to zero. We can run inference on 32*32*32 in real time, 72 | # we can damn well do it in 3*(2*32)*(2*32) = 12*32*32 73 | # Version Notes: 74 | # 75 | # 76 | # introspection notes: 77 | # Get more improvement/size of latent space; seems to also generalize better (the avg training error is closer to the validation error) 78 | 79 | # Figure out a way to shuffle back and forth such that we reconstruct/adversarialize on alternating halves of the dataset and 80 | # Add option to choose between nesterov momentum and adam 81 | 82 | # Consider adding in log_sigma parameters for the introspective losses as well 83 | 84 | # 85 | 86 | ## Hierarchical Adversarial Autoencoder 87 | # 88 | # Consider throwing in a "output pictures and save only if validation accuracy improves" 89 | # Version 3+: Changing from (0,1) to (-1,1) and using Tanh in place of sigmoids 90 | 91 | 92 | 93 | ## Notes: orthogonal regularization 94 | # -Only applied to weight vectors! make sure this only happens on convolution filters or square matrices! 95 | # instead of Eye, do we need to scale by eye/num_filts? 96 | import argparse 97 | import imp 98 | import time 99 | import logging 100 | import itertools 101 | 102 | import numpy as np 103 | from path import Path 104 | import theano 105 | import theano.tensor as T 106 | from theano.tensor.opt import register_canonicalize 107 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 108 | import lasagne 109 | from lasagne.layers import SliceLayer as SL 110 | 111 | import voxnet 112 | import CAcheckpoints 113 | import GANcheckpoints 114 | from collections import OrderedDict 115 | import matplotlib 116 | matplotlib.use('Agg') 117 | import matplotlib.pyplot as plt 118 | from fuel.datasets import CelebA 119 | from discgen_utils import plot_image_grid 120 | 121 | 122 | 123 | ## Utilities: 124 | # to_tanh: transforms an array in the range [0,255] to the range [-1,1] 125 | # from_tanh: transforms an array in the range [-1,1] to the range[0,255] 126 | def to_tanh(input): 127 | return 2.0*(input/255.0)-1.0 128 | # return input/255.0 129 | 130 | def from_tanh(input): 131 | return 255.0*(input+1)/2.0 132 | # return 255.0*input 133 | 134 | 135 | ### Make Training Functions Method 136 | # This function defines and compiles the computational graphs that define the training, validation, and test functions. 137 | 138 | def make_training_functions(cfg,model): 139 | 140 | # Define input tensors 141 | # Tensor axes are batch-channel-dim1-dim2 142 | 143 | # Image Input 144 | X = T.TensorType('float32', [False]*4)('X') 145 | 146 | # Latent Input, for providing latent values from the main function 147 | Z = T.TensorType('float32', [False]*2)('Z') # Latents 148 | 149 | # Classification Tensor, only used if including a supervised or class-conditional task 150 | y = T.TensorType('float32', [False]*2)('y') 151 | 152 | # Shared and Utility Variables 153 | X_shared = lasagne.utils.shared_empty(4, dtype='float32') 154 | y_shared = lasagne.utils.shared_empty(2, dtype='float32') 155 | Z_shared = lasagne.utils.shared_empty(2, dtype='float32') 156 | pi = np.cast[theano.config.floatX](np.pi) 157 | 158 | 159 | # Input layer 160 | l_in = model['l_in'] 161 | 162 | # Output layer 163 | l_out = model['l_out'] 164 | 165 | # Latent Layer 166 | l_latents = model['l_latents'] 167 | 168 | # Means 169 | l_mu = model['l_mu'] 170 | 171 | # Log-sigmas 172 | l_ls = model['l_ls'] 173 | 174 | # Introspective loss layers 175 | l_introspect = model['l_introspect'] 176 | 177 | # Adversarial Discriminator 178 | l_discrim = model['l_discrim'] 179 | 180 | # Classifier 181 | l_classifier = model['l_classifier'] 182 | 183 | # Class-conditional latents 184 | l_cc = model['l_cc'] 185 | 186 | # Decoder Layers, including final output layer. Consider calling this something difference to indicate that it is 187 | # actually a list of layers, not just a single layer. 188 | l_decoder = lasagne.layers.get_all_layers(l_out)[len(lasagne.layers.get_all_layers(l_latents)):] 189 | 190 | 191 | # Batch Indexing Parameters 192 | batch_index = T.iscalar('batch_index') 193 | batch_slice = slice(batch_index*cfg['batch_size'], (batch_index+1)*cfg['batch_size']) 194 | 195 | # Define RNG 196 | rng = RandomStreams(lasagne.random.get_rng().randint(1,69)) 197 | 198 | ############################################################################### 199 | # Step 1: Compute full forward pass, save the outputs of all relevant layers? # 200 | ############################################################################### 201 | 202 | # Build the main computational graph 203 | # if cfg['reconstruct'] or cfg['introspect']: 204 | # if cfg['adversarial']: 205 | # outputs = lasagne.layers.get_output([l_out]+[l_mu]+[l_ls]+\ 206 | # [l_classifier]+[l_discrim]+[SL(l,indices=slice(0,cfg['batch_size']//2),axis=0)\ 207 | # for l in l_introspect],\ 208 | # {l_in:X, 209 | # model['l_cc']:y, 210 | # model['l_Z_rand']:rng.normal(cfg['batch_size']//2,cfg['num_latents'])}) # Consider swapping l_classifier in for l_latents. may need to tab this properly 211 | # else: 212 | # outputs = lasagne.layers.get_output([l_out]+[lasagne.layers.ConcatLayer([lasagne.layers.flatten(mu) for mu in l_mu])]+[lasagne.layers.ConcatLayer([lasagne.layers.flatten(ls) for ls in l_ls])]+\ 213 | # [l_classifier]+[l_discrim]+l_introspect,\ 214 | # {l_in:X, 215 | # model['l_cc']:y}) # Consider swapping l_classifier in for l_latents. may need to tab this properly 216 | # elif cfg['adversarial']: 217 | outputs = lasagne.layers.get_output([l_out]+[l_mu]+[l_ls]+\ 218 | [l_classifier]+[l_discrim]+l_introspect,\ 219 | {l_in:X, 220 | model['l_cc']:y, 221 | model['l_Z_rand']:Z}) 222 | 223 | # Reconstruction 224 | X_hat = outputs[0] 225 | 226 | # Latent means 227 | Z_mu = outputs[1] 228 | 229 | # Latent log-sigmas 230 | Z_ls = outputs[2] 231 | 232 | # Classification Predictions 233 | y_hat = outputs[3] 234 | 235 | # Discriminator Output 236 | p_X = outputs[4] 237 | 238 | # Output of the encoder layers (selected for introspection) as a function of the input image 239 | g_X = outputs[5:] 240 | 241 | # Build the second half of the computational graph 242 | # if cfg['adversarial']: 243 | # out_hat = lasagne.layers.get_output([l_discrim]+[SL(l,indices=slice(0,cfg['batch_size']//2),axis=0) for l in l_introspect],{l_in:X_hat}) 244 | # else: 245 | # out_hat = lasagne.layers.get_output([l_discrim]+l_introspect,{l_in:X_hat}) 246 | out_hat = lasagne.layers.get_output([l_discrim]+l_introspect,{l_in:X_hat}) 247 | # Discriminator Output given Reconstruction/samples 248 | p_X_hat = out_hat[0] 249 | 250 | # Output of the encoder layers (selected for introspection) as a function of the reconstruction 251 | g_X_hat = out_hat[1:] 252 | 253 | p_X_gen = lasagne.layers.get_output(l_discrim,{l_in:lasagne.layers.get_output(l_out,{model['l_latents']:Z})}) 254 | 255 | # Build the testing computational graph 256 | out_d = lasagne.layers.get_output([l_out,l_classifier]+[SL(lat,indices=slice(0,cfg['batch_size']//2),axis=0) for lat in l_latents], 257 | {l_in:X, 258 | model['l_cc']:y, 259 | model['l_Z_rand']:rng.normal((cfg['batch_size']//2,cfg['num_latents']))}, 260 | deterministic=True) if cfg['adversarial'] and cfg['reconstruct'] else\ 261 | lasagne.layers.get_output([l_out, 262 | l_classifier]+l_latents, 263 | {l_in:X, 264 | model['l_cc']:y}, 265 | deterministic=True) if cfg['hierarchical'] else\ 266 | lasagne.layers.get_output([l_out, 267 | l_classifier,l_latents], 268 | {model['l_Z_rand']:rng.normal((cfg['batch_size'],cfg['num_latents']))}, 269 | deterministic=True) 270 | X_hat_deterministic = out_d[0] 271 | y_hat_deterministic = out_d[1] 272 | latent_values = out_d[2:] if cfg['hierarchical'] else None 273 | 274 | ################################# 275 | # Step 2: Define loss functions # 276 | ################################# 277 | 278 | # Orthogonal normalization for all parameters 279 | # Define orthonormal residual 280 | def ortho_res(z): 281 | s = 0 282 | for x in z: 283 | if x.name[4:8] is 'conv': 284 | y = T.batched_tensordot(x,x.dimshuffle(0,1,3,2),[[1,3],[1,2]]) 285 | y-=T.eye(x.shape[2],x.shape[3]).dimshuffle('x',0,1).repeat(x.shape[0],0) 286 | s+=T.sum(T.abs(y)) 287 | return(s) 288 | # return sum([T.sum( T.abs_(T.batched_tensordot(x,x.dimshuffle(0,1,3,2),[[1,3],[1,2]])-T.eye(x.shape[2],x.shape[3]).dimshuffle('x',0,1).repeat(x.shape[0],0))) for x in z]) 289 | 290 | # return T.sum(T.batched_dot(x,x.dimshuffle(1,0,2))-T.identity_like(x)) 291 | 292 | # y = T.batched_dot(x,x.T) 293 | # return T.sum(y-T.identity_like(x)) 294 | # y = [T.batched_dot(x,x.T) for x in z] 295 | # return T.sum([T.sum(x-T.identity_like(x)) for x in y]) 296 | l2_all = lasagne.regularization.regularize_network_params(l_out, 297 | lasagne.regularization.l2,tags={'regularizable':True,'trainable':True}) 298 | 299 | # Create log_sigma parameter 300 | log_sigma_theta = lasagne.utils.create_param(spec=np.zeros((3, 64, 64)),shape=(3,64,64), name='log_sigma_theta') 301 | log_sigma = log_sigma_theta.dimshuffle('x', 0, 1, 2) 302 | 303 | # Define Pixel-wise reconstruction loss 304 | # if cfg['reconstruct'] or cfg['introspect']: 305 | # pixel_loss = 0.5 * T.mean(T.log(2 * pi) + 2 * log_sigma + T.sqr(lasagne.nonlinearities.sigmoid(X_hat[:cfg['batch_size']//2,:,:,:])\ 306 | # - X[:cfg['batch_size']//2,:,:,:]) / T.exp(2 * log_sigma)) if cfg['adversarial']\ 307 | # else 0.5 * T.mean(T.log(2 * pi) + 2 * log_sigma + T.sqr(lasagne.nonlinearities.sigmoid(X_hat) - X) / (T.exp(2 * log_sigma)*(T.abs_(X-0.5)+0.5))) 308 | # else: 309 | # pixel_loss = None 310 | # pixel_loss = 0.5 * T.mean(T.log(2 * pi) + 2 * log_sigma + T.sqr(X_hat - X) / (T.exp(2 * log_sigma))) 311 | # pixel_loss = T.mean(2.0*T.log(0.5*( T.exp(100*(X_hat-X)) + T.exp(-100*(X_hat-X)) ) ) ) 312 | pixel_loss = T.mean(2*T.abs_(X_hat-X+1e-8)) 313 | # KL Divergence between latents and Standard Normal prior 314 | kl_div = -0.5 * T.mean(1 + 2*Z_ls - T.sqr(Z_mu) - T.exp(2 * Z_ls)) 315 | # kl_div = -0.5 * T.mean(1 + 2*Z_ls[:cfg['batch_size']//2,:] - T.sqr(Z_mu[:cfg['batch_size']//2,:]) - T.exp(2 * Z_ls[:cfg['batch_size']//2,:])) if cfg['adversarial'] and cfg['reconstruct']\ 316 | # else -0.5 * T.mean(1 + 2*Z_ls - T.sqr(Z_mu) - T.exp(2 * Z_ls)) if cfg['reconstruct'] or cfg['introspect']\ 317 | # else None 318 | 319 | # Classification objective losses 320 | if cfg['discriminative']: 321 | print('Calculating Classification Loss and Grads...') 322 | # Calculate Classifier Loss 323 | classifier_loss = T.cast(T.mean(T.nnet.categorical_crossentropy(T.nnet.softmax(y_hat), y)), 'float32') 324 | 325 | # Classifier Training Accuracy 326 | classifier_error_rate = T.cast( T.mean( T.neq(T.argmax(y_hat,axis=1), T.argmax(y,axis=1)) ), 'float32' ) 327 | 328 | # Classifier Validation/Test Accuracy 329 | classifier_test_error_rate = T.cast( T.mean( T.neq(T.argmax(y_hat_deterministic,axis=1), T.argmax(y,axis=1))), 'float32' ) 330 | 331 | # Combined losses 332 | reg_pixel_loss = pixel_loss + cfg['reg']*l2_all +classifier_loss+kl_div if cfg['kl_div'] else pixel_loss + cfg['reg']*l2_all +classifier_loss 333 | 334 | else: 335 | classifier_loss = None 336 | classifier_error_rate = None 337 | classifier_test_error_rate = None 338 | reg_pixel_loss = pixel_loss + cfg['reg']*l2_all+kl_div if cfg['kl_div'] and cfg['reconstruct'] else pixel_loss + cfg['reg']*l2_all if cfg['reconstruct'] else None 339 | 340 | ########################## 341 | # Step 3: Define Updates # 342 | ########################## 343 | 344 | # Get Parameters 345 | 346 | # All network parameters, including log_sigma 347 | params = lasagne.layers.get_all_params(l_out,trainable=True)+[log_sigma_theta] 348 | 349 | # Encoder Parameters 350 | encoder_params = lasagne.layers.get_all_params(l_latents,trainable=True)+l_discrim.get_params(trainable=True) 351 | 352 | # Decoder Params--consider including log_sigma_theta in this, too? 353 | decoder_params = [p for p in lasagne.layers.get_all_params(l_out,trainable=True) if p not in encoder_params] 354 | # decoder_params = lasagne.layers.get_all_params(l_out,trainable=True) 355 | Z_params = [p for p in lasagne.layers.get_all_params(l_latents,trainable=True) if p not in lasagne.layers.get_all_params(l_discrim,trainable=True)] 356 | 357 | 358 | # Define learning rate, with provisions made for annealing schedule 359 | if isinstance(cfg['learning_rate'], dict): 360 | learning_rate = theano.shared(np.float32(cfg['learning_rate'][0])) 361 | else: 362 | learning_rate = theano.shared(np.float32(cfg['learning_rate'])) 363 | 364 | # Prepare the pixel-wise reconstruction updates 365 | # if cfg['reconstruct']: 366 | # print('Calculating Pixel-wise Loss and Grads...') 367 | # updates = lasagne.updates.adam(reg_pixel_loss,params,learning_rate) 368 | # elif cfg['kl_div']: 369 | # updates = lasagne.updates.adam(kl_div,encoder_params,learning_rate) 370 | # else: 371 | # updates = OrderedDict() 372 | # updates=lasagne.updates.adam( 373 | 374 | # grads = T.grad(cost = reg_pixel_loss, wrt = params) # Optionally calculate gradients directly 375 | 376 | 377 | # Adversarial Stuff 378 | if cfg['adversarial']: 379 | 380 | 381 | print('Calculating Adversarial Loss and Grads...') 382 | # Regularizations 383 | # l2_discrim = lasagne.regularization.regularize_network_params(l_discrim, 384 | # lasagne.regularization.l2,tags={'regularizable':True,'trainable':True}) 385 | l2_discrim = lasagne.regularization.apply_penalty(encoder_params,lasagne.regularization.l2) 386 | l2_gen = lasagne.regularization.apply_penalty(decoder_params,lasagne.regularization.l2) 387 | # l2_gen = lasagne.regularization.regularize_network_params(l_out, 388 | # lasagne.regularization.l2,tags={'regularizable':True,'trainable':True}) 389 | 390 | 391 | # Adversarial Loss for Discriminator 392 | # adversarial_discrim_loss = T.mean(T.nnet.binary_crossentropy(T.clip( p_X_hat , 1e-7, 1.0 - 1e-7), T.zeros(p_X_hat.shape)))\ 393 | # + T.mean(T.nnet.binary_crossentropy(T.clip( p_X , 1e-7, 1.0 - 1e-7), T.ones(p_X.shape)))+cfg['reg']*l2_discrim 394 | feature_loss = T.cast(T.mean([T.mean(lasagne.objectives.squared_error(g_X[i],g_X_hat[i])) for i in xrange(len(g_X_hat))]),'float32') 395 | discrim_g_loss = T.mean(T.nnet.binary_crossentropy(T.clip( p_X_hat , 1e-7, 1.0 - 1e-7), T.zeros(p_X_hat.shape)))+\ 396 | T.mean(T.nnet.binary_crossentropy(T.clip( p_X_gen , 1e-7, 1.0 - 1e-7), T.zeros(p_X_gen.shape))) 397 | discrim_d_loss = T.mean(T.nnet.binary_crossentropy(T.clip( p_X , 1e-7, 1.0 - 1e-7), T.ones(p_X.shape))) 398 | adversarial_discrim_loss = discrim_g_loss+discrim_d_loss+cfg['reg']*l2_discrim#+kl_div+cfg['recon_weight']*pixel_loss 399 | 400 | 401 | # Discriminator Accuracy 402 | discrim_accuracy = (T.mean(T.ge(p_X,0.5))+T.mean(T.lt(p_X_hat,0.5)))/2 403 | 404 | # Adversarial Loss for Generator 405 | adversarial_gen_loss = T.mean(T.nnet.binary_crossentropy(T.clip( p_X_hat , 1e-7, 1.0 - 1e-7), T.ones(p_X_hat.shape)))+\ 406 | T.mean(T.nnet.binary_crossentropy(T.clip( p_X_gen , 1e-7, 1.0 - 1e-7), T.ones(p_X_gen.shape)))+\ 407 | cfg['reg']*l2_gen 408 | # adversarial_gen_loss = T.mean(-T.log(T.clip( p_X_hat , 1e-7, 1.0 - 1e-7)/(1-T.clip( p_X_hat , 1e-7, 1.0 - 1e-7)))) + cfg['reg']*l2_gen 409 | # Total Adversarial Loss 410 | # adversarial_loss = adversarial_discrim_loss+adversarial_gen_loss 411 | 412 | # Optional: Expressions not to backpropagate through, for use with "consider_constant" in T.grad 413 | # block = list(itertools.chain.from_iterable([i.get_params() for i in lasagne.layers.get_all_layers(model['l_dec_fc1'])[len(lasagne.layers.get_all_layers(model['l_latents']))+1:-1]])) 414 | 415 | # Adversarial Gradients for Discriminator 416 | # adversarial_discrim_grads = T.grad(cost = adversarial_discrim_loss, wrt = lasagne.layers.get_all_params(l_discrim,trainable=True))#, consider_constant = [X_hat]) 417 | 418 | # Adversarial Gradients for Generator 419 | # adversarial_gen_grads = T.grad(cost = adversarial_gen_loss, wrt = lasagne.layers.get_all_params(l_out,trainable=True)) 420 | 421 | # Prepare Adversarial Updates with Adam 422 | # updates = lasagne.updates.adam(adversarial_discrim_grads+adversarial_gen_grads, 423 | # lasagne.layers.get_all_params(l_discrim,trainable=True)+decoder_params,learning_rate,beta1=cfg['beta1']) if cfg['optimizer']=='Adam'\ 424 | # else lasagne.updates.nesterov_momentum(adversarial_discrim_grads+adversarial_gen_grads, 425 | # lasagne.layers.get_all_params(l_discrim,trainable=True)+decoder_params,learning_rate,cfg['momentum']) 426 | discrim_updates = lasagne.updates.adam(adversarial_discrim_loss,encoder_params,learning_rate,beta1=cfg['beta1']) 427 | discrim_to_latent_updates = lasagne.updates.adam(cfg['feature_weight']*feature_loss+cfg['recon_weight']*pixel_loss+kl_div, 428 | Z_params, 429 | learning_rate=learning_rate,beta1=cfg['beta1']) 430 | for ud in discrim_to_latent_updates: 431 | discrim_updates[ud] = discrim_to_latent_updates[ud] 432 | 433 | gen_updates = lasagne.updates.adam(adversarial_gen_loss+cfg['recon_weight']*pixel_loss+cfg['feature_weight']*feature_loss,decoder_params,learning_rate,beta1=cfg['beta1']) 434 | 435 | # for param in lasagne.layers.get_all_params(model['l_mu'],trainable=True)[-3:]+lasagne.layers.get_all_params(model['l_ls'],trainable=True)[-3:]: 436 | # gen_updates[param] = lasagne.updates.adam(cfg['feature_weight']*feature_loss+cfg['recon_weight']*pixel_loss+kl_div,[param],learning_rate,beta1=cfg['beta1']) 437 | for ud in discrim_to_latent_updates: 438 | gen_updates[ud] = discrim_to_latent_updates[ud] 439 | # if cfg['reconstruct'] or cfg[: 440 | # for param in adversarial_updates: 441 | # updates[param] = updates[param] + adversarial_updates[param] - param if param in updates else adversarial_updates[param] 442 | # Prepare Adversarial Updates with Nesterov Momentum 443 | # for param,grad in zip(lasagne.layers.get_all_params(l_discrim,trainable=True)+decoder_params,adversarial_discrim_grads+adversarial_gen_grads): 444 | # value = param.get_value(borrow=True) 445 | # velocity = theano.shared(np.zeros(value.shape, dtype=value.dtype), 446 | # broadcastable=param.broadcastable) 447 | # x = cfg['momentum'] * velocity - learning_rate*grad 448 | # updates[velocity] = x 449 | # updates[param] = updates[param] + x*cfg['momentum'] - learning_rate*grad if param in updates else param + x*cfg['momentum'] - learning_rate*grad 450 | else: 451 | adversarial_gen_loss = None 452 | adversarial_discrim_loss = None 453 | 454 | if cfg['introspect']: 455 | print('Calculating Introspective Loss and Grads...') 456 | 457 | # Introspective Loss Term 458 | # Optionally include term to scale losses such that deeper layers are considered more important than more shallow layers 459 | feature_loss = T.cast(T.mean([T.mean(lasagne.objectives.squared_error(g_X[i],g_X_hat[i])) for i in xrange(len(g_X_hat))]),'float32') 460 | 461 | # Introspective Gradients 462 | feature_grads = lasagne.updates.get_or_compute_grads(feature_loss,decoder_params) 463 | 464 | # Prepare Introspective Updates with Adam 465 | feature_updates = lasagne.updates.adam(feature_grads,decoder_params,learning_rate) if cfg['optimizer']=='Adam'\ 466 | else lasagne.updates.nesterov_momentum(feature_grads,decoder_params,learning_rate,cfg['momentum']) 467 | 468 | for param in feature_updates: 469 | updates[param] = updates[param] + feature_updates[param] - param if param in updates else feature_updates[param] 470 | 471 | # Prepare Introspective Updates with Nesterov Momentum 472 | # for param,grad in zip(decoder_params,feature_grads): 473 | # value = param.get_value(borrow=True) 474 | # velocity = theano.shared(np.zeros(value.shape, dtype=value.dtype), 475 | # broadcastable=param.broadcastable) 476 | # x = cfg['momentum'] * velocity - learning_rate*grad 477 | # updates[velocity] = x 478 | # updates[param] = updates[param] + x*cfg['momentum'] - learning_rate*grad if param in updates else param + x*cfg['momentum'] - learning_rate * grad 479 | 480 | 481 | else: 482 | feature_loss = None 483 | 484 | # Pixel-wise Training MSE 485 | # error_rate = T.cast( T.mean( T.sqr(lasagne.nonlinearities.sigmoid(X_hat[:cfg['batch_size']//2,:,:,:])-X[:cfg['batch_size']//2,:,:,:])), 'float32' ) if cfg['adversarial'] and cfg['reconstruct']\ 486 | # else T.cast( T.mean( T.sqr(lasagne.nonlinearities.sigmoid(X_hat)-X)), 'float32' ) if cfg['reconstruct'] or cfg['introspect']\ 487 | # else None 488 | error_rate = T.cast( T.mean( T.sqr(X_hat-X)), 'float32' ) 489 | # Pixel-wise Test MSE 490 | test_error_rate = T.cast( T.mean( T.sqr(lasagne.nonlinearities.sigmoid(X_hat_deterministic[:cfg['batch_size']//2,:,:,:])-X[:cfg['batch_size']//2,:,:,:])), 'float32' ) if cfg['adversarial'] and cfg['reconstruct']\ 491 | else T.cast( T.mean( T.sqr(lasagne.nonlinearities.sigmoid(X_hat_deterministic)-X)), 'float32' ) if cfg['reconstruct'] or cfg['introspect']\ 492 | else None 493 | 494 | # Sample function 495 | if cfg['hierarchical']: 496 | sample = theano.function([Z],lasagne.nonlinearities.sigmoid(lasagne.layers.get_output(model['l_out'],\ 497 | {l_Z:T.reshape(Z[:,a[0]:a[1]],(T.shape(Z)[0],)+dim) for l_Z,a,dim in zip(model['l_latents'], zip(np.append(0,cfg['latent_indices'][:-1]),cfg['latent_indices']),cfg['latent_dims'])},deterministic=True)),on_unused_input='warn') 498 | else: 499 | sample = theano.function([Z],lasagne.layers.get_output(model['l_out'],{model['l_latents']:Z},deterministic=True)) 500 | # Inference Function--Infer latents given an image 501 | # Zfn = theano.function([X],T.concatenate([T.flatten(l,2) for l in latent_values],axis=1),on_unused_input='warn') if cfg['reconstruct'] or cfg['introspect'] else None 502 | Zfn = theano.function([X],lasagne.layers.get_output(model['l_latents'],{model['l_in']:X},deterministic=True),on_unused_input='warn') 503 | 504 | 505 | # Outputs for Update Function 506 | update_outs = [x for x in [pixel_loss, 507 | feature_loss, 508 | classifier_loss, 509 | adversarial_gen_loss, 510 | adversarial_discrim_loss, 511 | kl_div, 512 | classifier_error_rate, 513 | error_rate, 514 | ] if x is not None] 515 | 516 | # Define Update Function 517 | # update_iter = theano.function([batch_index],update_outs, 518 | # updates=updates, givens={ 519 | # X: X_shared[batch_slice], 520 | # y: y_shared[batch_slice], 521 | # Z: Z_shared[batch_slice], 522 | # },on_unused_input='warn' ) 523 | 524 | 525 | update_gen = theano.function([batch_index],[adversarial_gen_loss,pixel_loss,1-error_rate], 526 | updates=gen_updates, 527 | givens = {X: X_shared[batch_slice], y: y_shared[batch_slice],Z: Z_shared[batch_slice]}, 528 | on_unused_input = 'warn') 529 | 530 | update_discrim = theano.function([batch_index],[discrim_g_loss,discrim_d_loss,discrim_accuracy,pixel_loss,1-error_rate], 531 | updates=discrim_updates, 532 | givens = {X: X_shared[batch_slice], y: y_shared[batch_slice],Z: Z_shared[batch_slice]}, 533 | on_unused_input = 'warn') 534 | # outputs for test/validation function 535 | test_outs = [x for x in [test_error_rate, 536 | classifier_test_error_rate] if x is not None] 537 | 538 | # Define test/validation function 539 | test_error_fn = theano.function([batch_index], 540 | test_outs, givens={ 541 | X: X_shared[batch_slice], 542 | y: y_shared[batch_slice] 543 | },on_unused_input='warn' ) 544 | 545 | # Dictionary of Theano Functions 546 | # tfuncs = {'update_iter':update_iter, 547 | tfuncs = {'update_gen': update_gen, 548 | 'update_discrim': update_discrim, 549 | 'test_function':test_error_fn, 550 | 'sample': sample, 551 | 'Zfn' : Zfn 552 | } 553 | 554 | # Dictionary of Theano Variables 555 | tvars = {'X' : X, 556 | 'y' : y, 557 | 'Z' : Z, 558 | 'X_shared' : X_shared, 559 | 'y_shared' : y_shared, 560 | 'Z_shared' : Z_shared, 561 | 'batch_slice' : batch_slice, 562 | 'batch_index' : batch_index, 563 | 'learning_rate' : learning_rate, 564 | 'log_sigma': log_sigma_theta 565 | } 566 | 567 | return tfuncs, tvars, model 568 | 569 | # Data Loading Function 570 | # 571 | # This function interfaces with a Fuel dataset and returns numpy arrays containing the requested data 572 | def data_loader(cfg,set,offset=0,shuffle=False,seed=42): 573 | 574 | # Define chunk size 575 | chunk_size = cfg['batch_size']*cfg['batches_per_chunk'] 576 | 577 | np.random.seed(seed) 578 | index = np.random.permutation(set.num_examples-offset) if shuffle else np.asarray(range(set.num_examples-offset)) 579 | 580 | # Open Dataset 581 | set.open() 582 | 583 | 584 | # Loop across all data 585 | for i in xrange(set.num_examples//chunk_size): 586 | yield to_tanh(np.float32(set.get_data(request = list(index[range(offset+chunk_size*i,offset+chunk_size*(i+1))]))[0])) 587 | 588 | # Close dataset 589 | set.close(state=None) 590 | 591 | 592 | # Main Function 593 | def main(args): 594 | 595 | # Load Config Module from source file 596 | config_module = imp.load_source('config', args.config_path) 597 | 598 | # Get configuration parameters 599 | cfg = config_module.cfg 600 | 601 | # Define name of npz file to which the model parameters will be saved 602 | weights_fname = str(args.config_path)[:-3]+'.npz' 603 | 604 | # Define the name of the jsonl file to which the training log will be saved 605 | metrics_fname = weights_fname[:-4]+'METRICS.jsonl' 606 | 607 | # Prepare logs 608 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s| %(message)s') 609 | logging.info('Metrics will be saved to {}'.format(metrics_fname)) 610 | mlog = voxnet.metrics_logging.MetricsLogger(metrics_fname, reinitialize=True) 611 | model = config_module.get_model(interp=False) 612 | 613 | logging.info('Compiling theano functions...') 614 | 615 | # Compile functions 616 | tfuncs, tvars,model = make_training_functions(cfg,model) 617 | 618 | logging.info('Training...') 619 | 620 | # Iteration Counter, indicates total number of minibatches processed 621 | itr = 0 622 | 623 | # Best validation accuracy variable 624 | best_acc = 0 625 | 626 | # Test set for interpolations 627 | test_set = CelebA('64',('test',),sources=('features',)) 628 | 629 | # Loop across epochs 630 | offset = True 631 | params = list(set(lasagne.layers.get_all_params(model['l_out'],trainable=True)+[tvars['log_sigma']]+lasagne.layers.get_all_params(model['l_discrim'],trainable=True)+[x for x in lasagne.layers.get_all_params(model['l_out']) if x.name[-4:]=='mean' or x.name[-7:]=='inv_std'])) 632 | 633 | # Ratio of gen updates to discrim updates 634 | update_ratio = cfg['update_ratio'] 635 | for epoch in xrange(cfg['max_epochs']): 636 | offset = not offset 637 | 638 | # Get generator for data 639 | loader = data_loader(cfg, 640 | CelebA('64',('train',),sources=('features',)), 641 | offset=offset*cfg['batch_size']//2,shuffle=cfg['shuffle'], 642 | seed=epoch) # Does this need to happen every epoch? 643 | 644 | # Update Learning Rate, either with annealing schedule or decay rate 645 | if isinstance(cfg['learning_rate'], dict) and epoch > 0: 646 | if any(x==epoch for x in cfg['learning_rate'].keys()): 647 | lr = np.float32(tvars['learning_rate'].get_value()) 648 | new_lr = cfg['learning_rate'][epoch] 649 | logging.info('Changing learning rate from {} to {}'.format(lr, new_lr)) 650 | tvars['learning_rate'].set_value(np.float32(new_lr)) 651 | if cfg['decay_rate'] and epoch > 0: 652 | lr = np.float32(tvars['learning_rate'].get_value()) 653 | new_lr = lr*(1-cfg['decay_rate']) 654 | logging.info('Changing learning rate from {} to {}'.format(lr, new_lr)) 655 | tvars['learning_rate'].set_value(np.float32(new_lr)) 656 | 657 | # Number of Chunks 658 | iter_counter = 0 659 | 660 | # Epoch-Wise Metrics 661 | # vloss_e, floss_e, closs_e, a_g_loss_e, a_d_loss_e, d_kl_e, c_acc_e, acc_e = 0, 0, 0, 0, 0, 0, 0, 0 662 | 663 | # Loop across all chunks 664 | for x_shared in loader: 665 | 666 | # Increment Chunk Counter 667 | iter_counter+=1 668 | 669 | # Figure out number of batches 670 | num_batches = len(x_shared)//cfg['batch_size'] 671 | 672 | # Shuffle chunk 673 | # np.random.seed(42*epoch) 674 | index = np.random.permutation(len(x_shared)) 675 | 676 | # Load data onto GPU 677 | tvars['X_shared'].set_value(x_shared[index], borrow=True) 678 | tvars['Z_shared'].set_value(np.float32(np.random.randn(len(x_shared),cfg['num_latents'])),borrow=True) 679 | 680 | # Chunk Metrics 681 | # voxel_lvs,feature_lvs,class_lvs,a_g_lvs,a_d_lvs, kl_divs,class_accs,accs = [],[],[],[],[],[],[],[] 682 | a_g_lvs,a_dg_lvs,a_dd_lvs,discrim_acc,pixel_lvs,pixel_accs = [],[],[],[],[],[] 683 | # Loop across all batches in chunk 684 | for bi in xrange(num_batches): 685 | 686 | if itr % (update_ratio+1)==0: 687 | [a_gen_loss,pixel_loss,pixel_acc] = tfuncs['update_gen'](bi) 688 | a_g_lvs.append(a_gen_loss) 689 | pixel_lvs.append(pixel_loss) 690 | pixel_accs.append(pixel_acc) 691 | else: 692 | [a_discrim_gloss,a_discrim_dloss,discrim_accuracy,pixel_loss,pixel_acc] = tfuncs['update_discrim'](bi) 693 | a_dg_lvs.append(a_discrim_gloss) 694 | a_dd_lvs.append(a_discrim_dloss) 695 | discrim_acc.append(discrim_accuracy) 696 | pixel_lvs.append(pixel_loss) 697 | pixel_accs.append(pixel_acc) 698 | 699 | itr += 1 700 | 701 | # if not a_dg_lvs: 702 | # a_dg_lvs,a_dd_lvs,discrim_acc = 0,0,0 703 | # if not a_g_lvs: 704 | # a_g_lvs = 0 705 | 706 | # Train! 707 | # results = tfuncs['update_iter'](bi) 708 | 709 | # Assign results 710 | # TODO: Clean up the assignment so that the variable things are just on the end of the assignment and 711 | # this can be done in one or two lines 712 | # voxel_loss = results[0] if cfg['reconstruct'] or cfg['introspect'] else 0 713 | # feature_loss = results[(cfg['reconstruct'] or cfg['introspect'])] if cfg['introspect'] else 0 714 | # classifier_loss = results[cfg['introspect']+(cfg['reconstruct'] or cfg['introspect'])] if cfg['discriminative'] else 0 715 | # a_gen_loss = results[cfg['introspect']+cfg['discriminative']+(cfg['reconstruct'] or cfg['introspect'])] if cfg['adversarial'] else 0 716 | # a_discrim_loss = results[1+cfg['introspect']+cfg['discriminative']+(cfg['reconstruct'] or cfg['introspect'])] if cfg['adversarial'] else 0 717 | # kl_div = results[cfg['introspect']+cfg['discriminative']+2*cfg['adversarial']+(cfg['reconstruct'] or cfg['introspect'])] if cfg['kl_div'] else 0 718 | # class_acc = results[cfg['introspect']+cfg['discriminative']+2*cfg['adversarial']+cfg['kl_div']+(cfg['reconstruct'] or cfg['introspect'])] if cfg['discriminative'] else 0 719 | # acc = results[cfg['introspect']+2*cfg['discriminative']+2*cfg['adversarial']+cfg['kl_div']+(cfg['reconstruct'] or cfg['introspect'])] if cfg['reconstruct'] or cfg['introspect'] else 0 720 | # voxel_lvs.append(voxel_loss) 721 | # feature_lvs.append(feature_loss) 722 | # class_lvs.append(classifier_loss) 723 | 724 | 725 | # kl_divs.append(kl_div) 726 | # class_accs.append(class_acc) 727 | # accs.append(acc) 728 | 729 | 730 | [agloss,adgloss,addloss,accuracy] = [float(np.mean(a_g_lvs)),float(np.mean(a_dg_lvs)),float(np.mean(a_dd_lvs)),float(np.mean(discrim_acc))] 731 | [ploss,pixel_accuracy] = [float(np.mean(pixel_lvs)),float(np.mean(pixel_accs))] 732 | # Chunk-wise metrics 733 | # [vloss, floss,closs, agloss, adloss, d_kl,c_acc,acc] = [float(np.mean(voxel_lvs)), float(np.mean(feature_lvs)), 734 | # float(np.mean(class_lvs)), float(np.mean(a_g_lvs)),float(np.mean(a_d_lvs)),float(np.mean(kl_divs)), 735 | # 1.0-float(np.mean(class_accs)), 1.0-float(np.mean(accs))] 736 | # Epoch-wise metrics 737 | # vloss_e, floss_e, closs_e, a_g_loss_e, a_d_loss_e, d_kl_e, c_acc_e, acc_e = [vloss_e+vloss, floss_e+floss, closs_e+closs, a_g_loss_e+agloss, a_d_loss_e+adloss, d_kl_e+d_kl, c_acc_e+c_acc, acc_e+acc] 738 | 739 | # Report Chunk Metrics 740 | # logging.info('epoch: {:4d}, itr: {:8d}, p_loss: {:8.5f}, f_loss: {:8.5f}, a_g_loss: {:8.5f}, a_d_loss: {:8.5f}, D_kl: {:8.5f}, acc: {:6.5f}'.format(epoch, itr, vloss, floss, 741 | # agloss, adloss, d_kl, acc)) 742 | logging.info('epoch: {:4d}, itr: {:8d}, ag_loss: {:7.4f}, adg_loss: {:7.4f}, add_loss: {:7.4f}, acc: {:5.3f}, ploss: {:7.4f}, pacc: {:5.3f}'.format(epoch,itr,agloss,adgloss,addloss,accuracy,ploss,pixel_accuracy)) 743 | mlog.log(epoch=epoch, itr=itr, agloss=agloss,adgloss = adgloss,addloss=addloss,discrim_accuracy=accuracy,ploss=ploss,pixel_accuracy=pixel_accuracy) 744 | # Log Chunk Metrics 745 | # mlog.log(epoch=epoch, itr=itr, vloss=vloss,floss=floss, agloss=agloss,adloss = adloss, acc=acc,d_kl=d_kl,c_acc=c_acc) 746 | 747 | # Average Epoch-wise Metrics 748 | # vloss_e, floss_e, closs_e, a_g_loss_e, a_d_loss_e, d_kl_e, c_acc_e, acc_e = [vloss_e/iter_counter, floss_e/iter_counter, 749 | # closs_e/iter_counter, a_g_loss_e/iter_counter, a_d_loss_e/iter_counter, d_kl_e/iter_counter, 750 | # c_acc_e/iter_counter, acc_e/iter_counter] 751 | # Report Epoch-wise metrics 752 | # logging.info('Training metrics, Epoch {}, p_loss: {}, f_loss: {}, a_g_loss: {}, a_d_loss: {}, c_loss: {}, D_kl: {}, class_acc: {}, acc: {}'.format(epoch, vloss_e, floss_e,a_g_loss_e, a_d_loss_e, closs_e,d_kl_e,c_acc_e,acc_e)) 753 | 754 | # Log Epoch-wise metrics 755 | # mlog.log(epoch=epoch, vloss_e=vloss_e, floss_e=floss_e, a_g_loss_e = a_g_loss_e,a_d_loss_e=a_d_loss_e,closs_e=closs_e, d_kl_e=d_kl_e, c_acc_e=c_acc_e, acc_e=acc_e) 756 | 757 | 758 | 759 | if cfg['reconstruct'] or cfg['introspect']: 760 | logging.info('Examining performance on validation set') 761 | 762 | # Validation Metrics 763 | test_error,test_class_error = [],[], 764 | 765 | # Prepare Test Loader 766 | for o in xrange(2): 767 | test_loader = data_loader(cfg,CelebA('64',('valid',),sources=('features',)),offset=o*cfg['batch_size']//2) 768 | 769 | # Loop Across Chunks 770 | for x_shared in test_loader: 771 | 772 | # Figure Out Number of Batches 773 | num_batches = len(x_shared)//cfg['batch_size'] 774 | 775 | # Load chunk onto GPU 776 | tvars['X_shared'].set_value(x_shared, borrow=True) 777 | 778 | # Loop Across Batches 779 | for bi in xrange(num_batches): 780 | 781 | # Test! 782 | test_results = tfuncs['test_function'](bi) # Get the test 783 | 784 | # Assign results 785 | batch_test_error=test_results[0] 786 | batch_test_class_error = test_results[1] if cfg['discriminative'] else 0 787 | test_error.append(batch_test_error) 788 | test_class_error.append(batch_test_class_error) 789 | 790 | 791 | 792 | # Average Results 793 | t_error = 1-float(np.mean(test_error)) 794 | t_class_error = 1-float(np.mean(test_class_error)) 795 | 796 | # Report Validation Results 797 | logging.info('Epoch {} Test Accuracy: {}, Classification Test Accuracy: {}, Best_acc: {} '.format(epoch, t_error,t_class_error,best_acc)) 798 | 799 | # Log Validation Results 800 | mlog.log(test_error=t_error,t_class_error = t_class_error) 801 | 802 | 803 | # If we see improvement, save weights and produce output images 804 | # if cfg['reconstruct'] or cfg['introspect']: 805 | if not (epoch%cfg['checkpoint_every_nth']): 806 | 807 | 808 | # Update Best-yet accuracy 809 | # if t_error > best_acc: 810 | # best_acc = t_error 811 | 812 | # Save Weights 813 | 814 | 815 | # Open Test Set 816 | test_set.open() 817 | 818 | np.random.seed(epoch*42+5) 819 | # Generate Random Samples 820 | samples = np.uint8(from_tanh(tfuncs['sample'](np.random.randn(27,cfg['num_latents']).astype(np.float32)))) 821 | 822 | 823 | np.random.seed(epoch*42+5) 824 | # Get Reconstruction/Interpolation Endpoints 825 | endpoints = np.uint8(test_set.get_data(request = list(np.random.choice(test_set.num_examples,6,replace=False)))[0]) 826 | 827 | # Get reconstruction latents 828 | Ze = np.asarray(tfuncs['Zfn'](to_tanh(np.float32(endpoints)))) 829 | 830 | print(np.shape(Ze)) 831 | 832 | # Get Interpolant Latents 833 | Z = np.asarray([Ze[2 * i, :] * (1 - j) + Ze[2 * i + 1, :] * j for i in range(3) for j in [x/6.0 for x in range(7)]],dtype=np.float32) 834 | 835 | # Get all images 836 | images = np.append(samples,np.concatenate([np.insert(endpoints[2*i:2*(i+1),:,:,:],1,np.uint8(from_tanh(tfuncs['sample'](Z[7*i:7*(i+1),:]))),axis=0) for i in range(3)],axis=0),axis=0) 837 | 838 | 839 | # Plot images 840 | plot_image_grid(images,6,9,'pics/'+str(args.config_path)[:-3]+'_'+str(epoch)+'.png') 841 | 842 | # Close test set 843 | test_set.close(state=None) 844 | params = list(set(lasagne.layers.get_all_params(model['l_out'],trainable=True)+[tvars['log_sigma']]+lasagne.layers.get_all_params(model['l_discrim'],trainable=True)+[x for x in lasagne.layers.get_all_params(model['l_out'])+lasagne.layers.get_all_params(model['l_discrim']) if x.name[-4:]=='mean' or x.name[-7:]=='inv_std'])) 845 | GANcheckpoints.save_weights(weights_fname, params,{'itr': itr, 'ts': time.time(),'learning_rate':np.float32(tvars['learning_rate'].get_value())}) 846 | # Save weights 847 | # CAcheckpoints.save_weights(weights_fname, model['l_out'],tvars['log_sigma'], 848 | # {'itr': itr, 'ts': time.time(),'best_acc':best_acc,'learning_rate':np.float32(tvars['learning_rate'].get_value())}) 849 | # elif not (epoch%cfg['checkpoint_every_nth']): 850 | # logging.info('Checkpoint: Saving weights and generating samples') 851 | # np.random.seed(epoch*42+5) 852 | # samples = np.uint8(from_tanh(tfuncs['sample'](np.random.randn(54,cfg['num_latents']).astype(np.float32)))) 853 | # Plot images 854 | # plot_image_grid(samples,6,9,'pics/'+str(args.config_path)[:-3]+'_'+str(epoch)+'.png') 855 | # params = list(set(lasagne.layers.get_all_params(model['l_out'],trainable=True)+[tvars['log_sigma']]+lasagne.layers.get_all_params(model['l_discrim'],trainable=True)+[x for x in lasagne.layers.get_all_params(model['l_out'])+lasagne.layers.get_all_params(model['l_discrim']) if x.name[-4:]=='mean' or x.name[-7:]=='inv_std'])) 856 | # GANcheckpoints.save_weights(weights_fname, params,{'itr': itr, 'ts': time.time(),'learning_rate':np.float32(tvars['learning_rate'].get_value())}) 857 | logging.info('training done') 858 | 859 | 860 | if __name__=='__main__': 861 | parser = argparse.ArgumentParser() 862 | parser.add_argument('config_path', type=Path, help='config .py file') 863 | args = parser.parse_args() 864 | main(args) 865 | --------------------------------------------------------------------------------