├── intro.jpg ├── Results ├── bag.jpg ├── digit.jpg ├── joint1.jpg ├── joint2.jpg ├── tran3.jpg └── tran4.jpg ├── ops.py ├── utils.py ├── main.py ├── README.md ├── module.py └── model.py /intro.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangyuanHao/MIXGAN/HEAD/intro.jpg -------------------------------------------------------------------------------- /Results/bag.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangyuanHao/MIXGAN/HEAD/Results/bag.jpg -------------------------------------------------------------------------------- /Results/digit.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangyuanHao/MIXGAN/HEAD/Results/digit.jpg -------------------------------------------------------------------------------- /Results/joint1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangyuanHao/MIXGAN/HEAD/Results/joint1.jpg -------------------------------------------------------------------------------- /Results/joint2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangyuanHao/MIXGAN/HEAD/Results/joint2.jpg -------------------------------------------------------------------------------- /Results/tran3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangyuanHao/MIXGAN/HEAD/Results/tran3.jpg -------------------------------------------------------------------------------- /Results/tran4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangyuanHao/MIXGAN/HEAD/Results/tran4.jpg -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import tensorflow as tf 4 | import tensorflow.contrib.slim as slim 5 | from tensorflow.python.framework import ops 6 | 7 | 8 | def batch_norm(x, name="batch_norm"): 9 | return tf.contrib.layers.batch_norm(x, decay=0.9, updates_collections=None, epsilon=1e-5, scale=True, scope=name) 10 | 11 | def instance_norm(input, name="instance_norm"): 12 | with tf.variable_scope(name): 13 | depth = input.get_shape()[3] 14 | scale = tf.get_variable("scale", [depth], initializer=tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32)) 15 | offset = tf.get_variable("offset", [depth], initializer=tf.constant_initializer(0.0)) 16 | mean, variance = tf.nn.moments(input, axes=[1,2], keep_dims=True) 17 | epsilon = 1e-5 18 | inv = tf.rsqrt(variance + epsilon) 19 | normalized = (input-mean)*inv 20 | return scale*normalized + offset 21 | 22 | def conv2d(input_, output_dim, ks=4, s=2, stddev=0.02, padding='SAME', name="conv2d"): 23 | with tf.variable_scope(name): 24 | return slim.conv2d(input_, output_dim, ks, s, padding=padding, activation_fn=None, 25 | weights_initializer=tf.truncated_normal_initializer(stddev=stddev), 26 | biases_initializer=None) 27 | 28 | def deconv2d(input_, output_dim, ks=4, s=2, stddev=0.02, name="deconv2d"): 29 | with tf.variable_scope(name): 30 | return slim.conv2d_transpose(input_, output_dim, ks, s, padding='SAME', activation_fn=None, 31 | weights_initializer=tf.truncated_normal_initializer(stddev=stddev), 32 | biases_initializer=None) 33 | 34 | def lrelu(x, leak=0.2, name="lrelu"): 35 | return tf.maximum(x, leak*x) 36 | 37 | def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): 38 | 39 | with tf.variable_scope(scope or "Linear"): 40 | matrix = tf.get_variable("Matrix", [input_.get_shape()[-1], output_size], tf.float32, 41 | tf.random_normal_initializer(stddev=stddev)) 42 | bias = tf.get_variable("bias", [output_size], 43 | initializer=tf.constant_initializer(bias_start)) 44 | if with_w: 45 | return tf.matmul(input_, matrix) + bias, matrix, bias 46 | else: 47 | return tf.matmul(input_, matrix) + bias 48 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import math 3 | import json 4 | import random 5 | import pprint 6 | import scipy.misc 7 | import numpy as np 8 | import copy 9 | import os 10 | from scipy.io import loadmat as load 11 | import numpy as np 12 | import scipy 13 | from PIL import Image 14 | import cv2 15 | 16 | def load_data(array): 17 | n =array.shape[0] 18 | # print("n",n) 19 | size = array[0].shape 20 | # print(size) 21 | imgA = array[0].reshape(1,size[0],size[1],size[2]) 22 | # print(imgB.shape) 23 | for i in range(n-1): 24 | # print(array[i + 1][0].reshape(1, size[0], size[1], size[2]).shape) 25 | imgA = np.concatenate((imgA,array[i+1].reshape(1,size[0],size[1],size[2])),axis=0) 26 | # print(("[%d]"%(i)),imgB.shape) 27 | return imgA/127.5-1.0 28 | def rgb2gray(rgb): 29 | return np.dot(rgb[...,:3],[0.299, 0.587, 0.144]) 30 | # imgA = load_data(q).astype(np.float32).reshape(32,32,3) 31 | # imgA = rgb2gray(imgA) 32 | 33 | 34 | # print(cannyA) 35 | # print(np.max(imgA[1])) 36 | 37 | def load_label(array): 38 | n =array.shape[0] 39 | hot_code = np.zeros(10).reshape(1,10) 40 | hot_code[0][array[0][1]]=1 41 | labelA = hot_code 42 | # print(hot_code) 43 | for i in range(n-1): 44 | hot_code = np.zeros(10).reshape(1, 10) 45 | hot_code[0][array[i+1][1]] = 1 46 | labelA = np.concatenate((labelA,hot_code),axis=0) 47 | # print(("[%d]"%(i)),labelA.shape) 48 | return labelA 49 | # labelA= load_label(q) 50 | # print(labelA) 51 | 52 | # ____________________________________________________ 53 | 54 | def save_images(image, size, path): 55 | return imsave(inverse_transform(image), size, path) 56 | 57 | def imsave(image, size, path): 58 | return scipy.misc.imsave(path, merge(image, size), format='png') 59 | 60 | def merge(image, size): 61 | # print(type(image)) 62 | [n, h, w, c] = image.shape 63 | # print(n, h, w, c) 64 | image = image.reshape(n * h, w, c).astype(np.float) 65 | if c == 1: 66 | image = image.reshape(n * h, w) 67 | img = image[:h * size[0]] 68 | # print(img.shape) 69 | # print(size[1]) 70 | for i in range(size[1] - 1): 71 | img = np.concatenate((img, image[(i + 1) * h * size[0]:(i + 2) * h * size[0]]), axis=1) 72 | # print(img.shape) 73 | # print(img.shape) 74 | return img 75 | 76 | def inverse_transform(image): 77 | return (image+1.)/2. 78 | 79 | 80 | 81 | # if __name__ == '__main__': 82 | # 83 | # pass 84 | 85 | 86 | # python svhnmat.py -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import scipy.misc 4 | import numpy as np 5 | import tensorflow as tf 6 | from model import mixgan 7 | 8 | parser = argparse.ArgumentParser(description='') 9 | parser.add_argument('--dataset_dir', dest='dataset_dir', default='m2v', help='path of the dataset') 10 | parser.add_argument('--epoch', dest='epoch', type=int, default=300, help='# of epoch') 11 | parser.add_argument('--batch_size', dest='batch_size', type=int, default=64, help='# images in batch') 12 | parser.add_argument('--z_dim', dest='z_dim', type=int, default=2048, help='z dim') 13 | parser.add_argument('--train_size', dest='train_size', type=int, default=1e8, help='# images used to train') 14 | parser.add_argument('--load_size', dest='load_size', type=int, default=64, help='scale images to this size') 15 | parser.add_argument('--fine_size', dest='fine_size', type=int, default=64, help='then crop to this size') 16 | parser.add_argument('--ngf', dest='ngf', type=int, default=32, help='# of gen filters in first conv layer') 17 | parser.add_argument('--ndf', dest='ndf', type=int, default=64, help='# of discri filters in first conv layer') 18 | parser.add_argument('--input_nc', dest='input_nc', type=int, default=1, help='# of input image channels') 19 | parser.add_argument('--output_nc', dest='output_nc', type=int, default=3, help='# of output image channels') 20 | parser.add_argument('--niter', dest='niter', type=int, default=200, help='# of iter at starting learning rate') 21 | parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='initial learning rate for adam') 22 | parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam') 23 | parser.add_argument('--flip', dest='flip', type=bool, default=True, help='if flip the images for data argumentation') 24 | parser.add_argument('--which_direction', dest='which_direction', default='AtoB', help='AtoB or BtoA') 25 | parser.add_argument('--phase', dest='phase', default='train', help='train, test') 26 | parser.add_argument('--save_epoch_freq', dest='save_epoch_freq', type=int, default=50, help='save a model every save_epoch_freq epochs (does not overwrite previously saved models)') 27 | parser.add_argument('--save_latest_freq', dest='save_latest_freq', type=int, default=5000, help='save the latest model every latest_freq sgd iterations (overwrites the previous latest model)') 28 | parser.add_argument('--print_freq', dest='print_freq', type=int, default=50, help='print the debug information every print_freq iterations') 29 | parser.add_argument('--continue_train', dest='continue_train', type=bool, default=False, help='if continue training, load the latest model: 1: true, 0: false') 30 | parser.add_argument('--serial_batches', dest='serial_batches', type=bool, default=False, help='f 1, takes images in order to make batches, otherwise takes them randomly') 31 | parser.add_argument('--serial_batch_iter', dest='serial_batch_iter', type=bool, default=True, help='iter into serial image list') 32 | parser.add_argument('--checkpoint_dir', dest='checkpoint_dir', default='./checkpoint', help='models are saved here') 33 | parser.add_argument('--sample_dir', dest='sample_dir', default='./sample', help='samples are saved here') 34 | parser.add_argument('--log_dir', dest='log_dir', default='./logs', help='logs are saved here') 35 | parser.add_argument('--test_dir', dest='test_dir', default='./test', help='test samples are saved here') 36 | parser.add_argument('--L1_lambda', dest='L1_lambda', type=float, default=10.0, help='weight on L1 term in objective') 37 | parser.add_argument('--use_resnet', dest='use_resnet', type=bool, default=True, help='generation network using reidule block') 38 | parser.add_argument('--use_lsgan', dest='use_lsgan', type=bool, default=True, help='gan loss defined in lsgan') 39 | parser.add_argument('--max_size', dest='max_size', type=int, default=50, help='max size of image pool, 0 means do not use image pool') 40 | 41 | args = parser.parse_args() 42 | 43 | def main(_): 44 | if not os.path.exists(args.checkpoint_dir): 45 | os.makedirs(args.checkpoint_dir) 46 | if not os.path.exists(args.sample_dir): 47 | os.makedirs(args.sample_dir) 48 | if not os.path.exists(args.test_dir): 49 | os.makedirs(args.test_dir) 50 | 51 | with tf.Session() as sess: 52 | model = mixgan(sess, args) 53 | model.train(args) if args.phase == 'train' \ 54 | else model.test(args) 55 | 56 | if __name__ == '__main__': 57 | tf.app.run() 58 | #export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH 59 | # CUDA_VISIBLE_DEVICES=0 python main1.py local grus0 60 | 61 | # tensorboard --port=6031 --logdir=./logs1 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MIXGAN: Learning Concepts from Different Domains for Mixture Generation 2 | ## The paper was accepted by International Joint Conference on Artificial Intelligence, 2018. 3 | 4 | ## Introduction 5 | 6 | When you are looking at a red T-shirt in eBay, you could easily imagine how you would look like when you wear it: you know well the shape of your body, you have an image of this red T-shirt in your mind, and thus you can wear it in your imaginary world. However, can a learning machine do a job like this? This means that the machine should have the ability to learn from different domains (e.g., people and T-shirts) and extract some specific knowledge from them, respectively (e.g., people’s body shapes and T-shirts’ color style). Then, it is expected to join the specific kinds of knowledge and thereby generate a brand new domain (e.g., imagination on wearing the T-shirt). In realistic applications, this allows more flexible generation strategies than the conventional generation problem where only one source domain exists and the generated domain is expected to be the same as the source. For instance, as illustrated in the right part in Figure 1, a machine learns to generate images in a new domain (i.e., colorful handbags) with shape style of one domain (black-white handbags) and color style of another domain (colorful shoes) By this way, it helps bring new ideas and provides visualizations for the handbag designers who only have some raw ideas about designing handbags of new color styles. 7 | ![intro](https://github.com/GuangyuanHao/MIXGAN/raw/master/intro.jpg) 8 | Unfortunately, as illustrated in the left part in figure above, existing GAN-based methods, e.g., GAN, CoGAN, ACGAN and BEGAN, are restricted to generate new samples similar to the ones from source domains. The style transfer and image-to-image translation models, e.g., CycleGAN, pix2pix, and UNIT, can transfer a single image or samples in one domain to an image with style similar to another image or samples in another domain. However, these models are restricted in that they require certain input images as subjects to be translated, and thus cannot deal with the problems requiring going beyond the available subjects (e.g., designing new shape styles of handbags in the above situation). Therefore, the problem of learning from different domains for generating a new domain still remains an open issue. 9 | 10 | To explicitly learn different types of knowledge from different domains, in this work we focus on learning global structural information(e.g., shape of an object like a handbags) from one domain, and learning local structural information (e.g., color style another object like a shoes) from another. Learning global information is straightforward: a simple auto-encoder structure can be leveraged to capture and encode the global information, since the auto-encoder focuses on the whole images. As for learning local structural information, we propose to learn from small patches of the source images. By this way, our model has the ability to additionally focus on the specific local patterns that appear in patches, and thus can also be effective in capturing the local structural information. 11 | 12 | We evaluate our model on several tasks. The experimental results show that our model can learn to generate images in a new domain, e.g., generating digits whose shape style and color style are learned from two different domains, i.e., the MNIST dataset and the SVHN dataset. The main contributions are as follows: 13 | 14 | 1. We propose an unsupervised method to absorb different concepts, i.e., global structural information and local structural information, from different domains. 15 | 2. We build a model to learn a distribution to describe a new domain with global structural information of one domain and local structural information of another domain by focusing on learning global and local structural information at the scale of a whole image from one domain and at the scale of small patches from another domain respectively. 16 | 3. We show our model can successfully learn to generate mixed-style samples with shape style of one dataset and color style of another dataset in several tasks. 17 | 18 | ## Results 19 | ![digit](https://github.com/GuangyuanHao/MIXGAN/raw/master/Results/digit.jpg) 20 | ![bag](https://github.com/GuangyuanHao/MIXGAN/raw/master/Results/bag.jpg) 21 | ## Other Applications 22 | ### Learning a Joint Distribution of Two Domains 23 | Our model can also learn a Joint Distribution of Two Domains successfully as CoGAN did, when two datasets have similar global structural information. 24 | #### Results 25 | ![joint1](https://github.com/GuangyuanHao/MIXGAN/raw/master/Results/joint1.jpg) 26 | ![joint2](https://github.com/GuangyuanHao/MIXGAN/raw/master/Results/joint2.jpg) 27 | ### Image-to-Image-Translation 28 | Our model can also accomplish image-to-image translation tasks successfully as UNIT did, when two datasets have similar global structural information. 29 | #### Results 30 | ![tran3](https://github.com/GuangyuanHao/MIXGAN/raw/master/Results/tran3.jpg) 31 | ![tran4](https://github.com/GuangyuanHao/MIXGAN/raw/master/Results/tran4.jpg) 32 | -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import tensorflow as tf 3 | from ops import * 4 | 5 | def discriminator(image, options, reuse=False, name="discriminator"): 6 | 7 | with tf.variable_scope(name): 8 | # image is 64 x 64 x input_c_dim 9 | if reuse: 10 | tf.get_variable_scope().reuse_variables() 11 | else: 12 | assert tf.get_variable_scope().reuse == False 13 | 14 | h0 = lrelu(conv2d(image, options.df_dim,ks=4,s=1, name='d_h0_conv')) 15 | # h0 is (64 x 64 x self.df_dim) 16 | h1 = lrelu(instance_norm(conv2d(h0, options.df_dim*2,ks=2,s=1, name='d_h1_conv'), 'd_bn1')) 17 | # h1 is (32 x 32 x self.df_dim*2) 18 | h2 = lrelu(instance_norm(conv2d(h1, options.df_dim*4,ks=1, s=1,name='d_h2_conv'), 'd_bn2')) 19 | # h2 is (32x 32 x self.df_dim*4) 20 | h4 = conv2d(h2, 1, s=1, name='d_h3_pred') 21 | # h4 is (32 x 32 x 1) 22 | return h4 23 | def dis_z(z, options, reuse=False, name="dis_z"): 24 | 25 | with tf.variable_scope(name): 26 | if reuse: 27 | tf.get_variable_scope().reuse_variables() 28 | else: 29 | assert tf.get_variable_scope().reuse == False 30 | z = linear(z, options.gf_dim *32*32) 31 | z_ = tf.reshape(z, [-1, 32, 32, options.gf_dim]) 32 | h0 = lrelu(conv2d(z_, options.df_dim, ks=4,s=2, name='d_h0_conv')) 33 | h1 = lrelu(instance_norm(conv2d(h0, options.df_dim*2,ks=4,s=2, name='d_h1_conv'), 'd_bn1')) 34 | h2 = lrelu(instance_norm(conv2d(h1, options.df_dim*4,ks=4,s=2, name='d_h2_conv'), 'd_bn2')) 35 | h3 = lrelu(instance_norm(conv2d(h2, options.df_dim*8, ks=4,s=2, name='d_h3_conv'), 'd_bn3')) 36 | h3=tf.reshape(h3, [-1, options.df_dim*8*4]) 37 | h4 = linear(h3, 1, 'd_h4_lin') 38 | return h4 39 | 40 | 41 | def encoder(image, options, reuse=False, name="encoder"): 42 | 43 | with tf.variable_scope(name): 44 | if reuse: 45 | tf.get_variable_scope().reuse_variables() 46 | else: 47 | assert tf.get_variable_scope().reuse == False 48 | # print("encoder",image.get_shape()) 49 | image_size = int(options.image_size) 50 | c0 = tf.pad(image, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT") 51 | c1 = tf.nn.relu(conv2d(c0, options.gf_dim, 7, 1, padding='VALID', name='g_e1_c')) 52 | c2 = tf.nn.relu(instance_norm(conv2d(c1, options.gf_dim * 2, 3, 2, name='g_e2_c'), 'g_e2_bn')) 53 | c3 = tf.nn.relu(instance_norm(conv2d(c2, options.gf_dim * 4, 3, 2, name='g_e3_c'), 'g_e3_bn')) 54 | c4 = tf.nn.relu(instance_norm(conv2d(c3, options.gf_dim * 4, 3, 2, name='g_e4_c'), 'g_e4_bn')) 55 | c5 = tf.nn.relu(instance_norm(conv2d(c4, options.gf_dim * 8, 3, 2, name='g_e5_c'), 'g_e5_bn')) 56 | d4 = tf.reshape(c5, [-1,options.gf_dim *8*int((image_size/16)**2)]) 57 | z1 = linear(d4, options.gf_dim * 4 * ((image_size/16)**2), scope='l1') 58 | z1 = tf.nn.relu(batch_norm(z1, 'g_z_bn1')) 59 | z2 = linear(z1, options.gf_dim * 4 * ((image_size/16)**2), scope='l2') 60 | z2 = tf.nn.relu(batch_norm(z2, 'g_z_bn2')) 61 | z3 = linear(z2, options.gf_dim * 2 * ((image_size/16)**2), scope='l3') 62 | z3 = tf.nn.relu(batch_norm(z3, 'g_z_bn3')) 63 | z4 = linear(z3, options.z_dim) 64 | return z4 65 | 66 | def decoder(z, options, reuse=False, name="generator"): 67 | 68 | with tf.variable_scope(name): 69 | if reuse: 70 | tf.get_variable_scope().reuse_variables() 71 | else: 72 | assert tf.get_variable_scope().reuse == False 73 | 74 | def residule_block(x, dim, ks=3, s=1, name='res'): 75 | p = int((ks - 1) / 2) 76 | y = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT") 77 | y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name+'_c1'), name+'_bn1') 78 | y = tf.pad(tf.nn.relu(y), [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT") 79 | y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name+'_c2'), name+'_bn2') 80 | return y + x 81 | 82 | image_size = int(options.image_size) 83 | z1 = linear(z, options.gf_dim * 2 * ((image_size / 16) ** 2), scope='l1') 84 | z1 = tf.nn.relu(batch_norm(z1, 'g_z_bn1')) 85 | z2 = linear(z1, options.gf_dim * 4 * ((image_size / 16) ** 2), scope='l2') 86 | z2 = tf.nn.relu(batch_norm(z2, 'g_z_bn2')) 87 | z3 = linear(z2, options.gf_dim * 4 * ((image_size / 16) ** 2), scope='l3') 88 | z3 = tf.nn.relu(batch_norm(z3, 'g_z_bn3')) 89 | z4 = linear(z3, options.gf_dim * 8 * ((image_size / 16) ** 2), scope='l4') 90 | c0 = tf.reshape(z4, [-1, int(image_size / 16), int(image_size / 16), options.gf_dim * 8]) 91 | c0 = tf.nn.relu(instance_norm(c0, 'g_z_bn')) 92 | c1 = deconv2d(c0, options.gf_dim * 8, 3, 2, name='g_c1_dc') # 16 93 | c1 = tf.nn.relu(instance_norm(c1, 'g_c1_bn')) 94 | c2 = deconv2d(c1, options.gf_dim * 4, 3, 2, name='g_c2_dc') # 32 95 | c2 = tf.nn.relu(instance_norm(c2, 'g_c2_bn')) 96 | # define G network with 6 resnet blocks 97 | r1 = residule_block(c2, options.gf_dim*4, name='g_r1') 98 | r2 = residule_block(r1, options.gf_dim*4, name='g_r2') 99 | r3 = residule_block(r2, options.gf_dim*4, name='g_r3') 100 | r4 = residule_block(r3, options.gf_dim*4, name='g_r4') 101 | r5 = residule_block(r4, options.gf_dim*4, name='g_r5') 102 | r6 = residule_block(r5, options.gf_dim*4, name='g_r6') 103 | d1 = deconv2d(r6, options.gf_dim*2, 3, 2, name='g_d1_dc') 104 | d1 = tf.nn.relu(instance_norm(d1, 'g_d1_bn')) 105 | d2 = deconv2d(d1, options.gf_dim*1, 3, 2, name='g_d2_dc') 106 | d2 = tf.nn.relu(instance_norm(d2, 'g_d2_bn')) 107 | d2_ = tf.pad(d2, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT") 108 | pred = conv2d(d2_, options.input_c_dim, 7, 1, padding='VALID', name='g_pred_c') 109 | pred = tf.nn.tanh(pred, 'g_pred_bn') 110 | return pred, tf.concat([d2,pred],axis=3) 111 | def generator_a(dp, options, reuse=False, name="generator_y"): 112 | 113 | with tf.variable_scope(name): 114 | if reuse: 115 | tf.get_variable_scope().reuse_variables() 116 | else: 117 | assert tf.get_variable_scope().reuse == False 118 | d1 = deconv2d(dp, options.gf_dim * 2, 3, 1, name='g_d1_dc') 119 | d1 = tf.nn.relu(instance_norm(d1, 'g_d1_bn')) 120 | d2 = tf.concat([d1,dp],axis=3) 121 | d2 = tf.pad(d2, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT") 122 | pred = conv2d(d2, options.output_c_dim, 7, 1, padding='VALID', name='g_pred_c') 123 | pred = tf.nn.tanh(pred, 'g_pred_bn') 124 | return pred 125 | 126 | def abs_criterion(in_, target): 127 | return tf.reduce_mean(tf.abs(in_ - target)) 128 | 129 | def mae_criterion(in_, target): 130 | return tf.reduce_mean((in_-target)**2) 131 | 132 | def sce_criterion(logits, labels): 133 | return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import time 4 | from glob import glob 5 | import tensorflow as tf 6 | import numpy as np 7 | from six.moves import xrange 8 | from collections import namedtuple 9 | #export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH 10 | # CUDA_VISIVLE_DEVICES=1 python main.py 11 | import h5py 12 | from module import * 13 | from utils import * 14 | 15 | 16 | class mixgan(object): 17 | def __init__(self, sess, args): 18 | self.sess = sess 19 | self.batch_size = args.batch_size 20 | self.image_size = args.fine_size 21 | self.z_size = args.z_dim 22 | self.input_c_dim = args.input_nc 23 | self.output_c_dim = args.output_nc 24 | self.L1_lambda = args.L1_lambda 25 | self.dataset_dir = args.dataset_dir 26 | self.dis_z = dis_z 27 | self.encoder =encoder 28 | self.generator = generator_a 29 | self.discriminator = discriminator 30 | self.decoder = decoder 31 | if args.use_lsgan: 32 | self.criterionGAN = mae_criterion 33 | else: 34 | self.criterionGAN = sce_criterion 35 | 36 | OPTIONS = namedtuple('OPTIONS', 'batch_size image_size \ 37 | gf_dim df_dim output_c_dim input_c_dim z_dim') 38 | self.options = OPTIONS._make((args.batch_size, args.fine_size, 39 | args.ngf, args.ndf, args.output_nc, args.input_nc, args.z_dim)) 40 | 41 | self._build_model() 42 | self.saver = tf.train.Saver() 43 | 44 | def _build_model(self): 45 | self.real_A = tf.placeholder(tf.float32, 46 | [None, self.image_size, self.image_size, 47 | self.output_c_dim], 48 | name='real_images_A') 49 | self.real_B = tf.placeholder(tf.float32, 50 | [None, self.image_size, self.image_size, 51 | self.input_c_dim], 52 | name='real_images_B') 53 | 54 | self.z = tf.placeholder(tf.float32,[None, self.z_size], name='noise_z') 55 | self.fake_z = self.encoder(self.real_B, self.options, False, name="encoder") 56 | self.fake_BB, dpb = self.decoder(self.fake_z, self.options, False, name="decoder") 57 | self.test_BB, dpz = self.decoder(self.z, self.options, True, name="decoder") 58 | 59 | self.fake_A = self.generator(dpz, self.options, False, name="generator") 60 | 61 | self.DA_fake = self.discriminator(self.fake_A, self.options, False, name="dis_A") 62 | 63 | self.g_loss = self.criterionGAN(self.DA_fake, tf.ones_like(self.DA_fake)) 64 | self.DA_real = self.discriminator(self.real_A, self.options, True, name="dis_A") 65 | self.dA_loss_real = self.criterionGAN(self.DA_real, tf.ones_like(self.DA_real)) 66 | self.dA_loss_fake = self.criterionGAN(self.DA_fake, tf.zeros_like(self.DA_fake)) 67 | self.dA_loss = (self.dA_loss_real + self.dA_loss_fake) / 2 68 | 69 | self.g_sum = tf.summary.scalar("g_loss", self.g_loss) 70 | self.dA_loss_sum = tf.summary.scalar("dA_loss", self.dA_loss) 71 | self.dA_loss_real_sum = tf.summary.scalar("dA_loss_real", self.dA_loss_real) 72 | self.dA_loss_fake_sum = tf.summary.scalar("dA_loss_fake", self.dA_loss_fake) 73 | self.dA_sum = tf.summary.merge( 74 | [self.dA_loss_sum, self.dA_loss_real_sum, self.dA_loss_fake_sum] 75 | ) 76 | 77 | self.Dz_fake = self.dis_z(self.fake_z, self.options, reuse=False, name="dis_z") 78 | self.en_asg_loss = self.criterionGAN(self.Dz_fake, tf.ones_like(self.Dz_fake)) 79 | self.x_ae_loss = self.L1_lambda * abs_criterion(self.fake_BB, self.real_B) 80 | self.en_loss = self.en_asg_loss + self.x_ae_loss 81 | self.de_g_loss = self.x_ae_loss + self.g_loss 82 | 83 | # + abs_criterion(self.fake_x, self.real_x) 84 | 85 | self.Dz_real = self.dis_z(self.z, self.options, reuse=True, name="dis_z") 86 | self.dz_loss_real = self.criterionGAN(self.Dz_real, tf.ones_like(self.Dz_real)) 87 | self.dz_loss_fake = self.criterionGAN(self.Dz_fake, tf.zeros_like(self.Dz_fake)) 88 | self.dz_loss = (self.dz_loss_real + self.dz_loss_fake) / 2 89 | 90 | self.en_sum = tf.summary.scalar("en_loss", self.en_loss) 91 | self.de_g_sum = tf.summary.scalar("de_g_loss", self.de_g_loss) 92 | self.x_ae_sum = tf.summary.scalar("x_ae_loss", self.x_ae_loss) 93 | self.en_asg_sum = tf.summary.scalar("en_asg_loss", self.en_asg_loss) 94 | self.de_g_summary = tf.summary.merge( 95 | [self.de_g_sum, self.x_ae_sum, self.g_sum] 96 | ) 97 | self.en_summary = tf.summary.merge( 98 | [self.en_sum, self.x_ae_sum, self.en_asg_sum] 99 | ) 100 | 101 | self.dz_loss_sum = tf.summary.scalar("dz_loss", self.dz_loss) 102 | self.dz_loss_real_sum = tf.summary.scalar("dz_loss_real", self.dz_loss_real) 103 | self.dz_loss_fake_sum = tf.summary.scalar("dz_loss_fake", self.dz_loss_fake) 104 | self.dz_sum = tf.summary.merge( 105 | [self.dz_loss_sum, self.dz_loss_real_sum, self.dz_loss_fake_sum] 106 | ) 107 | 108 | t_vars = tf.trainable_variables() 109 | self.dz_vars = [var for var in t_vars if 'dis_z' in var.name] 110 | self.de_vars = [var for var in t_vars if 'decoder' in var.name] 111 | self.en_vars = [var for var in t_vars if 'encoder' in var.name] 112 | self.dA_vars = [var for var in t_vars if 'dis_A' in var.name] 113 | self.g_vars = [var for var in t_vars if 'generator' in var.name] 114 | for var in t_vars: print(var.name) 115 | 116 | def train(self, args): 117 | """Train mixgan""" 118 | self.dz_optim = tf.train.AdamOptimizer(args.lr/50, beta1=args.beta1) \ 119 | .minimize(self.dz_loss, var_list=self.dz_vars) 120 | self.en_optim = tf.train.AdamOptimizer(args.lr, beta1=args.beta1) \ 121 | .minimize(self.en_loss, var_list=self.en_vars) 122 | self.dA_optim = tf.train.AdamOptimizer(args.lr/10, beta1=args.beta1) \ 123 | .minimize(self.dA_loss, var_list=self.dA_vars) #1/2 124 | self.de_g_optim = tf.train.AdamOptimizer(args.lr, beta1=args.beta1) \ 125 | .minimize(self.g_loss, var_list=self.g_vars)#1/1 126 | # .minimize(self.de_g_loss, var_list=[self.g_vars,self.de_vars]) 127 | 128 | 129 | # .minimize(self.de_g_loss, var_list=self.de_vars) 130 | 131 | 132 | 133 | 134 | init_op = tf.global_variables_initializer() 135 | self.sess.run(init_op) 136 | self.writer = tf.summary.FileWriter(args.log_dir, self.sess.graph) 137 | 138 | counter = 1 139 | start_time = time.time() 140 | 141 | if self.load(args.checkpoint_dir): 142 | print(" [*] Load SUCCESS") 143 | else: 144 | print(" [!] Load failed...") 145 | 146 | for epoch in range(args.epoch): 147 | dataA = h5py.File('/home/guangyuan/Downloads/handbag_64.hdf5', 'r')['imgs'] 148 | dataB = h5py.File('/home/guangyuan/Downloads/shoes_64.hdf5', 'r')['imgs'] 149 | 150 | batch_idxs = min(len(dataA), len(dataB), args.train_size) // self.batch_size 151 | for idx in range(0, batch_idxs): 152 | batch_filesA = dataA[idx * self.batch_size: (idx + 1) * self.batch_size] 153 | batch_imagesA = load_data(batch_filesA).astype(np.float32) 154 | batch_filesB = dataB[idx * self.batch_size:(idx + 1) * self.batch_size] 155 | batch_imagesB = load_data(batch_filesB).astype(np.float32) 156 | batch_imagesB1, batch_imagesB2, batch_imagesB3 = np.split(batch_imagesB, 3, axis=3) 157 | batch_imagesB = batch_imagesB1 * 0.114 + batch_imagesB2 * 0.587 + batch_imagesB3 * 0.299 158 | batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_size]) \ 159 | .astype(np.float32) 160 | if epoch<100: 161 | # Update D network 162 | _, summary_str = self.sess.run([self.dz_optim, self.dz_sum], 163 | feed_dict={self.real_B: batch_imagesB, 164 | self.z: batch_z}) 165 | self.writer.add_summary(summary_str, counter) 166 | 167 | # Update G network 168 | _, summary_str = self.sess.run([self.en_optim, self.en_summary], 169 | feed_dict={self.real_B: batch_imagesB, 170 | self.z: batch_z}) 171 | self.writer.add_summary(summary_str, counter) 172 | if epoch > 99: 173 | # Update D network 174 | _, summary_str = self.sess.run([self.dA_optim, self.dA_sum], 175 | feed_dict={self.real_A: batch_imagesA, 176 | self.z: batch_z}) 177 | self.writer.add_summary(summary_str, counter) 178 | 179 | # Update G network 180 | _, summary_str = self.sess.run([self.de_g_optim, self.de_g_summary], 181 | feed_dict={self.real_A: batch_imagesA, 182 | self.real_B: batch_imagesB, 183 | self.z: batch_z}) 184 | self.writer.add_summary(summary_str, counter) 185 | 186 | counter += 1 187 | print(("Epoch: [%2d] [%4d/%4d] time: %4.4f" \ 188 | % (epoch, idx, batch_idxs, time.time() - start_time))) 189 | 190 | if np.mod(counter, 100) == 1: 191 | self.sample_model(args.sample_dir, epoch, idx) 192 | 193 | 194 | if np.mod(counter, 1000) == 2: 195 | self.save(args.checkpoint_dir, counter) 196 | 197 | def save(self, checkpoint_dir, step): 198 | model_name = "mixgan.model" 199 | model_dir = "%s_%s" % (self.dataset_dir, self.image_size) 200 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir) 201 | 202 | if not os.path.exists(checkpoint_dir): 203 | os.makedirs(checkpoint_dir) 204 | 205 | self.saver.save(self.sess, 206 | os.path.join(checkpoint_dir, model_name), 207 | global_step=step) 208 | 209 | def load(self, checkpoint_dir): 210 | print(" [*] Reading checkpoint...") 211 | 212 | model_dir = "%s_%s" % (self.dataset_dir, self.image_size) 213 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir) 214 | 215 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 216 | if ckpt and ckpt.model_checkpoint_path: 217 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 218 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 219 | return True 220 | else: 221 | return False 222 | 223 | def sample_model(self, sample_dir, epoch, idx): 224 | 225 | 226 | # dataA = h5py.File('/home/guangyuan/Downloads/handbag_64.hdf5', 'r')['imgs'] 227 | dataB = h5py.File('/home/guangyuan/Downloads/shoes_64.hdf5', 'r')['imgs'] 228 | 229 | 230 | # batch_filesA = dataA[idx * self.batch_size: (idx + 1) * self.batch_size] 231 | # batch_imagesA = load_data(batch_filesA).astype(np.float32) 232 | batch_filesB = dataB[idx * self.batch_size:(idx + 1) * self.batch_size] 233 | batch_imagesB = load_data(batch_filesB).astype(np.float32) 234 | batch_imagesB1, batch_imagesB2, batch_imagesB3 = np.split(batch_imagesB, 3, axis=3) 235 | batch_imagesB = batch_imagesB1 * 0.114 + batch_imagesB2 * 0.587 + batch_imagesB3 * 0.299 236 | batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_size]) \ 237 | .astype(np.float32) 238 | [fake_A,fake_BB,test_BB] = self.sess.run([self.fake_A,self.fake_BB,self.test_BB], 239 | feed_dict={self.real_B: batch_imagesB, 240 | self.z: batch_z}) 241 | 242 | save_images(fake_A, [int(self.batch_size/8), 8], 243 | './{}/zA_{:02d}_{:04d}.jpg'.format(sample_dir, epoch, idx)) 244 | save_images(test_BB, [int(self.batch_size/8), 8], 245 | './{}/zBB_{:02d}_{:04d}.jpg'.format(sample_dir, epoch, idx)) 246 | save_images(fake_BB, [int(self.batch_size/8), 8], 247 | './{}/rBB_{:02d}_{:04d}.jpg'.format(sample_dir, epoch, idx)) 248 | 249 | def test(self, args): 250 | init_op = tf.global_variables_initializer() 251 | self.sess.run(init_op) 252 | if self.load(args.checkpoint_dir): 253 | print(" [*] Load SUCCESS") 254 | else: 255 | print(" [!] Load failed...") 256 | for k in range(100): 257 | batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_size]) \ 258 | .astype(np.float32) 259 | [fake_A] = self.sess.run([self.fake_A], 260 | feed_dict={self.z: batch_z}) 261 | save_images(fake_A, [int(np.sqrt(self.batch_size)), int(np.sqrt(self.batch_size))], 262 | './{}/test_G_{:2d}.jpg'.format(args.test_dir, k)) 263 | 264 | 265 | --------------------------------------------------------------------------------