├── .eslintrc.json ├── .gitignore ├── .tern-project ├── LICENSE ├── README.md ├── setup.py └── style_transfer ├── __init__.py ├── cli.py ├── sRGB Profile.icc ├── sqrtm.py ├── style_transfer.py ├── web_interface.py └── web_static ├── index.html ├── jquery-3.5.1.min.js ├── main.css ├── main.js └── normalize.css /.eslintrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "env": { 3 | "browser": true, 4 | "es6": true, 5 | "jquery": true 6 | }, 7 | "extends": "eslint:recommended", 8 | "parserOptions": { 9 | "ecmaVersion": 2015 10 | }, 11 | "rules": { 12 | "indent": [ 13 | "error", 14 | 4 15 | ], 16 | "linebreak-style": [ 17 | "error", 18 | "unix" 19 | ], 20 | "quotes": [ 21 | "error", 22 | "double" 23 | ], 24 | "semi": [ 25 | "error", 26 | "always" 27 | ], 28 | "no-console": "off" 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv* 2 | __pycache__ 3 | .ipynb_checkpoints 4 | out.* 5 | trace.json 6 | *.egg-info 7 | -------------------------------------------------------------------------------- /.tern-project: -------------------------------------------------------------------------------- 1 | { 2 | "ecmaVersion": "6", 3 | "libs": [ 4 | "browser", 5 | "jquery" 6 | ], 7 | "loadEagerly": [], 8 | "dontLoad": [ 9 | "node_modules/**" 10 | ], 11 | "plugins": { 12 | "doc_comment": true 13 | } 14 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 Katherine Crowson 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # style-transfer-pytorch 2 | 3 | An implementation of neural style transfer ([A Neural Algorithm of Artistic Style](https://arxiv.org/abs/1508.06576)) in PyTorch, supporting CPUs and Nvidia GPUs. It does automatic multi-scale (coarse-to-fine) stylization to produce high-quality high resolution stylizations, even up to print resolution if the GPUs have sufficient memory. If two GPUs are available, they can both be used to increase the maximum resolution. (Using two GPUs is not faster than using one.) 4 | 5 | The algorithm has been modified from that in the literature by: 6 | 7 | - Using the PyTorch pre-trained VGG-19 weights instead of the original VGG-19 weights 8 | 9 | - Changing the padding mode of the first layer of VGG-19 to 'replicate', to reduce edge artifacts 10 | 11 | - When using average or L2 pooling, scaling the result by an empirically derived factor to ensure that the magnitude of the result stays the same on average (Gatys et al. (2015) did not do this) 12 | 13 | - Using [Wasserstein-2 style loss](https://wandb.ai/johnowhitaker/style_loss_showdown/reports/An-Explanation-of-Style-Transfer-with-a-Showdown-of-Different-Techniques--VmlldzozMDIzNjg0#style-loss-#3:-%22vincent's-loss%22) 14 | 15 | - Taking an exponential moving average over the iterates to reduce iterate noise (each new scale is initialized with the previous scale's averaged iterate) 16 | 17 | - Warm-starting the Adam optimizer with scaled-up versions of its first and second moment buffers at the beginning of each new scale, to prevent noise from being added to the iterates at the beginning of each scale 18 | 19 | - Using non-equal weights for the style layers to improve visual quality 20 | 21 | - Stylizing the image at progressively larger scales, each greater by a factor of sqrt(2) (this is improved from the multi-scale scheme given in Gatys et al. (2016)) 22 | 23 | ## Example outputs (click for the full-sized version) 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | ## Installation 36 | 37 | [Python](https://www.python.org/downloads/) 3.6+ is required. 38 | 39 | [PyTorch](https://pytorch.org) is required: follow [their installation instructions](https://pytorch.org/get-started/locally/) before proceeding. If you do not have an Nvidia GPU, select None for CUDA. On Linux, you can find out your CUDA version using the `nvidia-smi` command. PyTorch packages for CUDA versions lower than yours will work, but select the highest you can. 40 | 41 | To install `style-transfer-pytorch`, first clone the repository, then run the command: 42 | 43 | ```sh 44 | pip install -e PATH_TO_REPO 45 | ``` 46 | 47 | This will install the `style_transfer` CLI tool. `style_transfer` uses a pre-trained VGG-19 model (Simonyan et al.), which is 548MB in size, and will download it when first run. 48 | 49 | If you have a supported GPU and `style_transfer` is using the CPU, try using the argument `--device cuda:0` to force it to try to use the first CUDA GPU. This should print an informative error message. 50 | 51 | ## Colab 52 | 53 | You can try `style_transfer` without installing it locally by using the [official Colab](https://colab.research.google.com/drive/1Tmuwmncao5E3D-5tTIVQjRy2YQ8JdpoB?usp=sharing). 54 | 55 | ## Basic usage 56 | 57 | ```sh 58 | style_transfer CONTENT_IMAGE STYLE_IMAGE [STYLE_IMAGE ...] [-o OUTPUT_IMAGE] 59 | ``` 60 | 61 | Input images will be converted to sRGB when loaded, and output images have the sRGB colorspace. If the output image is a TIFF file, it will be written with 16 bits per channel. Alpha channels in the inputs will be ignored. 62 | 63 | `style_transfer` has many optional arguments: run it with the `--help` argument to see a full list. Particularly notable ones include: 64 | 65 | - `--web` enables a simple web interface while the program is running that allows you to watch its progress. It runs on port 8080 by default, but you can change it with `--port`. If you just want to view the current image and refresh it manually, you can go to `/image`. 66 | 67 | - `--devices` manually sets the PyTorch device names. It can be set to `cpu` to force it to run on the CPU on a machine with a supported GPU, or to e.g. `cuda:1` (zero indexed) to select the second CUDA GPU. Two GPUs can be specified, for instance `--devices cuda:0 cuda:1`. `style_transfer` will automatically use the first visible CUDA GPU, falling back to the CPU, if it is omitted. 68 | 69 | - `-s` (`--end-scale`) sets the maximum image dimension (height and width) of the output. A large image (e.g. 2896x2172) can take around fifteen minutes to generate on an RTX 3090 and will require nearly all of its 24GB of memory. Since both memory usage and runtime increase linearly in the number of pixels (quadratically in the value of the `--end-scale` parameter), users with less GPU memory or who do not want to wait very long are encouraged to use smaller resolutions. The default is 512. 70 | 71 | - `-sw` (`--style-weights`) specifies factors for the weighted average of multiple styles if there is more than one style image specified. These factors are automatically normalized to sum to 1. If omitted, the styles will be blended equally. 72 | 73 | - `-cw` (`--content-weight`) sets the degree to which features from the content image are included in the output image. The default is 0.015. 74 | 75 | - `-tw` (`--tv-weight`) sets the strength of the smoothness prior. The default is 2. 76 | 77 | ## References 78 | 79 | 1. L. Gatys, A. Ecker, M. Bethge (2015), "[A Neural Algorithm of Artistic Style](https://arxiv.org/abs/1508.06576)" 80 | 81 | 1. L. Gatys, A. Ecker, M. Bethge, A. Hertzmann, E. Shechtman (2016), "[Controlling Perceptual Factors in Neural Style Transfer](https://arxiv.org/abs/1611.07865)" 82 | 83 | 1. J. Johnson, A. Alahi, L. Fei-Fei (2016), "[Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155)" 84 | 85 | 1. A. Mahendran, A. Vedaldi (2014), "[Understanding Deep Image Representations by Inverting Them](https://arxiv.org/abs/1412.0035)" 86 | 87 | 1. D. Kingma, J. Ba (2014), "[Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980)" 88 | 89 | 1. K. Simonyan, A. Zisserman (2014), "[Very Deep Convolutional Networks for Large-Scale Image Recognition](https://arxiv.org/abs/1409.1556)" 90 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name='style-transfer-pytorch', 5 | version='0.1', 6 | description='Neural style transfer in PyTorch.', 7 | # TODO: add long description 8 | long_description='Neural style transfer in PyTorch.', 9 | url='https://github.com/crowsonkb/style-transfer-pytorch', 10 | author='Katherine Crowson', 11 | author_email='crowsonkb@gmail.com', 12 | license='MIT', 13 | packages=['style_transfer'], 14 | entry_points={ 15 | 'console_scripts': ['style_transfer=style_transfer.cli:main'], 16 | }, 17 | package_data={'style_transfer': ['*.icc', 'web_static/*']}, 18 | install_requires=['aiohttp>=3.7.2', 19 | 'dataclasses>=0.8;python_version<"3.7"', 20 | 'numpy>=1.19.2', 21 | 'Pillow>=8.0.0', 22 | 'tifffile>=2020.9.3', 23 | 'torch>=1.7.1', 24 | 'torchvision>=0.8.2', 25 | 'tqdm>=4.46.0'], 26 | python_requires=">=3.6", 27 | # TODO: Add classifiers 28 | classifiers=[], 29 | ) 30 | -------------------------------------------------------------------------------- /style_transfer/__init__.py: -------------------------------------------------------------------------------- 1 | """Neural style transfer (https://arxiv.org/abs/1508.06576) in PyTorch.""" 2 | 3 | from pathlib import Path 4 | 5 | srgb_profile = (Path(__file__).resolve().parent / 'sRGB Profile.icc').read_bytes() 6 | del Path 7 | 8 | from .style_transfer import STIterate, StyleTransfer 9 | from .web_interface import WebInterface 10 | -------------------------------------------------------------------------------- /style_transfer/cli.py: -------------------------------------------------------------------------------- 1 | """Neural style transfer (https://arxiv.org/abs/1508.06576) in PyTorch.""" 2 | 3 | import argparse 4 | import atexit 5 | from dataclasses import asdict 6 | import io 7 | import json 8 | from pathlib import Path 9 | import platform 10 | import sys 11 | import webbrowser 12 | 13 | import numpy as np 14 | from PIL import Image, ImageCms 15 | from tifffile import TIFF, TiffWriter 16 | import torch 17 | import torch.multiprocessing as mp 18 | from tqdm import tqdm 19 | 20 | from . import srgb_profile, StyleTransfer, WebInterface 21 | 22 | 23 | def prof_to_prof(image, src_prof, dst_prof, **kwargs): 24 | src_prof = io.BytesIO(src_prof) 25 | dst_prof = io.BytesIO(dst_prof) 26 | return ImageCms.profileToProfile(image, src_prof, dst_prof, **kwargs) 27 | 28 | 29 | def load_image(path, proof_prof=None): 30 | src_prof = dst_prof = srgb_profile 31 | try: 32 | image = Image.open(path) 33 | if 'icc_profile' in image.info: 34 | src_prof = image.info['icc_profile'] 35 | else: 36 | image = image.convert('RGB') 37 | if proof_prof is None: 38 | if src_prof == dst_prof: 39 | return image.convert('RGB') 40 | return prof_to_prof(image, src_prof, dst_prof, outputMode='RGB') 41 | proof_prof = Path(proof_prof).read_bytes() 42 | cmyk = prof_to_prof(image, src_prof, proof_prof, outputMode='CMYK') 43 | return prof_to_prof(cmyk, proof_prof, dst_prof, outputMode='RGB') 44 | except OSError as err: 45 | print_error(err) 46 | sys.exit(1) 47 | 48 | 49 | def save_pil(path, image): 50 | try: 51 | kwargs = {'icc_profile': srgb_profile} 52 | if path.suffix.lower() in {'.jpg', '.jpeg'}: 53 | kwargs['quality'] = 95 54 | kwargs['subsampling'] = 0 55 | elif path.suffix.lower() == '.webp': 56 | kwargs['quality'] = 95 57 | image.save(path, **kwargs) 58 | except (OSError, ValueError) as err: 59 | print_error(err) 60 | sys.exit(1) 61 | 62 | 63 | def save_tiff(path, image): 64 | tag = ('InterColorProfile', TIFF.DATATYPES.BYTE, len(srgb_profile), srgb_profile, False) 65 | try: 66 | with TiffWriter(path) as writer: 67 | writer.save(image, photometric='rgb', resolution=(72, 72), extratags=[tag]) 68 | except OSError as err: 69 | print_error(err) 70 | sys.exit(1) 71 | 72 | 73 | def save_image(path, image): 74 | path = Path(path) 75 | tqdm.write(f'Writing image to {path}.') 76 | if isinstance(image, Image.Image): 77 | save_pil(path, image) 78 | elif isinstance(image, np.ndarray) and path.suffix.lower() in {'.tif', '.tiff'}: 79 | save_tiff(path, image) 80 | else: 81 | raise ValueError('Unsupported combination of image type and extension') 82 | 83 | 84 | def get_safe_scale(w, h, dim): 85 | """Given a w x h content image and that a dim x dim square does not 86 | exceed GPU memory, compute a safe end_scale for that content image.""" 87 | return int(pow(w / h if w > h else h / w, 1/2) * dim) 88 | 89 | 90 | def setup_exceptions(): 91 | try: 92 | from IPython.core.ultratb import FormattedTB 93 | sys.excepthook = FormattedTB(mode='Plain', color_scheme='Neutral') 94 | except ImportError: 95 | pass 96 | 97 | 98 | def fix_start_method(): 99 | if platform.system() == 'Darwin': 100 | mp.set_start_method('spawn') 101 | 102 | 103 | def print_error(err): 104 | print('\033[31m{}:\033[0m {}'.format(type(err).__name__, err), file=sys.stderr) 105 | 106 | 107 | class Callback: 108 | def __init__(self, st, args, image_type='pil', web_interface=None): 109 | self.st = st 110 | self.args = args 111 | self.image_type = image_type 112 | self.web_interface = web_interface 113 | self.iterates = [] 114 | self.progress = None 115 | 116 | def __call__(self, iterate): 117 | self.iterates.append(asdict(iterate)) 118 | if iterate.i == 1: 119 | self.progress = tqdm(total=iterate.i_max, dynamic_ncols=True) 120 | msg = 'Size: {}x{}, iteration: {}, loss: {:g}' 121 | tqdm.write(msg.format(iterate.w, iterate.h, iterate.i, iterate.loss)) 122 | self.progress.update() 123 | if self.web_interface is not None: 124 | self.web_interface.put_iterate(iterate, self.st.get_image_tensor()) 125 | if iterate.i == iterate.i_max: 126 | self.progress.close() 127 | if max(iterate.w, iterate.h) != self.args.end_scale: 128 | save_image(self.args.output, self.st.get_image(self.image_type)) 129 | else: 130 | if self.web_interface is not None: 131 | self.web_interface.put_done() 132 | elif iterate.i % self.args.save_every == 0: 133 | save_image(self.args.output, self.st.get_image(self.image_type)) 134 | 135 | def close(self): 136 | if self.progress is not None: 137 | self.progress.close() 138 | 139 | def get_trace(self): 140 | return {'args': self.args.__dict__, 'iterates': self.iterates} 141 | 142 | 143 | def main(): 144 | setup_exceptions() 145 | fix_start_method() 146 | 147 | p = argparse.ArgumentParser(description=__doc__, 148 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 149 | 150 | def arg_info(arg): 151 | defaults = StyleTransfer.stylize.__kwdefaults__ 152 | default_types = StyleTransfer.stylize.__annotations__ 153 | return {'default': defaults[arg], 'type': default_types[arg]} 154 | 155 | p.add_argument('content', type=str, help='the content image') 156 | p.add_argument('styles', type=str, nargs='+', metavar='style', help='the style images') 157 | p.add_argument('--output', '-o', type=str, default='out.png', 158 | help='the output image') 159 | p.add_argument('--style-weights', '-sw', type=float, nargs='+', default=None, 160 | metavar='STYLE_WEIGHT', help='the relative weights for each style image') 161 | p.add_argument('--devices', type=str, default=[], nargs='+', 162 | help='the device names to use (omit for auto)') 163 | p.add_argument('--random-seed', '-r', type=int, default=0, 164 | help='the random seed') 165 | p.add_argument('--content-weight', '-cw', **arg_info('content_weight'), 166 | help='the content weight') 167 | p.add_argument('--tv-weight', '-tw', **arg_info('tv_weight'), 168 | help='the smoothing weight') 169 | p.add_argument('--optimizer', **arg_info('optimizer'), 170 | choices=['adam', 'lbfgs'], 171 | help='the optimizer to use') 172 | p.add_argument('--min-scale', '-ms', **arg_info('min_scale'), 173 | help='the minimum scale (max image dim), in pixels') 174 | p.add_argument('--end-scale', '-s', type=str, default='512', 175 | help='the final scale (max image dim), in pixels') 176 | p.add_argument('--iterations', '-i', **arg_info('iterations'), 177 | help='the number of iterations per scale') 178 | p.add_argument('--initial-iterations', '-ii', **arg_info('initial_iterations'), 179 | help='the number of iterations on the first scale') 180 | p.add_argument('--save-every', type=int, default=50, 181 | help='save the image every SAVE_EVERY iterations') 182 | p.add_argument('--step-size', '-ss', **arg_info('step_size'), 183 | help='the step size (learning rate) for Adam') 184 | p.add_argument('--avg-decay', '-ad', **arg_info('avg_decay'), 185 | help='the EMA decay rate for iterate averaging') 186 | p.add_argument('--init', **arg_info('init'), 187 | choices=['content', 'gray', 'uniform', 'normal', 'style_stats'], 188 | help='the initial image') 189 | p.add_argument('--style-scale-fac', **arg_info('style_scale_fac'), 190 | help='the relative scale of the style to the content') 191 | p.add_argument('--style-size', **arg_info('style_size'), 192 | help='the fixed scale of the style at different content scales') 193 | p.add_argument('--pooling', type=str, default='max', choices=['max', 'average', 'l2'], 194 | help='the model\'s pooling mode') 195 | p.add_argument('--proof', type=str, default=None, 196 | help='the ICC color profile (CMYK) for soft proofing the content and styles') 197 | p.add_argument('--web', default=False, action='store_true', help='enable the web interface') 198 | p.add_argument('--host', type=str, default='0.0.0.0', 199 | help='the host the web interface binds to') 200 | p.add_argument('--port', type=int, default=8080, 201 | help='the port the web interface binds to') 202 | p.add_argument('--browser', type=str, default='', nargs='?', 203 | help='open a web browser (specify the browser if not system default)') 204 | 205 | args = p.parse_args() 206 | 207 | content_img = load_image(args.content, args.proof) 208 | style_imgs = [load_image(img, args.proof) for img in args.styles] 209 | 210 | image_type = 'pil' 211 | if Path(args.output).suffix.lower() in {'.tif', '.tiff'}: 212 | image_type = 'np_uint16' 213 | 214 | devices = [torch.device(device) for device in args.devices] 215 | if not devices: 216 | devices = [torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')] 217 | if len(set(device.type for device in devices)) != 1: 218 | print('Devices must all be the same type.') 219 | sys.exit(1) 220 | if not 1 <= len(devices) <= 2: 221 | print('Only 1 or 2 devices are supported.') 222 | sys.exit(1) 223 | print('Using devices:', ' '.join(str(device) for device in devices)) 224 | 225 | if devices[0].type == 'cpu': 226 | print('CPU threads:', torch.get_num_threads()) 227 | if devices[0].type == 'cuda': 228 | for i, device in enumerate(devices): 229 | props = torch.cuda.get_device_properties(device) 230 | print(f'GPU {i} type: {props.name} (compute {props.major}.{props.minor})') 231 | print(f'GPU {i} RAM:', round(props.total_memory / 1024 / 1024), 'MB') 232 | 233 | end_scale = int(args.end_scale.rstrip('+')) 234 | if args.end_scale.endswith('+'): 235 | end_scale = get_safe_scale(*content_img.size, end_scale) 236 | args.end_scale = end_scale 237 | 238 | web_interface = None 239 | if args.web: 240 | web_interface = WebInterface(args.host, args.port) 241 | atexit.register(web_interface.close) 242 | 243 | for device in devices: 244 | torch.tensor(0).to(device) 245 | torch.manual_seed(args.random_seed) 246 | 247 | print('Loading model...') 248 | st = StyleTransfer(devices=devices, pooling=args.pooling) 249 | callback = Callback(st, args, image_type=image_type, web_interface=web_interface) 250 | atexit.register(callback.close) 251 | 252 | url = f'http://{args.host}:{args.port}/' 253 | if args.web: 254 | if args.browser: 255 | webbrowser.get(args.browser).open(url) 256 | elif args.browser is None: 257 | webbrowser.open(url) 258 | 259 | defaults = StyleTransfer.stylize.__kwdefaults__ 260 | st_kwargs = {k: v for k, v in args.__dict__.items() if k in defaults} 261 | try: 262 | st.stylize(content_img, style_imgs, **st_kwargs, callback=callback) 263 | except KeyboardInterrupt: 264 | pass 265 | 266 | output_image = st.get_image(image_type) 267 | if output_image is not None: 268 | save_image(args.output, output_image) 269 | with open('trace.json', 'w') as fp: 270 | json.dump(callback.get_trace(), fp, indent=4) 271 | 272 | 273 | if __name__ == '__main__': 274 | main() 275 | -------------------------------------------------------------------------------- /style_transfer/sRGB Profile.icc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/style-transfer-pytorch/e7e2c7134e3937be05ff9f5fcc0873fe5ceb6060/style_transfer/sRGB Profile.icc -------------------------------------------------------------------------------- /style_transfer/sqrtm.py: -------------------------------------------------------------------------------- 1 | """Matrix square roots with backward passes. 2 | 3 | Cleaned up from https://github.com/msubhransu/matrix-sqrt. 4 | """ 5 | 6 | import torch 7 | 8 | 9 | def sqrtm_ns(a, num_iters=10): 10 | if a.ndim < 2: 11 | raise RuntimeError('tensor of matrices must have at least 2 dimensions') 12 | if a.shape[-2] != a.shape[-1]: 13 | raise RuntimeError('tensor must be batches of square matrices') 14 | if num_iters < 0: 15 | raise RuntimeError('num_iters must not be negative') 16 | norm_a = a.pow(2).sum(dim=[-2, -1], keepdim=True).sqrt() 17 | y = a / norm_a 18 | eye = torch.eye(a.shape[-1], device=a.device, dtype=a.dtype) * 3 19 | z = torch.eye(a.shape[-1], device=a.device, dtype=a.dtype) 20 | z = z.repeat([*a.shape[:-2], 1, 1]) 21 | for i in range(num_iters): 22 | t = (eye - z @ y) / 2 23 | y = y @ t 24 | z = t @ z 25 | return y * norm_a.sqrt() 26 | 27 | 28 | class _MatrixSquareRootNSLyap(torch.autograd.Function): 29 | @staticmethod 30 | def forward(ctx, a, num_iters, num_iters_backward): 31 | z = sqrtm_ns(a, num_iters) 32 | ctx.save_for_backward(z, torch.tensor(num_iters_backward)) 33 | return z 34 | 35 | @staticmethod 36 | def backward(ctx, grad_output): 37 | z, num_iters = ctx.saved_tensors 38 | norm_z = z.pow(2).sum(dim=[-2, -1], keepdim=True).sqrt() 39 | a = z / norm_z 40 | eye = torch.eye(z.shape[-1], device=z.device, dtype=z.dtype) * 3 41 | q = grad_output / norm_z 42 | for i in range(num_iters): 43 | eye_a_a = eye - a @ a 44 | q = q = (q @ eye_a_a - a.transpose(-2, -1) @ (a.transpose(-2, -1) @ q - q @ a)) / 2 45 | if i < num_iters - 1: 46 | a = a @ eye_a_a / 2 47 | return q / 2, None, None 48 | 49 | 50 | def sqrtm_ns_lyap(a, num_iters=10, num_iters_backward=None): 51 | if num_iters_backward is None: 52 | num_iters_backward = num_iters 53 | if num_iters_backward < 0: 54 | raise RuntimeError('num_iters_backward must not be negative') 55 | return _MatrixSquareRootNSLyap.apply(a, num_iters, num_iters_backward) 56 | 57 | 58 | class _MatrixSquareRootEig(torch.autograd.Function): 59 | @staticmethod 60 | def forward(ctx, a): 61 | vals, vecs = torch.linalg.eigh(a) 62 | ctx.save_for_backward(vals, vecs) 63 | return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1) 64 | 65 | @staticmethod 66 | def backward(ctx, grad_output): 67 | vals, vecs = ctx.saved_tensors 68 | d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1) 69 | vecs_t = vecs.transpose(-2, -1) 70 | return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t 71 | 72 | 73 | def sqrtm_eig(a): 74 | if a.ndim < 2: 75 | raise RuntimeError('tensor of matrices must have at least 2 dimensions') 76 | if a.shape[-2] != a.shape[-1]: 77 | raise RuntimeError('tensor must be batches of square matrices') 78 | return _MatrixSquareRootEig.apply(a) 79 | -------------------------------------------------------------------------------- /style_transfer/style_transfer.py: -------------------------------------------------------------------------------- 1 | """Neural style transfer (https://arxiv.org/abs/1508.06576) in PyTorch.""" 2 | 3 | import copy 4 | from dataclasses import dataclass 5 | from functools import partial 6 | import time 7 | import warnings 8 | 9 | import numpy as np 10 | from PIL import Image 11 | import torch 12 | from torch import optim, nn 13 | from torch.nn import functional as F 14 | from torchvision import models, transforms 15 | from torchvision.transforms import functional as TF 16 | 17 | from . import sqrtm 18 | 19 | 20 | class VGGFeatures(nn.Module): 21 | poolings = {'max': nn.MaxPool2d, 'average': nn.AvgPool2d, 'l2': partial(nn.LPPool2d, 2)} 22 | pooling_scales = {'max': 1., 'average': 2., 'l2': 0.78} 23 | 24 | def __init__(self, layers, pooling='max'): 25 | super().__init__() 26 | self.layers = sorted(set(layers)) 27 | 28 | # The PyTorch pre-trained VGG-19 expects sRGB inputs in the range [0, 1] which are then 29 | # normalized according to this transform, unlike Simonyan et al.'s original model. 30 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 31 | std=[0.229, 0.224, 0.225]) 32 | 33 | # The PyTorch pre-trained VGG-19 has different parameters from Simonyan et al.'s original 34 | # model. 35 | self.model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features[:self.layers[-1] + 1] 36 | self.devices = [torch.device('cpu')] * len(self.model) 37 | 38 | # Reduces edge artifacts. 39 | self.model[0] = self._change_padding_mode(self.model[0], 'replicate') 40 | 41 | pool_scale = self.pooling_scales[pooling] 42 | for i, layer in enumerate(self.model): 43 | if pooling != 'max' and isinstance(layer, nn.MaxPool2d): 44 | # Changing the pooling type from max results in the scale of activations 45 | # changing, so rescale them. Gatys et al. (2015) do not do this. 46 | self.model[i] = Scale(self.poolings[pooling](2), pool_scale) 47 | 48 | self.model.eval() 49 | self.model.requires_grad_(False) 50 | 51 | @staticmethod 52 | def _change_padding_mode(conv, padding_mode): 53 | new_conv = nn.Conv2d(conv.in_channels, conv.out_channels, conv.kernel_size, 54 | stride=conv.stride, padding=conv.padding, 55 | padding_mode=padding_mode) 56 | with torch.no_grad(): 57 | new_conv.weight.copy_(conv.weight) 58 | new_conv.bias.copy_(conv.bias) 59 | return new_conv 60 | 61 | @staticmethod 62 | def _get_min_size(layers): 63 | last_layer = max(layers) 64 | min_size = 1 65 | for layer in [4, 9, 18, 27, 36]: 66 | if last_layer < layer: 67 | break 68 | min_size *= 2 69 | return min_size 70 | 71 | def distribute_layers(self, devices): 72 | for i, layer in enumerate(self.model): 73 | if i in devices: 74 | device = torch.device(devices[i]) 75 | self.model[i] = layer.to(device) 76 | self.devices[i] = device 77 | 78 | def forward(self, input, layers=None): 79 | layers = self.layers if layers is None else sorted(set(layers)) 80 | h, w = input.shape[2:4] 81 | min_size = self._get_min_size(layers) 82 | if min(h, w) < min_size: 83 | raise ValueError(f'Input is {h}x{w} but must be at least {min_size}x{min_size}') 84 | feats = {'input': input} 85 | input = self.normalize(input) 86 | for i in range(max(layers) + 1): 87 | input = self.model[i](input.to(self.devices[i])) 88 | if i in layers: 89 | feats[i] = input 90 | return feats 91 | 92 | 93 | class ScaledMSELoss(nn.Module): 94 | """Computes MSE scaled such that its gradient L1 norm is approximately 1. 95 | This differs from Gatys at al. (2015) and Johnson et al.""" 96 | 97 | def __init__(self, eps=1e-8): 98 | super().__init__() 99 | self.register_buffer('eps', torch.tensor(eps)) 100 | 101 | def extra_repr(self): 102 | return f'eps={self.eps:g}' 103 | 104 | def forward(self, input, target): 105 | diff = input - target 106 | return diff.pow(2).sum() / diff.abs().sum().add(self.eps) 107 | 108 | 109 | class ContentLoss(nn.Module): 110 | def __init__(self, target, eps=1e-8): 111 | super().__init__() 112 | self.register_buffer('target', target) 113 | self.loss = ScaledMSELoss(eps=eps) 114 | 115 | def forward(self, input): 116 | return self.loss(input, self.target) 117 | 118 | 119 | class ContentLossMSE(nn.Module): 120 | def __init__(self, target): 121 | super().__init__() 122 | self.register_buffer('target', target) 123 | self.loss = nn.MSELoss() 124 | 125 | def forward(self, input): 126 | return self.loss(input, self.target) 127 | 128 | 129 | class StyleLoss(nn.Module): 130 | def __init__(self, target, eps=1e-8): 131 | super().__init__() 132 | self.register_buffer('target', target) 133 | self.loss = ScaledMSELoss(eps=eps) 134 | 135 | @staticmethod 136 | def get_target(target): 137 | mat = target.flatten(-2) 138 | # The Gram matrix normalization differs from Gatys et al. (2015) and Johnson et al. 139 | return mat @ mat.transpose(-2, -1) / mat.shape[-1] 140 | 141 | def forward(self, input): 142 | return self.loss(self.get_target(input), self.target) 143 | 144 | 145 | def eye_like(x): 146 | return torch.eye(x.shape[-2], x.shape[-1], dtype=x.dtype, device=x.device).expand_as(x) 147 | 148 | 149 | class StyleLossW2(nn.Module): 150 | """Wasserstein-2 style loss.""" 151 | 152 | def __init__(self, target, eps=1e-4): 153 | super().__init__() 154 | self.sqrtm = partial(sqrtm.sqrtm_ns_lyap, num_iters=12) 155 | mean, srm = target 156 | cov = self.srm_to_cov(mean, srm) + eye_like(srm) * eps 157 | self.register_buffer('mean', mean) 158 | self.register_buffer('cov', cov) 159 | self.register_buffer('cov_sqrt', self.sqrtm(cov)) 160 | self.register_buffer('eps', mean.new_tensor(eps)) 161 | 162 | @staticmethod 163 | def get_target(target): 164 | """Compute the mean and second raw moment of the target activations. 165 | Unlike the covariance matrix, these are valid to combine linearly.""" 166 | mean = target.mean([-2, -1]) 167 | srm = torch.einsum('...chw,...dhw->...cd', target, target) / (target.shape[-2] * target.shape[-1]) 168 | return mean, srm 169 | 170 | @staticmethod 171 | def srm_to_cov(mean, srm): 172 | """Compute the covariance matrix from the mean and second raw moment.""" 173 | return srm - torch.einsum('...c,...d->...cd', mean, mean) 174 | 175 | def forward(self, input): 176 | mean, srm = self.get_target(input) 177 | cov = self.srm_to_cov(mean, srm) + eye_like(srm) * self.eps 178 | mean_diff = torch.mean((mean - self.mean) ** 2) 179 | sqrt_term = self.sqrtm(self.cov_sqrt @ cov @ self.cov_sqrt) 180 | cov_diff = torch.diagonal(self.cov + cov - 2 * sqrt_term, dim1=-2, dim2=-1).mean() 181 | return mean_diff + cov_diff 182 | 183 | 184 | class TVLoss(nn.Module): 185 | """L2 total variation loss (nine point stencil).""" 186 | 187 | def forward(self, input): 188 | input = F.pad(input, (1, 1, 1, 1), 'replicate') 189 | s1, s2 = slice(1, -1), slice(2, None) 190 | s3, s4 = slice(None, -1), slice(1, None) 191 | d1 = (input[..., s1, s2] - input[..., s1, s1]).pow(2).mean() / 3 192 | d2 = (input[..., s2, s1] - input[..., s1, s1]).pow(2).mean() / 3 193 | d3 = (input[..., s4, s4] - input[..., s3, s3]).pow(2).mean() / 12 194 | d4 = (input[..., s4, s3] - input[..., s3, s4]).pow(2).mean() / 12 195 | return 2 * (d1 + d2 + d3 + d4) 196 | 197 | 198 | class SumLoss(nn.ModuleList): 199 | def __init__(self, losses, verbose=False): 200 | super().__init__(losses) 201 | self.verbose = verbose 202 | 203 | def forward(self, *args, **kwargs): 204 | losses = [loss(*args, **kwargs) for loss in self] 205 | if self.verbose: 206 | for i, loss in enumerate(losses): 207 | print(f'({i}): {loss.item():g}') 208 | return sum(loss.to(losses[-1].device) for loss in losses) 209 | 210 | 211 | class Scale(nn.Module): 212 | def __init__(self, module, scale): 213 | super().__init__() 214 | self.module = module 215 | self.register_buffer('scale', torch.tensor(scale)) 216 | 217 | def extra_repr(self): 218 | return f'(scale): {self.scale.item():g}' 219 | 220 | def forward(self, *args, **kwargs): 221 | return self.module(*args, **kwargs) * self.scale 222 | 223 | 224 | class LayerApply(nn.Module): 225 | def __init__(self, module, layer): 226 | super().__init__() 227 | self.module = module 228 | self.layer = layer 229 | 230 | def extra_repr(self): 231 | return f'(layer): {self.layer!r}' 232 | 233 | def forward(self, input): 234 | return self.module(input[self.layer]) 235 | 236 | 237 | class EMA(nn.Module): 238 | """A bias-corrected exponential moving average, as in Kingma et al. (Adam).""" 239 | 240 | def __init__(self, input, decay): 241 | super().__init__() 242 | self.register_buffer('value', torch.zeros_like(input)) 243 | self.register_buffer('decay', torch.tensor(decay)) 244 | self.register_buffer('accum', torch.tensor(1.)) 245 | self.update(input) 246 | 247 | def get(self): 248 | return self.value / (1 - self.accum) 249 | 250 | def update(self, input): 251 | self.accum *= self.decay 252 | self.value *= self.decay 253 | self.value += (1 - self.decay) * input 254 | 255 | 256 | def size_to_fit(size, max_dim, scale_up=False): 257 | w, h = size 258 | if not scale_up and max(h, w) <= max_dim: 259 | return w, h 260 | new_w, new_h = max_dim, max_dim 261 | if h > w: 262 | new_w = round(max_dim * w / h) 263 | else: 264 | new_h = round(max_dim * h / w) 265 | return new_w, new_h 266 | 267 | 268 | def gen_scales(start, end): 269 | scale = end 270 | i = 0 271 | scales = set() 272 | while scale >= start: 273 | scales.add(scale) 274 | i += 1 275 | scale = round(end / pow(2, i/2)) 276 | return sorted(scales) 277 | 278 | 279 | def interpolate(*args, **kwargs): 280 | with warnings.catch_warnings(): 281 | warnings.simplefilter('ignore', UserWarning) 282 | return F.interpolate(*args, **kwargs) 283 | 284 | 285 | def scale_adam(state, shape): 286 | """Prepares a state dict to warm-start the Adam optimizer at a new scale.""" 287 | state = copy.deepcopy(state) 288 | for group in state['state'].values(): 289 | exp_avg, exp_avg_sq = group['exp_avg'], group['exp_avg_sq'] 290 | group['exp_avg'] = interpolate(exp_avg, shape, mode='bicubic') 291 | group['exp_avg_sq'] = interpolate(exp_avg_sq, shape, mode='bilinear').relu_() 292 | if 'max_exp_avg_sq' in group: 293 | max_exp_avg_sq = group['max_exp_avg_sq'] 294 | group['max_exp_avg_sq'] = interpolate(max_exp_avg_sq, shape, mode='bilinear').relu_() 295 | return state 296 | 297 | 298 | @dataclass 299 | class STIterate: 300 | w: int 301 | h: int 302 | i: int 303 | i_max: int 304 | loss: float 305 | time: float 306 | gpu_ram: int 307 | 308 | 309 | class StyleTransfer: 310 | def __init__(self, devices=['cpu'], pooling='max'): 311 | self.devices = [torch.device(device) for device in devices] 312 | self.image = None 313 | self.average = None 314 | 315 | # The default content and style layers follow Gatys et al. (2015). 316 | self.content_layers = [22] 317 | self.style_layers = [1, 6, 11, 20, 29] 318 | 319 | # The weighting of the style layers differs from Gatys et al. (2015) and Johnson et al. 320 | style_weights = [256, 64, 16, 4, 1] 321 | weight_sum = sum(abs(w) for w in style_weights) 322 | self.style_weights = [w / weight_sum for w in style_weights] 323 | 324 | self.model = VGGFeatures(self.style_layers + self.content_layers, pooling=pooling) 325 | 326 | if len(self.devices) == 1: 327 | device_plan = {0: self.devices[0]} 328 | elif len(self.devices) == 2: 329 | device_plan = {0: self.devices[0], 5: self.devices[1]} 330 | else: 331 | raise ValueError('Only 1 or 2 devices are supported.') 332 | 333 | self.model.distribute_layers(device_plan) 334 | 335 | def get_image_tensor(self): 336 | return self.average.get().detach()[0].clamp(0, 1) 337 | 338 | def get_image(self, image_type='pil'): 339 | if self.average is not None: 340 | image = self.get_image_tensor() 341 | if image_type.lower() == 'pil': 342 | return TF.to_pil_image(image) 343 | elif image_type.lower() == 'np_uint16': 344 | arr = image.cpu().movedim(0, 2).numpy() 345 | return np.uint16(np.round(arr * 65535)) 346 | else: 347 | raise ValueError("image_type must be 'pil' or 'np_uint16'") 348 | 349 | def stylize(self, content_image, style_images, *, 350 | style_weights=None, 351 | content_weight: float = 0.015, 352 | tv_weight: float = 2., 353 | optimizer: str = 'adam', 354 | min_scale: int = 128, 355 | end_scale: int = 512, 356 | iterations: int = 500, 357 | initial_iterations: int = 1000, 358 | step_size: float = 0.02, 359 | avg_decay: float = 0.99, 360 | init: str = 'content', 361 | style_scale_fac: float = 1., 362 | style_size: int = None, 363 | callback=None): 364 | 365 | min_scale = min(min_scale, end_scale) 366 | content_weights = [content_weight / len(self.content_layers)] * len(self.content_layers) 367 | 368 | if style_weights is None: 369 | style_weights = [1 / len(style_images)] * len(style_images) 370 | else: 371 | weight_sum = sum(abs(w) for w in style_weights) 372 | style_weights = [weight / weight_sum for weight in style_weights] 373 | if len(style_images) != len(style_weights): 374 | raise ValueError('style_images and style_weights must have the same length') 375 | 376 | tv_loss = Scale(LayerApply(TVLoss(), 'input'), tv_weight) 377 | 378 | scales = gen_scales(min_scale, end_scale) 379 | 380 | cw, ch = size_to_fit(content_image.size, scales[0], scale_up=True) 381 | if init == 'content': 382 | self.image = TF.to_tensor(content_image.resize((cw, ch), Image.BICUBIC))[None] 383 | elif init == 'gray': 384 | self.image = torch.rand([1, 3, ch, cw]) / 255 + 0.5 385 | elif init == 'uniform': 386 | self.image = torch.rand([1, 3, ch, cw]) 387 | elif init == 'normal': 388 | self.image = torch.empty([1, 3, ch, cw]) 389 | nn.init.trunc_normal_(self.image, mean=0.5, std=0.25, a=0, b=1) 390 | elif init == 'style_stats': 391 | means, variances = [], [] 392 | for i, image in enumerate(style_images): 393 | my_image = TF.to_tensor(image) 394 | means.append(my_image.mean(dim=(1, 2)) * style_weights[i]) 395 | variances.append(my_image.var(dim=(1, 2)) * style_weights[i]) 396 | means = sum(means) 397 | variances = sum(variances) 398 | channels = [] 399 | for mean, variance in zip(means, variances): 400 | channel = torch.empty([1, 1, ch, cw]) 401 | nn.init.trunc_normal_(channel, mean=mean, std=variance.sqrt(), a=0, b=1) 402 | channels.append(channel) 403 | self.image = torch.cat(channels, dim=1) 404 | else: 405 | raise ValueError("init must be one of 'content', 'gray', 'uniform', 'style_mean'") 406 | self.image = self.image.to(self.devices[0]) 407 | 408 | opt = None 409 | 410 | # Stylize the image at successively finer scales, each greater by a factor of sqrt(2). 411 | # This differs from the scheme given in Gatys et al. (2016). 412 | for scale in scales: 413 | if self.devices[0].type == 'cuda': 414 | torch.cuda.empty_cache() 415 | 416 | cw, ch = size_to_fit(content_image.size, scale, scale_up=True) 417 | content = TF.to_tensor(content_image.resize((cw, ch), Image.BICUBIC))[None] 418 | content = content.to(self.devices[0]) 419 | 420 | self.image = interpolate(self.image.detach(), (ch, cw), mode='bicubic').clamp(0, 1) 421 | self.average = EMA(self.image, avg_decay) 422 | self.image.requires_grad_() 423 | 424 | print(f'Processing content image ({cw}x{ch})...') 425 | content_feats = self.model(content, layers=self.content_layers) 426 | content_losses = [] 427 | for layer, weight in zip(self.content_layers, content_weights): 428 | target = content_feats[layer] 429 | content_losses.append(Scale(LayerApply(ContentLossMSE(target), layer), weight)) 430 | 431 | style_targets, style_losses = {}, [] 432 | for i, image in enumerate(style_images): 433 | if style_size is None: 434 | sw, sh = size_to_fit(image.size, round(scale * style_scale_fac)) 435 | else: 436 | sw, sh = size_to_fit(image.size, style_size) 437 | style = TF.to_tensor(image.resize((sw, sh), Image.BICUBIC))[None] 438 | style = style.to(self.devices[0]) 439 | print(f'Processing style image ({sw}x{sh})...') 440 | style_feats = self.model(style, layers=self.style_layers) 441 | # Take the weighted average of multiple style targets (Gram matrices). 442 | for layer in self.style_layers: 443 | target_mean, target_cov = StyleLossW2.get_target(style_feats[layer]) 444 | target_mean *= style_weights[i] 445 | target_cov *= style_weights[i] 446 | if layer not in style_targets: 447 | style_targets[layer] = target_mean, target_cov 448 | else: 449 | style_targets[layer][0].add_(target_mean) 450 | style_targets[layer][1].add_(target_cov) 451 | for layer, weight in zip(self.style_layers, self.style_weights): 452 | target = style_targets[layer] 453 | style_losses.append(Scale(LayerApply(StyleLossW2(target), layer), weight)) 454 | 455 | crit = SumLoss([*content_losses, *style_losses, tv_loss]) 456 | 457 | if optimizer == 'adam': 458 | opt2 = optim.Adam([self.image], lr=step_size, betas=(0.9, 0.99)) 459 | # Warm-start the Adam optimizer if this is not the first scale. 460 | if scale != scales[0]: 461 | opt_state = scale_adam(opt.state_dict(), (ch, cw)) 462 | opt2.load_state_dict(opt_state) 463 | opt = opt2 464 | elif optimizer == 'lbfgs': 465 | opt = optim.LBFGS([self.image], max_iter=1, history_size=10) 466 | else: 467 | raise ValueError("optimizer must be one of 'adam', 'lbfgs'") 468 | 469 | if self.devices[0].type == 'cuda': 470 | torch.cuda.empty_cache() 471 | 472 | def closure(): 473 | feats = self.model(self.image) 474 | loss = crit(feats) 475 | loss.backward() 476 | return loss 477 | 478 | actual_its = initial_iterations if scale == scales[0] else iterations 479 | for i in range(1, actual_its + 1): 480 | opt.zero_grad() 481 | loss = opt.step(closure) 482 | # Enforce box constraints, but not for L-BFGS because it will mess it up. 483 | if optimizer != 'lbfgs': 484 | with torch.no_grad(): 485 | self.image.clamp_(0, 1) 486 | self.average.update(self.image) 487 | if callback is not None: 488 | gpu_ram = 0 489 | for device in self.devices: 490 | if device.type == 'cuda': 491 | gpu_ram = max(gpu_ram, torch.cuda.max_memory_allocated(device)) 492 | callback(STIterate(w=cw, h=ch, i=i, i_max=actual_its, loss=loss.item(), 493 | time=time.time(), gpu_ram=gpu_ram)) 494 | 495 | # Initialize each new scale with the previous scale's averaged iterate. 496 | with torch.no_grad(): 497 | self.image.copy_(self.average.get()) 498 | 499 | return self.get_image() 500 | -------------------------------------------------------------------------------- /style_transfer/web_interface.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from dataclasses import dataclass, is_dataclass 3 | import io 4 | import json 5 | from pathlib import Path 6 | 7 | from aiohttp import web 8 | import torch 9 | import torch.multiprocessing as mp 10 | from torchvision.transforms import functional as TF 11 | 12 | from . import srgb_profile, STIterate 13 | 14 | 15 | @dataclass 16 | class WIIterate: 17 | iterate: STIterate 18 | image: torch.Tensor 19 | 20 | 21 | @dataclass 22 | class WIDone: 23 | pass 24 | 25 | 26 | @dataclass 27 | class WIStop: 28 | pass 29 | 30 | 31 | class DCJSONEncoder(json.JSONEncoder): 32 | def default(self, obj): 33 | if is_dataclass(obj): 34 | dct = dict(obj.__dict__) 35 | dct['_type'] = type(obj).__name__ 36 | return dct 37 | return super().default(obj) 38 | 39 | 40 | class WebInterface: 41 | def __init__(self, host, port): 42 | self.host = host 43 | self.port = port 44 | self.q = mp.Queue() 45 | self.encoder = DCJSONEncoder() 46 | self.image = None 47 | self.loop = None 48 | self.runner = None 49 | self.wss = [] 50 | 51 | self.app = web.Application() 52 | self.static_path = Path(__file__).resolve().parent / 'web_static' 53 | self.app.router.add_routes([web.get('/', self.handle_index), 54 | web.get('/image', self.handle_image), 55 | web.get('/websocket', self.handle_websocket), 56 | web.static('/', self.static_path)]) 57 | 58 | print(f'Starting web interface at http://{self.host}:{self.port}/') 59 | self.process = mp.Process(target=self.run) 60 | self.process.start() 61 | 62 | async def run_app(self): 63 | self.runner = web.AppRunner(self.app) 64 | await self.runner.setup() 65 | site = web.TCPSite(self.runner, self.host, self.port, shutdown_timeout=5) 66 | await site.start() 67 | while True: 68 | await asyncio.sleep(3600) 69 | 70 | async def process_events(self): 71 | while True: 72 | f = self.loop.run_in_executor(None, self.q.get) 73 | await f 74 | event = f.result() 75 | if isinstance(event, WIIterate): 76 | self.image = event.image 77 | await self.send_websocket_message(event.iterate) 78 | elif isinstance(event, WIDone): 79 | await self.send_websocket_message(event) 80 | if self.wss: 81 | print('Waiting for web clients to finish...') 82 | await asyncio.sleep(5) 83 | elif isinstance(event, WIStop): 84 | for ws in self.wss: 85 | await ws.close() 86 | if self.runner is not None: 87 | await self.runner.cleanup() 88 | self.loop.stop() 89 | return 90 | 91 | def compress_image(self): 92 | buf = io.BytesIO() 93 | TF.to_pil_image(self.image).save(buf, format='jpeg', icc_profile=srgb_profile, 94 | quality=95, subsampling=0) 95 | return buf.getvalue() 96 | 97 | async def handle_image(self, request): 98 | if self.image is None: 99 | raise web.HTTPNotFound() 100 | f = self.loop.run_in_executor(None, self.compress_image) 101 | await f 102 | return web.Response(body=f.result(), content_type='image/jpeg') 103 | 104 | async def handle_index(self, request): 105 | body = (self.static_path / 'index.html').read_bytes() 106 | return web.Response(body=body, content_type='text/html') 107 | 108 | async def handle_websocket(self, request): 109 | ws = web.WebSocketResponse() 110 | await ws.prepare(request) 111 | self.wss.append(ws) 112 | async for _ in ws: 113 | pass 114 | try: 115 | self.wss.remove(ws) 116 | except ValueError: 117 | pass 118 | return ws 119 | 120 | async def send_websocket_message(self, msg): 121 | for ws in self.wss: 122 | try: 123 | await ws.send_json(msg, dumps=self.encoder.encode) 124 | except ConnectionError: 125 | try: 126 | self.wss.remove(ws) 127 | except ValueError: 128 | pass 129 | 130 | def put_iterate(self, iterate, image): 131 | self.q.put_nowait(WIIterate(iterate, image.cpu())) 132 | 133 | def put_done(self): 134 | self.q.put(WIDone()) 135 | 136 | def close(self): 137 | self.q.put(WIStop()) 138 | self.process.join(12) 139 | 140 | def run(self): 141 | self.loop = asyncio.get_event_loop() 142 | asyncio.ensure_future(self.run_app()) 143 | asyncio.ensure_future(self.process_events()) 144 | try: 145 | self.loop.run_forever() 146 | except KeyboardInterrupt: 147 | self.q.put(WIStop()) 148 | self.loop.run_forever() 149 | -------------------------------------------------------------------------------- /style_transfer/web_static/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Style transfer 6 | 7 | 8 | 9 | 10 | 11 | 12 |

Style transfer

13 | 14 |

15 | Size: 0x0, 16 | iteration: 0/0, 17 | loss: none
18 | 19 | Iterations per second: 0
20 | 23 | Connecting to the backend... 24 |
25 | 26 | 27 | -------------------------------------------------------------------------------- /style_transfer/web_static/jquery-3.5.1.min.js: -------------------------------------------------------------------------------- 1 | /*! jQuery v3.5.1 | (c) JS Foundation and other contributors | jquery.org/license */ 2 | !function(e,t){"use strict";"object"==typeof module&&"object"==typeof module.exports?module.exports=e.document?t(e,!0):function(e){if(!e.document)throw new Error("jQuery requires a window with a document");return t(e)}:t(e)}("undefined"!=typeof window?window:this,function(C,e){"use strict";var t=[],r=Object.getPrototypeOf,s=t.slice,g=t.flat?function(e){return t.flat.call(e)}:function(e){return t.concat.apply([],e)},u=t.push,i=t.indexOf,n={},o=n.toString,v=n.hasOwnProperty,a=v.toString,l=a.call(Object),y={},m=function(e){return"function"==typeof e&&"number"!=typeof e.nodeType},x=function(e){return null!=e&&e===e.window},E=C.document,c={type:!0,src:!0,nonce:!0,noModule:!0};function b(e,t,n){var r,i,o=(n=n||E).createElement("script");if(o.text=e,t)for(r in c)(i=t[r]||t.getAttribute&&t.getAttribute(r))&&o.setAttribute(r,i);n.head.appendChild(o).parentNode.removeChild(o)}function w(e){return null==e?e+"":"object"==typeof e||"function"==typeof e?n[o.call(e)]||"object":typeof e}var f="3.5.1",S=function(e,t){return new S.fn.init(e,t)};function p(e){var t=!!e&&"length"in e&&e.length,n=w(e);return!m(e)&&!x(e)&&("array"===n||0===t||"number"==typeof t&&0+~]|"+M+")"+M+"*"),U=new RegExp(M+"|>"),X=new RegExp(F),V=new RegExp("^"+I+"$"),G={ID:new RegExp("^#("+I+")"),CLASS:new RegExp("^\\.("+I+")"),TAG:new RegExp("^("+I+"|[*])"),ATTR:new RegExp("^"+W),PSEUDO:new RegExp("^"+F),CHILD:new RegExp("^:(only|first|last|nth|nth-last)-(child|of-type)(?:\\("+M+"*(even|odd|(([+-]|)(\\d*)n|)"+M+"*(?:([+-]|)"+M+"*(\\d+)|))"+M+"*\\)|)","i"),bool:new RegExp("^(?:"+R+")$","i"),needsContext:new RegExp("^"+M+"*[>+~]|:(even|odd|eq|gt|lt|nth|first|last)(?:\\("+M+"*((?:-\\d)?\\d*)"+M+"*\\)|)(?=[^-]|$)","i")},Y=/HTML$/i,Q=/^(?:input|select|textarea|button)$/i,J=/^h\d$/i,K=/^[^{]+\{\s*\[native \w/,Z=/^(?:#([\w-]+)|(\w+)|\.([\w-]+))$/,ee=/[+~]/,te=new RegExp("\\\\[\\da-fA-F]{1,6}"+M+"?|\\\\([^\\r\\n\\f])","g"),ne=function(e,t){var n="0x"+e.slice(1)-65536;return t||(n<0?String.fromCharCode(n+65536):String.fromCharCode(n>>10|55296,1023&n|56320))},re=/([\0-\x1f\x7f]|^-?\d)|^-$|[^\0-\x1f\x7f-\uFFFF\w-]/g,ie=function(e,t){return t?"\0"===e?"\ufffd":e.slice(0,-1)+"\\"+e.charCodeAt(e.length-1).toString(16)+" ":"\\"+e},oe=function(){T()},ae=be(function(e){return!0===e.disabled&&"fieldset"===e.nodeName.toLowerCase()},{dir:"parentNode",next:"legend"});try{H.apply(t=O.call(p.childNodes),p.childNodes),t[p.childNodes.length].nodeType}catch(e){H={apply:t.length?function(e,t){L.apply(e,O.call(t))}:function(e,t){var n=e.length,r=0;while(e[n++]=t[r++]);e.length=n-1}}}function se(t,e,n,r){var i,o,a,s,u,l,c,f=e&&e.ownerDocument,p=e?e.nodeType:9;if(n=n||[],"string"!=typeof t||!t||1!==p&&9!==p&&11!==p)return n;if(!r&&(T(e),e=e||C,E)){if(11!==p&&(u=Z.exec(t)))if(i=u[1]){if(9===p){if(!(a=e.getElementById(i)))return n;if(a.id===i)return n.push(a),n}else if(f&&(a=f.getElementById(i))&&y(e,a)&&a.id===i)return n.push(a),n}else{if(u[2])return H.apply(n,e.getElementsByTagName(t)),n;if((i=u[3])&&d.getElementsByClassName&&e.getElementsByClassName)return H.apply(n,e.getElementsByClassName(i)),n}if(d.qsa&&!N[t+" "]&&(!v||!v.test(t))&&(1!==p||"object"!==e.nodeName.toLowerCase())){if(c=t,f=e,1===p&&(U.test(t)||z.test(t))){(f=ee.test(t)&&ye(e.parentNode)||e)===e&&d.scope||((s=e.getAttribute("id"))?s=s.replace(re,ie):e.setAttribute("id",s=S)),o=(l=h(t)).length;while(o--)l[o]=(s?"#"+s:":scope")+" "+xe(l[o]);c=l.join(",")}try{return H.apply(n,f.querySelectorAll(c)),n}catch(e){N(t,!0)}finally{s===S&&e.removeAttribute("id")}}}return g(t.replace($,"$1"),e,n,r)}function ue(){var r=[];return function e(t,n){return r.push(t+" ")>b.cacheLength&&delete e[r.shift()],e[t+" "]=n}}function le(e){return e[S]=!0,e}function ce(e){var t=C.createElement("fieldset");try{return!!e(t)}catch(e){return!1}finally{t.parentNode&&t.parentNode.removeChild(t),t=null}}function fe(e,t){var n=e.split("|"),r=n.length;while(r--)b.attrHandle[n[r]]=t}function pe(e,t){var n=t&&e,r=n&&1===e.nodeType&&1===t.nodeType&&e.sourceIndex-t.sourceIndex;if(r)return r;if(n)while(n=n.nextSibling)if(n===t)return-1;return e?1:-1}function de(t){return function(e){return"input"===e.nodeName.toLowerCase()&&e.type===t}}function he(n){return function(e){var t=e.nodeName.toLowerCase();return("input"===t||"button"===t)&&e.type===n}}function ge(t){return function(e){return"form"in e?e.parentNode&&!1===e.disabled?"label"in e?"label"in e.parentNode?e.parentNode.disabled===t:e.disabled===t:e.isDisabled===t||e.isDisabled!==!t&&ae(e)===t:e.disabled===t:"label"in e&&e.disabled===t}}function ve(a){return le(function(o){return o=+o,le(function(e,t){var n,r=a([],e.length,o),i=r.length;while(i--)e[n=r[i]]&&(e[n]=!(t[n]=e[n]))})})}function ye(e){return e&&"undefined"!=typeof e.getElementsByTagName&&e}for(e in d=se.support={},i=se.isXML=function(e){var t=e.namespaceURI,n=(e.ownerDocument||e).documentElement;return!Y.test(t||n&&n.nodeName||"HTML")},T=se.setDocument=function(e){var t,n,r=e?e.ownerDocument||e:p;return r!=C&&9===r.nodeType&&r.documentElement&&(a=(C=r).documentElement,E=!i(C),p!=C&&(n=C.defaultView)&&n.top!==n&&(n.addEventListener?n.addEventListener("unload",oe,!1):n.attachEvent&&n.attachEvent("onunload",oe)),d.scope=ce(function(e){return a.appendChild(e).appendChild(C.createElement("div")),"undefined"!=typeof e.querySelectorAll&&!e.querySelectorAll(":scope fieldset div").length}),d.attributes=ce(function(e){return e.className="i",!e.getAttribute("className")}),d.getElementsByTagName=ce(function(e){return e.appendChild(C.createComment("")),!e.getElementsByTagName("*").length}),d.getElementsByClassName=K.test(C.getElementsByClassName),d.getById=ce(function(e){return a.appendChild(e).id=S,!C.getElementsByName||!C.getElementsByName(S).length}),d.getById?(b.filter.ID=function(e){var t=e.replace(te,ne);return function(e){return e.getAttribute("id")===t}},b.find.ID=function(e,t){if("undefined"!=typeof t.getElementById&&E){var n=t.getElementById(e);return n?[n]:[]}}):(b.filter.ID=function(e){var n=e.replace(te,ne);return function(e){var t="undefined"!=typeof e.getAttributeNode&&e.getAttributeNode("id");return t&&t.value===n}},b.find.ID=function(e,t){if("undefined"!=typeof t.getElementById&&E){var n,r,i,o=t.getElementById(e);if(o){if((n=o.getAttributeNode("id"))&&n.value===e)return[o];i=t.getElementsByName(e),r=0;while(o=i[r++])if((n=o.getAttributeNode("id"))&&n.value===e)return[o]}return[]}}),b.find.TAG=d.getElementsByTagName?function(e,t){return"undefined"!=typeof t.getElementsByTagName?t.getElementsByTagName(e):d.qsa?t.querySelectorAll(e):void 0}:function(e,t){var n,r=[],i=0,o=t.getElementsByTagName(e);if("*"===e){while(n=o[i++])1===n.nodeType&&r.push(n);return r}return o},b.find.CLASS=d.getElementsByClassName&&function(e,t){if("undefined"!=typeof t.getElementsByClassName&&E)return t.getElementsByClassName(e)},s=[],v=[],(d.qsa=K.test(C.querySelectorAll))&&(ce(function(e){var t;a.appendChild(e).innerHTML="",e.querySelectorAll("[msallowcapture^='']").length&&v.push("[*^$]="+M+"*(?:''|\"\")"),e.querySelectorAll("[selected]").length||v.push("\\["+M+"*(?:value|"+R+")"),e.querySelectorAll("[id~="+S+"-]").length||v.push("~="),(t=C.createElement("input")).setAttribute("name",""),e.appendChild(t),e.querySelectorAll("[name='']").length||v.push("\\["+M+"*name"+M+"*="+M+"*(?:''|\"\")"),e.querySelectorAll(":checked").length||v.push(":checked"),e.querySelectorAll("a#"+S+"+*").length||v.push(".#.+[+~]"),e.querySelectorAll("\\\f"),v.push("[\\r\\n\\f]")}),ce(function(e){e.innerHTML="";var t=C.createElement("input");t.setAttribute("type","hidden"),e.appendChild(t).setAttribute("name","D"),e.querySelectorAll("[name=d]").length&&v.push("name"+M+"*[*^$|!~]?="),2!==e.querySelectorAll(":enabled").length&&v.push(":enabled",":disabled"),a.appendChild(e).disabled=!0,2!==e.querySelectorAll(":disabled").length&&v.push(":enabled",":disabled"),e.querySelectorAll("*,:x"),v.push(",.*:")})),(d.matchesSelector=K.test(c=a.matches||a.webkitMatchesSelector||a.mozMatchesSelector||a.oMatchesSelector||a.msMatchesSelector))&&ce(function(e){d.disconnectedMatch=c.call(e,"*"),c.call(e,"[s!='']:x"),s.push("!=",F)}),v=v.length&&new RegExp(v.join("|")),s=s.length&&new RegExp(s.join("|")),t=K.test(a.compareDocumentPosition),y=t||K.test(a.contains)?function(e,t){var n=9===e.nodeType?e.documentElement:e,r=t&&t.parentNode;return e===r||!(!r||1!==r.nodeType||!(n.contains?n.contains(r):e.compareDocumentPosition&&16&e.compareDocumentPosition(r)))}:function(e,t){if(t)while(t=t.parentNode)if(t===e)return!0;return!1},D=t?function(e,t){if(e===t)return l=!0,0;var n=!e.compareDocumentPosition-!t.compareDocumentPosition;return n||(1&(n=(e.ownerDocument||e)==(t.ownerDocument||t)?e.compareDocumentPosition(t):1)||!d.sortDetached&&t.compareDocumentPosition(e)===n?e==C||e.ownerDocument==p&&y(p,e)?-1:t==C||t.ownerDocument==p&&y(p,t)?1:u?P(u,e)-P(u,t):0:4&n?-1:1)}:function(e,t){if(e===t)return l=!0,0;var n,r=0,i=e.parentNode,o=t.parentNode,a=[e],s=[t];if(!i||!o)return e==C?-1:t==C?1:i?-1:o?1:u?P(u,e)-P(u,t):0;if(i===o)return pe(e,t);n=e;while(n=n.parentNode)a.unshift(n);n=t;while(n=n.parentNode)s.unshift(n);while(a[r]===s[r])r++;return r?pe(a[r],s[r]):a[r]==p?-1:s[r]==p?1:0}),C},se.matches=function(e,t){return se(e,null,null,t)},se.matchesSelector=function(e,t){if(T(e),d.matchesSelector&&E&&!N[t+" "]&&(!s||!s.test(t))&&(!v||!v.test(t)))try{var n=c.call(e,t);if(n||d.disconnectedMatch||e.document&&11!==e.document.nodeType)return n}catch(e){N(t,!0)}return 0":{dir:"parentNode",first:!0}," ":{dir:"parentNode"},"+":{dir:"previousSibling",first:!0},"~":{dir:"previousSibling"}},preFilter:{ATTR:function(e){return e[1]=e[1].replace(te,ne),e[3]=(e[3]||e[4]||e[5]||"").replace(te,ne),"~="===e[2]&&(e[3]=" "+e[3]+" "),e.slice(0,4)},CHILD:function(e){return e[1]=e[1].toLowerCase(),"nth"===e[1].slice(0,3)?(e[3]||se.error(e[0]),e[4]=+(e[4]?e[5]+(e[6]||1):2*("even"===e[3]||"odd"===e[3])),e[5]=+(e[7]+e[8]||"odd"===e[3])):e[3]&&se.error(e[0]),e},PSEUDO:function(e){var t,n=!e[6]&&e[2];return G.CHILD.test(e[0])?null:(e[3]?e[2]=e[4]||e[5]||"":n&&X.test(n)&&(t=h(n,!0))&&(t=n.indexOf(")",n.length-t)-n.length)&&(e[0]=e[0].slice(0,t),e[2]=n.slice(0,t)),e.slice(0,3))}},filter:{TAG:function(e){var t=e.replace(te,ne).toLowerCase();return"*"===e?function(){return!0}:function(e){return e.nodeName&&e.nodeName.toLowerCase()===t}},CLASS:function(e){var t=m[e+" "];return t||(t=new RegExp("(^|"+M+")"+e+"("+M+"|$)"))&&m(e,function(e){return t.test("string"==typeof e.className&&e.className||"undefined"!=typeof e.getAttribute&&e.getAttribute("class")||"")})},ATTR:function(n,r,i){return function(e){var t=se.attr(e,n);return null==t?"!="===r:!r||(t+="","="===r?t===i:"!="===r?t!==i:"^="===r?i&&0===t.indexOf(i):"*="===r?i&&-1:\x20\t\r\n\f]*)[\x20\t\r\n\f]*\/?>(?:<\/\1>|)$/i;function D(e,n,r){return m(n)?S.grep(e,function(e,t){return!!n.call(e,t,e)!==r}):n.nodeType?S.grep(e,function(e){return e===n!==r}):"string"!=typeof n?S.grep(e,function(e){return-1)[^>]*|#([\w-]+))$/;(S.fn.init=function(e,t,n){var r,i;if(!e)return this;if(n=n||j,"string"==typeof e){if(!(r="<"===e[0]&&">"===e[e.length-1]&&3<=e.length?[null,e,null]:q.exec(e))||!r[1]&&t)return!t||t.jquery?(t||n).find(e):this.constructor(t).find(e);if(r[1]){if(t=t instanceof S?t[0]:t,S.merge(this,S.parseHTML(r[1],t&&t.nodeType?t.ownerDocument||t:E,!0)),N.test(r[1])&&S.isPlainObject(t))for(r in t)m(this[r])?this[r](t[r]):this.attr(r,t[r]);return this}return(i=E.getElementById(r[2]))&&(this[0]=i,this.length=1),this}return e.nodeType?(this[0]=e,this.length=1,this):m(e)?void 0!==n.ready?n.ready(e):e(S):S.makeArray(e,this)}).prototype=S.fn,j=S(E);var L=/^(?:parents|prev(?:Until|All))/,H={children:!0,contents:!0,next:!0,prev:!0};function O(e,t){while((e=e[t])&&1!==e.nodeType);return e}S.fn.extend({has:function(e){var t=S(e,this),n=t.length;return this.filter(function(){for(var e=0;e\x20\t\r\n\f]*)/i,he=/^$|^module$|\/(?:java|ecma)script/i;ce=E.createDocumentFragment().appendChild(E.createElement("div")),(fe=E.createElement("input")).setAttribute("type","radio"),fe.setAttribute("checked","checked"),fe.setAttribute("name","t"),ce.appendChild(fe),y.checkClone=ce.cloneNode(!0).cloneNode(!0).lastChild.checked,ce.innerHTML="",y.noCloneChecked=!!ce.cloneNode(!0).lastChild.defaultValue,ce.innerHTML="",y.option=!!ce.lastChild;var ge={thead:[1,"","
"],col:[2,"","
"],tr:[2,"","
"],td:[3,"","
"],_default:[0,"",""]};function ve(e,t){var n;return n="undefined"!=typeof e.getElementsByTagName?e.getElementsByTagName(t||"*"):"undefined"!=typeof e.querySelectorAll?e.querySelectorAll(t||"*"):[],void 0===t||t&&A(e,t)?S.merge([e],n):n}function ye(e,t){for(var n=0,r=e.length;n",""]);var me=/<|&#?\w+;/;function xe(e,t,n,r,i){for(var o,a,s,u,l,c,f=t.createDocumentFragment(),p=[],d=0,h=e.length;d\s*$/g;function qe(e,t){return A(e,"table")&&A(11!==t.nodeType?t:t.firstChild,"tr")&&S(e).children("tbody")[0]||e}function Le(e){return e.type=(null!==e.getAttribute("type"))+"/"+e.type,e}function He(e){return"true/"===(e.type||"").slice(0,5)?e.type=e.type.slice(5):e.removeAttribute("type"),e}function Oe(e,t){var n,r,i,o,a,s;if(1===t.nodeType){if(Y.hasData(e)&&(s=Y.get(e).events))for(i in Y.remove(t,"handle events"),s)for(n=0,r=s[i].length;n").attr(n.scriptAttrs||{}).prop({charset:n.scriptCharset,src:n.url}).on("load error",i=function(e){r.remove(),i=null,e&&t("error"===e.type?404:200,e.type)}),E.head.appendChild(r[0])},abort:function(){i&&i()}}});var Ut,Xt=[],Vt=/(=)\?(?=&|$)|\?\?/;S.ajaxSetup({jsonp:"callback",jsonpCallback:function(){var e=Xt.pop()||S.expando+"_"+Ct.guid++;return this[e]=!0,e}}),S.ajaxPrefilter("json jsonp",function(e,t,n){var r,i,o,a=!1!==e.jsonp&&(Vt.test(e.url)?"url":"string"==typeof e.data&&0===(e.contentType||"").indexOf("application/x-www-form-urlencoded")&&Vt.test(e.data)&&"data");if(a||"jsonp"===e.dataTypes[0])return r=e.jsonpCallback=m(e.jsonpCallback)?e.jsonpCallback():e.jsonpCallback,a?e[a]=e[a].replace(Vt,"$1"+r):!1!==e.jsonp&&(e.url+=(Et.test(e.url)?"&":"?")+e.jsonp+"="+r),e.converters["script json"]=function(){return o||S.error(r+" was not called"),o[0]},e.dataTypes[0]="json",i=C[r],C[r]=function(){o=arguments},n.always(function(){void 0===i?S(C).removeProp(r):C[r]=i,e[r]&&(e.jsonpCallback=t.jsonpCallback,Xt.push(r)),o&&m(i)&&i(o[0]),o=i=void 0}),"script"}),y.createHTMLDocument=((Ut=E.implementation.createHTMLDocument("").body).innerHTML="

",2===Ut.childNodes.length),S.parseHTML=function(e,t,n){return"string"!=typeof e?[]:("boolean"==typeof t&&(n=t,t=!1),t||(y.createHTMLDocument?((r=(t=E.implementation.createHTMLDocument("")).createElement("base")).href=E.location.href,t.head.appendChild(r)):t=E),o=!n&&[],(i=N.exec(e))?[t.createElement(i[1])]:(i=xe([e],t,o),o&&o.length&&S(o).remove(),S.merge([],i.childNodes)));var r,i,o},S.fn.load=function(e,t,n){var r,i,o,a=this,s=e.indexOf(" ");return-1").append(S.parseHTML(e)).find(r):e)}).always(n&&function(e,t){a.each(function(){n.apply(this,o||[e.responseText,t,e])})}),this},S.expr.pseudos.animated=function(t){return S.grep(S.timers,function(e){return t===e.elem}).length},S.offset={setOffset:function(e,t,n){var r,i,o,a,s,u,l=S.css(e,"position"),c=S(e),f={};"static"===l&&(e.style.position="relative"),s=c.offset(),o=S.css(e,"top"),u=S.css(e,"left"),("absolute"===l||"fixed"===l)&&-1<(o+u).indexOf("auto")?(a=(r=c.position()).top,i=r.left):(a=parseFloat(o)||0,i=parseFloat(u)||0),m(t)&&(t=t.call(e,n,S.extend({},s))),null!=t.top&&(f.top=t.top-s.top+a),null!=t.left&&(f.left=t.left-s.left+i),"using"in t?t.using.call(e,f):("number"==typeof f.top&&(f.top+="px"),"number"==typeof f.left&&(f.left+="px"),c.css(f))}},S.fn.extend({offset:function(t){if(arguments.length)return void 0===t?this:this.each(function(e){S.offset.setOffset(this,t,e)});var e,n,r=this[0];return r?r.getClientRects().length?(e=r.getBoundingClientRect(),n=r.ownerDocument.defaultView,{top:e.top+n.pageYOffset,left:e.left+n.pageXOffset}):{top:0,left:0}:void 0},position:function(){if(this[0]){var e,t,n,r=this[0],i={top:0,left:0};if("fixed"===S.css(r,"position"))t=r.getBoundingClientRect();else{t=this.offset(),n=r.ownerDocument,e=r.offsetParent||n.documentElement;while(e&&(e===n.body||e===n.documentElement)&&"static"===S.css(e,"position"))e=e.parentNode;e&&e!==r&&1===e.nodeType&&((i=S(e).offset()).top+=S.css(e,"borderTopWidth",!0),i.left+=S.css(e,"borderLeftWidth",!0))}return{top:t.top-i.top-S.css(r,"marginTop",!0),left:t.left-i.left-S.css(r,"marginLeft",!0)}}},offsetParent:function(){return this.map(function(){var e=this.offsetParent;while(e&&"static"===S.css(e,"position"))e=e.offsetParent;return e||re})}}),S.each({scrollLeft:"pageXOffset",scrollTop:"pageYOffset"},function(t,i){var o="pageYOffset"===i;S.fn[t]=function(e){return $(this,function(e,t,n){var r;if(x(e)?r=e:9===e.nodeType&&(r=e.defaultView),void 0===n)return r?r[i]:e[t];r?r.scrollTo(o?r.pageXOffset:n,o?n:r.pageYOffset):e[t]=n},t,e,arguments.length)}}),S.each(["top","left"],function(e,n){S.cssHooks[n]=$e(y.pixelPosition,function(e,t){if(t)return t=Be(e,n),Me.test(t)?S(e).position()[n]+"px":t})}),S.each({Height:"height",Width:"width"},function(a,s){S.each({padding:"inner"+a,content:s,"":"outer"+a},function(r,o){S.fn[o]=function(e,t){var n=arguments.length&&(r||"boolean"!=typeof e),i=r||(!0===e||!0===t?"margin":"border");return $(this,function(e,t,n){var r;return x(e)?0===o.indexOf("outer")?e["inner"+a]:e.document.documentElement["client"+a]:9===e.nodeType?(r=e.documentElement,Math.max(e.body["scroll"+a],r["scroll"+a],e.body["offset"+a],r["offset"+a],r["client"+a])):void 0===n?S.css(e,t,i):S.style(e,t,n,i)},s,n?e:void 0,n)}})}),S.each(["ajaxStart","ajaxStop","ajaxComplete","ajaxError","ajaxSuccess","ajaxSend"],function(e,t){S.fn[t]=function(e){return this.on(t,e)}}),S.fn.extend({bind:function(e,t,n){return this.on(e,null,t,n)},unbind:function(e,t){return this.off(e,null,t)},delegate:function(e,t,n,r){return this.on(t,e,n,r)},undelegate:function(e,t,n){return 1===arguments.length?this.off(e,"**"):this.off(t,e||"**",n)},hover:function(e,t){return this.mouseenter(e).mouseleave(t||e)}}),S.each("blur focus focusin focusout resize scroll click dblclick mousedown mouseup mousemove mouseover mouseout mouseenter mouseleave change select submit keydown keypress keyup contextmenu".split(" "),function(e,n){S.fn[n]=function(e,t){return 0 { 16 | return this.value / (1 - Math.pow(this.decay, this.t)); 17 | }; 18 | 19 | this.update = (time) => { 20 | this.t += 1; 21 | if (this.t) { 22 | this.value *= this.decay; 23 | this.value += (1 - this.decay) * (time - this.last_time); 24 | } 25 | this.last_time = time; 26 | return this.get(); 27 | }; 28 | } 29 | 30 | 31 | function reloadImage() { 32 | if (canLoadImage) { 33 | canLoadImage = false; 34 | let width = $("#image").attr("width"); 35 | let height = $("#image").attr("height"); 36 | $("#image").attr("id", "backup-image"); 37 | let img = $(""); 38 | img.attr("src", "image?time=" + Date.now()); 39 | img.attr("width", width); 40 | img.attr("height", height); 41 | img.on("load", onImageLoad); 42 | img.on("error", onImageError); 43 | $("#backup-image").after(img); 44 | } 45 | } 46 | 47 | 48 | function onImageLoad() { 49 | $("#backup-image").remove(); 50 | $("#image").css("display", ""); 51 | setTimeout(() => {canLoadImage = true;}, 100); 52 | } 53 | 54 | 55 | function onImageError() { 56 | $("#image").remove(); 57 | $("#backup-image").css("display", ""); 58 | ws.close(); 59 | } 60 | 61 | 62 | function wsConnect() { 63 | let protocol = window.location.protocol.replace("http", "ws"); 64 | ws = new WebSocket(protocol + "//" + window.location.host + "/websocket"); 65 | 66 | ws.onopen = () => { 67 | $("#status").text("Waiting for the first iteration..."); 68 | }; 69 | 70 | ws.onclose = () => { 71 | if (!done) { 72 | $("#status").text("Lost the connection to the backend."); 73 | $("#status").css("display", ""); 74 | } 75 | }; 76 | 77 | ws.onerror = ws.onclose; 78 | 79 | ws.onmessage = (e) => { 80 | let msg = JSON.parse(e.data); 81 | let dpr = Math.min(window.devicePixelRatio, 2); 82 | switch (msg._type) { 83 | case "STIterate": 84 | $("#image").attr("width", msg.w / dpr); 85 | $("#image").attr("height", msg.h / dpr); 86 | $("#w").text(msg.w); 87 | $("#h").text(msg.h); 88 | $("#i").text(msg.i); 89 | $("#i-max").text(msg.i_max); 90 | $("#loss").text(msg.loss.toFixed(6)); 91 | if (!average || msg.i == 1) { 92 | average = new AverageTime(0.9); 93 | $("#ips").text("0"); 94 | } 95 | average.update(msg.time); 96 | if (average.t > 0) { 97 | $("#ips").text((1 / average.get()).toFixed(2)); 98 | } 99 | if (msg.gpu_ram) { 100 | $("#gpu-ram").text((msg.gpu_ram / 1024 / 1024).toFixed()); 101 | $("#gpu-wrap").css("display", ""); 102 | } 103 | $("#status").css("display", "none"); 104 | reloadImage(); 105 | break; 106 | case "WIDone": 107 | $("#status").text("Iteration finished."); 108 | $("#status").css("display", ""); 109 | done = true; 110 | $("#image").off(); 111 | canLoadImage = true; 112 | reloadImage(); 113 | ws.close(); 114 | break; 115 | default: 116 | console.log(msg); 117 | } 118 | }; 119 | } 120 | 121 | $(document).ready(() => { 122 | wsConnect(); 123 | }); 124 | -------------------------------------------------------------------------------- /style_transfer/web_static/normalize.css: -------------------------------------------------------------------------------- 1 | /*! normalize.css v8.0.1 | MIT License | github.com/necolas/normalize.css */ 2 | 3 | /* Document 4 | ========================================================================== */ 5 | 6 | /** 7 | * 1. Correct the line height in all browsers. 8 | * 2. Prevent adjustments of font size after orientation changes in iOS. 9 | */ 10 | 11 | html { 12 | line-height: 1.15; /* 1 */ 13 | -webkit-text-size-adjust: 100%; /* 2 */ 14 | } 15 | 16 | /* Sections 17 | ========================================================================== */ 18 | 19 | /** 20 | * Remove the margin in all browsers. 21 | */ 22 | 23 | body { 24 | margin: 0; 25 | } 26 | 27 | /** 28 | * Render the `main` element consistently in IE. 29 | */ 30 | 31 | main { 32 | display: block; 33 | } 34 | 35 | /** 36 | * Correct the font size and margin on `h1` elements within `section` and 37 | * `article` contexts in Chrome, Firefox, and Safari. 38 | */ 39 | 40 | h1 { 41 | font-size: 2em; 42 | margin: 0.67em 0; 43 | } 44 | 45 | /* Grouping content 46 | ========================================================================== */ 47 | 48 | /** 49 | * 1. Add the correct box sizing in Firefox. 50 | * 2. Show the overflow in Edge and IE. 51 | */ 52 | 53 | hr { 54 | box-sizing: content-box; /* 1 */ 55 | height: 0; /* 1 */ 56 | overflow: visible; /* 2 */ 57 | } 58 | 59 | /** 60 | * 1. Correct the inheritance and scaling of font size in all browsers. 61 | * 2. Correct the odd `em` font sizing in all browsers. 62 | */ 63 | 64 | pre { 65 | font-family: monospace, monospace; /* 1 */ 66 | font-size: 1em; /* 2 */ 67 | } 68 | 69 | /* Text-level semantics 70 | ========================================================================== */ 71 | 72 | /** 73 | * Remove the gray background on active links in IE 10. 74 | */ 75 | 76 | a { 77 | background-color: transparent; 78 | } 79 | 80 | /** 81 | * 1. Remove the bottom border in Chrome 57- 82 | * 2. Add the correct text decoration in Chrome, Edge, IE, Opera, and Safari. 83 | */ 84 | 85 | abbr[title] { 86 | border-bottom: none; /* 1 */ 87 | text-decoration: underline; /* 2 */ 88 | text-decoration: underline dotted; /* 2 */ 89 | } 90 | 91 | /** 92 | * Add the correct font weight in Chrome, Edge, and Safari. 93 | */ 94 | 95 | b, 96 | strong { 97 | font-weight: bolder; 98 | } 99 | 100 | /** 101 | * 1. Correct the inheritance and scaling of font size in all browsers. 102 | * 2. Correct the odd `em` font sizing in all browsers. 103 | */ 104 | 105 | code, 106 | kbd, 107 | samp { 108 | font-family: monospace, monospace; /* 1 */ 109 | font-size: 1em; /* 2 */ 110 | } 111 | 112 | /** 113 | * Add the correct font size in all browsers. 114 | */ 115 | 116 | small { 117 | font-size: 80%; 118 | } 119 | 120 | /** 121 | * Prevent `sub` and `sup` elements from affecting the line height in 122 | * all browsers. 123 | */ 124 | 125 | sub, 126 | sup { 127 | font-size: 75%; 128 | line-height: 0; 129 | position: relative; 130 | vertical-align: baseline; 131 | } 132 | 133 | sub { 134 | bottom: -0.25em; 135 | } 136 | 137 | sup { 138 | top: -0.5em; 139 | } 140 | 141 | /* Embedded content 142 | ========================================================================== */ 143 | 144 | /** 145 | * Remove the border on images inside links in IE 10. 146 | */ 147 | 148 | img { 149 | border-style: none; 150 | } 151 | 152 | /* Forms 153 | ========================================================================== */ 154 | 155 | /** 156 | * 1. Change the font styles in all browsers. 157 | * 2. Remove the margin in Firefox and Safari. 158 | */ 159 | 160 | button, 161 | input, 162 | optgroup, 163 | select, 164 | textarea { 165 | font-family: inherit; /* 1 */ 166 | font-size: 100%; /* 1 */ 167 | line-height: 1.15; /* 1 */ 168 | margin: 0; /* 2 */ 169 | } 170 | 171 | /** 172 | * Show the overflow in IE. 173 | * 1. Show the overflow in Edge. 174 | */ 175 | 176 | button, 177 | input { /* 1 */ 178 | overflow: visible; 179 | } 180 | 181 | /** 182 | * Remove the inheritance of text transform in Edge, Firefox, and IE. 183 | * 1. Remove the inheritance of text transform in Firefox. 184 | */ 185 | 186 | button, 187 | select { /* 1 */ 188 | text-transform: none; 189 | } 190 | 191 | /** 192 | * Correct the inability to style clickable types in iOS and Safari. 193 | */ 194 | 195 | button, 196 | [type="button"], 197 | [type="reset"], 198 | [type="submit"] { 199 | -webkit-appearance: button; 200 | } 201 | 202 | /** 203 | * Remove the inner border and padding in Firefox. 204 | */ 205 | 206 | button::-moz-focus-inner, 207 | [type="button"]::-moz-focus-inner, 208 | [type="reset"]::-moz-focus-inner, 209 | [type="submit"]::-moz-focus-inner { 210 | border-style: none; 211 | padding: 0; 212 | } 213 | 214 | /** 215 | * Restore the focus styles unset by the previous rule. 216 | */ 217 | 218 | button:-moz-focusring, 219 | [type="button"]:-moz-focusring, 220 | [type="reset"]:-moz-focusring, 221 | [type="submit"]:-moz-focusring { 222 | outline: 1px dotted ButtonText; 223 | } 224 | 225 | /** 226 | * Correct the padding in Firefox. 227 | */ 228 | 229 | fieldset { 230 | padding: 0.35em 0.75em 0.625em; 231 | } 232 | 233 | /** 234 | * 1. Correct the text wrapping in Edge and IE. 235 | * 2. Correct the color inheritance from `fieldset` elements in IE. 236 | * 3. Remove the padding so developers are not caught out when they zero out 237 | * `fieldset` elements in all browsers. 238 | */ 239 | 240 | legend { 241 | box-sizing: border-box; /* 1 */ 242 | color: inherit; /* 2 */ 243 | display: table; /* 1 */ 244 | max-width: 100%; /* 1 */ 245 | padding: 0; /* 3 */ 246 | white-space: normal; /* 1 */ 247 | } 248 | 249 | /** 250 | * Add the correct vertical alignment in Chrome, Firefox, and Opera. 251 | */ 252 | 253 | progress { 254 | vertical-align: baseline; 255 | } 256 | 257 | /** 258 | * Remove the default vertical scrollbar in IE 10+. 259 | */ 260 | 261 | textarea { 262 | overflow: auto; 263 | } 264 | 265 | /** 266 | * 1. Add the correct box sizing in IE 10. 267 | * 2. Remove the padding in IE 10. 268 | */ 269 | 270 | [type="checkbox"], 271 | [type="radio"] { 272 | box-sizing: border-box; /* 1 */ 273 | padding: 0; /* 2 */ 274 | } 275 | 276 | /** 277 | * Correct the cursor style of increment and decrement buttons in Chrome. 278 | */ 279 | 280 | [type="number"]::-webkit-inner-spin-button, 281 | [type="number"]::-webkit-outer-spin-button { 282 | height: auto; 283 | } 284 | 285 | /** 286 | * 1. Correct the odd appearance in Chrome and Safari. 287 | * 2. Correct the outline style in Safari. 288 | */ 289 | 290 | [type="search"] { 291 | -webkit-appearance: textfield; /* 1 */ 292 | outline-offset: -2px; /* 2 */ 293 | } 294 | 295 | /** 296 | * Remove the inner padding in Chrome and Safari on macOS. 297 | */ 298 | 299 | [type="search"]::-webkit-search-decoration { 300 | -webkit-appearance: none; 301 | } 302 | 303 | /** 304 | * 1. Correct the inability to style clickable types in iOS and Safari. 305 | * 2. Change font properties to `inherit` in Safari. 306 | */ 307 | 308 | ::-webkit-file-upload-button { 309 | -webkit-appearance: button; /* 1 */ 310 | font: inherit; /* 2 */ 311 | } 312 | 313 | /* Interactive 314 | ========================================================================== */ 315 | 316 | /* 317 | * Add the correct display in Edge, IE 10+, and Firefox. 318 | */ 319 | 320 | details { 321 | display: block; 322 | } 323 | 324 | /* 325 | * Add the correct display in all browsers. 326 | */ 327 | 328 | summary { 329 | display: list-item; 330 | } 331 | 332 | /* Misc 333 | ========================================================================== */ 334 | 335 | /** 336 | * Add the correct display in IE 10+. 337 | */ 338 | 339 | template { 340 | display: none; 341 | } 342 | 343 | /** 344 | * Add the correct display in IE 10. 345 | */ 346 | 347 | [hidden] { 348 | display: none; 349 | } 350 | --------------------------------------------------------------------------------