├── .gitignore
├── README.md
├── images
├── crop.py
├── style
│ ├── 0_udnie.jpg
│ ├── 10_Yellow_sunset.jpg
│ ├── 11_Three_Fishing_Boats.jpg
│ ├── 12_The_Annunciation_of_the_Virgin_Deal.jpg
│ ├── 13_Edith_with_Striped_Dress.jpeg
│ ├── 14_Colors_from_a_Distance.jpg
│ ├── 15_Sunrise.jpg
│ ├── 1_la_muse.jpg
│ ├── 2_rain_princess.jpg
│ ├── 3_the_scream.jpg
│ ├── 4_the_shipwreck_of_the_minotaur.jpg
│ ├── 5_wave.jpg
│ ├── 6_composition.jpg
│ ├── 7_guernica.jpg
│ ├── 8_starry_night.jpg
│ └── 9_White_Zig_Zags.jpg
├── style_crop
│ ├── 0_udnie.jpg
│ ├── 10_Yellow_sunset.jpg
│ ├── 11_Three_Fishing_Boats.jpg
│ ├── 12_The_Annunciation_of_the_Virgin_Deal.jpg
│ ├── 13_Edith_with_Striped_Dress.jpeg
│ ├── 14_Colors_from_a_Distance.jpg
│ ├── 15_Sunrise.jpg
│ ├── 1_la_muse.jpg
│ ├── 2_rain_princess.jpg
│ ├── 3_the_scream.jpg
│ ├── 4_the_shipwreck_of_the_minotaur.jpg
│ ├── 5_wave.jpg
│ ├── 6_composition.jpg
│ ├── 7_guernica.jpg
│ ├── 8_starry_night.jpg
│ └── 9_White_Zig_Zags.jpg
└── test
│ ├── chicago.jpg
│ ├── stata.jpg
│ └── tubingen.jpg
├── main.py
├── result
├── conditional_instance_norm.jpg
├── result.jpg
├── style.jpg
├── style01_01.gif
├── style02_01.gif
├── style03_01.gif
├── style04_01.gif
├── style05_01.gif
├── style06_01.gif
├── style07_01.gif
├── style08_01.gif
├── style09_01.gif
├── style10_01.gif
├── style11_01.gif
├── style12_01.gif
├── style13_01.gif
├── tubingen_10.jpg
├── tubingen_11.jpg
├── tubingen_9.jpg
├── tubingen_9_10.jpg
└── tubingen_9_10_11.jpg
├── src
├── __init__.py
├── functions.py
├── layers.py
├── multi_style_transfer.py
├── op.py
└── vgg19.py
├── test_style.sh
└── train_style.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | src/*.pyc
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Fast Multi(Interpolative) Style Transfer
2 | Implementation of Google Brain's [A Learned Representation For Artistic Style](https://arxiv.org/pdf/1610.07629v2.pdf) in Tensorflow.
3 | You can mix various type of style images using just One Model and it's still Fast!
4 |
5 |
6 |
7 |
8 | Figure1. Using one model and making multi style transfer image. Center image is mixed with 4 style
9 |
10 | This paper is next version of [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155)
11 | and [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
12 | These papers are fast and nice result, but one model make only one style image.
13 |
14 |
15 | ## Usage
16 | Recommand to download project files (model, vgg, image, etc.) [One drive](https://1drv.ms/f/s!ArFpOdlDcjqQga8fwL0m4VQGmgKSfg) / [Dropbox](https://www.dropbox.com/sh/b3by1ipmr0v821y/AABJ4gadaOk9RRsqsOTC336Xa?dl=0). And Download [COCO](http://mscoco.org/dataset/#download) on your data folder. Example command lines are below and train_style.sh, test_style.sh.
17 |
18 | #### Project folder tree
19 | Working Directory
20 | ├── MST
21 | │ ├── models
22 | │ │ ├── checkpoint
23 | │ │ ├── xxx.index
24 | │ │ ├── xxx.data
25 | | │ └── xxxx.meta
26 | │ ├── test_result
27 | │ └── train_result
28 | ├── images
29 | │ ├── style
30 | │ ├── style_crop
31 | │ ├── test
32 | │ └── crop.py
33 | ├── src
34 | │ ├── vgg19.mat
35 | │ └── ...
36 | ├── main.py
37 | ├── test_style.sh
38 | └── train_style.sh
39 |
40 |
41 |
42 | #### Style Control Weight (SCW)
43 | "-scw, --style_control_weights" is style control argument. "0 0 0 ... 0 0 0 " means weight of "style1 style2 ... style16"
44 |
45 | If you want single style
46 |
47 | style1 -scw "1 0 0 ... 0 0 0"
48 | style16 -scw "0 0 0 ... 0 0 1"
49 |
50 | If you want multi style
51 |
52 | 0.5 * style1 + 0.5 * style2 -scw "0.5 0.5 0 ... 0 0 0" or "1 1 0 ... 0 0 0"
53 | 0.2 * style1 + 0.3 * style2 + 0.4 * style3 -scw "0.2 0.3 0.4 ... 0 0 0" or "2 3 4 ... 0 0 0"
54 | 1/16 * (style1 ~ style16) -scw "0.63 0.63 ... 0.63 0.63" or "1 1 1 ... 1 1 1"
55 |
56 |
57 | ### Train
58 | #### From Scratch.
59 |
60 | python main.py -f 1 -gn 0 -p MST -n 10 -b 16 \
61 | -tsd images/test -sti images/style_crop/0_udnie.jpg \
62 | -ctd /mnt/cloud/Data/COCO/train2014 \
63 | -scw 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \
64 |
65 | Train weight, bias, gamma_0, beta_0. Need 40000 iteration (10 epoch)
66 |
67 |
68 |
69 |
70 | #### Fine-Tuned. (after train 'from scratch' or download trained model)
71 |
72 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 \
73 | -tsd images/test -sti images/style_crop/1_la_muse.jpg \
74 | -ctd /mnt/cloud/Data/COCO/train2014 \
75 | -scw 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \
76 |
77 | Train only gamma_i, beta_i. Just need 4000 iteration (1 epoch, 1/10 scratch)
78 |
79 | You can see that images gradually change to new style.
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 | #### Number of style
98 | if you want to train 32-style model, edit main.py and -scw range (16 to 32)
99 |
100 | 1. parser.add_argument("-scw", "--style_control_weights", type=float, nargs=16 --> 32)
101 | 2. -scw 1 2 3 ... 14 15 16 --> -scw 1 2 3 ... 30 31 32
102 |
103 |
104 | ### Test
105 | #### Single style
106 |
107 | ex) style9
108 | python main.py -f 0 -gn 0 -p MST \
109 | -tsd images/test \
110 | -scw 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
111 |
112 |
113 |
114 |
115 |
116 |
117 | ex) style10
118 | python main.py -f 0 -gn 0 -p MST \
119 | -tsd images/test \
120 | -scw 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0
121 |
122 |
123 |
124 |
125 |
126 |
127 | #### Multi Style
128 |
129 | ex) 0.5*style9 + 0.5*style10
130 | python main.py -f 0 -gn 0 -p MST \
131 | -tsd images/test \
132 | -scw 0 0 0 0 0 0 0 0 0 0.5 0.5 0 0 0 0 0 \
133 |
134 |
135 |
136 |
137 |
138 | ex) 0.33*style9 + 0.33*style10 + 0.33*style11
139 | python main.py -f 0 -gn 0 -p MST \
140 | -tsd images/test \
141 | -scw 0 0 0 0 0 0 0 0 0 0.33 0.33 0.33 0 0 0 0 \
142 |
143 |
144 |
145 |
146 |
147 | ## Implementation Details
148 | #### Conditional instance normalization
149 |
150 | The key of this paper is Conditional instance normalization.
151 |
152 |
153 |
154 |
155 |
156 | Instance normalization is similar with batch normalization, but it doesn't accumulate mean(mu), variance(alpha).
157 | Conditional instance normalization have N scale(gamma) and N shift(beta). N is number of style images.
158 | This means when you add new style, you just train new gamma and new beta.
159 | See the below code.
160 |
161 | def conditional_instance_norm:
162 | ...
163 | shift = []
164 | scale = []
165 |
166 | for i in range(len(style_control)):
167 | with tf.variable_scope('{0}'.format(i) + '_style'):
168 | shift.append(tf.get_variable('shift', shape=var_shape, initializer=tf.constant_initializer(0.)))
169 | scale.append(tf.get_variable('scale', shape=var_shape, initializer=tf.constant_initializer(1.)))
170 | ...
171 |
172 | idx = [i for i, x in enumerate(style_control) if not x == 0]
173 | style_scale = reduce(tf.add, [scale[i]*style_control[i] for i in idx]) / sum(style_control)
174 | style_shift = reduce(tf.add, [shift[i]*style_control[i] for i in idx]) / sum(style_control)
175 | output = style_scale * normalized + style_shift
176 |
177 |
178 | #### Upsampling
179 | Paper's upsampling method is "Image_resize-Conv". But I use ["Deconv-Pooling"](https://arxiv.org/abs/1611.04994)
180 |
181 | def mst_net:
182 | ...
183 | x = conv_tranpose_layer(x, 64, 3, 2, style_control=style_control, name='up_conv1')
184 | x = pooling(x)
185 | x = conv_tranpose_layer(x, 32, 3, 2, style_control=style_control, name='up_conv2')
186 | x = pooling(x)
187 | ...
188 |
189 |
190 | ## Requirements
191 | - TensorFlow 1.0.0
192 | - Python 2.7.12, Pillow 3.4.2, scipy 0.18.1, numpy 1.11.2
193 |
194 | ## Attributions/Thanks
195 | This project borrowed some code from [Lengstrom's fast-style-transfer.](https://github.com/lengstrom/fast-style-transfer)
196 | And Google brain's code is [here](https://github.com/tensorflow/magenta) (need some install)
197 |
--------------------------------------------------------------------------------
/images/crop.py:
--------------------------------------------------------------------------------
1 | import os
2 | import scipy.misc as cv2
3 | import numpy as np
4 |
5 | dataset = 'style'
6 | saveset = 'style_crop'
7 | size = 512
8 |
9 | for fn in os.listdir(dataset):
10 | print fn
11 | img = cv2.imread(dataset + '/' + fn)
12 | w,h,c = np.shape(img)
13 | print w,h
14 |
15 | if w >= h:
16 | ratio = float(h)/float(w)
17 | resize_factor = (int(size/ratio), size)
18 | img_resize = cv2.imresize(img, resize_factor)
19 | else:
20 | ratio = float(w)/float(h)
21 | resize_factor = (size, int(size/ratio))
22 | img_resize = cv2.imresize(img, resize_factor)
23 |
24 | w,h,c = np.shape(img_resize)
25 | crop_w = int((w-size) * 0.5)
26 | crop_h = int((h-size) * 0.5)
27 | # cv2.imsave(saveset + '/' + 'resize_' + fn, img_resize)
28 |
29 | print crop_h, crop_w
30 | img_crop = img_resize[crop_w:crop_w+size,crop_h:crop_h+size,:]
31 | cv2.imsave(saveset + '/' + fn, img_crop)
32 |
33 |
34 |
--------------------------------------------------------------------------------
/images/style/0_udnie.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/0_udnie.jpg
--------------------------------------------------------------------------------
/images/style/10_Yellow_sunset.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/10_Yellow_sunset.jpg
--------------------------------------------------------------------------------
/images/style/11_Three_Fishing_Boats.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/11_Three_Fishing_Boats.jpg
--------------------------------------------------------------------------------
/images/style/12_The_Annunciation_of_the_Virgin_Deal.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/12_The_Annunciation_of_the_Virgin_Deal.jpg
--------------------------------------------------------------------------------
/images/style/13_Edith_with_Striped_Dress.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/13_Edith_with_Striped_Dress.jpeg
--------------------------------------------------------------------------------
/images/style/14_Colors_from_a_Distance.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/14_Colors_from_a_Distance.jpg
--------------------------------------------------------------------------------
/images/style/15_Sunrise.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/15_Sunrise.jpg
--------------------------------------------------------------------------------
/images/style/1_la_muse.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/1_la_muse.jpg
--------------------------------------------------------------------------------
/images/style/2_rain_princess.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/2_rain_princess.jpg
--------------------------------------------------------------------------------
/images/style/3_the_scream.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/3_the_scream.jpg
--------------------------------------------------------------------------------
/images/style/4_the_shipwreck_of_the_minotaur.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/4_the_shipwreck_of_the_minotaur.jpg
--------------------------------------------------------------------------------
/images/style/5_wave.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/5_wave.jpg
--------------------------------------------------------------------------------
/images/style/6_composition.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/6_composition.jpg
--------------------------------------------------------------------------------
/images/style/7_guernica.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/7_guernica.jpg
--------------------------------------------------------------------------------
/images/style/8_starry_night.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/8_starry_night.jpg
--------------------------------------------------------------------------------
/images/style/9_White_Zig_Zags.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style/9_White_Zig_Zags.jpg
--------------------------------------------------------------------------------
/images/style_crop/0_udnie.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/0_udnie.jpg
--------------------------------------------------------------------------------
/images/style_crop/10_Yellow_sunset.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/10_Yellow_sunset.jpg
--------------------------------------------------------------------------------
/images/style_crop/11_Three_Fishing_Boats.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/11_Three_Fishing_Boats.jpg
--------------------------------------------------------------------------------
/images/style_crop/12_The_Annunciation_of_the_Virgin_Deal.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/12_The_Annunciation_of_the_Virgin_Deal.jpg
--------------------------------------------------------------------------------
/images/style_crop/13_Edith_with_Striped_Dress.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/13_Edith_with_Striped_Dress.jpeg
--------------------------------------------------------------------------------
/images/style_crop/14_Colors_from_a_Distance.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/14_Colors_from_a_Distance.jpg
--------------------------------------------------------------------------------
/images/style_crop/15_Sunrise.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/15_Sunrise.jpg
--------------------------------------------------------------------------------
/images/style_crop/1_la_muse.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/1_la_muse.jpg
--------------------------------------------------------------------------------
/images/style_crop/2_rain_princess.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/2_rain_princess.jpg
--------------------------------------------------------------------------------
/images/style_crop/3_the_scream.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/3_the_scream.jpg
--------------------------------------------------------------------------------
/images/style_crop/4_the_shipwreck_of_the_minotaur.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/4_the_shipwreck_of_the_minotaur.jpg
--------------------------------------------------------------------------------
/images/style_crop/5_wave.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/5_wave.jpg
--------------------------------------------------------------------------------
/images/style_crop/6_composition.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/6_composition.jpg
--------------------------------------------------------------------------------
/images/style_crop/7_guernica.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/7_guernica.jpg
--------------------------------------------------------------------------------
/images/style_crop/8_starry_night.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/8_starry_night.jpg
--------------------------------------------------------------------------------
/images/style_crop/9_White_Zig_Zags.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/style_crop/9_White_Zig_Zags.jpg
--------------------------------------------------------------------------------
/images/test/chicago.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/test/chicago.jpg
--------------------------------------------------------------------------------
/images/test/stata.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/test/stata.jpg
--------------------------------------------------------------------------------
/images/test/tubingen.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/images/test/tubingen.jpg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import distutils.util
3 | import tensorflow as tf
4 | import src.multi_style_transfer as mst
5 |
6 |
7 | def main():
8 | parser = argparse.ArgumentParser()
9 |
10 | # Train ############################################################################################################
11 | parser.add_argument("-f", "--flag", type=distutils.util.strtobool, default='True')
12 | parser.add_argument("-gn", "--gpu_number", type=int, default=0)
13 | parser.add_argument("-p", "--project", type=str, default="mst")
14 |
15 | ## Train images
16 | parser.add_argument("-ctd", "--content_dataset", type=str, default="/mnt/cloud/Data/COCO/all")
17 | parser.add_argument("-cts", "--content_data_size", type=int, default=256)
18 | parser.add_argument("-sti", "--style_image", type=str, default="images/style/0_udnie.jpg")
19 |
20 | ## Train Iteration
21 | parser.add_argument("-n", "--niter", type=int, default=2)
22 | parser.add_argument("-ns", "--nsnapshot", type=int, default=100)
23 | parser.add_argument("-mx", "--max_to_keep", type=int, default=10)
24 |
25 | ## Train Parameter
26 | parser.add_argument("-b", "--batch_size", type=int, default=16)
27 | parser.add_argument("-lr", "--learning_rate", type=float, default=1e-3)
28 | parser.add_argument("-m", "--momentum", type=float, default=0.9)
29 | parser.add_argument("-m2", "--momentum2", type=float, default=0.999)
30 |
31 | ## loss weight
32 | parser.add_argument("-lc", "--content_loss_weights", type=float, default=1.5e0)
33 | parser.add_argument("-ls", "--style_loss_weights", type=float, default=1e2)
34 | parser.add_argument("-lt", "--tv_loss_weight", type=float, default=2e2)
35 |
36 | # Test #############################################################################################################
37 | parser.add_argument("-tsd", "--test_dataset", type=str, default="images/test")
38 | parser.add_argument("-scw", "--style_control_weights", type=float, nargs=16)
39 |
40 | args = parser.parse_args()
41 | gpu_number = args.gpu_number
42 | train_flag = args.flag
43 |
44 | with tf.device('/gpu:{}'.format(gpu_number)):
45 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.85)
46 | config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)
47 |
48 | with tf.Session(config=config) as sess:
49 | ## Make Model
50 | model = mst.mst(args, sess)
51 |
52 | ## TRAIN / TEST
53 | if train_flag:
54 | model.train(train_flag)
55 | else:
56 | model.test(train_flag)
57 |
58 |
59 | if __name__ == '__main__':
60 | main()
61 |
62 |
--------------------------------------------------------------------------------
/result/conditional_instance_norm.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/conditional_instance_norm.jpg
--------------------------------------------------------------------------------
/result/result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/result.jpg
--------------------------------------------------------------------------------
/result/style.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style.jpg
--------------------------------------------------------------------------------
/result/style01_01.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style01_01.gif
--------------------------------------------------------------------------------
/result/style02_01.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style02_01.gif
--------------------------------------------------------------------------------
/result/style03_01.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style03_01.gif
--------------------------------------------------------------------------------
/result/style04_01.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style04_01.gif
--------------------------------------------------------------------------------
/result/style05_01.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style05_01.gif
--------------------------------------------------------------------------------
/result/style06_01.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style06_01.gif
--------------------------------------------------------------------------------
/result/style07_01.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style07_01.gif
--------------------------------------------------------------------------------
/result/style08_01.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style08_01.gif
--------------------------------------------------------------------------------
/result/style09_01.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style09_01.gif
--------------------------------------------------------------------------------
/result/style10_01.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style10_01.gif
--------------------------------------------------------------------------------
/result/style11_01.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style11_01.gif
--------------------------------------------------------------------------------
/result/style12_01.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style12_01.gif
--------------------------------------------------------------------------------
/result/style13_01.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/style13_01.gif
--------------------------------------------------------------------------------
/result/tubingen_10.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/tubingen_10.jpg
--------------------------------------------------------------------------------
/result/tubingen_11.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/tubingen_11.jpg
--------------------------------------------------------------------------------
/result/tubingen_9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/tubingen_9.jpg
--------------------------------------------------------------------------------
/result/tubingen_9_10.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/tubingen_9_10.jpg
--------------------------------------------------------------------------------
/result/tubingen_9_10_11.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/result/tubingen_9_10_11.jpg
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/Fast_Multi_Style_Transfer-tensorflow/dc74be72f9a924eaaf482188277b2bd94b64b1ff/src/__init__.py
--------------------------------------------------------------------------------
/src/functions.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import scipy.misc as scm
4 | from glob import glob
5 |
6 |
7 | def get_image(img_path, size=None):
8 | img = scm.imread(img_path, mode='RGB')
9 | h, w, c = np.shape(img)
10 |
11 | img = img[:h, :w, ::-1] # rgb to bgr
12 | if size:
13 | img = scm.imresize(img, (size, size))
14 | return img
15 |
16 |
17 | def inverse_image(img):
18 | img[img > 255] = 255
19 | img[img < 0] = 0
20 | img = img[:, :, ::-1] # bgr to rgb
21 | return img
22 |
23 |
24 | def make_project_dir(project_dir):
25 | if not os.path.exists(project_dir):
26 | os.makedirs(project_dir)
27 | os.makedirs(os.path.join(project_dir,'models'))
28 | os.makedirs(os.path.join(project_dir,'train_result'))
29 | os.makedirs(os.path.join(project_dir,'test_result'))
30 |
31 |
32 | def data_loader(dataset):
33 | print 'images Load ....'
34 | data_path = dataset
35 |
36 | if os.path.exists(data_path + '.npy'):
37 | data = np.load(data_path + '.npy')
38 | else:
39 | data = glob(os.path.join(data_path, "*.*"))
40 | np.save(data_path + '.npy', data)
41 | print 'images Load Done'
42 |
43 | return data
44 |
--------------------------------------------------------------------------------
/src/layers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def conv_layer(net, num_filters, filter_size, strides, style_control=None, relu=True, name='conv'):
5 | with tf.variable_scope(name):
6 | b,w,h,c = net.get_shape().as_list()
7 | weights_shape = [filter_size, filter_size, c, num_filters]
8 | weights_init = tf.get_variable(name, shape=weights_shape, initializer=tf.truncated_normal_initializer(stddev=.01))
9 | strides_shape = [1, strides, strides, 1]
10 |
11 | p = (filter_size - 1) / 2
12 | if strides == 1:
13 | net = tf.pad(net, [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT")
14 | net = tf.nn.conv2d(net, weights_init, strides_shape, padding="VALID")
15 | else:
16 | net = tf.nn.conv2d(net, weights_init, strides_shape, padding="SAME")
17 |
18 | net = conditional_instance_norm(net, style_control=style_control)
19 | if relu:
20 | net = tf.nn.relu(net)
21 |
22 | return net
23 |
24 |
25 | def conv_tranpose_layer(net, num_filters, filter_size, strides, style_control=None, name='conv_t'):
26 | with tf.variable_scope(name):
27 | b, w, h, c = net.get_shape().as_list()
28 | weights_shape = [filter_size, filter_size, num_filters, c]
29 | weights_init = tf.get_variable(name, shape=weights_shape, initializer=tf.truncated_normal_initializer(stddev=.01))
30 |
31 | batch_size, rows, cols, in_channels = [i.value for i in net.get_shape()]
32 | new_rows, new_cols = int(rows * strides), int(cols * strides)
33 | # new_shape = #tf.pack([tf.shape(net)[0], new_rows, new_cols, num_filters])
34 |
35 | new_shape = [batch_size, new_rows, new_cols, num_filters]
36 | tf_shape = tf.stack(new_shape)
37 | strides_shape = [1,strides,strides,1]
38 |
39 | p = (filter_size - 1) / 2
40 | if strides == 1:
41 | net = tf.pad(net, [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT")
42 | net = tf.nn.conv2d_transpose(net, weights_init, tf_shape, strides_shape, padding="VALID")
43 | else:
44 | net = tf.nn.conv2d_transpose(net, weights_init, tf_shape, strides_shape, padding="SAME")
45 | net = conditional_instance_norm(net, style_control=style_control)
46 |
47 | return tf.nn.relu(net)
48 |
49 |
50 | def residual_block(net, filter_size=3, style_control=None, name='res'):
51 | with tf.variable_scope(name+'_a'):
52 | tmp = conv_layer(net, 128, filter_size, 1, style_control=style_control)
53 | with tf.variable_scope(name+'_b'):
54 | output = net + conv_layer(tmp, 128, filter_size, 1, style_control=style_control, relu=False)
55 | return output
56 |
57 |
58 | def conditional_instance_norm(net, style_control=None, name='cond_in'):
59 | with tf.variable_scope(name):
60 | batch, rows, cols, channels = [i.value for i in net.get_shape()]
61 | mu, sigma_sq = tf.nn.moments(net, [1,2], keep_dims=True)
62 |
63 | var_shape = [channels]
64 | shift = []
65 | scale = []
66 |
67 | for i in range(len(style_control)):
68 | with tf.variable_scope('{0}'.format(i) + '_style'):
69 | shift.append(tf.get_variable('shift', shape=var_shape, initializer=tf.constant_initializer(0.)))
70 | scale.append(tf.get_variable('scale', shape=var_shape, initializer=tf.constant_initializer(1.)))
71 | epsilon = 1e-3
72 | normalized = (net-mu)/(sigma_sq + epsilon)**(.5)
73 |
74 | idx = [i for i, x in enumerate(style_control) if not x == 0]
75 |
76 | style_scale = reduce(tf.add, [scale[i]*style_control[i] for i in idx]) / sum(style_control)
77 | style_shift = reduce(tf.add, [shift[i]*style_control[i] for i in idx]) / sum(style_control)
78 | output = style_scale * normalized + style_shift
79 |
80 | return output
81 |
82 |
83 | def instance_norm(net, train=True, name='in'):
84 | with tf.variable_scope(name):
85 | batch, rows, cols, channels = [i.value for i in net.get_shape()]
86 | var_shape = [channels]
87 | mu, sigma_sq = tf.nn.moments(net, [1,2], keep_dims=True)
88 | shift = tf.get_variable('shift', shape=var_shape, initializer=tf.constant_initializer(0.))
89 | scale = tf.get_variable('scale', shape=var_shape, initializer=tf.constant_initializer(1.))
90 | epsilon = 1e-3
91 | normalized = (net-mu)/(sigma_sq + epsilon)**(.5)
92 | return scale * normalized + shift
93 |
94 | def pooling(input):
95 | return tf.nn.avg_pool(input, ksize=(1, 2, 2, 1), strides=(1, 1, 1, 1), padding='SAME')
96 |
97 |
98 | def total_variation(preds):
99 | # total variation denoising
100 | b,w,h,c = preds.get_shape().as_list()
101 | y_tv = tf.nn.l2_loss(preds[:,1:,:,:] - preds[:,:w-1,:,:])
102 | x_tv = tf.nn.l2_loss(preds[:,:,1:,:] - preds[:,:,:h-1,:])
103 | tv_loss = 2*(x_tv + y_tv)/b/w/h/c
104 | return tv_loss
105 |
106 |
107 | def euclidean_loss(input_, target_):
108 | b,w,h,c = input_.get_shape().as_list()
109 | return 2 * tf.nn.l2_loss(input_- target_) / b/w/h/c
110 |
111 |
112 | def gram_matrix(net):
113 | b,h,w,c = net.get_shape().as_list()
114 | feats = tf.reshape(net, (b, h*w, c))
115 | feats_T = tf.transpose(feats, perm=[0,2,1])
116 | grams = tf.matmul(feats_T, feats) / h/w/c
117 | return grams
118 |
119 |
120 | def style_loss(input_, style_):
121 | b,h,w,c = input_.get_shape().as_list()
122 | input_gram = gram_matrix(input_)
123 | style_gram = gram_matrix(style_)
124 | return 2 * tf.nn.l2_loss(input_gram - style_gram)/b/c/c
125 |
126 |
--------------------------------------------------------------------------------
/src/multi_style_transfer.py:
--------------------------------------------------------------------------------
1 | import scipy.io
2 | import numpy as np
3 | import src.vgg19 as vgg
4 |
5 | from src.op import op
6 | from src.layers import *
7 | from src.functions import *
8 |
9 |
10 | class mst(op):
11 | def __init__(self,args, sess):
12 | op.__init__(self, args, sess)
13 |
14 | def mst_net(self, x, style_control=None, reuse=False):
15 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
16 | b,h,w,c = x.get_shape().as_list()
17 |
18 | x = conv_layer(x, 32, 9, 1, style_control=style_control, name='conv1')
19 | x = conv_layer(x, 64, 3, 2, style_control=style_control, name='conv2')
20 | x = conv_layer(x, 128, 3, 2, style_control=style_control, name='conv3')
21 | x = residual_block(x, 3, style_control=style_control, name='res1')
22 | x = residual_block(x, 3, style_control=style_control, name='res2')
23 | x = residual_block(x, 3, style_control=style_control, name='res3')
24 | x = residual_block(x, 3, style_control=style_control, name='res4')
25 | x = residual_block(x, 3, style_control=style_control, name='res5')
26 | x = conv_tranpose_layer(x, 64, 3, 2, style_control=style_control, name='up_conv1')
27 | x = pooling(x)
28 | x = conv_tranpose_layer(x, 32, 3, 2, style_control=style_control, name='up_conv2')
29 | x = pooling(x)
30 | x = conv_layer(x, 3, 9, 1, relu=False, style_control=style_control, name='output')
31 | preds = tf.nn.tanh(x) * 150 + 255./2
32 | return preds
33 |
34 |
35 | def build_model(self):
36 | # Set content / style input
37 | # content_input
38 | b = self.batch_size; h = self.content_data_size; w = self.content_data_size;
39 | self.content_input = tf.placeholder(tf.float32, shape=[b, h, w, 3], name='content_input')
40 |
41 | # style_input
42 | style_img = get_image(self.style_image)
43 | style_idx = [i for i, x in enumerate(self.style_control) if not x == 0][0]
44 | print 'style_idx : {}'.format(style_idx)
45 | style_input = tf.constant((style_img[np.newaxis, ...]), dtype=tf.float32)
46 |
47 | # MST_output (Pastiche)
48 | MST_output = self.mst_net(self.content_input, style_control=self.style_control)
49 |
50 | # VGG network
51 | weights = scipy.io.loadmat('src/vgg19.mat')
52 | vgg_mean = tf.constant(np.array([103.939, 116.779, 123.68]).reshape((1, 1, 1, 3)), dtype='float32')
53 |
54 | content_feats = vgg.net(self.content_input - vgg_mean, weights)
55 | style_feats = vgg.net(style_input - vgg_mean, weights)
56 | MST_feats = vgg.net(MST_output - vgg_mean, weights)
57 |
58 | c_loss = self.content_weights * euclidean_loss(MST_feats[-1], content_feats[-1])
59 | s_loss = self.style_weights * sum([style_loss(MST_feats[i], style_feats[i]) for i in range(5)])
60 | tv_loss = self.tv_weight * total_variation(MST_output)
61 | loss = c_loss + s_loss + tv_loss
62 |
63 | t_vars = tf.trainable_variables()
64 | vars = [var for var in t_vars if '{0}'.format(style_idx) + '_style' in var.name]
65 |
66 | if style_idx == 0:
67 | train_opt = tf.train.AdamOptimizer(self.learning_rate, self.momentum).minimize(loss)
68 | else:
69 | train_opt = tf.train.AdamOptimizer(self.learning_rate, self.momentum).minimize(loss, var_list=vars)
70 |
71 | self.optimize = [train_opt, loss, c_loss, s_loss, tv_loss]
72 | self.saver = tf.train.Saver(var_list=t_vars, max_to_keep=(self.max_to_keep))
73 | self.sess.run(tf.global_variables_initializer())
74 |
75 |
76 | def train(self,Train_flag):
77 | op.train(self,Train_flag)
78 |
79 | def test(self, Train_flag):
80 | op.test(self,Train_flag)
81 |
82 | def save(self):
83 | op.save(self)
84 |
85 | def load(self):
86 | op.load(self)
87 |
--------------------------------------------------------------------------------
/src/op.py:
--------------------------------------------------------------------------------
1 | import time
2 | import tensorflow as tf
3 | from src.functions import *
4 |
5 |
6 | class op(object):
7 |
8 | def __init__(self, args, sess):
9 | self.sess = sess
10 |
11 | ## Train
12 | self.gpu_number = args.gpu_number
13 | self.project_name = args.project
14 |
15 | ## Train images
16 | self.content_dataset = args.content_dataset ## test2015
17 | self.content_data_size = args.content_data_size
18 | self.style_image = args.style_image
19 |
20 | ## Train Iteration
21 | self.niter = args.niter
22 | self.niter_snapshot = args.nsnapshot
23 | self.max_to_keep = args.max_to_keep
24 |
25 | ## Train Parameter
26 | self.batch_size = args.batch_size
27 | self.learning_rate = args.learning_rate
28 | self.momentum = args.momentum
29 | self.momentum2 = args.momentum2
30 |
31 | self.content_weights = args.content_loss_weights
32 | self.style_weights = args.style_loss_weights
33 | self.tv_weight = args.tv_loss_weight
34 |
35 | ## Result Dir & File
36 | self.project_dir = '{0}/'.format(self.project_name)
37 | make_project_dir(self.project_dir)
38 | self.ckpt_dir = os.path.join(self.project_dir, 'models')
39 |
40 | ## Test
41 | self.test_dataset = args.test_dataset
42 | self.style_control = args.style_control_weights
43 |
44 | ## build model
45 | self.build_model()
46 |
47 |
48 | def train(self,Train_flag):
49 | data = data_loader(self.content_dataset)
50 | print 'Shuffle ....'
51 | random_order = np.random.permutation(len(data))
52 | data = [data[i] for i in random_order[:10000*self.batch_size]]
53 | print 'Shuffle Done'
54 |
55 | start_time = time.time()
56 | count = 0
57 |
58 | try:
59 | self.load()
60 | print 'Weight Load !!'
61 | except:
62 | self.sess.run(tf.global_variables_initializer())
63 |
64 | for epoch in xrange(self.niter):
65 | batch_idxs = len(data) // self.batch_size
66 |
67 | for idx in xrange(0, batch_idxs):
68 | count += 1
69 |
70 | batch_files = data[idx * self.batch_size: (idx + 1) * self.batch_size]
71 | batch_label = [(get_image(batch_file, self.content_data_size)) for batch_file in batch_files]
72 |
73 | feeds = {self.content_input: batch_label}
74 |
75 | _, loss_all, loss_c, loss_s, loss_tv = self.sess.run(self.optimize, feed_dict=feeds)
76 | train_time = time.time() - start_time
77 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.4f, loss_c: %.4f, loss_s: %.4f, loss_tv: %.4f"
78 | % (epoch, idx, batch_idxs, train_time, loss_all, loss_c, loss_s, loss_tv))
79 |
80 | ## Test during Training
81 | if count % self.niter_snapshot == (self.niter_snapshot-1):
82 | self.count = count
83 | self.save()
84 | self.test(Train_flag)
85 |
86 |
87 | def test(self, Train_flag=True):
88 | for fn in os.listdir(self.test_dataset):
89 |
90 | ## Read RGB Image
91 | im_input = get_image(self.test_dataset + '/' + fn)
92 | im_input_4d = im_input[np.newaxis, ...]
93 | im_b, im_h, im_w, im_c = np.shape(im_input_4d)
94 |
95 | ## Run Model
96 | img = tf.placeholder(tf.float32, [im_b, im_h, im_w, im_c], name='img')
97 |
98 | self.test_recon = self.mst_net(img, style_control=self.style_control, reuse=True)
99 | self.load()
100 |
101 | im_output = self.sess.run(self.test_recon, feed_dict={img : im_input_4d})
102 | im_output = inverse_image(im_output[0])
103 |
104 | style_idx = ['{0}_{1}'.format(i, x) for i, x in enumerate(self.style_control) if not x == 0]
105 |
106 | ## Image Show & Save
107 | style_name = os.path.split(self.style_image)[-1].split('.')[0]
108 | if Train_flag:
109 | train_output_dir = os.path.join(self.project_dir, 'train_result', style_name)
110 | if not os.path.exists(train_output_dir):
111 | os.makedirs(train_output_dir)
112 | filename = fn[:-4] + '_' + str(style_idx) + '_' + str(self.count) + '_output.bmp'
113 | scm.imsave(os.path.join(train_output_dir, filename), im_output)
114 | else:
115 | test_output_dir = os.path.join(self.project_dir, 'test_result')
116 | filename = fn[:-4] + '_' + str(style_idx) + '_output.bmp'
117 | scm.imsave(os.path.join(test_output_dir, filename), im_output)
118 |
119 | print filename
120 |
121 |
122 | def save(self):
123 | style_name = os.path.basename(self.style_image)[:-4]
124 | self.model_name = "{0}_{1}.model".format(self.project_name, style_name)
125 |
126 | if not os.path.exists(self.ckpt_dir):
127 | os.makedirs(self.ckpt_dir)
128 | self.saver.save(self.sess, os.path.join(self.ckpt_dir, self.model_name), global_step=self.count)
129 |
130 |
131 | def load(self):
132 | ckpt = tf.train.get_checkpoint_state(self.ckpt_dir)
133 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
134 | self.saver.restore(self.sess, os.path.join(self.ckpt_dir, ckpt_name))
--------------------------------------------------------------------------------
/src/vgg19.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import scipy.io
4 |
5 |
6 | MEAN_PIXEL = np.array([ 123.68 , 116.779, 103.939])
7 |
8 | def net(input_image, data):
9 | layers = (
10 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
11 |
12 | 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
13 |
14 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
15 | 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
16 |
17 | 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
18 | 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
19 |
20 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
21 | 'relu5_3', 'conv5_4', 'relu5_4'
22 | )
23 |
24 | #data = scipy.io.loadmat(data_path)
25 | #mean = data['normalization'][0][0][0]
26 | #mean_pixel = np.mean(mean, axis=(0, 1))
27 | weights = data['layers'][0]
28 |
29 | net = {}
30 | current = input_image
31 | net['input'] = input_image
32 | for i, name in enumerate(layers):
33 | kind = name[:4]
34 | if kind == 'conv':
35 | kernels, bias = weights[i][0][0][0][0]
36 | # matconvnet: weights are [width, height, in_channels, out_channels]
37 | # tensorflow: weights are [height, width, in_channels, out_channels]
38 | kernels = np.transpose(kernels, (1, 0, 2, 3))
39 | bias = bias.reshape(-1)
40 | current = _conv_layer(current, kernels, bias)
41 | elif kind == 'relu':
42 | current = tf.nn.relu(current)
43 | elif kind == 'pool':
44 | current = _pool_layer(current)
45 | net[name] = current
46 |
47 | #assert len(net) == len(layers)
48 | return [net['relu1_1'], net['relu2_1'], net['relu3_1'], net['relu4_1'], net['relu5_1'], net['relu4_2']]
49 |
50 |
51 | def _conv_layer(input, weights, bias):
52 | conv = tf.nn.conv2d(input, tf.constant(weights), strides=(1, 1, 1, 1),
53 | padding='SAME')
54 | return tf.nn.bias_add(conv, bias)
55 |
56 |
57 | def _pool_layer(input):
58 | # return tf.nn.max_pool(input, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1),
59 | return tf.nn.avg_pool(input, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1),
60 | padding='SAME')
61 |
--------------------------------------------------------------------------------
/test_style.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py -f 0 -gn 0 -p MST -tsd images/test -scw 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
3 | python main.py -f 0 -gn 0 -p MST -tsd images/test -scw 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0
4 | python main.py -f 0 -gn 0 -p MST -tsd images/test -scw 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0
5 | python main.py -f 0 -gn 0 -p MST -tsd images/test -scw 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0
6 |
7 | python main.py -f 0 -gn 0 -p MST -tsd images/test -scw 5 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0
8 | python main.py -f 0 -gn 0 -p MST -tsd images/test -scw 5 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0
9 | python main.py -f 0 -gn 0 -p MST -tsd images/test -scw 0 5 0 5 0 0 0 0 0 0 0 0 0 0 0 0
10 | python main.py -f 0 -gn 0 -p MST -tsd images/test -scw 0 0 5 5 0 0 0 0 0 0 0 0 0 0 0 0
11 | python main.py -f 0 -gn 0 -p MST -tsd images/test -scw 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0
12 |
13 |
--------------------------------------------------------------------------------
/train_style.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py -f 1 -gn 0 -p MST -n 5 -b 16 -tsd images/test -scw 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -sti images/style_crop/0_udnie.jpg
3 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -sti images/style_crop/1_la_muse.jpg
4 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 -sti images/style_crop/2_rain_princess.jpg
5 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 -sti images/style_crop/3_the_scream.jpg
6 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 -sti images/style_crop/4_the_shipwreck_of_the_minotaur.jpg
7 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 -sti images/style_crop/5_wave.jpg
8 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 -sti images/style_crop/6_composition.jpg
9 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 -sti images/style_crop/7_guernica.jpg
10 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 -sti images/style_crop/8_starry_night.jpg
11 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 -sti images/style_crop/9_White_Zig_Zags.jpg
12 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 -sti images/style_crop/10_Yellow_sunset.jpg
13 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 -sti images/style_crop/11_Three_Fishing_Boats.jpg
14 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 -sti images/style_crop/12_The_Annunciation_of_the_Virgin_Deal.jpg
15 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 -sti images/style_crop/13_Edith_with_Striped_Dress.jpeg
16 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 -sti images/style_crop/14_Colors_from_a_Distance.jpg
17 | python main.py -f 1 -gn 0 -p MST -n 1 -b 16 -tsd images/test -scw 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 -sti images/style_crop/15_Sunrise.jpg
--------------------------------------------------------------------------------