├── .gitignore ├── LICENSE ├── README.md ├── assets └── fig1.png ├── grounding_input └── condition_null_generator.py ├── images ├── color1.pth ├── color2.pth ├── dummy.pth ├── fire.png ├── jeep_depth.png ├── jeep_sketch.png ├── nature.png └── partial_sketch.png ├── inference.py ├── ldm ├── lr_scheduler.py ├── models │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── classifier.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── ldm.py │ │ ├── plms.py │ │ └── plmsg.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── condition_net.py │ │ ├── model.py │ │ ├── multimodal_openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ ├── multimodal_attention.py │ └── x_transformer.py └── util.py ├── requirements.txt └── visualization ├── draw_utils.py ├── extract_utils.py └── image_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | *.jpg 3 | *.pth 4 | *.ckpt 5 | *.npy 6 | OUTPUT/ 7 | wandb/ 8 | inference/ 9 | sampling/ 10 | 11 | # IntelliJ project files 12 | .idea 13 | *.iml 14 | out 15 | gen 16 | 17 | ### Vim template 18 | [._]*.s[a-w][a-z] 19 | [._]s[a-w][a-z] 20 | *.un~ 21 | Session.vim 22 | .netrwhist 23 | *~ 24 | 25 | ### IPythonNotebook template 26 | # Temporary data 27 | .ipynb_checkpoints/ 28 | 29 | ### Python template 30 | # Byte-compiled / optimized / DLL files 31 | __pycache__/ 32 | *.py[cod] 33 | *$py.class 34 | 35 | # C extensions 36 | *.so 37 | 38 | # Distribution / packaging 39 | .Python 40 | env/ 41 | build/ 42 | develop-eggs/ 43 | dist/ 44 | downloads/ 45 | eggs/ 46 | .eggs/ 47 | #lib/ 48 | #lib64/ 49 | parts/ 50 | sdist/ 51 | var/ 52 | *.egg-info/ 53 | .installed.cfg 54 | *.egg 55 | 56 | # PyInstaller 57 | # Usually these files are written by a python script from a template 58 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 59 | *.manifest 60 | *.spec 61 | 62 | # Installer logs 63 | pip-log.txt 64 | pip-delete-this-directory.txt 65 | 66 | # Unit test / coverage reports 67 | htmlcov/ 68 | .tox/ 69 | .coverage 70 | .coverage.* 71 | .cache 72 | nosetests.xml 73 | coverage.xml 74 | *,cover 75 | 76 | # Translations 77 | *.mo 78 | *.pot 79 | 80 | # Django stuff: 81 | *.log 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | target/ 88 | 89 | *.ipynb 90 | *.params 91 | # *.json 92 | .vscode/ 93 | *.code-workspace/ 94 | 95 | lib/pycocotools/_mask.c 96 | lib/nms/cpu_nms.c 97 | 98 | DATA/ 99 | logs/ 100 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiffBlender: Scalable and Composable Multimodal Text-to-Image Diffusion Models 🔥 2 | 3 | 4 | 5 | 6 | 7 | 8 | - **DiffBlender** successfully synthesizes complex combinations of input modalities. It enables flexible manipulation of conditions, providing the customized generation aligned with user preferences. 9 | - We designed its structure to intuitively extend to additional modalities while achieving a low training cost through a partial update of hypernetworks. 10 | 11 |

12 | teaser 13 |

