├── .gitignore ├── Aphantasia.ipynb ├── CLIP_VQGAN.ipynb ├── IllusTrip3D.ipynb ├── Illustra.ipynb ├── LICENSE ├── README.md ├── _out ├── Aphantasia.jpg ├── Aphantasia2.jpg ├── Aphantasia3.jpg ├── Aphantasia4.jpg ├── some_cute_image-FFT.jpg ├── some_cute_image-SIREN.jpg └── some_cute_image-VQGAN.jpg ├── aphantasia ├── __init__.py ├── image.py ├── interpol.py ├── progress_bar.py ├── transforms.py └── utils.py ├── clip_fft.py ├── cppn.py ├── depth ├── __init__.py ├── any2 │ ├── dinov2.py │ ├── dinov2_layers │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── block.py │ │ ├── drop_path.py │ │ ├── layer_scale.py │ │ ├── mlp.py │ │ ├── patch_embed.py │ │ └── swiglu_ffn.py │ ├── dpt.py │ ├── run.py │ └── util │ │ ├── blocks.py │ │ └── transform.py └── depth.py ├── illustra.py ├── illustrip.py ├── requirements.txt ├── setup.py └── shader_expo.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | _cudacache/ 4 | .pytest_cache 5 | *.egg-info 6 | thumbs.db 7 | _out 8 | 9 | .ipynb_checkpoints 10 | 11 | *.pkl 12 | *.pt 13 | *.avi 14 | *.mp4 15 | -------------------------------------------------------------------------------- /CLIP_VQGAN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "Text2Image_VQGAN.ipynb", 8 | "private_outputs": true, 9 | "provenance": [] 10 | }, 11 | "kernelspec": { 12 | "display_name": "Python 3", 13 | "language": "python", 14 | "name": "python3" 15 | }, 16 | "language_info": { 17 | "codemirror_mode": { 18 | "name": "ipython", 19 | "version": 3 20 | }, 21 | "file_extension": ".py", 22 | "mimetype": "text/x-python", 23 | "name": "python", 24 | "nbconvert_exporter": "python", 25 | "pygments_lexer": "ipython3", 26 | "version": "3.7.9" 27 | } 28 | }, 29 | "cells": [ 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "id": "toWe1IoH7X35" 34 | }, 35 | "source": [ 36 | "# Text to Image tool\n", 37 | "\n", 38 | "Part of [Aphantasia](https://github.com/eps696/aphantasia) suite, made by Vadim Epstein [[eps696](https://github.com/eps696)] \n", 39 | "Based on [CLIP](https://github.com/openai/CLIP) + VQGAN from [Taming Transformers](https://github.com/CompVis/taming-transformers). \n", 40 | "thanks to [Ryan Murdock](https://twitter.com/advadnoun), [Jonathan Fly](https://twitter.com/jonathanfly), [Hannu Toyryla](https://twitter.com/htoyryla) for ideas.\n", 41 | "\n", 42 | "## Features\n", 43 | "* complex requests:\n", 44 | " * image and/or text as main prompts \n", 45 | " (composition similarity controlled with [LPIPS](https://github.com/richzhang/PerceptualSimilarity) loss)\n", 46 | " * separate text prompts for image style and to subtract (suppress) topics\n", 47 | " * criteria inversion (show \"the opposite\")\n", 48 | "\n", 49 | "* various VQGAN models (incl. newest Gumbel-F8)\n", 50 | "* various CLIP models\n", 51 | "* saving/loading VQGAN snapshots to resume processing" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": { 57 | "id": "QytcEMSKBtN-" 58 | }, 59 | "source": [ 60 | "**Run the cell below after each session restart**\n", 61 | "\n", 62 | "First select `VQGAN_model` for generation. \n", 63 | "`Gumbel` is probably the best, but eats more RAM (max resolution on Colab ~900x500). `F16-1024` can go up to ~1000x600. \n", 64 | "`resume` if you want to start from the saved snapshot." 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "metadata": { 70 | "id": "etzxXVZ_r-Nf", 71 | "cellView": "form" 72 | }, 73 | "source": [ 74 | "#@title General setup\n", 75 | "\n", 76 | "VQGAN_model = \"gumbel_f8-8192\" #@param ['gumbel_f8-8192', 'imagenet_f16-1024', 'imagenet_f16-16384']\n", 77 | "resume = False #@param {type:\"boolean\"}\n", 78 | "\n", 79 | "!pip install ftfy gputil ffpb\n", 80 | "\n", 81 | "import os\n", 82 | "import io\n", 83 | "import time\n", 84 | "from math import exp\n", 85 | "import random\n", 86 | "import imageio\n", 87 | "import numpy as np\n", 88 | "import PIL\n", 89 | "from collections import OrderedDict\n", 90 | "from base64 import b64encode\n", 91 | "\n", 92 | "import torch\n", 93 | "import torch.nn as nn\n", 94 | "import torch.nn.functional as F\n", 95 | "import torchvision\n", 96 | "from torch.autograd import Variable\n", 97 | "\n", 98 | "from IPython.display import HTML, Image, display, clear_output\n", 99 | "from IPython.core.interactiveshell import InteractiveShell\n", 100 | "InteractiveShell.ast_node_interactivity = \"all\"\n", 101 | "import ipywidgets as ipy\n", 102 | "from google.colab import output, files\n", 103 | "output.enable_custom_widget_manager()\n", 104 | "\n", 105 | "import warnings\n", 106 | "warnings.filterwarnings(\"ignore\")\n", 107 | "\n", 108 | "!pip install --no-deps git+https://github.com/openai/CLIP.git\n", 109 | "import clip\n", 110 | "!pip install --no-deps kornia kornia_rs\n", 111 | "import kornia\n", 112 | "!pip install --no-deps lpips\n", 113 | "import lpips\n", 114 | "\n", 115 | "%cd /content\n", 116 | "!pip install git+https://github.com/eps696/aphantasia\n", 117 | "from aphantasia.utils import slice_imgs, pad_up_to, basename, img_list, img_read, plot_text, txt_clean, old_torch\n", 118 | "from aphantasia import transforms\n", 119 | "from aphantasia.progress_bar import ProgressIPy as ProgressBar\n", 120 | "\n", 121 | "!pip install omegaconf>=2.0.0 einops>=0.3.0\n", 122 | "!pip3 install --no-deps torchmetrics lightning_utilities pytorch_lightning\n", 123 | "import pytorch_lightning as pl\n", 124 | "!git clone https://github.com/CompVis/taming-transformers\n", 125 | "!mv taming-transformers/* ./\n", 126 | "import yaml\n", 127 | "from omegaconf import OmegaConf\n", 128 | "from taming.modules.diffusionmodules.model import Decoder\n", 129 | "from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer\n", 130 | "from taming.modules.vqvae.quantize import GumbelQuantize\n", 131 | "\n", 132 | "class VQModel(pl.LightningModule):\n", 133 | " def __init__(self, ddconfig, n_embed, embed_dim, remap=None, sane_index_shape=False, **kwargs_ignore): # tell vector quantizer to return indices as bhw\n", 134 | " super().__init__()\n", 135 | " self.decoder = Decoder(**ddconfig)\n", 136 | " self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape)\n", 137 | " def decode(self, quant):\n", 138 | " return self.decoder(quant)\n", 139 | "\n", 140 | "class GumbelVQ(VQModel):\n", 141 | " def __init__(self, ddconfig, n_embed, embed_dim, kl_weight=1e-8, remap=None, **kwargs_ignore):\n", 142 | " z_channels = ddconfig[\"z_channels\"]\n", 143 | " super().__init__(ddconfig, n_embed, embed_dim)\n", 144 | " self.quantize = GumbelQuantize(z_channels, embed_dim, n_embed=n_embed, kl_weight=kl_weight, temp_init=1.0, remap=remap)\n", 145 | "\n", 146 | "if not os.path.isdir('/content/models_TT'):\n", 147 | " !mkdir -p /content/models_TT\n", 148 | "def getm(url, path):\n", 149 | " if os.path.isfile(path) and os.stat(path).st_size > 0:\n", 150 | " print(' already exists', path, os.stat(path).st_size)\n", 151 | " else:\n", 152 | " !wget $url -O $path\n", 153 | "\n", 154 | "if VQGAN_model == \"gumbel_f8-8192\" and not os.path.isfile('/content/models_TT/gumbel_f8-8192.ckpt'):\n", 155 | " getm('https://heibox.uni-heidelberg.de/f/34a747d5765840b5a99d/?dl=1', '/content/models_TT/gumbel_f8-8192.ckpt')\n", 156 | " getm('https://heibox.uni-heidelberg.de/f/b24d14998a8d4f19a34f/?dl=1', '/content/models_TT/gumbel_f8-8192.yaml')\n", 157 | "elif VQGAN_model == \"imagenet_f16-1024\" and not os.path.isfile('/content/models_TT/imagenet_f16-1024.ckpt'):\n", 158 | " getm('https://heibox.uni-heidelberg.de/f/140747ba53464f49b476/?dl=1', '/content/models_TT/imagenet_f16-1024.ckpt')\n", 159 | " getm('https://heibox.uni-heidelberg.de/f/6ecf2af6c658432c8298/?dl=1', '/content/models_TT/imagenet_f16-1024.yaml')\n", 160 | "elif VQGAN_model == \"imagenet_f16-16384\" and not os.path.isfile('/content/models_TT/imagenet_f16-16384.ckpt'):\n", 161 | " getm('https://heibox.uni-heidelberg.de/f/867b05fc8c4841768640/?dl=1', '/content/models_TT/imagenet_f16-16384.ckpt')\n", 162 | " getm('https://heibox.uni-heidelberg.de/f/274fb24ed38341bfa753/?dl=1', '/content/models_TT/imagenet_f16-16384.yaml')\n", 163 | "\n", 164 | "clear_output()\n", 165 | "\n", 166 | "if resume:\n", 167 | " resumed = files.upload()\n", 168 | " params_pt = list(resumed.values())[0]\n", 169 | " params_pt = torch.load(io.BytesIO(params_pt))\n", 170 | "\n", 171 | "if VQGAN_model == \"gumbel_f8-8192\":\n", 172 | " scale_res = 8\n", 173 | "else:\n", 174 | " scale_res = 16\n", 175 | "\n", 176 | "def load_config(config_path):\n", 177 | " config = OmegaConf.load(config_path)\n", 178 | " return config\n", 179 | "\n", 180 | "def load_vqgan(config, ckpt_path=None):\n", 181 | " if VQGAN_model == \"gumbel_f8-8192\":\n", 182 | " model = GumbelVQ(**config.model.params)\n", 183 | " else:\n", 184 | " model = VQModel(**config.model.params)\n", 185 | " if ckpt_path is not None:\n", 186 | " sd = torch.load(ckpt_path, map_location=\"cpu\")[\"state_dict\"]\n", 187 | " missing, unexpected = model.load_state_dict(sd, strict=False)\n", 188 | " return model.eval()\n", 189 | "\n", 190 | "def vqgan_image(model, z):\n", 191 | " x = model.decode(z)\n", 192 | " x = (x+1.)/2.\n", 193 | " return x\n", 194 | "\n", 195 | "class latents(torch.nn.Module):\n", 196 | " def __init__(self, shape):\n", 197 | " super(latents, self).__init__()\n", 198 | " init_rnd = torch.zeros(shape).normal_(0.,4.)\n", 199 | " self.lats = torch.nn.Parameter(init_rnd.cuda())\n", 200 | " def forward(self):\n", 201 | " return self.lats\n", 202 | "\n", 203 | "config_vqgan = load_config(\"/content/models_TT/%s.yaml\" % VQGAN_model)\n", 204 | "model_vqgan = load_vqgan(config_vqgan, ckpt_path=\"/content/models_TT/%s.ckpt\" % VQGAN_model).cuda()\n", 205 | "\n", 206 | "def makevid(seq_dir, size=None):\n", 207 | " char_len = len(basename(img_list(seq_dir)[0]))\n", 208 | " out_sequence = seq_dir + '/%0{}d.jpg'.format(char_len)\n", 209 | " out_video = seq_dir + '.mp4'\n", 210 | " print('.. generating video ..')\n", 211 | " !ffmpeg -y -v warning -i $out_sequence -crf 20 $out_video\n", 212 | " data_url = \"data:video/mp4;base64,\" + b64encode(open(out_video,'rb').read()).decode()\n", 213 | " wh = '' if size is None else 'width=%d height=%d' % (size, size)\n", 214 | " return \"\"\"\"\"\" % (wh, data_url)\n", 215 | "\n", 216 | "# Hardware check\n", 217 | "!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi\n", 218 | "import GPUtil as GPU\n", 219 | "gpu = GPU.getGPUs()[0] # XXX: only one GPU on Colab and isn’t guaranteed\n", 220 | "!nvidia-smi -L\n", 221 | "print(\"GPU RAM {0:.0f}MB | Free {1:.0f}MB)\".format(gpu.memoryTotal, gpu.memoryFree))\n", 222 | "print('\\nDone!')" 223 | ], 224 | "execution_count": null, 225 | "outputs": [] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "metadata": { 230 | "id": "CbJ9K4Cq8MtB" 231 | }, 232 | "source": [ 233 | "Type some `text` and/or upload some image to start. \n", 234 | "Describe `style`, which you'd like to apply to the imagery. \n", 235 | "Put to `subtract` the topics, which you would like to avoid in the result. \n", 236 | "`invert` the whole criteria, if you want to see \"the totally opposite\". \n", 237 | "Mark `translate` to use Google translation." 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "metadata": { 243 | "id": "JUvpdy8BWGuM", 244 | "cellView": "form" 245 | }, 246 | "source": [ 247 | "#@title Input\n", 248 | "\n", 249 | "text = \"\" #@param {type:\"string\"}\n", 250 | "style = \"\" #@param {type:\"string\"}\n", 251 | "subtract = \"\" #@param {type:\"string\"}\n", 252 | "translate = False #@param {type:\"boolean\"}\n", 253 | "invert = False #@param {type:\"boolean\"}\n", 254 | "upload_image = False #@param {type:\"boolean\"}\n", 255 | "\n", 256 | "if translate:\n", 257 | " !pip3 install googletrans==3.1.0a0\n", 258 | " clear_output()\n", 259 | " from googletrans import Translator\n", 260 | " translator = Translator()\n", 261 | "\n", 262 | "if upload_image:\n", 263 | " uploaded = files.upload()\n", 264 | "\n", 265 | "workdir = '_out'\n", 266 | "tempdir = os.path.join(workdir, '%s-%s' % (txt_clean(text)[:50], txt_clean(style)[:50]))" 267 | ], 268 | "execution_count": null, 269 | "outputs": [] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "metadata": { 274 | "id": "f3Sj0fxmtw6K" 275 | }, 276 | "source": [ 277 | "### Settings\n", 278 | "\n", 279 | "Select CLIP visual `model` (results do vary!). I prefer ViT for consistency (and it's the only native multi-language option). \n", 280 | "`align` option is about composition. `uniform` looks most adequate, `overscan` can make semi-seamless tileable texture. \n", 281 | "`aug_transform` applies some augmentations, inhibiting image fragmentation & \"graffiti\" printing (slower, yet recommended). \n", 282 | "`sync` value adds LPIPS loss between the output and input image (if there's one), allowing to \"redraw\" it with controlled similarity. \n", 283 | "Decrease `samples` or resolution if you face OOM. \n", 284 | "\n", 285 | "Generation video and final parameters snapshot are saved automatically. \n", 286 | "NB: Requests are cumulative (start near the end of the previous run). To start generation from scratch, re-run General setup." 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "metadata": { 292 | "id": "Nq0wA-wc-P-s", 293 | "cellView": "form" 294 | }, 295 | "source": [ 296 | "#@title Generate\n", 297 | "\n", 298 | "!rm -rf $tempdir\n", 299 | "os.makedirs(tempdir, exist_ok=True)\n", 300 | "\n", 301 | "sideX = 900 #@param {type:\"integer\"}\n", 302 | "sideY = 500#@param {type:\"integer\"}\n", 303 | "#@markdown > Config\n", 304 | "model = 'ViT-B/32' #@param ['ViT-B/16', 'ViT-B/32', 'RN101', 'RN50x16', 'RN50x4', 'RN50']\n", 305 | "align = 'uniform' #@param ['central', 'uniform', 'overscan']\n", 306 | "aug_transform = True #@param {type:\"boolean\"}\n", 307 | "sync = 0.4 #@param {type:\"number\"}\n", 308 | "#@markdown > Training\n", 309 | "steps = 200 #@param {type:\"integer\"}\n", 310 | "samples = 60 #@param {type:\"integer\"}\n", 311 | "learning_rate = 0.1 #@param {type:\"number\"}\n", 312 | "save_freq = 1 #@param {type:\"integer\"}\n", 313 | "\n", 314 | "if resume:\n", 315 | " if not isinstance(params_pt, dict):\n", 316 | " params_pt = OrderedDict({'lats': params_pt})\n", 317 | " ps = params_pt['lats'].shape\n", 318 | " size = [s*scale_res for s in ps[2:]]\n", 319 | " lats = latents(ps).cuda()\n", 320 | " _ = lats.load_state_dict(params_pt)\n", 321 | " print(' resumed with size', size)\n", 322 | "else:\n", 323 | " lats = latents([1, 256, sideY//scale_res, sideX//scale_res]).cuda()\n", 324 | "\n", 325 | "if len(subtract) > 0:\n", 326 | " samples = int(samples * 0.75)\n", 327 | "if sync > 0 and upload_image:\n", 328 | " samples = int(samples * 0.5)\n", 329 | "print(' using %d samples' % samples)\n", 330 | "\n", 331 | "model_clip, _ = clip.load(model, jit=old_torch())\n", 332 | "modsize = model_clip.visual.input_resolution\n", 333 | "xmem = {'ViT-B/16':0.25, 'RN50':0.5, 'RN50x4':0.16, 'RN50x16':0.06, 'RN101':0.33}\n", 334 | "if model in xmem.keys():\n", 335 | " samples = int(samples * xmem[model])\n", 336 | "\n", 337 | "def enc_text(txt):\n", 338 | " emb = model_clip.encode_text(clip.tokenize(txt).cuda())\n", 339 | " return emb.detach().clone()\n", 340 | "\n", 341 | "sign = 1. if invert else -1.\n", 342 | "if aug_transform:\n", 343 | " trform_f = transforms.transforms_fast\n", 344 | " samples = int(samples * 0.95)\n", 345 | "else:\n", 346 | " trform_f = transforms.normalize()\n", 347 | "\n", 348 | "if upload_image:\n", 349 | " in_img = list(uploaded.values())[0]\n", 350 | " print(' image:', list(uploaded)[0])\n", 351 | " img_in = torch.from_numpy(imageio.imread(in_img).astype(np.float32)/255.).unsqueeze(0).permute(0,3,1,2).cuda()[:,:3,:,:]\n", 352 | " in_sliced = slice_imgs([img_in], samples, modsize, transforms.normalize(), align)[0]\n", 353 | " img_enc = model_clip.encode_image(in_sliced).detach().clone()\n", 354 | " if sync > 0:\n", 355 | " align = 'overscan'\n", 356 | " sim_loss = lpips.LPIPS(net='vgg', verbose=False).cuda()\n", 357 | " sim_size = [sideY//4, sideX//4]\n", 358 | " img_in = F.interpolate(img_in, sim_size).float()\n", 359 | " # img_in = F.interpolate(img_in, (sideY, sideX)).float()\n", 360 | " else:\n", 361 | " del img_in\n", 362 | " del in_sliced; torch.cuda.empty_cache()\n", 363 | "\n", 364 | "if len(text) > 0:\n", 365 | " print(' text:', text)\n", 366 | " if translate:\n", 367 | " text = translator.translate(text, dest='en').text\n", 368 | " print(' translated to:', text)\n", 369 | " txt_enc = enc_text(text)\n", 370 | "\n", 371 | "if len(style) > 0:\n", 372 | " print(' style:', style)\n", 373 | " if translate:\n", 374 | " style = translator.translate(style, dest='en').text\n", 375 | " print(' translated to:', style)\n", 376 | " txt_enc2 = enc_text(style)\n", 377 | "\n", 378 | "if len(subtract) > 0:\n", 379 | " print(' without:', subtract)\n", 380 | " if translate:\n", 381 | " subtract = translator.translate(subtract, dest='en').text\n", 382 | " print(' translated to:', subtract)\n", 383 | " txt_enc0 = enc_text(subtract)\n", 384 | "\n", 385 | "optimizer = torch.optim.AdamW(lats.parameters(), learning_rate, weight_decay=0.01, amsgrad=True)\n", 386 | "\n", 387 | "def save_img(img, fname=None):\n", 388 | " img = np.array(img)[:,:,:]\n", 389 | " img = np.transpose(img, (1,2,0))\n", 390 | " img = np.clip(img*255, 0, 255).astype(np.uint8)\n", 391 | " if fname is not None:\n", 392 | " imageio.imsave(fname, np.array(img))\n", 393 | " imageio.imsave('result.jpg', np.array(img))\n", 394 | "\n", 395 | "def checkout(num):\n", 396 | " with torch.no_grad():\n", 397 | " img = vqgan_image(model_vqgan, lats()).cpu().numpy()[0]\n", 398 | " save_img(img, os.path.join(tempdir, '%04d.jpg' % num))\n", 399 | " outpic.clear_output()\n", 400 | " with outpic:\n", 401 | " display(Image('result.jpg'))\n", 402 | "\n", 403 | "def train(i):\n", 404 | " loss = 0\n", 405 | " img_out = vqgan_image(model_vqgan, lats())\n", 406 | " img_sliced = slice_imgs([img_out], samples, modsize, trform_f, align, macro=0.4)[0]\n", 407 | " out_enc = model_clip.encode_image(img_sliced)\n", 408 | "\n", 409 | " if len(text) > 0: # input text\n", 410 | " loss += sign * torch.cosine_similarity(txt_enc, out_enc, dim=-1).mean()\n", 411 | " if len(style) > 0: # input text - style\n", 412 | " loss += sign * 0.5 * torch.cosine_similarity(txt_enc2, out_enc, dim=-1).mean()\n", 413 | " if len(subtract) > 0: # subtract text\n", 414 | " loss += -sign * 0.5 * torch.cosine_similarity(txt_enc0, out_enc, dim=-1).mean()\n", 415 | " if upload_image:\n", 416 | " loss += sign * 0.5 * torch.cosine_similarity(img_enc, out_enc, dim=-1).mean()\n", 417 | " if sync > 0 and upload_image: # image composition sync\n", 418 | " prog_sync = (steps - i) / steps\n", 419 | " loss += prog_sync * sync * sim_loss(F.interpolate(img_out, sim_size).float(), img_in, normalize=True).squeeze()\n", 420 | " del img_out, img_sliced, out_enc; torch.cuda.empty_cache()\n", 421 | "\n", 422 | " optimizer.zero_grad()\n", 423 | " loss.backward()\n", 424 | " optimizer.step()\n", 425 | "\n", 426 | " if i % save_freq == 0:\n", 427 | " checkout(i // save_freq)\n", 428 | "\n", 429 | "outpic = ipy.Output()\n", 430 | "outpic\n", 431 | "\n", 432 | "pbar = ProgressBar(steps)\n", 433 | "for i in range(steps):\n", 434 | " train(i)\n", 435 | " _ = pbar.upd()\n", 436 | "\n", 437 | "HTML(makevid(tempdir))\n", 438 | "torch.save(lats.lats, tempdir + '.pt')\n", 439 | "files.download(tempdir + '.pt')\n", 440 | "files.download(tempdir + '.mp4')\n" 441 | ], 442 | "execution_count": null, 443 | "outputs": [] 444 | } 445 | ] 446 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Vadim Epstein 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Aphantasia 2 | 3 |

4 | 5 | This is a collection of text-to-image tools, evolved from the [artwork] of the same name. 6 | Based on [CLIP] model and [Lucent] library, with FFT/DWT/RGB parameterizers (no-GAN generation). 7 | *Updated: Old depth estimation method is replaced with [Depth Anything 2].* 8 | Tested on Python 3.7-3.11 with PyTorch from 1.7.1 to 2.3.1. 9 | 10 | *[Aphantasia] is the inability to visualize mental images, the deprivation of visual dreams. 11 | The image in the header is generated by the tool from this word.* 12 | 13 | **Please be kind to mention this project, if you employ it for your masterpieces** 14 | 15 | ## Features 16 | * generating massive detailed textures, a la deepdream 17 | * fullHD/4K resolutions and above 18 | * various CLIP models 19 | * continuous mode to process phrase lists (e.g. illustrating lyrics) 20 | * pan/zoom motion with smooth interpolation 21 | * direct RGB pixels optimization (very stable) 22 | * 3D look, based on [Depth Anything 2] 23 | * complex queries: 24 | * text and/or image as main prompts 25 | * separate text prompts for style and to subtract (avoid) topics 26 | * starting/resuming process from saved parameters or from an image 27 | 28 | Setup [CLIP] et cetera: 29 | ``` 30 | pip install -r requirements.txt 31 | pip install git+https://github.com/openai/CLIP.git 32 | ``` 33 | 34 | ## Operations 35 | 36 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/eps696/aphantasia/blob/master/Aphantasia.ipynb) 37 | 38 | * Generate an image from the text prompt (set the size as you wish): 39 | ``` 40 | python clip_fft.py -t "the text" --size 1280-720 41 | ``` 42 | * Reproduce an image: 43 | ``` 44 | python clip_fft.py -i theimage.jpg --sync 0.4 45 | ``` 46 | If `--sync X` argument > 0, [LPIPS] loss is added to keep the composition similar to the original image. 47 | 48 | You can combine both text and image prompts. 49 | For non-English languages use `--translate` (Google translation). 50 | 51 | * Set more specific query like this: 52 | ``` 53 | python clip_fft.py -t "topic sentence" -t2 "style description" -t0 "avoid this" --size 1280-720 54 | ``` 55 | * Other options: 56 | Text inputs understand syntax with weights, like `good prompt :1 | also good prompt :1 | bad prompt :-0.5`. 57 | `--model M` selects one of the released CLIP visual models: `ViT-B/32` (default), `ViT-B/16`, `RN50`, `RN50x4`, `RN50x16`, `RN101`. 58 | One can also set `--dualmod` to use `ViT-B/32` and `ViT-B/16` at once (preferrable). 59 | `--dwt` switches to DWT (wavelets) generator instead of FFT. There are few methods, chosen by `--wave X`, e.g. `db2`, `db3`, `coif1`, `coif2`, etc. 60 | `--align XX` option is about composition (or sampling distribution, to be more precise): `uniform` is maybe the most adequate; `overscan` can make semi-seamless tileable textures. 61 | `--steps N` sets iterations count. 100-200 is enough for a starter; 500-1000 would elaborate it more thoroughly. 62 | `--samples N` sets amount of the image cuts (samples), processed at one step. With more samples you can set fewer iterations for similar result (and vice versa). 200/200 is a good guess. NB: GPU memory is mostly eaten by this count (not resolution)! 63 | `--aest X` enforces overall cuteness by employing [aesthetic loss](https://github.com/LAION-AI/aesthetic-predictor). try various values (may be negative). 64 | `--decay X` (compositional softness), `--colors X` (saturation) and `--contrast X` may be useful, especially for ResNet models (they tend to burn the colors). 65 | `--sharp X` may be useful to increase sharpness, if the image becomes "myopic" after increasing `decay`. it affects the other color parameters, better tweak them all together! 66 | Current defaults are `--decay 1.5 --colors 1.8 --contrast 1.1 --sharp 0`. 67 | `--transform X` applies some augmentations, usually enhancing result (but slower). there are few choices; `fast` seems optimal. 68 | `--optimizer` can be `adam`, `adamw`, `adam_custom` or `adamw_custom`. Custom options are noiser but stable; pure `adam` is softer, but may tend to colored blurring. 69 | `--invert` negates the whole criteria, if you fancy checking "totally opposite". 70 | `--save_pt myfile.pt` will save FFT/DWT parameters, to resume for next query with `--resume myfile.pt`. One can also start/resume directly from an image file. 71 | `--opt_step N` tells to save every Nth frame (useful with high iterations, default is 1). 72 | `--verbose` ('on' by default) enables some printouts and realtime image preview. 73 | * Some experimental tricks with less definite effects: 74 | `--enforce X` adds more details by boosting similarity between two parallel samples. good start is ~0.1. 75 | `--expand X` boosts diversity by enforcing difference between prev/next samples. good start is ~0.3. 76 | `--noise X` adds some noise to the parameters, possibly making composition less clogged (in a degree). 77 | `--macro X` (from 0 to 1) shifts generation to bigger forms and less disperse composition. should not be too close to 1, since the quality depends on the variety of samples. 78 | `--prog` sets progressive learning rate (from 0.1x to 2x of the one, set by `lrate`). it may boost macro forms creation in some cases (see more [here](https://github.com/eps696/aphantasia/issues/2)). 79 | `--lrate` controls learning rate. The range is quite wide (tested at least within 0.001 to 10). 80 | 81 | ## Text-to-video [continuous mode] 82 | 83 | Here is two ways of making video from the text file(s), processing it line by line in one shot. 84 | 85 | ### Illustrip 86 | 87 | New method, interpolating topics as a constant flow with permanent pan/zoom motion and optional 3D look. 88 | 89 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/eps696/aphantasia/blob/master/IllusTrip3D.ipynb) 90 | 91 | * Make video from two text files, processing them line by line, rendering 100 frames per line: 92 | ``` 93 | python illustrip.py --in_txt mycontent.txt --in_txt2 mystyles.txt --size 1280-720 --steps 100 94 | ``` 95 | * Make video from two phrases, with total length 500 frames: 96 | ``` 97 | python illustrip.py --in_txt "my super content" --in_txt2 "my super style" --size 1280-720 --steps 500 98 | ``` 99 | Prefixes (`-pre`), postfixes (`-post`) and "stop words" (`--in_txt0`) may be loaded as phrases or text files as well. 100 | All text inputs understand syntax with weights, like `good prompt :1 | also good prompt :1 | bad prompt :-0.5` (within one line). 101 | One can also use image(s) as references with `--in_img` argument. Explore other arguments for more explicit control. 102 | This method works best with direct RGB pixels optimization, but can also be used with FFT parameterization: 103 | ``` 104 | python illustrip.py ... --gen FFT --smooth --align uniform --colors 1.8 --contrast 1.1 105 | ``` 106 | 107 | To add 3D look, add `--depth 0.01` to the command. 108 | 109 | ### Illustra 110 | 111 | Generates separate images for every text line (with sequences and training videos, as in single-image mode above), then renders final video from those (mixing images in FFT space) of the `length` duration in seconds. 112 | 113 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/eps696/aphantasia/blob/master/Illustra.ipynb) 114 | 115 | * Make video from a text file, processing it line by line: 116 | ``` 117 | python illustra.py -t mysong.txt --size 1280-720 --length 155 118 | ``` 119 | There is `--keep X` parameter, controlling how well the next line/image generation follows the previous. 0 means it's randomly initiated, the higher - the stricter it will keep the original composition. Safe values are 1~2 (much higher numbers may cause the imagery getting stuck). 120 | 121 | * Make video from a directory with saved *.pt snapshots (just interpolate them): 122 | ``` 123 | python interpol.py -i mydir --length 155 124 | ``` 125 | 126 | ## Other generators 127 | 128 | * VQGAN from [Taming Transformers](https://github.com/CompVis/taming-transformers) 129 | One of the best methods for colors/tones/details (especially with new Gumbel-F8 model); has quite limited resolution though (~800x600 max on Colab). 130 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/eps696/aphantasia/blob/master/CLIP_VQGAN.ipynb) 131 |

132 | 133 | * CPPN + [export to HLSL shaders](https://github.com/wxs/cppn-to-glsl) 134 | One of the very first methods, with exports for TouchDesigner, vvvv, Shadertoy, etc. 135 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Kbbbwoet3igHPJ4KpNh8z3V-RxtstAcz) 136 | ``` 137 | python cppn.py -v -t "the text" --aest 0.5 138 | ``` 139 | 140 | * SIREN + [Fourier feature modulation](https://github.com/tancik/fourier-feature-networks) 141 | Another early method, not so interesting on its own. 142 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1L14q4To5rMK8q2E6whOibQBnPnVbRJ_7) 143 |

144 | 145 | ## Credits 146 | 147 | Based on [CLIP] model by OpenAI ([paper]). 148 | FFT encoding is taken from [Lucent] library, 3D depth processing made by [deKxi]. 149 | 150 | Thanks to [Ryan Murdock], [Jonathan Fly], [Hannu Toyryla], [@eduwatch2], [torridgristle] for ideas. 151 | 152 |

153 | 154 | [artwork]: 155 | [Aphantasia]: 156 | [CLIP]: 157 | [SBERT]: 158 | [Lucent]: 159 | [Depth Anything 2]: 160 | [LPIPS]: 161 | [Taming Transformers]: 162 | [Ryan Murdock]: 163 | [Jonathan Fly]: 164 | [Hannu Toyryla]: 165 | [@eduwatch2]: 166 | [torridgristle]: 167 | [deKxi]: 168 | [paper]: 169 | -------------------------------------------------------------------------------- /_out/Aphantasia.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/aphantasia/8a415286d2891e92d865150d6e0e59fdfd32fb01/_out/Aphantasia.jpg -------------------------------------------------------------------------------- /_out/Aphantasia2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/aphantasia/8a415286d2891e92d865150d6e0e59fdfd32fb01/_out/Aphantasia2.jpg -------------------------------------------------------------------------------- /_out/Aphantasia3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/aphantasia/8a415286d2891e92d865150d6e0e59fdfd32fb01/_out/Aphantasia3.jpg -------------------------------------------------------------------------------- /_out/Aphantasia4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/aphantasia/8a415286d2891e92d865150d6e0e59fdfd32fb01/_out/Aphantasia4.jpg -------------------------------------------------------------------------------- /_out/some_cute_image-FFT.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/aphantasia/8a415286d2891e92d865150d6e0e59fdfd32fb01/_out/some_cute_image-FFT.jpg -------------------------------------------------------------------------------- /_out/some_cute_image-SIREN.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/aphantasia/8a415286d2891e92d865150d6e0e59fdfd32fb01/_out/some_cute_image-SIREN.jpg -------------------------------------------------------------------------------- /_out/some_cute_image-VQGAN.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/aphantasia/8a415286d2891e92d865150d6e0e59fdfd32fb01/_out/some_cute_image-VQGAN.jpg -------------------------------------------------------------------------------- /aphantasia/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/aphantasia/8a415286d2891e92d865150d6e0e59fdfd32fb01/aphantasia/__init__.py -------------------------------------------------------------------------------- /aphantasia/image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from imageio import imread 4 | 5 | import pywt 6 | from pytorch_wavelets import DWTForward, DWTInverse 7 | # from pytorch_wavelets import DTCWTForward, DTCWTInverse 8 | 9 | import torch 10 | 11 | from aphantasia.utils import slice_imgs, derivat, sim_func, basename, img_list, img_read, plot_text, old_torch 12 | from aphantasia.transforms import normalize 13 | 14 | def to_valid_rgb(image_f, colors=1., decorrelate=True): 15 | color_correlation_svd_sqrt = torch.tensor([[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]]) 16 | color_correlation_svd_sqrt /= torch.tensor([colors, 1., 1.]) # saturate, empirical 17 | max_norm_svd_sqrt = color_correlation_svd_sqrt.norm(dim=0).max() 18 | color_correlation_normalized = color_correlation_svd_sqrt / max_norm_svd_sqrt 19 | colcorr_t = color_correlation_normalized.T.cuda() 20 | 21 | def _linear_decorrelate_color(image): 22 | return torch.einsum('nchw,cd->ndhw', image, colcorr_t) # edit by katherine crowson 23 | 24 | def inner(*args, **kwargs): 25 | image = image_f(*args, **kwargs) 26 | if decorrelate: 27 | image = _linear_decorrelate_color(image) 28 | return torch.sigmoid(image) 29 | return inner 30 | 31 | ### DWT [wavelets] 32 | 33 | def init_dwt(resume=None, shape=None, wave=None, colors=None): 34 | size = None 35 | wp_fake = pywt.WaveletPacket2D(data=np.zeros(shape[2:]), wavelet='db1', mode='symmetric') 36 | xfm = DWTForward(J=wp_fake.maxlevel, wave=wave, mode='symmetric').cuda() 37 | # xfm = DTCWTForward(J=lvl, biort='near_sym_b', qshift='qshift_b').cuda() # 4x more params, biort ['antonini','legall','near_sym_a','near_sym_b'] 38 | ifm = DWTInverse(wave=wave, mode='symmetric').cuda() # symmetric zero periodization 39 | # ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda() # 4x more params, biort ['antonini','legall','near_sym_a','near_sym_b'] 40 | if resume is None: # random init 41 | Yl_in, Yh_in = xfm(torch.zeros(shape).cuda()) 42 | Ys = [torch.randn(*Y.shape).cuda() for Y in [Yl_in, *Yh_in]] 43 | elif isinstance(resume, str): 44 | if os.path.isfile(resume): 45 | if os.path.splitext(resume)[1].lower()[1:] in ['jpg','png','tif','bmp']: 46 | img_in = imread(resume) 47 | Ys = img2dwt(img_in, wave=wave, colors=colors) 48 | print(' loaded image', resume, img_in.shape, 'level', len(Ys)-1) 49 | size = img_in.shape[:2] 50 | wp_fake = pywt.WaveletPacket2D(data=np.zeros(size), wavelet='db1', mode='symmetric') 51 | xfm = DWTForward(J=wp_fake.maxlevel, wave=wave, mode='symmetric').cuda() 52 | else: 53 | Ys = torch.load(resume) 54 | Ys = [y.detach().cuda() for y in Ys] 55 | else: print(' Snapshot not found:', resume); exit() 56 | else: 57 | Ys = [y.cuda() for y in resume] 58 | # print('level', len(Ys)-1, 'low freq', Ys[0].cpu().numpy().shape) 59 | return Ys, xfm, ifm, size 60 | 61 | def dwt_image(shape, wave='coif2', sharp=0.3, colors=1., resume=None): 62 | Ys, _, ifm, size = init_dwt(resume, shape, wave, colors) 63 | Ys = [y.requires_grad_(True) for y in Ys] 64 | scale = dwt_scale(Ys, sharp) 65 | 66 | def inner(shift=None, contrast=1.): 67 | image = ifm((Ys[0], [Ys[i+1] * float(scale[i]) for i in range(len(Ys)-1)])) 68 | image = image * contrast / image.std() # keep contrast, empirical *1.33 69 | return image 70 | 71 | return Ys, inner, size 72 | 73 | def dwt_scale(Ys, sharp): 74 | scale = [] 75 | [h0,w0] = Ys[1].shape[3:5] 76 | for i in range(len(Ys)-1): 77 | [h,w] = Ys[i+1].shape[3:5] 78 | scale.append( ((h0*w0)/(h*w)) ** (1.-sharp) ) 79 | # print(i+1, Ys[i+1].shape) 80 | return scale 81 | 82 | def img2dwt(img_in, wave='coif2', sharp=0.3, colors=1.): 83 | image_t = un_rgb(img_in, colors=colors) 84 | with torch.no_grad(): 85 | wp_fake = pywt.WaveletPacket2D(data=np.zeros(image_t.shape[2:]), wavelet='db1', mode='zero') 86 | lvl = wp_fake.maxlevel 87 | # print(image_t.shape, lvl) 88 | xfm = DWTForward(J=lvl, wave=wave, mode='symmetric').cuda() 89 | Yl_in, Yh_in = xfm(image_t.cuda()) 90 | Ys = [Yl_in, *Yh_in] 91 | scale = dwt_scale(Ys, sharp) 92 | for i in range(len(Ys)-1): 93 | Ys[i+1] /= scale[i] 94 | return Ys 95 | 96 | ### FFT/RGB from Lucent library ### https://github.com/greentfrapp/lucent 97 | 98 | def pixel_image(shape, resume=None, sd=1., *noargs, **nokwargs): 99 | size = None 100 | if resume is None: 101 | image_t = torch.randn(*shape) * sd 102 | elif isinstance(resume, str): 103 | if os.path.isfile(resume): 104 | img_in = img_read(resume) 105 | image_t = 3.3 * un_rgb(img_in, colors=2.) 106 | size = img_in.shape[:2] 107 | print(resume, size) 108 | else: print(' Image not found:', resume); exit() 109 | else: 110 | if isinstance(resume, list): resume = resume[0] 111 | image_t = resume 112 | image_t = image_t.cuda().requires_grad_(True) 113 | 114 | def inner(shift=None, contrast=1., fixcontrast=False): # *noargs, **nokwargs 115 | if fixcontrast is True: # for resuming from image 116 | return image_t * contrast / 3.3 117 | else: 118 | return image_t * contrast / image_t.std() 119 | return [image_t], inner, size # lambda: image_t 120 | 121 | # From https://github.com/tensorflow/lucid/blob/master/lucid/optvis/param/spatial.py 122 | def rfft2d_freqs(h, w): 123 | """Computes 2D spectrum frequencies.""" 124 | fy = np.fft.fftfreq(h)[:, None] 125 | # when we have an odd input dimension we need to keep one additional frequency and later cut off 1 pixel 126 | w2 = (w+1)//2 if w%2 == 1 else w//2+1 127 | fx = np.fft.fftfreq(w)[:w2] 128 | return np.sqrt(fx * fx + fy * fy) 129 | 130 | def resume_fft(resume=None, shape=None, decay=None, colors=1.6, sd=0.01): 131 | size = None 132 | if resume is None: # random init 133 | params_shape = [*shape[:3], shape[3]//2+1, 2] # [1,3,512,257,2] for 512x512 (2 for imaginary and real components) 134 | params = 0.01 * torch.randn(*params_shape).cuda() 135 | elif isinstance(resume, str): 136 | if os.path.isfile(resume): 137 | if os.path.splitext(resume)[1].lower()[1:] in ['jpg','png','tif','bmp']: 138 | img_in = img_read(resume) 139 | params = img2fft(img_in, decay, colors) 140 | size = img_in.shape[:2] 141 | else: 142 | params = torch.load(resume) 143 | if isinstance(params, list): params = params[0] 144 | params = params.detach().cuda() 145 | params *= sd 146 | else: print(' Snapshot not found:', resume); exit() 147 | else: 148 | if isinstance(resume, list): resume = resume[0] 149 | params = resume.cuda() 150 | return params, size 151 | 152 | def fft_image(shape, sd=0.01, decay_power=1.0, resume=None): # decay ~ blur 153 | 154 | params, size = resume_fft(resume, shape, decay_power, sd=sd) 155 | spectrum_real_imag_t = params.requires_grad_(True) 156 | if size is not None: shape[2:] = size 157 | [h,w] = list(shape[2:]) 158 | 159 | freqs = rfft2d_freqs(h,w) 160 | scale = 1. / np.maximum(freqs, 4./max(h,w)) ** decay_power 161 | scale *= np.sqrt(h*w) 162 | scale = torch.tensor(scale).float()[None, None, ..., None].cuda() 163 | 164 | def inner(shift=None, contrast=1., *noargs, **nokwargs): 165 | scaled_spectrum_t = scale * spectrum_real_imag_t 166 | if shift is not None: 167 | scaled_spectrum_t += scale * shift 168 | if old_torch(): 169 | image = torch.irfft(scaled_spectrum_t, 2, normalized=True, signal_sizes=(h, w)) 170 | else: 171 | if type(scaled_spectrum_t) is not torch.complex64: 172 | scaled_spectrum_t = torch.view_as_complex(scaled_spectrum_t) 173 | image = torch.fft.irfftn(scaled_spectrum_t, s=(h, w), norm='ortho') 174 | image = image * contrast / image.std() # keep contrast, empirical 175 | return image 176 | 177 | return [spectrum_real_imag_t], inner, size 178 | 179 | def inv_sigmoid(x): 180 | eps = 1.e-12 181 | x = torch.clamp(x.double(), eps, 1-eps) 182 | y = torch.log(x/(1-x)) 183 | return y.float() 184 | 185 | def un_rgb(image, colors=1.): 186 | color_correlation_svd_sqrt = torch.tensor([[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]]) 187 | color_correlation_svd_sqrt /= torch.tensor([colors, 1., 1.]) # saturate, empirical 188 | max_norm_svd_sqrt = color_correlation_svd_sqrt.norm(dim=0).max() 189 | color_correlation_normalized = color_correlation_svd_sqrt / max_norm_svd_sqrt 190 | colcorr_t = color_correlation_normalized.T.cuda() 191 | colcorr_t_inv = torch.linalg.inv(colcorr_t) 192 | 193 | if not isinstance(image, torch.Tensor): # numpy int array [0..255] 194 | image = torch.Tensor(image).cuda().permute(2,0,1).unsqueeze(0) / 255. 195 | # image = inv_sigmoid(image) 196 | image = normalize()(image) # experimental 197 | return torch.einsum('nchw,cd->ndhw', image, colcorr_t_inv) # edit by katherine crowson 198 | 199 | def un_spectrum(spectrum, decay_power): 200 | h = spectrum.shape[2] 201 | w = (spectrum.shape[3]-1)*2 202 | freqs = rfft2d_freqs(h, w) 203 | scale = 1.0 / np.maximum(freqs, 1.0 / max(w, h)) ** decay_power 204 | scale *= np.sqrt(w*h) 205 | scale = torch.tensor(scale).float()[None, None, ..., None].cuda() 206 | return spectrum / scale 207 | 208 | def img2fft(img_in, decay=1., colors=1.): 209 | image_t = un_rgb(img_in, colors=colors) 210 | h, w = image_t.shape[2], image_t.shape[3] 211 | 212 | with torch.no_grad(): 213 | if old_torch(): 214 | spectrum = torch.rfft(image_t, 2, normalized=True) # 1.7 215 | else: 216 | spectrum = torch.fft.rfftn(image_t, s=(h, w), dim=[2,3], norm='ortho') # 1.8 217 | spectrum = torch.view_as_real(spectrum) 218 | spectrum = un_spectrum(spectrum, decay_power=decay) 219 | spectrum *= 500000. # [sic!!!] 220 | return spectrum 221 | -------------------------------------------------------------------------------- /aphantasia/interpol.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | warnings.filterwarnings("ignore") 4 | import argparse 5 | import math 6 | import numpy as np 7 | 8 | import torch 9 | 10 | from clip_fft import to_valid_rgb, fft_image 11 | from aphantasia.utils import basename, file_list, checkout 12 | try: # progress bar for notebooks 13 | get_ipython().__class__.__name__ 14 | from aphantasia.progress_bar import ProgressIPy as ProgressBar 15 | except: # normal console 16 | from aphantasia.progress_bar import ProgressBar 17 | 18 | def get_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('-i', '--in_dir', default='pt') 21 | parser.add_argument('-o', '--out_dir', default='_out') 22 | parser.add_argument('-l', '--length', default=None, type=int, help='Total length in sec') 23 | parser.add_argument('-s', '--steps', default=25, type=int, help='Override length') 24 | parser.add_argument( '--fps', default=25, type=int) 25 | parser.add_argument( '--contrast', default=1.1, type=float) 26 | parser.add_argument( '--colors', default=1.8, type=float) 27 | parser.add_argument('-d', '--decay', default=1.5, type=float) 28 | parser.add_argument('-v', '--verbose', default=True, type=bool) 29 | a = parser.parse_args() 30 | return a 31 | 32 | def read_pt(file): 33 | return torch.load(file)[0].cuda() 34 | 35 | def main(): 36 | a = get_args() 37 | tempdir = os.path.join(a.out_dir, 'a') 38 | os.makedirs(tempdir, exist_ok=True) 39 | 40 | ptfiles = file_list(a.in_dir, 'pt') 41 | 42 | ptest = torch.load(ptfiles[0]) 43 | if isinstance(ptest, list): ptest = ptest[0] 44 | shape = [*ptest.shape[:3], (ptest.shape[3]-1)*2] 45 | 46 | vsteps = a.lsteps if a.length is None else int(a.length * a.fps / count) 47 | pbar = ProgressBar(vsteps * len(ptfiles)) 48 | for px in range(len(ptfiles)): 49 | params1 = read_pt(ptfiles[px]) 50 | params2 = read_pt(ptfiles[(px+1) % len(ptfiles)]) 51 | 52 | params, image_f, _ = fft_image(shape, resume=params1, sd=1., decay_power=a.decay) 53 | image_f = to_valid_rgb(image_f, colors = a.colors) 54 | 55 | for i in range(vsteps): 56 | with torch.no_grad(): 57 | x = i/vsteps # math.sin(1.5708 * i/vsteps) 58 | img = image_f((params2 - params1) * x, contrast=a.contrast).cpu().numpy()[0] 59 | checkout(img, os.path.join(tempdir, '%05d.jpg' % (px * vsteps + i)), verbose=a.verbose) 60 | pbar.upd() 61 | 62 | os.system('ffmpeg -v warning -y -i %s/\%%05d.jpg "%s-pts.mp4"' % (tempdir, a.in_dir)) 63 | 64 | 65 | if __name__ == '__main__': 66 | main() 67 | -------------------------------------------------------------------------------- /aphantasia/progress_bar.py: -------------------------------------------------------------------------------- 1 | """ 2 | from progress_bar import ProgressBar 3 | 4 | pbar = ProgressBar(steps) 5 | pbar.upd() 6 | """ 7 | 8 | import os 9 | import sys 10 | import math 11 | os.system('') #enable VT100 Escape Sequence for WINDOWS 10 Ver. 1607 12 | 13 | from shutil import get_terminal_size 14 | import time 15 | 16 | import ipywidgets as ipy 17 | import IPython 18 | class ProgressIPy(object): 19 | def __init__(self, task_num=10): 20 | self.pbar = ipy.IntProgress(min=0, max=task_num, bar_style='') # (value=0, min=0, max=max, step=1, description=description, bar_style='') 21 | self.labl = ipy.Label() 22 | IPython.display.display(ipy.HBox([self.pbar, self.labl])) 23 | self.task_num = task_num 24 | self.completed = 0 25 | self.start() 26 | 27 | def start(self, task_num=None): 28 | if task_num is not None: 29 | self.task_num = task_num 30 | if self.task_num > 0: 31 | self.labl.value = '0/{}'.format(self.task_num) 32 | else: 33 | self.labl.value = 'completed: 0, elapsed: 0s' 34 | self.start_time = time.time() 35 | 36 | def upd(self, *p, **kw): 37 | self.completed += 1 38 | elapsed = time.time() - self.start_time + 0.0000000000001 39 | fps = self.completed / elapsed if elapsed>0 else 0 40 | if self.task_num > 0: 41 | finaltime = time.asctime(time.localtime(self.start_time + self.task_num * elapsed / float(self.completed))) 42 | fin = ' end %s' % finaltime[11:16] 43 | percentage = self.completed / float(self.task_num) 44 | eta = int(elapsed * (1 - percentage) / percentage + 0.5) 45 | self.labl.value = '{}/{}, rate {:.3g}s, time {}s, left {}s, {}'.format(self.completed, self.task_num, 1./fps, shortime(elapsed), shortime(eta), fin) 46 | else: 47 | self.labl.value = 'completed {}, time {}s, {:.1f} steps/s'.format(self.completed, int(elapsed + 0.5), fps) 48 | self.pbar.value += 1 49 | if self.completed == self.task_num: self.pbar.bar_style = 'success' 50 | return self.completed 51 | 52 | 53 | class ProgressBar(object): 54 | '''A progress bar which can print the progress 55 | modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py 56 | ''' 57 | def __init__(self, task_num=0, bar_width=50, start=True): 58 | self.task_num = task_num 59 | max_bar_width = self._get_max_bar_width() 60 | self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width) 61 | self.completed = 0 62 | if start: 63 | self.start() 64 | 65 | def _get_max_bar_width(self): 66 | terminal_width, _ = get_terminal_size() 67 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) 68 | if max_bar_width < 10: 69 | print('terminal is small ({}), make it bigger for proper visualization'.format(terminal_width)) 70 | max_bar_width = 10 71 | return max_bar_width 72 | 73 | def start(self, task_num=None): 74 | if task_num is not None: 75 | self.task_num = task_num 76 | if self.task_num > 0: 77 | sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(' ' * self.bar_width, self.task_num, 'Start...')) 78 | else: 79 | sys.stdout.write('completed: 0, elapsed: 0s') 80 | sys.stdout.flush() 81 | self.start_time = time.time() 82 | 83 | def upd(self, msg=None): 84 | self.completed += 1 85 | elapsed = time.time() - self.start_time + 0.0000000000001 86 | fps = self.completed / elapsed if elapsed>0 else 0 87 | if self.task_num > 0: 88 | percentage = self.completed / float(self.task_num) 89 | eta = int(elapsed * (1 - percentage) / percentage + 0.5) 90 | finaltime = time.asctime(time.localtime(self.start_time + self.task_num * elapsed / float(self.completed))) 91 | fin_msg = ' %ss left, end %s' % (shortime(eta), finaltime[11:16]) 92 | if msg is not None: fin_msg += ' ' + str(msg) 93 | mark_width = int(self.bar_width * percentage) 94 | bar_chars = 'X' * mark_width + '-' * (self.bar_width - mark_width) # ▒ ▓ █ 95 | sys.stdout.write('\033[2A') # cursor up 2 lines 96 | sys.stdout.write('\033[J') # clean the output (remove extra chars since last display) 97 | try: 98 | sys.stdout.write('[{}] {}/{}, rate {:.3g}s, time {}s, left {}s \n{}\n'.format( 99 | bar_chars, self.completed, self.task_num, 1./fps, shortime(elapsed), shortime(eta), fin_msg)) 100 | except: 101 | sys.stdout.write('[{}] {}/{}, rate {:.3g}s, time {}s, left {}s \n{}\n'.format( 102 | bar_chars, self.completed, self.task_num, 1./fps, shortime(elapsed), shortime(eta), '<< unprintable >>')) 103 | else: 104 | sys.stdout.write('completed {}, time {}s, {:.1f} steps/s'.format(self.completed, int(elapsed + 0.5), fps)) 105 | sys.stdout.flush() 106 | 107 | def reset(self, count=None, newline=False): 108 | self.start_time = time.time() 109 | if count is not None: 110 | self.task_num = count 111 | if newline is True: 112 | sys.stdout.write('\n\n') 113 | 114 | def time_days(sec): 115 | return '%dd %d:%02d:%02d' % (sec/86400, (sec/3600)%24, (sec/60)%60, sec%60) 116 | def time_hrs(sec): 117 | return '%d:%02d:%02d' % (sec/3600, (sec/60)%60, sec%60) 118 | def shortime(sec): 119 | if sec < 60: 120 | time_short = '%d' % (sec) 121 | elif sec < 3600: 122 | time_short = '%d:%02d' % ((sec/60)%60, sec%60) 123 | elif sec < 86400: 124 | time_short = time_hrs(sec) 125 | else: 126 | time_short = time_days(sec) 127 | return time_short 128 | 129 | -------------------------------------------------------------------------------- /aphantasia/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Lucent Authors. All Rights Reserved. 2 | # http://www.apache.org/licenses/LICENSE-2.0 3 | 4 | import numpy as np 5 | import PIL 6 | import kornia 7 | import kornia.geometry.transform as K 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torchvision import transforms as T 12 | 13 | from .utils import old_torch 14 | 15 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 16 | 17 | def random_elastic(): 18 | def inner(x): 19 | a = np.random.rand(2) 20 | k = np.random.randint(8,64) * 2 + 1 # 63 21 | s = k / (np.random.rand()+2.) # 2-3 times less than k 22 | # s = float(np.random.randint(8,64)) # 32 23 | noise = torch.zeros([1, 2, x.shape[2], x.shape[3]]).cuda() 24 | return K.elastic_transform2d(x, noise, (k,k), (s,s), tuple(a)) 25 | return inner 26 | 27 | def jitter(d): 28 | assert d > 1, "Jitter parameter d must be more than 1, currently {}".format(d) 29 | def inner(image_t): 30 | dx = np.random.choice(d) 31 | dy = np.random.choice(d) 32 | return K.translate(image_t, torch.tensor([[dx, dy]]).float().to(device)) 33 | return inner 34 | 35 | def pad(w, mode="reflect", constant_value=0.5): 36 | if mode != "constant": 37 | constant_value = 0 38 | def inner(image_t): 39 | return F.pad(image_t, [w] * 4, mode=mode, value=constant_value,) 40 | return inner 41 | 42 | def random_scale(scales): 43 | def inner(image_t): 44 | scale = np.random.choice(scales) 45 | shp = image_t.shape[2:] 46 | scale_shape = [_roundup(scale * d) for d in shp] 47 | pad_x = max(0, _roundup((shp[1] - scale_shape[1]) / 2)) 48 | pad_y = max(0, _roundup((shp[0] - scale_shape[0]) / 2)) 49 | upsample = torch.nn.Upsample(size=scale_shape, mode="bilinear", align_corners=True) 50 | return F.pad(upsample(image_t), [pad_y, pad_x] * 2) 51 | return inner 52 | 53 | def random_rotate(angles, units="degrees"): 54 | def inner(image_t): 55 | b, _, h, w = image_t.shape 56 | # kornia takes degrees 57 | alpha = _rads2angle(np.random.choice(angles), units) 58 | angle = torch.ones(b) * alpha 59 | # scale = torch.ones(b) 60 | scale = torch.ones(b, 2) 61 | center = torch.ones(b, 2) 62 | center[..., 0] = (image_t.shape[3] - 1) / 2 63 | center[..., 1] = (image_t.shape[2] - 1) / 2 64 | try: 65 | M = kornia.geometry.transform.get_rotation_matrix2d(center, angle, scale).to(device) 66 | rotated_image = kornia.geometry.transform.warp_affine(image_t.float(), M, dsize=(h, w)) 67 | except: 68 | M = kornia.get_rotation_matrix2d(center, angle, scale).to(device) 69 | rotated_image = kornia.warp_affine(image_t.float(), M, dsize=(h, w)) 70 | return rotated_image 71 | return inner 72 | 73 | def random_rotate_fast(angles): 74 | def inner(img): 75 | angle = float(np.random.choice(angles)) 76 | size = img.shape[-2:] 77 | if old_torch(): # 1.7.1 78 | img = T.functional.affine(img, angle, [0,0], 1, 0, fillcolor=0, resample=PIL.Image.BILINEAR) 79 | else: # 1.8+ 80 | img = T.functional.affine(img, angle, [0,0], 1, 0, fill=0, interpolation=T.InterpolationMode.BILINEAR) 81 | img = T.functional.center_crop(img, size) # on 1.8+ also pads 82 | return img 83 | return inner 84 | 85 | def compose(transforms): 86 | def inner(x): 87 | for transform in transforms: 88 | x = transform(x) 89 | return x 90 | return inner 91 | 92 | def _roundup(value): 93 | return np.ceil(value).astype(int) 94 | 95 | def _rads2angle(angle, units): 96 | if units.lower() == "degrees": 97 | return angle 98 | if units.lower() in ["radians", "rads", "rad"]: 99 | angle = angle * 180.0 / np.pi 100 | return angle 101 | 102 | def normalize(): 103 | # ImageNet normalization for torchvision models 104 | # see https://pytorch.org/docs/stable/torchvision/models.html 105 | # normal = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 106 | normal = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 107 | def inner(image_t): 108 | return torch.stack([normal(t) for t in image_t]) 109 | return inner 110 | 111 | def preprocess_inceptionv1(): 112 | # Original Tensorflow's InceptionV1 model takes in [-117, 138] 113 | # See https://github.com/tensorflow/lucid/blob/master/lucid/modelzoo/other_models/InceptionV1.py#L56 114 | # Thanks to ProGamerGov for this! 115 | return lambda x: x * 255 - 117 116 | 117 | # from lucent 118 | transforms_lucent = compose([ 119 | pad(12, mode="constant", constant_value=0.5), 120 | jitter(8), 121 | random_scale([1 + (i - 5) / 50.0 for i in range(11)]), 122 | random_rotate(list(range(-10, 11)) + 5 * [0]), 123 | jitter(4), 124 | ]) 125 | 126 | # from openai 127 | transforms_openai = compose([ 128 | pad(2, mode='constant', constant_value=.5), 129 | jitter(4), 130 | jitter(4), 131 | jitter(4), 132 | jitter(4), 133 | jitter(4), 134 | jitter(4), 135 | jitter(4), 136 | jitter(4), 137 | jitter(4), 138 | jitter(4), 139 | # random_scale([0.995**n for n in range(-5,80)] + [0.998**n for n in 2*list(range(20,40))]), 140 | random_rotate(list(range(-20,20))+list(range(-10,10))+list(range(-5,5))+5*[0]), 141 | jitter(2), 142 | # crop_or_pad_to(resolution, resolution) 143 | ]) 144 | 145 | # my compos 146 | 147 | transforms_elastic = compose([ 148 | pad(4, mode="constant", constant_value=0.5), 149 | T.RandomErasing(0.2), 150 | random_rotate(list(range(-30, 30)) + 20 * [0]), 151 | random_elastic(), 152 | jitter(8), 153 | normalize() 154 | ]) 155 | 156 | transforms_custom = compose([ 157 | pad(4, mode="constant", constant_value=0.5), 158 | # T.RandomPerspective(0.33, 0.2), 159 | # T.RandomErasing(0.2), 160 | random_rotate(list(range(-30, 30)) + 20 * [0]), 161 | jitter(8), 162 | normalize() 163 | ]) 164 | 165 | transforms_fast = compose([ 166 | T.RandomPerspective(0.33, 0.2), 167 | T.RandomErasing(0.2), 168 | random_rotate_fast(list(range(-30, 30)) + 20 * [0]), 169 | normalize() 170 | ]) 171 | 172 | -------------------------------------------------------------------------------- /aphantasia/utils.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import os 3 | import math 4 | import time 5 | from imageio import imread, imsave 6 | import cv2 7 | import numpy as np 8 | import collections 9 | import scipy 10 | from scipy.ndimage import gaussian_filter 11 | from scipy.interpolate import CubicSpline as CubSpline 12 | import matplotlib.pyplot as plt 13 | from kornia.filters.sobel import spatial_gradient 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | 18 | def plot_text(txt, size=224): 19 | fig = plt.figure(figsize=(1,1), dpi=size) 20 | fontsize = size//len(txt) if len(txt) < 15 else 8 21 | plt.text(0.5, 0.5, txt, fontsize=fontsize, ha='center', va='center', wrap=True) 22 | plt.axis('off') 23 | fig.tight_layout(pad=0) 24 | fig.canvas.draw() 25 | img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 26 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 27 | return img 28 | 29 | def txt_clean(txt): 30 | return txt.translate(str.maketrans(dict.fromkeys(list("\n',.—|!?/:;\\"), ""))).replace(' ', '_').replace('"', '') 31 | 32 | def intrl(a, b, step=2): 33 | assert len(a) == len(b), ' diff lengths: %d %d' % (len(a), len(b)) 34 | assert step > 1 35 | nums = list(range(len(a)))[step::step] 36 | for num in nums: 37 | a[num] = b[num] 38 | return a 39 | 40 | def old_torch(): 41 | ver = [int(i) for i in torch.__version__.split('.')[:2]] 42 | return True if (ver[0] < 2 and ver[1] < 8) else False 43 | 44 | def basename(file): 45 | return os.path.splitext(os.path.basename(file))[0] 46 | 47 | def file_list(path, ext=None, subdir=None): 48 | if subdir is True: 49 | files = [os.path.join(dp, f) for dp, dn, fn in os.walk(path) for f in fn] 50 | else: 51 | files = [os.path.join(path, f) for f in os.listdir(path)] 52 | if ext is not None: 53 | if isinstance(ext, list): 54 | files = [f for f in files if os.path.splitext(f.lower())[1][1:] in ext] 55 | elif isinstance(ext, str): 56 | files = [f for f in files if f.endswith(ext)] 57 | else: 58 | print(' Unknown extension/type for file list!') 59 | return sorted([f for f in files if os.path.isfile(f)]) 60 | 61 | def img_list(path, subdir=None): 62 | if subdir is True: 63 | files = [os.path.join(dp, f) for dp, dn, fn in os.walk(path) for f in fn] 64 | else: 65 | files = [os.path.join(path, f) for f in os.listdir(path)] 66 | files = [f for f in files if os.path.splitext(f.lower())[1][1:] in ['jpg', 'jpeg', 'png', 'ppm', 'tif']] 67 | return sorted([f for f in files if os.path.isfile(f)]) 68 | 69 | def img_read(path): 70 | img = imread(path) 71 | # 8bit to 256bit 72 | if (img.ndim == 2) or (img.shape[2] == 1): 73 | img = np.dstack((img,img,img)) 74 | # rgba to rgb 75 | if img.shape[2] == 4: 76 | img = img[:,:,:3] 77 | return img 78 | 79 | def img_save(path, img, norm=True): 80 | if norm == True and not np.issubdtype(img.dtype.kind, np.integer): 81 | img = (img*255).astype(np.uint8) 82 | imsave(path, img) 83 | 84 | def cvshow(img): 85 | img = np.array(img) 86 | if img.shape[0] > 720 or img.shape[1] > 1280: 87 | x_ = 1280 / img.shape[1] 88 | y_ = 720 / img.shape[0] 89 | psize = tuple([int(s * min(x_, y_)) for s in img.shape[:2][::-1]]) 90 | img = cv2.resize(img, psize) 91 | cv2.imshow('t', img[:,:,::-1]) 92 | cv2.waitKey(1) 93 | 94 | def checkout(img, fname=None, verbose=False): 95 | img = np.transpose(np.array(img)[:,:,:], (1,2,0)) 96 | if verbose is True: 97 | cvshow(img) 98 | if fname is not None: 99 | img = np.clip(img*255, 0, 255).astype(np.uint8) 100 | imsave(fname, img) 101 | 102 | def save_cfg(args, dir='./', file='config.txt'): 103 | if dir != '': 104 | os.makedirs(dir, exist_ok=True) 105 | try: args = vars(args) 106 | except: pass 107 | if file is None: 108 | print_dict(args) 109 | else: 110 | with open(os.path.join(dir, file), 'w') as cfg_file: 111 | print_dict(args, cfg_file) 112 | 113 | def print_dict(dict, file=None, path="", indent=''): 114 | for k in sorted(dict.keys()): 115 | if isinstance(dict[k], collections.abc.Mapping): 116 | if file is None: 117 | print(indent + str(k)) 118 | else: 119 | file.write(indent + str(k) + ' \n') 120 | path = k if path=="" else path + "->" + k 121 | print_dict(dict[k], file, path, indent + ' ') 122 | else: 123 | if file is None: 124 | print('%s%s: %s' % (indent, str(k), str(dict[k]))) 125 | else: 126 | file.write('%s%s: %s \n' % (indent, str(k), str(dict[k]))) 127 | 128 | def minmax(x, torch=True): 129 | if torch: 130 | mn = torch.min(x).detach().cpu().numpy() 131 | mx = torch.max(x).detach().cpu().numpy() 132 | else: 133 | mn = np.min(x.detach().cpu().numpy()) 134 | mx = np.max(x.detach().cpu().numpy()) 135 | return (mn, mx) 136 | 137 | def triangle_blur(x, kernel_size=3, pow=1.0): 138 | padding = (kernel_size-1) // 2 139 | b,c,h,w = x.shape 140 | kernel = torch.linspace(-1,1,kernel_size+2)[1:-1].abs().neg().add(1).reshape(1,1,1,kernel_size).pow(pow).cuda() 141 | kernel = kernel / kernel.sum() 142 | x = x.reshape(b*c,1,h,w) 143 | x = F.pad(x, (padding,padding,padding,padding), mode='reflect') 144 | x = F.conv2d(x, kernel) 145 | x = F.conv2d(x, kernel.permute(0,1,3,2)) 146 | x = x.reshape(b,c,h,w) 147 | return x 148 | 149 | # Tiles an array around two points, allowing for pad lengths greater than the input length 150 | # NB: if symm=True, every second tile is mirrored = messed up in GAN 151 | # adapted from https://discuss.pytorch.org/t/symmetric-padding/19866/3 152 | def tile_pad(xt, padding, symm=False): 153 | h, w = xt.shape[-2:] 154 | left, right, top, bottom = padding 155 | 156 | def tile(x, minx, maxx): 157 | rng = maxx - minx 158 | if symm is True: # triangular reflection 159 | double_rng = 2*rng 160 | mod = np.fmod(x - minx, double_rng) 161 | normed_mod = np.where(mod < 0, mod+double_rng, mod) 162 | out = np.where(normed_mod >= rng, double_rng - normed_mod, normed_mod) + minx 163 | else: # repeating tiles 164 | mod = np.remainder(x - minx, rng) 165 | out = mod + minx 166 | return np.array(out, dtype=x.dtype) 167 | 168 | x_idx = np.arange(-left, w+right) 169 | y_idx = np.arange(-top, h+bottom) 170 | x_pad = tile(x_idx, -0.5, w-0.5) 171 | y_pad = tile(y_idx, -0.5, h-0.5) 172 | xx, yy = np.meshgrid(x_pad, y_pad) 173 | return xt[..., yy, xx] 174 | 175 | def pad_up_to(x, size, type='centr'): 176 | sh = x.shape[2:][::-1] 177 | if list(x.shape[2:]) == list(size): return x 178 | padding = [] 179 | for i, s in enumerate(size[::-1]): 180 | if 'side' in type.lower(): 181 | padding = padding + [0, s-sh[i]] 182 | else: # centr 183 | p0 = (s-sh[i]) // 2 184 | p1 = s-sh[i] - p0 185 | padding = padding + [p0,p1] 186 | y = tile_pad(x, padding, symm = ('symm' in type.lower())) 187 | return y 188 | 189 | def smoothstep(x, NN=1, xmin=0., xmax=1.): 190 | N = math.ceil(NN) 191 | x = np.clip((x - xmin) / (xmax - xmin), 0, 1) 192 | result = 0 193 | for n in range(0, N+1): 194 | result += scipy.special.comb(N+n, n) * scipy.special.comb(2*N+1, N-n) * (-x)**n 195 | result *= x**(N+1) 196 | if NN != N: result = (x + result) / 2 197 | return result 198 | 199 | def slerp(z1, z2, num_steps=None, x=None, smooth=0.5): 200 | z1_norm = z1.norm() 201 | z2_norm = z2.norm() 202 | z2_normal = z2 * (z1_norm / z2_norm) 203 | vectors = [] 204 | if num_steps is not None: 205 | xs = [step / (num_steps - 1) for step in range(num_steps)] 206 | else: 207 | xs = [x] 208 | if smooth > 0: xs = [smoothstep(x, smooth) for x in xs] 209 | for x in xs: 210 | interplain = z1 + (z2 - z1) * x 211 | interp = z1 + (z2_normal - z1) * x 212 | interp_norm = interp.norm() 213 | if interp_norm != 0: 214 | interpol_normal = interplain * (z1_norm / interp_norm) 215 | vectors.append(interpol_normal) 216 | return torch.cat(vectors) 217 | 218 | def slice_imgs(imgs, count, size=224, transform=None, align='uniform', macro=0.): 219 | def map(x, a, b): 220 | return x * (b-a) + a 221 | 222 | rnd_size = torch.rand(count) 223 | if align == 'central': # normal around center 224 | rnd_offx = torch.clip(torch.randn(count) * 0.2 + 0.5, 0., 1.) 225 | rnd_offy = torch.clip(torch.randn(count) * 0.2 + 0.5, 0., 1.) 226 | else: # uniform 227 | rnd_offx = torch.rand(count) 228 | rnd_offy = torch.rand(count) 229 | 230 | sz = [img.shape[2:] for img in imgs] 231 | sz_max = [torch.min(torch.tensor(s)) for s in sz] 232 | if 'over' in align: # expand frame to sample outside 233 | if align == 'overmax': 234 | sz = [[2*s[0], 2*s[1]] for s in list(sz)] 235 | else: 236 | sz = [[int(1.5*s[0]), int(1.5*s[1])] for s in list(sz)] 237 | imgs = [pad_up_to(imgs[i], sz[i], type='centr') for i in range(len(imgs))] 238 | 239 | sliced = [] 240 | for i, img in enumerate(imgs): 241 | cuts = [] 242 | sz_max_i = sz_max[i] 243 | for c in range(count): 244 | sz_min_i = 0.9*sz_max[i] if torch.rand(1) < macro else size 245 | csize = map(rnd_size[c], sz_min_i, sz_max_i).int() 246 | offsetx = map(rnd_offx[c], 0, sz[i][1] - csize).int() 247 | offsety = map(rnd_offy[c], 0, sz[i][0] - csize).int() 248 | cut = img[:, :, offsety:offsety + csize, offsetx:offsetx + csize] 249 | cut = F.interpolate(cut, (size,size), mode='bicubic', align_corners=True) # bilinear 250 | if transform is not None: 251 | cut = transform(cut) 252 | cuts.append(cut) 253 | sliced.append(torch.cat(cuts, 0)) 254 | return sliced 255 | 256 | def derivat(img, mode='sobel'): 257 | if mode == 'scharr': 258 | # https://en.wikipedia.org/wiki/Sobel_operator#Alternative_operators 259 | k_scharr = torch.Tensor([[[-0.183,0.,0.183], [-0.634,0.,0.634], [-0.183,0.,0.183]], [[-0.183,-0.634,-0.183], [0.,0.,0.], [0.183,0.634,0.183]]]) 260 | k_scharr = k_scharr.unsqueeze(1).tile((1,3,1,1)).cuda() 261 | return 0.2 * torch.mean(torch.abs(F.conv2d(img, k_scharr))) 262 | elif mode == 'sobel': 263 | # https://kornia.readthedocs.io/en/latest/filters.html#edge-detection 264 | return torch.mean(torch.abs(spatial_gradient(img))) 265 | else: # trivial hack 266 | dx = torch.mean(torch.abs(img[:,:,:,1:] - img[:,:,:,:-1])) 267 | dy = torch.mean(torch.abs(img[:,:,1:,:] - img[:,:,:-1,:])) 268 | return 0.5 * (dx+dy) 269 | 270 | def dot_compare(v1, v2, cossim_pow=0): 271 | dot = (v1 * v2).sum() 272 | mag = torch.sqrt(torch.sum(v2**2)) 273 | cossim = dot/(1e-6 + mag) 274 | return dot * cossim ** cossim_pow 275 | 276 | def sim_func(v1, v2, type=None): 277 | if type is not None and 'mix' in type: # mixed 278 | coss = torch.cosine_similarity(v1, v2, dim=-1).mean() 279 | v1 = F.normalize(v1, dim=-1) 280 | v2 = F.normalize(v2, dim=-1) 281 | spher = torch.abs((v1 - v2).norm(dim=-1).div(2).arcsin().pow(2).mul(2)).mean() 282 | return coss - 0.25 * spher 283 | elif type is not None and 'spher' in type: # spherical 284 | # from https://colab.research.google.com/drive/1ED6_MYVXTApBHzQObUPaaMolgf9hZOOF 285 | v1 = F.normalize(v1, dim=-1) 286 | v2 = F.normalize(v2, dim=-1) 287 | # return 1 - torch.abs((v1 - v2).norm(dim=-1).div(2).arcsin().pow(2).mul(2)).mean() 288 | return (v1 - v2).norm(dim=-1).div(2).arcsin().pow(2).mul(2) 289 | elif type is not None and 'ang' in type: # angular 290 | # return 1 - torch.acos(torch.cosine_similarity(v1, v2, dim=-1).mean()) / np.pi 291 | return 1 - torch.acos(torch.cosine_similarity(v1, v2, dim=-1)).mean() / np.pi 292 | elif type is not None and 'dot' in type: # dot compare cossim from lucent inversion 293 | return dot_compare(v1, v2, cossim_pow=1) # decrease pow if nan (black output) 294 | else: 295 | return torch.cosine_similarity(v1, v2, dim=-1).mean() 296 | 297 | # = = = = = = = = = = = = = = = = = = = = = = = = = = = 298 | 299 | def get_z(shape, rnd, uniform=False): 300 | if uniform: 301 | return rnd.uniform(0., 1., shape) 302 | else: 303 | return rnd.randn(*shape) # *x unpacks tuple/list to sequence 304 | 305 | def smoothstep(x, NN=1., xmin=0., xmax=1.): 306 | N = math.ceil(NN) 307 | x = np.clip((x - xmin) / (xmax - xmin), 0, 1) 308 | result = 0 309 | for n in range(0, N+1): 310 | result += scipy.special.comb(N+n, n) * scipy.special.comb(2*N+1, N-n) * (-x)**n 311 | result *= x**(N+1) 312 | if NN != N: result = (x + result) / 2 313 | return result 314 | 315 | def lerp(z1, z2, num_steps, smooth=0.): 316 | vectors = [] 317 | xs = [step / (num_steps - 1) for step in range(num_steps)] 318 | if smooth > 0: xs = [smoothstep(x, smooth) for x in xs] 319 | for x in xs: 320 | interpol = z1 + (z2 - z1) * x 321 | vectors.append(interpol) 322 | return np.array(vectors) 323 | 324 | # interpolate on hypersphere 325 | def slerp_np(z1, z2, num_steps, smooth=0.): 326 | z1_norm = np.linalg.norm(z1) 327 | z2_norm = np.linalg.norm(z2) 328 | z2_normal = z2 * (z1_norm / z2_norm) 329 | vectors = [] 330 | xs = [step / (num_steps - 1) for step in range(num_steps)] 331 | if smooth > 0: xs = [smoothstep(x, smooth) for x in xs] 332 | for x in xs: 333 | interplain = z1 + (z2 - z1) * x 334 | interp = z1 + (z2_normal - z1) * x 335 | interp_norm = np.linalg.norm(interp) 336 | interpol_normal = interplain * (z1_norm / interp_norm) 337 | # interpol_normal = interp * (z1_norm / interp_norm) 338 | vectors.append(interpol_normal) 339 | return np.array(vectors) 340 | 341 | def cublerp(points, steps, fstep, looped=True): 342 | keys = np.array([i*fstep for i in range(steps)] + [steps*fstep]) 343 | last_pt_num = 0 if looped is True else -1 344 | points = np.concatenate((points, np.expand_dims(points[last_pt_num], 0))) 345 | cspline = CubSpline(keys, points) 346 | return cspline(range(steps*fstep+1)) 347 | 348 | # = = = = = = = = = = = = = = = = = = = = = = = = = = = 349 | 350 | def latent_anima(shape, frames, transit, key_latents=None, smooth=0.5, uniform=False, cubic=False, gauss=False, start_lat=None, seed=None, looped=True, verbose=False): 351 | if key_latents is None: 352 | transit = int(max(1, min(frames//2, transit))) 353 | steps = max(1, math.ceil(frames / transit)) 354 | log = ' timeline: %d steps by %d' % (steps, transit) 355 | 356 | if seed is None: 357 | seed = np.random.seed(int((time.time()%1) * 9999)) 358 | rnd = np.random.RandomState(seed) 359 | 360 | # make key points 361 | if key_latents is None: 362 | key_latents = np.array([get_z(shape, rnd, uniform) for i in range(steps)]) 363 | if start_lat is not None: 364 | key_latents[0] = start_lat 365 | 366 | latents = np.expand_dims(key_latents[0], 0) 367 | 368 | # populate lerp between key points 369 | if transit == 1: 370 | latents = key_latents 371 | else: 372 | if cubic: 373 | latents = cublerp(key_latents, steps, transit, looped) 374 | log += ', cubic' 375 | else: 376 | for i in range(steps): 377 | zA = key_latents[i] 378 | lat_num = (i+1)%steps if looped is True else min(i+1, steps-1) 379 | zB = key_latents[lat_num] 380 | if uniform is True: 381 | interps_z = lerp(zA, zB, transit, smooth=smooth) 382 | else: 383 | interps_z = slerp_np(zA, zB, transit, smooth=smooth) 384 | latents = np.concatenate((latents, interps_z)) 385 | latents = np.array(latents) 386 | 387 | if gauss: 388 | lats_post = gaussian_filter(latents, [transit, 0, 0], mode="wrap") 389 | lats_post = (lats_post / np.linalg.norm(lats_post, axis=-1, keepdims=True)) * math.sqrt(np.prod(shape)) 390 | log += ', gauss' 391 | latents = lats_post 392 | 393 | if verbose: print(log) 394 | if latents.shape[0] > frames: # extra frame 395 | latents = latents[1:] 396 | return latents 397 | 398 | # = = = = = = = = = = = = = = = = = = = = = = = = = = = 399 | 400 | # from https://github.com/LAION-AI/aesthetic-predictor 401 | from urllib.request import urlretrieve # pylint: disable=import-outside-toplevel 402 | def aesthetic_model(clip_model='ViT-B/32'): 403 | nf = 768 if clip_model == "ViT-L/14" else 512 if clip_model in ['ViT-B/16', 'ViT-B/32'] else None 404 | clip_model = clip_model.replace('/','_').replace('-','_').lower() 405 | path_to_model = 'sa_0_4_%s_linear.pth' % clip_model 406 | if not os.path.isfile(path_to_model): 407 | url_model = "https://github.com/LAION-AI/aesthetic-predictor/blob/main/sa_0_4_" + clip_model + "_linear.pth?raw=true" 408 | urlretrieve(url_model, path_to_model) 409 | if nf is None or not os.path.isfile(path_to_model): return None 410 | m = torch.nn.Linear(nf, 1) 411 | m.load_state_dict(torch.load(path_to_model)) 412 | m.eval().half() 413 | return m 414 | 415 | -------------------------------------------------------------------------------- /clip_fft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | warnings.filterwarnings("ignore") 4 | import argparse 5 | import numpy as np 6 | from imageio import imread, imsave 7 | import shutil 8 | 9 | try: 10 | from googletrans import Translator 11 | googletrans_ok = True 12 | except: 13 | googletrans_ok = False 14 | 15 | import torch 16 | import torchvision 17 | import torch.nn.functional as F 18 | 19 | import clip 20 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 21 | # from sentence_transformers import SentenceTransformer 22 | import lpips 23 | 24 | from aphantasia.image import to_valid_rgb, fft_image, dwt_image 25 | from aphantasia.utils import slice_imgs, derivat, sim_func, aesthetic_model, basename, img_list, img_read, plot_text, txt_clean, checkout, old_torch 26 | from aphantasia import transforms 27 | try: # progress bar for notebooks 28 | get_ipython().__class__.__name__ 29 | from aphantasia.progress_bar import ProgressIPy as ProgressBar 30 | except: # normal console 31 | from aphantasia.progress_bar import ProgressBar 32 | 33 | clip_models = ['ViT-B/16', 'ViT-B/32', 'RN101', 'RN50x16', 'RN50x4', 'RN50'] 34 | 35 | def get_args(): 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('-t', '--in_txt', default=None, help='input text') 38 | parser.add_argument('-t2', '--in_txt2', default=None, help='input text - style') 39 | parser.add_argument('-t0', '--in_txt0', default=None, help='input text to subtract') 40 | parser.add_argument('-i', '--in_img', default=None, help='input image') 41 | parser.add_argument('-wi', '--weight_img', default=0.5, type=float, help='weight for images') 42 | parser.add_argument( '--out_dir', default='_out') 43 | parser.add_argument('-s', '--size', default='1280-720', help='Output resolution') 44 | parser.add_argument('-r', '--resume', default=None, help='Path to saved FFT snapshots, to resume from') 45 | parser.add_argument('-ops', '--opt_step', default=1, type=int, help='How many optimizing steps per save step') 46 | parser.add_argument('-tr', '--translate', action='store_true', help='Translate text with Google Translate') 47 | # parser.add_argument('-ml', '--multilang', action='store_true', help='Use SBERT multilanguage model for text') 48 | parser.add_argument( '--save_pt', action='store_true', help='Save FFT snapshots for further use') 49 | parser.add_argument('-v', '--verbose', dest='verbose', action='store_true') 50 | parser.add_argument('-nv', '--no-verbose', dest='verbose', action='store_false') 51 | parser.set_defaults(verbose=True) 52 | # training 53 | parser.add_argument('-m', '--model', default='ViT-B/32', choices=clip_models, help='Select CLIP model to use') 54 | parser.add_argument( '--steps', default=200, type=int, help='Total iterations') 55 | parser.add_argument( '--samples', default=200, type=int, help='Samples to evaluate') 56 | parser.add_argument('-lr', '--lrate', default=0.05, type=float, help='Learning rate') 57 | parser.add_argument('-p', '--prog', action='store_true', help='Enable progressive lrate growth (up to double a.lrate)') 58 | parser.add_argument('-dm', '--dualmod', default=None, type=int, help='Every this step use another CLIP ViT model') 59 | # wavelet 60 | parser.add_argument( '--dwt', action='store_true', help='Use DWT instead of FFT') 61 | parser.add_argument('-w', '--wave', default='coif2', help='wavelets: db[1..], coif[1..], haar, dmey') 62 | # tweaks 63 | parser.add_argument('-a', '--align', default='uniform', choices=['central', 'uniform', 'overscan', 'overmax'], help='Sampling distribution') 64 | parser.add_argument('-tf', '--transform', default='fast', choices=['none', 'fast', 'custom', 'elastic'], help='augmenting transforms') 65 | parser.add_argument('-opt', '--optimizer', default='adam_custom', choices=['adam', 'adamw', 'adam_custom', 'adamw_custom'], help='Optimizer') 66 | parser.add_argument( '--contrast', default=1.1, type=float) 67 | parser.add_argument( '--colors', default=1.8, type=float) 68 | parser.add_argument( '--decay', default=1.5, type=float) 69 | parser.add_argument('-sh', '--sharp', default=0., type=float) 70 | parser.add_argument('-mm', '--macro', default=0.4, type=float, help='Endorse macro forms 0..1 ') 71 | parser.add_argument( '--aest', default=0., type=float, help='Enhance aesthetics') 72 | parser.add_argument('-e', '--enforce', default=0, type=float, help='Enforce details (by boosting similarity between two parallel samples)') 73 | parser.add_argument('-x', '--expand', default=0, type=float, help='Boosts diversity (by enforcing difference between prev/next samples)') 74 | parser.add_argument('-n', '--noise', default=0, type=float, help='Add noise to suppress accumulation') # < 0.05 ? 75 | parser.add_argument('-c', '--sync', default=0, type=float, help='Sync output to input image') 76 | parser.add_argument( '--invert', action='store_true', help='Invert criteria') 77 | parser.add_argument( '--sim', default='mix', help='Similarity function (dot/angular/spherical/mixed; None = cossim)') 78 | a = parser.parse_args() 79 | 80 | if a.size is not None: a.size = [int(s) for s in a.size.split('-')][::-1] 81 | if len(a.size)==1: a.size = a.size * 2 82 | if (a.in_img is not None and a.sync != 0) or a.resume is not None: a.align = 'overscan' 83 | # if a.multilang is True: a.model = 'ViT-B/32' # sbert model is trained with ViT 84 | if a.translate is True and googletrans_ok is not True: 85 | print('\n Install googletrans module to enable translation!'); exit() 86 | if a.dualmod is not None: 87 | a.model = 'ViT-B/32' 88 | a.sim = 'cossim' 89 | 90 | return a 91 | 92 | def main(): 93 | a = get_args() 94 | 95 | shape = [1, 3, *a.size] 96 | if a.dwt is True: 97 | params, image_f, sz = dwt_image(shape, a.wave, 0.3, a.colors, a.resume) 98 | else: 99 | params, image_f, sz = fft_image(shape, 0.07, a.decay, a.resume) 100 | if sz is not None: a.size = sz 101 | image_f = to_valid_rgb(image_f, colors = a.colors) 102 | 103 | if a.prog is True: 104 | lr1 = a.lrate * 2 105 | lr0 = lr1 * 0.01 106 | else: 107 | lr0 = a.lrate 108 | if a.optimizer.lower() == 'adamw': 109 | optimizer = torch.optim.AdamW(params, lr0, weight_decay=0.01) 110 | elif a.optimizer.lower() == 'adamw_custom': 111 | optimizer = torch.optim.AdamW(params, lr0, weight_decay=0.01, betas=(.0,.999), amsgrad=True) 112 | elif a.optimizer.lower() == 'adam': 113 | optimizer = torch.optim.Adam(params, lr0) 114 | else: # adam_custom 115 | optimizer = torch.optim.Adam(params, lr0, betas=(.0,.999)) 116 | sign = 1. if a.invert is True else -1. 117 | 118 | # Load CLIP models 119 | model_clip, _ = clip.load(a.model, jit=old_torch()) 120 | try: 121 | a.modsize = model_clip.visual.input_resolution 122 | except: 123 | a.modsize = 288 if a.model == 'RN50x4' else 384 if a.model == 'RN50x16' else 224 124 | if a.verbose is True: print(' using model', a.model) 125 | xmem = {'ViT-B/16':0.25, 'RN50':0.5, 'RN50x4':0.16, 'RN50x16':0.06, 'RN101':0.33} 126 | if a.model in xmem.keys(): 127 | a.samples = int(a.samples * xmem[a.model]) 128 | 129 | # if a.multilang is True: 130 | # model_lang = SentenceTransformer('clip-ViT-B-32-multilingual-v1').cuda() 131 | 132 | if a.dualmod is not None: # second is vit-16 133 | model_clip2, _ = clip.load('ViT-B/16', jit=old_torch()) 134 | a.samples = int(a.samples * 0.23) 135 | dualmod_nums = list(range(a.steps))[a.dualmod::a.dualmod] 136 | print(' dual model every %d step' % a.dualmod) 137 | 138 | if a.aest != 0 and a.model in ['ViT-B/32', 'ViT-B/16', 'ViT-L/14']: 139 | aest = aesthetic_model(a.model).cuda() 140 | if a.dualmod is not None: 141 | aest2 = aesthetic_model('ViT-B/16').cuda() 142 | 143 | def enc_text(txt, model_clip=model_clip): 144 | embs = [] 145 | for subtxt in txt.split('|'): 146 | if ':' in subtxt: 147 | [subtxt, wt] = subtxt.split(':') 148 | wt = float(wt) 149 | else: wt = 1. 150 | emb = model_clip.encode_text(clip.tokenize(subtxt).cuda()) 151 | # if a.multilang is True: 152 | # emb = model_lang.encode([subtxt], convert_to_tensor=True, show_progress_bar=False) 153 | embs.append([emb.detach().clone(), wt]) 154 | return embs 155 | 156 | if a.enforce != 0: 157 | a.samples = int(a.samples * 0.5) 158 | if a.sync > 0: 159 | a.samples = int(a.samples * 0.5) 160 | 161 | if 'elastic' in a.transform: 162 | trform_f = transforms.transforms_elastic 163 | a.samples = int(a.samples * 0.95) 164 | elif 'custom' in a.transform: 165 | trform_f = transforms.transforms_custom 166 | a.samples = int(a.samples * 0.95) 167 | elif 'fast' in a.transform: 168 | trform_f = transforms.transforms_fast 169 | a.samples = int(a.samples * 0.95) 170 | else: 171 | trform_f = transforms.normalize() 172 | 173 | out_name = [] 174 | if a.in_txt is not None: 175 | if a.verbose is True: print(' topic text: ', a.in_txt) 176 | if a.translate: 177 | translator = Translator() 178 | a.in_txt = translator.translate(a.in_txt, dest='en').text 179 | if a.verbose is True: print(' translated to:', a.in_txt) 180 | txt_enc = enc_text(a.in_txt) 181 | out_name.append(txt_clean(a.in_txt).lower()[:40]) 182 | if a.dualmod is not None: 183 | txt_enc2 = enc_text(a.in_txt, model_clip2) 184 | 185 | if a.in_txt2 is not None: 186 | if a.verbose is True: print(' style text:', a.in_txt2) 187 | a.samples = int(a.samples * 0.75) 188 | if a.translate: 189 | translator = Translator() 190 | a.in_txt2 = translator.translate(a.in_txt2, dest='en').text 191 | if a.verbose is True: print(' translated to:', a.in_txt2) 192 | style_enc = enc_text(a.in_txt2) 193 | out_name.append(txt_clean(a.in_txt2).lower()[:40]) 194 | if a.dualmod is not None: 195 | style_enc2 = enc_text(a.in_txt2, model_clip2) 196 | 197 | if a.in_txt0 is not None: 198 | if a.verbose is True: print(' subtract text:', a.in_txt0) 199 | a.samples = int(a.samples * 0.75) 200 | if a.translate: 201 | translator = Translator() 202 | a.in_txt0 = translator.translate(a.in_txt0, dest='en').text 203 | if a.verbose is True: print(' translated to:', a.in_txt0) 204 | not_enc = enc_text(a.in_txt0) 205 | out_name.append('off-' + txt_clean(a.in_txt0).lower()[:40]) 206 | if a.dualmod is not None: 207 | not_enc2 = enc_text(a.in_txt0, model_clip2) 208 | 209 | # if a.multilang is True: del model_lang 210 | 211 | if a.in_img is not None and os.path.isfile(a.in_img): 212 | if a.verbose is True: print(' ref image:', basename(a.in_img)) 213 | img_in = torch.from_numpy(img_read(a.in_img)/255.).unsqueeze(0).permute(0,3,1,2).cuda() 214 | img_in = img_in[:,:3,:,:] # fix rgb channels 215 | in_sliced = slice_imgs([img_in], a.samples, a.modsize, transforms.normalize(), a.align)[0] 216 | img_enc = model_clip.encode_image(in_sliced).detach().clone() 217 | if a.dualmod is not None: 218 | img_enc2 = model_clip2.encode_image(in_sliced).detach().clone() 219 | if a.sync > 0: 220 | sim_loss = lpips.LPIPS(net='vgg', verbose=False).cuda() 221 | sim_size = [s//2 for s in a.size] 222 | img_in = F.interpolate(img_in, sim_size, mode='bicubic', align_corners=True).float() 223 | else: 224 | del img_in 225 | del in_sliced; torch.cuda.empty_cache() 226 | out_name.append(basename(a.in_img).replace(' ', '_')) 227 | 228 | if a.verbose is True: print(' samples:', a.samples) 229 | out_name = '-'.join(out_name) 230 | out_name += '-%s' % a.model.replace('/','').replace('-','') if a.dualmod is None else '-dm%d' % a.dualmod 231 | tempdir = os.path.join(a.out_dir, out_name) 232 | os.makedirs(tempdir, exist_ok=True) 233 | 234 | prev_enc = 0 235 | def train(i): 236 | loss = 0 237 | 238 | noise = a.noise * torch.rand(1, 1, *params[0].shape[2:4], 1).cuda() if a.noise > 0 else None 239 | img_out = image_f(noise) 240 | img_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f, a.align, a.macro)[0] 241 | 242 | if a.in_txt is not None: # input text 243 | txt_enc_ = txt_enc2 if a.dualmod is not None and i in dualmod_nums else txt_enc 244 | if a.in_txt2 is not None: 245 | style_enc_ = style_enc2 if a.dualmod is not None and i in dualmod_nums else style_enc 246 | if a.in_img is not None and os.path.isfile(a.in_img): 247 | img_enc_ = img_enc2 if a.dualmod is not None and i in dualmod_nums else img_enc 248 | if a.in_txt0 is not None: 249 | not_enc_ = not_enc2 if a.dualmod is not None and i in dualmod_nums else not_enc 250 | model_clip_ = model_clip2 if a.dualmod is not None and i in dualmod_nums else model_clip 251 | if a.aest != 0: 252 | aest_ = aest2 if a.dualmod is not None and i in dualmod_nums else aest 253 | 254 | out_enc = model_clip_.encode_image(img_sliced) 255 | if a.aest != 0 and aest_ is not None: 256 | loss -= 0.001 * a.aest * aest_(out_enc).mean() 257 | if a.in_txt is not None: # input text 258 | for enc, wt in txt_enc_: 259 | loss += sign * wt * sim_func(enc, out_enc, a.sim) 260 | if a.in_txt2 is not None: # input text - style 261 | for enc, wt in style_enc_: 262 | loss += sign * wt * sim_func(enc, out_enc, a.sim) 263 | if a.in_txt0 is not None: # subtract text 264 | for enc, wt in not_enc_: 265 | loss += -sign * wt * sim_func(enc, out_enc, a.sim) 266 | if a.in_img is not None and os.path.isfile(a.in_img): # input image 267 | loss += sign * a.weight_img * sim_func(img_enc_, out_enc, a.sim) 268 | if a.sync > 0 and a.in_img is not None and os.path.isfile(a.in_img): # image composition 269 | prog_sync = (a.steps // a.opt_step - i) / (a.steps // a.opt_step) 270 | loss += prog_sync * a.sync * sim_loss(F.interpolate(img_out, sim_size, mode='bicubic', align_corners=True).float(), img_in, normalize=True).squeeze() 271 | if a.sharp != 0 and a.dwt is not True: # scharr|sobel|default 272 | loss -= a.sharp * derivat(img_out, mode='naiv') 273 | # loss -= a.sharp * derivat(img_sliced, mode='scharr') 274 | if a.enforce != 0: 275 | img_sliced = slice_imgs([image_f(noise)], a.samples, a.modsize, trform_f, a.align, a.macro)[0] 276 | out_enc2 = model_clip_.encode_image(img_sliced) 277 | loss -= a.enforce * sim_func(out_enc, out_enc2, a.sim) 278 | del out_enc2; torch.cuda.empty_cache() 279 | if a.expand > 0: 280 | global prev_enc 281 | if i > 0: 282 | loss += a.expand * sim_func(out_enc, prev_enc, a.sim) 283 | prev_enc = out_enc.detach() # .clone() 284 | 285 | del img_out, img_sliced, out_enc; torch.cuda.empty_cache() 286 | assert not isinstance(loss, int), ' Loss not defined, check the inputs' 287 | 288 | if a.prog is True: 289 | lr_cur = lr0 + (i / a.steps) * (lr1 - lr0) 290 | for g in optimizer.param_groups: 291 | g['lr'] = lr_cur 292 | 293 | optimizer.zero_grad() 294 | loss.backward() 295 | optimizer.step() 296 | 297 | if i % a.opt_step == 0: 298 | with torch.no_grad(): 299 | img = image_f(contrast=a.contrast).cpu().numpy()[0] 300 | # empirical tone mapping 301 | if (a.sync > 0 and a.in_img is not None): 302 | img = img **1.3 303 | elif a.sharp != 0: 304 | img = img ** (1 + a.sharp/2.) 305 | checkout(img, os.path.join(tempdir, '%04d.jpg' % (i // a.opt_step)), verbose=a.verbose) 306 | pbar.upd() 307 | 308 | pbar = ProgressBar(a.steps // a.opt_step) 309 | for i in range(a.steps): 310 | train(i) 311 | 312 | os.system('ffmpeg -v warning -y -i %s/\%%04d.jpg "%s.mp4"' % (tempdir, os.path.join(a.out_dir, out_name))) 313 | shutil.copy(img_list(tempdir)[-1], os.path.join(a.out_dir, '%s-%d.jpg' % (out_name, a.steps))) 314 | if a.save_pt is True: 315 | torch.save(params, '%s.pt' % os.path.join(a.out_dir, out_name)) 316 | 317 | if __name__ == '__main__': 318 | main() 319 | -------------------------------------------------------------------------------- /cppn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | warnings.filterwarnings("ignore") 4 | import argparse 5 | import numpy as np 6 | import shutil 7 | import math 8 | from collections import OrderedDict 9 | 10 | try: 11 | from googletrans import Translator 12 | googletrans_ok = True 13 | except: 14 | googletrans_ok = False 15 | 16 | import torch 17 | import torchvision 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | import clip 22 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 23 | 24 | from aphantasia.utils import slice_imgs, derivat, aesthetic_model, txt_clean, checkout, old_torch 25 | from aphantasia import transforms 26 | from shader_expo import cppn_to_shader 27 | 28 | from eps.progress_bar import ProgressBar 29 | from eps.data_load import basename, img_list, img_read, file_list, save_cfg 30 | 31 | clip_models = ['ViT-B/16', 'ViT-B/32', 'ViT-L/14', 'RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101'] 32 | 33 | def get_args(): 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('-i', '--in_img', default=None, help='input image') 36 | parser.add_argument('-t', '--in_txt', default=None, help='input text') 37 | parser.add_argument('-t0', '--in_txt0', default=None, help='input text to subtract') 38 | parser.add_argument( '--out_dir', default='_out') 39 | parser.add_argument('-r', '--resume', default=None, help='Input CPPN model (NPY file) to resume from') 40 | parser.add_argument('-s', '--size', default='512-512', help='Output resolution') 41 | parser.add_argument( '--fstep', default=1, type=int, help='Saving step') 42 | parser.add_argument('-tr', '--translate', action='store_true') 43 | parser.add_argument('-v', '--verbose', action='store_true') 44 | parser.add_argument('-ex', '--export', action='store_true', help="Only export shaders from resumed snapshot") 45 | # networks 46 | parser.add_argument('-l', '--layers', default=10, type=int, help='CPPN layers') 47 | parser.add_argument('-nf', '--nf', default=24, type=int, help='num features') # 256 48 | parser.add_argument('-act', '--actfn', default='unbias', choices=['unbias', 'comp', 'relu'], help='activation function') 49 | parser.add_argument('-dec', '--decim', default=3, type=int, help='Decimal precision for export') 50 | # training 51 | parser.add_argument('-m', '--model', default='ViT-B/32', choices=clip_models, help='Select CLIP model to use') 52 | parser.add_argument('-dm', '--dualmod', default=None, type=int, help='Every this step use another CLIP ViT model') 53 | parser.add_argument( '--steps', default=200, type=int, help='Total iterations') 54 | parser.add_argument( '--samples', default=50, type=int, help='Samples to evaluate') 55 | parser.add_argument('-lr', '--lrate', default=0.003, type=float, help='Learning rate') 56 | parser.add_argument('-a', '--align', default='overscan', choices=['central', 'uniform', 'overscan'], help='Sampling distribution') 57 | parser.add_argument('-sh', '--sharp', default=0, type=float) 58 | parser.add_argument('-tf', '--transform', action='store_true', help='use augmenting transforms?') 59 | parser.add_argument('-mc', '--macro', default=0.4, type=float, help='Endorse macro forms 0..1; -1 = normal big') 60 | parser.add_argument( '--aest', default=0., type=float) 61 | a = parser.parse_args() 62 | if a.size is not None: a.size = [int(s) for s in a.size.split('-')][::-1] 63 | if len(a.size)==1: a.size = a.size * 2 64 | if a.translate is True and googletrans_ok is not True: 65 | print('\n Install googletrans module to enable translation!'); exit() 66 | if a.dualmod is not None: 67 | a.model = 'ViT-B/32' 68 | return a 69 | 70 | 71 | class ConvLayer(nn.Module): 72 | def __init__(self, nf_in, nf_out, act_fn='relu'): 73 | super().__init__() 74 | self.nf_in = nf_in 75 | self.conv = nn.Conv2d(nf_in, nf_out, 1, 1) 76 | if act_fn == 'comp': 77 | self.act_fn = self.composite_activation 78 | elif act_fn == 'unbias': 79 | self.act_fn = self.composite_activation_unbiased 80 | elif act_fn == 'relu': 81 | self.act_fn = self.relu_normalized 82 | else: # last layer (output) 83 | self.act_fn = torch.sigmoid 84 | with torch.no_grad(): # init 85 | self.conv.weight.normal_(0., math.sqrt(1./self.nf_in)) 86 | self.conv.bias.uniform_(-.5, .5) 87 | 88 | def composite_activation(self, x): 89 | x = torch.atan(x) 90 | return torch.cat([x/0.67, (x*x)/0.6], 1) 91 | def composite_activation_unbiased(self, x): 92 | x = torch.atan(x) 93 | return torch.cat([x/0.67, (x*x-0.45)/0.396], 1) 94 | def relu_normalized(self, x): 95 | x = F.relu(x) 96 | return (x-0.40)/0.58 97 | # https://colab.research.google.com/drive/1F1c2ouulmqys-GJBVBHn04I1UVWeexiB 98 | 99 | def forward(self, input): 100 | return self.act_fn(self.conv(input)) 101 | 102 | class CPPN(nn.Module): 103 | def __init__(self, nf_in=2, nf_hid=16, num_layers=9, nf_out=3, act_fn='unbias'): # unbias relu 104 | super().__init__() 105 | nf_hid_in = nf_hid if act_fn == 'relu' else nf_hid*2 106 | self.net = [] 107 | self.net.append(ConvLayer(nf_in, nf_hid, act_fn)) 108 | for i in range(num_layers-1): 109 | self.net.append(ConvLayer(nf_hid_in, nf_hid, act_fn)) 110 | self.net.append(ConvLayer(nf_hid_in, nf_out, 'sigmoid')) 111 | self.net = nn.Sequential(*self.net) 112 | 113 | def forward(self, coords): 114 | coords = coords.clone().detach().requires_grad_(True) # [1,3,h,w] 115 | output = self.net(coords.cuda()) 116 | return output 117 | 118 | def load_cppn(file, verbose=True): # actfn='unbias' 119 | params = np.load(file, allow_pickle=True) 120 | nf = params[0].shape[-1] 121 | num_layers = len(params) // 2 - 1 122 | act_fn = 'relu' if params[0].shape[-1] == params[2].shape[-2] else 'unbias' 123 | snet = CPPN(2, nf, num_layers, 3, act_fn=act_fn).cuda() 124 | if verbose is True: print(' loaded:', file) 125 | if verbose is True: print(' .. %d vars, %d layers, %d nf, act %s' % (len(params), num_layers, nf, act_fn)) 126 | keys = list(snet.state_dict().keys()) 127 | assert len(keys) == len(params) 128 | cppn_dict = OrderedDict({}) 129 | for lnum in range(0, len(keys), 2): 130 | cppn_dict[keys[lnum]] = np.transpose(torch.from_numpy(params[lnum]), (3,2,1,0)) 131 | cppn_dict[keys[lnum+1]] = torch.from_numpy(params[lnum+1]) 132 | snet.load_state_dict(cppn_dict) 133 | return snet 134 | 135 | def get_mgrid(sideX, sideY): 136 | tensors = [np.linspace(-1, 1, num=sideY), np.linspace(-1, 1, num=sideX)] 137 | mgrid = np.stack(np.meshgrid(*tensors), axis=-1) 138 | mgrid = np.transpose(mgrid, (2,0,1))[np.newaxis] 139 | return mgrid 140 | 141 | def export_gfx(model, out_name, mode, precision, size): 142 | shader = cppn_to_shader(model, mode=mode, verbose=False, fix_aspect=True, size=size, precision=precision) 143 | if mode == 'vvvv': out_path = out_name + '.tfx' 144 | elif mode == 'buffer': out_path = out_name + '.txt' 145 | else: out_path = out_name + '-%s.glsl' % mode 146 | with open(out_path, 'wt') as f: 147 | f.write(shader) 148 | return out_path 149 | 150 | def export_data(cppn_dict, out_name, size, decim=3, actfn='unbias', shaders=False, npy=True): 151 | if npy is True: arrays = [] 152 | if shaders is True: params = [] 153 | keys = list(cppn_dict.keys()) 154 | 155 | for lnum in range(0, len(keys), 2): 156 | w = cppn_dict[keys[lnum]].permute((3,2,1,0)).cpu().numpy() 157 | b = cppn_dict[keys[lnum+1]].cpu().numpy() 158 | if shaders is True: params.append({'weights': w, 'bias': b, 'activation': actfn}) 159 | if npy is True: arrays += [w,b] 160 | 161 | if npy is True: 162 | np.save(out_name + '.npy', np.array(arrays, object)) 163 | if shaders is True: 164 | export_gfx(params, out_name, 'td', decim, size) 165 | export_gfx(params, out_name, 'vvvv', decim, size) 166 | export_gfx(params, out_name, 'buffer', decim, size) 167 | export_gfx(params, out_name, 'bookofshaders', decim, size) 168 | export_gfx(params, out_name, 'shadertoy', decim, size) 169 | 170 | 171 | def main(): 172 | a = get_args() 173 | bx = 1. 174 | 175 | mgrid = get_mgrid(*a.size) 176 | mgrid = torch.from_numpy(mgrid.astype(np.float32)).cuda() 177 | 178 | # Load models 179 | if a.resume is not None and os.path.isfile(a.resume): 180 | snet = load_cppn(a.resume) 181 | else: 182 | snet = CPPN(mgrid.shape[1], a.nf, a.layers, 3, act_fn=a.actfn).cuda() 183 | print(' .. %d vars, %d layers, %d nf, act %s' % (len(snet.state_dict().keys()), a.layers, a.nf, a.actfn)) 184 | 185 | if a.export is True: 186 | print('exporting') 187 | export_data(snet.state_dict(), a.resume.replace('.npy', ''), a.size, a.decim, a.actfn, shaders=True, npy=False) 188 | img = snet(mgrid).detach().cpu().numpy()[0] 189 | checkout(img, a.resume.replace('.npy', '.jpg'), verbose=False) 190 | exit(0) 191 | 192 | model_clip, _ = clip.load(a.model, jit=old_torch()) 193 | try: 194 | a.modsize = model_clip.visual.input_resolution 195 | except: 196 | a.modsize = 288 if a.model == 'RN50x4' else 384 if a.model == 'RN50x16' else 448 if a.model == 'RN50x64' else 224 197 | xmem = {'ViT-B/16':0.25, 'ViT-L/14':0.11, 'RN50':0.5, 'RN50x4':0.16, 'RN50x16':0.06, 'RN50x64':0.04, 'RN101':0.33} 198 | if a.model in xmem.keys(): 199 | a.samples = int(a.samples * xmem[a.model]) 200 | 201 | if a.dualmod is not None: 202 | model_clip2, _ = clip.load('ViT-B/16', jit=old_torch()) 203 | a.samples = int(a.samples * 0.69) # second is vit-16 204 | dualmod_nums = list(range(a.steps))[a.dualmod::a.dualmod] 205 | print(' dual model every %d step' % a.dualmod) 206 | 207 | if a.aest != 0 and a.model in ['ViT-B/32', 'ViT-B/16', 'ViT-L/14']: 208 | aest = aesthetic_model(a.model).cuda() 209 | if a.dualmod is not None: 210 | aest2 = aesthetic_model('ViT-B/16').cuda() 211 | 212 | def enc_text(txt, model_clip=model_clip): 213 | if txt is None or len(txt)==0: return None 214 | emb = model_clip.encode_text(clip.tokenize(txt).cuda()[:,:77]) 215 | return emb.detach().clone() 216 | 217 | optimizer = torch.optim.Adam(snet.parameters(), a.lrate) # orig .00001, better 0.0001 218 | 219 | if a.transform is True: 220 | trform_f = transforms.trfm_fast 221 | a.samples = int(a.samples * 0.95) 222 | else: 223 | trform_f = transforms.normalize() 224 | 225 | out_name = [] 226 | if a.in_txt is not None: 227 | print(' ref text: ', basename(a.in_txt)) 228 | if a.translate: 229 | translator = Translator() 230 | a.in_txt = translator.translate(a.in_txt, dest='en').text 231 | print(' translated to:', a.in_txt) 232 | txt_enc = enc_text(a.in_txt) 233 | if a.dualmod is not None: 234 | txt_enc2 = enc_text(a.in_txt, model_clip2) 235 | out_name.append(txt_clean(a.in_txt)) 236 | 237 | if a.in_txt0 is not None: 238 | print(' no text: ', basename(a.in_txt0)) 239 | if a.translate: 240 | translator = Translator() 241 | a.in_txt0 = translator.translate(a.in_txt0, dest='en').text 242 | print(' translated to:', a.in_txt0) 243 | not_enc = enc_text(a.in_txt0) 244 | if a.dualmod is not None: 245 | not_enc2 = enc_text(a.in_txt0, model_clip2) 246 | 247 | img_enc = None 248 | if a.in_img is not None and os.path.isfile(a.in_img): 249 | print(' ref image:', basename(a.in_img)) 250 | img_in = torch.from_numpy(img_read(a.in_img)/255.).unsqueeze(0).permute(0,3,1,2).cuda() 251 | in_sliced = slice_imgs([img_in], a.samples, a.modsize, transforms.normalize(), a.align)[0] 252 | img_enc = model_clip.encode_image(in_sliced).detach().clone() 253 | if a.dualmod is not None: 254 | img_enc2 = model_clip2.encode_image(in_sliced).detach().clone() 255 | del img_in, in_sliced; torch.cuda.empty_cache() 256 | out_name.append(basename(a.in_img).replace(' ', '_')) 257 | 258 | # Prepare dirs 259 | sfx = '-l%d-n%d' % (a.layers, a.nf) 260 | if a.dualmod is not None: sfx += '-dm%d' % a.dualmod 261 | if a.aest != 0: sfx += '-ae%.2g' % a.aest 262 | workdir = os.path.join(a.out_dir, 'cppn') 263 | out_name = os.path.join(workdir, '-'.join(out_name) + sfx) 264 | tempdir = out_name 265 | os.makedirs(out_name, exist_ok=True) 266 | print(a.samples) 267 | 268 | def train(i, img_enc=None): 269 | loss = 0 270 | img_out = snet(mgrid) 271 | 272 | txt_enc_ = txt_enc2 if a.dualmod is not None and i in dualmod_nums else txt_enc 273 | if a.in_img is not None and os.path.isfile(a.in_img): 274 | img_enc_ = img_enc2 if a.dualmod is not None and i in dualmod_nums else img_enc 275 | if a.in_txt0 is not None: 276 | not_enc_ = not_enc2 if a.dualmod is not None and i in dualmod_nums else not_enc 277 | model_clip_ = model_clip2 if a.dualmod is not None and i in dualmod_nums else model_clip 278 | if a.aest != 0: 279 | aest_ = aest2 if a.dualmod is not None and i in dualmod_nums else aest 280 | 281 | imgs_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f, a.align, a.macro) 282 | out_enc = model_clip_.encode_image(imgs_sliced[-1]) 283 | if a.aest != 0 and aest_ is not None: 284 | loss -= 0.001 * a.aest * aest_(out_enc).mean() 285 | if a.in_txt is not None: 286 | loss -= torch.cosine_similarity(txt_enc_, out_enc, dim=-1).mean() 287 | if a.in_txt0 is not None: 288 | loss += 0.5 * torch.cosine_similarity(not_enc_, out_enc, dim=-1).mean() 289 | if a.in_img is not None and os.path.isfile(a.in_img): 290 | loss -= torch.cosine_similarity(img_enc_, out_enc, dim=-1).mean() 291 | if a.sharp != 0: # mode = scharr|sobel|default 292 | loss -= a.sharp * derivat(img_out, mode='sobel') 293 | del img_out, imgs_sliced, out_enc; torch.cuda.empty_cache() 294 | 295 | optimizer.zero_grad() 296 | loss.backward() 297 | optimizer.step() 298 | 299 | if i % a.fstep == 0: 300 | with torch.no_grad(): 301 | img = snet(mgrid).cpu().numpy()[0] 302 | fname = os.path.join(tempdir, '%04d' % (i // a.fstep)) 303 | checkout(img, fname + '.jpg', verbose=a.verbose) 304 | export_data(snet.state_dict(), fname, a.size, a.decim) 305 | return 306 | 307 | pbar = ProgressBar(a.steps) 308 | for i in range(a.steps): 309 | log = train(i, img_enc) 310 | pbar.upd(log) 311 | 312 | export_data(snet.state_dict(), out_name, a.size, a.decim, shaders=True) 313 | os.system('ffmpeg -v warning -y -i %s\%%04d.jpg -c:v mjpeg -pix_fmt yuvj444p -dst_range 1 -q:v 2 "%s.avi"' % (tempdir, out_name)) 314 | shutil.copy(img_list(tempdir)[-1], out_name + '-%d.jpg' % a.steps) 315 | # shutil.rmtree(tempdir) 316 | 317 | 318 | if __name__ == '__main__': 319 | main() 320 | -------------------------------------------------------------------------------- /depth/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/aphantasia/8a415286d2891e92d865150d6e0e59fdfd32fb01/depth/__init__.py -------------------------------------------------------------------------------- /depth/any2/dinov2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | from functools import partial 11 | import math 12 | import logging 13 | from typing import Sequence, Tuple, Union, Callable 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.utils.checkpoint 18 | from torch.nn.init import trunc_normal_ 19 | 20 | from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block 21 | 22 | 23 | logger = logging.getLogger("dinov2") 24 | 25 | 26 | def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: 27 | if not depth_first and include_root: 28 | fn(module=module, name=name) 29 | for child_name, child_module in module.named_children(): 30 | child_name = ".".join((name, child_name)) if name else child_name 31 | named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) 32 | if depth_first and include_root: 33 | fn(module=module, name=name) 34 | return module 35 | 36 | 37 | class BlockChunk(nn.ModuleList): 38 | def forward(self, x): 39 | for b in self: 40 | x = b(x) 41 | return x 42 | 43 | 44 | class DinoVisionTransformer(nn.Module): 45 | def __init__( 46 | self, 47 | img_size=224, 48 | patch_size=16, 49 | in_chans=3, 50 | embed_dim=768, 51 | depth=12, 52 | num_heads=12, 53 | mlp_ratio=4.0, 54 | qkv_bias=True, 55 | ffn_bias=True, 56 | proj_bias=True, 57 | drop_path_rate=0.0, 58 | drop_path_uniform=False, 59 | init_values=None, # for layerscale: None or 0 => no layerscale 60 | embed_layer=PatchEmbed, 61 | act_layer=nn.GELU, 62 | block_fn=Block, 63 | ffn_layer="mlp", 64 | block_chunks=1, 65 | num_register_tokens=0, 66 | interpolate_antialias=False, 67 | interpolate_offset=0.1, 68 | ): 69 | """ 70 | Args: 71 | img_size (int, tuple): input image size 72 | patch_size (int, tuple): patch size 73 | in_chans (int): number of input channels 74 | embed_dim (int): embedding dimension 75 | depth (int): depth of transformer 76 | num_heads (int): number of attention heads 77 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 78 | qkv_bias (bool): enable bias for qkv if True 79 | proj_bias (bool): enable bias for proj in attn if True 80 | ffn_bias (bool): enable bias for ffn if True 81 | drop_path_rate (float): stochastic depth rate 82 | drop_path_uniform (bool): apply uniform drop rate across blocks 83 | weight_init (str): weight init scheme 84 | init_values (float): layer-scale init values 85 | embed_layer (nn.Module): patch embedding layer 86 | act_layer (nn.Module): MLP activation layer 87 | block_fn (nn.Module): transformer block class 88 | ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" 89 | block_chunks: (int) split block sequence into block_chunks units for FSDP wrap 90 | num_register_tokens: (int) number of extra cls tokens (so-called "registers") 91 | interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings 92 | interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings 93 | """ 94 | super().__init__() 95 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 96 | 97 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 98 | self.num_tokens = 1 99 | self.n_blocks = depth 100 | self.num_heads = num_heads 101 | self.patch_size = patch_size 102 | self.num_register_tokens = num_register_tokens 103 | self.interpolate_antialias = interpolate_antialias 104 | self.interpolate_offset = interpolate_offset 105 | 106 | self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 107 | num_patches = self.patch_embed.num_patches 108 | 109 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 110 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 111 | assert num_register_tokens >= 0 112 | self.register_tokens = ( 113 | nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None 114 | ) 115 | 116 | if drop_path_uniform is True: 117 | dpr = [drop_path_rate] * depth 118 | else: 119 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 120 | 121 | if ffn_layer == "mlp": 122 | logger.info("using MLP layer as FFN") 123 | ffn_layer = Mlp 124 | elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": 125 | logger.info("using SwiGLU layer as FFN") 126 | ffn_layer = SwiGLUFFNFused 127 | elif ffn_layer == "identity": 128 | logger.info("using Identity layer as FFN") 129 | 130 | def f(*args, **kwargs): 131 | return nn.Identity() 132 | 133 | ffn_layer = f 134 | else: 135 | raise NotImplementedError 136 | 137 | blocks_list = [ 138 | block_fn( 139 | dim=embed_dim, 140 | num_heads=num_heads, 141 | mlp_ratio=mlp_ratio, 142 | qkv_bias=qkv_bias, 143 | proj_bias=proj_bias, 144 | ffn_bias=ffn_bias, 145 | drop_path=dpr[i], 146 | norm_layer=norm_layer, 147 | act_layer=act_layer, 148 | ffn_layer=ffn_layer, 149 | init_values=init_values, 150 | ) 151 | for i in range(depth) 152 | ] 153 | if block_chunks > 0: 154 | self.chunked_blocks = True 155 | chunked_blocks = [] 156 | chunksize = depth // block_chunks 157 | for i in range(0, depth, chunksize): 158 | # this is to keep the block index consistent if we chunk the block list 159 | chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) 160 | self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) 161 | else: 162 | self.chunked_blocks = False 163 | self.blocks = nn.ModuleList(blocks_list) 164 | 165 | self.norm = norm_layer(embed_dim) 166 | self.head = nn.Identity() 167 | 168 | self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) 169 | 170 | self.init_weights() 171 | 172 | def init_weights(self): 173 | trunc_normal_(self.pos_embed, std=0.02) 174 | nn.init.normal_(self.cls_token, std=1e-6) 175 | if self.register_tokens is not None: 176 | nn.init.normal_(self.register_tokens, std=1e-6) 177 | named_apply(init_weights_vit_timm, self) 178 | 179 | def interpolate_pos_encoding(self, x, w, h): 180 | previous_dtype = x.dtype 181 | npatch = x.shape[1] - 1 182 | N = self.pos_embed.shape[1] - 1 183 | if npatch == N and w == h: 184 | return self.pos_embed 185 | pos_embed = self.pos_embed.float() 186 | class_pos_embed = pos_embed[:, 0] 187 | patch_pos_embed = pos_embed[:, 1:] 188 | dim = x.shape[-1] 189 | w0 = w // self.patch_size 190 | h0 = h // self.patch_size 191 | # we add a small number to avoid floating point error in the interpolation 192 | # see discussion at https://github.com/facebookresearch/dino/issues/8 193 | # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0 194 | w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset 195 | # w0, h0 = w0 + 0.1, h0 + 0.1 196 | 197 | sqrt_N = math.sqrt(N) 198 | sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N 199 | patch_pos_embed = nn.functional.interpolate( 200 | patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), 201 | scale_factor=(sx, sy), 202 | # (int(w0), int(h0)), # to solve the upsampling shape issue 203 | mode="bicubic", 204 | antialias=self.interpolate_antialias 205 | ) 206 | 207 | assert int(w0) == patch_pos_embed.shape[-2] 208 | assert int(h0) == patch_pos_embed.shape[-1] 209 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 210 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) 211 | 212 | def prepare_tokens_with_masks(self, x, masks=None): 213 | B, nc, w, h = x.shape 214 | x = self.patch_embed(x) 215 | if masks is not None: 216 | x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) 217 | 218 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 219 | x = x + self.interpolate_pos_encoding(x, w, h) 220 | 221 | if self.register_tokens is not None: 222 | x = torch.cat( 223 | ( 224 | x[:, :1], 225 | self.register_tokens.expand(x.shape[0], -1, -1), 226 | x[:, 1:], 227 | ), 228 | dim=1, 229 | ) 230 | 231 | return x 232 | 233 | def forward_features_list(self, x_list, masks_list): 234 | x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] 235 | for blk in self.blocks: 236 | x = blk(x) 237 | 238 | all_x = x 239 | output = [] 240 | for x, masks in zip(all_x, masks_list): 241 | x_norm = self.norm(x) 242 | output.append( 243 | { 244 | "x_norm_clstoken": x_norm[:, 0], 245 | "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], 246 | "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], 247 | "x_prenorm": x, 248 | "masks": masks, 249 | } 250 | ) 251 | return output 252 | 253 | def forward_features(self, x, masks=None): 254 | if isinstance(x, list): 255 | return self.forward_features_list(x, masks) 256 | 257 | x = self.prepare_tokens_with_masks(x, masks) 258 | 259 | for blk in self.blocks: 260 | x = blk(x) 261 | 262 | x_norm = self.norm(x) 263 | return { 264 | "x_norm_clstoken": x_norm[:, 0], 265 | "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], 266 | "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], 267 | "x_prenorm": x, 268 | "masks": masks, 269 | } 270 | 271 | def _get_intermediate_layers_not_chunked(self, x, n=1): 272 | x = self.prepare_tokens_with_masks(x) 273 | # If n is an int, take the n last blocks. If it's a list, take them 274 | output, total_block_len = [], len(self.blocks) 275 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n 276 | for i, blk in enumerate(self.blocks): 277 | x = blk(x) 278 | if i in blocks_to_take: 279 | output.append(x) 280 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" 281 | return output 282 | 283 | def _get_intermediate_layers_chunked(self, x, n=1): 284 | x = self.prepare_tokens_with_masks(x) 285 | output, i, total_block_len = [], 0, len(self.blocks[-1]) 286 | # If n is an int, take the n last blocks. If it's a list, take them 287 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n 288 | for block_chunk in self.blocks: 289 | for blk in block_chunk[i:]: # Passing the nn.Identity() 290 | x = blk(x) 291 | if i in blocks_to_take: 292 | output.append(x) 293 | i += 1 294 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" 295 | return output 296 | 297 | def get_intermediate_layers( 298 | self, 299 | x: torch.Tensor, 300 | n: Union[int, Sequence] = 1, # Layers or n last layers to take 301 | reshape: bool = False, 302 | return_class_token: bool = False, 303 | norm=True 304 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: 305 | if self.chunked_blocks: 306 | outputs = self._get_intermediate_layers_chunked(x, n) 307 | else: 308 | outputs = self._get_intermediate_layers_not_chunked(x, n) 309 | if norm: 310 | outputs = [self.norm(out) for out in outputs] 311 | class_tokens = [out[:, 0] for out in outputs] 312 | outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] 313 | if reshape: 314 | B, _, w, h = x.shape 315 | outputs = [ 316 | out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() 317 | for out in outputs 318 | ] 319 | if return_class_token: 320 | return tuple(zip(outputs, class_tokens)) 321 | return tuple(outputs) 322 | 323 | def forward(self, *args, is_training=False, **kwargs): 324 | ret = self.forward_features(*args, **kwargs) 325 | if is_training: 326 | return ret 327 | else: 328 | return self.head(ret["x_norm_clstoken"]) 329 | 330 | 331 | def init_weights_vit_timm(module: nn.Module, name: str = ""): 332 | """ViT weight initialization, original timm impl (for reproducibility)""" 333 | if isinstance(module, nn.Linear): 334 | trunc_normal_(module.weight, std=0.02) 335 | if module.bias is not None: 336 | nn.init.zeros_(module.bias) 337 | 338 | 339 | def vit_small(patch_size=16, num_register_tokens=0, **kwargs): 340 | model = DinoVisionTransformer( 341 | patch_size=patch_size, 342 | embed_dim=384, 343 | depth=12, 344 | num_heads=6, 345 | mlp_ratio=4, 346 | block_fn=partial(Block, attn_class=MemEffAttention), 347 | num_register_tokens=num_register_tokens, 348 | **kwargs, 349 | ) 350 | return model 351 | 352 | 353 | def vit_base(patch_size=16, num_register_tokens=0, **kwargs): 354 | model = DinoVisionTransformer( 355 | patch_size=patch_size, 356 | embed_dim=768, 357 | depth=12, 358 | num_heads=12, 359 | mlp_ratio=4, 360 | block_fn=partial(Block, attn_class=MemEffAttention), 361 | num_register_tokens=num_register_tokens, 362 | **kwargs, 363 | ) 364 | return model 365 | 366 | 367 | def vit_large(patch_size=16, num_register_tokens=0, **kwargs): 368 | model = DinoVisionTransformer( 369 | patch_size=patch_size, 370 | embed_dim=1024, 371 | depth=24, 372 | num_heads=16, 373 | mlp_ratio=4, 374 | block_fn=partial(Block, attn_class=MemEffAttention), 375 | num_register_tokens=num_register_tokens, 376 | **kwargs, 377 | ) 378 | return model 379 | 380 | 381 | def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): 382 | """ 383 | Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 384 | """ 385 | model = DinoVisionTransformer( 386 | patch_size=patch_size, 387 | embed_dim=1536, 388 | depth=40, 389 | num_heads=24, 390 | mlp_ratio=4, 391 | block_fn=partial(Block, attn_class=MemEffAttention), 392 | num_register_tokens=num_register_tokens, 393 | **kwargs, 394 | ) 395 | return model 396 | 397 | 398 | def DINOv2(model_name): 399 | model_zoo = { 400 | "vits": vit_small, 401 | "vitb": vit_base, 402 | "vitl": vit_large, 403 | "vitg": vit_giant2 404 | } 405 | 406 | return model_zoo[model_name]( 407 | img_size=518, 408 | patch_size=14, 409 | init_values=1.0, 410 | ffn_layer="mlp" if model_name != "vitg" else "swiglufused", 411 | block_chunks=0, 412 | num_register_tokens=0, 413 | interpolate_antialias=False, 414 | interpolate_offset=0.1 415 | ) 416 | -------------------------------------------------------------------------------- /depth/any2/dinov2_layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .mlp import Mlp 8 | from .patch_embed import PatchEmbed 9 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 10 | from .block import NestedTensorBlock 11 | from .attention import MemEffAttention 12 | -------------------------------------------------------------------------------- /depth/any2/dinov2_layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 10 | 11 | import logging 12 | 13 | from torch import Tensor 14 | from torch import nn 15 | 16 | 17 | logger = logging.getLogger("dinov2") 18 | 19 | 20 | try: 21 | from xformers.ops import memory_efficient_attention, unbind, fmha 22 | 23 | XFORMERS_AVAILABLE = True 24 | except ImportError: 25 | logger.warning("xFormers not available") 26 | XFORMERS_AVAILABLE = False 27 | 28 | 29 | class Attention(nn.Module): 30 | def __init__( 31 | self, 32 | dim: int, 33 | num_heads: int = 8, 34 | qkv_bias: bool = False, 35 | proj_bias: bool = True, 36 | attn_drop: float = 0.0, 37 | proj_drop: float = 0.0, 38 | ) -> None: 39 | super().__init__() 40 | self.num_heads = num_heads 41 | head_dim = dim // num_heads 42 | self.scale = head_dim**-0.5 43 | 44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 45 | self.attn_drop = nn.Dropout(attn_drop) 46 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 47 | self.proj_drop = nn.Dropout(proj_drop) 48 | 49 | def forward(self, x: Tensor) -> Tensor: 50 | B, N, C = x.shape 51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 52 | 53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 54 | attn = q @ k.transpose(-2, -1) 55 | 56 | attn = attn.softmax(dim=-1) 57 | attn = self.attn_drop(attn) 58 | 59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 60 | x = self.proj(x) 61 | x = self.proj_drop(x) 62 | return x 63 | 64 | 65 | class MemEffAttention(Attention): 66 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 67 | if not XFORMERS_AVAILABLE: 68 | assert attn_bias is None, "xFormers is required for nested tensors usage" 69 | return super().forward(x) 70 | 71 | B, N, C = x.shape 72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 73 | 74 | q, k, v = unbind(qkv, 2) 75 | 76 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 77 | x = x.reshape([B, N, C]) 78 | 79 | x = self.proj(x) 80 | x = self.proj_drop(x) 81 | return x 82 | 83 | -------------------------------------------------------------------------------- /depth/any2/dinov2_layers/block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | import logging 12 | from typing import Callable, List, Any, Tuple, Dict 13 | 14 | import torch 15 | from torch import nn, Tensor 16 | 17 | from .attention import Attention, MemEffAttention 18 | from .drop_path import DropPath 19 | from .layer_scale import LayerScale 20 | from .mlp import Mlp 21 | 22 | 23 | logger = logging.getLogger("dinov2") 24 | 25 | 26 | try: 27 | from xformers.ops import fmha 28 | from xformers.ops import scaled_index_add, index_select_cat 29 | 30 | XFORMERS_AVAILABLE = True 31 | except ImportError: 32 | logger.warning("xFormers not available") 33 | XFORMERS_AVAILABLE = False 34 | 35 | 36 | class Block(nn.Module): 37 | def __init__( 38 | self, 39 | dim: int, 40 | num_heads: int, 41 | mlp_ratio: float = 4.0, 42 | qkv_bias: bool = False, 43 | proj_bias: bool = True, 44 | ffn_bias: bool = True, 45 | drop: float = 0.0, 46 | attn_drop: float = 0.0, 47 | init_values=None, 48 | drop_path: float = 0.0, 49 | act_layer: Callable[..., nn.Module] = nn.GELU, 50 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 51 | attn_class: Callable[..., nn.Module] = Attention, 52 | ffn_layer: Callable[..., nn.Module] = Mlp, 53 | ) -> None: 54 | super().__init__() 55 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") 56 | self.norm1 = norm_layer(dim) 57 | self.attn = attn_class( 58 | dim, 59 | num_heads=num_heads, 60 | qkv_bias=qkv_bias, 61 | proj_bias=proj_bias, 62 | attn_drop=attn_drop, 63 | proj_drop=drop, 64 | ) 65 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 66 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 67 | 68 | self.norm2 = norm_layer(dim) 69 | mlp_hidden_dim = int(dim * mlp_ratio) 70 | self.mlp = ffn_layer( 71 | in_features=dim, 72 | hidden_features=mlp_hidden_dim, 73 | act_layer=act_layer, 74 | drop=drop, 75 | bias=ffn_bias, 76 | ) 77 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 78 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 79 | 80 | self.sample_drop_ratio = drop_path 81 | 82 | def forward(self, x: Tensor) -> Tensor: 83 | def attn_residual_func(x: Tensor) -> Tensor: 84 | return self.ls1(self.attn(self.norm1(x))) 85 | 86 | def ffn_residual_func(x: Tensor) -> Tensor: 87 | return self.ls2(self.mlp(self.norm2(x))) 88 | 89 | if self.training and self.sample_drop_ratio > 0.1: 90 | # the overhead is compensated only for a drop path rate larger than 0.1 91 | x = drop_add_residual_stochastic_depth( 92 | x, 93 | residual_func=attn_residual_func, 94 | sample_drop_ratio=self.sample_drop_ratio, 95 | ) 96 | x = drop_add_residual_stochastic_depth( 97 | x, 98 | residual_func=ffn_residual_func, 99 | sample_drop_ratio=self.sample_drop_ratio, 100 | ) 101 | elif self.training and self.sample_drop_ratio > 0.0: 102 | x = x + self.drop_path1(attn_residual_func(x)) 103 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 104 | else: 105 | x = x + attn_residual_func(x) 106 | x = x + ffn_residual_func(x) 107 | return x 108 | 109 | 110 | def drop_add_residual_stochastic_depth( 111 | x: Tensor, 112 | residual_func: Callable[[Tensor], Tensor], 113 | sample_drop_ratio: float = 0.0, 114 | ) -> Tensor: 115 | # 1) extract subset using permutation 116 | b, n, d = x.shape 117 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 118 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 119 | x_subset = x[brange] 120 | 121 | # 2) apply residual_func to get residual 122 | residual = residual_func(x_subset) 123 | 124 | x_flat = x.flatten(1) 125 | residual = residual.flatten(1) 126 | 127 | residual_scale_factor = b / sample_subset_size 128 | 129 | # 3) add the residual 130 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 131 | return x_plus_residual.view_as(x) 132 | 133 | 134 | def get_branges_scales(x, sample_drop_ratio=0.0): 135 | b, n, d = x.shape 136 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 137 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 138 | residual_scale_factor = b / sample_subset_size 139 | return brange, residual_scale_factor 140 | 141 | 142 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): 143 | if scaling_vector is None: 144 | x_flat = x.flatten(1) 145 | residual = residual.flatten(1) 146 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 147 | else: 148 | x_plus_residual = scaled_index_add( 149 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor 150 | ) 151 | return x_plus_residual 152 | 153 | 154 | attn_bias_cache: Dict[Tuple, Any] = {} 155 | 156 | 157 | def get_attn_bias_and_cat(x_list, branges=None): 158 | """ 159 | this will perform the index select, cat the tensors, and provide the attn_bias from cache 160 | """ 161 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] 162 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) 163 | if all_shapes not in attn_bias_cache.keys(): 164 | seqlens = [] 165 | for b, x in zip(batch_sizes, x_list): 166 | for _ in range(b): 167 | seqlens.append(x.shape[1]) 168 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) 169 | attn_bias._batch_sizes = batch_sizes 170 | attn_bias_cache[all_shapes] = attn_bias 171 | 172 | if branges is not None: 173 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) 174 | else: 175 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) 176 | cat_tensors = torch.cat(tensors_bs1, dim=1) 177 | 178 | return attn_bias_cache[all_shapes], cat_tensors 179 | 180 | 181 | def drop_add_residual_stochastic_depth_list( 182 | x_list: List[Tensor], 183 | residual_func: Callable[[Tensor, Any], Tensor], 184 | sample_drop_ratio: float = 0.0, 185 | scaling_vector=None, 186 | ) -> Tensor: 187 | # 1) generate random set of indices for dropping samples in the batch 188 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] 189 | branges = [s[0] for s in branges_scales] 190 | residual_scale_factors = [s[1] for s in branges_scales] 191 | 192 | # 2) get attention bias and index+concat the tensors 193 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) 194 | 195 | # 3) apply residual_func to get residual, and split the result 196 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore 197 | 198 | outputs = [] 199 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): 200 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) 201 | return outputs 202 | 203 | 204 | class NestedTensorBlock(Block): 205 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: 206 | """ 207 | x_list contains a list of tensors to nest together and run 208 | """ 209 | assert isinstance(self.attn, MemEffAttention) 210 | 211 | if self.training and self.sample_drop_ratio > 0.0: 212 | 213 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 214 | return self.attn(self.norm1(x), attn_bias=attn_bias) 215 | 216 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 217 | return self.mlp(self.norm2(x)) 218 | 219 | x_list = drop_add_residual_stochastic_depth_list( 220 | x_list, 221 | residual_func=attn_residual_func, 222 | sample_drop_ratio=self.sample_drop_ratio, 223 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, 224 | ) 225 | x_list = drop_add_residual_stochastic_depth_list( 226 | x_list, 227 | residual_func=ffn_residual_func, 228 | sample_drop_ratio=self.sample_drop_ratio, 229 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, 230 | ) 231 | return x_list 232 | else: 233 | 234 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 235 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) 236 | 237 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 238 | return self.ls2(self.mlp(self.norm2(x))) 239 | 240 | attn_bias, x = get_attn_bias_and_cat(x_list) 241 | x = x + attn_residual_func(x, attn_bias=attn_bias) 242 | x = x + ffn_residual_func(x) 243 | return attn_bias.split(x) 244 | 245 | def forward(self, x_or_x_list): 246 | if isinstance(x_or_x_list, Tensor): 247 | return super().forward(x_or_x_list) 248 | elif isinstance(x_or_x_list, list): 249 | assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" 250 | return self.forward_nested(x_or_x_list) 251 | else: 252 | raise AssertionError 253 | -------------------------------------------------------------------------------- /depth/any2/dinov2_layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 10 | 11 | 12 | from torch import nn 13 | 14 | 15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 16 | if drop_prob == 0.0 or not training: 17 | return x 18 | keep_prob = 1 - drop_prob 19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 21 | if keep_prob > 0.0: 22 | random_tensor.div_(keep_prob) 23 | output = x * random_tensor 24 | return output 25 | 26 | 27 | class DropPath(nn.Module): 28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 29 | 30 | def __init__(self, drop_prob=None): 31 | super(DropPath, self).__init__() 32 | self.drop_prob = drop_prob 33 | 34 | def forward(self, x): 35 | return drop_path(x, self.drop_prob, self.training) 36 | -------------------------------------------------------------------------------- /depth/any2/dinov2_layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 8 | 9 | from typing import Union 10 | 11 | import torch 12 | from torch import Tensor 13 | from torch import nn 14 | 15 | 16 | class LayerScale(nn.Module): 17 | def __init__( 18 | self, 19 | dim: int, 20 | init_values: Union[float, Tensor] = 1e-5, 21 | inplace: bool = False, 22 | ) -> None: 23 | super().__init__() 24 | self.inplace = inplace 25 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 26 | 27 | def forward(self, x: Tensor) -> Tensor: 28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 29 | -------------------------------------------------------------------------------- /depth/any2/dinov2_layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 10 | 11 | 12 | from typing import Callable, Optional 13 | 14 | from torch import Tensor, nn 15 | 16 | 17 | class Mlp(nn.Module): 18 | def __init__( 19 | self, 20 | in_features: int, 21 | hidden_features: Optional[int] = None, 22 | out_features: Optional[int] = None, 23 | act_layer: Callable[..., nn.Module] = nn.GELU, 24 | drop: float = 0.0, 25 | bias: bool = True, 26 | ) -> None: 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x: Tensor) -> Tensor: 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | -------------------------------------------------------------------------------- /depth/any2/dinov2_layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | from typing import Callable, Optional, Tuple, Union 12 | 13 | from torch import Tensor 14 | import torch.nn as nn 15 | 16 | 17 | def make_2tuple(x): 18 | if isinstance(x, tuple): 19 | assert len(x) == 2 20 | return x 21 | 22 | assert isinstance(x, int) 23 | return (x, x) 24 | 25 | 26 | class PatchEmbed(nn.Module): 27 | """ 28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 29 | 30 | Args: 31 | img_size: Image size. 32 | patch_size: Patch token size. 33 | in_chans: Number of input image channels. 34 | embed_dim: Number of linear projection output channels. 35 | norm_layer: Normalization layer. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | img_size: Union[int, Tuple[int, int]] = 224, 41 | patch_size: Union[int, Tuple[int, int]] = 16, 42 | in_chans: int = 3, 43 | embed_dim: int = 768, 44 | norm_layer: Optional[Callable] = None, 45 | flatten_embedding: bool = True, 46 | ) -> None: 47 | super().__init__() 48 | 49 | image_HW = make_2tuple(img_size) 50 | patch_HW = make_2tuple(patch_size) 51 | patch_grid_size = ( 52 | image_HW[0] // patch_HW[0], 53 | image_HW[1] // patch_HW[1], 54 | ) 55 | 56 | self.img_size = image_HW 57 | self.patch_size = patch_HW 58 | self.patches_resolution = patch_grid_size 59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 60 | 61 | self.in_chans = in_chans 62 | self.embed_dim = embed_dim 63 | 64 | self.flatten_embedding = flatten_embedding 65 | 66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 68 | 69 | def forward(self, x: Tensor) -> Tensor: 70 | _, _, H, W = x.shape 71 | patch_H, patch_W = self.patch_size 72 | 73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 75 | 76 | x = self.proj(x) # B C H W 77 | H, W = x.size(2), x.size(3) 78 | x = x.flatten(2).transpose(1, 2) # B HW C 79 | x = self.norm(x) 80 | if not self.flatten_embedding: 81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 82 | return x 83 | 84 | def flops(self) -> float: 85 | Ho, Wo = self.patches_resolution 86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 87 | if self.norm is not None: 88 | flops += Ho * Wo * self.embed_dim 89 | return flops 90 | -------------------------------------------------------------------------------- /depth/any2/dinov2_layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable, Optional 8 | 9 | from torch import Tensor, nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class SwiGLUFFN(nn.Module): 14 | def __init__( 15 | self, 16 | in_features: int, 17 | hidden_features: Optional[int] = None, 18 | out_features: Optional[int] = None, 19 | act_layer: Callable[..., nn.Module] = None, 20 | drop: float = 0.0, 21 | bias: bool = True, 22 | ) -> None: 23 | super().__init__() 24 | out_features = out_features or in_features 25 | hidden_features = hidden_features or in_features 26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 28 | 29 | def forward(self, x: Tensor) -> Tensor: 30 | x12 = self.w12(x) 31 | x1, x2 = x12.chunk(2, dim=-1) 32 | hidden = F.silu(x1) * x2 33 | return self.w3(hidden) 34 | 35 | 36 | try: 37 | from xformers.ops import SwiGLU 38 | 39 | XFORMERS_AVAILABLE = True 40 | except ImportError: 41 | SwiGLU = SwiGLUFFN 42 | XFORMERS_AVAILABLE = False 43 | 44 | 45 | class SwiGLUFFNFused(SwiGLU): 46 | def __init__( 47 | self, 48 | in_features: int, 49 | hidden_features: Optional[int] = None, 50 | out_features: Optional[int] = None, 51 | act_layer: Callable[..., nn.Module] = None, 52 | drop: float = 0.0, 53 | bias: bool = True, 54 | ) -> None: 55 | out_features = out_features or in_features 56 | hidden_features = hidden_features or in_features 57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 58 | super().__init__( 59 | in_features=in_features, 60 | hidden_features=hidden_features, 61 | out_features=out_features, 62 | bias=bias, 63 | ) 64 | -------------------------------------------------------------------------------- /depth/any2/dpt.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchvision.transforms import Compose 6 | 7 | from .dinov2 import DINOv2 8 | from .util.blocks import FeatureFusionBlock, _make_scratch 9 | from .util.transform import Resize, NormalizeImage, PrepareForNet 10 | 11 | def _make_fusion_block(features, use_bn, size=None): 12 | return FeatureFusionBlock(features, nn.ReLU(False), deconv=False, bn=use_bn, expand=False, align_corners=True, size=size) 13 | 14 | class ConvBlock(nn.Module): 15 | def __init__(self, in_feature, out_feature): 16 | super().__init__() 17 | self.conv_block = nn.Sequential( 18 | nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1), 19 | nn.BatchNorm2d(out_feature), 20 | nn.ReLU(True) 21 | ) 22 | def forward(self, x): 23 | return self.conv_block(x) 24 | 25 | class DPTHead(nn.Module): 26 | def __init__(self, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False): 27 | super(DPTHead, self).__init__() 28 | 29 | self.use_clstoken = use_clstoken 30 | 31 | self.projects = nn.ModuleList([ 32 | nn.Conv2d(in_channels=in_channels, out_channels=out_channel, kernel_size=1, stride=1, padding=0) for out_channel in out_channels 33 | ]) 34 | 35 | self.resize_layers = nn.ModuleList([ 36 | nn.ConvTranspose2d(in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0), 37 | nn.ConvTranspose2d(in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0), 38 | nn.Identity(), 39 | nn.Conv2d(in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1) 40 | ]) 41 | 42 | if use_clstoken: 43 | self.readout_projects = nn.ModuleList() 44 | for _ in range(len(self.projects)): 45 | self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU())) 46 | 47 | self.scratch = _make_scratch(out_channels, features, groups=1, expand=False) 48 | 49 | self.scratch.stem_transpose = None 50 | 51 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 52 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 53 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 54 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 55 | 56 | head_features_1 = features 57 | head_features_2 = 32 58 | 59 | self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) 60 | self.scratch.output_conv2 = nn.Sequential( 61 | nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), 62 | nn.ReLU(True), 63 | nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), 64 | nn.ReLU(True), 65 | nn.Identity(), 66 | ) 67 | 68 | def forward(self, out_features, patch_h, patch_w): 69 | out = [] 70 | for i, x in enumerate(out_features): 71 | if self.use_clstoken: 72 | x, cls_token = x[0], x[1] 73 | readout = cls_token.unsqueeze(1).expand_as(x) 74 | x = self.readout_projects[i](torch.cat((x, readout), -1)) 75 | else: 76 | x = x[0] 77 | x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) 78 | x = self.projects[i](x) 79 | x = self.resize_layers[i](x) 80 | out.append(x) 81 | 82 | layer_1, layer_2, layer_3, layer_4 = out 83 | 84 | layer_1_rn = self.scratch.layer1_rn(layer_1) 85 | layer_2_rn = self.scratch.layer2_rn(layer_2) 86 | layer_3_rn = self.scratch.layer3_rn(layer_3) 87 | layer_4_rn = self.scratch.layer4_rn(layer_4) 88 | 89 | path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) 90 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) 91 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) 92 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 93 | 94 | out = self.scratch.output_conv1(path_1) 95 | out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) 96 | out = self.scratch.output_conv2(out) 97 | return out 98 | 99 | class DepthAnythingV2(nn.Module): 100 | def __init__(self, encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], use_bn=False, use_clstoken=False): 101 | super(DepthAnythingV2, self).__init__() 102 | self.intermediate_layer_idx = { 103 | 'vits': [2, 5, 8, 11], 104 | 'vitb': [2, 5, 8, 11], 105 | 'vitl': [4, 11, 17, 23], 106 | 'vitg': [9, 19, 29, 39] 107 | } 108 | self.encoder = encoder 109 | self.pretrained = DINOv2(model_name=encoder) 110 | self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) 111 | 112 | def forward(self, x): 113 | patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14 114 | features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True) 115 | depth = self.depth_head(features, patch_h, patch_w) 116 | depth = F.relu(depth) 117 | return depth # .squeeze(1) 118 | 119 | @torch.no_grad() 120 | def infer_image(self, image, input_size=518, bgr=True): 121 | image, (h, w) = self.image2tensor(image, input_size, bgr=bgr) 122 | depth = self.forward(image) 123 | depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0] 124 | return depth.cpu().numpy() 125 | 126 | def image2tensor(self, image, input_size=518, bgr=True): 127 | transform = Compose([ 128 | Resize(width=input_size, height=input_size, resize_target=False, keep_aspect_ratio=True, ensure_multiple_of=14, 129 | resize_method='lower_bound', image_interpolation_method=cv2.INTER_CUBIC), 130 | NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 131 | PrepareForNet(), 132 | ]) 133 | h, w = image.shape[:2] 134 | if bgr: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0 135 | image = transform({'image': image})['image'] 136 | image = torch.from_numpy(image).unsqueeze(0) 137 | DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' 138 | image = image.to(DEVICE) 139 | return image, (h, w) 140 | 141 | -------------------------------------------------------------------------------- /depth/any2/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import cv2 4 | import numpy as np 5 | # import matplotlib 6 | 7 | import torch 8 | 9 | from deptha2.dpt import DepthAnythingV2 10 | 11 | from eps import img_list, img_read, basename, progbar 12 | 13 | parser = argparse.ArgumentParser(description='Depth Anything V2') 14 | parser.add_argument('-i', '--input', default='_in', help='Input image or folder') 15 | parser.add_argument('-o', '--out_dir', default='_out') 16 | parser.add_argument('-md','--maindir', default='./', help='Main directory') 17 | parser.add_argument('--encoder', default='vitl', choices=['vits', 'vitb', 'vitl', 'vitg']) 18 | parser.add_argument('-sz', '--size', type=int, default=768) # 518 19 | parser.add_argument('--seed', default=None, type=int, help='Random seed') 20 | # parser.add_argument('--pre', action='store_true', help='display combined mix') 21 | parser.add_argument('-v', '--verbose', action='store_true') 22 | a = parser.parse_args() 23 | 24 | model_configs = { 25 | 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, 26 | 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, 27 | 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, 28 | 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} 29 | } 30 | 31 | def main(): 32 | os.makedirs(a.out_dir, exist_ok=True) 33 | device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') 34 | 35 | depth_anything = DepthAnythingV2(**model_configs[a.encoder]) 36 | depth_anything.load_state_dict(torch.load(os.path.join(a.maindir, 'models', f'depth_anything_v2_{a.encoder}.pth'), map_location='cpu')) 37 | depth_anything = depth_anything.to(device).eval() 38 | 39 | # cmap = matplotlib.colormaps.get_cmap('Spectral_r') 40 | 41 | paths = [a.input] if os.path.isfile(a.input) else img_list(a.input) 42 | pbar = progbar(len(paths)) 43 | for k, path in enumerate(paths): 44 | img_in = cv2.imread(path) 45 | 46 | depth = depth_anything.infer_image(img_in, a.size) 47 | 48 | depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 49 | depth = depth.astype(np.uint8) 50 | depth = np.repeat(depth[..., np.newaxis], 3, axis=-1) 51 | # depth = (cmap(depth)[:, :, :3] * 255)[:, :, ::-1].astype(np.uint8) 52 | 53 | # if a.pre: 54 | # split_region = np.ones((img_in.shape[0], 50, 3), dtype=np.uint8) * 255 55 | # depth = cv2.hconcat([img_in, split_region, depth]) 56 | 57 | cv2.imwrite(os.path.join(a.out_dir, basename(path) + '.png'), depth) 58 | pbar.upd() 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | -------------------------------------------------------------------------------- /depth/any2/util/blocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 5 | scratch = nn.Module() 6 | 7 | out_shape1 = out_shape 8 | out_shape2 = out_shape 9 | out_shape3 = out_shape 10 | if len(in_shape) >= 4: 11 | out_shape4 = out_shape 12 | 13 | if expand: 14 | out_shape1 = out_shape 15 | out_shape2 = out_shape * 2 16 | out_shape3 = out_shape * 4 17 | if len(in_shape) >= 4: 18 | out_shape4 = out_shape * 8 19 | 20 | scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 21 | scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 22 | scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 23 | if len(in_shape) >= 4: 24 | scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 25 | 26 | return scratch 27 | 28 | 29 | class ResidualConvUnit(nn.Module): 30 | """Residual convolution module. 31 | """ 32 | 33 | def __init__(self, features, activation, bn): 34 | """Init. 35 | 36 | Args: 37 | features (int): number of features 38 | """ 39 | super().__init__() 40 | 41 | self.bn = bn 42 | 43 | self.groups=1 44 | 45 | self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) 46 | 47 | self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) 48 | 49 | if self.bn == True: 50 | self.bn1 = nn.BatchNorm2d(features) 51 | self.bn2 = nn.BatchNorm2d(features) 52 | 53 | self.activation = activation 54 | 55 | self.skip_add = nn.quantized.FloatFunctional() 56 | 57 | def forward(self, x): 58 | """Forward pass. 59 | 60 | Args: 61 | x (tensor): input 62 | 63 | Returns: 64 | tensor: output 65 | """ 66 | 67 | out = self.activation(x) 68 | out = self.conv1(out) 69 | if self.bn == True: 70 | out = self.bn1(out) 71 | 72 | out = self.activation(out) 73 | out = self.conv2(out) 74 | if self.bn == True: 75 | out = self.bn2(out) 76 | 77 | if self.groups > 1: 78 | out = self.conv_merge(out) 79 | 80 | return self.skip_add.add(out, x) 81 | 82 | 83 | class FeatureFusionBlock(nn.Module): 84 | """Feature fusion block. 85 | """ 86 | 87 | def __init__( 88 | self, 89 | features, 90 | activation, 91 | deconv=False, 92 | bn=False, 93 | expand=False, 94 | align_corners=True, 95 | size=None 96 | ): 97 | """Init. 98 | 99 | Args: 100 | features (int): number of features 101 | """ 102 | super(FeatureFusionBlock, self).__init__() 103 | 104 | self.deconv = deconv 105 | self.align_corners = align_corners 106 | 107 | self.groups=1 108 | 109 | self.expand = expand 110 | out_features = features 111 | if self.expand == True: 112 | out_features = features // 2 113 | 114 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 115 | 116 | self.resConfUnit1 = ResidualConvUnit(features, activation, bn) 117 | self.resConfUnit2 = ResidualConvUnit(features, activation, bn) 118 | 119 | self.skip_add = nn.quantized.FloatFunctional() 120 | 121 | self.size=size 122 | 123 | def forward(self, *xs, size=None): 124 | """Forward pass. 125 | 126 | Returns: 127 | tensor: output 128 | """ 129 | output = xs[0] 130 | 131 | if len(xs) == 2: 132 | res = self.resConfUnit1(xs[1]) 133 | output = self.skip_add.add(output, res) 134 | 135 | output = self.resConfUnit2(output) 136 | 137 | if (size is None) and (self.size is None): 138 | modifier = {"scale_factor": 2} 139 | elif size is None: 140 | modifier = {"size": self.size} 141 | else: 142 | modifier = {"size": size} 143 | 144 | output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) 145 | 146 | output = self.out_conv(output) 147 | 148 | return output 149 | -------------------------------------------------------------------------------- /depth/any2/util/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | 5 | class Resize(object): 6 | """Resize sample to given size (width, height). 7 | """ 8 | 9 | def __init__( 10 | self, 11 | width, 12 | height, 13 | resize_target=True, 14 | keep_aspect_ratio=False, 15 | ensure_multiple_of=1, 16 | resize_method="lower_bound", 17 | image_interpolation_method=cv2.INTER_AREA, 18 | ): 19 | """Init. 20 | 21 | Args: 22 | width (int): desired output width 23 | height (int): desired output height 24 | resize_target (bool, optional): 25 | True: Resize the full sample (image, mask, target). 26 | False: Resize image only. 27 | Defaults to True. 28 | keep_aspect_ratio (bool, optional): 29 | True: Keep the aspect ratio of the input sample. 30 | Output sample might not have the given width and height, and 31 | resize behaviour depends on the parameter 'resize_method'. 32 | Defaults to False. 33 | ensure_multiple_of (int, optional): 34 | Output width and height is constrained to be multiple of this parameter. 35 | Defaults to 1. 36 | resize_method (str, optional): 37 | "lower_bound": Output will be at least as large as the given size. 38 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 39 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 40 | Defaults to "lower_bound". 41 | """ 42 | self.__width = width 43 | self.__height = height 44 | 45 | self.__resize_target = resize_target 46 | self.__keep_aspect_ratio = keep_aspect_ratio 47 | self.__multiple_of = ensure_multiple_of 48 | self.__resize_method = resize_method 49 | self.__image_interpolation_method = image_interpolation_method 50 | 51 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 52 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 53 | 54 | if max_val is not None and y > max_val: 55 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 56 | 57 | if y < min_val: 58 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 59 | 60 | return y 61 | 62 | def get_size(self, width, height): 63 | # determine new height and width 64 | scale_height = self.__height / height 65 | scale_width = self.__width / width 66 | 67 | if self.__keep_aspect_ratio: 68 | if self.__resize_method == "lower_bound": 69 | # scale such that output size is lower bound 70 | if scale_width > scale_height: 71 | # fit width 72 | scale_height = scale_width 73 | else: 74 | # fit height 75 | scale_width = scale_height 76 | elif self.__resize_method == "upper_bound": 77 | # scale such that output size is upper bound 78 | if scale_width < scale_height: 79 | # fit width 80 | scale_height = scale_width 81 | else: 82 | # fit height 83 | scale_width = scale_height 84 | elif self.__resize_method == "minimal": 85 | # scale as least as possbile 86 | if abs(1 - scale_width) < abs(1 - scale_height): 87 | # fit width 88 | scale_height = scale_width 89 | else: 90 | # fit height 91 | scale_width = scale_height 92 | else: 93 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 94 | 95 | if self.__resize_method == "lower_bound": 96 | new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height) 97 | new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width) 98 | elif self.__resize_method == "upper_bound": 99 | new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height) 100 | new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width) 101 | elif self.__resize_method == "minimal": 102 | new_height = self.constrain_to_multiple_of(scale_height * height) 103 | new_width = self.constrain_to_multiple_of(scale_width * width) 104 | else: 105 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 106 | 107 | return (new_width, new_height) 108 | 109 | def __call__(self, sample): 110 | width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0]) 111 | 112 | # resize sample 113 | sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method) 114 | 115 | if self.__resize_target: 116 | if "depth" in sample: 117 | sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST) 118 | 119 | if "mask" in sample: 120 | sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST) 121 | 122 | return sample 123 | 124 | 125 | class NormalizeImage(object): 126 | """Normlize image by given mean and std. 127 | """ 128 | 129 | def __init__(self, mean, std): 130 | self.__mean = mean 131 | self.__std = std 132 | 133 | def __call__(self, sample): 134 | sample["image"] = (sample["image"] - self.__mean) / self.__std 135 | 136 | return sample 137 | 138 | 139 | class PrepareForNet(object): 140 | """Prepare sample for usage as network input. 141 | """ 142 | 143 | def __init__(self): 144 | pass 145 | 146 | def __call__(self, sample): 147 | image = np.transpose(sample["image"], (2, 0, 1)) 148 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 149 | 150 | if "depth" in sample: 151 | depth = sample["depth"].astype(np.float32) 152 | sample["depth"] = np.ascontiguousarray(depth) 153 | 154 | if "mask" in sample: 155 | sample["mask"] = sample["mask"].astype(np.float32) 156 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 157 | 158 | return sample -------------------------------------------------------------------------------- /depth/depth.py: -------------------------------------------------------------------------------- 1 | ### original method & code was by https://twitter.com/deKxi 2 | 3 | import logging 4 | logging.getLogger('xformers').setLevel(logging.ERROR) # shutup triton, before torch! 5 | 6 | import os 7 | import sys 8 | import cv2 9 | from imageio import imsave 10 | import numpy as np 11 | import PIL 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torchvision import transforms as T 16 | 17 | from aphantasia.utils import triangle_blur 18 | from .any2.dpt import DepthAnythingV2 19 | 20 | class InferDepthAny: 21 | def __init__(self, modtype='B', device=torch.device('cuda')): 22 | modtype = 'Large' if modtype[0].lower()=='l' else 'Small' if modtype[0].lower()=='s' else 'Base' 23 | from transformers import AutoModelForDepthEstimation 24 | model = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-%s-hf" % modtype) 25 | self.model = model.cuda().eval() 26 | 27 | @torch.no_grad() 28 | def __call__(self, image): 29 | image = T.functional.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 30 | depth = self.model(pixel_values=image).predicted_depth.unsqueeze(0) 31 | return (depth - depth.min()) / (depth.max() - depth.min()) 32 | 33 | def save_img(img, fname=None): 34 | if fname is not None: 35 | img = np.array(img)[:,:,:] 36 | img = np.transpose(img, (1,2,0)) 37 | img = np.clip(img*255, 0, 255).astype(np.uint8) 38 | if img.shape[-1]==1: img = img[:,:,[0,0,0]] 39 | imsave(fname, np.array(img)) 40 | 41 | def resize(img, size): 42 | return F.interpolate(img, size, mode='bicubic', align_corners=True).float().cuda() 43 | 44 | def grid_warp(img, dtensor, H, W, strength, centre, midpoint, dlens=0.05): 45 | # Building the coordinates 46 | xx = torch.linspace(-1, 1, W) 47 | yy = torch.linspace(-1, 1, H) 48 | gy, gx = torch.meshgrid(yy, xx) 49 | 50 | # Apply depth warp 51 | grid = torch.stack([gx, gy], dim=-1).cuda() 52 | d = centre - grid 53 | d_sum = dtensor[0] 54 | # Adjust midpoint / move direction 55 | d_sum = d_sum - torch.max(d_sum) * midpoint 56 | grid_warped = grid + d * d_sum.unsqueeze(-1) * strength 57 | img = F.grid_sample(img, grid_warped.unsqueeze(0).float(), mode='bilinear', align_corners=True, padding_mode='reflection') 58 | 59 | # Apply simple lens distortion to stretch periphery (instead of sphere wrap) 60 | lens_distortion = torch.sqrt((d**2).sum(axis=-1)).cuda() 61 | grid_warped = grid + d * lens_distortion.unsqueeze(-1) * strength * dlens 62 | img = F.grid_sample(img, grid_warped.unsqueeze(0).float(), mode='bilinear', align_corners=True, padding_mode='reflection') 63 | 64 | return img 65 | 66 | def depthwarp(img_t, img, infer_any, strength=0, centre=[0,0], midpoint=0.5, save_path=None, save_num=0, dlens=0.05): 67 | _, _, H, W = img.shape # [1,3,720,1280] [0..1] 68 | 69 | res = 518 # 518 on lower dimension for DepthAny 70 | dim = [res, int(res*W/H)] if H < W else [int(res*H/W), res] 71 | dim = [x - x % 14 for x in dim] 72 | 73 | image = resize(torch.lerp(img, triangle_blur(img, 5, 2), 0.5), dim) # [1,3,518,910] [0..1] 74 | depth = infer_any(image) # [1,1,h,w] 75 | depth = depth * torch.flip(infer_any(torch.flip(image, [-1])), [-1]) # enhance depth with mirrored estimation 76 | depth = resize(depth, (H,W)) # [1,1,H,W] 77 | 78 | if save_path is not None: # Save depth map out, currently its as its own image but it could just be added as an alpha channel to main image 79 | out_depth = depth.detach().clone().cpu().squeeze(0) 80 | save_img(out_depth, os.path.join(save_path, '%05d.jpg' % save_num)) 81 | 82 | img = grid_warp(img_t, depth.squeeze(0), H, W, strength, torch.as_tensor(centre).cuda(), midpoint, dlens) 83 | 84 | return img 85 | 86 | -------------------------------------------------------------------------------- /illustra.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import os 3 | import time 4 | import warnings 5 | warnings.filterwarnings("ignore") 6 | import argparse 7 | import numpy as np 8 | import random 9 | import shutil 10 | 11 | import torch 12 | import torchvision 13 | import torch.nn.functional as F 14 | 15 | import clip 16 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 17 | 18 | from aphantasia.image import to_valid_rgb, fft_image 19 | from aphantasia.utils import slice_imgs, derivat, checkout, basename, file_list, img_list, img_read, txt_clean, old_torch, save_cfg, sim_func, aesthetic_model 20 | from aphantasia import transforms 21 | try: # progress bar for notebooks 22 | get_ipython().__class__.__name__ 23 | from aphantasia.progress_bar import ProgressIPy as ProgressBar 24 | except: # normal console 25 | from aphantasia.progress_bar import ProgressBar 26 | 27 | clip_models = ['ViT-B/16', 'ViT-B/32', 'ViT-L/14', 'ViT-L/14@336px', 'RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101'] 28 | 29 | def get_args(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('-s', '--size', default='1280-720', help='Output resolution') 32 | parser.add_argument('-t', '--in_txt', default=None, help='input text or file - main topic') 33 | parser.add_argument('-t2', '--in_txt2', default=None, help='input text or file - style') 34 | parser.add_argument('-im', '--in_img', default=None, help='input image or directory with images') 35 | parser.add_argument('-r', '--resume', default=None, help='Resume from saved params') 36 | parser.add_argument( '--out_dir', default='_out/fft') 37 | parser.add_argument( '--save_step', default=1, type=int, help='Save every this step') 38 | parser.add_argument('-tr', '--translate', action='store_true', help='Translate with Google Translate') 39 | parser.add_argument('-v', '--verbose', dest='verbose', action='store_true') 40 | parser.add_argument('-nv', '--no-verbose', dest='verbose', action='store_false') 41 | parser.set_defaults(verbose=True) 42 | # training 43 | parser.add_argument('-m', '--model', default='ViT-B/32', choices=clip_models, help='Select CLIP model to use') 44 | parser.add_argument( '--steps', default=150, type=int, help='Iterations per input') 45 | parser.add_argument( '--samples', default=200, type=int, help='Samples to evaluate') 46 | parser.add_argument('-lr', '--lrate', default=0.05, type=float, help='Learning rate') 47 | parser.add_argument('-dm', '--dualmod', default=None, type=int, help='Every this step use another CLIP ViT model') 48 | # tweaks 49 | parser.add_argument('-opt', '--optimr', default='adam', choices=['adam', 'adamw'], help='Optimizer') 50 | parser.add_argument('-a', '--align', default='uniform', choices=['central', 'uniform', 'overscan', 'overmax'], help='Sampling distribution') 51 | parser.add_argument('-tf', '--transform', default='fast', choices=['none', 'custom', 'fast', 'elastic'], help='augmenting transforms') 52 | parser.add_argument( '--aest', default=1., type=float) 53 | parser.add_argument( '--contrast', default=1.1, type=float) 54 | parser.add_argument( '--colors', default=1.8, type=float) 55 | parser.add_argument('-d', '--decay', default=1.5, type=float) 56 | parser.add_argument('-sh', '--sharp', default=0, type=float) 57 | parser.add_argument('-mc', '--macro', default=0.4, type=float, help='Endorse macro forms 0..1 ') 58 | parser.add_argument('-e', '--enforce', default=0, type=float, help='Enhance consistency, boosts training') 59 | parser.add_argument('-n', '--noise', default=0, type=float, help='Add noise to decrease accumulation') 60 | parser.add_argument( '--sim', default='mix', help='Similarity function (dot/angular/spherical/mixed; None = cossim)') 61 | parser.add_argument( '--loop', action='store_true', help='Loop inputs [or keep the last one]') 62 | parser.add_argument( '--save_pt', action='store_true', help='save fft snapshots to pt file') 63 | # multi input 64 | parser.add_argument('-l', '--length', default=None, type=int, help='Override total length in sec') 65 | parser.add_argument( '--lsteps', default=25, type=int, help='Frames per step') 66 | parser.add_argument( '--fps', default=25, type=int) 67 | parser.add_argument( '--keep', default=1.5, type=float, help='Accumulate imagery: 0 random, 1+ ~prev') 68 | parser.add_argument( '--separate', action='store_true', help='process inputs separately') 69 | a = parser.parse_args() 70 | 71 | if a.size is not None: a.size = [int(s) for s in a.size.split('-')][::-1] 72 | if len(a.size)==1: a.size = a.size * 2 73 | if not a.separate: a.save_pt = True 74 | if a.dualmod is not None: 75 | a.model = 'ViT-B/32' 76 | a.sim = 'cossim' 77 | 78 | return a 79 | 80 | a = get_args() 81 | 82 | if a.translate is True: 83 | try: 84 | from googletrans import Translator 85 | except ImportError as e: 86 | print('\n Install googletrans module to enable translation!'); exit() 87 | 88 | def main(): 89 | bx = 1. 90 | 91 | model_clip, _ = clip.load(a.model, jit=old_torch()) 92 | try: 93 | a.modsize = model_clip.visual.input_resolution 94 | except: 95 | a.modsize = 288 if a.model == 'RN50x4' else 384 if a.model == 'RN50x16' else 448 if a.model == 'RN50x64' else 336 if '336' in a.model else 224 96 | model_clip = model_clip.eval().cuda() 97 | xmem = {'ViT-B/16':0.25, 'ViT-L/14':0.04, 'RN50':0.5, 'RN50x4':0.16, 'RN50x16':0.06, 'RN50x64':0.01, 'RN101':0.33} 98 | if a.model in xmem.keys(): 99 | bx *= xmem[a.model] 100 | 101 | if a.dualmod is not None: 102 | model_clip2, _ = clip.load('ViT-B/16', jit=old_torch()) 103 | bx *= 0.23 # second is vit-16 104 | dualmod_nums = list(range(a.steps))[a.dualmod::a.dualmod] 105 | print(' dual model every %d step' % a.dualmod) 106 | 107 | if a.aest != 0 and a.model in ['ViT-B/32', 'ViT-B/16', 'ViT-L/14']: 108 | aest = aesthetic_model(a.model).cuda() 109 | if a.dualmod is not None: 110 | aest2 = aesthetic_model('ViT-B/16').cuda() 111 | 112 | if 'elastic' in a.transform: 113 | trform_f = transforms.transforms_elastic 114 | elif 'custom' in a.transform: 115 | trform_f = transforms.transforms_custom 116 | elif 'fast' in a.transform: 117 | trform_f = transforms.transforms_fast 118 | else: 119 | trform_f = transforms.normalize() 120 | bx *= 1.05 121 | bx *= 0.95 122 | if a.enforce != 0: 123 | bx *= 0.5 124 | a.samples = int(bx * a.samples) 125 | 126 | if a.translate: 127 | translator = Translator() 128 | 129 | def enc_text(txt, model_clip=model_clip): 130 | if txt is None or len(txt)==0: return None 131 | embs = [] 132 | for subtxt in txt.split('|'): 133 | if ':' in subtxt: 134 | [subtxt, wt] = subtxt.split(':') 135 | wt = float(wt) 136 | else: wt = 1. 137 | emb = model_clip.encode_text(clip.tokenize(subtxt).cuda()[:77]) 138 | # emb = emb / emb.norm(dim=-1, keepdim=True) 139 | embs.append([emb.detach().clone(), wt]) 140 | return embs 141 | 142 | def enc_image(img, model_clip=model_clip): 143 | emb = model_clip.encode_image(img) 144 | # emb = emb / emb.norm(dim=-1, keepdim=True) 145 | return emb 146 | 147 | def proc_image(img_file, model_clip=model_clip): 148 | img_t = torch.from_numpy(img_read(img_file)/255.).unsqueeze(0).permute(0,3,1,2).cuda()[:,:3,:,:] 149 | in_sliced = slice_imgs([img_t], a.samples, a.modsize, transforms.normalize(), a.align)[0] 150 | emb = enc_image(in_sliced, model_clip) 151 | return emb.detach().clone() 152 | 153 | def pick_(list_, num_): 154 | cnt = len(list_) 155 | if cnt == 0: return None 156 | num = num_ % cnt if a.loop is True else min(num_, cnt-1) 157 | return list_[num] 158 | 159 | def read_text(in_txt): 160 | if os.path.isfile(in_txt): 161 | with open(in_txt, 'r', encoding="utf-8") as f: 162 | lines = f.read().splitlines() 163 | texts = [] 164 | for tt in lines: 165 | if len(tt.strip()) == 0: texts.append('') 166 | elif tt.strip()[0] != '#': texts.append(tt.strip()) 167 | else: 168 | texts = [in_txt] 169 | return texts 170 | 171 | # Encode inputs 172 | count = 0 173 | texts = [] 174 | styles = [] 175 | img_paths = [] 176 | 177 | if a.in_img is not None and os.path.exists(a.in_img): 178 | if a.verbose is True: print(' ref image:', basename(a.in_img)) 179 | img_paths = img_list(a.in_img) if os.path.isdir(a.in_img) else [a.in_img] 180 | img_encs = [proc_image(image) for image in img_paths] 181 | if a.dualmod is not None: 182 | img_encs2 = [proc_image(image, model_clip2) for image in img_paths] 183 | count = max(count, len(img_encs)) 184 | 185 | if a.in_txt is not None: 186 | if a.verbose is True: print(' topic:', a.in_txt) 187 | texts = read_text(a.in_txt) 188 | if a.translate: 189 | texts = [translator.translate(txt, dest='en').text for txt in texts] 190 | # if a.verbose is True: print(' translated to:', texts) 191 | txt_encs = [enc_text(txt) for txt in texts] 192 | if a.dualmod is not None: 193 | txt_encs2 = [enc_text(txt, model_clip2) for txt in texts] 194 | count = max(count, len(txt_encs)) 195 | 196 | if a.in_txt2 is not None: 197 | if a.verbose is True: print(' style:', a.in_txt2) 198 | styles = read_text(a.in_txt2) 199 | if a.translate is True: 200 | styles = [tr.text for tr in translator.translate(styles)] 201 | # if a.verbose is True: print(' translated to:', styles) 202 | styl_encs = [enc_text(style) for style in styles] 203 | if a.dualmod is not None: 204 | styl_encs2 = [enc_text(style, model_clip2) for style in styles] 205 | count = max(count, len(styl_encs)) 206 | 207 | assert count > 0, "No inputs found!" 208 | 209 | if a.verbose is True: print(' samples:', a.samples) 210 | sfx = '' 211 | if a.dualmod is None: sfx += '-%s' % a.model.replace('/','').replace('-','') 212 | if a.enforce != 0: sfx += '-e%.2g' % a.enforce 213 | # if a.noise > 0: sfx += '-n%.2g' % a.noise 214 | # if a.aest != 0: sfx += '-ae%.2g' % a.aest 215 | 216 | def train(num, i): 217 | loss = 0 218 | noise = a.noise * (torch.rand(1, 1, *params[0].shape[2:4], 1)-0.5).cuda() if a.noise > 0 else None 219 | img_out = image_f(noise) 220 | img_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f, a.align, a.macro)[0] 221 | 222 | if a.in_txt is not None: 223 | txt_enc = pick_(txt_encs2, num) if a.dualmod is not None and i in dualmod_nums else pick_(txt_encs, num) 224 | if a.in_txt2 is not None: 225 | style_enc = pick_(styl_encs2, num) if a.dualmod is not None and i in dualmod_nums else pick_(styl_encs, num) 226 | if a.in_img is not None and os.path.isfile(a.in_img): 227 | img_enc = pick_(img_encs2, num) if a.dualmod is not None and i in dualmod_nums else pick_(img_encs, num) 228 | model_clip_ = model_clip2 if a.dualmod is not None and i in dualmod_nums else model_clip 229 | if a.aest != 0: 230 | aest_ = aest2 if a.dualmod is not None and i in dualmod_nums else aest 231 | 232 | out_enc = model_clip_.encode_image(img_sliced) 233 | if a.aest != 0 and aest_ is not None: 234 | loss -= 0.001 * a.aest * aest_(out_enc).mean() 235 | if a.in_txt is not None and txt_enc is not None: # input text - main topic 236 | for enc, wt in txt_enc: 237 | loss -= wt * sim_func(enc, out_enc, a.sim) 238 | if a.in_txt2 is not None and style_enc is not None: # input text - style 239 | for enc, wt in style_enc: 240 | loss -= wt * sim_func(enc, out_enc, a.sim) 241 | if a.in_img is not None and img_enc is not None: # input image 242 | loss -= sim_func(img_enc[:len(out_enc)], out_enc, a.sim) 243 | if a.sharp != 0: # scharr|sobel|naiv 244 | loss -= a.sharp * derivat(img_out, mode='naiv') 245 | if a.enforce != 0: 246 | img_sliced = slice_imgs([image_f(noise)], a.samples, a.modsize, trform_f, a.align, a.macro)[0] 247 | out_enc2 = model_clip_.encode_image(img_sliced) 248 | loss -= a.enforce * sim_func(out_enc, out_enc2, a.sim) 249 | del out_enc2 # torch.cuda.empty_cache() 250 | 251 | del img_out, img_sliced, out_enc 252 | assert not isinstance(loss, int), ' Loss not defined, check inputs' 253 | 254 | optimizer.zero_grad() 255 | loss.backward() 256 | optimizer.step() 257 | 258 | if i % a.save_step == 0: 259 | with torch.no_grad(): 260 | img = image_f(contrast=a.contrast).cpu().numpy()[0] 261 | checkout(img, os.path.join(tempdir, '%04d.jpg' % (i // a.save_step)), verbose=a.verbose) 262 | pbar.upd() 263 | del img 264 | 265 | 266 | try: 267 | for num in range(count): 268 | shape = [1, 3, *a.size] 269 | global params 270 | 271 | if num == 0 or a.separate is True: 272 | resume_cur = a.resume 273 | else: 274 | opt_state = optimizer.state_dict() 275 | param_ = params[0].detach() 276 | resume_cur = [a.keep * param_ / (param_.max() - param_.min())] 277 | 278 | params, image_f, sz = fft_image(shape, 0.08, a.decay, resume_cur) 279 | if sz is not None: a.size = sz 280 | image_f = to_valid_rgb(image_f, colors = a.colors) 281 | 282 | if a.optimr.lower() == 'adamw': 283 | optimizer = torch.optim.AdamW(params, a.lrate, weight_decay=0.01, betas=(.0,.999), amsgrad=True) 284 | else: 285 | optimizer = torch.optim.Adam(params, a.lrate, betas=(.0, .999)) 286 | if num > 0 and not a.separate: optimizer.load_state_dict(opt_state) 287 | 288 | out_names = [] 289 | if a.resume is not None and num == 0: out_names += [basename(a.resume)[:12]] 290 | if a.in_txt is not None: out_names += [txt_clean(pick_(texts, num))[:32]] 291 | if a.in_txt2 is not None: out_names += [txt_clean(pick_(styles, num))[:32]] 292 | out_name = '-'.join(out_names) + sfx 293 | if count > 1: out_name = '%04d-' % (num+1) + out_name 294 | print(out_name) 295 | workdir = a.out_dir 296 | tempdir = os.path.join(workdir, out_name) 297 | os.makedirs(tempdir, exist_ok=True) 298 | if num == 0: save_cfg(a, workdir, out_name + '.txt') 299 | 300 | pbar = ProgressBar(a.steps // a.save_step) 301 | for i in range(a.steps): 302 | train(num, i) 303 | 304 | file_out = os.path.join(workdir, '%s-%d.jpg' % (out_name, a.steps)) 305 | shutil.copy(img_list(tempdir)[-1], file_out) 306 | os.system('ffmpeg -v warning -y -i %s\%%04d.jpg "%s.mp4"' % (tempdir, os.path.join(workdir, out_name))) 307 | if a.save_pt is True: 308 | torch.save(params[0], '%s.pt' % os.path.join(workdir, out_name)) 309 | 310 | except KeyboardInterrupt: 311 | exit() 312 | 313 | if not a.separate: 314 | vsteps = a.lsteps if a.length is None else int(a.length * a.fps / count) 315 | tempdir = os.path.join(workdir, '_final') 316 | os.makedirs(tempdir, exist_ok=True) 317 | 318 | def read_pt(file): 319 | return torch.load(file).cuda() 320 | 321 | if a.verbose is True: print(' rendering complete piece') 322 | ptfiles = file_list(workdir, 'pt') 323 | pbar = ProgressBar(vsteps * len(ptfiles)) 324 | for px in range(len(ptfiles)): 325 | params1 = read_pt(ptfiles[px]) 326 | params2 = read_pt(ptfiles[(px+1) % len(ptfiles)]) 327 | 328 | params, image_f, sz_ = fft_image([1, 3, *a.size], resume=params1, sd=1., decay_power=a.decay) 329 | image_f = to_valid_rgb(image_f, colors = a.colors) 330 | 331 | for i in range(vsteps): 332 | with torch.no_grad(): 333 | x = i/vsteps # math.sin(1.5708 * i/vsteps) 334 | img = image_f((params2 - params1) * x, contrast=a.contrast).cpu().numpy()[0] 335 | checkout(img, os.path.join(tempdir, '%05d.jpg' % (px * vsteps + i)), verbose=a.verbose) 336 | pbar.upd() 337 | 338 | os.system('ffmpeg -v warning -y -i %s/\%%05d.jpg "%s.mp4"' % (tempdir, os.path.join(a.out_dir, basename(a.in_txt)))) 339 | 340 | 341 | if __name__ == '__main__': 342 | main() 343 | -------------------------------------------------------------------------------- /illustrip.py: -------------------------------------------------------------------------------- 1 | # coding: UTF-8 2 | import os 3 | os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 4 | import warnings 5 | warnings.filterwarnings("ignore") 6 | import argparse 7 | import numpy as np 8 | import shutil 9 | import PIL 10 | import time 11 | from imageio import imread, imsave 12 | 13 | try: 14 | from googletrans import Translator 15 | googletrans_ok = True 16 | except: 17 | googletrans_ok = False 18 | 19 | import torch 20 | import torchvision 21 | import torch.nn.functional as F 22 | from torchvision import transforms as T 23 | 24 | import clip 25 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 26 | 27 | from aphantasia.image import to_valid_rgb, fft_image, resume_fft, pixel_image 28 | from aphantasia.utils import slice_imgs, derivat, sim_func, aesthetic_model, intrl, slerp, basename, file_list, img_list, img_read, pad_up_to, txt_clean, latent_anima, cvshow, checkout, save_cfg, old_torch 29 | from aphantasia import transforms 30 | from depth import depth 31 | try: # progress bar for notebooks 32 | get_ipython().__class__.__name__ 33 | from aphantasia.progress_bar import ProgressIPy as ProgressBar 34 | except: # normal console 35 | from aphantasia.progress_bar import ProgressBar 36 | 37 | clip_models = ['ViT-B/16', 'ViT-B/32', 'RN50', 'RN50x4', 'RN50x16', 'RN101'] 38 | 39 | def get_args(): 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('-s', '--size', default='1280-720', help='Output resolution') 42 | parser.add_argument('-t', '--in_txt', default=None, help='Text string or file to process (main topic)') 43 | parser.add_argument('-pre', '--in_txt_pre', default=None, help='Prefix for input text') 44 | parser.add_argument('-post', '--in_txt_post', default=None, help='Postfix for input text') 45 | parser.add_argument('-t2', '--in_txt2', default=None, help='Text string or file to process (style)') 46 | parser.add_argument('-t0', '--in_txt0', default=None, help='input text to subtract') 47 | parser.add_argument('-im', '--in_img', default=None, help='input image or directory with images') 48 | parser.add_argument('-wi', '--weight_img', default=0.5, type=float, help='weight for images') 49 | parser.add_argument('-r', '--resume', default=None, help='Resume from saved params or from an image') 50 | parser.add_argument( '--out_dir', default='_out') 51 | parser.add_argument('-tr', '--translate', action='store_true', help='Translate with Google Translate') 52 | parser.add_argument( '--invert', action='store_true', help='Invert criteria') 53 | parser.add_argument('-v', '--verbose', dest='verbose', action='store_true') 54 | parser.add_argument('-nv', '--no-verbose', dest='verbose', action='store_false') 55 | parser.set_defaults(verbose=True) 56 | # training 57 | parser.add_argument( '--gen', default='RGB', help='Generation (optimization) method: FFT or RGB') 58 | parser.add_argument('-m', '--model', default='ViT-B/32', choices=clip_models, help='Select CLIP model to use') 59 | parser.add_argument( '--steps', default=300, type=int, help='Iterations (frames) per scene (text line)') 60 | parser.add_argument( '--samples', default=100, type=int, help='Samples to evaluate per frame') 61 | parser.add_argument('-lr', '--lrate', default=0.1, type=float, help='Learning rate') 62 | parser.add_argument('-dm', '--dualmod', default=None, type=int, help='Every this step use another CLIP ViT model') 63 | # motion 64 | parser.add_argument('-ops', '--opt_step', default=1, type=int, help='How many optimizing steps per save/transform step') 65 | parser.add_argument('-sm', '--smooth', action='store_true', help='Smoothen interframe jittering for FFT method') 66 | parser.add_argument('-it', '--interpol', default=True, help='Interpolate topics? (or change by cut)') 67 | parser.add_argument( '--fstep', default=100, type=int, help='How many frames before changing motion') 68 | parser.add_argument( '--scale', default=0.012, type=float) 69 | parser.add_argument( '--shift', default=10., type=float, help='in pixels') 70 | parser.add_argument( '--angle', default=0.8, type=float, help='in degrees') 71 | parser.add_argument( '--shear', default=0.4, type=float) 72 | parser.add_argument( '--anima', default=True, help='Animate motion') 73 | # depth 74 | parser.add_argument('-d', '--depth', default=0, type=float, help='Add depth with such strength, if > 0') 75 | parser.add_argument( '--depth_model', default='b', help='Depth Anything model: large, base or small') 76 | parser.add_argument( '--depth_dir', default=None, help='Directory to save depth, if not None') 77 | # tweaks 78 | parser.add_argument('-a', '--align', default='overscan', choices=['central', 'uniform', 'overscan', 'overmax'], help='Sampling distribution') 79 | parser.add_argument('-tf', '--transform', default='fast', choices=['none', 'fast', 'custom', 'elastic'], help='augmenting transforms') 80 | parser.add_argument('-opt', '--optimizer', default='adam_custom', choices=['adam', 'adam_custom', 'adamw', 'adamw_custom'], help='Optimizer') 81 | parser.add_argument( '--fixcontrast', action='store_true', help='Required for proper resuming from image') 82 | parser.add_argument( '--contrast', default=1.2, type=float) 83 | parser.add_argument( '--colors', default=2.3, type=float) 84 | parser.add_argument('-sh', '--sharp', default=0, type=float) 85 | parser.add_argument('-mc', '--macro', default=0.3, type=float, help='Endorse macro forms 0..1 ') 86 | parser.add_argument( '--aest', default=0., type=float, help='Enhance aesthetics') 87 | parser.add_argument('-e', '--enforce', default=0, type=float, help='Enforce details (by boosting similarity between two parallel samples)') 88 | parser.add_argument('-x', '--expand', default=0, type=float, help='Boosts diversity (by enforcing difference between prev/next samples)') 89 | parser.add_argument('-n', '--noise', default=2., type=float, help='Add noise to make composition sparse (FFT only)') # 0.04 90 | parser.add_argument( '--sim', default='mix', help='Similarity function (angular/spherical/mixed; None = cossim)') 91 | parser.add_argument( '--rem', default=None, help='Dummy text to add to project name') 92 | a = parser.parse_args() 93 | 94 | if a.size is not None: a.size = [int(s) for s in a.size.split('-')][::-1] 95 | if len(a.size)==1: a.size = a.size * 2 96 | a.gen = a.gen.upper() 97 | a.invert = -1. if a.invert is True else 1. 98 | 99 | # Overriding some parameters, depending on other settings 100 | if a.gen == 'RGB': 101 | a.smooth = False 102 | a.align = 'overscan' 103 | if a.resume is not None: a.fixcontrast = True 104 | if a.model == 'ViT-B/16': a.sim = 'cossim' 105 | 106 | if a.translate is True and googletrans_ok is not True: 107 | print('\n Install googletrans module to enable translation!'); exit() 108 | 109 | if a.dualmod is not None: 110 | a.model = 'ViT-B/32' 111 | a.sim = 'cossim' 112 | 113 | return a 114 | 115 | def depth_transform(img_t, _deptha, depthX=0, scale=1., shift=[0,0], colors=1, depth_dir=None, save_num=0): 116 | if not isinstance(depthX, float): depthX = float(depthX) 117 | if not isinstance(scale, float): scale = float(scale[0]) 118 | size = img_t.shape[-2:] 119 | # d X/Y define the origin point of the depth warp, effectively a "3D pan zoom", [-1..1] 120 | # plus = look ahead, minus = look aside 121 | dX = 100. * shift[0] / size[1] 122 | dY = 100. * shift[1] / size[0] 123 | # dZ = movement direction: 1 away (zoom out), 0 towards (zoom in), 0.5 stay 124 | dZ = 0.5 + 32. * (scale-1) 125 | def ttt(x): return x 126 | img = to_valid_rgb(ttt, colors = colors)(img_t) 127 | img = depth.depthwarp(img_t, img, _deptha, depthX, [dX,dY], dZ, save_path=depth_dir, save_num=save_num) 128 | return img 129 | 130 | def frame_transform(img, size, angle, shift, scale, shear): 131 | if old_torch(): # 1.7.1 132 | img = T.functional.affine(img, angle, tuple(shift), scale, shear, fillcolor=0, resample=PIL.Image.BILINEAR) 133 | img = T.functional.center_crop(img, size) 134 | img = pad_up_to(img, size) 135 | else: # 1.8+ 136 | img = T.functional.affine(img, angle, tuple(shift), scale, shear, fill=0, interpolation=T.InterpolationMode.BILINEAR) 137 | img = T.functional.center_crop(img, size) # on 1.8+ also pads 138 | return img 139 | 140 | def main(): 141 | a = get_args() 142 | 143 | # Load CLIP models 144 | model_clip, _ = clip.load(a.model, jit=old_torch()) 145 | try: 146 | a.modsize = model_clip.visual.input_resolution 147 | except: 148 | a.modsize = 288 if a.model == 'RN50x4' else 384 if a.model == 'RN50x16' else 224 149 | if a.verbose is True: print(' using model', a.model) 150 | xmem = {'ViT-B/16':0.25, 'RN50':0.5, 'RN50x4':0.16, 'RN50x16':0.06, 'RN101':0.33} 151 | if a.model in xmem.keys(): 152 | a.samples = int(a.samples * xmem[a.model]) 153 | 154 | if a.translate: 155 | translator = Translator() 156 | 157 | if a.dualmod is not None: # second is vit-16 158 | model_clip2, _ = clip.load('ViT-B/16', jit=old_torch()) 159 | a.samples = int(a.samples * 0.23) 160 | dualmod_nums = list(range(a.steps))[a.dualmod::a.dualmod] 161 | print(' dual model every %d step' % a.dualmod) 162 | 163 | if a.aest != 0 and a.model in ['ViT-B/32', 'ViT-B/16', 'ViT-L/14']: 164 | aest = aesthetic_model(a.model).cuda() 165 | if a.dualmod is not None: 166 | aest2 = aesthetic_model('ViT-B/16').cuda() 167 | 168 | if a.enforce != 0: 169 | a.samples = int(a.samples * 0.5) 170 | 171 | if 'elastic' in a.transform: 172 | trform_f = transforms.transforms_elastic 173 | a.samples = int(a.samples * 0.95) 174 | elif 'custom' in a.transform: 175 | trform_f = transforms.transforms_custom 176 | a.samples = int(a.samples * 0.95) 177 | elif 'fast' in a.transform: 178 | trform_f = transforms.transforms_fast 179 | a.samples = int(a.samples * 0.95) 180 | else: 181 | trform_f = transforms.normalize() 182 | 183 | def enc_text(txt, model_clip=model_clip): 184 | if txt is None or len(txt)==0: return None 185 | embs = [] 186 | for subtxt in txt.split('|'): 187 | if ':' in subtxt: 188 | [subtxt, wt] = subtxt.split(':') 189 | wt = float(wt) 190 | else: wt = 1. 191 | emb = model_clip.encode_text(clip.tokenize(subtxt).cuda()[:77]) 192 | embs.append([emb.detach().clone(), wt]) 193 | return embs 194 | 195 | def enc_image(img_file, model_clip=model_clip): 196 | img_t = torch.from_numpy(img_read(img_file)/255.).unsqueeze(0).permute(0,3,1,2).cuda()[:,:3,:,:] 197 | in_sliced = slice_imgs([img_t], a.samples, a.modsize, transforms.normalize(), a.align)[0] 198 | emb = model_clip.encode_image(in_sliced) 199 | return emb.detach().clone() 200 | 201 | def read_text(in_txt): 202 | if os.path.isfile(in_txt): 203 | with open(in_txt, 'r', encoding="utf-8") as f: 204 | lines = f.read().splitlines() 205 | texts = [] 206 | for tt in lines: 207 | if len(tt.strip()) == 0: texts.append('') 208 | elif tt.strip()[0] != '#': texts.append(tt.strip()) 209 | else: 210 | texts = [in_txt] 211 | return texts 212 | 213 | # Encode inputs 214 | count = 0 215 | texts = [] 216 | styles = [] 217 | notexts = [] 218 | images = [] 219 | 220 | if a.in_txt is not None: 221 | texts = read_text(a.in_txt) 222 | if a.in_txt_pre is not None: 223 | pretexts = read_text(a.in_txt_pre) 224 | texts = [' | '.join([pick_(pretexts, n), texts[n]]).strip() for n in range(len(texts))] 225 | if a.in_txt_post is not None: 226 | postexts = read_text(a.in_txt_post) 227 | texts = [' | '.join([texts[n], pick_(postexts, n)]).strip() for n in range(len(texts))] 228 | if a.translate is True: 229 | texts = [tr.text for tr in translator.translate(texts)] 230 | # print(' texts trans', texts) 231 | key_txt_encs = [enc_text(txt) for txt in texts] 232 | if a.dualmod is not None: 233 | key_txt_encs2 = [enc_text(txt, model_clip2) for txt in texts] 234 | count = max(count, len(key_txt_encs)) 235 | 236 | if a.in_txt2 is not None: 237 | styles = read_text(a.in_txt2) 238 | if a.translate is True: 239 | styles = [tr.text for tr in translator.translate(styles)] 240 | # print(' styles trans', styles) 241 | key_styl_encs = [enc_text(style) for style in styles] 242 | if a.dualmod is not None: 243 | key_styl_encs2 = [enc_text(style, model_clip2) for style in styles] 244 | count = max(count, len(key_styl_encs)) 245 | 246 | if a.in_txt0 is not None: 247 | notexts = read_text(a.in_txt0) 248 | if a.translate is True: 249 | notexts = [tr.text for tr in translator.translate(notexts)] 250 | # print(' notexts trans', notexts) 251 | key_not_encs = [enc_text(notext) for notext in notexts] 252 | if a.dualmod is not None: 253 | key_not_encs2 = [enc_text(notext, model_clip2) for notext in notexts] 254 | count = max(count, len(key_not_encs)) 255 | 256 | if a.in_img is not None and os.path.exists(a.in_img): 257 | images = file_list(a.in_img) if os.path.isdir(a.in_img) else [a.in_img] 258 | key_img_encs = [enc_image(image) for image in images] 259 | if a.dualmod is not None: 260 | key_img_encs2 = [proc_image(image, model_clip2) for image in images] 261 | count = max(count, len(key_img_encs)) 262 | 263 | assert count > 0, "No inputs found!" 264 | 265 | if a.verbose is True: print(' samples:', a.samples) 266 | 267 | global params_tmp 268 | shape = [1, 3, *a.size] 269 | 270 | if a.gen == 'RGB': 271 | params_tmp, _, sz = pixel_image(shape, a.resume) 272 | params_tmp = params_tmp[0].cuda().detach() 273 | else: 274 | params_tmp, sz = resume_fft(a.resume, shape, decay=1.5, sd=1) 275 | if sz is not None: a.size = sz 276 | 277 | if a.depth != 0: 278 | _deptha = depth.InferDepthAny(a.depth_model) 279 | if a.depth_dir is not None: 280 | os.makedirs(a.depth_dir, exist_ok=True) 281 | print(' depth dir:', a.depth_dir) 282 | 283 | steps = a.steps 284 | glob_steps = count * steps 285 | if glob_steps == a.fstep: a.fstep = glob_steps // 2 # otherwise no motion 286 | 287 | workname = basename(a.in_txt) if a.in_txt is not None else basename(a.in_img) 288 | workname = txt_clean(workname) 289 | workdir = os.path.join(a.out_dir, workname + '-%s' % a.gen.lower()) 290 | if a.rem is not None: workdir += '-%s' % a.rem 291 | if a.dualmod is not None: workdir += '-dm%d' % a.dualmod 292 | if 'RN' in a.model.upper(): workdir += '-%s' % a.model 293 | tempdir = os.path.join(workdir, 'ttt') 294 | os.makedirs(tempdir, exist_ok=True) 295 | save_cfg(a, workdir) 296 | if a.in_txt is not None and os.path.isfile(a.in_txt): 297 | shutil.copy(a.in_txt, os.path.join(workdir, os.path.basename(a.in_txt))) 298 | if a.in_txt2 is not None and os.path.isfile(a.in_txt2): 299 | shutil.copy(a.in_txt2, os.path.join(workdir, os.path.basename(a.in_txt2))) 300 | 301 | midp = 0.5 302 | if a.anima: 303 | if a.gen == 'RGB': # zoom in 304 | m_scale = latent_anima([1], glob_steps, a.fstep, uniform=True, cubic=True, start_lat=[-0.3], verbose=False) 305 | m_scale = 1 + (m_scale + 0.3) * a.scale 306 | else: 307 | m_scale = latent_anima([1], glob_steps, a.fstep, uniform=True, cubic=True, start_lat=[0.6], verbose=False) 308 | m_scale = 1 - (m_scale-0.6) * a.scale 309 | m_shift = latent_anima([2], glob_steps, a.fstep, uniform=True, cubic=True, start_lat=[midp,midp], verbose=False) 310 | m_angle = latent_anima([1], glob_steps, a.fstep, uniform=True, cubic=True, start_lat=[midp], verbose=False) 311 | m_shear = latent_anima([1], glob_steps, a.fstep, uniform=True, cubic=True, start_lat=[midp], verbose=False) 312 | m_shift = (midp-m_shift) * a.shift * abs(m_scale-1) / a.scale 313 | m_angle = (midp-m_angle) * a.angle * abs(m_scale-1) / a.scale 314 | m_shear = (midp-m_shear) * a.shear * abs(m_scale-1) / a.scale 315 | 316 | def get_encs(encs, num): 317 | cnt = len(encs) 318 | if cnt == 0: return [] 319 | enc_1 = encs[min(num, cnt-1)] 320 | enc_2 = encs[min(num+1, cnt-1)] 321 | if a.interpol is not True: return [enc_1] * steps 322 | enc_pairs = [] 323 | for i in range(steps): 324 | enc1_step = [] 325 | if enc_1 is not None: 326 | if isinstance(enc_1, list): 327 | for enc, wt in enc_1: 328 | enc1_step.append([enc, wt * (steps-i)/steps]) 329 | else: 330 | enc1_step.append(enc_1 * (steps-i)/steps) 331 | enc2_step = [] 332 | if enc_2 is not None: 333 | if isinstance(enc_2, list): 334 | for enc, wt in enc_2: 335 | enc2_step.append([enc, wt * i/steps]) 336 | else: 337 | enc2_step.append(enc_2 * (steps-i)/steps) 338 | enc_pairs.append(enc1_step + enc2_step) 339 | return enc_pairs 340 | 341 | prev_enc = 0 342 | def process(num): 343 | global params_tmp, opt_state, params, image_f, optimizer 344 | 345 | txt_encs = get_encs(key_txt_encs, num) 346 | styl_encs = get_encs(key_styl_encs, num) 347 | not_encs = get_encs(key_not_encs, num) 348 | img_encs = get_encs(key_img_encs, num) 349 | if a.dualmod is not None: 350 | txt_encs2 = get_encs(key_txt_encs2, num) 351 | styl_encs2 = get_encs(key_styl_encs2, num) 352 | not_encs2 = get_encs(key_not_encs2, num) 353 | img_encs2 = get_encs(key_img_encs2, num) 354 | txt_encs = intrl(txt_encs, txt_encs2, a.dualmod) 355 | styl_encs = intrl(styl_encs, styl_encs2, a.dualmod) 356 | not_encs = intrl(not_encs, not_encs2, a.dualmod) 357 | img_encs = intrl(img_encs, img_encs2, a.dualmod) 358 | del txt_encs2, styl_encs2, not_encs2, img_encs2 359 | 360 | if a.verbose is True: 361 | if len(texts) > 0: print(' ref text: ', texts[min(num, len(texts)-1)][:80]) 362 | if len(styles) > 0: print(' ref style: ', styles[min(num, len(styles)-1)][:80]) 363 | if len(notexts) > 0: print(' ref avoid: ', notexts[min(num, len(notexts)-1)][:80]) 364 | if len(images) > 0: print(' ref image: ', basename(images[min(num, len(images)-1)])[:80]) 365 | 366 | pbar = ProgressBar(steps) 367 | for ii in range(steps): 368 | glob_step = num * steps + ii # save/transform 369 | 370 | txt_enc = txt_encs[ii % len(txt_encs)] if len(txt_encs) > 0 else None 371 | styl_enc = styl_encs[ii % len(styl_encs)] if len(styl_encs) > 0 else None 372 | not_enc = not_encs[ii % len(not_encs)] if len(not_encs) > 0 else None 373 | img_enc = img_encs[ii % len(img_encs)] if len(img_encs) > 0 else None 374 | 375 | model_clip_ = model_clip2 if a.dualmod is not None and ii in dualmod_nums else model_clip 376 | if a.aest != 0: 377 | aest_ = aest2 if a.dualmod is not None and ii in dualmod_nums else aest 378 | 379 | # MOTION: transform frame, reload params 380 | 381 | scale = m_scale[glob_step] if a.anima else 1 + a.scale 382 | shift = m_shift[glob_step] if a.anima else [0, a.shift] 383 | angle = m_angle[glob_step][0] if a.anima else a.angle 384 | shear = m_shear[glob_step][0] if a.anima else a.shear 385 | 386 | if a.gen == 'RGB': 387 | if a.depth > 0: 388 | params_tmp = depth_transform(params_tmp, _deptha, a.depth, scale, shift, a.colors, a.depth_dir, glob_step) 389 | params_tmp = frame_transform(params_tmp, a.size, angle, shift, scale, shear) 390 | params, image_f, _ = pixel_image([1, 3, *a.size], resume=params_tmp) 391 | img_tmp = None 392 | 393 | else: # FFT 394 | if old_torch(): # 1.7.1 395 | img_tmp = torch.irfft(params_tmp, 2, normalized=True, signal_sizes=a.size) 396 | if a.depth > 0: 397 | img_tmp = depth_transform(img_tmp, _deptha, a.depth, scale, shift, a.colors, a.depth_dir, glob_step) 398 | img_tmp = frame_transform(img_tmp, a.size, angle, shift, scale, shear) 399 | params_tmp = torch.rfft(img_tmp, 2, normalized=True) 400 | else: # 1.8+ 401 | if type(params_tmp) is not torch.complex64: 402 | params_tmp = torch.view_as_complex(params_tmp) 403 | img_tmp = torch.fft.irfftn(params_tmp, s=a.size, norm='ortho') 404 | if a.depth > 0: 405 | img_tmp = depth_transform(img_tmp, _deptha, a.depth, scale, shift, a.colors, a.depth_dir, glob_step) 406 | img_tmp = frame_transform(img_tmp, a.size, angle, shift, scale, shear) 407 | params_tmp = torch.fft.rfftn(img_tmp, s=a.size, dim=[2,3], norm='ortho') 408 | params_tmp = torch.view_as_real(params_tmp) 409 | params, image_f, _ = fft_image([1, 3, *a.size], sd=1, resume=params_tmp) 410 | 411 | if a.optimizer.lower() == 'adamw': 412 | optimizer = torch.optim.AdamW(params, a.lrate, weight_decay=0.01) 413 | elif a.optimizer.lower() == 'adamw_custom': 414 | optimizer = torch.optim.AdamW(params, a.lrate, weight_decay=0.01, betas=(.0,.999), amsgrad=True) 415 | elif a.optimizer.lower() == 'adam': 416 | optimizer = torch.optim.Adam(params, a.lrate) 417 | else: # adam_custom 418 | optimizer = torch.optim.Adam(params, a.lrate, betas=(.0,.999)) 419 | image_f = to_valid_rgb(image_f, colors = a.colors) 420 | del img_tmp 421 | 422 | if a.smooth is True and num + ii > 0: 423 | optimizer.load_state_dict(opt_state) 424 | 425 | ### optimization 426 | for ss in range(a.opt_step): 427 | loss = 0 428 | 429 | noise = a.noise * (torch.rand(1, 1, a.size[0], a.size[1]//2+1, 1)-0.5).cuda() if a.noise>0 else 0. 430 | img_out = image_f(noise, fixcontrast=a.fixcontrast) 431 | 432 | img_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f, a.align, a.macro)[0] 433 | out_enc = model_clip_.encode_image(img_sliced) 434 | 435 | if a.aest != 0 and a.model in ['ViT-B/32', 'ViT-B/16', 'ViT-L/14'] and aest_ is not None: 436 | loss -= 0.001 * a.aest * aest_(out_enc).mean() 437 | 438 | if a.gen == 'RGB': # empirical hack 439 | loss += abs(img_out.mean((2,3)) - 0.45).mean() # fix brightness 440 | loss += abs(img_out.std((2,3)) - 0.17).mean() # fix contrast 441 | 442 | if txt_enc is not None: 443 | for enc, wt in txt_enc: 444 | loss -= a.invert * wt * sim_func(enc, out_enc, a.sim) 445 | if styl_enc is not None: 446 | for enc, wt in styl_enc: 447 | loss -= wt * sim_func(enc, out_enc, a.sim) 448 | if not_enc is not None: # subtract text 449 | for enc, wt in not_enc: 450 | loss += wt * sim_func(enc, out_enc, a.sim) 451 | if img_enc is not None: 452 | for enc in img_enc: 453 | loss -= a.weight_img * sim_func(enc, out_enc, a.sim) 454 | if a.sharp != 0: # scharr|sobel|naive 455 | loss -= a.sharp * derivat(img_out, mode='naive') 456 | if a.enforce != 0: 457 | img_sliced = slice_imgs([image_f(noise, fixcontrast=a.fixcontrast)], a.samples, a.modsize, trform_f, a.align, a.macro)[0] 458 | out_enc2 = model_clip_.encode_image(img_sliced) 459 | loss -= a.enforce * sim_func(out_enc, out_enc2, a.sim) 460 | del out_enc2; torch.cuda.empty_cache() 461 | if a.expand > 0: 462 | global prev_enc 463 | if ii > 0: 464 | loss += a.expand * sim_func(prev_enc, out_enc, a.sim) 465 | prev_enc = out_enc.detach().clone() 466 | del img_out, img_sliced, out_enc; torch.cuda.empty_cache() 467 | 468 | optimizer.zero_grad() 469 | loss.backward() 470 | optimizer.step() 471 | 472 | ### save params & frame 473 | 474 | params_tmp = params[0].detach().clone() 475 | if a.smooth is True: 476 | opt_state = optimizer.state_dict() 477 | 478 | with torch.no_grad(): 479 | img_t = image_f(contrast=a.contrast, fixcontrast=a.fixcontrast)[0].permute(1,2,0) 480 | img_np = torch.clip(img_t*255, 0, 255).cpu().numpy().astype(np.uint8) 481 | imsave(os.path.join(tempdir, '%06d.jpg' % glob_step), img_np, quality=95) 482 | if a.verbose is True: cvshow(img_np) 483 | del img_t, img_np 484 | pbar.upd() 485 | 486 | params_tmp = params[0].detach().clone() 487 | 488 | glob_start = time.time() 489 | try: 490 | for i in range(count): 491 | process(i) 492 | except KeyboardInterrupt: 493 | pass 494 | 495 | os.system('ffmpeg -v warning -y -i %s/\%%06d.jpg "%s.mp4"' % (tempdir, os.path.join(workdir, workname))) 496 | 497 | 498 | if __name__ == '__main__': 499 | main() 500 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | imageio 3 | ipywidgets 4 | regex 5 | tqdm 6 | # googletrans==3.1.0a0 7 | torch>=1.7.1 8 | torchvision>=0.8.2 9 | opencv-python 10 | # sentence_transformers 11 | transformers>=4.6.0 12 | kornia>=0.5.3 13 | lpips 14 | omegaconf>=2.0.0 15 | pytorch-lightning>=1.0.8 16 | einops 17 | PyWavelets>=1.1.1 18 | git+https://github.com/fbcotter/pytorch_wavelets 19 | 20 | matplotlib 21 | scipy -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='aphantasia', 5 | version='3.1.0', 6 | description='CLIP + FFT/DWT/RGB text-to-image tools', 7 | url='https://github.com/eps696/aphantasia', 8 | author='vadim epstein', 9 | packages=['aphantasia'], 10 | # packages=find_packages(), 11 | install_requires=[], 12 | classifiers=[], 13 | ) 14 | -------------------------------------------------------------------------------- /shader_expo.py: -------------------------------------------------------------------------------- 1 | # # CPPNs in GLSL 2 | # taken from https://github.com/wxs/cppn-to-glsl 3 | # Original code was for the NIPS Creativity Workshop submission 'Interactive CPPNs in GLSL' 4 | # modified Mordvintsev et al's CPPN notebook from https://github.com/tensorflow/lucid/blob/master/notebooks/differentiable-parameterizations/xy2rgb.ipynb 5 | # https://www.apache.org/licenses/LICENSE-2.0 6 | 7 | import numpy as np 8 | 9 | ### Code to convert to GLSL/HLSL 10 | 11 | def cppn_to_shader(layers, fn_name='cppn_fn', mode='shadertoy', verbose=False, fix_aspect=True, size=[1., 1.], precision=8): 12 | """ 13 | Generate shader code out of the list of dicts defining trained CPPN layers 14 | mode='vvvv': 15 | Exports TextureFX shader file for vvvv 16 | mode='buffer': 17 | Exports txt file with values for dynamicbuffer input in TextureFX shader for vvvv (and optionally shader itself) 18 | mode='td': 19 | Exports code compatible with TouchDesigner: can be dropped into a 'GLSL TOP' 20 | (see https://docs.derivative.ca/GLSL_TOP). TouchDesigner can be found at http://derivative.ca 21 | mode='shadertoy': 22 | Exports code compatible with the ShaderToy editor at http://shadertoy.com 23 | mode='bookofshaders': 24 | Exports code compatible with the Book Of Shaders editor here http://editor.thebookofshaders.com/ 25 | """ 26 | 27 | # Set True to export TFX template for dynamic buffer mode (just once) 28 | export_tfx = False 29 | 30 | # the xy2rgb cppn's internal size is the output of its first layer (pre-activation) 31 | # so will just inspect that to figure it out 32 | n_hidden = layers[0]['weights'].shape[-1] 33 | if n_hidden % 4 != 0: 34 | raise ValueError('Currently only support multiples of 4 for hidden layer size') 35 | modes = {'vvvv', 'buffer', 'td', 'shadertoy', 'bookofshaders'} 36 | if mode not in modes: 37 | raise ValueError('Mode {} not one of the supported modes: {}'.format(mode, modes)) 38 | 39 | if verbose and precision < 8: print(' .. precision', precision) 40 | fmt = '%' + '.%df' % precision 41 | 42 | global hlsl; hlsl = None 43 | 44 | if mode == 'buffer': 45 | global sbW; sbW = [] 46 | buffer = True 47 | else: buffer = False 48 | 49 | if mode in ['vvvv', 'buffer']: 50 | hlsl = True 51 | snippet = """ 52 | float2 R:TARGETSIZE; 53 | float4 """ 54 | for i in range(2, len(layers)-2): 55 | snippet += "in%d_, " % i 56 | snippet = snippet[:-2] + ';' 57 | if mode == 'buffer': 58 | snippet += '\nStructuredBuffer sbW;' 59 | snippet += """ 60 | #define mod(x,y) (x - y * floor(x/y)) 61 | #define N_HIDDEN {} 62 | float4 {}(float2 uv) {{ 63 | float4 bufA[N_HIDDEN/4]; 64 | float4 bufB[N_HIDDEN/2]; 65 | float4 tmp; 66 | bufB[0] = float4(uv.x, uv.y, 0., 0.); 67 | """.format(n_hidden, fn_name) 68 | elif mode == 'td': 69 | snippet = """ 70 | uniform float uIn0; 71 | uniform float uIn1; 72 | uniform float uIn2; 73 | uniform float uIn3; 74 | out vec4 fragColor; 75 | """ 76 | elif mode == 'shadertoy': 77 | snippet =""" 78 | #ifdef GL_ES 79 | precision lowp float; 80 | #endif 81 | """ 82 | elif mode == 'bookofshaders': 83 | snippet =""" 84 | #ifdef GL_ES 85 | precision lowp float; 86 | #endif 87 | uniform vec2 u_resolution; 88 | uniform vec2 u_mouse; 89 | uniform float u_time; 90 | """ 91 | 92 | if not mode in ['vvvv', 'buffer']: 93 | snippet += """ 94 | #define N_HIDDEN {} 95 | vec4 bufA[N_HIDDEN/4]; 96 | vec4 bufB[N_HIDDEN/2]; 97 | vec4 {}(vec2 coordinate, float in0, float in1, float in2, float in3) {{ 98 | vec4 tmp; 99 | bufB[0] = vec4(coordinate.x, coordinate.y, 0., 0.); 100 | """.format(n_hidden, fn_name) 101 | 102 | def vec(a): 103 | """Take a Python array of length 4 (or less) and output code for a GLSL vec4 or HLSL float4, possibly zero-padded at the end""" 104 | global hlsl, sbW 105 | if len(a) == 4: 106 | if hlsl is True: 107 | if 'sbW' in globals(): # check if sbW defined (working with structbuffer input instead of values) 108 | for i in range(4): 109 | sbW.append(a[i]) 110 | return 'sbW[%d]' % (len(sbW)//4-1) 111 | # return 'float4({})'.format(', '.join(str(x) for x in a)) 112 | return 'float4({})'.format(', '.join(fmt % x for x in a)) 113 | else: 114 | # return 'vec4({})'.format(', '.join(str(x) for x in a)) 115 | return 'vec4({})'.format(', '.join(fmt % x for x in a)) 116 | else: 117 | assert len(a) < 4 , 'Length must less than 4' 118 | return vec(np.concatenate([a, [0.]*(4-len(a))])) 119 | 120 | def mat(a): 121 | # Take a numpy matrix of 4 rows and 4 or fewer columns, and output GLSL or HLSL code for a mat4, 122 | # possibly with zeros padded in the last columns 123 | if a.shape[0] < 4: 124 | m2 = np.vstack([a, [[0.,0.,0.,0.]] * (4 - a.shape[0])]) 125 | return mat(m2) 126 | assert a.shape[0] == 4, 'Expected a of shape (4,n<=4). Got: {}.'.format(a.shape) 127 | global hlsl 128 | if hlsl is True: 129 | return 'float4x4({})'.format(', '.join(vec(row) for row in a)) 130 | else: 131 | return 'mat4({})'.format(', '.join(vec(row) for row in a)) 132 | 133 | for layer_i, layer_dict in enumerate(layers): 134 | weight = layer_dict['weights'] 135 | bias = layer_dict['bias'] 136 | activation = layer_dict['activation'] 137 | 138 | _, _, from_size, to_size = weight.shape 139 | if verbose: print('Processing layer {}. from_size={}, to_size={} .. shape {}'.format(layer_i, from_size, to_size, weight.shape)) 140 | snippet += '\n // layer {} \n'.format(layer_i) 141 | 142 | # First, compute the transformation from the last layer into bufA 143 | for to_index in range(max(1,to_size//4)): 144 | #Again, the max(1) is important here, because to_size is 3 for the last layer! 145 | if verbose: print(' generating output {} into bufA'.format(to_index)) 146 | snippet += 'bufA[{}] = {}'.format(to_index, vec(bias[to_index*4:to_index*4+4])) 147 | if verbose: print('bufA[{}] = {} . . .'.format(to_index, vec(bias[to_index*4:to_index*4+4]))) 148 | for from_index in range(max(1,from_size//4)): 149 | # the 'max' in the above loop gives us a special case for the first layer, where there are only two inputs. 150 | if mode in ['vvvv', 'buffer']: 151 | snippet += ' + mul(bufB[{}], {})'.format(from_index, mat(weight[0, 0, from_index*4:from_index*4+4, to_index*4:to_index*4+4])) 152 | # snippet += ' + mul({}, bufB[{}])'.format(mat(weight[0, 0, from_index*4:from_index*4+4, to_index*4:to_index*4+4]), from_index) 153 | else: 154 | snippet += ' + {} * bufB[{}]'.format(mat(weight[0, 0, from_index*4:from_index*4+4, to_index*4:to_index*4+4]), from_index) 155 | if mode in ['vvvv', 'buffer'] and layer_i > 1 and layer_i < len(layers)-2: 156 | suffix = ['x','y','z','w'] 157 | snippet += ' + in{}_.{}'.format(layer_i, suffix[to_index%4]) 158 | else: 159 | if layer_i == 3: 160 | snippet += ' + in{}'.format(to_index%4) 161 | snippet += ';\n' 162 | 163 | # print('export', layer_i, activation) 164 | if to_size != 3: 165 | if verbose: print(' Doing the activation into bufB') 166 | for to_index in range(to_size//4): 167 | if activation == 'comp': 168 | snippet += 'tmp = atan(bufA[{}]);\n'.format(to_index) 169 | snippet += 'bufB[{}] = tmp/0.67;\n'.format(to_index) 170 | snippet += 'bufB[{}] = (tmp*tmp) / 0.6;\n'.format(to_index + to_size//4) 171 | elif activation == 'unbias': 172 | snippet += 'tmp = atan(bufA[{}]);\n'.format(to_index) 173 | snippet += 'bufB[{}] = tmp/0.67;\n'.format(to_index) 174 | snippet += 'bufB[{}] = (tmp*tmp - 0.45) / 0.396;\n'.format(to_index + to_size//4) 175 | elif activation == 'relu': 176 | snippet += 'bufB[{}] = (max(bufA[{}], 0.) - 0.4) / 0.58;\n'.format(to_index, to_index) 177 | else: 178 | raise ValueError('Unknown activation: {}'.format(activation.__name__)) 179 | else: 180 | if verbose: print(' Sigmoiding the last layer') 181 | # sigmoid at the last layer 182 | sigmoider = lambda s: '1. / (1. + exp(-{}))'.format(s) 183 | if mode in ['vvvv', 'buffer']: 184 | snippet += '\n return float4(({}).rgb, 1.0);\n'.format(sigmoider('bufA[0]')) 185 | # snippet += '\n return float4((1. / (1. + exp(-bufA[0]))).xyz, 1.0);\n}' 186 | else: 187 | snippet += '\n return vec4(({}).xyz, 1.0);\n'.format(sigmoider('bufA[0]')) 188 | # snippet += '\n return vec4((1. / (1. + exp(-bufA[0]))).xyz, 1.0);\n}' 189 | snippet += '}\n' 190 | 191 | if mode in ['vvvv', 'buffer']: 192 | snippet += """ 193 | float4 PS(float4 p:SV_Position, float2 uv:TEXCOORD0): SV_Target { 194 | uv = 2 * (uv - 0.5); 195 | """ 196 | if fix_aspect: 197 | snippet += """ 198 | uv *= R/R.y; 199 | """ 200 | snippet += """ 201 | return {}(2*uv); 202 | }} 203 | technique10 Process 204 | {{ pass P0 205 | {{ SetPixelShader(CompileShader(ps_4_0,PS())); }} 206 | }} 207 | """.format(fn_name) 208 | elif mode == 'td': 209 | snippet += """ 210 | void main() { 211 | // Normalized pixel coordinates (from 0 to 1) 212 | vec2 uv = vUV.xy; 213 | """ 214 | if fix_aspect: 215 | snippet += """ 216 | // TODO: don't know how to find the resolution of the GLSL Top output to fix aspect... 217 | """ 218 | snippet += """ 219 | // Shifted to the form expected by the CPPN 220 | uv.xy = vec2(1., -1.) * 2. * (uv.xy - vec2(0.5, 0.5)); 221 | uv.y /= {} / {}; 222 | // Output to screen 223 | fragColor = TDOutputSwizzle({}(uv.xy, uIn0, uIn1, uIn2, uIn3)); 224 | }} 225 | """.format(float(size[0]), float(size[1]), fn_name) 226 | elif mode == 'shadertoy': 227 | snippet += """ 228 | void mainImage( out vec4 fragColor, in vec2 fragCoord ) { 229 | // Normalized pixel coordinates (from 0 to 1) 230 | vec2 uv = fragCoord/iResolution.xy; 231 | vec2 mouseNorm = (iMouse.xy / iResolution.xy) - vec2(0.5, 0.5); 232 | """ 233 | if fix_aspect: 234 | snippet += """ 235 | uv.x *= iResolution.x / iResolution.y; 236 | uv.x -= ((iResolution.x / iResolution.y) - 1.) /2.; 237 | """ 238 | snippet += """ 239 | // Shifted to the form expected by the CPPN 240 | uv = vec2(1., -1.) * 1.5 * (uv - vec2(0.5, 0.5)); 241 | uv.y /= {} / {}; 242 | // Output to screen 243 | fragColor = {}(uv, 0.23*sin(iTime), 0.32*sin(0.69*iTime), 0.32*sin(0.44*iTime), 0.23*sin(1.23*iTime)); 244 | }} 245 | """.format(float(size[0]), float(size[1]), fn_name) 246 | elif mode=='bookofshaders': 247 | snippet += """ 248 | void main() { 249 | vec2 st = gl_FragCoord.xy/u_resolution.xy; 250 | """ 251 | if fix_aspect: 252 | snippet += """ 253 | st.x *= u_resolution.x/u_resolution.y; 254 | st.x -= ((u_resolution.x / u_resolution.y) - 1.) /2.; 255 | """ 256 | snippet += """ 257 | st = vec2(1., -1.) * 1.5 * (st - vec2(0.5, 0.5)); 258 | st.y /= {} / {}; 259 | gl_FragColor = {}(st, 0.23*sin(u_time), 0.32*sin(0.69*u_time), 0.32*sin(0.44*u_time), 0.23*sin(1.23*u_time)); 260 | }} 261 | """.format(float(size[0]), float(size[1]), fn_name) 262 | 263 | if buffer is True: 264 | # buffer = ','.join('%.8f'%x for x in sbW) 265 | buffer = ','.join(fmt % x for x in sbW) 266 | if export_tfx == True: 267 | with open('CPPN-%d-%d.tfx' % (len(layers)-1, n_hidden), 'w') as f: 268 | f.write(snippet) 269 | # print(' total values', len(sbW)) 270 | return buffer 271 | else: 272 | return snippet 273 | 274 | --------------------------------------------------------------------------------