├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── docs └── img │ ├── down_arrow.png │ ├── left_arrow.png │ ├── streamlit.jpg │ └── title_figure.jpg ├── images ├── content │ ├── elefant.jpg │ ├── golden_gate.jpg │ ├── olive_trees_greece.jpg │ ├── road.jpg │ └── winxp.jpg └── style │ ├── derain_mountains_at_colloiure.jpg │ ├── picasso_self_portrait.jpg │ ├── picasso_weeping_woman.jpg │ ├── van_gogh_red_cabbages_and_onions.jpg │ ├── van_gogh_starry_night.jpg │ ├── van_gogh_trees.jpg │ └── van_gogh_van_gogh_road_cypress.jpg ├── model.py ├── networks.py ├── notebooks ├── BrushstrokeStyleTransfer.ipynb └── BrushstrokeStyleTransferDrawApp.ipynb ├── ops.py ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .temp 3 | pretrained_weights 4 | *.swp 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 CompVis Heidelberg 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Rethinking Style Transfer: From Pixels to Parameterized Brushstrokes (CVPR 2021) 2 | 3 |
img
4 | 5 | ### [Project page](https://compvis.github.io/brushstroke-parameterized-style-transfer/) | [Paper](https://arxiv.org/abs/2103.17185) | [Colab](https://colab.research.google.com/drive/1J9B6_G2DSWmaBWw9Ot80W9t7O6pWu8Kw?usp=sharing) | [Colab for Drawing App](https://colab.research.google.com/drive/1ALNRoZgCj35uJ3Xvs24-QDwwtCb2lm3P?usp=sharing) 6 | 7 | Rethinking Style Transfer: From Pixels to Parameterized Brushstrokes. 8 | [Dmytro Kotovenko*](https://scholar.google.de/citations?user=T_U8yxwAAAAJ&hl=en), [Matthias Wright*](http://www.matthias-wright.com/), [Arthur Heimbrecht](http://www.aheimbrecht.de/), and [Björn Ommer](https://hci.iwr.uni-heidelberg.de/people/bommer).
9 | * denotes equal contribution
10 | 11 | ## Implementations 12 | We provide implementations in [Tensorflow 1](https://github.com/CompVis/brushstroke-parameterized-style-transfer/tree/tensorflow_v1) and [Tensorflow 2](https://github.com/CompVis/brushstroke-parameterized-style-transfer/tree/tensorflow_v2). In order to reproduce the results from the paper, we recommend the [Tensorflow 1](https://github.com/CompVis/brushstroke-parameterized-style-transfer/tree/tensorflow_v1) implementation. 13 | 14 | ## Installation 15 | 1. Clone this repository: 16 | ```sh 17 | > git clone https://github.com/CompVis/brushstroke-parameterized-style-transfer 18 | > cd brushstroke-parameterized-style-transfer 19 | ``` 20 | 2. Install Tensorflow 1.14 (preferably with GPU support). 21 | If you are using [Conda](https://docs.conda.io/en/latest/index.html), this command will create a new environment and install Tensorflow as well as compatible CUDA and cuDNN versions. 22 | ```sh 23 | > conda create --name tf14 tensorflow-gpu==1.14 24 | > conda activate tf14 25 | ``` 26 | 3. Install requirements: 27 | ```sh 28 | > pip install -r requirements.txt 29 | ``` 30 | 31 | ## Basic Usage 32 | ```python 33 | from PIL import Image 34 | import model 35 | 36 | content_img = Image.open('images/content/golden_gate.jpg') 37 | style_img = Image.open('images/style/van_gogh_starry_night.jpg') 38 | 39 | stylized_img = model.stylize(content_img, 40 | style_img, 41 | num_strokes=5000, 42 | num_steps=100, 43 | content_weight=1.0, 44 | style_weight=3.0, 45 | num_steps_pixel=1000) 46 | 47 | stylized_img.save('images/stylized.jpg') 48 | ``` 49 | or open [Colab](https://colab.research.google.com/drive/1J9B6_G2DSWmaBWw9Ot80W9t7O6pWu8Kw?usp=sharing). 50 | 51 | ## Drawing App 52 | We created a [Streamlit](https://streamlit.io/) app where you can draw curves to control the flow of brushstrokes. 53 | 54 |
img
55 | 56 | #### Run drawing app on your machine 57 | To run the app on your own machine: 58 | ```sh 59 | > CUDA_VISIBLE_DEVICES=0 streamlit run app.py 60 | ``` 61 | 62 | 63 | You can also run the app on a remote server and forward the port to your local machine: 64 | [https://docs.streamlit.io/en/0.66.0/tutorial/run_streamlit_remotely.html](https://docs.streamlit.io/en/0.66.0/tutorial/run_streamlit_remotely.html) 65 | 66 | 67 | #### Run streamlit app from Colab 68 | If you don't have access to GPUs we also created a [Colab](https://colab.research.google.com/drive/1ALNRoZgCj35uJ3Xvs24-QDwwtCb2lm3P?usp=sharing) from which you can start the drawing app. 69 | 70 | ## Other implementations 71 | [PyTorch implementation](https://github.com/justanhduc/brushstroke-parameterized-style-transfer) by [justanhduc](https://github.com/justanhduc). 72 | 73 | ## Citation 74 | ``` 75 | @article{kotovenko_cvpr_2021, 76 | title={Rethinking Style Transfer: From Pixels to Parameterized Brushstrokes}, 77 | author={Dmytro Kotovenko and Matthias Wright and Arthur Heimbrecht and Bj{\"o}rn Ommer}, 78 | journal={CVPR}, 79 | year={2021} 80 | } 81 | ``` 82 | 83 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from typing import Sequence 4 | import streamlit as st 5 | from streamlit_drawable_canvas import st_canvas 6 | from stqdm import stqdm 7 | import os 8 | import base64 9 | 10 | from model import BrushstrokeOptimizer, PixelOptimizer 11 | 12 | 13 | def parse_paths(json_obj, height, width): 14 | xs = [] 15 | ys = [] 16 | for segments in json_obj['path']: 17 | if segments[0] == 'Q': 18 | xs.append(segments[2] / width) 19 | xs.append(segments[4] / width) 20 | ys.append(segments[1] / height) 21 | ys.append(segments[3] / height) 22 | xs = np.array(xs) 23 | ys = np.array(ys) 24 | return np.stack((xs, ys), axis=1) 25 | 26 | 27 | def sample_vectors(points, lookahead=10, freq=10): 28 | if points.shape[0] > 30: 29 | idcs = np.arange(points.shape[0])[::freq] 30 | 31 | idcs = np.arange(points.shape[0]) 32 | vectors = [] 33 | positions = [] 34 | 35 | lookahead = min(lookahead, idcs.shape[0] - 1) 36 | for i in range(idcs.shape[0] - lookahead): 37 | vectors.append(points[idcs[i] + lookahead] - points[idcs[i]]) 38 | positions.append(points[idcs[i]]) 39 | return np.array(vectors), np.array(positions) 40 | 41 | 42 | def get_binary_file_downloader_html(bin_file, file_label='File'): 43 | # Taken from: https://discuss.streamlit.io/t/how-to-download-file-in-streamlit/1806/27 44 | with open(bin_file, 'rb') as f: 45 | data = f.read() 46 | bin_str = base64.b64encode(data).decode() 47 | href = f'Download {file_label}' 48 | return href 49 | 50 | 51 | def resize(img, size, interpolation=Image.BILINEAR): 52 | # https://github.com/pytorch/vision/blob/master/torchvision/transforms/functional.py 53 | if isinstance(size, int) or len(size) == 1: 54 | if isinstance(size, Sequence): 55 | size = size[0] 56 | w, h = img.size 57 | if (w <= h and w == size) or (h <= w and h == size): 58 | return img 59 | if w < h: 60 | ow = size 61 | oh = int(size * h / w) 62 | return img.resize((ow, oh), interpolation) 63 | else: 64 | oh = size 65 | ow = int(size * w / h) 66 | return img.resize((ow, oh), interpolation) 67 | else: 68 | return img.resize(size[::-1], interpolation) 69 | 70 | 71 | st.sidebar.markdown(""" 72 | 83 | """, unsafe_allow_html=True) 84 | 85 | 86 | # Sidebar 87 | ## Content and Style image 88 | st.sidebar.markdown('

Content Image

', unsafe_allow_html=True) 89 | selected_content_image = st.sidebar.selectbox('Select content image:', [None] + os.listdir('images/content')) 90 | st.sidebar.text('OR') 91 | uploaded_content_image = st.sidebar.file_uploader('Upload content image:', type=['png', 'jpg']) 92 | 93 | st.sidebar.markdown('

Style Image

', unsafe_allow_html=True) 94 | selected_style_image = st.sidebar.selectbox('Select style image:', [None] + os.listdir('images/style')) 95 | st.sidebar.text('OR') 96 | uploaded_style_image = st.sidebar.file_uploader('Upload style image:', type=['png', 'jpg']) 97 | 98 | ## Parameters 99 | st.sidebar.markdown('

Options

