├── .gitignore ├── .gitlint ├── .pre-commit-config.yaml ├── .streamlit └── config.toml ├── CHANGELOG.md ├── LICENSE ├── README.md ├── app.py ├── assets └── .gitkeep ├── defaults.yaml ├── diffusion_app.py ├── diffusion_logic.py ├── docs ├── architecture.md ├── images │ ├── diffusion-example.jpg │ ├── four-seasons-20210808.jpg │ ├── gallery.jpg │ ├── translationx_example_trimmed.gif │ └── ui.jpeg ├── implementation-details.md ├── notes-and-observations.md └── tips-n-tricks.md ├── download-diffusion-weights.sh ├── download-weights.sh ├── environment.yml ├── gallery.py ├── logic.py ├── templates └── index.html └── vqgan_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Common Python excludes 2 | **/.ipynb_checkpoints/ 3 | **/__pycache__/ 4 | 5 | # VQGAN weights 6 | assets/*.ckpt 7 | assets/*.yaml 8 | 9 | # Outputs 10 | output* 11 | 12 | # Test data 13 | test-samples/ 14 | -------------------------------------------------------------------------------- /.gitlint: -------------------------------------------------------------------------------- 1 | # Config file for gitlint (https://jorisroovers.com/gitlint/) 2 | # Can be installed from pip, recommend to use pipx 3 | # Generate default .gitlint using gitlint generate-config 4 | # Can be used as a hook in pre-commit 5 | 6 | [general] 7 | # B6 requires a body message to be defined 8 | ignore=B6 9 | # Activate conventional commits 10 | contrib=contrib-title-conventional-commits 11 | 12 | [contrib-title-conventional-commits] 13 | # Specify allowed commit types. For details see: https://www.conventionalcommits.org/ 14 | types = feat, fix, chore, docs, test, refactor, update, merge, explore 15 | 16 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: "v2.3.0" 4 | hooks: 5 | - id: check-ast 6 | - id: check-json 7 | - id: check-yaml 8 | - repo: https://github.com/psf/black 9 | rev: "19.3b0" 10 | hooks: 11 | - id: black 12 | - repo: https://github.com/jorisroovers/gitlint 13 | rev: "018f42e" 14 | hooks: 15 | - id: gitlint 16 | # required to refer to .gitlint from where the pre-commit venv is 17 | args: [-C ../../../.gitlint, --msg-filename] 18 | -------------------------------------------------------------------------------- /.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [server] 2 | # Default is 200 MB 3 | maxUploadSize = 10 4 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## 1.1 - Dec 18, 2021 4 | 5 | ### What's new 6 | 7 | + Fixed gallery app adding new images to the last page instead of the first page when refreshing (7bf3b04) 8 | + Added output commit ID to metadata if GitPython is installed (19eeb30) 9 | + Added cutout augmentations to VQGAN-CLIP (b3a7ab1) and guided diffusion (9651bc1) 10 | + Fixed random seed not saved to output if unspecified in guided diffusion (8972a5f) 11 | + Added feature to enable generating scrolling/zooming images to VQGAN-CLIP (https://github.com/tnwei/vqgan-clip-app/pull/11) 12 | 13 | ### Transitioning to 1.1 from 1.0 14 | 15 | The Python environment defined in `environment.yml` has been updated to enable generating scrolling/zooming images. Although we only need to add opencv, conda's package resolution required updating Pytorch as well to find a compatible version of opencv. 16 | 17 | Therefore existing users need to do either of the following: 18 | 19 | ``` bash 20 | # Remove the current Python env and recreate 21 | conda env remove -n vqgan-clip-app 22 | conda env create -f environment.yml 23 | 24 | # Directly update the Python environment in-place 25 | conda activate vqgan-clip-app 26 | conda env update -f environment.yml --prune 27 | ``` 28 | 29 | ## 1.0 - Nov 21, 2021 30 | 31 | Since starting this repo in Jul 2021 as a personal project, I believe the codebase is now sufficiently feature-complete and stable to call it a 1.0 release. 32 | 33 | ### What's new 34 | 35 | + Added support for CLIP guided diffusion (https://github.com/tnwei/vqgan-clip-app/pull/8) 36 | + Improvements to gallery viewer: added pagination (https://github.com/tnwei/vqgan-clip-app/pull/3), improved webpage responsiveness (https://github.com/tnwei/vqgan-clip-app/issues/9) 37 | + Minor tweaks to VQGAN-CLIP app, added options to control MSE regularization (https://github.com/tnwei/vqgan-clip-app/pull/4) and TV loss (https://github.com/tnwei/vqgan-clip-app/5) 38 | + Reorganized documentation in README.md and `docs/` 39 | 40 | ### Transitioning to 1.0 for existing users 41 | 42 | Update to 1.0 by running `git pull` from your local copy of this repo. No breaking changes are expected, run results from older versions of the codebase should still show up in the gallery viewer. 43 | 44 | However, some new packages are needed to support CLIP guided diffusion. You can follow these steps below instead of setting up the Python environment from scratch: 45 | 46 | 1. In the repo directory, run `git clone https://github.com/crowsonkb/guided-diffusion` 47 | 2. `pip install ./guided-diffusion` 48 | 3. `pip install lpips` 49 | 4. Download diffusion model checkpoints in `download-diffusion-weights.sh` 50 | 51 | ## Before Oct 2, 2021 52 | 53 | VQGAN-CLIP app and basic gallery viewer implemented. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 tnwei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VQGAN-CLIP web app & CLIP guided diffusion web app 2 | 3 | ![LGTM Grade](https://img.shields.io/lgtm/grade/python/github/tnwei/vqgan-clip-app) 4 | ![License](https://img.shields.io/github/license/tnwei/vqgan-clip-app) 5 | 6 | Link to repo: [tnwei/vqgan-clip-app](https://github.com/tnwei/vqgan-clip-app). 7 | 8 | ## Intro to VQGAN-CLIP 9 | 10 | VQGAN-CLIP has been in vogue for generating art using deep learning. Searching the `r/deepdream` subreddit for VQGAN-CLIP yields [quite a number of results](https://www.reddit.com/r/deepdream/search?q=vqgan+clip&restrict_sr=on). Basically, [VQGAN](https://github.com/CompVis/taming-transformers) can generate pretty high fidelity images, while [CLIP](https://github.com/openai/CLIP) can produce relevant captions for images. Combined, VQGAN-CLIP can take prompts from human input, and iterate to generate images that fit the prompts. 11 | 12 | Thanks to the generosity of creators sharing notebooks on Google Colab, the VQGAN-CLIP technique has seen widespread circulation. However, for regular usage across multiple sessions, I prefer a local setup that can be started up rapidly. Thus, this simple Streamlit app for generating VQGAN-CLIP images on a local environment. Screenshot of the UI as below: 13 | 14 | ![Screenshot of the UI](docs/images/ui.jpeg) 15 | 16 | Be advised that you need a beefy GPU with lots of VRAM to generate images large enough to be interesting. (Hello Quadro owners!). For reference, an RTX2060 can barely manage a 300x300 image. Otherwise you are best served using the notebooks on Colab. 17 | 18 | Reference is [this Colab notebook](https://colab.research.google.com/drive/1L8oL-vLJXVcRzCFbPwOoMkPKJ8-aYdPN?usp=sharing) originally by Katherine Crowson. The notebook can also be found in [this repo hosted by EleutherAI](https://github.com/EleutherAI/vqgan-clip). 19 | 20 | ## Intro to CLIP guided diffusion 21 | 22 | In mid 2021, Open AI released [Diffusion Models Beat GANS on Image Synthesis](arxiv.org/abs/2105.05233), with corresponding [source code and model checkpoints released on github](https://github.com/openai/guided-diffusion). The cadre of people that brought us VQGAN-CLIP worked their magic, and shared CLIP guided diffusion notebooks for public use. CLIP guided diffusion uses more GPU VRAM, runs slower, and has fixed output sizes depending on the trained model checkpoints, but is capable of producing more breathtaking images. 23 | 24 | Here's a few examples using the prompt _"Flowery fragrance intertwined with the freshness of the ocean breeze by Greg Rutkowski"_, run on the 512x512 HQ Uncond model: 25 | 26 | ![Example output for CLIP guided diffusion](docs/images/diffusion-example.jpg) 27 | 28 | The implementation of CLIP guided diffusion in this repo is based on notebooks from the same `EleutherAI/vqgan-clip` repo. 29 | 30 | ## Setup 31 | 32 | 1. Install the required Python libraries. Using `conda`, run `conda env create -f environment.yml` 33 | 2. Git clone this repo. After that, `cd` into the repo and run: 34 | + `git clone https://github.com/CompVis/taming-transformers` (Update to pip install if either of [these](https://github.com/CompVis/taming-transformers/pull/89) [two](https://github.com/CompVis/taming-transformers/pull/81) PRs are merged) 35 | + `git clone https://github.com/crowsonkb/guided-diffusion` (Update to pip install if [this PR](https://github.com/crowsonkb/guided-diffusion/pull/2) is merged) 36 | 3. Download the pretrained weights and config files using the provided links in the files listed below. Note that that all of the links are commented out by default. Recommend to download one by one, as some of the downloads can take a while. 37 | + For VQGAN-CLIP: `download-weights.sh`. You'll want to at least have both the ImageNet weights, which are used in the reference notebook. 38 | + For CLIP guided diffusion: `download-diffusion-weights.sh`. 39 | 40 | ## Usage 41 | 42 | + VQGAN-CLIP: `streamlit run app.py`, launches web app on `localhost:8501` if available 43 | + CLIP guided diffusion: `streamlit run diffusion_app.py`, launches web app on `localhost:8501` if available 44 | + Image gallery: `python gallery.py`, launches a gallery viewer on `localhost:5000`. More on this below. 45 | 46 | In the web app, select settings on the sidebar, key in the text prompt, and click run to generate images using VQGAN-CLIP. When done, the web app will display the output image as well as a video compilation showing progression of image generation. You can save them directly through the browser's right-click menu. 47 | 48 | A one-time download of additional pre-trained weights will occur before generating the first image. Might take a few minutes depending on your internet connection. 49 | 50 | If you have multiple GPUs, specify the GPU you want to use by adding `-- --gpu X`. An extra double dash is required to [bypass Streamlit argument parsing](https://github.com/streamlit/streamlit/issues/337). Example commands: 51 | 52 | ```bash 53 | # Use 2nd GPU 54 | streamlit run app.py -- --gpu 1 55 | 56 | # Use 3rd GPU 57 | streamlit run diffusion_app.py -- --gpu 2 58 | ``` 59 | 60 | See: [tips and tricks](docs/tips-n-tricks.md) 61 | 62 | ## Output and gallery viewer 63 | 64 | Each run's metadata and output is saved to the `output/` directory, organized into subfolders named using the timestamp when a run is launched, as well as a unique run ID. Example `output` dir: 65 | 66 | ``` bash 67 | $ tree output 68 | ├── 20210920T232927-vTf6Aot6 69 | │   ├── anim.mp4 70 | │   ├── details.json 71 | │   └── output.PNG 72 | └── 20210920T232935-9TJ9YusD 73 |    ├── anim.mp4 74 |    ├── details.json 75 |    └── output.PNG 76 | ``` 77 | 78 | The gallery viewer reads from `output/` and visualizes previous runs together with saved metadata. 79 | 80 | ![Screenshot of the gallery viewer](docs/images/gallery.jpg) 81 | 82 | If the details are too much, call `python gallery.py --kiosk` instead to only show the images and their prompts. 83 | 84 | ## More details 85 | 86 | + [Architecture](docs/architecture.md) 87 | + [Implementation details](docs/implementation-details.md) 88 | + [Tips and tricks](docs/tips-n-tricks.md) 89 | + [Notes and observations](docs/notes-and-observations.md) -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is organized like so: 3 | + `if __name__ == "__main__" sets up the Streamlit UI elements 4 | + `generate_image` houses interactions between UI and the CLIP image 5 | generation models 6 | + Core model code is abstracted in `logic.py` and imported in `generate_image` 7 | """ 8 | import streamlit as st 9 | from pathlib import Path 10 | import sys 11 | import datetime 12 | import shutil 13 | import torch 14 | import json 15 | import os 16 | import base64 17 | import traceback 18 | 19 | import argparse 20 | 21 | sys.path.append("./taming-transformers") 22 | 23 | from PIL import Image 24 | from typing import Optional, List 25 | from omegaconf import OmegaConf 26 | import imageio 27 | import numpy as np 28 | 29 | # Catch import issue, introduced in version 1.1 30 | # Deprecate in a few minor versions 31 | try: 32 | import cv2 33 | except ModuleNotFoundError: 34 | st.warning( 35 | "Version 1.1 onwards requires opencv. Please update your Python environment as defined in `environment.yml`" 36 | ) 37 | 38 | from logic import VQGANCLIPRun 39 | 40 | # Optional 41 | try: 42 | import git 43 | except ModuleNotFoundError: 44 | pass 45 | 46 | 47 | def generate_image( 48 | text_input: str = "the first day of the waters", 49 | vqgan_ckpt: str = "vqgan_imagenet_f16_16384", 50 | num_steps: int = 300, 51 | image_x: int = 300, 52 | image_y: int = 300, 53 | init_image: Optional[Image.Image] = None, 54 | image_prompts: List[Image.Image] = [], 55 | continue_prev_run: bool = False, 56 | seed: Optional[int] = None, 57 | mse_weight: float = 0, 58 | mse_weight_decay: float = 0, 59 | mse_weight_decay_steps: int = 0, 60 | tv_loss_weight: float = 1e-3, 61 | use_scrolling_zooming: bool = False, 62 | translation_x: int = 0, 63 | translation_y: int = 0, 64 | rotation_angle: float = 0, 65 | zoom_factor: float = 1, 66 | transform_interval: int = 10, 67 | use_cutout_augmentations: bool = True, 68 | device: Optional[torch.device] = None, 69 | ) -> None: 70 | 71 | ### Init ------------------------------------------------------------------- 72 | run = VQGANCLIPRun( 73 | text_input=text_input, 74 | vqgan_ckpt=vqgan_ckpt, 75 | num_steps=num_steps, 76 | image_x=image_x, 77 | image_y=image_y, 78 | seed=seed, 79 | init_image=init_image, 80 | image_prompts=image_prompts, 81 | continue_prev_run=continue_prev_run, 82 | mse_weight=mse_weight, 83 | mse_weight_decay=mse_weight_decay, 84 | mse_weight_decay_steps=mse_weight_decay_steps, 85 | tv_loss_weight=tv_loss_weight, 86 | use_scrolling_zooming=use_scrolling_zooming, 87 | translation_x=translation_x, 88 | translation_y=translation_y, 89 | rotation_angle=rotation_angle, 90 | zoom_factor=zoom_factor, 91 | transform_interval=transform_interval, 92 | use_cutout_augmentations=use_cutout_augmentations, 93 | device=device, 94 | ) 95 | 96 | ### Load model ------------------------------------------------------------- 97 | 98 | if continue_prev_run is True: 99 | run.load_model( 100 | prev_model=st.session_state["model"], 101 | prev_perceptor=st.session_state["perceptor"], 102 | ) 103 | prev_run_id = st.session_state["run_id"] 104 | 105 | else: 106 | # Remove the cache first! CUDA out of memory 107 | if "model" in st.session_state: 108 | del st.session_state["model"] 109 | 110 | if "perceptor" in st.session_state: 111 | del st.session_state["perceptor"] 112 | 113 | st.session_state["model"], st.session_state["perceptor"] = run.load_model() 114 | prev_run_id = None 115 | 116 | # Generate random run ID 117 | # Used to link runs linked w/ continue_prev_run 118 | # ref: https://stackoverflow.com/a/42703382/13095028 119 | # Use URL and filesystem safe version since we're using this as a folder name 120 | run_id = st.session_state["run_id"] = base64.urlsafe_b64encode( 121 | os.urandom(6) 122 | ).decode("ascii") 123 | 124 | run_start_dt = datetime.datetime.now() 125 | 126 | ### Model init ------------------------------------------------------------- 127 | if continue_prev_run is True: 128 | run.model_init(init_image=st.session_state["prev_im"]) 129 | elif init_image is not None: 130 | run.model_init(init_image=init_image) 131 | else: 132 | run.model_init() 133 | 134 | ### Iterate ---------------------------------------------------------------- 135 | step_counter = 0 136 | frames = [] 137 | 138 | try: 139 | # Try block catches st.script_runner.StopExecution, no need of a dedicated stop button 140 | # Reason is st.form is meant to be self-contained either within sidebar, or in main body 141 | # The way the form is implemented in this app splits the form across both regions 142 | # This is intended to prevent the model settings from crowding the main body 143 | # However, touching any button resets the app state, making it impossible to 144 | # implement a stop button that can still dump output 145 | # Thankfully there's a built-in stop button :) 146 | while True: 147 | # While loop to accomodate running predetermined steps or running indefinitely 148 | status_text.text(f"Running step {step_counter}") 149 | 150 | _, im = run.iterate() 151 | 152 | if num_steps > 0: # skip when num_steps = -1 153 | step_progress_bar.progress((step_counter + 1) / num_steps) 154 | else: 155 | step_progress_bar.progress(100) 156 | 157 | # At every step, display and save image 158 | im_display_slot.image(im, caption="Output image", output_format="PNG") 159 | st.session_state["prev_im"] = im 160 | 161 | # ref: https://stackoverflow.com/a/33117447/13095028 162 | # im_byte_arr = io.BytesIO() 163 | # im.save(im_byte_arr, format="JPEG") 164 | # frames.append(im_byte_arr.getvalue()) # read() 165 | frames.append(np.asarray(im)) 166 | 167 | step_counter += 1 168 | 169 | if (step_counter == num_steps) and num_steps > 0: 170 | break 171 | 172 | # Stitch into video using imageio 173 | writer = imageio.get_writer("temp.mp4", fps=24) 174 | for frame in frames: 175 | writer.append_data(frame) 176 | writer.close() 177 | 178 | # Save to output folder if run completed 179 | runoutputdir = outputdir / ( 180 | run_start_dt.strftime("%Y%m%dT%H%M%S") + "-" + run_id 181 | ) 182 | runoutputdir.mkdir() 183 | 184 | # Save final image 185 | im.save(runoutputdir / "output.PNG", format="PNG") 186 | 187 | # Save init image 188 | if init_image is not None: 189 | init_image.save(runoutputdir / "init-image.JPEG", format="JPEG") 190 | 191 | # Save image prompts 192 | for count, image_prompt in enumerate(image_prompts): 193 | image_prompt.save( 194 | runoutputdir / f"image-prompt-{count}.JPEG", format="JPEG" 195 | ) 196 | 197 | # Save animation 198 | shutil.copy("temp.mp4", runoutputdir / "anim.mp4") 199 | 200 | # Save metadata 201 | details = { 202 | "run_id": run_id, 203 | "num_steps": step_counter, 204 | "planned_num_steps": num_steps, 205 | "text_input": text_input, 206 | "init_image": False if init_image is None else True, 207 | "image_prompts": False if len(image_prompts) == 0 else True, 208 | "continue_prev_run": continue_prev_run, 209 | "prev_run_id": prev_run_id, 210 | "seed": run.seed, 211 | "Xdim": image_x, 212 | "ydim": image_y, 213 | "vqgan_ckpt": vqgan_ckpt, 214 | "start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"), 215 | "end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"), 216 | "mse_weight": mse_weight, 217 | "mse_weight_decay": mse_weight_decay, 218 | "mse_weight_decay_steps": mse_weight_decay_steps, 219 | "tv_loss_weight": tv_loss_weight, 220 | } 221 | 222 | if use_scrolling_zooming: 223 | details.update( 224 | { 225 | "translation_x": translation_x, 226 | "translation_y": translation_y, 227 | "rotation_angle": rotation_angle, 228 | "zoom_factor": zoom_factor, 229 | "transform_interval": transform_interval, 230 | } 231 | ) 232 | if use_cutout_augmentations: 233 | details["use_cutout_augmentations"] = True 234 | 235 | if "git" in sys.modules: 236 | try: 237 | repo = git.Repo(search_parent_directories=True) 238 | commit_sha = repo.head.object.hexsha 239 | details["commit_sha"] = commit_sha[:6] 240 | except Exception as e: 241 | print("GitPython detected but not able to write commit SHA to file") 242 | print(f"raised Exception {e}") 243 | 244 | with open(runoutputdir / "details.json", "w") as f: 245 | json.dump(details, f, indent=4) 246 | 247 | status_text.text("Done!") # End of run 248 | 249 | except st.script_runner.StopException as e: 250 | # Dump output to dashboard 251 | print(f"Received Streamlit StopException") 252 | status_text.text("Execution interruped, dumping outputs ...") 253 | writer = imageio.get_writer("temp.mp4", fps=24) 254 | for frame in frames: 255 | writer.append_data(frame) 256 | writer.close() 257 | 258 | # TODO: Make the following DRY 259 | # Save to output folder if run completed 260 | runoutputdir = outputdir / ( 261 | run_start_dt.strftime("%Y%m%dT%H%M%S") + "-" + run_id 262 | ) 263 | runoutputdir.mkdir() 264 | 265 | # Save final image 266 | im.save(runoutputdir / "output.PNG", format="PNG") 267 | 268 | # Save init image 269 | if init_image is not None: 270 | init_image.save(runoutputdir / "init-image.JPEG", format="JPEG") 271 | 272 | # Save image prompts 273 | for count, image_prompt in enumerate(image_prompts): 274 | image_prompt.save( 275 | runoutputdir / f"image-prompt-{count}.JPEG", format="JPEG" 276 | ) 277 | 278 | # Save animation 279 | shutil.copy("temp.mp4", runoutputdir / "anim.mp4") 280 | 281 | # Save metadata 282 | details = { 283 | "run_id": run_id, 284 | "num_steps": step_counter, 285 | "planned_num_steps": num_steps, 286 | "text_input": text_input, 287 | "init_image": False if init_image is None else True, 288 | "image_prompts": False if len(image_prompts) == 0 else True, 289 | "continue_prev_run": continue_prev_run, 290 | "prev_run_id": prev_run_id, 291 | "seed": run.seed, 292 | "Xdim": image_x, 293 | "ydim": image_y, 294 | "vqgan_ckpt": vqgan_ckpt, 295 | "start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"), 296 | "end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"), 297 | "mse_weight": mse_weight, 298 | "mse_weight_decay": mse_weight_decay, 299 | "mse_weight_decay_steps": mse_weight_decay_steps, 300 | "tv_loss_weight": tv_loss_weight, 301 | } 302 | 303 | if use_scrolling_zooming: 304 | details.update( 305 | { 306 | "translation_x": translation_x, 307 | "translation_y": translation_y, 308 | "rotation_angle": rotation_angle, 309 | "zoom_factor": zoom_factor, 310 | "transform_interval": transform_interval, 311 | } 312 | ) 313 | if use_cutout_augmentations: 314 | details["use_cutout_augmentations"] = True 315 | 316 | if "git" in sys.modules: 317 | try: 318 | repo = git.Repo(search_parent_directories=True) 319 | commit_sha = repo.head.object.hexsha 320 | details["commit_sha"] = commit_sha[:6] 321 | except Exception as e: 322 | print("GitPython detected but not able to write commit SHA to file") 323 | print(f"raised Exception {e}") 324 | 325 | with open(runoutputdir / "details.json", "w") as f: 326 | json.dump(details, f, indent=4) 327 | 328 | status_text.text("Done!") # End of run 329 | 330 | 331 | if __name__ == "__main__": 332 | 333 | # Argparse to capture GPU num 334 | parser = argparse.ArgumentParser() 335 | 336 | parser.add_argument( 337 | "--gpu", type=str, default=None, help="Specify GPU number. Defaults to None." 338 | ) 339 | args = parser.parse_args() 340 | 341 | # Select specific GPU if chosen 342 | if args.gpu is not None: 343 | for i in args.gpu.split(","): 344 | assert ( 345 | int(i) < torch.cuda.device_count() 346 | ), f"You specified --gpu {args.gpu} but torch.cuda.device_count() returned {torch.cuda.device_count()}" 347 | 348 | try: 349 | device = torch.device(f"cuda:{args.gpu}") 350 | except RuntimeError: 351 | print(traceback.format_exc()) 352 | else: 353 | device = None 354 | 355 | defaults = OmegaConf.load("defaults.yaml") 356 | outputdir = Path("output") 357 | if not outputdir.exists(): 358 | outputdir.mkdir() 359 | 360 | st.set_page_config(page_title="VQGAN-CLIP playground") 361 | st.title("VQGAN-CLIP playground") 362 | 363 | # Determine what weights are available in `assets/` 364 | weights_dir = Path("assets").resolve() 365 | available_weight_ckpts = list(weights_dir.glob("*.ckpt")) 366 | available_weight_configs = list(weights_dir.glob("*.yaml")) 367 | available_weights = [ 368 | i.stem 369 | for i in available_weight_ckpts 370 | if i.stem in [j.stem for j in available_weight_configs] 371 | ] 372 | 373 | # i.e. no weights found, ask user to download weights 374 | if len(available_weights) == 0: 375 | st.warning("No weights found in `assets/`, refer to `download-weights.sh`") 376 | st.stop() 377 | 378 | # Set vqgan_imagenet_f16_1024 as default if possible 379 | if "vqgan_imagenet_f16_1024" in available_weights: 380 | default_weight_index = available_weights.index("vqgan_imagenet_f16_1024") 381 | else: 382 | default_weight_index = 0 383 | 384 | # Start of input form 385 | with st.form("form-inputs"): 386 | # Only element not in the sidebar, but in the form 387 | text_input = st.text_input( 388 | "Text prompt", 389 | help="VQGAN-CLIP will generate an image that best fits the prompt", 390 | ) 391 | radio = st.sidebar.radio( 392 | "Model weights", 393 | available_weights, 394 | index=default_weight_index, 395 | help="Choose which weights to load, trained on different datasets. Make sure the weights and configs are downloaded to `assets/` as per the README!", 396 | ) 397 | num_steps = st.sidebar.number_input( 398 | "Num steps", 399 | value=defaults["num_steps"], 400 | min_value=-1, 401 | max_value=None, 402 | step=1, 403 | help="Specify -1 to run indefinitely. Use Streamlit's stop button in the top right corner to terminate execution. The exception is caught so the most recent output will be dumped to dashboard", 404 | ) 405 | 406 | image_x = st.sidebar.number_input( 407 | "Xdim", value=defaults["Xdim"], help="Width of output image, in pixels" 408 | ) 409 | image_y = st.sidebar.number_input( 410 | "ydim", value=defaults["ydim"], help="Height of output image, in pixels" 411 | ) 412 | set_seed = st.sidebar.checkbox( 413 | "Set seed", 414 | value=defaults["set_seed"], 415 | help="Check to set random seed for reproducibility. Will add option to specify seed", 416 | ) 417 | 418 | seed_widget = st.sidebar.empty() 419 | if set_seed is True: 420 | # Use text_input as number_input relies on JS 421 | # which can't natively handle large numbers 422 | # torch.seed() generates int w/ 19 or 20 chars! 423 | seed_str = seed_widget.text_input( 424 | "Seed", value=str(defaults["seed"]), help="Random seed to use" 425 | ) 426 | try: 427 | seed = int(seed_str) 428 | except ValueError as e: 429 | st.error("seed input needs to be int") 430 | else: 431 | seed = None 432 | 433 | use_custom_starting_image = st.sidebar.checkbox( 434 | "Use starting image", 435 | value=defaults["use_starting_image"], 436 | help="Check to add a starting image to the network", 437 | ) 438 | 439 | starting_image_widget = st.sidebar.empty() 440 | if use_custom_starting_image is True: 441 | init_image = starting_image_widget.file_uploader( 442 | "Upload starting image", 443 | type=["png", "jpeg", "jpg"], 444 | accept_multiple_files=False, 445 | help="Starting image for the network, will be resized to fit specified dimensions", 446 | ) 447 | # Convert from UploadedFile object to PIL Image 448 | if init_image is not None: 449 | init_image: Image.Image = Image.open(init_image).convert( 450 | "RGB" 451 | ) # just to be sure 452 | else: 453 | init_image = None 454 | 455 | use_image_prompts = st.sidebar.checkbox( 456 | "Add image prompt(s)", 457 | value=defaults["use_image_prompts"], 458 | help="Check to add image prompt(s), conditions the network similar to the text prompt", 459 | ) 460 | 461 | image_prompts_widget = st.sidebar.empty() 462 | if use_image_prompts is True: 463 | image_prompts = image_prompts_widget.file_uploader( 464 | "Upload image prompts(s)", 465 | type=["png", "jpeg", "jpg"], 466 | accept_multiple_files=True, 467 | help="Image prompt(s) for the network, will be resized to fit specified dimensions", 468 | ) 469 | # Convert from UploadedFile object to PIL Image 470 | if len(image_prompts) != 0: 471 | image_prompts = [Image.open(i).convert("RGB") for i in image_prompts] 472 | else: 473 | image_prompts = [] 474 | 475 | continue_prev_run = st.sidebar.checkbox( 476 | "Continue previous run", 477 | value=defaults["continue_prev_run"], 478 | help="Use existing image and existing weights for the next run. If yes, ignores 'Use starting image'", 479 | ) 480 | 481 | use_mse_reg = st.sidebar.checkbox( 482 | "Use MSE regularization", 483 | value=defaults["use_mse_regularization"], 484 | help="Check to add MSE regularization", 485 | ) 486 | mse_weight_widget = st.sidebar.empty() 487 | mse_weight_decay_widget = st.sidebar.empty() 488 | mse_weight_decay_steps = st.sidebar.empty() 489 | 490 | if use_mse_reg is True: 491 | mse_weight = mse_weight_widget.number_input( 492 | "MSE weight", 493 | value=defaults["mse_weight"], 494 | # min_value=0.0, # leave this out to allow creativity 495 | step=0.05, 496 | help="Set weights for MSE regularization", 497 | ) 498 | mse_weight_decay = mse_weight_decay_widget.number_input( 499 | "Decay MSE weight by ...", 500 | value=defaults["mse_weight_decay"], 501 | # min_value=0.0, # leave this out to allow creativity 502 | step=0.05, 503 | help="Subtracts MSE weight by this amount at every step change. MSE weight change stops at zero", 504 | ) 505 | mse_weight_decay_steps = mse_weight_decay_steps.number_input( 506 | "... every N steps", 507 | value=defaults["mse_weight_decay_steps"], 508 | min_value=0, 509 | step=1, 510 | help="Number of steps to subtract MSE weight. Leave zero for no weight decay", 511 | ) 512 | else: 513 | mse_weight = 0 514 | mse_weight_decay = 0 515 | mse_weight_decay_steps = 0 516 | 517 | use_tv_loss = st.sidebar.checkbox( 518 | "Use TV loss regularization", 519 | value=defaults["use_tv_loss_regularization"], 520 | help="Check to add MSE regularization", 521 | ) 522 | tv_loss_weight_widget = st.sidebar.empty() 523 | if use_tv_loss is True: 524 | tv_loss_weight = tv_loss_weight_widget.number_input( 525 | "TV loss weight", 526 | value=defaults["tv_loss_weight"], 527 | min_value=0.0, 528 | step=1e-4, 529 | help="Set weights for TV loss regularization, which encourages spatial smoothness. Ref: https://github.com/jcjohnson/neural-style/issues/302", 530 | format="%.1e", 531 | ) 532 | else: 533 | tv_loss_weight = 0 534 | 535 | use_scrolling_zooming = st.sidebar.checkbox( 536 | "Scrolling/zooming transforms", 537 | value=False, 538 | help="At fixed intervals, move the generated image up/down/left/right or zoom in/out", 539 | ) 540 | translation_x_widget = st.sidebar.empty() 541 | translation_y_widget = st.sidebar.empty() 542 | rotation_angle_widget = st.sidebar.empty() 543 | zoom_factor_widget = st.sidebar.empty() 544 | transform_interval_widget = st.sidebar.empty() 545 | if use_scrolling_zooming is True: 546 | translation_x = translation_x_widget.number_input( 547 | "Translation in X", value=0, min_value=0, step=1 548 | ) 549 | translation_y = translation_y_widget.number_input( 550 | "Translation in y", value=0, min_value=0, step=1 551 | ) 552 | rotation_angle = rotation_angle_widget.number_input( 553 | "Rotation angle (degrees)", 554 | value=0.0, 555 | min_value=0.0, 556 | max_value=360.0, 557 | step=0.05, 558 | format="%.2f", 559 | ) 560 | zoom_factor = zoom_factor_widget.number_input( 561 | "Zoom factor", 562 | value=1.0, 563 | min_value=0.1, 564 | max_value=10.0, 565 | step=0.02, 566 | format="%.2f", 567 | ) 568 | transform_interval = transform_interval_widget.number_input( 569 | "Iterations per frame", 570 | value=10, 571 | min_value=0, 572 | step=1, 573 | help="Note: Will multiply by num steps above!", 574 | ) 575 | else: 576 | translation_x = 0 577 | translation_y = 0 578 | rotation_angle = 0 579 | zoom_factor = 1 580 | transform_interval = 1 581 | 582 | use_cutout_augmentations = st.sidebar.checkbox( 583 | "Use cutout augmentations", 584 | value=True, 585 | help="Adds cutout augmentatinos in the image generation process. Uses up to additional 4 GiB of GPU memory. Greatly improves image quality. Toggled on by default.", 586 | ) 587 | 588 | submitted = st.form_submit_button("Run!") 589 | # End of form 590 | 591 | status_text = st.empty() 592 | status_text.text("Pending input prompt") 593 | step_progress_bar = st.progress(0) 594 | 595 | im_display_slot = st.empty() 596 | vid_display_slot = st.empty() 597 | debug_slot = st.empty() 598 | 599 | if "prev_im" in st.session_state: 600 | im_display_slot.image( 601 | st.session_state["prev_im"], caption="Output image", output_format="PNG" 602 | ) 603 | 604 | with st.expander("Expand for README"): 605 | with open("README.md", "r") as f: 606 | # Preprocess links to redirect to github 607 | # Thank you https://discuss.streamlit.io/u/asehmi, works like a charm! 608 | # ref: https://discuss.streamlit.io/t/image-in-markdown/13274/8 609 | markdown_links = [str(i) for i in Path("docs/").glob("*.md")] 610 | images = [str(i) for i in Path("docs/images/").glob("*")] 611 | readme_lines = f.readlines() 612 | readme_buffer = [] 613 | 614 | for line in readme_lines: 615 | for md_link in markdown_links: 616 | if md_link in line: 617 | line = line.replace( 618 | md_link, 619 | "https://github.com/tnwei/vqgan-clip-app/tree/main/" 620 | + md_link, 621 | ) 622 | 623 | readme_buffer.append(line) 624 | for image in images: 625 | if image in line: 626 | st.markdown(" ".join(readme_buffer[:-1])) 627 | st.image( 628 | f"https://raw.githubusercontent.com/tnwei/vqgan-clip-app/main/{image}" 629 | ) 630 | readme_buffer.clear() 631 | st.markdown(" ".join(readme_buffer)) 632 | 633 | with st.expander("Expand for CHANGELOG"): 634 | with open("CHANGELOG.md", "r") as f: 635 | st.markdown(f.read()) 636 | 637 | if submitted: 638 | # debug_slot.write(st.session_state) # DEBUG 639 | status_text.text("Loading weights ...") 640 | generate_image( 641 | # Inputs 642 | text_input=text_input, 643 | vqgan_ckpt=radio, 644 | num_steps=num_steps, 645 | image_x=int(image_x), 646 | image_y=int(image_y), 647 | seed=int(seed) if set_seed is True else None, 648 | init_image=init_image, 649 | image_prompts=image_prompts, 650 | continue_prev_run=continue_prev_run, 651 | mse_weight=mse_weight, 652 | mse_weight_decay=mse_weight_decay, 653 | mse_weight_decay_steps=mse_weight_decay_steps, 654 | use_scrolling_zooming=use_scrolling_zooming, 655 | translation_x=translation_x, 656 | translation_y=translation_y, 657 | rotation_angle=rotation_angle, 658 | zoom_factor=zoom_factor, 659 | transform_interval=transform_interval, 660 | use_cutout_augmentations=use_cutout_augmentations, 661 | device=device, 662 | ) 663 | 664 | vid_display_slot.video("temp.mp4") 665 | # debug_slot.write(st.session_state) # DEBUG 666 | -------------------------------------------------------------------------------- /assets/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tnwei/vqgan-clip-app/82aff7993b5f64baacfe626e214398bad2d657d7/assets/.gitkeep -------------------------------------------------------------------------------- /defaults.yaml: -------------------------------------------------------------------------------- 1 | # Modify for different systems, e.g. larger default xdim/ydim for more powerful GPUs 2 | num_steps: 500 3 | Xdim: 640 4 | ydim: 480 5 | set_seed: false 6 | seed: 0 7 | use_starting_image: false 8 | use_image_prompts: false 9 | continue_prev_run: false 10 | mse_weight: 0.5 11 | mse_weight_decay: 0.1 12 | mse_weight_decay_steps: 50 13 | use_mse_regularization: false 14 | use_tv_loss_regularization: true 15 | tv_loss_weight: 1e-3 -------------------------------------------------------------------------------- /diffusion_app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from pathlib import Path 3 | import sys 4 | import datetime 5 | import shutil 6 | import json 7 | import os 8 | import torch 9 | import traceback 10 | import base64 11 | from PIL import Image 12 | from typing import Optional 13 | import argparse 14 | 15 | sys.path.append("./taming-transformers") 16 | 17 | import imageio 18 | import numpy as np 19 | from diffusion_logic import CLIPGuidedDiffusion, DIFFUSION_METHODS_AND_WEIGHTS 20 | 21 | # Optional 22 | try: 23 | import git 24 | except ModuleNotFoundError: 25 | pass 26 | 27 | 28 | def generate_image( 29 | diffusion_weights: str, 30 | prompt: str, 31 | seed=0, 32 | num_steps=500, 33 | continue_prev_run=True, 34 | init_image: Optional[Image.Image] = None, 35 | skip_timesteps: int = 0, 36 | use_cutout_augmentations: bool = False, 37 | device: Optional[torch.device] = None, 38 | ) -> None: 39 | 40 | ### Init ------------------------------------------------------------------- 41 | run = CLIPGuidedDiffusion( 42 | prompt=prompt, 43 | ckpt=diffusion_weights, 44 | seed=seed, 45 | num_steps=num_steps, 46 | continue_prev_run=continue_prev_run, 47 | skip_timesteps=skip_timesteps, 48 | use_cutout_augmentations=use_cutout_augmentations, 49 | device=device, 50 | ) 51 | 52 | # Generate random run ID 53 | # Used to link runs linked w/ continue_prev_run 54 | # ref: https://stackoverflow.com/a/42703382/13095028 55 | # Use URL and filesystem safe version since we're using this as a folder name 56 | run_id = st.session_state["run_id"] = base64.urlsafe_b64encode( 57 | os.urandom(6) 58 | ).decode("ascii") 59 | 60 | if "loaded_wt" not in st.session_state: 61 | st.session_state["loaded_wt"] = None 62 | 63 | run_start_dt = datetime.datetime.now() 64 | 65 | ### Load model ------------------------------------------------------------- 66 | if ( 67 | continue_prev_run 68 | and ("model" in st.session_state) 69 | and ("clip_model" in st.session_state) 70 | and ("diffusion" in st.session_state) 71 | and st.session_state["loaded_wt"] == diffusion_weights 72 | ): 73 | run.load_model( 74 | prev_model=st.session_state["model"], 75 | prev_diffusion=st.session_state["diffusion"], 76 | prev_clip_model=st.session_state["clip_model"], 77 | ) 78 | else: 79 | ( 80 | st.session_state["model"], 81 | st.session_state["diffusion"], 82 | st.session_state["clip_model"], 83 | ) = run.load_model( 84 | model_file_loc="assets/" 85 | + DIFFUSION_METHODS_AND_WEIGHTS.get(diffusion_method) 86 | ) 87 | st.session_state["loaded_wt"] = diffusion_method 88 | 89 | ### Model init ------------------------------------------------------------- 90 | # if continue_prev_run is True: 91 | # run.model_init(init_image=st.session_state["prev_im"]) 92 | # elif init_image is not None: 93 | if init_image is not None: 94 | run.model_init(init_image=init_image) 95 | else: 96 | run.model_init() 97 | 98 | ### Iterate ---------------------------------------------------------------- 99 | step_counter = 0 + skip_timesteps 100 | frames = [] 101 | 102 | try: 103 | # Try block catches st.script_runner.StopExecution, no need of a dedicated stop button 104 | # Reason is st.form is meant to be self-contained either within sidebar, or in main body 105 | # The way the form is implemented in this app splits the form across both regions 106 | # This is intended to prevent the model settings from crowding the main body 107 | # However, touching any button resets the app state, making it impossible to 108 | # implement a stop button that can still dump output 109 | # Thankfully there's a built-in stop button :) 110 | while True: 111 | # While loop to accomodate running predetermined steps or running indefinitely 112 | status_text.text(f"Running step {step_counter}") 113 | 114 | ims = run.iterate() 115 | im = ims[0] 116 | 117 | if num_steps > 0: # skip when num_steps = -1 118 | step_progress_bar.progress((step_counter + 1) / num_steps) 119 | else: 120 | step_progress_bar.progress(100) 121 | 122 | # At every step, display and save image 123 | im_display_slot.image(im, caption="Output image", output_format="PNG") 124 | st.session_state["prev_im"] = im 125 | 126 | # ref: https://stackoverflow.com/a/33117447/13095028 127 | # im_byte_arr = io.BytesIO() 128 | # im.save(im_byte_arr, format="JPEG") 129 | # frames.append(im_byte_arr.getvalue()) # read() 130 | frames.append(np.asarray(im)) 131 | 132 | step_counter += 1 133 | 134 | if (step_counter == num_steps) and num_steps > 0: 135 | break 136 | 137 | # Stitch into video using imageio 138 | writer = imageio.get_writer("temp.mp4", fps=24) 139 | for frame in frames: 140 | writer.append_data(frame) 141 | writer.close() 142 | 143 | # Save to output folder if run completed 144 | runoutputdir = outputdir / ( 145 | run_start_dt.strftime("%Y%m%dT%H%M%S") + "-" + run_id 146 | ) 147 | runoutputdir.mkdir() 148 | 149 | im.save(runoutputdir / "output.PNG", format="PNG") 150 | shutil.copy("temp.mp4", runoutputdir / "anim.mp4") 151 | 152 | details = { 153 | "run_id": run_id, 154 | "diffusion_method": diffusion_method, 155 | "ckpt": DIFFUSION_METHODS_AND_WEIGHTS.get(diffusion_method), 156 | "num_steps": step_counter, 157 | "planned_num_steps": num_steps, 158 | "text_input": prompt, 159 | "continue_prev_run": continue_prev_run, 160 | "seed": seed, 161 | "Xdim": imsize, 162 | "ydim": imsize, 163 | "start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"), 164 | "end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"), 165 | } 166 | 167 | if use_cutout_augmentations: 168 | details["use_cutout_augmentations"] = True 169 | 170 | if "git" in sys.modules: 171 | try: 172 | repo = git.Repo(search_parent_directories=True) 173 | commit_sha = repo.head.object.hexsha 174 | details["commit_sha"] = commit_sha[:6] 175 | except Exception as e: 176 | print("GitPython detected but not able to write commit SHA to file") 177 | print(f"raised Exception {e}") 178 | 179 | with open(runoutputdir / "details.json", "w") as f: 180 | json.dump(details, f, indent=4) 181 | 182 | status_text.text("Done!") # End of run 183 | 184 | except st.script_runner.StopException as e: 185 | # Dump output to dashboard 186 | print(f"Received Streamlit StopException") 187 | status_text.text("Execution interruped, dumping outputs ...") 188 | writer = imageio.get_writer("temp.mp4", fps=24) 189 | for frame in frames: 190 | writer.append_data(frame) 191 | writer.close() 192 | 193 | # Save to output folder if run completed 194 | runoutputdir = outputdir / ( 195 | run_start_dt.strftime("%Y%m%dT%H%M%S") + "-" + run_id 196 | ) 197 | runoutputdir.mkdir() 198 | 199 | im.save(runoutputdir / "output.PNG", format="PNG") 200 | shutil.copy("temp.mp4", runoutputdir / "anim.mp4") 201 | 202 | details = { 203 | "run_id": run_id, 204 | "diffusion_method": diffusion_method, 205 | "ckpt": DIFFUSION_METHODS_AND_WEIGHTS.get(diffusion_method), 206 | "num_steps": step_counter, 207 | "planned_num_steps": num_steps, 208 | "text_input": prompt, 209 | "continue_prev_run": continue_prev_run, 210 | "seed": seed, 211 | "Xdim": imsize, 212 | "ydim": imsize, 213 | "start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"), 214 | "end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"), 215 | } 216 | 217 | if use_cutout_augmentations: 218 | details["use_cutout_augmentations"] = True 219 | 220 | if "git" in sys.modules: 221 | try: 222 | repo = git.Repo(search_parent_directories=True) 223 | commit_sha = repo.head.object.hexsha 224 | details["commit_sha"] = commit_sha[:6] 225 | except Exception as e: 226 | print("GitPython detected but not able to write commit SHA to file") 227 | print(f"raised Exception {e}") 228 | 229 | with open(runoutputdir / "details.json", "w") as f: 230 | json.dump(details, f, indent=4) 231 | 232 | status_text.text("Done!") # End of run 233 | 234 | 235 | if __name__ == "__main__": 236 | # Argparse to capture GPU num 237 | parser = argparse.ArgumentParser() 238 | 239 | parser.add_argument( 240 | "--gpu", type=str, default=None, help="Specify GPU number. Defaults to None." 241 | ) 242 | args = parser.parse_args() 243 | 244 | # Select specific GPU if chosen 245 | if args.gpu is not None: 246 | for i in args.gpu.split(","): 247 | assert ( 248 | int(i) < torch.cuda.device_count() 249 | ), f"You specified --gpu {args.gpu} but torch.cuda.device_count() returned {torch.cuda.device_count()}" 250 | 251 | try: 252 | device = torch.device(f"cuda:{args.gpu}") 253 | except RuntimeError: 254 | print(traceback.format_exc()) 255 | else: 256 | device = None 257 | 258 | outputdir = Path("output") 259 | if not outputdir.exists(): 260 | outputdir.mkdir() 261 | 262 | st.set_page_config(page_title="CLIP guided diffusion playground") 263 | st.title("CLIP guided diffusion playground") 264 | 265 | # Determine what weights are available in `assets/` 266 | weights_dir = Path("assets").resolve() 267 | available_diffusion_weights = list(weights_dir.glob("*.pt")) 268 | available_diffusion_weights = [i.name for i in available_diffusion_weights] 269 | diffusion_weights_and_methods = { 270 | j: i for i, j in DIFFUSION_METHODS_AND_WEIGHTS.items() 271 | } 272 | available_diffusion_methods = [ 273 | diffusion_weights_and_methods.get(i) for i in available_diffusion_weights 274 | ] 275 | 276 | # i.e. no weights found, ask user to download weights 277 | if len(available_diffusion_methods) == 0: 278 | st.warning( 279 | "No weights found, download diffusion weights in `download-diffusion-weights.sh`. " 280 | ) 281 | st.stop() 282 | 283 | # Start of input form 284 | with st.form("form-inputs"): 285 | # Only element not in the sidebar, but in the form 286 | 287 | text_input = st.text_input( 288 | "Text prompt", 289 | help="CLIP-guided diffusion will generate an image that best fits the prompt", 290 | ) 291 | 292 | diffusion_method = st.sidebar.radio( 293 | "Method", 294 | available_diffusion_methods, 295 | index=0, 296 | help="Choose diffusion image generation method, corresponding to the notebooks in Eleuther's repo", 297 | ) 298 | 299 | if diffusion_method.startswith("256"): 300 | image_size_notice = st.sidebar.text("Image size: fixed to 256x256") 301 | imsize = 256 302 | elif diffusion_method.startswith("512"): 303 | image_size_notice = st.sidebar.text("Image size: fixed to 512x512") 304 | imsize = 512 305 | 306 | set_seed = st.sidebar.checkbox( 307 | "Set seed", 308 | value=0, 309 | help="Check to set random seed for reproducibility. Will add option to specify seed", 310 | ) 311 | num_steps = st.sidebar.number_input( 312 | "Num steps", 313 | value=1000, 314 | min_value=0, 315 | max_value=None, 316 | step=1, 317 | # help="Specify -1 to run indefinitely. Use Streamlit's stop button in the top right corner to terminate execution. The exception is caught so the most recent output will be dumped to dashboard", 318 | ) 319 | 320 | seed_widget = st.sidebar.empty() 321 | if set_seed is True: 322 | seed = seed_widget.number_input("Seed", value=0, help="Random seed to use") 323 | else: 324 | seed = None 325 | 326 | use_custom_reference_image = st.sidebar.checkbox( 327 | "Use reference image", 328 | value=False, 329 | help="Check to add a reference image. The network will attempt to match the generated image to the provided reference", 330 | ) 331 | 332 | reference_image_widget = st.sidebar.empty() 333 | skip_timesteps_widget = st.sidebar.empty() 334 | if use_custom_reference_image is True: 335 | reference_image = reference_image_widget.file_uploader( 336 | "Upload reference image", 337 | type=["png", "jpeg", "jpg"], 338 | accept_multiple_files=False, 339 | help="Reference image for the network, will be resized to fit specified dimensions", 340 | ) 341 | # Convert from UploadedFile object to PIL Image 342 | if reference_image is not None: 343 | reference_image: Image.Image = Image.open(reference_image).convert( 344 | "RGB" 345 | ) # just to be sure 346 | skip_timesteps = skip_timesteps_widget.number_input( 347 | "Skip timesteps (suggested 200-500)", 348 | value=200, 349 | help="Higher values make the output look more like the reference image", 350 | ) 351 | else: 352 | reference_image = None 353 | skip_timesteps = 0 354 | 355 | continue_prev_run = st.sidebar.checkbox( 356 | "Skip init if models are loaded", 357 | value=True, 358 | help="Skips lengthy model init", 359 | ) 360 | 361 | use_cutout_augmentations = st.sidebar.checkbox( 362 | "Use cutout augmentations", 363 | value=False, 364 | help="Adds cutout augmentations in the image generation process. Uses additional 1-2 GiB of GPU memory. Increases image quality, but probably not noticeable for guided diffusion since it's already pretty HQ and consumes a lot of VRAM, but feel free to experiment. Will significantly change image composition if toggled on vs toggled off. Toggled off by default.", 365 | ) 366 | 367 | submitted = st.form_submit_button("Run!") 368 | # End of form 369 | 370 | status_text = st.empty() 371 | status_text.text("Pending input prompt") 372 | step_progress_bar = st.progress(0) 373 | 374 | im_display_slot = st.empty() 375 | vid_display_slot = st.empty() 376 | debug_slot = st.empty() 377 | 378 | if "prev_im" in st.session_state: 379 | im_display_slot.image( 380 | st.session_state["prev_im"], caption="Output image", output_format="PNG" 381 | ) 382 | 383 | with st.expander("Expand for README"): 384 | with open("README.md", "r") as f: 385 | # Preprocess links to redirect to github 386 | # Thank you https://discuss.streamlit.io/u/asehmi, works like a charm! 387 | # ref: https://discuss.streamlit.io/t/image-in-markdown/13274/8 388 | markdown_links = [str(i) for i in Path("docs/").glob("*.md")] 389 | images = [str(i) for i in Path("docs/images/").glob("*")] 390 | readme_lines = f.readlines() 391 | readme_buffer = [] 392 | 393 | for line in readme_lines: 394 | for md_link in markdown_links: 395 | if md_link in line: 396 | line = line.replace( 397 | md_link, 398 | "https://github.com/tnwei/vqgan-clip-app/tree/main/" 399 | + md_link, 400 | ) 401 | 402 | readme_buffer.append(line) 403 | for image in images: 404 | if image in line: 405 | st.markdown(" ".join(readme_buffer[:-1])) 406 | st.image( 407 | f"https://raw.githubusercontent.com/tnwei/vqgan-clip-app/main/{image}" 408 | ) 409 | readme_buffer.clear() 410 | st.markdown(" ".join(readme_buffer)) 411 | 412 | with st.expander("Expand for CHANGELOG"): 413 | with open("CHANGELOG.md", "r") as f: 414 | st.markdown(f.read()) 415 | 416 | if submitted: 417 | # debug_slot.write(st.session_state) # DEBUG 418 | status_text.text("Loading weights ...") 419 | generate_image( 420 | diffusion_weights=diffusion_method, 421 | prompt=text_input, 422 | seed=seed, 423 | num_steps=num_steps, 424 | continue_prev_run=continue_prev_run, 425 | init_image=reference_image, 426 | skip_timesteps=skip_timesteps, 427 | use_cutout_augmentations=use_cutout_augmentations, 428 | device=device, 429 | ) 430 | vid_display_slot.video("temp.mp4") 431 | # debug_slot.write(st.session_state) # DEBUG 432 | -------------------------------------------------------------------------------- /diffusion_logic.py: -------------------------------------------------------------------------------- 1 | import clip 2 | import sys 3 | import torch 4 | from torchvision import transforms 5 | from torchvision.transforms import functional as TF 6 | from torch import nn 7 | from torch.nn import functional as F 8 | import lpips 9 | from PIL import Image 10 | import kornia.augmentation as K 11 | from typing import Optional 12 | 13 | sys.path.append("./guided-diffusion") 14 | 15 | from guided_diffusion.script_util import ( 16 | create_model_and_diffusion, 17 | model_and_diffusion_defaults, 18 | ) 19 | 20 | DIFFUSION_METHODS_AND_WEIGHTS = { 21 | # "CLIP Guided Diffusion 256x256", 22 | "256x256 HQ Uncond": "256x256_diffusion_uncond.pt", 23 | "512x512 HQ Cond": "512x512_diffusion.pt", 24 | "512x512 HQ Uncond": "512x512_diffusion_uncond_finetune_008100.pt", 25 | } 26 | 27 | 28 | def spherical_dist_loss(x, y): 29 | x = F.normalize(x, dim=-1) 30 | y = F.normalize(y, dim=-1) 31 | return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) 32 | 33 | 34 | def parse_prompt(prompt): 35 | vals = prompt.rsplit(":", 1) 36 | vals = vals + ["", "1"][len(vals) :] 37 | return vals[0], float(vals[1]) 38 | 39 | 40 | class MakeCutouts(nn.Module): 41 | def __init__(self, cut_size, cutn, cut_pow=1.0, noise_fac=None, augs=None): 42 | super().__init__() 43 | self.cut_size = cut_size 44 | self.cutn = cutn 45 | self.cut_pow = cut_pow 46 | self.noise_fac = noise_fac 47 | self.augs = augs 48 | 49 | def forward(self, input): 50 | sideY, sideX = input.shape[2:4] 51 | max_size = min(sideX, sideY) 52 | min_size = min(sideX, sideY, self.cut_size) 53 | cutouts = [] 54 | for _ in range(self.cutn): 55 | size = int( 56 | torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size 57 | ) 58 | offsetx = torch.randint(0, sideX - size + 1, ()) 59 | offsety = torch.randint(0, sideY - size + 1, ()) 60 | cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] 61 | cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) 62 | 63 | if self.augs: 64 | batch = self.augs(torch.cat(cutouts, dim=0)) 65 | else: 66 | batch = torch.cat(cutouts, dim=0) 67 | 68 | if self.noise_fac: 69 | facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac) 70 | batch = batch + facs * torch.randn_like(batch) 71 | 72 | return batch 73 | 74 | 75 | def tv_loss(input): 76 | """L2 total variation loss, as in Mahendran et al.""" 77 | input = F.pad(input, (0, 1, 0, 1), "replicate") 78 | x_diff = input[..., :-1, 1:] - input[..., :-1, :-1] 79 | y_diff = input[..., 1:, :-1] - input[..., :-1, :-1] 80 | return (x_diff ** 2 + y_diff ** 2).mean([1, 2, 3]) 81 | 82 | 83 | def range_loss(input): 84 | return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3]) 85 | 86 | 87 | class CLIPGuidedDiffusion: 88 | def __init__( 89 | self, 90 | prompt: str, 91 | ckpt: str, 92 | batch_size: int = 1, 93 | clip_guidance_scale: float = 1000, 94 | seed: int = 0, 95 | num_steps: int = 1000, 96 | continue_prev_run: bool = True, 97 | skip_timesteps: int = 0, 98 | use_cutout_augmentations: bool = False, 99 | device: Optional[torch.device] = None, 100 | ) -> None: 101 | 102 | assert ckpt in DIFFUSION_METHODS_AND_WEIGHTS.keys() 103 | self.ckpt = ckpt 104 | print(self.ckpt) 105 | 106 | # Default config 107 | self.model_config = model_and_diffusion_defaults() 108 | self.model_config.update( 109 | { 110 | "attention_resolutions": "32, 16, 8", 111 | "class_cond": True if ckpt == "512x512 HQ Cond" else False, 112 | "diffusion_steps": num_steps, 113 | "rescale_timesteps": True, 114 | "timestep_respacing": str( 115 | num_steps 116 | ), # modify this to decrease timesteps 117 | "image_size": 512 if ckpt.startswith("512") else 256, 118 | "learn_sigma": True, 119 | "noise_schedule": "linear", 120 | "num_channels": 256, 121 | "num_head_channels": 64, 122 | "num_res_blocks": 2, 123 | "resblock_updown": True, 124 | "use_checkpoint": False, 125 | "use_fp16": True, 126 | "use_scale_shift_norm": True, 127 | } 128 | ) 129 | # Split text by "|" symbol 130 | self.prompts = [phrase.strip() for phrase in prompt.split("|")] 131 | if self.prompts == [""]: 132 | self.prompts = [] 133 | 134 | self.image_prompts = [] # TODO 135 | self.batch_size = batch_size 136 | 137 | # Controls how much the image should look like the prompt. 138 | self.clip_guidance_scale = clip_guidance_scale 139 | 140 | # Controls the smoothness of the final output. 141 | self.tv_scale = 150 # TODO add control widget 142 | 143 | # Controls how far out of range RGB values are allowed to be. 144 | self.range_scale = 50 # TODO add control widget 145 | 146 | self.cutn = 32 # TODO add control widget 147 | self.cutn_batches = 2 # TODO add control widget 148 | self.cut_pow = 0.5 # TODO add control widget 149 | 150 | # Removed, repeat batches by triggering a new run 151 | # self.n_batches = 1 152 | 153 | # This enhances the effect of the init image, a good value is 1000. 154 | self.init_scale = 1000 # TODO add control widget 155 | 156 | # This needs to be between approx. 200 and 500 when using an init image. 157 | # Higher values make the output look more like the init. 158 | self.skip_timesteps = skip_timesteps # TODO add control widget 159 | 160 | self.seed = seed 161 | self.continue_prev_run = continue_prev_run 162 | 163 | self.use_cutout_augmentations = use_cutout_augmentations 164 | 165 | if device is None: 166 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 167 | else: 168 | self.device = device 169 | 170 | print("Using device:", self.device) 171 | 172 | def load_model( 173 | self, 174 | model_file_loc="assets/256x256_diffusion_uncond.pt", 175 | prev_model=None, 176 | prev_diffusion=None, 177 | prev_clip_model=None, 178 | ) -> None: 179 | if ( 180 | self.continue_prev_run is True 181 | and prev_model is not None 182 | and prev_diffusion is not None 183 | and prev_clip_model is not None 184 | ): 185 | self.model = prev_model 186 | self.diffusion = prev_diffusion 187 | self.clip_model = prev_clip_model 188 | 189 | self.clip_size = self.clip_model.visual.input_resolution 190 | self.normalize = transforms.Normalize( 191 | mean=[0.48145466, 0.4578275, 0.40821073], 192 | std=[0.26862954, 0.26130258, 0.27577711], 193 | ) 194 | 195 | else: 196 | self.model, self.diffusion = create_model_and_diffusion(**self.model_config) 197 | self.model.load_state_dict(torch.load(model_file_loc, map_location="cpu")) 198 | self.model.eval().requires_grad_(False).to(self.device) 199 | 200 | if self.ckpt == "512x512 HQ Cond": 201 | for name, param in self.model.named_parameters(): 202 | if "qkv" in name or "norm" in name or "proj" in name: 203 | param.requires_grad_() 204 | 205 | if self.model_config["use_fp16"]: 206 | self.model.convert_to_fp16() 207 | 208 | self.clip_model = ( 209 | clip.load("ViT-B/16", jit=False)[0] 210 | .eval() 211 | .requires_grad_(False) 212 | .to(self.device) 213 | ) 214 | 215 | self.clip_size = self.clip_model.visual.input_resolution 216 | self.normalize = transforms.Normalize( 217 | mean=[0.48145466, 0.4578275, 0.40821073], 218 | std=[0.26862954, 0.26130258, 0.27577711], 219 | ) 220 | 221 | return self.model, self.diffusion, self.clip_model 222 | 223 | def cond_fn_conditional(self, x, t, y=None): 224 | # From 512 HQ notebook using OpenAI's conditional 512x512 model 225 | # TODO: Merge with cond_fn's cutn_batches 226 | with torch.enable_grad(): 227 | x = x.detach().requires_grad_() 228 | n = x.shape[0] 229 | my_t = torch.ones([n], device=self.device, dtype=torch.long) * self.cur_t 230 | out = self.diffusion.p_mean_variance( 231 | self.model, x, my_t, clip_denoised=False, model_kwargs={"y": y} 232 | ) 233 | fac = self.diffusion.sqrt_one_minus_alphas_cumprod[self.cur_t] 234 | x_in = out["pred_xstart"] * fac + x * (1 - fac) 235 | clip_in = self.normalize(self.make_cutouts(x_in.add(1).div(2))) 236 | image_embeds = ( 237 | self.clip_model.encode_image(clip_in).float().view([self.cutn, n, -1]) 238 | ) 239 | dists = spherical_dist_loss(image_embeds, self.target_embeds.unsqueeze(0)) 240 | losses = dists.mean(0) 241 | tv_losses = tv_loss(x_in) 242 | loss = ( 243 | losses.sum() * self.clip_guidance_scale 244 | + tv_losses.sum() * self.tv_scale 245 | ) 246 | # TODO: Implement init image 247 | return -torch.autograd.grad(loss, x)[0] 248 | 249 | def cond_fn(self, x, t, out, y=None): 250 | n = x.shape[0] 251 | fac = self.diffusion.sqrt_one_minus_alphas_cumprod[self.cur_t] 252 | x_in = out["pred_xstart"] * fac + x * (1 - fac) 253 | x_in_grad = torch.zeros_like(x_in) 254 | for i in range(self.cutn_batches): 255 | clip_in = self.normalize(self.make_cutouts(x_in.add(1).div(2))) 256 | image_embeds = self.clip_model.encode_image(clip_in).float() 257 | dists = spherical_dist_loss( 258 | image_embeds.unsqueeze(1), self.target_embeds.unsqueeze(0) 259 | ) 260 | dists = dists.view([self.cutn, n, -1]) 261 | losses = dists.mul(self.weights).sum(2).mean(0) 262 | x_in_grad += ( 263 | torch.autograd.grad(losses.sum() * self.clip_guidance_scale, x_in)[0] 264 | / self.cutn_batches 265 | ) 266 | tv_losses = tv_loss(x_in) 267 | range_losses = range_loss(out["pred_xstart"]) 268 | loss = tv_losses.sum() * self.tv_scale + range_losses.sum() * self.range_scale 269 | if self.init is not None and self.init_scale: 270 | init_losses = self.lpips_model(x_in, self.init) 271 | loss = loss + init_losses.sum() * self.init_scale 272 | x_in_grad += torch.autograd.grad(loss, x_in)[0] 273 | grad = -torch.autograd.grad(x_in, x, x_in_grad)[0] 274 | return grad 275 | 276 | def model_init(self, init_image: Image.Image = None) -> None: 277 | if self.seed is not None: 278 | torch.manual_seed(self.seed) 279 | else: 280 | self.seed = torch.seed() # Trigger a seed, retrieve the utilized seed 281 | 282 | if self.use_cutout_augmentations: 283 | noise_fac = 0.1 284 | augs = nn.Sequential( 285 | K.RandomHorizontalFlip(p=0.5), 286 | K.RandomSharpness(0.3, p=0.4), 287 | K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"), 288 | K.RandomPerspective(0.2, p=0.4), 289 | K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), 290 | ) 291 | else: 292 | noise_fac = None 293 | augs = None 294 | 295 | self.make_cutouts = MakeCutouts( 296 | self.clip_size, self.cutn, self.cut_pow, noise_fac=noise_fac, augs=augs 297 | ) 298 | self.side_x = self.side_y = self.model_config["image_size"] 299 | 300 | self.target_embeds, self.weights = [], [] 301 | 302 | for prompt in self.prompts: 303 | txt, weight = parse_prompt(prompt) 304 | self.target_embeds.append( 305 | self.clip_model.encode_text(clip.tokenize(txt).to(self.device)).float() 306 | ) 307 | self.weights.append(weight) 308 | 309 | # TODO: Implement image prompt parsing 310 | # for prompt in self.image_prompts: 311 | # path, weight = parse_prompt(prompt) 312 | # img = Image.open(fetch(path)).convert('RGB') 313 | # img = TF.resize(img, min(side_x, side_y, *img.size), transforms.InterpolationMode.LANCZOS) 314 | # batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device)) 315 | # embed = clip_model.encode_image(normalize(batch)).float() 316 | # target_embeds.append(embed) 317 | # weights.extend([weight / cutn] * cutn) 318 | 319 | self.target_embeds = torch.cat(self.target_embeds) 320 | self.weights = torch.tensor(self.weights, device=self.device) 321 | if self.weights.sum().abs() < 1e-3: 322 | raise RuntimeError("The weights must not sum to 0.") 323 | self.weights /= self.weights.sum().abs() 324 | 325 | self.init = None 326 | if init_image is not None: 327 | self.init = init_image.resize((self.side_x, self.side_y), Image.LANCZOS) 328 | self.init = ( 329 | TF.to_tensor(self.init).to(self.device).unsqueeze(0).mul(2).sub(1) 330 | ) 331 | 332 | # LPIPS not required if init_image not used! 333 | if self.init is None: 334 | self.lpips_model = None 335 | else: 336 | self.lpips_model = lpips.LPIPS(net="vgg").to(self.device) 337 | 338 | if self.model_config["timestep_respacing"].startswith("ddim"): 339 | sample_fn = self.diffusion.ddim_sample_loop_progressive 340 | else: 341 | sample_fn = self.diffusion.p_sample_loop_progressive 342 | 343 | self.cur_t = self.diffusion.num_timesteps - self.skip_timesteps - 1 344 | 345 | if self.ckpt == "512x512 HQ Cond": 346 | print("Using conditional sampling fn") 347 | self.samples = sample_fn( 348 | self.model, 349 | (self.batch_size, 3, self.side_y, self.side_x), 350 | clip_denoised=False, 351 | model_kwargs={ 352 | "y": torch.zeros( 353 | [self.batch_size], device=self.device, dtype=torch.long 354 | ) 355 | }, 356 | cond_fn=self.cond_fn_conditional, 357 | progress=True, 358 | skip_timesteps=self.skip_timesteps, 359 | init_image=self.init, 360 | randomize_class=True, 361 | ) 362 | else: 363 | print("Using unconditional sampling fn") 364 | self.samples = sample_fn( 365 | self.model, 366 | (self.batch_size, 3, self.side_y, self.side_x), 367 | clip_denoised=False, 368 | model_kwargs={}, 369 | cond_fn=self.cond_fn, 370 | progress=True, 371 | skip_timesteps=self.skip_timesteps, 372 | init_image=self.init, 373 | randomize_class=True, 374 | cond_fn_with_grad=True, 375 | ) 376 | 377 | self.samplesgen = enumerate(self.samples) 378 | 379 | def iterate(self): 380 | self.cur_t -= 1 381 | _, sample = next(self.samplesgen) 382 | 383 | ims = [] 384 | for _, image in enumerate(sample["pred_xstart"]): 385 | im = TF.to_pil_image(image.add(1).div(2).clamp(0, 1)) 386 | ims.append(im) 387 | 388 | return ims 389 | -------------------------------------------------------------------------------- /docs/architecture.md: -------------------------------------------------------------------------------- 1 | # Architecture 2 | 3 | ## App structure 4 | 5 | There are three major components in this repo: 6 | 7 | + VQGAN-CLIP app 8 | + `app.py` houses the UI 9 | + `logic.py` stores underlying logic 10 | + `vqgan_utils.py` stores utility functions used by `logic.py` 11 | + CLIP guided diffusion app 12 | + `diffusion_app.py` houses the UI 13 | + `diffusion_logic.py` stores underlying logic 14 | + Gallery viewer 15 | + `gallery.py` houses the UI 16 | + `templates/index.html` houses the HTML page template 17 | 18 | To date, all three components are independent and do not have shared dependencies. 19 | 20 | The UI for both image generation apps are built using Streamlit ([docs](https://docs.streamlit.io/en/stable/index.html)), which makes it easy to throw together dashboards for ML projects in a short amount of time. The gallery viewer is a simple Flask dashboard. 21 | 22 | ## Customizing this repo 23 | 24 | Defaults settings for the app upon launch are specified in `defaults.yaml`, which can be adjusted as necessary. 25 | 26 | To use customized weights for VQGAN-CLIP, save both the `yaml` config and the `ckpt` weights in `assets/`, and ensure that they have the same filename, just with different postfixes (e.g. `mymodel.ckpt`, `mymodel.yaml`). It will then appear in the app interface for use. Refer to the [VQGAN repo](https://github.com/CompVis/taming-transformers) for training VQGAN on your own dataset. 27 | 28 | To modify the image generation logic, instantiate your own model class in `logic.py` / `diffusion_logic.py` and modify `app.py` / `diffusion_app.py` accordingly if changes to UI elements are needed. -------------------------------------------------------------------------------- /docs/images/diffusion-example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tnwei/vqgan-clip-app/82aff7993b5f64baacfe626e214398bad2d657d7/docs/images/diffusion-example.jpg -------------------------------------------------------------------------------- /docs/images/four-seasons-20210808.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tnwei/vqgan-clip-app/82aff7993b5f64baacfe626e214398bad2d657d7/docs/images/four-seasons-20210808.jpg -------------------------------------------------------------------------------- /docs/images/gallery.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tnwei/vqgan-clip-app/82aff7993b5f64baacfe626e214398bad2d657d7/docs/images/gallery.jpg -------------------------------------------------------------------------------- /docs/images/translationx_example_trimmed.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tnwei/vqgan-clip-app/82aff7993b5f64baacfe626e214398bad2d657d7/docs/images/translationx_example_trimmed.gif -------------------------------------------------------------------------------- /docs/images/ui.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tnwei/vqgan-clip-app/82aff7993b5f64baacfe626e214398bad2d657d7/docs/images/ui.jpeg -------------------------------------------------------------------------------- /docs/implementation-details.md: -------------------------------------------------------------------------------- 1 | # Implementation details 2 | 3 | ## VQGAN-CLIP 4 | 5 | Code for the VQGAN-CLIP app mostly came from the z+quantize_method notebook hosted in [EleutherAI/vqgan-clip](https://github.com/EleutherAI/vqgan-clip/tree/main/notebooks). The logic is mostly left unchanged, just refactored to work well with a GUI frontend. 6 | 7 | Prompt weighting was implemented after seeing it in one of the Colab variants floating around the internet. Browsing other notebooks led to adding loss functions like MSE regularization and TV loss for improved image quality. 8 | 9 | ## CLIP guided diffusion 10 | 11 | Code for the CLIP guided diffusion app mostly came from the HQ 512x512 Unconditional notebook, which also can be found in EleutherAI/vqgan-clip. Models from other guided diffusion notebooks are also implemented save for the non-HQ 256x256 version, which did not generate satisfactory results during testing. -------------------------------------------------------------------------------- /docs/notes-and-observations.md: -------------------------------------------------------------------------------- 1 | # Notes and observations 2 | 3 | ## Generated image size 4 | 5 | Generated image size is bound by GPU VRAM available. The reference notebook default to use 480x480. One of the notebooks in the thoughts section below uses 640x512. For reference, an RTX2060 can barely manage 300x300. You can use image upscaling tools such as [Waifu2X](https://github.com/nagadomi/waifu2x) or [Real ESRGAN](https://github.com/xinntao/Real-ESRGAN) to further upscale the generated image beyond VRAM limits. Just be aware that smaller generated images fundamentally contain less complexity than larger images. 6 | 7 | ## GPU VRAM consumption 8 | 9 | Following are GPU VRAM consumption read from `nvidia-smi`, note that your mileage may vary. 10 | 11 | For VQGAN-CLIP, using the `vqgan_imagenet_f16_1024` model checkpoint: 12 | 13 | | Resolution| VRAM Consumption | 14 | | ----------| ---------------- | 15 | | 300 x 300 | 4,829 MiB | 16 | | 480 x 480 | 8,465 MiB | 17 | | 640 x 360 | 8,169 MiB | 18 | | 640 x 480 | 10,247 MiB | 19 | | 800 x 450 | 13,977 MiB | 20 | | 800 x 600 | 18,157 MiB | 21 | | 960 x 540 | 15,131 MiB | 22 | | 960 x 720 | 19,777 MiB | 23 | | 1024 x 576| 17,175 MiB | 24 | | 1024 x 768| 22,167 MiB | 25 | | 1280 x 720| 24,353 MiB | 26 | 27 | For CLIP guided diffusion, using various checkpoints: 28 | 29 | | Method | VRAM Consumption | 30 | | ------------------| ---------------- | 31 | | 256x256 HQ Uncond | 9,199 MiB | 32 | | 512x512 HQ Cond | 15,079 MiB | 33 | | 512x512 HQ Uncond | 15,087 MiB | 34 | 35 | ## CUDA out of memory error 36 | 37 | If you're getting a CUDA out of memory error on your first run, it is a sign that the image size is too large. If you were able to generate images of a particular size prior, then you might have a CUDA memory leak and need to restart the application. Still figuring out where the leak is from. 38 | 39 | ## VQGAN weights and art style 40 | 41 | In the download links are trained on different datasets. You might find them suitable for generating different art styles. The `sflickr` dataset is skewed towards generating landscape images while the `faceshq` dataset is skewed towards generating faces. If you have no art style preference, the ImageNet weights do remarkably well. In fact, VQGAN-CLIP can be conditioned to generate specific styles, thanks to the breadth of understanding supplied by the CLIP model (see tips section). 42 | 43 | Some weights have multiple versions, e.g. ImageNet 1024 and Image 16384. The number represents the codebook (latent space) dimensionality. For more info, refer to the [VQGAN repo](https://github.com/CompVis/taming-transformers), also linked in the intro above. 44 | 45 | ## How many steps should I let the models run? 46 | 47 | A good starting number to try is 1000 steps for VQGAN-CLIP, and 2000 steps for CLIP guided diffusion. 48 | 49 | There is no ground rule on how many steps to run to get a good image, images generated are also not guaranteed to be interesting. YMMV as the image generation process depends on how well CLIP is able to understand the given prompt(s). 50 | 51 | ## Reproducibility 52 | 53 | Replicating a run with the same configuration, model checkpoint and seed will allow recreating the same output albeit with very minor variations. There exists a tiny bit of stochasticity that can't be eliminated, due to how the underlying convolution operators are implemented in CUDA. The underlying CUDA operations used by cuDNN can vary depending on differences in hardware or plain noise (see [Pytorch docs on reproducibility](https://pytorch.org/docs/stable/notes/randomness.html)). Using fully deterministic algorithms isn't an option as some CUDA operations do not have deterministic counterparts (e.g. `upsample_bicubic2d_backward_cuda`) 54 | 55 | In practice, the variations are not easy to spot unless a side-by-side comparison is done. -------------------------------------------------------------------------------- /docs/tips-n-tricks.md: -------------------------------------------------------------------------------- 1 | # Tips and tricks 2 | 3 | ## Prompt weighting 4 | 5 | The text prompts allow separating text prompts using the "|" symbol, with custom weightage for each. Example: `A beautiful sunny countryside scene of a Collie dog running towards a herd of sheep in a grassy field with a farmhouse in the background:100 | wild flowers:30 | roses:-30 | photorealistic:20 | V-ray:20 | ray tracing:20 | unreal engine:20`. 6 | 7 | Refer to [this Reddit post](https://www.reddit.com/r/bigsleep/comments/p15fis/tutorial_an_introduction_for_newbies_to_using_the/) for more info. 8 | 9 | ## Using long, descriptive prompts 10 | 11 | CLIP can handle long, descriptive prompts, as long as the prompt is understandable and not too specific. Example: [this Reddit post](https://www.reddit.com/r/MediaSynthesis/comments/oej9qc/gptneo_vqganclip/) 12 | 13 | ## Prompt engineering for stylization 14 | 15 | Art style of the images generated can be somewhat controlled via prompt engineering. 16 | 17 | + Adding mentions of "Unreal engine" in the prompt causes the generated output to resemble a 3D render in high resolution. Example: [this Tweet](https://twitter.com/arankomatsuzaki/status/1399471244760649729?s=20) 18 | + The art style of specific artists can be mimicked by adding their name to the prompt. [This blogpost](https://moultano.wordpress.com/2021/07/20/tour-of-the-sacred-library/) shows a number of images generated using the style of James Gurney. [This imgur album](https://imgur.com/a/Ha7lsYu) compares the output of similar prompts with different artist names appended. 19 | 20 | ## Multi-stage iteration 21 | 22 | The output image of the previous run can be carried over as the input image of the next run in the VQGAN-CLIP webapp, using the `continue_previous_run` option. This allows you to continue iterating on an image if you like how it has turned out thus far. Extending upon that feature enables multi-stage iteration, where the same image can be iterated upon using different prompts at different stages. 23 | 24 | For example, you can tell the network to generate "Cowboy singing in the sky", then continue the same image and weights using a different prompt, "Fish on an alien planet under the night sky". Because of how backprop works, the network will find the easiest way to change the previous image to fit the new prompt. Should be useful for preserving visual structure between images, and for smoothly transitioning from one scene to another. 25 | 26 | Here is an example where "Backyard in spring" is first generated, then iterated upon with prompts "Backyard in summer", "Backyard in autumn", and "Backyard in winter". Major visual elements in the initial image were inherited and utilized across multiple runs. 27 | 28 | ![Backyard in spring, summer, autumn and winter](images/four-seasons-20210808.jpg) 29 | 30 | A few things to take note: 31 | + If a new image size is specified, the existing output image will be cropped to size accordingly. 32 | + This is specifically possible for VQGAN-CLIP but not for CLIP guided diffusion. (Explain how both of them work) 33 | + Splitting a long run into multiple successive runs using the same prompt do not yield the same outcome due to the underlying stochasticity. This randomness can't be mitigated by setting the random seed alone. See the section on reproducibility in notes-and-observations.md. 34 | 35 | ## Scrolling and zooming 36 | 37 | Added scrolling and zooming from [this notebook](https://colab.research.google.com/github/chigozienri/VQGAN-CLIP-animations/blob/main/VQGAN-CLIP-animations.ipynb) by @chigozienri. 38 | 39 | ![Beautiful swirling wind, trending on ArtStation](images/translationx_example_trimmed.gif) 40 | 41 | More examples at [this imgur link](https://imgur.com/a/8pyUNCQ). 42 | 43 | -------------------------------------------------------------------------------- /download-diffusion-weights.sh: -------------------------------------------------------------------------------- 1 | # 256x256 2 | # curl -L -o assets/256x256_diffusion_uncond.pt -C - 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt' 3 | 4 | # 512x512 class conditional model 5 | # curl -L -o assets/512x512_diffusion.pt -C - 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/512x512_diffusion.pt' 6 | 7 | # 512x512 unconditional model 8 | # curl -L -o assets/512x512_diffusion_uncond_finetune_008100.pt --http1.1 'https://the-eye.eu/public/AI/models/512x512_diffusion_unconditional_ImageNet/512x512_diffusion_uncond_finetune_008100.pt' 9 | -------------------------------------------------------------------------------- /download-weights.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # Uncomment and run script to download 4 | 5 | # ImageNet 1024 6 | # curl -L -o assets/vqgan_imagenet_f16_1024.yaml -C - 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' 7 | # curl -L -o assets/vqgan_imagenet_f16_1024.ckpt -C - 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fckpts%2Flast.ckpt&dl=1' 8 | 9 | 10 | # ImageNet 16384 11 | # curl -L -o assets/vqgan_imagenet_f16_16384.yaml -C - 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' 12 | # curl -L -o assets/vqgan_imagenet_f16_16384.ckpt -C - 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fckpts%2Flast.ckpt&dl=1' 13 | 14 | # COCO 15 | # curl -L -o assets/coco.yaml -C - 'https://dl.nmkd.de/ai/clip/coco/coco.yaml' #COCO 16 | # curl -L -o assets/coco.ckpt -C - 'https://dl.nmkd.de/ai/clip/coco/coco.ckpt' #COCO 17 | 18 | # Faces HQ 19 | # curl -L -o assets/faceshq.yaml -C - 'https://drive.google.com/uc?export=download&id=1fHwGx_hnBtC8nsq7hesJvs-Klv-P0gzT' 20 | # curl -L -o assets/faceshq.ckpt -C - 'https://app.koofr.net/content/links/a04deec9-0c59-4673-8b37-3d696fe63a5d/files/get/last.ckpt?path=%2F2020-11-13T21-41-45_faceshq_transformer%2Fcheckpoints%2Flast.ckpt' 21 | 22 | 23 | # WikiArt 16384 24 | # curl -L 'http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.yaml' > assets/wikiart_16384.yaml 25 | # curl -L 'http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.ckpt' > assets/wikiart_16384.ckpt 26 | 27 | # S-Flickr 28 | # curl -L -o assets/sflckr.yaml -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fconfigs%2F2020-11-09T13-31-51-project.yaml&dl=1' 29 | # curl -L -o assets/sflckr.ckpt -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fcheckpoints%2Flast.ckpt&dl=1' 30 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: vqgan-clip-app 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - pytorch::pytorch=1.10.0 8 | - pytorch::torchvision=0.11.1 9 | - cudatoolkit=10.2 10 | - omegaconf 11 | - pytorch-lightning 12 | - tqdm 13 | - regex 14 | - kornia 15 | - ftfy 16 | - pillow=7.1.2 17 | - python=3.8 # For compatibility 18 | - imageio-ffmpeg=0.2.0 # For compatibility 19 | - ipykernel 20 | - imageio 21 | - ipywidgets 22 | - streamlit 23 | - conda-forge::ffmpeg 24 | - pyyaml 25 | - flask 26 | - pip 27 | - gitpython 28 | - opencv 29 | - pip: 30 | # - stegano 31 | # - python-xmp-toolkit 32 | # - imgtag 33 | - einops 34 | - transformers 35 | - git+https://github.com/openai/CLIP 36 | # For guided diffusion 37 | - lpips 38 | - git+https://github.com//crowsonkb/guided-diffusion 39 | prefix: /home/tnwei/miniconda3/envs/vqgan-clip-app 40 | -------------------------------------------------------------------------------- /gallery.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | from typing import Union 5 | 6 | import yaml 7 | from flask import Flask, render_template, request, send_from_directory 8 | from PIL import Image 9 | 10 | 11 | class RunResults: 12 | """ 13 | Store run output and metadata as a class for ease of use. 14 | """ 15 | 16 | def __init__(self, fdir: Union[str, Path]): 17 | self.fdir: str = Path(fdir) # exactly as given 18 | self.absfdir: Path = Path(fdir).resolve() # abs 19 | files_available = [i.name for i in self.fdir.glob("*")] 20 | 21 | # The most important info is the final image and the metadata 22 | # Leaving an option to delete videos to save space 23 | if ( 24 | "details.txt" not in files_available 25 | and "details.json" not in files_available 26 | ): 27 | raise ValueError( 28 | f"fdir passed has neither details.txt or details.json: {fdir}" 29 | ) 30 | 31 | if "output.PNG" not in files_available: 32 | raise ValueError(f"fdir passed contains no output.PNG: {fdir}") 33 | 34 | self.impath = (self.fdir / "output.PNG").as_posix() 35 | 36 | if "anim.mp4" in files_available: 37 | self.animpath = (self.fdir / "anim.mp4").as_posix() 38 | else: 39 | self.animpath = None 40 | print(f"fdir passed contains no anim.mp4: {fdir}") 41 | 42 | if "init-image.JPEG" in files_available: 43 | self.initimpath = (self.fdir / "init-image.JPEG").as_posix() 44 | else: 45 | self.initimpath = None 46 | 47 | self.impromptspath = [i.as_posix() for i in self.fdir.glob("image-prompt*")] 48 | if len(self.impromptspath) == 0: 49 | self.impromptspath = None 50 | 51 | if "details.txt" in files_available: 52 | self.detailspath = (self.fdir / "details.txt").as_posix() 53 | elif "details.json" in files_available: 54 | self.detailspath = (self.fdir / "details.json").as_posix() 55 | 56 | with open(self.detailspath, "r") as f: 57 | self.details = json.load(f) 58 | 59 | # Preserving the filepaths as I realize might be calling 60 | # them through Jinja in HTML instead of within the app 61 | # self.detailshtmlstr = markdown2.markdown( 62 | # "```" + json.dumps(self.details, indent=4) + "```", 63 | # extras=["fenced-code-blocks"] 64 | # ) 65 | self.detailshtmlstr = yaml.dump(self.details) 66 | 67 | # Replace with line feed and carriage return 68 | # ref: https://stackoverflow.com/a/39325879/13095028 69 | self.detailshtmlstr = self.detailshtmlstr.replace("\n", "
") 70 | # Edit: Using preformatted tag
, solved!
 71 | 
 72 |         self.im = Image.open(self.fdir / "output.PNG").convert(
 73 |             "RGB"
 74 |         )  # just to be sure the format is right
 75 | 
 76 | 
 77 | def update_runs(fdir, runs):
 78 |     existing_run_folders = [i.absfdir.name for i in runs]
 79 |     # Load each run as ModelRun objects
 80 |     # Loading the latest ones first
 81 |     for i in sorted(fdir.iterdir()):
 82 |         # If is a folder and contains images and metadata
 83 |         if (
 84 |             i.is_dir()
 85 |             and (i / "details.json").exists()
 86 |             and (i / "output.PNG").exists()
 87 |             and i.name not in existing_run_folders
 88 |         ):
 89 |             try:
 90 |                 runs.insert(0, RunResults(i))
 91 |             except Exception as e:
 92 |                 print(f"Skipped {i} due to raised exception {e}")
 93 |     return runs
 94 | 
 95 | 
 96 | if __name__ == "__main__":
 97 |     # Select output dir
 98 |     parser = argparse.ArgumentParser()
 99 |     parser.add_argument(
100 |         "path",
101 |         help="Path to output folder, should contain subdirs of individual runs",
102 |         nargs="?",
103 |         default="./output/",
104 |     )
105 |     parser.add_argument(
106 |         "-n",
107 |         "--numitems",
108 |         help="Number of items per page",
109 |         default=24,  # multiple of three since the dashboard has three panels
110 |     )
111 |     parser.add_argument(
112 |         "--kiosk",
113 |         help="Omit showing run details on dashboard",
114 |         default=False,
115 |         action="store_true",
116 |     )
117 |     args = parser.parse_args()
118 | 
119 |     fdir = Path(args.path)
120 |     runs = update_runs(fdir, [])
121 | 
122 |     app = Flask(
123 |         __name__,
124 |         # Hack to allow easy access to images
125 |         # Else typically this stuff needs to be put in a static/ folder!
126 |         static_url_path="",
127 |         static_folder="",
128 |     )
129 | 
130 |     @app.route("/")
131 |     def home():
132 |         # startidx = request.args.get('startidx')
133 |         # endidx = request.args.get('endidx')
134 | 
135 |         # Pagenum starts at 1
136 |         page = request.args.get("page")
137 |         page = 1 if page is None else int(page)
138 | 
139 |         # startidx = 1 if startidx is None else int(startidx)
140 |         # endidx = args.numitems if endidx is None else int(endidx)
141 |         # print("startidx, endidx: ", startidx, endidx)
142 |         global runs
143 |         runs = update_runs(fdir, runs)  # Updates new results when refreshed
144 |         num_pages = (len(runs) // args.numitems) + 1
145 | 
146 |         page_labels = {}
147 |         for i in range(0, num_pages):
148 |             page_labels[i + 1] = dict(
149 |                 start=i * args.numitems + 1, end=(i + 1) * args.numitems
150 |             )
151 | 
152 |         return render_template(
153 |             "index.html",
154 |             runs=runs,
155 |             startidx=page_labels[page]["start"],
156 |             endidx=page_labels[page]["end"],
157 |             page=page,
158 |             fdir=fdir,
159 |             page_labels=page_labels,
160 |             kiosk=args.kiosk,
161 |         )
162 | 
163 |     @app.route("/findurl")
164 |     def findurl(path, filename):
165 |         return send_from_directory(path, filename)
166 | 
167 |     app.run(debug=False)
168 | 


--------------------------------------------------------------------------------
/logic.py:
--------------------------------------------------------------------------------
  1 | from typing import Optional, List, Tuple
  2 | from PIL import Image
  3 | import argparse
  4 | import clip
  5 | from vqgan_utils import (
  6 |     load_vqgan_model,
  7 |     MakeCutouts,
  8 |     parse_prompt,
  9 |     resize_image,
 10 |     Prompt,
 11 |     synth,
 12 |     checkin,
 13 |     TVLoss,
 14 | )
 15 | import torch
 16 | from torchvision.transforms import functional as TF
 17 | import torch.nn as nn
 18 | from torch.nn import functional as F
 19 | from torch import optim
 20 | from torchvision import transforms
 21 | import cv2
 22 | import numpy as np
 23 | import kornia.augmentation as K
 24 | 
 25 | 
 26 | class Run:
 27 |     """
 28 |     Subclass this to house your own implementation of CLIP-based image generation 
 29 |     models within the UI
 30 |     """
 31 | 
 32 |     def __init__(self):
 33 |         """
 34 |         Set up the run's config here
 35 |         """
 36 |         pass
 37 | 
 38 |     def load_model(self):
 39 |         """
 40 |         Load models here. Separated this from __init__ to allow loading model state
 41 |         from a previous run
 42 |         """
 43 |         pass
 44 | 
 45 |     def model_init(self):
 46 |         """
 47 |         Continue run setup, for items that require the models to be in=place. 
 48 |         Call once after load_model
 49 |         """
 50 |         pass
 51 | 
 52 |     def iterate(self):
 53 |         """
 54 |         Place iteration logic here. Outputs results for human consumption at
 55 |         every step. 
 56 |         """
 57 |         pass
 58 | 
 59 | 
 60 | class VQGANCLIPRun(Run):
 61 |     def __init__(
 62 |         # Inputs
 63 |         self,
 64 |         text_input: str = "the first day of the waters",
 65 |         vqgan_ckpt: str = "vqgan_imagenet_f16_16384",
 66 |         num_steps: int = 300,
 67 |         image_x: int = 300,
 68 |         image_y: int = 300,
 69 |         init_image: Optional[Image.Image] = None,
 70 |         image_prompts: List[Image.Image] = [],
 71 |         continue_prev_run: bool = False,
 72 |         seed: Optional[int] = None,
 73 |         mse_weight=0.5,
 74 |         mse_weight_decay=0.1,
 75 |         mse_weight_decay_steps=50,
 76 |         tv_loss_weight=1e-3,
 77 |         use_cutout_augmentations: bool = True,
 78 |         # use_augs: bool = True,
 79 |         # noise_fac: float = 0.1,
 80 |         # use_noise: Optional[float] = None,
 81 |         # mse_withzeros=True,
 82 |         ## **kwargs,  # Use this to receive Streamlit objects ## Call from main UI
 83 |         use_scrolling_zooming: bool = False,
 84 |         translation_x: int = 0,
 85 |         translation_y: int = 0,
 86 |         rotation_angle: float = 0,
 87 |         zoom_factor: float = 1,
 88 |         transform_interval: int = 10,
 89 |         device: Optional[torch.device] = None,
 90 |     ) -> None:
 91 |         super().__init__()
 92 |         self.text_input = text_input
 93 |         self.vqgan_ckpt = vqgan_ckpt
 94 |         self.num_steps = num_steps
 95 |         self.image_x = image_x
 96 |         self.image_y = image_y
 97 |         self.init_image = init_image
 98 |         self.image_prompts = image_prompts
 99 |         self.continue_prev_run = continue_prev_run
100 |         self.seed = seed
101 | 
102 |         # Setup ------------------------------------------------------------------------------
103 |         # Split text by "|" symbol
104 |         texts = [phrase.strip() for phrase in text_input.split("|")]
105 |         if texts == [""]:
106 |             texts = []
107 | 
108 |         # Leaving most of this untouched
109 |         self.args = argparse.Namespace(
110 |             prompts=texts,
111 |             image_prompts=image_prompts,
112 |             noise_prompt_seeds=[],
113 |             noise_prompt_weights=[],
114 |             size=[int(image_x), int(image_y)],
115 |             init_image=init_image,
116 |             init_weight=mse_weight,
117 |             # clip.available_models()
118 |             # ['RN50', 'RN101', 'RN50x4', 'ViT-B/32']
119 |             # Visual Transformer seems to be the smallest
120 |             clip_model="ViT-B/32",
121 |             vqgan_config=f"assets/{vqgan_ckpt}.yaml",
122 |             vqgan_checkpoint=f"assets/{vqgan_ckpt}.ckpt",
123 |             step_size=0.05,
124 |             cutn=64,
125 |             cut_pow=1.0,
126 |             display_freq=50,
127 |             seed=seed,
128 |         )
129 | 
130 |         if device is None:
131 |             self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
132 |         else:
133 |             self.device = device
134 | 
135 |         print("Using device:", device)
136 | 
137 |         self.iterate_counter = 0
138 |         # self.use_augs = use_augs
139 |         # self.noise_fac = noise_fac
140 |         # self.use_noise = use_noise
141 |         # self.mse_withzeros = mse_withzeros
142 |         self.init_mse_weight = mse_weight
143 |         self.mse_weight = mse_weight
144 |         self.mse_weight_decay = mse_weight_decay
145 |         self.mse_weight_decay_steps = mse_weight_decay_steps
146 | 
147 |         self.use_cutout_augmentations = use_cutout_augmentations
148 | 
149 |         # For TV loss
150 |         self.tv_loss_weight = tv_loss_weight
151 | 
152 |         self.use_scrolling_zooming = use_scrolling_zooming
153 |         self.translation_x = translation_x
154 |         self.translation_y = translation_y
155 |         self.rotation_angle = rotation_angle
156 |         self.zoom_factor = zoom_factor
157 |         self.transform_interval = transform_interval
158 | 
159 |     def load_model(
160 |         self, prev_model: nn.Module = None, prev_perceptor: nn.Module = None
161 |     ) -> Optional[Tuple[nn.Module, nn.Module]]:
162 |         if self.continue_prev_run is True:
163 |             self.model = prev_model
164 |             self.perceptor = prev_perceptor
165 |             return None
166 | 
167 |         else:
168 |             self.model = load_vqgan_model(
169 |                 self.args.vqgan_config, self.args.vqgan_checkpoint
170 |             ).to(self.device)
171 | 
172 |             self.perceptor = (
173 |                 clip.load(self.args.clip_model, jit=False)[0]
174 |                 .eval()
175 |                 .requires_grad_(False)
176 |                 .to(self.device)
177 |             )
178 | 
179 |             return self.model, self.perceptor
180 | 
181 |     def model_init(self, init_image: Image.Image = None) -> None:
182 |         cut_size = self.perceptor.visual.input_resolution
183 |         e_dim = self.model.quantize.e_dim
184 |         f = 2 ** (self.model.decoder.num_resolutions - 1)
185 | 
186 |         if self.use_cutout_augmentations:
187 |             noise_fac = 0.1
188 |             augs = nn.Sequential(
189 |                 K.RandomHorizontalFlip(p=0.5),
190 |                 K.RandomSharpness(0.3, p=0.4),
191 |                 K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"),
192 |                 K.RandomPerspective(0.2, p=0.4),
193 |                 K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
194 |             )
195 |         else:
196 |             noise_fac = None
197 |             augs = None
198 | 
199 |         self.make_cutouts = MakeCutouts(
200 |             cut_size,
201 |             self.args.cutn,
202 |             cut_pow=self.args.cut_pow,
203 |             noise_fac=noise_fac,
204 |             augs=augs,
205 |         )
206 | 
207 |         n_toks = self.model.quantize.n_e
208 |         toksX, toksY = self.args.size[0] // f, self.args.size[1] // f
209 |         sideX, sideY = toksX * f, toksY * f
210 |         self.z_min = self.model.quantize.embedding.weight.min(dim=0).values[
211 |             None, :, None, None
212 |         ]
213 |         self.z_max = self.model.quantize.embedding.weight.max(dim=0).values[
214 |             None, :, None, None
215 |         ]
216 | 
217 |         if self.seed is not None:
218 |             torch.manual_seed(self.seed)
219 |         else:
220 |             self.seed = torch.seed()  # Trigger a seed, retrieve the utilized seed
221 | 
222 |         # Initialization order: continue_prev_im, init_image, then only random init
223 |         if init_image is not None:
224 |             init_image = init_image.resize((sideX, sideY), Image.LANCZOS)
225 |             self.z, *_ = self.model.encode(
226 |                 TF.to_tensor(init_image).to(self.device).unsqueeze(0) * 2 - 1
227 |             )
228 |         elif self.args.init_image:
229 |             pil_image = self.args.init_image
230 |             pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
231 |             self.z, *_ = self.model.encode(
232 |                 TF.to_tensor(pil_image).to(self.device).unsqueeze(0) * 2 - 1
233 |             )
234 |         else:
235 |             one_hot = F.one_hot(
236 |                 torch.randint(n_toks, [toksY * toksX], device=self.device), n_toks
237 |             ).float()
238 |             self.z = one_hot @ self.model.quantize.embedding.weight
239 |             self.z = self.z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
240 |         self.z_orig = self.z.clone()
241 |         self.z.requires_grad_(True)
242 |         self.opt = optim.Adam([self.z], lr=self.args.step_size)
243 | 
244 |         self.normalize = transforms.Normalize(
245 |             mean=[0.48145466, 0.4578275, 0.40821073],
246 |             std=[0.26862954, 0.26130258, 0.27577711],
247 |         )
248 | 
249 |         self.pMs = []
250 | 
251 |         for prompt in self.args.prompts:
252 |             txt, weight, stop = parse_prompt(prompt)
253 |             embed = self.perceptor.encode_text(
254 |                 clip.tokenize(txt).to(self.device)
255 |             ).float()
256 |             self.pMs.append(Prompt(embed, weight, stop).to(self.device))
257 | 
258 |         for uploaded_image in self.args.image_prompts:
259 |             # path, weight, stop = parse_prompt(prompt)
260 |             # img = resize_image(Image.open(fetch(path)).convert("RGB"), (sideX, sideY))
261 |             img = resize_image(uploaded_image.convert("RGB"), (sideX, sideY))
262 |             batch = self.make_cutouts(TF.to_tensor(img).unsqueeze(0).to(self.device))
263 |             embed = self.perceptor.encode_image(self.normalize(batch)).float()
264 |             self.pMs.append(Prompt(embed, weight, stop).to(self.device))
265 | 
266 |         for seed, weight in zip(
267 |             self.args.noise_prompt_seeds, self.args.noise_prompt_weights
268 |         ):
269 |             gen = torch.Generator().manual_seed(seed)
270 |             embed = torch.empty([1, self.perceptor.visual.output_dim]).normal_(
271 |                 generator=gen
272 |             )
273 |             self.pMs.append(Prompt(embed, weight).to(self.device))
274 | 
275 |     def _ascend_txt(self) -> List:
276 |         out = synth(self.model, self.z)
277 |         iii = self.perceptor.encode_image(
278 |             self.normalize(self.make_cutouts(out))
279 |         ).float()
280 | 
281 |         result = {}
282 | 
283 |         if self.args.init_weight:
284 |             result["mse_loss"] = F.mse_loss(self.z, self.z_orig) * self.mse_weight / 2
285 | 
286 |             # MSE regularization scheduler
287 |             with torch.no_grad():
288 |                 # if not the first step
289 |                 # and is time for step change
290 |                 # and both weight decay steps and magnitude are nonzero
291 |                 # and MSE isn't zero already
292 |                 if (
293 |                     self.iterate_counter > 0
294 |                     and self.iterate_counter % self.mse_weight_decay_steps == 0
295 |                     and self.mse_weight_decay != 0
296 |                     and self.mse_weight_decay_steps != 0
297 |                     and self.mse_weight != 0
298 |                 ):
299 |                     self.mse_weight = self.mse_weight - self.mse_weight_decay
300 | 
301 |                     # Don't allow changing sign
302 |                     # Basically, caps MSE at zero if decreasing from positive
303 |                     # But, also prevents MSE from becoming positive if -MSE intended
304 |                     if self.init_mse_weight > 0:
305 |                         self.mse_weight = max(self.mse_weight, 0)
306 |                     else:
307 |                         self.mse_weight = min(self.mse_weight, 0)
308 | 
309 |                     print(f"updated mse weight: {self.mse_weight}")
310 | 
311 |         tv_loss_fn = TVLoss()
312 |         result["tv_loss"] = tv_loss_fn(self.z) * self.tv_loss_weight
313 | 
314 |         for count, prompt in enumerate(self.pMs):
315 |             result[f"prompt_loss_{count}"] = prompt(iii)
316 | 
317 |         return result
318 | 
319 |     def iterate(self) -> Tuple[List[float], Image.Image]:
320 |         if not self.use_scrolling_zooming:
321 |             # Forward prop
322 |             self.opt.zero_grad()
323 |             losses = self._ascend_txt()
324 | 
325 |             # Grab an image
326 |             im: Image.Image = checkin(self.model, self.z)
327 | 
328 |             # Backprop
329 |             loss = sum([j for i, j in losses.items()])
330 |             loss.backward()
331 |             self.opt.step()
332 |             with torch.no_grad():
333 |                 self.z.copy_(self.z.maximum(self.z_min).minimum(self.z_max))
334 | 
335 |             # Advance iteration counter
336 |             self.iterate_counter += 1
337 | 
338 |             print(
339 |                 f"Step {self.iterate_counter} losses: {[(i, j.item()) for i, j in losses.items()]}"
340 |             )
341 | 
342 |             # Output stuff useful for humans
343 |             return [(i, j.item()) for i, j in losses.items()], im
344 | 
345 |         else:
346 |             # Grab current image
347 |             im_before_transform: Image.Image = checkin(self.model, self.z)
348 | 
349 |             # Convert for use in OpenCV
350 |             imarr = np.array(im_before_transform)
351 |             imarr = cv2.cvtColor(imarr, cv2.COLOR_RGB2BGR)
352 | 
353 |             translation = np.float32(
354 |                 [[1, 0, self.translation_x], [0, 1, self.translation_y]]
355 |             )
356 | 
357 |             imcenter = (imarr.shape[1] // 2, imarr.shape[0] // 2)
358 |             rotation = cv2.getRotationMatrix2D(
359 |                 imcenter, angle=self.rotation_angle, scale=self.zoom_factor
360 |             )
361 | 
362 |             trans_mat = np.vstack([translation, [0, 0, 1]])
363 |             rot_mat = np.vstack([rotation, [0, 0, 1]])
364 |             transformation_matrix = np.matmul(rot_mat, trans_mat)
365 | 
366 |             outarr = cv2.warpPerspective(
367 |                 imarr,
368 |                 transformation_matrix,
369 |                 (imarr.shape[1], imarr.shape[0]),
370 |                 borderMode=cv2.BORDER_WRAP,
371 |             )
372 | 
373 |             transformed_im = Image.fromarray(cv2.cvtColor(outarr, cv2.COLOR_BGR2RGB))
374 | 
375 |             # Encode as z, reinit
376 |             self.z, *_ = self.model.encode(
377 |                 TF.to_tensor(transformed_im).to(self.device).unsqueeze(0) * 2 - 1
378 |             )
379 |             self.z.requires_grad_(True)
380 |             self.opt = optim.Adam([self.z], lr=self.args.step_size)
381 | 
382 |             for _ in range(self.transform_interval):
383 |                 # Forward prop
384 |                 self.opt.zero_grad()
385 |                 losses = self._ascend_txt()
386 | 
387 |                 # Grab an image
388 |                 im: Image.Image = checkin(self.model, self.z)
389 | 
390 |                 # Backprop
391 |                 loss = sum([j for i, j in losses.items()])
392 |                 loss.backward()
393 |                 self.opt.step()
394 |                 with torch.no_grad():
395 |                     self.z.copy_(self.z.maximum(self.z_min).minimum(self.z_max))
396 | 
397 |             # Advance iteration counter
398 |             self.iterate_counter += 1
399 | 
400 |             print(
401 |                 f"Step {self.iterate_counter} losses: {[(i, j.item()) for i, j in losses.items()]}"
402 |             )
403 | 
404 |             # Output stuff useful for humans
405 |             return [(i, j.item()) for i, j in losses.items()], im
406 | 


--------------------------------------------------------------------------------
/templates/index.html:
--------------------------------------------------------------------------------
  1 | 
  2 | 
  3 | 
  4 | 
  5 |   
  6 |   
  7 |   
  8 | 
  9 |   
 10 |   
 12 | 
 13 |   VQGAN-CLIP output gallery
 14 | 
 15 | 
 16 | 
 17 |   
18 |
19 |

VQGAN-CLIP output gallery

20 |
21 |
22 | 23 | Output folder: {{ fdir }} 24 |
25 | Total runs: {{ runs|length }} 26 |
27 | {% for pagenum, indices in page_labels.items() %} 28 |   29 | 30 | {% if pagenum == page %} 31 | {{indices['start']}}-{{indices['end']}} 32 | {% else %} 33 | {{indices['start']}}-{{indices['end']}} 34 | {% endif %} 35 | 36 |   37 | {% endfor %} 38 |
39 |
40 | 41 |
42 |
43 | 44 | {% for run in runs[startidx-1:endidx] %} 45 |
46 |
47 | Output Image 50 |
51 |
{{ run.details.text_input }}
52 | 53 | {% if not kiosk %} 54 |

55 | 56 |

    57 | {% for i, j in run.details.items() %} 58 | 59 |
  • {{ i ~ ': ' ~ j }}
  • 60 | {% endfor %} 61 |
62 | 63 |

64 | 65 |

66 | {% if run.animpath is not none %} 67 | Animation 68 |   69 | {% endif %} 70 | {% if run.initimpath is not none %} 71 | Init image 72 |   73 | {% endif %} 74 | {% if run.impromptspath is not none %} 75 | {% for i in range(run.impromptspath|length) %} 76 | Image prompt {{i+1}} 77 |   78 | {% endfor %} 79 | {% endif %} 80 |

81 | 82 | {% endif %} 83 | 84 |
85 |
86 |
87 | {% endfor %} 88 |
89 |
90 | 91 |
92 | 93 | {% for pagenum, indices in page_labels.items() %} 94 |   95 | 96 | {% if pagenum == page %} 97 | {{indices['start']}}-{{indices['end']}} 98 | {% else %} 99 | {{indices['start']}}-{{indices['end']}} 100 | {% endif %} 101 | 102 |   103 | {% endfor %} 104 | 105 |
106 |
107 | 108 | Code: tnwei/vqgan-clip-app 109 | 110 |
111 |
112 |
113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /vqgan_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | import io 6 | from omegaconf import OmegaConf 7 | from taming.models import cond_transformer, vqgan 8 | from PIL import Image 9 | from torchvision.transforms import functional as TF 10 | import requests 11 | import sys 12 | 13 | sys.path.append("./taming-transformers") 14 | 15 | 16 | def sinc(x): 17 | return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([])) 18 | 19 | 20 | def lanczos(x, a): 21 | cond = torch.logical_and(-a < x, x < a) 22 | out = torch.where(cond, sinc(x) * sinc(x / a), x.new_zeros([])) 23 | return out / out.sum() 24 | 25 | 26 | def ramp(ratio, width): 27 | n = math.ceil(width / ratio + 1) 28 | out = torch.empty([n]) 29 | cur = 0 30 | for i in range(out.shape[0]): 31 | out[i] = cur 32 | cur += ratio 33 | return torch.cat([-out[1:].flip([0]), out])[1:-1] 34 | 35 | 36 | def resample(input, size, align_corners=True): 37 | n, c, h, w = input.shape 38 | dh, dw = size 39 | 40 | input = input.view([n * c, 1, h, w]) 41 | 42 | if dh < h: 43 | kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype) 44 | pad_h = (kernel_h.shape[0] - 1) // 2 45 | input = F.pad(input, (0, 0, pad_h, pad_h), "reflect") 46 | input = F.conv2d(input, kernel_h[None, None, :, None]) 47 | 48 | if dw < w: 49 | kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype) 50 | pad_w = (kernel_w.shape[0] - 1) // 2 51 | input = F.pad(input, (pad_w, pad_w, 0, 0), "reflect") 52 | input = F.conv2d(input, kernel_w[None, None, None, :]) 53 | 54 | input = input.view([n, c, h, w]) 55 | return F.interpolate(input, size, mode="bicubic", align_corners=align_corners) 56 | 57 | 58 | class ReplaceGrad(torch.autograd.Function): 59 | @staticmethod 60 | def forward(ctx, x_forward, x_backward): 61 | ctx.shape = x_backward.shape 62 | return x_forward 63 | 64 | @staticmethod 65 | def backward(ctx, grad_in): 66 | return None, grad_in.sum_to_size(ctx.shape) 67 | 68 | 69 | replace_grad = ReplaceGrad.apply 70 | 71 | 72 | class ClampWithGrad(torch.autograd.Function): 73 | @staticmethod 74 | def forward(ctx, input, min, max): 75 | ctx.min = min 76 | ctx.max = max 77 | ctx.save_for_backward(input) 78 | return input.clamp(min, max) 79 | 80 | @staticmethod 81 | def backward(ctx, grad_in): 82 | input, = ctx.saved_tensors 83 | return ( 84 | grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), 85 | None, 86 | None, 87 | ) 88 | 89 | 90 | clamp_with_grad = ClampWithGrad.apply 91 | 92 | 93 | def vector_quantize(x, codebook): 94 | d = ( 95 | x.pow(2).sum(dim=-1, keepdim=True) 96 | + codebook.pow(2).sum(dim=1) 97 | - 2 * x @ codebook.T 98 | ) 99 | indices = d.argmin(-1) 100 | x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook 101 | return replace_grad(x_q, x) 102 | 103 | 104 | class Prompt(nn.Module): 105 | def __init__(self, embed, weight=1.0, stop=float("-inf")): 106 | super().__init__() 107 | self.register_buffer("embed", embed) 108 | self.register_buffer("weight", torch.as_tensor(weight)) 109 | self.register_buffer("stop", torch.as_tensor(stop)) 110 | 111 | def forward(self, input): 112 | input_normed = F.normalize(input.unsqueeze(1), dim=2) 113 | embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2) 114 | dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) 115 | dists = dists * self.weight.sign() 116 | return ( 117 | self.weight.abs() 118 | * replace_grad(dists, torch.maximum(dists, self.stop)).mean() 119 | ) 120 | 121 | 122 | def fetch(url_or_path): 123 | if str(url_or_path).startswith("http://") or str(url_or_path).startswith( 124 | "https://" 125 | ): 126 | r = requests.get(url_or_path) 127 | r.raise_for_status() 128 | fd = io.BytesIO() 129 | fd.write(r.content) 130 | fd.seek(0) 131 | return fd 132 | return open(url_or_path, "rb") 133 | 134 | 135 | def parse_prompt(prompt): 136 | if prompt.startswith("http://") or prompt.startswith("https://"): 137 | vals = prompt.rsplit(":", 3) 138 | vals = [vals[0] + ":" + vals[1], *vals[2:]] 139 | else: 140 | vals = prompt.rsplit(":", 2) 141 | vals = vals + ["", "1", "-inf"][len(vals) :] 142 | return vals[0], float(vals[1]), float(vals[2]) 143 | 144 | 145 | class MakeCutouts(nn.Module): 146 | def __init__(self, cut_size, cutn, cut_pow=1.0, noise_fac=None, augs=None): 147 | super().__init__() 148 | self.cut_size = cut_size 149 | self.cutn = cutn 150 | self.cut_pow = cut_pow 151 | self.noise_fac = noise_fac 152 | self.augs = augs 153 | 154 | def forward(self, input): 155 | sideY, sideX = input.shape[2:4] 156 | max_size = min(sideX, sideY) 157 | min_size = min(sideX, sideY, self.cut_size) 158 | cutouts = [] 159 | for _ in range(self.cutn): 160 | size = int( 161 | torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size 162 | ) 163 | offsetx = torch.randint(0, sideX - size + 1, ()) 164 | offsety = torch.randint(0, sideY - size + 1, ()) 165 | cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] 166 | cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) 167 | 168 | if self.augs: 169 | batch = self.augs(torch.cat(cutouts, dim=0)) 170 | else: 171 | batch = torch.cat(cutouts, dim=0) 172 | 173 | if self.noise_fac: 174 | facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac) 175 | batch = batch + facs * torch.randn_like(batch) 176 | 177 | return clamp_with_grad(batch, 0, 1) 178 | 179 | 180 | def load_vqgan_model(config_path, checkpoint_path): 181 | config = OmegaConf.load(config_path) 182 | if config.model.target == "taming.models.vqgan.VQModel": 183 | model = vqgan.VQModel(**config.model.params) 184 | model.eval().requires_grad_(False) 185 | model.init_from_ckpt(checkpoint_path) 186 | elif config.model.target == "taming.models.cond_transformer.Net2NetTransformer": 187 | parent_model = cond_transformer.Net2NetTransformer(**config.model.params) 188 | parent_model.eval().requires_grad_(False) 189 | parent_model.init_from_ckpt(checkpoint_path) 190 | model = parent_model.first_stage_model 191 | else: 192 | raise ValueError(f"unknown model type: {config.model.target}") 193 | del model.loss 194 | return model 195 | 196 | 197 | def resize_image(image, out_size): 198 | ratio = image.size[0] / image.size[1] 199 | area = min(image.size[0] * image.size[1], out_size[0] * out_size[1]) 200 | size = round((area * ratio) ** 0.5), round((area / ratio) ** 0.5) 201 | return image.resize(size, Image.LANCZOS) 202 | 203 | 204 | def synth(model, z): 205 | z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim( 206 | 3, 1 207 | ) 208 | return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1) 209 | 210 | 211 | @torch.no_grad() 212 | def checkin(model, z): 213 | # losses_str = ", ".join(f"{loss.item():g}" for loss in losses) 214 | # tqdm.write(f"i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}") 215 | out = synth(model, z) 216 | im = TF.to_pil_image(out[0].cpu()) 217 | return im 218 | # display.display(display.Image('progress.png')) # ipynb only 219 | 220 | 221 | class TVLoss(nn.Module): 222 | def forward(self, input): 223 | input = F.pad(input, (0, 1, 0, 1), "replicate") 224 | x_diff = input[..., :-1, 1:] - input[..., :-1, :-1] 225 | y_diff = input[..., 1:, :-1] - input[..., :-1, :-1] 226 | diff = x_diff ** 2 + y_diff ** 2 + 1e-8 227 | return diff.mean(dim=1).sqrt().mean() 228 | --------------------------------------------------------------------------------