├── Mse_regulized_Modified_VQGANCLIP_zquantize_public.ipynb ├── PixelDrawer.ipynb ├── Pixray_Panorama_Demo.ipynb ├── README.md ├── VQGAN+CLIP(Updated).ipynb ├── VQGAN+CLIP_(Zooming)_(z+quantize_method_with_addons).ipynb └── VQGAN+CLIP_(z+quantize_method_with_augmentations,_user_friendly_interface).ipynb /Mse_regulized_Modified_VQGANCLIP_zquantize_public.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "Mse_regulized_Modified_VQGANCLIP_zquantize_public.ipynb", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "display_name": "Python 3", 13 | "name": "python3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | } 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "id": "CppIQlPhhwhs" 24 | }, 25 | "source": [ 26 | "# Generates images from text prompts with VQGAN and CLIP (Mse regulized zquantize method).\n", 27 | "\n", 28 | "By jbustter https://twitter.com/jbusted1 .\n", 29 | "Based on a notebook by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings)\n", 30 | "\n", 31 | "\n", 32 | "*Modified by: Justin John*\n" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "metadata": { 38 | "id": "sCuopwEiNdAR", 39 | "cellView": "form" 40 | }, 41 | "source": [ 42 | "#@markdown #**Check GPU type**\n", 43 | "#@markdown ### Factory reset runtime if you don't have the desired GPU.\n", 44 | "\n", 45 | "#@markdown ---\n", 46 | "\n", 47 | "\n", 48 | "\n", 49 | "\n", 50 | "#@markdown V100 = Excellent (*Available only for Colab Pro Users*)\n", 51 | "\n", 52 | "#@markdown P100 = Very Good\n", 53 | "\n", 54 | "#@markdown T4 = Good\n", 55 | "\n", 56 | "#@markdown K80 = Meh\n", 57 | "\n", 58 | "#@markdown P4 = Aight\n", 59 | "\n", 60 | "#@markdown ---\n", 61 | "\n", 62 | "!nvidia-smi -L" 63 | ], 64 | "execution_count": null, 65 | "outputs": [] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "metadata": { 70 | "cellView": "form", 71 | "id": "7dViyQpDNhgP" 72 | }, 73 | "source": [ 74 | "#@markdown #**Anti-Disconnect for Google Colab**\n", 75 | "#@markdown ## Run this to stop it from disconnecting automatically \n", 76 | "#@markdown **(It will anyhow disconnect after 6 - 12 hrs for using the free version of Colab.)**\n", 77 | "#@markdown *(Colab Pro users will get about 24 hrs usage time)*\n", 78 | "#@markdown ---\n", 79 | "\n", 80 | "import IPython\n", 81 | "js_code = '''\n", 82 | "function ClickConnect(){\n", 83 | "console.log(\"Working\");\n", 84 | "document.querySelector(\"colab-toolbar-button#connect\").click()\n", 85 | "}\n", 86 | "setInterval(ClickConnect,60000)\n", 87 | "'''\n", 88 | "display(IPython.display.Javascript(js_code))" 89 | ], 90 | "execution_count": null, 91 | "outputs": [] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "metadata": { 96 | "id": "wSfISAhyPmyp", 97 | "cellView": "form" 98 | }, 99 | "source": [ 100 | "#@markdown #**Installation of libraries**\n", 101 | "# @markdown This cell will take a little while because it has to download several libraries\n", 102 | "\n", 103 | "#@markdown ---\n", 104 | "\n", 105 | "!git clone https://github.com/openai/CLIP\n", 106 | "!git clone https://github.com/CompVis/taming-transformers\n", 107 | "!pip install ftfy regex tqdm omegaconf pytorch-lightning\n", 108 | "!pip install kornia\n", 109 | "!pip install einops" 110 | ], 111 | "execution_count": null, 112 | "outputs": [] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "metadata": { 117 | "cellView": "form", 118 | "id": "JPKRgKREkQUY" 119 | }, 120 | "source": [ 121 | "#@markdown #**Selection of models to download**\n", 122 | "#@markdown ---\n", 123 | "#@markdown OpenImages and ImageNet models\n", 124 | "\n", 125 | "#@markdown ---\n", 126 | "\n", 127 | "imagenet_16384 = True #@param {type:\"boolean\"}\n", 128 | "\n", 129 | "openimages_8192 = True #@param {type:\"boolean\"}\n", 130 | "\n", 131 | "\n", 132 | "if imagenet_16384:\n", 133 | " !curl -L -o vqgan_imagenet_f16_16384.yaml -C - 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' #ImageNet 16384\n", 134 | " !curl -L -o vqgan_imagenet_f16_16384.ckpt -C - 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fckpts%2Flast.ckpt&dl=1' #ImageNet 16384\n", 135 | "if openimages_8192:\n", 136 | " !curl -L -o vqgan_openimages_f16_8192.yaml -C - 'https://dl.nmkd.de/ai/clip/vqgan/8k-2021-06/vqgan-f8-8192.ckpt' #ImageNet 16384\n", 137 | " !curl -L -o vqgan_openimages_f16_8192.ckpt -C - 'https://dl.nmkd.de/ai/clip/vqgan/8k-2021-06/vqgan-f8-8192.yaml' #ImageNet 16384\n" 138 | ], 139 | "execution_count": null, 140 | "outputs": [] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "metadata": { 145 | "id": "EXMSuW2EQWsd", 146 | "cellView": "form" 147 | }, 148 | "source": [ 149 | "#@markdown #**Loading libraries and definitions**\n", 150 | "\n", 151 | "import argparse\n", 152 | "import math\n", 153 | "from pathlib import Path\n", 154 | "import sys\n", 155 | "\n", 156 | "sys.path.append('./taming-transformers')\n", 157 | "\n", 158 | "from IPython import display\n", 159 | "from omegaconf import OmegaConf\n", 160 | "from PIL import Image\n", 161 | "from taming.models import cond_transformer, vqgan\n", 162 | "import torch\n", 163 | "from torch import nn, optim\n", 164 | "from torch.nn import functional as F\n", 165 | "from torchvision import transforms\n", 166 | "from torchvision.transforms import functional as TF\n", 167 | "from tqdm.notebook import tqdm\n", 168 | "import numpy as np\n", 169 | "\n", 170 | "from CLIP import clip\n", 171 | "\n", 172 | "import kornia.augmentation as K\n", 173 | "\n", 174 | "def noise_gen(shape):\n", 175 | " n, c, h, w = shape\n", 176 | " noise = torch.zeros([n, c, 1, 1])\n", 177 | " for i in reversed(range(5)):\n", 178 | " h_cur, w_cur = h // 2**i, w // 2**i\n", 179 | " noise = F.interpolate(noise, (h_cur, w_cur), mode='bicubic', align_corners=False)\n", 180 | " noise += torch.randn([n, c, h_cur, w_cur]) / 5\n", 181 | " return noise\n", 182 | "\n", 183 | "\n", 184 | "def sinc(x):\n", 185 | " return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))\n", 186 | "\n", 187 | "\n", 188 | "def lanczos(x, a):\n", 189 | " cond = torch.logical_and(-a < x, x < a)\n", 190 | " out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))\n", 191 | " return out / out.sum()\n", 192 | "\n", 193 | "\n", 194 | "def ramp(ratio, width):\n", 195 | " n = math.ceil(width / ratio + 1)\n", 196 | " out = torch.empty([n])\n", 197 | " cur = 0\n", 198 | " for i in range(out.shape[0]):\n", 199 | " out[i] = cur\n", 200 | " cur += ratio\n", 201 | " return torch.cat([-out[1:].flip([0]), out])[1:-1]\n", 202 | "\n", 203 | "\n", 204 | "def resample(input, size, align_corners=True):\n", 205 | " n, c, h, w = input.shape\n", 206 | " dh, dw = size\n", 207 | "\n", 208 | " input = input.view([n * c, 1, h, w])\n", 209 | "\n", 210 | " if dh < h:\n", 211 | " kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)\n", 212 | " pad_h = (kernel_h.shape[0] - 1) // 2\n", 213 | " input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')\n", 214 | " input = F.conv2d(input, kernel_h[None, None, :, None])\n", 215 | "\n", 216 | " if dw < w:\n", 217 | " kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)\n", 218 | " pad_w = (kernel_w.shape[0] - 1) // 2\n", 219 | " input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')\n", 220 | " input = F.conv2d(input, kernel_w[None, None, None, :])\n", 221 | "\n", 222 | " input = input.view([n, c, h, w])\n", 223 | " return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)\n", 224 | " \n", 225 | "\n", 226 | "# def replace_grad(fake, real):\n", 227 | "# return fake.detach() - real.detach() + real\n", 228 | "\n", 229 | "\n", 230 | "class ReplaceGrad(torch.autograd.Function):\n", 231 | " @staticmethod\n", 232 | " def forward(ctx, x_forward, x_backward):\n", 233 | " ctx.shape = x_backward.shape\n", 234 | " return x_forward\n", 235 | "\n", 236 | " @staticmethod\n", 237 | " def backward(ctx, grad_in):\n", 238 | " return None, grad_in.sum_to_size(ctx.shape)\n", 239 | "\n", 240 | "\n", 241 | "class ClampWithGrad(torch.autograd.Function):\n", 242 | " @staticmethod\n", 243 | " def forward(ctx, input, min, max):\n", 244 | " ctx.min = min\n", 245 | " ctx.max = max\n", 246 | " ctx.save_for_backward(input)\n", 247 | " return input.clamp(min, max)\n", 248 | "\n", 249 | " @staticmethod\n", 250 | " def backward(ctx, grad_in):\n", 251 | " input, = ctx.saved_tensors\n", 252 | " return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None\n", 253 | "\n", 254 | "replace_grad = ReplaceGrad.apply\n", 255 | "\n", 256 | "clamp_with_grad = ClampWithGrad.apply\n", 257 | "# clamp_with_grad = torch.clamp\n", 258 | "\n", 259 | "def vector_quantize(x, codebook):\n", 260 | " d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T\n", 261 | " indices = d.argmin(-1)\n", 262 | " x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook\n", 263 | " return replace_grad(x_q, x)\n", 264 | "\n", 265 | "\n", 266 | "class Prompt(nn.Module):\n", 267 | " def __init__(self, embed, weight=1., stop=float('-inf')):\n", 268 | " super().__init__()\n", 269 | " self.register_buffer('embed', embed)\n", 270 | " self.register_buffer('weight', torch.as_tensor(weight))\n", 271 | " self.register_buffer('stop', torch.as_tensor(stop))\n", 272 | "\n", 273 | " def forward(self, input):\n", 274 | " \n", 275 | " input_normed = F.normalize(input.unsqueeze(1), dim=2)#(input / input.norm(dim=-1, keepdim=True)).unsqueeze(1)# \n", 276 | " embed_normed = F.normalize((self.embed).unsqueeze(0), dim=2)#(self.embed / self.embed.norm(dim=-1, keepdim=True)).unsqueeze(0)#\n", 277 | "\n", 278 | " dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)\n", 279 | " dists = dists * self.weight.sign()\n", 280 | " return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()\n", 281 | "\n", 282 | "\n", 283 | "def parse_prompt(prompt):\n", 284 | " vals = prompt.rsplit(':', 2)\n", 285 | " vals = vals + ['', '1', '-inf'][len(vals):]\n", 286 | " return vals[0], float(vals[1]), float(vals[2])\n", 287 | "\n", 288 | "def one_sided_clip_loss(input, target, labels=None, logit_scale=100):\n", 289 | " input_normed = F.normalize(input, dim=-1)\n", 290 | " target_normed = F.normalize(target, dim=-1)\n", 291 | " logits = input_normed @ target_normed.T * logit_scale\n", 292 | " if labels is None:\n", 293 | " labels = torch.arange(len(input), device=logits.device)\n", 294 | " return F.cross_entropy(logits, labels)\n", 295 | "\n", 296 | "class MakeCutouts(nn.Module):\n", 297 | " def __init__(self, cut_size, cutn, cut_pow=1.):\n", 298 | " super().__init__()\n", 299 | " self.cut_size = cut_size\n", 300 | " self.cutn = cutn\n", 301 | " self.cut_pow = cut_pow\n", 302 | "\n", 303 | " self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))\n", 304 | " self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))\n", 305 | "\n", 306 | " def set_cut_pow(self, cut_pow):\n", 307 | " self.cut_pow = cut_pow\n", 308 | "\n", 309 | " def forward(self, input):\n", 310 | " sideY, sideX = input.shape[2:4]\n", 311 | " max_size = min(sideX, sideY)\n", 312 | " min_size = min(sideX, sideY, self.cut_size)\n", 313 | " cutouts = []\n", 314 | " cutouts_full = []\n", 315 | " \n", 316 | " min_size_width = min(sideX, sideY)\n", 317 | " lower_bound = float(self.cut_size/min_size_width)\n", 318 | " \n", 319 | " for ii in range(self.cutn):\n", 320 | " \n", 321 | " \n", 322 | " # size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)\n", 323 | " size = int(min_size_width*torch.zeros(1,).normal_(mean=.8, std=.3).clip(lower_bound, 1.)) # replace .5 with a result for 224 the default large size is .95\n", 324 | " # size = int(min_size_width*torch.zeros(1,).normal_(mean=.9, std=.3).clip(lower_bound, .95)) # replace .5 with a result for 224 the default large size is .95\n", 325 | "\n", 326 | " offsetx = torch.randint(0, sideX - size + 1, ())\n", 327 | " offsety = torch.randint(0, sideY - size + 1, ())\n", 328 | " cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n", 329 | " cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))\n", 330 | "\n", 331 | " \n", 332 | " cutouts = torch.cat(cutouts, dim=0)\n", 333 | "\n", 334 | " # if args.use_augs:\n", 335 | " # cutouts = augs(cutouts)\n", 336 | "\n", 337 | " # if args.noise_fac:\n", 338 | " # facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(0, args.noise_fac)\n", 339 | " # cutouts = cutouts + facs * torch.randn_like(cutouts)\n", 340 | " \n", 341 | "\n", 342 | " return clamp_with_grad(cutouts, 0, 1)\n", 343 | "\n", 344 | "\n", 345 | "def load_vqgan_model(config_path, checkpoint_path):\n", 346 | " config = OmegaConf.load(config_path)\n", 347 | " if config.model.target == 'taming.models.vqgan.VQModel':\n", 348 | " model = vqgan.VQModel(**config.model.params)\n", 349 | " model.eval().requires_grad_(False)\n", 350 | " model.init_from_ckpt(checkpoint_path)\n", 351 | " elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':\n", 352 | " parent_model = cond_transformer.Net2NetTransformer(**config.model.params)\n", 353 | " parent_model.eval().requires_grad_(False)\n", 354 | " parent_model.init_from_ckpt(checkpoint_path)\n", 355 | " model = parent_model.first_stage_model\n", 356 | " elif config.model.target == 'taming.models.vqgan.GumbelVQ':\n", 357 | " model = vqgan.GumbelVQ(**config.model.params)\n", 358 | " model.eval().requires_grad_(False)\n", 359 | " model.init_from_ckpt(checkpoint_path)\n", 360 | " else:\n", 361 | " raise ValueError(f'unknown model type: {config.model.target}')\n", 362 | " del model.loss\n", 363 | " return model\n", 364 | "\n", 365 | "def resize_image(image, out_size):\n", 366 | " ratio = image.size[0] / image.size[1]\n", 367 | " area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])\n", 368 | " size = round((area * ratio)**0.5), round((area / ratio)**0.5)\n", 369 | " return image.resize(size, Image.LANCZOS)\n", 370 | "\n", 371 | "class TVLoss(nn.Module):\n", 372 | " def forward(self, input):\n", 373 | " input = F.pad(input, (0, 1, 0, 1), 'replicate')\n", 374 | " x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]\n", 375 | " y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]\n", 376 | " diff = x_diff**2 + y_diff**2 + 1e-8\n", 377 | " return diff.mean(dim=1).sqrt().mean()\n", 378 | "\n", 379 | "class GaussianBlur2d(nn.Module):\n", 380 | " def __init__(self, sigma, window=0, mode='reflect', value=0):\n", 381 | " super().__init__()\n", 382 | " self.mode = mode\n", 383 | " self.value = value\n", 384 | " if not window:\n", 385 | " window = max(math.ceil((sigma * 6 + 1) / 2) * 2 - 1, 3)\n", 386 | " if sigma:\n", 387 | " kernel = torch.exp(-(torch.arange(window) - window // 2)**2 / 2 / sigma**2)\n", 388 | " kernel /= kernel.sum()\n", 389 | " else:\n", 390 | " kernel = torch.ones([1])\n", 391 | " self.register_buffer('kernel', kernel)\n", 392 | "\n", 393 | " def forward(self, input):\n", 394 | " n, c, h, w = input.shape\n", 395 | " input = input.view([n * c, 1, h, w])\n", 396 | " start_pad = (self.kernel.shape[0] - 1) // 2\n", 397 | " end_pad = self.kernel.shape[0] // 2\n", 398 | " input = F.pad(input, (start_pad, end_pad, start_pad, end_pad), self.mode, self.value)\n", 399 | " input = F.conv2d(input, self.kernel[None, None, None, :])\n", 400 | " input = F.conv2d(input, self.kernel[None, None, :, None])\n", 401 | " return input.view([n, c, h, w])\n", 402 | "\n", 403 | "class EMATensor(nn.Module):\n", 404 | " \"\"\"implmeneted by Katherine Crowson\"\"\"\n", 405 | " def __init__(self, tensor, decay):\n", 406 | " super().__init__()\n", 407 | " self.tensor = nn.Parameter(tensor)\n", 408 | " self.register_buffer('biased', torch.zeros_like(tensor))\n", 409 | " self.register_buffer('average', torch.zeros_like(tensor))\n", 410 | " self.decay = decay\n", 411 | " self.register_buffer('accum', torch.tensor(1.))\n", 412 | " self.update()\n", 413 | " \n", 414 | " @torch.no_grad()\n", 415 | " def update(self):\n", 416 | " if not self.training:\n", 417 | " raise RuntimeError('update() should only be called during training')\n", 418 | "\n", 419 | " self.accum *= self.decay\n", 420 | " self.biased.mul_(self.decay)\n", 421 | " self.biased.add_((1 - self.decay) * self.tensor)\n", 422 | " self.average.copy_(self.biased)\n", 423 | " self.average.div_(1 - self.accum)\n", 424 | "\n", 425 | " def forward(self):\n", 426 | " if self.training:\n", 427 | " return self.tensor\n", 428 | " return self.average\n", 429 | "\n", 430 | "%mkdir /content/vids" 431 | ], 432 | "execution_count": null, 433 | "outputs": [] 434 | }, 435 | { 436 | "cell_type": "markdown", 437 | "metadata": { 438 | "id": "WN4OtaLbHBN6" 439 | }, 440 | "source": [ 441 | "## **Arguments** " 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "metadata": { 447 | "id": "tLw9p5Rzacso", 448 | "cellView": "form" 449 | }, 450 | "source": [ 451 | "#@markdown #**Double-click here and input**\n", 452 | "#@markdown like prompts, iterations and other stuff \n", 453 | "#@markdown *(This is a W.I.P space and will be simplified further soon)*\n", 454 | "#@markdown ---\n", 455 | "\n", 456 | "args = argparse.Namespace(\n", 457 | " \n", 458 | " prompts=[\"a photo-realistic and beautiful painting of an old man sitting on a chair next to a giant iceberg and looking at a green lush landscape.\"],\n", 459 | " size=[640, 512], \n", 460 | " init_image= None,\n", 461 | " init_weight= 0.5,\n", 462 | "\n", 463 | " # clip model settings\n", 464 | " clip_model='ViT-B/32',\n", 465 | " vqgan_config='vqgan_imagenet_f16_16384.yaml', \n", 466 | " vqgan_checkpoint='vqgan_imagenet_f16_16384.ckpt',\n", 467 | " step_size=0.1,\n", 468 | " \n", 469 | " # cutouts / crops\n", 470 | " cutn=64,\n", 471 | " cut_pow=1,\n", 472 | "\n", 473 | " # display\n", 474 | " display_freq=25,\n", 475 | " seed=158758,\n", 476 | " use_augs = True,\n", 477 | " noise_fac= 0.1,\n", 478 | " ema_val = 0.99,\n", 479 | "\n", 480 | " record_generation=True,\n", 481 | "\n", 482 | " # noise and other constraints\n", 483 | " use_noise = None,\n", 484 | " constraint_regions = False,#\n", 485 | " \n", 486 | " \n", 487 | " # add noise to embedding\n", 488 | " noise_prompt_weights = None,\n", 489 | " noise_prompt_seeds = [14575],#\n", 490 | "\n", 491 | " # mse settings\n", 492 | " mse_withzeros = True,\n", 493 | " mse_decay_rate = 50,\n", 494 | " mse_epoches = 5,\n", 495 | " mse_quantize = True,\n", 496 | "\n", 497 | " # end itteration\n", 498 | " max_itter = -1,\n", 499 | ")\n", 500 | "\n", 501 | "mse_decay = 0\n", 502 | "if args.init_weight:\n", 503 | " mse_decay = args.init_weight / args.mse_epoches\n", 504 | "\n", 505 | "# \n", 506 | "augs = nn.Sequential(\n", 507 | " \n", 508 | " K.RandomHorizontalFlip(p=0.5),\n", 509 | " K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'), # padding_mode=2\n", 510 | " K.RandomPerspective(0.2,p=0.4, ),\n", 511 | " K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),\n", 512 | "\n", 513 | " )\n", 514 | "\n", 515 | "noise = noise_gen([1, 3, args.size[0], args.size[1]])\n", 516 | "image = TF.to_pil_image(noise.div(5).add(0.5).clamp(0, 1)[0])\n", 517 | "image.save('init3.png')" 518 | ], 519 | "execution_count": null, 520 | "outputs": [] 521 | }, 522 | { 523 | "cell_type": "markdown", 524 | "metadata": { 525 | "id": "crRQdV3jPSvw" 526 | }, 527 | "source": [ 528 | "# **Constraints**" 529 | ] 530 | }, 531 | { 532 | "cell_type": "code", 533 | "metadata": { 534 | "cellView": "form", 535 | "id": "qSUp-6M0m-Dz" 536 | }, 537 | "source": [ 538 | "#@markdown #*Double-click here and edit me if you like*\n", 539 | "#@markdown ---\n", 540 | "\n", 541 | "from PIL import Image, ImageDraw\n", 542 | "\n", 543 | "if args.constraint_regions and args.init_image:\n", 544 | " \n", 545 | " device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 546 | "\n", 547 | " toksX, toksY = args.size[0] // 16, args.size[1] // 16\n", 548 | "\n", 549 | " pil_image = Image.open(args.init_image).convert('RGB')\n", 550 | " pil_image = pil_image.resize((toksX * 16, toksY * 16), Image.LANCZOS)\n", 551 | "\n", 552 | " width, height = pil_image.size\n", 553 | "\n", 554 | " d = ImageDraw.Draw(pil_image)\n", 555 | " for i in range(0,width,16):\n", 556 | " d.text((i+4,0), f\"{int(i/16)}\", fill=(50,200,100))\n", 557 | " for i in range(0,height,16):\n", 558 | " d.text((4,i), f\"{int(i/16)}\", fill=(50,200,100))\n", 559 | "\n", 560 | " pil_image = TF.to_tensor(pil_image)\n", 561 | "\n", 562 | " print(pil_image.shape)\n", 563 | " for i in range(pil_image.shape[1]):\n", 564 | " for j in range(pil_image.shape[2]):\n", 565 | " if i%16 == 0 or j%16 ==0:\n", 566 | " pil_image[:,i,j] = 0\n", 567 | "\n", 568 | " # select region\n", 569 | " c_h = [16,32]\n", 570 | " c_w = [0,40]\n", 571 | "\n", 572 | " c_hf = [i*16 for i in c_h]\n", 573 | " c_wf = [i*16 for i in c_w]\n", 574 | "\n", 575 | " pil_image[0,c_hf[0]:c_hf[1],c_wf[0]:c_wf[1]] = 0\n", 576 | "\n", 577 | " TF.to_pil_image(pil_image.cpu()).save('progress.png')\n", 578 | " display.display(display.Image('progress.png'))\n", 579 | "\n", 580 | " z_mask = torch.zeros([1, 256, int(height/16), int(width/16)]).to(device)\n", 581 | " z_mask[:,:,c_h[0]:c_h[1],c_w[0]:c_w[1]] = 1" 582 | ], 583 | "execution_count": null, 584 | "outputs": [] 585 | }, 586 | { 587 | "cell_type": "markdown", 588 | "metadata": { 589 | "id": "jmuK7MTroHHE" 590 | }, 591 | "source": [ 592 | "#**Final Steps**" 593 | ] 594 | }, 595 | { 596 | "cell_type": "code", 597 | "metadata": { 598 | "id": "g7EDme5RYCrt", 599 | "cellView": "form" 600 | }, 601 | "source": [ 602 | "#@markdown #**Fire up the AI**\n", 603 | "\n", 604 | "#@markdown ---\n", 605 | "\n", 606 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 607 | "\n", 608 | "print('Using device:', device)\n", 609 | "print('using prompts: ', args.prompts)\n", 610 | "\n", 611 | "tv_loss = TVLoss() \n", 612 | "\n", 613 | "model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)\n", 614 | "perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)\n", 615 | "mse_weight = args.init_weight\n", 616 | "\n", 617 | "cut_size = perceptor.visual.input_resolution\n", 618 | "# e_dim = model.quantize.e_dim\n", 619 | "\n", 620 | "if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt':\n", 621 | " e_dim = 256\n", 622 | " n_toks = model.quantize.n_embed\n", 623 | " z_min = model.quantize.embed.weight.min(dim=0).values[None, :, None, None]\n", 624 | " z_max = model.quantize.embed.weight.max(dim=0).values[None, :, None, None]\n", 625 | "else:\n", 626 | " e_dim = model.quantize.e_dim\n", 627 | " n_toks = model.quantize.n_e\n", 628 | " z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]\n", 629 | " z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]\n", 630 | "\n", 631 | "\n", 632 | "make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow)\n", 633 | "\n", 634 | "f = 2**(model.decoder.num_resolutions - 1)\n", 635 | "toksX, toksY = args.size[0] // f, args.size[1] // f\n", 636 | "\n", 637 | "if args.seed is not None:\n", 638 | " torch.manual_seed(args.seed)\n", 639 | "\n", 640 | "if args.init_image:\n", 641 | " pil_image = Image.open(args.init_image).convert('RGB')\n", 642 | " pil_image = pil_image.resize((toksX * 16, toksY * 16), Image.LANCZOS)\n", 643 | " pil_image = TF.to_tensor(pil_image)\n", 644 | " if args.use_noise:\n", 645 | " pil_image = pil_image + args.use_noise * torch.randn_like(pil_image) \n", 646 | " z, *_ = model.encode(pil_image.to(device).unsqueeze(0) * 2 - 1)\n", 647 | "\n", 648 | "else:\n", 649 | " \n", 650 | " one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float()\n", 651 | "\n", 652 | " if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt':\n", 653 | " z = one_hot @ model.quantize.embed.weight\n", 654 | " else:\n", 655 | " z = one_hot @ model.quantize.embedding.weight\n", 656 | " z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)\n", 657 | "\n", 658 | "\n", 659 | "z = EMATensor(z, args.ema_val)\n", 660 | "\n", 661 | "if args.mse_withzeros and not args.init_image:\n", 662 | " z_orig = torch.zeros_like(z.tensor)\n", 663 | "else:\n", 664 | " z_orig = z.tensor.clone()\n", 665 | "\n", 666 | "\n", 667 | "opt = optim.Adam(z.parameters(), lr=args.step_size, weight_decay=0.00000000)\n", 668 | "\n", 669 | "normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],\n", 670 | " std=[0.26862954, 0.26130258, 0.27577711])\n", 671 | "\n", 672 | "pMs = []\n", 673 | "\n", 674 | "if args.noise_prompt_weights and args.noise_prompt_seeds:\n", 675 | " for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):\n", 676 | " gen = torch.Generator().manual_seed(seed)\n", 677 | " embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)\n", 678 | " pMs.append(Prompt(embed, weight).to(device))\n", 679 | "\n", 680 | "for prompt in args.prompts:\n", 681 | " txt, weight, stop = parse_prompt(prompt)\n", 682 | " embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()\n", 683 | " pMs.append(Prompt(embed, weight, stop).to(device))\n", 684 | " # pMs[0].embed = pMs[0].embed + Prompt(embed, weight, stop).embed.to(device)\n", 685 | "\n", 686 | "\n", 687 | "def synth(z, quantize=True):\n", 688 | " if args.constraint_regions:\n", 689 | " z = replace_grad(z, z * z_mask)\n", 690 | "\n", 691 | " if quantize:\n", 692 | " if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt':\n", 693 | " z_q = vector_quantize(z.movedim(1, 3), model.quantize.embed.weight).movedim(3, 1)\n", 694 | " else:\n", 695 | " z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)\n", 696 | "\n", 697 | " else:\n", 698 | " z_q = z.model\n", 699 | "\n", 700 | " return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)\n", 701 | "\n", 702 | "@torch.no_grad()\n", 703 | "def checkin(i, losses):\n", 704 | " losses_str = ', '.join(f'{loss.item():g}' for loss in losses)\n", 705 | " tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}')\n", 706 | " out = synth(z.average, True)\n", 707 | "\n", 708 | " TF.to_pil_image(out[0].cpu()).save('progress.png') \n", 709 | " display.display(display.Image('progress.png')) \n", 710 | "\n", 711 | "\n", 712 | "def ascend_txt():\n", 713 | " global mse_weight\n", 714 | "\n", 715 | " out = synth(z.tensor)\n", 716 | "\n", 717 | " if args.record_generation:\n", 718 | " with torch.no_grad():\n", 719 | " global vid_index\n", 720 | " out_a = synth(z.average, True)\n", 721 | " TF.to_pil_image(out_a[0].cpu()).save(f'/content/vids/{vid_index}.png')\n", 722 | " vid_index += 1\n", 723 | "\n", 724 | " cutouts = make_cutouts(out)\n", 725 | "\n", 726 | " if args.use_augs:\n", 727 | " cutouts = augs(cutouts)\n", 728 | "\n", 729 | " if args.noise_fac:\n", 730 | " facs = cutouts.new_empty([args.cutn, 1, 1, 1]).uniform_(0, args.noise_fac)\n", 731 | " cutouts = cutouts + facs * torch.randn_like(cutouts)\n", 732 | "\n", 733 | " iii = perceptor.encode_image(normalize(cutouts)).float()\n", 734 | "\n", 735 | " result = []\n", 736 | "\n", 737 | " if args.init_weight:\n", 738 | " \n", 739 | " global z_orig\n", 740 | " \n", 741 | " result.append(F.mse_loss(z.tensor, z_orig) * mse_weight / 2)\n", 742 | " # result.append(F.mse_loss(z, z_orig) * ((1/torch.tensor((i)*2 + 1))*mse_weight) / 2)\n", 743 | "\n", 744 | " with torch.no_grad():\n", 745 | " if i > 0 and i%args.mse_decay_rate==0 and i <= args.mse_decay_rate*args.mse_epoches:\n", 746 | "\n", 747 | " if args.mse_quantize:\n", 748 | " z_orig = vector_quantize(z.average.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)#z.average\n", 749 | " else:\n", 750 | " z_orig = z.average.clone()\n", 751 | "\n", 752 | " if mse_weight - mse_decay > 0 and mse_weight - mse_decay >= mse_decay:\n", 753 | " mse_weight = mse_weight - mse_decay\n", 754 | " print(f\"updated mse weight: {mse_weight}\")\n", 755 | " else:\n", 756 | " mse_weight = 0\n", 757 | " print(f\"updated mse weight: {mse_weight}\")\n", 758 | "\n", 759 | " for prompt in pMs:\n", 760 | " result.append(prompt(iii))\n", 761 | "\n", 762 | " return result\n", 763 | "\n", 764 | "vid_index = 0\n", 765 | "def train(i):\n", 766 | " \n", 767 | " opt.zero_grad()\n", 768 | " lossAll = ascend_txt()\n", 769 | "\n", 770 | " if i % args.display_freq == 0:\n", 771 | " checkin(i, lossAll)\n", 772 | " \n", 773 | " loss = sum(lossAll)\n", 774 | "\n", 775 | " loss.backward()\n", 776 | " opt.step()\n", 777 | " z.update()\n", 778 | "\n", 779 | "i = 0\n", 780 | "try:\n", 781 | " with tqdm() as pbar:\n", 782 | " while True and i != args.max_itter:\n", 783 | "\n", 784 | " train(i)\n", 785 | "\n", 786 | " if i > 0 and i%args.mse_decay_rate==0 and i <= args.mse_decay_rate * args.mse_epoches:\n", 787 | " z = EMATensor(z.average, args.ema_val)\n", 788 | " opt = optim.Adam(z.parameters(), lr=args.step_size, weight_decay=0.00000000)\n", 789 | "\n", 790 | " i += 1\n", 791 | " pbar.update()\n", 792 | "\n", 793 | "except KeyboardInterrupt:\n", 794 | " pass\n" 795 | ], 796 | "execution_count": null, 797 | "outputs": [] 798 | }, 799 | { 800 | "cell_type": "markdown", 801 | "metadata": { 802 | "id": "CDUaCaRnUKMZ" 803 | }, 804 | "source": [ 805 | "# **Generate video**" 806 | ] 807 | }, 808 | { 809 | "cell_type": "code", 810 | "metadata": { 811 | "id": "DT3hKb5gJUPq", 812 | "cellView": "form" 813 | }, 814 | "source": [ 815 | "#@markdown #*Double-click here and edit me*\n", 816 | "\n", 817 | "%cd vids\n", 818 | "\n", 819 | "images = \"%d.png\"\n", 820 | "video = \"/content/old_man_iceberg.mp4\"\n", 821 | "!ffmpeg -r 30 -i $images -crf 20 -s 640x512 -pix_fmt yuv420p $video\n", 822 | "\n", 823 | "%cd .." 824 | ], 825 | "execution_count": null, 826 | "outputs": [] 827 | }, 828 | { 829 | "cell_type": "markdown", 830 | "metadata": { 831 | "id": "UiZMW3kAUD1f" 832 | }, 833 | "source": [ 834 | "**Delete all frames from folder**" 835 | ] 836 | }, 837 | { 838 | "cell_type": "code", 839 | "metadata": { 840 | "cellView": "form", 841 | "id": "OXMWkXo7okMA" 842 | }, 843 | "source": [ 844 | "#@markdown Run this tab if you wanna clear all the genarated frames images\n", 845 | "\n", 846 | "\n", 847 | "%cd vids\n", 848 | "%rm *.png\n", 849 | "%cd .." 850 | ], 851 | "execution_count": null, 852 | "outputs": [] 853 | } 854 | ] 855 | } -------------------------------------------------------------------------------- /PixelDrawer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "PixelDrawer.ipynb", 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "machine_shape": "hm" 11 | }, 12 | "kernelspec": { 13 | "display_name": "Python 3 (ipykernel)", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.8.10" 28 | } 29 | }, 30 | "cells": [ 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "sK-2G1dm-5RG" 35 | }, 36 | "source": [ 37 | "# Generate images from text phrases with VQGAN and CLIP (PixelDrawer)\n", 38 | "Notebook by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). \n", 39 | "\n", 40 | "The original BigGAN + CLIP method was made by https://twitter.com/advadnoun.\n", 41 | "\n", 42 | "Special thanks to [@dribnet's clipit repo](https://github.com/dribnet/clipit)\n", 43 | "\n", 44 | "Modifications by: Justin John\n", 45 | "\n", 46 | "TODO: gen vid" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "metadata": { 52 | "cellView": "form", 53 | "id": "iUa-GDBrDNyo" 54 | }, 55 | "source": [ 56 | "#@markdown ###Licensed under the MIT License\n", 57 | "#@markdown ---\n", 58 | "\n", 59 | "# Copyright (c) 2021 Katherine Crowson\n", 60 | "\n", 61 | "# Permission is hereby granted, free of charge, to any person obtaining a copy\n", 62 | "# of this software and associated documentation files (the \"Software\"), to deal\n", 63 | "# in the Software without restriction, including without limitation the rights\n", 64 | "# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", 65 | "# copies of the Software, and to permit persons to whom the Software is\n", 66 | "# furnished to do so, subject to the following conditions:\n", 67 | "\n", 68 | "# The above copyright notice and this permission notice shall be included in\n", 69 | "# all copies or substantial portions of the Software.\n", 70 | "\n", 71 | "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", 72 | "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", 73 | "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", 74 | "# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", 75 | "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", 76 | "# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n", 77 | "# THE SOFTWARE." 78 | ], 79 | "execution_count": null, 80 | "outputs": [] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "metadata": { 85 | "id": "KLYwf1WCDtCv", 86 | "cellView": "form" 87 | }, 88 | "source": [ 89 | "#@markdown ##**Check GPU type**\n", 90 | "#@markdown ### Factory reset runtime if you don't have the desired GPU.\n", 91 | "\n", 92 | "#@markdown ---\n", 93 | "\n", 94 | "\n", 95 | "\n", 96 | "\n", 97 | "#@markdown V100 = Excellent (*Available only for Colab Pro users*)\n", 98 | "\n", 99 | "#@markdown P100 = Very Good (*Available only for Colab Pro users*)\n", 100 | "\n", 101 | "#@markdown T4 = Good (*preferred*)(*Available only for Colab Pro users*)\n", 102 | "\n", 103 | "#@markdown K80 = Meh (*Not tested yet*)\n", 104 | "\n", 105 | "#@markdown P4 = (*Not Recommended*) \n", 106 | "\n", 107 | "#@markdown ---\n", 108 | "\n", 109 | "!nvidia-smi -L" 110 | ], 111 | "execution_count": null, 112 | "outputs": [] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "metadata": { 117 | "id": "qQOvOhnKQ-Tu", 118 | "cellView": "form" 119 | }, 120 | "source": [ 121 | "#@title ##**Setup**\n", 122 | "\n", 123 | "#@markdown Please execute this cell by pressing the _Play_ button \n", 124 | "#@markdown on the left. For setup,\n", 125 | "#@markdown **you need to run this cell,\n", 126 | "#@markdown then choose Runtime -> Restart Runtime from the menu,\n", 127 | "#@markdown and then run the cell again**. It should remind you to\n", 128 | "#@markdown do this after the first run.\n", 129 | "\n", 130 | "#@markdown Setup can take 5-10 minutes, but once it is complete it usually does not need to be repeated\n", 131 | "#@markdown until you close the window.\n", 132 | "\n", 133 | "#@markdown **Note**: This installs the software on the Colab \n", 134 | "#@markdown notebook in the cloud and not on your computer.\n", 135 | "\n", 136 | "#@markdown ---\n", 137 | "\n", 138 | "# https://stackoverflow.com/a/56727659/1010653\n", 139 | "\n", 140 | "from google.colab import output\n", 141 | "\n", 142 | "nvidia_output = !nvidia-smi --query-gpu=memory.total --format=noheader,nounits,csv\n", 143 | "gpu_memory = int(nvidia_output[0])\n", 144 | "if gpu_memory < 14000:\n", 145 | " output.eval_js('new Audio(\"https://upload.wikimedia.org/wikipedia/commons/0/05/Beep-09.ogg\").play()')\n", 146 | " warning_string = f\"--> GPU check: ONLY {gpu_memory} MiB available: WARNING, THIS IS PROBABLY NOT ENOUGH <--\"\n", 147 | " print(warning_string)\n", 148 | " output.eval_js('alert(\"Warning - low GPU (see message)\")')\n", 149 | "else:\n", 150 | " print(f\"GPU check: {gpu_memory} MiB available: this should be fine\")\n", 151 | "\n", 152 | "from IPython.utils import io\n", 153 | "with io.capture_output() as captured:\n", 154 | " !git clone https://github.com/openai/CLIP\n", 155 | " # !pip install taming-transformers\n", 156 | " !git clone https://github.com/CompVis/taming-transformers.git\n", 157 | " !rm -Rf pixray\n", 158 | " !git clone https://github.com/dribnet/pixray\n", 159 | " !pip install ftfy regex tqdm omegaconf pytorch-lightning\n", 160 | " !pip install kornia\n", 161 | " !pip install imageio-ffmpeg \n", 162 | " !pip install einops\n", 163 | " !pip install torch-optimizer\n", 164 | " !pip install easydict\n", 165 | " !pip install braceexpand\n", 166 | " !pip install git+https://github.com/pvigier/perlin-numpy\n", 167 | "\n", 168 | " # ClipDraw deps\n", 169 | " !pip install svgwrite\n", 170 | " !pip install svgpathtools\n", 171 | " !pip install cssutils\n", 172 | " !pip install numba\n", 173 | " !pip install torch-tools\n", 174 | " !pip install visdom\n", 175 | "\n", 176 | " !git clone https://github.com/BachiLi/diffvg\n", 177 | " %cd diffvg\n", 178 | " # !ls\n", 179 | " !git submodule update --init --recursive\n", 180 | " !python setup.py install\n", 181 | " %cd ..\n", 182 | "\n", 183 | "output.clear()\n", 184 | "import sys\n", 185 | "sys.path.append(\"pixray\")\n", 186 | "\n", 187 | "result_msg = \"setup complete\"\n", 188 | "import IPython\n", 189 | "import os\n", 190 | "if not os.path.isfile(\"first_init_complete\"):\n", 191 | " # put stuff in here that should only happen once\n", 192 | " !mkdir -p models\n", 193 | " os.mknod(\"first_init_complete\")\n", 194 | " result_msg = \"Please choose Runtime -> Restart Runtime from the menu, and then run Setup again\"\n", 195 | "\n", 196 | "js_code = f'''\n", 197 | "document.querySelector(\"#output-area\").appendChild(document.createTextNode(\"{result_msg}\"));\n", 198 | "'''\n", 199 | "js_code += '''\n", 200 | "for (rule of document.styleSheets[0].cssRules){\n", 201 | " if (rule.selectorText=='body') break\n", 202 | "}\n", 203 | "rule.style.fontSize = '30px'\n", 204 | "'''\n", 205 | "display(IPython.display.Javascript(js_code))" 206 | ], 207 | "execution_count": null, 208 | "outputs": [] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "metadata": { 213 | "id": "XziodsCqVC2A", 214 | "cellView": "form" 215 | }, 216 | "source": [ 217 | "#@title ##**Settings**\n", 218 | "\n", 219 | "#@markdown Enter a description of what you want to draw - I usually add #pixelart to the prompt.\n", 220 | "#@markdown You can switch between other models if you don't want to use the one you have (PixelDraw).\n", 221 | "#@markdown Like VQGAN or CLIPDraw.
\n", 222 | "\n", 223 | "prompts = \"MF Doom. #pixelart\" #@param {type:\"string\"}\n", 224 | "\n", 225 | "init_image = \"\" #@param {type:\"string\"}\n", 226 | "\n", 227 | "aspect = \"widescreen\" #@param [\"widescreen\", \"square\"]\n", 228 | "\n", 229 | "use = \"pixeldraw\" #@param [\"vqgan\", \"pixeldraw\", \"clipdraw\"]\n", 230 | "\n", 231 | "#@markdown When you have the settings you want, press the play button on the left.\n", 232 | "#@markdown The system will save these and start generating images below.\n", 233 | "\n", 234 | "#@markdown When that is done you can change these\n", 235 | "#@markdown settings and see if you get different results. Or if you get\n", 236 | "#@markdown impatient, just select \"Runtime -> Interrupt Execution\".\n", 237 | "#@markdown Note that the first time you run it may take a bit longer\n", 238 | "#@markdown as nessary files are downloaded.\n", 239 | "\n", 240 | "\n", 241 | "#@markdown\n", 242 | "#@markdown *Advanced: you can also edit this cell and add add additional\n", 243 | "#@markdown settings, combining settings from different notebooks.*\n", 244 | "\n", 245 | "\n", 246 | "# Simple setup\n", 247 | "import pixray\n", 248 | "\n", 249 | "# these are good settings for pixeldraw\n", 250 | "pixray.reset_settings()\n", 251 | "pixray.add_settings(prompts=prompts, aspect=aspect)\n", 252 | "pixray.add_settings(quality=\"better\", scale=2.5)\n", 253 | "pixray.add_settings(drawer=use)\n", 254 | "pixray.add_settings(display_clear=True)\n", 255 | "pixray.add_settings(iterations=2000, display_every=50)\n", 256 | "\n", 257 | "#### YOU CAN ADD YOUR OWN CUSTOM SETTING HERE ####\n", 258 | "# this is the example of how to run longer with less frequent display\n", 259 | "# pixray.add_settings(iterations=500, display_every=50)\n", 260 | "\n", 261 | "settings = pixray.apply_settings()\n", 262 | "pixray.do_init(settings)\n", 263 | "pixray.do_run(settings)\n", 264 | "#@markdown ---" 265 | ], 266 | "execution_count": null, 267 | "outputs": [] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "metadata": { 272 | "id": "SBhzz5NCNgX8" 273 | }, 274 | "source": [ 275 | "JS to prevent idle timeout:\n", 276 | "\n", 277 | "Press F12 OR CTRL + SHIFT + I OR right click on this website -> inspect.\n", 278 | "Then click on the console tab and paste in the following code.\n", 279 | "\n", 280 | "```javascript\n", 281 | "function ClickConnect(){\n", 282 | "console.log(\"Working\");\n", 283 | "document.querySelector(\"colab-toolbar-button#connect\").click()\n", 284 | "}\n", 285 | "setInterval(ClickConnect,60000)\n", 286 | "```" 287 | ] 288 | } 289 | ] 290 | } -------------------------------------------------------------------------------- /Pixray_Panorama_Demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "Pixray Panorama Demo", 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "machine_shape": "hm" 11 | }, 12 | "kernelspec": { 13 | "display_name": "Python 3 (ipykernel)", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.8.10" 28 | } 29 | }, 30 | "cells": [ 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "JyURfzEun4q5" 35 | }, 36 | "source": [ 37 | "# **Pixray Panorama demo**\n", 38 | "\n", 39 | "Special thanks to [@dribnet's pixray repo](https://github.com/dribnet/pixray)\n", 40 | "and [@altsoph](http://altsoph.com/)\n" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "metadata": { 46 | "id": "qQOvOhnKQ-Tu", 47 | "cellView": "form" 48 | }, 49 | "source": [ 50 | "#@title **Setup**\n", 51 | "\n", 52 | "#@markdown Please execute this cell by pressing the _Play_ button \n", 53 | "#@markdown on the left. For setup,\n", 54 | "#@markdown **you need to run this cell,\n", 55 | "#@markdown then choose Runtime -> Restart Runtime from the menu,\n", 56 | "#@markdown and then run the cell again**. It should remind you to\n", 57 | "#@markdown do this after the first run.\n", 58 | "\n", 59 | "#@markdown Setup can take 5-10 minutes, but once it is complete it usually does not need to be repeated\n", 60 | "#@markdown until you close the window.\n", 61 | "\n", 62 | "#@markdown **Note**: This installs the software on the Colab \n", 63 | "#@markdown notebook in the cloud and not on your computer.\n", 64 | "\n", 65 | "# https://stackoverflow.com/a/56727659/1010653\n", 66 | "\n", 67 | "# Add a gpu check\n", 68 | "# (this can get better over time)\n", 69 | "from google.colab import output\n", 70 | "\n", 71 | "nvidia_output = !nvidia-smi --query-gpu=memory.total --format=noheader,nounits,csv\n", 72 | "gpu_memory = int(nvidia_output[0])\n", 73 | "if gpu_memory < 14000:\n", 74 | " output.eval_js('new Audio(\"https://upload.wikimedia.org/wikipedia/commons/0/05/Beep-09.ogg\").play()')\n", 75 | " warning_string = f\"--> GPU check: ONLY {gpu_memory} MiB available: WARNING, THIS IS PROBABLY NOT ENOUGH <--\"\n", 76 | " print(warning_string)\n", 77 | " output.eval_js('alert(\"Warning - low GPU (see message)\")')\n", 78 | "else:\n", 79 | " print(f\"GPU check: {gpu_memory} MiB available: this should be fine\")\n", 80 | "\n", 81 | "from IPython.utils import io\n", 82 | "with io.capture_output() as captured:\n", 83 | " # On 2021/10/08, Colab updated its default PyTorch installation to a version that causes\n", 84 | " # problems with diffvg. So, first thing, let's roll back to the older version:\n", 85 | " !pip install torch==1.9.0+cu102 torchvision==0.10.0+cu102 -f https://download.pytorch.org/whl/torch/ -f https://download.pytorch.org/whl/torchvision/\n", 86 | "\n", 87 | " !git clone https://github.com/openai/CLIP\n", 88 | " # !pip install taming-transformers\n", 89 | " !git clone https://github.com/CompVis/taming-transformers.git\n", 90 | " !rm -Rf pixray\n", 91 | " !git clone https://github.com/dribnet/pixray\n", 92 | " !pip install ftfy regex tqdm omegaconf pytorch-lightning\n", 93 | " !pip install kornia\n", 94 | " !pip install imageio-ffmpeg \n", 95 | " !pip install einops\n", 96 | " !pip install torch-optimizer\n", 97 | " !pip install easydict\n", 98 | " !pip install braceexpand\n", 99 | " !pip install git+https://github.com/pvigier/perlin-numpy\n", 100 | "\n", 101 | " # ClipDraw deps\n", 102 | " !pip install svgwrite\n", 103 | " !pip install svgpathtools\n", 104 | " !pip install cssutils\n", 105 | " !pip install numba\n", 106 | " !pip install torch-tools\n", 107 | " !pip install visdom\n", 108 | "\n", 109 | " !git clone https://github.com/BachiLi/diffvg\n", 110 | " %cd diffvg\n", 111 | " # !ls\n", 112 | " !git submodule update --init --recursive\n", 113 | " !python setup.py install\n", 114 | " %cd ..\n", 115 | "\n", 116 | "output.clear()\n", 117 | "import sys\n", 118 | "import shutil\n", 119 | "import numpy as np\n", 120 | "from PIL import ImageFile, Image, PngImagePlugin\n", 121 | "sys.path.append(\"pixray\")\n", 122 | "\n", 123 | "def make_seed_img(from_fn, to_fn='cur_seed.png', delete_pixels = 20, shift_pixels = 240):\n", 124 | " in_img = Image.open(from_fn)\n", 125 | " seed_img_array = np.random.rand(in_img.size[1],in_img.size[0],3) * 255\n", 126 | " seed_img = Image.fromarray(seed_img_array.astype('uint8')).convert('RGB')\n", 127 | " # mask_img = Image.new(mode=\"RGB\", size=(s_x, s_y), color=(0, 0, 0))\n", 128 | " seed_img.paste(in_img.crop((shift_pixels,0,in_img.size[0]-delete_pixels,in_img.size[1])), (0,0))\n", 129 | " seed_img.save(to_fn)\n", 130 | " # seed_img.crop((shift_pixels,0,seed_img.size[0]-delete_pixels,seed_img.size[1]))\n", 131 | "\n", 132 | "result_msg = \"setup complete\"\n", 133 | "import IPython\n", 134 | "import os\n", 135 | "if not os.path.isfile(\"first_init_complete\"):\n", 136 | " # put stuff in here that should only happen once\n", 137 | " !mkdir -p models\n", 138 | " os.mknod(\"first_init_complete\")\n", 139 | " result_msg = \"Please choose Runtime -> Restart Runtime from the menu, and then run Setup again\"\n", 140 | "\n", 141 | "js_code = f'''\n", 142 | "document.querySelector(\"#output-area\").appendChild(document.createTextNode(\"{result_msg}\"));\n", 143 | "'''\n", 144 | "js_code += '''\n", 145 | "for (rule of document.styleSheets[0].cssRules){\n", 146 | " if (rule.selectorText=='body') break\n", 147 | "}\n", 148 | "rule.style.fontSize = '30px'\n", 149 | "'''\n", 150 | "display(IPython.display.Javascript(js_code))" 151 | ], 152 | "execution_count": null, 153 | "outputs": [] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "metadata": { 158 | "id": "5Ule9Ee4S78b", 159 | "cellView": "form" 160 | }, 161 | "source": [ 162 | "#@title **First Frame Settings**\n", 163 | "\n", 164 | "#@markdown Enter a description of what you want to draw - I usually add #pixelart to the prompt.\n", 165 | "#@markdown If PixelDraw is not used, it will use VQGAN instead.\n", 166 | "#@markdown
\n", 167 | "\n", 168 | "prompts = \"Scary Monsters And Nice Sprites. #pixelart #8bit\" #@param {type:\"string\"}\n", 169 | "\n", 170 | "aspect = \"widescreen\" ##param [\"widescreen\", \"square\"]\n", 171 | "\n", 172 | "do_pixel = True #@param {type:\"boolean\"}\n", 173 | "\n", 174 | "#@markdown Specify the desired palette (\"\" for default), here's a few examples:\n", 175 | "#@markdown * red (16 color black to red ramp)\n", 176 | "#@markdown * rust\\8 (8 color black to rust ramp)\n", 177 | "#@markdown * black->red->white (16 color black/red/white ramp)\n", 178 | "#@markdown * [#000000, #ff0000, #ffff00, #000080] (four colors)\n", 179 | "#@markdown * red->yellow;[black] (16 colors from ramp and also black)\n", 180 | "#@markdown * Named colors can be anything in this lookup table\n", 181 | "\n", 182 | "use_palette = \"[#000000, #071008, #0e2011, #153019, #1c4022, #23502a, #2a6033, #31703b, #388044, #3f8f4c, #469f54, #4daf5d, #54bf65, #5bcf6e, #62df76, #69ef7f];black->white\" #@param {type:\"string\"}\n", 183 | "#@markdown Use this flag to encourage smoothess:\n", 184 | "smoothness = True #@param {type:\"boolean\"} \n", 185 | "\n", 186 | "#@markdown Use this flag to encourage color saturation (use it against color fading):\n", 187 | "saturation = True #@param {type:\"boolean\"} \n", 188 | "\n", 189 | "#@markdown When you have the settings you want, press the play button on the left.\n", 190 | "#@markdown The system will save these and start generating images below.\n", 191 | "\n", 192 | "#@markdown When that is done you can change these\n", 193 | "#@markdown settings and see if you get different results. Or if you get\n", 194 | "#@markdown impatient, just select \"Runtime -> Interrupt Execution\".\n", 195 | "#@markdown Note that the first time you run it may take a bit longer\n", 196 | "#@markdown as nessary files are downloaded.\n", 197 | "\n", 198 | "\n", 199 | "#@markdown\n", 200 | "#@markdown **Advanced: you can also edit this cell and add add additional settings, combining settings from different notebooks.**\n", 201 | "\n", 202 | "\n", 203 | "\n", 204 | "# Simple setup\n", 205 | "import pixray\n", 206 | "\n", 207 | "# these are good settings for pixeldraw\n", 208 | "pixray.reset_settings()\n", 209 | "pixray.add_settings(prompts=prompts, aspect=aspect)\n", 210 | "pixray.add_settings(quality=\"better\", scale=2.5)\n", 211 | "pixray.add_settings(display_clear=True)\n", 212 | "\n", 213 | "if do_pixel:\n", 214 | " pixray.add_settings(drawer=\"pixel\")\n", 215 | "\n", 216 | "# palette = None\n", 217 | "if use_palette and use_palette!='None':\n", 218 | " pixray.add_settings(target_palette=use_palette)\n", 219 | "\n", 220 | "if smoothness and smoothness!='None':\n", 221 | " pixray.add_settings(smoothness=2.0, smoothness_type='log')\n", 222 | "\n", 223 | "if saturation:\n", 224 | " pixray.add_settings(saturation=1.0)\n", 225 | "\n", 226 | "pixray.add_settings(noise_prompt_seeds=[1,2,3])\n", 227 | "\n", 228 | "#### YOU CAN ADD YOUR OWN CUSTOM SETTING HERE ####\n", 229 | "# this is the example of how to run longer with less frequent display\n", 230 | "# pixray.add_settings(iterations=500, display_every=50)\n", 231 | "\n", 232 | "settings = pixray.apply_settings()\n", 233 | "pixray.do_init(settings)\n", 234 | "pixray.do_run(settings)\n", 235 | "\n", 236 | "shutil.copy('./output.png', './very_first_frame.png')" 237 | ], 238 | "execution_count": null, 239 | "outputs": [] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "metadata": { 244 | "id": "_LxkC8GEZA-q", 245 | "cellView": "form" 246 | }, 247 | "source": [ 248 | "#@title **Make sure you like the first frame, then start this cell to generate all other frames:**\n", 249 | "\n", 250 | "shutil.copy('./very_first_frame.png', f'./frame_{0:03d}.png')\n", 251 | "\n", 252 | "delete_from_right_pixels = 20 #@param\n", 253 | "shift_pixels = 240 #@param\n", 254 | "half_frames2generate = 5 #@param\n", 255 | "\n", 256 | "\n", 257 | "for frame in range(half_frames2generate):\n", 258 | " shutil.copy('./output.png', f'./frame_{frame:03d}.png')\n", 259 | " make_seed_img(f'./frame_{frame:03d}.png')\n", 260 | " pixray.reset_settings()\n", 261 | " pixray.add_settings(prompts=prompts, aspect=aspect)\n", 262 | " pixray.add_settings(quality=\"better\", scale=2.5)\n", 263 | " pixray.add_settings(display_clear=True)\n", 264 | " pixray.add_settings(init_image='./cur_seed.png')\n", 265 | " if do_pixel:\n", 266 | " pixray.add_settings(drawer=\"pixel\")\n", 267 | " # palette = None\n", 268 | " if use_palette and use_palette!='None':\n", 269 | " pixray.add_settings(target_palette=use_palette)\n", 270 | " if smoothness and smoothness!='None':\n", 271 | " pixray.add_settings(smoothness=2.0, smoothness_type='log')\n", 272 | " if saturation:\n", 273 | " pixray.add_settings(saturation=1.0)\n", 274 | " pixray.add_settings(noise_prompt_seeds=[1,2,3]) \n", 275 | "\n", 276 | " settings = pixray.apply_settings()\n", 277 | " pixray.do_init(settings)\n", 278 | " pixray.do_run(settings)\n", 279 | "\n", 280 | " # shutil.copy('./output.png', f'./frame_{frame+1:03d}.png')\n" 281 | ], 282 | "execution_count": null, 283 | "outputs": [] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "metadata": { 288 | "id": "G_LR5QpoZBB2", 289 | "cellView": "form" 290 | }, 291 | "source": [ 292 | "#@title **Blend frames into the single image**\n", 293 | "\n", 294 | "from glob import glob\n", 295 | "# these values are okay for the widescreen aspect only\n", 296 | "sidex = 500\n", 297 | "sidey = 280\n", 298 | "\n", 299 | "\n", 300 | "frames = []\n", 301 | "for fn in glob('frame_*.png'):\n", 302 | " frames.append( fn )\n", 303 | "\n", 304 | "pano_img = Image.new(mode=\"RGB\", \n", 305 | " size=(sidex+shift_pixels*len(frames)-(sidex-shift_pixels), \n", 306 | " sidey+0*len(frames)), color=(255, 255, 255))\n", 307 | "\n", 308 | "for idx, fr in enumerate(frames):\n", 309 | " fr_img = Image.open(fr)\n", 310 | " fr_img = fr_img.convert('RGB')\n", 311 | " if not idx:\n", 312 | " pano_img.paste( fr_img , (idx*shift_pixels, 0))\n", 313 | " else:\n", 314 | " pano_img.paste( fr_img.crop((shift_pixels,0,sidex,sidey)), ((idx+1)*shift_pixels, 0))\n", 315 | " for x in range(shift_pixels):\n", 316 | " w = x/shift_pixels\n", 317 | " for y in range(sidey):\n", 318 | " int_c = np.array(pano_img.getpixel((idx*shift_pixels+x,y)))*(1-w)+\\\n", 319 | " np.array( fr_img.getpixel((x,y)) ) *(w)\n", 320 | " int_c = int_c.astype(int).tolist()\n", 321 | " pano_img.putpixel((idx*shift_pixels+x,y),tuple(int_c))\n", 322 | "pano_img.save('pano.png')\n", 323 | "pano_img" 324 | ], 325 | "execution_count": null, 326 | "outputs": [] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "metadata": { 331 | "id": "TCJu5oZIVTIm", 332 | "cellView": "form" 333 | }, 334 | "source": [ 335 | "#@title **Compile and view the resulting video**\n", 336 | "\n", 337 | "from tqdm.notebook import tqdm\n", 338 | "from subprocess import Popen, PIPE\n", 339 | "\n", 340 | "out_width = 500 #@param\n", 341 | "speed = 3 #@param\n", 342 | "fps = 25 #@param\n", 343 | "\n", 344 | "pano_img = Image.open('pano.png')\n", 345 | "pano_img.size\n", 346 | "\n", 347 | "\n", 348 | "frames = []\n", 349 | "for idx, x in enumerate( range(0,pano_img.size[0]-out_width,speed) ):\n", 350 | " frames.append( pano_img.crop( (x,0,x+out_width,pano_img.size[1]) ) )\n", 351 | "\n", 352 | "\n", 353 | "p = Popen(['ffmpeg',\n", 354 | " '-y',\n", 355 | " '-f', 'image2pipe',\n", 356 | " '-vcodec', 'png',\n", 357 | " '-r', str(fps),\n", 358 | " '-i',\n", 359 | " '-',\n", 360 | " '-vcodec', 'libx264',\n", 361 | " '-r', str(fps),\n", 362 | " '-pix_fmt', 'yuv420p',\n", 363 | " '-crf', '17',\n", 364 | " '-preset', 'veryslow',\n", 365 | " 'pano.mp4'], stdin=PIPE)\n", 366 | "\n", 367 | "for im in tqdm(frames):\n", 368 | " im.save(p.stdin, 'PNG')\n", 369 | "p.stdin.close()\n", 370 | "p.wait()\n", 371 | "\n", 372 | "from IPython.display import HTML\n", 373 | "from base64 import b64encode\n", 374 | "mp4 = open('pano.mp4','rb').read()\n", 375 | "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", 376 | "HTML(\"\"\"\n", 377 | "\n", 380 | "\"\"\" % data_url)" 381 | ], 382 | "execution_count": null, 383 | "outputs": [] 384 | } 385 | ] 386 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VQGAN+CLIP and other image generation system 2 | VQGAN+CLIP Colab Notebook with user-friendly interface. 3 | 4 | 5 | **Latest Notebook**: Open In Colab 6 | 7 | **Mse regulized zquantize Notebook**: Open In Colab 8 | 9 | **Zooming (Latest release with few addons)(W.I.P)**: Open In Colab 10 | 11 | 12 | 13 | **PixelDrawer**: Open In Colab 14 | 15 | 16 | **Pixray Panorama Demo**: Open In Colab 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | # Citations 62 | 63 | ```bibtex 64 | @misc{unpublished2021clip, 65 | title = {CLIP: Connecting Text and Images}, 66 | author = {Alec Radford, Ilya Sutskever, Jong Wook Kim, Gretchen Krueger, Sandhini Agarwal}, 67 | year = {2021} 68 | } 69 | ``` 70 | ```bibtex 71 | @misc{esser2020taming, 72 | title={Taming Transformers for High-Resolution Image Synthesis}, 73 | author={Patrick Esser and Robin Rombach and Björn Ommer}, 74 | year={2020}, 75 | eprint={2012.09841}, 76 | archivePrefix={arXiv}, 77 | primaryClass={cs.CV} 78 | } 79 | ``` 80 | Katherine Crowson - https://github.com/crowsonkb 81 | 82 | Public Domain images from Open Access Images at the Art Institute of Chicago - https://www.artic.edu/open-access/open-access-images 83 | -------------------------------------------------------------------------------- /VQGAN+CLIP(Updated).ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "VQGAN+CLIP(Updated).ipynb", 7 | "private_outputs": true, 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "machine_shape": "hm" 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | }, 19 | "accelerator": "GPU" 20 | }, 21 | "cells": [ 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "id": "CppIQlPhhwhs" 26 | }, 27 | "source": [ 28 | "# Generate images from text phrases with VQGAN and CLIP (z + quantize method with augmentations).\n", 29 | "\n", 30 | "Notebook by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). The original BigGAN + CLIP method was made by https://twitter.com/advadnoun. Translated and added explanations, and modifications by Eleiber # 8347, and the friendly interface was made thanks to Abulafia # 3734.\n", 31 | "\n", 32 | "For a detailed tutorial on how to use it, I recommend [visit this article] (https://yourcriatures.miraheze.org/wiki/Help:Create_images_with_VQGAN+CLIP), made by Jakeukalane # 2767 and Avengium (Angel) # 3715\n", 33 | "\n", 34 | "Modified by: Justin John\n" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "metadata": { 40 | "id": "VA1PHoJrRiK9", 41 | "cellView": "form" 42 | }, 43 | "source": [ 44 | "#@markdown #**Licensed under the MIT License (*Double-click me to read the license agreement*)**\n", 45 | "#@markdown ---\n", 46 | "\n", 47 | "# Copyright (c) 2021 Katherine Crowson\n", 48 | "\n", 49 | "# Permission is hereby granted, free of charge, to any person obtaining a copy\n", 50 | "# of this software and associated documentation files (the \"Software\"), to deal\n", 51 | "# in the Software without restriction, including without limitation the rights\n", 52 | "# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", 53 | "# copies of the Software, and to permit persons to whom the Software is\n", 54 | "# furnished to do so, subject to the following conditions:\n", 55 | "\n", 56 | "# The above copyright notice and this permission notice shall be included in\n", 57 | "# all copies or substantial portions of the Software.\n", 58 | "\n", 59 | "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", 60 | "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", 61 | "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", 62 | "# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", 63 | "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", 64 | "# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n", 65 | "# THE SOFTWARE." 66 | ], 67 | "execution_count": null, 68 | "outputs": [] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "metadata": { 73 | "id": "eq0-E5mjSpmP", 74 | "cellView": "form" 75 | }, 76 | "source": [ 77 | "#@markdown #**Check GPU type**\n", 78 | "#@markdown ### Factory reset runtime if you don't have the desired GPU.\n", 79 | "\n", 80 | "#@markdown ---\n", 81 | "\n", 82 | "\n", 83 | "\n", 84 | "\n", 85 | "#@markdown V100 = Excellent (*Available only for Colab Pro users*)\n", 86 | "\n", 87 | "#@markdown P100 = Very Good\n", 88 | "\n", 89 | "#@markdown T4 = Good (*preferred*)\n", 90 | "\n", 91 | "#@markdown K80 = Meh\n", 92 | "\n", 93 | "#@markdown P4 = (*Not Recommended*) \n", 94 | "\n", 95 | "#@markdown ---\n", 96 | "\n", 97 | "!nvidia-smi -L" 98 | ], 99 | "execution_count": null, 100 | "outputs": [] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "metadata": { 105 | "id": "IO09yGQNSmSd", 106 | "cellView": "form" 107 | }, 108 | "source": [ 109 | "#@markdown #**Anti-Disconnect for Google Colab**\n", 110 | "#@markdown ## Run this to stop it from disconnecting automatically \n", 111 | "#@markdown **(disconnects anyhow after 6 - 12 hrs for using the free version of Colab.)**\n", 112 | "#@markdown *(Pro users will get about 24 hrs usage time[depends])*\n", 113 | "#@markdown ---\n", 114 | "\n", 115 | "import IPython\n", 116 | "js_code = '''\n", 117 | "function ClickConnect(){\n", 118 | "console.log(\"Working\");\n", 119 | "document.querySelector(\"colab-toolbar-button#connect\").click()\n", 120 | "}\n", 121 | "setInterval(ClickConnect,60000)\n", 122 | "'''\n", 123 | "display(IPython.display.Javascript(js_code))" 124 | ], 125 | "execution_count": null, 126 | "outputs": [] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "metadata": { 131 | "id": "wSfISAhyPmyp", 132 | "cellView": "form" 133 | }, 134 | "source": [ 135 | "#@markdown #**Installation of libraries**\n", 136 | "# @markdown This cell will take a little while because it has to download several libraries\n", 137 | "\n", 138 | "#@markdown ---\n", 139 | " \n", 140 | "print(\"Installing CLIP...\")\n", 141 | "!git clone https://github.com/openai/CLIP &> /dev/null\n", 142 | " \n", 143 | "print(\"Installing Python Libraries for AI...\")\n", 144 | "!git clone https://github.com/CompVis/taming-transformers &> /dev/null\n", 145 | "!pip install transformers &> /dev/null\n", 146 | "!pip install ftfy regex tqdm omegaconf pytorch-lightning &> /dev/null\n", 147 | "!pip install kornia &> /dev/null\n", 148 | "!pip install einops &> /dev/null\n", 149 | "!pip install wget &> /dev/null\n", 150 | " \n", 151 | "print(\"Installing libraries for metadata management...\")\n", 152 | "!pip install stegano &> /dev/null\n", 153 | "!apt install exempi &> /dev/null\n", 154 | "!pip install python-xmp-toolkit &> /dev/null\n", 155 | "!pip install imgtag &> /dev/null\n", 156 | "!pip install pillow==7.1.2 &> /dev/null\n", 157 | " \n", 158 | "print(\"Installing Python libraries for creating videos...\")\n", 159 | "!pip install imageio-ffmpeg &> /dev/null\n", 160 | "!mkdir steps\n", 161 | "print(\"Installation completed.\")" 162 | ], 163 | "execution_count": null, 164 | "outputs": [] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "metadata": { 169 | "id": "FhhdWrSxQhwg", 170 | "cellView": "form" 171 | }, 172 | "source": [ 173 | "#@markdown #**Selection of models to download**\n", 174 | "#@markdown ---\n", 175 | "\n", 176 | "#@markdown By default, the notebook downloads Model 16384 from ImageNet. There are others that are not downloaded by default, since it would be in vain if you are not going to use them, so if you want to use them, simply select the models to download.\n", 177 | "\n", 178 | "#@markdown ---\n", 179 | "\n", 180 | "imagenet_1024 = False #@param {type:\"boolean\"}\n", 181 | "imagenet_16384 = True #@param {type:\"boolean\"}\n", 182 | "gumbel_8192 = False #@param {type:\"boolean\"}\n", 183 | "coco = False #@param {type:\"boolean\"}\n", 184 | "faceshq = False #@param {type:\"boolean\"}\n", 185 | "wikiart_1024 = False #@param {type:\"boolean\"}\n", 186 | "wikiart_16384 = False #@param {type:\"boolean\"}\n", 187 | "sflckr = False #@param {type:\"boolean\"}\n", 188 | "ade20k = False #@param {type:\"boolean\"}\n", 189 | "ffhq = False #@param {type:\"boolean\"}\n", 190 | "celebahq = False #@param {type:\"boolean\"}\n", 191 | "\n", 192 | "if imagenet_1024:\n", 193 | " !curl -L -o vqgan_imagenet_f16_1024.yaml -C - 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' #ImageNet 1024\n", 194 | " !curl -L -o vqgan_imagenet_f16_1024.ckpt -C - 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fckpts%2Flast.ckpt&dl=1' #ImageNet 1024\n", 195 | "if imagenet_16384:\n", 196 | " !curl -L -o vqgan_imagenet_f16_16384.yaml -C - 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' #ImageNet 16384\n", 197 | " !curl -L -o vqgan_imagenet_f16_16384.ckpt -C - 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fckpts%2Flast.ckpt&dl=1' #ImageNet 16384\n", 198 | "if gumbel_8192:\n", 199 | " !curl -L -o gumbel_8192.yaml -C - 'https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' #Gumbel 8192\n", 200 | " !curl -L -o gumbel_8192.ckpt -C - 'https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fckpts%2Flast.ckpt&dl=1' #Gumbel 8192\n", 201 | "if coco:\n", 202 | " !curl -L -o coco.yaml -C - 'https://dl.nmkd.de/ai/clip/coco/coco.yaml' #COCO\n", 203 | " !curl -L -o coco.ckpt -C - 'https://dl.nmkd.de/ai/clip/coco/coco.ckpt' #COCO\n", 204 | "if faceshq:\n", 205 | " !curl -L -o faceshq.yaml -C - 'https://drive.google.com/uc?export=download&id=1fHwGx_hnBtC8nsq7hesJvs-Klv-P0gzT' #FacesHQ\n", 206 | " !curl -L -o faceshq.ckpt -C - 'https://app.koofr.net/content/links/a04deec9-0c59-4673-8b37-3d696fe63a5d/files/get/last.ckpt?path=%2F2020-11-13T21-41-45_faceshq_transformer%2Fcheckpoints%2Flast.ckpt' #FacesHQ\n", 207 | "if wikiart_1024: \n", 208 | " !curl -L -o wikiart_1024.yaml -C - 'http://mirror.io.community/blob/vqgan/wikiart.yaml' #WikiArt 1024\n", 209 | " !curl -L -o wikiart_1024.ckpt -C - 'http://mirror.io.community/blob/vqgan/wikiart.ckpt' #WikiArt 1024\n", 210 | "if wikiart_16384: \n", 211 | " !curl -L -o wikiart_16384.yaml -C - 'http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.yaml' #WikiArt 16384\n", 212 | " !curl -L -o wikiart_16384.ckpt -C - 'http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.ckpt' #WikiArt 16384\n", 213 | "if sflckr:\n", 214 | " !curl -L -o sflckr.yaml -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fconfigs%2F2020-11-09T13-31-51-project.yaml&dl=1' #S-FLCKR\n", 215 | " !curl -L -o sflckr.ckpt -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fcheckpoints%2Flast.ckpt&dl=1' #S-FLCKR\n", 216 | "if ade20k:\n", 217 | " !curl -L -o ade20k.yaml -C - 'https://static.miraheze.org/intercriaturaswiki/b/bf/Ade20k.txt' #ADE20K\n", 218 | " !curl -L -o ade20k.ckpt -C - 'https://app.koofr.net/content/links/0f65c2cd-7102-4550-a2bd-07fd383aac9e/files/get/last.ckpt?path=%2F2020-11-20T21-45-44_ade20k_transformer%2Fcheckpoints%2Flast.ckpt' #ADE20K\n", 219 | "if ffhq:\n", 220 | " !curl -L -o ffhq.yaml -C - 'https://app.koofr.net/content/links/0fc005bf-3dca-4079-9d40-cdf38d42cd7a/files/get/2021-04-23T18-19-01-project.yaml?path=%2F2021-04-23T18-19-01_ffhq_transformer%2Fconfigs%2F2021-04-23T18-19-01-project.yaml&force' #FFHQ\n", 221 | " !curl -L -o ffhq.ckpt -C - 'https://app.koofr.net/content/links/0fc005bf-3dca-4079-9d40-cdf38d42cd7a/files/get/last.ckpt?path=%2F2021-04-23T18-19-01_ffhq_transformer%2Fcheckpoints%2Flast.ckpt&force' #FFHQ\n", 222 | "if celebahq:\n", 223 | " !curl -L -o celebahq.yaml -C - 'https://app.koofr.net/content/links/6dddf083-40c8-470a-9360-a9dab2a94e96/files/get/2021-04-23T18-11-19-project.yaml?path=%2F2021-04-23T18-11-19_celebahq_transformer%2Fconfigs%2F2021-04-23T18-11-19-project.yaml&force' #CelebA-HQ\n", 224 | " !curl -L -o celebahq.ckpt -C - 'https://app.koofr.net/content/links/6dddf083-40c8-470a-9360-a9dab2a94e96/files/get/last.ckpt?path=%2F2021-04-23T18-11-19_celebahq_transformer%2Fcheckpoints%2Flast.ckpt&force' #CelebA-HQ" 225 | ], 226 | "execution_count": null, 227 | "outputs": [] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "metadata": { 232 | "id": "EXMSuW2EQWsd", 233 | "cellView": "form" 234 | }, 235 | "source": [ 236 | "# @title Loading libraries and definitions\n", 237 | " \n", 238 | "import argparse\n", 239 | "import math\n", 240 | "from pathlib import Path\n", 241 | "import sys\n", 242 | " \n", 243 | "sys.path.append('./taming-transformers')\n", 244 | "from IPython import display\n", 245 | "from base64 import b64encode\n", 246 | "from omegaconf import OmegaConf\n", 247 | "from PIL import Image\n", 248 | "from taming.models import cond_transformer, vqgan\n", 249 | "import torch\n", 250 | "from torch import nn, optim\n", 251 | "from torch.nn import functional as F\n", 252 | "from torchvision import transforms\n", 253 | "from torchvision.transforms import functional as TF\n", 254 | "from tqdm.notebook import tqdm\n", 255 | " \n", 256 | "from CLIP import clip\n", 257 | "import kornia.augmentation as K\n", 258 | "import numpy as np\n", 259 | "import imageio\n", 260 | "from PIL import ImageFile, Image\n", 261 | "from imgtag import ImgTag # metadatos \n", 262 | "from libxmp import * # metadatos\n", 263 | "import libxmp # metadatos\n", 264 | "from stegano import lsb\n", 265 | "import json\n", 266 | "ImageFile.LOAD_TRUNCATED_IMAGES = True\n", 267 | " \n", 268 | "def sinc(x):\n", 269 | " return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))\n", 270 | " \n", 271 | " \n", 272 | "def lanczos(x, a):\n", 273 | " cond = torch.logical_and(-a < x, x < a)\n", 274 | " out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))\n", 275 | " return out / out.sum()\n", 276 | " \n", 277 | " \n", 278 | "def ramp(ratio, width):\n", 279 | " n = math.ceil(width / ratio + 1)\n", 280 | " out = torch.empty([n])\n", 281 | " cur = 0\n", 282 | " for i in range(out.shape[0]):\n", 283 | " out[i] = cur\n", 284 | " cur += ratio\n", 285 | " return torch.cat([-out[1:].flip([0]), out])[1:-1]\n", 286 | " \n", 287 | " \n", 288 | "def resample(input, size, align_corners=True):\n", 289 | " n, c, h, w = input.shape\n", 290 | " dh, dw = size\n", 291 | " \n", 292 | " input = input.view([n * c, 1, h, w])\n", 293 | " \n", 294 | " if dh < h:\n", 295 | " kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)\n", 296 | " pad_h = (kernel_h.shape[0] - 1) // 2\n", 297 | " input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')\n", 298 | " input = F.conv2d(input, kernel_h[None, None, :, None])\n", 299 | " \n", 300 | " if dw < w:\n", 301 | " kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)\n", 302 | " pad_w = (kernel_w.shape[0] - 1) // 2\n", 303 | " input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')\n", 304 | " input = F.conv2d(input, kernel_w[None, None, None, :])\n", 305 | " \n", 306 | " input = input.view([n, c, h, w])\n", 307 | " return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)\n", 308 | " \n", 309 | " \n", 310 | "class ReplaceGrad(torch.autograd.Function):\n", 311 | " @staticmethod\n", 312 | " def forward(ctx, x_forward, x_backward):\n", 313 | " ctx.shape = x_backward.shape\n", 314 | " return x_forward\n", 315 | " \n", 316 | " @staticmethod\n", 317 | " def backward(ctx, grad_in):\n", 318 | " return None, grad_in.sum_to_size(ctx.shape)\n", 319 | " \n", 320 | " \n", 321 | "replace_grad = ReplaceGrad.apply\n", 322 | " \n", 323 | " \n", 324 | "class ClampWithGrad(torch.autograd.Function):\n", 325 | " @staticmethod\n", 326 | " def forward(ctx, input, min, max):\n", 327 | " ctx.min = min\n", 328 | " ctx.max = max\n", 329 | " ctx.save_for_backward(input)\n", 330 | " return input.clamp(min, max)\n", 331 | " \n", 332 | " @staticmethod\n", 333 | " def backward(ctx, grad_in):\n", 334 | " input, = ctx.saved_tensors\n", 335 | " return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None\n", 336 | " \n", 337 | " \n", 338 | "clamp_with_grad = ClampWithGrad.apply\n", 339 | " \n", 340 | " \n", 341 | "def vector_quantize(x, codebook):\n", 342 | " d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T\n", 343 | " indices = d.argmin(-1)\n", 344 | " x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook\n", 345 | " return replace_grad(x_q, x)\n", 346 | " \n", 347 | " \n", 348 | "class Prompt(nn.Module):\n", 349 | " def __init__(self, embed, weight=1., stop=float('-inf')):\n", 350 | " super().__init__()\n", 351 | " self.register_buffer('embed', embed)\n", 352 | " self.register_buffer('weight', torch.as_tensor(weight))\n", 353 | " self.register_buffer('stop', torch.as_tensor(stop))\n", 354 | " \n", 355 | " def forward(self, input):\n", 356 | " input_normed = F.normalize(input.unsqueeze(1), dim=2)\n", 357 | " embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)\n", 358 | " dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)\n", 359 | " dists = dists * self.weight.sign()\n", 360 | " return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()\n", 361 | " \n", 362 | " \n", 363 | "def parse_prompt(prompt):\n", 364 | " vals = prompt.rsplit(':', 2)\n", 365 | " vals = vals + ['', '1', '-inf'][len(vals):]\n", 366 | " return vals[0], float(vals[1]), float(vals[2])\n", 367 | " \n", 368 | " \n", 369 | "class MakeCutouts(nn.Module):\n", 370 | " def __init__(self, cut_size, cutn, cut_pow=1.):\n", 371 | " super().__init__()\n", 372 | " self.cut_size = cut_size\n", 373 | " self.cutn = cutn\n", 374 | " self.cut_pow = cut_pow\n", 375 | " self.augs = nn.Sequential(\n", 376 | " K.RandomHorizontalFlip(p=0.5),\n", 377 | " # K.RandomSolarize(0.01, 0.01, p=0.7),\n", 378 | " K.RandomSharpness(0.3,p=0.4),\n", 379 | " K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'),\n", 380 | " K.RandomPerspective(0.2,p=0.4),\n", 381 | " K.ColorJitter(hue=0.01, saturation=0.01, p=0.7))\n", 382 | " self.noise_fac = 0.1\n", 383 | " \n", 384 | " \n", 385 | " def forward(self, input):\n", 386 | " sideY, sideX = input.shape[2:4]\n", 387 | " max_size = min(sideX, sideY)\n", 388 | " min_size = min(sideX, sideY, self.cut_size)\n", 389 | " cutouts = []\n", 390 | " for _ in range(self.cutn):\n", 391 | " size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)\n", 392 | " offsetx = torch.randint(0, sideX - size + 1, ())\n", 393 | " offsety = torch.randint(0, sideY - size + 1, ())\n", 394 | " cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n", 395 | " cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))\n", 396 | " batch = self.augs(torch.cat(cutouts, dim=0))\n", 397 | " if self.noise_fac:\n", 398 | " facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)\n", 399 | " batch = batch + facs * torch.randn_like(batch)\n", 400 | " return batch\n", 401 | " \n", 402 | " \n", 403 | "def load_vqgan_model(config_path, checkpoint_path):\n", 404 | " config = OmegaConf.load(config_path)\n", 405 | " if config.model.target == 'taming.models.vqgan.VQModel':\n", 406 | " model = vqgan.VQModel(**config.model.params)\n", 407 | " model.eval().requires_grad_(False)\n", 408 | " model.init_from_ckpt(checkpoint_path)\n", 409 | " elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':\n", 410 | " parent_model = cond_transformer.Net2NetTransformer(**config.model.params)\n", 411 | " parent_model.eval().requires_grad_(False)\n", 412 | " parent_model.init_from_ckpt(checkpoint_path)\n", 413 | " model = parent_model.first_stage_model\n", 414 | " elif config.model.target == 'taming.models.vqgan.GumbelVQ':\n", 415 | " model = vqgan.GumbelVQ(**config.model.params)\n", 416 | " print(config.model.params)\n", 417 | " model.eval().requires_grad_(False)\n", 418 | " model.init_from_ckpt(checkpoint_path)\n", 419 | " else:\n", 420 | " raise ValueError(f'unknown model type: {config.model.target}')\n", 421 | " del model.loss\n", 422 | " return model\n", 423 | " \n", 424 | " \n", 425 | "def resize_image(image, out_size):\n", 426 | " ratio = image.size[0] / image.size[1]\n", 427 | " area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])\n", 428 | " size = round((area * ratio)**0.5), round((area / ratio)**0.5)\n", 429 | " return image.resize(size, Image.LANCZOS)\n", 430 | "\n", 431 | "def download_img(img_url):\n", 432 | " try:\n", 433 | " return wget.download(img_url,out=\"input.jpg\")\n", 434 | " except:\n", 435 | " return\n" 436 | ], 437 | "execution_count": null, 438 | "outputs": [] 439 | }, 440 | { 441 | "cell_type": "markdown", 442 | "metadata": { 443 | "id": "1tthw0YaispD" 444 | }, 445 | "source": [ 446 | "## Tools for execution:\n", 447 | "Mainly what you will have to modify will be `texts:`, there you can place the text (s) you want to generate (separated with `|`). It is a list because you can put more than one text, and so the AI ​​tries to 'mix' the images, giving the same priority to both texts.\n", 448 | "\n", 449 | "To use an initial image to the model, you just have to upload a file to the Colab environment (in the section on the left), and then modify `init_image:` putting the exact name of the file. Example: `sample.png`\n", 450 | "\n", 451 | "You can also modify the model by changing the lines that say `model:`. Currently 1024, 16384, WikiArt, S-FLCKR and COCO-Stuff are available. To activate them you have to have downloaded them first, and then you can simply select it.\n", 452 | "\n", 453 | "You can also use `target_images`, which is basically putting one or more images on it that the AI ​​will take as a \"target\", fulfilling the same function as putting text on it. To put more than one you have to use `|` as a separator." 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "metadata": { 459 | "id": "ZdlpRFL8UAlW", 460 | "cellView": "form" 461 | }, 462 | "source": [ 463 | "#@markdown #**Parameters**\n", 464 | "#@markdown ---\n", 465 | "texts = \"a fantasy world.\" #@param {type:\"string\"}\n", 466 | "width = 300#@param {type:\"number\"}\n", 467 | "height = 300#@param {type:\"number\"}\n", 468 | "model = \"vqgan_imagenet_f16_16384\" #@param [\"vqgan_imagenet_f16_16384\", \"vqgan_imagenet_f16_1024\", \"wikiart_1024\", \"wikiart_16384\", \"coco\", \"faceshq\", \"sflckr\", \"ade20k\", \"ffhq\", \"celebahq\", \"gumbel_8192\"]\n", 469 | "images_interval = 50#@param {type:\"number\"}\n", 470 | "init_image = \"\"#@param {type:\"string\"}\n", 471 | "target_images = \"\"#@param {type:\"string\"}\n", 472 | "seed = -1#@param {type:\"number\"}\n", 473 | "max_iterations = -1#@param {type:\"number\"}\n", 474 | "input_images = \"\"\n", 475 | "\n", 476 | "model_names={\"vqgan_imagenet_f16_16384\": 'ImageNet 16384',\"vqgan_imagenet_f16_1024\":\"ImageNet 1024\", \n", 477 | " \"wikiart_1024\":\"WikiArt 1024\", \"wikiart_16384\":\"WikiArt 16384\", \"coco\":\"COCO-Stuff\", \"faceshq\":\"FacesHQ\", \"sflckr\":\"S-FLCKR\", \"ade20k\":\"ADE20K\", \"ffhq\":\"FFHQ\", \"celebahq\":\"CelebA-HQ\", \"gumbel_8192\": \"Gumbel 8192\"}\n", 478 | "name_model = model_names[model] \n", 479 | "\n", 480 | "if model == \"gumbel_8192\":\n", 481 | " is_gumbel = True\n", 482 | "else:\n", 483 | " is_gumbel = False\n", 484 | "\n", 485 | "if seed == -1:\n", 486 | " seed = None\n", 487 | "if init_image == \"None\":\n", 488 | " init_image = None\n", 489 | "elif init_image and init_image.lower().startswith(\"http\"):\n", 490 | " init_image = download_img(init_image)\n", 491 | "\n", 492 | "\n", 493 | "if target_images == \"None\" or not target_images:\n", 494 | " target_images = []\n", 495 | "else:\n", 496 | " target_images = target_images.split(\"|\")\n", 497 | " target_images = [image.strip() for image in target_images]\n", 498 | "\n", 499 | "if init_image or target_images != []:\n", 500 | " input_images = True\n", 501 | "\n", 502 | "texts = [frase.strip() for frase in texts.split(\"|\")]\n", 503 | "if texts == ['']:\n", 504 | " texts = []\n", 505 | "\n", 506 | "\n", 507 | "args = argparse.Namespace(\n", 508 | " prompts=texts,\n", 509 | " image_prompts=target_images,\n", 510 | " noise_prompt_seeds=[],\n", 511 | " noise_prompt_weights=[],\n", 512 | " size=[width, height],\n", 513 | " init_image=init_image,\n", 514 | " init_weight=0.,\n", 515 | " clip_model='ViT-B/32',\n", 516 | " vqgan_config=f'{model}.yaml',\n", 517 | " vqgan_checkpoint=f'{model}.ckpt',\n", 518 | " step_size=0.1,\n", 519 | " cutn=64,\n", 520 | " cut_pow=1.,\n", 521 | " display_freq=images_interval,\n", 522 | " seed=seed,\n", 523 | ")" 524 | ], 525 | "execution_count": null, 526 | "outputs": [] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "metadata": { 531 | "id": "g7EDme5RYCrt", 532 | "cellView": "form" 533 | }, 534 | "source": [ 535 | "#@markdown #**Fire up the AI**\n", 536 | "\n", 537 | "#@markdown ---\n", 538 | "\n", 539 | "\n", 540 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 541 | "print('Using device:', device)\n", 542 | "if texts:\n", 543 | " print('Using texts:', texts)\n", 544 | "if target_images:\n", 545 | " print('Using image prompts:', target_images)\n", 546 | "if args.seed is None:\n", 547 | " seed = torch.seed()\n", 548 | "else:\n", 549 | " seed = args.seed\n", 550 | "torch.manual_seed(seed)\n", 551 | "print('Using seed:', seed)\n", 552 | "\n", 553 | "model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)\n", 554 | "perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)\n", 555 | "\n", 556 | "cut_size = perceptor.visual.input_resolution\n", 557 | "if is_gumbel:\n", 558 | " e_dim = model.quantize.embedding_dim\n", 559 | "else:\n", 560 | " e_dim = model.quantize.e_dim\n", 561 | "\n", 562 | "f = 2**(model.decoder.num_resolutions - 1)\n", 563 | "make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow)\n", 564 | "if is_gumbel:\n", 565 | " n_toks = model.quantize.n_embed\n", 566 | "else:\n", 567 | " n_toks = model.quantize.n_e\n", 568 | "\n", 569 | "toksX, toksY = args.size[0] // f, args.size[1] // f\n", 570 | "sideX, sideY = toksX * f, toksY * f\n", 571 | "if is_gumbel:\n", 572 | " z_min = model.quantize.embed.weight.min(dim=0).values[None, :, None, None]\n", 573 | " z_max = model.quantize.embed.weight.max(dim=0).values[None, :, None, None]\n", 574 | "else:\n", 575 | " z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]\n", 576 | " z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]\n", 577 | "\n", 578 | "if args.init_image:\n", 579 | " pil_image = Image.open(args.init_image).convert('RGB')\n", 580 | " pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)\n", 581 | " z, *_ = model.encode(TF.to_tensor(pil_image).to(device).unsqueeze(0) * 2 - 1)\n", 582 | "else:\n", 583 | " one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float()\n", 584 | " if is_gumbel:\n", 585 | " z = one_hot @ model.quantize.embed.weight\n", 586 | " else:\n", 587 | " z = one_hot @ model.quantize.embedding.weight\n", 588 | " z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)\n", 589 | "z_orig = z.clone()\n", 590 | "z.requires_grad_(True)\n", 591 | "opt = optim.Adam([z], lr=args.step_size)\n", 592 | "\n", 593 | "normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],\n", 594 | " std=[0.26862954, 0.26130258, 0.27577711])\n", 595 | "\n", 596 | "pMs = []\n", 597 | "\n", 598 | "for prompt in args.prompts:\n", 599 | " txt, weight, stop = parse_prompt(prompt)\n", 600 | " embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()\n", 601 | " pMs.append(Prompt(embed, weight, stop).to(device))\n", 602 | "\n", 603 | "for prompt in args.image_prompts:\n", 604 | " path, weight, stop = parse_prompt(prompt)\n", 605 | " img = resize_image(Image.open(path).convert('RGB'), (sideX, sideY))\n", 606 | " batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))\n", 607 | " embed = perceptor.encode_image(normalize(batch)).float()\n", 608 | " pMs.append(Prompt(embed, weight, stop).to(device))\n", 609 | "\n", 610 | "for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):\n", 611 | " gen = torch.Generator().manual_seed(seed)\n", 612 | " embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)\n", 613 | " pMs.append(Prompt(embed, weight).to(device))\n", 614 | "\n", 615 | "def synth(z):\n", 616 | " if is_gumbel:\n", 617 | " z_q = vector_quantize(z.movedim(1, 3), model.quantize.embed.weight).movedim(3, 1)\n", 618 | " else:\n", 619 | " z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)\n", 620 | " \n", 621 | " return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)\n", 622 | "\n", 623 | "def add_xmp_data(nombrefichero):\n", 624 | " imagen = ImgTag(filename=nombrefichero)\n", 625 | " imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'creator', 'VQGAN+CLIP', {\"prop_array_is_ordered\":True, \"prop_value_is_array\":True})\n", 626 | " if args.prompts:\n", 627 | " imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'title', \" | \".join(args.prompts), {\"prop_array_is_ordered\":True, \"prop_value_is_array\":True})\n", 628 | " else:\n", 629 | " imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'title', 'None', {\"prop_array_is_ordered\":True, \"prop_value_is_array\":True})\n", 630 | " imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'i', str(i), {\"prop_array_is_ordered\":True, \"prop_value_is_array\":True})\n", 631 | " imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'model', name_model, {\"prop_array_is_ordered\":True, \"prop_value_is_array\":True})\n", 632 | " imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'seed',str(seed) , {\"prop_array_is_ordered\":True, \"prop_value_is_array\":True})\n", 633 | " imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'input_images',str(input_images) , {\"prop_array_is_ordered\":True, \"prop_value_is_array\":True})\n", 634 | " #for frases in args.prompts:\n", 635 | " # imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'Prompt' ,frases, {\"prop_array_is_ordered\":True, \"prop_value_is_array\":True})\n", 636 | " imagen.close()\n", 637 | "\n", 638 | "def add_stegano_data(filename):\n", 639 | " data = {\n", 640 | " \"title\": \" | \".join(args.prompts) if args.prompts else None,\n", 641 | " \"notebook\": \"VQGAN+CLIP\",\n", 642 | " \"i\": i,\n", 643 | " \"model\": name_model,\n", 644 | " \"seed\": str(seed),\n", 645 | " \"input_images\": input_images\n", 646 | " }\n", 647 | " lsb.hide(filename, json.dumps(data)).save(filename)\n", 648 | "\n", 649 | "@torch.no_grad()\n", 650 | "def checkin(i, losses):\n", 651 | " losses_str = ', '.join(f'{loss.item():g}' for loss in losses)\n", 652 | " tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}')\n", 653 | " out = synth(z)\n", 654 | " TF.to_pil_image(out[0].cpu()).save('progress.png')\n", 655 | " add_stegano_data('progress.png')\n", 656 | " add_xmp_data('progress.png')\n", 657 | " display.display(display.Image('progress.png'))\n", 658 | "\n", 659 | "def ascend_txt():\n", 660 | " global i\n", 661 | " out = synth(z)\n", 662 | " iii = perceptor.encode_image(normalize(make_cutouts(out))).float()\n", 663 | "\n", 664 | " result = []\n", 665 | "\n", 666 | " if args.init_weight:\n", 667 | " result.append(F.mse_loss(z, z_orig) * args.init_weight / 2)\n", 668 | "\n", 669 | " for prompt in pMs:\n", 670 | " result.append(prompt(iii))\n", 671 | " img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]\n", 672 | " img = np.transpose(img, (1, 2, 0))\n", 673 | " filename = f\"steps/{i:04}.png\"\n", 674 | " imageio.imwrite(filename, np.array(img))\n", 675 | " add_stegano_data(filename)\n", 676 | " add_xmp_data(filename)\n", 677 | " return result\n", 678 | "\n", 679 | "def train(i):\n", 680 | " opt.zero_grad()\n", 681 | " lossAll = ascend_txt()\n", 682 | " if i % args.display_freq == 0:\n", 683 | " checkin(i, lossAll)\n", 684 | " loss = sum(lossAll)\n", 685 | " loss.backward()\n", 686 | " opt.step()\n", 687 | " with torch.no_grad():\n", 688 | " z.copy_(z.maximum(z_min).minimum(z_max))\n", 689 | "\n", 690 | "i = 0\n", 691 | "try:\n", 692 | " with tqdm() as pbar:\n", 693 | " while True:\n", 694 | " train(i)\n", 695 | " if i == max_iterations:\n", 696 | " break\n", 697 | " i += 1\n", 698 | " pbar.update()\n", 699 | "except KeyboardInterrupt:\n", 700 | " pass" 701 | ], 702 | "execution_count": null, 703 | "outputs": [] 704 | }, 705 | { 706 | "cell_type": "code", 707 | "metadata": { 708 | "id": "_htVEROPRuje", 709 | "cellView": "form" 710 | }, 711 | "source": [ 712 | "#@markdown **Generate a video with the result (You can edit frame rate and stuff by double-clicking this tab)**\n", 713 | "init_frame = 1 #This is the frame where the video will start\n", 714 | "last_frame = i #You can change i to the number of the last frame you want to generate. It will raise an error if that number of frames does not exist.\n", 715 | "\n", 716 | "min_fps = 10\n", 717 | "max_fps = 30\n", 718 | "\n", 719 | "total_frames = last_frame-init_frame\n", 720 | "\n", 721 | "length = 15 #Desired video time in seconds\n", 722 | "\n", 723 | "frames = []\n", 724 | "tqdm.write('Generating video...')\n", 725 | "for i in range(init_frame,last_frame): #\n", 726 | " filename = f\"steps/{i:04}.png\"\n", 727 | " frames.append(Image.open(filename))\n", 728 | "\n", 729 | "#fps = last_frame/10\n", 730 | "fps = np.clip(total_frames/length,min_fps,max_fps)\n", 731 | "\n", 732 | "from subprocess import Popen, PIPE\n", 733 | "p = Popen(['ffmpeg', '-y', '-f', 'image2pipe', '-vcodec', 'png', '-r', str(fps), '-i', '-', '-vcodec', 'libx264', '-r', str(fps), '-pix_fmt', 'yuv420p', '-crf', '17', '-preset', 'veryslow', 'video.mp4'], stdin=PIPE)\n", 734 | "for im in tqdm(frames):\n", 735 | " im.save(p.stdin, 'PNG')\n", 736 | "p.stdin.close()\n", 737 | "\n", 738 | "print(\"The video is now being compressed, wait...\")\n", 739 | "p.wait()\n", 740 | "print(\"The video is ready\")" 741 | ], 742 | "execution_count": null, 743 | "outputs": [] 744 | }, 745 | { 746 | "cell_type": "code", 747 | "metadata": { 748 | "id": "M8Oomx6_Ry74", 749 | "cellView": "form" 750 | }, 751 | "source": [ 752 | "#@markdown **View video in browser**\n", 753 | "\n", 754 | "# @markdown *This process may take a little longer. If you don't want to wait, download it by executing the next cell instead of using this cell.*\n", 755 | "mp4 = open('video.mp4','rb').read()\n", 756 | "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", 757 | "display.HTML(\"\"\"\n", 758 | "\n", 761 | "\"\"\" % data_url)" 762 | ], 763 | "execution_count": null, 764 | "outputs": [] 765 | }, 766 | { 767 | "cell_type": "code", 768 | "metadata": { 769 | "id": "JqFEaEChR20v", 770 | "cellView": "form" 771 | }, 772 | "source": [ 773 | "#@markdown #**Download the result video**\n", 774 | "from google.colab import files\n", 775 | "files.download(\"video.mp4\")" 776 | ], 777 | "execution_count": null, 778 | "outputs": [] 779 | }, 780 | { 781 | "cell_type": "markdown", 782 | "metadata": { 783 | "id": "PpgclmwdQHMg" 784 | }, 785 | "source": [ 786 | "**Delete all frames from folder**" 787 | ] 788 | }, 789 | { 790 | "cell_type": "code", 791 | "metadata": { 792 | "id": "ZalrMVZ6QJPh", 793 | "cellView": "form" 794 | }, 795 | "source": [ 796 | "#@markdown Run this tab if you wanna clear all the genarated frames images\n", 797 | "\n", 798 | "\n", 799 | "%cd content/steps\n", 800 | "%rm *.png\n", 801 | "%cd .." 802 | ], 803 | "execution_count": null, 804 | "outputs": [] 805 | }, 806 | { 807 | "cell_type": "markdown", 808 | "metadata": { 809 | "id": "-o8F1NCNTK2u" 810 | }, 811 | "source": [ 812 | "JS to prevent idle timeout:\n", 813 | "\n", 814 | "Press F12 OR CTRL + SHIFT + I OR right click on this website -> inspect.\n", 815 | "Then click on the console tab and paste in the following code.\n", 816 | "\n", 817 | "```javascript\n", 818 | "function ClickConnect(){\n", 819 | "console.log(\"Working\");\n", 820 | "document.querySelector(\"colab-toolbar-button#connect\").click()\n", 821 | "}\n", 822 | "setInterval(ClickConnect,60000)\n", 823 | "```" 824 | ] 825 | } 826 | ] 827 | } -------------------------------------------------------------------------------- /VQGAN+CLIP_(Zooming)_(z+quantize_method_with_addons).ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "VQGAN+CLIP_(Zooming)_(z+quantize_method_with_addons).ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "csJoi-Ppfgpq" 23 | }, 24 | "source": [ 25 | "# Generate images from text phrases with VQGAN and CLIP (z + quantize method), with animation and keyframes\n", 26 | "\n", 27 | "Notebook by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). The original BigGAN + CLIP method was made by https://twitter.com/advadnoun. Translated into Spanish and added explanations, and modifications by Eleiber#8347, and the friendly interface was made thanks to Abulafia#3734. Translated back into English, and zoom, pan, rotation, and keyframes features by Chigozie Nri (https://github.com/chigozienri, https://twitter.com/chigozienri)\n", 28 | "If you encounter problems using it, you are welcome to ask me to fix it at https://twitter.com/chigozienri\n", 29 | "\n", 30 | "Slight modifications by : Justin John\n" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": { 36 | "id": "wguhQ-XdIgmf" 37 | }, 38 | "source": [ 39 | "*ToDo: Add more models*" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": { 45 | "id": "cCfJMQjpfwqh" 46 | }, 47 | "source": [ 48 | "# How to use this notebook\n", 49 | "\n", 50 | "This is an example of a Jupyter Notebook, running in Google Colab\n", 51 | "\n", 52 | "It runs Python code in your browser. It's not hard to use, even if you haven't run code before.\n", 53 | "\n", 54 | "First, in the menu bar, click Runtime>Change Runtime Type, and ensure that under \"Hardware Accelerator\" it says \"GPU\". If not, choose \"GPU\" from the drop-down menu, and click Save.\n", 55 | "\n", 56 | "Then, run each of the cells in the notebook, one by one. Make sure to run all of them in order! Click in the cell, and press Shift-Enter on your keyboard. This will run the code in the cell, and then move to the next cell.\n", 57 | "\n", 58 | "Follow the instructions in each cell, and you'll have an AI image in no time!" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": { 64 | "id": "oPmHkif0gLro" 65 | }, 66 | "source": [ 67 | "# Load Google Drive\n", 68 | "\n", 69 | "Long-running colab notebooks might halt, and discard all progress. For this reason, it's useful (although optional) to save the images as they are produced in your personal google drive. Run the cell below to load google drive, click the link, sign in, paste the code generated into the prompt, and press enter." 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "metadata": { 75 | "id": "ij62RNZMgO7d", 76 | "cellView": "form" 77 | }, 78 | "source": [ 79 | "#@markdown #**Run me** \n", 80 | "from google.colab import drive\n", 81 | "drive.mount('/content/gdrive')\n", 82 | "\n", 83 | "working_dir = '/content/gdrive/MyDrive/vqgan'" 84 | ], 85 | "execution_count": null, 86 | "outputs": [] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": { 91 | "id": "_Zvdw9xBgSUy" 92 | }, 93 | "source": [ 94 | "If you choose not to use google drive, uncomment the cell below and run it instead." 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "metadata": { 100 | "id": "MP65HghcgV2A", 101 | "cellView": "form" 102 | }, 103 | "source": [ 104 | "#@markdown **Double-click here and uncomment this code if you don't want to use Google Drive**\n", 105 | "# working_dir = '/content'" 106 | ], 107 | "execution_count": null, 108 | "outputs": [] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "metadata": { 113 | "id": "oA9isL44gb5V", 114 | "cellView": "form" 115 | }, 116 | "source": [ 117 | "# @title **Licensed under the MIT License**\n", 118 | "\n", 119 | "# Copyright (c) 2021 Katherine Crowson\n", 120 | "\n", 121 | "# Permission is hereby granted, free of charge, to any person obtaining a copy\n", 122 | "# of this software and associated documentation files (the \"Software\"), to deal\n", 123 | "# in the Software without restriction, including without limitation the rights\n", 124 | "# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", 125 | "# copies of the Software, and to permit persons to whom the Software is\n", 126 | "# furnished to do so, subject to the following conditions:\n", 127 | "\n", 128 | "# The above copyright notice and this permission notice shall be included in\n", 129 | "# all copies or substantial portions of the Software.\n", 130 | "\n", 131 | "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", 132 | "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", 133 | "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", 134 | "# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", 135 | "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", 136 | "# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n", 137 | "# THE SOFTWARE." 138 | ], 139 | "execution_count": null, 140 | "outputs": [] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "metadata": { 145 | "id": "oQ3nrdjzgerS", 146 | "cellView": "form" 147 | }, 148 | "source": [ 149 | "#@markdown #**Check GPU type**\n", 150 | "#@markdown ### Factory reset runtime if you don't have the desired GPU.\n", 151 | "\n", 152 | "#@markdown ---\n", 153 | "\n", 154 | "\n", 155 | "\n", 156 | "\n", 157 | "#@markdown V100 = Excellent (*Available only for Colab Pro users*)\n", 158 | "\n", 159 | "#@markdown P100 = Very Good (*Available only for Colab Pro users*)\n", 160 | "\n", 161 | "#@markdown T4 = Good (*Available only for Colab Pro users*)\n", 162 | "\n", 163 | "#@markdown K80 = (*Untested*)\n", 164 | "\n", 165 | "#@markdown P4 = (*Not Recommended*) \n", 166 | "\n", 167 | "#@markdown ---\n", 168 | "\n", 169 | "!nvidia-smi -L" 170 | ], 171 | "execution_count": null, 172 | "outputs": [] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "metadata": { 177 | "id": "ngVYOAZEghkK", 178 | "cellView": "form" 179 | }, 180 | "source": [ 181 | "# @title #**Library installation**\n", 182 | "# @markdown This cell will take a while because you have to download multiple libraries\n", 183 | "\n", 184 | "print(\"Downloading CLIP...\")\n", 185 | "!git clone https://github.com/openai/CLIP &> /dev/null\n", 186 | " \n", 187 | "print(\"Downloading Python AI libraries...\")\n", 188 | "!git clone https://github.com/CompVis/taming-transformers &> /dev/null\n", 189 | "!pip install ftfy regex tqdm omegaconf pytorch-lightning &> /dev/null\n", 190 | "!pip install kornia &> /dev/null\n", 191 | "!pip install einops &> /dev/null\n", 192 | " \n", 193 | "print(\"Installing libraries for handling metadata...\")\n", 194 | "!pip install stegano &> /dev/null\n", 195 | "!apt install exempi &> /dev/null\n", 196 | "!pip install python-xmp-toolkit &> /dev/null\n", 197 | "!pip install imgtag &> /dev/null\n", 198 | "!pip install pillow==7.1.2 &> /dev/null\n", 199 | " \n", 200 | "print(\"Installing Python video creation libraries...\")\n", 201 | "!pip install imageio-ffmpeg &> /dev/null\n", 202 | "path = f'{working_dir}/steps'\n", 203 | "!mkdir --parents {path}\n", 204 | "print(\"Installation finished.\")" 205 | ], 206 | "execution_count": null, 207 | "outputs": [] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "metadata": { 212 | "id": "xkgbrCZqgkFw", 213 | "cellView": "form" 214 | }, 215 | "source": [ 216 | "#@title #**Selection of models to download**\n", 217 | "#@markdown By default, the notebook downloads Model 16384 from ImageNet. There are others such as ImageNet 1024, COCO-Stuff, WikiArt 1024, WikiArt 16384, FacesHQ or S-FLCKR, which are not downloaded by default, since it would be in vain if you are not going to use them, so if you want to use them, simply select the models to download.\n", 218 | "\n", 219 | "imagenet_1024 = False #@param {type:\"boolean\"}\n", 220 | "imagenet_16384 = True #@param {type:\"boolean\"}\n", 221 | "coco = False #@param {type:\"boolean\"}\n", 222 | "faceshq = False #@param {type:\"boolean\"}\n", 223 | "wikiart_16384 = False #@param {type:\"boolean\"}\n", 224 | "sflckr = False #@param {type:\"boolean\"}\n", 225 | "\n", 226 | "if imagenet_1024:\n", 227 | " !curl -L -o vqgan_imagenet_f16_1024.yaml -C - 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' #ImageNet 1024\n", 228 | " !curl -L -o vqgan_imagenet_f16_1024.ckpt -C - 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fckpts%2Flast.ckpt&dl=1' #ImageNet 1024\n", 229 | "if imagenet_16384:\n", 230 | " !curl -L -o vqgan_imagenet_f16_16384.yaml -C - 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' #ImageNet 16384\n", 231 | " !curl -L -o vqgan_imagenet_f16_16384.ckpt -C - 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fckpts%2Flast.ckpt&dl=1' #ImageNet 16384\n", 232 | "if coco:\n", 233 | " !curl -L -o coco.yaml -C - 'https://dl.nmkd.de/ai/clip/coco/coco.yaml' #COCO\n", 234 | " !curl -L -o coco.ckpt -C - 'https://dl.nmkd.de/ai/clip/coco/coco.ckpt' #COCO\n", 235 | "if faceshq:\n", 236 | " !curl -L -o faceshq.yaml -C - 'https://drive.google.com/uc?export=download&id=1fHwGx_hnBtC8nsq7hesJvs-Klv-P0gzT' #FacesHQ\n", 237 | " !curl -L -o faceshq.ckpt -C - 'https://app.koofr.net/content/links/a04deec9-0c59-4673-8b37-3d696fe63a5d/files/get/last.ckpt?path=%2F2020-11-13T21-41-45_faceshq_transformer%2Fcheckpoints%2Flast.ckpt' #FacesHQ\n", 238 | "if wikiart_16384:\n", 239 | " !curl -L -o wikiart_16384.yaml -C - 'http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.yaml' #WikiArt 16384\n", 240 | " !curl -L -o wikiart_16384.ckpt -C - 'http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.ckpt' #WikiArt 16384\n", 241 | "if sflckr:\n", 242 | " !curl -L -o sflckr.yaml -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fconfigs%2F2020-11-09T13-31-51-project.yaml&dl=1' #S-FLCKR\n", 243 | " !curl -L -o sflckr.ckpt -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fcheckpoints%2Flast.ckpt&dl=1' #S-FLCKR" 244 | ], 245 | "execution_count": null, 246 | "outputs": [] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "metadata": { 251 | "id": "rXx8O5Fvgm6v", 252 | "cellView": "form" 253 | }, 254 | "source": [ 255 | "# @title #**Loading of libraries and definitions**\n", 256 | " \n", 257 | "import argparse\n", 258 | "import math\n", 259 | "from pathlib import Path\n", 260 | "import sys\n", 261 | "import os\n", 262 | "import cv2\n", 263 | "import pandas as pd\n", 264 | "import numpy as np\n", 265 | "import subprocess\n", 266 | " \n", 267 | "sys.path.append('./taming-transformers')\n", 268 | "\n", 269 | "# Some models include transformers, others need explicit pip install\n", 270 | "try:\n", 271 | " import transformers\n", 272 | "except Exception:\n", 273 | " !pip install transformers\n", 274 | " import transformers\n", 275 | "\n", 276 | "from IPython import display\n", 277 | "from base64 import b64encode\n", 278 | "from omegaconf import OmegaConf\n", 279 | "from PIL import Image\n", 280 | "from taming.models import cond_transformer, vqgan\n", 281 | "import torch\n", 282 | "from torch import nn, optim\n", 283 | "from torch.nn import functional as F\n", 284 | "from torchvision import transforms\n", 285 | "from torchvision.transforms import functional as TF\n", 286 | "from tqdm.notebook import tqdm\n", 287 | " \n", 288 | "from CLIP import clip\n", 289 | "import kornia.augmentation as K\n", 290 | "import numpy as np\n", 291 | "import imageio\n", 292 | "from PIL import ImageFile, Image\n", 293 | "from imgtag import ImgTag # metadata \n", 294 | "from libxmp import * # metadata\n", 295 | "import libxmp # metadata\n", 296 | "from stegano import lsb\n", 297 | "import json\n", 298 | "ImageFile.LOAD_TRUNCATED_IMAGES = True\n", 299 | " \n", 300 | "def sinc(x):\n", 301 | " return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))\n", 302 | " \n", 303 | " \n", 304 | "def lanczos(x, a):\n", 305 | " cond = torch.logical_and(-a < x, x < a)\n", 306 | " out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))\n", 307 | " return out / out.sum()\n", 308 | " \n", 309 | " \n", 310 | "def ramp(ratio, width):\n", 311 | " n = math.ceil(width / ratio + 1)\n", 312 | " out = torch.empty([n])\n", 313 | " cur = 0\n", 314 | " for i in range(out.shape[0]):\n", 315 | " out[i] = cur\n", 316 | " cur += ratio\n", 317 | " return torch.cat([-out[1:].flip([0]), out])[1:-1]\n", 318 | " \n", 319 | " \n", 320 | "def resample(input, size, align_corners=True):\n", 321 | " n, c, h, w = input.shape\n", 322 | " dh, dw = size\n", 323 | " \n", 324 | " input = input.view([n * c, 1, h, w])\n", 325 | " \n", 326 | " if dh < h:\n", 327 | " kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)\n", 328 | " pad_h = (kernel_h.shape[0] - 1) // 2\n", 329 | " input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')\n", 330 | " input = F.conv2d(input, kernel_h[None, None, :, None])\n", 331 | " \n", 332 | " if dw < w:\n", 333 | " kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)\n", 334 | " pad_w = (kernel_w.shape[0] - 1) // 2\n", 335 | " input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')\n", 336 | " input = F.conv2d(input, kernel_w[None, None, None, :])\n", 337 | " \n", 338 | " input = input.view([n, c, h, w])\n", 339 | " return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)\n", 340 | " \n", 341 | " \n", 342 | "class ReplaceGrad(torch.autograd.Function):\n", 343 | " @staticmethod\n", 344 | " def forward(ctx, x_forward, x_backward):\n", 345 | " ctx.shape = x_backward.shape\n", 346 | " return x_forward\n", 347 | " \n", 348 | " @staticmethod\n", 349 | " def backward(ctx, grad_in):\n", 350 | " return None, grad_in.sum_to_size(ctx.shape)\n", 351 | " \n", 352 | " \n", 353 | "replace_grad = ReplaceGrad.apply\n", 354 | " \n", 355 | " \n", 356 | "class ClampWithGrad(torch.autograd.Function):\n", 357 | " @staticmethod\n", 358 | " def forward(ctx, input, min, max):\n", 359 | " ctx.min = min\n", 360 | " ctx.max = max\n", 361 | " ctx.save_for_backward(input)\n", 362 | " return input.clamp(min, max)\n", 363 | " \n", 364 | " @staticmethod\n", 365 | " def backward(ctx, grad_in):\n", 366 | " input, = ctx.saved_tensors\n", 367 | " return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None\n", 368 | " \n", 369 | " \n", 370 | "clamp_with_grad = ClampWithGrad.apply\n", 371 | " \n", 372 | " \n", 373 | "def vector_quantize(x, codebook):\n", 374 | " d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T\n", 375 | " indices = d.argmin(-1)\n", 376 | " x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook\n", 377 | " return replace_grad(x_q, x)\n", 378 | " \n", 379 | " \n", 380 | "class Prompt(nn.Module):\n", 381 | " def __init__(self, embed, weight=1., stop=float('-inf')):\n", 382 | " super().__init__()\n", 383 | " self.register_buffer('embed', embed)\n", 384 | " self.register_buffer('weight', torch.as_tensor(weight))\n", 385 | " self.register_buffer('stop', torch.as_tensor(stop))\n", 386 | " \n", 387 | " def forward(self, input):\n", 388 | " input_normed = F.normalize(input.unsqueeze(1), dim=2)\n", 389 | " embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)\n", 390 | " dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)\n", 391 | " dists = dists * self.weight.sign()\n", 392 | " return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()\n", 393 | " \n", 394 | " \n", 395 | "def parse_prompt(prompt):\n", 396 | " vals = prompt.rsplit(':', 2)\n", 397 | " vals = vals + ['', '1', '-inf'][len(vals):]\n", 398 | " return vals[0], float(vals[1]), float(vals[2])\n", 399 | " \n", 400 | " \n", 401 | "class MakeCutouts(nn.Module):\n", 402 | " def __init__(self, cut_size, cutn, cut_pow=1.):\n", 403 | " super().__init__()\n", 404 | " self.cut_size = cut_size\n", 405 | " self.cutn = cutn\n", 406 | " self.cut_pow = cut_pow\n", 407 | " self.augs = nn.Sequential(\n", 408 | " K.RandomHorizontalFlip(p=0.5),\n", 409 | " # K.RandomSolarize(0.01, 0.01, p=0.7),\n", 410 | " K.RandomSharpness(0.3,p=0.4),\n", 411 | " K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'),\n", 412 | " K.RandomPerspective(0.2,p=0.4),\n", 413 | " K.ColorJitter(hue=0.01, saturation=0.01, p=0.7))\n", 414 | " self.noise_fac = 0.1\n", 415 | " \n", 416 | " \n", 417 | " def forward(self, input):\n", 418 | " sideY, sideX = input.shape[2:4]\n", 419 | " max_size = min(sideX, sideY)\n", 420 | " min_size = min(sideX, sideY, self.cut_size)\n", 421 | " cutouts = []\n", 422 | " for _ in range(self.cutn):\n", 423 | " size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)\n", 424 | " offsetx = torch.randint(0, sideX - size + 1, ())\n", 425 | " offsety = torch.randint(0, sideY - size + 1, ())\n", 426 | " cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n", 427 | " cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))\n", 428 | " batch = self.augs(torch.cat(cutouts, dim=0))\n", 429 | " if self.noise_fac:\n", 430 | " facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)\n", 431 | " batch = batch + facs * torch.randn_like(batch)\n", 432 | " return batch\n", 433 | " \n", 434 | " \n", 435 | "def load_vqgan_model(config_path, checkpoint_path):\n", 436 | " config = OmegaConf.load(config_path)\n", 437 | " if config.model.target == 'taming.models.vqgan.VQModel':\n", 438 | " model = vqgan.VQModel(**config.model.params)\n", 439 | " model.eval().requires_grad_(False)\n", 440 | " model.init_from_ckpt(checkpoint_path)\n", 441 | " elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':\n", 442 | " parent_model = cond_transformer.Net2NetTransformer(**config.model.params)\n", 443 | " parent_model.eval().requires_grad_(False)\n", 444 | " parent_model.init_from_ckpt(checkpoint_path)\n", 445 | " model = parent_model.first_stage_model\n", 446 | " else:\n", 447 | " raise ValueError(f'unknown model type: {config.model.target}')\n", 448 | " del model.loss\n", 449 | " return model\n", 450 | " \n", 451 | " \n", 452 | "def resize_image(image, out_size):\n", 453 | " ratio = image.size[0] / image.size[1]\n", 454 | " area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])\n", 455 | " size = round((area * ratio)**0.5), round((area / ratio)**0.5)\n", 456 | " return image.resize(size, Image.LANCZOS)" 457 | ], 458 | "execution_count": null, 459 | "outputs": [] 460 | }, 461 | { 462 | "cell_type": "markdown", 463 | "metadata": { 464 | "id": "4INwpxFKgvbo" 465 | }, 466 | "source": [ 467 | "## **Instructions for setting parameters:**\n", 468 | "\n", 469 | "| Parameter | Usage |\n", 470 | "|---|---|\n", 471 | "| `key_frames` | Whether to use key frames to change the parameters over the course of the run |\n", 472 | "| `text_prompts` | Text prompts, separated by \"\\|\" |\n", 473 | "| `width` | Width of the output, in pixels. This will be rounded down to a multiple of 16 |\n", 474 | "| `height` | Height of the output, in pixels. This will be rounded down to a multiple of 16 |\n", 475 | "| `model` | Choice of model, must be downloaded above |\n", 476 | "| `interval` | How often to display the frame in the notebook (doesn't affect the actual output) |\n", 477 | "| `initial_image` | Image to start with (relative path to file) |\n", 478 | "| `target_images` | Image prompts to target, separated by \"|\" (relative path to files) |\n", 479 | "| `seed` | Random seed, if set to a positive integer the run will be repeatable (get the same output for the same input each time, if set to -1 a random seed will be used. |\n", 480 | "| `max_frames` | Number of frames for the animation |\n", 481 | "| `angle` | Angle in degrees to rotate clockwise between each frame |\n", 482 | "| `zoom` | Factor to zoom in each frame, 1 is no zoom, less than 1 is zoom out, more than 1 is zoom in (negative is uninteresting, just adds an extra 180 rotation beyond that in angle) |\n", 483 | "| `translation_x` | Number of pixels to shift right each frame |\n", 484 | "| `translation_y` | Number of pixels to shift down each frame |\n", 485 | "| `iterations_per_frame` | Number of times to run the VQGAN+CLIP method each frame |\n", 486 | "| `save_all_iterations` | Debugging, set False in normal operation |\n", 487 | "\n", 488 | "---------\n", 489 | "\n", 490 | "Transformations (zoom, rotation, and translation)\n", 491 | "\n", 492 | "On each frame, the network restarts, is fed a version of the output zoomed in by `zoom` as the initial image, rotated clockwise by `angle` degrees, translated horizontally by `translation_x` pixels, and translated vertically by `translation_y` pixels. Then it runs `iterations_per_frame` iterations of the VQGAN+CLIP method. 0 `iterations_per_frame` is supported, to help test out the transformations without changing the image.\n", 493 | "\n", 494 | "For `iterations_per_frame = 1` (recommended for more abstract effects), the resulting images will not have much to do with the prompts, but at least one prompt is still required.\n", 495 | "\n", 496 | "In normal use, only the last iteration of each frame will be saved, but for trouble-shooting you can set `save_all_iterations` to True, and every iteration of each frame will be saved.\n", 497 | "\n", 498 | "----------------\n", 499 | "\n", 500 | "Mainly what you will have to modify will be `text_prompts`: there you can place the prompt(s) you want to generate (separated with |). It is a list because you can put more than one text, and so the AI tries to 'mix' the images, giving the same priority to both texts. You can also assign weights, to bias the priority towards one prompt or another, or negative weights, to remove an element (for example, a colour).\n", 501 | "\n", 502 | "Example of weights with decimals:\n", 503 | "\n", 504 | "Text : rubber:0.5 | rainbow:0.5\n", 505 | "\n", 506 | "To use an initial image to the model, you just have to upload a file to the Colab environment (in the section on the left), and then modify `initial_image`: putting the exact name of the file. Example: sample.png\n", 507 | "\n", 508 | "You can also change the model by changing the line that says `model`. Currently 1024, 16384, WikiArt, S-FLCKR and COCO-Stuff are available. To activate them you have to have downloaded them first, and then you can simply select it.\n", 509 | "\n", 510 | "You can also use `target_images`, which is basically putting one or more images on it that the AI will take as a \"target\", fulfilling the same function as putting text on it. To put more than one you have to use | as a separator.\n", 511 | "\n", 512 | "------------\n", 513 | "\n", 514 | "Key Frames\n", 515 | "\n", 516 | "If `key_frames` is set to True, you are able to change the parameters over the course of the run.\n", 517 | "To do this, put the parameters in in the following format:\n", 518 | "10:(0.5), 20: (1.0), 35: (-1.0)\n", 519 | "\n", 520 | "This means at frame 10, the value should be 0.5, at frame 20 the value should be 1.0, and at frame 35 the value should be -1.0. The value at each other frame will be linearly interpolated (that is, before frame 10, the value will be 0.5, between frame 10 and 20 the value will increase frame-by-frame from 0.5 to 1.0, between frame 20 and 35 the value will decrease frame-by-frame from 1.0 to -1.0, and after frame 35 the value will be -1.0)\n", 521 | "\n", 522 | "This also works for text_prompts, e.g. 10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)\n", 523 | "will start with an Apple value of 1, once it hits frame 10 it will start decreasing in in Apple and increasing in Orange until it hits frame 20. Note that Peach will have a value of 1 the whole time.\n", 524 | "\n", 525 | "If `key_frames` is set to True, all of the parameters which can be key-framed must be entered in this format." 526 | ] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "metadata": { 531 | "id": "9SqvStcBgr_0", 532 | "cellView": "form" 533 | }, 534 | "source": [ 535 | "#@title **Parameters**\n", 536 | "\n", 537 | "key_frames = True #@param {type:\"boolean\"}\n", 538 | "text_prompts = \"10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)\" #@param {type:\"string\"}\n", 539 | "width = 400#@param {type:\"number\"}\n", 540 | "height = 400#@param {type:\"number\"}\n", 541 | "model = \"vqgan_imagenet_f16_16384\" #@param [\"vqgan_imagenet_f16_16384\", \"vqgan_imagenet_f16_1024\", \"wikiart_16384\", \"coco\", \"faceshq\", \"sflckr\"]\n", 542 | "interval = 1#@param {type:\"number\"}\n", 543 | "initial_image = \"\"#@param {type:\"string\"}\n", 544 | "target_images = \"\"#@param {type:\"string\"}\n", 545 | "seed = 1#@param {type:\"number\"}\n", 546 | "max_frames = 50#@param {type:\"number\"}\n", 547 | "angle = \"10: (0), 30: (10), 50: (0)\"#@param {type:\"string\"}\n", 548 | "zoom = \"10: (1), 30: (1.2), 50: (1)\"#@param {type:\"string\"}\n", 549 | "translation_x = \"0: (0)\"#@param {type:\"string\"}\n", 550 | "translation_y = \"0: (0)\"#@param {type:\"string\"}\n", 551 | "iterations_per_frame = \"0: (10)\"#@param {type:\"string\"}\n", 552 | "save_all_iterations = False#@param {type:\"boolean\"}\n", 553 | "\n", 554 | "\n", 555 | "if initial_image != \"\":\n", 556 | " print(\n", 557 | " \"WARNING: You have specified an initial image. Note that the image resolution \"\n", 558 | " \"will be inherited from this image, not whatever width and height you specified. \"\n", 559 | " \"If the initial image resolution is too high, this can result in out of memory errors.\"\n", 560 | " )\n", 561 | "elif width * height > 160000:\n", 562 | " print(\n", 563 | " \"WARNING: The width and height you have specified may be too high, in which case \"\n", 564 | " \"you will encounter out of memory errors either at the image generation stage or the \"\n", 565 | " \"video synthesis stage. If so, try reducing the resolution\"\n", 566 | " )\n", 567 | "model_names={\n", 568 | " \"vqgan_imagenet_f16_16384\": 'ImageNet 16384',\n", 569 | " \"vqgan_imagenet_f16_1024\":\"ImageNet 1024\", \n", 570 | " \"wikiart_1024\":\"WikiArt 1024\",\n", 571 | " \"wikiart_16384\":\"WikiArt 16384\",\n", 572 | " \"coco\":\"COCO-Stuff\",\n", 573 | " \"faceshq\":\"FacesHQ\",\n", 574 | " \"sflckr\":\"S-FLCKR\"\n", 575 | "}\n", 576 | "model_name = model_names[model]\n", 577 | "\n", 578 | "if seed == -1:\n", 579 | " seed = None\n", 580 | "\n", 581 | "def parse_key_frames(string, prompt_parser=None):\n", 582 | " import re\n", 583 | " pattern = r'((?P[0-9]+):[\\s]*[\\(](?P[\\S\\s]*?)[\\)])'\n", 584 | " frames = dict()\n", 585 | " for match_object in re.finditer(pattern, string):\n", 586 | " frame = int(match_object.groupdict()['frame'])\n", 587 | " param = match_object.groupdict()['param']\n", 588 | " if prompt_parser:\n", 589 | " frames[frame] = prompt_parser(param)\n", 590 | " else:\n", 591 | " frames[frame] = param\n", 592 | "\n", 593 | " if frames == {} and len(string) != 0:\n", 594 | " raise RuntimeError('Key Frame string not correctly formatted')\n", 595 | " return frames\n", 596 | "\n", 597 | "def get_inbetweens(key_frames, integer=False):\n", 598 | " key_frame_series = pd.Series([np.nan for a in range(max_frames)])\n", 599 | " for i, value in key_frames.items():\n", 600 | " key_frame_series[i] = value\n", 601 | " key_frame_series = key_frame_series.astype(float)\n", 602 | " key_frame_series = key_frame_series.interpolate(limit_direction='both')\n", 603 | " if integer:\n", 604 | " return key_frame_series.astype(int)\n", 605 | " return key_frame_series\n", 606 | "\n", 607 | "def split_key_frame_text_prompts(frames):\n", 608 | " prompt_dict = dict()\n", 609 | " for i, parameters in frames.items():\n", 610 | " prompts = parameters.split('|')\n", 611 | " for prompt in prompts:\n", 612 | " string, value = prompt.split(':')\n", 613 | " string = string.strip()\n", 614 | " value = float(value.strip())\n", 615 | " if string in prompt_dict:\n", 616 | " prompt_dict[string][i] = value\n", 617 | " else:\n", 618 | " prompt_dict[string] = {i: value}\n", 619 | " prompt_series_dict = dict()\n", 620 | " for prompt, values in prompt_dict.items():\n", 621 | " value_string = (\n", 622 | " ', '.join([f'{value}: ({values[value]})' for value in values])\n", 623 | " )\n", 624 | " prompt_series = get_inbetweens(parse_key_frames(value_string))\n", 625 | " prompt_series_dict[prompt] = prompt_series\n", 626 | " prompt_list = []\n", 627 | " for i in range(max_frames):\n", 628 | " prompt_list.append(\n", 629 | " ' | '.join(\n", 630 | " [f'{prompt}: {prompt_series_dict[prompt][i]}'\n", 631 | " for prompt in prompt_series_dict]\n", 632 | " )\n", 633 | " )\n", 634 | " return prompt_list\n", 635 | "\n", 636 | "if key_frames:\n", 637 | " try:\n", 638 | " text_prompts_series = split_key_frame_text_prompts(\n", 639 | " parse_key_frames(text_prompts)\n", 640 | " )\n", 641 | " except RuntimeError as e:\n", 642 | " print(\n", 643 | " \"WARNING: You have selected to use key frames, but you have not \"\n", 644 | " \"formatted `text_prompts` correctly for key frames.\\n\"\n", 645 | " \"Attempting to interpret `text_prompts` as \"\n", 646 | " f'\"0: ({text_prompts}:1)\"\\n'\n", 647 | " \"Please read the instructions to find out how to use key frames \"\n", 648 | " \"correctly.\\n\"\n", 649 | " )\n", 650 | " text_prompts = f\"0: ({text_prompts}:1)\"\n", 651 | " text_prompts_series = split_key_frame_text_prompts(\n", 652 | " parse_key_frames(text_prompts)\n", 653 | " )\n", 654 | "\n", 655 | " try:\n", 656 | " target_images_series = split_key_frame_text_prompts(\n", 657 | " parse_key_frames(target_images)\n", 658 | " )\n", 659 | " except RuntimeError as e:\n", 660 | " print(\n", 661 | " \"WARNING: You have selected to use key frames, but you have not \"\n", 662 | " \"formatted `target_images` correctly for key frames.\\n\"\n", 663 | " \"Attempting to interpret `target_images` as \"\n", 664 | " f'\"0: ({target_images}:1)\"\\n'\n", 665 | " \"Please read the instructions to find out how to use key frames \"\n", 666 | " \"correctly.\\n\"\n", 667 | " )\n", 668 | " target_images = f\"0: ({target_images}:1)\"\n", 669 | " target_images_series = split_key_frame_text_prompts(\n", 670 | " parse_key_frames(target_images)\n", 671 | " )\n", 672 | "\n", 673 | " try:\n", 674 | " angle_series = get_inbetweens(parse_key_frames(angle))\n", 675 | " except RuntimeError as e:\n", 676 | " print(\n", 677 | " \"WARNING: You have selected to use key frames, but you have not \"\n", 678 | " \"formatted `angle` correctly for key frames.\\n\"\n", 679 | " \"Attempting to interpret `angle` as \"\n", 680 | " f'\"0: ({angle})\"\\n'\n", 681 | " \"Please read the instructions to find out how to use key frames \"\n", 682 | " \"correctly.\\n\"\n", 683 | " )\n", 684 | " angle = f\"0: ({angle})\"\n", 685 | " angle_series = get_inbetweens(parse_key_frames(angle))\n", 686 | "\n", 687 | " try:\n", 688 | " zoom_series = get_inbetweens(parse_key_frames(zoom))\n", 689 | " except RuntimeError as e:\n", 690 | " print(\n", 691 | " \"WARNING: You have selected to use key frames, but you have not \"\n", 692 | " \"formatted `zoom` correctly for key frames.\\n\"\n", 693 | " \"Attempting to interpret `zoom` as \"\n", 694 | " f'\"0: ({zoom})\"\\n'\n", 695 | " \"Please read the instructions to find out how to use key frames \"\n", 696 | " \"correctly.\\n\"\n", 697 | " )\n", 698 | " zoom = f\"0: ({zoom})\"\n", 699 | " zoom_series = get_inbetweens(parse_key_frames(zoom))\n", 700 | "\n", 701 | " try:\n", 702 | " translation_x_series = get_inbetweens(parse_key_frames(translation_x))\n", 703 | " except RuntimeError as e:\n", 704 | " print(\n", 705 | " \"WARNING: You have selected to use key frames, but you have not \"\n", 706 | " \"formatted `translation_x` correctly for key frames.\\n\"\n", 707 | " \"Attempting to interpret `translation_x` as \"\n", 708 | " f'\"0: ({translation_x})\"\\n'\n", 709 | " \"Please read the instructions to find out how to use key frames \"\n", 710 | " \"correctly.\\n\"\n", 711 | " )\n", 712 | " translation_x = f\"0: ({translation_x})\"\n", 713 | " translation_x_series = get_inbetweens(parse_key_frames(translation_x))\n", 714 | "\n", 715 | " try:\n", 716 | " translation_y_series = get_inbetweens(parse_key_frames(translation_y))\n", 717 | " except RuntimeError as e:\n", 718 | " print(\n", 719 | " \"WARNING: You have selected to use key frames, but you have not \"\n", 720 | " \"formatted `translation_y` correctly for key frames.\\n\"\n", 721 | " \"Attempting to interpret `translation_y` as \"\n", 722 | " f'\"0: ({translation_y})\"\\n'\n", 723 | " \"Please read the instructions to find out how to use key frames \"\n", 724 | " \"correctly.\\n\"\n", 725 | " )\n", 726 | " translation_y = f\"0: ({translation_y})\"\n", 727 | " translation_y_series = get_inbetweens(parse_key_frames(translation_y))\n", 728 | "\n", 729 | " try:\n", 730 | " iterations_per_frame_series = get_inbetweens(\n", 731 | " parse_key_frames(iterations_per_frame), integer=True\n", 732 | " )\n", 733 | " except RuntimeError as e:\n", 734 | " print(\n", 735 | " \"WARNING: You have selected to use key frames, but you have not \"\n", 736 | " \"formatted `iterations_per_frame` correctly for key frames.\\n\"\n", 737 | " \"Attempting to interpret `iterations_per_frame` as \"\n", 738 | " f'\"0: ({iterations_per_frame})\"\\n'\n", 739 | " \"Please read the instructions to find out how to use key frames \"\n", 740 | " \"correctly.\\n\"\n", 741 | " )\n", 742 | " iterations_per_frame = f\"0: ({iterations_per_frame})\"\n", 743 | " \n", 744 | " iterations_per_frame_series = get_inbetweens(\n", 745 | " parse_key_frames(iterations_per_frame), integer=True\n", 746 | " )\n", 747 | "else:\n", 748 | " text_prompts = [phrase.strip() for phrase in text_prompts.split(\"|\")]\n", 749 | " if text_prompts == ['']:\n", 750 | " text_prompts = []\n", 751 | " if target_images == \"None\" or not target_images:\n", 752 | " target_images = []\n", 753 | " else:\n", 754 | " target_images = target_images.split(\"|\")\n", 755 | " target_images = [image.strip() for image in target_images]\n", 756 | "\n", 757 | " angle = float(angle)\n", 758 | " zoom = float(zoom)\n", 759 | " translation_x = float(translation_x)\n", 760 | " translation_y = float(translation_y)\n", 761 | " iterations_per_frame = int(iterations_per_frame)\n", 762 | "\n", 763 | "args = argparse.Namespace(\n", 764 | " prompts=text_prompts,\n", 765 | " image_prompts=target_images,\n", 766 | " noise_prompt_seeds=[],\n", 767 | " noise_prompt_weights=[],\n", 768 | " size=[width, height],\n", 769 | " init_weight=0.,\n", 770 | " clip_model='ViT-B/32',\n", 771 | " vqgan_config=f'{model}.yaml',\n", 772 | " vqgan_checkpoint=f'{model}.ckpt',\n", 773 | " step_size=0.1,\n", 774 | " cutn=64,\n", 775 | " cut_pow=1.,\n", 776 | " display_freq=interval,\n", 777 | " seed=seed,\n", 778 | ")" 779 | ], 780 | "execution_count": null, 781 | "outputs": [] 782 | }, 783 | { 784 | "cell_type": "code", 785 | "metadata": { 786 | "id": "R1cLyxLcg8RA", 787 | "cellView": "form" 788 | }, 789 | "source": [ 790 | "#@markdown **This cell deletes any frames already in the steps directory. Make sure you have saved any frames you want to keep from previous runs**\n", 791 | "\n", 792 | "path = f'{working_dir}/steps'\n", 793 | "!rm -r {path}\n", 794 | "!mkdir --parents {path}" 795 | ], 796 | "execution_count": null, 797 | "outputs": [] 798 | }, 799 | { 800 | "cell_type": "code", 801 | "metadata": { 802 | "id": "M2cDGzStHUna", 803 | "cellView": "form" 804 | }, 805 | "source": [ 806 | "#@markdown **Delete/clear the generated video**\n", 807 | "\n", 808 | "if key_frames:\n", 809 | " # key frame filename would be too long\n", 810 | " filename = \"video.mp4\"\n", 811 | "else:\n", 812 | " filename = f\"{'_'.join(text_prompts).replace(' ', '')}.mp4\"\n", 813 | "filepath = f'{working_dir}/{filename}'" 814 | ], 815 | "execution_count": null, 816 | "outputs": [] 817 | }, 818 | { 819 | "cell_type": "code", 820 | "metadata": { 821 | "id": "owUFbInCHW9W", 822 | "cellView": "form" 823 | }, 824 | "source": [ 825 | "#@title #**Fire up the A.I**\n", 826 | "\n", 827 | "# Delete memory from previous runs\n", 828 | "!nvidia-smi -caa\n", 829 | "for var in ['device', 'model', 'perceptor', 'z']:\n", 830 | " try:\n", 831 | " del globals()[var]\n", 832 | " except:\n", 833 | " pass\n", 834 | "\n", 835 | "try:\n", 836 | " import gc\n", 837 | " gc.collect()\n", 838 | "except:\n", 839 | " pass\n", 840 | "\n", 841 | "try:\n", 842 | " torch.cuda.empty_cache()\n", 843 | "except:\n", 844 | " pass\n", 845 | "\n", 846 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 847 | "print('Using device:', device)\n", 848 | "if not key_frames:\n", 849 | " if text_prompts:\n", 850 | " print('Using text prompts:', text_prompts)\n", 851 | " if target_images:\n", 852 | " print('Using image prompts:', target_images)\n", 853 | "if args.seed is None:\n", 854 | " seed = torch.seed()\n", 855 | "else:\n", 856 | " seed = args.seed\n", 857 | "torch.manual_seed(seed)\n", 858 | "print('Using seed:', seed)\n", 859 | " \n", 860 | "model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)\n", 861 | "perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)\n", 862 | " \n", 863 | "cut_size = perceptor.visual.input_resolution\n", 864 | "e_dim = model.quantize.e_dim\n", 865 | "f = 2**(model.decoder.num_resolutions - 1)\n", 866 | "make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow)\n", 867 | "n_toks = model.quantize.n_e\n", 868 | "toksX, toksY = args.size[0] // f, args.size[1] // f\n", 869 | "sideX, sideY = toksX * f, toksY * f\n", 870 | "z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]\n", 871 | "z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]\n", 872 | "stop_on_next_loop = False # Make sure GPU memory doesn't get corrupted from cancelling the run mid-way through, allow a full frame to complete\n", 873 | "\n", 874 | "def read_image_workaround(path):\n", 875 | " \"\"\"OpenCV reads images as BGR, Pillow saves them as RGB. Work around\n", 876 | " this incompatibility to avoid colour inversions.\"\"\"\n", 877 | " im_tmp = cv2.imread(path)\n", 878 | " return cv2.cvtColor(im_tmp, cv2.COLOR_BGR2RGB)\n", 879 | "\n", 880 | "for i in range(max_frames):\n", 881 | " if stop_on_next_loop:\n", 882 | " break\n", 883 | " if key_frames:\n", 884 | " text_prompts = text_prompts_series[i]\n", 885 | " text_prompts = [phrase.strip() for phrase in text_prompts.split(\"|\")]\n", 886 | " if text_prompts == ['']:\n", 887 | " text_prompts = []\n", 888 | " args.prompts = text_prompts\n", 889 | "\n", 890 | " target_images = target_images_series[i]\n", 891 | "\n", 892 | " if target_images == \"None\" or not target_images:\n", 893 | " target_images = []\n", 894 | " else:\n", 895 | " target_images = target_images.split(\"|\")\n", 896 | " target_images = [image.strip() for image in target_images]\n", 897 | " args.image_prompts = target_images\n", 898 | "\n", 899 | " angle = angle_series[i]\n", 900 | " zoom = zoom_series[i]\n", 901 | " translation_x = translation_x_series[i]\n", 902 | " translation_y = translation_y_series[i]\n", 903 | " iterations_per_frame = iterations_per_frame_series[i]\n", 904 | " print(\n", 905 | " f'text_prompts: {text_prompts}'\n", 906 | " f'angle: {angle}',\n", 907 | " f'zoom: {zoom}',\n", 908 | " f'translation_x: {translation_x}',\n", 909 | " f'translation_y: {translation_y}',\n", 910 | " f'iterations_per_frame: {iterations_per_frame}'\n", 911 | " )\n", 912 | " try:\n", 913 | " if i == 0 and initial_image != \"\":\n", 914 | " img_0 = read_image_workaround(initial_image)\n", 915 | " z, *_ = model.encode(TF.to_tensor(img_0).to(device).unsqueeze(0) * 2 - 1)\n", 916 | " elif i == 0 and not os.path.isfile(f'{working_dir}/steps/{i:04d}.png'):\n", 917 | " one_hot = F.one_hot(\n", 918 | " torch.randint(n_toks, [toksY * toksX], device=device), n_toks\n", 919 | " ).float()\n", 920 | " z = one_hot @ model.quantize.embedding.weight\n", 921 | " z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)\n", 922 | " else:\n", 923 | " if save_all_iterations:\n", 924 | " img_0 = read_image_workaround(\n", 925 | " f'{working_dir}/steps/{i:04d}_{iterations_per_frame}.png')\n", 926 | " else:\n", 927 | " img_0 = read_image_workaround(f'{working_dir}/steps/{i:04d}.png')\n", 928 | "\n", 929 | " center = (1*img_0.shape[1]//2, 1*img_0.shape[0]//2)\n", 930 | " trans_mat = np.float32(\n", 931 | " [[1, 0, translation_x],\n", 932 | " [0, 1, translation_y]]\n", 933 | " )\n", 934 | " rot_mat = cv2.getRotationMatrix2D( center, angle, zoom )\n", 935 | "\n", 936 | " trans_mat = np.vstack([trans_mat, [0,0,1]])\n", 937 | " rot_mat = np.vstack([rot_mat, [0,0,1]])\n", 938 | " transformation_matrix = np.matmul(rot_mat, trans_mat)\n", 939 | "\n", 940 | " img_0 = cv2.warpPerspective(\n", 941 | " img_0,\n", 942 | " transformation_matrix,\n", 943 | " (img_0.shape[1], img_0.shape[0]),\n", 944 | " borderMode=cv2.BORDER_WRAP\n", 945 | " )\n", 946 | " z, *_ = model.encode(TF.to_tensor(img_0).to(device).unsqueeze(0) * 2 - 1)\n", 947 | " i += 1\n", 948 | "\n", 949 | " z_orig = z.clone()\n", 950 | " z.requires_grad_(True)\n", 951 | " opt = optim.Adam([z], lr=args.step_size)\n", 952 | "\n", 953 | " normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],\n", 954 | " std=[0.26862954, 0.26130258, 0.27577711])\n", 955 | "\n", 956 | " pMs = []\n", 957 | "\n", 958 | " for prompt in args.prompts:\n", 959 | " txt, weight, stop = parse_prompt(prompt)\n", 960 | " embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()\n", 961 | " pMs.append(Prompt(embed, weight, stop).to(device))\n", 962 | "\n", 963 | " for prompt in args.image_prompts:\n", 964 | " path, weight, stop = parse_prompt(prompt)\n", 965 | " img = resize_image(Image.open(path).convert('RGB'), (sideX, sideY))\n", 966 | " batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))\n", 967 | " embed = perceptor.encode_image(normalize(batch)).float()\n", 968 | " pMs.append(Prompt(embed, weight, stop).to(device))\n", 969 | "\n", 970 | " for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):\n", 971 | " gen = torch.Generator().manual_seed(seed)\n", 972 | " embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)\n", 973 | " pMs.append(Prompt(embed, weight).to(device))\n", 974 | "\n", 975 | " def synth(z):\n", 976 | " z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)\n", 977 | " return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)\n", 978 | "\n", 979 | " def add_xmp_data(filename):\n", 980 | " imagen = ImgTag(filename=filename)\n", 981 | " imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'creator', 'VQGAN+CLIP', {\"prop_array_is_ordered\":True, \"prop_value_is_array\":True})\n", 982 | " if args.prompts:\n", 983 | " imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'title', \" | \".join(args.prompts), {\"prop_array_is_ordered\":True, \"prop_value_is_array\":True})\n", 984 | " else:\n", 985 | " imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'title', 'None', {\"prop_array_is_ordered\":True, \"prop_value_is_array\":True})\n", 986 | " imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'i', str(i), {\"prop_array_is_ordered\":True, \"prop_value_is_array\":True})\n", 987 | " imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'model', model_name, {\"prop_array_is_ordered\":True, \"prop_value_is_array\":True})\n", 988 | " imagen.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'seed',str(seed) , {\"prop_array_is_ordered\":True, \"prop_value_is_array\":True})\n", 989 | " imagen.close()\n", 990 | "\n", 991 | " def add_stegano_data(filename):\n", 992 | " data = {\n", 993 | " \"title\": \" | \".join(args.prompts) if args.prompts else None,\n", 994 | " \"notebook\": \"VQGAN+CLIP\",\n", 995 | " \"i\": i,\n", 996 | " \"model\": model_name,\n", 997 | " \"seed\": str(seed),\n", 998 | " }\n", 999 | " lsb.hide(filename, json.dumps(data)).save(filename)\n", 1000 | "\n", 1001 | " @torch.no_grad()\n", 1002 | " def checkin(i, losses):\n", 1003 | " losses_str = ', '.join(f'{loss.item():g}' for loss in losses)\n", 1004 | " tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}')\n", 1005 | " out = synth(z)\n", 1006 | " TF.to_pil_image(out[0].cpu()).save('progress.png')\n", 1007 | " add_stegano_data('progress.png')\n", 1008 | " add_xmp_data('progress.png')\n", 1009 | " display.display(display.Image('progress.png'))\n", 1010 | "\n", 1011 | " def save_output(i, img, suffix=None):\n", 1012 | " filename = \\\n", 1013 | " f\"{working_dir}/steps/{i:04}{'_' + suffix if suffix else ''}.png\"\n", 1014 | " imageio.imwrite(filename, np.array(img))\n", 1015 | " add_stegano_data(filename)\n", 1016 | " add_xmp_data(filename)\n", 1017 | "\n", 1018 | " def ascend_txt(i, save=True, suffix=None):\n", 1019 | " out = synth(z)\n", 1020 | " iii = perceptor.encode_image(normalize(make_cutouts(out))).float()\n", 1021 | "\n", 1022 | " result = []\n", 1023 | "\n", 1024 | " if args.init_weight:\n", 1025 | " result.append(F.mse_loss(z, z_orig) * args.init_weight / 2)\n", 1026 | "\n", 1027 | " for prompt in pMs:\n", 1028 | " result.append(prompt(iii))\n", 1029 | " img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]\n", 1030 | " img = np.transpose(img, (1, 2, 0))\n", 1031 | " if save:\n", 1032 | " save_output(i, img, suffix=suffix)\n", 1033 | " return result\n", 1034 | "\n", 1035 | " def train(i, save=True, suffix=None):\n", 1036 | " opt.zero_grad()\n", 1037 | " lossAll = ascend_txt(i, save=save, suffix=suffix)\n", 1038 | " if i % args.display_freq == 0 and save:\n", 1039 | " checkin(i, lossAll)\n", 1040 | " loss = sum(lossAll)\n", 1041 | " loss.backward()\n", 1042 | " opt.step()\n", 1043 | " with torch.no_grad():\n", 1044 | " z.copy_(z.maximum(z_min).minimum(z_max))\n", 1045 | "\n", 1046 | " with tqdm() as pbar:\n", 1047 | " if iterations_per_frame == 0:\n", 1048 | " save_output(i, img_0)\n", 1049 | " j = 1\n", 1050 | " while True:\n", 1051 | " suffix = (str(j) if save_all_iterations else None)\n", 1052 | " if j >= iterations_per_frame:\n", 1053 | " train(i, save=True, suffix=suffix)\n", 1054 | " break\n", 1055 | " if save_all_iterations:\n", 1056 | " train(i, save=True, suffix=suffix)\n", 1057 | " else:\n", 1058 | " train(i, save=False, suffix=suffix)\n", 1059 | " j += 1\n", 1060 | " pbar.update()\n", 1061 | " except KeyboardInterrupt:\n", 1062 | " stop_on_next_loop = True\n", 1063 | " pass" 1064 | ], 1065 | "execution_count": null, 1066 | "outputs": [] 1067 | }, 1068 | { 1069 | "cell_type": "markdown", 1070 | "metadata": { 1071 | "id": "E25qfzqvHfDN" 1072 | }, 1073 | "source": [ 1074 | "# **Optional: SRCNN for increasing resolution**" 1075 | ] 1076 | }, 1077 | { 1078 | "cell_type": "code", 1079 | "metadata": { 1080 | "id": "E-rTtjYHHhQ2", 1081 | "cellView": "form" 1082 | }, 1083 | "source": [ 1084 | "#@markdown ## **Install SRCNN** \n", 1085 | "\n", 1086 | "!git clone https://github.com/Mirwaisse/SRCNN.git\n", 1087 | "!curl https://raw.githubusercontent.com/justinjohn0306/SRCNN/master/models/model_2x.pth -o model_2x.pth\n", 1088 | "!curl https://raw.githubusercontent.com/justinjohn0306/SRCNN/master/models/model_3x.pth -o model_3x.pth\n", 1089 | "!curl https://raw.githubusercontent.com/justinjohn0306/SRCNN/master/models/model_4x.pth -o model_4x.pth" 1090 | ], 1091 | "execution_count": null, 1092 | "outputs": [] 1093 | }, 1094 | { 1095 | "cell_type": "code", 1096 | "metadata": { 1097 | "id": "lp0_TdGHHjtM", 1098 | "cellView": "form" 1099 | }, 1100 | "source": [ 1101 | "#@markdown ## **Increase Resolution**\n", 1102 | "\n", 1103 | "# import subprocess in case this cell is run without the above cells\n", 1104 | "import subprocess\n", 1105 | "# Set zoomed = True if this cell is run\n", 1106 | "zoomed = True\n", 1107 | "\n", 1108 | "init_frame = 1#@param {type:\"number\"}\n", 1109 | "last_frame = i#@param {type:\"number\"}\n", 1110 | "\n", 1111 | "for i in range(init_frame, last_frame): #\n", 1112 | " filename = f\"{i:04}.png\"\n", 1113 | " cmd = [\n", 1114 | " 'python',\n", 1115 | " '/content/SRCNN/run.py',\n", 1116 | " '--zoom_factor',\n", 1117 | " '2', # Note if you increase this, you also need to change the model.\n", 1118 | " '--model',\n", 1119 | " '/content/model_2x.pth', # 2x, 3x and 4x are available from the repo above\n", 1120 | " '--image',\n", 1121 | " filename,\n", 1122 | " '--cuda'\n", 1123 | " ]\n", 1124 | " print(f'Upscaling frame {i}')\n", 1125 | "\n", 1126 | " process = subprocess.Popen(cmd, cwd=f'{working_dir}/steps/')\n", 1127 | " stdout, stderr = process.communicate()\n", 1128 | " if process.returncode != 0:\n", 1129 | " print(stderr)\n", 1130 | " print(\n", 1131 | " \"You may be able to avoid this error by backing up the frames,\"\n", 1132 | " \"restarting the notebook, and running only the video synthesis cells,\"\n", 1133 | " \"or by decreasing the resolution of the image generation steps. \"\n", 1134 | " \"If you restart the notebook, you will have to define the `filepath` manually\"\n", 1135 | " \"by adding `filepath = 'PATH_TO_THE_VIDEO'` to the beginning of this cell. \"\n", 1136 | " \"If these steps do not work, please post the traceback in the github.\"\n", 1137 | " )\n", 1138 | " raise RuntimeError(stderr)" 1139 | ], 1140 | "execution_count": null, 1141 | "outputs": [] 1142 | }, 1143 | { 1144 | "cell_type": "markdown", 1145 | "metadata": { 1146 | "id": "zio2FUEtHmCH" 1147 | }, 1148 | "source": [ 1149 | "## **Make a video of the results**\n", 1150 | "\n", 1151 | "To generate a video with the frames, run the cell below. You can modify the number of FPS, the initial frame, the last frame, etc." 1152 | ] 1153 | }, 1154 | { 1155 | "cell_type": "code", 1156 | "metadata": { 1157 | "id": "rWFJWlcMHoHp", 1158 | "cellView": "form" 1159 | }, 1160 | "source": [ 1161 | "# @title ### **Create video**\n", 1162 | "\n", 1163 | "# import subprocess in case this cell is run without the above cells\n", 1164 | "import subprocess\n", 1165 | "\n", 1166 | "init_frame = 1#@param {type:\"number\"} This is the frame where the video will start\n", 1167 | "last_frame = i#@param {type:\"number\"} You can change i to the number of the last frame you want to generate. It will raise an error if that number of frames does not exist.\n", 1168 | "fps = 12#@param {type:\"number\"}\n", 1169 | "\n", 1170 | "frames = []\n", 1171 | "# tqdm.write('Generating video...')\n", 1172 | "try:\n", 1173 | " zoomed\n", 1174 | "except NameError:\n", 1175 | " image_path = f'{working_dir}/steps/%04d.png'\n", 1176 | "else:\n", 1177 | " image_path = f'{working_dir}/steps/zoomed_%04d.png'\n", 1178 | "\n", 1179 | "cmd = [\n", 1180 | " 'ffmpeg',\n", 1181 | " '-y',\n", 1182 | " '-vcodec',\n", 1183 | " 'png',\n", 1184 | " '-r',\n", 1185 | " str(fps),\n", 1186 | " '-start_number',\n", 1187 | " str(init_frame),\n", 1188 | " '-i',\n", 1189 | " image_path,\n", 1190 | " '-c:v',\n", 1191 | " 'libx264',\n", 1192 | " '-vf',\n", 1193 | " f'fps={fps}',\n", 1194 | " '-pix_fmt',\n", 1195 | " 'yuv420p',\n", 1196 | " '-crf',\n", 1197 | " '17',\n", 1198 | " '-preset',\n", 1199 | " 'veryslow',\n", 1200 | " filepath\n", 1201 | "]\n", 1202 | "\n", 1203 | "process = subprocess.Popen(cmd, cwd=f'{working_dir}/steps/', stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n", 1204 | "stdout, stderr = process.communicate()\n", 1205 | "if process.returncode != 0:\n", 1206 | " print(stderr)\n", 1207 | " print(\n", 1208 | " \"You may be able to avoid this error by backing up the frames,\"\n", 1209 | " \"restarting the notebook, and running only the video synthesis cells,\"\n", 1210 | " \"or by decreasing the resolution of the image generation steps. \"\n", 1211 | " \"If you restart the notebook, you will have to define the `filepath` manually\"\n", 1212 | " \"by adding `filepath = 'PATH_TO_THE_VIDEO'` to the beginning of this cell. \"\n", 1213 | " \"If these steps do not work, please post the traceback in the github.\"\n", 1214 | " )\n", 1215 | " raise RuntimeError(stderr)\n", 1216 | "else:\n", 1217 | " print(\"The video is ready\")" 1218 | ], 1219 | "execution_count": null, 1220 | "outputs": [] 1221 | }, 1222 | { 1223 | "cell_type": "code", 1224 | "metadata": { 1225 | "id": "TBUF3NxeHrOl", 1226 | "cellView": "form" 1227 | }, 1228 | "source": [ 1229 | "# @title **See video in the browser**\n", 1230 | "# @markdown This process may take a little longer. If you don't want to wait, download it by executing the next cell instead of using this cell.\n", 1231 | "mp4 = open(filepath,'rb').read()\n", 1232 | "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", 1233 | "display.HTML(\"\"\"\n", 1234 | "\n", 1237 | "\"\"\" % data_url)" 1238 | ], 1239 | "execution_count": null, 1240 | "outputs": [] 1241 | }, 1242 | { 1243 | "cell_type": "code", 1244 | "metadata": { 1245 | "id": "w4w4HfRpHt8e", 1246 | "cellView": "form" 1247 | }, 1248 | "source": [ 1249 | "# @title **Download video**\n", 1250 | "from google.colab import files\n", 1251 | "files.download(filepath)" 1252 | ], 1253 | "execution_count": null, 1254 | "outputs": [] 1255 | }, 1256 | { 1257 | "cell_type": "markdown", 1258 | "metadata": { 1259 | "id": "-UvfjxCGHwz2" 1260 | }, 1261 | "source": [ 1262 | "# Optional: Super-Slomo for smoothing movement\n", 1263 | "\n", 1264 | "This step might run out of memory if you run it right after the steps above. If it does, restart the notebook, upload a saved copy of the video from the previous step (or get it from google drive) and define the variable `filepath` with the path to the video before running the cells below again" 1265 | ] 1266 | }, 1267 | { 1268 | "cell_type": "code", 1269 | "metadata": { 1270 | "id": "q9cCek0HHzDw", 1271 | "cellView": "form" 1272 | }, 1273 | "source": [ 1274 | "# @title **Download Super-Slomo model**\n", 1275 | "!git clone -q --depth 1 https://github.com/avinashpaliwal/Super-SloMo.git\n", 1276 | "from os.path import exists\n", 1277 | "def download_from_google_drive(file_id, file_name):\n", 1278 | " # download a file from the Google Drive link\n", 1279 | " !rm -f ./cookie\n", 1280 | " !curl -c ./cookie -s -L \"https://drive.google.com/uc?export=download&id={file_id}\" > /dev/null\n", 1281 | " confirm_text = !awk '/download/ {print $NF}' ./cookie\n", 1282 | " confirm_text = confirm_text[0]\n", 1283 | " !curl -Lb ./cookie \"https://drive.google.com/uc?export=download&confirm={confirm_text}&id={file_id}\" -o {file_name}\n", 1284 | " \n", 1285 | "pretrained_model = 'SuperSloMo.ckpt'\n", 1286 | "if not exists(pretrained_model):\n", 1287 | " download_from_google_drive('1IvobLDbRiBgZr3ryCRrWL8xDbMZ-KnpF', pretrained_model)" 1288 | ], 1289 | "execution_count": null, 1290 | "outputs": [] 1291 | }, 1292 | { 1293 | "cell_type": "code", 1294 | "metadata": { 1295 | "id": "ybgyF2flcmME", 1296 | "cellView": "form" 1297 | }, 1298 | "source": [ 1299 | "# import subprocess in case this cell is run without the above cells\n", 1300 | "import subprocess\n", 1301 | "\n", 1302 | "SLOW_MOTION_FACTOR = 3#@param {type:\"number\"}\n", 1303 | "TARGET_FPS = 12#@param {type:\"number\"}\n", 1304 | "\n", 1305 | "cmd1 = [\n", 1306 | " 'python',\n", 1307 | " 'Super-SloMo/video_to_slomo.py',\n", 1308 | " '--checkpoint',\n", 1309 | " pretrained_model,\n", 1310 | " '--video',\n", 1311 | " filepath,\n", 1312 | " '--sf',\n", 1313 | " str(SLOW_MOTION_FACTOR),\n", 1314 | " '--fps',\n", 1315 | " str(TARGET_FPS),\n", 1316 | " '--output',\n", 1317 | " f'{filepath}-slomo.mkv',\n", 1318 | "]\n", 1319 | "process = subprocess.Popen(cmd1, cwd=f'/content', stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n", 1320 | "stdout, stderr = process.communicate()\n", 1321 | "if process.returncode != 0:\n", 1322 | " raise RuntimeError(stderr)\n", 1323 | "\n", 1324 | "cmd2 = [\n", 1325 | " 'ffmpeg',\n", 1326 | " '-i',\n", 1327 | " f'{filepath}-slomo.mkv',\n", 1328 | " '-pix_fmt',\n", 1329 | " 'yuv420p',\n", 1330 | " '-crf',\n", 1331 | " '17',\n", 1332 | " '-preset',\n", 1333 | " 'veryslow',\n", 1334 | " f'{filepath}-slomo.mp4',\n", 1335 | "]\n", 1336 | "\n", 1337 | "process = subprocess.Popen(cmd2, cwd=f'/content', stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n", 1338 | "stdout, stderr = process.communicate()\n", 1339 | "if process.returncode != 0:\n", 1340 | " raise RuntimeError(stderr)\n", 1341 | " print(stderr)\n", 1342 | " print(\n", 1343 | " \"You may be able to avoid this error by backing up the frames,\"\n", 1344 | " \"restarting the notebook, and running only the video synthesis cells,\"\n", 1345 | " \"or by decreasing the resolution of the image generation steps. \"\n", 1346 | " \"If you restart the notebook, you will have to define the `filepath` manually\"\n", 1347 | " \"by adding `filepath = 'PATH_TO_THE_VIDEO'` to the beginning of this cell. \"\n", 1348 | " \"If these steps do not work, please post the traceback in the github.\"\n", 1349 | " )\n" 1350 | ], 1351 | "execution_count": null, 1352 | "outputs": [] 1353 | }, 1354 | { 1355 | "cell_type": "code", 1356 | "metadata": { 1357 | "id": "FwQCcrggH1oY", 1358 | "cellView": "form" 1359 | }, 1360 | "source": [ 1361 | "# @title **See video in the browser**\n", 1362 | "# @markdown This process may take a little longer. If you don't want to wait, download it by executing the next cell instead of using this cell.\n", 1363 | "from base64 import b64encode\n", 1364 | "from IPython import display\n", 1365 | "mp4 = open(f'{filepath}-slomo.mp4','rb').read()\n", 1366 | "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", 1367 | "display.HTML(\"\"\"\n", 1368 | "\n", 1371 | "\"\"\" % data_url)" 1372 | ], 1373 | "execution_count": null, 1374 | "outputs": [] 1375 | }, 1376 | { 1377 | "cell_type": "code", 1378 | "metadata": { 1379 | "id": "WfJrASI3ctpP", 1380 | "cellView": "form" 1381 | }, 1382 | "source": [ 1383 | "# @title **Download video**\n", 1384 | "from google.colab import files\n", 1385 | "files.download(f'{filepath}-slomo.mp4')" 1386 | ], 1387 | "execution_count": null, 1388 | "outputs": [] 1389 | }, 1390 | { 1391 | "cell_type": "markdown", 1392 | "metadata": { 1393 | "id": "iAM3C4vJH5b7" 1394 | }, 1395 | "source": [ 1396 | "JS to prevent idle timeout:\n", 1397 | "\n", 1398 | "Press F12 OR CTRL + SHIFT + I OR right click on this website -> inspect.\n", 1399 | "Then click on the console tab and paste in the following code.\n", 1400 | "\n", 1401 | "```javascript\n", 1402 | "function ClickConnect(){\n", 1403 | "console.log(\"Working\");\n", 1404 | "document.querySelector(\"colab-toolbar-button#connect\").click()\n", 1405 | "}\n", 1406 | "setInterval(ClickConnect,60000)\n", 1407 | "```" 1408 | ] 1409 | } 1410 | ] 1411 | } -------------------------------------------------------------------------------- /VQGAN+CLIP_(z+quantize_method_with_augmentations,_user_friendly_interface).ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 5, 4 | "metadata": { 5 | "kernelspec": { 6 | "display_name": "Python 3", 7 | "language": "python", 8 | "name": "python3" 9 | }, 10 | "language_info": { 11 | "codemirror_mode": { 12 | "name": "ipython", 13 | "version": 3 14 | }, 15 | "file_extension": ".py", 16 | "mimetype": "text/x-python", 17 | "name": "python", 18 | "nbconvert_exporter": "python", 19 | "pygments_lexer": "ipython3", 20 | "version": "3.8.8" 21 | }, 22 | "colab": { 23 | "name": "VQGAN+CLIP_(z+quantize_method_with_augmentations,_user_friendly_interface).ipynb", 24 | "provenance": [], 25 | "collapsed_sections": [] 26 | }, 27 | "accelerator": "GPU" 28 | }, 29 | "cells": [ 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "id": "2061ea2b-37f6-44d5-8012-58bd9abf6548" 34 | }, 35 | "source": [ 36 | "**Generate images from text phrases with VQGAN and CLIP** (*z+quantize method with augmentations*)." 37 | ], 38 | "id": "2061ea2b-37f6-44d5-8012-58bd9abf6548" 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": { 43 | "id": "c3d7a8be-73ce-4cee-be70-e21c1210a7a6" 44 | }, 45 | "source": [ 46 | "Original Notebook made by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings).\n", 47 | "\n", 48 | "Modified by: Justin John" 49 | ], 50 | "id": "c3d7a8be-73ce-4cee-be70-e21c1210a7a6" 51 | }, 52 | { 53 | "cell_type": "code", 54 | "metadata": { 55 | "id": "c2505191-1756-4e5d-b9a8-33af798ad879", 56 | "cellView": "form" 57 | }, 58 | "source": [ 59 | "#@markdown #**Licensed under the MIT License (*Double-click me to read the license agreement*)**\n", 60 | "#@markdown ---\n", 61 | "\n", 62 | "# Copyright (c) 2021 Katherine Crowson\n", 63 | "\n", 64 | "# Permission is hereby granted, free of charge, to any person obtaining a copy\n", 65 | "# of this software and associated documentation files (the \"Software\"), to deal\n", 66 | "# in the Software without restriction, including without limitation the rights\n", 67 | "# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", 68 | "# copies of the Software, and to permit persons to whom the Software is\n", 69 | "# furnished to do so, subject to the following conditions:\n", 70 | "\n", 71 | "# The above copyright notice and this permission notice shall be included in\n", 72 | "# all copies or substantial portions of the Software.\n", 73 | "\n", 74 | "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", 75 | "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", 76 | "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", 77 | "# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", 78 | "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", 79 | "# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n", 80 | "# THE SOFTWARE." 81 | ], 82 | "id": "c2505191-1756-4e5d-b9a8-33af798ad879", 83 | "execution_count": null, 84 | "outputs": [] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "metadata": { 89 | "id": "809824ee-8f52-4986-a9d3-514b14df72b7", 90 | "cellView": "form" 91 | }, 92 | "source": [ 93 | "#@markdown #**Check GPU type**\n", 94 | "#@markdown ### Factory reset runtime if you don't have the desired GPU.\n", 95 | "\n", 96 | "#@markdown ---\n", 97 | "\n", 98 | "\n", 99 | "\n", 100 | "\n", 101 | "#@markdown V100 = Excellent (*Available only for Colab Pro Users*)\n", 102 | "\n", 103 | "#@markdown P100 = Very Good\n", 104 | "\n", 105 | "#@markdown T4 = Good\n", 106 | "\n", 107 | "#@markdown K80 = Meh\n", 108 | "\n", 109 | "#@markdown P4 = (Not Recommended) *for heavy A.I Models like COCO, WikiArt 1024, WikiArt 16384, FacesHQ or S-FLCKR*\n", 110 | "\n", 111 | "#@markdown ---\n", 112 | "\n", 113 | "!nvidia-smi -L" 114 | ], 115 | "id": "809824ee-8f52-4986-a9d3-514b14df72b7", 116 | "execution_count": null, 117 | "outputs": [] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "metadata": { 122 | "id": "XHyPd4oxVp_l", 123 | "cellView": "form" 124 | }, 125 | "source": [ 126 | "#@markdown #**Anti-Disconnect for Google Colab**\n", 127 | "#@markdown ## Run this to stop it from disconnecting automatically \n", 128 | "#@markdown **(It will anyhow disconnect after 6 - 12 hrs for using the free version of Colab.)**\n", 129 | "#@markdown *(Colab Pro users will get about 24 hrs usage time)*\n", 130 | "#@markdown ---\n", 131 | "\n", 132 | "import IPython\n", 133 | "js_code = '''\n", 134 | "function ClickConnect(){\n", 135 | "console.log(\"Working\");\n", 136 | "document.querySelector(\"colab-toolbar-button#connect\").click()\n", 137 | "}\n", 138 | "setInterval(ClickConnect,60000)\n", 139 | "'''\n", 140 | "display(IPython.display.Javascript(js_code))" 141 | ], 142 | "id": "XHyPd4oxVp_l", 143 | "execution_count": null, 144 | "outputs": [] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "metadata": { 149 | "id": "45cac47a-fbac-48bd-8f70-95d5099a12e9", 150 | "cellView": "form" 151 | }, 152 | "source": [ 153 | "#@markdown #**Installation of libraries**\n", 154 | "# @markdown This cell will take a little while because it has to download several libraries\n", 155 | "\n", 156 | "#@markdown ---\n", 157 | "\n", 158 | "!git clone https://github.com/openai/CLIP\n", 159 | "!pip install taming-transformers\n", 160 | "!git clone https://github.com/CompVis/taming-transformers.git\n", 161 | "!pip install ftfy regex tqdm omegaconf pytorch-lightning\n", 162 | "!pip install kornia\n", 163 | "!pip install imageio-ffmpeg \n", 164 | "!pip install einops \n", 165 | "!mkdir steps" 166 | ], 167 | "id": "45cac47a-fbac-48bd-8f70-95d5099a12e9", 168 | "execution_count": null, 169 | "outputs": [] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "metadata": { 174 | "id": "af151e0b-8b57-4f8c-92dd-947f122db8c2", 175 | "cellView": "form" 176 | }, 177 | "source": [ 178 | "#@markdown #**Selection of models to download**\n", 179 | "#@markdown ---\n", 180 | "#@markdown **By default, the notebook downloads the 1024 and 16384 models from ImageNet. There are others like COCO-Stuff, WikiArt 1024, WikiArt 16384, FacesHQ or S-FLCKR, which are heavy, and if you are not going to use them it would be pointless to download them, so if you want to use them, simply select the models to download.**\n", 181 | "\n", 182 | "#@markdown ---\n", 183 | "\n", 184 | "imagenet_1024 = True #@param {type:\"boolean\"}\n", 185 | "imagenet_16384 = True #@param {type:\"boolean\"}\n", 186 | "coco = False #@param {type:\"boolean\"}\n", 187 | "faceshq = False #@param {type:\"boolean\"}\n", 188 | "wikiart_1024 = False #@param {type:\"boolean\"}\n", 189 | "wikiart_16384 = False #@param {type:\"boolean\"}\n", 190 | "sflckr = False #@param {type:\"boolean\"}\n", 191 | "openimages_8192 = False #@param {type:\"boolean\"}\n", 192 | "\n", 193 | "if imagenet_1024:\n", 194 | " !curl -L -o vqgan_imagenet_f16_1024.yaml -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_1024.yaml' #ImageNet 1024\n", 195 | " !curl -L -o vqgan_imagenet_f16_1024.ckpt -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_1024.ckpt' #ImageNet 1024\n", 196 | "if imagenet_16384:\n", 197 | " !curl -L -o vqgan_imagenet_f16_16384.yaml -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.yaml' #ImageNet 16384\n", 198 | " !curl -L -o vqgan_imagenet_f16_16384.ckpt -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.ckpt' #ImageNet 16384\n", 199 | "if openimages_8192:\n", 200 | " !curl -L -o vqgan_openimages_f16_8192.yaml -C - 'https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' #ImageNet 16384\n", 201 | " !curl -L -o vqgan_openimages_f16_8192.ckpt -C - 'https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fckpts%2Flast.ckpt&dl=1' #ImageNet 16384\n", 202 | "\n", 203 | "if coco:\n", 204 | " !curl -L -o coco.yaml -C - 'https://dl.nmkd.de/ai/clip/coco/coco.yaml' #COCO\n", 205 | " !curl -L -o coco.ckpt -C - 'https://dl.nmkd.de/ai/clip/coco/coco.ckpt' #COCO\n", 206 | "if faceshq:\n", 207 | " !curl -L -o faceshq.yaml -C - 'https://drive.google.com/uc?export=download&id=1fHwGx_hnBtC8nsq7hesJvs-Klv-P0gzT' #FacesHQ\n", 208 | " !curl -L -o faceshq.ckpt -C - 'https://app.koofr.net/content/links/a04deec9-0c59-4673-8b37-3d696fe63a5d/files/get/last.ckpt?path=%2F2020-11-13T21-41-45_faceshq_transformer%2Fcheckpoints%2Flast.ckpt' #FacesHQ\n", 209 | "if wikiart_1024: \n", 210 | " !curl -L -o wikiart_1024.yaml -C - 'http://mirror.io.community/blob/vqgan/wikiart.yaml' #WikiArt 1024\n", 211 | " !curl -L -o wikiart_1024.ckpt -C - 'http://mirror.io.community/blob/vqgan/wikiart.ckpt' #WikiArt 1024\n", 212 | "if wikiart_16384: \n", 213 | " !curl -L -o wikiart_16384.yaml -C - 'http://mirror.io.community/blob/vqgan/wikiart_16384.yaml' #WikiArt 16384\n", 214 | " !curl -L -o wikiart_16384.ckpt -C - 'http://mirror.io.community/blob/vqgan/wikiart_16384.ckpt' #WikiArt 16384\n", 215 | "if sflckr:\n", 216 | " !curl -L -o sflckr.yaml -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fconfigs%2F2020-11-09T13-31-51-project.yaml&dl=1' #S-FLCKR\n", 217 | " !curl -L -o sflckr.ckpt -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fcheckpoints%2Flast.ckpt&dl=1' #S-FLCKR" 218 | ], 219 | "id": "af151e0b-8b57-4f8c-92dd-947f122db8c2", 220 | "execution_count": null, 221 | "outputs": [] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "metadata": { 226 | "id": "66d3b419-955d-44f7-ad87-904ee3c16ff8", 227 | "cellView": "form" 228 | }, 229 | "source": [ 230 | "#@markdown #**Loading libraries and definitions**\n", 231 | "\n", 232 | "import argparse\n", 233 | "import math\n", 234 | "from pathlib import Path\n", 235 | "import sys\n", 236 | "\n", 237 | "sys.path.insert(1, '/content/taming-transformers')\n", 238 | "from IPython import display\n", 239 | "from base64 import b64encode\n", 240 | "from omegaconf import OmegaConf\n", 241 | "from PIL import Image\n", 242 | "from taming.models import cond_transformer, vqgan\n", 243 | "import taming.modules \n", 244 | "import torch\n", 245 | "from torch import nn, optim\n", 246 | "from torch.nn import functional as F\n", 247 | "from torchvision import transforms\n", 248 | "from torchvision.transforms import functional as TF\n", 249 | "from tqdm.notebook import tqdm\n", 250 | "\n", 251 | "from CLIP import clip\n", 252 | "import kornia.augmentation as K\n", 253 | "import numpy as np\n", 254 | "import imageio\n", 255 | "from PIL import ImageFile, Image\n", 256 | "ImageFile.LOAD_TRUNCATED_IMAGES = True\n", 257 | "\n", 258 | "def sinc(x):\n", 259 | " return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))\n", 260 | "\n", 261 | "\n", 262 | "def lanczos(x, a):\n", 263 | " cond = torch.logical_and(-a < x, x < a)\n", 264 | " out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))\n", 265 | " return out / out.sum()\n", 266 | "\n", 267 | "\n", 268 | "def ramp(ratio, width):\n", 269 | " n = math.ceil(width / ratio + 1)\n", 270 | " out = torch.empty([n])\n", 271 | " cur = 0\n", 272 | " for i in range(out.shape[0]):\n", 273 | " out[i] = cur\n", 274 | " cur += ratio\n", 275 | " return torch.cat([-out[1:].flip([0]), out])[1:-1]\n", 276 | "\n", 277 | "\n", 278 | "def resample(input, size, align_corners=True):\n", 279 | " n, c, h, w = input.shape\n", 280 | " dh, dw = size\n", 281 | "\n", 282 | " input = input.view([n * c, 1, h, w])\n", 283 | "\n", 284 | " if dh < h:\n", 285 | " kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)\n", 286 | " pad_h = (kernel_h.shape[0] - 1) // 2\n", 287 | " input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')\n", 288 | " input = F.conv2d(input, kernel_h[None, None, :, None])\n", 289 | "\n", 290 | " if dw < w:\n", 291 | " kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)\n", 292 | " pad_w = (kernel_w.shape[0] - 1) // 2\n", 293 | " input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')\n", 294 | " input = F.conv2d(input, kernel_w[None, None, None, :])\n", 295 | "\n", 296 | " input = input.view([n, c, h, w])\n", 297 | " return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)\n", 298 | "\n", 299 | "\n", 300 | "class ReplaceGrad(torch.autograd.Function):\n", 301 | " @staticmethod\n", 302 | " def forward(ctx, x_forward, x_backward):\n", 303 | " ctx.shape = x_backward.shape\n", 304 | " return x_forward\n", 305 | "\n", 306 | " @staticmethod\n", 307 | " def backward(ctx, grad_in):\n", 308 | " return None, grad_in.sum_to_size(ctx.shape)\n", 309 | "\n", 310 | "\n", 311 | "replace_grad = ReplaceGrad.apply\n", 312 | "\n", 313 | "\n", 314 | "class ClampWithGrad(torch.autograd.Function):\n", 315 | " @staticmethod\n", 316 | " def forward(ctx, input, min, max):\n", 317 | " ctx.min = min\n", 318 | " ctx.max = max\n", 319 | " ctx.save_for_backward(input)\n", 320 | " return input.clamp(min, max)\n", 321 | "\n", 322 | " @staticmethod\n", 323 | " def backward(ctx, grad_in):\n", 324 | " input, = ctx.saved_tensors\n", 325 | " return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None\n", 326 | "\n", 327 | "\n", 328 | "clamp_with_grad = ClampWithGrad.apply\n", 329 | "\n", 330 | "\n", 331 | "def vector_quantize(x, codebook):\n", 332 | " d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T\n", 333 | " indices = d.argmin(-1)\n", 334 | " x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook\n", 335 | " return replace_grad(x_q, x)\n", 336 | "\n", 337 | "\n", 338 | "class Prompt(nn.Module):\n", 339 | " def __init__(self, embed, weight=1., stop=float('-inf')):\n", 340 | " super().__init__()\n", 341 | " self.register_buffer('embed', embed)\n", 342 | " self.register_buffer('weight', torch.as_tensor(weight))\n", 343 | " self.register_buffer('stop', torch.as_tensor(stop))\n", 344 | "\n", 345 | " def forward(self, input):\n", 346 | " input_normed = F.normalize(input.unsqueeze(1), dim=2)\n", 347 | " embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)\n", 348 | " dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)\n", 349 | " dists = dists * self.weight.sign()\n", 350 | " return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()\n", 351 | "\n", 352 | "\n", 353 | "def parse_prompt(prompt):\n", 354 | " vals = prompt.rsplit(':', 2)\n", 355 | " vals = vals + ['', '1', '-inf'][len(vals):]\n", 356 | " return vals[0], float(vals[1]), float(vals[2])\n", 357 | "\n", 358 | "\n", 359 | "class MakeCutouts(nn.Module):\n", 360 | " def __init__(self, cut_size, cutn, cut_pow=1.):\n", 361 | " super().__init__()\n", 362 | " self.cut_size = cut_size\n", 363 | " self.cutn = cutn\n", 364 | " self.cut_pow = cut_pow\n", 365 | "\n", 366 | " self.augs = nn.Sequential(\n", 367 | " # K.RandomHorizontalFlip(p=0.5),\n", 368 | " # K.RandomVerticalFlip(p=0.5),\n", 369 | " # K.RandomSolarize(0.01, 0.01, p=0.7),\n", 370 | " # K.RandomSharpness(0.3,p=0.4),\n", 371 | " # K.RandomResizedCrop(size=(self.cut_size,self.cut_size), scale=(0.1,1), ratio=(0.75,1.333), cropping_mode='resample', p=0.5),\n", 372 | " # K.RandomCrop(size=(self.cut_size,self.cut_size), p=0.5),\n", 373 | " K.RandomAffine(degrees=15, translate=0.1, p=0.7, padding_mode='border'),\n", 374 | " K.RandomPerspective(0.7,p=0.7),\n", 375 | " K.ColorJitter(hue=0.1, saturation=0.1, p=0.7),\n", 376 | " K.RandomErasing((.1, .4), (.3, 1/.3), same_on_batch=True, p=0.7),\n", 377 | " \n", 378 | ")\n", 379 | " self.noise_fac = 0.1\n", 380 | " self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))\n", 381 | " self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))\n", 382 | "\n", 383 | " def forward(self, input):\n", 384 | " sideY, sideX = input.shape[2:4]\n", 385 | " max_size = min(sideX, sideY)\n", 386 | " min_size = min(sideX, sideY, self.cut_size)\n", 387 | " cutouts = []\n", 388 | " \n", 389 | " for _ in range(self.cutn):\n", 390 | "\n", 391 | " # size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)\n", 392 | " # offsetx = torch.randint(0, sideX - size + 1, ())\n", 393 | " # offsety = torch.randint(0, sideY - size + 1, ())\n", 394 | " # cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n", 395 | " # cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))\n", 396 | "\n", 397 | " # cutout = transforms.Resize(size=(self.cut_size, self.cut_size))(input)\n", 398 | " \n", 399 | " cutout = (self.av_pool(input) + self.max_pool(input))/2\n", 400 | " cutouts.append(cutout)\n", 401 | " batch = self.augs(torch.cat(cutouts, dim=0))\n", 402 | " if self.noise_fac:\n", 403 | " facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)\n", 404 | " batch = batch + facs * torch.randn_like(batch)\n", 405 | " return batch\n", 406 | "\n", 407 | "\n", 408 | "def load_vqgan_model(config_path, checkpoint_path):\n", 409 | " config = OmegaConf.load(config_path)\n", 410 | " if config.model.target == 'taming.models.vqgan.VQModel':\n", 411 | " model = vqgan.VQModel(**config.model.params)\n", 412 | " model.eval().requires_grad_(False)\n", 413 | " model.init_from_ckpt(checkpoint_path)\n", 414 | " elif config.model.target == 'taming.models.vqgan.GumbelVQ':\n", 415 | " model = vqgan.GumbelVQ(**config.model.params)\n", 416 | " model.eval().requires_grad_(False)\n", 417 | " model.init_from_ckpt(checkpoint_path)\n", 418 | " elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':\n", 419 | " parent_model = cond_transformer.Net2NetTransformer(**config.model.params)\n", 420 | " parent_model.eval().requires_grad_(False)\n", 421 | " parent_model.init_from_ckpt(checkpoint_path)\n", 422 | " model = parent_model.first_stage_model\n", 423 | " else:\n", 424 | " raise ValueError(f'unknown model type: {config.model.target}')\n", 425 | " del model.loss\n", 426 | " return model\n", 427 | "\n", 428 | "\n", 429 | "def resize_image(image, out_size):\n", 430 | " ratio = image.size[0] / image.size[1]\n", 431 | " area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])\n", 432 | " size = round((area * ratio)**0.5), round((area / ratio)**0.5)\n", 433 | " return image.resize(size, Image.LANCZOS)" 434 | ], 435 | "id": "66d3b419-955d-44f7-ad87-904ee3c16ff8", 436 | "execution_count": null, 437 | "outputs": [] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "metadata": { 442 | "id": "f80696fb-8535-488b-b83e-a3e4928d6e6a" 443 | }, 444 | "source": [ 445 | "## Settings for this run:\n", 446 | "Mainly what you will have to modify will be `texts:`, there you can place the text or texts you want to generate (separated with `|`). It is a list because you can put more than one text, and so the AI ​​tries to 'mix' the images, giving the same priority to both texts.\n", 447 | "\n", 448 | "To use an initial image to the model, you just have to upload a file to the Colab environment (in the section on the left), and then modify `init_image:` putting the exact name of the file. Example: `sample.png`\n", 449 | "\n", 450 | "You can also modify the model by changing the lines that say `model:`. Currently ImageNet 1024, ImageNet 16384, WikiArt 1024, WikiArt 16384, S-FLCKR, COCO-Stuff and Open Images are available. To activate them you have to have downloaded them first, and then you can simply select it.\n", 451 | "\n", 452 | "You can also use `target_images`, which is basically putting one or more images on it that the AI ​​will take as a \"target\", fulfilling the same function as putting text on it. To put more than one you have to use `|` as a separator.\n", 453 | "\n", 454 | "**You can change the width and the height of the image that will be generated according to your liking.**" 455 | ], 456 | "id": "f80696fb-8535-488b-b83e-a3e4928d6e6a" 457 | }, 458 | { 459 | "cell_type": "code", 460 | "metadata": { 461 | "id": "fccf05b3-2e0a-46a1-a377-607d151377ac", 462 | "cellView": "form" 463 | }, 464 | "source": [ 465 | "#@markdown #**Parameters**\n", 466 | "#@markdown ---\n", 467 | "\n", 468 | "texts = \"\" #@param {type:\"string\"}\n", 469 | "width = 300#@param {type:\"number\"}\n", 470 | "height = 300#@param {type:\"number\"}\n", 471 | "model = \"vqgan_imagenet_f16_16384\" #@param [\"vqgan_imagenet_f16_16384\", \"vqgan_imagenet_f16_1024\", \"vqgan_openimages_f16_8192\", \"wikiart_1024\", \"wikiart_16384\", \"coco\", \"faceshq\", \"sflckr\"]\n", 472 | "images_interval = 50#@param {type:\"number\"}\n", 473 | "init_image = \"\"#@param {type:\"string\"}\n", 474 | "target_images = \"\"#@param {type:\"string\"}\n", 475 | "seed = -1#@param {type:\"number\"}\n", 476 | "max_iterations = -1#@param {type:\"number\"}\n", 477 | "\n", 478 | "model_names={\"vqgan_imagenet_f16_16384\": 'ImageNet 16384',\"vqgan_imagenet_f16_1024\":\"ImageNet 1024\", 'vqgan_openimages_f16_8192':'OpenImages 8912',\n", 479 | " \"wikiart_1024\":\"WikiArt 1024\", \"wikiart_16384\":\"WikiArt 16384\", \"coco\":\"COCO-Stuff\", \"faceshq\":\"FacesHQ\", \"sflckr\":\"S-FLCKR\"}\n", 480 | "name_model = model_names[model] \n", 481 | "\n", 482 | "if seed == -1:\n", 483 | " seed = None\n", 484 | "if init_image == \"None\":\n", 485 | " init_image = None\n", 486 | "if target_images == \"None\" or not target_images:\n", 487 | " target_images = []\n", 488 | "else:\n", 489 | " target_images = target_images.split(\"|\")\n", 490 | " target_images = [image.strip() for image in target_images]\n", 491 | "\n", 492 | "texts = [phrase.strip() for phrase in texts.split(\"|\")]\n", 493 | "if texts == ['']:\n", 494 | " texts = []\n", 495 | "\n", 496 | "\n", 497 | "args = argparse.Namespace(\n", 498 | " prompts=texts,\n", 499 | " image_prompts=target_images,\n", 500 | " noise_prompt_seeds=[],\n", 501 | " noise_prompt_weights=[],\n", 502 | " size=[width, height],\n", 503 | " init_image=init_image,\n", 504 | " init_weight=0.,\n", 505 | " clip_model='ViT-B/32',\n", 506 | " vqgan_config=f'{model}.yaml',\n", 507 | " vqgan_checkpoint=f'{model}.ckpt',\n", 508 | " step_size=0.1,\n", 509 | " cutn=32,\n", 510 | " cut_pow=1.,\n", 511 | " display_freq=images_interval,\n", 512 | " seed=seed,\n", 513 | ")" 514 | ], 515 | "id": "fccf05b3-2e0a-46a1-a377-607d151377ac", 516 | "execution_count": null, 517 | "outputs": [] 518 | }, 519 | { 520 | "cell_type": "code", 521 | "metadata": { 522 | "id": "a09f280f-bc39-4e33-b35c-510a38fbc7bf", 523 | "cellView": "form" 524 | }, 525 | "source": [ 526 | "#@markdown #**Fire up the AI**\n", 527 | "\n", 528 | "#@markdown ---\n", 529 | "from urllib.request import urlopen\n", 530 | "\n", 531 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 532 | "print('Using device:', device)\n", 533 | "if texts:\n", 534 | " print('Using texts:', texts)\n", 535 | "if target_images:\n", 536 | " print('Using image prompts:', target_images)\n", 537 | "if args.seed is None:\n", 538 | " seed = torch.seed()\n", 539 | "else:\n", 540 | " seed = args.seed\n", 541 | "torch.manual_seed(seed)\n", 542 | "print('Using seed:', seed)\n", 543 | "\n", 544 | "model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)\n", 545 | "perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)\n", 546 | "# clock=deepcopy(perceptor.visual.positional_embedding.data)\n", 547 | "# perceptor.visual.positional_embedding.data = clock/clock.max()\n", 548 | "# perceptor.visual.positional_embedding.data=clamp_with_grad(clock,0,1)\n", 549 | "\n", 550 | "cut_size = perceptor.visual.input_resolution\n", 551 | "\n", 552 | "f = 2**(model.decoder.num_resolutions - 1)\n", 553 | "make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow)\n", 554 | "\n", 555 | "toksX, toksY = args.size[0] // f, args.size[1] // f\n", 556 | "sideX, sideY = toksX * f, toksY * f\n", 557 | "\n", 558 | "if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt':\n", 559 | " e_dim = 256\n", 560 | " n_toks = model.quantize.n_embed\n", 561 | " z_min = model.quantize.embed.weight.min(dim=0).values[None, :, None, None]\n", 562 | " z_max = model.quantize.embed.weight.max(dim=0).values[None, :, None, None]\n", 563 | "else:\n", 564 | " e_dim = model.quantize.e_dim\n", 565 | " n_toks = model.quantize.n_e\n", 566 | " z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]\n", 567 | " z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]\n", 568 | "# z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]\n", 569 | "# z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]\n", 570 | "\n", 571 | "# normalize_imagenet = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", 572 | "# std=[0.229, 0.224, 0.225])\n", 573 | "\n", 574 | "if args.init_image:\n", 575 | " if 'http' in args.init_image:\n", 576 | " img = Image.open(urlopen(args.init_image))\n", 577 | " else:\n", 578 | " img = Image.open(args.init_image)\n", 579 | " pil_image = img.convert('RGB')\n", 580 | " pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)\n", 581 | " pil_tensor = TF.to_tensor(pil_image)\n", 582 | " z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1)\n", 583 | "else:\n", 584 | " one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float()\n", 585 | " # z = one_hot @ model.quantize.embedding.weight\n", 586 | " if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt':\n", 587 | " z = one_hot @ model.quantize.embed.weight\n", 588 | " else:\n", 589 | " z = one_hot @ model.quantize.embedding.weight\n", 590 | " z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2) \n", 591 | " z = torch.rand_like(z)*2\n", 592 | "z_orig = z.clone()\n", 593 | "z.requires_grad_(True)\n", 594 | "opt = optim.Adam([z], lr=args.step_size)\n", 595 | "\n", 596 | "normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],\n", 597 | " std=[0.26862954, 0.26130258, 0.27577711])\n", 598 | "\n", 599 | "\n", 600 | "\n", 601 | "pMs = []\n", 602 | "\n", 603 | "for prompt in args.prompts:\n", 604 | " txt, weight, stop = parse_prompt(prompt)\n", 605 | " embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()\n", 606 | " pMs.append(Prompt(embed, weight, stop).to(device))\n", 607 | "\n", 608 | "for prompt in args.image_prompts:\n", 609 | " path, weight, stop = parse_prompt(prompt)\n", 610 | " img = Image.open(path)\n", 611 | " pil_image = img.convert('RGB')\n", 612 | " img = resize_image(pil_image, (sideX, sideY))\n", 613 | " batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))\n", 614 | " embed = perceptor.encode_image(normalize(batch)).float()\n", 615 | " pMs.append(Prompt(embed, weight, stop).to(device))\n", 616 | "\n", 617 | "for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):\n", 618 | " gen = torch.Generator().manual_seed(seed)\n", 619 | " embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)\n", 620 | " pMs.append(Prompt(embed, weight).to(device))\n", 621 | "\n", 622 | "def synth(z):\n", 623 | " if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt':\n", 624 | " z_q = vector_quantize(z.movedim(1, 3), model.quantize.embed.weight).movedim(3, 1)\n", 625 | " else:\n", 626 | " z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)\n", 627 | " return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)\n", 628 | "\n", 629 | "@torch.no_grad()\n", 630 | "def checkin(i, losses):\n", 631 | " losses_str = ', '.join(f'{loss.item():g}' for loss in losses)\n", 632 | " tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}')\n", 633 | " out = synth(z)\n", 634 | " TF.to_pil_image(out[0].cpu()).save('progress.png')\n", 635 | " display.display(display.Image('progress.png'))\n", 636 | "\n", 637 | "def ascend_txt():\n", 638 | " global i\n", 639 | " out = synth(z)\n", 640 | " iii = perceptor.encode_image(normalize(make_cutouts(out))).float()\n", 641 | " \n", 642 | " result = []\n", 643 | "\n", 644 | " if args.init_weight:\n", 645 | " # result.append(F.mse_loss(z, z_orig) * args.init_weight / 2)\n", 646 | " result.append(F.mse_loss(z, torch.zeros_like(z_orig)) * ((1/torch.tensor(i*2 + 1))*args.init_weight) / 2)\n", 647 | " for prompt in pMs:\n", 648 | " result.append(prompt(iii))\n", 649 | " img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]\n", 650 | " img = np.transpose(img, (1, 2, 0))\n", 651 | " imageio.imwrite('./steps/' + str(i) + '.png', np.array(img))\n", 652 | "\n", 653 | " return result\n", 654 | "\n", 655 | "def train(i):\n", 656 | " opt.zero_grad()\n", 657 | " lossAll = ascend_txt()\n", 658 | " if i % args.display_freq == 0:\n", 659 | " checkin(i, lossAll)\n", 660 | " \n", 661 | " loss = sum(lossAll)\n", 662 | " loss.backward()\n", 663 | " opt.step()\n", 664 | " with torch.no_grad():\n", 665 | " z.copy_(z.maximum(z_min).minimum(z_max))\n", 666 | "\n", 667 | "i = 0\n", 668 | "try:\n", 669 | " with tqdm() as pbar:\n", 670 | " while True:\n", 671 | " train(i)\n", 672 | " if i == max_iterations:\n", 673 | " break\n", 674 | " i += 1\n", 675 | " pbar.update()\n", 676 | "except KeyboardInterrupt:\n", 677 | " pass" 678 | ], 679 | "id": "a09f280f-bc39-4e33-b35c-510a38fbc7bf", 680 | "execution_count": null, 681 | "outputs": [] 682 | }, 683 | { 684 | "cell_type": "code", 685 | "metadata": { 686 | "id": "c053240e-8a05-448d-9977-8308179b6d09", 687 | "cellView": "form" 688 | }, 689 | "source": [ 690 | "#@markdown **Generate a video with the result (You can edit frame rate and stuff by double-clicking this tab)**\n", 691 | "init_frame = 1 #This is the frame where the video will start\n", 692 | "last_frame = i #You can change i to the number of the last frame you want to generate. It will raise an error if that number of frames does not exist.\n", 693 | "\n", 694 | "min_fps = 10\n", 695 | "max_fps = 60\n", 696 | "\n", 697 | "total_frames = last_frame-init_frame\n", 698 | "\n", 699 | "length = 15 #Desired time of the video in seconds\n", 700 | "\n", 701 | "frames = []\n", 702 | "tqdm.write('Generating video...')\n", 703 | "for i in range(init_frame,last_frame): #\n", 704 | " frames.append(Image.open(\"./steps/\"+ str(i) +'.png'))\n", 705 | "\n", 706 | "#fps = last_frame/10\n", 707 | "fps = np.clip(total_frames/length,min_fps,max_fps)\n", 708 | "\n", 709 | "from subprocess import Popen, PIPE\n", 710 | "p = Popen(['ffmpeg', '-y', '-f', 'image2pipe', '-vcodec', 'png', '-r', str(fps), '-i', '-', '-vcodec', 'libx264', '-r', str(fps), '-pix_fmt', 'yuv420p', '-crf', '17', '-preset', 'veryslow', 'video.mp4'], stdin=PIPE)\n", 711 | "for im in tqdm(frames):\n", 712 | " im.save(p.stdin, 'PNG')\n", 713 | "p.stdin.close()\n", 714 | "p.wait()\n", 715 | " \n", 716 | "print(\"The video is now being compressed, wait ...\")\n", 717 | "p.wait()\n", 718 | "print(\"The video is ready\")" 719 | ], 720 | "id": "c053240e-8a05-448d-9977-8308179b6d09", 721 | "execution_count": null, 722 | "outputs": [] 723 | }, 724 | { 725 | "cell_type": "code", 726 | "metadata": { 727 | "id": "jcSE6X9Knp3g", 728 | "cellView": "form" 729 | }, 730 | "source": [ 731 | "#@markdown **View video in browser**\n", 732 | "\n", 733 | "# @markdown *This process may take a little longer. If you don't want to wait, download it by executing the next cell instead of using this cell.*\n", 734 | "mp4 = open('video.mp4','rb').read()\n", 735 | "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", 736 | "display.HTML(\"\"\"\n", 737 | "\n", 740 | "\"\"\" % data_url)" 741 | ], 742 | "id": "jcSE6X9Knp3g", 743 | "execution_count": null, 744 | "outputs": [] 745 | }, 746 | { 747 | "cell_type": "code", 748 | "metadata": { 749 | "id": "fZpYoLNWn10M", 750 | "cellView": "form" 751 | }, 752 | "source": [ 753 | "#@markdown #**Download the result video**\n", 754 | "from google.colab import files\n", 755 | "files.download(\"video.mp4\")" 756 | ], 757 | "id": "fZpYoLNWn10M", 758 | "execution_count": null, 759 | "outputs": [] 760 | }, 761 | { 762 | "cell_type": "markdown", 763 | "metadata": { 764 | "id": "DUmwey4oV4zH" 765 | }, 766 | "source": [ 767 | "JS to prevent idle timeout:\n", 768 | "\n", 769 | "Press F12 OR CTRL + SHIFT + I OR right click on this website -> inspect.\n", 770 | "Then click on the console tab and paste in the following code.\n", 771 | "\n", 772 | "```javascript\n", 773 | "function ClickConnect(){\n", 774 | "console.log(\"Working\");\n", 775 | "document.querySelector(\"colab-toolbar-button#connect\").click()\n", 776 | "}\n", 777 | "setInterval(ClickConnect,60000)\n", 778 | "```" 779 | ], 780 | "id": "DUmwey4oV4zH" 781 | } 782 | ] 783 | } --------------------------------------------------------------------------------