', unsafe_allow_html=True) 100 | num_steps_stroke = st.sidebar.slider('Brushstroke optimization steps:', 20, 100, 100) 101 | num_steps_pixel = st.sidebar.slider('Pixel optimization steps:', 100, 5000, 2000) 102 | num_strokes = st.sidebar.slider('Number of brushstrokes:', 100, 10000, 5000) 103 | content_weight = st.sidebar.slider('Content weight:', 1.0, 50.0, 1.0) 104 | style_weight = st.sidebar.slider('Style weight:', 1.0, 50.0, 3.0) 105 | draw_weight = st.sidebar.slider('Drawing weight', 50.0, 200.0, 100.0) 106 | draw_strength = st.sidebar.slider('Drawing strength (denoted L in the paper):', 50, 200, 100) 107 | stroke_width = st.sidebar.slider('Stroke width:', 0.01, 2.0, 0.1) 108 | stroke_length = st.sidebar.slider('Stroke length:', 0.1, 2.0, 1.1) 109 | 110 | 111 | #drawing_mode = st.sidebar.selectbox( 112 | # 'Drawing tool:', ('freedraw', 'line', 'rect', 'circle', 'transform') 113 | #) 114 | realtime_update = st.sidebar.checkbox('Update in realtime', True) 115 | 116 | # Main 117 | stroke_color = st.color_picker('Stroke color hex: ', '#ff0000') 118 | 119 | 120 | content_img = None 121 | if selected_content_image is not None: content_img = Image.open(os.path.join('images/content', selected_content_image)) 122 | if uploaded_content_image is not None: content_img = Image.open(uploaded_content_image) 123 | 124 | style_img = None 125 | if selected_style_image is not None: style_img = Image.open(os.path.join('images/style', selected_style_image)) 126 | if uploaded_style_image is not None: style_img = Image.open(uploaded_style_image) 127 | 128 | if content_img is None or style_img is None: 129 | st.image(Image.open('docs/img/left_arrow.png')) 130 | st.image(Image.open('docs/img/down_arrow.png')) 131 | #st.markdown('

Select or upload content and style images...

