├── .gitignore
├── LICENSE
├── README.md
├── assets
└── fig1.png
├── grounding_input
└── condition_null_generator.py
├── images
├── color1.pth
├── color2.pth
├── dummy.pth
├── fire.png
├── jeep_depth.png
├── jeep_sketch.png
├── nature.png
└── partial_sketch.png
├── inference.py
├── ldm
├── lr_scheduler.py
├── models
│ ├── autoencoder.py
│ └── diffusion
│ │ ├── __init__.py
│ │ ├── classifier.py
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ ├── ldm.py
│ │ ├── plms.py
│ │ └── plmsg.py
├── modules
│ ├── attention.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── condition_net.py
│ │ ├── model.py
│ │ ├── multimodal_openaimodel.py
│ │ └── util.py
│ ├── distributions
│ │ ├── __init__.py
│ │ └── distributions.py
│ ├── ema.py
│ ├── encoders
│ │ ├── __init__.py
│ │ └── modules.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── contperceptual.py
│ │ └── vqperceptual.py
│ ├── multimodal_attention.py
│ └── x_transformer.py
└── util.py
├── requirements.txt
└── visualization
├── draw_utils.py
├── extract_utils.py
└── image_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.png
2 | *.jpg
3 | *.pth
4 | *.ckpt
5 | *.npy
6 | OUTPUT/
7 | wandb/
8 | inference/
9 | sampling/
10 |
11 | # IntelliJ project files
12 | .idea
13 | *.iml
14 | out
15 | gen
16 |
17 | ### Vim template
18 | [._]*.s[a-w][a-z]
19 | [._]s[a-w][a-z]
20 | *.un~
21 | Session.vim
22 | .netrwhist
23 | *~
24 |
25 | ### IPythonNotebook template
26 | # Temporary data
27 | .ipynb_checkpoints/
28 |
29 | ### Python template
30 | # Byte-compiled / optimized / DLL files
31 | __pycache__/
32 | *.py[cod]
33 | *$py.class
34 |
35 | # C extensions
36 | *.so
37 |
38 | # Distribution / packaging
39 | .Python
40 | env/
41 | build/
42 | develop-eggs/
43 | dist/
44 | downloads/
45 | eggs/
46 | .eggs/
47 | #lib/
48 | #lib64/
49 | parts/
50 | sdist/
51 | var/
52 | *.egg-info/
53 | .installed.cfg
54 | *.egg
55 |
56 | # PyInstaller
57 | # Usually these files are written by a python script from a template
58 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
59 | *.manifest
60 | *.spec
61 |
62 | # Installer logs
63 | pip-log.txt
64 | pip-delete-this-directory.txt
65 |
66 | # Unit test / coverage reports
67 | htmlcov/
68 | .tox/
69 | .coverage
70 | .coverage.*
71 | .cache
72 | nosetests.xml
73 | coverage.xml
74 | *,cover
75 |
76 | # Translations
77 | *.mo
78 | *.pot
79 |
80 | # Django stuff:
81 | *.log
82 |
83 | # Sphinx documentation
84 | docs/_build/
85 |
86 | # PyBuilder
87 | target/
88 |
89 | *.ipynb
90 | *.params
91 | # *.json
92 | .vscode/
93 | *.code-workspace/
94 |
95 | lib/pycocotools/_mask.c
96 | lib/nms/cpu_nms.c
97 |
98 | DATA/
99 | logs/
100 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DiffBlender: Scalable and Composable Multimodal Text-to-Image Diffusion Models 🔥
2 |
3 |
4 |
5 |
6 |
7 |
8 | - **DiffBlender** successfully synthesizes complex combinations of input modalities. It enables flexible manipulation of conditions, providing the customized generation aligned with user preferences.
9 | - We designed its structure to intuitively extend to additional modalities while achieving a low training cost through a partial update of hypernetworks.
10 |
11 |
12 |
13 |
14 |
15 | ## 🗓️ TODOs
16 |
17 | - [x] Project page is open: [link](https://sungnyun.github.io/diffblender/)
18 | - [x] DiffBlender model: code & checkpoint
19 | - [x] Release inference code
20 | - [ ] Release training code & pipeline
21 | - [ ] Gradio UI
22 |
23 | ## 🚀 Getting Started
24 | Install the necessary packages with:
25 | ```sh
26 | $ pip install -r requirements.txt
27 | ```
28 |
29 | Download DiffBlender model checkpoint from this [Huggingface model](https://huggingface.co/sungnyun/diffblender), and place it under `./diffblender_checkpoints/`.
30 | Also, prepare the SD model from this [link](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original) (we used CompVis/sd-v1-4.ckpt).
31 |
32 | ## ⚡️ Try Multimodal T2I Generation with DiffBlender
33 | ```sh
34 | $ python inference.py --ckpt_path=./diffblender_checkpoints/{CKPT_NAME}.pth \
35 | --official_ckpt_path=/path/to/sd-v1-4.ckpt \
36 | --save_name={SAVE_NAME}
37 | ```
38 |
39 | Results will be saved under `./inference/{SAVE_NAME}/`, in the format as {conditions + generated image}.
40 |
41 |
42 |
43 | ## BibTeX
44 | ```
45 | @article{kim2023diffblender,
46 | title={DiffBlender: Scalable and Composable Multimodal Text-to-Image Diffusion Models},
47 | author={Kim, Sungnyun and Lee, Junsoo and Hong, Kibeom and Kim, Daesik and Ahn, Namhyuk},
48 | journal={arXiv preprint arXiv:2305.15194},
49 | year={2023}
50 | }
51 | ```
52 |
--------------------------------------------------------------------------------
/assets/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/assets/fig1.png
--------------------------------------------------------------------------------
/grounding_input/condition_null_generator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch as th
3 |
4 |
5 |
6 | class BoxConditionInput:
7 | def __init__(self):
8 | self.set = False
9 |
10 | def prepare(self, batch):
11 | self.set = True
12 |
13 | boxes = batch["values"]
14 | masks = batch["masks"]
15 | text_embeddings = batch["text_embeddings"]
16 |
17 | self.batch_size, self.max_box, self.embedding_len = text_embeddings.shape
18 | self.device = text_embeddings.device
19 | self.dtype = text_embeddings.dtype
20 |
21 | # return {"values": boxes, "masks": masks, "text_embeddings": text_embeddings}
22 |
23 | def get_null_input(self, batch, batch_size=None, device=None, dtype=None):
24 | assert self.set, "not set yet, cannot call this funcion"
25 | batch_size = self.batch_size if batch_size is None else batch_size
26 | device = self.device if device is None else device
27 | dtype = self.dtype if dtype is None else dtype
28 |
29 | boxes = th.zeros(batch_size, self.max_box, 4).type(dtype).to(device)
30 | masks = th.zeros(batch_size, self.max_box).type(dtype).to(device)
31 | text_embeddings = th.zeros(batch_size, self.max_box, self.embedding_len).type(dtype).to(device)
32 |
33 | batch["values"] = boxes
34 | batch["masks"] = masks
35 | batch["text_embeddings"] = text_embeddings
36 |
37 | return batch
38 |
39 |
40 | class KeypointConditionInput:
41 | def __init__(self):
42 | self.set = False
43 |
44 | def prepare(self, batch):
45 | self.set = True
46 |
47 | points = batch["values"]
48 | masks = batch["masks"]
49 |
50 | self.batch_size, self.max_persons_per_image, _ = points.shape
51 | self.max_persons_per_image = int(self.max_persons_per_image / 17)
52 | self.device = points.device
53 | self.dtype = points.dtype
54 |
55 | # return {"values": points, "masks": masks}
56 |
57 | def get_null_input(self, batch, batch_size=None, device=None, dtype=None):
58 | assert self.set, "not set yet, cannot call this funcion"
59 | batch_size = self.batch_size if batch_size is None else batch_size
60 | device = self.device if device is None else device
61 | dtype = self.dtype if dtype is None else dtype
62 |
63 | points = th.zeros(batch_size, self.max_persons_per_image*17, 2).to(device)
64 | masks = th.zeros(batch_size, self.max_persons_per_image*17).to(device)
65 |
66 | batch["values"] = points
67 | batch["masks"] = masks
68 |
69 | return batch
70 |
71 |
72 | class NSPVectorConditionInput:
73 | def __init__(self):
74 | self.set = False
75 |
76 | def prepare(self, batch):
77 | self.set = True
78 |
79 | vectors = batch["values"]
80 | masks = batch["masks"]
81 |
82 | self.batch_size, self.in_dim = vectors.shape
83 | self.device = vectors.device
84 | self.dtype = vectors.dtype
85 |
86 | # return {"values": vectors, "masks": masks}
87 |
88 | def get_null_input(self, batch, batch_size=None, device=None, dtype=None):
89 | assert self.set, "not set yet, cannot call this funcion"
90 | batch_size = self.batch_size if batch_size is None else batch_size
91 | device = self.device if device is None else device
92 | dtype = self.dtype if dtype is None else dtype
93 |
94 | vectors = th.zeros(batch_size, self.in_dim).to(device)
95 | masks = th.zeros(batch_size, 1).to(device)
96 |
97 | batch["values"] = vectors
98 | batch["masks"] = masks
99 |
100 | return batch
101 |
102 |
103 | class ImageConditionInput:
104 | def __init__(self):
105 | self.set = False
106 |
107 | def prepare(self, batch):
108 | self.set = True
109 |
110 | image = batch["values"]
111 | self.batch_size, self.in_channel, self.H, self.W = image.shape
112 | self.device = image.device
113 | self.dtype = image.dtype
114 |
115 | # return {"values": image}
116 |
117 | def get_null_input(self, batch, batch_size=None, device=None, dtype=None):
118 | assert self.set, "not set yet, cannot call this funcion"
119 | batch_size = self.batch_size if batch_size is None else batch_size
120 | device = self.device if device is None else device
121 | dtype = self.dtype if dtype is None else dtype
122 |
123 | image = th.zeros(batch_size, self.in_channel, self.H, self.W).to(device)
124 | masks = th.zeros(batch_size, 1).to(device)
125 |
126 | batch["values"] = image
127 | batch["masks"] = masks
128 |
129 | return batch
130 |
131 |
--------------------------------------------------------------------------------
/images/color1.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/images/color1.pth
--------------------------------------------------------------------------------
/images/color2.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/images/color2.pth
--------------------------------------------------------------------------------
/images/dummy.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/images/dummy.pth
--------------------------------------------------------------------------------
/images/fire.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/images/fire.png
--------------------------------------------------------------------------------
/images/jeep_depth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/images/jeep_depth.png
--------------------------------------------------------------------------------
/images/jeep_sketch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/images/jeep_sketch.png
--------------------------------------------------------------------------------
/images/nature.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/images/nature.png
--------------------------------------------------------------------------------
/images/partial_sketch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/images/partial_sketch.png
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import random
4 | from PIL import Image
5 | from omegaconf import OmegaConf
6 | from copy import deepcopy
7 | from functools import partial
8 |
9 | import torch
10 | import torchvision
11 | import torchvision.transforms as tf
12 | from torch.utils.data import DataLoader
13 | import pytorch_lightning
14 |
15 | from ldm.util import instantiate_from_config
16 | from ldm.models.diffusion.plms import PLMSSampler
17 | from transformers import CLIPProcessor, CLIPModel
18 | from visualization.extract_utils import get_sketch, get_depth, get_box, get_keypoint, get_color_palette, get_clip_feature
19 | import visualization.image_utils as iutils
20 | from visualization.draw_utils import *
21 |
22 |
23 | device = "cuda"
24 |
25 |
26 | def read_official_ckpt(ckpt_path, no_model=False):
27 | state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
28 | out = {}
29 | out["model"] = {}
30 | out["text_encoder"] = {}
31 | out["autoencoder"] = {}
32 | out["unexpected"] = {}
33 | out["diffusion"] = {}
34 |
35 | for k,v in state_dict.items():
36 | if k.startswith('model.diffusion_model'):
37 | if no_model:
38 | continue
39 | out["model"][k.replace("model.diffusion_model.", "")] = v
40 | elif k.startswith('cond_stage_model'):
41 | out["text_encoder"][k.replace("cond_stage_model.", "")] = v
42 | elif k.startswith('first_stage_model'):
43 | out["autoencoder"][k.replace("first_stage_model.", "")] = v
44 | elif k in ["model_ema.decay", "model_ema.num_updates"]:
45 | out["unexpected"][k] = v
46 | else:
47 | out["diffusion"][k] = v
48 |
49 | if no_model:
50 | del state_dict
51 | return out
52 |
53 | def batch_to_device(batch, device):
54 | for k in batch:
55 | if isinstance(batch[k], torch.Tensor):
56 | batch[k] = batch[k].to(device)
57 |
58 | elif isinstance(batch[k], dict):
59 | for k_2 in batch[k]:
60 | if isinstance(batch[k][k_2], torch.Tensor):
61 | batch[k][k_2] = batch[k][k_2].to(device)
62 | return batch
63 |
64 | def load_ckpt(ckpt_path, official_ckpt_path='./sd-v1-4.ckpt'):
65 |
66 | saved_ckpt = torch.load(ckpt_path)
67 | config = saved_ckpt["config_dict"]["_content"]
68 |
69 | model = instantiate_from_config(config['model']).to(device).eval()
70 | autoencoder = instantiate_from_config(config['autoencoder']).to(device).eval()
71 | text_encoder = instantiate_from_config(config['text_encoder']).to(device).eval()
72 | diffusion = instantiate_from_config(config['diffusion']).to(device)
73 |
74 | # donot need to load official_ckpt for self.model here, since we will load from our ckpt
75 | missing, unexpected = model.load_state_dict( saved_ckpt['model'], strict=False )
76 | assert missing == []
77 | # print('unexpected keys:', unexpected)
78 |
79 | official_ckpt = read_official_ckpt(official_ckpt_path)
80 | autoencoder.load_state_dict( official_ckpt["autoencoder"] )
81 | text_encoder.load_state_dict( official_ckpt["text_encoder"] )
82 | diffusion.load_state_dict( official_ckpt["diffusion"] )
83 |
84 | if model.use_autoencoder_kl:
85 | for mode, input_type in zip(model.input_modalities, model.input_types):
86 | if input_type == "image":
87 | model.condition_nets[mode].autoencoder = deepcopy(autoencoder)
88 | model.condition_nets[mode].set = True
89 |
90 | return model, autoencoder, text_encoder, diffusion, config
91 |
92 | def set_alpha_scale(model, alpha_scale):
93 | from ldm.modules.multimodal_attention import GatedCrossAttentionDense, GatedSelfAttentionDense
94 | from ldm.modules.diffusionmodules.multimodal_openaimodel import UNetModel
95 | alpha_scale_sp, alpha_scale_nsp, alpha_scale_image = alpha_scale
96 | for name, module in model.named_modules():
97 | if type(module) == GatedCrossAttentionDense or type(module) == GatedSelfAttentionDense:
98 | if '.sp_fuser' in name:
99 | module.scale = alpha_scale_sp
100 | elif '.nsp_fuser' in name:
101 | module.scale = alpha_scale_nsp
102 | elif type(module) == UNetModel:
103 | module.scales = [alpha_scale_image] * 4
104 |
105 | def alpha_generator(length, config):
106 | """
107 | length is total timestpes needed for sampling.
108 | type should be a list containing three values which sum should be 1
109 |
110 | It means the percentage of three stages:
111 | alpha=scale stage
112 | linear deacy stage
113 | alpha=0 stage.
114 |
115 | For example if length=100, type=[0.8,0.1,0.1,_scale_]
116 | then the first 800 stpes, alpha will be _scale_, and then linearly decay to 0 in the next 100 steps,
117 | and the last 100 stpes are 0.
118 | """
119 |
120 | alpha_schedule_sp = config['alpha_type_sp']
121 | alpha_schedule_nsp = config['alpha_type_nsp']
122 | alpha_schedule_image = config['alpha_type_image']
123 |
124 | alphas_ = list()
125 | for alpha_schedule in [alpha_schedule_sp, alpha_schedule_nsp, alpha_schedule_image]:
126 |
127 | assert len(alpha_schedule)==4
128 | assert alpha_schedule[0] + alpha_schedule[1] + alpha_schedule[2] == 1
129 |
130 | stage0_length = int(alpha_schedule[0]*length)
131 | stage1_length = int(alpha_schedule[1]*length)
132 | stage2_length = length - stage0_length - stage1_length
133 |
134 | if stage1_length != 0:
135 | decay_alphas = alpha_schedule[3] * np.arange(start=0, stop=1, step=1/stage1_length)[::-1]
136 | decay_alphas = list(decay_alphas)
137 | else:
138 | decay_alphas = []
139 |
140 | alphas = [alpha_schedule[3]]*stage0_length + decay_alphas + [0]*stage2_length
141 |
142 | assert len(alphas) == length
143 | alphas_.append(alphas)
144 |
145 | return list(zip(*alphas_))
146 |
147 | def preprocess(prompt="",
148 | sketch=None,
149 | depth=None,
150 | phrases=None,
151 | locations=None,
152 | keypoints=None,
153 | color=None,
154 | reference=None):
155 | batch = dict()
156 | null_conditions = []
157 |
158 | batch = torch.load('./images/dummy.pth', map_location='cpu') # dummy var
159 | batch["caption"] = [prompt]
160 |
161 | if sketch is not None:
162 | sketch_tensor = get_sketch(sketch)
163 | selected_sketch = dict()
164 | selected_sketch["values"] = sketch_tensor.unsqueeze(0)
165 | selected_sketch["masks"] = torch.tensor([[1.]])
166 | batch["sketch"] = selected_sketch
167 | else:
168 | null_conditions.append("sketch")
169 |
170 | if depth is not None:
171 | depth_tensor = get_depth(depth)
172 | selected_depth = dict()
173 | selected_depth["values"] = depth_tensor.unsqueeze(0)
174 | selected_depth["masks"] = torch.tensor([[1.]])
175 | batch["depth"] = selected_depth
176 | else:
177 | null_conditions.append("depth")
178 |
179 | if locations is not None and phrases is not None:
180 | version = "openai/clip-vit-large-patch14"
181 | clip_model = CLIPModel.from_pretrained(version).cuda()
182 | clip_processor = CLIPProcessor.from_pretrained(version)
183 | boxes, masks, text_embeddings = get_box(locations, phrases, clip_model, clip_processor)
184 | selected_box = dict()
185 | selected_box["values"] = boxes.unsqueeze(0)
186 | selected_box["masks"] = masks.unsqueeze(0)
187 | selected_box["text_embeddings"] = text_embeddings.unsqueeze(0)
188 | batch["box"] = selected_box
189 | else:
190 | null_conditions.append("box")
191 |
192 | if keypoints is not None:
193 | points, masks = get_keypoint(keypoints)
194 | selected_keypoint = dict()
195 | selected_keypoint["values"] = points.unsqueeze(0)
196 | selected_keypoint["masks"] = masks.unsqueeze(0)
197 | batch["keypoint"] = selected_keypoint
198 | else:
199 | null_conditions.append("keypoint")
200 |
201 | if color is not None:
202 | selected_color_palette = dict()
203 | # color_palette = get_color_palette(color) # for .png file
204 | # selected_color_palette["values"] = torch.tensor(color_palette, dtype=torch.float32).unsqueeze(0)
205 | selected_color_palette["values"] = torch.load(color).unsqueeze(0)
206 | selected_color_palette["masks"] = torch.tensor([[1.]])
207 | batch["color_palette"] = selected_color_palette
208 | else:
209 | null_conditions.append("color_palette")
210 |
211 | if reference is not None:
212 | version = "openai/clip-vit-large-patch14"
213 | clip_model = CLIPModel.from_pretrained(version).cuda()
214 | clip_processor = CLIPProcessor.from_pretrained(version)
215 | clip_features = get_clip_feature(reference, clip_model, clip_processor)
216 | selected_image_embedding = dict()
217 | selected_image_embedding["values"] = torch.tensor(clip_features).unsqueeze(0)
218 | selected_image_embedding["masks"] = torch.tensor([[1.]])
219 | batch["image_embedding"] = selected_image_embedding
220 | batch["image_embedding"]["image"] = tf.ToTensor()(tf.Resize((512,512))(Image.open(reference).convert('RGB'))).unsqueeze(0)
221 | else:
222 | batch["image_embedding"]["image"] = -torch.ones_like(batch["image"])
223 | null_conditions.append("image_embedding")
224 |
225 | return batch, null_conditions
226 |
227 |
228 |
229 | @torch.no_grad()
230 | def run(selected_batch_,
231 | config,
232 | model,
233 | autoencoder,
234 | text_encoder,
235 | diffusion,
236 | condition_null_generator_dict,
237 | idx,
238 | NULL_CONDITION,
239 | SAVE_NAME,
240 | seed):
241 |
242 | #### Starting noise fixed ####
243 | torch.manual_seed(seed)
244 | torch.cuda.manual_seed(seed)
245 | starting_noise = torch.randn(1, 4, 64, 64).to(device)
246 |
247 | selected_batch = deepcopy(selected_batch_)
248 | uc_batch = deepcopy(selected_batch_)
249 | for mode in condition_null_generator_dict:
250 | if mode in NULL_CONDITION:
251 | condition_null_generator = condition_null_generator_dict[mode]
252 | condition_null_generator.prepare(selected_batch[mode])
253 | selected_batch[mode] = condition_null_generator.get_null_input(selected_batch[mode])
254 | else:
255 | if mode in ["sketch", "depth"]:
256 | continue
257 | condition_null_generator = condition_null_generator_dict[mode]
258 | condition_null_generator.prepare(uc_batch[mode])
259 | uc_batch[mode] = condition_null_generator.get_null_input(uc_batch[mode])
260 |
261 | selected_batch = batch_to_device(selected_batch, device)
262 | uc_batch = batch_to_device(uc_batch, device)
263 |
264 | torch.cuda.empty_cache()
265 |
266 | batch_here = config['batch_size']
267 | context = text_encoder.encode(selected_batch["caption"])
268 | # you can set negative prompts here
269 | # uc = text_encoder.encode(batch_here*["longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"])
270 | uc = text_encoder.encode(batch_here*[""])
271 |
272 | # plms sampling
273 | alpha_generator_func = partial(alpha_generator, config=config)
274 | sampler = PLMSSampler(diffusion, model, alpha_generator_func=alpha_generator_func, set_alpha_scale=set_alpha_scale)
275 | steps = 50
276 | shape = (batch_here, model.in_channels, model.image_size, model.image_size)
277 |
278 | input_dict = dict(x = starting_noise,
279 | timesteps = None,
280 | context = context,
281 | inpainting_extra_input = None,
282 | condition = selected_batch )
283 |
284 | uc_dict = dict(context = uc,
285 | condition = uc_batch )
286 |
287 | samples = sampler.sample(S=steps, shape=shape, input=input_dict, uc_dict=uc_dict, guidance_scale=config['guidance_scale'])
288 | pred_image = autoencoder.decode(samples)
289 |
290 | image_dict = [
291 | # {"tensors": selected_batch["image"], "n_in_row": 1, "pp_type": iutils.PP_RGB},
292 | {"tensors": draw_sketch_with_batch_to_tensor(selected_batch), "n_in_row": 1, "pp_type": iutils.PP_SEGM},
293 | {"tensors": draw_depth_with_batch_to_tensor(selected_batch), "n_in_row": 1, "pp_type": iutils.PP_SEGM},
294 | {"tensors": draw_boxes_with_batch_to_tensor(selected_batch), "n_in_row": 1, "pp_type": iutils.PP_SEGM},
295 | {"tensors": draw_keypoints_with_batch_to_tensor(selected_batch), "n_in_row": 1, "pp_type": iutils.PP_SEGM},
296 | {"tensors": draw_image_embedding_with_batch_to_tensor(selected_batch), "n_in_row": 1, "pp_type": iutils.PP_RGB},
297 | {"tensors": draw_color_palettes_with_batch_to_tensor(selected_batch), "n_in_row": 1, "pp_type": iutils.PP_SEGM}, # range 0~1
298 | {"tensors": pred_image, "n_in_row": 1, "pp_type": iutils.PP_RGB},
299 | ]
300 | os.makedirs(os.path.join("inference", SAVE_NAME), exist_ok=True)
301 | iutils.save_images_from_dict(
302 | image_dict, dir_path=os.path.join("inference", SAVE_NAME), file_name="sampled_{:4d}".format(idx),
303 | n_instance=config['batch_size'], is_save=True, return_images=False
304 | )
305 | save_path = os.path.join("inference", SAVE_NAME, 'captions.txt')
306 | with open(save_path, "a") as f:
307 | f.write( 'idx ' + str(idx) + ':\n' )
308 | for cap in selected_batch['caption']:
309 | f.write( cap + '\n' )
310 | f.write( '\n' )
311 | print("Save images and its corresponding captions.. done")
312 |
313 | return pred_image.detach().cpu()
314 |
315 |
316 |
317 | if __name__ == "__main__":
318 |
319 | parser = argparse.ArgumentParser()
320 | parser.add_argument("--ckpt_path", type=str, default="./diffblender_checkpoints/checkpoint_latest.pth", help="pretrained checkpoint path")
321 | parser.add_argument("--official_ckpt_path", type=str, default="/path/to/sd-v1-4.ckpt", help="official SD path")
322 | parser.add_argument("--save_name", type=str, default="SAVE_NAME", help="")
323 |
324 | parser.add_argument("--alpha_type_sp", nargs='+', type=float, default=[0.3, 0.0, 0.7, 1.0], help="alpha scheduling type for spatial cond.")
325 | parser.add_argument("--alpha_type_nsp", nargs='+', type=float, default=[0.3, 0.0, 0.7, 1.0], help="alpha scheduling type for non-spatial cond.")
326 | parser.add_argument("--alpha_type_image", nargs='+', type=float, default=[1.0, 0.0, 0.0, 0.7], help="alpha scheduling type for image-form cond.")
327 | parser.add_argument("--guidance_scale", type=float, default=5.0, help="classifier-free guidance scale")
328 |
329 | args = parser.parse_args()
330 |
331 |
332 | model, autoencoder, text_encoder, diffusion, config = load_ckpt(ckpt_path=args.ckpt_path, official_ckpt_path=args.official_ckpt_path)
333 | condition_null_generator_dict = dict()
334 | for mode in config['condition_null_generator']['input_modalities']:
335 | condition_null_generator_dict[mode] = instantiate_from_config(config['condition_null_generator'][mode])
336 |
337 | # replace config
338 | config['batch_size'] = 1
339 | config['alpha_type_sp'] = args.alpha_type_sp
340 | config['alpha_type_nsp'] = args.alpha_type_nsp
341 | config['alpha_type_image'] = args.alpha_type_image
342 | config['guidance_scale'] = args.guidance_scale
343 |
344 | kwargs_dict = dict(
345 | config=config,
346 | model=model,
347 | autoencoder=autoencoder,
348 | text_encoder=text_encoder,
349 | diffusion=diffusion,
350 | condition_null_generator_dict=condition_null_generator_dict,
351 | SAVE_NAME=args.save_name,
352 | )
353 |
354 |
355 | meta_list = [ # change
356 |
357 | dict(
358 | prompt = "jeep",
359 | sketch = "images/jeep_sketch.png",
360 | depth = "images/jeep_depth.png",
361 | color = "images/color1.pth", # can also use image file via get_color_palette func
362 | reference = "images/fire.png",
363 | ),
364 |
365 | dict(
366 | prompt = "swimming rabbits",
367 | phrases = ["rabbit", "rabbit", "rabbit"],
368 | locations = [ [0.3500, 0.5000, 1.0000, 0.9500], [0.2000, 0.2500, 0.6000, 0.5500], [0.0500, 0.0500, 0.4000, 0.3000] ],
369 | color = "images/color2.pth",
370 | ),
371 |
372 | dict(
373 | prompt = "jumping astronaut",
374 | sketch = "images/partial_sketch.png",
375 | phrases = ["astronaut"],
376 | locations = [[0.1158, 0.1053, 0.5140, 0.6111]],
377 | keypoints = [
378 | [ [0.2767, 0.2025],
379 | [0.2617, 0.1875],
380 | [0.2917, 0.1875],
381 | [0.0000, 0.0000],
382 | [0.3117, 0.1800],
383 | [0.2192, 0.2375],
384 | [0.3392, 0.2425],
385 | [0.1942, 0.2850],
386 | [0.3967, 0.3075],
387 | [0.1667, 0.3475],
388 | [0.4142, 0.3675],
389 | [0.2592, 0.3775],
390 | [0.3242, 0.3700],
391 | [0.2717, 0.4425],
392 | [0.3992, 0.4375],
393 | [0.2367, 0.5550],
394 | [0.4067, 0.5225], ]
395 | ],
396 | reference = "images/nature.png",
397 | ),
398 |
399 | ]
400 |
401 | seed_list = [40, 10, 20] # change
402 |
403 | for idx, (meta, seed) in enumerate(zip(meta_list, seed_list)):
404 | batch, null_conditions = preprocess(**meta)
405 | kwargs_dict['idx'] = idx
406 | kwargs_dict['seed'] = seed
407 | kwargs_dict['NULL_CONDITION'] = null_conditions
408 | pred_image = run(batch, **kwargs_dict)
409 |
--------------------------------------------------------------------------------
/ldm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/ldm/models/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | #import pytorch_lightning as pl
4 | import torch.nn.functional as F
5 | from contextlib import contextmanager
6 |
7 | # from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
8 |
9 | from ldm.modules.diffusionmodules.model import Encoder, Decoder
10 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
11 |
12 | from ldm.util import instantiate_from_config
13 |
14 |
15 |
16 |
17 | class AutoencoderKL(nn.Module):
18 | def __init__(self,
19 | ddconfig,
20 | embed_dim,
21 | scale_factor=1
22 | ):
23 | super().__init__()
24 | self.encoder = Encoder(**ddconfig)
25 | self.decoder = Decoder(**ddconfig)
26 | assert ddconfig["double_z"]
27 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
28 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
29 | self.embed_dim = embed_dim
30 | self.scale_factor = scale_factor
31 |
32 |
33 |
34 | def encode(self, x):
35 | h = self.encoder(x)
36 | moments = self.quant_conv(h)
37 | posterior = DiagonalGaussianDistribution(moments)
38 | return posterior.sample() * self.scale_factor
39 |
40 | def decode(self, z):
41 | z = 1. / self.scale_factor * z
42 | z = self.post_quant_conv(z)
43 | dec = self.decoder(z)
44 | return dec
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/ldm/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/ldm/models/diffusion/classifier.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pytorch_lightning as pl
4 | from omegaconf import OmegaConf
5 | from torch.nn import functional as F
6 | from torch.optim import AdamW
7 | from torch.optim.lr_scheduler import LambdaLR
8 | from copy import deepcopy
9 | from einops import rearrange
10 | from glob import glob
11 | from natsort import natsorted
12 |
13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15 |
16 | __models__ = {
17 | 'class_label': EncoderUNetModel,
18 | 'segmentation': UNetModel
19 | }
20 |
21 |
22 | def disabled_train(self, mode=True):
23 | """Overwrite model.train with this function to make sure train/eval mode
24 | does not change anymore."""
25 | return self
26 |
27 |
28 | class NoisyLatentImageClassifier(pl.LightningModule):
29 |
30 | def __init__(self,
31 | diffusion_path,
32 | num_classes,
33 | ckpt_path=None,
34 | pool='attention',
35 | label_key=None,
36 | diffusion_ckpt_path=None,
37 | scheduler_config=None,
38 | weight_decay=1.e-2,
39 | log_steps=10,
40 | monitor='val/loss',
41 | *args,
42 | **kwargs):
43 | super().__init__(*args, **kwargs)
44 | self.num_classes = num_classes
45 | # get latest config of diffusion model
46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47 | self.diffusion_config = OmegaConf.load(diffusion_config).model
48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49 | self.load_diffusion()
50 |
51 | self.monitor = monitor
52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54 | self.log_steps = log_steps
55 |
56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57 | else self.diffusion_model.cond_stage_key
58 |
59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60 |
61 | if self.label_key not in __models__:
62 | raise NotImplementedError()
63 |
64 | self.load_classifier(ckpt_path, pool)
65 |
66 | self.scheduler_config = scheduler_config
67 | self.use_scheduler = self.scheduler_config is not None
68 | self.weight_decay = weight_decay
69 |
70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71 | sd = torch.load(path, map_location="cpu")
72 | if "state_dict" in list(sd.keys()):
73 | sd = sd["state_dict"]
74 | keys = list(sd.keys())
75 | for k in keys:
76 | for ik in ignore_keys:
77 | if k.startswith(ik):
78 | print("Deleting key {} from state_dict.".format(k))
79 | del sd[k]
80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81 | sd, strict=False)
82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83 | if len(missing) > 0:
84 | print(f"Missing Keys: {missing}")
85 | if len(unexpected) > 0:
86 | print(f"Unexpected Keys: {unexpected}")
87 |
88 | def load_diffusion(self):
89 | model = instantiate_from_config(self.diffusion_config)
90 | self.diffusion_model = model.eval()
91 | self.diffusion_model.train = disabled_train
92 | for param in self.diffusion_model.parameters():
93 | param.requires_grad = False
94 |
95 | def load_classifier(self, ckpt_path, pool):
96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98 | model_config.out_channels = self.num_classes
99 | if self.label_key == 'class_label':
100 | model_config.pool = pool
101 |
102 | self.model = __models__[self.label_key](**model_config)
103 | if ckpt_path is not None:
104 | print('#####################################################################')
105 | print(f'load from ckpt "{ckpt_path}"')
106 | print('#####################################################################')
107 | self.init_from_ckpt(ckpt_path)
108 |
109 | @torch.no_grad()
110 | def get_x_noisy(self, x, t, noise=None):
111 | noise = default(noise, lambda: torch.randn_like(x))
112 | continuous_sqrt_alpha_cumprod = None
113 | if self.diffusion_model.use_continuous_noise:
114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115 | # todo: make sure t+1 is correct here
116 |
117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119 |
120 | def forward(self, x_noisy, t, *args, **kwargs):
121 | return self.model(x_noisy, t)
122 |
123 | @torch.no_grad()
124 | def get_input(self, batch, k):
125 | x = batch[k]
126 | if len(x.shape) == 3:
127 | x = x[..., None]
128 | x = rearrange(x, 'b h w c -> b c h w')
129 | x = x.to(memory_format=torch.contiguous_format).float()
130 | return x
131 |
132 | @torch.no_grad()
133 | def get_conditioning(self, batch, k=None):
134 | if k is None:
135 | k = self.label_key
136 | assert k is not None, 'Needs to provide label key'
137 |
138 | targets = batch[k].to(self.device)
139 |
140 | if self.label_key == 'segmentation':
141 | targets = rearrange(targets, 'b h w c -> b c h w')
142 | for down in range(self.numd):
143 | h, w = targets.shape[-2:]
144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145 |
146 | # targets = rearrange(targets,'b c h w -> b h w c')
147 |
148 | return targets
149 |
150 | def compute_top_k(self, logits, labels, k, reduction="mean"):
151 | _, top_ks = torch.topk(logits, k, dim=1)
152 | if reduction == "mean":
153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154 | elif reduction == "none":
155 | return (top_ks == labels[:, None]).float().sum(dim=-1)
156 |
157 | def on_train_epoch_start(self):
158 | # save some memory
159 | self.diffusion_model.model.to('cpu')
160 |
161 | @torch.no_grad()
162 | def write_logs(self, loss, logits, targets):
163 | log_prefix = 'train' if self.training else 'val'
164 | log = {}
165 | log[f"{log_prefix}/loss"] = loss.mean()
166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167 | logits, targets, k=1, reduction="mean"
168 | )
169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170 | logits, targets, k=5, reduction="mean"
171 | )
172 |
173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176 | lr = self.optimizers().param_groups[0]['lr']
177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178 |
179 | def shared_step(self, batch, t=None):
180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181 | targets = self.get_conditioning(batch)
182 | if targets.dim() == 4:
183 | targets = targets.argmax(dim=1)
184 | if t is None:
185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186 | else:
187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188 | x_noisy = self.get_x_noisy(x, t)
189 | logits = self(x_noisy, t)
190 |
191 | loss = F.cross_entropy(logits, targets, reduction='none')
192 |
193 | self.write_logs(loss.detach(), logits.detach(), targets.detach())
194 |
195 | loss = loss.mean()
196 | return loss, logits, x_noisy, targets
197 |
198 | def training_step(self, batch, batch_idx):
199 | loss, *_ = self.shared_step(batch)
200 | return loss
201 |
202 | def reset_noise_accs(self):
203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205 |
206 | def on_validation_start(self):
207 | self.reset_noise_accs()
208 |
209 | @torch.no_grad()
210 | def validation_step(self, batch, batch_idx):
211 | loss, *_ = self.shared_step(batch)
212 |
213 | for t in self.noisy_acc:
214 | _, logits, _, targets = self.shared_step(batch, t)
215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217 |
218 | return loss
219 |
220 | def configure_optimizers(self):
221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222 |
223 | if self.use_scheduler:
224 | scheduler = instantiate_from_config(self.scheduler_config)
225 |
226 | print("Setting up LambdaLR scheduler...")
227 | scheduler = [
228 | {
229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230 | 'interval': 'step',
231 | 'frequency': 1
232 | }]
233 | return [optimizer], scheduler
234 |
235 | return optimizer
236 |
237 | @torch.no_grad()
238 | def log_images(self, batch, N=8, *args, **kwargs):
239 | log = dict()
240 | x = self.get_input(batch, self.diffusion_model.first_stage_key)
241 | log['inputs'] = x
242 |
243 | y = self.get_conditioning(batch)
244 |
245 | if self.label_key == 'class_label':
246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247 | log['labels'] = y
248 |
249 | if ismap(y):
250 | log['labels'] = self.diffusion_model.to_rgb(y)
251 |
252 | for step in range(self.log_steps):
253 | current_time = step * self.log_time_interval
254 |
255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256 |
257 | log[f'inputs@t{current_time}'] = x_noisy
258 |
259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260 | pred = rearrange(pred, 'b h w c -> b c h w')
261 |
262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263 |
264 | for key in log:
265 | log[key] = log[key][:N]
266 |
267 | return log
268 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/ddim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from tqdm import tqdm
4 | from functools import partial
5 |
6 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
7 |
8 |
9 | class DDIMSampler(object):
10 | def __init__(self, diffusion, model, schedule="linear", alpha_generator_func=None, set_alpha_scale=None):
11 | super().__init__()
12 | self.diffusion = diffusion
13 | self.model = model
14 | self.device = diffusion.betas.device
15 | self.ddpm_num_timesteps = diffusion.num_timesteps
16 | self.schedule = schedule
17 | self.alpha_generator_func = alpha_generator_func
18 | self.set_alpha_scale = set_alpha_scale
19 |
20 |
21 | def register_buffer(self, name, attr):
22 | if type(attr) == torch.Tensor:
23 | attr = attr.to(self.device)
24 | setattr(self, name, attr)
25 |
26 |
27 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.):
28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=False)
30 | alphas_cumprod = self.diffusion.alphas_cumprod
31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
33 |
34 | self.register_buffer('betas', to_torch(self.diffusion.betas))
35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.diffusion.alphas_cumprod_prev))
37 |
38 | # calculations for diffusion q(x_t | x_{t-1}) and others
39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
44 |
45 | # ddim sampling parameters
46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
47 | ddim_timesteps=self.ddim_timesteps,
48 | eta=ddim_eta,verbose=False)
49 | self.register_buffer('ddim_sigmas', ddim_sigmas)
50 | self.register_buffer('ddim_alphas', ddim_alphas)
51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
57 |
58 |
59 | @torch.no_grad()
60 | def sample(self, S, shape, input, uc_dict=None, guidance_scale=1, mask=None, x0=None):
61 | self.make_schedule(ddim_num_steps=S)
62 | return self.ddim_sampling(shape, input, uc_dict, guidance_scale, mask=mask, x0=x0)
63 |
64 |
65 | @torch.no_grad()
66 | def ddim_sampling(self, shape, input, uc_dict=None, guidance_scale=1, mask=None, x0=None):
67 | b = shape[0]
68 |
69 | img = input["x"]
70 | if img == None:
71 | img = torch.randn(shape, device=self.device)
72 | input["x"] = img
73 |
74 |
75 | time_range = np.flip(self.ddim_timesteps)
76 | total_steps = self.ddim_timesteps.shape[0]
77 |
78 | #iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
79 | iterator = time_range
80 |
81 | if self.alpha_generator_func != None:
82 | alphas = self.alpha_generator_func(len(iterator))
83 |
84 |
85 | for i, step in enumerate(iterator):
86 |
87 | # set alpha
88 | if self.alpha_generator_func != None:
89 | self.set_alpha_scale(self.model, alphas[i])
90 |
91 | # run
92 | index = total_steps - i - 1
93 | input["timesteps"] = torch.full((b,), step, device=self.device, dtype=torch.long)
94 |
95 | if mask is not None:
96 | assert x0 is not None
97 | img_orig = self.diffusion.q_sample( x0, input["timesteps"] )
98 | img = img_orig * mask + (1. - mask) * img
99 | input["x"] = img
100 |
101 | img, pred_x0 = self.p_sample_ddim(input, index=index, uc_dict=uc_dict, guidance_scale=guidance_scale)
102 | input["x"] = img
103 |
104 | return img
105 |
106 |
107 | @torch.no_grad()
108 | def p_sample_ddim(self, input, index, uc_dict=None, guidance_scale=1):
109 |
110 |
111 | e_t = self.model(input)
112 | if uc_dict is not None and guidance_scale != 1:
113 | unconditional_input = dict(
114 | x=input["x"],
115 | timesteps=input["timesteps"],
116 | context=uc_dict['context'],
117 | inpainting_extra_input=input["inpainting_extra_input"],
118 | condition=uc_dict['condition'],
119 | )
120 | e_t_uncond = self.model( unconditional_input )
121 | e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond)
122 |
123 | # select parameters corresponding to the currently considered timestep
124 | b = input["x"].shape[0]
125 | a_t = torch.full((b, 1, 1, 1), self.ddim_alphas[index], device=self.device)
126 | a_prev = torch.full((b, 1, 1, 1), self.ddim_alphas_prev[index], device=self.device)
127 | sigma_t = torch.full((b, 1, 1, 1), self.ddim_sigmas[index], device=self.device)
128 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), self.ddim_sqrt_one_minus_alphas[index],device=self.device)
129 |
130 | # current prediction for x_0
131 | pred_x0 = (input["x"] - sqrt_one_minus_at * e_t) / a_t.sqrt()
132 |
133 | # direction pointing to x_t
134 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
135 | noise = sigma_t * torch.randn_like( input["x"] )
136 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
137 |
138 | return x_prev, pred_x0
139 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/ddpm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from functools import partial
5 | from ldm.modules.diffusionmodules.util import make_beta_schedule
6 |
7 |
8 |
9 |
10 |
11 | class DDPM(nn.Module):
12 | def __init__(self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
13 | super().__init__()
14 |
15 | self.v_posterior = 0
16 | self.register_schedule(beta_schedule, timesteps, linear_start, linear_end, cosine_s)
17 |
18 |
19 | def register_schedule(self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
20 |
21 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
22 | alphas = 1. - betas
23 | alphas_cumprod = np.cumprod(alphas, axis=0)
24 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
25 |
26 | timesteps, = betas.shape
27 | self.num_timesteps = int(timesteps)
28 | self.linear_start = linear_start
29 | self.linear_end = linear_end
30 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
31 |
32 | to_torch = partial(torch.tensor, dtype=torch.float32)
33 |
34 | self.register_buffer('betas', to_torch(betas))
35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
36 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
37 |
38 | # calculations for diffusion q(x_t | x_{t-1}) and others
39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
44 |
45 | # calculations for posterior q(x_{t-1} | x_t, x_0)
46 | posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( 1. - alphas_cumprod) + self.v_posterior * betas
47 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
48 |
49 | self.register_buffer('posterior_variance', to_torch(posterior_variance))
50 |
51 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
52 | self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
53 | self.register_buffer('posterior_mean_coef1', to_torch( betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
54 | self.register_buffer('posterior_mean_coef2', to_torch( (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/ldm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from tqdm import tqdm
5 | from ldm.util import default
6 | from ldm.modules.diffusionmodules.util import extract_into_tensor
7 | from .ddpm import DDPM
8 |
9 |
10 |
11 | class LatentDiffusion(DDPM):
12 | def __init__(self, *args, **kwargs):
13 | super().__init__(*args, **kwargs)
14 | # hardcoded
15 | self.clip_denoised = False
16 |
17 |
18 |
19 | def q_sample(self, x_start, t, noise=None):
20 | noise = default(noise, lambda: torch.randn_like(x_start))
21 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
22 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
23 |
24 |
25 | "Does not support DDPM sampling anymore. Only do DDIM or PLMS"
26 |
27 | # = = = = = = = = = = = = Below is for sampling = = = = = = = = = = = = #
28 |
29 | # def predict_start_from_noise(self, x_t, t, noise):
30 | # return ( extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
31 | # extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise )
32 |
33 | # def q_posterior(self, x_start, x_t, t):
34 | # posterior_mean = (
35 | # extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
36 | # extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
37 | # )
38 | # posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
39 | # posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
40 | # return posterior_mean, posterior_variance, posterior_log_variance_clipped
41 |
42 |
43 | # def p_mean_variance(self, model, x, c, t):
44 |
45 | # model_out = model(x, t, c)
46 | # x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
47 |
48 | # if self.clip_denoised:
49 | # x_recon.clamp_(-1., 1.)
50 |
51 | # model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
52 | # return model_mean, posterior_variance, posterior_log_variance, x_recon
53 |
54 |
55 | # @torch.no_grad()
56 | # def p_sample(self, model, x, c, t):
57 | # b, *_, device = *x.shape, x.device
58 | # model_mean, _, model_log_variance, x0 = self.p_mean_variance(model, x=x, c=c, t=t, )
59 | # noise = torch.randn_like(x)
60 |
61 | # # no noise when t == 0
62 | # nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
63 |
64 | # return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
65 |
66 |
67 | # @torch.no_grad()
68 | # def p_sample_loop(self, model, shape, c):
69 | # device = self.betas.device
70 | # b = shape[0]
71 | # img = torch.randn(shape, device=device)
72 |
73 | # iterator = tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps)
74 | # for i in iterator:
75 | # ts = torch.full((b,), i, device=device, dtype=torch.long)
76 | # img, x0 = self.p_sample(model, img, c, ts)
77 |
78 | # return img
79 |
80 |
81 | # @torch.no_grad()
82 | # def sample(self, model, shape, c, uc=None, guidance_scale=None):
83 | # return self.p_sample_loop(model, shape, c)
84 |
85 |
86 |
87 |
88 |
89 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/plms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from tqdm import tqdm
4 | from functools import partial
5 | from copy import deepcopy
6 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
7 |
8 |
9 | class PLMSSampler(object):
10 | def __init__(self, diffusion, model, schedule="linear", alpha_generator_func=None, set_alpha_scale=None):
11 | super().__init__()
12 | self.diffusion = diffusion
13 | self.model = model
14 | self.device = diffusion.betas.device
15 | self.ddpm_num_timesteps = diffusion.num_timesteps
16 | self.schedule = schedule
17 | self.alpha_generator_func = alpha_generator_func
18 | self.set_alpha_scale = set_alpha_scale
19 |
20 | def register_buffer(self, name, attr):
21 | if type(attr) == torch.Tensor:
22 | attr = attr.to(self.device)
23 | setattr(self, name, attr)
24 |
25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=False):
26 | if ddim_eta != 0:
27 | raise ValueError('ddim_eta must be 0 for PLMS')
28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
30 | alphas_cumprod = self.diffusion.alphas_cumprod
31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
33 |
34 | self.register_buffer('betas', to_torch(self.diffusion.betas))
35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.diffusion.alphas_cumprod_prev))
37 |
38 | # calculations for diffusion q(x_t | x_{t-1}) and others
39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
44 |
45 | # ddim sampling parameters
46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
47 | ddim_timesteps=self.ddim_timesteps,
48 | eta=ddim_eta,verbose=verbose)
49 | self.register_buffer('ddim_sigmas', ddim_sigmas)
50 | self.register_buffer('ddim_alphas', ddim_alphas)
51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
57 |
58 |
59 | @torch.no_grad()
60 | def sample(self, S, shape, input, uc_dict=None, guidance_scale=1, mask=None, x0=None):
61 | self.make_schedule(ddim_num_steps=S)
62 | return self.plms_sampling(shape, input, uc_dict, guidance_scale, mask=mask, x0=x0)
63 |
64 |
65 | @torch.no_grad()
66 | def plms_sampling(self, shape, input, uc_dict=None, guidance_scale=1, mask=None, x0=None):
67 |
68 | b = shape[0]
69 |
70 | img = input["x"]
71 | if img == None:
72 | img = torch.randn(shape, device=self.device)
73 | input["x"] = img
74 |
75 | time_range = np.flip(self.ddim_timesteps)
76 | total_steps = self.ddim_timesteps.shape[0]
77 |
78 | old_eps = []
79 |
80 | if self.alpha_generator_func != None:
81 | alphas = self.alpha_generator_func(len(time_range))
82 |
83 | for i, step in enumerate(time_range):
84 | print(f"PLMS sampling step : ({step}/{len(time_range)})")
85 | # set alpha
86 | if self.alpha_generator_func != None:
87 | self.set_alpha_scale(self.model, alphas[i])
88 |
89 | # run
90 | index = total_steps - i - 1
91 | ts = torch.full((b,), step, device=self.device, dtype=torch.long)
92 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=self.device, dtype=torch.long)
93 |
94 | if mask is not None:
95 | assert x0 is not None
96 | img_orig = self.diffusion.q_sample(x0, ts)
97 | img = img_orig * mask + (1. - mask) * img
98 | input["x"] = img
99 |
100 | img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc_dict=uc_dict, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next)
101 | input["x"] = img
102 | old_eps.append(e_t)
103 | if len(old_eps) >= 4:
104 | old_eps.pop(0)
105 |
106 | return img
107 |
108 |
109 | @torch.no_grad()
110 | def p_sample_plms(self, input, t, index, guidance_scale=1., uc_dict=None, old_eps=None, t_next=None):
111 | x = deepcopy(input["x"])
112 | b = x.shape[0]
113 |
114 | def get_model_output(input):
115 | e_t = self.model(input)
116 | if uc_dict is not None and guidance_scale != 1:
117 | unconditional_input = dict(
118 | x=input["x"],
119 | timesteps=input["timesteps"],
120 | context=uc_dict['context'],
121 | inpainting_extra_input=input["inpainting_extra_input"],
122 | condition=uc_dict['condition'],
123 | )
124 | e_t_uncond = self.model(unconditional_input)
125 | e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond)
126 | return e_t
127 |
128 | def get_x_prev_and_pred_x0(e_t, index):
129 | # select parameters corresponding to the currently considered timestep
130 | a_t = torch.full((b, 1, 1, 1), self.ddim_alphas[index], device=self.device)
131 | a_prev = torch.full((b, 1, 1, 1), self.ddim_alphas_prev[index], device=self.device)
132 | sigma_t = torch.full((b, 1, 1, 1), self.ddim_sigmas[index], device=self.device)
133 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), self.ddim_sqrt_one_minus_alphas[index],device=self.device)
134 |
135 | # current prediction for x_0
136 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
137 |
138 | # direction pointing to x_t
139 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
140 | noise = sigma_t * torch.randn_like(x)
141 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
142 | return x_prev, pred_x0
143 |
144 | input["timesteps"] = t
145 | e_t = get_model_output(input)
146 | if len(old_eps) == 0:
147 | # Pseudo Improved Euler (2nd order)
148 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
149 | input["x"] = x_prev
150 | input["timesteps"] = t_next
151 | e_t_next = get_model_output(input)
152 | e_t_prime = (e_t + e_t_next) / 2
153 | elif len(old_eps) == 1:
154 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
155 | e_t_prime = (3 * e_t - old_eps[-1]) / 2
156 | elif len(old_eps) == 2:
157 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
158 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
159 | elif len(old_eps) >= 3:
160 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
161 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
162 |
163 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
164 |
165 | return x_prev, pred_x0, e_t
166 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/plmsg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from tqdm import tqdm
4 | from functools import partial
5 | from copy import deepcopy
6 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
7 |
8 |
9 | class PLMSGSampler(object):
10 | def __init__(self, diffusion, model, schedule="linear", alpha_generator_func=None, set_alpha_scale=None):
11 | super().__init__()
12 | self.diffusion = diffusion
13 | self.model = model
14 | self.device = diffusion.betas.device
15 | self.ddpm_num_timesteps = diffusion.num_timesteps
16 | self.schedule = schedule
17 | self.alpha_generator_func = alpha_generator_func
18 | self.set_alpha_scale = set_alpha_scale
19 |
20 | def register_buffer(self, name, attr):
21 | if type(attr) == torch.Tensor:
22 | attr = attr.to(self.device)
23 | setattr(self, name, attr)
24 |
25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=False):
26 | if ddim_eta != 0:
27 | raise ValueError('ddim_eta must be 0 for PLMS')
28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
30 | alphas_cumprod = self.diffusion.alphas_cumprod
31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
33 |
34 | self.register_buffer('betas', to_torch(self.diffusion.betas))
35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.diffusion.alphas_cumprod_prev))
37 |
38 | # calculations for diffusion q(x_t | x_{t-1}) and others
39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
44 |
45 | # ddim sampling parameters
46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
47 | ddim_timesteps=self.ddim_timesteps,
48 | eta=ddim_eta,verbose=verbose)
49 | self.register_buffer('ddim_sigmas', ddim_sigmas)
50 | self.register_buffer('ddim_alphas', ddim_alphas)
51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
57 |
58 |
59 | @torch.no_grad()
60 | def sample(self, S, shape, input, uc_dict=None, guidance_scale=1, guidance_mode_scale=[1], mask=None, x0=None):
61 | self.make_schedule(ddim_num_steps=S)
62 | return self.plms_sampling(shape, input, uc_dict, guidance_scale, guidance_mode_scale, mask=mask, x0=x0)
63 |
64 |
65 | @torch.no_grad()
66 | def plms_sampling(self, shape, input, uc_dict=None, guidance_scale=1, guidance_mode_scale=[1], mask=None, x0=None):
67 |
68 | b = shape[0]
69 |
70 | img = input["x"]
71 | if img == None:
72 | img = torch.randn(shape, device=self.device)
73 | input["x"] = img
74 |
75 | time_range = np.flip(self.ddim_timesteps)
76 | total_steps = self.ddim_timesteps.shape[0]
77 |
78 | old_eps = []
79 |
80 | if self.alpha_generator_func != None:
81 | alphas = self.alpha_generator_func(len(time_range))
82 |
83 | for i, step in enumerate(time_range):
84 | print(f"PLMS sampling step : ({step}/{len(time_range)})")
85 | # set alpha
86 | if self.alpha_generator_func != None:
87 | self.set_alpha_scale(self.model, alphas[i])
88 |
89 | # run
90 | index = total_steps - i - 1
91 | ts = torch.full((b,), step, device=self.device, dtype=torch.long)
92 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=self.device, dtype=torch.long)
93 |
94 | if mask is not None:
95 | assert x0 is not None
96 | img_orig = self.diffusion.q_sample(x0, ts)
97 | img = img_orig * mask + (1. - mask) * img
98 | input["x"] = img
99 |
100 | img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc_dict=uc_dict, guidance_scale=guidance_scale, guidance_mode_scale=guidance_mode_scale, old_eps=old_eps, t_next=ts_next)
101 | input["x"] = img
102 | old_eps.append(e_t)
103 | if len(old_eps) >= 4:
104 | old_eps.pop(0)
105 |
106 | return img
107 |
108 |
109 | @torch.no_grad()
110 | def p_sample_plms(self, input, t, index, guidance_scale=1., guidance_mode_scale=[1], uc_dict=None, old_eps=None, t_next=None):
111 | x = deepcopy(input["x"])
112 | b = x.shape[0]
113 |
114 | def get_model_output(input):
115 | e_t = self.model(input)
116 | if uc_dict is not None and guidance_scale != 1:
117 | unconditional_input = dict(
118 | x=input["x"],
119 | timesteps=input["timesteps"],
120 | context=uc_dict['context'],
121 | inpainting_extra_input=input["inpainting_extra_input"],
122 | condition=uc_dict['condition'],
123 | )
124 | unconditional_mode_input = dict(
125 | x=input["x"],
126 | timesteps=input["timesteps"],
127 | context=input['context'],
128 | inpainting_extra_input=input["inpainting_extra_input"],
129 | condition=uc_dict['mode_condition'],
130 | )
131 |
132 | e_t_uncond = self.model(unconditional_input)
133 | e_t_uncond_mode = self.model(unconditional_mode_input)
134 | mode_guidance = []
135 | for i, scale in enumerate(guidance_mode_scale):
136 | mode_guidance.append(scale * (e_t[i] - e_t_uncond_mode[i]))
137 | mode_guidance = torch.stack(mode_guidance)
138 | e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond) + mode_guidance
139 | return e_t
140 |
141 | def get_x_prev_and_pred_x0(e_t, index):
142 | # select parameters corresponding to the currently considered timestep
143 | a_t = torch.full((b, 1, 1, 1), self.ddim_alphas[index], device=self.device)
144 | a_prev = torch.full((b, 1, 1, 1), self.ddim_alphas_prev[index], device=self.device)
145 | sigma_t = torch.full((b, 1, 1, 1), self.ddim_sigmas[index], device=self.device)
146 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), self.ddim_sqrt_one_minus_alphas[index],device=self.device)
147 |
148 | # current prediction for x_0
149 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
150 |
151 | # direction pointing to x_t
152 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
153 | noise = sigma_t * torch.randn_like(x)
154 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
155 | return x_prev, pred_x0
156 |
157 | input["timesteps"] = t
158 | e_t = get_model_output(input)
159 | if len(old_eps) == 0:
160 | # Pseudo Improved Euler (2nd order)
161 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
162 | input["x"] = x_prev
163 | input["timesteps"] = t_next
164 | e_t_next = get_model_output(input)
165 | e_t_prime = (e_t + e_t_next) / 2
166 | elif len(old_eps) == 1:
167 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
168 | e_t_prime = (3 * e_t - old_eps[-1]) / 2
169 | elif len(old_eps) == 2:
170 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
171 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
172 | elif len(old_eps) >= 3:
173 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
174 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
175 |
176 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
177 |
178 | return x_prev, pred_x0, e_t
179 |
--------------------------------------------------------------------------------
/ldm/modules/attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn, einsum
6 | from einops import rearrange, repeat
7 |
8 | # from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder
9 | from torch.utils import checkpoint
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def uniq(arr):
16 | return{el: True for el in arr}.keys()
17 |
18 |
19 | def default(val, d):
20 | if exists(val):
21 | return val
22 | return d() if isfunction(d) else d
23 |
24 |
25 | def max_neg_value(t):
26 | return -torch.finfo(t.dtype).max
27 |
28 |
29 | def init_(tensor):
30 | dim = tensor.shape[-1]
31 | std = 1 / math.sqrt(dim)
32 | tensor.uniform_(-std, std)
33 | return tensor
34 |
35 |
36 | # feedforward
37 | class GEGLU(nn.Module):
38 | def __init__(self, dim_in, dim_out):
39 | super().__init__()
40 | self.proj = nn.Linear(dim_in, dim_out * 2)
41 |
42 | def forward(self, x):
43 | x, gate = self.proj(x).chunk(2, dim=-1)
44 | return x * F.gelu(gate)
45 |
46 |
47 | class FeedForward(nn.Module):
48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49 | super().__init__()
50 | inner_dim = int(dim * mult)
51 | dim_out = default(dim_out, dim)
52 | project_in = nn.Sequential(
53 | nn.Linear(dim, inner_dim),
54 | nn.GELU()
55 | ) if not glu else GEGLU(dim, inner_dim)
56 |
57 | self.net = nn.Sequential(
58 | project_in,
59 | nn.Dropout(dropout),
60 | nn.Linear(inner_dim, dim_out)
61 | )
62 |
63 | def forward(self, x):
64 | return self.net(x)
65 |
66 |
67 | def zero_module(module):
68 | """
69 | Zero out the parameters of a module and return it.
70 | """
71 | for p in module.parameters():
72 | p.detach().zero_()
73 | return module
74 |
75 |
76 | def Normalize(in_channels):
77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78 |
79 |
80 | class LinearAttention(nn.Module):
81 | def __init__(self, dim, heads=4, dim_head=32):
82 | super().__init__()
83 | self.heads = heads
84 | hidden_dim = dim_head * heads
85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87 |
88 | def forward(self, x):
89 | b, c, h, w = x.shape
90 | qkv = self.to_qkv(x)
91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92 | k = k.softmax(dim=-1)
93 | context = torch.einsum('bhdn,bhen->bhde', k, v)
94 | out = torch.einsum('bhde,bhdn->bhen', context, q)
95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96 | return self.to_out(out)
97 |
98 |
99 |
100 |
101 |
102 | class CrossAttention(nn.Module):
103 | def __init__(self, query_dim, key_dim, value_dim, heads=8, dim_head=64, dropout=0):
104 | super().__init__()
105 | inner_dim = dim_head * heads
106 | self.scale = dim_head ** -0.5
107 | self.heads = heads
108 |
109 |
110 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
111 | self.to_k = nn.Linear(key_dim, inner_dim, bias=False)
112 | self.to_v = nn.Linear(value_dim, inner_dim, bias=False)
113 |
114 |
115 | self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) )
116 |
117 |
118 | def fill_inf_from_mask(self, sim, mask):
119 | if mask is not None:
120 | B,M = mask.shape
121 | mask = mask.unsqueeze(1).repeat(1,self.heads,1).reshape(B*self.heads,1,-1)
122 | max_neg_value = -torch.finfo(sim.dtype).max
123 | sim.masked_fill_(~mask, max_neg_value)
124 | return sim
125 |
126 |
127 | def forward(self, x, key, value, mask=None):
128 |
129 | q = self.to_q(x) # B*N*(H*C)
130 | k = self.to_k(key) # B*M*(H*C)
131 | v = self.to_v(value) # B*M*(H*C)
132 |
133 | B, N, HC = q.shape
134 | _, M, _ = key.shape
135 | H = self.heads
136 | C = HC // H
137 |
138 | q = q.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C
139 | k = k.view(B,M,H,C).permute(0,2,1,3).reshape(B*H,M,C) # (B*H)*M*C
140 | v = v.view(B,M,H,C).permute(0,2,1,3).reshape(B*H,M,C) # (B*H)*M*C
141 |
142 | sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale # (B*H)*N*M
143 | self.fill_inf_from_mask(sim, mask)
144 | attn = sim.softmax(dim=-1) # (B*H)*N*M
145 |
146 | out = torch.einsum('b i j, b j d -> b i d', attn, v) # (B*H)*N*C
147 | out = out.view(B,H,N,C).permute(0,2,1,3).reshape(B,N,(H*C)) # B*N*(H*C)
148 |
149 | return self.to_out(out)
150 |
151 |
152 |
153 |
154 | class SelfAttention(nn.Module):
155 | def __init__(self, query_dim, heads=8, dim_head=64, dropout=0.):
156 | super().__init__()
157 | inner_dim = dim_head * heads
158 | self.scale = dim_head ** -0.5
159 | self.heads = heads
160 |
161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
162 | self.to_k = nn.Linear(query_dim, inner_dim, bias=False)
163 | self.to_v = nn.Linear(query_dim, inner_dim, bias=False)
164 |
165 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) )
166 |
167 | def forward(self, x):
168 | q = self.to_q(x) # B*N*(H*C)
169 | k = self.to_k(x) # B*N*(H*C)
170 | v = self.to_v(x) # B*N*(H*C)
171 |
172 | B, N, HC = q.shape
173 | H = self.heads
174 | C = HC // H
175 |
176 | q = q.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C
177 | k = k.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C
178 | v = v.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C
179 |
180 | sim = torch.einsum('b i c, b j c -> b i j', q, k) * self.scale # (B*H)*N*N
181 | attn = sim.softmax(dim=-1) # (B*H)*N*N
182 |
183 | out = torch.einsum('b i j, b j c -> b i c', attn, v) # (B*H)*N*C
184 | out = out.view(B,H,N,C).permute(0,2,1,3).reshape(B,N,(H*C)) # B*N*(H*C)
185 |
186 | return self.to_out(out)
187 |
188 |
189 |
190 | class GatedCrossAttentionDense(nn.Module):
191 | def __init__(self, query_dim, key_dim, value_dim, n_heads, d_head):
192 | super().__init__()
193 |
194 | self.attn = CrossAttention(query_dim=query_dim, key_dim=key_dim, value_dim=value_dim, heads=n_heads, dim_head=d_head)
195 | self.ff = FeedForward(query_dim, glu=True)
196 |
197 | self.norm1 = nn.LayerNorm(query_dim)
198 | self.norm2 = nn.LayerNorm(query_dim)
199 |
200 | self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)) )
201 | self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)) )
202 |
203 | # this can be useful: we can externally change magnitude of tanh(alpha)
204 | # for example, when it is set to 0, then the entire model is same as original one
205 | self.scale = 1
206 |
207 | def forward(self, x, objs):
208 |
209 | x = x + self.scale*torch.tanh(self.alpha_attn) * self.attn( self.norm1(x), objs, objs)
210 | x = x + self.scale*torch.tanh(self.alpha_dense) * self.ff( self.norm2(x) )
211 |
212 | return x
213 |
214 |
215 | class GatedSelfAttentionDense(nn.Module):
216 | def __init__(self, query_dim, context_dim, n_heads, d_head):
217 | super().__init__()
218 |
219 | # we need a linear projection since we need cat visual feature and obj feature
220 | self.linear = nn.Linear(context_dim, query_dim)
221 |
222 | self.attn = SelfAttention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
223 | self.ff = FeedForward(query_dim, glu=True)
224 |
225 | self.norm1 = nn.LayerNorm(query_dim)
226 | self.norm2 = nn.LayerNorm(query_dim)
227 |
228 | self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)) )
229 | self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)) )
230 |
231 | # this can be useful: we can externally change magnitude of tanh(alpha)
232 | # for example, when it is set to 0, then the entire model is same as original one
233 | self.scale = 1
234 |
235 |
236 | def forward(self, x, objs):
237 |
238 | N_visual = x.shape[1]
239 | objs = self.linear(objs)
240 |
241 | x = x + self.scale*torch.tanh(self.alpha_attn) * self.attn( self.norm1(torch.cat([x,objs],dim=1)) )[:,0:N_visual,:]
242 | x = x + self.scale*torch.tanh(self.alpha_dense) * self.ff( self.norm2(x) )
243 |
244 | return x
245 |
246 |
247 | class BasicTransformerBlock(nn.Module):
248 | def __init__(self, query_dim, key_dim, value_dim, n_heads, d_head, fuser_type, use_checkpoint=True):
249 | super().__init__()
250 | self.attn1 = SelfAttention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
251 | self.ff = FeedForward(query_dim, glu=True)
252 | self.attn2 = CrossAttention(query_dim=query_dim, key_dim=key_dim, value_dim=value_dim, heads=n_heads, dim_head=d_head)
253 | self.norm1 = nn.LayerNorm(query_dim)
254 | self.norm2 = nn.LayerNorm(query_dim)
255 | self.norm3 = nn.LayerNorm(query_dim)
256 | self.use_checkpoint = use_checkpoint
257 |
258 | if fuser_type == "gatedSA":
259 | # note key_dim here actually is context_dim
260 | self.fuser = GatedSelfAttentionDense(query_dim, key_dim, n_heads, d_head)
261 | elif fuser_type == "gatedCA":
262 | self.fuser = GatedCrossAttentionDense(query_dim, key_dim, value_dim, n_heads, d_head)
263 | else:
264 | assert False
265 |
266 |
267 | def forward(self, x, context, objs):
268 | # return checkpoint(self._forward, (x, context, objs), self.parameters(), self.use_checkpoint)
269 | if self.use_checkpoint and x.requires_grad:
270 | return checkpoint.checkpoint(self._forward, x, context, objs)
271 | else:
272 | return self._forward(x, context, objs)
273 |
274 | def _forward(self, x, context, objs):
275 | x = self.attn1( self.norm1(x) ) + x
276 | x = self.fuser(x, objs) # identity mapping in the beginning
277 | x = self.attn2(self.norm2(x), context, context) + x
278 | x = self.ff(self.norm3(x)) + x
279 | return x
280 |
281 |
282 | class SpatialTransformer(nn.Module):
283 | def __init__(self, in_channels, key_dim, value_dim, n_heads, d_head, depth=1, fuser_type=None, use_checkpoint=True):
284 | super().__init__()
285 | self.in_channels = in_channels
286 | query_dim = n_heads * d_head
287 | self.norm = Normalize(in_channels)
288 |
289 |
290 | self.proj_in = nn.Conv2d(in_channels,
291 | query_dim,
292 | kernel_size=1,
293 | stride=1,
294 | padding=0)
295 |
296 | self.transformer_blocks = nn.ModuleList(
297 | [BasicTransformerBlock(query_dim, key_dim, value_dim, n_heads, d_head, fuser_type, use_checkpoint=use_checkpoint)
298 | for d in range(depth)]
299 | )
300 |
301 | self.proj_out = zero_module(nn.Conv2d(query_dim,
302 | in_channels,
303 | kernel_size=1,
304 | stride=1,
305 | padding=0))
306 |
307 | def forward(self, x, context, objs):
308 | b, c, h, w = x.shape
309 | x_in = x
310 | x = self.norm(x)
311 | x = self.proj_in(x)
312 | x = rearrange(x, 'b c h w -> b (h w) c')
313 | for block in self.transformer_blocks:
314 | x = block(x, context, objs)
315 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
316 | x = self.proj_out(x)
317 | return x + x_in
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/condition_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from ldm.modules.diffusionmodules.util import FourierEmbedder
5 |
6 | from ldm.modules.diffusionmodules.util import (
7 | conv_nd,
8 | zero_module,
9 | normalization,
10 | )
11 |
12 |
13 | def conv_nd(dims, *args, **kwargs):
14 | """
15 | Create a 1D, 2D, or 3D convolution module.
16 | """
17 | if dims == 1:
18 | return nn.Conv1d(*args, **kwargs)
19 | elif dims == 2:
20 | return nn.Conv2d(*args, **kwargs)
21 | elif dims == 3:
22 | return nn.Conv3d(*args, **kwargs)
23 | raise ValueError(f"unsupported dimensions: {dims}")
24 |
25 |
26 | def avg_pool_nd(dims, *args, **kwargs):
27 | """
28 | Create a 1D, 2D, or 3D average pooling module.
29 | """
30 | if dims == 1:
31 | return nn.AvgPool1d(*args, **kwargs)
32 | elif dims == 2:
33 | return nn.AvgPool2d(*args, **kwargs)
34 | elif dims == 3:
35 | return nn.AvgPool3d(*args, **kwargs)
36 | raise ValueError(f"unsupported dimensions: {dims}")
37 |
38 |
39 | class Downsample(nn.Module):
40 | """
41 | A downsampling layer with an optional convolution.
42 | :param channels: channels in the inputs and outputs.
43 | :param use_conv: a bool determining if a convolution is applied.
44 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
45 | downsampling occurs in the inner-two dimensions.
46 | """
47 |
48 | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
49 | super().__init__()
50 | self.channels = channels
51 | self.out_channels = out_channels or channels
52 | self.use_conv = use_conv
53 | self.dims = dims
54 | stride = 2 if dims != 3 else (1, 2, 2)
55 | if use_conv:
56 | self.op = conv_nd(
57 | dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
58 | )
59 | else:
60 | assert self.channels == self.out_channels
61 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
62 |
63 | def forward(self, x):
64 | assert x.shape[1] == self.channels
65 | return self.op(x)
66 |
67 |
68 | class ResnetBlock(nn.Module):
69 | def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
70 | super().__init__()
71 | ps = ksize // 2
72 | if in_c != out_c or sk == False:
73 | self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
74 | else:
75 | self.in_conv = None
76 |
77 | self.norm = normalization(out_c)
78 |
79 | self.body = nn.Sequential(
80 | conv_nd(2, out_c, out_c, 3, padding=1),
81 | nn.SiLU(),
82 | conv_nd(2, out_c, out_c, ksize, padding=0),
83 | )
84 | if sk == False:
85 | self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
86 | else:
87 | self.skep = None
88 |
89 | self.down = down
90 | if self.down == True:
91 | self.down_opt = Downsample(in_c, use_conv=use_conv)
92 |
93 | def forward(self, x):
94 | if self.down == True:
95 | x = self.down_opt(x)
96 | if self.in_conv is not None: # edit
97 | x = self.in_conv(x)
98 |
99 | x = self.norm(x)
100 | h = self.body(x)
101 | if self.skep is not None:
102 | return h + self.skep(x)
103 | else:
104 | return h + x
105 |
106 |
107 | class ImageConditionNet(nn.Module):
108 | def __init__(self, autoencoder=None, channels=[320, 640, 1280, 1280], nums_rb=3, cin=4, ksize=3, sk=False, use_conv=True):
109 | super(ImageConditionNet, self).__init__()
110 | self.autoencoder = autoencoder
111 | self.set = False
112 |
113 | self.channels = channels
114 | self.nums_rb = nums_rb
115 |
116 | self.in_layers = nn.Sequential(
117 | nn.InstanceNorm2d(cin, affine=True),
118 | conv_nd(2, cin, channels[0], 3, padding=1),
119 | )
120 |
121 | self.body = []
122 | self.out_convs = []
123 | for i in range(len(channels)):
124 | for j in range(nums_rb):
125 | if (i != 0) and (j == 0):
126 | self.body.append(
127 | ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
128 | else:
129 | self.body.append(
130 | ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
131 | self.out_convs.append(
132 | zero_module(
133 | conv_nd(2, channels[i], channels[i], 1, padding=0, bias=False)
134 | )
135 | )
136 |
137 | self.body = nn.ModuleList(self.body)
138 | self.out_convs = nn.ModuleList(self.out_convs)
139 |
140 | def forward(self, input_dict, h=None):
141 | assert self.set
142 |
143 | x, masks = input_dict['values'], input_dict['masks']
144 |
145 | with torch.no_grad():
146 | x = self.autoencoder.encode(x)
147 |
148 | # extract features
149 | features = []
150 |
151 | x = self.in_layers(x)
152 |
153 | for i in range(len(self.channels)):
154 | for j in range(self.nums_rb):
155 | idx = i * self.nums_rb + j
156 | x = self.body[idx](x)
157 | x = self.out_convs[i](x)
158 | features.append(x)
159 |
160 |
161 | return features
162 |
163 |
164 | class AutoencoderKLWrapper(nn.Module):
165 | def __init__(self, autoencoder=None):
166 | super().__init__()
167 | self.autoencoder = autoencoder
168 | self.set = False
169 |
170 | def forward(self, input_dict):
171 | assert self.set
172 | x = input_dict['values']
173 | with torch.no_grad():
174 | x = self.autoencoder.encode(x)
175 | return x
176 |
177 |
178 | class NSPVectorConditionNet(nn.Module):
179 | def __init__(self, in_dim, out_dim, norm=False, fourier_freqs=0, temperature=100, scale=1.0):
180 | super().__init__()
181 | self.in_dim = in_dim
182 | self.out_dim = out_dim
183 | self.norm = norm
184 | self.fourier_freqs = fourier_freqs
185 |
186 | if fourier_freqs > 0:
187 | self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs, temperature=temperature)
188 | self.in_dim *= (fourier_freqs * 2)
189 | if self.norm:
190 | self.linears = nn.Sequential(
191 | nn.Linear(self.in_dim, 512),
192 | nn.LayerNorm(512),
193 | nn.SiLU(),
194 | nn.Linear(512, 512),
195 | nn.LayerNorm(512),
196 | nn.SiLU(),
197 | nn.Linear(512, out_dim),
198 | )
199 | else:
200 | self.linears = nn.Sequential(
201 | nn.Linear(self.in_dim, 512),
202 | nn.SiLU(),
203 | nn.Linear(512, 512),
204 | nn.SiLU(),
205 | nn.Linear(512, out_dim),
206 | )
207 |
208 | self.null_features = torch.nn.Parameter(torch.zeros([self.in_dim]))
209 | self.scale = scale
210 |
211 | def forward(self, input_dict):
212 | vectors, masks = input_dict['values'], input_dict['masks'] # vectors: B*C, masks: B*1
213 | if self.fourier_freqs > 0:
214 | vectors = self.fourier_embedder(vectors * self.scale)
215 | objs = masks * vectors + (1-masks) * self.null_features.view(1,-1)
216 | return self.linears(objs).unsqueeze(1) # B*1*C
217 |
218 |
219 | class BoxConditionNet(nn.Module):
220 | def __init__(self, in_dim, out_dim, fourier_freqs=8):
221 | super().__init__()
222 | self.in_dim = in_dim
223 | self.out_dim = out_dim
224 |
225 | self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
226 | self.position_dim = fourier_freqs*2*4 # 2 is sin&cos, 4 is xyxy
227 |
228 | self.linears = nn.Sequential(
229 | nn.Linear(self.in_dim + self.position_dim, 512),
230 | nn.SiLU(),
231 | nn.Linear(512, 512),
232 | nn.SiLU(),
233 | nn.Linear(512, out_dim),
234 | )
235 |
236 | self.null_text_feature = torch.nn.Parameter(torch.zeros([self.in_dim]))
237 | self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
238 |
239 |
240 | def forward(self, input_dict):
241 | boxes, masks, text_embeddings = input_dict['values'], input_dict['masks'], input_dict['text_embeddings']
242 |
243 | B, N, _ = boxes.shape
244 | masks = masks.unsqueeze(-1)
245 |
246 | # embedding position (it may includes padding as placeholder)
247 | xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
248 |
249 | # learnable null embedding
250 | text_null = self.null_text_feature.view(1,1,-1)
251 | xyxy_null = self.null_position_feature.view(1,1,-1)
252 |
253 | # replace padding with learnable null embedding
254 | text_embeddings = text_embeddings*masks + (1-masks)*text_null
255 | xyxy_embedding = xyxy_embedding*masks + (1-masks)*xyxy_null
256 |
257 | objs = self.linears( torch.cat([text_embeddings, xyxy_embedding], dim=-1) )
258 | assert objs.shape == torch.Size([B,N,self.out_dim])
259 | return objs
260 |
261 |
262 | class KeypointConditionNet(nn.Module):
263 | def __init__(self, max_persons_per_image, out_dim, fourier_freqs=8):
264 | super().__init__()
265 | self.max_persons_per_image = max_persons_per_image
266 | self.out_dim = out_dim
267 |
268 | self.person_embeddings = torch.nn.Parameter(torch.zeros([max_persons_per_image,out_dim]))
269 | self.keypoint_embeddings = torch.nn.Parameter(torch.zeros([17,out_dim]))
270 |
271 |
272 | self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
273 | self.position_dim = fourier_freqs*2*2 # 2 is sin&cos, 2 is xy
274 |
275 | self.linears = nn.Sequential(
276 | nn.Linear(self.out_dim + self.position_dim, 512),
277 | nn.SiLU(),
278 | nn.Linear(512, 512),
279 | nn.SiLU(),
280 | nn.Linear(512, out_dim),
281 | )
282 |
283 | self.null_person_feature = torch.nn.Parameter(torch.zeros([self.out_dim]))
284 | self.null_xy_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
285 |
286 |
287 | def forward(self, input_dict):
288 | points, masks = input_dict['values'], input_dict['masks']
289 |
290 | masks = masks.unsqueeze(-1)
291 | N = points.shape[0]
292 |
293 | person_embeddings = self.person_embeddings.unsqueeze(1).repeat(1,17,1).reshape(self.max_persons_per_image*17, self.out_dim)
294 | keypoint_embeddings = torch.cat([self.keypoint_embeddings]*self.max_persons_per_image, dim=0)
295 | person_embeddings = person_embeddings + keypoint_embeddings # (num_person*17) * C
296 | person_embeddings = person_embeddings.unsqueeze(0).repeat(N,1,1)
297 |
298 | # embedding position (it may includes padding as placeholder)
299 | xy_embedding = self.fourier_embedder(points) # B*N*2 --> B*N*C
300 |
301 |
302 | # learnable null embedding
303 | person_null = self.null_person_feature.view(1,1,-1)
304 | xy_null = self.null_xy_feature.view(1,1,-1)
305 |
306 | # replace padding with learnable null embedding
307 | person_embeddings = person_embeddings*masks + (1-masks)*person_null
308 | xy_embedding = xy_embedding*masks + (1-masks)*xy_null
309 |
310 | objs = self.linears( torch.cat([person_embeddings, xy_embedding], dim=-1) )
311 |
312 | return objs
313 |
314 |
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/multimodal_openaimodel.py:
--------------------------------------------------------------------------------
1 | """
2 | Ack: Our code is highly relied on GLIGEN (https://github.com/gligen/GLIGEN).
3 | """
4 | from abc import abstractmethod
5 | from functools import partial
6 | import math
7 |
8 | import numpy as np
9 | import random
10 | import torch as th
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | from ldm.modules.diffusionmodules.util import (
15 | conv_nd,
16 | linear,
17 | avg_pool_nd,
18 | zero_module,
19 | normalization,
20 | timestep_embedding,
21 | )
22 | from ldm.modules.multimodal_attention import SpatialTransformer
23 | # from .positionnet import PositionNet
24 | from torch.utils import checkpoint
25 | from ldm.util import instantiate_from_config
26 |
27 |
28 | class TimestepBlock(nn.Module):
29 | """
30 | Any module where forward() takes timestep embeddings as a second argument.
31 | """
32 |
33 | @abstractmethod
34 | def forward(self, x, emb):
35 | """
36 | Apply the module to `x` given `emb` timestep embeddings.
37 | """
38 |
39 |
40 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
41 | """
42 | A sequential module that passes timestep embeddings to the children that
43 | support it as an extra input.
44 | """
45 |
46 | def forward(self, x, emb, context, sp_objs, nsp_objs):
47 | for layer in self:
48 | if isinstance(layer, TimestepBlock):
49 | x = layer(x, emb)
50 | elif isinstance(layer, SpatialTransformer):
51 | x = layer(x, context, sp_objs, nsp_objs)
52 | else:
53 | x = layer(x)
54 | return x
55 |
56 |
57 | class Upsample(nn.Module):
58 | """
59 | An upsampling layer with an optional convolution.
60 | :param channels: channels in the inputs and outputs.
61 | :param use_conv: a bool determining if a convolution is applied.
62 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
63 | upsampling occurs in the inner-two dimensions.
64 | """
65 |
66 | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
67 | super().__init__()
68 | self.channels = channels
69 | self.out_channels = out_channels or channels
70 | self.use_conv = use_conv
71 | self.dims = dims
72 | if use_conv:
73 | self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
74 |
75 | def forward(self, x):
76 | assert x.shape[1] == self.channels
77 | if self.dims == 3:
78 | x = F.interpolate(
79 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
80 | )
81 | else:
82 | x = F.interpolate(x, scale_factor=2, mode="nearest")
83 | if self.use_conv:
84 | x = self.conv(x)
85 | return x
86 |
87 |
88 |
89 |
90 | class Downsample(nn.Module):
91 | """
92 | A downsampling layer with an optional convolution.
93 | :param channels: channels in the inputs and outputs.
94 | :param use_conv: a bool determining if a convolution is applied.
95 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
96 | downsampling occurs in the inner-two dimensions.
97 | """
98 |
99 | def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
100 | super().__init__()
101 | self.channels = channels
102 | self.out_channels = out_channels or channels
103 | self.use_conv = use_conv
104 | self.dims = dims
105 | stride = 2 if dims != 3 else (1, 2, 2)
106 | if use_conv:
107 | self.op = conv_nd(
108 | dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
109 | )
110 | else:
111 | assert self.channels == self.out_channels
112 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
113 |
114 | def forward(self, x):
115 | assert x.shape[1] == self.channels
116 | return self.op(x)
117 |
118 |
119 | class ResBlock(TimestepBlock):
120 | """
121 | A residual block that can optionally change the number of channels.
122 | :param channels: the number of input channels.
123 | :param emb_channels: the number of timestep embedding channels.
124 | :param dropout: the rate of dropout.
125 | :param out_channels: if specified, the number of out channels.
126 | :param use_conv: if True and out_channels is specified, use a spatial
127 | convolution instead of a smaller 1x1 convolution to change the
128 | channels in the skip connection.
129 | :param dims: determines if the signal is 1D, 2D, or 3D.
130 | :param use_checkpoint: if True, use gradient checkpointing on this module.
131 | :param up: if True, use this block for upsampling.
132 | :param down: if True, use this block for downsampling.
133 | """
134 |
135 | def __init__(
136 | self,
137 | channels,
138 | emb_channels,
139 | dropout,
140 | out_channels=None,
141 | use_conv=False,
142 | use_scale_shift_norm=False,
143 | dims=2,
144 | use_checkpoint=False,
145 | up=False,
146 | down=False,
147 | ):
148 | super().__init__()
149 | self.channels = channels
150 | self.emb_channels = emb_channels
151 | self.dropout = dropout
152 | self.out_channels = out_channels or channels
153 | self.use_conv = use_conv
154 | self.use_checkpoint = use_checkpoint
155 | self.use_scale_shift_norm = use_scale_shift_norm
156 |
157 | self.in_layers = nn.Sequential(
158 | normalization(channels),
159 | nn.SiLU(),
160 | conv_nd(dims, channels, self.out_channels, 3, padding=1),
161 | )
162 |
163 | self.updown = up or down
164 |
165 | if up:
166 | self.h_upd = Upsample(channels, False, dims)
167 | self.x_upd = Upsample(channels, False, dims)
168 | elif down:
169 | self.h_upd = Downsample(channels, False, dims)
170 | self.x_upd = Downsample(channels, False, dims)
171 | else:
172 | self.h_upd = self.x_upd = nn.Identity()
173 |
174 | self.emb_layers = nn.Sequential(
175 | nn.SiLU(),
176 | linear(
177 | emb_channels,
178 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
179 | ),
180 | )
181 | self.out_layers = nn.Sequential(
182 | normalization(self.out_channels),
183 | nn.SiLU(),
184 | nn.Dropout(p=dropout),
185 | zero_module(
186 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
187 | ),
188 | )
189 |
190 | if self.out_channels == channels:
191 | self.skip_connection = nn.Identity()
192 | elif use_conv:
193 | self.skip_connection = conv_nd(
194 | dims, channels, self.out_channels, 3, padding=1
195 | )
196 | else:
197 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
198 |
199 | def forward(self, x, emb):
200 | """
201 | Apply the block to a Tensor, conditioned on a timestep embedding.
202 | :param x: an [N x C x ...] Tensor of features.
203 | :param emb: an [N x emb_channels] Tensor of timestep embeddings.
204 | :return: an [N x C x ...] Tensor of outputs.
205 | """
206 | # return checkpoint(
207 | # self._forward, (x, emb), self.parameters(), self.use_checkpoint
208 | # )
209 | if self.use_checkpoint and x.requires_grad:
210 | return checkpoint.checkpoint(self._forward, x, emb )
211 | else:
212 | return self._forward(x, emb)
213 |
214 |
215 | def _forward(self, x, emb):
216 | if self.updown:
217 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
218 | h = in_rest(x)
219 | h = self.h_upd(h)
220 | x = self.x_upd(x)
221 | h = in_conv(h)
222 | else:
223 | h = self.in_layers(x)
224 | emb_out = self.emb_layers(emb).type(h.dtype)
225 | while len(emb_out.shape) < len(h.shape):
226 | emb_out = emb_out[..., None]
227 | if self.use_scale_shift_norm:
228 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
229 | scale, shift = th.chunk(emb_out, 2, dim=1)
230 | h = out_norm(h) * (1 + scale) + shift
231 | h = out_rest(h)
232 | else:
233 | h = h + emb_out
234 | h = self.out_layers(h)
235 | return self.skip_connection(x) + h
236 |
237 |
238 |
239 |
240 | class UNetModel(nn.Module):
241 | def __init__(
242 | self,
243 | image_size,
244 | in_channels,
245 | model_channels,
246 | out_channels,
247 | num_res_blocks,
248 | attention_resolutions,
249 | dropout=0,
250 | channel_mult=(1, 2, 4, 8),
251 | conv_resample=True,
252 | dims=2,
253 | use_checkpoint=False,
254 | num_heads=8,
255 | use_scale_shift_norm=False,
256 | transformer_depth=1,
257 | context_dim=None,
258 | fuser_type = None,
259 | inpaint_mode = False,
260 | grounding_tokenizer = None,
261 | init_alpha_pre_input_conv=0.1,
262 | use_autoencoder_kl=False,
263 | image_cond_injection_type=None,
264 | input_modalities=[],
265 | input_types=[],
266 | freeze_modules=[],
267 | ):
268 | super().__init__()
269 |
270 | self.image_size = image_size
271 | self.in_channels = in_channels
272 | self.model_channels = model_channels
273 | self.out_channels = out_channels
274 | self.num_res_blocks = num_res_blocks
275 | self.attention_resolutions = attention_resolutions
276 | self.dropout = dropout
277 | self.channel_mult = channel_mult
278 | self.conv_resample = conv_resample
279 | self.use_checkpoint = use_checkpoint
280 | self.num_heads = num_heads
281 | self.context_dim = context_dim
282 | self.fuser_type = fuser_type
283 | self.inpaint_mode = inpaint_mode
284 | self.freeze_modules = freeze_modules
285 | assert fuser_type in ["gatedSA", "gatedCA", "gatedSA-gatedCA", "gatedCA-gatedSA"]
286 |
287 | self.input_modalities = input_modalities
288 | self.input_types = input_types
289 | self.use_autoencoder_kl = use_autoencoder_kl
290 | self.image_cond_injection_type = image_cond_injection_type
291 | assert self.image_cond_injection_type is not None
292 |
293 |
294 | time_embed_dim = model_channels * 4
295 | self.time_embed = nn.Sequential(
296 | linear(model_channels, time_embed_dim),
297 | nn.SiLU(),
298 | linear(time_embed_dim, time_embed_dim),
299 | )
300 |
301 | num_image_condition = self.input_types.count("image")
302 | num_sp_condition = self.input_types.count("sp_vector")
303 | use_sp = num_sp_condition > 0
304 | num_nsp_condition = self.input_types.count("nsp_vector")
305 | use_nsp = num_nsp_condition > 0
306 |
307 | if num_image_condition >= 1:
308 | pass
309 |
310 | if inpaint_mode:
311 | # The new added channels are: masked image (encoded image) and mask, which is 4+1
312 | self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels+in_channels+1, model_channels, 3, padding=1))])
313 | else:
314 | """ Enlarged mode"""
315 | # self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, (num_image_condition+1)*in_channels, model_channels, 3, padding=1))])
316 | """ Non-enlarged mode"""
317 | self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))])
318 |
319 |
320 | input_block_chans = [model_channels]
321 | ch = model_channels
322 | ds = 1
323 |
324 | # = = = = = = = = = = = = = = = = = = = = Down Branch = = = = = = = = = = = = = = = = = = = = #
325 | for level, mult in enumerate(channel_mult):
326 | for _ in range(num_res_blocks):
327 | layers = [ ResBlock(ch,
328 | time_embed_dim,
329 | dropout,
330 | out_channels=mult * model_channels,
331 | dims=dims,
332 | use_checkpoint=use_checkpoint,
333 | use_scale_shift_norm=use_scale_shift_norm,) ]
334 |
335 | ch = mult * model_channels
336 | if ds in attention_resolutions:
337 | dim_head = ch // num_heads
338 | layers.append(SpatialTransformer(ch, key_dim=context_dim, value_dim=context_dim, n_heads=num_heads, d_head=dim_head, depth=transformer_depth, fuser_type=fuser_type, use_checkpoint=use_checkpoint, use_sp=use_sp, use_nsp=use_nsp))
339 |
340 | self.input_blocks.append(TimestepEmbedSequential(*layers))
341 | input_block_chans.append(ch)
342 |
343 | if level != len(channel_mult) - 1: # will not go to this downsample branch in the last feature
344 | out_ch = ch
345 | self.input_blocks.append( TimestepEmbedSequential( Downsample(ch, conv_resample, dims=dims, out_channels=out_ch ) ) )
346 | ch = out_ch
347 | input_block_chans.append(ch)
348 | ds *= 2
349 | dim_head = ch // num_heads
350 |
351 |
352 | # = = = = = = = = = = = = = = = = = = = = BottleNeck = = = = = = = = = = = = = = = = = = = = #
353 |
354 | self.middle_block = TimestepEmbedSequential(
355 | ResBlock(ch,
356 | time_embed_dim,
357 | dropout,
358 | dims=dims,
359 | use_checkpoint=use_checkpoint,
360 | use_scale_shift_norm=use_scale_shift_norm),
361 | SpatialTransformer(ch, key_dim=context_dim, value_dim=context_dim, n_heads=num_heads, d_head=dim_head, depth=transformer_depth, fuser_type=fuser_type, use_checkpoint=use_checkpoint, use_sp=use_sp, use_nsp=use_nsp),
362 | ResBlock(ch,
363 | time_embed_dim,
364 | dropout,
365 | dims=dims,
366 | use_checkpoint=use_checkpoint,
367 | use_scale_shift_norm=use_scale_shift_norm))
368 |
369 |
370 |
371 | # = = = = = = = = = = = = = = = = = = = = Up Branch = = = = = = = = = = = = = = = = = = = = #
372 |
373 |
374 | self.output_blocks = nn.ModuleList([])
375 | for level, mult in list(enumerate(channel_mult))[::-1]:
376 | for i in range(num_res_blocks + 1):
377 | ich = input_block_chans.pop()
378 | layers = [ ResBlock(ch + ich,
379 | time_embed_dim,
380 | dropout,
381 | out_channels=model_channels * mult,
382 | dims=dims,
383 | use_checkpoint=use_checkpoint,
384 | use_scale_shift_norm=use_scale_shift_norm) ]
385 | ch = model_channels * mult
386 |
387 | if ds in attention_resolutions:
388 | dim_head = ch // num_heads
389 | layers.append( SpatialTransformer(ch, key_dim=context_dim, value_dim=context_dim, n_heads=num_heads, d_head=dim_head, depth=transformer_depth, fuser_type=fuser_type, use_checkpoint=use_checkpoint, use_sp=use_sp, use_nsp=use_nsp) )
390 | if level and i == num_res_blocks:
391 | out_ch = ch
392 | layers.append( Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) )
393 | ds //= 2
394 |
395 | self.output_blocks.append(TimestepEmbedSequential(*layers))
396 |
397 |
398 |
399 | self.out = nn.Sequential(
400 | normalization(ch),
401 | nn.SiLU(),
402 | zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
403 | )
404 |
405 |
406 | # = = = = = = = = = = = = = = = = = = = = Multimodal Condition Networks = = = = = = = = = = = = = = = = = = = = #
407 |
408 | self.condition_nets = nn.ModuleDict()
409 | for mode in self.input_modalities:
410 | self.condition_nets[mode] = instantiate_from_config(grounding_tokenizer["tokenizer_{}".format(mode)])
411 |
412 | self.scales = [1.0]*4
413 |
414 |
415 | def forward(self, input_dict):
416 | condition = input_dict["condition"]
417 |
418 | # aggregate objs by each type of mode
419 | im_objs, sp_objs, nsp_objs = [], [], []
420 | for mode, input_type in zip(self.input_modalities, self.input_types):
421 | assert mode in condition
422 |
423 | objs = self.condition_nets[mode](condition[mode])
424 |
425 | if input_type == "image":
426 | im_objs.append(objs) # B*C*H*W
427 | elif input_type == "sp_vector":
428 | sp_objs.append(objs) # B*N*C
429 | elif input_type == "nsp_vector":
430 | nsp_objs.append(objs) # B*1*C
431 | else:
432 | raise NotImplementedError
433 |
434 | # aggregate image form conditions
435 | im_objs = [th.stack(arr, dim=0).sum(0) for arr in zip(*im_objs)] if len(im_objs) > 0 else None
436 |
437 | sp_objs = th.cat(sp_objs, dim=1) if len(sp_objs)>0 else None
438 | nsp_objs = th.cat(nsp_objs, dim=1) if len(nsp_objs)>0 else None
439 |
440 | # Time embedding
441 | t_emb = timestep_embedding(input_dict["timesteps"], self.model_channels, repeat_only=False)
442 | emb = self.time_embed(t_emb)
443 |
444 | # input tensor
445 | h = input_dict["x"]
446 |
447 | if self.inpaint_mode:
448 | h = th.cat( [h, input_dict["inpainting_extra_input"]], dim=1 )
449 |
450 | # Text input
451 | context = input_dict["context"]
452 |
453 | # Start forwarding
454 | hs = []
455 | adapter_idx = 0
456 | for i, module in enumerate(self.input_blocks):
457 | if self.image_cond_injection_type == "enc" and (i+1) % 3 == 0 and im_objs is not None:
458 | h = module(h, emb, context, sp_objs, nsp_objs)
459 | h = h + self.scales[adapter_idx] * im_objs[adapter_idx]
460 | adapter_idx += 1
461 | else:
462 | h = module(h, emb, context, sp_objs, nsp_objs)
463 | hs.append(h)
464 |
465 | h = self.middle_block(h, emb, context, sp_objs, nsp_objs)
466 |
467 | adapter_idx = 0
468 | for i, module in enumerate(self.output_blocks):
469 | if self.image_cond_injection_type == "dec" and i % 3 == 0 and im_objs is not None:
470 | enc_h = hs.pop() + self.scales[adapter_idx] * im_objs.pop()
471 | adapter_idx += 1
472 | else:
473 | enc_h = hs.pop()
474 | h = th.cat([h, enc_h], dim=1)
475 |
476 | h = module(h, emb, context, sp_objs, nsp_objs)
477 |
478 | return self.out(h)
479 |
480 |
481 |
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 | from einops import repeat
7 |
8 | from ldm.util import instantiate_from_config
9 |
10 |
11 |
12 | class FourierEmbedder():
13 | def __init__(self, num_freqs=64, temperature=100):
14 |
15 | self.num_freqs = num_freqs
16 | self.temperature = temperature
17 | self.freq_bands = temperature ** ( torch.arange(num_freqs) / num_freqs )
18 |
19 | @ torch.no_grad()
20 | def __call__(self, x, cat_dim=-1):
21 | "x: arbitrary shape of tensor. dim: cat dim"
22 | out = []
23 | for freq in self.freq_bands:
24 | out.append( torch.sin( freq*x ) )
25 | out.append( torch.cos( freq*x ) )
26 | return torch.cat(out, cat_dim)
27 |
28 |
29 |
30 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
31 | if schedule == "linear":
32 | betas = (
33 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
34 | )
35 |
36 | elif schedule == "cosine":
37 | timesteps = (
38 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
39 | )
40 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
41 | alphas = torch.cos(alphas).pow(2)
42 | alphas = alphas / alphas[0]
43 | betas = 1 - alphas[1:] / alphas[:-1]
44 | betas = np.clip(betas, a_min=0, a_max=0.999)
45 |
46 | elif schedule == "sqrt_linear":
47 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
48 | elif schedule == "sqrt":
49 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
50 | else:
51 | raise ValueError(f"schedule '{schedule}' unknown.")
52 | return betas.numpy()
53 |
54 |
55 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
56 | if ddim_discr_method == 'uniform':
57 | c = num_ddpm_timesteps // num_ddim_timesteps
58 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
59 | elif ddim_discr_method == 'quad':
60 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
61 | else:
62 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
63 |
64 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
65 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
66 | steps_out = ddim_timesteps + 1
67 | if verbose:
68 | print(f'Selected timesteps for ddim sampler: {steps_out}')
69 | return steps_out
70 |
71 |
72 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
73 | # select alphas for computing the variance schedule
74 | alphas = alphacums[ddim_timesteps]
75 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
76 |
77 | # according the the formula provided in https://arxiv.org/abs/2010.02502
78 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
79 | if verbose:
80 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
81 | print(f'For the chosen value of eta, which is {eta}, '
82 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
83 | return sigmas, alphas, alphas_prev
84 |
85 |
86 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
87 | """
88 | Create a beta schedule that discretizes the given alpha_t_bar function,
89 | which defines the cumulative product of (1-beta) over time from t = [0,1].
90 | :param num_diffusion_timesteps: the number of betas to produce.
91 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
92 | produces the cumulative product of (1-beta) up to that
93 | part of the diffusion process.
94 | :param max_beta: the maximum beta to use; use values lower than 1 to
95 | prevent singularities.
96 | """
97 | betas = []
98 | for i in range(num_diffusion_timesteps):
99 | t1 = i / num_diffusion_timesteps
100 | t2 = (i + 1) / num_diffusion_timesteps
101 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
102 | return np.array(betas)
103 |
104 |
105 | def extract_into_tensor(a, t, x_shape):
106 | b, *_ = t.shape
107 | out = a.gather(-1, t)
108 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
109 |
110 |
111 | def checkpoint(func, inputs, params, flag):
112 | """
113 | Evaluate a function without caching intermediate activations, allowing for
114 | reduced memory at the expense of extra compute in the backward pass.
115 | :param func: the function to evaluate.
116 | :param inputs: the argument sequence to pass to `func`.
117 | :param params: a sequence of parameters `func` depends on but does not
118 | explicitly take as arguments.
119 | :param flag: if False, disable gradient checkpointing.
120 | """
121 | if flag:
122 | args = tuple(inputs) + tuple(params)
123 | return CheckpointFunction.apply(func, len(inputs), *args)
124 | else:
125 | return func(*inputs)
126 |
127 |
128 | class CheckpointFunction(torch.autograd.Function):
129 | @staticmethod
130 | def forward(ctx, run_function, length, *args):
131 | ctx.run_function = run_function
132 | ctx.input_tensors = list(args[:length])
133 | ctx.input_params = list(args[length:])
134 |
135 | with torch.no_grad():
136 | output_tensors = ctx.run_function(*ctx.input_tensors)
137 | return output_tensors
138 |
139 | @staticmethod
140 | def backward(ctx, *output_grads):
141 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
142 | with torch.enable_grad():
143 | # Fixes a bug where the first op in run_function modifies the
144 | # Tensor storage in place, which is not allowed for detach()'d
145 | # Tensors.
146 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
147 | output_tensors = ctx.run_function(*shallow_copies)
148 | input_grads = torch.autograd.grad(
149 | output_tensors,
150 | ctx.input_tensors + ctx.input_params,
151 | output_grads,
152 | allow_unused=True,
153 | )
154 | del ctx.input_tensors
155 | del ctx.input_params
156 | del output_tensors
157 | return (None, None) + input_grads
158 |
159 |
160 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
161 | """
162 | Create sinusoidal timestep embeddings.
163 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
164 | These may be fractional.
165 | :param dim: the dimension of the output.
166 | :param max_period: controls the minimum frequency of the embeddings.
167 | :return: an [N x dim] Tensor of positional embeddings.
168 | """
169 | if not repeat_only:
170 | half = dim // 2
171 | freqs = torch.exp(
172 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
173 | ).to(device=timesteps.device)
174 | args = timesteps[:, None].float() * freqs[None]
175 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
176 | if dim % 2:
177 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
178 | else:
179 | embedding = repeat(timesteps, 'b -> b d', d=dim)
180 | return embedding
181 |
182 |
183 | def zero_module(module):
184 | """
185 | Zero out the parameters of a module and return it.
186 | """
187 | for p in module.parameters():
188 | p.detach().zero_()
189 | return module
190 |
191 |
192 | def scale_module(module, scale):
193 | """
194 | Scale the parameters of a module and return it.
195 | """
196 | for p in module.parameters():
197 | p.detach().mul_(scale)
198 | return module
199 |
200 |
201 | def mean_flat(tensor):
202 | """
203 | Take the mean over all non-batch dimensions.
204 | """
205 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
206 |
207 |
208 | def normalization(channels):
209 | """
210 | Make a standard normalization layer.
211 | :param channels: number of input channels.
212 | :return: an nn.Module for normalization.
213 | """
214 | return GroupNorm32(32, channels)
215 |
216 |
217 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
218 | class SiLU(nn.Module):
219 | def forward(self, x):
220 | return x * torch.sigmoid(x)
221 |
222 |
223 | class GroupNorm32(nn.GroupNorm):
224 | def forward(self, x):
225 | return super().forward(x.float()).type(x.dtype)
226 | #return super().forward(x).type(x.dtype)
227 |
228 | def conv_nd(dims, *args, **kwargs):
229 | """
230 | Create a 1D, 2D, or 3D convolution module.
231 | """
232 | if dims == 1:
233 | return nn.Conv1d(*args, **kwargs)
234 | elif dims == 2:
235 | return nn.Conv2d(*args, **kwargs)
236 | elif dims == 3:
237 | return nn.Conv3d(*args, **kwargs)
238 | raise ValueError(f"unsupported dimensions: {dims}")
239 |
240 |
241 | def linear(*args, **kwargs):
242 | """
243 | Create a linear module.
244 | """
245 | return nn.Linear(*args, **kwargs)
246 |
247 |
248 | def avg_pool_nd(dims, *args, **kwargs):
249 | """
250 | Create a 1D, 2D, or 3D average pooling module.
251 | """
252 | if dims == 1:
253 | return nn.AvgPool1d(*args, **kwargs)
254 | elif dims == 2:
255 | return nn.AvgPool2d(*args, **kwargs)
256 | elif dims == 3:
257 | return nn.AvgPool3d(*args, **kwargs)
258 | raise ValueError(f"unsupported dimensions: {dims}")
259 |
260 |
261 | class HybridConditioner(nn.Module):
262 |
263 | def __init__(self, c_concat_config, c_crossattn_config):
264 | super().__init__()
265 | self.concat_conditioner = instantiate_from_config(c_concat_config)
266 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
267 |
268 | def forward(self, c_concat, c_crossattn):
269 | c_concat = self.concat_conditioner(c_concat)
270 | c_crossattn = self.crossattn_conditioner(c_crossattn)
271 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
272 |
273 |
274 | def noise_like(shape, device, repeat=False):
275 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
276 | noise = lambda: torch.randn(shape, device=device)
277 | return repeat_noise() if repeat else noise()
--------------------------------------------------------------------------------
/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sungnyun/diffblender/b3481d17d0e13d89d45dcbb9a250e89459904031/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/encoders/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from functools import partial
4 | import clip
5 | from einops import rearrange, repeat
6 | from transformers import CLIPTokenizer, CLIPTextModel
7 | import kornia
8 |
9 | from ldm.modules.x_transformer import Encoder, TransformerWrapper
10 |
11 |
12 | class AbstractEncoder(nn.Module):
13 | def __init__(self):
14 | super().__init__()
15 |
16 | def encode(self, *args, **kwargs):
17 | raise NotImplementedError
18 |
19 |
20 |
21 | class ClassEmbedder(nn.Module):
22 | def __init__(self, embed_dim, n_classes=1000, key='class'):
23 | super().__init__()
24 | self.key = key
25 | self.embedding = nn.Embedding(n_classes, embed_dim)
26 |
27 | def forward(self, batch, key=None):
28 | if key is None:
29 | key = self.key
30 | # this is for use in crossattn
31 | c = batch[key][:, None]
32 | c = self.embedding(c)
33 | return c
34 |
35 |
36 | class TransformerEmbedder(AbstractEncoder):
37 | """Some transformer encoder layers"""
38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
39 | super().__init__()
40 | self.device = device
41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
42 | attn_layers=Encoder(dim=n_embed, depth=n_layer))
43 |
44 | def forward(self, tokens):
45 | tokens = tokens.to(self.device) # meh
46 | z = self.transformer(tokens, return_embeddings=True)
47 | return z
48 |
49 | def encode(self, x):
50 | return self(x)
51 |
52 |
53 | class BERTTokenizer(AbstractEncoder):
54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
55 | def __init__(self, device="cuda", vq_interface=True, max_length=77):
56 | super().__init__()
57 | from transformers import BertTokenizerFast # TODO: add to reuquirements
58 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
59 | self.device = device
60 | self.vq_interface = vq_interface
61 | self.max_length = max_length
62 |
63 | def forward(self, text):
64 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
65 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt",
66 | return_offsets_mapping=True)
67 | tokens = batch_encoding["input_ids"].to(self.device)
68 | offset_mapping = batch_encoding["offset_mapping"]
69 | return tokens, offset_mapping
70 |
71 | @torch.no_grad()
72 | def encode(self, text):
73 | tokens = self(text)
74 | if not self.vq_interface:
75 | return tokens
76 | return None, None, [None, None, tokens]
77 |
78 | def decode(self, text):
79 | return text
80 |
81 |
82 | class BERTEmbedder(AbstractEncoder):
83 | """Uses the BERT tokenizr model and add some transformer encoder layers"""
84 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
85 | device="cuda",use_tokenizer=True, embedding_dropout=0.0):
86 | super().__init__()
87 | self.use_tknz_fn = use_tokenizer
88 | if self.use_tknz_fn:
89 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
90 | self.device = device
91 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
92 | attn_layers=Encoder(dim=n_embed, depth=n_layer),
93 | emb_dropout=embedding_dropout)
94 |
95 | def forward(self, text, return_offset_mapping=False):
96 | if self.use_tknz_fn:
97 | tokens, offset_mapping = self.tknz_fn(text)#.to(self.device)
98 | else:
99 | assert False
100 | tokens = text
101 | z = self.transformer(tokens, return_embeddings=True)
102 |
103 | if return_offset_mapping:
104 | return z, offset_mapping
105 | else:
106 | return z
107 |
108 | def encode(self, text, return_offset_mapping=False):
109 | # output of length 77
110 | return self(text, return_offset_mapping)
111 |
112 |
113 | class SpatialRescaler(nn.Module):
114 | def __init__(self,
115 | n_stages=1,
116 | method='bilinear',
117 | multiplier=0.5,
118 | in_channels=3,
119 | out_channels=None,
120 | bias=False):
121 | super().__init__()
122 | self.n_stages = n_stages
123 | assert self.n_stages >= 0
124 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
125 | self.multiplier = multiplier
126 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
127 | self.remap_output = out_channels is not None
128 | if self.remap_output:
129 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
130 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
131 |
132 | def forward(self,x):
133 | for stage in range(self.n_stages):
134 | x = self.interpolator(x, scale_factor=self.multiplier)
135 |
136 |
137 | if self.remap_output:
138 | x = self.channel_mapper(x)
139 | return x
140 |
141 | def encode(self, x):
142 | return self(x)
143 |
144 | class FrozenCLIPEmbedder(AbstractEncoder):
145 | """Uses the CLIP transformer encoder for text (from Hugging Face)"""
146 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
147 | super().__init__()
148 | self.tokenizer = CLIPTokenizer.from_pretrained(version)
149 | self.transformer = CLIPTextModel.from_pretrained(version)
150 | self.device = device
151 | self.max_length = max_length
152 | self.freeze()
153 |
154 | def freeze(self):
155 | self.transformer = self.transformer.eval()
156 | for param in self.parameters():
157 | param.requires_grad = False
158 |
159 | def forward(self, text, return_pooler_output=False):
160 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
161 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
162 | tokens = batch_encoding["input_ids"].to(self.device)
163 | outputs = self.transformer(input_ids=tokens)
164 |
165 | z = outputs.last_hidden_state
166 |
167 | if not return_pooler_output:
168 | return z
169 | else:
170 | return z, outputs.pooler_output
171 |
172 | def encode(self, text, return_pooler_output=False):
173 | return self(text, return_pooler_output)
174 |
175 |
176 | class FrozenCLIPTextEmbedder(nn.Module):
177 | """
178 | Uses the CLIP transformer encoder for text.
179 | """
180 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
181 | super().__init__()
182 | self.model, _ = clip.load(version, jit=False, device="cpu")
183 | self.device = device
184 | self.max_length = max_length
185 | self.n_repeat = n_repeat
186 | self.normalize = normalize
187 |
188 | def freeze(self):
189 | self.model = self.model.eval()
190 | for param in self.parameters():
191 | param.requires_grad = False
192 |
193 | def forward(self, text):
194 | tokens = clip.tokenize(text).to(self.device)
195 | z = self.model.encode_text(tokens)
196 | if self.normalize:
197 | z = z / torch.linalg.norm(z, dim=1, keepdim=True)
198 | return z
199 |
200 | def encode(self, text):
201 | z = self(text)
202 | if z.ndim==2:
203 | z = z[:, None, :]
204 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
205 | return z
206 |
207 |
208 | class FrozenClipImageEmbedder(nn.Module):
209 | """
210 | Uses the CLIP image encoder.
211 | """
212 | def __init__(
213 | self,
214 | model,
215 | jit=False,
216 | device='cuda' if torch.cuda.is_available() else 'cpu',
217 | antialias=False,
218 | ):
219 | super().__init__()
220 | self.model, _ = clip.load(name=model, device=device, jit=jit)
221 |
222 | self.antialias = antialias
223 |
224 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
225 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
226 |
227 | def preprocess(self, x):
228 | # normalize to [0,1]
229 | x = kornia.geometry.resize(x, (224, 224),
230 | interpolation='bicubic',align_corners=True,
231 | antialias=self.antialias)
232 | x = (x + 1.) / 2.
233 | # renormalize according to clip
234 | x = kornia.enhance.normalize(x, self.mean, self.std)
235 | return x
236 |
237 | def forward(self, x):
238 | # x is assumed to be in range [-1,1]
239 | return self.model.encode_image(self.preprocess(x))
240 |
241 |
242 | if __name__ == "__main__":
243 | from ldm.util import count_params
244 | model = FrozenCLIPEmbedder()
245 | count_params(model, verbose=True)
246 |
--------------------------------------------------------------------------------
/ldm/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/ldm/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge"):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 |
32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33 | if last_layer is not None:
34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36 | else:
37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39 |
40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42 | d_weight = d_weight * self.discriminator_weight
43 | return d_weight
44 |
45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46 | global_step, last_layer=None, cond=None, split="train",
47 | weights=None):
48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49 | if self.perceptual_weight > 0:
50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51 | rec_loss = rec_loss + self.perceptual_weight * p_loss
52 |
53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54 | weighted_nll_loss = nll_loss
55 | if weights is not None:
56 | weighted_nll_loss = weights*nll_loss
57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59 | kl_loss = posteriors.kl()
60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61 |
62 | # now the GAN part
63 | if optimizer_idx == 0:
64 | # generator update
65 | if cond is None:
66 | assert not self.disc_conditional
67 | logits_fake = self.discriminator(reconstructions.contiguous())
68 | else:
69 | assert self.disc_conditional
70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71 | g_loss = -torch.mean(logits_fake)
72 |
73 | if self.disc_factor > 0.0:
74 | try:
75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76 | except RuntimeError:
77 | assert not self.training
78 | d_weight = torch.tensor(0.0)
79 | else:
80 | d_weight = torch.tensor(0.0)
81 |
82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
84 |
85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
87 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
88 | "{}/d_weight".format(split): d_weight.detach(),
89 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
90 | "{}/g_loss".format(split): g_loss.detach().mean(),
91 | }
92 | return loss, log
93 |
94 | if optimizer_idx == 1:
95 | # second pass for discriminator update
96 | if cond is None:
97 | logits_real = self.discriminator(inputs.contiguous().detach())
98 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
99 | else:
100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
102 |
103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
105 |
106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
107 | "{}/logits_real".format(split): logits_real.detach().mean(),
108 | "{}/logits_fake".format(split): logits_fake.detach().mean()
109 | }
110 | return d_loss, log
111 |
112 |
--------------------------------------------------------------------------------
/ldm/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import repeat
5 |
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 | from taming.modules.losses.lpips import LPIPS
8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9 |
10 |
11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15 | loss_real = (weights * loss_real).sum() / weights.sum()
16 | loss_fake = (weights * loss_fake).sum() / weights.sum()
17 | d_loss = 0.5 * (loss_real + loss_fake)
18 | return d_loss
19 |
20 | def adopt_weight(weight, global_step, threshold=0, value=0.):
21 | if global_step < threshold:
22 | weight = value
23 | return weight
24 |
25 |
26 | def measure_perplexity(predicted_indices, n_embed):
27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
30 | avg_probs = encodings.mean(0)
31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
32 | cluster_use = torch.sum(avg_probs > 0)
33 | return perplexity, cluster_use
34 |
35 | def l1(x, y):
36 | return torch.abs(x-y)
37 |
38 |
39 | def l2(x, y):
40 | return torch.pow((x-y), 2)
41 |
42 |
43 | class VQLPIPSWithDiscriminator(nn.Module):
44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
48 | pixel_loss="l1"):
49 | super().__init__()
50 | assert disc_loss in ["hinge", "vanilla"]
51 | assert perceptual_loss in ["lpips", "clips", "dists"]
52 | assert pixel_loss in ["l1", "l2"]
53 | self.codebook_weight = codebook_weight
54 | self.pixel_weight = pixelloss_weight
55 | if perceptual_loss == "lpips":
56 | print(f"{self.__class__.__name__}: Running with LPIPS.")
57 | self.perceptual_loss = LPIPS().eval()
58 | else:
59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60 | self.perceptual_weight = perceptual_weight
61 |
62 | if pixel_loss == "l1":
63 | self.pixel_loss = l1
64 | else:
65 | self.pixel_loss = l2
66 |
67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
68 | n_layers=disc_num_layers,
69 | use_actnorm=use_actnorm,
70 | ndf=disc_ndf
71 | ).apply(weights_init)
72 | self.discriminator_iter_start = disc_start
73 | if disc_loss == "hinge":
74 | self.disc_loss = hinge_d_loss
75 | elif disc_loss == "vanilla":
76 | self.disc_loss = vanilla_d_loss
77 | else:
78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
80 | self.disc_factor = disc_factor
81 | self.discriminator_weight = disc_weight
82 | self.disc_conditional = disc_conditional
83 | self.n_classes = n_classes
84 |
85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
86 | if last_layer is not None:
87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
89 | else:
90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
92 |
93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
95 | d_weight = d_weight * self.discriminator_weight
96 | return d_weight
97 |
98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
100 | if not exists(codebook_loss):
101 | codebook_loss = torch.tensor([0.]).to(inputs.device)
102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
104 | if self.perceptual_weight > 0:
105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
106 | rec_loss = rec_loss + self.perceptual_weight * p_loss
107 | else:
108 | p_loss = torch.tensor([0.0])
109 |
110 | nll_loss = rec_loss
111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
112 | nll_loss = torch.mean(nll_loss)
113 |
114 | # now the GAN part
115 | if optimizer_idx == 0:
116 | # generator update
117 | if cond is None:
118 | assert not self.disc_conditional
119 | logits_fake = self.discriminator(reconstructions.contiguous())
120 | else:
121 | assert self.disc_conditional
122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
123 | g_loss = -torch.mean(logits_fake)
124 |
125 | try:
126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
127 | except RuntimeError:
128 | assert not self.training
129 | d_weight = torch.tensor(0.0)
130 |
131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
133 |
134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
136 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
137 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
138 | "{}/p_loss".format(split): p_loss.detach().mean(),
139 | "{}/d_weight".format(split): d_weight.detach(),
140 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
141 | "{}/g_loss".format(split): g_loss.detach().mean(),
142 | }
143 | if predicted_indices is not None:
144 | assert self.n_classes is not None
145 | with torch.no_grad():
146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
147 | log[f"{split}/perplexity"] = perplexity
148 | log[f"{split}/cluster_usage"] = cluster_usage
149 | return loss, log
150 |
151 | if optimizer_idx == 1:
152 | # second pass for discriminator update
153 | if cond is None:
154 | logits_real = self.discriminator(inputs.contiguous().detach())
155 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
156 | else:
157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
159 |
160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
162 |
163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
164 | "{}/logits_real".format(split): logits_real.detach().mean(),
165 | "{}/logits_fake".format(split): logits_fake.detach().mean()
166 | }
167 | return d_loss, log
168 |
--------------------------------------------------------------------------------
/ldm/modules/multimodal_attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn, einsum
6 | from einops import rearrange, repeat
7 |
8 | # from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder
9 | from torch.utils import checkpoint
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def uniq(arr):
16 | return{el: True for el in arr}.keys()
17 |
18 |
19 | def default(val, d):
20 | if exists(val):
21 | return val
22 | return d() if isfunction(d) else d
23 |
24 |
25 | def max_neg_value(t):
26 | return -torch.finfo(t.dtype).max
27 |
28 |
29 | def init_(tensor):
30 | dim = tensor.shape[-1]
31 | std = 1 / math.sqrt(dim)
32 | tensor.uniform_(-std, std)
33 | return tensor
34 |
35 |
36 | # feedforward
37 | class GEGLU(nn.Module):
38 | def __init__(self, dim_in, dim_out):
39 | super().__init__()
40 | self.proj = nn.Linear(dim_in, dim_out * 2)
41 |
42 | def forward(self, x):
43 | x, gate = self.proj(x).chunk(2, dim=-1)
44 | return x * F.gelu(gate)
45 |
46 |
47 | class FeedForward(nn.Module):
48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49 | super().__init__()
50 | inner_dim = int(dim * mult)
51 | dim_out = default(dim_out, dim)
52 | project_in = nn.Sequential(
53 | nn.Linear(dim, inner_dim),
54 | nn.GELU()
55 | ) if not glu else GEGLU(dim, inner_dim)
56 |
57 | self.net = nn.Sequential(
58 | project_in,
59 | nn.Dropout(dropout),
60 | nn.Linear(inner_dim, dim_out)
61 | )
62 |
63 | def forward(self, x):
64 | return self.net(x)
65 |
66 |
67 | def zero_module(module):
68 | """
69 | Zero out the parameters of a module and return it.
70 | """
71 | for p in module.parameters():
72 | p.detach().zero_()
73 | return module
74 |
75 |
76 | def Normalize(in_channels):
77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78 |
79 |
80 | class LinearAttention(nn.Module):
81 | def __init__(self, dim, heads=4, dim_head=32):
82 | super().__init__()
83 | self.heads = heads
84 | hidden_dim = dim_head * heads
85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87 |
88 | def forward(self, x):
89 | b, c, h, w = x.shape
90 | qkv = self.to_qkv(x)
91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92 | k = k.softmax(dim=-1)
93 | context = torch.einsum('bhdn,bhen->bhde', k, v)
94 | out = torch.einsum('bhde,bhdn->bhen', context, q)
95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96 | return self.to_out(out)
97 |
98 |
99 |
100 |
101 |
102 | class CrossAttention(nn.Module):
103 | def __init__(self, query_dim, key_dim, value_dim, heads=8, dim_head=64, dropout=0):
104 | super().__init__()
105 | inner_dim = dim_head * heads
106 | self.scale = dim_head ** -0.5
107 | self.heads = heads
108 |
109 |
110 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
111 | self.to_k = nn.Linear(key_dim, inner_dim, bias=False)
112 | self.to_v = nn.Linear(value_dim, inner_dim, bias=False)
113 |
114 |
115 | self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) )
116 |
117 |
118 | def fill_inf_from_mask(self, sim, mask):
119 | if mask is not None:
120 | B,M = mask.shape
121 | mask = mask.unsqueeze(1).repeat(1,self.heads,1).reshape(B*self.heads,1,-1)
122 | max_neg_value = -torch.finfo(sim.dtype).max
123 | sim.masked_fill_(~mask, max_neg_value)
124 | return sim
125 |
126 |
127 | def forward(self, x, key, value, mask=None):
128 |
129 | q = self.to_q(x) # B*N*(H*C)
130 | k = self.to_k(key) # B*M*(H*C)
131 | v = self.to_v(value) # B*M*(H*C)
132 |
133 | B, N, HC = q.shape
134 | _, M, _ = key.shape
135 | H = self.heads
136 | C = HC // H
137 |
138 | q = q.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C
139 | k = k.view(B,M,H,C).permute(0,2,1,3).reshape(B*H,M,C) # (B*H)*M*C
140 | v = v.view(B,M,H,C).permute(0,2,1,3).reshape(B*H,M,C) # (B*H)*M*C
141 |
142 | sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale # (B*H)*N*M
143 | self.fill_inf_from_mask(sim, mask)
144 | attn = sim.softmax(dim=-1) # (B*H)*N*M
145 |
146 | out = torch.einsum('b i j, b j d -> b i d', attn, v) # (B*H)*N*C
147 | out = out.view(B,H,N,C).permute(0,2,1,3).reshape(B,N,(H*C)) # B*N*(H*C)
148 |
149 | return self.to_out(out)
150 |
151 |
152 |
153 |
154 | class SelfAttention(nn.Module):
155 | def __init__(self, query_dim, heads=8, dim_head=64, dropout=0.):
156 | super().__init__()
157 | inner_dim = dim_head * heads
158 | self.scale = dim_head ** -0.5
159 | self.heads = heads
160 |
161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
162 | self.to_k = nn.Linear(query_dim, inner_dim, bias=False)
163 | self.to_v = nn.Linear(query_dim, inner_dim, bias=False)
164 |
165 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) )
166 |
167 | def forward(self, x):
168 | q = self.to_q(x) # B*N*(H*C)
169 | k = self.to_k(x) # B*N*(H*C)
170 | v = self.to_v(x) # B*N*(H*C)
171 |
172 | B, N, HC = q.shape
173 | H = self.heads
174 | C = HC // H
175 |
176 | q = q.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C
177 | k = k.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C
178 | v = v.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C
179 |
180 | sim = torch.einsum('b i c, b j c -> b i j', q, k) * self.scale # (B*H)*N*N
181 | attn = sim.softmax(dim=-1) # (B*H)*N*N
182 |
183 | out = torch.einsum('b i j, b j c -> b i c', attn, v) # (B*H)*N*C
184 | out = out.view(B,H,N,C).permute(0,2,1,3).reshape(B,N,(H*C)) # B*N*(H*C)
185 |
186 | return self.to_out(out)
187 |
188 |
189 |
190 | class GatedCrossAttentionDense(nn.Module):
191 | def __init__(self, query_dim, key_dim, value_dim, n_heads, d_head):
192 | super().__init__()
193 |
194 | self.attn = CrossAttention(query_dim=query_dim, key_dim=key_dim, value_dim=value_dim, heads=n_heads, dim_head=d_head)
195 | self.ff = FeedForward(query_dim, glu=True)
196 |
197 | self.norm1 = nn.LayerNorm(query_dim)
198 | self.norm2 = nn.LayerNorm(query_dim)
199 |
200 | self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)) )
201 | self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)) )
202 |
203 | # this can be useful: we can externally change magnitude of tanh(alpha)
204 | # for example, when it is set to 0, then the entire model is same as original one
205 | self.scale = 1
206 |
207 | def forward(self, x, objs):
208 |
209 | x = x + self.scale*torch.tanh(self.alpha_attn) * self.attn( self.norm1(x), objs, objs)
210 | x = x + self.scale*torch.tanh(self.alpha_dense) * self.ff( self.norm2(x) )
211 |
212 | return x
213 |
214 |
215 | class GatedSelfAttentionDense(nn.Module):
216 | def __init__(self, query_dim, context_dim, n_heads, d_head):
217 | super().__init__()
218 |
219 | # we need a linear projection since we need cat visual feature and obj feature
220 | self.linear = nn.Linear(context_dim, query_dim)
221 |
222 | self.attn = SelfAttention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
223 | self.ff = FeedForward(query_dim, glu=True)
224 |
225 | self.norm1 = nn.LayerNorm(query_dim)
226 | self.norm2 = nn.LayerNorm(query_dim)
227 |
228 | self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)) )
229 | self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)) )
230 |
231 | # this can be useful: we can externally change magnitude of tanh(alpha)
232 | # for example, when it is set to 0, then the entire model is same as original one
233 | self.scale = 1
234 |
235 |
236 | def forward(self, x, objs):
237 |
238 | N_visual = x.shape[1]
239 | objs = self.linear(objs)
240 |
241 | x = x + self.scale*torch.tanh(self.alpha_attn) * self.attn( self.norm1(torch.cat([x,objs],dim=1)) )[:,0:N_visual,:]
242 | x = x + self.scale*torch.tanh(self.alpha_dense) * self.ff( self.norm2(x) )
243 |
244 | return x
245 |
246 |
247 | class BasicTransformerBlock(nn.Module):
248 | def __init__(self, query_dim, key_dim, value_dim, n_heads, d_head, fuser_type, use_checkpoint=True, use_sp=True, use_nsp=True):
249 | super().__init__()
250 | self.attn1 = SelfAttention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
251 | self.ff = FeedForward(query_dim, glu=True)
252 | self.attn2 = CrossAttention(query_dim=query_dim, key_dim=key_dim, value_dim=value_dim, heads=n_heads, dim_head=d_head)
253 | self.norm1 = nn.LayerNorm(query_dim)
254 | self.norm2 = nn.LayerNorm(query_dim)
255 | self.norm3 = nn.LayerNorm(query_dim)
256 | self.use_checkpoint = use_checkpoint
257 |
258 | # note key_dim and value_dim here actually are context_dim
259 | if fuser_type == "gatedSA":
260 | if use_sp:
261 | self.sp_fuser = GatedSelfAttentionDense(query_dim, key_dim, n_heads, d_head)
262 | if use_nsp:
263 | self.nsp_fuser = GatedSelfAttentionDense(key_dim, key_dim, n_heads, d_head) # match with text dim
264 | elif fuser_type == "gatedCA":
265 | if use_sp:
266 | self.sp_fuser = GatedCrossAttentionDense(query_dim, key_dim, value_dim, n_heads, d_head)
267 | if use_nsp:
268 | self.nsp_fuser = GatedCrossAttentionDense(key_dim, key_dim, value_dim, n_heads, d_head) # match with text dim
269 | elif fuser_type == "gatedSA-gatedCA":
270 | if use_sp:
271 | self.sp_fuser = GatedSelfAttentionDense(query_dim, key_dim, n_heads, d_head)
272 | if use_nsp:
273 | self.nsp_fuser = GatedCrossAttentionDense(key_dim, key_dim, value_dim, n_heads, d_head) # match with text dim
274 | elif fuser_type == "gatedCA-gatedSA":
275 | if use_sp:
276 | self.sp_fuser = GatedCrossAttentionDense(query_dim, key_dim, value_dim, n_heads, d_head)
277 | if use_nsp:
278 | self.nsp_fuser = GatedSelfAttentionDense(key_dim, key_dim, n_heads, d_head) # match with text dim
279 | else:
280 | assert False
281 |
282 |
283 | def forward(self, x, context, sp_objs, nsp_objs):
284 | # return checkpoint(self._forward, (x, context, objs), self.parameters(), self.use_checkpoint)
285 | if self.use_checkpoint and x.requires_grad:
286 | return checkpoint.checkpoint(self._forward, x, context, sp_objs, nsp_objs)
287 | else:
288 | return self._forward(x, context, sp_objs, nsp_objs)
289 |
290 | def _forward(self, x, context, sp_objs, nsp_objs):
291 | x = self.attn1(self.norm1(x)) + x
292 | if sp_objs is not None:
293 | x = self.sp_fuser(x, sp_objs) # identity mapping in the beginning
294 | if nsp_objs is not None:
295 | context = self.nsp_fuser(context, nsp_objs) # identity mapping in the beginning
296 | x = self.attn2(self.norm2(x), context, context) + x
297 | x = self.ff(self.norm3(x)) + x
298 | return x
299 |
300 |
301 | class SpatialTransformer(nn.Module):
302 | def __init__(self, in_channels, key_dim, value_dim, n_heads, d_head, depth=1, fuser_type=None, use_checkpoint=True, use_sp=True, use_nsp=True):
303 | super().__init__()
304 | self.in_channels = in_channels
305 | query_dim = n_heads * d_head
306 | self.norm = Normalize(in_channels)
307 |
308 |
309 | self.proj_in = nn.Conv2d(in_channels,
310 | query_dim,
311 | kernel_size=1,
312 | stride=1,
313 | padding=0)
314 |
315 | self.transformer_blocks = nn.ModuleList(
316 | [BasicTransformerBlock(query_dim, key_dim, value_dim, n_heads, d_head, fuser_type, use_checkpoint=use_checkpoint, use_sp=use_sp, use_nsp=use_nsp)
317 | for d in range(depth)]
318 | )
319 |
320 | self.proj_out = zero_module(nn.Conv2d(query_dim,
321 | in_channels,
322 | kernel_size=1,
323 | stride=1,
324 | padding=0))
325 |
326 | def forward(self, x, context, sp_objs, nsp_objs):
327 | b, c, h, w = x.shape
328 | x_in = x
329 | x = self.norm(x)
330 | x = self.proj_in(x)
331 | x = rearrange(x, 'b c h w -> b (h w) c')
332 | for block in self.transformer_blocks:
333 | x = block(x, context, sp_objs, nsp_objs)
334 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
335 | x = self.proj_out(x)
336 | return x + x_in
337 |
--------------------------------------------------------------------------------
/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | import numpy as np
5 |
6 | from inspect import isfunction
7 | from PIL import Image, ImageDraw, ImageFont
8 |
9 |
10 | def log_txt_as_img(wh, xc, size=10):
11 | # wh a tuple of (width, height)
12 | # xc a list of captions to plot
13 | b = len(xc)
14 | txts = list()
15 | for bi in range(b):
16 | txt = Image.new("RGB", wh, color="white")
17 | draw = ImageDraw.Draw(txt)
18 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
19 | nc = int(40 * (wh[0] / 256))
20 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
21 |
22 | try:
23 | draw.text((0, 0), lines, fill="black", font=font)
24 | except UnicodeEncodeError:
25 | print("Cant encode string for logging. Skipping.")
26 |
27 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
28 | txts.append(txt)
29 | txts = np.stack(txts)
30 | txts = torch.tensor(txts)
31 | return txts
32 |
33 |
34 | def ismap(x):
35 | if not isinstance(x, torch.Tensor):
36 | return False
37 | return (len(x.shape) == 4) and (x.shape[1] > 3)
38 |
39 |
40 | def isimage(x):
41 | if not isinstance(x,torch.Tensor):
42 | return False
43 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
44 |
45 |
46 | def exists(x):
47 | return x is not None
48 |
49 |
50 | def default(val, d):
51 | if exists(val):
52 | return val
53 | return d() if isfunction(d) else d
54 |
55 |
56 | def mean_flat(tensor):
57 | """
58 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
59 | Take the mean over all non-batch dimensions.
60 | """
61 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
62 |
63 |
64 | def count_params(model, verbose=False):
65 | total_params = sum(p.numel() for p in model.parameters())
66 | if verbose:
67 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
68 | return total_params
69 |
70 |
71 | def instantiate_from_config(config):
72 | if not "target" in config:
73 | if config == '__is_first_stage__':
74 | return None
75 | elif config == "__is_unconditional__":
76 | return None
77 | raise KeyError("Expected key `target` to instantiate.")
78 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
79 |
80 |
81 | def get_obj_from_str(string, reload=False):
82 | module, cls = string.rsplit(".", 1)
83 | if reload:
84 | module_imp = importlib.import_module(module)
85 | importlib.reload(module_imp)
86 | return getattr(importlib.import_module(module, package=None), cls)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.13.0
2 | torchvision==0.14.0
3 | albumentations==0.4.3
4 | opencv-python
5 | pudb==2019.2
6 | imageio==2.9.0
7 | imageio-ffmpeg==0.4.2
8 | pytorch-lightning==1.4.2
9 | omegaconf==2.1.1
10 | test-tube>=0.7.5
11 | streamlit>=0.73.1
12 | einops==0.3.0
13 | torch-fidelity==0.3.0
14 | git+https://github.com/openai/CLIP.git
15 | protobuf~=3.20.1
16 | torchmetrics==0.6.0
17 | transformers==4.19.2
18 | kornia==0.5.8
19 |
20 | wandb
21 | scikit-image
22 |
23 | # uninstall
24 | # pip uninstall -y torchtext
25 |
--------------------------------------------------------------------------------
/visualization/draw_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from copy import deepcopy
3 | import visualization.image_utils as iutils
4 |
5 |
6 |
7 | def draw_sketch_with_batch_to_tensor(batch):
8 | if "sketch" not in batch or batch["sketch"] is None:
9 | return torch.zeros_like(batch["image"])
10 | else:
11 | return batch["sketch"]["values"]
12 |
13 | def draw_depth_with_batch_to_tensor(batch):
14 | if "depth" not in batch or batch["depth"] is None:
15 | return torch.zeros_like(batch["image"])
16 | else:
17 | return batch["depth"]["values"]
18 |
19 | def draw_boxes_with_batch_to_tensor(batch):
20 | if "box" not in batch or batch["box"] is None:
21 | return torch.zeros_like(batch["image"])
22 |
23 | batch_size = batch["image"].size(0)
24 | box_drawing = []
25 | for i in range(batch_size):
26 | if "box" in batch:
27 | info_dict = {"image": batch["image"][i], "boxes": batch["box"]["values"][i]}
28 | boxed_img = iutils.vis_boxes(info_dict)
29 | else:
30 | boxed_img = torch.randn_like(batch["image"][i])
31 | box_drawing.append(boxed_img)
32 | box_tensor = torch.stack(box_drawing)
33 | return box_tensor
34 |
35 | def draw_keypoints_with_batch_to_tensor(batch):
36 | if "keypoint" not in batch or batch["keypoint"] is None:
37 | return torch.zeros_like(batch["image"])
38 |
39 | batch_size = batch["image"].size(0)
40 | keypoint_drawing = []
41 | for i in range(batch_size):
42 | if "keypoint" in batch:
43 | info_dict = {"image": batch["image"][i], "points": batch["keypoint"]["values"][i]}
44 | keypointed_img = iutils.vis_keypoints(info_dict)
45 | else:
46 | keypointed_img = torch.randn_like(batch["image"][i])
47 | keypoint_drawing.append(keypointed_img)
48 | keypoint_tensor = torch.stack(keypoint_drawing)
49 | return keypoint_tensor
50 |
51 | def draw_color_palettes_with_batch_to_tensor(batch):
52 | if "color_palette" not in batch or batch["color_palette"] is None:
53 | return torch.zeros_like(batch["image"])
54 | try:
55 | batch_size = batch["image"].size(0)
56 | color_palette_drawing = []
57 | for i in range(batch_size):
58 | if "color_palette" in batch:
59 | color_hist = deepcopy(batch["color_palette"]["values"][i])
60 | color_palette = iutils.vis_color_palette(color_hist, batch["image"].shape[-1])
61 | else:
62 | color_palette = torch.randn_like(batch["image"][i])
63 | color_palette_drawing.append(color_palette)
64 | color_palette_tensor = torch.stack(color_palette_drawing)
65 | return color_palette_tensor
66 | except:
67 | print(f">> Exception occured in draw_color_palettes_with_batch_to_tensor(batch)..")
68 | return torch.zeros_like(batch["image"])
69 |
70 | def draw_image_embedding_with_batch_to_tensor(batch):
71 | if "image_embedding" not in batch or batch["image_embedding"] is None:
72 | return -torch.ones_like(batch["image"])
73 | else:
74 | if "image" in batch["image_embedding"]:
75 | return batch["image_embedding"]["image"]
76 | else: return -torch.ones_like(batch["image"])
77 |
78 |
79 |
--------------------------------------------------------------------------------
/visualization/extract_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from PIL import Image
5 | from transformers import CLIPProcessor, CLIPModel
6 |
7 | from skimage.color import hsv2rgb, rgb2lab, lab2rgb
8 | from skimage.io import imread, imsave
9 | from sklearn.metrics import euclidean_distances
10 |
11 |
12 | @torch.no_grad()
13 | def get_clip_feature(img_path, model, processor):
14 | # clip_features = dict()
15 | image = Image.open(img_path).convert("RGB")
16 | inputs = processor(images=[image], return_tensors="pt", padding=True)
17 | inputs['pixel_values'] = inputs['pixel_values'].cuda()
18 | inputs['input_ids'] = torch.tensor([[0,1,2,3]]).cuda()
19 | outputs = model(**inputs)
20 | feature = outputs.image_embeds
21 | return feature.squeeze().cpu().numpy() # dim: [768,]
22 |
23 |
24 |
25 |
26 |
27 | '---------------------------------------- color converter ----------------------------------------'
28 |
29 | def rgb2hex(rgb_number):
30 | """
31 | Args:
32 | - rgb_number (sequence of float)
33 | Returns:
34 | - hex_number (string)
35 | """
36 | # print(rgb_number, [np.round(val*255) for val in rgb_number])
37 | return '#{:02x}{:02x}{:02x}'.format(*tuple([int(np.round(val * 255)) for val in rgb_number]))
38 |
39 |
40 | def hex2rgb(hexcolor_str):
41 | """
42 | Args:
43 | - hexcolor_str (string): e.g. '#ffffff' or '33cc00'
44 | Returns:
45 | - rgb_color (sequence of floats): e.g. (0.2, 0.3, 0)
46 | """
47 | color = hexcolor_str.strip('#')
48 | # rgb = lambda x: round(int(x, 16) / 255., 5)
49 | return tuple(round(int(color[i:i+2], 16) / 255., 5) for i in (0, 2, 4))
50 |
51 |
52 |
53 | '---------------------------------------- color palette histogram ----------------------------------------'
54 |
55 | def histogram_colors_smoothed(lab_array, palette, sigma=10,
56 | plot_filename=None, direct=True):
57 | """
58 | Returns a palette histogram of colors in the image, smoothed with
59 | a Gaussian. Can smooth directly per-pixel, or after computing a strict
60 | histogram.
61 | Parameters
62 | ----------
63 | lab_array : (N,3) ndarray
64 | The L*a*b color of each of N pixels.
65 | palette : rayleigh.Palette
66 | Containing K colors.
67 | sigma : float
68 | Variance of the smoothing Gaussian.
69 | direct : bool, optional
70 | If True, constructs a smoothed histogram directly from pixels.
71 | If False, constructs a nearest-color histogram and then smoothes it.
72 | Returns
73 | -------
74 | color_hist : (K,) ndarray
75 | """
76 | if direct:
77 | color_hist_smooth = histogram_colors_with_smoothing(
78 | lab_array, palette, sigma)
79 | else:
80 | color_hist_strict = histogram_colors_strict(lab_array, palette)
81 | color_hist_smooth = smooth_histogram(color_hist_strict, palette, sigma)
82 | if plot_filename is not None:
83 | plot_histogram(color_hist_smooth, palette, plot_filename)
84 | return color_hist_smooth
85 |
86 | def smooth_histogram(color_hist, palette, sigma=10):
87 | """
88 | Smooth the given palette histogram with a Gaussian of variance sigma.
89 | Parameters
90 | ----------
91 | color_hist : (K,) ndarray
92 | palette : rayleigh.Palette
93 | containing K colors.
94 | Returns
95 | -------
96 | color_hist_smooth : (K,) ndarray
97 | """
98 | n = 2. * sigma ** 2
99 | weights = np.exp(-palette.distances / n)
100 | norm_weights = weights / weights.sum(1)[:, np.newaxis]
101 | color_hist_smooth = (norm_weights * color_hist).sum(1)
102 | color_hist_smooth[color_hist_smooth < 1e-5] = 0
103 | return color_hist_smooth
104 |
105 | def histogram_colors_with_smoothing(lab_array, palette, sigma=10):
106 | """
107 | Assign colors in the image to nearby colors in the palette, weighted by
108 | distance in Lab color space.
109 | Parameters
110 | ----------
111 | lab_array (N,3) ndarray:
112 | N is the number of data points, columns are L, a, b values.
113 | palette : rayleigh.Palette
114 | containing K colors.
115 | sigma : float
116 | (0,1] value to control the steepness of exponential falloff.
117 | To see the effect:
118 | >>> from pylab import *
119 | >>> ds = linspace(0,5000) # squared distance
120 | >>> sigma=10; plot(ds, exp(-ds/(2*sigma**2)), label='$\sigma=%.1f$'%sigma)
121 | >>> sigma=20; plot(ds, exp(-ds/(2*sigma**2)), label='$\sigma=%.1f$'%sigma)
122 | >>> sigma=40; plot(ds, exp(-ds/(2*sigma**2)), label='$\sigma=%.1f$'%sigma)
123 | >>> ylim([0,1]); legend();
124 | >>> xlabel('Squared distance'); ylabel('Weight');
125 | >>> title('Exponential smoothing')
126 | >>> #plt.savefig('exponential_smoothing.png', dpi=300)
127 | sigma=20 seems reasonable: hits 0 around squared distance of 4000.
128 | Returns:
129 | color_hist : (K,) ndarray
130 | the normalized, smooth histogram of colors.
131 | """
132 | dist = euclidean_distances(palette.lab_array, lab_array, squared=True).T
133 | n = 2. * sigma ** 2
134 | weights = np.exp(-dist / n)
135 |
136 | # normalize by sum: if a color is equally well represented by several colors
137 | # it should not contribute much to the overall histogram
138 | normalizing = weights.sum(1)
139 | normalizing[normalizing == 0] = 1e16
140 | normalized_weights = weights / normalizing[:, np.newaxis]
141 |
142 | color_hist = normalized_weights.sum(0)
143 | color_hist /= lab_array.shape[0]
144 | color_hist[color_hist < 1e-5] = 0
145 | return color_hist
146 |
147 | def histogram_colors_strict(lab_array, palette, plot_filename=None):
148 | """
149 | Return a palette histogram of colors in the image.
150 | Parameters
151 | ----------
152 | lab_array : (N,3) ndarray
153 | The L*a*b color of each of N pixels.
154 | palette : rayleigh.Palette
155 | Containing K colors.
156 | plot_filename : string, optional
157 | If given, save histogram to this filename.
158 | Returns
159 | -------
160 | color_hist : (K,) ndarray
161 | """
162 | # This is the fastest way that I've found.
163 | # >>> %%timeit -n 200 from sklearn.metrics import euclidean_distances
164 | # >>> euclidean_distances(palette, lab_array, squared=True)
165 | dist = euclidean_distances(palette.lab_array, lab_array, squared=True).T
166 | min_ind = np.argmin(dist, axis=1)
167 | num_colors = palette.lab_array.shape[0]
168 | num_pixels = lab_array.shape[0]
169 | color_hist = 1. * np.bincount(min_ind, minlength=num_colors) / num_pixels
170 | if plot_filename is not None:
171 | plot_histogram(color_hist, palette, plot_filename)
172 | return color_hist
173 |
174 | def plot_histogram(color_hist, palette, plot_filename=None):
175 | """
176 | Return Figure containing the color palette histogram.
177 | Args:
178 | - color_hist (K, ndarray)
179 | - palette (Palette)
180 | - plot_filename (string) [default=None]:
181 | Save histogram to this file, if given.
182 | Returns:
183 | - fig (Figure)
184 | """
185 | fig = plt.figure(figsize=(5, 3), dpi=150)
186 | ax = fig.add_subplot(111)
187 | ax.bar(
188 | range(len(color_hist)), color_hist,
189 | color=palette.hex_list)
190 | ax.set_ylim((0, 0.1))
191 | ax.xaxis.set_ticks([])
192 | ax.set_xlim((0, len(palette.hex_list)))
193 | if plot_filename:
194 | fig.savefig(plot_filename, dpi=150, facecolor='none')
195 | return fig
196 |
197 |
198 | '---------------------------------------- histogram to image ----------------------------------------'
199 | # only for visualization
200 | # for extracting color palette, color histogram is enough.
201 |
202 | def color_hist_to_palette_image(color_hist, palette, percentile=90,
203 | width=200, height=50, filename=None):
204 | """
205 | Output the main colors in the histogram to a "palette image."
206 | Parameters
207 | ----------
208 | color_hist : (K,) ndarray
209 | palette : rayleigh.Palette
210 | percentile : int, optional:
211 | Output only colors above this percentile of prevalence in the histogram.
212 | filename : string, optional:
213 | If given, save the resulting image to file.
214 | Returns
215 | -------
216 | rgb_image : ndarray
217 | """
218 | ind = np.argsort(-color_hist)
219 | ind = ind[color_hist[ind] > np.percentile(color_hist, percentile)]
220 | hex_list = np.take(palette.hex_list, ind)
221 | values = color_hist[ind]
222 | rgb_image = palette_query_to_rgb_image(dict(zip(hex_list, values)))
223 | if filename:
224 | imsave(filename, rgb_image)
225 | return rgb_image
226 |
227 |
228 | def palette_query_to_rgb_image(palette_query, width=200, height=50):
229 | """
230 | Convert a list of hex colors and their values to an RGB image of given
231 | width and height.
232 | Args:
233 | - palette_query (dict):
234 | a dictionary of hex colors to unnormalized values,
235 | e.g. {'#ffffff': 20, '#33cc00': 0.4}.
236 | """
237 | hex_list, values = zip(*palette_query.items())
238 | values = np.array(values)
239 | values /= values.sum()
240 | nums = np.array(values * width, dtype=int)
241 | rgb_arrays = (np.tile(np.array(hex2rgb(x)), (num, 1))
242 | for x, num in zip(hex_list, nums))
243 | rgb_array = np.vstack(rgb_arrays)
244 | rgb_image = rgb_array[np.newaxis, :, :]
245 | rgb_image = np.tile(rgb_image, (height, 1, 1))
246 | return rgb_image
247 |
248 |
249 |
250 |
251 | '---------------------------------------- Color Palette ----------------------------------------'
252 |
253 | class Palette(object):
254 | """
255 | Create a color palette (codebook) in the form of a 2D grid of colors,
256 | as described in the parameters list below.
257 | Further, the rightmost column has num_hues gradations from black to white.
258 | Parameters
259 | ----------
260 | num_hues : int
261 | number of colors with full lightness and saturation, in the middle
262 | sat_range : int
263 | number of rows above middle row that show
264 | the same hues with decreasing saturation.
265 | light_range : int
266 | number of rows below middle row that show
267 | the same hues with decreasing lightness.
268 | Returns
269 | -------
270 | palette: rayleigh.Palette
271 | """
272 |
273 | def __init__(self, num_hues=8, sat_range=2, light_range=2):
274 | height = 1 + sat_range + (2 * light_range - 1)
275 | # generate num_hues+1 hues, but don't take the last one:
276 | # hues are on a circle, and we would be oversampling the origin
277 | hues = np.tile(np.linspace(0, 1, num_hues + 1)[:-1], (height, 1))
278 | if num_hues == 8:
279 | hues = np.tile(np.array(
280 | [0., 0.10, 0.15, 0.28, 0.51, 0.58, 0.77, 0.85]), (height, 1))
281 | if num_hues == 9:
282 | hues = np.tile(np.array(
283 | [0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.7, 0.87]), (height, 1))
284 | if num_hues == 10:
285 | hues = np.tile(np.array(
286 | [0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.66, 0.76, 0.87]), (height, 1))
287 | elif num_hues == 11:
288 | hues = np.tile(np.array(
289 | [0.0, 0.0833, 0.166, 0.25,
290 | 0.333, 0.5, 0.56333,
291 | 0.666, 0.73, 0.803,
292 | 0.916]), (height, 1))
293 |
294 | sats = np.hstack((
295 | np.linspace(0, 1, sat_range + 2)[1:-1],
296 | 1,
297 | [1] * (light_range),
298 | [.4] * (light_range - 1),
299 | ))
300 | lights = np.hstack((
301 | [1] * sat_range,
302 | 1,
303 | np.linspace(1, 0.2, light_range + 2)[1:-1],
304 | np.linspace(1, 0.2, light_range + 2)[1:-2],
305 | ))
306 |
307 | sats = np.tile(np.atleast_2d(sats).T, (1, num_hues))
308 | lights = np.tile(np.atleast_2d(lights).T, (1, num_hues))
309 | colors = hsv2rgb(np.dstack((hues, sats, lights)))
310 | grays = np.tile(
311 | np.linspace(1, 0, height)[:, np.newaxis, np.newaxis], (1, 1, 3))
312 |
313 | self.rgb_image = np.hstack((colors, grays))
314 |
315 | # Make a nice histogram ordering of the hues and grays
316 | h, w, d = colors.shape
317 | color_array = colors.T.reshape((d, w * h)).T
318 | h, w, d = grays.shape
319 | gray_array = grays.T.reshape((d, w * h)).T
320 |
321 | self.rgb_array = np.vstack((color_array, gray_array))
322 | self.lab_array = rgb2lab(self.rgb_array[None, :, :]).squeeze()
323 | self.hex_list = [rgb2hex(row) for row in self.rgb_array]
324 | #assert(np.all(self.rgb_array == self.rgb_array[None, :, :].squeeze()))
325 |
326 | self.distances = euclidean_distances(self.lab_array, squared=True)
327 |
328 |
329 | '---------------------------------------- Image Wrapper for lab_array ----------------------------------------'
330 |
331 | class ColorImage(object):
332 | """
333 | Read the image at the URL in RGB format, downsample if needed,
334 | and convert to Lab colorspace.
335 | Store original dimensions, resize_factor, and the filename of the image.
336 | Image dimensions will be resized independently such that neither width nor
337 | height exceed the maximum allowed dimension MAX_DIMENSION.
338 | Parameters
339 | ----------
340 | url : string
341 | URL or file path of the image to load.
342 | id : string, optional
343 | Name or some other id of the image. For example, the Flickr ID.
344 | """
345 |
346 | MAX_DIMENSION = 240 + 1
347 |
348 | def __init__(self, url, _id=None):
349 | self.id = _id
350 | self.url = url
351 | img = imread(url)
352 |
353 | # Handle grayscale and RGBA images.
354 | # TODO: Should be smarter here in the future, but for now simply remove
355 | # the alpha channel if present.
356 | if img.ndim == 2:
357 | img = np.tile(img[:, :, np.newaxis], (1, 1, 3))
358 | elif img.ndim == 4:
359 | img = img[:, :, :3]
360 | elif img.ndim == 3 and img.shape[2] == 4:
361 | img = img[:, :, :3]
362 |
363 | # Downsample for speed.
364 | #
365 | # NOTE: I can't find a good method to resize properly in Python!
366 | # scipy.misc.imresize uses PIL, which needs 8bit data.
367 | # Anyway, this is faster and almost as good.
368 | #
369 | # >>> def d(dim, max_dim): return arange(0, dim, dim / max_dim + 1).shape
370 | # >>> plot(range(1200), [d(x, 200) for x in range(1200)])
371 | h, w, d = tuple(img.shape)
372 | self.orig_h, self.orig_w, self.orig_d = tuple(img.shape)
373 | h_stride = h // self.MAX_DIMENSION + 1
374 | w_stride = w // self.MAX_DIMENSION + 1
375 | img = img[::h_stride, ::w_stride, :]
376 |
377 | # Convert to L*a*b colors.
378 | h, w, d = img.shape
379 | self.h, self.w, self.d = img.shape
380 | self.lab_array = rgb2lab(img).reshape((h * w, d))
381 |
382 | def as_dict(self):
383 | """
384 | Return relevant info about self in a dict.
385 | """
386 | return {'id': self.id, 'url': self.url,
387 | 'resized_width': self.w, 'resized_height': self.h,
388 | 'width': self.orig_w, 'height': self.orig_h}
389 |
390 | def output_quantized_to_palette(self, palette, filename):
391 | """
392 | Save to filename a version of the image with all colors quantized
393 | to the nearest color in the given palette.
394 | Parameters
395 | ----------
396 | palette : rayleigh.Palette
397 | Containing K colors.
398 | filename : string
399 | Where image will be written.
400 | """
401 | dist = euclidean_distances(
402 | palette.lab_array, self.lab_array, squared=True).T
403 | min_ind = np.argmin(dist, axis=1)
404 | quantized_lab_array = palette.lab_array[min_ind, :]
405 | img = lab2rgb(quantized_lab_array.reshape((self.h, self.w, self.d)))
406 | imsave(filename, img)
407 |
408 |
409 | def get_color_palette(img_path):
410 | palette = Palette(num_hues=11, sat_range=5, light_range=5)
411 | assert len(palette.hex_list) == 180
412 | query_img = ColorImage(url=img_path)
413 | color_hist = histogram_colors_smoothed(query_img.lab_array, palette, sigma=10, direct=False)
414 | return color_hist
415 |
416 |
417 | @torch.no_grad()
418 | def get_box(locations, phrases, model, processor):
419 | boxes = torch.zeros(30, 4)
420 | masks = torch.zeros(30)
421 | text_embeddings = torch.zeros(30, 768)
422 |
423 | text_features = []
424 | image_features = []
425 | for phrase in phrases:
426 | inputs = processor(text=phrase, return_tensors="pt", padding=True)
427 | inputs['input_ids'] = inputs['input_ids'].cuda()
428 | inputs['pixel_values'] = torch.ones(1,3,224,224).cuda() # placeholder
429 | inputs['attention_mask'] = inputs['attention_mask'].cuda()
430 | outputs = model(**inputs)
431 | feature = outputs.text_model_output.pooler_output
432 | text_features.append(feature.squeeze())
433 |
434 | for idx, (box, text_feature) in enumerate(zip(locations, text_features)):
435 | boxes[idx] = torch.tensor(box)
436 | masks[idx] = 1
437 | text_embeddings[idx] = text_feature
438 |
439 | return boxes, masks, text_embeddings
440 |
441 |
442 | def get_keypoint(locations):
443 | points = torch.zeros(8*17,2)
444 | idx = 0
445 | for this_person_kp in locations:
446 | for kp in this_person_kp:
447 | points[idx,0] = kp[0]
448 | points[idx,1] = kp[1]
449 | idx += 1
450 |
451 | # derive masks from points
452 | masks = (points.mean(dim=1)!=0) * 1
453 | masks = masks.float()
454 | return points, masks
455 |
456 |
457 | def transform_image_from_pil_to_numpy(pil_image):
458 | np_image, trans_info = center_crop_arr(pil_image, image_size=512)
459 |
460 | if np_image.ndim == 2: # when we load the image with "L" option
461 | np_image = np.expand_dims(np_image, axis=2)
462 | return np_image, trans_info
463 |
464 | def flip_image_from_numpy_to_numpy(np_image, trans_info, is_flip=False):
465 | if is_flip:
466 | np_image = np_image[:, ::-1]
467 | trans_info["performed_flip"] = True
468 | return np_image, trans_info
469 | else:
470 | return np_image, trans_info
471 |
472 | def convert_image_from_numpy_to_tensor(np_image, type_mask=False):
473 | if type_mask:
474 | """
475 | value range : (0, 1), for sketch, segm, mask,
476 | """
477 | np_image = np_image.astype(np.float32) / 255.0
478 | else:
479 | """
480 | value range : (-1, 1), for rgb
481 | """
482 | np_image = np_image.astype(np.float32) / 127.5 - 1
483 | np_image = np.transpose(np_image, [2,0,1])
484 | return torch.tensor(np_image)
485 |
486 | def invert_image_from_numpy_to_numpy(np_image):
487 | return 255.0 - np_image
488 |
489 | def center_crop_arr(pil_image, image_size):
490 | # We are not on a new enough PIL to support the `reducing_gap`
491 | # argument, which uses BOX downsampling at powers of two first.
492 | # Thus, we do it by hand to improve downsample quality.
493 | WW, HH = pil_image.size
494 |
495 | while min(*pil_image.size) >= 2 * image_size:
496 | pil_image = pil_image.resize(
497 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
498 | )
499 |
500 | scale = image_size / min(*pil_image.size)
501 |
502 | pil_image = pil_image.resize(
503 | tuple(round(x * scale) for x in pil_image.size), resample=Image.Resampling.BICUBIC
504 | )
505 |
506 | # at this point, the min of pil_image side is desired image_size
507 | performed_scale = image_size / min(WW, HH)
508 |
509 | arr = np.array(pil_image)
510 | crop_y = (arr.shape[0] - image_size) // 2
511 | crop_x = (arr.shape[1] - image_size) // 2
512 |
513 | info = {"performed_scale":performed_scale, 'crop_y':crop_y, 'crop_x':crop_x, "WW":WW, 'HH':HH, "performed_flip": False}
514 |
515 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size], info
516 |
517 |
518 | def get_sketch(img_path):
519 | pil_sketch = Image.open(img_path).convert('RGB')
520 | np_sketch,_ = transform_image_from_pil_to_numpy(pil_sketch)
521 | np_sketch = invert_image_from_numpy_to_numpy(np_sketch)
522 | sketch_tensor = convert_image_from_numpy_to_tensor(np_sketch, type_mask=True) # mask type: range 0 to 1
523 | return sketch_tensor
524 |
525 |
526 | def get_depth(img_path):
527 | pil_depth = Image.open(img_path).convert('RGB')
528 | np_depth,_ = transform_image_from_pil_to_numpy(pil_depth)
529 | depth_tensor = convert_image_from_numpy_to_tensor(np_depth, type_mask=True) # mask type: range 0 to 1
530 | return depth_tensor
531 |
532 |
--------------------------------------------------------------------------------
/visualization/image_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | from os.path import join
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from torchvision.utils import save_image, make_grid
8 | import torchvision
9 |
10 | from PIL import Image, ImageDraw
11 | from skimage.color import hsv2rgb, rgb2lab, lab2rgb
12 |
13 |
14 | PP_RGB = 1
15 | PP_SEGM = 2
16 |
17 | def convert_zero2one(image_tensor):
18 | return (image_tensor.detach().cpu()+1)/2.0
19 |
20 | def postprocess(image_tensors, pp_type):
21 | """
22 | image_tensors:
23 | (B x C=3 x H x W)
24 | torch.tensors
25 | pp_type:
26 | int
27 | """
28 | if pp_type == PP_RGB:
29 | return convert_zero2one(image_tensors)
30 | elif pp_type == PP_SEGM:
31 | if image_tensors.size(1) == 3:
32 | return image_tensors.detach().cpu()
33 | else:
34 | if image_tensors.ndim == 4:
35 | return image_tensors.detach().cpu().repeat(1, 3, 1, 1)
36 | elif image_tensors.ndim == 5:
37 | return image_tensors.detach().cpu().repeat(1, 1, 3, 1, 1)
38 | else:
39 | raise NotImplementedError
40 | else:
41 | raise NotImplementedError
42 |
43 |
44 | def get_scale_factor(image_tensors, feature_tensors):
45 | """
46 | image_tensors:
47 | (B x C=3 x H x W)
48 | feature_tensors:
49 | (B x C x h x w)
50 | """
51 | B, _, H, W = image_tensors.size()
52 | _, _, h, w = feature_tensors.size()
53 | scale_factor = int(H / h)
54 | return scale_factor
55 |
56 | def do_scale(image_tensors, feature_tensors):
57 | """
58 | image_tensors:
59 | (B x C=3 x H x W)
60 | feature_tensors:
61 | (B x C x h x w)
62 | """
63 | scale_factor = get_scale_factor(image_tensors, feature_tensors)
64 | scaled_tensor = F.interpolate(feature_tensors, scale_factor=scale_factor)
65 | return scaled_tensor
66 |
67 | def save_images_from_dict(
68 | image_dict, dir_path, file_name, n_instance,
69 | is_save=False, save_per_instance=False, return_images=False,
70 | save_per_instance_idxs=[]):
71 | """
72 | image_dict:
73 | [
74 | {
75 | "tensors": tensors:tensor,
76 | "n_in_row": int,
77 | "pp_type": int
78 | },
79 | ...
80 | """
81 |
82 | n_row = 0
83 |
84 | for each_item in image_dict:
85 | tensors = each_item["tensors"]
86 | bs = tensors.size(0)
87 | n_instance = min(bs, n_instance)
88 |
89 | n_in_row = each_item["n_in_row"]
90 | n_row += n_in_row
91 |
92 | pp_type = each_item["pp_type"]
93 | post_tensor = postprocess(tensors, pp_type=pp_type)
94 | each_item["tensors"] = torch.clamp(post_tensor, min=0, max=1)
95 |
96 | if save_per_instance:
97 | for i in range(n_instance):
98 | image_list = []
99 | for each_item in image_dict:
100 | if each_item["n_in_row"] == 1:
101 | image_list.append(each_item["tensors"][i].unsqueeze(0))
102 | else:
103 | for j in range(each_item["n_in_row"]):
104 | image_list.append(each_item["tensors"][i, j].unsqueeze(0))
105 | images = torch.cat(image_list, dim=0)
106 | if len(save_per_instance_idxs) > 0:
107 | save_path = join(dir_path, f"{file_name}_{save_per_instance_idxs[i]}.png")
108 | else:
109 | save_path = join(dir_path, f"{file_name}_{i}.png")
110 | if is_save:
111 | save_image(images, save_path, padding=0, pad_value=0.5, nrow=n_row)
112 | else:
113 | save_path = join(dir_path, f"{file_name}.png")
114 | image_list = []
115 | for i in range(n_instance):
116 | for each_item in image_dict:
117 | if each_item["n_in_row"] == 1:
118 | image_list.append(each_item["tensors"][i].unsqueeze(0))
119 | else:
120 | for j in range(each_item["n_in_row"]):
121 | image_list.append(each_item["tensors"][i, j].unsqueeze(0))
122 |
123 | images = torch.cat(image_list, dim=0)
124 | concated_image = make_grid(images, padding=2, pad_value=0.5, nrow=n_row)
125 | if is_save:
126 | save_image(concated_image, save_path, nrow=1)
127 |
128 | if return_images:
129 | return concated_image
130 |
131 | ### =========================================================================== ###
132 | """
133 | functions to visualize boxes
134 | """
135 | def draw_box(img, boxes):
136 | colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
137 | draw = ImageDraw.Draw(img)
138 | for bid, box in enumerate(boxes):
139 | draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4)
140 | # draw.rectangle([box[0], box[1], box[2], box[3]], outline ="red", width=2) # x0 y0 x1 y1
141 | return img
142 |
143 | def vis_boxes(info_dict):
144 |
145 | device = info_dict["image"].device
146 | # img = torchvision.transforms.functional.to_pil_image( info_dict["image"]*0.5+0.5 )
147 | canvas = torchvision.transforms.functional.to_pil_image( torch.zeros_like(info_dict["image"]) ) # unused
148 | W, H = canvas.size
149 |
150 | boxes = []
151 | for box in info_dict["boxes"]:
152 | x0,y0,x1,y1 = box
153 | boxes.append( [float(x0*W), float(y0*H), float(x1*W), float(y1*H)] )
154 | canvas = draw_box(canvas, boxes)
155 |
156 | return torchvision.transforms.functional.to_tensor(canvas).to(device)
157 | ### =========================================================================== ###
158 |
159 | ### =========================================================================== ###
160 | """
161 | functions to visualize keypoints
162 | """
163 | def draw_points(img, points):
164 | colors = ["red", "yellow", "blue", "green", "orange", "brown", "cyan", "purple", "deeppink", "coral", "gold", "darkblue", "khaki", "lightgreen", "snow", "yellowgreen", "lime"]
165 | colors = colors * 100
166 | draw = ImageDraw.Draw(img)
167 |
168 | r = 3
169 | for point, color in zip(points, colors):
170 | if point[0] == point[1] == 0:
171 | pass
172 | else:
173 | x, y = float(point[0]), float(point[1])
174 | draw.ellipse( [ (x-r,y-r), (x+r,y+r) ], fill=color )
175 | # draw.rectangle([box[0], box[1], box[2], box[3]], outline ="red", width=2) # x0 y0 x1 y1
176 | return img
177 |
178 | def vis_keypoints(info_dict):
179 |
180 | device = info_dict["image"].device
181 | # img = torchvision.transforms.functional.to_pil_image( info_dict["image"]*0.5+0.5 )
182 | _, H, W = info_dict["image"].size()
183 | canvas = torchvision.transforms.functional.to_pil_image( torch.ones_like(info_dict["image"])*0.2 )
184 | assert W==H
185 | img = draw_points( canvas, info_dict["points"]*W )
186 |
187 | return torchvision.transforms.functional.to_tensor(img).to(device)
188 | ### =========================================================================== ###
189 |
190 |
191 |
192 | ### ================================ Color palette-related functions ================================== ###
193 |
194 |
195 | class Palette(object):
196 | """
197 | Create a color palette (codebook) in the form of a 2D grid of colors,
198 | as described in the parameters list below.
199 | Further, the rightmost column has num_hues gradations from black to white.
200 | Parameters
201 | ----------
202 | num_hues : int
203 | number of colors with full lightness and saturation, in the middle
204 | sat_range : int
205 | number of rows above middle row that show
206 | the same hues with decreasing saturation.
207 | light_range : int
208 | number of rows below middle row that show
209 | the same hues with decreasing lightness.
210 | Returns
211 | -------
212 | palette: rayleigh.Palette
213 | """
214 |
215 | def __init__(self, num_hues=8, sat_range=2, light_range=2):
216 | height = 1 + sat_range + (2 * light_range - 1)
217 | # generate num_hues+1 hues, but don't take the last one:
218 | # hues are on a circle, and we would be oversampling the origin
219 | hues = np.tile(np.linspace(0, 1, num_hues + 1)[:-1], (height, 1))
220 | if num_hues == 8:
221 | hues = np.tile(np.array(
222 | [0., 0.10, 0.15, 0.28, 0.51, 0.58, 0.77, 0.85]), (height, 1))
223 | if num_hues == 9:
224 | hues = np.tile(np.array(
225 | [0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.7, 0.87]), (height, 1))
226 | if num_hues == 10:
227 | hues = np.tile(np.array(
228 | [0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.66, 0.76, 0.87]), (height, 1))
229 | elif num_hues == 11:
230 | hues = np.tile(np.array(
231 | [0.0, 0.0833, 0.166, 0.25,
232 | 0.333, 0.5, 0.56333,
233 | 0.666, 0.73, 0.803,
234 | 0.916]), (height, 1))
235 |
236 | sats = np.hstack((
237 | np.linspace(0, 1, sat_range + 2)[1:-1],
238 | 1,
239 | [1] * (light_range),
240 | [.4] * (light_range - 1),
241 | ))
242 | lights = np.hstack((
243 | [1] * sat_range,
244 | 1,
245 | np.linspace(1, 0.2, light_range + 2)[1:-1],
246 | np.linspace(1, 0.2, light_range + 2)[1:-2],
247 | ))
248 |
249 | sats = np.tile(np.atleast_2d(sats).T, (1, num_hues))
250 | lights = np.tile(np.atleast_2d(lights).T, (1, num_hues))
251 | colors = hsv2rgb(np.dstack((hues, sats, lights)))
252 | grays = np.tile(
253 | np.linspace(1, 0, height)[:, np.newaxis, np.newaxis], (1, 1, 3))
254 |
255 | self.rgb_image = np.hstack((colors, grays))
256 |
257 | # Make a nice histogram ordering of the hues and grays
258 | h, w, d = colors.shape
259 | color_array = colors.T.reshape((d, w * h)).T
260 | h, w, d = grays.shape
261 | gray_array = grays.T.reshape((d, w * h)).T
262 |
263 | self.rgb_array = np.vstack((color_array, gray_array))
264 | self.lab_array = rgb2lab(self.rgb_array[None, :, :]).squeeze()
265 | self.hex_list = [rgb2hex(row) for row in self.rgb_array]
266 |
267 | def output(self, dirname):
268 | """
269 | Output an image of the palette, josn list of the hex
270 | colors, and an HTML color picker for it.
271 | Parameters
272 | ----------
273 | dirname : string
274 | directory for the files to be output
275 | """
276 | pass # we do not need this for visualization
277 |
278 | def color_hist_to_palette_image(color_hist, palette, percentile=90,
279 | width=200, height=50, filename=None):
280 | """
281 | Output the main colors in the histogram to a "palette image."
282 | Parameters
283 | ----------
284 | color_hist : (K,) ndarray
285 | palette : rayleigh.Palette
286 | percentile : int, optional:
287 | Output only colors above this percentile of prevalence in the histogram.
288 | filename : string, optional:
289 | If given, save the resulting image to file.
290 | Returns
291 | -------
292 | rgb_image : ndarray
293 | """
294 | ind = np.argsort(-color_hist)
295 | ind = ind[color_hist[ind] > np.percentile(color_hist, percentile)]
296 | hex_list = np.take(palette.hex_list, ind)
297 | values = color_hist[ind]
298 | rgb_image = palette_query_to_rgb_image(dict(zip(hex_list, values)), width, height)
299 | if filename:
300 | imsave(filename, rgb_image)
301 | return rgb_image
302 |
303 | def palette_query_to_rgb_image(palette_query, width=200, height=50):
304 | """
305 | Convert a list of hex colors and their values to an RGB image of given
306 | width and height.
307 | Args:
308 | - palette_query (dict):
309 | a dictionary of hex colors to unnormalized values,
310 | e.g. {'#ffffff': 20, '#33cc00': 0.4}.
311 | """
312 | hex_list, values = zip(*palette_query.items())
313 | values = np.array(values)
314 | values /= values.sum()
315 | nums = np.array(values * width, dtype=int)
316 | rgb_arrays = (np.tile(np.array(hex2rgb(x)), (num, 1))
317 | for x, num in zip(hex_list, nums))
318 | rgb_array = np.vstack(list(rgb_arrays))
319 | rgb_image = rgb_array[np.newaxis, :, :]
320 | rgb_image = np.tile(rgb_image, (height, 1, 1))
321 | return rgb_image
322 |
323 | def rgb2hex(rgb_number):
324 | """
325 | Args:
326 | - rgb_number (sequence of float)
327 | Returns:
328 | - hex_number (string)
329 | """
330 | return '#{:02x}{:02x}{:02x}'.format(*tuple([int(np.round(val * 255)) for val in rgb_number]))
331 |
332 | def hex2rgb(hexcolor_str):
333 | """
334 | Args:
335 | - hexcolor_str (string): e.g. '#ffffff' or '33cc00'
336 | Returns:
337 | - rgb_color (sequence of floats): e.g. (0.2, 0.3, 0)
338 | """
339 | color = hexcolor_str.strip('#')
340 | return tuple(round(int(color[i:i+2], 16) / 255., 5) for i in (0, 2, 4))
341 |
342 | def vis_color_palette(color_hist, shape):
343 | if color_hist.sum() == 0:
344 | color_hist[-1] = 1.0
345 | palette = Palette(num_hues=11, sat_range=5, light_range=5)
346 | color_palette = color_hist_to_palette_image(color_hist.cpu().numpy(), palette, percentile=90)
347 | color_palette = torch.tensor(color_palette.transpose(2,0,1)).unsqueeze(0)
348 | color_palette = F.interpolate(color_palette, size=(shape, shape), mode='nearest')
349 | return color_palette.squeeze(0)
350 |
351 |
--------------------------------------------------------------------------------