14 | 15 | ## 🗓️ TODOs 16 | 17 | - [x] Project page is open: [link](https://sungnyun.github.io/diffblender/) 18 | - [x] DiffBlender model: code & checkpoint 19 | - [x] Release inference code 20 | - [ ] Release training code & pipeline 21 | - [ ] Gradio UI 22 | 23 | ## 🚀 Getting Started 24 | Install the necessary packages with: 25 | ```sh 26 | $ pip install -r requirements.txt 27 | ``` 28 | 29 | Download DiffBlender model checkpoint from this [Huggingface model](https://huggingface.co/sungnyun/diffblender), and place it under `./diffblender_checkpoints/`. 30 | Also, prepare the SD model from this [link](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original) (we used CompVis/sd-v1-4.ckpt). 31 | 32 | ## ⚡️ Try Multimodal T2I Generation with DiffBlender 33 | ```sh 34 | $ python inference.py --ckpt_path=./diffblender_checkpoints/{CKPT_NAME}.pth \ 35 | --official_ckpt_path=/path/to/sd-v1-4.ckpt \ 36 | --save_name={SAVE_NAME} 37 | ``` 38 | 39 | Results will be saved under `./inference/{SAVE_NAME}/`, in the format as {conditions + generated image}. 40 | 41 | 42 | 43 | ## BibTeX 44 | ``` 45 | @article{kim2023diffblender, 46 | title={DiffBlender: Scalable and Composable Multimodal Text-to-Image Diffusion Models}, 47 | author={Kim, Sungnyun and Lee, Junsoo and Hong, Kibeom and Kim, Daesik and Ahn, Namhyuk}, 48 | journal={arXiv preprint arXiv:2305.15194}, 49 | year={2023} 50 | } 51 | ``` 52 | -------------------------------------------------------------------------------- /assets/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/assets/fig1.png -------------------------------------------------------------------------------- /grounding_input/condition_null_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch as th 3 | 4 | 5 | 6 | class BoxConditionInput: 7 | def __init__(self): 8 | self.set = False 9 | 10 | def prepare(self, batch): 11 | self.set = True 12 | 13 | boxes = batch["values"] 14 | masks = batch["masks"] 15 | text_embeddings = batch["text_embeddings"] 16 | 17 | self.batch_size, self.max_box, self.embedding_len = text_embeddings.shape 18 | self.device = text_embeddings.device 19 | self.dtype = text_embeddings.dtype 20 | 21 | # return {"values": boxes, "masks": masks, "text_embeddings": text_embeddings} 22 | 23 | def get_null_input(self, batch, batch_size=None, device=None, dtype=None): 24 | assert self.set, "not set yet, cannot call this funcion" 25 | batch_size = self.batch_size if batch_size is None else batch_size 26 | device = self.device if device is None else device 27 | dtype = self.dtype if dtype is None else dtype 28 | 29 | boxes = th.zeros(batch_size, self.max_box, 4).type(dtype).to(device) 30 | masks = th.zeros(batch_size, self.max_box).type(dtype).to(device) 31 | text_embeddings = th.zeros(batch_size, self.max_box, self.embedding_len).type(dtype).to(device) 32 | 33 | batch["values"] = boxes 34 | batch["masks"] = masks 35 | batch["text_embeddings"] = text_embeddings 36 | 37 | return batch 38 | 39 | 40 | class KeypointConditionInput: 41 | def __init__(self): 42 | self.set = False 43 | 44 | def prepare(self, batch): 45 | self.set = True 46 | 47 | points = batch["values"] 48 | masks = batch["masks"] 49 | 50 | self.batch_size, self.max_persons_per_image, _ = points.shape 51 | self.max_persons_per_image = int(self.max_persons_per_image / 17) 52 | self.device = points.device 53 | self.dtype = points.dtype 54 | 55 | # return {"values": points, "masks": masks} 56 | 57 | def get_null_input(self, batch, batch_size=None, device=None, dtype=None): 58 | assert self.set, "not set yet, cannot call this funcion" 59 | batch_size = self.batch_size if batch_size is None else batch_size 60 | device = self.device if device is None else device 61 | dtype = self.dtype if dtype is None else dtype 62 | 63 | points = th.zeros(batch_size, self.max_persons_per_image*17, 2).to(device) 64 | masks = th.zeros(batch_size, self.max_persons_per_image*17).to(device) 65 | 66 | batch["values"] = points 67 | batch["masks"] = masks 68 | 69 | return batch 70 | 71 | 72 | class NSPVectorConditionInput: 73 | def __init__(self): 74 | self.set = False 75 | 76 | def prepare(self, batch): 77 | self.set = True 78 | 79 | vectors = batch["values"] 80 | masks = batch["masks"] 81 | 82 | self.batch_size, self.in_dim = vectors.shape 83 | self.device = vectors.device 84 | self.dtype = vectors.dtype 85 | 86 | # return {"values": vectors, "masks": masks} 87 | 88 | def get_null_input(self, batch, batch_size=None, device=None, dtype=None): 89 | assert self.set, "not set yet, cannot call this funcion" 90 | batch_size = self.batch_size if batch_size is None else batch_size 91 | device = self.device if device is None else device 92 | dtype = self.dtype if dtype is None else dtype 93 | 94 | vectors = th.zeros(batch_size, self.in_dim).to(device) 95 | masks = th.zeros(batch_size, 1).to(device) 96 | 97 | batch["values"] = vectors 98 | batch["masks"] = masks 99 | 100 | return batch 101 | 102 | 103 | class ImageConditionInput: 104 | def __init__(self): 105 | self.set = False 106 | 107 | def prepare(self, batch): 108 | self.set = True 109 | 110 | image = batch["values"] 111 | self.batch_size, self.in_channel, self.H, self.W = image.shape 112 | self.device = image.device 113 | self.dtype = image.dtype 114 | 115 | # return {"values": image} 116 | 117 | def get_null_input(self, batch, batch_size=None, device=None, dtype=None): 118 | assert self.set, "not set yet, cannot call this funcion" 119 | batch_size = self.batch_size if batch_size is None else batch_size 120 | device = self.device if device is None else device 121 | dtype = self.dtype if dtype is None else dtype 122 | 123 | image = th.zeros(batch_size, self.in_channel, self.H, self.W).to(device) 124 | masks = th.zeros(batch_size, 1).to(device) 125 | 126 | batch["values"] = image 127 | batch["masks"] = masks 128 | 129 | return batch 130 | 131 | -------------------------------------------------------------------------------- /images/color1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/images/color1.pth -------------------------------------------------------------------------------- /images/color2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/images/color2.pth -------------------------------------------------------------------------------- /images/dummy.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/images/dummy.pth -------------------------------------------------------------------------------- /images/fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/images/fire.png -------------------------------------------------------------------------------- /images/jeep_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/images/jeep_depth.png -------------------------------------------------------------------------------- /images/jeep_sketch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/images/jeep_sketch.png -------------------------------------------------------------------------------- /images/nature.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/images/nature.png -------------------------------------------------------------------------------- /images/partial_sketch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/images/partial_sketch.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | from PIL import Image 5 | from omegaconf import OmegaConf 6 | from copy import deepcopy 7 | from functools import partial 8 | 9 | import torch 10 | import torchvision 11 | import torchvision.transforms as tf 12 | from torch.utils.data import DataLoader 13 | import pytorch_lightning 14 | 15 | from ldm.util import instantiate_from_config 16 | from ldm.models.diffusion.plms import PLMSSampler 17 | from transformers import CLIPProcessor, CLIPModel 18 | from visualization.extract_utils import get_sketch, get_depth, get_box, get_keypoint, get_color_palette, get_clip_feature 19 | import visualization.image_utils as iutils 20 | from visualization.draw_utils import * 21 | 22 | 23 | device = "cuda" 24 | 25 | 26 | def read_official_ckpt(ckpt_path, no_model=False): 27 | state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] 28 | out = {} 29 | out["model"] = {} 30 | out["text_encoder"] = {} 31 | out["autoencoder"] = {} 32 | out["unexpected"] = {} 33 | out["diffusion"] = {} 34 | 35 | for k,v in state_dict.items(): 36 | if k.startswith('model.diffusion_model'): 37 | if no_model: 38 | continue 39 | out["model"][k.replace("model.diffusion_model.", "")] = v 40 | elif k.startswith('cond_stage_model'): 41 | out["text_encoder"][k.replace("cond_stage_model.", "")] = v 42 | elif k.startswith('first_stage_model'): 43 | out["autoencoder"][k.replace("first_stage_model.", "")] = v 44 | elif k in ["model_ema.decay", "model_ema.num_updates"]: 45 | out["unexpected"][k] = v 46 | else: 47 | out["diffusion"][k] = v 48 | 49 | if no_model: 50 | del state_dict 51 | return out 52 | 53 | def batch_to_device(batch, device): 54 | for k in batch: 55 | if isinstance(batch[k], torch.Tensor): 56 | batch[k] = batch[k].to(device) 57 | 58 | elif isinstance(batch[k], dict): 59 | for k_2 in batch[k]: 60 | if isinstance(batch[k][k_2], torch.Tensor): 61 | batch[k][k_2] = batch[k][k_2].to(device) 62 | return batch 63 | 64 | def load_ckpt(ckpt_path, official_ckpt_path='./sd-v1-4.ckpt'): 65 | 66 | saved_ckpt = torch.load(ckpt_path) 67 | config = saved_ckpt["config_dict"]["_content"] 68 | 69 | model = instantiate_from_config(config['model']).to(device).eval() 70 | autoencoder = instantiate_from_config(config['autoencoder']).to(device).eval() 71 | text_encoder = instantiate_from_config(config['text_encoder']).to(device).eval() 72 | diffusion = instantiate_from_config(config['diffusion']).to(device) 73 | 74 | # donot need to load official_ckpt for self.model here, since we will load from our ckpt 75 | missing, unexpected = model.load_state_dict( saved_ckpt['model'], strict=False ) 76 | assert missing == [] 77 | # print('unexpected keys:', unexpected) 78 | 79 | official_ckpt = read_official_ckpt(official_ckpt_path) 80 | autoencoder.load_state_dict( official_ckpt["autoencoder"] ) 81 | text_encoder.load_state_dict( official_ckpt["text_encoder"] ) 82 | diffusion.load_state_dict( official_ckpt["diffusion"] ) 83 | 84 | if model.use_autoencoder_kl: 85 | for mode, input_type in zip(model.input_modalities, model.input_types): 86 | if input_type == "image": 87 | model.condition_nets[mode].autoencoder = deepcopy(autoencoder) 88 | model.condition_nets[mode].set = True 89 | 90 | return model, autoencoder, text_encoder, diffusion, config 91 | 92 | def set_alpha_scale(model, alpha_scale): 93 | from ldm.modules.multimodal_attention import GatedCrossAttentionDense, GatedSelfAttentionDense 94 | from ldm.modules.diffusionmodules.multimodal_openaimodel import UNetModel 95 | alpha_scale_sp, alpha_scale_nsp, alpha_scale_image = alpha_scale 96 | for name, module in model.named_modules(): 97 | if type(module) == GatedCrossAttentionDense or type(module) == GatedSelfAttentionDense: 98 | if '.sp_fuser' in name: 99 | module.scale = alpha_scale_sp 100 | elif '.nsp_fuser' in name: 101 | module.scale = alpha_scale_nsp 102 | elif type(module) == UNetModel: 103 | module.scales = [alpha_scale_image] * 4 104 | 105 | def alpha_generator(length, config): 106 | """ 107 | length is total timestpes needed for sampling. 108 | type should be a list containing three values which sum should be 1 109 | 110 | It means the percentage of three stages: 111 | alpha=scale stage 112 | linear deacy stage 113 | alpha=0 stage. 114 | 115 | For example if length=100, type=[0.8,0.1,0.1,_scale_] 116 | then the first 800 stpes, alpha will be _scale_, and then linearly decay to 0 in the next 100 steps, 117 | and the last 100 stpes are 0. 118 | """ 119 | 120 | alpha_schedule_sp = config['alpha_type_sp'] 121 | alpha_schedule_nsp = config['alpha_type_nsp'] 122 | alpha_schedule_image = config['alpha_type_image'] 123 | 124 | alphas_ = list() 125 | for alpha_schedule in [alpha_schedule_sp, alpha_schedule_nsp, alpha_schedule_image]: 126 | 127 | assert len(alpha_schedule)==4 128 | assert alpha_schedule[0] + alpha_schedule[1] + alpha_schedule[2] == 1 129 | 130 | stage0_length = int(alpha_schedule[0]*length) 131 | stage1_length = int(alpha_schedule[1]*length) 132 | stage2_length = length - stage0_length - stage1_length 133 | 134 | if stage1_length != 0: 135 | decay_alphas = alpha_schedule[3] * np.arange(start=0, stop=1, step=1/stage1_length)[::-1] 136 | decay_alphas = list(decay_alphas) 137 | else: 138 | decay_alphas = [] 139 | 140 | alphas = [alpha_schedule[3]]*stage0_length + decay_alphas + [0]*stage2_length 141 | 142 | assert len(alphas) == length 143 | alphas_.append(alphas) 144 | 145 | return list(zip(*alphas_)) 146 | 147 | def preprocess(prompt="", 148 | sketch=None, 149 | depth=None, 150 | phrases=None, 151 | locations=None, 152 | keypoints=None, 153 | color=None, 154 | reference=None): 155 | batch = dict() 156 | null_conditions = [] 157 | 158 | batch = torch.load('./images/dummy.pth', map_location='cpu') # dummy var 159 | batch["caption"] = [prompt] 160 | 161 | if sketch is not None: 162 | sketch_tensor = get_sketch(sketch) 163 | selected_sketch = dict() 164 | selected_sketch["values"] = sketch_tensor.unsqueeze(0) 165 | selected_sketch["masks"] = torch.tensor([[1.]]) 166 | batch["sketch"] = selected_sketch 167 | else: 168 | null_conditions.append("sketch") 169 | 170 | if depth is not None: 171 | depth_tensor = get_depth(depth) 172 | selected_depth = dict() 173 | selected_depth["values"] = depth_tensor.unsqueeze(0) 174 | selected_depth["masks"] = torch.tensor([[1.]]) 175 | batch["depth"] = selected_depth 176 | else: 177 | null_conditions.append("depth") 178 | 179 | if locations is not None and phrases is not None: 180 | version = "openai/clip-vit-large-patch14" 181 | clip_model = CLIPModel.from_pretrained(version).cuda() 182 | clip_processor = CLIPProcessor.from_pretrained(version) 183 | boxes, masks, text_embeddings = get_box(locations, phrases, clip_model, clip_processor) 184 | selected_box = dict() 185 | selected_box["values"] = boxes.unsqueeze(0) 186 | selected_box["masks"] = masks.unsqueeze(0) 187 | selected_box["text_embeddings"] = text_embeddings.unsqueeze(0) 188 | batch["box"] = selected_box 189 | else: 190 | null_conditions.append("box") 191 | 192 | if keypoints is not None: 193 | points, masks = get_keypoint(keypoints) 194 | selected_keypoint = dict() 195 | selected_keypoint["values"] = points.unsqueeze(0) 196 | selected_keypoint["masks"] = masks.unsqueeze(0) 197 | batch["keypoint"] = selected_keypoint 198 | else: 199 | null_conditions.append("keypoint") 200 | 201 | if color is not None: 202 | selected_color_palette = dict() 203 | # color_palette = get_color_palette(color) # for .png file 204 | # selected_color_palette["values"] = torch.tensor(color_palette, dtype=torch.float32).unsqueeze(0) 205 | selected_color_palette["values"] = torch.load(color).unsqueeze(0) 206 | selected_color_palette["masks"] = torch.tensor([[1.]]) 207 | batch["color_palette"] = selected_color_palette 208 | else: 209 | null_conditions.append("color_palette") 210 | 211 | if reference is not None: 212 | version = "openai/clip-vit-large-patch14" 213 | clip_model = CLIPModel.from_pretrained(version).cuda() 214 | clip_processor = CLIPProcessor.from_pretrained(version) 215 | clip_features = get_clip_feature(reference, clip_model, clip_processor) 216 | selected_image_embedding = dict() 217 | selected_image_embedding["values"] = torch.tensor(clip_features).unsqueeze(0) 218 | selected_image_embedding["masks"] = torch.tensor([[1.]]) 219 | batch["image_embedding"] = selected_image_embedding 220 | batch["image_embedding"]["image"] = tf.ToTensor()(tf.Resize((512,512))(Image.open(reference).convert('RGB'))).unsqueeze(0) 221 | else: 222 | batch["image_embedding"]["image"] = -torch.ones_like(batch["image"]) 223 | null_conditions.append("image_embedding") 224 | 225 | return batch, null_conditions 226 | 227 | 228 | 229 | @torch.no_grad() 230 | def run(selected_batch_, 231 | config, 232 | model, 233 | autoencoder, 234 | text_encoder, 235 | diffusion, 236 | condition_null_generator_dict, 237 | idx, 238 | NULL_CONDITION, 239 | SAVE_NAME, 240 | seed): 241 | 242 | #### Starting noise fixed #### 243 | torch.manual_seed(seed) 244 | torch.cuda.manual_seed(seed) 245 | starting_noise = torch.randn(1, 4, 64, 64).to(device) 246 | 247 | selected_batch = deepcopy(selected_batch_) 248 | uc_batch = deepcopy(selected_batch_) 249 | for mode in condition_null_generator_dict: 250 | if mode in NULL_CONDITION: 251 | condition_null_generator = condition_null_generator_dict[mode] 252 | condition_null_generator.prepare(selected_batch[mode]) 253 | selected_batch[mode] = condition_null_generator.get_null_input(selected_batch[mode]) 254 | else: 255 | if mode in ["sketch", "depth"]: 256 | continue 257 | condition_null_generator = condition_null_generator_dict[mode] 258 | condition_null_generator.prepare(uc_batch[mode]) 259 | uc_batch[mode] = condition_null_generator.get_null_input(uc_batch[mode]) 260 | 261 | selected_batch = batch_to_device(selected_batch, device) 262 | uc_batch = batch_to_device(uc_batch, device) 263 | 264 | torch.cuda.empty_cache() 265 | 266 | batch_here = config['batch_size'] 267 | context = text_encoder.encode(selected_batch["caption"]) 268 | # you can set negative prompts here 269 | # uc = text_encoder.encode(batch_here*["longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"]) 270 | uc = text_encoder.encode(batch_here*[""]) 271 | 272 | # plms sampling 273 | alpha_generator_func = partial(alpha_generator, config=config) 274 | sampler = PLMSSampler(diffusion, model, alpha_generator_func=alpha_generator_func, set_alpha_scale=set_alpha_scale) 275 | steps = 50 276 | shape = (batch_here, model.in_channels, model.image_size, model.image_size) 277 | 278 | input_dict = dict(x = starting_noise, 279 | timesteps = None, 280 | context = context, 281 | inpainting_extra_input = None, 282 | condition = selected_batch ) 283 | 284 | uc_dict = dict(context = uc, 285 | condition = uc_batch ) 286 | 287 | samples = sampler.sample(S=steps, shape=shape, input=input_dict, uc_dict=uc_dict, guidance_scale=config['guidance_scale']) 288 | pred_image = autoencoder.decode(samples) 289 | 290 | image_dict = [ 291 | # {"tensors": selected_batch["image"], "n_in_row": 1, "pp_type": iutils.PP_RGB}, 292 | {"tensors": draw_sketch_with_batch_to_tensor(selected_batch), "n_in_row": 1, "pp_type": iutils.PP_SEGM}, 293 | {"tensors": draw_depth_with_batch_to_tensor(selected_batch), "n_in_row": 1, "pp_type": iutils.PP_SEGM}, 294 | {"tensors": draw_boxes_with_batch_to_tensor(selected_batch), "n_in_row": 1, "pp_type": iutils.PP_SEGM}, 295 | {"tensors": draw_keypoints_with_batch_to_tensor(selected_batch), "n_in_row": 1, "pp_type": iutils.PP_SEGM}, 296 | {"tensors": draw_image_embedding_with_batch_to_tensor(selected_batch), "n_in_row": 1, "pp_type": iutils.PP_RGB}, 297 | {"tensors": draw_color_palettes_with_batch_to_tensor(selected_batch), "n_in_row": 1, "pp_type": iutils.PP_SEGM}, # range 0~1 298 | {"tensors": pred_image, "n_in_row": 1, "pp_type": iutils.PP_RGB}, 299 | ] 300 | os.makedirs(os.path.join("inference", SAVE_NAME), exist_ok=True) 301 | iutils.save_images_from_dict( 302 | image_dict, dir_path=os.path.join("inference", SAVE_NAME), file_name="sampled_{:4d}".format(idx), 303 | n_instance=config['batch_size'], is_save=True, return_images=False 304 | ) 305 | save_path = os.path.join("inference", SAVE_NAME, 'captions.txt') 306 | with open(save_path, "a") as f: 307 | f.write( 'idx ' + str(idx) + ':\n' ) 308 | for cap in selected_batch['caption']: 309 | f.write( cap + '\n' ) 310 | f.write( '\n' ) 311 | print("Save images and its corresponding captions.. done") 312 | 313 | return pred_image.detach().cpu() 314 | 315 | 316 | 317 | if __name__ == "__main__": 318 | 319 | parser = argparse.ArgumentParser() 320 | parser.add_argument("--ckpt_path", type=str, default="./diffblender_checkpoints/checkpoint_latest.pth", help="pretrained checkpoint path") 321 | parser.add_argument("--official_ckpt_path", type=str, default="/path/to/sd-v1-4.ckpt", help="official SD path") 322 | parser.add_argument("--save_name", type=str, default="SAVE_NAME", help="") 323 | 324 | parser.add_argument("--alpha_type_sp", nargs='+', type=float, default=[0.3, 0.0, 0.7, 1.0], help="alpha scheduling type for spatial cond.") 325 | parser.add_argument("--alpha_type_nsp", nargs='+', type=float, default=[0.3, 0.0, 0.7, 1.0], help="alpha scheduling type for non-spatial cond.") 326 | parser.add_argument("--alpha_type_image", nargs='+', type=float, default=[1.0, 0.0, 0.0, 0.7], help="alpha scheduling type for image-form cond.") 327 | parser.add_argument("--guidance_scale", type=float, default=5.0, help="classifier-free guidance scale") 328 | 329 | args = parser.parse_args() 330 | 331 | 332 | model, autoencoder, text_encoder, diffusion, config = load_ckpt(ckpt_path=args.ckpt_path, official_ckpt_path=args.official_ckpt_path) 333 | condition_null_generator_dict = dict() 334 | for mode in config['condition_null_generator']['input_modalities']: 335 | condition_null_generator_dict[mode] = instantiate_from_config(config['condition_null_generator'][mode]) 336 | 337 | # replace config 338 | config['batch_size'] = 1 339 | config['alpha_type_sp'] = args.alpha_type_sp 340 | config['alpha_type_nsp'] = args.alpha_type_nsp 341 | config['alpha_type_image'] = args.alpha_type_image 342 | config['guidance_scale'] = args.guidance_scale 343 | 344 | kwargs_dict = dict( 345 | config=config, 346 | model=model, 347 | autoencoder=autoencoder, 348 | text_encoder=text_encoder, 349 | diffusion=diffusion, 350 | condition_null_generator_dict=condition_null_generator_dict, 351 | SAVE_NAME=args.save_name, 352 | ) 353 | 354 | 355 | meta_list = [ # change 356 | 357 | dict( 358 | prompt = "jeep", 359 | sketch = "images/jeep_sketch.png", 360 | depth = "images/jeep_depth.png", 361 | color = "images/color1.pth", # can also use image file via get_color_palette func 362 | reference = "images/fire.png", 363 | ), 364 | 365 | dict( 366 | prompt = "swimming rabbits", 367 | phrases = ["rabbit", "rabbit", "rabbit"], 368 | locations = [ [0.3500, 0.5000, 1.0000, 0.9500], [0.2000, 0.2500, 0.6000, 0.5500], [0.0500, 0.0500, 0.4000, 0.3000] ], 369 | color = "images/color2.pth", 370 | ), 371 | 372 | dict( 373 | prompt = "jumping astronaut", 374 | sketch = "images/partial_sketch.png", 375 | phrases = ["astronaut"], 376 | locations = [[0.1158, 0.1053, 0.5140, 0.6111]], 377 | keypoints = [ 378 | [ [0.2767, 0.2025], 379 | [0.2617, 0.1875], 380 | [0.2917, 0.1875], 381 | [0.0000, 0.0000], 382 | [0.3117, 0.1800], 383 | [0.2192, 0.2375], 384 | [0.3392, 0.2425], 385 | [0.1942, 0.2850], 386 | [0.3967, 0.3075], 387 | [0.1667, 0.3475], 388 | [0.4142, 0.3675], 389 | [0.2592, 0.3775], 390 | [0.3242, 0.3700], 391 | [0.2717, 0.4425], 392 | [0.3992, 0.4375], 393 | [0.2367, 0.5550], 394 | [0.4067, 0.5225], ] 395 | ], 396 | reference = "images/nature.png", 397 | ), 398 | 399 | ] 400 | 401 | seed_list = [40, 10, 20] # change 402 | 403 | for idx, (meta, seed) in enumerate(zip(meta_list, seed_list)): 404 | batch, null_conditions = preprocess(**meta) 405 | kwargs_dict['idx'] = idx 406 | kwargs_dict['seed'] = seed 407 | kwargs_dict['NULL_CONDITION'] = null_conditions 408 | pred_image = run(batch, **kwargs_dict) 409 | -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /ldm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | #import pytorch_lightning as pl 4 | import torch.nn.functional as F 5 | from contextlib import contextmanager 6 | 7 | # from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer 8 | 9 | from ldm.modules.diffusionmodules.model import Encoder, Decoder 10 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution 11 | 12 | from ldm.util import instantiate_from_config 13 | 14 | 15 | 16 | 17 | class AutoencoderKL(nn.Module): 18 | def __init__(self, 19 | ddconfig, 20 | embed_dim, 21 | scale_factor=1 22 | ): 23 | super().__init__() 24 | self.encoder = Encoder(**ddconfig) 25 | self.decoder = Decoder(**ddconfig) 26 | assert ddconfig["double_z"] 27 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 28 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 29 | self.embed_dim = embed_dim 30 | self.scale_factor = scale_factor 31 | 32 | 33 | 34 | def encode(self, x): 35 | h = self.encoder(x) 36 | moments = self.quant_conv(h) 37 | posterior = DiagonalGaussianDistribution(moments) 38 | return posterior.sample() * self.scale_factor 39 | 40 | def decode(self, z): 41 | z = 1. / self.scale_factor * z 42 | z = self.post_quant_conv(z) 43 | dec = self.decoder(z) 44 | return dec 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch_lightning as pl 4 | from omegaconf import OmegaConf 5 | from torch.nn import functional as F 6 | from torch.optim import AdamW 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from copy import deepcopy 9 | from einops import rearrange 10 | from glob import glob 11 | from natsort import natsorted 12 | 13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel 14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config 15 | 16 | __models__ = { 17 | 'class_label': EncoderUNetModel, 18 | 'segmentation': UNetModel 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | class NoisyLatentImageClassifier(pl.LightningModule): 29 | 30 | def __init__(self, 31 | diffusion_path, 32 | num_classes, 33 | ckpt_path=None, 34 | pool='attention', 35 | label_key=None, 36 | diffusion_ckpt_path=None, 37 | scheduler_config=None, 38 | weight_decay=1.e-2, 39 | log_steps=10, 40 | monitor='val/loss', 41 | *args, 42 | **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.num_classes = num_classes 45 | # get latest config of diffusion model 46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] 47 | self.diffusion_config = OmegaConf.load(diffusion_config).model 48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path 49 | self.load_diffusion() 50 | 51 | self.monitor = monitor 52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps 54 | self.log_steps = log_steps 55 | 56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ 57 | else self.diffusion_model.cond_stage_key 58 | 59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' 60 | 61 | if self.label_key not in __models__: 62 | raise NotImplementedError() 63 | 64 | self.load_classifier(ckpt_path, pool) 65 | 66 | self.scheduler_config = scheduler_config 67 | self.use_scheduler = self.scheduler_config is not None 68 | self.weight_decay = weight_decay 69 | 70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): 71 | sd = torch.load(path, map_location="cpu") 72 | if "state_dict" in list(sd.keys()): 73 | sd = sd["state_dict"] 74 | keys = list(sd.keys()) 75 | for k in keys: 76 | for ik in ignore_keys: 77 | if k.startswith(ik): 78 | print("Deleting key {} from state_dict.".format(k)) 79 | del sd[k] 80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( 81 | sd, strict=False) 82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 83 | if len(missing) > 0: 84 | print(f"Missing Keys: {missing}") 85 | if len(unexpected) > 0: 86 | print(f"Unexpected Keys: {unexpected}") 87 | 88 | def load_diffusion(self): 89 | model = instantiate_from_config(self.diffusion_config) 90 | self.diffusion_model = model.eval() 91 | self.diffusion_model.train = disabled_train 92 | for param in self.diffusion_model.parameters(): 93 | param.requires_grad = False 94 | 95 | def load_classifier(self, ckpt_path, pool): 96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params) 97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels 98 | model_config.out_channels = self.num_classes 99 | if self.label_key == 'class_label': 100 | model_config.pool = pool 101 | 102 | self.model = __models__[self.label_key](**model_config) 103 | if ckpt_path is not None: 104 | print('#####################################################################') 105 | print(f'load from ckpt "{ckpt_path}"') 106 | print('#####################################################################') 107 | self.init_from_ckpt(ckpt_path) 108 | 109 | @torch.no_grad() 110 | def get_x_noisy(self, x, t, noise=None): 111 | noise = default(noise, lambda: torch.randn_like(x)) 112 | continuous_sqrt_alpha_cumprod = None 113 | if self.diffusion_model.use_continuous_noise: 114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) 115 | # todo: make sure t+1 is correct here 116 | 117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, 118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) 119 | 120 | def forward(self, x_noisy, t, *args, **kwargs): 121 | return self.model(x_noisy, t) 122 | 123 | @torch.no_grad() 124 | def get_input(self, batch, k): 125 | x = batch[k] 126 | if len(x.shape) == 3: 127 | x = x[..., None] 128 | x = rearrange(x, 'b h w c -> b c h w') 129 | x = x.to(memory_format=torch.contiguous_format).float() 130 | return x 131 | 132 | @torch.no_grad() 133 | def get_conditioning(self, batch, k=None): 134 | if k is None: 135 | k = self.label_key 136 | assert k is not None, 'Needs to provide label key' 137 | 138 | targets = batch[k].to(self.device) 139 | 140 | if self.label_key == 'segmentation': 141 | targets = rearrange(targets, 'b h w c -> b c h w') 142 | for down in range(self.numd): 143 | h, w = targets.shape[-2:] 144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') 145 | 146 | # targets = rearrange(targets,'b c h w -> b h w c') 147 | 148 | return targets 149 | 150 | def compute_top_k(self, logits, labels, k, reduction="mean"): 151 | _, top_ks = torch.topk(logits, k, dim=1) 152 | if reduction == "mean": 153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 154 | elif reduction == "none": 155 | return (top_ks == labels[:, None]).float().sum(dim=-1) 156 | 157 | def on_train_epoch_start(self): 158 | # save some memory 159 | self.diffusion_model.model.to('cpu') 160 | 161 | @torch.no_grad() 162 | def write_logs(self, loss, logits, targets): 163 | log_prefix = 'train' if self.training else 'val' 164 | log = {} 165 | log[f"{log_prefix}/loss"] = loss.mean() 166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k( 167 | logits, targets, k=1, reduction="mean" 168 | ) 169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k( 170 | logits, targets, k=5, reduction="mean" 171 | ) 172 | 173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) 174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) 175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) 176 | lr = self.optimizers().param_groups[0]['lr'] 177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) 178 | 179 | def shared_step(self, batch, t=None): 180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) 181 | targets = self.get_conditioning(batch) 182 | if targets.dim() == 4: 183 | targets = targets.argmax(dim=1) 184 | if t is None: 185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() 186 | else: 187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() 188 | x_noisy = self.get_x_noisy(x, t) 189 | logits = self(x_noisy, t) 190 | 191 | loss = F.cross_entropy(logits, targets, reduction='none') 192 | 193 | self.write_logs(loss.detach(), logits.detach(), targets.detach()) 194 | 195 | loss = loss.mean() 196 | return loss, logits, x_noisy, targets 197 | 198 | def training_step(self, batch, batch_idx): 199 | loss, *_ = self.shared_step(batch) 200 | return loss 201 | 202 | def reset_noise_accs(self): 203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in 204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} 205 | 206 | def on_validation_start(self): 207 | self.reset_noise_accs() 208 | 209 | @torch.no_grad() 210 | def validation_step(self, batch, batch_idx): 211 | loss, *_ = self.shared_step(batch) 212 | 213 | for t in self.noisy_acc: 214 | _, logits, _, targets = self.shared_step(batch, t) 215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) 216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) 217 | 218 | return loss 219 | 220 | def configure_optimizers(self): 221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 222 | 223 | if self.use_scheduler: 224 | scheduler = instantiate_from_config(self.scheduler_config) 225 | 226 | print("Setting up LambdaLR scheduler...") 227 | scheduler = [ 228 | { 229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 230 | 'interval': 'step', 231 | 'frequency': 1 232 | }] 233 | return [optimizer], scheduler 234 | 235 | return optimizer 236 | 237 | @torch.no_grad() 238 | def log_images(self, batch, N=8, *args, **kwargs): 239 | log = dict() 240 | x = self.get_input(batch, self.diffusion_model.first_stage_key) 241 | log['inputs'] = x 242 | 243 | y = self.get_conditioning(batch) 244 | 245 | if self.label_key == 'class_label': 246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) 247 | log['labels'] = y 248 | 249 | if ismap(y): 250 | log['labels'] = self.diffusion_model.to_rgb(y) 251 | 252 | for step in range(self.log_steps): 253 | current_time = step * self.log_time_interval 254 | 255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) 256 | 257 | log[f'inputs@t{current_time}'] = x_noisy 258 | 259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) 260 | pred = rearrange(pred, 'b h w c -> b c h w') 261 | 262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) 263 | 264 | for key in log: 265 | log[key] = log[key][:N] 266 | 267 | return log 268 | -------------------------------------------------------------------------------- /ldm/models/diffusion/ddim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | from functools import partial 5 | 6 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 7 | 8 | 9 | class DDIMSampler(object): 10 | def __init__(self, diffusion, model, schedule="linear", alpha_generator_func=None, set_alpha_scale=None): 11 | super().__init__() 12 | self.diffusion = diffusion 13 | self.model = model 14 | self.device = diffusion.betas.device 15 | self.ddpm_num_timesteps = diffusion.num_timesteps 16 | self.schedule = schedule 17 | self.alpha_generator_func = alpha_generator_func 18 | self.set_alpha_scale = set_alpha_scale 19 | 20 | 21 | def register_buffer(self, name, attr): 22 | if type(attr) == torch.Tensor: 23 | attr = attr.to(self.device) 24 | setattr(self, name, attr) 25 | 26 | 27 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.): 28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=False) 30 | alphas_cumprod = self.diffusion.alphas_cumprod 31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device) 33 | 34 | self.register_buffer('betas', to_torch(self.diffusion.betas)) 35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.diffusion.alphas_cumprod_prev)) 37 | 38 | # calculations for diffusion q(x_t | x_{t-1}) and others 39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 44 | 45 | # ddim sampling parameters 46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 47 | ddim_timesteps=self.ddim_timesteps, 48 | eta=ddim_eta,verbose=False) 49 | self.register_buffer('ddim_sigmas', ddim_sigmas) 50 | self.register_buffer('ddim_alphas', ddim_alphas) 51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 57 | 58 | 59 | @torch.no_grad() 60 | def sample(self, S, shape, input, uc_dict=None, guidance_scale=1, mask=None, x0=None): 61 | self.make_schedule(ddim_num_steps=S) 62 | return self.ddim_sampling(shape, input, uc_dict, guidance_scale, mask=mask, x0=x0) 63 | 64 | 65 | @torch.no_grad() 66 | def ddim_sampling(self, shape, input, uc_dict=None, guidance_scale=1, mask=None, x0=None): 67 | b = shape[0] 68 | 69 | img = input["x"] 70 | if img == None: 71 | img = torch.randn(shape, device=self.device) 72 | input["x"] = img 73 | 74 | 75 | time_range = np.flip(self.ddim_timesteps) 76 | total_steps = self.ddim_timesteps.shape[0] 77 | 78 | #iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 79 | iterator = time_range 80 | 81 | if self.alpha_generator_func != None: 82 | alphas = self.alpha_generator_func(len(iterator)) 83 | 84 | 85 | for i, step in enumerate(iterator): 86 | 87 | # set alpha 88 | if self.alpha_generator_func != None: 89 | self.set_alpha_scale(self.model, alphas[i]) 90 | 91 | # run 92 | index = total_steps - i - 1 93 | input["timesteps"] = torch.full((b,), step, device=self.device, dtype=torch.long) 94 | 95 | if mask is not None: 96 | assert x0 is not None 97 | img_orig = self.diffusion.q_sample( x0, input["timesteps"] ) 98 | img = img_orig * mask + (1. - mask) * img 99 | input["x"] = img 100 | 101 | img, pred_x0 = self.p_sample_ddim(input, index=index, uc_dict=uc_dict, guidance_scale=guidance_scale) 102 | input["x"] = img 103 | 104 | return img 105 | 106 | 107 | @torch.no_grad() 108 | def p_sample_ddim(self, input, index, uc_dict=None, guidance_scale=1): 109 | 110 | 111 | e_t = self.model(input) 112 | if uc_dict is not None and guidance_scale != 1: 113 | unconditional_input = dict( 114 | x=input["x"], 115 | timesteps=input["timesteps"], 116 | context=uc_dict['context'], 117 | inpainting_extra_input=input["inpainting_extra_input"], 118 | condition=uc_dict['condition'], 119 | ) 120 | e_t_uncond = self.model( unconditional_input ) 121 | e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond) 122 | 123 | # select parameters corresponding to the currently considered timestep 124 | b = input["x"].shape[0] 125 | a_t = torch.full((b, 1, 1, 1), self.ddim_alphas[index], device=self.device) 126 | a_prev = torch.full((b, 1, 1, 1), self.ddim_alphas_prev[index], device=self.device) 127 | sigma_t = torch.full((b, 1, 1, 1), self.ddim_sigmas[index], device=self.device) 128 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), self.ddim_sqrt_one_minus_alphas[index],device=self.device) 129 | 130 | # current prediction for x_0 131 | pred_x0 = (input["x"] - sqrt_one_minus_at * e_t) / a_t.sqrt() 132 | 133 | # direction pointing to x_t 134 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 135 | noise = sigma_t * torch.randn_like( input["x"] ) 136 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 137 | 138 | return x_prev, pred_x0 139 | -------------------------------------------------------------------------------- /ldm/models/diffusion/ddpm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from functools import partial 5 | from ldm.modules.diffusionmodules.util import make_beta_schedule 6 | 7 | 8 | 9 | 10 | 11 | class DDPM(nn.Module): 12 | def __init__(self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 13 | super().__init__() 14 | 15 | self.v_posterior = 0 16 | self.register_schedule(beta_schedule, timesteps, linear_start, linear_end, cosine_s) 17 | 18 | 19 | def register_schedule(self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 20 | 21 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) 22 | alphas = 1. - betas 23 | alphas_cumprod = np.cumprod(alphas, axis=0) 24 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 25 | 26 | timesteps, = betas.shape 27 | self.num_timesteps = int(timesteps) 28 | self.linear_start = linear_start 29 | self.linear_end = linear_end 30 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 31 | 32 | to_torch = partial(torch.tensor, dtype=torch.float32) 33 | 34 | self.register_buffer('betas', to_torch(betas)) 35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 36 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 37 | 38 | # calculations for diffusion q(x_t | x_{t-1}) and others 39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 44 | 45 | # calculations for posterior q(x_{t-1} | x_t, x_0) 46 | posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( 1. - alphas_cumprod) + self.v_posterior * betas 47 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 48 | 49 | self.register_buffer('posterior_variance', to_torch(posterior_variance)) 50 | 51 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 52 | self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) 53 | self.register_buffer('posterior_mean_coef1', to_torch( betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) 54 | self.register_buffer('posterior_mean_coef2', to_torch( (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /ldm/models/diffusion/ldm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from tqdm import tqdm 5 | from ldm.util import default 6 | from ldm.modules.diffusionmodules.util import extract_into_tensor 7 | from .ddpm import DDPM 8 | 9 | 10 | 11 | class LatentDiffusion(DDPM): 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | # hardcoded 15 | self.clip_denoised = False 16 | 17 | 18 | 19 | def q_sample(self, x_start, t, noise=None): 20 | noise = default(noise, lambda: torch.randn_like(x_start)) 21 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 22 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 23 | 24 | 25 | "Does not support DDPM sampling anymore. Only do DDIM or PLMS" 26 | 27 | # = = = = = = = = = = = = Below is for sampling = = = = = = = = = = = = # 28 | 29 | # def predict_start_from_noise(self, x_t, t, noise): 30 | # return ( extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 31 | # extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) 32 | 33 | # def q_posterior(self, x_start, x_t, t): 34 | # posterior_mean = ( 35 | # extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + 36 | # extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 37 | # ) 38 | # posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) 39 | # posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) 40 | # return posterior_mean, posterior_variance, posterior_log_variance_clipped 41 | 42 | 43 | # def p_mean_variance(self, model, x, c, t): 44 | 45 | # model_out = model(x, t, c) 46 | # x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) 47 | 48 | # if self.clip_denoised: 49 | # x_recon.clamp_(-1., 1.) 50 | 51 | # model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) 52 | # return model_mean, posterior_variance, posterior_log_variance, x_recon 53 | 54 | 55 | # @torch.no_grad() 56 | # def p_sample(self, model, x, c, t): 57 | # b, *_, device = *x.shape, x.device 58 | # model_mean, _, model_log_variance, x0 = self.p_mean_variance(model, x=x, c=c, t=t, ) 59 | # noise = torch.randn_like(x) 60 | 61 | # # no noise when t == 0 62 | # nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) 63 | 64 | # return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 65 | 66 | 67 | # @torch.no_grad() 68 | # def p_sample_loop(self, model, shape, c): 69 | # device = self.betas.device 70 | # b = shape[0] 71 | # img = torch.randn(shape, device=device) 72 | 73 | # iterator = tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps) 74 | # for i in iterator: 75 | # ts = torch.full((b,), i, device=device, dtype=torch.long) 76 | # img, x0 = self.p_sample(model, img, c, ts) 77 | 78 | # return img 79 | 80 | 81 | # @torch.no_grad() 82 | # def sample(self, model, shape, c, uc=None, guidance_scale=None): 83 | # return self.p_sample_loop(model, shape, c) 84 | 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /ldm/models/diffusion/plms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | from functools import partial 5 | from copy import deepcopy 6 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 7 | 8 | 9 | class PLMSSampler(object): 10 | def __init__(self, diffusion, model, schedule="linear", alpha_generator_func=None, set_alpha_scale=None): 11 | super().__init__() 12 | self.diffusion = diffusion 13 | self.model = model 14 | self.device = diffusion.betas.device 15 | self.ddpm_num_timesteps = diffusion.num_timesteps 16 | self.schedule = schedule 17 | self.alpha_generator_func = alpha_generator_func 18 | self.set_alpha_scale = set_alpha_scale 19 | 20 | def register_buffer(self, name, attr): 21 | if type(attr) == torch.Tensor: 22 | attr = attr.to(self.device) 23 | setattr(self, name, attr) 24 | 25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=False): 26 | if ddim_eta != 0: 27 | raise ValueError('ddim_eta must be 0 for PLMS') 28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 30 | alphas_cumprod = self.diffusion.alphas_cumprod 31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device) 33 | 34 | self.register_buffer('betas', to_torch(self.diffusion.betas)) 35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.diffusion.alphas_cumprod_prev)) 37 | 38 | # calculations for diffusion q(x_t | x_{t-1}) and others 39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 44 | 45 | # ddim sampling parameters 46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 47 | ddim_timesteps=self.ddim_timesteps, 48 | eta=ddim_eta,verbose=verbose) 49 | self.register_buffer('ddim_sigmas', ddim_sigmas) 50 | self.register_buffer('ddim_alphas', ddim_alphas) 51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 57 | 58 | 59 | @torch.no_grad() 60 | def sample(self, S, shape, input, uc_dict=None, guidance_scale=1, mask=None, x0=None): 61 | self.make_schedule(ddim_num_steps=S) 62 | return self.plms_sampling(shape, input, uc_dict, guidance_scale, mask=mask, x0=x0) 63 | 64 | 65 | @torch.no_grad() 66 | def plms_sampling(self, shape, input, uc_dict=None, guidance_scale=1, mask=None, x0=None): 67 | 68 | b = shape[0] 69 | 70 | img = input["x"] 71 | if img == None: 72 | img = torch.randn(shape, device=self.device) 73 | input["x"] = img 74 | 75 | time_range = np.flip(self.ddim_timesteps) 76 | total_steps = self.ddim_timesteps.shape[0] 77 | 78 | old_eps = [] 79 | 80 | if self.alpha_generator_func != None: 81 | alphas = self.alpha_generator_func(len(time_range)) 82 | 83 | for i, step in enumerate(time_range): 84 | print(f"PLMS sampling step : ({step}/{len(time_range)})") 85 | # set alpha 86 | if self.alpha_generator_func != None: 87 | self.set_alpha_scale(self.model, alphas[i]) 88 | 89 | # run 90 | index = total_steps - i - 1 91 | ts = torch.full((b,), step, device=self.device, dtype=torch.long) 92 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=self.device, dtype=torch.long) 93 | 94 | if mask is not None: 95 | assert x0 is not None 96 | img_orig = self.diffusion.q_sample(x0, ts) 97 | img = img_orig * mask + (1. - mask) * img 98 | input["x"] = img 99 | 100 | img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc_dict=uc_dict, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next) 101 | input["x"] = img 102 | old_eps.append(e_t) 103 | if len(old_eps) >= 4: 104 | old_eps.pop(0) 105 | 106 | return img 107 | 108 | 109 | @torch.no_grad() 110 | def p_sample_plms(self, input, t, index, guidance_scale=1., uc_dict=None, old_eps=None, t_next=None): 111 | x = deepcopy(input["x"]) 112 | b = x.shape[0] 113 | 114 | def get_model_output(input): 115 | e_t = self.model(input) 116 | if uc_dict is not None and guidance_scale != 1: 117 | unconditional_input = dict( 118 | x=input["x"], 119 | timesteps=input["timesteps"], 120 | context=uc_dict['context'], 121 | inpainting_extra_input=input["inpainting_extra_input"], 122 | condition=uc_dict['condition'], 123 | ) 124 | e_t_uncond = self.model(unconditional_input) 125 | e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond) 126 | return e_t 127 | 128 | def get_x_prev_and_pred_x0(e_t, index): 129 | # select parameters corresponding to the currently considered timestep 130 | a_t = torch.full((b, 1, 1, 1), self.ddim_alphas[index], device=self.device) 131 | a_prev = torch.full((b, 1, 1, 1), self.ddim_alphas_prev[index], device=self.device) 132 | sigma_t = torch.full((b, 1, 1, 1), self.ddim_sigmas[index], device=self.device) 133 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), self.ddim_sqrt_one_minus_alphas[index],device=self.device) 134 | 135 | # current prediction for x_0 136 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 137 | 138 | # direction pointing to x_t 139 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 140 | noise = sigma_t * torch.randn_like(x) 141 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 142 | return x_prev, pred_x0 143 | 144 | input["timesteps"] = t 145 | e_t = get_model_output(input) 146 | if len(old_eps) == 0: 147 | # Pseudo Improved Euler (2nd order) 148 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 149 | input["x"] = x_prev 150 | input["timesteps"] = t_next 151 | e_t_next = get_model_output(input) 152 | e_t_prime = (e_t + e_t_next) / 2 153 | elif len(old_eps) == 1: 154 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 155 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 156 | elif len(old_eps) == 2: 157 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 158 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 159 | elif len(old_eps) >= 3: 160 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 161 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 162 | 163 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 164 | 165 | return x_prev, pred_x0, e_t 166 | -------------------------------------------------------------------------------- /ldm/models/diffusion/plmsg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | from functools import partial 5 | from copy import deepcopy 6 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 7 | 8 | 9 | class PLMSGSampler(object): 10 | def __init__(self, diffusion, model, schedule="linear", alpha_generator_func=None, set_alpha_scale=None): 11 | super().__init__() 12 | self.diffusion = diffusion 13 | self.model = model 14 | self.device = diffusion.betas.device 15 | self.ddpm_num_timesteps = diffusion.num_timesteps 16 | self.schedule = schedule 17 | self.alpha_generator_func = alpha_generator_func 18 | self.set_alpha_scale = set_alpha_scale 19 | 20 | def register_buffer(self, name, attr): 21 | if type(attr) == torch.Tensor: 22 | attr = attr.to(self.device) 23 | setattr(self, name, attr) 24 | 25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=False): 26 | if ddim_eta != 0: 27 | raise ValueError('ddim_eta must be 0 for PLMS') 28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 30 | alphas_cumprod = self.diffusion.alphas_cumprod 31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device) 33 | 34 | self.register_buffer('betas', to_torch(self.diffusion.betas)) 35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.diffusion.alphas_cumprod_prev)) 37 | 38 | # calculations for diffusion q(x_t | x_{t-1}) and others 39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 44 | 45 | # ddim sampling parameters 46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 47 | ddim_timesteps=self.ddim_timesteps, 48 | eta=ddim_eta,verbose=verbose) 49 | self.register_buffer('ddim_sigmas', ddim_sigmas) 50 | self.register_buffer('ddim_alphas', ddim_alphas) 51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 57 | 58 | 59 | @torch.no_grad() 60 | def sample(self, S, shape, input, uc_dict=None, guidance_scale=1, guidance_mode_scale=[1], mask=None, x0=None): 61 | self.make_schedule(ddim_num_steps=S) 62 | return self.plms_sampling(shape, input, uc_dict, guidance_scale, guidance_mode_scale, mask=mask, x0=x0) 63 | 64 | 65 | @torch.no_grad() 66 | def plms_sampling(self, shape, input, uc_dict=None, guidance_scale=1, guidance_mode_scale=[1], mask=None, x0=None): 67 | 68 | b = shape[0] 69 | 70 | img = input["x"] 71 | if img == None: 72 | img = torch.randn(shape, device=self.device) 73 | input["x"] = img 74 | 75 | time_range = np.flip(self.ddim_timesteps) 76 | total_steps = self.ddim_timesteps.shape[0] 77 | 78 | old_eps = [] 79 | 80 | if self.alpha_generator_func != None: 81 | alphas = self.alpha_generator_func(len(time_range)) 82 | 83 | for i, step in enumerate(time_range): 84 | print(f"PLMS sampling step : ({step}/{len(time_range)})") 85 | # set alpha 86 | if self.alpha_generator_func != None: 87 | self.set_alpha_scale(self.model, alphas[i]) 88 | 89 | # run 90 | index = total_steps - i - 1 91 | ts = torch.full((b,), step, device=self.device, dtype=torch.long) 92 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=self.device, dtype=torch.long) 93 | 94 | if mask is not None: 95 | assert x0 is not None 96 | img_orig = self.diffusion.q_sample(x0, ts) 97 | img = img_orig * mask + (1. - mask) * img 98 | input["x"] = img 99 | 100 | img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc_dict=uc_dict, guidance_scale=guidance_scale, guidance_mode_scale=guidance_mode_scale, old_eps=old_eps, t_next=ts_next) 101 | input["x"] = img 102 | old_eps.append(e_t) 103 | if len(old_eps) >= 4: 104 | old_eps.pop(0) 105 | 106 | return img 107 | 108 | 109 | @torch.no_grad() 110 | def p_sample_plms(self, input, t, index, guidance_scale=1., guidance_mode_scale=[1], uc_dict=None, old_eps=None, t_next=None): 111 | x = deepcopy(input["x"]) 112 | b = x.shape[0] 113 | 114 | def get_model_output(input): 115 | e_t = self.model(input) 116 | if uc_dict is not None and guidance_scale != 1: 117 | unconditional_input = dict( 118 | x=input["x"], 119 | timesteps=input["timesteps"], 120 | context=uc_dict['context'], 121 | inpainting_extra_input=input["inpainting_extra_input"], 122 | condition=uc_dict['condition'], 123 | ) 124 | unconditional_mode_input = dict( 125 | x=input["x"], 126 | timesteps=input["timesteps"], 127 | context=input['context'], 128 | inpainting_extra_input=input["inpainting_extra_input"], 129 | condition=uc_dict['mode_condition'], 130 | ) 131 | 132 | e_t_uncond = self.model(unconditional_input) 133 | e_t_uncond_mode = self.model(unconditional_mode_input) 134 | mode_guidance = [] 135 | for i, scale in enumerate(guidance_mode_scale): 136 | mode_guidance.append(scale * (e_t[i] - e_t_uncond_mode[i])) 137 | mode_guidance = torch.stack(mode_guidance) 138 | e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond) + mode_guidance 139 | return e_t 140 | 141 | def get_x_prev_and_pred_x0(e_t, index): 142 | # select parameters corresponding to the currently considered timestep 143 | a_t = torch.full((b, 1, 1, 1), self.ddim_alphas[index], device=self.device) 144 | a_prev = torch.full((b, 1, 1, 1), self.ddim_alphas_prev[index], device=self.device) 145 | sigma_t = torch.full((b, 1, 1, 1), self.ddim_sigmas[index], device=self.device) 146 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), self.ddim_sqrt_one_minus_alphas[index],device=self.device) 147 | 148 | # current prediction for x_0 149 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 150 | 151 | # direction pointing to x_t 152 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 153 | noise = sigma_t * torch.randn_like(x) 154 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 155 | return x_prev, pred_x0 156 | 157 | input["timesteps"] = t 158 | e_t = get_model_output(input) 159 | if len(old_eps) == 0: 160 | # Pseudo Improved Euler (2nd order) 161 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 162 | input["x"] = x_prev 163 | input["timesteps"] = t_next 164 | e_t_next = get_model_output(input) 165 | e_t_prime = (e_t + e_t_next) / 2 166 | elif len(old_eps) == 1: 167 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 168 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 169 | elif len(old_eps) == 2: 170 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 171 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 172 | elif len(old_eps) >= 3: 173 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 174 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 175 | 176 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 177 | 178 | return x_prev, pred_x0, e_t 179 | -------------------------------------------------------------------------------- /ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | # from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder 9 | from torch.utils import checkpoint 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | 100 | 101 | 102 | class CrossAttention(nn.Module): 103 | def __init__(self, query_dim, key_dim, value_dim, heads=8, dim_head=64, dropout=0): 104 | super().__init__() 105 | inner_dim = dim_head * heads 106 | self.scale = dim_head ** -0.5 107 | self.heads = heads 108 | 109 | 110 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 111 | self.to_k = nn.Linear(key_dim, inner_dim, bias=False) 112 | self.to_v = nn.Linear(value_dim, inner_dim, bias=False) 113 | 114 | 115 | self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) 116 | 117 | 118 | def fill_inf_from_mask(self, sim, mask): 119 | if mask is not None: 120 | B,M = mask.shape 121 | mask = mask.unsqueeze(1).repeat(1,self.heads,1).reshape(B*self.heads,1,-1) 122 | max_neg_value = -torch.finfo(sim.dtype).max 123 | sim.masked_fill_(~mask, max_neg_value) 124 | return sim 125 | 126 | 127 | def forward(self, x, key, value, mask=None): 128 | 129 | q = self.to_q(x) # B*N*(H*C) 130 | k = self.to_k(key) # B*M*(H*C) 131 | v = self.to_v(value) # B*M*(H*C) 132 | 133 | B, N, HC = q.shape 134 | _, M, _ = key.shape 135 | H = self.heads 136 | C = HC // H 137 | 138 | q = q.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C 139 | k = k.view(B,M,H,C).permute(0,2,1,3).reshape(B*H,M,C) # (B*H)*M*C 140 | v = v.view(B,M,H,C).permute(0,2,1,3).reshape(B*H,M,C) # (B*H)*M*C 141 | 142 | sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale # (B*H)*N*M 143 | self.fill_inf_from_mask(sim, mask) 144 | attn = sim.softmax(dim=-1) # (B*H)*N*M 145 | 146 | out = torch.einsum('b i j, b j d -> b i d', attn, v) # (B*H)*N*C 147 | out = out.view(B,H,N,C).permute(0,2,1,3).reshape(B,N,(H*C)) # B*N*(H*C) 148 | 149 | return self.to_out(out) 150 | 151 | 152 | 153 | 154 | class SelfAttention(nn.Module): 155 | def __init__(self, query_dim, heads=8, dim_head=64, dropout=0.): 156 | super().__init__() 157 | inner_dim = dim_head * heads 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(query_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(query_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) 166 | 167 | def forward(self, x): 168 | q = self.to_q(x) # B*N*(H*C) 169 | k = self.to_k(x) # B*N*(H*C) 170 | v = self.to_v(x) # B*N*(H*C) 171 | 172 | B, N, HC = q.shape 173 | H = self.heads 174 | C = HC // H 175 | 176 | q = q.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C 177 | k = k.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C 178 | v = v.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C 179 | 180 | sim = torch.einsum('b i c, b j c -> b i j', q, k) * self.scale # (B*H)*N*N 181 | attn = sim.softmax(dim=-1) # (B*H)*N*N 182 | 183 | out = torch.einsum('b i j, b j c -> b i c', attn, v) # (B*H)*N*C 184 | out = out.view(B,H,N,C).permute(0,2,1,3).reshape(B,N,(H*C)) # B*N*(H*C) 185 | 186 | return self.to_out(out) 187 | 188 | 189 | 190 | class GatedCrossAttentionDense(nn.Module): 191 | def __init__(self, query_dim, key_dim, value_dim, n_heads, d_head): 192 | super().__init__() 193 | 194 | self.attn = CrossAttention(query_dim=query_dim, key_dim=key_dim, value_dim=value_dim, heads=n_heads, dim_head=d_head) 195 | self.ff = FeedForward(query_dim, glu=True) 196 | 197 | self.norm1 = nn.LayerNorm(query_dim) 198 | self.norm2 = nn.LayerNorm(query_dim) 199 | 200 | self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)) ) 201 | self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)) ) 202 | 203 | # this can be useful: we can externally change magnitude of tanh(alpha) 204 | # for example, when it is set to 0, then the entire model is same as original one 205 | self.scale = 1 206 | 207 | def forward(self, x, objs): 208 | 209 | x = x + self.scale*torch.tanh(self.alpha_attn) * self.attn( self.norm1(x), objs, objs) 210 | x = x + self.scale*torch.tanh(self.alpha_dense) * self.ff( self.norm2(x) ) 211 | 212 | return x 213 | 214 | 215 | class GatedSelfAttentionDense(nn.Module): 216 | def __init__(self, query_dim, context_dim, n_heads, d_head): 217 | super().__init__() 218 | 219 | # we need a linear projection since we need cat visual feature and obj feature 220 | self.linear = nn.Linear(context_dim, query_dim) 221 | 222 | self.attn = SelfAttention(query_dim=query_dim, heads=n_heads, dim_head=d_head) 223 | self.ff = FeedForward(query_dim, glu=True) 224 | 225 | self.norm1 = nn.LayerNorm(query_dim) 226 | self.norm2 = nn.LayerNorm(query_dim) 227 | 228 | self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)) ) 229 | self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)) ) 230 | 231 | # this can be useful: we can externally change magnitude of tanh(alpha) 232 | # for example, when it is set to 0, then the entire model is same as original one 233 | self.scale = 1 234 | 235 | 236 | def forward(self, x, objs): 237 | 238 | N_visual = x.shape[1] 239 | objs = self.linear(objs) 240 | 241 | x = x + self.scale*torch.tanh(self.alpha_attn) * self.attn( self.norm1(torch.cat([x,objs],dim=1)) )[:,0:N_visual,:] 242 | x = x + self.scale*torch.tanh(self.alpha_dense) * self.ff( self.norm2(x) ) 243 | 244 | return x 245 | 246 | 247 | class BasicTransformerBlock(nn.Module): 248 | def __init__(self, query_dim, key_dim, value_dim, n_heads, d_head, fuser_type, use_checkpoint=True): 249 | super().__init__() 250 | self.attn1 = SelfAttention(query_dim=query_dim, heads=n_heads, dim_head=d_head) 251 | self.ff = FeedForward(query_dim, glu=True) 252 | self.attn2 = CrossAttention(query_dim=query_dim, key_dim=key_dim, value_dim=value_dim, heads=n_heads, dim_head=d_head) 253 | self.norm1 = nn.LayerNorm(query_dim) 254 | self.norm2 = nn.LayerNorm(query_dim) 255 | self.norm3 = nn.LayerNorm(query_dim) 256 | self.use_checkpoint = use_checkpoint 257 | 258 | if fuser_type == "gatedSA": 259 | # note key_dim here actually is context_dim 260 | self.fuser = GatedSelfAttentionDense(query_dim, key_dim, n_heads, d_head) 261 | elif fuser_type == "gatedCA": 262 | self.fuser = GatedCrossAttentionDense(query_dim, key_dim, value_dim, n_heads, d_head) 263 | else: 264 | assert False 265 | 266 | 267 | def forward(self, x, context, objs): 268 | # return checkpoint(self._forward, (x, context, objs), self.parameters(), self.use_checkpoint) 269 | if self.use_checkpoint and x.requires_grad: 270 | return checkpoint.checkpoint(self._forward, x, context, objs) 271 | else: 272 | return self._forward(x, context, objs) 273 | 274 | def _forward(self, x, context, objs): 275 | x = self.attn1( self.norm1(x) ) + x 276 | x = self.fuser(x, objs) # identity mapping in the beginning 277 | x = self.attn2(self.norm2(x), context, context) + x 278 | x = self.ff(self.norm3(x)) + x 279 | return x 280 | 281 | 282 | class SpatialTransformer(nn.Module): 283 | def __init__(self, in_channels, key_dim, value_dim, n_heads, d_head, depth=1, fuser_type=None, use_checkpoint=True): 284 | super().__init__() 285 | self.in_channels = in_channels 286 | query_dim = n_heads * d_head 287 | self.norm = Normalize(in_channels) 288 | 289 | 290 | self.proj_in = nn.Conv2d(in_channels, 291 | query_dim, 292 | kernel_size=1, 293 | stride=1, 294 | padding=0) 295 | 296 | self.transformer_blocks = nn.ModuleList( 297 | [BasicTransformerBlock(query_dim, key_dim, value_dim, n_heads, d_head, fuser_type, use_checkpoint=use_checkpoint) 298 | for d in range(depth)] 299 | ) 300 | 301 | self.proj_out = zero_module(nn.Conv2d(query_dim, 302 | in_channels, 303 | kernel_size=1, 304 | stride=1, 305 | padding=0)) 306 | 307 | def forward(self, x, context, objs): 308 | b, c, h, w = x.shape 309 | x_in = x 310 | x = self.norm(x) 311 | x = self.proj_in(x) 312 | x = rearrange(x, 'b c h w -> b (h w) c') 313 | for block in self.transformer_blocks: 314 | x = block(x, context, objs) 315 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 316 | x = self.proj_out(x) 317 | return x + x_in -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/condition_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from ldm.modules.diffusionmodules.util import FourierEmbedder 5 | 6 | from ldm.modules.diffusionmodules.util import ( 7 | conv_nd, 8 | zero_module, 9 | normalization, 10 | ) 11 | 12 | 13 | def conv_nd(dims, *args, **kwargs): 14 | """ 15 | Create a 1D, 2D, or 3D convolution module. 16 | """ 17 | if dims == 1: 18 | return nn.Conv1d(*args, **kwargs) 19 | elif dims == 2: 20 | return nn.Conv2d(*args, **kwargs) 21 | elif dims == 3: 22 | return nn.Conv3d(*args, **kwargs) 23 | raise ValueError(f"unsupported dimensions: {dims}") 24 | 25 | 26 | def avg_pool_nd(dims, *args, **kwargs): 27 | """ 28 | Create a 1D, 2D, or 3D average pooling module. 29 | """ 30 | if dims == 1: 31 | return nn.AvgPool1d(*args, **kwargs) 32 | elif dims == 2: 33 | return nn.AvgPool2d(*args, **kwargs) 34 | elif dims == 3: 35 | return nn.AvgPool3d(*args, **kwargs) 36 | raise ValueError(f"unsupported dimensions: {dims}") 37 | 38 | 39 | class Downsample(nn.Module): 40 | """ 41 | A downsampling layer with an optional convolution. 42 | :param channels: channels in the inputs and outputs. 43 | :param use_conv: a bool determining if a convolution is applied. 44 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 45 | downsampling occurs in the inner-two dimensions. 46 | """ 47 | 48 | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): 49 | super().__init__() 50 | self.channels = channels 51 | self.out_channels = out_channels or channels 52 | self.use_conv = use_conv 53 | self.dims = dims 54 | stride = 2 if dims != 3 else (1, 2, 2) 55 | if use_conv: 56 | self.op = conv_nd( 57 | dims, self.channels, self.out_channels, 3, stride=stride, padding=padding 58 | ) 59 | else: 60 | assert self.channels == self.out_channels 61 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 62 | 63 | def forward(self, x): 64 | assert x.shape[1] == self.channels 65 | return self.op(x) 66 | 67 | 68 | class ResnetBlock(nn.Module): 69 | def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True): 70 | super().__init__() 71 | ps = ksize // 2 72 | if in_c != out_c or sk == False: 73 | self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps) 74 | else: 75 | self.in_conv = None 76 | 77 | self.norm = normalization(out_c) 78 | 79 | self.body = nn.Sequential( 80 | conv_nd(2, out_c, out_c, 3, padding=1), 81 | nn.SiLU(), 82 | conv_nd(2, out_c, out_c, ksize, padding=0), 83 | ) 84 | if sk == False: 85 | self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps) 86 | else: 87 | self.skep = None 88 | 89 | self.down = down 90 | if self.down == True: 91 | self.down_opt = Downsample(in_c, use_conv=use_conv) 92 | 93 | def forward(self, x): 94 | if self.down == True: 95 | x = self.down_opt(x) 96 | if self.in_conv is not None: # edit 97 | x = self.in_conv(x) 98 | 99 | x = self.norm(x) 100 | h = self.body(x) 101 | if self.skep is not None: 102 | return h + self.skep(x) 103 | else: 104 | return h + x 105 | 106 | 107 | class ImageConditionNet(nn.Module): 108 | def __init__(self, autoencoder=None, channels=[320, 640, 1280, 1280], nums_rb=3, cin=4, ksize=3, sk=False, use_conv=True): 109 | super(ImageConditionNet, self).__init__() 110 | self.autoencoder = autoencoder 111 | self.set = False 112 | 113 | self.channels = channels 114 | self.nums_rb = nums_rb 115 | 116 | self.in_layers = nn.Sequential( 117 | nn.InstanceNorm2d(cin, affine=True), 118 | conv_nd(2, cin, channels[0], 3, padding=1), 119 | ) 120 | 121 | self.body = [] 122 | self.out_convs = [] 123 | for i in range(len(channels)): 124 | for j in range(nums_rb): 125 | if (i != 0) and (j == 0): 126 | self.body.append( 127 | ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv)) 128 | else: 129 | self.body.append( 130 | ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv)) 131 | self.out_convs.append( 132 | zero_module( 133 | conv_nd(2, channels[i], channels[i], 1, padding=0, bias=False) 134 | ) 135 | ) 136 | 137 | self.body = nn.ModuleList(self.body) 138 | self.out_convs = nn.ModuleList(self.out_convs) 139 | 140 | def forward(self, input_dict, h=None): 141 | assert self.set 142 | 143 | x, masks = input_dict['values'], input_dict['masks'] 144 | 145 | with torch.no_grad(): 146 | x = self.autoencoder.encode(x) 147 | 148 | # extract features 149 | features = [] 150 | 151 | x = self.in_layers(x) 152 | 153 | for i in range(len(self.channels)): 154 | for j in range(self.nums_rb): 155 | idx = i * self.nums_rb + j 156 | x = self.body[idx](x) 157 | x = self.out_convs[i](x) 158 | features.append(x) 159 | 160 | 161 | return features 162 | 163 | 164 | class AutoencoderKLWrapper(nn.Module): 165 | def __init__(self, autoencoder=None): 166 | super().__init__() 167 | self.autoencoder = autoencoder 168 | self.set = False 169 | 170 | def forward(self, input_dict): 171 | assert self.set 172 | x = input_dict['values'] 173 | with torch.no_grad(): 174 | x = self.autoencoder.encode(x) 175 | return x 176 | 177 | 178 | class NSPVectorConditionNet(nn.Module): 179 | def __init__(self, in_dim, out_dim, norm=False, fourier_freqs=0, temperature=100, scale=1.0): 180 | super().__init__() 181 | self.in_dim = in_dim 182 | self.out_dim = out_dim 183 | self.norm = norm 184 | self.fourier_freqs = fourier_freqs 185 | 186 | if fourier_freqs > 0: 187 | self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs, temperature=temperature) 188 | self.in_dim *= (fourier_freqs * 2) 189 | if self.norm: 190 | self.linears = nn.Sequential( 191 | nn.Linear(self.in_dim, 512), 192 | nn.LayerNorm(512), 193 | nn.SiLU(), 194 | nn.Linear(512, 512), 195 | nn.LayerNorm(512), 196 | nn.SiLU(), 197 | nn.Linear(512, out_dim), 198 | ) 199 | else: 200 | self.linears = nn.Sequential( 201 | nn.Linear(self.in_dim, 512), 202 | nn.SiLU(), 203 | nn.Linear(512, 512), 204 | nn.SiLU(), 205 | nn.Linear(512, out_dim), 206 | ) 207 | 208 | self.null_features = torch.nn.Parameter(torch.zeros([self.in_dim])) 209 | self.scale = scale 210 | 211 | def forward(self, input_dict): 212 | vectors, masks = input_dict['values'], input_dict['masks'] # vectors: B*C, masks: B*1 213 | if self.fourier_freqs > 0: 214 | vectors = self.fourier_embedder(vectors * self.scale) 215 | objs = masks * vectors + (1-masks) * self.null_features.view(1,-1) 216 | return self.linears(objs).unsqueeze(1) # B*1*C 217 | 218 | 219 | class BoxConditionNet(nn.Module): 220 | def __init__(self, in_dim, out_dim, fourier_freqs=8): 221 | super().__init__() 222 | self.in_dim = in_dim 223 | self.out_dim = out_dim 224 | 225 | self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) 226 | self.position_dim = fourier_freqs*2*4 # 2 is sin&cos, 4 is xyxy 227 | 228 | self.linears = nn.Sequential( 229 | nn.Linear(self.in_dim + self.position_dim, 512), 230 | nn.SiLU(), 231 | nn.Linear(512, 512), 232 | nn.SiLU(), 233 | nn.Linear(512, out_dim), 234 | ) 235 | 236 | self.null_text_feature = torch.nn.Parameter(torch.zeros([self.in_dim])) 237 | self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) 238 | 239 | 240 | def forward(self, input_dict): 241 | boxes, masks, text_embeddings = input_dict['values'], input_dict['masks'], input_dict['text_embeddings'] 242 | 243 | B, N, _ = boxes.shape 244 | masks = masks.unsqueeze(-1) 245 | 246 | # embedding position (it may includes padding as placeholder) 247 | xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C 248 | 249 | # learnable null embedding 250 | text_null = self.null_text_feature.view(1,1,-1) 251 | xyxy_null = self.null_position_feature.view(1,1,-1) 252 | 253 | # replace padding with learnable null embedding 254 | text_embeddings = text_embeddings*masks + (1-masks)*text_null 255 | xyxy_embedding = xyxy_embedding*masks + (1-masks)*xyxy_null 256 | 257 | objs = self.linears( torch.cat([text_embeddings, xyxy_embedding], dim=-1) ) 258 | assert objs.shape == torch.Size([B,N,self.out_dim]) 259 | return objs 260 | 261 | 262 | class KeypointConditionNet(nn.Module): 263 | def __init__(self, max_persons_per_image, out_dim, fourier_freqs=8): 264 | super().__init__() 265 | self.max_persons_per_image = max_persons_per_image 266 | self.out_dim = out_dim 267 | 268 | self.person_embeddings = torch.nn.Parameter(torch.zeros([max_persons_per_image,out_dim])) 269 | self.keypoint_embeddings = torch.nn.Parameter(torch.zeros([17,out_dim])) 270 | 271 | 272 | self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) 273 | self.position_dim = fourier_freqs*2*2 # 2 is sin&cos, 2 is xy 274 | 275 | self.linears = nn.Sequential( 276 | nn.Linear(self.out_dim + self.position_dim, 512), 277 | nn.SiLU(), 278 | nn.Linear(512, 512), 279 | nn.SiLU(), 280 | nn.Linear(512, out_dim), 281 | ) 282 | 283 | self.null_person_feature = torch.nn.Parameter(torch.zeros([self.out_dim])) 284 | self.null_xy_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) 285 | 286 | 287 | def forward(self, input_dict): 288 | points, masks = input_dict['values'], input_dict['masks'] 289 | 290 | masks = masks.unsqueeze(-1) 291 | N = points.shape[0] 292 | 293 | person_embeddings = self.person_embeddings.unsqueeze(1).repeat(1,17,1).reshape(self.max_persons_per_image*17, self.out_dim) 294 | keypoint_embeddings = torch.cat([self.keypoint_embeddings]*self.max_persons_per_image, dim=0) 295 | person_embeddings = person_embeddings + keypoint_embeddings # (num_person*17) * C 296 | person_embeddings = person_embeddings.unsqueeze(0).repeat(N,1,1) 297 | 298 | # embedding position (it may includes padding as placeholder) 299 | xy_embedding = self.fourier_embedder(points) # B*N*2 --> B*N*C 300 | 301 | 302 | # learnable null embedding 303 | person_null = self.null_person_feature.view(1,1,-1) 304 | xy_null = self.null_xy_feature.view(1,1,-1) 305 | 306 | # replace padding with learnable null embedding 307 | person_embeddings = person_embeddings*masks + (1-masks)*person_null 308 | xy_embedding = xy_embedding*masks + (1-masks)*xy_null 309 | 310 | objs = self.linears( torch.cat([person_embeddings, xy_embedding], dim=-1) ) 311 | 312 | return objs 313 | 314 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/multimodal_openaimodel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ack: Our code is highly relied on GLIGEN (https://github.com/gligen/GLIGEN). 3 | """ 4 | from abc import abstractmethod 5 | from functools import partial 6 | import math 7 | 8 | import numpy as np 9 | import random 10 | import torch as th 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from ldm.modules.diffusionmodules.util import ( 15 | conv_nd, 16 | linear, 17 | avg_pool_nd, 18 | zero_module, 19 | normalization, 20 | timestep_embedding, 21 | ) 22 | from ldm.modules.multimodal_attention import SpatialTransformer 23 | # from .positionnet import PositionNet 24 | from torch.utils import checkpoint 25 | from ldm.util import instantiate_from_config 26 | 27 | 28 | class TimestepBlock(nn.Module): 29 | """ 30 | Any module where forward() takes timestep embeddings as a second argument. 31 | """ 32 | 33 | @abstractmethod 34 | def forward(self, x, emb): 35 | """ 36 | Apply the module to `x` given `emb` timestep embeddings. 37 | """ 38 | 39 | 40 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 41 | """ 42 | A sequential module that passes timestep embeddings to the children that 43 | support it as an extra input. 44 | """ 45 | 46 | def forward(self, x, emb, context, sp_objs, nsp_objs): 47 | for layer in self: 48 | if isinstance(layer, TimestepBlock): 49 | x = layer(x, emb) 50 | elif isinstance(layer, SpatialTransformer): 51 | x = layer(x, context, sp_objs, nsp_objs) 52 | else: 53 | x = layer(x) 54 | return x 55 | 56 | 57 | class Upsample(nn.Module): 58 | """ 59 | An upsampling layer with an optional convolution. 60 | :param channels: channels in the inputs and outputs. 61 | :param use_conv: a bool determining if a convolution is applied. 62 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 63 | upsampling occurs in the inner-two dimensions. 64 | """ 65 | 66 | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): 67 | super().__init__() 68 | self.channels = channels 69 | self.out_channels = out_channels or channels 70 | self.use_conv = use_conv 71 | self.dims = dims 72 | if use_conv: 73 | self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) 74 | 75 | def forward(self, x): 76 | assert x.shape[1] == self.channels 77 | if self.dims == 3: 78 | x = F.interpolate( 79 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" 80 | ) 81 | else: 82 | x = F.interpolate(x, scale_factor=2, mode="nearest") 83 | if self.use_conv: 84 | x = self.conv(x) 85 | return x 86 | 87 | 88 | 89 | 90 | class Downsample(nn.Module): 91 | """ 92 | A downsampling layer with an optional convolution. 93 | :param channels: channels in the inputs and outputs. 94 | :param use_conv: a bool determining if a convolution is applied. 95 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 96 | downsampling occurs in the inner-two dimensions. 97 | """ 98 | 99 | def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): 100 | super().__init__() 101 | self.channels = channels 102 | self.out_channels = out_channels or channels 103 | self.use_conv = use_conv 104 | self.dims = dims 105 | stride = 2 if dims != 3 else (1, 2, 2) 106 | if use_conv: 107 | self.op = conv_nd( 108 | dims, self.channels, self.out_channels, 3, stride=stride, padding=padding 109 | ) 110 | else: 111 | assert self.channels == self.out_channels 112 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 113 | 114 | def forward(self, x): 115 | assert x.shape[1] == self.channels 116 | return self.op(x) 117 | 118 | 119 | class ResBlock(TimestepBlock): 120 | """ 121 | A residual block that can optionally change the number of channels. 122 | :param channels: the number of input channels. 123 | :param emb_channels: the number of timestep embedding channels. 124 | :param dropout: the rate of dropout. 125 | :param out_channels: if specified, the number of out channels. 126 | :param use_conv: if True and out_channels is specified, use a spatial 127 | convolution instead of a smaller 1x1 convolution to change the 128 | channels in the skip connection. 129 | :param dims: determines if the signal is 1D, 2D, or 3D. 130 | :param use_checkpoint: if True, use gradient checkpointing on this module. 131 | :param up: if True, use this block for upsampling. 132 | :param down: if True, use this block for downsampling. 133 | """ 134 | 135 | def __init__( 136 | self, 137 | channels, 138 | emb_channels, 139 | dropout, 140 | out_channels=None, 141 | use_conv=False, 142 | use_scale_shift_norm=False, 143 | dims=2, 144 | use_checkpoint=False, 145 | up=False, 146 | down=False, 147 | ): 148 | super().__init__() 149 | self.channels = channels 150 | self.emb_channels = emb_channels 151 | self.dropout = dropout 152 | self.out_channels = out_channels or channels 153 | self.use_conv = use_conv 154 | self.use_checkpoint = use_checkpoint 155 | self.use_scale_shift_norm = use_scale_shift_norm 156 | 157 | self.in_layers = nn.Sequential( 158 | normalization(channels), 159 | nn.SiLU(), 160 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 161 | ) 162 | 163 | self.updown = up or down 164 | 165 | if up: 166 | self.h_upd = Upsample(channels, False, dims) 167 | self.x_upd = Upsample(channels, False, dims) 168 | elif down: 169 | self.h_upd = Downsample(channels, False, dims) 170 | self.x_upd = Downsample(channels, False, dims) 171 | else: 172 | self.h_upd = self.x_upd = nn.Identity() 173 | 174 | self.emb_layers = nn.Sequential( 175 | nn.SiLU(), 176 | linear( 177 | emb_channels, 178 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 179 | ), 180 | ) 181 | self.out_layers = nn.Sequential( 182 | normalization(self.out_channels), 183 | nn.SiLU(), 184 | nn.Dropout(p=dropout), 185 | zero_module( 186 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 187 | ), 188 | ) 189 | 190 | if self.out_channels == channels: 191 | self.skip_connection = nn.Identity() 192 | elif use_conv: 193 | self.skip_connection = conv_nd( 194 | dims, channels, self.out_channels, 3, padding=1 195 | ) 196 | else: 197 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 198 | 199 | def forward(self, x, emb): 200 | """ 201 | Apply the block to a Tensor, conditioned on a timestep embedding. 202 | :param x: an [N x C x ...] Tensor of features. 203 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 204 | :return: an [N x C x ...] Tensor of outputs. 205 | """ 206 | # return checkpoint( 207 | # self._forward, (x, emb), self.parameters(), self.use_checkpoint 208 | # ) 209 | if self.use_checkpoint and x.requires_grad: 210 | return checkpoint.checkpoint(self._forward, x, emb ) 211 | else: 212 | return self._forward(x, emb) 213 | 214 | 215 | def _forward(self, x, emb): 216 | if self.updown: 217 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 218 | h = in_rest(x) 219 | h = self.h_upd(h) 220 | x = self.x_upd(x) 221 | h = in_conv(h) 222 | else: 223 | h = self.in_layers(x) 224 | emb_out = self.emb_layers(emb).type(h.dtype) 225 | while len(emb_out.shape) < len(h.shape): 226 | emb_out = emb_out[..., None] 227 | if self.use_scale_shift_norm: 228 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 229 | scale, shift = th.chunk(emb_out, 2, dim=1) 230 | h = out_norm(h) * (1 + scale) + shift 231 | h = out_rest(h) 232 | else: 233 | h = h + emb_out 234 | h = self.out_layers(h) 235 | return self.skip_connection(x) + h 236 | 237 | 238 | 239 | 240 | class UNetModel(nn.Module): 241 | def __init__( 242 | self, 243 | image_size, 244 | in_channels, 245 | model_channels, 246 | out_channels, 247 | num_res_blocks, 248 | attention_resolutions, 249 | dropout=0, 250 | channel_mult=(1, 2, 4, 8), 251 | conv_resample=True, 252 | dims=2, 253 | use_checkpoint=False, 254 | num_heads=8, 255 | use_scale_shift_norm=False, 256 | transformer_depth=1, 257 | context_dim=None, 258 | fuser_type = None, 259 | inpaint_mode = False, 260 | grounding_tokenizer = None, 261 | init_alpha_pre_input_conv=0.1, 262 | use_autoencoder_kl=False, 263 | image_cond_injection_type=None, 264 | input_modalities=[], 265 | input_types=[], 266 | freeze_modules=[], 267 | ): 268 | super().__init__() 269 | 270 | self.image_size = image_size 271 | self.in_channels = in_channels 272 | self.model_channels = model_channels 273 | self.out_channels = out_channels 274 | self.num_res_blocks = num_res_blocks 275 | self.attention_resolutions = attention_resolutions 276 | self.dropout = dropout 277 | self.channel_mult = channel_mult 278 | self.conv_resample = conv_resample 279 | self.use_checkpoint = use_checkpoint 280 | self.num_heads = num_heads 281 | self.context_dim = context_dim 282 | self.fuser_type = fuser_type 283 | self.inpaint_mode = inpaint_mode 284 | self.freeze_modules = freeze_modules 285 | assert fuser_type in ["gatedSA", "gatedCA", "gatedSA-gatedCA", "gatedCA-gatedSA"] 286 | 287 | self.input_modalities = input_modalities 288 | self.input_types = input_types 289 | self.use_autoencoder_kl = use_autoencoder_kl 290 | self.image_cond_injection_type = image_cond_injection_type 291 | assert self.image_cond_injection_type is not None 292 | 293 | 294 | time_embed_dim = model_channels * 4 295 | self.time_embed = nn.Sequential( 296 | linear(model_channels, time_embed_dim), 297 | nn.SiLU(), 298 | linear(time_embed_dim, time_embed_dim), 299 | ) 300 | 301 | num_image_condition = self.input_types.count("image") 302 | num_sp_condition = self.input_types.count("sp_vector") 303 | use_sp = num_sp_condition > 0 304 | num_nsp_condition = self.input_types.count("nsp_vector") 305 | use_nsp = num_nsp_condition > 0 306 | 307 | if num_image_condition >= 1: 308 | pass 309 | 310 | if inpaint_mode: 311 | # The new added channels are: masked image (encoded image) and mask, which is 4+1 312 | self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels+in_channels+1, model_channels, 3, padding=1))]) 313 | else: 314 | """ Enlarged mode""" 315 | # self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, (num_image_condition+1)*in_channels, model_channels, 3, padding=1))]) 316 | """ Non-enlarged mode""" 317 | self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]) 318 | 319 | 320 | input_block_chans = [model_channels] 321 | ch = model_channels 322 | ds = 1 323 | 324 | # = = = = = = = = = = = = = = = = = = = = Down Branch = = = = = = = = = = = = = = = = = = = = # 325 | for level, mult in enumerate(channel_mult): 326 | for _ in range(num_res_blocks): 327 | layers = [ ResBlock(ch, 328 | time_embed_dim, 329 | dropout, 330 | out_channels=mult * model_channels, 331 | dims=dims, 332 | use_checkpoint=use_checkpoint, 333 | use_scale_shift_norm=use_scale_shift_norm,) ] 334 | 335 | ch = mult * model_channels 336 | if ds in attention_resolutions: 337 | dim_head = ch // num_heads 338 | layers.append(SpatialTransformer(ch, key_dim=context_dim, value_dim=context_dim, n_heads=num_heads, d_head=dim_head, depth=transformer_depth, fuser_type=fuser_type, use_checkpoint=use_checkpoint, use_sp=use_sp, use_nsp=use_nsp)) 339 | 340 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 341 | input_block_chans.append(ch) 342 | 343 | if level != len(channel_mult) - 1: # will not go to this downsample branch in the last feature 344 | out_ch = ch 345 | self.input_blocks.append( TimestepEmbedSequential( Downsample(ch, conv_resample, dims=dims, out_channels=out_ch ) ) ) 346 | ch = out_ch 347 | input_block_chans.append(ch) 348 | ds *= 2 349 | dim_head = ch // num_heads 350 | 351 | 352 | # = = = = = = = = = = = = = = = = = = = = BottleNeck = = = = = = = = = = = = = = = = = = = = # 353 | 354 | self.middle_block = TimestepEmbedSequential( 355 | ResBlock(ch, 356 | time_embed_dim, 357 | dropout, 358 | dims=dims, 359 | use_checkpoint=use_checkpoint, 360 | use_scale_shift_norm=use_scale_shift_norm), 361 | SpatialTransformer(ch, key_dim=context_dim, value_dim=context_dim, n_heads=num_heads, d_head=dim_head, depth=transformer_depth, fuser_type=fuser_type, use_checkpoint=use_checkpoint, use_sp=use_sp, use_nsp=use_nsp), 362 | ResBlock(ch, 363 | time_embed_dim, 364 | dropout, 365 | dims=dims, 366 | use_checkpoint=use_checkpoint, 367 | use_scale_shift_norm=use_scale_shift_norm)) 368 | 369 | 370 | 371 | # = = = = = = = = = = = = = = = = = = = = Up Branch = = = = = = = = = = = = = = = = = = = = # 372 | 373 | 374 | self.output_blocks = nn.ModuleList([]) 375 | for level, mult in list(enumerate(channel_mult))[::-1]: 376 | for i in range(num_res_blocks + 1): 377 | ich = input_block_chans.pop() 378 | layers = [ ResBlock(ch + ich, 379 | time_embed_dim, 380 | dropout, 381 | out_channels=model_channels * mult, 382 | dims=dims, 383 | use_checkpoint=use_checkpoint, 384 | use_scale_shift_norm=use_scale_shift_norm) ] 385 | ch = model_channels * mult 386 | 387 | if ds in attention_resolutions: 388 | dim_head = ch // num_heads 389 | layers.append( SpatialTransformer(ch, key_dim=context_dim, value_dim=context_dim, n_heads=num_heads, d_head=dim_head, depth=transformer_depth, fuser_type=fuser_type, use_checkpoint=use_checkpoint, use_sp=use_sp, use_nsp=use_nsp) ) 390 | if level and i == num_res_blocks: 391 | out_ch = ch 392 | layers.append( Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) ) 393 | ds //= 2 394 | 395 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 396 | 397 | 398 | 399 | self.out = nn.Sequential( 400 | normalization(ch), 401 | nn.SiLU(), 402 | zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), 403 | ) 404 | 405 | 406 | # = = = = = = = = = = = = = = = = = = = = Multimodal Condition Networks = = = = = = = = = = = = = = = = = = = = # 407 | 408 | self.condition_nets = nn.ModuleDict() 409 | for mode in self.input_modalities: 410 | self.condition_nets[mode] = instantiate_from_config(grounding_tokenizer["tokenizer_{}".format(mode)]) 411 | 412 | self.scales = [1.0]*4 413 | 414 | 415 | def forward(self, input_dict): 416 | condition = input_dict["condition"] 417 | 418 | # aggregate objs by each type of mode 419 | im_objs, sp_objs, nsp_objs = [], [], [] 420 | for mode, input_type in zip(self.input_modalities, self.input_types): 421 | assert mode in condition 422 | 423 | objs = self.condition_nets[mode](condition[mode]) 424 | 425 | if input_type == "image": 426 | im_objs.append(objs) # B*C*H*W 427 | elif input_type == "sp_vector": 428 | sp_objs.append(objs) # B*N*C 429 | elif input_type == "nsp_vector": 430 | nsp_objs.append(objs) # B*1*C 431 | else: 432 | raise NotImplementedError 433 | 434 | # aggregate image form conditions 435 | im_objs = [th.stack(arr, dim=0).sum(0) for arr in zip(*im_objs)] if len(im_objs) > 0 else None 436 | 437 | sp_objs = th.cat(sp_objs, dim=1) if len(sp_objs)>0 else None 438 | nsp_objs = th.cat(nsp_objs, dim=1) if len(nsp_objs)>0 else None 439 | 440 | # Time embedding 441 | t_emb = timestep_embedding(input_dict["timesteps"], self.model_channels, repeat_only=False) 442 | emb = self.time_embed(t_emb) 443 | 444 | # input tensor 445 | h = input_dict["x"] 446 | 447 | if self.inpaint_mode: 448 | h = th.cat( [h, input_dict["inpainting_extra_input"]], dim=1 ) 449 | 450 | # Text input 451 | context = input_dict["context"] 452 | 453 | # Start forwarding 454 | hs = [] 455 | adapter_idx = 0 456 | for i, module in enumerate(self.input_blocks): 457 | if self.image_cond_injection_type == "enc" and (i+1) % 3 == 0 and im_objs is not None: 458 | h = module(h, emb, context, sp_objs, nsp_objs) 459 | h = h + self.scales[adapter_idx] * im_objs[adapter_idx] 460 | adapter_idx += 1 461 | else: 462 | h = module(h, emb, context, sp_objs, nsp_objs) 463 | hs.append(h) 464 | 465 | h = self.middle_block(h, emb, context, sp_objs, nsp_objs) 466 | 467 | adapter_idx = 0 468 | for i, module in enumerate(self.output_blocks): 469 | if self.image_cond_injection_type == "dec" and i % 3 == 0 and im_objs is not None: 470 | enc_h = hs.pop() + self.scales[adapter_idx] * im_objs.pop() 471 | adapter_idx += 1 472 | else: 473 | enc_h = hs.pop() 474 | h = th.cat([h, enc_h], dim=1) 475 | 476 | h = module(h, emb, context, sp_objs, nsp_objs) 477 | 478 | return self.out(h) 479 | 480 | 481 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from einops import repeat 7 | 8 | from ldm.util import instantiate_from_config 9 | 10 | 11 | 12 | class FourierEmbedder(): 13 | def __init__(self, num_freqs=64, temperature=100): 14 | 15 | self.num_freqs = num_freqs 16 | self.temperature = temperature 17 | self.freq_bands = temperature ** ( torch.arange(num_freqs) / num_freqs ) 18 | 19 | @ torch.no_grad() 20 | def __call__(self, x, cat_dim=-1): 21 | "x: arbitrary shape of tensor. dim: cat dim" 22 | out = [] 23 | for freq in self.freq_bands: 24 | out.append( torch.sin( freq*x ) ) 25 | out.append( torch.cos( freq*x ) ) 26 | return torch.cat(out, cat_dim) 27 | 28 | 29 | 30 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 31 | if schedule == "linear": 32 | betas = ( 33 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 34 | ) 35 | 36 | elif schedule == "cosine": 37 | timesteps = ( 38 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 39 | ) 40 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 41 | alphas = torch.cos(alphas).pow(2) 42 | alphas = alphas / alphas[0] 43 | betas = 1 - alphas[1:] / alphas[:-1] 44 | betas = np.clip(betas, a_min=0, a_max=0.999) 45 | 46 | elif schedule == "sqrt_linear": 47 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 48 | elif schedule == "sqrt": 49 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 50 | else: 51 | raise ValueError(f"schedule '{schedule}' unknown.") 52 | return betas.numpy() 53 | 54 | 55 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 56 | if ddim_discr_method == 'uniform': 57 | c = num_ddpm_timesteps // num_ddim_timesteps 58 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 59 | elif ddim_discr_method == 'quad': 60 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 61 | else: 62 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 63 | 64 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 65 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 66 | steps_out = ddim_timesteps + 1 67 | if verbose: 68 | print(f'Selected timesteps for ddim sampler: {steps_out}') 69 | return steps_out 70 | 71 | 72 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 73 | # select alphas for computing the variance schedule 74 | alphas = alphacums[ddim_timesteps] 75 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 76 | 77 | # according the the formula provided in https://arxiv.org/abs/2010.02502 78 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 79 | if verbose: 80 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 81 | print(f'For the chosen value of eta, which is {eta}, ' 82 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 83 | return sigmas, alphas, alphas_prev 84 | 85 | 86 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 87 | """ 88 | Create a beta schedule that discretizes the given alpha_t_bar function, 89 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 90 | :param num_diffusion_timesteps: the number of betas to produce. 91 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 92 | produces the cumulative product of (1-beta) up to that 93 | part of the diffusion process. 94 | :param max_beta: the maximum beta to use; use values lower than 1 to 95 | prevent singularities. 96 | """ 97 | betas = [] 98 | for i in range(num_diffusion_timesteps): 99 | t1 = i / num_diffusion_timesteps 100 | t2 = (i + 1) / num_diffusion_timesteps 101 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 102 | return np.array(betas) 103 | 104 | 105 | def extract_into_tensor(a, t, x_shape): 106 | b, *_ = t.shape 107 | out = a.gather(-1, t) 108 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 109 | 110 | 111 | def checkpoint(func, inputs, params, flag): 112 | """ 113 | Evaluate a function without caching intermediate activations, allowing for 114 | reduced memory at the expense of extra compute in the backward pass. 115 | :param func: the function to evaluate. 116 | :param inputs: the argument sequence to pass to `func`. 117 | :param params: a sequence of parameters `func` depends on but does not 118 | explicitly take as arguments. 119 | :param flag: if False, disable gradient checkpointing. 120 | """ 121 | if flag: 122 | args = tuple(inputs) + tuple(params) 123 | return CheckpointFunction.apply(func, len(inputs), *args) 124 | else: 125 | return func(*inputs) 126 | 127 | 128 | class CheckpointFunction(torch.autograd.Function): 129 | @staticmethod 130 | def forward(ctx, run_function, length, *args): 131 | ctx.run_function = run_function 132 | ctx.input_tensors = list(args[:length]) 133 | ctx.input_params = list(args[length:]) 134 | 135 | with torch.no_grad(): 136 | output_tensors = ctx.run_function(*ctx.input_tensors) 137 | return output_tensors 138 | 139 | @staticmethod 140 | def backward(ctx, *output_grads): 141 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 142 | with torch.enable_grad(): 143 | # Fixes a bug where the first op in run_function modifies the 144 | # Tensor storage in place, which is not allowed for detach()'d 145 | # Tensors. 146 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 147 | output_tensors = ctx.run_function(*shallow_copies) 148 | input_grads = torch.autograd.grad( 149 | output_tensors, 150 | ctx.input_tensors + ctx.input_params, 151 | output_grads, 152 | allow_unused=True, 153 | ) 154 | del ctx.input_tensors 155 | del ctx.input_params 156 | del output_tensors 157 | return (None, None) + input_grads 158 | 159 | 160 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 161 | """ 162 | Create sinusoidal timestep embeddings. 163 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 164 | These may be fractional. 165 | :param dim: the dimension of the output. 166 | :param max_period: controls the minimum frequency of the embeddings. 167 | :return: an [N x dim] Tensor of positional embeddings. 168 | """ 169 | if not repeat_only: 170 | half = dim // 2 171 | freqs = torch.exp( 172 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 173 | ).to(device=timesteps.device) 174 | args = timesteps[:, None].float() * freqs[None] 175 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 176 | if dim % 2: 177 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 178 | else: 179 | embedding = repeat(timesteps, 'b -> b d', d=dim) 180 | return embedding 181 | 182 | 183 | def zero_module(module): 184 | """ 185 | Zero out the parameters of a module and return it. 186 | """ 187 | for p in module.parameters(): 188 | p.detach().zero_() 189 | return module 190 | 191 | 192 | def scale_module(module, scale): 193 | """ 194 | Scale the parameters of a module and return it. 195 | """ 196 | for p in module.parameters(): 197 | p.detach().mul_(scale) 198 | return module 199 | 200 | 201 | def mean_flat(tensor): 202 | """ 203 | Take the mean over all non-batch dimensions. 204 | """ 205 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 206 | 207 | 208 | def normalization(channels): 209 | """ 210 | Make a standard normalization layer. 211 | :param channels: number of input channels. 212 | :return: an nn.Module for normalization. 213 | """ 214 | return GroupNorm32(32, channels) 215 | 216 | 217 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 218 | class SiLU(nn.Module): 219 | def forward(self, x): 220 | return x * torch.sigmoid(x) 221 | 222 | 223 | class GroupNorm32(nn.GroupNorm): 224 | def forward(self, x): 225 | return super().forward(x.float()).type(x.dtype) 226 | #return super().forward(x).type(x.dtype) 227 | 228 | def conv_nd(dims, *args, **kwargs): 229 | """ 230 | Create a 1D, 2D, or 3D convolution module. 231 | """ 232 | if dims == 1: 233 | return nn.Conv1d(*args, **kwargs) 234 | elif dims == 2: 235 | return nn.Conv2d(*args, **kwargs) 236 | elif dims == 3: 237 | return nn.Conv3d(*args, **kwargs) 238 | raise ValueError(f"unsupported dimensions: {dims}") 239 | 240 | 241 | def linear(*args, **kwargs): 242 | """ 243 | Create a linear module. 244 | """ 245 | return nn.Linear(*args, **kwargs) 246 | 247 | 248 | def avg_pool_nd(dims, *args, **kwargs): 249 | """ 250 | Create a 1D, 2D, or 3D average pooling module. 251 | """ 252 | if dims == 1: 253 | return nn.AvgPool1d(*args, **kwargs) 254 | elif dims == 2: 255 | return nn.AvgPool2d(*args, **kwargs) 256 | elif dims == 3: 257 | return nn.AvgPool3d(*args, **kwargs) 258 | raise ValueError(f"unsupported dimensions: {dims}") 259 | 260 | 261 | class HybridConditioner(nn.Module): 262 | 263 | def __init__(self, c_concat_config, c_crossattn_config): 264 | super().__init__() 265 | self.concat_conditioner = instantiate_from_config(c_concat_config) 266 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 267 | 268 | def forward(self, c_concat, c_crossattn): 269 | c_concat = self.concat_conditioner(c_concat) 270 | c_crossattn = self.crossattn_conditioner(c_crossattn) 271 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 272 | 273 | 274 | def noise_like(shape, device, repeat=False): 275 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 276 | noise = lambda: torch.randn(shape, device=device) 277 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | import clip 5 | from einops import rearrange, repeat 6 | from transformers import CLIPTokenizer, CLIPTextModel 7 | import kornia 8 | 9 | from ldm.modules.x_transformer import Encoder, TransformerWrapper 10 | 11 | 12 | class AbstractEncoder(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def encode(self, *args, **kwargs): 17 | raise NotImplementedError 18 | 19 | 20 | 21 | class ClassEmbedder(nn.Module): 22 | def __init__(self, embed_dim, n_classes=1000, key='class'): 23 | super().__init__() 24 | self.key = key 25 | self.embedding = nn.Embedding(n_classes, embed_dim) 26 | 27 | def forward(self, batch, key=None): 28 | if key is None: 29 | key = self.key 30 | # this is for use in crossattn 31 | c = batch[key][:, None] 32 | c = self.embedding(c) 33 | return c 34 | 35 | 36 | class TransformerEmbedder(AbstractEncoder): 37 | """Some transformer encoder layers""" 38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): 39 | super().__init__() 40 | self.device = device 41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 42 | attn_layers=Encoder(dim=n_embed, depth=n_layer)) 43 | 44 | def forward(self, tokens): 45 | tokens = tokens.to(self.device) # meh 46 | z = self.transformer(tokens, return_embeddings=True) 47 | return z 48 | 49 | def encode(self, x): 50 | return self(x) 51 | 52 | 53 | class BERTTokenizer(AbstractEncoder): 54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 55 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 56 | super().__init__() 57 | from transformers import BertTokenizerFast # TODO: add to reuquirements 58 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 59 | self.device = device 60 | self.vq_interface = vq_interface 61 | self.max_length = max_length 62 | 63 | def forward(self, text): 64 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 65 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt", 66 | return_offsets_mapping=True) 67 | tokens = batch_encoding["input_ids"].to(self.device) 68 | offset_mapping = batch_encoding["offset_mapping"] 69 | return tokens, offset_mapping 70 | 71 | @torch.no_grad() 72 | def encode(self, text): 73 | tokens = self(text) 74 | if not self.vq_interface: 75 | return tokens 76 | return None, None, [None, None, tokens] 77 | 78 | def decode(self, text): 79 | return text 80 | 81 | 82 | class BERTEmbedder(AbstractEncoder): 83 | """Uses the BERT tokenizr model and add some transformer encoder layers""" 84 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, 85 | device="cuda",use_tokenizer=True, embedding_dropout=0.0): 86 | super().__init__() 87 | self.use_tknz_fn = use_tokenizer 88 | if self.use_tknz_fn: 89 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) 90 | self.device = device 91 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 92 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 93 | emb_dropout=embedding_dropout) 94 | 95 | def forward(self, text, return_offset_mapping=False): 96 | if self.use_tknz_fn: 97 | tokens, offset_mapping = self.tknz_fn(text)#.to(self.device) 98 | else: 99 | assert False 100 | tokens = text 101 | z = self.transformer(tokens, return_embeddings=True) 102 | 103 | if return_offset_mapping: 104 | return z, offset_mapping 105 | else: 106 | return z 107 | 108 | def encode(self, text, return_offset_mapping=False): 109 | # output of length 77 110 | return self(text, return_offset_mapping) 111 | 112 | 113 | class SpatialRescaler(nn.Module): 114 | def __init__(self, 115 | n_stages=1, 116 | method='bilinear', 117 | multiplier=0.5, 118 | in_channels=3, 119 | out_channels=None, 120 | bias=False): 121 | super().__init__() 122 | self.n_stages = n_stages 123 | assert self.n_stages >= 0 124 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] 125 | self.multiplier = multiplier 126 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method) 127 | self.remap_output = out_channels is not None 128 | if self.remap_output: 129 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') 130 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) 131 | 132 | def forward(self,x): 133 | for stage in range(self.n_stages): 134 | x = self.interpolator(x, scale_factor=self.multiplier) 135 | 136 | 137 | if self.remap_output: 138 | x = self.channel_mapper(x) 139 | return x 140 | 141 | def encode(self, x): 142 | return self(x) 143 | 144 | class FrozenCLIPEmbedder(AbstractEncoder): 145 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 146 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): 147 | super().__init__() 148 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 149 | self.transformer = CLIPTextModel.from_pretrained(version) 150 | self.device = device 151 | self.max_length = max_length 152 | self.freeze() 153 | 154 | def freeze(self): 155 | self.transformer = self.transformer.eval() 156 | for param in self.parameters(): 157 | param.requires_grad = False 158 | 159 | def forward(self, text, return_pooler_output=False): 160 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 161 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 162 | tokens = batch_encoding["input_ids"].to(self.device) 163 | outputs = self.transformer(input_ids=tokens) 164 | 165 | z = outputs.last_hidden_state 166 | 167 | if not return_pooler_output: 168 | return z 169 | else: 170 | return z, outputs.pooler_output 171 | 172 | def encode(self, text, return_pooler_output=False): 173 | return self(text, return_pooler_output) 174 | 175 | 176 | class FrozenCLIPTextEmbedder(nn.Module): 177 | """ 178 | Uses the CLIP transformer encoder for text. 179 | """ 180 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): 181 | super().__init__() 182 | self.model, _ = clip.load(version, jit=False, device="cpu") 183 | self.device = device 184 | self.max_length = max_length 185 | self.n_repeat = n_repeat 186 | self.normalize = normalize 187 | 188 | def freeze(self): 189 | self.model = self.model.eval() 190 | for param in self.parameters(): 191 | param.requires_grad = False 192 | 193 | def forward(self, text): 194 | tokens = clip.tokenize(text).to(self.device) 195 | z = self.model.encode_text(tokens) 196 | if self.normalize: 197 | z = z / torch.linalg.norm(z, dim=1, keepdim=True) 198 | return z 199 | 200 | def encode(self, text): 201 | z = self(text) 202 | if z.ndim==2: 203 | z = z[:, None, :] 204 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) 205 | return z 206 | 207 | 208 | class FrozenClipImageEmbedder(nn.Module): 209 | """ 210 | Uses the CLIP image encoder. 211 | """ 212 | def __init__( 213 | self, 214 | model, 215 | jit=False, 216 | device='cuda' if torch.cuda.is_available() else 'cpu', 217 | antialias=False, 218 | ): 219 | super().__init__() 220 | self.model, _ = clip.load(name=model, device=device, jit=jit) 221 | 222 | self.antialias = antialias 223 | 224 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 225 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 226 | 227 | def preprocess(self, x): 228 | # normalize to [0,1] 229 | x = kornia.geometry.resize(x, (224, 224), 230 | interpolation='bicubic',align_corners=True, 231 | antialias=self.antialias) 232 | x = (x + 1.) / 2. 233 | # renormalize according to clip 234 | x = kornia.enhance.normalize(x, self.mean, self.std) 235 | return x 236 | 237 | def forward(self, x): 238 | # x is assumed to be in range [-1,1] 239 | return self.model.encode_image(self.preprocess(x)) 240 | 241 | 242 | if __name__ == "__main__": 243 | from ldm.util import count_params 244 | model = FrozenCLIPEmbedder() 245 | count_params(model, verbose=True) 246 | -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if not exists(codebook_loss): 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log 168 | -------------------------------------------------------------------------------- /ldm/modules/multimodal_attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | # from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder 9 | from torch.utils import checkpoint 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | 100 | 101 | 102 | class CrossAttention(nn.Module): 103 | def __init__(self, query_dim, key_dim, value_dim, heads=8, dim_head=64, dropout=0): 104 | super().__init__() 105 | inner_dim = dim_head * heads 106 | self.scale = dim_head ** -0.5 107 | self.heads = heads 108 | 109 | 110 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 111 | self.to_k = nn.Linear(key_dim, inner_dim, bias=False) 112 | self.to_v = nn.Linear(value_dim, inner_dim, bias=False) 113 | 114 | 115 | self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) 116 | 117 | 118 | def fill_inf_from_mask(self, sim, mask): 119 | if mask is not None: 120 | B,M = mask.shape 121 | mask = mask.unsqueeze(1).repeat(1,self.heads,1).reshape(B*self.heads,1,-1) 122 | max_neg_value = -torch.finfo(sim.dtype).max 123 | sim.masked_fill_(~mask, max_neg_value) 124 | return sim 125 | 126 | 127 | def forward(self, x, key, value, mask=None): 128 | 129 | q = self.to_q(x) # B*N*(H*C) 130 | k = self.to_k(key) # B*M*(H*C) 131 | v = self.to_v(value) # B*M*(H*C) 132 | 133 | B, N, HC = q.shape 134 | _, M, _ = key.shape 135 | H = self.heads 136 | C = HC // H 137 | 138 | q = q.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C 139 | k = k.view(B,M,H,C).permute(0,2,1,3).reshape(B*H,M,C) # (B*H)*M*C 140 | v = v.view(B,M,H,C).permute(0,2,1,3).reshape(B*H,M,C) # (B*H)*M*C 141 | 142 | sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale # (B*H)*N*M 143 | self.fill_inf_from_mask(sim, mask) 144 | attn = sim.softmax(dim=-1) # (B*H)*N*M 145 | 146 | out = torch.einsum('b i j, b j d -> b i d', attn, v) # (B*H)*N*C 147 | out = out.view(B,H,N,C).permute(0,2,1,3).reshape(B,N,(H*C)) # B*N*(H*C) 148 | 149 | return self.to_out(out) 150 | 151 | 152 | 153 | 154 | class SelfAttention(nn.Module): 155 | def __init__(self, query_dim, heads=8, dim_head=64, dropout=0.): 156 | super().__init__() 157 | inner_dim = dim_head * heads 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(query_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(query_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) 166 | 167 | def forward(self, x): 168 | q = self.to_q(x) # B*N*(H*C) 169 | k = self.to_k(x) # B*N*(H*C) 170 | v = self.to_v(x) # B*N*(H*C) 171 | 172 | B, N, HC = q.shape 173 | H = self.heads 174 | C = HC // H 175 | 176 | q = q.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C 177 | k = k.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C 178 | v = v.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C 179 | 180 | sim = torch.einsum('b i c, b j c -> b i j', q, k) * self.scale # (B*H)*N*N 181 | attn = sim.softmax(dim=-1) # (B*H)*N*N 182 | 183 | out = torch.einsum('b i j, b j c -> b i c', attn, v) # (B*H)*N*C 184 | out = out.view(B,H,N,C).permute(0,2,1,3).reshape(B,N,(H*C)) # B*N*(H*C) 185 | 186 | return self.to_out(out) 187 | 188 | 189 | 190 | class GatedCrossAttentionDense(nn.Module): 191 | def __init__(self, query_dim, key_dim, value_dim, n_heads, d_head): 192 | super().__init__() 193 | 194 | self.attn = CrossAttention(query_dim=query_dim, key_dim=key_dim, value_dim=value_dim, heads=n_heads, dim_head=d_head) 195 | self.ff = FeedForward(query_dim, glu=True) 196 | 197 | self.norm1 = nn.LayerNorm(query_dim) 198 | self.norm2 = nn.LayerNorm(query_dim) 199 | 200 | self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)) ) 201 | self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)) ) 202 | 203 | # this can be useful: we can externally change magnitude of tanh(alpha) 204 | # for example, when it is set to 0, then the entire model is same as original one 205 | self.scale = 1 206 | 207 | def forward(self, x, objs): 208 | 209 | x = x + self.scale*torch.tanh(self.alpha_attn) * self.attn( self.norm1(x), objs, objs) 210 | x = x + self.scale*torch.tanh(self.alpha_dense) * self.ff( self.norm2(x) ) 211 | 212 | return x 213 | 214 | 215 | class GatedSelfAttentionDense(nn.Module): 216 | def __init__(self, query_dim, context_dim, n_heads, d_head): 217 | super().__init__() 218 | 219 | # we need a linear projection since we need cat visual feature and obj feature 220 | self.linear = nn.Linear(context_dim, query_dim) 221 | 222 | self.attn = SelfAttention(query_dim=query_dim, heads=n_heads, dim_head=d_head) 223 | self.ff = FeedForward(query_dim, glu=True) 224 | 225 | self.norm1 = nn.LayerNorm(query_dim) 226 | self.norm2 = nn.LayerNorm(query_dim) 227 | 228 | self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)) ) 229 | self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)) ) 230 | 231 | # this can be useful: we can externally change magnitude of tanh(alpha) 232 | # for example, when it is set to 0, then the entire model is same as original one 233 | self.scale = 1 234 | 235 | 236 | def forward(self, x, objs): 237 | 238 | N_visual = x.shape[1] 239 | objs = self.linear(objs) 240 | 241 | x = x + self.scale*torch.tanh(self.alpha_attn) * self.attn( self.norm1(torch.cat([x,objs],dim=1)) )[:,0:N_visual,:] 242 | x = x + self.scale*torch.tanh(self.alpha_dense) * self.ff( self.norm2(x) ) 243 | 244 | return x 245 | 246 | 247 | class BasicTransformerBlock(nn.Module): 248 | def __init__(self, query_dim, key_dim, value_dim, n_heads, d_head, fuser_type, use_checkpoint=True, use_sp=True, use_nsp=True): 249 | super().__init__() 250 | self.attn1 = SelfAttention(query_dim=query_dim, heads=n_heads, dim_head=d_head) 251 | self.ff = FeedForward(query_dim, glu=True) 252 | self.attn2 = CrossAttention(query_dim=query_dim, key_dim=key_dim, value_dim=value_dim, heads=n_heads, dim_head=d_head) 253 | self.norm1 = nn.LayerNorm(query_dim) 254 | self.norm2 = nn.LayerNorm(query_dim) 255 | self.norm3 = nn.LayerNorm(query_dim) 256 | self.use_checkpoint = use_checkpoint 257 | 258 | # note key_dim and value_dim here actually are context_dim 259 | if fuser_type == "gatedSA": 260 | if use_sp: 261 | self.sp_fuser = GatedSelfAttentionDense(query_dim, key_dim, n_heads, d_head) 262 | if use_nsp: 263 | self.nsp_fuser = GatedSelfAttentionDense(key_dim, key_dim, n_heads, d_head) # match with text dim 264 | elif fuser_type == "gatedCA": 265 | if use_sp: 266 | self.sp_fuser = GatedCrossAttentionDense(query_dim, key_dim, value_dim, n_heads, d_head) 267 | if use_nsp: 268 | self.nsp_fuser = GatedCrossAttentionDense(key_dim, key_dim, value_dim, n_heads, d_head) # match with text dim 269 | elif fuser_type == "gatedSA-gatedCA": 270 | if use_sp: 271 | self.sp_fuser = GatedSelfAttentionDense(query_dim, key_dim, n_heads, d_head) 272 | if use_nsp: 273 | self.nsp_fuser = GatedCrossAttentionDense(key_dim, key_dim, value_dim, n_heads, d_head) # match with text dim 274 | elif fuser_type == "gatedCA-gatedSA": 275 | if use_sp: 276 | self.sp_fuser = GatedCrossAttentionDense(query_dim, key_dim, value_dim, n_heads, d_head) 277 | if use_nsp: 278 | self.nsp_fuser = GatedSelfAttentionDense(key_dim, key_dim, n_heads, d_head) # match with text dim 279 | else: 280 | assert False 281 | 282 | 283 | def forward(self, x, context, sp_objs, nsp_objs): 284 | # return checkpoint(self._forward, (x, context, objs), self.parameters(), self.use_checkpoint) 285 | if self.use_checkpoint and x.requires_grad: 286 | return checkpoint.checkpoint(self._forward, x, context, sp_objs, nsp_objs) 287 | else: 288 | return self._forward(x, context, sp_objs, nsp_objs) 289 | 290 | def _forward(self, x, context, sp_objs, nsp_objs): 291 | x = self.attn1(self.norm1(x)) + x 292 | if sp_objs is not None: 293 | x = self.sp_fuser(x, sp_objs) # identity mapping in the beginning 294 | if nsp_objs is not None: 295 | context = self.nsp_fuser(context, nsp_objs) # identity mapping in the beginning 296 | x = self.attn2(self.norm2(x), context, context) + x 297 | x = self.ff(self.norm3(x)) + x 298 | return x 299 | 300 | 301 | class SpatialTransformer(nn.Module): 302 | def __init__(self, in_channels, key_dim, value_dim, n_heads, d_head, depth=1, fuser_type=None, use_checkpoint=True, use_sp=True, use_nsp=True): 303 | super().__init__() 304 | self.in_channels = in_channels 305 | query_dim = n_heads * d_head 306 | self.norm = Normalize(in_channels) 307 | 308 | 309 | self.proj_in = nn.Conv2d(in_channels, 310 | query_dim, 311 | kernel_size=1, 312 | stride=1, 313 | padding=0) 314 | 315 | self.transformer_blocks = nn.ModuleList( 316 | [BasicTransformerBlock(query_dim, key_dim, value_dim, n_heads, d_head, fuser_type, use_checkpoint=use_checkpoint, use_sp=use_sp, use_nsp=use_nsp) 317 | for d in range(depth)] 318 | ) 319 | 320 | self.proj_out = zero_module(nn.Conv2d(query_dim, 321 | in_channels, 322 | kernel_size=1, 323 | stride=1, 324 | padding=0)) 325 | 326 | def forward(self, x, context, sp_objs, nsp_objs): 327 | b, c, h, w = x.shape 328 | x_in = x 329 | x = self.norm(x) 330 | x = self.proj_in(x) 331 | x = rearrange(x, 'b c h w -> b (h w) c') 332 | for block in self.transformer_blocks: 333 | x = block(x, context, sp_objs, nsp_objs) 334 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 335 | x = self.proj_out(x) 336 | return x + x_in 337 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from inspect import isfunction 7 | from PIL import Image, ImageDraw, ImageFont 8 | 9 | 10 | def log_txt_as_img(wh, xc, size=10): 11 | # wh a tuple of (width, height) 12 | # xc a list of captions to plot 13 | b = len(xc) 14 | txts = list() 15 | for bi in range(b): 16 | txt = Image.new("RGB", wh, color="white") 17 | draw = ImageDraw.Draw(txt) 18 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 19 | nc = int(40 * (wh[0] / 256)) 20 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 21 | 22 | try: 23 | draw.text((0, 0), lines, fill="black", font=font) 24 | except UnicodeEncodeError: 25 | print("Cant encode string for logging. Skipping.") 26 | 27 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 28 | txts.append(txt) 29 | txts = np.stack(txts) 30 | txts = torch.tensor(txts) 31 | return txts 32 | 33 | 34 | def ismap(x): 35 | if not isinstance(x, torch.Tensor): 36 | return False 37 | return (len(x.shape) == 4) and (x.shape[1] > 3) 38 | 39 | 40 | def isimage(x): 41 | if not isinstance(x,torch.Tensor): 42 | return False 43 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 44 | 45 | 46 | def exists(x): 47 | return x is not None 48 | 49 | 50 | def default(val, d): 51 | if exists(val): 52 | return val 53 | return d() if isfunction(d) else d 54 | 55 | 56 | def mean_flat(tensor): 57 | """ 58 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 59 | Take the mean over all non-batch dimensions. 60 | """ 61 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 62 | 63 | 64 | def count_params(model, verbose=False): 65 | total_params = sum(p.numel() for p in model.parameters()) 66 | if verbose: 67 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 68 | return total_params 69 | 70 | 71 | def instantiate_from_config(config): 72 | if not "target" in config: 73 | if config == '__is_first_stage__': 74 | return None 75 | elif config == "__is_unconditional__": 76 | return None 77 | raise KeyError("Expected key `target` to instantiate.") 78 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 79 | 80 | 81 | def get_obj_from_str(string, reload=False): 82 | module, cls = string.rsplit(".", 1) 83 | if reload: 84 | module_imp = importlib.import_module(module) 85 | importlib.reload(module_imp) 86 | return getattr(importlib.import_module(module, package=None), cls) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.0 2 | torchvision==0.14.0 3 | albumentations==0.4.3 4 | opencv-python 5 | pudb==2019.2 6 | imageio==2.9.0 7 | imageio-ffmpeg==0.4.2 8 | pytorch-lightning==1.4.2 9 | omegaconf==2.1.1 10 | test-tube>=0.7.5 11 | streamlit>=0.73.1 12 | einops==0.3.0 13 | torch-fidelity==0.3.0 14 | git+https://github.com/openai/CLIP.git 15 | protobuf~=3.20.1 16 | torchmetrics==0.6.0 17 | transformers==4.19.2 18 | kornia==0.5.8 19 | 20 | wandb 21 | scikit-image 22 | 23 | # uninstall 24 | # pip uninstall -y torchtext 25 | -------------------------------------------------------------------------------- /visualization/draw_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from copy import deepcopy 3 | import visualization.image_utils as iutils 4 | 5 | 6 | 7 | def draw_sketch_with_batch_to_tensor(batch): 8 | if "sketch" not in batch or batch["sketch"] is None: 9 | return torch.zeros_like(batch["image"]) 10 | else: 11 | return batch["sketch"]["values"] 12 | 13 | def draw_depth_with_batch_to_tensor(batch): 14 | if "depth" not in batch or batch["depth"] is None: 15 | return torch.zeros_like(batch["image"]) 16 | else: 17 | return batch["depth"]["values"] 18 | 19 | def draw_boxes_with_batch_to_tensor(batch): 20 | if "box" not in batch or batch["box"] is None: 21 | return torch.zeros_like(batch["image"]) 22 | 23 | batch_size = batch["image"].size(0) 24 | box_drawing = [] 25 | for i in range(batch_size): 26 | if "box" in batch: 27 | info_dict = {"image": batch["image"][i], "boxes": batch["box"]["values"][i]} 28 | boxed_img = iutils.vis_boxes(info_dict) 29 | else: 30 | boxed_img = torch.randn_like(batch["image"][i]) 31 | box_drawing.append(boxed_img) 32 | box_tensor = torch.stack(box_drawing) 33 | return box_tensor 34 | 35 | def draw_keypoints_with_batch_to_tensor(batch): 36 | if "keypoint" not in batch or batch["keypoint"] is None: 37 | return torch.zeros_like(batch["image"]) 38 | 39 | batch_size = batch["image"].size(0) 40 | keypoint_drawing = [] 41 | for i in range(batch_size): 42 | if "keypoint" in batch: 43 | info_dict = {"image": batch["image"][i], "points": batch["keypoint"]["values"][i]} 44 | keypointed_img = iutils.vis_keypoints(info_dict) 45 | else: 46 | keypointed_img = torch.randn_like(batch["image"][i]) 47 | keypoint_drawing.append(keypointed_img) 48 | keypoint_tensor = torch.stack(keypoint_drawing) 49 | return keypoint_tensor 50 | 51 | def draw_color_palettes_with_batch_to_tensor(batch): 52 | if "color_palette" not in batch or batch["color_palette"] is None: 53 | return torch.zeros_like(batch["image"]) 54 | try: 55 | batch_size = batch["image"].size(0) 56 | color_palette_drawing = [] 57 | for i in range(batch_size): 58 | if "color_palette" in batch: 59 | color_hist = deepcopy(batch["color_palette"]["values"][i]) 60 | color_palette = iutils.vis_color_palette(color_hist, batch["image"].shape[-1]) 61 | else: 62 | color_palette = torch.randn_like(batch["image"][i]) 63 | color_palette_drawing.append(color_palette) 64 | color_palette_tensor = torch.stack(color_palette_drawing) 65 | return color_palette_tensor 66 | except: 67 | print(f">> Exception occured in draw_color_palettes_with_batch_to_tensor(batch)..") 68 | return torch.zeros_like(batch["image"]) 69 | 70 | def draw_image_embedding_with_batch_to_tensor(batch): 71 | if "image_embedding" not in batch or batch["image_embedding"] is None: 72 | return -torch.ones_like(batch["image"]) 73 | else: 74 | if "image" in batch["image_embedding"]: 75 | return batch["image_embedding"]["image"] 76 | else: return -torch.ones_like(batch["image"]) 77 | 78 | 79 | -------------------------------------------------------------------------------- /visualization/extract_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from PIL import Image 5 | from transformers import CLIPProcessor, CLIPModel 6 | 7 | from skimage.color import hsv2rgb, rgb2lab, lab2rgb 8 | from skimage.io import imread, imsave 9 | from sklearn.metrics import euclidean_distances 10 | 11 | 12 | @torch.no_grad() 13 | def get_clip_feature(img_path, model, processor): 14 | # clip_features = dict() 15 | image = Image.open(img_path).convert("RGB") 16 | inputs = processor(images=[image], return_tensors="pt", padding=True) 17 | inputs['pixel_values'] = inputs['pixel_values'].cuda() 18 | inputs['input_ids'] = torch.tensor([[0,1,2,3]]).cuda() 19 | outputs = model(**inputs) 20 | feature = outputs.image_embeds 21 | return feature.squeeze().cpu().numpy() # dim: [768,] 22 | 23 | 24 | 25 | 26 | 27 | '---------------------------------------- color converter ----------------------------------------' 28 | 29 | def rgb2hex(rgb_number): 30 | """ 31 | Args: 32 | - rgb_number (sequence of float) 33 | Returns: 34 | - hex_number (string) 35 | """ 36 | # print(rgb_number, [np.round(val*255) for val in rgb_number]) 37 | return '#{:02x}{:02x}{:02x}'.format(*tuple([int(np.round(val * 255)) for val in rgb_number])) 38 | 39 | 40 | def hex2rgb(hexcolor_str): 41 | """ 42 | Args: 43 | - hexcolor_str (string): e.g. '#ffffff' or '33cc00' 44 | Returns: 45 | - rgb_color (sequence of floats): e.g. (0.2, 0.3, 0) 46 | """ 47 | color = hexcolor_str.strip('#') 48 | # rgb = lambda x: round(int(x, 16) / 255., 5) 49 | return tuple(round(int(color[i:i+2], 16) / 255., 5) for i in (0, 2, 4)) 50 | 51 | 52 | 53 | '---------------------------------------- color palette histogram ----------------------------------------' 54 | 55 | def histogram_colors_smoothed(lab_array, palette, sigma=10, 56 | plot_filename=None, direct=True): 57 | """ 58 | Returns a palette histogram of colors in the image, smoothed with 59 | a Gaussian. Can smooth directly per-pixel, or after computing a strict 60 | histogram. 61 | Parameters 62 | ---------- 63 | lab_array : (N,3) ndarray 64 | The L*a*b color of each of N pixels. 65 | palette : rayleigh.Palette 66 | Containing K colors. 67 | sigma : float 68 | Variance of the smoothing Gaussian. 69 | direct : bool, optional 70 | If True, constructs a smoothed histogram directly from pixels. 71 | If False, constructs a nearest-color histogram and then smoothes it. 72 | Returns 73 | ------- 74 | color_hist : (K,) ndarray 75 | """ 76 | if direct: 77 | color_hist_smooth = histogram_colors_with_smoothing( 78 | lab_array, palette, sigma) 79 | else: 80 | color_hist_strict = histogram_colors_strict(lab_array, palette) 81 | color_hist_smooth = smooth_histogram(color_hist_strict, palette, sigma) 82 | if plot_filename is not None: 83 | plot_histogram(color_hist_smooth, palette, plot_filename) 84 | return color_hist_smooth 85 | 86 | def smooth_histogram(color_hist, palette, sigma=10): 87 | """ 88 | Smooth the given palette histogram with a Gaussian of variance sigma. 89 | Parameters 90 | ---------- 91 | color_hist : (K,) ndarray 92 | palette : rayleigh.Palette 93 | containing K colors. 94 | Returns 95 | ------- 96 | color_hist_smooth : (K,) ndarray 97 | """ 98 | n = 2. * sigma ** 2 99 | weights = np.exp(-palette.distances / n) 100 | norm_weights = weights / weights.sum(1)[:, np.newaxis] 101 | color_hist_smooth = (norm_weights * color_hist).sum(1) 102 | color_hist_smooth[color_hist_smooth < 1e-5] = 0 103 | return color_hist_smooth 104 | 105 | def histogram_colors_with_smoothing(lab_array, palette, sigma=10): 106 | """ 107 | Assign colors in the image to nearby colors in the palette, weighted by 108 | distance in Lab color space. 109 | Parameters 110 | ---------- 111 | lab_array (N,3) ndarray: 112 | N is the number of data points, columns are L, a, b values. 113 | palette : rayleigh.Palette 114 | containing K colors. 115 | sigma : float 116 | (0,1] value to control the steepness of exponential falloff. 117 | To see the effect: 118 | >>> from pylab import * 119 | >>> ds = linspace(0,5000) # squared distance 120 | >>> sigma=10; plot(ds, exp(-ds/(2*sigma**2)), label='$\sigma=%.1f$'%sigma) 121 | >>> sigma=20; plot(ds, exp(-ds/(2*sigma**2)), label='$\sigma=%.1f$'%sigma) 122 | >>> sigma=40; plot(ds, exp(-ds/(2*sigma**2)), label='$\sigma=%.1f$'%sigma) 123 | >>> ylim([0,1]); legend(); 124 | >>> xlabel('Squared distance'); ylabel('Weight'); 125 | >>> title('Exponential smoothing') 126 | >>> #plt.savefig('exponential_smoothing.png', dpi=300) 127 | sigma=20 seems reasonable: hits 0 around squared distance of 4000. 128 | Returns: 129 | color_hist : (K,) ndarray 130 | the normalized, smooth histogram of colors. 131 | """ 132 | dist = euclidean_distances(palette.lab_array, lab_array, squared=True).T 133 | n = 2. * sigma ** 2 134 | weights = np.exp(-dist / n) 135 | 136 | # normalize by sum: if a color is equally well represented by several colors 137 | # it should not contribute much to the overall histogram 138 | normalizing = weights.sum(1) 139 | normalizing[normalizing == 0] = 1e16 140 | normalized_weights = weights / normalizing[:, np.newaxis] 141 | 142 | color_hist = normalized_weights.sum(0) 143 | color_hist /= lab_array.shape[0] 144 | color_hist[color_hist < 1e-5] = 0 145 | return color_hist 146 | 147 | def histogram_colors_strict(lab_array, palette, plot_filename=None): 148 | """ 149 | Return a palette histogram of colors in the image. 150 | Parameters 151 | ---------- 152 | lab_array : (N,3) ndarray 153 | The L*a*b color of each of N pixels. 154 | palette : rayleigh.Palette 155 | Containing K colors. 156 | plot_filename : string, optional 157 | If given, save histogram to this filename. 158 | Returns 159 | ------- 160 | color_hist : (K,) ndarray 161 | """ 162 | # This is the fastest way that I've found. 163 | # >>> %%timeit -n 200 from sklearn.metrics import euclidean_distances 164 | # >>> euclidean_distances(palette, lab_array, squared=True) 165 | dist = euclidean_distances(palette.lab_array, lab_array, squared=True).T 166 | min_ind = np.argmin(dist, axis=1) 167 | num_colors = palette.lab_array.shape[0] 168 | num_pixels = lab_array.shape[0] 169 | color_hist = 1. * np.bincount(min_ind, minlength=num_colors) / num_pixels 170 | if plot_filename is not None: 171 | plot_histogram(color_hist, palette, plot_filename) 172 | return color_hist 173 | 174 | def plot_histogram(color_hist, palette, plot_filename=None): 175 | """ 176 | Return Figure containing the color palette histogram. 177 | Args: 178 | - color_hist (K, ndarray) 179 | - palette (Palette) 180 | - plot_filename (string) [default=None]: 181 | Save histogram to this file, if given. 182 | Returns: 183 | - fig (Figure) 184 | """ 185 | fig = plt.figure(figsize=(5, 3), dpi=150) 186 | ax = fig.add_subplot(111) 187 | ax.bar( 188 | range(len(color_hist)), color_hist, 189 | color=palette.hex_list) 190 | ax.set_ylim((0, 0.1)) 191 | ax.xaxis.set_ticks([]) 192 | ax.set_xlim((0, len(palette.hex_list))) 193 | if plot_filename: 194 | fig.savefig(plot_filename, dpi=150, facecolor='none') 195 | return fig 196 | 197 | 198 | '---------------------------------------- histogram to image ----------------------------------------' 199 | # only for visualization 200 | # for extracting color palette, color histogram is enough. 201 | 202 | def color_hist_to_palette_image(color_hist, palette, percentile=90, 203 | width=200, height=50, filename=None): 204 | """ 205 | Output the main colors in the histogram to a "palette image." 206 | Parameters 207 | ---------- 208 | color_hist : (K,) ndarray 209 | palette : rayleigh.Palette 210 | percentile : int, optional: 211 | Output only colors above this percentile of prevalence in the histogram. 212 | filename : string, optional: 213 | If given, save the resulting image to file. 214 | Returns 215 | ------- 216 | rgb_image : ndarray 217 | """ 218 | ind = np.argsort(-color_hist) 219 | ind = ind[color_hist[ind] > np.percentile(color_hist, percentile)] 220 | hex_list = np.take(palette.hex_list, ind) 221 | values = color_hist[ind] 222 | rgb_image = palette_query_to_rgb_image(dict(zip(hex_list, values))) 223 | if filename: 224 | imsave(filename, rgb_image) 225 | return rgb_image 226 | 227 | 228 | def palette_query_to_rgb_image(palette_query, width=200, height=50): 229 | """ 230 | Convert a list of hex colors and their values to an RGB image of given 231 | width and height. 232 | Args: 233 | - palette_query (dict): 234 | a dictionary of hex colors to unnormalized values, 235 | e.g. {'#ffffff': 20, '#33cc00': 0.4}. 236 | """ 237 | hex_list, values = zip(*palette_query.items()) 238 | values = np.array(values) 239 | values /= values.sum() 240 | nums = np.array(values * width, dtype=int) 241 | rgb_arrays = (np.tile(np.array(hex2rgb(x)), (num, 1)) 242 | for x, num in zip(hex_list, nums)) 243 | rgb_array = np.vstack(rgb_arrays) 244 | rgb_image = rgb_array[np.newaxis, :, :] 245 | rgb_image = np.tile(rgb_image, (height, 1, 1)) 246 | return rgb_image 247 | 248 | 249 | 250 | 251 | '---------------------------------------- Color Palette ----------------------------------------' 252 | 253 | class Palette(object): 254 | """ 255 | Create a color palette (codebook) in the form of a 2D grid of colors, 256 | as described in the parameters list below. 257 | Further, the rightmost column has num_hues gradations from black to white. 258 | Parameters 259 | ---------- 260 | num_hues : int 261 | number of colors with full lightness and saturation, in the middle 262 | sat_range : int 263 | number of rows above middle row that show 264 | the same hues with decreasing saturation. 265 | light_range : int 266 | number of rows below middle row that show 267 | the same hues with decreasing lightness. 268 | Returns 269 | ------- 270 | palette: rayleigh.Palette 271 | """ 272 | 273 | def __init__(self, num_hues=8, sat_range=2, light_range=2): 274 | height = 1 + sat_range + (2 * light_range - 1) 275 | # generate num_hues+1 hues, but don't take the last one: 276 | # hues are on a circle, and we would be oversampling the origin 277 | hues = np.tile(np.linspace(0, 1, num_hues + 1)[:-1], (height, 1)) 278 | if num_hues == 8: 279 | hues = np.tile(np.array( 280 | [0., 0.10, 0.15, 0.28, 0.51, 0.58, 0.77, 0.85]), (height, 1)) 281 | if num_hues == 9: 282 | hues = np.tile(np.array( 283 | [0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.7, 0.87]), (height, 1)) 284 | if num_hues == 10: 285 | hues = np.tile(np.array( 286 | [0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.66, 0.76, 0.87]), (height, 1)) 287 | elif num_hues == 11: 288 | hues = np.tile(np.array( 289 | [0.0, 0.0833, 0.166, 0.25, 290 | 0.333, 0.5, 0.56333, 291 | 0.666, 0.73, 0.803, 292 | 0.916]), (height, 1)) 293 | 294 | sats = np.hstack(( 295 | np.linspace(0, 1, sat_range + 2)[1:-1], 296 | 1, 297 | [1] * (light_range), 298 | [.4] * (light_range - 1), 299 | )) 300 | lights = np.hstack(( 301 | [1] * sat_range, 302 | 1, 303 | np.linspace(1, 0.2, light_range + 2)[1:-1], 304 | np.linspace(1, 0.2, light_range + 2)[1:-2], 305 | )) 306 | 307 | sats = np.tile(np.atleast_2d(sats).T, (1, num_hues)) 308 | lights = np.tile(np.atleast_2d(lights).T, (1, num_hues)) 309 | colors = hsv2rgb(np.dstack((hues, sats, lights))) 310 | grays = np.tile( 311 | np.linspace(1, 0, height)[:, np.newaxis, np.newaxis], (1, 1, 3)) 312 | 313 | self.rgb_image = np.hstack((colors, grays)) 314 | 315 | # Make a nice histogram ordering of the hues and grays 316 | h, w, d = colors.shape 317 | color_array = colors.T.reshape((d, w * h)).T 318 | h, w, d = grays.shape 319 | gray_array = grays.T.reshape((d, w * h)).T 320 | 321 | self.rgb_array = np.vstack((color_array, gray_array)) 322 | self.lab_array = rgb2lab(self.rgb_array[None, :, :]).squeeze() 323 | self.hex_list = [rgb2hex(row) for row in self.rgb_array] 324 | #assert(np.all(self.rgb_array == self.rgb_array[None, :, :].squeeze())) 325 | 326 | self.distances = euclidean_distances(self.lab_array, squared=True) 327 | 328 | 329 | '---------------------------------------- Image Wrapper for lab_array ----------------------------------------' 330 | 331 | class ColorImage(object): 332 | """ 333 | Read the image at the URL in RGB format, downsample if needed, 334 | and convert to Lab colorspace. 335 | Store original dimensions, resize_factor, and the filename of the image. 336 | Image dimensions will be resized independently such that neither width nor 337 | height exceed the maximum allowed dimension MAX_DIMENSION. 338 | Parameters 339 | ---------- 340 | url : string 341 | URL or file path of the image to load. 342 | id : string, optional 343 | Name or some other id of the image. For example, the Flickr ID. 344 | """ 345 | 346 | MAX_DIMENSION = 240 + 1 347 | 348 | def __init__(self, url, _id=None): 349 | self.id = _id 350 | self.url = url 351 | img = imread(url) 352 | 353 | # Handle grayscale and RGBA images. 354 | # TODO: Should be smarter here in the future, but for now simply remove 355 | # the alpha channel if present. 356 | if img.ndim == 2: 357 | img = np.tile(img[:, :, np.newaxis], (1, 1, 3)) 358 | elif img.ndim == 4: 359 | img = img[:, :, :3] 360 | elif img.ndim == 3 and img.shape[2] == 4: 361 | img = img[:, :, :3] 362 | 363 | # Downsample for speed. 364 | # 365 | # NOTE: I can't find a good method to resize properly in Python! 366 | # scipy.misc.imresize uses PIL, which needs 8bit data. 367 | # Anyway, this is faster and almost as good. 368 | # 369 | # >>> def d(dim, max_dim): return arange(0, dim, dim / max_dim + 1).shape 370 | # >>> plot(range(1200), [d(x, 200) for x in range(1200)]) 371 | h, w, d = tuple(img.shape) 372 | self.orig_h, self.orig_w, self.orig_d = tuple(img.shape) 373 | h_stride = h // self.MAX_DIMENSION + 1 374 | w_stride = w // self.MAX_DIMENSION + 1 375 | img = img[::h_stride, ::w_stride, :] 376 | 377 | # Convert to L*a*b colors. 378 | h, w, d = img.shape 379 | self.h, self.w, self.d = img.shape 380 | self.lab_array = rgb2lab(img).reshape((h * w, d)) 381 | 382 | def as_dict(self): 383 | """ 384 | Return relevant info about self in a dict. 385 | """ 386 | return {'id': self.id, 'url': self.url, 387 | 'resized_width': self.w, 'resized_height': self.h, 388 | 'width': self.orig_w, 'height': self.orig_h} 389 | 390 | def output_quantized_to_palette(self, palette, filename): 391 | """ 392 | Save to filename a version of the image with all colors quantized 393 | to the nearest color in the given palette. 394 | Parameters 395 | ---------- 396 | palette : rayleigh.Palette 397 | Containing K colors. 398 | filename : string 399 | Where image will be written. 400 | """ 401 | dist = euclidean_distances( 402 | palette.lab_array, self.lab_array, squared=True).T 403 | min_ind = np.argmin(dist, axis=1) 404 | quantized_lab_array = palette.lab_array[min_ind, :] 405 | img = lab2rgb(quantized_lab_array.reshape((self.h, self.w, self.d))) 406 | imsave(filename, img) 407 | 408 | 409 | def get_color_palette(img_path): 410 | palette = Palette(num_hues=11, sat_range=5, light_range=5) 411 | assert len(palette.hex_list) == 180 412 | query_img = ColorImage(url=img_path) 413 | color_hist = histogram_colors_smoothed(query_img.lab_array, palette, sigma=10, direct=False) 414 | return color_hist 415 | 416 | 417 | @torch.no_grad() 418 | def get_box(locations, phrases, model, processor): 419 | boxes = torch.zeros(30, 4) 420 | masks = torch.zeros(30) 421 | text_embeddings = torch.zeros(30, 768) 422 | 423 | text_features = [] 424 | image_features = [] 425 | for phrase in phrases: 426 | inputs = processor(text=phrase, return_tensors="pt", padding=True) 427 | inputs['input_ids'] = inputs['input_ids'].cuda() 428 | inputs['pixel_values'] = torch.ones(1,3,224,224).cuda() # placeholder 429 | inputs['attention_mask'] = inputs['attention_mask'].cuda() 430 | outputs = model(**inputs) 431 | feature = outputs.text_model_output.pooler_output 432 | text_features.append(feature.squeeze()) 433 | 434 | for idx, (box, text_feature) in enumerate(zip(locations, text_features)): 435 | boxes[idx] = torch.tensor(box) 436 | masks[idx] = 1 437 | text_embeddings[idx] = text_feature 438 | 439 | return boxes, masks, text_embeddings 440 | 441 | 442 | def get_keypoint(locations): 443 | points = torch.zeros(8*17,2) 444 | idx = 0 445 | for this_person_kp in locations: 446 | for kp in this_person_kp: 447 | points[idx,0] = kp[0] 448 | points[idx,1] = kp[1] 449 | idx += 1 450 | 451 | # derive masks from points 452 | masks = (points.mean(dim=1)!=0) * 1 453 | masks = masks.float() 454 | return points, masks 455 | 456 | 457 | def transform_image_from_pil_to_numpy(pil_image): 458 | np_image, trans_info = center_crop_arr(pil_image, image_size=512) 459 | 460 | if np_image.ndim == 2: # when we load the image with "L" option 461 | np_image = np.expand_dims(np_image, axis=2) 462 | return np_image, trans_info 463 | 464 | def flip_image_from_numpy_to_numpy(np_image, trans_info, is_flip=False): 465 | if is_flip: 466 | np_image = np_image[:, ::-1] 467 | trans_info["performed_flip"] = True 468 | return np_image, trans_info 469 | else: 470 | return np_image, trans_info 471 | 472 | def convert_image_from_numpy_to_tensor(np_image, type_mask=False): 473 | if type_mask: 474 | """ 475 | value range : (0, 1), for sketch, segm, mask, 476 | """ 477 | np_image = np_image.astype(np.float32) / 255.0 478 | else: 479 | """ 480 | value range : (-1, 1), for rgb 481 | """ 482 | np_image = np_image.astype(np.float32) / 127.5 - 1 483 | np_image = np.transpose(np_image, [2,0,1]) 484 | return torch.tensor(np_image) 485 | 486 | def invert_image_from_numpy_to_numpy(np_image): 487 | return 255.0 - np_image 488 | 489 | def center_crop_arr(pil_image, image_size): 490 | # We are not on a new enough PIL to support the `reducing_gap` 491 | # argument, which uses BOX downsampling at powers of two first. 492 | # Thus, we do it by hand to improve downsample quality. 493 | WW, HH = pil_image.size 494 | 495 | while min(*pil_image.size) >= 2 * image_size: 496 | pil_image = pil_image.resize( 497 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 498 | ) 499 | 500 | scale = image_size / min(*pil_image.size) 501 | 502 | pil_image = pil_image.resize( 503 | tuple(round(x * scale) for x in pil_image.size), resample=Image.Resampling.BICUBIC 504 | ) 505 | 506 | # at this point, the min of pil_image side is desired image_size 507 | performed_scale = image_size / min(WW, HH) 508 | 509 | arr = np.array(pil_image) 510 | crop_y = (arr.shape[0] - image_size) // 2 511 | crop_x = (arr.shape[1] - image_size) // 2 512 | 513 | info = {"performed_scale":performed_scale, 'crop_y':crop_y, 'crop_x':crop_x, "WW":WW, 'HH':HH, "performed_flip": False} 514 | 515 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size], info 516 | 517 | 518 | def get_sketch(img_path): 519 | pil_sketch = Image.open(img_path).convert('RGB') 520 | np_sketch,_ = transform_image_from_pil_to_numpy(pil_sketch) 521 | np_sketch = invert_image_from_numpy_to_numpy(np_sketch) 522 | sketch_tensor = convert_image_from_numpy_to_tensor(np_sketch, type_mask=True) # mask type: range 0 to 1 523 | return sketch_tensor 524 | 525 | 526 | def get_depth(img_path): 527 | pil_depth = Image.open(img_path).convert('RGB') 528 | np_depth,_ = transform_image_from_pil_to_numpy(pil_depth) 529 | depth_tensor = convert_image_from_numpy_to_tensor(np_depth, type_mask=True) # mask type: range 0 to 1 530 | return depth_tensor 531 | 532 | -------------------------------------------------------------------------------- /visualization/image_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from os.path import join 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torchvision.utils import save_image, make_grid 8 | import torchvision 9 | 10 | from PIL import Image, ImageDraw 11 | from skimage.color import hsv2rgb, rgb2lab, lab2rgb 12 | 13 | 14 | PP_RGB = 1 15 | PP_SEGM = 2 16 | 17 | def convert_zero2one(image_tensor): 18 | return (image_tensor.detach().cpu()+1)/2.0 19 | 20 | def postprocess(image_tensors, pp_type): 21 | """ 22 | image_tensors: 23 | (B x C=3 x H x W) 24 | torch.tensors 25 | pp_type: 26 | int 27 | """ 28 | if pp_type == PP_RGB: 29 | return convert_zero2one(image_tensors) 30 | elif pp_type == PP_SEGM: 31 | if image_tensors.size(1) == 3: 32 | return image_tensors.detach().cpu() 33 | else: 34 | if image_tensors.ndim == 4: 35 | return image_tensors.detach().cpu().repeat(1, 3, 1, 1) 36 | elif image_tensors.ndim == 5: 37 | return image_tensors.detach().cpu().repeat(1, 1, 3, 1, 1) 38 | else: 39 | raise NotImplementedError 40 | else: 41 | raise NotImplementedError 42 | 43 | 44 | def get_scale_factor(image_tensors, feature_tensors): 45 | """ 46 | image_tensors: 47 | (B x C=3 x H x W) 48 | feature_tensors: 49 | (B x C x h x w) 50 | """ 51 | B, _, H, W = image_tensors.size() 52 | _, _, h, w = feature_tensors.size() 53 | scale_factor = int(H / h) 54 | return scale_factor 55 | 56 | def do_scale(image_tensors, feature_tensors): 57 | """ 58 | image_tensors: 59 | (B x C=3 x H x W) 60 | feature_tensors: 61 | (B x C x h x w) 62 | """ 63 | scale_factor = get_scale_factor(image_tensors, feature_tensors) 64 | scaled_tensor = F.interpolate(feature_tensors, scale_factor=scale_factor) 65 | return scaled_tensor 66 | 67 | def save_images_from_dict( 68 | image_dict, dir_path, file_name, n_instance, 69 | is_save=False, save_per_instance=False, return_images=False, 70 | save_per_instance_idxs=[]): 71 | """ 72 | image_dict: 73 | [ 74 | { 75 | "tensors": tensors:tensor, 76 | "n_in_row": int, 77 | "pp_type": int 78 | }, 79 | ... 80 | """ 81 | 82 | n_row = 0 83 | 84 | for each_item in image_dict: 85 | tensors = each_item["tensors"] 86 | bs = tensors.size(0) 87 | n_instance = min(bs, n_instance) 88 | 89 | n_in_row = each_item["n_in_row"] 90 | n_row += n_in_row 91 | 92 | pp_type = each_item["pp_type"] 93 | post_tensor = postprocess(tensors, pp_type=pp_type) 94 | each_item["tensors"] = torch.clamp(post_tensor, min=0, max=1) 95 | 96 | if save_per_instance: 97 | for i in range(n_instance): 98 | image_list = [] 99 | for each_item in image_dict: 100 | if each_item["n_in_row"] == 1: 101 | image_list.append(each_item["tensors"][i].unsqueeze(0)) 102 | else: 103 | for j in range(each_item["n_in_row"]): 104 | image_list.append(each_item["tensors"][i, j].unsqueeze(0)) 105 | images = torch.cat(image_list, dim=0) 106 | if len(save_per_instance_idxs) > 0: 107 | save_path = join(dir_path, f"{file_name}_{save_per_instance_idxs[i]}.png") 108 | else: 109 | save_path = join(dir_path, f"{file_name}_{i}.png") 110 | if is_save: 111 | save_image(images, save_path, padding=0, pad_value=0.5, nrow=n_row) 112 | else: 113 | save_path = join(dir_path, f"{file_name}.png") 114 | image_list = [] 115 | for i in range(n_instance): 116 | for each_item in image_dict: 117 | if each_item["n_in_row"] == 1: 118 | image_list.append(each_item["tensors"][i].unsqueeze(0)) 119 | else: 120 | for j in range(each_item["n_in_row"]): 121 | image_list.append(each_item["tensors"][i, j].unsqueeze(0)) 122 | 123 | images = torch.cat(image_list, dim=0) 124 | concated_image = make_grid(images, padding=2, pad_value=0.5, nrow=n_row) 125 | if is_save: 126 | save_image(concated_image, save_path, nrow=1) 127 | 128 | if return_images: 129 | return concated_image 130 | 131 | ### =========================================================================== ### 132 | """ 133 | functions to visualize boxes 134 | """ 135 | def draw_box(img, boxes): 136 | colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"] 137 | draw = ImageDraw.Draw(img) 138 | for bid, box in enumerate(boxes): 139 | draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4) 140 | # draw.rectangle([box[0], box[1], box[2], box[3]], outline ="red", width=2) # x0 y0 x1 y1 141 | return img 142 | 143 | def vis_boxes(info_dict): 144 | 145 | device = info_dict["image"].device 146 | # img = torchvision.transforms.functional.to_pil_image( info_dict["image"]*0.5+0.5 ) 147 | canvas = torchvision.transforms.functional.to_pil_image( torch.zeros_like(info_dict["image"]) ) # unused 148 | W, H = canvas.size 149 | 150 | boxes = [] 151 | for box in info_dict["boxes"]: 152 | x0,y0,x1,y1 = box 153 | boxes.append( [float(x0*W), float(y0*H), float(x1*W), float(y1*H)] ) 154 | canvas = draw_box(canvas, boxes) 155 | 156 | return torchvision.transforms.functional.to_tensor(canvas).to(device) 157 | ### =========================================================================== ### 158 | 159 | ### =========================================================================== ### 160 | """ 161 | functions to visualize keypoints 162 | """ 163 | def draw_points(img, points): 164 | colors = ["red", "yellow", "blue", "green", "orange", "brown", "cyan", "purple", "deeppink", "coral", "gold", "darkblue", "khaki", "lightgreen", "snow", "yellowgreen", "lime"] 165 | colors = colors * 100 166 | draw = ImageDraw.Draw(img) 167 | 168 | r = 3 169 | for point, color in zip(points, colors): 170 | if point[0] == point[1] == 0: 171 | pass 172 | else: 173 | x, y = float(point[0]), float(point[1]) 174 | draw.ellipse( [ (x-r,y-r), (x+r,y+r) ], fill=color ) 175 | # draw.rectangle([box[0], box[1], box[2], box[3]], outline ="red", width=2) # x0 y0 x1 y1 176 | return img 177 | 178 | def vis_keypoints(info_dict): 179 | 180 | device = info_dict["image"].device 181 | # img = torchvision.transforms.functional.to_pil_image( info_dict["image"]*0.5+0.5 ) 182 | _, H, W = info_dict["image"].size() 183 | canvas = torchvision.transforms.functional.to_pil_image( torch.ones_like(info_dict["image"])*0.2 ) 184 | assert W==H 185 | img = draw_points( canvas, info_dict["points"]*W ) 186 | 187 | return torchvision.transforms.functional.to_tensor(img).to(device) 188 | ### =========================================================================== ### 189 | 190 | 191 | 192 | ### ================================ Color palette-related functions ================================== ### 193 | 194 | 195 | class Palette(object): 196 | """ 197 | Create a color palette (codebook) in the form of a 2D grid of colors, 198 | as described in the parameters list below. 199 | Further, the rightmost column has num_hues gradations from black to white. 200 | Parameters 201 | ---------- 202 | num_hues : int 203 | number of colors with full lightness and saturation, in the middle 204 | sat_range : int 205 | number of rows above middle row that show 206 | the same hues with decreasing saturation. 207 | light_range : int 208 | number of rows below middle row that show 209 | the same hues with decreasing lightness. 210 | Returns 211 | ------- 212 | palette: rayleigh.Palette 213 | """ 214 | 215 | def __init__(self, num_hues=8, sat_range=2, light_range=2): 216 | height = 1 + sat_range + (2 * light_range - 1) 217 | # generate num_hues+1 hues, but don't take the last one: 218 | # hues are on a circle, and we would be oversampling the origin 219 | hues = np.tile(np.linspace(0, 1, num_hues + 1)[:-1], (height, 1)) 220 | if num_hues == 8: 221 | hues = np.tile(np.array( 222 | [0., 0.10, 0.15, 0.28, 0.51, 0.58, 0.77, 0.85]), (height, 1)) 223 | if num_hues == 9: 224 | hues = np.tile(np.array( 225 | [0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.7, 0.87]), (height, 1)) 226 | if num_hues == 10: 227 | hues = np.tile(np.array( 228 | [0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.66, 0.76, 0.87]), (height, 1)) 229 | elif num_hues == 11: 230 | hues = np.tile(np.array( 231 | [0.0, 0.0833, 0.166, 0.25, 232 | 0.333, 0.5, 0.56333, 233 | 0.666, 0.73, 0.803, 234 | 0.916]), (height, 1)) 235 | 236 | sats = np.hstack(( 237 | np.linspace(0, 1, sat_range + 2)[1:-1], 238 | 1, 239 | [1] * (light_range), 240 | [.4] * (light_range - 1), 241 | )) 242 | lights = np.hstack(( 243 | [1] * sat_range, 244 | 1, 245 | np.linspace(1, 0.2, light_range + 2)[1:-1], 246 | np.linspace(1, 0.2, light_range + 2)[1:-2], 247 | )) 248 | 249 | sats = np.tile(np.atleast_2d(sats).T, (1, num_hues)) 250 | lights = np.tile(np.atleast_2d(lights).T, (1, num_hues)) 251 | colors = hsv2rgb(np.dstack((hues, sats, lights))) 252 | grays = np.tile( 253 | np.linspace(1, 0, height)[:, np.newaxis, np.newaxis], (1, 1, 3)) 254 | 255 | self.rgb_image = np.hstack((colors, grays)) 256 | 257 | # Make a nice histogram ordering of the hues and grays 258 | h, w, d = colors.shape 259 | color_array = colors.T.reshape((d, w * h)).T 260 | h, w, d = grays.shape 261 | gray_array = grays.T.reshape((d, w * h)).T 262 | 263 | self.rgb_array = np.vstack((color_array, gray_array)) 264 | self.lab_array = rgb2lab(self.rgb_array[None, :, :]).squeeze() 265 | self.hex_list = [rgb2hex(row) for row in self.rgb_array] 266 | 267 | def output(self, dirname): 268 | """ 269 | Output an image of the palette, josn list of the hex 270 | colors, and an HTML color picker for it. 271 | Parameters 272 | ---------- 273 | dirname : string 274 | directory for the files to be output 275 | """ 276 | pass # we do not need this for visualization 277 | 278 | def color_hist_to_palette_image(color_hist, palette, percentile=90, 279 | width=200, height=50, filename=None): 280 | """ 281 | Output the main colors in the histogram to a "palette image." 282 | Parameters 283 | ---------- 284 | color_hist : (K,) ndarray 285 | palette : rayleigh.Palette 286 | percentile : int, optional: 287 | Output only colors above this percentile of prevalence in the histogram. 288 | filename : string, optional: 289 | If given, save the resulting image to file. 290 | Returns 291 | ------- 292 | rgb_image : ndarray 293 | """ 294 | ind = np.argsort(-color_hist) 295 | ind = ind[color_hist[ind] > np.percentile(color_hist, percentile)] 296 | hex_list = np.take(palette.hex_list, ind) 297 | values = color_hist[ind] 298 | rgb_image = palette_query_to_rgb_image(dict(zip(hex_list, values)), width, height) 299 | if filename: 300 | imsave(filename, rgb_image) 301 | return rgb_image 302 | 303 | def palette_query_to_rgb_image(palette_query, width=200, height=50): 304 | """ 305 | Convert a list of hex colors and their values to an RGB image of given 306 | width and height. 307 | Args: 308 | - palette_query (dict): 309 | a dictionary of hex colors to unnormalized values, 310 | e.g. {'#ffffff': 20, '#33cc00': 0.4}. 311 | """ 312 | hex_list, values = zip(*palette_query.items()) 313 | values = np.array(values) 314 | values /= values.sum() 315 | nums = np.array(values * width, dtype=int) 316 | rgb_arrays = (np.tile(np.array(hex2rgb(x)), (num, 1)) 317 | for x, num in zip(hex_list, nums)) 318 | rgb_array = np.vstack(list(rgb_arrays)) 319 | rgb_image = rgb_array[np.newaxis, :, :] 320 | rgb_image = np.tile(rgb_image, (height, 1, 1)) 321 | return rgb_image 322 | 323 | def rgb2hex(rgb_number): 324 | """ 325 | Args: 326 | - rgb_number (sequence of float) 327 | Returns: 328 | - hex_number (string) 329 | """ 330 | return '#{:02x}{:02x}{:02x}'.format(*tuple([int(np.round(val * 255)) for val in rgb_number])) 331 | 332 | def hex2rgb(hexcolor_str): 333 | """ 334 | Args: 335 | - hexcolor_str (string): e.g. '#ffffff' or '33cc00' 336 | Returns: 337 | - rgb_color (sequence of floats): e.g. (0.2, 0.3, 0) 338 | """ 339 | color = hexcolor_str.strip('#') 340 | return tuple(round(int(color[i:i+2], 16) / 255., 5) for i in (0, 2, 4)) 341 | 342 | def vis_color_palette(color_hist, shape): 343 | if color_hist.sum() == 0: 344 | color_hist[-1] = 1.0 345 | palette = Palette(num_hues=11, sat_range=5, light_range=5) 346 | color_palette = color_hist_to_palette_image(color_hist.cpu().numpy(), palette, percentile=90) 347 | color_palette = torch.tensor(color_palette.transpose(2,0,1)).unsqueeze(0) 348 | color_palette = F.interpolate(color_palette, size=(shape, shape), mode='nearest') 349 | return color_palette.squeeze(0) 350 | 351 | --------------------------------------------------------------------------------