├── LICENSE
├── README.md
├── configs
├── ours_nonstyle_best.yaml
├── ours_nonstyle_best_colab.yaml
├── ours_style_best.yaml
└── ours_style_best_colab.yaml
├── diffusion_core
├── __init__.py
├── configuration_utils.py
├── custom_forwards
│ ├── __init__.py
│ └── unet_sd.py
├── diffusion_models.py
├── diffusion_schedulers.py
├── diffusion_utils.py
├── guiders
│ ├── __init__.py
│ ├── guidance_editing.py
│ ├── noise_rescales.py
│ ├── opt_guiders.py
│ └── scale_schedulers.py
├── inversion
│ ├── __init__.py
│ ├── negativ_p_inversion.py
│ └── null_inversion.py
├── schedulers
│ ├── __init__.py
│ ├── opt_schedulers.py
│ └── sample_schedulers.py
└── utils
│ ├── __init__.py
│ ├── class_registry.py
│ ├── grad_checkpoint.py
│ ├── image_utils.py
│ └── model_utils.py
├── docs
├── diagram.png
└── teaser_image.png
├── example_images
├── face.png
└── zebra.jpeg
├── example_notebooks
└── guide_and_rescale.ipynb
└── sd_env.yaml
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Guide-and-Rescale: Self-Guidance Mechanism for Effective Tuning-Free Real Image Editing (ECCV 2024)
2 |
3 |
4 |
5 |
6 | [](./LICENSE)
7 |
8 | >Despite recent advances in large-scale text-to-image generative models, manipulating real images with these models remains a challenging problem. The main limitations of existing editing methods are that they either fail to perform with consistent quality on a wide range of image edits, or require time-consuming hyperparameter tuning or fine-tuning of the diffusion model to preserve the image-specific appearance of the input image. Most of these approaches utilize source image information via intermediate feature caching which is inserted in generation process as itself. However, such technique produce feature misalignment of the model that leads to inconsistent results.
9 | We propose a novel approach that is built upon modified diffusion sampling process via guidance mechanism. In this work, we explore self-guidance technique to preserve the overall structure of the input image and its local regions appearance that should not be edited. In particular, we explicitly introduce layout preserving energy functions that are aimed to save local and global structures of the source image. Additionally, we propose a noise rescaling mechanism that allows to preserve noise distribution by balancing the norms of classifier-free guidance and our proposed guiders during generation. It leads to more consistent and better editing results. Such guiding approach does not require fine-tuning diffusion model and exact inversion process. As a result, the proposed method provides a fast and high quality editing mechanism.
10 | In our experiments, we show through human evaluation and quantitative analysis that the proposed method allows to produce desired editing which is more preferable by the human and also achieves a better trade-off between editing quality and preservation of the original image.
11 | >
12 |
13 | 
14 |
15 | ## Setup
16 |
17 | This code uses a pre-trained [Stable Diffusion](https://huggingface.co/docs/diffusers/v0.25.1/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline) from [Diffusers](https://github.com/huggingface/diffusers#readme) library. We ran our code with Python 3.8.5, PyTorch 2.3.0, Diffuser 0.17.1 on NVIDIA A100 GPU with 40GB RAM.
18 |
19 | In order to setup the environment, run:
20 | ```
21 | conda env create -f sd_env.yaml
22 | ```
23 | Conda environment `ldm` will be created and you can use it.
24 |
25 |
26 | ## Quickstart
27 |
28 | We provide examples of applying our pipeline to real image editing in Colab
.
29 |
30 | You can try Grardio demo in HF Spaces
.
31 |
32 | We also provide [a jupyter notebook](example_notebooks/guide_and_rescale.ipynb) to try Guide-and-Rescale pipeline on your own server.
33 |
34 | ## Method Diagram
35 |
36 |
37 |
38 |
39 |
40 |
41 | Overall scheme of the proposed method Guide-and-Rescale. First, our method uses a classic ddim inversion of the source real image. Then the method performs real image editing via the classical denoising process. For every denoising step the noise term is modified by guider that utilizes latents $z_t$ from the current generation process and time-aligned ddim latents $z^*_t$.
42 |
43 |
44 |
45 | ## Guiders
46 |
47 | In our work we propose specific **guiders**, i.e. guidance signals suitable for editing. The code for these guiders can be found in [diffusion_core/guiders/opt_guiders.py](diffusion_core/guiders/opt_guiders.py).
48 |
49 | Every guider is defined as a separate class, that inherits from the parent class `BaseGuider`. A template for defining a new guider class looks as follows:
50 |
51 | ```
52 | class SomeGuider(BaseGuider):
53 | patched: bool
54 | forward_hooks: list
55 |
56 | def [grad_fn or calc_energy](self, data_dict):
57 | ...
58 |
59 | def model_patch(self, model):
60 | ...
61 |
62 | def single_output_clear(self):
63 | ...
64 | ```
65 |
66 | ### grad_fn or calc_energy
67 |
68 | The `BaseGuider` class contains a property `grad_guider`. This property is `True`, when the guider does not require any backpropagation over its outputs for retrieving the gradient w.r.t. the current latent (for example, as in classifier-free guidance). In this case, the child class contains a function `grad_fn`, where the gradient w.r.t. the current latent is estimated algorithmically.
69 |
70 | When the gradient has to be estimated with backpropagation and `grad_guider` is `False` (for example, as when using the norm of the difference of attention maps for guidance), the child class contains a function `calc_energy`, where the desired energy function output is calculated. This output is further used for backpropagation.
71 |
72 | The `grad_fn` and `calc_energy` functions receive a dictionary (`data_dict`) as input. In this dictionary we store all objects (the diffusion model instance, prompts, current latent, etc.) that might be usefull for the guiders in the current pipeline.
73 |
74 | ### model_patch and patched
75 |
76 | When the guider requires outputs of intermediate layers of the diffusion model to estimate the energy function/gradient, we define a function `model_patch` in this guider's class and set property `patched` equal `True`. We will further refer to such guiders as *patched guiders*.
77 |
78 | This function patches the desired layers of the diffusion model, an retrieves the necesarry output from these layers. This output is then stored in the property `output` of the guider class object. This way it can be accessed by the editing pipeline an stored in `data_dict` for further use in `calc_energy`/`grad_fn` functions.
79 |
80 | ### forward_hooks
81 |
82 | In the editing pipeline we conduct 4 diffusion model forward passes:
83 |
84 | - unconditional, from the current latent $z_t$
85 | - `cur_inv`: conditional on the initial prompt, from the current latent $z_t$
86 | - `inv_inv`: conditional on the initial prompt, from the corresponding inversion latent $z^*_t$
87 | - `cur_trg`: conditional on the prompt describing the editing result, from the current latent $z_t$
88 |
89 | We store the unconditional prediction in `data_dict`, as well as the ouputs of `cur_inv` and `cur_trg` forward passes for further use in classifier-free guidance.
90 |
91 | However, when the guider is patched, we also have its `output` to store in `data_dict`. In `forward_hooks` property of the guider class we define the list of forward passes (from the range: `cur_inv`, `inv_inv`, `cur_trg`), for which we need to store the `output`.
92 |
93 | After the specific forward pass is conducted we can access the `output` of the guider and store it in `data_dict`, if the forward pass is listed in `forward_hooks`. We store it with a key, specifying the current forward pass.
94 |
95 | This way we can avoid storing unnecesary `output`s in `data_dict`, as well as distinguish `output`s from different forward passes by their keys.
96 |
97 |
98 | ### single_output_clear
99 |
100 | This is only relevant for patched guiders.
101 |
102 | When the data from the `output` property of the guiders class object is stored in `data_dict`, we need to empty the `output` to avoid exceeding memory limit. For this purpose we define a `single_output_clear` function. It returns an empty `output`, for example `None`, or an empty list `[]`.
103 |
104 | ## References & Acknowledgments
105 |
106 | The repository was started from [Prompt-to-Prompt](https://github.com/google/prompt-to-prompt/).
107 |
108 | ## Citation
109 |
110 | If you use this code for your research, please cite our paper:
111 | ```
112 | @article{titov2024guideandrescale
113 | title={Guide-and-Rescale: Self-Guidance Mechanism for Effective Tuning-Free Real Image Editing},
114 | author={Vadim Titov and Madina Khalmatova and Alexandra Ivanova and Dmitry Vetrov and Aibek Alanov},
115 | journal={arXiv preprint arXiv:2409.01322},
116 | year={2024}
117 | }
118 | ```
119 |
--------------------------------------------------------------------------------
/configs/ours_nonstyle_best.yaml:
--------------------------------------------------------------------------------
1 | scheduler_type: ddim_50_eps
2 | inversion_type: dummy
3 | model_name: stable-diffusion-v1-4
4 | pipeline_type: ours
5 | start_latent: inversion
6 | verbose: false
7 | guiders:
8 | - name: cfg
9 | g_scale: 7.5
10 | kwargs:
11 | is_source_guidance: false
12 | - name: self_attn_map_l2_appearance
13 | g_scale: 1.
14 | kwargs:
15 | self_attn_gs: 300000.
16 | app_gs: 500.
17 | new_features: true
18 | total_first_steps: 30
19 | noise_rescaling_setup:
20 | type: range_other_on_cfg_norm
21 | init_setup:
22 | - 0.33
23 | - 3.0
24 | edit_types:
25 | - animal-2-animal
26 | - face-in-the-wild
27 | - person-in-the-wild
--------------------------------------------------------------------------------
/configs/ours_nonstyle_best_colab.yaml:
--------------------------------------------------------------------------------
1 | scheduler_type: ddim_50_eps
2 | inversion_type: dummy
3 | model_name: stable-diffusion-v1-4
4 | pipeline_type: ours
5 | start_latent: inversion
6 | verbose: false
7 | guiders:
8 | - name: cfg
9 | g_scale: 7.5
10 | kwargs:
11 | is_source_guidance: false
12 | - name: self_attn_map_l2_appearance
13 | g_scale: 1.
14 | kwargs:
15 | self_attn_gs: 300000.
16 | app_gs: 500.
17 | new_features: true
18 | total_first_steps: 30
19 | noise_rescaling_setup:
20 | type: range_other_on_cfg_norm
21 | init_setup:
22 | - 0.33
23 | - 3.0
24 | edit_types:
25 | - animal-2-animal
26 | - face-in-the-wild
27 | - person-in-the-wild
28 | gradient_checkpointing: true
29 | self_attn_layers_num:
30 | - [2, 6]
31 | - [0, 1]
32 | - [0, 4]
33 |
--------------------------------------------------------------------------------
/configs/ours_style_best.yaml:
--------------------------------------------------------------------------------
1 | scheduler_type: ddim_50_eps
2 | inversion_type: dummy
3 | model_name: stable-diffusion-v1-4
4 | pipeline_type: ours
5 | start_latent: inversion
6 | verbose: false
7 | edit_types: stylisation
8 | guiders:
9 | - name: cfg
10 | g_scale: 7.5
11 | kwargs:
12 | is_source_guidance: false
13 | - name: self_attn_map_l2
14 | g_scale:
15 | - 100000.0
16 | - 100000.0
17 | - 100000.0
18 | - 100000.0
19 | - 100000.0
20 | - 100000.0
21 | - 100000.0
22 | - 100000.0
23 | - 100000.0
24 | - 100000.0
25 | - 100000.0
26 | - 100000.0
27 | - 100000.0
28 | - 100000.0
29 | - 100000.0
30 | - 100000.0
31 | - 100000.0
32 | - 100000.0
33 | - 100000.0
34 | - 100000.0
35 | - 100000.0
36 | - 100000.0
37 | - 100000.0
38 | - 100000.0
39 | - 100000.0
40 | - 0.0
41 | - 0.0
42 | - 0.0
43 | - 0.0
44 | - 0.0
45 | - 0.0
46 | - 0.0
47 | - 0.0
48 | - 0.0
49 | - 0.0
50 | - 0.0
51 | - 0.0
52 | - 0.0
53 | - 0.0
54 | - 0.0
55 | - 0.0
56 | - 0.0
57 | - 0.0
58 | - 0.0
59 | - 0.0
60 | - 0.0
61 | - 0.0
62 | - 0.0
63 | - 0.0
64 | - 0.0
65 | kwargs: {}
66 | - name: features_map_l2
67 | g_scale:
68 | - 2.5
69 | - 2.5
70 | - 2.5
71 | - 2.5
72 | - 2.5
73 | - 2.5
74 | - 2.5
75 | - 2.5
76 | - 2.5
77 | - 2.5
78 | - 2.5
79 | - 2.5
80 | - 2.5
81 | - 2.5
82 | - 2.5
83 | - 2.5
84 | - 2.5
85 | - 2.5
86 | - 2.5
87 | - 2.5
88 | - 2.5
89 | - 2.5
90 | - 2.5
91 | - 2.5
92 | - 2.5
93 | - 0.0
94 | - 0.0
95 | - 0.0
96 | - 0.0
97 | - 0.0
98 | - 0.0
99 | - 0.0
100 | - 0.0
101 | - 0.0
102 | - 0.0
103 | - 0.0
104 | - 0.0
105 | - 0.0
106 | - 0.0
107 | - 0.0
108 | - 0.0
109 | - 0.0
110 | - 0.0
111 | - 0.0
112 | - 0.0
113 | - 0.0
114 | - 0.0
115 | - 0.0
116 | - 0.0
117 | - 0.0
118 | kwargs: {}
119 | noise_rescaling_setup:
120 | type: range_other_on_cfg_norm
121 | init_setup:
122 | - 1.5
123 | - 1.5
124 |
--------------------------------------------------------------------------------
/configs/ours_style_best_colab.yaml:
--------------------------------------------------------------------------------
1 | scheduler_type: ddim_50_eps
2 | inversion_type: dummy
3 | model_name: stable-diffusion-v1-4
4 | pipeline_type: ours
5 | start_latent: inversion
6 | verbose: false
7 | edit_types: stylisation
8 | guiders:
9 | - name: cfg
10 | g_scale: 7.5
11 | kwargs:
12 | is_source_guidance: false
13 | - name: self_attn_map_l2
14 | g_scale:
15 | - 100000.0
16 | - 100000.0
17 | - 100000.0
18 | - 100000.0
19 | - 100000.0
20 | - 100000.0
21 | - 100000.0
22 | - 100000.0
23 | - 100000.0
24 | - 100000.0
25 | - 100000.0
26 | - 100000.0
27 | - 100000.0
28 | - 100000.0
29 | - 100000.0
30 | - 100000.0
31 | - 100000.0
32 | - 100000.0
33 | - 100000.0
34 | - 100000.0
35 | - 100000.0
36 | - 100000.0
37 | - 100000.0
38 | - 100000.0
39 | - 100000.0
40 | - 0.0
41 | - 0.0
42 | - 0.0
43 | - 0.0
44 | - 0.0
45 | - 0.0
46 | - 0.0
47 | - 0.0
48 | - 0.0
49 | - 0.0
50 | - 0.0
51 | - 0.0
52 | - 0.0
53 | - 0.0
54 | - 0.0
55 | - 0.0
56 | - 0.0
57 | - 0.0
58 | - 0.0
59 | - 0.0
60 | - 0.0
61 | - 0.0
62 | - 0.0
63 | - 0.0
64 | - 0.0
65 | kwargs: {}
66 | - name: features_map_l2
67 | g_scale:
68 | - 2.5
69 | - 2.5
70 | - 2.5
71 | - 2.5
72 | - 2.5
73 | - 2.5
74 | - 2.5
75 | - 2.5
76 | - 2.5
77 | - 2.5
78 | - 2.5
79 | - 2.5
80 | - 2.5
81 | - 2.5
82 | - 2.5
83 | - 2.5
84 | - 2.5
85 | - 2.5
86 | - 2.5
87 | - 2.5
88 | - 2.5
89 | - 2.5
90 | - 2.5
91 | - 2.5
92 | - 2.5
93 | - 0.0
94 | - 0.0
95 | - 0.0
96 | - 0.0
97 | - 0.0
98 | - 0.0
99 | - 0.0
100 | - 0.0
101 | - 0.0
102 | - 0.0
103 | - 0.0
104 | - 0.0
105 | - 0.0
106 | - 0.0
107 | - 0.0
108 | - 0.0
109 | - 0.0
110 | - 0.0
111 | - 0.0
112 | - 0.0
113 | - 0.0
114 | - 0.0
115 | - 0.0
116 | - 0.0
117 | - 0.0
118 | kwargs: {}
119 | noise_rescaling_setup:
120 | type: range_other_on_cfg_norm
121 | init_setup:
122 | - 1.5
123 | - 1.5
124 | gradient_checkpointing: true
125 | self_attn_layers_num:
126 | - [2, 6]
127 | - [0, 1]
128 | - [0, 4]
129 |
--------------------------------------------------------------------------------
/diffusion_core/__init__.py:
--------------------------------------------------------------------------------
1 | from .diffusion_models import diffusion_models_registry
2 | from .diffusion_schedulers import diffusion_schedulers_registry
3 |
--------------------------------------------------------------------------------
/diffusion_core/configuration_utils.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import inspect
3 |
4 |
5 | class MethodStorage(dict):
6 | @staticmethod
7 | def register(method_name):
8 | def decorator(method):
9 | method._method_name = method_name
10 | method._is_decorated = True
11 | return method
12 | return decorator
13 |
14 | def register_methods(self):
15 | self.registered_methods = {}
16 | for method_name, method in inspect.getmembers(self, predicate=inspect.ismethod):
17 | print(method_name)
18 | if getattr(method, '_is_decorated', False):
19 | self.registered_methods[method._method_name] = method
20 |
21 | def __getitem__(self, method_name):
22 | return self.registered_methods[method_name]
23 |
--------------------------------------------------------------------------------
/diffusion_core/custom_forwards/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FusionBrainLab/Guide-and-Rescale/4a79013c641d20f57b206ad7f0bf2e9d01cd412c/diffusion_core/custom_forwards/__init__.py
--------------------------------------------------------------------------------
/diffusion_core/custom_forwards/unet_sd.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from typing import Optional, Tuple, Union, Dict, Any
4 |
5 | from diffusion_core.utils import checkpoint_forward
6 |
7 |
8 | @checkpoint_forward
9 | def unet_down_forward(downsample_block, sample, emb, encoder_hidden_states, attention_mask, cross_attention_kwargs):
10 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
11 | sample, res_samples = downsample_block(
12 | hidden_states=sample,
13 | temb=emb,
14 | encoder_hidden_states=encoder_hidden_states,
15 | attention_mask=attention_mask,
16 | cross_attention_kwargs=cross_attention_kwargs,
17 | )
18 | else:
19 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
20 | return sample, res_samples
21 |
22 |
23 | @checkpoint_forward
24 | def unet_mid_forward(mid_block, sample, emb, encoder_hidden_states, attention_mask, cross_attention_kwargs):
25 | sample = mid_block(
26 | sample,
27 | emb,
28 | encoder_hidden_states=encoder_hidden_states,
29 | attention_mask=attention_mask,
30 | cross_attention_kwargs=cross_attention_kwargs,
31 | )
32 | return sample
33 |
34 |
35 | @checkpoint_forward
36 | def unet_up_forward(upsample_block, sample, emb, res_samples, encoder_hidden_states, cross_attention_kwargs,
37 | upsample_size, attention_mask):
38 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
39 | sample = upsample_block(
40 | hidden_states=sample,
41 | temb=emb,
42 | res_hidden_states_tuple=res_samples,
43 | encoder_hidden_states=encoder_hidden_states,
44 | cross_attention_kwargs=cross_attention_kwargs,
45 | upsample_size=upsample_size,
46 | attention_mask=attention_mask,
47 | )
48 | else:
49 | sample = upsample_block(
50 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
51 | )
52 | return sample
53 |
54 |
55 | def unet_forward(
56 | model,
57 | sample: torch.FloatTensor,
58 | timestep: Union[torch.Tensor, float, int],
59 | encoder_hidden_states: torch.Tensor,
60 | controlnet_cond=None,
61 | controlnet_conditioning_scale=1.,
62 | class_labels: Optional[torch.Tensor] = None,
63 | timestep_cond: Optional[torch.Tensor] = None,
64 | attention_mask: Optional[torch.Tensor] = None,
65 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
66 | return_dict: bool = True,
67 | ):
68 | if controlnet_cond is not None:
69 | down_block_additional_residuals, mid_block_additional_residual = model.controlnet(
70 | sample,
71 | timestep,
72 | encoder_hidden_states=encoder_hidden_states,
73 | controlnet_cond=controlnet_cond,
74 | return_dict=False,
75 | )
76 |
77 | down_block_additional_residuals = [
78 | down_block_res_sample * controlnet_conditioning_scale
79 | for down_block_res_sample in down_block_additional_residuals
80 | ]
81 | mid_block_additional_residual *= controlnet_conditioning_scale
82 | else:
83 | down_block_additional_residuals = None
84 | mid_block_additional_residual = None
85 |
86 | default_overall_up_factor = 2 ** model.unet.num_upsamplers
87 |
88 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
89 | forward_upsample_size = False
90 | upsample_size = None
91 |
92 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
93 | logger.info("Forward upsample size to force interpolation output size.")
94 | forward_upsample_size = True
95 |
96 | # prepare attention_mask
97 | if attention_mask is not None:
98 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
99 | attention_mask = attention_mask.unsqueeze(1)
100 |
101 | # 0. center input if necessary
102 | if model.unet.config.center_input_sample:
103 | sample = 2 * sample - 1.0
104 |
105 | # 1. time
106 | timesteps = timestep
107 | if not torch.is_tensor(timesteps):
108 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
109 | # This would be a good case for the `match` statement (Python 3.10+)
110 | is_mps = sample.device.type == "mps"
111 | if isinstance(timestep, float):
112 | dtype = torch.float32 if is_mps else torch.float64
113 | else:
114 | dtype = torch.int32 if is_mps else torch.int64
115 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
116 | elif len(timesteps.shape) == 0:
117 | timesteps = timesteps[None].to(sample.device)
118 |
119 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
120 | timesteps = timesteps.expand(sample.shape[0])
121 |
122 | t_emb = model.unet.time_proj(timesteps)
123 |
124 | # timesteps does not contain any weights and will always return f32 tensors
125 | # but time_embedding might actually be running in fp16. so we need to cast here.
126 | # there might be better ways to encapsulate this.
127 | t_emb = t_emb.to(dtype=model.unet.dtype)
128 |
129 | emb = model.unet.time_embedding(t_emb, timestep_cond)
130 |
131 | if model.unet.class_embedding is not None:
132 | if class_labels is None:
133 | raise ValueError("class_labels should be provided when num_class_embeds > 0")
134 |
135 | if model.unet.config.class_embed_type == "timestep":
136 | class_labels = model.unet.time_proj(class_labels)
137 |
138 | class_emb = model.unet.class_embedding(class_labels).to(dtype=model.unet.dtype)
139 | emb = emb + class_emb
140 |
141 | # 2. pre-process
142 | sample = model.unet.conv_in(sample)
143 |
144 | # 3. down
145 | down_block_res_samples = (sample,)
146 | for downsample_block in model.unet.down_blocks:
147 | sample, res_samples = unet_down_forward(downsample_block, sample, emb, encoder_hidden_states, attention_mask,
148 | cross_attention_kwargs)
149 |
150 | down_block_res_samples += res_samples
151 |
152 | if down_block_additional_residuals is not None:
153 | new_down_block_res_samples = ()
154 |
155 | for down_block_res_sample, down_block_additional_residual in zip(
156 | down_block_res_samples, down_block_additional_residuals
157 | ):
158 | down_block_res_sample = down_block_res_sample + down_block_additional_residual
159 | new_down_block_res_samples += (down_block_res_sample,)
160 |
161 | down_block_res_samples = new_down_block_res_samples
162 |
163 | # 4. mid
164 | if model.unet.mid_block is not None:
165 | sample = unet_mid_forward(model.unet.mid_block, sample, emb, encoder_hidden_states, attention_mask,
166 | cross_attention_kwargs)
167 |
168 | if mid_block_additional_residual is not None:
169 | sample = sample + mid_block_additional_residual
170 |
171 | # 5. up
172 | for i, upsample_block in enumerate(model.unet.up_blocks):
173 | is_final_block = i == len(model.unet.up_blocks) - 1
174 |
175 | res_samples = down_block_res_samples[-len(upsample_block.resnets):]
176 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
177 |
178 | # if we have not reached the final block and need to forward the
179 | # upsample size, we do it here
180 | if not is_final_block and forward_upsample_size:
181 | upsample_size = down_block_res_samples[-1].shape[2:]
182 |
183 | sample = unet_up_forward(upsample_block, sample, emb, res_samples, encoder_hidden_states,
184 | cross_attention_kwargs, upsample_size, attention_mask)
185 |
186 | # 6. post-process
187 | if model.unet.conv_norm_out:
188 | sample = model.unet.conv_norm_out(sample)
189 | sample = model.unet.conv_act(sample)
190 | sample = model.unet.conv_out(sample)
191 |
192 | # if not return_dict:
193 | # return (sample,)
194 |
195 | return sample
196 |
--------------------------------------------------------------------------------
/diffusion_core/diffusion_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from diffusers.pipelines import StableDiffusionPipeline
4 | from .utils import ClassRegistry
5 |
6 | diffusion_models_registry = ClassRegistry()
7 |
8 |
9 | @diffusion_models_registry.add_to_registry("stable-diffusion-v1-4")
10 | def read_v14(scheduler):
11 | model_id = "CompVis/stable-diffusion-v1-4"
12 | model = StableDiffusionPipeline.from_pretrained(
13 | model_id, torch_dtype=torch.float32, scheduler=scheduler
14 | )
15 | return model
16 |
17 |
18 | @diffusion_models_registry.add_to_registry("stable-diffusion-v1-5")
19 | def read_v15(scheduler):
20 | model_id = "runwayml/stable-diffusion-v1-5"
21 | model = StableDiffusionPipeline.from_pretrained(
22 | model_id, torch_dtype=torch.float32, scheduler=scheduler
23 | )
24 | return model
25 |
26 |
27 | @diffusion_models_registry.add_to_registry("stable-diffusion-v2-1")
28 | def read_v21(scheduler):
29 | model_id = "stabilityai/stable-diffusion-2-1"
30 | model = StableDiffusionPipeline.from_pretrained(
31 | model_id, torch_dtype=torch.float32, scheduler=scheduler
32 | )
33 | return model
34 |
--------------------------------------------------------------------------------
/diffusion_core/diffusion_schedulers.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from diffusers.pipelines import StableDiffusionPipeline
4 |
5 | from diffusion_core.schedulers import DDIMScheduler
6 | from .utils import ClassRegistry
7 |
8 |
9 | diffusion_schedulers_registry = ClassRegistry()
10 |
11 |
12 | @diffusion_schedulers_registry.add_to_registry("ddim_50_eps")
13 | def get_ddim_50_e():
14 | scheduler = DDIMScheduler(
15 | beta_start=0.00085,
16 | beta_end=0.012,
17 | beta_schedule="scaled_linear",
18 | set_alpha_to_one=False,
19 | num_inference_steps=50,
20 | prediction_type='epsilon'
21 | )
22 | return scheduler
23 |
24 |
25 | @diffusion_schedulers_registry.add_to_registry("ddim_50_v")
26 | def get_ddim_50_v():
27 | scheduler = DDIMScheduler(
28 | beta_start=0.00085,
29 | beta_end=0.012,
30 | beta_schedule="scaled_linear",
31 | set_alpha_to_one=False,
32 | num_inference_steps=50,
33 | prediction_type='v_prediction'
34 | )
35 | return scheduler
36 |
37 |
38 | @diffusion_schedulers_registry.add_to_registry("ddim_200_v")
39 | def get_ddim_200_v():
40 | scheduler = DDIMScheduler(
41 | beta_start=0.00085,
42 | beta_end=0.012,
43 | beta_schedule="scaled_linear",
44 | set_alpha_to_one=False,
45 | num_inference_steps=200,
46 | prediction_type='v_prediction'
47 | )
48 | return scheduler
49 |
--------------------------------------------------------------------------------
/diffusion_core/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from PIL import Image
5 |
6 |
7 | @torch.no_grad()
8 | def latent2image(latents, model, return_type='np'):
9 | latents = latents.detach() / model.vae.config.scaling_factor
10 | image = model.vae.decode(latents)['sample']
11 | if return_type == 'np':
12 | image = (image / 2 + 0.5).clamp(0, 1)
13 | image = image.cpu().permute(0, 2, 3, 1).numpy()
14 | image = (image * 255).astype(np.uint8)
15 | return image
16 |
17 |
18 | @torch.no_grad()
19 | def image2latent(image, model):
20 | if type(image) is Image:
21 | image = np.array(image)
22 | if type(image) is torch.Tensor and image.dim() == 4:
23 | latents = image
24 | else:
25 | image = torch.from_numpy(image).float() / 127.5 - 1
26 | image = image.permute(2, 0, 1).unsqueeze(0).to(model.device).to(model.unet.dtype)
27 | latents = model.vae.encode(image)['latent_dist'].mean
28 | latents = latents * model.vae.config.scaling_factor
29 | return latents
30 |
31 |
--------------------------------------------------------------------------------
/diffusion_core/guiders/__init__.py:
--------------------------------------------------------------------------------
1 | from .guidance_editing import GuidanceEditing
--------------------------------------------------------------------------------
/diffusion_core/guiders/guidance_editing.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import OrderedDict
3 | from typing import Callable, Dict, Optional
4 |
5 | import torch
6 | import numpy as np
7 | import PIL
8 | import gc
9 |
10 | from tqdm.auto import trange, tqdm
11 | from diffusion_core.guiders.opt_guiders import opt_registry
12 | from diffusion_core.diffusion_utils import latent2image, image2latent
13 | from diffusion_core.custom_forwards.unet_sd import unet_forward
14 | from diffusion_core.guiders.noise_rescales import noise_rescales
15 | from diffusion_core.inversion import Inversion, NullInversion, NegativePromptInversion
16 | from diffusion_core.utils import toggle_grad, use_grad_checkpointing
17 |
18 |
19 | class GuidanceEditing:
20 | def __init__(
21 | self,
22 | model,
23 | config
24 | ):
25 |
26 | self.config = config
27 | self.model = model
28 |
29 | toggle_grad(self.model.unet, False)
30 |
31 | if config.get('gradient_checkpointing', False):
32 | use_grad_checkpointing(mode=True)
33 | else:
34 | use_grad_checkpointing(mode=False)
35 |
36 | self.guiders = {
37 | g_data.name: (opt_registry[g_data.name](**g_data.get('kwargs', {})), g_data.g_scale)
38 | for g_data in config.guiders
39 | }
40 |
41 | self._setup_inversion_engine()
42 | self.latents_stack = []
43 |
44 | self.context = None
45 |
46 | self.noise_rescaler = noise_rescales[config.noise_rescaling_setup.type](
47 | config.noise_rescaling_setup.init_setup,
48 | **config.noise_rescaling_setup.get('kwargs', {})
49 | )
50 |
51 | for guider_name, (guider, _) in self.guiders.items():
52 | guider.clear_outputs()
53 |
54 | self.self_attn_layers_num = config.get('self_attn_layers_num', [6, 1, 9])
55 | if type(self.self_attn_layers_num[0]) is int:
56 | for i in range(len(self.self_attn_layers_num)):
57 | self.self_attn_layers_num[i] = (0, self.self_attn_layers_num[i])
58 |
59 |
60 | def _setup_inversion_engine(self):
61 | if self.config.inversion_type == 'ntinv':
62 | self.inversion_engine = NullInversion(
63 | self.model,
64 | self.model.scheduler.num_inference_steps,
65 | self.config.guiders[0]['g_scale'],
66 | forward_guidance_scale=1,
67 | verbose=self.config.verbose
68 | )
69 | elif self.config.inversion_type == 'npinv':
70 | self.inversion_engine = NegativePromptInversion(
71 | self.model,
72 | self.model.scheduler.num_inference_steps,
73 | self.config.guiders[0]['g_scale'],
74 | forward_guidance_scale=1,
75 | verbose=self.config.verbose
76 | )
77 | elif self.config.inversion_type == 'dummy':
78 | self.inversion_engine = Inversion(
79 | self.model,
80 | self.model.scheduler.num_inference_steps,
81 | self.config.guiders[0]['g_scale'],
82 | forward_guidance_scale=1,
83 | verbose=self.config.verbose
84 | )
85 | else:
86 | raise ValueError('Incorrect InversionType')
87 |
88 | def __call__(
89 | self,
90 | image_gt: PIL.Image.Image,
91 | inv_prompt: str,
92 | trg_prompt: str,
93 | control_image: Optional[PIL.Image.Image] = None,
94 | verbose: bool = False
95 | ):
96 | self.train(
97 | image_gt,
98 | inv_prompt,
99 | trg_prompt,
100 | control_image,
101 | verbose
102 | )
103 |
104 | return self.edit()
105 |
106 | def train(
107 | self,
108 | image_gt: PIL.Image.Image,
109 | inv_prompt: str,
110 | trg_prompt: str,
111 | control_image: Optional[PIL.Image.Image] = None,
112 | verbose: bool = False
113 | ):
114 | self.init_prompt(inv_prompt, trg_prompt)
115 | self.verbose = verbose
116 |
117 | image_gt = np.array(image_gt)
118 | if self.config.start_latent == 'inversion':
119 | _, self.inv_latents, self.uncond_embeddings = self.inversion_engine(
120 | image_gt, inv_prompt,
121 | verbose=self.verbose
122 | )
123 | elif self.config.start_latent == 'random':
124 | self.inv_latents = self.sample_noised_latents(
125 | image2latent(image_gt, self.model)
126 | )
127 | else:
128 | raise ValueError('Incorrect start latent type')
129 |
130 | for g_name, (guider, _) in self.guiders.items():
131 | if hasattr(guider, 'model_patch'):
132 | guider.model_patch(self.model, self_attn_layers_num=self.self_attn_layers_num)
133 |
134 | self.start_latent = self.inv_latents[-1].clone()
135 |
136 | params = {
137 | 'model': self.model,
138 | 'inv_prompt': inv_prompt,
139 | 'trg_prompt': trg_prompt
140 | }
141 | for g_name, (guider, _) in self.guiders.items():
142 | if hasattr(guider, 'train'):
143 | guider.train(params)
144 |
145 | for guider_name, (guider, _) in self.guiders.items():
146 | guider.clear_outputs()
147 |
148 | def _construct_data_dict(
149 | self, latents,
150 | diffusion_iter,
151 | timestep
152 | ):
153 | uncond_emb, inv_prompt_emb, trg_prompt_emb = self.context.chunk(3)
154 |
155 | if self.uncond_embeddings is not None:
156 | uncond_emb = self.uncond_embeddings[diffusion_iter]
157 |
158 | data_dict = {
159 | 'latent': latents,
160 | 'inv_latent': self.inv_latents[-diffusion_iter - 1],
161 | 'timestep': timestep,
162 | 'model': self.model,
163 | 'uncond_emb': uncond_emb,
164 | 'trg_emb': trg_prompt_emb,
165 | 'inv_emb': inv_prompt_emb,
166 | 'diff_iter': diffusion_iter
167 | }
168 |
169 | with torch.no_grad():
170 | uncond_unet = unet_forward(
171 | self.model,
172 | data_dict['latent'],
173 | data_dict['timestep'],
174 | data_dict['uncond_emb'],
175 | None
176 | )
177 |
178 | for g_name, (guider, _) in self.guiders.items():
179 | if hasattr(guider, 'model_patch'):
180 | guider.clear_outputs()
181 |
182 | with torch.no_grad():
183 | inv_prompt_unet = unet_forward(
184 | self.model,
185 | data_dict['inv_latent'],
186 | data_dict['timestep'],
187 | data_dict['inv_emb'],
188 | None
189 | )
190 |
191 | for g_name, (guider, _) in self.guiders.items():
192 | if hasattr(guider, 'model_patch'):
193 | if 'inv_inv' in guider.forward_hooks:
194 | data_dict.update({f"{g_name}_inv_inv": guider.output})
195 | guider.clear_outputs()
196 |
197 | data_dict['latent'].requires_grad = True
198 |
199 | src_prompt_unet = unet_forward(
200 | self.model,
201 | data_dict['latent'],
202 | data_dict['timestep'],
203 | data_dict['inv_emb'],
204 | None
205 | )
206 |
207 | for g_name, (guider, _) in self.guiders.items():
208 | if hasattr(guider, 'model_patch'):
209 | if 'cur_inv' in guider.forward_hooks:
210 | data_dict.update({f"{g_name}_cur_inv": guider.output})
211 | guider.clear_outputs()
212 |
213 | trg_prompt_unet = unet_forward(
214 | self.model,
215 | data_dict['latent'],
216 | data_dict['timestep'],
217 | data_dict['trg_emb'],
218 | None
219 | )
220 |
221 | for g_name, (guider, _) in self.guiders.items():
222 | if hasattr(guider, 'model_patch'):
223 | if 'cur_trg' in guider.forward_hooks:
224 | data_dict.update({f"{g_name}_cur_trg": guider.output})
225 | guider.clear_outputs()
226 |
227 | data_dict.update({
228 | 'uncond_unet': uncond_unet,
229 | 'trg_prompt_unet': trg_prompt_unet,
230 | })
231 |
232 | return data_dict
233 |
234 | def _get_noise(self, data_dict, diffusion_iter):
235 | backward_guiders_sum = 0.
236 | noises = {
237 | 'uncond': data_dict['uncond_unet'],
238 | }
239 | index = torch.where(self.model.scheduler.timesteps == data_dict['timestep'])[0].item()
240 |
241 | # self.noise_rescaler
242 | for name, (guider, g_scale) in self.guiders.items():
243 | if guider.grad_guider:
244 | cur_noise_pred = self._get_scale(g_scale, diffusion_iter) * guider(data_dict)
245 | noises[name] = cur_noise_pred
246 | else:
247 | energy = self._get_scale(g_scale, diffusion_iter) * guider(data_dict)
248 | if not torch.allclose(energy, torch.tensor(0.)):
249 | backward_guiders_sum += energy
250 |
251 | if hasattr(backward_guiders_sum, 'backward'):
252 | backward_guiders_sum.backward()
253 | noises['other'] = data_dict['latent'].grad
254 |
255 | scales = self.noise_rescaler(noises, index)
256 | noise_pred = sum(scales[k] * noises[k] for k in noises)
257 |
258 | for g_name, (guider, _) in self.guiders.items():
259 | if not guider.grad_guider:
260 | guider.clear_outputs()
261 | gc.collect()
262 | torch.cuda.empty_cache()
263 |
264 | return noise_pred
265 |
266 | @staticmethod
267 | def _get_scale(g_scale, diffusion_iter):
268 | if type(g_scale) is float:
269 | return g_scale
270 | else:
271 | return g_scale[diffusion_iter]
272 |
273 | @torch.no_grad()
274 | def _step(self, noise_pred, t, latents):
275 | latents = self.model.scheduler.step_backward(noise_pred, t, latents).prev_sample
276 | self.latents_stack.append(latents.detach())
277 | return latents
278 |
279 | def edit(self):
280 | self.model.scheduler.set_timesteps(self.model.scheduler.num_inference_steps)
281 | latents = self.start_latent
282 | self.latents_stack = []
283 |
284 | for i, timestep in tqdm(
285 | enumerate(self.model.scheduler.timesteps),
286 | total=self.model.scheduler.num_inference_steps,
287 | desc='Editing',
288 | disable=not self.verbose
289 | ):
290 | # 1. Construct dict
291 | data_dict = self._construct_data_dict(latents, i, timestep)
292 |
293 | # 2. Calculate guidance
294 | noise_pred = self._get_noise(data_dict, i)
295 |
296 | # 3. Scheduler step
297 | latents = self._step(noise_pred, timestep, latents)
298 |
299 | self._model_unpatch(self.model)
300 | return latent2image(latents, self.model)[0]
301 |
302 | @torch.no_grad()
303 | def init_prompt(self, inv_prompt: str, trg_prompt: str):
304 | trg_prompt_embed = self.get_prompt_embed(trg_prompt)
305 | inv_prompt_embed = self.get_prompt_embed(inv_prompt)
306 | uncond_embed = self.get_prompt_embed("")
307 |
308 | self.context = torch.cat([uncond_embed, inv_prompt_embed, trg_prompt_embed])
309 |
310 | def get_prompt_embed(self, prompt: str):
311 | text_input = self.model.tokenizer(
312 | [prompt],
313 | padding="max_length",
314 | max_length=self.model.tokenizer.model_max_length,
315 | truncation=True,
316 | return_tensors="pt",
317 | )
318 | text_embeddings = self.model.text_encoder(
319 | text_input.input_ids.to(self.model.device)
320 | )[0]
321 |
322 | return text_embeddings
323 |
324 | def sample_noised_latents(self, latent):
325 | all_latent = [latent.clone().detach()]
326 | latent = latent.clone().detach()
327 | for i in trange(self.model.scheduler.num_inference_steps, desc='Latent Sampling'):
328 | timestep = self.model.scheduler.timesteps[-i - 1]
329 | if i + 1 < len(self.model.scheduler.timesteps):
330 | next_timestep = self.model.scheduler.timesteps[- i - 2]
331 | else:
332 | next_timestep = 999
333 |
334 | alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep]
335 | alpha_prod_t_next = self.model.scheduler.alphas_cumprod[next_timestep]
336 |
337 | alpha_slice = alpha_prod_t_next / alpha_prod_t
338 |
339 | latent = torch.sqrt(alpha_slice) * latent + torch.sqrt(1 - alpha_slice) * torch.randn_like(latent)
340 | all_latent.append(latent)
341 | return all_latent
342 |
343 | def _model_unpatch(self, model):
344 | def new_forward_info(self):
345 | def patched_forward(
346 | hidden_states,
347 | encoder_hidden_states=None,
348 | attention_mask=None,
349 | temb=None,
350 | ):
351 | residual = hidden_states
352 |
353 | if self.spatial_norm is not None:
354 | hidden_states = self.spatial_norm(hidden_states, temb)
355 |
356 | input_ndim = hidden_states.ndim
357 |
358 | if input_ndim == 4:
359 | batch_size, channel, height, width = hidden_states.shape
360 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
361 |
362 | batch_size, sequence_length, _ = (
363 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
364 | )
365 | attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
366 |
367 | if self.group_norm is not None:
368 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
369 |
370 | query = self.to_q(hidden_states)
371 |
372 | ## Injection
373 | is_self = encoder_hidden_states is None
374 |
375 | if encoder_hidden_states is None:
376 | encoder_hidden_states = hidden_states
377 | elif self.norm_cross:
378 | encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
379 |
380 | key = self.to_k(encoder_hidden_states)
381 | value = self.to_v(encoder_hidden_states)
382 |
383 | query = self.head_to_batch_dim(query)
384 | key = self.head_to_batch_dim(key)
385 | value = self.head_to_batch_dim(value)
386 |
387 | attention_probs = self.get_attention_scores(query, key, attention_mask)
388 | hidden_states = torch.bmm(attention_probs, value)
389 | hidden_states = self.batch_to_head_dim(hidden_states)
390 |
391 | # linear proj
392 | hidden_states = self.to_out[0](hidden_states)
393 | # dropout
394 | hidden_states = self.to_out[1](hidden_states)
395 |
396 | if input_ndim == 4:
397 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
398 |
399 | if self.residual_connection:
400 | hidden_states = hidden_states + residual
401 |
402 | hidden_states = hidden_states / self.rescale_output_factor
403 |
404 | return hidden_states
405 |
406 | return patched_forward
407 |
408 | def register_attn(module):
409 | if 'Attention' in module.__class__.__name__:
410 | module.forward = new_forward_info(module)
411 | elif hasattr(module, 'children'):
412 | for module_ in module.children():
413 | register_attn(module_)
414 |
415 | def remove_hooks(module):
416 | if hasattr(module, "_forward_hooks"):
417 | module._forward_hooks: Dict[int, Callable] = OrderedDict()
418 | if hasattr(module, 'children'):
419 | for module_ in module.children():
420 | remove_hooks(module_)
421 |
422 | register_attn(model.unet)
423 | remove_hooks(model.unet)
424 |
--------------------------------------------------------------------------------
/diffusion_core/guiders/noise_rescales.py:
--------------------------------------------------------------------------------
1 | import enum
2 | import torch
3 | import numpy as np
4 |
5 | from diffusion_core.utils import ClassRegistry
6 |
7 |
8 | noise_rescales = ClassRegistry()
9 |
10 |
11 | class RescaleType(enum.Enum):
12 | UPPER = 0
13 | RANGE = 1
14 |
15 |
16 | class BaseNoiseRescaler:
17 | def __init__(self, noise_rescale_setup):
18 | if isinstance(noise_rescale_setup, float):
19 | self.upper_bound = noise_rescale_setup
20 | self.rescale_type = RescaleType.UPPER
21 | elif len(noise_rescale_setup) == 2:
22 | self.upper_bound, self.upper_bound = noise_rescale_setup
23 | self.rescale_type = RescaleType.RANGE
24 | else:
25 | raise TypeError('Incorrect noise_rescale_setup type: possible types are float, tuple(float, float)')
26 |
27 | def __call__(self, noises, index):
28 | if 'other' not in noises:
29 | return {k: 1. for k in noises}
30 | rescale_dict = self._rescale(noises, index)
31 | return rescale_dict
32 |
33 | def _rescale(self, noises, index):
34 | raise NotImplementedError('')
35 |
36 |
37 | @noise_rescales.add_to_registry('identity_rescaler')
38 | class IdentityRescaler:
39 | def __init__(self, *args, **kwargs):
40 | pass
41 |
42 | def __call__(self, noises, index):
43 | return {k: 1. for k in noises}
44 |
45 |
46 | @noise_rescales.add_to_registry('range_other_on_cfg_norm')
47 | class RangeNoiseRescaler(BaseNoiseRescaler):
48 | def __init__(self, noise_rescale_setup):
49 | super().__init__(noise_rescale_setup)
50 | assert len(noise_rescale_setup) == 2, 'incorrect len of noise_rescale_setup'
51 | self.lower_bound, self.upper_bound = noise_rescale_setup
52 |
53 | def _rescale(self, noises, index):
54 | cfg_noise_norm = torch.norm(noises['cfg']).item()
55 | other_noise_norm = torch.norm(noises['other']).item()
56 |
57 | ratio = other_noise_norm / cfg_noise_norm if cfg_noise_norm != 0 else 1.
58 | ratio_clipped = np.clip(ratio, self.lower_bound, self.upper_bound)
59 | if other_noise_norm != 0.:
60 | other_scale = ratio_clipped / ratio
61 | else:
62 | other_scale = 1.
63 |
64 | answer = {
65 | 'cfg': 1.,
66 | 'uncond': 1.,
67 | 'other': other_scale
68 | }
69 | return answer
70 |
--------------------------------------------------------------------------------
/diffusion_core/guiders/opt_guiders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Optional
3 |
4 | from diffusion_core.utils.class_registry import ClassRegistry
5 | from diffusion_core.guiders.scale_schedulers import last_steps, first_steps
6 |
7 | opt_registry = ClassRegistry()
8 |
9 | class BaseGuider:
10 | def __init__(self):
11 | self.clear_outputs()
12 |
13 | @property
14 | def grad_guider(self):
15 | return hasattr(self, 'grad_fn')
16 |
17 | def __call__(self, data_dict):
18 | if self.grad_guider:
19 | return self.grad_fn(data_dict)
20 | else:
21 | return self.calc_energy(data_dict)
22 |
23 | def clear_outputs(self):
24 | if not self.grad_guider:
25 | self.output = self.single_output_clear()
26 |
27 | def single_output_clear(self):
28 | raise NotImplementedError()
29 |
30 |
31 | @opt_registry.add_to_registry('cfg')
32 | class ClassifierFreeGuidance(BaseGuider):
33 | def __init__(self, is_source_guidance=False):
34 | self.is_source_guidance = is_source_guidance
35 |
36 | def grad_fn(self, data_dict):
37 | prompt_unet = data_dict['src_prompt_unet'] if self.is_source_guidance else data_dict['trg_prompt_unet']
38 | return prompt_unet - data_dict['uncond_unet']
39 |
40 |
41 | @opt_registry.add_to_registry('latents_diff')
42 | class LatentsDiffGuidance(BaseGuider):
43 | """
44 | \| z_t* - z_t \|^2_2
45 | """
46 | def grad_fn(self, data_dict):
47 | return 2 * (data_dict['latent'] - data_dict['inv_latent'])
48 |
49 |
50 | @opt_registry.add_to_registry('features_map_l2')
51 | class FeaturesMapL2EnergyGuider(BaseGuider):
52 | def __init__(self, block='up'):
53 | assert block in ['down', 'up', 'mid', 'whole']
54 | self.block = block
55 |
56 | patched = True
57 | forward_hooks = ['cur_trg', 'inv_inv']
58 | def calc_energy(self, data_dict):
59 | return torch.mean(torch.pow(data_dict['features_map_l2_cur_trg'] - data_dict['features_map_l2_inv_inv'], 2))
60 |
61 | def model_patch(self, model, self_attn_layers_num=None):
62 | def hook_fn(module, input, output):
63 | self.output = output
64 | if self.block == 'mid':
65 | model.unet.mid_block.register_forward_hook(hook_fn)
66 | elif self.block == 'up':
67 | model.unet.up_blocks[1].resnets[1].register_forward_hook(hook_fn)
68 | elif self.block == 'down':
69 | model.unet.down_blocks[1].resnets[1].register_forward_hook(hook_fn)
70 |
71 | def single_output_clear(self):
72 | None
73 |
74 |
75 | @opt_registry.add_to_registry('self_attn_map_l2')
76 | class SelfAttnMapL2EnergyGuider(BaseGuider):
77 | patched = True
78 | forward_hooks = ['cur_inv', 'inv_inv']
79 | def single_output_clear(self):
80 | return {
81 | "down_cross": [], "mid_cross": [], "up_cross": [],
82 | "down_self": [], "mid_self": [], "up_self": []
83 | }
84 |
85 | def calc_energy(self, data_dict):
86 | result = 0.
87 | for unet_place, data in data_dict['self_attn_map_l2_cur_inv'].items():
88 | for elem_idx, elem in enumerate(data):
89 | result += torch.mean(
90 | torch.pow(
91 | elem - data_dict['self_attn_map_l2_inv_inv'][unet_place][elem_idx], 2
92 | )
93 | )
94 | self.single_output_clear()
95 | return result
96 |
97 | def model_patch(guider_self, model, self_attn_layers_num=None):
98 | def new_forward_info(self, place_unet):
99 | def patched_forward(
100 | hidden_states,
101 | encoder_hidden_states=None,
102 | attention_mask=None,
103 | temb=None,
104 | ):
105 | residual = hidden_states
106 |
107 | if self.spatial_norm is not None:
108 | hidden_states = self.spatial_norm(hidden_states, temb)
109 |
110 | input_ndim = hidden_states.ndim
111 |
112 | if input_ndim == 4:
113 | batch_size, channel, height, width = hidden_states.shape
114 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
115 |
116 | batch_size, sequence_length, _ = (
117 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
118 | )
119 | attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
120 |
121 | if self.group_norm is not None:
122 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
123 |
124 | query = self.to_q(hidden_states)
125 |
126 | ## Injection
127 | is_self = encoder_hidden_states is None
128 |
129 | if encoder_hidden_states is None:
130 | encoder_hidden_states = hidden_states
131 | elif self.norm_cross:
132 | encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
133 |
134 | key = self.to_k(encoder_hidden_states)
135 | value = self.to_v(encoder_hidden_states)
136 |
137 | query = self.head_to_batch_dim(query)
138 | key = self.head_to_batch_dim(key)
139 | value = self.head_to_batch_dim(value)
140 |
141 | attention_probs = self.get_attention_scores(query, key, attention_mask)
142 | if is_self:
143 | guider_self.output[f"{place_unet}_self"].append(attention_probs)
144 | hidden_states = torch.bmm(attention_probs, value)
145 | hidden_states = self.batch_to_head_dim(hidden_states)
146 |
147 | # linear proj
148 | hidden_states = self.to_out[0](hidden_states)
149 | # dropout
150 | hidden_states = self.to_out[1](hidden_states)
151 |
152 | if input_ndim == 4:
153 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
154 |
155 | if self.residual_connection:
156 | hidden_states = hidden_states + residual
157 |
158 | hidden_states = hidden_states / self.rescale_output_factor
159 |
160 | return hidden_states
161 | return patched_forward
162 |
163 | def register_attn(module, place_in_unet, layers_num, cur_layers_num=0):
164 | if 'Attention' in module.__class__.__name__:
165 | if 2 * layers_num[0] <= cur_layers_num < 2 * layers_num[1]:
166 | module.forward = new_forward_info(module, place_in_unet)
167 | return cur_layers_num + 1
168 | elif hasattr(module, 'children'):
169 | for module_ in module.children():
170 | cur_layers_num = register_attn(module_, place_in_unet, layers_num, cur_layers_num)
171 | return cur_layers_num
172 |
173 | sub_nets = model.unet.named_children()
174 | for name, net in sub_nets:
175 | if "down" in name:
176 | register_attn(net, "down", self_attn_layers_num[0])
177 | if "mid" in name:
178 | register_attn(net, "mid", self_attn_layers_num[1])
179 | if "up" in name:
180 | register_attn(net, "up", self_attn_layers_num[2])
181 |
182 |
183 | @opt_registry.add_to_registry('self_attn_map_l2_appearance')
184 | class SelfAttnMapL2withAppearanceEnergyGuider(BaseGuider):
185 | patched = True
186 | forward_hooks = ['cur_inv', 'inv_inv']
187 |
188 | def __init__(
189 | self, self_attn_gs: float, app_gs: float, new_features: bool=False,
190 | total_last_steps: Optional[int] = None, total_first_steps: Optional[int] = None
191 | ):
192 | super().__init__()
193 |
194 | self.new_features = new_features
195 |
196 | if total_last_steps is not None:
197 | self.app_gs = last_steps(app_gs, total_last_steps)
198 | self.self_attn_gs = last_steps(self_attn_gs, total_last_steps)
199 | elif total_first_steps is not None:
200 | self.app_gs = first_steps(app_gs, total_first_steps)
201 | self.self_attn_gs = first_steps(self_attn_gs, total_first_steps)
202 | else:
203 | self.app_gs = app_gs
204 | self.self_attn_gs = self_attn_gs
205 |
206 | def single_output_clear(self):
207 | return {
208 | "down_self": [],
209 | "mid_self": [],
210 | "up_self": [],
211 | "features": None
212 | }
213 |
214 | def calc_energy(self, data_dict):
215 | self_attn_result = 0.
216 | unet_places = ['down_self', 'up_self', 'mid_self']
217 | for unet_place in unet_places:
218 | data = data_dict['self_attn_map_l2_appearance_cur_inv'][unet_place]
219 | for elem_idx, elem in enumerate(data):
220 | self_attn_result += torch.mean(
221 | torch.pow(
222 | elem - data_dict['self_attn_map_l2_appearance_inv_inv'][unet_place][elem_idx], 2
223 | )
224 | )
225 |
226 | features_orig = data_dict['self_attn_map_l2_appearance_inv_inv']['features']
227 | features_cur = data_dict['self_attn_map_l2_appearance_cur_inv']['features']
228 | app_result = torch.mean(torch.abs(features_cur - features_orig))
229 |
230 | self.single_output_clear()
231 |
232 | if type(self.app_gs) == float:
233 | _app_gs = self.app_gs
234 | else:
235 | _app_gs = self.app_gs[data_dict['diff_iter']]
236 |
237 | if type(self.self_attn_gs) == float:
238 | _self_attn_gs = self.self_attn_gs
239 | else:
240 | _self_attn_gs = self.self_attn_gs[data_dict['diff_iter']]
241 |
242 | return _self_attn_gs * self_attn_result + _app_gs * app_result
243 |
244 | def model_patch(guider_self, model, self_attn_layers_num=None):
245 | def new_forward_info(self, place_unet):
246 | def patched_forward(
247 | hidden_states,
248 | encoder_hidden_states=None,
249 | attention_mask=None,
250 | temb=None,
251 | ):
252 | residual = hidden_states
253 |
254 | if self.spatial_norm is not None:
255 | hidden_states = self.spatial_norm(hidden_states, temb)
256 |
257 | input_ndim = hidden_states.ndim
258 |
259 | if input_ndim == 4:
260 | batch_size, channel, height, width = hidden_states.shape
261 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
262 |
263 | batch_size, sequence_length, _ = (
264 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
265 | )
266 | attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
267 |
268 | if self.group_norm is not None:
269 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
270 |
271 | query = self.to_q(hidden_states)
272 |
273 | ## Injection
274 | is_self = encoder_hidden_states is None
275 |
276 | if encoder_hidden_states is None:
277 | encoder_hidden_states = hidden_states
278 | elif self.norm_cross:
279 | encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
280 |
281 | key = self.to_k(encoder_hidden_states)
282 | value = self.to_v(encoder_hidden_states)
283 |
284 | query = self.head_to_batch_dim(query)
285 | key = self.head_to_batch_dim(key)
286 | value = self.head_to_batch_dim(value)
287 |
288 | attention_probs = self.get_attention_scores(query, key, attention_mask)
289 | if is_self:
290 | guider_self.output[f"{place_unet}_self"].append(attention_probs)
291 |
292 | hidden_states = torch.bmm(attention_probs, value)
293 |
294 | hidden_states = self.batch_to_head_dim(hidden_states)
295 |
296 | # linear proj
297 | hidden_states = self.to_out[0](hidden_states)
298 | # dropout
299 | hidden_states = self.to_out[1](hidden_states)
300 |
301 | if input_ndim == 4:
302 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
303 |
304 | if self.residual_connection:
305 | hidden_states = hidden_states + residual
306 |
307 | hidden_states = hidden_states / self.rescale_output_factor
308 |
309 | return hidden_states
310 | return patched_forward
311 |
312 | def register_attn(module, place_in_unet, layers_num, cur_layers_num=0):
313 | if 'Attention' in module.__class__.__name__:
314 | if 2 * layers_num[0] <= cur_layers_num < 2 * layers_num[1]:
315 | module.forward = new_forward_info(module, place_in_unet)
316 | return cur_layers_num + 1
317 | elif hasattr(module, 'children'):
318 | for module_ in module.children():
319 | cur_layers_num = register_attn(module_, place_in_unet, layers_num, cur_layers_num)
320 | return cur_layers_num
321 |
322 | sub_nets = model.unet.named_children()
323 | for name, net in sub_nets:
324 | if "down" in name:
325 | register_attn(net, "down", self_attn_layers_num[0])
326 | if "mid" in name:
327 | register_attn(net, "mid", self_attn_layers_num[1])
328 | if "up" in name:
329 | register_attn(net, "up", self_attn_layers_num[2])
330 |
331 | def hook_fn(module, input, output):
332 | guider_self.output["features"] = output
333 |
334 | if guider_self.new_features:
335 | model.unet.up_blocks[-1].register_forward_hook(hook_fn)
336 | else:
337 | model.unet.conv_norm_out.register_forward_hook(hook_fn)
338 |
--------------------------------------------------------------------------------
/diffusion_core/guiders/scale_schedulers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def first_steps(g_scale, steps):
5 | g_scales = np.ones(50)
6 | g_scales[:steps] *= g_scale
7 | g_scales[steps:] = 0.
8 | return g_scales.tolist()
9 |
10 |
11 | def last_steps(g_scale, steps):
12 | g_scales = np.ones(50)
13 | g_scales[-steps:] *= g_scale
14 | g_scales[:-steps] = 0.
15 | return g_scales.tolist()
16 |
17 |
--------------------------------------------------------------------------------
/diffusion_core/inversion/__init__.py:
--------------------------------------------------------------------------------
1 | from .null_inversion import NullInversion, Inversion
2 | from .negativ_p_inversion import NegativePromptInversion
--------------------------------------------------------------------------------
/diffusion_core/inversion/negativ_p_inversion.py:
--------------------------------------------------------------------------------
1 | import PIL
2 | import torch
3 |
4 | from PIL import Image
5 | from typing import Optional, Union
6 |
7 | from .null_inversion import Inversion
8 |
9 |
10 | class NegativePromptInversion(Inversion):
11 | def negative_prompt_inversion(self, latents, verbose=False):
12 | uncond_embeddings, cond_embeddings = self.context.chunk(2)
13 | uncond_embeddings_list = [cond_embeddings.detach()] * self.infer_steps
14 | return uncond_embeddings_list
15 |
16 | def __call__(
17 | self,
18 | image_gt: PIL.Image.Image,
19 | prompt: Union[str, torch.Tensor],
20 | control_image: Optional[PIL.Image.Image] = None,
21 | verbose: bool = False
22 | ):
23 | image_rec, latents, _ = super().__call__(image_gt, prompt, control_image, verbose)
24 |
25 | if verbose:
26 | print("[Negative-Prompt inversion]")
27 |
28 | uncond_embeddings = self.negative_prompt_inversion(
29 | latents,
30 | verbose
31 | )
32 | return image_rec, latents, uncond_embeddings
33 |
--------------------------------------------------------------------------------
/diffusion_core/inversion/null_inversion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn.functional as nnf
4 | import PIL
5 |
6 | from typing import Optional, Union, List, Dict
7 | from tqdm.auto import tqdm, trange
8 | from torch.optim.adam import Adam
9 |
10 | from diffusion_core.diffusion_utils import latent2image, image2latent
11 | from diffusion_core.custom_forwards.unet_sd import unet_forward
12 | from diffusion_core.schedulers.opt_schedulers import opt_registry
13 |
14 |
15 | class Inversion:
16 | def __init__(
17 | self, model,
18 | inference_steps,
19 | inference_guidance_scale,
20 | forward_guidance_scale=1,
21 | verbose=False
22 | ):
23 | self.model = model
24 | self.tokenizer = self.model.tokenizer
25 | self.scheduler = model.scheduler
26 | self.scheduler.set_timesteps(inference_steps)
27 |
28 | self.prompt = None
29 | self.context = None
30 | self.controlnet_cond = None
31 |
32 | self.forward_guidance = forward_guidance_scale
33 | self.backward_guidance = inference_guidance_scale
34 | self.infer_steps = inference_steps
35 | self.half_mode = model.unet.dtype == torch.float16
36 | self.verbose = verbose
37 |
38 | @torch.no_grad()
39 | def init_controlnet_cond(self, control_image):
40 | if control_image is None:
41 | return
42 |
43 | controlnet_cond = self.model.prepare_image(
44 | control_image,
45 | 512,
46 | 512,
47 | 1 * 1,
48 | 1,
49 | self.model.controlnet.device,
50 | self.model.controlnet.dtype,
51 | )
52 |
53 | self.controlnet_cond = controlnet_cond
54 |
55 | def get_noise_pred_single(self, latents, t, context):
56 | noise_pred = unet_forward(
57 | self.model,
58 | latents,
59 | t,
60 | context,
61 | self.controlnet_cond
62 | )
63 | return noise_pred
64 |
65 | def get_noise_pred_guided(self, latents, t, guidance_scale, context=None):
66 | if context is None:
67 | context = self.context
68 |
69 | latents_input = torch.cat([latents] * 2)
70 |
71 | if self.controlnet_cond is not None:
72 | controlnet_cond = torch.cat([self.controlnet_cond] * 2)
73 | noise_pred = unet_forward(
74 | self.model,
75 | latents_input,
76 | t,
77 | encoder_hidden_states=context,
78 | controlnet_cond=controlnet_cond
79 | )
80 | else:
81 | noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
82 |
83 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
84 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
85 | return noise_pred
86 |
87 | @torch.no_grad()
88 | def init_prompt(self, prompt: Union[str, torch.Tensor]):
89 | uncond_input = self.model.tokenizer(
90 | [""], padding="max_length", max_length=self.model.tokenizer.model_max_length,
91 | return_tensors="pt"
92 | )
93 | uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0]
94 | if type(prompt) == str:
95 | text_input = self.model.tokenizer(
96 | [prompt],
97 | padding="max_length",
98 | max_length=self.model.tokenizer.model_max_length,
99 | truncation=True,
100 | return_tensors="pt",
101 | )
102 | text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0]
103 | else:
104 | text_embeddings = prompt
105 | self.context = torch.cat([uncond_embeddings, text_embeddings])
106 | self.prompt = prompt
107 |
108 | @torch.no_grad()
109 | def loop(self, latent):
110 | uncond_embeddings, cond_embeddings = self.context.chunk(2)
111 | all_latent = [latent]
112 | latent = latent.clone().detach()
113 | for i in trange(self.infer_steps, desc='Inversion', disable=not self.verbose):
114 | t = self.scheduler.timesteps[len(self.scheduler.timesteps) - i - 1]
115 |
116 | if not np.allclose(self.forward_guidance, 1.):
117 | noise_pred = self.get_noise_pred_guided(
118 | latent, t, self.forward_guidance, self.context
119 | )
120 | else:
121 | noise_pred = self.get_noise_pred_single(
122 | latent, t, cond_embeddings
123 | )
124 |
125 | latent = self.scheduler.step_forward(noise_pred, t, latent).prev_sample
126 |
127 | all_latent.append(latent)
128 | return all_latent
129 |
130 | @torch.no_grad()
131 | def inversion(self, image):
132 | latent = image2latent(image, self.model)
133 | image_rec = latent2image(latent, self.model)
134 | latents = self.loop(latent)
135 | return image_rec, latents
136 |
137 | def __call__(
138 | self,
139 | image_gt: PIL.Image.Image,
140 | prompt: Union[str, torch.Tensor],
141 | control_image: Optional[PIL.Image.Image] = None,
142 | verbose=False
143 | ):
144 | self.init_prompt(prompt)
145 | self.init_controlnet_cond(control_image)
146 |
147 | image_gt = np.array(image_gt)
148 | image_rec, latents = self.inversion(image_gt)
149 |
150 | return image_rec, latents, None
151 |
152 |
153 | class NullInversion(Inversion):
154 | def null_optimization(self, latents, verbose=False):
155 | self.scheduler.set_timesteps(self.infer_steps)
156 |
157 | uncond_embeddings, cond_embeddings = self.context.chunk(2)
158 | uncond_embeddings_list = []
159 | latent_cur = latents[-1]
160 | bar = tqdm(total=self.opt_scheduler.max_inner_steps * self.infer_steps)
161 |
162 | for i in range(self.infer_steps):
163 | uncond_embeddings = uncond_embeddings.clone().detach()
164 | uncond_embeddings.requires_grad = True
165 | optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
166 | latent_prev = latents[len(latents) - i - 2]
167 | t = self.scheduler.timesteps[i]
168 |
169 | with torch.no_grad():
170 | noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings)
171 |
172 | for j in range(self.opt_scheduler.max_inner_steps):
173 | noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings)
174 | noise_pred = noise_pred_uncond + self.backward_guidance * (noise_pred_cond - noise_pred_uncond)
175 | latents_prev_rec = self.scheduler.step_backward(noise_pred, t, latent_cur, first_time=(j == 0), last_time=False).prev_sample
176 | loss = nnf.mse_loss(latents_prev_rec, latent_prev)
177 | optimizer.zero_grad()
178 | loss.backward()
179 | optimizer.step()
180 | loss_item = loss.item()
181 |
182 | bar.update()
183 | if self.opt_scheduler(i, j, loss_item):
184 | break
185 | for j in range(j + 1, self.opt_scheduler.max_inner_steps):
186 | bar.update()
187 | uncond_embeddings_list.append(uncond_embeddings[:1].detach())
188 |
189 | with torch.no_grad():
190 | context = torch.cat([uncond_embeddings, cond_embeddings])
191 | noise_pred = self.get_noise_pred_guided(latent_cur, t, self.backward_guidance, context)
192 | latent_cur = self.scheduler.step_backward(noise_pred, t, latent_cur, first_time=False, last_time=True).prev_sample
193 |
194 | bar.close()
195 | return uncond_embeddings_list
196 |
197 | def __call__(
198 | self,
199 | image_gt: PIL.Image.Image,
200 | prompt: Union[str, torch.Tensor],
201 | opt_scheduler_name: str = 'loss',
202 | opt_num_inner_steps: int = 10,
203 | opt_early_stop_epsilon: float = 1e-5,
204 | opt_plateau_prop: float = 1/5,
205 | control_image: Optional[PIL.Image.Image] = None,
206 | verbose: bool = False
207 | ):
208 | image_rec, latents, _ = super().__call__(image_gt, prompt, control_image, verbose)
209 |
210 | if verbose:
211 | print("Null-text optimization...")
212 |
213 | self.opt_scheduler = opt_registry[opt_scheduler_name](
214 | self.infer_steps, opt_num_inner_steps,
215 | opt_early_stop_epsilon, opt_plateau_prop
216 | )
217 |
218 | uncond_embeddings = self.null_optimization(
219 | latents,
220 | verbose
221 | )
222 | return image_rec, latents, uncond_embeddings
223 |
--------------------------------------------------------------------------------
/diffusion_core/schedulers/__init__.py:
--------------------------------------------------------------------------------
1 | from .sample_schedulers import DDIMScheduler
--------------------------------------------------------------------------------
/diffusion_core/schedulers/opt_schedulers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from diffusion_core.utils.class_registry import ClassRegistry
3 |
4 |
5 | opt_registry = ClassRegistry()
6 |
7 | @opt_registry.add_to_registry('constant')
8 | class OptScheduler:
9 | def __init__(self, max_ddim_steps, max_inner_steps, early_stop_epsilon, plateau_prop):
10 | self.max_inner_steps = max_inner_steps
11 | self.max_ddim_steps = max_ddim_steps
12 | self.early_stop_epsilon = early_stop_epsilon
13 | self.plateau_prop = plateau_prop
14 | self.inner_steps_list = np.full(self.max_ddim_steps, self.max_inner_steps)
15 |
16 | def __call__(self, ddim_step, inner_step, loss=None):
17 | return inner_step + 1 >= self.inner_steps_list[ddim_step]
18 |
19 |
20 | @opt_registry.add_to_registry('loss')
21 | class LossOptScheduler(OptScheduler):
22 | def __call__(self, ddim_step, inner_step, loss):
23 | if loss < self.early_stop_epsilon + ddim_step * 2e-5:
24 | return True
25 | self.inner_steps_list[ddim_step] = inner_step + 1
26 | return False
27 |
--------------------------------------------------------------------------------
/diffusion_core/schedulers/sample_schedulers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from typing import List, Optional, Tuple, Union
4 | from dataclasses import dataclass
5 | from diffusers.utils.outputs import BaseOutput
6 | from diffusers.schedulers.scheduling_utils import SchedulerMixin
7 | from diffusers.configuration_utils import ConfigMixin, register_to_config
8 |
9 |
10 | @dataclass
11 | class SchedulerOutput(BaseOutput):
12 | prev_sample: torch.FloatTensor
13 | pred_original_sample: Optional[torch.FloatTensor] = None
14 |
15 | def rescale_zero_terminal_snr(betas):
16 | """
17 | Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
18 |
19 |
20 | Args:
21 | betas (`torch.FloatTensor`):
22 | the betas that the scheduler is being initialized with.
23 |
24 | Returns:
25 | `torch.FloatTensor`: rescaled betas with zero terminal SNR
26 | """
27 | # Convert betas to alphas_bar_sqrt
28 | alphas = 1.0 - betas
29 | alphas_cumprod = torch.cumprod(alphas, dim=0)
30 | alphas_bar_sqrt = alphas_cumprod.sqrt()
31 |
32 | # Store old values.
33 | alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
34 | alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
35 |
36 | # Shift so the last timestep is zero.
37 | alphas_bar_sqrt -= alphas_bar_sqrt_T
38 |
39 | # Scale so the first timestep is back to the old value.
40 | alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
41 |
42 | # Convert alphas_bar_sqrt to betas
43 | alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
44 | alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
45 | alphas = torch.cat([alphas_bar[0:1], alphas])
46 | betas = 1 - alphas
47 |
48 | return betas
49 |
50 |
51 | class SampleScheduler(SchedulerMixin, ConfigMixin):
52 | @register_to_config
53 | def __init__(
54 | self,
55 | num_train_timesteps: int = 1000,
56 | num_inference_steps: int = 50,
57 | beta_start: float = 0.0001,
58 | beta_end: float = 0.02,
59 | beta_schedule: str = 'linear',
60 | rescale_betas_zero_snr: bool = False,
61 | timestep_spacing: str = 'leading',
62 | set_alpha_to_one: bool = False,
63 | prediction_type: str = 'epsilon'
64 | ):
65 | if beta_schedule == "linear":
66 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
67 | elif beta_schedule == "scaled_linear":
68 | # this schedule is very specific to the latent diffusion model.
69 | self.betas = (
70 | torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
71 | )
72 | elif beta_schedule == "squaredcos_cap_v2":
73 | # Glide cosine schedule
74 | self.betas = betas_for_alpha_bar(num_train_timesteps)
75 | else:
76 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
77 |
78 |
79 | self.alphas = 1.0 - self.betas
80 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
81 | self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
82 |
83 | if num_inference_steps > num_train_timesteps:
84 | raise ValueError(
85 | f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.train_timesteps`:"
86 | f" {num_train_timesteps} as the unet model trained with this scheduler can only handle"
87 | f" maximal {num_train_timesteps} timesteps."
88 | )
89 |
90 | self.timestep_spacing = timestep_spacing
91 |
92 | self.num_train_timesteps = num_train_timesteps
93 | self.num_inference_steps = num_inference_steps
94 |
95 | def step_backward(
96 | self,
97 | model_output: torch.FloatTensor,
98 | timestep: int,
99 | sample: torch.FloatTensor,
100 | first_time: bool = True,
101 | last_time: bool = True,
102 | return_dict: bool = True
103 | ):
104 | raise NotImplementedError(f"step_backward does is not implemented for {self.__class__}")
105 |
106 | def step_forward(
107 | self,
108 | model_output: torch.FloatTensor,
109 | timestep: int,
110 | sample: torch.FloatTensor,
111 | return_dict: bool = True
112 | ):
113 | raise NotImplementedError(f"step_forward does is not implemented for {self.__class__}")
114 |
115 | def set_timesteps(
116 | self,
117 | num_inference_steps: int = None,
118 | device: Union[str, torch.device] = None
119 | ):
120 | raise NotImplementedError(f"step_forward does is not implemented for {self.__class__}")
121 |
122 |
123 | class DDIMScheduler(SampleScheduler):
124 | @register_to_config
125 | def __init__(self, **args):
126 | super().__init__(**args)
127 | self.set_timesteps(self.num_inference_steps)
128 | self.step = self.step_backward
129 | self.init_noise_sigma = 1.
130 | self.order = 1
131 | self.scale_model_input = lambda x, t: x
132 |
133 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
134 | """
135 | Sets the discrete timesteps used for the diffusion chain (to be run before inference).
136 |
137 | Args:
138 | num_inference_steps (`int`):
139 | The number of diffusion steps used when generating samples with a pre-trained model.
140 | """
141 |
142 | if num_inference_steps > self.num_train_timesteps:
143 | raise ValueError(
144 | f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
145 | f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle"
146 | f" maximal {self.num_train_timesteps} timesteps."
147 | )
148 |
149 | self.num_inference_steps = num_inference_steps
150 |
151 | # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
152 | if self.timestep_spacing == "linspace":
153 | timesteps = (
154 | np.linspace(0, self.num_train_timesteps - 1, num_inference_steps)
155 | .round()[::-1]
156 | .copy()
157 | .astype(np.int64)
158 | )
159 | elif self.timestep_spacing == "leading":
160 | step_ratio = self.num_train_timesteps // self.num_inference_steps
161 | # creates integer timesteps by multiplying by ratio
162 | # casting to int to avoid issues when num_inference_step is power of 3
163 | timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
164 | # timesteps += self.steps_offset
165 | elif self.timestep_spacing == "trailing":
166 | step_ratio = self.num_train_timesteps / self.num_inference_steps
167 | # creates integer timesteps by multiplying by ratio
168 | # casting to int to avoid issues when num_inference_step is power of 3
169 | timesteps = np.round(np.arange(self.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
170 | timesteps -= 1
171 | else:
172 | raise ValueError(
173 | f"{self.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
174 | )
175 |
176 | self.timesteps = torch.from_numpy(timesteps).to(device)
177 |
178 | def step_backward(
179 | self,
180 | model_output: torch.FloatTensor,
181 | timestep: int,
182 | sample: torch.FloatTensor,
183 | return_dict: bool = True,
184 | **kwargs
185 | ):
186 | # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
187 | # Ideally, read DDIM paper in-detail understanding
188 |
189 | # Notation ( ->
190 | # - pred_noise_t -> e_theta(x_t, t)
191 | # - pred_original_sample -> f_theta(x_t, t) or x_0
192 | # - std_dev_t -> sigma_t
193 | # - eta -> η
194 | # - pred_sample_direction -> "direction pointing to x_t"
195 | # - pred_prev_sample -> "x_t-1"
196 |
197 | # 1. get previous step value (=t-1)
198 | prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps
199 |
200 | # 2. compute alphas, betas
201 | alpha_prod_t = self.alphas_cumprod[timestep]
202 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
203 | beta_prod_t = 1 - alpha_prod_t
204 |
205 | # 3. compute predicted original sample from predicted noise also called
206 | if self.config.prediction_type == "epsilon":
207 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
208 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
209 | pred_epsilon = model_output
210 | elif self.config.prediction_type == "sample":
211 | pred_original_sample = model_output
212 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
213 | elif self.config.prediction_type == "v_prediction":
214 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
215 | pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
216 | else:
217 | raise ValueError(
218 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
219 | " `v_prediction`"
220 | )
221 |
222 | if kwargs.get("ref_image", None) is not None and kwargs.get("recon_lr", 0.0) > 0.0:
223 | ref_image = kwargs.get("ref_image").expand_as(pred_original_sample)
224 | recon_lr = kwargs.get("recon_lr", 0.0)
225 | recon_mask = kwargs.get("recon_mask", None)
226 | if recon_mask is not None:
227 | recon_mask = recon_mask.expand_as(pred_original_sample).float()
228 | pred_original_sample = pred_original_sample - recon_lr * (pred_original_sample - ref_image) * recon_mask
229 | else:
230 | pred_original_sample = pred_original_sample - recon_lr * (pred_original_sample - ref_image)
231 |
232 | # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
233 | pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon
234 |
235 | # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
236 | prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
237 |
238 | if not return_dict:
239 | return (prev_sample, pred_original_sample)
240 |
241 | return SchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
242 |
243 | def step_forward(
244 | self,
245 | model_output: torch.FloatTensor,
246 | timestep: int,
247 | sample: torch.FloatTensor,
248 | return_dict: bool = True
249 | ):
250 | # 1. get previous step value (=t+1)
251 | prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps
252 |
253 | # 2. compute alphas, betas
254 | # change original implementation to exactly match noise levels for analogous forward process
255 | alpha_prod_t = self.alphas_cumprod[timestep]
256 | alpha_prod_t_prev = (
257 | self.alphas_cumprod[prev_timestep]
258 | if prev_timestep < self.num_train_timesteps
259 | else self.alphas_cumprod[-1]
260 | )
261 |
262 | beta_prod_t = 1 - alpha_prod_t
263 |
264 | # 3. compute predicted original sample from predicted noise also called
265 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
266 | if self.config.prediction_type == "epsilon":
267 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
268 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
269 | pred_epsilon = model_output
270 | elif self.config.prediction_type == "sample":
271 | pred_original_sample = model_output
272 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
273 | elif self.config.prediction_type == "v_prediction":
274 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
275 | pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
276 | else:
277 | raise ValueError(
278 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
279 | " `v_prediction`"
280 | )
281 |
282 | # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
283 | pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon
284 |
285 | # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
286 | prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
287 |
288 | if not return_dict:
289 | return (prev_sample, pred_original_sample)
290 |
291 | return SchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
292 |
293 | def __repr__(self):
294 | return f"{self.__class__.__name__} {self.to_json_string()}"
295 |
--------------------------------------------------------------------------------
/diffusion_core/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .image_utils import load_512
2 | from .class_registry import ClassRegistry
3 | from .grad_checkpoint import checkpoint_forward, use_grad_checkpointing
4 | from .model_utils import use_deterministic, toggle_grad
5 |
--------------------------------------------------------------------------------
/diffusion_core/utils/class_registry.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import typing
3 | import omegaconf
4 | import dataclasses
5 | import typing as tp
6 |
7 |
8 | class ClassRegistry:
9 | def __init__(self):
10 | self.classes = dict()
11 | self.args = dict()
12 | self.arg_keys = None
13 |
14 | def __getitem__(self, item):
15 | return self.classes[item]
16 |
17 | def make_dataclass_from_func(self, func, name, arg_keys):
18 | args = inspect.signature(func).parameters
19 | args = [
20 | (k, typing.Any, omegaconf.MISSING)
21 | if v.default is inspect.Parameter.empty
22 | else (k, typing.Optional[typing.Any], None)
23 | if v.default is None
24 | else (
25 | k,
26 | type(v.default),
27 | dataclasses.field(default=v.default),
28 | )
29 | for k, v in args.items()
30 | ]
31 | args = [
32 | arg
33 | for arg in args
34 | if (arg[0] != "self" and arg[0] != "args" and arg[0] != "kwargs")
35 | ]
36 | if arg_keys:
37 | self.arg_keys = arg_keys
38 | arg_classes = dict()
39 | for key in arg_keys:
40 | arg_classes[key] = dataclasses.make_dataclass(key, args)
41 | return dataclasses.make_dataclass(
42 | name,
43 | [
44 | (k, v, dataclasses.field(default=v()))
45 | for k, v in arg_classes.items()
46 | ],
47 | )
48 | return dataclasses.make_dataclass(name, args)
49 |
50 | def make_dataclass_from_classes(self):
51 | return dataclasses.make_dataclass(
52 | 'Name',
53 | [
54 | (k, v, dataclasses.field(default=v()))
55 | for k, v in self.classes.items()
56 | ],
57 | )
58 |
59 | def make_dataclass_from_args(self):
60 | return dataclasses.make_dataclass(
61 | 'Name',
62 | [
63 | (k, v, dataclasses.field(default=v()))
64 | for k, v in self.args.items()
65 | ],
66 | )
67 |
68 | def _add_single_obj(self, obj, name, arg_keys):
69 | self.classes[name] = obj
70 | if inspect.isfunction(obj):
71 | self.args[name] = self.make_dataclass_from_func(
72 | obj, name, arg_keys
73 | )
74 | elif inspect.isclass(obj):
75 | self.args[name] = self.make_dataclass_from_func(
76 | obj.__init__, name, arg_keys
77 | )
78 |
79 | def add_to_registry(self, names: tp.Union[str, tp.List[str]], arg_keys=None):
80 | if not isinstance(names, list):
81 | names = [names]
82 |
83 | def decorator(obj):
84 | for name in names:
85 | self._add_single_obj(obj, name, arg_keys)
86 |
87 | return obj
88 | return decorator
89 |
90 | def __contains__(self, name: str):
91 | return name in self.args.keys()
92 |
93 | def __repr__(self):
94 | return f"{list(self.args.keys())}"
95 |
--------------------------------------------------------------------------------
/diffusion_core/utils/grad_checkpoint.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from inspect import signature
3 | from torch.utils import checkpoint
4 |
5 | use_mode: bool = True
6 |
7 |
8 | def use_grad_checkpointing(mode: bool=True):
9 | global use_mode
10 | use_mode = mode
11 |
12 |
13 | def checkpoint_forward(func):
14 | sig = signature(func)
15 |
16 | @functools.wraps(func)
17 | def wrapper(*args, **kwargs):
18 | if use_mode:
19 | bound = sig.bind(*args, **kwargs)
20 | bound.apply_defaults()
21 | new_args = bound.arguments.values()
22 | result = checkpoint.checkpoint(func, *new_args, use_reentrant=False)
23 | else:
24 | result = func(*args, **kwargs)
25 | return result
26 |
27 | return wrapper
28 |
--------------------------------------------------------------------------------
/diffusion_core/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from PIL import Image
4 |
5 |
6 |
7 | def load_512(image_path, left=0, right=0, top=0, bottom=0):
8 | if type(image_path) is str:
9 | image = np.array(Image.open(image_path).convert('RGB'))[:, :, :3]
10 | else:
11 | image = image_path
12 | h, w, c = image.shape
13 | left = min(left, w-1)
14 | right = min(right, w - left - 1)
15 | top = min(top, h - left - 1)
16 | bottom = min(bottom, h - top - 1)
17 | image = image[top:h-bottom, left:w-right]
18 | h, w, c = image.shape
19 | if h < w:
20 | offset = (w - h) // 2
21 | image = image[:, offset:offset + h]
22 | elif w < h:
23 | offset = (h - w) // 2
24 | image = image[offset:offset + w]
25 | image = np.array(Image.fromarray(image).resize((512, 512)))
26 | return image
27 |
28 |
29 | def to_im(torch_image, **kwargs):
30 | return transforms.ToPILImage()(
31 | make_grid(torch_image, **kwargs)
32 | )
33 |
--------------------------------------------------------------------------------
/diffusion_core/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def use_deterministic():
5 | torch.backends.cudnn.benchmark = False
6 | torch.backends.cudnn.deterministic = True
7 |
8 |
9 | def toggle_grad(model, mode=True):
10 | for p in model.parameters():
11 | p.requires_grad = mode
12 |
--------------------------------------------------------------------------------
/docs/diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FusionBrainLab/Guide-and-Rescale/4a79013c641d20f57b206ad7f0bf2e9d01cd412c/docs/diagram.png
--------------------------------------------------------------------------------
/docs/teaser_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FusionBrainLab/Guide-and-Rescale/4a79013c641d20f57b206ad7f0bf2e9d01cd412c/docs/teaser_image.png
--------------------------------------------------------------------------------
/example_images/face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FusionBrainLab/Guide-and-Rescale/4a79013c641d20f57b206ad7f0bf2e9d01cd412c/example_images/face.png
--------------------------------------------------------------------------------
/example_images/zebra.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FusionBrainLab/Guide-and-Rescale/4a79013c641d20f57b206ad7f0bf2e9d01cd412c/example_images/zebra.jpeg
--------------------------------------------------------------------------------
/sd_env.yaml:
--------------------------------------------------------------------------------
1 | name: ldm
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=11.3
9 | - numpy=1.19.2
10 | - pip:
11 | - torch==2.3.0
12 | - torchvision==0.18.0
13 | - image-reward
14 | - albumentations==0.4.3
15 | - diffusers==0.17.1
16 | - matplotlib
17 | - opencv-python==4.1.2.30
18 | - pudb==2019.2
19 | - invisible-watermark
20 | - imageio==2.9.0
21 | - imageio-ffmpeg==0.4.2
22 | - pytorch-lightning
23 | - omegaconf==2.1.1
24 | - test-tube>=0.7.5
25 | - streamlit>=0.73.1
26 | - einops==0.3.0
27 | - torch-fidelity==0.3.0
28 | - transformers==4.29.0
29 | - torchmetrics==0.6.0
30 | - kornia==0.6
31 | - accelerate
32 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
33 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip
34 | - ipykernel
35 | - ipywidgets
--------------------------------------------------------------------------------