├── .gitignore
├── LICENSE
├── README.md
├── app.py
├── docs
└── img
│ ├── down_arrow.png
│ ├── left_arrow.png
│ ├── streamlit.jpg
│ └── title_figure.jpg
├── images
├── content
│ ├── elefant.jpg
│ ├── golden_gate.jpg
│ ├── olive_trees_greece.jpg
│ ├── road.jpg
│ └── winxp.jpg
└── style
│ ├── derain_mountains_at_colloiure.jpg
│ ├── picasso_self_portrait.jpg
│ ├── picasso_weeping_woman.jpg
│ ├── van_gogh_red_cabbages_and_onions.jpg
│ ├── van_gogh_starry_night.jpg
│ ├── van_gogh_trees.jpg
│ └── van_gogh_van_gogh_road_cypress.jpg
├── model.py
├── networks.py
├── notebooks
├── BrushstrokeStyleTransfer.ipynb
└── BrushstrokeStyleTransferDrawApp.ipynb
├── ops.py
├── requirements.txt
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .temp
3 | pretrained_weights
4 | *.swp
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 CompVis Heidelberg
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Rethinking Style Transfer: From Pixels to Parameterized Brushstrokes (CVPR 2021)
2 |
3 |

4 |
5 | ### [Project page](https://compvis.github.io/brushstroke-parameterized-style-transfer/) | [Paper](https://arxiv.org/abs/2103.17185) | [Colab](https://colab.research.google.com/drive/1J9B6_G2DSWmaBWw9Ot80W9t7O6pWu8Kw?usp=sharing) | [Colab for Drawing App](https://colab.research.google.com/drive/1ALNRoZgCj35uJ3Xvs24-QDwwtCb2lm3P?usp=sharing)
6 |
7 | Rethinking Style Transfer: From Pixels to Parameterized Brushstrokes.
8 | [Dmytro Kotovenko*](https://scholar.google.de/citations?user=T_U8yxwAAAAJ&hl=en), [Matthias Wright*](http://www.matthias-wright.com/), [Arthur Heimbrecht](http://www.aheimbrecht.de/), and [Björn Ommer](https://hci.iwr.uni-heidelberg.de/people/bommer).
9 | * denotes equal contribution
10 |
11 | ## Implementations
12 | We provide implementations in [Tensorflow 1](https://github.com/CompVis/brushstroke-parameterized-style-transfer/tree/tensorflow_v1) and [Tensorflow 2](https://github.com/CompVis/brushstroke-parameterized-style-transfer/tree/tensorflow_v2). In order to reproduce the results from the paper, we recommend the [Tensorflow 1](https://github.com/CompVis/brushstroke-parameterized-style-transfer/tree/tensorflow_v1) implementation.
13 |
14 | ## Installation
15 | 1. Clone this repository:
16 | ```sh
17 | > git clone https://github.com/CompVis/brushstroke-parameterized-style-transfer
18 | > cd brushstroke-parameterized-style-transfer
19 | ```
20 | 2. Install Tensorflow 1.14 (preferably with GPU support).
21 | If you are using [Conda](https://docs.conda.io/en/latest/index.html), this command will create a new environment and install Tensorflow as well as compatible CUDA and cuDNN versions.
22 | ```sh
23 | > conda create --name tf14 tensorflow-gpu==1.14
24 | > conda activate tf14
25 | ```
26 | 3. Install requirements:
27 | ```sh
28 | > pip install -r requirements.txt
29 | ```
30 |
31 | ## Basic Usage
32 | ```python
33 | from PIL import Image
34 | import model
35 |
36 | content_img = Image.open('images/content/golden_gate.jpg')
37 | style_img = Image.open('images/style/van_gogh_starry_night.jpg')
38 |
39 | stylized_img = model.stylize(content_img,
40 | style_img,
41 | num_strokes=5000,
42 | num_steps=100,
43 | content_weight=1.0,
44 | style_weight=3.0,
45 | num_steps_pixel=1000)
46 |
47 | stylized_img.save('images/stylized.jpg')
48 | ```
49 | or open [Colab](https://colab.research.google.com/drive/1J9B6_G2DSWmaBWw9Ot80W9t7O6pWu8Kw?usp=sharing).
50 |
51 | ## Drawing App
52 | We created a [Streamlit](https://streamlit.io/) app where you can draw curves to control the flow of brushstrokes.
53 |
54 | 
55 |
56 | #### Run drawing app on your machine
57 | To run the app on your own machine:
58 | ```sh
59 | > CUDA_VISIBLE_DEVICES=0 streamlit run app.py
60 | ```
61 |
62 |
63 | You can also run the app on a remote server and forward the port to your local machine:
64 | [https://docs.streamlit.io/en/0.66.0/tutorial/run_streamlit_remotely.html](https://docs.streamlit.io/en/0.66.0/tutorial/run_streamlit_remotely.html)
65 |
66 |
67 | #### Run streamlit app from Colab
68 | If you don't have access to GPUs we also created a [Colab](https://colab.research.google.com/drive/1ALNRoZgCj35uJ3Xvs24-QDwwtCb2lm3P?usp=sharing) from which you can start the drawing app.
69 |
70 | ## Other implementations
71 | [PyTorch implementation](https://github.com/justanhduc/brushstroke-parameterized-style-transfer) by [justanhduc](https://github.com/justanhduc).
72 |
73 | ## Citation
74 | ```
75 | @article{kotovenko_cvpr_2021,
76 | title={Rethinking Style Transfer: From Pixels to Parameterized Brushstrokes},
77 | author={Dmytro Kotovenko and Matthias Wright and Arthur Heimbrecht and Bj{\"o}rn Ommer},
78 | journal={CVPR},
79 | year={2021}
80 | }
81 | ```
82 |
83 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from typing import Sequence
4 | import streamlit as st
5 | from streamlit_drawable_canvas import st_canvas
6 | from stqdm import stqdm
7 | import os
8 | import base64
9 |
10 | from model import BrushstrokeOptimizer, PixelOptimizer
11 |
12 |
13 | def parse_paths(json_obj, height, width):
14 | xs = []
15 | ys = []
16 | for segments in json_obj['path']:
17 | if segments[0] == 'Q':
18 | xs.append(segments[2] / width)
19 | xs.append(segments[4] / width)
20 | ys.append(segments[1] / height)
21 | ys.append(segments[3] / height)
22 | xs = np.array(xs)
23 | ys = np.array(ys)
24 | return np.stack((xs, ys), axis=1)
25 |
26 |
27 | def sample_vectors(points, lookahead=10, freq=10):
28 | if points.shape[0] > 30:
29 | idcs = np.arange(points.shape[0])[::freq]
30 |
31 | idcs = np.arange(points.shape[0])
32 | vectors = []
33 | positions = []
34 |
35 | lookahead = min(lookahead, idcs.shape[0] - 1)
36 | for i in range(idcs.shape[0] - lookahead):
37 | vectors.append(points[idcs[i] + lookahead] - points[idcs[i]])
38 | positions.append(points[idcs[i]])
39 | return np.array(vectors), np.array(positions)
40 |
41 |
42 | def get_binary_file_downloader_html(bin_file, file_label='File'):
43 | # Taken from: https://discuss.streamlit.io/t/how-to-download-file-in-streamlit/1806/27
44 | with open(bin_file, 'rb') as f:
45 | data = f.read()
46 | bin_str = base64.b64encode(data).decode()
47 | href = f'Download {file_label}'
48 | return href
49 |
50 |
51 | def resize(img, size, interpolation=Image.BILINEAR):
52 | # https://github.com/pytorch/vision/blob/master/torchvision/transforms/functional.py
53 | if isinstance(size, int) or len(size) == 1:
54 | if isinstance(size, Sequence):
55 | size = size[0]
56 | w, h = img.size
57 | if (w <= h and w == size) or (h <= w and h == size):
58 | return img
59 | if w < h:
60 | ow = size
61 | oh = int(size * h / w)
62 | return img.resize((ow, oh), interpolation)
63 | else:
64 | oh = size
65 | ow = int(size * w / h)
66 | return img.resize((ow, oh), interpolation)
67 | else:
68 | return img.resize(size[::-1], interpolation)
69 |
70 |
71 | st.sidebar.markdown("""
72 |
83 | """, unsafe_allow_html=True)
84 |
85 |
86 | # Sidebar
87 | ## Content and Style image
88 | st.sidebar.markdown('Content Image
', unsafe_allow_html=True)
89 | selected_content_image = st.sidebar.selectbox('Select content image:', [None] + os.listdir('images/content'))
90 | st.sidebar.text('OR')
91 | uploaded_content_image = st.sidebar.file_uploader('Upload content image:', type=['png', 'jpg'])
92 |
93 | st.sidebar.markdown('Style Image
', unsafe_allow_html=True)
94 | selected_style_image = st.sidebar.selectbox('Select style image:', [None] + os.listdir('images/style'))
95 | st.sidebar.text('OR')
96 | uploaded_style_image = st.sidebar.file_uploader('Upload style image:', type=['png', 'jpg'])
97 |
98 | ## Parameters
99 | st.sidebar.markdown('Options
', unsafe_allow_html=True)
100 | num_steps_stroke = st.sidebar.slider('Brushstroke optimization steps:', 20, 100, 100)
101 | num_steps_pixel = st.sidebar.slider('Pixel optimization steps:', 100, 5000, 2000)
102 | num_strokes = st.sidebar.slider('Number of brushstrokes:', 100, 10000, 5000)
103 | content_weight = st.sidebar.slider('Content weight:', 1.0, 50.0, 1.0)
104 | style_weight = st.sidebar.slider('Style weight:', 1.0, 50.0, 3.0)
105 | draw_weight = st.sidebar.slider('Drawing weight', 50.0, 200.0, 100.0)
106 | draw_strength = st.sidebar.slider('Drawing strength (denoted L in the paper):', 50, 200, 100)
107 | stroke_width = st.sidebar.slider('Stroke width:', 0.01, 2.0, 0.1)
108 | stroke_length = st.sidebar.slider('Stroke length:', 0.1, 2.0, 1.1)
109 |
110 |
111 | #drawing_mode = st.sidebar.selectbox(
112 | # 'Drawing tool:', ('freedraw', 'line', 'rect', 'circle', 'transform')
113 | #)
114 | realtime_update = st.sidebar.checkbox('Update in realtime', True)
115 |
116 | # Main
117 | stroke_color = st.color_picker('Stroke color hex: ', '#ff0000')
118 |
119 |
120 | content_img = None
121 | if selected_content_image is not None: content_img = Image.open(os.path.join('images/content', selected_content_image))
122 | if uploaded_content_image is not None: content_img = Image.open(uploaded_content_image)
123 |
124 | style_img = None
125 | if selected_style_image is not None: style_img = Image.open(os.path.join('images/style', selected_style_image))
126 | if uploaded_style_image is not None: style_img = Image.open(uploaded_style_image)
127 |
128 | if content_img is None or style_img is None:
129 | st.image(Image.open('docs/img/left_arrow.png'))
130 | st.image(Image.open('docs/img/down_arrow.png'))
131 | #st.markdown('Select or upload content and style images...
', unsafe_allow_html=True)
132 |
133 |
134 | col1, col2 = st.beta_columns(2)
135 |
136 | # Preview images
137 | if content_img is not None:
138 | content_thumb = resize(content_img, size=400)
139 | col1.header('Content image')
140 | col1.image(content_img, use_column_width=True)
141 | if style_img is not None:
142 | style_thumb = resize(style_img, size=400)
143 | col2.header('Style image')
144 | col2.image(style_thumb, use_column_width=True)
145 |
146 |
147 | if content_img is not None and style_img is not None:
148 | if not os.path.exists('.temp'):
149 | os.makedirs('.temp')
150 |
151 | content_img_name = content_img.filename
152 | content_img = content_img.convert('RGB')
153 | content_img.save(f'.temp/content_img.jpg')
154 | style_img_name = style_img.filename
155 | style_img = style_img.convert('RGB')
156 | style_img.save(f'.temp/style_img.jpg')
157 |
158 | height = content_img.size[1]
159 | width = content_img.size[0]
160 | factor = 1.0
161 | # resize image such that the largest side is 512 because else the canvas drawer messes up
162 | if width > 512 or height > 512:
163 | if width < height:
164 | height = int(512 * (height / width))
165 | width = 512
166 | factor *= height / width
167 | else:
168 | width = int(512 * (width / height))
169 | height = 512
170 | factor *= width / height
171 |
172 |
173 | st.text('Now draw some curves on the canvas.')
174 | st.text('To draw a curve:')
175 | st.text('- hold down the left mouse button')
176 | st.text('- and move the mouse over the canvas.')
177 |
178 | # Create a canvas component
179 | canvas_result = st_canvas(
180 | fill_color='rgba(255, 165, 0, 0.3)', # Fixed fill color with some opacity
181 | stroke_width=3,
182 | stroke_color=stroke_color,
183 | background_color='' if content_img else '#eee',
184 | background_image=content_img,
185 | update_streamlit=realtime_update,
186 | height=height,
187 | width=width,
188 | drawing_mode='freedraw',
189 | #key='canvas',
190 | )
191 |
192 | if canvas_result.json_data is not None:
193 | if len(canvas_result.json_data['objects']) > 0:
194 | if st.button('Stylize'):
195 | vectors_all = []
196 | positions_all = []
197 | img_array = np.array(content_img)
198 | for i in range(len(canvas_result.json_data['objects'])):
199 | points = parse_paths(canvas_result.json_data['objects'][i], float(height), float(width))
200 |
201 | if points.shape[0] == 0:
202 | continue
203 |
204 | vectors, positions = sample_vectors(points, lookahead=5, freq=5)
205 |
206 | if vectors.ndim < 2 or positions.ndim < 2:
207 | continue
208 |
209 | vectors_all.append(vectors)
210 | positions_all.append(positions)
211 |
212 | for i in range(points.shape[0]):
213 | y = int(points[i, 0] * content_img.size[0])
214 | x = int(points[i, 1] * content_img.size[1])
215 | img_array[y-2:y+2, x-2:x+2] = np.array([255, 0, 0])
216 |
217 | vectors_all = np.concatenate(vectors_all, axis=0).astype(np.float32)
218 | positions_all = np.concatenate(positions_all, axis=0).astype(np.float32)
219 | np.save('.temp/vectors', vectors_all)
220 | np.save('.temp/positions', positions_all)
221 |
222 | content_img = Image.open('.temp/content_img.jpg')
223 | style_img = Image.open('.temp/style_img.jpg')
224 |
225 | st.text('Brushstroke optimization...')
226 | pbar = stqdm(range(num_steps_stroke))
227 | stroke_optim = BrushstrokeOptimizer(content_img,
228 | style_img,
229 | draw_curve_position_path='.temp/positions.npy',
230 | draw_curve_vector_path='.temp/vectors.npy',
231 | draw_strength=draw_strength,
232 | resolution=512,
233 | num_strokes=num_strokes,
234 | num_steps=num_steps_stroke,
235 | width_scale=stroke_width,
236 | length_scale=stroke_length,
237 | content_weight=content_weight,
238 | style_weight=style_weight,
239 | draw_weight=draw_weight,
240 | streamlit_pbar=pbar)
241 | canvas = stroke_optim.optimize()
242 |
243 | st.text('Pixel optimization...')
244 | pbar = stqdm(range(num_steps_pixel))
245 | pixel_optim = PixelOptimizer(canvas,
246 | style_img,
247 | resolution=1024,
248 | num_steps=num_steps_pixel,
249 | content_weight=1.0,
250 | style_weight=10000.0,
251 | streamlit_pbar=pbar)
252 | canvas = pixel_optim.optimize()
253 |
254 | st.text('Stylized image:')
255 | st.image(canvas.resize((width, height)))
256 |
257 | canvas.save('.temp/canvas.jpg')
258 | st.markdown(get_binary_file_downloader_html('.temp/canvas.jpg', 'stylized image in high resolution'), unsafe_allow_html=True)
259 |
260 |
261 |
--------------------------------------------------------------------------------
/docs/img/down_arrow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/docs/img/down_arrow.png
--------------------------------------------------------------------------------
/docs/img/left_arrow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/docs/img/left_arrow.png
--------------------------------------------------------------------------------
/docs/img/streamlit.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/docs/img/streamlit.jpg
--------------------------------------------------------------------------------
/docs/img/title_figure.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/docs/img/title_figure.jpg
--------------------------------------------------------------------------------
/images/content/elefant.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/content/elefant.jpg
--------------------------------------------------------------------------------
/images/content/golden_gate.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/content/golden_gate.jpg
--------------------------------------------------------------------------------
/images/content/olive_trees_greece.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/content/olive_trees_greece.jpg
--------------------------------------------------------------------------------
/images/content/road.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/content/road.jpg
--------------------------------------------------------------------------------
/images/content/winxp.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/content/winxp.jpg
--------------------------------------------------------------------------------
/images/style/derain_mountains_at_colloiure.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/style/derain_mountains_at_colloiure.jpg
--------------------------------------------------------------------------------
/images/style/picasso_self_portrait.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/style/picasso_self_portrait.jpg
--------------------------------------------------------------------------------
/images/style/picasso_weeping_woman.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/style/picasso_weeping_woman.jpg
--------------------------------------------------------------------------------
/images/style/van_gogh_red_cabbages_and_onions.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/style/van_gogh_red_cabbages_and_onions.jpg
--------------------------------------------------------------------------------
/images/style/van_gogh_starry_night.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/style/van_gogh_starry_night.jpg
--------------------------------------------------------------------------------
/images/style/van_gogh_trees.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/style/van_gogh_trees.jpg
--------------------------------------------------------------------------------
/images/style/van_gogh_van_gogh_road_cypress.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/style/van_gogh_van_gogh_road_cypress.jpg
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | with warnings.catch_warnings():
3 | warnings.filterwarnings('ignore', category=RuntimeWarning)
4 | warnings.filterwarnings('ignore', category=FutureWarning)
5 | import tensorflow as tf
6 | from tensorflow.core.protobuf import config_pb2
7 |
8 | import os
9 | import numpy as np
10 | from PIL import Image
11 | from tqdm import trange
12 |
13 | import networks
14 | import ops
15 | import utils
16 |
17 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
18 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
19 |
20 |
21 | def stylize(content_img,
22 | style_img,
23 | # Brushstroke optimizer params
24 | resolution=512,
25 | num_strokes=5000,
26 | num_steps=100,
27 | S=10,
28 | K=20,
29 | canvas_color='gray',
30 | width_scale=0.1,
31 | length_scale=1.1,
32 | content_weight=1.0,
33 | style_weight=3.0,
34 | tv_weight=0.008,
35 | curviture_weight=4.0,
36 | # Pixel optimizer params
37 | pixel_resolution=1024,
38 | num_steps_pixel=2000
39 | ):
40 |
41 | stroke_optim = BrushstrokeOptimizer(content_img,
42 | style_img,
43 | resolution=resolution,
44 | num_strokes=num_strokes,
45 | num_steps=num_steps,
46 | S=S,
47 | K=K,
48 | canvas_color=canvas_color,
49 | width_scale=width_scale,
50 | length_scale=length_scale,
51 | content_weight=content_weight,
52 | style_weight=style_weight,
53 | tv_weight=tv_weight,
54 | curviture_weight=curviture_weight)
55 | print('Stroke optimization:')
56 | canvas = stroke_optim.optimize()
57 |
58 | pixel_optim = PixelOptimizer(canvas,
59 | style_img,
60 | resolution=pixel_resolution,
61 | num_steps=num_steps_pixel,
62 | content_weight=1.0,
63 | style_weight=10000.0)
64 |
65 | print('Pixel optimization:')
66 | canvas = pixel_optim.optimize()
67 | return canvas
68 |
69 |
70 | class BrushstrokeOptimizer:
71 |
72 | def __init__(self,
73 | content_img, # Content image (PIL.Image).
74 | style_img, # Style image (PIL.Image).
75 | draw_curve_position_path = None, # Set of points that represent the drawn curves, denoted as P_i in Sec. B of the paper (str).
76 | draw_curve_vector_path = None, # Set of tangent vectors for the points of the drawn curves, denoted as v_i in Sec. B of the paper (str).
77 | draw_strength = 100, # Strength of the influence of the drawn curves, denoted L in Sec. B of the paper (int).
78 | resolution = 512, # Resolution of the canvas (int).
79 | num_strokes = 5000, # Number of brushstrokes (int).
80 | num_steps = 100, # Number of optimization steps (int).
81 | S = 10, # Number of points to sample on each curve, see Sec. 4.2.1 of the paper (int).
82 | K = 20, # Number of brushstrokes to consider for each pixel, see Sec. C.2 of the paper (int).
83 | canvas_color = 'gray', # Color of the canvas (str).
84 | width_scale = 0.1, # Scale parameter for the brushstroke width (float).
85 | length_scale = 1.1, # Scale parameter for the brushstroke length (float).
86 | content_weight = 1.0, # Weight for the content loss (float).
87 | style_weight = 3.0, # Weight for the style loss (float).
88 | tv_weight = 0.008, # Weight for the total variation loss (float).
89 | draw_weight = 100.0, # Weight for the drawing projection loss (float)
90 | curviture_weight = 4.0, # Weight for the curviture loss (float).
91 | streamlit_pbar = None, # Progressbar for streamlit app (obj).
92 | dtype = 'float32' # Data type (str).
93 | ):
94 |
95 | self.draw_strength = draw_strength
96 | self.draw_weight = draw_weight
97 | self.resolution = resolution
98 | self.num_strokes = num_strokes
99 | self.num_steps = num_steps
100 | self.S = S
101 | self.K = K
102 | self.canvas_color = canvas_color
103 | self.width_scale = width_scale
104 | self.length_scale = length_scale
105 | self.content_weight = content_weight
106 | self.style_weight = style_weight
107 | self.tv_weight = tv_weight
108 | self.curviture_weight = curviture_weight
109 | self.streamlit_pbar = streamlit_pbar
110 | self.dtype = dtype
111 |
112 | # Set canvas size (set smaller side of content image to 'resolution' and scale other side accordingly)
113 | W, H = content_img.size
114 | if H < W:
115 | new_H = resolution
116 | new_W = int((W / H) * new_H)
117 | else:
118 | new_W = resolution
119 | new_H = int((H / W) * new_W)
120 |
121 | self.canvas_height = new_H
122 | self.canvas_width = new_W
123 |
124 | content_img = content_img.resize((self.canvas_width, self.canvas_height))
125 | style_img = style_img.resize((self.canvas_width, self.canvas_height))
126 |
127 | content_img = np.array(content_img).astype(self.dtype)
128 | style_img = np.array(style_img).astype(self.dtype)
129 |
130 | content_img /= 255.0
131 | style_img /= 255.0
132 |
133 | self.content_img_np = content_img
134 | self.style_img_np = style_img
135 |
136 | if draw_curve_position_path is not None and draw_curve_vector_path is not None:
137 | self.draw_curve_position_np = np.load(draw_curve_position_path)
138 | self.draw_curve_vector_np = np.load(draw_curve_vector_path)
139 | self.draw_curve_position_np[..., 0] *= self.canvas_width
140 | self.draw_curve_position_np[..., 1] *= self.canvas_height
141 |
142 | ckpt_path = utils.download_weights(url='https://www.dropbox.com/s/hv7b4eajrj7isyq/vgg_weights.pickle?dl=1',
143 | name='vgg_weights.pickle')
144 | self.vgg = networks.VGG(ckpt_path=ckpt_path)
145 |
146 | def optimize(self):
147 | self._initialize()
148 | self._render()
149 | self._losses()
150 | self._optimizer()
151 |
152 |
153 | with tf.Session() as sess:
154 | sess.run(tf.global_variables_initializer())
155 | steps = trange(self.num_steps, desc='', leave=True)
156 | for step in steps:
157 |
158 | I_, loss_dict_, params_dict_, _ = \
159 | sess.run(fetches=[self.I,
160 | self.loss_dict,
161 | self.params_dict,
162 | self.optim_step_with_constraints],
163 | options=config_pb2.RunOptions(report_tensor_allocations_upon_oom=True)
164 | )
165 |
166 | steps.set_description(f'content_loss: {loss_dict_["content"]:.6f}, style_loss: {loss_dict_["style"]:.6f}')
167 | #s = ''
168 | #for key in loss_dict_:
169 | # loss = loss_dict_[key]
170 | # s += key + f': {loss_dict_[key]:.4f}, '
171 | #steps.set_description(s[:-2])
172 | #print(s)
173 |
174 | steps.refresh()
175 | if self.streamlit_pbar is not None: self.streamlit_pbar.update(1)
176 | return Image.fromarray(np.array(np.clip(I_, 0, 1) * 255, dtype=np.uint8))
177 |
178 | def _initialize(self):
179 | location, s, e, c, width, color = utils.initialize_brushstrokes(self.content_img_np,
180 | self.num_strokes,
181 | self.canvas_height,
182 | self.canvas_width,
183 | self.length_scale,
184 | self.width_scale)
185 |
186 | self.curve_s = tf.Variable(name='curve_s', initial_value=s, dtype=self.dtype)
187 | self.curve_e = tf.Variable(name='curve_e', initial_value=e, dtype=self.dtype)
188 | self.curve_c = tf.Variable(name='curve_c', initial_value=c, dtype=self.dtype)
189 | self.color = tf.Variable(name='color', initial_value=color, dtype=self.dtype)
190 | self.location = tf.Variable(name='location', initial_value=location, dtype=self.dtype)
191 | self.width = tf.Variable(name='width', initial_value=width, dtype=self.dtype)
192 | self.content_img = tf.constant(name='content_img', value=self.content_img_np, dtype=self.dtype)
193 | self.style_img = tf.constant(name='style_img', value=self.style_img_np, dtype=self.dtype)
194 |
195 | if hasattr(self, 'draw_curve_position_np') and hasattr(self, 'draw_curve_vector_np'):
196 | self.draw_curve_position = tf.constant(name='draw_curve_position', value=self.draw_curve_position_np, dtype=self.dtype)
197 | self.draw_curve_vector = tf.constant(name='draw_curve_vector', value=self.draw_curve_vector_np, dtype=self.dtype)
198 |
199 | self.params_dict = {'location': self.location,
200 | 'curve_s': self.curve_s,
201 | 'curve_e': self.curve_e,
202 | 'curve_c': self.curve_c,
203 | 'width': self.width,
204 | 'color': self.color}
205 |
206 | def _render(self):
207 | curve_points = ops.sample_quadratic_bezier_curve(s=self.curve_s + self.location,
208 | e=self.curve_e + self.location,
209 | c=self.curve_c + self.location,
210 | num_points=self.S,
211 | dtype=self.dtype)
212 |
213 | self.I = ops.renderer(curve_points,
214 | self.location,
215 | self.color,
216 | self.width,
217 | self.canvas_height,
218 | self.canvas_width,
219 | self.K,
220 | canvas_color=self.canvas_color,
221 | dtype=self.dtype)
222 |
223 | def _losses(self):
224 | # resize images to save memory
225 | rendered_canvas_resized = \
226 | tf.image.resize_nearest_neighbor(images=ops.preprocess_img(self.I),
227 | size=(int(self.canvas_height // 2), int(self.canvas_width // 2)))
228 |
229 | content_img_resized = \
230 | tf.image.resize_nearest_neighbor(images=ops.preprocess_img(self.content_img),
231 | size=(int(self.canvas_height // 2), int(self.canvas_width // 2)))
232 |
233 | style_img_resized = \
234 | tf.image.resize_nearest_neighbor(images=ops.preprocess_img(self.style_img),
235 | size=(int(self.canvas_height // 2), int(self.canvas_width // 2)))
236 |
237 | self.loss_dict = {}
238 | self.loss_dict['content'] = ops.content_loss(self.vgg.extract_features(rendered_canvas_resized),
239 | self.vgg.extract_features(content_img_resized),
240 | #layers=['conv1_2', 'conv2_2', 'conv3_2', 'conv4_2', 'conv5_2'],
241 | layers=['conv4_2', 'conv5_2'],
242 | weights=[1, 1],
243 | scale_by_y=True)
244 | self.loss_dict['content'] *= self.content_weight
245 |
246 | self.loss_dict['style'] = ops.style_loss(self.vgg.extract_features(rendered_canvas_resized),
247 | self.vgg.extract_features(style_img_resized),
248 | layers=['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1'],
249 | weights=[1, 1, 1, 1, 1])
250 | self.loss_dict['style'] *= self.style_weight
251 |
252 | self.loss_dict['curviture'] = ops.curviture_loss(self.curve_s, self.curve_e, self.curve_c)
253 | self.loss_dict['curviture'] *= self.curviture_weight
254 |
255 | self.loss_dict['tv'] = ops.total_variation_loss(x_loc=self.location, s=self.curve_s, e=self.curve_e, K=10)
256 | self.loss_dict['tv'] *= self.tv_weight
257 |
258 | if hasattr(self, 'draw_curve_position') and hasattr(self, 'draw_curve_vector'):
259 | self.loss_dict['drawing'] = ops.draw_projection_loss(self.location,
260 | self.curve_s,
261 | self.curve_e,
262 | self.draw_curve_position,
263 | self.draw_curve_vector,
264 | self.draw_strength)
265 | self.loss_dict['drawing'] *= self.draw_weight
266 |
267 |
268 | def _optimizer(self):
269 | loss = tf.constant(0.0)
270 | for key in self.loss_dict:
271 | loss += self.loss_dict[key]
272 |
273 | step_ops = []
274 | optim_step = tf.train.AdamOptimizer(0.1).minimize(
275 | loss=loss,
276 | var_list=[self.location, self.curve_s, self.curve_e, self.curve_c, self.width])
277 | step_ops.append(optim_step)
278 | optim_step_color = tf.train.AdamOptimizer(0.01).minimize(
279 | loss=self.loss_dict['style'],
280 | var_list=self.color)
281 | step_ops.append(optim_step_color)
282 |
283 | # constraint parameters to certain range
284 | with tf.control_dependencies(step_ops.copy()):
285 | step_ops.append(tf.assign(self.color, tf.clip_by_value(self.color, 0, 1)))
286 | coord_x, coord_y = tf.gather(self.location, axis=-1, indices=[0]), tf.gather(self.location, axis=-1, indices=[1])
287 | coord_clip = tf.concat([tf.clip_by_value(coord_x, 0, self.canvas_height), tf.clip_by_value(coord_y, 0, self.canvas_width)], axis=-1)
288 | step_ops.append(tf.assign(self.location, coord_clip))
289 | step_ops.append(tf.assign(self.width, tf.nn.relu(self.width)))
290 | self.optim_step_with_constraints = tf.group(*step_ops)
291 |
292 |
293 | class PixelOptimizer:
294 |
295 | def __init__(self,
296 | canvas, # Canvas (PIL.Image).
297 | style_img, # Style image (PIL.Image).
298 | resolution = 1024, # Resolution of the canvas.
299 | num_steps = 2000, # Number of optimization steps.
300 | content_weight = 1.0, # Weight for the content loss.
301 | style_weight = 10000.0, # Weight for the style loss.
302 | tv_weight = 0.0, # Weight for the total variation loss.
303 | streamlit_pbar = None, # Progressbar for streamlit app (obj).
304 | dtype = 'float32' # Data type.
305 | ):
306 |
307 | self.resolution = resolution
308 | self.num_steps = num_steps
309 | self.content_weight = content_weight
310 | self.style_weight = style_weight
311 | self.tv_weight = tv_weight
312 | self.streamlit_pbar = streamlit_pbar
313 | self.dtype = dtype
314 |
315 | # Set canvas size (set smaller side of content image to 'resolution' and scale other side accordingly)
316 | W, H = canvas.size
317 | if H < W:
318 | new_H = resolution
319 | new_W = int((W / H) * new_H)
320 | else:
321 | new_W = resolution
322 | new_H = int((H / W) * new_W)
323 |
324 | self.canvas_height = new_H
325 | self.canvas_width = new_W
326 |
327 | canvas = canvas.resize((self.canvas_width, self.canvas_height))
328 | style_img = style_img.resize((self.canvas_width, self.canvas_height))
329 |
330 | canvas = np.array(canvas).astype(self.dtype)
331 | style_img = np.array(style_img).astype(self.dtype)
332 |
333 | canvas /= 255.0
334 | style_img /= 255.0
335 |
336 | self.canvas_np = canvas
337 | self.content_img_np = canvas
338 | self.style_img_np = style_img
339 |
340 | ckpt_path = utils.download_weights(url='https://www.dropbox.com/s/hv7b4eajrj7isyq/vgg_weights.pickle?dl=1',
341 | name='vgg_weights.pickle')
342 | self.vgg = networks.VGG(ckpt_path=ckpt_path)
343 |
344 | def optimize(self):
345 | self._initialize()
346 | self._losses()
347 | self._optimizer()
348 |
349 | with tf.Session() as sess:
350 | sess.run(tf.global_variables_initializer())
351 | steps = trange(self.num_steps, desc='', leave=True)
352 | for step in steps:
353 | canvas_, loss_dict_, _ = \
354 | sess.run(fetches=[self.canvas,
355 | self.loss_dict,
356 | self.optim_step_with_constraints],
357 | options=config_pb2.RunOptions(report_tensor_allocations_upon_oom=True)
358 | )
359 |
360 | s = ''
361 | for key in loss_dict_:
362 | loss = loss_dict_[key]
363 | s += key + f': {loss_dict_[key]:.6f}, '
364 |
365 | steps.set_description(s[:-2])
366 | steps.refresh()
367 | if self.streamlit_pbar is not None: self.streamlit_pbar.update(1)
368 | return Image.fromarray(np.array(np.clip(canvas_, 0, 1) * 255, dtype=np.uint8))
369 |
370 | def _initialize(self):
371 | self.canvas = tf.Variable(name='canvas', initial_value=self.canvas_np, dtype=self.dtype)
372 | self.content_img = tf.constant(name='content_img', value=self.content_img_np, dtype=self.dtype)
373 | self.style_img = tf.constant(name='style_img', value=self.style_img_np, dtype=self.dtype)
374 |
375 | def _losses(self):
376 | # resize images to save memory
377 | rendered_canvas_resized = \
378 | tf.image.resize_nearest_neighbor(images=ops.preprocess_img(self.canvas),
379 | size=(int(self.canvas_height), int(self.canvas_width)))
380 |
381 | content_img_resized = \
382 | tf.image.resize_nearest_neighbor(images=ops.preprocess_img(self.content_img),
383 | size=(int(self.canvas_height), int(self.canvas_width)))
384 |
385 | style_img_resized = \
386 | tf.image.resize_nearest_neighbor(images=ops.preprocess_img(self.style_img),
387 | size=(int(self.canvas_height), int(self.canvas_width)))
388 |
389 | self.loss_dict = {}
390 | self.loss_dict['content'] = ops.content_loss(self.vgg.extract_features(rendered_canvas_resized),
391 | self.vgg.extract_features(content_img_resized),
392 | layers=['conv1_2_pool', 'conv2_2_pool', 'conv3_3_pool', 'conv4_3_pool', 'conv5_3_pool'],
393 | weights=[1, 1, 1, 1, 1])
394 | self.loss_dict['content'] *= self.content_weight
395 |
396 | self.loss_dict['style'] = ops.style_loss(self.vgg.extract_features(rendered_canvas_resized),
397 | self.vgg.extract_features(style_img_resized),
398 | layers=['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1'],
399 | weights=[1, 1, 1, 1, 1])
400 | self.loss_dict['style'] *= self.style_weight
401 |
402 | self.loss_dict['tv'] = ((tf.nn.l2_loss(self.canvas[1:, :, :] - self.canvas[:-1, :, :]) / self.canvas.shape.as_list()[0]) +
403 | (tf.nn.l2_loss(self.canvas[:, 1:, :] - self.canvas[:, :-1, :]) / self.canvas.shape.as_list()[1]))
404 | self.loss_dict['tv'] *= self.tv_weight
405 |
406 | def _optimizer(self):
407 | loss = tf.constant(0.0)
408 | for key in self.loss_dict:
409 | loss += self.loss_dict[key]
410 |
411 | step_ops = []
412 | optim_step = tf.train.AdamOptimizer(0.01).minimize(loss=loss, var_list=self.canvas)
413 | step_ops.append(optim_step)
414 |
415 | # constraint parameters to certain range
416 | with tf.control_dependencies(step_ops.copy()):
417 | step_ops.append(tf.assign(self.canvas, tf.clip_by_value(self.canvas, 0, 1)))
418 |
419 | self.optim_step_with_constraints = tf.group(*step_ops)
420 |
421 |
422 |
423 |
--------------------------------------------------------------------------------
/networks.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | with warnings.catch_warnings():
3 | warnings.filterwarnings("ignore", category=FutureWarning)
4 | import tensorflow as tf
5 | import pickle
6 |
7 |
8 | class VGG:
9 |
10 | def __init__(self, ckpt_path):
11 | self.param_dict = pickle.load(open(ckpt_path, 'rb'))
12 |
13 | def extract_features(self, x):
14 | features = {}
15 | x = self._conv2d_block(x, self.param_dict['block1']['conv1']['weight'], self.param_dict['block1']['conv1']['bias'])
16 | features['conv1_1'] = x
17 | x = self._conv2d_block(x, self.param_dict['block1']['conv2']['weight'], self.param_dict['block1']['conv2']['bias'])
18 | features['conv1_2'] = x
19 | x = tf.nn.max_pool2d(x, ksize=2, strides=2, padding='VALID')
20 | features['conv1_2_pool'] = x
21 |
22 | x = self._conv2d_block(x, self.param_dict['block2']['conv1']['weight'], self.param_dict['block2']['conv1']['bias'])
23 | features['conv2_1'] = x
24 | x = self._conv2d_block(x, self.param_dict['block2']['conv2']['weight'], self.param_dict['block2']['conv2']['bias'])
25 | features['conv2_2'] = x
26 | x = tf.nn.max_pool2d(x, ksize=2, strides=2, padding='VALID')
27 | features['conv2_2_pool'] = x
28 |
29 | x = self._conv2d_block(x, self.param_dict['block3']['conv1']['weight'], self.param_dict['block3']['conv1']['bias'])
30 | features['conv3_1'] = x
31 | x = self._conv2d_block(x, self.param_dict['block3']['conv2']['weight'], self.param_dict['block3']['conv2']['bias'])
32 | features['conv3_2'] = x
33 | x = self._conv2d_block(x, self.param_dict['block3']['conv3']['weight'], self.param_dict['block3']['conv3']['bias'])
34 | features['conv3_3'] = x
35 | x = tf.nn.max_pool2d(x, ksize=2, strides=2, padding='VALID')
36 | features['conv3_3_pool'] = x
37 |
38 | x = self._conv2d_block(x, self.param_dict['block4']['conv1']['weight'], self.param_dict['block4']['conv1']['bias'])
39 | features['conv4_1'] = x
40 | x = self._conv2d_block(x, self.param_dict['block4']['conv2']['weight'], self.param_dict['block4']['conv2']['bias'])
41 | features['conv4_2'] = x
42 | x = self._conv2d_block(x, self.param_dict['block4']['conv3']['weight'], self.param_dict['block4']['conv3']['bias'])
43 | features['conv4_3'] = x
44 | x = tf.nn.max_pool2d(x, ksize=2, strides=2, padding='VALID')
45 | features['conv4_3_pool'] = x
46 |
47 | x = self._conv2d_block(x, self.param_dict['block5']['conv1']['weight'], self.param_dict['block5']['conv1']['bias'])
48 | features['conv5_1'] = x
49 | x = self._conv2d_block(x, self.param_dict['block5']['conv2']['weight'], self.param_dict['block5']['conv2']['bias'])
50 | features['conv5_2'] = x
51 | x = self._conv2d_block(x, self.param_dict['block5']['conv3']['weight'], self.param_dict['block5']['conv3']['bias'])
52 | features['conv5_3'] = x
53 | x = tf.nn.max_pool2d(x, ksize=2, strides=2, padding='VALID')
54 | features['conv5_3_pool'] = x
55 | return features
56 |
57 | def _conv2d_block(self, x, kernel, bias):
58 | x = tf.nn.conv2d(x, filters=kernel, strides=[1, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0]])
59 | x = tf.nn.bias_add(x, bias=bias, data_format='N...C')
60 | x = tf.nn.relu(x)
61 | return x
62 |
63 |
--------------------------------------------------------------------------------
/notebooks/BrushstrokeStyleTransfer.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "accelerator": "GPU",
6 | "colab": {
7 | "name": "BrushstrokeStyleTransfer.ipynb",
8 | "provenance": [],
9 | "collapsed_sections": []
10 | },
11 | "kernelspec": {
12 | "display_name": "Python 3",
13 | "name": "python3"
14 | },
15 | "language_info": {
16 | "name": "python"
17 | }
18 | },
19 | "cells": [
20 | {
21 | "cell_type": "code",
22 | "metadata": {
23 | "id": "N80TrV4juOLV"
24 | },
25 | "source": [
26 | "!pip uninstall -y albumentations\n",
27 | "!pip install -U scikit-learn\n",
28 | "!pip install scikit-image==0.17.2"
29 | ],
30 | "execution_count": null,
31 | "outputs": []
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {
36 | "id": "WfavNAsehO_-"
37 | },
38 | "source": [
39 | "Ignore the error messages...
"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "metadata": {
45 | "id": "d_gmUTe69NK2"
46 | },
47 | "source": [
48 | "!git clone https://github.com/CompVis/brushstroke-parameterized-style-transfer\n",
49 | "%cd brushstroke-parameterized-style-transfer"
50 | ],
51 | "execution_count": null,
52 | "outputs": []
53 | },
54 | {
55 | "cell_type": "code",
56 | "metadata": {
57 | "id": "Bg7ppjI48CuI"
58 | },
59 | "source": [
60 | "import os\n",
61 | "from google.colab import files\n",
62 | "import matplotlib.pyplot as plt\n",
63 | "import numpy as np\n",
64 | "from PIL import Image\n",
65 | "\n",
66 | "%tensorflow_version 1.x\n",
67 | "import tensorflow as tf\n",
68 | "tf.test.is_gpu_available()\n",
69 | "tf.__version__"
70 | ],
71 | "execution_count": null,
72 | "outputs": []
73 | },
74 | {
75 | "cell_type": "markdown",
76 | "metadata": {
77 | "id": "eHSa8q_ziWuc"
78 | },
79 | "source": [
80 | "Upload your own images
"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "metadata": {
86 | "id": "EcaQJvqfieUw"
87 | },
88 | "source": [
89 | "uploaded = files.upload()\n",
90 | "\n",
91 | "for fn in uploaded.keys():\n",
92 | " print('User uploaded file \"{name}\" with length {length} bytes'.format(\n",
93 | " name=fn, length=len(uploaded[fn])))"
94 | ],
95 | "execution_count": null,
96 | "outputs": []
97 | },
98 | {
99 | "cell_type": "markdown",
100 | "metadata": {
101 | "id": "zlKwIBl8jeFc"
102 | },
103 | "source": [
104 | "Or choose from the available images
\n"
105 | ]
106 | },
107 | {
108 | "cell_type": "markdown",
109 | "metadata": {
110 | "id": "ZY5Xnu5_oCor"
111 | },
112 | "source": [
113 | "Content Images"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "metadata": {
119 | "id": "KbrqOU0XjLgQ"
120 | },
121 | "source": [
122 | "content_images = os.listdir('images/content')\n",
123 | "print(content_images)\n",
124 | "\n",
125 | "fig, ax = plt.subplots(nrows=1, ncols=len(content_images), figsize=(30, 5))\n",
126 | "for i in range(len(content_images)):\n",
127 | " img = Image.open(os.path.join('images/content/', content_images[i]))\n",
128 | " ax[i].imshow(np.array(img))\n",
129 | " ax[i].axis('off')\n",
130 | " ax[i].title.set_text(content_images[i])"
131 | ],
132 | "execution_count": null,
133 | "outputs": []
134 | },
135 | {
136 | "cell_type": "markdown",
137 | "metadata": {
138 | "id": "UxzLXLvtpkVJ"
139 | },
140 | "source": [
141 | "Style Images\n"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "metadata": {
147 | "id": "SMeF9GzxoTIA"
148 | },
149 | "source": [
150 | "style_images = os.listdir('images/style')\n",
151 | "print(style_images)\n",
152 | "\n",
153 | "fig, ax = plt.subplots(nrows=1, ncols=len(style_images), figsize=(30, 5))\n",
154 | "for i in range(len(style_images)):\n",
155 | " img = Image.open(os.path.join('images/style/', style_images[i]))\n",
156 | " ax[i].imshow(np.array(img))\n",
157 | " ax[i].axis('off')\n",
158 | " ax[i].title.set_text(style_images[i])"
159 | ],
160 | "execution_count": null,
161 | "outputs": []
162 | },
163 | {
164 | "cell_type": "markdown",
165 | "metadata": {
166 | "id": "0AJJixL6rd5R"
167 | },
168 | "source": [
169 | "Load content and style images
"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "metadata": {
175 | "id": "VRTmFFtjpq_W"
176 | },
177 | "source": [
178 | "# Note: for uploaded images, the prefix 'images/{content, style}' needs to be removed.\n",
179 | "# If you upload an image named 'uploaded_image.jpg', the path to use is also 'uploaded_image.jpg'.\n",
180 | "content_img = Image.open('images/content/golden_gate.jpg')\n",
181 | "style_img = Image.open('images/style/van_gogh_starry_night.jpg')\n",
182 | "\n",
183 | "fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))\n",
184 | "ax[0].imshow(np.array(content_img))\n",
185 | "ax[0].axis('off')\n",
186 | "ax[0].title.set_text('Content image')\n",
187 | "\n",
188 | "ax[1].imshow(np.array(style_img))\n",
189 | "ax[1].axis('off')\n",
190 | "ax[1].title.set_text('Style image')"
191 | ],
192 | "execution_count": null,
193 | "outputs": []
194 | },
195 | {
196 | "cell_type": "code",
197 | "metadata": {
198 | "id": "O5-H46Y0q_99"
199 | },
200 | "source": [
201 | "import model\n",
202 | "\n",
203 | "\n",
204 | "stylized_img = model.stylize(content_img, # Content image (PIL.Image).\n",
205 | " style_img, # Style image (PIL.Image).\n",
206 | " num_strokes=5000, # Number of brushstrokes (int). \n",
207 | " num_steps=100, # Number of stroke optimization steps (int).\n",
208 | " canvas_color='gray', # Color of the canvas (str). Options: 'gray', 'white', 'black', 'noise'\n",
209 | " width_scale=0.1, # Scale parameter for the brushstroke width (float).\n",
210 | " length_scale=1.1, # Scale parameter for the brushstroke length (float).\n",
211 | " content_weight=1.0, # Weight for the content loss (float).\n",
212 | " style_weight=3.0, # Weight for the style loss (float).\n",
213 | " tv_weight=0.008, # Weight for the total variation loss (float). \n",
214 | " pixel_resolution=1024, # Resolution of the canvas for pixel optimization (int).\n",
215 | " num_steps_pixel=2000) # Number of pixel optimization steps (int).\n",
216 | "\n",
217 | "stylized_img.save('stylized.jpg')\n",
218 | "display(stylized_img)"
219 | ],
220 | "execution_count": null,
221 | "outputs": []
222 | }
223 | ]
224 | }
--------------------------------------------------------------------------------
/notebooks/BrushstrokeStyleTransferDrawApp.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "BrushstrokeStyleTransferDrawApp.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": []
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "language_info": {
15 | "name": "python"
16 | },
17 | "accelerator": "GPU"
18 | },
19 | "cells": [
20 | {
21 | "cell_type": "code",
22 | "metadata": {
23 | "id": "Yhv1eMzsHm5E"
24 | },
25 | "source": [
26 | "!pip uninstall -y datascience\n",
27 | "!pip install colab-everything\n",
28 | "!pip install streamlit-drawable-canvas\n",
29 | "!pip install -U scikit-learn\n",
30 | "!pip install scikit-image==0.17.2\n",
31 | "!pip install stqdm"
32 | ],
33 | "execution_count": null,
34 | "outputs": []
35 | },
36 | {
37 | "cell_type": "markdown",
38 | "metadata": {
39 | "id": "KZv5oLILZwZo"
40 | },
41 | "source": [
42 | "Ignore the error messages...
"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "metadata": {
48 | "id": "elOYlKQ9Xz9X"
49 | },
50 | "source": [
51 | "!git clone https://github.com/CompVis/brushstroke-parameterized-style-transfer\n",
52 | "%cd brushstroke-parameterized-style-transfer"
53 | ],
54 | "execution_count": null,
55 | "outputs": []
56 | },
57 | {
58 | "cell_type": "markdown",
59 | "metadata": {
60 | "id": "vtKSiie6VW0V"
61 | },
62 | "source": [
63 | "Run app
"
64 | ]
65 | },
66 | {
67 | "cell_type": "markdown",
68 | "metadata": {
69 | "id": "Iv_qmYLnVnLG"
70 | },
71 | "source": [
72 | "You have to click on the link that ends with \".ngrok.io\".\n",
73 | "\n",
74 | "
"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "metadata": {
80 | "id": "JPZzcLolSZHj"
81 | },
82 | "source": [
83 | "%tensorflow_version 1.x\n",
84 | "import tensorflow as tf\n",
85 | "tf.test.is_gpu_available()\n",
86 | "tf.__version__\n",
87 | "\n",
88 | "from colab_everything import ColabStreamlit\n",
89 | "ColabStreamlit('app.py')"
90 | ],
91 | "execution_count": null,
92 | "outputs": []
93 | },
94 | {
95 | "cell_type": "code",
96 | "metadata": {
97 | "id": "jPA8BG9NKvd6"
98 | },
99 | "source": [
100 | ""
101 | ],
102 | "execution_count": null,
103 | "outputs": []
104 | },
105 | {
106 | "cell_type": "code",
107 | "metadata": {
108 | "id": "T8OQDcflKxQN"
109 | },
110 | "source": [
111 | ""
112 | ],
113 | "execution_count": null,
114 | "outputs": []
115 | },
116 | {
117 | "cell_type": "code",
118 | "metadata": {
119 | "id": "jXjSTak_K1bv"
120 | },
121 | "source": [
122 | ""
123 | ],
124 | "execution_count": null,
125 | "outputs": []
126 | }
127 | ]
128 | }
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | with warnings.catch_warnings():
4 | warnings.filterwarnings("ignore", category=FutureWarning)
5 | import tensorflow as tf
6 |
7 |
8 | #---------------------------------------------------------------------
9 | # Misc
10 | #---------------------------------------------------------------------
11 | def preprocess_img(x):
12 | x = 2 * x - 1
13 | x = tf.expand_dims(x, axis=0)
14 | return x
15 |
16 |
17 | def norm(x, axis=None, keepdims=None, eps=1e-8):
18 | """
19 | Numerically stable norm.
20 | """
21 | return tf.sqrt(tf.reduce_sum(tf.square(x), axis=axis, keepdims=keepdims) + eps)
22 | #return tf.reduce_sum(tf.square(x), axis=axis, keepdims=keepdims)
23 |
24 |
25 | #---------------------------------------------------------------------
26 | # Brushstrokes
27 | #---------------------------------------------------------------------
28 | def sample_quadratic_bezier_curve(s, c, e, num_points=20, dtype='float32'):
29 | """
30 | Samples points from the quadratic bezier curves defined by the control points.
31 | Number of points to sample is num.
32 |
33 | Args:
34 | s (tensor): Start point of each curve, shape [N, 2].
35 | c (tensor): Control point of each curve, shape [N, 2].
36 | e (tensor): End point of each curve, shape [N, 2].
37 | num_points (int): Number of points to sample on every curve.
38 |
39 | Return:
40 | (tensor): Coordinates of the points on the Bezier curves, shape [N, num_points, 2]
41 | """
42 | N, _ = s.shape.as_list()
43 | t = tf.linspace(0., 1., num_points)
44 | t = tf.cast(t, dtype=dtype)
45 | t = tf.stack([t] * N, axis=0)
46 | s_x = tf.expand_dims(s[..., 0], axis=1)
47 | s_y = tf.expand_dims(s[..., 1], axis=1)
48 | e_x = tf.expand_dims(e[..., 0], axis=1)
49 | e_y = tf.expand_dims(e[..., 1], axis=1)
50 | c_x = tf.expand_dims(c[..., 0], axis=1)
51 | c_y = tf.expand_dims(c[..., 1], axis=1)
52 | x = c_x + (1. - t) ** 2 * (s_x - c_x) + t ** 2 * (e_x - c_x)
53 | y = c_y + (1. - t) ** 2 * (s_y - c_y) + t ** 2 * (e_y - c_y)
54 | return tf.stack([x, y], axis=-1)
55 |
56 |
57 | def renderer(curve_points, locations, colors, widths, H, W, K, canvas_color='gray', dtype='float32'):
58 | """
59 | Renders the given brushstroke parameters onto a canvas.
60 | See Alg. 1 in https://arxiv.org/pdf/2103.17185.pdf.
61 |
62 | Args:
63 | curve_points (tensor): Points specifying the curves that will be rendered on the canvas, shape [N, S, 2].
64 | locations (tensor): Location of each curve, shape [N, 2].
65 | colors (tensor): Color of each curve, shape [N, 3].
66 | widths (tensor): Width of each curve, shape [N, 1].
67 | H (int): Height of the canvas.
68 | W (int): Width of the canvas.
69 | K (int): Number of brushstrokes to consider for each pixel, see Sec. C.2 of the paper (Arxiv version).
70 | canvas_color (str): Background color of the canvas. Options: 'gray', 'white', 'black', 'noise'.
71 | Returns:
72 | (tensor): The rendered canvas, shape [H, W, 3].
73 | """
74 | N, S, _ = curve_points.shape.as_list()
75 | # define coarse grid cell
76 | t_H = tf.linspace(0., float(H), int(H // 5))
77 | t_W = tf.linspace(0., float(W), int(W // 5))
78 | t_H = tf.cast(t_H, dtype=dtype)
79 | t_W = tf.cast(t_W, dtype=dtype)
80 | P_y, P_x = tf.meshgrid(t_W, t_H)
81 | P = tf.stack([P_x, P_y], axis=-1) # [32, 32, 2]
82 | # Compute now distances from every brushtroke center to every coarse grid cell
83 | #P_norms = tf.square(norm(P, axis=-1))
84 | #B_center_norms = tf.square(norm(locations, axis=-1))
85 | #P_dot_B_center = tf.einsum('xyf,Nf->xyN', P, locations)
86 | # [32, 32, N]
87 | #D_to_all_B_centers = tf.expand_dims(P_norms, axis=-1) + tf.expand_dims(tf.expand_dims(B_center_norms, axis=0), axis=0) - 2. * P_dot_B_center
88 |
89 | #####
90 | D_to_all_B_centers = tf.reduce_sum(tf.square(tf.expand_dims(P, axis=-2) - locations), axis=-1) # [H // C, W // C, N]
91 | #####
92 |
93 | # Find nearest brushstrokes' indices for every coarse grid cell
94 | _, idcs = tf.math.top_k(-D_to_all_B_centers, k=K) # [32, 32, K]
95 | # Now create 2 tensors (spatial size of a grid cell). One containing brushstroke locations, another containing
96 | # brushstroke colors.
97 | # [H // 10, W // 10, K, S, 2]
98 | canvas_with_nearest_Bs = tf.gather(params=curve_points,
99 | indices=idcs,
100 | batch_dims=0)
101 | # [H // 10, W // 10, K, 3]
102 | canvas_with_nearest_Bs_colors = tf.gather(params=colors,
103 | indices=idcs,
104 | batch_dims=0)
105 | # [H // 10, W // 10, K, 1]
106 | canvas_with_nearest_Bs_bs = tf.gather(params=widths,
107 | indices=idcs,
108 | batch_dims=0)
109 | # Resize those tensors to the full canvas size (not coarse grid)
110 | # First locations of points sampled from curves
111 | H_, W_, r1, r2, r3 = canvas_with_nearest_Bs.shape.as_list()
112 | canvas_with_nearest_Bs = tf.reshape(canvas_with_nearest_Bs, shape=(1, H_, W_, r1 * r2 * r3)) # [1, H // 10, W // 10, K * S * 2]
113 | canvas_with_nearest_Bs = tf.image.resize_nearest_neighbor(canvas_with_nearest_Bs, size=(H, W)) # [1, H, W, K * S * 2]
114 | canvas_with_nearest_Bs = tf.reshape(canvas_with_nearest_Bs, shape=(H, W, r1, r2, r3)) # [H, W, N, S, 2]
115 | # Now colors of curves
116 | H_, W_, r1, r2 = canvas_with_nearest_Bs_colors.shape.as_list()
117 | canvas_with_nearest_Bs_colors = tf.reshape(canvas_with_nearest_Bs_colors, shape=(1, H_, W_, r1 * r2)) # [1, H // 10, W // 10, K * 3]
118 | canvas_with_nearest_Bs_colors = tf.image.resize_nearest_neighbor(canvas_with_nearest_Bs_colors, size=(H, W)) # [1, H, W, K * 3]
119 | canvas_with_nearest_Bs_colors = tf.reshape(canvas_with_nearest_Bs_colors, shape=(H, W, r1, r2)) # [H, W, K, 3]
120 | # And with the brush size
121 | H_, W_, r1, r2 = canvas_with_nearest_Bs_bs.shape.as_list()
122 | canvas_with_nearest_Bs_bs = tf.reshape(canvas_with_nearest_Bs_bs, shape=(1, H_, W_, r1 * r2)) # [1, H // 10, W // 10, K]
123 | canvas_with_nearest_Bs_bs = tf.image.resize_nearest_neighbor(canvas_with_nearest_Bs_bs, size=(H, W)) # [1, H, W, K]
124 | canvas_with_nearest_Bs_bs = tf.reshape(canvas_with_nearest_Bs_bs, shape=(H, W, r1, r2)) # [H, W, K, 1]
125 | # Now create full-size canvas
126 | t_H = tf.linspace(0., float(H), H)
127 | t_W = tf.linspace(0., float(W), W)
128 | t_H = tf.cast(t_H, dtype=dtype)
129 | t_W = tf.cast(t_W, dtype=dtype)
130 | P_y, P_x = tf.meshgrid(t_W, t_H)
131 | P_full = tf.stack([P_x, P_y], axis=-1) # [H, W, 2]
132 | # Compute distance from every pixel on canvas to each (among nearest ones) line segment between points from curves
133 | canvas_with_nearest_Bs_a = tf.gather(canvas_with_nearest_Bs, axis=-2, indices=[i for i in range(S - 1)]) # start points of each line segment
134 | canvas_with_nearest_Bs_b = tf.gather(canvas_with_nearest_Bs, axis=-2, indices=[i for i in range(1, S)]) # end points of each line segments
135 | canvas_with_nearest_Bs_b_a = canvas_with_nearest_Bs_b - canvas_with_nearest_Bs_a # [H, W, N, S - 1, 2]
136 | P_full_canvas_with_nearest_Bs_a = tf.expand_dims(tf.expand_dims(P_full, axis=2), axis=2) - canvas_with_nearest_Bs_a # [H, W, K, S - 1, 2]
137 | # compute t value for which each pixel is closest to each line that goes through each line segment (among nearest ones)
138 | t = tf.reduce_sum(canvas_with_nearest_Bs_b_a * P_full_canvas_with_nearest_Bs_a, axis=-1) \
139 | / (tf.reduce_sum(tf.square(canvas_with_nearest_Bs_b_a), axis=-1) + 1e-8)
140 | # if t value is outside [0, 1], then the nearest point on the line does not lie on the segment, so clip values of t
141 | t = tf.clip_by_value(t, clip_value_min=0.0, clip_value_max=1.0)
142 | # compute closest points on each line segment - [H, W, K, S - 1, 2]
143 | closest_points_on_each_line_segment = canvas_with_nearest_Bs_a + tf.expand_dims(t, axis=-1) * canvas_with_nearest_Bs_b_a
144 | # compute the distance from every pixel to the closest point on each line segment - [H, W, K, S - 1]
145 | dist_to_closest_point_on_line_segment = \
146 | tf.reduce_sum(tf.square(tf.expand_dims(tf.expand_dims(P_full, axis=2), axis=2) - closest_points_on_each_line_segment), axis=-1)
147 | # and distance to the nearest bezier curve.
148 | D = tf.reduce_min(dist_to_closest_point_on_line_segment, axis=[-1, -2]) # [H, W]
149 | # Finally render curves on a canvas to obtain image.
150 | I_NNs_B_ranking = tf.nn.softmax(100000. * (1.0 / (1e-8 + tf.reduce_min(dist_to_closest_point_on_line_segment, axis=[-1]))), axis=-1) # [H, W, N]
151 | I_colors = tf.einsum('hwnf,hwn->hwf', canvas_with_nearest_Bs_colors, I_NNs_B_ranking) # [H, W, 3]
152 | bs = tf.einsum('hwnf,hwn->hwf', canvas_with_nearest_Bs_bs, I_NNs_B_ranking) # [H, W, 1]
153 | bs_mask = tf.math.sigmoid(bs - tf.expand_dims(D, axis=-1))
154 | if canvas_color == 'gray':
155 | canvas = tf.ones(shape=I_colors.shape, dtype=dtype) * 0.5
156 | elif canvas_color == 'white':
157 | canvas = tf.ones(shape=I_colors.shape, dtype=dtype)
158 | elif canvas_color == 'black':
159 | canvas = tf.zeros(shape=I_colors.shape, dtype=dtype)
160 | elif canvas_color == 'noise':
161 | canvas = tf.random.normal(shape=I_colors.shape, dtype=dtype) * 0.1
162 |
163 | I = I_colors * bs_mask + (1 - bs_mask) * canvas
164 | return I
165 |
166 |
167 | #---------------------------------------------------------------------
168 | # Losses
169 | #---------------------------------------------------------------------
170 | def content_loss(features_lhs, features_rhs, layers, weights, scale_by_y=False):
171 | """
172 | Computes the VGG perceptual loss.
173 |
174 | Args:
175 | features_lhs (dict of tensors): Dictionary of VGG activations.
176 | features_rhs (dict of tensors): Dictionary of VGG activations.
177 | layers (list of str): List specifying the layers to use.
178 | weights (list of floats): List specifying the weights for the used layers.
179 |
180 | Returns:
181 | VGG perceptual loss.
182 | """
183 |
184 | feat_lhs = [features_lhs[key] for key in layers]
185 | feat_rhs = [features_rhs[key] for key in layers]
186 |
187 | if scale_by_y:
188 | losses = [w * tf.reduce_mean(tf.square(xf - yf) * tf.minimum(yf, tf.sigmoid(yf))) for w, xf, yf in zip(weights, feat_lhs, feat_rhs)]
189 | else:
190 | losses = [w * tf.reduce_mean(tf.square(xf - yf)) for w, xf, yf in zip(weights, feat_lhs, feat_rhs)]
191 |
192 | loss = tf.add_n(losses)
193 | return loss
194 |
195 |
196 | def get_gram_matrices(features):
197 | """
198 | Computes the gram matrices for the given list of activations.
199 |
200 | Args:
201 | features (list of tensors): Dictionary of VGG activations.
202 |
203 | Returns:
204 | List of gram matrices.
205 | """
206 | gram_matrices = []
207 | for feature in features:
208 | gram_matrix = tf.einsum('bhwf,bhwl->bfl', feature, feature)
209 | B, H, W, C = feature.shape.as_list()
210 | gram_matrix /= tf.cast(H * W * C, dtype=tf.float32)
211 | gram_matrices.append(gram_matrix)
212 | return gram_matrices
213 |
214 |
215 | def style_loss(features_lhs, features_rhs, layers, weights):
216 | """
217 | Computes the VGG gram matrix style loss.
218 |
219 | Args:
220 | features_lhs (dict of tensors): Dictionary of VGG activations.
221 | features_rhs (dict of tensors): Dictionary of VGG activations.
222 | layers (list of str): List specifying the layers to use.
223 | weights (list of floats): List specifying the weights for the used layers.
224 |
225 | Returns:
226 | VGG gram matrix style loss.
227 | """
228 | feat_lhs = [features_lhs[key] for key in layers]
229 | feat_rhs = [features_rhs[key] for key in layers]
230 | gram_matrices_lhs = get_gram_matrices(feat_lhs)
231 | gram_matrices_rhs = get_gram_matrices(feat_rhs)
232 | losses = [w * tf.reduce_sum(tf.square(gram_lhs - gram_rhs)) for w, gram_lhs, gram_rhs in zip(weights, gram_matrices_lhs, gram_matrices_rhs)]
233 | loss = tf.add_n(losses)
234 | return loss
235 |
236 |
237 | def get_nn_idxs(X, k, fetch_dist=False):
238 | """
239 | For a given tensor compute all the nearest neighbor indices to each element.
240 |
241 | Args:
242 | x (tensor): Tensor of shape [B, N, F].
243 | k (int): Number of nearest neighbors.
244 | fetch_dist (bool): Also return the distances.
245 |
246 | Returns:
247 | Tensor of shape [B, N, k].
248 |
249 | """
250 | r = tf.reduce_sum(X * X, 2, keepdims=True)
251 | D = r - 2 * tf.matmul(X, tf.transpose(X, perm=(0, 2, 1))) + tf.transpose(r, perm=(0, 2, 1))
252 | X_top_vals, X_top_idxs = tf.math.top_k(-D, k=k, sorted=True, name=None)
253 |
254 | if fetch_dist:
255 | return X_top_idxs, X_top_vals
256 | else:
257 | return X_top_idxs
258 |
259 |
260 | def total_variation_loss(x_loc, s, e, K=10):
261 |
262 | def projection(z):
263 | x = tf.gather(z, axis=-1, indices=[0])
264 | y = tf.gather(z, axis=-1, indices=[1])
265 | return tf.concat([tf.square(x), tf.square(y), x * y], axis=-1)
266 |
267 | se_vec = e - s
268 | se_vec_proj = projection(se_vec)
269 |
270 | x_nn_idcs = get_nn_idxs(tf.expand_dims(x_loc, axis=0), k=K)
271 |
272 | x_nn_idcs = tf.squeeze(x_nn_idcs, axis=0)
273 | x_sig_nns = tf.gather(se_vec, indices=x_nn_idcs, axis=0, batch_dims=0)
274 |
275 | dist_to_centroid = tf.reduce_mean(tf.reduce_sum(tf.square(projection(x_sig_nns) - tf.expand_dims(projection(se_vec), axis=-2)), axis=-1))
276 | return dist_to_centroid
277 |
278 |
279 | def draw_projection_loss(location, s, e, draw_curve_position, draw_curve_vector, draw_strength):
280 | dist = tf.reduce_sum(tf.square(tf.expand_dims(draw_curve_position, axis=1) - location), axis=-1)
281 | _, idcs = tf.math.top_k(-dist, k=draw_strength) # [num_points, K]
282 | se_vec = e - s
283 | strokes_vec_nn = tf.gather(se_vec, indices=idcs, axis=0) # [num_points, K, 2]
284 | strokes_vec_nn /= (norm(strokes_vec_nn, axis=-1, keepdims=True) + 1e-6)
285 | curves_vec = draw_curve_vector / (norm(draw_curve_vector, axis=-1, keepdims=True) + 1e-6)
286 | projection = tf.abs(tf.einsum('mki,mi->mk', strokes_vec_nn, curves_vec)) # [num_points, num_strokes]
287 | projection_loss = tf.reduce_mean(tf.square(1 - projection))
288 | return projection_loss
289 |
290 |
291 | def curviture_loss(s, e, c):
292 | v1 = s - c
293 | v2 = e - c
294 | dist_se = norm(e - s, axis=-1) + 1e-6
295 | return tf.reduce_mean(norm(v1 + v2, axis=-1) / dist_se)
296 |
297 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | Pillow==7.2.0
2 | tqdm
3 | requests==2.24.0
4 | scikit-image==0.17.2
5 | numpy==1.19.1
6 | scikit-learn
7 | streamlit==0.82.0
8 | streamlit-drawable-canvas
9 | stqdm
10 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import requests
3 | import os
4 | from PIL import Image
5 | import numpy as np
6 | from skimage.segmentation import slic
7 | from scipy.spatial import ConvexHull
8 |
9 |
10 | #------------------------------------------------------------------
11 | # I/O
12 | #------------------------------------------------------------------
13 |
14 | def download_weights(url, name):
15 | """
16 | Downloads the checkpoint file specified by 'url'.
17 |
18 | Args:
19 | url (str): URL specifying the checkpoint file.
20 | name (str): Name under which the checkpoint file will be stored.
21 |
22 | Returns:
23 | (str): Path to the checkpoint file.
24 | """
25 | ckpt_dir = 'pretrained_weights'
26 | ckpt_file = os.path.join(ckpt_dir, name)
27 | if not os.path.exists(ckpt_file):
28 | print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}')
29 | if not os.path.exists(ckpt_dir):
30 | os.makedirs(ckpt_dir)
31 |
32 | response = requests.get(url, stream=True)
33 | total_size_in_bytes = int(response.headers.get('content-length', 0))
34 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
35 |
36 | # first create temp file, in case the download fails
37 | ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp')
38 | with open(ckpt_file_temp, 'wb') as file:
39 | for data in response.iter_content(chunk_size=1024):
40 | progress_bar.update(len(data))
41 | file.write(data)
42 | progress_bar.close()
43 |
44 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
45 | print('An error occured while downloading, please try again.')
46 | if os.path.exists(ckpt_file_temp):
47 | os.remove(ckpt_file_temp)
48 | else:
49 | # if download was successful, rename the temp file
50 | os.rename(ckpt_file_temp, ckpt_file)
51 | return ckpt_file
52 |
53 |
54 | #------------------------------------------------------------------
55 | # Brushstrokes
56 | #------------------------------------------------------------------
57 |
58 | def clusters_to_strokes(segments, img, H, W, sec_scale=0.001, width_scale=1):
59 | segments += np.abs(np.min(segments))
60 | num_clusters = np.max(segments)
61 | clusters_params = {'center': [],
62 | 's': [],
63 | 'e': [],
64 | 'bp1': [],
65 | 'bp2': [],
66 | 'num_pixels': [],
67 | 'stddev': [],
68 | 'width': [],
69 | 'color_rgb': []
70 | }
71 |
72 | for cluster_idx in range(num_clusters + 1):
73 | cluster_mask = segments==cluster_idx
74 | if np.sum(cluster_mask) < 5: continue
75 | cluster_mask_nonzeros = np.nonzero(cluster_mask)
76 |
77 | cluster_points = np.stack((cluster_mask_nonzeros[0], cluster_mask_nonzeros[1]), axis=-1)
78 | try:
79 | convex_hull = ConvexHull(cluster_points)
80 | except:
81 | continue
82 |
83 | # find the two points (pixels) in the cluster that have the largest distance between them
84 | border_points = cluster_points[convex_hull.simplices.reshape(-1)]
85 | dist = np.sum((np.expand_dims(border_points, axis=1) - border_points)**2, axis=-1)
86 | max_idx_a, max_idx_b = np.nonzero(dist == np.max(dist))
87 | point_a = border_points[max_idx_a[0]]
88 | point_b = border_points[max_idx_b[0]]
89 | # compute the two intersection points of the line that goes orthogonal to point_a and point_b
90 | v_ba = point_b - point_a
91 | v_orth = np.array([v_ba[1], -v_ba[0]])
92 | m = (point_a + point_b) / 2.0
93 | n = m + 0.5 * v_orth
94 | p = cluster_points[convex_hull.simplices][:, 0]
95 | q = cluster_points[convex_hull.simplices][:, 1]
96 | u = - ((m[..., 0] - n[..., 0]) * (m[..., 1] - p[..., 1]) - (m[..., 1] - n[..., 1]) * (m[..., 0] - p[..., 0])) \
97 | / ((m[..., 0] - n[..., 0]) * (p[..., 1] - q[..., 1]) - (m[..., 1] - n[..., 1]) * (p[..., 0] - q[..., 0]))
98 | intersec_idcs = np.logical_and(u >= 0, u <= 1)
99 | intersec_points = p + u.reshape(-1, 1) * (q - p)
100 | intersec_points = intersec_points[intersec_idcs]
101 |
102 | width = np.sum((intersec_points[0] - intersec_points[1])**2)
103 |
104 | if width == 0.0: continue
105 |
106 | clusters_params['s'].append(point_a / img.shape[:2])
107 | clusters_params['e'].append(point_b / img.shape[:2])
108 | clusters_params['bp1'].append(intersec_points[0] / img.shape[:2])
109 | clusters_params['bp2'].append(intersec_points[1] / img.shape[:2])
110 | clusters_params['width'].append(np.sum((intersec_points[0] - intersec_points[1])**2))
111 |
112 | clusters_params['color_rgb'].append(np.mean(img[cluster_mask], axis=0))
113 | center_x = np.mean(cluster_mask_nonzeros[0]) / img.shape[0]
114 | center_y = np.mean(cluster_mask_nonzeros[1]) / img.shape[1]
115 | clusters_params['center'].append(np.array([center_x, center_y]))
116 | clusters_params['num_pixels'].append(np.sum(cluster_mask))
117 | clusters_params['stddev'].append(np.mean(np.std(img[cluster_mask], axis=0)))
118 |
119 | for key in clusters_params.keys():
120 | clusters_params[key] = np.array(clusters_params[key])
121 |
122 | N = clusters_params['center'].shape[0]
123 |
124 | stddev = clusters_params['stddev']
125 | rel_num_pixels = 5 * clusters_params['num_pixels'] / np.sqrt(H * W)
126 |
127 | location = clusters_params['center']
128 | num_pixels_per_cluster = clusters_params['num_pixels'].reshape(-1, 1)
129 | s = clusters_params['s']
130 | e = clusters_params['e']
131 | cluster_width = clusters_params['width']
132 |
133 | location[..., 0] *= H
134 | location[..., 1] *= W
135 | s[..., 0] *= H
136 | s[..., 1] *= W
137 | e[..., 0] *= H
138 | e[..., 1] *= W
139 |
140 | s -= location
141 | e -= location
142 |
143 | color = clusters_params['color_rgb']
144 |
145 | c = (s + e) / 2. + np.stack([np.random.uniform(low=-1, high=1, size=[N]),
146 | np.random.uniform(low=-1, high=1, size=[N])],
147 | axis=-1)
148 |
149 | sec_center = (s + e + c) / 3.
150 | s -= sec_center
151 | e -= sec_center
152 | c -= sec_center
153 |
154 | rel_num_pix_quant = np.quantile(rel_num_pixels, q=[0.3, 0.99])
155 | width_quant = np.quantile(cluster_width, q=[0.3, 0.99])
156 | rel_num_pixels = np.clip(rel_num_pixels, rel_num_pix_quant[0], rel_num_pix_quant[1])
157 | cluster_width = np.clip(cluster_width, width_quant[0], width_quant[1])
158 | width = width_scale * rel_num_pixels.reshape(-1, 1) * cluster_width.reshape(-1, 1)
159 | s, e, c = [x * sec_scale for x in [s, e, c]]
160 |
161 | location, s, e, c, width, color = [x.astype(np.float32) for x in [location, s, e, c, width, color]]
162 |
163 | return location, s, e, c, width, color
164 |
165 |
166 | def initialize_brushstrokes(content_img, num_strokes, canvas_height, canvas_width, sec_scale, width_scale, init='sp'):
167 |
168 | if init == 'random':
169 | # Brushstroke colors
170 | color = np.random.rand(num_strokes, 3)
171 |
172 | # Brushstroke widths
173 | width = np.random.rand(num_strokes, 1) * width_scale
174 |
175 | # Brushstroke locations
176 | location = np.stack([np.random.rand(num_strokes) * canvas_height, np.random.rand(num_strokes) * canvas_width], axis=-1)
177 |
178 | # Start point for the Bezier curves
179 | s = np.stack([np.random.uniform(low=-1, high=1, size=num_strokes) * canvas_height,
180 | np.random.uniform(low=-1, high=1, size=num_strokes) * canvas_width], axis=-1)
181 |
182 | # End point for the Bezier curves
183 | e = np.stack([np.random.uniform(low=-1, high=1, size=num_strokes) * canvas_height,
184 | np.random.uniform(low=-1, high=1, size=num_strokes) * canvas_width], axis=-1)
185 |
186 | # Control point for the Bezier curves
187 | c = np.stack([np.random.uniform(low=-1, high=1, size=num_strokes) * canvas_height,
188 | np.random.uniform(low=-1, high=1, size=num_strokes) * canvas_width], axis=-1)
189 |
190 | # Normalize control points
191 | sec_center = (s + e + c) / 3.0
192 | s, e, c = [x - sec_center for x in [s, e, c]]
193 | s, e, c = [x * sec_scale for x in [s, e, c]]
194 | else:
195 | segments = slic(content_img,
196 | n_segments=num_strokes,
197 | min_size_factor=0.02,
198 | max_size_factor=4.,
199 | compactness=2,
200 | sigma=1,
201 | start_label=0)
202 |
203 | location, s, e, c, width, color = clusters_to_strokes(segments,
204 | content_img,
205 | canvas_height,
206 | canvas_width,
207 | sec_scale=sec_scale,
208 | width_scale=width_scale)
209 |
210 | return location, s, e, c, width, color
211 |
--------------------------------------------------------------------------------