├── LICENSE ├── README.md ├── assets ├── GoodDrag_demo.gif ├── cat_2 │ ├── image_with_new_points.png │ ├── image_with_points.jpg │ ├── original.jpg │ └── trajectory.gif ├── chess_1 │ ├── image_with_new_points.png │ ├── image_with_points.jpg │ ├── original.jpg │ └── trajectory.gif ├── furniture_0 │ ├── image_with_new_points.png │ ├── image_with_points.jpg │ ├── original.jpg │ └── trajectory.gif ├── gooddrag_icon.png ├── human_6 │ ├── image_with_new_points.png │ ├── image_with_points.jpg │ ├── original.jpg │ └── trajectory.gif ├── leopard │ ├── image_with_new_points.png │ ├── image_with_points.jpg │ ├── original.jpg │ └── trajectory.gif └── rabbit │ ├── image_with_new_points.png │ ├── image_with_points.jpg │ ├── original.jpg │ └── trajectory.gif ├── bench_gooddrag.py ├── dataset ├── cat_2 │ ├── image_with_points.jpg │ ├── mask.png │ ├── original.jpg │ └── points.json └── furniture_0 │ ├── image_with_points.jpg │ ├── mask.png │ ├── original.jpg │ └── points.json ├── environment.yaml ├── evaluation ├── GScore.ipynb └── compute_DAI.py ├── gooddrag_ui.py ├── pipeline.py ├── requirements.txt ├── utils ├── __pycache__ │ ├── attn_utils.cpython-310.pyc │ ├── drag_utils.cpython-310.pyc │ ├── lora_utils.cpython-310.pyc │ └── ui_utils.cpython-310.pyc ├── attn_utils.py ├── drag_utils.py ├── lora_utils.py └── ui_utils.py ├── webui.bat └── webui.sh /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

GoodDrag: Towards Good Practices for Drag Editing with Diffusion Models

3 |

4 | Zewei Zhang 5 |    6 | Huan Liu 7 |    8 | Jun Chen 9 |    10 | Xiangyu Xu 11 |

12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 |
20 | 21 |
22 | 23 | 24 | 25 | 26 |
27 | 28 |
29 | 30 | 31 | 32 | 33 |
34 | 35 | 36 |
37 | 38 | 39 | 40 | 41 |
42 | 43 | 44 |
45 | 46 | 47 | 48 | 49 |
50 | 51 |
52 | 53 | 54 | 55 | 56 |
57 | 58 |
59 | 60 |
61 |

62 |
63 | 64 | arXiv Preprint 65 | google colab logo 66 | Download Dataset 67 |

