├── .gitignore
├── .gitlint
├── .pre-commit-config.yaml
├── .streamlit
└── config.toml
├── CHANGELOG.md
├── LICENSE
├── README.md
├── app.py
├── assets
└── .gitkeep
├── defaults.yaml
├── diffusion_app.py
├── diffusion_logic.py
├── docs
├── architecture.md
├── images
│ ├── diffusion-example.jpg
│ ├── four-seasons-20210808.jpg
│ ├── gallery.jpg
│ ├── translationx_example_trimmed.gif
│ └── ui.jpeg
├── implementation-details.md
├── notes-and-observations.md
└── tips-n-tricks.md
├── download-diffusion-weights.sh
├── download-weights.sh
├── environment.yml
├── gallery.py
├── logic.py
├── templates
└── index.html
└── vqgan_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Common Python excludes
2 | **/.ipynb_checkpoints/
3 | **/__pycache__/
4 |
5 | # VQGAN weights
6 | assets/*.ckpt
7 | assets/*.yaml
8 |
9 | # Outputs
10 | output*
11 |
12 | # Test data
13 | test-samples/
14 |
--------------------------------------------------------------------------------
/.gitlint:
--------------------------------------------------------------------------------
1 | # Config file for gitlint (https://jorisroovers.com/gitlint/)
2 | # Can be installed from pip, recommend to use pipx
3 | # Generate default .gitlint using gitlint generate-config
4 | # Can be used as a hook in pre-commit
5 |
6 | [general]
7 | # B6 requires a body message to be defined
8 | ignore=B6
9 | # Activate conventional commits
10 | contrib=contrib-title-conventional-commits
11 |
12 | [contrib-title-conventional-commits]
13 | # Specify allowed commit types. For details see: https://www.conventionalcommits.org/
14 | types = feat, fix, chore, docs, test, refactor, update, merge, explore
15 |
16 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: "v2.3.0"
4 | hooks:
5 | - id: check-ast
6 | - id: check-json
7 | - id: check-yaml
8 | - repo: https://github.com/psf/black
9 | rev: "19.3b0"
10 | hooks:
11 | - id: black
12 | - repo: https://github.com/jorisroovers/gitlint
13 | rev: "018f42e"
14 | hooks:
15 | - id: gitlint
16 | # required to refer to .gitlint from where the pre-commit venv is
17 | args: [-C ../../../.gitlint, --msg-filename]
18 |
--------------------------------------------------------------------------------
/.streamlit/config.toml:
--------------------------------------------------------------------------------
1 | [server]
2 | # Default is 200 MB
3 | maxUploadSize = 10
4 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | ## 1.1 - Dec 18, 2021
4 |
5 | ### What's new
6 |
7 | + Fixed gallery app adding new images to the last page instead of the first page when refreshing (7bf3b04)
8 | + Added output commit ID to metadata if GitPython is installed (19eeb30)
9 | + Added cutout augmentations to VQGAN-CLIP (b3a7ab1) and guided diffusion (9651bc1)
10 | + Fixed random seed not saved to output if unspecified in guided diffusion (8972a5f)
11 | + Added feature to enable generating scrolling/zooming images to VQGAN-CLIP (https://github.com/tnwei/vqgan-clip-app/pull/11)
12 |
13 | ### Transitioning to 1.1 from 1.0
14 |
15 | The Python environment defined in `environment.yml` has been updated to enable generating scrolling/zooming images. Although we only need to add opencv, conda's package resolution required updating Pytorch as well to find a compatible version of opencv.
16 |
17 | Therefore existing users need to do either of the following:
18 |
19 | ``` bash
20 | # Remove the current Python env and recreate
21 | conda env remove -n vqgan-clip-app
22 | conda env create -f environment.yml
23 |
24 | # Directly update the Python environment in-place
25 | conda activate vqgan-clip-app
26 | conda env update -f environment.yml --prune
27 | ```
28 |
29 | ## 1.0 - Nov 21, 2021
30 |
31 | Since starting this repo in Jul 2021 as a personal project, I believe the codebase is now sufficiently feature-complete and stable to call it a 1.0 release.
32 |
33 | ### What's new
34 |
35 | + Added support for CLIP guided diffusion (https://github.com/tnwei/vqgan-clip-app/pull/8)
36 | + Improvements to gallery viewer: added pagination (https://github.com/tnwei/vqgan-clip-app/pull/3), improved webpage responsiveness (https://github.com/tnwei/vqgan-clip-app/issues/9)
37 | + Minor tweaks to VQGAN-CLIP app, added options to control MSE regularization (https://github.com/tnwei/vqgan-clip-app/pull/4) and TV loss (https://github.com/tnwei/vqgan-clip-app/5)
38 | + Reorganized documentation in README.md and `docs/`
39 |
40 | ### Transitioning to 1.0 for existing users
41 |
42 | Update to 1.0 by running `git pull` from your local copy of this repo. No breaking changes are expected, run results from older versions of the codebase should still show up in the gallery viewer.
43 |
44 | However, some new packages are needed to support CLIP guided diffusion. You can follow these steps below instead of setting up the Python environment from scratch:
45 |
46 | 1. In the repo directory, run `git clone https://github.com/crowsonkb/guided-diffusion`
47 | 2. `pip install ./guided-diffusion`
48 | 3. `pip install lpips`
49 | 4. Download diffusion model checkpoints in `download-diffusion-weights.sh`
50 |
51 | ## Before Oct 2, 2021
52 |
53 | VQGAN-CLIP app and basic gallery viewer implemented.
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 tnwei
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VQGAN-CLIP web app & CLIP guided diffusion web app
2 |
3 | 
4 | 
5 |
6 | Link to repo: [tnwei/vqgan-clip-app](https://github.com/tnwei/vqgan-clip-app).
7 |
8 | ## Intro to VQGAN-CLIP
9 |
10 | VQGAN-CLIP has been in vogue for generating art using deep learning. Searching the `r/deepdream` subreddit for VQGAN-CLIP yields [quite a number of results](https://www.reddit.com/r/deepdream/search?q=vqgan+clip&restrict_sr=on). Basically, [VQGAN](https://github.com/CompVis/taming-transformers) can generate pretty high fidelity images, while [CLIP](https://github.com/openai/CLIP) can produce relevant captions for images. Combined, VQGAN-CLIP can take prompts from human input, and iterate to generate images that fit the prompts.
11 |
12 | Thanks to the generosity of creators sharing notebooks on Google Colab, the VQGAN-CLIP technique has seen widespread circulation. However, for regular usage across multiple sessions, I prefer a local setup that can be started up rapidly. Thus, this simple Streamlit app for generating VQGAN-CLIP images on a local environment. Screenshot of the UI as below:
13 |
14 | 
15 |
16 | Be advised that you need a beefy GPU with lots of VRAM to generate images large enough to be interesting. (Hello Quadro owners!). For reference, an RTX2060 can barely manage a 300x300 image. Otherwise you are best served using the notebooks on Colab.
17 |
18 | Reference is [this Colab notebook](https://colab.research.google.com/drive/1L8oL-vLJXVcRzCFbPwOoMkPKJ8-aYdPN?usp=sharing) originally by Katherine Crowson. The notebook can also be found in [this repo hosted by EleutherAI](https://github.com/EleutherAI/vqgan-clip).
19 |
20 | ## Intro to CLIP guided diffusion
21 |
22 | In mid 2021, Open AI released [Diffusion Models Beat GANS on Image Synthesis](arxiv.org/abs/2105.05233), with corresponding [source code and model checkpoints released on github](https://github.com/openai/guided-diffusion). The cadre of people that brought us VQGAN-CLIP worked their magic, and shared CLIP guided diffusion notebooks for public use. CLIP guided diffusion uses more GPU VRAM, runs slower, and has fixed output sizes depending on the trained model checkpoints, but is capable of producing more breathtaking images.
23 |
24 | Here's a few examples using the prompt _"Flowery fragrance intertwined with the freshness of the ocean breeze by Greg Rutkowski"_, run on the 512x512 HQ Uncond model:
25 |
26 | 
27 |
28 | The implementation of CLIP guided diffusion in this repo is based on notebooks from the same `EleutherAI/vqgan-clip` repo.
29 |
30 | ## Setup
31 |
32 | 1. Install the required Python libraries. Using `conda`, run `conda env create -f environment.yml`
33 | 2. Git clone this repo. After that, `cd` into the repo and run:
34 | + `git clone https://github.com/CompVis/taming-transformers` (Update to pip install if either of [these](https://github.com/CompVis/taming-transformers/pull/89) [two](https://github.com/CompVis/taming-transformers/pull/81) PRs are merged)
35 | + `git clone https://github.com/crowsonkb/guided-diffusion` (Update to pip install if [this PR](https://github.com/crowsonkb/guided-diffusion/pull/2) is merged)
36 | 3. Download the pretrained weights and config files using the provided links in the files listed below. Note that that all of the links are commented out by default. Recommend to download one by one, as some of the downloads can take a while.
37 | + For VQGAN-CLIP: `download-weights.sh`. You'll want to at least have both the ImageNet weights, which are used in the reference notebook.
38 | + For CLIP guided diffusion: `download-diffusion-weights.sh`.
39 |
40 | ## Usage
41 |
42 | + VQGAN-CLIP: `streamlit run app.py`, launches web app on `localhost:8501` if available
43 | + CLIP guided diffusion: `streamlit run diffusion_app.py`, launches web app on `localhost:8501` if available
44 | + Image gallery: `python gallery.py`, launches a gallery viewer on `localhost:5000`. More on this below.
45 |
46 | In the web app, select settings on the sidebar, key in the text prompt, and click run to generate images using VQGAN-CLIP. When done, the web app will display the output image as well as a video compilation showing progression of image generation. You can save them directly through the browser's right-click menu.
47 |
48 | A one-time download of additional pre-trained weights will occur before generating the first image. Might take a few minutes depending on your internet connection.
49 |
50 | If you have multiple GPUs, specify the GPU you want to use by adding `-- --gpu X`. An extra double dash is required to [bypass Streamlit argument parsing](https://github.com/streamlit/streamlit/issues/337). Example commands:
51 |
52 | ```bash
53 | # Use 2nd GPU
54 | streamlit run app.py -- --gpu 1
55 |
56 | # Use 3rd GPU
57 | streamlit run diffusion_app.py -- --gpu 2
58 | ```
59 |
60 | See: [tips and tricks](docs/tips-n-tricks.md)
61 |
62 | ## Output and gallery viewer
63 |
64 | Each run's metadata and output is saved to the `output/` directory, organized into subfolders named using the timestamp when a run is launched, as well as a unique run ID. Example `output` dir:
65 |
66 | ``` bash
67 | $ tree output
68 | ├── 20210920T232927-vTf6Aot6
69 | │ ├── anim.mp4
70 | │ ├── details.json
71 | │ └── output.PNG
72 | └── 20210920T232935-9TJ9YusD
73 | ├── anim.mp4
74 | ├── details.json
75 | └── output.PNG
76 | ```
77 |
78 | The gallery viewer reads from `output/` and visualizes previous runs together with saved metadata.
79 |
80 | 
81 |
82 | If the details are too much, call `python gallery.py --kiosk` instead to only show the images and their prompts.
83 |
84 | ## More details
85 |
86 | + [Architecture](docs/architecture.md)
87 | + [Implementation details](docs/implementation-details.md)
88 | + [Tips and tricks](docs/tips-n-tricks.md)
89 | + [Notes and observations](docs/notes-and-observations.md)
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | """
2 | This script is organized like so:
3 | + `if __name__ == "__main__" sets up the Streamlit UI elements
4 | + `generate_image` houses interactions between UI and the CLIP image
5 | generation models
6 | + Core model code is abstracted in `logic.py` and imported in `generate_image`
7 | """
8 | import streamlit as st
9 | from pathlib import Path
10 | import sys
11 | import datetime
12 | import shutil
13 | import torch
14 | import json
15 | import os
16 | import base64
17 | import traceback
18 |
19 | import argparse
20 |
21 | sys.path.append("./taming-transformers")
22 |
23 | from PIL import Image
24 | from typing import Optional, List
25 | from omegaconf import OmegaConf
26 | import imageio
27 | import numpy as np
28 |
29 | # Catch import issue, introduced in version 1.1
30 | # Deprecate in a few minor versions
31 | try:
32 | import cv2
33 | except ModuleNotFoundError:
34 | st.warning(
35 | "Version 1.1 onwards requires opencv. Please update your Python environment as defined in `environment.yml`"
36 | )
37 |
38 | from logic import VQGANCLIPRun
39 |
40 | # Optional
41 | try:
42 | import git
43 | except ModuleNotFoundError:
44 | pass
45 |
46 |
47 | def generate_image(
48 | text_input: str = "the first day of the waters",
49 | vqgan_ckpt: str = "vqgan_imagenet_f16_16384",
50 | num_steps: int = 300,
51 | image_x: int = 300,
52 | image_y: int = 300,
53 | init_image: Optional[Image.Image] = None,
54 | image_prompts: List[Image.Image] = [],
55 | continue_prev_run: bool = False,
56 | seed: Optional[int] = None,
57 | mse_weight: float = 0,
58 | mse_weight_decay: float = 0,
59 | mse_weight_decay_steps: int = 0,
60 | tv_loss_weight: float = 1e-3,
61 | use_scrolling_zooming: bool = False,
62 | translation_x: int = 0,
63 | translation_y: int = 0,
64 | rotation_angle: float = 0,
65 | zoom_factor: float = 1,
66 | transform_interval: int = 10,
67 | use_cutout_augmentations: bool = True,
68 | device: Optional[torch.device] = None,
69 | ) -> None:
70 |
71 | ### Init -------------------------------------------------------------------
72 | run = VQGANCLIPRun(
73 | text_input=text_input,
74 | vqgan_ckpt=vqgan_ckpt,
75 | num_steps=num_steps,
76 | image_x=image_x,
77 | image_y=image_y,
78 | seed=seed,
79 | init_image=init_image,
80 | image_prompts=image_prompts,
81 | continue_prev_run=continue_prev_run,
82 | mse_weight=mse_weight,
83 | mse_weight_decay=mse_weight_decay,
84 | mse_weight_decay_steps=mse_weight_decay_steps,
85 | tv_loss_weight=tv_loss_weight,
86 | use_scrolling_zooming=use_scrolling_zooming,
87 | translation_x=translation_x,
88 | translation_y=translation_y,
89 | rotation_angle=rotation_angle,
90 | zoom_factor=zoom_factor,
91 | transform_interval=transform_interval,
92 | use_cutout_augmentations=use_cutout_augmentations,
93 | device=device,
94 | )
95 |
96 | ### Load model -------------------------------------------------------------
97 |
98 | if continue_prev_run is True:
99 | run.load_model(
100 | prev_model=st.session_state["model"],
101 | prev_perceptor=st.session_state["perceptor"],
102 | )
103 | prev_run_id = st.session_state["run_id"]
104 |
105 | else:
106 | # Remove the cache first! CUDA out of memory
107 | if "model" in st.session_state:
108 | del st.session_state["model"]
109 |
110 | if "perceptor" in st.session_state:
111 | del st.session_state["perceptor"]
112 |
113 | st.session_state["model"], st.session_state["perceptor"] = run.load_model()
114 | prev_run_id = None
115 |
116 | # Generate random run ID
117 | # Used to link runs linked w/ continue_prev_run
118 | # ref: https://stackoverflow.com/a/42703382/13095028
119 | # Use URL and filesystem safe version since we're using this as a folder name
120 | run_id = st.session_state["run_id"] = base64.urlsafe_b64encode(
121 | os.urandom(6)
122 | ).decode("ascii")
123 |
124 | run_start_dt = datetime.datetime.now()
125 |
126 | ### Model init -------------------------------------------------------------
127 | if continue_prev_run is True:
128 | run.model_init(init_image=st.session_state["prev_im"])
129 | elif init_image is not None:
130 | run.model_init(init_image=init_image)
131 | else:
132 | run.model_init()
133 |
134 | ### Iterate ----------------------------------------------------------------
135 | step_counter = 0
136 | frames = []
137 |
138 | try:
139 | # Try block catches st.script_runner.StopExecution, no need of a dedicated stop button
140 | # Reason is st.form is meant to be self-contained either within sidebar, or in main body
141 | # The way the form is implemented in this app splits the form across both regions
142 | # This is intended to prevent the model settings from crowding the main body
143 | # However, touching any button resets the app state, making it impossible to
144 | # implement a stop button that can still dump output
145 | # Thankfully there's a built-in stop button :)
146 | while True:
147 | # While loop to accomodate running predetermined steps or running indefinitely
148 | status_text.text(f"Running step {step_counter}")
149 |
150 | _, im = run.iterate()
151 |
152 | if num_steps > 0: # skip when num_steps = -1
153 | step_progress_bar.progress((step_counter + 1) / num_steps)
154 | else:
155 | step_progress_bar.progress(100)
156 |
157 | # At every step, display and save image
158 | im_display_slot.image(im, caption="Output image", output_format="PNG")
159 | st.session_state["prev_im"] = im
160 |
161 | # ref: https://stackoverflow.com/a/33117447/13095028
162 | # im_byte_arr = io.BytesIO()
163 | # im.save(im_byte_arr, format="JPEG")
164 | # frames.append(im_byte_arr.getvalue()) # read()
165 | frames.append(np.asarray(im))
166 |
167 | step_counter += 1
168 |
169 | if (step_counter == num_steps) and num_steps > 0:
170 | break
171 |
172 | # Stitch into video using imageio
173 | writer = imageio.get_writer("temp.mp4", fps=24)
174 | for frame in frames:
175 | writer.append_data(frame)
176 | writer.close()
177 |
178 | # Save to output folder if run completed
179 | runoutputdir = outputdir / (
180 | run_start_dt.strftime("%Y%m%dT%H%M%S") + "-" + run_id
181 | )
182 | runoutputdir.mkdir()
183 |
184 | # Save final image
185 | im.save(runoutputdir / "output.PNG", format="PNG")
186 |
187 | # Save init image
188 | if init_image is not None:
189 | init_image.save(runoutputdir / "init-image.JPEG", format="JPEG")
190 |
191 | # Save image prompts
192 | for count, image_prompt in enumerate(image_prompts):
193 | image_prompt.save(
194 | runoutputdir / f"image-prompt-{count}.JPEG", format="JPEG"
195 | )
196 |
197 | # Save animation
198 | shutil.copy("temp.mp4", runoutputdir / "anim.mp4")
199 |
200 | # Save metadata
201 | details = {
202 | "run_id": run_id,
203 | "num_steps": step_counter,
204 | "planned_num_steps": num_steps,
205 | "text_input": text_input,
206 | "init_image": False if init_image is None else True,
207 | "image_prompts": False if len(image_prompts) == 0 else True,
208 | "continue_prev_run": continue_prev_run,
209 | "prev_run_id": prev_run_id,
210 | "seed": run.seed,
211 | "Xdim": image_x,
212 | "ydim": image_y,
213 | "vqgan_ckpt": vqgan_ckpt,
214 | "start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"),
215 | "end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"),
216 | "mse_weight": mse_weight,
217 | "mse_weight_decay": mse_weight_decay,
218 | "mse_weight_decay_steps": mse_weight_decay_steps,
219 | "tv_loss_weight": tv_loss_weight,
220 | }
221 |
222 | if use_scrolling_zooming:
223 | details.update(
224 | {
225 | "translation_x": translation_x,
226 | "translation_y": translation_y,
227 | "rotation_angle": rotation_angle,
228 | "zoom_factor": zoom_factor,
229 | "transform_interval": transform_interval,
230 | }
231 | )
232 | if use_cutout_augmentations:
233 | details["use_cutout_augmentations"] = True
234 |
235 | if "git" in sys.modules:
236 | try:
237 | repo = git.Repo(search_parent_directories=True)
238 | commit_sha = repo.head.object.hexsha
239 | details["commit_sha"] = commit_sha[:6]
240 | except Exception as e:
241 | print("GitPython detected but not able to write commit SHA to file")
242 | print(f"raised Exception {e}")
243 |
244 | with open(runoutputdir / "details.json", "w") as f:
245 | json.dump(details, f, indent=4)
246 |
247 | status_text.text("Done!") # End of run
248 |
249 | except st.script_runner.StopException as e:
250 | # Dump output to dashboard
251 | print(f"Received Streamlit StopException")
252 | status_text.text("Execution interruped, dumping outputs ...")
253 | writer = imageio.get_writer("temp.mp4", fps=24)
254 | for frame in frames:
255 | writer.append_data(frame)
256 | writer.close()
257 |
258 | # TODO: Make the following DRY
259 | # Save to output folder if run completed
260 | runoutputdir = outputdir / (
261 | run_start_dt.strftime("%Y%m%dT%H%M%S") + "-" + run_id
262 | )
263 | runoutputdir.mkdir()
264 |
265 | # Save final image
266 | im.save(runoutputdir / "output.PNG", format="PNG")
267 |
268 | # Save init image
269 | if init_image is not None:
270 | init_image.save(runoutputdir / "init-image.JPEG", format="JPEG")
271 |
272 | # Save image prompts
273 | for count, image_prompt in enumerate(image_prompts):
274 | image_prompt.save(
275 | runoutputdir / f"image-prompt-{count}.JPEG", format="JPEG"
276 | )
277 |
278 | # Save animation
279 | shutil.copy("temp.mp4", runoutputdir / "anim.mp4")
280 |
281 | # Save metadata
282 | details = {
283 | "run_id": run_id,
284 | "num_steps": step_counter,
285 | "planned_num_steps": num_steps,
286 | "text_input": text_input,
287 | "init_image": False if init_image is None else True,
288 | "image_prompts": False if len(image_prompts) == 0 else True,
289 | "continue_prev_run": continue_prev_run,
290 | "prev_run_id": prev_run_id,
291 | "seed": run.seed,
292 | "Xdim": image_x,
293 | "ydim": image_y,
294 | "vqgan_ckpt": vqgan_ckpt,
295 | "start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"),
296 | "end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"),
297 | "mse_weight": mse_weight,
298 | "mse_weight_decay": mse_weight_decay,
299 | "mse_weight_decay_steps": mse_weight_decay_steps,
300 | "tv_loss_weight": tv_loss_weight,
301 | }
302 |
303 | if use_scrolling_zooming:
304 | details.update(
305 | {
306 | "translation_x": translation_x,
307 | "translation_y": translation_y,
308 | "rotation_angle": rotation_angle,
309 | "zoom_factor": zoom_factor,
310 | "transform_interval": transform_interval,
311 | }
312 | )
313 | if use_cutout_augmentations:
314 | details["use_cutout_augmentations"] = True
315 |
316 | if "git" in sys.modules:
317 | try:
318 | repo = git.Repo(search_parent_directories=True)
319 | commit_sha = repo.head.object.hexsha
320 | details["commit_sha"] = commit_sha[:6]
321 | except Exception as e:
322 | print("GitPython detected but not able to write commit SHA to file")
323 | print(f"raised Exception {e}")
324 |
325 | with open(runoutputdir / "details.json", "w") as f:
326 | json.dump(details, f, indent=4)
327 |
328 | status_text.text("Done!") # End of run
329 |
330 |
331 | if __name__ == "__main__":
332 |
333 | # Argparse to capture GPU num
334 | parser = argparse.ArgumentParser()
335 |
336 | parser.add_argument(
337 | "--gpu", type=str, default=None, help="Specify GPU number. Defaults to None."
338 | )
339 | args = parser.parse_args()
340 |
341 | # Select specific GPU if chosen
342 | if args.gpu is not None:
343 | for i in args.gpu.split(","):
344 | assert (
345 | int(i) < torch.cuda.device_count()
346 | ), f"You specified --gpu {args.gpu} but torch.cuda.device_count() returned {torch.cuda.device_count()}"
347 |
348 | try:
349 | device = torch.device(f"cuda:{args.gpu}")
350 | except RuntimeError:
351 | print(traceback.format_exc())
352 | else:
353 | device = None
354 |
355 | defaults = OmegaConf.load("defaults.yaml")
356 | outputdir = Path("output")
357 | if not outputdir.exists():
358 | outputdir.mkdir()
359 |
360 | st.set_page_config(page_title="VQGAN-CLIP playground")
361 | st.title("VQGAN-CLIP playground")
362 |
363 | # Determine what weights are available in `assets/`
364 | weights_dir = Path("assets").resolve()
365 | available_weight_ckpts = list(weights_dir.glob("*.ckpt"))
366 | available_weight_configs = list(weights_dir.glob("*.yaml"))
367 | available_weights = [
368 | i.stem
369 | for i in available_weight_ckpts
370 | if i.stem in [j.stem for j in available_weight_configs]
371 | ]
372 |
373 | # i.e. no weights found, ask user to download weights
374 | if len(available_weights) == 0:
375 | st.warning("No weights found in `assets/`, refer to `download-weights.sh`")
376 | st.stop()
377 |
378 | # Set vqgan_imagenet_f16_1024 as default if possible
379 | if "vqgan_imagenet_f16_1024" in available_weights:
380 | default_weight_index = available_weights.index("vqgan_imagenet_f16_1024")
381 | else:
382 | default_weight_index = 0
383 |
384 | # Start of input form
385 | with st.form("form-inputs"):
386 | # Only element not in the sidebar, but in the form
387 | text_input = st.text_input(
388 | "Text prompt",
389 | help="VQGAN-CLIP will generate an image that best fits the prompt",
390 | )
391 | radio = st.sidebar.radio(
392 | "Model weights",
393 | available_weights,
394 | index=default_weight_index,
395 | help="Choose which weights to load, trained on different datasets. Make sure the weights and configs are downloaded to `assets/` as per the README!",
396 | )
397 | num_steps = st.sidebar.number_input(
398 | "Num steps",
399 | value=defaults["num_steps"],
400 | min_value=-1,
401 | max_value=None,
402 | step=1,
403 | help="Specify -1 to run indefinitely. Use Streamlit's stop button in the top right corner to terminate execution. The exception is caught so the most recent output will be dumped to dashboard",
404 | )
405 |
406 | image_x = st.sidebar.number_input(
407 | "Xdim", value=defaults["Xdim"], help="Width of output image, in pixels"
408 | )
409 | image_y = st.sidebar.number_input(
410 | "ydim", value=defaults["ydim"], help="Height of output image, in pixels"
411 | )
412 | set_seed = st.sidebar.checkbox(
413 | "Set seed",
414 | value=defaults["set_seed"],
415 | help="Check to set random seed for reproducibility. Will add option to specify seed",
416 | )
417 |
418 | seed_widget = st.sidebar.empty()
419 | if set_seed is True:
420 | # Use text_input as number_input relies on JS
421 | # which can't natively handle large numbers
422 | # torch.seed() generates int w/ 19 or 20 chars!
423 | seed_str = seed_widget.text_input(
424 | "Seed", value=str(defaults["seed"]), help="Random seed to use"
425 | )
426 | try:
427 | seed = int(seed_str)
428 | except ValueError as e:
429 | st.error("seed input needs to be int")
430 | else:
431 | seed = None
432 |
433 | use_custom_starting_image = st.sidebar.checkbox(
434 | "Use starting image",
435 | value=defaults["use_starting_image"],
436 | help="Check to add a starting image to the network",
437 | )
438 |
439 | starting_image_widget = st.sidebar.empty()
440 | if use_custom_starting_image is True:
441 | init_image = starting_image_widget.file_uploader(
442 | "Upload starting image",
443 | type=["png", "jpeg", "jpg"],
444 | accept_multiple_files=False,
445 | help="Starting image for the network, will be resized to fit specified dimensions",
446 | )
447 | # Convert from UploadedFile object to PIL Image
448 | if init_image is not None:
449 | init_image: Image.Image = Image.open(init_image).convert(
450 | "RGB"
451 | ) # just to be sure
452 | else:
453 | init_image = None
454 |
455 | use_image_prompts = st.sidebar.checkbox(
456 | "Add image prompt(s)",
457 | value=defaults["use_image_prompts"],
458 | help="Check to add image prompt(s), conditions the network similar to the text prompt",
459 | )
460 |
461 | image_prompts_widget = st.sidebar.empty()
462 | if use_image_prompts is True:
463 | image_prompts = image_prompts_widget.file_uploader(
464 | "Upload image prompts(s)",
465 | type=["png", "jpeg", "jpg"],
466 | accept_multiple_files=True,
467 | help="Image prompt(s) for the network, will be resized to fit specified dimensions",
468 | )
469 | # Convert from UploadedFile object to PIL Image
470 | if len(image_prompts) != 0:
471 | image_prompts = [Image.open(i).convert("RGB") for i in image_prompts]
472 | else:
473 | image_prompts = []
474 |
475 | continue_prev_run = st.sidebar.checkbox(
476 | "Continue previous run",
477 | value=defaults["continue_prev_run"],
478 | help="Use existing image and existing weights for the next run. If yes, ignores 'Use starting image'",
479 | )
480 |
481 | use_mse_reg = st.sidebar.checkbox(
482 | "Use MSE regularization",
483 | value=defaults["use_mse_regularization"],
484 | help="Check to add MSE regularization",
485 | )
486 | mse_weight_widget = st.sidebar.empty()
487 | mse_weight_decay_widget = st.sidebar.empty()
488 | mse_weight_decay_steps = st.sidebar.empty()
489 |
490 | if use_mse_reg is True:
491 | mse_weight = mse_weight_widget.number_input(
492 | "MSE weight",
493 | value=defaults["mse_weight"],
494 | # min_value=0.0, # leave this out to allow creativity
495 | step=0.05,
496 | help="Set weights for MSE regularization",
497 | )
498 | mse_weight_decay = mse_weight_decay_widget.number_input(
499 | "Decay MSE weight by ...",
500 | value=defaults["mse_weight_decay"],
501 | # min_value=0.0, # leave this out to allow creativity
502 | step=0.05,
503 | help="Subtracts MSE weight by this amount at every step change. MSE weight change stops at zero",
504 | )
505 | mse_weight_decay_steps = mse_weight_decay_steps.number_input(
506 | "... every N steps",
507 | value=defaults["mse_weight_decay_steps"],
508 | min_value=0,
509 | step=1,
510 | help="Number of steps to subtract MSE weight. Leave zero for no weight decay",
511 | )
512 | else:
513 | mse_weight = 0
514 | mse_weight_decay = 0
515 | mse_weight_decay_steps = 0
516 |
517 | use_tv_loss = st.sidebar.checkbox(
518 | "Use TV loss regularization",
519 | value=defaults["use_tv_loss_regularization"],
520 | help="Check to add MSE regularization",
521 | )
522 | tv_loss_weight_widget = st.sidebar.empty()
523 | if use_tv_loss is True:
524 | tv_loss_weight = tv_loss_weight_widget.number_input(
525 | "TV loss weight",
526 | value=defaults["tv_loss_weight"],
527 | min_value=0.0,
528 | step=1e-4,
529 | help="Set weights for TV loss regularization, which encourages spatial smoothness. Ref: https://github.com/jcjohnson/neural-style/issues/302",
530 | format="%.1e",
531 | )
532 | else:
533 | tv_loss_weight = 0
534 |
535 | use_scrolling_zooming = st.sidebar.checkbox(
536 | "Scrolling/zooming transforms",
537 | value=False,
538 | help="At fixed intervals, move the generated image up/down/left/right or zoom in/out",
539 | )
540 | translation_x_widget = st.sidebar.empty()
541 | translation_y_widget = st.sidebar.empty()
542 | rotation_angle_widget = st.sidebar.empty()
543 | zoom_factor_widget = st.sidebar.empty()
544 | transform_interval_widget = st.sidebar.empty()
545 | if use_scrolling_zooming is True:
546 | translation_x = translation_x_widget.number_input(
547 | "Translation in X", value=0, min_value=0, step=1
548 | )
549 | translation_y = translation_y_widget.number_input(
550 | "Translation in y", value=0, min_value=0, step=1
551 | )
552 | rotation_angle = rotation_angle_widget.number_input(
553 | "Rotation angle (degrees)",
554 | value=0.0,
555 | min_value=0.0,
556 | max_value=360.0,
557 | step=0.05,
558 | format="%.2f",
559 | )
560 | zoom_factor = zoom_factor_widget.number_input(
561 | "Zoom factor",
562 | value=1.0,
563 | min_value=0.1,
564 | max_value=10.0,
565 | step=0.02,
566 | format="%.2f",
567 | )
568 | transform_interval = transform_interval_widget.number_input(
569 | "Iterations per frame",
570 | value=10,
571 | min_value=0,
572 | step=1,
573 | help="Note: Will multiply by num steps above!",
574 | )
575 | else:
576 | translation_x = 0
577 | translation_y = 0
578 | rotation_angle = 0
579 | zoom_factor = 1
580 | transform_interval = 1
581 |
582 | use_cutout_augmentations = st.sidebar.checkbox(
583 | "Use cutout augmentations",
584 | value=True,
585 | help="Adds cutout augmentatinos in the image generation process. Uses up to additional 4 GiB of GPU memory. Greatly improves image quality. Toggled on by default.",
586 | )
587 |
588 | submitted = st.form_submit_button("Run!")
589 | # End of form
590 |
591 | status_text = st.empty()
592 | status_text.text("Pending input prompt")
593 | step_progress_bar = st.progress(0)
594 |
595 | im_display_slot = st.empty()
596 | vid_display_slot = st.empty()
597 | debug_slot = st.empty()
598 |
599 | if "prev_im" in st.session_state:
600 | im_display_slot.image(
601 | st.session_state["prev_im"], caption="Output image", output_format="PNG"
602 | )
603 |
604 | with st.expander("Expand for README"):
605 | with open("README.md", "r") as f:
606 | # Preprocess links to redirect to github
607 | # Thank you https://discuss.streamlit.io/u/asehmi, works like a charm!
608 | # ref: https://discuss.streamlit.io/t/image-in-markdown/13274/8
609 | markdown_links = [str(i) for i in Path("docs/").glob("*.md")]
610 | images = [str(i) for i in Path("docs/images/").glob("*")]
611 | readme_lines = f.readlines()
612 | readme_buffer = []
613 |
614 | for line in readme_lines:
615 | for md_link in markdown_links:
616 | if md_link in line:
617 | line = line.replace(
618 | md_link,
619 | "https://github.com/tnwei/vqgan-clip-app/tree/main/"
620 | + md_link,
621 | )
622 |
623 | readme_buffer.append(line)
624 | for image in images:
625 | if image in line:
626 | st.markdown(" ".join(readme_buffer[:-1]))
627 | st.image(
628 | f"https://raw.githubusercontent.com/tnwei/vqgan-clip-app/main/{image}"
629 | )
630 | readme_buffer.clear()
631 | st.markdown(" ".join(readme_buffer))
632 |
633 | with st.expander("Expand for CHANGELOG"):
634 | with open("CHANGELOG.md", "r") as f:
635 | st.markdown(f.read())
636 |
637 | if submitted:
638 | # debug_slot.write(st.session_state) # DEBUG
639 | status_text.text("Loading weights ...")
640 | generate_image(
641 | # Inputs
642 | text_input=text_input,
643 | vqgan_ckpt=radio,
644 | num_steps=num_steps,
645 | image_x=int(image_x),
646 | image_y=int(image_y),
647 | seed=int(seed) if set_seed is True else None,
648 | init_image=init_image,
649 | image_prompts=image_prompts,
650 | continue_prev_run=continue_prev_run,
651 | mse_weight=mse_weight,
652 | mse_weight_decay=mse_weight_decay,
653 | mse_weight_decay_steps=mse_weight_decay_steps,
654 | use_scrolling_zooming=use_scrolling_zooming,
655 | translation_x=translation_x,
656 | translation_y=translation_y,
657 | rotation_angle=rotation_angle,
658 | zoom_factor=zoom_factor,
659 | transform_interval=transform_interval,
660 | use_cutout_augmentations=use_cutout_augmentations,
661 | device=device,
662 | )
663 |
664 | vid_display_slot.video("temp.mp4")
665 | # debug_slot.write(st.session_state) # DEBUG
666 |
--------------------------------------------------------------------------------
/assets/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tnwei/vqgan-clip-app/82aff7993b5f64baacfe626e214398bad2d657d7/assets/.gitkeep
--------------------------------------------------------------------------------
/defaults.yaml:
--------------------------------------------------------------------------------
1 | # Modify for different systems, e.g. larger default xdim/ydim for more powerful GPUs
2 | num_steps: 500
3 | Xdim: 640
4 | ydim: 480
5 | set_seed: false
6 | seed: 0
7 | use_starting_image: false
8 | use_image_prompts: false
9 | continue_prev_run: false
10 | mse_weight: 0.5
11 | mse_weight_decay: 0.1
12 | mse_weight_decay_steps: 50
13 | use_mse_regularization: false
14 | use_tv_loss_regularization: true
15 | tv_loss_weight: 1e-3
--------------------------------------------------------------------------------
/diffusion_app.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from pathlib import Path
3 | import sys
4 | import datetime
5 | import shutil
6 | import json
7 | import os
8 | import torch
9 | import traceback
10 | import base64
11 | from PIL import Image
12 | from typing import Optional
13 | import argparse
14 |
15 | sys.path.append("./taming-transformers")
16 |
17 | import imageio
18 | import numpy as np
19 | from diffusion_logic import CLIPGuidedDiffusion, DIFFUSION_METHODS_AND_WEIGHTS
20 |
21 | # Optional
22 | try:
23 | import git
24 | except ModuleNotFoundError:
25 | pass
26 |
27 |
28 | def generate_image(
29 | diffusion_weights: str,
30 | prompt: str,
31 | seed=0,
32 | num_steps=500,
33 | continue_prev_run=True,
34 | init_image: Optional[Image.Image] = None,
35 | skip_timesteps: int = 0,
36 | use_cutout_augmentations: bool = False,
37 | device: Optional[torch.device] = None,
38 | ) -> None:
39 |
40 | ### Init -------------------------------------------------------------------
41 | run = CLIPGuidedDiffusion(
42 | prompt=prompt,
43 | ckpt=diffusion_weights,
44 | seed=seed,
45 | num_steps=num_steps,
46 | continue_prev_run=continue_prev_run,
47 | skip_timesteps=skip_timesteps,
48 | use_cutout_augmentations=use_cutout_augmentations,
49 | device=device,
50 | )
51 |
52 | # Generate random run ID
53 | # Used to link runs linked w/ continue_prev_run
54 | # ref: https://stackoverflow.com/a/42703382/13095028
55 | # Use URL and filesystem safe version since we're using this as a folder name
56 | run_id = st.session_state["run_id"] = base64.urlsafe_b64encode(
57 | os.urandom(6)
58 | ).decode("ascii")
59 |
60 | if "loaded_wt" not in st.session_state:
61 | st.session_state["loaded_wt"] = None
62 |
63 | run_start_dt = datetime.datetime.now()
64 |
65 | ### Load model -------------------------------------------------------------
66 | if (
67 | continue_prev_run
68 | and ("model" in st.session_state)
69 | and ("clip_model" in st.session_state)
70 | and ("diffusion" in st.session_state)
71 | and st.session_state["loaded_wt"] == diffusion_weights
72 | ):
73 | run.load_model(
74 | prev_model=st.session_state["model"],
75 | prev_diffusion=st.session_state["diffusion"],
76 | prev_clip_model=st.session_state["clip_model"],
77 | )
78 | else:
79 | (
80 | st.session_state["model"],
81 | st.session_state["diffusion"],
82 | st.session_state["clip_model"],
83 | ) = run.load_model(
84 | model_file_loc="assets/"
85 | + DIFFUSION_METHODS_AND_WEIGHTS.get(diffusion_method)
86 | )
87 | st.session_state["loaded_wt"] = diffusion_method
88 |
89 | ### Model init -------------------------------------------------------------
90 | # if continue_prev_run is True:
91 | # run.model_init(init_image=st.session_state["prev_im"])
92 | # elif init_image is not None:
93 | if init_image is not None:
94 | run.model_init(init_image=init_image)
95 | else:
96 | run.model_init()
97 |
98 | ### Iterate ----------------------------------------------------------------
99 | step_counter = 0 + skip_timesteps
100 | frames = []
101 |
102 | try:
103 | # Try block catches st.script_runner.StopExecution, no need of a dedicated stop button
104 | # Reason is st.form is meant to be self-contained either within sidebar, or in main body
105 | # The way the form is implemented in this app splits the form across both regions
106 | # This is intended to prevent the model settings from crowding the main body
107 | # However, touching any button resets the app state, making it impossible to
108 | # implement a stop button that can still dump output
109 | # Thankfully there's a built-in stop button :)
110 | while True:
111 | # While loop to accomodate running predetermined steps or running indefinitely
112 | status_text.text(f"Running step {step_counter}")
113 |
114 | ims = run.iterate()
115 | im = ims[0]
116 |
117 | if num_steps > 0: # skip when num_steps = -1
118 | step_progress_bar.progress((step_counter + 1) / num_steps)
119 | else:
120 | step_progress_bar.progress(100)
121 |
122 | # At every step, display and save image
123 | im_display_slot.image(im, caption="Output image", output_format="PNG")
124 | st.session_state["prev_im"] = im
125 |
126 | # ref: https://stackoverflow.com/a/33117447/13095028
127 | # im_byte_arr = io.BytesIO()
128 | # im.save(im_byte_arr, format="JPEG")
129 | # frames.append(im_byte_arr.getvalue()) # read()
130 | frames.append(np.asarray(im))
131 |
132 | step_counter += 1
133 |
134 | if (step_counter == num_steps) and num_steps > 0:
135 | break
136 |
137 | # Stitch into video using imageio
138 | writer = imageio.get_writer("temp.mp4", fps=24)
139 | for frame in frames:
140 | writer.append_data(frame)
141 | writer.close()
142 |
143 | # Save to output folder if run completed
144 | runoutputdir = outputdir / (
145 | run_start_dt.strftime("%Y%m%dT%H%M%S") + "-" + run_id
146 | )
147 | runoutputdir.mkdir()
148 |
149 | im.save(runoutputdir / "output.PNG", format="PNG")
150 | shutil.copy("temp.mp4", runoutputdir / "anim.mp4")
151 |
152 | details = {
153 | "run_id": run_id,
154 | "diffusion_method": diffusion_method,
155 | "ckpt": DIFFUSION_METHODS_AND_WEIGHTS.get(diffusion_method),
156 | "num_steps": step_counter,
157 | "planned_num_steps": num_steps,
158 | "text_input": prompt,
159 | "continue_prev_run": continue_prev_run,
160 | "seed": seed,
161 | "Xdim": imsize,
162 | "ydim": imsize,
163 | "start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"),
164 | "end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"),
165 | }
166 |
167 | if use_cutout_augmentations:
168 | details["use_cutout_augmentations"] = True
169 |
170 | if "git" in sys.modules:
171 | try:
172 | repo = git.Repo(search_parent_directories=True)
173 | commit_sha = repo.head.object.hexsha
174 | details["commit_sha"] = commit_sha[:6]
175 | except Exception as e:
176 | print("GitPython detected but not able to write commit SHA to file")
177 | print(f"raised Exception {e}")
178 |
179 | with open(runoutputdir / "details.json", "w") as f:
180 | json.dump(details, f, indent=4)
181 |
182 | status_text.text("Done!") # End of run
183 |
184 | except st.script_runner.StopException as e:
185 | # Dump output to dashboard
186 | print(f"Received Streamlit StopException")
187 | status_text.text("Execution interruped, dumping outputs ...")
188 | writer = imageio.get_writer("temp.mp4", fps=24)
189 | for frame in frames:
190 | writer.append_data(frame)
191 | writer.close()
192 |
193 | # Save to output folder if run completed
194 | runoutputdir = outputdir / (
195 | run_start_dt.strftime("%Y%m%dT%H%M%S") + "-" + run_id
196 | )
197 | runoutputdir.mkdir()
198 |
199 | im.save(runoutputdir / "output.PNG", format="PNG")
200 | shutil.copy("temp.mp4", runoutputdir / "anim.mp4")
201 |
202 | details = {
203 | "run_id": run_id,
204 | "diffusion_method": diffusion_method,
205 | "ckpt": DIFFUSION_METHODS_AND_WEIGHTS.get(diffusion_method),
206 | "num_steps": step_counter,
207 | "planned_num_steps": num_steps,
208 | "text_input": prompt,
209 | "continue_prev_run": continue_prev_run,
210 | "seed": seed,
211 | "Xdim": imsize,
212 | "ydim": imsize,
213 | "start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"),
214 | "end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"),
215 | }
216 |
217 | if use_cutout_augmentations:
218 | details["use_cutout_augmentations"] = True
219 |
220 | if "git" in sys.modules:
221 | try:
222 | repo = git.Repo(search_parent_directories=True)
223 | commit_sha = repo.head.object.hexsha
224 | details["commit_sha"] = commit_sha[:6]
225 | except Exception as e:
226 | print("GitPython detected but not able to write commit SHA to file")
227 | print(f"raised Exception {e}")
228 |
229 | with open(runoutputdir / "details.json", "w") as f:
230 | json.dump(details, f, indent=4)
231 |
232 | status_text.text("Done!") # End of run
233 |
234 |
235 | if __name__ == "__main__":
236 | # Argparse to capture GPU num
237 | parser = argparse.ArgumentParser()
238 |
239 | parser.add_argument(
240 | "--gpu", type=str, default=None, help="Specify GPU number. Defaults to None."
241 | )
242 | args = parser.parse_args()
243 |
244 | # Select specific GPU if chosen
245 | if args.gpu is not None:
246 | for i in args.gpu.split(","):
247 | assert (
248 | int(i) < torch.cuda.device_count()
249 | ), f"You specified --gpu {args.gpu} but torch.cuda.device_count() returned {torch.cuda.device_count()}"
250 |
251 | try:
252 | device = torch.device(f"cuda:{args.gpu}")
253 | except RuntimeError:
254 | print(traceback.format_exc())
255 | else:
256 | device = None
257 |
258 | outputdir = Path("output")
259 | if not outputdir.exists():
260 | outputdir.mkdir()
261 |
262 | st.set_page_config(page_title="CLIP guided diffusion playground")
263 | st.title("CLIP guided diffusion playground")
264 |
265 | # Determine what weights are available in `assets/`
266 | weights_dir = Path("assets").resolve()
267 | available_diffusion_weights = list(weights_dir.glob("*.pt"))
268 | available_diffusion_weights = [i.name for i in available_diffusion_weights]
269 | diffusion_weights_and_methods = {
270 | j: i for i, j in DIFFUSION_METHODS_AND_WEIGHTS.items()
271 | }
272 | available_diffusion_methods = [
273 | diffusion_weights_and_methods.get(i) for i in available_diffusion_weights
274 | ]
275 |
276 | # i.e. no weights found, ask user to download weights
277 | if len(available_diffusion_methods) == 0:
278 | st.warning(
279 | "No weights found, download diffusion weights in `download-diffusion-weights.sh`. "
280 | )
281 | st.stop()
282 |
283 | # Start of input form
284 | with st.form("form-inputs"):
285 | # Only element not in the sidebar, but in the form
286 |
287 | text_input = st.text_input(
288 | "Text prompt",
289 | help="CLIP-guided diffusion will generate an image that best fits the prompt",
290 | )
291 |
292 | diffusion_method = st.sidebar.radio(
293 | "Method",
294 | available_diffusion_methods,
295 | index=0,
296 | help="Choose diffusion image generation method, corresponding to the notebooks in Eleuther's repo",
297 | )
298 |
299 | if diffusion_method.startswith("256"):
300 | image_size_notice = st.sidebar.text("Image size: fixed to 256x256")
301 | imsize = 256
302 | elif diffusion_method.startswith("512"):
303 | image_size_notice = st.sidebar.text("Image size: fixed to 512x512")
304 | imsize = 512
305 |
306 | set_seed = st.sidebar.checkbox(
307 | "Set seed",
308 | value=0,
309 | help="Check to set random seed for reproducibility. Will add option to specify seed",
310 | )
311 | num_steps = st.sidebar.number_input(
312 | "Num steps",
313 | value=1000,
314 | min_value=0,
315 | max_value=None,
316 | step=1,
317 | # help="Specify -1 to run indefinitely. Use Streamlit's stop button in the top right corner to terminate execution. The exception is caught so the most recent output will be dumped to dashboard",
318 | )
319 |
320 | seed_widget = st.sidebar.empty()
321 | if set_seed is True:
322 | seed = seed_widget.number_input("Seed", value=0, help="Random seed to use")
323 | else:
324 | seed = None
325 |
326 | use_custom_reference_image = st.sidebar.checkbox(
327 | "Use reference image",
328 | value=False,
329 | help="Check to add a reference image. The network will attempt to match the generated image to the provided reference",
330 | )
331 |
332 | reference_image_widget = st.sidebar.empty()
333 | skip_timesteps_widget = st.sidebar.empty()
334 | if use_custom_reference_image is True:
335 | reference_image = reference_image_widget.file_uploader(
336 | "Upload reference image",
337 | type=["png", "jpeg", "jpg"],
338 | accept_multiple_files=False,
339 | help="Reference image for the network, will be resized to fit specified dimensions",
340 | )
341 | # Convert from UploadedFile object to PIL Image
342 | if reference_image is not None:
343 | reference_image: Image.Image = Image.open(reference_image).convert(
344 | "RGB"
345 | ) # just to be sure
346 | skip_timesteps = skip_timesteps_widget.number_input(
347 | "Skip timesteps (suggested 200-500)",
348 | value=200,
349 | help="Higher values make the output look more like the reference image",
350 | )
351 | else:
352 | reference_image = None
353 | skip_timesteps = 0
354 |
355 | continue_prev_run = st.sidebar.checkbox(
356 | "Skip init if models are loaded",
357 | value=True,
358 | help="Skips lengthy model init",
359 | )
360 |
361 | use_cutout_augmentations = st.sidebar.checkbox(
362 | "Use cutout augmentations",
363 | value=False,
364 | help="Adds cutout augmentations in the image generation process. Uses additional 1-2 GiB of GPU memory. Increases image quality, but probably not noticeable for guided diffusion since it's already pretty HQ and consumes a lot of VRAM, but feel free to experiment. Will significantly change image composition if toggled on vs toggled off. Toggled off by default.",
365 | )
366 |
367 | submitted = st.form_submit_button("Run!")
368 | # End of form
369 |
370 | status_text = st.empty()
371 | status_text.text("Pending input prompt")
372 | step_progress_bar = st.progress(0)
373 |
374 | im_display_slot = st.empty()
375 | vid_display_slot = st.empty()
376 | debug_slot = st.empty()
377 |
378 | if "prev_im" in st.session_state:
379 | im_display_slot.image(
380 | st.session_state["prev_im"], caption="Output image", output_format="PNG"
381 | )
382 |
383 | with st.expander("Expand for README"):
384 | with open("README.md", "r") as f:
385 | # Preprocess links to redirect to github
386 | # Thank you https://discuss.streamlit.io/u/asehmi, works like a charm!
387 | # ref: https://discuss.streamlit.io/t/image-in-markdown/13274/8
388 | markdown_links = [str(i) for i in Path("docs/").glob("*.md")]
389 | images = [str(i) for i in Path("docs/images/").glob("*")]
390 | readme_lines = f.readlines()
391 | readme_buffer = []
392 |
393 | for line in readme_lines:
394 | for md_link in markdown_links:
395 | if md_link in line:
396 | line = line.replace(
397 | md_link,
398 | "https://github.com/tnwei/vqgan-clip-app/tree/main/"
399 | + md_link,
400 | )
401 |
402 | readme_buffer.append(line)
403 | for image in images:
404 | if image in line:
405 | st.markdown(" ".join(readme_buffer[:-1]))
406 | st.image(
407 | f"https://raw.githubusercontent.com/tnwei/vqgan-clip-app/main/{image}"
408 | )
409 | readme_buffer.clear()
410 | st.markdown(" ".join(readme_buffer))
411 |
412 | with st.expander("Expand for CHANGELOG"):
413 | with open("CHANGELOG.md", "r") as f:
414 | st.markdown(f.read())
415 |
416 | if submitted:
417 | # debug_slot.write(st.session_state) # DEBUG
418 | status_text.text("Loading weights ...")
419 | generate_image(
420 | diffusion_weights=diffusion_method,
421 | prompt=text_input,
422 | seed=seed,
423 | num_steps=num_steps,
424 | continue_prev_run=continue_prev_run,
425 | init_image=reference_image,
426 | skip_timesteps=skip_timesteps,
427 | use_cutout_augmentations=use_cutout_augmentations,
428 | device=device,
429 | )
430 | vid_display_slot.video("temp.mp4")
431 | # debug_slot.write(st.session_state) # DEBUG
432 |
--------------------------------------------------------------------------------
/diffusion_logic.py:
--------------------------------------------------------------------------------
1 | import clip
2 | import sys
3 | import torch
4 | from torchvision import transforms
5 | from torchvision.transforms import functional as TF
6 | from torch import nn
7 | from torch.nn import functional as F
8 | import lpips
9 | from PIL import Image
10 | import kornia.augmentation as K
11 | from typing import Optional
12 |
13 | sys.path.append("./guided-diffusion")
14 |
15 | from guided_diffusion.script_util import (
16 | create_model_and_diffusion,
17 | model_and_diffusion_defaults,
18 | )
19 |
20 | DIFFUSION_METHODS_AND_WEIGHTS = {
21 | # "CLIP Guided Diffusion 256x256",
22 | "256x256 HQ Uncond": "256x256_diffusion_uncond.pt",
23 | "512x512 HQ Cond": "512x512_diffusion.pt",
24 | "512x512 HQ Uncond": "512x512_diffusion_uncond_finetune_008100.pt",
25 | }
26 |
27 |
28 | def spherical_dist_loss(x, y):
29 | x = F.normalize(x, dim=-1)
30 | y = F.normalize(y, dim=-1)
31 | return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
32 |
33 |
34 | def parse_prompt(prompt):
35 | vals = prompt.rsplit(":", 1)
36 | vals = vals + ["", "1"][len(vals) :]
37 | return vals[0], float(vals[1])
38 |
39 |
40 | class MakeCutouts(nn.Module):
41 | def __init__(self, cut_size, cutn, cut_pow=1.0, noise_fac=None, augs=None):
42 | super().__init__()
43 | self.cut_size = cut_size
44 | self.cutn = cutn
45 | self.cut_pow = cut_pow
46 | self.noise_fac = noise_fac
47 | self.augs = augs
48 |
49 | def forward(self, input):
50 | sideY, sideX = input.shape[2:4]
51 | max_size = min(sideX, sideY)
52 | min_size = min(sideX, sideY, self.cut_size)
53 | cutouts = []
54 | for _ in range(self.cutn):
55 | size = int(
56 | torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size
57 | )
58 | offsetx = torch.randint(0, sideX - size + 1, ())
59 | offsety = torch.randint(0, sideY - size + 1, ())
60 | cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
61 | cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
62 |
63 | if self.augs:
64 | batch = self.augs(torch.cat(cutouts, dim=0))
65 | else:
66 | batch = torch.cat(cutouts, dim=0)
67 |
68 | if self.noise_fac:
69 | facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
70 | batch = batch + facs * torch.randn_like(batch)
71 |
72 | return batch
73 |
74 |
75 | def tv_loss(input):
76 | """L2 total variation loss, as in Mahendran et al."""
77 | input = F.pad(input, (0, 1, 0, 1), "replicate")
78 | x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
79 | y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
80 | return (x_diff ** 2 + y_diff ** 2).mean([1, 2, 3])
81 |
82 |
83 | def range_loss(input):
84 | return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])
85 |
86 |
87 | class CLIPGuidedDiffusion:
88 | def __init__(
89 | self,
90 | prompt: str,
91 | ckpt: str,
92 | batch_size: int = 1,
93 | clip_guidance_scale: float = 1000,
94 | seed: int = 0,
95 | num_steps: int = 1000,
96 | continue_prev_run: bool = True,
97 | skip_timesteps: int = 0,
98 | use_cutout_augmentations: bool = False,
99 | device: Optional[torch.device] = None,
100 | ) -> None:
101 |
102 | assert ckpt in DIFFUSION_METHODS_AND_WEIGHTS.keys()
103 | self.ckpt = ckpt
104 | print(self.ckpt)
105 |
106 | # Default config
107 | self.model_config = model_and_diffusion_defaults()
108 | self.model_config.update(
109 | {
110 | "attention_resolutions": "32, 16, 8",
111 | "class_cond": True if ckpt == "512x512 HQ Cond" else False,
112 | "diffusion_steps": num_steps,
113 | "rescale_timesteps": True,
114 | "timestep_respacing": str(
115 | num_steps
116 | ), # modify this to decrease timesteps
117 | "image_size": 512 if ckpt.startswith("512") else 256,
118 | "learn_sigma": True,
119 | "noise_schedule": "linear",
120 | "num_channels": 256,
121 | "num_head_channels": 64,
122 | "num_res_blocks": 2,
123 | "resblock_updown": True,
124 | "use_checkpoint": False,
125 | "use_fp16": True,
126 | "use_scale_shift_norm": True,
127 | }
128 | )
129 | # Split text by "|" symbol
130 | self.prompts = [phrase.strip() for phrase in prompt.split("|")]
131 | if self.prompts == [""]:
132 | self.prompts = []
133 |
134 | self.image_prompts = [] # TODO
135 | self.batch_size = batch_size
136 |
137 | # Controls how much the image should look like the prompt.
138 | self.clip_guidance_scale = clip_guidance_scale
139 |
140 | # Controls the smoothness of the final output.
141 | self.tv_scale = 150 # TODO add control widget
142 |
143 | # Controls how far out of range RGB values are allowed to be.
144 | self.range_scale = 50 # TODO add control widget
145 |
146 | self.cutn = 32 # TODO add control widget
147 | self.cutn_batches = 2 # TODO add control widget
148 | self.cut_pow = 0.5 # TODO add control widget
149 |
150 | # Removed, repeat batches by triggering a new run
151 | # self.n_batches = 1
152 |
153 | # This enhances the effect of the init image, a good value is 1000.
154 | self.init_scale = 1000 # TODO add control widget
155 |
156 | # This needs to be between approx. 200 and 500 when using an init image.
157 | # Higher values make the output look more like the init.
158 | self.skip_timesteps = skip_timesteps # TODO add control widget
159 |
160 | self.seed = seed
161 | self.continue_prev_run = continue_prev_run
162 |
163 | self.use_cutout_augmentations = use_cutout_augmentations
164 |
165 | if device is None:
166 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
167 | else:
168 | self.device = device
169 |
170 | print("Using device:", self.device)
171 |
172 | def load_model(
173 | self,
174 | model_file_loc="assets/256x256_diffusion_uncond.pt",
175 | prev_model=None,
176 | prev_diffusion=None,
177 | prev_clip_model=None,
178 | ) -> None:
179 | if (
180 | self.continue_prev_run is True
181 | and prev_model is not None
182 | and prev_diffusion is not None
183 | and prev_clip_model is not None
184 | ):
185 | self.model = prev_model
186 | self.diffusion = prev_diffusion
187 | self.clip_model = prev_clip_model
188 |
189 | self.clip_size = self.clip_model.visual.input_resolution
190 | self.normalize = transforms.Normalize(
191 | mean=[0.48145466, 0.4578275, 0.40821073],
192 | std=[0.26862954, 0.26130258, 0.27577711],
193 | )
194 |
195 | else:
196 | self.model, self.diffusion = create_model_and_diffusion(**self.model_config)
197 | self.model.load_state_dict(torch.load(model_file_loc, map_location="cpu"))
198 | self.model.eval().requires_grad_(False).to(self.device)
199 |
200 | if self.ckpt == "512x512 HQ Cond":
201 | for name, param in self.model.named_parameters():
202 | if "qkv" in name or "norm" in name or "proj" in name:
203 | param.requires_grad_()
204 |
205 | if self.model_config["use_fp16"]:
206 | self.model.convert_to_fp16()
207 |
208 | self.clip_model = (
209 | clip.load("ViT-B/16", jit=False)[0]
210 | .eval()
211 | .requires_grad_(False)
212 | .to(self.device)
213 | )
214 |
215 | self.clip_size = self.clip_model.visual.input_resolution
216 | self.normalize = transforms.Normalize(
217 | mean=[0.48145466, 0.4578275, 0.40821073],
218 | std=[0.26862954, 0.26130258, 0.27577711],
219 | )
220 |
221 | return self.model, self.diffusion, self.clip_model
222 |
223 | def cond_fn_conditional(self, x, t, y=None):
224 | # From 512 HQ notebook using OpenAI's conditional 512x512 model
225 | # TODO: Merge with cond_fn's cutn_batches
226 | with torch.enable_grad():
227 | x = x.detach().requires_grad_()
228 | n = x.shape[0]
229 | my_t = torch.ones([n], device=self.device, dtype=torch.long) * self.cur_t
230 | out = self.diffusion.p_mean_variance(
231 | self.model, x, my_t, clip_denoised=False, model_kwargs={"y": y}
232 | )
233 | fac = self.diffusion.sqrt_one_minus_alphas_cumprod[self.cur_t]
234 | x_in = out["pred_xstart"] * fac + x * (1 - fac)
235 | clip_in = self.normalize(self.make_cutouts(x_in.add(1).div(2)))
236 | image_embeds = (
237 | self.clip_model.encode_image(clip_in).float().view([self.cutn, n, -1])
238 | )
239 | dists = spherical_dist_loss(image_embeds, self.target_embeds.unsqueeze(0))
240 | losses = dists.mean(0)
241 | tv_losses = tv_loss(x_in)
242 | loss = (
243 | losses.sum() * self.clip_guidance_scale
244 | + tv_losses.sum() * self.tv_scale
245 | )
246 | # TODO: Implement init image
247 | return -torch.autograd.grad(loss, x)[0]
248 |
249 | def cond_fn(self, x, t, out, y=None):
250 | n = x.shape[0]
251 | fac = self.diffusion.sqrt_one_minus_alphas_cumprod[self.cur_t]
252 | x_in = out["pred_xstart"] * fac + x * (1 - fac)
253 | x_in_grad = torch.zeros_like(x_in)
254 | for i in range(self.cutn_batches):
255 | clip_in = self.normalize(self.make_cutouts(x_in.add(1).div(2)))
256 | image_embeds = self.clip_model.encode_image(clip_in).float()
257 | dists = spherical_dist_loss(
258 | image_embeds.unsqueeze(1), self.target_embeds.unsqueeze(0)
259 | )
260 | dists = dists.view([self.cutn, n, -1])
261 | losses = dists.mul(self.weights).sum(2).mean(0)
262 | x_in_grad += (
263 | torch.autograd.grad(losses.sum() * self.clip_guidance_scale, x_in)[0]
264 | / self.cutn_batches
265 | )
266 | tv_losses = tv_loss(x_in)
267 | range_losses = range_loss(out["pred_xstart"])
268 | loss = tv_losses.sum() * self.tv_scale + range_losses.sum() * self.range_scale
269 | if self.init is not None and self.init_scale:
270 | init_losses = self.lpips_model(x_in, self.init)
271 | loss = loss + init_losses.sum() * self.init_scale
272 | x_in_grad += torch.autograd.grad(loss, x_in)[0]
273 | grad = -torch.autograd.grad(x_in, x, x_in_grad)[0]
274 | return grad
275 |
276 | def model_init(self, init_image: Image.Image = None) -> None:
277 | if self.seed is not None:
278 | torch.manual_seed(self.seed)
279 | else:
280 | self.seed = torch.seed() # Trigger a seed, retrieve the utilized seed
281 |
282 | if self.use_cutout_augmentations:
283 | noise_fac = 0.1
284 | augs = nn.Sequential(
285 | K.RandomHorizontalFlip(p=0.5),
286 | K.RandomSharpness(0.3, p=0.4),
287 | K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"),
288 | K.RandomPerspective(0.2, p=0.4),
289 | K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
290 | )
291 | else:
292 | noise_fac = None
293 | augs = None
294 |
295 | self.make_cutouts = MakeCutouts(
296 | self.clip_size, self.cutn, self.cut_pow, noise_fac=noise_fac, augs=augs
297 | )
298 | self.side_x = self.side_y = self.model_config["image_size"]
299 |
300 | self.target_embeds, self.weights = [], []
301 |
302 | for prompt in self.prompts:
303 | txt, weight = parse_prompt(prompt)
304 | self.target_embeds.append(
305 | self.clip_model.encode_text(clip.tokenize(txt).to(self.device)).float()
306 | )
307 | self.weights.append(weight)
308 |
309 | # TODO: Implement image prompt parsing
310 | # for prompt in self.image_prompts:
311 | # path, weight = parse_prompt(prompt)
312 | # img = Image.open(fetch(path)).convert('RGB')
313 | # img = TF.resize(img, min(side_x, side_y, *img.size), transforms.InterpolationMode.LANCZOS)
314 | # batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
315 | # embed = clip_model.encode_image(normalize(batch)).float()
316 | # target_embeds.append(embed)
317 | # weights.extend([weight / cutn] * cutn)
318 |
319 | self.target_embeds = torch.cat(self.target_embeds)
320 | self.weights = torch.tensor(self.weights, device=self.device)
321 | if self.weights.sum().abs() < 1e-3:
322 | raise RuntimeError("The weights must not sum to 0.")
323 | self.weights /= self.weights.sum().abs()
324 |
325 | self.init = None
326 | if init_image is not None:
327 | self.init = init_image.resize((self.side_x, self.side_y), Image.LANCZOS)
328 | self.init = (
329 | TF.to_tensor(self.init).to(self.device).unsqueeze(0).mul(2).sub(1)
330 | )
331 |
332 | # LPIPS not required if init_image not used!
333 | if self.init is None:
334 | self.lpips_model = None
335 | else:
336 | self.lpips_model = lpips.LPIPS(net="vgg").to(self.device)
337 |
338 | if self.model_config["timestep_respacing"].startswith("ddim"):
339 | sample_fn = self.diffusion.ddim_sample_loop_progressive
340 | else:
341 | sample_fn = self.diffusion.p_sample_loop_progressive
342 |
343 | self.cur_t = self.diffusion.num_timesteps - self.skip_timesteps - 1
344 |
345 | if self.ckpt == "512x512 HQ Cond":
346 | print("Using conditional sampling fn")
347 | self.samples = sample_fn(
348 | self.model,
349 | (self.batch_size, 3, self.side_y, self.side_x),
350 | clip_denoised=False,
351 | model_kwargs={
352 | "y": torch.zeros(
353 | [self.batch_size], device=self.device, dtype=torch.long
354 | )
355 | },
356 | cond_fn=self.cond_fn_conditional,
357 | progress=True,
358 | skip_timesteps=self.skip_timesteps,
359 | init_image=self.init,
360 | randomize_class=True,
361 | )
362 | else:
363 | print("Using unconditional sampling fn")
364 | self.samples = sample_fn(
365 | self.model,
366 | (self.batch_size, 3, self.side_y, self.side_x),
367 | clip_denoised=False,
368 | model_kwargs={},
369 | cond_fn=self.cond_fn,
370 | progress=True,
371 | skip_timesteps=self.skip_timesteps,
372 | init_image=self.init,
373 | randomize_class=True,
374 | cond_fn_with_grad=True,
375 | )
376 |
377 | self.samplesgen = enumerate(self.samples)
378 |
379 | def iterate(self):
380 | self.cur_t -= 1
381 | _, sample = next(self.samplesgen)
382 |
383 | ims = []
384 | for _, image in enumerate(sample["pred_xstart"]):
385 | im = TF.to_pil_image(image.add(1).div(2).clamp(0, 1))
386 | ims.append(im)
387 |
388 | return ims
389 |
--------------------------------------------------------------------------------
/docs/architecture.md:
--------------------------------------------------------------------------------
1 | # Architecture
2 |
3 | ## App structure
4 |
5 | There are three major components in this repo:
6 |
7 | + VQGAN-CLIP app
8 | + `app.py` houses the UI
9 | + `logic.py` stores underlying logic
10 | + `vqgan_utils.py` stores utility functions used by `logic.py`
11 | + CLIP guided diffusion app
12 | + `diffusion_app.py` houses the UI
13 | + `diffusion_logic.py` stores underlying logic
14 | + Gallery viewer
15 | + `gallery.py` houses the UI
16 | + `templates/index.html` houses the HTML page template
17 |
18 | To date, all three components are independent and do not have shared dependencies.
19 |
20 | The UI for both image generation apps are built using Streamlit ([docs](https://docs.streamlit.io/en/stable/index.html)), which makes it easy to throw together dashboards for ML projects in a short amount of time. The gallery viewer is a simple Flask dashboard.
21 |
22 | ## Customizing this repo
23 |
24 | Defaults settings for the app upon launch are specified in `defaults.yaml`, which can be adjusted as necessary.
25 |
26 | To use customized weights for VQGAN-CLIP, save both the `yaml` config and the `ckpt` weights in `assets/`, and ensure that they have the same filename, just with different postfixes (e.g. `mymodel.ckpt`, `mymodel.yaml`). It will then appear in the app interface for use. Refer to the [VQGAN repo](https://github.com/CompVis/taming-transformers) for training VQGAN on your own dataset.
27 |
28 | To modify the image generation logic, instantiate your own model class in `logic.py` / `diffusion_logic.py` and modify `app.py` / `diffusion_app.py` accordingly if changes to UI elements are needed.
--------------------------------------------------------------------------------
/docs/images/diffusion-example.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tnwei/vqgan-clip-app/82aff7993b5f64baacfe626e214398bad2d657d7/docs/images/diffusion-example.jpg
--------------------------------------------------------------------------------
/docs/images/four-seasons-20210808.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tnwei/vqgan-clip-app/82aff7993b5f64baacfe626e214398bad2d657d7/docs/images/four-seasons-20210808.jpg
--------------------------------------------------------------------------------
/docs/images/gallery.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tnwei/vqgan-clip-app/82aff7993b5f64baacfe626e214398bad2d657d7/docs/images/gallery.jpg
--------------------------------------------------------------------------------
/docs/images/translationx_example_trimmed.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tnwei/vqgan-clip-app/82aff7993b5f64baacfe626e214398bad2d657d7/docs/images/translationx_example_trimmed.gif
--------------------------------------------------------------------------------
/docs/images/ui.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tnwei/vqgan-clip-app/82aff7993b5f64baacfe626e214398bad2d657d7/docs/images/ui.jpeg
--------------------------------------------------------------------------------
/docs/implementation-details.md:
--------------------------------------------------------------------------------
1 | # Implementation details
2 |
3 | ## VQGAN-CLIP
4 |
5 | Code for the VQGAN-CLIP app mostly came from the z+quantize_method notebook hosted in [EleutherAI/vqgan-clip](https://github.com/EleutherAI/vqgan-clip/tree/main/notebooks). The logic is mostly left unchanged, just refactored to work well with a GUI frontend.
6 |
7 | Prompt weighting was implemented after seeing it in one of the Colab variants floating around the internet. Browsing other notebooks led to adding loss functions like MSE regularization and TV loss for improved image quality.
8 |
9 | ## CLIP guided diffusion
10 |
11 | Code for the CLIP guided diffusion app mostly came from the HQ 512x512 Unconditional notebook, which also can be found in EleutherAI/vqgan-clip. Models from other guided diffusion notebooks are also implemented save for the non-HQ 256x256 version, which did not generate satisfactory results during testing.
--------------------------------------------------------------------------------
/docs/notes-and-observations.md:
--------------------------------------------------------------------------------
1 | # Notes and observations
2 |
3 | ## Generated image size
4 |
5 | Generated image size is bound by GPU VRAM available. The reference notebook default to use 480x480. One of the notebooks in the thoughts section below uses 640x512. For reference, an RTX2060 can barely manage 300x300. You can use image upscaling tools such as [Waifu2X](https://github.com/nagadomi/waifu2x) or [Real ESRGAN](https://github.com/xinntao/Real-ESRGAN) to further upscale the generated image beyond VRAM limits. Just be aware that smaller generated images fundamentally contain less complexity than larger images.
6 |
7 | ## GPU VRAM consumption
8 |
9 | Following are GPU VRAM consumption read from `nvidia-smi`, note that your mileage may vary.
10 |
11 | For VQGAN-CLIP, using the `vqgan_imagenet_f16_1024` model checkpoint:
12 |
13 | | Resolution| VRAM Consumption |
14 | | ----------| ---------------- |
15 | | 300 x 300 | 4,829 MiB |
16 | | 480 x 480 | 8,465 MiB |
17 | | 640 x 360 | 8,169 MiB |
18 | | 640 x 480 | 10,247 MiB |
19 | | 800 x 450 | 13,977 MiB |
20 | | 800 x 600 | 18,157 MiB |
21 | | 960 x 540 | 15,131 MiB |
22 | | 960 x 720 | 19,777 MiB |
23 | | 1024 x 576| 17,175 MiB |
24 | | 1024 x 768| 22,167 MiB |
25 | | 1280 x 720| 24,353 MiB |
26 |
27 | For CLIP guided diffusion, using various checkpoints:
28 |
29 | | Method | VRAM Consumption |
30 | | ------------------| ---------------- |
31 | | 256x256 HQ Uncond | 9,199 MiB |
32 | | 512x512 HQ Cond | 15,079 MiB |
33 | | 512x512 HQ Uncond | 15,087 MiB |
34 |
35 | ## CUDA out of memory error
36 |
37 | If you're getting a CUDA out of memory error on your first run, it is a sign that the image size is too large. If you were able to generate images of a particular size prior, then you might have a CUDA memory leak and need to restart the application. Still figuring out where the leak is from.
38 |
39 | ## VQGAN weights and art style
40 |
41 | In the download links are trained on different datasets. You might find them suitable for generating different art styles. The `sflickr` dataset is skewed towards generating landscape images while the `faceshq` dataset is skewed towards generating faces. If you have no art style preference, the ImageNet weights do remarkably well. In fact, VQGAN-CLIP can be conditioned to generate specific styles, thanks to the breadth of understanding supplied by the CLIP model (see tips section).
42 |
43 | Some weights have multiple versions, e.g. ImageNet 1024 and Image 16384. The number represents the codebook (latent space) dimensionality. For more info, refer to the [VQGAN repo](https://github.com/CompVis/taming-transformers), also linked in the intro above.
44 |
45 | ## How many steps should I let the models run?
46 |
47 | A good starting number to try is 1000 steps for VQGAN-CLIP, and 2000 steps for CLIP guided diffusion.
48 |
49 | There is no ground rule on how many steps to run to get a good image, images generated are also not guaranteed to be interesting. YMMV as the image generation process depends on how well CLIP is able to understand the given prompt(s).
50 |
51 | ## Reproducibility
52 |
53 | Replicating a run with the same configuration, model checkpoint and seed will allow recreating the same output albeit with very minor variations. There exists a tiny bit of stochasticity that can't be eliminated, due to how the underlying convolution operators are implemented in CUDA. The underlying CUDA operations used by cuDNN can vary depending on differences in hardware or plain noise (see [Pytorch docs on reproducibility](https://pytorch.org/docs/stable/notes/randomness.html)). Using fully deterministic algorithms isn't an option as some CUDA operations do not have deterministic counterparts (e.g. `upsample_bicubic2d_backward_cuda`)
54 |
55 | In practice, the variations are not easy to spot unless a side-by-side comparison is done.
--------------------------------------------------------------------------------
/docs/tips-n-tricks.md:
--------------------------------------------------------------------------------
1 | # Tips and tricks
2 |
3 | ## Prompt weighting
4 |
5 | The text prompts allow separating text prompts using the "|" symbol, with custom weightage for each. Example: `A beautiful sunny countryside scene of a Collie dog running towards a herd of sheep in a grassy field with a farmhouse in the background:100 | wild flowers:30 | roses:-30 | photorealistic:20 | V-ray:20 | ray tracing:20 | unreal engine:20`.
6 |
7 | Refer to [this Reddit post](https://www.reddit.com/r/bigsleep/comments/p15fis/tutorial_an_introduction_for_newbies_to_using_the/) for more info.
8 |
9 | ## Using long, descriptive prompts
10 |
11 | CLIP can handle long, descriptive prompts, as long as the prompt is understandable and not too specific. Example: [this Reddit post](https://www.reddit.com/r/MediaSynthesis/comments/oej9qc/gptneo_vqganclip/)
12 |
13 | ## Prompt engineering for stylization
14 |
15 | Art style of the images generated can be somewhat controlled via prompt engineering.
16 |
17 | + Adding mentions of "Unreal engine" in the prompt causes the generated output to resemble a 3D render in high resolution. Example: [this Tweet](https://twitter.com/arankomatsuzaki/status/1399471244760649729?s=20)
18 | + The art style of specific artists can be mimicked by adding their name to the prompt. [This blogpost](https://moultano.wordpress.com/2021/07/20/tour-of-the-sacred-library/) shows a number of images generated using the style of James Gurney. [This imgur album](https://imgur.com/a/Ha7lsYu) compares the output of similar prompts with different artist names appended.
19 |
20 | ## Multi-stage iteration
21 |
22 | The output image of the previous run can be carried over as the input image of the next run in the VQGAN-CLIP webapp, using the `continue_previous_run` option. This allows you to continue iterating on an image if you like how it has turned out thus far. Extending upon that feature enables multi-stage iteration, where the same image can be iterated upon using different prompts at different stages.
23 |
24 | For example, you can tell the network to generate "Cowboy singing in the sky", then continue the same image and weights using a different prompt, "Fish on an alien planet under the night sky". Because of how backprop works, the network will find the easiest way to change the previous image to fit the new prompt. Should be useful for preserving visual structure between images, and for smoothly transitioning from one scene to another.
25 |
26 | Here is an example where "Backyard in spring" is first generated, then iterated upon with prompts "Backyard in summer", "Backyard in autumn", and "Backyard in winter". Major visual elements in the initial image were inherited and utilized across multiple runs.
27 |
28 | 
29 |
30 | A few things to take note:
31 | + If a new image size is specified, the existing output image will be cropped to size accordingly.
32 | + This is specifically possible for VQGAN-CLIP but not for CLIP guided diffusion. (Explain how both of them work)
33 | + Splitting a long run into multiple successive runs using the same prompt do not yield the same outcome due to the underlying stochasticity. This randomness can't be mitigated by setting the random seed alone. See the section on reproducibility in notes-and-observations.md.
34 |
35 | ## Scrolling and zooming
36 |
37 | Added scrolling and zooming from [this notebook](https://colab.research.google.com/github/chigozienri/VQGAN-CLIP-animations/blob/main/VQGAN-CLIP-animations.ipynb) by @chigozienri.
38 |
39 | 
40 |
41 | More examples at [this imgur link](https://imgur.com/a/8pyUNCQ).
42 |
43 |
--------------------------------------------------------------------------------
/download-diffusion-weights.sh:
--------------------------------------------------------------------------------
1 | # 256x256
2 | # curl -L -o assets/256x256_diffusion_uncond.pt -C - 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'
3 |
4 | # 512x512 class conditional model
5 | # curl -L -o assets/512x512_diffusion.pt -C - 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/512x512_diffusion.pt'
6 |
7 | # 512x512 unconditional model
8 | # curl -L -o assets/512x512_diffusion_uncond_finetune_008100.pt --http1.1 'https://the-eye.eu/public/AI/models/512x512_diffusion_unconditional_ImageNet/512x512_diffusion_uncond_finetune_008100.pt'
9 |
--------------------------------------------------------------------------------
/download-weights.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | # Uncomment and run script to download
4 |
5 | # ImageNet 1024
6 | # curl -L -o assets/vqgan_imagenet_f16_1024.yaml -C - 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1'
7 | # curl -L -o assets/vqgan_imagenet_f16_1024.ckpt -C - 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fckpts%2Flast.ckpt&dl=1'
8 |
9 |
10 | # ImageNet 16384
11 | # curl -L -o assets/vqgan_imagenet_f16_16384.yaml -C - 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1'
12 | # curl -L -o assets/vqgan_imagenet_f16_16384.ckpt -C - 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fckpts%2Flast.ckpt&dl=1'
13 |
14 | # COCO
15 | # curl -L -o assets/coco.yaml -C - 'https://dl.nmkd.de/ai/clip/coco/coco.yaml' #COCO
16 | # curl -L -o assets/coco.ckpt -C - 'https://dl.nmkd.de/ai/clip/coco/coco.ckpt' #COCO
17 |
18 | # Faces HQ
19 | # curl -L -o assets/faceshq.yaml -C - 'https://drive.google.com/uc?export=download&id=1fHwGx_hnBtC8nsq7hesJvs-Klv-P0gzT'
20 | # curl -L -o assets/faceshq.ckpt -C - 'https://app.koofr.net/content/links/a04deec9-0c59-4673-8b37-3d696fe63a5d/files/get/last.ckpt?path=%2F2020-11-13T21-41-45_faceshq_transformer%2Fcheckpoints%2Flast.ckpt'
21 |
22 |
23 | # WikiArt 16384
24 | # curl -L 'http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.yaml' > assets/wikiart_16384.yaml
25 | # curl -L 'http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.ckpt' > assets/wikiart_16384.ckpt
26 |
27 | # S-Flickr
28 | # curl -L -o assets/sflckr.yaml -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fconfigs%2F2020-11-09T13-31-51-project.yaml&dl=1'
29 | # curl -L -o assets/sflckr.ckpt -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fcheckpoints%2Flast.ckpt&dl=1'
30 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: vqgan-clip-app
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - pytorch::pytorch=1.10.0
8 | - pytorch::torchvision=0.11.1
9 | - cudatoolkit=10.2
10 | - omegaconf
11 | - pytorch-lightning
12 | - tqdm
13 | - regex
14 | - kornia
15 | - ftfy
16 | - pillow=7.1.2
17 | - python=3.8 # For compatibility
18 | - imageio-ffmpeg=0.2.0 # For compatibility
19 | - ipykernel
20 | - imageio
21 | - ipywidgets
22 | - streamlit
23 | - conda-forge::ffmpeg
24 | - pyyaml
25 | - flask
26 | - pip
27 | - gitpython
28 | - opencv
29 | - pip:
30 | # - stegano
31 | # - python-xmp-toolkit
32 | # - imgtag
33 | - einops
34 | - transformers
35 | - git+https://github.com/openai/CLIP
36 | # For guided diffusion
37 | - lpips
38 | - git+https://github.com//crowsonkb/guided-diffusion
39 | prefix: /home/tnwei/miniconda3/envs/vqgan-clip-app
40 |
--------------------------------------------------------------------------------
/gallery.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from pathlib import Path
4 | from typing import Union
5 |
6 | import yaml
7 | from flask import Flask, render_template, request, send_from_directory
8 | from PIL import Image
9 |
10 |
11 | class RunResults:
12 | """
13 | Store run output and metadata as a class for ease of use.
14 | """
15 |
16 | def __init__(self, fdir: Union[str, Path]):
17 | self.fdir: str = Path(fdir) # exactly as given
18 | self.absfdir: Path = Path(fdir).resolve() # abs
19 | files_available = [i.name for i in self.fdir.glob("*")]
20 |
21 | # The most important info is the final image and the metadata
22 | # Leaving an option to delete videos to save space
23 | if (
24 | "details.txt" not in files_available
25 | and "details.json" not in files_available
26 | ):
27 | raise ValueError(
28 | f"fdir passed has neither details.txt or details.json: {fdir}"
29 | )
30 |
31 | if "output.PNG" not in files_available:
32 | raise ValueError(f"fdir passed contains no output.PNG: {fdir}")
33 |
34 | self.impath = (self.fdir / "output.PNG").as_posix()
35 |
36 | if "anim.mp4" in files_available:
37 | self.animpath = (self.fdir / "anim.mp4").as_posix()
38 | else:
39 | self.animpath = None
40 | print(f"fdir passed contains no anim.mp4: {fdir}")
41 |
42 | if "init-image.JPEG" in files_available:
43 | self.initimpath = (self.fdir / "init-image.JPEG").as_posix()
44 | else:
45 | self.initimpath = None
46 |
47 | self.impromptspath = [i.as_posix() for i in self.fdir.glob("image-prompt*")]
48 | if len(self.impromptspath) == 0:
49 | self.impromptspath = None
50 |
51 | if "details.txt" in files_available:
52 | self.detailspath = (self.fdir / "details.txt").as_posix()
53 | elif "details.json" in files_available:
54 | self.detailspath = (self.fdir / "details.json").as_posix()
55 |
56 | with open(self.detailspath, "r") as f:
57 | self.details = json.load(f)
58 |
59 | # Preserving the filepaths as I realize might be calling
60 | # them through Jinja in HTML instead of within the app
61 | # self.detailshtmlstr = markdown2.markdown(
62 | # "```" + json.dumps(self.details, indent=4) + "```",
63 | # extras=["fenced-code-blocks"]
64 | # )
65 | self.detailshtmlstr = yaml.dump(self.details)
66 |
67 | # Replace with line feed and carriage return
68 | # ref: https://stackoverflow.com/a/39325879/13095028
69 | self.detailshtmlstr = self.detailshtmlstr.replace("\n", "
")
70 | # Edit: Using preformatted tag
, solved! 71 | 72 | self.im = Image.open(self.fdir / "output.PNG").convert( 73 | "RGB" 74 | ) # just to be sure the format is right 75 | 76 | 77 | def update_runs(fdir, runs): 78 | existing_run_folders = [i.absfdir.name for i in runs] 79 | # Load each run as ModelRun objects 80 | # Loading the latest ones first 81 | for i in sorted(fdir.iterdir()): 82 | # If is a folder and contains images and metadata 83 | if ( 84 | i.is_dir() 85 | and (i / "details.json").exists() 86 | and (i / "output.PNG").exists() 87 | and i.name not in existing_run_folders 88 | ): 89 | try: 90 | runs.insert(0, RunResults(i)) 91 | except Exception as e: 92 | print(f"Skipped {i} due to raised exception {e}") 93 | return runs 94 | 95 | 96 | if __name__ == "__main__": 97 | # Select output dir 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument( 100 | "path", 101 | help="Path to output folder, should contain subdirs of individual runs", 102 | nargs="?", 103 | default="./output/", 104 | ) 105 | parser.add_argument( 106 | "-n", 107 | "--numitems", 108 | help="Number of items per page", 109 | default=24, # multiple of three since the dashboard has three panels 110 | ) 111 | parser.add_argument( 112 | "--kiosk", 113 | help="Omit showing run details on dashboard", 114 | default=False, 115 | action="store_true", 116 | ) 117 | args = parser.parse_args() 118 | 119 | fdir = Path(args.path) 120 | runs = update_runs(fdir, []) 121 | 122 | app = Flask( 123 | __name__, 124 | # Hack to allow easy access to images 125 | # Else typically this stuff needs to be put in a static/ folder! 126 | static_url_path="", 127 | static_folder="", 128 | ) 129 | 130 | @app.route("/") 131 | def home(): 132 | # startidx = request.args.get('startidx') 133 | # endidx = request.args.get('endidx') 134 | 135 | # Pagenum starts at 1 136 | page = request.args.get("page") 137 | page = 1 if page is None else int(page) 138 | 139 | # startidx = 1 if startidx is None else int(startidx) 140 | # endidx = args.numitems if endidx is None else int(endidx) 141 | # print("startidx, endidx: ", startidx, endidx) 142 | global runs 143 | runs = update_runs(fdir, runs) # Updates new results when refreshed 144 | num_pages = (len(runs) // args.numitems) + 1 145 | 146 | page_labels = {} 147 | for i in range(0, num_pages): 148 | page_labels[i + 1] = dict( 149 | start=i * args.numitems + 1, end=(i + 1) * args.numitems 150 | ) 151 | 152 | return render_template( 153 | "index.html", 154 | runs=runs, 155 | startidx=page_labels[page]["start"], 156 | endidx=page_labels[page]["end"], 157 | page=page, 158 | fdir=fdir, 159 | page_labels=page_labels, 160 | kiosk=args.kiosk, 161 | ) 162 | 163 | @app.route("/findurl") 164 | def findurl(path, filename): 165 | return send_from_directory(path, filename) 166 | 167 | app.run(debug=False) 168 | -------------------------------------------------------------------------------- /logic.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple 2 | from PIL import Image 3 | import argparse 4 | import clip 5 | from vqgan_utils import ( 6 | load_vqgan_model, 7 | MakeCutouts, 8 | parse_prompt, 9 | resize_image, 10 | Prompt, 11 | synth, 12 | checkin, 13 | TVLoss, 14 | ) 15 | import torch 16 | from torchvision.transforms import functional as TF 17 | import torch.nn as nn 18 | from torch.nn import functional as F 19 | from torch import optim 20 | from torchvision import transforms 21 | import cv2 22 | import numpy as np 23 | import kornia.augmentation as K 24 | 25 | 26 | class Run: 27 | """ 28 | Subclass this to house your own implementation of CLIP-based image generation 29 | models within the UI 30 | """ 31 | 32 | def __init__(self): 33 | """ 34 | Set up the run's config here 35 | """ 36 | pass 37 | 38 | def load_model(self): 39 | """ 40 | Load models here. Separated this from __init__ to allow loading model state 41 | from a previous run 42 | """ 43 | pass 44 | 45 | def model_init(self): 46 | """ 47 | Continue run setup, for items that require the models to be in=place. 48 | Call once after load_model 49 | """ 50 | pass 51 | 52 | def iterate(self): 53 | """ 54 | Place iteration logic here. Outputs results for human consumption at 55 | every step. 56 | """ 57 | pass 58 | 59 | 60 | class VQGANCLIPRun(Run): 61 | def __init__( 62 | # Inputs 63 | self, 64 | text_input: str = "the first day of the waters", 65 | vqgan_ckpt: str = "vqgan_imagenet_f16_16384", 66 | num_steps: int = 300, 67 | image_x: int = 300, 68 | image_y: int = 300, 69 | init_image: Optional[Image.Image] = None, 70 | image_prompts: List[Image.Image] = [], 71 | continue_prev_run: bool = False, 72 | seed: Optional[int] = None, 73 | mse_weight=0.5, 74 | mse_weight_decay=0.1, 75 | mse_weight_decay_steps=50, 76 | tv_loss_weight=1e-3, 77 | use_cutout_augmentations: bool = True, 78 | # use_augs: bool = True, 79 | # noise_fac: float = 0.1, 80 | # use_noise: Optional[float] = None, 81 | # mse_withzeros=True, 82 | ## **kwargs, # Use this to receive Streamlit objects ## Call from main UI 83 | use_scrolling_zooming: bool = False, 84 | translation_x: int = 0, 85 | translation_y: int = 0, 86 | rotation_angle: float = 0, 87 | zoom_factor: float = 1, 88 | transform_interval: int = 10, 89 | device: Optional[torch.device] = None, 90 | ) -> None: 91 | super().__init__() 92 | self.text_input = text_input 93 | self.vqgan_ckpt = vqgan_ckpt 94 | self.num_steps = num_steps 95 | self.image_x = image_x 96 | self.image_y = image_y 97 | self.init_image = init_image 98 | self.image_prompts = image_prompts 99 | self.continue_prev_run = continue_prev_run 100 | self.seed = seed 101 | 102 | # Setup ------------------------------------------------------------------------------ 103 | # Split text by "|" symbol 104 | texts = [phrase.strip() for phrase in text_input.split("|")] 105 | if texts == [""]: 106 | texts = [] 107 | 108 | # Leaving most of this untouched 109 | self.args = argparse.Namespace( 110 | prompts=texts, 111 | image_prompts=image_prompts, 112 | noise_prompt_seeds=[], 113 | noise_prompt_weights=[], 114 | size=[int(image_x), int(image_y)], 115 | init_image=init_image, 116 | init_weight=mse_weight, 117 | # clip.available_models() 118 | # ['RN50', 'RN101', 'RN50x4', 'ViT-B/32'] 119 | # Visual Transformer seems to be the smallest 120 | clip_model="ViT-B/32", 121 | vqgan_config=f"assets/{vqgan_ckpt}.yaml", 122 | vqgan_checkpoint=f"assets/{vqgan_ckpt}.ckpt", 123 | step_size=0.05, 124 | cutn=64, 125 | cut_pow=1.0, 126 | display_freq=50, 127 | seed=seed, 128 | ) 129 | 130 | if device is None: 131 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 132 | else: 133 | self.device = device 134 | 135 | print("Using device:", device) 136 | 137 | self.iterate_counter = 0 138 | # self.use_augs = use_augs 139 | # self.noise_fac = noise_fac 140 | # self.use_noise = use_noise 141 | # self.mse_withzeros = mse_withzeros 142 | self.init_mse_weight = mse_weight 143 | self.mse_weight = mse_weight 144 | self.mse_weight_decay = mse_weight_decay 145 | self.mse_weight_decay_steps = mse_weight_decay_steps 146 | 147 | self.use_cutout_augmentations = use_cutout_augmentations 148 | 149 | # For TV loss 150 | self.tv_loss_weight = tv_loss_weight 151 | 152 | self.use_scrolling_zooming = use_scrolling_zooming 153 | self.translation_x = translation_x 154 | self.translation_y = translation_y 155 | self.rotation_angle = rotation_angle 156 | self.zoom_factor = zoom_factor 157 | self.transform_interval = transform_interval 158 | 159 | def load_model( 160 | self, prev_model: nn.Module = None, prev_perceptor: nn.Module = None 161 | ) -> Optional[Tuple[nn.Module, nn.Module]]: 162 | if self.continue_prev_run is True: 163 | self.model = prev_model 164 | self.perceptor = prev_perceptor 165 | return None 166 | 167 | else: 168 | self.model = load_vqgan_model( 169 | self.args.vqgan_config, self.args.vqgan_checkpoint 170 | ).to(self.device) 171 | 172 | self.perceptor = ( 173 | clip.load(self.args.clip_model, jit=False)[0] 174 | .eval() 175 | .requires_grad_(False) 176 | .to(self.device) 177 | ) 178 | 179 | return self.model, self.perceptor 180 | 181 | def model_init(self, init_image: Image.Image = None) -> None: 182 | cut_size = self.perceptor.visual.input_resolution 183 | e_dim = self.model.quantize.e_dim 184 | f = 2 ** (self.model.decoder.num_resolutions - 1) 185 | 186 | if self.use_cutout_augmentations: 187 | noise_fac = 0.1 188 | augs = nn.Sequential( 189 | K.RandomHorizontalFlip(p=0.5), 190 | K.RandomSharpness(0.3, p=0.4), 191 | K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"), 192 | K.RandomPerspective(0.2, p=0.4), 193 | K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), 194 | ) 195 | else: 196 | noise_fac = None 197 | augs = None 198 | 199 | self.make_cutouts = MakeCutouts( 200 | cut_size, 201 | self.args.cutn, 202 | cut_pow=self.args.cut_pow, 203 | noise_fac=noise_fac, 204 | augs=augs, 205 | ) 206 | 207 | n_toks = self.model.quantize.n_e 208 | toksX, toksY = self.args.size[0] // f, self.args.size[1] // f 209 | sideX, sideY = toksX * f, toksY * f 210 | self.z_min = self.model.quantize.embedding.weight.min(dim=0).values[ 211 | None, :, None, None 212 | ] 213 | self.z_max = self.model.quantize.embedding.weight.max(dim=0).values[ 214 | None, :, None, None 215 | ] 216 | 217 | if self.seed is not None: 218 | torch.manual_seed(self.seed) 219 | else: 220 | self.seed = torch.seed() # Trigger a seed, retrieve the utilized seed 221 | 222 | # Initialization order: continue_prev_im, init_image, then only random init 223 | if init_image is not None: 224 | init_image = init_image.resize((sideX, sideY), Image.LANCZOS) 225 | self.z, *_ = self.model.encode( 226 | TF.to_tensor(init_image).to(self.device).unsqueeze(0) * 2 - 1 227 | ) 228 | elif self.args.init_image: 229 | pil_image = self.args.init_image 230 | pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS) 231 | self.z, *_ = self.model.encode( 232 | TF.to_tensor(pil_image).to(self.device).unsqueeze(0) * 2 - 1 233 | ) 234 | else: 235 | one_hot = F.one_hot( 236 | torch.randint(n_toks, [toksY * toksX], device=self.device), n_toks 237 | ).float() 238 | self.z = one_hot @ self.model.quantize.embedding.weight 239 | self.z = self.z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2) 240 | self.z_orig = self.z.clone() 241 | self.z.requires_grad_(True) 242 | self.opt = optim.Adam([self.z], lr=self.args.step_size) 243 | 244 | self.normalize = transforms.Normalize( 245 | mean=[0.48145466, 0.4578275, 0.40821073], 246 | std=[0.26862954, 0.26130258, 0.27577711], 247 | ) 248 | 249 | self.pMs = [] 250 | 251 | for prompt in self.args.prompts: 252 | txt, weight, stop = parse_prompt(prompt) 253 | embed = self.perceptor.encode_text( 254 | clip.tokenize(txt).to(self.device) 255 | ).float() 256 | self.pMs.append(Prompt(embed, weight, stop).to(self.device)) 257 | 258 | for uploaded_image in self.args.image_prompts: 259 | # path, weight, stop = parse_prompt(prompt) 260 | # img = resize_image(Image.open(fetch(path)).convert("RGB"), (sideX, sideY)) 261 | img = resize_image(uploaded_image.convert("RGB"), (sideX, sideY)) 262 | batch = self.make_cutouts(TF.to_tensor(img).unsqueeze(0).to(self.device)) 263 | embed = self.perceptor.encode_image(self.normalize(batch)).float() 264 | self.pMs.append(Prompt(embed, weight, stop).to(self.device)) 265 | 266 | for seed, weight in zip( 267 | self.args.noise_prompt_seeds, self.args.noise_prompt_weights 268 | ): 269 | gen = torch.Generator().manual_seed(seed) 270 | embed = torch.empty([1, self.perceptor.visual.output_dim]).normal_( 271 | generator=gen 272 | ) 273 | self.pMs.append(Prompt(embed, weight).to(self.device)) 274 | 275 | def _ascend_txt(self) -> List: 276 | out = synth(self.model, self.z) 277 | iii = self.perceptor.encode_image( 278 | self.normalize(self.make_cutouts(out)) 279 | ).float() 280 | 281 | result = {} 282 | 283 | if self.args.init_weight: 284 | result["mse_loss"] = F.mse_loss(self.z, self.z_orig) * self.mse_weight / 2 285 | 286 | # MSE regularization scheduler 287 | with torch.no_grad(): 288 | # if not the first step 289 | # and is time for step change 290 | # and both weight decay steps and magnitude are nonzero 291 | # and MSE isn't zero already 292 | if ( 293 | self.iterate_counter > 0 294 | and self.iterate_counter % self.mse_weight_decay_steps == 0 295 | and self.mse_weight_decay != 0 296 | and self.mse_weight_decay_steps != 0 297 | and self.mse_weight != 0 298 | ): 299 | self.mse_weight = self.mse_weight - self.mse_weight_decay 300 | 301 | # Don't allow changing sign 302 | # Basically, caps MSE at zero if decreasing from positive 303 | # But, also prevents MSE from becoming positive if -MSE intended 304 | if self.init_mse_weight > 0: 305 | self.mse_weight = max(self.mse_weight, 0) 306 | else: 307 | self.mse_weight = min(self.mse_weight, 0) 308 | 309 | print(f"updated mse weight: {self.mse_weight}") 310 | 311 | tv_loss_fn = TVLoss() 312 | result["tv_loss"] = tv_loss_fn(self.z) * self.tv_loss_weight 313 | 314 | for count, prompt in enumerate(self.pMs): 315 | result[f"prompt_loss_{count}"] = prompt(iii) 316 | 317 | return result 318 | 319 | def iterate(self) -> Tuple[List[float], Image.Image]: 320 | if not self.use_scrolling_zooming: 321 | # Forward prop 322 | self.opt.zero_grad() 323 | losses = self._ascend_txt() 324 | 325 | # Grab an image 326 | im: Image.Image = checkin(self.model, self.z) 327 | 328 | # Backprop 329 | loss = sum([j for i, j in losses.items()]) 330 | loss.backward() 331 | self.opt.step() 332 | with torch.no_grad(): 333 | self.z.copy_(self.z.maximum(self.z_min).minimum(self.z_max)) 334 | 335 | # Advance iteration counter 336 | self.iterate_counter += 1 337 | 338 | print( 339 | f"Step {self.iterate_counter} losses: {[(i, j.item()) for i, j in losses.items()]}" 340 | ) 341 | 342 | # Output stuff useful for humans 343 | return [(i, j.item()) for i, j in losses.items()], im 344 | 345 | else: 346 | # Grab current image 347 | im_before_transform: Image.Image = checkin(self.model, self.z) 348 | 349 | # Convert for use in OpenCV 350 | imarr = np.array(im_before_transform) 351 | imarr = cv2.cvtColor(imarr, cv2.COLOR_RGB2BGR) 352 | 353 | translation = np.float32( 354 | [[1, 0, self.translation_x], [0, 1, self.translation_y]] 355 | ) 356 | 357 | imcenter = (imarr.shape[1] // 2, imarr.shape[0] // 2) 358 | rotation = cv2.getRotationMatrix2D( 359 | imcenter, angle=self.rotation_angle, scale=self.zoom_factor 360 | ) 361 | 362 | trans_mat = np.vstack([translation, [0, 0, 1]]) 363 | rot_mat = np.vstack([rotation, [0, 0, 1]]) 364 | transformation_matrix = np.matmul(rot_mat, trans_mat) 365 | 366 | outarr = cv2.warpPerspective( 367 | imarr, 368 | transformation_matrix, 369 | (imarr.shape[1], imarr.shape[0]), 370 | borderMode=cv2.BORDER_WRAP, 371 | ) 372 | 373 | transformed_im = Image.fromarray(cv2.cvtColor(outarr, cv2.COLOR_BGR2RGB)) 374 | 375 | # Encode as z, reinit 376 | self.z, *_ = self.model.encode( 377 | TF.to_tensor(transformed_im).to(self.device).unsqueeze(0) * 2 - 1 378 | ) 379 | self.z.requires_grad_(True) 380 | self.opt = optim.Adam([self.z], lr=self.args.step_size) 381 | 382 | for _ in range(self.transform_interval): 383 | # Forward prop 384 | self.opt.zero_grad() 385 | losses = self._ascend_txt() 386 | 387 | # Grab an image 388 | im: Image.Image = checkin(self.model, self.z) 389 | 390 | # Backprop 391 | loss = sum([j for i, j in losses.items()]) 392 | loss.backward() 393 | self.opt.step() 394 | with torch.no_grad(): 395 | self.z.copy_(self.z.maximum(self.z_min).minimum(self.z_max)) 396 | 397 | # Advance iteration counter 398 | self.iterate_counter += 1 399 | 400 | print( 401 | f"Step {self.iterate_counter} losses: {[(i, j.item()) for i, j in losses.items()]}" 402 | ) 403 | 404 | # Output stuff useful for humans 405 | return [(i, j.item()) for i, j in losses.items()], im 406 | -------------------------------------------------------------------------------- /templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 12 | 13 |VQGAN-CLIP output gallery 14 | 15 | 16 | 17 |18 |113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /vqgan_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | import io 6 | from omegaconf import OmegaConf 7 | from taming.models import cond_transformer, vqgan 8 | from PIL import Image 9 | from torchvision.transforms import functional as TF 10 | import requests 11 | import sys 12 | 13 | sys.path.append("./taming-transformers") 14 | 15 | 16 | def sinc(x): 17 | return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([])) 18 | 19 | 20 | def lanczos(x, a): 21 | cond = torch.logical_and(-a < x, x < a) 22 | out = torch.where(cond, sinc(x) * sinc(x / a), x.new_zeros([])) 23 | return out / out.sum() 24 | 25 | 26 | def ramp(ratio, width): 27 | n = math.ceil(width / ratio + 1) 28 | out = torch.empty([n]) 29 | cur = 0 30 | for i in range(out.shape[0]): 31 | out[i] = cur 32 | cur += ratio 33 | return torch.cat([-out[1:].flip([0]), out])[1:-1] 34 | 35 | 36 | def resample(input, size, align_corners=True): 37 | n, c, h, w = input.shape 38 | dh, dw = size 39 | 40 | input = input.view([n * c, 1, h, w]) 41 | 42 | if dh < h: 43 | kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype) 44 | pad_h = (kernel_h.shape[0] - 1) // 2 45 | input = F.pad(input, (0, 0, pad_h, pad_h), "reflect") 46 | input = F.conv2d(input, kernel_h[None, None, :, None]) 47 | 48 | if dw < w: 49 | kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype) 50 | pad_w = (kernel_w.shape[0] - 1) // 2 51 | input = F.pad(input, (pad_w, pad_w, 0, 0), "reflect") 52 | input = F.conv2d(input, kernel_w[None, None, None, :]) 53 | 54 | input = input.view([n, c, h, w]) 55 | return F.interpolate(input, size, mode="bicubic", align_corners=align_corners) 56 | 57 | 58 | class ReplaceGrad(torch.autograd.Function): 59 | @staticmethod 60 | def forward(ctx, x_forward, x_backward): 61 | ctx.shape = x_backward.shape 62 | return x_forward 63 | 64 | @staticmethod 65 | def backward(ctx, grad_in): 66 | return None, grad_in.sum_to_size(ctx.shape) 67 | 68 | 69 | replace_grad = ReplaceGrad.apply 70 | 71 | 72 | class ClampWithGrad(torch.autograd.Function): 73 | @staticmethod 74 | def forward(ctx, input, min, max): 75 | ctx.min = min 76 | ctx.max = max 77 | ctx.save_for_backward(input) 78 | return input.clamp(min, max) 79 | 80 | @staticmethod 81 | def backward(ctx, grad_in): 82 | input, = ctx.saved_tensors 83 | return ( 84 | grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), 85 | None, 86 | None, 87 | ) 88 | 89 | 90 | clamp_with_grad = ClampWithGrad.apply 91 | 92 | 93 | def vector_quantize(x, codebook): 94 | d = ( 95 | x.pow(2).sum(dim=-1, keepdim=True) 96 | + codebook.pow(2).sum(dim=1) 97 | - 2 * x @ codebook.T 98 | ) 99 | indices = d.argmin(-1) 100 | x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook 101 | return replace_grad(x_q, x) 102 | 103 | 104 | class Prompt(nn.Module): 105 | def __init__(self, embed, weight=1.0, stop=float("-inf")): 106 | super().__init__() 107 | self.register_buffer("embed", embed) 108 | self.register_buffer("weight", torch.as_tensor(weight)) 109 | self.register_buffer("stop", torch.as_tensor(stop)) 110 | 111 | def forward(self, input): 112 | input_normed = F.normalize(input.unsqueeze(1), dim=2) 113 | embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2) 114 | dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) 115 | dists = dists * self.weight.sign() 116 | return ( 117 | self.weight.abs() 118 | * replace_grad(dists, torch.maximum(dists, self.stop)).mean() 119 | ) 120 | 121 | 122 | def fetch(url_or_path): 123 | if str(url_or_path).startswith("http://") or str(url_or_path).startswith( 124 | "https://" 125 | ): 126 | r = requests.get(url_or_path) 127 | r.raise_for_status() 128 | fd = io.BytesIO() 129 | fd.write(r.content) 130 | fd.seek(0) 131 | return fd 132 | return open(url_or_path, "rb") 133 | 134 | 135 | def parse_prompt(prompt): 136 | if prompt.startswith("http://") or prompt.startswith("https://"): 137 | vals = prompt.rsplit(":", 3) 138 | vals = [vals[0] + ":" + vals[1], *vals[2:]] 139 | else: 140 | vals = prompt.rsplit(":", 2) 141 | vals = vals + ["", "1", "-inf"][len(vals) :] 142 | return vals[0], float(vals[1]), float(vals[2]) 143 | 144 | 145 | class MakeCutouts(nn.Module): 146 | def __init__(self, cut_size, cutn, cut_pow=1.0, noise_fac=None, augs=None): 147 | super().__init__() 148 | self.cut_size = cut_size 149 | self.cutn = cutn 150 | self.cut_pow = cut_pow 151 | self.noise_fac = noise_fac 152 | self.augs = augs 153 | 154 | def forward(self, input): 155 | sideY, sideX = input.shape[2:4] 156 | max_size = min(sideX, sideY) 157 | min_size = min(sideX, sideY, self.cut_size) 158 | cutouts = [] 159 | for _ in range(self.cutn): 160 | size = int( 161 | torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size 162 | ) 163 | offsetx = torch.randint(0, sideX - size + 1, ()) 164 | offsety = torch.randint(0, sideY - size + 1, ()) 165 | cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] 166 | cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) 167 | 168 | if self.augs: 169 | batch = self.augs(torch.cat(cutouts, dim=0)) 170 | else: 171 | batch = torch.cat(cutouts, dim=0) 172 | 173 | if self.noise_fac: 174 | facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac) 175 | batch = batch + facs * torch.randn_like(batch) 176 | 177 | return clamp_with_grad(batch, 0, 1) 178 | 179 | 180 | def load_vqgan_model(config_path, checkpoint_path): 181 | config = OmegaConf.load(config_path) 182 | if config.model.target == "taming.models.vqgan.VQModel": 183 | model = vqgan.VQModel(**config.model.params) 184 | model.eval().requires_grad_(False) 185 | model.init_from_ckpt(checkpoint_path) 186 | elif config.model.target == "taming.models.cond_transformer.Net2NetTransformer": 187 | parent_model = cond_transformer.Net2NetTransformer(**config.model.params) 188 | parent_model.eval().requires_grad_(False) 189 | parent_model.init_from_ckpt(checkpoint_path) 190 | model = parent_model.first_stage_model 191 | else: 192 | raise ValueError(f"unknown model type: {config.model.target}") 193 | del model.loss 194 | return model 195 | 196 | 197 | def resize_image(image, out_size): 198 | ratio = image.size[0] / image.size[1] 199 | area = min(image.size[0] * image.size[1], out_size[0] * out_size[1]) 200 | size = round((area * ratio) ** 0.5), round((area / ratio) ** 0.5) 201 | return image.resize(size, Image.LANCZOS) 202 | 203 | 204 | def synth(model, z): 205 | z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim( 206 | 3, 1 207 | ) 208 | return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1) 209 | 210 | 211 | @torch.no_grad() 212 | def checkin(model, z): 213 | # losses_str = ", ".join(f"{loss.item():g}" for loss in losses) 214 | # tqdm.write(f"i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}") 215 | out = synth(model, z) 216 | im = TF.to_pil_image(out[0].cpu()) 217 | return im 218 | # display.display(display.Image('progress.png')) # ipynb only 219 | 220 | 221 | class TVLoss(nn.Module): 222 | def forward(self, input): 223 | input = F.pad(input, (0, 1, 0, 1), "replicate") 224 | x_diff = input[..., :-1, 1:] - input[..., :-1, :-1] 225 | y_diff = input[..., 1:, :-1] - input[..., :-1, :-1] 226 | diff = x_diff ** 2 + y_diff ** 2 + 1e-8 227 | return diff.mean(dim=1).sqrt().mean() 228 | --------------------------------------------------------------------------------19 |21 |VQGAN-CLIP output gallery
20 |22 | 23 | Output folder: {{ fdir }} 24 |40 | 41 |
25 | Total runs: {{ runs|length }} 26 |
27 | {% for pagenum, indices in page_labels.items() %} 28 | 29 | 30 | {% if pagenum == page %} 31 | {{indices['start']}}-{{indices['end']}} 32 | {% else %} 33 | {{indices['start']}}-{{indices['end']}} 34 | {% endif %} 35 | 36 | 37 | {% endfor %} 38 | 39 |
42 |43 | 44 | {% for run in runs[startidx-1:endidx] %} 45 |89 |46 |87 | {% endfor %} 88 |47 |86 |50 |
51 |85 |{{ run.details.text_input }}
52 | 53 | {% if not kiosk %} 54 |55 | 56 |
64 | 65 |57 | {% for i, j in run.details.items() %} 58 | 59 |
62 | 63 |- {{ i ~ ': ' ~ j }}
60 | {% endfor %} 61 |66 | {% if run.animpath is not none %} 67 | Animation 68 | 69 | {% endif %} 70 | {% if run.initimpath is not none %} 71 | Init image 72 | 73 | {% endif %} 74 | {% if run.impromptspath is not none %} 75 | {% for i in range(run.impromptspath|length) %} 76 | Image prompt {{i+1}} 77 | 78 | {% endfor %} 79 | {% endif %} 80 |
81 | 82 | {% endif %} 83 | 84 |
90 | 91 |92 | 93 | {% for pagenum, indices in page_labels.items() %} 94 | 95 | 96 | {% if pagenum == page %} 97 | {{indices['start']}}-{{indices['end']}} 98 | {% else %} 99 | {{indices['start']}}-{{indices['end']}} 100 | {% endif %} 101 | 102 | 103 | {% endfor %} 104 | 105 |106 |107 | 108 | Code: tnwei/vqgan-clip-app 109 | 110 |111 |
112 |