', unsafe_allow_html=True) 132 | 133 | 134 | col1, col2 = st.beta_columns(2) 135 | 136 | # Preview images 137 | if content_img is not None: 138 | content_thumb = resize(content_img, size=400) 139 | col1.header('Content image') 140 | col1.image(content_img, use_column_width=True) 141 | if style_img is not None: 142 | style_thumb = resize(style_img, size=400) 143 | col2.header('Style image') 144 | col2.image(style_thumb, use_column_width=True) 145 | 146 | 147 | if content_img is not None and style_img is not None: 148 | if not os.path.exists('.temp'): 149 | os.makedirs('.temp') 150 | 151 | content_img_name = content_img.filename 152 | content_img = content_img.convert('RGB') 153 | content_img.save(f'.temp/content_img.jpg') 154 | style_img_name = style_img.filename 155 | style_img = style_img.convert('RGB') 156 | style_img.save(f'.temp/style_img.jpg') 157 | 158 | height = content_img.size[1] 159 | width = content_img.size[0] 160 | factor = 1.0 161 | # resize image such that the largest side is 512 because else the canvas drawer messes up 162 | if width > 512 or height > 512: 163 | if width < height: 164 | height = int(512 * (height / width)) 165 | width = 512 166 | factor *= height / width 167 | else: 168 | width = int(512 * (width / height)) 169 | height = 512 170 | factor *= width / height 171 | 172 | 173 | st.text('Now draw some curves on the canvas.') 174 | st.text('To draw a curve:') 175 | st.text('- hold down the left mouse button') 176 | st.text('- and move the mouse over the canvas.') 177 | 178 | # Create a canvas component 179 | canvas_result = st_canvas( 180 | fill_color='rgba(255, 165, 0, 0.3)', # Fixed fill color with some opacity 181 | stroke_width=3, 182 | stroke_color=stroke_color, 183 | background_color='' if content_img else '#eee', 184 | background_image=content_img, 185 | update_streamlit=realtime_update, 186 | height=height, 187 | width=width, 188 | drawing_mode='freedraw', 189 | #key='canvas', 190 | ) 191 | 192 | if canvas_result.json_data is not None: 193 | if len(canvas_result.json_data['objects']) > 0: 194 | if st.button('Stylize'): 195 | vectors_all = [] 196 | positions_all = [] 197 | img_array = np.array(content_img) 198 | for i in range(len(canvas_result.json_data['objects'])): 199 | points = parse_paths(canvas_result.json_data['objects'][i], float(height), float(width)) 200 | 201 | if points.shape[0] == 0: 202 | continue 203 | 204 | vectors, positions = sample_vectors(points, lookahead=5, freq=5) 205 | 206 | if vectors.ndim < 2 or positions.ndim < 2: 207 | continue 208 | 209 | vectors_all.append(vectors) 210 | positions_all.append(positions) 211 | 212 | for i in range(points.shape[0]): 213 | y = int(points[i, 0] * content_img.size[0]) 214 | x = int(points[i, 1] * content_img.size[1]) 215 | img_array[y-2:y+2, x-2:x+2] = np.array([255, 0, 0]) 216 | 217 | vectors_all = np.concatenate(vectors_all, axis=0).astype(np.float32) 218 | positions_all = np.concatenate(positions_all, axis=0).astype(np.float32) 219 | np.save('.temp/vectors', vectors_all) 220 | np.save('.temp/positions', positions_all) 221 | 222 | content_img = Image.open('.temp/content_img.jpg') 223 | style_img = Image.open('.temp/style_img.jpg') 224 | 225 | st.text('Brushstroke optimization...') 226 | pbar = stqdm(range(num_steps_stroke)) 227 | stroke_optim = BrushstrokeOptimizer(content_img, 228 | style_img, 229 | draw_curve_position_path='.temp/positions.npy', 230 | draw_curve_vector_path='.temp/vectors.npy', 231 | draw_strength=draw_strength, 232 | resolution=512, 233 | num_strokes=num_strokes, 234 | num_steps=num_steps_stroke, 235 | width_scale=stroke_width, 236 | length_scale=stroke_length, 237 | content_weight=content_weight, 238 | style_weight=style_weight, 239 | draw_weight=draw_weight, 240 | streamlit_pbar=pbar) 241 | canvas = stroke_optim.optimize() 242 | 243 | st.text('Pixel optimization...') 244 | pbar = stqdm(range(num_steps_pixel)) 245 | pixel_optim = PixelOptimizer(canvas, 246 | style_img, 247 | resolution=1024, 248 | num_steps=num_steps_pixel, 249 | content_weight=1.0, 250 | style_weight=10000.0, 251 | streamlit_pbar=pbar) 252 | canvas = pixel_optim.optimize() 253 | 254 | st.text('Stylized image:') 255 | st.image(canvas.resize((width, height))) 256 | 257 | canvas.save('.temp/canvas.jpg') 258 | st.markdown(get_binary_file_downloader_html('.temp/canvas.jpg', 'stylized image in high resolution'), unsafe_allow_html=True) 259 | 260 | 261 | -------------------------------------------------------------------------------- /docs/img/down_arrow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/docs/img/down_arrow.png -------------------------------------------------------------------------------- /docs/img/left_arrow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/docs/img/left_arrow.png -------------------------------------------------------------------------------- /docs/img/streamlit.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/docs/img/streamlit.jpg -------------------------------------------------------------------------------- /docs/img/title_figure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/docs/img/title_figure.jpg -------------------------------------------------------------------------------- /images/content/elefant.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/content/elefant.jpg -------------------------------------------------------------------------------- /images/content/golden_gate.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/content/golden_gate.jpg -------------------------------------------------------------------------------- /images/content/olive_trees_greece.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/content/olive_trees_greece.jpg -------------------------------------------------------------------------------- /images/content/road.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/content/road.jpg -------------------------------------------------------------------------------- /images/content/winxp.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/content/winxp.jpg -------------------------------------------------------------------------------- /images/style/derain_mountains_at_colloiure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/style/derain_mountains_at_colloiure.jpg -------------------------------------------------------------------------------- /images/style/picasso_self_portrait.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/style/picasso_self_portrait.jpg -------------------------------------------------------------------------------- /images/style/picasso_weeping_woman.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/style/picasso_weeping_woman.jpg -------------------------------------------------------------------------------- /images/style/van_gogh_red_cabbages_and_onions.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/style/van_gogh_red_cabbages_and_onions.jpg -------------------------------------------------------------------------------- /images/style/van_gogh_starry_night.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/style/van_gogh_starry_night.jpg -------------------------------------------------------------------------------- /images/style/van_gogh_trees.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/style/van_gogh_trees.jpg -------------------------------------------------------------------------------- /images/style/van_gogh_van_gogh_road_cypress.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/brushstroke-parameterized-style-transfer/72f3cabe7092e754fe0e9072741b399aee9e6dff/images/style/van_gogh_van_gogh_road_cypress.jpg -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | with warnings.catch_warnings(): 3 | warnings.filterwarnings('ignore', category=RuntimeWarning) 4 | warnings.filterwarnings('ignore', category=FutureWarning) 5 | import tensorflow as tf 6 | from tensorflow.core.protobuf import config_pb2 7 | 8 | import os 9 | import numpy as np 10 | from PIL import Image 11 | from tqdm import trange 12 | 13 | import networks 14 | import ops 15 | import utils 16 | 17 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) 18 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 19 | 20 | 21 | def stylize(content_img, 22 | style_img, 23 | # Brushstroke optimizer params 24 | resolution=512, 25 | num_strokes=5000, 26 | num_steps=100, 27 | S=10, 28 | K=20, 29 | canvas_color='gray', 30 | width_scale=0.1, 31 | length_scale=1.1, 32 | content_weight=1.0, 33 | style_weight=3.0, 34 | tv_weight=0.008, 35 | curviture_weight=4.0, 36 | # Pixel optimizer params 37 | pixel_resolution=1024, 38 | num_steps_pixel=2000 39 | ): 40 | 41 | stroke_optim = BrushstrokeOptimizer(content_img, 42 | style_img, 43 | resolution=resolution, 44 | num_strokes=num_strokes, 45 | num_steps=num_steps, 46 | S=S, 47 | K=K, 48 | canvas_color=canvas_color, 49 | width_scale=width_scale, 50 | length_scale=length_scale, 51 | content_weight=content_weight, 52 | style_weight=style_weight, 53 | tv_weight=tv_weight, 54 | curviture_weight=curviture_weight) 55 | print('Stroke optimization:') 56 | canvas = stroke_optim.optimize() 57 | 58 | pixel_optim = PixelOptimizer(canvas, 59 | style_img, 60 | resolution=pixel_resolution, 61 | num_steps=num_steps_pixel, 62 | content_weight=1.0, 63 | style_weight=10000.0) 64 | 65 | print('Pixel optimization:') 66 | canvas = pixel_optim.optimize() 67 | return canvas 68 | 69 | 70 | class BrushstrokeOptimizer: 71 | 72 | def __init__(self, 73 | content_img, # Content image (PIL.Image). 74 | style_img, # Style image (PIL.Image). 75 | draw_curve_position_path = None, # Set of points that represent the drawn curves, denoted as P_i in Sec. B of the paper (str). 76 | draw_curve_vector_path = None, # Set of tangent vectors for the points of the drawn curves, denoted as v_i in Sec. B of the paper (str). 77 | draw_strength = 100, # Strength of the influence of the drawn curves, denoted L in Sec. B of the paper (int). 78 | resolution = 512, # Resolution of the canvas (int). 79 | num_strokes = 5000, # Number of brushstrokes (int). 80 | num_steps = 100, # Number of optimization steps (int). 81 | S = 10, # Number of points to sample on each curve, see Sec. 4.2.1 of the paper (int). 82 | K = 20, # Number of brushstrokes to consider for each pixel, see Sec. C.2 of the paper (int). 83 | canvas_color = 'gray', # Color of the canvas (str). 84 | width_scale = 0.1, # Scale parameter for the brushstroke width (float). 85 | length_scale = 1.1, # Scale parameter for the brushstroke length (float). 86 | content_weight = 1.0, # Weight for the content loss (float). 87 | style_weight = 3.0, # Weight for the style loss (float). 88 | tv_weight = 0.008, # Weight for the total variation loss (float). 89 | draw_weight = 100.0, # Weight for the drawing projection loss (float) 90 | curviture_weight = 4.0, # Weight for the curviture loss (float). 91 | streamlit_pbar = None, # Progressbar for streamlit app (obj). 92 | dtype = 'float32' # Data type (str). 93 | ): 94 | 95 | self.draw_strength = draw_strength 96 | self.draw_weight = draw_weight 97 | self.resolution = resolution 98 | self.num_strokes = num_strokes 99 | self.num_steps = num_steps 100 | self.S = S 101 | self.K = K 102 | self.canvas_color = canvas_color 103 | self.width_scale = width_scale 104 | self.length_scale = length_scale 105 | self.content_weight = content_weight 106 | self.style_weight = style_weight 107 | self.tv_weight = tv_weight 108 | self.curviture_weight = curviture_weight 109 | self.streamlit_pbar = streamlit_pbar 110 | self.dtype = dtype 111 | 112 | # Set canvas size (set smaller side of content image to 'resolution' and scale other side accordingly) 113 | W, H = content_img.size 114 | if H < W: 115 | new_H = resolution 116 | new_W = int((W / H) * new_H) 117 | else: 118 | new_W = resolution 119 | new_H = int((H / W) * new_W) 120 | 121 | self.canvas_height = new_H 122 | self.canvas_width = new_W 123 | 124 | content_img = content_img.resize((self.canvas_width, self.canvas_height)) 125 | style_img = style_img.resize((self.canvas_width, self.canvas_height)) 126 | 127 | content_img = np.array(content_img).astype(self.dtype) 128 | style_img = np.array(style_img).astype(self.dtype) 129 | 130 | content_img /= 255.0 131 | style_img /= 255.0 132 | 133 | self.content_img_np = content_img 134 | self.style_img_np = style_img 135 | 136 | if draw_curve_position_path is not None and draw_curve_vector_path is not None: 137 | self.draw_curve_position_np = np.load(draw_curve_position_path) 138 | self.draw_curve_vector_np = np.load(draw_curve_vector_path) 139 | self.draw_curve_position_np[..., 0] *= self.canvas_width 140 | self.draw_curve_position_np[..., 1] *= self.canvas_height 141 | 142 | ckpt_path = utils.download_weights(url='https://www.dropbox.com/s/hv7b4eajrj7isyq/vgg_weights.pickle?dl=1', 143 | name='vgg_weights.pickle') 144 | self.vgg = networks.VGG(ckpt_path=ckpt_path) 145 | 146 | def optimize(self): 147 | self._initialize() 148 | self._render() 149 | self._losses() 150 | self._optimizer() 151 | 152 | 153 | with tf.Session() as sess: 154 | sess.run(tf.global_variables_initializer()) 155 | steps = trange(self.num_steps, desc='', leave=True) 156 | for step in steps: 157 | 158 | I_, loss_dict_, params_dict_, _ = \ 159 | sess.run(fetches=[self.I, 160 | self.loss_dict, 161 | self.params_dict, 162 | self.optim_step_with_constraints], 163 | options=config_pb2.RunOptions(report_tensor_allocations_upon_oom=True) 164 | ) 165 | 166 | steps.set_description(f'content_loss: {loss_dict_["content"]:.6f}, style_loss: {loss_dict_["style"]:.6f}') 167 | #s = '' 168 | #for key in loss_dict_: 169 | # loss = loss_dict_[key] 170 | # s += key + f': {loss_dict_[key]:.4f}, ' 171 | #steps.set_description(s[:-2]) 172 | #print(s) 173 | 174 | steps.refresh() 175 | if self.streamlit_pbar is not None: self.streamlit_pbar.update(1) 176 | return Image.fromarray(np.array(np.clip(I_, 0, 1) * 255, dtype=np.uint8)) 177 | 178 | def _initialize(self): 179 | location, s, e, c, width, color = utils.initialize_brushstrokes(self.content_img_np, 180 | self.num_strokes, 181 | self.canvas_height, 182 | self.canvas_width, 183 | self.length_scale, 184 | self.width_scale) 185 | 186 | self.curve_s = tf.Variable(name='curve_s', initial_value=s, dtype=self.dtype) 187 | self.curve_e = tf.Variable(name='curve_e', initial_value=e, dtype=self.dtype) 188 | self.curve_c = tf.Variable(name='curve_c', initial_value=c, dtype=self.dtype) 189 | self.color = tf.Variable(name='color', initial_value=color, dtype=self.dtype) 190 | self.location = tf.Variable(name='location', initial_value=location, dtype=self.dtype) 191 | self.width = tf.Variable(name='width', initial_value=width, dtype=self.dtype) 192 | self.content_img = tf.constant(name='content_img', value=self.content_img_np, dtype=self.dtype) 193 | self.style_img = tf.constant(name='style_img', value=self.style_img_np, dtype=self.dtype) 194 | 195 | if hasattr(self, 'draw_curve_position_np') and hasattr(self, 'draw_curve_vector_np'): 196 | self.draw_curve_position = tf.constant(name='draw_curve_position', value=self.draw_curve_position_np, dtype=self.dtype) 197 | self.draw_curve_vector = tf.constant(name='draw_curve_vector', value=self.draw_curve_vector_np, dtype=self.dtype) 198 | 199 | self.params_dict = {'location': self.location, 200 | 'curve_s': self.curve_s, 201 | 'curve_e': self.curve_e, 202 | 'curve_c': self.curve_c, 203 | 'width': self.width, 204 | 'color': self.color} 205 | 206 | def _render(self): 207 | curve_points = ops.sample_quadratic_bezier_curve(s=self.curve_s + self.location, 208 | e=self.curve_e + self.location, 209 | c=self.curve_c + self.location, 210 | num_points=self.S, 211 | dtype=self.dtype) 212 | 213 | self.I = ops.renderer(curve_points, 214 | self.location, 215 | self.color, 216 | self.width, 217 | self.canvas_height, 218 | self.canvas_width, 219 | self.K, 220 | canvas_color=self.canvas_color, 221 | dtype=self.dtype) 222 | 223 | def _losses(self): 224 | # resize images to save memory 225 | rendered_canvas_resized = \ 226 | tf.image.resize_nearest_neighbor(images=ops.preprocess_img(self.I), 227 | size=(int(self.canvas_height // 2), int(self.canvas_width // 2))) 228 | 229 | content_img_resized = \ 230 | tf.image.resize_nearest_neighbor(images=ops.preprocess_img(self.content_img), 231 | size=(int(self.canvas_height // 2), int(self.canvas_width // 2))) 232 | 233 | style_img_resized = \ 234 | tf.image.resize_nearest_neighbor(images=ops.preprocess_img(self.style_img), 235 | size=(int(self.canvas_height // 2), int(self.canvas_width // 2))) 236 | 237 | self.loss_dict = {} 238 | self.loss_dict['content'] = ops.content_loss(self.vgg.extract_features(rendered_canvas_resized), 239 | self.vgg.extract_features(content_img_resized), 240 | #layers=['conv1_2', 'conv2_2', 'conv3_2', 'conv4_2', 'conv5_2'], 241 | layers=['conv4_2', 'conv5_2'], 242 | weights=[1, 1], 243 | scale_by_y=True) 244 | self.loss_dict['content'] *= self.content_weight 245 | 246 | self.loss_dict['style'] = ops.style_loss(self.vgg.extract_features(rendered_canvas_resized), 247 | self.vgg.extract_features(style_img_resized), 248 | layers=['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1'], 249 | weights=[1, 1, 1, 1, 1]) 250 | self.loss_dict['style'] *= self.style_weight 251 | 252 | self.loss_dict['curviture'] = ops.curviture_loss(self.curve_s, self.curve_e, self.curve_c) 253 | self.loss_dict['curviture'] *= self.curviture_weight 254 | 255 | self.loss_dict['tv'] = ops.total_variation_loss(x_loc=self.location, s=self.curve_s, e=self.curve_e, K=10) 256 | self.loss_dict['tv'] *= self.tv_weight 257 | 258 | if hasattr(self, 'draw_curve_position') and hasattr(self, 'draw_curve_vector'): 259 | self.loss_dict['drawing'] = ops.draw_projection_loss(self.location, 260 | self.curve_s, 261 | self.curve_e, 262 | self.draw_curve_position, 263 | self.draw_curve_vector, 264 | self.draw_strength) 265 | self.loss_dict['drawing'] *= self.draw_weight 266 | 267 | 268 | def _optimizer(self): 269 | loss = tf.constant(0.0) 270 | for key in self.loss_dict: 271 | loss += self.loss_dict[key] 272 | 273 | step_ops = [] 274 | optim_step = tf.train.AdamOptimizer(0.1).minimize( 275 | loss=loss, 276 | var_list=[self.location, self.curve_s, self.curve_e, self.curve_c, self.width]) 277 | step_ops.append(optim_step) 278 | optim_step_color = tf.train.AdamOptimizer(0.01).minimize( 279 | loss=self.loss_dict['style'], 280 | var_list=self.color) 281 | step_ops.append(optim_step_color) 282 | 283 | # constraint parameters to certain range 284 | with tf.control_dependencies(step_ops.copy()): 285 | step_ops.append(tf.assign(self.color, tf.clip_by_value(self.color, 0, 1))) 286 | coord_x, coord_y = tf.gather(self.location, axis=-1, indices=[0]), tf.gather(self.location, axis=-1, indices=[1]) 287 | coord_clip = tf.concat([tf.clip_by_value(coord_x, 0, self.canvas_height), tf.clip_by_value(coord_y, 0, self.canvas_width)], axis=-1) 288 | step_ops.append(tf.assign(self.location, coord_clip)) 289 | step_ops.append(tf.assign(self.width, tf.nn.relu(self.width))) 290 | self.optim_step_with_constraints = tf.group(*step_ops) 291 | 292 | 293 | class PixelOptimizer: 294 | 295 | def __init__(self, 296 | canvas, # Canvas (PIL.Image). 297 | style_img, # Style image (PIL.Image). 298 | resolution = 1024, # Resolution of the canvas. 299 | num_steps = 2000, # Number of optimization steps. 300 | content_weight = 1.0, # Weight for the content loss. 301 | style_weight = 10000.0, # Weight for the style loss. 302 | tv_weight = 0.0, # Weight for the total variation loss. 303 | streamlit_pbar = None, # Progressbar for streamlit app (obj). 304 | dtype = 'float32' # Data type. 305 | ): 306 | 307 | self.resolution = resolution 308 | self.num_steps = num_steps 309 | self.content_weight = content_weight 310 | self.style_weight = style_weight 311 | self.tv_weight = tv_weight 312 | self.streamlit_pbar = streamlit_pbar 313 | self.dtype = dtype 314 | 315 | # Set canvas size (set smaller side of content image to 'resolution' and scale other side accordingly) 316 | W, H = canvas.size 317 | if H < W: 318 | new_H = resolution 319 | new_W = int((W / H) * new_H) 320 | else: 321 | new_W = resolution 322 | new_H = int((H / W) * new_W) 323 | 324 | self.canvas_height = new_H 325 | self.canvas_width = new_W 326 | 327 | canvas = canvas.resize((self.canvas_width, self.canvas_height)) 328 | style_img = style_img.resize((self.canvas_width, self.canvas_height)) 329 | 330 | canvas = np.array(canvas).astype(self.dtype) 331 | style_img = np.array(style_img).astype(self.dtype) 332 | 333 | canvas /= 255.0 334 | style_img /= 255.0 335 | 336 | self.canvas_np = canvas 337 | self.content_img_np = canvas 338 | self.style_img_np = style_img 339 | 340 | ckpt_path = utils.download_weights(url='https://www.dropbox.com/s/hv7b4eajrj7isyq/vgg_weights.pickle?dl=1', 341 | name='vgg_weights.pickle') 342 | self.vgg = networks.VGG(ckpt_path=ckpt_path) 343 | 344 | def optimize(self): 345 | self._initialize() 346 | self._losses() 347 | self._optimizer() 348 | 349 | with tf.Session() as sess: 350 | sess.run(tf.global_variables_initializer()) 351 | steps = trange(self.num_steps, desc='', leave=True) 352 | for step in steps: 353 | canvas_, loss_dict_, _ = \ 354 | sess.run(fetches=[self.canvas, 355 | self.loss_dict, 356 | self.optim_step_with_constraints], 357 | options=config_pb2.RunOptions(report_tensor_allocations_upon_oom=True) 358 | ) 359 | 360 | s = '' 361 | for key in loss_dict_: 362 | loss = loss_dict_[key] 363 | s += key + f': {loss_dict_[key]:.6f}, ' 364 | 365 | steps.set_description(s[:-2]) 366 | steps.refresh() 367 | if self.streamlit_pbar is not None: self.streamlit_pbar.update(1) 368 | return Image.fromarray(np.array(np.clip(canvas_, 0, 1) * 255, dtype=np.uint8)) 369 | 370 | def _initialize(self): 371 | self.canvas = tf.Variable(name='canvas', initial_value=self.canvas_np, dtype=self.dtype) 372 | self.content_img = tf.constant(name='content_img', value=self.content_img_np, dtype=self.dtype) 373 | self.style_img = tf.constant(name='style_img', value=self.style_img_np, dtype=self.dtype) 374 | 375 | def _losses(self): 376 | # resize images to save memory 377 | rendered_canvas_resized = \ 378 | tf.image.resize_nearest_neighbor(images=ops.preprocess_img(self.canvas), 379 | size=(int(self.canvas_height), int(self.canvas_width))) 380 | 381 | content_img_resized = \ 382 | tf.image.resize_nearest_neighbor(images=ops.preprocess_img(self.content_img), 383 | size=(int(self.canvas_height), int(self.canvas_width))) 384 | 385 | style_img_resized = \ 386 | tf.image.resize_nearest_neighbor(images=ops.preprocess_img(self.style_img), 387 | size=(int(self.canvas_height), int(self.canvas_width))) 388 | 389 | self.loss_dict = {} 390 | self.loss_dict['content'] = ops.content_loss(self.vgg.extract_features(rendered_canvas_resized), 391 | self.vgg.extract_features(content_img_resized), 392 | layers=['conv1_2_pool', 'conv2_2_pool', 'conv3_3_pool', 'conv4_3_pool', 'conv5_3_pool'], 393 | weights=[1, 1, 1, 1, 1]) 394 | self.loss_dict['content'] *= self.content_weight 395 | 396 | self.loss_dict['style'] = ops.style_loss(self.vgg.extract_features(rendered_canvas_resized), 397 | self.vgg.extract_features(style_img_resized), 398 | layers=['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1'], 399 | weights=[1, 1, 1, 1, 1]) 400 | self.loss_dict['style'] *= self.style_weight 401 | 402 | self.loss_dict['tv'] = ((tf.nn.l2_loss(self.canvas[1:, :, :] - self.canvas[:-1, :, :]) / self.canvas.shape.as_list()[0]) + 403 | (tf.nn.l2_loss(self.canvas[:, 1:, :] - self.canvas[:, :-1, :]) / self.canvas.shape.as_list()[1])) 404 | self.loss_dict['tv'] *= self.tv_weight 405 | 406 | def _optimizer(self): 407 | loss = tf.constant(0.0) 408 | for key in self.loss_dict: 409 | loss += self.loss_dict[key] 410 | 411 | step_ops = [] 412 | optim_step = tf.train.AdamOptimizer(0.01).minimize(loss=loss, var_list=self.canvas) 413 | step_ops.append(optim_step) 414 | 415 | # constraint parameters to certain range 416 | with tf.control_dependencies(step_ops.copy()): 417 | step_ops.append(tf.assign(self.canvas, tf.clip_by_value(self.canvas, 0, 1))) 418 | 419 | self.optim_step_with_constraints = tf.group(*step_ops) 420 | 421 | 422 | 423 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | with warnings.catch_warnings(): 3 | warnings.filterwarnings("ignore", category=FutureWarning) 4 | import tensorflow as tf 5 | import pickle 6 | 7 | 8 | class VGG: 9 | 10 | def __init__(self, ckpt_path): 11 | self.param_dict = pickle.load(open(ckpt_path, 'rb')) 12 | 13 | def extract_features(self, x): 14 | features = {} 15 | x = self._conv2d_block(x, self.param_dict['block1']['conv1']['weight'], self.param_dict['block1']['conv1']['bias']) 16 | features['conv1_1'] = x 17 | x = self._conv2d_block(x, self.param_dict['block1']['conv2']['weight'], self.param_dict['block1']['conv2']['bias']) 18 | features['conv1_2'] = x 19 | x = tf.nn.max_pool2d(x, ksize=2, strides=2, padding='VALID') 20 | features['conv1_2_pool'] = x 21 | 22 | x = self._conv2d_block(x, self.param_dict['block2']['conv1']['weight'], self.param_dict['block2']['conv1']['bias']) 23 | features['conv2_1'] = x 24 | x = self._conv2d_block(x, self.param_dict['block2']['conv2']['weight'], self.param_dict['block2']['conv2']['bias']) 25 | features['conv2_2'] = x 26 | x = tf.nn.max_pool2d(x, ksize=2, strides=2, padding='VALID') 27 | features['conv2_2_pool'] = x 28 | 29 | x = self._conv2d_block(x, self.param_dict['block3']['conv1']['weight'], self.param_dict['block3']['conv1']['bias']) 30 | features['conv3_1'] = x 31 | x = self._conv2d_block(x, self.param_dict['block3']['conv2']['weight'], self.param_dict['block3']['conv2']['bias']) 32 | features['conv3_2'] = x 33 | x = self._conv2d_block(x, self.param_dict['block3']['conv3']['weight'], self.param_dict['block3']['conv3']['bias']) 34 | features['conv3_3'] = x 35 | x = tf.nn.max_pool2d(x, ksize=2, strides=2, padding='VALID') 36 | features['conv3_3_pool'] = x 37 | 38 | x = self._conv2d_block(x, self.param_dict['block4']['conv1']['weight'], self.param_dict['block4']['conv1']['bias']) 39 | features['conv4_1'] = x 40 | x = self._conv2d_block(x, self.param_dict['block4']['conv2']['weight'], self.param_dict['block4']['conv2']['bias']) 41 | features['conv4_2'] = x 42 | x = self._conv2d_block(x, self.param_dict['block4']['conv3']['weight'], self.param_dict['block4']['conv3']['bias']) 43 | features['conv4_3'] = x 44 | x = tf.nn.max_pool2d(x, ksize=2, strides=2, padding='VALID') 45 | features['conv4_3_pool'] = x 46 | 47 | x = self._conv2d_block(x, self.param_dict['block5']['conv1']['weight'], self.param_dict['block5']['conv1']['bias']) 48 | features['conv5_1'] = x 49 | x = self._conv2d_block(x, self.param_dict['block5']['conv2']['weight'], self.param_dict['block5']['conv2']['bias']) 50 | features['conv5_2'] = x 51 | x = self._conv2d_block(x, self.param_dict['block5']['conv3']['weight'], self.param_dict['block5']['conv3']['bias']) 52 | features['conv5_3'] = x 53 | x = tf.nn.max_pool2d(x, ksize=2, strides=2, padding='VALID') 54 | features['conv5_3_pool'] = x 55 | return features 56 | 57 | def _conv2d_block(self, x, kernel, bias): 58 | x = tf.nn.conv2d(x, filters=kernel, strides=[1, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0]]) 59 | x = tf.nn.bias_add(x, bias=bias, data_format='N...C') 60 | x = tf.nn.relu(x) 61 | return x 62 | 63 | -------------------------------------------------------------------------------- /notebooks/BrushstrokeStyleTransfer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "BrushstrokeStyleTransfer.ipynb", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "display_name": "Python 3", 13 | "name": "python3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | } 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "code", 22 | "metadata": { 23 | "id": "N80TrV4juOLV" 24 | }, 25 | "source": [ 26 | "!pip uninstall -y albumentations\n", 27 | "!pip install -U scikit-learn\n", 28 | "!pip install scikit-image==0.17.2" 29 | ], 30 | "execution_count": null, 31 | "outputs": [] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": { 36 | "id": "WfavNAsehO_-" 37 | }, 38 | "source": [ 39 | "

