├── checkpoint_new2 ├── checkpoint ├── checkpoint~ ├── DCGAN.model-29002.meta ├── DCGAN.model-7002.meta └── DCGAN.model-7502.meta ├── README.md ├── simple-distributions.py ├── reverse_GAN_batch_accuracy.py ├── reverse_GAN_noisy_batch.py ├── reverse_GAN_untrained.py ├── train-dcgan.py ├── LICENSE ├── ops.py ├── utils.py └── model.py /checkpoint_new2/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "DCGAN.model-7002" 2 | all_model_checkpoint_paths: "DCGAN.model-7002" 3 | -------------------------------------------------------------------------------- /checkpoint_new2/checkpoint~: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "DCGAN.model-7502" 2 | all_model_checkpoint_paths: "DCGAN.model-7502" 3 | -------------------------------------------------------------------------------- /checkpoint_new2/DCGAN.model-29002.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SubarnaTripathi/ReverseGAN/HEAD/checkpoint_new2/DCGAN.model-29002.meta -------------------------------------------------------------------------------- /checkpoint_new2/DCGAN.model-7002.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SubarnaTripathi/ReverseGAN/HEAD/checkpoint_new2/DCGAN.model-7002.meta -------------------------------------------------------------------------------- /checkpoint_new2/DCGAN.model-7502.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SubarnaTripathi/ReverseGAN/HEAD/checkpoint_new2/DCGAN.model-7502.meta -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Precise Recovery of Latent Vectors from Generative Adversarial Networks 2 | 3 | This repository implements Zachary C. Lipton and Subarna Tripathi's paper Precise Recovery of Latent Vectors from Generative Adversarial Networks, ICLR 2017 workshop track [arxiv]. https://arxiv.org/abs/1702.04782 4 | 5 | Most of the code in this repository was written by modifying Brandon Amos's DCGAN-completion-tensorflow project, which is MIT-licensed. Our modifications are also MIT-licensed. 6 | The ./checkpoint_new2 directory contains a pre-trained model for faces used in the paper. 7 | 8 | Yixing Lao's implementation in Pytorch: https://github.com/yxlao/pytorch-reverse-gan 9 | -------------------------------------------------------------------------------- /simple-distributions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | from scipy.stats import norm 5 | 6 | import matplotlib as mpl 7 | mpl.use('Agg') 8 | import matplotlib.pyplot as plt 9 | plt.style.use('bmh') 10 | import matplotlib.mlab as mlab 11 | 12 | np.random.seed(0) 13 | 14 | X = np.arange(-3, 3, 0.001) 15 | Y = norm.pdf(X, 0, 1) 16 | 17 | fig = plt.figure() 18 | plt.plot(X, Y) 19 | plt.tight_layout() 20 | plt.savefig("normal-pdf.png") 21 | 22 | nSamples = 35 23 | X = np.random.normal(0, 1, nSamples) 24 | Y = np.zeros(nSamples) 25 | fig = plt.figure(figsize=(7,3)) 26 | plt.scatter(X, Y, color='k') 27 | plt.xlim((-3,3)) 28 | frame = plt.gca() 29 | frame.axes.get_yaxis().set_visible(False) 30 | plt.savefig("normal-samples.png") 31 | 32 | delta = 0.025 33 | x = np.arange(-3.0, 3.0, delta) 34 | y = np.arange(-3.0, 3.0, delta) 35 | X, Y = np.meshgrid(x, y) 36 | Z = mlab.bivariate_normal(X, Y, 1.0, 1.0, 0.0, 0.0) 37 | 38 | plt.figure() 39 | CS = plt.contour(X, Y, Z) 40 | plt.clabel(CS, inline=1, fontsize=10) 41 | 42 | nSamples = 200 43 | mean = [0, 0] 44 | cov = [[1,0], [0,1]] 45 | X, Y = np.random.multivariate_normal(mean, cov, nSamples).T 46 | plt.scatter(X, Y, color='k') 47 | 48 | plt.savefig("normal-2d.png") 49 | -------------------------------------------------------------------------------- /reverse_GAN_batch_accuracy.py: -------------------------------------------------------------------------------- 1 | # Subarna Tripathi (http://acsweb.ucsd.edu/~stripath/research) 2 | # License: MIT 3 | # 2017-04-04 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import argparse 10 | import numpy as np 11 | import os 12 | import tensorflow as tf 13 | from tensorflow import flags 14 | 15 | from model import DCGAN 16 | 17 | import os 18 | import scipy.misc 19 | import numpy as np 20 | 21 | from model import DCGAN 22 | from utils import pp, visualize, to_json 23 | 24 | 25 | flags = tf.app.flags 26 | 27 | #flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]") 28 | #flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]") 29 | #flags.DEFINE_string("checkpoint_dir", "checkpoint_new2", "Directory name to save the checkpoints [checkpoint]") 30 | FLAGS = flags.FLAGS 31 | 32 | from PIL import Image 33 | from tensorflow.contrib.framework.python.framework import checkpoint_utils 34 | 35 | config = tf.ConfigProto() 36 | # config.gpu_options.allow_growth = True 37 | with tf.Session(config=config) as sess: 38 | dcgan = DCGAN(sess, 39 | batch_size=1024, sample_size=1024, 40 | num_iters = 80000, 41 | LEARNING_RATE=1.) 42 | dcgan.reverse_GAN_batch_all_prec(FLAGS) 43 | -------------------------------------------------------------------------------- /reverse_GAN_noisy_batch.py: -------------------------------------------------------------------------------- 1 | # Subarna Tripathi (http://acsweb.ucsd.edu/~stripath/research) 2 | # License: MIT 3 | # 2017-04-04 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import argparse 10 | import numpy as np 11 | import os 12 | import tensorflow as tf 13 | from tensorflow import flags 14 | 15 | from model import DCGAN 16 | 17 | import os 18 | import scipy.misc 19 | import numpy as np 20 | 21 | from model import DCGAN 22 | from utils import pp, visualize, to_json 23 | 24 | 25 | flags = tf.app.flags 26 | 27 | #flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]") 28 | #flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]") 29 | #flags.DEFINE_string("checkpoint_dir", "checkpoint_new2", "Directory name to save the checkpoints [checkpoint]") 30 | FLAGS = flags.FLAGS 31 | 32 | from PIL import Image 33 | from tensorflow.contrib.framework.python.framework import checkpoint_utils 34 | 35 | config = tf.ConfigProto() 36 | # config.gpu_options.allow_growth = True 37 | with tf.Session(config=config) as sess: 38 | dcgan = DCGAN(sess, 39 | batch_size=1024, sample_size=1024, 40 | num_iters=80000, 41 | LEARNING_RATE=1., 42 | ) 43 | dcgan.ReverseBatchwithNoise(FLAGS) 44 | -------------------------------------------------------------------------------- /reverse_GAN_untrained.py: -------------------------------------------------------------------------------- 1 | # Subarna Tripathi (http://acsweb.ucsd.edu/~stripath/research) 2 | # License: MIT 3 | # 2017-04-04 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import argparse 10 | import numpy as np 11 | import os 12 | import tensorflow as tf 13 | from tensorflow import flags 14 | 15 | from model import DCGAN 16 | 17 | import os 18 | import scipy.misc 19 | import numpy as np 20 | 21 | from model import DCGAN 22 | from utils import pp, visualize, to_json 23 | 24 | 25 | flags = tf.app.flags 26 | # flags.DEFINE_integer("epoch", 25, "Epoch" to train [25]") 27 | #flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]") 28 | #flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]") 29 | FLAGS = flags.FLAGS 30 | 31 | 32 | copy_num = 1 33 | config = tf.ConfigProto() 34 | # config.gpu_options.allow_growth = True 35 | with tf.Session(config=config) as sess: 36 | dcgan = DCGAN(sess, 37 | batch_size=1024, sample_size=1024, 38 | untrained_net = True, 39 | external_image=False, 40 | num_iters=100000, 41 | LEARNING_RATE=1., 42 | #clipping=False, 43 | stochastic_clipping=True 44 | ) 45 | dcgan.reverse_GAN_batch_all_prec(FLAGS) 46 | 47 | -------------------------------------------------------------------------------- /train-dcgan.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.4 2 | 3 | # Original Version: Taehoon Kim (http://carpedm20.github.io) 4 | # + Source: https://github.com/carpedm20/DCGAN-tensorflow/blob/e30539fb5e20d5a0fed40935853da97e9e55eee8/main.py 5 | # + License: MIT 6 | # [2016-08-05] Modifications for Inpainting: Brandon Amos (http://bamos.github.io) 7 | # + License: MIT 8 | 9 | import os 10 | import scipy.misc 11 | import numpy as np 12 | 13 | from model import DCGAN 14 | from utils import pp, visualize, to_json 15 | 16 | import tensorflow as tf 17 | 18 | flags = tf.app.flags 19 | flags.DEFINE_integer("epoch", 25, "Epoch to train [25]") 20 | flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]") 21 | flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]") 22 | flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]") 23 | flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]") 24 | flags.DEFINE_integer("image_size", 64, "The size of image to use") 25 | #flags.DEFINE_string("dataset", "all_images/gan_train_img_align_celeba/img_align_celeba", "Dataset directory.") #"lfw-aligned-64" 26 | #flags.DEFINE_string("dataset", "img_align_celeba_opencv_64", "Dataset directory.") #"lfw-aligned-64" 27 | flags.DEFINE_string("checkpoint_dir", "checkpoint_new", "Directory name to save the checkpoints [checkpoint]") 28 | flags.DEFINE_string("sample_dir", "samples_new", "Directory name to save the image samples [samples]") 29 | FLAGS = flags.FLAGS 30 | 31 | if not os.path.exists(FLAGS.checkpoint_dir): 32 | os.makedirs(FLAGS.checkpoint_dir) 33 | if not os.path.exists(FLAGS.sample_dir): 34 | os.makedirs(FLAGS.sample_dir) 35 | 36 | config = tf.ConfigProto() 37 | config.gpu_options.allow_growth = True 38 | with tf.Session(config=config) as sess: 39 | dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, 40 | is_crop=False, checkpoint_dir=FLAGS.checkpoint_dir) 41 | 42 | dcgan.train(FLAGS) 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The code for recovering latent vectors with stochastic clipping, projected gradient and without clipping is built on top of 2 | Brandon Amos's DCGAN-completion-tensorflow project: 3 | https://github.com/bamos/dcgan-completion.tensorflow 4 | 5 | The above mentioned additions belong to 6 | Copyright (c) 2017 Subarna Tripathi 7 | 8 | In compliance with the license for DCGAN-tensorflow, we reproduce the license 9 | statement below, and release the code in this directory under the same license. 10 | 11 | # From Deep Completion 12 | Most of the code in this repository was written by modifying a duplicate of 13 | Taehoon Kim's DCGAN-tensorflow project: 14 | https://github.com/carpedm20/DCGAN-tensorflow 15 | 16 | The modifications are Copyright (c) 2016 Brandon Amos 17 | 18 | In compliance with the license for DCGAN-tensorflow, we reproduce the license 19 | statement below, and release the code in this directory under the same license. 20 | 21 | The MIT License (MIT) 22 | 23 | Copyright (c) 2016 Taehoon Kim 24 | 25 | Permission is hereby granted, free of charge, to any person obtaining a copy 26 | of this software and associated documentation files (the "Software"), to deal 27 | in the Software without restriction, including without limitation the rights 28 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 29 | copies of the Software, and to permit persons to whom the Software is 30 | furnished to do so, subject to the following conditions: 31 | 32 | The above copyright notice and this permission notice shall be included in all 33 | copies or substantial portions of the Software. 34 | 35 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 36 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 37 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 38 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 39 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 40 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 41 | SOFTWARE. 42 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | # Original Version: Taehoon Kim (http://carpedm20.github.io) 2 | # + Source: https://github.com/carpedm20/DCGAN-tensorflow/blob/e30539fb5e20d5a0fed40935853da97e9e55eee8/ops.py 3 | # + License: MIT 4 | 5 | import math 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | from tensorflow.python.framework import ops 10 | 11 | from utils import * 12 | 13 | class batch_norm(object): 14 | """Code modification of http://stackoverflow.com/a/33950177""" 15 | def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"): 16 | with tf.variable_scope(name): 17 | self.epsilon = epsilon 18 | self.momentum = momentum 19 | 20 | self.ema = tf.train.ExponentialMovingAverage(decay=self.momentum) 21 | self.name = name 22 | 23 | def __call__(self, x, train=True): 24 | shape = x.get_shape().as_list() 25 | 26 | if train: 27 | with tf.variable_scope(self.name) as scope: 28 | self.beta = tf.get_variable("beta", [shape[-1]], 29 | initializer=tf.constant_initializer(0.)) 30 | self.gamma = tf.get_variable("gamma", [shape[-1]], 31 | initializer=tf.random_normal_initializer(1., 0.02)) 32 | 33 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments') 34 | ema_apply_op = self.ema.apply([batch_mean, batch_var]) 35 | self.ema_mean, self.ema_var = self.ema.average(batch_mean), self.ema.average(batch_var) 36 | 37 | with tf.control_dependencies([ema_apply_op]): 38 | mean, var = tf.identity(batch_mean), tf.identity(batch_var) 39 | else: 40 | mean, var = self.ema_mean, self.ema_var 41 | #mean, var = tf.nn.moments(x, [0, 1, 2], name='moments') 42 | 43 | normed = tf.nn.batch_norm_with_global_normalization( 44 | x, mean, var, self.beta, self.gamma, self.epsilon, scale_after_normalization=True) 45 | 46 | return normed 47 | 48 | def binary_cross_entropy(preds, targets, name=None): 49 | """Computes binary cross entropy given `preds`. 50 | 51 | For brevity, let `x = `, `z = targets`. The logistic loss is 52 | 53 | loss(x, z) = - sum_i (x[i] * log(z[i]) + (1 - x[i]) * log(1 - z[i])) 54 | 55 | Args: 56 | preds: A `Tensor` of type `float32` or `float64`. 57 | targets: A `Tensor` of the same type and shape as `preds`. 58 | """ 59 | eps = 1e-12 60 | with ops.op_scope([preds, targets], name, "bce_loss") as name: 61 | preds = ops.convert_to_tensor(preds, name="preds") 62 | targets = ops.convert_to_tensor(targets, name="targets") 63 | return tf.reduce_mean(-(targets * tf.log(preds + eps) + 64 | (1. - targets) * tf.log(1. - preds + eps))) 65 | 66 | def conv_cond_concat(x, y): 67 | """Concatenate conditioning vector on feature map axis.""" 68 | x_shapes = x.get_shape() 69 | y_shapes = y.get_shape() 70 | return tf.concat(3, [x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])]) 71 | 72 | def conv2d(input_, output_dim, 73 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 74 | name="conv2d"): 75 | with tf.variable_scope(name): 76 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], 77 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 78 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') 79 | 80 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 81 | # conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 82 | conv = tf.nn.bias_add(conv, biases) 83 | 84 | return conv 85 | 86 | def conv2d_transpose(input_, output_shape, 87 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 88 | name="conv2d_transpose", with_w=False): 89 | with tf.variable_scope(name): 90 | # filter : [height, width, output_channels, in_channels] 91 | w = tf.get_variable('w', [k_h, k_h, output_shape[-1], input_.get_shape()[-1]], 92 | initializer=tf.random_normal_initializer(stddev=stddev)) 93 | 94 | try: 95 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, 96 | strides=[1, d_h, d_w, 1]) 97 | 98 | # Support for verisons of TensorFlow before 0.7.0 99 | except AttributeError: 100 | deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, 101 | strides=[1, d_h, d_w, 1]) 102 | 103 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 104 | # deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) 105 | deconv = tf.nn.bias_add(deconv, biases) 106 | 107 | if with_w: 108 | return deconv, w, biases 109 | else: 110 | return deconv 111 | 112 | def lrelu(x, leak=0.2, name="lrelu"): 113 | with tf.variable_scope(name): 114 | f1 = 0.5 * (1 + leak) 115 | f2 = 0.5 * (1 - leak) 116 | return f1 * x + f2 * abs(x) 117 | 118 | def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): 119 | shape = input_.get_shape().as_list() 120 | 121 | with tf.variable_scope(scope or "Linear"): 122 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, 123 | tf.random_normal_initializer(stddev=stddev)) 124 | bias = tf.get_variable("bias", [output_size], 125 | initializer=tf.constant_initializer(bias_start)) 126 | if with_w: 127 | return tf.matmul(input_, matrix) + bias, matrix, bias 128 | else: 129 | return tf.matmul(input_, matrix) + bias 130 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Original Version: Taehoon Kim (http://carpedm20.github.io) 2 | # + Source: https://github.com/carpedm20/DCGAN-tensorflow/blob/e30539fb5e20d5a0fed40935853da97e9e55eee8/utils.py 3 | # + License: MIT 4 | 5 | """ 6 | Some codes from https://github.com/Newmu/dcgan_code 7 | """ 8 | from __future__ import division 9 | import math 10 | import json 11 | import random 12 | import pprint 13 | import scipy.misc 14 | import numpy as np 15 | from time import gmtime, strftime 16 | 17 | pp = pprint.PrettyPrinter() 18 | 19 | get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1]) 20 | 21 | def get_image(image_path, image_size, is_crop=True): 22 | return transform(imread(image_path), image_size, is_crop) 23 | 24 | def save_images(images, size, image_path): 25 | return imsave(inverse_transform(images), size, image_path) 26 | 27 | def imread(path): 28 | return scipy.misc.imread(path, mode='RGB').astype(np.float) 29 | 30 | def merge_images(images, size): 31 | return inverse_transform(images) 32 | 33 | def merge(images, size): 34 | h, w = images.shape[1], images.shape[2] 35 | img = np.zeros((int(h * size[0]), int(w * size[1]), 3)) 36 | for idx, image in enumerate(images): 37 | i = idx % size[1] 38 | j = idx // size[1] 39 | img[j*h:j*h+h, i*w:i*w+w, :] = image 40 | 41 | return img 42 | 43 | def imsave(images, size, path): 44 | return scipy.misc.imsave(path, merge(images, size)) 45 | 46 | def center_crop(x, crop_h, crop_w=None, resize_w=64): 47 | if crop_w is None: 48 | crop_w = crop_h 49 | h, w = x.shape[:2] 50 | j = int(round((h - crop_h)/2.)) 51 | i = int(round((w - crop_w)/2.)) 52 | return scipy.misc.imresize(x[j:j+crop_h, i:i+crop_w], 53 | [resize_w, resize_w]) 54 | 55 | def smoothing(x, window_len=3, window='bartlett'): 56 | if x.ndim != 1: 57 | raise ValueError, 'smoothonly accepts 1D arrays.' 58 | 59 | if x.size < window_len: 60 | raise ValueError, 'Inout vector needs to be bigger than window size' 61 | 62 | s=np.r_[x[window_len-1:0:-1], x, x[-1:-window_len:-1]] 63 | if window == 'flat': 64 | w =np.ones(window_len,'d') 65 | else: 66 | w = eval('np.'+window+'(window_len)') 67 | 68 | y=np.convolve(w/w.sum(), s, mode='valid') 69 | return y[-101:-1] 70 | 71 | 72 | def transform(image, npx=64, is_crop=True): 73 | # npx : # of pixels width/height of image 74 | if is_crop: 75 | cropped_image = center_crop(image, npx) 76 | else: 77 | cropped_image = image 78 | return np.array(cropped_image)/127.5 - 1. 79 | 80 | def inverse_transform(images): 81 | return (images+1.)/2. 82 | 83 | 84 | def to_json(output_path, *layers): 85 | with open(output_path, "w") as layer_f: 86 | lines = "" 87 | for w, b, bn in layers: 88 | layer_idx = w.name.split('/')[0].split('h')[1] 89 | 90 | B = b.eval() 91 | 92 | if "lin/" in w.name: 93 | W = w.eval() 94 | depth = W.shape[1] 95 | else: 96 | W = np.rollaxis(w.eval(), 2, 0) 97 | depth = W.shape[0] 98 | 99 | biases = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(B)]} 100 | if bn != None: 101 | gamma = bn.gamma.eval() 102 | beta = bn.beta.eval() 103 | 104 | gamma = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(gamma)]} 105 | beta = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(beta)]} 106 | else: 107 | gamma = {"sy": 1, "sx": 1, "depth": 0, "w": []} 108 | beta = {"sy": 1, "sx": 1, "depth": 0, "w": []} 109 | 110 | if "lin/" in w.name: 111 | fs = [] 112 | for w in W.T: 113 | fs.append({"sy": 1, "sx": 1, "depth": W.shape[0], "w": ['%.2f' % elem for elem in list(w)]}) 114 | 115 | lines += """ 116 | var layer_%s = { 117 | "layer_type": "fc", 118 | "sy": 1, "sx": 1, 119 | "out_sx": 1, "out_sy": 1, 120 | "stride": 1, "pad": 0, 121 | "out_depth": %s, "in_depth": %s, 122 | "biases": %s, 123 | "gamma": %s, 124 | "beta": %s, 125 | "filters": %s 126 | };""" % (layer_idx.split('_')[0], W.shape[1], W.shape[0], biases, gamma, beta, fs) 127 | else: 128 | fs = [] 129 | for w_ in W: 130 | fs.append({"sy": 5, "sx": 5, "depth": W.shape[3], "w": ['%.2f' % elem for elem in list(w_.flatten())]}) 131 | 132 | lines += """ 133 | var layer_%s = { 134 | "layer_type": "deconv", 135 | "sy": 5, "sx": 5, 136 | "out_sx": %s, "out_sy": %s, 137 | "stride": 2, "pad": 1, 138 | "out_depth": %s, "in_depth": %s, 139 | "biases": %s, 140 | "gamma": %s, 141 | "beta": %s, 142 | "filters": %s 143 | };""" % (layer_idx, 2**(int(layer_idx)+2), 2**(int(layer_idx)+2), 144 | W.shape[0], W.shape[3], biases, gamma, beta, fs) 145 | layer_f.write(" ".join(lines.replace("'","").split())) 146 | 147 | def make_gif(images, fname, duration=2, true_image=False): 148 | import moviepy.editor as mpy 149 | 150 | def make_frame(t): 151 | try: 152 | x = images[int(len(images)/duration*t)] 153 | except: 154 | x = images[-1] 155 | 156 | if true_image: 157 | return x.astype(np.uint8) 158 | else: 159 | return ((x+1)/2*255).astype(np.uint8) 160 | 161 | clip = mpy.VideoClip(make_frame, duration=duration) 162 | clip.write_gif(fname, fps = len(images) / duration) 163 | 164 | def visualize(sess, dcgan, config, option): 165 | if option == 0: 166 | z_sample = np.random.uniform(-0.5, 0.5, size=(dcgan.batch_size, dcgan.z_dim)) #config.batch_size 167 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 168 | save_images(samples, [8, 8], './samples/test_%s.png' % strftime("%Y-%m-%d %H:%M:%S", gmtime())) 169 | elif option == 1: 170 | values = np.arange(0, 1, 1./config.batch_size) 171 | for idx in xrange(100): 172 | print(" [*] %d" % idx) 173 | z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 174 | for kdx, z in enumerate(z_sample): 175 | z[idx] = values[kdx] 176 | 177 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 178 | save_images(samples, [8, 8], 'samples/test_arange_%s.png' % (idx)) #./samples/test_arange_%s.png 179 | elif option == 2: 180 | values = np.arange(0, 1, 1./config.batch_size) 181 | for idx in [random.randint(0, 99) for _ in xrange(100)]: 182 | print(" [*] %d" % idx) 183 | z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim)) 184 | z_sample = np.tile(z, (config.batch_size, 1)) 185 | #z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 186 | for kdx, z in enumerate(z_sample): 187 | z[idx] = values[kdx] 188 | 189 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 190 | make_gif(samples, './samples/test_gif_%s.gif' % (idx)) 191 | elif option == 3: 192 | values = np.arange(0, 1, 1./config.batch_size) 193 | for idx in xrange(100): 194 | print(" [*] %d" % idx) 195 | z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 196 | for kdx, z in enumerate(z_sample): 197 | z[idx] = values[kdx] 198 | 199 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 200 | make_gif(samples, './samples/test_gif_%s.gif' % (idx)) 201 | elif option == 4: 202 | image_set = [] 203 | values = np.arange(0, 1, 1./config.batch_size) 204 | 205 | for idx in xrange(100): 206 | print(" [*] %d" % idx) 207 | z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 208 | for kdx, z in enumerate(z_sample): z[idx] = values[kdx] 209 | 210 | image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})) 211 | make_gif(image_set[-1], './samples/test_gif_%s.gif' % (idx)) 212 | 213 | new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \ 214 | for idx in range(64) + range(63, -1, -1)] 215 | make_gif(new_image_set, './samples/test_gif_merged.gif', duration=8) 216 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Original Version: Taehoon Kim (http://carpedm20.github.io) 2 | # + Source: https://github.com/carpedm20/DCGAN-tensorflow/blob/e30539fb5e20d5a0fed40935853da97e9e55eee8/model.py 3 | # + License: MIT 4 | # [2016-08-05] Modifications for Completion: Brandon Amos (http://bamos.github.io) 5 | # + License: MIT 6 | # [2017-04-04] Modifications for latent vector recovery: Subarna Tripathi (http://acsweb.ucsd.edu/~stripath/research) 7 | # + License: MIT 8 | 9 | from __future__ import division 10 | import os 11 | import time 12 | from glob import glob 13 | import tensorflow as tf 14 | from six.moves import xrange 15 | 16 | from ops import * 17 | from utils import * 18 | import scipy.fftpack as scifft 19 | from PIL import Image 20 | import pickle 21 | from os import listdir 22 | 23 | class DCGAN(object): 24 | def __init__(self, sess, image_size=64, is_crop=False, 25 | batch_size=64, sample_size=64, 26 | z_dim=100, gf_dim=64, df_dim=64, 27 | gfc_dim=1024, dfc_dim=1024, c_dim=3, 28 | checkpoint_dir=None, lam=0.1, 29 | external_image=False, 30 | untrained_net=False, 31 | robustness_expt_num = 500, 32 | num_iters = 80000, 33 | recon_thresh = 0.005, 34 | LEARNING_RATE = 1., 35 | clipping = True, 36 | stochastic_clipping = False, 37 | recover_training_images = False, 38 | save_visualizations=True 39 | ): 40 | """ 41 | 42 | Args: 43 | sess: TensorFlow session 44 | batch_size: The size of batch. Should be specified before training. 45 | z_dim: (optional) Dimension of dim for Z. [100] 46 | gf_dim: (optional) Dimension of gen filters in first conv layer. [64] 47 | df_dim: (optional) Dimension of discrim filters in first conv layer. [64] 48 | gfc_dim: (optional) Dimension of gen untis for for fully connected layer. [1024] 49 | dfc_dim: (optional) Dimension of discrim units for fully connected layer. [1024] 50 | c_dim: (optional) Dimension of image color. [3] 51 | """ 52 | self.sess = sess 53 | self.is_crop = is_crop 54 | self.batch_size = batch_size 55 | self.image_size = image_size 56 | self.sample_size = sample_size 57 | self.image_shape = [image_size, image_size, 3] 58 | 59 | self.z_dim = z_dim 60 | 61 | self.gf_dim = gf_dim 62 | self.df_dim = df_dim 63 | 64 | self.gfc_dim = gfc_dim 65 | self.dfc_dim = dfc_dim 66 | 67 | self.lam = lam 68 | 69 | self.c_dim = 3 70 | 71 | # batch normalization : deals with poor initialization helps gradient flow 72 | self.d_bn1 = batch_norm(name='d_bn1') 73 | self.d_bn2 = batch_norm(name='d_bn2') 74 | self.d_bn3 = batch_norm(name='d_bn3') 75 | 76 | self.g_bn0 = batch_norm(name='g_bn0') 77 | self.g_bn1 = batch_norm(name='g_bn1') 78 | self.g_bn2 = batch_norm(name='g_bn2') 79 | self.g_bn3 = batch_norm(name='g_bn3') 80 | 81 | self.checkpoint_dir = checkpoint_dir 82 | self.build_model() 83 | 84 | self.model_name = "DCGAN.model" 85 | self.from_external_image = external_image 86 | self.robustness_expt_num = robustness_expt_num 87 | 88 | self.num_iters = num_iters 89 | self.recon_thresh = recon_thresh 90 | self.LEARNING_RATE = LEARNING_RATE 91 | self.stochastic_clipping = stochastic_clipping 92 | self.clipping = clipping 93 | self.save_visualizations = save_visualizations 94 | viz_rows = int(np.sqrt(self.sample_size*1.)) 95 | for counter in range(viz_rows, 1, -1): 96 | #print(counter, np.remainder(self.sample_size, counter)) 97 | if (self.sample_size % counter) == 0: 98 | self.viz_rows = counter 99 | self.viz_cols = int(self.sample_size/self.viz_rows) 100 | break 101 | 102 | self.untrained_net = untrained_net 103 | self.recover_training_images = recover_training_images 104 | 105 | def build_model(self): 106 | self.images = tf.placeholder( 107 | tf.float32, [None] + self.image_shape, name='real_images') 108 | self.sample_images= tf.placeholder( 109 | tf.float32, [None] + self.image_shape, name='sample_images') 110 | self.z = tf.placeholder(tf.float32, [None, self.z_dim], name='z') 111 | self.z_sum = tf.histogram_summary("z", self.z) 112 | 113 | self.G = self.generator(self.z) 114 | self.D, self.D_logits = self.discriminator(self.images) 115 | 116 | self.sampler = self.sampler(self.z) 117 | self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True) 118 | 119 | self.d_sum = tf.histogram_summary("d", self.D) 120 | self.d__sum = tf.histogram_summary("d_", self.D_) 121 | self.G_sum = tf.image_summary("G", self.G) 122 | 123 | self.d_loss_real = tf.reduce_mean( 124 | tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits, 125 | tf.ones_like(self.D))) 126 | self.d_loss_fake = tf.reduce_mean( 127 | tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_, 128 | tf.zeros_like(self.D_))) 129 | self.g_loss = tf.reduce_mean( 130 | tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_, 131 | tf.ones_like(self.D_))) 132 | 133 | self.d_loss_real_sum = tf.scalar_summary("d_loss_real", self.d_loss_real) 134 | self.d_loss_fake_sum = tf.scalar_summary("d_loss_fake", self.d_loss_fake) 135 | 136 | self.d_loss = self.d_loss_real + self.d_loss_fake 137 | 138 | self.g_loss_sum = tf.scalar_summary("g_loss", self.g_loss) 139 | self.d_loss_sum = tf.scalar_summary("d_loss", self.d_loss) 140 | 141 | t_vars = tf.trainable_variables() 142 | 143 | self.d_vars = [var for var in t_vars if 'd_' in var.name] 144 | self.g_vars = [var for var in t_vars if 'g_' in var.name] 145 | 146 | self.saver = tf.train.Saver(max_to_keep=1) 147 | 148 | # Completion. 149 | self.mask = tf.placeholder(tf.float32, [None] + self.image_shape, name='mask') 150 | self.contextual_loss = tf.reduce_sum( 151 | tf.contrib.layers.flatten( 152 | tf.abs(tf.mul(self.mask, self.G) - tf.mul(self.mask, self.images))), 1) 153 | self.perceptual_loss = self.g_loss 154 | self.complete_loss = self.contextual_loss + self.lam*self.perceptual_loss 155 | self.grad_complete_loss = tf.gradients(self.complete_loss, self.z) 156 | 157 | # Reverse. 158 | self.reverse_z = tf.placeholder(tf.float32, [None, self.z_dim], name='reverse_z') 159 | self.G_hat = self.generator_imposter(self.reverse_z) 160 | self.reverse_loss_sum = tf.reduce_sum(tf.pow(self.G - self.G_hat, 2), 0) 161 | self.reverse_loss = tf.reduce_mean(self.reverse_loss_sum) 162 | self.reverse_grads = tf.gradients(self.reverse_loss, self.reverse_z) 163 | # reverse 164 | self.restorer = tf.train.Saver() 165 | 166 | 167 | def train(self, config): 168 | data = glob(os.path.join(config.dataset, "*.png")) 169 | #np.random.shuffle(data) 170 | assert(len(data) > 0) 171 | 172 | d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \ 173 | .minimize(self.d_loss, var_list=self.d_vars) 174 | g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \ 175 | .minimize(self.g_loss, var_list=self.g_vars) 176 | tf.initialize_all_variables().run() 177 | 178 | self.g_sum = tf.merge_summary( 179 | [self.z_sum, self.d__sum, self.G_sum, self.d_loss_fake_sum, self.g_loss_sum]) 180 | self.d_sum = tf.merge_summary( 181 | [self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum]) 182 | self.writer = tf.train.SummaryWriter("./logs", self.sess.graph) 183 | 184 | sample_z = np.random.uniform(-1, 1, size=(self.sample_size , self.z_dim)) 185 | sample_files = data[0:self.sample_size] 186 | sample = [get_image(sample_file, self.image_size, is_crop=self.is_crop) for sample_file in sample_files] 187 | sample_images = np.array(sample).astype(np.float32) 188 | 189 | counter = 1 190 | start_time = time.time() 191 | 192 | if self.load(self.checkpoint_dir): 193 | print(""" 194 | 195 | ====== 196 | An existing model was found in the checkpoint directory. 197 | If you just cloned this repository, it's Brandon Amos' 198 | trained model for faces that's used in the post. 199 | If you want to train a new model from scratch, 200 | delete the checkpoint directory or specify a different 201 | --checkpoint_dir argument. 202 | ====== 203 | 204 | """) 205 | else: 206 | print(""" 207 | 208 | ====== 209 | An existing model was not found in the checkpoint directory. 210 | Initializing a new one. 211 | ====== 212 | 213 | """) 214 | 215 | for epoch in xrange(config.epoch): 216 | data = glob(os.path.join(config.dataset, "*.png")) 217 | batch_idxs = min(len(data), config.train_size) // self.batch_size 218 | 219 | for idx in xrange(0, batch_idxs): 220 | batch_files = data[idx*config.batch_size:(idx+1)*config.batch_size] 221 | batch = [get_image(batch_file, self.image_size, is_crop=self.is_crop) 222 | for batch_file in batch_files] 223 | batch_images = np.array(batch).astype(np.float32) 224 | 225 | batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]) \ 226 | .astype(np.float32) 227 | 228 | # Update D network 229 | _, summary_str = self.sess.run([d_optim, self.d_sum], 230 | feed_dict={ self.images: batch_images, self.z: batch_z }) 231 | self.writer.add_summary(summary_str, counter) 232 | 233 | # Update G network 234 | _, summary_str = self.sess.run([g_optim, self.g_sum], 235 | feed_dict={ self.z: batch_z }) 236 | self.writer.add_summary(summary_str, counter) 237 | 238 | # Run g_optim twice to make sure that d_loss does not go to zero (different from paper) 239 | _, summary_str = self.sess.run([g_optim, self.g_sum], 240 | feed_dict={ self.z: batch_z }) 241 | self.writer.add_summary(summary_str, counter) 242 | 243 | errD_fake = self.d_loss_fake.eval({self.z: batch_z}) 244 | errD_real = self.d_loss_real.eval({self.images: batch_images}) 245 | errG = self.g_loss.eval({self.z: batch_z}) 246 | 247 | counter += 1 248 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ 249 | % (epoch, idx, batch_idxs, 250 | time.time() - start_time, errD_fake+errD_real, errG)) 251 | 252 | if np.mod(counter, 100) == 1: 253 | samples, d_loss, g_loss = self.sess.run( 254 | [self.sampler, self.d_loss, self.g_loss], 255 | feed_dict={self.z: sample_z, self.images: sample_images} 256 | ) 257 | save_images(samples, [8, 8], 258 | './samples/train_{:02d}_{:04d}.png'.format(epoch, idx)) 259 | print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) 260 | 261 | if np.mod(counter, 500) == 2: 262 | self.save(config.checkpoint_dir, counter) 263 | 264 | 265 | 266 | def ReverseBatchwithNoise(self, config): 267 | z_recon_thresh = self.recon_thresh 268 | iterations = self.num_iters #100000 269 | LEARNING_RATE = self.LEARNING_RATE 270 | 271 | tf.initialize_all_variables().run() 272 | 273 | isLoaded = self.load_new(self.checkpoint_dir) 274 | assert(isLoaded) 275 | #visualize(self.sess, self, config, 0) 276 | 277 | configurations = ["no_clipping", "standard_clipping", "stochastic_clipping"] 278 | 279 | noise_levels = [1e-3, 5e-3, 1e-2, 5e-2, 1e-1, 5e-1, 1e+0] 280 | mean = 0 281 | 282 | sample_z = np.random.uniform(-1, 1, size=(self.sample_size, self.z_dim)) ## input 283 | s_zh = np.random.uniform(-1, 1, size=(self.sample_size, self.z_dim)) ## imposter 284 | 285 | orig_image = self.sess.run(self.G, feed_dict={self.z: sample_z}) 286 | imposter_image_base = self.sess.run(self.G_hat, feed_dict={self.reverse_z: s_zh}) #sample_zh 287 | 288 | rand_shape = (self.sample_size, 64, 64, 3) 289 | save_visualizations = True 290 | 291 | for config_id in range(len(configurations)): 292 | ## copy orig and imposter images every time in the loop 293 | imposter_image = np.copy(imposter_image_base) 294 | sample_zh = np.copy(s_zh) 295 | 296 | config_name = configurations[config_id] 297 | base_result_dir = "result_visualizations/noisy_batch/" + str(config_name) + "/" 298 | if not os.path.exists(base_result_dir): 299 | os.makedirs(base_result_dir) 300 | 301 | imgName = base_result_dir + "visualization_orig.jpg" 302 | save_images(orig_image, [self.viz_rows, self.viz_cols], imgName) 303 | 304 | imgName = base_result_dir + "visualization_imposter_init.jpg" 305 | save_images(imposter_image, [self.viz_rows, self.viz_cols], imgName) 306 | 307 | for level in np.arange(len(noise_levels)): 308 | sample_zh = np.copy(s_zh) 309 | learning_rate = LEARNING_RATE 310 | sigma = noise_levels[level]**.5 311 | gnoise = np.random.normal(mean, sigma, rand_shape) 312 | gnoise = gnoise.reshape(rand_shape) 313 | 314 | noisy_image = np.copy(orig_image) 315 | noisy_image = noisy_image + gnoise 316 | noisy_image = np.clip(noisy_image, -1, 1) 317 | 318 | z_loss = [] 319 | initial_z_loss = np.sum(np.power(sample_z - sample_zh, 2), 0) 320 | intial_z_loss = np.mean(initial_z_loss) 321 | z_loss.append(initial_z_loss) 322 | 323 | phi_loss = [] 324 | initial_phi_loss = np.sum(np.power(noisy_image - imposter_image,2), 0) 325 | initial_phi_loss = np.mean(initial_phi_loss) 326 | phi_loss.append(initial_phi_loss) 327 | 328 | pixel_loss = [] 329 | initial_pixel_loss = np.sum(np.power(orig_image - imposter_image,2), 0) 330 | initial_pixel_loss = np.mean(initial_pixel_loss) 331 | pixel_loss.append(initial_pixel_loss) 332 | 333 | result_dir = base_result_dir + str(level) + "/" 334 | if not os.path.exists(result_dir): 335 | os.makedirs(result_dir) 336 | 337 | imgName = result_dir + "visualization_noisy_orig.jpg" 338 | save_images(noisy_image, [self.viz_rows, self.viz_cols], imgName) 339 | 340 | z_primes_iters = [] 341 | for n_iter in np.arange(iterations): 342 | if n_iter > 0: 343 | curr_z_loss = np.mean(np.power(sample_zh - sample_z, 2), 1) 344 | curr_z_loss = np.sum(curr_z_loss) 345 | phi_loss.append(my_loss) 346 | z_loss.append(curr_z_loss) 347 | 348 | #z_primes_iters.append(sample_zh) 349 | 350 | ## pixel reconstrcution error 351 | curr_pixel_loss = np.mean(np.power(orig_image - imposter_image, 2), 1) 352 | curr_pixel_loss = np.sum(curr_pixel_loss) 353 | pixel_loss.append(curr_pixel_loss) 354 | 355 | if np.abs(curr_z_loss) < self.recon_thresh * self.sample_size: 356 | break 357 | 358 | fd = {self.G: noisy_image, ## this is phi(z) 359 | self.reverse_z: sample_zh 360 | } 361 | run = [self.reverse_loss, self.reverse_grads, self.G_hat] 362 | [my_loss, my_grads, imposter_image] = self.sess.run(run, feed_dict=fd) 363 | 364 | my_grads = np.asarray(my_grads[0]) 365 | #print(sample_zh.shape) 366 | #print (n_iter+1, my_loss) 367 | 368 | ## save after 100-th iterations 369 | if n_iter == 99: 370 | imgName = result_dir + "visualization_imposter" + str(n_iter+1) + ".png" 371 | save_images(imposter_image, [self.viz_rows, self.viz_cols], imgName) 372 | ######## 373 | 374 | 375 | if (n_iter+1) % 300 == 0: 376 | print (n_iter, my_loss, np.mean(np.power(sample_zh - sample_z, 2))) 377 | 378 | if n_iter != 0 and save_visualizations: 379 | imgName = result_dir + "visualization_imposter" + str(n_iter+1) + ".png" 380 | save_images(imposter_image, [self.viz_rows, self.viz_cols], imgName) 381 | 382 | 383 | sample_zh = sample_zh - learning_rate * my_grads 384 | 385 | if config_name != "no_clipping": # do nothing for NO clipping case 386 | if config_name == "stochastic_clipping" : ## stochastic clipping 387 | for j in range(self.sample_size): 388 | edge1 = np.where(sample_zh[j] >= 1.)[0] #1 389 | edge2 = np.where(sample_zh[j] <= -1)[0] #1 390 | 391 | if edge1.shape[0] > 0: 392 | rand_el1 = np.random.uniform(-1, 1, size=(1, edge1.shape[0])) 393 | sample_zh[j,edge1] = rand_el1 394 | if edge2.shape[0] > 0: 395 | rand_el2 = np.random.uniform(-1, 1, size=(1, edge2.shape[0])) 396 | sample_zh[j,edge2] = rand_el2 397 | 398 | #if edge1.shape[0] > 0 or edge2.shape[0] > 0: 399 | #print (edge1.shape[0], edge2.shape[0]) 400 | else: ## standard clipping 401 | sample_zh = np.clip(sample_zh, -1, 1) 402 | 403 | ## always save the recovered image 404 | imgName = result_dir + "visualization_recovered.png" 405 | save_images(imposter_image, [self.viz_rows, self.viz_cols], imgName) 406 | ######## 407 | 408 | informtion = {'phi_loss': phi_loss, 409 | 'z_loss':z_loss, 410 | 'pixel_loss': pixel_loss, 411 | 'z': sample_z, 412 | 'z_prime': sample_zh, 413 | 'z_primes_iters': z_primes_iters 414 | } 415 | 416 | file_name = result_dir + "GAN_loss_batch_noisy.pickle" 417 | with open(file_name, 'wb') as handle: 418 | pickle.dump(informtion, handle, protocol=pickle.HIGHEST_PROTOCOL) 419 | 420 | ### write complements 421 | comple_z1 = sample_z - sample_zh #### G(z - recovered_z) 422 | comple_z2 = np.ones((self.sample_size, self.z_dim)) - sample_zh #### G(1 - recovered_z) 423 | comple_z3 = -sample_zh 424 | comple_z4 = -sample_z 425 | 426 | if config_name != "no_clipping": 427 | comple_z1 = np.clip(comple_z1, -1., 1.) 428 | comple_z2 = np.clip(comple_z2, -1., 1.) 429 | 430 | comple_image1 = self.sess.run(self.G, feed_dict={self.z: comple_z1}) 431 | comple_image2 = self.sess.run(self.G, feed_dict={self.z: comple_z2}) 432 | comple_image3 = self.sess.run(self.G, feed_dict={self.z: comple_z3}) 433 | comple_image4 = self.sess.run(self.G, feed_dict={self.z: comple_z4}) 434 | 435 | imgName1 = result_dir + "z_minus_zh.png" 436 | imgName2 = result_dir + "one_minus_zp.png" 437 | imgName3 = result_dir + "minus_zh.png" 438 | imgName4 = result_dir + "minus_z.png" 439 | 440 | save_images(comple_image1, [self.viz_rows, self.viz_cols], imgName1) 441 | save_images(comple_image2, [self.viz_rows, self.viz_cols], imgName2) 442 | save_images(comple_image3, [self.viz_rows, self.viz_cols], imgName3) 443 | save_images(comple_image4, [self.viz_rows, self.viz_cols], imgName4) 444 | ##################### 445 | 446 | 447 | #noise_recon_stats = {'stats': stats} 448 | #file_name = base_result_dir + "noisy_batch_recon_stats.pickle" 449 | #with open(file_name, 'wb') as handle: 450 | #pickle.dump(informtion, handle, protocol=pickle.HIGHEST_PROTOCOL) 451 | 452 | print ('GIR batch with noise complete') 453 | 454 | 455 | 456 | def reverse_GAN_batch_all_prec(self, config): 457 | tf.initialize_all_variables().run() 458 | 459 | base_result_dir = "result_visualizations/" 460 | 461 | if self.untrained_net: 462 | base_result_dir = "untrained_net/" + base_result_dir 463 | else: 464 | isLoaded = self.load_new(self.checkpoint_dir) 465 | assert(isLoaded) 466 | if not os.path.exists(base_result_dir): 467 | os.makedirs(base_result_dir) 468 | 469 | save_visualizations = self.save_visualizations 470 | 471 | precision_levels = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5] 472 | 473 | sample_z = np.random.uniform(-1, 1, size=(self.sample_size, self.z_dim)) ## input 474 | s_zh = np.random.uniform(-1, 1, size=(self.sample_size, self.z_dim)) ## imposter 475 | sample_zh = np.copy(s_zh) 476 | configurations = ["no_clipping", "standard_clipping", "stochastic_clipping"] 477 | 478 | #tf.get_variable_scope().reuse_variables() 479 | if not self.from_external_image: 480 | orig_image = self.sess.run(self.G, feed_dict={self.z: sample_z}) 481 | else : 482 | orig_image = np.zeros((self.sample_size, 64, 64, 3)) 483 | img_name = 'unseen_face6.png' #'unseen_face6.png' #'test_img.jpg' 484 | for j in np.arange(self.sample_size): 485 | orig_image[j] = get_image(img_name, 64) 486 | 487 | imposter_image_base = self.sess.run(self.G_hat, feed_dict={self.reverse_z: sample_zh}) #sample_zh 488 | #print(orig_image.shape, imposter_image.shape) 489 | 490 | for config_id in range(len(configurations)): 491 | imposter_image = np.copy(imposter_image_base) 492 | sample_zh = np.copy(s_zh) 493 | 494 | config_name = configurations[config_id] 495 | comment=config_name + "Experiment Started" 496 | print(comment) 497 | 498 | reverse_accuracy_matrix = np.zeros((self.sample_size, len(precision_levels))) 499 | base_result_dir = "result_visualizations/accuracy/" + str(config_name) + "/" 500 | if not os.path.exists(base_result_dir): 501 | os.makedirs(base_result_dir) 502 | 503 | if save_visualizations: 504 | imgName = base_result_dir + "visualization_orig.jpg" 505 | save_images(orig_image, [self.viz_rows, self.viz_cols], imgName) 506 | imgName = base_result_dir + "visualization_imposter_init.jpg" 507 | save_images(imposter_image, [self.viz_rows, self.viz_cols], imgName) 508 | 509 | #initial_phi_loss = np.mean(np.power(orig_image_pixel - imposter_image_pixel,2)) 510 | if self.from_external_image == 0: 511 | z_loss = [] 512 | initial_z_loss = np.sum(np.power(sample_z - sample_zh, 2), 0) 513 | initial_z_loss = np.mean(initial_z_loss) 514 | z_loss.append(initial_z_loss) 515 | else: 516 | recon_zs = [] 517 | 518 | phi_loss = [] 519 | initial_phi_loss = np.sum(np.power(orig_image - imposter_image,2), 0) # sum over batch 520 | initial_phi_loss = np.mean(initial_phi_loss) 521 | #print(str(0), initial_phi_loss, initial_z_loss) 522 | phi_loss.append(initial_phi_loss) 523 | 524 | learning_rate = self.LEARNING_RATE #1. 525 | iterations = self.num_iters 526 | for n_iter in np.arange(iterations): 527 | if n_iter % 50000 == 0 and n_iter > 0: #2000 528 | learning_rate /= 2. 529 | 530 | if n_iter > 0: 531 | phi_loss.append(my_loss) 532 | 533 | if self.from_external_image == 0: 534 | curr_z_loss = np.power(sample_zh - sample_z, 2) 535 | curr_z_loss = np.mean(curr_z_loss, 1) 536 | #curr_z_loss = np.mean((curr_z_loss**.5), 1) 537 | z_loss.append(np.mean(curr_z_loss)) 538 | for p in np.arange(len(precision_levels)): 539 | curr_z_loss_p = curr_z_loss #[:,p] 540 | prec_level = precision_levels[p] 541 | d = np.where(curr_z_loss_p < prec_level) 542 | reverse_accuracy_matrix[d, p] = 1 543 | 544 | if np.sum(reverse_accuracy_matrix) == self.sample_size * len(precision_levels): 545 | break 546 | 547 | fd = {self.G: orig_image, ## this is phi(z) 548 | self.reverse_z: sample_zh 549 | } 550 | run = [self.reverse_loss, self.reverse_grads, self.G_hat] 551 | [my_loss, my_grads, imposter_image] = self.sess.run(run, feed_dict=fd) 552 | 553 | my_grads = np.asarray(my_grads[0]) 554 | 555 | if n_iter % 200 == 0: 556 | if self.from_external_image == 0: 557 | print (n_iter, my_loss, np.mean(np.power(sample_zh - sample_z, 2))) 558 | else: 559 | print (n_iter, my_loss) 560 | if save_visualizations == 1 and n_iter != 0: 561 | imgName = base_result_dir + "visualization_imposter" + str(n_iter) + ".jpg" 562 | save_images(imposter_image, [self.viz_rows, self.viz_cols], imgName) 563 | 564 | sample_zh = sample_zh - learning_rate * my_grads 565 | 566 | if config_name != "no_clipping": # do nothing for NO clipping case 567 | if config_name == "stochastic_clipping" : ## stochastic clipping 568 | for j in range(self.sample_size): 569 | edge1 = np.where(sample_zh[j] >= 1.)[0] #1 570 | edge2 = np.where(sample_zh[j] <= -1)[0] #1 571 | 572 | if edge1.shape[0] > 0: 573 | rand_el1 = np.random.uniform(-1, 1, size=(1, edge1.shape[0])) 574 | sample_zh[j,edge1] = rand_el1 575 | if edge2.shape[0] > 0: 576 | rand_el2 = np.random.uniform(-1, 1, size=(1, edge2.shape[0])) 577 | sample_zh[j,edge2] = rand_el2 578 | 579 | #if edge1.shape[0] > 0 or edge2.shape[0] > 0: 580 | #print (edge1.shape[0], edge2.shape[0]) 581 | else: ## standard clipping 582 | sample_zh = np.clip(sample_zh, -1, 1) 583 | 584 | 585 | if self.from_external_image: 586 | for j in np.arange(self.sample_size): 587 | recon_zs.append(sample_zh[j]) 588 | recon_z_info = {'recon_zs':recon_zs} 589 | file_name = base_result_dir + "recon_zs.pickle" 590 | with open(file_name, 'wb') as handle: #'GAN_loss_4.pickle' 591 | pickle.dump(recon_z_info, handle, protocol=pickle.HIGHEST_PROTOCOL) 592 | 593 | if self.from_external_image == 0: 594 | informtion = {'phi_loss': phi_loss, 595 | 'z_loss':z_loss, 596 | 'sample_z': sample_z, 597 | 'sample_zh': sample_zh} 598 | else: 599 | informtion = {'phi_loss': phi_loss, 600 | 'sample_zh': sample_zh} 601 | 602 | file_name = base_result_dir + "GAN_loss.pickle" 603 | with open(file_name, 'wb') as handle: #'GAN_loss_4.pickle' 604 | pickle.dump(informtion, handle, protocol=pickle.HIGHEST_PROTOCOL) 605 | 606 | accuracy_informtion = {'accuracy_stats': reverse_accuracy_matrix} 607 | file_name = base_result_dir + "accuracy_stats.pickle" 608 | with open(file_name, 'wb') as handle: #'GAN_loss_4.pickle' 609 | pickle.dump(accuracy_informtion, handle, protocol=pickle.HIGHEST_PROTOCOL) 610 | 611 | 612 | print ('GIR accuracy experiment complete') 613 | 614 | 615 | 616 | def reverse_GAN_batch_all_prec_posterior(self, config): 617 | tf.initialize_all_variables().run() 618 | self.from_external_image = True 619 | get_training_images = self.recover_training_images 620 | 621 | if get_training_images: 622 | if self.clipping: 623 | base_result_dir = "training_images/st_clip_" + str(self.stochastic_clipping) + str("/") 624 | else: 625 | base_result_dir = "training_images/Noclip/" 626 | else: 627 | if self.clipping: 628 | base_result_dir = "unseen_images/st_clip_" + str(self.stochastic_clipping) + str("/") 629 | else: 630 | base_result_dir = "unseen_images/Noclip/" 631 | 632 | if self.untrained_net: 633 | base_result_dir = "untrained_net/" + base_result_dir 634 | else: 635 | isLoaded = self.load_new(self.checkpoint_dir) 636 | assert(isLoaded) 637 | if not os.path.exists(base_result_dir): 638 | os.makedirs(base_result_dir) 639 | 640 | save_visualizations = 1 641 | 642 | precision_levels = [1e-5] #[1e-1, 1e-2, 1e-3, 1e-4, 1e-5] 643 | reverse_accuracy_matrix = np.zeros((self.sample_size, len(precision_levels))) 644 | 645 | sample_z = np.random.uniform(-1, 1, size=(self.sample_size, self.z_dim)) ## input 646 | s_zh = np.random.uniform(-1, 1, size=(self.sample_size, self.z_dim)) ## imposter 647 | sample_zh = np.copy(s_zh) 648 | 649 | #tf.get_variable_scope().reuse_variables() 650 | if self.from_external_image == 0: 651 | orig_image = self.sess.run(self.G, feed_dict={self.z: sample_z}) 652 | else : 653 | orig_image = np.zeros((self.sample_size, 64, 64, 3)) 654 | if get_training_images: 655 | start_num = np.random.randint(100, 5000) #(0, 10100) 656 | end_num = start_num + 9000 657 | img_idx=[np.random.randint(start_num, end_num) for p in range(0, self.sample_size)] 658 | print(img_idx) 659 | png_files = [f for f in listdir("all_images/gan_train_img_align_celeba/img_align_celeba/")] 660 | for j in np.arange(self.sample_size): 661 | q = img_idx[j] 662 | image_name = "all_images/gan_train_img_align_celeba/img_align_celeba/" + png_files[q] 663 | orig_image[j] = get_image(img_name, 64) 664 | #print(orig_image.shape) 665 | else: 666 | start_num = np.random.randint(0, 200) 667 | end_num = start_num + 3000 668 | img_idx=[np.random.randint(start_num, end_num) for p in range(0, self.sample_size)] 669 | png_files = [f for f in listdir("all_images/lfw/aligned")] 670 | print(img_idx) 671 | for j in np.arange(self.sample_size): 672 | q = img_idx[j] 673 | img_name = "all_images/lfw/aligned/" + png_files[q] 674 | orig_image[j] = get_image(img_name, 64) 675 | #print(orig_image.shape) 676 | 677 | #img_name = 'unseen_face6.png' #'unseen_face7.png' 678 | #for j in np.arange(self.sample_size): 679 | #orig_image[j] = get_image(img_name, 64) 680 | 681 | imposter_image = self.sess.run(self.G_hat, feed_dict={self.reverse_z: sample_zh}) #sample_zh 682 | #print(orig_image.shape, imposter_image.shape) 683 | 684 | if save_visualizations == 1: 685 | result_dir = base_result_dir 686 | if not os.path.exists(result_dir): 687 | os.makedirs(result_dir) 688 | 689 | imgName = result_dir + "visualization_orig.jpg" 690 | save_images(orig_image, [self.viz_rows, self.viz_cols], imgName) 691 | 692 | imgName = result_dir + "visualization_imposter_init.jpg" # "crazy_imgs_4/visualization_imposter_init.jpg" 693 | save_images(imposter_image, [self.viz_rows, self.viz_cols], imgName) 694 | 695 | 696 | #initial_phi_loss = np.mean(np.power(orig_image_pixel - imposter_image_pixel,2)) 697 | if self.from_external_image == 0: 698 | z_loss = [] 699 | initial_z_loss = np.sum(np.power(sample_z - sample_zh, 2), 0) 700 | initial_z_loss = np.mean(initial_z_loss) 701 | z_loss.append(initial_z_loss) 702 | else: 703 | recon_zs = [] 704 | 705 | phi_loss = [] 706 | initial_phi_loss = np.sum(np.power(orig_image - imposter_image,2), 0) # sum over batch 707 | initial_phi_loss = np.mean(initial_phi_loss) 708 | phi_loss.append(initial_phi_loss) 709 | 710 | learning_rate = self.LEARNING_RATE #1. 711 | iterations = self.num_iters 712 | for n_iter in np.arange(iterations): 713 | if n_iter % 50000 == 0 and n_iter > 0: #2000 714 | learning_rate /= 2. 715 | 716 | if n_iter > 0: 717 | phi_loss.append(my_loss) 718 | 719 | if self.from_external_image == 0: 720 | curr_z_loss = np.power(sample_zh - sample_z, 2) 721 | curr_z_loss = np.mean(curr_z_loss, 1) 722 | #curr_z_loss = np.mean((curr_z_loss**.5), 1) 723 | 724 | z_loss.append(np.mean(curr_z_loss)) 725 | 726 | for p in np.arange(len(precision_levels)): 727 | current_z_loss_p = curr_z_loss[:,p] 728 | prec_level = precision_levels[p] 729 | d = np.where(curr_z_loss_p > p) 730 | reverse_accuracy_matrix[d, p] = 1 731 | 732 | if np.sum(reverse_accuracy_matrix) == self.sample_size * len(precision_levels): 733 | break 734 | 735 | 736 | fd = {self.G: orig_image, ## this is phi(z) 737 | self.reverse_z: sample_zh 738 | } 739 | run = [self.reverse_loss, self.reverse_grads, self.G_hat] 740 | [my_loss, my_grads, imposter_image] = self.sess.run(run, feed_dict=fd) 741 | 742 | my_grads = np.asarray(my_grads[0]) 743 | 744 | if n_iter % 200 == 0: 745 | if self.from_external_image == 0: 746 | print (n_iter, my_loss, np.mean(np.power(sample_zh - sample_z, 2))) 747 | else: 748 | print (n_iter, my_loss) 749 | if save_visualizations == 1 and n_iter != 0: 750 | imgName = result_dir + "visualization_imposter" + str(n_iter) + ".jpg" 751 | save_images(imposter_image, [self.viz_rows, self.viz_cols], imgName) 752 | 753 | sample_zh = sample_zh - learning_rate * my_grads 754 | 755 | if self.clipping: 756 | if self.stochastic_clipping: 757 | for j in range(self.sample_size): 758 | edge1 = np.where(sample_zh[j] >= 1.)[0] #1 759 | edge2 = np.where(sample_zh[j] <= -1)[0] #1 760 | 761 | if edge1.shape[0] > 0: 762 | rand_el1 = np.random.uniform(-1, 1, size=(1, edge1.shape[0])) 763 | sample_zh[j,edge1] = rand_el1 764 | if edge2.shape[0] > 0: 765 | rand_el2 = np.random.uniform(-1, 1, size=(1, edge2.shape[0])) 766 | sample_zh[j,edge2] = rand_el2 767 | 768 | #if edge1.shape[0] > 0 or edge2.shape[0] > 0: 769 | #print (edge1.shape[0], edge2.shape[0]) 770 | else: 771 | sample_zh = np.clip(sample_zh, -1, 1) 772 | 773 | if self.from_external_image: 774 | ##### apply discriminator 775 | fd_orig = {self.images: orig_image} 776 | fd_imposter = {self.images: imposter_image} 777 | run = [self.D, self.D_logits] 778 | [D_score, D_logits] = self.sess.run(run, feed_dict=fd_orig) 779 | [D_score_imposter, D_logits_imposter] = self.sess.run(run, feed_dict=fd_imposter) 780 | 781 | print(D_score, D_logits) 782 | print(D_score_imposter, D_logits_imposter) 783 | 784 | for j in np.arange(self.sample_size): 785 | recon_zs.append(sample_zh[j]) 786 | recon_z_info = {'recon_zs':recon_zs, 787 | 'D_scores': D_score, 788 | 'D_logits' : D_logits, 789 | 'D_scores_imposter': D_score_imposter, 790 | 'D_logits_imposter': D_logits_imposter} 791 | file_name = base_result_dir + "recon_zs.pickle" 792 | with open(file_name, 'wb') as handle: #'GAN_loss_4.pickle' 793 | pickle.dump(recon_z_info, handle, protocol=pickle.HIGHEST_PROTOCOL) 794 | 795 | 796 | 797 | if self.from_external_image == 0: 798 | informtion = {'phi_loss': phi_loss, 799 | 'z_loss':z_loss} 800 | else: 801 | informtion = {'phi_loss': phi_loss} 802 | 803 | file_name = base_result_dir + "GAN_loss.pickle" 804 | with open(file_name, 'wb') as handle: #'GAN_loss_4.pickle' 805 | pickle.dump(informtion, handle, protocol=pickle.HIGHEST_PROTOCOL) 806 | 807 | accuracy_informtion = {'accuracy_stats': reverse_accuracy_matrix} 808 | file_name = base_result_dir + "accuracy_stats.pickle" 809 | with open(file_name, 'wb') as handle: #'GAN_loss_4.pickle' 810 | pickle.dump(accuracy_informtion, handle, protocol=pickle.HIGHEST_PROTOCOL) 811 | 812 | 813 | print ('Latent space recovery from training or unseen images complete') 814 | 815 | 816 | def complete(self, config): 817 | os.makedirs(os.path.join(config.outDir, 'hats_imgs'), exist_ok=True) 818 | os.makedirs(os.path.join(config.outDir, 'completed'), exist_ok=True) 819 | 820 | tf.initialize_all_variables().run() 821 | 822 | isLoaded = self.load(self.checkpoint_dir) 823 | assert(isLoaded) 824 | 825 | # data = glob(os.path.join(config.dataset, "*.png")) 826 | nImgs = len(config.imgs) 827 | 828 | batch_idxs = int(np.ceil(nImgs/self.batch_size)) 829 | if config.maskType == 'random': 830 | fraction_masked = 0.2 831 | mask = np.ones(self.image_shape) 832 | mask[np.random.random(self.image_shape[:2]) < fraction_masked] = 0.0 833 | elif config.maskType == 'center': 834 | scale = 0.25 835 | assert(scale <= 0.5) 836 | mask = np.ones(self.image_shape) 837 | sz = self.image_size 838 | l = int(self.image_size*scale) 839 | u = int(self.image_size*(1.0-scale)) 840 | mask[l:u, l:u, :] = 0.0 841 | elif config.maskType == 'left': 842 | mask = np.ones(self.image_shape) 843 | c = self.image_size // 2 844 | mask[:,:c,:] = 0.0 845 | elif config.maskType == 'full': 846 | mask = np.ones(self.image_shape) 847 | else: 848 | assert(False) 849 | 850 | for idx in xrange(0, batch_idxs): 851 | l = idx*self.batch_size 852 | u = min((idx+1)*self.batch_size, nImgs) 853 | batchSz = u-l 854 | batch_files = config.imgs[l:u] 855 | batch = [get_image(batch_file, self.image_size, is_crop=self.is_crop) 856 | for batch_file in batch_files] 857 | batch_images = np.array(batch).astype(np.float32) 858 | if batchSz < self.batch_size: 859 | print(batchSz) 860 | padSz = ((0, int(self.batch_size-batchSz)), (0,0), (0,0), (0,0)) 861 | batch_images = np.pad(batch_images, padSz, 'constant') 862 | batch_images = batch_images.astype(np.float32) 863 | 864 | batch_mask = np.resize(mask, [self.batch_size] + self.image_shape) 865 | zhats = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim)) 866 | v = 0 867 | 868 | nRows = np.ceil(batchSz/8) 869 | nCols = 8 870 | save_images(batch_images[:batchSz,:,:,:], [nRows,nCols], 871 | os.path.join(config.outDir, 'before.png')) 872 | masked_images = np.multiply(batch_images, batch_mask) 873 | save_images(masked_images[:batchSz,:,:,:], [nRows,nCols], 874 | os.path.join(config.outDir, 'masked.png')) 875 | 876 | for i in xrange(config.nIter): 877 | fd = { 878 | self.z: zhats, 879 | self.mask: batch_mask, 880 | self.images: batch_images, 881 | } 882 | run = [self.complete_loss, self.grad_complete_loss, self.G] 883 | loss, g, G_imgs = self.sess.run(run, feed_dict=fd) 884 | 885 | v_prev = np.copy(v) 886 | v = config.momentum*v - config.lr*g[0] 887 | zhats += -config.momentum * v_prev + (1+config.momentum)*v 888 | zhats = np.clip(zhats, -1, 1) 889 | 890 | if i % 50 == 0: 891 | print(i, np.mean(loss[0:batchSz])) 892 | imgName = os.path.join(config.outDir, 893 | 'hats_imgs/{:04d}.png'.format(i)) 894 | nRows = np.ceil(batchSz/8) 895 | nCols = 8 896 | save_images(G_imgs[:batchSz,:,:,:], [nRows,nCols], imgName) 897 | 898 | inv_masked_hat_images = np.multiply(G_imgs, 1.0-batch_mask) 899 | completeed = masked_images + inv_masked_hat_images 900 | imgName = os.path.join(config.outDir, 901 | 'completed/{:04d}.png'.format(i)) 902 | save_images(completeed[:batchSz,:,:,:], [nRows,nCols], imgName) 903 | 904 | 905 | def generator(self, z): 906 | self.z_, self.h0_w, self.h0_b = linear(z, self.gf_dim*8*4*4, 'g_h0_lin', with_w=True) 907 | 908 | self.h0 = tf.reshape(self.z_, [-1, 4, 4, self.gf_dim * 8]) 909 | h0 = tf.nn.relu(self.g_bn0(self.h0)) 910 | 911 | self.h1, self.h1_w, self.h1_b = conv2d_transpose(h0, 912 | [self.batch_size, 8, 8, self.gf_dim*4], name='g_h1', with_w=True) 913 | h1 = tf.nn.relu(self.g_bn1(self.h1)) 914 | 915 | h2, self.h2_w, self.h2_b = conv2d_transpose(h1, 916 | [self.batch_size, 16, 16, self.gf_dim*2], name='g_h2', with_w=True) 917 | h2 = tf.nn.relu(self.g_bn2(h2)) 918 | 919 | h3, self.h3_w, self.h3_b = conv2d_transpose(h2, 920 | [self.batch_size, 32, 32, self.gf_dim*1], name='g_h3', with_w=True) 921 | h3 = tf.nn.relu(self.g_bn3(h3)) 922 | 923 | h4, self.h4_w, self.h4_b = conv2d_transpose(h3, 924 | [self.batch_size, 64, 64, 3], name='g_h4', with_w=True) 925 | 926 | return tf.nn.tanh(h4) 927 | 928 | 929 | def discriminator(self, image, reuse=False): 930 | if reuse: 931 | tf.get_variable_scope().reuse_variables() 932 | 933 | h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv')) 934 | h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim*2, name='d_h1_conv'))) 935 | h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, name='d_h2_conv'))) 936 | h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim*8, name='d_h3_conv'))) 937 | h4 = linear(tf.reshape(h3, [-1, 8192]), 1, 'd_h3_lin') 938 | 939 | return tf.nn.sigmoid(h4), h4 940 | 941 | 942 | def generator_imposter(self, reverse_z): 943 | tf.get_variable_scope().reuse_variables() 944 | 945 | self.reverse_z_, self.h0_w, self.h0_b = linear(reverse_z, self.gf_dim*8*4*4, 'g_h0_lin', with_w=True) 946 | 947 | self.h0 = tf.reshape(self.reverse_z_, [-1, 4, 4, self.gf_dim * 8]) 948 | h0 = tf.nn.relu(self.g_bn0(self.h0)) 949 | 950 | self.h1, self.h1_w, self.h1_b = conv2d_transpose(h0, 951 | [self.batch_size, 8, 8, self.gf_dim*4], name='g_h1', with_w=True) 952 | h1 = tf.nn.relu(self.g_bn1(self.h1)) 953 | 954 | h2, self.h2_w, self.h2_b = conv2d_transpose(h1, 955 | [self.batch_size, 16, 16, self.gf_dim*2], name='g_h2', with_w=True) 956 | h2 = tf.nn.relu(self.g_bn2(h2)) 957 | 958 | h3, self.h3_w, self.h3_b = conv2d_transpose(h2, 959 | [self.batch_size, 32, 32, self.gf_dim*1], name='g_h3', with_w=True) 960 | h3 = tf.nn.relu(self.g_bn3(h3)) 961 | 962 | h4, self.h4_w, self.h4_b = conv2d_transpose(h3, 963 | [self.batch_size, 64, 64, 3], name='g_h4', with_w=True) 964 | 965 | return tf.nn.tanh(h4) 966 | 967 | 968 | def sampler(self, z, y=None): 969 | tf.get_variable_scope().reuse_variables() 970 | 971 | h0 = tf.reshape(linear(z, self.gf_dim*8*4*4, 'g_h0_lin'), 972 | [-1, 4, 4, self.gf_dim * 8]) 973 | h0 = tf.nn.relu(self.g_bn0(h0, train=False)) 974 | 975 | h1 = conv2d_transpose(h0, [self.batch_size, 8, 8, self.gf_dim*4], name='g_h1') 976 | h1 = tf.nn.relu(self.g_bn1(h1, train=False)) 977 | 978 | h2 = conv2d_transpose(h1, [self.batch_size, 16, 16, self.gf_dim*2], name='g_h2') 979 | h2 = tf.nn.relu(self.g_bn2(h2, train=False)) 980 | 981 | h3 = conv2d_transpose(h2, [self.batch_size, 32, 32, self.gf_dim*1], name='g_h3') 982 | h3 = tf.nn.relu(self.g_bn3(h3, train=False)) 983 | 984 | h4 = conv2d_transpose(h3, [self.batch_size, 64, 64, 3], name='g_h4') 985 | 986 | return tf.nn.tanh(h4) 987 | 988 | def save(self, checkpoint_dir, step): 989 | if not os.path.exists(checkpoint_dir): 990 | os.makedirs(checkpoint_dir) 991 | 992 | self.saver.save(self.sess, 993 | os.path.join(checkpoint_dir, self.model_name), 994 | max_to_keep=100, 995 | global_step=step) 996 | 997 | def load_new(self, checkpoint_dir): 998 | print(" [*] Reading checkpoints...") 999 | 1000 | checkpoint_include_scopes=["g_h0_lin", "g_h1", "g_h2", "g_h3"] 1001 | inclusions = [scope.strip() for scope in checkpoint_include_scopes] 1002 | variables_to_restore = [] 1003 | for var in tf.all_variables(): 1004 | for inclusion in inclusions: 1005 | if var.op.name.startswith(inclusion): 1006 | variables_to_restore.append(var) 1007 | 1008 | self.restorer = tf.train.Saver(variables_to_restore) 1009 | 1010 | #ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 1011 | #if ckpt and ckpt.model_checkpoint_path: 1012 | #self.saver.restore(self.sess, ckpt.model_checkpoint_path) 1013 | #self.restorer.restore(self.sess, ckpt.model_checkpoint_path) 1014 | #return True 1015 | #else: 1016 | #return False 1017 | 1018 | 1019 | #checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) 1020 | checkpoint_path = "checkpoint_new2/DCGAN.model-29002" #7002 1021 | self.restorer.restore(self.sess, checkpoint_path) 1022 | self.saver.restore(self.sess, checkpoint_path) 1023 | return True 1024 | 1025 | 1026 | 1027 | def load(self, checkpoint_dir): 1028 | print(" [*] Reading checkpoints...") 1029 | 1030 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 1031 | if ckpt and ckpt.model_checkpoint_path: 1032 | self.saver.restore(self.sess, ckpt.model_checkpoint_path) 1033 | #self.restorer.restore(self.sess, ckpt.model_checkpoint_path) 1034 | return True 1035 | else: 1036 | return False 1037 | --------------------------------------------------------------------------------