├── LICENSE ├── README.md ├── faces ├── README.md ├── load.py └── train_uncond_dcgan.py ├── imagenet └── load_pretrained.py ├── images ├── 50.png ├── 50_cropped.png ├── albums_128px.png ├── faces_128_filter_samples.png ├── faces_arithmetic.png ├── faces_arithmetic_collage.png ├── faces_arithmetic_collage_v2.png ├── googsearch_dcgan.png ├── googsearch_lapgan.png ├── interp_comparison.png ├── lsun_bedrooms_five_epoch_samples.png ├── lsun_bedrooms_five_epochs_interps.png ├── lsun_bedrooms_generator.png ├── lsun_bedrooms_one_epoch_interps.png ├── lsun_bedrooms_one_epoch_samples.png ├── lsun_bedrooms_real.png ├── lsun_bedrooms_window_drop_test.png ├── lsun_five_epochs_guided.png ├── lsun_five_epochs_max_act_l4.png ├── mnist_collage.png ├── mnist_cond_conv_gan_samples.png ├── mnist_cond_fc_gan_samples.png ├── mnist_real.png ├── random_weights.png ├── smile_arithmetic.png ├── turn_vector.png └── vecarithmetictwo.png ├── lib ├── __init__.py ├── activations.py ├── config.py ├── costs.py ├── cv2_utils.py ├── data_utils.py ├── inits.py ├── metrics.py ├── ops.py ├── rng.py ├── theano_utils.py ├── updates.py └── vis.py ├── mnist ├── README.md ├── load.py └── train_cond_dcgan.py ├── models ├── imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z │ ├── 30_discrim_params.jl │ ├── 30_discrim_params.jl_01.npy │ ├── 30_discrim_params.jl_02.npy │ ├── 30_discrim_params.jl_03.npy │ ├── 30_discrim_params.jl_04.npy │ ├── 30_discrim_params.jl_05.npy │ ├── 30_discrim_params.jl_06.npy │ ├── 30_discrim_params.jl_07.npy │ ├── 30_discrim_params.jl_08.npy │ ├── 30_discrim_params.jl_09.npy │ ├── 30_discrim_params.jl_10.npy │ ├── 30_discrim_params.jl_11.npy │ ├── 30_discrim_params.jl_12.npy │ ├── 30_discrim_params.jl_13.npy │ ├── 30_discrim_params.jl_14.npy │ ├── 30_discrim_params.jl_15.npy │ ├── 30_discrim_params.jl_16.npy │ ├── 30_discrim_params.jl_17.npy │ ├── 30_gen_params.jl │ ├── 30_gen_params.jl_01.npy │ ├── 30_gen_params.jl_02.npy │ ├── 30_gen_params.jl_03.npy │ ├── 30_gen_params.jl_04.npy │ ├── 30_gen_params.jl_05.npy │ ├── 30_gen_params.jl_06.npy │ ├── 30_gen_params.jl_07.npy │ ├── 30_gen_params.jl_08.npy │ ├── 30_gen_params.jl_09.npy │ ├── 30_gen_params.jl_10.npy │ ├── 30_gen_params.jl_11.npy │ ├── 30_gen_params.jl_12.npy │ ├── 30_gen_params.jl_13.npy │ ├── 30_gen_params.jl_14.npy │ ├── 30_gen_params.jl_15.npy │ ├── 30_gen_params.jl_16.npy │ ├── 30_gen_params.jl_17.npy │ ├── 30_gen_params.jl_18.npy │ └── 30_gen_params.jl_19.npy └── svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb │ ├── 200_discrim_params.jl │ ├── 200_discrim_params.jl_01.npy │ ├── 200_discrim_params.jl_02.npy │ ├── 200_discrim_params.jl_03.npy │ ├── 200_discrim_params.jl_04.npy │ ├── 200_discrim_params.jl_05.npy │ ├── 200_discrim_params.jl_06.npy │ ├── 200_discrim_params.jl_07.npy │ ├── 200_discrim_params.jl_08.npy │ ├── 200_gen_params.jl │ ├── 200_gen_params.jl_01.npy │ ├── 200_gen_params.jl_02.npy │ ├── 200_gen_params.jl_03.npy │ ├── 200_gen_params.jl_04.npy │ ├── 200_gen_params.jl_05.npy │ ├── 200_gen_params.jl_06.npy │ ├── 200_gen_params.jl_07.npy │ ├── 200_gen_params.jl_08.npy │ ├── 200_gen_params.jl_09.npy │ └── 200_gen_params.jl_10.npy └── svhn ├── load.py └── svhn_semisup_analysis.py /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Alec Radford 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks 2 | ## [Alec Radford]((https://github.com/newmu)), [Luke Metz](https://github.com/lukemetz), [Soumith Chintala](https://github.com/soumith) 3 | 4 | All images in this paper are generated by a neural network. They are NOT REAL. 5 | 6 | Full paper here: [http://arxiv.org/abs/1511.06434](http://arxiv.org/abs/1511.06434) 7 | 8 | ###Other implementations of DCGAN 9 | * [Torch](https://github.com/soumith/dcgan.torch) 10 | * [Chainer](https://github.com/mattya/chainer-DCGAN) 11 | * [TensorFlow](https://github.com/carpedm20/DCGAN-tensorflow) 12 | 13 | ##Summary of DCGAN 14 | We 15 | - stabilize Generative Adversarial networks with some architectural constraints 16 | - Replace any pooling layers with strided convolutions (discriminator) and fractional-strided 17 | convolutions (generator). 18 | - Use batchnorm in both the generator and the discriminator 19 | - Remove fully connected hidden layers for deeper architectures. Just use average pooling at the end. 20 | - Use ReLU activation in generator for all layers except for the output, which uses Tanh. 21 | - Use LeakyReLU activation in the discriminator for all layers. 22 | - use the discriminator as a pre-trained net for CIFAR-10 classification and show pretty decent results. 23 | - generate really cool bedroom images that look super real 24 | - To convince you that the network is not cheating: 25 | - show the interpolated latent space, where transitions are really smooth and every image in the latent space is a bedroom. 26 | - show bedrooms after one epoch of training (with a 0.0002 learning rate), come on the network cant really memorize at this stage. 27 | - To explore what the representations that the network learnt, 28 | - show deconvolution over the filters, to show that maximal activations occur at objects like windows and beds 29 | - figure out a way to identify and remove filters that draw windows in generation. 30 | - Now you can control the generator to not output certain objects. 31 | - Because we are tripping 32 | - Smiling woman - neutral woman + neutral man = Smiling man. Whuttttt! 33 | - man with glasses - man without glasses + woman without glasses = woman with glasses. Omg!!!! 34 | - learnt a latent space in a completely unsupervised fashion where ROTATIONS ARE LINEAR in this latent space. WHHHAAATT????!!!!!! 35 | - Figure 11, trained on imagenet has a plane with bird legs. so cooool. 36 | 37 | # Bedrooms after 5 epochs 38 | Generated bedrooms after five epochs of training. There appears to be evidence of visual 39 | under-fitting via repeated textures across multiple samples. 40 | ![](images/lsun_bedrooms_five_epoch_samples.png) 41 | 42 | # Bedrooms after 1 epoch 43 | Generated bedrooms after one training pass through the dataset. Theoretically, the model 44 | could learn to memorize training examples, but this is experimentally unlikely as we train with a 45 | small learning rate and minibatch SGD. We are aware of no prior empirical evidence demonstrating 46 | memorization with SGD and a small learning rate in only one epoch. 47 | ![](images/lsun_bedrooms_one_epoch_samples.png) 48 | 49 | # Walking from one point to another in bedroom latent space 50 | 51 | Interpolation between a series of 9 random points in Z show that the space 52 | learned has smooth transitions, with every image in the space plausibly looking like a bedroom. In 53 | the 6th row, you see a room without a window slowly transforming into a room with a giant window. 54 | In the 10th row, you see what appears to be a TV slowly being transformed into a window. 55 | 56 | ![](images/lsun_bedrooms_five_epochs_interps.png) 57 | 58 | # Forgetting to draw windows 59 | 60 | Top row: un-modified samples from model. Bottom row: the same samples generated 61 | with dropping out ”window” filters. Some windows are removed, others are transformed into objects 62 | with similar visual appearance such as doors and mirrors. Although visual quality decreased, overall 63 | scene composition stayed similar, suggesting the generator has done a good job disentangling scene 64 | representation from object representation. Extended experiments could be done to remove other 65 | objects from the image and modify the objects the generator draws. 66 | 67 | ![](images/lsun_bedrooms_window_drop_test.png) 68 | 69 | # Google image search from generations 70 | 71 | ![](images/googsearch_dcgan.png) 72 | 73 | 74 | # Arithmetic on faces 75 | 76 | ![](images/faces_arithmetic_collage.png) 77 | 78 | # Rotations are linear in latent space 79 | 80 | ![](images/turn_vector.png) 81 | 82 | # More faces 83 | 84 | ![](images/faces_128_filter_samples.png) 85 | 86 | # Album covers 87 | 88 | ![](images/albums_128px.png) 89 | 90 | # Imagenet generations 91 | 92 | ![](images/50.png) 93 | -------------------------------------------------------------------------------- /faces/README.md: -------------------------------------------------------------------------------- 1 | Modify data_dir in lib/config.py to point to directory with faces hdf5. 2 | 3 | *Currently this data file is not released due to size/data restrictions.* 4 | 5 | Run train_uncond_dcgan.py to train face model from paper. It will create a few folders and save training info, model parameters, and samples periodically. Should be ~ 12 hours/overnight. 6 | 7 | Libs you'll need installed/configured to run it: 8 | - theano 9 | - cudnn 10 | - fuel/h5py 11 | - sklearn 12 | - numpy 13 | - scipy 14 | - matplotlib 15 | - tqdm -------------------------------------------------------------------------------- /faces/load.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | 4 | import os 5 | from fuel.datasets.hdf5 import H5PYDataset 6 | from fuel.schemes import ShuffledScheme, SequentialScheme 7 | from fuel.streams import DataStream 8 | 9 | from lib.config import data_dir 10 | 11 | def faces(ntrain=None, nval=None, ntest=None, batch_size=128): 12 | path = os.path.join(data_dir, 'faces_364293_128px.hdf5') 13 | tr_data = H5PYDataset(path, which_sets=('train',)) 14 | te_data = H5PYDataset(path, which_sets=('test',)) 15 | 16 | if ntrain is None: 17 | ntrain = tr_data.num_examples 18 | if ntest is None: 19 | ntest = te_data.num_examples 20 | if nval is None: 21 | nval = te_data.num_examples 22 | 23 | tr_scheme = ShuffledScheme(examples=ntrain, batch_size=batch_size) 24 | tr_stream = DataStream(tr_data, iteration_scheme=tr_scheme) 25 | 26 | te_scheme = SequentialScheme(examples=ntest, batch_size=batch_size) 27 | te_stream = DataStream(te_data, iteration_scheme=te_scheme) 28 | 29 | val_scheme = SequentialScheme(examples=nval, batch_size=batch_size) 30 | val_stream = DataStream(tr_data, iteration_scheme=val_scheme) 31 | return tr_data, te_data, tr_stream, val_stream, te_stream -------------------------------------------------------------------------------- /faces/train_uncond_dcgan.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | 4 | import os 5 | import json 6 | from time import time 7 | import numpy as np 8 | from tqdm import tqdm 9 | from matplotlib import pyplot as plt 10 | from sklearn.externals import joblib 11 | 12 | import theano 13 | import theano.tensor as T 14 | from theano.sandbox.cuda.dnn import dnn_conv 15 | 16 | from lib import activations 17 | from lib import updates 18 | from lib import inits 19 | from lib.vis import color_grid_vis 20 | from lib.rng import py_rng, np_rng 21 | from lib.ops import batchnorm, conv_cond_concat, deconv, dropout, l2normalize 22 | from lib.metrics import nnc_score, nnd_score 23 | from lib.theano_utils import floatX, sharedX 24 | from lib.data_utils import OneHot, shuffle, iter_data, center_crop, patch 25 | 26 | from load import faces 27 | 28 | def transform(X): 29 | X = [center_crop(x, npx) for x in X] 30 | return floatX(X).transpose(0, 3, 1, 2)/127.5 - 1. 31 | 32 | def inverse_transform(X): 33 | X = (X.reshape(-1, nc, npx, npx).transpose(0, 2, 3, 1)+1.)/2. 34 | return X 35 | 36 | k = 1 # # of discrim updates for each gen update 37 | l2 = 1e-5 # l2 weight decay 38 | nvis = 196 # # of samples to visualize during training 39 | b1 = 0.5 # momentum term of adam 40 | nc = 3 # # of channels in image 41 | nbatch = 128 # # of examples in batch 42 | npx = 64 # # of pixels width/height of images 43 | nz = 100 # # of dim for Z 44 | ngf = 128 # # of gen filters in first conv layer 45 | ndf = 128 # # of discrim filters in first conv layer 46 | nx = npx*npx*nc # # of dimensions in X 47 | niter = 25 # # of iter at starting learning rate 48 | niter_decay = 0 # # of iter to linearly decay learning rate to zero 49 | lr = 0.0002 # initial learning rate for adam 50 | ntrain = 350000 # # of examples to train on 51 | 52 | tr_data, te_data, tr_stream, val_stream, te_stream = faces(ntrain=ntrain) 53 | 54 | tr_handle = tr_data.open() 55 | vaX, = tr_data.get_data(tr_handle, slice(0, 10000)) 56 | vaX = transform(vaX) 57 | 58 | desc = 'uncond_dcgan' 59 | model_dir = 'models/%s'%desc 60 | samples_dir = 'samples/%s'%desc 61 | if not os.path.exists('logs/'): 62 | os.makedirs('logs/') 63 | if not os.path.exists(model_dir): 64 | os.makedirs(model_dir) 65 | if not os.path.exists(samples_dir): 66 | os.makedirs(samples_dir) 67 | 68 | relu = activations.Rectify() 69 | sigmoid = activations.Sigmoid() 70 | lrelu = activations.LeakyRectify() 71 | tanh = activations.Tanh() 72 | bce = T.nnet.binary_crossentropy 73 | 74 | gifn = inits.Normal(scale=0.02) 75 | difn = inits.Normal(scale=0.02) 76 | gain_ifn = inits.Normal(loc=1., scale=0.02) 77 | bias_ifn = inits.Constant(c=0.) 78 | 79 | gw = gifn((nz, ngf*8*4*4), 'gw') 80 | gg = gain_ifn((ngf*8*4*4), 'gg') 81 | gb = bias_ifn((ngf*8*4*4), 'gb') 82 | gw2 = gifn((ngf*8, ngf*4, 5, 5), 'gw2') 83 | gg2 = gain_ifn((ngf*4), 'gg2') 84 | gb2 = bias_ifn((ngf*4), 'gb2') 85 | gw3 = gifn((ngf*4, ngf*2, 5, 5), 'gw3') 86 | gg3 = gain_ifn((ngf*2), 'gg3') 87 | gb3 = bias_ifn((ngf*2), 'gb3') 88 | gw4 = gifn((ngf*2, ngf, 5, 5), 'gw4') 89 | gg4 = gain_ifn((ngf), 'gg4') 90 | gb4 = bias_ifn((ngf), 'gb4') 91 | gwx = gifn((ngf, nc, 5, 5), 'gwx') 92 | 93 | dw = difn((ndf, nc, 5, 5), 'dw') 94 | dw2 = difn((ndf*2, ndf, 5, 5), 'dw2') 95 | dg2 = gain_ifn((ndf*2), 'dg2') 96 | db2 = bias_ifn((ndf*2), 'db2') 97 | dw3 = difn((ndf*4, ndf*2, 5, 5), 'dw3') 98 | dg3 = gain_ifn((ndf*4), 'dg3') 99 | db3 = bias_ifn((ndf*4), 'db3') 100 | dw4 = difn((ndf*8, ndf*4, 5, 5), 'dw4') 101 | dg4 = gain_ifn((ndf*8), 'dg4') 102 | db4 = bias_ifn((ndf*8), 'db4') 103 | dwy = difn((ndf*8*4*4, 1), 'dwy') 104 | 105 | gen_params = [gw, gg, gb, gw2, gg2, gb2, gw3, gg3, gb3, gw4, gg4, gb4, gwx] 106 | discrim_params = [dw, dw2, dg2, db2, dw3, dg3, db3, dw4, dg4, db4, dwy] 107 | 108 | def gen(Z, w, g, b, w2, g2, b2, w3, g3, b3, w4, g4, b4, wx): 109 | h = relu(batchnorm(T.dot(Z, w), g=g, b=b)) 110 | h = h.reshape((h.shape[0], ngf*8, 4, 4)) 111 | h2 = relu(batchnorm(deconv(h, w2, subsample=(2, 2), border_mode=(2, 2)), g=g2, b=b2)) 112 | h3 = relu(batchnorm(deconv(h2, w3, subsample=(2, 2), border_mode=(2, 2)), g=g3, b=b3)) 113 | h4 = relu(batchnorm(deconv(h3, w4, subsample=(2, 2), border_mode=(2, 2)), g=g4, b=b4)) 114 | x = tanh(deconv(h4, wx, subsample=(2, 2), border_mode=(2, 2))) 115 | return x 116 | 117 | def discrim(X, w, w2, g2, b2, w3, g3, b3, w4, g4, b4, wy): 118 | h = lrelu(dnn_conv(X, w, subsample=(2, 2), border_mode=(2, 2))) 119 | h2 = lrelu(batchnorm(dnn_conv(h, w2, subsample=(2, 2), border_mode=(2, 2)), g=g2, b=b2)) 120 | h3 = lrelu(batchnorm(dnn_conv(h2, w3, subsample=(2, 2), border_mode=(2, 2)), g=g3, b=b3)) 121 | h4 = lrelu(batchnorm(dnn_conv(h3, w4, subsample=(2, 2), border_mode=(2, 2)), g=g4, b=b4)) 122 | h4 = T.flatten(h4, 2) 123 | y = sigmoid(T.dot(h4, wy)) 124 | return y 125 | 126 | X = T.tensor4() 127 | Z = T.matrix() 128 | 129 | gX = gen(Z, *gen_params) 130 | 131 | p_real = discrim(X, *discrim_params) 132 | p_gen = discrim(gX, *discrim_params) 133 | 134 | d_cost_real = bce(p_real, T.ones(p_real.shape)).mean() 135 | d_cost_gen = bce(p_gen, T.zeros(p_gen.shape)).mean() 136 | g_cost_d = bce(p_gen, T.ones(p_gen.shape)).mean() 137 | 138 | d_cost = d_cost_real + d_cost_gen 139 | g_cost = g_cost_d 140 | 141 | cost = [g_cost, d_cost, g_cost_d, d_cost_real, d_cost_gen] 142 | 143 | lrt = sharedX(lr) 144 | d_updater = updates.Adam(lr=lrt, b1=b1, regularizer=updates.Regularizer(l2=l2)) 145 | g_updater = updates.Adam(lr=lrt, b1=b1, regularizer=updates.Regularizer(l2=l2)) 146 | d_updates = d_updater(discrim_params, d_cost) 147 | g_updates = g_updater(gen_params, g_cost) 148 | updates = d_updates + g_updates 149 | 150 | print 'COMPILING' 151 | t = time() 152 | _train_g = theano.function([X, Z], cost, updates=g_updates) 153 | _train_d = theano.function([X, Z], cost, updates=d_updates) 154 | _gen = theano.function([Z], gX) 155 | print '%.2f seconds to compile theano functions'%(time()-t) 156 | 157 | vis_idxs = py_rng.sample(np.arange(len(vaX)), nvis) 158 | vaX_vis = inverse_transform(vaX[vis_idxs]) 159 | color_grid_vis(vaX_vis, (14, 14), 'samples/%s_etl_test.png'%desc) 160 | 161 | sample_zmb = floatX(np_rng.uniform(-1., 1., size=(nvis, nz))) 162 | 163 | def gen_samples(n, nbatch=128): 164 | samples = [] 165 | n_gen = 0 166 | for i in range(n/nbatch): 167 | zmb = floatX(np_rng.uniform(-1., 1., size=(nbatch, nz))) 168 | xmb = _gen(zmb) 169 | samples.append(xmb) 170 | n_gen += len(xmb) 171 | n_left = n-n_gen 172 | zmb = floatX(np_rng.uniform(-1., 1., size=(n_left, nz))) 173 | xmb = _gen(zmb) 174 | samples.append(xmb) 175 | return np.concatenate(samples, axis=0) 176 | 177 | f_log = open('logs/%s.ndjson'%desc, 'wb') 178 | log_fields = [ 179 | 'n_epochs', 180 | 'n_updates', 181 | 'n_examples', 182 | 'n_seconds', 183 | '1k_va_nnd', 184 | '10k_va_nnd', 185 | '100k_va_nnd', 186 | 'g_cost', 187 | 'd_cost', 188 | ] 189 | 190 | vaX = vaX.reshape(len(vaX), -1) 191 | 192 | print desc.upper() 193 | n_updates = 0 194 | n_check = 0 195 | n_epochs = 0 196 | n_updates = 0 197 | n_examples = 0 198 | t = time() 199 | for epoch in range(niter): 200 | for imb, in tqdm(tr_stream.get_epoch_iterator(), total=ntrain/nbatch): 201 | imb = transform(imb) 202 | zmb = floatX(np_rng.uniform(-1., 1., size=(len(imb), nz))) 203 | if n_updates % (k+1) == 0: 204 | cost = _train_g(imb, zmb) 205 | else: 206 | cost = _train_d(imb, zmb) 207 | n_updates += 1 208 | n_examples += len(imb) 209 | g_cost = float(cost[0]) 210 | d_cost = float(cost[1]) 211 | gX = gen_samples(100000) 212 | gX = gX.reshape(len(gX), -1) 213 | va_nnd_1k = nnd_score(gX[:1000], vaX, metric='euclidean') 214 | va_nnd_10k = nnd_score(gX[:10000], vaX, metric='euclidean') 215 | va_nnd_100k = nnd_score(gX[:100000], vaX, metric='euclidean') 216 | log = [n_epochs, n_updates, n_examples, time()-t, va_nnd_1k, va_nnd_10k, va_nnd_100k, g_cost, d_cost] 217 | print '%.0f %.2f %.2f %.2f %.4f %.4f'%(epoch, va_nnd_1k, va_nnd_10k, va_nnd_100k, g_cost, d_cost) 218 | f_log.write(json.dumps(dict(zip(log_fields, log)))+'\n') 219 | f_log.flush() 220 | 221 | samples = np.asarray(_gen(sample_zmb)) 222 | color_grid_vis(inverse_transform(samples), (14, 14), 'samples/%s/%d.png'%(desc, n_epochs)) 223 | n_epochs += 1 224 | if n_epochs > niter: 225 | lrt.set_value(floatX(lrt.get_value() - lr/niter_decay)) 226 | if n_epochs in [1, 2, 3, 4, 5, 10, 15, 20, 25]: 227 | joblib.dump([p.get_value() for p in gen_params], 'models/%s/%d_gen_params.jl'%(desc, n_epochs)) 228 | joblib.dump([p.get_value() for p in discrim_params], 'models/%s/%d_discrim_params.jl'%(desc, n_epochs)) -------------------------------------------------------------------------------- /imagenet/load_pretrained.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | 4 | import numpy as np 5 | import theano 6 | import theano.tensor as T 7 | from theano.sandbox.cuda.dnn import dnn_conv 8 | 9 | from lib import costs 10 | from lib import inits 11 | from lib import updates 12 | from lib import activations 13 | from lib.vis import color_grid_vis 14 | from lib.rng import py_rng, np_rng 15 | from lib.ops import batchnorm, conv_cond_concat, deconv, dropout, l2normalize 16 | from lib.metrics import nnc_score, nnd_score 17 | from lib.theano_utils import floatX, sharedX, intX 18 | from lib.data_utils import OneHot, shuffle, iter_data, center_crop, patch 19 | 20 | from sklearn.externals import joblib 21 | 22 | """ 23 | This example loads the 32x32 imagenet model used in the paper, 24 | generates 400 random samples, and sorts them according to the 25 | discriminator's probability of being real and renders them to 26 | the file samples.png 27 | """ 28 | 29 | nz = 256 30 | nc = 3 31 | npx = 32 32 | ngf = 128 33 | ndf = 128 34 | 35 | relu = activations.Rectify() 36 | sigmoid = activations.Sigmoid() 37 | lrelu = activations.LeakyRectify() 38 | tanh = activations.Tanh() 39 | 40 | model_path = '../models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/' 41 | gen_params = [sharedX(p) for p in joblib.load(model_path+'30_gen_params.jl')] 42 | discrim_params = [sharedX(p) for p in joblib.load(model_path+'30_discrim_params.jl')] 43 | 44 | def gen(Z, w, g, b, w2, g2, b2, w3, g3, b3, w4, g4, b4, w5, g5, b5, w6, g6, b6, wx): 45 | h = relu(batchnorm(T.dot(Z, w), g=g, b=b)) 46 | h = h.reshape((h.shape[0], ngf*4, 4, 4)) 47 | h2 = relu(batchnorm(deconv(h, w2, subsample=(2, 2), border_mode=(1, 1)), g=g2, b=b2)) 48 | h3 = relu(batchnorm(deconv(h2, w3, subsample=(1, 1), border_mode=(1, 1)), g=g3, b=b3)) 49 | h4 = relu(batchnorm(deconv(h3, w4, subsample=(2, 2), border_mode=(1, 1)), g=g4, b=b4)) 50 | h5 = relu(batchnorm(deconv(h4, w5, subsample=(1, 1), border_mode=(1, 1)), g=g5, b=b5)) 51 | h6 = relu(batchnorm(deconv(h5, w6, subsample=(2, 2), border_mode=(1, 1)), g=g6, b=b6)) 52 | x = tanh(deconv(h6, wx, subsample=(1, 1), border_mode=(1, 1))) 53 | return x 54 | 55 | def discrim(X, w, w2, g2, b2, w3, g3, b3, w4, g4, b4, w5, g5, b5, w6, g6, b6, wy): 56 | h = lrelu(dnn_conv(X, w, subsample=(1, 1), border_mode=(1, 1))) 57 | h2 = lrelu(batchnorm(dnn_conv(h, w2, subsample=(2, 2), border_mode=(1, 1)), g=g2, b=b2)) 58 | h3 = lrelu(batchnorm(dnn_conv(h2, w3, subsample=(1, 1), border_mode=(1, 1)), g=g3, b=b3)) 59 | h4 = lrelu(batchnorm(dnn_conv(h3, w4, subsample=(2, 2), border_mode=(1, 1)), g=g4, b=b4)) 60 | h5 = lrelu(batchnorm(dnn_conv(h4, w5, subsample=(1, 1), border_mode=(1, 1)), g=g5, b=b5)) 61 | h6 = lrelu(batchnorm(dnn_conv(h5, w6, subsample=(2, 2), border_mode=(1, 1)), g=g6, b=b6)) 62 | h6 = T.flatten(h6, 2) 63 | y = sigmoid(T.dot(h6, wy)) 64 | return y 65 | 66 | def inverse_transform(X): 67 | X = (X.reshape(-1, nc, npx, npx).transpose(0, 2, 3, 1)+1.)/2. 68 | return X 69 | 70 | Z = T.matrix() 71 | X = T.tensor4() 72 | 73 | gX = gen(Z, *gen_params) 74 | dX = discrim(X, *discrim_params) 75 | 76 | _gen = theano.function([Z], gX) 77 | _discrim = theano.function([X], dX) 78 | 79 | sample_zmb = floatX(np_rng.uniform(-1., 1., size=(400, 256))) 80 | samples = _gen(sample_zmb) 81 | scores = _discrim(samples) 82 | sort = np.argsort(scores.flatten())[::-1] 83 | samples = samples[sort] 84 | color_grid_vis(inverse_transform(samples), (20, 20), 'samples.png') 85 | -------------------------------------------------------------------------------- /images/50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/50.png -------------------------------------------------------------------------------- /images/50_cropped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/50_cropped.png -------------------------------------------------------------------------------- /images/albums_128px.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/albums_128px.png -------------------------------------------------------------------------------- /images/faces_128_filter_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/faces_128_filter_samples.png -------------------------------------------------------------------------------- /images/faces_arithmetic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/faces_arithmetic.png -------------------------------------------------------------------------------- /images/faces_arithmetic_collage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/faces_arithmetic_collage.png -------------------------------------------------------------------------------- /images/faces_arithmetic_collage_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/faces_arithmetic_collage_v2.png -------------------------------------------------------------------------------- /images/googsearch_dcgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/googsearch_dcgan.png -------------------------------------------------------------------------------- /images/googsearch_lapgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/googsearch_lapgan.png -------------------------------------------------------------------------------- /images/interp_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/interp_comparison.png -------------------------------------------------------------------------------- /images/lsun_bedrooms_five_epoch_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/lsun_bedrooms_five_epoch_samples.png -------------------------------------------------------------------------------- /images/lsun_bedrooms_five_epochs_interps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/lsun_bedrooms_five_epochs_interps.png -------------------------------------------------------------------------------- /images/lsun_bedrooms_generator.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/lsun_bedrooms_generator.png -------------------------------------------------------------------------------- /images/lsun_bedrooms_one_epoch_interps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/lsun_bedrooms_one_epoch_interps.png -------------------------------------------------------------------------------- /images/lsun_bedrooms_one_epoch_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/lsun_bedrooms_one_epoch_samples.png -------------------------------------------------------------------------------- /images/lsun_bedrooms_real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/lsun_bedrooms_real.png -------------------------------------------------------------------------------- /images/lsun_bedrooms_window_drop_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/lsun_bedrooms_window_drop_test.png -------------------------------------------------------------------------------- /images/lsun_five_epochs_guided.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/lsun_five_epochs_guided.png -------------------------------------------------------------------------------- /images/lsun_five_epochs_max_act_l4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/lsun_five_epochs_max_act_l4.png -------------------------------------------------------------------------------- /images/mnist_collage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/mnist_collage.png -------------------------------------------------------------------------------- /images/mnist_cond_conv_gan_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/mnist_cond_conv_gan_samples.png -------------------------------------------------------------------------------- /images/mnist_cond_fc_gan_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/mnist_cond_fc_gan_samples.png -------------------------------------------------------------------------------- /images/mnist_real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/mnist_real.png -------------------------------------------------------------------------------- /images/random_weights.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/random_weights.png -------------------------------------------------------------------------------- /images/smile_arithmetic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/smile_arithmetic.png -------------------------------------------------------------------------------- /images/turn_vector.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/turn_vector.png -------------------------------------------------------------------------------- /images/vecarithmetictwo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/images/vecarithmetictwo.png -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/lib/__init__.py -------------------------------------------------------------------------------- /lib/activations.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | 4 | class Softmax(object): 5 | 6 | def __init__(self): 7 | pass 8 | 9 | def __call__(self, x): 10 | e_x = T.exp(x - x.max(axis=1).dimshuffle(0, 'x')) 11 | return e_x / e_x.sum(axis=1).dimshuffle(0, 'x') 12 | 13 | class ConvSoftmax(object): 14 | 15 | def __init__(self): 16 | pass 17 | 18 | def __call__(self, x): 19 | e_x = T.exp(x - x.max(axis=1, keepdims=True)) 20 | return e_x / e_x.sum(axis=1, keepdims=True) 21 | 22 | class Maxout(object): 23 | 24 | def __init__(self, n_pool=2): 25 | self.n_pool = n_pool 26 | 27 | def __call__(self, x): 28 | if x.ndim == 2: 29 | x = T.max([x[:, n::self.n_pool] for n in range(self.n_pool)], axis=0) 30 | elif x.ndim == 4: 31 | x = T.max([x[:, n::self.n_pool, :, :] for n in range(self.n_pool)], axis=0) 32 | else: 33 | raise NotImplementedError 34 | return x 35 | 36 | class Rectify(object): 37 | 38 | def __init__(self): 39 | pass 40 | 41 | def __call__(self, x): 42 | return (x + abs(x)) / 2.0 43 | 44 | class ClippedRectify(object): 45 | 46 | def __init__(self, clip=10.): 47 | self.clip = clip 48 | 49 | def __call__(self, x): 50 | return T.clip((x + abs(x)) / 2.0, 0., self.clip) 51 | 52 | class LeakyRectify(object): 53 | 54 | def __init__(self, leak=0.2): 55 | self.leak = leak 56 | 57 | def __call__(self, x): 58 | f1 = 0.5 * (1 + self.leak) 59 | f2 = 0.5 * (1 - self.leak) 60 | return f1 * x + f2 * abs(x) 61 | 62 | class Prelu(object): 63 | 64 | def __init__(self): 65 | pass 66 | 67 | def __call__(self, x, leak): 68 | if x.ndim == 4: 69 | leak = leak.dimshuffle('x', 0, 'x', 'x') 70 | f1 = 0.5 * (1 + leak) 71 | f2 = 0.5 * (1 - leak) 72 | return f1 * x + f2 * abs(x) 73 | 74 | class Tanh(object): 75 | 76 | def __init__(self): 77 | pass 78 | 79 | def __call__(self, x): 80 | return T.tanh(x) 81 | 82 | class Sigmoid(object): 83 | 84 | def __init__(self): 85 | pass 86 | 87 | def __call__(self, x): 88 | return T.nnet.sigmoid(x) 89 | 90 | class Linear(object): 91 | 92 | def __init__(self): 93 | pass 94 | 95 | def __call__(self, x): 96 | return x 97 | 98 | class HardSigmoid(object): 99 | 100 | def __init__(self): 101 | pass 102 | 103 | def __call__(self, X): 104 | return T.clip(X + 0.5, 0., 1.) 105 | 106 | class TRec(object): 107 | 108 | def __init__(self, t=1): 109 | self.t = t 110 | 111 | def __call__(self, X): 112 | return X*(X > self.t) 113 | 114 | class HardTanh(object): 115 | 116 | def __init__(self): 117 | pass 118 | 119 | def __call__(self, X): 120 | return T.clip(X, -1., 1.) -------------------------------------------------------------------------------- /lib/config.py: -------------------------------------------------------------------------------- 1 | data_dir = '/home/indico/datasets/iclr2016' -------------------------------------------------------------------------------- /lib/costs.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | 4 | def CategoricalCrossEntropy(y_true, y_pred): 5 | return T.nnet.categorical_crossentropy(y_pred, y_true).mean() 6 | 7 | def BinaryCrossEntropy(y_true, y_pred): 8 | return T.nnet.binary_crossentropy(y_pred, y_true).mean() 9 | 10 | def MeanSquaredError(y_true, y_pred): 11 | return T.sqr(y_pred - y_true).mean() 12 | 13 | def MeanAbsoluteError(y_true, y_pred): 14 | return T.abs_(y_pred - y_true).mean() 15 | 16 | def SquaredHinge(y_true, y_pred): 17 | return T.sqr(T.maximum(1. - y_true * y_pred, 0.)).mean() 18 | 19 | def Hinge(y_true, y_pred): 20 | return T.maximum(1. - y_true * y_pred, 0.).mean() 21 | 22 | cce = CCE = CategoricalCrossEntropy 23 | bce = BCE = BinaryCrossEntropy 24 | mse = MSE = MeanSquaredError 25 | mae = MAE = MeanAbsoluteError 26 | -------------------------------------------------------------------------------- /lib/cv2_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | def min_resize(x, size, interpolation=cv2.INTER_LINEAR): 4 | """ 5 | Resize an image so that it is size along the minimum spatial dimension. 6 | """ 7 | w, h = map(float, x.shape[:2]) 8 | if min([w, h]) != size: 9 | if w <= h: 10 | x = cv2.resize(x, (int(round((h/w)*size)), int(size)), interpolation=interpolation) 11 | else: 12 | x = cv2.resize(x, (int(size), int(round((w/h)*size))), interpolation=interpolation) 13 | return x -------------------------------------------------------------------------------- /lib/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import utils as skutils 3 | 4 | from rng import np_rng, py_rng 5 | 6 | def center_crop(x, ph, pw=None): 7 | if pw is None: 8 | pw = ph 9 | h, w = x.shape[:2] 10 | j = int(round((h - ph)/2.)) 11 | i = int(round((w - pw)/2.)) 12 | return x[j:j+ph, i:i+pw] 13 | 14 | def patch(x, ph, pw=None): 15 | if pw is None: 16 | pw = ph 17 | h, w = x.shape[:2] 18 | j = py_rng.randint(0, h-ph) 19 | i = py_rng.randint(0, w-pw) 20 | x = x[j:j+ph, i:i+pw] 21 | return x 22 | 23 | def list_shuffle(*data): 24 | idxs = np_rng.permutation(np.arange(len(data[0]))) 25 | if len(data) == 1: 26 | return [data[0][idx] for idx in idxs] 27 | else: 28 | return [[d[idx] for idx in idxs] for d in data] 29 | 30 | def shuffle(*arrays, **options): 31 | if isinstance(arrays[0][0], basestring): 32 | return list_shuffle(*arrays) 33 | else: 34 | return skutils.shuffle(*arrays, random_state=np_rng) 35 | 36 | def OneHot(X, n=None, negative_class=0.): 37 | X = np.asarray(X).flatten() 38 | if n is None: 39 | n = np.max(X) + 1 40 | Xoh = np.ones((len(X), n)) * negative_class 41 | Xoh[np.arange(len(X)), X] = 1. 42 | return Xoh 43 | 44 | def iter_data(*data, **kwargs): 45 | size = kwargs.get('size', 128) 46 | try: 47 | n = len(data[0]) 48 | except: 49 | n = data[0].shape[0] 50 | batches = n / size 51 | if n % size != 0: 52 | batches += 1 53 | 54 | for b in range(batches): 55 | start = b * size 56 | end = (b + 1) * size 57 | if end > n: 58 | end = n 59 | if len(data) == 1: 60 | yield data[0][start:end] 61 | else: 62 | yield tuple([d[start:end] for d in data]) -------------------------------------------------------------------------------- /lib/inits.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import theano 5 | import theano.tensor as T 6 | 7 | from theano_utils import sharedX, floatX, intX 8 | from rng import np_rng 9 | 10 | class Uniform(object): 11 | def __init__(self, scale=0.05): 12 | self.scale = 0.05 13 | 14 | def __call__(self, shape, name=None): 15 | return sharedX(np_rng.uniform(low=-self.scale, high=self.scale, size=shape), name=name) 16 | 17 | class Normal(object): 18 | def __init__(self, loc=0., scale=0.05): 19 | self.scale = scale 20 | self.loc = loc 21 | 22 | def __call__(self, shape, name=None): 23 | return sharedX(np_rng.normal(loc=self.loc, scale=self.scale, size=shape), name=name) 24 | 25 | class Orthogonal(object): 26 | """ benanne lasagne ortho init (faster than qr approach)""" 27 | def __init__(self, scale=1.1): 28 | self.scale = scale 29 | 30 | def __call__(self, shape, name=None): 31 | print 'called orthogonal init with shape', shape 32 | flat_shape = (shape[0], np.prod(shape[1:])) 33 | a = np_rng.normal(0.0, 1.0, flat_shape) 34 | u, _, v = np.linalg.svd(a, full_matrices=False) 35 | q = u if u.shape == flat_shape else v # pick the one with the correct shape 36 | q = q.reshape(shape) 37 | return sharedX(self.scale * q[:shape[0], :shape[1]], name=name) 38 | 39 | class Frob(object): 40 | 41 | def __init__(self): 42 | pass 43 | 44 | def __call__(self, shape, name=None): 45 | r = np_rng.normal(loc=0, scale=0.01, size=shape) 46 | r = r/np.sqrt(np.sum(r**2))*np.sqrt(shape[1]) 47 | return sharedX(r, name=name) 48 | 49 | class Constant(object): 50 | 51 | def __init__(self, c=0.): 52 | self.c = c 53 | 54 | def __call__(self, shape, name=None): 55 | return sharedX(np.ones(shape) * self.c, name=name) 56 | 57 | class ConvIdentity(object): 58 | 59 | def __init__(self, scale=1.): 60 | self.scale = scale 61 | 62 | def __call__(self, shape, name=None): 63 | w = np.zeros(shape) 64 | ycenter = shape[2]//2 65 | xcenter = shape[3]//2 66 | 67 | if shape[0] == shape[1]: 68 | o_idxs = np.arange(shape[0]) 69 | i_idxs = np.arange(shape[1]) 70 | elif shape[1] < shape[0]: 71 | o_idxs = np.arange(shape[0]) 72 | i_idxs = np.random.permutation(np.tile(np.arange(shape[1]), shape[0]/shape[1]+1))[:shape[0]] 73 | w[o_idxs, i_idxs, ycenter, xcenter] = self.scale 74 | return sharedX(w, name=name) 75 | 76 | class Identity(object): 77 | 78 | def __init__(self, scale=0.25): 79 | self.scale = scale 80 | 81 | def __call__(self, shape, name=None): 82 | if shape[0] != shape[1]: 83 | w = np.zeros(shape) 84 | o_idxs = np.arange(shape[0]) 85 | i_idxs = np.random.permutation(np.tile(np.arange(shape[1]), shape[0]/shape[1]+1))[:shape[0]] 86 | w[o_idxs, i_idxs] = self.scale 87 | else: 88 | w = np.identity(shape[0]) * self.scale 89 | return sharedX(w, name=name) 90 | 91 | class ReluInit(object): 92 | 93 | def __init__(self): 94 | pass 95 | 96 | def __call__(self, shape, name=None): 97 | if len(shape) == 2: 98 | scale = np.sqrt(2./shape[0]) 99 | elif len(shape) == 4: 100 | scale = np.sqrt(2./np.prod(shape[1:])) 101 | else: 102 | raise NotImplementedError 103 | return sharedX(np_rng.normal(size=shape, scale=scale), name=name) -------------------------------------------------------------------------------- /lib/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import theano 4 | import theano.tensor as T 5 | import gc 6 | import time 7 | 8 | from theano_utils import floatX 9 | from ops import euclidean, cosine 10 | 11 | from sklearn import metrics 12 | from sklearn.linear_model import LogisticRegression as LR 13 | 14 | def cv_reg_lr(trX, trY, vaX, vaY, Cs=[0.01, 0.05, 0.1, 0.5, 1., 5., 10., 50., 100.]): 15 | tr_accs = [] 16 | va_accs = [] 17 | models = [] 18 | for C in Cs: 19 | model = LR(C=C) 20 | model.fit(trX, trY) 21 | tr_pred = model.predict(trX) 22 | va_pred = model.predict(vaX) 23 | tr_acc = metrics.accuracy_score(trY, tr_pred) 24 | va_acc = metrics.accuracy_score(vaY, va_pred) 25 | print '%.4f %.4f %.4f'%(C, tr_acc, va_acc) 26 | tr_accs.append(tr_acc) 27 | va_accs.append(va_acc) 28 | models.append(model) 29 | best = np.argmax(va_accs) 30 | print 'best model C: %.4f tr_acc: %.4f va_acc: %.4f'%(Cs[best], tr_accs[best], va_accs[best]) 31 | return models[best] 32 | 33 | def gpu_nnc_predict(trX, trY, teX, metric='cosine', batch_size=4096): 34 | if metric == 'cosine': 35 | metric_fn = cosine_dist 36 | else: 37 | metric_fn = euclid_dist 38 | idxs = [] 39 | for i in range(0, len(teX), batch_size): 40 | mb_dists = [] 41 | mb_idxs = [] 42 | for j in range(0, len(trX), batch_size): 43 | dist = metric_fn(floatX(teX[i:i+batch_size]), floatX(trX[j:j+batch_size])) 44 | if metric == 'cosine': 45 | mb_dists.append(np.max(dist, axis=1)) 46 | mb_idxs.append(j+np.argmax(dist, axis=1)) 47 | else: 48 | mb_dists.append(np.min(dist, axis=1)) 49 | mb_idxs.append(j+np.argmin(dist, axis=1)) 50 | mb_idxs = np.asarray(mb_idxs) 51 | mb_dists = np.asarray(mb_dists) 52 | if metric == 'cosine': 53 | i = mb_idxs[np.argmax(mb_dists, axis=0), np.arange(mb_idxs.shape[1])] 54 | else: 55 | i = mb_idxs[np.argmin(mb_dists, axis=0), np.arange(mb_idxs.shape[1])] 56 | idxs.append(i) 57 | idxs = np.concatenate(idxs, axis=0) 58 | nearest = trY[idxs] 59 | return nearest 60 | 61 | def gpu_nnd_score(trX, teX, metric='cosine', batch_size=4096): 62 | if metric == 'cosine': 63 | metric_fn = cosine_dist 64 | else: 65 | metric_fn = euclid_dist 66 | dists = [] 67 | for i in range(0, len(teX), batch_size): 68 | mb_dists = [] 69 | for j in range(0, len(trX), batch_size): 70 | dist = metric_fn(floatX(teX[i:i+batch_size]), floatX(trX[j:j+batch_size])) 71 | if metric == 'cosine': 72 | mb_dists.append(np.max(dist, axis=1)) 73 | else: 74 | mb_dists.append(np.min(dist, axis=1)) 75 | mb_dists = np.asarray(mb_dists) 76 | if metric == 'cosine': 77 | d = np.max(mb_dists, axis=0) 78 | else: 79 | d = np.min(mb_dists, axis=0) 80 | dists.append(d) 81 | dists = np.concatenate(dists, axis=0) 82 | return float(np.mean(dists)) 83 | 84 | A = T.matrix() 85 | B = T.matrix() 86 | 87 | ed = euclidean(A, B) 88 | cd = cosine(A, B) 89 | 90 | cosine_dist = theano.function([A, B], cd) 91 | euclid_dist = theano.function([A, B], ed) 92 | 93 | def nnc_score(trX, trY, teX, teY, metric='euclidean'): 94 | pred = gpu_nnc_predict(trX, trY, teX, metric=metric) 95 | acc = metrics.accuracy_score(teY, pred) 96 | return acc*100. 97 | 98 | def nnd_score(trX, teX, metric='euclidean'): 99 | return gpu_nnd_score(trX, teX, metric=metric) 100 | -------------------------------------------------------------------------------- /lib/ops.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable, 4 | host_from_gpu, 5 | gpu_contiguous, HostFromGpu, 6 | gpu_alloc_empty) 7 | from theano.sandbox.cuda.dnn import GpuDnnConvDesc, GpuDnnConv, GpuDnnConvGradI, dnn_conv, dnn_pool 8 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 9 | 10 | from rng import t_rng 11 | 12 | t_rng = RandomStreams() 13 | 14 | def l2normalize(x, axis=1, e=1e-8, keepdims=True): 15 | return x/l2norm(x, axis=axis, e=e, keepdims=keepdims) 16 | 17 | def l2norm(x, axis=1, e=1e-8, keepdims=True): 18 | return T.sqrt(T.sum(T.sqr(x), axis=axis, keepdims=keepdims) + e) 19 | 20 | def cosine(x, y): 21 | d = T.dot(x, y.T) 22 | d /= l2norm(x).dimshuffle(0, 'x') 23 | d /= l2norm(y).dimshuffle('x', 0) 24 | return d 25 | 26 | def euclidean(x, y, e=1e-8): 27 | xx = T.sqr(T.sqrt((x*x).sum(axis=1) + e)) 28 | yy = T.sqr(T.sqrt((y*y).sum(axis=1) + e)) 29 | dist = T.dot(x, y.T) 30 | dist *= -2 31 | dist += xx.dimshuffle(0, 'x') 32 | dist += yy.dimshuffle('x', 0) 33 | dist = T.sqrt(dist) 34 | return dist 35 | 36 | def dropout(X, p=0.): 37 | """ 38 | dropout using activation scaling to avoid test time weight rescaling 39 | """ 40 | if p > 0: 41 | retain_prob = 1 - p 42 | X *= t_rng.binomial(X.shape, p=retain_prob, dtype=theano.config.floatX) 43 | X /= retain_prob 44 | return X 45 | 46 | def conv_cond_concat(x, y): 47 | """ 48 | concatenate conditioning vector on feature map axis 49 | """ 50 | return T.concatenate([x, y*T.ones((x.shape[0], y.shape[1], x.shape[2], x.shape[3]))], axis=1) 51 | 52 | def batchnorm(X, g=None, b=None, u=None, s=None, a=1., e=1e-8): 53 | """ 54 | batchnorm with support for not using scale and shift parameters 55 | as well as inference values (u and s) and partial batchnorm (via a) 56 | will detect and use convolutional or fully connected version 57 | """ 58 | if X.ndim == 4: 59 | if u is not None and s is not None: 60 | b_u = u.dimshuffle('x', 0, 'x', 'x') 61 | b_s = s.dimshuffle('x', 0, 'x', 'x') 62 | else: 63 | b_u = T.mean(X, axis=[0, 2, 3]).dimshuffle('x', 0, 'x', 'x') 64 | b_s = T.mean(T.sqr(X - b_u), axis=[0, 2, 3]).dimshuffle('x', 0, 'x', 'x') 65 | if a != 1: 66 | b_u = (1. - a)*0. + a*b_u 67 | b_s = (1. - a)*1. + a*b_s 68 | X = (X - b_u) / T.sqrt(b_s + e) 69 | if g is not None and b is not None: 70 | X = X*g.dimshuffle('x', 0, 'x', 'x') + b.dimshuffle('x', 0, 'x', 'x') 71 | elif X.ndim == 2: 72 | if u is None and s is None: 73 | u = T.mean(X, axis=0) 74 | s = T.mean(T.sqr(X - u), axis=0) 75 | if a != 1: 76 | u = (1. - a)*0. + a*u 77 | s = (1. - a)*1. + a*s 78 | X = (X - u) / T.sqrt(s + e) 79 | if g is not None and b is not None: 80 | X = X*g + b 81 | else: 82 | raise NotImplementedError 83 | return X 84 | 85 | def deconv(X, w, subsample=(1, 1), border_mode=(0, 0), conv_mode='conv'): 86 | """ 87 | sets up dummy convolutional forward pass and uses its grad as deconv 88 | currently only tested/working with same padding 89 | """ 90 | img = gpu_contiguous(X) 91 | kerns = gpu_contiguous(w) 92 | desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample, 93 | conv_mode=conv_mode)(gpu_alloc_empty(img.shape[0], kerns.shape[1], img.shape[2]*subsample[0], img.shape[3]*subsample[1]).shape, kerns.shape) 94 | out = gpu_alloc_empty(img.shape[0], kerns.shape[1], img.shape[2]*subsample[0], img.shape[3]*subsample[1]) 95 | d_img = GpuDnnConvGradI()(kerns, img, out, desc) 96 | return d_img -------------------------------------------------------------------------------- /lib/rng.py: -------------------------------------------------------------------------------- 1 | from numpy.random import RandomState 2 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 3 | from random import Random 4 | 5 | seed = 42 6 | 7 | py_rng = Random(seed) 8 | np_rng = RandomState(seed) 9 | t_rng = RandomStreams(seed) 10 | 11 | def set_seed(n): 12 | global seed, py_rng, np_rng, t_rng 13 | 14 | seed = n 15 | py_rng = Random(seed) 16 | np_rng = RandomState(seed) 17 | t_rng = RandomStreams(seed) 18 | -------------------------------------------------------------------------------- /lib/theano_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import theano 3 | 4 | def intX(X): 5 | return np.asarray(X, dtype=np.int32) 6 | 7 | def floatX(X): 8 | return np.asarray(X, dtype=theano.config.floatX) 9 | 10 | def sharedX(X, dtype=theano.config.floatX, name=None): 11 | return theano.shared(np.asarray(X, dtype=dtype), name=name) 12 | 13 | def shared0s(shape, dtype=theano.config.floatX, name=None): 14 | return sharedX(np.zeros(shape), dtype=dtype, name=name) 15 | 16 | def sharedNs(shape, n, dtype=theano.config.floatX, name=None): 17 | return sharedX(np.ones(shape)*n, dtype=dtype, name=name) -------------------------------------------------------------------------------- /lib/updates.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | 5 | from theano_utils import shared0s, floatX, sharedX 6 | from ops import l2norm 7 | 8 | def clip_norm(g, c, n): 9 | if c > 0: 10 | g = T.switch(T.ge(n, c), g*c/n, g) 11 | return g 12 | 13 | def clip_norms(gs, c): 14 | norm = T.sqrt(sum([T.sum(g**2) for g in gs])) 15 | return [clip_norm(g, c, norm) for g in gs] 16 | 17 | class Regularizer(object): 18 | 19 | def __init__(self, l1=0., l2=0., maxnorm=0., l2norm=False, frobnorm=False): 20 | self.__dict__.update(locals()) 21 | 22 | def max_norm(self, p, maxnorm): 23 | if maxnorm > 0: 24 | norms = T.sqrt(T.sum(T.sqr(p), axis=0)) 25 | desired = T.clip(norms, 0, maxnorm) 26 | p = p * (desired/ (1e-7 + norms)) 27 | return p 28 | 29 | def l2_norm(self, p): 30 | return p/l2norm(p, axis=0) 31 | 32 | def frob_norm(self, p, nrows): 33 | return (p/T.sqrt(T.sum(T.sqr(p))))*T.sqrt(nrows) 34 | 35 | def gradient_regularize(self, p, g): 36 | g += p * self.l2 37 | g += T.sgn(p) * self.l1 38 | return g 39 | 40 | def weight_regularize(self, p): 41 | p = self.max_norm(p, self.maxnorm) 42 | if self.l2norm: 43 | p = self.l2_norm(p) 44 | if self.frobnorm > 0: 45 | p = self.frob_norm(p, self.frobnorm) 46 | return p 47 | 48 | 49 | class Update(object): 50 | 51 | def __init__(self, regularizer=Regularizer(), clipnorm=0.): 52 | self.__dict__.update(locals()) 53 | 54 | def __call__(self, params, grads): 55 | raise NotImplementedError 56 | 57 | class SGD(Update): 58 | 59 | def __init__(self, lr=0.01, *args, **kwargs): 60 | Update.__init__(self, *args, **kwargs) 61 | self.__dict__.update(locals()) 62 | 63 | def __call__(self, params, cost): 64 | updates = [] 65 | grads = T.grad(cost, params) 66 | grads = clip_norms(grads, self.clipnorm) 67 | for p,g in zip(params,grads): 68 | g = self.regularizer.gradient_regularize(p, g) 69 | updated_p = p - self.lr * g 70 | updated_p = self.regularizer.weight_regularize(updated_p) 71 | updates.append((p, updated_p)) 72 | return updates 73 | 74 | class Momentum(Update): 75 | 76 | def __init__(self, lr=0.01, momentum=0.9, *args, **kwargs): 77 | Update.__init__(self, *args, **kwargs) 78 | self.__dict__.update(locals()) 79 | 80 | def __call__(self, params, cost): 81 | updates = [] 82 | grads = T.grad(cost, params) 83 | grads = clip_norms(grads, self.clipnorm) 84 | for p,g in zip(params,grads): 85 | g = self.regularizer.gradient_regularize(p, g) 86 | m = theano.shared(p.get_value() * 0.) 87 | v = (self.momentum * m) - (self.lr * g) 88 | updates.append((m, v)) 89 | 90 | updated_p = p + v 91 | updated_p = self.regularizer.weight_regularize(updated_p) 92 | updates.append((p, updated_p)) 93 | return updates 94 | 95 | 96 | class NAG(Update): 97 | 98 | def __init__(self, lr=0.01, momentum=0.9, *args, **kwargs): 99 | Update.__init__(self, *args, **kwargs) 100 | self.__dict__.update(locals()) 101 | 102 | def __call__(self, params, cost): 103 | updates = [] 104 | grads = T.grad(cost, params) 105 | grads = clip_norms(grads, self.clipnorm) 106 | for p, g in zip(params, grads): 107 | g = self.regularizer.gradient_regularize(p, g) 108 | m = theano.shared(p.get_value() * 0.) 109 | v = (self.momentum * m) - (self.lr * g) 110 | 111 | updated_p = p + self.momentum * v - self.lr * g 112 | updated_p = self.regularizer.weight_regularize(updated_p) 113 | updates.append((m,v)) 114 | updates.append((p, updated_p)) 115 | return updates 116 | 117 | 118 | class RMSprop(Update): 119 | 120 | def __init__(self, lr=0.001, rho=0.9, epsilon=1e-6, *args, **kwargs): 121 | Update.__init__(self, *args, **kwargs) 122 | self.__dict__.update(locals()) 123 | 124 | def __call__(self, params, cost): 125 | updates = [] 126 | grads = T.grad(cost, params) 127 | grads = clip_norms(grads, self.clipnorm) 128 | for p,g in zip(params,grads): 129 | g = self.regularizer.gradient_regularize(p, g) 130 | acc = theano.shared(p.get_value() * 0.) 131 | acc_new = self.rho * acc + (1 - self.rho) * g ** 2 132 | updates.append((acc, acc_new)) 133 | 134 | updated_p = p - self.lr * (g / T.sqrt(acc_new + self.epsilon)) 135 | updated_p = self.regularizer.weight_regularize(updated_p) 136 | updates.append((p, updated_p)) 137 | return updates 138 | 139 | 140 | class Adam(Update): 141 | 142 | def __init__(self, lr=0.001, b1=0.9, b2=0.999, e=1e-8, l=1-1e-8, *args, **kwargs): 143 | Update.__init__(self, *args, **kwargs) 144 | self.__dict__.update(locals()) 145 | 146 | def __call__(self, params, cost): 147 | updates = [] 148 | grads = T.grad(cost, params) 149 | grads = clip_norms(grads, self.clipnorm) 150 | t = theano.shared(floatX(1.)) 151 | b1_t = self.b1*self.l**(t-1) 152 | 153 | for p, g in zip(params, grads): 154 | g = self.regularizer.gradient_regularize(p, g) 155 | m = theano.shared(p.get_value() * 0.) 156 | v = theano.shared(p.get_value() * 0.) 157 | 158 | m_t = b1_t*m + (1 - b1_t)*g 159 | v_t = self.b2*v + (1 - self.b2)*g**2 160 | m_c = m_t / (1-self.b1**t) 161 | v_c = v_t / (1-self.b2**t) 162 | p_t = p - (self.lr * m_c) / (T.sqrt(v_c) + self.e) 163 | p_t = self.regularizer.weight_regularize(p_t) 164 | updates.append((m, m_t)) 165 | updates.append((v, v_t)) 166 | updates.append((p, p_t) ) 167 | updates.append((t, t + 1.)) 168 | return updates 169 | 170 | 171 | class Adagrad(Update): 172 | 173 | def __init__(self, lr=0.01, epsilon=1e-6, *args, **kwargs): 174 | Update.__init__(self, *args, **kwargs) 175 | self.__dict__.update(locals()) 176 | 177 | def __call__(self, params, cost): 178 | updates = [] 179 | grads = T.grad(cost, params) 180 | grads = clip_norms(grads, self.clipnorm) 181 | for p,g in zip(params,grads): 182 | g = self.regularizer.gradient_regularize(p, g) 183 | acc = theano.shared(p.get_value() * 0.) 184 | acc_t = acc + g ** 2 185 | updates.append((acc, acc_t)) 186 | 187 | p_t = p - (self.lr / T.sqrt(acc_t + self.epsilon)) * g 188 | p_t = self.regularizer.weight_regularize(p_t) 189 | updates.append((p, p_t)) 190 | return updates 191 | 192 | 193 | class Adadelta(Update): 194 | 195 | def __init__(self, lr=0.5, rho=0.95, epsilon=1e-6, *args, **kwargs): 196 | Update.__init__(self, *args, **kwargs) 197 | self.__dict__.update(locals()) 198 | 199 | def __call__(self, params, cost): 200 | updates = [] 201 | grads = T.grad(cost, params) 202 | grads = clip_norms(grads, self.clipnorm) 203 | for p,g in zip(params,grads): 204 | g = self.regularizer.gradient_regularize(p, g) 205 | 206 | acc = theano.shared(p.get_value() * 0.) 207 | acc_delta = theano.shared(p.get_value() * 0.) 208 | acc_new = self.rho * acc + (1 - self.rho) * g ** 2 209 | updates.append((acc,acc_new)) 210 | 211 | update = g * T.sqrt(acc_delta + self.epsilon) / T.sqrt(acc_new + self.epsilon) 212 | updated_p = p - self.lr * update 213 | updated_p = self.regularizer.weight_regularize(updated_p) 214 | updates.append((p, updated_p)) 215 | 216 | acc_delta_new = self.rho * acc_delta + (1 - self.rho) * update ** 2 217 | updates.append((acc_delta,acc_delta_new)) 218 | return updates 219 | 220 | 221 | class NoUpdate(Update): 222 | 223 | def __init__(self, lr=0.01, momentum=0.9, *args, **kwargs): 224 | Update.__init__(self, *args, **kwargs) 225 | self.__dict__.update(locals()) 226 | 227 | def __call__(self, params, cost): 228 | updates = [] 229 | for p in params: 230 | updates.append((p, p)) 231 | return updates 232 | -------------------------------------------------------------------------------- /lib/vis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.misc import imsave 3 | 4 | def grayscale_grid_vis(X, (nh, nw), save_path=None): 5 | h, w = X[0].shape[:2] 6 | img = np.zeros((h*nh, w*nw)) 7 | for n, x in enumerate(X): 8 | j = n/nw 9 | i = n%nw 10 | img[j*h:j*h+h, i*w:i*w+w] = x 11 | if save_path is not None: 12 | imsave(save_path, img) 13 | return img 14 | 15 | def color_grid_vis(X, (nh, nw), save_path=None): 16 | h, w = X[0].shape[:2] 17 | img = np.zeros((h*nh, w*nw, 3)) 18 | for n, x in enumerate(X): 19 | j = n/nw 20 | i = n%nw 21 | img[j*h:j*h+h, i*w:i*w+w, :] = x 22 | if save_path is not None: 23 | imsave(save_path, img) 24 | return img 25 | 26 | def grayscale_weight_grid_vis(w, (nh, nw), save_path=None): 27 | w = (w+w.min())/(w.max()-w.min()) 28 | return grayscale_grid_vis(w, (nh, nw), save_path=save_path) -------------------------------------------------------------------------------- /mnist/README.md: -------------------------------------------------------------------------------- 1 | Modify data_dir in lib/config.py to point to directory with mnist files. 2 | 3 | Run train_cond_dcgan.py to train mnist model from appendix. It will create a few folders and save training info, model parameters, and samples periodically. Should take ~ an hour to run on a good GPU. 4 | 5 | Libs you'll need installed/configured to run it: 6 | - theano 7 | - cudnn 8 | - sklearn 9 | - numpy 10 | - scipy 11 | - matplotlib 12 | - tqdm -------------------------------------------------------------------------------- /mnist/load.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | 4 | import numpy as np 5 | import os 6 | from time import time 7 | from collections import Counter 8 | import random 9 | from matplotlib import pyplot as plt 10 | 11 | from lib.data_utils import shuffle 12 | from lib.config import data_dir 13 | 14 | def mnist(): 15 | fd = open(os.path.join(data_dir,'train-images.idx3-ubyte')) 16 | loaded = np.fromfile(file=fd,dtype=np.uint8) 17 | trX = loaded[16:].reshape((60000,28*28)).astype(float) 18 | 19 | fd = open(os.path.join(data_dir,'train-labels.idx1-ubyte')) 20 | loaded = np.fromfile(file=fd,dtype=np.uint8) 21 | trY = loaded[8:].reshape((60000)) 22 | 23 | fd = open(os.path.join(data_dir,'t10k-images.idx3-ubyte')) 24 | loaded = np.fromfile(file=fd,dtype=np.uint8) 25 | teX = loaded[16:].reshape((10000,28*28)).astype(float) 26 | 27 | fd = open(os.path.join(data_dir,'t10k-labels.idx1-ubyte')) 28 | loaded = np.fromfile(file=fd,dtype=np.uint8) 29 | teY = loaded[8:].reshape((10000)) 30 | 31 | trY = np.asarray(trY) 32 | teY = np.asarray(teY) 33 | 34 | return trX, teX, trY, teY 35 | 36 | def mnist_with_valid_set(): 37 | trX, teX, trY, teY = mnist() 38 | 39 | trX, trY = shuffle(trX, trY) 40 | vaX = trX[50000:] 41 | vaY = trY[50000:] 42 | trX = trX[:50000] 43 | trY = trY[:50000] 44 | 45 | return trX, vaX, teX, trY, vaY, teY -------------------------------------------------------------------------------- /mnist/train_cond_dcgan.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | 4 | import os 5 | import json 6 | from time import time 7 | import numpy as np 8 | from tqdm import tqdm 9 | from matplotlib import pyplot as plt 10 | from sklearn.externals import joblib 11 | 12 | import theano 13 | import theano.tensor as T 14 | from theano.sandbox.cuda.dnn import dnn_conv 15 | 16 | from lib import activations 17 | from lib import updates 18 | from lib import inits 19 | from lib.vis import grayscale_grid_vis 20 | from lib.rng import py_rng, np_rng 21 | from lib.ops import batchnorm, conv_cond_concat, deconv, dropout 22 | from lib.theano_utils import floatX, sharedX 23 | from lib.data_utils import OneHot, shuffle, iter_data 24 | from lib.metrics import nnc_score, nnd_score 25 | 26 | from load import mnist_with_valid_set 27 | 28 | trX, vaX, teX, trY, vaY, teY = mnist_with_valid_set() 29 | 30 | vaX = floatX(vaX)/255. 31 | 32 | k = 1 # # of discrim updates for each gen update 33 | l2 = 2.5e-5 # l2 weight decay 34 | b1 = 0.5 # momentum term of adam 35 | nc = 1 # # of channels in image 36 | ny = 10 # # of classes 37 | nbatch = 128 # # of examples in batch 38 | npx = 28 # # of pixels width/height of images 39 | nz = 100 # # of dim for Z 40 | ngfc = 1024 # # of gen units for fully connected layers 41 | ndfc = 1024 # # of discrim units for fully connected layers 42 | ngf = 64 # # of gen filters in first conv layer 43 | ndf = 64 # # of discrim filters in first conv layer 44 | nx = npx*npx*nc # # of dimensions in X 45 | niter = 100 # # of iter at starting learning rate 46 | niter_decay = 100 # # of iter to linearly decay learning rate to zero 47 | lr = 0.0002 # initial learning rate for adam 48 | ntrain, nval, ntest = len(trX), len(vaX), len(teX) 49 | 50 | def transform(X): 51 | return (floatX(X)/255.).reshape(-1, nc, npx, npx) 52 | 53 | def inverse_transform(X): 54 | X = X.reshape(-1, npx, npx) 55 | return X 56 | 57 | desc = 'cond_dcgan' 58 | model_dir = 'models/%s'%desc 59 | samples_dir = 'samples/%s'%desc 60 | if not os.path.exists('logs/'): 61 | os.makedirs('logs/') 62 | if not os.path.exists(model_dir): 63 | os.makedirs(model_dir) 64 | if not os.path.exists(samples_dir): 65 | os.makedirs(samples_dir) 66 | 67 | relu = activations.Rectify() 68 | sigmoid = activations.Sigmoid() 69 | lrelu = activations.LeakyRectify() 70 | bce = T.nnet.binary_crossentropy 71 | 72 | gifn = inits.Normal(scale=0.02) 73 | difn = inits.Normal(scale=0.02) 74 | 75 | gw = gifn((nz+ny, ngfc), 'gw') 76 | gw2 = gifn((ngfc+ny, ngf*2*7*7), 'gw2') 77 | gw3 = gifn((ngf*2+ny, ngf, 5, 5), 'gw3') 78 | gwx = gifn((ngf+ny, nc, 5, 5), 'gwx') 79 | 80 | dw = difn((ndf, nc+ny, 5, 5), 'dw') 81 | dw2 = difn((ndf*2, ndf+ny, 5, 5), 'dw2') 82 | dw3 = difn((ndf*2*7*7+ny, ndfc), 'dw3') 83 | dwy = difn((ndfc+ny, 1), 'dwy') 84 | 85 | gen_params = [gw, gw2, gw3, gwx] 86 | discrim_params = [dw, dw2, dw3, dwy] 87 | 88 | def gen(Z, Y, w, w2, w3, wx): 89 | yb = Y.dimshuffle(0, 1, 'x', 'x') 90 | Z = T.concatenate([Z, Y], axis=1) 91 | h = relu(batchnorm(T.dot(Z, w))) 92 | h = T.concatenate([h, Y], axis=1) 93 | h2 = relu(batchnorm(T.dot(h, w2))) 94 | h2 = h2.reshape((h2.shape[0], ngf*2, 7, 7)) 95 | h2 = conv_cond_concat(h2, yb) 96 | h3 = relu(batchnorm(deconv(h2, w3, subsample=(2, 2), border_mode=(2, 2)))) 97 | h3 = conv_cond_concat(h3, yb) 98 | x = sigmoid(deconv(h3, wx, subsample=(2, 2), border_mode=(2, 2))) 99 | return x 100 | 101 | def discrim(X, Y, w, w2, w3, wy): 102 | yb = Y.dimshuffle(0, 1, 'x', 'x') 103 | X = conv_cond_concat(X, yb) 104 | h = lrelu(dnn_conv(X, w, subsample=(2, 2), border_mode=(2, 2))) 105 | h = conv_cond_concat(h, yb) 106 | h2 = lrelu(batchnorm(dnn_conv(h, w2, subsample=(2, 2), border_mode=(2, 2)))) 107 | h2 = T.flatten(h2, 2) 108 | h2 = T.concatenate([h2, Y], axis=1) 109 | h3 = lrelu(batchnorm(T.dot(h2, w3))) 110 | h3 = T.concatenate([h3, Y], axis=1) 111 | y = sigmoid(T.dot(h3, wy)) 112 | return y 113 | 114 | X = T.tensor4() 115 | Z = T.matrix() 116 | Y = T.matrix() 117 | 118 | gX = gen(Z, Y, *gen_params) 119 | 120 | p_real = discrim(X, Y, *discrim_params) 121 | p_gen = discrim(gX, Y, *discrim_params) 122 | 123 | d_cost_real = bce(p_real, T.ones(p_real.shape)).mean() 124 | d_cost_gen = bce(p_gen, T.zeros(p_gen.shape)).mean() 125 | g_cost_d = bce(p_gen, T.ones(p_gen.shape)).mean() 126 | 127 | d_cost = d_cost_real + d_cost_gen 128 | g_cost = g_cost_d 129 | 130 | cost = [g_cost, d_cost, g_cost_d, d_cost_real, d_cost_gen] 131 | 132 | lrt = sharedX(lr) 133 | d_updater = updates.Adam(lr=lrt, b1=b1, regularizer=updates.Regularizer(l2=l2)) 134 | g_updater = updates.Adam(lr=lrt, b1=b1, regularizer=updates.Regularizer(l2=l2)) 135 | d_updates = d_updater(discrim_params, d_cost) 136 | g_updates = g_updater(gen_params, g_cost) 137 | updates = d_updates + g_updates 138 | 139 | print 'COMPILING' 140 | t = time() 141 | _train_g = theano.function([X, Z, Y], cost, updates=g_updates) 142 | _train_d = theano.function([X, Z, Y], cost, updates=d_updates) 143 | _gen = theano.function([Z, Y], gX) 144 | print '%.2f seconds to compile theano functions'%(time()-t) 145 | 146 | tr_idxs = np.arange(len(trX)) 147 | trX_vis = np.asarray([[trX[i] for i in py_rng.sample(tr_idxs[trY==y], 20)] for y in range(10)]).reshape(200, -1) 148 | trX_vis = inverse_transform(transform(trX_vis)) 149 | grayscale_grid_vis(trX_vis, (10, 20), 'samples/%s_etl_test.png'%desc) 150 | 151 | sample_zmb = floatX(np_rng.uniform(-1., 1., size=(200, nz))) 152 | sample_ymb = floatX(OneHot(np.asarray([[i for _ in range(20)] for i in range(10)]).flatten(), ny)) 153 | 154 | def gen_samples(n, nbatch=128): 155 | samples = [] 156 | labels = [] 157 | n_gen = 0 158 | for i in range(n/nbatch): 159 | ymb = floatX(OneHot(np_rng.randint(0, 10, nbatch), ny)) 160 | zmb = floatX(np_rng.uniform(-1., 1., size=(nbatch, nz))) 161 | xmb = _gen(zmb, ymb) 162 | samples.append(xmb) 163 | labels.append(np.argmax(ymb, axis=1)) 164 | n_gen += len(xmb) 165 | n_left = n-n_gen 166 | ymb = floatX(OneHot(np_rng.randint(0, 10, n_left), ny)) 167 | zmb = floatX(np_rng.uniform(-1., 1., size=(n_left, nz))) 168 | xmb = _gen(zmb, ymb) 169 | samples.append(xmb) 170 | labels.append(np.argmax(ymb, axis=1)) 171 | return np.concatenate(samples, axis=0), np.concatenate(labels, axis=0) 172 | 173 | f_log = open('logs/%s.ndjson'%desc, 'wb') 174 | log_fields = [ 175 | 'n_epochs', 176 | 'n_updates', 177 | 'n_examples', 178 | 'n_seconds', 179 | '1k_va_nnc_acc', 180 | '10k_va_nnc_acc', 181 | '100k_va_nnc_acc', 182 | '1k_va_nnd', 183 | '10k_va_nnd', 184 | '100k_va_nnd', 185 | 'g_cost', 186 | 'd_cost', 187 | ] 188 | 189 | print desc.upper() 190 | n_updates = 0 191 | n_check = 0 192 | n_epochs = 0 193 | n_updates = 0 194 | n_examples = 0 195 | t = time() 196 | for epoch in range(1, niter+niter_decay+1): 197 | trX, trY = shuffle(trX, trY) 198 | for imb, ymb in tqdm(iter_data(trX, trY, size=nbatch), total=ntrain/nbatch): 199 | imb = transform(imb) 200 | ymb = floatX(OneHot(ymb, ny)) 201 | zmb = floatX(np_rng.uniform(-1., 1., size=(len(imb), nz))) 202 | if n_updates % (k+1) == 0: 203 | cost = _train_g(imb, zmb, ymb) 204 | else: 205 | cost = _train_d(imb, zmb, ymb) 206 | n_updates += 1 207 | n_examples += len(imb) 208 | if (epoch-1) % 5 == 0: 209 | g_cost = float(cost[0]) 210 | d_cost = float(cost[1]) 211 | gX, gY = gen_samples(100000) 212 | gX = gX.reshape(len(gX), -1) 213 | va_nnc_acc_1k = nnc_score(gX[:1000], gY[:1000], vaX, vaY, metric='euclidean') 214 | va_nnc_acc_10k = nnc_score(gX[:10000], gY[:10000], vaX, vaY, metric='euclidean') 215 | va_nnc_acc_100k = nnc_score(gX[:100000], gY[:100000], vaX, vaY, metric='euclidean') 216 | va_nnd_1k = nnd_score(gX[:1000], vaX, metric='euclidean') 217 | va_nnd_10k = nnd_score(gX[:10000], vaX, metric='euclidean') 218 | va_nnd_100k = nnd_score(gX[:100000], vaX, metric='euclidean') 219 | log = [n_epochs, n_updates, n_examples, time()-t, va_nnc_acc_1k, va_nnc_acc_10k, va_nnc_acc_100k, va_nnd_1k, va_nnd_10k, va_nnd_100k, g_cost, d_cost] 220 | print '%.0f %.2f %.2f %.2f %.4f %.4f %.4f %.4f %.4f'%(epoch, va_nnc_acc_1k, va_nnc_acc_10k, va_nnc_acc_100k, va_nnd_1k, va_nnd_10k, va_nnd_100k, g_cost, d_cost) 221 | f_log.write(json.dumps(dict(zip(log_fields, log)))+'\n') 222 | f_log.flush() 223 | 224 | samples = np.asarray(_gen(sample_zmb, sample_ymb)) 225 | grayscale_grid_vis(inverse_transform(samples), (10, 20), 'samples/%s/%d.png'%(desc, n_epochs)) 226 | n_epochs += 1 227 | if n_epochs > niter: 228 | lrt.set_value(floatX(lrt.get_value() - lr/niter_decay)) 229 | if n_epochs in [1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 150, 200]: 230 | joblib.dump([p.get_value() for p in gen_params], 'models/%s/%d_gen_params.jl'%(desc, n_epochs)) 231 | joblib.dump([p.get_value() for p in discrim_params], 'models/%s/%d_discrim_params.jl'%(desc, n_epochs)) -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_01.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_01.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_02.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_02.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_03.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_03.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_04.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_04.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_05.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_05.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_06.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_06.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_07.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_07.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_08.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_08.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_09.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_09.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_10.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_11.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_11.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_12.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_12.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_13.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_13.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_14.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_14.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_15.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_15.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_16.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_16.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_17.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_discrim_params.jl_17.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_01.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_01.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_02.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_02.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_03.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_03.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_04.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_04.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_05.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_05.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_06.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_06.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_07.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_07.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_08.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_08.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_09.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_09.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_10.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_11.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_11.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_12.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_12.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_13.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_13.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_14.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_14.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_15.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_15.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_16.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_16.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_17.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_17.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_18.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_18.npy -------------------------------------------------------------------------------- /models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_19.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/30_gen_params.jl_19.npy -------------------------------------------------------------------------------- /models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_discrim_params.jl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_discrim_params.jl -------------------------------------------------------------------------------- /models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_discrim_params.jl_01.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_discrim_params.jl_01.npy -------------------------------------------------------------------------------- /models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_discrim_params.jl_02.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_discrim_params.jl_02.npy -------------------------------------------------------------------------------- /models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_discrim_params.jl_03.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_discrim_params.jl_03.npy -------------------------------------------------------------------------------- /models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_discrim_params.jl_04.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_discrim_params.jl_04.npy -------------------------------------------------------------------------------- /models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_discrim_params.jl_05.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_discrim_params.jl_05.npy -------------------------------------------------------------------------------- /models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_discrim_params.jl_06.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_discrim_params.jl_06.npy -------------------------------------------------------------------------------- /models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_discrim_params.jl_07.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_discrim_params.jl_07.npy -------------------------------------------------------------------------------- /models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_discrim_params.jl_08.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_discrim_params.jl_08.npy -------------------------------------------------------------------------------- /models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl -------------------------------------------------------------------------------- /models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl_01.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl_01.npy -------------------------------------------------------------------------------- /models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl_02.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl_02.npy -------------------------------------------------------------------------------- /models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl_03.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl_03.npy -------------------------------------------------------------------------------- /models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl_04.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl_04.npy -------------------------------------------------------------------------------- /models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl_05.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl_05.npy -------------------------------------------------------------------------------- /models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl_06.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl_06.npy -------------------------------------------------------------------------------- /models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl_07.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl_07.npy -------------------------------------------------------------------------------- /models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl_08.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl_08.npy -------------------------------------------------------------------------------- /models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl_09.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl_09.npy -------------------------------------------------------------------------------- /models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl_10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Newmu/dcgan_code/ee12b2d15a3856794b8dae77d1eb263c67c36e47/models/svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb/200_gen_params.jl_10.npy -------------------------------------------------------------------------------- /svhn/load.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | 4 | import os 5 | import numpy as np 6 | from scipy.io import loadmat 7 | 8 | from lib.data_utils import shuffle 9 | from lib.config import data_dir 10 | 11 | def svhn(extra=False): 12 | data = loadmat(os.path.join(data_dir, 'train_32x32.mat')) 13 | trX = data['X'].transpose(3, 2, 0, 1) 14 | trY = data['y'].flatten()-1 15 | data = loadmat(os.path.join(data_dir, 'test_32x32.mat')) 16 | teX = data['X'].transpose(3, 2, 0, 1) 17 | teY = data['y'].flatten()-1 18 | if extra: 19 | data = loadmat(os.path.join(data_dir, 'extra_32x32.mat')) 20 | exX = data['X'].transpose(3, 2, 0, 1) 21 | exY = data['y'].flatten()-1 22 | return trX, exX, teX, trY, exY, teY 23 | return trX, teX, trY, teY 24 | 25 | def svhn_with_valid_set(extra=False): 26 | if extra: 27 | trX, exX, teX, trY, exY, teY = svhn(extra=extra) 28 | else: 29 | trX, teX, trY, teY = svhn(extra=extra) 30 | trX, trY = shuffle(trX, trY) 31 | vaX = trX[:10000] 32 | vaY = trY[:10000] 33 | trX = trX[10000:] 34 | trY = trY[10000:] 35 | if extra: 36 | trS = np.asarray([1 for _ in range(len(trY))] + [0 for _ in range(len(exY))]) 37 | trX = np.concatenate([trX, exX], axis=0) 38 | trY = np.concatenate([trY, exY], axis=0) 39 | trX, trY, trS = shuffle(trX, trY, trS) 40 | if extra: 41 | return trX, vaX, teX, trY, vaY, teY, trS 42 | else: 43 | return trX, vaX, teX, trY, vaY, teY 44 | -------------------------------------------------------------------------------- /svhn/svhn_semisup_analysis.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | 4 | import os 5 | import json 6 | from time import time 7 | import numpy as np 8 | from tqdm import tqdm 9 | from sklearn.externals import joblib 10 | 11 | from sklearn import metrics 12 | from sklearn.linear_model import LogisticRegression as LR 13 | from sklearn.svm import LinearSVC as LSVC 14 | 15 | import theano 16 | import theano.tensor as T 17 | from theano.sandbox.cuda.dnn import dnn_conv, dnn_pool 18 | 19 | from lib import activations 20 | from lib import updates 21 | from lib import inits 22 | from lib.vis import color_grid_vis 23 | from lib.rng import py_rng, np_rng 24 | from lib.ops import batchnorm, conv_cond_concat, deconv, dropout 25 | from lib.theano_utils import floatX, sharedX 26 | from lib.data_utils import OneHot, shuffle, iter_data 27 | from lib.metrics import nnc_score, nnd_score 28 | from lib.costs import MSE,CCE 29 | 30 | from load import svhn_with_valid_set 31 | 32 | relu = activations.Rectify() 33 | lrelu = activations.LeakyRectify(leak=0.2) 34 | sigmoid = activations.Sigmoid() 35 | 36 | trX, vaX, teX, trY, vaY, teY = svhn_with_valid_set(extra=False) 37 | 38 | vaX = floatX(vaX)/127.5-1. 39 | trX = floatX(trX)/127.5-1. 40 | teX = floatX(teX)/127.5-1. 41 | 42 | X = T.tensor4() 43 | 44 | desc = 'svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb' 45 | epoch = 200 46 | params = [sharedX(p) for p in joblib.load('../models/%s/%d_discrim_params.jl'%(desc, epoch))] 47 | print desc.upper() 48 | print 'epoch %d'%epoch 49 | 50 | def mean_and_var(X): 51 | u = T.mean(X, axis=[0, 2, 3]) 52 | s = T.mean(T.sqr(X - u.dimshuffle('x', 0, 'x', 'x')), axis=[0, 2, 3]) 53 | return u, s 54 | 55 | def bnorm_statistics(X, w, w2, g2, b2, w3, g3, b3, wy): 56 | h = lrelu(dnn_conv(X, w, subsample=(2, 2), border_mode=(2, 2))) 57 | 58 | h2 = dnn_conv(h, w2, subsample=(2, 2), border_mode=(2, 2)) 59 | h2_u, h2_s = mean_and_var(h2) 60 | h2 = lrelu(batchnorm(h2, g=g2, b=b2)) 61 | 62 | h3 = dnn_conv(h2, w3, subsample=(2, 2), border_mode=(2, 2)) 63 | h3_u, h3_s = mean_and_var(h3) 64 | h3 = lrelu(batchnorm(h3, g=g3, b=b3)) 65 | 66 | h_us = [h2_u, h3_u] 67 | h_ss = [h2_s, h3_s] 68 | return h_us, h_ss 69 | 70 | def infer_bnorm_stats(X, nbatch=128): 71 | U = [np.zeros(128, dtype=theano.config.floatX), np.zeros(256, dtype=theano.config.floatX)] 72 | S = [np.zeros(128, dtype=theano.config.floatX), np.zeros(256, dtype=theano.config.floatX)] 73 | n = 0 74 | for xmb in iter_data(X, size=nbatch): 75 | stats = _bnorm_stats(floatX(xmb)) 76 | umb = stats[:2] 77 | smb = stats[2:] 78 | for i, u in enumerate(umb): 79 | U[i] += u 80 | for i, s in enumerate(smb): 81 | S[i] += s 82 | n += 1 83 | U = [u/n for u in U] 84 | S = [s/n for s in S] 85 | return U, S 86 | 87 | def model(X, 88 | h2_u, h3_u, 89 | h2_s, h3_s, 90 | w, w2, g2, b2, w3, g3, b3, wy 91 | ): 92 | h = lrelu(dnn_conv(X, w, subsample=(2, 2), border_mode=(2, 2))) 93 | h2 = lrelu(batchnorm(dnn_conv(h, w2, subsample=(2, 2), border_mode=(2, 2)), g=g2, b=b2, u=h2_u, s=h2_s)) 94 | h3 = lrelu(batchnorm(dnn_conv(h2, w3, subsample=(2, 2), border_mode=(2, 2)), g=g3, b=b3, u=h3_u, s=h3_s)) 95 | h = T.flatten(dnn_pool(h, (4, 4), (4, 4), mode='max'), 2) 96 | h2 = T.flatten(dnn_pool(h2, (2, 2), (2, 2), mode='max'), 2) 97 | h3 = T.flatten(dnn_pool(h3, (1, 1), (1, 1), mode='max'), 2) 98 | f = T.concatenate([h, h2, h3], axis=1) 99 | return [f] 100 | 101 | X = T.tensor4() 102 | 103 | h_us, h_ss = bnorm_statistics(X, *params) 104 | _bnorm_stats = theano.function([X], h_us + h_ss) 105 | 106 | trU, trS = infer_bnorm_stats(trX) 107 | 108 | HUs = [sharedX(u) for u in trU] 109 | HSs = [sharedX(s) for s in trS] 110 | 111 | targs = [X]+HUs+HSs+params 112 | f = model(*targs) 113 | _features = theano.function([X], f) 114 | 115 | def features(X, nbatch=128): 116 | Xfs = [] 117 | for xmb in iter_data(X, size=nbatch): 118 | fmbs = _features(floatX(xmb)) 119 | for i, fmb in enumerate(fmbs): 120 | Xfs.append(fmb) 121 | return np.concatenate(Xfs, axis=0) 122 | 123 | cs = [0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.01] 124 | vaXt = features(vaX) 125 | mean_va_accs = [] 126 | for c in cs: 127 | tr_accs = [] 128 | va_accs = [] 129 | te_accs = [] 130 | for _ in tqdm(range(10), leave=False, ncols=80): 131 | idxs = np.arange(len(trX)) 132 | classes_idxs = [idxs[trY==y] for y in range(10)] 133 | sampled_idxs = [py_rng.sample(class_idxs, 100) for class_idxs in classes_idxs] 134 | sampled_idxs = np.asarray(sampled_idxs).flatten() 135 | 136 | trXt = features(trX[sampled_idxs]) 137 | 138 | model = LSVC(C=c) 139 | model.fit(trXt[:1000], trY[sampled_idxs]) 140 | tr_pred = model.predict(trXt) 141 | va_pred = model.predict(vaXt) 142 | tr_acc = metrics.accuracy_score(trY[sampled_idxs], tr_pred[:1000]) 143 | va_acc = metrics.accuracy_score(vaY, va_pred) 144 | tr_accs.append(100*(1-tr_acc)) 145 | va_accs.append(100*(1-va_acc)) 146 | mean_va_accs.append(np.mean(va_accs)) 147 | print 'c: %.4f train: %.4f %.4f valid: %.4f %.4f'%(c, np.mean(tr_accs), np.std(tr_accs)*1.96, np.mean(va_accs), np.std(va_accs)*1.96) 148 | best_va_idx = np.argmin(mean_va_accs) 149 | best_va_c = cs[best_va_idx] 150 | print 'best c: %.4f'%best_va_c 151 | teXt = features(teX) 152 | 153 | tr_accs = [] 154 | va_accs = [] 155 | te_accs = [] 156 | for _ in tqdm(range(100), leave=False, ncols=80): 157 | idxs = np.arange(len(trX)) 158 | classes_idxs = [idxs[trY==y] for y in range(10)] 159 | sampled_idxs = [py_rng.sample(class_idxs, 100) for class_idxs in classes_idxs] 160 | sampled_idxs = np.asarray(sampled_idxs).flatten() 161 | 162 | trXt = features(trX[sampled_idxs]) 163 | 164 | model = LSVC(C=best_va_c) 165 | model.fit(trXt[:1000], trY[sampled_idxs]) 166 | tr_pred = model.predict(trXt) 167 | va_pred = model.predict(vaXt) 168 | te_pred = model.predict(teXt) 169 | tr_acc = metrics.accuracy_score(trY[sampled_idxs], tr_pred[:1000]) 170 | va_acc = metrics.accuracy_score(vaY, va_pred) 171 | te_acc = metrics.accuracy_score(teY, te_pred) 172 | # print '%.4f %.4f %.4f %.4f'%(c, 100*(1-tr_acc), 100*(1-va_acc), 100*(1-te_acc)) 173 | tr_accs.append(100*(1-tr_acc)) 174 | va_accs.append(100*(1-va_acc)) 175 | te_accs.append(100*(1-te_acc)) 176 | print 'train: %.4f %.4f valid: %.4f %.4f test: %.4f %.4f'%(np.mean(tr_accs), np.std(tr_accs)*1.96, np.mean(va_accs), np.std(va_accs)*1.96, np.mean(te_accs), np.std(te_accs)*1.96) 177 | --------------------------------------------------------------------------------