├── .gitignore ├── README.md ├── dataset.py ├── model.py ├── ops.py ├── original.jpg └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | datasets/ 3 | model/ 4 | output_imgs/ 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pix2pix-keras-tensorflow 2 | 3 | Keras and TensorFlow hybrid-implementation of [Image-to-Image Translation Using Conditional Adversarial Networks](https://arxiv.org/pdf/1611.07004v1.pdf) that learns a mapping from input images to output images. 4 | This implementation is as same as possible to the original paper. 5 | 6 | The examples from the paper: 7 | ![examples](original.jpg) 8 | 9 | 10 | ## Setup 11 | 12 | ### Prerequistic 13 | 14 | - Software 15 | - python2.7 16 | - tensorflow==0.12.0 17 | - keras==1.2.0 18 | - numpy==1.11.3 19 | - scipy==0.18.1 20 | - matplotlib==1.5.3 21 | - progressbar2==3.12.0 22 | 23 | - Hardware 24 | - nVIDIA GPU (Highly Recommend) 25 | 26 | ### Install 27 | 28 | - Clone this repo to your PC. 29 | 30 | ```bash 31 | $ git clone https://github.com/makora9143/pix2pix-keras-tensorflow.git 32 | $ cd pix2pix-keras-tensorflow 33 | 34 | ``` 35 | 36 | ### Usage (WIP) 37 | 38 | - To train the model, just run the command below. (It will takes few hours.) 39 | - [dataset] = facades / cityscapes / maps / edges2shoes / edges2handbags 40 | ```bash 41 | $ python train.py -d [dataset] 42 | 43 | ``` 44 | - The generated sample images is in the `output_imgs` directory. 45 | If you want to generate some images, run this command: 46 | 47 | ```bash 48 | $ python test.py 49 | ``` 50 | 51 | 52 | # pix2pix-keras-tensorflow 53 | 54 | 画像から出力画像への変換を学習する[Image-to-Image Translation Using Conditional Adversarial Networks](https://arxiv.org/pdf/1611.07004v1.pdf)のKerasとTensorflowを組み合わせた実装です. 55 | 可能な限り,論文内及び著者の実装に準拠しています. 56 | 57 | 元論文の出力例: 58 | ![examples](original.jpg) 59 | 60 | 61 | ## 設定 62 | 63 | ### 必要な環境 64 | 65 | - ソフトウェア・ライブラリ 66 | - python2.7 67 | - tensorflow==0.12.0 68 | - keras==1.2.0 69 | - numpy==1.11.3 70 | - scipy==0.18.1 71 | - matplotlib==1.5.3 72 | - progressbar2==3.12.0 73 | 74 | - ハードウェア 75 | - nVIDIA GPU (推奨) 76 | 77 | ### 準備 78 | 79 | - ローカルPCに`git clone`してください. 80 | 81 | ```bash 82 | $ git clone https://github.com/makora9143/pix2pix-keras-tensorflow.git 83 | $ cd pix2pix-keras-tensorflow 84 | 85 | 86 | ### 使い方 87 | 88 | 89 | ``` 90 | - 学習するために,次のコマンドを実行してください.(数時間かかります.) 91 | - [データセット] = facades / cityscapes / maps / edges2shoes / edges2handbags 92 | ```bash 93 | $ python train.py -d [データセット] 94 | 95 | ``` 96 | - 生成された画像は,`output_imgs`ディレクトリに出力されます. 97 | 画像を生成するために,次のコマンドを実行します. 98 | 99 | ```bash 100 | $ python test.py 101 | ``` 102 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import threading 4 | 5 | 6 | from ops import * 7 | import os 8 | from glob import glob 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | def download(dataset_name): 14 | datasets_dir = './datasets/' 15 | mkdir(datasets_dir) 16 | URL='https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/%s.tar.gz' % (dataset_name) 17 | TAR_FILE='./datasets/%s.tar.gz' % (dataset_name) 18 | TARGET_DIR='./datasets/%s/' % (dataset_name) 19 | os.system('wget -N %s -O %s' % (URL, TAR_FILE)) 20 | os.mkdir(TARGET_DIR) 21 | os.system('tar -zxf %s -C ./datasets/' % (TAR_FILE)) 22 | os.remove(TAR_FILE) 23 | 24 | 25 | class Dataset(object): 26 | def __init__(self, dataset, is_test=False, batch_size=4, crop_width=256, thread_num=1): 27 | self.batch_size = batch_size 28 | self.thread_num = thread_num 29 | print "Batch size: %d, Thread num: %d" % (batch_size, thread_num) 30 | datasetDir = './datasets/{}'.format(dataset) 31 | if not os.path.isdir(datasetDir): 32 | download(dataset) 33 | dataDir = datasetDir + '/train' 34 | data = glob((dataDir + '/*.jpg').format(dataset)) 35 | self.data_size = min(400, len(data)) 36 | self.data_indice = range(self.data_size - 1) 37 | self.dataDir = dataDir 38 | self.is_test = is_test 39 | self.dataset = [] 40 | for i in range(1, self.data_size): 41 | img, label = load_image(self.dataDir + '/%d.jpg' % i) 42 | 43 | self.dataset.append((img, label)) 44 | print "load dataset done" 45 | print 'data size: %d' % len(self.dataset) 46 | self.img_shape = list(self.dataset[0][0].shape) 47 | self.label_shape = list(self.dataset[0][1].shape) 48 | self.fine_size = self.img_shape[0] 49 | self.crop_width = self.fine_size 50 | self.load_size = self.fine_size + 30 51 | 52 | self.img_data = tf.placeholder(tf.float32, shape=[None] + self.img_shape) 53 | self.label_data = tf.placeholder(tf.float32, shape=[None] + self.label_shape) 54 | self.queue = tf.FIFOQueue(shapes=[self.label_shape, self.img_shape], 55 | dtypes=[tf.float32, tf.float32], 56 | capacity=2000) 57 | self.enqueue_ops = self.queue.enqueue_many([self.label_data, self.img_data]) 58 | 59 | 60 | def batch_iterator(self): 61 | while True: 62 | shuffle_indices = np.random.permutation(self.data_indice) 63 | for i in range(len(self.data_indice) / self.batch_size): 64 | img_batch = [] 65 | label_batch = [] 66 | for j in range(i*self.batch_size, (i+1)*self.batch_size): 67 | label = self.dataset[shuffle_indices[j]][1] 68 | img = self.dataset[shuffle_indices[j]][0] 69 | img, label = img_preprocess(img, label, self.fine_size, self.load_size ) 70 | label_batch.append(label) 71 | img_batch.append(img) 72 | yield np.array(label_batch), np.array(img_batch) 73 | 74 | def get_inputs(self): 75 | labels, imgs = self.queue.dequeue_many(self.batch_size) 76 | return labels, imgs 77 | 78 | def thread_main(self, sess): 79 | for labels, imgs in self.batch_iterator(): 80 | sess.run(self.enqueue_ops, feed_dict={self.label_data: labels , self.img_data: imgs}) 81 | sess.run(self.enqueue_ops, feed_dict={self.label_data: labels , self.img_data: imgs}) 82 | 83 | def start_threads(self, sess): 84 | threads = [] 85 | for n in range(self.thread_num): 86 | t = threading.Thread(target=self.thread_main, args=(sess,)) 87 | t.daemon = True 88 | t.start() 89 | threads.append(t) 90 | return threads 91 | 92 | def get_size(self): 93 | return self.data_size 94 | 95 | def get_shape(self): 96 | return self.img_shape, self.label_shape 97 | 98 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from keras.models import Sequential 4 | from keras import initializations 5 | from keras.layers import Activation, Flatten, Dropout, merge 6 | from keras.layers.convolutional import Convolution2D, Deconvolution2D 7 | from keras.layers.normalization import BatchNormalization 8 | from keras.layers.advanced_activations import LeakyReLU 9 | 10 | def my_init(shape, name=None): 11 | return initializations.normal(shape, scale=0.02, name=name) 12 | 13 | def create_netD(image_width, image_height, input_channel, ndf, filter_size): 14 | ######################################### 15 | # Discriminator 16 | # 17 | # PatchGAN (Sequential) 18 | # 19 | # C64-C128-C256-C512 20 | # 21 | ######################################### 22 | patchGAN = Sequential() 23 | 24 | # C64 256=>128 25 | patchGAN.add(Convolution2D(ndf, filter_size, filter_size, 26 | subsample=(2, 2), 27 | border_mode='same', 28 | init=my_init, 29 | input_shape=(image_width, image_height, input_channel))) 30 | patchGAN.add(LeakyReLU(alpha=0.2)) 31 | 32 | # C128 128=>64 33 | patchGAN.add(Convolution2D(ndf * 2, filter_size, filter_size, 34 | subsample=(2, 2), 35 | init=my_init, 36 | border_mode='same', 37 | )) 38 | patchGAN.add(BatchNormalization()) 39 | patchGAN.add(LeakyReLU(alpha=0.2)) 40 | 41 | # C256 64=>32 42 | patchGAN.add(Convolution2D(ndf * 4, filter_size, filter_size, 43 | subsample=(2, 2), 44 | init=my_init, 45 | border_mode='same', 46 | )) 47 | patchGAN.add(BatchNormalization()) 48 | patchGAN.add(LeakyReLU(alpha=0.2)) 49 | 50 | # C512 32=>16 51 | patchGAN.add(Convolution2D(ndf * 8, filter_size, filter_size, 52 | subsample=(1, 1), 53 | init=my_init, 54 | border_mode='same', 55 | )) 56 | patchGAN.add(BatchNormalization()) 57 | patchGAN.add(LeakyReLU(alpha=0.2)) 58 | 59 | patchGAN.add(Convolution2D(1, filter_size, filter_size, 60 | subsample=(1, 1), 61 | init=my_init, 62 | border_mode='same', 63 | )) 64 | patchGAN.add(Activation('sigmoid')) 65 | patchGAN.add(Flatten()) 66 | return patchGAN 67 | 68 | def create_netG(train_X, tmp_x, ngf, filter_size, image_width, image_height, input_channel, output_channel, batch_size): 69 | encoder_decoder = [] 70 | # encoder 71 | # C64 256=>128 72 | enc_conv0 = Convolution2D(ngf, filter_size, filter_size, 73 | subsample=(2, 2), 74 | border_mode='same', 75 | init=my_init, 76 | input_shape=(image_width, image_height, input_channel) 77 | ) 78 | enc_output0 = enc_conv0(train_X) 79 | enc_output0_ = enc_conv0(tmp_x) 80 | 81 | encoder_decoder.append(enc_conv0) 82 | 83 | # C128 128=>64 84 | enc_conv1 = Convolution2D(ngf * 2, filter_size, filter_size, 85 | subsample=(2, 2), 86 | init=my_init, 87 | border_mode='same' 88 | ) 89 | enc_bn1 = BatchNormalization(epsilon=1e-5, momentum=0.9) 90 | leaky_relu1 = LeakyReLU(alpha=0.2) 91 | 92 | encoder_decoder += [enc_conv1, enc_bn1] 93 | 94 | enc_output1 = enc_bn1(enc_conv1(leaky_relu1(enc_output0))) 95 | enc_output1_ = enc_bn1(enc_conv1(leaky_relu1(enc_output0_))) 96 | 97 | # C256 64=>32 98 | enc_conv2 = Convolution2D(ngf * 4, filter_size, filter_size, 99 | subsample=(2, 2), 100 | init=my_init, 101 | border_mode='same' 102 | ) 103 | enc_bn2 = BatchNormalization(epsilon=1e-5, momentum=0.9) 104 | leaky_relu2 = LeakyReLU(alpha=0.2) 105 | 106 | encoder_decoder += [enc_conv2, enc_bn2] 107 | 108 | enc_output2 = enc_bn2(enc_conv2(leaky_relu2(enc_output1))) 109 | enc_output2_ = enc_bn2(enc_conv2(leaky_relu2(enc_output1_))) 110 | 111 | # C512 32=>16 112 | enc_conv3 = Convolution2D(ngf * 8, filter_size, filter_size, 113 | subsample=(2, 2), 114 | init=my_init, 115 | border_mode='same' 116 | ) 117 | enc_bn3 = BatchNormalization(epsilon=1e-5, momentum=0.9) 118 | leaky_relu3 = LeakyReLU(alpha=0.2) 119 | 120 | encoder_decoder += [enc_conv3, enc_bn3] 121 | enc_output3 = enc_bn3(enc_conv3(leaky_relu3(enc_output2))) 122 | enc_output3_ = enc_bn3(enc_conv3(leaky_relu3(enc_output2_))) 123 | 124 | # C512 16=>8 125 | enc_conv4 = Convolution2D(ngf * 8, filter_size, filter_size, 126 | subsample=(2, 2), 127 | init=my_init, 128 | border_mode='same' 129 | ) 130 | enc_bn4 = BatchNormalization(epsilon=1e-5, momentum=0.9) 131 | leaky_relu4 = LeakyReLU(alpha=0.2) 132 | 133 | encoder_decoder += [enc_conv4, enc_bn4] 134 | enc_output4 = enc_bn4(enc_conv4(leaky_relu4(enc_output3))) 135 | enc_output4_ = enc_bn4(enc_conv4(leaky_relu4(enc_output3_))) 136 | # C512 8=>4 137 | enc_conv5 = Convolution2D(ngf * 8, filter_size, filter_size, 138 | subsample=(2, 2), 139 | init=my_init, 140 | border_mode='same' 141 | ) 142 | enc_bn5 = BatchNormalization(epsilon=1e-5, momentum=0.9) 143 | leaky_relu5 = LeakyReLU(alpha=0.2) 144 | 145 | encoder_decoder += [enc_conv5, enc_bn5] 146 | enc_output5 = enc_bn5(enc_conv5(leaky_relu5(enc_output4))) 147 | enc_output5_ = enc_bn5(enc_conv5(leaky_relu5(enc_output4_))) 148 | 149 | # C512 4=>2 150 | enc_conv6 = Convolution2D(ngf * 8, filter_size, filter_size, 151 | subsample=(2, 2), 152 | init=my_init, 153 | border_mode='same' 154 | ) 155 | enc_bn6 = BatchNormalization(epsilon=1e-5, momentum=0.9) 156 | leaky_relu6 = LeakyReLU(alpha=0.2) 157 | 158 | encoder_decoder += [enc_conv6, enc_bn6] 159 | 160 | enc_output6 = enc_bn6(enc_conv6(leaky_relu6(enc_output5))) 161 | enc_output6_ = enc_bn6(enc_conv6(leaky_relu6(enc_output5_))) 162 | 163 | 164 | 165 | # C512 2=>1 166 | enc_conv7 = Convolution2D(ngf * 8, filter_size, filter_size, 167 | subsample=(2, 2), 168 | init=my_init, 169 | border_mode='same' 170 | ) 171 | enc_bn7 = BatchNormalization(epsilon=1e-5, momentum=0.9) 172 | leaky_relu7 = LeakyReLU(alpha=0.2) 173 | 174 | encoder_decoder += [enc_conv7, enc_bn7] 175 | 176 | enc_output7 = enc_bn7(enc_conv7(leaky_relu7(enc_output6))) 177 | enc_output7_ = enc_bn7(enc_conv7(leaky_relu7(enc_output6_))) 178 | 179 | # decoder 180 | #CD512 1=>2 181 | dec_conv0 = Deconvolution2D(ngf * 8, filter_size, filter_size, 182 | output_shape=(batch_size, int(np.ceil(image_width / 128.)), int(np.ceil(image_height / 128.)), ngf * 8), 183 | subsample=(2, 2), 184 | init=my_init, 185 | border_mode='same' 186 | ) 187 | dec_bn0 = BatchNormalization(epsilon=1e-5, momentum=0.9) 188 | dropout0 = Dropout(0.5) 189 | relu0 = Activation('relu') 190 | 191 | encoder_decoder += [dec_conv0, dec_bn0] 192 | dec_output0 = dropout0(dec_bn0(dec_conv0(relu0(enc_output7)))) 193 | 194 | dec_output0 = merge([dec_output0, enc_output6], mode='concat') 195 | 196 | dec_output0_ = dropout0(dec_bn0(dec_conv0(relu0(enc_output7_)))) 197 | 198 | dec_output0_ = merge([dec_output0_, enc_output6_], mode='concat') 199 | 200 | 201 | #CD512 2=>4 202 | dec_conv1 = Deconvolution2D(ngf * 8, filter_size, filter_size, 203 | output_shape=(batch_size, int(np.ceil(image_width / 64.)), int(np.ceil(image_height / 64.)), ngf * 8), 204 | subsample=(2, 2), 205 | init=my_init, 206 | border_mode='same' 207 | ) 208 | dec_bn1 = BatchNormalization(epsilon=1e-5, momentum=0.9) 209 | dropout1 = Dropout(0.5) 210 | relu1 = Activation('relu') 211 | 212 | encoder_decoder += [dec_conv1, dec_bn1] 213 | dec_output1 = dropout1(dec_bn1(dec_conv1(relu1(dec_output0)))) 214 | 215 | dec_output1 = merge([dec_output1, enc_output5], mode='concat') 216 | 217 | dec_output1_ = dropout1(dec_bn1(dec_conv1(relu1(dec_output0_)))) 218 | 219 | dec_output1_ = merge([dec_output1_, enc_output5_], mode='concat') 220 | 221 | #CD512 4=>8 222 | dec_conv2 = Deconvolution2D(ngf * 8, filter_size, filter_size, 223 | output_shape=(batch_size, int(np.ceil(image_width / 32.)), int(np.ceil(image_height / 32.)), ngf * 8), 224 | subsample=(2, 2), 225 | init=my_init, 226 | border_mode='same' 227 | ) 228 | dec_bn2 = BatchNormalization(epsilon=1e-5, momentum=0.9) 229 | dropout2 = Dropout(0.5) 230 | relu2 = Activation('relu') 231 | 232 | encoder_decoder += [dec_conv2, dec_bn2] 233 | dec_output2 = dropout2(dec_bn2(dec_conv2(relu2(dec_output1)))) 234 | 235 | dec_output2 = merge([dec_output2, enc_output4], mode='concat') 236 | 237 | dec_output2_ = dropout2(dec_bn2(dec_conv2(relu2(dec_output1_)))) 238 | 239 | dec_output2_ = merge([dec_output2_, enc_output4_], mode='concat') 240 | 241 | #C512 8=>16 242 | dec_conv3 = Deconvolution2D(ngf * 8, filter_size, filter_size, 243 | output_shape=(batch_size, int(np.ceil(image_width / 16.)), int(np.ceil(image_height / 16.)), ngf * 8), 244 | subsample=(2, 2), 245 | init=my_init, 246 | border_mode='same' 247 | ) 248 | dec_bn3 = BatchNormalization(epsilon=1e-5, momentum=0.9) 249 | relu3 = Activation('relu') 250 | 251 | encoder_decoder += [dec_conv3, dec_bn3] 252 | dec_output3 = dec_bn3(dec_conv3(relu3(dec_output2))) 253 | 254 | dec_output3 = merge([dec_output3, enc_output3], mode='concat') 255 | 256 | dec_output3_ = dec_bn3(dec_conv3(relu3(dec_output2_))) 257 | 258 | dec_output3_ = merge([dec_output3_, enc_output3_], mode='concat') 259 | 260 | #C256 16=>32 261 | dec_conv4 = Deconvolution2D(ngf * 4, filter_size, filter_size, 262 | output_shape=(batch_size, int(np.ceil(image_width / 8.)), int(np.ceil(image_height / 8.)), ngf * 4), 263 | subsample=(2, 2), 264 | init=my_init, 265 | border_mode='same' 266 | ) 267 | dec_bn4 = BatchNormalization(epsilon=1e-5, momentum=0.9) 268 | relu4 = Activation('relu') 269 | 270 | encoder_decoder += [dec_conv4, dec_bn4] 271 | dec_output4 = dec_bn4(dec_conv4(relu4(dec_output3))) 272 | 273 | dec_output4 = merge([dec_output4, enc_output2], mode='concat') 274 | 275 | dec_output4_ = dec_bn4(dec_conv4(relu4(dec_output3_))) 276 | 277 | dec_output4_ = merge([dec_output4_, enc_output2_], mode='concat') 278 | 279 | 280 | 281 | #C128 32=>64 282 | dec_conv5 = Deconvolution2D(ngf * 2, filter_size, filter_size, 283 | output_shape=(batch_size, int(np.ceil(image_width / 4.)), int(np.ceil(image_height / 4.)), ngf * 2), 284 | subsample=(2, 2), 285 | init=my_init, 286 | border_mode='same' 287 | ) 288 | dec_bn5 = BatchNormalization(epsilon=1e-5, momentum=0.9) 289 | relu5 = Activation('relu') 290 | encoder_decoder += [dec_conv5, dec_bn5] 291 | dec_output5 = dec_bn5(dec_conv5(relu5(dec_output4))) 292 | 293 | 294 | dec_output5 = merge([dec_output5, enc_output1], mode='concat') 295 | 296 | dec_output5_ = dec_bn5(dec_conv5(relu5(dec_output4_))) 297 | 298 | dec_output5_ = merge([dec_output5_, enc_output1_], mode='concat') 299 | 300 | #C64 64=>128 301 | dec_conv6 = Deconvolution2D(ngf, filter_size, filter_size, 302 | output_shape=(batch_size, int(np.ceil(image_width / 2.)), int(np.ceil(image_height / 2.)), ngf), 303 | subsample=(2, 2), 304 | init=my_init, 305 | border_mode='same' 306 | ) 307 | dec_bn6 = BatchNormalization(epsilon=1e-5, momentum=0.9) 308 | relu6 = Activation('relu') 309 | 310 | encoder_decoder += [dec_conv6, dec_bn6] 311 | dec_output6 = dec_bn6(dec_conv6(relu6(dec_output5))) 312 | dec_output6 = merge([dec_output6, enc_output0], mode='concat') 313 | 314 | dec_output6_ = dec_bn6(dec_conv6(relu6(dec_output5_))) 315 | dec_output6_ = merge([dec_output6_, enc_output0_], mode='concat') 316 | 317 | #C3 128=>256 last layer tanh 318 | dec_conv7 = Deconvolution2D(output_channel, filter_size, filter_size, 319 | output_shape=(batch_size, image_width, image_height, output_channel), 320 | subsample=(2, 2), 321 | init=my_init, 322 | border_mode='same' 323 | ) 324 | dec_tanh = Activation('tanh') 325 | encoder_decoder += [dec_conv7] 326 | 327 | dec_output = dec_tanh(dec_conv7(relu6(dec_output6))) 328 | generated_img = dec_tanh(dec_conv7(relu6(dec_output6_))) 329 | return dec_output, generated_img, encoder_decoder 330 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | import numpy as np 6 | import scipy.misc 7 | import tensorflow as tf 8 | import matplotlib.pyplot as plt 9 | 10 | def mkdir(dirpath): 11 | if not os.path.isdir(dirpath): 12 | os.mkdir(dirpath) 13 | return 14 | 15 | 16 | 17 | def imread(path, is_grayscale = False): 18 | if (is_grayscale): 19 | return scipy.misc.imread(path, flatten = True).astype(np.float) 20 | return scipy.misc.imread(path).astype(np.float) 21 | 22 | 23 | def load_image(image_path): 24 | input_img = imread(image_path) 25 | w = int(input_img.shape[1]) 26 | w2 = int(w/2) 27 | img_A = input_img[:, 0:w2] 28 | img_B = input_img[:, w2:w] 29 | return img_A, img_B 30 | 31 | 32 | def show_image(img): 33 | plt.imshow(np.asarray(np.clip((img + 1.)*127.5, 0., 255.), dtype=np.uint8)) 34 | plt.show() 35 | 36 | 37 | def save_image(img, filedir, i): 38 | plt.imsave(filedir + '/epoch-%d.jpg' % i, np.asarray(np.clip((img + 1.)*127.5, 0., 255.), dtype=np.uint8)) 39 | 40 | 41 | def img_preprocess(img, label, fine_size, load_size, is_test=False): 42 | if is_test: 43 | img = scipy.misc.imresize(img, [fine_size, fine_size]) 44 | label = scipy.misc.imresize(label, [fine_size, fine_size]) 45 | else: 46 | img = scipy.misc.imresize(img, [load_size, load_size]) 47 | label = scipy.misc.imresize(label, [load_size, load_size]) 48 | 49 | h1 = int(np.ceil(np.random.uniform(1e-2, load_size-fine_size))) 50 | w1 = int(np.ceil(np.random.uniform(1e-2, load_size-fine_size))) 51 | img = img[h1: h1 + fine_size, w1: w1 + fine_size] 52 | label = label[h1: h1 + fine_size, w1: w1 + fine_size] 53 | if np.random.random() > 0.5: 54 | img = np.fliplr(img) 55 | label = np.fliplr(label) 56 | img = img_shift(img) 57 | label = img_shift(label) 58 | return img, label 59 | 60 | 61 | def img_shift(img): 62 | return img / 127.5 - 1. 63 | 64 | def concat(x, y): 65 | return tf.concat(3, [x, y]) 66 | -------------------------------------------------------------------------------- /original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/makora9143/pix2pix-keras-tensorflow/4b7d2192607448659ba7b2c0b638d395dcd23ef4/original.jpg -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | 4 | import threading 5 | import argparse 6 | import progressbar 7 | import time 8 | 9 | from progressbar import Bar, ETA, Percentage, ProgressBar, SimpleProgress 10 | 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | import tensorflow as tf 14 | 15 | from keras import backend as K 16 | 17 | from dataset import Dataset 18 | from ops import * 19 | from model import create_netD, create_netG, my_init 20 | 21 | 22 | np.random.seed(1234) 23 | tf.set_random_seed(1234) 24 | 25 | sess = tf.Session() 26 | K.set_session(sess) 27 | 28 | 29 | # Parameters 30 | parser = argparse.ArgumentParser(description='Training Pix2Pix Model') 31 | parser.add_argument('--dataset', '-d', default='facades', help='Select the datasets from facades') 32 | parser.add_argument('--out', '-o', default='./output_imgs', help='Directory path for generated images') 33 | parser.add_argument('--batchsize', '-b', type=int, default=1, help='Number of images in each mini-batch') 34 | parser.add_argument('--learningrate', '-l', type=float, default=0.0002, help='Learning rate') 35 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1; momentum') 36 | parser.add_argument('--epoch', '-e', type=int, default=200, help='Epoch') 37 | parser.add_argument('--thread', '-t', type=int, default=1, help='num of thread') 38 | parser.add_argument('--filter', '-f', type=int, default=4, help='kernel/filter size') 39 | 40 | args = parser.parse_args() 41 | 42 | mkdir(args.out) 43 | 44 | ngf = 64 45 | ndf = 64 46 | batch_size = args.batchsize 47 | nb_epochs = args.epoch 48 | 49 | data = Dataset(dataset=args.dataset, batch_size=batch_size, thread_num=args.thread) 50 | 51 | train_X, train_y = data.get_inputs() 52 | 53 | test_img, test_label = load_image('./datasets/%s/val/%d.jpg' % (args.dataset, 1)) 54 | test_img = img_shift(test_img) 55 | test_label = img_shift(test_label) 56 | 57 | img_shape, label_shape = data.get_shape() 58 | 59 | image_width = img_shape[0] 60 | image_height = img_shape[1] 61 | input_channel = label_shape[2] 62 | output_channel = img_shape[2] 63 | ############################################## 64 | # Generator 65 | # U-NET 66 | # 67 | # CD512-CD1024-CD1024-C1024-C1024-C512-C256-C128 68 | # 69 | ############################################## 70 | 71 | tmp_x = tf.placeholder(tf.float32, [batch_size, image_width, image_height, input_channel]) 72 | 73 | D = create_netD(image_width, image_height, input_channel+output_channel, ndf, args.filter) 74 | dec_output, generated_img, encoder_decoder = create_netG(train_X, tmp_x, ngf, args.filter, image_width, image_height, input_channel, output_channel, batch_size) 75 | 76 | # ## Objective function 77 | 78 | loss_d = tf.reduce_mean(tf.log(D(concat(train_X, train_y)) + 1e-12)) + tf.reduce_mean(tf.log(1 - D(concat(train_X, dec_output)) + 1e-12)) 79 | 80 | loss_g_1 = tf.reduce_mean(tf.log(1 - D(concat(train_X, dec_output)) + 1e-12)) 81 | loss_g_2 = tf.reduce_mean(tf.abs(train_y - dec_output)) 82 | loss_g = loss_g_1 + 100. * loss_g_2 83 | 84 | 85 | # ## Optimizer 86 | 87 | train_d = tf.train.AdamOptimizer(0.0002, beta1=0.5).minimize(-loss_d, var_list=D.trainable_weights) 88 | train_g = tf.train.AdamOptimizer(0.0002, beta1=0.5).minimize(loss_g, var_list=[op for l in map(lambda x: x.trainable_weights, encoder_decoder) for op in l]) 89 | 90 | 91 | # ## Initialize 92 | 93 | sess.run(tf.global_variables_initializer()) 94 | tf.train.start_queue_runners(sess=sess) 95 | data.start_threads(sess) 96 | saver = tf.train.Saver() 97 | mkdir('./model') 98 | # # Training 99 | 100 | print 'start training' 101 | widgets = ['Train: ', Percentage(), '(', SimpleProgress(), ') ',Bar(marker='#', left='[', right=']'), ' ', ETA()] 102 | 103 | for i in range(nb_epochs): 104 | ave_d = [] 105 | ave_g = [] 106 | 107 | pbar = ProgressBar(widgets=widgets, maxval=data.get_size() - 1 ) 108 | pbar.start() 109 | 110 | for j in range(data.get_size() - 1): 111 | sess.run(train_d, feed_dict={K.learning_phase(): 1}) 112 | sess.run(train_g, feed_dict={K.learning_phase(): 1}) 113 | 114 | loss_d_val = sess.run(loss_d, feed_dict={K.learning_phase(): 1}) 115 | ave_d.append(loss_d_val) 116 | ave_g.append(sess.run(loss_g, feed_dict={K.learning_phase(): 1})) 117 | time.sleep(0.001) 118 | pbar.update(j) 119 | pbar.finish() 120 | 121 | print "Epoch %d/%d - dis_loss: %g - gen_loss: %g" % (i+1, nb_epochs, np.mean(ave_d), np.mean(ave_g)) 122 | generated_image = sess.run(generated_img, feed_dict={tmp_x: [test_label], K.learning_phase(): 1}) 123 | save_image(generated_image[0], args.out + '/' + args.dataset , i+1) 124 | saver.save(sess, './model/{}/model.ckpt'.format(args.dataset), global_step=i+1) 125 | 126 | 127 | --------------------------------------------------------------------------------