├── LICENSE
├── README.md
├── T2I_inference.py
├── T2I_inference_merge_lora.py
├── assets
├── example1.png
├── example2.png
├── example3.png
├── example4.png
├── example5.png
├── hyperdreambooth.png
├── hypernet.png
├── img_3.png
├── img_5.png
├── img_6.png
├── light_lora.png
├── pipeline.png
├── preoptnet.png
└── relax_relax_fast_finetune.png
├── export_hypernet_weight.py
├── export_hypernet_weight.sh
├── export_preoptnet_weight.py
├── export_preoptnet_weight.sh
├── fast_finetune.py
├── fast_finetune.sh
├── images
├── 1.jpg
├── 2.jpg
└── 3.jpg
├── modules
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-311.pyc
│ └── attention.cpython-311.pyc
├── attention.py
├── hypernet.py
├── hypernet_test.py
├── light_lora.py
├── light_lora_pro.py
├── lora.py
├── ortho_lora.py
├── relax_lora.py
├── relax_lora_backup.py
└── utils
│ ├── __init__.py
│ ├── lora_utils.py
│ └── xformers_utils.py
├── rank_relax.py
├── rank_relax_test.py
├── requirements.txt
├── setup.py
├── train_dreambooth_light_lora.py
├── train_dreambooth_light_lora.sh
├── train_dreambooth_lora.py
├── train_dreambooth_lora.sh
├── train_hypernet.py
├── train_hypernet.sh
├── train_hypernet_pro.py
├── train_hypernet_pro.sh
├── train_preoptnet.py
└── train_preoptnet.sh
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [2023] [KohakuBlueLeaf]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
203 | ---
204 |
205 | HyperKohaku includes a number of components and libraries with separate copyright notices and license terms. Your use of those components are subject to the terms and conditions of the following licenses:
206 |
207 | - train_hyperdreambooth.py:
208 | Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch
209 | Project URL: https://github.com/huggingface/diffusers
210 | License: The Apache Software License, Version 2.0 (ASL-2.0) http://www.apache.org/licenses/LICENSE-2.0.txt
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # HyperDreamBooth
2 |
3 |
4 | ## Overview
5 |
6 | In this project, I provide a diffusers based implementation of [HyperDreamBooth: HyperNetworks for Fast Personalization of Text-to-Image Models]https://arxiv.org/abs/2307.06949)
7 | Some of the codes come from this project:https://github.com/KohakuBlueleaf/HyperKohaku?tab=readme-ov-file
8 |
9 | ## Main Steps
10 |
11 |
12 |
13 |
14 | ### 1.Lightweight DreamBooth
15 |
16 | Run the scripts below to test if Lightweight DreamBooth is working properly. It should generate normal images just like a standard LoRA.
17 |
18 | `sh train_dreambooth_light_lora.sh`
19 | `python T2I_inference.py`
20 |
21 |
22 |
23 |
24 |
25 | ### 2.Weight Pre-Optimization
26 |
27 | Run the scripts for pre-optimization of weights, then export the corresponding LoRA according to the Identity. It supports batch training.
28 |
29 | `sh train_preoptnet.sh`
30 | `python export_preoptnet_weight.sh`
31 | `python T2I_inference.py`
32 |
33 |
34 |
35 |
36 |
37 | ### 3.Hypernet Training
38 |
39 | Run the scripts for hypernetwork training, then export the corresponding LoRA based on the input image.
40 |
41 | `sh train_hypernet.sh`
42 | `sh export_hypernet_weight.sh`
43 | `python T2I_inference.py`
44 |
45 |
46 |
47 |
48 |
49 |
50 | ### 4.Rank Relaxed Fast Finetuning
51 |
52 | The Rank-Relaxed Fast Finetuning can be executed by merely adjusting the LoRALinearLayer. It comprises two LoRA structures: a frozen LoRA(r=1), initialized with the weights predicted by the Hypernet, and a trainable LoRA(r>1) with zero initialization. The frozen and trainable LoRA need to be merged and exported as a standard LoRA, which is then restored for fast finetuning.
53 |
54 | The detailed steps are as follows:
55 |
56 | 1. Merge the linear layers in the frozen and trainable LoRA into a single linear layer. This can be achieved with simple multiplication or addition operations on their weight matrices, We can easily obtain:
57 | Rank(merged_lora) ≤ Rank(frozen_lora) + Rank(trainable_lora).
58 |
59 | 2. Apply SVD decomposition (with the diagonal matrix absorbed to the left) to the weight matrix of this newly merged linear layer to derive the final LoRA's Up and Down matrices. By retaining all non-zero singular values (N=Rank(merged_lora)), the SVD truncation error becomes zero, which is exactly our objective.
60 |
61 | 3. Restore the weights of the merged LoRA and perform a standard LoRA finetuning. It should take approximately 25 steps.
62 |
63 | `python rank_relax.py`
64 | `sh fast_finetune.py`
65 | `python T2I_inference.py`
66 |
67 |
68 |
69 |
70 |
71 | ### 5.Experiments
72 |
73 | 1. Set appropriate hyperparameters to directly train light_lora, ensuring that the results are normal. Then, conduct batch training of light_lora in preoptnet for all IDs.
74 |
75 | 2. When only utilizing the hypernetwork prediction, there are some flaws in the local details.
76 | 3. Implementing rank-relaxed fast fine-tuning based on the hypernetwork prediction reduces the training steps to 25, significantly improving the results.
77 |
78 | 4. If combined with the style of LoRA, the effect is as follows.
79 |
80 | 5. If the Down Aux and Up Aux are set to be learnable, there is no need for the Weight Pre-Optimization and Fast-finetuning process, and the results are even better.
81 | `sh train_hypernet_pro.sh`
82 | `sh export_hypernet_weight.sh`
83 | `python T2I_inference.py`
84 |
85 |
86 |
87 | ------
88 |
89 | ## Main Contributor
90 |
91 | chenxinhua: chenxinhua1002@163.com
92 |
--------------------------------------------------------------------------------
/T2I_inference.py:
--------------------------------------------------------------------------------
1 | from diffusers import StableDiffusionPipeline, DDIMScheduler
2 | import torch
3 | import time
4 | pretrain_model_path="stable-diffusion-models/realisticVisionV40_v40VAE"
5 |
6 | noise_scheduler = DDIMScheduler(
7 | num_train_timesteps=1000,
8 | beta_start=0.00085,
9 | beta_end=0.012,
10 | beta_schedule="scaled_linear",
11 | clip_sample=False,
12 | set_alpha_to_one=False,
13 | steps_offset=1
14 | )
15 |
16 | pipe = StableDiffusionPipeline.from_pretrained(pretrain_model_path,
17 | torch_dtype=torch.float16,
18 | scheduler=noise_scheduler,
19 | requires_safety_checker=False)
20 |
21 |
22 |
23 |
24 | lora_model_path = "/projects/AIGC/experiments2/rank_relax/"
25 |
26 | pipe.load_lora_weights(lora_model_path, weight_name="pytorch_lora_weights.safetensors")
27 |
28 | # pipe.load_lora_weights(lora_model_path)
29 | prompt = "A [V] face"
30 |
31 | negative_prompt = "nsfw,easynegative"
32 | # negative_prompt = "nsfw, easynegative, paintings, sketches, (worst quality:2), low res, normal quality, ((monochrome)), skin spots, acnes, skin blemishes, extra fingers, fewer fingers, strange fingers, bad hand, mole, ((extra legs)), ((extra hands)), bad-hands-5"
33 | # prompt = "1girl, stulmna, exquisitely detailed skin, looking at viewer, ultra high res, delicate"
34 | # prompt = "A [v] face"
35 | # prompt = "A pixarstyle of a [V] face"
36 | # prompt = "A [V] face with bark skin"
37 | # prompt = "A [V] face"
38 | # prompt = "A professional portrait of a [V] face"
39 | # prompt = "1girl, lineart, monochrome"
40 |
41 | # prompt = "1girl,(exquisitely detailed skin:1.3), looking at viewer, ultra high res, delicate"
42 | # prompt = "1boy, a professional detailed high-quality image, looking at viewer"
43 | # prompt = "1girl, stulmno, solo, best quality, looking at viewer"
44 | # prompt = "1girl, solo, best quality, looking at viewer"
45 |
46 | # prompt = "(upper body: 1.5),(white background:1.4), (illustration:1.1),(best quality),(masterpiece:1.1),(extremely detailed CG unity 8k wallpaper:1.1), (colorful:0.9),(panorama shot:1.4),(full body:1.05),(solo:1.2), (ink splashing),(color splashing),((watercolor)), clear sharp focus,{ 1boy standing },((chinese style )),(flowers,woods),outdoors,rocks, looking at viewer, happy expression ,soft smile, detailed face, clothing decorative pattern details, black hair,black eyes, "
47 |
48 | pipe.to("cuda")
49 |
50 | t0 = time.time()
51 | for i in range(10):
52 | # image = pipe(prompt, negative_prompt=negative_prompt, height=512, width=512, num_inference_steps=30, guidance_scale=7.5,cross_attention_kwargs={"scale":1}).images[0]
53 | image = pipe(prompt, height=512, width=512, num_inference_steps=30).images[0]
54 | image.save("aigc_samples/test_%d.jpg" % i)
55 | t1 = time.time()
56 | print("time elapsed: %f"%((t1-t0)/10))
57 | print("LoRA: %s"%lora_model_path)
58 |
--------------------------------------------------------------------------------
/T2I_inference_merge_lora.py:
--------------------------------------------------------------------------------
1 | from diffusers import StableDiffusionPipeline, DDIMScheduler
2 | import torch
3 | import random
4 | pretrain_model_path="stable-diffusion-models/realisticVisionV40_v40VAE"
5 |
6 | noise_scheduler = DDIMScheduler(
7 | num_train_timesteps=1000,
8 | beta_start=0.00085,
9 | beta_end=0.012,
10 | beta_schedule="scaled_linear",
11 | clip_sample=False,
12 | set_alpha_to_one=False,
13 | steps_offset=1
14 | )
15 |
16 | pipe = StableDiffusionPipeline.from_pretrained(pretrain_model_path,
17 | torch_dtype=torch.float16,
18 | scheduler=noise_scheduler,
19 | requires_safety_checker=False)
20 |
21 |
22 |
23 | # personal lora
24 | dir1 = "projects/AIGC/lora_model_test"
25 | lora_model_path1 = "pytorch_lora_weights.safetensors"
26 | # prompt = "A [V] face"
27 | dir2="projects/AIGC/experiments/style_lora"
28 | lora_model_path2 = "pixarStyleModel_lora128.safetensors"
29 | # lora_model_path2 = "Watercolor_Painting_by_vizsumit.safetensors"
30 | # lora_model_path2 = "Professional_Portrait_1.5-000008.safetensors"
31 | prompt = "A pixarstyle of a [V] face"
32 | # prompt = "A watercolor paint of a [V] face"
33 | # prompt = "A professional portrait of a [V] face"
34 |
35 | negative_prompt = "nsfw,easynegative"
36 |
37 | pipe.to("cuda")
38 |
39 |
40 | pipe.load_lora_weights(dir1, weight_name=lora_model_path1, adapter_name="person")
41 | pipe.load_lora_weights(dir2, weight_name=lora_model_path2, adapter_name="style")
42 | # pipe.set_adapters(["person", "style"], adapter_weights=[0.6, 0.4]) #pixar
43 |
44 | pipe.set_adapters(["person", "style"], adapter_weights=[0.4, 0.4]) #watercolor
45 | # Fuses the LoRAs into the Unet
46 | pipe.fuse_lora()
47 |
48 | for i in range(10):
49 | seed = random.randint(0, 100)
50 | generator = torch.Generator(device="cuda").manual_seed(seed)
51 | # image = pipe(prompt, negative_prompt=negative_prompt, height=512, width=512, num_inference_steps=30, guidance_scale=7.5,cross_attention_kwargs={"scale":1}).images[0]
52 | # image = pipe(prompt, height=512, width=512, num_inference_steps=30, guidance_scale=7.5, cross_attention_kwargs={"scale": 1.0}, generator=torch.manual_seed(0)).images[0]
53 | # image = pipe(prompt, height=512, width=512, num_inference_steps=30, guidance_scale=7.5, cross_attention_kwargs={"scale": 1.0}).images[0]
54 | image = pipe(prompt, height=512, width=512, num_inference_steps=30, generator=generator).images[0]
55 |
56 | image.save("aigc_samples/test_after_export_%d.jpg" % i)
57 |
58 | # Gets the Unet back to the original state
59 | pipe.unfuse_lora()
--------------------------------------------------------------------------------
/assets/example1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/example1.png
--------------------------------------------------------------------------------
/assets/example2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/example2.png
--------------------------------------------------------------------------------
/assets/example3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/example3.png
--------------------------------------------------------------------------------
/assets/example4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/example4.png
--------------------------------------------------------------------------------
/assets/example5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/example5.png
--------------------------------------------------------------------------------
/assets/hyperdreambooth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/hyperdreambooth.png
--------------------------------------------------------------------------------
/assets/hypernet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/hypernet.png
--------------------------------------------------------------------------------
/assets/img_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/img_3.png
--------------------------------------------------------------------------------
/assets/img_5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/img_5.png
--------------------------------------------------------------------------------
/assets/img_6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/img_6.png
--------------------------------------------------------------------------------
/assets/light_lora.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/light_lora.png
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/pipeline.png
--------------------------------------------------------------------------------
/assets/preoptnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/preoptnet.png
--------------------------------------------------------------------------------
/assets/relax_relax_fast_finetune.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/relax_relax_fast_finetune.png
--------------------------------------------------------------------------------
/export_hypernet_weight.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 |
4 | # Modified by KohakuBlueLeaf
5 | # Modified from diffusers/example/dreambooth/train_dreambooth_lora.py
6 | # see original licensed below
7 | # =======================================================================
8 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
9 | #
10 | # Licensed under the Apache License, Version 2.0 (the "License");
11 | # you may not use this file except in compliance with the License.
12 | # You may obtain a copy of the License at
13 | #
14 | # http://www.apache.org/licenses/LICENSE-2.0
15 | #
16 | # Unless required by applicable law or agreed to in writing, software
17 | # distributed under the License is distributed on an "AS IS" BASIS,
18 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 | # See the License for the specific language governing permissions and
20 | # =======================================================================
21 |
22 | import argparse
23 | import os
24 | from packaging import version
25 | from PIL import Image
26 | import torch
27 | from torchvision.transforms import CenterCrop, Resize, ToTensor, Compose, Normalize
28 | from diffusers import StableDiffusionPipeline
29 | from diffusers.models.attention_processor import (
30 | AttnAddedKVProcessor,
31 | AttnAddedKVProcessor2_0,
32 | SlicedAttnAddedKVProcessor,
33 | )
34 |
35 |
36 | from modules.light_lora import LoRALinearLayer, LoraLoaderMixin
37 | from modules.utils.lora_utils import unet_lora_state_dict, text_encoder_lora_state_dict
38 | from modules.hypernet import HyperDream
39 |
40 |
41 | def parse_args(input_args=None):
42 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
43 | parser.add_argument(
44 | "--pretrained_model_name_or_path",
45 | type=str,
46 | default=None,
47 | required=True,
48 | help="Path to pretrained model or model identifier from huggingface.co/models.",
49 | )
50 | parser.add_argument(
51 | "--revision",
52 | type=str,
53 | default=None,
54 | required=False,
55 | help="Revision of pretrained model identifier from huggingface.co/models.",
56 | )
57 | parser.add_argument(
58 | "--hypernet_model_path",
59 | type=str,
60 | default=None,
61 | required=True,
62 | help="Path to pretrained hyperkohaku model",
63 | )
64 | parser.add_argument(
65 | "--output_dir",
66 | type=str,
67 | default=None,
68 | required=True,
69 | )
70 | parser.add_argument(
71 | "--reference_image_path",
72 | type=str,
73 | default=None,
74 | required=True,
75 | help="Path to reference image",
76 | )
77 | parser.add_argument(
78 | "--vit_model_name",
79 | type=str,
80 | default="vit_base_patch16_224",
81 | help="The ViT encoder used in hypernet encoder.",
82 | )
83 |
84 | parser.add_argument(
85 | "--rank",
86 | type=int,
87 | default=1,
88 | help=("The dimension of the LoRA update matrices."),
89 | )
90 | parser.add_argument(
91 | "--down_dim",
92 | type=int,
93 | default=160,
94 | help=("The dimension of the LoRA update matrices."),
95 | )
96 | parser.add_argument(
97 | "--up_dim",
98 | type=int,
99 | default=80,
100 | help=("The dimension of the LoRA update matrices."),
101 | )
102 | parser.add_argument(
103 | "--train_text_encoder",
104 | action="store_true",
105 | help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
106 | )
107 | parser.add_argument(
108 | "--patch_mlp",
109 | action="store_true",
110 | help="Whether to train the text encoder with mlp. If set, the text encoder should be float32 precision.",
111 | )
112 | parser.add_argument(
113 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
114 | )
115 |
116 | if input_args is not None:
117 | args = parser.parse_args(input_args)
118 | else:
119 | args = parser.parse_args()
120 | return args
121 |
122 | def main(args):
123 | # Load Model
124 | pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, torch_dtype=torch.float32)
125 | # pipe.to("cuda")
126 |
127 | unet = pipe.unet
128 | text_encoder = pipe.text_encoder
129 |
130 | unet_lora_parameters = []
131 | unet_lora_linear_layers = []
132 | for i, (attn_processor_name, attn_processor) in enumerate(unet.attn_processors.items()):
133 | print("unet.attn_processor->%d:%s" % (i, attn_processor_name), attn_processor)
134 | # Parse the attention module.
135 | attn_module = unet
136 | for n in attn_processor_name.split(".")[:-1]:
137 | attn_module = getattr(attn_module, n)
138 | print("attn_module:", attn_module)
139 |
140 | # Set the `lora_layer` attribute of the attention-related matrices.
141 | attn_module.to_q.set_lora_layer(
142 | LoRALinearLayer(
143 | in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features,
144 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False,
145 | )
146 | )
147 | attn_module.to_k.set_lora_layer(
148 | LoRALinearLayer(
149 | in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features,
150 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False,
151 | )
152 | )
153 | attn_module.to_v.set_lora_layer(
154 | LoRALinearLayer(
155 | in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features,
156 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False,
157 | )
158 | )
159 | attn_module.to_out[0].set_lora_layer(
160 | LoRALinearLayer(
161 | in_features=attn_module.to_out[0].in_features, out_features=attn_module.to_out[0].out_features,
162 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False,
163 | )
164 | )
165 | # Accumulate the LoRA params to optimize.
166 | unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
167 | unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
168 | unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
169 | unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
170 |
171 | # Accumulate the LoRALinerLayer to optimize.
172 | unet_lora_linear_layers.append(attn_module.to_q.lora_layer)
173 | unet_lora_linear_layers.append(attn_module.to_k.lora_layer)
174 | unet_lora_linear_layers.append(attn_module.to_v.lora_layer)
175 | unet_lora_linear_layers.append(attn_module.to_out[0].lora_layer)
176 |
177 | if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
178 | attn_module.add_k_proj.set_lora_layer(
179 | LoRALinearLayer(
180 | in_features=attn_module.add_k_proj.in_features,
181 | out_features=attn_module.add_k_proj.out_features,
182 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False,
183 | )
184 | )
185 | attn_module.add_v_proj.set_lora_layer(
186 | LoRALinearLayer(
187 | in_features=attn_module.add_v_proj.in_features,
188 | out_features=attn_module.add_v_proj.out_features,
189 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False,
190 | )
191 | )
192 | unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters())
193 | unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters())
194 |
195 | unet_lora_linear_layers.append(attn_module.add_k_proj.lora_layer)
196 | unet_lora_linear_layers.append(attn_module.add_v_proj.lora_layer)
197 |
198 | # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
199 | # So, instead, we monkey-patch the forward calls of its attention-blocks.
200 | if args.train_text_encoder:
201 | # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
202 | # if patch_mlp is True, the finetuning will cover the text encoder mlp,
203 | # otherwise only the text encoder attention, total lora is (12+12)*4=96
204 | text_lora_parameters, text_lora_linear_layers = LoraLoaderMixin._modify_text_encoder(text_encoder,
205 | dtype=torch.float32,
206 | rank=args.rank,
207 | down_dim=args.down_dim,
208 | up_dim=args.up_dim,
209 | patch_mlp=args.patch_mlp,
210 | is_train=False)
211 | # total loras
212 | lora_linear_layers = unet_lora_linear_layers + text_lora_linear_layers \
213 | if args.train_text_encoder else unet_lora_linear_layers
214 |
215 | if args.vit_model_name == "vit_base_patch16_224":
216 | img_encoder_model_name = "vit_base_patch16_224"
217 | ref_img_size = 224
218 | mean = [0.5000]
219 | std = [0.5000]
220 | elif args.vit_model_name == "vit_huge_patch14_clip_224":
221 | img_encoder_model_name = "vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k"
222 | ref_img_size = 224
223 | mean = [0.4815, 0.4578, 0.4082]
224 | std = [0.2686, 0.2613, 0.2758]
225 | elif args.vit_model_name == "vit_huge_patch14_clip_336":
226 | img_encoder_model_name = "vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k"
227 | ref_img_size = 336
228 | mean = [0.4815, 0.4578, 0.4082]
229 | std = [0.2686, 0.2613, 0.2758]
230 | else:
231 | raise ValueError("%s does not supports!" % args.img_encoder_model_name)
232 |
233 | hypernet_transposes = Compose([
234 | Resize(size=ref_img_size),
235 | CenterCrop(size=(ref_img_size, ref_img_size)),
236 | ToTensor(),
237 | Normalize(mean=mean, std=std),
238 | ])
239 |
240 | hypernetwork = HyperDream(
241 | img_encoder_model_name=img_encoder_model_name,
242 | ref_img_size=ref_img_size,
243 | weight_num=len(lora_linear_layers),
244 | weight_dim=(args.up_dim + args.down_dim) * args.rank,
245 | )
246 | hypernetwork.set_lilora(lora_linear_layers)
247 |
248 | if os.path.isdir(args.hypernet_model_path):
249 | path = os.path.join(args.hypernet_model_path, "hypernetwork.bin")
250 | weight = torch.load(path)
251 | sd = weight['hypernetwork']
252 | hypernetwork.load_state_dict(sd)
253 | else:
254 | weight = torch.load(args.hypernet_model_path['hypernetwork'])
255 | sd = weight['hypernetwork']
256 | hypernetwork.load_state_dict(sd)
257 |
258 | for i, lilora in enumerate(lora_linear_layers):
259 | seed = weight['aux_seed_%d' % i]
260 | down_aux = weight['down_aux_%d' % i]
261 | up_aux = weight['up_aux_%d' % i]
262 | lilora.update_aux(seed, down_aux, up_aux)
263 |
264 | print(f"Hypernet weights loaded from: {args.hypernet_model_path}")
265 |
266 | hypernetwork = hypernetwork.to("cuda")
267 | hypernetwork = hypernetwork.eval()
268 |
269 | ref_img = Image.open(args.reference_image_path).convert("RGB")
270 | ref_img = hypernet_transposes(ref_img).unsqueeze(0).to("cuda")
271 |
272 | with torch.no_grad():
273 | weight, weight_list = hypernetwork(ref_img)
274 | print("weight>>>>>>>>>>>:",weight.shape, weight)
275 |
276 | # convert down and up weights to linear layer as LoRALinearLayer
277 | for weight, lora_layer in zip(weight_list, lora_linear_layers):
278 | lora_layer.update_weight(weight)
279 | lora_layer.convert_to_standard_lora()
280 |
281 | unet_lora_layers_to_save = unet_lora_state_dict(unet)
282 | text_encoder_lora_layers_to_save = None
283 | if args.train_text_encoder:
284 | text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(text_encoder, patch_mlp=args.patch_mlp)
285 |
286 | LoraLoaderMixin.save_lora_weights(
287 | save_directory=args.output_dir,
288 | unet_lora_layers=unet_lora_layers_to_save,
289 | text_encoder_lora_layers=text_encoder_lora_layers_to_save,
290 | )
291 |
292 | print("Export LoRA to: %s"%args.output_dir)
293 | print("==================================complete======================================")
294 |
295 |
296 | if __name__ == "__main__":
297 | args = parse_args()
298 | main(args)
--------------------------------------------------------------------------------
/export_hypernet_weight.sh:
--------------------------------------------------------------------------------
1 | export MODEL_NAME="stable-diffusion-models/realisticVisionV40_v40VAE"
2 | export HYPER_WEIGHT_DIR="projects/AIGC/experiments2/hypernet/CelebA-HQ-10k-no-pretrain2"
3 | export OUTPUT_DIR="projects/AIGC/lora_model_test"
4 | export reference_image_path="dataset/FFHQ_test/00015.png"
5 |
6 |
7 | python "export_hypernet_weight.py" \
8 | --pretrained_model_name_or_path $MODEL_NAME \
9 | --hypernet_model_path $HYPER_WEIGHT_DIR \
10 | --output_dir $OUTPUT_DIR \
11 | --rank 1 \
12 | --down_dim 160 \
13 | --up_dim 80 \
14 | --train_text_encoder \
15 | --reference_image_path $reference_image_path
16 |
--------------------------------------------------------------------------------
/export_preoptnet_weight.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 |
4 | # Modified by KohakuBlueLeaf
5 | # Modified from diffusers/example/dreambooth/train_dreambooth_lora.py
6 | # see original licensed below
7 | # =======================================================================
8 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
9 | #
10 | # Licensed under the Apache License, Version 2.0 (the "License");
11 | # you may not use this file except in compliance with the License.
12 | # You may obtain a copy of the License at
13 | #
14 | # http://www.apache.org/licenses/LICENSE-2.0
15 | #
16 | # Unless required by applicable law or agreed to in writing, software
17 | # distributed under the License is distributed on an "AS IS" BASIS,
18 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 | # See the License for the specific language governing permissions and
20 | # =======================================================================
21 |
22 | import argparse
23 | import os
24 | from packaging import version
25 |
26 | import torch
27 | from transformers import AutoTokenizer, PretrainedConfig
28 |
29 | from diffusers import (
30 | UNet2DConditionModel,
31 | )
32 | from diffusers.models.attention_processor import (
33 | AttnAddedKVProcessor,
34 | AttnAddedKVProcessor2_0,
35 | SlicedAttnAddedKVProcessor,
36 | )
37 | from diffusers import StableDiffusionPipeline
38 |
39 | from modules.light_lora import LoRALinearLayer, LoraLoaderMixin
40 | from modules.utils.lora_utils import unet_lora_state_dict, text_encoder_lora_state_dict
41 | from modules.hypernet import PreOptHyperDream
42 |
43 |
44 | def parse_args(input_args=None):
45 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
46 | parser.add_argument(
47 | "--pretrained_model_name_or_path",
48 | type=str,
49 | default=None,
50 | required=True,
51 | help="Path to pretrained model or model identifier from huggingface.co/models.",
52 | )
53 | parser.add_argument(
54 | "--revision",
55 | type=str,
56 | default=None,
57 | required=False,
58 | help="Revision of pretrained model identifier from huggingface.co/models.",
59 | )
60 | parser.add_argument(
61 | "--pre_opt_weight_path",
62 | type=str,
63 | default=None,
64 | required=True,
65 | help="Path to pretrained hyperkohaku model",
66 | )
67 |
68 | parser.add_argument(
69 | "--output_dir",
70 | type=str,
71 | default=None,
72 | required=True,
73 | )
74 |
75 | parser.add_argument(
76 | "--rank",
77 | type=int,
78 | default=1,
79 | help=("The dimension of the LoRA update matrices."),
80 | )
81 | parser.add_argument(
82 | "--down_dim",
83 | type=int,
84 | default=160,
85 | help=("The dimension of the LoRA update matrices."),
86 | )
87 | parser.add_argument(
88 | "--up_dim",
89 | type=int,
90 | default=80,
91 | help=("The dimension of the LoRA update matrices."),
92 | )
93 | parser.add_argument(
94 | "--train_text_encoder",
95 | action="store_true",
96 | help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
97 | )
98 | parser.add_argument(
99 | "--patch_mlp",
100 | action="store_true",
101 | help="Whether to train the text encoder with mlp. If set, the text encoder should be float32 precision.",
102 | )
103 | parser.add_argument(
104 | "--reference_image_id",
105 | type=int,
106 | default=1,
107 | help=("id in the celeb-a dataset"),
108 | )
109 |
110 | parser.add_argument(
111 | "--total_identities",
112 | type=int,
113 | default=30000,
114 | help=("The identities size of the training dataset."),
115 | )
116 |
117 | parser.add_argument(
118 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
119 | )
120 | if input_args is not None:
121 | args = parser.parse_args(input_args)
122 | else:
123 | args = parser.parse_args()
124 | return args
125 |
126 |
127 | def main(args):
128 | # Load Model
129 | pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, torch_dtype=torch.float32)
130 | # pipe.to("cuda")
131 |
132 | unet = pipe.unet
133 | text_encoder = pipe.text_encoder
134 |
135 | unet_lora_parameters = []
136 | unet_lora_linear_layers = []
137 | for i, (attn_processor_name, attn_processor) in enumerate(unet.attn_processors.items()):
138 | print("unet.attn_processor->%d:%s" % (i, attn_processor_name), attn_processor)
139 | # Parse the attention module.
140 | attn_module = unet
141 | for n in attn_processor_name.split(".")[:-1]:
142 | attn_module = getattr(attn_module, n)
143 | print("attn_module:", attn_module)
144 |
145 | # Set the `lora_layer` attribute of the attention-related matrices.
146 | attn_module.to_q.set_lora_layer(
147 | LoRALinearLayer(
148 | in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features,
149 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False,
150 | )
151 | )
152 | attn_module.to_k.set_lora_layer(
153 | LoRALinearLayer(
154 | in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features,
155 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False,
156 | )
157 | )
158 | attn_module.to_v.set_lora_layer(
159 | LoRALinearLayer(
160 | in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features,
161 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False,
162 | )
163 | )
164 | attn_module.to_out[0].set_lora_layer(
165 | LoRALinearLayer(
166 | in_features=attn_module.to_out[0].in_features, out_features=attn_module.to_out[0].out_features,
167 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False,
168 | )
169 | )
170 | # Accumulate the LoRA params to optimize.
171 | unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
172 | unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
173 | unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
174 | unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
175 |
176 | # Accumulate the LoRALinerLayer to optimize.
177 | unet_lora_linear_layers.append(attn_module.to_q.lora_layer)
178 | unet_lora_linear_layers.append(attn_module.to_k.lora_layer)
179 | unet_lora_linear_layers.append(attn_module.to_v.lora_layer)
180 | unet_lora_linear_layers.append(attn_module.to_out[0].lora_layer)
181 |
182 | if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
183 | attn_module.add_k_proj.set_lora_layer(
184 | LoRALinearLayer(
185 | in_features=attn_module.add_k_proj.in_features, out_features=attn_module.add_k_proj.out_features,
186 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False,
187 | )
188 | )
189 | attn_module.add_v_proj.set_lora_layer(
190 | LoRALinearLayer(
191 | in_features=attn_module.add_v_proj.in_features, out_features=attn_module.add_v_proj.out_features,
192 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False,
193 | )
194 | )
195 | unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters())
196 | unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters())
197 |
198 | unet_lora_linear_layers.append(attn_module.add_k_proj.lora_layer)
199 | unet_lora_linear_layers.append(attn_module.add_v_proj.lora_layer)
200 |
201 | # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
202 | # So, instead, we monkey-patch the forward calls of its attention-blocks.
203 | if args.train_text_encoder:
204 | # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
205 | # if patch_mlp is True, the finetuning will cover the text encoder mlp,
206 | # otherwise only the text encoder attention, total lora is (12+12)*4=96
207 | text_lora_parameters, text_lora_linear_layers = LoraLoaderMixin._modify_text_encoder(text_encoder,
208 | dtype=torch.float32,
209 | rank=args.rank,
210 | down_dim=args.down_dim,
211 | up_dim=args.up_dim,
212 | patch_mlp=args.patch_mlp,
213 | is_train=False)
214 |
215 | lora_linear_layers = unet_lora_linear_layers + text_lora_linear_layers \
216 | if args.train_text_encoder else unet_lora_linear_layers
217 |
218 | # create PreOptHyperDream and set lilora
219 | pre_opt_net = PreOptHyperDream(args.rank, args.down_dim, args.up_dim)
220 | pre_opt_net.set_lilora(lora_linear_layers, args.total_identities)
221 |
222 | # load weight
223 | if os.path.isfile(args.pre_opt_weight_path):
224 | weight = torch.load(args.pre_opt_weight_path)
225 | else:
226 | weight = torch.load(os.path.join(args.pre_opt_weight_path, 'pre_optimized.bin'))
227 | # load pre-optimized lilora weights for each identity
228 | sd = weight['pre_optimized']
229 | pre_opt_net.load_state_dict(sd)
230 | pre_opt_net.requires_grad_(False)
231 | pre_opt_net.set_device('cuda')
232 |
233 | for i,lilora in enumerate(lora_linear_layers):
234 | seed = weight['aux_seed_%d'%i]
235 | down_aux = weight['down_aux_%d'%i]
236 | up_aux = weight['up_aux_%d'%i]
237 | lilora.update_aux(seed, down_aux, up_aux)
238 |
239 | print(f"PreOptNet weights loaded from: {args.pre_opt_weight_path}")
240 |
241 | with torch.no_grad():
242 | # get pre-optimized weights according to identity
243 | weights, weight_list = pre_opt_net([args.reference_image_id])
244 | for weight, lora_layer in zip(weight_list, lora_linear_layers):
245 | lora_layer.update_weight(weight)
246 | lora_layer.convert_to_standard_lora()
247 | print("weight",weight)
248 |
249 | unet_lora_layers_to_save = unet_lora_state_dict(unet)
250 | text_encoder_lora_layers_to_save = None
251 | if args.train_text_encoder:
252 | text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(text_encoder, patch_mlp=args.patch_mlp)
253 |
254 | LoraLoaderMixin.save_lora_weights(
255 | save_directory=args.output_dir,
256 | unet_lora_layers=unet_lora_layers_to_save,
257 | text_encoder_lora_layers=text_encoder_lora_layers_to_save,
258 | )
259 |
260 | print("Save LoRA to: %s"%args.output_dir)
261 | print("==================================complete======================================")
262 |
263 | if __name__ == "__main__":
264 | args = parse_args()
265 | main(args)
--------------------------------------------------------------------------------
/export_preoptnet_weight.sh:
--------------------------------------------------------------------------------
1 | export MODEL_NAME="stable-diffusion-models/realisticVisionV40_v40VAE"
2 | export PRE_OPTNET_WEIGHT_DIR="projects/AIGC/experiments/pretrained/CelebA-HQ-100"
3 | export OUTPUT_DIR="projects/AIGC/lora_model_test"
4 |
5 | python "export_preoptnet_weight.py" \
6 | --pretrained_model_name_or_path $MODEL_NAME \
7 | --pre_opt_weight_path $PRE_OPTNET_WEIGHT_DIR \
8 | --output_dir $OUTPUT_DIR \
9 | --vit_model_name vit_huge_patch14_clip_336 \
10 | --rank 1 \
11 | --down_dim 160 \
12 | --up_dim 80 \
13 | --train_text_encoder \
14 | --total_identities 100 \
15 | --reference_image_id 10
16 |
--------------------------------------------------------------------------------
/fast_finetune.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 |
16 | import argparse
17 | import copy
18 | import gc
19 | import itertools
20 | import logging
21 | import math
22 | import os
23 | import time
24 | import shutil
25 | import warnings
26 | from pathlib import Path
27 |
28 | import numpy as np
29 | import torch
30 | import torch.nn.functional as F
31 | import torch.utils.checkpoint
32 | import transformers
33 | from accelerate import Accelerator
34 | from accelerate.logging import get_logger
35 | from accelerate.utils import ProjectConfiguration, set_seed
36 | from huggingface_hub import create_repo, upload_folder
37 | from huggingface_hub.utils import insecure_hashlib
38 | from packaging import version
39 | from PIL import Image
40 | from PIL.ImageOps import exif_transpose
41 | from torch.utils.data import Dataset
42 | from torchvision import transforms
43 | from tqdm.auto import tqdm
44 | from transformers import AutoTokenizer, PretrainedConfig
45 |
46 | import diffusers
47 | from diffusers import (
48 | AutoencoderKL,
49 | DDPMScheduler,
50 | DiffusionPipeline,
51 | DPMSolverMultistepScheduler,
52 | StableDiffusionPipeline,
53 | UNet2DConditionModel,
54 | )
55 | from diffusers.models.attention_processor import (
56 | AttnAddedKVProcessor,
57 | AttnAddedKVProcessor2_0,
58 | SlicedAttnAddedKVProcessor,
59 | )
60 | from diffusers.optimization import get_scheduler
61 | from diffusers.training_utils import unet_lora_state_dict
62 | from diffusers.utils import check_min_version, is_wandb_available
63 | from diffusers.utils.import_utils import is_xformers_available
64 |
65 | from modules.lora import LoRALinearLayer, LoraLoaderMixin
66 |
67 | # from diffusers.models.lora import LoRALinearLayer
68 | # from diffusers.loaders import LoraLoaderMixin
69 |
70 |
71 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
72 | check_min_version("0.25.0.dev0")
73 |
74 | logger = get_logger(__name__)
75 |
76 |
77 | # TODO: This function should be removed once training scripts are rewritten in PEFT
78 | def text_encoder_lora_state_dict(text_encoder, patch_mlp=False):
79 | state_dict = {}
80 |
81 | def text_encoder_attn_modules(text_encoder):
82 | from transformers import CLIPTextModel, CLIPTextModelWithProjection
83 |
84 | attn_modules = []
85 |
86 | if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
87 | for i, layer in enumerate(text_encoder.text_model.encoder.layers):
88 | name = f"text_model.encoder.layers.{i}.self_attn"
89 | mod = layer.self_attn
90 | attn_modules.append((name, mod))
91 |
92 | return attn_modules
93 |
94 | def text_encoder_mlp_modules(text_encoder):
95 | from transformers import CLIPTextModel, CLIPTextModelWithProjection
96 | mlp_modules = []
97 | if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
98 | for i, layer in enumerate(text_encoder.text_model.encoder.layers):
99 | name = f"text_model.encoder.layers.{i}.mlp"
100 | mod = layer.mlp
101 | mlp_modules.append((name, mod))
102 |
103 | return mlp_modules
104 |
105 | # text encoder attn layer
106 | for name, module in text_encoder_attn_modules(text_encoder):
107 | for k, v in module.q_proj.lora_linear_layer.state_dict().items():
108 | state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
109 |
110 | for k, v in module.k_proj.lora_linear_layer.state_dict().items():
111 | state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
112 |
113 | for k, v in module.v_proj.lora_linear_layer.state_dict().items():
114 | state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
115 |
116 | for k, v in module.out_proj.lora_linear_layer.state_dict().items():
117 | state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
118 |
119 | # text encoder mlp layer
120 | if patch_mlp:
121 | for name, module in text_encoder_mlp_modules(text_encoder):
122 | for k, v in module.fc1.lora_linear_layer.state_dict().items():
123 | state_dict[f"{name}.fc1.lora_linear_layer.{k}"] = v
124 |
125 | for k, v in module.fc2.lora_linear_layer.state_dict().items():
126 | state_dict[f"{name}.fc2.lora_linear_layer.{k}"] = v
127 |
128 | return state_dict
129 |
130 |
131 | def save_model_card(
132 | repo_id: str,
133 | images=None,
134 | base_model=str,
135 | train_text_encoder=False,
136 | prompt=str,
137 | repo_folder=None,
138 | pipeline: DiffusionPipeline = None,
139 | ):
140 | img_str = ""
141 | for i, image in enumerate(images):
142 | image.save(os.path.join(repo_folder, f"image_{i}.png"))
143 | img_str += f"\n"
144 |
145 | yaml = f"""
146 | ---
147 | license: creativeml-openrail-m
148 | base_model: {base_model}
149 | instance_prompt: {prompt}
150 | tags:
151 | - {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
152 | - {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
153 | - text-to-image
154 | - diffusers
155 | - lora
156 | inference: true
157 | ---
158 | """
159 | model_card = f"""
160 | # LoRA DreamBooth - {repo_id}
161 |
162 | These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
163 | {img_str}
164 |
165 | LoRA for the text encoder was enabled: {train_text_encoder}.
166 | """
167 | with open(os.path.join(repo_folder, "README.md"), "w") as f:
168 | f.write(yaml + model_card)
169 |
170 |
171 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
172 | text_encoder_config = PretrainedConfig.from_pretrained(
173 | pretrained_model_name_or_path,
174 | subfolder="text_encoder",
175 | revision=revision,
176 | )
177 | model_class = text_encoder_config.architectures[0]
178 |
179 | if model_class == "CLIPTextModel":
180 | from transformers import CLIPTextModel
181 |
182 | return CLIPTextModel
183 | elif model_class == "RobertaSeriesModelWithTransformation":
184 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
185 |
186 | return RobertaSeriesModelWithTransformation
187 | elif model_class == "T5EncoderModel":
188 | from transformers import T5EncoderModel
189 |
190 | return T5EncoderModel
191 | else:
192 | raise ValueError(f"{model_class} is not supported.")
193 |
194 |
195 | def parse_args(input_args=None):
196 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
197 | parser.add_argument(
198 | "--pretrained_model_name_or_path",
199 | type=str,
200 | default=None,
201 | required=True,
202 | help="Path to pretrained model or model identifier from huggingface.co/models.",
203 | )
204 | parser.add_argument(
205 | "--revision",
206 | type=str,
207 | default=None,
208 | required=False,
209 | help="Revision of pretrained model identifier from huggingface.co/models.",
210 | )
211 | parser.add_argument(
212 | "--variant",
213 | type=str,
214 | default=None,
215 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
216 | )
217 | parser.add_argument(
218 | "--tokenizer_name",
219 | type=str,
220 | default=None,
221 | help="Pretrained tokenizer name or path if not the same as model_name",
222 | )
223 | parser.add_argument(
224 | "--instance_image_path",
225 | type=str,
226 | default=None,
227 | required=True,
228 | help="A folder containing the training data of instance images.",
229 | )
230 | parser.add_argument(
231 | "--class_data_dir",
232 | type=str,
233 | default=None,
234 | required=False,
235 | help="A folder containing the training data of class images.",
236 | )
237 | parser.add_argument(
238 | "--instance_prompt",
239 | type=str,
240 | default=None,
241 | required=True,
242 | help="The prompt with identifier specifying the instance",
243 | )
244 | parser.add_argument(
245 | "--class_prompt",
246 | type=str,
247 | default=None,
248 | help="The prompt to specify images in the same class as provided instance images.",
249 | )
250 | parser.add_argument(
251 | "--validation_prompt",
252 | type=str,
253 | default=None,
254 | help="A prompt that is used during validation to verify that the model is learning.",
255 | )
256 | parser.add_argument(
257 | "--num_validation_images",
258 | type=int,
259 | default=5,
260 | help="Number of images that should be generated during validation with `validation_prompt`.",
261 | )
262 |
263 | parser.add_argument(
264 | "--with_prior_preservation",
265 | default=False,
266 | action="store_true",
267 | help="Flag to add prior preservation loss.",
268 | )
269 | parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
270 | parser.add_argument(
271 | "--num_class_images",
272 | type=int,
273 | default=100,
274 | help=(
275 | "Minimal class images for prior preservation loss. If there are not enough images already present in"
276 | " class_data_dir, additional images will be sampled with class_prompt."
277 | ),
278 | )
279 | parser.add_argument(
280 | "--output_dir",
281 | type=str,
282 | default="lora-dreambooth-model",
283 | help="The output directory where the model predictions and checkpoints will be written.",
284 | )
285 | parser.add_argument(
286 | "--resume_dir",
287 | type=str,
288 | default="",
289 | help="The output directory where the pretrained model predictions and checkpoints has be written.",
290 | )
291 |
292 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
293 | parser.add_argument(
294 | "--resolution",
295 | type=int,
296 | default=512,
297 | help=(
298 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
299 | " resolution"
300 | ),
301 | )
302 | parser.add_argument(
303 | "--center_crop",
304 | default=False,
305 | action="store_true",
306 | help=(
307 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
308 | " cropped. The images will be resized to the resolution first before cropping."
309 | ),
310 | )
311 | parser.add_argument(
312 | "--train_text_encoder",
313 | action="store_true",
314 | help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
315 | )
316 | parser.add_argument(
317 | "--patch_mlp",
318 | action="store_true",
319 | help="Whether to train the text encoder with mlp. If set, the text encoder should be float32 precision.",
320 | )
321 | parser.add_argument(
322 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
323 | )
324 | parser.add_argument(
325 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
326 | )
327 | parser.add_argument("--num_train_steps", type=int, default=1)
328 | parser.add_argument(
329 | "--max_train_steps",
330 | type=int,
331 | default=None,
332 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
333 | )
334 | parser.add_argument(
335 | "--checkpoints_total_limit",
336 | type=int,
337 | default=None,
338 | help=("Max number of checkpoints to store."),
339 | )
340 | parser.add_argument(
341 | "--resume_from_checkpoint",
342 | type=str,
343 | default=None,
344 | help=(
345 | "Whether training should be resumed from a previous checkpoint. Use a path saved by"
346 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
347 | ),
348 | )
349 | parser.add_argument(
350 | "--gradient_accumulation_steps",
351 | type=int,
352 | default=1,
353 | help="Number of updates steps to accumulate before performing a backward/update pass.",
354 | )
355 | parser.add_argument(
356 | "--gradient_checkpointing",
357 | action="store_true",
358 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
359 | )
360 | parser.add_argument(
361 | "--learning_rate",
362 | type=float,
363 | default=5e-4,
364 | help="Initial learning rate (after the potential warmup period) to use.",
365 | )
366 | parser.add_argument(
367 | "--scale_lr",
368 | action="store_true",
369 | default=False,
370 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
371 | )
372 | parser.add_argument(
373 | "--lr_scheduler",
374 | type=str,
375 | default="constant",
376 | help=(
377 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
378 | ' "constant", "constant_with_warmup"]'
379 | ),
380 | )
381 | parser.add_argument(
382 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
383 | )
384 | parser.add_argument(
385 | "--lr_num_cycles",
386 | type=int,
387 | default=1,
388 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
389 | )
390 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
391 | parser.add_argument(
392 | "--dataloader_num_workers",
393 | type=int,
394 | default=0,
395 | help=(
396 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
397 | ),
398 | )
399 | parser.add_argument(
400 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
401 | )
402 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
403 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
404 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
405 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
406 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
407 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
408 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
409 | parser.add_argument(
410 | "--hub_model_id",
411 | type=str,
412 | default=None,
413 | help="The name of the repository to keep in sync with the local `output_dir`.",
414 | )
415 | parser.add_argument(
416 | "--logging_dir",
417 | type=str,
418 | default="logs",
419 | help=(
420 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
421 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
422 | ),
423 | )
424 | parser.add_argument(
425 | "--allow_tf32",
426 | action="store_true",
427 | help=(
428 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
429 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
430 | ),
431 | )
432 | parser.add_argument(
433 | "--report_to",
434 | type=str,
435 | default="tensorboard",
436 | help=(
437 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
438 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
439 | ),
440 | )
441 | parser.add_argument(
442 | "--mixed_precision",
443 | type=str,
444 | default=None,
445 | choices=["no", "fp16", "bf16"],
446 | help=(
447 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
448 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
449 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
450 | ),
451 | )
452 | parser.add_argument(
453 | "--prior_generation_precision",
454 | type=str,
455 | default=None,
456 | choices=["no", "fp32", "fp16", "bf16"],
457 | help=(
458 | "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
459 | " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
460 | ),
461 | )
462 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
463 | parser.add_argument(
464 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
465 | )
466 | parser.add_argument(
467 | "--pre_compute_text_embeddings",
468 | action="store_true",
469 | help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.",
470 | )
471 | parser.add_argument(
472 | "--tokenizer_max_length",
473 | type=int,
474 | default=None,
475 | required=False,
476 | help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.",
477 | )
478 | parser.add_argument(
479 | "--text_encoder_use_attention_mask",
480 | action="store_true",
481 | required=False,
482 | help="Whether to use attention mask for the text encoder",
483 | )
484 | parser.add_argument(
485 | "--validation_images",
486 | required=False,
487 | default=None,
488 | nargs="+",
489 | help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
490 | )
491 | parser.add_argument(
492 | "--class_labels_conditioning",
493 | required=False,
494 | default=None,
495 | help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
496 | )
497 | parser.add_argument(
498 | "--rank",
499 | type=int,
500 | default=4,
501 | help=("The dimension of the LoRA update matrices."),
502 | )
503 |
504 | if input_args is not None:
505 | args = parser.parse_args(input_args)
506 | else:
507 | args = parser.parse_args()
508 |
509 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
510 | if env_local_rank != -1 and env_local_rank != args.local_rank:
511 | args.local_rank = env_local_rank
512 |
513 | if args.with_prior_preservation:
514 | if args.class_data_dir is None:
515 | raise ValueError("You must specify a data directory for class images.")
516 | if args.class_prompt is None:
517 | raise ValueError("You must specify prompt for class images.")
518 | else:
519 | # logger is not available yet
520 | if args.class_data_dir is not None:
521 | warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
522 | if args.class_prompt is not None:
523 | warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
524 |
525 | if args.train_text_encoder and args.pre_compute_text_embeddings:
526 | raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`")
527 |
528 | return args
529 |
530 |
531 | class DreamBoothDataset(Dataset):
532 | """
533 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
534 | It pre-processes the images and the tokenizes prompts.
535 | """
536 |
537 | def __init__(
538 | self,
539 | instance_image_path,
540 | instance_prompt,
541 | tokenizer,
542 | class_data_root=None,
543 | class_prompt=None,
544 | class_num=None,
545 | size=512,
546 | center_crop=False,
547 | encoder_hidden_states=None,
548 | class_prompt_encoder_hidden_states=None,
549 | tokenizer_max_length=None,
550 | ):
551 | self.size = size
552 | self.center_crop = center_crop
553 | self.tokenizer = tokenizer
554 | self.encoder_hidden_states = encoder_hidden_states
555 | self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
556 | self.tokenizer_max_length = tokenizer_max_length
557 |
558 | if not os.path.isfile(instance_image_path):
559 | raise ValueError('%s is not a file!'%instance_image_path)
560 |
561 | self.instance_images_path = [instance_image_path]
562 |
563 | self.num_instance_images = len(self.instance_images_path)
564 | self.instance_prompt = instance_prompt
565 | self._length = self.num_instance_images
566 |
567 | if class_data_root is not None:
568 | self.class_data_root = Path(class_data_root)
569 | self.class_data_root.mkdir(parents=True, exist_ok=True)
570 | self.class_images_path = list(self.class_data_root.iterdir())
571 | if class_num is not None:
572 | self.num_class_images = min(len(self.class_images_path), class_num)
573 | else:
574 | self.num_class_images = len(self.class_images_path)
575 | self._length = max(self.num_class_images, self.num_instance_images)
576 | self.class_prompt = class_prompt
577 | else:
578 | self.class_data_root = None
579 |
580 | self.image_transforms = transforms.Compose(
581 | [
582 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
583 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
584 | transforms.ToTensor(),
585 | transforms.Normalize([0.5], [0.5]),
586 | ]
587 | )
588 |
589 | def __len__(self):
590 | return self._length
591 |
592 | def __getitem__(self, index):
593 | example = {}
594 | instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
595 | instance_image = exif_transpose(instance_image)
596 |
597 | if not instance_image.mode == "RGB":
598 | instance_image = instance_image.convert("RGB")
599 | example["instance_images"] = self.image_transforms(instance_image)
600 |
601 | if self.encoder_hidden_states is not None:
602 | example["instance_prompt_ids"] = self.encoder_hidden_states
603 | else:
604 | text_inputs = tokenize_prompt(
605 | self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
606 | )
607 | example["instance_prompt_ids"] = text_inputs.input_ids
608 | example["instance_attention_mask"] = text_inputs.attention_mask
609 |
610 | if self.class_data_root:
611 | class_image = Image.open(self.class_images_path[index % self.num_class_images])
612 | class_image = exif_transpose(class_image)
613 |
614 | if not class_image.mode == "RGB":
615 | class_image = class_image.convert("RGB")
616 | example["class_images"] = self.image_transforms(class_image)
617 |
618 | if self.class_prompt_encoder_hidden_states is not None:
619 | example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
620 | else:
621 | class_text_inputs = tokenize_prompt(
622 | self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
623 | )
624 | example["class_prompt_ids"] = class_text_inputs.input_ids
625 | example["class_attention_mask"] = class_text_inputs.attention_mask
626 |
627 | return example
628 |
629 |
630 | def collate_fn(examples, with_prior_preservation=False):
631 | has_attention_mask = "instance_attention_mask" in examples[0]
632 |
633 | input_ids = [example["instance_prompt_ids"] for example in examples]
634 | pixel_values = [example["instance_images"] for example in examples]
635 |
636 | if has_attention_mask:
637 | attention_mask = [example["instance_attention_mask"] for example in examples]
638 |
639 | # Concat class and instance examples for prior preservation.
640 | # We do this to avoid doing two forward passes.
641 | if with_prior_preservation:
642 | input_ids += [example["class_prompt_ids"] for example in examples]
643 | pixel_values += [example["class_images"] for example in examples]
644 | if has_attention_mask:
645 | attention_mask += [example["class_attention_mask"] for example in examples]
646 |
647 | pixel_values = torch.stack(pixel_values)
648 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
649 |
650 | input_ids = torch.cat(input_ids, dim=0)
651 |
652 | batch = {
653 | "input_ids": input_ids,
654 | "pixel_values": pixel_values,
655 | }
656 |
657 | if has_attention_mask:
658 | batch["attention_mask"] = attention_mask
659 |
660 | return batch
661 |
662 |
663 | class PromptDataset(Dataset):
664 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
665 |
666 | def __init__(self, prompt, num_samples):
667 | self.prompt = prompt
668 | self.num_samples = num_samples
669 |
670 | def __len__(self):
671 | return self.num_samples
672 |
673 | def __getitem__(self, index):
674 | example = {}
675 | example["prompt"] = self.prompt
676 | example["index"] = index
677 | return example
678 |
679 |
680 | def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
681 | if tokenizer_max_length is not None:
682 | max_length = tokenizer_max_length
683 | else:
684 | max_length = tokenizer.model_max_length
685 |
686 | text_inputs = tokenizer(
687 | prompt,
688 | truncation=True,
689 | padding="max_length",
690 | max_length=max_length,
691 | return_tensors="pt",
692 | )
693 |
694 | return text_inputs
695 |
696 |
697 | def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
698 | text_input_ids = input_ids.to(text_encoder.device)
699 |
700 | if text_encoder_use_attention_mask:
701 | attention_mask = attention_mask.to(text_encoder.device)
702 | else:
703 | attention_mask = None
704 |
705 | prompt_embeds = text_encoder(
706 | text_input_ids,
707 | attention_mask=attention_mask,
708 | )
709 | prompt_embeds = prompt_embeds[0]
710 |
711 | return prompt_embeds
712 |
713 |
714 | def main(args):
715 | logging_dir = Path(args.output_dir, args.logging_dir)
716 |
717 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
718 |
719 | accelerator = Accelerator(
720 | gradient_accumulation_steps=args.gradient_accumulation_steps,
721 | mixed_precision=args.mixed_precision,
722 | log_with=args.report_to,
723 | project_config=accelerator_project_config,
724 | )
725 |
726 | if args.report_to == "wandb":
727 | if not is_wandb_available():
728 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
729 | import wandb
730 |
731 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
732 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
733 | # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate.
734 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
735 | raise ValueError(
736 | "Gradient accumulation is not supported when training the text encoder in distributed training. "
737 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
738 | )
739 |
740 | # Make one log on every process with the configuration for debugging.
741 | logging.basicConfig(
742 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
743 | datefmt="%m/%d/%Y %H:%M:%S",
744 | level=logging.INFO,
745 | )
746 | logger.info(accelerator.state, main_process_only=False)
747 | if accelerator.is_local_main_process:
748 | transformers.utils.logging.set_verbosity_warning()
749 | diffusers.utils.logging.set_verbosity_info()
750 | else:
751 | transformers.utils.logging.set_verbosity_error()
752 | diffusers.utils.logging.set_verbosity_error()
753 |
754 | # If passed along, set the training seed now.
755 | if args.seed is not None:
756 | set_seed(args.seed)
757 |
758 | # Generate class images if prior preservation is enabled.
759 | if args.with_prior_preservation:
760 | class_images_dir = Path(args.class_data_dir)
761 | if not class_images_dir.exists():
762 | class_images_dir.mkdir(parents=True)
763 | cur_class_images = len(list(class_images_dir.iterdir()))
764 |
765 | if cur_class_images < args.num_class_images:
766 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
767 | if args.prior_generation_precision == "fp32":
768 | torch_dtype = torch.float32
769 | elif args.prior_generation_precision == "fp16":
770 | torch_dtype = torch.float16
771 | elif args.prior_generation_precision == "bf16":
772 | torch_dtype = torch.bfloat16
773 | pipeline = DiffusionPipeline.from_pretrained(
774 | args.pretrained_model_name_or_path,
775 | torch_dtype=torch_dtype,
776 | safety_checker=None,
777 | revision=args.revision,
778 | variant=args.variant,
779 | )
780 | pipeline.set_progress_bar_config(disable=True)
781 |
782 | num_new_images = args.num_class_images - cur_class_images
783 | logger.info(f"Number of class images to sample: {num_new_images}.")
784 |
785 | sample_dataset = PromptDataset(args.class_prompt, num_new_images)
786 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
787 |
788 | sample_dataloader = accelerator.prepare(sample_dataloader)
789 | pipeline.to(accelerator.device)
790 |
791 | for example in tqdm(
792 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
793 | ):
794 | images = pipeline(example["prompt"]).images
795 |
796 | for i, image in enumerate(images):
797 | hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
798 | image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
799 | image.save(image_filename)
800 |
801 | del pipeline
802 | if torch.cuda.is_available():
803 | torch.cuda.empty_cache()
804 |
805 | # Handle the repository creation
806 | if accelerator.is_main_process:
807 | if args.output_dir is not None:
808 | os.makedirs(args.output_dir, exist_ok=True)
809 |
810 | if args.push_to_hub:
811 | repo_id = create_repo(
812 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
813 | ).repo_id
814 |
815 | # Load the tokenizer
816 | if args.tokenizer_name:
817 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
818 | elif args.pretrained_model_name_or_path:
819 | tokenizer = AutoTokenizer.from_pretrained(
820 | args.pretrained_model_name_or_path,
821 | subfolder="tokenizer",
822 | revision=args.revision,
823 | use_fast=False,
824 | )
825 |
826 | # import correct text encoder class
827 | text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
828 |
829 | # Load scheduler and models
830 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
831 | text_encoder = text_encoder_cls.from_pretrained(
832 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
833 | )
834 | try:
835 | vae = AutoencoderKL.from_pretrained(
836 | args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
837 | )
838 | except OSError:
839 | # IF does not have a VAE so let's just set it to None
840 | # We don't have to error out here
841 | vae = None
842 |
843 | unet = UNet2DConditionModel.from_pretrained(
844 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
845 | )
846 |
847 | # We only train the additional adapter LoRA layers
848 | if vae is not None:
849 | vae.requires_grad_(False)
850 | text_encoder.requires_grad_(False)
851 | unet.requires_grad_(False)
852 |
853 |
854 | # TODO: RESUME
855 | if args.resume_dir:
856 | lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(args.resume_dir)
857 | # print(lora_state_dict.keys())
858 |
859 | # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
860 | # as these weights are only used for inference, keeping weights in full precision is not required.
861 | weight_dtype = torch.float32
862 | if accelerator.mixed_precision == "fp16":
863 | weight_dtype = torch.float16
864 | elif accelerator.mixed_precision == "bf16":
865 | weight_dtype = torch.bfloat16
866 |
867 | # Move unet, vae and text_encoder to device and cast to weight_dtype
868 | unet.to(accelerator.device, dtype=weight_dtype)
869 | if vae is not None:
870 | vae.to(accelerator.device, dtype=weight_dtype)
871 | text_encoder.to(accelerator.device, dtype=weight_dtype)
872 |
873 | if args.enable_xformers_memory_efficient_attention:
874 | if is_xformers_available():
875 | import xformers
876 |
877 | xformers_version = version.parse(xformers.__version__)
878 | if xformers_version == version.parse("0.0.16"):
879 | logger.warn(
880 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
881 | )
882 | unet.enable_xformers_memory_efficient_attention()
883 | else:
884 | raise ValueError("xformers is not available. Make sure it is installed correctly")
885 |
886 | if args.gradient_checkpointing:
887 | unet.enable_gradient_checkpointing()
888 | if args.train_text_encoder:
889 | text_encoder.gradient_checkpointing_enable()
890 |
891 | # now we will add new LoRA weights to the attention layers
892 | # It's important to realize here how many attention weights will be added and of which sizes
893 | # The sizes of the attention layers consist only of two different variables:
894 | # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
895 | # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
896 |
897 | # Let's first see how many attention processors we will have to set.
898 | # For Stable Diffusion, it should be equal to:
899 | # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
900 | # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
901 | # - up blocks (2x attention layers) * (3x transformer layers) * (3x up blocks) = 18
902 | # => 32 layers
903 |
904 | # Set correct lora layers, the total layer is 32, total lora is 32*4=128
905 | unet_lora_parameters = []
906 | for i, (attn_processor_name, attn_processor) in enumerate(unet.attn_processors.items()):
907 | print("unet.attn_processor->%d:%s" % (i, attn_processor_name), attn_processor)
908 | # Parse the attention module.
909 | attn_module = unet
910 | for n in attn_processor_name.split(".")[:-1]:
911 | attn_module = getattr(attn_module, n)
912 | print("attn_module:",attn_module)
913 |
914 | # Set the `lora_layer` attribute of the attention-related matrices.
915 | attn_module.to_q.set_lora_layer(
916 | LoRALinearLayer(
917 | in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
918 | )
919 | )
920 | attn_module.to_k.set_lora_layer(
921 | LoRALinearLayer(
922 | in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
923 | )
924 | )
925 | attn_module.to_v.set_lora_layer(
926 | LoRALinearLayer(
927 | in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
928 | )
929 | )
930 | attn_module.to_out[0].set_lora_layer(
931 | LoRALinearLayer(
932 | in_features=attn_module.to_out[0].in_features,
933 | out_features=attn_module.to_out[0].out_features,
934 | rank=args.rank,
935 | )
936 | )
937 | # Accumulate the LoRA params to optimize.
938 | unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
939 | unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
940 | unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
941 | unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
942 | # TODO: # copy weights
943 | for layer_name in ['to_q', 'to_k', 'to_v', 'to_out']:
944 | attn_processor_name = attn_processor_name.replace('.processor', '')
945 | if layer_name == 'to_out':
946 | layer = getattr(attn_module, layer_name)[0].lora_layer
947 | down_key = "unet.%s.%s.0.lora.down.weight" % (attn_processor_name, layer_name)
948 | up_key = "unet.%s.%s.0.lora.up.weight" % (attn_processor_name, layer_name)
949 | else:
950 | layer = getattr(attn_module, layer_name).lora_layer
951 | down_key = "unet.%s.%s.lora.down.weight" % (attn_processor_name, layer_name)
952 | up_key = "unet.%s.%s.lora.up.weight" % (attn_processor_name, layer_name)
953 |
954 | layer.down.weight.data.copy_(lora_state_dict[down_key].to(torch.float32))
955 | layer.up.weight.data.copy_(lora_state_dict[up_key].to(torch.float32))
956 |
957 | if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
958 | attn_module.add_k_proj.set_lora_layer(
959 | LoRALinearLayer(
960 | in_features=attn_module.add_k_proj.in_features,
961 | out_features=attn_module.add_k_proj.out_features,
962 | rank=args.rank,
963 | )
964 | )
965 | attn_module.add_v_proj.set_lora_layer(
966 | LoRALinearLayer(
967 | in_features=attn_module.add_v_proj.in_features,
968 | out_features=attn_module.add_v_proj.out_features,
969 | rank=args.rank,
970 | )
971 | )
972 | unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters())
973 | unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters())
974 |
975 | for layer_name in ['add_k_proj', 'add_v_proj']:
976 | attn_processor_name = attn_processor_name.replace('.processor', '')
977 | layer = getattr(attn_module, layer_name).lora_layer
978 | down_key = "unet.%s.%s.lora.down.weight" % (attn_processor_name, layer_name)
979 | up_key = "unet.%s.%s.lora.up.weight" % (attn_processor_name, layer_name)
980 | # copy weights
981 | layer.down.weight.data.copy_(lora_state_dict[down_key].to(torch.float32))
982 | layer.up.weight.data.copy_(lora_state_dict[up_key].to(torch.float32))
983 | print("unet add_proj lora initialized!")
984 |
985 | # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
986 | # So, instead, we monkey-patch the forward calls of its attention-blocks.
987 | if args.train_text_encoder:
988 | # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
989 | # if patch_mlp is True, the finetuning will cover the text encoder mlp, otherwise only the text encoder attention, total lora is (12+12)*4=96
990 | text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, lora_state_dict, dtype=torch.float32, rank=args.rank, patch_mlp=args.patch_mlp)
991 |
992 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
993 | def save_model_hook(models, weights, output_dir):
994 | if accelerator.is_main_process:
995 | # there are only two options here. Either are just the unet attn processor layers
996 | # or there are the unet and text encoder atten layers
997 | unet_lora_layers_to_save = None
998 | text_encoder_lora_layers_to_save = None
999 |
1000 | for model in models:
1001 | if isinstance(model, type(accelerator.unwrap_model(unet))):
1002 | unet_lora_layers_to_save = unet_lora_state_dict(model)
1003 | elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
1004 | text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model, patch_mlp=args.patch_mlp)
1005 | else:
1006 | raise ValueError(f"unexpected save model: {model.__class__}")
1007 |
1008 | # make sure to pop weight so that corresponding model is not saved again
1009 | weights.pop()
1010 |
1011 | LoraLoaderMixin.save_lora_weights(
1012 | output_dir,
1013 | unet_lora_layers=unet_lora_layers_to_save,
1014 | text_encoder_lora_layers=text_encoder_lora_layers_to_save,
1015 | )
1016 |
1017 | def load_model_hook(models, input_dir):
1018 | unet_ = None
1019 | text_encoder_ = None
1020 |
1021 | while len(models) > 0:
1022 | model = models.pop()
1023 |
1024 | if isinstance(model, type(accelerator.unwrap_model(unet))):
1025 | unet_ = model
1026 | elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
1027 | text_encoder_ = model
1028 | else:
1029 | raise ValueError(f"unexpected save model: {model.__class__}")
1030 |
1031 | lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
1032 | LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
1033 | LoraLoaderMixin.load_lora_into_text_encoder(
1034 | lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
1035 | )
1036 |
1037 | accelerator.register_save_state_pre_hook(save_model_hook)
1038 | accelerator.register_load_state_pre_hook(load_model_hook)
1039 |
1040 | # Enable TF32 for faster training on Ampere GPUs,
1041 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1042 | if args.allow_tf32:
1043 | torch.backends.cuda.matmul.allow_tf32 = True
1044 |
1045 | if args.scale_lr:
1046 | args.learning_rate = (
1047 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1048 | )
1049 |
1050 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
1051 | if args.use_8bit_adam:
1052 | try:
1053 | import bitsandbytes as bnb
1054 | except ImportError:
1055 | raise ImportError(
1056 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
1057 | )
1058 |
1059 | optimizer_class = bnb.optim.AdamW8bit
1060 | else:
1061 | optimizer_class = torch.optim.AdamW
1062 |
1063 | # Optimizer creation
1064 | params_to_optimize = (
1065 | itertools.chain(unet_lora_parameters, text_lora_parameters)
1066 | if args.train_text_encoder
1067 | else unet_lora_parameters
1068 | )
1069 | optimizer = optimizer_class(
1070 | params_to_optimize,
1071 | lr=args.learning_rate,
1072 | betas=(args.adam_beta1, args.adam_beta2),
1073 | weight_decay=args.adam_weight_decay,
1074 | eps=args.adam_epsilon,
1075 | )
1076 |
1077 | if args.pre_compute_text_embeddings:
1078 |
1079 | def compute_text_embeddings(prompt):
1080 | with torch.no_grad():
1081 | text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)
1082 | prompt_embeds = encode_prompt(
1083 | text_encoder,
1084 | text_inputs.input_ids,
1085 | text_inputs.attention_mask,
1086 | text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
1087 | )
1088 |
1089 | return prompt_embeds
1090 |
1091 | pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
1092 | validation_prompt_negative_prompt_embeds = compute_text_embeddings("")
1093 |
1094 | if args.validation_prompt is not None:
1095 | validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)
1096 | else:
1097 | validation_prompt_encoder_hidden_states = None
1098 |
1099 | if args.class_prompt is not None:
1100 | pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)
1101 | else:
1102 | pre_computed_class_prompt_encoder_hidden_states = None
1103 |
1104 | text_encoder = None
1105 | tokenizer = None
1106 |
1107 | gc.collect()
1108 | torch.cuda.empty_cache()
1109 | else:
1110 | pre_computed_encoder_hidden_states = None
1111 | validation_prompt_encoder_hidden_states = None
1112 | validation_prompt_negative_prompt_embeds = None
1113 | pre_computed_class_prompt_encoder_hidden_states = None
1114 |
1115 | # Dataset and DataLoaders creation:
1116 | train_dataset = DreamBoothDataset(
1117 | instance_image_path=args.instance_image_path,
1118 | instance_prompt=args.instance_prompt,
1119 | class_data_root=args.class_data_dir if args.with_prior_preservation else None,
1120 | class_prompt=args.class_prompt,
1121 | class_num=args.num_class_images,
1122 | tokenizer=tokenizer,
1123 | size=args.resolution,
1124 | center_crop=args.center_crop,
1125 | encoder_hidden_states=pre_computed_encoder_hidden_states,
1126 | class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
1127 | tokenizer_max_length=args.tokenizer_max_length,
1128 | )
1129 |
1130 | train_dataloader = torch.utils.data.DataLoader(
1131 | train_dataset,
1132 | batch_size=args.train_batch_size,
1133 | shuffle=True,
1134 | collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
1135 | num_workers=args.dataloader_num_workers,
1136 | )
1137 |
1138 | # Scheduler and math around the number of training steps.
1139 | overrode_max_train_steps = False
1140 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1141 | if args.max_train_steps is None:
1142 | args.max_train_steps = args.num_train_steps * num_update_steps_per_epoch
1143 | overrode_max_train_steps = True
1144 |
1145 | lr_scheduler = get_scheduler(
1146 | args.lr_scheduler,
1147 | optimizer=optimizer,
1148 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1149 | num_training_steps=args.max_train_steps * accelerator.num_processes,
1150 | num_cycles=args.lr_num_cycles,
1151 | power=args.lr_power,
1152 | )
1153 |
1154 | # Prepare everything with our `accelerator`.
1155 | if args.train_text_encoder:
1156 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1157 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler
1158 | )
1159 | else:
1160 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1161 | unet, optimizer, train_dataloader, lr_scheduler
1162 | )
1163 |
1164 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1165 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1166 | if overrode_max_train_steps:
1167 | args.max_train_steps = args.num_train_steps * num_update_steps_per_epoch
1168 | # Afterwards we recalculate our number of training epochs
1169 |
1170 | # We need to initialize the trackers we use, and also store our configuration.
1171 | # The trackers initializes automatically on the main process.
1172 | if accelerator.is_main_process:
1173 | tracker_config = vars(copy.deepcopy(args))
1174 | tracker_config.pop("validation_images")
1175 | accelerator.init_trackers("dreambooth-lora", config=tracker_config)
1176 |
1177 | # Train!
1178 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1179 |
1180 | logger.info("***** Running training *****")
1181 | logger.info(f" Num examples = {len(train_dataset)}")
1182 | logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1183 | logger.info(f" Num Steps = {args.num_train_steps}")
1184 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1185 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1186 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1187 | logger.info(f" Total optimization steps = {args.max_train_steps}")
1188 | global_step = 0
1189 | first_epoch = 0
1190 |
1191 |
1192 | # Potentially load in the weights and states from a previous save
1193 | if args.resume_from_checkpoint:
1194 | if args.resume_from_checkpoint != "latest":
1195 | path = os.path.basename(args.resume_from_checkpoint)
1196 | else:
1197 | # Get the mos recent checkpoint
1198 | dirs = os.listdir(args.output_dir)
1199 | dirs = [d for d in dirs if d.startswith("checkpoint")]
1200 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1201 | path = dirs[-1] if len(dirs) > 0 else None
1202 |
1203 | if path is None:
1204 | accelerator.print(
1205 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1206 | )
1207 | args.resume_from_checkpoint = None
1208 | initial_global_step = 0
1209 | else:
1210 | accelerator.print(f"Resuming from checkpoint {path}")
1211 | accelerator.load_state(os.path.join(args.output_dir, path))
1212 | global_step = int(path.split("-")[1])
1213 |
1214 | initial_global_step = global_step
1215 | first_epoch = global_step // num_update_steps_per_epoch
1216 | else:
1217 | initial_global_step = 0
1218 |
1219 | progress_bar = tqdm(
1220 | range(0, args.num_train_steps),
1221 | initial=initial_global_step,
1222 | desc="Steps",
1223 | # Only show the progress bar once on each machine.
1224 | disable=not accelerator.is_local_main_process,
1225 | )
1226 |
1227 | t0 = time.time()
1228 | unet.train()
1229 | if args.train_text_encoder:
1230 | text_encoder.train()
1231 | for batch in train_dataloader:
1232 | for step in range(args.num_train_steps):
1233 | with accelerator.accumulate(unet):
1234 | pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
1235 |
1236 | if vae is not None:
1237 | # Convert images to latent space
1238 | model_input = vae.encode(pixel_values).latent_dist.sample()
1239 | model_input = model_input * vae.config.scaling_factor
1240 | else:
1241 | model_input = pixel_values
1242 |
1243 | # Sample noise that we'll add to the latents
1244 | noise = torch.randn_like(model_input)
1245 | bsz, channels, height, width = model_input.shape
1246 | # Sample a random timestep for each image
1247 | timesteps = torch.randint(
1248 | 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
1249 | )
1250 | timesteps = timesteps.long()
1251 |
1252 | # Add noise to the model input according to the noise magnitude at each timestep
1253 | # (this is the forward diffusion process)
1254 | noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
1255 |
1256 | # Get the text embedding for conditioning
1257 | if args.pre_compute_text_embeddings:
1258 | encoder_hidden_states = batch["input_ids"]
1259 | else:
1260 | encoder_hidden_states = encode_prompt(
1261 | text_encoder,
1262 | batch["input_ids"],
1263 | batch["attention_mask"],
1264 | text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
1265 | )
1266 |
1267 | if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
1268 | noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
1269 |
1270 | if args.class_labels_conditioning == "timesteps":
1271 | class_labels = timesteps
1272 | else:
1273 | class_labels = None
1274 |
1275 | # Predict the noise residual
1276 | model_pred = unet(
1277 | noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
1278 | ).sample
1279 |
1280 | # if model predicts variance, throw away the prediction. we will only train on the
1281 | # simplified training objective. This means that all schedulers using the fine tuned
1282 | # model must be configured to use one of the fixed variance variance types.
1283 | if model_pred.shape[1] == 6:
1284 | model_pred, _ = torch.chunk(model_pred, 2, dim=1)
1285 |
1286 | # Get the target for loss depending on the prediction type
1287 | if noise_scheduler.config.prediction_type == "epsilon":
1288 | target = noise
1289 | elif noise_scheduler.config.prediction_type == "v_prediction":
1290 | target = noise_scheduler.get_velocity(model_input, noise, timesteps)
1291 | else:
1292 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1293 |
1294 | if args.with_prior_preservation:
1295 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
1296 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
1297 | target, target_prior = torch.chunk(target, 2, dim=0)
1298 |
1299 | # Compute instance loss
1300 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1301 |
1302 | # Compute prior loss
1303 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
1304 |
1305 | # Add the prior loss to the instance loss.
1306 | loss = loss + args.prior_loss_weight * prior_loss
1307 | else:
1308 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1309 |
1310 | accelerator.backward(loss)
1311 | if accelerator.sync_gradients:
1312 | params_to_clip = (
1313 | itertools.chain(unet_lora_parameters, text_lora_parameters)
1314 | if args.train_text_encoder
1315 | else unet_lora_parameters
1316 | )
1317 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1318 | optimizer.step()
1319 | lr_scheduler.step()
1320 | optimizer.zero_grad()
1321 |
1322 | # Checks if the accelerator has performed an optimization step behind the scenes
1323 | if accelerator.sync_gradients:
1324 | progress_bar.update(1)
1325 | global_step += 1
1326 |
1327 |
1328 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1329 | progress_bar.set_postfix(**logs)
1330 | accelerator.log(logs, step=global_step)
1331 |
1332 | if global_step >= args.max_train_steps:
1333 | break
1334 |
1335 | t1 = time.time()
1336 | # Save the lora layers
1337 | accelerator.wait_for_everyone()
1338 | if accelerator.is_main_process:
1339 | unet = accelerator.unwrap_model(unet)
1340 | unet = unet.to(torch.float32)
1341 | unet_lora_layers = unet_lora_state_dict(unet)
1342 |
1343 | if text_encoder is not None and args.train_text_encoder:
1344 | text_encoder = accelerator.unwrap_model(text_encoder)
1345 | text_encoder = text_encoder.to(torch.float32)
1346 | text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder, patch_mlp=args.patch_mlp)
1347 | else:
1348 | text_encoder_lora_layers = None
1349 |
1350 | LoraLoaderMixin.save_lora_weights(
1351 | save_directory=args.output_dir,
1352 | unet_lora_layers=unet_lora_layers,
1353 | text_encoder_lora_layers=text_encoder_lora_layers,
1354 | )
1355 | print("unet_lora_layers>>>>>>>>",len(unet_lora_layers))
1356 | print("text_encoder_lora_layers>>>>>",len(text_encoder_lora_layers))
1357 |
1358 | # Final inference
1359 | # Load previous pipeline
1360 | pipeline = DiffusionPipeline.from_pretrained(
1361 | args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
1362 | )
1363 |
1364 | # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1365 | scheduler_args = {}
1366 |
1367 | if "variance_type" in pipeline.scheduler.config:
1368 | variance_type = pipeline.scheduler.config.variance_type
1369 |
1370 | if variance_type in ["learned", "learned_range"]:
1371 | variance_type = "fixed_small"
1372 |
1373 | scheduler_args["variance_type"] = variance_type
1374 |
1375 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
1376 |
1377 | pipeline = pipeline.to(accelerator.device)
1378 |
1379 | # load attention processors
1380 | pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
1381 |
1382 | # run inference
1383 | images = []
1384 | if args.validation_prompt and args.num_validation_images > 0:
1385 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1386 | images = [
1387 | pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
1388 | for _ in range(args.num_validation_images)
1389 | ]
1390 | for i,image in enumerate(images):
1391 | image.save("aigc_samples/validation_%d.jpg"%i)
1392 |
1393 | for tracker in accelerator.trackers:
1394 | if tracker.name == "tensorboard":
1395 | np_images = np.stack([np.asarray(img) for img in images])
1396 | tracker.writer.add_images("test", np_images, step, dataformats="NHWC")
1397 | if tracker.name == "wandb":
1398 | tracker.log(
1399 | {
1400 | "test": [
1401 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1402 | for i, image in enumerate(images)
1403 | ]
1404 | }
1405 | )
1406 |
1407 | if args.push_to_hub:
1408 | save_model_card(
1409 | repo_id,
1410 | images=images,
1411 | base_model=args.pretrained_model_name_or_path,
1412 | train_text_encoder=args.train_text_encoder,
1413 | prompt=args.instance_prompt,
1414 | repo_folder=args.output_dir,
1415 | pipeline=pipeline,
1416 | )
1417 | upload_folder(
1418 | repo_id=repo_id,
1419 | folder_path=args.output_dir,
1420 | commit_message="End of training",
1421 | ignore_patterns=["step_*", "epoch_*"],
1422 | )
1423 |
1424 | accelerator.end_training()
1425 | print("\ntime elapsed: %f"%(t1-t0))
1426 | print("==================================complete======================================")
1427 |
1428 |
1429 | if __name__ == "__main__":
1430 | args = parse_args()
1431 | main(args)
1432 |
--------------------------------------------------------------------------------
/fast_finetune.sh:
--------------------------------------------------------------------------------
1 | export MODEL_NAME="stable-diffusion-models/realisticVisionV40_v40VAE"
2 | INSTANCE_IMAGE_PATH="projects/AIGC/dataset/FFHQ_test/00083.png"
3 |
4 | export RESUME_DIR="projects/AIGC/experiments2/rank_relax"
5 | export OUTPUT_DIR="projects/AIGC/experiments2/fast_finetune"
6 |
7 |
8 | CUDA_VISIBLE_DEVICES=0 \
9 | accelerate launch --mixed_precision="fp16" fast_finetune.py \
10 | --pretrained_model_name_or_path=$MODEL_NAME \
11 | --instance_image_path=$INSTANCE_IMAGE_PATH \
12 | --instance_prompt="A [V] face" \
13 | --resolution=512 \
14 | --train_batch_size=1 \
15 | --num_train_steps=25 \
16 | --learning_rate=1e-4 --lr_scheduler="constant" --lr_warmup_steps=0 \
17 | --seed=42 \
18 | --rank=4 \
19 | --output_dir=$OUTPUT_DIR \
20 | --resume_dir=$RESUME_DIR \
21 | --num_validation_images=5 \
22 | --validation_prompt="A [V] face" \
23 | --train_text_encoder
24 | # --patch_mlp \
25 | # --resume_from_checkpoint=$RESUME_DIR \
26 |
27 |
28 |
--------------------------------------------------------------------------------
/images/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/images/1.jpg
--------------------------------------------------------------------------------
/images/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/images/2.jpg
--------------------------------------------------------------------------------
/images/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/images/3.jpg
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/modules/__init__.py
--------------------------------------------------------------------------------
/modules/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/modules/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/attention.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/modules/__pycache__/attention.cpython-311.pyc
--------------------------------------------------------------------------------
/modules/attention.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from einops import rearrange, repeat
8 |
9 | from .utils import default
10 | from .utils.xformers_utils import (
11 | XFORMERS_AVAIL,
12 | memory_efficient_attention
13 | )
14 |
15 |
16 | class GEGLU(nn.Module):
17 | def __init__(self, dim_in, dim_out):
18 | super().__init__()
19 | self.proj = nn.Linear(dim_in, dim_out * 2)
20 |
21 | def forward(self, x):
22 | x, gate = self.proj(x).chunk(2, dim=-1)
23 | return x * F.gelu(gate)
24 |
25 |
26 | class FeedForward(nn.Module):
27 | def __init__(self, dim, dim_out=None, mult=4, glu=False):
28 | super().__init__()
29 | inner_dim = int(dim * mult)
30 | dim_out = default(dim_out, dim)
31 | project_in = nn.Sequential(
32 | nn.Linear(dim, inner_dim),
33 | nn.GELU()
34 | ) if not glu else GEGLU(dim, inner_dim)
35 |
36 | self.net = nn.Sequential(
37 | project_in,
38 | nn.Linear(inner_dim, dim_out)
39 | )
40 | # nn.init.constant_(self.net[-1].weight, 0)
41 | # nn.init.constant_(self.net[-1].bias, 0)
42 |
43 | def forward(self, x):
44 | return self.net(x)
45 |
46 |
47 | MEMORY_LAYOUTS = {
48 | 'torch': (
49 | 'b n (h d) -> b h n d',
50 | 'b h n d -> b n (h d)',
51 | lambda x: (1, x, 1, 1),
52 | ),
53 | 'xformers': (
54 | 'b n (h d) -> b n h d',
55 | 'b n h d -> b n (h d)',
56 | lambda x: (1, 1, x, 1),
57 | ),
58 | 'vanilla': (
59 | 'b n (h d) -> b h n d',
60 | 'b h n d -> b n (h d)',
61 | lambda x: (1, x, 1, 1),
62 | )
63 | }
64 | ATTN_FUNCTION = {
65 | 'torch': F.scaled_dot_product_attention,
66 | 'xformers': memory_efficient_attention
67 | }
68 |
69 |
70 | def vanilla_attention(q, k, v, mask, scale=None):
71 | if scale is None:
72 | scale = math.sqrt(q.size(-1))
73 | scores = torch.bmm(q, k.transpose(-1, -2)) / scale
74 | if mask is not None:
75 | mask = rearrange(mask, 'b ... -> b (...)')
76 | max_neg_value = -torch.finfo(scores.dtype).max
77 | mask = repeat(mask, 'b j -> (b h) j', h=q.size(-3))
78 | scores = scores.masked_fill(~mask, max_neg_value)
79 | p_attn = F.softmax(scores, dim=-1)
80 | return torch.bmm(p_attn, v)
81 |
82 |
83 | class Attention(nn.Module):
84 | '''
85 | Attention Class without norm and residual
86 | (You need to wrap them by your self)
87 | '''
88 | def __init__(
89 | self,
90 | in_ch,
91 | context_ch=None,
92 | heads=8,
93 | head_ch=64,
94 | self_cross=False,
95 | single_kv_head=False,
96 | attn_backend='torch',
97 | # attn_backend='xformers',
98 | cosine_attn=False,
99 | qk_head_ch=-1
100 | ):
101 | super().__init__()
102 | if heads==-1:
103 | assert in_ch%head_ch==0
104 | heads = in_ch//head_ch
105 | if head_ch==-1:
106 | assert in_ch%heads==0
107 | head_ch = in_ch//heads
108 | if qk_head_ch==-1:
109 | qk_head_ch = head_ch
110 | q_ch = heads * qk_head_ch
111 | k_ch = (1 if single_kv_head else heads) * qk_head_ch
112 | v_ch = (1 if single_kv_head else heads) * head_ch
113 | inner_ch = heads * head_ch
114 | assert inner_ch == in_ch
115 | use_context = context_ch is not None
116 | context_ch = default(context_ch, in_ch)
117 |
118 | if attn_backend == 'xformers':
119 | assert XFORMERS_AVAIL
120 | if attn_backend == 'torch':
121 | assert torch.version.__version__ >= '2.0.0'
122 |
123 | self.heads = heads
124 | self.self_cross = self_cross
125 | self.single_kv_head = single_kv_head
126 | self.attn = ATTN_FUNCTION[attn_backend]
127 | self.memory_layout = MEMORY_LAYOUTS[attn_backend]
128 | self.cosine_attn = cosine_attn
129 |
130 | if cosine_attn:
131 | self.scale = nn.Parameter(torch.ones(MEMORY_LAYOUTS[attn_backend][2](heads)))
132 | else:
133 | self.scale = None
134 |
135 | self.q = nn.Linear(in_ch, q_ch, bias=False)
136 | if self_cross and use_context:
137 | self.k = nn.Linear(in_ch, k_ch, bias=False)
138 | self.v = nn.Linear(in_ch, v_ch, bias=False)
139 | self.ck = nn.Linear(context_ch, k_ch, bias=False)
140 | self.cv = nn.Linear(context_ch, v_ch, bias=False)
141 | else:
142 | self.k = nn.Linear(context_ch, k_ch, bias=False)
143 | self.v = nn.Linear(context_ch, v_ch, bias=False)
144 |
145 | self.out = nn.Linear(inner_ch, in_ch)
146 |
147 | def forward(self, x:torch.Tensor, context=None, mask=None):
148 | # Input Projection
149 | heads = self.heads
150 | q = self.q(x)
151 | if self.self_cross:
152 | k = self.k(x)
153 | v = self.v(x)
154 | if context is not None:
155 | ck = self.ck(context)
156 | cv = self.cv(context)
157 | k = torch.concat([k, ck], dim=1)
158 | v = torch.concat([v, cv], dim=1)
159 | else:
160 | ctx = default(context, x)
161 | k = self.k(ctx)
162 | v = self.v(ctx)
163 |
164 | # Rearrange for Attention
165 | q = rearrange(q, self.memory_layout[0], h=heads)
166 | if self.single_kv_head:
167 | k = k.unsqueeze(1)
168 | v = v.unsqueeze(1)
169 |
170 | b, _, seq, _ = k.shape
171 | k = k.expand(b, heads, seq, k.size(3))
172 | v = v.expand(b, heads, seq, v.size(3))
173 | else:
174 | k = rearrange(k, self.memory_layout[0], h=heads)
175 | v = rearrange(v, self.memory_layout[0], h=heads)
176 |
177 | if self.cosine_attn:
178 | q = (F.normalize(q, dim=-1) * math.sqrt(q.size(-1))).to(v.dtype)
179 | k = (F.normalize(k, dim=-1) * self.scale).to(v.dtype)
180 |
181 | # Attention
182 | out = self.attn(
183 | q.contiguous(), k.contiguous(), v.contiguous(), mask
184 | )
185 |
186 | # Output Projection
187 | out = rearrange(out, self.memory_layout[1], h=heads)
188 | return self.out(out)
189 |
190 |
191 | class TransformerBlock(nn.Module):
192 | def __init__(
193 | self,
194 | dim,
195 | n_heads,
196 | d_head,
197 | context_dim=None,
198 | gated_ff=True,
199 | self_cross=False,
200 | single_kv_head=False,
201 | attn_backend='torch',
202 | cosine_attn=False,
203 | qk_head_ch=-1,
204 | disable_self_attn=False,
205 | single_attn=False,
206 | ):
207 | super().__init__()
208 | self.single_attn = single_attn
209 | self.disable_self_attn = disable_self_attn or single_attn
210 |
211 | self.norm1 = nn.LayerNorm(dim)
212 | self.attn1 = Attention(
213 | dim, context_dim if self.disable_self_attn else None, n_heads, d_head,
214 | self_cross, single_kv_head, attn_backend, cosine_attn, qk_head_ch
215 | )
216 | if not single_attn:
217 | self.norm2 = nn.LayerNorm(dim)
218 | self.attn2 = Attention(
219 | dim, context_dim, n_heads, d_head,
220 | self_cross, single_kv_head, attn_backend, cosine_attn, qk_head_ch
221 | )
222 | self.norm3 = nn.LayerNorm(dim)
223 | self.ff = FeedForward(dim, glu=gated_ff)
224 |
225 | def forward(self, x, context=None):
226 | x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
227 | if not self.single_attn and (context is not None or self.attn2.self_cross):
228 | x = self.attn2(self.norm2(x), context=context) + x
229 | x = self.ff(self.norm3(x)) + x
230 | return x
--------------------------------------------------------------------------------
/modules/hypernet.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import *
3 | import numpy as np
4 | import torch
5 | import platform
6 |
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch.utils.checkpoint as checkpoint
10 |
11 | from torchvision.transforms.functional import resize
12 |
13 | from timm import create_model
14 | # model download: https://github.com/huggingface/pytorch-image-models
15 |
16 | # from einops import rearrange
17 |
18 | from .attention import TransformerBlock
19 | from .light_lora import LoRALinearLayer
20 |
21 |
22 | def _get_sinusoid_encoding_table(n_position, d_hid):
23 | ''' Sinusoid position encoding table '''
24 |
25 | def get_position_angle_vec(position):
26 | # this part calculate the position In brackets
27 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
28 |
29 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
30 | # [:, 0::2] are all even subscripts, is dim_2i
31 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
32 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
33 |
34 | return torch.FloatTensor(sinusoid_table).unsqueeze(0)
35 |
36 |
37 | class WeightDecoder(nn.Module):
38 | def __init__(
39 | self,
40 | weight_dim: int = 150,
41 | weight_num: int = 168,
42 | decoder_blocks: int = 4,
43 | add_constant: bool = False,
44 | ):
45 | super(WeightDecoder, self).__init__()
46 | self.weight_num = weight_num
47 | self.weight_dim = weight_dim
48 |
49 | self.register_buffer(
50 | 'block_pos_emb',
51 | _get_sinusoid_encoding_table(weight_num * 2, weight_dim)
52 | )
53 |
54 | # calc heads for mem-eff or flash_attn
55 | heads = 1
56 | while weight_dim % heads == 0 and weight_dim // heads > 64:
57 | heads *= 2
58 | heads //= 2
59 |
60 | self.pos_emb_proj = nn.Linear(weight_dim, weight_dim, bias=False)
61 | self.decoder_model = nn.ModuleList(
62 | TransformerBlock(weight_dim, heads, weight_dim // heads, context_dim=weight_dim, gated_ff=False)
63 | for _ in range(decoder_blocks)
64 | )
65 | # self.delta_proj = nn.Linear(weight_dim, weight_dim, bias=False)
66 | self.delta_proj = nn.Sequential(
67 | nn.LayerNorm(weight_dim),
68 | nn.Linear(weight_dim, weight_dim, bias=False)
69 | )
70 | self.init_weights(add_constant)
71 |
72 | def init_weights(self, add_constant: bool = False):
73 | def basic_init(module):
74 | if isinstance(module, nn.Linear):
75 | nn.init.xavier_uniform_(module.weight)
76 | if module.bias is not None:
77 | nn.init.constant_(module.bias, 0)
78 |
79 | self.apply(basic_init)
80 |
81 | # For no pre-optimized training, you should consider use the following init
82 | # with self.down = down@down_aux + 1 in LiLoRAAttnProcessor
83 | # if add_constant:
84 | # torch.nn.init.constant_(self.delta_proj[1].weight, 0)
85 |
86 | # advice from Nataniel Ruiz, looks like 1e-3 is small enough
87 | # else:
88 | # torch.nn.init.normal_(self.delta_proj[1].weight, std=1e-3)
89 | torch.nn.init.normal_(self.delta_proj[1].weight, std=1e-3)
90 |
91 | def forward(self, weight, features):
92 | pos_emb = self.pos_emb_proj(self.block_pos_emb[:, :weight.size(1)].clone().detach())
93 | h = weight + pos_emb
94 | for decoder in self.decoder_model:
95 | h = decoder(h, context=features)
96 | weight = weight + self.delta_proj(h)
97 | return weight
98 |
99 |
100 | class ImgWeightGenerator(nn.Module):
101 | def __init__(
102 | self,
103 | encoder_model_name: str = "vit_base_patch16_224",
104 | train_encoder: bool = False,
105 | reference_size: int = 224,
106 | weight_dim: int = 240,
107 | weight_num: int = 176,
108 | decoder_blocks: int = 4,
109 | sample_iters: int = 1,
110 | add_constant: bool = False,
111 | ):
112 | super(ImgWeightGenerator, self).__init__()
113 | self.ref_size = reference_size
114 | self.weight_num = weight_num
115 | self.weight_dim = weight_dim
116 | self.sample_iters = sample_iters
117 | self.train_encoder = train_encoder
118 | self.encoder_model_name = encoder_model_name
119 | self.register_buffer(
120 | 'block_pos_emb',
121 | _get_sinusoid_encoding_table(weight_num * 2, weight_dim)
122 | )
123 | # operating system
124 | if platform.system().lower() == "linux":
125 | model_dir = "projects/AIGC/model_zoo"
126 | elif platform.system().lower() == "darwin":
127 | model_dir = "workspace/model_zoo"
128 | else:
129 | raise ValueError('Operating system does not support!')
130 |
131 | # vit encoder
132 | print("encoder_model_name: ", encoder_model_name)
133 | if encoder_model_name == "vit_base_patch16_224":
134 | # forward_features: [1, 197, 768]
135 | checkpoint_path = os.path.join(model_dir, "B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz")
136 | elif encoder_model_name == "models--timm--vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k.safetensors":
137 | # forward_features: [1, 257, 1280]
138 | checkpoint_path = os.path.join(model_dir, "models--timm--vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k.safetensors")
139 | elif encoder_model_name == "vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k":
140 | # forward_features: [1, 577, 1280]
141 | checkpoint_path = os.path.join(model_dir, "models--timm--vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k.safetensors")
142 | else:
143 | raise ValueError('%s does not exist!'%encoder_model_name)
144 | # Create ViT Model
145 | self.encoder_model: nn.Module = create_model(encoder_model_name, checkpoint_path=checkpoint_path)
146 | # self.encoder_model: nn.Module = create_model(encoder_model_name, pretrained=True)
147 | for p in self.encoder_model.parameters():
148 | p.requires_grad_(train_encoder)
149 |
150 | # check encoder model shape and format
151 | test_input = torch.randn(1, 3, reference_size, reference_size)
152 | test_output = self.encoder_model.forward_features(test_input)
153 | if isinstance(test_output, list):
154 | test_output = test_output[-1]
155 | if len(test_output.shape) == 4:
156 | # B, C, H, W -> B, L, C
157 | test_output = test_output.view(1, test_output.size(1), -1).transpose(1, 2)
158 |
159 | self.feature_proj = nn.Linear(test_output.shape[-1], weight_dim, bias=False)
160 | self.pos_emb_proj = nn.Linear(weight_dim, weight_dim, bias=False)
161 | self.decoder_model = WeightDecoder(weight_dim, weight_num, decoder_blocks, add_constant)
162 |
163 | def encode_features(self, ref_img):
164 | # ref_img = resize(ref_img, [self.ref_size, self.ref_size], antialias=True)
165 | if not self.train_encoder:
166 | with torch.no_grad():
167 | img_features = self.encoder_model.forward_features(ref_img)
168 | else:
169 | img_features = self.encoder_model.forward_features(ref_img)
170 | if isinstance(img_features, list):
171 | img_features = img_features[-1]
172 | if len(img_features.shape) == 4:
173 | # B, C, H, W -> B, L, C
174 | img_features = img_features.view(img_features.size(0), img_features.size(1), -1).transpose(1, 2)
175 | # print("img_features", img_features)
176 | return img_features
177 |
178 | def decode_weight(self, img_features, iters=None, weight=None):
179 | img_features = self.feature_proj(img_features)
180 |
181 | if weight is None:
182 | weight = torch.zeros(
183 | img_features.size(0), self.weight_num, self.weight_dim,
184 | device=img_features.device
185 | )
186 |
187 | for i in range(iters or self.sample_iters):
188 | weight = self.decoder_model(weight, img_features)
189 | return weight
190 |
191 | def forward(self, ref_img, iters=None, weight=None, ensure_grad=0):
192 | img_features = self.encode_features(ref_img) + ensure_grad
193 | weight = self.decode_weight(img_features, iters, weight)
194 | return weight
195 |
196 |
197 | class HyperDream(nn.Module):
198 | def __init__(
199 | self,
200 | img_encoder_model_name: str = "vit_base_patch16_224",
201 | ref_img_size: int = 224,
202 | weight_dim: int = 240,
203 | weight_num: int = 176,
204 | decoder_blocks: int = 4,
205 | sample_iters: int = 4,
206 | add_constant: bool = False,
207 | train_encoder: bool = False,
208 | ):
209 | super(HyperDream, self).__init__()
210 | self.img_weight_generator = ImgWeightGenerator(
211 | encoder_model_name=img_encoder_model_name,
212 | reference_size=ref_img_size,
213 | weight_dim=weight_dim,
214 | weight_num=weight_num,
215 | decoder_blocks=decoder_blocks,
216 | sample_iters=sample_iters,
217 | train_encoder=train_encoder,
218 | add_constant=add_constant,
219 | )
220 | self.weight_dim = weight_dim
221 | self.add_constant = add_constant
222 | self.liloras: Dict[str, LoRALinearLayer] = {}
223 | self.liloras_keys: List[str] = []
224 | self.gradient_checkpointing = False
225 |
226 | def enable_gradient_checkpointing(self):
227 | self.gradient_checkpointing = True
228 |
229 | def train_params(self):
230 | return [p for p in self.parameters() if p.requires_grad]
231 |
232 | def set_lilora(self, liloras):
233 | self.liloras = liloras
234 | if isinstance(liloras, dict):
235 | self.liloras_keys = list(liloras.keys()) # for fixed order
236 | else:
237 | self.liloras_keys = range(len(liloras))
238 | length = len(self.liloras_keys)
239 | print(f"Total LiLoRAs: {length}, Hypernet params for each image: {length * self.weight_dim}")
240 |
241 | def gen_weight(self, ref_img, iters, weight, ensure_grad=0):
242 | weights = self.img_weight_generator(ref_img, iters, weight, ensure_grad)
243 | weight_list = weights.split(1, dim=1) # [b, n, dim] -> n*[b, 1, dim]
244 | return weights, [weight.squeeze(1) for weight in weight_list]
245 |
246 | def forward(self, ref_img: torch.Tensor, iters: int = None, weight: torch.Tensor = None):
247 | if self.training and self.gradient_checkpointing:
248 | ensure_grad = torch.zeros(1, device=ref_img.device).requires_grad_(True)
249 | weights, weight_list = checkpoint.checkpoint(
250 | self.gen_weight, ref_img, iters, weight, ensure_grad
251 | )
252 | else:
253 | weights, weight_list = self.gen_weight(ref_img, iters, weight)
254 |
255 | # for key, weight in zip(self.liloras_keys, weight_list):
256 | # self.liloras[key].update_weight(weight, self.add_constant)
257 |
258 | return weights, weight_list
259 |
260 |
261 | class PreOptHyperDream(nn.Module):
262 | def __init__(
263 | self,
264 | rank: int = 1,
265 | down_dim: int = 100,
266 | up_dim: int = 50,
267 | ):
268 | super(PreOptHyperDream, self).__init__()
269 | self.weights = nn.Parameter(torch.tensor(0.0))
270 | self.rank = rank
271 | self.down_dim = down_dim
272 | self.up_dim = up_dim
273 | self.params_per_lora = (down_dim + up_dim) * rank
274 | self.liloras: Dict[str, LoRALinearLayer] = {}
275 | self.liloras_keys: List[str] = []
276 | self.gradient_checkpointing = False
277 | self.device = 'cpu'
278 |
279 | def enable_gradient_checkpointing(self):
280 | self.gradient_checkpointing = True
281 |
282 | def train_params(self):
283 | return [p for p in self.parameters() if p.requires_grad]
284 |
285 | def set_device(self, device):
286 | self.device = device
287 |
288 | def set_lilora(self, liloras, identities=1):
289 | self.liloras = liloras
290 | if isinstance(liloras, dict):
291 | self.liloras_keys = list(liloras.keys()) # for fixed order
292 | elif isinstance(liloras, list):
293 | self.liloras_keys = range(len(liloras))
294 | else:
295 | raise TypeError("liloras only support dict and list!")
296 |
297 |
298 | length = len(self.liloras_keys)
299 | print(f"Total LiLoRAs: {length}, Pre-Optimized params for each image: {length * self.params_per_lora}")
300 | print(f"Pre-Optimized params: {length * self.params_per_lora * identities / 1e6:.1f}M")
301 | del self.weights
302 |
303 | self.length = length
304 | self.weights = nn.ParameterList(
305 | torch.concat([
306 | torch.randn(1, length, self.down_dim * self.rank),
307 | torch.zeros(1, length, self.up_dim * self.rank)
308 | ], dim=-1) for _ in range(identities)
309 | )
310 |
311 | def gen_weight(self, identities: torch.Tensor):
312 | weights = torch.concat([self.weights[id] for id in identities], dim=0).to(self.device)
313 | weight_list = weights.split(1, dim=1) # [b, n, dim] -> n*[b, 1, dim] -> n*[b, dim]
314 | return weights, [weight.squeeze(1) for weight in weight_list]
315 |
316 | def forward(self, identities: torch.Tensor):
317 | if self.training and self.gradient_checkpointing:
318 | weights, weight_list = checkpoint.checkpoint(
319 | self.gen_weight, identities
320 | )
321 | else:
322 | weights, weight_list = self.gen_weight(identities)
323 |
324 | # for key, weight in zip(self.liloras_keys, weight_list):
325 | # self.liloras[key].update_weight(weight)
326 |
327 | return weights, weight_list
--------------------------------------------------------------------------------
/modules/hypernet_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import numpy as np
4 | import timm
5 | import platform
6 | from torch.nn import TransformerDecoder, TransformerDecoderLayer
7 | import torch.nn as nn
8 | from torchvision import transforms
9 | from torchvision.transforms.functional import resize
10 | from PIL import Image
11 |
12 | # 创建一个transform对象,包含了一系列的预处理操作
13 | transform = transforms.Compose([
14 | transforms.Resize((224, 224)), # 将图像resize到模型需要的大小
15 | transforms.ToTensor(), # 将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1]
16 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
17 | # 对图像进行归一化,这里的均值和标准差是ImageNet数据集的均值和标准差
18 | ])
19 |
20 |
21 | class PositionalEncoding(nn.Module):
22 | """
23 | 这个PositionalEncoding类的输入维度该是(time, batch, channel)。
24 | 如果输入是(batch, time, channel)的形式,可以使用torch.transpose或者torch.permute先将其转置为(time, batch, channel)的形式,
25 | 然后再传入PositionalEncoding。
26 | """
27 | def __init__(self, d_model, dropout=0.1, max_len=5000):
28 | super(PositionalEncoding, self).__init__()
29 | self.dropout = nn.Dropout(p=dropout)
30 | pe = torch.zeros(max_len, d_model)
31 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
32 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
33 | pe[:, 0::2] = torch.sin(position * div_term)
34 | pe[:, 1::2] = torch.cos(position * div_term)
35 | pe = pe.unsqueeze(0).transpose(0, 1)
36 | self.register_buffer('pe', pe)
37 |
38 | def forward(self, x):
39 | x = x + self.pe[:x.size(0), :]
40 | return self.dropout(x)
41 |
42 |
43 | class VisualImageEncoder(torch.nn.Module):
44 | def __init__(self):
45 | global checkpoint_path
46 | super(VisualImageEncoder, self).__init__()
47 | # 加载预训练的ViT-H模型
48 | # self.vit_encoder = timm.create_model('vit_huge_patch16_224', pretrained=True)
49 | # self.vit_encoder = timm.create_model('vit_base_patch16_224', pretrained=True)
50 | # TODO: vit encoder
51 | if platform.system().lower() == "linux":
52 | checkpoint_path = "projects/AIGC/models/liloras/vit_base_patch16_224/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz"
53 | elif platform.system().lower() == "darwin":
54 | checkpoint_path = "model_zoo/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz"
55 | self.vit_encoder: nn.Module = timm.create_model(model_name="vit_base_patch16_224",
56 | checkpoint_path=checkpoint_path)
57 |
58 | def forward(self, x):
59 | """
60 | 输入图像尺寸为224x224
61 | """
62 | features = self.vit_encoder.forward_features(x) # 提取图片特征,一个(B,T,C)的3D张量
63 | return features
64 |
65 |
66 | class WeightTransformerDecoder(nn.Module):
67 | def __init__(self, d_model, nhead, num_layers):
68 | super(WeightTransformerDecoder, self).__init__()
69 | self.model_type = 'Transformer'
70 | self.src_mask = None
71 | self.d_model = d_model
72 | self.pos_encoder = PositionalEncoding(d_model)
73 | decoder_layer = TransformerDecoderLayer(d_model, nhead)
74 | self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers)
75 | self.weight_proj = nn.Sequential(
76 | nn.LayerNorm(d_model),
77 | nn.Linear(d_model, d_model, bias=False)
78 | )
79 |
80 | # 使用均匀分布初始化全连接层的权重
81 | for name, param in self.transformer_decoder.named_parameters():
82 | if 'linear' in name and 'weight' in name:
83 | nn.init.uniform_(param.data)
84 |
85 | # 将偏置项初始化为0
86 | for name, param in self.transformer_decoder.named_parameters():
87 | if 'linear' in name and 'bias' in name:
88 | nn.init.zeros_(param.data)
89 |
90 |
91 |
92 | def forward(self, weight_embedding, face_embedding):
93 | """
94 | # 创建一个随机的weight_embedding和face_embedding
95 | weight_embedding = torch.rand(seq_length, batch_size, embedding_dim)
96 | face_embedding = torch.rand(seq_length, batch_size, embedding_dim)
97 | """
98 | # if self.src_mask is None or self.src_mask.size(0) != len(weight_embedding):
99 | # device = weight_embedding.device
100 | # mask = self._generate_square_subsequent_mask(len(weight_embedding)).to(device)
101 | # self.src_mask = mask
102 | # hidden_embedding = self.transformer_decoder(pos_embedding, face_embedding, self.src_mask)
103 |
104 | pos_embedding = self.pos_encoder(weight_embedding)
105 | hidden_embedding = self.transformer_decoder(pos_embedding, face_embedding)
106 | output = self.weight_proj(hidden_embedding)
107 | return output
108 |
109 |
110 | def _generate_square_subsequent_mask(self, sz):
111 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
112 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
113 | return mask
114 |
115 |
116 | class HyperNetwork(torch.nn.Module):
117 | def __init__(
118 | self,
119 | ref_img_size: tuple[int] = (224, 224),
120 | rank: int = 1,
121 | down_dim: int = 128,
122 | up_dim: int = 64,
123 | weight_num: int = 128,
124 | iters: int = 4,
125 | train_encoder: bool = False):
126 | super(HyperNetwork, self).__init__()
127 |
128 | self.weight_dim = (down_dim + up_dim) * rank
129 | self.weight_num = weight_num
130 | self.iters = iters
131 | self.train_encoder = train_encoder
132 | self.ref_img_size = ref_img_size
133 | self.visual_image_encoder = VisualImageEncoder()
134 | self.weight_transformer_decoder = WeightTransformerDecoder(d_model=self.weight_dim, nhead=8, num_layers=4)
135 | total_params = sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
136 | print('Number of hypernetwork parameters: {:.2f}M'.format(total_params))
137 |
138 | # check encoder model shape and format
139 | test_input = torch.randn(1, 3, *ref_img_size)
140 | test_output = self.visual_image_encoder(test_input)
141 | if len(test_output.shape) == 3:
142 | # default shape in (B,T,C)
143 | pass
144 | elif len(test_output.shape) == 4:
145 | # B, C, H, W -> B, T, C
146 | test_output = test_output.view(1, test_output.size(1), -1).transpose(1, 2)
147 | else:
148 | raise ValueError("Output dimension must be 3 or 4")
149 | # 根据输出特征维度设置
150 | feature_dim = test_output.size(-1)
151 | self.feature_proj = nn.Linear(feature_dim, self.weight_dim, bias=False)
152 |
153 | if not train_encoder:
154 | # 设置visual_image_encoder为不可训练
155 | for param in self.visual_image_encoder.parameters():
156 | param.requires_grad = False
157 | # 设置visual_image_encoder为评估模式
158 | self.visual_image_encoder.eval()
159 |
160 | def train(self, mode=True):
161 | super().train(mode)
162 | if not self.train_encoder:
163 | self.visual_image_encoder.eval() # 确保visual_image_encoder始终在评估模式
164 |
165 | def train_params(self):
166 | return [p for p in self.parameters() if p.requires_grad]
167 |
168 | def forward(self, x):
169 | x = resize(x, self.ref_img_size, antialias=True)
170 | # batch first
171 | image_features = self.visual_image_encoder(x)
172 | # print("image_features:", image_features, image_features.shape)
173 | face_embedding = self.feature_proj(image_features)
174 | # weight_embedding zero initialization
175 | weight_embedding = torch.zeros(face_embedding.size(0), self.weight_num, self.weight_dim,
176 | device=image_features.device)
177 | # batch first to time first
178 | face_embedding = face_embedding.permute(1, 0, 2)
179 | weight_embedding = weight_embedding.permute(1, 0, 2)
180 | # Iterative Prediction
181 | for i in range(self.iters):
182 | weight_embedding += self.weight_transformer_decoder(weight_embedding, face_embedding)
183 | # print("weight_embedding_%d"%i, weight_embedding)
184 | # print("weight_embedding Prediction", weight_embedding.shape)
185 | # time first to batch first
186 | weight_embedding = weight_embedding.permute(1, 0, 2)
187 | # weight = self.weight_proj(weight_embedding)
188 | print("weight:",weight_embedding, weight_embedding.shape)
189 | return weight_embedding
190 |
191 |
--------------------------------------------------------------------------------
/modules/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 |
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | import torch.linalg as linalg
10 |
11 | from tqdm import tqdm
12 |
13 |
14 | def default(val, d):
15 | return val if val is not None else d
--------------------------------------------------------------------------------
/modules/utils/lora_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Any, Dict, Iterable, Optional, Union
3 | from diffusers.models import UNet2DConditionModel
4 |
5 | def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
6 | r"""
7 | Returns:
8 | A state dict containing just the LoRA parameters.
9 | """
10 | lora_state_dict = {}
11 |
12 | for name, module in unet.named_modules():
13 | if hasattr(module, "set_lora_layer"):
14 | lora_layer = getattr(module, "lora_layer")
15 | if lora_layer is not None:
16 | current_lora_layer_sd = lora_layer.state_dict()
17 | for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items():
18 | # The matrix name can either be "down" or "up".
19 | lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param
20 |
21 | return lora_state_dict
22 |
23 |
24 | def text_encoder_lora_state_dict(text_encoder, patch_mlp=False):
25 | state_dict = {}
26 |
27 | def text_encoder_attn_modules(text_encoder):
28 | from transformers import CLIPTextModel, CLIPTextModelWithProjection
29 |
30 | attn_modules = []
31 |
32 | if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
33 | for i, layer in enumerate(text_encoder.text_model.encoder.layers):
34 | name = f"text_model.encoder.layers.{i}.self_attn"
35 | mod = layer.self_attn
36 | attn_modules.append((name, mod))
37 |
38 | return attn_modules
39 |
40 | def text_encoder_mlp_modules(text_encoder):
41 | from transformers import CLIPTextModel, CLIPTextModelWithProjection
42 | mlp_modules = []
43 | if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
44 | for i, layer in enumerate(text_encoder.text_model.encoder.layers):
45 | name = f"text_model.encoder.layers.{i}.mlp"
46 | mod = layer.mlp
47 | mlp_modules.append((name, mod))
48 |
49 | return mlp_modules
50 |
51 | # text encoder attn layer
52 | for name, module in text_encoder_attn_modules(text_encoder):
53 | for k, v in module.q_proj.lora_linear_layer.state_dict().items():
54 | state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
55 |
56 | for k, v in module.k_proj.lora_linear_layer.state_dict().items():
57 | state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
58 |
59 | for k, v in module.v_proj.lora_linear_layer.state_dict().items():
60 | state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
61 |
62 | for k, v in module.out_proj.lora_linear_layer.state_dict().items():
63 | state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
64 |
65 | # text encoder mlp layer
66 | if patch_mlp:
67 | for name, module in text_encoder_mlp_modules(text_encoder):
68 | for k, v in module.fc1.lora_linear_layer.state_dict().items():
69 | state_dict[f"{name}.fc1.lora_linear_layer.{k}"] = v
70 |
71 | for k, v in module.fc2.lora_linear_layer.state_dict().items():
72 | state_dict[f"{name}.fc2.lora_linear_layer.{k}"] = v
73 |
74 | return state_dict
75 |
--------------------------------------------------------------------------------
/modules/utils/xformers_utils.py:
--------------------------------------------------------------------------------
1 | memory_efficient_attention = None
2 | try:
3 | import xformers
4 | except:
5 | pass
6 |
7 | try:
8 | from xformers.ops import memory_efficient_attention
9 | XFORMERS_AVAIL = True
10 | except:
11 | memory_efficient_attention = None
12 | XFORMERS_AVAIL = False
13 |
--------------------------------------------------------------------------------
/rank_relax.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import itertools
4 | from PIL import Image
5 | from torchvision.transforms import CenterCrop, Resize, ToTensor, Compose, Normalize
6 | from diffusers.models.attention_processor import (
7 | AttnAddedKVProcessor,
8 | AttnAddedKVProcessor2_0,
9 | SlicedAttnAddedKVProcessor,
10 | )
11 | from diffusers import StableDiffusionPipeline
12 | import torch.utils.checkpoint
13 | from modules.relax_lora import LoRALinearLayer, LoraLoaderMixin
14 | from modules.utils.lora_utils import unet_lora_state_dict, text_encoder_lora_state_dict
15 | from modules.hypernet import HyperDream
16 |
17 | pretrain_model_path = "stable-diffusion-models/realisticVisionV40_v40VAE"
18 | hypernet_model_path = "projects/AIGC/experiments2/hypernet/CelebA-HQ-10k-pretrain2/checkpoint-250000"
19 | reference_image_path = "projects/AIGC/dataset/FFHQ_test/00019.png"
20 | output_dir = "projects/AIGC/experiments2/rank_relax"
21 |
22 | # Parameter Settings
23 | train_text_encoder = True
24 | patch_mlp = False
25 | down_dim = 160
26 | up_dim = 80
27 | rank_in = 1 # hypernet output
28 | rank_out = 4 # rank relax output
29 | # vit_model_name = "vit_base_patch16_224"
30 | # vit_model_name = "vit_huge_patch14_clip_224"
31 | vit_model_name = "vit_huge_patch14_clip_336"
32 |
33 | t0 = time.time()
34 | # TODO: 1.load predicted lora weights
35 | pipe = StableDiffusionPipeline.from_pretrained(pretrain_model_path, torch_dtype=torch.float32)
36 | # state_dict, network_alphas = pipe.lora_state_dict(lora_model_path)
37 | pipe.to("cuda")
38 | unet = pipe.unet
39 | text_encoder = pipe.text_encoder
40 |
41 | # TODO: 2.Create rank_relaxed LoRA
42 | unet_lora_parameters = []
43 | unet_lora_linear_layers = []
44 | print("Create a combined LoRA consisted of Frozen LoRA and Trainable LoRA.")
45 | for i, (attn_processor_name, attn_processor) in enumerate(unet.attn_processors.items()):
46 | print("unet.attn_processor->%d:%s" % (i, attn_processor_name), attn_processor)
47 | # attn_processor_name: mid_block.attentions.0.transformer_blocks.0.attn1.processor
48 | # Parse the attention module.
49 | attn_module = unet
50 | for n in attn_processor_name.split(".")[:-1]:
51 | attn_module = getattr(attn_module, n)
52 | print("attn_module:", attn_module)
53 |
54 | # Set the `lora_layer` attribute of the attention-related matrices.
55 | attn_module.to_q.set_lora_layer(
56 | LoRALinearLayer(
57 | in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=rank_out
58 | )
59 | )
60 | attn_module.to_k.set_lora_layer(
61 | LoRALinearLayer(
62 | in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=rank_out
63 | )
64 | )
65 | attn_module.to_v.set_lora_layer(
66 | LoRALinearLayer(
67 | in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=rank_out
68 | )
69 | )
70 | attn_module.to_out[0].set_lora_layer(
71 | LoRALinearLayer(
72 | in_features=attn_module.to_out[0].in_features, out_features=attn_module.to_out[0].out_features,
73 | rank=rank_out,
74 | )
75 | )
76 | # Accumulate the LoRA params to optimize.
77 | unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
78 | unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
79 | unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
80 | unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
81 |
82 | # Accumulate the LoRALinerLayer to optimize.
83 | unet_lora_linear_layers.append(attn_module.to_q.lora_layer)
84 | unet_lora_linear_layers.append(attn_module.to_k.lora_layer)
85 | unet_lora_linear_layers.append(attn_module.to_v.lora_layer)
86 | unet_lora_linear_layers.append(attn_module.to_out[0].lora_layer)
87 |
88 | if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
89 | attn_module.add_k_proj.set_lora_layer(
90 | LoRALinearLayer(
91 | in_features=attn_module.add_k_proj.in_features,
92 | out_features=attn_module.add_k_proj.out_features,
93 | rank=rank_out,
94 | )
95 | )
96 | attn_module.add_v_proj.set_lora_layer(
97 | LoRALinearLayer(
98 | in_features=attn_module.add_v_proj.in_features,
99 | out_features=attn_module.add_v_proj.out_features,
100 | rank=rank_out,
101 | )
102 | )
103 | unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters())
104 | unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters())
105 |
106 | unet_lora_linear_layers.append(attn_module.add_k_proj.lora_layer)
107 | unet_lora_linear_layers.append(attn_module.add_v_proj.lora_layer)
108 |
109 | # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
110 | # So, instead, we monkey-patch the forward calls of its attention-blocks.
111 | if train_text_encoder:
112 | # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
113 | # if patch_mlp is True, the finetuning will cover the text encoder mlp, otherwise only the text encoder attention, total lora is (12+12)*4=96
114 | # if state_dict is not None, the frozen linear will be initializaed.
115 | text_lora_parameters, text_encoder_lora_linear_layers = LoraLoaderMixin._modify_text_encoder(text_encoder,
116 | state_dict=None,
117 | dtype=torch.float32,
118 | rank=rank_out,
119 | patch_mlp=patch_mlp)
120 | # print(text_encoder_lora_linear_layers)
121 |
122 | # total loras
123 | lora_linear_layers = unet_lora_linear_layers + text_encoder_lora_linear_layers if train_text_encoder else unet_lora_linear_layers
124 | print("========================================================================")
125 | t1 = time.time()
126 |
127 | # TODO: 3.Convert rank_lora to a standard LoRA
128 | print("Create Hypernet...")
129 | if vit_model_name == "vit_base_patch16_224":
130 | img_encoder_model_name = "vit_base_patch16_224"
131 | ref_img_size = 224
132 | mean = [0.5000]
133 | std = [0.5000]
134 | elif vit_model_name == "vit_huge_patch14_clip_224":
135 | img_encoder_model_name = "vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k"
136 | ref_img_size = 224
137 | mean = [0.4815, 0.4578, 0.4082]
138 | std = [0.2686, 0.2613, 0.2758]
139 | elif vit_model_name == "vit_huge_patch14_clip_336":
140 | img_encoder_model_name = "vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k"
141 | ref_img_size = 336
142 | mean = [0.4815, 0.4578, 0.4082]
143 | std = [0.2686, 0.2613, 0.2758]
144 | else:
145 | raise ValueError("%s does not supports!" % vit_model_name)
146 |
147 | hypernet_transposes = Compose([
148 | Resize(size=ref_img_size),
149 | CenterCrop(size=(ref_img_size, ref_img_size)),
150 | ToTensor(),
151 | Normalize(mean=mean, std=std),
152 | ])
153 |
154 | hypernetwork = HyperDream(
155 | img_encoder_model_name=img_encoder_model_name,
156 | ref_img_size=ref_img_size,
157 | weight_num=len(lora_linear_layers),
158 | weight_dim=(up_dim + down_dim) * rank_in,
159 | )
160 | hypernetwork.set_lilora(lora_linear_layers)
161 |
162 | if os.path.isdir(hypernet_model_path):
163 | path = os.path.join(hypernet_model_path, "hypernetwork.bin")
164 | weight_dict = torch.load(path)
165 | sd = weight_dict['hypernetwork']
166 | hypernetwork.load_state_dict(sd)
167 | else:
168 | weight_dict = torch.load(hypernet_model_path['hypernetwork'])
169 | sd = weight_dict['hypernetwork']
170 | hypernetwork.load_state_dict(sd)
171 |
172 | for i, lilora in enumerate(lora_linear_layers):
173 | seed = weight_dict['aux_seed_%d' % i]
174 | down_aux = weight_dict['down_aux_%d' % i]
175 | up_aux = weight_dict['up_aux_%d' % i]
176 |
177 | print(f"Hypernet weights loaded from: {hypernet_model_path}")
178 |
179 | hypernetwork = hypernetwork.to("cuda")
180 | hypernetwork = hypernetwork.eval()
181 |
182 | ref_img = Image.open(reference_image_path).convert("RGB")
183 | ref_img = hypernet_transposes(ref_img).unsqueeze(0).to("cuda")
184 | # warmup
185 | _ = hypernetwork(ref_img)
186 |
187 | t2_0 = time.time()
188 | weight, weight_list = hypernetwork(ref_img)
189 | print("weight>>>>>>>>>>>:", weight.shape, weight)
190 | t2_1 = time.time()
191 |
192 | t111 = time.time()
193 | # convert down and up weights to linear layer as LoRALinearLayer
194 | for i, (weight, lora_layer) in enumerate(zip(weight_list, lora_linear_layers)):
195 | seed = weight_dict['aux_seed_%d' % i]
196 | down_aux = weight_dict['down_aux_%d' % i]
197 | up_aux = weight_dict['up_aux_%d' % i]
198 | # reshape weight
199 | down_weight, up_weight = weight.split([down_dim * rank_in, up_dim * rank_in], dim=-1)
200 | down_weight = down_weight.reshape(rank_in, -1)
201 | up_weight = up_weight.reshape(-1, rank_in)
202 | # make weight, matrix multiplication
203 | down = down_weight @ down_aux
204 | up = up_aux @ up_weight
205 | lora_layer.down.weight.data.copy_(down.to(torch.float32))
206 | lora_layer.up.weight.data.copy_(up.to(torch.float32))
207 | t222 = time.time()
208 |
209 | print("Convert to standard LoRA...")
210 | for lora_linear_layer in lora_linear_layers:
211 | lora_linear_layer = lora_linear_layer.to("cuda")
212 | lora_linear_layer.convert_to_standard_lora()
213 | print("========================================================================")
214 | t2 = time.time()
215 |
216 | # TODO: 4.Save standard LoRA
217 | print("Save standard LoRA...")
218 | unet_lora_layers_to_save = unet_lora_state_dict(unet)
219 | text_encoder_lora_layers_to_save = None
220 | if train_text_encoder:
221 | text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(text_encoder, patch_mlp=patch_mlp)
222 |
223 | LoraLoaderMixin.save_lora_weights(
224 | save_directory=output_dir,
225 | unet_lora_layers=unet_lora_layers_to_save,
226 | text_encoder_lora_layers=text_encoder_lora_layers_to_save,
227 | )
228 |
229 | t3 = time.time()
230 |
231 | print("Successfully save LoRA to: %s" % (output_dir))
232 | print("load pipeline: %f" % (t1 - t0))
233 | print("load hypernet: %f" % (t2_0 - t1))
234 | print("hypernet inference: %f" % (t2_1 - t2_0))
235 | print("copy weight: %f" % (t222 - t111))
236 |
237 | print("rank relax: %f" % (t2 - t2_1))
238 | print("model save: %f" % (t3 - t2))
239 | print("total time: %f" % (t3 - t0))
240 | print("==================================complete======================================")
241 |
--------------------------------------------------------------------------------
/rank_relax_test.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | from diffusers.models.attention_processor import (
4 | AttnAddedKVProcessor,
5 | AttnAddedKVProcessor2_0,
6 | SlicedAttnAddedKVProcessor,
7 | )
8 | from diffusers import StableDiffusionPipeline
9 | import torch.utils.checkpoint
10 | from modules.relax_lora import LoRALinearLayer, LoraLoaderMixin
11 | from modules.utils.lora_utils import unet_lora_state_dict, text_encoder_lora_state_dict
12 |
13 | t0 = time.time()
14 |
15 | pretrain_model_path="stable-diffusion-models/realisticVisionV40_v40VAE"
16 | lora_model_path = "projects/AIGC/lora_model_test"
17 | output_dir = "projects/AIGC/experiments2/rank_relax"
18 |
19 | train_text_encoder = True
20 | patch_mlp = False
21 |
22 | # TODO: 1.load predicted lora weights
23 | pipe = StableDiffusionPipeline.from_pretrained(pretrain_model_path, torch_dtype=torch.float32)
24 | state_dict, network_alphas = pipe.lora_state_dict(lora_model_path)
25 | pipe.to("cuda")
26 |
27 | unet = pipe.unet
28 | text_encoder = pipe.text_encoder
29 | # print(state_dict.keys())
30 | # unet.up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_v.lora.down.weight
31 | # unet.down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.lora.down.weight
32 | # text_encoder.text_model.encoder.layers.11.self_attn.out_proj.lora_linear_layer.down.weight
33 | # text_encoder.text_model.encoder.layers.11.mlp.fc1.lora_linear_layer.down.weight
34 |
35 | # TODO: 2.Create rank_relaxed LoRA and initialize the froze linear layer
36 | rank = 4 # the relax lora rank is 4.
37 | unet_lora_parameters = []
38 | unet_lora_linear_layers = []
39 | print("Create a combined LoRA consisted of Frozen LoRA and Trainable LoRA.")
40 | for i, (attn_processor_name, attn_processor) in enumerate(unet.attn_processors.items()):
41 | print("unet.attn_processor->%d:%s" % (i, attn_processor_name), attn_processor)
42 | # attn_processor_name: mid_block.attentions.0.transformer_blocks.0.attn1.processor
43 | # Parse the attention module.
44 | attn_module = unet
45 | for n in attn_processor_name.split(".")[:-1]:
46 | attn_module = getattr(attn_module, n)
47 | print("attn_module:",attn_module)
48 |
49 | # Set the `lora_layer` attribute of the attention-related matrices.
50 | attn_module.to_q.set_lora_layer(
51 | LoRALinearLayer(
52 | in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=rank
53 | )
54 | )
55 | attn_module.to_k.set_lora_layer(
56 | LoRALinearLayer(
57 | in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=rank
58 | )
59 | )
60 | attn_module.to_v.set_lora_layer(
61 | LoRALinearLayer(
62 | in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=rank
63 | )
64 | )
65 | attn_module.to_out[0].set_lora_layer(
66 | LoRALinearLayer(
67 | in_features=attn_module.to_out[0].in_features,
68 | out_features=attn_module.to_out[0].out_features,
69 | rank=rank,
70 | )
71 | )
72 | # Accumulate the LoRA params to optimize.
73 | unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
74 | unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
75 | unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
76 | unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
77 |
78 | # Accumulate the LoRALinerLayer to optimize.
79 | unet_lora_linear_layers.append(attn_module.to_q.lora_layer)
80 | unet_lora_linear_layers.append(attn_module.to_k.lora_layer)
81 | unet_lora_linear_layers.append(attn_module.to_v.lora_layer)
82 | unet_lora_linear_layers.append(attn_module.to_out[0].lora_layer)
83 |
84 | # Set predicted weights to frozen lora
85 | # static_dict: unet.up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.lora.up.weight,
86 | # static_dict: unet.up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_k.lora.down.weight
87 | # attn_processor_name: mid_block.attentions.0.transformer_blocks.0.attn1.processor
88 | for layer_name in ['to_q', 'to_k', 'to_v', 'to_out']:
89 | attn_processor_name = attn_processor_name.replace('.processor', '')
90 | if layer_name == 'to_out':
91 | layer = getattr(attn_module, layer_name)[0].lora_layer
92 | down_key = "unet.%s.%s.0.lora.down.weight" % (attn_processor_name, layer_name)
93 | up_key = "unet.%s.%s.0.lora.up.weight" % (attn_processor_name, layer_name)
94 | else:
95 | layer = getattr(attn_module, layer_name).lora_layer
96 | down_key = "unet.%s.%s.lora.down.weight" % (attn_processor_name, layer_name)
97 | up_key = "unet.%s.%s.lora.up.weight" % (attn_processor_name, layer_name)
98 | # copy weights
99 | layer.down.weight.data.copy_(state_dict[down_key].to(torch.float32))
100 | layer.up.weight.data.copy_(state_dict[up_key].to(torch.float32))
101 | print("unet attention lora initialized!")
102 |
103 | if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
104 | attn_module.add_k_proj.set_lora_layer(
105 | LoRALinearLayer(
106 | in_features=attn_module.add_k_proj.in_features,
107 | out_features=attn_module.add_k_proj.out_features,
108 | rank=rank,
109 | )
110 | )
111 | attn_module.add_v_proj.set_lora_layer(
112 | LoRALinearLayer(
113 | in_features=attn_module.add_v_proj.in_features,
114 | out_features=attn_module.add_v_proj.out_features,
115 | rank=rank,
116 | )
117 | )
118 | unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters())
119 | unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters())
120 |
121 | unet_lora_linear_layers.append(attn_module.add_k_proj.lora_layer)
122 | unet_lora_linear_layers.append(attn_module.add_v_proj.lora_layer)
123 |
124 | for layer_name in ['add_k_proj', 'add_v_proj']:
125 | attn_processor_name = attn_processor_name.replace('.processor', '')
126 | layer = getattr(attn_module, layer_name).lora_layer
127 | down_key = "unet.%s.%s.lora.down.weight" % (attn_processor_name, layer_name)
128 | up_key = "unet.%s.%s.lora.up.weight" % (attn_processor_name, layer_name)
129 | # copy weights
130 | layer.down.weight.data.copy_(state_dict[down_key].to(torch.float32))
131 | layer.up.weight.data.copy_(state_dict[up_key].to(torch.float32))
132 | print("unet add_proj lora initialized!")
133 |
134 | # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
135 | # So, instead, we monkey-patch the forward calls of its attention-blocks.
136 | if train_text_encoder:
137 | # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
138 | # if patch_mlp is True, the finetuning will cover the text encoder mlp, otherwise only the text encoder attention, total lora is (12+12)*4=96
139 | # if state_dict is not None, the frozen linear will be initializaed.
140 | text_lora_parameters, text_encoder_lora_linear_layers = LoraLoaderMixin._modify_text_encoder(text_encoder, state_dict, dtype=torch.float32, rank=rank, patch_mlp=patch_mlp)
141 | # print(text_encoder_lora_linear_layers)
142 |
143 |
144 | # TODO: 3.Convert rank_lora to a standard LoRA
145 | print("Convert rank_lora to a standard LoRA...")
146 | lora_linear_layers = unet_lora_linear_layers + text_encoder_lora_linear_layers \
147 | if train_text_encoder else unet_lora_linear_layers
148 |
149 | for lora_linear_layer in lora_linear_layers:
150 | lora_linear_layer = lora_linear_layer.to("cuda")
151 | lora_linear_layer.convert_to_standard_lora()
152 |
153 |
154 | # TODO: 4.Save standard LoRA
155 | print("Save standard LoRA...")
156 | unet_lora_layers_to_save = unet_lora_state_dict(unet)
157 | text_encoder_lora_layers_to_save = None
158 | if train_text_encoder:
159 | text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(text_encoder, patch_mlp=patch_mlp)
160 |
161 |
162 | LoraLoaderMixin.save_lora_weights(
163 | save_directory=output_dir,
164 | unet_lora_layers=unet_lora_layers_to_save,
165 | text_encoder_lora_layers=text_encoder_lora_layers_to_save,
166 | )
167 |
168 | t1 = time.time()
169 |
170 | print("Successfully save LoRA to: %s" % (output_dir))
171 | print("time elapsed: %f"%(t1-t0))
172 | print("==================================complete======================================")
173 |
174 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.0.1
2 | torchvision
3 | einops
4 | accelerate
5 | timm
6 | transformers
7 | diffusers==0.25.0.dev0
8 | git+https://github.com/huggingface/diffusers.git
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/main/setup.py
17 |
18 | To create the package for PyPI.
19 |
20 | 1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the
21 | documentation.
22 |
23 | If releasing on a special branch, copy the updated README.md on the main branch for the commit you will make
24 | for the post-release and run `make fix-copies` on the main branch as well.
25 |
26 | 2. Run Tests for Amazon Sagemaker. The documentation is located in `./tests/sagemaker/README.md`, otherwise @philschmid.
27 |
28 | 3. Unpin specific versions from setup.py that use a git install.
29 |
30 | 4. Checkout the release branch (v-release, for example v4.19-release), and commit these changes with the
31 | message: "Release: " and push.
32 |
33 | 5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs).
34 |
35 | 6. Add a tag in git to mark the release: "git tag v -m 'Adds tag v for PyPI'"
36 | Push the tag to git: git push --tags origin v-release
37 |
38 | 7. Build both the sources and the wheel. Do not change anything in setup.py between
39 | creating the wheel and the source distribution (obviously).
40 |
41 | For the wheel, run: "python setup.py bdist_wheel" in the top level directory
42 | (This will build a wheel for the Python version you use to build it).
43 |
44 | For the sources, run: "python setup.py sdist"
45 | You should now have a /dist directory with both .whl and .tar.gz source versions.
46 |
47 | Long story cut short, you need to run both before you can upload the distribution to the
48 | test PyPI and the actual PyPI servers:
49 |
50 | python setup.py bdist_wheel && python setup.py sdist
51 |
52 | 8. Check that everything looks correct by uploading the package to the PyPI test server:
53 |
54 | twine upload dist/* -r pypitest
55 | (pypi suggests using twine as other methods upload files via plaintext.)
56 | You may have to specify the repository url, use the following command then:
57 | twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
58 |
59 | Check that you can install it in a virtualenv by running:
60 | pip install -i https://testpypi.python.org/pypi diffusers
61 |
62 | If you are testing from a Colab Notebook, for instance, then do:
63 | pip install diffusers && pip uninstall diffusers
64 | pip install -i https://testpypi.python.org/pypi diffusers
65 |
66 | Check you can run the following commands:
67 | python -c "from diffusers import __version__; print(__version__)"
68 | python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()"
69 | python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')"
70 | python -c "from diffusers import *"
71 |
72 | 9. Upload the final version to the actual PyPI:
73 | twine upload dist/* -r pypi
74 |
75 | 10. Prepare the release notes and publish them on GitHub once everything is looking hunky-dory.
76 |
77 | 11. Run `make post-release` (or, for a patch release, `make post-patch`). If you were on a branch for the release,
78 | you need to go back to main before executing this.
79 | """
80 |
81 | import os
82 | import re
83 | import sys
84 | from distutils.core import Command
85 |
86 | from setuptools import find_packages, setup
87 |
88 |
89 | # IMPORTANT:
90 | # 1. all dependencies should be listed here with their version requirements if any
91 | # 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
92 | _deps = [
93 | "Pillow", # keep the PIL.Image.Resampling deprecation away
94 | "accelerate>=0.11.0",
95 | "compel==0.1.8",
96 | "datasets",
97 | "filelock",
98 | "flax>=0.4.1",
99 | "hf-doc-builder>=0.3.0",
100 | "huggingface-hub>=0.19.4",
101 | "requests-mock==1.10.0",
102 | "importlib_metadata",
103 | "invisible-watermark>=0.2.0",
104 | "isort>=5.5.4",
105 | "jax>=0.4.1",
106 | "jaxlib>=0.4.1",
107 | "Jinja2",
108 | "k-diffusion>=0.0.12",
109 | "torchsde",
110 | "note_seq",
111 | "librosa",
112 | "numpy",
113 | "omegaconf",
114 | "parameterized",
115 | "peft>=0.6.0",
116 | "protobuf>=3.20.3,<4",
117 | "pytest",
118 | "pytest-timeout",
119 | "pytest-xdist",
120 | "python>=3.8.0",
121 | "ruff==0.1.5",
122 | "safetensors>=0.3.1",
123 | "sentencepiece>=0.1.91,!=0.1.92",
124 | "GitPython<3.1.19",
125 | "scipy",
126 | "onnx",
127 | "regex!=2019.12.17",
128 | "requests",
129 | "tensorboard",
130 | "torch>=1.4",
131 | "torchvision",
132 | "transformers>=4.25.1",
133 | "urllib3<=2.0.0",
134 | ]
135 |
136 | # this is a lookup table with items like:
137 | #
138 | # tokenizers: "huggingface-hub==0.8.0"
139 | # packaging: "packaging"
140 | #
141 | # some of the values are versioned whereas others aren't.
142 | deps = {b: a for a, b in (re.findall(r"^(([^!=<>~]+)(?:[!=<>~].*)?$)", x)[0] for x in _deps)}
143 |
144 | # since we save this data in src/diffusers/dependency_versions_table.py it can be easily accessed from
145 | # anywhere. If you need to quickly access the data from this table in a shell, you can do so easily with:
146 | #
147 | # python -c 'import sys; from diffusers.dependency_versions_table import deps; \
148 | # print(" ".join([deps[x] for x in sys.argv[1:]]))' tokenizers datasets
149 | #
150 | # Just pass the desired package names to that script as it's shown with 2 packages above.
151 | #
152 | # If diffusers is not yet installed and the work is done from the cloned repo remember to add `PYTHONPATH=src` to the script above
153 | #
154 | # You can then feed this for example to `pip`:
155 | #
156 | # pip install -U $(python -c 'import sys; from diffusers.dependency_versions_table import deps; \
157 | # print(" ".join([deps[x] for x in sys.argv[1:]]))' tokenizers datasets)
158 | #
159 |
160 |
161 | def deps_list(*pkgs):
162 | return [deps[pkg] for pkg in pkgs]
163 |
164 |
165 | class DepsTableUpdateCommand(Command):
166 | """
167 | A custom distutils command that updates the dependency table.
168 | usage: python setup.py deps_table_update
169 | """
170 |
171 | description = "build runtime dependency table"
172 | user_options = [
173 | # format: (long option, short option, description).
174 | (
175 | "dep-table-update",
176 | None,
177 | "updates src/diffusers/dependency_versions_table.py",
178 | ),
179 | ]
180 |
181 | def initialize_options(self):
182 | pass
183 |
184 | def finalize_options(self):
185 | pass
186 |
187 | def run(self):
188 | entries = "\n".join([f' "{k}": "{v}",' for k, v in deps.items()])
189 | content = [
190 | "# THIS FILE HAS BEEN AUTOGENERATED. To update:",
191 | "# 1. modify the `_deps` dict in setup.py",
192 | "# 2. run `make deps_table_update`",
193 | "deps = {",
194 | entries,
195 | "}",
196 | "",
197 | ]
198 | target = "src/diffusers/dependency_versions_table.py"
199 | print(f"updating {target}")
200 | with open(target, "w", encoding="utf-8", newline="\n") as f:
201 | f.write("\n".join(content))
202 |
203 |
204 | extras = {}
205 | extras["quality"] = deps_list("urllib3", "isort", "ruff", "hf-doc-builder")
206 | extras["docs"] = deps_list("hf-doc-builder")
207 | extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2")
208 | extras["test"] = deps_list(
209 | "compel",
210 | "GitPython",
211 | "datasets",
212 | "Jinja2",
213 | "invisible-watermark",
214 | "k-diffusion",
215 | "librosa",
216 | "omegaconf",
217 | "parameterized",
218 | "pytest",
219 | "pytest-timeout",
220 | "pytest-xdist",
221 | "requests-mock",
222 | "safetensors",
223 | "sentencepiece",
224 | "scipy",
225 | "torchvision",
226 | "transformers",
227 | )
228 | extras["torch"] = deps_list("torch", "accelerate")
229 |
230 | if os.name == "nt": # windows
231 | extras["flax"] = [] # jax is not supported on windows
232 | else:
233 | extras["flax"] = deps_list("jax", "jaxlib", "flax")
234 |
235 | extras["dev"] = (
236 | extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
237 | )
238 |
239 | install_requires = [
240 | deps["importlib_metadata"],
241 | deps["filelock"],
242 | deps["huggingface-hub"],
243 | deps["numpy"],
244 | deps["regex"],
245 | deps["requests"],
246 | deps["safetensors"],
247 | deps["Pillow"],
248 | ]
249 |
250 | version_range_max = max(sys.version_info[1], 10) + 1
251 |
252 | setup(
253 | name="diffusers",
254 | version="0.25.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
255 | description="State-of-the-art diffusion in PyTorch and JAX.",
256 | long_description=open("README.md", "r", encoding="utf-8").read(),
257 | long_description_content_type="text/markdown",
258 | keywords="deep learning diffusion jax pytorch stable diffusion audioldm",
259 | license="Apache 2.0 License",
260 | author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/diffusers/graphs/contributors)",
261 | author_email="patrick@huggingface.co",
262 | url="https://github.com/huggingface/diffusers",
263 | package_dir={"": "src"},
264 | packages=find_packages("src"),
265 | package_data={"diffusers": ["py.typed"]},
266 | include_package_data=True,
267 | python_requires=">=3.8.0",
268 | install_requires=list(install_requires),
269 | extras_require=extras,
270 | entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]},
271 | classifiers=[
272 | "Development Status :: 5 - Production/Stable",
273 | "Intended Audience :: Developers",
274 | "Intended Audience :: Education",
275 | "Intended Audience :: Science/Research",
276 | "License :: OSI Approved :: Apache Software License",
277 | "Operating System :: OS Independent",
278 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
279 | "Programming Language :: Python :: 3",
280 | ]
281 | + [f"Programming Language :: Python :: 3.{i}" for i in range(8, version_range_max)],
282 | cmdclass={"deps_table_update": DepsTableUpdateCommand},
283 | )
284 |
--------------------------------------------------------------------------------
/train_dreambooth_light_lora.sh:
--------------------------------------------------------------------------------
1 | export MODEL_NAME="stable-diffusion-models/realisticVisionV40_v40VAE"
2 | export INSTANCE_DIR="projects/AIGC/dataset/AI_drawing/instance_dir"
3 | export OUTPUT_DIR="projects/AIGC/lora_model_test"
4 |
5 |
6 | CUDA_VISIBLE_DEVICES=0 \
7 | accelerate launch --mixed_precision="fp16" train_dreambooth_light_lora.py \
8 | --pretrained_model_name_or_path=$MODEL_NAME \
9 | --instance_data_dir=$INSTANCE_DIR \
10 | --instance_prompt="A [V] face" \
11 | --resolution=512 \
12 | --train_batch_size=1 \
13 | --num_train_epochs=301 --checkpointing_steps=500 \
14 | --learning_rate=1e-3 --lr_scheduler="constant" --lr_warmup_steps=0 \
15 | --cfg_drop_rate 0.1 \
16 | --rank=1 \
17 | --down_dim=160 \
18 | --up_dim=80 \
19 | --output_dir=$OUTPUT_DIR \
20 | --num_validation_images=5 \
21 | --validation_prompt="A [V] face" \
22 | --validation_epochs=300 \
23 | --train_text_encoder
24 | # --patch_mlp \
25 | # --resume_from_checkpoint "latest" \
26 | # --seed=42
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/train_dreambooth_lora.sh:
--------------------------------------------------------------------------------
1 | export MODEL_NAME="stable-diffusion-models/realisticVisionV40_v40VAE"
2 | export INSTANCE_DIR="projects/AIGC/dataset/AI_drawing/instance_dir"
3 | export OUTPUT_DIR="projects/AIGC/lora_model_test"
4 |
5 | CUDA_VISIBLE_DEVICES=0 \
6 | accelerate launch --mixed_precision="fp16" train_dreambooth_lora.py \
7 | --pretrained_model_name_or_path=$MODEL_NAME \
8 | --instance_data_dir=$INSTANCE_DIR \
9 | --instance_prompt="A [V] face" \
10 | --resolution=512 \
11 | --train_batch_size=1 \
12 | --num_train_epochs=201 --checkpointing_steps=500 \
13 | --learning_rate=5e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
14 | --seed=42 \
15 | --rank=4 \
16 | --output_dir=$OUTPUT_DIR \
17 | --num_validation_images=5 \
18 | --validation_prompt="A [V] face" \
19 | --validation_epochs=200 \
20 | --train_text_encoder
21 | # --patch_mlp \
22 | # --resume_from_checkpoint=$RESUME_DIR \
23 |
24 |
25 |
--------------------------------------------------------------------------------
/train_hypernet.sh:
--------------------------------------------------------------------------------
1 | # Model
2 | export MODEL_NAME="stable-diffusion-models/realisticVisionV40_v40VAE"
3 | export PRE_OPT_WEIGHT_DIR="projects/AIGC/experiments2/pretrained/CelebA-HQ-10k-fake"
4 | #vit_base_patch16_224
5 | # Image
6 | export INSTANCE_DIR="projects/AIGC/dataset/CelebA-HQ-10k"
7 | export VALIDATION_INPUT_DIR="projects/AIGC/dataset/FFHQ_test"
8 | # Output
9 | export VALIDATION_OUTPUT_DIR="projects/AIGC/experiments2/validation_outputs_10k-4"
10 | export OUTPUT_DIR="experiments/hypernet/CelebA-HQ-10k"
11 |
12 |
13 | CUDA_VISIBLE_DEVICES=0 \
14 | accelerate launch --mixed_precision="fp16" train_hypernet.py \
15 | --pretrained_model_name_or_path $MODEL_NAME \
16 | --pre_opt_weight_path $PRE_OPT_WEIGHT_DIR \
17 | --instance_data_dir $INSTANCE_DIR \
18 | --vit_model_name vit_huge_patch14_clip_336 \
19 | --instance_prompt "A [V] face" \
20 | --output_dir $OUTPUT_DIR \
21 | --allow_tf32 \
22 | --resolution 512 \
23 | --learning_rate 1e-3 \
24 | --lr_scheduler cosine \
25 | --lr_warmup_steps 100 \
26 | --checkpoints_total_limit 10 \
27 | --checkpointing_steps 10000 \
28 | --cfg_drop_rate 0.1 \
29 | --seed=42 \
30 | --rank 1 \
31 | --down_dim 160 \
32 | --up_dim 80 \
33 | --train_batch_size 4 \
34 | --pre_opt_weight_coeff 0.02 \
35 | --num_train_epochs 400 \
36 | --resume_from_checkpoint "latest" \
37 | --validation_prompt="A [V] face" \
38 | --validation_input_dir $VALIDATION_INPUT_DIR \
39 | --validation_output_dir $VALIDATION_OUTPUT_DIR \
40 | --validation_epochs 10 \
41 | --train_text_encoder
42 |
43 |
44 | # when train_text_encoder, not pre_compute_text_embeddings
45 | # --pre_compute_text_embeddings \
46 | # --train_text_encoder
47 | # --patch_mlp \
48 |
49 |
50 |
--------------------------------------------------------------------------------
/train_hypernet_pro.sh:
--------------------------------------------------------------------------------
1 | # Model
2 | export MODEL_NAME="stable-diffusion-models/realisticVisionV40_v40VAE"
3 | export PRE_OPT_WEIGHT_DIR="projects/AIGC/experiments2/pretrained/CelebA-HQ-30k"
4 | #vit_base_patch16_224
5 | # Image
6 | export INSTANCE_DIR="projects/AIGC/dataset/CelebA-HQ-30k"
7 | export VALIDATION_INPUT_DIR="projects/AIGC/dataset/FFHQ_test"
8 | # Output
9 | export VALIDATION_OUTPUT_DIR="projects/AIGC/experiments2/validation_outputs_10k-7"
10 | export OUTPUT_DIR="AIGC/experiments2/hypernet/CelebA-HQ-30k-no-pretrain-2"
11 |
12 | CUDA_VISIBLE_DEVICES=0 \
13 | accelerate launch --mixed_precision="fp16" train_hypernet_pro.py \
14 | --pretrained_model_name_or_path $MODEL_NAME \
15 | --pre_opt_weight_path $PRE_OPT_WEIGHT_DIR \
16 | --instance_data_dir $INSTANCE_DIR \
17 | --vit_model_name vit_huge_patch14_clip_336 \
18 | --instance_prompt "A [V] face" \
19 | --output_dir $OUTPUT_DIR \
20 | --allow_tf32 \
21 | --resolution 512 \
22 | --learning_rate 5e-4 \
23 | --lr_scheduler cosine \
24 | --lr_warmup_steps 100 \
25 | --checkpoints_total_limit 10 \
26 | --checkpointing_steps 10000 \
27 | --cfg_drop_rate 0.1 \
28 | --seed=42 \
29 | --rank 1 \
30 | --down_dim 160 \
31 | --up_dim 80 \
32 | --train_batch_size 16 \
33 | --pre_opt_weight_coeff 0.0 \
34 | --num_train_epochs 200 \
35 | --resume_from_checkpoint "latest" \
36 | --validation_prompt="A [V] face" \
37 | --validation_input_dir $VALIDATION_INPUT_DIR \
38 | --validation_output_dir $VALIDATION_OUTPUT_DIR \
39 | --validation_epochs 10 \
40 | --train_text_encoder
41 |
42 |
43 | # when train_text_encoder, not pre_compute_text_embeddings
44 | # --pre_compute_text_embeddings \
45 | # --train_text_encoder
46 | # --patch_mlp \
47 |
48 |
49 |
--------------------------------------------------------------------------------
/train_preoptnet.sh:
--------------------------------------------------------------------------------
1 | # run train
2 | export MODEL_NAME="stable-diffusion-models/realisticVisionV40_v40VAE"
3 | export INSTANCE_DIR="projects/AIGC/dataset/CelebA-HQ-10k"
4 | export OUTPUT_DIR="experiments/pretrained/CelebA-HQ-10k"
5 |
6 | CUDA_VISIBLE_DEVICES=0 \
7 | accelerate launch --mixed_precision="fp16" train_preoptnet.py \
8 | --pretrained_model_name_or_path $MODEL_NAME \
9 | --instance_data_dir $INSTANCE_DIR \
10 | --instance_prompt "A [V] face" \
11 | --output_dir $OUTPUT_DIR \
12 | --resolution 512 \
13 | --learning_rate 1e-3 \
14 | --lr_scheduler constant \
15 | --checkpoints_total_limit 2 \
16 | --checkpointing_steps 4 \
17 | --cfg_drop_rate 0.1 \
18 | --seed=42 \
19 | --rank 1 \
20 | --down_dim 160 \
21 | --up_dim 80 \
22 | --train_batch_size 8 \
23 | --train_steps_per_identity 300 \
24 | --resume_from_checkpoint "latest" \
25 | --validation_prompt="A [V] face" \
26 | --train_text_encoder
27 |
28 |
29 | # when train_text_encoder, not pre_compute_text_embeddings
30 | # --pre_compute_text_embeddings \
31 | # --train_text_encoder
32 | # --patch_mlp \
--------------------------------------------------------------------------------