├── README.md
├── images
├── city1.jpg
├── city2.jpg
├── folder.jpg
├── food.jpg
├── home.jpg
├── method.jpg
├── person1.jpg
├── person2.jpg
└── use_cases.jpg
├── index.html
├── index_files
├── pixl-bk.css
└── pixl-fonts.css
├── paper
├── 06791-supp.pdf
├── 06791.pdf
├── shinjuku.jpg
└── white_box_cartoon_acm_style.pdf
├── test_code
├── cartoonize.py
├── guided_filter.py
├── network.py
├── saved_models
│ ├── checkpoint
│ ├── model-33999.data-00000-of-00001
│ └── model-33999.index
└── test_images
│ ├── actress2.jpg
│ ├── china6.jpg
│ ├── food16.jpg
│ ├── food6.jpg
│ ├── liuyifei4.jpg
│ ├── london1.jpg
│ ├── mountain4.jpg
│ ├── mountain5.jpg
│ ├── national_park1.jpg
│ ├── party5.jpg
│ └── party7.jpg
└── train_code
├── guided_filter.py
├── layers.py
├── loss.py
├── network.py
├── pretrain.py
├── selective_search
├── __init__.py
├── adaptive_color.py
├── batch_ss.py
├── core.py
├── measure.py
├── structure.py
└── util.py
├── train.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # [CVPR2020]Learning to Cartoonize Using White-box Cartoon Representations
6 | [project page](https://systemerrorwang.github.io/White-box-Cartoonization/) | [paper](https://github.com/SystemErrorWang/White-box-Cartoonization/blob/master/paper/06791.pdf) | [twitter](https://twitter.com/IlIIlIIIllIllII/status/1243108510423896065) | [zhihu](https://zhuanlan.zhihu.com/p/117422157) | [bilibili](https://www.bilibili.com/video/av56708333) | [facial model](https://github.com/SystemErrorWang/FacialCartoonization)
7 |
8 | - Tensorflow implementation for CVPR2020 paper “Learning to Cartoonize Using White-box Cartoon Representations”.
9 | - Improved method for facial images are now available:
10 | - https://github.com/SystemErrorWang/FacialCartoonization
11 |
12 |
13 |
14 |
15 | ## Use cases
16 |
17 | ### Scenery
18 |
19 |
20 |
21 | ### Food
22 |
23 |
24 | ### Indoor Scenes
25 |
26 |
27 | ### People
28 |
29 |
30 |
31 | ### More Images Are Shown In The Supplementary Materials
32 |
33 |
34 | ## Online demo
35 |
36 | - Some kind people made online demo for this project
37 | - Demo link: https://cartoonize-lkqov62dia-de.a.run.app/cartoonize
38 | - Code: https://github.com/experience-ml/cartoonize
39 | - Sample Demo: https://www.youtube.com/watch?v=GqduSLcmhto&feature=emb_title
40 |
41 | ## Prerequisites
42 |
43 | - Training code: Linux or Windows
44 | - NVIDIA GPU + CUDA CuDNN for performance
45 | - Inference code: Linux, Windows and MacOS
46 |
47 |
48 | ## How To Use
49 |
50 | ### Installation
51 |
52 | - Assume you already have NVIDIA GPU and CUDA CuDNN installed
53 | - Install tensorflow-gpu, we tested 1.12.0 and 1.13.0rc0
54 | - Install scikit-image==0.14.5, other versions may cause problems
55 |
56 |
57 | ### Inference with Pre-trained Model
58 |
59 | - Store test images in /test_code/test_images
60 | - Run /test_code/cartoonize.py
61 | - Results will be saved in /test_code/cartoonized_images
62 |
63 |
64 | ### Train
65 |
66 | - Place your training data in corresponding folders in /dataset
67 | - Run pretrain.py, results will be saved in /pretrain folder
68 | - Run train.py, results will be saved in /train_cartoon folder
69 | - Codes are cleaned from production environment and untested
70 | - There may be minor problems but should be easy to resolve
71 | - Pretrained VGG_19 model can be found at following url:
72 | https://drive.google.com/file/d/1j0jDENjdwxCDb36meP6-u5xDBzmKBOjJ/view?usp=sharing
73 |
74 |
75 |
76 | ### Datasets
77 |
78 | - Due to copyright issues, we cannot provide cartoon images used for training
79 | - However, these training datasets are easy to prepare
80 | - Scenery images are collected from Shinkai Makoto, Miyazaki Hayao and Hosoda Mamoru films
81 | - Clip films into frames and random crop and resize to 256x256
82 | - Portrait images are from Kyoto animations and PA Works
83 | - We use this repo(https://github.com/nagadomi/lbpcascade_animeface) to detect facial areas
84 | - Manual data cleaning will greatly increace both datasets quality
85 |
86 | ## Acknowledgement
87 |
88 | We are grateful for the help from Lvmin Zhang and Style2Paints Research
89 |
90 | ## License
91 | - Copyright (C) Xinrui Wang All rights reserved. Licensed under the CC BY-NC-SA 4.0
92 | - license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
93 | - Commercial application is prohibited, please remain this license if you clone this repo
94 |
95 | ## Citation
96 |
97 | If you use this code for your research, please cite our [paper](https://systemerrorwang.github.io/White-box-Cartoonization/):
98 |
99 | @InProceedings{Wang_2020_CVPR,
100 | author = {Wang, Xinrui and Yu, Jinze},
101 | title = {Learning to Cartoonize Using White-Box Cartoon Representations},
102 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
103 | month = {June},
104 | year = {2020}
105 | }
106 |
107 |
108 | # 中文社区
109 |
110 | 我们有一个除了技术什么东西都聊的以技术交流为主的群。如果你一次加群失败,可以多次尝试: 816096787。
111 |
--------------------------------------------------------------------------------
/images/city1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/images/city1.jpg
--------------------------------------------------------------------------------
/images/city2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/images/city2.jpg
--------------------------------------------------------------------------------
/images/folder.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/images/folder.jpg
--------------------------------------------------------------------------------
/images/food.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/images/food.jpg
--------------------------------------------------------------------------------
/images/home.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/images/home.jpg
--------------------------------------------------------------------------------
/images/method.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/images/method.jpg
--------------------------------------------------------------------------------
/images/person1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/images/person1.jpg
--------------------------------------------------------------------------------
/images/person2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/images/person2.jpg
--------------------------------------------------------------------------------
/images/use_cases.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/images/use_cases.jpg
--------------------------------------------------------------------------------
/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
22 |
29 |
30 |
31 |
Example of image cartoonization with our method: left is a frame in the animation "Garden of words", right is a real-world photo processed by our proposed method.
32 |
33 |
34 |
35 | This paper presents an approach for image cartooniza- tion. By observing the cartoon painting behavior and consulting artists, we propose to separately identify three white-box representations from images: the surface rep- resentation that contains a smooth surface of cartoon im- ages, the structure representation that refers to the sparse color-blocks and flatten global content in the celluloid style workflow, and the texture representation that reflects high- frequency texture, contours, and details in cartoon im- ages. A Generative Adversarial Network (GAN) framework is used to learn the extracted representations and to car- toonize images.
36 |
37 | The learning objectives of our method are separately based on each extracted representations, making our frame- work controllable and adjustable. This enables our ap- proach to meet artists’ requirements in different styles and diverse use cases. Qualitative comparisons and quanti- tative analyses, as well as user studies, have been con- ducted to validate the effectiveness of this approach, and our method outperforms previous methods in all compar- isons. Finally, the ablation study demonstrates the influence of each component in our framework.
38 |
39 |
40 |
41 |
45 |
46 |
47 |
48 | - Source Code - Only inference code available now, training code will be updated later.
49 |
50 | - Demo Video - Generated with early version of our work in bilibili.com.
51 |
52 |
53 |
54 |
55 |
56 |
57 | Xinrui Wang and Jinze Yu
58 | "Learning to Cartoonize Using White-box Cartoon Representations."
59 | IEEE Conference on Computer Vision and Pattern Recognition, June 2020.
60 |
61 |
73 |
74 |
75 |
76 |
--------------------------------------------------------------------------------
/index_files/pixl-bk.css:
--------------------------------------------------------------------------------
1 | /* Un-break mobile pages. */
2 | * {
3 | text-size-adjust: none;
4 | -webkit-text-size-adjust: none;
5 | }
6 |
7 | /* Page background and default foreground colors */
8 | body {
9 | background: #111;
10 | color: #ddd;
11 | padding-bottom: 5vw;
12 | }
13 | @media screen and (max-width: 533px) {
14 | body {
15 | font-size: 3vw;
16 | }
17 | }
18 |
19 | /* Default link colors */
20 | a {
21 | color: #f38025;
22 | text-decoration: none;
23 | }
24 | .unlinked {
25 | color: #f38025;
26 | }
27 |
28 | /* "Crumbtrail" bar at top - meant to be used as '
' */
29 | .crumb {
30 | position: fixed;
31 | top: 0px;
32 | left: 0px;
33 | right: 0px;
34 | height: 22px;
35 | background: #f38025;
36 | border-bottom: 5px solid #000;
37 | font-size: 14px;
38 | font-weight: bold;
39 | color: #000;
40 | padding: 2px 10px;
41 | }
42 | @media screen and (max-width: 800px) {
43 | .crumb {
44 | height: 2.75vw;
45 | border-bottom: 0.625vw solid #000;
46 | font-size: 1.75vw;
47 | padding: 0.025vw 1.25vw;
48 | }
49 | }
50 | .crumb a {
51 | color: #000;
52 | }
53 |
54 | /* Right-aligned span */
55 | .right {
56 | float: right;
57 | }
58 |
59 | /* Page main title */
60 | .title {
61 | text-align: center;
62 | font-size: 56px;
63 | padding-top: 80px;
64 | padding-bottom: 56px;
65 | color: #f38025;
66 | }
67 | .title img {
68 | vertical-align: middle;
69 | height: 75px;
70 | padding-right: 10px;
71 | }
72 | @media screen and (max-width: 800px) {
73 | .title {
74 | font-size: 7vw;
75 | padding-top: 10vw;
76 | padding-bottom: 7vw;
77 | }
78 | .title img {
79 | height: 9.375vw;
80 | padding-right: 1.25vw;
81 | }
82 | }
83 | .logoX {
84 | color: #fff;
85 | }
86 |
87 | /* Centered tables or divs */
88 | .content, .linkgallery, .projlist, .abstract {
89 | width: 720px;
90 | margin: 0px auto;
91 | }
92 | .linkgallery, .projlist, .centered {
93 | text-align: center;
94 | }
95 | .linkgallery a img {
96 | width: 188;
97 | filter: grayscale(20%) brightness(150%);
98 | -webkit-filter: grayscale(20%) brightness(150%);
99 | }
100 | .link {
101 | color: #f38025;
102 | font-size: 36px;
103 | }
104 | .projlist td {
105 | padding: 24px;
106 | }
107 | .projlist img {
108 | width: 240px;
109 | }
110 | @media screen and (max-width: 800px) {
111 | .content, .linkgallery, .projlist, .abstract {
112 | width: 90vw;
113 | }
114 | .linkgallery a img {
115 | width: 23.5vw;
116 | }
117 | .link {
118 | font-size: 4.5vw;
119 | }
120 | .projlist td {
121 | padding: 3vw;
122 | }
123 | .projlist img {
124 | width: 30vw;
125 | }
126 | }
127 |
128 | /* Paper listings */
129 | .paperheader {
130 | text-align: center;
131 | padding-top: 80px;
132 | padding-bottom: 56px;
133 | }
134 | @media screen and (max-width: 800px) {
135 | .paperheader {
136 | padding-top: 10vw;
137 | padding-bottom: 7vw;
138 | }
139 | }
140 | .papertitle {
141 | font-size: 150%;
142 | font-weight: bold;
143 | color: #f38025;
144 | }
145 | .pubinfo {
146 | }
147 | .authors {
148 | font-size: 125%;
149 | }
150 | .paperimg {
151 | text-align: center;
152 | }
153 | .paperimg img {
154 | max-width: 100%;
155 | }
156 | .longcaption {
157 | margin: 24px;
158 | font-style: italic;
159 | text-align: justify;
160 | }
161 | .shortcaption {
162 | margin: 24px;
163 | font-style: italic;
164 | text-align: center;
165 | }
166 | @media screen and (max-width: 800px) {
167 | .longcaption {
168 | margin: 3vw;
169 | }
170 | .shortcaption {
171 | margin: 3vw;
172 | }
173 | }
174 | .abstract {
175 | text-align: justify;
176 | }
177 |
178 | /* Links with fancy fading hover effect */
179 | a, .linkgallery a img {
180 | transition: .1s ease-in-out;
181 | -webkit-transition: .1s ease-in-out;
182 | }
183 | a:hover, a:active, .crumb a:hover, .crumb a:active,
184 | .linkgallery a:hover img, .linkgallery a:active img {
185 | color: #fff;
186 | filter: grayscale(100%) brightness(125%);
187 | -webkit-filter: grayscale(100%) brightness(125%);
188 | }
189 |
190 | /* Larger text */
191 | .larger {
192 | font-size: 125%;
193 | }
194 |
195 | /* Small header */
196 | .header {
197 | font-size: 125%;
198 | font-weight: bold;
199 | margin-top: 1em;
200 | }
201 |
202 | /* Text inputs */
203 | input[type=text], select, textarea {
204 | width: 100%;
205 | }
206 |
--------------------------------------------------------------------------------
/index_files/pixl-fonts.css:
--------------------------------------------------------------------------------
1 | /* Use Source Sans Pro fonts */
2 | @import url('https://fonts.googleapis.com/css?family=Source+Sans+Pro:400,600');
3 |
4 | * {
5 | font-family: 'Source Sans Pro', sans-serif;
6 | }
7 |
8 |
--------------------------------------------------------------------------------
/paper/06791-supp.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/paper/06791-supp.pdf
--------------------------------------------------------------------------------
/paper/06791.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/paper/06791.pdf
--------------------------------------------------------------------------------
/paper/shinjuku.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/paper/shinjuku.jpg
--------------------------------------------------------------------------------
/paper/white_box_cartoon_acm_style.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/paper/white_box_cartoon_acm_style.pdf
--------------------------------------------------------------------------------
/test_code/cartoonize.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | import tensorflow as tf
5 | import network
6 | import guided_filter
7 | from tqdm import tqdm
8 |
9 |
10 |
11 | def resize_crop(image):
12 | h, w, c = np.shape(image)
13 | if min(h, w) > 720:
14 | if h > w:
15 | h, w = int(720*h/w), 720
16 | else:
17 | h, w = 720, int(720*w/h)
18 | image = cv2.resize(image, (w, h),
19 | interpolation=cv2.INTER_AREA)
20 | h, w = (h//8)*8, (w//8)*8
21 | image = image[:h, :w, :]
22 | return image
23 |
24 |
25 | def cartoonize(load_folder, save_folder, model_path):
26 | input_photo = tf.placeholder(tf.float32, [1, None, None, 3])
27 | network_out = network.unet_generator(input_photo)
28 | final_out = guided_filter.guided_filter(input_photo, network_out, r=1, eps=5e-3)
29 |
30 | all_vars = tf.trainable_variables()
31 | gene_vars = [var for var in all_vars if 'generator' in var.name]
32 | saver = tf.train.Saver(var_list=gene_vars)
33 |
34 | config = tf.ConfigProto()
35 | config.gpu_options.allow_growth = True
36 | sess = tf.Session(config=config)
37 |
38 | sess.run(tf.global_variables_initializer())
39 | saver.restore(sess, tf.train.latest_checkpoint(model_path))
40 | name_list = os.listdir(load_folder)
41 | for name in tqdm(name_list):
42 | try:
43 | load_path = os.path.join(load_folder, name)
44 | save_path = os.path.join(save_folder, name)
45 | image = cv2.imread(load_path)
46 | image = resize_crop(image)
47 | batch_image = image.astype(np.float32)/127.5 - 1
48 | batch_image = np.expand_dims(batch_image, axis=0)
49 | output = sess.run(final_out, feed_dict={input_photo: batch_image})
50 | output = (np.squeeze(output)+1)*127.5
51 | output = np.clip(output, 0, 255).astype(np.uint8)
52 | cv2.imwrite(save_path, output)
53 | except:
54 | print('cartoonize {} failed'.format(load_path))
55 |
56 |
57 |
58 |
59 | if __name__ == '__main__':
60 | model_path = 'saved_models'
61 | load_folder = 'test_images'
62 | save_folder = 'cartoonized_images'
63 | if not os.path.exists(save_folder):
64 | os.mkdir(save_folder)
65 | cartoonize(load_folder, save_folder, model_path)
66 |
67 |
68 |
--------------------------------------------------------------------------------
/test_code/guided_filter.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 |
5 |
6 |
7 | def tf_box_filter(x, r):
8 | k_size = int(2*r+1)
9 | ch = x.get_shape().as_list()[-1]
10 | weight = 1/(k_size**2)
11 | box_kernel = weight*np.ones((k_size, k_size, ch, 1))
12 | box_kernel = np.array(box_kernel).astype(np.float32)
13 | output = tf.nn.depthwise_conv2d(x, box_kernel, [1, 1, 1, 1], 'SAME')
14 | return output
15 |
16 |
17 |
18 | def guided_filter(x, y, r, eps=1e-2):
19 |
20 | x_shape = tf.shape(x)
21 | #y_shape = tf.shape(y)
22 |
23 | N = tf_box_filter(tf.ones((1, x_shape[1], x_shape[2], 1), dtype=x.dtype), r)
24 |
25 | mean_x = tf_box_filter(x, r) / N
26 | mean_y = tf_box_filter(y, r) / N
27 | cov_xy = tf_box_filter(x * y, r) / N - mean_x * mean_y
28 | var_x = tf_box_filter(x * x, r) / N - mean_x * mean_x
29 |
30 | A = cov_xy / (var_x + eps)
31 | b = mean_y - A * mean_x
32 |
33 | mean_A = tf_box_filter(A, r) / N
34 | mean_b = tf_box_filter(b, r) / N
35 |
36 | output = mean_A * x + mean_b
37 |
38 | return output
39 |
40 |
41 |
42 | def fast_guided_filter(lr_x, lr_y, hr_x, r=1, eps=1e-8):
43 |
44 | #assert lr_x.shape.ndims == 4 and lr_y.shape.ndims == 4 and hr_x.shape.ndims == 4
45 |
46 | lr_x_shape = tf.shape(lr_x)
47 | #lr_y_shape = tf.shape(lr_y)
48 | hr_x_shape = tf.shape(hr_x)
49 |
50 | N = tf_box_filter(tf.ones((1, lr_x_shape[1], lr_x_shape[2], 1), dtype=lr_x.dtype), r)
51 |
52 | mean_x = tf_box_filter(lr_x, r) / N
53 | mean_y = tf_box_filter(lr_y, r) / N
54 | cov_xy = tf_box_filter(lr_x * lr_y, r) / N - mean_x * mean_y
55 | var_x = tf_box_filter(lr_x * lr_x, r) / N - mean_x * mean_x
56 |
57 | A = cov_xy / (var_x + eps)
58 | b = mean_y - A * mean_x
59 |
60 | mean_A = tf.image.resize_images(A, hr_x_shape[1: 3])
61 | mean_b = tf.image.resize_images(b, hr_x_shape[1: 3])
62 |
63 | output = mean_A * hr_x + mean_b
64 |
65 | return output
66 |
67 |
68 | if __name__ == '__main__':
69 | import cv2
70 | from tqdm import tqdm
71 |
72 | input_photo = tf.placeholder(tf.float32, [1, None, None, 3])
73 | #input_superpixel = tf.placeholder(tf.float32, [16, 256, 256, 3])
74 | output = guided_filter(input_photo, input_photo, 5, eps=1)
75 | image = cv2.imread('output_figure1/cartoon2.jpg')
76 | image = image/127.5 - 1
77 | image = np.expand_dims(image, axis=0)
78 |
79 | config = tf.ConfigProto()
80 | config.gpu_options.allow_growth = True
81 | sess = tf.Session(config=config)
82 | sess.run(tf.global_variables_initializer())
83 |
84 | out = sess.run(output, feed_dict={input_photo: image})
85 | out = (np.squeeze(out)+1)*127.5
86 | out = np.clip(out, 0, 255).astype(np.uint8)
87 | cv2.imwrite('output_figure1/cartoon2_filter.jpg', out)
88 |
--------------------------------------------------------------------------------
/test_code/network.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import tensorflow.contrib.slim as slim
4 |
5 |
6 |
7 | def resblock(inputs, out_channel=32, name='resblock'):
8 |
9 | with tf.variable_scope(name):
10 |
11 | x = slim.convolution2d(inputs, out_channel, [3, 3],
12 | activation_fn=None, scope='conv1')
13 | x = tf.nn.leaky_relu(x)
14 | x = slim.convolution2d(x, out_channel, [3, 3],
15 | activation_fn=None, scope='conv2')
16 |
17 | return x + inputs
18 |
19 |
20 |
21 |
22 | def unet_generator(inputs, channel=32, num_blocks=4, name='generator', reuse=False):
23 | with tf.variable_scope(name, reuse=reuse):
24 |
25 | x0 = slim.convolution2d(inputs, channel, [7, 7], activation_fn=None)
26 | x0 = tf.nn.leaky_relu(x0)
27 |
28 | x1 = slim.convolution2d(x0, channel, [3, 3], stride=2, activation_fn=None)
29 | x1 = tf.nn.leaky_relu(x1)
30 | x1 = slim.convolution2d(x1, channel*2, [3, 3], activation_fn=None)
31 | x1 = tf.nn.leaky_relu(x1)
32 |
33 | x2 = slim.convolution2d(x1, channel*2, [3, 3], stride=2, activation_fn=None)
34 | x2 = tf.nn.leaky_relu(x2)
35 | x2 = slim.convolution2d(x2, channel*4, [3, 3], activation_fn=None)
36 | x2 = tf.nn.leaky_relu(x2)
37 |
38 | for idx in range(num_blocks):
39 | x2 = resblock(x2, out_channel=channel*4, name='block_{}'.format(idx))
40 |
41 | x2 = slim.convolution2d(x2, channel*2, [3, 3], activation_fn=None)
42 | x2 = tf.nn.leaky_relu(x2)
43 |
44 | h1, w1 = tf.shape(x2)[1], tf.shape(x2)[2]
45 | x3 = tf.image.resize_bilinear(x2, (h1*2, w1*2))
46 | x3 = slim.convolution2d(x3+x1, channel*2, [3, 3], activation_fn=None)
47 | x3 = tf.nn.leaky_relu(x3)
48 | x3 = slim.convolution2d(x3, channel, [3, 3], activation_fn=None)
49 | x3 = tf.nn.leaky_relu(x3)
50 |
51 | h2, w2 = tf.shape(x3)[1], tf.shape(x3)[2]
52 | x4 = tf.image.resize_bilinear(x3, (h2*2, w2*2))
53 | x4 = slim.convolution2d(x4+x0, channel, [3, 3], activation_fn=None)
54 | x4 = tf.nn.leaky_relu(x4)
55 | x4 = slim.convolution2d(x4, 3, [7, 7], activation_fn=None)
56 |
57 | return x4
58 |
59 | if __name__ == '__main__':
60 |
61 |
62 | pass
--------------------------------------------------------------------------------
/test_code/saved_models/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "model-33999"
2 | all_model_checkpoint_paths: "model-33999"
3 | all_model_checkpoint_paths: "model-37499"
4 |
--------------------------------------------------------------------------------
/test_code/saved_models/model-33999.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/saved_models/model-33999.data-00000-of-00001
--------------------------------------------------------------------------------
/test_code/saved_models/model-33999.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/saved_models/model-33999.index
--------------------------------------------------------------------------------
/test_code/test_images/actress2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/actress2.jpg
--------------------------------------------------------------------------------
/test_code/test_images/china6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/china6.jpg
--------------------------------------------------------------------------------
/test_code/test_images/food16.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/food16.jpg
--------------------------------------------------------------------------------
/test_code/test_images/food6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/food6.jpg
--------------------------------------------------------------------------------
/test_code/test_images/liuyifei4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/liuyifei4.jpg
--------------------------------------------------------------------------------
/test_code/test_images/london1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/london1.jpg
--------------------------------------------------------------------------------
/test_code/test_images/mountain4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/mountain4.jpg
--------------------------------------------------------------------------------
/test_code/test_images/mountain5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/mountain5.jpg
--------------------------------------------------------------------------------
/test_code/test_images/national_park1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/national_park1.jpg
--------------------------------------------------------------------------------
/test_code/test_images/party5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/party5.jpg
--------------------------------------------------------------------------------
/test_code/test_images/party7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/party7.jpg
--------------------------------------------------------------------------------
/train_code/guided_filter.py:
--------------------------------------------------------------------------------
1 | '''
2 | CVPR 2020 submission, Paper ID 6791
3 | Source code for 'Learning to Cartoonize Using White-Box Cartoon Representations'
4 | '''
5 |
6 |
7 | import tensorflow as tf
8 | import numpy as np
9 |
10 |
11 | def tf_box_filter(x, r):
12 | ch = x.get_shape().as_list()[-1]
13 | weight = 1/((2*r+1)**2)
14 | box_kernel = weight*np.ones((2*r+1, 2*r+1, ch, 1))
15 | box_kernel = np.array(box_kernel).astype(np.float32)
16 | output = tf.nn.depthwise_conv2d(x, box_kernel, [1, 1, 1, 1], 'SAME')
17 | return output
18 |
19 |
20 |
21 | def guided_filter(x, y, r, eps=1e-2):
22 |
23 | x_shape = tf.shape(x)
24 | #y_shape = tf.shape(y)
25 |
26 | N = tf_box_filter(tf.ones((1, x_shape[1], x_shape[2], 1), dtype=x.dtype), r)
27 |
28 | mean_x = tf_box_filter(x, r) / N
29 | mean_y = tf_box_filter(y, r) / N
30 | cov_xy = tf_box_filter(x * y, r) / N - mean_x * mean_y
31 | var_x = tf_box_filter(x * x, r) / N - mean_x * mean_x
32 |
33 | A = cov_xy / (var_x + eps)
34 | b = mean_y - A * mean_x
35 |
36 | mean_A = tf_box_filter(A, r) / N
37 | mean_b = tf_box_filter(b, r) / N
38 |
39 | output = mean_A * x + mean_b
40 |
41 | return output
42 |
43 |
44 | if __name__ == '__main__':
45 | pass
46 |
--------------------------------------------------------------------------------
/train_code/layers.py:
--------------------------------------------------------------------------------
1 | '''
2 | CVPR 2020 submission, Paper ID 6791
3 | Source code for 'Learning to Cartoonize Using White-Box Cartoon Representations'
4 | '''
5 |
6 |
7 | import tensorflow as tf
8 | import numpy as np
9 | import tensorflow.contrib.slim as slim
10 |
11 |
12 |
13 | def adaptive_instance_norm(content, style, epsilon=1e-5):
14 |
15 | c_mean, c_var = tf.nn.moments(content, axes=[1, 2], keep_dims=True)
16 | s_mean, s_var = tf.nn.moments(style, axes=[1, 2], keep_dims=True)
17 | c_std, s_std = tf.sqrt(c_var + epsilon), tf.sqrt(s_var + epsilon)
18 |
19 | return s_std * (content - c_mean) / c_std + s_mean
20 |
21 |
22 |
23 | def spectral_norm(w, iteration=1):
24 | w_shape = w.shape.as_list()
25 | w = tf.reshape(w, [-1, w_shape[-1]])
26 |
27 | u = tf.get_variable("u", [1, w_shape[-1]],
28 | initializer=tf.random_normal_initializer(), trainable=False)
29 |
30 | u_hat = u
31 | v_hat = None
32 | for i in range(iteration):
33 | """
34 | power iteration
35 | Usually iteration = 1 will be enough
36 | """
37 | v_ = tf.matmul(u_hat, tf.transpose(w))
38 | v_hat = tf.nn.l2_normalize(v_)
39 |
40 | u_ = tf.matmul(v_hat, w)
41 | u_hat = tf.nn.l2_normalize(u_)
42 |
43 | u_hat = tf.stop_gradient(u_hat)
44 | v_hat = tf.stop_gradient(v_hat)
45 |
46 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
47 |
48 | with tf.control_dependencies([u.assign(u_hat)]):
49 | w_norm = w / sigma
50 | w_norm = tf.reshape(w_norm, w_shape)
51 |
52 | return w_norm
53 |
54 |
55 | def conv_spectral_norm(x, channel, k_size, stride=1, name='conv_snorm'):
56 | with tf.variable_scope(name):
57 | w = tf.get_variable("kernel", shape=[k_size[0], k_size[1], x.get_shape()[-1], channel])
58 | b = tf.get_variable("bias", [channel], initializer=tf.constant_initializer(0.0))
59 |
60 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), strides=[1, stride, stride, 1], padding='SAME') + b
61 |
62 | return x
63 |
64 |
65 |
66 | def self_attention(inputs, name='attention', reuse=False):
67 | with tf.variable_scope(name, reuse=reuse):
68 | h, w = tf.shape(inputs)[1], tf.shape(inputs)[2]
69 | bs, _, _, ch = inputs.get_shape().as_list()
70 | f = slim.convolution2d(inputs, ch//8, [1, 1], activation_fn=None)
71 | g = slim.convolution2d(inputs, ch//8, [1, 1], activation_fn=None)
72 | s = slim.convolution2d(inputs, 1, [1, 1], activation_fn=None)
73 | f_flatten = tf.reshape(f, shape=[f.shape[0], -1, f.shape[-1]])
74 | g_flatten = tf.reshape(g, shape=[g.shape[0], -1, g.shape[-1]])
75 | beta = tf.matmul(f_flatten, g_flatten, transpose_b=True)
76 | beta = tf.nn.softmax(beta)
77 |
78 | s_flatten = tf.reshape(s, shape=[s.shape[0], -1, s.shape[-1]])
79 | att_map = tf.matmul(beta, s_flatten)
80 | att_map = tf.reshape(att_map, shape=[bs, h, w, 1])
81 | gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))
82 | output = att_map * gamma + inputs
83 |
84 | return att_map, output
85 |
86 |
87 |
88 | if __name__ == '__main__':
89 | pass
90 |
91 |
92 |
93 |
--------------------------------------------------------------------------------
/train_code/loss.py:
--------------------------------------------------------------------------------
1 | '''
2 | CVPR 2020 submission, Paper ID 6791
3 | Source code for 'Learning to Cartoonize Using White-Box Cartoon Representations'
4 | '''
5 |
6 |
7 | import numpy as np
8 | import scipy.stats as st
9 | import tensorflow as tf
10 |
11 |
12 |
13 | VGG_MEAN = [103.939, 116.779, 123.68]
14 |
15 |
16 | class Vgg19:
17 |
18 | def __init__(self, vgg19_npy_path=None):
19 |
20 | self.data_dict = np.load(vgg19_npy_path, encoding='latin1', allow_pickle=True).item()
21 | print('Finished loading vgg19.npy')
22 |
23 |
24 | def build_conv4_4(self, rgb, include_fc=False):
25 |
26 | rgb_scaled = (rgb+1) * 127.5
27 |
28 | blue, green, red = tf.split(axis=3, num_or_size_splits=3, value=rgb_scaled)
29 | bgr = tf.concat(axis=3, values=[blue - VGG_MEAN[0],
30 | green - VGG_MEAN[1], red - VGG_MEAN[2]])
31 |
32 | self.conv1_1 = self.conv_layer(bgr, "conv1_1")
33 | self.relu1_1 = tf.nn.relu(self.conv1_1)
34 | self.conv1_2 = self.conv_layer(self.relu1_1, "conv1_2")
35 | self.relu1_2 = tf.nn.relu(self.conv1_2)
36 | self.pool1 = self.max_pool(self.relu1_2, 'pool1')
37 |
38 | self.conv2_1 = self.conv_layer(self.pool1, "conv2_1")
39 | self.relu2_1 = tf.nn.relu(self.conv2_1)
40 | self.conv2_2 = self.conv_layer(self.relu2_1, "conv2_2")
41 | self.relu2_2 = tf.nn.relu(self.conv2_2)
42 | self.pool2 = self.max_pool(self.relu2_2, 'pool2')
43 |
44 | self.conv3_1 = self.conv_layer(self.pool2, "conv3_1")
45 | self.relu3_1 = tf.nn.relu(self.conv3_1)
46 | self.conv3_2 = self.conv_layer(self.relu3_1, "conv3_2")
47 | self.relu3_2 = tf.nn.relu(self.conv3_2)
48 | self.conv3_3 = self.conv_layer(self.relu3_2, "conv3_3")
49 | self.relu3_3 = tf.nn.relu(self.conv3_3)
50 | self.conv3_4 = self.conv_layer(self.relu3_3, "conv3_4")
51 | self.relu3_4 = tf.nn.relu(self.conv3_4)
52 | self.pool3 = self.max_pool(self.relu3_4, 'pool3')
53 |
54 | self.conv4_1 = self.conv_layer(self.pool3, "conv4_1")
55 | self.relu4_1 = tf.nn.relu(self.conv4_1)
56 | self.conv4_2 = self.conv_layer(self.relu4_1, "conv4_2")
57 | self.relu4_2 = tf.nn.relu(self.conv4_2)
58 | self.conv4_3 = self.conv_layer(self.relu4_2, "conv4_3")
59 | self.relu4_3 = tf.nn.relu(self.conv4_3)
60 | self.conv4_4 = self.conv_layer(self.relu4_3, "conv4_4")
61 | self.relu4_4 = tf.nn.relu(self.conv4_4)
62 | self.pool4 = self.max_pool(self.relu4_4, 'pool4')
63 |
64 | return self.conv4_4
65 |
66 | def max_pool(self, bottom, name):
67 | return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1],
68 | strides=[1, 2, 2, 1], padding='SAME', name=name)
69 |
70 | def conv_layer(self, bottom, name):
71 | with tf.variable_scope(name):
72 | filt = self.get_conv_filter(name)
73 |
74 | conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')
75 |
76 | conv_biases = self.get_bias(name)
77 | bias = tf.nn.bias_add(conv, conv_biases)
78 |
79 | #relu = tf.nn.relu(bias)
80 | return bias
81 |
82 |
83 |
84 | def fc_layer(self, bottom, name):
85 | with tf.variable_scope(name):
86 | shape = bottom.get_shape().as_list()
87 | dim = 1
88 | for d in shape[1:]:
89 | dim *= d
90 | x = tf.reshape(bottom, [-1, dim])
91 |
92 | weights = self.get_fc_weight(name)
93 | biases = self.get_bias(name)
94 |
95 | # Fully connected layer. Note that the '+' operation automatically
96 | # broadcasts the biases.
97 | fc = tf.nn.bias_add(tf.matmul(x, weights), biases)
98 |
99 | return fc
100 |
101 | def get_conv_filter(self, name):
102 | return tf.constant(self.data_dict[name][0], name="filter")
103 |
104 | def get_bias(self, name):
105 | return tf.constant(self.data_dict[name][1], name="biases")
106 |
107 | def get_fc_weight(self, name):
108 | return tf.constant(self.data_dict[name][0], name="weights")
109 |
110 |
111 |
112 | def vggloss_4_4(image_a, image_b):
113 | vgg_model = Vgg19('vgg19_no_fc.npy')
114 | vgg_a = vgg_model.build_conv4_4(image_a)
115 | vgg_b = vgg_model.build_conv4_4(image_b)
116 | VGG_loss = tf.losses.absolute_difference(vgg_a, vgg_b)
117 | #VGG_loss = tf.nn.l2_loss(vgg_a - vgg_b)
118 | h, w, c= vgg_a.get_shape().as_list()[1:]
119 | VGG_loss = tf.reduce_mean(VGG_loss)/(h*w*c)
120 | return VGG_loss
121 |
122 |
123 |
124 | def wgan_loss(discriminator, real, fake, patch=True,
125 | channel=32, name='discriminator', lambda_=2):
126 | real_logits = discriminator(real, patch=patch, channel=channel, name=name, reuse=False)
127 | fake_logits = discriminator(fake, patch=patch, channel=channel, name=name, reuse=True)
128 |
129 | d_loss_real = - tf.reduce_mean(real_logits)
130 | d_loss_fake = tf.reduce_mean(fake_logits)
131 |
132 | d_loss = d_loss_real + d_loss_fake
133 | g_loss = - d_loss_fake
134 |
135 | """ Gradient Penalty """
136 | # This is borrowed from https://github.com/kodalinaveen3/DRAGAN/blob/master/DRAGAN.ipynb
137 | alpha = tf.random_uniform([tf.shape(real)[0], 1, 1, 1], minval=0.,maxval=1.)
138 | differences = fake - real # This is different from MAGAN
139 | interpolates = real + (alpha * differences)
140 | inter_logit = discriminator(interpolates, channel=channel, name=name, reuse=True)
141 | gradients = tf.gradients(inter_logit, [interpolates])[0]
142 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
143 | gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
144 | d_loss += lambda_ * gradient_penalty
145 |
146 | return d_loss, g_loss
147 |
148 |
149 | def gan_loss(discriminator, real, fake, scale=1,channel=32, patch=False, name='discriminator'):
150 |
151 | real_logit = discriminator(real, scale, channel, name=name, patch=patch, reuse=False)
152 | fake_logit = discriminator(fake, scale, channel, name=name, patch=patch, reuse=True)
153 |
154 | real_logit = tf.nn.sigmoid(real_logit)
155 | fake_logit = tf.nn.sigmoid(fake_logit)
156 |
157 | g_loss_blur = -tf.reduce_mean(tf.log(fake_logit))
158 | d_loss_blur = -tf.reduce_mean(tf.log(real_logit) + tf.log(1. - fake_logit))
159 |
160 | return d_loss_blur, g_loss_blur
161 |
162 |
163 |
164 | def lsgan_loss(discriminator, real, fake, scale=1,
165 | channel=32, patch=False, name='discriminator'):
166 |
167 | real_logit = discriminator(real, scale, channel, name=name, patch=patch, reuse=False)
168 | fake_logit = discriminator(fake, scale, channel, name=name, patch=patch, reuse=True)
169 |
170 | g_loss = tf.reduce_mean((fake_logit - 1)**2)
171 | d_loss = 0.5*(tf.reduce_mean((real_logit - 1)**2) + tf.reduce_mean(fake_logit**2))
172 |
173 | return d_loss, g_loss
174 |
175 |
176 |
177 | def total_variation_loss(image, k_size=1):
178 | h, w = image.get_shape().as_list()[1:3]
179 | tv_h = tf.reduce_mean((image[:, k_size:, :, :] - image[:, :h - k_size, :, :])**2)
180 | tv_w = tf.reduce_mean((image[:, :, k_size:, :] - image[:, :, :w - k_size, :])**2)
181 | tv_loss = (tv_h + tv_w)/(3*h*w)
182 | return tv_loss
183 |
184 |
185 |
186 |
187 | if __name__ == '__main__':
188 | pass
189 |
190 |
191 |
--------------------------------------------------------------------------------
/train_code/network.py:
--------------------------------------------------------------------------------
1 | '''
2 | CVPR 2020 submission, Paper ID 6791
3 | Source code for 'Learning to Cartoonize Using White-Box Cartoon Representations'
4 | '''
5 |
6 |
7 | import layers
8 | import tensorflow as tf
9 | import numpy as np
10 | import tensorflow.contrib.slim as slim
11 |
12 | from tqdm import tqdm
13 |
14 |
15 |
16 | def resblock(inputs, out_channel=32, name='resblock'):
17 |
18 | with tf.variable_scope(name):
19 |
20 | x = slim.convolution2d(inputs, out_channel, [3, 3],
21 | activation_fn=None, scope='conv1')
22 | x = tf.nn.leaky_relu(x)
23 | x = slim.convolution2d(x, out_channel, [3, 3],
24 | activation_fn=None, scope='conv2')
25 |
26 | return x + inputs
27 |
28 |
29 |
30 | def generator(inputs, channel=32, num_blocks=4, name='generator', reuse=False):
31 | with tf.variable_scope(name, reuse=reuse):
32 |
33 | x = slim.convolution2d(inputs, channel, [7, 7], activation_fn=None)
34 | x = tf.nn.leaky_relu(x)
35 |
36 | x = slim.convolution2d(x, channel*2, [3, 3], stride=2, activation_fn=None)
37 | x = slim.convolution2d(x, channel*2, [3, 3], activation_fn=None)
38 | x = tf.nn.leaky_relu(x)
39 |
40 | x = slim.convolution2d(x, channel*4, [3, 3], stride=2, activation_fn=None)
41 | x = slim.convolution2d(x, channel*4, [3, 3], activation_fn=None)
42 | x = tf.nn.leaky_relu(x)
43 |
44 | for idx in range(num_blocks):
45 | x = resblock(x, out_channel=channel*4, name='block_{}'.format(idx))
46 |
47 | x = slim.conv2d_transpose(x, channel*2, [3, 3], stride=2, activation_fn=None)
48 | x = slim.convolution2d(x, channel*2, [3, 3], activation_fn=None)
49 |
50 | x = tf.nn.leaky_relu(x)
51 |
52 | x = slim.conv2d_transpose(x, channel, [3, 3], stride=2, activation_fn=None)
53 | x = slim.convolution2d(x, channel, [3, 3], activation_fn=None)
54 | x = tf.nn.leaky_relu(x)
55 |
56 | x = slim.convolution2d(x, 3, [7, 7], activation_fn=None)
57 | #x = tf.clip_by_value(x, -0.999999, 0.999999)
58 |
59 | return x
60 |
61 |
62 | def unet_generator(inputs, channel=32, num_blocks=4, name='generator', reuse=False):
63 | with tf.variable_scope(name, reuse=reuse):
64 |
65 | x0 = slim.convolution2d(inputs, channel, [7, 7], activation_fn=None)
66 | x0 = tf.nn.leaky_relu(x0)
67 |
68 | x1 = slim.convolution2d(x0, channel, [3, 3], stride=2, activation_fn=None)
69 | x1 = tf.nn.leaky_relu(x1)
70 | x1 = slim.convolution2d(x1, channel*2, [3, 3], activation_fn=None)
71 | x1 = tf.nn.leaky_relu(x1)
72 |
73 | x2 = slim.convolution2d(x1, channel*2, [3, 3], stride=2, activation_fn=None)
74 | x2 = tf.nn.leaky_relu(x2)
75 | x2 = slim.convolution2d(x2, channel*4, [3, 3], activation_fn=None)
76 | x2 = tf.nn.leaky_relu(x2)
77 |
78 | for idx in range(num_blocks):
79 | x2 = resblock(x2, out_channel=channel*4, name='block_{}'.format(idx))
80 |
81 | x2 = slim.convolution2d(x2, channel*2, [3, 3], activation_fn=None)
82 | x2 = tf.nn.leaky_relu(x2)
83 |
84 | h1, w1 = tf.shape(x2)[1], tf.shape(x2)[2]
85 | x3 = tf.image.resize_bilinear(x2, (h1*2, w1*2))
86 | x3 = slim.convolution2d(x3+x1, channel*2, [3, 3], activation_fn=None)
87 | x3 = tf.nn.leaky_relu(x3)
88 | x3 = slim.convolution2d(x3, channel, [3, 3], activation_fn=None)
89 | x3 = tf.nn.leaky_relu(x3)
90 |
91 | h2, w2 = tf.shape(x3)[1], tf.shape(x3)[2]
92 | x4 = tf.image.resize_bilinear(x3, (h2*2, w2*2))
93 | x4 = slim.convolution2d(x4+x0, channel, [3, 3], activation_fn=None)
94 | x4 = tf.nn.leaky_relu(x4)
95 | x4 = slim.convolution2d(x4, 3, [7, 7], activation_fn=None)
96 | #x4 = tf.clip_by_value(x4, -1, 1)
97 | return x4
98 |
99 |
100 |
101 | def disc_bn(x, scale=1, channel=32, is_training=True,
102 | name='discriminator', patch=True, reuse=False):
103 |
104 | with tf.variable_scope(name, reuse=reuse):
105 |
106 | for idx in range(3):
107 | x = slim.convolution2d(x, channel*2**idx, [3, 3], stride=2, activation_fn=None)
108 | x = slim.batch_norm(x, is_training=is_training, center=True, scale=True)
109 | x = tf.nn.leaky_relu(x)
110 |
111 | x = slim.convolution2d(x, channel*2**idx, [3, 3], activation_fn=None)
112 | x = slim.batch_norm(x, is_training=is_training, center=True, scale=True)
113 | x = tf.nn.leaky_relu(x)
114 |
115 | if patch == True:
116 | x = slim.convolution2d(x, 1, [1, 1], activation_fn=None)
117 | else:
118 | x = tf.reduce_mean(x, axis=[1, 2])
119 | x = slim.fully_connected(x, 1, activation_fn=None)
120 |
121 | return x
122 |
123 |
124 |
125 |
126 | def disc_sn(x, scale=1, channel=32, patch=True, name='discriminator', reuse=False):
127 | with tf.variable_scope(name, reuse=reuse):
128 |
129 | for idx in range(3):
130 | x = layers.conv_spectral_norm(x, channel*2**idx, [3, 3],
131 | stride=2, name='conv{}_1'.format(idx))
132 | x = tf.nn.leaky_relu(x)
133 |
134 | x = layers.conv_spectral_norm(x, channel*2**idx, [3, 3],
135 | name='conv{}_2'.format(idx))
136 | x = tf.nn.leaky_relu(x)
137 |
138 |
139 | if patch == True:
140 | x = layers.conv_spectral_norm(x, 1, [1, 1], name='conv_out'.format(idx))
141 |
142 | else:
143 | x = tf.reduce_mean(x, axis=[1, 2])
144 | x = slim.fully_connected(x, 1, activation_fn=None)
145 |
146 | return x
147 |
148 |
149 | def disc_ln(x, channel=32, is_training=True, name='discriminator', patch=True, reuse=False):
150 | with tf.variable_scope(name, reuse=reuse):
151 |
152 | for idx in range(3):
153 | x = slim.convolution2d(x, channel*2**idx, [3, 3], stride=2, activation_fn=None)
154 | x = tf.contrib.layers.layer_norm(x)
155 | x = tf.nn.leaky_relu(x)
156 |
157 | x = slim.convolution2d(x, channel*2**idx, [3, 3], activation_fn=None)
158 | x = tf.contrib.layers.layer_norm(x)
159 | x = tf.nn.leaky_relu(x)
160 |
161 | if patch == True:
162 | x = slim.convolution2d(x, 1, [1, 1], activation_fn=None)
163 | else:
164 | x = tf.reduce_mean(x, axis=[1, 2])
165 | x = slim.fully_connected(x, 1, activation_fn=None)
166 |
167 | return x
168 |
169 |
170 |
171 |
172 | if __name__ == '__main__':
173 | pass
174 |
175 |
--------------------------------------------------------------------------------
/train_code/pretrain.py:
--------------------------------------------------------------------------------
1 | '''
2 | Source code for CVPR 2020 paper
3 | 'Learning to Cartoonize Using White-Box Cartoon Representations'
4 | by Xinrui Wang and Jinze yu
5 | '''
6 |
7 |
8 |
9 | import tensorflow as tf
10 | import tensorflow.contrib.slim as slim
11 |
12 | import utils
13 | import os
14 | import numpy as np
15 | import argparse
16 | import network
17 | from tqdm import tqdm
18 |
19 |
20 | os.environ["CUDA_VISIBLE_DEVICES"]="0"
21 |
22 |
23 | def arg_parser():
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument("--patch_size", default = 256, type = int)
26 | parser.add_argument("--batch_size", default = 16, type = int)
27 | parser.add_argument("--total_iter", default = 50000, type = int)
28 | parser.add_argument("--adv_train_lr", default = 2e-4, type = float)
29 | parser.add_argument("--gpu_fraction", default = 0.5, type = float)
30 | parser.add_argument("--save_dir", default = 'pretrain')
31 |
32 | args = parser.parse_args()
33 |
34 | return args
35 |
36 |
37 |
38 | def train(args):
39 |
40 |
41 | input_photo = tf.placeholder(tf.float32, [args.batch_size,
42 | args.patch_size, args.patch_size, 3])
43 |
44 | output = network.unet_generator(input_photo)
45 |
46 | recon_loss = tf.reduce_mean(tf.losses.absolute_difference(input_photo, output))
47 |
48 | all_vars = tf.trainable_variables()
49 | gene_vars = [var for var in all_vars if 'gene' in var.name]
50 |
51 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
52 | with tf.control_dependencies(update_ops):
53 |
54 | optim = tf.train.AdamOptimizer(args.adv_train_lr, beta1=0.5, beta2=0.99)\
55 | .minimize(recon_loss, var_list=gene_vars)
56 |
57 |
58 | '''
59 | config = tf.ConfigProto()
60 | config.gpu_options.allow_growth = True
61 | sess = tf.Session(config=config)
62 | '''
63 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_fraction)
64 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
65 | saver = tf.train.Saver(var_list=gene_vars, max_to_keep=20)
66 |
67 | with tf.device('/device:GPU:0'):
68 |
69 | sess.run(tf.global_variables_initializer())
70 | face_photo_dir = 'dataset/photo_face'
71 | face_photo_list = utils.load_image_list(face_photo_dir)
72 | scenery_photo_dir = 'dataset/photo_scenery'
73 | scenery_photo_list = utils.load_image_list(scenery_photo_dir)
74 |
75 |
76 | for total_iter in tqdm(range(args.total_iter)):
77 |
78 | if np.mod(total_iter, 5) == 0:
79 | photo_batch = utils.next_batch(face_photo_list, args.batch_size)
80 | else:
81 | photo_batch = utils.next_batch(scenery_photo_list, args.batch_size)
82 |
83 | _, r_loss = sess.run([optim, recon_loss], feed_dict={input_photo: photo_batch})
84 |
85 | if np.mod(total_iter+1, 50) == 0:
86 |
87 | print('pretrain, iter: {}, recon_loss: {}'.format(total_iter, r_loss))
88 | if np.mod(total_iter+1, 500 ) == 0:
89 | saver.save(sess, args.save_dir+'save_models/model',
90 | write_meta_graph=False, global_step=total_iter)
91 |
92 | photo_face = utils.next_batch(face_photo_list, args.batch_size)
93 | photo_scenery = utils.next_batch(scenery_photo_list, args.batch_size)
94 |
95 | result_face = sess.run(output, feed_dict={input_photo: photo_face})
96 |
97 | result_scenery = sess.run(output, feed_dict={input_photo: photo_scenery})
98 |
99 | utils.write_batch_image(result_face, args.save_dir+'/images',
100 | str(total_iter)+'_face_result.jpg', 4)
101 | utils.write_batch_image(photo_face, args.save_dir+'/images',
102 | str(total_iter)+'_face_photo.jpg', 4)
103 | utils.write_batch_image(result_scenery, args.save_dir+'/images',
104 | str(total_iter)+'_scenery_result.jpg', 4)
105 | utils.write_batch_image(photo_scenery, args.save_dir+'/images',
106 | str(total_iter)+'_scenery_photo.jpg', 4)
107 |
108 |
109 |
110 |
111 |
112 | if __name__ == '__main__':
113 |
114 | args = arg_parser()
115 | train(args)
116 |
--------------------------------------------------------------------------------
/train_code/selective_search/__init__.py:
--------------------------------------------------------------------------------
1 | from .core import selective_search, box_filter
2 |
--------------------------------------------------------------------------------
/train_code/selective_search/adaptive_color.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def label2rgb(label_field, image, kind='avg', bg_label=-1, bg_color=(0, 0, 0)):
5 |
6 | #std_list = list()
7 | out = np.zeros_like(image)
8 | labels = np.unique(label_field)
9 | bg = (labels == bg_label)
10 | if bg.any():
11 | labels = labels[labels != bg_label]
12 | mask = (label_field == bg_label).nonzero()
13 | out[mask] = bg_color
14 | for label in labels:
15 | mask = (label_field == label).nonzero()
16 | #std = np.std(image[mask])
17 | #std_list.append(std)
18 | if kind == 'avg':
19 | color = image[mask].mean(axis=0)
20 | elif kind == 'median':
21 | color = np.median(image[mask], axis=0)
22 | elif kind == 'mix':
23 | std = np.std(image[mask])
24 | if std < 20:
25 | color = image[mask].mean(axis=0)
26 | elif 20 < std < 40:
27 | mean = image[mask].mean(axis=0)
28 | median = np.median(image[mask], axis=0)
29 | color = 0.5*mean + 0.5*median
30 | elif 40 < std:
31 | color = np.median(image[mask], axis=0)
32 | out[mask] = color
33 | return out
--------------------------------------------------------------------------------
/train_code/selective_search/batch_ss.py:
--------------------------------------------------------------------------------
1 | '''
2 | CVPR 2020 submission, Paper ID 6791
3 | Source code for 'Learning to Cartoonize Using White-Box Cartoon Representations'
4 | '''
5 |
6 |
7 | import numpy as np
8 | from adaptive_color import label2rgb
9 | from joblib import Parallel, delayed
10 | from skimage.segmentation import felzenszwalb
11 | from util import switch_color_space
12 | from structure import HierarchicalGrouping
13 |
14 |
15 | def color_ss_map(image, color_space='Lab', k=10,
16 | sim_strategy='CTSF', seg_num=200, power=1):
17 |
18 | img_seg = felzenszwalb(image, scale=k, sigma=0.8, min_size=100)
19 | img_cvtcolor = label2rgb(img_seg, image, kind='mix')
20 | img_cvtcolor = switch_color_space(img_cvtcolor, color_space)
21 | S = HierarchicalGrouping(img_cvtcolor, img_seg, sim_strategy)
22 | S.build_regions()
23 | S.build_region_pairs()
24 |
25 | # Start hierarchical grouping
26 |
27 | while S.num_regions() > seg_num:
28 |
29 | i,j = S.get_highest_similarity()
30 | S.merge_region(i,j)
31 | S.remove_similarities(i,j)
32 | S.calculate_similarity_for_new_region()
33 |
34 | image = label2rgb(S.img_seg, image, kind='mix')
35 | image = (image+1)/2
36 | image = image**power
37 | image = image/np.max(image)
38 | image = image*2 - 1
39 |
40 | return image
41 |
42 |
43 | def selective_adacolor(batch_image, seg_num=200, power=1):
44 | num_job = np.shape(batch_image)[0]
45 | batch_out = Parallel(n_jobs=num_job)(delayed(color_ss_map)\
46 | (image, seg_num, power) for image in batch_image)
47 | return np.array(batch_out)
48 |
49 |
50 | if __name__ == '__main__':
51 | pass
--------------------------------------------------------------------------------
/train_code/selective_search/core.py:
--------------------------------------------------------------------------------
1 | from joblib import Parallel, delayed
2 | from skimage.segmentation import felzenszwalb
3 | from util import oversegmentation, switch_color_space, load_strategy
4 | from structure import HierarchicalGrouping
5 |
6 |
7 |
8 |
9 | def selective_search_one(img, color_space, k, sim_strategy):
10 | '''
11 | Selective Search using single diversification strategy
12 | Parameters
13 | ----------
14 | im_orig : ndarray
15 | Original image
16 | color_space : string
17 | Colour Spaces
18 | k : int
19 | Threshold parameter for starting regions
20 | sim_stategy : string
21 | Combinations of similarity measures
22 |
23 | Returns
24 | -------
25 | boxes : list
26 | Bounding boxes of the regions
27 | priority: list
28 | Small priority number indicates higher position in the hierarchy
29 | '''
30 |
31 | # convert RGB image to target color space
32 | img = switch_color_space(img, color_space)
33 |
34 | # Generate starting locations
35 | img_seg = oversegmentation(img, k)
36 |
37 | # Initialze hierarchical grouping
38 | S = HierarchicalGrouping(img, img_seg, sim_strategy)
39 |
40 | S.build_regions()
41 | S.build_region_pairs()
42 |
43 | # Start hierarchical grouping
44 | while not S.is_empty():
45 | i,j = S.get_highest_similarity()
46 |
47 | S.merge_region(i,j)
48 |
49 | S.remove_similarities(i,j)
50 |
51 | S.calculate_similarity_for_new_region()
52 |
53 | # convert the order by hierarchical priority
54 | boxes = [x['box'] for x in S.regions.values()][::-1]
55 |
56 | # drop duplicates by maintaining order
57 | boxes = list(dict.fromkeys(boxes))
58 |
59 | # generate priority for boxes
60 | priorities = list(range(1, len(boxes)+1))
61 |
62 | return boxes, priorities
63 |
64 |
65 | def selective_search(img, mode='single', random=False):
66 | """
67 | Selective Search in Python
68 | """
69 |
70 | # load selective search strategy
71 | strategy = load_strategy(mode)
72 |
73 | # Excecute selective search in parallel
74 | vault = Parallel(n_jobs=1)(delayed(selective_search_one)(img, color, k, sim) for (color, k, sim) in strategy)
75 |
76 | boxes = [x for x,_ in vault]
77 | priorities = [y for _, y in vault]
78 |
79 | boxes = [item for sublist in boxes for item in sublist]
80 | priorities = [item for sublist in priorities for item in sublist]
81 |
82 | if random:
83 | # Do pseudo random sorting as in paper
84 | rand_list = [random() for i in range(len(priorities))]
85 | priorities = [p * r for p, r in zip(priorities, rand_list)]
86 | boxes = [b for _, b in sorted(zip(priorities, boxes))]
87 |
88 | # drop duplicates by maintaining order
89 | boxes = list(dict.fromkeys(boxes))
90 |
91 | return boxes
92 |
93 | def box_filter(boxes, min_size=20, max_ratio=None, topN=None):
94 | proposal = []
95 |
96 | for box in boxes:
97 | # Calculate width and height of the box
98 | w, h = box[2] - box[0], box[3] - box[1]
99 |
100 | # Filter for size
101 | if w < min_size or h < min_size:
102 | continue
103 |
104 | # Filter for box ratio
105 | if max_ratio:
106 | if w / h > max_ratio or h / w > max_ratio:
107 | continue
108 |
109 | proposal.append(box)
110 |
111 | if topN:
112 | if topN <= len(proposal):
113 | return proposal[:topN]
114 | else:
115 | return proposal
116 | else:
117 | return proposal
118 |
119 |
120 |
121 |
122 |
--------------------------------------------------------------------------------
/train_code/selective_search/measure.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from skimage.feature import local_binary_pattern
3 |
4 | def _calculate_color_sim(ri, rj):
5 | """
6 | Calculate color similarity using histogram intersection
7 | """
8 | return sum([min(a, b) for a, b in zip(ri["color_hist"], rj["color_hist"])])
9 |
10 |
11 | def _calculate_texture_sim(ri, rj):
12 | """
13 | Calculate texture similarity using histogram intersection
14 | """
15 | return sum([min(a, b) for a, b in zip(ri["texture_hist"], rj["texture_hist"])])
16 |
17 |
18 | def _calculate_size_sim(ri, rj, imsize):
19 | """
20 | Size similarity boosts joint between small regions, which prevents
21 | a single region from engulfing other blobs one by one.
22 |
23 | size (ri, rj) = 1 − [size(ri) + size(rj)] / size(image)
24 | """
25 | return 1.0 - (ri['size'] + rj['size']) / imsize
26 |
27 |
28 | def _calculate_fill_sim(ri, rj, imsize):
29 | """
30 | Fill similarity measures how well ri and rj fit into each other.
31 | BBij is the bounding box around ri and rj.
32 |
33 | fill(ri, rj) = 1 − [size(BBij) − size(ri) − size(ri)] / size(image)
34 | """
35 |
36 | bbsize = (max(ri['box'][2], rj['box'][2]) - min(ri['box'][0], rj['box'][0])) * (max(ri['box'][3], rj['box'][3]) - min(ri['box'][1], rj['box'][1]))
37 |
38 | return 1.0 - (bbsize - ri['size'] - rj['size']) / imsize
39 |
40 |
41 | def calculate_color_hist(mask, img):
42 | """
43 | Calculate colour histogram for the region.
44 | The output will be an array with n_BINS * n_color_channels.
45 | The number of channel is varied because of different
46 | colour spaces.
47 | """
48 |
49 | BINS = 25
50 | if len(img.shape) == 2:
51 | img = img.reshape(img.shape[0], img.shape[1], 1)
52 |
53 | channel_nums = img.shape[2]
54 | hist = np.array([])
55 |
56 | for channel in range(channel_nums):
57 | layer = img[:, :, channel][mask]
58 | hist = np.concatenate([hist] + [np.histogram(layer, BINS)[0]])
59 |
60 | # L1 normalize
61 | hist = hist / np.sum(hist)
62 |
63 | return hist
64 |
65 |
66 | def generate_lbp_image(img):
67 |
68 | if len(img.shape) == 2:
69 | img = img.reshape(img.shape[0], img.shape[1], 1)
70 | channel_nums = img.shape[2]
71 |
72 | lbp_img = np.zeros(img.shape)
73 | for channel in range(channel_nums):
74 | layer = img[:, :, channel]
75 | lbp_img[:, :,channel] = local_binary_pattern(layer, 8, 1)
76 |
77 | return lbp_img
78 |
79 |
80 | def calculate_texture_hist(mask, lbp_img):
81 | """
82 | Use LBP for now, enlightened by AlpacaDB's implementation.
83 | Plan to switch to Gaussian derivatives as the paper in future
84 | version.
85 | """
86 |
87 | BINS = 10
88 | channel_nums = lbp_img.shape[2]
89 | hist = np.array([])
90 |
91 | for channel in range(channel_nums):
92 | layer = lbp_img[:, :, channel][mask]
93 | hist = np.concatenate([hist] + [np.histogram(layer, BINS)[0]])
94 |
95 | # L1 normalize
96 | hist = hist / np.sum(hist)
97 |
98 | return hist
99 |
100 |
101 | def calculate_sim(ri, rj, imsize, sim_strategy):
102 | """
103 | Calculate similarity between region ri and rj using diverse
104 | combinations of similarity measures.
105 | C: color, T: texture, S: size, F: fill.
106 | """
107 | sim = 0
108 |
109 | if 'C' in sim_strategy:
110 | sim += _calculate_color_sim(ri, rj)
111 | if 'T' in sim_strategy:
112 | sim += _calculate_texture_sim(ri, rj)
113 | if 'S' in sim_strategy:
114 | sim += _calculate_size_sim(ri, rj, imsize)
115 | if 'F' in sim_strategy:
116 | sim += _calculate_fill_sim(ri, rj, imsize)
117 |
118 | return sim
119 |
--------------------------------------------------------------------------------
/train_code/selective_search/structure.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from skimage.segmentation import find_boundaries
3 | from skimage.segmentation import felzenszwalb
4 | from scipy.ndimage import find_objects
5 | import measure
6 |
7 |
8 | class HierarchicalGrouping(object):
9 | def __init__(self, img, img_seg, sim_strategy):
10 | self.img = img
11 | self.sim_strategy = sim_strategy
12 | self.img_seg = img_seg.copy()
13 | self.labels = np.unique(self.img_seg).tolist()
14 |
15 | def build_regions(self):
16 | self.regions = {}
17 | lbp_img = measure.generate_lbp_image(self.img)
18 | for label in self.labels:
19 | size = (self.img_seg == 1).sum()
20 | region_slice = find_objects(self.img_seg==label)[0]
21 | box = tuple([region_slice[i].start for i in (1,0)] +
22 | [region_slice[i].stop for i in (1,0)])
23 |
24 | mask = self.img_seg == label
25 | color_hist = measure.calculate_color_hist(mask, self.img)
26 | texture_hist = measure.calculate_texture_hist(mask, lbp_img)
27 |
28 | self.regions[label] = {
29 | 'size': size,
30 | 'box': box,
31 | 'color_hist': color_hist,
32 | 'texture_hist': texture_hist
33 | }
34 |
35 |
36 | def build_region_pairs(self):
37 | self.s = {}
38 | for i in self.labels:
39 | neighbors = self._find_neighbors(i)
40 | for j in neighbors:
41 | if i < j:
42 | self.s[(i,j)] = measure.calculate_sim(self.regions[i],
43 | self.regions[j],
44 | self.img.size,
45 | self.sim_strategy)
46 |
47 |
48 | def _find_neighbors(self, label):
49 | """
50 | Parameters
51 | ----------
52 | label : int
53 | label of the region
54 | Returns
55 | -------
56 | neighbors : list
57 | list of labels of neighbors
58 | """
59 |
60 | boundary = find_boundaries(self.img_seg == label,
61 | mode='outer')
62 | neighbors = np.unique(self.img_seg[boundary]).tolist()
63 |
64 | return neighbors
65 |
66 | def get_highest_similarity(self):
67 | return sorted(self.s.items(), key=lambda i: i[1])[-1][0]
68 |
69 | def merge_region(self, i, j):
70 |
71 | # generate a unique label and put in the label list
72 | new_label = max(self.labels) + 1
73 | self.labels.append(new_label)
74 |
75 | # merge blobs and update blob set
76 | ri, rj = self.regions[i], self.regions[j]
77 |
78 | new_size = ri['size'] + rj['size']
79 | new_box = (min(ri['box'][0], rj['box'][0]),
80 | min(ri['box'][1], rj['box'][1]),
81 | max(ri['box'][2], rj['box'][2]),
82 | max(ri['box'][3], rj['box'][3]))
83 | value = {
84 | 'box': new_box,
85 | 'size': new_size,
86 | 'color_hist':
87 | (ri['color_hist'] * ri['size']
88 | + rj['color_hist'] * rj['size']) / new_size,
89 | 'texture_hist':
90 | (ri['texture_hist'] * ri['size']
91 | + rj['texture_hist'] * rj['size']) / new_size,
92 | }
93 |
94 | self.regions[new_label] = value
95 |
96 | # update segmentation mask
97 | self.img_seg[self.img_seg == i] = new_label
98 | self.img_seg[self.img_seg == j] = new_label
99 |
100 | def remove_similarities(self, i, j):
101 |
102 | # mark keys for region pairs to be removed
103 | key_to_delete = []
104 | for key in self.s.keys():
105 | if (i in key) or (j in key):
106 | key_to_delete.append(key)
107 |
108 | for key in key_to_delete:
109 | del self.s[key]
110 |
111 | # remove old labels in label list
112 | self.labels.remove(i)
113 | self.labels.remove(j)
114 |
115 | def calculate_similarity_for_new_region(self):
116 | i = max(self.labels)
117 | neighbors = self._find_neighbors(i)
118 |
119 | for j in neighbors:
120 | # i is larger than j, so use (j,i) instead
121 | self.s[(j,i)] = measure.calculate_sim(self.regions[i],
122 | self.regions[j],
123 | self.img.size,
124 | self.sim_strategy)
125 |
126 | def is_empty(self):
127 | return True if not self.s.keys() else False
128 |
129 |
130 | def num_regions(self):
131 | return len(self.s.keys())
132 |
--------------------------------------------------------------------------------
/train_code/selective_search/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from itertools import product
3 |
4 | from skimage.segmentation import felzenszwalb
5 | from skimage.color import rgb2hsv, rgb2lab, rgb2grey
6 |
7 |
8 | def oversegmentation(img, k):
9 | """
10 | Generating various starting regions using the method of
11 | Felzenszwalb.
12 | k effectively sets a scale of observation, in that
13 | a larger k causes a preference for larger components.
14 | sigma = 0.8 which was used in the original paper.
15 | min_size = 100 refer to Keon's Matlab implementation.
16 | """
17 | img_seg = felzenszwalb(img, scale=k, sigma=0.8, min_size=100)
18 |
19 | return img_seg
20 |
21 |
22 | def switch_color_space(img, target):
23 | """
24 | RGB to target color space conversion.
25 | I: the intensity (grey scale), Lab, rgI: the rg channels of
26 | normalized RGB plus intensity, HSV, H: the Hue channel H from HSV
27 | """
28 |
29 | if target == 'HSV':
30 | return rgb2hsv(img)
31 |
32 | elif target == 'Lab':
33 | return rgb2lab(img)
34 |
35 | elif target == 'I':
36 | return rgb2grey(img)
37 |
38 | elif target == 'rgb':
39 | img = img / np.sum(img, axis=0)
40 | return img
41 |
42 | elif target == 'rgI':
43 | img = img / np.sum(img, axis=0)
44 | img[:,:,2] = rgb2grey(img)
45 | return img
46 |
47 | elif target == 'H':
48 | return rgb2hsv(img)[:,:,0]
49 |
50 | else:
51 | raise "{} is not suported.".format(target)
52 |
53 | def load_strategy(mode):
54 | # TODO: Add mode sanity check
55 |
56 | cfg = {
57 | "single": {
58 | "ks": [100],
59 | "colors": ["HSV"],
60 | "sims": ["CTSF"]
61 | },
62 | "lab": {
63 | "ks": [100],
64 | "colors": ["Lab"],
65 | "sims": ["CTSF"]
66 | },
67 | "fast": {
68 | "ks": [50, 100],
69 | "colors": ["HSV", "Lab"],
70 | "sims": ["CTSF", "TSF"]
71 | },
72 | "quality": {
73 | "ks": [50, 100, 150, 300],
74 | "colors": ["HSV", "Lab", "I", "rgI", "H"],
75 | "sims": ["CTSF", "TSF", "F", "S"]
76 | }
77 | }
78 |
79 | if isinstance(mode, dict):
80 | cfg['manual'] = mode
81 | mode = 'manual'
82 |
83 | colors, ks, sims = cfg[mode]['colors'], cfg[mode]['ks'], cfg[mode]['sims']
84 |
85 | return product(colors, ks, sims)
86 |
--------------------------------------------------------------------------------
/train_code/train.py:
--------------------------------------------------------------------------------
1 | '''
2 | Source code for CVPR 2020 paper
3 | 'Learning to Cartoonize Using White-Box Cartoon Representations'
4 | by Xinrui Wang and Jinze yu
5 | '''
6 |
7 |
8 | import tensorflow as tf
9 | import tensorflow.contrib.slim as slim
10 |
11 | import utils
12 | import os
13 | import numpy as np
14 | import argparse
15 | import network
16 | import loss
17 |
18 | from tqdm import tqdm
19 | from guided_filter import guided_filter
20 |
21 | os.environ["CUDA_VISIBLE_DEVICES"]="0"
22 |
23 |
24 | def arg_parser():
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument("--patch_size", default = 256, type = int)
27 | parser.add_argument("--batch_size", default = 16, type = int)
28 | parser.add_argument("--total_iter", default = 100000, type = int)
29 | parser.add_argument("--adv_train_lr", default = 2e-4, type = float)
30 | parser.add_argument("--gpu_fraction", default = 0.5, type = float)
31 | parser.add_argument("--save_dir", default = 'train_cartoon', type = str)
32 | parser.add_argument("--use_enhance", default = False)
33 |
34 | args = parser.parse_args()
35 |
36 | return args
37 |
38 |
39 |
40 | def train(args):
41 |
42 |
43 | input_photo = tf.placeholder(tf.float32, [args.batch_size,
44 | args.patch_size, args.patch_size, 3])
45 | input_superpixel = tf.placeholder(tf.float32, [args.batch_size,
46 | args.patch_size, args.patch_size, 3])
47 | input_cartoon = tf.placeholder(tf.float32, [args.batch_size,
48 | args.patch_size, args.patch_size, 3])
49 |
50 | output = network.unet_generator(input_photo)
51 | output = guided_filter(input_photo, output, r=1)
52 |
53 |
54 | blur_fake = guided_filter(output, output, r=5, eps=2e-1)
55 | blur_cartoon = guided_filter(input_cartoon, input_cartoon, r=5, eps=2e-1)
56 |
57 | gray_fake, gray_cartoon = utils.color_shift(output, input_cartoon)
58 |
59 | d_loss_gray, g_loss_gray = loss.lsgan_loss(network.disc_sn, gray_cartoon, gray_fake,
60 | scale=1, patch=True, name='disc_gray')
61 | d_loss_blur, g_loss_blur = loss.lsgan_loss(network.disc_sn, blur_cartoon, blur_fake,
62 | scale=1, patch=True, name='disc_blur')
63 |
64 |
65 | vgg_model = loss.Vgg19('vgg19_no_fc.npy')
66 | vgg_photo = vgg_model.build_conv4_4(input_photo)
67 | vgg_output = vgg_model.build_conv4_4(output)
68 | vgg_superpixel = vgg_model.build_conv4_4(input_superpixel)
69 | h, w, c = vgg_photo.get_shape().as_list()[1:]
70 |
71 | photo_loss = tf.reduce_mean(tf.losses.absolute_difference(vgg_photo, vgg_output))/(h*w*c)
72 | superpixel_loss = tf.reduce_mean(tf.losses.absolute_difference\
73 | (vgg_superpixel, vgg_output))/(h*w*c)
74 | recon_loss = photo_loss + superpixel_loss
75 | tv_loss = loss.total_variation_loss(output)
76 |
77 | g_loss_total = 1e4*tv_loss + 1e-1*g_loss_blur + g_loss_gray + 2e2*recon_loss
78 | d_loss_total = d_loss_blur + d_loss_gray
79 |
80 | all_vars = tf.trainable_variables()
81 | gene_vars = [var for var in all_vars if 'gene' in var.name]
82 | disc_vars = [var for var in all_vars if 'disc' in var.name]
83 |
84 |
85 | tf.summary.scalar('tv_loss', tv_loss)
86 | tf.summary.scalar('photo_loss', photo_loss)
87 | tf.summary.scalar('superpixel_loss', superpixel_loss)
88 | tf.summary.scalar('recon_loss', recon_loss)
89 | tf.summary.scalar('d_loss_gray', d_loss_gray)
90 | tf.summary.scalar('g_loss_gray', g_loss_gray)
91 | tf.summary.scalar('d_loss_blur', d_loss_blur)
92 | tf.summary.scalar('g_loss_blur', g_loss_blur)
93 | tf.summary.scalar('d_loss_total', d_loss_total)
94 | tf.summary.scalar('g_loss_total', g_loss_total)
95 |
96 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
97 | with tf.control_dependencies(update_ops):
98 |
99 | g_optim = tf.train.AdamOptimizer(args.adv_train_lr, beta1=0.5, beta2=0.99)\
100 | .minimize(g_loss_total, var_list=gene_vars)
101 |
102 | d_optim = tf.train.AdamOptimizer(args.adv_train_lr, beta1=0.5, beta2=0.99)\
103 | .minimize(d_loss_total, var_list=disc_vars)
104 |
105 | '''
106 | config = tf.ConfigProto()
107 | config.gpu_options.allow_growth = True
108 | sess = tf.Session(config=config)
109 | '''
110 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_fraction)
111 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
112 |
113 |
114 | train_writer = tf.summary.FileWriter(args.save_dir+'/train_log')
115 | summary_op = tf.summary.merge_all()
116 | saver = tf.train.Saver(var_list=gene_vars, max_to_keep=20)
117 |
118 | with tf.device('/device:GPU:0'):
119 |
120 | sess.run(tf.global_variables_initializer())
121 | saver.restore(sess, tf.train.latest_checkpoint('pretrain/saved_models'))
122 |
123 | face_photo_dir = 'dataset/photo_face'
124 | face_photo_list = utils.load_image_list(face_photo_dir)
125 | scenery_photo_dir = 'dataset/photo_scenery'
126 | scenery_photo_list = utils.load_image_list(scenery_photo_dir)
127 |
128 | face_cartoon_dir = 'dataset/cartoon_face'
129 | face_cartoon_list = utils.load_image_list(face_cartoon_dir)
130 | scenery_cartoon_dir = 'dataset/cartoon_scenery'
131 | scenery_cartoon_list = utils.load_image_list(scenery_cartoon_dir)
132 |
133 | for total_iter in tqdm(range(args.total_iter)):
134 |
135 | if np.mod(total_iter, 5) == 0:
136 | photo_batch = utils.next_batch(face_photo_list, args.batch_size)
137 | cartoon_batch = utils.next_batch(face_cartoon_list, args.batch_size)
138 | else:
139 | photo_batch = utils.next_batch(scenery_photo_list, args.batch_size)
140 | cartoon_batch = utils.next_batch(scenery_cartoon_list, args.batch_size)
141 |
142 | inter_out = sess.run(output, feed_dict={input_photo: photo_batch,
143 | input_superpixel: photo_batch,
144 | input_cartoon: cartoon_batch})
145 |
146 | '''
147 | adaptive coloring has to be applied with the clip_by_value
148 | in the last layer of generator network, which is not very stable.
149 | to stabiliy reproduce our results, please use power=1.0
150 | and comment the clip_by_value function in the network.py first
151 | If this works, then try to use adaptive color with clip_by_value.
152 | '''
153 | if args.use_enhance:
154 | superpixel_batch = utils.selective_adacolor(inter_out, power=1.2)
155 | else:
156 | superpixel_batch = utils.simple_superpixel(inter_out, seg_num=200)
157 |
158 | _, g_loss, r_loss = sess.run([g_optim, g_loss_total, recon_loss],
159 | feed_dict={input_photo: photo_batch,
160 | input_superpixel: superpixel_batch,
161 | input_cartoon: cartoon_batch})
162 |
163 | _, d_loss, train_info = sess.run([d_optim, d_loss_total, summary_op],
164 | feed_dict={input_photo: photo_batch,
165 | input_superpixel: superpixel_batch,
166 | input_cartoon: cartoon_batch})
167 |
168 |
169 | train_writer.add_summary(train_info, total_iter)
170 |
171 | if np.mod(total_iter+1, 50) == 0:
172 |
173 | print('Iter: {}, d_loss: {}, g_loss: {}, recon_loss: {}'.\
174 | format(total_iter, d_loss, g_loss, r_loss))
175 | if np.mod(total_iter+1, 500 ) == 0:
176 | saver.save(sess, args.save_dir+'/saved_models/model',
177 | write_meta_graph=False, global_step=total_iter)
178 |
179 | photo_face = utils.next_batch(face_photo_list, args.batch_size)
180 | cartoon_face = utils.next_batch(face_cartoon_list, args.batch_size)
181 | photo_scenery = utils.next_batch(scenery_photo_list, args.batch_size)
182 | cartoon_scenery = utils.next_batch(scenery_cartoon_list, args.batch_size)
183 |
184 | result_face = sess.run(output, feed_dict={input_photo: photo_face,
185 | input_superpixel: photo_face,
186 | input_cartoon: cartoon_face})
187 |
188 | result_scenery = sess.run(output, feed_dict={input_photo: photo_scenery,
189 | input_superpixel: photo_scenery,
190 | input_cartoon: cartoon_scenery})
191 |
192 | utils.write_batch_image(result_face, args.save_dir+'/images',
193 | str(total_iter)+'_face_result.jpg', 4)
194 | utils.write_batch_image(photo_face, args.save_dir+'/images',
195 | str(total_iter)+'_face_photo.jpg', 4)
196 |
197 | utils.write_batch_image(result_scenery, args.save_dir+'/images',
198 | str(total_iter)+'_scenery_result.jpg', 4)
199 | utils.write_batch_image(photo_scenery, args.save_dir+'/images',
200 | str(total_iter)+'_scenery_photo.jpg', 4)
201 |
202 |
203 | if __name__ == '__main__':
204 |
205 | args = arg_parser()
206 | train(args)
207 |
--------------------------------------------------------------------------------
/train_code/utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | Source code for CVPR 2020 paper
3 | 'Learning to Cartoonize Using White-Box Cartoon Representations'
4 | by Xinrui Wang and Jinze yu
5 | '''
6 |
7 |
8 | from scipy.ndimage import filters
9 | from skimage import segmentation, color
10 | from joblib import Parallel, delayed
11 | from selective_search.util import switch_color_space
12 | from selective_search.structure import HierarchicalGrouping
13 |
14 | import os
15 | import cv2
16 | import numpy as np
17 | import scipy.stats as st
18 | import tensorflow as tf
19 |
20 |
21 |
22 | def color_shift(image1, image2, mode='uniform'):
23 | b1, g1, r1 = tf.split(image1, num_or_size_splits=3, axis=3)
24 | b2, g2, r2 = tf.split(image2, num_or_size_splits=3, axis=3)
25 | if mode == 'normal':
26 | b_weight = tf.random.normal(shape=[1], mean=0.114, stddev=0.1)
27 | g_weight = np.random.normal(shape=[1], mean=0.587, stddev=0.1)
28 | r_weight = np.random.normal(shape=[1], mean=0.299, stddev=0.1)
29 | elif mode == 'uniform':
30 | b_weight = tf.random.uniform(shape=[1], minval=0.014, maxval=0.214)
31 | g_weight = tf.random.uniform(shape=[1], minval=0.487, maxval=0.687)
32 | r_weight = tf.random.uniform(shape=[1], minval=0.199, maxval=0.399)
33 | output1 = (b_weight*b1+g_weight*g1+r_weight*r1)/(b_weight+g_weight+r_weight)
34 | output2 = (b_weight*b2+g_weight*g2+r_weight*r2)/(b_weight+g_weight+r_weight)
35 | return output1, output2
36 |
37 |
38 |
39 |
40 | def label2rgb(label_field, image, kind='mix', bg_label=-1, bg_color=(0, 0, 0)):
41 |
42 | #std_list = list()
43 | out = np.zeros_like(image)
44 | labels = np.unique(label_field)
45 | bg = (labels == bg_label)
46 | if bg.any():
47 | labels = labels[labels != bg_label]
48 | mask = (label_field == bg_label).nonzero()
49 | out[mask] = bg_color
50 | for label in labels:
51 | mask = (label_field == label).nonzero()
52 | #std = np.std(image[mask])
53 | #std_list.append(std)
54 | if kind == 'avg':
55 | color = image[mask].mean(axis=0)
56 | elif kind == 'median':
57 | color = np.median(image[mask], axis=0)
58 | elif kind == 'mix':
59 | std = np.std(image[mask])
60 | if std < 20:
61 | color = image[mask].mean(axis=0)
62 | elif 20 < std < 40:
63 | mean = image[mask].mean(axis=0)
64 | median = np.median(image[mask], axis=0)
65 | color = 0.5*mean + 0.5*median
66 | elif 40 < std:
67 | color = image[mask].median(axis=0)
68 | out[mask] = color
69 | return out
70 |
71 |
72 |
73 | def color_ss_map(image, seg_num=200, power=1,
74 | color_space='Lab', k=10, sim_strategy='CTSF'):
75 |
76 | img_seg = segmentation.felzenszwalb(image, scale=k, sigma=0.8, min_size=100)
77 | img_cvtcolor = label2rgb(img_seg, image, kind='mix')
78 | img_cvtcolor = switch_color_space(img_cvtcolor, color_space)
79 | S = HierarchicalGrouping(img_cvtcolor, img_seg, sim_strategy)
80 | S.build_regions()
81 | S.build_region_pairs()
82 |
83 | # Start hierarchical grouping
84 |
85 | while S.num_regions() > seg_num:
86 |
87 | i,j = S.get_highest_similarity()
88 | S.merge_region(i,j)
89 | S.remove_similarities(i,j)
90 | S.calculate_similarity_for_new_region()
91 |
92 | image = label2rgb(S.img_seg, image, kind='mix')
93 | image = (image+1)/2
94 | image = image**power
95 | image = image/np.max(image)
96 | image = image*2 - 1
97 |
98 | return image
99 |
100 |
101 | def selective_adacolor(batch_image, seg_num=200, power=1):
102 | num_job = np.shape(batch_image)[0]
103 | batch_out = Parallel(n_jobs=num_job)(delayed(color_ss_map)\
104 | (image, seg_num, power) for image in batch_image)
105 | return np.array(batch_out)
106 |
107 |
108 |
109 | def simple_superpixel(batch_image, seg_num=200):
110 |
111 | def process_slic(image):
112 | seg_label = segmentation.slic(image, n_segments=seg_num, sigma=1,
113 | compactness=10, convert2lab=True)
114 | image = color.label2rgb(seg_label, image, kind='mix')
115 | return image
116 |
117 | num_job = np.shape(batch_image)[0]
118 | batch_out = Parallel(n_jobs=num_job)(delayed(process_slic)\
119 | (image) for image in batch_image)
120 | return np.array(batch_out)
121 |
122 |
123 |
124 | def load_image_list(data_dir):
125 | name_list = list()
126 | for name in os.listdir(data_dir):
127 | name_list.append(os.path.join(data_dir, name))
128 | name_list.sort()
129 | return name_list
130 |
131 |
132 | def next_batch(filename_list, batch_size):
133 | idx = np.arange(0 , len(filename_list))
134 | np.random.shuffle(idx)
135 | idx = idx[:batch_size]
136 | batch_data = []
137 | for i in range(batch_size):
138 | image = cv2.imread(filename_list[idx[i]])
139 | image = image.astype(np.float32)/127.5 - 1
140 | #image = image.astype(np.float32)/255.0
141 | batch_data.append(image)
142 |
143 | return np.asarray(batch_data)
144 |
145 |
146 |
147 | def write_batch_image(image, save_dir, name, n):
148 |
149 | if not os.path.exists(save_dir):
150 | os.makedirs(save_dir)
151 |
152 | fused_dir = os.path.join(save_dir, name)
153 | fused_image = [0] * n
154 | for i in range(n):
155 | fused_image[i] = []
156 | for j in range(n):
157 | k = i * n + j
158 | image[k] = (image[k]+1) * 127.5
159 | #image[k] = image[k] - np.min(image[k])
160 | #image[k] = image[k]/np.max(image[k])
161 | #image[k] = image[k] * 255.0
162 | fused_image[i].append(image[k])
163 | fused_image[i] = np.hstack(fused_image[i])
164 | fused_image = np.vstack(fused_image)
165 | cv2.imwrite(fused_dir, fused_image.astype(np.uint8))
166 |
167 |
168 | if __name__ == '__main__':
169 | pass
170 |
--------------------------------------------------------------------------------