Ignore the error messages...

" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "metadata": { 45 | "id": "d_gmUTe69NK2" 46 | }, 47 | "source": [ 48 | "!git clone https://github.com/CompVis/brushstroke-parameterized-style-transfer\n", 49 | "%cd brushstroke-parameterized-style-transfer" 50 | ], 51 | "execution_count": null, 52 | "outputs": [] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "metadata": { 57 | "id": "Bg7ppjI48CuI" 58 | }, 59 | "source": [ 60 | "import os\n", 61 | "from google.colab import files\n", 62 | "import matplotlib.pyplot as plt\n", 63 | "import numpy as np\n", 64 | "from PIL import Image\n", 65 | "\n", 66 | "%tensorflow_version 1.x\n", 67 | "import tensorflow as tf\n", 68 | "tf.test.is_gpu_available()\n", 69 | "tf.__version__" 70 | ], 71 | "execution_count": null, 72 | "outputs": [] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": { 77 | "id": "eHSa8q_ziWuc" 78 | }, 79 | "source": [ 80 | "

Upload your own images

" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "metadata": { 86 | "id": "EcaQJvqfieUw" 87 | }, 88 | "source": [ 89 | "uploaded = files.upload()\n", 90 | "\n", 91 | "for fn in uploaded.keys():\n", 92 | " print('User uploaded file \"{name}\" with length {length} bytes'.format(\n", 93 | " name=fn, length=len(uploaded[fn])))" 94 | ], 95 | "execution_count": null, 96 | "outputs": [] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": { 101 | "id": "zlKwIBl8jeFc" 102 | }, 103 | "source": [ 104 | "

Or choose from the available images

\n" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": { 110 | "id": "ZY5Xnu5_oCor" 111 | }, 112 | "source": [ 113 | "

Content Images

" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "metadata": { 119 | "id": "KbrqOU0XjLgQ" 120 | }, 121 | "source": [ 122 | "content_images = os.listdir('images/content')\n", 123 | "print(content_images)\n", 124 | "\n", 125 | "fig, ax = plt.subplots(nrows=1, ncols=len(content_images), figsize=(30, 5))\n", 126 | "for i in range(len(content_images)):\n", 127 | " img = Image.open(os.path.join('images/content/', content_images[i]))\n", 128 | " ax[i].imshow(np.array(img))\n", 129 | " ax[i].axis('off')\n", 130 | " ax[i].title.set_text(content_images[i])" 131 | ], 132 | "execution_count": null, 133 | "outputs": [] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": { 138 | "id": "UxzLXLvtpkVJ" 139 | }, 140 | "source": [ 141 | "

Style Images

\n" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "metadata": { 147 | "id": "SMeF9GzxoTIA" 148 | }, 149 | "source": [ 150 | "style_images = os.listdir('images/style')\n", 151 | "print(style_images)\n", 152 | "\n", 153 | "fig, ax = plt.subplots(nrows=1, ncols=len(style_images), figsize=(30, 5))\n", 154 | "for i in range(len(style_images)):\n", 155 | " img = Image.open(os.path.join('images/style/', style_images[i]))\n", 156 | " ax[i].imshow(np.array(img))\n", 157 | " ax[i].axis('off')\n", 158 | " ax[i].title.set_text(style_images[i])" 159 | ], 160 | "execution_count": null, 161 | "outputs": [] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": { 166 | "id": "0AJJixL6rd5R" 167 | }, 168 | "source": [ 169 | "

Load content and style images

" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "metadata": { 175 | "id": "VRTmFFtjpq_W" 176 | }, 177 | "source": [ 178 | "# Note: for uploaded images, the prefix 'images/{content, style}' needs to be removed.\n", 179 | "# If you upload an image named 'uploaded_image.jpg', the path to use is also 'uploaded_image.jpg'.\n", 180 | "content_img = Image.open('images/content/golden_gate.jpg')\n", 181 | "style_img = Image.open('images/style/van_gogh_starry_night.jpg')\n", 182 | "\n", 183 | "fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))\n", 184 | "ax[0].imshow(np.array(content_img))\n", 185 | "ax[0].axis('off')\n", 186 | "ax[0].title.set_text('Content image')\n", 187 | "\n", 188 | "ax[1].imshow(np.array(style_img))\n", 189 | "ax[1].axis('off')\n", 190 | "ax[1].title.set_text('Style image')" 191 | ], 192 | "execution_count": null, 193 | "outputs": [] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "metadata": { 198 | "id": "O5-H46Y0q_99" 199 | }, 200 | "source": [ 201 | "import model\n", 202 | "\n", 203 | "\n", 204 | "stylized_img = model.stylize(content_img, # Content image (PIL.Image).\n", 205 | " style_img, # Style image (PIL.Image).\n", 206 | " num_strokes=5000, # Number of brushstrokes (int). \n", 207 | " num_steps=100, # Number of stroke optimization steps (int).\n", 208 | " canvas_color='gray', # Color of the canvas (str). Options: 'gray', 'white', 'black', 'noise'\n", 209 | " width_scale=0.1, # Scale parameter for the brushstroke width (float).\n", 210 | " length_scale=1.1, # Scale parameter for the brushstroke length (float).\n", 211 | " content_weight=1.0, # Weight for the content loss (float).\n", 212 | " style_weight=3.0, # Weight for the style loss (float).\n", 213 | " tv_weight=0.008, # Weight for the total variation loss (float). \n", 214 | " pixel_resolution=1024, # Resolution of the canvas for pixel optimization (int).\n", 215 | " num_steps_pixel=2000) # Number of pixel optimization steps (int).\n", 216 | "\n", 217 | "stylized_img.save('stylized.jpg')\n", 218 | "display(stylized_img)" 219 | ], 220 | "execution_count": null, 221 | "outputs": [] 222 | } 223 | ] 224 | } -------------------------------------------------------------------------------- /notebooks/BrushstrokeStyleTransferDrawApp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "BrushstrokeStyleTransferDrawApp.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | }, 17 | "accelerator": "GPU" 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "code", 22 | "metadata": { 23 | "id": "Yhv1eMzsHm5E" 24 | }, 25 | "source": [ 26 | "!pip uninstall -y datascience\n", 27 | "!pip install colab-everything\n", 28 | "!pip install streamlit-drawable-canvas\n", 29 | "!pip install -U scikit-learn\n", 30 | "!pip install scikit-image==0.17.2\n", 31 | "!pip install stqdm" 32 | ], 33 | "execution_count": null, 34 | "outputs": [] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": { 39 | "id": "KZv5oLILZwZo" 40 | }, 41 | "source": [ 42 | "

Ignore the error messages...

" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "metadata": { 48 | "id": "elOYlKQ9Xz9X" 49 | }, 50 | "source": [ 51 | "!git clone https://github.com/CompVis/brushstroke-parameterized-style-transfer\n", 52 | "%cd brushstroke-parameterized-style-transfer" 53 | ], 54 | "execution_count": null, 55 | "outputs": [] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": { 60 | "id": "vtKSiie6VW0V" 61 | }, 62 | "source": [ 63 | "

Run app

" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": { 69 | "id": "Iv_qmYLnVnLG" 70 | }, 71 | "source": [ 72 | "

You have to click on the link that ends with \".ngrok.io\".\n", 73 | "\n", 74 | "

" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "metadata": { 80 | "id": "JPZzcLolSZHj" 81 | }, 82 | "source": [ 83 | "%tensorflow_version 1.x\n", 84 | "import tensorflow as tf\n", 85 | "tf.test.is_gpu_available()\n", 86 | "tf.__version__\n", 87 | "\n", 88 | "from colab_everything import ColabStreamlit\n", 89 | "ColabStreamlit('app.py')" 90 | ], 91 | "execution_count": null, 92 | "outputs": [] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "metadata": { 97 | "id": "jPA8BG9NKvd6" 98 | }, 99 | "source": [ 100 | "" 101 | ], 102 | "execution_count": null, 103 | "outputs": [] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "metadata": { 108 | "id": "T8OQDcflKxQN" 109 | }, 110 | "source": [ 111 | "" 112 | ], 113 | "execution_count": null, 114 | "outputs": [] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "metadata": { 119 | "id": "jXjSTak_K1bv" 120 | }, 121 | "source": [ 122 | "" 123 | ], 124 | "execution_count": null, 125 | "outputs": [] 126 | } 127 | ] 128 | } -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | with warnings.catch_warnings(): 4 | warnings.filterwarnings("ignore", category=FutureWarning) 5 | import tensorflow as tf 6 | 7 | 8 | #--------------------------------------------------------------------- 9 | # Misc 10 | #--------------------------------------------------------------------- 11 | def preprocess_img(x): 12 | x = 2 * x - 1 13 | x = tf.expand_dims(x, axis=0) 14 | return x 15 | 16 | 17 | def norm(x, axis=None, keepdims=None, eps=1e-8): 18 | """ 19 | Numerically stable norm. 20 | """ 21 | return tf.sqrt(tf.reduce_sum(tf.square(x), axis=axis, keepdims=keepdims) + eps) 22 | #return tf.reduce_sum(tf.square(x), axis=axis, keepdims=keepdims) 23 | 24 | 25 | #--------------------------------------------------------------------- 26 | # Brushstrokes 27 | #--------------------------------------------------------------------- 28 | def sample_quadratic_bezier_curve(s, c, e, num_points=20, dtype='float32'): 29 | """ 30 | Samples points from the quadratic bezier curves defined by the control points. 31 | Number of points to sample is num. 32 | 33 | Args: 34 | s (tensor): Start point of each curve, shape [N, 2]. 35 | c (tensor): Control point of each curve, shape [N, 2]. 36 | e (tensor): End point of each curve, shape [N, 2]. 37 | num_points (int): Number of points to sample on every curve. 38 | 39 | Return: 40 | (tensor): Coordinates of the points on the Bezier curves, shape [N, num_points, 2] 41 | """ 42 | N, _ = s.shape.as_list() 43 | t = tf.linspace(0., 1., num_points) 44 | t = tf.cast(t, dtype=dtype) 45 | t = tf.stack([t] * N, axis=0) 46 | s_x = tf.expand_dims(s[..., 0], axis=1) 47 | s_y = tf.expand_dims(s[..., 1], axis=1) 48 | e_x = tf.expand_dims(e[..., 0], axis=1) 49 | e_y = tf.expand_dims(e[..., 1], axis=1) 50 | c_x = tf.expand_dims(c[..., 0], axis=1) 51 | c_y = tf.expand_dims(c[..., 1], axis=1) 52 | x = c_x + (1. - t) ** 2 * (s_x - c_x) + t ** 2 * (e_x - c_x) 53 | y = c_y + (1. - t) ** 2 * (s_y - c_y) + t ** 2 * (e_y - c_y) 54 | return tf.stack([x, y], axis=-1) 55 | 56 | 57 | def renderer(curve_points, locations, colors, widths, H, W, K, canvas_color='gray', dtype='float32'): 58 | """ 59 | Renders the given brushstroke parameters onto a canvas. 60 | See Alg. 1 in https://arxiv.org/pdf/2103.17185.pdf. 61 | 62 | Args: 63 | curve_points (tensor): Points specifying the curves that will be rendered on the canvas, shape [N, S, 2]. 64 | locations (tensor): Location of each curve, shape [N, 2]. 65 | colors (tensor): Color of each curve, shape [N, 3]. 66 | widths (tensor): Width of each curve, shape [N, 1]. 67 | H (int): Height of the canvas. 68 | W (int): Width of the canvas. 69 | K (int): Number of brushstrokes to consider for each pixel, see Sec. C.2 of the paper (Arxiv version). 70 | canvas_color (str): Background color of the canvas. Options: 'gray', 'white', 'black', 'noise'. 71 | Returns: 72 | (tensor): The rendered canvas, shape [H, W, 3]. 73 | """ 74 | N, S, _ = curve_points.shape.as_list() 75 | # define coarse grid cell 76 | t_H = tf.linspace(0., float(H), int(H // 5)) 77 | t_W = tf.linspace(0., float(W), int(W // 5)) 78 | t_H = tf.cast(t_H, dtype=dtype) 79 | t_W = tf.cast(t_W, dtype=dtype) 80 | P_y, P_x = tf.meshgrid(t_W, t_H) 81 | P = tf.stack([P_x, P_y], axis=-1) # [32, 32, 2] 82 | # Compute now distances from every brushtroke center to every coarse grid cell 83 | #P_norms = tf.square(norm(P, axis=-1)) 84 | #B_center_norms = tf.square(norm(locations, axis=-1)) 85 | #P_dot_B_center = tf.einsum('xyf,Nf->xyN', P, locations) 86 | # [32, 32, N] 87 | #D_to_all_B_centers = tf.expand_dims(P_norms, axis=-1) + tf.expand_dims(tf.expand_dims(B_center_norms, axis=0), axis=0) - 2. * P_dot_B_center 88 | 89 | ##### 90 | D_to_all_B_centers = tf.reduce_sum(tf.square(tf.expand_dims(P, axis=-2) - locations), axis=-1) # [H // C, W // C, N] 91 | ##### 92 | 93 | # Find nearest brushstrokes' indices for every coarse grid cell 94 | _, idcs = tf.math.top_k(-D_to_all_B_centers, k=K) # [32, 32, K] 95 | # Now create 2 tensors (spatial size of a grid cell). One containing brushstroke locations, another containing 96 | # brushstroke colors. 97 | # [H // 10, W // 10, K, S, 2] 98 | canvas_with_nearest_Bs = tf.gather(params=curve_points, 99 | indices=idcs, 100 | batch_dims=0) 101 | # [H // 10, W // 10, K, 3] 102 | canvas_with_nearest_Bs_colors = tf.gather(params=colors, 103 | indices=idcs, 104 | batch_dims=0) 105 | # [H // 10, W // 10, K, 1] 106 | canvas_with_nearest_Bs_bs = tf.gather(params=widths, 107 | indices=idcs, 108 | batch_dims=0) 109 | # Resize those tensors to the full canvas size (not coarse grid) 110 | # First locations of points sampled from curves 111 | H_, W_, r1, r2, r3 = canvas_with_nearest_Bs.shape.as_list() 112 | canvas_with_nearest_Bs = tf.reshape(canvas_with_nearest_Bs, shape=(1, H_, W_, r1 * r2 * r3)) # [1, H // 10, W // 10, K * S * 2] 113 | canvas_with_nearest_Bs = tf.image.resize_nearest_neighbor(canvas_with_nearest_Bs, size=(H, W)) # [1, H, W, K * S * 2] 114 | canvas_with_nearest_Bs = tf.reshape(canvas_with_nearest_Bs, shape=(H, W, r1, r2, r3)) # [H, W, N, S, 2] 115 | # Now colors of curves 116 | H_, W_, r1, r2 = canvas_with_nearest_Bs_colors.shape.as_list() 117 | canvas_with_nearest_Bs_colors = tf.reshape(canvas_with_nearest_Bs_colors, shape=(1, H_, W_, r1 * r2)) # [1, H // 10, W // 10, K * 3] 118 | canvas_with_nearest_Bs_colors = tf.image.resize_nearest_neighbor(canvas_with_nearest_Bs_colors, size=(H, W)) # [1, H, W, K * 3] 119 | canvas_with_nearest_Bs_colors = tf.reshape(canvas_with_nearest_Bs_colors, shape=(H, W, r1, r2)) # [H, W, K, 3] 120 | # And with the brush size 121 | H_, W_, r1, r2 = canvas_with_nearest_Bs_bs.shape.as_list() 122 | canvas_with_nearest_Bs_bs = tf.reshape(canvas_with_nearest_Bs_bs, shape=(1, H_, W_, r1 * r2)) # [1, H // 10, W // 10, K] 123 | canvas_with_nearest_Bs_bs = tf.image.resize_nearest_neighbor(canvas_with_nearest_Bs_bs, size=(H, W)) # [1, H, W, K] 124 | canvas_with_nearest_Bs_bs = tf.reshape(canvas_with_nearest_Bs_bs, shape=(H, W, r1, r2)) # [H, W, K, 1] 125 | # Now create full-size canvas 126 | t_H = tf.linspace(0., float(H), H) 127 | t_W = tf.linspace(0., float(W), W) 128 | t_H = tf.cast(t_H, dtype=dtype) 129 | t_W = tf.cast(t_W, dtype=dtype) 130 | P_y, P_x = tf.meshgrid(t_W, t_H) 131 | P_full = tf.stack([P_x, P_y], axis=-1) # [H, W, 2] 132 | # Compute distance from every pixel on canvas to each (among nearest ones) line segment between points from curves 133 | canvas_with_nearest_Bs_a = tf.gather(canvas_with_nearest_Bs, axis=-2, indices=[i for i in range(S - 1)]) # start points of each line segment 134 | canvas_with_nearest_Bs_b = tf.gather(canvas_with_nearest_Bs, axis=-2, indices=[i for i in range(1, S)]) # end points of each line segments 135 | canvas_with_nearest_Bs_b_a = canvas_with_nearest_Bs_b - canvas_with_nearest_Bs_a # [H, W, N, S - 1, 2] 136 | P_full_canvas_with_nearest_Bs_a = tf.expand_dims(tf.expand_dims(P_full, axis=2), axis=2) - canvas_with_nearest_Bs_a # [H, W, K, S - 1, 2] 137 | # compute t value for which each pixel is closest to each line that goes through each line segment (among nearest ones) 138 | t = tf.reduce_sum(canvas_with_nearest_Bs_b_a * P_full_canvas_with_nearest_Bs_a, axis=-1) \ 139 | / (tf.reduce_sum(tf.square(canvas_with_nearest_Bs_b_a), axis=-1) + 1e-8) 140 | # if t value is outside [0, 1], then the nearest point on the line does not lie on the segment, so clip values of t 141 | t = tf.clip_by_value(t, clip_value_min=0.0, clip_value_max=1.0) 142 | # compute closest points on each line segment - [H, W, K, S - 1, 2] 143 | closest_points_on_each_line_segment = canvas_with_nearest_Bs_a + tf.expand_dims(t, axis=-1) * canvas_with_nearest_Bs_b_a 144 | # compute the distance from every pixel to the closest point on each line segment - [H, W, K, S - 1] 145 | dist_to_closest_point_on_line_segment = \ 146 | tf.reduce_sum(tf.square(tf.expand_dims(tf.expand_dims(P_full, axis=2), axis=2) - closest_points_on_each_line_segment), axis=-1) 147 | # and distance to the nearest bezier curve. 148 | D = tf.reduce_min(dist_to_closest_point_on_line_segment, axis=[-1, -2]) # [H, W] 149 | # Finally render curves on a canvas to obtain image. 150 | I_NNs_B_ranking = tf.nn.softmax(100000. * (1.0 / (1e-8 + tf.reduce_min(dist_to_closest_point_on_line_segment, axis=[-1]))), axis=-1) # [H, W, N] 151 | I_colors = tf.einsum('hwnf,hwn->hwf', canvas_with_nearest_Bs_colors, I_NNs_B_ranking) # [H, W, 3] 152 | bs = tf.einsum('hwnf,hwn->hwf', canvas_with_nearest_Bs_bs, I_NNs_B_ranking) # [H, W, 1] 153 | bs_mask = tf.math.sigmoid(bs - tf.expand_dims(D, axis=-1)) 154 | if canvas_color == 'gray': 155 | canvas = tf.ones(shape=I_colors.shape, dtype=dtype) * 0.5 156 | elif canvas_color == 'white': 157 | canvas = tf.ones(shape=I_colors.shape, dtype=dtype) 158 | elif canvas_color == 'black': 159 | canvas = tf.zeros(shape=I_colors.shape, dtype=dtype) 160 | elif canvas_color == 'noise': 161 | canvas = tf.random.normal(shape=I_colors.shape, dtype=dtype) * 0.1 162 | 163 | I = I_colors * bs_mask + (1 - bs_mask) * canvas 164 | return I 165 | 166 | 167 | #--------------------------------------------------------------------- 168 | # Losses 169 | #--------------------------------------------------------------------- 170 | def content_loss(features_lhs, features_rhs, layers, weights, scale_by_y=False): 171 | """ 172 | Computes the VGG perceptual loss. 173 | 174 | Args: 175 | features_lhs (dict of tensors): Dictionary of VGG activations. 176 | features_rhs (dict of tensors): Dictionary of VGG activations. 177 | layers (list of str): List specifying the layers to use. 178 | weights (list of floats): List specifying the weights for the used layers. 179 | 180 | Returns: 181 | VGG perceptual loss. 182 | """ 183 | 184 | feat_lhs = [features_lhs[key] for key in layers] 185 | feat_rhs = [features_rhs[key] for key in layers] 186 | 187 | if scale_by_y: 188 | losses = [w * tf.reduce_mean(tf.square(xf - yf) * tf.minimum(yf, tf.sigmoid(yf))) for w, xf, yf in zip(weights, feat_lhs, feat_rhs)] 189 | else: 190 | losses = [w * tf.reduce_mean(tf.square(xf - yf)) for w, xf, yf in zip(weights, feat_lhs, feat_rhs)] 191 | 192 | loss = tf.add_n(losses) 193 | return loss 194 | 195 | 196 | def get_gram_matrices(features): 197 | """ 198 | Computes the gram matrices for the given list of activations. 199 | 200 | Args: 201 | features (list of tensors): Dictionary of VGG activations. 202 | 203 | Returns: 204 | List of gram matrices. 205 | """ 206 | gram_matrices = [] 207 | for feature in features: 208 | gram_matrix = tf.einsum('bhwf,bhwl->bfl', feature, feature) 209 | B, H, W, C = feature.shape.as_list() 210 | gram_matrix /= tf.cast(H * W * C, dtype=tf.float32) 211 | gram_matrices.append(gram_matrix) 212 | return gram_matrices 213 | 214 | 215 | def style_loss(features_lhs, features_rhs, layers, weights): 216 | """ 217 | Computes the VGG gram matrix style loss. 218 | 219 | Args: 220 | features_lhs (dict of tensors): Dictionary of VGG activations. 221 | features_rhs (dict of tensors): Dictionary of VGG activations. 222 | layers (list of str): List specifying the layers to use. 223 | weights (list of floats): List specifying the weights for the used layers. 224 | 225 | Returns: 226 | VGG gram matrix style loss. 227 | """ 228 | feat_lhs = [features_lhs[key] for key in layers] 229 | feat_rhs = [features_rhs[key] for key in layers] 230 | gram_matrices_lhs = get_gram_matrices(feat_lhs) 231 | gram_matrices_rhs = get_gram_matrices(feat_rhs) 232 | losses = [w * tf.reduce_sum(tf.square(gram_lhs - gram_rhs)) for w, gram_lhs, gram_rhs in zip(weights, gram_matrices_lhs, gram_matrices_rhs)] 233 | loss = tf.add_n(losses) 234 | return loss 235 | 236 | 237 | def get_nn_idxs(X, k, fetch_dist=False): 238 | """ 239 | For a given tensor compute all the nearest neighbor indices to each element. 240 | 241 | Args: 242 | x (tensor): Tensor of shape [B, N, F]. 243 | k (int): Number of nearest neighbors. 244 | fetch_dist (bool): Also return the distances. 245 | 246 | Returns: 247 | Tensor of shape [B, N, k]. 248 | 249 | """ 250 | r = tf.reduce_sum(X * X, 2, keepdims=True) 251 | D = r - 2 * tf.matmul(X, tf.transpose(X, perm=(0, 2, 1))) + tf.transpose(r, perm=(0, 2, 1)) 252 | X_top_vals, X_top_idxs = tf.math.top_k(-D, k=k, sorted=True, name=None) 253 | 254 | if fetch_dist: 255 | return X_top_idxs, X_top_vals 256 | else: 257 | return X_top_idxs 258 | 259 | 260 | def total_variation_loss(x_loc, s, e, K=10): 261 | 262 | def projection(z): 263 | x = tf.gather(z, axis=-1, indices=[0]) 264 | y = tf.gather(z, axis=-1, indices=[1]) 265 | return tf.concat([tf.square(x), tf.square(y), x * y], axis=-1) 266 | 267 | se_vec = e - s 268 | se_vec_proj = projection(se_vec) 269 | 270 | x_nn_idcs = get_nn_idxs(tf.expand_dims(x_loc, axis=0), k=K) 271 | 272 | x_nn_idcs = tf.squeeze(x_nn_idcs, axis=0) 273 | x_sig_nns = tf.gather(se_vec, indices=x_nn_idcs, axis=0, batch_dims=0) 274 | 275 | dist_to_centroid = tf.reduce_mean(tf.reduce_sum(tf.square(projection(x_sig_nns) - tf.expand_dims(projection(se_vec), axis=-2)), axis=-1)) 276 | return dist_to_centroid 277 | 278 | 279 | def draw_projection_loss(location, s, e, draw_curve_position, draw_curve_vector, draw_strength): 280 | dist = tf.reduce_sum(tf.square(tf.expand_dims(draw_curve_position, axis=1) - location), axis=-1) 281 | _, idcs = tf.math.top_k(-dist, k=draw_strength) # [num_points, K] 282 | se_vec = e - s 283 | strokes_vec_nn = tf.gather(se_vec, indices=idcs, axis=0) # [num_points, K, 2] 284 | strokes_vec_nn /= (norm(strokes_vec_nn, axis=-1, keepdims=True) + 1e-6) 285 | curves_vec = draw_curve_vector / (norm(draw_curve_vector, axis=-1, keepdims=True) + 1e-6) 286 | projection = tf.abs(tf.einsum('mki,mi->mk', strokes_vec_nn, curves_vec)) # [num_points, num_strokes] 287 | projection_loss = tf.reduce_mean(tf.square(1 - projection)) 288 | return projection_loss 289 | 290 | 291 | def curviture_loss(s, e, c): 292 | v1 = s - c 293 | v2 = e - c 294 | dist_se = norm(e - s, axis=-1) + 1e-6 295 | return tf.reduce_mean(norm(v1 + v2, axis=-1) / dist_se) 296 | 297 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow==7.2.0 2 | tqdm 3 | requests==2.24.0 4 | scikit-image==0.17.2 5 | numpy==1.19.1 6 | scikit-learn 7 | streamlit==0.82.0 8 | streamlit-drawable-canvas 9 | stqdm 10 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import requests 3 | import os 4 | from PIL import Image 5 | import numpy as np 6 | from skimage.segmentation import slic 7 | from scipy.spatial import ConvexHull 8 | 9 | 10 | #------------------------------------------------------------------ 11 | # I/O 12 | #------------------------------------------------------------------ 13 | 14 | def download_weights(url, name): 15 | """ 16 | Downloads the checkpoint file specified by 'url'. 17 | 18 | Args: 19 | url (str): URL specifying the checkpoint file. 20 | name (str): Name under which the checkpoint file will be stored. 21 | 22 | Returns: 23 | (str): Path to the checkpoint file. 24 | """ 25 | ckpt_dir = 'pretrained_weights' 26 | ckpt_file = os.path.join(ckpt_dir, name) 27 | if not os.path.exists(ckpt_file): 28 | print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}') 29 | if not os.path.exists(ckpt_dir): 30 | os.makedirs(ckpt_dir) 31 | 32 | response = requests.get(url, stream=True) 33 | total_size_in_bytes = int(response.headers.get('content-length', 0)) 34 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) 35 | 36 | # first create temp file, in case the download fails 37 | ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp') 38 | with open(ckpt_file_temp, 'wb') as file: 39 | for data in response.iter_content(chunk_size=1024): 40 | progress_bar.update(len(data)) 41 | file.write(data) 42 | progress_bar.close() 43 | 44 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: 45 | print('An error occured while downloading, please try again.') 46 | if os.path.exists(ckpt_file_temp): 47 | os.remove(ckpt_file_temp) 48 | else: 49 | # if download was successful, rename the temp file 50 | os.rename(ckpt_file_temp, ckpt_file) 51 | return ckpt_file 52 | 53 | 54 | #------------------------------------------------------------------ 55 | # Brushstrokes 56 | #------------------------------------------------------------------ 57 | 58 | def clusters_to_strokes(segments, img, H, W, sec_scale=0.001, width_scale=1): 59 | segments += np.abs(np.min(segments)) 60 | num_clusters = np.max(segments) 61 | clusters_params = {'center': [], 62 | 's': [], 63 | 'e': [], 64 | 'bp1': [], 65 | 'bp2': [], 66 | 'num_pixels': [], 67 | 'stddev': [], 68 | 'width': [], 69 | 'color_rgb': [] 70 | } 71 | 72 | for cluster_idx in range(num_clusters + 1): 73 | cluster_mask = segments==cluster_idx 74 | if np.sum(cluster_mask) < 5: continue 75 | cluster_mask_nonzeros = np.nonzero(cluster_mask) 76 | 77 | cluster_points = np.stack((cluster_mask_nonzeros[0], cluster_mask_nonzeros[1]), axis=-1) 78 | try: 79 | convex_hull = ConvexHull(cluster_points) 80 | except: 81 | continue 82 | 83 | # find the two points (pixels) in the cluster that have the largest distance between them 84 | border_points = cluster_points[convex_hull.simplices.reshape(-1)] 85 | dist = np.sum((np.expand_dims(border_points, axis=1) - border_points)**2, axis=-1) 86 | max_idx_a, max_idx_b = np.nonzero(dist == np.max(dist)) 87 | point_a = border_points[max_idx_a[0]] 88 | point_b = border_points[max_idx_b[0]] 89 | # compute the two intersection points of the line that goes orthogonal to point_a and point_b 90 | v_ba = point_b - point_a 91 | v_orth = np.array([v_ba[1], -v_ba[0]]) 92 | m = (point_a + point_b) / 2.0 93 | n = m + 0.5 * v_orth 94 | p = cluster_points[convex_hull.simplices][:, 0] 95 | q = cluster_points[convex_hull.simplices][:, 1] 96 | u = - ((m[..., 0] - n[..., 0]) * (m[..., 1] - p[..., 1]) - (m[..., 1] - n[..., 1]) * (m[..., 0] - p[..., 0])) \ 97 | / ((m[..., 0] - n[..., 0]) * (p[..., 1] - q[..., 1]) - (m[..., 1] - n[..., 1]) * (p[..., 0] - q[..., 0])) 98 | intersec_idcs = np.logical_and(u >= 0, u <= 1) 99 | intersec_points = p + u.reshape(-1, 1) * (q - p) 100 | intersec_points = intersec_points[intersec_idcs] 101 | 102 | width = np.sum((intersec_points[0] - intersec_points[1])**2) 103 | 104 | if width == 0.0: continue 105 | 106 | clusters_params['s'].append(point_a / img.shape[:2]) 107 | clusters_params['e'].append(point_b / img.shape[:2]) 108 | clusters_params['bp1'].append(intersec_points[0] / img.shape[:2]) 109 | clusters_params['bp2'].append(intersec_points[1] / img.shape[:2]) 110 | clusters_params['width'].append(np.sum((intersec_points[0] - intersec_points[1])**2)) 111 | 112 | clusters_params['color_rgb'].append(np.mean(img[cluster_mask], axis=0)) 113 | center_x = np.mean(cluster_mask_nonzeros[0]) / img.shape[0] 114 | center_y = np.mean(cluster_mask_nonzeros[1]) / img.shape[1] 115 | clusters_params['center'].append(np.array([center_x, center_y])) 116 | clusters_params['num_pixels'].append(np.sum(cluster_mask)) 117 | clusters_params['stddev'].append(np.mean(np.std(img[cluster_mask], axis=0))) 118 | 119 | for key in clusters_params.keys(): 120 | clusters_params[key] = np.array(clusters_params[key]) 121 | 122 | N = clusters_params['center'].shape[0] 123 | 124 | stddev = clusters_params['stddev'] 125 | rel_num_pixels = 5 * clusters_params['num_pixels'] / np.sqrt(H * W) 126 | 127 | location = clusters_params['center'] 128 | num_pixels_per_cluster = clusters_params['num_pixels'].reshape(-1, 1) 129 | s = clusters_params['s'] 130 | e = clusters_params['e'] 131 | cluster_width = clusters_params['width'] 132 | 133 | location[..., 0] *= H 134 | location[..., 1] *= W 135 | s[..., 0] *= H 136 | s[..., 1] *= W 137 | e[..., 0] *= H 138 | e[..., 1] *= W 139 | 140 | s -= location 141 | e -= location 142 | 143 | color = clusters_params['color_rgb'] 144 | 145 | c = (s + e) / 2. + np.stack([np.random.uniform(low=-1, high=1, size=[N]), 146 | np.random.uniform(low=-1, high=1, size=[N])], 147 | axis=-1) 148 | 149 | sec_center = (s + e + c) / 3. 150 | s -= sec_center 151 | e -= sec_center 152 | c -= sec_center 153 | 154 | rel_num_pix_quant = np.quantile(rel_num_pixels, q=[0.3, 0.99]) 155 | width_quant = np.quantile(cluster_width, q=[0.3, 0.99]) 156 | rel_num_pixels = np.clip(rel_num_pixels, rel_num_pix_quant[0], rel_num_pix_quant[1]) 157 | cluster_width = np.clip(cluster_width, width_quant[0], width_quant[1]) 158 | width = width_scale * rel_num_pixels.reshape(-1, 1) * cluster_width.reshape(-1, 1) 159 | s, e, c = [x * sec_scale for x in [s, e, c]] 160 | 161 | location, s, e, c, width, color = [x.astype(np.float32) for x in [location, s, e, c, width, color]] 162 | 163 | return location, s, e, c, width, color 164 | 165 | 166 | def initialize_brushstrokes(content_img, num_strokes, canvas_height, canvas_width, sec_scale, width_scale, init='sp'): 167 | 168 | if init == 'random': 169 | # Brushstroke colors 170 | color = np.random.rand(num_strokes, 3) 171 | 172 | # Brushstroke widths 173 | width = np.random.rand(num_strokes, 1) * width_scale 174 | 175 | # Brushstroke locations 176 | location = np.stack([np.random.rand(num_strokes) * canvas_height, np.random.rand(num_strokes) * canvas_width], axis=-1) 177 | 178 | # Start point for the Bezier curves 179 | s = np.stack([np.random.uniform(low=-1, high=1, size=num_strokes) * canvas_height, 180 | np.random.uniform(low=-1, high=1, size=num_strokes) * canvas_width], axis=-1) 181 | 182 | # End point for the Bezier curves 183 | e = np.stack([np.random.uniform(low=-1, high=1, size=num_strokes) * canvas_height, 184 | np.random.uniform(low=-1, high=1, size=num_strokes) * canvas_width], axis=-1) 185 | 186 | # Control point for the Bezier curves 187 | c = np.stack([np.random.uniform(low=-1, high=1, size=num_strokes) * canvas_height, 188 | np.random.uniform(low=-1, high=1, size=num_strokes) * canvas_width], axis=-1) 189 | 190 | # Normalize control points 191 | sec_center = (s + e + c) / 3.0 192 | s, e, c = [x - sec_center for x in [s, e, c]] 193 | s, e, c = [x * sec_scale for x in [s, e, c]] 194 | else: 195 | segments = slic(content_img, 196 | n_segments=num_strokes, 197 | min_size_factor=0.02, 198 | max_size_factor=4., 199 | compactness=2, 200 | sigma=1, 201 | start_label=0) 202 | 203 | location, s, e, c, width, color = clusters_to_strokes(segments, 204 | content_img, 205 | canvas_height, 206 | canvas_width, 207 | sec_scale=sec_scale, 208 | width_scale=width_scale) 209 | 210 | return location, s, e, c, width, color 211 | --------------------------------------------------------------------------------