├── .gitignore
├── CartoonGAN.py
├── LICENSE
├── README.md
├── edge_smooth.py
├── main.py
├── ops.py
├── utils.py
└── vgg19.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/CartoonGAN.py:
--------------------------------------------------------------------------------
1 | from ops import *
2 | from utils import *
3 | from glob import glob
4 | import time
5 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
6 | import numpy as np
7 |
8 | class CartoonGAN(object) :
9 | def __init__(self, sess, args):
10 | self.model_name = 'CartoonGAN'
11 | self.sess = sess
12 | self.checkpoint_dir = args.checkpoint_dir
13 | self.result_dir = args.result_dir
14 | self.log_dir = args.log_dir
15 | self.dataset_name = args.dataset
16 | self.augment_flag = args.augment_flag
17 |
18 | self.epoch = args.epoch
19 | self.init_epoch = args.init_epoch # args.epoch // 20
20 | self.iteration = args.iteration
21 | self.decay_flag = args.decay_flag
22 | self.decay_epoch = args.decay_epoch
23 |
24 | self.gan_type = args.gan_type
25 |
26 | self.batch_size = args.batch_size
27 | self.print_freq = args.print_freq
28 | self.save_freq = args.save_freq
29 |
30 | self.init_lr = args.lr
31 | self.ch = args.ch
32 |
33 | """ Weight """
34 | self.adv_weight = args.adv_weight
35 | self.vgg_weight = args.vgg_weight
36 | self.ld = args.ld
37 |
38 | """ Generator """
39 | self.n_res = args.n_res
40 |
41 | """ Discriminator """
42 | self.n_dis = args.n_dis
43 | self.n_critic = args.n_critic
44 | self.sn = args.sn
45 |
46 | self.img_size = args.img_size
47 | self.img_ch = args.img_ch
48 |
49 |
50 | self.sample_dir = os.path.join(args.sample_dir, self.model_dir)
51 | check_folder(self.sample_dir)
52 |
53 | self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA'))
54 | self.trainB_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB'))
55 | self.trainB_smooth_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB_smooth'))
56 |
57 | self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset))
58 |
59 | print()
60 |
61 | print("##### Information #####")
62 | print("# gan type : ", self.gan_type)
63 | print("# dataset : ", self.dataset_name)
64 | print("# max dataset number : ", self.dataset_num)
65 | print("# batch_size : ", self.batch_size)
66 | print("# epoch : ", self.epoch)
67 | print("# init_epoch : ", self.init_epoch)
68 | print("# iteration per epoch : ", self.iteration)
69 |
70 | print()
71 |
72 | print("##### Generator #####")
73 | print("# residual blocks : ", self.n_res)
74 |
75 | print()
76 |
77 | print("##### Discriminator #####")
78 | print("# the number of discriminator layer : ", self.n_dis)
79 | print("# the number of critic : ", self.n_critic)
80 | print("# spectral normalization : ", self.sn)
81 |
82 | print()
83 |
84 | ##################################################################################
85 | # Generator
86 | ##################################################################################
87 |
88 | def generator(self, x_init, reuse=False, scope="generator"):
89 | channel = self.ch
90 | with tf.variable_scope(scope, reuse=reuse) :
91 | x = conv(x_init, channel, kernel=7, stride=1, pad=3, pad_type='reflect', use_bias=False, scope='conv')
92 | x = instance_norm(x, scope='ins_norm')
93 | x = relu(x)
94 |
95 | # Down-Sampling
96 | for i in range(2) :
97 | x = conv(x, channel*2, kernel=3, stride=2, pad=1, use_bias=False, scope='conv_s2_'+str(i))
98 | x = conv(x, channel*2, kernel=3, stride=1, pad=1, use_bias=False, scope='conv_s1_'+str(i))
99 | x = instance_norm(x, scope='ins_norm_'+str(i))
100 | x = relu(x)
101 |
102 | channel = channel * 2
103 |
104 | # Bottleneck
105 | for i in range(self.n_res):
106 | x = resblock(x, channel, use_bias=False, scope='resblock_' + str(i))
107 |
108 | # Up-Sampling
109 | for i in range(2) :
110 | x = deconv(x, channel//2, kernel=3, stride=2, use_bias=False, scope='deconv_'+str(i))
111 | x = conv(x, channel//2, kernel=3, stride=1, pad=1, use_bias=False, scope='up_conv_'+str(i))
112 | x = instance_norm(x, scope='up_ins_norm_'+str(i))
113 | x = relu(x)
114 |
115 | channel = channel // 2
116 |
117 |
118 | x = conv(x, channels=self.img_ch, kernel=7, stride=1, pad=3, pad_type='reflect', use_bias=False, scope='G_logit')
119 | x = tanh(x)
120 |
121 | return x
122 |
123 | ##################################################################################
124 | # Discriminator
125 | ##################################################################################
126 |
127 | def discriminator(self, x_init, reuse=False, scope="discriminator"):
128 | channel = self.ch // 2
129 | with tf.variable_scope(scope, reuse=reuse):
130 | x = conv(x_init, channel, kernel=3, stride=1, pad=1, use_bias=False, sn=self.sn, scope='conv_0')
131 | x = lrelu(x, 0.2)
132 |
133 | for i in range(1, self.n_dis):
134 | x = conv(x, channel * 2, kernel=3, stride=2, pad=1, use_bias=False, sn=self.sn, scope='conv_s2_' + str(i))
135 | x = lrelu(x, 0.2)
136 |
137 | x = conv(x, channel * 4, kernel=3, stride=1, pad=1, use_bias=False, sn=self.sn, scope='conv_s1_' + str(i))
138 | x = instance_norm(x, scope='ins_norm_' + str(i))
139 | x = lrelu(x, 0.2)
140 |
141 | channel = channel * 2
142 |
143 | x = conv(x, channel * 2, kernel=3, stride=1, pad=1, use_bias=False, sn=self.sn, scope='last_conv')
144 | x = instance_norm(x, scope='last_ins_norm')
145 | x = lrelu(x, 0.2)
146 |
147 | x = conv(x, channels=1, kernel=3, stride=1, pad=1, use_bias=False, sn=self.sn, scope='D_logit')
148 |
149 | return x
150 |
151 | ##################################################################################
152 | # Model
153 | ##################################################################################
154 | def gradient_panalty(self, real, fake, scope="discriminator"):
155 | if self.gan_type.__contains__('dragan') :
156 | eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.)
157 | _, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])
158 | x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region
159 |
160 | fake = real + 0.5 * x_std * eps
161 |
162 | alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.)
163 | interpolated = real + alpha * (fake - real)
164 |
165 | logit = self.discriminator(interpolated, reuse=True, scope=scope)
166 |
167 |
168 | grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated)
169 | grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm
170 |
171 | GP = 0
172 | # WGAN - LP
173 | if self.gan_type.__contains__('lp'):
174 | GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))
175 |
176 | elif self.gan_type.__contains__('gp') or self.gan_type == 'dragan' :
177 | GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.))
178 |
179 |
180 | return GP
181 |
182 | def build_model(self):
183 | self.lr = tf.placeholder(tf.float32, name='learning_rate')
184 |
185 |
186 | """ Input Image"""
187 | Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag)
188 |
189 | trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
190 | trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)
191 | trainB_smooth = tf.data.Dataset.from_tensor_slices(self.trainB_smooth_dataset)
192 |
193 | gpu_device = '/gpu:0'
194 | trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))
195 | trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))
196 | trainB_smooth = trainB_smooth.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))
197 |
198 | trainA_iterator = trainA.make_one_shot_iterator()
199 | trainB_iterator = trainB.make_one_shot_iterator()
200 | trainB_smooth_iterator = trainB_smooth.make_one_shot_iterator()
201 |
202 |
203 | self.real_A = trainA_iterator.get_next()
204 | self.real_B = trainB_iterator.get_next()
205 | self.real_B_smooth = trainB_smooth_iterator.get_next()
206 |
207 | self.test_real_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_real_A')
208 |
209 |
210 | """ Define Generator, Discriminator """
211 | self.fake_B = self.generator(self.real_A)
212 |
213 | real_B_logit = self.discriminator(self.real_B)
214 | fake_B_logit = self.discriminator(self.fake_B, reuse=True)
215 | real_B_smooth_logit = self.discriminator(self.real_B_smooth, reuse=True)
216 |
217 |
218 | """ Define Loss """
219 | if self.gan_type.__contains__('gp') or self.gan_type.__contains__('lp') or self.gan_type.__contains__('dragan') :
220 | GP = self.gradient_panalty(real=self.real_B, fake=self.fake_B) + self.gradient_panalty(self.real_B, fake=self.real_B_smooth)
221 | else :
222 | GP = 0.0
223 |
224 | v_loss = self.vgg_weight * vgg_loss(self.real_A, self.fake_B)
225 | g_loss = self.adv_weight * generator_loss(self.gan_type, fake_B_logit)
226 | d_loss = self.adv_weight * discriminator_loss(self.gan_type, real_B_logit, fake_B_logit, real_B_smooth_logit) + GP
227 |
228 | self.Vgg_loss = v_loss
229 | self.Generator_loss = g_loss + v_loss
230 | self.Discriminator_loss = d_loss
231 |
232 |
233 | """ Result Image """
234 | self.test_fake_B = self.generator(self.test_real_A, reuse=True)
235 |
236 | """ Training """
237 | t_vars = tf.trainable_variables()
238 | G_vars = [var for var in t_vars if 'generator' in var.name]
239 | D_vars = [var for var in t_vars if 'discriminator' in var.name]
240 |
241 | self.init_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Vgg_loss, var_list=G_vars)
242 | self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars)
243 | self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars)
244 |
245 |
246 | """" Summary """
247 | self.G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss)
248 | self.D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss)
249 |
250 | self.G_gan = tf.summary.scalar("G_gan", g_loss)
251 | self.G_vgg = tf.summary.scalar("G_vgg", v_loss)
252 |
253 | self.V_loss_merge = tf.summary.merge([self.G_vgg])
254 | self.G_loss_merge = tf.summary.merge([self.G_loss, self.G_gan, self.G_vgg])
255 | self.D_loss_merge = tf.summary.merge([self.D_loss])
256 |
257 |
258 | def train(self):
259 | # initialize all variables
260 | tf.global_variables_initializer().run()
261 |
262 | # saver to save model
263 | self.saver = tf.train.Saver()
264 |
265 | # summary writer
266 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)
267 |
268 |
269 | # restore check-point if it exits
270 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
271 | if could_load:
272 | start_epoch = (int)(checkpoint_counter / self.iteration)
273 | start_batch_id = checkpoint_counter - start_epoch * self.iteration
274 | counter = checkpoint_counter
275 | print(" [*] Load SUCCESS")
276 | else:
277 | start_epoch = 0
278 | start_batch_id = 0
279 | counter = 1
280 | print(" [!] Load failed...")
281 |
282 | # loop for epoch
283 | start_time = time.time()
284 | past_g_loss = -1.
285 | lr = self.init_lr
286 | for epoch in range(start_epoch, self.epoch):
287 | # lr = self.init_lr if epoch < self.decay_epoch else self.init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch)
288 | if self.decay_flag :
289 | lr = self.init_lr * pow(0.5, epoch // self.decay_epoch)
290 |
291 | for idx in range(start_batch_id, self.iteration):
292 |
293 | train_feed_dict = {
294 | self.lr : lr
295 | }
296 |
297 | if epoch < self.init_epoch :
298 | # Init G
299 | real_A_images, fake_B_images, _, v_loss, summary_str = self.sess.run([self.real_A, self.fake_B,
300 | self.init_optim,
301 | self.Vgg_loss, self.V_loss_merge], feed_dict = train_feed_dict)
302 | self.writer.add_summary(summary_str, counter)
303 | print("Epoch: [%3d] [%5d/%5d] time: %4.4f v_loss: %.8f" % (epoch, idx, self.iteration, time.time() - start_time, v_loss))
304 |
305 | else :
306 | # Update D
307 | _, d_loss, summary_str = self.sess.run([self.D_optim, self.Discriminator_loss, self.D_loss_merge], feed_dict = train_feed_dict)
308 | self.writer.add_summary(summary_str, counter)
309 |
310 | # Update G
311 | g_loss = None
312 | if (counter - 1) % self.n_critic == 0 :
313 | real_A_images, fake_B_images, _, g_loss, summary_str = self.sess.run([self.real_A, self.fake_B,
314 | self.G_optim,
315 | self.Generator_loss, self.G_loss_merge], feed_dict = train_feed_dict)
316 | self.writer.add_summary(summary_str, counter)
317 | past_g_loss = g_loss
318 |
319 | if g_loss == None:
320 | g_loss = past_g_loss
321 | print("Epoch: [%3d] [%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss))
322 |
323 | # display training status
324 | counter += 1
325 |
326 |
327 | if np.mod(idx+1, self.print_freq) == 0 :
328 | save_images(real_A_images, [self.batch_size, 1],
329 | './{}/real_A_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1))
330 | save_images(fake_B_images, [self.batch_size, 1],
331 | './{}/fake_B_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1))
332 |
333 | if np.mod(idx + 1, self.save_freq) == 0:
334 | self.save(self.checkpoint_dir, counter)
335 |
336 |
337 |
338 | # After an epoch, start_batch_id is set to zero
339 | # non-zero value is only for the first epoch after loading pre-trained model
340 | start_batch_id = 0
341 |
342 | # save model for final step
343 | self.save(self.checkpoint_dir, counter)
344 |
345 | @property
346 | def model_dir(self):
347 | n_res = str(self.n_res) + 'resblock'
348 | n_dis = str(self.n_dis) + 'dis'
349 | return "{}_{}_{}_{}_{}_{}_{}_{}_{}".format(self.model_name, self.dataset_name,
350 | self.gan_type, n_res, n_dis,
351 | self.n_critic, self.sn,
352 | int(self.adv_weight), int(self.vgg_weight))
353 |
354 | def save(self, checkpoint_dir, step):
355 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
356 |
357 | if not os.path.exists(checkpoint_dir):
358 | os.makedirs(checkpoint_dir)
359 |
360 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
361 |
362 | def load(self, checkpoint_dir):
363 | print(" [*] Reading checkpoints...")
364 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
365 |
366 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # checkpoint file information
367 |
368 | if ckpt and ckpt.model_checkpoint_path:
369 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) # first line
370 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
371 | counter = int(ckpt_name.split('-')[-1])
372 | print(" [*] Success to read {}".format(ckpt_name))
373 | return True, counter
374 | else:
375 | print(" [*] Failed to find a checkpoint")
376 | return False, 0
377 |
378 | def test(self):
379 | tf.global_variables_initializer().run()
380 | test_A_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testA'))
381 |
382 | self.saver = tf.train.Saver()
383 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
384 | self.result_dir = os.path.join(self.result_dir, self.model_dir)
385 | check_folder(self.result_dir)
386 |
387 | if could_load :
388 | print(" [*] Load SUCCESS")
389 | else :
390 | print(" [!] Load failed...")
391 |
392 | # write html for visual comparison
393 | index_path = os.path.join(self.result_dir, 'index.html')
394 | index = open(index_path, 'w')
395 | index.write("
")
396 | index.write("name | input | output |
")
397 |
398 | for sample_file in test_A_files : # A -> B
399 | print('Processing A image: ' + sample_file)
400 | sample_image = np.asarray(load_test_data(sample_file))
401 | image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file)))
402 |
403 | fake_img = self.sess.run(self.test_fake_B, feed_dict = {self.test_real_A : sample_image})
404 | save_images(fake_img, [1, 1], image_path)
405 |
406 | index.write("%s | " % os.path.basename(image_path))
407 |
408 | index.write(" | " % (sample_file if os.path.isabs(sample_file) else (
409 | '../..' + os.path.sep + sample_file), self.img_size, self.img_size))
410 | index.write(" | " % (image_path if os.path.isabs(image_path) else (
411 | '../..' + os.path.sep + image_path), self.img_size, self.img_size))
412 | index.write("")
413 |
414 | index.close()
415 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Junho Kim (1993.01.12)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CartoonGAN-Tensorflow
2 | Simple Tensorflow implementation of [CartoonGAN](http://openaccess.thecvf.com/content_cvpr_2018/papers/Chen_CartoonGAN_Generative_Adversarial_CVPR_2018_paper.pdf) (CVPR 2018)
3 |
4 | ## Pytorch version
5 | * [CartoonGAN-Pytorch](https://github.com/znxlwm/pytorch-CartoonGAN)
6 |
7 | ## Requirements
8 | * Tensorflow 1.8
9 | * Python 3.6
10 |
11 | ## Usage
12 | ### 1. Download vgg19
13 | * [vgg19.npy](https://mega.nz/#!xZ8glS6J!MAnE91ND_WyfZ_8mvkuSa2YcA7q-1ehfSm-Q1fxOvvs)
14 |
15 | ### 2. Do edge_smooth
16 | ```
17 | > python edge_smooth.py --dataset face2anime --img_size 256
18 | ```
19 |
20 | ```
21 | ├── dataset
22 | └── YOUR_DATASET_NAME
23 | ├── trainA
24 | ├── xxx.jpg (name, format doesn't matter)
25 | ├── yyy.png
26 | └── ...
27 | ├── trainB
28 | ├── zzz.jpg
29 | ├── www.png
30 | └── ...
31 | ├── trainB_smooth (After you run the above code, it will be created automatically)
32 | ├── zzz.jpg
33 | ├── www.png
34 | └── ...
35 | └── testA
36 | ├── aaa.jpg
37 | ├── bbb.png
38 | └── ...
39 | ```
40 |
41 | ### 3. Train
42 | * python main.py --phase train --dataset face2anime --epoch 100 --init_epoch 1
43 |
44 | ### 4. Test
45 | * python main.py --phase test --dataset face2anime
46 |
47 | ## Author
48 | Junho Kim
49 |
--------------------------------------------------------------------------------
/edge_smooth.py:
--------------------------------------------------------------------------------
1 | from utils import check_folder
2 | import numpy as np
3 | import cv2, os, argparse
4 | from glob import glob
5 | from tqdm import tqdm
6 |
7 | def parse_args():
8 | desc = "Edge smoothed"
9 | parser = argparse.ArgumentParser(description=desc)
10 | parser.add_argument('--dataset', type=str, default='hw', help='dataset_name')
11 | parser.add_argument('--img_size', type=int, default=256, help='The size of image')
12 |
13 | return parser.parse_args()
14 |
15 | def make_edge_smooth(dataset_name, img_size) :
16 | check_folder('./dataset/{}/{}'.format(dataset_name, 'trainB_smooth'))
17 |
18 | file_list = glob('./dataset/{}/{}/*.*'.format(dataset_name, 'trainB'))
19 | save_dir = './dataset/{}/trainB_smooth'.format(dataset_name)
20 |
21 | kernel_size = 5
22 | kernel = np.ones((kernel_size, kernel_size), np.uint8)
23 | gauss = cv2.getGaussianKernel(kernel_size, 0)
24 | gauss = gauss * gauss.transpose(1, 0)
25 |
26 | for f in tqdm(file_list) :
27 | file_name = os.path.basename(f)
28 |
29 | bgr_img = cv2.imread(f)
30 | gray_img = cv2.imread(f, 0)
31 |
32 | bgr_img = cv2.resize(bgr_img, (img_size, img_size))
33 | pad_img = np.pad(bgr_img, ((2, 2), (2, 2), (0, 0)), mode='reflect')
34 | gray_img = cv2.resize(gray_img, (img_size, img_size))
35 |
36 | edges = cv2.Canny(gray_img, 100, 200)
37 | dilation = cv2.dilate(edges, kernel)
38 |
39 | gauss_img = np.copy(bgr_img)
40 | idx = np.where(dilation != 0)
41 | for i in range(np.sum(dilation != 0)):
42 | gauss_img[idx[0][i], idx[1][i], 0] = np.sum(
43 | np.multiply(pad_img[idx[0][i]:idx[0][i] + kernel_size, idx[1][i]:idx[1][i] + kernel_size, 0], gauss))
44 | gauss_img[idx[0][i], idx[1][i], 1] = np.sum(
45 | np.multiply(pad_img[idx[0][i]:idx[0][i] + kernel_size, idx[1][i]:idx[1][i] + kernel_size, 1], gauss))
46 | gauss_img[idx[0][i], idx[1][i], 2] = np.sum(
47 | np.multiply(pad_img[idx[0][i]:idx[0][i] + kernel_size, idx[1][i]:idx[1][i] + kernel_size, 2], gauss))
48 |
49 | cv2.imwrite(os.path.join(save_dir, file_name), gauss_img)
50 |
51 | """main"""
52 | def main():
53 | # parse arguments
54 | args = parse_args()
55 | if args is None:
56 | exit()
57 |
58 | make_edge_smooth(args.dataset, args.img_size)
59 |
60 |
61 | if __name__ == '__main__':
62 | main()
63 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from CartoonGAN import CartoonGAN
2 | import argparse
3 | from utils import *
4 |
5 | """parsing and configuration"""
6 |
7 | def parse_args():
8 | desc = "Tensorflow implementation of CartoonGAN"
9 | parser = argparse.ArgumentParser(description=desc)
10 | parser.add_argument('--phase', type=str, default='train', help='train or test ?')
11 | parser.add_argument('--dataset', type=str, default='face2anime', help='dataset_name')
12 |
13 | parser.add_argument('--epoch', type=int, default=100, help='The number of epochs to run')
14 | parser.add_argument('--init_epoch', type=int, default=1, help='The number of epochs for weight initialization')
15 | parser.add_argument('--iteration', type=int, default=10000, help='The number of training iterations')
16 | parser.add_argument('--batch_size', type=int, default=1, help='The size of batch size')
17 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq')
18 | parser.add_argument('--save_freq', type=int, default=1000, help='The number of ckpt_save_freq')
19 | parser.add_argument('--decay_flag', type=str2bool, default=False, help='The decay_flag')
20 | parser.add_argument('--decay_epoch', type=int, default=10, help='decay epoch')
21 |
22 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate')
23 | parser.add_argument('--ld', type=float, default=10.0, help='The gradient penalty lambda')
24 | parser.add_argument('--adv_weight', type=float, default=1.0, help='Weight about GAN')
25 | parser.add_argument('--vgg_weight', type=float, default=10.0, help='Weight about VGG19')
26 | parser.add_argument('--gan_type', type=str, default='gan', help='[gan / lsgan / wgan-gp / wgan-lp / dragan / hinge')
27 |
28 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer')
29 | parser.add_argument('--n_res', type=int, default=8, help='The number of resblock')
30 |
31 | parser.add_argument('--n_dis', type=int, default=3, help='The number of discriminator layer')
32 | parser.add_argument('--n_critic', type=int, default=1, help='The number of critic')
33 | parser.add_argument('--sn', type=str2bool, default=False, help='using spectral norm')
34 |
35 | parser.add_argument('--img_size', type=int, default=256, help='The size of image')
36 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
37 | parser.add_argument('--augment_flag', type=str2bool, default=False, help='Image augmentation use or not')
38 |
39 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
40 | help='Directory name to save the checkpoints')
41 | parser.add_argument('--result_dir', type=str, default='results',
42 | help='Directory name to save the generated images')
43 | parser.add_argument('--log_dir', type=str, default='logs',
44 | help='Directory name to save training logs')
45 | parser.add_argument('--sample_dir', type=str, default='samples',
46 | help='Directory name to save the samples on training')
47 |
48 | return check_args(parser.parse_args())
49 |
50 | """checking arguments"""
51 | def check_args(args):
52 | # --checkpoint_dir
53 | check_folder(args.checkpoint_dir)
54 |
55 | # --result_dir
56 | check_folder(args.result_dir)
57 |
58 | # --result_dir
59 | check_folder(args.log_dir)
60 |
61 | # --sample_dir
62 | check_folder(args.sample_dir)
63 |
64 | # --epoch
65 | try:
66 | assert args.epoch >= 1
67 | except:
68 | print('number of epochs must be larger than or equal to one')
69 |
70 | # --batch_size
71 | try:
72 | assert args.batch_size >= 1
73 | except:
74 | print('batch size must be larger than or equal to one')
75 | return args
76 |
77 | """main"""
78 | def main():
79 | # parse arguments
80 | args = parse_args()
81 | if args is None:
82 | exit()
83 |
84 | # open session
85 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
86 | gan = CartoonGAN(sess, args)
87 |
88 | # build graph
89 | gan.build_model()
90 |
91 | # show network architecture
92 | show_all_variables()
93 |
94 | if args.phase == 'train' :
95 | gan.train()
96 | print(" [*] Training finished!")
97 |
98 | if args.phase == 'test' :
99 | gan.test()
100 | print(" [*] Test finished!")
101 |
102 | if __name__ == '__main__':
103 | main()
104 |
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow.contrib as tf_contrib
3 | from vgg19 import Vgg19
4 |
5 | # Xavier : tf_contrib.layers.xavier_initializer()
6 | # He : tf_contrib.layers.variance_scaling_initializer()
7 | # Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02)
8 | # l2_decay : tf_contrib.layers.l2_regularizer(0.0001)
9 |
10 |
11 | weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02)
12 | weight_regularizer = None
13 |
14 | ##################################################################################
15 | # Layer
16 | ##################################################################################
17 |
18 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'):
19 | with tf.variable_scope(scope):
20 | if (kernel - stride) % 2 == 0 :
21 | pad_top = pad
22 | pad_bottom = pad
23 | pad_left = pad
24 | pad_right = pad
25 |
26 | else :
27 | pad_top = pad
28 | pad_bottom = kernel - stride - pad_top
29 | pad_left = pad
30 | pad_right = kernel - stride - pad_left
31 |
32 | if pad_type == 'zero' :
33 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]])
34 | if pad_type == 'reflect' :
35 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT')
36 |
37 | if sn :
38 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
39 | regularizer=weight_regularizer)
40 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
41 | strides=[1, stride, stride, 1], padding='VALID')
42 | if use_bias :
43 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
44 | x = tf.nn.bias_add(x, bias)
45 |
46 | else :
47 | x = tf.layers.conv2d(inputs=x, filters=channels,
48 | kernel_size=kernel, kernel_initializer=weight_init,
49 | kernel_regularizer=weight_regularizer,
50 | strides=stride, use_bias=use_bias)
51 |
52 |
53 | return x
54 |
55 | def deconv(x, channels, kernel=4, stride=2, use_bias=True, sn=False, scope='deconv_0'):
56 | with tf.variable_scope(scope):
57 | x_shape = x.get_shape().as_list()
58 | output_shape = [x_shape[0], x_shape[1]*stride, x_shape[2]*stride, channels]
59 | if sn :
60 | w = tf.get_variable("kernel", shape=[kernel, kernel, channels, x.get_shape()[-1]], initializer=weight_init, regularizer=weight_regularizer)
61 | x = tf.nn.conv2d_transpose(x, filter=spectral_norm(w), output_shape=output_shape, strides=[1, stride, stride, 1], padding='SAME')
62 |
63 | if use_bias :
64 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
65 | x = tf.nn.bias_add(x, bias)
66 |
67 | else :
68 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels,
69 | kernel_size=kernel, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer,
70 | strides=stride, padding='SAME', use_bias=use_bias)
71 |
72 | return x
73 |
74 |
75 | ##################################################################################
76 | # Residual-block
77 | ##################################################################################
78 |
79 | def resblock(x_init, channels, use_bias=True, scope='resblock_0'):
80 | with tf.variable_scope(scope):
81 | with tf.variable_scope('res1'):
82 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
83 | x = instance_norm(x)
84 | x = relu(x)
85 |
86 | with tf.variable_scope('res2'):
87 | x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
88 | x = instance_norm(x)
89 |
90 | return x + x_init
91 |
92 | ##################################################################################
93 | # Sampling
94 | ##################################################################################
95 |
96 | def flatten(x) :
97 | return tf.layers.flatten(x)
98 |
99 | ##################################################################################
100 | # Activation function
101 | ##################################################################################
102 |
103 | def lrelu(x, alpha=0.2):
104 | return tf.nn.leaky_relu(x, alpha)
105 |
106 |
107 | def relu(x):
108 | return tf.nn.relu(x)
109 |
110 |
111 | def tanh(x):
112 | return tf.tanh(x)
113 |
114 | def sigmoid(x) :
115 | return tf.sigmoid(x)
116 |
117 | ##################################################################################
118 | # Normalization function
119 | ##################################################################################
120 |
121 | def instance_norm(x, scope='instance_norm'):
122 | return tf_contrib.layers.instance_norm(x,
123 | epsilon=1e-05,
124 | center=True, scale=True,
125 | scope=scope)
126 |
127 | def layer_norm(x, scope='layer_norm') :
128 | return tf_contrib.layers.layer_norm(x,
129 | center=True, scale=True,
130 | scope=scope)
131 |
132 | def batch_norm(x, is_training=True, scope='batch_norm'):
133 | return tf_contrib.layers.batch_norm(x,
134 | decay=0.9, epsilon=1e-05,
135 | center=True, scale=True, updates_collections=None,
136 | is_training=is_training, scope=scope)
137 |
138 |
139 | def spectral_norm(w, iteration=1):
140 | w_shape = w.shape.as_list()
141 | w = tf.reshape(w, [-1, w_shape[-1]])
142 |
143 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False)
144 |
145 | u_hat = u
146 | v_hat = None
147 | for i in range(iteration):
148 | """
149 | power iteration
150 | Usually iteration = 1 will be enough
151 | """
152 | v_ = tf.matmul(u_hat, tf.transpose(w))
153 | v_hat = l2_norm(v_)
154 |
155 | u_ = tf.matmul(v_hat, w)
156 | u_hat = l2_norm(u_)
157 |
158 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
159 | w_norm = w / sigma
160 |
161 | with tf.control_dependencies([u.assign(u_hat)]):
162 | w_norm = tf.reshape(w_norm, w_shape)
163 |
164 | return w_norm
165 |
166 | def l2_norm(v, eps=1e-12):
167 | return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)
168 |
169 | ##################################################################################
170 | # Loss function
171 | ##################################################################################
172 |
173 | def L1_loss(x, y):
174 | loss = tf.reduce_mean(tf.abs(x - y))
175 |
176 | return loss
177 |
178 | def discriminator_loss(loss_func, real, fake, real_blur):
179 | real_loss = 0
180 | fake_loss = 0
181 | real_blur_loss = 0
182 |
183 |
184 | if loss_func == 'wgan-gp' or loss_func == 'wgan-lp':
185 | real_loss = -tf.reduce_mean(real)
186 | fake_loss = tf.reduce_mean(fake)
187 | real_blur_loss = tf.reduce_mean(real_blur)
188 |
189 | if loss_func == 'lsgan' :
190 | real_loss = tf.reduce_mean(tf.square(real - 1.0))
191 | fake_loss = tf.reduce_mean(tf.square(fake))
192 | real_blur_loss = tf.reduce_mean(tf.square(real_blur))
193 |
194 | if loss_func == 'gan' or loss_func == 'dragan' :
195 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real))
196 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake))
197 | real_blur_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(real_blur), logits=real_blur))
198 |
199 | if loss_func == 'hinge':
200 | real_loss = tf.reduce_mean(relu(1.0 - real))
201 | fake_loss = tf.reduce_mean(relu(1.0 + fake))
202 | real_blur_loss = tf.reduce_mean(relu(1.0 + real_blur))
203 |
204 | loss = real_loss + fake_loss + real_blur_loss
205 |
206 | return loss
207 |
208 | def generator_loss(loss_func, fake):
209 | fake_loss = 0
210 |
211 | if loss_func == 'wgan-gp' or loss_func == 'wgan-lp':
212 | fake_loss = -tf.reduce_mean(fake)
213 |
214 | if loss_func == 'lsgan' :
215 | fake_loss = tf.reduce_mean(tf.square(fake - 1.0))
216 |
217 | if loss_func == 'gan' or loss_func == 'dragan':
218 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake))
219 |
220 | if loss_func == 'hinge':
221 | fake_loss = -tf.reduce_mean(fake)
222 |
223 | loss = fake_loss
224 |
225 | return loss
226 |
227 | def vgg_loss(real, fake):
228 | vgg = Vgg19('vgg19.npy')
229 |
230 | vgg.build(real)
231 | real_feature_map = vgg.conv4_4_no_activation
232 |
233 | vgg.build(fake)
234 | fake_feature_map = vgg.conv4_4_no_activation
235 |
236 | loss = L1_loss(real_feature_map, fake_feature_map)
237 |
238 | return loss
239 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib import slim
3 | from scipy import misc
4 | import os, random
5 | import numpy as np
6 |
7 | class ImageData:
8 |
9 | def __init__(self, load_size, channels, augment_flag):
10 | self.load_size = load_size
11 | self.channels = channels
12 | self.augment_flag = augment_flag
13 |
14 | def image_processing(self, filename):
15 | x = tf.read_file(filename)
16 | x_decode = tf.image.decode_jpeg(x, channels=self.channels)
17 | img = tf.image.resize_images(x_decode, [self.load_size, self.load_size])
18 | img = tf.cast(img, tf.float32) / 127.5 - 1
19 |
20 | if self.augment_flag :
21 | augment_size = self.load_size + (30 if self.load_size == 256 else 15)
22 | p = random.random()
23 | if p > 0.5:
24 | img = augmentation(img, augment_size)
25 |
26 | return img
27 |
28 |
29 | def load_test_data(image_path, size=256):
30 | img = misc.imread(image_path, mode='RGB')
31 | img = misc.imresize(img, [size, size])
32 | img = np.expand_dims(img, axis=0)
33 | img = preprocessing(img)
34 |
35 | return img
36 |
37 | def preprocessing(x):
38 | x = x/127.5 - 1 # -1 ~ 1
39 | return x
40 |
41 | def augmentation(image, augment_size):
42 | seed = random.randint(0, 2 ** 31 - 1)
43 | ori_image_shape = tf.shape(image)
44 | image = tf.image.random_flip_left_right(image, seed=seed)
45 | image = tf.image.resize_images(image, [augment_size, augment_size])
46 | image = tf.random_crop(image, ori_image_shape, seed=seed)
47 | return image
48 |
49 | def save_images(images, size, image_path):
50 | return imsave(inverse_transform(images), size, image_path)
51 |
52 | def inverse_transform(images):
53 | return (images+1.) / 2
54 |
55 |
56 | def imsave(images, size, path):
57 | return misc.imsave(path, merge(images, size))
58 |
59 | def merge(images, size):
60 | h, w = images.shape[1], images.shape[2]
61 | img = np.zeros((h * size[0], w * size[1], 3))
62 | for idx, image in enumerate(images):
63 | i = idx % size[1]
64 | j = idx // size[1]
65 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image
66 |
67 | return img
68 |
69 | def show_all_variables():
70 | model_vars = tf.trainable_variables()
71 | slim.model_analyzer.analyze_vars(model_vars, print_info=True)
72 |
73 | def check_folder(log_dir):
74 | if not os.path.exists(log_dir):
75 | os.makedirs(log_dir)
76 | return log_dir
77 |
78 | def str2bool(x):
79 | return x.lower() in ('true')
80 |
81 |
82 |
--------------------------------------------------------------------------------
/vgg19.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tensorflow as tf
3 |
4 | import numpy as np
5 | import time
6 | import inspect
7 |
8 | VGG_MEAN = [103.939, 116.779, 123.68]
9 |
10 |
11 | class Vgg19:
12 |
13 | def __init__(self, vgg19_npy_path=None):
14 | if vgg19_npy_path is None:
15 | path = inspect.getfile(Vgg19)
16 | path = os.path.abspath(os.path.join(path, os.pardir))
17 | path = os.path.join(path, "vgg19.npy")
18 | vgg19_npy_path = path
19 | print(vgg19_npy_path)
20 |
21 | self.data_dict = np.load(vgg19_npy_path, encoding='latin1').item()
22 | print("npy file loaded")
23 |
24 | def build(self, rgb, include_fc=False):
25 | """
26 | load variable from npy to build the VGG
27 | input format: bgr image with shape [batch_size, h, w, 3]
28 | scale: (-1, 1)
29 | """
30 |
31 | start_time = time.time()
32 | rgb_scaled = ((rgb + 1) / 2) * 255.0 # [-1, 1] ~ [0, 255]
33 |
34 | red, green, blue = tf.split(axis=3, num_or_size_splits=3, value=rgb_scaled)
35 | bgr = tf.concat(axis=3, values=[blue - VGG_MEAN[0],
36 | green - VGG_MEAN[1],
37 | red - VGG_MEAN[2]])
38 |
39 | self.conv1_1 = self.conv_layer(bgr, "conv1_1")
40 | self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2")
41 | self.pool1 = self.max_pool(self.conv1_2, 'pool1')
42 |
43 | self.conv2_1 = self.conv_layer(self.pool1, "conv2_1")
44 | self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2")
45 | self.pool2 = self.max_pool(self.conv2_2, 'pool2')
46 |
47 | self.conv3_1 = self.conv_layer(self.pool2, "conv3_1")
48 | self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2")
49 | self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3")
50 | self.conv3_4 = self.conv_layer(self.conv3_3, "conv3_4")
51 | self.pool3 = self.max_pool(self.conv3_4, 'pool3')
52 |
53 | self.conv4_1 = self.conv_layer(self.pool3, "conv4_1")
54 | self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2")
55 | self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3")
56 |
57 | self.conv4_4_no_activation = self.no_activation_conv_layer(self.conv4_3, "conv4_4")
58 |
59 | self.conv4_4 = self.conv_layer(self.conv4_3, "conv4_4")
60 | self.pool4 = self.max_pool(self.conv4_4, 'pool4')
61 |
62 | self.conv5_1 = self.conv_layer(self.pool4, "conv5_1")
63 | self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2")
64 | self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3")
65 | self.conv5_4 = self.conv_layer(self.conv5_3, "conv5_4")
66 | self.pool5 = self.max_pool(self.conv5_4, 'pool5')
67 |
68 | if include_fc:
69 | self.fc6 = self.fc_layer(self.pool5, "fc6")
70 | assert self.fc6.get_shape().as_list()[1:] == [4096]
71 | self.relu6 = tf.nn.relu(self.fc6)
72 |
73 | self.fc7 = self.fc_layer(self.relu6, "fc7")
74 | self.relu7 = tf.nn.relu(self.fc7)
75 |
76 | self.fc8 = self.fc_layer(self.relu7, "fc8")
77 |
78 | self.prob = tf.nn.softmax(self.fc8, name="prob")
79 |
80 | self.data_dict = None
81 |
82 | print(("Finished building vgg19: %ds" % (time.time() - start_time)))
83 |
84 | def avg_pool(self, bottom, name):
85 | return tf.nn.avg_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)
86 |
87 | def max_pool(self, bottom, name):
88 | return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)
89 |
90 | def conv_layer(self, bottom, name):
91 | with tf.variable_scope(name):
92 | filt = self.get_conv_filter(name)
93 |
94 | conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')
95 |
96 | conv_biases = self.get_bias(name)
97 | bias = tf.nn.bias_add(conv, conv_biases)
98 |
99 | relu = tf.nn.relu(bias)
100 | return relu
101 |
102 | def no_activation_conv_layer(self, bottom, name):
103 | with tf.variable_scope(name):
104 | filt = self.get_conv_filter(name)
105 |
106 | conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')
107 |
108 | conv_biases = self.get_bias(name)
109 | x = tf.nn.bias_add(conv, conv_biases)
110 |
111 |
112 | return x
113 |
114 | def fc_layer(self, bottom, name):
115 | with tf.variable_scope(name):
116 | shape = bottom.get_shape().as_list()
117 | dim = 1
118 | for d in shape[1:]:
119 | dim *= d
120 | x = tf.reshape(bottom, [-1, dim])
121 |
122 | weights = self.get_fc_weight(name)
123 | biases = self.get_bias(name)
124 |
125 | # Fully connected layer. Note that the '+' operation automatically
126 | # broadcasts the biases.
127 | fc = tf.nn.bias_add(tf.matmul(x, weights), biases)
128 |
129 | return fc
130 |
131 | def get_conv_filter(self, name):
132 | return tf.constant(self.data_dict[name][0], name="filter")
133 |
134 | def get_bias(self, name):
135 | return tf.constant(self.data_dict[name][1], name="biases")
136 |
137 | def get_fc_weight(self, name):
138 | return tf.constant(self.data_dict[name][0], name="weights")
--------------------------------------------------------------------------------