├── README.md ├── images ├── city1.jpg ├── city2.jpg ├── folder.jpg ├── food.jpg ├── home.jpg ├── method.jpg ├── person1.jpg ├── person2.jpg └── use_cases.jpg ├── index.html ├── index_files ├── pixl-bk.css └── pixl-fonts.css ├── paper ├── 06791-supp.pdf ├── 06791.pdf ├── shinjuku.jpg └── white_box_cartoon_acm_style.pdf ├── test_code ├── cartoonize.py ├── guided_filter.py ├── network.py ├── saved_models │ ├── checkpoint │ ├── model-33999.data-00000-of-00001 │ └── model-33999.index └── test_images │ ├── actress2.jpg │ ├── china6.jpg │ ├── food16.jpg │ ├── food6.jpg │ ├── liuyifei4.jpg │ ├── london1.jpg │ ├── mountain4.jpg │ ├── mountain5.jpg │ ├── national_park1.jpg │ ├── party5.jpg │ └── party7.jpg └── train_code ├── guided_filter.py ├── layers.py ├── loss.py ├── network.py ├── pretrain.py ├── selective_search ├── __init__.py ├── adaptive_color.py ├── batch_ss.py ├── core.py ├── measure.py ├── structure.py └── util.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |


4 | 5 | # [CVPR2020]Learning to Cartoonize Using White-box Cartoon Representations 6 | [project page](https://systemerrorwang.github.io/White-box-Cartoonization/) | [paper](https://github.com/SystemErrorWang/White-box-Cartoonization/blob/master/paper/06791.pdf) | [twitter](https://twitter.com/IlIIlIIIllIllII/status/1243108510423896065) | [zhihu](https://zhuanlan.zhihu.com/p/117422157) | [bilibili](https://www.bilibili.com/video/av56708333) | [facial model](https://github.com/SystemErrorWang/FacialCartoonization) 7 | 8 | - Tensorflow implementation for CVPR2020 paper “Learning to Cartoonize Using White-box Cartoon Representations”. 9 | - Improved method for facial images are now available: 10 | - https://github.com/SystemErrorWang/FacialCartoonization 11 | 12 | 13 | 14 | 15 | ## Use cases 16 | 17 | ### Scenery 18 | 19 | 20 | 21 | ### Food 22 | 23 | 24 | ### Indoor Scenes 25 | 26 | 27 | ### People 28 | 29 | 30 | 31 | ### More Images Are Shown In The Supplementary Materials 32 | 33 | 34 | ## Online demo 35 | 36 | - Some kind people made online demo for this project 37 | - Demo link: https://cartoonize-lkqov62dia-de.a.run.app/cartoonize 38 | - Code: https://github.com/experience-ml/cartoonize 39 | - Sample Demo: https://www.youtube.com/watch?v=GqduSLcmhto&feature=emb_title 40 | 41 | ## Prerequisites 42 | 43 | - Training code: Linux or Windows 44 | - NVIDIA GPU + CUDA CuDNN for performance 45 | - Inference code: Linux, Windows and MacOS 46 | 47 | 48 | ## How To Use 49 | 50 | ### Installation 51 | 52 | - Assume you already have NVIDIA GPU and CUDA CuDNN installed 53 | - Install tensorflow-gpu, we tested 1.12.0 and 1.13.0rc0 54 | - Install scikit-image==0.14.5, other versions may cause problems 55 | 56 | 57 | ### Inference with Pre-trained Model 58 | 59 | - Store test images in /test_code/test_images 60 | - Run /test_code/cartoonize.py 61 | - Results will be saved in /test_code/cartoonized_images 62 | 63 | 64 | ### Train 65 | 66 | - Place your training data in corresponding folders in /dataset 67 | - Run pretrain.py, results will be saved in /pretrain folder 68 | - Run train.py, results will be saved in /train_cartoon folder 69 | - Codes are cleaned from production environment and untested 70 | - There may be minor problems but should be easy to resolve 71 | - Pretrained VGG_19 model can be found at following url: 72 | https://drive.google.com/file/d/1j0jDENjdwxCDb36meP6-u5xDBzmKBOjJ/view?usp=sharing 73 | 74 | 75 | 76 | ### Datasets 77 | 78 | - Due to copyright issues, we cannot provide cartoon images used for training 79 | - However, these training datasets are easy to prepare 80 | - Scenery images are collected from Shinkai Makoto, Miyazaki Hayao and Hosoda Mamoru films 81 | - Clip films into frames and random crop and resize to 256x256 82 | - Portrait images are from Kyoto animations and PA Works 83 | - We use this repo(https://github.com/nagadomi/lbpcascade_animeface) to detect facial areas 84 | - Manual data cleaning will greatly increace both datasets quality 85 | 86 | ## Acknowledgement 87 | 88 | We are grateful for the help from Lvmin Zhang and Style2Paints Research 89 | 90 | ## License 91 | - Copyright (C) Xinrui Wang All rights reserved. Licensed under the CC BY-NC-SA 4.0 92 | - license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 93 | - Commercial application is prohibited, please remain this license if you clone this repo 94 | 95 | ## Citation 96 | 97 | If you use this code for your research, please cite our [paper](https://systemerrorwang.github.io/White-box-Cartoonization/): 98 | 99 | @InProceedings{Wang_2020_CVPR, 100 | author = {Wang, Xinrui and Yu, Jinze}, 101 | title = {Learning to Cartoonize Using White-Box Cartoon Representations}, 102 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 103 | month = {June}, 104 | year = {2020} 105 | } 106 | 107 | 108 | # 中文社区 109 | 110 | 我们有一个除了技术什么东西都聊的以技术交流为主的群。如果你一次加群失败,可以多次尝试: 816096787。 111 | -------------------------------------------------------------------------------- /images/city1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/images/city1.jpg -------------------------------------------------------------------------------- /images/city2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/images/city2.jpg -------------------------------------------------------------------------------- /images/folder.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/images/folder.jpg -------------------------------------------------------------------------------- /images/food.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/images/food.jpg -------------------------------------------------------------------------------- /images/home.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/images/home.jpg -------------------------------------------------------------------------------- /images/method.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/images/method.jpg -------------------------------------------------------------------------------- /images/person1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/images/person1.jpg -------------------------------------------------------------------------------- /images/person2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/images/person2.jpg -------------------------------------------------------------------------------- /images/use_cases.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/images/use_cases.jpg -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Learning to Cartoonize Using White-box Cartoon Representations 8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | Style2Paints Research → 16 | [Wang et al. 2020] 17 | 18 |
19 | 20 | 21 |
22 |
23 |
Learning to Cartoonize Using White-box Cartoon Representations
24 |
25 |
Computer Vision and Pattern Recognition (CVPR), June 2020
26 |
27 |
Xinrui Wang and Jinze Yu
28 |
29 | 30 |
31 |
Example of image cartoonization with our method: left is a frame in the animation "Garden of words", right is a real-world photo processed by our proposed method.
32 |
Abstract
33 |

34 |

35 | This paper presents an approach for image cartooniza- tion. By observing the cartoon painting behavior and consulting artists, we propose to separately identify three white-box representations from images: the surface rep- resentation that contains a smooth surface of cartoon im- ages, the structure representation that refers to the sparse color-blocks and flatten global content in the celluloid style workflow, and the texture representation that reflects high- frequency texture, contours, and details in cartoon im- ages. A Generative Adversarial Network (GAN) framework is used to learn the extracted representations and to car- toonize images. 36 |
37 | The learning objectives of our method are separately based on each extracted representations, making our frame- work controllable and adjustable. This enables our ap- proach to meet artists’ requirements in different styles and diverse use cases. Qualitative comparisons and quanti- tative analyses, as well as user studies, have been con- ducted to validate the effectiveness of this approach, and our method outperforms previous methods in all compar- isons. Finally, the ablation study demonstrates the influence of each component in our framework. 38 | 39 |
40 |
Files
41 | 45 |
See Also
46 | 54 | 55 |
Citation
56 |

