├── LICENSE ├── README.md ├── T2I_inference.py ├── T2I_inference_merge_lora.py ├── assets ├── example1.png ├── example2.png ├── example3.png ├── example4.png ├── example5.png ├── hyperdreambooth.png ├── hypernet.png ├── img_3.png ├── img_5.png ├── img_6.png ├── light_lora.png ├── pipeline.png ├── preoptnet.png └── relax_relax_fast_finetune.png ├── export_hypernet_weight.py ├── export_hypernet_weight.sh ├── export_preoptnet_weight.py ├── export_preoptnet_weight.sh ├── fast_finetune.py ├── fast_finetune.sh ├── images ├── 1.jpg ├── 2.jpg └── 3.jpg ├── modules ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-311.pyc │ └── attention.cpython-311.pyc ├── attention.py ├── hypernet.py ├── hypernet_test.py ├── light_lora.py ├── light_lora_pro.py ├── lora.py ├── ortho_lora.py ├── relax_lora.py ├── relax_lora_backup.py └── utils │ ├── __init__.py │ ├── lora_utils.py │ └── xformers_utils.py ├── rank_relax.py ├── rank_relax_test.py ├── requirements.txt ├── setup.py ├── train_dreambooth_light_lora.py ├── train_dreambooth_light_lora.sh ├── train_dreambooth_lora.py ├── train_dreambooth_lora.sh ├── train_hypernet.py ├── train_hypernet.sh ├── train_hypernet_pro.py ├── train_hypernet_pro.sh ├── train_preoptnet.py └── train_preoptnet.sh /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2023] [KohakuBlueLeaf] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | --- 204 | 205 | HyperKohaku includes a number of components and libraries with separate copyright notices and license terms. Your use of those components are subject to the terms and conditions of the following licenses: 206 | 207 | - train_hyperdreambooth.py: 208 | Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch 209 | Project URL: https://github.com/huggingface/diffusers 210 | License: The Apache Software License, Version 2.0 (ASL-2.0) http://www.apache.org/licenses/LICENSE-2.0.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HyperDreamBooth 2 | 3 | 4 | ## Overview 5 | 6 | In this project, I provide a diffusers based implementation of [HyperDreamBooth: HyperNetworks for Fast Personalization of Text-to-Image Models]https://arxiv.org/abs/2307.06949) 7 | Some of the codes come from this project:https://github.com/KohakuBlueleaf/HyperKohaku?tab=readme-ov-file 8 | 9 | ## Main Steps 10 | 11 | Architecture of hyperdreambooth 12 | 13 | 14 | ### 1.Lightweight DreamBooth 15 | 16 | Run the scripts below to test if Lightweight DreamBooth is working properly. It should generate normal images just like a standard LoRA. 17 | 18 | `sh train_dreambooth_light_lora.sh` 19 | `python T2I_inference.py` 20 | 21 | Lightweight DreamBooth 22 | 23 | 24 | 25 | ### 2.Weight Pre-Optimization 26 | 27 | Run the scripts for pre-optimization of weights, then export the corresponding LoRA according to the Identity. It supports batch training. 28 | 29 | `sh train_preoptnet.sh` 30 | `python export_preoptnet_weight.sh` 31 | `python T2I_inference.py` 32 | 33 | Weight Pre-Optimization 34 | 35 | 36 | 37 | ### 3.Hypernet Training 38 | 39 | Run the scripts for hypernetwork training, then export the corresponding LoRA based on the input image. 40 | 41 | `sh train_hypernet.sh` 42 | `sh export_hypernet_weight.sh` 43 | `python T2I_inference.py` 44 | 45 | Hypernet Training 46 | 47 | 48 | 49 | 50 | ### 4.Rank Relaxed Fast Finetuning 51 | 52 | The Rank-Relaxed Fast Finetuning can be executed by merely adjusting the LoRALinearLayer. It comprises two LoRA structures: a frozen LoRA(r=1), initialized with the weights predicted by the Hypernet, and a trainable LoRA(r>1) with zero initialization. The frozen and trainable LoRA need to be merged and exported as a standard LoRA, which is then restored for fast finetuning. 53 | 54 | The detailed steps are as follows: 55 | 56 | 1. Merge the linear layers in the frozen and trainable LoRA into a single linear layer. This can be achieved with simple multiplication or addition operations on their weight matrices, We can easily obtain: 57 | Rank(merged_lora) ≤ Rank(frozen_lora) + Rank(trainable_lora). 58 | 59 | 2. Apply SVD decomposition (with the diagonal matrix absorbed to the left) to the weight matrix of this newly merged linear layer to derive the final LoRA's Up and Down matrices. By retaining all non-zero singular values (N=Rank(merged_lora)), the SVD truncation error becomes zero, which is exactly our objective. 60 | 61 | 3. Restore the weights of the merged LoRA and perform a standard LoRA finetuning. It should take approximately 25 steps. 62 | 63 | `python rank_relax.py` 64 | `sh fast_finetune.py` 65 | `python T2I_inference.py` 66 | 67 | Rank Relaxed Fast Finetuning 68 | 69 | 70 | 71 | ### 5.Experiments 72 | 73 | 1. Set appropriate hyperparameters to directly train light_lora, ensuring that the results are normal. Then, conduct batch training of light_lora in preoptnet for all IDs. 74 | 75 | 2. When only utilizing the hypernetwork prediction, there are some flaws in the local details. 76 | 3. Implementing rank-relaxed fast fine-tuning based on the hypernetwork prediction reduces the training steps to 25, significantly improving the results. 77 | 78 | 4. If combined with the style of LoRA, the effect is as follows. 79 | 80 | 5. If the Down Aux and Up Aux are set to be learnable, there is no need for the Weight Pre-Optimization and Fast-finetuning process, and the results are even better. 81 | `sh train_hypernet_pro.sh` 82 | `sh export_hypernet_weight.sh` 83 | `python T2I_inference.py` 84 | 85 | 86 | 87 | ------ 88 | 89 | ## Main Contributor 90 | 91 | chenxinhua: chenxinhua1002@163.com 92 | -------------------------------------------------------------------------------- /T2I_inference.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionPipeline, DDIMScheduler 2 | import torch 3 | import time 4 | pretrain_model_path="stable-diffusion-models/realisticVisionV40_v40VAE" 5 | 6 | noise_scheduler = DDIMScheduler( 7 | num_train_timesteps=1000, 8 | beta_start=0.00085, 9 | beta_end=0.012, 10 | beta_schedule="scaled_linear", 11 | clip_sample=False, 12 | set_alpha_to_one=False, 13 | steps_offset=1 14 | ) 15 | 16 | pipe = StableDiffusionPipeline.from_pretrained(pretrain_model_path, 17 | torch_dtype=torch.float16, 18 | scheduler=noise_scheduler, 19 | requires_safety_checker=False) 20 | 21 | 22 | 23 | 24 | lora_model_path = "/projects/AIGC/experiments2/rank_relax/" 25 | 26 | pipe.load_lora_weights(lora_model_path, weight_name="pytorch_lora_weights.safetensors") 27 | 28 | # pipe.load_lora_weights(lora_model_path) 29 | prompt = "A [V] face" 30 | 31 | negative_prompt = "nsfw,easynegative" 32 | # negative_prompt = "nsfw, easynegative, paintings, sketches, (worst quality:2), low res, normal quality, ((monochrome)), skin spots, acnes, skin blemishes, extra fingers, fewer fingers, strange fingers, bad hand, mole, ((extra legs)), ((extra hands)), bad-hands-5" 33 | # prompt = "1girl, stulmna, exquisitely detailed skin, looking at viewer, ultra high res, delicate" 34 | # prompt = "A [v] face" 35 | # prompt = "A pixarstyle of a [V] face" 36 | # prompt = "A [V] face with bark skin" 37 | # prompt = "A [V] face" 38 | # prompt = "A professional portrait of a [V] face" 39 | # prompt = "1girl, lineart, monochrome" 40 | 41 | # prompt = "1girl,(exquisitely detailed skin:1.3), looking at viewer, ultra high res, delicate" 42 | # prompt = "1boy, a professional detailed high-quality image, looking at viewer" 43 | # prompt = "1girl, stulmno, solo, best quality, looking at viewer" 44 | # prompt = "1girl, solo, best quality, looking at viewer" 45 | 46 | # prompt = "(upper body: 1.5),(white background:1.4), (illustration:1.1),(best quality),(masterpiece:1.1),(extremely detailed CG unity 8k wallpaper:1.1), (colorful:0.9),(panorama shot:1.4),(full body:1.05),(solo:1.2), (ink splashing),(color splashing),((watercolor)), clear sharp focus,{ 1boy standing },((chinese style )),(flowers,woods),outdoors,rocks, looking at viewer, happy expression ,soft smile, detailed face, clothing decorative pattern details, black hair,black eyes, " 47 | 48 | pipe.to("cuda") 49 | 50 | t0 = time.time() 51 | for i in range(10): 52 | # image = pipe(prompt, negative_prompt=negative_prompt, height=512, width=512, num_inference_steps=30, guidance_scale=7.5,cross_attention_kwargs={"scale":1}).images[0] 53 | image = pipe(prompt, height=512, width=512, num_inference_steps=30).images[0] 54 | image.save("aigc_samples/test_%d.jpg" % i) 55 | t1 = time.time() 56 | print("time elapsed: %f"%((t1-t0)/10)) 57 | print("LoRA: %s"%lora_model_path) 58 | -------------------------------------------------------------------------------- /T2I_inference_merge_lora.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionPipeline, DDIMScheduler 2 | import torch 3 | import random 4 | pretrain_model_path="stable-diffusion-models/realisticVisionV40_v40VAE" 5 | 6 | noise_scheduler = DDIMScheduler( 7 | num_train_timesteps=1000, 8 | beta_start=0.00085, 9 | beta_end=0.012, 10 | beta_schedule="scaled_linear", 11 | clip_sample=False, 12 | set_alpha_to_one=False, 13 | steps_offset=1 14 | ) 15 | 16 | pipe = StableDiffusionPipeline.from_pretrained(pretrain_model_path, 17 | torch_dtype=torch.float16, 18 | scheduler=noise_scheduler, 19 | requires_safety_checker=False) 20 | 21 | 22 | 23 | # personal lora 24 | dir1 = "projects/AIGC/lora_model_test" 25 | lora_model_path1 = "pytorch_lora_weights.safetensors" 26 | # prompt = "A [V] face" 27 | dir2="projects/AIGC/experiments/style_lora" 28 | lora_model_path2 = "pixarStyleModel_lora128.safetensors" 29 | # lora_model_path2 = "Watercolor_Painting_by_vizsumit.safetensors" 30 | # lora_model_path2 = "Professional_Portrait_1.5-000008.safetensors" 31 | prompt = "A pixarstyle of a [V] face" 32 | # prompt = "A watercolor paint of a [V] face" 33 | # prompt = "A professional portrait of a [V] face" 34 | 35 | negative_prompt = "nsfw,easynegative" 36 | 37 | pipe.to("cuda") 38 | 39 | 40 | pipe.load_lora_weights(dir1, weight_name=lora_model_path1, adapter_name="person") 41 | pipe.load_lora_weights(dir2, weight_name=lora_model_path2, adapter_name="style") 42 | # pipe.set_adapters(["person", "style"], adapter_weights=[0.6, 0.4]) #pixar 43 | 44 | pipe.set_adapters(["person", "style"], adapter_weights=[0.4, 0.4]) #watercolor 45 | # Fuses the LoRAs into the Unet 46 | pipe.fuse_lora() 47 | 48 | for i in range(10): 49 | seed = random.randint(0, 100) 50 | generator = torch.Generator(device="cuda").manual_seed(seed) 51 | # image = pipe(prompt, negative_prompt=negative_prompt, height=512, width=512, num_inference_steps=30, guidance_scale=7.5,cross_attention_kwargs={"scale":1}).images[0] 52 | # image = pipe(prompt, height=512, width=512, num_inference_steps=30, guidance_scale=7.5, cross_attention_kwargs={"scale": 1.0}, generator=torch.manual_seed(0)).images[0] 53 | # image = pipe(prompt, height=512, width=512, num_inference_steps=30, guidance_scale=7.5, cross_attention_kwargs={"scale": 1.0}).images[0] 54 | image = pipe(prompt, height=512, width=512, num_inference_steps=30, generator=generator).images[0] 55 | 56 | image.save("aigc_samples/test_after_export_%d.jpg" % i) 57 | 58 | # Gets the Unet back to the original state 59 | pipe.unfuse_lora() -------------------------------------------------------------------------------- /assets/example1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/example1.png -------------------------------------------------------------------------------- /assets/example2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/example2.png -------------------------------------------------------------------------------- /assets/example3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/example3.png -------------------------------------------------------------------------------- /assets/example4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/example4.png -------------------------------------------------------------------------------- /assets/example5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/example5.png -------------------------------------------------------------------------------- /assets/hyperdreambooth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/hyperdreambooth.png -------------------------------------------------------------------------------- /assets/hypernet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/hypernet.png -------------------------------------------------------------------------------- /assets/img_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/img_3.png -------------------------------------------------------------------------------- /assets/img_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/img_5.png -------------------------------------------------------------------------------- /assets/img_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/img_6.png -------------------------------------------------------------------------------- /assets/light_lora.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/light_lora.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/pipeline.png -------------------------------------------------------------------------------- /assets/preoptnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/preoptnet.png -------------------------------------------------------------------------------- /assets/relax_relax_fast_finetune.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/assets/relax_relax_fast_finetune.png -------------------------------------------------------------------------------- /export_hypernet_weight.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | # Modified by KohakuBlueLeaf 5 | # Modified from diffusers/example/dreambooth/train_dreambooth_lora.py 6 | # see original licensed below 7 | # ======================================================================= 8 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 9 | # 10 | # Licensed under the Apache License, Version 2.0 (the "License"); 11 | # you may not use this file except in compliance with the License. 12 | # You may obtain a copy of the License at 13 | # 14 | # http://www.apache.org/licenses/LICENSE-2.0 15 | # 16 | # Unless required by applicable law or agreed to in writing, software 17 | # distributed under the License is distributed on an "AS IS" BASIS, 18 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | # See the License for the specific language governing permissions and 20 | # ======================================================================= 21 | 22 | import argparse 23 | import os 24 | from packaging import version 25 | from PIL import Image 26 | import torch 27 | from torchvision.transforms import CenterCrop, Resize, ToTensor, Compose, Normalize 28 | from diffusers import StableDiffusionPipeline 29 | from diffusers.models.attention_processor import ( 30 | AttnAddedKVProcessor, 31 | AttnAddedKVProcessor2_0, 32 | SlicedAttnAddedKVProcessor, 33 | ) 34 | 35 | 36 | from modules.light_lora import LoRALinearLayer, LoraLoaderMixin 37 | from modules.utils.lora_utils import unet_lora_state_dict, text_encoder_lora_state_dict 38 | from modules.hypernet import HyperDream 39 | 40 | 41 | def parse_args(input_args=None): 42 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 43 | parser.add_argument( 44 | "--pretrained_model_name_or_path", 45 | type=str, 46 | default=None, 47 | required=True, 48 | help="Path to pretrained model or model identifier from huggingface.co/models.", 49 | ) 50 | parser.add_argument( 51 | "--revision", 52 | type=str, 53 | default=None, 54 | required=False, 55 | help="Revision of pretrained model identifier from huggingface.co/models.", 56 | ) 57 | parser.add_argument( 58 | "--hypernet_model_path", 59 | type=str, 60 | default=None, 61 | required=True, 62 | help="Path to pretrained hyperkohaku model", 63 | ) 64 | parser.add_argument( 65 | "--output_dir", 66 | type=str, 67 | default=None, 68 | required=True, 69 | ) 70 | parser.add_argument( 71 | "--reference_image_path", 72 | type=str, 73 | default=None, 74 | required=True, 75 | help="Path to reference image", 76 | ) 77 | parser.add_argument( 78 | "--vit_model_name", 79 | type=str, 80 | default="vit_base_patch16_224", 81 | help="The ViT encoder used in hypernet encoder.", 82 | ) 83 | 84 | parser.add_argument( 85 | "--rank", 86 | type=int, 87 | default=1, 88 | help=("The dimension of the LoRA update matrices."), 89 | ) 90 | parser.add_argument( 91 | "--down_dim", 92 | type=int, 93 | default=160, 94 | help=("The dimension of the LoRA update matrices."), 95 | ) 96 | parser.add_argument( 97 | "--up_dim", 98 | type=int, 99 | default=80, 100 | help=("The dimension of the LoRA update matrices."), 101 | ) 102 | parser.add_argument( 103 | "--train_text_encoder", 104 | action="store_true", 105 | help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", 106 | ) 107 | parser.add_argument( 108 | "--patch_mlp", 109 | action="store_true", 110 | help="Whether to train the text encoder with mlp. If set, the text encoder should be float32 precision.", 111 | ) 112 | parser.add_argument( 113 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 114 | ) 115 | 116 | if input_args is not None: 117 | args = parser.parse_args(input_args) 118 | else: 119 | args = parser.parse_args() 120 | return args 121 | 122 | def main(args): 123 | # Load Model 124 | pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, torch_dtype=torch.float32) 125 | # pipe.to("cuda") 126 | 127 | unet = pipe.unet 128 | text_encoder = pipe.text_encoder 129 | 130 | unet_lora_parameters = [] 131 | unet_lora_linear_layers = [] 132 | for i, (attn_processor_name, attn_processor) in enumerate(unet.attn_processors.items()): 133 | print("unet.attn_processor->%d:%s" % (i, attn_processor_name), attn_processor) 134 | # Parse the attention module. 135 | attn_module = unet 136 | for n in attn_processor_name.split(".")[:-1]: 137 | attn_module = getattr(attn_module, n) 138 | print("attn_module:", attn_module) 139 | 140 | # Set the `lora_layer` attribute of the attention-related matrices. 141 | attn_module.to_q.set_lora_layer( 142 | LoRALinearLayer( 143 | in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, 144 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False, 145 | ) 146 | ) 147 | attn_module.to_k.set_lora_layer( 148 | LoRALinearLayer( 149 | in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, 150 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False, 151 | ) 152 | ) 153 | attn_module.to_v.set_lora_layer( 154 | LoRALinearLayer( 155 | in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, 156 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False, 157 | ) 158 | ) 159 | attn_module.to_out[0].set_lora_layer( 160 | LoRALinearLayer( 161 | in_features=attn_module.to_out[0].in_features, out_features=attn_module.to_out[0].out_features, 162 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False, 163 | ) 164 | ) 165 | # Accumulate the LoRA params to optimize. 166 | unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) 167 | unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) 168 | unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) 169 | unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) 170 | 171 | # Accumulate the LoRALinerLayer to optimize. 172 | unet_lora_linear_layers.append(attn_module.to_q.lora_layer) 173 | unet_lora_linear_layers.append(attn_module.to_k.lora_layer) 174 | unet_lora_linear_layers.append(attn_module.to_v.lora_layer) 175 | unet_lora_linear_layers.append(attn_module.to_out[0].lora_layer) 176 | 177 | if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): 178 | attn_module.add_k_proj.set_lora_layer( 179 | LoRALinearLayer( 180 | in_features=attn_module.add_k_proj.in_features, 181 | out_features=attn_module.add_k_proj.out_features, 182 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False, 183 | ) 184 | ) 185 | attn_module.add_v_proj.set_lora_layer( 186 | LoRALinearLayer( 187 | in_features=attn_module.add_v_proj.in_features, 188 | out_features=attn_module.add_v_proj.out_features, 189 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False, 190 | ) 191 | ) 192 | unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters()) 193 | unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters()) 194 | 195 | unet_lora_linear_layers.append(attn_module.add_k_proj.lora_layer) 196 | unet_lora_linear_layers.append(attn_module.add_v_proj.lora_layer) 197 | 198 | # The text encoder comes from 🤗 transformers, so we cannot directly modify it. 199 | # So, instead, we monkey-patch the forward calls of its attention-blocks. 200 | if args.train_text_encoder: 201 | # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 202 | # if patch_mlp is True, the finetuning will cover the text encoder mlp, 203 | # otherwise only the text encoder attention, total lora is (12+12)*4=96 204 | text_lora_parameters, text_lora_linear_layers = LoraLoaderMixin._modify_text_encoder(text_encoder, 205 | dtype=torch.float32, 206 | rank=args.rank, 207 | down_dim=args.down_dim, 208 | up_dim=args.up_dim, 209 | patch_mlp=args.patch_mlp, 210 | is_train=False) 211 | # total loras 212 | lora_linear_layers = unet_lora_linear_layers + text_lora_linear_layers \ 213 | if args.train_text_encoder else unet_lora_linear_layers 214 | 215 | if args.vit_model_name == "vit_base_patch16_224": 216 | img_encoder_model_name = "vit_base_patch16_224" 217 | ref_img_size = 224 218 | mean = [0.5000] 219 | std = [0.5000] 220 | elif args.vit_model_name == "vit_huge_patch14_clip_224": 221 | img_encoder_model_name = "vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k" 222 | ref_img_size = 224 223 | mean = [0.4815, 0.4578, 0.4082] 224 | std = [0.2686, 0.2613, 0.2758] 225 | elif args.vit_model_name == "vit_huge_patch14_clip_336": 226 | img_encoder_model_name = "vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k" 227 | ref_img_size = 336 228 | mean = [0.4815, 0.4578, 0.4082] 229 | std = [0.2686, 0.2613, 0.2758] 230 | else: 231 | raise ValueError("%s does not supports!" % args.img_encoder_model_name) 232 | 233 | hypernet_transposes = Compose([ 234 | Resize(size=ref_img_size), 235 | CenterCrop(size=(ref_img_size, ref_img_size)), 236 | ToTensor(), 237 | Normalize(mean=mean, std=std), 238 | ]) 239 | 240 | hypernetwork = HyperDream( 241 | img_encoder_model_name=img_encoder_model_name, 242 | ref_img_size=ref_img_size, 243 | weight_num=len(lora_linear_layers), 244 | weight_dim=(args.up_dim + args.down_dim) * args.rank, 245 | ) 246 | hypernetwork.set_lilora(lora_linear_layers) 247 | 248 | if os.path.isdir(args.hypernet_model_path): 249 | path = os.path.join(args.hypernet_model_path, "hypernetwork.bin") 250 | weight = torch.load(path) 251 | sd = weight['hypernetwork'] 252 | hypernetwork.load_state_dict(sd) 253 | else: 254 | weight = torch.load(args.hypernet_model_path['hypernetwork']) 255 | sd = weight['hypernetwork'] 256 | hypernetwork.load_state_dict(sd) 257 | 258 | for i, lilora in enumerate(lora_linear_layers): 259 | seed = weight['aux_seed_%d' % i] 260 | down_aux = weight['down_aux_%d' % i] 261 | up_aux = weight['up_aux_%d' % i] 262 | lilora.update_aux(seed, down_aux, up_aux) 263 | 264 | print(f"Hypernet weights loaded from: {args.hypernet_model_path}") 265 | 266 | hypernetwork = hypernetwork.to("cuda") 267 | hypernetwork = hypernetwork.eval() 268 | 269 | ref_img = Image.open(args.reference_image_path).convert("RGB") 270 | ref_img = hypernet_transposes(ref_img).unsqueeze(0).to("cuda") 271 | 272 | with torch.no_grad(): 273 | weight, weight_list = hypernetwork(ref_img) 274 | print("weight>>>>>>>>>>>:",weight.shape, weight) 275 | 276 | # convert down and up weights to linear layer as LoRALinearLayer 277 | for weight, lora_layer in zip(weight_list, lora_linear_layers): 278 | lora_layer.update_weight(weight) 279 | lora_layer.convert_to_standard_lora() 280 | 281 | unet_lora_layers_to_save = unet_lora_state_dict(unet) 282 | text_encoder_lora_layers_to_save = None 283 | if args.train_text_encoder: 284 | text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(text_encoder, patch_mlp=args.patch_mlp) 285 | 286 | LoraLoaderMixin.save_lora_weights( 287 | save_directory=args.output_dir, 288 | unet_lora_layers=unet_lora_layers_to_save, 289 | text_encoder_lora_layers=text_encoder_lora_layers_to_save, 290 | ) 291 | 292 | print("Export LoRA to: %s"%args.output_dir) 293 | print("==================================complete======================================") 294 | 295 | 296 | if __name__ == "__main__": 297 | args = parse_args() 298 | main(args) -------------------------------------------------------------------------------- /export_hypernet_weight.sh: -------------------------------------------------------------------------------- 1 | export MODEL_NAME="stable-diffusion-models/realisticVisionV40_v40VAE" 2 | export HYPER_WEIGHT_DIR="projects/AIGC/experiments2/hypernet/CelebA-HQ-10k-no-pretrain2" 3 | export OUTPUT_DIR="projects/AIGC/lora_model_test" 4 | export reference_image_path="dataset/FFHQ_test/00015.png" 5 | 6 | 7 | python "export_hypernet_weight.py" \ 8 | --pretrained_model_name_or_path $MODEL_NAME \ 9 | --hypernet_model_path $HYPER_WEIGHT_DIR \ 10 | --output_dir $OUTPUT_DIR \ 11 | --rank 1 \ 12 | --down_dim 160 \ 13 | --up_dim 80 \ 14 | --train_text_encoder \ 15 | --reference_image_path $reference_image_path 16 | -------------------------------------------------------------------------------- /export_preoptnet_weight.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | # Modified by KohakuBlueLeaf 5 | # Modified from diffusers/example/dreambooth/train_dreambooth_lora.py 6 | # see original licensed below 7 | # ======================================================================= 8 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 9 | # 10 | # Licensed under the Apache License, Version 2.0 (the "License"); 11 | # you may not use this file except in compliance with the License. 12 | # You may obtain a copy of the License at 13 | # 14 | # http://www.apache.org/licenses/LICENSE-2.0 15 | # 16 | # Unless required by applicable law or agreed to in writing, software 17 | # distributed under the License is distributed on an "AS IS" BASIS, 18 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | # See the License for the specific language governing permissions and 20 | # ======================================================================= 21 | 22 | import argparse 23 | import os 24 | from packaging import version 25 | 26 | import torch 27 | from transformers import AutoTokenizer, PretrainedConfig 28 | 29 | from diffusers import ( 30 | UNet2DConditionModel, 31 | ) 32 | from diffusers.models.attention_processor import ( 33 | AttnAddedKVProcessor, 34 | AttnAddedKVProcessor2_0, 35 | SlicedAttnAddedKVProcessor, 36 | ) 37 | from diffusers import StableDiffusionPipeline 38 | 39 | from modules.light_lora import LoRALinearLayer, LoraLoaderMixin 40 | from modules.utils.lora_utils import unet_lora_state_dict, text_encoder_lora_state_dict 41 | from modules.hypernet import PreOptHyperDream 42 | 43 | 44 | def parse_args(input_args=None): 45 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 46 | parser.add_argument( 47 | "--pretrained_model_name_or_path", 48 | type=str, 49 | default=None, 50 | required=True, 51 | help="Path to pretrained model or model identifier from huggingface.co/models.", 52 | ) 53 | parser.add_argument( 54 | "--revision", 55 | type=str, 56 | default=None, 57 | required=False, 58 | help="Revision of pretrained model identifier from huggingface.co/models.", 59 | ) 60 | parser.add_argument( 61 | "--pre_opt_weight_path", 62 | type=str, 63 | default=None, 64 | required=True, 65 | help="Path to pretrained hyperkohaku model", 66 | ) 67 | 68 | parser.add_argument( 69 | "--output_dir", 70 | type=str, 71 | default=None, 72 | required=True, 73 | ) 74 | 75 | parser.add_argument( 76 | "--rank", 77 | type=int, 78 | default=1, 79 | help=("The dimension of the LoRA update matrices."), 80 | ) 81 | parser.add_argument( 82 | "--down_dim", 83 | type=int, 84 | default=160, 85 | help=("The dimension of the LoRA update matrices."), 86 | ) 87 | parser.add_argument( 88 | "--up_dim", 89 | type=int, 90 | default=80, 91 | help=("The dimension of the LoRA update matrices."), 92 | ) 93 | parser.add_argument( 94 | "--train_text_encoder", 95 | action="store_true", 96 | help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", 97 | ) 98 | parser.add_argument( 99 | "--patch_mlp", 100 | action="store_true", 101 | help="Whether to train the text encoder with mlp. If set, the text encoder should be float32 precision.", 102 | ) 103 | parser.add_argument( 104 | "--reference_image_id", 105 | type=int, 106 | default=1, 107 | help=("id in the celeb-a dataset"), 108 | ) 109 | 110 | parser.add_argument( 111 | "--total_identities", 112 | type=int, 113 | default=30000, 114 | help=("The identities size of the training dataset."), 115 | ) 116 | 117 | parser.add_argument( 118 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 119 | ) 120 | if input_args is not None: 121 | args = parser.parse_args(input_args) 122 | else: 123 | args = parser.parse_args() 124 | return args 125 | 126 | 127 | def main(args): 128 | # Load Model 129 | pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, torch_dtype=torch.float32) 130 | # pipe.to("cuda") 131 | 132 | unet = pipe.unet 133 | text_encoder = pipe.text_encoder 134 | 135 | unet_lora_parameters = [] 136 | unet_lora_linear_layers = [] 137 | for i, (attn_processor_name, attn_processor) in enumerate(unet.attn_processors.items()): 138 | print("unet.attn_processor->%d:%s" % (i, attn_processor_name), attn_processor) 139 | # Parse the attention module. 140 | attn_module = unet 141 | for n in attn_processor_name.split(".")[:-1]: 142 | attn_module = getattr(attn_module, n) 143 | print("attn_module:", attn_module) 144 | 145 | # Set the `lora_layer` attribute of the attention-related matrices. 146 | attn_module.to_q.set_lora_layer( 147 | LoRALinearLayer( 148 | in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, 149 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False, 150 | ) 151 | ) 152 | attn_module.to_k.set_lora_layer( 153 | LoRALinearLayer( 154 | in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, 155 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False, 156 | ) 157 | ) 158 | attn_module.to_v.set_lora_layer( 159 | LoRALinearLayer( 160 | in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, 161 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False, 162 | ) 163 | ) 164 | attn_module.to_out[0].set_lora_layer( 165 | LoRALinearLayer( 166 | in_features=attn_module.to_out[0].in_features, out_features=attn_module.to_out[0].out_features, 167 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False, 168 | ) 169 | ) 170 | # Accumulate the LoRA params to optimize. 171 | unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) 172 | unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) 173 | unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) 174 | unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) 175 | 176 | # Accumulate the LoRALinerLayer to optimize. 177 | unet_lora_linear_layers.append(attn_module.to_q.lora_layer) 178 | unet_lora_linear_layers.append(attn_module.to_k.lora_layer) 179 | unet_lora_linear_layers.append(attn_module.to_v.lora_layer) 180 | unet_lora_linear_layers.append(attn_module.to_out[0].lora_layer) 181 | 182 | if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): 183 | attn_module.add_k_proj.set_lora_layer( 184 | LoRALinearLayer( 185 | in_features=attn_module.add_k_proj.in_features, out_features=attn_module.add_k_proj.out_features, 186 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False, 187 | ) 188 | ) 189 | attn_module.add_v_proj.set_lora_layer( 190 | LoRALinearLayer( 191 | in_features=attn_module.add_v_proj.in_features, out_features=attn_module.add_v_proj.out_features, 192 | rank=args.rank, down_dim=args.down_dim, up_dim=args.up_dim, is_train=False, 193 | ) 194 | ) 195 | unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters()) 196 | unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters()) 197 | 198 | unet_lora_linear_layers.append(attn_module.add_k_proj.lora_layer) 199 | unet_lora_linear_layers.append(attn_module.add_v_proj.lora_layer) 200 | 201 | # The text encoder comes from 🤗 transformers, so we cannot directly modify it. 202 | # So, instead, we monkey-patch the forward calls of its attention-blocks. 203 | if args.train_text_encoder: 204 | # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 205 | # if patch_mlp is True, the finetuning will cover the text encoder mlp, 206 | # otherwise only the text encoder attention, total lora is (12+12)*4=96 207 | text_lora_parameters, text_lora_linear_layers = LoraLoaderMixin._modify_text_encoder(text_encoder, 208 | dtype=torch.float32, 209 | rank=args.rank, 210 | down_dim=args.down_dim, 211 | up_dim=args.up_dim, 212 | patch_mlp=args.patch_mlp, 213 | is_train=False) 214 | 215 | lora_linear_layers = unet_lora_linear_layers + text_lora_linear_layers \ 216 | if args.train_text_encoder else unet_lora_linear_layers 217 | 218 | # create PreOptHyperDream and set lilora 219 | pre_opt_net = PreOptHyperDream(args.rank, args.down_dim, args.up_dim) 220 | pre_opt_net.set_lilora(lora_linear_layers, args.total_identities) 221 | 222 | # load weight 223 | if os.path.isfile(args.pre_opt_weight_path): 224 | weight = torch.load(args.pre_opt_weight_path) 225 | else: 226 | weight = torch.load(os.path.join(args.pre_opt_weight_path, 'pre_optimized.bin')) 227 | # load pre-optimized lilora weights for each identity 228 | sd = weight['pre_optimized'] 229 | pre_opt_net.load_state_dict(sd) 230 | pre_opt_net.requires_grad_(False) 231 | pre_opt_net.set_device('cuda') 232 | 233 | for i,lilora in enumerate(lora_linear_layers): 234 | seed = weight['aux_seed_%d'%i] 235 | down_aux = weight['down_aux_%d'%i] 236 | up_aux = weight['up_aux_%d'%i] 237 | lilora.update_aux(seed, down_aux, up_aux) 238 | 239 | print(f"PreOptNet weights loaded from: {args.pre_opt_weight_path}") 240 | 241 | with torch.no_grad(): 242 | # get pre-optimized weights according to identity 243 | weights, weight_list = pre_opt_net([args.reference_image_id]) 244 | for weight, lora_layer in zip(weight_list, lora_linear_layers): 245 | lora_layer.update_weight(weight) 246 | lora_layer.convert_to_standard_lora() 247 | print("weight",weight) 248 | 249 | unet_lora_layers_to_save = unet_lora_state_dict(unet) 250 | text_encoder_lora_layers_to_save = None 251 | if args.train_text_encoder: 252 | text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(text_encoder, patch_mlp=args.patch_mlp) 253 | 254 | LoraLoaderMixin.save_lora_weights( 255 | save_directory=args.output_dir, 256 | unet_lora_layers=unet_lora_layers_to_save, 257 | text_encoder_lora_layers=text_encoder_lora_layers_to_save, 258 | ) 259 | 260 | print("Save LoRA to: %s"%args.output_dir) 261 | print("==================================complete======================================") 262 | 263 | if __name__ == "__main__": 264 | args = parse_args() 265 | main(args) -------------------------------------------------------------------------------- /export_preoptnet_weight.sh: -------------------------------------------------------------------------------- 1 | export MODEL_NAME="stable-diffusion-models/realisticVisionV40_v40VAE" 2 | export PRE_OPTNET_WEIGHT_DIR="projects/AIGC/experiments/pretrained/CelebA-HQ-100" 3 | export OUTPUT_DIR="projects/AIGC/lora_model_test" 4 | 5 | python "export_preoptnet_weight.py" \ 6 | --pretrained_model_name_or_path $MODEL_NAME \ 7 | --pre_opt_weight_path $PRE_OPTNET_WEIGHT_DIR \ 8 | --output_dir $OUTPUT_DIR \ 9 | --vit_model_name vit_huge_patch14_clip_336 \ 10 | --rank 1 \ 11 | --down_dim 160 \ 12 | --up_dim 80 \ 13 | --train_text_encoder \ 14 | --total_identities 100 \ 15 | --reference_image_id 10 16 | -------------------------------------------------------------------------------- /fast_finetune.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import argparse 17 | import copy 18 | import gc 19 | import itertools 20 | import logging 21 | import math 22 | import os 23 | import time 24 | import shutil 25 | import warnings 26 | from pathlib import Path 27 | 28 | import numpy as np 29 | import torch 30 | import torch.nn.functional as F 31 | import torch.utils.checkpoint 32 | import transformers 33 | from accelerate import Accelerator 34 | from accelerate.logging import get_logger 35 | from accelerate.utils import ProjectConfiguration, set_seed 36 | from huggingface_hub import create_repo, upload_folder 37 | from huggingface_hub.utils import insecure_hashlib 38 | from packaging import version 39 | from PIL import Image 40 | from PIL.ImageOps import exif_transpose 41 | from torch.utils.data import Dataset 42 | from torchvision import transforms 43 | from tqdm.auto import tqdm 44 | from transformers import AutoTokenizer, PretrainedConfig 45 | 46 | import diffusers 47 | from diffusers import ( 48 | AutoencoderKL, 49 | DDPMScheduler, 50 | DiffusionPipeline, 51 | DPMSolverMultistepScheduler, 52 | StableDiffusionPipeline, 53 | UNet2DConditionModel, 54 | ) 55 | from diffusers.models.attention_processor import ( 56 | AttnAddedKVProcessor, 57 | AttnAddedKVProcessor2_0, 58 | SlicedAttnAddedKVProcessor, 59 | ) 60 | from diffusers.optimization import get_scheduler 61 | from diffusers.training_utils import unet_lora_state_dict 62 | from diffusers.utils import check_min_version, is_wandb_available 63 | from diffusers.utils.import_utils import is_xformers_available 64 | 65 | from modules.lora import LoRALinearLayer, LoraLoaderMixin 66 | 67 | # from diffusers.models.lora import LoRALinearLayer 68 | # from diffusers.loaders import LoraLoaderMixin 69 | 70 | 71 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 72 | check_min_version("0.25.0.dev0") 73 | 74 | logger = get_logger(__name__) 75 | 76 | 77 | # TODO: This function should be removed once training scripts are rewritten in PEFT 78 | def text_encoder_lora_state_dict(text_encoder, patch_mlp=False): 79 | state_dict = {} 80 | 81 | def text_encoder_attn_modules(text_encoder): 82 | from transformers import CLIPTextModel, CLIPTextModelWithProjection 83 | 84 | attn_modules = [] 85 | 86 | if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): 87 | for i, layer in enumerate(text_encoder.text_model.encoder.layers): 88 | name = f"text_model.encoder.layers.{i}.self_attn" 89 | mod = layer.self_attn 90 | attn_modules.append((name, mod)) 91 | 92 | return attn_modules 93 | 94 | def text_encoder_mlp_modules(text_encoder): 95 | from transformers import CLIPTextModel, CLIPTextModelWithProjection 96 | mlp_modules = [] 97 | if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): 98 | for i, layer in enumerate(text_encoder.text_model.encoder.layers): 99 | name = f"text_model.encoder.layers.{i}.mlp" 100 | mod = layer.mlp 101 | mlp_modules.append((name, mod)) 102 | 103 | return mlp_modules 104 | 105 | # text encoder attn layer 106 | for name, module in text_encoder_attn_modules(text_encoder): 107 | for k, v in module.q_proj.lora_linear_layer.state_dict().items(): 108 | state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v 109 | 110 | for k, v in module.k_proj.lora_linear_layer.state_dict().items(): 111 | state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v 112 | 113 | for k, v in module.v_proj.lora_linear_layer.state_dict().items(): 114 | state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v 115 | 116 | for k, v in module.out_proj.lora_linear_layer.state_dict().items(): 117 | state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v 118 | 119 | # text encoder mlp layer 120 | if patch_mlp: 121 | for name, module in text_encoder_mlp_modules(text_encoder): 122 | for k, v in module.fc1.lora_linear_layer.state_dict().items(): 123 | state_dict[f"{name}.fc1.lora_linear_layer.{k}"] = v 124 | 125 | for k, v in module.fc2.lora_linear_layer.state_dict().items(): 126 | state_dict[f"{name}.fc2.lora_linear_layer.{k}"] = v 127 | 128 | return state_dict 129 | 130 | 131 | def save_model_card( 132 | repo_id: str, 133 | images=None, 134 | base_model=str, 135 | train_text_encoder=False, 136 | prompt=str, 137 | repo_folder=None, 138 | pipeline: DiffusionPipeline = None, 139 | ): 140 | img_str = "" 141 | for i, image in enumerate(images): 142 | image.save(os.path.join(repo_folder, f"image_{i}.png")) 143 | img_str += f"![img_{i}](./image_{i}.png)\n" 144 | 145 | yaml = f""" 146 | --- 147 | license: creativeml-openrail-m 148 | base_model: {base_model} 149 | instance_prompt: {prompt} 150 | tags: 151 | - {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'} 152 | - {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'} 153 | - text-to-image 154 | - diffusers 155 | - lora 156 | inference: true 157 | --- 158 | """ 159 | model_card = f""" 160 | # LoRA DreamBooth - {repo_id} 161 | 162 | These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n 163 | {img_str} 164 | 165 | LoRA for the text encoder was enabled: {train_text_encoder}. 166 | """ 167 | with open(os.path.join(repo_folder, "README.md"), "w") as f: 168 | f.write(yaml + model_card) 169 | 170 | 171 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): 172 | text_encoder_config = PretrainedConfig.from_pretrained( 173 | pretrained_model_name_or_path, 174 | subfolder="text_encoder", 175 | revision=revision, 176 | ) 177 | model_class = text_encoder_config.architectures[0] 178 | 179 | if model_class == "CLIPTextModel": 180 | from transformers import CLIPTextModel 181 | 182 | return CLIPTextModel 183 | elif model_class == "RobertaSeriesModelWithTransformation": 184 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation 185 | 186 | return RobertaSeriesModelWithTransformation 187 | elif model_class == "T5EncoderModel": 188 | from transformers import T5EncoderModel 189 | 190 | return T5EncoderModel 191 | else: 192 | raise ValueError(f"{model_class} is not supported.") 193 | 194 | 195 | def parse_args(input_args=None): 196 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 197 | parser.add_argument( 198 | "--pretrained_model_name_or_path", 199 | type=str, 200 | default=None, 201 | required=True, 202 | help="Path to pretrained model or model identifier from huggingface.co/models.", 203 | ) 204 | parser.add_argument( 205 | "--revision", 206 | type=str, 207 | default=None, 208 | required=False, 209 | help="Revision of pretrained model identifier from huggingface.co/models.", 210 | ) 211 | parser.add_argument( 212 | "--variant", 213 | type=str, 214 | default=None, 215 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 216 | ) 217 | parser.add_argument( 218 | "--tokenizer_name", 219 | type=str, 220 | default=None, 221 | help="Pretrained tokenizer name or path if not the same as model_name", 222 | ) 223 | parser.add_argument( 224 | "--instance_image_path", 225 | type=str, 226 | default=None, 227 | required=True, 228 | help="A folder containing the training data of instance images.", 229 | ) 230 | parser.add_argument( 231 | "--class_data_dir", 232 | type=str, 233 | default=None, 234 | required=False, 235 | help="A folder containing the training data of class images.", 236 | ) 237 | parser.add_argument( 238 | "--instance_prompt", 239 | type=str, 240 | default=None, 241 | required=True, 242 | help="The prompt with identifier specifying the instance", 243 | ) 244 | parser.add_argument( 245 | "--class_prompt", 246 | type=str, 247 | default=None, 248 | help="The prompt to specify images in the same class as provided instance images.", 249 | ) 250 | parser.add_argument( 251 | "--validation_prompt", 252 | type=str, 253 | default=None, 254 | help="A prompt that is used during validation to verify that the model is learning.", 255 | ) 256 | parser.add_argument( 257 | "--num_validation_images", 258 | type=int, 259 | default=5, 260 | help="Number of images that should be generated during validation with `validation_prompt`.", 261 | ) 262 | 263 | parser.add_argument( 264 | "--with_prior_preservation", 265 | default=False, 266 | action="store_true", 267 | help="Flag to add prior preservation loss.", 268 | ) 269 | parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") 270 | parser.add_argument( 271 | "--num_class_images", 272 | type=int, 273 | default=100, 274 | help=( 275 | "Minimal class images for prior preservation loss. If there are not enough images already present in" 276 | " class_data_dir, additional images will be sampled with class_prompt." 277 | ), 278 | ) 279 | parser.add_argument( 280 | "--output_dir", 281 | type=str, 282 | default="lora-dreambooth-model", 283 | help="The output directory where the model predictions and checkpoints will be written.", 284 | ) 285 | parser.add_argument( 286 | "--resume_dir", 287 | type=str, 288 | default="", 289 | help="The output directory where the pretrained model predictions and checkpoints has be written.", 290 | ) 291 | 292 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 293 | parser.add_argument( 294 | "--resolution", 295 | type=int, 296 | default=512, 297 | help=( 298 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 299 | " resolution" 300 | ), 301 | ) 302 | parser.add_argument( 303 | "--center_crop", 304 | default=False, 305 | action="store_true", 306 | help=( 307 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 308 | " cropped. The images will be resized to the resolution first before cropping." 309 | ), 310 | ) 311 | parser.add_argument( 312 | "--train_text_encoder", 313 | action="store_true", 314 | help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", 315 | ) 316 | parser.add_argument( 317 | "--patch_mlp", 318 | action="store_true", 319 | help="Whether to train the text encoder with mlp. If set, the text encoder should be float32 precision.", 320 | ) 321 | parser.add_argument( 322 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 323 | ) 324 | parser.add_argument( 325 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." 326 | ) 327 | parser.add_argument("--num_train_steps", type=int, default=1) 328 | parser.add_argument( 329 | "--max_train_steps", 330 | type=int, 331 | default=None, 332 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 333 | ) 334 | parser.add_argument( 335 | "--checkpoints_total_limit", 336 | type=int, 337 | default=None, 338 | help=("Max number of checkpoints to store."), 339 | ) 340 | parser.add_argument( 341 | "--resume_from_checkpoint", 342 | type=str, 343 | default=None, 344 | help=( 345 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 346 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 347 | ), 348 | ) 349 | parser.add_argument( 350 | "--gradient_accumulation_steps", 351 | type=int, 352 | default=1, 353 | help="Number of updates steps to accumulate before performing a backward/update pass.", 354 | ) 355 | parser.add_argument( 356 | "--gradient_checkpointing", 357 | action="store_true", 358 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 359 | ) 360 | parser.add_argument( 361 | "--learning_rate", 362 | type=float, 363 | default=5e-4, 364 | help="Initial learning rate (after the potential warmup period) to use.", 365 | ) 366 | parser.add_argument( 367 | "--scale_lr", 368 | action="store_true", 369 | default=False, 370 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 371 | ) 372 | parser.add_argument( 373 | "--lr_scheduler", 374 | type=str, 375 | default="constant", 376 | help=( 377 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 378 | ' "constant", "constant_with_warmup"]' 379 | ), 380 | ) 381 | parser.add_argument( 382 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 383 | ) 384 | parser.add_argument( 385 | "--lr_num_cycles", 386 | type=int, 387 | default=1, 388 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", 389 | ) 390 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") 391 | parser.add_argument( 392 | "--dataloader_num_workers", 393 | type=int, 394 | default=0, 395 | help=( 396 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 397 | ), 398 | ) 399 | parser.add_argument( 400 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 401 | ) 402 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 403 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 404 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 405 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 406 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 407 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 408 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 409 | parser.add_argument( 410 | "--hub_model_id", 411 | type=str, 412 | default=None, 413 | help="The name of the repository to keep in sync with the local `output_dir`.", 414 | ) 415 | parser.add_argument( 416 | "--logging_dir", 417 | type=str, 418 | default="logs", 419 | help=( 420 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 421 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 422 | ), 423 | ) 424 | parser.add_argument( 425 | "--allow_tf32", 426 | action="store_true", 427 | help=( 428 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 429 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 430 | ), 431 | ) 432 | parser.add_argument( 433 | "--report_to", 434 | type=str, 435 | default="tensorboard", 436 | help=( 437 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 438 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 439 | ), 440 | ) 441 | parser.add_argument( 442 | "--mixed_precision", 443 | type=str, 444 | default=None, 445 | choices=["no", "fp16", "bf16"], 446 | help=( 447 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 448 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 449 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 450 | ), 451 | ) 452 | parser.add_argument( 453 | "--prior_generation_precision", 454 | type=str, 455 | default=None, 456 | choices=["no", "fp32", "fp16", "bf16"], 457 | help=( 458 | "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 459 | " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." 460 | ), 461 | ) 462 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 463 | parser.add_argument( 464 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 465 | ) 466 | parser.add_argument( 467 | "--pre_compute_text_embeddings", 468 | action="store_true", 469 | help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.", 470 | ) 471 | parser.add_argument( 472 | "--tokenizer_max_length", 473 | type=int, 474 | default=None, 475 | required=False, 476 | help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", 477 | ) 478 | parser.add_argument( 479 | "--text_encoder_use_attention_mask", 480 | action="store_true", 481 | required=False, 482 | help="Whether to use attention mask for the text encoder", 483 | ) 484 | parser.add_argument( 485 | "--validation_images", 486 | required=False, 487 | default=None, 488 | nargs="+", 489 | help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", 490 | ) 491 | parser.add_argument( 492 | "--class_labels_conditioning", 493 | required=False, 494 | default=None, 495 | help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", 496 | ) 497 | parser.add_argument( 498 | "--rank", 499 | type=int, 500 | default=4, 501 | help=("The dimension of the LoRA update matrices."), 502 | ) 503 | 504 | if input_args is not None: 505 | args = parser.parse_args(input_args) 506 | else: 507 | args = parser.parse_args() 508 | 509 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 510 | if env_local_rank != -1 and env_local_rank != args.local_rank: 511 | args.local_rank = env_local_rank 512 | 513 | if args.with_prior_preservation: 514 | if args.class_data_dir is None: 515 | raise ValueError("You must specify a data directory for class images.") 516 | if args.class_prompt is None: 517 | raise ValueError("You must specify prompt for class images.") 518 | else: 519 | # logger is not available yet 520 | if args.class_data_dir is not None: 521 | warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") 522 | if args.class_prompt is not None: 523 | warnings.warn("You need not use --class_prompt without --with_prior_preservation.") 524 | 525 | if args.train_text_encoder and args.pre_compute_text_embeddings: 526 | raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`") 527 | 528 | return args 529 | 530 | 531 | class DreamBoothDataset(Dataset): 532 | """ 533 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. 534 | It pre-processes the images and the tokenizes prompts. 535 | """ 536 | 537 | def __init__( 538 | self, 539 | instance_image_path, 540 | instance_prompt, 541 | tokenizer, 542 | class_data_root=None, 543 | class_prompt=None, 544 | class_num=None, 545 | size=512, 546 | center_crop=False, 547 | encoder_hidden_states=None, 548 | class_prompt_encoder_hidden_states=None, 549 | tokenizer_max_length=None, 550 | ): 551 | self.size = size 552 | self.center_crop = center_crop 553 | self.tokenizer = tokenizer 554 | self.encoder_hidden_states = encoder_hidden_states 555 | self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states 556 | self.tokenizer_max_length = tokenizer_max_length 557 | 558 | if not os.path.isfile(instance_image_path): 559 | raise ValueError('%s is not a file!'%instance_image_path) 560 | 561 | self.instance_images_path = [instance_image_path] 562 | 563 | self.num_instance_images = len(self.instance_images_path) 564 | self.instance_prompt = instance_prompt 565 | self._length = self.num_instance_images 566 | 567 | if class_data_root is not None: 568 | self.class_data_root = Path(class_data_root) 569 | self.class_data_root.mkdir(parents=True, exist_ok=True) 570 | self.class_images_path = list(self.class_data_root.iterdir()) 571 | if class_num is not None: 572 | self.num_class_images = min(len(self.class_images_path), class_num) 573 | else: 574 | self.num_class_images = len(self.class_images_path) 575 | self._length = max(self.num_class_images, self.num_instance_images) 576 | self.class_prompt = class_prompt 577 | else: 578 | self.class_data_root = None 579 | 580 | self.image_transforms = transforms.Compose( 581 | [ 582 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), 583 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 584 | transforms.ToTensor(), 585 | transforms.Normalize([0.5], [0.5]), 586 | ] 587 | ) 588 | 589 | def __len__(self): 590 | return self._length 591 | 592 | def __getitem__(self, index): 593 | example = {} 594 | instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) 595 | instance_image = exif_transpose(instance_image) 596 | 597 | if not instance_image.mode == "RGB": 598 | instance_image = instance_image.convert("RGB") 599 | example["instance_images"] = self.image_transforms(instance_image) 600 | 601 | if self.encoder_hidden_states is not None: 602 | example["instance_prompt_ids"] = self.encoder_hidden_states 603 | else: 604 | text_inputs = tokenize_prompt( 605 | self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length 606 | ) 607 | example["instance_prompt_ids"] = text_inputs.input_ids 608 | example["instance_attention_mask"] = text_inputs.attention_mask 609 | 610 | if self.class_data_root: 611 | class_image = Image.open(self.class_images_path[index % self.num_class_images]) 612 | class_image = exif_transpose(class_image) 613 | 614 | if not class_image.mode == "RGB": 615 | class_image = class_image.convert("RGB") 616 | example["class_images"] = self.image_transforms(class_image) 617 | 618 | if self.class_prompt_encoder_hidden_states is not None: 619 | example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states 620 | else: 621 | class_text_inputs = tokenize_prompt( 622 | self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length 623 | ) 624 | example["class_prompt_ids"] = class_text_inputs.input_ids 625 | example["class_attention_mask"] = class_text_inputs.attention_mask 626 | 627 | return example 628 | 629 | 630 | def collate_fn(examples, with_prior_preservation=False): 631 | has_attention_mask = "instance_attention_mask" in examples[0] 632 | 633 | input_ids = [example["instance_prompt_ids"] for example in examples] 634 | pixel_values = [example["instance_images"] for example in examples] 635 | 636 | if has_attention_mask: 637 | attention_mask = [example["instance_attention_mask"] for example in examples] 638 | 639 | # Concat class and instance examples for prior preservation. 640 | # We do this to avoid doing two forward passes. 641 | if with_prior_preservation: 642 | input_ids += [example["class_prompt_ids"] for example in examples] 643 | pixel_values += [example["class_images"] for example in examples] 644 | if has_attention_mask: 645 | attention_mask += [example["class_attention_mask"] for example in examples] 646 | 647 | pixel_values = torch.stack(pixel_values) 648 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 649 | 650 | input_ids = torch.cat(input_ids, dim=0) 651 | 652 | batch = { 653 | "input_ids": input_ids, 654 | "pixel_values": pixel_values, 655 | } 656 | 657 | if has_attention_mask: 658 | batch["attention_mask"] = attention_mask 659 | 660 | return batch 661 | 662 | 663 | class PromptDataset(Dataset): 664 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs." 665 | 666 | def __init__(self, prompt, num_samples): 667 | self.prompt = prompt 668 | self.num_samples = num_samples 669 | 670 | def __len__(self): 671 | return self.num_samples 672 | 673 | def __getitem__(self, index): 674 | example = {} 675 | example["prompt"] = self.prompt 676 | example["index"] = index 677 | return example 678 | 679 | 680 | def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): 681 | if tokenizer_max_length is not None: 682 | max_length = tokenizer_max_length 683 | else: 684 | max_length = tokenizer.model_max_length 685 | 686 | text_inputs = tokenizer( 687 | prompt, 688 | truncation=True, 689 | padding="max_length", 690 | max_length=max_length, 691 | return_tensors="pt", 692 | ) 693 | 694 | return text_inputs 695 | 696 | 697 | def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): 698 | text_input_ids = input_ids.to(text_encoder.device) 699 | 700 | if text_encoder_use_attention_mask: 701 | attention_mask = attention_mask.to(text_encoder.device) 702 | else: 703 | attention_mask = None 704 | 705 | prompt_embeds = text_encoder( 706 | text_input_ids, 707 | attention_mask=attention_mask, 708 | ) 709 | prompt_embeds = prompt_embeds[0] 710 | 711 | return prompt_embeds 712 | 713 | 714 | def main(args): 715 | logging_dir = Path(args.output_dir, args.logging_dir) 716 | 717 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 718 | 719 | accelerator = Accelerator( 720 | gradient_accumulation_steps=args.gradient_accumulation_steps, 721 | mixed_precision=args.mixed_precision, 722 | log_with=args.report_to, 723 | project_config=accelerator_project_config, 724 | ) 725 | 726 | if args.report_to == "wandb": 727 | if not is_wandb_available(): 728 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") 729 | import wandb 730 | 731 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate 732 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. 733 | # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate. 734 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: 735 | raise ValueError( 736 | "Gradient accumulation is not supported when training the text encoder in distributed training. " 737 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 738 | ) 739 | 740 | # Make one log on every process with the configuration for debugging. 741 | logging.basicConfig( 742 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 743 | datefmt="%m/%d/%Y %H:%M:%S", 744 | level=logging.INFO, 745 | ) 746 | logger.info(accelerator.state, main_process_only=False) 747 | if accelerator.is_local_main_process: 748 | transformers.utils.logging.set_verbosity_warning() 749 | diffusers.utils.logging.set_verbosity_info() 750 | else: 751 | transformers.utils.logging.set_verbosity_error() 752 | diffusers.utils.logging.set_verbosity_error() 753 | 754 | # If passed along, set the training seed now. 755 | if args.seed is not None: 756 | set_seed(args.seed) 757 | 758 | # Generate class images if prior preservation is enabled. 759 | if args.with_prior_preservation: 760 | class_images_dir = Path(args.class_data_dir) 761 | if not class_images_dir.exists(): 762 | class_images_dir.mkdir(parents=True) 763 | cur_class_images = len(list(class_images_dir.iterdir())) 764 | 765 | if cur_class_images < args.num_class_images: 766 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 767 | if args.prior_generation_precision == "fp32": 768 | torch_dtype = torch.float32 769 | elif args.prior_generation_precision == "fp16": 770 | torch_dtype = torch.float16 771 | elif args.prior_generation_precision == "bf16": 772 | torch_dtype = torch.bfloat16 773 | pipeline = DiffusionPipeline.from_pretrained( 774 | args.pretrained_model_name_or_path, 775 | torch_dtype=torch_dtype, 776 | safety_checker=None, 777 | revision=args.revision, 778 | variant=args.variant, 779 | ) 780 | pipeline.set_progress_bar_config(disable=True) 781 | 782 | num_new_images = args.num_class_images - cur_class_images 783 | logger.info(f"Number of class images to sample: {num_new_images}.") 784 | 785 | sample_dataset = PromptDataset(args.class_prompt, num_new_images) 786 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) 787 | 788 | sample_dataloader = accelerator.prepare(sample_dataloader) 789 | pipeline.to(accelerator.device) 790 | 791 | for example in tqdm( 792 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process 793 | ): 794 | images = pipeline(example["prompt"]).images 795 | 796 | for i, image in enumerate(images): 797 | hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() 798 | image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" 799 | image.save(image_filename) 800 | 801 | del pipeline 802 | if torch.cuda.is_available(): 803 | torch.cuda.empty_cache() 804 | 805 | # Handle the repository creation 806 | if accelerator.is_main_process: 807 | if args.output_dir is not None: 808 | os.makedirs(args.output_dir, exist_ok=True) 809 | 810 | if args.push_to_hub: 811 | repo_id = create_repo( 812 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token 813 | ).repo_id 814 | 815 | # Load the tokenizer 816 | if args.tokenizer_name: 817 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) 818 | elif args.pretrained_model_name_or_path: 819 | tokenizer = AutoTokenizer.from_pretrained( 820 | args.pretrained_model_name_or_path, 821 | subfolder="tokenizer", 822 | revision=args.revision, 823 | use_fast=False, 824 | ) 825 | 826 | # import correct text encoder class 827 | text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) 828 | 829 | # Load scheduler and models 830 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 831 | text_encoder = text_encoder_cls.from_pretrained( 832 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant 833 | ) 834 | try: 835 | vae = AutoencoderKL.from_pretrained( 836 | args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant 837 | ) 838 | except OSError: 839 | # IF does not have a VAE so let's just set it to None 840 | # We don't have to error out here 841 | vae = None 842 | 843 | unet = UNet2DConditionModel.from_pretrained( 844 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant 845 | ) 846 | 847 | # We only train the additional adapter LoRA layers 848 | if vae is not None: 849 | vae.requires_grad_(False) 850 | text_encoder.requires_grad_(False) 851 | unet.requires_grad_(False) 852 | 853 | 854 | # TODO: RESUME 855 | if args.resume_dir: 856 | lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(args.resume_dir) 857 | # print(lora_state_dict.keys()) 858 | 859 | # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision 860 | # as these weights are only used for inference, keeping weights in full precision is not required. 861 | weight_dtype = torch.float32 862 | if accelerator.mixed_precision == "fp16": 863 | weight_dtype = torch.float16 864 | elif accelerator.mixed_precision == "bf16": 865 | weight_dtype = torch.bfloat16 866 | 867 | # Move unet, vae and text_encoder to device and cast to weight_dtype 868 | unet.to(accelerator.device, dtype=weight_dtype) 869 | if vae is not None: 870 | vae.to(accelerator.device, dtype=weight_dtype) 871 | text_encoder.to(accelerator.device, dtype=weight_dtype) 872 | 873 | if args.enable_xformers_memory_efficient_attention: 874 | if is_xformers_available(): 875 | import xformers 876 | 877 | xformers_version = version.parse(xformers.__version__) 878 | if xformers_version == version.parse("0.0.16"): 879 | logger.warn( 880 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 881 | ) 882 | unet.enable_xformers_memory_efficient_attention() 883 | else: 884 | raise ValueError("xformers is not available. Make sure it is installed correctly") 885 | 886 | if args.gradient_checkpointing: 887 | unet.enable_gradient_checkpointing() 888 | if args.train_text_encoder: 889 | text_encoder.gradient_checkpointing_enable() 890 | 891 | # now we will add new LoRA weights to the attention layers 892 | # It's important to realize here how many attention weights will be added and of which sizes 893 | # The sizes of the attention layers consist only of two different variables: 894 | # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. 895 | # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. 896 | 897 | # Let's first see how many attention processors we will have to set. 898 | # For Stable Diffusion, it should be equal to: 899 | # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 900 | # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 901 | # - up blocks (2x attention layers) * (3x transformer layers) * (3x up blocks) = 18 902 | # => 32 layers 903 | 904 | # Set correct lora layers, the total layer is 32, total lora is 32*4=128 905 | unet_lora_parameters = [] 906 | for i, (attn_processor_name, attn_processor) in enumerate(unet.attn_processors.items()): 907 | print("unet.attn_processor->%d:%s" % (i, attn_processor_name), attn_processor) 908 | # Parse the attention module. 909 | attn_module = unet 910 | for n in attn_processor_name.split(".")[:-1]: 911 | attn_module = getattr(attn_module, n) 912 | print("attn_module:",attn_module) 913 | 914 | # Set the `lora_layer` attribute of the attention-related matrices. 915 | attn_module.to_q.set_lora_layer( 916 | LoRALinearLayer( 917 | in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank 918 | ) 919 | ) 920 | attn_module.to_k.set_lora_layer( 921 | LoRALinearLayer( 922 | in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank 923 | ) 924 | ) 925 | attn_module.to_v.set_lora_layer( 926 | LoRALinearLayer( 927 | in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank 928 | ) 929 | ) 930 | attn_module.to_out[0].set_lora_layer( 931 | LoRALinearLayer( 932 | in_features=attn_module.to_out[0].in_features, 933 | out_features=attn_module.to_out[0].out_features, 934 | rank=args.rank, 935 | ) 936 | ) 937 | # Accumulate the LoRA params to optimize. 938 | unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) 939 | unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) 940 | unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) 941 | unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) 942 | # TODO: # copy weights 943 | for layer_name in ['to_q', 'to_k', 'to_v', 'to_out']: 944 | attn_processor_name = attn_processor_name.replace('.processor', '') 945 | if layer_name == 'to_out': 946 | layer = getattr(attn_module, layer_name)[0].lora_layer 947 | down_key = "unet.%s.%s.0.lora.down.weight" % (attn_processor_name, layer_name) 948 | up_key = "unet.%s.%s.0.lora.up.weight" % (attn_processor_name, layer_name) 949 | else: 950 | layer = getattr(attn_module, layer_name).lora_layer 951 | down_key = "unet.%s.%s.lora.down.weight" % (attn_processor_name, layer_name) 952 | up_key = "unet.%s.%s.lora.up.weight" % (attn_processor_name, layer_name) 953 | 954 | layer.down.weight.data.copy_(lora_state_dict[down_key].to(torch.float32)) 955 | layer.up.weight.data.copy_(lora_state_dict[up_key].to(torch.float32)) 956 | 957 | if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): 958 | attn_module.add_k_proj.set_lora_layer( 959 | LoRALinearLayer( 960 | in_features=attn_module.add_k_proj.in_features, 961 | out_features=attn_module.add_k_proj.out_features, 962 | rank=args.rank, 963 | ) 964 | ) 965 | attn_module.add_v_proj.set_lora_layer( 966 | LoRALinearLayer( 967 | in_features=attn_module.add_v_proj.in_features, 968 | out_features=attn_module.add_v_proj.out_features, 969 | rank=args.rank, 970 | ) 971 | ) 972 | unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters()) 973 | unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters()) 974 | 975 | for layer_name in ['add_k_proj', 'add_v_proj']: 976 | attn_processor_name = attn_processor_name.replace('.processor', '') 977 | layer = getattr(attn_module, layer_name).lora_layer 978 | down_key = "unet.%s.%s.lora.down.weight" % (attn_processor_name, layer_name) 979 | up_key = "unet.%s.%s.lora.up.weight" % (attn_processor_name, layer_name) 980 | # copy weights 981 | layer.down.weight.data.copy_(lora_state_dict[down_key].to(torch.float32)) 982 | layer.up.weight.data.copy_(lora_state_dict[up_key].to(torch.float32)) 983 | print("unet add_proj lora initialized!") 984 | 985 | # The text encoder comes from 🤗 transformers, so we cannot directly modify it. 986 | # So, instead, we monkey-patch the forward calls of its attention-blocks. 987 | if args.train_text_encoder: 988 | # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 989 | # if patch_mlp is True, the finetuning will cover the text encoder mlp, otherwise only the text encoder attention, total lora is (12+12)*4=96 990 | text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, lora_state_dict, dtype=torch.float32, rank=args.rank, patch_mlp=args.patch_mlp) 991 | 992 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 993 | def save_model_hook(models, weights, output_dir): 994 | if accelerator.is_main_process: 995 | # there are only two options here. Either are just the unet attn processor layers 996 | # or there are the unet and text encoder atten layers 997 | unet_lora_layers_to_save = None 998 | text_encoder_lora_layers_to_save = None 999 | 1000 | for model in models: 1001 | if isinstance(model, type(accelerator.unwrap_model(unet))): 1002 | unet_lora_layers_to_save = unet_lora_state_dict(model) 1003 | elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): 1004 | text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model, patch_mlp=args.patch_mlp) 1005 | else: 1006 | raise ValueError(f"unexpected save model: {model.__class__}") 1007 | 1008 | # make sure to pop weight so that corresponding model is not saved again 1009 | weights.pop() 1010 | 1011 | LoraLoaderMixin.save_lora_weights( 1012 | output_dir, 1013 | unet_lora_layers=unet_lora_layers_to_save, 1014 | text_encoder_lora_layers=text_encoder_lora_layers_to_save, 1015 | ) 1016 | 1017 | def load_model_hook(models, input_dir): 1018 | unet_ = None 1019 | text_encoder_ = None 1020 | 1021 | while len(models) > 0: 1022 | model = models.pop() 1023 | 1024 | if isinstance(model, type(accelerator.unwrap_model(unet))): 1025 | unet_ = model 1026 | elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): 1027 | text_encoder_ = model 1028 | else: 1029 | raise ValueError(f"unexpected save model: {model.__class__}") 1030 | 1031 | lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) 1032 | LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) 1033 | LoraLoaderMixin.load_lora_into_text_encoder( 1034 | lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_ 1035 | ) 1036 | 1037 | accelerator.register_save_state_pre_hook(save_model_hook) 1038 | accelerator.register_load_state_pre_hook(load_model_hook) 1039 | 1040 | # Enable TF32 for faster training on Ampere GPUs, 1041 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 1042 | if args.allow_tf32: 1043 | torch.backends.cuda.matmul.allow_tf32 = True 1044 | 1045 | if args.scale_lr: 1046 | args.learning_rate = ( 1047 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 1048 | ) 1049 | 1050 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 1051 | if args.use_8bit_adam: 1052 | try: 1053 | import bitsandbytes as bnb 1054 | except ImportError: 1055 | raise ImportError( 1056 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 1057 | ) 1058 | 1059 | optimizer_class = bnb.optim.AdamW8bit 1060 | else: 1061 | optimizer_class = torch.optim.AdamW 1062 | 1063 | # Optimizer creation 1064 | params_to_optimize = ( 1065 | itertools.chain(unet_lora_parameters, text_lora_parameters) 1066 | if args.train_text_encoder 1067 | else unet_lora_parameters 1068 | ) 1069 | optimizer = optimizer_class( 1070 | params_to_optimize, 1071 | lr=args.learning_rate, 1072 | betas=(args.adam_beta1, args.adam_beta2), 1073 | weight_decay=args.adam_weight_decay, 1074 | eps=args.adam_epsilon, 1075 | ) 1076 | 1077 | if args.pre_compute_text_embeddings: 1078 | 1079 | def compute_text_embeddings(prompt): 1080 | with torch.no_grad(): 1081 | text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length) 1082 | prompt_embeds = encode_prompt( 1083 | text_encoder, 1084 | text_inputs.input_ids, 1085 | text_inputs.attention_mask, 1086 | text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, 1087 | ) 1088 | 1089 | return prompt_embeds 1090 | 1091 | pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) 1092 | validation_prompt_negative_prompt_embeds = compute_text_embeddings("") 1093 | 1094 | if args.validation_prompt is not None: 1095 | validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt) 1096 | else: 1097 | validation_prompt_encoder_hidden_states = None 1098 | 1099 | if args.class_prompt is not None: 1100 | pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt) 1101 | else: 1102 | pre_computed_class_prompt_encoder_hidden_states = None 1103 | 1104 | text_encoder = None 1105 | tokenizer = None 1106 | 1107 | gc.collect() 1108 | torch.cuda.empty_cache() 1109 | else: 1110 | pre_computed_encoder_hidden_states = None 1111 | validation_prompt_encoder_hidden_states = None 1112 | validation_prompt_negative_prompt_embeds = None 1113 | pre_computed_class_prompt_encoder_hidden_states = None 1114 | 1115 | # Dataset and DataLoaders creation: 1116 | train_dataset = DreamBoothDataset( 1117 | instance_image_path=args.instance_image_path, 1118 | instance_prompt=args.instance_prompt, 1119 | class_data_root=args.class_data_dir if args.with_prior_preservation else None, 1120 | class_prompt=args.class_prompt, 1121 | class_num=args.num_class_images, 1122 | tokenizer=tokenizer, 1123 | size=args.resolution, 1124 | center_crop=args.center_crop, 1125 | encoder_hidden_states=pre_computed_encoder_hidden_states, 1126 | class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states, 1127 | tokenizer_max_length=args.tokenizer_max_length, 1128 | ) 1129 | 1130 | train_dataloader = torch.utils.data.DataLoader( 1131 | train_dataset, 1132 | batch_size=args.train_batch_size, 1133 | shuffle=True, 1134 | collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), 1135 | num_workers=args.dataloader_num_workers, 1136 | ) 1137 | 1138 | # Scheduler and math around the number of training steps. 1139 | overrode_max_train_steps = False 1140 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 1141 | if args.max_train_steps is None: 1142 | args.max_train_steps = args.num_train_steps * num_update_steps_per_epoch 1143 | overrode_max_train_steps = True 1144 | 1145 | lr_scheduler = get_scheduler( 1146 | args.lr_scheduler, 1147 | optimizer=optimizer, 1148 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 1149 | num_training_steps=args.max_train_steps * accelerator.num_processes, 1150 | num_cycles=args.lr_num_cycles, 1151 | power=args.lr_power, 1152 | ) 1153 | 1154 | # Prepare everything with our `accelerator`. 1155 | if args.train_text_encoder: 1156 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 1157 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler 1158 | ) 1159 | else: 1160 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 1161 | unet, optimizer, train_dataloader, lr_scheduler 1162 | ) 1163 | 1164 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 1165 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 1166 | if overrode_max_train_steps: 1167 | args.max_train_steps = args.num_train_steps * num_update_steps_per_epoch 1168 | # Afterwards we recalculate our number of training epochs 1169 | 1170 | # We need to initialize the trackers we use, and also store our configuration. 1171 | # The trackers initializes automatically on the main process. 1172 | if accelerator.is_main_process: 1173 | tracker_config = vars(copy.deepcopy(args)) 1174 | tracker_config.pop("validation_images") 1175 | accelerator.init_trackers("dreambooth-lora", config=tracker_config) 1176 | 1177 | # Train! 1178 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 1179 | 1180 | logger.info("***** Running training *****") 1181 | logger.info(f" Num examples = {len(train_dataset)}") 1182 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 1183 | logger.info(f" Num Steps = {args.num_train_steps}") 1184 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 1185 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 1186 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 1187 | logger.info(f" Total optimization steps = {args.max_train_steps}") 1188 | global_step = 0 1189 | first_epoch = 0 1190 | 1191 | 1192 | # Potentially load in the weights and states from a previous save 1193 | if args.resume_from_checkpoint: 1194 | if args.resume_from_checkpoint != "latest": 1195 | path = os.path.basename(args.resume_from_checkpoint) 1196 | else: 1197 | # Get the mos recent checkpoint 1198 | dirs = os.listdir(args.output_dir) 1199 | dirs = [d for d in dirs if d.startswith("checkpoint")] 1200 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 1201 | path = dirs[-1] if len(dirs) > 0 else None 1202 | 1203 | if path is None: 1204 | accelerator.print( 1205 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 1206 | ) 1207 | args.resume_from_checkpoint = None 1208 | initial_global_step = 0 1209 | else: 1210 | accelerator.print(f"Resuming from checkpoint {path}") 1211 | accelerator.load_state(os.path.join(args.output_dir, path)) 1212 | global_step = int(path.split("-")[1]) 1213 | 1214 | initial_global_step = global_step 1215 | first_epoch = global_step // num_update_steps_per_epoch 1216 | else: 1217 | initial_global_step = 0 1218 | 1219 | progress_bar = tqdm( 1220 | range(0, args.num_train_steps), 1221 | initial=initial_global_step, 1222 | desc="Steps", 1223 | # Only show the progress bar once on each machine. 1224 | disable=not accelerator.is_local_main_process, 1225 | ) 1226 | 1227 | t0 = time.time() 1228 | unet.train() 1229 | if args.train_text_encoder: 1230 | text_encoder.train() 1231 | for batch in train_dataloader: 1232 | for step in range(args.num_train_steps): 1233 | with accelerator.accumulate(unet): 1234 | pixel_values = batch["pixel_values"].to(dtype=weight_dtype) 1235 | 1236 | if vae is not None: 1237 | # Convert images to latent space 1238 | model_input = vae.encode(pixel_values).latent_dist.sample() 1239 | model_input = model_input * vae.config.scaling_factor 1240 | else: 1241 | model_input = pixel_values 1242 | 1243 | # Sample noise that we'll add to the latents 1244 | noise = torch.randn_like(model_input) 1245 | bsz, channels, height, width = model_input.shape 1246 | # Sample a random timestep for each image 1247 | timesteps = torch.randint( 1248 | 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device 1249 | ) 1250 | timesteps = timesteps.long() 1251 | 1252 | # Add noise to the model input according to the noise magnitude at each timestep 1253 | # (this is the forward diffusion process) 1254 | noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) 1255 | 1256 | # Get the text embedding for conditioning 1257 | if args.pre_compute_text_embeddings: 1258 | encoder_hidden_states = batch["input_ids"] 1259 | else: 1260 | encoder_hidden_states = encode_prompt( 1261 | text_encoder, 1262 | batch["input_ids"], 1263 | batch["attention_mask"], 1264 | text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, 1265 | ) 1266 | 1267 | if accelerator.unwrap_model(unet).config.in_channels == channels * 2: 1268 | noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) 1269 | 1270 | if args.class_labels_conditioning == "timesteps": 1271 | class_labels = timesteps 1272 | else: 1273 | class_labels = None 1274 | 1275 | # Predict the noise residual 1276 | model_pred = unet( 1277 | noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels 1278 | ).sample 1279 | 1280 | # if model predicts variance, throw away the prediction. we will only train on the 1281 | # simplified training objective. This means that all schedulers using the fine tuned 1282 | # model must be configured to use one of the fixed variance variance types. 1283 | if model_pred.shape[1] == 6: 1284 | model_pred, _ = torch.chunk(model_pred, 2, dim=1) 1285 | 1286 | # Get the target for loss depending on the prediction type 1287 | if noise_scheduler.config.prediction_type == "epsilon": 1288 | target = noise 1289 | elif noise_scheduler.config.prediction_type == "v_prediction": 1290 | target = noise_scheduler.get_velocity(model_input, noise, timesteps) 1291 | else: 1292 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 1293 | 1294 | if args.with_prior_preservation: 1295 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 1296 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 1297 | target, target_prior = torch.chunk(target, 2, dim=0) 1298 | 1299 | # Compute instance loss 1300 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 1301 | 1302 | # Compute prior loss 1303 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") 1304 | 1305 | # Add the prior loss to the instance loss. 1306 | loss = loss + args.prior_loss_weight * prior_loss 1307 | else: 1308 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 1309 | 1310 | accelerator.backward(loss) 1311 | if accelerator.sync_gradients: 1312 | params_to_clip = ( 1313 | itertools.chain(unet_lora_parameters, text_lora_parameters) 1314 | if args.train_text_encoder 1315 | else unet_lora_parameters 1316 | ) 1317 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 1318 | optimizer.step() 1319 | lr_scheduler.step() 1320 | optimizer.zero_grad() 1321 | 1322 | # Checks if the accelerator has performed an optimization step behind the scenes 1323 | if accelerator.sync_gradients: 1324 | progress_bar.update(1) 1325 | global_step += 1 1326 | 1327 | 1328 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 1329 | progress_bar.set_postfix(**logs) 1330 | accelerator.log(logs, step=global_step) 1331 | 1332 | if global_step >= args.max_train_steps: 1333 | break 1334 | 1335 | t1 = time.time() 1336 | # Save the lora layers 1337 | accelerator.wait_for_everyone() 1338 | if accelerator.is_main_process: 1339 | unet = accelerator.unwrap_model(unet) 1340 | unet = unet.to(torch.float32) 1341 | unet_lora_layers = unet_lora_state_dict(unet) 1342 | 1343 | if text_encoder is not None and args.train_text_encoder: 1344 | text_encoder = accelerator.unwrap_model(text_encoder) 1345 | text_encoder = text_encoder.to(torch.float32) 1346 | text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder, patch_mlp=args.patch_mlp) 1347 | else: 1348 | text_encoder_lora_layers = None 1349 | 1350 | LoraLoaderMixin.save_lora_weights( 1351 | save_directory=args.output_dir, 1352 | unet_lora_layers=unet_lora_layers, 1353 | text_encoder_lora_layers=text_encoder_lora_layers, 1354 | ) 1355 | print("unet_lora_layers>>>>>>>>",len(unet_lora_layers)) 1356 | print("text_encoder_lora_layers>>>>>",len(text_encoder_lora_layers)) 1357 | 1358 | # Final inference 1359 | # Load previous pipeline 1360 | pipeline = DiffusionPipeline.from_pretrained( 1361 | args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype 1362 | ) 1363 | 1364 | # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it 1365 | scheduler_args = {} 1366 | 1367 | if "variance_type" in pipeline.scheduler.config: 1368 | variance_type = pipeline.scheduler.config.variance_type 1369 | 1370 | if variance_type in ["learned", "learned_range"]: 1371 | variance_type = "fixed_small" 1372 | 1373 | scheduler_args["variance_type"] = variance_type 1374 | 1375 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) 1376 | 1377 | pipeline = pipeline.to(accelerator.device) 1378 | 1379 | # load attention processors 1380 | pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors") 1381 | 1382 | # run inference 1383 | images = [] 1384 | if args.validation_prompt and args.num_validation_images > 0: 1385 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None 1386 | images = [ 1387 | pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] 1388 | for _ in range(args.num_validation_images) 1389 | ] 1390 | for i,image in enumerate(images): 1391 | image.save("aigc_samples/validation_%d.jpg"%i) 1392 | 1393 | for tracker in accelerator.trackers: 1394 | if tracker.name == "tensorboard": 1395 | np_images = np.stack([np.asarray(img) for img in images]) 1396 | tracker.writer.add_images("test", np_images, step, dataformats="NHWC") 1397 | if tracker.name == "wandb": 1398 | tracker.log( 1399 | { 1400 | "test": [ 1401 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}") 1402 | for i, image in enumerate(images) 1403 | ] 1404 | } 1405 | ) 1406 | 1407 | if args.push_to_hub: 1408 | save_model_card( 1409 | repo_id, 1410 | images=images, 1411 | base_model=args.pretrained_model_name_or_path, 1412 | train_text_encoder=args.train_text_encoder, 1413 | prompt=args.instance_prompt, 1414 | repo_folder=args.output_dir, 1415 | pipeline=pipeline, 1416 | ) 1417 | upload_folder( 1418 | repo_id=repo_id, 1419 | folder_path=args.output_dir, 1420 | commit_message="End of training", 1421 | ignore_patterns=["step_*", "epoch_*"], 1422 | ) 1423 | 1424 | accelerator.end_training() 1425 | print("\ntime elapsed: %f"%(t1-t0)) 1426 | print("==================================complete======================================") 1427 | 1428 | 1429 | if __name__ == "__main__": 1430 | args = parse_args() 1431 | main(args) 1432 | -------------------------------------------------------------------------------- /fast_finetune.sh: -------------------------------------------------------------------------------- 1 | export MODEL_NAME="stable-diffusion-models/realisticVisionV40_v40VAE" 2 | INSTANCE_IMAGE_PATH="projects/AIGC/dataset/FFHQ_test/00083.png" 3 | 4 | export RESUME_DIR="projects/AIGC/experiments2/rank_relax" 5 | export OUTPUT_DIR="projects/AIGC/experiments2/fast_finetune" 6 | 7 | 8 | CUDA_VISIBLE_DEVICES=0 \ 9 | accelerate launch --mixed_precision="fp16" fast_finetune.py \ 10 | --pretrained_model_name_or_path=$MODEL_NAME \ 11 | --instance_image_path=$INSTANCE_IMAGE_PATH \ 12 | --instance_prompt="A [V] face" \ 13 | --resolution=512 \ 14 | --train_batch_size=1 \ 15 | --num_train_steps=25 \ 16 | --learning_rate=1e-4 --lr_scheduler="constant" --lr_warmup_steps=0 \ 17 | --seed=42 \ 18 | --rank=4 \ 19 | --output_dir=$OUTPUT_DIR \ 20 | --resume_dir=$RESUME_DIR \ 21 | --num_validation_images=5 \ 22 | --validation_prompt="A [V] face" \ 23 | --train_text_encoder 24 | # --patch_mlp \ 25 | # --resume_from_checkpoint=$RESUME_DIR \ 26 | 27 | 28 | -------------------------------------------------------------------------------- /images/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/images/1.jpg -------------------------------------------------------------------------------- /images/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/images/2.jpg -------------------------------------------------------------------------------- /images/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/images/3.jpg -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/modules/__init__.py -------------------------------------------------------------------------------- /modules/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/modules/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /modules/__pycache__/attention.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MFaceTech/HyperDreamBooth/cdeadc831256d3cafc8c7c1fec019a22dbf182f8/modules/__pycache__/attention.cpython-311.pyc -------------------------------------------------------------------------------- /modules/attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange, repeat 8 | 9 | from .utils import default 10 | from .utils.xformers_utils import ( 11 | XFORMERS_AVAIL, 12 | memory_efficient_attention 13 | ) 14 | 15 | 16 | class GEGLU(nn.Module): 17 | def __init__(self, dim_in, dim_out): 18 | super().__init__() 19 | self.proj = nn.Linear(dim_in, dim_out * 2) 20 | 21 | def forward(self, x): 22 | x, gate = self.proj(x).chunk(2, dim=-1) 23 | return x * F.gelu(gate) 24 | 25 | 26 | class FeedForward(nn.Module): 27 | def __init__(self, dim, dim_out=None, mult=4, glu=False): 28 | super().__init__() 29 | inner_dim = int(dim * mult) 30 | dim_out = default(dim_out, dim) 31 | project_in = nn.Sequential( 32 | nn.Linear(dim, inner_dim), 33 | nn.GELU() 34 | ) if not glu else GEGLU(dim, inner_dim) 35 | 36 | self.net = nn.Sequential( 37 | project_in, 38 | nn.Linear(inner_dim, dim_out) 39 | ) 40 | # nn.init.constant_(self.net[-1].weight, 0) 41 | # nn.init.constant_(self.net[-1].bias, 0) 42 | 43 | def forward(self, x): 44 | return self.net(x) 45 | 46 | 47 | MEMORY_LAYOUTS = { 48 | 'torch': ( 49 | 'b n (h d) -> b h n d', 50 | 'b h n d -> b n (h d)', 51 | lambda x: (1, x, 1, 1), 52 | ), 53 | 'xformers': ( 54 | 'b n (h d) -> b n h d', 55 | 'b n h d -> b n (h d)', 56 | lambda x: (1, 1, x, 1), 57 | ), 58 | 'vanilla': ( 59 | 'b n (h d) -> b h n d', 60 | 'b h n d -> b n (h d)', 61 | lambda x: (1, x, 1, 1), 62 | ) 63 | } 64 | ATTN_FUNCTION = { 65 | 'torch': F.scaled_dot_product_attention, 66 | 'xformers': memory_efficient_attention 67 | } 68 | 69 | 70 | def vanilla_attention(q, k, v, mask, scale=None): 71 | if scale is None: 72 | scale = math.sqrt(q.size(-1)) 73 | scores = torch.bmm(q, k.transpose(-1, -2)) / scale 74 | if mask is not None: 75 | mask = rearrange(mask, 'b ... -> b (...)') 76 | max_neg_value = -torch.finfo(scores.dtype).max 77 | mask = repeat(mask, 'b j -> (b h) j', h=q.size(-3)) 78 | scores = scores.masked_fill(~mask, max_neg_value) 79 | p_attn = F.softmax(scores, dim=-1) 80 | return torch.bmm(p_attn, v) 81 | 82 | 83 | class Attention(nn.Module): 84 | ''' 85 | Attention Class without norm and residual 86 | (You need to wrap them by your self) 87 | ''' 88 | def __init__( 89 | self, 90 | in_ch, 91 | context_ch=None, 92 | heads=8, 93 | head_ch=64, 94 | self_cross=False, 95 | single_kv_head=False, 96 | attn_backend='torch', 97 | # attn_backend='xformers', 98 | cosine_attn=False, 99 | qk_head_ch=-1 100 | ): 101 | super().__init__() 102 | if heads==-1: 103 | assert in_ch%head_ch==0 104 | heads = in_ch//head_ch 105 | if head_ch==-1: 106 | assert in_ch%heads==0 107 | head_ch = in_ch//heads 108 | if qk_head_ch==-1: 109 | qk_head_ch = head_ch 110 | q_ch = heads * qk_head_ch 111 | k_ch = (1 if single_kv_head else heads) * qk_head_ch 112 | v_ch = (1 if single_kv_head else heads) * head_ch 113 | inner_ch = heads * head_ch 114 | assert inner_ch == in_ch 115 | use_context = context_ch is not None 116 | context_ch = default(context_ch, in_ch) 117 | 118 | if attn_backend == 'xformers': 119 | assert XFORMERS_AVAIL 120 | if attn_backend == 'torch': 121 | assert torch.version.__version__ >= '2.0.0' 122 | 123 | self.heads = heads 124 | self.self_cross = self_cross 125 | self.single_kv_head = single_kv_head 126 | self.attn = ATTN_FUNCTION[attn_backend] 127 | self.memory_layout = MEMORY_LAYOUTS[attn_backend] 128 | self.cosine_attn = cosine_attn 129 | 130 | if cosine_attn: 131 | self.scale = nn.Parameter(torch.ones(MEMORY_LAYOUTS[attn_backend][2](heads))) 132 | else: 133 | self.scale = None 134 | 135 | self.q = nn.Linear(in_ch, q_ch, bias=False) 136 | if self_cross and use_context: 137 | self.k = nn.Linear(in_ch, k_ch, bias=False) 138 | self.v = nn.Linear(in_ch, v_ch, bias=False) 139 | self.ck = nn.Linear(context_ch, k_ch, bias=False) 140 | self.cv = nn.Linear(context_ch, v_ch, bias=False) 141 | else: 142 | self.k = nn.Linear(context_ch, k_ch, bias=False) 143 | self.v = nn.Linear(context_ch, v_ch, bias=False) 144 | 145 | self.out = nn.Linear(inner_ch, in_ch) 146 | 147 | def forward(self, x:torch.Tensor, context=None, mask=None): 148 | # Input Projection 149 | heads = self.heads 150 | q = self.q(x) 151 | if self.self_cross: 152 | k = self.k(x) 153 | v = self.v(x) 154 | if context is not None: 155 | ck = self.ck(context) 156 | cv = self.cv(context) 157 | k = torch.concat([k, ck], dim=1) 158 | v = torch.concat([v, cv], dim=1) 159 | else: 160 | ctx = default(context, x) 161 | k = self.k(ctx) 162 | v = self.v(ctx) 163 | 164 | # Rearrange for Attention 165 | q = rearrange(q, self.memory_layout[0], h=heads) 166 | if self.single_kv_head: 167 | k = k.unsqueeze(1) 168 | v = v.unsqueeze(1) 169 | 170 | b, _, seq, _ = k.shape 171 | k = k.expand(b, heads, seq, k.size(3)) 172 | v = v.expand(b, heads, seq, v.size(3)) 173 | else: 174 | k = rearrange(k, self.memory_layout[0], h=heads) 175 | v = rearrange(v, self.memory_layout[0], h=heads) 176 | 177 | if self.cosine_attn: 178 | q = (F.normalize(q, dim=-1) * math.sqrt(q.size(-1))).to(v.dtype) 179 | k = (F.normalize(k, dim=-1) * self.scale).to(v.dtype) 180 | 181 | # Attention 182 | out = self.attn( 183 | q.contiguous(), k.contiguous(), v.contiguous(), mask 184 | ) 185 | 186 | # Output Projection 187 | out = rearrange(out, self.memory_layout[1], h=heads) 188 | return self.out(out) 189 | 190 | 191 | class TransformerBlock(nn.Module): 192 | def __init__( 193 | self, 194 | dim, 195 | n_heads, 196 | d_head, 197 | context_dim=None, 198 | gated_ff=True, 199 | self_cross=False, 200 | single_kv_head=False, 201 | attn_backend='torch', 202 | cosine_attn=False, 203 | qk_head_ch=-1, 204 | disable_self_attn=False, 205 | single_attn=False, 206 | ): 207 | super().__init__() 208 | self.single_attn = single_attn 209 | self.disable_self_attn = disable_self_attn or single_attn 210 | 211 | self.norm1 = nn.LayerNorm(dim) 212 | self.attn1 = Attention( 213 | dim, context_dim if self.disable_self_attn else None, n_heads, d_head, 214 | self_cross, single_kv_head, attn_backend, cosine_attn, qk_head_ch 215 | ) 216 | if not single_attn: 217 | self.norm2 = nn.LayerNorm(dim) 218 | self.attn2 = Attention( 219 | dim, context_dim, n_heads, d_head, 220 | self_cross, single_kv_head, attn_backend, cosine_attn, qk_head_ch 221 | ) 222 | self.norm3 = nn.LayerNorm(dim) 223 | self.ff = FeedForward(dim, glu=gated_ff) 224 | 225 | def forward(self, x, context=None): 226 | x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x 227 | if not self.single_attn and (context is not None or self.attn2.self_cross): 228 | x = self.attn2(self.norm2(x), context=context) + x 229 | x = self.ff(self.norm3(x)) + x 230 | return x -------------------------------------------------------------------------------- /modules/hypernet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import * 3 | import numpy as np 4 | import torch 5 | import platform 6 | 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.checkpoint as checkpoint 10 | 11 | from torchvision.transforms.functional import resize 12 | 13 | from timm import create_model 14 | # model download: https://github.com/huggingface/pytorch-image-models 15 | 16 | # from einops import rearrange 17 | 18 | from .attention import TransformerBlock 19 | from .light_lora import LoRALinearLayer 20 | 21 | 22 | def _get_sinusoid_encoding_table(n_position, d_hid): 23 | ''' Sinusoid position encoding table ''' 24 | 25 | def get_position_angle_vec(position): 26 | # this part calculate the position In brackets 27 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 28 | 29 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 30 | # [:, 0::2] are all even subscripts, is dim_2i 31 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 32 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 33 | 34 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 35 | 36 | 37 | class WeightDecoder(nn.Module): 38 | def __init__( 39 | self, 40 | weight_dim: int = 150, 41 | weight_num: int = 168, 42 | decoder_blocks: int = 4, 43 | add_constant: bool = False, 44 | ): 45 | super(WeightDecoder, self).__init__() 46 | self.weight_num = weight_num 47 | self.weight_dim = weight_dim 48 | 49 | self.register_buffer( 50 | 'block_pos_emb', 51 | _get_sinusoid_encoding_table(weight_num * 2, weight_dim) 52 | ) 53 | 54 | # calc heads for mem-eff or flash_attn 55 | heads = 1 56 | while weight_dim % heads == 0 and weight_dim // heads > 64: 57 | heads *= 2 58 | heads //= 2 59 | 60 | self.pos_emb_proj = nn.Linear(weight_dim, weight_dim, bias=False) 61 | self.decoder_model = nn.ModuleList( 62 | TransformerBlock(weight_dim, heads, weight_dim // heads, context_dim=weight_dim, gated_ff=False) 63 | for _ in range(decoder_blocks) 64 | ) 65 | # self.delta_proj = nn.Linear(weight_dim, weight_dim, bias=False) 66 | self.delta_proj = nn.Sequential( 67 | nn.LayerNorm(weight_dim), 68 | nn.Linear(weight_dim, weight_dim, bias=False) 69 | ) 70 | self.init_weights(add_constant) 71 | 72 | def init_weights(self, add_constant: bool = False): 73 | def basic_init(module): 74 | if isinstance(module, nn.Linear): 75 | nn.init.xavier_uniform_(module.weight) 76 | if module.bias is not None: 77 | nn.init.constant_(module.bias, 0) 78 | 79 | self.apply(basic_init) 80 | 81 | # For no pre-optimized training, you should consider use the following init 82 | # with self.down = down@down_aux + 1 in LiLoRAAttnProcessor 83 | # if add_constant: 84 | # torch.nn.init.constant_(self.delta_proj[1].weight, 0) 85 | 86 | # advice from Nataniel Ruiz, looks like 1e-3 is small enough 87 | # else: 88 | # torch.nn.init.normal_(self.delta_proj[1].weight, std=1e-3) 89 | torch.nn.init.normal_(self.delta_proj[1].weight, std=1e-3) 90 | 91 | def forward(self, weight, features): 92 | pos_emb = self.pos_emb_proj(self.block_pos_emb[:, :weight.size(1)].clone().detach()) 93 | h = weight + pos_emb 94 | for decoder in self.decoder_model: 95 | h = decoder(h, context=features) 96 | weight = weight + self.delta_proj(h) 97 | return weight 98 | 99 | 100 | class ImgWeightGenerator(nn.Module): 101 | def __init__( 102 | self, 103 | encoder_model_name: str = "vit_base_patch16_224", 104 | train_encoder: bool = False, 105 | reference_size: int = 224, 106 | weight_dim: int = 240, 107 | weight_num: int = 176, 108 | decoder_blocks: int = 4, 109 | sample_iters: int = 1, 110 | add_constant: bool = False, 111 | ): 112 | super(ImgWeightGenerator, self).__init__() 113 | self.ref_size = reference_size 114 | self.weight_num = weight_num 115 | self.weight_dim = weight_dim 116 | self.sample_iters = sample_iters 117 | self.train_encoder = train_encoder 118 | self.encoder_model_name = encoder_model_name 119 | self.register_buffer( 120 | 'block_pos_emb', 121 | _get_sinusoid_encoding_table(weight_num * 2, weight_dim) 122 | ) 123 | # operating system 124 | if platform.system().lower() == "linux": 125 | model_dir = "projects/AIGC/model_zoo" 126 | elif platform.system().lower() == "darwin": 127 | model_dir = "workspace/model_zoo" 128 | else: 129 | raise ValueError('Operating system does not support!') 130 | 131 | # vit encoder 132 | print("encoder_model_name: ", encoder_model_name) 133 | if encoder_model_name == "vit_base_patch16_224": 134 | # forward_features: [1, 197, 768] 135 | checkpoint_path = os.path.join(model_dir, "B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz") 136 | elif encoder_model_name == "models--timm--vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k.safetensors": 137 | # forward_features: [1, 257, 1280] 138 | checkpoint_path = os.path.join(model_dir, "models--timm--vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k.safetensors") 139 | elif encoder_model_name == "vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k": 140 | # forward_features: [1, 577, 1280] 141 | checkpoint_path = os.path.join(model_dir, "models--timm--vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k.safetensors") 142 | else: 143 | raise ValueError('%s does not exist!'%encoder_model_name) 144 | # Create ViT Model 145 | self.encoder_model: nn.Module = create_model(encoder_model_name, checkpoint_path=checkpoint_path) 146 | # self.encoder_model: nn.Module = create_model(encoder_model_name, pretrained=True) 147 | for p in self.encoder_model.parameters(): 148 | p.requires_grad_(train_encoder) 149 | 150 | # check encoder model shape and format 151 | test_input = torch.randn(1, 3, reference_size, reference_size) 152 | test_output = self.encoder_model.forward_features(test_input) 153 | if isinstance(test_output, list): 154 | test_output = test_output[-1] 155 | if len(test_output.shape) == 4: 156 | # B, C, H, W -> B, L, C 157 | test_output = test_output.view(1, test_output.size(1), -1).transpose(1, 2) 158 | 159 | self.feature_proj = nn.Linear(test_output.shape[-1], weight_dim, bias=False) 160 | self.pos_emb_proj = nn.Linear(weight_dim, weight_dim, bias=False) 161 | self.decoder_model = WeightDecoder(weight_dim, weight_num, decoder_blocks, add_constant) 162 | 163 | def encode_features(self, ref_img): 164 | # ref_img = resize(ref_img, [self.ref_size, self.ref_size], antialias=True) 165 | if not self.train_encoder: 166 | with torch.no_grad(): 167 | img_features = self.encoder_model.forward_features(ref_img) 168 | else: 169 | img_features = self.encoder_model.forward_features(ref_img) 170 | if isinstance(img_features, list): 171 | img_features = img_features[-1] 172 | if len(img_features.shape) == 4: 173 | # B, C, H, W -> B, L, C 174 | img_features = img_features.view(img_features.size(0), img_features.size(1), -1).transpose(1, 2) 175 | # print("img_features", img_features) 176 | return img_features 177 | 178 | def decode_weight(self, img_features, iters=None, weight=None): 179 | img_features = self.feature_proj(img_features) 180 | 181 | if weight is None: 182 | weight = torch.zeros( 183 | img_features.size(0), self.weight_num, self.weight_dim, 184 | device=img_features.device 185 | ) 186 | 187 | for i in range(iters or self.sample_iters): 188 | weight = self.decoder_model(weight, img_features) 189 | return weight 190 | 191 | def forward(self, ref_img, iters=None, weight=None, ensure_grad=0): 192 | img_features = self.encode_features(ref_img) + ensure_grad 193 | weight = self.decode_weight(img_features, iters, weight) 194 | return weight 195 | 196 | 197 | class HyperDream(nn.Module): 198 | def __init__( 199 | self, 200 | img_encoder_model_name: str = "vit_base_patch16_224", 201 | ref_img_size: int = 224, 202 | weight_dim: int = 240, 203 | weight_num: int = 176, 204 | decoder_blocks: int = 4, 205 | sample_iters: int = 4, 206 | add_constant: bool = False, 207 | train_encoder: bool = False, 208 | ): 209 | super(HyperDream, self).__init__() 210 | self.img_weight_generator = ImgWeightGenerator( 211 | encoder_model_name=img_encoder_model_name, 212 | reference_size=ref_img_size, 213 | weight_dim=weight_dim, 214 | weight_num=weight_num, 215 | decoder_blocks=decoder_blocks, 216 | sample_iters=sample_iters, 217 | train_encoder=train_encoder, 218 | add_constant=add_constant, 219 | ) 220 | self.weight_dim = weight_dim 221 | self.add_constant = add_constant 222 | self.liloras: Dict[str, LoRALinearLayer] = {} 223 | self.liloras_keys: List[str] = [] 224 | self.gradient_checkpointing = False 225 | 226 | def enable_gradient_checkpointing(self): 227 | self.gradient_checkpointing = True 228 | 229 | def train_params(self): 230 | return [p for p in self.parameters() if p.requires_grad] 231 | 232 | def set_lilora(self, liloras): 233 | self.liloras = liloras 234 | if isinstance(liloras, dict): 235 | self.liloras_keys = list(liloras.keys()) # for fixed order 236 | else: 237 | self.liloras_keys = range(len(liloras)) 238 | length = len(self.liloras_keys) 239 | print(f"Total LiLoRAs: {length}, Hypernet params for each image: {length * self.weight_dim}") 240 | 241 | def gen_weight(self, ref_img, iters, weight, ensure_grad=0): 242 | weights = self.img_weight_generator(ref_img, iters, weight, ensure_grad) 243 | weight_list = weights.split(1, dim=1) # [b, n, dim] -> n*[b, 1, dim] 244 | return weights, [weight.squeeze(1) for weight in weight_list] 245 | 246 | def forward(self, ref_img: torch.Tensor, iters: int = None, weight: torch.Tensor = None): 247 | if self.training and self.gradient_checkpointing: 248 | ensure_grad = torch.zeros(1, device=ref_img.device).requires_grad_(True) 249 | weights, weight_list = checkpoint.checkpoint( 250 | self.gen_weight, ref_img, iters, weight, ensure_grad 251 | ) 252 | else: 253 | weights, weight_list = self.gen_weight(ref_img, iters, weight) 254 | 255 | # for key, weight in zip(self.liloras_keys, weight_list): 256 | # self.liloras[key].update_weight(weight, self.add_constant) 257 | 258 | return weights, weight_list 259 | 260 | 261 | class PreOptHyperDream(nn.Module): 262 | def __init__( 263 | self, 264 | rank: int = 1, 265 | down_dim: int = 100, 266 | up_dim: int = 50, 267 | ): 268 | super(PreOptHyperDream, self).__init__() 269 | self.weights = nn.Parameter(torch.tensor(0.0)) 270 | self.rank = rank 271 | self.down_dim = down_dim 272 | self.up_dim = up_dim 273 | self.params_per_lora = (down_dim + up_dim) * rank 274 | self.liloras: Dict[str, LoRALinearLayer] = {} 275 | self.liloras_keys: List[str] = [] 276 | self.gradient_checkpointing = False 277 | self.device = 'cpu' 278 | 279 | def enable_gradient_checkpointing(self): 280 | self.gradient_checkpointing = True 281 | 282 | def train_params(self): 283 | return [p for p in self.parameters() if p.requires_grad] 284 | 285 | def set_device(self, device): 286 | self.device = device 287 | 288 | def set_lilora(self, liloras, identities=1): 289 | self.liloras = liloras 290 | if isinstance(liloras, dict): 291 | self.liloras_keys = list(liloras.keys()) # for fixed order 292 | elif isinstance(liloras, list): 293 | self.liloras_keys = range(len(liloras)) 294 | else: 295 | raise TypeError("liloras only support dict and list!") 296 | 297 | 298 | length = len(self.liloras_keys) 299 | print(f"Total LiLoRAs: {length}, Pre-Optimized params for each image: {length * self.params_per_lora}") 300 | print(f"Pre-Optimized params: {length * self.params_per_lora * identities / 1e6:.1f}M") 301 | del self.weights 302 | 303 | self.length = length 304 | self.weights = nn.ParameterList( 305 | torch.concat([ 306 | torch.randn(1, length, self.down_dim * self.rank), 307 | torch.zeros(1, length, self.up_dim * self.rank) 308 | ], dim=-1) for _ in range(identities) 309 | ) 310 | 311 | def gen_weight(self, identities: torch.Tensor): 312 | weights = torch.concat([self.weights[id] for id in identities], dim=0).to(self.device) 313 | weight_list = weights.split(1, dim=1) # [b, n, dim] -> n*[b, 1, dim] -> n*[b, dim] 314 | return weights, [weight.squeeze(1) for weight in weight_list] 315 | 316 | def forward(self, identities: torch.Tensor): 317 | if self.training and self.gradient_checkpointing: 318 | weights, weight_list = checkpoint.checkpoint( 319 | self.gen_weight, identities 320 | ) 321 | else: 322 | weights, weight_list = self.gen_weight(identities) 323 | 324 | # for key, weight in zip(self.liloras_keys, weight_list): 325 | # self.liloras[key].update_weight(weight) 326 | 327 | return weights, weight_list -------------------------------------------------------------------------------- /modules/hypernet_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | import timm 5 | import platform 6 | from torch.nn import TransformerDecoder, TransformerDecoderLayer 7 | import torch.nn as nn 8 | from torchvision import transforms 9 | from torchvision.transforms.functional import resize 10 | from PIL import Image 11 | 12 | # 创建一个transform对象,包含了一系列的预处理操作 13 | transform = transforms.Compose([ 14 | transforms.Resize((224, 224)), # 将图像resize到模型需要的大小 15 | transforms.ToTensor(), # 将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1] 16 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 17 | # 对图像进行归一化,这里的均值和标准差是ImageNet数据集的均值和标准差 18 | ]) 19 | 20 | 21 | class PositionalEncoding(nn.Module): 22 | """ 23 | 这个PositionalEncoding类的输入维度该是(time, batch, channel)。 24 | 如果输入是(batch, time, channel)的形式,可以使用torch.transpose或者torch.permute先将其转置为(time, batch, channel)的形式, 25 | 然后再传入PositionalEncoding。 26 | """ 27 | def __init__(self, d_model, dropout=0.1, max_len=5000): 28 | super(PositionalEncoding, self).__init__() 29 | self.dropout = nn.Dropout(p=dropout) 30 | pe = torch.zeros(max_len, d_model) 31 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 32 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) 33 | pe[:, 0::2] = torch.sin(position * div_term) 34 | pe[:, 1::2] = torch.cos(position * div_term) 35 | pe = pe.unsqueeze(0).transpose(0, 1) 36 | self.register_buffer('pe', pe) 37 | 38 | def forward(self, x): 39 | x = x + self.pe[:x.size(0), :] 40 | return self.dropout(x) 41 | 42 | 43 | class VisualImageEncoder(torch.nn.Module): 44 | def __init__(self): 45 | global checkpoint_path 46 | super(VisualImageEncoder, self).__init__() 47 | # 加载预训练的ViT-H模型 48 | # self.vit_encoder = timm.create_model('vit_huge_patch16_224', pretrained=True) 49 | # self.vit_encoder = timm.create_model('vit_base_patch16_224', pretrained=True) 50 | # TODO: vit encoder 51 | if platform.system().lower() == "linux": 52 | checkpoint_path = "projects/AIGC/models/liloras/vit_base_patch16_224/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz" 53 | elif platform.system().lower() == "darwin": 54 | checkpoint_path = "model_zoo/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz" 55 | self.vit_encoder: nn.Module = timm.create_model(model_name="vit_base_patch16_224", 56 | checkpoint_path=checkpoint_path) 57 | 58 | def forward(self, x): 59 | """ 60 | 输入图像尺寸为224x224 61 | """ 62 | features = self.vit_encoder.forward_features(x) # 提取图片特征,一个(B,T,C)的3D张量 63 | return features 64 | 65 | 66 | class WeightTransformerDecoder(nn.Module): 67 | def __init__(self, d_model, nhead, num_layers): 68 | super(WeightTransformerDecoder, self).__init__() 69 | self.model_type = 'Transformer' 70 | self.src_mask = None 71 | self.d_model = d_model 72 | self.pos_encoder = PositionalEncoding(d_model) 73 | decoder_layer = TransformerDecoderLayer(d_model, nhead) 74 | self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers) 75 | self.weight_proj = nn.Sequential( 76 | nn.LayerNorm(d_model), 77 | nn.Linear(d_model, d_model, bias=False) 78 | ) 79 | 80 | # 使用均匀分布初始化全连接层的权重 81 | for name, param in self.transformer_decoder.named_parameters(): 82 | if 'linear' in name and 'weight' in name: 83 | nn.init.uniform_(param.data) 84 | 85 | # 将偏置项初始化为0 86 | for name, param in self.transformer_decoder.named_parameters(): 87 | if 'linear' in name and 'bias' in name: 88 | nn.init.zeros_(param.data) 89 | 90 | 91 | 92 | def forward(self, weight_embedding, face_embedding): 93 | """ 94 | # 创建一个随机的weight_embedding和face_embedding 95 | weight_embedding = torch.rand(seq_length, batch_size, embedding_dim) 96 | face_embedding = torch.rand(seq_length, batch_size, embedding_dim) 97 | """ 98 | # if self.src_mask is None or self.src_mask.size(0) != len(weight_embedding): 99 | # device = weight_embedding.device 100 | # mask = self._generate_square_subsequent_mask(len(weight_embedding)).to(device) 101 | # self.src_mask = mask 102 | # hidden_embedding = self.transformer_decoder(pos_embedding, face_embedding, self.src_mask) 103 | 104 | pos_embedding = self.pos_encoder(weight_embedding) 105 | hidden_embedding = self.transformer_decoder(pos_embedding, face_embedding) 106 | output = self.weight_proj(hidden_embedding) 107 | return output 108 | 109 | 110 | def _generate_square_subsequent_mask(self, sz): 111 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 112 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 113 | return mask 114 | 115 | 116 | class HyperNetwork(torch.nn.Module): 117 | def __init__( 118 | self, 119 | ref_img_size: tuple[int] = (224, 224), 120 | rank: int = 1, 121 | down_dim: int = 128, 122 | up_dim: int = 64, 123 | weight_num: int = 128, 124 | iters: int = 4, 125 | train_encoder: bool = False): 126 | super(HyperNetwork, self).__init__() 127 | 128 | self.weight_dim = (down_dim + up_dim) * rank 129 | self.weight_num = weight_num 130 | self.iters = iters 131 | self.train_encoder = train_encoder 132 | self.ref_img_size = ref_img_size 133 | self.visual_image_encoder = VisualImageEncoder() 134 | self.weight_transformer_decoder = WeightTransformerDecoder(d_model=self.weight_dim, nhead=8, num_layers=4) 135 | total_params = sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 136 | print('Number of hypernetwork parameters: {:.2f}M'.format(total_params)) 137 | 138 | # check encoder model shape and format 139 | test_input = torch.randn(1, 3, *ref_img_size) 140 | test_output = self.visual_image_encoder(test_input) 141 | if len(test_output.shape) == 3: 142 | # default shape in (B,T,C) 143 | pass 144 | elif len(test_output.shape) == 4: 145 | # B, C, H, W -> B, T, C 146 | test_output = test_output.view(1, test_output.size(1), -1).transpose(1, 2) 147 | else: 148 | raise ValueError("Output dimension must be 3 or 4") 149 | # 根据输出特征维度设置 150 | feature_dim = test_output.size(-1) 151 | self.feature_proj = nn.Linear(feature_dim, self.weight_dim, bias=False) 152 | 153 | if not train_encoder: 154 | # 设置visual_image_encoder为不可训练 155 | for param in self.visual_image_encoder.parameters(): 156 | param.requires_grad = False 157 | # 设置visual_image_encoder为评估模式 158 | self.visual_image_encoder.eval() 159 | 160 | def train(self, mode=True): 161 | super().train(mode) 162 | if not self.train_encoder: 163 | self.visual_image_encoder.eval() # 确保visual_image_encoder始终在评估模式 164 | 165 | def train_params(self): 166 | return [p for p in self.parameters() if p.requires_grad] 167 | 168 | def forward(self, x): 169 | x = resize(x, self.ref_img_size, antialias=True) 170 | # batch first 171 | image_features = self.visual_image_encoder(x) 172 | # print("image_features:", image_features, image_features.shape) 173 | face_embedding = self.feature_proj(image_features) 174 | # weight_embedding zero initialization 175 | weight_embedding = torch.zeros(face_embedding.size(0), self.weight_num, self.weight_dim, 176 | device=image_features.device) 177 | # batch first to time first 178 | face_embedding = face_embedding.permute(1, 0, 2) 179 | weight_embedding = weight_embedding.permute(1, 0, 2) 180 | # Iterative Prediction 181 | for i in range(self.iters): 182 | weight_embedding += self.weight_transformer_decoder(weight_embedding, face_embedding) 183 | # print("weight_embedding_%d"%i, weight_embedding) 184 | # print("weight_embedding Prediction", weight_embedding.shape) 185 | # time first to batch first 186 | weight_embedding = weight_embedding.permute(1, 0, 2) 187 | # weight = self.weight_proj(weight_embedding) 188 | print("weight:",weight_embedding, weight_embedding.shape) 189 | return weight_embedding 190 | 191 | -------------------------------------------------------------------------------- /modules/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import torch.linalg as linalg 10 | 11 | from tqdm import tqdm 12 | 13 | 14 | def default(val, d): 15 | return val if val is not None else d -------------------------------------------------------------------------------- /modules/utils/lora_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Any, Dict, Iterable, Optional, Union 3 | from diffusers.models import UNet2DConditionModel 4 | 5 | def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: 6 | r""" 7 | Returns: 8 | A state dict containing just the LoRA parameters. 9 | """ 10 | lora_state_dict = {} 11 | 12 | for name, module in unet.named_modules(): 13 | if hasattr(module, "set_lora_layer"): 14 | lora_layer = getattr(module, "lora_layer") 15 | if lora_layer is not None: 16 | current_lora_layer_sd = lora_layer.state_dict() 17 | for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items(): 18 | # The matrix name can either be "down" or "up". 19 | lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param 20 | 21 | return lora_state_dict 22 | 23 | 24 | def text_encoder_lora_state_dict(text_encoder, patch_mlp=False): 25 | state_dict = {} 26 | 27 | def text_encoder_attn_modules(text_encoder): 28 | from transformers import CLIPTextModel, CLIPTextModelWithProjection 29 | 30 | attn_modules = [] 31 | 32 | if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): 33 | for i, layer in enumerate(text_encoder.text_model.encoder.layers): 34 | name = f"text_model.encoder.layers.{i}.self_attn" 35 | mod = layer.self_attn 36 | attn_modules.append((name, mod)) 37 | 38 | return attn_modules 39 | 40 | def text_encoder_mlp_modules(text_encoder): 41 | from transformers import CLIPTextModel, CLIPTextModelWithProjection 42 | mlp_modules = [] 43 | if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): 44 | for i, layer in enumerate(text_encoder.text_model.encoder.layers): 45 | name = f"text_model.encoder.layers.{i}.mlp" 46 | mod = layer.mlp 47 | mlp_modules.append((name, mod)) 48 | 49 | return mlp_modules 50 | 51 | # text encoder attn layer 52 | for name, module in text_encoder_attn_modules(text_encoder): 53 | for k, v in module.q_proj.lora_linear_layer.state_dict().items(): 54 | state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v 55 | 56 | for k, v in module.k_proj.lora_linear_layer.state_dict().items(): 57 | state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v 58 | 59 | for k, v in module.v_proj.lora_linear_layer.state_dict().items(): 60 | state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v 61 | 62 | for k, v in module.out_proj.lora_linear_layer.state_dict().items(): 63 | state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v 64 | 65 | # text encoder mlp layer 66 | if patch_mlp: 67 | for name, module in text_encoder_mlp_modules(text_encoder): 68 | for k, v in module.fc1.lora_linear_layer.state_dict().items(): 69 | state_dict[f"{name}.fc1.lora_linear_layer.{k}"] = v 70 | 71 | for k, v in module.fc2.lora_linear_layer.state_dict().items(): 72 | state_dict[f"{name}.fc2.lora_linear_layer.{k}"] = v 73 | 74 | return state_dict 75 | -------------------------------------------------------------------------------- /modules/utils/xformers_utils.py: -------------------------------------------------------------------------------- 1 | memory_efficient_attention = None 2 | try: 3 | import xformers 4 | except: 5 | pass 6 | 7 | try: 8 | from xformers.ops import memory_efficient_attention 9 | XFORMERS_AVAIL = True 10 | except: 11 | memory_efficient_attention = None 12 | XFORMERS_AVAIL = False 13 | -------------------------------------------------------------------------------- /rank_relax.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import itertools 4 | from PIL import Image 5 | from torchvision.transforms import CenterCrop, Resize, ToTensor, Compose, Normalize 6 | from diffusers.models.attention_processor import ( 7 | AttnAddedKVProcessor, 8 | AttnAddedKVProcessor2_0, 9 | SlicedAttnAddedKVProcessor, 10 | ) 11 | from diffusers import StableDiffusionPipeline 12 | import torch.utils.checkpoint 13 | from modules.relax_lora import LoRALinearLayer, LoraLoaderMixin 14 | from modules.utils.lora_utils import unet_lora_state_dict, text_encoder_lora_state_dict 15 | from modules.hypernet import HyperDream 16 | 17 | pretrain_model_path = "stable-diffusion-models/realisticVisionV40_v40VAE" 18 | hypernet_model_path = "projects/AIGC/experiments2/hypernet/CelebA-HQ-10k-pretrain2/checkpoint-250000" 19 | reference_image_path = "projects/AIGC/dataset/FFHQ_test/00019.png" 20 | output_dir = "projects/AIGC/experiments2/rank_relax" 21 | 22 | # Parameter Settings 23 | train_text_encoder = True 24 | patch_mlp = False 25 | down_dim = 160 26 | up_dim = 80 27 | rank_in = 1 # hypernet output 28 | rank_out = 4 # rank relax output 29 | # vit_model_name = "vit_base_patch16_224" 30 | # vit_model_name = "vit_huge_patch14_clip_224" 31 | vit_model_name = "vit_huge_patch14_clip_336" 32 | 33 | t0 = time.time() 34 | # TODO: 1.load predicted lora weights 35 | pipe = StableDiffusionPipeline.from_pretrained(pretrain_model_path, torch_dtype=torch.float32) 36 | # state_dict, network_alphas = pipe.lora_state_dict(lora_model_path) 37 | pipe.to("cuda") 38 | unet = pipe.unet 39 | text_encoder = pipe.text_encoder 40 | 41 | # TODO: 2.Create rank_relaxed LoRA 42 | unet_lora_parameters = [] 43 | unet_lora_linear_layers = [] 44 | print("Create a combined LoRA consisted of Frozen LoRA and Trainable LoRA.") 45 | for i, (attn_processor_name, attn_processor) in enumerate(unet.attn_processors.items()): 46 | print("unet.attn_processor->%d:%s" % (i, attn_processor_name), attn_processor) 47 | # attn_processor_name: mid_block.attentions.0.transformer_blocks.0.attn1.processor 48 | # Parse the attention module. 49 | attn_module = unet 50 | for n in attn_processor_name.split(".")[:-1]: 51 | attn_module = getattr(attn_module, n) 52 | print("attn_module:", attn_module) 53 | 54 | # Set the `lora_layer` attribute of the attention-related matrices. 55 | attn_module.to_q.set_lora_layer( 56 | LoRALinearLayer( 57 | in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=rank_out 58 | ) 59 | ) 60 | attn_module.to_k.set_lora_layer( 61 | LoRALinearLayer( 62 | in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=rank_out 63 | ) 64 | ) 65 | attn_module.to_v.set_lora_layer( 66 | LoRALinearLayer( 67 | in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=rank_out 68 | ) 69 | ) 70 | attn_module.to_out[0].set_lora_layer( 71 | LoRALinearLayer( 72 | in_features=attn_module.to_out[0].in_features, out_features=attn_module.to_out[0].out_features, 73 | rank=rank_out, 74 | ) 75 | ) 76 | # Accumulate the LoRA params to optimize. 77 | unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) 78 | unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) 79 | unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) 80 | unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) 81 | 82 | # Accumulate the LoRALinerLayer to optimize. 83 | unet_lora_linear_layers.append(attn_module.to_q.lora_layer) 84 | unet_lora_linear_layers.append(attn_module.to_k.lora_layer) 85 | unet_lora_linear_layers.append(attn_module.to_v.lora_layer) 86 | unet_lora_linear_layers.append(attn_module.to_out[0].lora_layer) 87 | 88 | if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): 89 | attn_module.add_k_proj.set_lora_layer( 90 | LoRALinearLayer( 91 | in_features=attn_module.add_k_proj.in_features, 92 | out_features=attn_module.add_k_proj.out_features, 93 | rank=rank_out, 94 | ) 95 | ) 96 | attn_module.add_v_proj.set_lora_layer( 97 | LoRALinearLayer( 98 | in_features=attn_module.add_v_proj.in_features, 99 | out_features=attn_module.add_v_proj.out_features, 100 | rank=rank_out, 101 | ) 102 | ) 103 | unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters()) 104 | unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters()) 105 | 106 | unet_lora_linear_layers.append(attn_module.add_k_proj.lora_layer) 107 | unet_lora_linear_layers.append(attn_module.add_v_proj.lora_layer) 108 | 109 | # The text encoder comes from 🤗 transformers, so we cannot directly modify it. 110 | # So, instead, we monkey-patch the forward calls of its attention-blocks. 111 | if train_text_encoder: 112 | # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 113 | # if patch_mlp is True, the finetuning will cover the text encoder mlp, otherwise only the text encoder attention, total lora is (12+12)*4=96 114 | # if state_dict is not None, the frozen linear will be initializaed. 115 | text_lora_parameters, text_encoder_lora_linear_layers = LoraLoaderMixin._modify_text_encoder(text_encoder, 116 | state_dict=None, 117 | dtype=torch.float32, 118 | rank=rank_out, 119 | patch_mlp=patch_mlp) 120 | # print(text_encoder_lora_linear_layers) 121 | 122 | # total loras 123 | lora_linear_layers = unet_lora_linear_layers + text_encoder_lora_linear_layers if train_text_encoder else unet_lora_linear_layers 124 | print("========================================================================") 125 | t1 = time.time() 126 | 127 | # TODO: 3.Convert rank_lora to a standard LoRA 128 | print("Create Hypernet...") 129 | if vit_model_name == "vit_base_patch16_224": 130 | img_encoder_model_name = "vit_base_patch16_224" 131 | ref_img_size = 224 132 | mean = [0.5000] 133 | std = [0.5000] 134 | elif vit_model_name == "vit_huge_patch14_clip_224": 135 | img_encoder_model_name = "vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k" 136 | ref_img_size = 224 137 | mean = [0.4815, 0.4578, 0.4082] 138 | std = [0.2686, 0.2613, 0.2758] 139 | elif vit_model_name == "vit_huge_patch14_clip_336": 140 | img_encoder_model_name = "vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k" 141 | ref_img_size = 336 142 | mean = [0.4815, 0.4578, 0.4082] 143 | std = [0.2686, 0.2613, 0.2758] 144 | else: 145 | raise ValueError("%s does not supports!" % vit_model_name) 146 | 147 | hypernet_transposes = Compose([ 148 | Resize(size=ref_img_size), 149 | CenterCrop(size=(ref_img_size, ref_img_size)), 150 | ToTensor(), 151 | Normalize(mean=mean, std=std), 152 | ]) 153 | 154 | hypernetwork = HyperDream( 155 | img_encoder_model_name=img_encoder_model_name, 156 | ref_img_size=ref_img_size, 157 | weight_num=len(lora_linear_layers), 158 | weight_dim=(up_dim + down_dim) * rank_in, 159 | ) 160 | hypernetwork.set_lilora(lora_linear_layers) 161 | 162 | if os.path.isdir(hypernet_model_path): 163 | path = os.path.join(hypernet_model_path, "hypernetwork.bin") 164 | weight_dict = torch.load(path) 165 | sd = weight_dict['hypernetwork'] 166 | hypernetwork.load_state_dict(sd) 167 | else: 168 | weight_dict = torch.load(hypernet_model_path['hypernetwork']) 169 | sd = weight_dict['hypernetwork'] 170 | hypernetwork.load_state_dict(sd) 171 | 172 | for i, lilora in enumerate(lora_linear_layers): 173 | seed = weight_dict['aux_seed_%d' % i] 174 | down_aux = weight_dict['down_aux_%d' % i] 175 | up_aux = weight_dict['up_aux_%d' % i] 176 | 177 | print(f"Hypernet weights loaded from: {hypernet_model_path}") 178 | 179 | hypernetwork = hypernetwork.to("cuda") 180 | hypernetwork = hypernetwork.eval() 181 | 182 | ref_img = Image.open(reference_image_path).convert("RGB") 183 | ref_img = hypernet_transposes(ref_img).unsqueeze(0).to("cuda") 184 | # warmup 185 | _ = hypernetwork(ref_img) 186 | 187 | t2_0 = time.time() 188 | weight, weight_list = hypernetwork(ref_img) 189 | print("weight>>>>>>>>>>>:", weight.shape, weight) 190 | t2_1 = time.time() 191 | 192 | t111 = time.time() 193 | # convert down and up weights to linear layer as LoRALinearLayer 194 | for i, (weight, lora_layer) in enumerate(zip(weight_list, lora_linear_layers)): 195 | seed = weight_dict['aux_seed_%d' % i] 196 | down_aux = weight_dict['down_aux_%d' % i] 197 | up_aux = weight_dict['up_aux_%d' % i] 198 | # reshape weight 199 | down_weight, up_weight = weight.split([down_dim * rank_in, up_dim * rank_in], dim=-1) 200 | down_weight = down_weight.reshape(rank_in, -1) 201 | up_weight = up_weight.reshape(-1, rank_in) 202 | # make weight, matrix multiplication 203 | down = down_weight @ down_aux 204 | up = up_aux @ up_weight 205 | lora_layer.down.weight.data.copy_(down.to(torch.float32)) 206 | lora_layer.up.weight.data.copy_(up.to(torch.float32)) 207 | t222 = time.time() 208 | 209 | print("Convert to standard LoRA...") 210 | for lora_linear_layer in lora_linear_layers: 211 | lora_linear_layer = lora_linear_layer.to("cuda") 212 | lora_linear_layer.convert_to_standard_lora() 213 | print("========================================================================") 214 | t2 = time.time() 215 | 216 | # TODO: 4.Save standard LoRA 217 | print("Save standard LoRA...") 218 | unet_lora_layers_to_save = unet_lora_state_dict(unet) 219 | text_encoder_lora_layers_to_save = None 220 | if train_text_encoder: 221 | text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(text_encoder, patch_mlp=patch_mlp) 222 | 223 | LoraLoaderMixin.save_lora_weights( 224 | save_directory=output_dir, 225 | unet_lora_layers=unet_lora_layers_to_save, 226 | text_encoder_lora_layers=text_encoder_lora_layers_to_save, 227 | ) 228 | 229 | t3 = time.time() 230 | 231 | print("Successfully save LoRA to: %s" % (output_dir)) 232 | print("load pipeline: %f" % (t1 - t0)) 233 | print("load hypernet: %f" % (t2_0 - t1)) 234 | print("hypernet inference: %f" % (t2_1 - t2_0)) 235 | print("copy weight: %f" % (t222 - t111)) 236 | 237 | print("rank relax: %f" % (t2 - t2_1)) 238 | print("model save: %f" % (t3 - t2)) 239 | print("total time: %f" % (t3 - t0)) 240 | print("==================================complete======================================") 241 | -------------------------------------------------------------------------------- /rank_relax_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from diffusers.models.attention_processor import ( 4 | AttnAddedKVProcessor, 5 | AttnAddedKVProcessor2_0, 6 | SlicedAttnAddedKVProcessor, 7 | ) 8 | from diffusers import StableDiffusionPipeline 9 | import torch.utils.checkpoint 10 | from modules.relax_lora import LoRALinearLayer, LoraLoaderMixin 11 | from modules.utils.lora_utils import unet_lora_state_dict, text_encoder_lora_state_dict 12 | 13 | t0 = time.time() 14 | 15 | pretrain_model_path="stable-diffusion-models/realisticVisionV40_v40VAE" 16 | lora_model_path = "projects/AIGC/lora_model_test" 17 | output_dir = "projects/AIGC/experiments2/rank_relax" 18 | 19 | train_text_encoder = True 20 | patch_mlp = False 21 | 22 | # TODO: 1.load predicted lora weights 23 | pipe = StableDiffusionPipeline.from_pretrained(pretrain_model_path, torch_dtype=torch.float32) 24 | state_dict, network_alphas = pipe.lora_state_dict(lora_model_path) 25 | pipe.to("cuda") 26 | 27 | unet = pipe.unet 28 | text_encoder = pipe.text_encoder 29 | # print(state_dict.keys()) 30 | # unet.up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_v.lora.down.weight 31 | # unet.down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.lora.down.weight 32 | # text_encoder.text_model.encoder.layers.11.self_attn.out_proj.lora_linear_layer.down.weight 33 | # text_encoder.text_model.encoder.layers.11.mlp.fc1.lora_linear_layer.down.weight 34 | 35 | # TODO: 2.Create rank_relaxed LoRA and initialize the froze linear layer 36 | rank = 4 # the relax lora rank is 4. 37 | unet_lora_parameters = [] 38 | unet_lora_linear_layers = [] 39 | print("Create a combined LoRA consisted of Frozen LoRA and Trainable LoRA.") 40 | for i, (attn_processor_name, attn_processor) in enumerate(unet.attn_processors.items()): 41 | print("unet.attn_processor->%d:%s" % (i, attn_processor_name), attn_processor) 42 | # attn_processor_name: mid_block.attentions.0.transformer_blocks.0.attn1.processor 43 | # Parse the attention module. 44 | attn_module = unet 45 | for n in attn_processor_name.split(".")[:-1]: 46 | attn_module = getattr(attn_module, n) 47 | print("attn_module:",attn_module) 48 | 49 | # Set the `lora_layer` attribute of the attention-related matrices. 50 | attn_module.to_q.set_lora_layer( 51 | LoRALinearLayer( 52 | in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=rank 53 | ) 54 | ) 55 | attn_module.to_k.set_lora_layer( 56 | LoRALinearLayer( 57 | in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=rank 58 | ) 59 | ) 60 | attn_module.to_v.set_lora_layer( 61 | LoRALinearLayer( 62 | in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=rank 63 | ) 64 | ) 65 | attn_module.to_out[0].set_lora_layer( 66 | LoRALinearLayer( 67 | in_features=attn_module.to_out[0].in_features, 68 | out_features=attn_module.to_out[0].out_features, 69 | rank=rank, 70 | ) 71 | ) 72 | # Accumulate the LoRA params to optimize. 73 | unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) 74 | unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) 75 | unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) 76 | unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) 77 | 78 | # Accumulate the LoRALinerLayer to optimize. 79 | unet_lora_linear_layers.append(attn_module.to_q.lora_layer) 80 | unet_lora_linear_layers.append(attn_module.to_k.lora_layer) 81 | unet_lora_linear_layers.append(attn_module.to_v.lora_layer) 82 | unet_lora_linear_layers.append(attn_module.to_out[0].lora_layer) 83 | 84 | # Set predicted weights to frozen lora 85 | # static_dict: unet.up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.lora.up.weight, 86 | # static_dict: unet.up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_k.lora.down.weight 87 | # attn_processor_name: mid_block.attentions.0.transformer_blocks.0.attn1.processor 88 | for layer_name in ['to_q', 'to_k', 'to_v', 'to_out']: 89 | attn_processor_name = attn_processor_name.replace('.processor', '') 90 | if layer_name == 'to_out': 91 | layer = getattr(attn_module, layer_name)[0].lora_layer 92 | down_key = "unet.%s.%s.0.lora.down.weight" % (attn_processor_name, layer_name) 93 | up_key = "unet.%s.%s.0.lora.up.weight" % (attn_processor_name, layer_name) 94 | else: 95 | layer = getattr(attn_module, layer_name).lora_layer 96 | down_key = "unet.%s.%s.lora.down.weight" % (attn_processor_name, layer_name) 97 | up_key = "unet.%s.%s.lora.up.weight" % (attn_processor_name, layer_name) 98 | # copy weights 99 | layer.down.weight.data.copy_(state_dict[down_key].to(torch.float32)) 100 | layer.up.weight.data.copy_(state_dict[up_key].to(torch.float32)) 101 | print("unet attention lora initialized!") 102 | 103 | if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): 104 | attn_module.add_k_proj.set_lora_layer( 105 | LoRALinearLayer( 106 | in_features=attn_module.add_k_proj.in_features, 107 | out_features=attn_module.add_k_proj.out_features, 108 | rank=rank, 109 | ) 110 | ) 111 | attn_module.add_v_proj.set_lora_layer( 112 | LoRALinearLayer( 113 | in_features=attn_module.add_v_proj.in_features, 114 | out_features=attn_module.add_v_proj.out_features, 115 | rank=rank, 116 | ) 117 | ) 118 | unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters()) 119 | unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters()) 120 | 121 | unet_lora_linear_layers.append(attn_module.add_k_proj.lora_layer) 122 | unet_lora_linear_layers.append(attn_module.add_v_proj.lora_layer) 123 | 124 | for layer_name in ['add_k_proj', 'add_v_proj']: 125 | attn_processor_name = attn_processor_name.replace('.processor', '') 126 | layer = getattr(attn_module, layer_name).lora_layer 127 | down_key = "unet.%s.%s.lora.down.weight" % (attn_processor_name, layer_name) 128 | up_key = "unet.%s.%s.lora.up.weight" % (attn_processor_name, layer_name) 129 | # copy weights 130 | layer.down.weight.data.copy_(state_dict[down_key].to(torch.float32)) 131 | layer.up.weight.data.copy_(state_dict[up_key].to(torch.float32)) 132 | print("unet add_proj lora initialized!") 133 | 134 | # The text encoder comes from 🤗 transformers, so we cannot directly modify it. 135 | # So, instead, we monkey-patch the forward calls of its attention-blocks. 136 | if train_text_encoder: 137 | # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 138 | # if patch_mlp is True, the finetuning will cover the text encoder mlp, otherwise only the text encoder attention, total lora is (12+12)*4=96 139 | # if state_dict is not None, the frozen linear will be initializaed. 140 | text_lora_parameters, text_encoder_lora_linear_layers = LoraLoaderMixin._modify_text_encoder(text_encoder, state_dict, dtype=torch.float32, rank=rank, patch_mlp=patch_mlp) 141 | # print(text_encoder_lora_linear_layers) 142 | 143 | 144 | # TODO: 3.Convert rank_lora to a standard LoRA 145 | print("Convert rank_lora to a standard LoRA...") 146 | lora_linear_layers = unet_lora_linear_layers + text_encoder_lora_linear_layers \ 147 | if train_text_encoder else unet_lora_linear_layers 148 | 149 | for lora_linear_layer in lora_linear_layers: 150 | lora_linear_layer = lora_linear_layer.to("cuda") 151 | lora_linear_layer.convert_to_standard_lora() 152 | 153 | 154 | # TODO: 4.Save standard LoRA 155 | print("Save standard LoRA...") 156 | unet_lora_layers_to_save = unet_lora_state_dict(unet) 157 | text_encoder_lora_layers_to_save = None 158 | if train_text_encoder: 159 | text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(text_encoder, patch_mlp=patch_mlp) 160 | 161 | 162 | LoraLoaderMixin.save_lora_weights( 163 | save_directory=output_dir, 164 | unet_lora_layers=unet_lora_layers_to_save, 165 | text_encoder_lora_layers=text_encoder_lora_layers_to_save, 166 | ) 167 | 168 | t1 = time.time() 169 | 170 | print("Successfully save LoRA to: %s" % (output_dir)) 171 | print("time elapsed: %f"%(t1-t0)) 172 | print("==================================complete======================================") 173 | 174 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | torchvision 3 | einops 4 | accelerate 5 | timm 6 | transformers 7 | diffusers==0.25.0.dev0 8 | git+https://github.com/huggingface/diffusers.git -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/main/setup.py 17 | 18 | To create the package for PyPI. 19 | 20 | 1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the 21 | documentation. 22 | 23 | If releasing on a special branch, copy the updated README.md on the main branch for the commit you will make 24 | for the post-release and run `make fix-copies` on the main branch as well. 25 | 26 | 2. Run Tests for Amazon Sagemaker. The documentation is located in `./tests/sagemaker/README.md`, otherwise @philschmid. 27 | 28 | 3. Unpin specific versions from setup.py that use a git install. 29 | 30 | 4. Checkout the release branch (v-release, for example v4.19-release), and commit these changes with the 31 | message: "Release: " and push. 32 | 33 | 5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs). 34 | 35 | 6. Add a tag in git to mark the release: "git tag v -m 'Adds tag v for PyPI'" 36 | Push the tag to git: git push --tags origin v-release 37 | 38 | 7. Build both the sources and the wheel. Do not change anything in setup.py between 39 | creating the wheel and the source distribution (obviously). 40 | 41 | For the wheel, run: "python setup.py bdist_wheel" in the top level directory 42 | (This will build a wheel for the Python version you use to build it). 43 | 44 | For the sources, run: "python setup.py sdist" 45 | You should now have a /dist directory with both .whl and .tar.gz source versions. 46 | 47 | Long story cut short, you need to run both before you can upload the distribution to the 48 | test PyPI and the actual PyPI servers: 49 | 50 | python setup.py bdist_wheel && python setup.py sdist 51 | 52 | 8. Check that everything looks correct by uploading the package to the PyPI test server: 53 | 54 | twine upload dist/* -r pypitest 55 | (pypi suggests using twine as other methods upload files via plaintext.) 56 | You may have to specify the repository url, use the following command then: 57 | twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ 58 | 59 | Check that you can install it in a virtualenv by running: 60 | pip install -i https://testpypi.python.org/pypi diffusers 61 | 62 | If you are testing from a Colab Notebook, for instance, then do: 63 | pip install diffusers && pip uninstall diffusers 64 | pip install -i https://testpypi.python.org/pypi diffusers 65 | 66 | Check you can run the following commands: 67 | python -c "from diffusers import __version__; print(__version__)" 68 | python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()" 69 | python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')" 70 | python -c "from diffusers import *" 71 | 72 | 9. Upload the final version to the actual PyPI: 73 | twine upload dist/* -r pypi 74 | 75 | 10. Prepare the release notes and publish them on GitHub once everything is looking hunky-dory. 76 | 77 | 11. Run `make post-release` (or, for a patch release, `make post-patch`). If you were on a branch for the release, 78 | you need to go back to main before executing this. 79 | """ 80 | 81 | import os 82 | import re 83 | import sys 84 | from distutils.core import Command 85 | 86 | from setuptools import find_packages, setup 87 | 88 | 89 | # IMPORTANT: 90 | # 1. all dependencies should be listed here with their version requirements if any 91 | # 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py 92 | _deps = [ 93 | "Pillow", # keep the PIL.Image.Resampling deprecation away 94 | "accelerate>=0.11.0", 95 | "compel==0.1.8", 96 | "datasets", 97 | "filelock", 98 | "flax>=0.4.1", 99 | "hf-doc-builder>=0.3.0", 100 | "huggingface-hub>=0.19.4", 101 | "requests-mock==1.10.0", 102 | "importlib_metadata", 103 | "invisible-watermark>=0.2.0", 104 | "isort>=5.5.4", 105 | "jax>=0.4.1", 106 | "jaxlib>=0.4.1", 107 | "Jinja2", 108 | "k-diffusion>=0.0.12", 109 | "torchsde", 110 | "note_seq", 111 | "librosa", 112 | "numpy", 113 | "omegaconf", 114 | "parameterized", 115 | "peft>=0.6.0", 116 | "protobuf>=3.20.3,<4", 117 | "pytest", 118 | "pytest-timeout", 119 | "pytest-xdist", 120 | "python>=3.8.0", 121 | "ruff==0.1.5", 122 | "safetensors>=0.3.1", 123 | "sentencepiece>=0.1.91,!=0.1.92", 124 | "GitPython<3.1.19", 125 | "scipy", 126 | "onnx", 127 | "regex!=2019.12.17", 128 | "requests", 129 | "tensorboard", 130 | "torch>=1.4", 131 | "torchvision", 132 | "transformers>=4.25.1", 133 | "urllib3<=2.0.0", 134 | ] 135 | 136 | # this is a lookup table with items like: 137 | # 138 | # tokenizers: "huggingface-hub==0.8.0" 139 | # packaging: "packaging" 140 | # 141 | # some of the values are versioned whereas others aren't. 142 | deps = {b: a for a, b in (re.findall(r"^(([^!=<>~]+)(?:[!=<>~].*)?$)", x)[0] for x in _deps)} 143 | 144 | # since we save this data in src/diffusers/dependency_versions_table.py it can be easily accessed from 145 | # anywhere. If you need to quickly access the data from this table in a shell, you can do so easily with: 146 | # 147 | # python -c 'import sys; from diffusers.dependency_versions_table import deps; \ 148 | # print(" ".join([deps[x] for x in sys.argv[1:]]))' tokenizers datasets 149 | # 150 | # Just pass the desired package names to that script as it's shown with 2 packages above. 151 | # 152 | # If diffusers is not yet installed and the work is done from the cloned repo remember to add `PYTHONPATH=src` to the script above 153 | # 154 | # You can then feed this for example to `pip`: 155 | # 156 | # pip install -U $(python -c 'import sys; from diffusers.dependency_versions_table import deps; \ 157 | # print(" ".join([deps[x] for x in sys.argv[1:]]))' tokenizers datasets) 158 | # 159 | 160 | 161 | def deps_list(*pkgs): 162 | return [deps[pkg] for pkg in pkgs] 163 | 164 | 165 | class DepsTableUpdateCommand(Command): 166 | """ 167 | A custom distutils command that updates the dependency table. 168 | usage: python setup.py deps_table_update 169 | """ 170 | 171 | description = "build runtime dependency table" 172 | user_options = [ 173 | # format: (long option, short option, description). 174 | ( 175 | "dep-table-update", 176 | None, 177 | "updates src/diffusers/dependency_versions_table.py", 178 | ), 179 | ] 180 | 181 | def initialize_options(self): 182 | pass 183 | 184 | def finalize_options(self): 185 | pass 186 | 187 | def run(self): 188 | entries = "\n".join([f' "{k}": "{v}",' for k, v in deps.items()]) 189 | content = [ 190 | "# THIS FILE HAS BEEN AUTOGENERATED. To update:", 191 | "# 1. modify the `_deps` dict in setup.py", 192 | "# 2. run `make deps_table_update`", 193 | "deps = {", 194 | entries, 195 | "}", 196 | "", 197 | ] 198 | target = "src/diffusers/dependency_versions_table.py" 199 | print(f"updating {target}") 200 | with open(target, "w", encoding="utf-8", newline="\n") as f: 201 | f.write("\n".join(content)) 202 | 203 | 204 | extras = {} 205 | extras["quality"] = deps_list("urllib3", "isort", "ruff", "hf-doc-builder") 206 | extras["docs"] = deps_list("hf-doc-builder") 207 | extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2") 208 | extras["test"] = deps_list( 209 | "compel", 210 | "GitPython", 211 | "datasets", 212 | "Jinja2", 213 | "invisible-watermark", 214 | "k-diffusion", 215 | "librosa", 216 | "omegaconf", 217 | "parameterized", 218 | "pytest", 219 | "pytest-timeout", 220 | "pytest-xdist", 221 | "requests-mock", 222 | "safetensors", 223 | "sentencepiece", 224 | "scipy", 225 | "torchvision", 226 | "transformers", 227 | ) 228 | extras["torch"] = deps_list("torch", "accelerate") 229 | 230 | if os.name == "nt": # windows 231 | extras["flax"] = [] # jax is not supported on windows 232 | else: 233 | extras["flax"] = deps_list("jax", "jaxlib", "flax") 234 | 235 | extras["dev"] = ( 236 | extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"] 237 | ) 238 | 239 | install_requires = [ 240 | deps["importlib_metadata"], 241 | deps["filelock"], 242 | deps["huggingface-hub"], 243 | deps["numpy"], 244 | deps["regex"], 245 | deps["requests"], 246 | deps["safetensors"], 247 | deps["Pillow"], 248 | ] 249 | 250 | version_range_max = max(sys.version_info[1], 10) + 1 251 | 252 | setup( 253 | name="diffusers", 254 | version="0.25.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) 255 | description="State-of-the-art diffusion in PyTorch and JAX.", 256 | long_description=open("README.md", "r", encoding="utf-8").read(), 257 | long_description_content_type="text/markdown", 258 | keywords="deep learning diffusion jax pytorch stable diffusion audioldm", 259 | license="Apache 2.0 License", 260 | author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/diffusers/graphs/contributors)", 261 | author_email="patrick@huggingface.co", 262 | url="https://github.com/huggingface/diffusers", 263 | package_dir={"": "src"}, 264 | packages=find_packages("src"), 265 | package_data={"diffusers": ["py.typed"]}, 266 | include_package_data=True, 267 | python_requires=">=3.8.0", 268 | install_requires=list(install_requires), 269 | extras_require=extras, 270 | entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]}, 271 | classifiers=[ 272 | "Development Status :: 5 - Production/Stable", 273 | "Intended Audience :: Developers", 274 | "Intended Audience :: Education", 275 | "Intended Audience :: Science/Research", 276 | "License :: OSI Approved :: Apache Software License", 277 | "Operating System :: OS Independent", 278 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 279 | "Programming Language :: Python :: 3", 280 | ] 281 | + [f"Programming Language :: Python :: 3.{i}" for i in range(8, version_range_max)], 282 | cmdclass={"deps_table_update": DepsTableUpdateCommand}, 283 | ) 284 | -------------------------------------------------------------------------------- /train_dreambooth_light_lora.sh: -------------------------------------------------------------------------------- 1 | export MODEL_NAME="stable-diffusion-models/realisticVisionV40_v40VAE" 2 | export INSTANCE_DIR="projects/AIGC/dataset/AI_drawing/instance_dir" 3 | export OUTPUT_DIR="projects/AIGC/lora_model_test" 4 | 5 | 6 | CUDA_VISIBLE_DEVICES=0 \ 7 | accelerate launch --mixed_precision="fp16" train_dreambooth_light_lora.py \ 8 | --pretrained_model_name_or_path=$MODEL_NAME \ 9 | --instance_data_dir=$INSTANCE_DIR \ 10 | --instance_prompt="A [V] face" \ 11 | --resolution=512 \ 12 | --train_batch_size=1 \ 13 | --num_train_epochs=301 --checkpointing_steps=500 \ 14 | --learning_rate=1e-3 --lr_scheduler="constant" --lr_warmup_steps=0 \ 15 | --cfg_drop_rate 0.1 \ 16 | --rank=1 \ 17 | --down_dim=160 \ 18 | --up_dim=80 \ 19 | --output_dir=$OUTPUT_DIR \ 20 | --num_validation_images=5 \ 21 | --validation_prompt="A [V] face" \ 22 | --validation_epochs=300 \ 23 | --train_text_encoder 24 | # --patch_mlp \ 25 | # --resume_from_checkpoint "latest" \ 26 | # --seed=42 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /train_dreambooth_lora.sh: -------------------------------------------------------------------------------- 1 | export MODEL_NAME="stable-diffusion-models/realisticVisionV40_v40VAE" 2 | export INSTANCE_DIR="projects/AIGC/dataset/AI_drawing/instance_dir" 3 | export OUTPUT_DIR="projects/AIGC/lora_model_test" 4 | 5 | CUDA_VISIBLE_DEVICES=0 \ 6 | accelerate launch --mixed_precision="fp16" train_dreambooth_lora.py \ 7 | --pretrained_model_name_or_path=$MODEL_NAME \ 8 | --instance_data_dir=$INSTANCE_DIR \ 9 | --instance_prompt="A [V] face" \ 10 | --resolution=512 \ 11 | --train_batch_size=1 \ 12 | --num_train_epochs=201 --checkpointing_steps=500 \ 13 | --learning_rate=5e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \ 14 | --seed=42 \ 15 | --rank=4 \ 16 | --output_dir=$OUTPUT_DIR \ 17 | --num_validation_images=5 \ 18 | --validation_prompt="A [V] face" \ 19 | --validation_epochs=200 \ 20 | --train_text_encoder 21 | # --patch_mlp \ 22 | # --resume_from_checkpoint=$RESUME_DIR \ 23 | 24 | 25 | -------------------------------------------------------------------------------- /train_hypernet.sh: -------------------------------------------------------------------------------- 1 | # Model 2 | export MODEL_NAME="stable-diffusion-models/realisticVisionV40_v40VAE" 3 | export PRE_OPT_WEIGHT_DIR="projects/AIGC/experiments2/pretrained/CelebA-HQ-10k-fake" 4 | #vit_base_patch16_224 5 | # Image 6 | export INSTANCE_DIR="projects/AIGC/dataset/CelebA-HQ-10k" 7 | export VALIDATION_INPUT_DIR="projects/AIGC/dataset/FFHQ_test" 8 | # Output 9 | export VALIDATION_OUTPUT_DIR="projects/AIGC/experiments2/validation_outputs_10k-4" 10 | export OUTPUT_DIR="experiments/hypernet/CelebA-HQ-10k" 11 | 12 | 13 | CUDA_VISIBLE_DEVICES=0 \ 14 | accelerate launch --mixed_precision="fp16" train_hypernet.py \ 15 | --pretrained_model_name_or_path $MODEL_NAME \ 16 | --pre_opt_weight_path $PRE_OPT_WEIGHT_DIR \ 17 | --instance_data_dir $INSTANCE_DIR \ 18 | --vit_model_name vit_huge_patch14_clip_336 \ 19 | --instance_prompt "A [V] face" \ 20 | --output_dir $OUTPUT_DIR \ 21 | --allow_tf32 \ 22 | --resolution 512 \ 23 | --learning_rate 1e-3 \ 24 | --lr_scheduler cosine \ 25 | --lr_warmup_steps 100 \ 26 | --checkpoints_total_limit 10 \ 27 | --checkpointing_steps 10000 \ 28 | --cfg_drop_rate 0.1 \ 29 | --seed=42 \ 30 | --rank 1 \ 31 | --down_dim 160 \ 32 | --up_dim 80 \ 33 | --train_batch_size 4 \ 34 | --pre_opt_weight_coeff 0.02 \ 35 | --num_train_epochs 400 \ 36 | --resume_from_checkpoint "latest" \ 37 | --validation_prompt="A [V] face" \ 38 | --validation_input_dir $VALIDATION_INPUT_DIR \ 39 | --validation_output_dir $VALIDATION_OUTPUT_DIR \ 40 | --validation_epochs 10 \ 41 | --train_text_encoder 42 | 43 | 44 | # when train_text_encoder, not pre_compute_text_embeddings 45 | # --pre_compute_text_embeddings \ 46 | # --train_text_encoder 47 | # --patch_mlp \ 48 | 49 | 50 | -------------------------------------------------------------------------------- /train_hypernet_pro.sh: -------------------------------------------------------------------------------- 1 | # Model 2 | export MODEL_NAME="stable-diffusion-models/realisticVisionV40_v40VAE" 3 | export PRE_OPT_WEIGHT_DIR="projects/AIGC/experiments2/pretrained/CelebA-HQ-30k" 4 | #vit_base_patch16_224 5 | # Image 6 | export INSTANCE_DIR="projects/AIGC/dataset/CelebA-HQ-30k" 7 | export VALIDATION_INPUT_DIR="projects/AIGC/dataset/FFHQ_test" 8 | # Output 9 | export VALIDATION_OUTPUT_DIR="projects/AIGC/experiments2/validation_outputs_10k-7" 10 | export OUTPUT_DIR="AIGC/experiments2/hypernet/CelebA-HQ-30k-no-pretrain-2" 11 | 12 | CUDA_VISIBLE_DEVICES=0 \ 13 | accelerate launch --mixed_precision="fp16" train_hypernet_pro.py \ 14 | --pretrained_model_name_or_path $MODEL_NAME \ 15 | --pre_opt_weight_path $PRE_OPT_WEIGHT_DIR \ 16 | --instance_data_dir $INSTANCE_DIR \ 17 | --vit_model_name vit_huge_patch14_clip_336 \ 18 | --instance_prompt "A [V] face" \ 19 | --output_dir $OUTPUT_DIR \ 20 | --allow_tf32 \ 21 | --resolution 512 \ 22 | --learning_rate 5e-4 \ 23 | --lr_scheduler cosine \ 24 | --lr_warmup_steps 100 \ 25 | --checkpoints_total_limit 10 \ 26 | --checkpointing_steps 10000 \ 27 | --cfg_drop_rate 0.1 \ 28 | --seed=42 \ 29 | --rank 1 \ 30 | --down_dim 160 \ 31 | --up_dim 80 \ 32 | --train_batch_size 16 \ 33 | --pre_opt_weight_coeff 0.0 \ 34 | --num_train_epochs 200 \ 35 | --resume_from_checkpoint "latest" \ 36 | --validation_prompt="A [V] face" \ 37 | --validation_input_dir $VALIDATION_INPUT_DIR \ 38 | --validation_output_dir $VALIDATION_OUTPUT_DIR \ 39 | --validation_epochs 10 \ 40 | --train_text_encoder 41 | 42 | 43 | # when train_text_encoder, not pre_compute_text_embeddings 44 | # --pre_compute_text_embeddings \ 45 | # --train_text_encoder 46 | # --patch_mlp \ 47 | 48 | 49 | -------------------------------------------------------------------------------- /train_preoptnet.sh: -------------------------------------------------------------------------------- 1 | # run train 2 | export MODEL_NAME="stable-diffusion-models/realisticVisionV40_v40VAE" 3 | export INSTANCE_DIR="projects/AIGC/dataset/CelebA-HQ-10k" 4 | export OUTPUT_DIR="experiments/pretrained/CelebA-HQ-10k" 5 | 6 | CUDA_VISIBLE_DEVICES=0 \ 7 | accelerate launch --mixed_precision="fp16" train_preoptnet.py \ 8 | --pretrained_model_name_or_path $MODEL_NAME \ 9 | --instance_data_dir $INSTANCE_DIR \ 10 | --instance_prompt "A [V] face" \ 11 | --output_dir $OUTPUT_DIR \ 12 | --resolution 512 \ 13 | --learning_rate 1e-3 \ 14 | --lr_scheduler constant \ 15 | --checkpoints_total_limit 2 \ 16 | --checkpointing_steps 4 \ 17 | --cfg_drop_rate 0.1 \ 18 | --seed=42 \ 19 | --rank 1 \ 20 | --down_dim 160 \ 21 | --up_dim 80 \ 22 | --train_batch_size 8 \ 23 | --train_steps_per_identity 300 \ 24 | --resume_from_checkpoint "latest" \ 25 | --validation_prompt="A [V] face" \ 26 | --train_text_encoder 27 | 28 | 29 | # when train_text_encoder, not pre_compute_text_embeddings 30 | # --pre_compute_text_embeddings \ 31 | # --train_text_encoder 32 | # --patch_mlp \ --------------------------------------------------------------------------------