68 | 69 | ## 📢 Latest Updates 70 | - **2024.04.17** - Updated DAI (Dragging Accuracy Index) and GScore (Gemini Score) evaluation methods. Please check the evaluation file. GScore is modified from [Generative AI](https://github.com/GoogleCloudPlatform/generative-ai/tree/main). 71 | - **2025.03.01** - Our paper is accepted by ICLR2025! 72 | 73 | ## 1. Getting Started with GoodDrag 74 | 75 | Before getting started, please make sure your system is equipped with a CUDA-compatible GPU and Python 3.9 or higher. We provide three methods to directly run GoodDrag: 76 | ### 1️⃣ Automated Script for Effortless Setup 77 | 78 | - **Windows Users:** Double-click **webui.bat** to automatically set up your environment and launch the GoodDrag web UI. 79 | - **Linux Users:** Run **webui.sh** for a similar one-step setup and launch process. 80 | 81 | ### 2️⃣ Manual Installation via pip 82 | 1. Install the necessary dependencies: 83 | 84 | ```bash 85 | pip install -r requirements.txt 86 | ``` 87 | 2. Launch the GoodDrag web UI: 88 | 89 | ```bash 90 | python gooddrag_ui.py 91 | ``` 92 | 93 | ### 3️⃣ Quick Start with Colab 94 | For a quick and easy start, access GoodDrag directly through Google Colab. Click the badge below to open a pre-configured notebook that will guide you through using GoodDrag in the Colab environment: google colab logo 95 | 96 | ### Runtime and Memory Requirements 97 | GoodDrag's efficiency depends on the image size and editing complexity. For a 512x512 image on an A100 GPU: the LoRA phase requires ~17 seconds, and drag editing takes around 1 minute. GPU memory requirement is below 13GB. 98 | 99 | ## 2. Parameter Description 100 | 101 | We have predefined a set of parameters in the GoodDrag WebUI. Here are a few that you might consider adjusting: 102 | | Parameter Name | Description | 103 | | -------------- | ------------------------------------------------------------ | 104 | | Learning Rate | Influences the speed of drag editing. Higher values lead to faster editing but may result in lower quality or instability. It is recommended to keep this value below 0.05. | 105 | | Prompt | The text prompt for the diffusion model. It is suggested to leave this empty. | 106 | | End time step | Specifies the length of the time step during the denoise phase of the diffusion model for drag editing. If good results are obtained early in the generated video, consider reducing this value. Conversely, if the drag editing is insufficient, increase it slightly. It is recommended to keep this value below 12. | 107 | | Lambda | Controls the consistency of the non-dragged regions with the original image. A higher value keeps the area outside the mask more in line with the original image. | 108 | 109 | ## 3. Acknowledgments 110 | Part of the code was based on [DragDiffusion](https://github.com/Yujun-Shi/DragDiffusion) and [DragGAN](https://github.com/XingangPan/DragGAN). Thanks for the great work! 111 | 112 | ## 4. BibTeX 113 | ```bibtex 114 | @inproceedings{zhang2025gooddrag, 115 | title={GoodDrag: Towards Good Practices for Drag Editing with Diffusion Models}, 116 | author={Zewei Zhang and Huan Liu and Jun Chen and Xiangyu Xu}, 117 | booktitle={The Thirteenth International Conference on Learning Representations}, 118 | year={2025}, 119 | } 120 | ``` 121 | -------------------------------------------------------------------------------- /assets/GoodDrag_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/GoodDrag_demo.gif -------------------------------------------------------------------------------- /assets/cat_2/image_with_new_points.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/cat_2/image_with_new_points.png -------------------------------------------------------------------------------- /assets/cat_2/image_with_points.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/cat_2/image_with_points.jpg -------------------------------------------------------------------------------- /assets/cat_2/original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/cat_2/original.jpg -------------------------------------------------------------------------------- /assets/cat_2/trajectory.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/cat_2/trajectory.gif -------------------------------------------------------------------------------- /assets/chess_1/image_with_new_points.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/chess_1/image_with_new_points.png -------------------------------------------------------------------------------- /assets/chess_1/image_with_points.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/chess_1/image_with_points.jpg -------------------------------------------------------------------------------- /assets/chess_1/original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/chess_1/original.jpg -------------------------------------------------------------------------------- /assets/chess_1/trajectory.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/chess_1/trajectory.gif -------------------------------------------------------------------------------- /assets/furniture_0/image_with_new_points.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/furniture_0/image_with_new_points.png -------------------------------------------------------------------------------- /assets/furniture_0/image_with_points.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/furniture_0/image_with_points.jpg -------------------------------------------------------------------------------- /assets/furniture_0/original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/furniture_0/original.jpg -------------------------------------------------------------------------------- /assets/furniture_0/trajectory.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/furniture_0/trajectory.gif -------------------------------------------------------------------------------- /assets/gooddrag_icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/gooddrag_icon.png -------------------------------------------------------------------------------- /assets/human_6/image_with_new_points.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/human_6/image_with_new_points.png -------------------------------------------------------------------------------- /assets/human_6/image_with_points.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/human_6/image_with_points.jpg -------------------------------------------------------------------------------- /assets/human_6/original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/human_6/original.jpg -------------------------------------------------------------------------------- /assets/human_6/trajectory.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/human_6/trajectory.gif -------------------------------------------------------------------------------- /assets/leopard/image_with_new_points.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/leopard/image_with_new_points.png -------------------------------------------------------------------------------- /assets/leopard/image_with_points.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/leopard/image_with_points.jpg -------------------------------------------------------------------------------- /assets/leopard/original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/leopard/original.jpg -------------------------------------------------------------------------------- /assets/leopard/trajectory.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/leopard/trajectory.gif -------------------------------------------------------------------------------- /assets/rabbit/image_with_new_points.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/rabbit/image_with_new_points.png -------------------------------------------------------------------------------- /assets/rabbit/image_with_points.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/rabbit/image_with_points.jpg -------------------------------------------------------------------------------- /assets/rabbit/original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/rabbit/original.jpg -------------------------------------------------------------------------------- /assets/rabbit/trajectory.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/rabbit/trajectory.gif -------------------------------------------------------------------------------- /bench_gooddrag.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ************************************************************************* 14 | 15 | 16 | import os 17 | import sys 18 | import cv2 19 | import numpy as np 20 | import json 21 | from pathlib import Path 22 | from PIL import Image 23 | from utils.ui_utils import run_gooddrag, train_lora_interface, show_cur_points, create_video 24 | 25 | 26 | def benchmark_dataset(dataset_folder): 27 | dataset_path = Path(dataset_folder) 28 | subfolders = [f for f in dataset_path.iterdir() if f.is_dir() and f.name != '.ipynb_checkpoints'] 29 | 30 | for subfolder in subfolders: 31 | print(f'Benchmarking {subfolder.name}...') 32 | try: 33 | bench_one_image(subfolder) 34 | except Exception as e: 35 | print(f'An error occured while benchmarking {subfolder.name}: {e}.') 36 | 37 | 38 | def load_data(folder): 39 | """Load the original image, mask, and points from the specified folder.""" 40 | folder_path = Path(folder) 41 | 42 | # Load original image 43 | original_image_path = folder_path / 'original.jpg' 44 | original_image = Image.open(original_image_path) 45 | original_image = np.array(original_image) 46 | 47 | # Load mask 48 | mask_path = folder_path / 'mask.png' 49 | mask = Image.open(mask_path) 50 | mask = np.array(mask) 51 | if len(mask.shape) == 3: 52 | mask = mask[:, :, 0] 53 | 54 | # Load points 55 | points_path = folder_path / 'points.json' 56 | with open(points_path, 'r') as f: 57 | points_data = json.load(f) 58 | points = points_data['points'] 59 | 60 | image_points_path = folder_path / 'image_with_points.jpg' 61 | image_with_points = Image.open(image_points_path) 62 | image_with_points = np.array(image_with_points) 63 | 64 | return original_image, mask, points, image_with_points 65 | 66 | 67 | def bench_one_image(folder): 68 | """ 69 | Test the saved data by running the drag model. 70 | 71 | Args: 72 | folder: The folder where the original image, mask, and points are saved. 73 | """ 74 | original_image, mask, points, image_with_points = load_data(folder) 75 | model_path = 'runwayml/stable-diffusion-v1-5' 76 | 77 | lora_path = f'./lora_data/{folder.parts[-1]}' 78 | 79 | print(f'Training Lora.') 80 | train_lora_interface(original_image=original_image, prompt='', model_path=model_path, 81 | vae_path='stabilityai/sd-vae-ft-mse', 82 | lora_path=lora_path, lora_step=70, lora_lr=0.0005, lora_batch_size=4, lora_rank=16, 83 | use_gradio_progress=False) 84 | print(f'Training Lora Done! Begin dragging.') 85 | 86 | return_intermediate_images = True 87 | 88 | result_dir = f'./bench_result/{Path(folder).parts[-1]}' 89 | os.makedirs(result_dir, exist_ok=True) 90 | 91 | output_image, new_points = run_gooddrag( 92 | source_image=original_image, 93 | image_with_clicks=image_with_points, 94 | mask=mask, 95 | prompt='', 96 | points=points, 97 | inversion_strength=0.75, 98 | lam=0.1, 99 | latent_lr=0.02, 100 | model_path=model_path, 101 | vae_path='stabilityai/sd-vae-ft-mse', 102 | lora_path=lora_path, 103 | drag_end_step=7, 104 | track_per_step=10, 105 | save_intermedia=False, 106 | compare_mode=False, 107 | r1=4, 108 | r2=12, 109 | d=4, 110 | max_drag_per_track=3, 111 | drag_loss_threshold=0, 112 | once_drag=False, 113 | max_track_no_change=5, 114 | return_intermediate_images=return_intermediate_images, 115 | result_save_path=result_dir 116 | ) 117 | 118 | print(f'Drag finished!') 119 | output_image = cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR) 120 | output_image_path = os.path.join(result_dir, 'output_image.png') 121 | cv2.imwrite(output_image_path, output_image) 122 | 123 | img_with_new_points = show_cur_points(np.ascontiguousarray(output_image), new_points, bgr=True) 124 | new_points_image_path = os.path.join(result_dir, 'image_with_new_points.png') 125 | cv2.imwrite(new_points_image_path, img_with_new_points) 126 | 127 | points_path = os.path.join(result_dir, f'new_points.json') 128 | with open(points_path, 'w') as f: 129 | json.dump({'points': new_points}, f) 130 | 131 | if return_intermediate_images: 132 | create_video(result_dir, folder) 133 | 134 | 135 | def main(dataset_folder): 136 | benchmark_dataset(dataset_folder) 137 | 138 | 139 | if __name__ == '__main__': 140 | dataset = sys.argv[1] 141 | main(dataset) 142 | -------------------------------------------------------------------------------- /dataset/cat_2/image_with_points.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/dataset/cat_2/image_with_points.jpg -------------------------------------------------------------------------------- /dataset/cat_2/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/dataset/cat_2/mask.png -------------------------------------------------------------------------------- /dataset/cat_2/original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/dataset/cat_2/original.jpg -------------------------------------------------------------------------------- /dataset/cat_2/points.json: -------------------------------------------------------------------------------- 1 | {"points": [[147, 67], [202, 46], [410, 88], [336, 54]]} -------------------------------------------------------------------------------- /dataset/furniture_0/image_with_points.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/dataset/furniture_0/image_with_points.jpg -------------------------------------------------------------------------------- /dataset/furniture_0/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/dataset/furniture_0/mask.png -------------------------------------------------------------------------------- /dataset/furniture_0/original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/dataset/furniture_0/original.jpg -------------------------------------------------------------------------------- /dataset/furniture_0/points.json: -------------------------------------------------------------------------------- 1 | {"points": [[234, 133], [248, 71], [320, 133], [328, 70], [405, 137], [418, 67]]} -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: GoodDrag 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=22.3.1 8 | - cudatoolkit=11.7 9 | - pip: 10 | - torch==2.0.0 11 | - accelerate 12 | - torchvision==0.15.1 13 | - gradio==3.50.2 14 | - pydantic==2.0.2 15 | - albumentations==1.3.0 16 | - opencv-contrib-python==4.3.0.36 17 | - imageio==2.9.0 18 | - imageio-ffmpeg==0.4.2 19 | - pytorch-lightning==1.5.0 20 | - omegaconf==2.3.0 21 | - test-tube>=0.7.5 22 | - streamlit==1.12.1 23 | - einops==0.6.0 24 | - transformers==4.27.0 25 | - webdataset==0.2.5 26 | - kornia==0.6 27 | - open_clip_torch==2.16.0 28 | - invisible-watermark>=0.1.5 29 | - streamlit-drawable-canvas==0.8.0 30 | - torchmetrics==0.6.0 31 | - timm==0.6.12 32 | - addict==2.4.0 33 | - yapf==0.32.0 34 | - prettytable==3.6.0 35 | - safetensors==0.2.7 36 | - basicsr==1.4.2 37 | - accelerate==0.17.0 38 | - decord==0.6.0 39 | - diffusers==0.20.0 40 | - moviepy==1.0.3 41 | - opencv_python==4.7.0.68 42 | - Pillow==9.4.0 43 | - scikit_image==0.19.3 44 | - scipy==1.10.1 45 | - tensorboardX==2.6 46 | - tqdm==4.64.1 47 | - numpy==1.24.1 48 | - PySoundFile 49 | -------------------------------------------------------------------------------- /evaluation/GScore.ipynb: -------------------------------------------------------------------------------- 1 | {"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"uxCkB_DXTHzf"},"outputs":[],"source":["# Copyright 2023 Google LLC\n","#\n","# Licensed under the Apache License, Version 2.0 (the \"License\");\n","# you may not use this file except in compliance with the License.\n","# You may obtain a copy of the License at\n","#\n","# https://www.apache.org/licenses/LICENSE-2.0\n","#\n","# Unless required by applicable law or agreed to in writing, software\n","# distributed under the License is distributed on an \"AS IS\" BASIS,\n","# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n","# See the License for the specific language governing permissions and\n","# limitations under the License."]},{"cell_type":"markdown","metadata":{"id":"Hny4I-ODTIS6"},"source":["# Visual Question Answering (VQA) with Imagen on Vertex AI\n","\n","\n"," \n"," \n"," \n","
\n"," \n"," \"Google
Run in Colab\n","
\n","
\n"," \n"," \"GitHub
View on GitHub\n","
\n","
\n"," \n"," \"Vertex
Open in Vertex AI Workbench\n","
\n","
\n"]},{"cell_type":"markdown","metadata":{"id":"-nLS57E2TO5y"},"source":["## Overview\n","\n","[Imagen on Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/image/overview) (image Generative AI) offers a variety of features:\n","- Image generation\n","- Image editing\n","- Visual captioning\n","- Visual question answering\n","\n","This notebook focuses on **visual question answering** only.\n","\n","[Visual question answering (VQA) with Imagen](https://cloud.google.com/vertex-ai/docs/generative-ai/image/visual-question-answering) can understand the content of an image and answer questions about it. The model takes in an image and a question as input, and then using the image as context to produce one or more answers to the question.\n","\n","The visual question answering (VQA) can be used for a variety of use cases, including:\n","- assisting the visually impaired to gain more information about the images\n","- answering customer questions about products or services in the image\n","- creating interactive learning environment and providing interactive learning experiences"]},{"cell_type":"markdown","metadata":{"id":"iXsvgIuwTPZw"},"source":["### Objectives\n","\n","In this notebook, you will learn how to use the Vertex AI Python SDK to:\n","\n","- Answering questions about images using the Imagen's visual question answering features\n","\n","- Experiment with different parameters, such as:\n"," - number of answers to be provided by the model"]},{"cell_type":"markdown","metadata":{"id":"skXAu__iqks_"},"source":["### Costs\n","\n","This tutorial uses billable components of Google Cloud:\n","- Vertex AI (Imagen)\n","\n","Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing) and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage."]},{"cell_type":"markdown","metadata":{"id":"mvKl-BtQTRiQ"},"source":["## Getting Started"]},{"cell_type":"markdown","metadata":{"id":"PwFMpIMrTV_4"},"source":["### Install Vertex AI SDK, other packages and their dependencies"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":21358,"status":"ok","timestamp":1704745393930,"user":{"displayName":"Zhang Zewei","userId":"07162348264595431723"},"user_tz":300},"id":"WYUu8VMdJs3V","colab":{"base_uri":"https://localhost:8080/"},"outputId":"16a31034-4e81-4920-912b-d2065e9eb9a3"},"outputs":[{"output_type":"stream","name":"stdout","text":["\u001b[33m WARNING: The script tb-gcp-uploader is installed in '/root/.local/bin' which is not on PATH.\n"," Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.\u001b[0m\u001b[33m\n","\u001b[0m"]}],"source":["%pip install --upgrade --user google-cloud-aiplatform>=1.29.0"]},{"cell_type":"markdown","metadata":{"id":"R5Xep4W9lq-Z"},"source":["### Restart current runtime\n","\n","To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which will restart the current kernel."]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":3,"status":"ok","timestamp":1704745393930,"user":{"displayName":"Zhang Zewei","userId":"07162348264595431723"},"user_tz":300},"id":"XRvKdaPDTznN","outputId":"88bab8b5-a4f0-4462-bd3b-9e5b88bb59be"},"outputs":[{"output_type":"execute_result","data":{"text/plain":["{'status': 'ok', 'restart': True}"]},"metadata":{},"execution_count":2}],"source":["# Restart kernel after installs so that your environment can access the new packages\n","import IPython\n","import time\n","\n","app = IPython.Application.instance()\n","app.kernel.do_shutdown(True)"]},{"cell_type":"markdown","metadata":{"id":"SbmM4z7FOBpM"},"source":["
\n","⚠️ The kernel is going to restart. Please wait until it is finished before continuing to the next step. ⚠️\n","
\n","\n"]},{"cell_type":"markdown","metadata":{"id":"opUxT_k5TdgP"},"source":["### Authenticate your notebook environment (Colab only)\n","\n","If you are running this notebook on Google Colab, you will need to authenticate your environment. To do this, run the new cell below. This step is not required if you are using [Vertex AI Workbench](https://cloud.google.com/vertex-ai-workbench)."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"vbNgv4q1T2Mi"},"outputs":[],"source":["import sys\n","\n","if 'google.colab' in sys.modules:\n","\n"," # Authenticate user to Google Cloud\n"," from google.colab import auth\n"," auth.authenticate_user()"]},{"cell_type":"markdown","metadata":{"id":"ybBXSukZkgjg"},"source":["### Define Google Cloud project information (Colab only)\n","\n","If you are running this notebook on Google Colab, you need to define Google Cloud project information to be used. In the following cell, you will define the information, import Vertex AI package, and initialize it. This step is also not required if you are using [Vertex AI Workbench](https://cloud.google.com/vertex-ai-workbench)."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"MkVX07nOC90T"},"outputs":[],"source":["if 'google.colab' in sys.modules:\n","\n"," # Define project information\n"," PROJECT_ID = \"\" # @param {type:\"string\"}\n"," LOCATION = \"\" # @param {type:\"string\"}\n","\n"," # Initialize Vertex AI\n"," import vertexai\n"," vertexai.init(project=PROJECT_ID, location=LOCATION)"]},{"cell_type":"markdown","metadata":{"id":"uQrBo92-W_IL"},"source":["# Strat Here"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"hRnt-Xs9W-ag"},"outputs":[],"source":["!pip install --upgrade google-cloud-aiplatform"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"GZpOZbzKY79M"},"outputs":[],"source":["from google.colab import drive\n","drive.mount('/content/drive')"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Vqer9SUFXH80"},"outputs":[],"source":["import base64\n","import os\n","import vertexai\n","from vertexai.preview.generative_models import GenerativeModel, Part\n","\n","def encode_image(image_path):\n"," with open(image_path, \"rb\") as image_file:\n"," return base64.b64encode(image_file.read()).decode('utf-8')\n","\n","\n","def generate(prompt, original_img, ours_img, dragdif_img, sde_img):\n"," model = GenerativeModel(\"gemini-pro-vision\")\n"," responses = model.generate_content(\n"," [prompt, original_img, ours_img, dragdif_img, sde_img],\n"," generation_config={\n"," \"max_output_tokens\": 2048,\n"," \"temperature\": 0.4,\n"," \"top_p\": 1,\n"," \"top_k\": 32\n"," },\n"," )\n","\n"," return responses.text\n","\n","\n","def write_2_txt(evaluation: str, path: str):\n"," # Extract directory from the path\n"," directory = os.path.dirname(path)\n","\n"," # Check if the directory exists, and create it if it doesn't\n"," if not os.path.exists(directory):\n"," os.makedirs(directory)\n","\n"," # Now, write to the file\n"," with open(path, 'w') as file:\n"," file.write(evaluation)\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"6bS1XftdXUI0"},"outputs":[],"source":["from pathlib import Path\n","from vertexai.preview.generative_models import Part\n","from tqdm import tqdm\n","dataset_dir = Path('/content/drive/MyDrive/dataset')\n","gooddrag_dir = Path('/content/drive/MyDrive/gooddrag')\n","dragdiffusion_dir = Path('/content/drive/MyDrive/dragdiffusion')\n","sde_dir = Path('/content/drive/MyDrive/final_drag/sde')\n","result_dir = Path('/content/drive/MyDrive/GScore')\n","\n","prompt = prompt = '''Conduct a detailed evaluation of three modified images, labeled 'A', 'B', and 'C', in comparison to an original image (Image 1). Image 1 serves as the baseline and will not be evaluated. Focus on assessing the quality of 'A' (Image 2), 'B' (Image 3), and 'C' (Image 4), particularly in terms of their naturalness and the presence or absence of artifacts. Examine how well each algorithm preserves the integrity of the original image while introducing modifications. Look for any signs of distortions, unnatural colors, pixelation, or other visual inconsistencies. Rate each image on a scale from 1 to 10, where 10 represents excellent quality with seamless modifications, and 1 indicates poor quality with significant and noticeable artifacts. Provide a comprehensive analysis for each rating, highlighting specific aspects of the image that influenced your evaluation. Answers must be in English.'''\n","for i in range(10):\n"," result_dir = Path(f'/content/drive/MyDrive/GScore/result_{i}')\n"," for item in tqdm(dataset_dir.iterdir(), desc='Evaluating:'):\n"," img_name = item.name\n"," original_image = Part.from_data(data=base64.b64decode(encode_image(dataset_dir / img_name/ 'original.jpg')), mime_type=\"image/jpeg\")\n"," gooddrag_image = Part.from_data(data=base64.b64decode(encode_image(gooddrag_dir / img_name/ 'output_image.png')), mime_type=\"image/png\")\n"," dragdiffusion_image = Part.from_data(data=base64.b64decode(encode_image(dragdiffusion_dir / img_name/ 'output_image.png')), mime_type=\"image/png\")\n"," sde_image = Part.from_data(data=base64.b64decode(encode_image(sde_dir / f'{img_name}.png')), mime_type=\"image/png\")\n"," evaluation_file_path = result_dir / f'{img_name}.txt'\n"," cur_evaluation = generate(prompt, original_image, gooddrag_image, dragdiffusion_image, sde_image)\n"," write_2_txt(cur_evaluation, evaluation_file_path)\n"]}],"metadata":{"colab":{"provenance":[{"file_id":"1SjGPEFeIdOOvJ4zVRpJagxjip0QUPCMp","timestamp":1713362599717},{"file_id":"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/vision/getting-started/visual_question_answering.ipynb","timestamp":1702608367588}]},"environment":{"kernel":"python3","name":"tf2-gpu.2-11.m110","type":"gcloud","uri":"gcr.io/deeplearning-platform-release/tf2-gpu.2-11:m110"},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.12"}},"nbformat":4,"nbformat_minor":0} -------------------------------------------------------------------------------- /evaluation/compute_DAI.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from PIL import Image 4 | from scipy.ndimage import gaussian_filter 5 | from pathlib import Path 6 | import logging 7 | from torchvision import transforms 8 | 9 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 10 | 11 | 12 | def load_image(path): 13 | """ Load an image from the given path. """ 14 | return np.array(Image.open(path)) 15 | 16 | 17 | def get_patch(image, center, radius): 18 | """ Extract a patch from the image centered at 'center' with the given 'radius'. """ 19 | x, y = center 20 | return image[y - radius:y + radius + 1, x - radius:x + radius + 1] 21 | 22 | 23 | def calculate_difference(patch1, patch2): 24 | """ Calculate the L2 norm (Euclidean distance) between two patches. """ 25 | difference = patch1 - patch2 26 | squared_difference = np.square(difference) 27 | l2_distance = np.sum(squared_difference) 28 | 29 | return l2_distance 30 | 31 | 32 | def compute_dai(original_image, result_image, points, radius): 33 | """ Compute the Drag Accuracy Index (DAI) for the given images and points. """ 34 | dai = 0 35 | for start, target in points: 36 | original_patch = get_patch(original_image, start, radius) 37 | result_patch = get_patch(result_image, target, radius) 38 | dai += calculate_difference(original_patch, result_patch) 39 | dai /= len(points) 40 | dai /= cal_patch_size(radius) 41 | return dai / len(points) 42 | 43 | 44 | def get_points(points_dir): 45 | with open(points_dir, 'r') as file: 46 | points_data = json.load(file) 47 | points = points_data['points'] 48 | 49 | # Assuming pairs of points: [start, target, start, target, ...] 50 | point_pairs = [(points[i], points[i + 1]) for i in range(0, len(points), 2)] 51 | return point_pairs 52 | 53 | 54 | def cal_patch_size(radius: int): 55 | return (1 + 2 * radius) ** 2 56 | 57 | 58 | def compute_average_dai(radius, dataset_path, original_dataset_path=None): 59 | """ Compute the average DAI for a given dataset. """ 60 | dataset_dir = Path(dataset_path) 61 | original_dataset_dir = Path(original_dataset_path) if original_dataset_path else dataset_dir 62 | total_dai, num_folders = 0, 0 63 | transform = transforms.Compose([ 64 | transforms.ToTensor(), 65 | ]) 66 | 67 | for item in dataset_dir.iterdir(): 68 | if item.is_dir() or (item.is_file() and original_dataset_path): 69 | folder_name = item.stem if item.is_file() else item.name 70 | original_image_path = original_dataset_dir / folder_name / 'original.jpg' 71 | result_image_path = item if item.is_file() else item / 'output_image.png' 72 | points_json_path = original_dataset_dir / folder_name / 'points.json' 73 | 74 | if original_image_path.exists() and result_image_path.exists() and points_json_path.exists(): 75 | original_image = load_image(str(original_image_path)) 76 | result_image = load_image(str(result_image_path)) 77 | point_pairs = get_points(str(points_json_path)) 78 | 79 | original_image = transform(original_image).permute(1, 2, 0).numpy() 80 | result_image = transform(result_image).permute(1, 2, 0).numpy() 81 | dai = compute_dai(original_image, result_image, point_pairs, radius) 82 | total_dai += dai 83 | num_folders += 1 84 | else: 85 | logging.warning(f"Missing files in {folder_name}") 86 | 87 | if num_folders > 0: 88 | average_dai = total_dai / num_folders 89 | logging.info(f'Average DAI for {dataset_dir} with r3 {radius} is {average_dai:.4f}. Total {num_folders} images.') 90 | else: 91 | logging.warning("No valid folders found for DAI calculation.") 92 | 93 | 94 | def main(): 95 | gamma = [1, 5, 10, 20] 96 | result_folder = './bench_result' 97 | data_folder = './dataset' 98 | for r in gamma: 99 | compute_average_dai(r, result_folder, data_folder) 100 | 101 | 102 | if __name__ == '__main__': 103 | main() 104 | -------------------------------------------------------------------------------- /gooddrag_ui.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ************************************************************************* 14 | 15 | import os 16 | import gradio as gr 17 | from utils.ui_utils import ( 18 | get_points, undo_points, show_cur_points, 19 | clear_all, store_img, train_lora_interface, run_gooddrag, save_image_mask_points, save_drag_result, 20 | save_intermediate_images, create_video 21 | ) 22 | 23 | LENGTH = 512 24 | 25 | 26 | def create_markdown_section(): 27 | gr.Markdown(""" 28 | # GoodDrag ✨ 29 | 30 | 👋 Welcome to GoodDrag! Follow these steps to easily manipulate your images: 31 | 32 | 1. **Upload Image:** 📤 Either drag and drop an image or click to upload in the **Draw Mask** box. 33 | 2. **Prepare for Training:** 🛠️ 34 | - Set the path for the LoRA algorithm for your image. 35 | - Click the **Train LoRA** button to initiate the training process. 36 | 3. **Draw and Click:** ✏️ 37 | - Use the **Draw Mask** box to create a mask on your image. 38 | - Next, go to the **Click Points** box. Here, you can add multiple pairs of points by clicking on the desired locations. 39 | 4. **Save Current Data (Optional):** 💾 40 | - If you wish to save the current state (including the image, mask, points, and the composite image with mask and points), specify the data path. 41 | - Click **Save Current Data** to store these elements. 42 | 5. **Run Drag Process:** ▶️ 43 | - Click the **Run** button to process the image based on the drawn mask and points. 44 | 6. **Save the Results (Optional):** 🏁 45 | - Specify a path to save the final dragged image, the new points, and an image showing the new points. 46 | - Click **Save Result** to download these items. 47 | 7. **Save Intermediate Images (Optional):** 📸 48 | - For those interested in viewing the drag process step-by-step, check the **Save Intermediate Images** option under the **Get Intermediate Images** section. 49 | - To obtain a video of the drag process, ensure all optional steps above have been completed, then click the **Get Video** button. 50 | 51 | Enjoy creating with GoodDrag! 🌟 52 | 53 | """) 54 | 55 | 56 | def create_base_model_config_ui(): 57 | with gr.Tab("Diffusion Model"): 58 | with gr.Row(): 59 | local_models_dir = 'local_pretrained_models' 60 | os.makedirs(local_models_dir, exist_ok=True) 61 | local_models_choice = \ 62 | [os.path.join(local_models_dir, d) for d in os.listdir(local_models_dir) if 63 | os.path.isdir(os.path.join(local_models_dir, d))] 64 | model_path = gr.Dropdown(value="runwayml/stable-diffusion-v1-5", 65 | label="Diffusion Model Path", 66 | choices=[ 67 | "runwayml/stable-diffusion-v1-5", 68 | "stabilityai/stable-diffusion-2-1-base", 69 | "stabilityai/stable-diffusion-xl-base-1.0", 70 | ] + local_models_choice 71 | ) 72 | vae_path = gr.Dropdown(value="stabilityai/sd-vae-ft-mse", 73 | label="VAE choice", 74 | choices=["stabilityai/sd-vae-ft-mse", 75 | "default"] + local_models_choice 76 | ) 77 | 78 | return model_path, vae_path 79 | 80 | 81 | def create_lora_parameters_ui(): 82 | with gr.Tab("LoRA Parameters"): 83 | with gr.Row(): 84 | lora_step = gr.Number(value=70, label="LoRA training steps", precision=0) 85 | lora_lr = gr.Number(value=0.0005, label="LoRA learning rate") 86 | lora_batch_size = gr.Number(value=4, label="LoRA batch size", precision=0) 87 | lora_rank = gr.Number(value=16, label="LoRA rank", precision=0) 88 | 89 | return lora_step, lora_lr, lora_batch_size, lora_rank 90 | 91 | 92 | def create_real_image_editing_ui(): 93 | with gr.Row(): 94 | with gr.Column(): 95 | gr.Markdown("

