├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── VQGAN+CLIP_Video_with_Optical_Flow.ipynb ├── dream.py ├── samples ├── sample_cat.mp4 └── sample_policeman.mp4 └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "CLIP"] 2 | path = CLIP 3 | url = https://github.com/openai/CLIP 4 | [submodule "taming-transformers"] 5 | path = taming-transformers 6 | url = https://github.com/CompVis/taming-transformers.git 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 robobeebop 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VQGAN-CLIP-Video 2 | 3 |
4 | 5 | Open In Colab 6 | 7 | ![diagram](https://user-images.githubusercontent.com/73405777/147384444-ee19c596-b79f-4e09-9f8f-378d857f3d2f.png) 8 | 9 | 10 | 11 | https://user-images.githubusercontent.com/73405777/161429669-0fdd9506-828b-4971-84cc-ee91134bf0b9.mp4 12 | 13 |
14 | -------------------------------------------------------------------------------- /VQGAN+CLIP_Video_with_Optical_Flow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "VQGAN+CLIP Video with Optical Flow", 7 | "private_outputs": true, 8 | "provenance": [], 9 | "collapsed_sections": [ 10 | "XYU-20jFsaW8" 11 | ] 12 | }, 13 | "kernelspec": { 14 | "name": "python3", 15 | "display_name": "Python 3" 16 | }, 17 | "language_info": { 18 | "name": "python" 19 | }, 20 | "accelerator": "GPU" 21 | }, 22 | "cells": [ 23 | { 24 | "cell_type": "markdown", 25 | "source": [ 26 | "# Setup" 27 | ], 28 | "metadata": { 29 | "id": "XYU-20jFsaW8" 30 | } 31 | }, 32 | { 33 | "cell_type": "code", 34 | "source": [ 35 | "# @title Prepare workspace\n", 36 | "! rm -r -f .config sample_data\n", 37 | "! git clone --recursive https://github.com/robobeebop/VQGAN-CLIP-Video.git ." 38 | ], 39 | "metadata": { 40 | "cellView": "form", 41 | "id": "cjDu-Hbtmrwb" 42 | }, 43 | "execution_count": null, 44 | "outputs": [] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "source": [ 49 | "# @title Mount Drive\n", 50 | "\n", 51 | "from google.colab import output\n", 52 | "from google.colab import drive\n", 53 | "drive.mount('/content/drive')" 54 | ], 55 | "metadata": { 56 | "cellView": "form", 57 | "id": "0mEllCWgnrf2" 58 | }, 59 | "execution_count": null, 60 | "outputs": [] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "source": [ 65 | "# @title Install Required Libraries\n", 66 | "! pip install lpips ftfy omegaconf einops pytorch-lightning\n", 67 | "! pip3 install transformers" 68 | ], 69 | "metadata": { 70 | "cellView": "form", 71 | "id": "lhwh9J2klrkf" 72 | }, 73 | "execution_count": null, 74 | "outputs": [] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "source": [ 79 | "# RUN" 80 | ], 81 | "metadata": { 82 | "id": "0A_d725Ssdm3" 83 | } 84 | }, 85 | { 86 | "cell_type": "code", 87 | "source": [ 88 | "# @title Select a model\n", 89 | "from dream import Dream\n", 90 | "from os.path import exists\n", 91 | "!mkdir checkpoints\n", 92 | "ckpt_dir = \"/content/checkpoints/\"\n", 93 | "\n", 94 | "vqgan_model = \"imagenet_1024\" #@param [\"imagenet_1024\", \"imagenet_16384\", \"coco\", \"sflckr\"]\n", 95 | "\n", 96 | "vqgan_options = {\n", 97 | " \"imagenet_1024\": [\"1-7QlixzWxZAO8ktGFqvxrZ_JzapzI5hH\", \"1-8mSOBsutfkE95piiGf4ZuVX0zkAwzkn\"],\n", 98 | " \"imagenet_16384\": [\"1_1q5zxEBx17AyTALEhGqhSsS7tyCJ4fe\", \"1-0D4pbu7NHrvWzTfbw4hiA1Sno75Z2_C\"],\n", 99 | " \"coco\": [\"1-9gq1a4yGOKC3rDw-X9NBe5_JVcKcLPG\", \"1-CPBZXsCgCv-Z6Uy4Sf4lKeyqG_C5i-Y\"],\n", 100 | " \"sflckr\": [\"1iIgSRV4H6og3l2myXPRE043ULoPlqn8w\", \"1-1vMpPmB6QZhGzriXG9iI6WeFZLl7VP2\"],\n", 101 | "}\n", 102 | "\n", 103 | "yaml, ckpt = ckpt_dir + \"%s.yaml\"%vqgan_model, ckpt_dir + \"%s.ckpt\"%vqgan_model\n", 104 | "\n", 105 | "if not exists(ckpt) or not exists(yaml):\n", 106 | " yaml_id = vqgan_options[vqgan_model][0]\n", 107 | " ckpt_id = vqgan_options[vqgan_model][1]\n", 108 | "\n", 109 | " !gdown --id \"$yaml_id\" -O \"$yaml\"\n", 110 | " !gdown --id \"$ckpt_id\" -O \"$ckpt\"\n", 111 | "\n", 112 | "\n", 113 | "\n", 114 | "dream = Dream()\n", 115 | "dream.cook([yaml, ckpt])" 116 | ], 117 | "metadata": { 118 | "cellView": "form", 119 | "id": "gzz5pv6-nXLk" 120 | }, 121 | "execution_count": null, 122 | "outputs": [] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "source": [ 127 | "Run the cell above once when it's initial googlecolab run or you want to change the VQGAN model" 128 | ], 129 | "metadata": { 130 | "id": "TRXTBau3bWdf" 131 | } 132 | }, 133 | { 134 | "cell_type": "code", 135 | "source": [ 136 | "# @title Dream { display-mode: \"form\" }\n", 137 | "#@markdown You can see output frames in /content/output folder.\n", 138 | "\n", 139 | "from dream import cv2\n", 140 | "from dream import save_img\n", 141 | "from dream import reduce_res\n", 142 | "from dream import np\n", 143 | "from dream import get_opflow_image\n", 144 | "from dream import PIL\n", 145 | "from dream import trange\n", 146 | "from dream import glob\n", 147 | "\n", 148 | "#@markdown ---\n", 149 | "\n", 150 | "#@markdown Describe how deepdream should look like. Each scene ( | ... | ) can be weighted: \"deepdream:3 | dogs:-1\"\n", 151 | "text_prompts = \"trending on artstation | lovecraftian horror | deepdream | vibrant colors | 4k | made by Edvard Munch\" #@param {type:\"string\"}\n", 152 | "\n", 153 | "#@markdown ---\n", 154 | "\n", 155 | "#@markdown Video paths.\n", 156 | "vid_path = '/content/samples/sample_policeman.mp4' #@param {type:\"string\"}\n", 157 | "output_vid_path = '/content/samples/deepdreamed_policeman.mp4' #@param {type:\"string\"}\n", 158 | "#@markdown ---\n", 159 | "\n", 160 | "#@markdown Play around with these settings, finding optimal settings may vary video to video. Set both to 0 for more chaotic experience.\n", 161 | "frame_weight = 10#@param {type:\"number\"}\n", 162 | "previous_frame_weight = 0.1#@param {type:\"number\"}\n", 163 | "#@markdown ---\n", 164 | "\n", 165 | "#@markdown Usual VQGAN+CLIP settings\n", 166 | "step_size = 0.15 #@param {type:\"slider\", min:0, max:1, step:0.05}\n", 167 | "iter_n = 5#@param {type:\"number\"}\n", 168 | "#@markdown ---\n", 169 | "\n", 170 | "#@markdown Dream more intense on first frame of the video. \n", 171 | "do_wait_first_frame = True #@param {type:\"boolean\"}\n", 172 | "wait_step_size = 0.15 #@param {type:\"slider\", min:0, max:1, step:0.05}\n", 173 | "wait_iter_n = 15#@param {type:\"number\"}\n", 174 | "#@markdown ---\n", 175 | "\n", 176 | "#@markdown Weights of how previous deepdreamed frame should effect current frame.\n", 177 | "blendflow = 0.6#@param {type:\"slider\", min:0, max:1, step:0.05}\n", 178 | "blendstatic = 0.6#@param {type:\"slider\", min:0, max:1, step:0.05}\n", 179 | "#@markdown ---\n", 180 | "#@markdown Video resolution and fps\n", 181 | "w = 1920#@param {type:\"number\"}\n", 182 | "h = 1080#@param {type:\"number\"}\n", 183 | "fps = 24#@param {type:\"number\"}\n", 184 | "#@markdown ---\n", 185 | "#@markdown Make a test run.\n", 186 | "is_test = False #@param {type:\"boolean\"}\n", 187 | "test_finish_at = 24#@param {type:\"number\"}\n", 188 | "#@markdown ---\n", 189 | "#@markdown Get all frames from video, can be set to False if video didn't change. \n", 190 | "video_to_frames = True #@param {type:\"boolean\"}\n", 191 | "\n", 192 | "!mkdir input\n", 193 | "!mkdir output\n", 194 | "!rm -r -f ./output/*.jpg\n", 195 | "\n", 196 | "\n", 197 | "if(video_to_frames):\n", 198 | " !rm -r -f ./input/*.jpg\n", 199 | " vidcap = cv2.VideoCapture(vid_path)\n", 200 | " success,image = vidcap.read()\n", 201 | " index = 1\n", 202 | " while success:\n", 203 | " cv2.imwrite(\"./input/%04d.jpg\" % index, image)\n", 204 | " success, image = vidcap.read()\n", 205 | " index += 1\n", 206 | "\n", 207 | "x, y = reduce_res((w, h))\n", 208 | "img_arr = sorted(glob('input/*.jpg'))\n", 209 | "\n", 210 | "np_img = np.float32(PIL.Image.open(img_arr[0]))\n", 211 | "np_img = cv2.resize(np_img, dsize=(x, y), interpolation=cv2.INTER_CUBIC)\n", 212 | "h, w, c = np_img.shape\n", 213 | "\n", 214 | "frame = None\n", 215 | "\n", 216 | "if do_wait_first_frame:\n", 217 | " frame = dream.deepdream(np_img, text_prompts, [x, y], iter_n=wait_iter_n, step_size=step_size, init_weight=frame_weight)\n", 218 | "else:\n", 219 | " frame = dream.deepdream(np_img, text_prompts, [x, y], iter_n=iter_n, step_size=step_size, init_weight=frame_weight)\n", 220 | "\n", 221 | "frame = cv2.resize(frame, dsize=(x, y), interpolation=cv2.INTER_CUBIC)\n", 222 | "save_img(frame, 'output/%04d.jpg'%0)\n", 223 | "\n", 224 | "img_range = trange(len(img_arr[:test_finish_at]), desc=\"Dreaming\") if is_test else trange(len(img_arr), desc=\"Dreaming\")\n", 225 | "prev_frame = None\n", 226 | "for i in img_range: \n", 227 | " if previous_frame_weight != 0:\n", 228 | " prev_frame = np.copy(frame)\n", 229 | "\n", 230 | " img = img_arr[i]\n", 231 | " np_prev_img = np_img\n", 232 | " np_img = np.float32(PIL.Image.open(img))\n", 233 | " np_img = cv2.resize(np_img, dsize=(x, y), interpolation=cv2.INTER_CUBIC)\n", 234 | " frame = cv2.resize(frame, dsize=(x, y), interpolation=cv2.INTER_CUBIC)\n", 235 | " \n", 236 | " frame_flow_masked, background_masked = get_opflow_image(np_prev_img, frame, np_img, blendflow, blendstatic)\n", 237 | " frame = frame_flow_masked + background_masked\n", 238 | " frame = dream.deepdream(frame, text_prompts, [x, y], iter_n=iter_n, init_weight=frame_weight, step_size=step_size, image_prompts=prev_frame, image_prompt_weight=previous_frame_weight)\n", 239 | " \n", 240 | " save_img(frame, 'output/%04d.jpg'%i)\n", 241 | "\n", 242 | "\n", 243 | "mp4_fourcc = cv2.VideoWriter_fourcc(*'MP4V')\n", 244 | "out = cv2.VideoWriter(output_vid_path, mp4_fourcc, fps, (w, h))\n", 245 | "filelist = sorted(glob('output/*.jpg'))\n", 246 | "\n", 247 | "for i in trange(len(filelist), desc=\"Generating Video\"):\n", 248 | " img = cv2.imread(filelist[i])\n", 249 | " img = cv2.resize(img, dsize=(w, h), interpolation=cv2.INTER_CUBIC)\n", 250 | " out.write(img)\n", 251 | "out.release()" 252 | ], 253 | "metadata": { 254 | "id": "O6GwujqCowd9" 255 | }, 256 | "execution_count": null, 257 | "outputs": [] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "source": [ 262 | "# @title Download Generated Video.\n", 263 | "# @markdown (doesn't work on Safari)\n", 264 | "from google.colab import files\n", 265 | "\n", 266 | "files.download(output_vid_path)" 267 | ], 268 | "metadata": { 269 | "cellView": "form", 270 | "id": "PXZmbS0IeKUn" 271 | }, 272 | "execution_count": null, 273 | "outputs": [] 274 | } 275 | ] 276 | } -------------------------------------------------------------------------------- /dream.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | class Dream: 4 | def __init__(self, isHighVRAM=True) -> None: 5 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 6 | self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) 7 | self.resLimit = 4.2e5 if isHighVRAM else 2.5e5 8 | 9 | def cook(self, vqgan_path, cut_n=32, cut_pow=1., prompts="", init_weight=4, clip_model='ViT-B/16'): 10 | self.vqgan_config = vqgan_path[0] 11 | self.vqgan_checkpoint = vqgan_path[1] 12 | self.model = load_vqgan_model(self.vqgan_config, self.vqgan_checkpoint).to(self.device) 13 | 14 | self.clip_model = clip_model 15 | self.perceptor = clip.load(self.clip_model, jit=False)[0].eval().requires_grad_(False).to(self.device) 16 | 17 | self.cut_size = self.perceptor.visual.input_resolution 18 | self.e_dim = self.model.quantize.e_dim 19 | self.f = 2**(self.model.decoder.num_resolutions - 1) 20 | self.make_cutouts = MakeCutouts(self.cut_size, cutn=cut_n, cut_pow=cut_pow) 21 | 22 | self.z_min = self.model.quantize.embedding.weight.min(dim=0).values[None, :, None, None] 23 | self.z_max = self.model.quantize.embedding.weight.max(dim=0).values[None, :, None, None] 24 | self.pMs = [] 25 | self.init_weight = init_weight 26 | prompts = prompts.split("|") 27 | 28 | for prompt in prompts: 29 | txt, weight, stop = parse_prompt(prompt) 30 | embed = self.perceptor.encode_text(clip.tokenize(txt).to(self.device)).float() 31 | self.pMs.append(Prompt(embed, weight, stop).to(self.device)) 32 | 33 | self.pMs.append(Prompt(embed, weight).to(self.device)) 34 | 35 | def deepdream(self, init_image, iter_n=25, step_size=0.3): 36 | 37 | pil_image = Image.fromarray((init_image * 1).astype(np.uint8)).convert('RGB') 38 | self.z, *_ = self.model.encode(TF.to_tensor(pil_image).to(self.device).unsqueeze(0) * 2 - 1) 39 | 40 | self.z_orig = self.z.clone() 41 | self.z.requires_grad_(True) 42 | self.opt = optim.Adam([self.z], lr=step_size) 43 | 44 | 45 | gen = torch.Generator().manual_seed(0) 46 | embed = torch.empty([1, self.perceptor.visual.output_dim]).normal_(generator=gen) 47 | 48 | try: 49 | for i in range(iter_n): 50 | out = self.train(i, iter_n) 51 | except KeyboardInterrupt: 52 | pass 53 | 54 | return np.float32(TF.to_pil_image(out[0].cpu())) 55 | 56 | def train(self, i, iter_n): 57 | torch.set_grad_enabled(True) 58 | self.opt.zero_grad() 59 | lossAll = self.ascend_txt() 60 | loss = sum(lossAll) 61 | loss.backward() 62 | torch.set_grad_enabled(False) 63 | self.opt.step() 64 | 65 | with torch.no_grad(): 66 | self.z.copy_(self.z.maximum(self.z_min).minimum(self.z_max)) 67 | 68 | if i == iter_n-1: 69 | return self.checkout() 70 | 71 | return None 72 | 73 | 74 | def ascend_txt(self): 75 | out = self.synth() 76 | iii = self.perceptor.encode_image(self.normalize(self.make_cutouts(out))).float() 77 | result = [] 78 | if self.init_weight: 79 | result.append(F.mse_loss(self.z, self.z_orig) * self.init_weight / 2) 80 | for prompt in self.pMs: 81 | result.append(prompt(iii)) 82 | return result 83 | 84 | def synth(self): 85 | z_q = vector_quantize(self.z.movedim(1, 3), self.model.quantize.embedding.weight).movedim(3, 1) 86 | return clamp_with_grad(self.model.decode(z_q).add(1).div(2), 0, 1) 87 | 88 | @torch.no_grad() 89 | def checkout(self): 90 | out = self.synth() 91 | return out -------------------------------------------------------------------------------- /samples/sample_cat.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robobeebop/VQGAN-CLIP-Video/72414dee1764397654121c75dbb576b038da68e4/samples/sample_cat.mp4 -------------------------------------------------------------------------------- /samples/sample_policeman.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robobeebop/VQGAN-CLIP-Video/72414dee1764397654121c75dbb576b038da68e4/samples/sample_policeman.mp4 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import io 3 | import cv2 4 | import sys 5 | import PIL 6 | import math 7 | import lpips 8 | import torch 9 | import requests 10 | import numpy as np 11 | 12 | sys.path.append('./CLIP') 13 | sys.path.append('./taming-transformers') 14 | 15 | from os import path 16 | from PIL import Image 17 | from glob import glob 18 | from CLIP import clip 19 | from pathlib import Path 20 | from IPython import display 21 | from torch import nn, optim 22 | from google.colab import output 23 | from omegaconf import OmegaConf 24 | from torchvision import transforms 25 | from torch.nn import functional as F 26 | from tqdm.notebook import tqdm, trange 27 | from taming.models import cond_transformer, vqgan 28 | from torchvision.transforms import functional as TF 29 | 30 | def reduce_res(res, max_res_value=4.5e5, max_res_scale=1.): # max limit aprx 700x700 = 49e4 31 | x1, y1 = res 32 | if x1 * y1 < max_res_value: 33 | return x1, y1 34 | x = (max_res_value**(1/2)) / (x1/y1)**(1/2) 35 | return int(max_res_scale*x1*x/y1), int(max_res_scale*x) 36 | 37 | def sinc(x): 38 | return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([])) 39 | 40 | 41 | def lanczos(x, a): 42 | cond = torch.logical_and(-a < x, x < a) 43 | out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([])) 44 | return out / out.sum() 45 | 46 | 47 | def ramp(ratio, width): 48 | n = math.ceil(width / ratio + 1) 49 | out = torch.empty([n]) 50 | cur = 0 51 | for i in range(out.shape[0]): 52 | out[i] = cur 53 | cur += ratio 54 | return torch.cat([-out[1:].flip([0]), out])[1:-1] 55 | 56 | 57 | def resample(input, size, align_corners=True): 58 | n, c, h, w = input.shape 59 | dh, dw = size 60 | 61 | input = input.view([n * c, 1, h, w]) 62 | 63 | if dh < h: 64 | kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype) 65 | pad_h = (kernel_h.shape[0] - 1) // 2 66 | input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect') 67 | input = F.conv2d(input, kernel_h[None, None, :, None]) 68 | 69 | if dw < w: 70 | kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype) 71 | pad_w = (kernel_w.shape[0] - 1) // 2 72 | input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect') 73 | input = F.conv2d(input, kernel_w[None, None, None, :]) 74 | 75 | input = input.view([n, c, h, w]) 76 | return F.interpolate(input, size, mode='bicubic', align_corners=align_corners) 77 | 78 | class ReplaceGrad(torch.autograd.Function): 79 | @staticmethod 80 | def forward(ctx, x_forward, x_backward): 81 | ctx.shape = x_backward.shape 82 | return x_forward 83 | 84 | @staticmethod 85 | def backward(ctx, grad_in): 86 | return None, grad_in.sum_to_size(ctx.shape) 87 | 88 | replace_grad = ReplaceGrad.apply 89 | 90 | class ClampWithGrad(torch.autograd.Function): 91 | @staticmethod 92 | def forward(ctx, input, min, max): 93 | ctx.min = min 94 | ctx.max = max 95 | ctx.save_for_backward(input) 96 | return input.clamp(min, max) 97 | 98 | @staticmethod 99 | def backward(ctx, grad_in): 100 | input, = ctx.saved_tensors 101 | return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None 102 | 103 | 104 | def vector_quantize(x, codebook): 105 | d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T 106 | indices = d.argmin(-1) 107 | x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook 108 | return replace_grad(x_q, x) 109 | 110 | clamp_with_grad = ClampWithGrad.apply 111 | 112 | class Prompt(nn.Module): 113 | def __init__(self, embed, weight=1., stop=float('-inf')): 114 | super().__init__() 115 | self.register_buffer('embed', embed) 116 | self.register_buffer('weight', torch.as_tensor(weight)) 117 | self.register_buffer('stop', torch.as_tensor(stop)) 118 | 119 | def forward(self, input): 120 | input_normed = F.normalize(input.unsqueeze(1), dim=2) 121 | embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2) 122 | dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) 123 | dists = dists * self.weight.sign() 124 | return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean() 125 | 126 | def fetch(url_or_path): 127 | if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'): 128 | r = requests.get(url_or_path) 129 | r.raise_for_status() 130 | fd = io.BytesIO() 131 | fd.write(r.content) 132 | fd.seek(0) 133 | return fd 134 | return open(url_or_path, 'rb') 135 | 136 | def parse_prompt(prompt): 137 | if prompt.startswith('http://') or prompt.startswith('https://'): 138 | vals = prompt.rsplit(':', 3) 139 | vals = [vals[0] + ':' + vals[1], *vals[2:]] 140 | else: 141 | vals = prompt.rsplit(':', 2) 142 | vals = vals + ['', '1', '-inf'][len(vals):] 143 | return vals[0], float(vals[1]), float(vals[2]) 144 | 145 | class MakeCutouts(nn.Module): 146 | def __init__(self, cut_size, cutn, cut_pow=1.): 147 | super().__init__() 148 | self.cut_size = cut_size 149 | self.cutn = cutn 150 | self.cut_pow = cut_pow 151 | 152 | def forward(self, input): 153 | sideY, sideX = input.shape[2:4] 154 | max_size = min(sideX, sideY) 155 | min_size = min(sideX, sideY, self.cut_size) 156 | cutouts = [] 157 | for _ in range(self.cutn): 158 | size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) 159 | offsetx = torch.randint(0, sideX - size + 1, ()) 160 | offsety = torch.randint(0, sideY - size + 1, ()) 161 | cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] 162 | cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) 163 | return clamp_with_grad(torch.cat(cutouts, dim=0), 0, 1) 164 | 165 | 166 | def resize_image(image, out_size): 167 | ratio = image.size[0] / image.size[1] 168 | area = min(image.size[0] * image.size[1], out_size[0] * out_size[1]) 169 | size = round((area * ratio)**0.5), round((area / ratio)**0.5) 170 | return image.resize(size, Image.LANCZOS) 171 | 172 | 173 | def save_img(a, dir): 174 | PIL.Image.fromarray(np.uint8(np.clip(a, 0, 255))).save(dir) 175 | 176 | 177 | def load_vqgan_model(config_path, checkpoint_path): 178 | config = OmegaConf.load(config_path) 179 | if config.model.target == 'taming.models.vqgan.VQModel': 180 | model = vqgan.VQModel(**config.model.params) 181 | model.eval().requires_grad_(False) 182 | model.init_from_ckpt(checkpoint_path) 183 | elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer': 184 | parent_model = cond_transformer.Net2NetTransformer(**config.model.params) 185 | parent_model.eval().requires_grad_(False) 186 | parent_model.init_from_ckpt(checkpoint_path) 187 | model = parent_model.first_stage_model 188 | else: 189 | raise ValueError(f'unknown model type: {config.model.target}') 190 | del model.loss 191 | return model 192 | 193 | class Network(torch.nn.Module): 194 | def __init__(self): 195 | super().__init__() 196 | 197 | class Preprocess(torch.nn.Module): 198 | def __init__(self): 199 | super().__init__() 200 | 201 | def forward(self, tenInput): 202 | tenBlue = (tenInput[:, 0:1, :, :] - 0.406) / 0.225 203 | tenGreen = (tenInput[:, 1:2, :, :] - 0.456) / 0.224 204 | tenRed = (tenInput[:, 2:3, :, :] - 0.485) / 0.229 205 | 206 | return torch.cat([ tenRed, tenGreen, tenBlue ], 1) 207 | 208 | class Basic(torch.nn.Module): 209 | def __init__(self, intLevel): 210 | super().__init__() 211 | 212 | self.netBasic = torch.nn.Sequential( 213 | torch.nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), 214 | torch.nn.ReLU(inplace=False), 215 | torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), 216 | torch.nn.ReLU(inplace=False), 217 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), 218 | torch.nn.ReLU(inplace=False), 219 | torch.nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), 220 | torch.nn.ReLU(inplace=False), 221 | torch.nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3) 222 | ) 223 | 224 | 225 | def forward(self, tenInput): 226 | return self.netBasic(tenInput) 227 | 228 | self.netPreprocess = Preprocess() 229 | self.netBasic = torch.nn.ModuleList([ Basic(intLevel) for intLevel in range(6) ]) 230 | self.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.hub.load_state_dict_from_url(url='http://content.sniklaus.com/github/pytorch-spynet/network-' + arguments_strModel + '.pytorch', file_name='spynet-' + arguments_strModel).items() }) 231 | 232 | def forward(self, tenOne, tenTwo): 233 | tenFlow = [] 234 | 235 | tenOne = [ self.netPreprocess(tenOne) ] 236 | tenTwo = [ self.netPreprocess(tenTwo) ] 237 | 238 | for intLevel in range(5): 239 | if tenOne[0].shape[2] > 32 or tenOne[0].shape[3] > 32: 240 | tenOne.insert(0, torch.nn.functional.avg_pool2d(input=tenOne[0], kernel_size=2, stride=2, count_include_pad=False)) 241 | tenTwo.insert(0, torch.nn.functional.avg_pool2d(input=tenTwo[0], kernel_size=2, stride=2, count_include_pad=False)) 242 | 243 | tenFlow = tenOne[0].new_zeros([ tenOne[0].shape[0], 2, int(math.floor(tenOne[0].shape[2] / 2.0)), int(math.floor(tenOne[0].shape[3] / 2.0)) ]) 244 | 245 | for intLevel in range(len(tenOne)): 246 | tenUpsampled = torch.nn.functional.interpolate(input=tenFlow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 247 | 248 | if tenUpsampled.shape[2] != tenOne[intLevel].shape[2]: tenUpsampled = torch.nn.functional.pad(input=tenUpsampled, pad=[ 0, 0, 0, 1 ], mode='replicate') 249 | if tenUpsampled.shape[3] != tenOne[intLevel].shape[3]: tenUpsampled = torch.nn.functional.pad(input=tenUpsampled, pad=[ 0, 1, 0, 0 ], mode='replicate') 250 | 251 | tenFlow = self.netBasic[intLevel](torch.cat([ tenOne[intLevel], backwarp(tenInput=tenTwo[intLevel], tenFlow=tenUpsampled), tenUpsampled ], 1)) + tenUpsampled 252 | 253 | 254 | return tenFlow 255 | 256 | torch.backends.cudnn.enabled = True 257 | arguments_strModel = 'sintel-final' # 'sintel-final', or 'sintel-clean', or 'chairs-final', or 'chairs-clean', or 'kitti-final' 258 | backwarp_tenGrid = {} 259 | 260 | def backwarp(tenInput, tenFlow): 261 | if str(tenFlow.shape) not in backwarp_tenGrid: 262 | tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), tenFlow.shape[3]).view(1, 1, 1, -1).expand(-1, -1, tenFlow.shape[2], -1) 263 | tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), tenFlow.shape[2]).view(1, 1, -1, 1).expand(-1, -1, -1, tenFlow.shape[3]) 264 | 265 | backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([ tenHor, tenVer ], 1).cuda() 266 | 267 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) 268 | 269 | return torch.nn.functional.grid_sample(input=tenInput, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='border', align_corners=False) 270 | 271 | netNetwork = None 272 | def estimate(tenOne, tenTwo): 273 | global netNetwork 274 | 275 | if netNetwork is None: 276 | netNetwork = Network().cuda().eval() 277 | 278 | assert(tenOne.shape[1] == tenTwo.shape[1]) 279 | assert(tenOne.shape[2] == tenTwo.shape[2]) 280 | 281 | intWidth = tenOne.shape[2] 282 | intHeight = tenOne.shape[1] 283 | 284 | tenPreprocessedOne = tenOne.cuda().view(1, 3, intHeight, intWidth) 285 | tenPreprocessedTwo = tenTwo.cuda().view(1, 3, intHeight, intWidth) 286 | 287 | intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 32.0) * 32.0)) 288 | intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 32.0) * 32.0)) 289 | 290 | tenPreprocessedOne = torch.nn.functional.interpolate(input=tenPreprocessedOne, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) 291 | tenPreprocessedTwo = torch.nn.functional.interpolate(input=tenPreprocessedTwo, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) 292 | 293 | tenFlow = torch.nn.functional.interpolate(input=netNetwork(tenPreprocessedOne, tenPreprocessedTwo), size=(intHeight, intWidth), mode='bilinear', align_corners=False) 294 | 295 | tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) 296 | tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) 297 | 298 | return tenFlow[0, :, :, :].cpu() 299 | 300 | def calc_opflow(img1, img2): 301 | img1 = PIL.Image.fromarray(img1) 302 | img2 = PIL.Image.fromarray(img2) 303 | 304 | tenFirst = torch.FloatTensor( 305 | np.ascontiguousarray( 306 | np.array(img1)[:, :, ::-1].transpose(2, 0, 1).astype(np.float32) 307 | * (1.0 / 255.0) 308 | ) 309 | ) 310 | tenSecond = torch.FloatTensor( 311 | np.ascontiguousarray( 312 | np.array(img2)[:, :, ::-1].transpose(2, 0, 1).astype(np.float32) 313 | * (1.0 / 255.0) 314 | ) 315 | ) 316 | 317 | tenOutput = estimate(tenFirst, tenSecond) 318 | return tenOutput 319 | 320 | def get_opflow_image(np_prev_img, frame, np_img, blendflow, blendstatic, threshold=6, do_blur=True, blur_value=(5, 5)): 321 | np_prev_img = np.float32(np_prev_img) 322 | frame = np.float32(frame) 323 | np_img = np.float32(np_img) 324 | 325 | h, w, _ = np_prev_img.shape 326 | 327 | flow = calc_opflow(np.uint8(np_prev_img), np.uint8(np_img)) 328 | flow = np.transpose(np.float32(flow), (1, 2, 0)) 329 | inv_flow = flow 330 | flow = -flow 331 | flow[:, :, 0] += np.arange(w) 332 | flow[:, :, 1] += np.arange(h)[:, np.newaxis] 333 | 334 | framediff = (np_img*(1-blendflow) + frame*blendflow) - np_prev_img 335 | framediff = cv2.remap(framediff, flow, None, cv2.INTER_LINEAR) 336 | if do_blur: 337 | framediff = cv2.GaussianBlur(framediff, blur_value, 0) 338 | 339 | frame_flow = np_img + framediff 340 | 341 | magnitude, _ = cv2.cartToPolar(inv_flow[...,0], inv_flow[...,1]) 342 | norm_mag = cv2.normalize(magnitude, None, 0, 255, cv2.NORM_MINMAX) 343 | _, mask = cv2.threshold(norm_mag, threshold, 255, cv2.THRESH_BINARY) 344 | flow_mask = mask.astype(np.uint8).reshape((h, w, 1)) 345 | frame_flow_masked = cv2.bitwise_and(frame_flow, frame_flow, mask=flow_mask) 346 | 347 | background_blendimg = cv2.addWeighted(np_img, (1-blendstatic), frame, blendstatic, 0) 348 | background_masked = cv2.bitwise_and(background_blendimg, background_blendimg, mask=cv2.bitwise_not(flow_mask)) 349 | 350 | return frame_flow_masked, background_masked 351 | --------------------------------------------------------------------------------