57 | Xinrui Wang and Jinze Yu
58 | "Learning to Cartoonize Using White-box Cartoon Representations."
59 | IEEE Conference on Computer Vision and Pattern Recognition, June 2020. 60 | 61 | 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /index_files/pixl-bk.css: -------------------------------------------------------------------------------- 1 | /* Un-break mobile pages. */ 2 | * { 3 | text-size-adjust: none; 4 | -webkit-text-size-adjust: none; 5 | } 6 | 7 | /* Page background and default foreground colors */ 8 | body { 9 | background: #111; 10 | color: #ddd; 11 | padding-bottom: 5vw; 12 | } 13 | @media screen and (max-width: 533px) { 14 | body { 15 | font-size: 3vw; 16 | } 17 | } 18 | 19 | /* Default link colors */ 20 | a { 21 | color: #f38025; 22 | text-decoration: none; 23 | } 24 | .unlinked { 25 | color: #f38025; 26 | } 27 | 28 | /* "Crumbtrail" bar at top - meant to be used as '

' */ 29 | .crumb { 30 | position: fixed; 31 | top: 0px; 32 | left: 0px; 33 | right: 0px; 34 | height: 22px; 35 | background: #f38025; 36 | border-bottom: 5px solid #000; 37 | font-size: 14px; 38 | font-weight: bold; 39 | color: #000; 40 | padding: 2px 10px; 41 | } 42 | @media screen and (max-width: 800px) { 43 | .crumb { 44 | height: 2.75vw; 45 | border-bottom: 0.625vw solid #000; 46 | font-size: 1.75vw; 47 | padding: 0.025vw 1.25vw; 48 | } 49 | } 50 | .crumb a { 51 | color: #000; 52 | } 53 | 54 | /* Right-aligned span */ 55 | .right { 56 | float: right; 57 | } 58 | 59 | /* Page main title */ 60 | .title { 61 | text-align: center; 62 | font-size: 56px; 63 | padding-top: 80px; 64 | padding-bottom: 56px; 65 | color: #f38025; 66 | } 67 | .title img { 68 | vertical-align: middle; 69 | height: 75px; 70 | padding-right: 10px; 71 | } 72 | @media screen and (max-width: 800px) { 73 | .title { 74 | font-size: 7vw; 75 | padding-top: 10vw; 76 | padding-bottom: 7vw; 77 | } 78 | .title img { 79 | height: 9.375vw; 80 | padding-right: 1.25vw; 81 | } 82 | } 83 | .logoX { 84 | color: #fff; 85 | } 86 | 87 | /* Centered tables or divs */ 88 | .content, .linkgallery, .projlist, .abstract { 89 | width: 720px; 90 | margin: 0px auto; 91 | } 92 | .linkgallery, .projlist, .centered { 93 | text-align: center; 94 | } 95 | .linkgallery a img { 96 | width: 188; 97 | filter: grayscale(20%) brightness(150%); 98 | -webkit-filter: grayscale(20%) brightness(150%); 99 | } 100 | .link { 101 | color: #f38025; 102 | font-size: 36px; 103 | } 104 | .projlist td { 105 | padding: 24px; 106 | } 107 | .projlist img { 108 | width: 240px; 109 | } 110 | @media screen and (max-width: 800px) { 111 | .content, .linkgallery, .projlist, .abstract { 112 | width: 90vw; 113 | } 114 | .linkgallery a img { 115 | width: 23.5vw; 116 | } 117 | .link { 118 | font-size: 4.5vw; 119 | } 120 | .projlist td { 121 | padding: 3vw; 122 | } 123 | .projlist img { 124 | width: 30vw; 125 | } 126 | } 127 | 128 | /* Paper listings */ 129 | .paperheader { 130 | text-align: center; 131 | padding-top: 80px; 132 | padding-bottom: 56px; 133 | } 134 | @media screen and (max-width: 800px) { 135 | .paperheader { 136 | padding-top: 10vw; 137 | padding-bottom: 7vw; 138 | } 139 | } 140 | .papertitle { 141 | font-size: 150%; 142 | font-weight: bold; 143 | color: #f38025; 144 | } 145 | .pubinfo { 146 | } 147 | .authors { 148 | font-size: 125%; 149 | } 150 | .paperimg { 151 | text-align: center; 152 | } 153 | .paperimg img { 154 | max-width: 100%; 155 | } 156 | .longcaption { 157 | margin: 24px; 158 | font-style: italic; 159 | text-align: justify; 160 | } 161 | .shortcaption { 162 | margin: 24px; 163 | font-style: italic; 164 | text-align: center; 165 | } 166 | @media screen and (max-width: 800px) { 167 | .longcaption { 168 | margin: 3vw; 169 | } 170 | .shortcaption { 171 | margin: 3vw; 172 | } 173 | } 174 | .abstract { 175 | text-align: justify; 176 | } 177 | 178 | /* Links with fancy fading hover effect */ 179 | a, .linkgallery a img { 180 | transition: .1s ease-in-out; 181 | -webkit-transition: .1s ease-in-out; 182 | } 183 | a:hover, a:active, .crumb a:hover, .crumb a:active, 184 | .linkgallery a:hover img, .linkgallery a:active img { 185 | color: #fff; 186 | filter: grayscale(100%) brightness(125%); 187 | -webkit-filter: grayscale(100%) brightness(125%); 188 | } 189 | 190 | /* Larger text */ 191 | .larger { 192 | font-size: 125%; 193 | } 194 | 195 | /* Small header */ 196 | .header { 197 | font-size: 125%; 198 | font-weight: bold; 199 | margin-top: 1em; 200 | } 201 | 202 | /* Text inputs */ 203 | input[type=text], select, textarea { 204 | width: 100%; 205 | } 206 | -------------------------------------------------------------------------------- /index_files/pixl-fonts.css: -------------------------------------------------------------------------------- 1 | /* Use Source Sans Pro fonts */ 2 | @import url('https://fonts.googleapis.com/css?family=Source+Sans+Pro:400,600'); 3 | 4 | * { 5 | font-family: 'Source Sans Pro', sans-serif; 6 | } 7 | 8 | -------------------------------------------------------------------------------- /paper/06791-supp.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/paper/06791-supp.pdf -------------------------------------------------------------------------------- /paper/06791.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/paper/06791.pdf -------------------------------------------------------------------------------- /paper/shinjuku.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/paper/shinjuku.jpg -------------------------------------------------------------------------------- /paper/white_box_cartoon_acm_style.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/paper/white_box_cartoon_acm_style.pdf -------------------------------------------------------------------------------- /test_code/cartoonize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import tensorflow as tf 5 | import network 6 | import guided_filter 7 | from tqdm import tqdm 8 | 9 | 10 | 11 | def resize_crop(image): 12 | h, w, c = np.shape(image) 13 | if min(h, w) > 720: 14 | if h > w: 15 | h, w = int(720*h/w), 720 16 | else: 17 | h, w = 720, int(720*w/h) 18 | image = cv2.resize(image, (w, h), 19 | interpolation=cv2.INTER_AREA) 20 | h, w = (h//8)*8, (w//8)*8 21 | image = image[:h, :w, :] 22 | return image 23 | 24 | 25 | def cartoonize(load_folder, save_folder, model_path): 26 | input_photo = tf.placeholder(tf.float32, [1, None, None, 3]) 27 | network_out = network.unet_generator(input_photo) 28 | final_out = guided_filter.guided_filter(input_photo, network_out, r=1, eps=5e-3) 29 | 30 | all_vars = tf.trainable_variables() 31 | gene_vars = [var for var in all_vars if 'generator' in var.name] 32 | saver = tf.train.Saver(var_list=gene_vars) 33 | 34 | config = tf.ConfigProto() 35 | config.gpu_options.allow_growth = True 36 | sess = tf.Session(config=config) 37 | 38 | sess.run(tf.global_variables_initializer()) 39 | saver.restore(sess, tf.train.latest_checkpoint(model_path)) 40 | name_list = os.listdir(load_folder) 41 | for name in tqdm(name_list): 42 | try: 43 | load_path = os.path.join(load_folder, name) 44 | save_path = os.path.join(save_folder, name) 45 | image = cv2.imread(load_path) 46 | image = resize_crop(image) 47 | batch_image = image.astype(np.float32)/127.5 - 1 48 | batch_image = np.expand_dims(batch_image, axis=0) 49 | output = sess.run(final_out, feed_dict={input_photo: batch_image}) 50 | output = (np.squeeze(output)+1)*127.5 51 | output = np.clip(output, 0, 255).astype(np.uint8) 52 | cv2.imwrite(save_path, output) 53 | except: 54 | print('cartoonize {} failed'.format(load_path)) 55 | 56 | 57 | 58 | 59 | if __name__ == '__main__': 60 | model_path = 'saved_models' 61 | load_folder = 'test_images' 62 | save_folder = 'cartoonized_images' 63 | if not os.path.exists(save_folder): 64 | os.mkdir(save_folder) 65 | cartoonize(load_folder, save_folder, model_path) 66 | 67 | 68 | -------------------------------------------------------------------------------- /test_code/guided_filter.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | 6 | 7 | def tf_box_filter(x, r): 8 | k_size = int(2*r+1) 9 | ch = x.get_shape().as_list()[-1] 10 | weight = 1/(k_size**2) 11 | box_kernel = weight*np.ones((k_size, k_size, ch, 1)) 12 | box_kernel = np.array(box_kernel).astype(np.float32) 13 | output = tf.nn.depthwise_conv2d(x, box_kernel, [1, 1, 1, 1], 'SAME') 14 | return output 15 | 16 | 17 | 18 | def guided_filter(x, y, r, eps=1e-2): 19 | 20 | x_shape = tf.shape(x) 21 | #y_shape = tf.shape(y) 22 | 23 | N = tf_box_filter(tf.ones((1, x_shape[1], x_shape[2], 1), dtype=x.dtype), r) 24 | 25 | mean_x = tf_box_filter(x, r) / N 26 | mean_y = tf_box_filter(y, r) / N 27 | cov_xy = tf_box_filter(x * y, r) / N - mean_x * mean_y 28 | var_x = tf_box_filter(x * x, r) / N - mean_x * mean_x 29 | 30 | A = cov_xy / (var_x + eps) 31 | b = mean_y - A * mean_x 32 | 33 | mean_A = tf_box_filter(A, r) / N 34 | mean_b = tf_box_filter(b, r) / N 35 | 36 | output = mean_A * x + mean_b 37 | 38 | return output 39 | 40 | 41 | 42 | def fast_guided_filter(lr_x, lr_y, hr_x, r=1, eps=1e-8): 43 | 44 | #assert lr_x.shape.ndims == 4 and lr_y.shape.ndims == 4 and hr_x.shape.ndims == 4 45 | 46 | lr_x_shape = tf.shape(lr_x) 47 | #lr_y_shape = tf.shape(lr_y) 48 | hr_x_shape = tf.shape(hr_x) 49 | 50 | N = tf_box_filter(tf.ones((1, lr_x_shape[1], lr_x_shape[2], 1), dtype=lr_x.dtype), r) 51 | 52 | mean_x = tf_box_filter(lr_x, r) / N 53 | mean_y = tf_box_filter(lr_y, r) / N 54 | cov_xy = tf_box_filter(lr_x * lr_y, r) / N - mean_x * mean_y 55 | var_x = tf_box_filter(lr_x * lr_x, r) / N - mean_x * mean_x 56 | 57 | A = cov_xy / (var_x + eps) 58 | b = mean_y - A * mean_x 59 | 60 | mean_A = tf.image.resize_images(A, hr_x_shape[1: 3]) 61 | mean_b = tf.image.resize_images(b, hr_x_shape[1: 3]) 62 | 63 | output = mean_A * hr_x + mean_b 64 | 65 | return output 66 | 67 | 68 | if __name__ == '__main__': 69 | import cv2 70 | from tqdm import tqdm 71 | 72 | input_photo = tf.placeholder(tf.float32, [1, None, None, 3]) 73 | #input_superpixel = tf.placeholder(tf.float32, [16, 256, 256, 3]) 74 | output = guided_filter(input_photo, input_photo, 5, eps=1) 75 | image = cv2.imread('output_figure1/cartoon2.jpg') 76 | image = image/127.5 - 1 77 | image = np.expand_dims(image, axis=0) 78 | 79 | config = tf.ConfigProto() 80 | config.gpu_options.allow_growth = True 81 | sess = tf.Session(config=config) 82 | sess.run(tf.global_variables_initializer()) 83 | 84 | out = sess.run(output, feed_dict={input_photo: image}) 85 | out = (np.squeeze(out)+1)*127.5 86 | out = np.clip(out, 0, 255).astype(np.uint8) 87 | cv2.imwrite('output_figure1/cartoon2_filter.jpg', out) 88 | -------------------------------------------------------------------------------- /test_code/network.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import tensorflow.contrib.slim as slim 4 | 5 | 6 | 7 | def resblock(inputs, out_channel=32, name='resblock'): 8 | 9 | with tf.variable_scope(name): 10 | 11 | x = slim.convolution2d(inputs, out_channel, [3, 3], 12 | activation_fn=None, scope='conv1') 13 | x = tf.nn.leaky_relu(x) 14 | x = slim.convolution2d(x, out_channel, [3, 3], 15 | activation_fn=None, scope='conv2') 16 | 17 | return x + inputs 18 | 19 | 20 | 21 | 22 | def unet_generator(inputs, channel=32, num_blocks=4, name='generator', reuse=False): 23 | with tf.variable_scope(name, reuse=reuse): 24 | 25 | x0 = slim.convolution2d(inputs, channel, [7, 7], activation_fn=None) 26 | x0 = tf.nn.leaky_relu(x0) 27 | 28 | x1 = slim.convolution2d(x0, channel, [3, 3], stride=2, activation_fn=None) 29 | x1 = tf.nn.leaky_relu(x1) 30 | x1 = slim.convolution2d(x1, channel*2, [3, 3], activation_fn=None) 31 | x1 = tf.nn.leaky_relu(x1) 32 | 33 | x2 = slim.convolution2d(x1, channel*2, [3, 3], stride=2, activation_fn=None) 34 | x2 = tf.nn.leaky_relu(x2) 35 | x2 = slim.convolution2d(x2, channel*4, [3, 3], activation_fn=None) 36 | x2 = tf.nn.leaky_relu(x2) 37 | 38 | for idx in range(num_blocks): 39 | x2 = resblock(x2, out_channel=channel*4, name='block_{}'.format(idx)) 40 | 41 | x2 = slim.convolution2d(x2, channel*2, [3, 3], activation_fn=None) 42 | x2 = tf.nn.leaky_relu(x2) 43 | 44 | h1, w1 = tf.shape(x2)[1], tf.shape(x2)[2] 45 | x3 = tf.image.resize_bilinear(x2, (h1*2, w1*2)) 46 | x3 = slim.convolution2d(x3+x1, channel*2, [3, 3], activation_fn=None) 47 | x3 = tf.nn.leaky_relu(x3) 48 | x3 = slim.convolution2d(x3, channel, [3, 3], activation_fn=None) 49 | x3 = tf.nn.leaky_relu(x3) 50 | 51 | h2, w2 = tf.shape(x3)[1], tf.shape(x3)[2] 52 | x4 = tf.image.resize_bilinear(x3, (h2*2, w2*2)) 53 | x4 = slim.convolution2d(x4+x0, channel, [3, 3], activation_fn=None) 54 | x4 = tf.nn.leaky_relu(x4) 55 | x4 = slim.convolution2d(x4, 3, [7, 7], activation_fn=None) 56 | 57 | return x4 58 | 59 | if __name__ == '__main__': 60 | 61 | 62 | pass -------------------------------------------------------------------------------- /test_code/saved_models/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model-33999" 2 | all_model_checkpoint_paths: "model-33999" 3 | all_model_checkpoint_paths: "model-37499" 4 | -------------------------------------------------------------------------------- /test_code/saved_models/model-33999.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/saved_models/model-33999.data-00000-of-00001 -------------------------------------------------------------------------------- /test_code/saved_models/model-33999.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/saved_models/model-33999.index -------------------------------------------------------------------------------- /test_code/test_images/actress2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/actress2.jpg -------------------------------------------------------------------------------- /test_code/test_images/china6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/china6.jpg -------------------------------------------------------------------------------- /test_code/test_images/food16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/food16.jpg -------------------------------------------------------------------------------- /test_code/test_images/food6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/food6.jpg -------------------------------------------------------------------------------- /test_code/test_images/liuyifei4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/liuyifei4.jpg -------------------------------------------------------------------------------- /test_code/test_images/london1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/london1.jpg -------------------------------------------------------------------------------- /test_code/test_images/mountain4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/mountain4.jpg -------------------------------------------------------------------------------- /test_code/test_images/mountain5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/mountain5.jpg -------------------------------------------------------------------------------- /test_code/test_images/national_park1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/national_park1.jpg -------------------------------------------------------------------------------- /test_code/test_images/party5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/party5.jpg -------------------------------------------------------------------------------- /test_code/test_images/party7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SystemErrorWang/White-box-Cartoonization/4a1a071cc119f1f48681030581c8134d97cf3d1e/test_code/test_images/party7.jpg -------------------------------------------------------------------------------- /train_code/guided_filter.py: -------------------------------------------------------------------------------- 1 | ''' 2 | CVPR 2020 submission, Paper ID 6791 3 | Source code for 'Learning to Cartoonize Using White-Box Cartoon Representations' 4 | ''' 5 | 6 | 7 | import tensorflow as tf 8 | import numpy as np 9 | 10 | 11 | def tf_box_filter(x, r): 12 | ch = x.get_shape().as_list()[-1] 13 | weight = 1/((2*r+1)**2) 14 | box_kernel = weight*np.ones((2*r+1, 2*r+1, ch, 1)) 15 | box_kernel = np.array(box_kernel).astype(np.float32) 16 | output = tf.nn.depthwise_conv2d(x, box_kernel, [1, 1, 1, 1], 'SAME') 17 | return output 18 | 19 | 20 | 21 | def guided_filter(x, y, r, eps=1e-2): 22 | 23 | x_shape = tf.shape(x) 24 | #y_shape = tf.shape(y) 25 | 26 | N = tf_box_filter(tf.ones((1, x_shape[1], x_shape[2], 1), dtype=x.dtype), r) 27 | 28 | mean_x = tf_box_filter(x, r) / N 29 | mean_y = tf_box_filter(y, r) / N 30 | cov_xy = tf_box_filter(x * y, r) / N - mean_x * mean_y 31 | var_x = tf_box_filter(x * x, r) / N - mean_x * mean_x 32 | 33 | A = cov_xy / (var_x + eps) 34 | b = mean_y - A * mean_x 35 | 36 | mean_A = tf_box_filter(A, r) / N 37 | mean_b = tf_box_filter(b, r) / N 38 | 39 | output = mean_A * x + mean_b 40 | 41 | return output 42 | 43 | 44 | if __name__ == '__main__': 45 | pass 46 | -------------------------------------------------------------------------------- /train_code/layers.py: -------------------------------------------------------------------------------- 1 | ''' 2 | CVPR 2020 submission, Paper ID 6791 3 | Source code for 'Learning to Cartoonize Using White-Box Cartoon Representations' 4 | ''' 5 | 6 | 7 | import tensorflow as tf 8 | import numpy as np 9 | import tensorflow.contrib.slim as slim 10 | 11 | 12 | 13 | def adaptive_instance_norm(content, style, epsilon=1e-5): 14 | 15 | c_mean, c_var = tf.nn.moments(content, axes=[1, 2], keep_dims=True) 16 | s_mean, s_var = tf.nn.moments(style, axes=[1, 2], keep_dims=True) 17 | c_std, s_std = tf.sqrt(c_var + epsilon), tf.sqrt(s_var + epsilon) 18 | 19 | return s_std * (content - c_mean) / c_std + s_mean 20 | 21 | 22 | 23 | def spectral_norm(w, iteration=1): 24 | w_shape = w.shape.as_list() 25 | w = tf.reshape(w, [-1, w_shape[-1]]) 26 | 27 | u = tf.get_variable("u", [1, w_shape[-1]], 28 | initializer=tf.random_normal_initializer(), trainable=False) 29 | 30 | u_hat = u 31 | v_hat = None 32 | for i in range(iteration): 33 | """ 34 | power iteration 35 | Usually iteration = 1 will be enough 36 | """ 37 | v_ = tf.matmul(u_hat, tf.transpose(w)) 38 | v_hat = tf.nn.l2_normalize(v_) 39 | 40 | u_ = tf.matmul(v_hat, w) 41 | u_hat = tf.nn.l2_normalize(u_) 42 | 43 | u_hat = tf.stop_gradient(u_hat) 44 | v_hat = tf.stop_gradient(v_hat) 45 | 46 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) 47 | 48 | with tf.control_dependencies([u.assign(u_hat)]): 49 | w_norm = w / sigma 50 | w_norm = tf.reshape(w_norm, w_shape) 51 | 52 | return w_norm 53 | 54 | 55 | def conv_spectral_norm(x, channel, k_size, stride=1, name='conv_snorm'): 56 | with tf.variable_scope(name): 57 | w = tf.get_variable("kernel", shape=[k_size[0], k_size[1], x.get_shape()[-1], channel]) 58 | b = tf.get_variable("bias", [channel], initializer=tf.constant_initializer(0.0)) 59 | 60 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), strides=[1, stride, stride, 1], padding='SAME') + b 61 | 62 | return x 63 | 64 | 65 | 66 | def self_attention(inputs, name='attention', reuse=False): 67 | with tf.variable_scope(name, reuse=reuse): 68 | h, w = tf.shape(inputs)[1], tf.shape(inputs)[2] 69 | bs, _, _, ch = inputs.get_shape().as_list() 70 | f = slim.convolution2d(inputs, ch//8, [1, 1], activation_fn=None) 71 | g = slim.convolution2d(inputs, ch//8, [1, 1], activation_fn=None) 72 | s = slim.convolution2d(inputs, 1, [1, 1], activation_fn=None) 73 | f_flatten = tf.reshape(f, shape=[f.shape[0], -1, f.shape[-1]]) 74 | g_flatten = tf.reshape(g, shape=[g.shape[0], -1, g.shape[-1]]) 75 | beta = tf.matmul(f_flatten, g_flatten, transpose_b=True) 76 | beta = tf.nn.softmax(beta) 77 | 78 | s_flatten = tf.reshape(s, shape=[s.shape[0], -1, s.shape[-1]]) 79 | att_map = tf.matmul(beta, s_flatten) 80 | att_map = tf.reshape(att_map, shape=[bs, h, w, 1]) 81 | gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0)) 82 | output = att_map * gamma + inputs 83 | 84 | return att_map, output 85 | 86 | 87 | 88 | if __name__ == '__main__': 89 | pass 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /train_code/loss.py: -------------------------------------------------------------------------------- 1 | ''' 2 | CVPR 2020 submission, Paper ID 6791 3 | Source code for 'Learning to Cartoonize Using White-Box Cartoon Representations' 4 | ''' 5 | 6 | 7 | import numpy as np 8 | import scipy.stats as st 9 | import tensorflow as tf 10 | 11 | 12 | 13 | VGG_MEAN = [103.939, 116.779, 123.68] 14 | 15 | 16 | class Vgg19: 17 | 18 | def __init__(self, vgg19_npy_path=None): 19 | 20 | self.data_dict = np.load(vgg19_npy_path, encoding='latin1', allow_pickle=True).item() 21 | print('Finished loading vgg19.npy') 22 | 23 | 24 | def build_conv4_4(self, rgb, include_fc=False): 25 | 26 | rgb_scaled = (rgb+1) * 127.5 27 | 28 | blue, green, red = tf.split(axis=3, num_or_size_splits=3, value=rgb_scaled) 29 | bgr = tf.concat(axis=3, values=[blue - VGG_MEAN[0], 30 | green - VGG_MEAN[1], red - VGG_MEAN[2]]) 31 | 32 | self.conv1_1 = self.conv_layer(bgr, "conv1_1") 33 | self.relu1_1 = tf.nn.relu(self.conv1_1) 34 | self.conv1_2 = self.conv_layer(self.relu1_1, "conv1_2") 35 | self.relu1_2 = tf.nn.relu(self.conv1_2) 36 | self.pool1 = self.max_pool(self.relu1_2, 'pool1') 37 | 38 | self.conv2_1 = self.conv_layer(self.pool1, "conv2_1") 39 | self.relu2_1 = tf.nn.relu(self.conv2_1) 40 | self.conv2_2 = self.conv_layer(self.relu2_1, "conv2_2") 41 | self.relu2_2 = tf.nn.relu(self.conv2_2) 42 | self.pool2 = self.max_pool(self.relu2_2, 'pool2') 43 | 44 | self.conv3_1 = self.conv_layer(self.pool2, "conv3_1") 45 | self.relu3_1 = tf.nn.relu(self.conv3_1) 46 | self.conv3_2 = self.conv_layer(self.relu3_1, "conv3_2") 47 | self.relu3_2 = tf.nn.relu(self.conv3_2) 48 | self.conv3_3 = self.conv_layer(self.relu3_2, "conv3_3") 49 | self.relu3_3 = tf.nn.relu(self.conv3_3) 50 | self.conv3_4 = self.conv_layer(self.relu3_3, "conv3_4") 51 | self.relu3_4 = tf.nn.relu(self.conv3_4) 52 | self.pool3 = self.max_pool(self.relu3_4, 'pool3') 53 | 54 | self.conv4_1 = self.conv_layer(self.pool3, "conv4_1") 55 | self.relu4_1 = tf.nn.relu(self.conv4_1) 56 | self.conv4_2 = self.conv_layer(self.relu4_1, "conv4_2") 57 | self.relu4_2 = tf.nn.relu(self.conv4_2) 58 | self.conv4_3 = self.conv_layer(self.relu4_2, "conv4_3") 59 | self.relu4_3 = tf.nn.relu(self.conv4_3) 60 | self.conv4_4 = self.conv_layer(self.relu4_3, "conv4_4") 61 | self.relu4_4 = tf.nn.relu(self.conv4_4) 62 | self.pool4 = self.max_pool(self.relu4_4, 'pool4') 63 | 64 | return self.conv4_4 65 | 66 | def max_pool(self, bottom, name): 67 | return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], 68 | strides=[1, 2, 2, 1], padding='SAME', name=name) 69 | 70 | def conv_layer(self, bottom, name): 71 | with tf.variable_scope(name): 72 | filt = self.get_conv_filter(name) 73 | 74 | conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME') 75 | 76 | conv_biases = self.get_bias(name) 77 | bias = tf.nn.bias_add(conv, conv_biases) 78 | 79 | #relu = tf.nn.relu(bias) 80 | return bias 81 | 82 | 83 | 84 | def fc_layer(self, bottom, name): 85 | with tf.variable_scope(name): 86 | shape = bottom.get_shape().as_list() 87 | dim = 1 88 | for d in shape[1:]: 89 | dim *= d 90 | x = tf.reshape(bottom, [-1, dim]) 91 | 92 | weights = self.get_fc_weight(name) 93 | biases = self.get_bias(name) 94 | 95 | # Fully connected layer. Note that the '+' operation automatically 96 | # broadcasts the biases. 97 | fc = tf.nn.bias_add(tf.matmul(x, weights), biases) 98 | 99 | return fc 100 | 101 | def get_conv_filter(self, name): 102 | return tf.constant(self.data_dict[name][0], name="filter") 103 | 104 | def get_bias(self, name): 105 | return tf.constant(self.data_dict[name][1], name="biases") 106 | 107 | def get_fc_weight(self, name): 108 | return tf.constant(self.data_dict[name][0], name="weights") 109 | 110 | 111 | 112 | def vggloss_4_4(image_a, image_b): 113 | vgg_model = Vgg19('vgg19_no_fc.npy') 114 | vgg_a = vgg_model.build_conv4_4(image_a) 115 | vgg_b = vgg_model.build_conv4_4(image_b) 116 | VGG_loss = tf.losses.absolute_difference(vgg_a, vgg_b) 117 | #VGG_loss = tf.nn.l2_loss(vgg_a - vgg_b) 118 | h, w, c= vgg_a.get_shape().as_list()[1:] 119 | VGG_loss = tf.reduce_mean(VGG_loss)/(h*w*c) 120 | return VGG_loss 121 | 122 | 123 | 124 | def wgan_loss(discriminator, real, fake, patch=True, 125 | channel=32, name='discriminator', lambda_=2): 126 | real_logits = discriminator(real, patch=patch, channel=channel, name=name, reuse=False) 127 | fake_logits = discriminator(fake, patch=patch, channel=channel, name=name, reuse=True) 128 | 129 | d_loss_real = - tf.reduce_mean(real_logits) 130 | d_loss_fake = tf.reduce_mean(fake_logits) 131 | 132 | d_loss = d_loss_real + d_loss_fake 133 | g_loss = - d_loss_fake 134 | 135 | """ Gradient Penalty """ 136 | # This is borrowed from https://github.com/kodalinaveen3/DRAGAN/blob/master/DRAGAN.ipynb 137 | alpha = tf.random_uniform([tf.shape(real)[0], 1, 1, 1], minval=0.,maxval=1.) 138 | differences = fake - real # This is different from MAGAN 139 | interpolates = real + (alpha * differences) 140 | inter_logit = discriminator(interpolates, channel=channel, name=name, reuse=True) 141 | gradients = tf.gradients(inter_logit, [interpolates])[0] 142 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) 143 | gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2) 144 | d_loss += lambda_ * gradient_penalty 145 | 146 | return d_loss, g_loss 147 | 148 | 149 | def gan_loss(discriminator, real, fake, scale=1,channel=32, patch=False, name='discriminator'): 150 | 151 | real_logit = discriminator(real, scale, channel, name=name, patch=patch, reuse=False) 152 | fake_logit = discriminator(fake, scale, channel, name=name, patch=patch, reuse=True) 153 | 154 | real_logit = tf.nn.sigmoid(real_logit) 155 | fake_logit = tf.nn.sigmoid(fake_logit) 156 | 157 | g_loss_blur = -tf.reduce_mean(tf.log(fake_logit)) 158 | d_loss_blur = -tf.reduce_mean(tf.log(real_logit) + tf.log(1. - fake_logit)) 159 | 160 | return d_loss_blur, g_loss_blur 161 | 162 | 163 | 164 | def lsgan_loss(discriminator, real, fake, scale=1, 165 | channel=32, patch=False, name='discriminator'): 166 | 167 | real_logit = discriminator(real, scale, channel, name=name, patch=patch, reuse=False) 168 | fake_logit = discriminator(fake, scale, channel, name=name, patch=patch, reuse=True) 169 | 170 | g_loss = tf.reduce_mean((fake_logit - 1)**2) 171 | d_loss = 0.5*(tf.reduce_mean((real_logit - 1)**2) + tf.reduce_mean(fake_logit**2)) 172 | 173 | return d_loss, g_loss 174 | 175 | 176 | 177 | def total_variation_loss(image, k_size=1): 178 | h, w = image.get_shape().as_list()[1:3] 179 | tv_h = tf.reduce_mean((image[:, k_size:, :, :] - image[:, :h - k_size, :, :])**2) 180 | tv_w = tf.reduce_mean((image[:, :, k_size:, :] - image[:, :, :w - k_size, :])**2) 181 | tv_loss = (tv_h + tv_w)/(3*h*w) 182 | return tv_loss 183 | 184 | 185 | 186 | 187 | if __name__ == '__main__': 188 | pass 189 | 190 | 191 | -------------------------------------------------------------------------------- /train_code/network.py: -------------------------------------------------------------------------------- 1 | ''' 2 | CVPR 2020 submission, Paper ID 6791 3 | Source code for 'Learning to Cartoonize Using White-Box Cartoon Representations' 4 | ''' 5 | 6 | 7 | import layers 8 | import tensorflow as tf 9 | import numpy as np 10 | import tensorflow.contrib.slim as slim 11 | 12 | from tqdm import tqdm 13 | 14 | 15 | 16 | def resblock(inputs, out_channel=32, name='resblock'): 17 | 18 | with tf.variable_scope(name): 19 | 20 | x = slim.convolution2d(inputs, out_channel, [3, 3], 21 | activation_fn=None, scope='conv1') 22 | x = tf.nn.leaky_relu(x) 23 | x = slim.convolution2d(x, out_channel, [3, 3], 24 | activation_fn=None, scope='conv2') 25 | 26 | return x + inputs 27 | 28 | 29 | 30 | def generator(inputs, channel=32, num_blocks=4, name='generator', reuse=False): 31 | with tf.variable_scope(name, reuse=reuse): 32 | 33 | x = slim.convolution2d(inputs, channel, [7, 7], activation_fn=None) 34 | x = tf.nn.leaky_relu(x) 35 | 36 | x = slim.convolution2d(x, channel*2, [3, 3], stride=2, activation_fn=None) 37 | x = slim.convolution2d(x, channel*2, [3, 3], activation_fn=None) 38 | x = tf.nn.leaky_relu(x) 39 | 40 | x = slim.convolution2d(x, channel*4, [3, 3], stride=2, activation_fn=None) 41 | x = slim.convolution2d(x, channel*4, [3, 3], activation_fn=None) 42 | x = tf.nn.leaky_relu(x) 43 | 44 | for idx in range(num_blocks): 45 | x = resblock(x, out_channel=channel*4, name='block_{}'.format(idx)) 46 | 47 | x = slim.conv2d_transpose(x, channel*2, [3, 3], stride=2, activation_fn=None) 48 | x = slim.convolution2d(x, channel*2, [3, 3], activation_fn=None) 49 | 50 | x = tf.nn.leaky_relu(x) 51 | 52 | x = slim.conv2d_transpose(x, channel, [3, 3], stride=2, activation_fn=None) 53 | x = slim.convolution2d(x, channel, [3, 3], activation_fn=None) 54 | x = tf.nn.leaky_relu(x) 55 | 56 | x = slim.convolution2d(x, 3, [7, 7], activation_fn=None) 57 | #x = tf.clip_by_value(x, -0.999999, 0.999999) 58 | 59 | return x 60 | 61 | 62 | def unet_generator(inputs, channel=32, num_blocks=4, name='generator', reuse=False): 63 | with tf.variable_scope(name, reuse=reuse): 64 | 65 | x0 = slim.convolution2d(inputs, channel, [7, 7], activation_fn=None) 66 | x0 = tf.nn.leaky_relu(x0) 67 | 68 | x1 = slim.convolution2d(x0, channel, [3, 3], stride=2, activation_fn=None) 69 | x1 = tf.nn.leaky_relu(x1) 70 | x1 = slim.convolution2d(x1, channel*2, [3, 3], activation_fn=None) 71 | x1 = tf.nn.leaky_relu(x1) 72 | 73 | x2 = slim.convolution2d(x1, channel*2, [3, 3], stride=2, activation_fn=None) 74 | x2 = tf.nn.leaky_relu(x2) 75 | x2 = slim.convolution2d(x2, channel*4, [3, 3], activation_fn=None) 76 | x2 = tf.nn.leaky_relu(x2) 77 | 78 | for idx in range(num_blocks): 79 | x2 = resblock(x2, out_channel=channel*4, name='block_{}'.format(idx)) 80 | 81 | x2 = slim.convolution2d(x2, channel*2, [3, 3], activation_fn=None) 82 | x2 = tf.nn.leaky_relu(x2) 83 | 84 | h1, w1 = tf.shape(x2)[1], tf.shape(x2)[2] 85 | x3 = tf.image.resize_bilinear(x2, (h1*2, w1*2)) 86 | x3 = slim.convolution2d(x3+x1, channel*2, [3, 3], activation_fn=None) 87 | x3 = tf.nn.leaky_relu(x3) 88 | x3 = slim.convolution2d(x3, channel, [3, 3], activation_fn=None) 89 | x3 = tf.nn.leaky_relu(x3) 90 | 91 | h2, w2 = tf.shape(x3)[1], tf.shape(x3)[2] 92 | x4 = tf.image.resize_bilinear(x3, (h2*2, w2*2)) 93 | x4 = slim.convolution2d(x4+x0, channel, [3, 3], activation_fn=None) 94 | x4 = tf.nn.leaky_relu(x4) 95 | x4 = slim.convolution2d(x4, 3, [7, 7], activation_fn=None) 96 | #x4 = tf.clip_by_value(x4, -1, 1) 97 | return x4 98 | 99 | 100 | 101 | def disc_bn(x, scale=1, channel=32, is_training=True, 102 | name='discriminator', patch=True, reuse=False): 103 | 104 | with tf.variable_scope(name, reuse=reuse): 105 | 106 | for idx in range(3): 107 | x = slim.convolution2d(x, channel*2**idx, [3, 3], stride=2, activation_fn=None) 108 | x = slim.batch_norm(x, is_training=is_training, center=True, scale=True) 109 | x = tf.nn.leaky_relu(x) 110 | 111 | x = slim.convolution2d(x, channel*2**idx, [3, 3], activation_fn=None) 112 | x = slim.batch_norm(x, is_training=is_training, center=True, scale=True) 113 | x = tf.nn.leaky_relu(x) 114 | 115 | if patch == True: 116 | x = slim.convolution2d(x, 1, [1, 1], activation_fn=None) 117 | else: 118 | x = tf.reduce_mean(x, axis=[1, 2]) 119 | x = slim.fully_connected(x, 1, activation_fn=None) 120 | 121 | return x 122 | 123 | 124 | 125 | 126 | def disc_sn(x, scale=1, channel=32, patch=True, name='discriminator', reuse=False): 127 | with tf.variable_scope(name, reuse=reuse): 128 | 129 | for idx in range(3): 130 | x = layers.conv_spectral_norm(x, channel*2**idx, [3, 3], 131 | stride=2, name='conv{}_1'.format(idx)) 132 | x = tf.nn.leaky_relu(x) 133 | 134 | x = layers.conv_spectral_norm(x, channel*2**idx, [3, 3], 135 | name='conv{}_2'.format(idx)) 136 | x = tf.nn.leaky_relu(x) 137 | 138 | 139 | if patch == True: 140 | x = layers.conv_spectral_norm(x, 1, [1, 1], name='conv_out'.format(idx)) 141 | 142 | else: 143 | x = tf.reduce_mean(x, axis=[1, 2]) 144 | x = slim.fully_connected(x, 1, activation_fn=None) 145 | 146 | return x 147 | 148 | 149 | def disc_ln(x, channel=32, is_training=True, name='discriminator', patch=True, reuse=False): 150 | with tf.variable_scope(name, reuse=reuse): 151 | 152 | for idx in range(3): 153 | x = slim.convolution2d(x, channel*2**idx, [3, 3], stride=2, activation_fn=None) 154 | x = tf.contrib.layers.layer_norm(x) 155 | x = tf.nn.leaky_relu(x) 156 | 157 | x = slim.convolution2d(x, channel*2**idx, [3, 3], activation_fn=None) 158 | x = tf.contrib.layers.layer_norm(x) 159 | x = tf.nn.leaky_relu(x) 160 | 161 | if patch == True: 162 | x = slim.convolution2d(x, 1, [1, 1], activation_fn=None) 163 | else: 164 | x = tf.reduce_mean(x, axis=[1, 2]) 165 | x = slim.fully_connected(x, 1, activation_fn=None) 166 | 167 | return x 168 | 169 | 170 | 171 | 172 | if __name__ == '__main__': 173 | pass 174 | 175 | -------------------------------------------------------------------------------- /train_code/pretrain.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Source code for CVPR 2020 paper 3 | 'Learning to Cartoonize Using White-Box Cartoon Representations' 4 | by Xinrui Wang and Jinze yu 5 | ''' 6 | 7 | 8 | 9 | import tensorflow as tf 10 | import tensorflow.contrib.slim as slim 11 | 12 | import utils 13 | import os 14 | import numpy as np 15 | import argparse 16 | import network 17 | from tqdm import tqdm 18 | 19 | 20 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 21 | 22 | 23 | def arg_parser(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--patch_size", default = 256, type = int) 26 | parser.add_argument("--batch_size", default = 16, type = int) 27 | parser.add_argument("--total_iter", default = 50000, type = int) 28 | parser.add_argument("--adv_train_lr", default = 2e-4, type = float) 29 | parser.add_argument("--gpu_fraction", default = 0.5, type = float) 30 | parser.add_argument("--save_dir", default = 'pretrain') 31 | 32 | args = parser.parse_args() 33 | 34 | return args 35 | 36 | 37 | 38 | def train(args): 39 | 40 | 41 | input_photo = tf.placeholder(tf.float32, [args.batch_size, 42 | args.patch_size, args.patch_size, 3]) 43 | 44 | output = network.unet_generator(input_photo) 45 | 46 | recon_loss = tf.reduce_mean(tf.losses.absolute_difference(input_photo, output)) 47 | 48 | all_vars = tf.trainable_variables() 49 | gene_vars = [var for var in all_vars if 'gene' in var.name] 50 | 51 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 52 | with tf.control_dependencies(update_ops): 53 | 54 | optim = tf.train.AdamOptimizer(args.adv_train_lr, beta1=0.5, beta2=0.99)\ 55 | .minimize(recon_loss, var_list=gene_vars) 56 | 57 | 58 | ''' 59 | config = tf.ConfigProto() 60 | config.gpu_options.allow_growth = True 61 | sess = tf.Session(config=config) 62 | ''' 63 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_fraction) 64 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 65 | saver = tf.train.Saver(var_list=gene_vars, max_to_keep=20) 66 | 67 | with tf.device('/device:GPU:0'): 68 | 69 | sess.run(tf.global_variables_initializer()) 70 | face_photo_dir = 'dataset/photo_face' 71 | face_photo_list = utils.load_image_list(face_photo_dir) 72 | scenery_photo_dir = 'dataset/photo_scenery' 73 | scenery_photo_list = utils.load_image_list(scenery_photo_dir) 74 | 75 | 76 | for total_iter in tqdm(range(args.total_iter)): 77 | 78 | if np.mod(total_iter, 5) == 0: 79 | photo_batch = utils.next_batch(face_photo_list, args.batch_size) 80 | else: 81 | photo_batch = utils.next_batch(scenery_photo_list, args.batch_size) 82 | 83 | _, r_loss = sess.run([optim, recon_loss], feed_dict={input_photo: photo_batch}) 84 | 85 | if np.mod(total_iter+1, 50) == 0: 86 | 87 | print('pretrain, iter: {}, recon_loss: {}'.format(total_iter, r_loss)) 88 | if np.mod(total_iter+1, 500 ) == 0: 89 | saver.save(sess, args.save_dir+'save_models/model', 90 | write_meta_graph=False, global_step=total_iter) 91 | 92 | photo_face = utils.next_batch(face_photo_list, args.batch_size) 93 | photo_scenery = utils.next_batch(scenery_photo_list, args.batch_size) 94 | 95 | result_face = sess.run(output, feed_dict={input_photo: photo_face}) 96 | 97 | result_scenery = sess.run(output, feed_dict={input_photo: photo_scenery}) 98 | 99 | utils.write_batch_image(result_face, args.save_dir+'/images', 100 | str(total_iter)+'_face_result.jpg', 4) 101 | utils.write_batch_image(photo_face, args.save_dir+'/images', 102 | str(total_iter)+'_face_photo.jpg', 4) 103 | utils.write_batch_image(result_scenery, args.save_dir+'/images', 104 | str(total_iter)+'_scenery_result.jpg', 4) 105 | utils.write_batch_image(photo_scenery, args.save_dir+'/images', 106 | str(total_iter)+'_scenery_photo.jpg', 4) 107 | 108 | 109 | 110 | 111 | 112 | if __name__ == '__main__': 113 | 114 | args = arg_parser() 115 | train(args) 116 | -------------------------------------------------------------------------------- /train_code/selective_search/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import selective_search, box_filter 2 | -------------------------------------------------------------------------------- /train_code/selective_search/adaptive_color.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def label2rgb(label_field, image, kind='avg', bg_label=-1, bg_color=(0, 0, 0)): 5 | 6 | #std_list = list() 7 | out = np.zeros_like(image) 8 | labels = np.unique(label_field) 9 | bg = (labels == bg_label) 10 | if bg.any(): 11 | labels = labels[labels != bg_label] 12 | mask = (label_field == bg_label).nonzero() 13 | out[mask] = bg_color 14 | for label in labels: 15 | mask = (label_field == label).nonzero() 16 | #std = np.std(image[mask]) 17 | #std_list.append(std) 18 | if kind == 'avg': 19 | color = image[mask].mean(axis=0) 20 | elif kind == 'median': 21 | color = np.median(image[mask], axis=0) 22 | elif kind == 'mix': 23 | std = np.std(image[mask]) 24 | if std < 20: 25 | color = image[mask].mean(axis=0) 26 | elif 20 < std < 40: 27 | mean = image[mask].mean(axis=0) 28 | median = np.median(image[mask], axis=0) 29 | color = 0.5*mean + 0.5*median 30 | elif 40 < std: 31 | color = np.median(image[mask], axis=0) 32 | out[mask] = color 33 | return out -------------------------------------------------------------------------------- /train_code/selective_search/batch_ss.py: -------------------------------------------------------------------------------- 1 | ''' 2 | CVPR 2020 submission, Paper ID 6791 3 | Source code for 'Learning to Cartoonize Using White-Box Cartoon Representations' 4 | ''' 5 | 6 | 7 | import numpy as np 8 | from adaptive_color import label2rgb 9 | from joblib import Parallel, delayed 10 | from skimage.segmentation import felzenszwalb 11 | from util import switch_color_space 12 | from structure import HierarchicalGrouping 13 | 14 | 15 | def color_ss_map(image, color_space='Lab', k=10, 16 | sim_strategy='CTSF', seg_num=200, power=1): 17 | 18 | img_seg = felzenszwalb(image, scale=k, sigma=0.8, min_size=100) 19 | img_cvtcolor = label2rgb(img_seg, image, kind='mix') 20 | img_cvtcolor = switch_color_space(img_cvtcolor, color_space) 21 | S = HierarchicalGrouping(img_cvtcolor, img_seg, sim_strategy) 22 | S.build_regions() 23 | S.build_region_pairs() 24 | 25 | # Start hierarchical grouping 26 | 27 | while S.num_regions() > seg_num: 28 | 29 | i,j = S.get_highest_similarity() 30 | S.merge_region(i,j) 31 | S.remove_similarities(i,j) 32 | S.calculate_similarity_for_new_region() 33 | 34 | image = label2rgb(S.img_seg, image, kind='mix') 35 | image = (image+1)/2 36 | image = image**power 37 | image = image/np.max(image) 38 | image = image*2 - 1 39 | 40 | return image 41 | 42 | 43 | def selective_adacolor(batch_image, seg_num=200, power=1): 44 | num_job = np.shape(batch_image)[0] 45 | batch_out = Parallel(n_jobs=num_job)(delayed(color_ss_map)\ 46 | (image, seg_num, power) for image in batch_image) 47 | return np.array(batch_out) 48 | 49 | 50 | if __name__ == '__main__': 51 | pass -------------------------------------------------------------------------------- /train_code/selective_search/core.py: -------------------------------------------------------------------------------- 1 | from joblib import Parallel, delayed 2 | from skimage.segmentation import felzenszwalb 3 | from util import oversegmentation, switch_color_space, load_strategy 4 | from structure import HierarchicalGrouping 5 | 6 | 7 | 8 | 9 | def selective_search_one(img, color_space, k, sim_strategy): 10 | ''' 11 | Selective Search using single diversification strategy 12 | Parameters 13 | ---------- 14 | im_orig : ndarray 15 | Original image 16 | color_space : string 17 | Colour Spaces 18 | k : int 19 | Threshold parameter for starting regions 20 | sim_stategy : string 21 | Combinations of similarity measures 22 | 23 | Returns 24 | ------- 25 | boxes : list 26 | Bounding boxes of the regions 27 | priority: list 28 | Small priority number indicates higher position in the hierarchy 29 | ''' 30 | 31 | # convert RGB image to target color space 32 | img = switch_color_space(img, color_space) 33 | 34 | # Generate starting locations 35 | img_seg = oversegmentation(img, k) 36 | 37 | # Initialze hierarchical grouping 38 | S = HierarchicalGrouping(img, img_seg, sim_strategy) 39 | 40 | S.build_regions() 41 | S.build_region_pairs() 42 | 43 | # Start hierarchical grouping 44 | while not S.is_empty(): 45 | i,j = S.get_highest_similarity() 46 | 47 | S.merge_region(i,j) 48 | 49 | S.remove_similarities(i,j) 50 | 51 | S.calculate_similarity_for_new_region() 52 | 53 | # convert the order by hierarchical priority 54 | boxes = [x['box'] for x in S.regions.values()][::-1] 55 | 56 | # drop duplicates by maintaining order 57 | boxes = list(dict.fromkeys(boxes)) 58 | 59 | # generate priority for boxes 60 | priorities = list(range(1, len(boxes)+1)) 61 | 62 | return boxes, priorities 63 | 64 | 65 | def selective_search(img, mode='single', random=False): 66 | """ 67 | Selective Search in Python 68 | """ 69 | 70 | # load selective search strategy 71 | strategy = load_strategy(mode) 72 | 73 | # Excecute selective search in parallel 74 | vault = Parallel(n_jobs=1)(delayed(selective_search_one)(img, color, k, sim) for (color, k, sim) in strategy) 75 | 76 | boxes = [x for x,_ in vault] 77 | priorities = [y for _, y in vault] 78 | 79 | boxes = [item for sublist in boxes for item in sublist] 80 | priorities = [item for sublist in priorities for item in sublist] 81 | 82 | if random: 83 | # Do pseudo random sorting as in paper 84 | rand_list = [random() for i in range(len(priorities))] 85 | priorities = [p * r for p, r in zip(priorities, rand_list)] 86 | boxes = [b for _, b in sorted(zip(priorities, boxes))] 87 | 88 | # drop duplicates by maintaining order 89 | boxes = list(dict.fromkeys(boxes)) 90 | 91 | return boxes 92 | 93 | def box_filter(boxes, min_size=20, max_ratio=None, topN=None): 94 | proposal = [] 95 | 96 | for box in boxes: 97 | # Calculate width and height of the box 98 | w, h = box[2] - box[0], box[3] - box[1] 99 | 100 | # Filter for size 101 | if w < min_size or h < min_size: 102 | continue 103 | 104 | # Filter for box ratio 105 | if max_ratio: 106 | if w / h > max_ratio or h / w > max_ratio: 107 | continue 108 | 109 | proposal.append(box) 110 | 111 | if topN: 112 | if topN <= len(proposal): 113 | return proposal[:topN] 114 | else: 115 | return proposal 116 | else: 117 | return proposal 118 | 119 | 120 | 121 | 122 | -------------------------------------------------------------------------------- /train_code/selective_search/measure.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.feature import local_binary_pattern 3 | 4 | def _calculate_color_sim(ri, rj): 5 | """ 6 | Calculate color similarity using histogram intersection 7 | """ 8 | return sum([min(a, b) for a, b in zip(ri["color_hist"], rj["color_hist"])]) 9 | 10 | 11 | def _calculate_texture_sim(ri, rj): 12 | """ 13 | Calculate texture similarity using histogram intersection 14 | """ 15 | return sum([min(a, b) for a, b in zip(ri["texture_hist"], rj["texture_hist"])]) 16 | 17 | 18 | def _calculate_size_sim(ri, rj, imsize): 19 | """ 20 | Size similarity boosts joint between small regions, which prevents 21 | a single region from engulfing other blobs one by one. 22 | 23 | size (ri, rj) = 1 − [size(ri) + size(rj)] / size(image) 24 | """ 25 | return 1.0 - (ri['size'] + rj['size']) / imsize 26 | 27 | 28 | def _calculate_fill_sim(ri, rj, imsize): 29 | """ 30 | Fill similarity measures how well ri and rj fit into each other. 31 | BBij is the bounding box around ri and rj. 32 | 33 | fill(ri, rj) = 1 − [size(BBij) − size(ri) − size(ri)] / size(image) 34 | """ 35 | 36 | bbsize = (max(ri['box'][2], rj['box'][2]) - min(ri['box'][0], rj['box'][0])) * (max(ri['box'][3], rj['box'][3]) - min(ri['box'][1], rj['box'][1])) 37 | 38 | return 1.0 - (bbsize - ri['size'] - rj['size']) / imsize 39 | 40 | 41 | def calculate_color_hist(mask, img): 42 | """ 43 | Calculate colour histogram for the region. 44 | The output will be an array with n_BINS * n_color_channels. 45 | The number of channel is varied because of different 46 | colour spaces. 47 | """ 48 | 49 | BINS = 25 50 | if len(img.shape) == 2: 51 | img = img.reshape(img.shape[0], img.shape[1], 1) 52 | 53 | channel_nums = img.shape[2] 54 | hist = np.array([]) 55 | 56 | for channel in range(channel_nums): 57 | layer = img[:, :, channel][mask] 58 | hist = np.concatenate([hist] + [np.histogram(layer, BINS)[0]]) 59 | 60 | # L1 normalize 61 | hist = hist / np.sum(hist) 62 | 63 | return hist 64 | 65 | 66 | def generate_lbp_image(img): 67 | 68 | if len(img.shape) == 2: 69 | img = img.reshape(img.shape[0], img.shape[1], 1) 70 | channel_nums = img.shape[2] 71 | 72 | lbp_img = np.zeros(img.shape) 73 | for channel in range(channel_nums): 74 | layer = img[:, :, channel] 75 | lbp_img[:, :,channel] = local_binary_pattern(layer, 8, 1) 76 | 77 | return lbp_img 78 | 79 | 80 | def calculate_texture_hist(mask, lbp_img): 81 | """ 82 | Use LBP for now, enlightened by AlpacaDB's implementation. 83 | Plan to switch to Gaussian derivatives as the paper in future 84 | version. 85 | """ 86 | 87 | BINS = 10 88 | channel_nums = lbp_img.shape[2] 89 | hist = np.array([]) 90 | 91 | for channel in range(channel_nums): 92 | layer = lbp_img[:, :, channel][mask] 93 | hist = np.concatenate([hist] + [np.histogram(layer, BINS)[0]]) 94 | 95 | # L1 normalize 96 | hist = hist / np.sum(hist) 97 | 98 | return hist 99 | 100 | 101 | def calculate_sim(ri, rj, imsize, sim_strategy): 102 | """ 103 | Calculate similarity between region ri and rj using diverse 104 | combinations of similarity measures. 105 | C: color, T: texture, S: size, F: fill. 106 | """ 107 | sim = 0 108 | 109 | if 'C' in sim_strategy: 110 | sim += _calculate_color_sim(ri, rj) 111 | if 'T' in sim_strategy: 112 | sim += _calculate_texture_sim(ri, rj) 113 | if 'S' in sim_strategy: 114 | sim += _calculate_size_sim(ri, rj, imsize) 115 | if 'F' in sim_strategy: 116 | sim += _calculate_fill_sim(ri, rj, imsize) 117 | 118 | return sim 119 | -------------------------------------------------------------------------------- /train_code/selective_search/structure.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.segmentation import find_boundaries 3 | from skimage.segmentation import felzenszwalb 4 | from scipy.ndimage import find_objects 5 | import measure 6 | 7 | 8 | class HierarchicalGrouping(object): 9 | def __init__(self, img, img_seg, sim_strategy): 10 | self.img = img 11 | self.sim_strategy = sim_strategy 12 | self.img_seg = img_seg.copy() 13 | self.labels = np.unique(self.img_seg).tolist() 14 | 15 | def build_regions(self): 16 | self.regions = {} 17 | lbp_img = measure.generate_lbp_image(self.img) 18 | for label in self.labels: 19 | size = (self.img_seg == 1).sum() 20 | region_slice = find_objects(self.img_seg==label)[0] 21 | box = tuple([region_slice[i].start for i in (1,0)] + 22 | [region_slice[i].stop for i in (1,0)]) 23 | 24 | mask = self.img_seg == label 25 | color_hist = measure.calculate_color_hist(mask, self.img) 26 | texture_hist = measure.calculate_texture_hist(mask, lbp_img) 27 | 28 | self.regions[label] = { 29 | 'size': size, 30 | 'box': box, 31 | 'color_hist': color_hist, 32 | 'texture_hist': texture_hist 33 | } 34 | 35 | 36 | def build_region_pairs(self): 37 | self.s = {} 38 | for i in self.labels: 39 | neighbors = self._find_neighbors(i) 40 | for j in neighbors: 41 | if i < j: 42 | self.s[(i,j)] = measure.calculate_sim(self.regions[i], 43 | self.regions[j], 44 | self.img.size, 45 | self.sim_strategy) 46 | 47 | 48 | def _find_neighbors(self, label): 49 | """ 50 | Parameters 51 | ---------- 52 | label : int 53 | label of the region 54 | Returns 55 | ------- 56 | neighbors : list 57 | list of labels of neighbors 58 | """ 59 | 60 | boundary = find_boundaries(self.img_seg == label, 61 | mode='outer') 62 | neighbors = np.unique(self.img_seg[boundary]).tolist() 63 | 64 | return neighbors 65 | 66 | def get_highest_similarity(self): 67 | return sorted(self.s.items(), key=lambda i: i[1])[-1][0] 68 | 69 | def merge_region(self, i, j): 70 | 71 | # generate a unique label and put in the label list 72 | new_label = max(self.labels) + 1 73 | self.labels.append(new_label) 74 | 75 | # merge blobs and update blob set 76 | ri, rj = self.regions[i], self.regions[j] 77 | 78 | new_size = ri['size'] + rj['size'] 79 | new_box = (min(ri['box'][0], rj['box'][0]), 80 | min(ri['box'][1], rj['box'][1]), 81 | max(ri['box'][2], rj['box'][2]), 82 | max(ri['box'][3], rj['box'][3])) 83 | value = { 84 | 'box': new_box, 85 | 'size': new_size, 86 | 'color_hist': 87 | (ri['color_hist'] * ri['size'] 88 | + rj['color_hist'] * rj['size']) / new_size, 89 | 'texture_hist': 90 | (ri['texture_hist'] * ri['size'] 91 | + rj['texture_hist'] * rj['size']) / new_size, 92 | } 93 | 94 | self.regions[new_label] = value 95 | 96 | # update segmentation mask 97 | self.img_seg[self.img_seg == i] = new_label 98 | self.img_seg[self.img_seg == j] = new_label 99 | 100 | def remove_similarities(self, i, j): 101 | 102 | # mark keys for region pairs to be removed 103 | key_to_delete = [] 104 | for key in self.s.keys(): 105 | if (i in key) or (j in key): 106 | key_to_delete.append(key) 107 | 108 | for key in key_to_delete: 109 | del self.s[key] 110 | 111 | # remove old labels in label list 112 | self.labels.remove(i) 113 | self.labels.remove(j) 114 | 115 | def calculate_similarity_for_new_region(self): 116 | i = max(self.labels) 117 | neighbors = self._find_neighbors(i) 118 | 119 | for j in neighbors: 120 | # i is larger than j, so use (j,i) instead 121 | self.s[(j,i)] = measure.calculate_sim(self.regions[i], 122 | self.regions[j], 123 | self.img.size, 124 | self.sim_strategy) 125 | 126 | def is_empty(self): 127 | return True if not self.s.keys() else False 128 | 129 | 130 | def num_regions(self): 131 | return len(self.s.keys()) 132 | -------------------------------------------------------------------------------- /train_code/selective_search/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from itertools import product 3 | 4 | from skimage.segmentation import felzenszwalb 5 | from skimage.color import rgb2hsv, rgb2lab, rgb2grey 6 | 7 | 8 | def oversegmentation(img, k): 9 | """ 10 | Generating various starting regions using the method of 11 | Felzenszwalb. 12 | k effectively sets a scale of observation, in that 13 | a larger k causes a preference for larger components. 14 | sigma = 0.8 which was used in the original paper. 15 | min_size = 100 refer to Keon's Matlab implementation. 16 | """ 17 | img_seg = felzenszwalb(img, scale=k, sigma=0.8, min_size=100) 18 | 19 | return img_seg 20 | 21 | 22 | def switch_color_space(img, target): 23 | """ 24 | RGB to target color space conversion. 25 | I: the intensity (grey scale), Lab, rgI: the rg channels of 26 | normalized RGB plus intensity, HSV, H: the Hue channel H from HSV 27 | """ 28 | 29 | if target == 'HSV': 30 | return rgb2hsv(img) 31 | 32 | elif target == 'Lab': 33 | return rgb2lab(img) 34 | 35 | elif target == 'I': 36 | return rgb2grey(img) 37 | 38 | elif target == 'rgb': 39 | img = img / np.sum(img, axis=0) 40 | return img 41 | 42 | elif target == 'rgI': 43 | img = img / np.sum(img, axis=0) 44 | img[:,:,2] = rgb2grey(img) 45 | return img 46 | 47 | elif target == 'H': 48 | return rgb2hsv(img)[:,:,0] 49 | 50 | else: 51 | raise "{} is not suported.".format(target) 52 | 53 | def load_strategy(mode): 54 | # TODO: Add mode sanity check 55 | 56 | cfg = { 57 | "single": { 58 | "ks": [100], 59 | "colors": ["HSV"], 60 | "sims": ["CTSF"] 61 | }, 62 | "lab": { 63 | "ks": [100], 64 | "colors": ["Lab"], 65 | "sims": ["CTSF"] 66 | }, 67 | "fast": { 68 | "ks": [50, 100], 69 | "colors": ["HSV", "Lab"], 70 | "sims": ["CTSF", "TSF"] 71 | }, 72 | "quality": { 73 | "ks": [50, 100, 150, 300], 74 | "colors": ["HSV", "Lab", "I", "rgI", "H"], 75 | "sims": ["CTSF", "TSF", "F", "S"] 76 | } 77 | } 78 | 79 | if isinstance(mode, dict): 80 | cfg['manual'] = mode 81 | mode = 'manual' 82 | 83 | colors, ks, sims = cfg[mode]['colors'], cfg[mode]['ks'], cfg[mode]['sims'] 84 | 85 | return product(colors, ks, sims) 86 | -------------------------------------------------------------------------------- /train_code/train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Source code for CVPR 2020 paper 3 | 'Learning to Cartoonize Using White-Box Cartoon Representations' 4 | by Xinrui Wang and Jinze yu 5 | ''' 6 | 7 | 8 | import tensorflow as tf 9 | import tensorflow.contrib.slim as slim 10 | 11 | import utils 12 | import os 13 | import numpy as np 14 | import argparse 15 | import network 16 | import loss 17 | 18 | from tqdm import tqdm 19 | from guided_filter import guided_filter 20 | 21 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 22 | 23 | 24 | def arg_parser(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--patch_size", default = 256, type = int) 27 | parser.add_argument("--batch_size", default = 16, type = int) 28 | parser.add_argument("--total_iter", default = 100000, type = int) 29 | parser.add_argument("--adv_train_lr", default = 2e-4, type = float) 30 | parser.add_argument("--gpu_fraction", default = 0.5, type = float) 31 | parser.add_argument("--save_dir", default = 'train_cartoon', type = str) 32 | parser.add_argument("--use_enhance", default = False) 33 | 34 | args = parser.parse_args() 35 | 36 | return args 37 | 38 | 39 | 40 | def train(args): 41 | 42 | 43 | input_photo = tf.placeholder(tf.float32, [args.batch_size, 44 | args.patch_size, args.patch_size, 3]) 45 | input_superpixel = tf.placeholder(tf.float32, [args.batch_size, 46 | args.patch_size, args.patch_size, 3]) 47 | input_cartoon = tf.placeholder(tf.float32, [args.batch_size, 48 | args.patch_size, args.patch_size, 3]) 49 | 50 | output = network.unet_generator(input_photo) 51 | output = guided_filter(input_photo, output, r=1) 52 | 53 | 54 | blur_fake = guided_filter(output, output, r=5, eps=2e-1) 55 | blur_cartoon = guided_filter(input_cartoon, input_cartoon, r=5, eps=2e-1) 56 | 57 | gray_fake, gray_cartoon = utils.color_shift(output, input_cartoon) 58 | 59 | d_loss_gray, g_loss_gray = loss.lsgan_loss(network.disc_sn, gray_cartoon, gray_fake, 60 | scale=1, patch=True, name='disc_gray') 61 | d_loss_blur, g_loss_blur = loss.lsgan_loss(network.disc_sn, blur_cartoon, blur_fake, 62 | scale=1, patch=True, name='disc_blur') 63 | 64 | 65 | vgg_model = loss.Vgg19('vgg19_no_fc.npy') 66 | vgg_photo = vgg_model.build_conv4_4(input_photo) 67 | vgg_output = vgg_model.build_conv4_4(output) 68 | vgg_superpixel = vgg_model.build_conv4_4(input_superpixel) 69 | h, w, c = vgg_photo.get_shape().as_list()[1:] 70 | 71 | photo_loss = tf.reduce_mean(tf.losses.absolute_difference(vgg_photo, vgg_output))/(h*w*c) 72 | superpixel_loss = tf.reduce_mean(tf.losses.absolute_difference\ 73 | (vgg_superpixel, vgg_output))/(h*w*c) 74 | recon_loss = photo_loss + superpixel_loss 75 | tv_loss = loss.total_variation_loss(output) 76 | 77 | g_loss_total = 1e4*tv_loss + 1e-1*g_loss_blur + g_loss_gray + 2e2*recon_loss 78 | d_loss_total = d_loss_blur + d_loss_gray 79 | 80 | all_vars = tf.trainable_variables() 81 | gene_vars = [var for var in all_vars if 'gene' in var.name] 82 | disc_vars = [var for var in all_vars if 'disc' in var.name] 83 | 84 | 85 | tf.summary.scalar('tv_loss', tv_loss) 86 | tf.summary.scalar('photo_loss', photo_loss) 87 | tf.summary.scalar('superpixel_loss', superpixel_loss) 88 | tf.summary.scalar('recon_loss', recon_loss) 89 | tf.summary.scalar('d_loss_gray', d_loss_gray) 90 | tf.summary.scalar('g_loss_gray', g_loss_gray) 91 | tf.summary.scalar('d_loss_blur', d_loss_blur) 92 | tf.summary.scalar('g_loss_blur', g_loss_blur) 93 | tf.summary.scalar('d_loss_total', d_loss_total) 94 | tf.summary.scalar('g_loss_total', g_loss_total) 95 | 96 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 97 | with tf.control_dependencies(update_ops): 98 | 99 | g_optim = tf.train.AdamOptimizer(args.adv_train_lr, beta1=0.5, beta2=0.99)\ 100 | .minimize(g_loss_total, var_list=gene_vars) 101 | 102 | d_optim = tf.train.AdamOptimizer(args.adv_train_lr, beta1=0.5, beta2=0.99)\ 103 | .minimize(d_loss_total, var_list=disc_vars) 104 | 105 | ''' 106 | config = tf.ConfigProto() 107 | config.gpu_options.allow_growth = True 108 | sess = tf.Session(config=config) 109 | ''' 110 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_fraction) 111 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 112 | 113 | 114 | train_writer = tf.summary.FileWriter(args.save_dir+'/train_log') 115 | summary_op = tf.summary.merge_all() 116 | saver = tf.train.Saver(var_list=gene_vars, max_to_keep=20) 117 | 118 | with tf.device('/device:GPU:0'): 119 | 120 | sess.run(tf.global_variables_initializer()) 121 | saver.restore(sess, tf.train.latest_checkpoint('pretrain/saved_models')) 122 | 123 | face_photo_dir = 'dataset/photo_face' 124 | face_photo_list = utils.load_image_list(face_photo_dir) 125 | scenery_photo_dir = 'dataset/photo_scenery' 126 | scenery_photo_list = utils.load_image_list(scenery_photo_dir) 127 | 128 | face_cartoon_dir = 'dataset/cartoon_face' 129 | face_cartoon_list = utils.load_image_list(face_cartoon_dir) 130 | scenery_cartoon_dir = 'dataset/cartoon_scenery' 131 | scenery_cartoon_list = utils.load_image_list(scenery_cartoon_dir) 132 | 133 | for total_iter in tqdm(range(args.total_iter)): 134 | 135 | if np.mod(total_iter, 5) == 0: 136 | photo_batch = utils.next_batch(face_photo_list, args.batch_size) 137 | cartoon_batch = utils.next_batch(face_cartoon_list, args.batch_size) 138 | else: 139 | photo_batch = utils.next_batch(scenery_photo_list, args.batch_size) 140 | cartoon_batch = utils.next_batch(scenery_cartoon_list, args.batch_size) 141 | 142 | inter_out = sess.run(output, feed_dict={input_photo: photo_batch, 143 | input_superpixel: photo_batch, 144 | input_cartoon: cartoon_batch}) 145 | 146 | ''' 147 | adaptive coloring has to be applied with the clip_by_value 148 | in the last layer of generator network, which is not very stable. 149 | to stabiliy reproduce our results, please use power=1.0 150 | and comment the clip_by_value function in the network.py first 151 | If this works, then try to use adaptive color with clip_by_value. 152 | ''' 153 | if args.use_enhance: 154 | superpixel_batch = utils.selective_adacolor(inter_out, power=1.2) 155 | else: 156 | superpixel_batch = utils.simple_superpixel(inter_out, seg_num=200) 157 | 158 | _, g_loss, r_loss = sess.run([g_optim, g_loss_total, recon_loss], 159 | feed_dict={input_photo: photo_batch, 160 | input_superpixel: superpixel_batch, 161 | input_cartoon: cartoon_batch}) 162 | 163 | _, d_loss, train_info = sess.run([d_optim, d_loss_total, summary_op], 164 | feed_dict={input_photo: photo_batch, 165 | input_superpixel: superpixel_batch, 166 | input_cartoon: cartoon_batch}) 167 | 168 | 169 | train_writer.add_summary(train_info, total_iter) 170 | 171 | if np.mod(total_iter+1, 50) == 0: 172 | 173 | print('Iter: {}, d_loss: {}, g_loss: {}, recon_loss: {}'.\ 174 | format(total_iter, d_loss, g_loss, r_loss)) 175 | if np.mod(total_iter+1, 500 ) == 0: 176 | saver.save(sess, args.save_dir+'/saved_models/model', 177 | write_meta_graph=False, global_step=total_iter) 178 | 179 | photo_face = utils.next_batch(face_photo_list, args.batch_size) 180 | cartoon_face = utils.next_batch(face_cartoon_list, args.batch_size) 181 | photo_scenery = utils.next_batch(scenery_photo_list, args.batch_size) 182 | cartoon_scenery = utils.next_batch(scenery_cartoon_list, args.batch_size) 183 | 184 | result_face = sess.run(output, feed_dict={input_photo: photo_face, 185 | input_superpixel: photo_face, 186 | input_cartoon: cartoon_face}) 187 | 188 | result_scenery = sess.run(output, feed_dict={input_photo: photo_scenery, 189 | input_superpixel: photo_scenery, 190 | input_cartoon: cartoon_scenery}) 191 | 192 | utils.write_batch_image(result_face, args.save_dir+'/images', 193 | str(total_iter)+'_face_result.jpg', 4) 194 | utils.write_batch_image(photo_face, args.save_dir+'/images', 195 | str(total_iter)+'_face_photo.jpg', 4) 196 | 197 | utils.write_batch_image(result_scenery, args.save_dir+'/images', 198 | str(total_iter)+'_scenery_result.jpg', 4) 199 | utils.write_batch_image(photo_scenery, args.save_dir+'/images', 200 | str(total_iter)+'_scenery_photo.jpg', 4) 201 | 202 | 203 | if __name__ == '__main__': 204 | 205 | args = arg_parser() 206 | train(args) 207 | -------------------------------------------------------------------------------- /train_code/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Source code for CVPR 2020 paper 3 | 'Learning to Cartoonize Using White-Box Cartoon Representations' 4 | by Xinrui Wang and Jinze yu 5 | ''' 6 | 7 | 8 | from scipy.ndimage import filters 9 | from skimage import segmentation, color 10 | from joblib import Parallel, delayed 11 | from selective_search.util import switch_color_space 12 | from selective_search.structure import HierarchicalGrouping 13 | 14 | import os 15 | import cv2 16 | import numpy as np 17 | import scipy.stats as st 18 | import tensorflow as tf 19 | 20 | 21 | 22 | def color_shift(image1, image2, mode='uniform'): 23 | b1, g1, r1 = tf.split(image1, num_or_size_splits=3, axis=3) 24 | b2, g2, r2 = tf.split(image2, num_or_size_splits=3, axis=3) 25 | if mode == 'normal': 26 | b_weight = tf.random.normal(shape=[1], mean=0.114, stddev=0.1) 27 | g_weight = np.random.normal(shape=[1], mean=0.587, stddev=0.1) 28 | r_weight = np.random.normal(shape=[1], mean=0.299, stddev=0.1) 29 | elif mode == 'uniform': 30 | b_weight = tf.random.uniform(shape=[1], minval=0.014, maxval=0.214) 31 | g_weight = tf.random.uniform(shape=[1], minval=0.487, maxval=0.687) 32 | r_weight = tf.random.uniform(shape=[1], minval=0.199, maxval=0.399) 33 | output1 = (b_weight*b1+g_weight*g1+r_weight*r1)/(b_weight+g_weight+r_weight) 34 | output2 = (b_weight*b2+g_weight*g2+r_weight*r2)/(b_weight+g_weight+r_weight) 35 | return output1, output2 36 | 37 | 38 | 39 | 40 | def label2rgb(label_field, image, kind='mix', bg_label=-1, bg_color=(0, 0, 0)): 41 | 42 | #std_list = list() 43 | out = np.zeros_like(image) 44 | labels = np.unique(label_field) 45 | bg = (labels == bg_label) 46 | if bg.any(): 47 | labels = labels[labels != bg_label] 48 | mask = (label_field == bg_label).nonzero() 49 | out[mask] = bg_color 50 | for label in labels: 51 | mask = (label_field == label).nonzero() 52 | #std = np.std(image[mask]) 53 | #std_list.append(std) 54 | if kind == 'avg': 55 | color = image[mask].mean(axis=0) 56 | elif kind == 'median': 57 | color = np.median(image[mask], axis=0) 58 | elif kind == 'mix': 59 | std = np.std(image[mask]) 60 | if std < 20: 61 | color = image[mask].mean(axis=0) 62 | elif 20 < std < 40: 63 | mean = image[mask].mean(axis=0) 64 | median = np.median(image[mask], axis=0) 65 | color = 0.5*mean + 0.5*median 66 | elif 40 < std: 67 | color = image[mask].median(axis=0) 68 | out[mask] = color 69 | return out 70 | 71 | 72 | 73 | def color_ss_map(image, seg_num=200, power=1, 74 | color_space='Lab', k=10, sim_strategy='CTSF'): 75 | 76 | img_seg = segmentation.felzenszwalb(image, scale=k, sigma=0.8, min_size=100) 77 | img_cvtcolor = label2rgb(img_seg, image, kind='mix') 78 | img_cvtcolor = switch_color_space(img_cvtcolor, color_space) 79 | S = HierarchicalGrouping(img_cvtcolor, img_seg, sim_strategy) 80 | S.build_regions() 81 | S.build_region_pairs() 82 | 83 | # Start hierarchical grouping 84 | 85 | while S.num_regions() > seg_num: 86 | 87 | i,j = S.get_highest_similarity() 88 | S.merge_region(i,j) 89 | S.remove_similarities(i,j) 90 | S.calculate_similarity_for_new_region() 91 | 92 | image = label2rgb(S.img_seg, image, kind='mix') 93 | image = (image+1)/2 94 | image = image**power 95 | image = image/np.max(image) 96 | image = image*2 - 1 97 | 98 | return image 99 | 100 | 101 | def selective_adacolor(batch_image, seg_num=200, power=1): 102 | num_job = np.shape(batch_image)[0] 103 | batch_out = Parallel(n_jobs=num_job)(delayed(color_ss_map)\ 104 | (image, seg_num, power) for image in batch_image) 105 | return np.array(batch_out) 106 | 107 | 108 | 109 | def simple_superpixel(batch_image, seg_num=200): 110 | 111 | def process_slic(image): 112 | seg_label = segmentation.slic(image, n_segments=seg_num, sigma=1, 113 | compactness=10, convert2lab=True) 114 | image = color.label2rgb(seg_label, image, kind='mix') 115 | return image 116 | 117 | num_job = np.shape(batch_image)[0] 118 | batch_out = Parallel(n_jobs=num_job)(delayed(process_slic)\ 119 | (image) for image in batch_image) 120 | return np.array(batch_out) 121 | 122 | 123 | 124 | def load_image_list(data_dir): 125 | name_list = list() 126 | for name in os.listdir(data_dir): 127 | name_list.append(os.path.join(data_dir, name)) 128 | name_list.sort() 129 | return name_list 130 | 131 | 132 | def next_batch(filename_list, batch_size): 133 | idx = np.arange(0 , len(filename_list)) 134 | np.random.shuffle(idx) 135 | idx = idx[:batch_size] 136 | batch_data = [] 137 | for i in range(batch_size): 138 | image = cv2.imread(filename_list[idx[i]]) 139 | image = image.astype(np.float32)/127.5 - 1 140 | #image = image.astype(np.float32)/255.0 141 | batch_data.append(image) 142 | 143 | return np.asarray(batch_data) 144 | 145 | 146 | 147 | def write_batch_image(image, save_dir, name, n): 148 | 149 | if not os.path.exists(save_dir): 150 | os.makedirs(save_dir) 151 | 152 | fused_dir = os.path.join(save_dir, name) 153 | fused_image = [0] * n 154 | for i in range(n): 155 | fused_image[i] = [] 156 | for j in range(n): 157 | k = i * n + j 158 | image[k] = (image[k]+1) * 127.5 159 | #image[k] = image[k] - np.min(image[k]) 160 | #image[k] = image[k]/np.max(image[k]) 161 | #image[k] = image[k] * 255.0 162 | fused_image[i].append(image[k]) 163 | fused_image[i] = np.hstack(fused_image[i]) 164 | fused_image = np.vstack(fused_image) 165 | cv2.imwrite(fused_dir, fused_image.astype(np.uint8)) 166 | 167 | 168 | if __name__ == '__main__': 169 | pass 170 | --------------------------------------------------------------------------------