├── .gitignore ├── README.md ├── images ├── crop.py ├── style │ ├── 0_udnie.jpg │ ├── 10_Yellow_sunset.jpg │ ├── 11_Three_Fishing_Boats.jpg │ ├── 12_The_Annunciation_of_the_Virgin_Deal.jpg │ ├── 13_Edith_with_Striped_Dress.jpeg │ ├── 14_Colors_from_a_Distance.jpg │ ├── 15_Sunrise.jpg │ ├── 1_la_muse.jpg │ ├── 2_rain_princess.jpg │ ├── 3_the_scream.jpg │ ├── 4_the_shipwreck_of_the_minotaur.jpg │ ├── 5_wave.jpg │ ├── 6_composition.jpg │ ├── 7_guernica.jpg │ ├── 8_starry_night.jpg │ └── 9_White_Zig_Zags.jpg ├── style_crop │ ├── 0_udnie.jpg │ ├── 10_Yellow_sunset.jpg │ ├── 11_Three_Fishing_Boats.jpg │ ├── 12_The_Annunciation_of_the_Virgin_Deal.jpg │ ├── 13_Edith_with_Striped_Dress.jpeg │ ├── 14_Colors_from_a_Distance.jpg │ ├── 15_Sunrise.jpg │ ├── 1_la_muse.jpg │ ├── 2_rain_princess.jpg │ ├── 3_the_scream.jpg │ ├── 4_the_shipwreck_of_the_minotaur.jpg │ ├── 5_wave.jpg │ ├── 6_composition.jpg │ ├── 7_guernica.jpg │ ├── 8_starry_night.jpg │ └── 9_White_Zig_Zags.jpg └── test │ ├── chicago.jpg │ ├── stata.jpg │ └── tubingen.jpg ├── main.py ├── result ├── conditional_instance_norm.jpg ├── result.jpg ├── style.jpg ├── style01_01.gif ├── style02_01.gif ├── style03_01.gif ├── style04_01.gif ├── style05_01.gif ├── style06_01.gif ├── style07_01.gif ├── style08_01.gif ├── style09_01.gif ├── style10_01.gif ├── style11_01.gif ├── style12_01.gif ├── style13_01.gif ├── tubingen_10.jpg ├── tubingen_11.jpg ├── tubingen_9.jpg ├── tubingen_9_10.jpg └── tubingen_9_10_11.jpg ├── src ├── __init__.py ├── functions.py ├── layers.py ├── multi_style_transfer.py ├── op.py └── vgg19.py ├── test_style.sh └── train_style.sh /.gitignore: -------------------------------------------------------------------------------- 1 | src/*.pyc 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fast Multi(Interpolative) Style Transfer 2 | Implementation of Google Brain's [A Learned Representation For Artistic Style](https://arxiv.org/pdf/1610.07629v2.pdf) in Tensorflow. 3 | You can mix various type of style images using just One Model and it's still Fast! 4 | 5 |

6 | 7 |

8 | Figure1. Using one model and making multi style transfer image. Center image is mixed with 4 style 9 | 10 | This paper is next version of [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155) 11 | and [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022). 12 | These papers are fast and nice result, but one model make only one style image. 13 | 14 | 15 | ## Usage 16 | Recommand to download project files (model, vgg, image, etc.) [One drive](https://1drv.ms/f/s!ArFpOdlDcjqQga8fwL0m4VQGmgKSfg) / [Dropbox](https://www.dropbox.com/sh/b3by1ipmr0v821y/AABJ4gadaOk9RRsqsOTC336Xa?dl=0). And Download [COCO](http://mscoco.org/dataset/#download) on your data folder. Example command lines are below and train_style.sh, test_style.sh. 17 | 18 | #### Project folder tree 19 | Working Directory 20 | ├── MST 21 | │ ├── models 22 | │ │ ├── checkpoint 23 | │ │ ├── xxx.index 24 | │ │ ├── xxx.data 25 | | │ └── xxxx.meta 26 | │ ├── test_result 27 | │ └── train_result 28 | ├── images 29 | │ ├── style 30 | │ ├── style_crop 31 | │ ├── test 32 | │ └── crop.py 33 | ├── src 34 | │ ├── vgg19.mat 35 | │ └── ... 36 | ├── main.py 37 | ├── test_style.sh 38 | └── train_style.sh 39 | 40 | 41 | 42 | #### Style Control Weight (SCW) 43 | "-scw, --style_control_weights" is style control argument. "0 0 0 ... 0 0 0 " means weight of "style1 style2 ... style16" 44 | 45 | If you want single style 46 | 47 | style1 -scw "1 0 0 ... 0 0 0" 48 | style16 -scw "0 0 0 ... 0 0 1" 49 | 50 | If you want multi style 51 | 52 | 0.5 * style1 + 0.5 * style2 -scw "0.5 0.5 0 ... 0 0 0" or "1 1 0 ... 0 0 0" 53 | 0.2 * style1 + 0.3 * style2 + 0.4 * style3 -scw "0.2 0.3 0.4 ... 0 0 0" or "2 3 4 ... 0 0 0" 54 | 1/16 * (style1 ~ style16) -scw "0.63 0.63 ... 0.63 0.63" or "1 1 1 ... 1 1 1" 55 | 56 | 57 | ### Train 58 | #### From Scratch. 59 | 60 | python main.py -f 1 -gn 0 -p MST -n 10 -b 16 \ 61 | -tsd images/test -sti images/style_crop/0_udnie.jpg \ 62 | -ctd /mnt/cloud/Data/COCO/train2014 \ 63 | -scw 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \ 64 | 65 | Train weight, bias, gamma_0, beta_0. Need 40000 iteration (10 epoch) 66 |

67 | 68 |

69 | 70 | #### Fine-Tuned. (after train 'from scratch' or download trained model) 71 | 72 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 \ 73 | -tsd images/test -sti images/style_crop/1_la_muse.jpg \ 74 | -ctd /mnt/cloud/Data/COCO/train2014 \ 75 | -scw 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \ 76 | 77 | Train only gamma_i, beta_i. Just need 4000 iteration (1 epoch, 1/10 scratch) 78 | 79 | You can see that images gradually change to new style. 80 |

81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 |

95 | 96 | 97 | #### Number of style 98 | if you want to train 32-style model, edit main.py and -scw range (16 to 32) 99 | 100 | 1. parser.add_argument("-scw", "--style_control_weights", type=float, nargs=16 --> 32) 101 | 2. -scw 1 2 3 ... 14 15 16 --> -scw 1 2 3 ... 30 31 32 102 | 103 | 104 | ### Test 105 | #### Single style 106 | 107 | ex) style9 108 | python main.py -f 0 -gn 0 -p MST \ 109 | -tsd images/test \ 110 | -scw 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 111 | 112 |

113 | 114 |

115 | 116 | 117 | ex) style10 118 | python main.py -f 0 -gn 0 -p MST \ 119 | -tsd images/test \ 120 | -scw 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 121 | 122 |

123 | 124 |

125 | 126 | 127 | #### Multi Style 128 | 129 | ex) 0.5*style9 + 0.5*style10 130 | python main.py -f 0 -gn 0 -p MST \ 131 | -tsd images/test \ 132 | -scw 0 0 0 0 0 0 0 0 0 0.5 0.5 0 0 0 0 0 \ 133 | 134 |

135 | 136 |

137 | 138 | ex) 0.33*style9 + 0.33*style10 + 0.33*style11 139 | python main.py -f 0 -gn 0 -p MST \ 140 | -tsd images/test \ 141 | -scw 0 0 0 0 0 0 0 0 0 0.33 0.33 0.33 0 0 0 0 \ 142 | 143 |

144 | 145 |

146 | 147 | ## Implementation Details 148 | #### Conditional instance normalization 149 | 150 | The key of this paper is Conditional instance normalization. 151 | 152 |

153 | 154 |

155 | 156 | Instance normalization is similar with batch normalization, but it doesn't accumulate mean(mu), variance(alpha). 157 | Conditional instance normalization have N scale(gamma) and N shift(beta). N is number of style images. 158 | This means when you add new style, you just train new gamma and new beta. 159 | See the below code. 160 | 161 | def conditional_instance_norm: 162 | ... 163 | shift = [] 164 | scale = [] 165 | 166 | for i in range(len(style_control)): 167 | with tf.variable_scope('{0}'.format(i) + '_style'): 168 | shift.append(tf.get_variable('shift', shape=var_shape, initializer=tf.constant_initializer(0.))) 169 | scale.append(tf.get_variable('scale', shape=var_shape, initializer=tf.constant_initializer(1.))) 170 | ... 171 | 172 | idx = [i for i, x in enumerate(style_control) if not x == 0] 173 | style_scale = reduce(tf.add, [scale[i]*style_control[i] for i in idx]) / sum(style_control) 174 | style_shift = reduce(tf.add, [shift[i]*style_control[i] for i in idx]) / sum(style_control) 175 | output = style_scale * normalized + style_shift 176 | 177 | 178 | #### Upsampling 179 | Paper's upsampling method is "Image_resize-Conv". But I use ["Deconv-Pooling"](https://arxiv.org/abs/1611.04994) 180 | 181 | def mst_net: 182 | ... 183 | x = conv_tranpose_layer(x, 64, 3, 2, style_control=style_control, name='up_conv1') 184 | x = pooling(x) 185 | x = conv_tranpose_layer(x, 32, 3, 2, style_control=style_control, name='up_conv2') 186 | x = pooling(x) 187 | ... 188 | 189 | 190 | ## Requirements 191 | - TensorFlow 1.0.0 192 | - Python 2.7.12, Pillow 3.4.2, scipy 0.18.1, numpy 1.11.2 193 | 194 | ## Attributions/Thanks 195 | This project borrowed some code from [Lengstrom's fast-style-transfer.](https://github.com/lengstrom/fast-style-transfer) 196 | And Google brain's code is [here](https://github.com/tensorflow/magenta) (need some install) 197 | -------------------------------------------------------------------------------- /images/crop.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.misc as cv2 3 | import numpy as np 4 | 5 | dataset = 'style' 6 | saveset = 'style_crop' 7 | size = 512 8 | 9 | for fn in os.listdir(dataset): 10 | print fn 11 | img = cv2.imread(dataset + '/' + fn) 12 | w,h,c = np.shape(img) 13 | print w,h 14 | 15 | if w >= h: 16 | ratio = float(h)/float(w) 17 | resize_factor = (int(size/ratio), size) 18 | img_resize = cv2.imresize(img, resize_factor) 19 | else: 20 | ratio = float(w)/float(h) 21 | resize_factor = (size, int(size/ratio)) 22 | img_resize = cv2.imresize(img, resize_factor) 23 | 24 | w,h,c = np.shape(img_resize) 25 | crop_w = int((w-size) * 0.5) 26 | crop_h = int((h-size) * 0.5) 27 | # cv2.imsave(saveset + '/' + 'resize_' + fn, img_resize) 28 | 29 | print crop_h, crop_w 30 | img_crop = img_resize[crop_w:crop_w+size,crop_h:crop_h+size,:] 31 | cv2.imsave(saveset + '/' + fn, img_crop) 32 | 33 | 34 | -------------------------------------------------------------------------------- /images/style/0_udnie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/0_udnie.jpg -------------------------------------------------------------------------------- /images/style/10_Yellow_sunset.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/10_Yellow_sunset.jpg -------------------------------------------------------------------------------- /images/style/11_Three_Fishing_Boats.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/11_Three_Fishing_Boats.jpg -------------------------------------------------------------------------------- /images/style/12_The_Annunciation_of_the_Virgin_Deal.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/12_The_Annunciation_of_the_Virgin_Deal.jpg -------------------------------------------------------------------------------- /images/style/13_Edith_with_Striped_Dress.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/13_Edith_with_Striped_Dress.jpeg -------------------------------------------------------------------------------- /images/style/14_Colors_from_a_Distance.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/14_Colors_from_a_Distance.jpg -------------------------------------------------------------------------------- /images/style/15_Sunrise.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/15_Sunrise.jpg -------------------------------------------------------------------------------- /images/style/1_la_muse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/1_la_muse.jpg -------------------------------------------------------------------------------- /images/style/2_rain_princess.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/2_rain_princess.jpg -------------------------------------------------------------------------------- /images/style/3_the_scream.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/3_the_scream.jpg -------------------------------------------------------------------------------- /images/style/4_the_shipwreck_of_the_minotaur.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/4_the_shipwreck_of_the_minotaur.jpg -------------------------------------------------------------------------------- /images/style/5_wave.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/5_wave.jpg -------------------------------------------------------------------------------- /images/style/6_composition.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/6_composition.jpg -------------------------------------------------------------------------------- /images/style/7_guernica.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/7_guernica.jpg -------------------------------------------------------------------------------- /images/style/8_starry_night.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/8_starry_night.jpg -------------------------------------------------------------------------------- /images/style/9_White_Zig_Zags.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/9_White_Zig_Zags.jpg -------------------------------------------------------------------------------- /images/style_crop/0_udnie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/0_udnie.jpg -------------------------------------------------------------------------------- /images/style_crop/10_Yellow_sunset.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/10_Yellow_sunset.jpg -------------------------------------------------------------------------------- /images/style_crop/11_Three_Fishing_Boats.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/11_Three_Fishing_Boats.jpg -------------------------------------------------------------------------------- /images/style_crop/12_The_Annunciation_of_the_Virgin_Deal.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/12_The_Annunciation_of_the_Virgin_Deal.jpg -------------------------------------------------------------------------------- /images/style_crop/13_Edith_with_Striped_Dress.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/13_Edith_with_Striped_Dress.jpeg -------------------------------------------------------------------------------- /images/style_crop/14_Colors_from_a_Distance.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/14_Colors_from_a_Distance.jpg -------------------------------------------------------------------------------- /images/style_crop/15_Sunrise.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/15_Sunrise.jpg -------------------------------------------------------------------------------- /images/style_crop/1_la_muse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/1_la_muse.jpg -------------------------------------------------------------------------------- /images/style_crop/2_rain_princess.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/2_rain_princess.jpg -------------------------------------------------------------------------------- /images/style_crop/3_the_scream.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/3_the_scream.jpg -------------------------------------------------------------------------------- /images/style_crop/4_the_shipwreck_of_the_minotaur.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/4_the_shipwreck_of_the_minotaur.jpg -------------------------------------------------------------------------------- /images/style_crop/5_wave.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/5_wave.jpg -------------------------------------------------------------------------------- /images/style_crop/6_composition.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/6_composition.jpg -------------------------------------------------------------------------------- /images/style_crop/7_guernica.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/7_guernica.jpg -------------------------------------------------------------------------------- /images/style_crop/8_starry_night.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/8_starry_night.jpg -------------------------------------------------------------------------------- /images/style_crop/9_White_Zig_Zags.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/9_White_Zig_Zags.jpg -------------------------------------------------------------------------------- /images/test/chicago.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/test/chicago.jpg -------------------------------------------------------------------------------- /images/test/stata.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/test/stata.jpg -------------------------------------------------------------------------------- /images/test/tubingen.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/test/tubingen.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import distutils.util 3 | import tensorflow as tf 4 | import src.multi_style_transfer as mst 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser() 9 | 10 | # Train ############################################################################################################ 11 | parser.add_argument("-f", "--flag", type=distutils.util.strtobool, default='True') 12 | parser.add_argument("-gn", "--gpu_number", type=int, default=0) 13 | parser.add_argument("-p", "--project", type=str, default="mst") 14 | 15 | ## Train images 16 | parser.add_argument("-ctd", "--content_dataset", type=str, default="/mnt/cloud/Data/COCO/all") 17 | parser.add_argument("-cts", "--content_data_size", type=int, default=256) 18 | parser.add_argument("-sti", "--style_image", type=str, default="images/style/0_udnie.jpg") 19 | 20 | ## Train Iteration 21 | parser.add_argument("-n", "--niter", type=int, default=2) 22 | parser.add_argument("-ns", "--nsnapshot", type=int, default=100) 23 | parser.add_argument("-mx", "--max_to_keep", type=int, default=10) 24 | 25 | ## Train Parameter 26 | parser.add_argument("-b", "--batch_size", type=int, default=16) 27 | parser.add_argument("-lr", "--learning_rate", type=float, default=1e-3) 28 | parser.add_argument("-m", "--momentum", type=float, default=0.9) 29 | parser.add_argument("-m2", "--momentum2", type=float, default=0.999) 30 | 31 | ## loss weight 32 | parser.add_argument("-lc", "--content_loss_weights", type=float, default=1.5e0) 33 | parser.add_argument("-ls", "--style_loss_weights", type=float, default=1e2) 34 | parser.add_argument("-lt", "--tv_loss_weight", type=float, default=2e2) 35 | 36 | # Test ############################################################################################################# 37 | parser.add_argument("-tsd", "--test_dataset", type=str, default="images/test") 38 | parser.add_argument("-scw", "--style_control_weights", type=float, nargs=16) 39 | 40 | args = parser.parse_args() 41 | gpu_number = args.gpu_number 42 | train_flag = args.flag 43 | 44 | with tf.device('/gpu:{}'.format(gpu_number)): 45 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.85) 46 | config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options) 47 | 48 | with tf.Session(config=config) as sess: 49 | ## Make Model 50 | model = mst.mst(args, sess) 51 | 52 | ## TRAIN / TEST 53 | if train_flag: 54 | model.train(train_flag) 55 | else: 56 | model.test(train_flag) 57 | 58 | 59 | if __name__ == '__main__': 60 | main() 61 | 62 | -------------------------------------------------------------------------------- /result/conditional_instance_norm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/conditional_instance_norm.jpg -------------------------------------------------------------------------------- /result/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/result.jpg -------------------------------------------------------------------------------- /result/style.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style.jpg -------------------------------------------------------------------------------- /result/style01_01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style01_01.gif -------------------------------------------------------------------------------- /result/style02_01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style02_01.gif -------------------------------------------------------------------------------- /result/style03_01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style03_01.gif -------------------------------------------------------------------------------- /result/style04_01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style04_01.gif -------------------------------------------------------------------------------- /result/style05_01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style05_01.gif -------------------------------------------------------------------------------- /result/style06_01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style06_01.gif -------------------------------------------------------------------------------- /result/style07_01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style07_01.gif -------------------------------------------------------------------------------- /result/style08_01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style08_01.gif -------------------------------------------------------------------------------- /result/style09_01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style09_01.gif -------------------------------------------------------------------------------- /result/style10_01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style10_01.gif -------------------------------------------------------------------------------- /result/style11_01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style11_01.gif -------------------------------------------------------------------------------- /result/style12_01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style12_01.gif -------------------------------------------------------------------------------- /result/style13_01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style13_01.gif -------------------------------------------------------------------------------- /result/tubingen_10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/tubingen_10.jpg -------------------------------------------------------------------------------- /result/tubingen_11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/tubingen_11.jpg -------------------------------------------------------------------------------- /result/tubingen_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/tubingen_9.jpg -------------------------------------------------------------------------------- /result/tubingen_9_10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/tubingen_9_10.jpg -------------------------------------------------------------------------------- /result/tubingen_9_10_11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/tubingen_9_10_11.jpg -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/src/__init__.py -------------------------------------------------------------------------------- /src/functions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy.misc as scm 4 | from glob import glob 5 | 6 | 7 | def get_image(img_path, size=None): 8 | img = scm.imread(img_path, mode='RGB') 9 | h, w, c = np.shape(img) 10 | 11 | img = img[:h, :w, ::-1] # rgb to bgr 12 | if size: 13 | img = scm.imresize(img, (size, size)) 14 | return img 15 | 16 | 17 | def inverse_image(img): 18 | img[img > 255] = 255 19 | img[img < 0] = 0 20 | img = img[:, :, ::-1] # bgr to rgb 21 | return img 22 | 23 | 24 | def make_project_dir(project_dir): 25 | if not os.path.exists(project_dir): 26 | os.makedirs(project_dir) 27 | os.makedirs(os.path.join(project_dir,'models')) 28 | os.makedirs(os.path.join(project_dir,'train_result')) 29 | os.makedirs(os.path.join(project_dir,'test_result')) 30 | 31 | 32 | def data_loader(dataset): 33 | print 'images Load ....' 34 | data_path = dataset 35 | 36 | if os.path.exists(data_path + '.npy'): 37 | data = np.load(data_path + '.npy') 38 | else: 39 | data = glob(os.path.join(data_path, "*.*")) 40 | np.save(data_path + '.npy', data) 41 | print 'images Load Done' 42 | 43 | return data 44 | -------------------------------------------------------------------------------- /src/layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def conv_layer(net, num_filters, filter_size, strides, style_control=None, relu=True, name='conv'): 5 | with tf.variable_scope(name): 6 | b,w,h,c = net.get_shape().as_list() 7 | weights_shape = [filter_size, filter_size, c, num_filters] 8 | weights_init = tf.get_variable(name, shape=weights_shape, initializer=tf.truncated_normal_initializer(stddev=.01)) 9 | strides_shape = [1, strides, strides, 1] 10 | 11 | p = (filter_size - 1) / 2 12 | if strides == 1: 13 | net = tf.pad(net, [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT") 14 | net = tf.nn.conv2d(net, weights_init, strides_shape, padding="VALID") 15 | else: 16 | net = tf.nn.conv2d(net, weights_init, strides_shape, padding="SAME") 17 | 18 | net = conditional_instance_norm(net, style_control=style_control) 19 | if relu: 20 | net = tf.nn.relu(net) 21 | 22 | return net 23 | 24 | 25 | def conv_tranpose_layer(net, num_filters, filter_size, strides, style_control=None, name='conv_t'): 26 | with tf.variable_scope(name): 27 | b, w, h, c = net.get_shape().as_list() 28 | weights_shape = [filter_size, filter_size, num_filters, c] 29 | weights_init = tf.get_variable(name, shape=weights_shape, initializer=tf.truncated_normal_initializer(stddev=.01)) 30 | 31 | batch_size, rows, cols, in_channels = [i.value for i in net.get_shape()] 32 | new_rows, new_cols = int(rows * strides), int(cols * strides) 33 | # new_shape = #tf.pack([tf.shape(net)[0], new_rows, new_cols, num_filters]) 34 | 35 | new_shape = [batch_size, new_rows, new_cols, num_filters] 36 | tf_shape = tf.stack(new_shape) 37 | strides_shape = [1,strides,strides,1] 38 | 39 | p = (filter_size - 1) / 2 40 | if strides == 1: 41 | net = tf.pad(net, [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT") 42 | net = tf.nn.conv2d_transpose(net, weights_init, tf_shape, strides_shape, padding="VALID") 43 | else: 44 | net = tf.nn.conv2d_transpose(net, weights_init, tf_shape, strides_shape, padding="SAME") 45 | net = conditional_instance_norm(net, style_control=style_control) 46 | 47 | return tf.nn.relu(net) 48 | 49 | 50 | def residual_block(net, filter_size=3, style_control=None, name='res'): 51 | with tf.variable_scope(name+'_a'): 52 | tmp = conv_layer(net, 128, filter_size, 1, style_control=style_control) 53 | with tf.variable_scope(name+'_b'): 54 | output = net + conv_layer(tmp, 128, filter_size, 1, style_control=style_control, relu=False) 55 | return output 56 | 57 | 58 | def conditional_instance_norm(net, style_control=None, name='cond_in'): 59 | with tf.variable_scope(name): 60 | batch, rows, cols, channels = [i.value for i in net.get_shape()] 61 | mu, sigma_sq = tf.nn.moments(net, [1,2], keep_dims=True) 62 | 63 | var_shape = [channels] 64 | shift = [] 65 | scale = [] 66 | 67 | for i in range(len(style_control)): 68 | with tf.variable_scope('{0}'.format(i) + '_style'): 69 | shift.append(tf.get_variable('shift', shape=var_shape, initializer=tf.constant_initializer(0.))) 70 | scale.append(tf.get_variable('scale', shape=var_shape, initializer=tf.constant_initializer(1.))) 71 | epsilon = 1e-3 72 | normalized = (net-mu)/(sigma_sq + epsilon)**(.5) 73 | 74 | idx = [i for i, x in enumerate(style_control) if not x == 0] 75 | 76 | style_scale = reduce(tf.add, [scale[i]*style_control[i] for i in idx]) / sum(style_control) 77 | style_shift = reduce(tf.add, [shift[i]*style_control[i] for i in idx]) / sum(style_control) 78 | output = style_scale * normalized + style_shift 79 | 80 | return output 81 | 82 | 83 | def instance_norm(net, train=True, name='in'): 84 | with tf.variable_scope(name): 85 | batch, rows, cols, channels = [i.value for i in net.get_shape()] 86 | var_shape = [channels] 87 | mu, sigma_sq = tf.nn.moments(net, [1,2], keep_dims=True) 88 | shift = tf.get_variable('shift', shape=var_shape, initializer=tf.constant_initializer(0.)) 89 | scale = tf.get_variable('scale', shape=var_shape, initializer=tf.constant_initializer(1.)) 90 | epsilon = 1e-3 91 | normalized = (net-mu)/(sigma_sq + epsilon)**(.5) 92 | return scale * normalized + shift 93 | 94 | def pooling(input): 95 | return tf.nn.avg_pool(input, ksize=(1, 2, 2, 1), strides=(1, 1, 1, 1), padding='SAME') 96 | 97 | 98 | def total_variation(preds): 99 | # total variation denoising 100 | b,w,h,c = preds.get_shape().as_list() 101 | y_tv = tf.nn.l2_loss(preds[:,1:,:,:] - preds[:,:w-1,:,:]) 102 | x_tv = tf.nn.l2_loss(preds[:,:,1:,:] - preds[:,:,:h-1,:]) 103 | tv_loss = 2*(x_tv + y_tv)/b/w/h/c 104 | return tv_loss 105 | 106 | 107 | def euclidean_loss(input_, target_): 108 | b,w,h,c = input_.get_shape().as_list() 109 | return 2 * tf.nn.l2_loss(input_- target_) / b/w/h/c 110 | 111 | 112 | def gram_matrix(net): 113 | b,h,w,c = net.get_shape().as_list() 114 | feats = tf.reshape(net, (b, h*w, c)) 115 | feats_T = tf.transpose(feats, perm=[0,2,1]) 116 | grams = tf.matmul(feats_T, feats) / h/w/c 117 | return grams 118 | 119 | 120 | def style_loss(input_, style_): 121 | b,h,w,c = input_.get_shape().as_list() 122 | input_gram = gram_matrix(input_) 123 | style_gram = gram_matrix(style_) 124 | return 2 * tf.nn.l2_loss(input_gram - style_gram)/b/c/c 125 | 126 | -------------------------------------------------------------------------------- /src/multi_style_transfer.py: -------------------------------------------------------------------------------- 1 | import scipy.io 2 | import numpy as np 3 | import src.vgg19 as vgg 4 | 5 | from src.op import op 6 | from src.layers import * 7 | from src.functions import * 8 | 9 | 10 | class mst(op): 11 | def __init__(self,args, sess): 12 | op.__init__(self, args, sess) 13 | 14 | def mst_net(self, x, style_control=None, reuse=False): 15 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): 16 | b,h,w,c = x.get_shape().as_list() 17 | 18 | x = conv_layer(x, 32, 9, 1, style_control=style_control, name='conv1') 19 | x = conv_layer(x, 64, 3, 2, style_control=style_control, name='conv2') 20 | x = conv_layer(x, 128, 3, 2, style_control=style_control, name='conv3') 21 | x = residual_block(x, 3, style_control=style_control, name='res1') 22 | x = residual_block(x, 3, style_control=style_control, name='res2') 23 | x = residual_block(x, 3, style_control=style_control, name='res3') 24 | x = residual_block(x, 3, style_control=style_control, name='res4') 25 | x = residual_block(x, 3, style_control=style_control, name='res5') 26 | x = conv_tranpose_layer(x, 64, 3, 2, style_control=style_control, name='up_conv1') 27 | x = pooling(x) 28 | x = conv_tranpose_layer(x, 32, 3, 2, style_control=style_control, name='up_conv2') 29 | x = pooling(x) 30 | x = conv_layer(x, 3, 9, 1, relu=False, style_control=style_control, name='output') 31 | preds = tf.nn.tanh(x) * 150 + 255./2 32 | return preds 33 | 34 | 35 | def build_model(self): 36 | # Set content / style input 37 | # content_input 38 | b = self.batch_size; h = self.content_data_size; w = self.content_data_size; 39 | self.content_input = tf.placeholder(tf.float32, shape=[b, h, w, 3], name='content_input') 40 | 41 | # style_input 42 | style_img = get_image(self.style_image) 43 | style_idx = [i for i, x in enumerate(self.style_control) if not x == 0][0] 44 | print 'style_idx : {}'.format(style_idx) 45 | style_input = tf.constant((style_img[np.newaxis, ...]), dtype=tf.float32) 46 | 47 | # MST_output (Pastiche) 48 | MST_output = self.mst_net(self.content_input, style_control=self.style_control) 49 | 50 | # VGG network 51 | weights = scipy.io.loadmat('src/vgg19.mat') 52 | vgg_mean = tf.constant(np.array([103.939, 116.779, 123.68]).reshape((1, 1, 1, 3)), dtype='float32') 53 | 54 | content_feats = vgg.net(self.content_input - vgg_mean, weights) 55 | style_feats = vgg.net(style_input - vgg_mean, weights) 56 | MST_feats = vgg.net(MST_output - vgg_mean, weights) 57 | 58 | c_loss = self.content_weights * euclidean_loss(MST_feats[-1], content_feats[-1]) 59 | s_loss = self.style_weights * sum([style_loss(MST_feats[i], style_feats[i]) for i in range(5)]) 60 | tv_loss = self.tv_weight * total_variation(MST_output) 61 | loss = c_loss + s_loss + tv_loss 62 | 63 | t_vars = tf.trainable_variables() 64 | vars = [var for var in t_vars if '{0}'.format(style_idx) + '_style' in var.name] 65 | 66 | if style_idx == 0: 67 | train_opt = tf.train.AdamOptimizer(self.learning_rate, self.momentum).minimize(loss) 68 | else: 69 | train_opt = tf.train.AdamOptimizer(self.learning_rate, self.momentum).minimize(loss, var_list=vars) 70 | 71 | self.optimize = [train_opt, loss, c_loss, s_loss, tv_loss] 72 | self.saver = tf.train.Saver(var_list=t_vars, max_to_keep=(self.max_to_keep)) 73 | self.sess.run(tf.global_variables_initializer()) 74 | 75 | 76 | def train(self,Train_flag): 77 | op.train(self,Train_flag) 78 | 79 | def test(self, Train_flag): 80 | op.test(self,Train_flag) 81 | 82 | def save(self): 83 | op.save(self) 84 | 85 | def load(self): 86 | op.load(self) 87 | -------------------------------------------------------------------------------- /src/op.py: -------------------------------------------------------------------------------- 1 | import time 2 | import tensorflow as tf 3 | from src.functions import * 4 | 5 | 6 | class op(object): 7 | 8 | def __init__(self, args, sess): 9 | self.sess = sess 10 | 11 | ## Train 12 | self.gpu_number = args.gpu_number 13 | self.project_name = args.project 14 | 15 | ## Train images 16 | self.content_dataset = args.content_dataset ## test2015 17 | self.content_data_size = args.content_data_size 18 | self.style_image = args.style_image 19 | 20 | ## Train Iteration 21 | self.niter = args.niter 22 | self.niter_snapshot = args.nsnapshot 23 | self.max_to_keep = args.max_to_keep 24 | 25 | ## Train Parameter 26 | self.batch_size = args.batch_size 27 | self.learning_rate = args.learning_rate 28 | self.momentum = args.momentum 29 | self.momentum2 = args.momentum2 30 | 31 | self.content_weights = args.content_loss_weights 32 | self.style_weights = args.style_loss_weights 33 | self.tv_weight = args.tv_loss_weight 34 | 35 | ## Result Dir & File 36 | self.project_dir = '{0}/'.format(self.project_name) 37 | make_project_dir(self.project_dir) 38 | self.ckpt_dir = os.path.join(self.project_dir, 'models') 39 | 40 | ## Test 41 | self.test_dataset = args.test_dataset 42 | self.style_control = args.style_control_weights 43 | 44 | ## build model 45 | self.build_model() 46 | 47 | 48 | def train(self,Train_flag): 49 | data = data_loader(self.content_dataset) 50 | print 'Shuffle ....' 51 | random_order = np.random.permutation(len(data)) 52 | data = [data[i] for i in random_order[:10000*self.batch_size]] 53 | print 'Shuffle Done' 54 | 55 | start_time = time.time() 56 | count = 0 57 | 58 | try: 59 | self.load() 60 | print 'Weight Load !!' 61 | except: 62 | self.sess.run(tf.global_variables_initializer()) 63 | 64 | for epoch in xrange(self.niter): 65 | batch_idxs = len(data) // self.batch_size 66 | 67 | for idx in xrange(0, batch_idxs): 68 | count += 1 69 | 70 | batch_files = data[idx * self.batch_size: (idx + 1) * self.batch_size] 71 | batch_label = [(get_image(batch_file, self.content_data_size)) for batch_file in batch_files] 72 | 73 | feeds = {self.content_input: batch_label} 74 | 75 | _, loss_all, loss_c, loss_s, loss_tv = self.sess.run(self.optimize, feed_dict=feeds) 76 | train_time = time.time() - start_time 77 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.4f, loss_c: %.4f, loss_s: %.4f, loss_tv: %.4f" 78 | % (epoch, idx, batch_idxs, train_time, loss_all, loss_c, loss_s, loss_tv)) 79 | 80 | ## Test during Training 81 | if count % self.niter_snapshot == (self.niter_snapshot-1): 82 | self.count = count 83 | self.save() 84 | self.test(Train_flag) 85 | 86 | 87 | def test(self, Train_flag=True): 88 | for fn in os.listdir(self.test_dataset): 89 | 90 | ## Read RGB Image 91 | im_input = get_image(self.test_dataset + '/' + fn) 92 | im_input_4d = im_input[np.newaxis, ...] 93 | im_b, im_h, im_w, im_c = np.shape(im_input_4d) 94 | 95 | ## Run Model 96 | img = tf.placeholder(tf.float32, [im_b, im_h, im_w, im_c], name='img') 97 | 98 | self.test_recon = self.mst_net(img, style_control=self.style_control, reuse=True) 99 | self.load() 100 | 101 | im_output = self.sess.run(self.test_recon, feed_dict={img : im_input_4d}) 102 | im_output = inverse_image(im_output[0]) 103 | 104 | style_idx = ['{0}_{1}'.format(i, x) for i, x in enumerate(self.style_control) if not x == 0] 105 | 106 | ## Image Show & Save 107 | style_name = os.path.split(self.style_image)[-1].split('.')[0] 108 | if Train_flag: 109 | train_output_dir = os.path.join(self.project_dir, 'train_result', style_name) 110 | if not os.path.exists(train_output_dir): 111 | os.makedirs(train_output_dir) 112 | filename = fn[:-4] + '_' + str(style_idx) + '_' + str(self.count) + '_output.bmp' 113 | scm.imsave(os.path.join(train_output_dir, filename), im_output) 114 | else: 115 | test_output_dir = os.path.join(self.project_dir, 'test_result') 116 | filename = fn[:-4] + '_' + str(style_idx) + '_output.bmp' 117 | scm.imsave(os.path.join(test_output_dir, filename), im_output) 118 | 119 | print filename 120 | 121 | 122 | def save(self): 123 | style_name = os.path.basename(self.style_image)[:-4] 124 | self.model_name = "{0}_{1}.model".format(self.project_name, style_name) 125 | 126 | if not os.path.exists(self.ckpt_dir): 127 | os.makedirs(self.ckpt_dir) 128 | self.saver.save(self.sess, os.path.join(self.ckpt_dir, self.model_name), global_step=self.count) 129 | 130 | 131 | def load(self): 132 | ckpt = tf.train.get_checkpoint_state(self.ckpt_dir) 133 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 134 | self.saver.restore(self.sess, os.path.join(self.ckpt_dir, ckpt_name)) -------------------------------------------------------------------------------- /src/vgg19.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import scipy.io 4 | 5 | 6 | MEAN_PIXEL = np.array([ 123.68 , 116.779, 103.939]) 7 | 8 | def net(input_image, data): 9 | layers = ( 10 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 11 | 12 | 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 13 | 14 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 15 | 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 16 | 17 | 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 18 | 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 19 | 20 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 21 | 'relu5_3', 'conv5_4', 'relu5_4' 22 | ) 23 | 24 | #data = scipy.io.loadmat(data_path) 25 | #mean = data['normalization'][0][0][0] 26 | #mean_pixel = np.mean(mean, axis=(0, 1)) 27 | weights = data['layers'][0] 28 | 29 | net = {} 30 | current = input_image 31 | net['input'] = input_image 32 | for i, name in enumerate(layers): 33 | kind = name[:4] 34 | if kind == 'conv': 35 | kernels, bias = weights[i][0][0][0][0] 36 | # matconvnet: weights are [width, height, in_channels, out_channels] 37 | # tensorflow: weights are [height, width, in_channels, out_channels] 38 | kernels = np.transpose(kernels, (1, 0, 2, 3)) 39 | bias = bias.reshape(-1) 40 | current = _conv_layer(current, kernels, bias) 41 | elif kind == 'relu': 42 | current = tf.nn.relu(current) 43 | elif kind == 'pool': 44 | current = _pool_layer(current) 45 | net[name] = current 46 | 47 | #assert len(net) == len(layers) 48 | return [net['relu1_1'], net['relu2_1'], net['relu3_1'], net['relu4_1'], net['relu5_1'], net['relu4_2']] 49 | 50 | 51 | def _conv_layer(input, weights, bias): 52 | conv = tf.nn.conv2d(input, tf.constant(weights), strides=(1, 1, 1, 1), 53 | padding='SAME') 54 | return tf.nn.bias_add(conv, bias) 55 | 56 | 57 | def _pool_layer(input): 58 | # return tf.nn.max_pool(input, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), 59 | return tf.nn.avg_pool(input, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), 60 | padding='SAME') 61 | -------------------------------------------------------------------------------- /test_style.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py -f 0 -gn 0 -p MST -tsd images/test -scw 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 | python main.py -f 0 -gn 0 -p MST -tsd images/test -scw 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 4 | python main.py -f 0 -gn 0 -p MST -tsd images/test -scw 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 5 | python main.py -f 0 -gn 0 -p MST -tsd images/test -scw 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 6 | 7 | python main.py -f 0 -gn 0 -p MST -tsd images/test -scw 5 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 8 | python main.py -f 0 -gn 0 -p MST -tsd images/test -scw 5 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0 9 | python main.py -f 0 -gn 0 -p MST -tsd images/test -scw 0 5 0 5 0 0 0 0 0 0 0 0 0 0 0 0 10 | python main.py -f 0 -gn 0 -p MST -tsd images/test -scw 0 0 5 5 0 0 0 0 0 0 0 0 0 0 0 0 11 | python main.py -f 0 -gn 0 -p MST -tsd images/test -scw 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 12 | 13 | -------------------------------------------------------------------------------- /train_style.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py -f 1 -gn 0 -p MST -n 5 -b 16 -tsd images/test -scw 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -sti images/style_crop/0_udnie.jpg 3 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -sti images/style_crop/1_la_muse.jpg 4 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 -sti images/style_crop/2_rain_princess.jpg 5 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 -sti images/style_crop/3_the_scream.jpg 6 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 -sti images/style_crop/4_the_shipwreck_of_the_minotaur.jpg 7 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 -sti images/style_crop/5_wave.jpg 8 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 -sti images/style_crop/6_composition.jpg 9 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 -sti images/style_crop/7_guernica.jpg 10 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 -sti images/style_crop/8_starry_night.jpg 11 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 -sti images/style_crop/9_White_Zig_Zags.jpg 12 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 -sti images/style_crop/10_Yellow_sunset.jpg 13 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 -sti images/style_crop/11_Three_Fishing_Boats.jpg 14 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 -sti images/style_crop/12_The_Annunciation_of_the_Virgin_Deal.jpg 15 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 -sti images/style_crop/13_Edith_with_Striped_Dress.jpeg 16 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 -sti images/style_crop/14_Colors_from_a_Distance.jpg 17 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 -sti images/style_crop/15_Sunrise.jpg --------------------------------------------------------------------------------