├── LICENSE ├── README.md ├── configs ├── ours_nonstyle_best.yaml ├── ours_nonstyle_best_colab.yaml ├── ours_style_best.yaml └── ours_style_best_colab.yaml ├── diffusion_core ├── __init__.py ├── configuration_utils.py ├── custom_forwards │ ├── __init__.py │ └── unet_sd.py ├── diffusion_models.py ├── diffusion_schedulers.py ├── diffusion_utils.py ├── guiders │ ├── __init__.py │ ├── guidance_editing.py │ ├── noise_rescales.py │ ├── opt_guiders.py │ └── scale_schedulers.py ├── inversion │ ├── __init__.py │ ├── negativ_p_inversion.py │ └── null_inversion.py ├── schedulers │ ├── __init__.py │ ├── opt_schedulers.py │ └── sample_schedulers.py └── utils │ ├── __init__.py │ ├── class_registry.py │ ├── grad_checkpoint.py │ ├── image_utils.py │ └── model_utils.py ├── docs ├── diagram.png └── teaser_image.png ├── example_images ├── face.png └── zebra.jpeg ├── example_notebooks └── guide_and_rescale.ipynb └── sd_env.yaml /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Guide-and-Rescale: Self-Guidance Mechanism for Effective Tuning-Free Real Image Editing (ECCV 2024) 2 | 3 | 4 | 5 | 6 | [![License](https://img.shields.io/github/license/AIRI-Institute/al_toolbox)](./LICENSE) 7 | 8 | >Despite recent advances in large-scale text-to-image generative models, manipulating real images with these models remains a challenging problem. The main limitations of existing editing methods are that they either fail to perform with consistent quality on a wide range of image edits, or require time-consuming hyperparameter tuning or fine-tuning of the diffusion model to preserve the image-specific appearance of the input image. Most of these approaches utilize source image information via intermediate feature caching which is inserted in generation process as itself. However, such technique produce feature misalignment of the model that leads to inconsistent results. 9 | We propose a novel approach that is built upon modified diffusion sampling process via guidance mechanism. In this work, we explore self-guidance technique to preserve the overall structure of the input image and its local regions appearance that should not be edited. In particular, we explicitly introduce layout preserving energy functions that are aimed to save local and global structures of the source image. Additionally, we propose a noise rescaling mechanism that allows to preserve noise distribution by balancing the norms of classifier-free guidance and our proposed guiders during generation. It leads to more consistent and better editing results. Such guiding approach does not require fine-tuning diffusion model and exact inversion process. As a result, the proposed method provides a fast and high quality editing mechanism. 10 | In our experiments, we show through human evaluation and quantitative analysis that the proposed method allows to produce desired editing which is more preferable by the human and also achieves a better trade-off between editing quality and preservation of the original image. 11 | > 12 | 13 | ![image](docs/teaser_image.png) 14 | 15 | ## Setup 16 | 17 | This code uses a pre-trained [Stable Diffusion](https://huggingface.co/docs/diffusers/v0.25.1/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline) from [Diffusers](https://github.com/huggingface/diffusers#readme) library. We ran our code with Python 3.8.5, PyTorch 2.3.0, Diffuser 0.17.1 on NVIDIA A100 GPU with 40GB RAM. 18 | 19 | In order to setup the environment, run: 20 | ``` 21 | conda env create -f sd_env.yaml 22 | ``` 23 | Conda environment `ldm` will be created and you can use it. 24 | 25 | 26 | ## Quickstart 27 | 28 | We provide examples of applying our pipeline to real image editing in Colab . 29 | 30 | You can try Grardio demo in HF Spaces . 31 | 32 | We also provide [a jupyter notebook](example_notebooks/guide_and_rescale.ipynb) to try Guide-and-Rescale pipeline on your own server. 33 | 34 | ## Method Diagram 35 |

36 | Diagram 37 |
38 |

39 |

40 |
41 | Overall scheme of the proposed method Guide-and-Rescale. First, our method uses a classic ddim inversion of the source real image. Then the method performs real image editing via the classical denoising process. For every denoising step the noise term is modified by guider that utilizes latents $z_t$ from the current generation process and time-aligned ddim latents $z^*_t$. 42 |

43 | 44 | 45 | ## Guiders 46 | 47 | In our work we propose specific **guiders**, i.e. guidance signals suitable for editing. The code for these guiders can be found in [diffusion_core/guiders/opt_guiders.py](diffusion_core/guiders/opt_guiders.py). 48 | 49 | Every guider is defined as a separate class, that inherits from the parent class `BaseGuider`. A template for defining a new guider class looks as follows: 50 | 51 | ``` 52 | class SomeGuider(BaseGuider): 53 | patched: bool 54 | forward_hooks: list 55 | 56 | def [grad_fn or calc_energy](self, data_dict): 57 | ... 58 | 59 | def model_patch(self, model): 60 | ... 61 | 62 | def single_output_clear(self): 63 | ... 64 | ``` 65 | 66 | ### grad_fn or calc_energy 67 | 68 | The `BaseGuider` class contains a property `grad_guider`. This property is `True`, when the guider does not require any backpropagation over its outputs for retrieving the gradient w.r.t. the current latent (for example, as in classifier-free guidance). In this case, the child class contains a function `grad_fn`, where the gradient w.r.t. the current latent is estimated algorithmically. 69 | 70 | When the gradient has to be estimated with backpropagation and `grad_guider` is `False` (for example, as when using the norm of the difference of attention maps for guidance), the child class contains a function `calc_energy`, where the desired energy function output is calculated. This output is further used for backpropagation. 71 | 72 | The `grad_fn` and `calc_energy` functions receive a dictionary (`data_dict`) as input. In this dictionary we store all objects (the diffusion model instance, prompts, current latent, etc.) that might be usefull for the guiders in the current pipeline. 73 | 74 | ### model_patch and patched 75 | 76 | When the guider requires outputs of intermediate layers of the diffusion model to estimate the energy function/gradient, we define a function `model_patch` in this guider's class and set property `patched` equal `True`. We will further refer to such guiders as *patched guiders*. 77 | 78 | This function patches the desired layers of the diffusion model, an retrieves the necesarry output from these layers. This output is then stored in the property `output` of the guider class object. This way it can be accessed by the editing pipeline an stored in `data_dict` for further use in `calc_energy`/`grad_fn` functions. 79 | 80 | ### forward_hooks 81 | 82 | In the editing pipeline we conduct 4 diffusion model forward passes: 83 | 84 | - unconditional, from the current latent $z_t$ 85 | - `cur_inv`: conditional on the initial prompt, from the current latent $z_t$ 86 | - `inv_inv`: conditional on the initial prompt, from the corresponding inversion latent $z^*_t$ 87 | - `cur_trg`: conditional on the prompt describing the editing result, from the current latent $z_t$ 88 | 89 | We store the unconditional prediction in `data_dict`, as well as the ouputs of `cur_inv` and `cur_trg` forward passes for further use in classifier-free guidance. 90 | 91 | However, when the guider is patched, we also have its `output` to store in `data_dict`. In `forward_hooks` property of the guider class we define the list of forward passes (from the range: `cur_inv`, `inv_inv`, `cur_trg`), for which we need to store the `output`. 92 | 93 | After the specific forward pass is conducted we can access the `output` of the guider and store it in `data_dict`, if the forward pass is listed in `forward_hooks`. We store it with a key, specifying the current forward pass. 94 | 95 | This way we can avoid storing unnecesary `output`s in `data_dict`, as well as distinguish `output`s from different forward passes by their keys. 96 | 97 | 98 | ### single_output_clear 99 | 100 | This is only relevant for patched guiders. 101 | 102 | When the data from the `output` property of the guiders class object is stored in `data_dict`, we need to empty the `output` to avoid exceeding memory limit. For this purpose we define a `single_output_clear` function. It returns an empty `output`, for example `None`, or an empty list `[]`. 103 | 104 | ## References & Acknowledgments 105 | 106 | The repository was started from [Prompt-to-Prompt](https://github.com/google/prompt-to-prompt/). 107 | 108 | ## Citation 109 | 110 | If you use this code for your research, please cite our paper: 111 | ``` 112 | @article{titov2024guideandrescale 113 | title={Guide-and-Rescale: Self-Guidance Mechanism for Effective Tuning-Free Real Image Editing}, 114 | author={Vadim Titov and Madina Khalmatova and Alexandra Ivanova and Dmitry Vetrov and Aibek Alanov}, 115 | journal={arXiv preprint arXiv:2409.01322}, 116 | year={2024} 117 | } 118 | ``` 119 | -------------------------------------------------------------------------------- /configs/ours_nonstyle_best.yaml: -------------------------------------------------------------------------------- 1 | scheduler_type: ddim_50_eps 2 | inversion_type: dummy 3 | model_name: stable-diffusion-v1-4 4 | pipeline_type: ours 5 | start_latent: inversion 6 | verbose: false 7 | guiders: 8 | - name: cfg 9 | g_scale: 7.5 10 | kwargs: 11 | is_source_guidance: false 12 | - name: self_attn_map_l2_appearance 13 | g_scale: 1. 14 | kwargs: 15 | self_attn_gs: 300000. 16 | app_gs: 500. 17 | new_features: true 18 | total_first_steps: 30 19 | noise_rescaling_setup: 20 | type: range_other_on_cfg_norm 21 | init_setup: 22 | - 0.33 23 | - 3.0 24 | edit_types: 25 | - animal-2-animal 26 | - face-in-the-wild 27 | - person-in-the-wild -------------------------------------------------------------------------------- /configs/ours_nonstyle_best_colab.yaml: -------------------------------------------------------------------------------- 1 | scheduler_type: ddim_50_eps 2 | inversion_type: dummy 3 | model_name: stable-diffusion-v1-4 4 | pipeline_type: ours 5 | start_latent: inversion 6 | verbose: false 7 | guiders: 8 | - name: cfg 9 | g_scale: 7.5 10 | kwargs: 11 | is_source_guidance: false 12 | - name: self_attn_map_l2_appearance 13 | g_scale: 1. 14 | kwargs: 15 | self_attn_gs: 300000. 16 | app_gs: 500. 17 | new_features: true 18 | total_first_steps: 30 19 | noise_rescaling_setup: 20 | type: range_other_on_cfg_norm 21 | init_setup: 22 | - 0.33 23 | - 3.0 24 | edit_types: 25 | - animal-2-animal 26 | - face-in-the-wild 27 | - person-in-the-wild 28 | gradient_checkpointing: true 29 | self_attn_layers_num: 30 | - [2, 6] 31 | - [0, 1] 32 | - [0, 4] 33 | -------------------------------------------------------------------------------- /configs/ours_style_best.yaml: -------------------------------------------------------------------------------- 1 | scheduler_type: ddim_50_eps 2 | inversion_type: dummy 3 | model_name: stable-diffusion-v1-4 4 | pipeline_type: ours 5 | start_latent: inversion 6 | verbose: false 7 | edit_types: stylisation 8 | guiders: 9 | - name: cfg 10 | g_scale: 7.5 11 | kwargs: 12 | is_source_guidance: false 13 | - name: self_attn_map_l2 14 | g_scale: 15 | - 100000.0 16 | - 100000.0 17 | - 100000.0 18 | - 100000.0 19 | - 100000.0 20 | - 100000.0 21 | - 100000.0 22 | - 100000.0 23 | - 100000.0 24 | - 100000.0 25 | - 100000.0 26 | - 100000.0 27 | - 100000.0 28 | - 100000.0 29 | - 100000.0 30 | - 100000.0 31 | - 100000.0 32 | - 100000.0 33 | - 100000.0 34 | - 100000.0 35 | - 100000.0 36 | - 100000.0 37 | - 100000.0 38 | - 100000.0 39 | - 100000.0 40 | - 0.0 41 | - 0.0 42 | - 0.0 43 | - 0.0 44 | - 0.0 45 | - 0.0 46 | - 0.0 47 | - 0.0 48 | - 0.0 49 | - 0.0 50 | - 0.0 51 | - 0.0 52 | - 0.0 53 | - 0.0 54 | - 0.0 55 | - 0.0 56 | - 0.0 57 | - 0.0 58 | - 0.0 59 | - 0.0 60 | - 0.0 61 | - 0.0 62 | - 0.0 63 | - 0.0 64 | - 0.0 65 | kwargs: {} 66 | - name: features_map_l2 67 | g_scale: 68 | - 2.5 69 | - 2.5 70 | - 2.5 71 | - 2.5 72 | - 2.5 73 | - 2.5 74 | - 2.5 75 | - 2.5 76 | - 2.5 77 | - 2.5 78 | - 2.5 79 | - 2.5 80 | - 2.5 81 | - 2.5 82 | - 2.5 83 | - 2.5 84 | - 2.5 85 | - 2.5 86 | - 2.5 87 | - 2.5 88 | - 2.5 89 | - 2.5 90 | - 2.5 91 | - 2.5 92 | - 2.5 93 | - 0.0 94 | - 0.0 95 | - 0.0 96 | - 0.0 97 | - 0.0 98 | - 0.0 99 | - 0.0 100 | - 0.0 101 | - 0.0 102 | - 0.0 103 | - 0.0 104 | - 0.0 105 | - 0.0 106 | - 0.0 107 | - 0.0 108 | - 0.0 109 | - 0.0 110 | - 0.0 111 | - 0.0 112 | - 0.0 113 | - 0.0 114 | - 0.0 115 | - 0.0 116 | - 0.0 117 | - 0.0 118 | kwargs: {} 119 | noise_rescaling_setup: 120 | type: range_other_on_cfg_norm 121 | init_setup: 122 | - 1.5 123 | - 1.5 124 | -------------------------------------------------------------------------------- /configs/ours_style_best_colab.yaml: -------------------------------------------------------------------------------- 1 | scheduler_type: ddim_50_eps 2 | inversion_type: dummy 3 | model_name: stable-diffusion-v1-4 4 | pipeline_type: ours 5 | start_latent: inversion 6 | verbose: false 7 | edit_types: stylisation 8 | guiders: 9 | - name: cfg 10 | g_scale: 7.5 11 | kwargs: 12 | is_source_guidance: false 13 | - name: self_attn_map_l2 14 | g_scale: 15 | - 100000.0 16 | - 100000.0 17 | - 100000.0 18 | - 100000.0 19 | - 100000.0 20 | - 100000.0 21 | - 100000.0 22 | - 100000.0 23 | - 100000.0 24 | - 100000.0 25 | - 100000.0 26 | - 100000.0 27 | - 100000.0 28 | - 100000.0 29 | - 100000.0 30 | - 100000.0 31 | - 100000.0 32 | - 100000.0 33 | - 100000.0 34 | - 100000.0 35 | - 100000.0 36 | - 100000.0 37 | - 100000.0 38 | - 100000.0 39 | - 100000.0 40 | - 0.0 41 | - 0.0 42 | - 0.0 43 | - 0.0 44 | - 0.0 45 | - 0.0 46 | - 0.0 47 | - 0.0 48 | - 0.0 49 | - 0.0 50 | - 0.0 51 | - 0.0 52 | - 0.0 53 | - 0.0 54 | - 0.0 55 | - 0.0 56 | - 0.0 57 | - 0.0 58 | - 0.0 59 | - 0.0 60 | - 0.0 61 | - 0.0 62 | - 0.0 63 | - 0.0 64 | - 0.0 65 | kwargs: {} 66 | - name: features_map_l2 67 | g_scale: 68 | - 2.5 69 | - 2.5 70 | - 2.5 71 | - 2.5 72 | - 2.5 73 | - 2.5 74 | - 2.5 75 | - 2.5 76 | - 2.5 77 | - 2.5 78 | - 2.5 79 | - 2.5 80 | - 2.5 81 | - 2.5 82 | - 2.5 83 | - 2.5 84 | - 2.5 85 | - 2.5 86 | - 2.5 87 | - 2.5 88 | - 2.5 89 | - 2.5 90 | - 2.5 91 | - 2.5 92 | - 2.5 93 | - 0.0 94 | - 0.0 95 | - 0.0 96 | - 0.0 97 | - 0.0 98 | - 0.0 99 | - 0.0 100 | - 0.0 101 | - 0.0 102 | - 0.0 103 | - 0.0 104 | - 0.0 105 | - 0.0 106 | - 0.0 107 | - 0.0 108 | - 0.0 109 | - 0.0 110 | - 0.0 111 | - 0.0 112 | - 0.0 113 | - 0.0 114 | - 0.0 115 | - 0.0 116 | - 0.0 117 | - 0.0 118 | kwargs: {} 119 | noise_rescaling_setup: 120 | type: range_other_on_cfg_norm 121 | init_setup: 122 | - 1.5 123 | - 1.5 124 | gradient_checkpointing: true 125 | self_attn_layers_num: 126 | - [2, 6] 127 | - [0, 1] 128 | - [0, 4] 129 | -------------------------------------------------------------------------------- /diffusion_core/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffusion_models import diffusion_models_registry 2 | from .diffusion_schedulers import diffusion_schedulers_registry 3 | -------------------------------------------------------------------------------- /diffusion_core/configuration_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import inspect 3 | 4 | 5 | class MethodStorage(dict): 6 | @staticmethod 7 | def register(method_name): 8 | def decorator(method): 9 | method._method_name = method_name 10 | method._is_decorated = True 11 | return method 12 | return decorator 13 | 14 | def register_methods(self): 15 | self.registered_methods = {} 16 | for method_name, method in inspect.getmembers(self, predicate=inspect.ismethod): 17 | print(method_name) 18 | if getattr(method, '_is_decorated', False): 19 | self.registered_methods[method._method_name] = method 20 | 21 | def __getitem__(self, method_name): 22 | return self.registered_methods[method_name] 23 | -------------------------------------------------------------------------------- /diffusion_core/custom_forwards/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/Guide-and-Rescale/4a79013c641d20f57b206ad7f0bf2e9d01cd412c/diffusion_core/custom_forwards/__init__.py -------------------------------------------------------------------------------- /diffusion_core/custom_forwards/unet_sd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from typing import Optional, Tuple, Union, Dict, Any 4 | 5 | from diffusion_core.utils import checkpoint_forward 6 | 7 | 8 | @checkpoint_forward 9 | def unet_down_forward(downsample_block, sample, emb, encoder_hidden_states, attention_mask, cross_attention_kwargs): 10 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 11 | sample, res_samples = downsample_block( 12 | hidden_states=sample, 13 | temb=emb, 14 | encoder_hidden_states=encoder_hidden_states, 15 | attention_mask=attention_mask, 16 | cross_attention_kwargs=cross_attention_kwargs, 17 | ) 18 | else: 19 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 20 | return sample, res_samples 21 | 22 | 23 | @checkpoint_forward 24 | def unet_mid_forward(mid_block, sample, emb, encoder_hidden_states, attention_mask, cross_attention_kwargs): 25 | sample = mid_block( 26 | sample, 27 | emb, 28 | encoder_hidden_states=encoder_hidden_states, 29 | attention_mask=attention_mask, 30 | cross_attention_kwargs=cross_attention_kwargs, 31 | ) 32 | return sample 33 | 34 | 35 | @checkpoint_forward 36 | def unet_up_forward(upsample_block, sample, emb, res_samples, encoder_hidden_states, cross_attention_kwargs, 37 | upsample_size, attention_mask): 38 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 39 | sample = upsample_block( 40 | hidden_states=sample, 41 | temb=emb, 42 | res_hidden_states_tuple=res_samples, 43 | encoder_hidden_states=encoder_hidden_states, 44 | cross_attention_kwargs=cross_attention_kwargs, 45 | upsample_size=upsample_size, 46 | attention_mask=attention_mask, 47 | ) 48 | else: 49 | sample = upsample_block( 50 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 51 | ) 52 | return sample 53 | 54 | 55 | def unet_forward( 56 | model, 57 | sample: torch.FloatTensor, 58 | timestep: Union[torch.Tensor, float, int], 59 | encoder_hidden_states: torch.Tensor, 60 | controlnet_cond=None, 61 | controlnet_conditioning_scale=1., 62 | class_labels: Optional[torch.Tensor] = None, 63 | timestep_cond: Optional[torch.Tensor] = None, 64 | attention_mask: Optional[torch.Tensor] = None, 65 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 66 | return_dict: bool = True, 67 | ): 68 | if controlnet_cond is not None: 69 | down_block_additional_residuals, mid_block_additional_residual = model.controlnet( 70 | sample, 71 | timestep, 72 | encoder_hidden_states=encoder_hidden_states, 73 | controlnet_cond=controlnet_cond, 74 | return_dict=False, 75 | ) 76 | 77 | down_block_additional_residuals = [ 78 | down_block_res_sample * controlnet_conditioning_scale 79 | for down_block_res_sample in down_block_additional_residuals 80 | ] 81 | mid_block_additional_residual *= controlnet_conditioning_scale 82 | else: 83 | down_block_additional_residuals = None 84 | mid_block_additional_residual = None 85 | 86 | default_overall_up_factor = 2 ** model.unet.num_upsamplers 87 | 88 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 89 | forward_upsample_size = False 90 | upsample_size = None 91 | 92 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 93 | logger.info("Forward upsample size to force interpolation output size.") 94 | forward_upsample_size = True 95 | 96 | # prepare attention_mask 97 | if attention_mask is not None: 98 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 99 | attention_mask = attention_mask.unsqueeze(1) 100 | 101 | # 0. center input if necessary 102 | if model.unet.config.center_input_sample: 103 | sample = 2 * sample - 1.0 104 | 105 | # 1. time 106 | timesteps = timestep 107 | if not torch.is_tensor(timesteps): 108 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 109 | # This would be a good case for the `match` statement (Python 3.10+) 110 | is_mps = sample.device.type == "mps" 111 | if isinstance(timestep, float): 112 | dtype = torch.float32 if is_mps else torch.float64 113 | else: 114 | dtype = torch.int32 if is_mps else torch.int64 115 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 116 | elif len(timesteps.shape) == 0: 117 | timesteps = timesteps[None].to(sample.device) 118 | 119 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 120 | timesteps = timesteps.expand(sample.shape[0]) 121 | 122 | t_emb = model.unet.time_proj(timesteps) 123 | 124 | # timesteps does not contain any weights and will always return f32 tensors 125 | # but time_embedding might actually be running in fp16. so we need to cast here. 126 | # there might be better ways to encapsulate this. 127 | t_emb = t_emb.to(dtype=model.unet.dtype) 128 | 129 | emb = model.unet.time_embedding(t_emb, timestep_cond) 130 | 131 | if model.unet.class_embedding is not None: 132 | if class_labels is None: 133 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 134 | 135 | if model.unet.config.class_embed_type == "timestep": 136 | class_labels = model.unet.time_proj(class_labels) 137 | 138 | class_emb = model.unet.class_embedding(class_labels).to(dtype=model.unet.dtype) 139 | emb = emb + class_emb 140 | 141 | # 2. pre-process 142 | sample = model.unet.conv_in(sample) 143 | 144 | # 3. down 145 | down_block_res_samples = (sample,) 146 | for downsample_block in model.unet.down_blocks: 147 | sample, res_samples = unet_down_forward(downsample_block, sample, emb, encoder_hidden_states, attention_mask, 148 | cross_attention_kwargs) 149 | 150 | down_block_res_samples += res_samples 151 | 152 | if down_block_additional_residuals is not None: 153 | new_down_block_res_samples = () 154 | 155 | for down_block_res_sample, down_block_additional_residual in zip( 156 | down_block_res_samples, down_block_additional_residuals 157 | ): 158 | down_block_res_sample = down_block_res_sample + down_block_additional_residual 159 | new_down_block_res_samples += (down_block_res_sample,) 160 | 161 | down_block_res_samples = new_down_block_res_samples 162 | 163 | # 4. mid 164 | if model.unet.mid_block is not None: 165 | sample = unet_mid_forward(model.unet.mid_block, sample, emb, encoder_hidden_states, attention_mask, 166 | cross_attention_kwargs) 167 | 168 | if mid_block_additional_residual is not None: 169 | sample = sample + mid_block_additional_residual 170 | 171 | # 5. up 172 | for i, upsample_block in enumerate(model.unet.up_blocks): 173 | is_final_block = i == len(model.unet.up_blocks) - 1 174 | 175 | res_samples = down_block_res_samples[-len(upsample_block.resnets):] 176 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 177 | 178 | # if we have not reached the final block and need to forward the 179 | # upsample size, we do it here 180 | if not is_final_block and forward_upsample_size: 181 | upsample_size = down_block_res_samples[-1].shape[2:] 182 | 183 | sample = unet_up_forward(upsample_block, sample, emb, res_samples, encoder_hidden_states, 184 | cross_attention_kwargs, upsample_size, attention_mask) 185 | 186 | # 6. post-process 187 | if model.unet.conv_norm_out: 188 | sample = model.unet.conv_norm_out(sample) 189 | sample = model.unet.conv_act(sample) 190 | sample = model.unet.conv_out(sample) 191 | 192 | # if not return_dict: 193 | # return (sample,) 194 | 195 | return sample 196 | -------------------------------------------------------------------------------- /diffusion_core/diffusion_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from diffusers.pipelines import StableDiffusionPipeline 4 | from .utils import ClassRegistry 5 | 6 | diffusion_models_registry = ClassRegistry() 7 | 8 | 9 | @diffusion_models_registry.add_to_registry("stable-diffusion-v1-4") 10 | def read_v14(scheduler): 11 | model_id = "CompVis/stable-diffusion-v1-4" 12 | model = StableDiffusionPipeline.from_pretrained( 13 | model_id, torch_dtype=torch.float32, scheduler=scheduler 14 | ) 15 | return model 16 | 17 | 18 | @diffusion_models_registry.add_to_registry("stable-diffusion-v1-5") 19 | def read_v15(scheduler): 20 | model_id = "runwayml/stable-diffusion-v1-5" 21 | model = StableDiffusionPipeline.from_pretrained( 22 | model_id, torch_dtype=torch.float32, scheduler=scheduler 23 | ) 24 | return model 25 | 26 | 27 | @diffusion_models_registry.add_to_registry("stable-diffusion-v2-1") 28 | def read_v21(scheduler): 29 | model_id = "stabilityai/stable-diffusion-2-1" 30 | model = StableDiffusionPipeline.from_pretrained( 31 | model_id, torch_dtype=torch.float32, scheduler=scheduler 32 | ) 33 | return model 34 | -------------------------------------------------------------------------------- /diffusion_core/diffusion_schedulers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from diffusers.pipelines import StableDiffusionPipeline 4 | 5 | from diffusion_core.schedulers import DDIMScheduler 6 | from .utils import ClassRegistry 7 | 8 | 9 | diffusion_schedulers_registry = ClassRegistry() 10 | 11 | 12 | @diffusion_schedulers_registry.add_to_registry("ddim_50_eps") 13 | def get_ddim_50_e(): 14 | scheduler = DDIMScheduler( 15 | beta_start=0.00085, 16 | beta_end=0.012, 17 | beta_schedule="scaled_linear", 18 | set_alpha_to_one=False, 19 | num_inference_steps=50, 20 | prediction_type='epsilon' 21 | ) 22 | return scheduler 23 | 24 | 25 | @diffusion_schedulers_registry.add_to_registry("ddim_50_v") 26 | def get_ddim_50_v(): 27 | scheduler = DDIMScheduler( 28 | beta_start=0.00085, 29 | beta_end=0.012, 30 | beta_schedule="scaled_linear", 31 | set_alpha_to_one=False, 32 | num_inference_steps=50, 33 | prediction_type='v_prediction' 34 | ) 35 | return scheduler 36 | 37 | 38 | @diffusion_schedulers_registry.add_to_registry("ddim_200_v") 39 | def get_ddim_200_v(): 40 | scheduler = DDIMScheduler( 41 | beta_start=0.00085, 42 | beta_end=0.012, 43 | beta_schedule="scaled_linear", 44 | set_alpha_to_one=False, 45 | num_inference_steps=200, 46 | prediction_type='v_prediction' 47 | ) 48 | return scheduler 49 | -------------------------------------------------------------------------------- /diffusion_core/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from PIL import Image 5 | 6 | 7 | @torch.no_grad() 8 | def latent2image(latents, model, return_type='np'): 9 | latents = latents.detach() / model.vae.config.scaling_factor 10 | image = model.vae.decode(latents)['sample'] 11 | if return_type == 'np': 12 | image = (image / 2 + 0.5).clamp(0, 1) 13 | image = image.cpu().permute(0, 2, 3, 1).numpy() 14 | image = (image * 255).astype(np.uint8) 15 | return image 16 | 17 | 18 | @torch.no_grad() 19 | def image2latent(image, model): 20 | if type(image) is Image: 21 | image = np.array(image) 22 | if type(image) is torch.Tensor and image.dim() == 4: 23 | latents = image 24 | else: 25 | image = torch.from_numpy(image).float() / 127.5 - 1 26 | image = image.permute(2, 0, 1).unsqueeze(0).to(model.device).to(model.unet.dtype) 27 | latents = model.vae.encode(image)['latent_dist'].mean 28 | latents = latents * model.vae.config.scaling_factor 29 | return latents 30 | 31 | -------------------------------------------------------------------------------- /diffusion_core/guiders/__init__.py: -------------------------------------------------------------------------------- 1 | from .guidance_editing import GuidanceEditing -------------------------------------------------------------------------------- /diffusion_core/guiders/guidance_editing.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import OrderedDict 3 | from typing import Callable, Dict, Optional 4 | 5 | import torch 6 | import numpy as np 7 | import PIL 8 | import gc 9 | 10 | from tqdm.auto import trange, tqdm 11 | from diffusion_core.guiders.opt_guiders import opt_registry 12 | from diffusion_core.diffusion_utils import latent2image, image2latent 13 | from diffusion_core.custom_forwards.unet_sd import unet_forward 14 | from diffusion_core.guiders.noise_rescales import noise_rescales 15 | from diffusion_core.inversion import Inversion, NullInversion, NegativePromptInversion 16 | from diffusion_core.utils import toggle_grad, use_grad_checkpointing 17 | 18 | 19 | class GuidanceEditing: 20 | def __init__( 21 | self, 22 | model, 23 | config 24 | ): 25 | 26 | self.config = config 27 | self.model = model 28 | 29 | toggle_grad(self.model.unet, False) 30 | 31 | if config.get('gradient_checkpointing', False): 32 | use_grad_checkpointing(mode=True) 33 | else: 34 | use_grad_checkpointing(mode=False) 35 | 36 | self.guiders = { 37 | g_data.name: (opt_registry[g_data.name](**g_data.get('kwargs', {})), g_data.g_scale) 38 | for g_data in config.guiders 39 | } 40 | 41 | self._setup_inversion_engine() 42 | self.latents_stack = [] 43 | 44 | self.context = None 45 | 46 | self.noise_rescaler = noise_rescales[config.noise_rescaling_setup.type]( 47 | config.noise_rescaling_setup.init_setup, 48 | **config.noise_rescaling_setup.get('kwargs', {}) 49 | ) 50 | 51 | for guider_name, (guider, _) in self.guiders.items(): 52 | guider.clear_outputs() 53 | 54 | self.self_attn_layers_num = config.get('self_attn_layers_num', [6, 1, 9]) 55 | if type(self.self_attn_layers_num[0]) is int: 56 | for i in range(len(self.self_attn_layers_num)): 57 | self.self_attn_layers_num[i] = (0, self.self_attn_layers_num[i]) 58 | 59 | 60 | def _setup_inversion_engine(self): 61 | if self.config.inversion_type == 'ntinv': 62 | self.inversion_engine = NullInversion( 63 | self.model, 64 | self.model.scheduler.num_inference_steps, 65 | self.config.guiders[0]['g_scale'], 66 | forward_guidance_scale=1, 67 | verbose=self.config.verbose 68 | ) 69 | elif self.config.inversion_type == 'npinv': 70 | self.inversion_engine = NegativePromptInversion( 71 | self.model, 72 | self.model.scheduler.num_inference_steps, 73 | self.config.guiders[0]['g_scale'], 74 | forward_guidance_scale=1, 75 | verbose=self.config.verbose 76 | ) 77 | elif self.config.inversion_type == 'dummy': 78 | self.inversion_engine = Inversion( 79 | self.model, 80 | self.model.scheduler.num_inference_steps, 81 | self.config.guiders[0]['g_scale'], 82 | forward_guidance_scale=1, 83 | verbose=self.config.verbose 84 | ) 85 | else: 86 | raise ValueError('Incorrect InversionType') 87 | 88 | def __call__( 89 | self, 90 | image_gt: PIL.Image.Image, 91 | inv_prompt: str, 92 | trg_prompt: str, 93 | control_image: Optional[PIL.Image.Image] = None, 94 | verbose: bool = False 95 | ): 96 | self.train( 97 | image_gt, 98 | inv_prompt, 99 | trg_prompt, 100 | control_image, 101 | verbose 102 | ) 103 | 104 | return self.edit() 105 | 106 | def train( 107 | self, 108 | image_gt: PIL.Image.Image, 109 | inv_prompt: str, 110 | trg_prompt: str, 111 | control_image: Optional[PIL.Image.Image] = None, 112 | verbose: bool = False 113 | ): 114 | self.init_prompt(inv_prompt, trg_prompt) 115 | self.verbose = verbose 116 | 117 | image_gt = np.array(image_gt) 118 | if self.config.start_latent == 'inversion': 119 | _, self.inv_latents, self.uncond_embeddings = self.inversion_engine( 120 | image_gt, inv_prompt, 121 | verbose=self.verbose 122 | ) 123 | elif self.config.start_latent == 'random': 124 | self.inv_latents = self.sample_noised_latents( 125 | image2latent(image_gt, self.model) 126 | ) 127 | else: 128 | raise ValueError('Incorrect start latent type') 129 | 130 | for g_name, (guider, _) in self.guiders.items(): 131 | if hasattr(guider, 'model_patch'): 132 | guider.model_patch(self.model, self_attn_layers_num=self.self_attn_layers_num) 133 | 134 | self.start_latent = self.inv_latents[-1].clone() 135 | 136 | params = { 137 | 'model': self.model, 138 | 'inv_prompt': inv_prompt, 139 | 'trg_prompt': trg_prompt 140 | } 141 | for g_name, (guider, _) in self.guiders.items(): 142 | if hasattr(guider, 'train'): 143 | guider.train(params) 144 | 145 | for guider_name, (guider, _) in self.guiders.items(): 146 | guider.clear_outputs() 147 | 148 | def _construct_data_dict( 149 | self, latents, 150 | diffusion_iter, 151 | timestep 152 | ): 153 | uncond_emb, inv_prompt_emb, trg_prompt_emb = self.context.chunk(3) 154 | 155 | if self.uncond_embeddings is not None: 156 | uncond_emb = self.uncond_embeddings[diffusion_iter] 157 | 158 | data_dict = { 159 | 'latent': latents, 160 | 'inv_latent': self.inv_latents[-diffusion_iter - 1], 161 | 'timestep': timestep, 162 | 'model': self.model, 163 | 'uncond_emb': uncond_emb, 164 | 'trg_emb': trg_prompt_emb, 165 | 'inv_emb': inv_prompt_emb, 166 | 'diff_iter': diffusion_iter 167 | } 168 | 169 | with torch.no_grad(): 170 | uncond_unet = unet_forward( 171 | self.model, 172 | data_dict['latent'], 173 | data_dict['timestep'], 174 | data_dict['uncond_emb'], 175 | None 176 | ) 177 | 178 | for g_name, (guider, _) in self.guiders.items(): 179 | if hasattr(guider, 'model_patch'): 180 | guider.clear_outputs() 181 | 182 | with torch.no_grad(): 183 | inv_prompt_unet = unet_forward( 184 | self.model, 185 | data_dict['inv_latent'], 186 | data_dict['timestep'], 187 | data_dict['inv_emb'], 188 | None 189 | ) 190 | 191 | for g_name, (guider, _) in self.guiders.items(): 192 | if hasattr(guider, 'model_patch'): 193 | if 'inv_inv' in guider.forward_hooks: 194 | data_dict.update({f"{g_name}_inv_inv": guider.output}) 195 | guider.clear_outputs() 196 | 197 | data_dict['latent'].requires_grad = True 198 | 199 | src_prompt_unet = unet_forward( 200 | self.model, 201 | data_dict['latent'], 202 | data_dict['timestep'], 203 | data_dict['inv_emb'], 204 | None 205 | ) 206 | 207 | for g_name, (guider, _) in self.guiders.items(): 208 | if hasattr(guider, 'model_patch'): 209 | if 'cur_inv' in guider.forward_hooks: 210 | data_dict.update({f"{g_name}_cur_inv": guider.output}) 211 | guider.clear_outputs() 212 | 213 | trg_prompt_unet = unet_forward( 214 | self.model, 215 | data_dict['latent'], 216 | data_dict['timestep'], 217 | data_dict['trg_emb'], 218 | None 219 | ) 220 | 221 | for g_name, (guider, _) in self.guiders.items(): 222 | if hasattr(guider, 'model_patch'): 223 | if 'cur_trg' in guider.forward_hooks: 224 | data_dict.update({f"{g_name}_cur_trg": guider.output}) 225 | guider.clear_outputs() 226 | 227 | data_dict.update({ 228 | 'uncond_unet': uncond_unet, 229 | 'trg_prompt_unet': trg_prompt_unet, 230 | }) 231 | 232 | return data_dict 233 | 234 | def _get_noise(self, data_dict, diffusion_iter): 235 | backward_guiders_sum = 0. 236 | noises = { 237 | 'uncond': data_dict['uncond_unet'], 238 | } 239 | index = torch.where(self.model.scheduler.timesteps == data_dict['timestep'])[0].item() 240 | 241 | # self.noise_rescaler 242 | for name, (guider, g_scale) in self.guiders.items(): 243 | if guider.grad_guider: 244 | cur_noise_pred = self._get_scale(g_scale, diffusion_iter) * guider(data_dict) 245 | noises[name] = cur_noise_pred 246 | else: 247 | energy = self._get_scale(g_scale, diffusion_iter) * guider(data_dict) 248 | if not torch.allclose(energy, torch.tensor(0.)): 249 | backward_guiders_sum += energy 250 | 251 | if hasattr(backward_guiders_sum, 'backward'): 252 | backward_guiders_sum.backward() 253 | noises['other'] = data_dict['latent'].grad 254 | 255 | scales = self.noise_rescaler(noises, index) 256 | noise_pred = sum(scales[k] * noises[k] for k in noises) 257 | 258 | for g_name, (guider, _) in self.guiders.items(): 259 | if not guider.grad_guider: 260 | guider.clear_outputs() 261 | gc.collect() 262 | torch.cuda.empty_cache() 263 | 264 | return noise_pred 265 | 266 | @staticmethod 267 | def _get_scale(g_scale, diffusion_iter): 268 | if type(g_scale) is float: 269 | return g_scale 270 | else: 271 | return g_scale[diffusion_iter] 272 | 273 | @torch.no_grad() 274 | def _step(self, noise_pred, t, latents): 275 | latents = self.model.scheduler.step_backward(noise_pred, t, latents).prev_sample 276 | self.latents_stack.append(latents.detach()) 277 | return latents 278 | 279 | def edit(self): 280 | self.model.scheduler.set_timesteps(self.model.scheduler.num_inference_steps) 281 | latents = self.start_latent 282 | self.latents_stack = [] 283 | 284 | for i, timestep in tqdm( 285 | enumerate(self.model.scheduler.timesteps), 286 | total=self.model.scheduler.num_inference_steps, 287 | desc='Editing', 288 | disable=not self.verbose 289 | ): 290 | # 1. Construct dict 291 | data_dict = self._construct_data_dict(latents, i, timestep) 292 | 293 | # 2. Calculate guidance 294 | noise_pred = self._get_noise(data_dict, i) 295 | 296 | # 3. Scheduler step 297 | latents = self._step(noise_pred, timestep, latents) 298 | 299 | self._model_unpatch(self.model) 300 | return latent2image(latents, self.model)[0] 301 | 302 | @torch.no_grad() 303 | def init_prompt(self, inv_prompt: str, trg_prompt: str): 304 | trg_prompt_embed = self.get_prompt_embed(trg_prompt) 305 | inv_prompt_embed = self.get_prompt_embed(inv_prompt) 306 | uncond_embed = self.get_prompt_embed("") 307 | 308 | self.context = torch.cat([uncond_embed, inv_prompt_embed, trg_prompt_embed]) 309 | 310 | def get_prompt_embed(self, prompt: str): 311 | text_input = self.model.tokenizer( 312 | [prompt], 313 | padding="max_length", 314 | max_length=self.model.tokenizer.model_max_length, 315 | truncation=True, 316 | return_tensors="pt", 317 | ) 318 | text_embeddings = self.model.text_encoder( 319 | text_input.input_ids.to(self.model.device) 320 | )[0] 321 | 322 | return text_embeddings 323 | 324 | def sample_noised_latents(self, latent): 325 | all_latent = [latent.clone().detach()] 326 | latent = latent.clone().detach() 327 | for i in trange(self.model.scheduler.num_inference_steps, desc='Latent Sampling'): 328 | timestep = self.model.scheduler.timesteps[-i - 1] 329 | if i + 1 < len(self.model.scheduler.timesteps): 330 | next_timestep = self.model.scheduler.timesteps[- i - 2] 331 | else: 332 | next_timestep = 999 333 | 334 | alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep] 335 | alpha_prod_t_next = self.model.scheduler.alphas_cumprod[next_timestep] 336 | 337 | alpha_slice = alpha_prod_t_next / alpha_prod_t 338 | 339 | latent = torch.sqrt(alpha_slice) * latent + torch.sqrt(1 - alpha_slice) * torch.randn_like(latent) 340 | all_latent.append(latent) 341 | return all_latent 342 | 343 | def _model_unpatch(self, model): 344 | def new_forward_info(self): 345 | def patched_forward( 346 | hidden_states, 347 | encoder_hidden_states=None, 348 | attention_mask=None, 349 | temb=None, 350 | ): 351 | residual = hidden_states 352 | 353 | if self.spatial_norm is not None: 354 | hidden_states = self.spatial_norm(hidden_states, temb) 355 | 356 | input_ndim = hidden_states.ndim 357 | 358 | if input_ndim == 4: 359 | batch_size, channel, height, width = hidden_states.shape 360 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 361 | 362 | batch_size, sequence_length, _ = ( 363 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 364 | ) 365 | attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) 366 | 367 | if self.group_norm is not None: 368 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 369 | 370 | query = self.to_q(hidden_states) 371 | 372 | ## Injection 373 | is_self = encoder_hidden_states is None 374 | 375 | if encoder_hidden_states is None: 376 | encoder_hidden_states = hidden_states 377 | elif self.norm_cross: 378 | encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) 379 | 380 | key = self.to_k(encoder_hidden_states) 381 | value = self.to_v(encoder_hidden_states) 382 | 383 | query = self.head_to_batch_dim(query) 384 | key = self.head_to_batch_dim(key) 385 | value = self.head_to_batch_dim(value) 386 | 387 | attention_probs = self.get_attention_scores(query, key, attention_mask) 388 | hidden_states = torch.bmm(attention_probs, value) 389 | hidden_states = self.batch_to_head_dim(hidden_states) 390 | 391 | # linear proj 392 | hidden_states = self.to_out[0](hidden_states) 393 | # dropout 394 | hidden_states = self.to_out[1](hidden_states) 395 | 396 | if input_ndim == 4: 397 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 398 | 399 | if self.residual_connection: 400 | hidden_states = hidden_states + residual 401 | 402 | hidden_states = hidden_states / self.rescale_output_factor 403 | 404 | return hidden_states 405 | 406 | return patched_forward 407 | 408 | def register_attn(module): 409 | if 'Attention' in module.__class__.__name__: 410 | module.forward = new_forward_info(module) 411 | elif hasattr(module, 'children'): 412 | for module_ in module.children(): 413 | register_attn(module_) 414 | 415 | def remove_hooks(module): 416 | if hasattr(module, "_forward_hooks"): 417 | module._forward_hooks: Dict[int, Callable] = OrderedDict() 418 | if hasattr(module, 'children'): 419 | for module_ in module.children(): 420 | remove_hooks(module_) 421 | 422 | register_attn(model.unet) 423 | remove_hooks(model.unet) 424 | -------------------------------------------------------------------------------- /diffusion_core/guiders/noise_rescales.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import torch 3 | import numpy as np 4 | 5 | from diffusion_core.utils import ClassRegistry 6 | 7 | 8 | noise_rescales = ClassRegistry() 9 | 10 | 11 | class RescaleType(enum.Enum): 12 | UPPER = 0 13 | RANGE = 1 14 | 15 | 16 | class BaseNoiseRescaler: 17 | def __init__(self, noise_rescale_setup): 18 | if isinstance(noise_rescale_setup, float): 19 | self.upper_bound = noise_rescale_setup 20 | self.rescale_type = RescaleType.UPPER 21 | elif len(noise_rescale_setup) == 2: 22 | self.upper_bound, self.upper_bound = noise_rescale_setup 23 | self.rescale_type = RescaleType.RANGE 24 | else: 25 | raise TypeError('Incorrect noise_rescale_setup type: possible types are float, tuple(float, float)') 26 | 27 | def __call__(self, noises, index): 28 | if 'other' not in noises: 29 | return {k: 1. for k in noises} 30 | rescale_dict = self._rescale(noises, index) 31 | return rescale_dict 32 | 33 | def _rescale(self, noises, index): 34 | raise NotImplementedError('') 35 | 36 | 37 | @noise_rescales.add_to_registry('identity_rescaler') 38 | class IdentityRescaler: 39 | def __init__(self, *args, **kwargs): 40 | pass 41 | 42 | def __call__(self, noises, index): 43 | return {k: 1. for k in noises} 44 | 45 | 46 | @noise_rescales.add_to_registry('range_other_on_cfg_norm') 47 | class RangeNoiseRescaler(BaseNoiseRescaler): 48 | def __init__(self, noise_rescale_setup): 49 | super().__init__(noise_rescale_setup) 50 | assert len(noise_rescale_setup) == 2, 'incorrect len of noise_rescale_setup' 51 | self.lower_bound, self.upper_bound = noise_rescale_setup 52 | 53 | def _rescale(self, noises, index): 54 | cfg_noise_norm = torch.norm(noises['cfg']).item() 55 | other_noise_norm = torch.norm(noises['other']).item() 56 | 57 | ratio = other_noise_norm / cfg_noise_norm if cfg_noise_norm != 0 else 1. 58 | ratio_clipped = np.clip(ratio, self.lower_bound, self.upper_bound) 59 | if other_noise_norm != 0.: 60 | other_scale = ratio_clipped / ratio 61 | else: 62 | other_scale = 1. 63 | 64 | answer = { 65 | 'cfg': 1., 66 | 'uncond': 1., 67 | 'other': other_scale 68 | } 69 | return answer 70 | -------------------------------------------------------------------------------- /diffusion_core/guiders/opt_guiders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | 4 | from diffusion_core.utils.class_registry import ClassRegistry 5 | from diffusion_core.guiders.scale_schedulers import last_steps, first_steps 6 | 7 | opt_registry = ClassRegistry() 8 | 9 | class BaseGuider: 10 | def __init__(self): 11 | self.clear_outputs() 12 | 13 | @property 14 | def grad_guider(self): 15 | return hasattr(self, 'grad_fn') 16 | 17 | def __call__(self, data_dict): 18 | if self.grad_guider: 19 | return self.grad_fn(data_dict) 20 | else: 21 | return self.calc_energy(data_dict) 22 | 23 | def clear_outputs(self): 24 | if not self.grad_guider: 25 | self.output = self.single_output_clear() 26 | 27 | def single_output_clear(self): 28 | raise NotImplementedError() 29 | 30 | 31 | @opt_registry.add_to_registry('cfg') 32 | class ClassifierFreeGuidance(BaseGuider): 33 | def __init__(self, is_source_guidance=False): 34 | self.is_source_guidance = is_source_guidance 35 | 36 | def grad_fn(self, data_dict): 37 | prompt_unet = data_dict['src_prompt_unet'] if self.is_source_guidance else data_dict['trg_prompt_unet'] 38 | return prompt_unet - data_dict['uncond_unet'] 39 | 40 | 41 | @opt_registry.add_to_registry('latents_diff') 42 | class LatentsDiffGuidance(BaseGuider): 43 | """ 44 | \| z_t* - z_t \|^2_2 45 | """ 46 | def grad_fn(self, data_dict): 47 | return 2 * (data_dict['latent'] - data_dict['inv_latent']) 48 | 49 | 50 | @opt_registry.add_to_registry('features_map_l2') 51 | class FeaturesMapL2EnergyGuider(BaseGuider): 52 | def __init__(self, block='up'): 53 | assert block in ['down', 'up', 'mid', 'whole'] 54 | self.block = block 55 | 56 | patched = True 57 | forward_hooks = ['cur_trg', 'inv_inv'] 58 | def calc_energy(self, data_dict): 59 | return torch.mean(torch.pow(data_dict['features_map_l2_cur_trg'] - data_dict['features_map_l2_inv_inv'], 2)) 60 | 61 | def model_patch(self, model, self_attn_layers_num=None): 62 | def hook_fn(module, input, output): 63 | self.output = output 64 | if self.block == 'mid': 65 | model.unet.mid_block.register_forward_hook(hook_fn) 66 | elif self.block == 'up': 67 | model.unet.up_blocks[1].resnets[1].register_forward_hook(hook_fn) 68 | elif self.block == 'down': 69 | model.unet.down_blocks[1].resnets[1].register_forward_hook(hook_fn) 70 | 71 | def single_output_clear(self): 72 | None 73 | 74 | 75 | @opt_registry.add_to_registry('self_attn_map_l2') 76 | class SelfAttnMapL2EnergyGuider(BaseGuider): 77 | patched = True 78 | forward_hooks = ['cur_inv', 'inv_inv'] 79 | def single_output_clear(self): 80 | return { 81 | "down_cross": [], "mid_cross": [], "up_cross": [], 82 | "down_self": [], "mid_self": [], "up_self": [] 83 | } 84 | 85 | def calc_energy(self, data_dict): 86 | result = 0. 87 | for unet_place, data in data_dict['self_attn_map_l2_cur_inv'].items(): 88 | for elem_idx, elem in enumerate(data): 89 | result += torch.mean( 90 | torch.pow( 91 | elem - data_dict['self_attn_map_l2_inv_inv'][unet_place][elem_idx], 2 92 | ) 93 | ) 94 | self.single_output_clear() 95 | return result 96 | 97 | def model_patch(guider_self, model, self_attn_layers_num=None): 98 | def new_forward_info(self, place_unet): 99 | def patched_forward( 100 | hidden_states, 101 | encoder_hidden_states=None, 102 | attention_mask=None, 103 | temb=None, 104 | ): 105 | residual = hidden_states 106 | 107 | if self.spatial_norm is not None: 108 | hidden_states = self.spatial_norm(hidden_states, temb) 109 | 110 | input_ndim = hidden_states.ndim 111 | 112 | if input_ndim == 4: 113 | batch_size, channel, height, width = hidden_states.shape 114 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 115 | 116 | batch_size, sequence_length, _ = ( 117 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 118 | ) 119 | attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) 120 | 121 | if self.group_norm is not None: 122 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 123 | 124 | query = self.to_q(hidden_states) 125 | 126 | ## Injection 127 | is_self = encoder_hidden_states is None 128 | 129 | if encoder_hidden_states is None: 130 | encoder_hidden_states = hidden_states 131 | elif self.norm_cross: 132 | encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) 133 | 134 | key = self.to_k(encoder_hidden_states) 135 | value = self.to_v(encoder_hidden_states) 136 | 137 | query = self.head_to_batch_dim(query) 138 | key = self.head_to_batch_dim(key) 139 | value = self.head_to_batch_dim(value) 140 | 141 | attention_probs = self.get_attention_scores(query, key, attention_mask) 142 | if is_self: 143 | guider_self.output[f"{place_unet}_self"].append(attention_probs) 144 | hidden_states = torch.bmm(attention_probs, value) 145 | hidden_states = self.batch_to_head_dim(hidden_states) 146 | 147 | # linear proj 148 | hidden_states = self.to_out[0](hidden_states) 149 | # dropout 150 | hidden_states = self.to_out[1](hidden_states) 151 | 152 | if input_ndim == 4: 153 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 154 | 155 | if self.residual_connection: 156 | hidden_states = hidden_states + residual 157 | 158 | hidden_states = hidden_states / self.rescale_output_factor 159 | 160 | return hidden_states 161 | return patched_forward 162 | 163 | def register_attn(module, place_in_unet, layers_num, cur_layers_num=0): 164 | if 'Attention' in module.__class__.__name__: 165 | if 2 * layers_num[0] <= cur_layers_num < 2 * layers_num[1]: 166 | module.forward = new_forward_info(module, place_in_unet) 167 | return cur_layers_num + 1 168 | elif hasattr(module, 'children'): 169 | for module_ in module.children(): 170 | cur_layers_num = register_attn(module_, place_in_unet, layers_num, cur_layers_num) 171 | return cur_layers_num 172 | 173 | sub_nets = model.unet.named_children() 174 | for name, net in sub_nets: 175 | if "down" in name: 176 | register_attn(net, "down", self_attn_layers_num[0]) 177 | if "mid" in name: 178 | register_attn(net, "mid", self_attn_layers_num[1]) 179 | if "up" in name: 180 | register_attn(net, "up", self_attn_layers_num[2]) 181 | 182 | 183 | @opt_registry.add_to_registry('self_attn_map_l2_appearance') 184 | class SelfAttnMapL2withAppearanceEnergyGuider(BaseGuider): 185 | patched = True 186 | forward_hooks = ['cur_inv', 'inv_inv'] 187 | 188 | def __init__( 189 | self, self_attn_gs: float, app_gs: float, new_features: bool=False, 190 | total_last_steps: Optional[int] = None, total_first_steps: Optional[int] = None 191 | ): 192 | super().__init__() 193 | 194 | self.new_features = new_features 195 | 196 | if total_last_steps is not None: 197 | self.app_gs = last_steps(app_gs, total_last_steps) 198 | self.self_attn_gs = last_steps(self_attn_gs, total_last_steps) 199 | elif total_first_steps is not None: 200 | self.app_gs = first_steps(app_gs, total_first_steps) 201 | self.self_attn_gs = first_steps(self_attn_gs, total_first_steps) 202 | else: 203 | self.app_gs = app_gs 204 | self.self_attn_gs = self_attn_gs 205 | 206 | def single_output_clear(self): 207 | return { 208 | "down_self": [], 209 | "mid_self": [], 210 | "up_self": [], 211 | "features": None 212 | } 213 | 214 | def calc_energy(self, data_dict): 215 | self_attn_result = 0. 216 | unet_places = ['down_self', 'up_self', 'mid_self'] 217 | for unet_place in unet_places: 218 | data = data_dict['self_attn_map_l2_appearance_cur_inv'][unet_place] 219 | for elem_idx, elem in enumerate(data): 220 | self_attn_result += torch.mean( 221 | torch.pow( 222 | elem - data_dict['self_attn_map_l2_appearance_inv_inv'][unet_place][elem_idx], 2 223 | ) 224 | ) 225 | 226 | features_orig = data_dict['self_attn_map_l2_appearance_inv_inv']['features'] 227 | features_cur = data_dict['self_attn_map_l2_appearance_cur_inv']['features'] 228 | app_result = torch.mean(torch.abs(features_cur - features_orig)) 229 | 230 | self.single_output_clear() 231 | 232 | if type(self.app_gs) == float: 233 | _app_gs = self.app_gs 234 | else: 235 | _app_gs = self.app_gs[data_dict['diff_iter']] 236 | 237 | if type(self.self_attn_gs) == float: 238 | _self_attn_gs = self.self_attn_gs 239 | else: 240 | _self_attn_gs = self.self_attn_gs[data_dict['diff_iter']] 241 | 242 | return _self_attn_gs * self_attn_result + _app_gs * app_result 243 | 244 | def model_patch(guider_self, model, self_attn_layers_num=None): 245 | def new_forward_info(self, place_unet): 246 | def patched_forward( 247 | hidden_states, 248 | encoder_hidden_states=None, 249 | attention_mask=None, 250 | temb=None, 251 | ): 252 | residual = hidden_states 253 | 254 | if self.spatial_norm is not None: 255 | hidden_states = self.spatial_norm(hidden_states, temb) 256 | 257 | input_ndim = hidden_states.ndim 258 | 259 | if input_ndim == 4: 260 | batch_size, channel, height, width = hidden_states.shape 261 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 262 | 263 | batch_size, sequence_length, _ = ( 264 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 265 | ) 266 | attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) 267 | 268 | if self.group_norm is not None: 269 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 270 | 271 | query = self.to_q(hidden_states) 272 | 273 | ## Injection 274 | is_self = encoder_hidden_states is None 275 | 276 | if encoder_hidden_states is None: 277 | encoder_hidden_states = hidden_states 278 | elif self.norm_cross: 279 | encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) 280 | 281 | key = self.to_k(encoder_hidden_states) 282 | value = self.to_v(encoder_hidden_states) 283 | 284 | query = self.head_to_batch_dim(query) 285 | key = self.head_to_batch_dim(key) 286 | value = self.head_to_batch_dim(value) 287 | 288 | attention_probs = self.get_attention_scores(query, key, attention_mask) 289 | if is_self: 290 | guider_self.output[f"{place_unet}_self"].append(attention_probs) 291 | 292 | hidden_states = torch.bmm(attention_probs, value) 293 | 294 | hidden_states = self.batch_to_head_dim(hidden_states) 295 | 296 | # linear proj 297 | hidden_states = self.to_out[0](hidden_states) 298 | # dropout 299 | hidden_states = self.to_out[1](hidden_states) 300 | 301 | if input_ndim == 4: 302 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 303 | 304 | if self.residual_connection: 305 | hidden_states = hidden_states + residual 306 | 307 | hidden_states = hidden_states / self.rescale_output_factor 308 | 309 | return hidden_states 310 | return patched_forward 311 | 312 | def register_attn(module, place_in_unet, layers_num, cur_layers_num=0): 313 | if 'Attention' in module.__class__.__name__: 314 | if 2 * layers_num[0] <= cur_layers_num < 2 * layers_num[1]: 315 | module.forward = new_forward_info(module, place_in_unet) 316 | return cur_layers_num + 1 317 | elif hasattr(module, 'children'): 318 | for module_ in module.children(): 319 | cur_layers_num = register_attn(module_, place_in_unet, layers_num, cur_layers_num) 320 | return cur_layers_num 321 | 322 | sub_nets = model.unet.named_children() 323 | for name, net in sub_nets: 324 | if "down" in name: 325 | register_attn(net, "down", self_attn_layers_num[0]) 326 | if "mid" in name: 327 | register_attn(net, "mid", self_attn_layers_num[1]) 328 | if "up" in name: 329 | register_attn(net, "up", self_attn_layers_num[2]) 330 | 331 | def hook_fn(module, input, output): 332 | guider_self.output["features"] = output 333 | 334 | if guider_self.new_features: 335 | model.unet.up_blocks[-1].register_forward_hook(hook_fn) 336 | else: 337 | model.unet.conv_norm_out.register_forward_hook(hook_fn) 338 | -------------------------------------------------------------------------------- /diffusion_core/guiders/scale_schedulers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def first_steps(g_scale, steps): 5 | g_scales = np.ones(50) 6 | g_scales[:steps] *= g_scale 7 | g_scales[steps:] = 0. 8 | return g_scales.tolist() 9 | 10 | 11 | def last_steps(g_scale, steps): 12 | g_scales = np.ones(50) 13 | g_scales[-steps:] *= g_scale 14 | g_scales[:-steps] = 0. 15 | return g_scales.tolist() 16 | 17 | -------------------------------------------------------------------------------- /diffusion_core/inversion/__init__.py: -------------------------------------------------------------------------------- 1 | from .null_inversion import NullInversion, Inversion 2 | from .negativ_p_inversion import NegativePromptInversion -------------------------------------------------------------------------------- /diffusion_core/inversion/negativ_p_inversion.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import torch 3 | 4 | from PIL import Image 5 | from typing import Optional, Union 6 | 7 | from .null_inversion import Inversion 8 | 9 | 10 | class NegativePromptInversion(Inversion): 11 | def negative_prompt_inversion(self, latents, verbose=False): 12 | uncond_embeddings, cond_embeddings = self.context.chunk(2) 13 | uncond_embeddings_list = [cond_embeddings.detach()] * self.infer_steps 14 | return uncond_embeddings_list 15 | 16 | def __call__( 17 | self, 18 | image_gt: PIL.Image.Image, 19 | prompt: Union[str, torch.Tensor], 20 | control_image: Optional[PIL.Image.Image] = None, 21 | verbose: bool = False 22 | ): 23 | image_rec, latents, _ = super().__call__(image_gt, prompt, control_image, verbose) 24 | 25 | if verbose: 26 | print("[Negative-Prompt inversion]") 27 | 28 | uncond_embeddings = self.negative_prompt_inversion( 29 | latents, 30 | verbose 31 | ) 32 | return image_rec, latents, uncond_embeddings 33 | -------------------------------------------------------------------------------- /diffusion_core/inversion/null_inversion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as nnf 4 | import PIL 5 | 6 | from typing import Optional, Union, List, Dict 7 | from tqdm.auto import tqdm, trange 8 | from torch.optim.adam import Adam 9 | 10 | from diffusion_core.diffusion_utils import latent2image, image2latent 11 | from diffusion_core.custom_forwards.unet_sd import unet_forward 12 | from diffusion_core.schedulers.opt_schedulers import opt_registry 13 | 14 | 15 | class Inversion: 16 | def __init__( 17 | self, model, 18 | inference_steps, 19 | inference_guidance_scale, 20 | forward_guidance_scale=1, 21 | verbose=False 22 | ): 23 | self.model = model 24 | self.tokenizer = self.model.tokenizer 25 | self.scheduler = model.scheduler 26 | self.scheduler.set_timesteps(inference_steps) 27 | 28 | self.prompt = None 29 | self.context = None 30 | self.controlnet_cond = None 31 | 32 | self.forward_guidance = forward_guidance_scale 33 | self.backward_guidance = inference_guidance_scale 34 | self.infer_steps = inference_steps 35 | self.half_mode = model.unet.dtype == torch.float16 36 | self.verbose = verbose 37 | 38 | @torch.no_grad() 39 | def init_controlnet_cond(self, control_image): 40 | if control_image is None: 41 | return 42 | 43 | controlnet_cond = self.model.prepare_image( 44 | control_image, 45 | 512, 46 | 512, 47 | 1 * 1, 48 | 1, 49 | self.model.controlnet.device, 50 | self.model.controlnet.dtype, 51 | ) 52 | 53 | self.controlnet_cond = controlnet_cond 54 | 55 | def get_noise_pred_single(self, latents, t, context): 56 | noise_pred = unet_forward( 57 | self.model, 58 | latents, 59 | t, 60 | context, 61 | self.controlnet_cond 62 | ) 63 | return noise_pred 64 | 65 | def get_noise_pred_guided(self, latents, t, guidance_scale, context=None): 66 | if context is None: 67 | context = self.context 68 | 69 | latents_input = torch.cat([latents] * 2) 70 | 71 | if self.controlnet_cond is not None: 72 | controlnet_cond = torch.cat([self.controlnet_cond] * 2) 73 | noise_pred = unet_forward( 74 | self.model, 75 | latents_input, 76 | t, 77 | encoder_hidden_states=context, 78 | controlnet_cond=controlnet_cond 79 | ) 80 | else: 81 | noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"] 82 | 83 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 84 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) 85 | return noise_pred 86 | 87 | @torch.no_grad() 88 | def init_prompt(self, prompt: Union[str, torch.Tensor]): 89 | uncond_input = self.model.tokenizer( 90 | [""], padding="max_length", max_length=self.model.tokenizer.model_max_length, 91 | return_tensors="pt" 92 | ) 93 | uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0] 94 | if type(prompt) == str: 95 | text_input = self.model.tokenizer( 96 | [prompt], 97 | padding="max_length", 98 | max_length=self.model.tokenizer.model_max_length, 99 | truncation=True, 100 | return_tensors="pt", 101 | ) 102 | text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0] 103 | else: 104 | text_embeddings = prompt 105 | self.context = torch.cat([uncond_embeddings, text_embeddings]) 106 | self.prompt = prompt 107 | 108 | @torch.no_grad() 109 | def loop(self, latent): 110 | uncond_embeddings, cond_embeddings = self.context.chunk(2) 111 | all_latent = [latent] 112 | latent = latent.clone().detach() 113 | for i in trange(self.infer_steps, desc='Inversion', disable=not self.verbose): 114 | t = self.scheduler.timesteps[len(self.scheduler.timesteps) - i - 1] 115 | 116 | if not np.allclose(self.forward_guidance, 1.): 117 | noise_pred = self.get_noise_pred_guided( 118 | latent, t, self.forward_guidance, self.context 119 | ) 120 | else: 121 | noise_pred = self.get_noise_pred_single( 122 | latent, t, cond_embeddings 123 | ) 124 | 125 | latent = self.scheduler.step_forward(noise_pred, t, latent).prev_sample 126 | 127 | all_latent.append(latent) 128 | return all_latent 129 | 130 | @torch.no_grad() 131 | def inversion(self, image): 132 | latent = image2latent(image, self.model) 133 | image_rec = latent2image(latent, self.model) 134 | latents = self.loop(latent) 135 | return image_rec, latents 136 | 137 | def __call__( 138 | self, 139 | image_gt: PIL.Image.Image, 140 | prompt: Union[str, torch.Tensor], 141 | control_image: Optional[PIL.Image.Image] = None, 142 | verbose=False 143 | ): 144 | self.init_prompt(prompt) 145 | self.init_controlnet_cond(control_image) 146 | 147 | image_gt = np.array(image_gt) 148 | image_rec, latents = self.inversion(image_gt) 149 | 150 | return image_rec, latents, None 151 | 152 | 153 | class NullInversion(Inversion): 154 | def null_optimization(self, latents, verbose=False): 155 | self.scheduler.set_timesteps(self.infer_steps) 156 | 157 | uncond_embeddings, cond_embeddings = self.context.chunk(2) 158 | uncond_embeddings_list = [] 159 | latent_cur = latents[-1] 160 | bar = tqdm(total=self.opt_scheduler.max_inner_steps * self.infer_steps) 161 | 162 | for i in range(self.infer_steps): 163 | uncond_embeddings = uncond_embeddings.clone().detach() 164 | uncond_embeddings.requires_grad = True 165 | optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.)) 166 | latent_prev = latents[len(latents) - i - 2] 167 | t = self.scheduler.timesteps[i] 168 | 169 | with torch.no_grad(): 170 | noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings) 171 | 172 | for j in range(self.opt_scheduler.max_inner_steps): 173 | noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings) 174 | noise_pred = noise_pred_uncond + self.backward_guidance * (noise_pred_cond - noise_pred_uncond) 175 | latents_prev_rec = self.scheduler.step_backward(noise_pred, t, latent_cur, first_time=(j == 0), last_time=False).prev_sample 176 | loss = nnf.mse_loss(latents_prev_rec, latent_prev) 177 | optimizer.zero_grad() 178 | loss.backward() 179 | optimizer.step() 180 | loss_item = loss.item() 181 | 182 | bar.update() 183 | if self.opt_scheduler(i, j, loss_item): 184 | break 185 | for j in range(j + 1, self.opt_scheduler.max_inner_steps): 186 | bar.update() 187 | uncond_embeddings_list.append(uncond_embeddings[:1].detach()) 188 | 189 | with torch.no_grad(): 190 | context = torch.cat([uncond_embeddings, cond_embeddings]) 191 | noise_pred = self.get_noise_pred_guided(latent_cur, t, self.backward_guidance, context) 192 | latent_cur = self.scheduler.step_backward(noise_pred, t, latent_cur, first_time=False, last_time=True).prev_sample 193 | 194 | bar.close() 195 | return uncond_embeddings_list 196 | 197 | def __call__( 198 | self, 199 | image_gt: PIL.Image.Image, 200 | prompt: Union[str, torch.Tensor], 201 | opt_scheduler_name: str = 'loss', 202 | opt_num_inner_steps: int = 10, 203 | opt_early_stop_epsilon: float = 1e-5, 204 | opt_plateau_prop: float = 1/5, 205 | control_image: Optional[PIL.Image.Image] = None, 206 | verbose: bool = False 207 | ): 208 | image_rec, latents, _ = super().__call__(image_gt, prompt, control_image, verbose) 209 | 210 | if verbose: 211 | print("Null-text optimization...") 212 | 213 | self.opt_scheduler = opt_registry[opt_scheduler_name]( 214 | self.infer_steps, opt_num_inner_steps, 215 | opt_early_stop_epsilon, opt_plateau_prop 216 | ) 217 | 218 | uncond_embeddings = self.null_optimization( 219 | latents, 220 | verbose 221 | ) 222 | return image_rec, latents, uncond_embeddings 223 | -------------------------------------------------------------------------------- /diffusion_core/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from .sample_schedulers import DDIMScheduler -------------------------------------------------------------------------------- /diffusion_core/schedulers/opt_schedulers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from diffusion_core.utils.class_registry import ClassRegistry 3 | 4 | 5 | opt_registry = ClassRegistry() 6 | 7 | @opt_registry.add_to_registry('constant') 8 | class OptScheduler: 9 | def __init__(self, max_ddim_steps, max_inner_steps, early_stop_epsilon, plateau_prop): 10 | self.max_inner_steps = max_inner_steps 11 | self.max_ddim_steps = max_ddim_steps 12 | self.early_stop_epsilon = early_stop_epsilon 13 | self.plateau_prop = plateau_prop 14 | self.inner_steps_list = np.full(self.max_ddim_steps, self.max_inner_steps) 15 | 16 | def __call__(self, ddim_step, inner_step, loss=None): 17 | return inner_step + 1 >= self.inner_steps_list[ddim_step] 18 | 19 | 20 | @opt_registry.add_to_registry('loss') 21 | class LossOptScheduler(OptScheduler): 22 | def __call__(self, ddim_step, inner_step, loss): 23 | if loss < self.early_stop_epsilon + ddim_step * 2e-5: 24 | return True 25 | self.inner_steps_list[ddim_step] = inner_step + 1 26 | return False 27 | -------------------------------------------------------------------------------- /diffusion_core/schedulers/sample_schedulers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import List, Optional, Tuple, Union 4 | from dataclasses import dataclass 5 | from diffusers.utils.outputs import BaseOutput 6 | from diffusers.schedulers.scheduling_utils import SchedulerMixin 7 | from diffusers.configuration_utils import ConfigMixin, register_to_config 8 | 9 | 10 | @dataclass 11 | class SchedulerOutput(BaseOutput): 12 | prev_sample: torch.FloatTensor 13 | pred_original_sample: Optional[torch.FloatTensor] = None 14 | 15 | def rescale_zero_terminal_snr(betas): 16 | """ 17 | Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) 18 | 19 | 20 | Args: 21 | betas (`torch.FloatTensor`): 22 | the betas that the scheduler is being initialized with. 23 | 24 | Returns: 25 | `torch.FloatTensor`: rescaled betas with zero terminal SNR 26 | """ 27 | # Convert betas to alphas_bar_sqrt 28 | alphas = 1.0 - betas 29 | alphas_cumprod = torch.cumprod(alphas, dim=0) 30 | alphas_bar_sqrt = alphas_cumprod.sqrt() 31 | 32 | # Store old values. 33 | alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() 34 | alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() 35 | 36 | # Shift so the last timestep is zero. 37 | alphas_bar_sqrt -= alphas_bar_sqrt_T 38 | 39 | # Scale so the first timestep is back to the old value. 40 | alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) 41 | 42 | # Convert alphas_bar_sqrt to betas 43 | alphas_bar = alphas_bar_sqrt**2 # Revert sqrt 44 | alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod 45 | alphas = torch.cat([alphas_bar[0:1], alphas]) 46 | betas = 1 - alphas 47 | 48 | return betas 49 | 50 | 51 | class SampleScheduler(SchedulerMixin, ConfigMixin): 52 | @register_to_config 53 | def __init__( 54 | self, 55 | num_train_timesteps: int = 1000, 56 | num_inference_steps: int = 50, 57 | beta_start: float = 0.0001, 58 | beta_end: float = 0.02, 59 | beta_schedule: str = 'linear', 60 | rescale_betas_zero_snr: bool = False, 61 | timestep_spacing: str = 'leading', 62 | set_alpha_to_one: bool = False, 63 | prediction_type: str = 'epsilon' 64 | ): 65 | if beta_schedule == "linear": 66 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) 67 | elif beta_schedule == "scaled_linear": 68 | # this schedule is very specific to the latent diffusion model. 69 | self.betas = ( 70 | torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 71 | ) 72 | elif beta_schedule == "squaredcos_cap_v2": 73 | # Glide cosine schedule 74 | self.betas = betas_for_alpha_bar(num_train_timesteps) 75 | else: 76 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") 77 | 78 | 79 | self.alphas = 1.0 - self.betas 80 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 81 | self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] 82 | 83 | if num_inference_steps > num_train_timesteps: 84 | raise ValueError( 85 | f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.train_timesteps`:" 86 | f" {num_train_timesteps} as the unet model trained with this scheduler can only handle" 87 | f" maximal {num_train_timesteps} timesteps." 88 | ) 89 | 90 | self.timestep_spacing = timestep_spacing 91 | 92 | self.num_train_timesteps = num_train_timesteps 93 | self.num_inference_steps = num_inference_steps 94 | 95 | def step_backward( 96 | self, 97 | model_output: torch.FloatTensor, 98 | timestep: int, 99 | sample: torch.FloatTensor, 100 | first_time: bool = True, 101 | last_time: bool = True, 102 | return_dict: bool = True 103 | ): 104 | raise NotImplementedError(f"step_backward does is not implemented for {self.__class__}") 105 | 106 | def step_forward( 107 | self, 108 | model_output: torch.FloatTensor, 109 | timestep: int, 110 | sample: torch.FloatTensor, 111 | return_dict: bool = True 112 | ): 113 | raise NotImplementedError(f"step_forward does is not implemented for {self.__class__}") 114 | 115 | def set_timesteps( 116 | self, 117 | num_inference_steps: int = None, 118 | device: Union[str, torch.device] = None 119 | ): 120 | raise NotImplementedError(f"step_forward does is not implemented for {self.__class__}") 121 | 122 | 123 | class DDIMScheduler(SampleScheduler): 124 | @register_to_config 125 | def __init__(self, **args): 126 | super().__init__(**args) 127 | self.set_timesteps(self.num_inference_steps) 128 | self.step = self.step_backward 129 | self.init_noise_sigma = 1. 130 | self.order = 1 131 | self.scale_model_input = lambda x, t: x 132 | 133 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): 134 | """ 135 | Sets the discrete timesteps used for the diffusion chain (to be run before inference). 136 | 137 | Args: 138 | num_inference_steps (`int`): 139 | The number of diffusion steps used when generating samples with a pre-trained model. 140 | """ 141 | 142 | if num_inference_steps > self.num_train_timesteps: 143 | raise ValueError( 144 | f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" 145 | f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" 146 | f" maximal {self.num_train_timesteps} timesteps." 147 | ) 148 | 149 | self.num_inference_steps = num_inference_steps 150 | 151 | # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 152 | if self.timestep_spacing == "linspace": 153 | timesteps = ( 154 | np.linspace(0, self.num_train_timesteps - 1, num_inference_steps) 155 | .round()[::-1] 156 | .copy() 157 | .astype(np.int64) 158 | ) 159 | elif self.timestep_spacing == "leading": 160 | step_ratio = self.num_train_timesteps // self.num_inference_steps 161 | # creates integer timesteps by multiplying by ratio 162 | # casting to int to avoid issues when num_inference_step is power of 3 163 | timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) 164 | # timesteps += self.steps_offset 165 | elif self.timestep_spacing == "trailing": 166 | step_ratio = self.num_train_timesteps / self.num_inference_steps 167 | # creates integer timesteps by multiplying by ratio 168 | # casting to int to avoid issues when num_inference_step is power of 3 169 | timesteps = np.round(np.arange(self.num_train_timesteps, 0, -step_ratio)).astype(np.int64) 170 | timesteps -= 1 171 | else: 172 | raise ValueError( 173 | f"{self.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." 174 | ) 175 | 176 | self.timesteps = torch.from_numpy(timesteps).to(device) 177 | 178 | def step_backward( 179 | self, 180 | model_output: torch.FloatTensor, 181 | timestep: int, 182 | sample: torch.FloatTensor, 183 | return_dict: bool = True, 184 | **kwargs 185 | ): 186 | # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf 187 | # Ideally, read DDIM paper in-detail understanding 188 | 189 | # Notation ( -> 190 | # - pred_noise_t -> e_theta(x_t, t) 191 | # - pred_original_sample -> f_theta(x_t, t) or x_0 192 | # - std_dev_t -> sigma_t 193 | # - eta -> η 194 | # - pred_sample_direction -> "direction pointing to x_t" 195 | # - pred_prev_sample -> "x_t-1" 196 | 197 | # 1. get previous step value (=t-1) 198 | prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps 199 | 200 | # 2. compute alphas, betas 201 | alpha_prod_t = self.alphas_cumprod[timestep] 202 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod 203 | beta_prod_t = 1 - alpha_prod_t 204 | 205 | # 3. compute predicted original sample from predicted noise also called 206 | if self.config.prediction_type == "epsilon": 207 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 208 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) 209 | pred_epsilon = model_output 210 | elif self.config.prediction_type == "sample": 211 | pred_original_sample = model_output 212 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) 213 | elif self.config.prediction_type == "v_prediction": 214 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output 215 | pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample 216 | else: 217 | raise ValueError( 218 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" 219 | " `v_prediction`" 220 | ) 221 | 222 | if kwargs.get("ref_image", None) is not None and kwargs.get("recon_lr", 0.0) > 0.0: 223 | ref_image = kwargs.get("ref_image").expand_as(pred_original_sample) 224 | recon_lr = kwargs.get("recon_lr", 0.0) 225 | recon_mask = kwargs.get("recon_mask", None) 226 | if recon_mask is not None: 227 | recon_mask = recon_mask.expand_as(pred_original_sample).float() 228 | pred_original_sample = pred_original_sample - recon_lr * (pred_original_sample - ref_image) * recon_mask 229 | else: 230 | pred_original_sample = pred_original_sample - recon_lr * (pred_original_sample - ref_image) 231 | 232 | # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 233 | pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon 234 | 235 | # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 236 | prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction 237 | 238 | if not return_dict: 239 | return (prev_sample, pred_original_sample) 240 | 241 | return SchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) 242 | 243 | def step_forward( 244 | self, 245 | model_output: torch.FloatTensor, 246 | timestep: int, 247 | sample: torch.FloatTensor, 248 | return_dict: bool = True 249 | ): 250 | # 1. get previous step value (=t+1) 251 | prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps 252 | 253 | # 2. compute alphas, betas 254 | # change original implementation to exactly match noise levels for analogous forward process 255 | alpha_prod_t = self.alphas_cumprod[timestep] 256 | alpha_prod_t_prev = ( 257 | self.alphas_cumprod[prev_timestep] 258 | if prev_timestep < self.num_train_timesteps 259 | else self.alphas_cumprod[-1] 260 | ) 261 | 262 | beta_prod_t = 1 - alpha_prod_t 263 | 264 | # 3. compute predicted original sample from predicted noise also called 265 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 266 | if self.config.prediction_type == "epsilon": 267 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 268 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) 269 | pred_epsilon = model_output 270 | elif self.config.prediction_type == "sample": 271 | pred_original_sample = model_output 272 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) 273 | elif self.config.prediction_type == "v_prediction": 274 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output 275 | pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample 276 | else: 277 | raise ValueError( 278 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" 279 | " `v_prediction`" 280 | ) 281 | 282 | # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 283 | pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon 284 | 285 | # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 286 | prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction 287 | 288 | if not return_dict: 289 | return (prev_sample, pred_original_sample) 290 | 291 | return SchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) 292 | 293 | def __repr__(self): 294 | return f"{self.__class__.__name__} {self.to_json_string()}" 295 | -------------------------------------------------------------------------------- /diffusion_core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .image_utils import load_512 2 | from .class_registry import ClassRegistry 3 | from .grad_checkpoint import checkpoint_forward, use_grad_checkpointing 4 | from .model_utils import use_deterministic, toggle_grad 5 | -------------------------------------------------------------------------------- /diffusion_core/utils/class_registry.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import typing 3 | import omegaconf 4 | import dataclasses 5 | import typing as tp 6 | 7 | 8 | class ClassRegistry: 9 | def __init__(self): 10 | self.classes = dict() 11 | self.args = dict() 12 | self.arg_keys = None 13 | 14 | def __getitem__(self, item): 15 | return self.classes[item] 16 | 17 | def make_dataclass_from_func(self, func, name, arg_keys): 18 | args = inspect.signature(func).parameters 19 | args = [ 20 | (k, typing.Any, omegaconf.MISSING) 21 | if v.default is inspect.Parameter.empty 22 | else (k, typing.Optional[typing.Any], None) 23 | if v.default is None 24 | else ( 25 | k, 26 | type(v.default), 27 | dataclasses.field(default=v.default), 28 | ) 29 | for k, v in args.items() 30 | ] 31 | args = [ 32 | arg 33 | for arg in args 34 | if (arg[0] != "self" and arg[0] != "args" and arg[0] != "kwargs") 35 | ] 36 | if arg_keys: 37 | self.arg_keys = arg_keys 38 | arg_classes = dict() 39 | for key in arg_keys: 40 | arg_classes[key] = dataclasses.make_dataclass(key, args) 41 | return dataclasses.make_dataclass( 42 | name, 43 | [ 44 | (k, v, dataclasses.field(default=v())) 45 | for k, v in arg_classes.items() 46 | ], 47 | ) 48 | return dataclasses.make_dataclass(name, args) 49 | 50 | def make_dataclass_from_classes(self): 51 | return dataclasses.make_dataclass( 52 | 'Name', 53 | [ 54 | (k, v, dataclasses.field(default=v())) 55 | for k, v in self.classes.items() 56 | ], 57 | ) 58 | 59 | def make_dataclass_from_args(self): 60 | return dataclasses.make_dataclass( 61 | 'Name', 62 | [ 63 | (k, v, dataclasses.field(default=v())) 64 | for k, v in self.args.items() 65 | ], 66 | ) 67 | 68 | def _add_single_obj(self, obj, name, arg_keys): 69 | self.classes[name] = obj 70 | if inspect.isfunction(obj): 71 | self.args[name] = self.make_dataclass_from_func( 72 | obj, name, arg_keys 73 | ) 74 | elif inspect.isclass(obj): 75 | self.args[name] = self.make_dataclass_from_func( 76 | obj.__init__, name, arg_keys 77 | ) 78 | 79 | def add_to_registry(self, names: tp.Union[str, tp.List[str]], arg_keys=None): 80 | if not isinstance(names, list): 81 | names = [names] 82 | 83 | def decorator(obj): 84 | for name in names: 85 | self._add_single_obj(obj, name, arg_keys) 86 | 87 | return obj 88 | return decorator 89 | 90 | def __contains__(self, name: str): 91 | return name in self.args.keys() 92 | 93 | def __repr__(self): 94 | return f"{list(self.args.keys())}" 95 | -------------------------------------------------------------------------------- /diffusion_core/utils/grad_checkpoint.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from inspect import signature 3 | from torch.utils import checkpoint 4 | 5 | use_mode: bool = True 6 | 7 | 8 | def use_grad_checkpointing(mode: bool=True): 9 | global use_mode 10 | use_mode = mode 11 | 12 | 13 | def checkpoint_forward(func): 14 | sig = signature(func) 15 | 16 | @functools.wraps(func) 17 | def wrapper(*args, **kwargs): 18 | if use_mode: 19 | bound = sig.bind(*args, **kwargs) 20 | bound.apply_defaults() 21 | new_args = bound.arguments.values() 22 | result = checkpoint.checkpoint(func, *new_args, use_reentrant=False) 23 | else: 24 | result = func(*args, **kwargs) 25 | return result 26 | 27 | return wrapper 28 | -------------------------------------------------------------------------------- /diffusion_core/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from PIL import Image 4 | 5 | 6 | 7 | def load_512(image_path, left=0, right=0, top=0, bottom=0): 8 | if type(image_path) is str: 9 | image = np.array(Image.open(image_path).convert('RGB'))[:, :, :3] 10 | else: 11 | image = image_path 12 | h, w, c = image.shape 13 | left = min(left, w-1) 14 | right = min(right, w - left - 1) 15 | top = min(top, h - left - 1) 16 | bottom = min(bottom, h - top - 1) 17 | image = image[top:h-bottom, left:w-right] 18 | h, w, c = image.shape 19 | if h < w: 20 | offset = (w - h) // 2 21 | image = image[:, offset:offset + h] 22 | elif w < h: 23 | offset = (h - w) // 2 24 | image = image[offset:offset + w] 25 | image = np.array(Image.fromarray(image).resize((512, 512))) 26 | return image 27 | 28 | 29 | def to_im(torch_image, **kwargs): 30 | return transforms.ToPILImage()( 31 | make_grid(torch_image, **kwargs) 32 | ) 33 | -------------------------------------------------------------------------------- /diffusion_core/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def use_deterministic(): 5 | torch.backends.cudnn.benchmark = False 6 | torch.backends.cudnn.deterministic = True 7 | 8 | 9 | def toggle_grad(model, mode=True): 10 | for p in model.parameters(): 11 | p.requires_grad = mode 12 | -------------------------------------------------------------------------------- /docs/diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/Guide-and-Rescale/4a79013c641d20f57b206ad7f0bf2e9d01cd412c/docs/diagram.png -------------------------------------------------------------------------------- /docs/teaser_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/Guide-and-Rescale/4a79013c641d20f57b206ad7f0bf2e9d01cd412c/docs/teaser_image.png -------------------------------------------------------------------------------- /example_images/face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/Guide-and-Rescale/4a79013c641d20f57b206ad7f0bf2e9d01cd412c/example_images/face.png -------------------------------------------------------------------------------- /example_images/zebra.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/Guide-and-Rescale/4a79013c641d20f57b206ad7f0bf2e9d01cd412c/example_images/zebra.jpeg -------------------------------------------------------------------------------- /sd_env.yaml: -------------------------------------------------------------------------------- 1 | name: ldm 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - numpy=1.19.2 10 | - pip: 11 | - torch==2.3.0 12 | - torchvision==0.18.0 13 | - image-reward 14 | - albumentations==0.4.3 15 | - diffusers==0.17.1 16 | - matplotlib 17 | - opencv-python==4.1.2.30 18 | - pudb==2019.2 19 | - invisible-watermark 20 | - imageio==2.9.0 21 | - imageio-ffmpeg==0.4.2 22 | - pytorch-lightning 23 | - omegaconf==2.1.1 24 | - test-tube>=0.7.5 25 | - streamlit>=0.73.1 26 | - einops==0.3.0 27 | - torch-fidelity==0.3.0 28 | - transformers==4.29.0 29 | - torchmetrics==0.6.0 30 | - kornia==0.6 31 | - accelerate 32 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 33 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip 34 | - ipykernel 35 | - ipywidgets --------------------------------------------------------------------------------