├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── VQGAN+CLIP_Video_with_Optical_Flow.ipynb
├── dream.py
├── samples
├── sample_cat.mp4
└── sample_policeman.mp4
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "CLIP"]
2 | path = CLIP
3 | url = https://github.com/openai/CLIP
4 | [submodule "taming-transformers"]
5 | path = taming-transformers
6 | url = https://github.com/CompVis/taming-transformers.git
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 robobeebop
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VQGAN-CLIP-Video
2 |
3 |
14 |
--------------------------------------------------------------------------------
/VQGAN+CLIP_Video_with_Optical_Flow.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "VQGAN+CLIP Video with Optical Flow",
7 | "private_outputs": true,
8 | "provenance": [],
9 | "collapsed_sections": [
10 | "XYU-20jFsaW8"
11 | ]
12 | },
13 | "kernelspec": {
14 | "name": "python3",
15 | "display_name": "Python 3"
16 | },
17 | "language_info": {
18 | "name": "python"
19 | },
20 | "accelerator": "GPU"
21 | },
22 | "cells": [
23 | {
24 | "cell_type": "markdown",
25 | "source": [
26 | "# Setup"
27 | ],
28 | "metadata": {
29 | "id": "XYU-20jFsaW8"
30 | }
31 | },
32 | {
33 | "cell_type": "code",
34 | "source": [
35 | "# @title Prepare workspace\n",
36 | "! rm -r -f .config sample_data\n",
37 | "! git clone --recursive https://github.com/robobeebop/VQGAN-CLIP-Video.git ."
38 | ],
39 | "metadata": {
40 | "cellView": "form",
41 | "id": "cjDu-Hbtmrwb"
42 | },
43 | "execution_count": null,
44 | "outputs": []
45 | },
46 | {
47 | "cell_type": "code",
48 | "source": [
49 | "# @title Mount Drive\n",
50 | "\n",
51 | "from google.colab import output\n",
52 | "from google.colab import drive\n",
53 | "drive.mount('/content/drive')"
54 | ],
55 | "metadata": {
56 | "cellView": "form",
57 | "id": "0mEllCWgnrf2"
58 | },
59 | "execution_count": null,
60 | "outputs": []
61 | },
62 | {
63 | "cell_type": "code",
64 | "source": [
65 | "# @title Install Required Libraries\n",
66 | "! pip install lpips ftfy omegaconf einops pytorch-lightning\n",
67 | "! pip3 install transformers"
68 | ],
69 | "metadata": {
70 | "cellView": "form",
71 | "id": "lhwh9J2klrkf"
72 | },
73 | "execution_count": null,
74 | "outputs": []
75 | },
76 | {
77 | "cell_type": "markdown",
78 | "source": [
79 | "# RUN"
80 | ],
81 | "metadata": {
82 | "id": "0A_d725Ssdm3"
83 | }
84 | },
85 | {
86 | "cell_type": "code",
87 | "source": [
88 | "# @title Select a model\n",
89 | "from dream import Dream\n",
90 | "from os.path import exists\n",
91 | "!mkdir checkpoints\n",
92 | "ckpt_dir = \"/content/checkpoints/\"\n",
93 | "\n",
94 | "vqgan_model = \"imagenet_1024\" #@param [\"imagenet_1024\", \"imagenet_16384\", \"coco\", \"sflckr\"]\n",
95 | "\n",
96 | "vqgan_options = {\n",
97 | " \"imagenet_1024\": [\"1-7QlixzWxZAO8ktGFqvxrZ_JzapzI5hH\", \"1-8mSOBsutfkE95piiGf4ZuVX0zkAwzkn\"],\n",
98 | " \"imagenet_16384\": [\"1_1q5zxEBx17AyTALEhGqhSsS7tyCJ4fe\", \"1-0D4pbu7NHrvWzTfbw4hiA1Sno75Z2_C\"],\n",
99 | " \"coco\": [\"1-9gq1a4yGOKC3rDw-X9NBe5_JVcKcLPG\", \"1-CPBZXsCgCv-Z6Uy4Sf4lKeyqG_C5i-Y\"],\n",
100 | " \"sflckr\": [\"1iIgSRV4H6og3l2myXPRE043ULoPlqn8w\", \"1-1vMpPmB6QZhGzriXG9iI6WeFZLl7VP2\"],\n",
101 | "}\n",
102 | "\n",
103 | "yaml, ckpt = ckpt_dir + \"%s.yaml\"%vqgan_model, ckpt_dir + \"%s.ckpt\"%vqgan_model\n",
104 | "\n",
105 | "if not exists(ckpt) or not exists(yaml):\n",
106 | " yaml_id = vqgan_options[vqgan_model][0]\n",
107 | " ckpt_id = vqgan_options[vqgan_model][1]\n",
108 | "\n",
109 | " !gdown --id \"$yaml_id\" -O \"$yaml\"\n",
110 | " !gdown --id \"$ckpt_id\" -O \"$ckpt\"\n",
111 | "\n",
112 | "\n",
113 | "\n",
114 | "dream = Dream()\n",
115 | "dream.cook([yaml, ckpt])"
116 | ],
117 | "metadata": {
118 | "cellView": "form",
119 | "id": "gzz5pv6-nXLk"
120 | },
121 | "execution_count": null,
122 | "outputs": []
123 | },
124 | {
125 | "cell_type": "markdown",
126 | "source": [
127 | "Run the cell above once when it's initial googlecolab run or you want to change the VQGAN model"
128 | ],
129 | "metadata": {
130 | "id": "TRXTBau3bWdf"
131 | }
132 | },
133 | {
134 | "cell_type": "code",
135 | "source": [
136 | "# @title Dream { display-mode: \"form\" }\n",
137 | "#@markdown You can see output frames in /content/output folder.\n",
138 | "\n",
139 | "from dream import cv2\n",
140 | "from dream import save_img\n",
141 | "from dream import reduce_res\n",
142 | "from dream import np\n",
143 | "from dream import get_opflow_image\n",
144 | "from dream import PIL\n",
145 | "from dream import trange\n",
146 | "from dream import glob\n",
147 | "\n",
148 | "#@markdown ---\n",
149 | "\n",
150 | "#@markdown Describe how deepdream should look like. Each scene ( | ... | ) can be weighted: \"deepdream:3 | dogs:-1\"\n",
151 | "text_prompts = \"trending on artstation | lovecraftian horror | deepdream | vibrant colors | 4k | made by Edvard Munch\" #@param {type:\"string\"}\n",
152 | "\n",
153 | "#@markdown ---\n",
154 | "\n",
155 | "#@markdown Video paths.\n",
156 | "vid_path = '/content/samples/sample_policeman.mp4' #@param {type:\"string\"}\n",
157 | "output_vid_path = '/content/samples/deepdreamed_policeman.mp4' #@param {type:\"string\"}\n",
158 | "#@markdown ---\n",
159 | "\n",
160 | "#@markdown Play around with these settings, finding optimal settings may vary video to video. Set both to 0 for more chaotic experience.\n",
161 | "frame_weight = 10#@param {type:\"number\"}\n",
162 | "previous_frame_weight = 0.1#@param {type:\"number\"}\n",
163 | "#@markdown ---\n",
164 | "\n",
165 | "#@markdown Usual VQGAN+CLIP settings\n",
166 | "step_size = 0.15 #@param {type:\"slider\", min:0, max:1, step:0.05}\n",
167 | "iter_n = 5#@param {type:\"number\"}\n",
168 | "#@markdown ---\n",
169 | "\n",
170 | "#@markdown Dream more intense on first frame of the video. \n",
171 | "do_wait_first_frame = True #@param {type:\"boolean\"}\n",
172 | "wait_step_size = 0.15 #@param {type:\"slider\", min:0, max:1, step:0.05}\n",
173 | "wait_iter_n = 15#@param {type:\"number\"}\n",
174 | "#@markdown ---\n",
175 | "\n",
176 | "#@markdown Weights of how previous deepdreamed frame should effect current frame.\n",
177 | "blendflow = 0.6#@param {type:\"slider\", min:0, max:1, step:0.05}\n",
178 | "blendstatic = 0.6#@param {type:\"slider\", min:0, max:1, step:0.05}\n",
179 | "#@markdown ---\n",
180 | "#@markdown Video resolution and fps\n",
181 | "w = 1920#@param {type:\"number\"}\n",
182 | "h = 1080#@param {type:\"number\"}\n",
183 | "fps = 24#@param {type:\"number\"}\n",
184 | "#@markdown ---\n",
185 | "#@markdown Make a test run.\n",
186 | "is_test = False #@param {type:\"boolean\"}\n",
187 | "test_finish_at = 24#@param {type:\"number\"}\n",
188 | "#@markdown ---\n",
189 | "#@markdown Get all frames from video, can be set to False if video didn't change. \n",
190 | "video_to_frames = True #@param {type:\"boolean\"}\n",
191 | "\n",
192 | "!mkdir input\n",
193 | "!mkdir output\n",
194 | "!rm -r -f ./output/*.jpg\n",
195 | "\n",
196 | "\n",
197 | "if(video_to_frames):\n",
198 | " !rm -r -f ./input/*.jpg\n",
199 | " vidcap = cv2.VideoCapture(vid_path)\n",
200 | " success,image = vidcap.read()\n",
201 | " index = 1\n",
202 | " while success:\n",
203 | " cv2.imwrite(\"./input/%04d.jpg\" % index, image)\n",
204 | " success, image = vidcap.read()\n",
205 | " index += 1\n",
206 | "\n",
207 | "x, y = reduce_res((w, h))\n",
208 | "img_arr = sorted(glob('input/*.jpg'))\n",
209 | "\n",
210 | "np_img = np.float32(PIL.Image.open(img_arr[0]))\n",
211 | "np_img = cv2.resize(np_img, dsize=(x, y), interpolation=cv2.INTER_CUBIC)\n",
212 | "h, w, c = np_img.shape\n",
213 | "\n",
214 | "frame = None\n",
215 | "\n",
216 | "if do_wait_first_frame:\n",
217 | " frame = dream.deepdream(np_img, text_prompts, [x, y], iter_n=wait_iter_n, step_size=step_size, init_weight=frame_weight)\n",
218 | "else:\n",
219 | " frame = dream.deepdream(np_img, text_prompts, [x, y], iter_n=iter_n, step_size=step_size, init_weight=frame_weight)\n",
220 | "\n",
221 | "frame = cv2.resize(frame, dsize=(x, y), interpolation=cv2.INTER_CUBIC)\n",
222 | "save_img(frame, 'output/%04d.jpg'%0)\n",
223 | "\n",
224 | "img_range = trange(len(img_arr[:test_finish_at]), desc=\"Dreaming\") if is_test else trange(len(img_arr), desc=\"Dreaming\")\n",
225 | "prev_frame = None\n",
226 | "for i in img_range: \n",
227 | " if previous_frame_weight != 0:\n",
228 | " prev_frame = np.copy(frame)\n",
229 | "\n",
230 | " img = img_arr[i]\n",
231 | " np_prev_img = np_img\n",
232 | " np_img = np.float32(PIL.Image.open(img))\n",
233 | " np_img = cv2.resize(np_img, dsize=(x, y), interpolation=cv2.INTER_CUBIC)\n",
234 | " frame = cv2.resize(frame, dsize=(x, y), interpolation=cv2.INTER_CUBIC)\n",
235 | " \n",
236 | " frame_flow_masked, background_masked = get_opflow_image(np_prev_img, frame, np_img, blendflow, blendstatic)\n",
237 | " frame = frame_flow_masked + background_masked\n",
238 | " frame = dream.deepdream(frame, text_prompts, [x, y], iter_n=iter_n, init_weight=frame_weight, step_size=step_size, image_prompts=prev_frame, image_prompt_weight=previous_frame_weight)\n",
239 | " \n",
240 | " save_img(frame, 'output/%04d.jpg'%i)\n",
241 | "\n",
242 | "\n",
243 | "mp4_fourcc = cv2.VideoWriter_fourcc(*'MP4V')\n",
244 | "out = cv2.VideoWriter(output_vid_path, mp4_fourcc, fps, (w, h))\n",
245 | "filelist = sorted(glob('output/*.jpg'))\n",
246 | "\n",
247 | "for i in trange(len(filelist), desc=\"Generating Video\"):\n",
248 | " img = cv2.imread(filelist[i])\n",
249 | " img = cv2.resize(img, dsize=(w, h), interpolation=cv2.INTER_CUBIC)\n",
250 | " out.write(img)\n",
251 | "out.release()"
252 | ],
253 | "metadata": {
254 | "id": "O6GwujqCowd9"
255 | },
256 | "execution_count": null,
257 | "outputs": []
258 | },
259 | {
260 | "cell_type": "code",
261 | "source": [
262 | "# @title Download Generated Video.\n",
263 | "# @markdown (doesn't work on Safari)\n",
264 | "from google.colab import files\n",
265 | "\n",
266 | "files.download(output_vid_path)"
267 | ],
268 | "metadata": {
269 | "cellView": "form",
270 | "id": "PXZmbS0IeKUn"
271 | },
272 | "execution_count": null,
273 | "outputs": []
274 | }
275 | ]
276 | }
--------------------------------------------------------------------------------
/dream.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 |
3 | class Dream:
4 | def __init__(self, isHighVRAM=True) -> None:
5 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
6 | self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
7 | self.resLimit = 4.2e5 if isHighVRAM else 2.5e5
8 |
9 | def cook(self, vqgan_path, cut_n=32, cut_pow=1., prompts="", init_weight=4, clip_model='ViT-B/16'):
10 | self.vqgan_config = vqgan_path[0]
11 | self.vqgan_checkpoint = vqgan_path[1]
12 | self.model = load_vqgan_model(self.vqgan_config, self.vqgan_checkpoint).to(self.device)
13 |
14 | self.clip_model = clip_model
15 | self.perceptor = clip.load(self.clip_model, jit=False)[0].eval().requires_grad_(False).to(self.device)
16 |
17 | self.cut_size = self.perceptor.visual.input_resolution
18 | self.e_dim = self.model.quantize.e_dim
19 | self.f = 2**(self.model.decoder.num_resolutions - 1)
20 | self.make_cutouts = MakeCutouts(self.cut_size, cutn=cut_n, cut_pow=cut_pow)
21 |
22 | self.z_min = self.model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
23 | self.z_max = self.model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
24 | self.pMs = []
25 | self.init_weight = init_weight
26 | prompts = prompts.split("|")
27 |
28 | for prompt in prompts:
29 | txt, weight, stop = parse_prompt(prompt)
30 | embed = self.perceptor.encode_text(clip.tokenize(txt).to(self.device)).float()
31 | self.pMs.append(Prompt(embed, weight, stop).to(self.device))
32 |
33 | self.pMs.append(Prompt(embed, weight).to(self.device))
34 |
35 | def deepdream(self, init_image, iter_n=25, step_size=0.3):
36 |
37 | pil_image = Image.fromarray((init_image * 1).astype(np.uint8)).convert('RGB')
38 | self.z, *_ = self.model.encode(TF.to_tensor(pil_image).to(self.device).unsqueeze(0) * 2 - 1)
39 |
40 | self.z_orig = self.z.clone()
41 | self.z.requires_grad_(True)
42 | self.opt = optim.Adam([self.z], lr=step_size)
43 |
44 |
45 | gen = torch.Generator().manual_seed(0)
46 | embed = torch.empty([1, self.perceptor.visual.output_dim]).normal_(generator=gen)
47 |
48 | try:
49 | for i in range(iter_n):
50 | out = self.train(i, iter_n)
51 | except KeyboardInterrupt:
52 | pass
53 |
54 | return np.float32(TF.to_pil_image(out[0].cpu()))
55 |
56 | def train(self, i, iter_n):
57 | torch.set_grad_enabled(True)
58 | self.opt.zero_grad()
59 | lossAll = self.ascend_txt()
60 | loss = sum(lossAll)
61 | loss.backward()
62 | torch.set_grad_enabled(False)
63 | self.opt.step()
64 |
65 | with torch.no_grad():
66 | self.z.copy_(self.z.maximum(self.z_min).minimum(self.z_max))
67 |
68 | if i == iter_n-1:
69 | return self.checkout()
70 |
71 | return None
72 |
73 |
74 | def ascend_txt(self):
75 | out = self.synth()
76 | iii = self.perceptor.encode_image(self.normalize(self.make_cutouts(out))).float()
77 | result = []
78 | if self.init_weight:
79 | result.append(F.mse_loss(self.z, self.z_orig) * self.init_weight / 2)
80 | for prompt in self.pMs:
81 | result.append(prompt(iii))
82 | return result
83 |
84 | def synth(self):
85 | z_q = vector_quantize(self.z.movedim(1, 3), self.model.quantize.embedding.weight).movedim(3, 1)
86 | return clamp_with_grad(self.model.decode(z_q).add(1).div(2), 0, 1)
87 |
88 | @torch.no_grad()
89 | def checkout(self):
90 | out = self.synth()
91 | return out
--------------------------------------------------------------------------------
/samples/sample_cat.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robobeebop/VQGAN-CLIP-Video/72414dee1764397654121c75dbb576b038da68e4/samples/sample_cat.mp4
--------------------------------------------------------------------------------
/samples/sample_policeman.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/robobeebop/VQGAN-CLIP-Video/72414dee1764397654121c75dbb576b038da68e4/samples/sample_policeman.mp4
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import io
3 | import cv2
4 | import sys
5 | import PIL
6 | import math
7 | import lpips
8 | import torch
9 | import requests
10 | import numpy as np
11 |
12 | sys.path.append('./CLIP')
13 | sys.path.append('./taming-transformers')
14 |
15 | from os import path
16 | from PIL import Image
17 | from glob import glob
18 | from CLIP import clip
19 | from pathlib import Path
20 | from IPython import display
21 | from torch import nn, optim
22 | from google.colab import output
23 | from omegaconf import OmegaConf
24 | from torchvision import transforms
25 | from torch.nn import functional as F
26 | from tqdm.notebook import tqdm, trange
27 | from taming.models import cond_transformer, vqgan
28 | from torchvision.transforms import functional as TF
29 |
30 | def reduce_res(res, max_res_value=4.5e5, max_res_scale=1.): # max limit aprx 700x700 = 49e4
31 | x1, y1 = res
32 | if x1 * y1 < max_res_value:
33 | return x1, y1
34 | x = (max_res_value**(1/2)) / (x1/y1)**(1/2)
35 | return int(max_res_scale*x1*x/y1), int(max_res_scale*x)
36 |
37 | def sinc(x):
38 | return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
39 |
40 |
41 | def lanczos(x, a):
42 | cond = torch.logical_and(-a < x, x < a)
43 | out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
44 | return out / out.sum()
45 |
46 |
47 | def ramp(ratio, width):
48 | n = math.ceil(width / ratio + 1)
49 | out = torch.empty([n])
50 | cur = 0
51 | for i in range(out.shape[0]):
52 | out[i] = cur
53 | cur += ratio
54 | return torch.cat([-out[1:].flip([0]), out])[1:-1]
55 |
56 |
57 | def resample(input, size, align_corners=True):
58 | n, c, h, w = input.shape
59 | dh, dw = size
60 |
61 | input = input.view([n * c, 1, h, w])
62 |
63 | if dh < h:
64 | kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
65 | pad_h = (kernel_h.shape[0] - 1) // 2
66 | input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
67 | input = F.conv2d(input, kernel_h[None, None, :, None])
68 |
69 | if dw < w:
70 | kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
71 | pad_w = (kernel_w.shape[0] - 1) // 2
72 | input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
73 | input = F.conv2d(input, kernel_w[None, None, None, :])
74 |
75 | input = input.view([n, c, h, w])
76 | return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
77 |
78 | class ReplaceGrad(torch.autograd.Function):
79 | @staticmethod
80 | def forward(ctx, x_forward, x_backward):
81 | ctx.shape = x_backward.shape
82 | return x_forward
83 |
84 | @staticmethod
85 | def backward(ctx, grad_in):
86 | return None, grad_in.sum_to_size(ctx.shape)
87 |
88 | replace_grad = ReplaceGrad.apply
89 |
90 | class ClampWithGrad(torch.autograd.Function):
91 | @staticmethod
92 | def forward(ctx, input, min, max):
93 | ctx.min = min
94 | ctx.max = max
95 | ctx.save_for_backward(input)
96 | return input.clamp(min, max)
97 |
98 | @staticmethod
99 | def backward(ctx, grad_in):
100 | input, = ctx.saved_tensors
101 | return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None
102 |
103 |
104 | def vector_quantize(x, codebook):
105 | d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
106 | indices = d.argmin(-1)
107 | x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
108 | return replace_grad(x_q, x)
109 |
110 | clamp_with_grad = ClampWithGrad.apply
111 |
112 | class Prompt(nn.Module):
113 | def __init__(self, embed, weight=1., stop=float('-inf')):
114 | super().__init__()
115 | self.register_buffer('embed', embed)
116 | self.register_buffer('weight', torch.as_tensor(weight))
117 | self.register_buffer('stop', torch.as_tensor(stop))
118 |
119 | def forward(self, input):
120 | input_normed = F.normalize(input.unsqueeze(1), dim=2)
121 | embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
122 | dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
123 | dists = dists * self.weight.sign()
124 | return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
125 |
126 | def fetch(url_or_path):
127 | if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
128 | r = requests.get(url_or_path)
129 | r.raise_for_status()
130 | fd = io.BytesIO()
131 | fd.write(r.content)
132 | fd.seek(0)
133 | return fd
134 | return open(url_or_path, 'rb')
135 |
136 | def parse_prompt(prompt):
137 | if prompt.startswith('http://') or prompt.startswith('https://'):
138 | vals = prompt.rsplit(':', 3)
139 | vals = [vals[0] + ':' + vals[1], *vals[2:]]
140 | else:
141 | vals = prompt.rsplit(':', 2)
142 | vals = vals + ['', '1', '-inf'][len(vals):]
143 | return vals[0], float(vals[1]), float(vals[2])
144 |
145 | class MakeCutouts(nn.Module):
146 | def __init__(self, cut_size, cutn, cut_pow=1.):
147 | super().__init__()
148 | self.cut_size = cut_size
149 | self.cutn = cutn
150 | self.cut_pow = cut_pow
151 |
152 | def forward(self, input):
153 | sideY, sideX = input.shape[2:4]
154 | max_size = min(sideX, sideY)
155 | min_size = min(sideX, sideY, self.cut_size)
156 | cutouts = []
157 | for _ in range(self.cutn):
158 | size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
159 | offsetx = torch.randint(0, sideX - size + 1, ())
160 | offsety = torch.randint(0, sideY - size + 1, ())
161 | cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
162 | cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
163 | return clamp_with_grad(torch.cat(cutouts, dim=0), 0, 1)
164 |
165 |
166 | def resize_image(image, out_size):
167 | ratio = image.size[0] / image.size[1]
168 | area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
169 | size = round((area * ratio)**0.5), round((area / ratio)**0.5)
170 | return image.resize(size, Image.LANCZOS)
171 |
172 |
173 | def save_img(a, dir):
174 | PIL.Image.fromarray(np.uint8(np.clip(a, 0, 255))).save(dir)
175 |
176 |
177 | def load_vqgan_model(config_path, checkpoint_path):
178 | config = OmegaConf.load(config_path)
179 | if config.model.target == 'taming.models.vqgan.VQModel':
180 | model = vqgan.VQModel(**config.model.params)
181 | model.eval().requires_grad_(False)
182 | model.init_from_ckpt(checkpoint_path)
183 | elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':
184 | parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
185 | parent_model.eval().requires_grad_(False)
186 | parent_model.init_from_ckpt(checkpoint_path)
187 | model = parent_model.first_stage_model
188 | else:
189 | raise ValueError(f'unknown model type: {config.model.target}')
190 | del model.loss
191 | return model
192 |
193 | class Network(torch.nn.Module):
194 | def __init__(self):
195 | super().__init__()
196 |
197 | class Preprocess(torch.nn.Module):
198 | def __init__(self):
199 | super().__init__()
200 |
201 | def forward(self, tenInput):
202 | tenBlue = (tenInput[:, 0:1, :, :] - 0.406) / 0.225
203 | tenGreen = (tenInput[:, 1:2, :, :] - 0.456) / 0.224
204 | tenRed = (tenInput[:, 2:3, :, :] - 0.485) / 0.229
205 |
206 | return torch.cat([ tenRed, tenGreen, tenBlue ], 1)
207 |
208 | class Basic(torch.nn.Module):
209 | def __init__(self, intLevel):
210 | super().__init__()
211 |
212 | self.netBasic = torch.nn.Sequential(
213 | torch.nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3),
214 | torch.nn.ReLU(inplace=False),
215 | torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3),
216 | torch.nn.ReLU(inplace=False),
217 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3),
218 | torch.nn.ReLU(inplace=False),
219 | torch.nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3),
220 | torch.nn.ReLU(inplace=False),
221 | torch.nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)
222 | )
223 |
224 |
225 | def forward(self, tenInput):
226 | return self.netBasic(tenInput)
227 |
228 | self.netPreprocess = Preprocess()
229 | self.netBasic = torch.nn.ModuleList([ Basic(intLevel) for intLevel in range(6) ])
230 | self.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.hub.load_state_dict_from_url(url='http://content.sniklaus.com/github/pytorch-spynet/network-' + arguments_strModel + '.pytorch', file_name='spynet-' + arguments_strModel).items() })
231 |
232 | def forward(self, tenOne, tenTwo):
233 | tenFlow = []
234 |
235 | tenOne = [ self.netPreprocess(tenOne) ]
236 | tenTwo = [ self.netPreprocess(tenTwo) ]
237 |
238 | for intLevel in range(5):
239 | if tenOne[0].shape[2] > 32 or tenOne[0].shape[3] > 32:
240 | tenOne.insert(0, torch.nn.functional.avg_pool2d(input=tenOne[0], kernel_size=2, stride=2, count_include_pad=False))
241 | tenTwo.insert(0, torch.nn.functional.avg_pool2d(input=tenTwo[0], kernel_size=2, stride=2, count_include_pad=False))
242 |
243 | tenFlow = tenOne[0].new_zeros([ tenOne[0].shape[0], 2, int(math.floor(tenOne[0].shape[2] / 2.0)), int(math.floor(tenOne[0].shape[3] / 2.0)) ])
244 |
245 | for intLevel in range(len(tenOne)):
246 | tenUpsampled = torch.nn.functional.interpolate(input=tenFlow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
247 |
248 | if tenUpsampled.shape[2] != tenOne[intLevel].shape[2]: tenUpsampled = torch.nn.functional.pad(input=tenUpsampled, pad=[ 0, 0, 0, 1 ], mode='replicate')
249 | if tenUpsampled.shape[3] != tenOne[intLevel].shape[3]: tenUpsampled = torch.nn.functional.pad(input=tenUpsampled, pad=[ 0, 1, 0, 0 ], mode='replicate')
250 |
251 | tenFlow = self.netBasic[intLevel](torch.cat([ tenOne[intLevel], backwarp(tenInput=tenTwo[intLevel], tenFlow=tenUpsampled), tenUpsampled ], 1)) + tenUpsampled
252 |
253 |
254 | return tenFlow
255 |
256 | torch.backends.cudnn.enabled = True
257 | arguments_strModel = 'sintel-final' # 'sintel-final', or 'sintel-clean', or 'chairs-final', or 'chairs-clean', or 'kitti-final'
258 | backwarp_tenGrid = {}
259 |
260 | def backwarp(tenInput, tenFlow):
261 | if str(tenFlow.shape) not in backwarp_tenGrid:
262 | tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), tenFlow.shape[3]).view(1, 1, 1, -1).expand(-1, -1, tenFlow.shape[2], -1)
263 | tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), tenFlow.shape[2]).view(1, 1, -1, 1).expand(-1, -1, -1, tenFlow.shape[3])
264 |
265 | backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([ tenHor, tenVer ], 1).cuda()
266 |
267 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1)
268 |
269 | return torch.nn.functional.grid_sample(input=tenInput, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='border', align_corners=False)
270 |
271 | netNetwork = None
272 | def estimate(tenOne, tenTwo):
273 | global netNetwork
274 |
275 | if netNetwork is None:
276 | netNetwork = Network().cuda().eval()
277 |
278 | assert(tenOne.shape[1] == tenTwo.shape[1])
279 | assert(tenOne.shape[2] == tenTwo.shape[2])
280 |
281 | intWidth = tenOne.shape[2]
282 | intHeight = tenOne.shape[1]
283 |
284 | tenPreprocessedOne = tenOne.cuda().view(1, 3, intHeight, intWidth)
285 | tenPreprocessedTwo = tenTwo.cuda().view(1, 3, intHeight, intWidth)
286 |
287 | intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 32.0) * 32.0))
288 | intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 32.0) * 32.0))
289 |
290 | tenPreprocessedOne = torch.nn.functional.interpolate(input=tenPreprocessedOne, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
291 | tenPreprocessedTwo = torch.nn.functional.interpolate(input=tenPreprocessedTwo, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
292 |
293 | tenFlow = torch.nn.functional.interpolate(input=netNetwork(tenPreprocessedOne, tenPreprocessedTwo), size=(intHeight, intWidth), mode='bilinear', align_corners=False)
294 |
295 | tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth)
296 | tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight)
297 |
298 | return tenFlow[0, :, :, :].cpu()
299 |
300 | def calc_opflow(img1, img2):
301 | img1 = PIL.Image.fromarray(img1)
302 | img2 = PIL.Image.fromarray(img2)
303 |
304 | tenFirst = torch.FloatTensor(
305 | np.ascontiguousarray(
306 | np.array(img1)[:, :, ::-1].transpose(2, 0, 1).astype(np.float32)
307 | * (1.0 / 255.0)
308 | )
309 | )
310 | tenSecond = torch.FloatTensor(
311 | np.ascontiguousarray(
312 | np.array(img2)[:, :, ::-1].transpose(2, 0, 1).astype(np.float32)
313 | * (1.0 / 255.0)
314 | )
315 | )
316 |
317 | tenOutput = estimate(tenFirst, tenSecond)
318 | return tenOutput
319 |
320 | def get_opflow_image(np_prev_img, frame, np_img, blendflow, blendstatic, threshold=6, do_blur=True, blur_value=(5, 5)):
321 | np_prev_img = np.float32(np_prev_img)
322 | frame = np.float32(frame)
323 | np_img = np.float32(np_img)
324 |
325 | h, w, _ = np_prev_img.shape
326 |
327 | flow = calc_opflow(np.uint8(np_prev_img), np.uint8(np_img))
328 | flow = np.transpose(np.float32(flow), (1, 2, 0))
329 | inv_flow = flow
330 | flow = -flow
331 | flow[:, :, 0] += np.arange(w)
332 | flow[:, :, 1] += np.arange(h)[:, np.newaxis]
333 |
334 | framediff = (np_img*(1-blendflow) + frame*blendflow) - np_prev_img
335 | framediff = cv2.remap(framediff, flow, None, cv2.INTER_LINEAR)
336 | if do_blur:
337 | framediff = cv2.GaussianBlur(framediff, blur_value, 0)
338 |
339 | frame_flow = np_img + framediff
340 |
341 | magnitude, _ = cv2.cartToPolar(inv_flow[...,0], inv_flow[...,1])
342 | norm_mag = cv2.normalize(magnitude, None, 0, 255, cv2.NORM_MINMAX)
343 | _, mask = cv2.threshold(norm_mag, threshold, 255, cv2.THRESH_BINARY)
344 | flow_mask = mask.astype(np.uint8).reshape((h, w, 1))
345 | frame_flow_masked = cv2.bitwise_and(frame_flow, frame_flow, mask=flow_mask)
346 |
347 | background_blendimg = cv2.addWeighted(np_img, (1-blendstatic), frame, blendstatic, 0)
348 | background_masked = cv2.bitwise_and(background_blendimg, background_blendimg, mask=cv2.bitwise_not(flow_mask))
349 |
350 | return frame_flow_masked, background_masked
351 |
--------------------------------------------------------------------------------