├── README.md
└── StyleGAN3+CLIP_v2.ipynb
/README.md:
--------------------------------------------------------------------------------
1 | # StyleGAN3-CLIP-ColabNB
2 | Google Colab notebook for NVIDIA's StyleGAN3 and OpenAI's CLIP for a text-based guided image generation.
3 |
4 |
5 |
8 |
9 |
10 | ---
11 |
12 | This notebook uses work made by [Katherine Crowson] [Twitter](https://twitter.com/RiversHaveWings) [Github](https://github.com/crowsonkb) and [nshepperd] [Twitter](https://twitter.com/nshepperd1) [Github](https://github.com/nshepperd).
13 |
14 | StyleGAN3 was created by NVIDIA. [Visit here](https://github.com/NVlabs/stylegan3)
15 |
16 | CLIP (Contrastive Language-Image Pre-Training) is a model made by OpenAI. For more info head over [Visit here](https://github.com/openai/CLIP).
17 |
--------------------------------------------------------------------------------
/StyleGAN3+CLIP_v2.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "StyleGAN3+CLIP_v2.ipynb",
7 | "private_outputs": true,
8 | "provenance": [],
9 | "collapsed_sections": []
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "language_info": {
16 | "name": "python"
17 | },
18 | "accelerator": "GPU"
19 | },
20 | "cells": [
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {
24 | "id": "bJOj_BWi_JwR"
25 | },
26 | "source": [
27 | "# **StyleGAN3 + CLIP 🖼️**\n",
28 | "\n",
29 | "## Generate images (mostly faces) from text prompts using NVIDIA's StyleGAN3 with CLIP guidance.\n",
30 | "\n",
31 | "Code written by [nshepperd](https://twitter.com/nshepperd1) (https://github.com/nshepperd).\n",
32 | "\n",
33 | "Modified by [justinjohn0306](https://github.com/justinjohn0306)\n",
34 | "\n",
35 | "Thanks to [Katherine Crowson](https://twitter.com/RiversHaveWings) (https://github.com/crowsonkb) for coming up with many improved sampling tricks, as well as some of the code.\n",
36 | "\n",
37 | "\n",
38 | "**Visit StyleGAN3**, [here](https://github.com/NVlabs/stylegan3)."
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "metadata": {
44 | "id": "so1yHofG7RxX"
45 | },
46 | "source": [
47 | "#@title Licensed under the MIT License { display-mode: \"form\" }\n",
48 | "\n",
49 | "# Copyright (c) 2021 nshepperd; Katherine Crowson\n",
50 | "\n",
51 | "# Permission is hereby granted, free of charge, to any person obtaining a copy\n",
52 | "# of this software and associated documentation files (the \"Software\"), to deal\n",
53 | "# in the Software without restriction, including without limitation the rights\n",
54 | "# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n",
55 | "# copies of the Software, and to permit persons to whom the Software is\n",
56 | "# furnished to do so, subject to the following conditions:\n",
57 | "\n",
58 | "# The above copyright notice and this permission notice shall be included in\n",
59 | "# all copies or substantial portions of the Software.\n",
60 | "\n",
61 | "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
62 | "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
63 | "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n",
64 | "# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
65 | "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n",
66 | "# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n",
67 | "# THE SOFTWARE."
68 | ],
69 | "execution_count": null,
70 | "outputs": []
71 | },
72 | {
73 | "cell_type": "code",
74 | "metadata": {
75 | "id": "rg_x6IdHv_2G",
76 | "cellView": "form"
77 | },
78 | "source": [
79 | "#@markdown #**Check GPU** 🕵️\n",
80 | "\n",
81 | "#@markdown ---\n",
82 | "\n",
83 | "!nvidia-smi -L\n",
84 | "!nvcc --version"
85 | ],
86 | "execution_count": null,
87 | "outputs": []
88 | },
89 | {
90 | "cell_type": "code",
91 | "metadata": {
92 | "id": "5K38uyFrv5wo",
93 | "cellView": "form"
94 | },
95 | "source": [
96 | "#@markdown #**Install libraries** 🏗️\n",
97 | "# @markdown This cell will take a little while because it has to download several libraries.\n",
98 | "\n",
99 | "#@markdown ---\n",
100 | "\n",
101 | "!pip install --upgrade torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html\n",
102 | "#!pip install --upgrade https://download.pytorch.org/whl/nightly/cu111/torch-1.11.0.dev20211012%2Bcu111-cp37-cp37m-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu111/torchvision-0.12.0.dev20211012%2Bcu111-cp37-cp37m-linux_x86_64.whl\n",
103 | "!git clone https://github.com/NVlabs/stylegan3\n",
104 | "!git clone https://github.com/openai/CLIP\n",
105 | "!pip install -e ./CLIP\n",
106 | "!pip install einops ninja\n",
107 | "\n",
108 | "import sys\n",
109 | "sys.path.append('./CLIP')\n",
110 | "sys.path.append('./stylegan3')\n",
111 | "\n",
112 | "import io\n",
113 | "import os, time\n",
114 | "import pickle\n",
115 | "import shutil\n",
116 | "import numpy as np\n",
117 | "from PIL import Image\n",
118 | "import torch\n",
119 | "import torch.nn.functional as F\n",
120 | "import requests\n",
121 | "import torchvision.transforms as transforms\n",
122 | "import torchvision.transforms.functional as TF\n",
123 | "import clip\n",
124 | "from tqdm.notebook import tqdm\n",
125 | "from torchvision.transforms import Compose, Resize, ToTensor, Normalize\n",
126 | "from IPython.display import display\n",
127 | "from einops import rearrange\n",
128 | "from google.colab import files\n",
129 | "\n",
130 | "device = torch.device('cuda:0')\n",
131 | "print('Using device:', device, file=sys.stderr)"
132 | ],
133 | "execution_count": null,
134 | "outputs": []
135 | },
136 | {
137 | "cell_type": "code",
138 | "metadata": {
139 | "id": "zVkNODOot_To",
140 | "cellView": "form"
141 | },
142 | "source": [
143 | "#@markdown #**Optional:** Save images in Google Drive 💾\n",
144 | "# @markdown Run this cell if you want to store the results inside Google Drive.\n",
145 | "\n",
146 | "# @markdown Copying the generated images to drive is faster to work with.\n",
147 | "\n",
148 | "# @markdown **Important**: you must have a folder named *samples* inside your drive, otherwise this may not work.\n",
149 | "\n",
150 | "#@markdown ---\n",
151 | "\n",
152 | "# Uncomment to copy generated images to drive, faster than downloading directly from colab in my experience.\n",
153 | "from google.colab import drive\n",
154 | "drive.mount('/content/drive', force_remount=True, use_metadata_server=False)"
155 | ],
156 | "execution_count": null,
157 | "outputs": []
158 | },
159 | {
160 | "cell_type": "code",
161 | "metadata": {
162 | "id": "1GOWJ_z-wgde",
163 | "cellView": "form"
164 | },
165 | "source": [
166 | "#@markdown #**Define necessary functions** 🛠️\n",
167 | "\n",
168 | "def fetch(url_or_path):\n",
169 | " if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):\n",
170 | " r = requests.get(url_or_path)\n",
171 | " r.raise_for_status()\n",
172 | " fd = io.BytesIO()\n",
173 | " fd.write(r.content)\n",
174 | " fd.seek(0)\n",
175 | " return fd\n",
176 | " return open(url_or_path, 'rb')\n",
177 | "\n",
178 | "def fetch_model(url_or_path):\n",
179 | " basename = os.path.basename(url_or_path)\n",
180 | " if os.path.exists(basename):\n",
181 | " return basename\n",
182 | " else:\n",
183 | " !wget -c '{url_or_path}'\n",
184 | " return basename\n",
185 | "\n",
186 | "def norm1(prompt):\n",
187 | " \"Normalize to the unit sphere.\"\n",
188 | " return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt()\n",
189 | "\n",
190 | "def spherical_dist_loss(x, y):\n",
191 | " x = F.normalize(x, dim=-1)\n",
192 | " y = F.normalize(y, dim=-1)\n",
193 | " return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)\n",
194 | "\n",
195 | "class MakeCutouts(torch.nn.Module):\n",
196 | " def __init__(self, cut_size, cutn, cut_pow=1.):\n",
197 | " super().__init__()\n",
198 | " self.cut_size = cut_size\n",
199 | " self.cutn = cutn\n",
200 | " self.cut_pow = cut_pow\n",
201 | "\n",
202 | " def forward(self, input):\n",
203 | " sideY, sideX = input.shape[2:4]\n",
204 | " max_size = min(sideX, sideY)\n",
205 | " min_size = min(sideX, sideY, self.cut_size)\n",
206 | " cutouts = []\n",
207 | " for _ in range(self.cutn):\n",
208 | " size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)\n",
209 | " offsetx = torch.randint(0, sideX - size + 1, ())\n",
210 | " offsety = torch.randint(0, sideY - size + 1, ())\n",
211 | " cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n",
212 | " cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))\n",
213 | " return torch.cat(cutouts)\n",
214 | "\n",
215 | "make_cutouts = MakeCutouts(224, 32, 0.5)\n",
216 | "\n",
217 | "def embed_image(image):\n",
218 | " n = image.shape[0]\n",
219 | " cutouts = make_cutouts(image)\n",
220 | " embeds = clip_model.embed_cutout(cutouts)\n",
221 | " embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n)\n",
222 | " return embeds\n",
223 | "\n",
224 | "def embed_url(url):\n",
225 | " image = Image.open(fetch(url)).convert('RGB')\n",
226 | " return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0)\n",
227 | "\n",
228 | "class CLIP(object):\n",
229 | " def __init__(self):\n",
230 | " clip_model = \"ViT-B/32\"\n",
231 | " self.model, _ = clip.load(clip_model)\n",
232 | " self.model = self.model.requires_grad_(False)\n",
233 | " self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],\n",
234 | " std=[0.26862954, 0.26130258, 0.27577711])\n",
235 | "\n",
236 | " @torch.no_grad()\n",
237 | " def embed_text(self, prompt):\n",
238 | " \"Normalized clip text embedding.\"\n",
239 | " return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float())\n",
240 | "\n",
241 | " def embed_cutout(self, image):\n",
242 | " \"Normalized clip image embedding.\"\n",
243 | " return norm1(self.model.encode_image(self.normalize(image)))\n",
244 | " \n",
245 | "clip_model = CLIP()"
246 | ],
247 | "execution_count": null,
248 | "outputs": []
249 | },
250 | {
251 | "cell_type": "markdown",
252 | "metadata": {
253 | "id": "6VXU9WBvHbAX"
254 | },
255 | "source": [
256 | "Available models:\n",
257 | "* **metfaces** for painings\n",
258 | "* **afhqv2** for animals\n",
259 | "* **ffhq** for photo faces\n",
260 | "* **Cosplay Faces** trained by [@l4rz](https://twitter.com/l4rz)\n",
261 | "* Wikiart: Trained by [Justin Pinkney](https://www.justinpinkney.com/) with the Wikiart 1024 dataset.\n",
262 | "* Landscapes: Trained by [Justin Pinkney](https://www.justinpinkney.com/) with the LHQ dataset.\n",
263 | "\n",
264 | "\n",
265 | "modes:\n",
266 | "* **t** for translation\n",
267 | "* **r** for rotation\n",
268 | "\n",
269 | "e.g.: *stylegan3-t-metfaces-1024x1024.pkl for painings with translation in 1024x1024 resolution*"
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "metadata": {
275 | "id": "4vpIq2vvzTtS",
276 | "cellView": "form"
277 | },
278 | "source": [
279 | "#@markdown #**Model selection** 🎭\n",
280 | "\n",
281 | "#@markdown By default, the notebook downloads the FFHQ model.\n",
282 | "#@markdown **Run this cell again if you change the model**.\n",
283 | "\n",
284 | "#@markdown ---\n",
285 | "\n",
286 | "import numpy as np\n",
287 | "\n",
288 | "base_url = \"https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/\"\n",
289 | "\n",
290 | "#@markdown Choose a model:\n",
291 | "model_name = \"Landscapes\" #@param [\"Landscapes\", \"Wikiart\", \"stylegan2-cosplay-faces-512x512-px\", \"stylegan3-r-afhqv2-512x512.pkl\", \"stylegan3-r-ffhq-1024x1024.pkl\", \"stylegan3-r-ffhqu-1024x1024.pkl\",\"stylegan3-r-ffhqu-256x256.pkl\",\"stylegan3-r-metfaces-1024x1024.pkl\",\"stylegan3-r-metfacesu-1024x1024.pkl\",\"stylegan3-t-afhqv2-512x512.pkl\",\"stylegan3-t-ffhq-1024x1024.pkl\",\"stylegan3-t-ffhqu-1024x1024.pkl\",\"stylegan3-t-ffhqu-256x256.pkl\",\"stylegan3-t-metfaces-1024x1024.pkl\",\"stylegan3-t-metfacesu-1024x1024.pkl\"]\n",
292 | "network_url = base_url + model_name\n",
293 | "\n",
294 | "if model_name == \"stylegan2-cosplay-faces-512x512-px\":\n",
295 | " network_url = 'https://l4rz.net/cosplayface-snapshot-004000-18160-FID367.pkl'\n",
296 | "if model_name == \"Wikiart\":\n",
297 | " network_url = 'https://archive.org/download/wikiart-1024-stylegan3-t-17.2Mimg/wikiart-1024-stylegan3-t-17.2Mimg.pkl'\n",
298 | "if model_name == \"Landscapes\":\n",
299 | " network_url = 'https://archive.org/download/lhq-256-stylegan3-t-25Mimg/lhq-256-stylegan3-t-25Mimg.pkl' \n",
300 | "\n",
301 | "with open(fetch_model(network_url), 'rb') as fp:\n",
302 | " G = pickle.load(fp)['G_ema'].to(device)\n",
303 | "\n",
304 | "clip_model = CLIP()\n"
305 | ],
306 | "execution_count": null,
307 | "outputs": []
308 | },
309 | {
310 | "cell_type": "code",
311 | "metadata": {
312 | "id": "V_rq-N2m0Tlb",
313 | "cellView": "form"
314 | },
315 | "source": [
316 | "#@markdown #**Parameters** ✍️\n",
317 | "#@markdown ---\n",
318 | "\n",
319 | "#@markdown **Enter your text prompt here:**\n",
320 | "text = 'Eminem' #@param {type:\"string\"}\n",
321 | "target = clip_model.embed_text(text) \n",
322 | "\n",
323 | "#@markdown How many steps should it be doing?\n",
324 | "steps = 1500#@param \n",
325 | "\n",
326 | "target_images = \"em.jpg\"#@param {type:\"string\"}\n",
327 | "\n",
328 | "#@markdown Choose random seed (-1 for completly random)\n",
329 | "seed = -1#@param \n",
330 | "if seed == -1:\n",
331 | " seed = np.random.randint(0,2**32 - 1)\n",
332 | "\n",
333 | "#@markdown How often do you want to see the results?\n",
334 | "show_every_n_steps = 20 #@param\n",
335 | "#@markdown ___\n",
336 | "\n",
337 | "#@markdown **EXPERIMENTAL/W.I.P**: Do you want to mix 2 models? (try FFHQ with MetFaces) \n",
338 | "mix = \"No\" #@param [\"Yes\", \"No\"]\n",
339 | "model_1_name = \"stylegan3-r-metfaces-1024x1024.pkl\" #@param [\"Landscapes\", \"Wikiart\",\"stylegan3-r-ffhq-1024x1024.pkl\", \"stylegan3-r-ffhqu-1024x1024.pkl\",\"stylegan3-r-metfaces-1024x1024.pkl\",\"stylegan3-r-metfacesu-1024x1024.pkl\",\"stylegan3-t-ffhq-1024x1024.pkl\",\"stylegan3-t-ffhqu-1024x1024.pkl\",\"stylegan3-t-metfaces-1024x1024.pkl\",\"stylegan3-t-metfacesu-1024x1024.pkl\"] \n",
340 | "model_2_name = \"stylegan3-r-ffhq-1024x1024.pkl\" #@param [\"Landscapes\", \"Wikiart\", \"stylegan3-r-ffhq-1024x1024.pkl\", \"stylegan3-r-ffhqu-1024x1024.pkl\",\"stylegan3-r-metfaces-1024x1024.pkl\",\"stylegan3-r-metfacesu-1024x1024.pkl\",\"stylegan3-t-ffhq-1024x1024.pkl\",\"stylegan3-t-ffhqu-1024x1024.pkl\",\"stylegan3-t-metfaces-1024x1024.pkl\",\"stylegan3-t-metfacesu-1024x1024.pkl\"] \n",
341 | "proportion = 0.5 #@param {type:\"slider\", min:0, max:1, step:0.05}\n",
342 | "\n",
343 | "if mix == 'Yes':\n",
344 | " with open(fetch_model(base_url + model_1_name), 'rb') as fp:\n",
345 | " G1 = pickle.load(fp)['G_ema'].to(device)\n",
346 | "\n",
347 | " #START_TEST\n",
348 | " with open(fetch_model(base_url + model_2_name), 'rb') as fp:\n",
349 | " G2 = pickle.load(fp)['G_ema'].to(device)\n",
350 | "\n",
351 | " G = G2\n",
352 | " for p_out, p_in1, p_in2 in zip(G.parameters(), G1.parameters(), G2.parameters()):\n",
353 | " p_out.data = torch.nn.Parameter(p_in1*proportion+p_in2*(1-proportion));\n",
354 | "\n",
355 | "zs = torch.randn([10000, G.mapping.z_dim], device=device)\n",
356 | "w_stds = G.mapping(zs, None).std(0)\n",
357 | "#got it. I think its the w_std\n",
358 | "\n",
359 | "#@markdown Do you want to fix the coordinate grid? (False - wobbly effect)\n",
360 | "fix_coordinates = \"True\" #@param [\"True\", \"False\"]\n",
361 | "\n",
362 | "if fix_coordinates == \"True\":\n",
363 | " if model_name != \"stylegan2-cosplay-faces-512x512-px\":\n",
364 | " shift = G.synthesis.input.affine(G.mapping.w_avg.unsqueeze(0))\n",
365 | " G.synthesis.input.affine.bias.data.add_(shift.squeeze(0))\n",
366 | " G.synthesis.input.affine.weight.data.zero_()\n",
367 | "\n",
368 | "# target = embed_url(\"https://4.bp.blogspot.com/-uw859dFGsLc/Va5gt-bU9bI/AAAAAAAA4gM/dcaWzX0ZxdI/s1600/Lubjana+dragon+1.jpg\")\n",
369 | "# target = embed_url(\"https://irc.zlkj.in/uploads/e399d2fee2c6edd9/20210827165231_0_nexus%20of%20abandoned%20places.%20trending%20on%20ArtStation.png\")"
370 | ],
371 | "execution_count": null,
372 | "outputs": []
373 | },
374 | {
375 | "cell_type": "code",
376 | "metadata": {
377 | "id": "QXoRP4SHzJ6i",
378 | "cellView": "form"
379 | },
380 | "source": [
381 | "#@markdown #**Run the model** 🚀\n",
382 | "\n",
383 | "from IPython.display import display\n",
384 | "\n",
385 | "tf = Compose([\n",
386 | " Resize(224),\n",
387 | " lambda x: torch.clamp((x+1)/2,min=0,max=1),\n",
388 | " ])\n",
389 | "\n",
390 | "def run(seed, G):\n",
391 | " torch.manual_seed(seed)\n",
392 | " timestring = time.strftime('%Y%m%d%H%M%S')\n",
393 | " rand_z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.mapping.z_dim)).to(device)\n",
394 | " q = (G.mapping(rand_z, None, truncation_psi=0.2)) / w_stds\n",
395 | " q.requires_grad_()\n",
396 | " \n",
397 | " with torch.no_grad():\n",
398 | " qs = []\n",
399 | " losses = []\n",
400 | " for _ in range(8):\n",
401 | " q = (G.mapping(torch.randn([4,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds\n",
402 | " images = G.synthesis(q * w_stds + G.mapping.w_avg)\n",
403 | " embeds = embed_image(images.add(1).div(2))\n",
404 | " loss = spherical_dist_loss(embeds, target).mean(0)\n",
405 | " i = torch.argmin(loss)\n",
406 | " qs.append(q[i])\n",
407 | " losses.append(loss[i])\n",
408 | " qs = torch.stack(qs)\n",
409 | " losses = torch.stack(losses)\n",
410 | " i = torch.argmin(losses)\n",
411 | " q = qs[i].unsqueeze(0).requires_grad_()\n",
412 | "\n",
413 | "# Sampling loop\n",
414 | " q_ema = q\n",
415 | " opt = torch.optim.AdamW([q], lr=0.03, betas=(0.0,0.999))\n",
416 | " loop = tqdm(range(steps))\n",
417 | " for i in loop:\n",
418 | " opt.zero_grad()\n",
419 | " w = q * w_stds\n",
420 | " image = G.synthesis(w + G.mapping.w_avg, noise_mode='const')\n",
421 | " embed = embed_image(image.add(1).div(2))\n",
422 | " loss = spherical_dist_loss(embed, target).mean()\n",
423 | " loss.backward()\n",
424 | " opt.step()\n",
425 | " loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item())\n",
426 | "\n",
427 | " q_ema = q_ema * 0.9 + q * 0.1\n",
428 | " image = G.synthesis(q_ema * w_stds + G.mapping.w_avg, noise_mode='const')\n",
429 | "\n",
430 | " if i % 10 == 0:\n",
431 | " display(TF.to_pil_image(tf(image)[0]))\n",
432 | " pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1))\n",
433 | " os.makedirs(f'samples/{timestring}', exist_ok=True)\n",
434 | " pil_image.save(f'samples/{timestring}/{i:04}.jpg')\n",
435 | " \n",
436 | " # Save images as a tar archive\n",
437 | " !tar cf samples/{timestring}.tar samples/{timestring}\n",
438 | "\n",
439 | " return timestring\n",
440 | "\n",
441 | "timestring = run(seed, G)"
442 | ],
443 | "execution_count": null,
444 | "outputs": []
445 | },
446 | {
447 | "cell_type": "code",
448 | "metadata": {
449 | "id": "JSnRMY8-_-iV",
450 | "cellView": "form"
451 | },
452 | "source": [
453 | " #@markdown #**Save images** 📷\n",
454 | " #@markdown A `.tar` file will be saved inside *samples* and automatically downloaded, unless you previously ran the Google Drive cell,\n",
455 | " #@markdown in which case it'll be saved inside your previously created drive *samples* folder.\n",
456 | " \n",
457 | "\n",
458 | " # Save images as a tar archive\n",
459 | " !tar cf samples/{timestring}.tar samples/{timestring}\n",
460 | " if os.path.isdir('drive/MyDrive/samples'):\n",
461 | " shutil.copyfile(f'samples/{timestring}.tar', f'drive/MyDrive/samples/{timestring}.tar')\n",
462 | " else:\n",
463 | " files.download(f'samples/{timestring}.tar')"
464 | ],
465 | "execution_count": null,
466 | "outputs": []
467 | },
468 | {
469 | "cell_type": "code",
470 | "metadata": {
471 | "id": "Z9Yyt8y99jfv",
472 | "cellView": "form"
473 | },
474 | "source": [
475 | "#@markdown #**Generate video** 🎥\n",
476 | "\n",
477 | "\n",
478 | "frames = os.listdir(f\"samples/{timestring}\")\n",
479 | "frames = len(list(filter(lambda filename: filename.endswith(\".jpg\"), frames))) #Get number of jpg generated\n",
480 | "\n",
481 | "init_frame = 1 #This is the frame where the video will start\n",
482 | "last_frame = frames #You can change i to the number of the last frame you want to generate. It will raise an error if that number of frames does not exist.\n",
483 | "\n",
484 | "min_fps = 10\n",
485 | "max_fps = 30\n",
486 | "\n",
487 | "total_frames = last_frame-init_frame\n",
488 | "\n",
489 | "#Desired video time in seconds\n",
490 | "video_length = 15 #@param {type:\"number\"}\n",
491 | "\n",
492 | "frames = []\n",
493 | "tqdm.write('Generating video...')\n",
494 | "for i in range(init_frame,last_frame): #\n",
495 | " filename = f\"samples/{timestring}/{i:04}.jpg\"\n",
496 | " frames.append(Image.open(filename))\n",
497 | "\n",
498 | "fps = np.clip(total_frames/video_length,min_fps,max_fps)\n",
499 | "fps = 30 #@param\n",
500 | "\n",
501 | "from subprocess import Popen, PIPE\n",
502 | "p = Popen(['ffmpeg', '-y', '-f', 'image2pipe', '-vcodec', 'png', '-r', str(fps), '-i', '-', '-vcodec', 'libx264', '-r', str(fps), '-pix_fmt', 'yuv420p', '-crf', '17', '-preset', 'veryslow', 'video.mp4'], stdin=PIPE)\n",
503 | "for im in tqdm(frames):\n",
504 | " im.save(p.stdin, 'PNG')\n",
505 | "p.stdin.close()\n",
506 | "\n",
507 | "print(\"The video is now being compressed, wait...\")\n",
508 | "p.wait()\n",
509 | "print(\"The video is ready\")"
510 | ],
511 | "execution_count": null,
512 | "outputs": []
513 | },
514 | {
515 | "cell_type": "code",
516 | "metadata": {
517 | "id": "uNpjDjR_-0dN",
518 | "cellView": "form"
519 | },
520 | "source": [
521 | "#@markdown #**Download video** 📀\n",
522 | "from google.colab import files\n",
523 | "files.download(\"video.mp4\")"
524 | ],
525 | "execution_count": null,
526 | "outputs": []
527 | },
528 | {
529 | "cell_type": "code",
530 | "metadata": {
531 | "cellView": "form",
532 | "id": "LTvMW3ssK9Lh"
533 | },
534 | "source": [
535 | "#@markdown #**View video in browser** 👀\n",
536 | "\n",
537 | "# @markdown This process may take a little longer.\n",
538 | "from IPython.display import HTML\n",
539 | "from base64 import b64encode\n",
540 | "mp4 = open('video.mp4','rb').read()\n",
541 | "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
542 | "HTML(\"\"\"\n",
543 | "\n",
546 | "\"\"\" % data_url)"
547 | ],
548 | "execution_count": null,
549 | "outputs": []
550 | },
551 | {
552 | "cell_type": "markdown",
553 | "metadata": {
554 | "id": "s5KAVLgCNTz7"
555 | },
556 | "source": [
557 | "JS to prevent idle timeout:\n",
558 | "\n",
559 | "Press F12 OR CTRL + SHIFT + I OR right click on this website -> inspect.\n",
560 | "Then click on the console tab and paste in the following code.\n",
561 | "\n",
562 | "```javascript\n",
563 | "function ClickConnect(){\n",
564 | "console.log(\"Working\");\n",
565 | "document.querySelector(\"colab-toolbar-button#connect\").click()\n",
566 | "}\n",
567 | "setInterval(ClickConnect,60000)\n",
568 | "```"
569 | ]
570 | }
571 | ]
572 | }
--------------------------------------------------------------------------------