📤 Draw Mask

") 96 | canvas = gr.Image(type="numpy", tool="sketch", label="Draw your mask on the image", 97 | show_label=True, height=LENGTH, width=LENGTH) # for mask painting 98 | with gr.Row(): 99 | train_lora_button = gr.Button("Train LoRA") 100 | lora_path = gr.Textbox(value=f"./lora_data/test", label="LoRA Path", 101 | placeholder="Enter path for LoRA data") 102 | 103 | with gr.Row(): 104 | lora_status_bar = gr.Textbox(label="LoRA Training Status", interactive=False) 105 | 106 | with gr.Column(): 107 | gr.Markdown("

✏Click Points

") 108 | input_image = gr.Image(type="numpy", label="Click on the image to mark points", 109 | show_label=True, height=LENGTH, width=LENGTH) # for points clicking 110 | with gr.Row(): 111 | undo_button = gr.Button("Undo Point") 112 | save_button = gr.Button('Save Current Data') 113 | data_dir = gr.Textbox(value='./dataset/test', label="Data Directory", 114 | placeholder="Enter directory path for mask and points") 115 | 116 | with gr.Column(): 117 | gr.Markdown("

🖼️ Editing Result

") 118 | output_image = gr.Image(type="numpy", label="View the editing results here", 119 | show_label=True, height=LENGTH, width=LENGTH) 120 | with gr.Row(): 121 | run_button = gr.Button("Run") 122 | clear_all_button = gr.Button("Clear All") 123 | save_result = gr.Button("Save Result") 124 | show_points = gr.Button("Show Points") 125 | result_save_path = gr.Textbox(value='./result/test', label="Result Folder", 126 | placeholder="Enter path to save the results") 127 | 128 | return canvas, train_lora_button, lora_path, lora_status_bar, input_image, undo_button, save_button, data_dir, \ 129 | output_image, run_button, clear_all_button, show_points, result_save_path, save_result 130 | 131 | 132 | def create_drag_parameters_ui(): 133 | with gr.Tab("Drag Parameters"): 134 | with gr.Row(): 135 | latent_lr = gr.Number(value=0.02, label="Learning rate") 136 | prompt = gr.Textbox(label="Prompt") 137 | drag_end_step = gr.Number(value=7, label="End time step", precision=0) 138 | drag_per_step = gr.Number(value=10, label="Point tracking number per each step", precision=0) 139 | 140 | return latent_lr, prompt, drag_end_step, drag_per_step 141 | 142 | 143 | def create_advance_parameters_ui(): 144 | with gr.Tab("Advanced Parameters"): 145 | with gr.Row(): 146 | r1 = gr.Number(value=4, label="Motion supervision feature path size", precision=0) 147 | r2 = gr.Number(value=12, label="Point tracking feature patch size", precision=0) 148 | drag_distance = gr.Number(value=4, label="The distance for motion supervision", precision=0) 149 | feature_idx = gr.Number(value=3, label="The index of the features [0,3]", precision=0) 150 | max_drag_per_track = gr.Number(value=3, 151 | label="Motion supervision times for each point tracking", 152 | precision=0) 153 | 154 | with gr.Row(): 155 | lam = gr.Number(value=0.2, label="Lambda", info="Regularization strength on unmasked areas") 156 | inversion_strength = gr.Slider(0, 1.0, 157 | value=0.75, 158 | label="Inversion strength") 159 | max_track_no_change = gr.Number(value=10, label="Early stop", 160 | info="The maximum number of times points is unchanged.") 161 | 162 | return (r1, r2, drag_distance, feature_idx, max_drag_per_track, lam, 163 | inversion_strength, max_track_no_change) 164 | 165 | 166 | def create_intermediate_save_ui(): 167 | with gr.Tab("Get Intermediate Images"): 168 | with gr.Row(): 169 | save_intermediates_images = gr.Checkbox(label='Save intermediate images') 170 | get_mp4 = gr.Button("Get video") 171 | 172 | return save_intermediates_images, get_mp4 173 | 174 | 175 | def attach_canvas_event(canvas: gr.State, original_image: gr.State, 176 | selected_points: gr.State, input_image, mask): 177 | canvas.edit( 178 | store_img, 179 | [canvas], 180 | [original_image, selected_points, input_image, mask] 181 | ) 182 | 183 | 184 | def attach_input_image_event(input_image, selected_points): 185 | input_image.select( 186 | get_points, 187 | [input_image, selected_points], 188 | [input_image] 189 | ) 190 | 191 | 192 | def attach_undo_button_event(undo_button, original_image, selected_points, mask, input_image): 193 | undo_button.click( 194 | undo_points, 195 | [original_image, mask], 196 | [input_image, selected_points] 197 | ) 198 | 199 | 200 | def attach_train_lora_button_event(train_lora_button, original_image, prompt, 201 | model_path, vae_path, lora_path, 202 | lora_step, lora_lr, lora_batch_size, lora_rank, 203 | lora_status_bar): 204 | train_lora_button.click( 205 | train_lora_interface, 206 | [original_image, prompt, model_path, vae_path, lora_path, 207 | lora_step, lora_lr, lora_batch_size, lora_rank], 208 | [lora_status_bar] 209 | ) 210 | 211 | 212 | def attach_run_button_event(run_button, original_image, input_image, mask, prompt, 213 | selected_points, inversion_strength, lam, latent_lr, 214 | model_path, vae_path, lora_path, 215 | drag_end_step, drag_per_step, 216 | output_image, r1, r2, d, feature_idx, new_points, 217 | max_drag_per_track, max_track_no_change, 218 | result_save_path, save_intermediates_images): 219 | run_button.click( 220 | run_gooddrag, 221 | [original_image, input_image, mask, prompt, selected_points, 222 | inversion_strength, lam, latent_lr, model_path, vae_path, 223 | lora_path, drag_end_step, drag_per_step, r1, r2, d, 224 | max_drag_per_track, max_track_no_change, feature_idx, result_save_path, save_intermediates_images], 225 | [output_image, new_points] 226 | ) 227 | 228 | 229 | def attach_show_points_event(show_points, output_image, selected_points): 230 | show_points.click( 231 | show_cur_points, 232 | [output_image, selected_points], 233 | [output_image] 234 | ) 235 | 236 | 237 | def attach_clear_all_button_event(clear_all_button, canvas, input_image, 238 | output_image, selected_points, original_image, mask): 239 | clear_all_button.click( 240 | clear_all, 241 | [gr.Number(value=LENGTH, visible=False, precision=0)], 242 | [canvas, input_image, output_image, selected_points, original_image, mask] 243 | ) 244 | 245 | 246 | def attach_save_button_event(save_button, mask, selected_points, input_image, save_dir): 247 | """ 248 | Attaches an event to the save button to trigger the save function. 249 | """ 250 | save_button.click( 251 | save_image_mask_points, 252 | inputs=[mask, selected_points, input_image, save_dir], 253 | outputs=[] 254 | ) 255 | 256 | 257 | def attach_save_result_event(save_result, output_image, new_points, result_path): 258 | """ 259 | Attaches an event to the save button to trigger the save function. 260 | """ 261 | save_result.click( 262 | save_drag_result, 263 | inputs=[output_image, new_points, result_path], 264 | outputs=[] 265 | ) 266 | 267 | 268 | def attach_video_event(get_mp4_button, result_save_path, data_dir): 269 | get_mp4_button.click( 270 | create_video, 271 | inputs=[result_save_path, data_dir] 272 | ) 273 | 274 | 275 | def main(): 276 | with gr.Blocks() as demo: 277 | mask = gr.State(value=None) 278 | selected_points = gr.State([]) 279 | new_points = gr.State([]) 280 | original_image = gr.State(value=None) 281 | create_markdown_section() 282 | intermediate_images = gr.State([]) 283 | 284 | canvas, train_lora_button, lora_path, lora_status_bar, input_image, undo_button, save_button, data_dir, \ 285 | output_image, run_button, clear_all_button, show_points, result_save_path, \ 286 | save_result = create_real_image_editing_ui() 287 | 288 | latent_lr, prompt, drag_end_step, drag_per_step = create_drag_parameters_ui() 289 | 290 | model_path, vae_path = create_base_model_config_ui() 291 | lora_step, lora_lr, lora_batch_size, lora_rank = create_lora_parameters_ui() 292 | r1, r2, d, feature_idx, max_drag_per_track, lam, inversion_strength, max_track_no_change = \ 293 | create_advance_parameters_ui() 294 | save_intermediates_images, get_mp4_button = create_intermediate_save_ui() 295 | 296 | attach_canvas_event(canvas, original_image, selected_points, input_image, mask) 297 | attach_input_image_event(input_image, selected_points) 298 | attach_undo_button_event(undo_button, original_image, selected_points, mask, input_image) 299 | attach_train_lora_button_event(train_lora_button, original_image, prompt, model_path, vae_path, lora_path, 300 | lora_step, lora_lr, lora_batch_size, lora_rank, lora_status_bar) 301 | attach_run_button_event(run_button, original_image, input_image, mask, prompt, selected_points, 302 | inversion_strength, lam, latent_lr, model_path, vae_path, lora_path, 303 | drag_end_step, drag_per_step, output_image, 304 | r1, r2, d, feature_idx, new_points, max_drag_per_track, 305 | max_track_no_change, result_save_path, save_intermediates_images) 306 | attach_show_points_event(show_points, output_image, new_points) 307 | attach_clear_all_button_event(clear_all_button, canvas, input_image, output_image, selected_points, 308 | original_image, mask) 309 | attach_save_button_event(save_button, mask, selected_points, input_image, data_dir) 310 | attach_save_result_event(save_result, output_image, new_points, result_save_path) 311 | attach_video_event(get_mp4_button, result_save_path, data_dir) 312 | 313 | demo.queue().launch(share=True, debug=True) 314 | 315 | 316 | if __name__ == '__main__': 317 | main() 318 | -------------------------------------------------------------------------------- /pipeline.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ************************************************************************* 14 | 15 | 16 | import torch 17 | import numpy as np 18 | import copy 19 | 20 | import torch.nn.functional as F 21 | from einops import rearrange 22 | from tqdm import tqdm 23 | from PIL import Image 24 | from typing import Any, Dict, List, Optional, Tuple, Union 25 | 26 | from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline 27 | from utils.drag_utils import point_tracking, check_handle_reach_target, interpolate_feature_patch 28 | from utils.attn_utils import register_attention_editor_diffusers, MutualSelfAttentionControl 29 | from diffusers import DDIMScheduler, AutoencoderKL 30 | from pytorch_lightning import seed_everything 31 | from accelerate import Accelerator 32 | 33 | 34 | # from diffusers.models.attention_processor import LoRAAttnProcessor2_0 35 | 36 | 37 | # override unet forward 38 | # The only difference from diffusers: 39 | # return intermediate UNet features of all UpSample blocks 40 | def override_forward(self): 41 | def forward( 42 | sample: torch.FloatTensor, 43 | timestep: Union[torch.Tensor, float, int], 44 | encoder_hidden_states: torch.Tensor, 45 | class_labels: Optional[torch.Tensor] = None, 46 | timestep_cond: Optional[torch.Tensor] = None, 47 | attention_mask: Optional[torch.Tensor] = None, 48 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 49 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 50 | mid_block_additional_residual: Optional[torch.Tensor] = None, 51 | return_intermediates: bool = False, 52 | last_up_block_idx: int = None, 53 | ): 54 | # By default samples have to be AT least a multiple of the overall upsampling factor. 55 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). 56 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 57 | # on the fly if necessary. 58 | default_overall_up_factor = 2 ** self.num_upsamplers 59 | 60 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 61 | forward_upsample_size = False 62 | upsample_size = None 63 | 64 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 65 | forward_upsample_size = True 66 | 67 | # prepare attention_mask 68 | if attention_mask is not None: 69 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 70 | attention_mask = attention_mask.unsqueeze(1) 71 | 72 | # 0. center input if necessary 73 | if self.config.center_input_sample: 74 | sample = 2 * sample - 1.0 75 | 76 | # 1. time 77 | timesteps = timestep 78 | if not torch.is_tensor(timesteps): 79 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 80 | # This would be a good case for the `match` statement (Python 3.10+) 81 | is_mps = sample.device.type == "mps" 82 | if isinstance(timestep, float): 83 | dtype = torch.float32 if is_mps else torch.float64 84 | else: 85 | dtype = torch.int32 if is_mps else torch.int64 86 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 87 | elif len(timesteps.shape) == 0: 88 | timesteps = timesteps[None].to(sample.device) 89 | 90 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 91 | timesteps = timesteps.expand(sample.shape[0]) 92 | 93 | t_emb = self.time_proj(timesteps) 94 | 95 | # `Timesteps` does not contain any weights and will always return f32 tensors 96 | # but time_embedding might actually be running in fp16. so we need to cast here. 97 | # there might be better ways to encapsulate this. 98 | t_emb = t_emb.to(dtype=self.dtype) 99 | 100 | emb = self.time_embedding(t_emb, timestep_cond) 101 | 102 | if self.class_embedding is not None: 103 | if class_labels is None: 104 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 105 | 106 | if self.config.class_embed_type == "timestep": 107 | class_labels = self.time_proj(class_labels) 108 | 109 | # `Timesteps` does not contain any weights and will always return f32 tensors 110 | # there might be better ways to encapsulate this. 111 | class_labels = class_labels.to(dtype=sample.dtype) 112 | 113 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 114 | 115 | if self.config.class_embeddings_concat: 116 | emb = torch.cat([emb, class_emb], dim=-1) 117 | else: 118 | emb = emb + class_emb 119 | 120 | if self.config.addition_embed_type == "text": 121 | aug_emb = self.add_embedding(encoder_hidden_states) 122 | emb = emb + aug_emb 123 | 124 | if self.time_embed_act is not None: 125 | emb = self.time_embed_act(emb) 126 | 127 | if self.encoder_hid_proj is not None: 128 | encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) 129 | 130 | # 2. pre-process 131 | sample = self.conv_in(sample) 132 | 133 | # 3. down 134 | down_block_res_samples = (sample,) 135 | for downsample_block in self.down_blocks: 136 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 137 | sample, res_samples = downsample_block( 138 | hidden_states=sample, 139 | temb=emb, 140 | encoder_hidden_states=encoder_hidden_states, 141 | attention_mask=attention_mask, 142 | cross_attention_kwargs=cross_attention_kwargs, 143 | ) 144 | else: 145 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 146 | 147 | down_block_res_samples += res_samples 148 | 149 | if down_block_additional_residuals is not None: 150 | new_down_block_res_samples = () 151 | 152 | for down_block_res_sample, down_block_additional_residual in zip( 153 | down_block_res_samples, down_block_additional_residuals 154 | ): 155 | down_block_res_sample = down_block_res_sample + down_block_additional_residual 156 | new_down_block_res_samples += (down_block_res_sample,) 157 | 158 | down_block_res_samples = new_down_block_res_samples 159 | 160 | # 4. mid 161 | if self.mid_block is not None: 162 | sample = self.mid_block( 163 | sample, 164 | emb, 165 | encoder_hidden_states=encoder_hidden_states, 166 | attention_mask=attention_mask, 167 | cross_attention_kwargs=cross_attention_kwargs, 168 | ) 169 | 170 | if mid_block_additional_residual is not None: 171 | sample = sample + mid_block_additional_residual 172 | 173 | # 5. up 174 | # only difference from diffusers: 175 | # save the intermediate features of unet upsample blocks 176 | # the 0-th element is the mid-block output 177 | all_intermediate_features = [sample] 178 | for i, upsample_block in enumerate(self.up_blocks): 179 | is_final_block = i == len(self.up_blocks) - 1 180 | 181 | res_samples = down_block_res_samples[-len(upsample_block.resnets):] 182 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 183 | 184 | # if we have not reached the final block and need to forward the 185 | # upsample size, we do it here 186 | if not is_final_block and forward_upsample_size: 187 | upsample_size = down_block_res_samples[-1].shape[2:] 188 | 189 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 190 | sample = upsample_block( 191 | hidden_states=sample, 192 | temb=emb, 193 | res_hidden_states_tuple=res_samples, 194 | encoder_hidden_states=encoder_hidden_states, 195 | cross_attention_kwargs=cross_attention_kwargs, 196 | upsample_size=upsample_size, 197 | attention_mask=attention_mask, 198 | ) 199 | else: 200 | sample = upsample_block( 201 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 202 | ) 203 | all_intermediate_features.append(sample) 204 | # return early to save computation time if needed 205 | if last_up_block_idx is not None and i == last_up_block_idx: 206 | return all_intermediate_features 207 | 208 | # 6. post-process 209 | if self.conv_norm_out: 210 | sample = self.conv_norm_out(sample) 211 | sample = self.conv_act(sample) 212 | sample = self.conv_out(sample) 213 | 214 | # only difference from diffusers, return intermediate results 215 | if return_intermediates: 216 | return sample, all_intermediate_features 217 | else: 218 | return sample 219 | 220 | return forward 221 | 222 | 223 | class GoodDragger: 224 | def __init__(self, device, model_path: str, prompt: str, 225 | full_height: int, full_width: int, 226 | inversion_strength: float, 227 | r1: int = 4, r2: int = 12, beta: int = 4, 228 | drag_end_step: int = 10, track_per_denoise: int = 10, 229 | lam: float = 0.2, latent_lr: float = 0.01, 230 | n_inference_step: int = 50, guidance_scale: float = 1.0, feature_idx: int = 3, 231 | compare_mode: bool = False, 232 | vae_path: str = "default", lora_path: str = '', seed: int = 42, 233 | max_drag_per_track: int = 10, drag_loss_threshold: float = 4.0, once_drag: bool = False, 234 | max_track_no_change: int = 10): 235 | self.device = device 236 | self.vae_path = vae_path 237 | self.lora_path = lora_path 238 | scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, 239 | beta_schedule="scaled_linear", clip_sample=False, 240 | set_alpha_to_one=False, steps_offset=1) 241 | 242 | is_sdxl = 'xl' in model_path 243 | self.is_sdxl = is_sdxl 244 | if is_sdxl: 245 | self.model = StableDiffusionXLPipeline.from_pretrained(model_path, scheduler=scheduler).to(self.device) 246 | self.model.unet.config.addition_embed_type = None 247 | else: 248 | self.model = StableDiffusionPipeline.from_pretrained(model_path, scheduler=scheduler).to(self.device) 249 | self.modify_unet_forward() 250 | if vae_path != "default": 251 | self.model.vae = AutoencoderKL.from_pretrained( 252 | vae_path 253 | ).to(self.device, self.model.vae.dtype) 254 | 255 | self.set_lora() 256 | 257 | self.model.vae.requires_grad_(False) 258 | self.model.text_encoder.requires_grad_(False) 259 | 260 | seed_everything(seed) 261 | 262 | self.prompt = prompt 263 | self.full_height = full_height 264 | self.full_width = full_width 265 | self.sup_res_h = int(0.5 * full_height) 266 | self.sup_res_w = int(0.5 * full_width) 267 | 268 | self.n_inference_step = n_inference_step 269 | self.n_actual_inference_step = round(inversion_strength * self.n_inference_step) 270 | self.guidance_scale = guidance_scale 271 | 272 | self.unet_feature_idx = [feature_idx] 273 | 274 | self.r_1 = r1 275 | self.r_2 = r2 276 | self.lam = lam 277 | self.beta = beta 278 | 279 | self.lr = latent_lr 280 | self.compare_mode = compare_mode 281 | 282 | self.t2 = drag_end_step 283 | self.track_per_denoise = track_per_denoise 284 | self.total_drag = int(track_per_denoise * self.t2) 285 | 286 | self.model.scheduler.set_timesteps(self.n_inference_step) 287 | 288 | self.do_drag = True 289 | self.drag_count = 0 290 | self.max_drag_per_track = max_drag_per_track 291 | 292 | self.drag_loss_threshold = drag_loss_threshold * ((2 * self.r_1) ** 2) 293 | self.once_drag = once_drag 294 | self.no_change_track_num = 0 295 | self.max_no_change_track_num = max_track_no_change 296 | 297 | def set_lora(self): 298 | if self.lora_path == "": 299 | print("applying default parameters") 300 | self.model.unet.set_default_attn_processor() 301 | else: 302 | print("applying lora: " + self.lora_path) 303 | self.model.unet.load_attn_procs(self.lora_path) 304 | 305 | def modify_unet_forward(self): 306 | self.model.unet.forward = override_forward(self.model.unet) 307 | 308 | def get_handle_target_points(self, points): 309 | handle_points = [] 310 | target_points = [] 311 | 312 | for idx, point in enumerate(points): 313 | cur_point = torch.tensor( 314 | [point[1] / self.full_height * self.sup_res_h, point[0] / self.full_width * self.sup_res_w]) 315 | cur_point = torch.round(cur_point) 316 | if idx % 2 == 0: 317 | handle_points.append(cur_point) 318 | else: 319 | target_points.append(cur_point) 320 | print(f'handle points: {handle_points}') 321 | print(f'target points: {target_points}') 322 | return handle_points, target_points 323 | 324 | def inv_step( 325 | self, 326 | model_output: torch.FloatTensor, 327 | timestep: int, 328 | x: torch.FloatTensor, 329 | verbose=False 330 | ): 331 | """ 332 | Inverse sampling for DDIM Inversion 333 | """ 334 | if verbose: 335 | print("timestep: ", timestep) 336 | next_step = timestep 337 | timestep = min( 338 | timestep - self.model.scheduler.config.num_train_timesteps // self.model.scheduler.num_inference_steps, 999) 339 | alpha_prod_t = self.model.scheduler.alphas_cumprod[ 340 | timestep] if timestep >= 0 else self.model.scheduler.final_alpha_cumprod 341 | alpha_prod_t_next = self.model.scheduler.alphas_cumprod[next_step] 342 | beta_prod_t = 1 - alpha_prod_t 343 | pred_x0 = (x - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 344 | pred_dir = (1 - alpha_prod_t_next) ** 0.5 * model_output 345 | x_next = alpha_prod_t_next ** 0.5 * pred_x0 + pred_dir 346 | return x_next, pred_x0 347 | 348 | @torch.no_grad() 349 | def image2latent(self, image): 350 | if type(image) is Image: 351 | image = np.array(image) 352 | image = torch.from_numpy(image).float() / 127.5 - 1 353 | image = image.permute(2, 0, 1).unsqueeze(0).to(self.device) 354 | 355 | latents = self.model.vae.encode(image)['latent_dist'].mean 356 | latents = latents * 0.18215 357 | return latents 358 | 359 | @torch.no_grad() 360 | def latent2image(self, latents, return_type='np'): 361 | latents = 1 / 0.18215 * latents.detach() 362 | image = self.model.vae.decode(latents)['sample'] 363 | if return_type == 'np': 364 | image = (image / 2 + 0.5).clamp(0, 1) 365 | image = image.cpu().permute(0, 2, 3, 1).numpy()[0] 366 | image = (image * 255).astype(np.uint8) 367 | elif return_type == "pt": 368 | image = (image / 2 + 0.5).clamp(0, 1) 369 | 370 | return image 371 | 372 | @torch.no_grad() 373 | def get_text_embeddings(self, prompt): 374 | text_input = self.model.tokenizer( 375 | prompt, 376 | padding="max_length", 377 | max_length=77, 378 | return_tensors="pt" 379 | ) 380 | text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.device))[0] 381 | return text_embeddings 382 | 383 | def forward_unet_features(self, z, t, encoder_hidden_states): 384 | unet_output, all_intermediate_features = self.model.unet( 385 | z, 386 | t, 387 | encoder_hidden_states=encoder_hidden_states, 388 | return_intermediates=True 389 | ) 390 | 391 | all_return_features = [] 392 | for idx in self.unet_feature_idx: 393 | feat = all_intermediate_features[idx] 394 | feat = F.interpolate(feat, (self.sup_res_h, self.sup_res_w), mode='bilinear') 395 | all_return_features.append(feat) 396 | return_features = torch.cat(all_return_features, dim=1) 397 | 398 | del all_intermediate_features 399 | torch.cuda.empty_cache() 400 | 401 | return unet_output, return_features 402 | 403 | @torch.no_grad() 404 | def invert( 405 | self, 406 | image: torch.Tensor, 407 | prompt, 408 | return_intermediates=False, 409 | ): 410 | """ 411 | invert a real image into noise map with determinisc DDIM inversion 412 | """ 413 | batch_size = image.shape[0] 414 | if isinstance(prompt, list): 415 | if batch_size == 1: 416 | image = image.expand(len(prompt), -1, -1, -1) 417 | elif isinstance(prompt, str): 418 | if batch_size > 1: 419 | prompt = [prompt] * batch_size 420 | 421 | if self.is_sdxl: 422 | text_embeddings, _, _, _ = self.model.encode_prompt(prompt) 423 | else: 424 | text_embeddings = self.get_text_embeddings(prompt) 425 | 426 | latents = self.image2latent(image) 427 | 428 | if self.guidance_scale > 1.: 429 | unconditional_embeddings = self.get_text_embeddings([''] * batch_size) 430 | text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0) 431 | 432 | print("Valid timesteps: ", self.model.scheduler.timesteps) 433 | latents_list = [latents] 434 | pred_x0_list = [latents] 435 | for i, t in enumerate(tqdm(reversed(self.model.scheduler.timesteps), desc="DDIM Inversion")): 436 | if self.n_actual_inference_step is not None and i >= self.n_actual_inference_step: 437 | continue 438 | 439 | if self.guidance_scale > 1.: 440 | model_inputs = torch.cat([latents] * 2) 441 | else: 442 | model_inputs = latents 443 | 444 | t_ = self.model.scheduler.timesteps[-(i + 2)] 445 | 446 | noise_pred = self.model.unet(model_inputs, t, encoder_hidden_states=text_embeddings) 447 | if self.guidance_scale > 1.: 448 | noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0) 449 | noise_pred = noise_pred_uncon + self.guidance_scale * (noise_pred_con - noise_pred_uncon) 450 | 451 | latents, pred_x0 = self.inv_step(noise_pred, t, latents) 452 | latents_list.append(latents) 453 | pred_x0_list.append(pred_x0) 454 | 455 | if return_intermediates: 456 | return latents, latents_list 457 | return latents 458 | 459 | def get_original_features(self, init_code, text_embeddings): 460 | timesteps = self.model.scheduler.timesteps 461 | strat_time_step_idx = self.n_inference_step - self.n_actual_inference_step 462 | original_step_output = {} 463 | features = {} 464 | cur_latents = init_code.detach().clone() 465 | with torch.no_grad(): 466 | for i, t in enumerate(tqdm(timesteps[strat_time_step_idx:], 467 | desc="Denosing for mask features")): 468 | if i <= self.t2: 469 | model_inputs = cur_latents 470 | noise_pred, F0 = self.forward_unet_features(model_inputs, t, encoder_hidden_states=text_embeddings) 471 | cur_latents = self.model.scheduler.step(noise_pred, t, model_inputs, return_dict=False)[0] 472 | original_step_output[t.item()] = cur_latents.cpu() 473 | features[t.item()] = F0.cpu() 474 | 475 | del noise_pred, cur_latents, F0 476 | torch.cuda.empty_cache() 477 | 478 | return original_step_output, features 479 | 480 | def get_noise_features(self, input_latents, t, text_embeddings): 481 | unet_output, F1 = self.forward_unet_features(input_latents, t, encoder_hidden_states=text_embeddings) 482 | return unet_output, F1 483 | 484 | def cal_motion_supervision_loss(self, handle_points, target_points, F1, x_prev_updated, original_prev, 485 | interp_mask, original_features, original_points, alpha=None): 486 | drag_loss = 0.0 487 | for i_ in range(len(handle_points)): 488 | pi, ti = handle_points[i_], target_points[i_] 489 | norm_dis = (ti - pi).norm() 490 | if norm_dis < 2.: 491 | continue 492 | 493 | di = (ti - pi) / (ti - pi).norm() * min(self.beta, norm_dis) 494 | 495 | original_features.requires_grad_(True) 496 | pi = original_points[i_] 497 | f0_patch = original_features[:, :, int(pi[0]) - self.r_1:int(pi[0]) + self.r_1 + 1, 498 | int(pi[1]) - self.r_1:int(pi[1]) + self.r_1 + 1].detach() 499 | 500 | pi = handle_points[i_] 501 | f1_patch = interpolate_feature_patch(F1, pi[0] + di[0], pi[1] + di[1], self.r_1) 502 | drag_loss += ((2 * self.r_1) ** 2) * F.l1_loss(f0_patch, f1_patch) 503 | 504 | print(f'Loss from drag: {drag_loss}') 505 | loss = drag_loss + self.lam * ((x_prev_updated - original_prev) 506 | * (1.0 - interp_mask)).abs().sum() 507 | print('Loss total=%f' % loss) 508 | return loss, drag_loss 509 | 510 | def track_step(self, original_feature, original_feature_, F1, F1_, handle_points, handle_points_init): 511 | if self.compare_mode: 512 | handle_points = point_tracking(original_feature, 513 | F1, handle_points, handle_points_init, self.r_2) 514 | else: 515 | handle_points = point_tracking(original_feature_, 516 | F1_, handle_points, handle_points_init, self.r_2) 517 | return handle_points 518 | 519 | def compare_tensor_lists(self, lst1, lst2): 520 | if len(lst1) != len(lst2): 521 | return False 522 | return all(torch.equal(t1, t2) for t1, t2 in zip(lst1, lst2)) 523 | 524 | def gooddrag_step(self, init_code, t, t_, text_embeddings, handle_points, target_points, 525 | features, handle_points_init, original_step_output, interp_mask): 526 | drag_latents = init_code.clone().detach() 527 | drag_latents.requires_grad_(True) 528 | 529 | first_drag = True 530 | need_track = False 531 | track_num = 0 532 | cur_drag_per_track = 0 533 | self.compare_mode = True 534 | accelerator = Accelerator( 535 | gradient_accumulation_steps=1, 536 | mixed_precision='fp16' 537 | ) 538 | 539 | optimizer = torch.optim.Adam([drag_latents], lr=self.lr) 540 | 541 | drag_latents, self.model.unet, optimizer = accelerator.prepare(drag_latents, self.model.unet, optimizer) 542 | while track_num < self.track_per_denoise: 543 | optimizer.zero_grad() 544 | unet_output, F1 = self.forward_unet_features(drag_latents, t, text_embeddings) 545 | x_prev_updated = self.model.scheduler.step(unet_output, t, drag_latents, return_dict=False)[0] 546 | 547 | if (need_track or first_drag) and (not self.compare_mode): 548 | with torch.no_grad(): 549 | _, F1_ = self.forward_unet_features(x_prev_updated, t_, text_embeddings) 550 | 551 | if first_drag: 552 | first_drag = False 553 | if self.compare_mode: 554 | handle_points = point_tracking(features[t.item()].cuda(), 555 | F1, handle_points, handle_points_init, self.r_2) 556 | else: 557 | handle_points = point_tracking(features[t_.item()].cuda(), 558 | F1_, handle_points, handle_points_init, self.r_2) 559 | 560 | print(f'After denoise new handle points: {handle_points}, drag count: {self.drag_count}') 561 | 562 | # break if all handle points have reached the targets 563 | if check_handle_reach_target(handle_points, target_points): 564 | self.do_drag = False 565 | print('Reached the target points') 566 | break 567 | 568 | if self.no_change_track_num == self.max_no_change_track_num: 569 | self.do_drag = False 570 | print('Early stop.') 571 | break 572 | 573 | del unet_output 574 | if need_track and (not self.compare_mode): 575 | del _ 576 | torch.cuda.empty_cache() 577 | 578 | loss, drag_loss = self.cal_motion_supervision_loss(handle_points, target_points, F1, x_prev_updated, 579 | original_step_output[t.item()].cuda(), interp_mask, 580 | original_features=features[t.item()].cuda(), 581 | original_points=handle_points_init) 582 | 583 | accelerator.backward(loss) 584 | 585 | optimizer.step() 586 | 587 | cur_drag_per_track += 1 588 | need_track = (cur_drag_per_track == self.max_drag_per_track) or ( 589 | drag_loss <= self.drag_loss_threshold) or self.once_drag 590 | if need_track: 591 | track_num += 1 592 | handle_points_cur = copy.deepcopy(handle_points) 593 | if self.compare_mode: 594 | handle_points = point_tracking(features[t.item()].cuda(), 595 | F1, handle_points, handle_points_init, self.r_2) 596 | else: 597 | handle_points = point_tracking(features[t_.item()].cuda(), 598 | F1_, handle_points, handle_points_init, self.r_2) 599 | 600 | if self.compare_tensor_lists(handle_points, handle_points_cur): 601 | self.no_change_track_num += 1 602 | print(f'{self.no_change_track_num} times handle points no changes.') 603 | else: 604 | self.no_change_track_num = 0 605 | 606 | self.drag_count += 1 607 | cur_drag_per_track = 0 608 | print(f'New handle points: {handle_points}, drag count: {self.drag_count}') 609 | 610 | init_code = drag_latents.clone().detach() 611 | init_code.requires_grad_(False) 612 | del optimizer, drag_latents 613 | torch.cuda.empty_cache() 614 | 615 | return init_code, handle_points 616 | 617 | def prepare_mask(self, mask): 618 | mask = torch.from_numpy(mask).float() / 255. 619 | mask[mask > 0.0] = 1.0 620 | mask = rearrange(mask, "h w -> 1 1 h w").cuda() 621 | mask = F.interpolate(mask, (self.sup_res_h, self.sup_res_w), mode="nearest") 622 | return mask 623 | 624 | def set_latent_masactrl(self): 625 | editor = MutualSelfAttentionControl(start_step=0, 626 | start_layer=10, 627 | total_steps=self.n_inference_step, 628 | guidance_scale=self.guidance_scale) 629 | if self.lora_path == "": 630 | register_attention_editor_diffusers(self.model, editor, attn_processor='attn_proc') 631 | else: 632 | register_attention_editor_diffusers(self.model, editor, attn_processor='lora_attn_proc') 633 | 634 | def get_intermediate_images(self, intermediate_images, intermediate_images_original, intermediate_images_t_idx, 635 | valid_timestep, text_embeddings): 636 | for i in range(len(intermediate_images)-1): 637 | current_original_code = intermediate_images_original[i].to(self.device) 638 | current_init_code = intermediate_images[i].to(self.device) 639 | 640 | self.set_latent_masactrl() 641 | 642 | for inter_i, inter_t in enumerate(valid_timestep[intermediate_images_t_idx[i] + 1:]): 643 | with torch.no_grad(): 644 | noise_pred_all = self.model.unet(torch.cat([current_original_code, current_init_code]), inter_t, 645 | encoder_hidden_states=torch.cat( 646 | [text_embeddings, text_embeddings])) 647 | noise_pred = noise_pred_all[1] 648 | noise_pred_original = noise_pred_all[0] 649 | current_init_code = \ 650 | self.model.scheduler.step(noise_pred, inter_t, current_init_code, return_dict=False)[0] 651 | current_original_code = \ 652 | self.model.scheduler.step(noise_pred_original, inter_t, current_original_code, 653 | return_dict=False)[0] 654 | intermediate_images[i] = self.latent2image(current_init_code, return_type="pt").cpu() 655 | intermediate_images.pop() 656 | return intermediate_images 657 | 658 | def good_drag(self, 659 | source_image, 660 | points, 661 | mask, 662 | return_intermediate_images=False, 663 | return_intermediate_features=False 664 | ): 665 | init_code = self.invert(source_image, self.prompt) 666 | original_init = init_code.detach().clone() 667 | if self.is_sdxl: 668 | text_embeddings, _, _, _ = self.model.encode_prompt(self.prompt) 669 | text_embeddings = text_embeddings.detach() 670 | else: 671 | text_embeddings = self.get_text_embeddings(self.prompt).detach() 672 | 673 | self.model.text_encoder.to('cpu') 674 | self.model.vae.encoder.to('cpu') 675 | 676 | timesteps = self.model.scheduler.timesteps 677 | start_time_step_idx = self.n_inference_step - self.n_actual_inference_step 678 | 679 | handle_points, target_points = self.get_handle_target_points(points) 680 | original_step_output, features = self.get_original_features(init_code, text_embeddings) 681 | 682 | handle_points_init = copy.deepcopy(handle_points) 683 | mask = self.prepare_mask(mask) 684 | interp_mask = F.interpolate(mask, (init_code.shape[2], init_code.shape[3]), mode='nearest') 685 | 686 | intermediate_features = [init_code.detach().clone().cpu()] if return_intermediate_features else [] 687 | valid_timestep = timesteps[start_time_step_idx:] 688 | set_mutual = True 689 | 690 | intermediate_images, intermediate_images_original, intermediate_images_t_idx = [], [], [] 691 | 692 | did_drag = False 693 | for i, t in enumerate(tqdm(valid_timestep, 694 | desc="Drag and Denoise")): 695 | if i < self.t2 and self.do_drag and (self.no_change_track_num != self.max_no_change_track_num): 696 | t_ = valid_timestep[i + 1] 697 | init_code, handle_points = self.gooddrag_step(init_code, t, t_, text_embeddings, handle_points, 698 | target_points, features, handle_points_init, 699 | original_step_output, interp_mask) 700 | did_drag = True 701 | else: 702 | if set_mutual: 703 | set_mutual = False 704 | self.set_latent_masactrl() 705 | 706 | with torch.no_grad(): 707 | noise_pred_all = self.model.unet(torch.cat([original_init, init_code]), t, 708 | encoder_hidden_states=torch.cat([text_embeddings, text_embeddings])) 709 | noise_pred = noise_pred_all[1] 710 | noise_pred_original = noise_pred_all[0] 711 | init_code = self.model.scheduler.step(noise_pred, t, init_code, return_dict=False)[0] 712 | original_init = self.model.scheduler.step(noise_pred_original, t, original_init, return_dict=False)[0] 713 | 714 | if did_drag and return_intermediate_images: 715 | current_init_code = init_code.detach().clone() 716 | current_original_code = original_init.detach().clone() 717 | 718 | intermediate_images.append(current_init_code.cpu()) 719 | intermediate_images_original.append(current_original_code.cpu()) 720 | intermediate_images_t_idx.append(i) 721 | did_drag = False 722 | if return_intermediate_features: 723 | intermediate_features.append(init_code.detach().clone().cpu()) 724 | 725 | if return_intermediate_images: 726 | intermediate_images = self.get_intermediate_images(intermediate_images, intermediate_images_original, 727 | intermediate_images_t_idx, valid_timestep, text_embeddings) 728 | 729 | image = self.latent2image(init_code, return_type="pt") 730 | return image, intermediate_features, handle_points, intermediate_images 731 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --find-links https://download.pytorch.org/whl/torch_stable.html 2 | torch==2.0.1+cu118 3 | scipy>=1.11.1 4 | transformers==4.27.0 5 | diffusers==0.20.0 6 | gradio==3.50.2 7 | einops==0.6.1 8 | pytorch_lightning==2.0.8 9 | accelerate==0.17.0 10 | opencv-python==4.8.0.76 11 | torchvision==0.15.2+cu118 12 | -------------------------------------------------------------------------------- /utils/__pycache__/attn_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/utils/__pycache__/attn_utils.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/drag_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/utils/__pycache__/drag_utils.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/lora_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/utils/__pycache__/lora_utils.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/ui_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/utils/__pycache__/ui_utils.cpython-310.pyc -------------------------------------------------------------------------------- /utils/attn_utils.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ************************************************************************* 14 | 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | from einops import rearrange, repeat 21 | 22 | 23 | class AttentionBase: 24 | 25 | def __init__(self): 26 | self.cur_step = 0 27 | self.num_att_layers = -1 28 | self.cur_att_layer = 0 29 | 30 | def after_step(self): 31 | pass 32 | 33 | def __call__(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs): 34 | out = self.forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs) 35 | self.cur_att_layer += 1 36 | if self.cur_att_layer == self.num_att_layers: 37 | self.cur_att_layer = 0 38 | self.cur_step += 1 39 | # after step 40 | self.after_step() 41 | return out 42 | 43 | def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs): 44 | out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) 45 | out = rearrange(out, 'b h n d -> b n (h d)') 46 | return out 47 | 48 | def reset(self): 49 | self.cur_step = 0 50 | self.cur_att_layer = 0 51 | 52 | 53 | class MutualSelfAttentionControl(AttentionBase): 54 | 55 | def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, guidance_scale=7.5): 56 | """ 57 | Mutual self-attention control for Stable-Diffusion model 58 | Args: 59 | start_step: the step to start mutual self-attention control 60 | start_layer: the layer to start mutual self-attention control 61 | layer_idx: list of the layers to apply mutual self-attention control 62 | step_idx: list the steps to apply mutual self-attention control 63 | total_steps: the total number of steps 64 | """ 65 | super().__init__() 66 | self.total_steps = total_steps 67 | self.start_step = start_step 68 | self.start_layer = start_layer 69 | self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, 16)) 70 | self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps)) 71 | # store the guidance scale to decide whether there are unconditional branch 72 | self.guidance_scale = guidance_scale 73 | print("step_idx: ", self.step_idx) 74 | print("layer_idx: ", self.layer_idx) 75 | 76 | def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs): 77 | """ 78 | Attention forward function 79 | """ 80 | if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx: 81 | return super().forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs) 82 | 83 | if self.guidance_scale > 1.0: 84 | qu, qc = q[0:2], q[2:4] 85 | ku, kc = k[0:2], k[2:4] 86 | vu, vc = v[0:2], v[2:4] 87 | 88 | # merge queries of source and target branch into one so we can use torch API 89 | qu = torch.cat([qu[0:1], qu[1:2]], dim=2) 90 | qc = torch.cat([qc[0:1], qc[1:2]], dim=2) 91 | 92 | out_u = F.scaled_dot_product_attention(qu, ku[0:1], vu[0:1], attn_mask=None, dropout_p=0.0, is_causal=False) 93 | out_u = torch.cat(out_u.chunk(2, dim=2), dim=0) # split the queries into source and target batch 94 | out_u = rearrange(out_u, 'b h n d -> b n (h d)') 95 | 96 | out_c = F.scaled_dot_product_attention(qc, kc[0:1], vc[0:1], attn_mask=None, dropout_p=0.0, is_causal=False) 97 | out_c = torch.cat(out_c.chunk(2, dim=2), dim=0) # split the queries into source and target batch 98 | out_c = rearrange(out_c, 'b h n d -> b n (h d)') 99 | 100 | out = torch.cat([out_u, out_c], dim=0) 101 | else: 102 | q = torch.cat([q[0:1], q[1:2]], dim=2) 103 | out = F.scaled_dot_product_attention(q, k[0:1], v[0:1], attn_mask=None, dropout_p=0.0, is_causal=False) 104 | out = torch.cat(out.chunk(2, dim=2), dim=0) # split the queries into source and target batch 105 | out = rearrange(out, 'b h n d -> b n (h d)') 106 | return out 107 | 108 | # forward function for default attention processor 109 | # modified from __call__ function of AttnProcessor in diffusers 110 | def override_attn_proc_forward(attn, editor, place_in_unet): 111 | def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None): 112 | """ 113 | The attention is similar to the original implementation of LDM CrossAttention class 114 | except adding some modifications on the attention 115 | """ 116 | if encoder_hidden_states is not None: 117 | context = encoder_hidden_states 118 | if attention_mask is not None: 119 | mask = attention_mask 120 | 121 | to_out = attn.to_out 122 | if isinstance(to_out, nn.modules.container.ModuleList): 123 | to_out = attn.to_out[0] 124 | else: 125 | to_out = attn.to_out 126 | 127 | h = attn.heads 128 | q = attn.to_q(x) 129 | is_cross = context is not None 130 | context = context if is_cross else x 131 | k = attn.to_k(context) 132 | v = attn.to_v(context) 133 | 134 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) 135 | 136 | # the only difference 137 | out = editor( 138 | q, k, v, is_cross, place_in_unet, 139 | attn.heads, scale=attn.scale) 140 | 141 | return to_out(out) 142 | 143 | return forward 144 | 145 | # forward function for lora attention processor 146 | # modified from __call__ function of LoRAAttnProcessor2_0 in diffusers v0.17.1 147 | def override_lora_attn_proc_forward(attn, editor, place_in_unet): 148 | def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, lora_scale=1.0): 149 | residual = hidden_states 150 | input_ndim = hidden_states.ndim 151 | is_cross = encoder_hidden_states is not None 152 | 153 | if input_ndim == 4: 154 | batch_size, channel, height, width = hidden_states.shape 155 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 156 | 157 | batch_size, sequence_length, _ = ( 158 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 159 | ) 160 | 161 | if attention_mask is not None: 162 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 163 | # scaled_dot_product_attention expects attention_mask shape to be 164 | # (batch, heads, source_length, target_length) 165 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 166 | 167 | if attn.group_norm is not None: 168 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 169 | 170 | query = attn.to_q(hidden_states) + lora_scale * attn.processor.to_q_lora(hidden_states) 171 | 172 | if encoder_hidden_states is None: 173 | encoder_hidden_states = hidden_states 174 | elif attn.norm_cross: 175 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 176 | 177 | key = attn.to_k(encoder_hidden_states) + lora_scale * attn.processor.to_k_lora(encoder_hidden_states) 178 | value = attn.to_v(encoder_hidden_states) + lora_scale * attn.processor.to_v_lora(encoder_hidden_states) 179 | 180 | query, key, value = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=attn.heads), (query, key, value)) 181 | 182 | # the only difference 183 | hidden_states = editor( 184 | query, key, value, is_cross, place_in_unet, 185 | attn.heads, scale=attn.scale) 186 | 187 | # linear proj 188 | hidden_states = attn.to_out[0](hidden_states) + lora_scale * attn.processor.to_out_lora(hidden_states) 189 | # dropout 190 | hidden_states = attn.to_out[1](hidden_states) 191 | 192 | if input_ndim == 4: 193 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 194 | 195 | if attn.residual_connection: 196 | hidden_states = hidden_states + residual 197 | 198 | hidden_states = hidden_states / attn.rescale_output_factor 199 | 200 | return hidden_states 201 | 202 | return forward 203 | 204 | def register_attention_editor_diffusers(model, editor: AttentionBase, attn_processor='attn_proc'): 205 | """ 206 | Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt] 207 | """ 208 | def register_editor(net, count, place_in_unet): 209 | for name, subnet in net.named_children(): 210 | if net.__class__.__name__ == 'Attention': # spatial Transformer layer 211 | if attn_processor == 'attn_proc': 212 | net.forward = override_attn_proc_forward(net, editor, place_in_unet) 213 | elif attn_processor == 'lora_attn_proc': 214 | net.forward = override_lora_attn_proc_forward(net, editor, place_in_unet) 215 | else: 216 | raise NotImplementedError("not implemented") 217 | return count + 1 218 | elif hasattr(net, 'children'): 219 | count = register_editor(subnet, count, place_in_unet) 220 | return count 221 | 222 | cross_att_count = 0 223 | for net_name, net in model.unet.named_children(): 224 | if "down" in net_name: 225 | cross_att_count += register_editor(net, 0, "down") 226 | elif "mid" in net_name: 227 | cross_att_count += register_editor(net, 0, "mid") 228 | elif "up" in net_name: 229 | cross_att_count += register_editor(net, 0, "up") 230 | editor.num_att_layers = cross_att_count 231 | -------------------------------------------------------------------------------- /utils/drag_utils.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ************************************************************************* 14 | 15 | 16 | import torch 17 | from typing import List 18 | 19 | 20 | def calculate_l1_distance(tensor1: torch.Tensor, tensor2: torch.Tensor) -> torch.Tensor: 21 | """Calculate the L1 (Manhattan) distance between two tensors.""" 22 | return torch.sum(torch.abs(tensor1 - tensor2), dim=1) 23 | 24 | 25 | def calculate_l2_distance(tensor1: torch.Tensor, tensor2: torch.Tensor) -> torch.Tensor: 26 | """Calculate the L2 (Euclidean) distance between two tensors.""" 27 | return torch.sqrt(torch.sum((tensor1 - tensor2) ** 2, dim=1)) 28 | 29 | 30 | def calculate_cosine_similarity(tensor1: torch.Tensor, tensor2: torch.Tensor) -> torch.Tensor: 31 | """Calculate the Cosine Similarity between two tensors.""" 32 | numerator = torch.sum(tensor1 * tensor2, dim=1) 33 | denominator = torch.sqrt(torch.sum(tensor1 ** 2, dim=1)) * torch.sqrt(torch.sum(tensor2 ** 2, dim=1)) 34 | return numerator / denominator 35 | 36 | 37 | def get_neighboring_patch(tensor: torch.Tensor, center: tuple, radius: int) -> torch.Tensor: 38 | """Get a neighboring patch from a tensor centered at a specific point.""" 39 | r1, r2 = int(center[0]) - radius, int(center[0]) + radius + 1 40 | c1, c2 = int(center[1]) - radius, int(center[1]) + radius + 1 41 | return tensor[:, :, r1:r2, c1:c2] 42 | 43 | 44 | def update_handle_points(handle_points: torch.Tensor, all_dist: torch.Tensor, r2) -> torch.Tensor: 45 | """Update handle points based on computed distances.""" 46 | row, col = divmod(all_dist.argmin().item(), all_dist.shape[-1]) 47 | updated_point = torch.tensor([ 48 | handle_points[0] - r2 + row, 49 | handle_points[1] - r2 + col 50 | ]) 51 | return updated_point 52 | 53 | 54 | def point_tracking(F0: torch.Tensor, F1: torch.Tensor, handle_points: List[torch.Tensor], 55 | handle_points_init: List[torch.Tensor], r2, distance_type: str = 'l1') -> List[torch.Tensor]: 56 | """Track points between F0 and F1 tensors.""" 57 | with torch.no_grad(): 58 | for i in range(len(handle_points)): 59 | pi0, pi = handle_points_init[i], handle_points[i] 60 | f0 = F0[:, :, int(pi0[0]), int(pi0[1])] 61 | f0_expanded = f0.unsqueeze(dim=-1).unsqueeze(dim=-1) 62 | 63 | F1_neighbor = get_neighboring_patch(F1, pi, r2) 64 | 65 | # Switch case for different distance functions 66 | if distance_type == 'l1': 67 | all_dist = calculate_l1_distance(f0_expanded, F1_neighbor) 68 | elif distance_type == 'l2': 69 | all_dist = calculate_l2_distance(f0_expanded, F1_neighbor) 70 | elif distance_type == 'cosine': 71 | all_dist = -calculate_cosine_similarity(f0_expanded, F1_neighbor) # Negative for minimization 72 | 73 | all_dist = all_dist.squeeze(dim=0) 74 | handle_points[i] = update_handle_points(pi, all_dist, r2) 75 | 76 | return handle_points 77 | 78 | 79 | def interpolate_feature_patch(feat: torch.Tensor, y: float, x: float, r: int) -> torch.Tensor: 80 | """Obtain the bilinear interpolated feature patch.""" 81 | x0, y0 = torch.floor(x).long(), torch.floor(y).long() 82 | x1, y1 = x0 + 1, y0 + 1 83 | 84 | weights = torch.tensor([(x1 - x) * (y1 - y), (x1 - x) * (y - y0), (x - x0) * (y1 - y), (x - x0) * (y - y0)]) 85 | weights = weights.to(feat.device) 86 | 87 | patches = torch.stack([ 88 | feat[:, :, y0 - r:y0 + r + 1, x0 - r:x0 + r + 1], 89 | feat[:, :, y1 - r:y1 + r + 1, x0 - r:x0 + r + 1], 90 | feat[:, :, y0 - r:y0 + r + 1, x1 - r:x1 + r + 1], 91 | feat[:, :, y1 - r:y1 + r + 1, x1 - r:x1 + r + 1] 92 | ]) 93 | 94 | return torch.sum(weights.view(-1, 1, 1, 1, 1) * patches, dim=0) 95 | 96 | 97 | def check_handle_reach_target(handle_points: list, target_points: list) -> bool: 98 | """Check if handle points are close to target points.""" 99 | all_dists = torch.tensor([(p - q).norm().item() for p, q in zip(handle_points, target_points)]) 100 | return (all_dists < 2.0).all().item() 101 | -------------------------------------------------------------------------------- /utils/lora_utils.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ************************************************************************* 14 | 15 | 16 | from PIL import Image 17 | import os 18 | import numpy as np 19 | from einops import rearrange 20 | import torch 21 | import torch.nn.functional as F 22 | from torchvision import transforms 23 | from accelerate import Accelerator 24 | from accelerate.utils import set_seed 25 | from PIL import Image 26 | from tqdm import tqdm 27 | 28 | from transformers import AutoTokenizer, PretrainedConfig 29 | 30 | import diffusers 31 | from diffusers import ( 32 | AutoencoderKL, 33 | DDPMScheduler, 34 | DiffusionPipeline, 35 | DPMSolverMultistepScheduler, 36 | StableDiffusionPipeline, 37 | UNet2DConditionModel, 38 | ) 39 | from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin 40 | from diffusers.models.attention_processor import ( 41 | AttnAddedKVProcessor, 42 | AttnAddedKVProcessor2_0, 43 | LoRAAttnAddedKVProcessor, 44 | LoRAAttnProcessor, 45 | LoRAAttnProcessor2_0, 46 | SlicedAttnAddedKVProcessor, 47 | ) 48 | from diffusers.optimization import get_scheduler 49 | from diffusers.utils import check_min_version 50 | from diffusers.utils.import_utils import is_xformers_available 51 | from diffusers import StableDiffusionXLPipeline 52 | 53 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 54 | check_min_version("0.17.0") 55 | 56 | 57 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): 58 | text_encoder_config = PretrainedConfig.from_pretrained( 59 | pretrained_model_name_or_path, 60 | subfolder="text_encoder", 61 | revision=revision, 62 | ) 63 | model_class = text_encoder_config.architectures[0] 64 | 65 | if model_class == "CLIPTextModel": 66 | from transformers import CLIPTextModel 67 | 68 | return CLIPTextModel 69 | elif model_class == "RobertaSeriesModelWithTransformation": 70 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation 71 | 72 | return RobertaSeriesModelWithTransformation 73 | elif model_class == "T5EncoderModel": 74 | from transformers import T5EncoderModel 75 | 76 | return T5EncoderModel 77 | else: 78 | raise ValueError(f"{model_class} is not supported.") 79 | 80 | 81 | def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): 82 | if tokenizer_max_length is not None: 83 | max_length = tokenizer_max_length 84 | else: 85 | max_length = tokenizer.model_max_length 86 | 87 | text_inputs = tokenizer( 88 | prompt, 89 | truncation=True, 90 | padding="max_length", 91 | max_length=max_length, 92 | return_tensors="pt", 93 | ) 94 | 95 | return text_inputs 96 | 97 | 98 | def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=False): 99 | text_input_ids = input_ids.to(text_encoder.device) 100 | 101 | if text_encoder_use_attention_mask: 102 | attention_mask = attention_mask.to(text_encoder.device) 103 | else: 104 | attention_mask = None 105 | 106 | prompt_embeds = text_encoder( 107 | text_input_ids, 108 | attention_mask=attention_mask, 109 | ) 110 | prompt_embeds = prompt_embeds[0] 111 | 112 | return prompt_embeds 113 | 114 | 115 | # model_path: path of the model 116 | # image: input image, have not been pre-processed 117 | # save_lora_path: the path to save the lora 118 | # prompt: the user input prompt 119 | # lora_step: number of lora training step 120 | # lora_lr: learning rate of lora training 121 | # lora_rank: the rank of lora 122 | # save_interval: the frequency of saving lora checkpoints 123 | def train_lora(image, 124 | prompt, 125 | model_path, 126 | vae_path, 127 | save_lora_path, 128 | lora_step, 129 | lora_lr, 130 | lora_batch_size, 131 | lora_rank, 132 | progress, 133 | use_gradio_progress=True, 134 | save_interval=-1): 135 | # initialize accelerator 136 | accelerator = Accelerator( 137 | gradient_accumulation_steps=1, 138 | mixed_precision='fp16' 139 | ) 140 | set_seed(0) 141 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 142 | 143 | is_sdxl = 'xl' in model_path 144 | if is_sdxl: 145 | model = StableDiffusionXLPipeline.from_pretrained(model_path).to(device) 146 | tokenizer = model.tokenizer 147 | noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler") 148 | text_encoder = model.text_encoder 149 | vae = model.vae 150 | unet = model.unet 151 | unet.config.addition_embed_type = None 152 | else: 153 | # Load the tokenizer 154 | tokenizer = AutoTokenizer.from_pretrained( 155 | model_path, 156 | subfolder="tokenizer", 157 | revision=None, 158 | use_fast=False, 159 | ) 160 | # initialize the model 161 | noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler") 162 | text_encoder_cls = import_model_class_from_model_name_or_path(model_path, revision=None) 163 | text_encoder = text_encoder_cls.from_pretrained( 164 | model_path, subfolder="text_encoder", revision=None 165 | ) 166 | if vae_path == "default": 167 | vae = AutoencoderKL.from_pretrained( 168 | model_path, subfolder="vae", revision=None 169 | ) 170 | else: 171 | vae = AutoencoderKL.from_pretrained(vae_path) 172 | 173 | unet = UNet2DConditionModel.from_pretrained( 174 | model_path, subfolder="unet", revision=None 175 | ) 176 | 177 | # set device and dtype 178 | 179 | vae.requires_grad_(False) 180 | text_encoder.requires_grad_(False) 181 | unet.requires_grad_(False) 182 | 183 | unet.to(device, dtype=torch.float16) 184 | vae.to(device, dtype=torch.float16) 185 | text_encoder.to(device, dtype=torch.float16) 186 | 187 | # initialize UNet LoRA 188 | unet_lora_attn_procs = {} 189 | for name, attn_processor in unet.attn_processors.items(): 190 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 191 | if name.startswith("mid_block"): 192 | hidden_size = unet.config.block_out_channels[-1] 193 | elif name.startswith("up_blocks"): 194 | block_id = int(name[len("up_blocks.")]) 195 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 196 | elif name.startswith("down_blocks"): 197 | block_id = int(name[len("down_blocks.")]) 198 | hidden_size = unet.config.block_out_channels[block_id] 199 | else: 200 | raise NotImplementedError("name must start with up_blocks, mid_blocks, or down_blocks") 201 | 202 | if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): 203 | lora_attn_processor_class = LoRAAttnAddedKVProcessor 204 | else: 205 | lora_attn_processor_class = ( 206 | LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor 207 | ) 208 | unet_lora_attn_procs[name] = lora_attn_processor_class( 209 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank 210 | ) 211 | 212 | unet.set_attn_processor(unet_lora_attn_procs) 213 | unet_lora_layers = AttnProcsLayers(unet.attn_processors) 214 | 215 | # Optimizer creation 216 | params_to_optimize = (unet_lora_layers.parameters()) 217 | optimizer = torch.optim.AdamW( 218 | params_to_optimize, 219 | lr=lora_lr, 220 | betas=(0.9, 0.999), 221 | weight_decay=1e-2, 222 | eps=1e-08, 223 | ) 224 | 225 | lr_scheduler = get_scheduler( 226 | "constant", 227 | optimizer=optimizer, 228 | num_warmup_steps=0, 229 | num_training_steps=lora_step, 230 | num_cycles=1, 231 | power=1.0, 232 | ) 233 | 234 | # prepare accelerator 235 | unet_lora_layers = accelerator.prepare_model(unet_lora_layers) 236 | optimizer = accelerator.prepare_optimizer(optimizer) 237 | lr_scheduler = accelerator.prepare_scheduler(lr_scheduler) 238 | 239 | # initialize text embeddings 240 | with torch.no_grad(): 241 | if is_sdxl: 242 | text_embedding, _, _, _ = model.encode_prompt(prompt) 243 | text_embedding = text_embedding.repeat(lora_batch_size, 1, 1).half() 244 | else: 245 | text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None) 246 | text_embedding = encode_prompt( 247 | text_encoder, 248 | text_inputs.input_ids, 249 | text_inputs.attention_mask, 250 | text_encoder_use_attention_mask=False 251 | ) 252 | text_embedding = text_embedding.repeat(lora_batch_size, 1, 1) 253 | 254 | # initialize latent distribution 255 | image_transforms = transforms.Compose( 256 | [ 257 | transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), 258 | transforms.RandomCrop(512), 259 | transforms.ToTensor(), 260 | transforms.Normalize([0.5], [0.5]), 261 | ] 262 | ) 263 | 264 | unet.train() 265 | image_batch = [] 266 | for _ in range(lora_batch_size): 267 | image_transformed = image_transforms(Image.fromarray(image)).to(device, dtype=torch.float16) 268 | image_transformed = image_transformed.unsqueeze(dim=0) 269 | image_batch.append(image_transformed) 270 | 271 | # repeat the image_transformed to enable multi-batch training 272 | image_batch = torch.cat(image_batch, dim=0) 273 | 274 | latents_dist = vae.encode(image_batch).latent_dist 275 | 276 | if use_gradio_progress: 277 | progress_bar = progress.tqdm(range(lora_step), desc="training LoRA") 278 | else: 279 | progress_bar = tqdm(range(lora_step), desc="training LoRA") 280 | for step in progress_bar: 281 | # unet.train() 282 | # image_batch = [] 283 | # for _ in range(lora_batch_size): 284 | # image_transformed = image_transforms(Image.fromarray(image)).to(device, dtype=torch.float16) 285 | # image_transformed = image_transformed.unsqueeze(dim=0) 286 | # image_batch.append(image_transformed) 287 | 288 | # repeat the image_transformed to enable multi-batch training 289 | # image_batch = torch.cat(image_batch, dim=0) 290 | 291 | # latents_dist = vae.encode(image_batch).latent_dist 292 | model_input = latents_dist.sample() * vae.config.scaling_factor 293 | # Sample noise that we'll add to the latents 294 | noise = torch.randn_like(model_input) 295 | bsz, channels, height, width = model_input.shape 296 | # Sample a random timestep for each image 297 | timesteps = torch.randint( 298 | 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device 299 | ) 300 | timesteps = timesteps.long() 301 | 302 | # Add noise to the model input according to the noise magnitude at each timestep 303 | # (this is the forward diffusion process) 304 | noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) 305 | 306 | # Predict the noise residual 307 | model_pred = unet(noisy_model_input, timesteps, text_embedding).sample 308 | 309 | # Get the target for loss depending on the prediction type 310 | if noise_scheduler.config.prediction_type == "epsilon": 311 | target = noise 312 | elif noise_scheduler.config.prediction_type == "v_prediction": 313 | target = noise_scheduler.get_velocity(model_input, noise, timesteps) 314 | else: 315 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 316 | 317 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 318 | accelerator.backward(loss) 319 | optimizer.step() 320 | lr_scheduler.step() 321 | optimizer.zero_grad() 322 | 323 | if save_interval > 0 and (step + 1) % save_interval == 0: 324 | save_lora_path_intermediate = os.path.join(save_lora_path, str(step + 1)) 325 | if not os.path.isdir(save_lora_path_intermediate): 326 | os.mkdir(save_lora_path_intermediate) 327 | # unet = unet.to(torch.float32) 328 | # unwrap_model is used to remove all special modules added when doing distributed training 329 | # so here, there is no need to call unwrap_model 330 | # unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) 331 | LoraLoaderMixin.save_lora_weights( 332 | save_directory=save_lora_path_intermediate, 333 | unet_lora_layers=unet_lora_layers, 334 | text_encoder_lora_layers=None, 335 | ) 336 | # unet = unet.to(torch.float16) 337 | 338 | # save the trained lora 339 | # unet = unet.to(torch.float32) 340 | # unwrap_model is used to remove all special modules added when doing distributed training 341 | # so here, there is no need to call unwrap_model 342 | # unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) 343 | LoraLoaderMixin.save_lora_weights( 344 | save_directory=save_lora_path, 345 | unet_lora_layers=unet_lora_layers, 346 | text_encoder_lora_layers=None, 347 | ) 348 | 349 | return 350 | -------------------------------------------------------------------------------- /utils/ui_utils.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ************************************************************************* 14 | 15 | 16 | import os 17 | import shutil 18 | import json 19 | from pathlib import Path 20 | from typing import List, Tuple 21 | 22 | import cv2 23 | import numpy as np 24 | import gradio as gr 25 | from copy import deepcopy 26 | from einops import rearrange 27 | from types import SimpleNamespace 28 | 29 | import datetime 30 | import PIL 31 | from PIL import Image 32 | from PIL.ImageOps import exif_transpose 33 | import torch 34 | import torch.nn.functional as F 35 | 36 | from diffusers import DDIMScheduler, AutoencoderKL 37 | from pipeline import GoodDragger 38 | 39 | from torchvision.utils import save_image 40 | from pytorch_lightning import seed_everything 41 | 42 | from .lora_utils import train_lora 43 | 44 | 45 | # -------------- general UI functionality -------------- 46 | def clear_all(length=512): 47 | return gr.Image.update(value=None, height=length, width=length), \ 48 | gr.Image.update(value=None, height=length, width=length), \ 49 | gr.Image.update(value=None, height=length, width=length), \ 50 | [], None, None 51 | 52 | 53 | def mask_image(image, 54 | mask, 55 | color=[255, 0, 0], 56 | alpha=0.5): 57 | """ Overlay mask on image for visualization purpose. 58 | Args: 59 | image (H, W, 3) or (H, W): input image 60 | mask (H, W): mask to be overlaid 61 | color: the color of overlaid mask 62 | alpha: the transparency of the mask 63 | """ 64 | out = deepcopy(image) 65 | img = deepcopy(image) 66 | img[mask == 1] = color 67 | out = cv2.addWeighted(img, alpha, out, 1 - alpha, 0, out) 68 | return out 69 | 70 | 71 | def store_img(img, length=512): 72 | image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. 73 | height, width, _ = image.shape 74 | image = Image.fromarray(image) 75 | image = exif_transpose(image) 76 | image = image.resize((length, int(length * height / width)), PIL.Image.BILINEAR) 77 | mask = cv2.resize(mask, (length, int(length * height / width)), interpolation=cv2.INTER_NEAREST) 78 | image = np.array(image) 79 | 80 | if mask.sum() > 0: 81 | mask = np.uint8(mask > 0) 82 | masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) 83 | else: 84 | masked_img = image.copy() 85 | # when new image is uploaded, `selected_points` should be empty 86 | return image, [], masked_img, mask 87 | 88 | 89 | # user click the image to get points, and show the points on the image 90 | def get_points(img, 91 | sel_pix, 92 | evt: gr.SelectData): 93 | # collect the selected point 94 | sel_pix.append(evt.index) 95 | # draw points 96 | points = [] 97 | for idx, point in enumerate(sel_pix): 98 | if idx % 2 == 0: 99 | # draw a red circle at the handle point 100 | cv2.circle(img, tuple(point), 5, (255, 0, 0), -1) 101 | else: 102 | # draw a blue circle at the handle point 103 | cv2.circle(img, tuple(point), 5, (0, 0, 255), -1) 104 | points.append(tuple(point)) 105 | # draw an arrow from handle point to target point 106 | if len(points) == 2: 107 | cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5) 108 | points = [] 109 | return img if isinstance(img, np.ndarray) else np.array(img) 110 | 111 | 112 | def show_cur_points(img, 113 | sel_pix, 114 | bgr=False): 115 | # draw points 116 | points = [] 117 | for idx, point in enumerate(sel_pix): 118 | if idx % 2 == 0: 119 | # draw a red circle at the handle point 120 | red = (255, 0, 0) if not bgr else (0, 0, 255) 121 | cv2.circle(img, tuple(point), 5, red, -1) 122 | else: 123 | # draw a blue circle at the handle point 124 | blue = (0, 0, 255) if not bgr else (255, 0, 0) 125 | cv2.circle(img, tuple(point), 5, blue, -1) 126 | points.append(tuple(point)) 127 | # draw an arrow from handle point to target point 128 | if len(points) == 2: 129 | cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5) 130 | points = [] 131 | return img if isinstance(img, np.ndarray) else np.array(img) 132 | 133 | 134 | # clear all handle/target points 135 | def undo_points(original_image, 136 | mask): 137 | if mask.sum() > 0: 138 | mask = np.uint8(mask > 0) 139 | masked_img = mask_image(original_image, 1 - mask, color=[0, 0, 0], alpha=0.3) 140 | else: 141 | masked_img = original_image.copy() 142 | return masked_img, [] 143 | 144 | 145 | def clear_folder(folder_path): 146 | # Check if the folder exists 147 | if os.path.exists(folder_path): 148 | # Iterate over all the files and directories in the folder 149 | for filename in os.listdir(folder_path): 150 | file_path = os.path.join(folder_path, filename) 151 | 152 | # Check if it's a file or a directory 153 | if os.path.isfile(file_path) or os.path.islink(file_path): 154 | os.unlink(file_path) # Remove the file or link 155 | elif os.path.isdir(file_path): 156 | shutil.rmtree(file_path) # Remove the directory and all its contents 157 | 158 | 159 | def train_lora_interface(original_image, 160 | prompt, 161 | model_path, 162 | vae_path, 163 | lora_path, 164 | lora_step, 165 | lora_lr, 166 | lora_batch_size, 167 | lora_rank, 168 | progress=gr.Progress(), 169 | use_gradio_progress=True): 170 | if not os.path.exists(lora_path): 171 | os.makedirs(lora_path) 172 | 173 | clear_folder(lora_path) 174 | 175 | train_lora( 176 | original_image, 177 | prompt, 178 | model_path, 179 | vae_path, 180 | lora_path, 181 | lora_step, 182 | lora_lr, 183 | lora_batch_size, 184 | lora_rank, 185 | progress, 186 | use_gradio_progress) 187 | return "Training LoRA Done!" 188 | 189 | 190 | def preprocess_image(image, 191 | device): 192 | image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1] 193 | image = rearrange(image, "h w c -> 1 c h w") 194 | image = image.to(device) 195 | return image 196 | 197 | 198 | def save_images_with_pillow(images, base_filename='image'): 199 | for index, img in enumerate(images): 200 | # Convert array to Image object and save 201 | img_pil = Image.fromarray(img) 202 | folder_path = f'./save' 203 | filename = os.path.join(folder_path, "{}_{}.png".format(base_filename, index)) 204 | img_pil.save(filename) 205 | print(f"Saved: {filename}") 206 | 207 | 208 | def get_original_points(handle_points: List[torch.Tensor], 209 | full_h: int, 210 | full_w: int, 211 | sup_res_w, 212 | sup_res_h, 213 | ) -> List[torch.Tensor]: 214 | """ 215 | Convert local handle points and target points back to their original UI coordinates. 216 | 217 | Args: 218 | sup_res_h: Half original height of the UI canvas. 219 | sup_res_w: Half original width of the UI canvas. 220 | handle_points: List of handle points in local coordinates. 221 | full_h: Original height of the UI canvas. 222 | full_w: Original width of the UI canvas. 223 | 224 | Returns: 225 | original_handle_points: List of handle points in original UI coordinates. 226 | """ 227 | original_handle_points = [] 228 | 229 | for cur_point in handle_points: 230 | original_point = torch.round( 231 | torch.tensor([cur_point[1] * full_w / sup_res_w, cur_point[0] * full_h / sup_res_h])) 232 | original_handle_points.append(original_point) 233 | 234 | return original_handle_points 235 | 236 | 237 | def save_image_mask_points(mask, points, image_with_points, output_dir='./saved_data'): 238 | """ 239 | Saves the mask and points to the specified directory. 240 | 241 | Args: 242 | mask: The mask data as a numpy array. 243 | points: The list of points collected from the user interaction. 244 | image_with_points: The image with points clicked by the user. 245 | output_dir: The directory where to save the data. 246 | """ 247 | os.makedirs(output_dir, exist_ok=True) 248 | 249 | # Save mask 250 | mask_path = os.path.join(output_dir, f"mask.png") 251 | Image.fromarray(mask.astype(np.uint8) * 255).save(mask_path) 252 | 253 | # Save points 254 | points_path = os.path.join(output_dir, f"points.json") 255 | with open(points_path, 'w') as f: 256 | json.dump({'points': points}, f) 257 | 258 | image_with_points_path = os.path.join(output_dir, "image_with_points.jpg") 259 | Image.fromarray(image_with_points).save(image_with_points_path) 260 | 261 | return 262 | 263 | 264 | def save_drag_result(output_image, new_points, result_path): 265 | output_image = cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR) 266 | 267 | result_dir = f'{result_path}' 268 | os.makedirs(result_dir, exist_ok=True) 269 | output_image_path = os.path.join(result_dir, 'output_image.png') 270 | cv2.imwrite(output_image_path, output_image) 271 | 272 | img_with_new_points = show_cur_points(np.ascontiguousarray(output_image), new_points, bgr=True) 273 | new_points_image_path = os.path.join(result_dir, 'image_with_new_points.png') 274 | cv2.imwrite(new_points_image_path, img_with_new_points) 275 | 276 | points_path = os.path.join(result_dir, f'new_points.json') 277 | with open(points_path, 'w') as f: 278 | json.dump({'points': new_points}, f) 279 | 280 | 281 | def save_intermediate_images(intermediate_images, result_dir): 282 | for i in range(len(intermediate_images)): 283 | intermediate_images[i] = cv2.cvtColor(intermediate_images[i], cv2.COLOR_RGB2BGR) 284 | intermediate_images_path = os.path.join(result_dir, f'output_image_{i}.png') 285 | cv2.imwrite(intermediate_images_path, intermediate_images[i]) 286 | 287 | 288 | def create_video(image_folder, data_folder, fps=2, first_frame_duration=2, last_frame_extra_duration=2): 289 | """ 290 | Creates an MP4 video from a sequence of images using OpenCV. 291 | """ 292 | img_folder = Path(image_folder) 293 | img_num = len(list(img_folder.glob('*.png'))) 294 | 295 | # Path to the original image with points 296 | data_folder = Path(data_folder) 297 | original_path = data_folder / 'image_with_points.jpg' 298 | output_path = img_folder / 'dragging.mp4' 299 | # Collect all image paths 300 | img_files = [original_path] 301 | 302 | # Load the first image to determine the size 303 | frame = cv2.imread(str(img_files[0])) 304 | height, width, layers = frame.shape 305 | size = (int(width), int(height)) 306 | 307 | # Define the codec and create VideoWriter object 308 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 'mp4v' for .mp4 format 309 | video = cv2.VideoWriter(str(output_path), fourcc, int(fps), size) 310 | 311 | for _ in range(int(fps * first_frame_duration)): 312 | video.write(frame) 313 | 314 | # Add images to video 315 | for i in range(img_num - 2): 316 | video.write(cv2.imread(str(img_folder / f'output_image_{i}.png'))) 317 | 318 | last_frame = cv2.imread(str(img_folder / 'output_image.png')) 319 | for _ in range(int(fps * last_frame_extra_duration)): 320 | video.write(last_frame) 321 | 322 | video.release() 323 | 324 | 325 | def run_gooddrag(source_image, 326 | image_with_clicks, 327 | mask, 328 | prompt, 329 | points, 330 | inversion_strength, 331 | lam, 332 | latent_lr, 333 | model_path, 334 | vae_path, 335 | lora_path, 336 | drag_end_step, 337 | track_per_step, 338 | r1, 339 | r2, 340 | d, 341 | max_drag_per_track, 342 | max_track_no_change, 343 | feature_idx=3, 344 | result_save_path='', 345 | return_intermediate_images=True, 346 | drag_loss_threshold=0, 347 | save_intermedia=False, 348 | compare_mode=False, 349 | once_drag=False, 350 | ): 351 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 352 | height, width = source_image.shape[:2] 353 | n_inference_step = 50 354 | guidance_scale = 1.0 355 | seed = 42 356 | dragger = GoodDragger(device, model_path, prompt, height, width, inversion_strength, r1, r2, d, 357 | drag_end_step, track_per_step, lam, latent_lr, 358 | n_inference_step, guidance_scale, feature_idx, compare_mode, vae_path, lora_path, seed, 359 | max_drag_per_track, drag_loss_threshold, once_drag, max_track_no_change) 360 | 361 | source_image = preprocess_image(source_image, device) 362 | 363 | gen_image, intermediate_features, new_points_handle, intermediate_images = \ 364 | dragger.good_drag(source_image, points, 365 | mask, 366 | return_intermediate_images=return_intermediate_images) 367 | 368 | new_points_handle = get_original_points(new_points_handle, height, width, dragger.sup_res_w, dragger.sup_res_h) 369 | if save_intermedia: 370 | drag_image = [dragger.latent2image(i.cuda()) for i in intermediate_features] 371 | save_images_with_pillow(drag_image, base_filename='drag_image') 372 | 373 | gen_image = F.interpolate(gen_image, (height, width), mode='bilinear') 374 | 375 | out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] 376 | out_image = (out_image * 255).astype(np.uint8) 377 | 378 | new_points = [] 379 | for i in range(len(new_points_handle)): 380 | new_cur_handle_points = new_points_handle[i].numpy().tolist() 381 | new_cur_handle_points = [int(point) for point in new_cur_handle_points] 382 | new_points.append(new_cur_handle_points) 383 | new_points.append(points[i * 2 + 1]) 384 | 385 | print(f'points {points}') 386 | print(f'new points {new_points}') 387 | 388 | if return_intermediate_images: 389 | os.makedirs(result_save_path, exist_ok=True) 390 | for i in range(len(intermediate_images)): 391 | intermediate_images[i] = F.interpolate(intermediate_images[i], (height, width), mode='bilinear') 392 | intermediate_images[i] = intermediate_images[i].cpu().permute(0, 2, 3, 1).numpy()[0] 393 | intermediate_images[i] = (intermediate_images[i] * 255).astype(np.uint8) 394 | 395 | for i in range(len(intermediate_images)): 396 | intermediate_images[i] = cv2.cvtColor(intermediate_images[i], cv2.COLOR_RGB2BGR) 397 | intermediate_images_path = os.path.join(result_save_path, f'output_image_{i}.png') 398 | cv2.imwrite(intermediate_images_path, intermediate_images[i]) 399 | 400 | return out_image, new_points 401 | -------------------------------------------------------------------------------- /webui.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | REM Temporary file for modified requirements 4 | set TEMP_REQ_FILE=temp_requirements.txt 5 | 6 | REM Detect CUDA version using nvcc 7 | for /f "delims=" %%i in ('nvcc --version ^| findstr /i "release"') do set "CUDA_VER_FULL=%%i" 8 | 9 | REM Extract the version number, assuming it's in the format "Cuda compilation tools, release 11.8, V11.8.89" 10 | for /f "tokens=5 delims=, " %%a in ("%CUDA_VER_FULL%") do set "CUDA_VER=%%a" 11 | for /f "tokens=1 delims=vV" %%b in ("%CUDA_VER%") do set "CUDA_VER=%%b" 12 | for /f "tokens=1-2 delims=." %%c in ("%CUDA_VER%") do set "CUDA_MAJOR=%%c" & set "CUDA_MINOR=%%d" 13 | 14 | REM Concatenate major and minor version numbers to form the CUDA tag 15 | set "CUDA_TAG=cu%CUDA_MAJOR%%CUDA_MINOR%" 16 | 17 | echo Detected CUDA Tag: %CUDA_TAG% 18 | 19 | REM Modify the torch and torchvision lines in requirements.txt to include the CUDA version 20 | for /F "tokens=*" %%A in (requirements.txt) do ( 21 | echo %%A | findstr /I "torch==" >nul 22 | if errorlevel 1 ( 23 | echo %%A | findstr /I "torchvision==" >nul 24 | if errorlevel 1 ( 25 | echo %%A >> "%TEMP_REQ_FILE%" 26 | ) else ( 27 | echo torchvision==0.15.2+%CUDA_TAG% >> "%TEMP_REQ_FILE%" 28 | ) 29 | ) else ( 30 | echo torch==2.0.1+%CUDA_TAG%>> "%TEMP_REQ_FILE%" 31 | ) 32 | ) 33 | 34 | REM Replace the original requirements file with the modified one 35 | move /Y "%TEMP_REQ_FILE%" requirements.txt 36 | 37 | 38 | REM Define the virtual environment directory 39 | set VENV_DIR=GoodDrag 40 | 41 | REM Check if Python is installed 42 | python --version > nul 2>&1 43 | if %errorlevel% neq 0 ( 44 | echo Python is not installed. Please install it first. 45 | pause 46 | exit /b 47 | ) 48 | 49 | 50 | 51 | REM Create a virtual environment if it doesn't exist 52 | if not exist "%VENV_DIR%" ( 53 | echo Creating virtual environment... 54 | python -m venv %VENV_DIR% 55 | ) 56 | 57 | REM Activate the virtual environment 58 | call %VENV_DIR%\Scripts\activate.bat 59 | 60 | REM Install dependencies (uncomment and modify the next line if you have any dependencies) 61 | pip install -r requirements.txt 62 | 63 | REM Run the Python script 64 | echo Starting gooddrag_ui.py... 65 | python gooddrag_ui.py 66 | 67 | REM Deactivate the virtual environment on script exit 68 | call %VENV_DIR%\Scripts\deactivate.bat 69 | 70 | echo Script finished. Press any key to exit. 71 | pause > nul 72 | -------------------------------------------------------------------------------- /webui.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Temporary file for modified requirements 4 | TEMP_REQ_FILE="temp_requirements.txt" 5 | 6 | # Detect CUDA version using nvcc 7 | CUDA_VER_FULL=$(nvcc --version | grep "release" | sed -n 's/.*release \([0-9]*\)\.\([0-9]*\),.*/\1\2/p') 8 | 9 | # Set the CUDA tag 10 | CUDA_TAG="cu$CUDA_VER_FULL" 11 | 12 | echo "Detected CUDA Tag: $CUDA_TAG" 13 | 14 | # Modify the torch and torchvision lines in requirements.txt to include the CUDA version 15 | while IFS= read -r line; do 16 | if [[ "$line" == torch==* ]]; then 17 | echo "torch==2.0.1+$CUDA_TAG" >> "$TEMP_REQ_FILE" 18 | elif [[ "$line" == torchvision==* ]]; then 19 | echo "torchvision==0.15.2+$CUDA_TAG" >> "$TEMP_REQ_FILE" 20 | else 21 | echo "$line" >> "$TEMP_REQ_FILE" 22 | fi 23 | done < requirements.txt 24 | 25 | # Replace the original requirements file with the modified one 26 | mv "$TEMP_REQ_FILE" requirements.txt 27 | 28 | # Define the virtual environment directory 29 | VENV_DIR="GoodDrag" 30 | 31 | # Check if Python 3 is installed 32 | if ! command -v python3 &> /dev/null; then 33 | echo "Python 3 is not installed. Please install it first." 34 | exit 1 35 | fi 36 | 37 | # Create a virtual environment if it doesn't exist 38 | if [ ! -d "$VENV_DIR" ]; then 39 | echo "Creating virtual environment..." 40 | python3 -m venv "$VENV_DIR" 41 | fi 42 | 43 | # Activate the virtual environment 44 | source "$VENV_DIR/bin/activate" 45 | 46 | # Install dependencies 47 | echo "Installing dependencies..." 48 | pip install -r requirements.txt 49 | 50 | # Run the Python script 51 | echo "Starting gooddrag_ui.py..." 52 | python3 gooddrag_ui.py 53 | 54 | # Deactivate the virtual environment on script exit 55 | deactivate 56 | 57 | echo "Script finished." --------------------------------------------------------------------------------