├── .gitignore ├── CartoonGAN.py ├── LICENSE ├── README.md ├── edge_smooth.py ├── main.py ├── ops.py ├── utils.py └── vgg19.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /CartoonGAN.py: -------------------------------------------------------------------------------- 1 | from ops import * 2 | from utils import * 3 | from glob import glob 4 | import time 5 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch 6 | import numpy as np 7 | 8 | class CartoonGAN(object) : 9 | def __init__(self, sess, args): 10 | self.model_name = 'CartoonGAN' 11 | self.sess = sess 12 | self.checkpoint_dir = args.checkpoint_dir 13 | self.result_dir = args.result_dir 14 | self.log_dir = args.log_dir 15 | self.dataset_name = args.dataset 16 | self.augment_flag = args.augment_flag 17 | 18 | self.epoch = args.epoch 19 | self.init_epoch = args.init_epoch # args.epoch // 20 20 | self.iteration = args.iteration 21 | self.decay_flag = args.decay_flag 22 | self.decay_epoch = args.decay_epoch 23 | 24 | self.gan_type = args.gan_type 25 | 26 | self.batch_size = args.batch_size 27 | self.print_freq = args.print_freq 28 | self.save_freq = args.save_freq 29 | 30 | self.init_lr = args.lr 31 | self.ch = args.ch 32 | 33 | """ Weight """ 34 | self.adv_weight = args.adv_weight 35 | self.vgg_weight = args.vgg_weight 36 | self.ld = args.ld 37 | 38 | """ Generator """ 39 | self.n_res = args.n_res 40 | 41 | """ Discriminator """ 42 | self.n_dis = args.n_dis 43 | self.n_critic = args.n_critic 44 | self.sn = args.sn 45 | 46 | self.img_size = args.img_size 47 | self.img_ch = args.img_ch 48 | 49 | 50 | self.sample_dir = os.path.join(args.sample_dir, self.model_dir) 51 | check_folder(self.sample_dir) 52 | 53 | self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA')) 54 | self.trainB_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB')) 55 | self.trainB_smooth_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB_smooth')) 56 | 57 | self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset)) 58 | 59 | print() 60 | 61 | print("##### Information #####") 62 | print("# gan type : ", self.gan_type) 63 | print("# dataset : ", self.dataset_name) 64 | print("# max dataset number : ", self.dataset_num) 65 | print("# batch_size : ", self.batch_size) 66 | print("# epoch : ", self.epoch) 67 | print("# init_epoch : ", self.init_epoch) 68 | print("# iteration per epoch : ", self.iteration) 69 | 70 | print() 71 | 72 | print("##### Generator #####") 73 | print("# residual blocks : ", self.n_res) 74 | 75 | print() 76 | 77 | print("##### Discriminator #####") 78 | print("# the number of discriminator layer : ", self.n_dis) 79 | print("# the number of critic : ", self.n_critic) 80 | print("# spectral normalization : ", self.sn) 81 | 82 | print() 83 | 84 | ################################################################################## 85 | # Generator 86 | ################################################################################## 87 | 88 | def generator(self, x_init, reuse=False, scope="generator"): 89 | channel = self.ch 90 | with tf.variable_scope(scope, reuse=reuse) : 91 | x = conv(x_init, channel, kernel=7, stride=1, pad=3, pad_type='reflect', use_bias=False, scope='conv') 92 | x = instance_norm(x, scope='ins_norm') 93 | x = relu(x) 94 | 95 | # Down-Sampling 96 | for i in range(2) : 97 | x = conv(x, channel*2, kernel=3, stride=2, pad=1, use_bias=False, scope='conv_s2_'+str(i)) 98 | x = conv(x, channel*2, kernel=3, stride=1, pad=1, use_bias=False, scope='conv_s1_'+str(i)) 99 | x = instance_norm(x, scope='ins_norm_'+str(i)) 100 | x = relu(x) 101 | 102 | channel = channel * 2 103 | 104 | # Bottleneck 105 | for i in range(self.n_res): 106 | x = resblock(x, channel, use_bias=False, scope='resblock_' + str(i)) 107 | 108 | # Up-Sampling 109 | for i in range(2) : 110 | x = deconv(x, channel//2, kernel=3, stride=2, use_bias=False, scope='deconv_'+str(i)) 111 | x = conv(x, channel//2, kernel=3, stride=1, pad=1, use_bias=False, scope='up_conv_'+str(i)) 112 | x = instance_norm(x, scope='up_ins_norm_'+str(i)) 113 | x = relu(x) 114 | 115 | channel = channel // 2 116 | 117 | 118 | x = conv(x, channels=self.img_ch, kernel=7, stride=1, pad=3, pad_type='reflect', use_bias=False, scope='G_logit') 119 | x = tanh(x) 120 | 121 | return x 122 | 123 | ################################################################################## 124 | # Discriminator 125 | ################################################################################## 126 | 127 | def discriminator(self, x_init, reuse=False, scope="discriminator"): 128 | channel = self.ch // 2 129 | with tf.variable_scope(scope, reuse=reuse): 130 | x = conv(x_init, channel, kernel=3, stride=1, pad=1, use_bias=False, sn=self.sn, scope='conv_0') 131 | x = lrelu(x, 0.2) 132 | 133 | for i in range(1, self.n_dis): 134 | x = conv(x, channel * 2, kernel=3, stride=2, pad=1, use_bias=False, sn=self.sn, scope='conv_s2_' + str(i)) 135 | x = lrelu(x, 0.2) 136 | 137 | x = conv(x, channel * 4, kernel=3, stride=1, pad=1, use_bias=False, sn=self.sn, scope='conv_s1_' + str(i)) 138 | x = instance_norm(x, scope='ins_norm_' + str(i)) 139 | x = lrelu(x, 0.2) 140 | 141 | channel = channel * 2 142 | 143 | x = conv(x, channel * 2, kernel=3, stride=1, pad=1, use_bias=False, sn=self.sn, scope='last_conv') 144 | x = instance_norm(x, scope='last_ins_norm') 145 | x = lrelu(x, 0.2) 146 | 147 | x = conv(x, channels=1, kernel=3, stride=1, pad=1, use_bias=False, sn=self.sn, scope='D_logit') 148 | 149 | return x 150 | 151 | ################################################################################## 152 | # Model 153 | ################################################################################## 154 | def gradient_panalty(self, real, fake, scope="discriminator"): 155 | if self.gan_type.__contains__('dragan') : 156 | eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.) 157 | _, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3]) 158 | x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region 159 | 160 | fake = real + 0.5 * x_std * eps 161 | 162 | alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.) 163 | interpolated = real + alpha * (fake - real) 164 | 165 | logit = self.discriminator(interpolated, reuse=True, scope=scope) 166 | 167 | 168 | grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated) 169 | grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm 170 | 171 | GP = 0 172 | # WGAN - LP 173 | if self.gan_type.__contains__('lp'): 174 | GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.))) 175 | 176 | elif self.gan_type.__contains__('gp') or self.gan_type == 'dragan' : 177 | GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.)) 178 | 179 | 180 | return GP 181 | 182 | def build_model(self): 183 | self.lr = tf.placeholder(tf.float32, name='learning_rate') 184 | 185 | 186 | """ Input Image""" 187 | Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag) 188 | 189 | trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset) 190 | trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset) 191 | trainB_smooth = tf.data.Dataset.from_tensor_slices(self.trainB_smooth_dataset) 192 | 193 | gpu_device = '/gpu:0' 194 | trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size)) 195 | trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size)) 196 | trainB_smooth = trainB_smooth.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size)) 197 | 198 | trainA_iterator = trainA.make_one_shot_iterator() 199 | trainB_iterator = trainB.make_one_shot_iterator() 200 | trainB_smooth_iterator = trainB_smooth.make_one_shot_iterator() 201 | 202 | 203 | self.real_A = trainA_iterator.get_next() 204 | self.real_B = trainB_iterator.get_next() 205 | self.real_B_smooth = trainB_smooth_iterator.get_next() 206 | 207 | self.test_real_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_real_A') 208 | 209 | 210 | """ Define Generator, Discriminator """ 211 | self.fake_B = self.generator(self.real_A) 212 | 213 | real_B_logit = self.discriminator(self.real_B) 214 | fake_B_logit = self.discriminator(self.fake_B, reuse=True) 215 | real_B_smooth_logit = self.discriminator(self.real_B_smooth, reuse=True) 216 | 217 | 218 | """ Define Loss """ 219 | if self.gan_type.__contains__('gp') or self.gan_type.__contains__('lp') or self.gan_type.__contains__('dragan') : 220 | GP = self.gradient_panalty(real=self.real_B, fake=self.fake_B) + self.gradient_panalty(self.real_B, fake=self.real_B_smooth) 221 | else : 222 | GP = 0.0 223 | 224 | v_loss = self.vgg_weight * vgg_loss(self.real_A, self.fake_B) 225 | g_loss = self.adv_weight * generator_loss(self.gan_type, fake_B_logit) 226 | d_loss = self.adv_weight * discriminator_loss(self.gan_type, real_B_logit, fake_B_logit, real_B_smooth_logit) + GP 227 | 228 | self.Vgg_loss = v_loss 229 | self.Generator_loss = g_loss + v_loss 230 | self.Discriminator_loss = d_loss 231 | 232 | 233 | """ Result Image """ 234 | self.test_fake_B = self.generator(self.test_real_A, reuse=True) 235 | 236 | """ Training """ 237 | t_vars = tf.trainable_variables() 238 | G_vars = [var for var in t_vars if 'generator' in var.name] 239 | D_vars = [var for var in t_vars if 'discriminator' in var.name] 240 | 241 | self.init_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Vgg_loss, var_list=G_vars) 242 | self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars) 243 | self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars) 244 | 245 | 246 | """" Summary """ 247 | self.G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) 248 | self.D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) 249 | 250 | self.G_gan = tf.summary.scalar("G_gan", g_loss) 251 | self.G_vgg = tf.summary.scalar("G_vgg", v_loss) 252 | 253 | self.V_loss_merge = tf.summary.merge([self.G_vgg]) 254 | self.G_loss_merge = tf.summary.merge([self.G_loss, self.G_gan, self.G_vgg]) 255 | self.D_loss_merge = tf.summary.merge([self.D_loss]) 256 | 257 | 258 | def train(self): 259 | # initialize all variables 260 | tf.global_variables_initializer().run() 261 | 262 | # saver to save model 263 | self.saver = tf.train.Saver() 264 | 265 | # summary writer 266 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph) 267 | 268 | 269 | # restore check-point if it exits 270 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 271 | if could_load: 272 | start_epoch = (int)(checkpoint_counter / self.iteration) 273 | start_batch_id = checkpoint_counter - start_epoch * self.iteration 274 | counter = checkpoint_counter 275 | print(" [*] Load SUCCESS") 276 | else: 277 | start_epoch = 0 278 | start_batch_id = 0 279 | counter = 1 280 | print(" [!] Load failed...") 281 | 282 | # loop for epoch 283 | start_time = time.time() 284 | past_g_loss = -1. 285 | lr = self.init_lr 286 | for epoch in range(start_epoch, self.epoch): 287 | # lr = self.init_lr if epoch < self.decay_epoch else self.init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch) 288 | if self.decay_flag : 289 | lr = self.init_lr * pow(0.5, epoch // self.decay_epoch) 290 | 291 | for idx in range(start_batch_id, self.iteration): 292 | 293 | train_feed_dict = { 294 | self.lr : lr 295 | } 296 | 297 | if epoch < self.init_epoch : 298 | # Init G 299 | real_A_images, fake_B_images, _, v_loss, summary_str = self.sess.run([self.real_A, self.fake_B, 300 | self.init_optim, 301 | self.Vgg_loss, self.V_loss_merge], feed_dict = train_feed_dict) 302 | self.writer.add_summary(summary_str, counter) 303 | print("Epoch: [%3d] [%5d/%5d] time: %4.4f v_loss: %.8f" % (epoch, idx, self.iteration, time.time() - start_time, v_loss)) 304 | 305 | else : 306 | # Update D 307 | _, d_loss, summary_str = self.sess.run([self.D_optim, self.Discriminator_loss, self.D_loss_merge], feed_dict = train_feed_dict) 308 | self.writer.add_summary(summary_str, counter) 309 | 310 | # Update G 311 | g_loss = None 312 | if (counter - 1) % self.n_critic == 0 : 313 | real_A_images, fake_B_images, _, g_loss, summary_str = self.sess.run([self.real_A, self.fake_B, 314 | self.G_optim, 315 | self.Generator_loss, self.G_loss_merge], feed_dict = train_feed_dict) 316 | self.writer.add_summary(summary_str, counter) 317 | past_g_loss = g_loss 318 | 319 | if g_loss == None: 320 | g_loss = past_g_loss 321 | print("Epoch: [%3d] [%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss)) 322 | 323 | # display training status 324 | counter += 1 325 | 326 | 327 | if np.mod(idx+1, self.print_freq) == 0 : 328 | save_images(real_A_images, [self.batch_size, 1], 329 | './{}/real_A_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1)) 330 | save_images(fake_B_images, [self.batch_size, 1], 331 | './{}/fake_B_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1)) 332 | 333 | if np.mod(idx + 1, self.save_freq) == 0: 334 | self.save(self.checkpoint_dir, counter) 335 | 336 | 337 | 338 | # After an epoch, start_batch_id is set to zero 339 | # non-zero value is only for the first epoch after loading pre-trained model 340 | start_batch_id = 0 341 | 342 | # save model for final step 343 | self.save(self.checkpoint_dir, counter) 344 | 345 | @property 346 | def model_dir(self): 347 | n_res = str(self.n_res) + 'resblock' 348 | n_dis = str(self.n_dis) + 'dis' 349 | return "{}_{}_{}_{}_{}_{}_{}_{}_{}".format(self.model_name, self.dataset_name, 350 | self.gan_type, n_res, n_dis, 351 | self.n_critic, self.sn, 352 | int(self.adv_weight), int(self.vgg_weight)) 353 | 354 | def save(self, checkpoint_dir, step): 355 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 356 | 357 | if not os.path.exists(checkpoint_dir): 358 | os.makedirs(checkpoint_dir) 359 | 360 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step) 361 | 362 | def load(self, checkpoint_dir): 363 | print(" [*] Reading checkpoints...") 364 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 365 | 366 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # checkpoint file information 367 | 368 | if ckpt and ckpt.model_checkpoint_path: 369 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) # first line 370 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 371 | counter = int(ckpt_name.split('-')[-1]) 372 | print(" [*] Success to read {}".format(ckpt_name)) 373 | return True, counter 374 | else: 375 | print(" [*] Failed to find a checkpoint") 376 | return False, 0 377 | 378 | def test(self): 379 | tf.global_variables_initializer().run() 380 | test_A_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testA')) 381 | 382 | self.saver = tf.train.Saver() 383 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 384 | self.result_dir = os.path.join(self.result_dir, self.model_dir) 385 | check_folder(self.result_dir) 386 | 387 | if could_load : 388 | print(" [*] Load SUCCESS") 389 | else : 390 | print(" [!] Load failed...") 391 | 392 | # write html for visual comparison 393 | index_path = os.path.join(self.result_dir, 'index.html') 394 | index = open(index_path, 'w') 395 | index.write("") 396 | index.write("") 397 | 398 | for sample_file in test_A_files : # A -> B 399 | print('Processing A image: ' + sample_file) 400 | sample_image = np.asarray(load_test_data(sample_file)) 401 | image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file))) 402 | 403 | fake_img = self.sess.run(self.test_fake_B, feed_dict = {self.test_real_A : sample_image}) 404 | save_images(fake_img, [1, 1], image_path) 405 | 406 | index.write("" % os.path.basename(image_path)) 407 | 408 | index.write("" % (sample_file if os.path.isabs(sample_file) else ( 409 | '../..' + os.path.sep + sample_file), self.img_size, self.img_size)) 410 | index.write("" % (image_path if os.path.isabs(image_path) else ( 411 | '../..' + os.path.sep + image_path), self.img_size, self.img_size)) 412 | index.write("") 413 | 414 | index.close() 415 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Junho Kim (1993.01.12) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CartoonGAN-Tensorflow 2 | Simple Tensorflow implementation of [CartoonGAN](http://openaccess.thecvf.com/content_cvpr_2018/papers/Chen_CartoonGAN_Generative_Adversarial_CVPR_2018_paper.pdf) (CVPR 2018) 3 | 4 | ## Pytorch version 5 | * [CartoonGAN-Pytorch](https://github.com/znxlwm/pytorch-CartoonGAN) 6 | 7 | ## Requirements 8 | * Tensorflow 1.8 9 | * Python 3.6 10 | 11 | ## Usage 12 | ### 1. Download vgg19 13 | * [vgg19.npy](https://mega.nz/#!xZ8glS6J!MAnE91ND_WyfZ_8mvkuSa2YcA7q-1ehfSm-Q1fxOvvs) 14 | 15 | ### 2. Do edge_smooth 16 | ``` 17 | > python edge_smooth.py --dataset face2anime --img_size 256 18 | ``` 19 | 20 | ``` 21 | ├── dataset 22 |    └── YOUR_DATASET_NAME 23 |    ├── trainA 24 |           ├── xxx.jpg (name, format doesn't matter) 25 | ├── yyy.png 26 | └── ... 27 |    ├── trainB 28 | ├── zzz.jpg 29 | ├── www.png 30 | └── ... 31 |    ├── trainB_smooth (After you run the above code, it will be created automatically) 32 |    ├── zzz.jpg 33 | ├── www.png 34 | └── ... 35 |    └── testA 36 | ├── aaa.jpg 37 | ├── bbb.png 38 | └── ... 39 | ``` 40 | 41 | ### 3. Train 42 | * python main.py --phase train --dataset face2anime --epoch 100 --init_epoch 1 43 | 44 | ### 4. Test 45 | * python main.py --phase test --dataset face2anime 46 | 47 | ## Author 48 | Junho Kim 49 | -------------------------------------------------------------------------------- /edge_smooth.py: -------------------------------------------------------------------------------- 1 | from utils import check_folder 2 | import numpy as np 3 | import cv2, os, argparse 4 | from glob import glob 5 | from tqdm import tqdm 6 | 7 | def parse_args(): 8 | desc = "Edge smoothed" 9 | parser = argparse.ArgumentParser(description=desc) 10 | parser.add_argument('--dataset', type=str, default='hw', help='dataset_name') 11 | parser.add_argument('--img_size', type=int, default=256, help='The size of image') 12 | 13 | return parser.parse_args() 14 | 15 | def make_edge_smooth(dataset_name, img_size) : 16 | check_folder('./dataset/{}/{}'.format(dataset_name, 'trainB_smooth')) 17 | 18 | file_list = glob('./dataset/{}/{}/*.*'.format(dataset_name, 'trainB')) 19 | save_dir = './dataset/{}/trainB_smooth'.format(dataset_name) 20 | 21 | kernel_size = 5 22 | kernel = np.ones((kernel_size, kernel_size), np.uint8) 23 | gauss = cv2.getGaussianKernel(kernel_size, 0) 24 | gauss = gauss * gauss.transpose(1, 0) 25 | 26 | for f in tqdm(file_list) : 27 | file_name = os.path.basename(f) 28 | 29 | bgr_img = cv2.imread(f) 30 | gray_img = cv2.imread(f, 0) 31 | 32 | bgr_img = cv2.resize(bgr_img, (img_size, img_size)) 33 | pad_img = np.pad(bgr_img, ((2, 2), (2, 2), (0, 0)), mode='reflect') 34 | gray_img = cv2.resize(gray_img, (img_size, img_size)) 35 | 36 | edges = cv2.Canny(gray_img, 100, 200) 37 | dilation = cv2.dilate(edges, kernel) 38 | 39 | gauss_img = np.copy(bgr_img) 40 | idx = np.where(dilation != 0) 41 | for i in range(np.sum(dilation != 0)): 42 | gauss_img[idx[0][i], idx[1][i], 0] = np.sum( 43 | np.multiply(pad_img[idx[0][i]:idx[0][i] + kernel_size, idx[1][i]:idx[1][i] + kernel_size, 0], gauss)) 44 | gauss_img[idx[0][i], idx[1][i], 1] = np.sum( 45 | np.multiply(pad_img[idx[0][i]:idx[0][i] + kernel_size, idx[1][i]:idx[1][i] + kernel_size, 1], gauss)) 46 | gauss_img[idx[0][i], idx[1][i], 2] = np.sum( 47 | np.multiply(pad_img[idx[0][i]:idx[0][i] + kernel_size, idx[1][i]:idx[1][i] + kernel_size, 2], gauss)) 48 | 49 | cv2.imwrite(os.path.join(save_dir, file_name), gauss_img) 50 | 51 | """main""" 52 | def main(): 53 | # parse arguments 54 | args = parse_args() 55 | if args is None: 56 | exit() 57 | 58 | make_edge_smooth(args.dataset, args.img_size) 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from CartoonGAN import CartoonGAN 2 | import argparse 3 | from utils import * 4 | 5 | """parsing and configuration""" 6 | 7 | def parse_args(): 8 | desc = "Tensorflow implementation of CartoonGAN" 9 | parser = argparse.ArgumentParser(description=desc) 10 | parser.add_argument('--phase', type=str, default='train', help='train or test ?') 11 | parser.add_argument('--dataset', type=str, default='face2anime', help='dataset_name') 12 | 13 | parser.add_argument('--epoch', type=int, default=100, help='The number of epochs to run') 14 | parser.add_argument('--init_epoch', type=int, default=1, help='The number of epochs for weight initialization') 15 | parser.add_argument('--iteration', type=int, default=10000, help='The number of training iterations') 16 | parser.add_argument('--batch_size', type=int, default=1, help='The size of batch size') 17 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq') 18 | parser.add_argument('--save_freq', type=int, default=1000, help='The number of ckpt_save_freq') 19 | parser.add_argument('--decay_flag', type=str2bool, default=False, help='The decay_flag') 20 | parser.add_argument('--decay_epoch', type=int, default=10, help='decay epoch') 21 | 22 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate') 23 | parser.add_argument('--ld', type=float, default=10.0, help='The gradient penalty lambda') 24 | parser.add_argument('--adv_weight', type=float, default=1.0, help='Weight about GAN') 25 | parser.add_argument('--vgg_weight', type=float, default=10.0, help='Weight about VGG19') 26 | parser.add_argument('--gan_type', type=str, default='gan', help='[gan / lsgan / wgan-gp / wgan-lp / dragan / hinge') 27 | 28 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer') 29 | parser.add_argument('--n_res', type=int, default=8, help='The number of resblock') 30 | 31 | parser.add_argument('--n_dis', type=int, default=3, help='The number of discriminator layer') 32 | parser.add_argument('--n_critic', type=int, default=1, help='The number of critic') 33 | parser.add_argument('--sn', type=str2bool, default=False, help='using spectral norm') 34 | 35 | parser.add_argument('--img_size', type=int, default=256, help='The size of image') 36 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel') 37 | parser.add_argument('--augment_flag', type=str2bool, default=False, help='Image augmentation use or not') 38 | 39 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 40 | help='Directory name to save the checkpoints') 41 | parser.add_argument('--result_dir', type=str, default='results', 42 | help='Directory name to save the generated images') 43 | parser.add_argument('--log_dir', type=str, default='logs', 44 | help='Directory name to save training logs') 45 | parser.add_argument('--sample_dir', type=str, default='samples', 46 | help='Directory name to save the samples on training') 47 | 48 | return check_args(parser.parse_args()) 49 | 50 | """checking arguments""" 51 | def check_args(args): 52 | # --checkpoint_dir 53 | check_folder(args.checkpoint_dir) 54 | 55 | # --result_dir 56 | check_folder(args.result_dir) 57 | 58 | # --result_dir 59 | check_folder(args.log_dir) 60 | 61 | # --sample_dir 62 | check_folder(args.sample_dir) 63 | 64 | # --epoch 65 | try: 66 | assert args.epoch >= 1 67 | except: 68 | print('number of epochs must be larger than or equal to one') 69 | 70 | # --batch_size 71 | try: 72 | assert args.batch_size >= 1 73 | except: 74 | print('batch size must be larger than or equal to one') 75 | return args 76 | 77 | """main""" 78 | def main(): 79 | # parse arguments 80 | args = parse_args() 81 | if args is None: 82 | exit() 83 | 84 | # open session 85 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 86 | gan = CartoonGAN(sess, args) 87 | 88 | # build graph 89 | gan.build_model() 90 | 91 | # show network architecture 92 | show_all_variables() 93 | 94 | if args.phase == 'train' : 95 | gan.train() 96 | print(" [*] Training finished!") 97 | 98 | if args.phase == 'test' : 99 | gan.test() 100 | print(" [*] Test finished!") 101 | 102 | if __name__ == '__main__': 103 | main() 104 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tf_contrib 3 | from vgg19 import Vgg19 4 | 5 | # Xavier : tf_contrib.layers.xavier_initializer() 6 | # He : tf_contrib.layers.variance_scaling_initializer() 7 | # Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02) 8 | # l2_decay : tf_contrib.layers.l2_regularizer(0.0001) 9 | 10 | 11 | weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02) 12 | weight_regularizer = None 13 | 14 | ################################################################################## 15 | # Layer 16 | ################################################################################## 17 | 18 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'): 19 | with tf.variable_scope(scope): 20 | if (kernel - stride) % 2 == 0 : 21 | pad_top = pad 22 | pad_bottom = pad 23 | pad_left = pad 24 | pad_right = pad 25 | 26 | else : 27 | pad_top = pad 28 | pad_bottom = kernel - stride - pad_top 29 | pad_left = pad 30 | pad_right = kernel - stride - pad_left 31 | 32 | if pad_type == 'zero' : 33 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]) 34 | if pad_type == 'reflect' : 35 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT') 36 | 37 | if sn : 38 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init, 39 | regularizer=weight_regularizer) 40 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), 41 | strides=[1, stride, stride, 1], padding='VALID') 42 | if use_bias : 43 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 44 | x = tf.nn.bias_add(x, bias) 45 | 46 | else : 47 | x = tf.layers.conv2d(inputs=x, filters=channels, 48 | kernel_size=kernel, kernel_initializer=weight_init, 49 | kernel_regularizer=weight_regularizer, 50 | strides=stride, use_bias=use_bias) 51 | 52 | 53 | return x 54 | 55 | def deconv(x, channels, kernel=4, stride=2, use_bias=True, sn=False, scope='deconv_0'): 56 | with tf.variable_scope(scope): 57 | x_shape = x.get_shape().as_list() 58 | output_shape = [x_shape[0], x_shape[1]*stride, x_shape[2]*stride, channels] 59 | if sn : 60 | w = tf.get_variable("kernel", shape=[kernel, kernel, channels, x.get_shape()[-1]], initializer=weight_init, regularizer=weight_regularizer) 61 | x = tf.nn.conv2d_transpose(x, filter=spectral_norm(w), output_shape=output_shape, strides=[1, stride, stride, 1], padding='SAME') 62 | 63 | if use_bias : 64 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 65 | x = tf.nn.bias_add(x, bias) 66 | 67 | else : 68 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels, 69 | kernel_size=kernel, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer, 70 | strides=stride, padding='SAME', use_bias=use_bias) 71 | 72 | return x 73 | 74 | 75 | ################################################################################## 76 | # Residual-block 77 | ################################################################################## 78 | 79 | def resblock(x_init, channels, use_bias=True, scope='resblock_0'): 80 | with tf.variable_scope(scope): 81 | with tf.variable_scope('res1'): 82 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias) 83 | x = instance_norm(x) 84 | x = relu(x) 85 | 86 | with tf.variable_scope('res2'): 87 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias) 88 | x = instance_norm(x) 89 | 90 | return x + x_init 91 | 92 | ################################################################################## 93 | # Sampling 94 | ################################################################################## 95 | 96 | def flatten(x) : 97 | return tf.layers.flatten(x) 98 | 99 | ################################################################################## 100 | # Activation function 101 | ################################################################################## 102 | 103 | def lrelu(x, alpha=0.2): 104 | return tf.nn.leaky_relu(x, alpha) 105 | 106 | 107 | def relu(x): 108 | return tf.nn.relu(x) 109 | 110 | 111 | def tanh(x): 112 | return tf.tanh(x) 113 | 114 | def sigmoid(x) : 115 | return tf.sigmoid(x) 116 | 117 | ################################################################################## 118 | # Normalization function 119 | ################################################################################## 120 | 121 | def instance_norm(x, scope='instance_norm'): 122 | return tf_contrib.layers.instance_norm(x, 123 | epsilon=1e-05, 124 | center=True, scale=True, 125 | scope=scope) 126 | 127 | def layer_norm(x, scope='layer_norm') : 128 | return tf_contrib.layers.layer_norm(x, 129 | center=True, scale=True, 130 | scope=scope) 131 | 132 | def batch_norm(x, is_training=True, scope='batch_norm'): 133 | return tf_contrib.layers.batch_norm(x, 134 | decay=0.9, epsilon=1e-05, 135 | center=True, scale=True, updates_collections=None, 136 | is_training=is_training, scope=scope) 137 | 138 | 139 | def spectral_norm(w, iteration=1): 140 | w_shape = w.shape.as_list() 141 | w = tf.reshape(w, [-1, w_shape[-1]]) 142 | 143 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False) 144 | 145 | u_hat = u 146 | v_hat = None 147 | for i in range(iteration): 148 | """ 149 | power iteration 150 | Usually iteration = 1 will be enough 151 | """ 152 | v_ = tf.matmul(u_hat, tf.transpose(w)) 153 | v_hat = l2_norm(v_) 154 | 155 | u_ = tf.matmul(v_hat, w) 156 | u_hat = l2_norm(u_) 157 | 158 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) 159 | w_norm = w / sigma 160 | 161 | with tf.control_dependencies([u.assign(u_hat)]): 162 | w_norm = tf.reshape(w_norm, w_shape) 163 | 164 | return w_norm 165 | 166 | def l2_norm(v, eps=1e-12): 167 | return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps) 168 | 169 | ################################################################################## 170 | # Loss function 171 | ################################################################################## 172 | 173 | def L1_loss(x, y): 174 | loss = tf.reduce_mean(tf.abs(x - y)) 175 | 176 | return loss 177 | 178 | def discriminator_loss(loss_func, real, fake, real_blur): 179 | real_loss = 0 180 | fake_loss = 0 181 | real_blur_loss = 0 182 | 183 | 184 | if loss_func == 'wgan-gp' or loss_func == 'wgan-lp': 185 | real_loss = -tf.reduce_mean(real) 186 | fake_loss = tf.reduce_mean(fake) 187 | real_blur_loss = tf.reduce_mean(real_blur) 188 | 189 | if loss_func == 'lsgan' : 190 | real_loss = tf.reduce_mean(tf.square(real - 1.0)) 191 | fake_loss = tf.reduce_mean(tf.square(fake)) 192 | real_blur_loss = tf.reduce_mean(tf.square(real_blur)) 193 | 194 | if loss_func == 'gan' or loss_func == 'dragan' : 195 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real)) 196 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake)) 197 | real_blur_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(real_blur), logits=real_blur)) 198 | 199 | if loss_func == 'hinge': 200 | real_loss = tf.reduce_mean(relu(1.0 - real)) 201 | fake_loss = tf.reduce_mean(relu(1.0 + fake)) 202 | real_blur_loss = tf.reduce_mean(relu(1.0 + real_blur)) 203 | 204 | loss = real_loss + fake_loss + real_blur_loss 205 | 206 | return loss 207 | 208 | def generator_loss(loss_func, fake): 209 | fake_loss = 0 210 | 211 | if loss_func == 'wgan-gp' or loss_func == 'wgan-lp': 212 | fake_loss = -tf.reduce_mean(fake) 213 | 214 | if loss_func == 'lsgan' : 215 | fake_loss = tf.reduce_mean(tf.square(fake - 1.0)) 216 | 217 | if loss_func == 'gan' or loss_func == 'dragan': 218 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake)) 219 | 220 | if loss_func == 'hinge': 221 | fake_loss = -tf.reduce_mean(fake) 222 | 223 | loss = fake_loss 224 | 225 | return loss 226 | 227 | def vgg_loss(real, fake): 228 | vgg = Vgg19('vgg19.npy') 229 | 230 | vgg.build(real) 231 | real_feature_map = vgg.conv4_4_no_activation 232 | 233 | vgg.build(fake) 234 | fake_feature_map = vgg.conv4_4_no_activation 235 | 236 | loss = L1_loss(real_feature_map, fake_feature_map) 237 | 238 | return loss 239 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import slim 3 | from scipy import misc 4 | import os, random 5 | import numpy as np 6 | 7 | class ImageData: 8 | 9 | def __init__(self, load_size, channels, augment_flag): 10 | self.load_size = load_size 11 | self.channels = channels 12 | self.augment_flag = augment_flag 13 | 14 | def image_processing(self, filename): 15 | x = tf.read_file(filename) 16 | x_decode = tf.image.decode_jpeg(x, channels=self.channels) 17 | img = tf.image.resize_images(x_decode, [self.load_size, self.load_size]) 18 | img = tf.cast(img, tf.float32) / 127.5 - 1 19 | 20 | if self.augment_flag : 21 | augment_size = self.load_size + (30 if self.load_size == 256 else 15) 22 | p = random.random() 23 | if p > 0.5: 24 | img = augmentation(img, augment_size) 25 | 26 | return img 27 | 28 | 29 | def load_test_data(image_path, size=256): 30 | img = misc.imread(image_path, mode='RGB') 31 | img = misc.imresize(img, [size, size]) 32 | img = np.expand_dims(img, axis=0) 33 | img = preprocessing(img) 34 | 35 | return img 36 | 37 | def preprocessing(x): 38 | x = x/127.5 - 1 # -1 ~ 1 39 | return x 40 | 41 | def augmentation(image, augment_size): 42 | seed = random.randint(0, 2 ** 31 - 1) 43 | ori_image_shape = tf.shape(image) 44 | image = tf.image.random_flip_left_right(image, seed=seed) 45 | image = tf.image.resize_images(image, [augment_size, augment_size]) 46 | image = tf.random_crop(image, ori_image_shape, seed=seed) 47 | return image 48 | 49 | def save_images(images, size, image_path): 50 | return imsave(inverse_transform(images), size, image_path) 51 | 52 | def inverse_transform(images): 53 | return (images+1.) / 2 54 | 55 | 56 | def imsave(images, size, path): 57 | return misc.imsave(path, merge(images, size)) 58 | 59 | def merge(images, size): 60 | h, w = images.shape[1], images.shape[2] 61 | img = np.zeros((h * size[0], w * size[1], 3)) 62 | for idx, image in enumerate(images): 63 | i = idx % size[1] 64 | j = idx // size[1] 65 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image 66 | 67 | return img 68 | 69 | def show_all_variables(): 70 | model_vars = tf.trainable_variables() 71 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 72 | 73 | def check_folder(log_dir): 74 | if not os.path.exists(log_dir): 75 | os.makedirs(log_dir) 76 | return log_dir 77 | 78 | def str2bool(x): 79 | return x.lower() in ('true') 80 | 81 | 82 | -------------------------------------------------------------------------------- /vgg19.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | import numpy as np 5 | import time 6 | import inspect 7 | 8 | VGG_MEAN = [103.939, 116.779, 123.68] 9 | 10 | 11 | class Vgg19: 12 | 13 | def __init__(self, vgg19_npy_path=None): 14 | if vgg19_npy_path is None: 15 | path = inspect.getfile(Vgg19) 16 | path = os.path.abspath(os.path.join(path, os.pardir)) 17 | path = os.path.join(path, "vgg19.npy") 18 | vgg19_npy_path = path 19 | print(vgg19_npy_path) 20 | 21 | self.data_dict = np.load(vgg19_npy_path, encoding='latin1').item() 22 | print("npy file loaded") 23 | 24 | def build(self, rgb, include_fc=False): 25 | """ 26 | load variable from npy to build the VGG 27 | input format: bgr image with shape [batch_size, h, w, 3] 28 | scale: (-1, 1) 29 | """ 30 | 31 | start_time = time.time() 32 | rgb_scaled = ((rgb + 1) / 2) * 255.0 # [-1, 1] ~ [0, 255] 33 | 34 | red, green, blue = tf.split(axis=3, num_or_size_splits=3, value=rgb_scaled) 35 | bgr = tf.concat(axis=3, values=[blue - VGG_MEAN[0], 36 | green - VGG_MEAN[1], 37 | red - VGG_MEAN[2]]) 38 | 39 | self.conv1_1 = self.conv_layer(bgr, "conv1_1") 40 | self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2") 41 | self.pool1 = self.max_pool(self.conv1_2, 'pool1') 42 | 43 | self.conv2_1 = self.conv_layer(self.pool1, "conv2_1") 44 | self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2") 45 | self.pool2 = self.max_pool(self.conv2_2, 'pool2') 46 | 47 | self.conv3_1 = self.conv_layer(self.pool2, "conv3_1") 48 | self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2") 49 | self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3") 50 | self.conv3_4 = self.conv_layer(self.conv3_3, "conv3_4") 51 | self.pool3 = self.max_pool(self.conv3_4, 'pool3') 52 | 53 | self.conv4_1 = self.conv_layer(self.pool3, "conv4_1") 54 | self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2") 55 | self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3") 56 | 57 | self.conv4_4_no_activation = self.no_activation_conv_layer(self.conv4_3, "conv4_4") 58 | 59 | self.conv4_4 = self.conv_layer(self.conv4_3, "conv4_4") 60 | self.pool4 = self.max_pool(self.conv4_4, 'pool4') 61 | 62 | self.conv5_1 = self.conv_layer(self.pool4, "conv5_1") 63 | self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2") 64 | self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3") 65 | self.conv5_4 = self.conv_layer(self.conv5_3, "conv5_4") 66 | self.pool5 = self.max_pool(self.conv5_4, 'pool5') 67 | 68 | if include_fc: 69 | self.fc6 = self.fc_layer(self.pool5, "fc6") 70 | assert self.fc6.get_shape().as_list()[1:] == [4096] 71 | self.relu6 = tf.nn.relu(self.fc6) 72 | 73 | self.fc7 = self.fc_layer(self.relu6, "fc7") 74 | self.relu7 = tf.nn.relu(self.fc7) 75 | 76 | self.fc8 = self.fc_layer(self.relu7, "fc8") 77 | 78 | self.prob = tf.nn.softmax(self.fc8, name="prob") 79 | 80 | self.data_dict = None 81 | 82 | print(("Finished building vgg19: %ds" % (time.time() - start_time))) 83 | 84 | def avg_pool(self, bottom, name): 85 | return tf.nn.avg_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name) 86 | 87 | def max_pool(self, bottom, name): 88 | return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name) 89 | 90 | def conv_layer(self, bottom, name): 91 | with tf.variable_scope(name): 92 | filt = self.get_conv_filter(name) 93 | 94 | conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME') 95 | 96 | conv_biases = self.get_bias(name) 97 | bias = tf.nn.bias_add(conv, conv_biases) 98 | 99 | relu = tf.nn.relu(bias) 100 | return relu 101 | 102 | def no_activation_conv_layer(self, bottom, name): 103 | with tf.variable_scope(name): 104 | filt = self.get_conv_filter(name) 105 | 106 | conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME') 107 | 108 | conv_biases = self.get_bias(name) 109 | x = tf.nn.bias_add(conv, conv_biases) 110 | 111 | 112 | return x 113 | 114 | def fc_layer(self, bottom, name): 115 | with tf.variable_scope(name): 116 | shape = bottom.get_shape().as_list() 117 | dim = 1 118 | for d in shape[1:]: 119 | dim *= d 120 | x = tf.reshape(bottom, [-1, dim]) 121 | 122 | weights = self.get_fc_weight(name) 123 | biases = self.get_bias(name) 124 | 125 | # Fully connected layer. Note that the '+' operation automatically 126 | # broadcasts the biases. 127 | fc = tf.nn.bias_add(tf.matmul(x, weights), biases) 128 | 129 | return fc 130 | 131 | def get_conv_filter(self, name): 132 | return tf.constant(self.data_dict[name][0], name="filter") 133 | 134 | def get_bias(self, name): 135 | return tf.constant(self.data_dict[name][1], name="biases") 136 | 137 | def get_fc_weight(self, name): 138 | return tf.constant(self.data_dict[name][0], name="weights") --------------------------------------------------------------------------------
nameinputoutput
%s