├── Chinese ├── c_detail.jpg ├── c_detail_sem.png ├── content.jpg ├── content_sem.png ├── s1_detail.jpg ├── s1_detail_sem.png ├── s2_detail.jpg ├── s2_detail_sem.png ├── style.jpg └── style_sem.png ├── GLStyleNet.py ├── LICENSE ├── README.md ├── artistic ├── c1.jpg ├── c2.png ├── c3.png ├── c4.png ├── s1.png ├── s2.png ├── s3.png └── s4.png ├── download_vgg19.sh ├── examples ├── Chinese.png ├── artistic.png ├── photo-realistic.png └── portraits.png ├── outputs └── ReadMe ├── photo-realistic ├── c1.png ├── c1_sem.png ├── c2.png ├── c2_sem.png ├── c3.png ├── c3_sem.png ├── c4.png ├── c4_sem.png ├── s1.png ├── s1_sem.png ├── s2.png ├── s2_sem.png ├── s3.png ├── s3_sem.png ├── s4.png └── s4_sem.png └── portrait ├── Freddie.jpg ├── Freddie_sem.png ├── Gogh.jpg ├── Gogh_sem.png ├── Mia.jpg ├── Mia_sem.png ├── Seth.jpg └── Seth_sem.png /Chinese/c_detail.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/Chinese/c_detail.jpg -------------------------------------------------------------------------------- /Chinese/c_detail_sem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/Chinese/c_detail_sem.png -------------------------------------------------------------------------------- /Chinese/content.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/Chinese/content.jpg -------------------------------------------------------------------------------- /Chinese/content_sem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/Chinese/content_sem.png -------------------------------------------------------------------------------- /Chinese/s1_detail.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/Chinese/s1_detail.jpg -------------------------------------------------------------------------------- /Chinese/s1_detail_sem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/Chinese/s1_detail_sem.png -------------------------------------------------------------------------------- /Chinese/s2_detail.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/Chinese/s2_detail.jpg -------------------------------------------------------------------------------- /Chinese/s2_detail_sem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/Chinese/s2_detail_sem.png -------------------------------------------------------------------------------- /Chinese/style.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/Chinese/style.jpg -------------------------------------------------------------------------------- /Chinese/style_sem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/Chinese/style_sem.png -------------------------------------------------------------------------------- /GLStyleNet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import skimage.io 4 | import itertools 5 | import os 6 | import bz2 7 | import argparse 8 | import scipy 9 | import skimage.transform 10 | 11 | CONTENT_LAYERS = ['4_1'] 12 | LOCAL_STYLE_LAYERS = ['1_1','2_1','3_1','4_1'] 13 | GLOBAL_STYLE_LAYERS=['1_1','2_1','3_1','4_1'] 14 | 15 | 16 | def conv2d(input_tensor, kernel, bias): 17 | kernel = np.transpose(kernel, [2, 3, 1, 0]) 18 | x = tf.pad(input_tensor, [[0,0], [1,1], [1,1], [0,0]]) 19 | x = tf.nn.conv2d(x, tf.constant(kernel), (1,1,1,1), 'VALID') 20 | x = tf.nn.bias_add(x, tf.constant(bias)) 21 | return tf.nn.relu(x) 22 | 23 | def avg_pooling(input_tensor, size=2): 24 | return tf.nn.pool(input_tensor, [size, size], 'AVG', 'VALID', strides=[size, size]) 25 | 26 | def norm(arr): 27 | n, *shape = arr.shape 28 | lst = [] 29 | for i in range(n): 30 | v = arr[i, :].flatten() 31 | v /= np.sqrt(sum(v**2)) 32 | lst.append(np.reshape(v, shape)) 33 | return lst 34 | 35 | def build_base_net(input_tensor,input_map=None): 36 | vgg19_file = os.path.join(os.path.dirname(__file__), 'vgg19.pkl.bz2') 37 | assert os.path.exists(vgg19_file), ("Model file with pre-trained convolution layers not found. Download here: " 38 | +"https://github.com/alexjc/neural-doodle/releases/download/v0.0/vgg19_conv.pkl.bz2") 39 | 40 | data = np.load(bz2.open(vgg19_file, 'rb')) 41 | k = 0 42 | net = {} 43 | # network divided into two parts,main and map,main downsamples the image,map dowsamples the semantic map 44 | net['img'] = input_tensor 45 | net['conv1_1'] = conv2d(net['img'], data[k], data[k+1]) 46 | k += 2 47 | net['conv1_2'] = conv2d(net['conv1_1'], data[k], data[k+1]) 48 | k += 2 49 | # average pooling without padding 50 | net['pool1'] = avg_pooling(net['conv1_2']) 51 | net['conv2_1'] = conv2d(net['pool1'], data[k], data[k+1]) 52 | k += 2 53 | net['conv2_2'] = conv2d(net['conv2_1'], data[k], data[k+1]) 54 | k += 2 55 | net['pool2'] = avg_pooling(net['conv2_2']) 56 | net['conv3_1'] = conv2d(net['pool2'], data[k], data[k+1]) 57 | k += 2 58 | net['conv3_2'] = conv2d(net['conv3_1'], data[k], data[k+1]) 59 | k += 2 60 | net['conv3_3'] = conv2d(net['conv3_2'], data[k], data[k+1]) 61 | k += 2 62 | net['conv3_4'] = conv2d(net['conv3_3'], data[k], data[k+1]) 63 | k += 2 64 | net['pool3'] = avg_pooling(net['conv3_4']) 65 | net['conv4_1'] = conv2d(net['pool3'], data[k], data[k+1]) 66 | k += 2 67 | net['conv4_2'] = conv2d(net['conv4_1'], data[k], data[k+1]) 68 | k += 2 69 | net['conv4_3'] = conv2d(net['conv4_2'], data[k], data[k+1]) 70 | k += 2 71 | net['conv4_4'] = conv2d(net['conv4_3'], data[k], data[k+1]) 72 | k += 2 73 | net['pool4'] = avg_pooling(net['conv4_4']) 74 | net['conv5_1'] = conv2d(net['pool4'], data[k], data[k+1]) 75 | k += 2 76 | net['conv5_2'] = conv2d(net['conv5_1'], data[k], data[k+1]) 77 | k += 2 78 | net['conv5_3'] = conv2d(net['conv5_2'], data[k], data[k+1]) 79 | k += 2 80 | net['conv5_4'] = conv2d(net['conv5_3'], data[k], data[k+1]) 81 | k += 2 82 | net['main'] = net['conv5_4'] 83 | 84 | net['map'] = input_map 85 | for j, i in itertools.product(range(5), range(4)): 86 | if j < 2 and i > 1: continue 87 | suffix = '%i_%i' % (j+1, i+1) 88 | 89 | if i == 0: 90 | net['map%i'%(j+1)] = avg_pooling(net['map'], 2**j) 91 | net['sem'+suffix] = tf.concat([net['conv'+suffix], net['map%i'%(j+1)]], -1) 92 | return net 93 | 94 | 95 | def extract_target_data(content, content_mask, style, style_mask): 96 | pixel_mean = np.array([103.939, 116.779, 123.680], dtype=np.float32).reshape((1,1,1,3)) 97 | # local style patches extracting 98 | input_tensor = style-pixel_mean 99 | input_map= style_mask 100 | net = build_base_net(input_tensor, input_map) 101 | local_features = [net['sem'+layer] for layer in LOCAL_STYLE_LAYERS] 102 | # layer aggregation for local style 103 | LF=local_features[0] 104 | for i in range(1,len(LOCAL_STYLE_LAYERS)): 105 | lf=local_features[i] 106 | LF=tf.image.resize_images(LF,[lf.shape[1],lf.shape[2]],method=tf.image.ResizeMethod.BILINEAR) 107 | LF=tf.concat([LF,lf],3) 108 | 109 | dim = LF.shape[-1].value 110 | x = tf.extract_image_patches(LF, (1,3,3,1), (1,1,1,1), (1,1,1,1), 'VALID') 111 | patches=tf.reshape(x, (-1, 3, 3, dim)) 112 | 113 | # content features 114 | input_tensor = content-pixel_mean 115 | input_map= content_mask 116 | net = build_base_net(input_tensor, input_map) 117 | content_features = [net['conv'+layer] for layer in CONTENT_LAYERS] 118 | content_data=[] 119 | 120 | # global feature correlations based on fused features 121 | input_tensor = style-pixel_mean 122 | input_map= style_mask 123 | net = build_base_net(input_tensor, input_map) 124 | global_features = [net['conv'+layer] for layer in GLOBAL_STYLE_LAYERS] 125 | GF=global_features[0] 126 | for i in range(1,len(GLOBAL_STYLE_LAYERS)): 127 | gf=global_features[i] 128 | GF=tf.image.resize_images(GF,[gf.shape[1],gf.shape[2]],method=tf.image.ResizeMethod.BILINEAR) 129 | GF=tf.concat([GF,gf],3) 130 | 131 | N=int(GF.shape[3]) 132 | M=int(GF.shape[1]*GF.shape[2]) 133 | GF=tf.reshape(GF,(M,N)) 134 | GF_corr=tf.matmul(tf.transpose(GF),GF) 135 | 136 | with tf.Session() as sess: 137 | sess.run(tf.global_variables_initializer()) 138 | patches=patches.eval() 139 | for c in content_features: 140 | content_data.append(c.eval()) 141 | global_data=GF_corr.eval() 142 | 143 | return content_data,patches,global_data 144 | 145 | 146 | def format_and_norm(arr, depth, sem_weight): 147 | n, *shape = arr.shape 148 | norm = np.zeros(shape+[n], dtype=arr.dtype) 149 | un_norm = np.zeros(shape+[n], dtype=arr.dtype) 150 | for i in range(n): 151 | t = arr[i, ...] 152 | un_norm[..., i] = t 153 | t1 = t[..., :depth] 154 | t1 = t1/np.sqrt(3*np.sum(t1**2)+1e-6) 155 | t2 = t[..., depth:] 156 | t2 = t2/np.sqrt(sem_weight*np.sum(t2**2)+1e-6) 157 | 158 | norm[..., i] = np.concatenate([t1,t2], -1) 159 | return norm, un_norm 160 | 161 | 162 | """GLStyleNet""" 163 | class Model(object): 164 | def __init__(self, args, content, style, style2, content_mask=None, style_mask=None): 165 | self.args = args 166 | if len(args.device)>3 and args.device[:3]=='gpu': 167 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device[3:] 168 | elif args.device=='cpu': 169 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 170 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 171 | self.pixel_mean = np.array([103.939, 116.779, 123.680], dtype=np.float32).reshape((1,1,1,3)) 172 | 173 | self.content = np.expand_dims(content, 0).astype(np.float32) 174 | self.style = np.expand_dims(style, 0).astype(np.float32) 175 | self.style2= np.expand_dims(style2, 0).astype(np.float32) 176 | 177 | if content_mask is not None: 178 | self.content_mask = np.expand_dims(content_mask, 0).astype(np.float32) 179 | else: 180 | self.content_mask = np.ones(self.content.shape[:-1]+(1,), np.float32) 181 | self.args.semantic_weight= 0.0 182 | if style_mask is not None: 183 | self.style_mask = np.expand_dims(style_mask, 0).astype(np.float32) 184 | else: 185 | self.style_mask = np.ones(self.style.shape[:-1]+(1,), np.float32) 186 | self.args.semantic_weight = 0.0 187 | assert self.content_mask.shape[-1] == self.style_mask.shape[-1] 188 | 189 | 190 | self.args.semantic_weight=100/self.args.semantic_weight if self.args.semantic_weight else 1E+8 191 | 192 | self.mask_depth = self.content_mask.shape[-1] 193 | # get target content features, local patches, global feature correlations 194 | self.content_data, self.local_data, self.global_data= extract_target_data(self.content, self.content_mask, self.style, self.style_mask) 195 | tf.reset_default_graph() 196 | 197 | if args.init=='style': 198 | input_tensor = tf.Variable(self.style2) 199 | elif args.init=='content': 200 | input_tensor = tf.Variable(self.content) 201 | else: 202 | input_tensor = tf.Variable(np.random.uniform(16, 240, self.content.shape).astype(np.float32)) 203 | 204 | input_map=tf.Variable(self.content_mask) 205 | self.net = build_base_net(input_tensor, input_map) 206 | 207 | self.content_features = [self.net['conv'+layer] for layer in CONTENT_LAYERS] 208 | self.local_features = [self.net['sem'+layer] for layer in LOCAL_STYLE_LAYERS] 209 | self.global_features = [self.net['conv'+layer] for layer in GLOBAL_STYLE_LAYERS] 210 | 211 | # local style layer aggregation 212 | LF=self.local_features[0] 213 | for i in range(1,len(LOCAL_STYLE_LAYERS)): 214 | lf=self.local_features[i] 215 | LF=tf.image.resize_images(LF,[lf.shape[1],lf.shape[2]],method=tf.image.ResizeMethod.BILINEAR) 216 | LF=tf.concat([LF,lf],3) 217 | 218 | # patch-matching,concatenate semantic maps 219 | self.local_loss = 0 220 | sem = LF 221 | patches = tf.extract_image_patches(sem, (1,3,3,1), (1,1,1,1), (1,1,1,1), 'VALID') 222 | patches = tf.reshape(patches, (-1, 3, 3, sem.shape[-1].value)) 223 | 224 | pow2 = patches**2 225 | p1 = tf.reduce_sum(pow2[..., :-self.mask_depth], [1,2,3]) 226 | p1 = tf.reshape(p1, [-1,1,1,1]) 227 | p1 = pow2[..., :-self.mask_depth]/(3*p1+1e-6) 228 | p2 = tf.reduce_sum(pow2[..., -self.mask_depth:], [1,2,3]) 229 | p2 = tf.reshape(p2, [-1,1,1,1]) 230 | p2 = pow2[..., -self.mask_depth:]/(self.args.semantic_weight*p2+1e-6) 231 | norm_patch = tf.concat([p1, p2], -1) 232 | norm_patch = tf.reshape(norm_patch, [-1, 9*sem.shape[-1].value]) 233 | 234 | norm, un_norm = format_and_norm(self.local_data, -self.mask_depth, self.args.semantic_weight) 235 | norm = np.reshape(norm, [9*sem.shape[-1].value, -1]) 236 | sim = tf.matmul(norm_patch, norm) 237 | max_ind = tf.argmax(sim, axis=-1) 238 | target_patches = tf.gather(self.local_data, tf.reshape(max_ind, [-1])) 239 | 240 | # local style loss 241 | self.local_loss += tf.reduce_mean((patches[...,:-self.mask_depth]-target_patches[...,:-self.mask_depth])**2) 242 | self.local_loss *= args.local_weight 243 | 244 | # content loss 245 | self.content_loss = 0 246 | for c, t in zip(self.content_features, self.content_data) : 247 | self.content_loss += tf.reduce_mean((c-t)**2) 248 | self.content_loss *= args.content_weight 249 | 250 | # total variation regularization loss 251 | self.tv_loss = args.smoothness*(tf.reduce_mean(tf.abs(input_tensor[..., :-1,:]-input_tensor[..., 1:,:])) 252 | +tf.reduce_mean(tf.abs(input_tensor[..., :, :-1]-input_tensor[..., :,1:]))) 253 | 254 | # global style loss 255 | GF=self.global_features[0] 256 | for i in range(1,len(GLOBAL_STYLE_LAYERS)): 257 | gf=self.global_features[i] 258 | GF=tf.image.resize_images(GF,[gf.shape[1],gf.shape[2]],method=tf.image.ResizeMethod.BILINEAR) 259 | GF=tf.concat([GF,gf],3) 260 | 261 | N=int(GF.shape[3]) 262 | M=int(GF.shape[1]*GF.shape[2]) 263 | GF=tf.reshape(GF,(M,N)) 264 | GF_corr=tf.matmul(tf.transpose(GF),GF) 265 | 266 | self.global_loss = tf.reduce_sum(((GF_corr-self.global_data)**2)/((2*M*N)**2)) 267 | self.global_loss *= args.global_weight 268 | 269 | # total loss 270 | self.loss = self.local_loss + self.content_loss + self.tv_loss + self.global_loss 271 | self.grad = tf.gradients(self.loss, self.net['img']) 272 | tf.summary.scalar('loss', self.loss) 273 | self.merged = tf.summary.merge_all() 274 | self.summary_writer = tf.summary.FileWriter('./summary', tf.get_default_graph()) 275 | def evaluate(self): 276 | sess = tf.Session() 277 | def func(img): 278 | self.iter += 1 279 | current_img = img.reshape(self.content.shape).astype(np.float32) - self.pixel_mean 280 | 281 | feed_dict = {self.net['img']:current_img, self.net['map']:self.content_mask} 282 | loss = 0 283 | grads = 0 284 | local_loss = 0 285 | content_loss = 0 286 | tv_loss=0 287 | global_loss=0 288 | sess.run(tf.global_variables_initializer()) 289 | loss, grads, local_loss, content_loss, tv_loss, global_loss, summ= sess.run( 290 | [self.loss, self.grad, self.local_loss, self.content_loss, self.tv_loss, self.global_loss, self.merged], 291 | feed_dict=feed_dict) 292 | self.summary_writer.add_summary(summ, self.iter) 293 | if self.iter % 10 == 0: 294 | out = current_img + self.pixel_mean 295 | out = np.squeeze(out) 296 | out = np.clip(out, 0, 255).astype('uint8') 297 | skimage.io.imsave('outputs/%s-%d.jpg'%(self.args.output, self.iter), out) 298 | 299 | print('Epoch:%d,loss:%f,local loss:%f,global loss:%f,content loss:%f,tv loss: %f.'% 300 | (self.iter, loss, local_loss, global_loss, content_loss, tv_loss)) 301 | if np.isnan(grads).any(): 302 | raise OverflowError("Optimization diverged; try using a different device or parameters.") 303 | 304 | # Return the data in the right format for L-BFGS. 305 | return loss, np.array(grads).flatten().astype(np.float64) 306 | return func 307 | 308 | def run(self): 309 | args = self.args 310 | if args.init=='style': 311 | Xn = self.style2 312 | elif args.init=='content': 313 | Xn = self.content 314 | else: 315 | Xn = np.random.uniform(16, 240, self.content.shape).astype(np.float32) 316 | 317 | self.iter = 0 318 | # Optimization algorithm needs min and max bounds to prevent divergence. 319 | data_bounds = np.zeros((np.product(Xn.shape), 2), dtype=np.float64) 320 | data_bounds[:] = (0.0, 255.0) 321 | print ("GLStyleNet: Start") 322 | try: 323 | Xn, *_ = scipy.optimize.fmin_l_bfgs_b( 324 | self.evaluate(), 325 | Xn.flatten(), 326 | bounds=data_bounds, 327 | factr=0.0, pgtol=0.0, # Disable automatic termination, set low threshold. 328 | m=5, # Maximum correlations kept in memory by algorithm. 329 | maxfun=args.iterations, # Limit number of calls to evaluate(). 330 | iprint=-1) # Handle our own logging of information. 331 | except OverflowError: 332 | print("The optimization diverged and NaNs were encountered.", 333 | " - Try using a different `--device` or change the parameters.", 334 | " - Make sure libraries are updated to work around platform bugs.") 335 | except KeyboardInterrupt: 336 | print("User canceled.") 337 | except Exception as e: 338 | print(e) 339 | 340 | print ("GLStyleNet: Completed!") 341 | 342 | self.summary_writer.close() 343 | 344 | 345 | def prepare_mask(content_mask, style_mask, n): 346 | from sklearn.cluster import KMeans 347 | x1 = content_mask.reshape((-1, content_mask.shape[-1])) 348 | x2 = style_mask.reshape((-1, style_mask.shape[-1])) 349 | kmeans = KMeans(n_clusters=n, random_state=0).fit(x1) 350 | y1 = kmeans.labels_ 351 | y2 = kmeans.predict(x2) 352 | y1 = y1.reshape(content_mask.shape[:-1]) 353 | y2 = y2.reshape(style_mask.shape[:-1]) 354 | diag = np.diag([1 for _ in range(n)]) 355 | return diag[y1].astype(np.float32), diag[y2].astype(np.float32) 356 | 357 | def main(): 358 | parser = argparse.ArgumentParser(description='GLStyleNet: transfer style of an image onto a content image.', 359 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 360 | add_arg = parser.add_argument 361 | 362 | add_arg('--content', default=None, type=str, help='Content image path.') 363 | add_arg('--content-mask', default=None, type=str, help='Content image semantic mask.') 364 | add_arg('--content-weight', default=10, type=float, help='Weight of content.') 365 | add_arg('--style', default=None, type=str, help='Style image path.') 366 | add_arg('--style-mask', default=None, type=str, help='Style image semantic map.') 367 | add_arg('--local-weight', default=100, type=float, help='Weight of local style loss.') 368 | add_arg('--semantic-weight', default=10, type=float, help='Weight of semantic map channel.') 369 | add_arg('--global-weight', default=0.1, type=float, help='Weight of global style loss.') 370 | add_arg('--output', default='output', type=str, help='Output image path.') 371 | add_arg('--smoothness', default=1E+0, type=float, help='Weight of image smoothing scheme.') 372 | 373 | add_arg('--init', default='content', type=str, help='Image path to initialize, "noise" or "content" or "style".') 374 | add_arg('--iterations', default=500, type=int, help='Number of iterations.') 375 | add_arg('--device', default='gpu', type=str, help='devices: "gpu"(default: all gpu) or "gpui"(e.g. gpu0) or "cpu" ') 376 | add_arg('--class-num', default=5, type=int, help='Count of semantic mask classes.') 377 | 378 | args = parser.parse_args() 379 | 380 | style = skimage.io.imread(args.style) 381 | if args.style_mask: 382 | style_mask = skimage.io.imread(args.style_mask) 383 | 384 | content = skimage.io.imread(args.content) 385 | if args.content_mask: 386 | content_mask = skimage.io.imread(args.content_mask) 387 | 388 | if style.shape[0]==content.shape[0] and style.shape[1]==content.shape[1]: 389 | style2=style 390 | else: 391 | style2=skimage.transform.resize(style,(content.shape[0],content.shape[1])) 392 | 393 | if args.content_mask and args.style_mask: 394 | content_mask, style_mask = prepare_mask(content_mask, style_mask, args.class_num) 395 | model = Model(args, content, style, style2, content_mask, style_mask) 396 | else: 397 | model = Model(args, content, style, style2) 398 | model.run() 399 | 400 | 401 | if __name__ == '__main__': 402 | main() 403 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Zhizhong Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # GLStyleNet 3 | **[update 1/12/2022]** 4 | 5 | paper: [GLStyleNet: Exquisite Style Transfer Combining Global and Local Pyramid Features](https://ietresearch.onlinelibrary.wiley.com/doi/pdf/10.1049/iet-cvi.2019.0844), published in [IET Computer Vision 2020](https://digital-library.theiet.org/content/journals/iet-cvi). 6 | 7 | Arxiv paper: [GLStyleNet: Higher Quality Style Transfer Combining Global and Local Pyramid Features](https://arxiv.org/abs/1811.07260). 8 | 9 | ### Environment Required: 10 | - Python 3.6 11 | - TensorFlow 1.4.0 12 | - CUDA 8.0 13 | 14 | ### Getting Started: 15 | Step 1: clone this repo 16 | 17 | 18 | `git clone https://github.com/EndyWon/GLStyleNet` 19 | `cd GLStyleNet` 20 | 21 | 22 | Step 2: download pre-trained vgg19 model 23 | 24 | 25 | `bash download_vgg19.sh` 26 | 27 | 28 | Step 3: run style transfer 29 | 1. **Script Parameters** 30 | * `--content` : content image path 31 | * `--content-mask` : content image semantic mask 32 | * `--style` : style image path 33 | * `--style-mask` : style image semantic mask 34 | * `--content-weight` : weight of content, default=10 35 | * `--local-weight` : weight of local style loss 36 | * `--semantic-weight` : weight of semantic map constraint 37 | * `--global-weight` : weight of global style loss 38 | * `--output` : output image path 39 | * `--smoothness` : weight of image smoothing scheme 40 | * `--init` : image type to initialize, value='noise' or 'content' or 'style', default='content' 41 | * `--iterations` : number of iterations, default=500 42 | * `--device` : devices, value='gpu'(all available GPUs) or 'gpui'(e.g. gpu0) or 'cpu', default='gpu' 43 | * `--class-num` : count of semantic mask classes, default=5 44 | 45 | 2. **portrait style transfer** (an example) 46 | 47 | 48 | `python GLStyleNet.py --content portrait/Seth.jpg --content-mask portrait/Seth_sem.png --style portrait/Gogh.jpg --style-mask portrait/Gogh_sem.png --content-weight 10 --local-weight 500 --semantic-weight 10 --global-weight 1 --init style --device gpu` 49 | 50 | 51 | **!!!You can find all the iteration results in folder 'outputs'!!!** 52 | 53 | ![portraits](https://github.com/EndyWon/GLStyleNet/blob/master/examples/portraits.png) 54 | 55 | 3. **Chinese ancient painting style transfer** (an example) 56 | 57 | 58 | `python GLStyleNet.py --content Chinese/content.jpg --content-mask Chinese/content_sem.png --style Chinese/style.jpg --style-mask Chinese/style_sem.png --content-weight 10 --local-weight 500 --semantic-weight 2.5 --global-weight 0.5 --init content --device gpu` 59 | 60 | ![Chinese](https://github.com/EndyWon/GLStyleNet/blob/master/examples/Chinese.png) 61 | 62 | 4. **artistic and photo-realistic style transfer** 63 | 64 | #### artistic: 65 | 66 | ![artistic](https://github.com/EndyWon/GLStyleNet/blob/master/examples/artistic.png) 67 | 68 | #### photo-realistic: 69 | 70 | ![photo-realistic](https://github.com/EndyWon/GLStyleNet/blob/master/examples/photo-realistic.png) 71 | 72 | 73 | ## Citation: 74 | 75 | If you find this code useful for your research, please cite the paper: 76 | 77 | ``` 78 | @article{wang2020glstylenet, 79 | title={GLStyleNet: exquisite style transfer combining global and local pyramid features}, 80 | author={Wang, Zhizhong and Zhao, Lei and Lin, Sihuan and Mo, Qihang and Zhang, Huiming and Xing, Wei and Lu, Dongming}, 81 | journal={IET Computer Vision}, 82 | volume={14}, 83 | number={8}, 84 | pages={575--586}, 85 | year={2020}, 86 | publisher={IET} 87 | } 88 | ``` 89 | 90 | ## Acknowledgement: 91 | The code was written based on [Champandard's code](https://github.com/alexjc/neural-doodle). 92 | -------------------------------------------------------------------------------- /artistic/c1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/artistic/c1.jpg -------------------------------------------------------------------------------- /artistic/c2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/artistic/c2.png -------------------------------------------------------------------------------- /artistic/c3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/artistic/c3.png -------------------------------------------------------------------------------- /artistic/c4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/artistic/c4.png -------------------------------------------------------------------------------- /artistic/s1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/artistic/s1.png -------------------------------------------------------------------------------- /artistic/s2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/artistic/s2.png -------------------------------------------------------------------------------- /artistic/s3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/artistic/s3.png -------------------------------------------------------------------------------- /artistic/s4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/artistic/s4.png -------------------------------------------------------------------------------- /download_vgg19.sh: -------------------------------------------------------------------------------- 1 | URL=https://github.com/EndyWon/GLStyleNet/releases/download/v1.0/vgg19.pkl.bz2 2 | FILE=./vgg19.pkl.bz2 3 | wget $URL -O $FILE 4 | -------------------------------------------------------------------------------- /examples/Chinese.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/examples/Chinese.png -------------------------------------------------------------------------------- /examples/artistic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/examples/artistic.png -------------------------------------------------------------------------------- /examples/photo-realistic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/examples/photo-realistic.png -------------------------------------------------------------------------------- /examples/portraits.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/examples/portraits.png -------------------------------------------------------------------------------- /outputs/ReadMe: -------------------------------------------------------------------------------- 1 | You can find all the iteration outputs in this folder after running 'GLStyleNet.py', we save the iteration results of every 10 iterations. 2 | -------------------------------------------------------------------------------- /photo-realistic/c1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/photo-realistic/c1.png -------------------------------------------------------------------------------- /photo-realistic/c1_sem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/photo-realistic/c1_sem.png -------------------------------------------------------------------------------- /photo-realistic/c2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/photo-realistic/c2.png -------------------------------------------------------------------------------- /photo-realistic/c2_sem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/photo-realistic/c2_sem.png -------------------------------------------------------------------------------- /photo-realistic/c3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/photo-realistic/c3.png -------------------------------------------------------------------------------- /photo-realistic/c3_sem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/photo-realistic/c3_sem.png -------------------------------------------------------------------------------- /photo-realistic/c4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/photo-realistic/c4.png -------------------------------------------------------------------------------- /photo-realistic/c4_sem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/photo-realistic/c4_sem.png -------------------------------------------------------------------------------- /photo-realistic/s1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/photo-realistic/s1.png -------------------------------------------------------------------------------- /photo-realistic/s1_sem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/photo-realistic/s1_sem.png -------------------------------------------------------------------------------- /photo-realistic/s2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/photo-realistic/s2.png -------------------------------------------------------------------------------- /photo-realistic/s2_sem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/photo-realistic/s2_sem.png -------------------------------------------------------------------------------- /photo-realistic/s3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/photo-realistic/s3.png -------------------------------------------------------------------------------- /photo-realistic/s3_sem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/photo-realistic/s3_sem.png -------------------------------------------------------------------------------- /photo-realistic/s4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/photo-realistic/s4.png -------------------------------------------------------------------------------- /photo-realistic/s4_sem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/photo-realistic/s4_sem.png -------------------------------------------------------------------------------- /portrait/Freddie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/portrait/Freddie.jpg -------------------------------------------------------------------------------- /portrait/Freddie_sem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/portrait/Freddie_sem.png -------------------------------------------------------------------------------- /portrait/Gogh.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/portrait/Gogh.jpg -------------------------------------------------------------------------------- /portrait/Gogh_sem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/portrait/Gogh_sem.png -------------------------------------------------------------------------------- /portrait/Mia.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/portrait/Mia.jpg -------------------------------------------------------------------------------- /portrait/Mia_sem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/portrait/Mia_sem.png -------------------------------------------------------------------------------- /portrait/Seth.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/portrait/Seth.jpg -------------------------------------------------------------------------------- /portrait/Seth_sem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EndyWon/GLStyleNet/0ad94e582bb72f338e345d7c0d7f036c2ad6bd32/portrait/Seth_sem.png --------------------------------------------------------------------------------