├── images ├── result.gif ├── weights.png └── activations.png ├── LICENSE ├── make_gif.py ├── README.md ├── .gitignore ├── main.py ├── ops.py ├── download.py ├── utils.py └── model_mnist.py /images/result.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangqianhui/Conditional-GAN/HEAD/images/result.gif -------------------------------------------------------------------------------- /images/weights.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangqianhui/Conditional-GAN/HEAD/images/weights.png -------------------------------------------------------------------------------- /images/activations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangqianhui/Conditional-GAN/HEAD/images/activations.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 jichao zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /make_gif.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from PIL import Image, ImageSequence 4 | from images2gif import writeGif 5 | from utils import get_image 6 | from utils import read_image_list 7 | 8 | def getShapeForData(filenames): 9 | 10 | array = [Image.open(batch_file) for batch_file in filenames] 11 | #return sub_image_mean(array , IMG_CHANNEL) 12 | 13 | return array 14 | 15 | ##get the numpy array of images from the path from image 16 | def GetImage(image_path): 17 | 18 | #Get the images from the path of image 19 | list_file = read_image_list(image_path) 20 | list_file.sort(compare) 21 | 22 | image_array = getShapeForData(list_file) 23 | 24 | return image_array 25 | 26 | def compare(x , y): 27 | stat_x = os.stat(x) 28 | stat_y = os.stat(y) 29 | if stat_x.st_ctime < stat_y.st_ctime: 30 | return -1 31 | elif stat_x.st_ctime > stat_y.st_ctime: 32 | return 1 33 | else: 34 | return 0 35 | 36 | def make_gif(images): 37 | writeGif('result.gif' , images , duration=0.5) 38 | 39 | #Run 40 | image_path = './gif_images/' 41 | image_array = GetImage(image_path) 42 | 43 | make_gif(image_array) 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Conditional-GANs 2 | The test code for Conditional Generative Adversarial Nets using tensorflow. 3 | 4 | ## INTRODUCTION 5 | 6 | Tensorflow implements of [Conditional Generative Adversarial Nets](https://arxiv.org/abs/1411.1784).The paper should be the first one to introduce Conditional GANS.But they did not provide source codes.My code has some differences comparing the paper:The Gans is based on Convolution network and the code refer to [DCGAN](https://github.com/carpedm20/DCGAN-tensorflow). 7 | 8 | ## Prerequisites 9 | 10 | - tensorflow >=1.0 11 | 12 | - python 2.7 13 | 14 | - opencv 2.4.8 15 | 16 | - scipy 0.13 17 | 18 | ## Usage 19 | 20 | Download mnist: 21 | 22 | $ python download.py mnist 23 | 24 | Train: 25 | 26 | $ python main.py --op 0 27 | 28 | Test: 29 | 30 | $ python main.py --op 1 31 | 32 | Visualization: 33 | 34 | $ python main.py --op 2 35 | 36 | GIF: 37 | 38 | $ python make_gif.py 39 | 40 | ## Result on mnist 41 | 42 | ![](images/result.gif) 43 | 44 | 45 | ## Visualization: 46 | 47 | the visualization of weights: 48 | 49 | ![](images/weights.png) 50 | 51 | the visualization of activation: 52 | 53 | ![](images/activations.png) 54 | 55 | 56 | ## Reference code 57 | 58 | [DCGAN](https://github.com/carpedm20/DCGAN-tensorflow) 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from model_mnist import CGAN 2 | import tensorflow as tf 3 | from utils import Mnist 4 | import os 5 | 6 | flags = tf.app.flags 7 | 8 | flags.DEFINE_string("sample_dir" , "samples_for_test" , "the dir of sample images") 9 | flags.DEFINE_integer("output_size", 28 , "the size of generate image") 10 | flags.DEFINE_float("learn_rate", 0.0002, "the learning rate for gan") 11 | flags.DEFINE_integer("batch_size", 64, "the batch number") 12 | flags.DEFINE_integer("z_dim", 100, "the dimension of noise z") 13 | flags.DEFINE_integer("y_dim", 10, "the dimension of condition y") 14 | flags.DEFINE_string("log_dir" , "/tmp/tensorflow_mnist" , "the path of tensorflow's log") 15 | flags.DEFINE_string("model_path" , "model/model.ckpt" , "the path of model") 16 | flags.DEFINE_string("visua_path" , "visualization" , "the path of visuzation images") 17 | flags.DEFINE_integer("op" , 0, "0: train ; 1:test ; 2:visualize") 18 | 19 | FLAGS = flags.FLAGS 20 | # 21 | if not os.path.exists(FLAGS.sample_dir): 22 | os.makedirs(FLAGS.sample_dir) 23 | if not os.path.exists(FLAGS.log_dir): 24 | os.makedirs(FLAGS.log_dir) 25 | if not os.path.exists(FLAGS.model_path): 26 | os.makedirs(FLAGS.model_path) 27 | if not os.path.exists(FLAGS.visua_path): 28 | os.makedirs(FLAGS.visua_path) 29 | 30 | def main(_): 31 | 32 | mn_object = Mnist() 33 | 34 | cg = CGAN(data_ob = mn_object, sample_dir = FLAGS.sample_dir, output_size=FLAGS.output_size, learn_rate=FLAGS.learn_rate 35 | , batch_size=FLAGS.batch_size, z_dim=FLAGS.z_dim, y_dim=FLAGS.y_dim, log_dir=FLAGS.log_dir 36 | , model_path=FLAGS.model_path, visua_path=FLAGS.visua_path) 37 | 38 | cg.build_model() 39 | 40 | if FLAGS.op == 0: 41 | 42 | cg.train() 43 | 44 | elif FLAGS.op == 1: 45 | 46 | cg.test() 47 | 48 | else: 49 | 50 | cg.visual() 51 | 52 | if __name__ == '__main__': 53 | tf.app.run() 54 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.layers.python.layers import batch_norm, variance_scaling_initializer 3 | 4 | #the implements of leakyRelu 5 | def lrelu(x , alpha = 0.2 , name="LeakyReLU"): 6 | return tf.maximum(x , alpha*x) 7 | 8 | def conv2d(input_, output_dim, 9 | k_h=3, k_w=3, d_h=2, d_w=2, 10 | name="conv2d"): 11 | with tf.variable_scope(name): 12 | 13 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], 14 | initializer= variance_scaling_initializer()) 15 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') 16 | 17 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 18 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 19 | 20 | return conv, w 21 | 22 | def de_conv(input_, output_shape, 23 | k_h=3, k_w=3, d_h=2, d_w=2, stddev=0.02, name="deconv2d", 24 | with_w=False, initializer = variance_scaling_initializer()): 25 | 26 | with tf.variable_scope(name): 27 | # filter : [height, width, output_channels, in_channels] 28 | w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]], 29 | initializer = initializer) 30 | try: 31 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, 32 | strides=[1, d_h, d_w, 1]) 33 | # Support for verisons of TensorFlow before 0.7.0 34 | except AttributeError: 35 | deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, 36 | strides=[1, d_h, d_w, 1]) 37 | 38 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 39 | deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) 40 | 41 | if with_w: 42 | return deconv, w, biases 43 | else: 44 | return deconv 45 | 46 | def fully_connect(input_, output_size, scope=None, with_w=False, 47 | initializer = variance_scaling_initializer()): 48 | 49 | shape = input_.get_shape().as_list() 50 | 51 | with tf.variable_scope(scope or "Linear"): 52 | 53 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, 54 | initializer = initializer) 55 | bias = tf.get_variable("bias", [output_size], initializer=tf.constant_initializer(0.0)) 56 | if with_w: 57 | return tf.matmul(input_, matrix) + bias, matrix, bias 58 | else: 59 | return tf.matmul(input_, matrix) + bias 60 | 61 | def conv_cond_concat(x, y): 62 | """Concatenate conditioning vector on feature map axis.""" 63 | x_shapes = x.get_shape() 64 | y_shapes = y.get_shape() 65 | 66 | return tf.concat([x , y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2] , y_shapes[3]])], 3) 67 | 68 | def batch_normal(input , scope="scope" , reuse=False): 69 | return batch_norm(input , epsilon=1e-5, decay=0.9 , scale=True, scope=scope , reuse = reuse , updates_collections=None) 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import json 5 | import zipfile 6 | import argparse 7 | import subprocess 8 | from six.moves import urllib 9 | 10 | parser = argparse.ArgumentParser(description='Download dataset for DCGAN.') 11 | parser.add_argument('datasets', metavar='N', type=str, nargs='+', choices=['celebA', 'lsun', 'mnist'], 12 | help='name of dataset to download [celebA, lsun, mnist]' , default='mnist') 13 | 14 | def download(url, dirpath): 15 | 16 | filename = url.split('/')[-1] 17 | filepath = os.path.join(dirpath, filename) 18 | u = urllib.request.urlopen(url) 19 | f = open(filepath, 'wb') 20 | filesize = int(u.headers["Content-Length"]) 21 | print("Downloading: %s Bytes: %s" % (filename, filesize)) 22 | 23 | downloaded = 0 24 | block_sz = 8192 25 | status_width = 70 26 | while True: 27 | buf = u.read(block_sz) 28 | if not buf: 29 | print('') 30 | break 31 | else: 32 | print('', end='\r') 33 | downloaded += len(buf) 34 | f.write(buf) 35 | status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") % 36 | ('=' * int(float(downloaded) / filesize * status_width) + '>', downloaded * 100. / filesize)) 37 | print(status, end='') 38 | sys.stdout.flush() 39 | f.close() 40 | return filepath 41 | 42 | def unzip(filepath): 43 | print("Extracting: " + filepath) 44 | dirpath = os.path.dirname(filepath) 45 | with zipfile.ZipFile(filepath) as zf: 46 | zf.extractall(dirpath) 47 | os.remove(filepath) 48 | 49 | def download_celeb_a(dirpath): 50 | data_dir = 'celebA' 51 | if os.path.exists(os.path.join(dirpath, data_dir)): 52 | print('Found Celeb-A - skip') 53 | return 54 | url = 'https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip?dl=1&pv=1' 55 | filepath = download(url, dirpath) 56 | zip_dir = '' 57 | with zipfile.ZipFile(filepath) as zf: 58 | zip_dir = zf.namelist()[0] 59 | zf.extractall(dirpath) 60 | os.remove(filepath) 61 | os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir)) 62 | 63 | def _list_categories(tag): 64 | url = 'http://lsun.cs.princeton.edu/htbin/list.cgi?tag=' + tag 65 | f = urllib.request.urlopen(url) 66 | return json.loads(f.read()) 67 | 68 | def _download_lsun(out_dir, category, set_name, tag): 69 | url = 'http://lsun.cs.princeton.edu/htbin/download.cgi?tag={tag}' \ 70 | '&category={category}&set={set_name}'.format(**locals()) 71 | print(url) 72 | if set_name == 'test': 73 | out_name = 'test_lmdb.zip' 74 | else: 75 | out_name = '{category}_{set_name}_lmdb.zip'.format(**locals()) 76 | out_path = os.path.join(out_dir, out_name) 77 | cmd = ['curl', url, '-o', out_path] 78 | print('Downloading', category, set_name, 'set') 79 | subprocess.call(cmd) 80 | 81 | def download_lsun(dirpath): 82 | data_dir = os.path.join(dirpath, 'lsun') 83 | if os.path.exists(data_dir): 84 | print('Found LSUN - skip') 85 | return 86 | else: 87 | os.mkdir(data_dir) 88 | 89 | tag = 'latest' 90 | #categories = _list_categories(tag) 91 | categories = ['bedroom'] 92 | 93 | for category in categories: 94 | _download_lsun(data_dir, category, 'train', tag) 95 | _download_lsun(data_dir, category, 'val', tag) 96 | _download_lsun(data_dir, '', 'test', tag) 97 | 98 | def download_mnist(dirpath): 99 | data_dir = os.path.join(dirpath, 'mnist') 100 | if os.path.exists(data_dir): 101 | print('Found MNIST - skip') 102 | return 103 | else: 104 | os.mkdir(data_dir) 105 | url_base = 'http://yann.lecun.com/exdb/mnist/' 106 | file_names = ['train-images-idx3-ubyte.gz','train-labels-idx1-ubyte.gz','t10k-images-idx3-ubyte.gz','t10k-labels-idx1-ubyte.gz'] 107 | for file_name in file_names: 108 | url = (url_base+file_name).format(**locals()) 109 | print(url) 110 | out_path = os.path.join(data_dir,file_name) 111 | cmd = ['curl', url, '-o', out_path] 112 | print('Downloading ', file_name) 113 | subprocess.call(cmd) 114 | cmd = ['gzip', '-d', out_path] 115 | print('Decompressing ', file_name) 116 | subprocess.call(cmd) 117 | 118 | def prepare_data_dir(path = './data'): 119 | if not os.path.exists(path): 120 | os.mkdir(path) 121 | 122 | if __name__ == '__main__': 123 | args = parser.parse_args() 124 | prepare_data_dir() 125 | 126 | if 'celebA' in args.datasets: 127 | download_celeb_a('./data') 128 | if 'lsun' in args.datasets: 129 | download_lsun('./data') 130 | if 'mnist' in args.datasets: 131 | download_mnist('./data') -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy 4 | import scipy.misc 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | class Mnist(object): 9 | 10 | def __init__(self): 11 | 12 | self.dataname = "Mnist" 13 | self.dims = 28*28 14 | self.shape = [28 , 28 , 1] 15 | self.image_size = 28 16 | self.data, self.data_y = self.load_mnist() 17 | 18 | def load_mnist(self): 19 | 20 | data_dir = os.path.join("./data", "mnist") 21 | fd = open(os.path.join(data_dir, 'train-images-idx3-ubyte')) 22 | loaded = np.fromfile(file=fd , dtype=np.uint8) 23 | trX = loaded[16:].reshape((60000, 28 , 28 , 1)).astype(np.float) 24 | 25 | fd = open(os.path.join(data_dir, 'train-labels-idx1-ubyte')) 26 | loaded = np.fromfile(file=fd, dtype=np.uint8) 27 | trY = loaded[8:].reshape((60000)).astype(np.float) 28 | 29 | fd = open(os.path.join(data_dir, 't10k-images-idx3-ubyte')) 30 | loaded = np.fromfile(file=fd, dtype=np.uint8) 31 | teX = loaded[16:].reshape((10000, 28 , 28 , 1)).astype(np.float) 32 | 33 | fd = open(os.path.join(data_dir, 't10k-labels-idx1-ubyte')) 34 | loaded = np.fromfile(file=fd, dtype=np.uint8) 35 | teY = loaded[8:].reshape((10000)).astype(np.float) 36 | 37 | trY = np.asarray(trY) 38 | teY = np.asarray(teY) 39 | 40 | X = np.concatenate((trX, teX), axis=0) 41 | y = np.concatenate((trY, teY), axis=0) 42 | 43 | seed = 547 44 | 45 | np.random.seed(seed) 46 | np.random.shuffle(X) 47 | np.random.seed(seed) 48 | np.random.shuffle(y) 49 | 50 | #convert label to one-hot 51 | 52 | y_vec = np.zeros((len(y), 10), dtype=np.float) 53 | for i, label in enumerate(y): 54 | y_vec[i, int(y[i])] = 1.0 55 | 56 | return X / 255., y_vec 57 | 58 | def getNext_batch(self, iter_num=0, batch_size=64): 59 | 60 | ro_num = len(self.data) / batch_size - 1 61 | 62 | if iter_num % ro_num == 0: 63 | 64 | length = len(self.data) 65 | perm = np.arange(length) 66 | np.random.shuffle(perm) 67 | self.data = np.array(self.data) 68 | self.data = self.data[perm] 69 | self.data_y = np.array(self.data_y) 70 | self.data_y = self.data_y[perm] 71 | 72 | return self.data[int(iter_num % ro_num) * batch_size: int(iter_num% ro_num + 1) * batch_size] \ 73 | , self.data_y[int(iter_num % ro_num) * batch_size: int(iter_num%ro_num + 1) * batch_size] 74 | 75 | 76 | def get_image(image_path , is_grayscale = False): 77 | return np.array(inverse_transform(imread(image_path, is_grayscale))) 78 | 79 | 80 | def save_images(images , size , image_path): 81 | return imsave(inverse_transform(images) , size , image_path) 82 | 83 | def imread(path, is_grayscale = False): 84 | if (is_grayscale): 85 | return scipy.misc.imread(path, flatten = True).astype(np.float) 86 | else: 87 | return scipy.misc.imread(path).astype(np.float) 88 | 89 | def imsave(images , size , path): 90 | return scipy.misc.imsave(path , merge(images , size)) 91 | 92 | def merge(images , size): 93 | h , w = images.shape[1] , images.shape[2] 94 | img = np.zeros((h*size[0] , w*size[1] , 3)) 95 | for idx , image in enumerate(images): 96 | i = idx % size[1] 97 | j = idx // size[1] 98 | img[j*h:j*h +h , i*w : i*w+w , :] = image 99 | 100 | return img 101 | 102 | def inverse_transform(image): 103 | return (image + 1.)/2. 104 | 105 | def read_image_list(category): 106 | filenames = [] 107 | print("list file") 108 | list = os.listdir(category) 109 | 110 | for file in list: 111 | filenames.append(category + "/" + file) 112 | 113 | print("list file ending!") 114 | 115 | return filenames 116 | 117 | ##from caffe 118 | def vis_square(visu_path , data , type): 119 | """Take an array of shape (n, height, width) or (n, height, width , 3) 120 | and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)""" 121 | 122 | # normalize data for display 123 | data = (data - data.min()) / (data.max() - data.min()) 124 | 125 | # force the number of filters to be square 126 | n = int(np.ceil(np.sqrt(data.shape[0]))) 127 | 128 | padding = (((0, n ** 2 - data.shape[0]) , 129 | (0, 1), (0, 1)) # add some space between filters 130 | + ((0, 0),) * (data.ndim - 3)) # don't pad the last dimension (if there is one) 131 | data = np.pad(data , padding, mode='constant' , constant_values=1) # pad with ones (white) 132 | 133 | # tilethe filters into an im age 134 | data = data.reshape((n , n) + data.shape[1:]).transpose((0 , 2 , 1 , 3) + tuple(range(4 , data.ndim + 1))) 135 | 136 | data = data.reshape((n * data.shape[1] , n * data.shape[3]) + data.shape[4:]) 137 | 138 | plt.imshow(data[:,:,0]) 139 | plt.axis('off') 140 | 141 | if type: 142 | plt.savefig('./{}/weights.png'.format(visu_path) , format='png') 143 | else: 144 | plt.savefig('./{}/activation.png'.format(visu_path) , format='png') 145 | 146 | 147 | def sample_label(): 148 | num = 64 149 | label_vector = np.zeros((num , 10), dtype=np.float) 150 | for i in range(0 , num): 151 | label_vector[i , int(i/8)] = 1.0 152 | return label_vector 153 | -------------------------------------------------------------------------------- /model_mnist.py: -------------------------------------------------------------------------------- 1 | from utils import save_images, vis_square,sample_label 2 | from tensorflow.contrib.layers.python.layers import xavier_initializer 3 | import cv2 4 | from ops import conv2d, lrelu, de_conv, fully_connect, conv_cond_concat, batch_normal 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | class CGAN(object): 9 | 10 | # build model 11 | def __init__(self, data_ob, sample_dir, output_size, learn_rate, batch_size, z_dim, y_dim, log_dir 12 | , model_path, visua_path): 13 | 14 | self.data_ob = data_ob 15 | self.sample_dir = sample_dir 16 | self.output_size = output_size 17 | self.learn_rate = learn_rate 18 | self.batch_size = batch_size 19 | self.z_dim = z_dim 20 | self.y_dim = y_dim 21 | self.log_dir = log_dir 22 | self.model_path = model_path 23 | self.vi_path = visua_path 24 | self.channel = self.data_ob.shape[2] 25 | self.images = tf.placeholder(tf.float32, [batch_size, self.output_size, self.output_size, self.channel]) 26 | self.z = tf.placeholder(tf.float32, [self.batch_size, self.z_dim]) 27 | self.y = tf.placeholder(tf.float32, [self.batch_size, self.y_dim]) 28 | 29 | def build_model(self): 30 | 31 | self.fake_images = self.gern_net(self.z, self.y) 32 | G_image = tf.summary.image("G_out", self.fake_images) 33 | ##the loss of gerenate network 34 | D_pro, D_logits = self.dis_net(self.images, self.y, False) 35 | D_pro_sum = tf.summary.histogram("D_pro", D_pro) 36 | 37 | G_pro, G_logits = self.dis_net(self.fake_images, self.y, True) 38 | G_pro_sum = tf.summary.histogram("G_pro", G_pro) 39 | 40 | D_fake_loss = tf.reduce_mean( 41 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(G_pro), logits=G_logits)) 42 | 43 | D_real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(D_pro), logits=D_logits)) 44 | G_fake_loss = tf.reduce_mean( 45 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(G_pro), logits=G_logits)) 46 | 47 | self.D_loss = D_real_loss + D_fake_loss 48 | self.G_loss = G_fake_loss 49 | 50 | loss_sum = tf.summary.scalar("D_loss", self.D_loss) 51 | G_loss_sum = tf.summary.scalar("G_loss", self.G_loss) 52 | 53 | self.merged_summary_op_d = tf.summary.merge([loss_sum, D_pro_sum]) 54 | self.merged_summary_op_g = tf.summary.merge([G_loss_sum, G_pro_sum, G_image]) 55 | 56 | t_vars = tf.trainable_variables() 57 | self.d_var = [var for var in t_vars if 'dis' in var.name] 58 | self.g_var = [var for var in t_vars if 'gen' in var.name] 59 | 60 | self.saver = tf.train.Saver() 61 | 62 | def train(self): 63 | 64 | opti_D = tf.train.AdamOptimizer(learning_rate=self.learn_rate, beta1=0.5).minimize(self.D_loss, var_list=self.d_var) 65 | opti_G = tf.train.AdamOptimizer(learning_rate=self.learn_rate, beta1=0.5).minimize(self.G_loss, 66 | var_list=self.g_var) 67 | init = tf.global_variables_initializer() 68 | 69 | config = tf.ConfigProto() 70 | config.gpu_options.allow_growth = True 71 | 72 | with tf.Session(config=config) as sess: 73 | 74 | sess.run(init) 75 | summary_writer = tf.summary.FileWriter(self.log_dir, graph=sess.graph) 76 | 77 | step = 0 78 | while step <= 10000: 79 | 80 | realbatch_array, real_labels = self.data_ob.getNext_batch(step) 81 | 82 | # Get the z 83 | batch_z = np.random.uniform(-1, 1, size=[self.batch_size, self.z_dim]) 84 | # batch_z = np.random.normal(0 , 0.2 , size=[batch_size , sample_size]) 85 | 86 | _, summary_str = sess.run([opti_D, self.merged_summary_op_d], 87 | feed_dict={self.images: realbatch_array, self.z: batch_z, self.y: real_labels}) 88 | summary_writer.add_summary(summary_str, step) 89 | 90 | _, summary_str = sess.run([opti_G, self.merged_summary_op_g], 91 | feed_dict={self.z: batch_z, self.y: real_labels}) 92 | summary_writer.add_summary(summary_str, step) 93 | 94 | if step % 50 == 0: 95 | 96 | D_loss = sess.run(self.D_loss, feed_dict={self.images: realbatch_array, self.z: batch_z, self.y: real_labels}) 97 | fake_loss = sess.run(self.G_loss, feed_dict={self.z: batch_z, self.y: real_labels}) 98 | print("Step %d: D: loss = %.7f G: loss=%.7f " % (step, D_loss, fake_loss)) 99 | 100 | if np.mod(step, 50) == 1 and step != 0: 101 | 102 | sample_images = sess.run(self.fake_images, feed_dict={self.z: batch_z, self.y: sample_label()}) 103 | save_images(sample_images, [8, 8], 104 | './{}/train_{:04d}.png'.format(self.sample_dir, step)) 105 | 106 | self.saver.save(sess, self.model_path) 107 | 108 | step = step + 1 109 | 110 | save_path = self.saver.save(sess, self.model_path) 111 | print ("Model saved in file: %s" % save_path) 112 | 113 | def test(self): 114 | 115 | init = tf.initialize_all_variables() 116 | 117 | with tf.Session() as sess: 118 | sess.run(init) 119 | 120 | self.saver.restore(sess, self.model_path) 121 | sample_z = np.random.uniform(1, -1, size=[self.batch_size, self.z_dim]) 122 | 123 | output = sess.run(self.fake_images, feed_dict={self.z: sample_z, self.y: sample_label()}) 124 | 125 | save_images(output, [8, 8], './{}/test{:02d}_{:04d}.png'.format(self.sample_dir, 0, 0)) 126 | 127 | image = cv2.imread('./{}/test{:02d}_{:04d}.png'.format(self.sample_dir, 0, 0), 0) 128 | 129 | cv2.imshow("test", image) 130 | 131 | cv2.waitKey(-1) 132 | 133 | print("Test finish!") 134 | 135 | def visual(self): 136 | 137 | init = tf.initialize_all_variables() 138 | with tf.Session() as sess: 139 | sess.run(init) 140 | 141 | self.saver.restore(sess, self.model_path) 142 | 143 | realbatch_array, real_labels = self.data_ob.getNext_batch(0) 144 | batch_z = np.random.uniform(-1, 1, size=[self.batch_size, self.z_dim]) 145 | # visualize the weights 1 or you can change weights_2 . 146 | conv_weights = sess.run([tf.get_collection('weight_2')]) 147 | vis_square(self.vi_path, conv_weights[0][0].transpose(3, 0, 1, 2), type=1) 148 | 149 | # visualize the activation 1 150 | ac = sess.run([tf.get_collection('ac_2')], 151 | feed_dict={self.images: realbatch_array[:64], self.z: batch_z, self.y: sample_label()}) 152 | 153 | vis_square(self.vi_path, ac[0][0].transpose(3, 1, 2, 0), type=0) 154 | 155 | print("the visualization finish!") 156 | 157 | def gern_net(self, z, y): 158 | 159 | with tf.variable_scope('generator') as scope: 160 | 161 | yb = tf.reshape(y, shape=[self.batch_size, 1, 1, self.y_dim]) 162 | z = tf.concat([z, y], 1) 163 | c1, c2 = int( self.output_size / 4), int(self.output_size / 2 ) 164 | 165 | # 10 stand for the num of labels 166 | d1 = tf.nn.relu(batch_normal(fully_connect(z, output_size=1024, scope='gen_fully'), scope='gen_bn1')) 167 | 168 | d1 = tf.concat([d1, y], 1) 169 | 170 | d2 = tf.nn.relu(batch_normal(fully_connect(d1, output_size=7*7*2*64, scope='gen_fully2'), scope='gen_bn2')) 171 | 172 | d2 = tf.reshape(d2, [self.batch_size, c1, c1, 64 * 2]) 173 | d2 = conv_cond_concat(d2, yb) 174 | 175 | d3 = tf.nn.relu(batch_normal(de_conv(d2, output_shape=[self.batch_size, c2, c2, 128], name='gen_deconv1'), scope='gen_bn3')) 176 | 177 | d3 = conv_cond_concat(d3, yb) 178 | 179 | d4 = de_conv(d3, output_shape=[self.batch_size, self.output_size, self.output_size, self.channel], 180 | name='gen_deconv2', initializer = xavier_initializer()) 181 | 182 | return tf.nn.sigmoid(d4) 183 | 184 | def dis_net(self, images, y, reuse=False): 185 | 186 | with tf.variable_scope("discriminator") as scope: 187 | 188 | if reuse == True: 189 | scope.reuse_variables() 190 | 191 | # mnist data's shape is (28 , 28 , 1) 192 | yb = tf.reshape(y, shape=[self.batch_size, 1, 1, self.y_dim]) 193 | # concat 194 | concat_data = conv_cond_concat(images, yb) 195 | 196 | conv1, w1 = conv2d(concat_data, output_dim=10, name='dis_conv1') 197 | tf.add_to_collection('weight_1', w1) 198 | 199 | conv1 = lrelu(conv1) 200 | conv1 = conv_cond_concat(conv1, yb) 201 | tf.add_to_collection('ac_1', conv1) 202 | 203 | 204 | conv2, w2 = conv2d(conv1, output_dim=64, name='dis_conv2') 205 | tf.add_to_collection('weight_2', w2) 206 | 207 | conv2 = lrelu(batch_normal(conv2, scope='dis_bn1')) 208 | tf.add_to_collection('ac_2', conv2) 209 | 210 | conv2 = tf.reshape(conv2, [self.batch_size, -1]) 211 | conv2 = tf.concat([conv2, y], 1) 212 | 213 | f1 = lrelu(batch_normal(fully_connect(conv2, output_size=1024, scope='dis_fully1'), scope='dis_bn2', reuse=reuse)) 214 | f1 = tf.concat([f1, y], 1) 215 | 216 | out = fully_connect(f1, output_size=1, scope='dis_fully2', initializer = xavier_initializer()) 217 | 218 | return tf.nn.sigmoid(out), out 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | --------------------------------------------------------------------------------