├── LICENSE
├── README.md
├── assets
├── GoodDrag_demo.gif
├── cat_2
│ ├── image_with_new_points.png
│ ├── image_with_points.jpg
│ ├── original.jpg
│ └── trajectory.gif
├── chess_1
│ ├── image_with_new_points.png
│ ├── image_with_points.jpg
│ ├── original.jpg
│ └── trajectory.gif
├── furniture_0
│ ├── image_with_new_points.png
│ ├── image_with_points.jpg
│ ├── original.jpg
│ └── trajectory.gif
├── gooddrag_icon.png
├── human_6
│ ├── image_with_new_points.png
│ ├── image_with_points.jpg
│ ├── original.jpg
│ └── trajectory.gif
├── leopard
│ ├── image_with_new_points.png
│ ├── image_with_points.jpg
│ ├── original.jpg
│ └── trajectory.gif
└── rabbit
│ ├── image_with_new_points.png
│ ├── image_with_points.jpg
│ ├── original.jpg
│ └── trajectory.gif
├── bench_gooddrag.py
├── dataset
├── cat_2
│ ├── image_with_points.jpg
│ ├── mask.png
│ ├── original.jpg
│ └── points.json
└── furniture_0
│ ├── image_with_points.jpg
│ ├── mask.png
│ ├── original.jpg
│ └── points.json
├── environment.yaml
├── evaluation
├── GScore.ipynb
└── compute_DAI.py
├── gooddrag_ui.py
├── pipeline.py
├── requirements.txt
├── utils
├── __pycache__
│ ├── attn_utils.cpython-310.pyc
│ ├── drag_utils.cpython-310.pyc
│ ├── lora_utils.cpython-310.pyc
│ └── ui_utils.cpython-310.pyc
├── attn_utils.py
├── drag_utils.py
├── lora_utils.py
└── ui_utils.py
├── webui.bat
└── webui.sh
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | 
GoodDrag: Towards Good Practices for Drag Editing with Diffusion Models
3 |
4 | Zewei Zhang
5 |
6 | Huan Liu
7 |
8 | Jun Chen
9 |
10 | Xiangyu Xu
11 |
12 |
13 |
14 |
20 |
21 |
27 |
28 |
34 |
35 |
36 |
42 |
43 |
44 |
50 |
51 |
57 |
58 |
59 |

