├── .gitignore ├── DrawingInterface.py ├── LICENSE ├── README.md ├── USE ├── clipdrawer.py ├── clipit.py ├── demos ├── Init_Image.ipynb ├── Moar_Settings.ipynb ├── PixelDrawer.ipynb ├── PixelDrawer_Init_Image.ipynb ├── Pixray_Swirl_Demo.ipynb ├── README.md ├── Start_Here.ipynb ├── Swap_Model.ipynb └── palette_enforcement.ipynb ├── download_models.sh ├── generate.py ├── models └── .gitignore ├── opt_tester.sh ├── pixeldrawer.py ├── random.sh ├── requirements.txt ├── samples ├── A_painting_of_an_apple_in_a_fruitbowl.png ├── Apple_weird.png ├── Bedroom.png ├── Cartoon.png ├── Cartoon2.png ├── Cartoon3.png ├── DemonBiscuits.png ├── Football.png ├── Fractal_Landscape3.png ├── Games_5.png ├── VanGogh.jpg ├── pencil_sketch_2.png ├── samples.txt ├── vvg_picasso.png ├── vvg_psychedelic.png ├── vvg_sketch.png └── zoom.gif ├── video_styler.sh ├── vqgan.py ├── vqgan.yml └── zoom.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | 4 | # Libraries and Models 5 | CLIP/ 6 | taming-transformers/ 7 | taming/ 8 | checkpoints/ 9 | 10 | # Editor 11 | .vscode/ 12 | 13 | # Operations 14 | outputs/ 15 | steps/ 16 | 17 | # Files 18 | output.png 19 | steps.mp4 -------------------------------------------------------------------------------- /DrawingInterface.py: -------------------------------------------------------------------------------- 1 | class DrawingInterface: 2 | model = None 3 | def load_model(self, config, checkpoint): 4 | pass 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | for use in commercial products also refer to USE file 2 | 3 | clipit is open sourced for all free and educational uses 4 | but is currently covered by different licenses for different sections 5 | 6 | this file will be udpated to reflect the licenses that apply to various dependencies. 7 | 8 | MIT: VQGAN and Clip guided core 9 | ------------------------------- 10 | VQGAN: MIT License 11 | Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer 12 | 13 | CLIP guided core: MIT License 14 | Copyright (c) 2021 Katherine Crowson 15 | 16 | Permission is hereby granted, free of charge, to any person obtaining a copy 17 | of this software and associated documentation files (the "Software"), to deal 18 | in the Software without restriction, including without limitation the rights 19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 20 | copies of the Software, and to permit persons to whom the Software is 21 | furnished to do so, subject to the following conditions: 22 | 23 | The above copyright notice and this permission notice shall be included in all 24 | copies or substantial portions of the Software. 25 | 26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 32 | SOFTWARE. 33 | 34 | diffvg: Apache License 2.0 35 | 36 | ClipDraw: No License (contact author: https://github.com/kvfrans/clipdraw/issues/3 ) 37 | 38 | PixelDraw: No License (contact author: @dribnet) 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # clipit 2 | 3 | ### Note: Updates now in [pixray](https://github.com/dribnet/pixray) project. 4 | 5 | This repo is now used to host Google Colab notebooks which demostrate various pixray capabilities. Issues and pull requests for this repo should be specific to the notebooks as the python library here is now out of date and only remains to support notebooks out in the wild. 6 | 7 | This version was originally a fork of @nerdyrodent's VQGAN-CLIP code which itself was based on the notebooks of @RiversWithWings and @advadnoun. 8 | 9 | To get started with pixray, check out [THE DEMO NOTEBOOKS](demos/README.md) - especially the super simple "Start Here" colab. 10 | 11 | 12 | # Citations 13 | 14 | ```bibtex 15 | @misc{unpublished2021clip, 16 | title = {CLIP: Connecting Text and Images}, 17 | author = {Alec Radford, Ilya Sutskever, Jong Wook Kim, Gretchen Krueger, Sandhini Agarwal}, 18 | year = {2021} 19 | } 20 | ``` 21 | ```bibtex 22 | @misc{esser2020taming, 23 | title={Taming Transformers for High-Resolution Image Synthesis}, 24 | author={Patrick Esser and Robin Rombach and Björn Ommer}, 25 | year={2020}, 26 | eprint={2012.09841}, 27 | archivePrefix={arXiv}, 28 | primaryClass={cs.CV} 29 | } 30 | ``` 31 | Katherine Crowson - https://github.com/crowsonkb 32 | Adverb https://twitter.com/advadnoun 33 | 34 | -------------------------------------------------------------------------------- /USE: -------------------------------------------------------------------------------- 1 | royalties: pixray.dribnet.eth, pixray.dribnet.tez 2 | author: dribnet 3 | terms: negotiable 4 | -------------------------------------------------------------------------------- /clipdrawer.py: -------------------------------------------------------------------------------- 1 | # this is derived from ClipDraw code 2 | # CLIPDraw: Exploring Text-to-Drawing Synthesis through Language-Image Encoders 3 | # Kevin Frans, L.B. Soros, Olaf Witkowski 4 | # https://arxiv.org/abs/2106.14843 5 | 6 | from DrawingInterface import DrawingInterface 7 | 8 | import pydiffvg 9 | import torch 10 | import skimage 11 | import skimage.io 12 | import random 13 | import ttools.modules 14 | import argparse 15 | import math 16 | import torchvision 17 | import torchvision.transforms as transforms 18 | import numpy as np 19 | import PIL.Image 20 | 21 | pydiffvg.set_print_timing(False) 22 | 23 | class ClipDrawer(DrawingInterface): 24 | num_paths = 256 25 | max_width = 50 26 | 27 | def __init__(self, width, height, num_paths): 28 | super(DrawingInterface, self).__init__() 29 | 30 | self.canvas_width = width 31 | self.canvas_height = height 32 | self.num_paths = num_paths 33 | 34 | def load_model(self, config_path, checkpoint_path, device): 35 | # gamma = 1.0 36 | 37 | # Use GPU if available 38 | pydiffvg.set_use_gpu(torch.cuda.is_available()) 39 | device = torch.device('cuda') 40 | pydiffvg.set_device(device) 41 | 42 | canvas_width, canvas_height = self.canvas_width, self.canvas_height 43 | num_paths = self.num_paths 44 | max_width = canvas_height / 10 45 | 46 | # Initialize Random Curves 47 | shapes = [] 48 | shape_groups = [] 49 | for i in range(num_paths): 50 | num_segments = random.randint(1, 3) 51 | num_control_points = torch.zeros(num_segments, dtype = torch.int32) + 2 52 | points = [] 53 | p0 = (random.random(), random.random()) 54 | points.append(p0) 55 | for j in range(num_segments): 56 | radius = 0.1 57 | p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5)) 58 | p2 = (p1[0] + radius * (random.random() - 0.5), p1[1] + radius * (random.random() - 0.5)) 59 | p3 = (p2[0] + radius * (random.random() - 0.5), p2[1] + radius * (random.random() - 0.5)) 60 | points.append(p1) 61 | points.append(p2) 62 | points.append(p3) 63 | p0 = p3 64 | points = torch.tensor(points) 65 | points[:, 0] *= canvas_width 66 | points[:, 1] *= canvas_height 67 | path = pydiffvg.Path(num_control_points = num_control_points, points = points, stroke_width = torch.tensor(max_width/10), is_closed = False) 68 | shapes.append(path) 69 | path_group = pydiffvg.ShapeGroup(shape_ids = torch.tensor([len(shapes) - 1]), fill_color = None, stroke_color = torch.tensor([random.random(), random.random(), random.random(), random.random()])) 70 | shape_groups.append(path_group) 71 | 72 | # Just some diffvg setup 73 | scene_args = pydiffvg.RenderFunction.serialize_scene(\ 74 | canvas_width, canvas_height, shapes, shape_groups) 75 | render = pydiffvg.RenderFunction.apply 76 | img = render(canvas_width, canvas_height, 2, 2, 0, None, *scene_args) 77 | 78 | points_vars = [] 79 | stroke_width_vars = [] 80 | color_vars = [] 81 | for path in shapes: 82 | path.points.requires_grad = True 83 | points_vars.append(path.points) 84 | path.stroke_width.requires_grad = True 85 | stroke_width_vars.append(path.stroke_width) 86 | for group in shape_groups: 87 | group.stroke_color.requires_grad = True 88 | color_vars.append(group.stroke_color) 89 | 90 | # Optimizers 91 | points_optim = torch.optim.Adam(points_vars, lr=1.0) 92 | width_optim = torch.optim.Adam(stroke_width_vars, lr=0.1) 93 | color_optim = torch.optim.Adam(color_vars, lr=0.01) 94 | 95 | self.img = img 96 | self.shapes = shapes 97 | self.shape_groups = shape_groups 98 | self.max_width = max_width 99 | self.canvas_width = canvas_width 100 | self.canvas_height = canvas_height 101 | self.opts = [points_optim, width_optim, color_optim] 102 | 103 | def get_opts(self): 104 | return self.opts 105 | 106 | def rand_init(self, toksX, toksY): 107 | # TODO 108 | pass 109 | 110 | def init_from_tensor(self, init_tensor): 111 | # TODO 112 | pass 113 | 114 | def reapply_from_tensor(self, new_tensor): 115 | # TODO 116 | pass 117 | 118 | def get_z_from_tensor(self, ref_tensor): 119 | return None 120 | 121 | def get_num_resolutions(self): 122 | # TODO 123 | return 5 124 | 125 | def synth(self, cur_iteration): 126 | render = pydiffvg.RenderFunction.apply 127 | scene_args = pydiffvg.RenderFunction.serialize_scene(\ 128 | self.canvas_width, self.canvas_height, self.shapes, self.shape_groups) 129 | img = render(self.canvas_width, self.canvas_height, 2, 2, cur_iteration, None, *scene_args) 130 | img = img[:, :, 3:4] * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device = pydiffvg.get_device()) * (1 - img[:, :, 3:4]) 131 | img = img[:, :, :3] 132 | img = img.unsqueeze(0) 133 | img = img.permute(0, 3, 1, 2) # NHWC -> NCHW 134 | self.img = img 135 | return img 136 | 137 | @torch.no_grad() 138 | def to_image(self): 139 | img = self.img.detach().cpu().numpy()[0] 140 | img = np.transpose(img, (1, 2, 0)) 141 | img = np.clip(img, 0, 1) 142 | img = np.uint8(img * 254) 143 | # img = np.repeat(img, 4, axis=0) 144 | # img = np.repeat(img, 4, axis=1) 145 | pimg = PIL.Image.fromarray(img, mode="RGB") 146 | return pimg 147 | 148 | def clip_z(self): 149 | with torch.no_grad(): 150 | for path in self.shapes: 151 | path.stroke_width.data.clamp_(1.0, self.max_width) 152 | for group in self.shape_groups: 153 | group.stroke_color.data.clamp_(0.0, 1.0) 154 | 155 | def get_z(self): 156 | return None 157 | 158 | def get_z_copy(self): 159 | return None 160 | 161 | ### EXTERNAL INTERFACE 162 | ### load_vqgan_model 163 | 164 | if __name__ == '__main__': 165 | main() 166 | -------------------------------------------------------------------------------- /clipit.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | from urllib.request import urlopen 4 | import sys 5 | import os 6 | import json 7 | import subprocess 8 | import glob 9 | from braceexpand import braceexpand 10 | from types import SimpleNamespace 11 | 12 | import os.path 13 | 14 | from omegaconf import OmegaConf 15 | 16 | import torch 17 | from torch import nn, optim 18 | from torch.nn import functional as F 19 | from torchvision import transforms 20 | from torchvision.transforms import functional as TF 21 | torch.backends.cudnn.benchmark = False # NR: True is a bit faster, but can lead to OOM. False is more deterministic. 22 | #torch.use_deterministic_algorithms(True) # NR: grid_sampler_2d_backward_cuda does not have a deterministic implementation 23 | 24 | from torch_optimizer import DiffGrad, AdamP, RAdam 25 | from perlin_numpy import generate_fractal_noise_2d 26 | 27 | from CLIP import clip 28 | import kornia 29 | import kornia.augmentation as K 30 | import numpy as np 31 | import imageio 32 | 33 | from PIL import ImageFile, Image, PngImagePlugin 34 | ImageFile.LOAD_TRUNCATED_IMAGES = True 35 | 36 | # or 'border' 37 | global_padding_mode = 'reflection' 38 | global_aspect_width = 1 39 | global_spot_file = None 40 | 41 | from vqgan import VqganDrawer 42 | try: 43 | from clipdrawer import ClipDrawer 44 | except ImportError: 45 | pass 46 | # print('clipdrawer not imported') 47 | try: 48 | from pixeldrawer import PixelDrawer 49 | except ImportError: 50 | pass 51 | # print('pixeldrawer not imported') 52 | 53 | try: 54 | import matplotlib.colors 55 | except ImportError: 56 | # only needed for palette stuff 57 | pass 58 | 59 | # print("warning: running unreleased future version") 60 | 61 | # https://stackoverflow.com/a/39662359 62 | def isnotebook(): 63 | try: 64 | shell = get_ipython().__class__.__name__ 65 | if shell == 'ZMQInteractiveShell': 66 | return True # Jupyter notebook or qtconsole 67 | elif shell == 'Shell': 68 | return True # Seems to be what co-lab does 69 | elif shell == 'TerminalInteractiveShell': 70 | return False # Terminal running IPython 71 | else: 72 | return False # Other type (?) 73 | except NameError: 74 | return False # Probably standard Python interpreter 75 | 76 | IS_NOTEBOOK = isnotebook() 77 | 78 | if IS_NOTEBOOK: 79 | from IPython import display 80 | from tqdm.notebook import tqdm 81 | from IPython.display import clear_output 82 | else: 83 | from tqdm import tqdm 84 | 85 | # file helpers 86 | def real_glob(rglob): 87 | glob_list = braceexpand(rglob) 88 | files = [] 89 | for g in glob_list: 90 | files = files + glob.glob(g) 91 | return sorted(files) 92 | 93 | # Functions and classes 94 | def sinc(x): 95 | return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([])) 96 | 97 | 98 | def lanczos(x, a): 99 | cond = torch.logical_and(-a < x, x < a) 100 | out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([])) 101 | return out / out.sum() 102 | 103 | 104 | def ramp(ratio, width): 105 | n = math.ceil(width / ratio + 1) 106 | out = torch.empty([n]) 107 | cur = 0 108 | for i in range(out.shape[0]): 109 | out[i] = cur 110 | cur += ratio 111 | return torch.cat([-out[1:].flip([0]), out])[1:-1] 112 | 113 | 114 | # NR: Testing with different intital images 115 | def old_random_noise_image(w,h): 116 | random_image = Image.fromarray(np.random.randint(0,255,(w,h,3),dtype=np.dtype('uint8'))) 117 | return random_image 118 | 119 | def NormalizeData(data): 120 | return (data - np.min(data)) / (np.max(data) - np.min(data)) 121 | 122 | # https://stats.stackexchange.com/a/289477 123 | def contrast_noise(n): 124 | n = 0.9998 * n + 0.0001 125 | n1 = (n / (1-n)) 126 | n2 = np.power(n1, -2) 127 | n3 = 1 / (1 + n2) 128 | return n3 129 | 130 | def random_noise_image(w,h): 131 | # scale up roughly as power of 2 132 | if (w>1024 or h>1024): 133 | side, octp = 2048, 7 134 | elif (w>512 or h>512): 135 | side, octp = 1024, 6 136 | elif (w>256 or h>256): 137 | side, octp = 512, 5 138 | else: 139 | side, octp = 256, 4 140 | 141 | nr = NormalizeData(generate_fractal_noise_2d((side, side), (32, 32), octp)) 142 | ng = NormalizeData(generate_fractal_noise_2d((side, side), (32, 32), octp)) 143 | nb = NormalizeData(generate_fractal_noise_2d((side, side), (32, 32), octp)) 144 | stack = np.dstack((contrast_noise(nr),contrast_noise(ng),contrast_noise(nb))) 145 | substack = stack[:h, :w, :] 146 | im = Image.fromarray((255.9 * stack).astype('uint8')) 147 | return im 148 | 149 | # testing 150 | def gradient_2d(start, stop, width, height, is_horizontal): 151 | if is_horizontal: 152 | return np.tile(np.linspace(start, stop, width), (height, 1)) 153 | else: 154 | return np.tile(np.linspace(start, stop, height), (width, 1)).T 155 | 156 | 157 | def gradient_3d(width, height, start_list, stop_list, is_horizontal_list): 158 | result = np.zeros((height, width, len(start_list)), dtype=float) 159 | 160 | for i, (start, stop, is_horizontal) in enumerate(zip(start_list, stop_list, is_horizontal_list)): 161 | result[:, :, i] = gradient_2d(start, stop, width, height, is_horizontal) 162 | 163 | return result 164 | 165 | 166 | def random_gradient_image(w,h): 167 | array = gradient_3d(w, h, (0, 0, np.random.randint(0,255)), (np.random.randint(1,255), np.random.randint(2,255), np.random.randint(3,128)), (True, False, False)) 168 | random_image = Image.fromarray(np.uint8(array)) 169 | return random_image 170 | 171 | 172 | class ReplaceGrad(torch.autograd.Function): 173 | @staticmethod 174 | def forward(ctx, x_forward, x_backward): 175 | ctx.shape = x_backward.shape 176 | return x_forward 177 | 178 | @staticmethod 179 | def backward(ctx, grad_in): 180 | return None, grad_in.sum_to_size(ctx.shape) 181 | 182 | replace_grad = ReplaceGrad.apply 183 | 184 | 185 | def spherical_dist_loss(x, y): 186 | x = F.normalize(x, dim=-1) 187 | y = F.normalize(y, dim=-1) 188 | return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) 189 | 190 | 191 | class Prompt(nn.Module): 192 | def __init__(self, embed, weight=1., stop=float('-inf')): 193 | super().__init__() 194 | self.register_buffer('embed', embed) 195 | self.register_buffer('weight', torch.as_tensor(weight)) 196 | self.register_buffer('stop', torch.as_tensor(stop)) 197 | 198 | def forward(self, input): 199 | input_normed = F.normalize(input.unsqueeze(1), dim=2) 200 | embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2) 201 | dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) 202 | dists = dists * self.weight.sign() 203 | return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean() 204 | 205 | 206 | def parse_prompt(prompt): 207 | vals = prompt.rsplit(':', 2) 208 | vals = vals + ['', '1', '-inf'][len(vals):] 209 | # print(f"parsed vals is {vals}") 210 | return vals[0], float(vals[1]), float(vals[2]) 211 | 212 | 213 | from typing import cast, Dict, List, Optional, Tuple, Union 214 | 215 | # override class to get padding_mode 216 | class MyRandomPerspective(K.RandomPerspective): 217 | def apply_transform( 218 | self, input: torch.Tensor, params: Dict[str, torch.Tensor], transform: Optional[torch.Tensor] = None 219 | ) -> torch.Tensor: 220 | _, _, height, width = input.shape 221 | transform = cast(torch.Tensor, transform) 222 | return kornia.geometry.warp_perspective( 223 | input, transform, (height, width), 224 | mode=self.resample.name.lower(), align_corners=self.align_corners, padding_mode=global_padding_mode 225 | ) 226 | 227 | 228 | cached_spot_indexes = {} 229 | def fetch_spot_indexes(sideX, sideY): 230 | global global_spot_file 231 | 232 | # make sure image is loaded if we need it 233 | cache_key = (sideX, sideY) 234 | 235 | if cache_key not in cached_spot_indexes: 236 | if global_spot_file is not None: 237 | mask_image = Image.open(global_spot_file) 238 | elif global_aspect_width != 1: 239 | mask_image = Image.open("inputs/spot_wide.png") 240 | else: 241 | mask_image = Image.open("inputs/spot_square.png") 242 | # this is a one channel mask 243 | mask_image = mask_image.convert('RGB') 244 | mask_image = mask_image.resize((sideX, sideY), Image.LANCZOS) 245 | mask_image_tensor = TF.to_tensor(mask_image) 246 | # print("ONE CHANNEL ", mask_image_tensor.shape) 247 | mask_indexes = mask_image_tensor.ge(0.5).to(device) 248 | # print("GE ", mask_indexes.shape) 249 | # sys.exit(0) 250 | mask_indexes_off = mask_image_tensor.lt(0.5).to(device) 251 | cached_spot_indexes[cache_key] = [mask_indexes, mask_indexes_off] 252 | 253 | return cached_spot_indexes[cache_key] 254 | 255 | # n = torch.ones((3,5,5)) 256 | # f = generate.fetch_spot_indexes(5, 5) 257 | # f[0].shape = [60,3] 258 | 259 | class MakeCutouts(nn.Module): 260 | def __init__(self, cut_size, cutn, cut_pow=1.): 261 | global global_aspect_width 262 | 263 | super().__init__() 264 | self.cut_size = cut_size 265 | self.cutn = cutn 266 | self.cutn_zoom = int(2*cutn/3) 267 | self.cut_pow = cut_pow 268 | self.transforms = None 269 | 270 | augmentations = [] 271 | if global_aspect_width != 1: 272 | augmentations.append(K.RandomCrop(size=(self.cut_size,self.cut_size), p=1.0, cropping_mode="resample", return_transform=True)) 273 | augmentations.append(MyRandomPerspective(distortion_scale=0.40, p=0.7, return_transform=True)) 274 | augmentations.append(K.RandomResizedCrop(size=(self.cut_size,self.cut_size), scale=(0.1,0.75), ratio=(0.85,1.2), cropping_mode='resample', p=0.7, return_transform=True)) 275 | augmentations.append(K.ColorJitter(hue=0.1, saturation=0.1, p=0.8, return_transform=True)) 276 | self.augs_zoom = nn.Sequential(*augmentations) 277 | 278 | augmentations = [] 279 | if global_aspect_width == 1: 280 | n_s = 0.95 281 | n_t = (1-n_s)/2 282 | augmentations.append(K.RandomAffine(degrees=0, translate=(n_t, n_t), scale=(n_s, n_s), p=1.0, return_transform=True)) 283 | elif global_aspect_width > 1: 284 | n_s = 1/global_aspect_width 285 | n_t = (1-n_s)/2 286 | augmentations.append(K.RandomAffine(degrees=0, translate=(0, n_t), scale=(0.9*n_s, n_s), p=1.0, return_transform=True)) 287 | else: 288 | n_s = global_aspect_width 289 | n_t = (1-n_s)/2 290 | augmentations.append(K.RandomAffine(degrees=0, translate=(n_t, 0), scale=(0.9*n_s, n_s), p=1.0, return_transform=True)) 291 | 292 | # augmentations.append(K.CenterCrop(size=(self.cut_size,self.cut_size), p=1.0, cropping_mode="resample", return_transform=True)) 293 | augmentations.append(K.CenterCrop(size=self.cut_size, cropping_mode='resample', p=1.0, return_transform=True)) 294 | augmentations.append(K.RandomPerspective(distortion_scale=0.20, p=0.7, return_transform=True)) 295 | augmentations.append(K.ColorJitter(hue=0.1, saturation=0.1, p=0.8, return_transform=True)) 296 | self.augs_wide = nn.Sequential(*augmentations) 297 | 298 | self.noise_fac = 0.1 299 | 300 | # Pooling 301 | self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size)) 302 | self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size)) 303 | 304 | def forward(self, input, spot=None): 305 | global global_aspect_width, cur_iteration 306 | sideY, sideX = input.shape[2:4] 307 | max_size = min(sideX, sideY) 308 | min_size = min(sideX, sideY, self.cut_size) 309 | cutouts = [] 310 | mask_indexes = None 311 | 312 | if spot is not None: 313 | spot_indexes = fetch_spot_indexes(self.cut_size, self.cut_size) 314 | if spot == 0: 315 | mask_indexes = spot_indexes[1] 316 | else: 317 | mask_indexes = spot_indexes[0] 318 | # print("Mask indexes ", mask_indexes) 319 | 320 | for _ in range(self.cutn): 321 | # Pooling 322 | cutout = (self.av_pool(input) + self.max_pool(input))/2 323 | 324 | if mask_indexes is not None: 325 | cutout[0][mask_indexes] = 0.5 326 | 327 | if global_aspect_width != 1: 328 | if global_aspect_width > 1: 329 | cutout = kornia.geometry.transform.rescale(cutout, (1, global_aspect_width)) 330 | else: 331 | cutout = kornia.geometry.transform.rescale(cutout, (1/global_aspect_width, 1)) 332 | 333 | # if cur_iteration % 50 == 0 and _ == 0: 334 | # print(cutout.shape) 335 | # TF.to_pil_image(cutout[0].cpu()).save(f"cutout_im_{cur_iteration:02d}_{spot}.png") 336 | 337 | cutouts.append(cutout) 338 | 339 | if self.transforms is not None: 340 | # print("Cached transforms available") 341 | batch1 = kornia.geometry.transform.warp_perspective(torch.cat(cutouts[:self.cutn_zoom], dim=0), self.transforms[:self.cutn_zoom], 342 | (self.cut_size, self.cut_size), padding_mode=global_padding_mode) 343 | batch2 = kornia.geometry.transform.warp_perspective(torch.cat(cutouts[self.cutn_zoom:], dim=0), self.transforms[self.cutn_zoom:], 344 | (self.cut_size, self.cut_size), padding_mode='zeros') 345 | batch = torch.cat([batch1, batch2]) 346 | # if cur_iteration < 2: 347 | # for j in range(4): 348 | # TF.to_pil_image(batch[j].cpu()).save(f"cached_im_{cur_iteration:02d}_{j:02d}_{spot}.png") 349 | # j_wide = j + self.cutn_zoom 350 | # TF.to_pil_image(batch[j_wide].cpu()).save(f"cached_im_{cur_iteration:02d}_{j_wide:02d}_{spot}.png") 351 | else: 352 | batch1, transforms1 = self.augs_zoom(torch.cat(cutouts[:self.cutn_zoom], dim=0)) 353 | batch2, transforms2 = self.augs_wide(torch.cat(cutouts[self.cutn_zoom:], dim=0)) 354 | # print(batch1.shape, batch2.shape) 355 | batch = torch.cat([batch1, batch2]) 356 | # print(batch.shape) 357 | self.transforms = torch.cat([transforms1, transforms2]) 358 | ## batch, self.transforms = self.augs(torch.cat(cutouts, dim=0)) 359 | # if cur_iteration < 2: 360 | # for j in range(4): 361 | # TF.to_pil_image(batch[j].cpu()).save(f"live_im_{cur_iteration:02d}_{j:02d}_{spot}.png") 362 | # j_wide = j + self.cutn_zoom 363 | # TF.to_pil_image(batch[j_wide].cpu()).save(f"live_im_{cur_iteration:02d}_{j_wide:02d}_{spot}.png") 364 | 365 | # print(batch.shape, self.transforms.shape) 366 | 367 | if self.noise_fac: 368 | facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac) 369 | batch = batch + facs * torch.randn_like(batch) 370 | return batch 371 | 372 | 373 | def resize_image(image, out_size): 374 | ratio = image.size[0] / image.size[1] 375 | area = min(image.size[0] * image.size[1], out_size[0] * out_size[1]) 376 | size = round((area * ratio)**0.5), round((area / ratio)**0.5) 377 | return image.resize(size, Image.LANCZOS) 378 | 379 | def do_init(args): 380 | global opts, perceptors, normalize, cutoutsTable, cutoutSizeTable 381 | global z_orig, z_targets, z_labels, init_image_tensor, target_image_tensor 382 | global gside_X, gside_Y, overlay_image_rgba 383 | global pmsTable, pmsImageTable, pImages, device, spotPmsTable, spotOffPmsTable 384 | global drawer 385 | 386 | # Do it (init that is) 387 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 388 | 389 | if args.use_clipdraw: 390 | drawer = ClipDrawer(args.size[0], args.size[1], args.strokes) 391 | elif args.use_pixeldraw: 392 | if args.pixel_size is not None: 393 | drawer = PixelDrawer(args.size[0], args.size[1], args.do_mono, args.pixel_size, scale=args.pixel_scale) 394 | elif global_aspect_width == 1: 395 | drawer = PixelDrawer(args.size[0], args.size[1], args.do_mono, [40, 40], scale=args.pixel_scale) 396 | else: 397 | drawer = PixelDrawer(args.size[0], args.size[1], args.do_mono, scale=args.pixel_scale) 398 | else: 399 | drawer = VqganDrawer(args.vqgan_model) 400 | drawer.load_model(args.vqgan_config, args.vqgan_checkpoint, device) 401 | num_resolutions = drawer.get_num_resolutions() 402 | # print("-----------> NUMR ", num_resolutions) 403 | 404 | jit = True if float(torch.__version__[:3]) < 1.8 else False 405 | f = 2**(num_resolutions - 1) 406 | 407 | toksX, toksY = args.size[0] // f, args.size[1] // f 408 | sideX, sideY = toksX * f, toksY * f 409 | 410 | # save sideX, sideY in globals (need if using overlay) 411 | gside_X = sideX 412 | gside_Y = sideY 413 | 414 | for clip_model in args.clip_models: 415 | perceptor = clip.load(clip_model, jit=jit)[0].eval().requires_grad_(False).to(device) 416 | perceptors[clip_model] = perceptor 417 | 418 | cut_size = perceptor.visual.input_resolution 419 | cutoutSizeTable[clip_model] = cut_size 420 | if not cut_size in cutoutsTable: 421 | make_cutouts = MakeCutouts(cut_size, args.num_cuts, cut_pow=args.cut_pow) 422 | cutoutsTable[cut_size] = make_cutouts 423 | 424 | init_image_tensor = None 425 | target_image_tensor = None 426 | 427 | # Image initialisation 428 | if args.init_image or args.init_noise: 429 | # setup init image wih pil 430 | # first - always start with noise or blank 431 | if args.init_noise == 'pixels': 432 | img = random_noise_image(args.size[0], args.size[1]) 433 | elif args.init_noise == 'gradient': 434 | img = random_gradient_image(args.size[0], args.size[1]) 435 | elif args.init_noise == 'snow': 436 | img = old_random_noise_image(args.size[0], args.size[1]) 437 | else: 438 | img = Image.new(mode="RGB", size=(args.size[0], args.size[1]), color=(255, 255, 255)) 439 | starting_image = img.convert('RGB') 440 | starting_image = starting_image.resize((sideX, sideY), Image.LANCZOS) 441 | 442 | if args.init_image: 443 | # now we might overlay an init image (init_image also can be recycled as overlay) 444 | if 'http' in args.init_image: 445 | init_image = Image.open(urlopen(args.init_image)) 446 | else: 447 | init_image = Image.open(args.init_image) 448 | # this version is needed potentially for the loss function 449 | init_image_rgb = init_image.convert('RGB') 450 | init_image_rgb = init_image_rgb.resize((sideX, sideY), Image.LANCZOS) 451 | init_image_tensor = TF.to_tensor(init_image_rgb) 452 | init_image_tensor = init_image_tensor.to(device).unsqueeze(0) 453 | 454 | # this version gets overlaid on the background (noise) 455 | init_image_rgba = init_image.convert('RGBA') 456 | init_image_rgba = init_image_rgba.resize((sideX, sideY), Image.LANCZOS) 457 | top_image = init_image_rgba.copy() 458 | if args.init_image_alpha and args.init_image_alpha >= 0: 459 | top_image.putalpha(args.init_image_alpha) 460 | starting_image.paste(top_image, (0, 0), top_image) 461 | 462 | starting_image.save("starting_image.png") 463 | starting_tensor = TF.to_tensor(starting_image) 464 | init_tensor = starting_tensor.to(device).unsqueeze(0) * 2 - 1 465 | drawer.init_from_tensor(init_tensor) 466 | 467 | else: 468 | # untested 469 | drawer.rand_init(toksX, toksY) 470 | 471 | if args.overlay_every: 472 | if args.overlay_image: 473 | if 'http' in args.overlay_image: 474 | overlay_image = Image.open(urlopen(args.overlay_image)) 475 | else: 476 | overlay_image = Image.open(args.overlay_image) 477 | overlay_image_rgba = overlay_image.convert('RGBA') 478 | overlay_image_rgba = overlay_image_rgba.resize((sideX, sideY), Image.LANCZOS) 479 | else: 480 | overlay_image_rgba = init_image_rgba 481 | if args.overlay_alpha: 482 | overlay_image_rgba.putalpha(args.overlay_alpha) 483 | overlay_image_rgba.save('overlay_image.png') 484 | 485 | if args.target_images is not None: 486 | z_targets = [] 487 | filelist = real_glob(args.target_images) 488 | for target_image in filelist: 489 | target_image = Image.open(target_image) 490 | target_image_rgb = target_image.convert('RGB') 491 | target_image_rgb = target_image_rgb.resize((sideX, sideY), Image.LANCZOS) 492 | target_image_tensor_local = TF.to_tensor(target_image_rgb) 493 | target_image_tensor = target_image_tensor_local.to(device).unsqueeze(0) * 2 - 1 494 | z_target = drawer.get_z_from_tensor(target_image_tensor) 495 | z_targets.append(z_target) 496 | 497 | if args.image_labels is not None: 498 | z_labels = [] 499 | filelist = real_glob(args.image_labels) 500 | cur_labels = [] 501 | for image_label in filelist: 502 | image_label = Image.open(image_label) 503 | image_label_rgb = image_label.convert('RGB') 504 | image_label_rgb = image_label_rgb.resize((sideX, sideY), Image.LANCZOS) 505 | image_label_rgb_tensor = TF.to_tensor(image_label_rgb) 506 | image_label_rgb_tensor = image_label_rgb_tensor.to(device).unsqueeze(0) * 2 - 1 507 | z_label = drawer.get_z_from_tensor(image_label_rgb_tensor) 508 | cur_labels.append(z_label) 509 | image_embeddings = torch.stack(cur_labels) 510 | print("Processing labels: ", image_embeddings.shape) 511 | image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True) 512 | image_embeddings = image_embeddings.mean(dim=0) 513 | image_embeddings /= image_embeddings.norm() 514 | z_labels.append(image_embeddings.unsqueeze(0)) 515 | 516 | z_orig = drawer.get_z_copy() 517 | 518 | pmsTable = {} 519 | pmsImageTable = {} 520 | spotPmsTable = {} 521 | spotOffPmsTable = {} 522 | for clip_model in args.clip_models: 523 | pmsTable[clip_model] = [] 524 | pmsImageTable[clip_model] = [] 525 | spotPmsTable[clip_model] = [] 526 | spotOffPmsTable[clip_model] = [] 527 | normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 528 | std=[0.26862954, 0.26130258, 0.27577711]) 529 | 530 | # CLIP tokenize/encode 531 | # NR: Weights / blending 532 | for prompt in args.prompts: 533 | for clip_model in args.clip_models: 534 | pMs = pmsTable[clip_model] 535 | perceptor = perceptors[clip_model] 536 | txt, weight, stop = parse_prompt(prompt) 537 | embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float() 538 | pMs.append(Prompt(embed, weight, stop).to(device)) 539 | 540 | for prompt in args.spot_prompts: 541 | for clip_model in args.clip_models: 542 | pMs = spotPmsTable[clip_model] 543 | perceptor = perceptors[clip_model] 544 | txt, weight, stop = parse_prompt(prompt) 545 | embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float() 546 | pMs.append(Prompt(embed, weight, stop).to(device)) 547 | 548 | for prompt in args.spot_prompts_off: 549 | for clip_model in args.clip_models: 550 | pMs = spotOffPmsTable[clip_model] 551 | perceptor = perceptors[clip_model] 552 | txt, weight, stop = parse_prompt(prompt) 553 | embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float() 554 | pMs.append(Prompt(embed, weight, stop).to(device)) 555 | 556 | for label in args.labels: 557 | for clip_model in args.clip_models: 558 | pMs = pmsTable[clip_model] 559 | perceptor = perceptors[clip_model] 560 | txt, weight, stop = parse_prompt(label) 561 | texts = [template.format(txt) for template in imagenet_templates] #format with class 562 | print(f"Tokenizing all of {texts}") 563 | texts = clip.tokenize(texts).to(device) #tokenize 564 | class_embeddings = perceptor.encode_text(texts) #embed with text encoder 565 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 566 | class_embedding = class_embeddings.mean(dim=0) 567 | class_embedding /= class_embedding.norm() 568 | pMs.append(Prompt(class_embedding.unsqueeze(0), weight, stop).to(device)) 569 | 570 | for clip_model in args.clip_models: 571 | pImages = pmsImageTable[clip_model] 572 | for prompt in args.image_prompts: 573 | path, weight, stop = parse_prompt(prompt) 574 | img = Image.open(path) 575 | pil_image = img.convert('RGB') 576 | img = resize_image(pil_image, (sideX, sideY)) 577 | pImages.append(TF.to_tensor(img).unsqueeze(0).to(device)) 578 | 579 | for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights): 580 | gen = torch.Generator().manual_seed(seed) 581 | embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen) 582 | pMs.append(Prompt(embed, weight).to(device)) 583 | 584 | opts = drawer.get_opts() 585 | if opts == None: 586 | # legacy 587 | 588 | # Set the optimiser 589 | z = drawer.get_z(); 590 | if args.optimiser == "Adam": 591 | opt = optim.Adam([z], lr=args.learning_rate) # LR=0.1 592 | elif args.optimiser == "AdamW": 593 | opt = optim.AdamW([z], lr=args.learning_rate) # LR=0.2 594 | elif args.optimiser == "Adagrad": 595 | opt = optim.Adagrad([z], lr=args.learning_rate) # LR=0.5+ 596 | elif args.optimiser == "Adamax": 597 | opt = optim.Adamax([z], lr=args.learning_rate) # LR=0.5+? 598 | elif args.optimiser == "DiffGrad": 599 | opt = DiffGrad([z], lr=args.learning_rate) # LR=2+? 600 | elif args.optimiser == "AdamP": 601 | opt = AdamP([z], lr=args.learning_rate) # LR=2+? 602 | elif args.optimiser == "RAdam": 603 | opt = RAdam([z], lr=args.learning_rate) # LR=2+? 604 | 605 | opts = [opt] 606 | 607 | # Output for the user 608 | print('Using device:', device) 609 | print('Optimising using:', args.optimiser) 610 | 611 | if args.prompts: 612 | print('Using text prompts:', args.prompts) 613 | if args.spot_prompts: 614 | print('Using spot prompts:', args.spot_prompts) 615 | if args.spot_prompts_off: 616 | print('Using spot off prompts:', args.spot_prompts_off) 617 | if args.image_prompts: 618 | print('Using image prompts:', args.image_prompts) 619 | if args.init_image: 620 | print('Using initial image:', args.init_image) 621 | if args.noise_prompt_weights: 622 | print('Noise prompt weights:', args.noise_prompt_weights) 623 | 624 | 625 | if args.seed is None: 626 | seed = torch.seed() 627 | else: 628 | seed = args.seed 629 | torch.manual_seed(seed) 630 | print('Using seed:', seed) 631 | 632 | 633 | # dreaded globals (for now) 634 | z_orig = None 635 | z_targets = None 636 | z_labels = None 637 | opts = None 638 | drawer = None 639 | perceptors = {} 640 | normalize = None 641 | cutoutsTable = {} 642 | cutoutSizeTable = {} 643 | init_image_tensor = None 644 | target_image_tensor = None 645 | pmsTable = None 646 | spotPmsTable = None 647 | spotOffPmsTable = None 648 | pmsImageTable = None 649 | gside_X=None 650 | gside_Y=None 651 | overlay_image_rgba=None 652 | device=None 653 | cur_iteration=None 654 | cur_anim_index=None 655 | anim_output_files=[] 656 | anim_cur_zs=[] 657 | anim_next_zs=[] 658 | 659 | def make_gif(args, iter): 660 | gif_output = os.path.join(args.animation_dir, "anim.gif") 661 | if os.path.exists(gif_output): 662 | os.remove(gif_output) 663 | cmd = ['ffmpeg', '-framerate', '10', '-pattern_type', 'glob', 664 | '-i', f"{args.animation_dir}/*.png", '-loop', '0', gif_output] 665 | try: 666 | output = subprocess.check_output(cmd) 667 | except subprocess.CalledProcessError as cpe: 668 | output = cpe.output 669 | print("Ignoring non-zero exit: ", output) 670 | 671 | return gif_output 672 | 673 | # !ffmpeg \ 674 | # -framerate 10 -pattern_type glob \ 675 | # -i '{animation_output}/*_*.png' \ 676 | # -loop 0 {animation_output}/final.gif 677 | 678 | @torch.no_grad() 679 | def checkin(args, iter, losses): 680 | global drawer 681 | losses_str = ', '.join(f'{loss.item():g}' for loss in losses) 682 | writestr = f'iter: {iter}, loss: {sum(losses).item():g}, losses: {losses_str}' 683 | if args.animation_dir is not None: 684 | writestr = f'anim: {cur_anim_index}/{len(anim_output_files)} {writestr}' 685 | tqdm.write(writestr) 686 | info = PngImagePlugin.PngInfo() 687 | info.add_text('comment', f'{args.prompts}') 688 | img = drawer.to_image() 689 | if cur_anim_index is None: 690 | outfile = args.output 691 | else: 692 | outfile = anim_output_files[cur_anim_index] 693 | img.save(outfile, pnginfo=info) 694 | if cur_anim_index == len(anim_output_files) - 1: 695 | # save gif 696 | gif_output = make_gif(args, iter) 697 | if IS_NOTEBOOK and iter % args.display_every == 0: 698 | clear_output() 699 | display.display(display.Image(open(gif_output,'rb').read())) 700 | if IS_NOTEBOOK and iter % args.display_every == 0: 701 | if cur_anim_index is None or iter == 0: 702 | display.display(display.Image(outfile)) 703 | 704 | def ascend_txt(args): 705 | global cur_iteration, cur_anim_index, perceptors, normalize, cutoutsTable, cutoutSizeTable 706 | global z_orig, z_targets, z_labels, init_image_tensor, target_image_tensor, drawer 707 | global pmsTable, pmsImageTable, spotPmsTable, spotOffPmsTable, global_padding_mode 708 | 709 | out = drawer.synth(cur_iteration); 710 | 711 | result = [] 712 | 713 | if (cur_iteration%2 == 0): 714 | global_padding_mode = 'reflection' 715 | else: 716 | global_padding_mode = 'border' 717 | 718 | cur_cutouts = {} 719 | cur_spot_cutouts = {} 720 | cur_spot_off_cutouts = {} 721 | for cutoutSize in cutoutsTable: 722 | make_cutouts = cutoutsTable[cutoutSize] 723 | cur_cutouts[cutoutSize] = make_cutouts(out) 724 | 725 | if args.spot_prompts: 726 | for cutoutSize in cutoutsTable: 727 | cur_spot_cutouts[cutoutSize] = make_cutouts(out, spot=1) 728 | 729 | if args.spot_prompts_off: 730 | for cutoutSize in cutoutsTable: 731 | cur_spot_off_cutouts[cutoutSize] = make_cutouts(out, spot=0) 732 | 733 | for clip_model in args.clip_models: 734 | perceptor = perceptors[clip_model] 735 | cutoutSize = cutoutSizeTable[clip_model] 736 | transient_pMs = [] 737 | 738 | if args.spot_prompts: 739 | iii_s = perceptor.encode_image(normalize( cur_spot_cutouts[cutoutSize] )).float() 740 | spotPms = spotPmsTable[clip_model] 741 | for prompt in spotPms: 742 | result.append(prompt(iii_s)) 743 | 744 | if args.spot_prompts_off: 745 | iii_so = perceptor.encode_image(normalize( cur_spot_off_cutouts[cutoutSize] )).float() 746 | spotOffPms = spotOffPmsTable[clip_model] 747 | for prompt in spotOffPms: 748 | result.append(prompt(iii_so)) 749 | 750 | pMs = pmsTable[clip_model] 751 | iii = perceptor.encode_image(normalize( cur_cutouts[cutoutSize] )).float() 752 | for prompt in pMs: 753 | result.append(prompt(iii)) 754 | 755 | # If there are image prompts we make cutouts for those each time 756 | # so that they line up with the current cutouts from augmentation 757 | make_cutouts = cutoutsTable[cutoutSize] 758 | pImages = pmsImageTable[clip_model] 759 | for timg in pImages: 760 | # note: this caches and reuses the transforms - a bit of a hack but it works 761 | 762 | if args.image_prompt_shuffle: 763 | # print("Disabling cached transforms") 764 | make_cutouts.transforms = None 765 | 766 | # print("Building throwaway image prompts") 767 | # new way builds throwaway Prompts 768 | batch = make_cutouts(timg) 769 | embed = perceptor.encode_image(normalize(batch)).float() 770 | if args.image_prompt_weight is not None: 771 | transient_pMs.append(Prompt(embed, args.image_prompt_weight).to(device)) 772 | else: 773 | transient_pMs.append(Prompt(embed).to(device)) 774 | 775 | for prompt in transient_pMs: 776 | result.append(prompt(iii)) 777 | 778 | if args.enforce_palette_annealing and args.target_palette: 779 | target_palette = torch.FloatTensor(args.target_palette).requires_grad_(False).to(device) 780 | _pixels = cur_cutouts[cutoutSize].permute(0,2,3,1).reshape(-1,3) 781 | palette_dists = torch.cdist(target_palette, _pixels, p=2) 782 | best_guesses = palette_dists.argmin(axis=0) 783 | diffs = _pixels - target_palette[best_guesses] 784 | palette_loss = torch.mean( torch.norm( diffs, 2, dim=1 ) )*cur_cutouts[cutoutSize].shape[0] 785 | result.append( palette_loss*cur_iteration/args.enforce_palette_annealing ) 786 | 787 | if args.enforce_smoothness and args.enforce_smoothness_type: 788 | _pixels = cur_cutouts[cutoutSize].permute(0,2,3,1).reshape(-1,cur_cutouts[cutoutSize].shape[2],3) 789 | gyr, gxr = torch.gradient(_pixels[:,:,0]) 790 | gyg, gxg = torch.gradient(_pixels[:,:,1]) 791 | gyb, gxb = torch.gradient(_pixels[:,:,2]) 792 | sharpness = torch.sqrt(gyr**2 + gxr**2+ gyg**2 + gxg**2 + gyb**2 + gxb**2) 793 | if args.enforce_smoothness_type=='clipped': 794 | sharpness = torch.clamp( sharpness, max=0.5 ) 795 | elif args.enforce_smoothness_type=='log': 796 | sharpness = torch.log( torch.ones_like(sharpness)+sharpness ) 797 | sharpness = torch.mean( sharpness ) 798 | 799 | result.append( sharpness*cur_iteration/args.enforce_smoothness ) 800 | 801 | if args.enforce_saturation: 802 | # based on the old "percepted colourfulness" heuristic from Hasler and Süsstrunk’s 2003 paper 803 | # https://www.researchgate.net/publication/243135534_Measuring_Colourfulness_in_Natural_Images 804 | _pixels = cur_cutouts[cutoutSize].permute(0,2,3,1).reshape(-1,3) 805 | rg = _pixels[:,0]-_pixels[:,1] 806 | yb = 0.5*(_pixels[:,0]+_pixels[:,1])-_pixels[:,2] 807 | rg_std, rg_mean = torch.std_mean(rg) 808 | yb_std, yb_mean = torch.std_mean(yb) 809 | std_rggb = torch.sqrt(rg_std**2 + yb_std**2) 810 | mean_rggb = torch.sqrt(rg_mean**2 + yb_mean**2) 811 | colorfullness = std_rggb+.3*mean_rggb 812 | 813 | result.append( -colorfullness*cur_iteration/args.enforce_saturation ) 814 | 815 | for cutoutSize in cutoutsTable: 816 | # clear the transform "cache" 817 | make_cutouts = cutoutsTable[cutoutSize] 818 | make_cutouts.transforms = None 819 | 820 | # main init_weight uses spherical loss 821 | if args.target_images is not None and args.target_image_weight > 0: 822 | if cur_anim_index is None: 823 | cur_z_targets = z_targets 824 | else: 825 | cur_z_targets = [ z_targets[cur_anim_index] ] 826 | for z_target in cur_z_targets: 827 | f = drawer.get_z().reshape(1,-1) 828 | f2 = z_target.reshape(1,-1) 829 | cur_loss = spherical_dist_loss(f, f2) * args.target_image_weight 830 | result.append(cur_loss) 831 | 832 | if args.target_weight_pix: 833 | if target_image_tensor is None: 834 | print("OOPS TIT is 0") 835 | else: 836 | cur_loss = F.l1_loss(out, target_image_tensor) * args.target_weight_pix 837 | result.append(cur_loss) 838 | 839 | if args.image_labels is not None: 840 | for z_label in z_labels: 841 | f = drawer.get_z().reshape(1,-1) 842 | f2 = z_label.reshape(1,-1) 843 | cur_loss = spherical_dist_loss(f, f2) * args.image_label_weight 844 | result.append(cur_loss) 845 | 846 | # main init_weight uses spherical loss 847 | if args.init_weight: 848 | f = drawer.get_z().reshape(1,-1) 849 | f2 = z_orig.reshape(1,-1) 850 | cur_loss = spherical_dist_loss(f, f2) * args.init_weight 851 | result.append(cur_loss) 852 | 853 | # these three init_weight variants offer mse_loss, mse_loss in pixel space, and cos loss 854 | if args.init_weight_dist: 855 | cur_loss = F.mse_loss(z, z_orig) * args.init_weight_dist / 2 856 | result.append(cur_loss) 857 | 858 | if args.init_weight_pix: 859 | if init_image_tensor is None: 860 | print("OOPS IIT is 0") 861 | else: 862 | cur_loss = F.l1_loss(out, init_image_tensor) * args.init_weight_pix / 2 863 | result.append(cur_loss) 864 | 865 | if args.init_weight_cos: 866 | f = drawer.get_z().reshape(1,-1) 867 | f2 = z_orig.reshape(1,-1) 868 | y = torch.ones_like(f[0]) 869 | cur_loss = F.cosine_embedding_loss(f, f2, y) * args.init_weight_cos 870 | result.append(cur_loss) 871 | 872 | if args.make_video: 873 | img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:] 874 | img = np.transpose(img, (1, 2, 0)) 875 | imageio.imwrite(f'./steps/frame_{cur_iteration:04d}.png', np.array(img)) 876 | 877 | return result 878 | 879 | def re_average_z(args): 880 | global gside_X, gside_Y 881 | global device, drawer 882 | 883 | # old_z = z.clone() 884 | cur_z_image = drawer.to_image() 885 | cur_z_image = cur_z_image.convert('RGB') 886 | if overlay_image_rgba: 887 | # print("applying overlay image") 888 | cur_z_image.paste(overlay_image_rgba, (0, 0), overlay_image_rgba) 889 | cur_z_image.save("overlaid.png") 890 | cur_z_image = cur_z_image.resize((gside_X, gside_Y), Image.LANCZOS) 891 | drawer.reapply_from_tensor(TF.to_tensor(cur_z_image).to(device).unsqueeze(0) * 2 - 1) 892 | 893 | # torch.autograd.set_detect_anomaly(True) 894 | 895 | def train(args, cur_it): 896 | global drawer; 897 | for opt in opts: 898 | # opt.zero_grad(set_to_none=True) 899 | opt.zero_grad() 900 | 901 | for i in range(args.batches): 902 | lossAll = ascend_txt(args) 903 | 904 | if i == 0 and cur_it % args.save_every == 0: 905 | checkin(args, cur_it, lossAll) 906 | 907 | loss = sum(lossAll) 908 | loss.backward() 909 | 910 | for opt in opts: 911 | opt.step() 912 | 913 | if args.overlay_every and cur_it != 0 and \ 914 | (cur_it % (args.overlay_every + args.overlay_offset)) == 0: 915 | re_average_z(args) 916 | 917 | drawer.clip_z() 918 | 919 | imagenet_templates = [ 920 | "itap of a {}.", 921 | "a bad photo of the {}.", 922 | "a origami {}.", 923 | "a photo of the large {}.", 924 | "a {} in a video game.", 925 | "art of the {}.", 926 | "a photo of the small {}.", 927 | ] 928 | 929 | def do_run(args): 930 | global cur_iteration, cur_anim_index 931 | global anim_cur_zs, anim_next_zs, anim_output_files 932 | 933 | cur_iteration = 0 934 | 935 | if args.animation_dir is not None: 936 | # we already have z_targets. setup some sort of global ring 937 | # we need something like 938 | # copies of all the current z's (they can all start off all as copies) 939 | # a list of all the output filenames 940 | # 941 | if not os.path.exists(args.animation_dir): 942 | os.mkdir(args.animation_dir) 943 | filelist = real_glob(args.target_images) 944 | num_anim_frames = len(filelist) 945 | for target_image in filelist: 946 | basename = os.path.basename(target_image) 947 | target_output = os.path.join(args.animation_dir, basename) 948 | anim_output_files.append(target_output) 949 | for i in range(num_anim_frames): 950 | cur_z = drawer.get_z_copy() 951 | anim_cur_zs.append(cur_z) 952 | anim_next_zs.append(None) 953 | 954 | step_iteration = 0 955 | 956 | with tqdm() as pbar: 957 | while True: 958 | cur_images = [] 959 | for i in range(num_anim_frames): 960 | # do merge frames here from cur->next when we are ready to be fancy 961 | cur_anim_index = i 962 | # anim_cur_zs[cur_anim_index] = anim_next_zs[cur_anim_index] 963 | cur_iteration = step_iteration 964 | drawer.set_z(anim_cur_zs[cur_anim_index]) 965 | for j in range(args.save_every): 966 | train(args, cur_iteration) 967 | cur_iteration += 1 968 | pbar.update() 969 | # anim_next_zs[cur_anim_index] = drawer.get_z_copy() 970 | cur_images.append(drawer.to_image()) 971 | step_iteration = step_iteration + args.save_every 972 | if step_iteration >= args.iterations: 973 | break 974 | # compute the next round of cur_zs here from all the next_zs 975 | for i in range(num_anim_frames): 976 | prev_i = (i + num_anim_frames - 1) % num_anim_frames 977 | base_image = cur_images[i].copy() 978 | prev_image = cur_images[prev_i].copy().convert('RGBA') 979 | prev_image.putalpha(args.animation_alpha) 980 | base_image.paste(prev_image, (0, 0), prev_image) 981 | # base_image.save(f"overlaid_{i:02d}.png") 982 | drawer.reapply_from_tensor(TF.to_tensor(base_image).to(device).unsqueeze(0) * 2 - 1) 983 | anim_cur_zs[i] = drawer.get_z_copy() 984 | else: 985 | try: 986 | with tqdm() as pbar: 987 | while True: 988 | try: 989 | train(args, cur_iteration) 990 | if cur_iteration == args.iterations: 991 | break 992 | cur_iteration += 1 993 | pbar.update() 994 | except RuntimeError as e: 995 | print("Oops: runtime error: ", e) 996 | print("Try reducing --num-cuts to save memory") 997 | raise e 998 | except KeyboardInterrupt: 999 | pass 1000 | 1001 | if args.make_video: 1002 | do_video(args) 1003 | 1004 | def do_video(args): 1005 | global cur_iteration 1006 | 1007 | # Video generation 1008 | init_frame = 1 # This is the frame where the video will start 1009 | last_frame = cur_iteration # You can change to the number of the last frame you want to generate. It will raise an error if that number of frames does not exist. 1010 | 1011 | min_fps = 10 1012 | max_fps = 60 1013 | 1014 | total_frames = last_frame-init_frame 1015 | 1016 | length = 15 # Desired time of the video in seconds 1017 | 1018 | frames = [] 1019 | tqdm.write('Generating video...') 1020 | for i in range(init_frame,last_frame): # 1021 | frames.append(Image.open(f'./steps/frame_{i:04d}.png')) 1022 | 1023 | #fps = last_frame/10 1024 | fps = np.clip(total_frames/length,min_fps,max_fps) 1025 | 1026 | from subprocess import Popen, PIPE 1027 | import re 1028 | output_file = re.compile('\.png$').sub('.mp4', args.output) 1029 | p = Popen(['ffmpeg', 1030 | '-y', 1031 | '-f', 'image2pipe', 1032 | '-vcodec', 'png', 1033 | '-r', str(fps), 1034 | '-i', 1035 | '-', 1036 | '-vcodec', 'libx264', 1037 | '-r', str(fps), 1038 | '-pix_fmt', 'yuv420p', 1039 | '-crf', '17', 1040 | '-preset', 'veryslow', 1041 | '-metadata', f'comment={args.prompts}', 1042 | output_file], stdin=PIPE) 1043 | for im in tqdm(frames): 1044 | im.save(p.stdin, 'PNG') 1045 | p.stdin.close() 1046 | p.wait() 1047 | 1048 | # this dictionary is used for settings in the notebook 1049 | global_clipit_settings = {} 1050 | 1051 | def setup_parser(): 1052 | # Create the parser 1053 | vq_parser = argparse.ArgumentParser(description='Image generation using VQGAN+CLIP') 1054 | 1055 | # Add the arguments 1056 | vq_parser.add_argument("-p", "--prompts", type=str, help="Text prompts", default=[], dest='prompts') 1057 | vq_parser.add_argument("-sp", "--spot", type=str, help="Spot Text prompts", default=[], dest='spot_prompts') 1058 | vq_parser.add_argument("-spo", "--spot_off", type=str, help="Spot off Text prompts", default=[], dest='spot_prompts_off') 1059 | vq_parser.add_argument("-spf", "--spot_file", type=str, help="Custom spot file", default=None, dest='spot_file') 1060 | vq_parser.add_argument("-l", "--labels", type=str, help="ImageNet labels", default=[], dest='labels') 1061 | vq_parser.add_argument("-ip", "--image_prompts", type=str, help="Image prompts", default=[], dest='image_prompts') 1062 | vq_parser.add_argument("-ipw", "--image_prompt_weight", type=float, help="Weight for image prompt", default=None, dest='image_prompt_weight') 1063 | vq_parser.add_argument("-ips", "--image_prompt_shuffle", type=bool, help="Shuffle image prompts", default=False, dest='image_prompt_shuffle') 1064 | vq_parser.add_argument("-il", "--image_labels", type=str, help="Image prompts", default=None, dest='image_labels') 1065 | vq_parser.add_argument("-ilw", "--image_label_weight", type=float, help="Weight for image prompt", default=1.0, dest='image_label_weight') 1066 | vq_parser.add_argument("-i", "--iterations", type=int, help="Number of iterations", default=None, dest='iterations') 1067 | vq_parser.add_argument("-se", "--save_every", type=int, help="Save image iterations", default=10, dest='save_every') 1068 | vq_parser.add_argument("-de", "--display_every", type=int, help="Display image iterations", default=20, dest='display_every') 1069 | vq_parser.add_argument("-ove", "--overlay_every", type=int, help="Overlay image iterations", default=None, dest='overlay_every') 1070 | vq_parser.add_argument("-ovo", "--overlay_offset", type=int, help="Overlay image iteration offset", default=0, dest='overlay_offset') 1071 | vq_parser.add_argument("-ovi", "--overlay_image", type=str, help="Overlay image (if not init)", default=None, dest='overlay_image') 1072 | vq_parser.add_argument("-qua", "--quality", type=str, help="draft, normal, best", default="normal", dest='quality') 1073 | vq_parser.add_argument("-asp", "--aspect", type=str, help="widescreen, square", default="widescreen", dest='aspect') 1074 | vq_parser.add_argument("-ezs", "--ezsize", type=str, help="small, medium, large", default=None, dest='ezsize') 1075 | vq_parser.add_argument("-sca", "--scale", type=float, help="scale (instead of ezsize)", default=None, dest='scale') 1076 | vq_parser.add_argument("-ova", "--overlay_alpha", type=int, help="Overlay alpha (0-255)", default=None, dest='overlay_alpha') 1077 | vq_parser.add_argument("-s", "--size", nargs=2, type=int, help="Image size (width height)", default=None, dest='size') 1078 | vq_parser.add_argument("-ps", "--pixel_size", nargs=2, type=int, help="Pixel size (width height)", default=None, dest='pixel_size') 1079 | vq_parser.add_argument("-psc", "--pixel_scale", type=float, help="Pixel scale", default=None, dest='pixel_scale') 1080 | vq_parser.add_argument("-ii", "--init_image", type=str, help="Initial image", default=None, dest='init_image') 1081 | vq_parser.add_argument("-iia", "--init_image_alpha", type=int, help="Init image alpha (0-255)", default=200, dest='init_image_alpha') 1082 | vq_parser.add_argument("-in", "--init_noise", type=str, help="Initial noise image (pixels or gradient)", default="pixels", dest='init_noise') 1083 | vq_parser.add_argument("-ti", "--target_images", type=str, help="Target images", default=None, dest='target_images') 1084 | vq_parser.add_argument("-tiw", "--target_image_weight", type=float, help="Target images weight", default=1.0, dest='target_image_weight') 1085 | vq_parser.add_argument("-twp", "--target_weight_pix", type=float, help="Target weight pix loss", default=0., dest='target_weight_pix') 1086 | vq_parser.add_argument("-anim", "--animation_dir", type=str, help="Animation output dir", default=None, dest='animation_dir') 1087 | vq_parser.add_argument("-ana", "--animation_alpha", type=int, help="Forward blend for consistency", default=128, dest='animation_alpha') 1088 | vq_parser.add_argument("-iw", "--init_weight", type=float, help="Initial weight (main=spherical)", default=None, dest='init_weight') 1089 | vq_parser.add_argument("-iwd", "--init_weight_dist", type=float, help="Initial weight dist loss", default=0., dest='init_weight_dist') 1090 | vq_parser.add_argument("-iwc", "--init_weight_cos", type=float, help="Initial weight cos loss", default=0., dest='init_weight_cos') 1091 | vq_parser.add_argument("-iwp", "--init_weight_pix", type=float, help="Initial weight pix loss", default=0., dest='init_weight_pix') 1092 | vq_parser.add_argument("-m", "--clip_models", type=str, help="CLIP model", default=None, dest='clip_models') 1093 | vq_parser.add_argument("-vqgan", "--vqgan_model", type=str, help="VQGAN model", default='imagenet_f16_16384', dest='vqgan_model') 1094 | vq_parser.add_argument("-conf", "--vqgan_config", type=str, help="VQGAN config", default=None, dest='vqgan_config') 1095 | vq_parser.add_argument("-ckpt", "--vqgan_checkpoint", type=str, help="VQGAN checkpoint", default=None, dest='vqgan_checkpoint') 1096 | vq_parser.add_argument("-nps", "--noise_prompt_seeds", nargs="*", type=int, help="Noise prompt seeds", default=[], dest='noise_prompt_seeds') 1097 | vq_parser.add_argument("-npw", "--noise_prompt_weights", nargs="*", type=float, help="Noise prompt weights", default=[], dest='noise_prompt_weights') 1098 | vq_parser.add_argument("-lr", "--learning_rate", type=float, help="Learning rate", default=0.2, dest='learning_rate') 1099 | vq_parser.add_argument("-cuts", "--num_cuts", type=int, help="Number of cuts", default=None, dest='num_cuts') 1100 | vq_parser.add_argument("-bats", "--batches", type=int, help="How many batches of cuts", default=1, dest='batches') 1101 | vq_parser.add_argument("-cutp", "--cut_power", type=float, help="Cut power", default=1., dest='cut_pow') 1102 | vq_parser.add_argument("-sd", "--seed", type=int, help="Seed", default=None, dest='seed') 1103 | vq_parser.add_argument("-opt", "--optimiser", type=str, help="Optimiser (Adam, AdamW, Adagrad, Adamax, DiffGrad, AdamP or RAdam)", default='Adam', dest='optimiser') 1104 | vq_parser.add_argument("-o", "--output", type=str, help="Output file", default="output.png", dest='output') 1105 | vq_parser.add_argument("-vid", "--video", type=bool, help="Create video frames?", default=False, dest='make_video') 1106 | vq_parser.add_argument("-d", "--deterministic", type=bool, help="Enable cudnn.deterministic?", default=False, dest='cudnn_determinism') 1107 | vq_parser.add_argument("-cd", "--use_clipdraw", type=bool, help="Use clipdraw", default=False, dest='use_clipdraw') 1108 | vq_parser.add_argument("-st", "--strokes", type=int, help="clipdraw strokes", default=1024, dest='strokes') 1109 | vq_parser.add_argument("-pd", "--use_pixeldraw", type=bool, help="Use pixeldraw", default=False, dest='use_pixeldraw') 1110 | vq_parser.add_argument("-mo", "--do_mono", type=bool, help="Monochromatic", default=False, dest='do_mono') 1111 | vq_parser.add_argument("-epw", "--enforce_palette_annealing", type=int, help="enforce palette annealing, 0 -- skip", default=5000, dest='enforce_palette_annealing') 1112 | vq_parser.add_argument("-tp", "--target_palette", type=str, help="target palette", default=None, dest='target_palette') 1113 | vq_parser.add_argument("-esw", "--enforce_smoothness", type=int, help="enforce smoothness, 0 -- skip", default=0, dest='enforce_smoothness') 1114 | vq_parser.add_argument("-est", "--enforce_smoothness_type", type=str, help="enforce smoothness type: default/clipped/log", default='default', dest='enforce_smoothness_type') 1115 | vq_parser.add_argument("-ecw", "--enforce_saturation", type=int, help="enforce saturation, 0 -- skip", default=0, dest='enforce_saturation') 1116 | 1117 | return vq_parser 1118 | 1119 | square_size = [144, 144] 1120 | widescreen_size = [200, 112] # at the small size this becomes 192,112 1121 | 1122 | ####### PALETTE SECTION ########## 1123 | 1124 | # canonical interpolation function, like https://p5js.org/reference/#/p5/map 1125 | def map_number(n, start1, stop1, start2, stop2): 1126 | return ((n-start1)/(stop1-start1))*(stop2-start2)+start2; 1127 | 1128 | # here are examples of what can be parsed 1129 | # white (16 color black to white ramp) 1130 | # red (16 color black to red ramp) 1131 | # rust\8 (8 color black to rust ramp) 1132 | # red->rust (16 color red to rust ramp) 1133 | # red->#ff0000 (16 color red to yellow ramp) 1134 | # red->#ff0000\20 (20 color red to yellow ramp) 1135 | # black->red->white (16 color black/red/white ramp) 1136 | # [black, red, #ff0000] (three colors) 1137 | # red->white;blue->yellow (32 colors across two ramps of 16) 1138 | # red;blue;yellow (48 colors from combining 3 ramps) 1139 | # red\8;blue->yellow\8 (16 colors from combining 2 ramps) 1140 | # red->yellow;[black] (16 colors from ramp and also black) 1141 | # 1142 | # TODO: maybe foo.jpg, foo.json, foo.png, foo.asc 1143 | def get_single_rgb(s): 1144 | palette_lookups = { 1145 | "pixel_green": [0.44, 1.00, 0.53], 1146 | "pixel_orange": [1.00, 0.80, 0.20], 1147 | "pixel_blue": [0.44, 0.53, 1.00], 1148 | "pixel_red": [1.00, 0.53, 0.44], 1149 | "pixel_grayscale": [1.00, 1.00, 1.00], 1150 | } 1151 | if s in palette_lookups: 1152 | rgb = palette_lookups[s] 1153 | elif s[:4] == "mat:": 1154 | rgb = matplotlib.colors.to_rgb(s[4:]) 1155 | elif matplotlib.colors.is_color_like(f"xkcd:{s}"): 1156 | rgb = matplotlib.colors.to_rgb(f"xkcd:{s}") 1157 | else: 1158 | rgb = matplotlib.colors.to_rgb(s) 1159 | return rgb 1160 | 1161 | def expand_colors(colors, num_steps): 1162 | index_episilon = 1e-6; 1163 | pal = [] 1164 | num_colors = len(colors) 1165 | for n in range(num_steps): 1166 | cur_float_index = map_number(n, 0, num_steps-1, 0, num_colors-1) 1167 | cur_int_index = int(cur_float_index) 1168 | cur_float_offset = cur_float_index - cur_int_index 1169 | if(cur_float_offset < index_episilon or (1.0-cur_float_offset) < index_episilon): 1170 | # debug print(n, "->", cur_int_index) 1171 | pal.append(colors[cur_int_index]) 1172 | else: 1173 | # debug print(n, num_steps, num_colors, cur_float_index, cur_int_index, cur_float_offset) 1174 | rgb1 = colors[cur_int_index] 1175 | rgb2 = colors[cur_int_index+1] 1176 | r = map_number(cur_float_offset, 0, 1, rgb1[0], rgb2[0]) 1177 | g = map_number(cur_float_offset, 0, 1, rgb1[1], rgb2[1]) 1178 | b = map_number(cur_float_offset, 0, 1, rgb1[2], rgb2[2]) 1179 | pal.append([r, g, b]) 1180 | return pal 1181 | 1182 | def get_rgb_range(s): 1183 | # get the list that defines the range 1184 | if s.find('->') > 0: 1185 | parts = s.split('->') 1186 | else: 1187 | parts = ["black", s] 1188 | 1189 | # look for a number of parts at the end 1190 | if parts[-1].find('\\') > 0: 1191 | colname, steps = parts[-1].split('\\') 1192 | parts[-1] = colname 1193 | num_steps = int(steps) 1194 | else: 1195 | num_steps = 16 1196 | 1197 | colors = [get_single_rgb(s) for s in parts] 1198 | #debug print("We have colors: ", colors) 1199 | 1200 | pal = expand_colors(colors, num_steps) 1201 | return pal 1202 | 1203 | def palette_from_section(s): 1204 | s = s.strip() 1205 | if s[0] == '[': 1206 | # look for a number of parts at the end 1207 | if s.find('\\') > 0: 1208 | col_list, steps = s.split('\\') 1209 | s = col_list 1210 | num_steps = int(steps) 1211 | else: 1212 | num_steps = None 1213 | 1214 | chunks = s[1:-1].split(",") 1215 | # chunks = [s.strip().tolower() for c in chunks] 1216 | pal = [get_single_rgb(c.strip()) for c in chunks] 1217 | 1218 | if num_steps is not None: 1219 | pal = expand_colors(pal, num_steps) 1220 | 1221 | return pal 1222 | else: 1223 | return get_rgb_range(s) 1224 | 1225 | def palette_from_string(s): 1226 | s = s.strip() 1227 | pal = [] 1228 | chunks = s.split(';') 1229 | for c in chunks: 1230 | pal = pal + palette_from_section(c) 1231 | return pal 1232 | 1233 | def process_args(vq_parser, namespace=None): 1234 | global global_aspect_width 1235 | global cur_iteration, cur_anim_index, anim_output_files, anim_cur_zs, anim_next_zs; 1236 | global global_spot_file 1237 | 1238 | if namespace == None: 1239 | # command line: use ARGV to get args 1240 | args = vq_parser.parse_args() 1241 | else: 1242 | # notebook, ignore ARGV and use dictionary instead 1243 | args = vq_parser.parse_args(args=[], namespace=namespace) 1244 | 1245 | if args.cudnn_determinism: 1246 | torch.backends.cudnn.deterministic = True 1247 | 1248 | quality_to_clip_models_table = { 1249 | 'draft': 'ViT-B/32', 1250 | 'normal': 'ViT-B/32,ViT-B/16', 1251 | 'better': 'RN50,ViT-B/32,ViT-B/16', 1252 | 'best': 'RN50x4,ViT-B/32,ViT-B/16' 1253 | } 1254 | quality_to_iterations_table = { 1255 | 'draft': 200, 1256 | 'normal': 350, 1257 | 'better': 500, 1258 | 'best': 500 1259 | } 1260 | quality_to_scale_table = { 1261 | 'draft': 1, 1262 | 'normal': 2, 1263 | 'better': 3, 1264 | 'best': 4 1265 | } 1266 | # this should be replaced with logic that does somethings 1267 | # smart based on available memory (eg: size, num_models, etc) 1268 | quality_to_num_cuts_table = { 1269 | 'draft': 40, 1270 | 'normal': 40, 1271 | 'better': 40, 1272 | 'best': 40 1273 | } 1274 | 1275 | if args.quality not in quality_to_clip_models_table: 1276 | print("Qualitfy setting not understood, aborting -> ", args.quality) 1277 | exit(1) 1278 | 1279 | if args.clip_models is None: 1280 | args.clip_models = quality_to_clip_models_table[args.quality] 1281 | if args.iterations is None: 1282 | args.iterations = quality_to_iterations_table[args.quality] 1283 | if args.num_cuts is None: 1284 | args.num_cuts = quality_to_num_cuts_table[args.quality] 1285 | if args.ezsize is None and args.scale is None: 1286 | args.scale = quality_to_scale_table[args.quality] 1287 | 1288 | size_to_scale_table = { 1289 | 'small': 1, 1290 | 'medium': 2, 1291 | 'large': 4 1292 | } 1293 | aspect_to_size_table = { 1294 | 'square': [150, 150], 1295 | 'widescreen': [200, 112] 1296 | } 1297 | 1298 | if args.size is not None: 1299 | global_aspect_width = args.size[0] / args.size[1] 1300 | elif args.aspect == "widescreen": 1301 | global_aspect_width = 16/9 1302 | else: 1303 | global_aspect_width = 1 1304 | 1305 | # determine size if not set 1306 | if args.size is None: 1307 | size_scale = args.scale 1308 | if size_scale is None: 1309 | if args.ezsize in size_to_scale_table: 1310 | size_scale = size_to_scale_table[args.ezsize] 1311 | else: 1312 | print("EZ Size not understood, aborting -> ", args.ezsize) 1313 | exit(1) 1314 | if args.aspect in aspect_to_size_table: 1315 | base_size = aspect_to_size_table[args.aspect] 1316 | base_width = int(size_scale * base_size[0]) 1317 | base_height = int(size_scale * base_size[1]) 1318 | args.size = [base_width, base_height] 1319 | else: 1320 | print("aspect not understood, aborting -> ", args.aspect) 1321 | exit(1) 1322 | 1323 | if args.init_noise.lower() == "none": 1324 | args.init_noise = None 1325 | 1326 | # Split text prompts using the pipe character 1327 | if args.prompts: 1328 | args.prompts = [phrase.strip() for phrase in args.prompts.split("|")] 1329 | 1330 | # Split text prompts using the pipe character 1331 | if args.spot_prompts: 1332 | args.spot_prompts = [phrase.strip() for phrase in args.spot_prompts.split("|")] 1333 | 1334 | # Split text prompts using the pipe character 1335 | if args.spot_prompts_off: 1336 | args.spot_prompts_off = [phrase.strip() for phrase in args.spot_prompts_off.split("|")] 1337 | 1338 | # Split text labels using the pipe character 1339 | if args.labels: 1340 | args.labels = [phrase.strip() for phrase in args.labels.split("|")] 1341 | 1342 | # Split target images using the pipe character 1343 | if args.image_prompts: 1344 | args.image_prompts = args.image_prompts.split("|") 1345 | args.image_prompts = [image.strip() for image in args.image_prompts] 1346 | 1347 | if args.target_palette is not None: 1348 | args.target_palette = palette_from_string(args.target_palette) 1349 | 1350 | if args.overlay_every is not None and args.overlay_every <= 0: 1351 | args.overlay_every = None 1352 | 1353 | clip_models = args.clip_models.split(",") 1354 | args.clip_models = [model.strip() for model in clip_models] 1355 | 1356 | # Make video steps directory 1357 | if args.make_video: 1358 | if not os.path.exists('steps'): 1359 | os.mkdir('steps') 1360 | 1361 | # reset global animation variables 1362 | cur_iteration=None 1363 | cur_anim_index=None 1364 | anim_output_files=[] 1365 | anim_cur_zs=[] 1366 | anim_next_zs=[] 1367 | 1368 | global_spot_file = args.spot_file 1369 | 1370 | return args 1371 | 1372 | def reset_settings(): 1373 | global global_clipit_settings 1374 | global_clipit_settings = {} 1375 | 1376 | def add_settings(**kwargs): 1377 | global global_clipit_settings 1378 | for k, v in kwargs.items(): 1379 | if v is None: 1380 | # just remove the key if it is there 1381 | global_clipit_settings.pop(k, None) 1382 | else: 1383 | global_clipit_settings[k] = v 1384 | 1385 | def apply_settings(): 1386 | global global_clipit_settings 1387 | settingsDict = None 1388 | vq_parser = setup_parser() 1389 | 1390 | if len(global_clipit_settings) > 0: 1391 | # check for any bogus entries in the settings 1392 | dests = [d.dest for d in vq_parser._actions] 1393 | for k in global_clipit_settings: 1394 | if not k in dests: 1395 | raise ValueError(f"Requested setting not found, aborting: {k}={global_clipit_settings[k]}") 1396 | 1397 | # convert dictionary to easyDict 1398 | # which can be used as an argparse namespace instead 1399 | # settingsDict = easydict.EasyDict(global_clipit_settings) 1400 | settingsDict = SimpleNamespace(**global_clipit_settings) 1401 | 1402 | settings = process_args(vq_parser, settingsDict) 1403 | return settings 1404 | 1405 | def main(): 1406 | settings = apply_settings() 1407 | do_init(settings) 1408 | do_run(settings) 1409 | 1410 | if __name__ == '__main__': 1411 | main() -------------------------------------------------------------------------------- /demos/Moar_Settings.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "VQGAN+CLIP (with overlays).ipynb\"", 7 | "private_outputs": true, 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "machine_shape": "hm", 11 | "include_colab_link": true 12 | }, 13 | "kernelspec": { 14 | "name": "python3", 15 | "display_name": "Python 3" 16 | }, 17 | "language_info": { 18 | "name": "python" 19 | }, 20 | "accelerator": "GPU" 21 | }, 22 | "cells": [ 23 | { 24 | "cell_type": "markdown", 25 | "metadata": { 26 | "id": "view-in-github", 27 | "colab_type": "text" 28 | }, 29 | "source": [ 30 | "\"Open" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": { 36 | "id": "CppIQlPhhwhs" 37 | }, 38 | "source": [ 39 | "# Generate images from text prompts with VQGAN and CLIP (z+quantize method).\n", 40 | "\n", 41 | "Originally made by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). The original BigGAN+CLIP method was by https://twitter.com/advadnoun.\n", 42 | " Added some explanations and modifications by Eleiber#8347, pooling trick by Crimeacs#8222 (https://twitter.com/EarthML1) and the GUI was made with the help of Abulafia#3734.\n", 43 | "\n", 44 | " This notebook supports [@dribnet's clipit repo](https://github.com/dribnet/clipit) which is a fork of nerdyrodent's command line version with some features such as overlay added.\n" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "metadata": { 50 | "cellView": "form", 51 | "id": "-nnPDoDFCp6n" 52 | }, 53 | "source": [ 54 | "# @title Licensed under the MIT License\n", 55 | "\n", 56 | "# Copyright (c) 2021 Katherine Crowson\n", 57 | "\n", 58 | "# Permission is hereby granted, free of charge, to any person obtaining a copy\n", 59 | "# of this software and associated documentation files (the \"Software\"), to deal\n", 60 | "# in the Software without restriction, including without limitation the rights\n", 61 | "# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", 62 | "# copies of the Software, and to permit persons to whom the Software is\n", 63 | "# furnished to do so, subject to the following conditions:\n", 64 | "\n", 65 | "# The above copyright notice and this permission notice shall be included in\n", 66 | "# all copies or substantial portions of the Software.\n", 67 | "\n", 68 | "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", 69 | "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", 70 | "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", 71 | "# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", 72 | "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", 73 | "# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n", 74 | "# THE SOFTWARE.\n" 75 | ], 76 | "execution_count": null, 77 | "outputs": [] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "metadata": { 82 | "id": "TkUfzT60ZZ9q" 83 | }, 84 | "source": [ 85 | "!nvidia-smi" 86 | ], 87 | "execution_count": null, 88 | "outputs": [] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "metadata": { 93 | "id": "VA1PHoJrRiK9" 94 | }, 95 | "source": [ 96 | " # On 2021/10/08, Colab updated its default PyTorch installation to a version that causes\n", 97 | " # problems with diffvg. So, first thing, let's roll back to the older version:\n", 98 | " !pip install torch==1.9.0+cu102 torchvision==0.10.0+cu102 -f https://download.pytorch.org/whl/torch/ -f https://download.pytorch.org/whl/torchvision/\n", 99 | "\n", 100 | "!git clone https://github.com/openai/CLIP\n", 101 | "# !pip install taming-transformers\n", 102 | "!git clone https://github.com/CompVis/taming-transformers.git\n", 103 | "!rm -Rf clipit\n", 104 | "!git clone https://github.com/dribnet/clipit\n", 105 | "!pip install ftfy regex tqdm omegaconf pytorch-lightning\n", 106 | "!pip install kornia==0.6.1\n", 107 | "!pip install imageio-ffmpeg \n", 108 | "!pip install einops\n", 109 | "!pip install torch-optimizer\n", 110 | "!pip install easydict\n", 111 | "!pip install braceexpand\n", 112 | "!pip install git+https://github.com/pvigier/perlin-numpy\n", 113 | "!mkdir steps\n", 114 | "!wget https://user-images.githubusercontent.com/945979/126260797-adc60317-9518-40de-8700-b1f93e81e0ec.png -O this_is_fine.png\n", 115 | "!wget https://user-images.githubusercontent.com/945979/126415385-d70ff2b0-f021-4238-9621-6180d33b242c.jpg -O perfume.jpg" 116 | ], 117 | "execution_count": null, 118 | "outputs": [] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": { 123 | "id": "nTg77tNuF7Og" 124 | }, 125 | "source": [ 126 | "By default, the notebook downloads the 1024 and 16384 models from ImageNet. There are others like COCO-Stuff, WikiArt or S-FLCKR, which are heavy, and if you are not going to use them it would be useless to download them, so if you want to use them, simply remove the numerals at the beginning of the lines depending on the model you want (the model name is at the end of the lines)." 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "metadata": { 132 | "id": "FhhdWrSxQhwg", 133 | "cellView": "form" 134 | }, 135 | "source": [ 136 | "#@title Selection of models to download\n", 137 | "#@markdown By default, the notebook downloads the 1024 and 16384 models from ImageNet. There are others like COCO-Stuff, WikiArt 1024, WikiArt 16384, FacesHQ or S-FLCKR, which are heavy, and if you are not going to use them it would be pointless to download them, so if you want to use them, simply select the models to download.\n", 138 | "\n", 139 | "imagenet_1024 = False #@param {type:\"boolean\"}\n", 140 | "imagenet_16384 = True #@param {type:\"boolean\"}\n", 141 | "coco = False #@param {type:\"boolean\"}\n", 142 | "faceshq = False #@param {type:\"boolean\"}\n", 143 | "wikiart_1024 = False #@param {type:\"boolean\"}\n", 144 | "wikiart_16384 = False #@param {type:\"boolean\"}\n", 145 | "sflckr = False #@param {type:\"boolean\"}\n", 146 | "openimages_8192 = False #@param {type:\"boolean\"}\n", 147 | "\n", 148 | "if imagenet_1024:\n", 149 | " !curl -L -o vqgan_imagenet_f16_1024.yaml -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_1024.yaml' #ImageNet 1024\n", 150 | " !curl -L -o vqgan_imagenet_f16_1024.ckpt -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_1024.ckpt' #ImageNet 1024\n", 151 | "if imagenet_16384:\n", 152 | " !curl -L -o vqgan_imagenet_f16_16384.yaml -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.yaml' #ImageNet 16384\n", 153 | " !curl -L -o vqgan_imagenet_f16_16384.ckpt -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.ckpt' #ImageNet 16384\n", 154 | "if openimages_8192:\n", 155 | " !curl -L -o vqgan_openimages_f16_8192.yaml -C - 'https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' #ImageNet 16384\n", 156 | " !curl -L -o vqgan_openimages_f16_8192.ckpt -C - 'https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fckpts%2Flast.ckpt&dl=1' #ImageNet 16384\n", 157 | "\n", 158 | "if coco:\n", 159 | " !curl -L -o coco.yaml -C - 'https://dl.nmkd.de/ai/clip/coco/coco.yaml' #COCO\n", 160 | " !curl -L -o coco.ckpt -C - 'https://dl.nmkd.de/ai/clip/coco/coco.ckpt' #COCO\n", 161 | "if faceshq:\n", 162 | " !curl -L -o faceshq.yaml -C - 'https://drive.google.com/uc?export=download&id=1fHwGx_hnBtC8nsq7hesJvs-Klv-P0gzT' #FacesHQ\n", 163 | " !curl -L -o faceshq.ckpt -C - 'https://app.koofr.net/content/links/a04deec9-0c59-4673-8b37-3d696fe63a5d/files/get/last.ckpt?path=%2F2020-11-13T21-41-45_faceshq_transformer%2Fcheckpoints%2Flast.ckpt' #FacesHQ\n", 164 | "if wikiart_1024: \n", 165 | " !curl -L -o wikiart_1024.yaml -C - 'http://mirror.io.community/blob/vqgan/wikiart.yaml' #WikiArt 1024\n", 166 | " !curl -L -o wikiart_1024.ckpt -C - 'http://mirror.io.community/blob/vqgan/wikiart.ckpt' #WikiArt 1024\n", 167 | "if wikiart_16384: \n", 168 | " !curl -L -o wikiart_16384.yaml -C - 'http://mirror.io.community/blob/vqgan/wikiart_16384.yaml' #WikiArt 16384\n", 169 | " !curl -L -o wikiart_16384.ckpt -C - 'http://mirror.io.community/blob/vqgan/wikiart_16384.ckpt' #WikiArt 16384\n", 170 | "if sflckr:\n", 171 | " !curl -L -o sflckr.yaml -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fconfigs%2F2020-11-09T13-31-51-project.yaml&dl=1' #S-FLCKR\n", 172 | " !curl -L -o sflckr.ckpt -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fcheckpoints%2Flast.ckpt&dl=1' #S-FLCKR" 173 | ], 174 | "execution_count": null, 175 | "outputs": [] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": { 180 | "id": "1tthw0YaispD" 181 | }, 182 | "source": [ 183 | "## Settings for this run:\n", 184 | "Mainly what you will have to modify will be `texts:`, there you can place the text or texts you want to generate (separated with `|`). It is a list because you can put more than one text, and so the AI ​​tries to 'mix' the images, giving the same priority to both texts.\n", 185 | "\n", 186 | "To use an initial image to the model, you just have to upload a file to the Colab environment (in the section on the left), and then modify `init_image:` putting the exact name of the file. Example: `sample.png`\n", 187 | "\n", 188 | "You can also modify the model by changing the lines that say `model:`. Currently ImageNet 1024, ImageNet 16384, WikiArt 1024, WikiArt 16384, S-FLCKR and COCO-Stuff are available. To activate them you have to have downloaded them first, and then you can simply select it.\n", 189 | "\n", 190 | "You can also use `target_images`, which is basically putting one or more images on it that the AI ​​will take as a \"target\", fulfilling the same function as putting text on it. To put more than one you have to use `|` as a separator." 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "metadata": { 196 | "id": "ZdlpRFL8UAlW", 197 | "cellView": "form" 198 | }, 199 | "source": [ 200 | "#@title Parameters\n", 201 | "prompts = \"this is fine\" #@param {type:\"string\"}\n", 202 | "image_prompts = \"\" #@param {type:\"string\"}\n", 203 | "init_image = \"this_is_fine.png\" #@param {type:\"string\"}\n", 204 | "overlay_every = 20 #@param {type:\"number\"}\n", 205 | "model = \"vqgan_imagenet_f16_16384\" #@param [\"vqgan_imagenet_f16_16384\", \"vqgan_imagenet_f16_1024\", \"vqgan_openimages_f16_8192\", \"wikiart_1024\", \"wikiart_16384\", \"coco\", \"faceshq\", \"sflckr\"]\n", 206 | "seed = 42 #@param {type:\"number\"}\n", 207 | "display_freq = 50 #@param {type:\"number\"}\n", 208 | "max_iterations = 400 #@param {type:\"number\"}\n", 209 | "width = 256 #@param {type:\"number\"}\n", 210 | "height = 256 #@param {type:\"number\"}\n", 211 | "\n", 212 | "############# SETUP WITH THESE SETTINGS\n", 213 | "\n", 214 | "model_names={\"vqgan_imagenet_f16_16384\": 'ImageNet 16384',\"vqgan_imagenet_f16_1024\":\"ImageNet 1024\", 'vqgan_openimages_f16_8192':'OpenImages 8912',\n", 215 | " \"wikiart_1024\":\"WikiArt 1024\", \"wikiart_16384\":\"WikiArt 16384\", \"coco\":\"COCO-Stuff\", \"faceshq\":\"FacesHQ\", \"sflckr\":\"S-FLCKR\"}\n", 216 | "name_model = model_names[model] \n", 217 | "\n", 218 | "if seed == -1:\n", 219 | " seed = None\n", 220 | "if overlay_every == \"None\":\n", 221 | " overlay_every = None\n", 222 | "if init_image == \"None\":\n", 223 | " init_image = None\n", 224 | "if image_prompts == \"None\" or not image_prompts:\n", 225 | " image_prompts = []\n", 226 | "\n", 227 | "# Simple setup\n", 228 | "from clipit import generate\n", 229 | "import easydict\n", 230 | "\n", 231 | "args = easydict.EasyDict({\n", 232 | " \"prompts\": prompts,\n", 233 | " \"image_prompts\": image_prompts,\n", 234 | " \"init_image\": init_image,\n", 235 | " \"overlay_every\": overlay_every,\n", 236 | " \"vqgan_config\": f'{model}.yaml',\n", 237 | " \"vqgan_checkpoint\": f'{model}.ckpt',\n", 238 | " \"seed\": seed,\n", 239 | " \"display_freq\": display_freq,\n", 240 | " \"max_iterations\": max_iterations,\n", 241 | " \"size\": [width, height],\n", 242 | " \"init_noise\": \"pixels\"\n", 243 | "})\n", 244 | "\n", 245 | "vq_parser = generate.setup_parser()\n", 246 | "settings = generate.process_args(vq_parser, namespace=args)\n", 247 | "generate.do_init(settings)" 248 | ], 249 | "execution_count": null, 250 | "outputs": [] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "metadata": { 255 | "id": "JmCtyfJD3DRW" 256 | }, 257 | "source": [ 258 | "\n", 259 | "from IPython import display\n", 260 | "generate.do_run(settings)\n", 261 | "if settings.overlay_every is not None:\n", 262 | " print(\"Final version with overlay\")\n", 263 | " display.display(display.Image(\"overlaid.png\"))" 264 | ], 265 | "execution_count": null, 266 | "outputs": [] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": { 271 | "id": "eOjp_zwsoWUn" 272 | }, 273 | "source": [ 274 | "## Another run" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "metadata": { 280 | "id": "6x6UYID4Kdnb", 281 | "cellView": "form" 282 | }, 283 | "source": [ 284 | "#@title Parameters\n", 285 | "prompts = \"photo of perfume\" #@param {type:\"string\"}\n", 286 | "image_prompts = \"\" #@param {type:\"string\"}\n", 287 | "init_image = \"perfume.jpg\" #@param {type:\"string\"}\n", 288 | "init_image_alpha = 200 #@param {type:\"number\"}\n", 289 | "init_weight = 0.5 #@param {type:\"number\"}\n", 290 | "overlay_every = 0 #@param {type:\"number\"}\n", 291 | "model = \"vqgan_imagenet_f16_16384\" #@param [\"vqgan_imagenet_f16_16384\", \"vqgan_imagenet_f16_1024\", \"vqgan_openimages_f16_8192\", \"wikiart_1024\", \"wikiart_16384\", \"coco\", \"faceshq\", \"sflckr\"]\n", 292 | "seed = 42 #@param {type:\"number\"}\n", 293 | "display_freq = 50 #@param {type:\"number\"}\n", 294 | "max_iterations = 200 #@param {type:\"number\"}\n", 295 | "width = 256 #@param {type:\"number\"}\n", 296 | "height = 256 #@param {type:\"number\"}\n", 297 | "\n", 298 | "############# SETUP WITH THESE SETTINGS\n", 299 | "\n", 300 | "model_names={\"vqgan_imagenet_f16_16384\": 'ImageNet 16384',\"vqgan_imagenet_f16_1024\":\"ImageNet 1024\", 'vqgan_openimages_f16_8192':'OpenImages 8912',\n", 301 | " \"wikiart_1024\":\"WikiArt 1024\", \"wikiart_16384\":\"WikiArt 16384\", \"coco\":\"COCO-Stuff\", \"faceshq\":\"FacesHQ\", \"sflckr\":\"S-FLCKR\"}\n", 302 | "name_model = model_names[model] \n", 303 | "\n", 304 | "if seed == -1:\n", 305 | " seed = None\n", 306 | "if overlay_every == \"None\":\n", 307 | " overlay_every = None\n", 308 | "if init_image == \"None\":\n", 309 | " init_image = None\n", 310 | "if image_prompts == \"None\" or not image_prompts:\n", 311 | " image_prompts = []\n", 312 | "\n", 313 | "# Simple setup\n", 314 | "from clipit import generate\n", 315 | "import easydict\n", 316 | "\n", 317 | "args = easydict.EasyDict({\n", 318 | " \"prompts\": prompts,\n", 319 | " \"image_prompts\": image_prompts,\n", 320 | " \"init_image\": init_image,\n", 321 | " \"init_image_alpha\": init_image_alpha,\n", 322 | " \"init_weight\": init_weight,\n", 323 | " \"overlay_every\": overlay_every,\n", 324 | " \"vqgan_config\": f'{model}.yaml',\n", 325 | " \"vqgan_checkpoint\": f'{model}.ckpt',\n", 326 | " \"seed\": seed,\n", 327 | " \"display_freq\": display_freq,\n", 328 | " \"max_iterations\": max_iterations,\n", 329 | " \"size\": [width, height],\n", 330 | " \"init_noise\": \"pixels\"\n", 331 | "})\n", 332 | "\n", 333 | "vq_parser = generate.setup_parser()\n", 334 | "settings = generate.process_args(vq_parser, namespace=args)\n", 335 | "generate.do_init(settings)" 336 | ], 337 | "execution_count": null, 338 | "outputs": [] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "metadata": { 343 | "id": "iyOOFoNxqOt0" 344 | }, 345 | "source": [ 346 | "\n", 347 | "from IPython import display\n", 348 | "generate.do_run(settings)\n", 349 | "if settings.overlay_every is not None:\n", 350 | " print(\"Final version with overlay\")\n", 351 | " display.display(display.Image(\"overlaid.png\"))" 352 | ], 353 | "execution_count": null, 354 | "outputs": [] 355 | } 356 | ] 357 | } 358 | -------------------------------------------------------------------------------- /demos/README.md: -------------------------------------------------------------------------------- 1 | # A growing set of Notebooks showing examples of how things work. 2 | 3 | | Demo | Colab Link | 4 | | ------------- | ------------- | 5 | | Start Here | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dribnet/clipit/blob/master/demos/Start_Here.ipynb) | 6 | | Swap Model | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dribnet/clipit/blob/master/demos/Swap_Model.ipynb) | 7 | | Pixel Art | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dribnet/clipit/blob/master/demos/PixelDrawer.ipynb) | 8 | | Palette | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dribnet/clipit/blob/master/demos/palette_enforcement.ipynb) | 9 | | Init Image | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dribnet/clipit/blob/master/demos/PixelDrawer_Init_Image.ipynb) | 10 | | Pixel Swirl | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dribnet/clipit/blob/master/demos/Pixray_Swirl_Demo.ipynb) | 11 | | Image Prompt | Coming soon | 12 | | Target Image | Coming soon | 13 | | Color Mapper | Coming soon | 14 | -------------------------------------------------------------------------------- /download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LOC=checkpoints 3 | mkdir -p "$LOC" 4 | 5 | # Which models to download? 6 | IMAGENET_1024=false 7 | IMAGENET_16384=true 8 | GUMBEL=false 9 | WIKIART_1024=false 10 | WIKIART_16384=false 11 | # Not yet working: 12 | COCO=false 13 | FACESHQ=false 14 | SFLCKR=false 15 | 16 | 17 | if [ "$IMAGENET_1024" = true ] ; then 18 | # imagenet_1024 - 958 MB: 19 | # Alternative URLs 20 | # https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1 #ImageNet 1024 21 | # https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fckpts%2Flast.ckpt&dl=1' #ImageNet 1024 22 | 23 | if [ ! -f "$LOC"/vqgan_imagenet_f16_1024.yaml ]; then 24 | curl -L -o "$LOC"/vqgan_imagenet_f16_1024.yaml -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_1024.yaml' #ImageNet 1024 25 | fi 26 | if [ ! -f "$LOC"/vqgan_imagenet_f16_1024.ckpt ]; then 27 | curl -L -o "$LOC"/vqgan_imagenet_f16_1024.ckpt -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_1024.ckpt' #ImageNet 1024 28 | fi 29 | fi 30 | 31 | if [ "$IMAGENET_16384" = true ] ; then 32 | # imagenet_16384 - 980 MB: 33 | # Alternative URLs 34 | # https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1 #ImageNet 16384 35 | # https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fckpts%2Flast.ckpt&dl=1 #ImageNet 16384 36 | if [ ! -f "$LOC"/vqgan_imagenet_f16_16384.yaml ]; then 37 | curl -L -o "$LOC"/vqgan_imagenet_f16_16384.yaml -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.yaml' #ImageNet 16384 38 | fi 39 | if [ ! -f "$LOC"/vqgan_imagenet_f16_16384.ckpt ]; then 40 | curl -L -o "$LOC"/vqgan_imagenet_f16_16384.ckpt -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.ckpt' #ImageNet 16384 41 | fi 42 | fi 43 | 44 | if [ "$GUMBEL" = true ] ; then 45 | # vqgan_gumbel_f8_8192 (was openimages_f16_8192) - 376 MB: 46 | if [ ! -f "$LOC"/vqgan_gumbel_f8_8192.yaml ]; then 47 | curl -L -o "$LOC"/vqgan_gumbel_f8_8192.yaml -C - 'https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' 48 | fi 49 | if [ ! -f "$LOC"/vqgan_gumbel_f8_8192.ckpt ]; then 50 | curl -L -o "$LOC"/vqgan_gumbel_f8_8192.ckpt -C - 'https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fckpts%2Flast.ckpt&dl=1' 51 | fi 52 | fi 53 | 54 | if [ "$COCO" = true ] ; then 55 | # coco - 8.4 GB: 56 | if [ ! -f "$LOC"/coco.yaml ]; then 57 | curl -L -o "$LOC"/coco.yaml -C - 'https://dl.nmkd.de/ai/clip/coco/coco.yaml' #COCO 58 | fi 59 | if [ ! -f "$LOC"/coco.ckpt ]; then 60 | curl -L -o "$LOC"/coco.ckpt -C - 'https://dl.nmkd.de/ai/clip/coco/coco.ckpt' #COCO 61 | fi 62 | fi 63 | 64 | if [ "$FACESHQ" = true ] ; then 65 | # faceshq: 66 | if [ ! -f "$LOC"/faceshq.yaml ]; then 67 | curl -L -o "$LOC"/faceshq.yaml -C - 'https://drive.google.com/uc?export=download&id=1fHwGx_hnBtC8nsq7hesJvs-Klv-P0gzT' #FacesHQ 68 | fi 69 | if [ ! -f "$LOC"/faceshq.ckpt ]; then 70 | curl -L -o "$LOC"/faceshq.ckpt -C - 'https://app.koofr.net/content/links/a04deec9-0c59-4673-8b37-3d696fe63a5d/files/get/last.ckpt?path=%2F2020-11-13T21-41-45_faceshq_transformer%2Fcheckpoints%2Flast.ckpt' #FacesHQ 71 | fi 72 | fi 73 | 74 | if [ "$WIKIART_1024" = true ] ; then 75 | # wikiart_1024 - 958 MB: 76 | if [ ! -f "$LOC"/wikiart_1024.yaml ]; then 77 | curl -L -o "$LOC"/wikiart_1024.yaml -C - 'http://mirror.io.community/blob/vqgan/wikiart.yaml' #WikiArt 1024 78 | fi 79 | if [ ! -f "$LOC"/wikiart_1024.ckpt ]; then 80 | curl -L -o "$LOC"/wikiart_1024.ckpt -C - 'http://mirror.io.community/blob/vqgan/wikiart.ckpt' #WikiArt 1024 81 | fi 82 | fi 83 | 84 | if [ "$WIKIART_16384" = true ] ; then 85 | #wikiart_16384 - 1 GB: 86 | if [ ! -f "$LOC"/wikiart_16384.yaml ]; then 87 | curl -L -o "$LOC"/wikiart_16384.yaml -C - 'http://mirror.io.community/blob/vqgan/wikiart_16384.yaml' #WikiArt 16384 88 | fi 89 | if [ ! -f "$LOC"/wikiart_16384.ckpt ]; then 90 | curl -L -o "$LOC"/wikiart_16384.ckpt -C - 'http://mirror.io.community/blob/vqgan/wikiart_16384.ckpt' #WikiArt 16384 91 | fi 92 | fi 93 | 94 | if [ "$SFLCKR" = true ] ; then 95 | # sflckr: 96 | if [ ! -f "$LOC"/sflckr.yaml ]; then 97 | curl -L -o "$LOC"/sflckr.yaml -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fconfigs%2F2020-11-09T13-31-51-project.yaml&dl=1' #S-FLCKR 98 | fi 99 | if [ ! -f "$LOC"/sflckr.ckpt ]; then 100 | curl -L -o "$LOC"/sflckr.ckpt -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fcheckpoints%2Flast.ckpt&dl=1' #S-FLCKR 101 | fi 102 | fi 103 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | # Originally made by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings) 2 | # The original BigGAN+CLIP method was by https://twitter.com/advadnoun 3 | 4 | import argparse 5 | import math 6 | from urllib.request import urlopen 7 | import sys 8 | import os 9 | import subprocess 10 | import glob 11 | from braceexpand import braceexpand 12 | from types import SimpleNamespace 13 | 14 | # pip install taming-transformers work with Gumbel, but does works with coco etc 15 | # appending the path works with Gumbel, but gives ModuleNotFoundError: No module named 'transformers' for coco etc 16 | sys.path.append('taming-transformers') 17 | import os.path 18 | 19 | from omegaconf import OmegaConf 20 | from taming.models import cond_transformer, vqgan 21 | 22 | import torch 23 | from torch import nn, optim 24 | from torch.nn import functional as F 25 | from torchvision import transforms 26 | from torchvision.transforms import functional as TF 27 | torch.backends.cudnn.benchmark = False # NR: True is a bit faster, but can lead to OOM. False is more deterministic. 28 | #torch.use_deterministic_algorithms(True) # NR: grid_sampler_2d_backward_cuda does not have a deterministic implementation 29 | 30 | from torch_optimizer import DiffGrad, AdamP, RAdam 31 | from perlin_numpy import generate_fractal_noise_2d 32 | 33 | from CLIP import clip 34 | import kornia 35 | import kornia.augmentation as K 36 | import numpy as np 37 | import imageio 38 | 39 | from PIL import ImageFile, Image, PngImagePlugin 40 | ImageFile.LOAD_TRUNCATED_IMAGES = True 41 | 42 | # or 'border' 43 | global_padding_mode = 'reflection' 44 | global_aspect_width = 1 45 | 46 | vqgan_config_table = { 47 | "imagenet_f16_1024": 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_1024.yaml', 48 | "imagenet_f16_16384": 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.yaml', 49 | "openimages_f16_8192": 'https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1', 50 | "coco": 'https://dl.nmkd.de/ai/clip/coco/coco.yaml', 51 | "faceshq": 'https://drive.google.com/uc?export=download&id=1fHwGx_hnBtC8nsq7hesJvs-Klv-P0gzT', 52 | "wikiart_1024": 'http://mirror.io.community/blob/vqgan/wikiart.yaml', 53 | "wikiart_16384": 'http://mirror.io.community/blob/vqgan/wikiart_16384.yaml', 54 | "sflckr": 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fconfigs%2F2020-11-09T13-31-51-project.yaml&dl=1', 55 | } 56 | vqgan_checkpoint_table = { 57 | "imagenet_f16_1024": 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_1024.ckpt', 58 | "imagenet_f16_16384": 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.ckpt', 59 | "openimages_f16_8192": 'https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fckpts%2Flast.ckpt&dl=1', 60 | "coco": 'https://dl.nmkd.de/ai/clip/coco/coco.ckpt', 61 | "faceshq": 'https://app.koofr.net/content/links/a04deec9-0c59-4673-8b37-3d696fe63a5d/files/get/last.ckpt?path=%2F2020-11-13T21-41-45_faceshq_transformer%2Fcheckpoints%2Flast.ckpt', 62 | "wikiart_1024": 'http://mirror.io.community/blob/vqgan/wikiart.ckpt', 63 | "wikiart_16384": 'http://mirror.io.community/blob/vqgan/wikiart_16384.ckpt', 64 | "sflckr": 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fcheckpoints%2Flast.ckpt&dl=1' 65 | } 66 | 67 | # https://stackoverflow.com/a/39662359 68 | def isnotebook(): 69 | try: 70 | shell = get_ipython().__class__.__name__ 71 | if shell == 'ZMQInteractiveShell': 72 | return True # Jupyter notebook or qtconsole 73 | elif shell == 'Shell': 74 | return True # Seems to be what co-lab does 75 | elif shell == 'TerminalInteractiveShell': 76 | return False # Terminal running IPython 77 | else: 78 | return False # Other type (?) 79 | except NameError: 80 | return False # Probably standard Python interpreter 81 | 82 | IS_NOTEBOOK = isnotebook() 83 | 84 | if IS_NOTEBOOK: 85 | from IPython import display 86 | from tqdm.notebook import tqdm 87 | else: 88 | from tqdm import tqdm 89 | 90 | # file helpers 91 | def real_glob(rglob): 92 | glob_list = braceexpand(rglob) 93 | files = [] 94 | for g in glob_list: 95 | files = files + glob.glob(g) 96 | return sorted(files) 97 | 98 | # Functions and classes 99 | def sinc(x): 100 | return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([])) 101 | 102 | 103 | def lanczos(x, a): 104 | cond = torch.logical_and(-a < x, x < a) 105 | out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([])) 106 | return out / out.sum() 107 | 108 | 109 | def ramp(ratio, width): 110 | n = math.ceil(width / ratio + 1) 111 | out = torch.empty([n]) 112 | cur = 0 113 | for i in range(out.shape[0]): 114 | out[i] = cur 115 | cur += ratio 116 | return torch.cat([-out[1:].flip([0]), out])[1:-1] 117 | 118 | 119 | # NR: Testing with different intital images 120 | def old_random_noise_image(w,h): 121 | random_image = Image.fromarray(np.random.randint(0,255,(w,h,3),dtype=np.dtype('uint8'))) 122 | return random_image 123 | 124 | def NormalizeData(data): 125 | return (data - np.min(data)) / (np.max(data) - np.min(data)) 126 | 127 | def random_noise_image(w,h): 128 | # scale up roughly as power of 2 129 | if (w>1024 or h>1024): 130 | side, octp = 2048, 7 131 | elif (w>512 or h>512): 132 | side, octp = 1024, 6 133 | elif (w>256 or h>256): 134 | side, octp = 512, 5 135 | else: 136 | side, octp = 256, 4 137 | 138 | nr = NormalizeData(generate_fractal_noise_2d((side, side), (32, 32), octp)) 139 | ng = NormalizeData(generate_fractal_noise_2d((side, side), (32, 32), octp)) 140 | nb = NormalizeData(generate_fractal_noise_2d((side, side), (32, 32), octp)) 141 | stack = np.dstack((nr,ng,nb)) 142 | substack = stack[:h, :w, :] 143 | im = Image.fromarray((255.9 * stack).astype('uint8')) 144 | return im 145 | 146 | # testing 147 | def gradient_2d(start, stop, width, height, is_horizontal): 148 | if is_horizontal: 149 | return np.tile(np.linspace(start, stop, width), (height, 1)) 150 | else: 151 | return np.tile(np.linspace(start, stop, height), (width, 1)).T 152 | 153 | 154 | def gradient_3d(width, height, start_list, stop_list, is_horizontal_list): 155 | result = np.zeros((height, width, len(start_list)), dtype=float) 156 | 157 | for i, (start, stop, is_horizontal) in enumerate(zip(start_list, stop_list, is_horizontal_list)): 158 | result[:, :, i] = gradient_2d(start, stop, width, height, is_horizontal) 159 | 160 | return result 161 | 162 | 163 | def random_gradient_image(w,h): 164 | array = gradient_3d(w, h, (0, 0, np.random.randint(0,255)), (np.random.randint(1,255), np.random.randint(2,255), np.random.randint(3,128)), (True, False, False)) 165 | random_image = Image.fromarray(np.uint8(array)) 166 | return random_image 167 | 168 | 169 | 170 | # Not used? 171 | def resample(input, size, align_corners=True): 172 | n, c, h, w = input.shape 173 | dh, dw = size 174 | 175 | input = input.view([n * c, 1, h, w]) 176 | 177 | if dh < h: 178 | kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype) 179 | pad_h = (kernel_h.shape[0] - 1) // 2 180 | input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect') 181 | input = F.conv2d(input, kernel_h[None, None, :, None]) 182 | 183 | if dw < w: 184 | kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype) 185 | pad_w = (kernel_w.shape[0] - 1) // 2 186 | input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect') 187 | input = F.conv2d(input, kernel_w[None, None, None, :]) 188 | 189 | input = input.view([n, c, h, w]) 190 | return F.interpolate(input, size, mode='bicubic', align_corners=align_corners) 191 | 192 | 193 | class ReplaceGrad(torch.autograd.Function): 194 | @staticmethod 195 | def forward(ctx, x_forward, x_backward): 196 | ctx.shape = x_backward.shape 197 | return x_forward 198 | 199 | @staticmethod 200 | def backward(ctx, grad_in): 201 | return None, grad_in.sum_to_size(ctx.shape) 202 | 203 | replace_grad = ReplaceGrad.apply 204 | 205 | 206 | class ClampWithGrad(torch.autograd.Function): 207 | @staticmethod 208 | def forward(ctx, input, min, max): 209 | ctx.min = min 210 | ctx.max = max 211 | ctx.save_for_backward(input) 212 | return input.clamp(min, max) 213 | 214 | @staticmethod 215 | def backward(ctx, grad_in): 216 | input, = ctx.saved_tensors 217 | return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None 218 | 219 | clamp_with_grad = ClampWithGrad.apply 220 | 221 | 222 | def vector_quantize(x, codebook): 223 | d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T 224 | indices = d.argmin(-1) 225 | x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook 226 | return replace_grad(x_q, x) 227 | 228 | 229 | def spherical_dist_loss(x, y): 230 | x = F.normalize(x, dim=-1) 231 | y = F.normalize(y, dim=-1) 232 | return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) 233 | 234 | 235 | class Prompt(nn.Module): 236 | def __init__(self, embed, weight=1., stop=float('-inf')): 237 | super().__init__() 238 | self.register_buffer('embed', embed) 239 | self.register_buffer('weight', torch.as_tensor(weight)) 240 | self.register_buffer('stop', torch.as_tensor(stop)) 241 | 242 | def forward(self, input): 243 | input_normed = F.normalize(input.unsqueeze(1), dim=2) 244 | embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2) 245 | dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) 246 | dists = dists * self.weight.sign() 247 | return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean() 248 | 249 | 250 | def parse_prompt(prompt): 251 | vals = prompt.rsplit(':', 2) 252 | vals = vals + ['', '1', '-inf'][len(vals):] 253 | # print(f"parsed vals is {vals}") 254 | return vals[0], float(vals[1]), float(vals[2]) 255 | 256 | 257 | from typing import cast, Dict, List, Optional, Tuple, Union 258 | 259 | # override class to get padding_mode 260 | class MyRandomPerspective(K.RandomPerspective): 261 | def apply_transform( 262 | self, input: torch.Tensor, params: Dict[str, torch.Tensor], transform: Optional[torch.Tensor] = None 263 | ) -> torch.Tensor: 264 | _, _, height, width = input.shape 265 | transform = cast(torch.Tensor, transform) 266 | return kornia.geometry.warp_perspective( 267 | input, transform, (height, width), 268 | mode=self.resample.name.lower(), align_corners=self.align_corners, padding_mode=global_padding_mode 269 | ) 270 | 271 | 272 | cached_spot_indexes = {} 273 | def fetch_spot_indexes(sideX, sideY): 274 | # make sure image is loaded if we need it 275 | cache_key = (sideX, sideY) 276 | 277 | if cache_key not in cached_spot_indexes: 278 | if global_aspect_width != 1: 279 | mask_image = Image.open("inputs/spot_wide.png") 280 | else: 281 | mask_image = Image.open("inputs/spot_square.png") 282 | # this is a one channel mask 283 | mask_image = mask_image.convert('RGB') 284 | mask_image = mask_image.resize((sideX, sideY), Image.LANCZOS) 285 | mask_image_tensor = TF.to_tensor(mask_image) 286 | # print("ONE CHANNEL ", mask_image_tensor.shape) 287 | mask_indexes = mask_image_tensor.ge(0.5).to(device) 288 | # print("GE ", mask_indexes.shape) 289 | # sys.exit(0) 290 | mask_indexes_off = mask_image_tensor.lt(0.5).to(device) 291 | cached_spot_indexes[cache_key] = [mask_indexes, mask_indexes_off] 292 | 293 | return cached_spot_indexes[cache_key] 294 | 295 | # n = torch.ones((3,5,5)) 296 | # f = generate.fetch_spot_indexes(5, 5) 297 | # f[0].shape = [60,3] 298 | 299 | class MakeCutouts(nn.Module): 300 | def __init__(self, cut_size, cutn, cut_pow=1.): 301 | global global_aspect_width 302 | 303 | super().__init__() 304 | self.cut_size = cut_size 305 | self.cutn = cutn 306 | self.cut_pow = cut_pow 307 | self.transforms = None 308 | 309 | augmentations = [] 310 | if global_aspect_width != 1: 311 | augmentations.append(K.RandomCrop(size=(self.cut_size,self.cut_size), p=1.0, return_transform=True)) 312 | augmentations.append(MyRandomPerspective(distortion_scale=0.40, p=0.7, return_transform=True)) 313 | augmentations.append(K.RandomResizedCrop(size=(self.cut_size,self.cut_size), scale=(0.15,0.80), ratio=(0.75,1.333), cropping_mode='resample', p=0.7, return_transform=True)) 314 | augmentations.append(K.ColorJitter(hue=0.1, saturation=0.1, p=0.8, return_transform=True)) 315 | self.augs = nn.Sequential(*augmentations) 316 | 317 | # self.augs = nn.Sequential( 318 | # # K.RandomHorizontalFlip(p=0.5), # NR: add augmentation options 319 | # # K.RandomVerticalFlip(p=0.5), 320 | # # K.RandomSolarize(0.01, 0.01, p=0.7), 321 | # # K.RandomSharpness(0.3,p=0.4), 322 | # # K.RandomResizedCrop(size=(self.cut_size,self.cut_size), scale=(0.1,1), ratio=(0.75,1.333), cropping_mode='resample', p=0.5, return_transform=True), 323 | # K.RandomCrop(size=(self.cut_size,self.cut_size), p=1.0), 324 | 325 | # # K.RandomAffine(degrees=15, translate=0.1, p=0.7, padding_mode='border', return_transform=True), 326 | 327 | # # MyRandomPerspective(distortion_scale=0.40, p=0.7, return_transform=True), 328 | # # K.RandomResizedCrop(size=(self.cut_size,self.cut_size), scale=(0.15,0.80), ratio=(0.75,1.333), cropping_mode='resample', p=0.7, return_transform=True), 329 | # K.ColorJitter(hue=0.1, saturation=0.1, p=0.8, return_transform=True), 330 | 331 | # # K.RandomErasing((.1, .4), (.3, 1/.3), same_on_batch=True, p=0.7, return_transform=True), 332 | # ) 333 | 334 | self.noise_fac = 0.1 335 | 336 | # Pooling 337 | self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size)) 338 | self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size)) 339 | 340 | def forward(self, input, spot=None): 341 | global i, global_aspect_width 342 | sideY, sideX = input.shape[2:4] 343 | max_size = min(sideX, sideY) 344 | min_size = min(sideX, sideY, self.cut_size) 345 | cutouts = [] 346 | mask_indexes = None 347 | 348 | if spot is not None: 349 | spot_indexes = fetch_spot_indexes(self.cut_size, self.cut_size) 350 | if spot == 0: 351 | mask_indexes = spot_indexes[1] 352 | else: 353 | mask_indexes = spot_indexes[0] 354 | # print("Mask indexes ", mask_indexes) 355 | 356 | for _ in range(self.cutn): 357 | # size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) 358 | # offsetx = torch.randint(0, sideX - size + 1, ()) 359 | # offsety = torch.randint(0, sideY - size + 1, ()) 360 | # cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] 361 | # cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) 362 | # cutout = transforms.Resize(size=(self.cut_size, self.cut_size))(input) 363 | 364 | # Pooling 365 | cutout = (self.av_pool(input) + self.max_pool(input))/2 366 | 367 | if mask_indexes is not None: 368 | cutout[0][mask_indexes] = 0.5 369 | 370 | if global_aspect_width != 1: 371 | cutout = kornia.geometry.transform.rescale(cutout, (1, 16/9)) 372 | 373 | # if i % 50 == 0 and _ == 0: 374 | # print(cutout.shape) 375 | # TF.to_pil_image(cutout[0].cpu()).save(f"cutout_im_{i:02d}_{spot}.png") 376 | 377 | cutouts.append(cutout) 378 | 379 | if self.transforms is not None: 380 | # print("Cached transforms available, but I'm not smart enough to use them") 381 | # print(cutouts.shape) 382 | # print(torch.cat(cutouts, dim=0).shape) 383 | # print(self.transforms.shape) 384 | # batch = kornia.geometry.transform.warp_affine(torch.cat(cutouts, dim=0), self.transforms, (sideY, sideX)) 385 | # batch = self.transforms @ torch.cat(cutouts, dim=0) 386 | batch = kornia.geometry.transform.warp_perspective(torch.cat(cutouts, dim=0), self.transforms, 387 | (self.cut_size, self.cut_size), padding_mode=global_padding_mode) 388 | # if i < 4: 389 | # for j in range(4): 390 | # TF.to_pil_image(batch[j].cpu()).save(f"cached_im_{i:02d}_{j:02d}_{spot}.png") 391 | else: 392 | batch, self.transforms = self.augs(torch.cat(cutouts, dim=0)) 393 | # if i < 4: 394 | # for j in range(4): 395 | # TF.to_pil_image(batch[j].cpu()).save(f"live_im_{i:02d}_{j:02d}_{spot}.png") 396 | 397 | # print(batch.shape, self.transforms.shape) 398 | 399 | if self.noise_fac: 400 | facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac) 401 | batch = batch + facs * torch.randn_like(batch) 402 | return batch 403 | 404 | 405 | def load_vqgan_model(config_path, checkpoint_path): 406 | global gumbel 407 | gumbel = False 408 | config = OmegaConf.load(config_path) 409 | if config.model.target == 'taming.models.vqgan.VQModel': 410 | model = vqgan.VQModel(**config.model.params) 411 | model.eval().requires_grad_(False) 412 | model.init_from_ckpt(checkpoint_path) 413 | elif config.model.target == 'taming.models.vqgan.GumbelVQ': 414 | model = vqgan.GumbelVQ(**config.model.params) 415 | model.eval().requires_grad_(False) 416 | model.init_from_ckpt(checkpoint_path) 417 | gumbel = True 418 | elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer': 419 | parent_model = cond_transformer.Net2NetTransformer(**config.model.params) 420 | parent_model.eval().requires_grad_(False) 421 | parent_model.init_from_ckpt(checkpoint_path) 422 | model = parent_model.first_stage_model 423 | else: 424 | raise ValueError(f'unknown model type: {config.model.target}') 425 | del model.loss 426 | return model 427 | 428 | def resize_image(image, out_size): 429 | ratio = image.size[0] / image.size[1] 430 | area = min(image.size[0] * image.size[1], out_size[0] * out_size[1]) 431 | size = round((area * ratio)**0.5), round((area / ratio)**0.5) 432 | return image.resize(size, Image.LANCZOS) 433 | 434 | def wget_file(url, out): 435 | try: 436 | output = subprocess.check_output(['wget', '-O', out, url]) 437 | except subprocess.CalledProcessError as cpe: 438 | output = e.output 439 | print("Ignoring non-zero exit: ", output) 440 | 441 | def do_init(args): 442 | global model, opt, perceptors, normalize, cutoutsTable, cutoutSizeTable 443 | global z, z_orig, z_targets, z_labels, z_min, z_max, init_image_tensor 444 | global gside_X, gside_Y, overlay_image_rgba 445 | global pmsTable, pImages, device, spotPmsTable, spotOffPmsTable 446 | 447 | # Do it (init that is) 448 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 449 | if args.vqgan_config is not None: 450 | vqgan_config = args.vqgan_config 451 | vqgan_checkpoint = args.vqgan_checkpoint 452 | else: 453 | # the "vqgan_model" option also downloads if necessary 454 | vqgan_config = f'models/vqgan_{args.vqgan_model}.yaml' 455 | vqgan_checkpoint = f'models/vqgan_{args.vqgan_model}.ckpt' 456 | if not os.path.exists(vqgan_config): 457 | wget_file(vqgan_config_table[args.vqgan_model], vqgan_config) 458 | if not os.path.exists(vqgan_checkpoint): 459 | wget_file(vqgan_checkpoint_table[args.vqgan_model], vqgan_checkpoint) 460 | 461 | model = load_vqgan_model(vqgan_config, vqgan_checkpoint).to(device) 462 | jit = True if float(torch.__version__[:3]) < 1.8 else False 463 | f = 2**(model.decoder.num_resolutions - 1) 464 | 465 | for clip_model in args.clip_models: 466 | perceptor = clip.load(clip_model, jit=jit)[0].eval().requires_grad_(False).to(device) 467 | perceptors[clip_model] = perceptor 468 | 469 | # TODO: is one cut_size enought? I hope so. 470 | cut_size = perceptor.visual.input_resolution 471 | cutoutSizeTable[clip_model] = cut_size 472 | if not cut_size in cutoutsTable: 473 | make_cutouts = MakeCutouts(cut_size, args.num_cuts, cut_pow=args.cut_pow) 474 | cutoutsTable[cut_size] = make_cutouts 475 | 476 | toksX, toksY = args.size[0] // f, args.size[1] // f 477 | sideX, sideY = toksX * f, toksY * f 478 | 479 | if gumbel: 480 | e_dim = 256 481 | n_toks = model.quantize.n_embed 482 | z_min = model.quantize.embed.weight.min(dim=0).values[None, :, None, None] 483 | z_max = model.quantize.embed.weight.max(dim=0).values[None, :, None, None] 484 | else: 485 | e_dim = model.quantize.e_dim 486 | n_toks = model.quantize.n_e 487 | z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None] 488 | z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None] 489 | 490 | # z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None] 491 | # z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None] 492 | 493 | # normalize_imagenet = transforms.Normalize(mean=[0.485, 0.456, 0.406], 494 | # std=[0.229, 0.224, 0.225]) 495 | 496 | # save sideX, sideY in globals (need if using overlay) 497 | gside_X = sideX 498 | gside_Y = sideY 499 | 500 | init_image_tensor = None 501 | 502 | # Image initialisation 503 | if args.init_image or args.init_noise: 504 | # setup init image wih pil 505 | # first - always start with noise or blank 506 | if args.init_noise == 'pixels': 507 | img = random_noise_image(args.size[0], args.size[1]) 508 | elif args.init_noise == 'gradient': 509 | img = random_gradient_image(args.size[0], args.size[1]) 510 | else: 511 | img = Image.new(mode="RGB", size=(args.size[0], args.size[1]), color=(255, 255, 255)) 512 | starting_image = img.convert('RGB') 513 | starting_image = starting_image.resize((sideX, sideY), Image.LANCZOS) 514 | 515 | if args.init_image: 516 | # now we might overlay an init image (init_image also can be recycled as overlay) 517 | if 'http' in args.init_image: 518 | init_image = Image.open(urlopen(args.init_image)) 519 | else: 520 | init_image = Image.open(args.init_image) 521 | # this version is needed potentially for the loss function 522 | init_image_rgb = init_image.convert('RGB') 523 | init_image_rgb = init_image_rgb.resize((sideX, sideY), Image.LANCZOS) 524 | init_image_tensor = TF.to_tensor(init_image_rgb) 525 | init_image_tensor = init_image_tensor.to(device).unsqueeze(0) 526 | 527 | # this version gets overlaid on the background (noise) 528 | init_image_rgba = init_image.convert('RGBA') 529 | init_image_rgba = init_image_rgba.resize((sideX, sideY), Image.LANCZOS) 530 | top_image = init_image_rgba.copy() 531 | if args.init_image_alpha and args.init_image_alpha >= 0: 532 | top_image.putalpha(args.init_image_alpha) 533 | starting_image.paste(top_image, (0, 0), top_image) 534 | 535 | starting_image.save("starting_image.png") 536 | starting_tensor = TF.to_tensor(starting_image) 537 | z, *_ = model.encode(starting_tensor.to(device).unsqueeze(0) * 2 - 1) 538 | 539 | else: 540 | # legacy init 541 | one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float() 542 | # z = one_hot @ model.quantize.embedding.weight 543 | if gumbel: 544 | z = one_hot @ model.quantize.embed.weight 545 | else: 546 | z = one_hot @ model.quantize.embedding.weight 547 | 548 | z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2) 549 | 550 | if args.overlay_every: 551 | if args.overlay_image: 552 | if 'http' in args.overlay_image: 553 | overlay_image = Image.open(urlopen(args.overlay_image)) 554 | else: 555 | overlay_image = Image.open(args.overlay_image) 556 | overlay_image_rgba = overlay_image.convert('RGBA') 557 | overlay_image_rgba = overlay_image_rgba.resize((sideX, sideY), Image.LANCZOS) 558 | else: 559 | overlay_image_rgba = init_image_rgba 560 | if args.overlay_alpha: 561 | overlay_image_rgba.putalpha(args.overlay_alpha) 562 | overlay_image_rgba.save('overlay_image.png') 563 | 564 | if args.target_images is not None: 565 | z_targets = [] 566 | filelist = real_glob(args.target_images) 567 | for target_image in filelist: 568 | target_image = Image.open(target_image) 569 | target_image_rgb = target_image.convert('RGB') 570 | target_image_rgb = target_image_rgb.resize((sideX, sideY), Image.LANCZOS) 571 | target_image_tensor = TF.to_tensor(target_image_rgb) 572 | target_image_tensor = target_image_tensor.to(device).unsqueeze(0) * 2 - 1 573 | z_target, *_ = model.encode(target_image_tensor) 574 | z_targets.append(z_target) 575 | 576 | if args.image_labels is not None: 577 | z_labels = [] 578 | filelist = real_glob(args.image_labels) 579 | cur_labels = [] 580 | for image_label in filelist: 581 | image_label = Image.open(image_label) 582 | image_label_rgb = image_label.convert('RGB') 583 | image_label_rgb = image_label_rgb.resize((sideX, sideY), Image.LANCZOS) 584 | image_label_rgb_tensor = TF.to_tensor(image_label_rgb) 585 | image_label_rgb_tensor = image_label_rgb_tensor.to(device).unsqueeze(0) * 2 - 1 586 | z_label, *_ = model.encode(image_label_rgb_tensor) 587 | cur_labels.append(z_label) 588 | image_embeddings = torch.stack(cur_labels) 589 | print("Processing labels: ", image_embeddings.shape) 590 | image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True) 591 | image_embeddings = image_embeddings.mean(dim=0) 592 | image_embeddings /= image_embeddings.norm() 593 | z_labels.append(image_embeddings.unsqueeze(0)) 594 | 595 | z_orig = z.clone() 596 | z.requires_grad_(True) 597 | 598 | pmsTable = {} 599 | spotPmsTable = {} 600 | spotOffPmsTable = {} 601 | for clip_model in args.clip_models: 602 | pmsTable[clip_model] = [] 603 | spotPmsTable[clip_model] = [] 604 | spotOffPmsTable[clip_model] = [] 605 | pImages = [] 606 | normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 607 | std=[0.26862954, 0.26130258, 0.27577711]) 608 | 609 | # CLIP tokenize/encode 610 | # NR: Weights / blending 611 | for prompt in args.prompts: 612 | for clip_model in args.clip_models: 613 | pMs = pmsTable[clip_model] 614 | perceptor = perceptors[clip_model] 615 | txt, weight, stop = parse_prompt(prompt) 616 | embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float() 617 | pMs.append(Prompt(embed, weight, stop).to(device)) 618 | 619 | for prompt in args.spot_prompts: 620 | for clip_model in args.clip_models: 621 | pMs = spotPmsTable[clip_model] 622 | perceptor = perceptors[clip_model] 623 | txt, weight, stop = parse_prompt(prompt) 624 | embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float() 625 | pMs.append(Prompt(embed, weight, stop).to(device)) 626 | 627 | for prompt in args.spot_prompts_off: 628 | for clip_model in args.clip_models: 629 | pMs = spotOffPmsTable[clip_model] 630 | perceptor = perceptors[clip_model] 631 | txt, weight, stop = parse_prompt(prompt) 632 | embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float() 633 | pMs.append(Prompt(embed, weight, stop).to(device)) 634 | 635 | for label in args.labels: 636 | for clip_model in args.clip_models: 637 | pMs = pmsTable[clip_model] 638 | perceptor = perceptors[clip_model] 639 | txt, weight, stop = parse_prompt(label) 640 | texts = [template.format(txt) for template in imagenet_templates] #format with class 641 | print(f"Tokenizing all of {texts}") 642 | texts = clip.tokenize(texts).to(device) #tokenize 643 | class_embeddings = perceptor.encode_text(texts) #embed with text encoder 644 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 645 | class_embedding = class_embeddings.mean(dim=0) 646 | class_embedding /= class_embedding.norm() 647 | pMs.append(Prompt(class_embedding.unsqueeze(0), weight, stop).to(device)) 648 | 649 | for prompt in args.image_prompts: 650 | path, weight, stop = parse_prompt(prompt) 651 | img = Image.open(path) 652 | pil_image = img.convert('RGB') 653 | img = resize_image(pil_image, (sideX, sideY)) 654 | pImages.append(TF.to_tensor(img).unsqueeze(0).to(device)) 655 | # batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device)) 656 | # embed = perceptor.encode_image(normalize(batch)).float() 657 | # pMs.append(Prompt(embed, weight, stop).to(device)) 658 | 659 | for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights): 660 | gen = torch.Generator().manual_seed(seed) 661 | embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen) 662 | pMs.append(Prompt(embed, weight).to(device)) 663 | 664 | 665 | # Set the optimiser 666 | if args.optimiser == "Adam": 667 | opt = optim.Adam([z], lr=args.step_size) # LR=0.1 668 | elif args.optimiser == "AdamW": 669 | opt = optim.AdamW([z], lr=args.step_size) # LR=0.2 670 | elif args.optimiser == "Adagrad": 671 | opt = optim.Adagrad([z], lr=args.step_size) # LR=0.5+ 672 | elif args.optimiser == "Adamax": 673 | opt = optim.Adamax([z], lr=args.step_size) # LR=0.5+? 674 | elif args.optimiser == "DiffGrad": 675 | opt = DiffGrad([z], lr=args.step_size) # LR=2+? 676 | elif args.optimiser == "AdamP": 677 | opt = AdamP([z], lr=args.step_size) # LR=2+? 678 | elif args.optimiser == "RAdam": 679 | opt = RAdam([z], lr=args.step_size) # LR=2+? 680 | 681 | 682 | # Output for the user 683 | print('Using device:', device) 684 | print('Optimising using:', args.optimiser) 685 | 686 | if args.prompts: 687 | print('Using text prompts:', args.prompts) 688 | if args.spot_prompts: 689 | print('Using spot prompts:', args.spot_prompts) 690 | if args.spot_prompts_off: 691 | print('Using spot off prompts:', args.spot_prompts_off) 692 | if args.image_prompts: 693 | print('Using image prompts:', args.image_prompts) 694 | if args.init_image: 695 | print('Using initial image:', args.init_image) 696 | if args.noise_prompt_weights: 697 | print('Noise prompt weights:', args.noise_prompt_weights) 698 | 699 | 700 | if args.seed is None: 701 | seed = torch.seed() 702 | else: 703 | seed = args.seed 704 | torch.manual_seed(seed) 705 | print('Using seed:', seed) 706 | 707 | 708 | def synth(z): 709 | if gumbel: 710 | z_q = vector_quantize(z.movedim(1, 3), model.quantize.embed.weight).movedim(3, 1) # Vector quantize 711 | else: 712 | z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1) 713 | return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1) 714 | 715 | # dreaded globals (for now) 716 | z = None 717 | z_orig = None 718 | z_targets = None 719 | z_labels = None 720 | z_min = None 721 | z_max = None 722 | opt = None 723 | model = None 724 | perceptors = {} 725 | normalize = None 726 | cutoutsTable = {} 727 | cutoutSizeTable = {} 728 | init_image_tensor = None 729 | pmsTable = None 730 | spotPmsTable = None 731 | spotOffPmsTable = None 732 | pImages = None 733 | gside_X=None 734 | gside_Y=None 735 | overlay_image_rgba=None 736 | device=None 737 | # OK, THIS ONE IS AWFUL 738 | i=None 739 | 740 | @torch.no_grad() 741 | def z_to_pil(): 742 | global z 743 | out = synth(z) 744 | return TF.to_pil_image(out[0].cpu()) 745 | 746 | @torch.no_grad() 747 | def checkin(args, i, losses): 748 | losses_str = ', '.join(f'{loss.item():g}' for loss in losses) 749 | tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}') 750 | info = PngImagePlugin.PngInfo() 751 | info.add_text('comment', f'{args.prompts}') 752 | img = z_to_pil() 753 | img.save(args.output, pnginfo=info) 754 | if IS_NOTEBOOK: 755 | display.display(display.Image(args.output)) 756 | 757 | def ascend_txt(args): 758 | global i, perceptors, normalize, cutoutsTable, cutoutSizeTable 759 | global z, z_orig, z_targets, z_labels, init_image_tensor 760 | global pmsTable, spotPmsTable, spotOffPmsTable, global_padding_mode 761 | 762 | out = synth(z) 763 | 764 | result = [] 765 | 766 | if (i%2 == 0): 767 | global_padding_mode = 'reflection' 768 | else: 769 | global_padding_mode = 'border' 770 | 771 | cur_cutouts = {} 772 | cur_spot_cutouts = {} 773 | cur_spot_off_cutouts = {} 774 | for cutoutSize in cutoutsTable: 775 | make_cutouts = cutoutsTable[cutoutSize] 776 | cur_cutouts[cutoutSize] = make_cutouts(out) 777 | 778 | if args.spot_prompts: 779 | for cutoutSize in cutoutsTable: 780 | cur_spot_cutouts[cutoutSize] = make_cutouts(out, spot=1) 781 | 782 | if args.spot_prompts_off: 783 | for cutoutSize in cutoutsTable: 784 | cur_spot_off_cutouts[cutoutSize] = make_cutouts(out, spot=0) 785 | 786 | for clip_model in args.clip_models: 787 | perceptor = perceptors[clip_model] 788 | cutoutSize = cutoutSizeTable[clip_model] 789 | transient_pMs = [] 790 | 791 | if args.spot_prompts: 792 | iii_s = perceptor.encode_image(normalize( cur_spot_cutouts[cutoutSize] )).float() 793 | spotPms = spotPmsTable[clip_model] 794 | for prompt in spotPms: 795 | result.append(prompt(iii_s)) 796 | 797 | if args.spot_prompts_off: 798 | iii_so = perceptor.encode_image(normalize( cur_spot_off_cutouts[cutoutSize] )).float() 799 | spotOffPms = spotOffPmsTable[clip_model] 800 | for prompt in spotOffPms: 801 | result.append(prompt(iii_so)) 802 | 803 | pMs = pmsTable[clip_model] 804 | iii = perceptor.encode_image(normalize( cur_cutouts[cutoutSize] )).float() 805 | for prompt in pMs: 806 | result.append(prompt(iii)) 807 | 808 | # If there are image prompts we make cutouts for those each time 809 | # so that they line up with the current cutouts from augmentation 810 | make_cutouts = cutoutsTable[cutoutSize] 811 | for timg in pImages: 812 | # note: this caches and reuses the transforms - a bit of a hack but it works 813 | 814 | if args.image_prompt_shuffle: 815 | # print("Disabling cached transforms") 816 | make_cutouts.transforms = None 817 | 818 | # new way builds throwaway Prompts 819 | batch = make_cutouts(timg) 820 | embed = perceptor.encode_image(normalize(batch)).float() 821 | if args.image_prompt_weight is not None: 822 | transient_pMs.append(Prompt(embed, args.image_prompt_weight).to(device)) 823 | else: 824 | transient_pMs.append(Prompt(embed).to(device)) 825 | 826 | for prompt in transient_pMs: 827 | result.append(prompt(iii)) 828 | 829 | for cutoutSize in cutoutsTable: 830 | # clear the transform "cache" 831 | make_cutouts = cutoutsTable[cutoutSize] 832 | make_cutouts.transforms = None 833 | 834 | # main init_weight uses spherical loss 835 | if args.target_images is not None: 836 | for z_target in z_targets: 837 | f = z.reshape(1,-1) 838 | f2 = z_target.reshape(1,-1) 839 | cur_loss = spherical_dist_loss(f, f2) * args.target_image_weight 840 | result.append(cur_loss) 841 | 842 | if args.image_labels is not None: 843 | for z_label in z_labels: 844 | f = z.reshape(1,-1) 845 | f2 = z_label.reshape(1,-1) 846 | cur_loss = spherical_dist_loss(f, f2) * args.image_label_weight 847 | result.append(cur_loss) 848 | 849 | # main init_weight uses spherical loss 850 | if args.init_weight: 851 | f = z.reshape(1,-1) 852 | f2 = z_orig.reshape(1,-1) 853 | cur_loss = spherical_dist_loss(f, f2) * args.init_weight 854 | result.append(cur_loss) 855 | 856 | # these three init_weight variants offer mse_loss, mse_loss in pixel space, and cos loss 857 | if args.init_weight_dist: 858 | cur_loss = F.mse_loss(z, z_orig) * args.init_weight_dist / 2 859 | result.append(cur_loss) 860 | 861 | if args.init_weight_pix: 862 | if init_image_tensor is None: 863 | print("OOPS IIT is 0") 864 | else: 865 | # TF.to_pil_image(out[0].cpu()).save(f"out_1.png") 866 | # TF.to_pil_image(init_image_tensor[0].cpu()).save(f"init_1.png") 867 | # print(out.shape) 868 | # print(init_image_tensor.shape) 869 | # print(out[0][0]) 870 | # print(init_image_tensor[0][0]) 871 | cur_loss = F.l1_loss(out, init_image_tensor) * args.init_weight_pix / 2 872 | result.append(cur_loss) 873 | 874 | if args.init_weight_cos: 875 | f = z.reshape(1,-1) 876 | f2 = z_orig.reshape(1,-1) 877 | y = torch.ones_like(f[0]) 878 | cur_loss = F.cosine_embedding_loss(f, f2, y) * args.init_weight_cos 879 | result.append(cur_loss) 880 | 881 | if args.make_video: 882 | img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:] 883 | img = np.transpose(img, (1, 2, 0)) 884 | imageio.imwrite(f'./steps/frame_{i:04d}.png', np.array(img)) 885 | 886 | return result 887 | 888 | def re_average_z(args): 889 | global z, gside_X, gside_Y 890 | global model, device 891 | 892 | # old_z = z.clone() 893 | cur_z_image = z_to_pil() 894 | cur_z_image = cur_z_image.convert('RGB') 895 | if overlay_image_rgba: 896 | # print("applying overlay image") 897 | cur_z_image.paste(overlay_image_rgba, (0, 0), overlay_image_rgba) 898 | cur_z_image.save("overlaid.png") 899 | cur_z_image = cur_z_image.resize((gside_X, gside_Y), Image.LANCZOS) 900 | new_z, *_ = model.encode(TF.to_tensor(cur_z_image).to(device).unsqueeze(0) * 2 - 1) 901 | # t_dist = F.pairwise_distance(new_z, old_z) 902 | with torch.no_grad(): 903 | z.copy_(new_z) 904 | # with torch.no_grad(): 905 | # z.copy_(z.maximum(z_min).minimum(z_max)) 906 | 907 | # torch.autograd.set_detect_anomaly(True) 908 | 909 | def train(args, i): 910 | global z, z_min, z_max 911 | opt.zero_grad(set_to_none=True) 912 | lossAll = ascend_txt(args) 913 | 914 | if i % args.display_freq == 0: 915 | checkin(args, i, lossAll) 916 | 917 | loss = sum(lossAll) 918 | loss.backward() 919 | opt.step() 920 | 921 | if args.overlay_every and i != 0 and \ 922 | (i % (args.overlay_every + args.overlay_offset)) == 0: 923 | re_average_z(args) 924 | 925 | with torch.no_grad(): 926 | z.copy_(z.maximum(z_min).minimum(z_max)) 927 | 928 | imagenet_templates = [ 929 | "itap of a {}.", 930 | "a bad photo of the {}.", 931 | "a origami {}.", 932 | "a photo of the large {}.", 933 | "a {} in a video game.", 934 | "art of the {}.", 935 | "a photo of the small {}.", 936 | ] 937 | 938 | def do_run(args): 939 | global i 940 | 941 | i = 0 942 | try: 943 | with tqdm() as pbar: 944 | while True: 945 | try: 946 | train(args, i) 947 | if i == args.iterations: 948 | break 949 | i += 1 950 | pbar.update() 951 | except RuntimeError as e: 952 | print("Oops: runtime error: ", e) 953 | print("Try reducing --num-cuts to save memory") 954 | raise e 955 | except KeyboardInterrupt: 956 | pass 957 | 958 | if args.make_video: 959 | do_video(settings) 960 | 961 | def do_video(args): 962 | global i 963 | 964 | # Video generation 965 | init_frame = 1 # This is the frame where the video will start 966 | last_frame = i # You can change i to the number of the last frame you want to generate. It will raise an error if that number of frames does not exist. 967 | 968 | min_fps = 10 969 | max_fps = 60 970 | 971 | total_frames = last_frame-init_frame 972 | 973 | length = 15 # Desired time of the video in seconds 974 | 975 | frames = [] 976 | tqdm.write('Generating video...') 977 | for i in range(init_frame,last_frame): # 978 | frames.append(Image.open(f'./steps/frame_{i:04d}.png')) 979 | 980 | #fps = last_frame/10 981 | fps = np.clip(total_frames/length,min_fps,max_fps) 982 | 983 | from subprocess import Popen, PIPE 984 | import re 985 | output_file = re.compile('\.png$').sub('.mp4', args.output) 986 | p = Popen(['ffmpeg', 987 | '-y', 988 | '-f', 'image2pipe', 989 | '-vcodec', 'png', 990 | '-r', str(fps), 991 | '-i', 992 | '-', 993 | '-vcodec', 'libx264', 994 | '-r', str(fps), 995 | '-pix_fmt', 'yuv420p', 996 | '-crf', '17', 997 | '-preset', 'veryslow', 998 | '-metadata', f'comment={args.prompts}', 999 | output_file], stdin=PIPE) 1000 | for im in tqdm(frames): 1001 | im.save(p.stdin, 'PNG') 1002 | p.stdin.close() 1003 | p.wait() 1004 | 1005 | # this dictionary is used for settings in the notebook 1006 | global_clipit_settings = {} 1007 | 1008 | def setup_parser(): 1009 | # Create the parser 1010 | vq_parser = argparse.ArgumentParser(description='Image generation using VQGAN+CLIP') 1011 | 1012 | # Add the arguments 1013 | vq_parser.add_argument("-p", "--prompts", type=str, help="Text prompts", default=[], dest='prompts') 1014 | vq_parser.add_argument("-sp", "--spot", type=str, help="Spot Text prompts", default=[], dest='spot_prompts') 1015 | vq_parser.add_argument("-spo", "--spot_off", type=str, help="Spot off Text prompts", default=[], dest='spot_prompts_off') 1016 | vq_parser.add_argument("-l", "--labels", type=str, help="ImageNet labels", default=[], dest='labels') 1017 | vq_parser.add_argument("-ip", "--image_prompts", type=str, help="Image prompts", default=[], dest='image_prompts') 1018 | vq_parser.add_argument("-ipw", "--image_prompt_weight", type=float, help="Weight for image prompt", default=None, dest='image_prompt_weight') 1019 | vq_parser.add_argument("-ips", "--image_prompt_shuffle", type=bool, help="Shuffle image prompts", default=False, dest='image_prompt_shuffle') 1020 | vq_parser.add_argument("-il", "--image_labels", type=str, help="Image prompts", default=None, dest='image_labels') 1021 | vq_parser.add_argument("-ilw", "--image_label_weight", type=float, help="Weight for image prompt", default=1.0, dest='image_label_weight') 1022 | vq_parser.add_argument("-i", "--iterations", type=int, help="Number of iterations", default=None, dest='iterations') 1023 | vq_parser.add_argument("-se", "--save_every", type=int, help="Save image iterations", default=50, dest='display_freq') 1024 | vq_parser.add_argument("-ove", "--overlay_every", type=int, help="Overlay image iterations", default=None, dest='overlay_every') 1025 | vq_parser.add_argument("-ovo", "--overlay_offset", type=int, help="Overlay image iteration offset", default=0, dest='overlay_offset') 1026 | vq_parser.add_argument("-ovi", "--overlay_image", type=str, help="Overlay image (if not init)", default=None, dest='overlay_image') 1027 | vq_parser.add_argument("-qua", "--quality", type=str, help="draft, normal, best", default="normal", dest='quality') 1028 | vq_parser.add_argument("-asp", "--aspect", type=str, help="widescreen, square", default="widescreen", dest='aspect') 1029 | vq_parser.add_argument("-ezs", "--ezsize", type=str, help="small, medium, large", default=None, dest='ezsize') 1030 | vq_parser.add_argument("-sca", "--scale", type=float, help="scale (instead of ezsize)", default=None, dest='scale') 1031 | vq_parser.add_argument("-ova", "--overlay_alpha", type=int, help="Overlay alpha (0-255)", default=None, dest='overlay_alpha') 1032 | vq_parser.add_argument("-s", "--size", nargs=2, type=int, help="Image size (width height)", default=None, dest='size') 1033 | vq_parser.add_argument("-ii", "--init_image", type=str, help="Initial image", default=None, dest='init_image') 1034 | vq_parser.add_argument("-iia", "--init_image_alpha", type=int, help="Init image alpha (0-255)", default=200, dest='init_image_alpha') 1035 | vq_parser.add_argument("-in", "--init_noise", type=str, help="Initial noise image (pixels or gradient)", default="pixels", dest='init_noise') 1036 | vq_parser.add_argument("-ti", "--target_images", type=str, help="Target images", default=None, dest='target_images') 1037 | vq_parser.add_argument("-tiw", "--target_image_weight", type=float, help="Target images weight", default=1.0, dest='target_image_weight') 1038 | vq_parser.add_argument("-iw", "--init_weight", type=float, help="Initial weight (main=spherical)", default=None, dest='init_weight') 1039 | vq_parser.add_argument("-iwd", "--init_weight_dist", type=float, help="Initial weight dist loss", default=0., dest='init_weight_dist') 1040 | vq_parser.add_argument("-iwc", "--init_weight_cos", type=float, help="Initial weight cos loss", default=0., dest='init_weight_cos') 1041 | vq_parser.add_argument("-iwp", "--init_weight_pix", type=float, help="Initial weight pix loss", default=0., dest='init_weight_pix') 1042 | vq_parser.add_argument("-m", "--clip_models", type=str, help="CLIP model", default=None, dest='clip_models') 1043 | vq_parser.add_argument("-vqgan", "--vqgan_model", type=str, help="VQGAN model", default='imagenet_f16_16384', dest='vqgan_model') 1044 | vq_parser.add_argument("-conf", "--vqgan_config", type=str, help="VQGAN config", default=None, dest='vqgan_config') 1045 | vq_parser.add_argument("-ckpt", "--vqgan_checkpoint", type=str, help="VQGAN checkpoint", default=None, dest='vqgan_checkpoint') 1046 | vq_parser.add_argument("-nps", "--noise_prompt_seeds", nargs="*", type=int, help="Noise prompt seeds", default=[], dest='noise_prompt_seeds') 1047 | vq_parser.add_argument("-npw", "--noise_prompt_weights", nargs="*", type=float, help="Noise prompt weights", default=[], dest='noise_prompt_weights') 1048 | vq_parser.add_argument("-lr", "--learning_rate", type=float, help="Learning rate", default=0.2, dest='step_size') 1049 | vq_parser.add_argument("-cuts", "--num_cuts", type=int, help="Number of cuts", default=None, dest='num_cuts') 1050 | vq_parser.add_argument("-cutp", "--cut_power", type=float, help="Cut power", default=1., dest='cut_pow') 1051 | vq_parser.add_argument("-sd", "--seed", type=int, help="Seed", default=None, dest='seed') 1052 | vq_parser.add_argument("-opt", "--optimiser", type=str, help="Optimiser (Adam, AdamW, Adagrad, Adamax, DiffGrad, AdamP or RAdam)", default='Adam', dest='optimiser') 1053 | vq_parser.add_argument("-o", "--output", type=str, help="Output file", default="output.png", dest='output') 1054 | vq_parser.add_argument("-vid", "--video", type=bool, help="Create video frames?", default=False, dest='make_video') 1055 | vq_parser.add_argument("-d", "--deterministic", type=bool, help="Enable cudnn.deterministic?", default=False, dest='cudnn_determinism') 1056 | 1057 | return vq_parser 1058 | 1059 | square_size = [144, 144] 1060 | widescreen_size = [200, 112] # at the small size this becomes 192,112 1061 | 1062 | def process_args(vq_parser, namespace=None): 1063 | global global_aspect_width 1064 | 1065 | if namespace == None: 1066 | # command line: use ARGV to get args 1067 | args = vq_parser.parse_args() 1068 | else: 1069 | # notebook, ignore ARGV and use dictionary instead 1070 | args = vq_parser.parse_args(args=[], namespace=namespace) 1071 | 1072 | if args.cudnn_determinism: 1073 | torch.backends.cudnn.deterministic = True 1074 | 1075 | quality_to_clip_models_table = { 1076 | 'draft': 'ViT-B/32', 1077 | 'normal': 'ViT-B/32,ViT-B/16', 1078 | 'better': 'RN50,ViT-B/32,ViT-B/16', 1079 | 'best': 'RN50x4,ViT-B/32,ViT-B/16' 1080 | } 1081 | quality_to_iterations_table = { 1082 | 'draft': 200, 1083 | 'normal': 350, 1084 | 'better': 500, 1085 | 'best': 500 1086 | } 1087 | quality_to_scale_table = { 1088 | 'draft': 1, 1089 | 'normal': 2, 1090 | 'better': 3, 1091 | 'best': 4 1092 | } 1093 | # this should be replaced with logic that does somethings 1094 | # smart based on available memory (eg: size, num_models, etc) 1095 | quality_to_num_cuts_table = { 1096 | 'draft': 40, 1097 | 'normal': 40, 1098 | 'better': 40, 1099 | 'best': 40 1100 | } 1101 | 1102 | if args.quality not in quality_to_clip_models_table: 1103 | print("Qualitfy setting not understood, aborting -> ", argz.quality) 1104 | exit(1) 1105 | 1106 | if args.clip_models is None: 1107 | args.clip_models = quality_to_clip_models_table[args.quality] 1108 | if args.iterations is None: 1109 | args.iterations = quality_to_iterations_table[args.quality] 1110 | if args.num_cuts is None: 1111 | args.num_cuts = quality_to_num_cuts_table[args.quality] 1112 | if args.ezsize is None and args.scale is None: 1113 | args.scale = quality_to_scale_table[args.quality] 1114 | 1115 | size_to_scale_table = { 1116 | 'small': 1, 1117 | 'medium': 2, 1118 | 'large': 4 1119 | } 1120 | aspect_to_size_table = { 1121 | 'square': [150, 150], 1122 | 'widescreen': [200, 112] 1123 | } 1124 | 1125 | # determine size if not set 1126 | if args.size is None: 1127 | size_scale = args.scale 1128 | if size_scale is None: 1129 | if args.ezsize in size_to_scale_table: 1130 | size_scale = size_to_scale_table[args.ezsize] 1131 | else: 1132 | print("EZ Size not understood, aborting -> ", argz.ezsize) 1133 | exit(1) 1134 | if args.aspect in aspect_to_size_table: 1135 | base_size = aspect_to_size_table[args.aspect] 1136 | base_width = int(size_scale * base_size[0]) 1137 | base_height = int(size_scale * base_size[1]) 1138 | args.size = [base_width, base_height] 1139 | else: 1140 | print("aspect not understood, aborting -> ", argz.aspect) 1141 | exit(1) 1142 | 1143 | if args.aspect == "widescreen": 1144 | global_aspect_width = 16/9 1145 | 1146 | if args.init_noise.lower() == "none": 1147 | args.init_noise = None 1148 | 1149 | # Split text prompts using the pipe character 1150 | if args.prompts: 1151 | args.prompts = [phrase.strip() for phrase in args.prompts.split("|")] 1152 | 1153 | # Split text prompts using the pipe character 1154 | if args.spot_prompts: 1155 | args.spot_prompts = [phrase.strip() for phrase in args.spot_prompts.split("|")] 1156 | 1157 | # Split text prompts using the pipe character 1158 | if args.spot_prompts_off: 1159 | args.spot_prompts_off = [phrase.strip() for phrase in args.spot_prompts_off.split("|")] 1160 | 1161 | # Split text labels using the pipe character 1162 | if args.labels: 1163 | args.labels = [phrase.strip() for phrase in args.labels.split("|")] 1164 | 1165 | # Split target images using the pipe character 1166 | if args.image_prompts: 1167 | args.image_prompts = args.image_prompts.split("|") 1168 | args.image_prompts = [image.strip() for image in args.image_prompts] 1169 | 1170 | # legacy "spread mode" removed 1171 | # if args.init_weight is not None: 1172 | # args.init_weight_pix = args.init_weight 1173 | # args.init_weight_cos = args.init_weight 1174 | # args.init_weight_dist = args.init_weight 1175 | 1176 | if args.overlay_every is not None and args.overlay_every <= 0: 1177 | args.overlay_every = None 1178 | 1179 | clip_models = args.clip_models.split(",") 1180 | args.clip_models = [model.strip() for model in clip_models] 1181 | 1182 | # Make video steps directory 1183 | if args.make_video: 1184 | if not os.path.exists('steps'): 1185 | os.mkdir('steps') 1186 | 1187 | return args 1188 | 1189 | def reset_settings(): 1190 | global global_clipit_settings 1191 | global_clipit_settings = {} 1192 | 1193 | def add_settings(**kwargs): 1194 | global global_clipit_settings 1195 | for k, v in kwargs.items(): 1196 | if v is None: 1197 | # just remove the key if it is there 1198 | global_clipit_settings.pop(k, None) 1199 | else: 1200 | global_clipit_settings[k] = v 1201 | 1202 | def apply_settings(): 1203 | global global_clipit_settings 1204 | settingsDict = None 1205 | vq_parser = setup_parser() 1206 | 1207 | if len(global_clipit_settings) > 0: 1208 | # check for any bogus entries in the settings 1209 | dests = [d.dest for d in vq_parser._actions] 1210 | for k in global_clipit_settings: 1211 | if not k in dests: 1212 | raise ValueError(f"Requested setting not found, aborting: {k}={global_clipit_settings[k]}") 1213 | 1214 | # convert dictionary to easyDict 1215 | # which can be used as an argparse namespace instead 1216 | # settingsDict = easydict.EasyDict(global_clipit_settings) 1217 | settingsDict = SimpleNamespace(**global_clipit_settings) 1218 | 1219 | settings = process_args(vq_parser, settingsDict) 1220 | return settings 1221 | 1222 | def main(): 1223 | settings = apply_settings() 1224 | do_init(settings) 1225 | do_run(settings) 1226 | 1227 | if __name__ == '__main__': 1228 | main() 1229 | -------------------------------------------------------------------------------- /models/.gitignore: -------------------------------------------------------------------------------- 1 | *.yaml 2 | *.ckpt -------------------------------------------------------------------------------- /opt_tester.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Using each optimiser, generate images using a range of learning rates 4 | # Produce a labelled montage to easily view the results 5 | 6 | TEXT="A painting in the style of Paul Gauguin" 7 | OUT_DIR="/home/nerdy/github/VQGAN-CLIP/Saves/OptimiserTesting-60it-Noise-NPW-1" 8 | ITERATIONS=60 9 | SAVE_EVERY=60 10 | HEIGHT=256 11 | WIDTH=256 12 | SEED=`shuf -i 1-9999999999 -n 1` # Keep the same seed each epoch for more deterministic runs 13 | 14 | # Main 15 | ################# 16 | 17 | export CUBLAS_WORKSPACE_CONFIG=:4096:8 18 | mkdir -p "$OUT_DIR" 19 | 20 | function do_optimiser_test () { 21 | OPTIMISER="$1" 22 | LR="$2" 23 | STEP="$3" 24 | NPW="$4" 25 | for i in {1..10} 26 | do 27 | PADDED_COUNT=$(printf "%03d" "$COUNT") 28 | echo "Loop for $OPTIMISER - $LR" 29 | python generate.py -p "$TEXT" -in pixels -o "$OUT_DIR"/"$PADDED_COUNT"-"$OPTIMISER"-"$LR"-"$NPW".png -opt "$OPTIMISER" -lr "$LR" -i "$ITERATIONS" -se "$SAVE_EVERY" -s "$HEIGHT" "$WIDTH" --seed "$SEED" -d True -iw 1 -nps 666 -npw "$NPW" -d True 30 | LR=$(echo $LR + $STEP | bc) 31 | ((COUNT++)) 32 | done 33 | } 34 | 35 | # Test optimisers 36 | COUNT=0 37 | do_optimiser_test "Adam" .1 .1 1 38 | COUNT=10 39 | do_optimiser_test "AdamW" .1 .1 1 40 | COUNT=20 41 | do_optimiser_test "Adamax" .1 .1 1 42 | COUNT=30 43 | do_optimiser_test "Adagrad" .1 .25 1 44 | COUNT=40 45 | do_optimiser_test "AdamP" .1 .25 1 46 | COUNT=50 47 | do_optimiser_test "RAdam" .1 .25 1 48 | COUNT=60 49 | do_optimiser_test "DiffGrad" .1 .25 1 50 | 51 | # Make montage 52 | mogrify -font Liberation-Sans -fill white -undercolor '#00000080' -pointsize 14 -gravity NorthEast -annotate +10+10 %t "$OUT_DIR"/*.png 53 | montage "$OUT_DIR"/*.png -geometry 256x256+1+1 -tile 10x7 collage.jpg 54 | -------------------------------------------------------------------------------- /pixeldrawer.py: -------------------------------------------------------------------------------- 1 | from DrawingInterface import DrawingInterface 2 | 3 | import pydiffvg 4 | import torch 5 | import skimage 6 | import skimage.io 7 | import random 8 | import ttools.modules 9 | import argparse 10 | import math 11 | import torchvision 12 | import torchvision.transforms as transforms 13 | import numpy as np 14 | import PIL.Image 15 | 16 | pydiffvg.set_print_timing(False) 17 | 18 | class PixelDrawer(DrawingInterface): 19 | num_rows = 45 20 | num_cols = 80 21 | do_mono = False 22 | pixels = [] 23 | 24 | def __init__(self, width, height, do_mono, shape=None, scale=None): 25 | super(DrawingInterface, self).__init__() 26 | 27 | self.canvas_width = width 28 | self.canvas_height = height 29 | self.do_mono = do_mono 30 | if shape is not None: 31 | self.num_cols, self.num_rows = shape 32 | if scale is not None and scale > 0: 33 | self.num_cols = int(self.num_cols / scale) 34 | self.num_rows = int(self.num_rows / scale) 35 | 36 | 37 | def load_model(self, config_path, checkpoint_path, device): 38 | # gamma = 1.0 39 | 40 | # Use GPU if available 41 | pydiffvg.set_use_gpu(torch.cuda.is_available()) 42 | pydiffvg.set_device(device) 43 | self.device = device 44 | 45 | canvas_width, canvas_height = self.canvas_width, self.canvas_height 46 | num_rows, num_cols = self.num_rows, self.num_cols 47 | cell_width = canvas_width / num_cols 48 | cell_height = canvas_height / num_rows 49 | 50 | # Initialize Random Pixels 51 | shapes = [] 52 | shape_groups = [] 53 | colors = [] 54 | for r in range(num_rows): 55 | cur_y = r * cell_height 56 | for c in range(num_cols): 57 | cur_x = c * cell_width 58 | if self.do_mono: 59 | mono_color = random.random() 60 | cell_color = torch.tensor([mono_color, mono_color, mono_color, 1.0]) 61 | else: 62 | cell_color = torch.tensor([random.random(), random.random(), random.random(), 1.0]) 63 | colors.append(cell_color) 64 | p0 = [cur_x, cur_y] 65 | p1 = [cur_x+cell_width, cur_y+cell_height] 66 | path = pydiffvg.Rect(p_min=torch.tensor(p0), p_max=torch.tensor(p1)) 67 | shapes.append(path) 68 | path_group = pydiffvg.ShapeGroup(shape_ids = torch.tensor([len(shapes) - 1]), stroke_color = None, fill_color = cell_color) 69 | shape_groups.append(path_group) 70 | 71 | # Just some diffvg setup 72 | scene_args = pydiffvg.RenderFunction.serialize_scene(\ 73 | canvas_width, canvas_height, shapes, shape_groups) 74 | render = pydiffvg.RenderFunction.apply 75 | img = render(canvas_width, canvas_height, 2, 2, 0, None, *scene_args) 76 | 77 | color_vars = [] 78 | for group in shape_groups: 79 | group.fill_color.requires_grad = True 80 | color_vars.append(group.fill_color) 81 | 82 | # Optimizers 83 | # points_optim = torch.optim.Adam(points_vars, lr=1.0) 84 | # width_optim = torch.optim.Adam(stroke_width_vars, lr=0.1) 85 | color_optim = torch.optim.Adam(color_vars, lr=0.02) 86 | 87 | self.img = img 88 | self.shapes = shapes 89 | self.shape_groups = shape_groups 90 | self.opts = [color_optim] 91 | 92 | def get_opts(self): 93 | return self.opts 94 | 95 | def rand_init(self, toksX, toksY): 96 | # TODO 97 | pass 98 | 99 | def init_from_tensor(self, init_tensor): 100 | # TODO 101 | pass 102 | 103 | def reapply_from_tensor(self, new_tensor): 104 | # TODO 105 | pass 106 | 107 | def get_z_from_tensor(self, ref_tensor): 108 | return None 109 | 110 | def get_num_resolutions(self): 111 | # TODO 112 | return 5 113 | 114 | def synth(self, cur_iteration): 115 | render = pydiffvg.RenderFunction.apply 116 | scene_args = pydiffvg.RenderFunction.serialize_scene(\ 117 | self.canvas_width, self.canvas_height, self.shapes, self.shape_groups) 118 | img = render(self.canvas_width, self.canvas_height, 2, 2, cur_iteration, None, *scene_args) 119 | img = img[:, :, 3:4] * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device = self.device) * (1 - img[:, :, 3:4]) 120 | img = img[:, :, :3] 121 | img = img.unsqueeze(0) 122 | img = img.permute(0, 3, 1, 2) # NHWC -> NCHW 123 | self.img = img 124 | return img 125 | 126 | @torch.no_grad() 127 | def to_image(self): 128 | img = self.img.detach().cpu().numpy()[0] 129 | if self.do_mono: 130 | img = img[1] # take the green channel (they should all be the same) 131 | s = img.shape 132 | # threshold is an approximate gaussian from [0,1] 133 | random_bates = np.average(np.random.uniform(size=(5, s[0], s[1])), axis=0) 134 | # pimg = PIL.Image.fromarray(np.uint8(random_bates*255), mode="L") 135 | # pimg.save("bates_debug.png") 136 | img = np.where(img > random_bates, 1, 0) 137 | img = np.uint8(img * 255) 138 | pimg = PIL.Image.fromarray(img, mode="L") 139 | else: 140 | img = np.transpose(img, (1, 2, 0)) 141 | img = np.clip(img, 0, 1) 142 | img = np.uint8(img * 254) 143 | pimg = PIL.Image.fromarray(img, mode="RGB") 144 | return pimg 145 | 146 | def clip_z(self): 147 | with torch.no_grad(): 148 | for group in self.shape_groups: 149 | group.fill_color.data[:3].clamp_(0.0, 1.0) 150 | group.fill_color.data[3].clamp_(1.0, 1.0) 151 | if self.do_mono: 152 | avg_amount = torch.mean(group.fill_color.data[:3]) 153 | group.fill_color.data[:3] = avg_amount 154 | 155 | def get_z(self): 156 | return None 157 | 158 | def get_z_copy(self): 159 | return None 160 | -------------------------------------------------------------------------------- /random.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | text_one=("A painting of a" "A pencil art sketch of a" "An illustration of a" "A photograph of a") 4 | text_two=("spinning" "dreaming" "watering" "loving" "eating" "drinking" "sleeping" "repeating" "surreal" "psychedelic") 5 | text_three=("fish" "egg" "peacock" "watermelon" "pickle" "horse" "dog" "house" "kitchen" "bedroom" "door" "table" "lamp" "dresser" "watch" "logo" "icon" "tree" 6 | "grass" "flower" "plant" "shrub" "bloom" "screwdriver" "spanner" "figurine" "statue" "graveyard" "hotel" "bus" "train" "car" "lamp" "computer" "monitor") 7 | styles=("Art Nouveau" "Camille Pissarro" "Michelangelo Caravaggio" "Claude Monet" "Edgar Degas" "Edvard Munch" "Fauvism" "Futurism" "Impressionism" 8 | "Picasso" "Pop Art" "Modern art" "Surreal Art" "Sandro Botticelli" "oil paints" "watercolours" "weird bananas" "strange colours") 9 | 10 | pickword() { 11 | local array=("$@") 12 | ARRAY_RANGE=$((${#array[@]}-1)) 13 | RANDOM_ENTRY=`shuf -i 0-$ARRAY_RANGE -n 1` 14 | UPDATE=${array[$RANDOM_ENTRY]} 15 | } 16 | 17 | 18 | # Generate some images 19 | for number in {1..50} 20 | do 21 | # Make some random text 22 | pickword "${text_one[@]}" 23 | TEXT=$UPDATE 24 | pickword "${text_two[@]}" 25 | TEXT+=" "$UPDATE 26 | pickword "${text_three[@]}" 27 | TEXT+=" "$UPDATE 28 | pickword "${text_three[@]}" 29 | TEXT+=" and a "$UPDATE 30 | pickword "${styles[@]}" 31 | TEXT+=" in the style of "$UPDATE 32 | pickword "${styles[@]}" 33 | TEXT+=" and "$UPDATE 34 | 35 | python generate.py -p "$TEXT" -o "$number".png 36 | done 37 | 38 | 39 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | aiohttp==3.7.4.post0 3 | antlr4-python3-runtime==4.8 4 | async-timeout==3.0.1 5 | attrs==21.2.0 6 | backcall==0.2.0 7 | cachetools==4.2.2 8 | certifi==2021.5.30 9 | chardet==4.0.0 10 | decorator==5.0.9 11 | einops==0.3.0 12 | fsspec==2021.6.1 13 | ftfy==6.0.3 14 | future==0.18.2 15 | google-auth==1.32.0 16 | google-auth-oauthlib==0.4.4 17 | grpcio==1.38.1 18 | idna==2.10 19 | imageio==2.9.0 20 | imageio-ffmpeg==0.4.4 21 | ipython==7.25.0 22 | ipython-genutils==0.2.0 23 | jedi==0.18.0 24 | kornia==0.5.4 25 | Markdown==3.3.4 26 | matplotlib-inline==0.1.2 27 | multidict==5.1.0 28 | numpy==1.21.0 29 | oauthlib==3.1.1 30 | omegaconf==2.1.0 31 | packaging==20.9 32 | parso==0.8.2 33 | pexpect==4.8.0 34 | pickleshare==0.7.5 35 | Pillow==8.2.0 36 | prompt-toolkit==3.0.19 37 | protobuf==3.17.3 38 | ptyprocess==0.7.0 39 | pyasn1==0.4.8 40 | pyasn1-modules==0.2.8 41 | pyDeprecate==0.3.0 42 | Pygments==2.9.0 43 | pyparsing==2.4.7 44 | pytorch-lightning==1.3.7.post0 45 | PyYAML==5.4.1 46 | regex==2021.4.4 47 | requests==2.25.1 48 | requests-oauthlib==1.3.0 49 | rsa==4.7.2 50 | six==1.16.0 51 | tensorboard==2.4.1 52 | tensorboard-plugin-wit==1.8.0 53 | torch==1.9.0+cu111 54 | torchaudio==0.9.0 55 | torchmetrics==0.3.2 56 | torchvision==0.10.0+cu111 57 | tqdm==4.61.1 58 | traitlets==5.0.5 59 | typing-extensions==3.10.0.0 60 | urllib3==1.26.6 61 | wcwidth==0.2.5 62 | Werkzeug==2.0.1 63 | yarl==1.6.3 64 | -------------------------------------------------------------------------------- /samples/A_painting_of_an_apple_in_a_fruitbowl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dribnet/clipit/c8c557ef038cfe1cbe418dcffae015df50974aaf/samples/A_painting_of_an_apple_in_a_fruitbowl.png -------------------------------------------------------------------------------- /samples/Apple_weird.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dribnet/clipit/c8c557ef038cfe1cbe418dcffae015df50974aaf/samples/Apple_weird.png -------------------------------------------------------------------------------- /samples/Bedroom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dribnet/clipit/c8c557ef038cfe1cbe418dcffae015df50974aaf/samples/Bedroom.png -------------------------------------------------------------------------------- /samples/Cartoon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dribnet/clipit/c8c557ef038cfe1cbe418dcffae015df50974aaf/samples/Cartoon.png -------------------------------------------------------------------------------- /samples/Cartoon2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dribnet/clipit/c8c557ef038cfe1cbe418dcffae015df50974aaf/samples/Cartoon2.png -------------------------------------------------------------------------------- /samples/Cartoon3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dribnet/clipit/c8c557ef038cfe1cbe418dcffae015df50974aaf/samples/Cartoon3.png -------------------------------------------------------------------------------- /samples/DemonBiscuits.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dribnet/clipit/c8c557ef038cfe1cbe418dcffae015df50974aaf/samples/DemonBiscuits.png -------------------------------------------------------------------------------- /samples/Football.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dribnet/clipit/c8c557ef038cfe1cbe418dcffae015df50974aaf/samples/Football.png -------------------------------------------------------------------------------- /samples/Fractal_Landscape3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dribnet/clipit/c8c557ef038cfe1cbe418dcffae015df50974aaf/samples/Fractal_Landscape3.png -------------------------------------------------------------------------------- /samples/Games_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dribnet/clipit/c8c557ef038cfe1cbe418dcffae015df50974aaf/samples/Games_5.png -------------------------------------------------------------------------------- /samples/VanGogh.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dribnet/clipit/c8c557ef038cfe1cbe418dcffae015df50974aaf/samples/VanGogh.jpg -------------------------------------------------------------------------------- /samples/pencil_sketch_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dribnet/clipit/c8c557ef038cfe1cbe418dcffae015df50974aaf/samples/pencil_sketch_2.png -------------------------------------------------------------------------------- /samples/samples.txt: -------------------------------------------------------------------------------- 1 | Example files for documentation 2 | -------------------------------------------------------------------------------- /samples/vvg_picasso.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dribnet/clipit/c8c557ef038cfe1cbe418dcffae015df50974aaf/samples/vvg_picasso.png -------------------------------------------------------------------------------- /samples/vvg_psychedelic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dribnet/clipit/c8c557ef038cfe1cbe418dcffae015df50974aaf/samples/vvg_psychedelic.png -------------------------------------------------------------------------------- /samples/vvg_sketch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dribnet/clipit/c8c557ef038cfe1cbe418dcffae015df50974aaf/samples/vvg_sketch.png -------------------------------------------------------------------------------- /samples/zoom.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dribnet/clipit/c8c557ef038cfe1cbe418dcffae015df50974aaf/samples/zoom.gif -------------------------------------------------------------------------------- /video_styler.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Video styler - Use all images in a directory and style them 3 | # video_styler.sh video.mp4 4 | 5 | # Style text 6 | TEXT="Oil painting of a woman in the foreground | pencil art landscape background" 7 | 8 | ## Input and output frame directories 9 | FRAMES_IN="/home/nerdy/github/VQGAN-CLIP/VideoFrames" 10 | FRAMES_OUT="/home/nerdy/github/VQGAN-CLIP/Saves/VideoStyleTesting" 11 | 12 | ## Output image size 13 | HEIGHT=640 14 | WIDTH=360 15 | 16 | ## Iterations 17 | ITERATIONS=25 18 | SAVE_EVERY=$ITERATIONS 19 | 20 | ## Optimiser & Learning rate 21 | OPTIMISER=Adagrad # Adam, AdamW, Adagrad, Adamax 22 | LR=0.2 23 | 24 | # Fixed seed 25 | SEED=`shuf -i 1-9999999999 -n 1` # Keep the same seed each frame for more deterministic runs 26 | 27 | # MAIN 28 | ############################ 29 | mkdir -p "$FRAMES_IN" 30 | mkdir -p "$FRAMES_OUT" 31 | 32 | # For cuDNN determinism 33 | export CUBLAS_WORKSPACE_CONFIG=:4096:8 34 | 35 | # Extract video into frames 36 | ffmpeg -y -i "$1" -q:v 2 "$FRAMES_IN"/frame-%04d.jpg 37 | 38 | # Style all the frames 39 | ls "$FRAMES_IN" | while read file; do 40 | # Set the output filename 41 | FILENAME="$FRAMES_OUT"/"$file"-"out".jpg 42 | 43 | # And imagine! 44 | echo "Input frame: $file" 45 | echo "Style text: $TEXT" 46 | echo "Output file: $FILENAME" 47 | 48 | python generate.py -p "$TEXT" -ii "$FRAMES_IN"/"$file" -o "$FILENAME" -opt "$OPTIMISER" -lr "$LR" -i "$ITERATIONS" -se "$SAVE_EVERY" -s "$HEIGHT" "$WIDTH" -sd "$SEED" -d True 49 | done 50 | 51 | ffmpeg -y -i "$FRAMES_OUT"/frame-%04d.jpg-out.jpg -b:v 8M -c:v h264_nvenc -pix_fmt yuv420p -strict -2 -filter:v "minterpolate='mi_mode=mci:mc_mode=aobmc:vsbmc=1:fps=60'" style_video.mp4 52 | -------------------------------------------------------------------------------- /vqgan.py: -------------------------------------------------------------------------------- 1 | # Originally made by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings) 2 | # The original BigGAN+CLIP method was by https://twitter.com/advadnoun 3 | 4 | from DrawingInterface import DrawingInterface 5 | 6 | import sys 7 | import subprocess 8 | sys.path.append('taming-transformers') 9 | import os.path 10 | import torch 11 | from torch.nn import functional as F 12 | from torchvision.transforms import functional as TF 13 | 14 | from omegaconf import OmegaConf 15 | from taming.models import cond_transformer, vqgan 16 | 17 | vqgan_config_table = { 18 | "imagenet_f16_1024": 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_1024.yaml', 19 | "imagenet_f16_16384": 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1', 20 | "imagenet_f16_16384m": 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.yaml', 21 | "openimages_f16_8192": 'https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1', 22 | "coco": 'https://dl.nmkd.de/ai/clip/coco/coco.yaml', 23 | "faceshq": 'https://drive.google.com/uc?export=download&id=1fHwGx_hnBtC8nsq7hesJvs-Klv-P0gzT', 24 | "wikiart_1024": 'http://mirror.io.community/blob/vqgan/wikiart.yaml', 25 | "wikiart_16384": 'http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.yaml', 26 | "wikiart_16384m": 'http://mirror.io.community/blob/vqgan/wikiart_16384.yaml', 27 | "sflckr": 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fconfigs%2F2020-11-09T13-31-51-project.yaml&dl=1', 28 | } 29 | vqgan_checkpoint_table = { 30 | "imagenet_f16_1024": 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_1024.ckpt', 31 | "imagenet_f16_16384": 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fckpts%2Flast.ckpt&dl=1', 32 | "imagenet_f16_16384m": 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.ckpt', 33 | "openimages_f16_8192": 'https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fckpts%2Flast.ckpt&dl=1', 34 | "coco": 'https://dl.nmkd.de/ai/clip/coco/coco.ckpt', 35 | "faceshq": 'https://app.koofr.net/content/links/a04deec9-0c59-4673-8b37-3d696fe63a5d/files/get/last.ckpt?path=%2F2020-11-13T21-41-45_faceshq_transformer%2Fcheckpoints%2Flast.ckpt', 36 | "wikiart_1024": 'http://mirror.io.community/blob/vqgan/wikiart.ckpt', 37 | "wikiart_16384": 'http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.ckpt', 38 | "wikiart_16384m": 'http://mirror.io.community/blob/vqgan/wikiart_16384.ckpt', 39 | "sflckr": 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fcheckpoints%2Flast.ckpt&dl=1' 40 | } 41 | 42 | def wget_file(url, out): 43 | try: 44 | output = subprocess.check_output(['wget', '-O', out, url]) 45 | except subprocess.CalledProcessError as cpe: 46 | output = cpe.output 47 | print("Ignoring non-zero exit: ", output) 48 | 49 | class ReplaceGrad(torch.autograd.Function): 50 | @staticmethod 51 | def forward(ctx, x_forward, x_backward): 52 | ctx.shape = x_backward.shape 53 | return x_forward 54 | 55 | @staticmethod 56 | def backward(ctx, grad_in): 57 | return None, grad_in.sum_to_size(ctx.shape) 58 | 59 | replace_grad = ReplaceGrad.apply 60 | 61 | def vector_quantize(x, codebook): 62 | d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T 63 | indices = d.argmin(-1) 64 | x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook 65 | return replace_grad(x_q, x) 66 | 67 | class ClampWithGrad(torch.autograd.Function): 68 | @staticmethod 69 | def forward(ctx, input, min, max): 70 | ctx.min = min 71 | ctx.max = max 72 | ctx.save_for_backward(input) 73 | return input.clamp(min, max) 74 | 75 | @staticmethod 76 | def backward(ctx, grad_in): 77 | input, = ctx.saved_tensors 78 | return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None 79 | 80 | clamp_with_grad = ClampWithGrad.apply 81 | 82 | class VqganDrawer(DrawingInterface): 83 | def __init__(self, vqgan_model): 84 | super(DrawingInterface, self).__init__() 85 | self.vqgan_model = vqgan_model 86 | 87 | def load_model(self, config_path, checkpoint_path, device): 88 | gumbel = False 89 | 90 | if config_path is None: 91 | config_path = f'models/vqgan_{self.vqgan_model}.yaml' 92 | 93 | if checkpoint_path is None: 94 | checkpoint_path = f'models/vqgan_{self.vqgan_model}.ckpt' 95 | 96 | if not os.path.exists(config_path): 97 | wget_file(vqgan_config_table[self.vqgan_model], config_path) 98 | if not os.path.exists(checkpoint_path): 99 | wget_file(vqgan_checkpoint_table[self.vqgan_model], checkpoint_path) 100 | 101 | config = OmegaConf.load(config_path) 102 | if config.model.target == 'taming.models.vqgan.VQModel': 103 | model = vqgan.VQModel(**config.model.params) 104 | model.eval().requires_grad_(False) 105 | model.init_from_ckpt(checkpoint_path) 106 | elif config.model.target == 'taming.models.vqgan.GumbelVQ': 107 | model = vqgan.GumbelVQ(**config.model.params) 108 | model.eval().requires_grad_(False) 109 | model.init_from_ckpt(checkpoint_path) 110 | gumbel = True 111 | elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer': 112 | parent_model = cond_transformer.Net2NetTransformer(**config.model.params) 113 | parent_model.eval().requires_grad_(False) 114 | parent_model.init_from_ckpt(checkpoint_path) 115 | model = parent_model.first_stage_model 116 | else: 117 | raise ValueError(f'unknown model type: {config.model.target}') 118 | del model.loss 119 | 120 | # model, gumbel = load_vqgan_model(vqgan_config, vqgan_checkpoint) 121 | self.model = model.to(device) 122 | self.gumbel = gumbel 123 | self.device = device 124 | 125 | if gumbel: 126 | self.e_dim = 256 127 | self.n_toks = model.quantize.n_embed 128 | self.z_min = model.quantize.embed.weight.min(dim=0).values[None, :, None, None] 129 | self.z_max = model.quantize.embed.weight.max(dim=0).values[None, :, None, None] 130 | else: 131 | self.e_dim = model.quantize.e_dim 132 | self.n_toks = model.quantize.n_e 133 | self.z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None] 134 | self.z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None] 135 | 136 | def get_opts(self): 137 | return None 138 | 139 | def rand_init(self, toksX, toksY): 140 | # legacy init 141 | one_hot = F.one_hot(torch.randint(self.n_toks, [toksY * toksX], device=self.device), n_toks).float() 142 | if self.gumbel: 143 | self.z = one_hot @ self.model.quantize.embed.weight 144 | else: 145 | self.z = one_hot @ self.model.quantize.embedding.weight 146 | 147 | self.z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2) 148 | self.z.requires_grad_(True) 149 | 150 | def init_from_tensor(self, init_tensor): 151 | self.z, *_ = self.model.encode(init_tensor) 152 | self.z.requires_grad_(True) 153 | 154 | def reapply_from_tensor(self, new_tensor): 155 | new_z, *_ = self.model.encode(new_tensor) 156 | with torch.no_grad(): 157 | self.z.copy_(new_z) 158 | 159 | def get_z_from_tensor(self, ref_tensor): 160 | z_ref, *_ = self.model.encode(ref_tensor) 161 | return z_ref 162 | 163 | def get_num_resolutions(self): 164 | return self.model.decoder.num_resolutions 165 | 166 | def synth(self, cur_iteration): 167 | if self.gumbel: 168 | z_q = vector_quantize(self.z.movedim(1, 3), self.model.quantize.embed.weight).movedim(3, 1) # Vector quantize 169 | else: 170 | z_q = vector_quantize(self.z.movedim(1, 3), self.model.quantize.embedding.weight).movedim(3, 1) 171 | return clamp_with_grad(self.model.decode(z_q).add(1).div(2), 0, 1) 172 | 173 | @torch.no_grad() 174 | def to_image(self): 175 | out = self.synth(None) 176 | return TF.to_pil_image(out[0].cpu()) 177 | 178 | def clip_z(self): 179 | with torch.no_grad(): 180 | self.z.copy_(self.z.maximum(self.z_min).minimum(self.z_max)) 181 | 182 | def get_z(self): 183 | return self.z 184 | 185 | def set_z(self, new_z): 186 | with torch.no_grad(): 187 | return self.z.copy_(new_z) 188 | 189 | def get_z_copy(self): 190 | return self.z.clone() 191 | # return model, gumbel 192 | 193 | ### EXTERNAL INTERFACE 194 | ### load_vqgan_model 195 | 196 | if __name__ == '__main__': 197 | main() 198 | -------------------------------------------------------------------------------- /vqgan.yml: -------------------------------------------------------------------------------- 1 | name: vqgan 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=1_gnu 9 | - ca-certificates=2021.5.30=ha878542_0 10 | - certifi=2021.5.30=py39hf3d152e_0 11 | - ld_impl_linux-64=2.35.1=hea4e1c9_2 12 | - libffi=3.3=h58526e2_2 13 | - libgcc-ng=9.3.0=h2828fa1_19 14 | - libgomp=9.3.0=h2828fa1_19 15 | - libstdcxx-ng=9.3.0=h6de172a_19 16 | - ncurses=6.2=h58526e2_4 17 | - openssl=1.1.1k=h7f98852_0 18 | - pip=21.1.3=pyhd8ed1ab_0 19 | - python=3.9.5=h49503c6_0_cpython 20 | - python_abi=3.9=2_cp39 21 | - readline=8.1=h46c0cb4_0 22 | - setuptools=49.6.0=py39hf3d152e_3 23 | - sqlite=3.36.0=h9cd32fc_0 24 | - tk=8.6.10=h21135ba_1 25 | - tzdata=2021a=he74cb21_0 26 | - wheel=0.36.2=pyhd3deb0d_0 27 | - xz=5.2.5=h516909a_1 28 | - zlib=1.2.11=h516909a_1010 29 | - pip: 30 | - absl-py==0.13.0 31 | - aiohttp==3.7.4.post0 32 | - antlr4-python3-runtime==4.8 33 | - async-timeout==3.0.1 34 | - attrs==21.2.0 35 | - backcall==0.2.0 36 | - cachetools==4.2.2 37 | - chardet==4.0.0 38 | - decorator==5.0.9 39 | - einops==0.3.0 40 | - fsspec==2021.6.1 41 | - ftfy==6.0.3 42 | - future==0.18.2 43 | - google-auth==1.32.0 44 | - google-auth-oauthlib==0.4.4 45 | - grpcio==1.38.1 46 | - idna==2.10 47 | - imageio==2.9.0 48 | - imageio-ffmpeg==0.4.4 49 | - ipython==7.25.0 50 | - ipython-genutils==0.2.0 51 | - jedi==0.18.0 52 | - kornia==0.5.4 53 | - markdown==3.3.4 54 | - matplotlib-inline==0.1.2 55 | - multidict==5.1.0 56 | - numpy==1.21.0 57 | - oauthlib==3.1.1 58 | - omegaconf==2.1.0 59 | - packaging==20.9 60 | - parso==0.8.2 61 | - pexpect==4.8.0 62 | - pickleshare==0.7.5 63 | - pillow==8.2.0 64 | - prompt-toolkit==3.0.19 65 | - protobuf==3.17.3 66 | - ptyprocess==0.7.0 67 | - pyasn1==0.4.8 68 | - pyasn1-modules==0.2.8 69 | - pydeprecate==0.3.0 70 | - pygments==2.9.0 71 | - pyparsing==2.4.7 72 | - pytorch-lightning==1.3.7.post0 73 | - pyyaml==5.4.1 74 | - regex==2021.4.4 75 | - requests==2.25.1 76 | - requests-oauthlib==1.3.0 77 | - rsa==4.7.2 78 | - six==1.16.0 79 | - tensorboard==2.4.1 80 | - tensorboard-plugin-wit==1.8.0 81 | - torch==1.9.0+cu111 82 | - torchaudio==0.9.0 83 | - torchmetrics==0.3.2 84 | - torchvision==0.10.0+cu111 85 | - tqdm==4.61.1 86 | - traitlets==5.0.5 87 | - typing-extensions==3.10.0.0 88 | - urllib3==1.26.6 89 | - wcwidth==0.2.5 90 | - werkzeug==2.0.1 91 | - yarl==1.6.3 92 | prefix: /home/nerdy/anaconda3/envs/vqgan 93 | -------------------------------------------------------------------------------- /zoom.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Example "Zoom" movie generation 3 | # e.g. ./zoom.sh "A painting of zooming in to a surreal, alien world" Zoom.png 180 4 | 5 | TEXT="$1" 6 | FILENAME="$2" 7 | MAX_EPOCHS=$3 8 | 9 | LR=0.1 10 | OPTIMISER=Adam 11 | MAX_ITERATIONS=25 12 | SEED=`shuf -i 1-9999999999 -n 1` # Keep the same seed each epoch for more deterministic runs 13 | 14 | # Extract 15 | FILENAME_NO_EXT=${FILENAME%.*} 16 | FILE_EXTENSION=${FILENAME##*.} 17 | 18 | # Initial run 19 | python generate.py -p="$TEXT" -opt="$OPTIMISER" -lr=$LR -i=$MAX_ITERATIONS -se=$MAX_ITERATIONS --seed=$SEED -o="$FILENAME" 20 | cp "$FILENAME" "$FILENAME_NO_EXT"-0000."$FILE_EXTENSION" 21 | convert "$FILENAME" -distort SRT 1.01,0 -gravity center "$FILENAME" # Zoom 22 | convert "$FILENAME" -distort SRT 1 -gravity center "$FILENAME" # Rotate 23 | 24 | # Feedback image loop 25 | for (( i=1; i<=$MAX_EPOCHS; i++ )) 26 | do 27 | padded_count=$(printf "%04d" "$i") 28 | python generate.py -p="$TEXT" -opt="$OPTIMISER" -lr=$LR -i=$MAX_ITERATIONS -se=$MAX_ITERATIONS --seed=$SEED -ii="$FILENAME" -o="$FILENAME" 29 | cp "$FILENAME" "$FILENAME_NO_EXT"-"$padded_count"."$FILE_EXTENSION" 30 | convert "$FILENAME" -distort SRT 1.01,0 -gravity center "$FILENAME" # Zoom 31 | convert "$FILENAME" -distort SRT 1 -gravity center "$FILENAME" # Rotate 32 | done 33 | 34 | # Make video - Nvidia GPU expected 35 | ffmpeg -y -i "$FILENAME_NO_EXT"-%04d."$FILE_EXTENSION" -b:v 8M -c:v h264_nvenc -pix_fmt yuv420p -strict -2 -filter:v "minterpolate='mi_mode=mci:mc_mode=aobmc:vsbmc=1:fps=60'" video.mp4 36 | --------------------------------------------------------------------------------