├── README.md ├── StyleGAN3+CLIP.ipynb ├── StyleGAN3+inversion+CLIP.ipynb ├── cog.yaml ├── download.sh └── predict.py /README.md: -------------------------------------------------------------------------------- 1 | # StyleGAN3 CLIP-based guidance 2 | 3 | ### StyleGAN3 + CLIP 4 | 5 | 6 | Open in Colab 9 | 10 | 11 | 12 | ### StyleGAN3 + inversion + CLIP 13 | 14 | 15 | Open in Colab 18 | 19 | 20 | --- 21 | 22 | This repo is a collection of Jupyter notebooks made to easily play with StyleGAN3[^1] and CLIP[^2] for a text-based guided image generation. 23 | 24 | Both notebooks are heavily based on [this notebook](https://colab.research.google.com/drive/1eYlenR1GHPZXt-YuvXabzO9wfh9CWY36#scrollTo=LQf7tzBQ8rn2), created by [nshepperd](https://twitter.com/nshepperd1) (thank you!). 25 | 26 | Special thanks too to [Katherine Crowson](https://twitter.com/RiversHaveWings) for coming up with many improved sampling tricks, as well as some of the code. 27 | 28 | [^1]: StyleGAN3 was created by NVIDIA. [Here](https://github.com/NVlabs/stylegan3) is the original repo. 29 | 30 | [^2]: CLIP (Contrastive Language-Image Pre-Training) is a multimodal model made by OpenAI. For more information head over [here](https://github.com/openai/CLIP). 31 | 32 | Feel free to suggest any changes! If anyone has any idea what license should this repo use, please let me know. 33 | -------------------------------------------------------------------------------- /StyleGAN3+CLIP.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "StyleGAN3+CLIP.ipynb", 7 | "private_outputs": true, 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | }, 19 | "accelerator": "GPU" 20 | }, 21 | "cells": [ 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "id": "view-in-github", 26 | "colab_type": "text" 27 | }, 28 | "source": [ 29 | "\"Open" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": { 35 | "id": "bJOj_BWi_JwR" 36 | }, 37 | "source": [ 38 | "# **StyleGAN3 + CLIP 🖼️**\n", 39 | "\n", 40 | "## Generate images from text prompts using NVIDIA's StyleGAN3 with CLIP guidance.\n", 41 | "\n", 42 | "Head over [here](https://github.com/ouhenio/StyleGAN3-CLIP-notebook) if you want to be up to date with the changes to this notebook and play with other alternatives.\n", 43 | "\n", 44 | "The original code was written by [nshepperd](https://twitter.com/nshepperd1)* (https://github.com/nshepperd), and later edited by [Eugenio Herrera](https://github.com/ouhenio).\n", 45 | "\n", 46 | "Thanks to [Katherine Crowson](https://twitter.com/RiversHaveWings) (https://github.com/crowsonkb) for coming up with many improved sampling tricks, as well as some of the code.\n", 47 | "\n", 48 | "----\n", 49 | "\n", 50 | "(*) nshepperd originally made [this notebook](https://colab.research.google.com/drive/1eYlenR1GHPZXt-YuvXabzO9wfh9CWY36#scrollTo=LQf7tzBQ8rn2).\n", 51 | "\n", 52 | "(**) The interface is inspired by [this notebook](https://colab.research.google.com/github/justinjohn0306/VQGAN-CLIP/blob/main/VQGAN%2BCLIP(Updated).ipynb), done by Jakeukalane and Avengium (Angel).\n", 53 | "\n", 54 | "(***) For more information about StyleGAN3, [visit the official repository](https://github.com/NVlabs/stylegan3)." 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "metadata": { 60 | "id": "rg_x6IdHv_2G", 61 | "cellView": "form" 62 | }, 63 | "source": [ 64 | "#@markdown #**Check GPU type** 🕵️\n", 65 | "#@markdown ### Factory reset runtime if you don't have the desired GPU.\n", 66 | "\n", 67 | "#@markdown ---\n", 68 | "\n", 69 | "\n", 70 | "\n", 71 | "\n", 72 | "#@markdown V100 = Excellent (*Available only for Colab Pro users*)\n", 73 | "\n", 74 | "#@markdown P100 = Very Good\n", 75 | "\n", 76 | "#@markdown T4 = Good (*preferred*)\n", 77 | "\n", 78 | "#@markdown K80 = Meh\n", 79 | "\n", 80 | "#@markdown P4 = (*Not Recommended*) \n", 81 | "\n", 82 | "#@markdown ---\n", 83 | "\n", 84 | "!nvidia-smi -L" 85 | ], 86 | "execution_count": null, 87 | "outputs": [] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "metadata": { 92 | "id": "5K38uyFrv5wo", 93 | "cellView": "form" 94 | }, 95 | "source": [ 96 | "#@markdown #**Install libraries** 🏗️\n", 97 | "# @markdown This cell will take a little while because it has to download several libraries.\n", 98 | "\n", 99 | "#@markdown ---\n", 100 | "\n", 101 | "!pip install --upgrade torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html\n", 102 | "#!pip install --upgrade https://download.pytorch.org/whl/nightly/cu111/torch-1.11.0.dev20211012%2Bcu111-cp37-cp37m-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu111/torchvision-0.12.0.dev20211012%2Bcu111-cp37-cp37m-linux_x86_64.whl\n", 103 | "!git clone https://github.com/NVlabs/stylegan3\n", 104 | "!git clone https://github.com/openai/CLIP\n", 105 | "!pip install -e ./CLIP\n", 106 | "!pip install einops ninja\n", 107 | "\n", 108 | "import sys\n", 109 | "sys.path.append('./CLIP')\n", 110 | "sys.path.append('./stylegan3')\n", 111 | "\n", 112 | "import io\n", 113 | "import os, time, glob\n", 114 | "import pickle\n", 115 | "import shutil\n", 116 | "import numpy as np\n", 117 | "from PIL import Image\n", 118 | "import torch\n", 119 | "import torch.nn.functional as F\n", 120 | "import requests\n", 121 | "import torchvision.transforms as transforms\n", 122 | "import torchvision.transforms.functional as TF\n", 123 | "import clip\n", 124 | "import unicodedata\n", 125 | "import re\n", 126 | "from tqdm.notebook import tqdm\n", 127 | "from torchvision.transforms import Compose, Resize, ToTensor, Normalize\n", 128 | "from IPython.display import display\n", 129 | "from einops import rearrange\n", 130 | "from google.colab import files\n", 131 | "\n", 132 | "device = torch.device('cuda:0')\n", 133 | "print('Using device:', device, file=sys.stderr)" 134 | ], 135 | "execution_count": null, 136 | "outputs": [] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "metadata": { 141 | "id": "zVkNODOot_To", 142 | "cellView": "form" 143 | }, 144 | "source": [ 145 | "#@markdown #**Optional:** Save images in Google Drive 💾\n", 146 | "# @markdown Run this cell if you want to store the results inside Google Drive.\n", 147 | "\n", 148 | "# @markdown Copying the generated images to drive is faster to work with.\n", 149 | "\n", 150 | "# @markdown **Important**: you must have a folder named *samples* inside your drive, otherwise this may not work.\n", 151 | "\n", 152 | "#@markdown ---\n", 153 | "\n", 154 | "# Uncomment to copy generated images to drive, faster than downloading directly from colab in my experience.\n", 155 | "from google.colab import drive\n", 156 | "drive.mount('/content/drive')" 157 | ], 158 | "execution_count": null, 159 | "outputs": [] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "metadata": { 164 | "id": "1GOWJ_z-wgde", 165 | "cellView": "form" 166 | }, 167 | "source": [ 168 | "#@markdown #**Define necessary functions** 🛠️\n", 169 | "\n", 170 | "def fetch(url_or_path):\n", 171 | " if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):\n", 172 | " r = requests.get(url_or_path)\n", 173 | " r.raise_for_status()\n", 174 | " fd = io.BytesIO()\n", 175 | " fd.write(r.content)\n", 176 | " fd.seek(0)\n", 177 | " return fd\n", 178 | " return open(url_or_path, 'rb')\n", 179 | "\n", 180 | "def fetch_model(url_or_path):\n", 181 | " if \"drive.google\" in url_or_path:\n", 182 | " if \"18MOpwTMJsl_Z17q-wQVnaRLCUFZYSNkj\" in url_or_path: \n", 183 | " basename = \"wikiart-1024-stylegan3-t-17.2Mimg.pkl\"\n", 184 | " elif \"14UGDDOusZ9TMb-pOrF0PAjMGVWLSAii1\" in url_or_path:\n", 185 | " basename = \"lhq-256-stylegan3-t-25Mimg.pkl\"\n", 186 | " else:\n", 187 | " basename = os.path.basename(url_or_path)\n", 188 | " if os.path.exists(basename):\n", 189 | " return basename\n", 190 | " else:\n", 191 | " if \"drive.google\" not in url_or_path:\n", 192 | " !wget -c '{url_or_path}'\n", 193 | " else:\n", 194 | " path_id = url_or_path.split(\"id=\")[-1]\n", 195 | " !gdown --id '{path_id}'\n", 196 | " return basename\n", 197 | "\n", 198 | "def slugify(value, allow_unicode=False):\n", 199 | " \"\"\"\n", 200 | " Taken from https://github.com/django/django/blob/master/django/utils/text.py\n", 201 | " Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated\n", 202 | " dashes to single dashes. Remove characters that aren't alphanumerics,\n", 203 | " underscores, or hyphens. Convert to lowercase. Also strip leading and\n", 204 | " trailing whitespace, dashes, and underscores.\n", 205 | " \"\"\"\n", 206 | " value = str(value)\n", 207 | " if allow_unicode:\n", 208 | " value = unicodedata.normalize('NFKC', value)\n", 209 | " else:\n", 210 | " value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')\n", 211 | " value = re.sub(r'[^\\w\\s-]', '', value.lower())\n", 212 | " return re.sub(r'[-\\s]+', '-', value).strip('-_')\n", 213 | "\n", 214 | "def norm1(prompt):\n", 215 | " \"Normalize to the unit sphere.\"\n", 216 | " return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt()\n", 217 | "\n", 218 | "def spherical_dist_loss(x, y):\n", 219 | " x = F.normalize(x, dim=-1)\n", 220 | " y = F.normalize(y, dim=-1)\n", 221 | " return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)\n", 222 | "\n", 223 | "def prompts_dist_loss(x, targets, loss):\n", 224 | " if len(targets) == 1: # Keeps consitent results vs previous method for single objective guidance \n", 225 | " return loss(x, targets[0])\n", 226 | " distances = [loss(x, target) for target in targets]\n", 227 | " return torch.stack(distances, dim=-1).sum(dim=-1) \n", 228 | "\n", 229 | "class MakeCutouts(torch.nn.Module):\n", 230 | " def __init__(self, cut_size, cutn, cut_pow=1.):\n", 231 | " super().__init__()\n", 232 | " self.cut_size = cut_size\n", 233 | " self.cutn = cutn\n", 234 | " self.cut_pow = cut_pow\n", 235 | "\n", 236 | " def forward(self, input):\n", 237 | " sideY, sideX = input.shape[2:4]\n", 238 | " max_size = min(sideX, sideY)\n", 239 | " min_size = min(sideX, sideY, self.cut_size)\n", 240 | " cutouts = []\n", 241 | " for _ in range(self.cutn):\n", 242 | " size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)\n", 243 | " offsetx = torch.randint(0, sideX - size + 1, ())\n", 244 | " offsety = torch.randint(0, sideY - size + 1, ())\n", 245 | " cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n", 246 | " cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))\n", 247 | " return torch.cat(cutouts)\n", 248 | "\n", 249 | "make_cutouts = MakeCutouts(224, 32, 0.5)\n", 250 | "\n", 251 | "def embed_image(image):\n", 252 | " n = image.shape[0]\n", 253 | " cutouts = make_cutouts(image)\n", 254 | " embeds = clip_model.embed_cutout(cutouts)\n", 255 | " embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n)\n", 256 | " return embeds\n", 257 | "\n", 258 | "def embed_url(url):\n", 259 | " image = Image.open(fetch(url)).convert('RGB')\n", 260 | " return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0)\n", 261 | "\n", 262 | "class CLIP(object):\n", 263 | " def __init__(self):\n", 264 | " clip_model = \"ViT-B/32\"\n", 265 | " self.model, _ = clip.load(clip_model)\n", 266 | " self.model = self.model.requires_grad_(False)\n", 267 | " self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],\n", 268 | " std=[0.26862954, 0.26130258, 0.27577711])\n", 269 | "\n", 270 | " @torch.no_grad()\n", 271 | " def embed_text(self, prompt):\n", 272 | " \"Normalized clip text embedding.\"\n", 273 | " return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float())\n", 274 | "\n", 275 | " def embed_cutout(self, image):\n", 276 | " \"Normalized clip image embedding.\"\n", 277 | " return norm1(self.model.encode_image(self.normalize(image)))\n", 278 | " \n", 279 | "clip_model = CLIP()" 280 | ], 281 | "execution_count": null, 282 | "outputs": [] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "metadata": { 287 | "id": "4vpIq2vvzTtS", 288 | "cellView": "form" 289 | }, 290 | "source": [ 291 | "#@markdown #**Model selection** 🎭\n", 292 | "\n", 293 | "\n", 294 | "#@markdown There are 4 pre-trained options to play with:\n", 295 | "#@markdown - FFHQ: Trained with human faces.\n", 296 | "#@markdown - MetFaces: Trained with paintings/portraits of human faces.\n", 297 | "#@markdown - AFHQv2: Trained with animal faces.\n", 298 | "#@markdown - Cosplay: Trained by [l4rz](https://twitter.com/l4rz) with cosplayer's faces.\n", 299 | "#@markdown - Wikiart: Trained by [Justin Pinkney](https://www.justinpinkney.com/) with the Wikiart 1024 dataset.\n", 300 | "#@markdown - Landscapes: Trained by [Justin Pinkney](https://www.justinpinkney.com/) with the LHQ dataset.\n", 301 | "\n", 302 | "\n", 303 | "#@markdown **Run this cell again if you change the model**.\n", 304 | "\n", 305 | "#@markdown ---\n", 306 | "\n", 307 | "base_url = \"https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/\"\n", 308 | "\n", 309 | "Model = 'FFHQ' #@param [\"FFHQ\", \"MetFaces\", \"AFHQv2\", \"cosplay\", \"Wikiart\", \"Landscapes\"]\n", 310 | "\n", 311 | "#@markdown ---\n", 312 | "\n", 313 | "model_name = {\n", 314 | " \"FFHQ\": base_url + \"stylegan3-t-ffhqu-1024x1024.pkl\",\n", 315 | " \"MetFaces\": base_url + \"stylegan3-r-metfacesu-1024x1024.pkl\",\n", 316 | " \"AFHQv2\": base_url + \"stylegan3-t-afhqv2-512x512.pkl\",\n", 317 | " \"cosplay\": \"https://l4rz.net/cosplayface-snapshot-stylegan3t-008000.pkl\",\n", 318 | " \"Wikiart\": \"https://archive.org/download/wikiart-1024-stylegan3-t-17.2Mimg/wikiart-1024-stylegan3-t-17.2Mimg.pkl\",\n", 319 | " \"Landscapes\": \"https://archive.org/download/lhq-256-stylegan3-t-25Mimg/lhq-256-stylegan3-t-25Mimg.pkl\"\n", 320 | "}\n", 321 | "\n", 322 | "network_url = model_name[Model]\n", 323 | "\n", 324 | "with open(fetch_model(network_url), 'rb') as fp:\n", 325 | " G = pickle.load(fp)['G_ema'].to(device)\n", 326 | "\n", 327 | "zs = torch.randn([10000, G.mapping.z_dim], device=device)\n", 328 | "w_stds = G.mapping(zs, None).std(0)" 329 | ], 330 | "execution_count": null, 331 | "outputs": [] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "metadata": { 336 | "id": "V_rq-N2m0Tlb", 337 | "cellView": "form" 338 | }, 339 | "source": [ 340 | "#@markdown #**Parameters** ✍️\n", 341 | "#@markdown `texts`: Enter here a prompt to guide the image generation. You can enter more than one prompt separated with\n", 342 | "#@markdown `|`, which will cause the guidance to focus on the different prompts at the same time, allowing to mix and play\n", 343 | "#@markdown with the generation process.\n", 344 | "\n", 345 | "#@markdown `steps`: Number of optimization steps. The more steps, the longer it will try to generate an image relevant to the prompt.\n", 346 | "\n", 347 | "#@markdown `seed`: Determines the randomness seed. Using the same seed and prompt should give you similar results at every run.\n", 348 | "#@markdown Use `-1` for a random seed.\n", 349 | "\n", 350 | "#@markdown ---\n", 351 | "\n", 352 | "texts = \"A portrait of Nicanor Parra \"#@param {type:\"string\"}\n", 353 | "steps = 300#@param {type:\"number\"}\n", 354 | "seed = 14#@param {type:\"number\"}\n", 355 | "\n", 356 | "#@markdown ---\n", 357 | "\n", 358 | "if seed == -1:\n", 359 | " seed = np.random.randint(0,9e9)\n", 360 | " print(f\"Your random seed is: {seed}\")\n", 361 | "\n", 362 | "texts = [frase.strip() for frase in texts.split(\"|\") if frase]\n", 363 | "\n", 364 | "targets = [clip_model.embed_text(text) for text in texts]" 365 | ], 366 | "execution_count": null, 367 | "outputs": [] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "metadata": { 372 | "id": "QXoRP4SHzJ6i", 373 | "cellView": "form" 374 | }, 375 | "source": [ 376 | "#@markdown #**Run the model** 🚀\n", 377 | "\n", 378 | "# Actually do the run\n", 379 | "\n", 380 | "tf = Compose([\n", 381 | " Resize(224),\n", 382 | " lambda x: torch.clamp((x+1)/2,min=0,max=1),\n", 383 | " ])\n", 384 | "\n", 385 | "def run(timestring):\n", 386 | " torch.manual_seed(seed)\n", 387 | "\n", 388 | " # Init\n", 389 | " # Sample 32 inits and choose the one closest to prompt\n", 390 | "\n", 391 | " with torch.no_grad():\n", 392 | " qs = []\n", 393 | " losses = []\n", 394 | " for _ in range(8):\n", 395 | " q = (G.mapping(torch.randn([4,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds\n", 396 | " images = G.synthesis(q * w_stds + G.mapping.w_avg)\n", 397 | " embeds = embed_image(images.add(1).div(2))\n", 398 | " loss = prompts_dist_loss(embeds, targets, spherical_dist_loss).mean(0)\n", 399 | " i = torch.argmin(loss)\n", 400 | " qs.append(q[i])\n", 401 | " losses.append(loss[i])\n", 402 | " qs = torch.stack(qs)\n", 403 | " losses = torch.stack(losses)\n", 404 | " # print(losses)\n", 405 | " # print(losses.shape, qs.shape)\n", 406 | " i = torch.argmin(losses)\n", 407 | " q = qs[i].unsqueeze(0).requires_grad_()\n", 408 | "\n", 409 | " # Sampling loop\n", 410 | " q_ema = q\n", 411 | " opt = torch.optim.AdamW([q], lr=0.03, betas=(0.0,0.999))\n", 412 | " loop = tqdm(range(steps))\n", 413 | " for i in loop:\n", 414 | " opt.zero_grad()\n", 415 | " w = q * w_stds\n", 416 | " image = G.synthesis(w + G.mapping.w_avg, noise_mode='const')\n", 417 | " embed = embed_image(image.add(1).div(2))\n", 418 | " loss = prompts_dist_loss(embed, targets, spherical_dist_loss).mean()\n", 419 | " loss.backward()\n", 420 | " opt.step()\n", 421 | " loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item())\n", 422 | "\n", 423 | " q_ema = q_ema * 0.9 + q * 0.1\n", 424 | " image = G.synthesis(q_ema * w_stds + G.mapping.w_avg, noise_mode='const')\n", 425 | "\n", 426 | " if i % 10 == 0:\n", 427 | " display(TF.to_pil_image(tf(image)[0]))\n", 428 | " print(f\"Image {i}/{steps} | Current loss: {loss}\")\n", 429 | " pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1))\n", 430 | " os.makedirs(f'samples/{timestring}', exist_ok=True)\n", 431 | " pil_image.save(f'samples/{timestring}/{i:04}.jpg')\n", 432 | "\n", 433 | "try:\n", 434 | " timestring = time.strftime('%Y%m%d%H%M%S')\n", 435 | " run(timestring)\n", 436 | "except KeyboardInterrupt:\n", 437 | " pass" 438 | ], 439 | "execution_count": null, 440 | "outputs": [] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "metadata": { 445 | "id": "JSnRMY8-_-iV", 446 | "cellView": "form" 447 | }, 448 | "source": [ 449 | "#@markdown #**Save images** 📷\n", 450 | "#@markdown A `.tar` file will be saved inside *samples* and automatically downloaded, unless you previously ran the Google Drive cell,\n", 451 | "#@markdown in which case it'll be saved inside your previously created drive *samples* folder.\n", 452 | "\n", 453 | "archive_name = \"optional\"#@param {type:\"string\"}\n", 454 | "\n", 455 | "archive_name = slugify(archive_name)\n", 456 | "\n", 457 | "if archive_name != \"optional\":\n", 458 | " fname = archive_name\n", 459 | " # os.rename(f'samples/{timestring}', f'samples/{fname}')\n", 460 | "else:\n", 461 | " fname = timestring\n", 462 | "# Save images as a tar archive\n", 463 | "!tar cf samples/{fname}.tar samples/{timestring}\n", 464 | "if os.path.isdir('drive/MyDrive/samples'):\n", 465 | " shutil.copyfile(f'samples/{fname}.tar', f'drive/MyDrive/samples/{fname}.tar')\n", 466 | "else:\n", 467 | " files.download(f'samples/{fname}.tar')" 468 | ], 469 | "execution_count": null, 470 | "outputs": [] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "metadata": { 475 | "id": "Z9Yyt8y99jfv", 476 | "cellView": "form" 477 | }, 478 | "source": [ 479 | "#@markdown #**Generate video** 🎥\n", 480 | "\n", 481 | "#@markdown You can edit frame rate and stuff by double-clicking this tab.\n", 482 | "\n", 483 | "frames = os.listdir(f\"samples/{timestring}\")\n", 484 | "frames = len(list(filter(lambda filename: filename.endswith(\".jpg\"), frames))) #Get number of jpg generated\n", 485 | "\n", 486 | "init_frame = 1 #This is the frame where the video will start\n", 487 | "last_frame = frames #You can change i to the number of the last frame you want to generate. It will raise an error if that number of frames does not exist.\n", 488 | "\n", 489 | "min_fps = 10\n", 490 | "max_fps = 60\n", 491 | "\n", 492 | "total_frames = last_frame-init_frame\n", 493 | "\n", 494 | "#Desired video time in seconds\n", 495 | "video_length = 14 #@param {type:\"number\"}\n", 496 | "#Video filename\n", 497 | "video_name = \"\" #@param {type:\"string\"}\n", 498 | "video_name = slugify(video_name)\n", 499 | "\n", 500 | "if not video_name:\n", 501 | " video_name = \"video\"\n", 502 | "# frames = []\n", 503 | "# tqdm.write('Generating video...')\n", 504 | "# for i in range(init_frame,last_frame): #\n", 505 | "# filename = f\"samples/{timestring}/{i:04}.jpg\"\n", 506 | "# frames.append(Image.open(filename))\n", 507 | "\n", 508 | "fps = np.clip(total_frames/video_length,min_fps,max_fps)\n", 509 | "\n", 510 | "print(\"Generating video...\")\n", 511 | "!ffmpeg -r {fps} -i samples/{timestring}/%04d.jpg -c:v libx264 -vf fps={fps} -pix_fmt yuv420p samples/{video_name}.mp4 -frames:v {total_frames}\n", 512 | "\n", 513 | "# from subprocess import Popen, PIPE\n", 514 | "# p = Popen(['ffmpeg', '-y', '-f', 'image2pipe', '-vcodec', 'png', '-r', str(fps), '-i', '-', '-vcodec', 'libx264', '-r', str(fps), '-pix_fmt', 'yuv420p', '-crf', '17', '-preset', 'veryslow', f'samples/{video_name}.mp4'], stdin=PIPE)\n", 515 | "# for im in tqdm(frames):\n", 516 | "# im.save(p.stdin, 'PNG')\n", 517 | "# p.stdin.close()\n", 518 | "\n", 519 | "print(\"The video is ready\")" 520 | ], 521 | "execution_count": null, 522 | "outputs": [] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "metadata": { 527 | "id": "uNpjDjR_-0dN", 528 | "cellView": "form" 529 | }, 530 | "source": [ 531 | "#@markdown #**Download video** 📀\n", 532 | "#@markdown If you're activated the download to GDrive option, the video will be save there. Don't worry about overwritting issues for colliding filenames, an id will be added to them to avoid this.\n", 533 | "\n", 534 | "#Video filename\n", 535 | "to_download_video_name = \"\" #@param {type:\"string\"}\n", 536 | "to_download_video_name = slugify(to_download_video_name)\n", 537 | "\n", 538 | "if not to_download_video_name:\n", 539 | " to_download_video_name = \"video\"\n", 540 | "\n", 541 | "\n", 542 | "from google.colab import files\n", 543 | "if os.path.isdir('drive/MyDrive/samples'):\n", 544 | " filelist = glob.glob(f'drive/MyDrive/samples/{to_download_video_name}*.mp4')\n", 545 | " video_count = len(filelist)\n", 546 | " if video_count:\n", 547 | " final_video_name = f\"{to_download_video_name}{video_count}\"\n", 548 | " else:\n", 549 | " final_video_name = to_download_video_name\n", 550 | " shutil.copyfile(f'samples/{video_name}.mp4', f'drive/MyDrive/samples/{final_video_name}.mp4')\n", 551 | "else:\n", 552 | " files.download(f\"samples/{to_download_video_name}.mp4\")" 553 | ], 554 | "execution_count": null, 555 | "outputs": [] 556 | }, 557 | { 558 | "cell_type": "code", 559 | "metadata": { 560 | "id": "so1yHofG7RxX" 561 | }, 562 | "source": [ 563 | "#@title Licensed under the MIT License { display-mode: \"form\" }\n", 564 | "\n", 565 | "# Copyright (c) 2021 nshepperd; Katherine Crowson\n", 566 | "\n", 567 | "# Permission is hereby granted, free of charge, to any person obtaining a copy\n", 568 | "# of this software and associated documentation files (the \"Software\"), to deal\n", 569 | "# in the Software without restriction, including without limitation the rights\n", 570 | "# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", 571 | "# copies of the Software, and to permit persons to whom the Software is\n", 572 | "# furnished to do so, subject to the following conditions:\n", 573 | "\n", 574 | "# The above copyright notice and this permission notice shall be included in\n", 575 | "# all copies or substantial portions of the Software.\n", 576 | "\n", 577 | "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", 578 | "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", 579 | "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", 580 | "# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", 581 | "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", 582 | "# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n", 583 | "# THE SOFTWARE." 584 | ], 585 | "execution_count": null, 586 | "outputs": [] 587 | } 588 | ] 589 | } -------------------------------------------------------------------------------- /StyleGAN3+inversion+CLIP.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "StyleGAN3+inversion+CLIP.ipynb", 7 | "private_outputs": true, 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | }, 19 | "accelerator": "GPU" 20 | }, 21 | "cells": [ 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "id": "view-in-github", 26 | "colab_type": "text" 27 | }, 28 | "source": [ 29 | "\"Open" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": { 35 | "id": "bJOj_BWi_JwR" 36 | }, 37 | "source": [ 38 | "# **StyleGAN3 + inversion + CLIP 🖼️**\n", 39 | "\n", 40 | "## Project images to the latent space and edit them with text prompts using StyleGAN3 and CLIP guidance.\n", 41 | "\n", 42 | "Head over [here](https://github.com/ouhenio/StyleGAN3-CLIP-notebook) if you want to be up to date with the changes to this notebook and play with other alternatives.\n", 43 | "\n", 44 | "The original code was written by [nshepperd](https://twitter.com/nshepperd1)* (https://github.com/nshepperd), and later edited by [Eugenio Herrera](https://github.com/ouhenio).\n", 45 | "\n", 46 | "Thanks to [Katherine Crowson](https://twitter.com/RiversHaveWings) (https://github.com/crowsonkb) for coming up with many improved sampling tricks, as well as some of the code.\n", 47 | "\n", 48 | "----\n", 49 | "\n", 50 | "(*) nshepperd originally made [this notebook](https://colab.research.google.com/drive/1eYlenR1GHPZXt-YuvXabzO9wfh9CWY36#scrollTo=LQf7tzBQ8rn2).\n", 51 | "\n", 52 | "(**) The interface is inspired by [this notebook](https://colab.research.google.com/github/justinjohn0306/VQGAN-CLIP/blob/main/VQGAN%2BCLIP(Updated).ipynb), done by Jakeukalane and Avengium (Angel).\n", 53 | "\n", 54 | "(***) For more information about StyleGAN3, [visit the official repository](https://github.com/NVlabs/stylegan3)." 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "metadata": { 60 | "id": "rg_x6IdHv_2G", 61 | "cellView": "form" 62 | }, 63 | "source": [ 64 | "#@markdown #**Check GPU type** 🕵️\n", 65 | "#@markdown ### Factory reset runtime if you don't have the desired GPU.\n", 66 | "\n", 67 | "#@markdown ---\n", 68 | "\n", 69 | "\n", 70 | "\n", 71 | "\n", 72 | "#@markdown V100 = Excellent (*Available only for Colab Pro users*)\n", 73 | "\n", 74 | "#@markdown P100 = Very Good\n", 75 | "\n", 76 | "#@markdown T4 = Good (*preferred*)\n", 77 | "\n", 78 | "#@markdown K80 = Meh\n", 79 | "\n", 80 | "#@markdown P4 = (*Not Recommended*) \n", 81 | "\n", 82 | "#@markdown ---\n", 83 | "\n", 84 | "!nvidia-smi -L" 85 | ], 86 | "execution_count": null, 87 | "outputs": [] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "metadata": { 92 | "id": "5K38uyFrv5wo", 93 | "cellView": "form" 94 | }, 95 | "source": [ 96 | "#@markdown #**Install libraries** 🏗️\n", 97 | "# @markdown This cell will take a little while because it has to download several libraries.\n", 98 | "\n", 99 | "#@markdown ---\n", 100 | "\n", 101 | "!pip install --upgrade torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html\n", 102 | "#!pip install --upgrade https://download.pytorch.org/whl/nightly/cu111/torch-1.11.0.dev20211012%2Bcu111-cp37-cp37m-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu111/torchvision-0.12.0.dev20211012%2Bcu111-cp37-cp37m-linux_x86_64.whl\n", 103 | "!git clone https://github.com/NVlabs/stylegan3\n", 104 | "!git clone https://github.com/openai/CLIP\n", 105 | "!pip install -e ./CLIP\n", 106 | "!pip install einops ninja\n", 107 | "\n", 108 | "import sys\n", 109 | "sys.path.append('./CLIP')\n", 110 | "sys.path.append('./stylegan3')\n", 111 | "\n", 112 | "import io\n", 113 | "import os, time\n", 114 | "import pickle\n", 115 | "import shutil\n", 116 | "import numpy as np\n", 117 | "import torch\n", 118 | "import torch.nn.functional as F\n", 119 | "import requests\n", 120 | "import torchvision.transforms as transforms\n", 121 | "import torchvision.transforms.functional as TF\n", 122 | "import clip\n", 123 | "import copy\n", 124 | "import imageio\n", 125 | "import unicodedata\n", 126 | "import re\n", 127 | "from PIL import Image\n", 128 | "from tqdm.notebook import tqdm\n", 129 | "from torchvision.transforms import Compose, Resize, ToTensor, Normalize\n", 130 | "from IPython.display import display\n", 131 | "from einops import rearrange\n", 132 | "from google.colab import files\n", 133 | "from time import perf_counter\n", 134 | "from stylegan3.dnnlib.util import open_url\n", 135 | "\n", 136 | "\n", 137 | "device = torch.device('cuda:0')\n", 138 | "\n", 139 | "# Load VGG16 feature detector.\n", 140 | "url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'\n", 141 | "with open_url(url) as f:\n", 142 | " vgg16 = torch.jit.load(f).eval().to(device)\n", 143 | "print('Using device:', device, file=sys.stderr)" 144 | ], 145 | "execution_count": null, 146 | "outputs": [] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "metadata": { 151 | "id": "zVkNODOot_To", 152 | "cellView": "form" 153 | }, 154 | "source": [ 155 | "#@markdown #**Optional:** Save images in Google Drive 💾\n", 156 | "# @markdown Run this cell if you want to store the results inside Google Drive.\n", 157 | "\n", 158 | "# @markdown Copying the generated images to drive is faster to work with.\n", 159 | "\n", 160 | "# @markdown **Important**: you must have a folder named *samples* inside your drive, otherwise this may not work.\n", 161 | "\n", 162 | "#@markdown ---\n", 163 | "\n", 164 | "# Uncomment to copy generated images to drive, faster than downloading directly from colab in my experience.\n", 165 | "from google.colab import drive\n", 166 | "drive.mount('/content/drive')" 167 | ], 168 | "execution_count": null, 169 | "outputs": [] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "metadata": { 174 | "id": "1GOWJ_z-wgde", 175 | "cellView": "form" 176 | }, 177 | "source": [ 178 | "#@markdown #**Define necessary functions** 🛠️\n", 179 | "\n", 180 | "def fetch(url_or_path):\n", 181 | " if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):\n", 182 | " r = requests.get(url_or_path)\n", 183 | " r.raise_for_status()\n", 184 | " fd = io.BytesIO()\n", 185 | " fd.write(r.content)\n", 186 | " fd.seek(0)\n", 187 | " return fd\n", 188 | " return open(url_or_path, 'rb')\n", 189 | "\n", 190 | "def fetch_model(url_or_path):\n", 191 | " if \"drive.google\" in url_or_path:\n", 192 | " if \"18MOpwTMJsl_Z17q-wQVnaRLCUFZYSNkj\" in url_or_path: \n", 193 | " basename = \"wikiart-1024-stylegan3-t-17.2Mimg.pkl\"\n", 194 | " elif \"14UGDDOusZ9TMb-pOrF0PAjMGVWLSAii1\" in url_or_path:\n", 195 | " basename = \"lhq-256-stylegan3-t-25Mimg.pkl\"\n", 196 | " else:\n", 197 | " basename = os.path.basename(url_or_path)\n", 198 | " if os.path.exists(basename):\n", 199 | " return basename\n", 200 | " else:\n", 201 | " if \"drive.google\" not in url_or_path:\n", 202 | " !wget -c '{url_or_path}'\n", 203 | " else:\n", 204 | " path_id = url_or_path.split(\"id=\")[-1]\n", 205 | " !gdown --id '{path_id}'\n", 206 | " return basename\n", 207 | "\n", 208 | "def slugify(value, allow_unicode=False):\n", 209 | " \"\"\"\n", 210 | " Taken from https://github.com/django/django/blob/master/django/utils/text.py\n", 211 | " Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated\n", 212 | " dashes to single dashes. Remove characters that aren't alphanumerics,\n", 213 | " underscores, or hyphens. Convert to lowercase. Also strip leading and\n", 214 | " trailing whitespace, dashes, and underscores.\n", 215 | " \"\"\"\n", 216 | " value = str(value)\n", 217 | " if allow_unicode:\n", 218 | " value = unicodedata.normalize('NFKC', value)\n", 219 | " else:\n", 220 | " value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')\n", 221 | " value = re.sub(r'[^\\w\\s-]', '', value.lower())\n", 222 | " return re.sub(r'[-\\s]+', '-', value).strip('-_')\n", 223 | "\n", 224 | "def norm1(prompt):\n", 225 | " \"Normalize to the unit sphere.\"\n", 226 | " return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt()\n", 227 | "\n", 228 | "def spherical_dist_loss(x, y):\n", 229 | " x = F.normalize(x, dim=-1)\n", 230 | " y = F.normalize(y, dim=-1)\n", 231 | " return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)\n", 232 | "\n", 233 | "class MakeCutouts(torch.nn.Module):\n", 234 | " def __init__(self, cut_size, cutn, cut_pow=1.):\n", 235 | " super().__init__()\n", 236 | " self.cut_size = cut_size\n", 237 | " self.cutn = cutn\n", 238 | " self.cut_pow = cut_pow\n", 239 | "\n", 240 | " def forward(self, input):\n", 241 | " sideY, sideX = input.shape[2:4]\n", 242 | " max_size = min(sideX, sideY)\n", 243 | " min_size = min(sideX, sideY, self.cut_size)\n", 244 | " cutouts = []\n", 245 | " for _ in range(self.cutn):\n", 246 | " size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)\n", 247 | " offsetx = torch.randint(0, sideX - size + 1, ())\n", 248 | " offsety = torch.randint(0, sideY - size + 1, ())\n", 249 | " cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n", 250 | " cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))\n", 251 | " return torch.cat(cutouts)\n", 252 | "\n", 253 | "make_cutouts = MakeCutouts(224, 32, 0.5)\n", 254 | "\n", 255 | "def embed_image(image):\n", 256 | " n = image.shape[0]\n", 257 | " cutouts = make_cutouts(image)\n", 258 | " embeds = clip_model.embed_cutout(cutouts)\n", 259 | " embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n)\n", 260 | " return embeds\n", 261 | "\n", 262 | "def embed_url(url):\n", 263 | " image = Image.open(fetch(url)).convert('RGB')\n", 264 | " return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0)\n", 265 | "\n", 266 | "class CLIP(object):\n", 267 | " def __init__(self):\n", 268 | " clip_model = \"ViT-B/32\"\n", 269 | " self.model, _ = clip.load(clip_model)\n", 270 | " self.model = self.model.requires_grad_(False)\n", 271 | " self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],\n", 272 | " std=[0.26862954, 0.26130258, 0.27577711])\n", 273 | "\n", 274 | " @torch.no_grad()\n", 275 | " def embed_text(self, prompt):\n", 276 | " \"Normalized clip text embedding.\"\n", 277 | " return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float())\n", 278 | "\n", 279 | " def embed_cutout(self, image):\n", 280 | " \"Normalized clip image embedding.\"\n", 281 | " return norm1(self.model.encode_image(self.normalize(image)))\n", 282 | " \n", 283 | "clip_model = CLIP()\n", 284 | "\n", 285 | "# Projector\n", 286 | "\n", 287 | "def project(\n", 288 | " G,\n", 289 | " target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution\n", 290 | " *,\n", 291 | " num_steps = 1000,\n", 292 | " w_avg_samples = -1,\n", 293 | " initial_learning_rate = 0.1,\n", 294 | " initial_noise_factor = 0.05,\n", 295 | " lr_rampdown_length = 0.25,\n", 296 | " lr_rampup_length = 0.05,\n", 297 | " noise_ramp_length = 0.75,\n", 298 | " regularize_noise_weight = 1e5,\n", 299 | " verbose = False,\n", 300 | " device: torch.device\n", 301 | "):\n", 302 | "\n", 303 | " assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)\n", 304 | "\n", 305 | " def logprint(*args):\n", 306 | " if verbose:\n", 307 | " print(*args)\n", 308 | "\n", 309 | " G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore\n", 310 | "\n", 311 | " # Compute w stats.\n", 312 | " if w_avg_samples > 0:\n", 313 | " logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')\n", 314 | " z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)\n", 315 | " else:\n", 316 | " seed = np.random.randint(0, 2**32 - 1)\n", 317 | " z_samples = np.random.RandomState(seed).randn(1, G.z_dim)\n", 318 | " w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C]\n", 319 | " w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]\n", 320 | " w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]\n", 321 | " w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5\n", 322 | "\n", 323 | " # Setup noise inputs.\n", 324 | " noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }\n", 325 | "\n", 326 | " # Features for target image.\n", 327 | " target_images = target.unsqueeze(0).to(device).to(torch.float32)\n", 328 | " if target_images.shape[2] > 256:\n", 329 | " target_images = F.interpolate(target_images, size=(256, 256), mode='area')\n", 330 | " target_features = vgg16(target_images, resize_images=False, return_lpips=True)\n", 331 | "\n", 332 | " w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable\n", 333 | " w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)\n", 334 | " optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)\n", 335 | "\n", 336 | " # Init noise.\n", 337 | " for buf in noise_bufs.values():\n", 338 | " buf[:] = torch.randn_like(buf)\n", 339 | " buf.requires_grad = True\n", 340 | "\n", 341 | " for step in range(num_steps):\n", 342 | " # Learning rate schedule.\n", 343 | " t = step / num_steps\n", 344 | " w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2\n", 345 | " lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)\n", 346 | " lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)\n", 347 | " lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)\n", 348 | " lr = initial_learning_rate * lr_ramp\n", 349 | " for param_group in optimizer.param_groups:\n", 350 | " param_group['lr'] = lr\n", 351 | "\n", 352 | " # Synth images from opt_w.\n", 353 | " w_noise = torch.randn_like(w_opt) * w_noise_scale\n", 354 | " ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])\n", 355 | " synth_images = G.synthesis(ws, noise_mode='const')\n", 356 | "\n", 357 | " # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.\n", 358 | " synth_images = (synth_images + 1) * (255/2)\n", 359 | " if synth_images.shape[2] > 256:\n", 360 | " synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')\n", 361 | "\n", 362 | " # Features for synth images.\n", 363 | " synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)\n", 364 | " dist = (target_features - synth_features).square().sum()\n", 365 | "\n", 366 | " # Noise regularization.\n", 367 | " reg_loss = 0.0\n", 368 | " for v in noise_bufs.values():\n", 369 | " noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d()\n", 370 | " while True:\n", 371 | " reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2\n", 372 | " reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2\n", 373 | " if noise.shape[2] <= 8:\n", 374 | " break\n", 375 | " noise = F.avg_pool2d(noise, kernel_size=2)\n", 376 | " loss = dist + reg_loss * regularize_noise_weight\n", 377 | "\n", 378 | " # Step\n", 379 | " optimizer.zero_grad(set_to_none=True)\n", 380 | " loss.backward()\n", 381 | " optimizer.step()\n", 382 | " logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')\n", 383 | "\n", 384 | " # Save projected W for each optimization step.\n", 385 | " w_out[step] = w_opt.detach()[0]\n", 386 | "\n", 387 | " # Normalize noise.\n", 388 | " with torch.no_grad():\n", 389 | " for buf in noise_bufs.values():\n", 390 | " buf -= buf.mean()\n", 391 | " buf *= buf.square().mean().rsqrt()\n", 392 | "\n", 393 | " return w_out.repeat([1, G.mapping.num_ws, 1])\n", 394 | "\n", 395 | "def get_perceptual_loss(synth_image, target_features):\n", 396 | " # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.\n", 397 | " synth_image = (synth_image + 1) * (255/2)\n", 398 | " if synth_image.shape[2] > 256:\n", 399 | " synth_image = F.interpolate(synth_image, size=(256, 256), mode='area')\n", 400 | "\n", 401 | " # Features for synth images.\n", 402 | " synth_features = vgg16(synth_image, resize_images=False, return_lpips=True)\n", 403 | " return (target_features - synth_features).square().sum()\n", 404 | "\n", 405 | "def get_target_features(target):\n", 406 | " target_images = target.unsqueeze(0).to(device).to(torch.float32)\n", 407 | " if target_images.shape[2] > 256:\n", 408 | " target_images = F.interpolate(target_images, size=(256, 256), mode='area')\n", 409 | " return vgg16(target_images, resize_images=False, return_lpips=True)" 410 | ], 411 | "execution_count": null, 412 | "outputs": [] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "metadata": { 417 | "id": "4vpIq2vvzTtS", 418 | "cellView": "form" 419 | }, 420 | "source": [ 421 | "#@markdown #**Model selection** 🎭\n", 422 | "\n", 423 | "\n", 424 | "#@markdown There are 4 pre-trained options to play with:\n", 425 | "#@markdown - FFHQ: Trained with human faces.\n", 426 | "#@markdown - MetFaces: Trained with paintings/portraits of human faces.\n", 427 | "#@markdown - AFHQv2: Trained with animal faces.\n", 428 | "#@markdown - Cosplay: Trained by [l4rz](https://twitter.com/l4rz) with cosplayer's faces.\n", 429 | "#@markdown - Wikiart: Trained by [Justin Pinkney](https://www.justinpinkney.com/) with the Wikiart 1024 dataset.\n", 430 | "#@markdown - Landscapes: Trained by [Justin Pinkney](https://www.justinpinkney.com/) with the LHQ dataset.\n", 431 | "\n", 432 | "\n", 433 | "#@markdown **Run this cell again if you change the model**.\n", 434 | "\n", 435 | "#@markdown ---\n", 436 | "\n", 437 | "base_url = \"https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/\"\n", 438 | "\n", 439 | "Model = 'FFHQ' #@param [\"FFHQ\", \"MetFaces\", \"AFHQv2\", \"cosplay\", \"Wikiart\", \"Landscapes\"]\n", 440 | "\n", 441 | "#@markdown ---\n", 442 | "\n", 443 | "model_name = {\n", 444 | " \"FFHQ\": base_url + \"stylegan3-t-ffhqu-1024x1024.pkl\",\n", 445 | " \"MetFaces\": base_url + \"stylegan3-r-metfacesu-1024x1024.pkl\",\n", 446 | " \"AFHQv2\": base_url + \"stylegan3-t-afhqv2-512x512.pkl\",\n", 447 | " \"cosplay\": \"https://l4rz.net/cosplayface-snapshot-stylegan3t-008000.pkl\",\n", 448 | " \"Wikiart\": \"https://archive.org/download/wikiart-1024-stylegan3-t-17.2Mimg/wikiart-1024-stylegan3-t-17.2Mimg.pkl\",\n", 449 | " \"Landscapes\": \"https://archive.org/download/lhq-256-stylegan3-t-25Mimg/lhq-256-stylegan3-t-25Mimg.pkl\"\n", 450 | "}\n", 451 | "\n", 452 | "network_url = model_name[Model]\n", 453 | "\n", 454 | "with open(fetch_model(network_url), 'rb') as fp:\n", 455 | " G = pickle.load(fp)['G_ema'].to(device)\n", 456 | "\n", 457 | "zs = torch.randn([10000, G.mapping.z_dim], device=device)\n", 458 | "w_stds = G.mapping(zs, None).std(0)" 459 | ], 460 | "execution_count": null, 461 | "outputs": [] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "metadata": { 466 | "id": "V_rq-N2m0Tlb", 467 | "cellView": "form" 468 | }, 469 | "source": [ 470 | "#@markdown #**Parameters** ✍️\n", 471 | "#@markdown ---\n", 472 | "\n", 473 | "target_image_filename = \"\"#@param {type:\"string\"}\n", 474 | "text = \"\"#@param {type:\"string\"}\n", 475 | "loss_ratio = 0.4#@param {type:\"number\"}\n", 476 | "steps = 800#@param {type:\"number\"}\n", 477 | "limit_step = 600#@param {type:\"number\"}\n", 478 | "seed = 14#@param {type:\"number\"}\n", 479 | "\n", 480 | "#@markdown Choose -1 for a random seed.\n", 481 | "\n", 482 | "#@markdown ---\n", 483 | "\n", 484 | "if seed == -1:\n", 485 | " seed = np.random.randint(0,9e9)\n", 486 | "\n", 487 | "target = clip_model.embed_text(text)\n", 488 | "\n", 489 | "target_pil = Image.open(target_image_filename).convert('RGB')\n", 490 | "w, h = target_pil.size\n", 491 | "s = min(w, h)\n", 492 | "target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))\n", 493 | "target_pil = target_pil.resize((G.img_resolution, G.img_resolution), Image.LANCZOS)\n", 494 | "target_uint8 = np.array(target_pil, dtype=np.uint8)\n", 495 | "target_tensor = torch.tensor(target_uint8.transpose([2, 0, 1]), device=device)" 496 | ], 497 | "execution_count": null, 498 | "outputs": [] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "metadata": { 503 | "id": "jEuChPI65Kh-", 504 | "cellView": "form" 505 | }, 506 | "source": [ 507 | "#@markdown #**Run the model** 🚀\n", 508 | "\n", 509 | "# Actually do the run\n", 510 | "\n", 511 | "tf = Compose([\n", 512 | " Resize(224),\n", 513 | " lambda x: torch.clamp((x+1)/2,min=0,max=1),\n", 514 | "])\n", 515 | "\n", 516 | "\n", 517 | "def run(timestring, projection_target):\n", 518 | " torch.manual_seed(seed)\n", 519 | "\n", 520 | " target_features = get_target_features(projection_target)\n", 521 | "\n", 522 | " # Init\n", 523 | " # Sample 32 inits and choose the one closest to prompt\n", 524 | "\n", 525 | " with torch.no_grad():\n", 526 | " qs = []\n", 527 | " losses = []\n", 528 | " for _ in range(8):\n", 529 | " q = (G.mapping(torch.randn([4,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds\n", 530 | " images = G.synthesis(q * w_stds + G.mapping.w_avg)\n", 531 | " loss = get_perceptual_loss(images, target_features)\n", 532 | " i = torch.argmin(loss)\n", 533 | " qs.append(q[i])\n", 534 | " losses.append(loss)\n", 535 | " qs = torch.stack(qs)\n", 536 | " losses = torch.stack(losses)\n", 537 | " i = torch.argmin(losses)\n", 538 | " q = qs[i].unsqueeze(0).requires_grad_()\n", 539 | "\n", 540 | " # Sampling loop\n", 541 | " q_ema = q\n", 542 | " opt = torch.optim.AdamW([q], lr=0.03, betas=(0.0,0.999))\n", 543 | " loop = tqdm(range(steps))\n", 544 | " for i in loop:\n", 545 | " opt.zero_grad()\n", 546 | " w = q * w_stds\n", 547 | " image = G.synthesis(w + G.mapping.w_avg, noise_mode='const')\n", 548 | " embed = embed_image(image.add(1).div(2))\n", 549 | " step_ratio = i / limit_step\n", 550 | " perceptual_loss = get_perceptual_loss(image, target_features)\n", 551 | " modulated_perceptual_loss = (\n", 552 | " max(loss_ratio, 1 - step_ratio)\n", 553 | " * get_perceptual_loss(image, target_features)\n", 554 | " )\n", 555 | " clip_loss = spherical_dist_loss(embed, target).mean()\n", 556 | " modulated_clip_loss = (\n", 557 | " min(1 - loss_ratio, step_ratio)\n", 558 | " * (step_ratio) * spherical_dist_loss(embed, target).mean()\n", 559 | " )\n", 560 | " loss = modulated_perceptual_loss + modulated_clip_loss\n", 561 | " loss.backward()\n", 562 | " opt.step()\n", 563 | " loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item())\n", 564 | "\n", 565 | " q_ema = q_ema * 0.9 + q * 0.1\n", 566 | " image = G.synthesis(q_ema * w_stds + G.mapping.w_avg, noise_mode='const')\n", 567 | "\n", 568 | " if i % 10 == 0:\n", 569 | " display(TF.to_pil_image(tf(image)[0]))\n", 570 | " print(f\"image {i}/{steps} | projector loss: {perceptual_loss} | clip loss: {clip_loss} | modulated loss: {loss}\")\n", 571 | " pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1))\n", 572 | " os.makedirs(f'samples/{timestring}', exist_ok=True)\n", 573 | " pil_image.save(f'samples/{timestring}/{i:04}.jpg')\n", 574 | "\n", 575 | "try:\n", 576 | " timestring = time.strftime('%Y%m%d%H%M%S')\n", 577 | " run(timestring, target_tensor)\n", 578 | "except KeyboardInterrupt:\n", 579 | " pass" 580 | ], 581 | "execution_count": null, 582 | "outputs": [] 583 | }, 584 | { 585 | "cell_type": "code", 586 | "metadata": { 587 | "id": "JSnRMY8-_-iV", 588 | "cellView": "form" 589 | }, 590 | "source": [ 591 | "#@markdown #**Save images** 📷\n", 592 | "#@markdown A `.tar` file will be saved inside *samples* and automatically downloaded, unless you previously ran the Google Drive cell,\n", 593 | "#@markdown in which case it'll be saved inside your previously created drive *samples* folder.\n", 594 | "\n", 595 | "archive_name = \"optional\"#@param {type:\"string\"}\n", 596 | "\n", 597 | "archive_name = slugify(archive_name)\n", 598 | "\n", 599 | "if archive_name != \"optional\":\n", 600 | " fname = archive_name\n", 601 | " # os.rename(f'samples/{timestring}', f'samples/{fname}')\n", 602 | "else:\n", 603 | " fname = timestring\n", 604 | "# Save images as a tar archive\n", 605 | "!tar cf samples/{fname}.tar samples/{timestring}\n", 606 | "if os.path.isdir('drive/MyDrive/samples'):\n", 607 | " shutil.copyfile(f'samples/{fname}.tar', f'drive/MyDrive/samples/{fname}.tar')\n", 608 | "else:\n", 609 | " files.download(f'samples/{fname}.tar')" 610 | ], 611 | "execution_count": null, 612 | "outputs": [] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "metadata": { 617 | "id": "Z9Yyt8y99jfv", 618 | "cellView": "form" 619 | }, 620 | "source": [ 621 | "#@markdown #**Generate video** 🎥\n", 622 | "\n", 623 | "#@markdown You can edit frame rate and stuff by double-clicking this tab.\n", 624 | "\n", 625 | "frames = os.listdir(f\"samples/{timestring}\")\n", 626 | "frames = len(list(filter(lambda filename: filename.endswith(\".jpg\"), frames))) #Get number of jpg generated\n", 627 | "\n", 628 | "init_frame = 1 #This is the frame where the video will start\n", 629 | "last_frame = frames #You can change i to the number of the last frame you want to generate. It will raise an error if that number of frames does not exist.\n", 630 | "\n", 631 | "min_fps = 10\n", 632 | "max_fps = 60\n", 633 | "\n", 634 | "total_frames = last_frame-init_frame\n", 635 | "\n", 636 | "#Desired video time in seconds\n", 637 | "video_length = 14 #@param {type:\"number\"}\n", 638 | "#Video filename\n", 639 | "video_name = \"\" #@param {type:\"string\"}\n", 640 | "video_name = slugify(video_name)\n", 641 | "\n", 642 | "# frames = []\n", 643 | "# tqdm.write('Generating video...')\n", 644 | "# for i in range(init_frame,last_frame): #\n", 645 | "# filename = f\"samples/{timestring}/{i:04}.jpg\"\n", 646 | "# frames.append(Image.open(filename))\n", 647 | "\n", 648 | "fps = np.clip(total_frames/video_length,min_fps,max_fps)\n", 649 | "\n", 650 | "!ffmpeg -r {fps} -i samples/{timestring}/%04d.jpg -c:v libx264 -vf fps={fps} -pix_fmt yuv420p samples/{video_name}.mp4 -frames:v {total_frames}\n", 651 | "\n", 652 | "# from subprocess import Popen, PIPE\n", 653 | "# p = Popen(['ffmpeg', '-y', '-f', 'image2pipe', '-vcodec', 'png', '-r', str(fps), '-i', '-', '-vcodec', 'libx264', '-r', str(fps), '-pix_fmt', 'yuv420p', '-crf', '17', '-preset', 'veryslow', f'samples/{video_name}.mp4'], stdin=PIPE)\n", 654 | "# for im in tqdm(frames):\n", 655 | "# im.save(p.stdin, 'PNG')\n", 656 | "# p.stdin.close()\n", 657 | "\n", 658 | "print(\"The video is now being compressed, wait...\")\n", 659 | "p.wait()\n", 660 | "print(\"The video is ready\")" 661 | ], 662 | "execution_count": null, 663 | "outputs": [] 664 | }, 665 | { 666 | "cell_type": "code", 667 | "metadata": { 668 | "id": "uNpjDjR_-0dN", 669 | "cellView": "form" 670 | }, 671 | "source": [ 672 | "#@markdown #**Download video** 📀\n", 673 | "#@markdown If you're activated the download to GDrive option, the video will be save there. Don't worry about overwritting issues for colliding filenames, an id will be added to them to avoid this.\n", 674 | "\n", 675 | "#Video filename\n", 676 | "to_download_video_name = \"\" #@param {type:\"string\"}\n", 677 | "to_download_video_name = slugify(to_download_video_name)\n", 678 | "\n", 679 | "from google.colab import files\n", 680 | "if os.path.isdir('drive/MyDrive/samples'):\n", 681 | " filelist = glob.glob(f'drive/MyDrive/samples/{to_download_video_name}*.mp4')\n", 682 | " video_count = len(filelist)\n", 683 | " if video_count:\n", 684 | " final_video_name = f\"{to_download_video_name}{video_count}\"\n", 685 | " else:\n", 686 | " final_video_name = to_download_video_name\n", 687 | " shutil.copyfile(f'samples/{video_name}.mp4', f'drive/MyDrive/samples/{final_video_name}.mp4')\n", 688 | "else:\n", 689 | " files.download(f\"{to_download_video_name}.mp4\")" 690 | ], 691 | "execution_count": null, 692 | "outputs": [] 693 | }, 694 | { 695 | "cell_type": "code", 696 | "metadata": { 697 | "id": "so1yHofG7RxX" 698 | }, 699 | "source": [ 700 | "#@title Licensed under the MIT License { display-mode: \"form\" }\n", 701 | "\n", 702 | "# Copyright (c) 2021 nshepperd; Katherine Crowson\n", 703 | "\n", 704 | "# Permission is hereby granted, free of charge, to any person obtaining a copy\n", 705 | "# of this software and associated documentation files (the \"Software\"), to deal\n", 706 | "# in the Software without restriction, including without limitation the rights\n", 707 | "# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", 708 | "# copies of the Software, and to permit persons to whom the Software is\n", 709 | "# furnished to do so, subject to the following conditions:\n", 710 | "\n", 711 | "# The above copyright notice and this permission notice shall be included in\n", 712 | "# all copies or substantial portions of the Software.\n", 713 | "\n", 714 | "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", 715 | "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", 716 | "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", 717 | "# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", 718 | "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", 719 | "# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n", 720 | "# THE SOFTWARE." 721 | ], 722 | "execution_count": null, 723 | "outputs": [] 724 | } 725 | ] 726 | } -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | gpu: true 3 | cuda: 11.0 4 | python_version: "3.8" 5 | system_packages: 6 | - "libgl1-mesa-glx" 7 | - "libglib2.0-0" 8 | - "ninja-build" 9 | python_packages: 10 | - "ninja==1.10.2.2" 11 | - "torchvision==0.10.0" 12 | - "torch==1.9.0" 13 | - "numpy==1.19.4" 14 | - "tqdm==4.62.2" 15 | - "Pillow==8.3.2" 16 | - "einops==0.3.2" 17 | - "ftfy==6.0.3" 18 | - "ipython==7.19.0" 19 | - "scipy==1.7.1" 20 | - "regex==2021.10.8" 21 | run: 22 | - apt-get update && apt-get install -y software-properties-common 23 | - add-apt-repository ppa:ubuntu-toolchain-r/test 24 | - apt update -y && apt-get install ffmpeg -y 25 | - apt install g++-7 -y 26 | - update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-7 60 --slave /usr/bin/g++ g++ /usr/bin/g++-7 27 | - update-alternatives --config gcc 28 | 29 | predict: "predict.py:Predictor" 30 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | git clone https://github.com/NVlabs/stylegan3 4 | git clone https://github.com/openai/CLIP 5 | pip install -e ./CLIP -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, 'CLIP') 4 | sys.path.insert(0, 'stylegan3') 5 | import tempfile 6 | from pathlib import Path 7 | from subprocess import call 8 | import io 9 | import pickle 10 | import shutil 11 | import numpy as np 12 | from PIL import Image 13 | import torch 14 | import torch.nn.functional as F 15 | import requests 16 | import torchvision.transforms as transforms 17 | import torchvision.transforms.functional as TF 18 | import clip 19 | from tqdm import tqdm 20 | from torchvision.transforms import Compose, Resize 21 | from einops import rearrange 22 | from subprocess import Popen, PIPE 23 | import cog 24 | 25 | 26 | class Predictor(cog.Predictor): 27 | def setup(self): 28 | self.device = torch.device('cuda:0') 29 | self.clip_model = CLIP() 30 | NVIDIA_MODEL_NAME = { 31 | "FFHQ": "stylegan3-t-ffhqu-1024x1024.pkl", 32 | "MetFaces": "stylegan3-r-metfacesu-1024x1024.pkl", 33 | "AFHQv2": "stylegan3-t-afhqv2-512x512.pkl" 34 | } 35 | BASE_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/" 36 | 37 | self.models = {} 38 | self.w_stdss = {} 39 | for key, value in NVIDIA_MODEL_NAME.items(): 40 | network_url = BASE_URL + value 41 | with open(fetch_model(network_url), 'rb') as fp: 42 | G = pickle.load(fp)['G_ema'].to(self.device) 43 | self.models[key] = G 44 | zs = torch.randn([10000, G.mapping.z_dim], device=self.device) 45 | self.w_stdss[key] = G.mapping(zs, None).std(0) 46 | 47 | CUSTOM_MODEL_NAME = { 48 | "Cosplay": "https://l4rz.net/cosplayface-snapshot-stylegan3t-008000.pkl", 49 | "Wikiart": "https://archive.org/download/wikiart-1024-stylegan3-t-17.2Mimg/wikiart-1024-stylegan3-t-17.2Mimg.pkl", 50 | "Landscapes": "https://archive.org/download/lhq-256-stylegan3-t-25Mimg/lhq-256-stylegan3-t-25Mimg.pkl" 51 | } 52 | 53 | for key, network_url in CUSTOM_MODEL_NAME.items(): 54 | with open(fetch_model(network_url), 'rb') as fp: 55 | G = pickle.load(fp)['G_ema'].to(self.device) 56 | self.models[key] = G 57 | zs = torch.randn([10000, G.mapping.z_dim], device=self.device) 58 | self.w_stdss[key] = G.mapping(zs, None).std(0) 59 | 60 | @cog.input( 61 | "texts", 62 | type=str, 63 | help="Enter here a prompt to guide the image generation. You can enter more than one prompt separated with |, " 64 | "which will cause the guidance to focus on the different prompts at the same time, allowing to mix " 65 | "and play with the generation process." 66 | ) 67 | @cog.input( 68 | "model_name", 69 | type=str, 70 | default='FFHQ', 71 | options=['FFHQ', 'MetFaces', 'AFHQv2', 'Cosplay', 'Wikiart', 'Landscape'], 72 | help="""choose model: FFHQ: human faces, 73 | MetFaces: human faces from works of art, 74 | AFHGv2: animal faces, 75 | Cosplay: cosplayer's faces (by l4rz), 76 | Wikiart: Wikiart 1024 dataset (by Justin Pinkney), 77 | Landscapes: landscape images (by Justin Pinkney) 78 | """ 79 | ) 80 | @cog.input( 81 | "output_type", 82 | type=str, 83 | default='mp4', 84 | options=['png', 'mp4'], 85 | help="choose output the final image or a video" 86 | ) 87 | @cog.input( 88 | "steps", 89 | type=int, 90 | default=200, 91 | min=1, 92 | help="sampling steps, for FFHQ and MetFaces models, recommended to set <= 100 to avoid time out" 93 | ) 94 | @cog.input( 95 | "learning_rate", 96 | type=float, 97 | default=0.05, 98 | help="learning rate" 99 | ) 100 | @cog.input( 101 | "video_length", 102 | type=int, 103 | default=10, 104 | max=20, 105 | min=1, 106 | help="choose video length, valid if output is mp4" 107 | ) 108 | @cog.input( 109 | "seed", 110 | type=int, 111 | default=2, 112 | help="set -1 for random seed" 113 | ) 114 | def predict(self, texts, model_name, steps, output_type, video_length, seed, learning_rate): 115 | if os.path.isdir('samples'): 116 | clean_folder('samples') 117 | os.makedirs(f'samples', exist_ok=True) 118 | 119 | G = self.models[model_name] 120 | w_stds = self.w_stdss[model_name] 121 | if not isinstance(seed, int): 122 | seed = 2 123 | if seed == -1: 124 | seed = np.random.randint(0, 9e9) 125 | 126 | texts = [frase.strip() for frase in texts.split("|") if frase] 127 | targets = [self.clip_model.embed_text(text, self.device) for text in texts] 128 | 129 | tf = Compose([ 130 | Resize(224), 131 | lambda x: torch.clamp((x + 1) / 2, min=0, max=1), 132 | ]) 133 | 134 | # do run 135 | torch.manual_seed(seed) 136 | 137 | # Init 138 | # Sample 32 inits and choose the one closest to prompt 139 | with torch.no_grad(): 140 | qs = [] 141 | losses = [] 142 | for _ in range(8): 143 | q = (G.mapping(torch.randn([4, G.mapping.z_dim], device=self.device), None, 144 | truncation_psi=0.7) - G.mapping.w_avg) / w_stds 145 | images = G.synthesis(q * w_stds + G.mapping.w_avg) 146 | embeds = embed_image(images.add(1).div(2), self.clip_model) 147 | loss = prompts_dist_loss(embeds, targets, spherical_dist_loss).mean(0) 148 | i = torch.argmin(loss) 149 | qs.append(q[i]) 150 | losses.append(loss[i]) 151 | qs = torch.stack(qs) 152 | losses = torch.stack(losses) 153 | i = torch.argmin(losses) 154 | q = qs[i].unsqueeze(0).requires_grad_() 155 | 156 | # Sampling loop 157 | q_ema = q 158 | opt = torch.optim.AdamW([q], lr=learning_rate, betas=(0.0, 0.999)) 159 | img_path = Path(tempfile.mkdtemp()) / "progress.png" 160 | 161 | for i in range(steps): 162 | opt.zero_grad() 163 | w = q * w_stds 164 | image = G.synthesis(w + G.mapping.w_avg, noise_mode='const') 165 | embed = embed_image(image.add(1).div(2), self.clip_model) 166 | loss = prompts_dist_loss(embed, targets, spherical_dist_loss).mean() 167 | loss.backward() 168 | opt.step() 169 | 170 | q_ema = q_ema * 0.9 + q * 0.1 171 | image = G.synthesis(q_ema * w_stds + G.mapping.w_avg, noise_mode='const') 172 | 173 | if (i + 1) % 10 == 0: 174 | yield checkin(i, steps, loss, tf, image, img_path) 175 | 176 | if output_type == 'mp4': 177 | pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0, 1)) 178 | pil_image.save(f'samples/{i:04}.png') 179 | 180 | if output_type == 'png': 181 | out_path_png = Path(tempfile.mkdtemp()) / "out.png" 182 | yield checkin(None, steps, None, tf, image, out_path_png, output_type, video_length, final=True) 183 | else: 184 | out_path_mp4 = Path(tempfile.mkdtemp()) / "out.mp4" 185 | yield checkin(None, steps, None, tf, image, out_path_mp4, output_type, video_length, final=True) 186 | 187 | 188 | def make_video(out_path, video_length): 189 | frames = os.listdir('samples') 190 | frames = len(list(filter(lambda filename: filename.endswith(".png"), frames))) # Get number of png generated 191 | 192 | init_frame = 1 # This is the frame where the video will start 193 | last_frame = frames 194 | 195 | min_fps = 10 196 | max_fps = 30 197 | 198 | total_frames = last_frame - init_frame 199 | 200 | frames = [] 201 | tqdm.write('Generating video...') 202 | for i in range(init_frame, last_frame): 203 | filename = f"samples/{i:04}.png" 204 | frames.append(Image.open(filename)) 205 | 206 | fps = np.clip(total_frames / video_length, min_fps, max_fps) 207 | 208 | p = Popen( 209 | ['ffmpeg', '-y', '-f', 'image2pipe', '-vcodec', 'png', '-r', str(fps), '-i', '-', '-vcodec', 'libx264', 210 | '-r', 211 | str(fps), '-pix_fmt', 'yuv420p', '-crf', '17', '-preset', 'veryslow', str(out_path)], stdin=PIPE) 212 | for im in tqdm(frames): 213 | im.save(p.stdin, 'PNG') 214 | p.stdin.close() 215 | p.wait() 216 | tqdm.write("The video is ready") 217 | 218 | 219 | def checkin(i, steps, loss, tf, image, out_path, output_type=None, video_length=None, final=False): 220 | if not final: 221 | tqdm.write(f"Image {i + 1}/{steps} | Current loss: {loss}") 222 | TF.to_pil_image(tf(image)[0]).save(str(out_path)) 223 | else: 224 | if output_type == 'png': 225 | TF.to_pil_image(image[0].add(1).div(2).clamp(0, 1)).save(str(out_path)) 226 | else: 227 | make_video(str(out_path), video_length) 228 | return out_path 229 | 230 | 231 | def fetch(url_or_path): 232 | if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'): 233 | r = requests.get(url_or_path) 234 | r.raise_for_status() 235 | fd = io.BytesIO() 236 | fd.write(r.content) 237 | fd.seek(0) 238 | return fd 239 | return open(url_or_path, 'rb') 240 | 241 | 242 | def fetch_model(url_or_path): 243 | basename = os.path.basename(url_or_path) 244 | if os.path.exists(basename): 245 | return basename 246 | else: 247 | cmd = ( 248 | "wget " 249 | + url_or_path 250 | ) 251 | call(cmd, shell=True) 252 | return basename 253 | 254 | 255 | def norm1(prompt): 256 | """Normalize to the unit sphere.""" 257 | return prompt / prompt.square().sum(dim=-1, keepdim=True).sqrt() 258 | 259 | 260 | def spherical_dist_loss(x, y): 261 | x = F.normalize(x, dim=-1) 262 | y = F.normalize(y, dim=-1) 263 | return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) 264 | 265 | 266 | def prompts_dist_loss(x, targets, loss): 267 | if len(targets) == 1: # Keeps consitent results vs previous method for single objective guidance 268 | return loss(x, targets[0]) 269 | distances = [loss(x, target) for target in targets] 270 | return torch.stack(distances, dim=-1).sum(dim=-1) 271 | 272 | 273 | class MakeCutouts(torch.nn.Module): 274 | def __init__(self, cut_size, cutn, cut_pow=1.): 275 | super().__init__() 276 | self.cut_size = cut_size 277 | self.cutn = cutn 278 | self.cut_pow = cut_pow 279 | 280 | def forward(self, input): 281 | sideY, sideX = input.shape[2:4] 282 | max_size = min(sideX, sideY) 283 | min_size = min(sideX, sideY, self.cut_size) 284 | cutouts = [] 285 | for _ in range(self.cutn): 286 | size = int(torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size) 287 | offsetx = torch.randint(0, sideX - size + 1, ()) 288 | offsety = torch.randint(0, sideY - size + 1, ()) 289 | cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] 290 | cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) 291 | return torch.cat(cutouts) 292 | 293 | 294 | def embed_image(image, clip_model): 295 | n = image.shape[0] 296 | make_cutouts = MakeCutouts(224, 32, 0.5) 297 | cutouts = make_cutouts(image) 298 | embeds = clip_model.embed_cutout(cutouts) 299 | embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n) 300 | return embeds 301 | 302 | 303 | class CLIP(object): 304 | def __init__(self): 305 | clip_model = "ViT-B/32" 306 | self.model, _ = clip.load(clip_model) 307 | self.model = self.model.requires_grad_(False) 308 | self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 309 | std=[0.26862954, 0.26130258, 0.27577711]) 310 | 311 | @torch.no_grad() 312 | def embed_text(self, prompt, device): 313 | """Normalized clip text embedding.""" 314 | return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float()) 315 | 316 | def embed_cutout(self, image): 317 | """Normalized clip image embedding.""" 318 | return norm1(self.model.encode_image(self.normalize(image))) 319 | 320 | 321 | def clean_folder(folder): 322 | for filename in os.listdir(folder): 323 | file_path = os.path.join(folder, filename) 324 | try: 325 | if os.path.isfile(file_path) or os.path.islink(file_path): 326 | os.unlink(file_path) 327 | elif os.path.isdir(file_path): 328 | shutil.rmtree(file_path) 329 | except Exception as e: 330 | print('Failed to delete %s. Reason: %s' % (file_path, e)) 331 | --------------------------------------------------------------------------------