60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 | ## 📢 Latest Updates
70 | - **2024.04.17** - Updated DAI (Dragging Accuracy Index) and GScore (Gemini Score) evaluation methods. Please check the evaluation file. GScore is modified from [Generative AI](https://github.com/GoogleCloudPlatform/generative-ai/tree/main).
71 | - **2025.03.01** - Our paper is accepted by ICLR2025!
72 |
73 | ## 1. Getting Started with GoodDrag
74 |
75 | Before getting started, please make sure your system is equipped with a CUDA-compatible GPU and Python 3.9 or higher. We provide three methods to directly run GoodDrag:
76 | ### 1️⃣ Automated Script for Effortless Setup
77 |
78 | - **Windows Users:** Double-click **webui.bat** to automatically set up your environment and launch the GoodDrag web UI.
79 | - **Linux Users:** Run **webui.sh** for a similar one-step setup and launch process.
80 |
81 | ### 2️⃣ Manual Installation via pip
82 | 1. Install the necessary dependencies:
83 |
84 | ```bash
85 | pip install -r requirements.txt
86 | ```
87 | 2. Launch the GoodDrag web UI:
88 |
89 | ```bash
90 | python gooddrag_ui.py
91 | ```
92 |
93 | ### 3️⃣ Quick Start with Colab
94 | For a quick and easy start, access GoodDrag directly through Google Colab. Click the badge below to open a pre-configured notebook that will guide you through using GoodDrag in the Colab environment:
95 |
96 | ### Runtime and Memory Requirements
97 | GoodDrag's efficiency depends on the image size and editing complexity. For a 512x512 image on an A100 GPU: the LoRA phase requires ~17 seconds, and drag editing takes around 1 minute. GPU memory requirement is below 13GB.
98 |
99 | ## 2. Parameter Description
100 |
101 | We have predefined a set of parameters in the GoodDrag WebUI. Here are a few that you might consider adjusting:
102 | | Parameter Name | Description |
103 | | -------------- | ------------------------------------------------------------ |
104 | | Learning Rate | Influences the speed of drag editing. Higher values lead to faster editing but may result in lower quality or instability. It is recommended to keep this value below 0.05. |
105 | | Prompt | The text prompt for the diffusion model. It is suggested to leave this empty. |
106 | | End time step | Specifies the length of the time step during the denoise phase of the diffusion model for drag editing. If good results are obtained early in the generated video, consider reducing this value. Conversely, if the drag editing is insufficient, increase it slightly. It is recommended to keep this value below 12. |
107 | | Lambda | Controls the consistency of the non-dragged regions with the original image. A higher value keeps the area outside the mask more in line with the original image. |
108 |
109 | ## 3. Acknowledgments
110 | Part of the code was based on [DragDiffusion](https://github.com/Yujun-Shi/DragDiffusion) and [DragGAN](https://github.com/XingangPan/DragGAN). Thanks for the great work!
111 |
112 | ## 4. BibTeX
113 | ```bibtex
114 | @inproceedings{zhang2025gooddrag,
115 | title={GoodDrag: Towards Good Practices for Drag Editing with Diffusion Models},
116 | author={Zewei Zhang and Huan Liu and Jun Chen and Xiangyu Xu},
117 | booktitle={The Thirteenth International Conference on Learning Representations},
118 | year={2025},
119 | }
120 | ```
121 |
--------------------------------------------------------------------------------
/assets/GoodDrag_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/GoodDrag_demo.gif
--------------------------------------------------------------------------------
/assets/cat_2/image_with_new_points.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/cat_2/image_with_new_points.png
--------------------------------------------------------------------------------
/assets/cat_2/image_with_points.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/cat_2/image_with_points.jpg
--------------------------------------------------------------------------------
/assets/cat_2/original.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/cat_2/original.jpg
--------------------------------------------------------------------------------
/assets/cat_2/trajectory.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/cat_2/trajectory.gif
--------------------------------------------------------------------------------
/assets/chess_1/image_with_new_points.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/chess_1/image_with_new_points.png
--------------------------------------------------------------------------------
/assets/chess_1/image_with_points.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/chess_1/image_with_points.jpg
--------------------------------------------------------------------------------
/assets/chess_1/original.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/chess_1/original.jpg
--------------------------------------------------------------------------------
/assets/chess_1/trajectory.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/chess_1/trajectory.gif
--------------------------------------------------------------------------------
/assets/furniture_0/image_with_new_points.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/furniture_0/image_with_new_points.png
--------------------------------------------------------------------------------
/assets/furniture_0/image_with_points.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/furniture_0/image_with_points.jpg
--------------------------------------------------------------------------------
/assets/furniture_0/original.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/furniture_0/original.jpg
--------------------------------------------------------------------------------
/assets/furniture_0/trajectory.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/furniture_0/trajectory.gif
--------------------------------------------------------------------------------
/assets/gooddrag_icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/gooddrag_icon.png
--------------------------------------------------------------------------------
/assets/human_6/image_with_new_points.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/human_6/image_with_new_points.png
--------------------------------------------------------------------------------
/assets/human_6/image_with_points.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/human_6/image_with_points.jpg
--------------------------------------------------------------------------------
/assets/human_6/original.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/human_6/original.jpg
--------------------------------------------------------------------------------
/assets/human_6/trajectory.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/human_6/trajectory.gif
--------------------------------------------------------------------------------
/assets/leopard/image_with_new_points.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/leopard/image_with_new_points.png
--------------------------------------------------------------------------------
/assets/leopard/image_with_points.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/leopard/image_with_points.jpg
--------------------------------------------------------------------------------
/assets/leopard/original.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/leopard/original.jpg
--------------------------------------------------------------------------------
/assets/leopard/trajectory.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/leopard/trajectory.gif
--------------------------------------------------------------------------------
/assets/rabbit/image_with_new_points.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/rabbit/image_with_new_points.png
--------------------------------------------------------------------------------
/assets/rabbit/image_with_points.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/rabbit/image_with_points.jpg
--------------------------------------------------------------------------------
/assets/rabbit/original.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/rabbit/original.jpg
--------------------------------------------------------------------------------
/assets/rabbit/trajectory.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/assets/rabbit/trajectory.gif
--------------------------------------------------------------------------------
/bench_gooddrag.py:
--------------------------------------------------------------------------------
1 | # *************************************************************************
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # *************************************************************************
14 |
15 |
16 | import os
17 | import sys
18 | import cv2
19 | import numpy as np
20 | import json
21 | from pathlib import Path
22 | from PIL import Image
23 | from utils.ui_utils import run_gooddrag, train_lora_interface, show_cur_points, create_video
24 |
25 |
26 | def benchmark_dataset(dataset_folder):
27 | dataset_path = Path(dataset_folder)
28 | subfolders = [f for f in dataset_path.iterdir() if f.is_dir() and f.name != '.ipynb_checkpoints']
29 |
30 | for subfolder in subfolders:
31 | print(f'Benchmarking {subfolder.name}...')
32 | try:
33 | bench_one_image(subfolder)
34 | except Exception as e:
35 | print(f'An error occured while benchmarking {subfolder.name}: {e}.')
36 |
37 |
38 | def load_data(folder):
39 | """Load the original image, mask, and points from the specified folder."""
40 | folder_path = Path(folder)
41 |
42 | # Load original image
43 | original_image_path = folder_path / 'original.jpg'
44 | original_image = Image.open(original_image_path)
45 | original_image = np.array(original_image)
46 |
47 | # Load mask
48 | mask_path = folder_path / 'mask.png'
49 | mask = Image.open(mask_path)
50 | mask = np.array(mask)
51 | if len(mask.shape) == 3:
52 | mask = mask[:, :, 0]
53 |
54 | # Load points
55 | points_path = folder_path / 'points.json'
56 | with open(points_path, 'r') as f:
57 | points_data = json.load(f)
58 | points = points_data['points']
59 |
60 | image_points_path = folder_path / 'image_with_points.jpg'
61 | image_with_points = Image.open(image_points_path)
62 | image_with_points = np.array(image_with_points)
63 |
64 | return original_image, mask, points, image_with_points
65 |
66 |
67 | def bench_one_image(folder):
68 | """
69 | Test the saved data by running the drag model.
70 |
71 | Args:
72 | folder: The folder where the original image, mask, and points are saved.
73 | """
74 | original_image, mask, points, image_with_points = load_data(folder)
75 | model_path = 'runwayml/stable-diffusion-v1-5'
76 |
77 | lora_path = f'./lora_data/{folder.parts[-1]}'
78 |
79 | print(f'Training Lora.')
80 | train_lora_interface(original_image=original_image, prompt='', model_path=model_path,
81 | vae_path='stabilityai/sd-vae-ft-mse',
82 | lora_path=lora_path, lora_step=70, lora_lr=0.0005, lora_batch_size=4, lora_rank=16,
83 | use_gradio_progress=False)
84 | print(f'Training Lora Done! Begin dragging.')
85 |
86 | return_intermediate_images = True
87 |
88 | result_dir = f'./bench_result/{Path(folder).parts[-1]}'
89 | os.makedirs(result_dir, exist_ok=True)
90 |
91 | output_image, new_points = run_gooddrag(
92 | source_image=original_image,
93 | image_with_clicks=image_with_points,
94 | mask=mask,
95 | prompt='',
96 | points=points,
97 | inversion_strength=0.75,
98 | lam=0.1,
99 | latent_lr=0.02,
100 | model_path=model_path,
101 | vae_path='stabilityai/sd-vae-ft-mse',
102 | lora_path=lora_path,
103 | drag_end_step=7,
104 | track_per_step=10,
105 | save_intermedia=False,
106 | compare_mode=False,
107 | r1=4,
108 | r2=12,
109 | d=4,
110 | max_drag_per_track=3,
111 | drag_loss_threshold=0,
112 | once_drag=False,
113 | max_track_no_change=5,
114 | return_intermediate_images=return_intermediate_images,
115 | result_save_path=result_dir
116 | )
117 |
118 | print(f'Drag finished!')
119 | output_image = cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR)
120 | output_image_path = os.path.join(result_dir, 'output_image.png')
121 | cv2.imwrite(output_image_path, output_image)
122 |
123 | img_with_new_points = show_cur_points(np.ascontiguousarray(output_image), new_points, bgr=True)
124 | new_points_image_path = os.path.join(result_dir, 'image_with_new_points.png')
125 | cv2.imwrite(new_points_image_path, img_with_new_points)
126 |
127 | points_path = os.path.join(result_dir, f'new_points.json')
128 | with open(points_path, 'w') as f:
129 | json.dump({'points': new_points}, f)
130 |
131 | if return_intermediate_images:
132 | create_video(result_dir, folder)
133 |
134 |
135 | def main(dataset_folder):
136 | benchmark_dataset(dataset_folder)
137 |
138 |
139 | if __name__ == '__main__':
140 | dataset = sys.argv[1]
141 | main(dataset)
142 |
--------------------------------------------------------------------------------
/dataset/cat_2/image_with_points.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/dataset/cat_2/image_with_points.jpg
--------------------------------------------------------------------------------
/dataset/cat_2/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/dataset/cat_2/mask.png
--------------------------------------------------------------------------------
/dataset/cat_2/original.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/dataset/cat_2/original.jpg
--------------------------------------------------------------------------------
/dataset/cat_2/points.json:
--------------------------------------------------------------------------------
1 | {"points": [[147, 67], [202, 46], [410, 88], [336, 54]]}
--------------------------------------------------------------------------------
/dataset/furniture_0/image_with_points.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/dataset/furniture_0/image_with_points.jpg
--------------------------------------------------------------------------------
/dataset/furniture_0/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/dataset/furniture_0/mask.png
--------------------------------------------------------------------------------
/dataset/furniture_0/original.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/dataset/furniture_0/original.jpg
--------------------------------------------------------------------------------
/dataset/furniture_0/points.json:
--------------------------------------------------------------------------------
1 | {"points": [[234, 133], [248, 71], [320, 133], [328, 70], [405, 137], [418, 67]]}
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: GoodDrag
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=22.3.1
8 | - cudatoolkit=11.7
9 | - pip:
10 | - torch==2.0.0
11 | - accelerate
12 | - torchvision==0.15.1
13 | - gradio==3.50.2
14 | - pydantic==2.0.2
15 | - albumentations==1.3.0
16 | - opencv-contrib-python==4.3.0.36
17 | - imageio==2.9.0
18 | - imageio-ffmpeg==0.4.2
19 | - pytorch-lightning==1.5.0
20 | - omegaconf==2.3.0
21 | - test-tube>=0.7.5
22 | - streamlit==1.12.1
23 | - einops==0.6.0
24 | - transformers==4.27.0
25 | - webdataset==0.2.5
26 | - kornia==0.6
27 | - open_clip_torch==2.16.0
28 | - invisible-watermark>=0.1.5
29 | - streamlit-drawable-canvas==0.8.0
30 | - torchmetrics==0.6.0
31 | - timm==0.6.12
32 | - addict==2.4.0
33 | - yapf==0.32.0
34 | - prettytable==3.6.0
35 | - safetensors==0.2.7
36 | - basicsr==1.4.2
37 | - accelerate==0.17.0
38 | - decord==0.6.0
39 | - diffusers==0.20.0
40 | - moviepy==1.0.3
41 | - opencv_python==4.7.0.68
42 | - Pillow==9.4.0
43 | - scikit_image==0.19.3
44 | - scipy==1.10.1
45 | - tensorboardX==2.6
46 | - tqdm==4.64.1
47 | - numpy==1.24.1
48 | - PySoundFile
49 |
--------------------------------------------------------------------------------
/evaluation/GScore.ipynb:
--------------------------------------------------------------------------------
1 | {"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"uxCkB_DXTHzf"},"outputs":[],"source":["# Copyright 2023 Google LLC\n","#\n","# Licensed under the Apache License, Version 2.0 (the \"License\");\n","# you may not use this file except in compliance with the License.\n","# You may obtain a copy of the License at\n","#\n","# https://www.apache.org/licenses/LICENSE-2.0\n","#\n","# Unless required by applicable law or agreed to in writing, software\n","# distributed under the License is distributed on an \"AS IS\" BASIS,\n","# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n","# See the License for the specific language governing permissions and\n","# limitations under the License."]},{"cell_type":"markdown","metadata":{"id":"Hny4I-ODTIS6"},"source":["# Visual Question Answering (VQA) with Imagen on Vertex AI\n","\n","\n"," \n"," \n","  Run in Colab\n"," \n"," | \n"," \n"," \n","  View on GitHub\n"," \n"," | \n"," \n"," \n","  Open in Vertex AI Workbench\n"," \n"," | \n","
\n"]},{"cell_type":"markdown","metadata":{"id":"-nLS57E2TO5y"},"source":["## Overview\n","\n","[Imagen on Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/image/overview) (image Generative AI) offers a variety of features:\n","- Image generation\n","- Image editing\n","- Visual captioning\n","- Visual question answering\n","\n","This notebook focuses on **visual question answering** only.\n","\n","[Visual question answering (VQA) with Imagen](https://cloud.google.com/vertex-ai/docs/generative-ai/image/visual-question-answering) can understand the content of an image and answer questions about it. The model takes in an image and a question as input, and then using the image as context to produce one or more answers to the question.\n","\n","The visual question answering (VQA) can be used for a variety of use cases, including:\n","- assisting the visually impaired to gain more information about the images\n","- answering customer questions about products or services in the image\n","- creating interactive learning environment and providing interactive learning experiences"]},{"cell_type":"markdown","metadata":{"id":"iXsvgIuwTPZw"},"source":["### Objectives\n","\n","In this notebook, you will learn how to use the Vertex AI Python SDK to:\n","\n","- Answering questions about images using the Imagen's visual question answering features\n","\n","- Experiment with different parameters, such as:\n"," - number of answers to be provided by the model"]},{"cell_type":"markdown","metadata":{"id":"skXAu__iqks_"},"source":["### Costs\n","\n","This tutorial uses billable components of Google Cloud:\n","- Vertex AI (Imagen)\n","\n","Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing) and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage."]},{"cell_type":"markdown","metadata":{"id":"mvKl-BtQTRiQ"},"source":["## Getting Started"]},{"cell_type":"markdown","metadata":{"id":"PwFMpIMrTV_4"},"source":["### Install Vertex AI SDK, other packages and their dependencies"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":21358,"status":"ok","timestamp":1704745393930,"user":{"displayName":"Zhang Zewei","userId":"07162348264595431723"},"user_tz":300},"id":"WYUu8VMdJs3V","colab":{"base_uri":"https://localhost:8080/"},"outputId":"16a31034-4e81-4920-912b-d2065e9eb9a3"},"outputs":[{"output_type":"stream","name":"stdout","text":["\u001b[33m WARNING: The script tb-gcp-uploader is installed in '/root/.local/bin' which is not on PATH.\n"," Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.\u001b[0m\u001b[33m\n","\u001b[0m"]}],"source":["%pip install --upgrade --user google-cloud-aiplatform>=1.29.0"]},{"cell_type":"markdown","metadata":{"id":"R5Xep4W9lq-Z"},"source":["### Restart current runtime\n","\n","To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which will restart the current kernel."]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":3,"status":"ok","timestamp":1704745393930,"user":{"displayName":"Zhang Zewei","userId":"07162348264595431723"},"user_tz":300},"id":"XRvKdaPDTznN","outputId":"88bab8b5-a4f0-4462-bd3b-9e5b88bb59be"},"outputs":[{"output_type":"execute_result","data":{"text/plain":["{'status': 'ok', 'restart': True}"]},"metadata":{},"execution_count":2}],"source":["# Restart kernel after installs so that your environment can access the new packages\n","import IPython\n","import time\n","\n","app = IPython.Application.instance()\n","app.kernel.do_shutdown(True)"]},{"cell_type":"markdown","metadata":{"id":"SbmM4z7FOBpM"},"source":["\n","⚠️ The kernel is going to restart. Please wait until it is finished before continuing to the next step. ⚠️\n","
\n","\n"]},{"cell_type":"markdown","metadata":{"id":"opUxT_k5TdgP"},"source":["### Authenticate your notebook environment (Colab only)\n","\n","If you are running this notebook on Google Colab, you will need to authenticate your environment. To do this, run the new cell below. This step is not required if you are using [Vertex AI Workbench](https://cloud.google.com/vertex-ai-workbench)."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"vbNgv4q1T2Mi"},"outputs":[],"source":["import sys\n","\n","if 'google.colab' in sys.modules:\n","\n"," # Authenticate user to Google Cloud\n"," from google.colab import auth\n"," auth.authenticate_user()"]},{"cell_type":"markdown","metadata":{"id":"ybBXSukZkgjg"},"source":["### Define Google Cloud project information (Colab only)\n","\n","If you are running this notebook on Google Colab, you need to define Google Cloud project information to be used. In the following cell, you will define the information, import Vertex AI package, and initialize it. This step is also not required if you are using [Vertex AI Workbench](https://cloud.google.com/vertex-ai-workbench)."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"MkVX07nOC90T"},"outputs":[],"source":["if 'google.colab' in sys.modules:\n","\n"," # Define project information\n"," PROJECT_ID = \"\" # @param {type:\"string\"}\n"," LOCATION = \"\" # @param {type:\"string\"}\n","\n"," # Initialize Vertex AI\n"," import vertexai\n"," vertexai.init(project=PROJECT_ID, location=LOCATION)"]},{"cell_type":"markdown","metadata":{"id":"uQrBo92-W_IL"},"source":["# Strat Here"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"hRnt-Xs9W-ag"},"outputs":[],"source":["!pip install --upgrade google-cloud-aiplatform"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"GZpOZbzKY79M"},"outputs":[],"source":["from google.colab import drive\n","drive.mount('/content/drive')"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Vqer9SUFXH80"},"outputs":[],"source":["import base64\n","import os\n","import vertexai\n","from vertexai.preview.generative_models import GenerativeModel, Part\n","\n","def encode_image(image_path):\n"," with open(image_path, \"rb\") as image_file:\n"," return base64.b64encode(image_file.read()).decode('utf-8')\n","\n","\n","def generate(prompt, original_img, ours_img, dragdif_img, sde_img):\n"," model = GenerativeModel(\"gemini-pro-vision\")\n"," responses = model.generate_content(\n"," [prompt, original_img, ours_img, dragdif_img, sde_img],\n"," generation_config={\n"," \"max_output_tokens\": 2048,\n"," \"temperature\": 0.4,\n"," \"top_p\": 1,\n"," \"top_k\": 32\n"," },\n"," )\n","\n"," return responses.text\n","\n","\n","def write_2_txt(evaluation: str, path: str):\n"," # Extract directory from the path\n"," directory = os.path.dirname(path)\n","\n"," # Check if the directory exists, and create it if it doesn't\n"," if not os.path.exists(directory):\n"," os.makedirs(directory)\n","\n"," # Now, write to the file\n"," with open(path, 'w') as file:\n"," file.write(evaluation)\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"6bS1XftdXUI0"},"outputs":[],"source":["from pathlib import Path\n","from vertexai.preview.generative_models import Part\n","from tqdm import tqdm\n","dataset_dir = Path('/content/drive/MyDrive/dataset')\n","gooddrag_dir = Path('/content/drive/MyDrive/gooddrag')\n","dragdiffusion_dir = Path('/content/drive/MyDrive/dragdiffusion')\n","sde_dir = Path('/content/drive/MyDrive/final_drag/sde')\n","result_dir = Path('/content/drive/MyDrive/GScore')\n","\n","prompt = prompt = '''Conduct a detailed evaluation of three modified images, labeled 'A', 'B', and 'C', in comparison to an original image (Image 1). Image 1 serves as the baseline and will not be evaluated. Focus on assessing the quality of 'A' (Image 2), 'B' (Image 3), and 'C' (Image 4), particularly in terms of their naturalness and the presence or absence of artifacts. Examine how well each algorithm preserves the integrity of the original image while introducing modifications. Look for any signs of distortions, unnatural colors, pixelation, or other visual inconsistencies. Rate each image on a scale from 1 to 10, where 10 represents excellent quality with seamless modifications, and 1 indicates poor quality with significant and noticeable artifacts. Provide a comprehensive analysis for each rating, highlighting specific aspects of the image that influenced your evaluation. Answers must be in English.'''\n","for i in range(10):\n"," result_dir = Path(f'/content/drive/MyDrive/GScore/result_{i}')\n"," for item in tqdm(dataset_dir.iterdir(), desc='Evaluating:'):\n"," img_name = item.name\n"," original_image = Part.from_data(data=base64.b64decode(encode_image(dataset_dir / img_name/ 'original.jpg')), mime_type=\"image/jpeg\")\n"," gooddrag_image = Part.from_data(data=base64.b64decode(encode_image(gooddrag_dir / img_name/ 'output_image.png')), mime_type=\"image/png\")\n"," dragdiffusion_image = Part.from_data(data=base64.b64decode(encode_image(dragdiffusion_dir / img_name/ 'output_image.png')), mime_type=\"image/png\")\n"," sde_image = Part.from_data(data=base64.b64decode(encode_image(sde_dir / f'{img_name}.png')), mime_type=\"image/png\")\n"," evaluation_file_path = result_dir / f'{img_name}.txt'\n"," cur_evaluation = generate(prompt, original_image, gooddrag_image, dragdiffusion_image, sde_image)\n"," write_2_txt(cur_evaluation, evaluation_file_path)\n"]}],"metadata":{"colab":{"provenance":[{"file_id":"1SjGPEFeIdOOvJ4zVRpJagxjip0QUPCMp","timestamp":1713362599717},{"file_id":"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/vision/getting-started/visual_question_answering.ipynb","timestamp":1702608367588}]},"environment":{"kernel":"python3","name":"tf2-gpu.2-11.m110","type":"gcloud","uri":"gcr.io/deeplearning-platform-release/tf2-gpu.2-11:m110"},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.12"}},"nbformat":4,"nbformat_minor":0}
--------------------------------------------------------------------------------
/evaluation/compute_DAI.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | from PIL import Image
4 | from scipy.ndimage import gaussian_filter
5 | from pathlib import Path
6 | import logging
7 | from torchvision import transforms
8 |
9 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
10 |
11 |
12 | def load_image(path):
13 | """ Load an image from the given path. """
14 | return np.array(Image.open(path))
15 |
16 |
17 | def get_patch(image, center, radius):
18 | """ Extract a patch from the image centered at 'center' with the given 'radius'. """
19 | x, y = center
20 | return image[y - radius:y + radius + 1, x - radius:x + radius + 1]
21 |
22 |
23 | def calculate_difference(patch1, patch2):
24 | """ Calculate the L2 norm (Euclidean distance) between two patches. """
25 | difference = patch1 - patch2
26 | squared_difference = np.square(difference)
27 | l2_distance = np.sum(squared_difference)
28 |
29 | return l2_distance
30 |
31 |
32 | def compute_dai(original_image, result_image, points, radius):
33 | """ Compute the Drag Accuracy Index (DAI) for the given images and points. """
34 | dai = 0
35 | for start, target in points:
36 | original_patch = get_patch(original_image, start, radius)
37 | result_patch = get_patch(result_image, target, radius)
38 | dai += calculate_difference(original_patch, result_patch)
39 | dai /= len(points)
40 | dai /= cal_patch_size(radius)
41 | return dai / len(points)
42 |
43 |
44 | def get_points(points_dir):
45 | with open(points_dir, 'r') as file:
46 | points_data = json.load(file)
47 | points = points_data['points']
48 |
49 | # Assuming pairs of points: [start, target, start, target, ...]
50 | point_pairs = [(points[i], points[i + 1]) for i in range(0, len(points), 2)]
51 | return point_pairs
52 |
53 |
54 | def cal_patch_size(radius: int):
55 | return (1 + 2 * radius) ** 2
56 |
57 |
58 | def compute_average_dai(radius, dataset_path, original_dataset_path=None):
59 | """ Compute the average DAI for a given dataset. """
60 | dataset_dir = Path(dataset_path)
61 | original_dataset_dir = Path(original_dataset_path) if original_dataset_path else dataset_dir
62 | total_dai, num_folders = 0, 0
63 | transform = transforms.Compose([
64 | transforms.ToTensor(),
65 | ])
66 |
67 | for item in dataset_dir.iterdir():
68 | if item.is_dir() or (item.is_file() and original_dataset_path):
69 | folder_name = item.stem if item.is_file() else item.name
70 | original_image_path = original_dataset_dir / folder_name / 'original.jpg'
71 | result_image_path = item if item.is_file() else item / 'output_image.png'
72 | points_json_path = original_dataset_dir / folder_name / 'points.json'
73 |
74 | if original_image_path.exists() and result_image_path.exists() and points_json_path.exists():
75 | original_image = load_image(str(original_image_path))
76 | result_image = load_image(str(result_image_path))
77 | point_pairs = get_points(str(points_json_path))
78 |
79 | original_image = transform(original_image).permute(1, 2, 0).numpy()
80 | result_image = transform(result_image).permute(1, 2, 0).numpy()
81 | dai = compute_dai(original_image, result_image, point_pairs, radius)
82 | total_dai += dai
83 | num_folders += 1
84 | else:
85 | logging.warning(f"Missing files in {folder_name}")
86 |
87 | if num_folders > 0:
88 | average_dai = total_dai / num_folders
89 | logging.info(f'Average DAI for {dataset_dir} with r3 {radius} is {average_dai:.4f}. Total {num_folders} images.')
90 | else:
91 | logging.warning("No valid folders found for DAI calculation.")
92 |
93 |
94 | def main():
95 | gamma = [1, 5, 10, 20]
96 | result_folder = './bench_result'
97 | data_folder = './dataset'
98 | for r in gamma:
99 | compute_average_dai(r, result_folder, data_folder)
100 |
101 |
102 | if __name__ == '__main__':
103 | main()
104 |
--------------------------------------------------------------------------------
/gooddrag_ui.py:
--------------------------------------------------------------------------------
1 | # *************************************************************************
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # *************************************************************************
14 |
15 | import os
16 | import gradio as gr
17 | from utils.ui_utils import (
18 | get_points, undo_points, show_cur_points,
19 | clear_all, store_img, train_lora_interface, run_gooddrag, save_image_mask_points, save_drag_result,
20 | save_intermediate_images, create_video
21 | )
22 |
23 | LENGTH = 512
24 |
25 |
26 | def create_markdown_section():
27 | gr.Markdown("""
28 | # GoodDrag ✨
29 |
30 | 👋 Welcome to GoodDrag! Follow these steps to easily manipulate your images:
31 |
32 | 1. **Upload Image:** 📤 Either drag and drop an image or click to upload in the **Draw Mask** box.
33 | 2. **Prepare for Training:** 🛠️
34 | - Set the path for the LoRA algorithm for your image.
35 | - Click the **Train LoRA** button to initiate the training process.
36 | 3. **Draw and Click:** ✏️
37 | - Use the **Draw Mask** box to create a mask on your image.
38 | - Next, go to the **Click Points** box. Here, you can add multiple pairs of points by clicking on the desired locations.
39 | 4. **Save Current Data (Optional):** 💾
40 | - If you wish to save the current state (including the image, mask, points, and the composite image with mask and points), specify the data path.
41 | - Click **Save Current Data** to store these elements.
42 | 5. **Run Drag Process:** ▶️
43 | - Click the **Run** button to process the image based on the drawn mask and points.
44 | 6. **Save the Results (Optional):** 🏁
45 | - Specify a path to save the final dragged image, the new points, and an image showing the new points.
46 | - Click **Save Result** to download these items.
47 | 7. **Save Intermediate Images (Optional):** 📸
48 | - For those interested in viewing the drag process step-by-step, check the **Save Intermediate Images** option under the **Get Intermediate Images** section.
49 | - To obtain a video of the drag process, ensure all optional steps above have been completed, then click the **Get Video** button.
50 |
51 | Enjoy creating with GoodDrag! 🌟
52 |
53 | """)
54 |
55 |
56 | def create_base_model_config_ui():
57 | with gr.Tab("Diffusion Model"):
58 | with gr.Row():
59 | local_models_dir = 'local_pretrained_models'
60 | os.makedirs(local_models_dir, exist_ok=True)
61 | local_models_choice = \
62 | [os.path.join(local_models_dir, d) for d in os.listdir(local_models_dir) if
63 | os.path.isdir(os.path.join(local_models_dir, d))]
64 | model_path = gr.Dropdown(value="runwayml/stable-diffusion-v1-5",
65 | label="Diffusion Model Path",
66 | choices=[
67 | "runwayml/stable-diffusion-v1-5",
68 | "stabilityai/stable-diffusion-2-1-base",
69 | "stabilityai/stable-diffusion-xl-base-1.0",
70 | ] + local_models_choice
71 | )
72 | vae_path = gr.Dropdown(value="stabilityai/sd-vae-ft-mse",
73 | label="VAE choice",
74 | choices=["stabilityai/sd-vae-ft-mse",
75 | "default"] + local_models_choice
76 | )
77 |
78 | return model_path, vae_path
79 |
80 |
81 | def create_lora_parameters_ui():
82 | with gr.Tab("LoRA Parameters"):
83 | with gr.Row():
84 | lora_step = gr.Number(value=70, label="LoRA training steps", precision=0)
85 | lora_lr = gr.Number(value=0.0005, label="LoRA learning rate")
86 | lora_batch_size = gr.Number(value=4, label="LoRA batch size", precision=0)
87 | lora_rank = gr.Number(value=16, label="LoRA rank", precision=0)
88 |
89 | return lora_step, lora_lr, lora_batch_size, lora_rank
90 |
91 |
92 | def create_real_image_editing_ui():
93 | with gr.Row():
94 | with gr.Column():
95 | gr.Markdown("📤 Draw Mask
")
96 | canvas = gr.Image(type="numpy", tool="sketch", label="Draw your mask on the image",
97 | show_label=True, height=LENGTH, width=LENGTH) # for mask painting
98 | with gr.Row():
99 | train_lora_button = gr.Button("Train LoRA")
100 | lora_path = gr.Textbox(value=f"./lora_data/test", label="LoRA Path",
101 | placeholder="Enter path for LoRA data")
102 |
103 | with gr.Row():
104 | lora_status_bar = gr.Textbox(label="LoRA Training Status", interactive=False)
105 |
106 | with gr.Column():
107 | gr.Markdown("✏Click Points
")
108 | input_image = gr.Image(type="numpy", label="Click on the image to mark points",
109 | show_label=True, height=LENGTH, width=LENGTH) # for points clicking
110 | with gr.Row():
111 | undo_button = gr.Button("Undo Point")
112 | save_button = gr.Button('Save Current Data')
113 | data_dir = gr.Textbox(value='./dataset/test', label="Data Directory",
114 | placeholder="Enter directory path for mask and points")
115 |
116 | with gr.Column():
117 | gr.Markdown("🖼️ Editing Result
")
118 | output_image = gr.Image(type="numpy", label="View the editing results here",
119 | show_label=True, height=LENGTH, width=LENGTH)
120 | with gr.Row():
121 | run_button = gr.Button("Run")
122 | clear_all_button = gr.Button("Clear All")
123 | save_result = gr.Button("Save Result")
124 | show_points = gr.Button("Show Points")
125 | result_save_path = gr.Textbox(value='./result/test', label="Result Folder",
126 | placeholder="Enter path to save the results")
127 |
128 | return canvas, train_lora_button, lora_path, lora_status_bar, input_image, undo_button, save_button, data_dir, \
129 | output_image, run_button, clear_all_button, show_points, result_save_path, save_result
130 |
131 |
132 | def create_drag_parameters_ui():
133 | with gr.Tab("Drag Parameters"):
134 | with gr.Row():
135 | latent_lr = gr.Number(value=0.02, label="Learning rate")
136 | prompt = gr.Textbox(label="Prompt")
137 | drag_end_step = gr.Number(value=7, label="End time step", precision=0)
138 | drag_per_step = gr.Number(value=10, label="Point tracking number per each step", precision=0)
139 |
140 | return latent_lr, prompt, drag_end_step, drag_per_step
141 |
142 |
143 | def create_advance_parameters_ui():
144 | with gr.Tab("Advanced Parameters"):
145 | with gr.Row():
146 | r1 = gr.Number(value=4, label="Motion supervision feature path size", precision=0)
147 | r2 = gr.Number(value=12, label="Point tracking feature patch size", precision=0)
148 | drag_distance = gr.Number(value=4, label="The distance for motion supervision", precision=0)
149 | feature_idx = gr.Number(value=3, label="The index of the features [0,3]", precision=0)
150 | max_drag_per_track = gr.Number(value=3,
151 | label="Motion supervision times for each point tracking",
152 | precision=0)
153 |
154 | with gr.Row():
155 | lam = gr.Number(value=0.2, label="Lambda", info="Regularization strength on unmasked areas")
156 | inversion_strength = gr.Slider(0, 1.0,
157 | value=0.75,
158 | label="Inversion strength")
159 | max_track_no_change = gr.Number(value=10, label="Early stop",
160 | info="The maximum number of times points is unchanged.")
161 |
162 | return (r1, r2, drag_distance, feature_idx, max_drag_per_track, lam,
163 | inversion_strength, max_track_no_change)
164 |
165 |
166 | def create_intermediate_save_ui():
167 | with gr.Tab("Get Intermediate Images"):
168 | with gr.Row():
169 | save_intermediates_images = gr.Checkbox(label='Save intermediate images')
170 | get_mp4 = gr.Button("Get video")
171 |
172 | return save_intermediates_images, get_mp4
173 |
174 |
175 | def attach_canvas_event(canvas: gr.State, original_image: gr.State,
176 | selected_points: gr.State, input_image, mask):
177 | canvas.edit(
178 | store_img,
179 | [canvas],
180 | [original_image, selected_points, input_image, mask]
181 | )
182 |
183 |
184 | def attach_input_image_event(input_image, selected_points):
185 | input_image.select(
186 | get_points,
187 | [input_image, selected_points],
188 | [input_image]
189 | )
190 |
191 |
192 | def attach_undo_button_event(undo_button, original_image, selected_points, mask, input_image):
193 | undo_button.click(
194 | undo_points,
195 | [original_image, mask],
196 | [input_image, selected_points]
197 | )
198 |
199 |
200 | def attach_train_lora_button_event(train_lora_button, original_image, prompt,
201 | model_path, vae_path, lora_path,
202 | lora_step, lora_lr, lora_batch_size, lora_rank,
203 | lora_status_bar):
204 | train_lora_button.click(
205 | train_lora_interface,
206 | [original_image, prompt, model_path, vae_path, lora_path,
207 | lora_step, lora_lr, lora_batch_size, lora_rank],
208 | [lora_status_bar]
209 | )
210 |
211 |
212 | def attach_run_button_event(run_button, original_image, input_image, mask, prompt,
213 | selected_points, inversion_strength, lam, latent_lr,
214 | model_path, vae_path, lora_path,
215 | drag_end_step, drag_per_step,
216 | output_image, r1, r2, d, feature_idx, new_points,
217 | max_drag_per_track, max_track_no_change,
218 | result_save_path, save_intermediates_images):
219 | run_button.click(
220 | run_gooddrag,
221 | [original_image, input_image, mask, prompt, selected_points,
222 | inversion_strength, lam, latent_lr, model_path, vae_path,
223 | lora_path, drag_end_step, drag_per_step, r1, r2, d,
224 | max_drag_per_track, max_track_no_change, feature_idx, result_save_path, save_intermediates_images],
225 | [output_image, new_points]
226 | )
227 |
228 |
229 | def attach_show_points_event(show_points, output_image, selected_points):
230 | show_points.click(
231 | show_cur_points,
232 | [output_image, selected_points],
233 | [output_image]
234 | )
235 |
236 |
237 | def attach_clear_all_button_event(clear_all_button, canvas, input_image,
238 | output_image, selected_points, original_image, mask):
239 | clear_all_button.click(
240 | clear_all,
241 | [gr.Number(value=LENGTH, visible=False, precision=0)],
242 | [canvas, input_image, output_image, selected_points, original_image, mask]
243 | )
244 |
245 |
246 | def attach_save_button_event(save_button, mask, selected_points, input_image, save_dir):
247 | """
248 | Attaches an event to the save button to trigger the save function.
249 | """
250 | save_button.click(
251 | save_image_mask_points,
252 | inputs=[mask, selected_points, input_image, save_dir],
253 | outputs=[]
254 | )
255 |
256 |
257 | def attach_save_result_event(save_result, output_image, new_points, result_path):
258 | """
259 | Attaches an event to the save button to trigger the save function.
260 | """
261 | save_result.click(
262 | save_drag_result,
263 | inputs=[output_image, new_points, result_path],
264 | outputs=[]
265 | )
266 |
267 |
268 | def attach_video_event(get_mp4_button, result_save_path, data_dir):
269 | get_mp4_button.click(
270 | create_video,
271 | inputs=[result_save_path, data_dir]
272 | )
273 |
274 |
275 | def main():
276 | with gr.Blocks() as demo:
277 | mask = gr.State(value=None)
278 | selected_points = gr.State([])
279 | new_points = gr.State([])
280 | original_image = gr.State(value=None)
281 | create_markdown_section()
282 | intermediate_images = gr.State([])
283 |
284 | canvas, train_lora_button, lora_path, lora_status_bar, input_image, undo_button, save_button, data_dir, \
285 | output_image, run_button, clear_all_button, show_points, result_save_path, \
286 | save_result = create_real_image_editing_ui()
287 |
288 | latent_lr, prompt, drag_end_step, drag_per_step = create_drag_parameters_ui()
289 |
290 | model_path, vae_path = create_base_model_config_ui()
291 | lora_step, lora_lr, lora_batch_size, lora_rank = create_lora_parameters_ui()
292 | r1, r2, d, feature_idx, max_drag_per_track, lam, inversion_strength, max_track_no_change = \
293 | create_advance_parameters_ui()
294 | save_intermediates_images, get_mp4_button = create_intermediate_save_ui()
295 |
296 | attach_canvas_event(canvas, original_image, selected_points, input_image, mask)
297 | attach_input_image_event(input_image, selected_points)
298 | attach_undo_button_event(undo_button, original_image, selected_points, mask, input_image)
299 | attach_train_lora_button_event(train_lora_button, original_image, prompt, model_path, vae_path, lora_path,
300 | lora_step, lora_lr, lora_batch_size, lora_rank, lora_status_bar)
301 | attach_run_button_event(run_button, original_image, input_image, mask, prompt, selected_points,
302 | inversion_strength, lam, latent_lr, model_path, vae_path, lora_path,
303 | drag_end_step, drag_per_step, output_image,
304 | r1, r2, d, feature_idx, new_points, max_drag_per_track,
305 | max_track_no_change, result_save_path, save_intermediates_images)
306 | attach_show_points_event(show_points, output_image, new_points)
307 | attach_clear_all_button_event(clear_all_button, canvas, input_image, output_image, selected_points,
308 | original_image, mask)
309 | attach_save_button_event(save_button, mask, selected_points, input_image, data_dir)
310 | attach_save_result_event(save_result, output_image, new_points, result_save_path)
311 | attach_video_event(get_mp4_button, result_save_path, data_dir)
312 |
313 | demo.queue().launch(share=True, debug=True)
314 |
315 |
316 | if __name__ == '__main__':
317 | main()
318 |
--------------------------------------------------------------------------------
/pipeline.py:
--------------------------------------------------------------------------------
1 | # *************************************************************************
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # *************************************************************************
14 |
15 |
16 | import torch
17 | import numpy as np
18 | import copy
19 |
20 | import torch.nn.functional as F
21 | from einops import rearrange
22 | from tqdm import tqdm
23 | from PIL import Image
24 | from typing import Any, Dict, List, Optional, Tuple, Union
25 |
26 | from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
27 | from utils.drag_utils import point_tracking, check_handle_reach_target, interpolate_feature_patch
28 | from utils.attn_utils import register_attention_editor_diffusers, MutualSelfAttentionControl
29 | from diffusers import DDIMScheduler, AutoencoderKL
30 | from pytorch_lightning import seed_everything
31 | from accelerate import Accelerator
32 |
33 |
34 | # from diffusers.models.attention_processor import LoRAAttnProcessor2_0
35 |
36 |
37 | # override unet forward
38 | # The only difference from diffusers:
39 | # return intermediate UNet features of all UpSample blocks
40 | def override_forward(self):
41 | def forward(
42 | sample: torch.FloatTensor,
43 | timestep: Union[torch.Tensor, float, int],
44 | encoder_hidden_states: torch.Tensor,
45 | class_labels: Optional[torch.Tensor] = None,
46 | timestep_cond: Optional[torch.Tensor] = None,
47 | attention_mask: Optional[torch.Tensor] = None,
48 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
49 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
50 | mid_block_additional_residual: Optional[torch.Tensor] = None,
51 | return_intermediates: bool = False,
52 | last_up_block_idx: int = None,
53 | ):
54 | # By default samples have to be AT least a multiple of the overall upsampling factor.
55 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
56 | # However, the upsampling interpolation output size can be forced to fit any upsampling size
57 | # on the fly if necessary.
58 | default_overall_up_factor = 2 ** self.num_upsamplers
59 |
60 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
61 | forward_upsample_size = False
62 | upsample_size = None
63 |
64 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
65 | forward_upsample_size = True
66 |
67 | # prepare attention_mask
68 | if attention_mask is not None:
69 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
70 | attention_mask = attention_mask.unsqueeze(1)
71 |
72 | # 0. center input if necessary
73 | if self.config.center_input_sample:
74 | sample = 2 * sample - 1.0
75 |
76 | # 1. time
77 | timesteps = timestep
78 | if not torch.is_tensor(timesteps):
79 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
80 | # This would be a good case for the `match` statement (Python 3.10+)
81 | is_mps = sample.device.type == "mps"
82 | if isinstance(timestep, float):
83 | dtype = torch.float32 if is_mps else torch.float64
84 | else:
85 | dtype = torch.int32 if is_mps else torch.int64
86 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
87 | elif len(timesteps.shape) == 0:
88 | timesteps = timesteps[None].to(sample.device)
89 |
90 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
91 | timesteps = timesteps.expand(sample.shape[0])
92 |
93 | t_emb = self.time_proj(timesteps)
94 |
95 | # `Timesteps` does not contain any weights and will always return f32 tensors
96 | # but time_embedding might actually be running in fp16. so we need to cast here.
97 | # there might be better ways to encapsulate this.
98 | t_emb = t_emb.to(dtype=self.dtype)
99 |
100 | emb = self.time_embedding(t_emb, timestep_cond)
101 |
102 | if self.class_embedding is not None:
103 | if class_labels is None:
104 | raise ValueError("class_labels should be provided when num_class_embeds > 0")
105 |
106 | if self.config.class_embed_type == "timestep":
107 | class_labels = self.time_proj(class_labels)
108 |
109 | # `Timesteps` does not contain any weights and will always return f32 tensors
110 | # there might be better ways to encapsulate this.
111 | class_labels = class_labels.to(dtype=sample.dtype)
112 |
113 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
114 |
115 | if self.config.class_embeddings_concat:
116 | emb = torch.cat([emb, class_emb], dim=-1)
117 | else:
118 | emb = emb + class_emb
119 |
120 | if self.config.addition_embed_type == "text":
121 | aug_emb = self.add_embedding(encoder_hidden_states)
122 | emb = emb + aug_emb
123 |
124 | if self.time_embed_act is not None:
125 | emb = self.time_embed_act(emb)
126 |
127 | if self.encoder_hid_proj is not None:
128 | encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
129 |
130 | # 2. pre-process
131 | sample = self.conv_in(sample)
132 |
133 | # 3. down
134 | down_block_res_samples = (sample,)
135 | for downsample_block in self.down_blocks:
136 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
137 | sample, res_samples = downsample_block(
138 | hidden_states=sample,
139 | temb=emb,
140 | encoder_hidden_states=encoder_hidden_states,
141 | attention_mask=attention_mask,
142 | cross_attention_kwargs=cross_attention_kwargs,
143 | )
144 | else:
145 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
146 |
147 | down_block_res_samples += res_samples
148 |
149 | if down_block_additional_residuals is not None:
150 | new_down_block_res_samples = ()
151 |
152 | for down_block_res_sample, down_block_additional_residual in zip(
153 | down_block_res_samples, down_block_additional_residuals
154 | ):
155 | down_block_res_sample = down_block_res_sample + down_block_additional_residual
156 | new_down_block_res_samples += (down_block_res_sample,)
157 |
158 | down_block_res_samples = new_down_block_res_samples
159 |
160 | # 4. mid
161 | if self.mid_block is not None:
162 | sample = self.mid_block(
163 | sample,
164 | emb,
165 | encoder_hidden_states=encoder_hidden_states,
166 | attention_mask=attention_mask,
167 | cross_attention_kwargs=cross_attention_kwargs,
168 | )
169 |
170 | if mid_block_additional_residual is not None:
171 | sample = sample + mid_block_additional_residual
172 |
173 | # 5. up
174 | # only difference from diffusers:
175 | # save the intermediate features of unet upsample blocks
176 | # the 0-th element is the mid-block output
177 | all_intermediate_features = [sample]
178 | for i, upsample_block in enumerate(self.up_blocks):
179 | is_final_block = i == len(self.up_blocks) - 1
180 |
181 | res_samples = down_block_res_samples[-len(upsample_block.resnets):]
182 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
183 |
184 | # if we have not reached the final block and need to forward the
185 | # upsample size, we do it here
186 | if not is_final_block and forward_upsample_size:
187 | upsample_size = down_block_res_samples[-1].shape[2:]
188 |
189 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
190 | sample = upsample_block(
191 | hidden_states=sample,
192 | temb=emb,
193 | res_hidden_states_tuple=res_samples,
194 | encoder_hidden_states=encoder_hidden_states,
195 | cross_attention_kwargs=cross_attention_kwargs,
196 | upsample_size=upsample_size,
197 | attention_mask=attention_mask,
198 | )
199 | else:
200 | sample = upsample_block(
201 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
202 | )
203 | all_intermediate_features.append(sample)
204 | # return early to save computation time if needed
205 | if last_up_block_idx is not None and i == last_up_block_idx:
206 | return all_intermediate_features
207 |
208 | # 6. post-process
209 | if self.conv_norm_out:
210 | sample = self.conv_norm_out(sample)
211 | sample = self.conv_act(sample)
212 | sample = self.conv_out(sample)
213 |
214 | # only difference from diffusers, return intermediate results
215 | if return_intermediates:
216 | return sample, all_intermediate_features
217 | else:
218 | return sample
219 |
220 | return forward
221 |
222 |
223 | class GoodDragger:
224 | def __init__(self, device, model_path: str, prompt: str,
225 | full_height: int, full_width: int,
226 | inversion_strength: float,
227 | r1: int = 4, r2: int = 12, beta: int = 4,
228 | drag_end_step: int = 10, track_per_denoise: int = 10,
229 | lam: float = 0.2, latent_lr: float = 0.01,
230 | n_inference_step: int = 50, guidance_scale: float = 1.0, feature_idx: int = 3,
231 | compare_mode: bool = False,
232 | vae_path: str = "default", lora_path: str = '', seed: int = 42,
233 | max_drag_per_track: int = 10, drag_loss_threshold: float = 4.0, once_drag: bool = False,
234 | max_track_no_change: int = 10):
235 | self.device = device
236 | self.vae_path = vae_path
237 | self.lora_path = lora_path
238 | scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012,
239 | beta_schedule="scaled_linear", clip_sample=False,
240 | set_alpha_to_one=False, steps_offset=1)
241 |
242 | is_sdxl = 'xl' in model_path
243 | self.is_sdxl = is_sdxl
244 | if is_sdxl:
245 | self.model = StableDiffusionXLPipeline.from_pretrained(model_path, scheduler=scheduler).to(self.device)
246 | self.model.unet.config.addition_embed_type = None
247 | else:
248 | self.model = StableDiffusionPipeline.from_pretrained(model_path, scheduler=scheduler).to(self.device)
249 | self.modify_unet_forward()
250 | if vae_path != "default":
251 | self.model.vae = AutoencoderKL.from_pretrained(
252 | vae_path
253 | ).to(self.device, self.model.vae.dtype)
254 |
255 | self.set_lora()
256 |
257 | self.model.vae.requires_grad_(False)
258 | self.model.text_encoder.requires_grad_(False)
259 |
260 | seed_everything(seed)
261 |
262 | self.prompt = prompt
263 | self.full_height = full_height
264 | self.full_width = full_width
265 | self.sup_res_h = int(0.5 * full_height)
266 | self.sup_res_w = int(0.5 * full_width)
267 |
268 | self.n_inference_step = n_inference_step
269 | self.n_actual_inference_step = round(inversion_strength * self.n_inference_step)
270 | self.guidance_scale = guidance_scale
271 |
272 | self.unet_feature_idx = [feature_idx]
273 |
274 | self.r_1 = r1
275 | self.r_2 = r2
276 | self.lam = lam
277 | self.beta = beta
278 |
279 | self.lr = latent_lr
280 | self.compare_mode = compare_mode
281 |
282 | self.t2 = drag_end_step
283 | self.track_per_denoise = track_per_denoise
284 | self.total_drag = int(track_per_denoise * self.t2)
285 |
286 | self.model.scheduler.set_timesteps(self.n_inference_step)
287 |
288 | self.do_drag = True
289 | self.drag_count = 0
290 | self.max_drag_per_track = max_drag_per_track
291 |
292 | self.drag_loss_threshold = drag_loss_threshold * ((2 * self.r_1) ** 2)
293 | self.once_drag = once_drag
294 | self.no_change_track_num = 0
295 | self.max_no_change_track_num = max_track_no_change
296 |
297 | def set_lora(self):
298 | if self.lora_path == "":
299 | print("applying default parameters")
300 | self.model.unet.set_default_attn_processor()
301 | else:
302 | print("applying lora: " + self.lora_path)
303 | self.model.unet.load_attn_procs(self.lora_path)
304 |
305 | def modify_unet_forward(self):
306 | self.model.unet.forward = override_forward(self.model.unet)
307 |
308 | def get_handle_target_points(self, points):
309 | handle_points = []
310 | target_points = []
311 |
312 | for idx, point in enumerate(points):
313 | cur_point = torch.tensor(
314 | [point[1] / self.full_height * self.sup_res_h, point[0] / self.full_width * self.sup_res_w])
315 | cur_point = torch.round(cur_point)
316 | if idx % 2 == 0:
317 | handle_points.append(cur_point)
318 | else:
319 | target_points.append(cur_point)
320 | print(f'handle points: {handle_points}')
321 | print(f'target points: {target_points}')
322 | return handle_points, target_points
323 |
324 | def inv_step(
325 | self,
326 | model_output: torch.FloatTensor,
327 | timestep: int,
328 | x: torch.FloatTensor,
329 | verbose=False
330 | ):
331 | """
332 | Inverse sampling for DDIM Inversion
333 | """
334 | if verbose:
335 | print("timestep: ", timestep)
336 | next_step = timestep
337 | timestep = min(
338 | timestep - self.model.scheduler.config.num_train_timesteps // self.model.scheduler.num_inference_steps, 999)
339 | alpha_prod_t = self.model.scheduler.alphas_cumprod[
340 | timestep] if timestep >= 0 else self.model.scheduler.final_alpha_cumprod
341 | alpha_prod_t_next = self.model.scheduler.alphas_cumprod[next_step]
342 | beta_prod_t = 1 - alpha_prod_t
343 | pred_x0 = (x - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
344 | pred_dir = (1 - alpha_prod_t_next) ** 0.5 * model_output
345 | x_next = alpha_prod_t_next ** 0.5 * pred_x0 + pred_dir
346 | return x_next, pred_x0
347 |
348 | @torch.no_grad()
349 | def image2latent(self, image):
350 | if type(image) is Image:
351 | image = np.array(image)
352 | image = torch.from_numpy(image).float() / 127.5 - 1
353 | image = image.permute(2, 0, 1).unsqueeze(0).to(self.device)
354 |
355 | latents = self.model.vae.encode(image)['latent_dist'].mean
356 | latents = latents * 0.18215
357 | return latents
358 |
359 | @torch.no_grad()
360 | def latent2image(self, latents, return_type='np'):
361 | latents = 1 / 0.18215 * latents.detach()
362 | image = self.model.vae.decode(latents)['sample']
363 | if return_type == 'np':
364 | image = (image / 2 + 0.5).clamp(0, 1)
365 | image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
366 | image = (image * 255).astype(np.uint8)
367 | elif return_type == "pt":
368 | image = (image / 2 + 0.5).clamp(0, 1)
369 |
370 | return image
371 |
372 | @torch.no_grad()
373 | def get_text_embeddings(self, prompt):
374 | text_input = self.model.tokenizer(
375 | prompt,
376 | padding="max_length",
377 | max_length=77,
378 | return_tensors="pt"
379 | )
380 | text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.device))[0]
381 | return text_embeddings
382 |
383 | def forward_unet_features(self, z, t, encoder_hidden_states):
384 | unet_output, all_intermediate_features = self.model.unet(
385 | z,
386 | t,
387 | encoder_hidden_states=encoder_hidden_states,
388 | return_intermediates=True
389 | )
390 |
391 | all_return_features = []
392 | for idx in self.unet_feature_idx:
393 | feat = all_intermediate_features[idx]
394 | feat = F.interpolate(feat, (self.sup_res_h, self.sup_res_w), mode='bilinear')
395 | all_return_features.append(feat)
396 | return_features = torch.cat(all_return_features, dim=1)
397 |
398 | del all_intermediate_features
399 | torch.cuda.empty_cache()
400 |
401 | return unet_output, return_features
402 |
403 | @torch.no_grad()
404 | def invert(
405 | self,
406 | image: torch.Tensor,
407 | prompt,
408 | return_intermediates=False,
409 | ):
410 | """
411 | invert a real image into noise map with determinisc DDIM inversion
412 | """
413 | batch_size = image.shape[0]
414 | if isinstance(prompt, list):
415 | if batch_size == 1:
416 | image = image.expand(len(prompt), -1, -1, -1)
417 | elif isinstance(prompt, str):
418 | if batch_size > 1:
419 | prompt = [prompt] * batch_size
420 |
421 | if self.is_sdxl:
422 | text_embeddings, _, _, _ = self.model.encode_prompt(prompt)
423 | else:
424 | text_embeddings = self.get_text_embeddings(prompt)
425 |
426 | latents = self.image2latent(image)
427 |
428 | if self.guidance_scale > 1.:
429 | unconditional_embeddings = self.get_text_embeddings([''] * batch_size)
430 | text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0)
431 |
432 | print("Valid timesteps: ", self.model.scheduler.timesteps)
433 | latents_list = [latents]
434 | pred_x0_list = [latents]
435 | for i, t in enumerate(tqdm(reversed(self.model.scheduler.timesteps), desc="DDIM Inversion")):
436 | if self.n_actual_inference_step is not None and i >= self.n_actual_inference_step:
437 | continue
438 |
439 | if self.guidance_scale > 1.:
440 | model_inputs = torch.cat([latents] * 2)
441 | else:
442 | model_inputs = latents
443 |
444 | t_ = self.model.scheduler.timesteps[-(i + 2)]
445 |
446 | noise_pred = self.model.unet(model_inputs, t, encoder_hidden_states=text_embeddings)
447 | if self.guidance_scale > 1.:
448 | noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
449 | noise_pred = noise_pred_uncon + self.guidance_scale * (noise_pred_con - noise_pred_uncon)
450 |
451 | latents, pred_x0 = self.inv_step(noise_pred, t, latents)
452 | latents_list.append(latents)
453 | pred_x0_list.append(pred_x0)
454 |
455 | if return_intermediates:
456 | return latents, latents_list
457 | return latents
458 |
459 | def get_original_features(self, init_code, text_embeddings):
460 | timesteps = self.model.scheduler.timesteps
461 | strat_time_step_idx = self.n_inference_step - self.n_actual_inference_step
462 | original_step_output = {}
463 | features = {}
464 | cur_latents = init_code.detach().clone()
465 | with torch.no_grad():
466 | for i, t in enumerate(tqdm(timesteps[strat_time_step_idx:],
467 | desc="Denosing for mask features")):
468 | if i <= self.t2:
469 | model_inputs = cur_latents
470 | noise_pred, F0 = self.forward_unet_features(model_inputs, t, encoder_hidden_states=text_embeddings)
471 | cur_latents = self.model.scheduler.step(noise_pred, t, model_inputs, return_dict=False)[0]
472 | original_step_output[t.item()] = cur_latents.cpu()
473 | features[t.item()] = F0.cpu()
474 |
475 | del noise_pred, cur_latents, F0
476 | torch.cuda.empty_cache()
477 |
478 | return original_step_output, features
479 |
480 | def get_noise_features(self, input_latents, t, text_embeddings):
481 | unet_output, F1 = self.forward_unet_features(input_latents, t, encoder_hidden_states=text_embeddings)
482 | return unet_output, F1
483 |
484 | def cal_motion_supervision_loss(self, handle_points, target_points, F1, x_prev_updated, original_prev,
485 | interp_mask, original_features, original_points, alpha=None):
486 | drag_loss = 0.0
487 | for i_ in range(len(handle_points)):
488 | pi, ti = handle_points[i_], target_points[i_]
489 | norm_dis = (ti - pi).norm()
490 | if norm_dis < 2.:
491 | continue
492 |
493 | di = (ti - pi) / (ti - pi).norm() * min(self.beta, norm_dis)
494 |
495 | original_features.requires_grad_(True)
496 | pi = original_points[i_]
497 | f0_patch = original_features[:, :, int(pi[0]) - self.r_1:int(pi[0]) + self.r_1 + 1,
498 | int(pi[1]) - self.r_1:int(pi[1]) + self.r_1 + 1].detach()
499 |
500 | pi = handle_points[i_]
501 | f1_patch = interpolate_feature_patch(F1, pi[0] + di[0], pi[1] + di[1], self.r_1)
502 | drag_loss += ((2 * self.r_1) ** 2) * F.l1_loss(f0_patch, f1_patch)
503 |
504 | print(f'Loss from drag: {drag_loss}')
505 | loss = drag_loss + self.lam * ((x_prev_updated - original_prev)
506 | * (1.0 - interp_mask)).abs().sum()
507 | print('Loss total=%f' % loss)
508 | return loss, drag_loss
509 |
510 | def track_step(self, original_feature, original_feature_, F1, F1_, handle_points, handle_points_init):
511 | if self.compare_mode:
512 | handle_points = point_tracking(original_feature,
513 | F1, handle_points, handle_points_init, self.r_2)
514 | else:
515 | handle_points = point_tracking(original_feature_,
516 | F1_, handle_points, handle_points_init, self.r_2)
517 | return handle_points
518 |
519 | def compare_tensor_lists(self, lst1, lst2):
520 | if len(lst1) != len(lst2):
521 | return False
522 | return all(torch.equal(t1, t2) for t1, t2 in zip(lst1, lst2))
523 |
524 | def gooddrag_step(self, init_code, t, t_, text_embeddings, handle_points, target_points,
525 | features, handle_points_init, original_step_output, interp_mask):
526 | drag_latents = init_code.clone().detach()
527 | drag_latents.requires_grad_(True)
528 |
529 | first_drag = True
530 | need_track = False
531 | track_num = 0
532 | cur_drag_per_track = 0
533 | self.compare_mode = True
534 | accelerator = Accelerator(
535 | gradient_accumulation_steps=1,
536 | mixed_precision='fp16'
537 | )
538 |
539 | optimizer = torch.optim.Adam([drag_latents], lr=self.lr)
540 |
541 | drag_latents, self.model.unet, optimizer = accelerator.prepare(drag_latents, self.model.unet, optimizer)
542 | while track_num < self.track_per_denoise:
543 | optimizer.zero_grad()
544 | unet_output, F1 = self.forward_unet_features(drag_latents, t, text_embeddings)
545 | x_prev_updated = self.model.scheduler.step(unet_output, t, drag_latents, return_dict=False)[0]
546 |
547 | if (need_track or first_drag) and (not self.compare_mode):
548 | with torch.no_grad():
549 | _, F1_ = self.forward_unet_features(x_prev_updated, t_, text_embeddings)
550 |
551 | if first_drag:
552 | first_drag = False
553 | if self.compare_mode:
554 | handle_points = point_tracking(features[t.item()].cuda(),
555 | F1, handle_points, handle_points_init, self.r_2)
556 | else:
557 | handle_points = point_tracking(features[t_.item()].cuda(),
558 | F1_, handle_points, handle_points_init, self.r_2)
559 |
560 | print(f'After denoise new handle points: {handle_points}, drag count: {self.drag_count}')
561 |
562 | # break if all handle points have reached the targets
563 | if check_handle_reach_target(handle_points, target_points):
564 | self.do_drag = False
565 | print('Reached the target points')
566 | break
567 |
568 | if self.no_change_track_num == self.max_no_change_track_num:
569 | self.do_drag = False
570 | print('Early stop.')
571 | break
572 |
573 | del unet_output
574 | if need_track and (not self.compare_mode):
575 | del _
576 | torch.cuda.empty_cache()
577 |
578 | loss, drag_loss = self.cal_motion_supervision_loss(handle_points, target_points, F1, x_prev_updated,
579 | original_step_output[t.item()].cuda(), interp_mask,
580 | original_features=features[t.item()].cuda(),
581 | original_points=handle_points_init)
582 |
583 | accelerator.backward(loss)
584 |
585 | optimizer.step()
586 |
587 | cur_drag_per_track += 1
588 | need_track = (cur_drag_per_track == self.max_drag_per_track) or (
589 | drag_loss <= self.drag_loss_threshold) or self.once_drag
590 | if need_track:
591 | track_num += 1
592 | handle_points_cur = copy.deepcopy(handle_points)
593 | if self.compare_mode:
594 | handle_points = point_tracking(features[t.item()].cuda(),
595 | F1, handle_points, handle_points_init, self.r_2)
596 | else:
597 | handle_points = point_tracking(features[t_.item()].cuda(),
598 | F1_, handle_points, handle_points_init, self.r_2)
599 |
600 | if self.compare_tensor_lists(handle_points, handle_points_cur):
601 | self.no_change_track_num += 1
602 | print(f'{self.no_change_track_num} times handle points no changes.')
603 | else:
604 | self.no_change_track_num = 0
605 |
606 | self.drag_count += 1
607 | cur_drag_per_track = 0
608 | print(f'New handle points: {handle_points}, drag count: {self.drag_count}')
609 |
610 | init_code = drag_latents.clone().detach()
611 | init_code.requires_grad_(False)
612 | del optimizer, drag_latents
613 | torch.cuda.empty_cache()
614 |
615 | return init_code, handle_points
616 |
617 | def prepare_mask(self, mask):
618 | mask = torch.from_numpy(mask).float() / 255.
619 | mask[mask > 0.0] = 1.0
620 | mask = rearrange(mask, "h w -> 1 1 h w").cuda()
621 | mask = F.interpolate(mask, (self.sup_res_h, self.sup_res_w), mode="nearest")
622 | return mask
623 |
624 | def set_latent_masactrl(self):
625 | editor = MutualSelfAttentionControl(start_step=0,
626 | start_layer=10,
627 | total_steps=self.n_inference_step,
628 | guidance_scale=self.guidance_scale)
629 | if self.lora_path == "":
630 | register_attention_editor_diffusers(self.model, editor, attn_processor='attn_proc')
631 | else:
632 | register_attention_editor_diffusers(self.model, editor, attn_processor='lora_attn_proc')
633 |
634 | def get_intermediate_images(self, intermediate_images, intermediate_images_original, intermediate_images_t_idx,
635 | valid_timestep, text_embeddings):
636 | for i in range(len(intermediate_images)-1):
637 | current_original_code = intermediate_images_original[i].to(self.device)
638 | current_init_code = intermediate_images[i].to(self.device)
639 |
640 | self.set_latent_masactrl()
641 |
642 | for inter_i, inter_t in enumerate(valid_timestep[intermediate_images_t_idx[i] + 1:]):
643 | with torch.no_grad():
644 | noise_pred_all = self.model.unet(torch.cat([current_original_code, current_init_code]), inter_t,
645 | encoder_hidden_states=torch.cat(
646 | [text_embeddings, text_embeddings]))
647 | noise_pred = noise_pred_all[1]
648 | noise_pred_original = noise_pred_all[0]
649 | current_init_code = \
650 | self.model.scheduler.step(noise_pred, inter_t, current_init_code, return_dict=False)[0]
651 | current_original_code = \
652 | self.model.scheduler.step(noise_pred_original, inter_t, current_original_code,
653 | return_dict=False)[0]
654 | intermediate_images[i] = self.latent2image(current_init_code, return_type="pt").cpu()
655 | intermediate_images.pop()
656 | return intermediate_images
657 |
658 | def good_drag(self,
659 | source_image,
660 | points,
661 | mask,
662 | return_intermediate_images=False,
663 | return_intermediate_features=False
664 | ):
665 | init_code = self.invert(source_image, self.prompt)
666 | original_init = init_code.detach().clone()
667 | if self.is_sdxl:
668 | text_embeddings, _, _, _ = self.model.encode_prompt(self.prompt)
669 | text_embeddings = text_embeddings.detach()
670 | else:
671 | text_embeddings = self.get_text_embeddings(self.prompt).detach()
672 |
673 | self.model.text_encoder.to('cpu')
674 | self.model.vae.encoder.to('cpu')
675 |
676 | timesteps = self.model.scheduler.timesteps
677 | start_time_step_idx = self.n_inference_step - self.n_actual_inference_step
678 |
679 | handle_points, target_points = self.get_handle_target_points(points)
680 | original_step_output, features = self.get_original_features(init_code, text_embeddings)
681 |
682 | handle_points_init = copy.deepcopy(handle_points)
683 | mask = self.prepare_mask(mask)
684 | interp_mask = F.interpolate(mask, (init_code.shape[2], init_code.shape[3]), mode='nearest')
685 |
686 | intermediate_features = [init_code.detach().clone().cpu()] if return_intermediate_features else []
687 | valid_timestep = timesteps[start_time_step_idx:]
688 | set_mutual = True
689 |
690 | intermediate_images, intermediate_images_original, intermediate_images_t_idx = [], [], []
691 |
692 | did_drag = False
693 | for i, t in enumerate(tqdm(valid_timestep,
694 | desc="Drag and Denoise")):
695 | if i < self.t2 and self.do_drag and (self.no_change_track_num != self.max_no_change_track_num):
696 | t_ = valid_timestep[i + 1]
697 | init_code, handle_points = self.gooddrag_step(init_code, t, t_, text_embeddings, handle_points,
698 | target_points, features, handle_points_init,
699 | original_step_output, interp_mask)
700 | did_drag = True
701 | else:
702 | if set_mutual:
703 | set_mutual = False
704 | self.set_latent_masactrl()
705 |
706 | with torch.no_grad():
707 | noise_pred_all = self.model.unet(torch.cat([original_init, init_code]), t,
708 | encoder_hidden_states=torch.cat([text_embeddings, text_embeddings]))
709 | noise_pred = noise_pred_all[1]
710 | noise_pred_original = noise_pred_all[0]
711 | init_code = self.model.scheduler.step(noise_pred, t, init_code, return_dict=False)[0]
712 | original_init = self.model.scheduler.step(noise_pred_original, t, original_init, return_dict=False)[0]
713 |
714 | if did_drag and return_intermediate_images:
715 | current_init_code = init_code.detach().clone()
716 | current_original_code = original_init.detach().clone()
717 |
718 | intermediate_images.append(current_init_code.cpu())
719 | intermediate_images_original.append(current_original_code.cpu())
720 | intermediate_images_t_idx.append(i)
721 | did_drag = False
722 | if return_intermediate_features:
723 | intermediate_features.append(init_code.detach().clone().cpu())
724 |
725 | if return_intermediate_images:
726 | intermediate_images = self.get_intermediate_images(intermediate_images, intermediate_images_original,
727 | intermediate_images_t_idx, valid_timestep, text_embeddings)
728 |
729 | image = self.latent2image(init_code, return_type="pt")
730 | return image, intermediate_features, handle_points, intermediate_images
731 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --find-links https://download.pytorch.org/whl/torch_stable.html
2 | torch==2.0.1+cu118
3 | scipy>=1.11.1
4 | transformers==4.27.0
5 | diffusers==0.20.0
6 | gradio==3.50.2
7 | einops==0.6.1
8 | pytorch_lightning==2.0.8
9 | accelerate==0.17.0
10 | opencv-python==4.8.0.76
11 | torchvision==0.15.2+cu118
12 |
--------------------------------------------------------------------------------
/utils/__pycache__/attn_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/utils/__pycache__/attn_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/drag_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/utils/__pycache__/drag_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/lora_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/utils/__pycache__/lora_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/ui_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zewei-Zhang/GoodDrag/f6d2aff55a544d6254b15b66542ba2704f154a82/utils/__pycache__/ui_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/attn_utils.py:
--------------------------------------------------------------------------------
1 | # *************************************************************************
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # *************************************************************************
14 |
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 |
20 | from einops import rearrange, repeat
21 |
22 |
23 | class AttentionBase:
24 |
25 | def __init__(self):
26 | self.cur_step = 0
27 | self.num_att_layers = -1
28 | self.cur_att_layer = 0
29 |
30 | def after_step(self):
31 | pass
32 |
33 | def __call__(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
34 | out = self.forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs)
35 | self.cur_att_layer += 1
36 | if self.cur_att_layer == self.num_att_layers:
37 | self.cur_att_layer = 0
38 | self.cur_step += 1
39 | # after step
40 | self.after_step()
41 | return out
42 |
43 | def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
44 | out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
45 | out = rearrange(out, 'b h n d -> b n (h d)')
46 | return out
47 |
48 | def reset(self):
49 | self.cur_step = 0
50 | self.cur_att_layer = 0
51 |
52 |
53 | class MutualSelfAttentionControl(AttentionBase):
54 |
55 | def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, guidance_scale=7.5):
56 | """
57 | Mutual self-attention control for Stable-Diffusion model
58 | Args:
59 | start_step: the step to start mutual self-attention control
60 | start_layer: the layer to start mutual self-attention control
61 | layer_idx: list of the layers to apply mutual self-attention control
62 | step_idx: list the steps to apply mutual self-attention control
63 | total_steps: the total number of steps
64 | """
65 | super().__init__()
66 | self.total_steps = total_steps
67 | self.start_step = start_step
68 | self.start_layer = start_layer
69 | self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, 16))
70 | self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps))
71 | # store the guidance scale to decide whether there are unconditional branch
72 | self.guidance_scale = guidance_scale
73 | print("step_idx: ", self.step_idx)
74 | print("layer_idx: ", self.layer_idx)
75 |
76 | def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
77 | """
78 | Attention forward function
79 | """
80 | if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
81 | return super().forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs)
82 |
83 | if self.guidance_scale > 1.0:
84 | qu, qc = q[0:2], q[2:4]
85 | ku, kc = k[0:2], k[2:4]
86 | vu, vc = v[0:2], v[2:4]
87 |
88 | # merge queries of source and target branch into one so we can use torch API
89 | qu = torch.cat([qu[0:1], qu[1:2]], dim=2)
90 | qc = torch.cat([qc[0:1], qc[1:2]], dim=2)
91 |
92 | out_u = F.scaled_dot_product_attention(qu, ku[0:1], vu[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
93 | out_u = torch.cat(out_u.chunk(2, dim=2), dim=0) # split the queries into source and target batch
94 | out_u = rearrange(out_u, 'b h n d -> b n (h d)')
95 |
96 | out_c = F.scaled_dot_product_attention(qc, kc[0:1], vc[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
97 | out_c = torch.cat(out_c.chunk(2, dim=2), dim=0) # split the queries into source and target batch
98 | out_c = rearrange(out_c, 'b h n d -> b n (h d)')
99 |
100 | out = torch.cat([out_u, out_c], dim=0)
101 | else:
102 | q = torch.cat([q[0:1], q[1:2]], dim=2)
103 | out = F.scaled_dot_product_attention(q, k[0:1], v[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
104 | out = torch.cat(out.chunk(2, dim=2), dim=0) # split the queries into source and target batch
105 | out = rearrange(out, 'b h n d -> b n (h d)')
106 | return out
107 |
108 | # forward function for default attention processor
109 | # modified from __call__ function of AttnProcessor in diffusers
110 | def override_attn_proc_forward(attn, editor, place_in_unet):
111 | def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
112 | """
113 | The attention is similar to the original implementation of LDM CrossAttention class
114 | except adding some modifications on the attention
115 | """
116 | if encoder_hidden_states is not None:
117 | context = encoder_hidden_states
118 | if attention_mask is not None:
119 | mask = attention_mask
120 |
121 | to_out = attn.to_out
122 | if isinstance(to_out, nn.modules.container.ModuleList):
123 | to_out = attn.to_out[0]
124 | else:
125 | to_out = attn.to_out
126 |
127 | h = attn.heads
128 | q = attn.to_q(x)
129 | is_cross = context is not None
130 | context = context if is_cross else x
131 | k = attn.to_k(context)
132 | v = attn.to_v(context)
133 |
134 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
135 |
136 | # the only difference
137 | out = editor(
138 | q, k, v, is_cross, place_in_unet,
139 | attn.heads, scale=attn.scale)
140 |
141 | return to_out(out)
142 |
143 | return forward
144 |
145 | # forward function for lora attention processor
146 | # modified from __call__ function of LoRAAttnProcessor2_0 in diffusers v0.17.1
147 | def override_lora_attn_proc_forward(attn, editor, place_in_unet):
148 | def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, lora_scale=1.0):
149 | residual = hidden_states
150 | input_ndim = hidden_states.ndim
151 | is_cross = encoder_hidden_states is not None
152 |
153 | if input_ndim == 4:
154 | batch_size, channel, height, width = hidden_states.shape
155 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
156 |
157 | batch_size, sequence_length, _ = (
158 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
159 | )
160 |
161 | if attention_mask is not None:
162 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
163 | # scaled_dot_product_attention expects attention_mask shape to be
164 | # (batch, heads, source_length, target_length)
165 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
166 |
167 | if attn.group_norm is not None:
168 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
169 |
170 | query = attn.to_q(hidden_states) + lora_scale * attn.processor.to_q_lora(hidden_states)
171 |
172 | if encoder_hidden_states is None:
173 | encoder_hidden_states = hidden_states
174 | elif attn.norm_cross:
175 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
176 |
177 | key = attn.to_k(encoder_hidden_states) + lora_scale * attn.processor.to_k_lora(encoder_hidden_states)
178 | value = attn.to_v(encoder_hidden_states) + lora_scale * attn.processor.to_v_lora(encoder_hidden_states)
179 |
180 | query, key, value = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=attn.heads), (query, key, value))
181 |
182 | # the only difference
183 | hidden_states = editor(
184 | query, key, value, is_cross, place_in_unet,
185 | attn.heads, scale=attn.scale)
186 |
187 | # linear proj
188 | hidden_states = attn.to_out[0](hidden_states) + lora_scale * attn.processor.to_out_lora(hidden_states)
189 | # dropout
190 | hidden_states = attn.to_out[1](hidden_states)
191 |
192 | if input_ndim == 4:
193 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
194 |
195 | if attn.residual_connection:
196 | hidden_states = hidden_states + residual
197 |
198 | hidden_states = hidden_states / attn.rescale_output_factor
199 |
200 | return hidden_states
201 |
202 | return forward
203 |
204 | def register_attention_editor_diffusers(model, editor: AttentionBase, attn_processor='attn_proc'):
205 | """
206 | Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
207 | """
208 | def register_editor(net, count, place_in_unet):
209 | for name, subnet in net.named_children():
210 | if net.__class__.__name__ == 'Attention': # spatial Transformer layer
211 | if attn_processor == 'attn_proc':
212 | net.forward = override_attn_proc_forward(net, editor, place_in_unet)
213 | elif attn_processor == 'lora_attn_proc':
214 | net.forward = override_lora_attn_proc_forward(net, editor, place_in_unet)
215 | else:
216 | raise NotImplementedError("not implemented")
217 | return count + 1
218 | elif hasattr(net, 'children'):
219 | count = register_editor(subnet, count, place_in_unet)
220 | return count
221 |
222 | cross_att_count = 0
223 | for net_name, net in model.unet.named_children():
224 | if "down" in net_name:
225 | cross_att_count += register_editor(net, 0, "down")
226 | elif "mid" in net_name:
227 | cross_att_count += register_editor(net, 0, "mid")
228 | elif "up" in net_name:
229 | cross_att_count += register_editor(net, 0, "up")
230 | editor.num_att_layers = cross_att_count
231 |
--------------------------------------------------------------------------------
/utils/drag_utils.py:
--------------------------------------------------------------------------------
1 | # *************************************************************************
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # *************************************************************************
14 |
15 |
16 | import torch
17 | from typing import List
18 |
19 |
20 | def calculate_l1_distance(tensor1: torch.Tensor, tensor2: torch.Tensor) -> torch.Tensor:
21 | """Calculate the L1 (Manhattan) distance between two tensors."""
22 | return torch.sum(torch.abs(tensor1 - tensor2), dim=1)
23 |
24 |
25 | def calculate_l2_distance(tensor1: torch.Tensor, tensor2: torch.Tensor) -> torch.Tensor:
26 | """Calculate the L2 (Euclidean) distance between two tensors."""
27 | return torch.sqrt(torch.sum((tensor1 - tensor2) ** 2, dim=1))
28 |
29 |
30 | def calculate_cosine_similarity(tensor1: torch.Tensor, tensor2: torch.Tensor) -> torch.Tensor:
31 | """Calculate the Cosine Similarity between two tensors."""
32 | numerator = torch.sum(tensor1 * tensor2, dim=1)
33 | denominator = torch.sqrt(torch.sum(tensor1 ** 2, dim=1)) * torch.sqrt(torch.sum(tensor2 ** 2, dim=1))
34 | return numerator / denominator
35 |
36 |
37 | def get_neighboring_patch(tensor: torch.Tensor, center: tuple, radius: int) -> torch.Tensor:
38 | """Get a neighboring patch from a tensor centered at a specific point."""
39 | r1, r2 = int(center[0]) - radius, int(center[0]) + radius + 1
40 | c1, c2 = int(center[1]) - radius, int(center[1]) + radius + 1
41 | return tensor[:, :, r1:r2, c1:c2]
42 |
43 |
44 | def update_handle_points(handle_points: torch.Tensor, all_dist: torch.Tensor, r2) -> torch.Tensor:
45 | """Update handle points based on computed distances."""
46 | row, col = divmod(all_dist.argmin().item(), all_dist.shape[-1])
47 | updated_point = torch.tensor([
48 | handle_points[0] - r2 + row,
49 | handle_points[1] - r2 + col
50 | ])
51 | return updated_point
52 |
53 |
54 | def point_tracking(F0: torch.Tensor, F1: torch.Tensor, handle_points: List[torch.Tensor],
55 | handle_points_init: List[torch.Tensor], r2, distance_type: str = 'l1') -> List[torch.Tensor]:
56 | """Track points between F0 and F1 tensors."""
57 | with torch.no_grad():
58 | for i in range(len(handle_points)):
59 | pi0, pi = handle_points_init[i], handle_points[i]
60 | f0 = F0[:, :, int(pi0[0]), int(pi0[1])]
61 | f0_expanded = f0.unsqueeze(dim=-1).unsqueeze(dim=-1)
62 |
63 | F1_neighbor = get_neighboring_patch(F1, pi, r2)
64 |
65 | # Switch case for different distance functions
66 | if distance_type == 'l1':
67 | all_dist = calculate_l1_distance(f0_expanded, F1_neighbor)
68 | elif distance_type == 'l2':
69 | all_dist = calculate_l2_distance(f0_expanded, F1_neighbor)
70 | elif distance_type == 'cosine':
71 | all_dist = -calculate_cosine_similarity(f0_expanded, F1_neighbor) # Negative for minimization
72 |
73 | all_dist = all_dist.squeeze(dim=0)
74 | handle_points[i] = update_handle_points(pi, all_dist, r2)
75 |
76 | return handle_points
77 |
78 |
79 | def interpolate_feature_patch(feat: torch.Tensor, y: float, x: float, r: int) -> torch.Tensor:
80 | """Obtain the bilinear interpolated feature patch."""
81 | x0, y0 = torch.floor(x).long(), torch.floor(y).long()
82 | x1, y1 = x0 + 1, y0 + 1
83 |
84 | weights = torch.tensor([(x1 - x) * (y1 - y), (x1 - x) * (y - y0), (x - x0) * (y1 - y), (x - x0) * (y - y0)])
85 | weights = weights.to(feat.device)
86 |
87 | patches = torch.stack([
88 | feat[:, :, y0 - r:y0 + r + 1, x0 - r:x0 + r + 1],
89 | feat[:, :, y1 - r:y1 + r + 1, x0 - r:x0 + r + 1],
90 | feat[:, :, y0 - r:y0 + r + 1, x1 - r:x1 + r + 1],
91 | feat[:, :, y1 - r:y1 + r + 1, x1 - r:x1 + r + 1]
92 | ])
93 |
94 | return torch.sum(weights.view(-1, 1, 1, 1, 1) * patches, dim=0)
95 |
96 |
97 | def check_handle_reach_target(handle_points: list, target_points: list) -> bool:
98 | """Check if handle points are close to target points."""
99 | all_dists = torch.tensor([(p - q).norm().item() for p, q in zip(handle_points, target_points)])
100 | return (all_dists < 2.0).all().item()
101 |
--------------------------------------------------------------------------------
/utils/lora_utils.py:
--------------------------------------------------------------------------------
1 | # *************************************************************************
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # *************************************************************************
14 |
15 |
16 | from PIL import Image
17 | import os
18 | import numpy as np
19 | from einops import rearrange
20 | import torch
21 | import torch.nn.functional as F
22 | from torchvision import transforms
23 | from accelerate import Accelerator
24 | from accelerate.utils import set_seed
25 | from PIL import Image
26 | from tqdm import tqdm
27 |
28 | from transformers import AutoTokenizer, PretrainedConfig
29 |
30 | import diffusers
31 | from diffusers import (
32 | AutoencoderKL,
33 | DDPMScheduler,
34 | DiffusionPipeline,
35 | DPMSolverMultistepScheduler,
36 | StableDiffusionPipeline,
37 | UNet2DConditionModel,
38 | )
39 | from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
40 | from diffusers.models.attention_processor import (
41 | AttnAddedKVProcessor,
42 | AttnAddedKVProcessor2_0,
43 | LoRAAttnAddedKVProcessor,
44 | LoRAAttnProcessor,
45 | LoRAAttnProcessor2_0,
46 | SlicedAttnAddedKVProcessor,
47 | )
48 | from diffusers.optimization import get_scheduler
49 | from diffusers.utils import check_min_version
50 | from diffusers.utils.import_utils import is_xformers_available
51 | from diffusers import StableDiffusionXLPipeline
52 |
53 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
54 | check_min_version("0.17.0")
55 |
56 |
57 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
58 | text_encoder_config = PretrainedConfig.from_pretrained(
59 | pretrained_model_name_or_path,
60 | subfolder="text_encoder",
61 | revision=revision,
62 | )
63 | model_class = text_encoder_config.architectures[0]
64 |
65 | if model_class == "CLIPTextModel":
66 | from transformers import CLIPTextModel
67 |
68 | return CLIPTextModel
69 | elif model_class == "RobertaSeriesModelWithTransformation":
70 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
71 |
72 | return RobertaSeriesModelWithTransformation
73 | elif model_class == "T5EncoderModel":
74 | from transformers import T5EncoderModel
75 |
76 | return T5EncoderModel
77 | else:
78 | raise ValueError(f"{model_class} is not supported.")
79 |
80 |
81 | def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
82 | if tokenizer_max_length is not None:
83 | max_length = tokenizer_max_length
84 | else:
85 | max_length = tokenizer.model_max_length
86 |
87 | text_inputs = tokenizer(
88 | prompt,
89 | truncation=True,
90 | padding="max_length",
91 | max_length=max_length,
92 | return_tensors="pt",
93 | )
94 |
95 | return text_inputs
96 |
97 |
98 | def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=False):
99 | text_input_ids = input_ids.to(text_encoder.device)
100 |
101 | if text_encoder_use_attention_mask:
102 | attention_mask = attention_mask.to(text_encoder.device)
103 | else:
104 | attention_mask = None
105 |
106 | prompt_embeds = text_encoder(
107 | text_input_ids,
108 | attention_mask=attention_mask,
109 | )
110 | prompt_embeds = prompt_embeds[0]
111 |
112 | return prompt_embeds
113 |
114 |
115 | # model_path: path of the model
116 | # image: input image, have not been pre-processed
117 | # save_lora_path: the path to save the lora
118 | # prompt: the user input prompt
119 | # lora_step: number of lora training step
120 | # lora_lr: learning rate of lora training
121 | # lora_rank: the rank of lora
122 | # save_interval: the frequency of saving lora checkpoints
123 | def train_lora(image,
124 | prompt,
125 | model_path,
126 | vae_path,
127 | save_lora_path,
128 | lora_step,
129 | lora_lr,
130 | lora_batch_size,
131 | lora_rank,
132 | progress,
133 | use_gradio_progress=True,
134 | save_interval=-1):
135 | # initialize accelerator
136 | accelerator = Accelerator(
137 | gradient_accumulation_steps=1,
138 | mixed_precision='fp16'
139 | )
140 | set_seed(0)
141 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
142 |
143 | is_sdxl = 'xl' in model_path
144 | if is_sdxl:
145 | model = StableDiffusionXLPipeline.from_pretrained(model_path).to(device)
146 | tokenizer = model.tokenizer
147 | noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
148 | text_encoder = model.text_encoder
149 | vae = model.vae
150 | unet = model.unet
151 | unet.config.addition_embed_type = None
152 | else:
153 | # Load the tokenizer
154 | tokenizer = AutoTokenizer.from_pretrained(
155 | model_path,
156 | subfolder="tokenizer",
157 | revision=None,
158 | use_fast=False,
159 | )
160 | # initialize the model
161 | noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
162 | text_encoder_cls = import_model_class_from_model_name_or_path(model_path, revision=None)
163 | text_encoder = text_encoder_cls.from_pretrained(
164 | model_path, subfolder="text_encoder", revision=None
165 | )
166 | if vae_path == "default":
167 | vae = AutoencoderKL.from_pretrained(
168 | model_path, subfolder="vae", revision=None
169 | )
170 | else:
171 | vae = AutoencoderKL.from_pretrained(vae_path)
172 |
173 | unet = UNet2DConditionModel.from_pretrained(
174 | model_path, subfolder="unet", revision=None
175 | )
176 |
177 | # set device and dtype
178 |
179 | vae.requires_grad_(False)
180 | text_encoder.requires_grad_(False)
181 | unet.requires_grad_(False)
182 |
183 | unet.to(device, dtype=torch.float16)
184 | vae.to(device, dtype=torch.float16)
185 | text_encoder.to(device, dtype=torch.float16)
186 |
187 | # initialize UNet LoRA
188 | unet_lora_attn_procs = {}
189 | for name, attn_processor in unet.attn_processors.items():
190 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
191 | if name.startswith("mid_block"):
192 | hidden_size = unet.config.block_out_channels[-1]
193 | elif name.startswith("up_blocks"):
194 | block_id = int(name[len("up_blocks.")])
195 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
196 | elif name.startswith("down_blocks"):
197 | block_id = int(name[len("down_blocks.")])
198 | hidden_size = unet.config.block_out_channels[block_id]
199 | else:
200 | raise NotImplementedError("name must start with up_blocks, mid_blocks, or down_blocks")
201 |
202 | if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
203 | lora_attn_processor_class = LoRAAttnAddedKVProcessor
204 | else:
205 | lora_attn_processor_class = (
206 | LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
207 | )
208 | unet_lora_attn_procs[name] = lora_attn_processor_class(
209 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank
210 | )
211 |
212 | unet.set_attn_processor(unet_lora_attn_procs)
213 | unet_lora_layers = AttnProcsLayers(unet.attn_processors)
214 |
215 | # Optimizer creation
216 | params_to_optimize = (unet_lora_layers.parameters())
217 | optimizer = torch.optim.AdamW(
218 | params_to_optimize,
219 | lr=lora_lr,
220 | betas=(0.9, 0.999),
221 | weight_decay=1e-2,
222 | eps=1e-08,
223 | )
224 |
225 | lr_scheduler = get_scheduler(
226 | "constant",
227 | optimizer=optimizer,
228 | num_warmup_steps=0,
229 | num_training_steps=lora_step,
230 | num_cycles=1,
231 | power=1.0,
232 | )
233 |
234 | # prepare accelerator
235 | unet_lora_layers = accelerator.prepare_model(unet_lora_layers)
236 | optimizer = accelerator.prepare_optimizer(optimizer)
237 | lr_scheduler = accelerator.prepare_scheduler(lr_scheduler)
238 |
239 | # initialize text embeddings
240 | with torch.no_grad():
241 | if is_sdxl:
242 | text_embedding, _, _, _ = model.encode_prompt(prompt)
243 | text_embedding = text_embedding.repeat(lora_batch_size, 1, 1).half()
244 | else:
245 | text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None)
246 | text_embedding = encode_prompt(
247 | text_encoder,
248 | text_inputs.input_ids,
249 | text_inputs.attention_mask,
250 | text_encoder_use_attention_mask=False
251 | )
252 | text_embedding = text_embedding.repeat(lora_batch_size, 1, 1)
253 |
254 | # initialize latent distribution
255 | image_transforms = transforms.Compose(
256 | [
257 | transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
258 | transforms.RandomCrop(512),
259 | transforms.ToTensor(),
260 | transforms.Normalize([0.5], [0.5]),
261 | ]
262 | )
263 |
264 | unet.train()
265 | image_batch = []
266 | for _ in range(lora_batch_size):
267 | image_transformed = image_transforms(Image.fromarray(image)).to(device, dtype=torch.float16)
268 | image_transformed = image_transformed.unsqueeze(dim=0)
269 | image_batch.append(image_transformed)
270 |
271 | # repeat the image_transformed to enable multi-batch training
272 | image_batch = torch.cat(image_batch, dim=0)
273 |
274 | latents_dist = vae.encode(image_batch).latent_dist
275 |
276 | if use_gradio_progress:
277 | progress_bar = progress.tqdm(range(lora_step), desc="training LoRA")
278 | else:
279 | progress_bar = tqdm(range(lora_step), desc="training LoRA")
280 | for step in progress_bar:
281 | # unet.train()
282 | # image_batch = []
283 | # for _ in range(lora_batch_size):
284 | # image_transformed = image_transforms(Image.fromarray(image)).to(device, dtype=torch.float16)
285 | # image_transformed = image_transformed.unsqueeze(dim=0)
286 | # image_batch.append(image_transformed)
287 |
288 | # repeat the image_transformed to enable multi-batch training
289 | # image_batch = torch.cat(image_batch, dim=0)
290 |
291 | # latents_dist = vae.encode(image_batch).latent_dist
292 | model_input = latents_dist.sample() * vae.config.scaling_factor
293 | # Sample noise that we'll add to the latents
294 | noise = torch.randn_like(model_input)
295 | bsz, channels, height, width = model_input.shape
296 | # Sample a random timestep for each image
297 | timesteps = torch.randint(
298 | 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
299 | )
300 | timesteps = timesteps.long()
301 |
302 | # Add noise to the model input according to the noise magnitude at each timestep
303 | # (this is the forward diffusion process)
304 | noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
305 |
306 | # Predict the noise residual
307 | model_pred = unet(noisy_model_input, timesteps, text_embedding).sample
308 |
309 | # Get the target for loss depending on the prediction type
310 | if noise_scheduler.config.prediction_type == "epsilon":
311 | target = noise
312 | elif noise_scheduler.config.prediction_type == "v_prediction":
313 | target = noise_scheduler.get_velocity(model_input, noise, timesteps)
314 | else:
315 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
316 |
317 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
318 | accelerator.backward(loss)
319 | optimizer.step()
320 | lr_scheduler.step()
321 | optimizer.zero_grad()
322 |
323 | if save_interval > 0 and (step + 1) % save_interval == 0:
324 | save_lora_path_intermediate = os.path.join(save_lora_path, str(step + 1))
325 | if not os.path.isdir(save_lora_path_intermediate):
326 | os.mkdir(save_lora_path_intermediate)
327 | # unet = unet.to(torch.float32)
328 | # unwrap_model is used to remove all special modules added when doing distributed training
329 | # so here, there is no need to call unwrap_model
330 | # unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
331 | LoraLoaderMixin.save_lora_weights(
332 | save_directory=save_lora_path_intermediate,
333 | unet_lora_layers=unet_lora_layers,
334 | text_encoder_lora_layers=None,
335 | )
336 | # unet = unet.to(torch.float16)
337 |
338 | # save the trained lora
339 | # unet = unet.to(torch.float32)
340 | # unwrap_model is used to remove all special modules added when doing distributed training
341 | # so here, there is no need to call unwrap_model
342 | # unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
343 | LoraLoaderMixin.save_lora_weights(
344 | save_directory=save_lora_path,
345 | unet_lora_layers=unet_lora_layers,
346 | text_encoder_lora_layers=None,
347 | )
348 |
349 | return
350 |
--------------------------------------------------------------------------------
/utils/ui_utils.py:
--------------------------------------------------------------------------------
1 | # *************************************************************************
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # *************************************************************************
14 |
15 |
16 | import os
17 | import shutil
18 | import json
19 | from pathlib import Path
20 | from typing import List, Tuple
21 |
22 | import cv2
23 | import numpy as np
24 | import gradio as gr
25 | from copy import deepcopy
26 | from einops import rearrange
27 | from types import SimpleNamespace
28 |
29 | import datetime
30 | import PIL
31 | from PIL import Image
32 | from PIL.ImageOps import exif_transpose
33 | import torch
34 | import torch.nn.functional as F
35 |
36 | from diffusers import DDIMScheduler, AutoencoderKL
37 | from pipeline import GoodDragger
38 |
39 | from torchvision.utils import save_image
40 | from pytorch_lightning import seed_everything
41 |
42 | from .lora_utils import train_lora
43 |
44 |
45 | # -------------- general UI functionality --------------
46 | def clear_all(length=512):
47 | return gr.Image.update(value=None, height=length, width=length), \
48 | gr.Image.update(value=None, height=length, width=length), \
49 | gr.Image.update(value=None, height=length, width=length), \
50 | [], None, None
51 |
52 |
53 | def mask_image(image,
54 | mask,
55 | color=[255, 0, 0],
56 | alpha=0.5):
57 | """ Overlay mask on image for visualization purpose.
58 | Args:
59 | image (H, W, 3) or (H, W): input image
60 | mask (H, W): mask to be overlaid
61 | color: the color of overlaid mask
62 | alpha: the transparency of the mask
63 | """
64 | out = deepcopy(image)
65 | img = deepcopy(image)
66 | img[mask == 1] = color
67 | out = cv2.addWeighted(img, alpha, out, 1 - alpha, 0, out)
68 | return out
69 |
70 |
71 | def store_img(img, length=512):
72 | image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255.
73 | height, width, _ = image.shape
74 | image = Image.fromarray(image)
75 | image = exif_transpose(image)
76 | image = image.resize((length, int(length * height / width)), PIL.Image.BILINEAR)
77 | mask = cv2.resize(mask, (length, int(length * height / width)), interpolation=cv2.INTER_NEAREST)
78 | image = np.array(image)
79 |
80 | if mask.sum() > 0:
81 | mask = np.uint8(mask > 0)
82 | masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
83 | else:
84 | masked_img = image.copy()
85 | # when new image is uploaded, `selected_points` should be empty
86 | return image, [], masked_img, mask
87 |
88 |
89 | # user click the image to get points, and show the points on the image
90 | def get_points(img,
91 | sel_pix,
92 | evt: gr.SelectData):
93 | # collect the selected point
94 | sel_pix.append(evt.index)
95 | # draw points
96 | points = []
97 | for idx, point in enumerate(sel_pix):
98 | if idx % 2 == 0:
99 | # draw a red circle at the handle point
100 | cv2.circle(img, tuple(point), 5, (255, 0, 0), -1)
101 | else:
102 | # draw a blue circle at the handle point
103 | cv2.circle(img, tuple(point), 5, (0, 0, 255), -1)
104 | points.append(tuple(point))
105 | # draw an arrow from handle point to target point
106 | if len(points) == 2:
107 | cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5)
108 | points = []
109 | return img if isinstance(img, np.ndarray) else np.array(img)
110 |
111 |
112 | def show_cur_points(img,
113 | sel_pix,
114 | bgr=False):
115 | # draw points
116 | points = []
117 | for idx, point in enumerate(sel_pix):
118 | if idx % 2 == 0:
119 | # draw a red circle at the handle point
120 | red = (255, 0, 0) if not bgr else (0, 0, 255)
121 | cv2.circle(img, tuple(point), 5, red, -1)
122 | else:
123 | # draw a blue circle at the handle point
124 | blue = (0, 0, 255) if not bgr else (255, 0, 0)
125 | cv2.circle(img, tuple(point), 5, blue, -1)
126 | points.append(tuple(point))
127 | # draw an arrow from handle point to target point
128 | if len(points) == 2:
129 | cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5)
130 | points = []
131 | return img if isinstance(img, np.ndarray) else np.array(img)
132 |
133 |
134 | # clear all handle/target points
135 | def undo_points(original_image,
136 | mask):
137 | if mask.sum() > 0:
138 | mask = np.uint8(mask > 0)
139 | masked_img = mask_image(original_image, 1 - mask, color=[0, 0, 0], alpha=0.3)
140 | else:
141 | masked_img = original_image.copy()
142 | return masked_img, []
143 |
144 |
145 | def clear_folder(folder_path):
146 | # Check if the folder exists
147 | if os.path.exists(folder_path):
148 | # Iterate over all the files and directories in the folder
149 | for filename in os.listdir(folder_path):
150 | file_path = os.path.join(folder_path, filename)
151 |
152 | # Check if it's a file or a directory
153 | if os.path.isfile(file_path) or os.path.islink(file_path):
154 | os.unlink(file_path) # Remove the file or link
155 | elif os.path.isdir(file_path):
156 | shutil.rmtree(file_path) # Remove the directory and all its contents
157 |
158 |
159 | def train_lora_interface(original_image,
160 | prompt,
161 | model_path,
162 | vae_path,
163 | lora_path,
164 | lora_step,
165 | lora_lr,
166 | lora_batch_size,
167 | lora_rank,
168 | progress=gr.Progress(),
169 | use_gradio_progress=True):
170 | if not os.path.exists(lora_path):
171 | os.makedirs(lora_path)
172 |
173 | clear_folder(lora_path)
174 |
175 | train_lora(
176 | original_image,
177 | prompt,
178 | model_path,
179 | vae_path,
180 | lora_path,
181 | lora_step,
182 | lora_lr,
183 | lora_batch_size,
184 | lora_rank,
185 | progress,
186 | use_gradio_progress)
187 | return "Training LoRA Done!"
188 |
189 |
190 | def preprocess_image(image,
191 | device):
192 | image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1]
193 | image = rearrange(image, "h w c -> 1 c h w")
194 | image = image.to(device)
195 | return image
196 |
197 |
198 | def save_images_with_pillow(images, base_filename='image'):
199 | for index, img in enumerate(images):
200 | # Convert array to Image object and save
201 | img_pil = Image.fromarray(img)
202 | folder_path = f'./save'
203 | filename = os.path.join(folder_path, "{}_{}.png".format(base_filename, index))
204 | img_pil.save(filename)
205 | print(f"Saved: {filename}")
206 |
207 |
208 | def get_original_points(handle_points: List[torch.Tensor],
209 | full_h: int,
210 | full_w: int,
211 | sup_res_w,
212 | sup_res_h,
213 | ) -> List[torch.Tensor]:
214 | """
215 | Convert local handle points and target points back to their original UI coordinates.
216 |
217 | Args:
218 | sup_res_h: Half original height of the UI canvas.
219 | sup_res_w: Half original width of the UI canvas.
220 | handle_points: List of handle points in local coordinates.
221 | full_h: Original height of the UI canvas.
222 | full_w: Original width of the UI canvas.
223 |
224 | Returns:
225 | original_handle_points: List of handle points in original UI coordinates.
226 | """
227 | original_handle_points = []
228 |
229 | for cur_point in handle_points:
230 | original_point = torch.round(
231 | torch.tensor([cur_point[1] * full_w / sup_res_w, cur_point[0] * full_h / sup_res_h]))
232 | original_handle_points.append(original_point)
233 |
234 | return original_handle_points
235 |
236 |
237 | def save_image_mask_points(mask, points, image_with_points, output_dir='./saved_data'):
238 | """
239 | Saves the mask and points to the specified directory.
240 |
241 | Args:
242 | mask: The mask data as a numpy array.
243 | points: The list of points collected from the user interaction.
244 | image_with_points: The image with points clicked by the user.
245 | output_dir: The directory where to save the data.
246 | """
247 | os.makedirs(output_dir, exist_ok=True)
248 |
249 | # Save mask
250 | mask_path = os.path.join(output_dir, f"mask.png")
251 | Image.fromarray(mask.astype(np.uint8) * 255).save(mask_path)
252 |
253 | # Save points
254 | points_path = os.path.join(output_dir, f"points.json")
255 | with open(points_path, 'w') as f:
256 | json.dump({'points': points}, f)
257 |
258 | image_with_points_path = os.path.join(output_dir, "image_with_points.jpg")
259 | Image.fromarray(image_with_points).save(image_with_points_path)
260 |
261 | return
262 |
263 |
264 | def save_drag_result(output_image, new_points, result_path):
265 | output_image = cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR)
266 |
267 | result_dir = f'{result_path}'
268 | os.makedirs(result_dir, exist_ok=True)
269 | output_image_path = os.path.join(result_dir, 'output_image.png')
270 | cv2.imwrite(output_image_path, output_image)
271 |
272 | img_with_new_points = show_cur_points(np.ascontiguousarray(output_image), new_points, bgr=True)
273 | new_points_image_path = os.path.join(result_dir, 'image_with_new_points.png')
274 | cv2.imwrite(new_points_image_path, img_with_new_points)
275 |
276 | points_path = os.path.join(result_dir, f'new_points.json')
277 | with open(points_path, 'w') as f:
278 | json.dump({'points': new_points}, f)
279 |
280 |
281 | def save_intermediate_images(intermediate_images, result_dir):
282 | for i in range(len(intermediate_images)):
283 | intermediate_images[i] = cv2.cvtColor(intermediate_images[i], cv2.COLOR_RGB2BGR)
284 | intermediate_images_path = os.path.join(result_dir, f'output_image_{i}.png')
285 | cv2.imwrite(intermediate_images_path, intermediate_images[i])
286 |
287 |
288 | def create_video(image_folder, data_folder, fps=2, first_frame_duration=2, last_frame_extra_duration=2):
289 | """
290 | Creates an MP4 video from a sequence of images using OpenCV.
291 | """
292 | img_folder = Path(image_folder)
293 | img_num = len(list(img_folder.glob('*.png')))
294 |
295 | # Path to the original image with points
296 | data_folder = Path(data_folder)
297 | original_path = data_folder / 'image_with_points.jpg'
298 | output_path = img_folder / 'dragging.mp4'
299 | # Collect all image paths
300 | img_files = [original_path]
301 |
302 | # Load the first image to determine the size
303 | frame = cv2.imread(str(img_files[0]))
304 | height, width, layers = frame.shape
305 | size = (int(width), int(height))
306 |
307 | # Define the codec and create VideoWriter object
308 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 'mp4v' for .mp4 format
309 | video = cv2.VideoWriter(str(output_path), fourcc, int(fps), size)
310 |
311 | for _ in range(int(fps * first_frame_duration)):
312 | video.write(frame)
313 |
314 | # Add images to video
315 | for i in range(img_num - 2):
316 | video.write(cv2.imread(str(img_folder / f'output_image_{i}.png')))
317 |
318 | last_frame = cv2.imread(str(img_folder / 'output_image.png'))
319 | for _ in range(int(fps * last_frame_extra_duration)):
320 | video.write(last_frame)
321 |
322 | video.release()
323 |
324 |
325 | def run_gooddrag(source_image,
326 | image_with_clicks,
327 | mask,
328 | prompt,
329 | points,
330 | inversion_strength,
331 | lam,
332 | latent_lr,
333 | model_path,
334 | vae_path,
335 | lora_path,
336 | drag_end_step,
337 | track_per_step,
338 | r1,
339 | r2,
340 | d,
341 | max_drag_per_track,
342 | max_track_no_change,
343 | feature_idx=3,
344 | result_save_path='',
345 | return_intermediate_images=True,
346 | drag_loss_threshold=0,
347 | save_intermedia=False,
348 | compare_mode=False,
349 | once_drag=False,
350 | ):
351 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
352 | height, width = source_image.shape[:2]
353 | n_inference_step = 50
354 | guidance_scale = 1.0
355 | seed = 42
356 | dragger = GoodDragger(device, model_path, prompt, height, width, inversion_strength, r1, r2, d,
357 | drag_end_step, track_per_step, lam, latent_lr,
358 | n_inference_step, guidance_scale, feature_idx, compare_mode, vae_path, lora_path, seed,
359 | max_drag_per_track, drag_loss_threshold, once_drag, max_track_no_change)
360 |
361 | source_image = preprocess_image(source_image, device)
362 |
363 | gen_image, intermediate_features, new_points_handle, intermediate_images = \
364 | dragger.good_drag(source_image, points,
365 | mask,
366 | return_intermediate_images=return_intermediate_images)
367 |
368 | new_points_handle = get_original_points(new_points_handle, height, width, dragger.sup_res_w, dragger.sup_res_h)
369 | if save_intermedia:
370 | drag_image = [dragger.latent2image(i.cuda()) for i in intermediate_features]
371 | save_images_with_pillow(drag_image, base_filename='drag_image')
372 |
373 | gen_image = F.interpolate(gen_image, (height, width), mode='bilinear')
374 |
375 | out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0]
376 | out_image = (out_image * 255).astype(np.uint8)
377 |
378 | new_points = []
379 | for i in range(len(new_points_handle)):
380 | new_cur_handle_points = new_points_handle[i].numpy().tolist()
381 | new_cur_handle_points = [int(point) for point in new_cur_handle_points]
382 | new_points.append(new_cur_handle_points)
383 | new_points.append(points[i * 2 + 1])
384 |
385 | print(f'points {points}')
386 | print(f'new points {new_points}')
387 |
388 | if return_intermediate_images:
389 | os.makedirs(result_save_path, exist_ok=True)
390 | for i in range(len(intermediate_images)):
391 | intermediate_images[i] = F.interpolate(intermediate_images[i], (height, width), mode='bilinear')
392 | intermediate_images[i] = intermediate_images[i].cpu().permute(0, 2, 3, 1).numpy()[0]
393 | intermediate_images[i] = (intermediate_images[i] * 255).astype(np.uint8)
394 |
395 | for i in range(len(intermediate_images)):
396 | intermediate_images[i] = cv2.cvtColor(intermediate_images[i], cv2.COLOR_RGB2BGR)
397 | intermediate_images_path = os.path.join(result_save_path, f'output_image_{i}.png')
398 | cv2.imwrite(intermediate_images_path, intermediate_images[i])
399 |
400 | return out_image, new_points
401 |
--------------------------------------------------------------------------------
/webui.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | REM Temporary file for modified requirements
4 | set TEMP_REQ_FILE=temp_requirements.txt
5 |
6 | REM Detect CUDA version using nvcc
7 | for /f "delims=" %%i in ('nvcc --version ^| findstr /i "release"') do set "CUDA_VER_FULL=%%i"
8 |
9 | REM Extract the version number, assuming it's in the format "Cuda compilation tools, release 11.8, V11.8.89"
10 | for /f "tokens=5 delims=, " %%a in ("%CUDA_VER_FULL%") do set "CUDA_VER=%%a"
11 | for /f "tokens=1 delims=vV" %%b in ("%CUDA_VER%") do set "CUDA_VER=%%b"
12 | for /f "tokens=1-2 delims=." %%c in ("%CUDA_VER%") do set "CUDA_MAJOR=%%c" & set "CUDA_MINOR=%%d"
13 |
14 | REM Concatenate major and minor version numbers to form the CUDA tag
15 | set "CUDA_TAG=cu%CUDA_MAJOR%%CUDA_MINOR%"
16 |
17 | echo Detected CUDA Tag: %CUDA_TAG%
18 |
19 | REM Modify the torch and torchvision lines in requirements.txt to include the CUDA version
20 | for /F "tokens=*" %%A in (requirements.txt) do (
21 | echo %%A | findstr /I "torch==" >nul
22 | if errorlevel 1 (
23 | echo %%A | findstr /I "torchvision==" >nul
24 | if errorlevel 1 (
25 | echo %%A >> "%TEMP_REQ_FILE%"
26 | ) else (
27 | echo torchvision==0.15.2+%CUDA_TAG% >> "%TEMP_REQ_FILE%"
28 | )
29 | ) else (
30 | echo torch==2.0.1+%CUDA_TAG%>> "%TEMP_REQ_FILE%"
31 | )
32 | )
33 |
34 | REM Replace the original requirements file with the modified one
35 | move /Y "%TEMP_REQ_FILE%" requirements.txt
36 |
37 |
38 | REM Define the virtual environment directory
39 | set VENV_DIR=GoodDrag
40 |
41 | REM Check if Python is installed
42 | python --version > nul 2>&1
43 | if %errorlevel% neq 0 (
44 | echo Python is not installed. Please install it first.
45 | pause
46 | exit /b
47 | )
48 |
49 |
50 |
51 | REM Create a virtual environment if it doesn't exist
52 | if not exist "%VENV_DIR%" (
53 | echo Creating virtual environment...
54 | python -m venv %VENV_DIR%
55 | )
56 |
57 | REM Activate the virtual environment
58 | call %VENV_DIR%\Scripts\activate.bat
59 |
60 | REM Install dependencies (uncomment and modify the next line if you have any dependencies)
61 | pip install -r requirements.txt
62 |
63 | REM Run the Python script
64 | echo Starting gooddrag_ui.py...
65 | python gooddrag_ui.py
66 |
67 | REM Deactivate the virtual environment on script exit
68 | call %VENV_DIR%\Scripts\deactivate.bat
69 |
70 | echo Script finished. Press any key to exit.
71 | pause > nul
72 |
--------------------------------------------------------------------------------
/webui.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Temporary file for modified requirements
4 | TEMP_REQ_FILE="temp_requirements.txt"
5 |
6 | # Detect CUDA version using nvcc
7 | CUDA_VER_FULL=$(nvcc --version | grep "release" | sed -n 's/.*release \([0-9]*\)\.\([0-9]*\),.*/\1\2/p')
8 |
9 | # Set the CUDA tag
10 | CUDA_TAG="cu$CUDA_VER_FULL"
11 |
12 | echo "Detected CUDA Tag: $CUDA_TAG"
13 |
14 | # Modify the torch and torchvision lines in requirements.txt to include the CUDA version
15 | while IFS= read -r line; do
16 | if [[ "$line" == torch==* ]]; then
17 | echo "torch==2.0.1+$CUDA_TAG" >> "$TEMP_REQ_FILE"
18 | elif [[ "$line" == torchvision==* ]]; then
19 | echo "torchvision==0.15.2+$CUDA_TAG" >> "$TEMP_REQ_FILE"
20 | else
21 | echo "$line" >> "$TEMP_REQ_FILE"
22 | fi
23 | done < requirements.txt
24 |
25 | # Replace the original requirements file with the modified one
26 | mv "$TEMP_REQ_FILE" requirements.txt
27 |
28 | # Define the virtual environment directory
29 | VENV_DIR="GoodDrag"
30 |
31 | # Check if Python 3 is installed
32 | if ! command -v python3 &> /dev/null; then
33 | echo "Python 3 is not installed. Please install it first."
34 | exit 1
35 | fi
36 |
37 | # Create a virtual environment if it doesn't exist
38 | if [ ! -d "$VENV_DIR" ]; then
39 | echo "Creating virtual environment..."
40 | python3 -m venv "$VENV_DIR"
41 | fi
42 |
43 | # Activate the virtual environment
44 | source "$VENV_DIR/bin/activate"
45 |
46 | # Install dependencies
47 | echo "Installing dependencies..."
48 | pip install -r requirements.txt
49 |
50 | # Run the Python script
51 | echo "Starting gooddrag_ui.py..."
52 | python3 gooddrag_ui.py
53 |
54 | # Deactivate the virtual environment on script exit
55 | deactivate
56 |
57 | echo "Script finished."
--------------------------------------------------------------------------------