├── .gitignore ├── LICENSE ├── README.md ├── assets └── teaser.gif ├── enviroment.yml ├── examples ├── bottle_b+lid_b │ ├── 0_bottle_b0.jpg │ ├── 0_bottle_b1.jpg │ ├── 0_bottle_b2.jpg │ ├── 1_lid_b0.jpg │ ├── 1_lid_b1.jpg │ ├── 1_lid_b2.jpg │ └── masks │ │ ├── 0_bottle_b0.png │ │ ├── 0_bottle_b1.png │ │ ├── 0_bottle_b2.png │ │ ├── 1_lid_b0.png │ │ ├── 1_lid_b1.png │ │ ├── 1_lid_b2.png │ │ └── others │ │ ├── 0_bottle_b0.png │ │ ├── 0_bottle_b1.png │ │ └── 0_bottle_b2.png ├── person_k+hair_c │ ├── 0_person_k0.jpg │ ├── 0_person_k1.jpg │ ├── 0_person_k2.jpg │ ├── 1_hair_c0.jpg │ ├── 1_hair_c1.jpg │ ├── 1_hair_c2.jpg │ └── masks │ │ ├── 0_person_k0.png │ │ ├── 0_person_k1.png │ │ ├── 0_person_k2.png │ │ ├── 1_hair_c0.png │ │ ├── 1_hair_c1.png │ │ ├── 1_hair_c2.png │ │ └── others │ │ ├── 0_person_k0.png │ │ ├── 0_person_k1.png │ │ └── 0_person_k2.png └── tower_a+roof_a │ ├── 0_tower_a0.jpg │ ├── 0_tower_a1.jpg │ ├── 0_tower_a2.jpg │ ├── 1_roof_a0.jpg │ ├── 1_roof_a1.jpg │ ├── 1_roof_a2.jpg │ └── masks │ ├── 0_tower_a0.png │ ├── 0_tower_a1.png │ ├── 0_tower_a2.png │ ├── 1_roof_a0.png │ ├── 1_roof_a1.png │ ├── 1_roof_a2.png │ └── others │ ├── 0_tower_a0.png │ ├── 0_tower_a1.png │ └── 0_tower_a2.png ├── inference.py ├── ptp_utils.py ├── scripts ├── inference.sh └── train.sh ├── tools ├── __init__.py └── mask_generation.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # tests and logs 10 | tests/fixtures/cached_*_text.txt 11 | logs/ 12 | lightning_logs/ 13 | lang_code_data/ 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # celery beat schedule file 90 | celerybeat-schedule 91 | 92 | # SageMath parsed files 93 | *.sage.py 94 | 95 | # Environments 96 | .env 97 | .venv 98 | env/ 99 | venv/ 100 | ENV/ 101 | env.bak/ 102 | venv.bak/ 103 | 104 | # Spyder project settings 105 | .spyderproject 106 | .spyproject 107 | 108 | # Rope project settings 109 | .ropeproject 110 | 111 | # mkdocs documentation 112 | /site 113 | 114 | # mypy 115 | .mypy_cache/ 116 | .dmypy.json 117 | dmypy.json 118 | 119 | # Pyre type checker 120 | .pyre/ 121 | 122 | # vscode 123 | .vs 124 | .vscode 125 | 126 | # Pycharm 127 | .idea 128 | 129 | # TF code 130 | tensorflow_code 131 | 132 | # Models 133 | proc_data 134 | 135 | # examples 136 | runs 137 | /runs_old 138 | /wandb 139 | /examples/runs 140 | /examples/**/*.args 141 | /examples/rag/sweep 142 | 143 | # data 144 | /data 145 | serialization_dir 146 | 147 | # emacs 148 | *.*~ 149 | debug.env 150 | 151 | # vim 152 | .*.swp 153 | 154 | #ctags 155 | tags 156 | 157 | # pre-commit 158 | .pre-commit* 159 | 160 | # .lock 161 | *.lock 162 | 163 | # DS_Store (MacOS) 164 | .DS_Store 165 | # RL pipelines may produce mp4 outputs 166 | *.mp4 167 | 168 | # dependencies 169 | /transformers 170 | 171 | # Project 172 | outputs/ 173 | inputs/ 174 | results* 175 | models/ 176 | dataset* 177 | resources/ 178 | .history/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MagicTailor: Component-Controllable Personalization in Text-to-Image Diffusion Models 2 | 3 | [![Page](https://img.shields.io/badge/Project-Page-green?logo=github&logoColor=white)](https://correr-zhou.github.io/MagicTailor/) 4 | [![Paper](https://img.shields.io/badge/arXiv-Paper-b31b1b?logo=arxiv&logoColor=white)](https://arxiv.org/pdf/2410.13370) 5 | [![Hugging Face](https://img.shields.io/badge/Hugging_Face-%233_Paper_of_the_Day-yellow?logo=huggingface&logoColor=white)](https://huggingface.co/papers?date=2024-10-21) 6 | [![News](https://img.shields.io/badge/Neuronad-News-980e5a?logo=googlechrome&logoColor=white)](https://neuronad.com/ai-news/tech/magictailor-personalization-in-text-to-image-generation/) 7 | [![Video](https://img.shields.io/badge/@ManuAGI01-Video-blue?logo=X&logoColor=white)](https://x.com/ManuAGI01/status/1850923512598516046) 8 | 9 | [Donghao Zhou](https://scholar.google.com/citations?hl=en&user=RsLS11MAAAAJ)1*, 10 | [Jiancheng Huang](https://huangjch526.github.io/)2*, 11 | [Jinbin Bai](https://noyii.github.io/)3, 12 | [Jiaze Wang](https://jiazewang.com/)1, 13 | [Hao Chen](https://scholar.google.com.hk/citations?user=tT03tysAAAAJ&hl=zh-CN)1, 14 | [Guangyong Chen](https://guangyongchen.github.io/)4,
15 | [Xiaowei Hu](https://xw-hu.github.io/)5†, 16 | [Pheng-Ann Heng](http://www.cse.cuhk.edu.hk/~pheng/)1 17 | 18 | 1CUHK   19 | 2SIAT, CAS   20 | 3NUS   21 | 4Zhejiang Lab   22 | 5Shanghai AI Lab 23 | 24 |
25 | 26 | ![teaser](assets/teaser.gif) 27 | 28 | We present **MagicTailor** to enable **component-controllable personalization**, a newly formulated task aiming to reconfigure specific components of concepts during personalization. 29 | 30 |
31 | Abstract 32 |

Recent advancements in text-to-image (T2I) diffusion models have enabled the creation of high-quality images from text prompts, but they still struggle to generate images with precise control over specific visual concepts. Existing approaches can replicate a given concept by learning from reference images, yet they lack the flexibility for fine-grained customization of the individual component within the concept. In this paper, we introduce component-controllable personalization, a novel task that pushes the boundaries of T2I models by allowing users to reconfigure and personalize specific components of concepts. This task is particularly challenging due to two primary obstacles: semantic pollution, where unwanted visual elements corrupt the personalized concept, and semantic imbalance, which causes disproportionate learning of visual semantics. To overcome these challenges, we design MagicTailor, an innovative framework that leverages Dynamic Masked Degradation (DM-Deg) to dynamically perturb undesired visual semantics and Dual-Stream Balancing (DS-Bal) to establish a balanced learning paradigm for visual semantics. Extensive comparisons, ablations, and analyses demonstrate that MagicTailor not only excels in this challenging task but also holds significant promise for practical applications, paving the way for more nuanced and creative image generation.

33 |
34 | 35 | 36 | ## 🔥 Updates 37 | - 2024.10: Our code is released! Feel free to [contact me](mailto:dhzhou@link.cuhk.edu.hk) if anything is unclear. 38 | - 2024.10: [Our paper](https://arxiv.org/pdf/2410.13370) is available. The code is coming soon! 39 | 40 | 41 | ## 🛠️ Installation 42 | 1. Install the conda environment: 43 | ``` 44 | conda env create -f environment.yml 45 | ``` 46 | 2. Install other dependencies (here we take CUDA 11.6 as an example): 47 | ``` 48 | conda activate magictailor 49 | pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 50 | ``` 51 | 3. Clone the Grounded-SAM repository: 52 | ``` 53 | cd {PATH-TO-THIS-CODE} 54 | git clone https://github.com/IDEA-Research/Grounded-Segment-Anything.git 55 | ``` 56 | 4. Follow the section of ["Install without Docker"](https://github.com/IDEA-Research/Grounded-Segment-Anything) to set up Grounded-SAM (please make sure that the CUDA version used for the installation here is the same as that of PyTorch). 57 | 58 | > ❗You can skip Step 3 and 4 if you just want to have a quick try using the example images we provide. 59 | 60 | ## 🔬 Training and Inference 61 | 62 | ### Preparing Data 63 | Directly use the example images in `./examples`, or you can prepare your own pair: 64 | 1. Create a folder named `{CONCEPT}_{ID}+{COMPONENT}_{ID}`, where `{CONCEPT}` and `{COMPONENT}` are the category names for the concept and component respectively, and `{ID}` is the customized index (you can set it to whatever you want) that helps you distinguish. 65 | 2. Put the reference images into this folder, and rename them as `0_{CONCEPT}_{ID}.jpg` and `1_{COMPONENT}_{ID}.jpg` for the images of the concept and component respectively. 66 | 3. Finally, the data will be organized like: 67 | ``` 68 | person_a+hair_a/ 69 | ├── 0_person_a0.jpg 70 | ├── 0_person_a1.jpg 71 | ├── 0_person_a2.jpg 72 | ├── 1_hair_a0.jpg 73 | ├── 1_hair_a1.jpg 74 | └── 1_hair_a2.jpg 75 | ``` 76 | 77 | ### Training 78 | You can train MagicTailor with default hyperparameters: 79 | ``` 80 | python train.py --instance_data_dir {PATH-TO-PREPARED-DATA} 81 | ``` 82 | For example: 83 | ``` 84 | python train.py --instance_data_dir examples/person_k+hair_c 85 | ``` 86 | > ❗Please check the quality of the masks output by Grounded-SAM to ensure that the model runs correctly. 87 | 88 | Alternatively, you can also train it with customized hyperparameters, such as: 89 | ``` 90 | python train.py 91 | --instance_data_dir examples/person_k+hair_c \ 92 | --phase1_train_steps 200 \ 93 | --phase2_train_steps 300 \ 94 | --phase1_learning_rate 1e-4 \ 95 | --phase2_learning_rate 1e-5 \ 96 | --lora_rank 32 \ 97 | --alpha 0.5 \ 98 | --gamma 32 \ 99 | --lambda_preservation 0.2 100 | ``` 101 | You can refer to [our paper](https://arxiv.org/pdf/2410.13370) or `train.py` to understand the meaning of the arguments. 102 | Adjusting these hyperparameters helps yield better results. 103 | 104 | Moreover, we also provide a detailed training script in `scripts/train.sh` for research or development purposes, supporting further modification. 105 | 106 | ### Inference 107 | After training, a model will be saved in `outputs/magictailor`. Placeholder tokens `` and `` will be assigned to the concept and component respectively for text-to-image generation. 108 | 109 | Then, you can generate images with the saved model, just like: 110 | ``` 111 | python inference.py \ 112 | --model_path "outputs/magictailor" \ 113 | --prompt " with " \ 114 | --output_path "outputs/inference/result.jpg" 115 | ``` 116 | 117 | ## 💡 Usage tips 118 | 1. If you face the GPU memory limit, please consider reducing the number of reference images or moving the momentum denoising U-Net (self.unet_m) to another GPU if applicable (see the corresponding code in `train.py`). 119 | 2. While our default hyperparameters are suitable in most cases, further adjustment of hyperparameters for a training pair is still recommended, which helps to achieve a better trade-off between text alignment and identity fidelity. 120 | 121 | 122 | ## 📑 Citation 123 | If you find that our work is helpful in your research, please consider citing our paper: 124 | ```latex 125 | @article{zhou2024magictailor, 126 | title={MagicTailor: Component-Controllable Personalization in Text-to-Image Diffusion Models}, 127 | author={Zhou, Donghao and Huang, Jiancheng and Bai, Jinbin and Wang, Jiaze and Chen, Hao and Chen, Guangyong and Hu, Xiaowei and Heng, Pheng-Ann}, 128 | journal={arXiv preprint arXiv:2410.13370}, 129 | year={2024} 130 | } 131 | ``` 132 | 133 | 134 | ## 🤝 Acknowledgement 135 | Our code is built upon the repositories of [diffusers](https://github.com/huggingface/diffusers) and [Break-A-Scene](https://github.com/google/break-a-scene/). Thank their authors for their excellent work. 136 | -------------------------------------------------------------------------------- /assets/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/assets/teaser.gif -------------------------------------------------------------------------------- /enviroment.yml: -------------------------------------------------------------------------------- 1 | name: magictailor 2 | channels: 3 | - huggingface 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - _sysroot_linux-64_curr_repodata_hack=3=haa98f57_10 11 | - abseil-cpp=20211102.0=h27087fc_1 12 | - aiosignal=1.3.1=pyhd8ed1ab_0 13 | - arrow-cpp=11.0.0=py310h7516544_0 14 | - attrs=23.1.0=pyh71513ae_1 15 | - aws-c-common=0.4.57=he6710b0_1 16 | - aws-c-event-stream=0.1.6=h2531618_5 17 | - aws-checksums=0.1.9=he6710b0_0 18 | - aws-sdk-cpp=1.8.185=hce553d0_0 19 | - binutils_impl_linux-64=2.38=h2a08ee3_1 20 | - binutils_linux-64=2.38.0=hc2dff05_0 21 | - blas=1.0=mkl 22 | - boost-cpp=1.65.1=0 23 | - brotli-bin=1.0.9=h166bdaf_7 24 | - brotli-python=1.0.9=py310hd8f1fbe_7 25 | - bzip2=1.0.8=h7b6447c_0 26 | - c-ares=1.19.0=h5eee18b_0 27 | - ca-certificates=2024.7.2=h06a4308_0 28 | - charset-normalizer=3.3.2=pyhd8ed1ab_0 29 | - click=8.1.3=unix_pyhd8ed1ab_2 30 | - colorama=0.4.6=pyhd8ed1ab_0 31 | - dataclasses=0.8=pyhc8e2a94_3 32 | - ffmpeg=4.3=hf484d3e_0 33 | - filelock=3.12.0=pyhd8ed1ab_0 34 | - flake8=7.0.0=py310h06a4308_0 35 | - freetype=2.12.1=h4a9f257_0 36 | - frozenlist=1.3.0=py310h5764c6d_1 37 | - fsspec=2023.5.0=pyh1a96a4e_0 38 | - future=0.18.3=pyhd8ed1ab_0 39 | - gflags=2.2.2=he1b5a44_1004 40 | - glog=0.5.0=h48cff8f_0 41 | - gmp=6.2.1=h295c915_3 42 | - gnutls=3.6.15=he1e5248_0 43 | - grpc-cpp=1.46.1=h33aed49_1 44 | - icu=58.2=hf484d3e_1000 45 | - idna=3.4=pyhd8ed1ab_0 46 | - importlib-metadata=6.6.0=pyha770c72_0 47 | - importlib_metadata=6.6.0=hd8ed1ab_0 48 | - intel-openmp=2023.1.0=hdb19cb5_46305 49 | - joblib=1.2.0=pyhd8ed1ab_0 50 | - jpeg=9e=h5eee18b_1 51 | - kernel-headers_linux-64=3.10.0=h57e8cba_10 52 | - krb5=1.19.4=h568e23c_0 53 | - lame=3.100=h7b6447c_0 54 | - lcms2=2.12=h3be6417_0 55 | - ld_impl_linux-64=2.38=h1181459_1 56 | - lerc=3.0=h295c915_0 57 | - libbrotlicommon=1.0.9=h166bdaf_7 58 | - libbrotlidec=1.0.9=h166bdaf_7 59 | - libbrotlienc=1.0.9=h166bdaf_7 60 | - libcublas=12.1.0.26=0 61 | - libcufft=11.0.2.4=0 62 | - libcufile=1.9.0.20=0 63 | - libcurand=10.3.5.119=0 64 | - libcurl=7.88.1=h91b91d3_0 65 | - libcusolver=11.4.4.55=0 66 | - libcusparse=12.0.2.55=0 67 | - libdeflate=1.17=h5eee18b_1 68 | - libedit=3.1.20221030=h5eee18b_0 69 | - libev=4.33=h516909a_1 70 | - libevent=2.1.10=h9b69904_4 71 | - libffi=3.4.4=h6a678d5_0 72 | - libgcc-devel_linux-64=11.2.0=h1234567_1 73 | - libgcc-ng=11.2.0=h1234567_1 74 | - libgomp=11.2.0=h1234567_1 75 | - libiconv=1.16=h7f8727e_2 76 | - libidn2=2.3.4=h5eee18b_0 77 | - libnghttp2=1.46.0=hce63b2e_0 78 | - libnpp=12.0.2.50=0 79 | - libnvjitlink=12.1.105=0 80 | - libnvjpeg=12.1.1.14=0 81 | - libpng=1.6.39=h5eee18b_0 82 | - libprotobuf=3.20.3=he621ea3_0 83 | - libssh2=1.10.0=h8f2d780_0 84 | - libstdcxx-devel_linux-64=11.2.0=h1234567_1 85 | - libstdcxx-ng=11.2.0=h1234567_1 86 | - libtasn1=4.19.0=h5eee18b_0 87 | - libthrift=0.15.0=he6d91bd_0 88 | - libtiff=4.5.1=h6a678d5_0 89 | - libunistring=0.9.10=h27cfd23_0 90 | - libuuid=1.41.5=h5eee18b_0 91 | - libwebp-base=1.3.2=h5eee18b_0 92 | - lz4-c=1.9.4=h6a678d5_0 93 | - mccabe=0.7.0=pyhd3eb1b0_0 94 | - mkl=2023.1.0=h6d00ec8_46342 95 | - mkl-service=2.4.0=py310h5eee18b_1 96 | - mkl_fft=1.3.6=py310h1128e8f_1 97 | - mkl_random=1.2.2=py310h1128e8f_1 98 | - ncurses=6.4=h6a678d5_0 99 | - nettle=3.7.3=hbbd107a_1 100 | - ninja-base=1.10.2=hd09550d_5 101 | - numpy=1.24.3=py310h5f9d8c6_1 102 | - numpy-base=1.24.3=py310hb5e798b_1 103 | - openh264=2.1.1=h4ff587b_0 104 | - openjpeg=2.4.0=h3ad879b_0 105 | - openssl=1.1.1w=h7f8727e_0 106 | - orc=1.7.4=hb3bc3d3_1 107 | - packaging=23.1=pyhd8ed1ab_0 108 | - pycodestyle=2.11.1=py310h06a4308_0 109 | - pycparser=2.21=pyhd8ed1ab_0 110 | - pyflakes=3.2.0=py310h06a4308_0 111 | - pysocks=1.7.1=pyha2e5f31_6 112 | - python=3.10.11=h7a1cb2a_2 113 | - python-dateutil=2.8.2=pyhd8ed1ab_0 114 | - python-xxhash=2.0.2=py310h5eee18b_1 115 | - python_abi=3.10=2_cp310 116 | - pytz=2023.3=pyhd8ed1ab_0 117 | - pyyaml=6.0=py310h5764c6d_4 118 | - re2=2022.04.01=h27087fc_0 119 | - readline=8.2=h5eee18b_0 120 | - regex=2022.4.24=py310h5764c6d_0 121 | - requests=2.32.3=py310h06a4308_0 122 | - responses=0.13.3=pyhd3eb1b0_0 123 | - sacremoses=master=py_0 124 | - six=1.16.0=pyh6c4a22f_0 125 | - snappy=1.1.9=h295c915_0 126 | - sqlite=3.41.2=h5eee18b_0 127 | - sysroot_linux-64=2.17=h57e8cba_10 128 | - tbb=2021.8.0=hdb19cb5_0 129 | - tk=8.6.12=h1ccaba5_0 130 | - typing-extensions=4.11.0=py310h06a4308_0 131 | - typing_extensions=4.11.0=py310h06a4308_0 132 | - utf8proc=2.6.1=h27cfd23_0 133 | - wheel=0.43.0=pyhd8ed1ab_1 134 | - xxhash=0.8.0=h7f98852_3 135 | - xz=5.4.2=h5eee18b_0 136 | - yaml=0.2.5=h7f98852_2 137 | - zipp=3.15.0=pyhd8ed1ab_0 138 | - zlib=1.2.13=h5eee18b_0 139 | - zstd=1.5.5=hc292b87_0 140 | - pip: 141 | - absl-py==2.1.0 142 | - accelerate==0.27.2 143 | - addict==2.4.0 144 | - aiohttp==3.9.5 145 | - alembic==1.13.1 146 | - asttokens==2.4.1 147 | - async-timeout==4.0.3 148 | - banal==1.0.6 149 | - bitsandbytes==0.43.3 150 | - boto3==1.34.105 151 | - botocore==1.34.105 152 | - certifi==2024.8.30 153 | - cffi==1.16.0 154 | - chumpy==0.70 155 | - cmake==3.30.2 156 | - coloredlogs==15.0.1 157 | - comm==0.2.2 158 | - contourpy==1.2.1 159 | - crcmod==1.7 160 | - cryptography==42.0.7 161 | - cycler==0.12.1 162 | - cython==3.0.10 163 | - dataset==1.6.2 164 | - datasets==2.20.0 165 | - debugpy==1.8.1 166 | - decorator==5.1.1 167 | - defusedxml==0.7.1 168 | - diffusers==0.27.0 169 | - dill==0.3.8 170 | - docker-pycreds==0.4.0 171 | - easydict==1.13 172 | - einops==0.8.0 173 | - exceptiongroup==1.2.1 174 | - executing==2.0.1 175 | - facexlib==0.3.0 176 | - filterpy==1.4.5 177 | - flatbuffers==24.3.25 178 | - fonttools==4.51.0 179 | - freetype-py==2.4.0 180 | - ftfy==6.2.0 181 | - gitdb==4.0.11 182 | - gitpython==3.1.43 183 | - greenlet==3.0.3 184 | - groundingdino==0.1.0 185 | - grpcio==1.64.1 186 | - huggingface-hub==0.23.0 187 | - humanfriendly==10.0 188 | - icecream==2.1.3 189 | - imageio==2.34.1 190 | - imgaug==0.4.0 191 | - iniconfig==2.0.0 192 | - ipykernel==6.29.4 193 | - ipython==8.24.0 194 | - jedi==0.19.1 195 | - jinja2==3.1.4 196 | - jmespath==0.10.0 197 | - json-tricks==3.17.3 198 | - jsonschema==4.22.0 199 | - jsonschema-specifications==2023.12.1 200 | - jupyter-client==8.6.1 201 | - jupyter-core==5.7.2 202 | - kiwisolver==1.4.5 203 | - lazy-loader==0.4 204 | - llvmlite==0.42.0 205 | - lmdb==1.5.1 206 | - mako==1.3.5 207 | - markupsafe==2.1.5 208 | - mdurl==0.1.2 209 | - model-index==0.1.11 210 | - mpmath==1.3.0 211 | - multidict==6.0.5 212 | - multiprocess==0.70.16 213 | - munkres==1.1.4 214 | - nest-asyncio==1.6.0 215 | - networkx==3.3 216 | - numba==0.59.1 217 | - opencv-python==4.9.0.80 218 | - opencv-python-headless==4.9.0.80 219 | - opendatalab==0.0.10 220 | - openmim==0.3.9 221 | - ordered-set==4.1.0 222 | - oss2==2.17.0 223 | - pandas==2.2.2 224 | - parso==0.8.4 225 | - peft==0.12.0 226 | - pexpect==4.9.0 227 | - pillow==10.3.0 228 | - platformdirs==4.2.2 229 | - pluggy==1.5.0 230 | - plyfile==1.0.3 231 | - prompt-toolkit==3.0.43 232 | - protobuf==4.25.3 233 | - psutil==5.9.8 234 | - ptyprocess==0.7.0 235 | - pure-eval==0.2.2 236 | - pyarrow==16.1.0 237 | - pyarrow-hotfix==0.6 238 | - pycryptodome==3.20.0 239 | - pyglet==1.5.27 240 | - pygments==2.18.0 241 | - pyiqa==0.1.11 242 | - pyopengl==3.1.0 243 | - pyparsing==3.1.2 244 | - pyproject-toml==0.0.10 245 | - pyrender==0.1.45 246 | - pytest==8.3.2 247 | - pyzmq==26.0.3 248 | - referencing==0.35.1 249 | - rich==13.4.2 250 | - rpds-py==0.18.1 251 | - s3transfer==0.10.1 252 | - safetensors==0.4.3 253 | - sentencepiece==0.2.0 254 | - sentry-sdk==2.1.1 255 | - setproctitle==1.3.3 256 | - setuptools==69.5.1 257 | - shapely==2.0.4 258 | - smmap==5.0.1 259 | - smplx==0.1.28 260 | - sqlalchemy==1.4.52 261 | - stack-data==0.6.3 262 | - supervision==0.21.0 263 | - sympy==1.12 264 | - tabulate==0.9.0 265 | - tensorboard==2.17.0 266 | - tensorboard-data-server==0.7.2 267 | - tensorboardx==2.6.2.2 268 | - threadpoolctl==3.5.0 269 | - tifffile==2024.5.10 270 | - timm==0.9.16 271 | - tokenizers==0.15.2 272 | - toml==0.10.2 273 | - tomli==2.0.1 274 | - tornado==6.4 275 | - tqdm==4.65.0 276 | - traitlets==5.14.3 277 | - transformers==4.37.2 278 | - trimesh==4.3.2 279 | - triton==2.3.0 280 | - tzdata==2024.1 281 | - urllib3==1.26.18 282 | - wcwidth==0.2.13 283 | - werkzeug==3.0.3 284 | - xtcocotools==1.14.3 285 | - yacs==0.1.8 286 | - yapf==0.40.2 287 | - yarl==1.9.4 288 | -------------------------------------------------------------------------------- /examples/bottle_b+lid_b/0_bottle_b0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/bottle_b+lid_b/0_bottle_b0.jpg -------------------------------------------------------------------------------- /examples/bottle_b+lid_b/0_bottle_b1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/bottle_b+lid_b/0_bottle_b1.jpg -------------------------------------------------------------------------------- /examples/bottle_b+lid_b/0_bottle_b2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/bottle_b+lid_b/0_bottle_b2.jpg -------------------------------------------------------------------------------- /examples/bottle_b+lid_b/1_lid_b0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/bottle_b+lid_b/1_lid_b0.jpg -------------------------------------------------------------------------------- /examples/bottle_b+lid_b/1_lid_b1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/bottle_b+lid_b/1_lid_b1.jpg -------------------------------------------------------------------------------- /examples/bottle_b+lid_b/1_lid_b2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/bottle_b+lid_b/1_lid_b2.jpg -------------------------------------------------------------------------------- /examples/bottle_b+lid_b/masks/0_bottle_b0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/bottle_b+lid_b/masks/0_bottle_b0.png -------------------------------------------------------------------------------- /examples/bottle_b+lid_b/masks/0_bottle_b1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/bottle_b+lid_b/masks/0_bottle_b1.png -------------------------------------------------------------------------------- /examples/bottle_b+lid_b/masks/0_bottle_b2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/bottle_b+lid_b/masks/0_bottle_b2.png -------------------------------------------------------------------------------- /examples/bottle_b+lid_b/masks/1_lid_b0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/bottle_b+lid_b/masks/1_lid_b0.png -------------------------------------------------------------------------------- /examples/bottle_b+lid_b/masks/1_lid_b1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/bottle_b+lid_b/masks/1_lid_b1.png -------------------------------------------------------------------------------- /examples/bottle_b+lid_b/masks/1_lid_b2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/bottle_b+lid_b/masks/1_lid_b2.png -------------------------------------------------------------------------------- /examples/bottle_b+lid_b/masks/others/0_bottle_b0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/bottle_b+lid_b/masks/others/0_bottle_b0.png -------------------------------------------------------------------------------- /examples/bottle_b+lid_b/masks/others/0_bottle_b1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/bottle_b+lid_b/masks/others/0_bottle_b1.png -------------------------------------------------------------------------------- /examples/bottle_b+lid_b/masks/others/0_bottle_b2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/bottle_b+lid_b/masks/others/0_bottle_b2.png -------------------------------------------------------------------------------- /examples/person_k+hair_c/0_person_k0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/person_k+hair_c/0_person_k0.jpg -------------------------------------------------------------------------------- /examples/person_k+hair_c/0_person_k1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/person_k+hair_c/0_person_k1.jpg -------------------------------------------------------------------------------- /examples/person_k+hair_c/0_person_k2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/person_k+hair_c/0_person_k2.jpg -------------------------------------------------------------------------------- /examples/person_k+hair_c/1_hair_c0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/person_k+hair_c/1_hair_c0.jpg -------------------------------------------------------------------------------- /examples/person_k+hair_c/1_hair_c1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/person_k+hair_c/1_hair_c1.jpg -------------------------------------------------------------------------------- /examples/person_k+hair_c/1_hair_c2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/person_k+hair_c/1_hair_c2.jpg -------------------------------------------------------------------------------- /examples/person_k+hair_c/masks/0_person_k0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/person_k+hair_c/masks/0_person_k0.png -------------------------------------------------------------------------------- /examples/person_k+hair_c/masks/0_person_k1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/person_k+hair_c/masks/0_person_k1.png -------------------------------------------------------------------------------- /examples/person_k+hair_c/masks/0_person_k2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/person_k+hair_c/masks/0_person_k2.png -------------------------------------------------------------------------------- /examples/person_k+hair_c/masks/1_hair_c0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/person_k+hair_c/masks/1_hair_c0.png -------------------------------------------------------------------------------- /examples/person_k+hair_c/masks/1_hair_c1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/person_k+hair_c/masks/1_hair_c1.png -------------------------------------------------------------------------------- /examples/person_k+hair_c/masks/1_hair_c2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/person_k+hair_c/masks/1_hair_c2.png -------------------------------------------------------------------------------- /examples/person_k+hair_c/masks/others/0_person_k0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/person_k+hair_c/masks/others/0_person_k0.png -------------------------------------------------------------------------------- /examples/person_k+hair_c/masks/others/0_person_k1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/person_k+hair_c/masks/others/0_person_k1.png -------------------------------------------------------------------------------- /examples/person_k+hair_c/masks/others/0_person_k2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/person_k+hair_c/masks/others/0_person_k2.png -------------------------------------------------------------------------------- /examples/tower_a+roof_a/0_tower_a0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/tower_a+roof_a/0_tower_a0.jpg -------------------------------------------------------------------------------- /examples/tower_a+roof_a/0_tower_a1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/tower_a+roof_a/0_tower_a1.jpg -------------------------------------------------------------------------------- /examples/tower_a+roof_a/0_tower_a2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/tower_a+roof_a/0_tower_a2.jpg -------------------------------------------------------------------------------- /examples/tower_a+roof_a/1_roof_a0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/tower_a+roof_a/1_roof_a0.jpg -------------------------------------------------------------------------------- /examples/tower_a+roof_a/1_roof_a1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/tower_a+roof_a/1_roof_a1.jpg -------------------------------------------------------------------------------- /examples/tower_a+roof_a/1_roof_a2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/tower_a+roof_a/1_roof_a2.jpg -------------------------------------------------------------------------------- /examples/tower_a+roof_a/masks/0_tower_a0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/tower_a+roof_a/masks/0_tower_a0.png -------------------------------------------------------------------------------- /examples/tower_a+roof_a/masks/0_tower_a1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/tower_a+roof_a/masks/0_tower_a1.png -------------------------------------------------------------------------------- /examples/tower_a+roof_a/masks/0_tower_a2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/tower_a+roof_a/masks/0_tower_a2.png -------------------------------------------------------------------------------- /examples/tower_a+roof_a/masks/1_roof_a0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/tower_a+roof_a/masks/1_roof_a0.png -------------------------------------------------------------------------------- /examples/tower_a+roof_a/masks/1_roof_a1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/tower_a+roof_a/masks/1_roof_a1.png -------------------------------------------------------------------------------- /examples/tower_a+roof_a/masks/1_roof_a2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/tower_a+roof_a/masks/1_roof_a2.png -------------------------------------------------------------------------------- /examples/tower_a+roof_a/masks/others/0_tower_a0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/tower_a+roof_a/masks/others/0_tower_a0.png -------------------------------------------------------------------------------- /examples/tower_a+roof_a/masks/others/0_tower_a1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/tower_a+roof_a/masks/others/0_tower_a1.png -------------------------------------------------------------------------------- /examples/tower_a+roof_a/masks/others/0_tower_a2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Correr-Zhou/MagicTailor/5197aec785655b3a855fc87b1974b09e78298947/examples/tower_a+roof_a/masks/others/0_tower_a2.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Below is the copyright notice from Google. 3 | 4 | Please also follow this license when you modify or distribute the code. 5 | """ 6 | 7 | """ 8 | Copyright 2023 Google LLC 9 | 10 | Licensed under the Apache License, Version 2.0 (the "License"); 11 | you may not use this file except in compliance with the License. 12 | You may obtain a copy of the License at 13 | 14 | https://www.apache.org/licenses/LICENSE-2.0 15 | 16 | Unless required by applicable law or agreed to in writing, software 17 | distributed under the License is distributed on an "AS IS" BASIS, 18 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | See the License for the specific language governing permissions and 20 | limitations under the License. 21 | """ 22 | 23 | import argparse 24 | 25 | from diffusers import DiffusionPipeline, DDIMScheduler 26 | from transformers import AutoTokenizer 27 | import torch 28 | 29 | import os 30 | 31 | class MagicTailorInference: 32 | def __init__(self): 33 | self._parse_args() 34 | self._load_pipeline() 35 | 36 | def _parse_args(self): 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("--pretrained_model_name_or_path", type=str, 39 | default="stabilityai/stable-diffusion-2-1-base", 40 | ) 41 | parser.add_argument("--model_path", type=str, required=True) 42 | parser.add_argument("--prompt", type=str, required=True) 43 | parser.add_argument("--output_path", type=str, required=True) 44 | parser.add_argument("--device", type=str, default="cuda") 45 | self.args = parser.parse_args() 46 | 47 | def _load_pipeline(self): 48 | self.pipeline = DiffusionPipeline.from_pretrained( 49 | self.args.pretrained_model_name_or_path, 50 | torch_dtype=torch.float16, 51 | ) 52 | self.pipeline.load_lora_weights(self.args.model_path) 53 | 54 | token_embedding_path = os.path.join(self.args.model_path, 'token_embedding.pth') 55 | token_embedding_state_dict = torch.load(token_embedding_path) 56 | self.pipeline.text_encoder.get_input_embeddings().weight.data = \ 57 | token_embedding_state_dict['weight'].type(torch.float16) 58 | 59 | self.pipeline.tokenizer = AutoTokenizer.from_pretrained( 60 | os.path.join(self.args.model_path, 'tokenizer'), 61 | use_fast=False 62 | ) 63 | 64 | self.pipeline.scheduler = DDIMScheduler( 65 | beta_start=0.00085, 66 | beta_end=0.012, 67 | beta_schedule="scaled_linear", 68 | clip_sample=False, 69 | set_alpha_to_one=False, 70 | ) 71 | self.num_inference_steps = 50 72 | self.guidance_scale = 7.5 73 | 74 | self.pipeline.enable_vae_slicing() 75 | self.pipeline.to(self.args.device) 76 | 77 | @torch.no_grad() 78 | def infer_and_save(self, prompts): 79 | images = self.pipeline( 80 | prompts, 81 | num_inference_steps=self.num_inference_steps, 82 | guidance_scale=self.guidance_scale, 83 | ).images 84 | if not self.args.output_path: 85 | self.args.output_path = os.path.join(self.args.model_path, "inference/result.jpg") 86 | os.makedirs(os.path.dirname(self.args.output_path), exist_ok=True) 87 | images[0].save(self.args.output_path) 88 | print(f"The genearated image is saved to: {self.args.output_path}") 89 | 90 | 91 | if __name__ == "__main__": 92 | inference = MagicTailorInference() 93 | inference.infer_and_save( 94 | prompts=[inference.args.prompt] 95 | ) 96 | -------------------------------------------------------------------------------- /ptp_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Below is the copyright notice from Google. 3 | 4 | Please also follow this license when you modify or distribute the code. 5 | """ 6 | 7 | """ 8 | Copyright 2023 Google LLC 9 | 10 | Licensed under the Apache License, Version 2.0 (the "License"); 11 | you may not use this file except in compliance with the License. 12 | You may obtain a copy of the License at 13 | 14 | https://www.apache.org/licenses/LICENSE-2.0 15 | 16 | Unless required by applicable law or agreed to in writing, software 17 | distributed under the License is distributed on an "AS IS" BASIS, 18 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | See the License for the specific language governing permissions and 20 | limitations under the License. 21 | """ 22 | 23 | import abc 24 | 25 | import cv2 26 | import numpy as np 27 | import torch 28 | 29 | from PIL import Image 30 | from typing import Union, Tuple, List, Dict, Optional 31 | import torch.nn.functional as nnf 32 | 33 | 34 | def text_under_image( 35 | image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0) 36 | ) -> np.ndarray: 37 | h, w, c = image.shape 38 | offset = int(h * 0.2) 39 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 40 | font = cv2.FONT_HERSHEY_SIMPLEX 41 | img[:h] = image 42 | textsize = cv2.getTextSize(text, font, 1, 2)[0] 43 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 44 | cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2) 45 | return img 46 | 47 | 48 | def view_images( 49 | images: Union[np.ndarray, List], 50 | num_rows: int = 1, 51 | offset_ratio: float = 0.02, 52 | display_image: bool = True, 53 | ) -> Image.Image: 54 | """Displays a list of images in a grid.""" 55 | if type(images) is list: 56 | num_empty = len(images) % num_rows 57 | elif images.ndim == 4: 58 | num_empty = images.shape[0] % num_rows 59 | else: 60 | images = [images] 61 | num_empty = 0 62 | 63 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 64 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty 65 | num_items = len(images) 66 | 67 | h, w, c = images[0].shape 68 | offset = int(h * offset_ratio) 69 | num_cols = num_items // num_rows 70 | image_ = ( 71 | np.ones( 72 | ( 73 | h * num_rows + offset * (num_rows - 1), 74 | w * num_cols + offset * (num_cols - 1), 75 | 3, 76 | ), 77 | dtype=np.uint8, 78 | ) 79 | * 255 80 | ) 81 | for i in range(num_rows): 82 | for j in range(num_cols): 83 | image_[ 84 | i * (h + offset) : i * (h + offset) + h :, 85 | j * (w + offset) : j * (w + offset) + w, 86 | ] = images[i * num_cols + j] 87 | 88 | pil_img = Image.fromarray(image_) 89 | 90 | return pil_img 91 | 92 | 93 | class AttentionControl(abc.ABC): 94 | def step_callback(self, x_t): 95 | return x_t 96 | 97 | def between_steps(self): 98 | return 99 | 100 | @property 101 | def num_uncond_att_layers(self): 102 | return 0 103 | 104 | @abc.abstractmethod 105 | def forward(self, attn, is_cross: bool, place_in_unet: str): 106 | raise NotImplementedError 107 | 108 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 109 | if self.cur_att_layer >= self.num_uncond_att_layers: 110 | h = attn.shape[0] 111 | # attn[h // 2 :] = self.forward(attn[h // 2 :], is_cross, place_in_unet) 112 | # self.forward(attn[h // 2 :], is_cross, place_in_unet) 113 | self.forward(attn, is_cross, place_in_unet) 114 | self.cur_att_layer += 1 115 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 116 | self.cur_att_layer = 0 117 | self.cur_step += 1 118 | self.between_steps() 119 | return attn 120 | 121 | def reset(self): 122 | self.cur_step = 0 123 | self.cur_att_layer = 0 124 | 125 | def __init__(self): 126 | self.cur_step = 0 127 | self.num_att_layers = -1 128 | self.cur_att_layer = 0 129 | 130 | 131 | class EmptyControl(AttentionControl): 132 | def forward(self, attn, is_cross: bool, place_in_unet: str): 133 | return attn 134 | 135 | 136 | class AttentionStore(AttentionControl): 137 | @staticmethod 138 | def get_empty_store(): 139 | return { 140 | "down_cross": [], 141 | "mid_cross": [], 142 | "up_cross": [], 143 | "down_self": [], 144 | "mid_self": [], 145 | "up_self": [], 146 | } 147 | 148 | def forward(self, attn, is_cross: bool, place_in_unet: str): 149 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 150 | if attn.shape[1] <= 32**2: 151 | self.step_store[key].append(attn) 152 | return attn 153 | 154 | def between_steps(self): 155 | if len(self.attention_store) == 0: 156 | self.attention_store = self.step_store 157 | else: 158 | for key in self.attention_store: 159 | for i in range(len(self.attention_store[key])): 160 | self.attention_store[key][i] += self.step_store[key][i] 161 | self.step_store = self.get_empty_store() 162 | 163 | def get_average_attention(self): 164 | average_attention = { 165 | key: [item / self.cur_step for item in self.attention_store[key]] 166 | for key in self.attention_store 167 | } 168 | return average_attention 169 | 170 | def reset(self): 171 | super(AttentionStore, self).reset() 172 | self.step_store = self.get_empty_store() 173 | self.attention_store = {} 174 | 175 | def __init__(self): 176 | super(AttentionStore, self).__init__() 177 | self.step_store = self.get_empty_store() 178 | self.attention_store = {} 179 | 180 | 181 | class LocalBlend: 182 | def __call__(self, x_t, attention_store): 183 | k = 1 184 | maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3] 185 | maps = [ 186 | item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, self.max_num_words) 187 | for item in maps 188 | ] 189 | maps = torch.cat(maps, dim=1) 190 | maps = (maps * self.alpha_layers).sum(-1).mean(1) 191 | mask = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k)) 192 | mask = nnf.interpolate(mask, size=(x_t.shape[2:])) 193 | mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0] 194 | mask = mask.gt(self.threshold) 195 | mask = (mask[:1] + mask[1:]).float() 196 | x_t = x_t[:1] + mask * (x_t - x_t[:1]) 197 | return x_t 198 | 199 | def __init__( 200 | self, 201 | prompts: List[str], 202 | words: [List[List[str]]], 203 | tokenizer, 204 | device, 205 | threshold=0.3, 206 | max_num_words=77, 207 | ): 208 | self.max_num_words = 77 209 | 210 | alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, self.max_num_words) 211 | for i, (prompt, words_) in enumerate(zip(prompts, words)): 212 | if type(words_) is str: 213 | words_ = [words_] 214 | for word in words_: 215 | ind = get_word_inds(prompt, word, tokenizer) 216 | alpha_layers[i, :, :, :, :, ind] = 1 217 | self.alpha_layers = alpha_layers.to(device) 218 | self.threshold = threshold 219 | 220 | 221 | class AttentionControlEdit(AttentionStore, abc.ABC): 222 | def step_callback(self, x_t): 223 | if self.local_blend is not None: 224 | x_t = self.local_blend(x_t, self.attention_store) 225 | return x_t 226 | 227 | def replace_self_attention(self, attn_base, att_replace): 228 | if att_replace.shape[2] <= 16**2: 229 | return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) 230 | else: 231 | return att_replace 232 | 233 | @abc.abstractmethod 234 | def replace_cross_attention(self, attn_base, att_replace): 235 | raise NotImplementedError 236 | 237 | def forward(self, attn, is_cross: bool, place_in_unet: str): 238 | super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) 239 | 240 | if is_cross or ( 241 | self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1] 242 | ): 243 | h = attn.shape[0] // (self.batch_size) 244 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) 245 | attn_base, attn_repalce = attn[0], attn[1:] 246 | if is_cross: 247 | alpha_words = self.cross_replace_alpha[self.cur_step] 248 | attn_repalce_new = ( 249 | self.replace_cross_attention(attn_base, attn_repalce) * alpha_words 250 | + (1 - alpha_words) * attn_repalce 251 | ) 252 | attn[1:] = attn_repalce_new 253 | else: 254 | attn[1:] = self.replace_self_attention(attn_base, attn_repalce) 255 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) 256 | return attn 257 | 258 | def __init__( 259 | self, 260 | prompts, 261 | num_steps: int, 262 | cross_replace_steps: Union[ 263 | float, Tuple[float, float], Dict[str, Tuple[float, float]] 264 | ], 265 | self_replace_steps: Union[float, Tuple[float, float]], 266 | local_blend: Optional[LocalBlend], 267 | tokenizer, 268 | device, 269 | ): 270 | super(AttentionControlEdit, self).__init__() 271 | 272 | self.tokenizer = tokenizer 273 | self.device = device 274 | 275 | self.batch_size = len(prompts) 276 | self.cross_replace_alpha = get_time_words_attention_alpha( 277 | prompts, num_steps, cross_replace_steps, self.tokenizer 278 | ).to(self.device) 279 | if type(self_replace_steps) is float: 280 | self_replace_steps = 0, self_replace_steps 281 | self.num_self_replace = int(num_steps * self_replace_steps[0]), int( 282 | num_steps * self_replace_steps[1] 283 | ) 284 | self.local_blend = local_blend 285 | 286 | 287 | class AttentionReplace(AttentionControlEdit): 288 | def replace_cross_attention(self, attn_base, att_replace): 289 | return torch.einsum("hpw,bwn->bhpn", attn_base, self.mapper) 290 | 291 | def __init__( 292 | self, 293 | prompts, 294 | num_steps: int, 295 | cross_replace_steps: float, 296 | self_replace_steps: float, 297 | local_blend: Optional[LocalBlend] = None, 298 | tokenizer=None, 299 | device=None, 300 | ): 301 | super(AttentionReplace, self).__init__( 302 | prompts, 303 | num_steps, 304 | cross_replace_steps, 305 | self_replace_steps, 306 | local_blend, 307 | tokenizer, 308 | device, 309 | ) 310 | self.mapper = get_replacement_mapper(prompts, self.tokenizer).to(self.device) 311 | 312 | 313 | class AttentionRefine(AttentionControlEdit): 314 | def replace_cross_attention(self, attn_base, att_replace): 315 | attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3) 316 | attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) 317 | return attn_replace 318 | 319 | def __init__( 320 | self, 321 | prompts, 322 | num_steps: int, 323 | cross_replace_steps: float, 324 | self_replace_steps: float, 325 | local_blend: Optional[LocalBlend] = None, 326 | tokenizer=None, 327 | device=None, 328 | ): 329 | super(AttentionRefine, self).__init__( 330 | prompts, 331 | num_steps, 332 | cross_replace_steps, 333 | self_replace_steps, 334 | local_blend, 335 | tokenizer, 336 | device, 337 | ) 338 | self.mapper, alphas = get_refinement_mapper(prompts, self.tokenizer) 339 | self.mapper, alphas = self.mapper.to(self.device), alphas.to(self.device) 340 | self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) 341 | 342 | 343 | class AttentionReweight(AttentionControlEdit): 344 | def replace_cross_attention(self, attn_base, att_replace): 345 | if self.prev_controller is not None: 346 | attn_base = self.prev_controller.replace_cross_attention( 347 | attn_base, att_replace 348 | ) 349 | attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] 350 | return attn_replace 351 | 352 | def __init__( 353 | self, 354 | prompts, 355 | num_steps: int, 356 | cross_replace_steps: float, 357 | self_replace_steps: float, 358 | equalizer, 359 | local_blend: Optional[LocalBlend] = None, 360 | controller: Optional[AttentionControlEdit] = None, 361 | tokenizer=None, 362 | device=None, 363 | ): 364 | super(AttentionReweight, self).__init__( 365 | prompts, 366 | num_steps, 367 | cross_replace_steps, 368 | self_replace_steps, 369 | local_blend, 370 | tokenizer, 371 | device, 372 | ) 373 | self.equalizer = equalizer.to(self.device) 374 | self.prev_controller = controller 375 | 376 | 377 | def get_equalizer( 378 | text: str, 379 | word_select: Union[int, Tuple[int, ...]], 380 | values: Union[List[float], Tuple[float, ...]], 381 | tokenizer, 382 | ): 383 | if type(word_select) is int or type(word_select) is str: 384 | word_select = (word_select,) 385 | equalizer = torch.ones(len(values), 77) 386 | values = torch.tensor(values, dtype=torch.float32) 387 | for word in word_select: 388 | inds = get_word_inds(text, word, tokenizer) 389 | equalizer[:, inds] = values 390 | return equalizer 391 | 392 | 393 | def update_alpha_time_word( 394 | alpha, 395 | bounds: Union[float, Tuple[float, float]], 396 | prompt_ind: int, 397 | word_inds: Optional[torch.Tensor] = None, 398 | ): 399 | if type(bounds) is float: 400 | bounds = 0, bounds 401 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) 402 | if word_inds is None: 403 | word_inds = torch.arange(alpha.shape[2]) 404 | alpha[:start, prompt_ind, word_inds] = 0 405 | alpha[start:end, prompt_ind, word_inds] = 1 406 | alpha[end:, prompt_ind, word_inds] = 0 407 | return alpha 408 | 409 | 410 | def get_time_words_attention_alpha( 411 | prompts, 412 | num_steps, 413 | cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], 414 | tokenizer, 415 | max_num_words=77, 416 | ): 417 | if type(cross_replace_steps) is not dict: 418 | cross_replace_steps = {"default_": cross_replace_steps} 419 | if "default_" not in cross_replace_steps: 420 | cross_replace_steps["default_"] = (0.0, 1.0) 421 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) 422 | for i in range(len(prompts) - 1): 423 | alpha_time_words = update_alpha_time_word( 424 | alpha_time_words, cross_replace_steps["default_"], i 425 | ) 426 | for key, item in cross_replace_steps.items(): 427 | if key != "default_": 428 | inds = [ 429 | get_word_inds(prompts[i], key, tokenizer) 430 | for i in range(1, len(prompts)) 431 | ] 432 | for i, ind in enumerate(inds): 433 | if len(ind) > 0: 434 | alpha_time_words = update_alpha_time_word( 435 | alpha_time_words, item, i, ind 436 | ) 437 | alpha_time_words = alpha_time_words.reshape( 438 | num_steps + 1, len(prompts) - 1, 1, 1, max_num_words 439 | ) 440 | return alpha_time_words 441 | 442 | 443 | class ScoreParams: 444 | def __init__(self, gap, match, mismatch): 445 | self.gap = gap 446 | self.match = match 447 | self.mismatch = mismatch 448 | 449 | def mis_match_char(self, x, y): 450 | if x != y: 451 | return self.mismatch 452 | else: 453 | return self.match 454 | 455 | 456 | def get_matrix(size_x, size_y, gap): 457 | matrix = [] 458 | for i in range(len(size_x) + 1): 459 | sub_matrix = [] 460 | for j in range(len(size_y) + 1): 461 | sub_matrix.append(0) 462 | matrix.append(sub_matrix) 463 | for j in range(1, len(size_y) + 1): 464 | matrix[0][j] = j * gap 465 | for i in range(1, len(size_x) + 1): 466 | matrix[i][0] = i * gap 467 | return matrix 468 | 469 | 470 | def get_matrix(size_x, size_y, gap): 471 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) 472 | matrix[0, 1:] = (np.arange(size_y) + 1) * gap 473 | matrix[1:, 0] = (np.arange(size_x) + 1) * gap 474 | return matrix 475 | 476 | 477 | def get_traceback_matrix(size_x, size_y): 478 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) 479 | matrix[0, 1:] = 1 480 | matrix[1:, 0] = 2 481 | matrix[0, 0] = 4 482 | return matrix 483 | 484 | 485 | def global_align(x, y, score): 486 | matrix = get_matrix(len(x), len(y), score.gap) 487 | trace_back = get_traceback_matrix(len(x), len(y)) 488 | for i in range(1, len(x) + 1): 489 | for j in range(1, len(y) + 1): 490 | left = matrix[i, j - 1] + score.gap 491 | up = matrix[i - 1, j] + score.gap 492 | diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1]) 493 | matrix[i, j] = max(left, up, diag) 494 | if matrix[i, j] == left: 495 | trace_back[i, j] = 1 496 | elif matrix[i, j] == up: 497 | trace_back[i, j] = 2 498 | else: 499 | trace_back[i, j] = 3 500 | return matrix, trace_back 501 | 502 | 503 | def get_aligned_sequences(x, y, trace_back): 504 | x_seq = [] 505 | y_seq = [] 506 | i = len(x) 507 | j = len(y) 508 | mapper_y_to_x = [] 509 | while i > 0 or j > 0: 510 | if trace_back[i, j] == 3: 511 | x_seq.append(x[i - 1]) 512 | y_seq.append(y[j - 1]) 513 | i = i - 1 514 | j = j - 1 515 | mapper_y_to_x.append((j, i)) 516 | elif trace_back[i][j] == 1: 517 | x_seq.append("-") 518 | y_seq.append(y[j - 1]) 519 | j = j - 1 520 | mapper_y_to_x.append((j, -1)) 521 | elif trace_back[i][j] == 2: 522 | x_seq.append(x[i - 1]) 523 | y_seq.append("-") 524 | i = i - 1 525 | elif trace_back[i][j] == 4: 526 | break 527 | mapper_y_to_x.reverse() 528 | return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64) 529 | 530 | 531 | def get_mapper(x: str, y: str, tokenizer, max_len=77): 532 | x_seq = tokenizer.encode(x) 533 | y_seq = tokenizer.encode(y) 534 | score = ScoreParams(0, 1, -1) 535 | matrix, trace_back = global_align(x_seq, y_seq, score) 536 | mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1] 537 | alphas = torch.ones(max_len) 538 | alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float() 539 | mapper = torch.zeros(max_len, dtype=torch.int64) 540 | mapper[: mapper_base.shape[0]] = mapper_base[:, 1] 541 | mapper[mapper_base.shape[0] :] = len(y_seq) + torch.arange(max_len - len(y_seq)) 542 | return mapper, alphas 543 | 544 | 545 | def get_refinement_mapper(prompts, tokenizer, max_len=77): 546 | x_seq = prompts[0] 547 | mappers, alphas = [], [] 548 | for i in range(1, len(prompts)): 549 | mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len) 550 | mappers.append(mapper) 551 | alphas.append(alpha) 552 | return torch.stack(mappers), torch.stack(alphas) 553 | 554 | 555 | def get_word_inds(text: str, word_place: int, tokenizer): 556 | split_text = text.split(" ") 557 | if type(word_place) is str: 558 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 559 | elif type(word_place) is int: 560 | word_place = [word_place] 561 | out = [] 562 | if len(word_place) > 0: 563 | words_encode = [ 564 | tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text) 565 | ][1:-1] 566 | cur_len, ptr = 0, 0 567 | 568 | for i in range(len(words_encode)): 569 | cur_len += len(words_encode[i]) 570 | if ptr in word_place: 571 | out.append(i + 1) 572 | if cur_len >= len(split_text[ptr]): 573 | ptr += 1 574 | cur_len = 0 575 | return np.array(out) 576 | 577 | 578 | def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77): 579 | words_x = x.split(" ") 580 | words_y = y.split(" ") 581 | if len(words_x) != len(words_y): 582 | raise ValueError( 583 | f"attention replacement edit can only be applied on prompts with the same length" 584 | f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words." 585 | ) 586 | inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]] 587 | inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace] 588 | inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace] 589 | mapper = np.zeros((max_len, max_len)) 590 | i = j = 0 591 | cur_inds = 0 592 | while i < max_len and j < max_len: 593 | if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i: 594 | inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds] 595 | if len(inds_source_) == len(inds_target_): 596 | mapper[inds_source_, inds_target_] = 1 597 | else: 598 | ratio = 1 / len(inds_target_) 599 | for i_t in inds_target_: 600 | mapper[inds_source_, i_t] = ratio 601 | cur_inds += 1 602 | i += len(inds_source_) 603 | j += len(inds_target_) 604 | elif cur_inds < len(inds_source): 605 | mapper[i, j] = 1 606 | i += 1 607 | j += 1 608 | else: 609 | mapper[j, j] = 1 610 | i += 1 611 | j += 1 612 | 613 | return torch.from_numpy(mapper).float() 614 | 615 | 616 | def get_replacement_mapper(prompts, tokenizer, max_len=77): 617 | x_seq = prompts[0] 618 | mappers = [] 619 | for i in range(1, len(prompts)): 620 | mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len) 621 | mappers.append(mapper) 622 | return torch.stack(mappers) 623 | -------------------------------------------------------------------------------- /scripts/inference.sh: -------------------------------------------------------------------------------- 1 | MODEL_PATH="outputs/magictailor" 2 | PROMPT=" with , on the beach" 3 | OUTPUT_PATH="outputs/inference/result.jpg" 4 | 5 | python inference.py \ 6 | --model_path $MODEL_PATH \ 7 | --prompt $PROMPT \ 8 | --output_path $OUTPUT_PATH 9 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | INPUT_DIR="examples/person_k+hair_c" 2 | OUTPUT_DIR="outputs/magictailor" 3 | 4 | LR1=1e-4 5 | LR2=1e-5 6 | STEP1=200 7 | STEP2=300 8 | LORA_RANK=32 9 | LA=1e-2 10 | ALPHA=0.5 11 | GAMMA=32 12 | ED=0.99 13 | LP=0.2 14 | 15 | python train.py \ 16 | --seed 0 \ 17 | --mixed_precision fp16 \ 18 | --dataloader_num_workers 8 \ 19 | --pretrained_model_name_or_path stabilityai/stable-diffusion-2-1-base \ 20 | --instance_data_dir $INPUT_DIR \ 21 | --output_dir $OUTPUT_DIR \ 22 | --scale_lr \ 23 | --gsam_repo_dir Grounded-Segment-Anything \ 24 | --phase1_train_steps $STEP1 \ 25 | --phase2_train_steps $STEP2 \ 26 | --phase1_learning_rate $LR1 \ 27 | --phase2_learning_rate $LR2 \ 28 | --lora_rank $LORA_RANK \ 29 | --alpha $ALPHA \ 30 | --gamma $GAMMA \ 31 | --ema_decay $ED \ 32 | --lambda_attention $LA \ 33 | --lambda_preservation $LP \ 34 | --placeholder_token "" \ 35 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | from mask_generation import * -------------------------------------------------------------------------------- /tools/mask_generation.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import supervision as sv 4 | import os 5 | import argparse 6 | import warnings 7 | from collections import OrderedDict 8 | from tqdm import tqdm 9 | import sys 10 | 11 | import torch 12 | import torchvision 13 | 14 | from groundingdino.util.inference import Model 15 | from segment_anything import sam_model_registry, SamPredictor 16 | 17 | 18 | GDINO_CONFIG_RELATIVE_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" 19 | GDINO_CKPT_RELATIVE_PATH = "groundingdino_swint_ogc.pth" 20 | SAM_ENCODER_VERSION = "vit_h" 21 | SAM_CKPT_RELATIVE_PATH = "sam_vit_h_4b8939.pth" 22 | SPEC_CLASS_BOX_MAX_NUMS = { 23 | "brow": 2, 24 | "eyebrow": 2, 25 | "eye": 2, 26 | "ear": 2, 27 | "wheel": 2, 28 | "window": 3, 29 | } 30 | 31 | 32 | def check_mask_existence(image_dir): 33 | assert os.path.isdir(image_dir) 34 | mask_dir = os.path.join(image_dir, 'masks') 35 | if not os.path.isdir(mask_dir): 36 | return False 37 | image_basenames = set( 38 | os.path.splitext(f)[0] 39 | for f in os.listdir(image_dir) 40 | if os.path.isfile(os.path.join(image_dir, f)) and f.endswith(('.png', '.jpg', '.jpeg')) 41 | ) 42 | mask_basenames = set( 43 | os.path.splitext(f)[0] 44 | for f in os.listdir(mask_dir) 45 | if os.path.isfile(os.path.join(mask_dir, f)) and f.endswith(('.png', '.jpg', '.jpeg')) 46 | ) 47 | if image_basenames == mask_basenames: 48 | return True 49 | else: 50 | return False 51 | 52 | 53 | # Prompting SAM with detected boxes 54 | def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray: 55 | sam_predictor.set_image(image) 56 | result_masks = [] 57 | for box in xyxy: 58 | masks, scores, logits = sam_predictor.predict( 59 | box=box, 60 | multimask_output=True 61 | ) 62 | index = np.argmax(scores) 63 | result_masks.append(masks[index]) 64 | return np.array(result_masks) 65 | 66 | 67 | def get_class_names(image_dir, no_redundant=False, class_nums=-1): 68 | filenames = [ 69 | f for f in sorted(os.listdir(image_dir)) 70 | if os.path.isfile(os.path.join(image_dir, f)) and f.endswith(('.png', '.jpg', '.jpeg')) 71 | ] 72 | class_names = [ 73 | filename.split('_')[1] for filename in filenames 74 | ] 75 | if no_redundant: 76 | class_names = list(OrderedDict.fromkeys(class_names)) 77 | if class_nums != -1: 78 | class_names = class_names[:class_nums] 79 | 80 | return class_names, filenames 81 | 82 | 83 | def save_mask_image(mask_image, image_name, mask_dir): 84 | mask_image = np.where(mask_image, 255, 0).astype(np.uint8) 85 | mask_image = np.transpose(mask_image, (1, 2, 0)) 86 | cv2.imwrite(os.path.join(mask_dir, image_name+".png"), mask_image) 87 | 88 | 89 | def generate_masks(args, grounding_dino_model, sam_predictor, image_dir, save_logs=True, check_existence=False): 90 | 91 | tqdm.write("-" * 50) 92 | tqdm.write(f"Processing: {image_dir}") 93 | 94 | mask_dir = os.path.join(image_dir, 'masks') 95 | mask_others_dir = os.path.join(mask_dir, "others") 96 | os.makedirs(mask_dir, exist_ok=True) 97 | os.makedirs(mask_others_dir, exist_ok=True) 98 | 99 | if check_existence and check_mask_existence(image_dir): 100 | tqdm.write("Masks alreadly exist.") 101 | return 102 | 103 | # warnings.filterwarnings("ignore") 104 | 105 | # get class names for segmentation 106 | class_names, filenames = get_class_names(image_dir) 107 | assert len(class_names) > 1 108 | seg_class_names = [] 109 | concept_name = class_names[0] 110 | for name_i in class_names: 111 | seg_class_names.append([name_i]) 112 | if name_i == concept_name: 113 | for name_j in class_names: 114 | if name_j != concept_name and name_j not in seg_class_names[-1]: 115 | seg_class_names[-1].append(name_j) 116 | tqdm.write(f"seg_class_names: {seg_class_names}") 117 | 118 | for i, filename in enumerate(filenames): 119 | 120 | file_path = os.path.join(image_dir, filename) 121 | image_name = os.path.splitext(filename)[0] 122 | image = cv2.imread(file_path) 123 | 124 | classes = seg_class_names[i] 125 | 126 | # detect objects 127 | detections = grounding_dino_model.predict_with_classes( 128 | image=image, 129 | classes=classes, 130 | box_threshold=args.box_threshold, 131 | text_threshold=args.text_threshold 132 | ) 133 | 134 | # annotate image with detections 135 | box_annotator = sv.BoxAnnotator() 136 | 137 | # NMS post process 138 | # print(f"Before NMS: {len(detections.xyxy)} boxes") 139 | nms_idx = torchvision.ops.nms( 140 | torch.from_numpy(detections.xyxy), 141 | torch.from_numpy(detections.confidence), 142 | args.nms_threshold 143 | ).numpy().tolist() 144 | detections.xyxy = detections.xyxy[nms_idx] 145 | detections.confidence = detections.confidence[nms_idx] 146 | detections.class_id = detections.class_id[nms_idx] 147 | # print(f"After NMS: {len(detections.xyxy)} boxes") 148 | 149 | # get top-K boxes 150 | topk_nums = [ 151 | SPEC_CLASS_BOX_MAX_NUMS[c] if c in SPEC_CLASS_BOX_MAX_NUMS else 1 152 | for c in classes 153 | ] 154 | topk_idx = [] 155 | for j, id in enumerate(set(detections.class_id)): 156 | k = topk_nums[j] 157 | id_idx = np.where(detections.class_id == id)[0] 158 | id_confidence = detections.confidence[id_idx] 159 | topk_idx.append(id_idx[np.argsort(id_confidence)[-k:]]) 160 | topk_idx = np.hstack(topk_idx) 161 | detections.xyxy = detections.xyxy[topk_idx] 162 | detections.confidence = detections.confidence[topk_idx] 163 | detections.class_id = detections.class_id[topk_idx] 164 | 165 | # convert detections to masks 166 | detections.mask = segment( 167 | sam_predictor=sam_predictor, 168 | image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), 169 | xyxy=detections.xyxy 170 | ) 171 | 172 | # save the mask image 173 | masks = [] 174 | for j, id in enumerate(set(detections.class_id)): 175 | id_idx = np.where(detections.class_id == id)[0] 176 | masks.append(np.any(detections.mask[id_idx], axis=0, keepdims=True)) 177 | masks = np.concatenate(masks) 178 | if masks.shape[0] > 1: 179 | masks[1:, :, :] = np.logical_not(masks[1:, :, :]) 180 | mask_image_w_comp = np.expand_dims(masks[0], axis=0) # mask w/ component 181 | mask_image = np.all(masks, axis=0, keepdims=True) # mask wo/ component 182 | save_mask_image(mask_image, image_name, mask_dir) 183 | save_mask_image(mask_image_w_comp, image_name, mask_others_dir) 184 | else: 185 | mask_image = masks 186 | save_mask_image(mask_image, image_name, mask_dir) 187 | 188 | # save logs 189 | if save_logs: 190 | # init log dir 191 | log_dir = os.path.join(image_dir, 'logs') 192 | os.makedirs(log_dir, exist_ok=True) 193 | 194 | # get the annotated image of grounding dino 195 | labels = [ 196 | f"{classes[class_id]} {confidence:0.2f}" 197 | for _, _, confidence, class_id, _, _ 198 | in detections] 199 | annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels) 200 | # cv2.imwrite(os.path.join(log_dir, image_name+"_gdino.jpg"), annotated_frame) 201 | 202 | # get the annotated image of grounding-SAM 203 | box_annotator = sv.BoxAnnotator() 204 | mask_annotator = sv.MaskAnnotator() 205 | annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections) 206 | annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels) 207 | cv2.imwrite(os.path.join(log_dir, image_name+"_gsam.jpg"), annotated_image) 208 | 209 | # warnings.resetwarnings() 210 | 211 | def get_gdino_and_sam_model(args, device): 212 | # GroundingDINO config and checkpoint 213 | gdino_config_path = os.path.join(args.gsam_repo_dir, GDINO_CONFIG_RELATIVE_PATH) 214 | gdino_ckpt_path = os.path.join(args.gsam_repo_dir, GDINO_CKPT_RELATIVE_PATH) 215 | 216 | # Segment-Anything checkpoint 217 | sam_ckpt_path = os.path.join(args.gsam_repo_dir, SAM_CKPT_RELATIVE_PATH) 218 | 219 | # Building GroundingDINO inference model 220 | grounding_dino_model = Model(model_config_path=gdino_config_path, model_checkpoint_path=gdino_ckpt_path) 221 | 222 | # Building SAM Model and SAM Predictor 223 | sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=sam_ckpt_path) 224 | sam.to(device=device) 225 | sam.eval() 226 | sam_predictor = SamPredictor(sam) 227 | 228 | return grounding_dino_model, sam_predictor 229 | 230 | if __name__ == "__main__": 231 | 232 | parser = argparse.ArgumentParser("Mask Generation with G-SAM", add_help=True) 233 | 234 | parser.add_argument( 235 | "--gsam_repo_dir", default="Grounded-Segment-Anything", 236 | type=str, help="dir to gsam repo", 237 | # required=True, 238 | ) 239 | parser.add_argument( 240 | "--dataset_dir", default="dataset_test", # ./tailorbench 241 | type=str, help="dir to the dataset", 242 | # required=True, 243 | ) 244 | 245 | parser.add_argument("--box_threshold", type=float, default=0.25, help="box threshold") 246 | parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold") 247 | parser.add_argument("--nms_threshold", type=float, default=0.8, help="nms threshold") 248 | 249 | args = parser.parse_args() 250 | 251 | print("Start geneartive masks for dataset ...") 252 | print(f"Dataset directory: {args.dataset_dir}") 253 | 254 | warnings.filterwarnings("ignore") 255 | 256 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 257 | grounding_dino_model, sam_predictor = get_gdino_and_sam_model(args, device) 258 | 259 | subfolders = [ 260 | f for f in sorted(os.listdir(args.dataset_dir)) 261 | if os.path.isdir(os.path.join(args.dataset_dir, f)) 262 | ] 263 | 264 | for subfolder in tqdm(subfolders, file=sys.stdout, desc="Pair Progress"): 265 | subfolder_dir = os.path.join(args.dataset_dir, subfolder) 266 | generate_masks(args, grounding_dino_model, sam_predictor, subfolder_dir, check_existence=True) 267 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Below is the copyright notice from Google. 3 | 4 | Please also follow this license when you modify or distribute the code. 5 | """ 6 | 7 | """ 8 | Copyright 2023 Google LLC 9 | 10 | Licensed under the Apache License, Version 2.0 (the "License"); 11 | you may not use this file except in compliance with the License. 12 | You may obtain a copy of the License at 13 | 14 | https://www.apache.org/licenses/LICENSE-2.0 15 | 16 | Unless required by applicable law or agreed to in writing, software 17 | distributed under the License is distributed on an "AS IS" BASIS, 18 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | See the License for the specific language governing permissions and 20 | limitations under the License. 21 | """ 22 | 23 | import argparse 24 | import hashlib 25 | import itertools 26 | import logging 27 | import math 28 | import os 29 | import warnings 30 | from pathlib import Path 31 | from typing import List, Optional 32 | import random 33 | 34 | import torch 35 | import torch.nn as nn 36 | import torch.nn.functional as F 37 | import torchvision.transforms.functional as TF 38 | import torch.utils.checkpoint 39 | from torch.utils.data import Dataset 40 | import numpy as np 41 | 42 | import datasets 43 | import diffusers 44 | import transformers 45 | from accelerate import Accelerator 46 | from accelerate.logging import get_logger 47 | from accelerate.utils import ProjectConfiguration, set_seed 48 | from diffusers import ( 49 | AutoencoderKL, 50 | DDPMScheduler, 51 | DiffusionPipeline, 52 | StableDiffusionPipeline, 53 | UNet2DConditionModel, 54 | DDIMScheduler, 55 | ) 56 | from diffusers.optimization import get_scheduler 57 | from diffusers.utils import check_min_version 58 | from diffusers.utils.import_utils import is_xformers_available 59 | from PIL import Image 60 | from torchvision import transforms 61 | from tqdm.auto import tqdm 62 | from transformers import AutoTokenizer, PretrainedConfig 63 | import ptp_utils 64 | from ptp_utils import AttentionStore 65 | # from diffusers.models.cross_attention import CrossAttention 66 | from diffusers.models.attention import Attention as CrossAttention, FeedForward, AdaLayerNorm 67 | 68 | import cv2 69 | 70 | from diffusers.loaders import LoraLoaderMixin 71 | from peft import LoraConfig, get_peft_model 72 | from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict 73 | from diffusers.utils import ( 74 | check_min_version, 75 | convert_state_dict_to_diffusers, 76 | convert_unet_state_dict_to_peft, 77 | ) 78 | from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params 79 | from diffusers.utils.torch_utils import is_compiled_module 80 | 81 | from typing import Any, List, Optional, Union 82 | import math 83 | 84 | from tools.mask_generation import check_mask_existence, generate_masks, get_gdino_and_sam_model, get_class_names 85 | 86 | 87 | check_min_version("0.12.0") 88 | 89 | logger = get_logger(__name__) 90 | 91 | 92 | def import_model_class_from_model_name_or_path( 93 | pretrained_model_name_or_path: str, revision: str 94 | ): 95 | text_encoder_config = PretrainedConfig.from_pretrained( 96 | pretrained_model_name_or_path, 97 | subfolder="text_encoder", 98 | revision=revision, 99 | ) 100 | model_class = text_encoder_config.architectures[0] 101 | 102 | if model_class == "CLIPTextModel": 103 | from transformers import CLIPTextModel 104 | 105 | return CLIPTextModel 106 | elif model_class == "RobertaSeriesModelWithTransformation": 107 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import ( 108 | RobertaSeriesModelWithTransformation, 109 | ) 110 | 111 | return RobertaSeriesModelWithTransformation 112 | else: 113 | raise ValueError(f"{model_class} is not supported.") 114 | 115 | 116 | def parse_args(input_args=None): 117 | parser = argparse.ArgumentParser( 118 | description="Simple example of a training script.") 119 | 120 | # data 121 | parser.add_argument( 122 | "--instance_data_dir", 123 | type=str, 124 | default=None, 125 | required=True, 126 | help="A folder containing the training data of instance images.", 127 | ) 128 | parser.add_argument( 129 | "--output_dir", 130 | type=str, 131 | default="outputs/magictailor", 132 | help="The output directory where the model predictions and checkpoints will be written.", 133 | ) 134 | 135 | # pipeline 136 | parser.add_argument( 137 | "--phase1_train_steps", 138 | type=int, 139 | default=200, 140 | help="Number of trainig steps for the first phase (warm-up).", 141 | ) 142 | parser.add_argument( 143 | "--phase2_train_steps", 144 | type=int, 145 | default=300, 146 | help="Number of trainig steps for the second phase (DS-Bal).", 147 | ) 148 | parser.add_argument( 149 | "--phase1_learning_rate", 150 | type=float, 151 | default=1e-4, 152 | help="Learning rate for the first phase (warm-up).", 153 | ) 154 | parser.add_argument( 155 | "--phase2_learning_rate", 156 | type=float, 157 | default=1e-5, 158 | help="Learning rate for the second phase (DS-Bal).", 159 | ) 160 | parser.add_argument( 161 | "--scale_lr", 162 | action="store_true", 163 | default=False, 164 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 165 | ) 166 | parser.add_argument("--lora_rank", type=int, default=32) 167 | 168 | # cross-attention loss 169 | parser.add_argument("--lambda_attention", type=float, default=1e-2) 170 | 171 | # DM-Deg 172 | parser.add_argument("--alpha", type=float, default=0.5) 173 | parser.add_argument("--gamma", type=float, default=32) 174 | 175 | # DS-Bal 176 | parser.add_argument("--ema_decay", type=float, default=0.99) 177 | parser.add_argument("--lambda_preservation", type=float, default=0.2) 178 | 179 | # ckpt 180 | parser.add_argument("--checkpoint_dir", type=str) 181 | parser.add_argument( 182 | "--resume_from_checkpoint", 183 | type=str, 184 | default=None, 185 | help= 186 | ("Whether training should be resumed from a previous checkpoint. Use a path saved by" 187 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 188 | ), 189 | ) 190 | parser.add_argument( 191 | "--checkpointing_steps", 192 | type=int, 193 | default=5000, 194 | help= 195 | ("Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" 196 | " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" 197 | " training using `--resume_from_checkpoint`."), 198 | ) 199 | 200 | # seed 201 | parser.add_argument( 202 | "--seed", 203 | type=int, 204 | default=0, 205 | help="A seed for reproducible training." 206 | ) 207 | 208 | # resolution 209 | parser.add_argument( 210 | "--resolution", 211 | type=int, 212 | default=512, 213 | help= 214 | ("The resolution for input images" 215 | " resolution"), 216 | ) 217 | 218 | # model 219 | parser.add_argument( 220 | "--pretrained_model_name_or_path", 221 | type=str, 222 | default="stabilityai/stable-diffusion-2-1-base", 223 | help= 224 | "Path to pretrained model or model identifier from huggingface.co/models.", 225 | ) 226 | parser.add_argument( 227 | "--revision", 228 | type=str, 229 | default=None, 230 | required=False, 231 | help= 232 | ("Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" 233 | " float32 precision."), 234 | ) 235 | parser.add_argument( 236 | "--tokenizer_name", 237 | type=str, 238 | default=None, 239 | help="Pretrained tokenizer name or path if not the same as model_name", 240 | ) 241 | 242 | # Grounding SAM 243 | parser.add_argument( 244 | "--gsam_repo_dir", 245 | default="Grounded-Segment-Anything", 246 | type=str, 247 | help="dir to gsam repo", 248 | ) 249 | parser.add_argument("--box_threshold", 250 | type=float, 251 | default=0.25, 252 | help="box threshold") 253 | parser.add_argument("--text_threshold", 254 | type=float, 255 | default=0.25, 256 | help="text threshold") 257 | parser.add_argument("--nms_threshold", 258 | type=float, 259 | default=0.8, 260 | help="nms threshold") 261 | 262 | # logging 263 | parser.add_argument( 264 | "--log_checkpoints", 265 | action="store_true", 266 | help="Indicator to log intermediate model checkpoints", 267 | ) 268 | parser.add_argument("--img_log_steps", type=int, default=500) 269 | parser.add_argument( 270 | "--logging_dir", 271 | type=str, 272 | default="logs", 273 | help= 274 | ("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 275 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), 276 | ) 277 | parser.add_argument( 278 | "--report_to", 279 | type=str, 280 | default="tensorboard", 281 | help= 282 | ('The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 283 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 284 | ), 285 | ) 286 | 287 | # prompt 288 | parser.add_argument( 289 | "--placeholder_token", 290 | type=str, 291 | default="", 292 | help="A token to use as a placeholder for the concept.", 293 | ) 294 | parser.add_argument("--inference_prompt", default=None, type=str) 295 | 296 | # training 297 | parser.add_argument( 298 | "--gradient_accumulation_steps", 299 | type=int, 300 | default=1, 301 | help= 302 | "Number of updates steps to accumulate before performing a backward/update pass.", 303 | ) 304 | parser.add_argument( 305 | "--gradient_checkpointing", 306 | action="store_true", 307 | help= 308 | "Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 309 | ) 310 | parser.add_argument( 311 | "--lr_scheduler", 312 | type=str, 313 | default="constant", 314 | help= 315 | ('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 316 | ' "constant", "constant_with_warmup"]'), 317 | ) 318 | parser.add_argument( 319 | "--lr_warmup_steps", 320 | type=int, 321 | default=0, 322 | help="Number of steps for the warmup in the lr scheduler.", 323 | ) 324 | parser.add_argument( 325 | "--lr_num_cycles", 326 | type=int, 327 | default=1, 328 | help= 329 | "Number of hard resets of the lr in cosine_with_restarts scheduler.", 330 | ) 331 | parser.add_argument( 332 | "--lr_power", 333 | type=float, 334 | default=1.0, 335 | help="Power factor of the polynomial scheduler.", 336 | ) 337 | parser.add_argument( 338 | "--use_8bit_adam", 339 | action="store_true", 340 | help="Whether or not to use 8-bit Adam from bitsandbytes.", 341 | ) 342 | parser.add_argument( 343 | "--dataloader_num_workers", 344 | type=int, 345 | default=8, 346 | help= 347 | ("Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 348 | ), 349 | ) 350 | parser.add_argument( 351 | "--adam_beta1", 352 | type=float, 353 | default=0.9, 354 | help="The beta1 parameter for the Adam optimizer.", 355 | ) 356 | parser.add_argument( 357 | "--adam_beta2", 358 | type=float, 359 | default=0.999, 360 | help="The beta2 parameter for the Adam optimizer.", 361 | ) 362 | parser.add_argument("--adam_weight_decay", 363 | type=float, 364 | default=1e-2, 365 | help="Weight decay to use.") 366 | parser.add_argument( 367 | "--adam_epsilon", 368 | type=float, 369 | default=1e-08, 370 | help="Epsilon value for the Adam optimizer", 371 | ) 372 | parser.add_argument("--max_grad_norm", 373 | default=1.0, 374 | type=float, 375 | help="Max gradient norm.") 376 | parser.add_argument( 377 | "--local_rank", 378 | type=int, 379 | default=-1, 380 | help="For distributed training: local_rank", 381 | ) 382 | parser.add_argument( 383 | "--enable_xformers_memory_efficient_attention", 384 | action="store_true", 385 | help="Whether or not to use xformers.", 386 | ) 387 | parser.add_argument( 388 | "--set_grads_to_none", 389 | action="store_true", 390 | help= 391 | ("Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" 392 | " behaviors, so disable this argument if it causes any problems. More info:" 393 | " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" 394 | ), 395 | ) 396 | 397 | # data type 398 | parser.add_argument( 399 | "--allow_tf32", 400 | action="store_true", 401 | help= 402 | ("Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 403 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 404 | ), 405 | ) 406 | parser.add_argument( 407 | "--mixed_precision", 408 | type=str, 409 | default="fp16", 410 | choices=["no", "fp16", "bf16"], 411 | help= 412 | ("Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 413 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 414 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 415 | ), 416 | ) 417 | 418 | # for debugging (do not modify them, some of them will be overwritten) 419 | parser.add_argument( 420 | "--do_not_apply_masked_loss", 421 | action="store_false", 422 | help="Use masked loss instead of standard loss", 423 | dest="apply_masked_loss" 424 | ) 425 | parser.add_argument("--num_train_epochs", type=int, default=1) 426 | parser.add_argument( 427 | "--train_batch_size", 428 | type=int, 429 | default=1, 430 | help="Batch size (per device) for the training dataloader.", 431 | ) 432 | parser.add_argument( 433 | "--num_of_assets", 434 | type=int, 435 | default=2, 436 | ) 437 | 438 | if input_args is not None: 439 | args = parser.parse_args(input_args) 440 | else: 441 | args = parser.parse_args() 442 | 443 | args.max_train_steps = args.phase1_train_steps + args.phase2_train_steps 444 | 445 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 446 | if env_local_rank != -1 and env_local_rank != args.local_rank: 447 | args.local_rank = env_local_rank 448 | 449 | return args 450 | 451 | 452 | def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): 453 | if tokenizer_max_length is not None: 454 | max_length = tokenizer_max_length 455 | else: 456 | max_length = tokenizer.model_max_length 457 | 458 | text_inputs = tokenizer( 459 | prompt, 460 | truncation=True, 461 | padding="max_length", 462 | max_length=max_length, 463 | return_tensors="pt", 464 | ) 465 | 466 | return text_inputs 467 | 468 | 469 | class CompCtrlPersDataset(Dataset): 470 | def __init__( 471 | self, 472 | instance_data_root, 473 | placeholder_tokens, 474 | tokenizer, 475 | size=512, 476 | flip_p=0.5, 477 | ): 478 | self.size = size 479 | self.tokenizer = tokenizer 480 | self.flip_p = flip_p 481 | 482 | self.image_transforms = transforms.Compose( 483 | [ 484 | transforms.Resize( 485 | size, interpolation=transforms.InterpolationMode.BILINEAR), 486 | transforms.ToTensor(), 487 | ] 488 | ) 489 | self.mask_transforms = transforms.Compose( 490 | [ 491 | transforms.Resize( 492 | size, interpolation=transforms.InterpolationMode.BILINEAR), 493 | transforms.ToTensor(), 494 | ] 495 | ) 496 | self.image_process = transforms.Compose( 497 | [ 498 | transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1), 499 | transforms.Normalize([0.5], [0.5]), 500 | ] 501 | ) 502 | 503 | if not Path(instance_data_root).exists(): 504 | raise ValueError("Instance images root doesn't exists.") 505 | 506 | self.placeholder_tokens = placeholder_tokens 507 | 508 | # get paths for images and masks 509 | self.instance_images_path = [ 510 | f for f in sorted(os.listdir(instance_data_root)) 511 | if os.path.isfile(os.path.join(instance_data_root, f)) and f.endswith(('.png', '.jpg', '.jpeg')) 512 | ] 513 | mask_dir = os.path.join(instance_data_root, "masks") 514 | self.instance_masks_path = [ 515 | f for f in sorted(os.listdir(mask_dir)) 516 | if os.path.isfile(os.path.join(mask_dir, f)) and f.endswith(('.png', '.jpg', '.jpeg')) 517 | ] 518 | assert len(self.instance_images_path) == len(self.instance_masks_path) 519 | 520 | # load images and masks 521 | self.instance_images = [] 522 | self.instance_masks = [] 523 | for i in range(len(self.instance_images_path)): 524 | # load and transform masks 525 | instance_mask_path = os.path.join(mask_dir, 526 | self.instance_masks_path[i]) 527 | mask = Image.open(instance_mask_path) 528 | mask = self.mask_transforms(mask)[0, None, None, ...] 529 | self.instance_masks.append(mask) 530 | # load and transform images 531 | instance_image_path = os.path.join(instance_data_root, 532 | self.instance_images_path[i]) 533 | image = Image.open(instance_image_path) 534 | image = self.image_transforms(image) 535 | self.instance_images.append(image) 536 | self.instance_images = torch.stack(self.instance_images) 537 | self.instance_masks = torch.cat(self.instance_masks) 538 | 539 | # get formatted prompts 540 | pair_name = os.path.basename(instance_data_root) 541 | sample_names_and_ids = pair_name.split('+') 542 | sample_names = [s.split('_')[0] for s in sample_names_and_ids] 543 | self.instance_prompts = [] 544 | for p in self.instance_images_path: 545 | sample_name = p.split('_')[1] 546 | sample_idx = sample_names.index(sample_name) 547 | referent = self.placeholder_tokens[sample_idx] 548 | prompt = f"A photo of {referent}" 549 | self.instance_prompts.append(prompt) 550 | 551 | # load prompt indexs 552 | text_inputs = tokenize_prompt( 553 | self.tokenizer, 554 | self.instance_prompts, 555 | ) 556 | self.instance_prompt_ids = text_inputs.input_ids 557 | 558 | self.num_instance_images = len(self.instance_images_path) 559 | self._length = self.num_instance_images 560 | 561 | def __len__(self): 562 | return self._length 563 | 564 | def __getitem__(self, index): 565 | example = {} 566 | 567 | example["instance_images"] = self.instance_images.clone() 568 | example["instance_masks"] = self.instance_masks.clone() 569 | example["instance_prompt_ids"] = self.instance_prompt_ids.clone() 570 | example["instance_images"] = self.image_process(example["instance_images"]) 571 | if random.random() > self.flip_p: 572 | example["instance_images"] = TF.hflip(example["instance_images"]) 573 | example["instance_masks"] = TF.hflip(example["instance_masks"]) 574 | 575 | return example 576 | 577 | 578 | def collate_fn(examples): 579 | input_ids = [example["instance_prompt_ids"] for example in examples] 580 | pixel_values = [example["instance_images"] for example in examples] 581 | masks = [example["instance_masks"] for example in examples] 582 | 583 | pixel_values = torch.cat(pixel_values, dim=0) 584 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 585 | 586 | input_ids = torch.cat(input_ids, dim=0) 587 | masks = torch.cat(masks) 588 | masks = masks.to(memory_format=torch.contiguous_format).float() 589 | 590 | batch = { 591 | "input_ids": input_ids, 592 | "pixel_values": pixel_values, 593 | "instance_masks": masks, 594 | } 595 | return batch 596 | 597 | 598 | class MagicTailor: 599 | def __init__(self): 600 | self.args = parse_args() 601 | self.main() 602 | 603 | def main(self): 604 | 605 | # overwirtte args configs for a pair 606 | instance_images_path = [ 607 | f for f in sorted(os.listdir(self.args.instance_data_dir)) 608 | if os.path.isfile(os.path.join(self.args.instance_data_dir, f)) and f.endswith(('.png', '.jpg', '.jpeg')) 609 | ] 610 | self.args.train_batch_size = len(instance_images_path) 611 | sample_names_and_ids = os.path.basename(self.args.instance_data_dir).split('+') 612 | self.args.num_of_assets = len(sample_names_and_ids) 613 | self.args.initializer_tokens = [s.split('_')[0] for s in sample_names_and_ids] 614 | 615 | logging_dir = Path(self.args.output_dir, self.args.logging_dir) 616 | 617 | accelerator_project_config = ProjectConfiguration( 618 | project_dir=self.args.output_dir, logging_dir=logging_dir) 619 | 620 | self.accelerator = Accelerator( 621 | gradient_accumulation_steps=self.args.gradient_accumulation_steps, 622 | mixed_precision=self.args.mixed_precision, 623 | log_with=self.args.report_to, 624 | project_config=accelerator_project_config, 625 | # logging_dir=logging_dir, 626 | ) 627 | 628 | if ( 629 | self.args.gradient_accumulation_steps > 1 630 | and self.accelerator.num_processes > 1 631 | ): 632 | raise ValueError( 633 | "Gradient accumulation is not supported when training the text encoder in distributed training. " 634 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 635 | ) 636 | 637 | logging.basicConfig( 638 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 639 | datefmt="%m/%d/%Y %H:%M:%S", 640 | level=logging.INFO, 641 | ) 642 | logger.info(self.accelerator.state, main_process_only=False) 643 | if self.accelerator.is_local_main_process: 644 | datasets.utils.logging.set_verbosity_warning() 645 | transformers.utils.logging.set_verbosity_warning() 646 | diffusers.utils.logging.set_verbosity_info() 647 | else: 648 | datasets.utils.logging.set_verbosity_error() 649 | transformers.utils.logging.set_verbosity_error() 650 | diffusers.utils.logging.set_verbosity_error() 651 | 652 | # if passed along, set the training seed now. 653 | if self.args.seed is not None: 654 | set_seed(self.args.seed) 655 | 656 | # text-guided mask generation 657 | if check_mask_existence(self.args.instance_data_dir): 658 | print("Masks alreadly exist.") 659 | else: 660 | print("Perform text-guided mask generation.") 661 | grounding_dino_model, sam_predictor = get_gdino_and_sam_model(self.args, self.accelerator.device) 662 | generate_masks(self.args, grounding_dino_model, sam_predictor, self.args.instance_data_dir, save_logs=True) 663 | del grounding_dino_model 664 | del sam_predictor 665 | 666 | # handle the repository creation 667 | if self.accelerator.is_main_process: 668 | os.makedirs(self.args.output_dir, exist_ok=True) 669 | 670 | # import correct text encoder class 671 | text_encoder_cls = import_model_class_from_model_name_or_path( 672 | self.args.pretrained_model_name_or_path, self.args.revision 673 | ) 674 | 675 | # load scheduler and models 676 | self.noise_scheduler = DDPMScheduler.from_pretrained( 677 | self.args.pretrained_model_name_or_path, subfolder="scheduler" 678 | ) 679 | self.text_encoder = text_encoder_cls.from_pretrained( 680 | self.args.pretrained_model_name_or_path, 681 | subfolder="text_encoder", 682 | revision=self.args.revision, 683 | ) 684 | self.vae = AutoencoderKL.from_pretrained( 685 | self.args.pretrained_model_name_or_path, 686 | subfolder="vae", 687 | revision=self.args.revision, 688 | ) 689 | self.unet = UNet2DConditionModel.from_pretrained( 690 | self.args.pretrained_model_name_or_path, 691 | subfolder="unet", 692 | revision=self.args.revision, 693 | ) 694 | 695 | # load the tokenizer 696 | if self.args.tokenizer_name: 697 | self.tokenizer = AutoTokenizer.from_pretrained( 698 | self.args.tokenizer_name, revision=self.args.revision, use_fast=False 699 | ) 700 | elif self.args.pretrained_model_name_or_path: 701 | self.tokenizer = AutoTokenizer.from_pretrained( 702 | self.args.pretrained_model_name_or_path, 703 | subfolder="tokenizer", 704 | revision=self.args.revision, 705 | use_fast=False, 706 | ) 707 | 708 | # add placeholder tokens to tokenizer 709 | self.placeholder_tokens = [ 710 | self.args.placeholder_token.replace(">", f"{idx}>") 711 | for idx in range(self.args.num_of_assets) 712 | ] 713 | num_added_tokens = self.tokenizer.add_tokens(self.placeholder_tokens) 714 | assert num_added_tokens == self.args.num_of_assets 715 | self.placeholder_token_ids = self.tokenizer.convert_tokens_to_ids( 716 | self.placeholder_tokens 717 | ) 718 | self.text_encoder.resize_token_embeddings(len(self.tokenizer)) 719 | 720 | if len(self.args.initializer_tokens) > 0: 721 | token_embeds = self.text_encoder.get_input_embeddings().weight.data 722 | for tkn_idx, initializer_token in enumerate(self.args.initializer_tokens): 723 | curr_token_ids = self.tokenizer.encode( 724 | initializer_token, add_special_tokens=False 725 | ) 726 | token_embeds[self.placeholder_token_ids[tkn_idx]] = token_embeds[ 727 | curr_token_ids[0] 728 | ].clone() 729 | else: 730 | token_embeds = self.text_encoder.get_input_embeddings().weight.data 731 | token_embeds[-self.args.num_of_assets :] = token_embeds[ 732 | -3 * self.args.num_of_assets : -2 * self.args.num_of_assets 733 | ] 734 | 735 | # set validation scheduler for logging 736 | self.validation_scheduler = DDIMScheduler( 737 | beta_start=0.00085, 738 | beta_end=0.012, 739 | beta_schedule="scaled_linear", 740 | clip_sample=False, 741 | set_alpha_to_one=False, 742 | ) 743 | self.validation_scheduler.set_timesteps(50) 744 | 745 | if self.args.enable_xformers_memory_efficient_attention: 746 | if is_xformers_available(): 747 | self.unet.enable_xformers_memory_efficient_attention() 748 | print("Enable xformers.") 749 | else: 750 | raise ValueError( 751 | "xformers is not available. Make sure it is installed correctly" 752 | ) 753 | 754 | if self.args.gradient_checkpointing: 755 | self.unet.enable_gradient_checkpointing() 756 | 757 | if self.args.allow_tf32: 758 | torch.backends.cuda.matmul.allow_tf32 = True 759 | 760 | if self.args.scale_lr: 761 | self.args.phase1_learning_rate = ( 762 | self.args.phase1_learning_rate 763 | * self.args.gradient_accumulation_steps 764 | * self.args.train_batch_size 765 | * self.accelerator.num_processes 766 | ) 767 | self.args.phase2_learning_rate = ( 768 | self.args.phase2_learning_rate 769 | * self.args.gradient_accumulation_steps 770 | * self.args.train_batch_size 771 | * self.accelerator.num_processes 772 | ) 773 | 774 | if self.args.use_8bit_adam: 775 | try: 776 | import bitsandbytes as bnb 777 | except ImportError: 778 | raise ImportError( 779 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 780 | ) 781 | 782 | optimizer_class = bnb.optim.AdamW8bit 783 | else: 784 | optimizer_class = torch.optim.AdamW 785 | 786 | # setup LoRA 787 | self.vae.requires_grad_(False) 788 | self.unet.requires_grad_(False) 789 | 790 | self.text_encoder.text_model.encoder.requires_grad_(False) 791 | self.text_encoder.text_model.final_layer_norm.requires_grad_(False) 792 | self.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) 793 | 794 | unet_lora_config = LoraConfig( 795 | r=self.args.lora_rank, 796 | lora_alpha=self.args.lora_rank, 797 | init_lora_weights="gaussian", 798 | target_modules=["to_k", "to_q", "to_v", "to_out.0", "proj_in", "proj_out", "ff.net.2"], 799 | ) 800 | self.unet.add_adapter(unet_lora_config) 801 | 802 | lora_params = list(filter(lambda p: p.requires_grad, self.unet.parameters())) 803 | params_to_optimize = ( 804 | itertools.chain( 805 | lora_params, 806 | self.text_encoder.get_input_embeddings().parameters(), 807 | ) 808 | ) 809 | 810 | optimizer = optimizer_class( 811 | params_to_optimize, 812 | lr=self.args.phase1_learning_rate, 813 | betas=(self.args.adam_beta1, self.args.adam_beta2), 814 | weight_decay=self.args.adam_weight_decay, 815 | eps=self.args.adam_epsilon, 816 | ) 817 | 818 | # saving and loading ckpt for the model with LoRA 819 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 820 | def save_model_hook(models, weights, output_dir): 821 | if self.accelerator.is_main_process: 822 | # there are only two options here. Either are just the unet attn processor layers 823 | # or there are the unet and text encoder atten layers 824 | unet_lora_layers_to_save = None 825 | text_encoder_lora_layers_to_save = None 826 | 827 | for model in models: 828 | if isinstance(model, type(self.unwrap_model(self.unet))): 829 | unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) 830 | elif isinstance(model, type(self.unwrap_model(self.text_encoder))): 831 | text_encoder_lora_layers_to_save = None 832 | else: 833 | raise ValueError(f"unexpected save model: {model.__class__}") 834 | 835 | # make sure to pop weight so that corresponding model is not saved again 836 | weights.pop() 837 | 838 | LoraLoaderMixin.save_lora_weights( 839 | output_dir, 840 | unet_lora_layers=unet_lora_layers_to_save, 841 | text_encoder_lora_layers=text_encoder_lora_layers_to_save, 842 | ) 843 | 844 | def load_model_hook(models, input_dir): 845 | unet_ = None 846 | text_encoder_ = None 847 | 848 | while len(models) > 0: 849 | model = models.pop() 850 | 851 | if isinstance(model, type(self.unwrap_model(self.unet))): 852 | unet_ = model 853 | elif isinstance(model, type(self.unwrap_model(self.text_encoder))): 854 | text_encoder_ = model 855 | else: 856 | raise ValueError(f"unexpected save model: {model.__class__}") 857 | 858 | lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) 859 | 860 | unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} 861 | unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) 862 | incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") 863 | 864 | if incompatible_keys is not None: 865 | # check only for unexpected keys 866 | unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) 867 | if unexpected_keys: 868 | logger.warning( 869 | f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " 870 | f" {unexpected_keys}. " 871 | ) 872 | 873 | if self.args.mixed_precision == "fp16": 874 | models = [unet_] 875 | # only upcast trainable parameters (LoRA) into fp32 876 | cast_training_params(models, dtype=torch.float32) 877 | 878 | self.accelerator.register_save_state_pre_hook(save_model_hook) 879 | self.accelerator.register_load_state_pre_hook(load_model_hook) 880 | 881 | # create dataLoaders 882 | train_dataset = CompCtrlPersDataset( 883 | instance_data_root=self.args.instance_data_dir, 884 | placeholder_tokens=self.placeholder_tokens, 885 | tokenizer=self.tokenizer, 886 | size=self.args.resolution, 887 | ) 888 | 889 | train_dataloader = torch.utils.data.DataLoader( 890 | train_dataset, 891 | batch_size=1, # load all images for once time, so here batch_size set to 1 892 | shuffle=True, 893 | collate_fn=lambda examples: collate_fn(examples), 894 | num_workers=self.args.dataloader_num_workers, 895 | ) 896 | 897 | # math around the number of training steps 898 | # (nothing important here, you can just skip and keep this) 899 | overrode_max_train_steps = False 900 | num_update_steps_per_epoch = math.ceil( 901 | len(train_dataloader) / self.args.gradient_accumulation_steps 902 | ) 903 | if self.args.max_train_steps is None: 904 | self.args.max_train_steps = ( 905 | self.args.num_train_epochs * num_update_steps_per_epoch 906 | ) 907 | overrode_max_train_steps = True 908 | 909 | lr_scheduler = get_scheduler( 910 | self.args.lr_scheduler, 911 | optimizer=optimizer, 912 | num_warmup_steps=self.args.lr_warmup_steps 913 | * self.args.gradient_accumulation_steps, 914 | num_training_steps=self.args.max_train_steps 915 | * self.args.gradient_accumulation_steps, 916 | num_cycles=self.args.lr_num_cycles, 917 | power=self.args.lr_power, 918 | ) 919 | 920 | ( 921 | self.unet, 922 | self.text_encoder, 923 | optimizer, 924 | train_dataloader, 925 | lr_scheduler, 926 | ) = self.accelerator.prepare( 927 | self.unet, self.text_encoder, optimizer, train_dataloader, lr_scheduler 928 | ) 929 | 930 | # for mixed precision training we cast the text_encoder and vae weights to half-precision 931 | # as these models are only used for inference, keeping weights in full precision is not required 932 | self.weight_dtype = torch.float32 933 | if self.accelerator.mixed_precision == "fp16": 934 | self.weight_dtype = torch.float16 935 | elif self.accelerator.mixed_precision == "bf16": 936 | self.weight_dtype = torch.bfloat16 937 | 938 | # move vae and text_encoder to device and cast to weight_dtype 939 | self.vae.to(self.accelerator.device, dtype=self.weight_dtype) 940 | 941 | low_precision_error_string = ( 942 | "Please make sure to always have all model weights in full float32 precision when starting training - even if" 943 | " doing mixed precision training. copy of the weights should still be float32." 944 | ) 945 | 946 | if self.accelerator.unwrap_model(self.unet).dtype != torch.float32: 947 | raise ValueError( 948 | f"Unet loaded as datatype {self.accelerator.unwrap_model(self.unet).dtype}. {low_precision_error_string}" 949 | ) 950 | 951 | if self.accelerator.unwrap_model(self.text_encoder).dtype != torch.float32: 952 | raise ValueError( 953 | f"Text encoder loaded as datatype {self.accelerator.unwrap_model(self.text_encoder).dtype}." 954 | f" {low_precision_error_string}" 955 | ) 956 | 957 | # we need to recalculate our total training steps as the size of the training dataloader may have changed 958 | # (nothing important here, you can just skip and keep this) 959 | num_update_steps_per_epoch = math.ceil( 960 | len(train_dataloader) / self.args.gradient_accumulation_steps 961 | ) 962 | if overrode_max_train_steps: 963 | self.args.max_train_steps = ( 964 | self.args.num_train_epochs * num_update_steps_per_epoch 965 | ) 966 | self.args.num_train_epochs = math.ceil( 967 | self.args.max_train_steps / num_update_steps_per_epoch 968 | ) 969 | 970 | if len(self.args.initializer_tokens) > 0: 971 | self.args.initializer_tokens = ", ".join(self.args.initializer_tokens) 972 | 973 | # we need to initialize the trackers we use, and also store our configuration 974 | # the trackers initializes automatically on the main process 975 | if self.accelerator.is_main_process: 976 | self.accelerator.init_trackers("MagicTailor", config=vars(self.args)) 977 | 978 | # set the inference prompt is it is not given 979 | if self.args.inference_prompt == None: 980 | if len(self.placeholder_tokens) > 1: 981 | self.args.inference_prompt = f"{self.placeholder_tokens[0]} with " + " and ".join( 982 | self.placeholder_tokens[1:] 983 | ) 984 | else: 985 | self.args.inference_prompt = self.placeholder_tokens[0] 986 | 987 | # begin training 988 | total_batch_size = ( 989 | self.args.train_batch_size 990 | * self.accelerator.num_processes 991 | * self.args.gradient_accumulation_steps 992 | ) 993 | 994 | logger.info("***** Running training *****") 995 | logger.info(f" Total number of reference images = {len(train_dataset)}") 996 | logger.info(f" Total optimization steps = {self.args.max_train_steps}") 997 | global_step = 0 998 | first_epoch = 0 999 | 1000 | # potentially load in the weights and states from a previous save 1001 | if self.args.resume_from_checkpoint: 1002 | if self.args.resume_from_checkpoint != "latest": 1003 | path = os.path.basename(self.args.resume_from_checkpoint) 1004 | else: 1005 | # get the mos recent checkpoint 1006 | dirs = os.listdir(self.args.checkpoint_dir) 1007 | dirs = [d for d in dirs if d.startswith("checkpoint")] 1008 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 1009 | path = dirs[-1] if len(dirs) > 0 else None 1010 | 1011 | if path is None: 1012 | self.accelerator.print( 1013 | f"Checkpoint '{self.args.resume_from_checkpoint}' does not exist. Starting a new training run." 1014 | ) 1015 | self.args.resume_from_checkpoint = None 1016 | else: 1017 | self.accelerator.print(f"Resuming from checkpoint {path}") 1018 | self.accelerator.load_state(os.path.join(self.args.checkpoint_dir, path)) 1019 | global_step = int(path.split("-")[1]) 1020 | resume_global_step = global_step * self.args.gradient_accumulation_steps 1021 | first_epoch = global_step // num_update_steps_per_epoch 1022 | resume_step = resume_global_step % ( 1023 | num_update_steps_per_epoch * self.args.gradient_accumulation_steps 1024 | ) 1025 | 1026 | # only show the progress bar once on each machine 1027 | progress_bar = tqdm( 1028 | range(global_step, self.args.max_train_steps), 1029 | disable=not self.accelerator.is_local_main_process, 1030 | ) 1031 | progress_bar.set_description("Steps") 1032 | 1033 | # create the attention controller 1034 | self.controller = AttentionStore() 1035 | self.register_attention_control(self.controller) 1036 | 1037 | for epoch in range(first_epoch, self.args.num_train_epochs): 1038 | self.unet.train() 1039 | for step, batch in enumerate(train_dataloader): 1040 | 1041 | if global_step == self.args.phase1_train_steps: 1042 | print("Warm-up ends. Switch to the DS-Bal Paradigm.") 1043 | # setup dual-streaming denoising U-Nets 1044 | # self.unet -> online denoising U-Net 1045 | # self.unet_m -> momentum denoising U-Net 1046 | self.unet_m_device = self.unet.device # modify this to move unet_m to another GPU if you reach the GPU memory limit 1047 | self.unet_m = EMA(self.unet, decay=self.args.ema_decay, device=self.unet_m_device) 1048 | self.unet_m.requires_grad_(False) 1049 | # change lr 1050 | for param_group in optimizer.param_groups: 1051 | param_group['lr'] = self.args.phase2_learning_rate 1052 | 1053 | logs = {} 1054 | 1055 | # skip steps until we reach the resumed step 1056 | if ( 1057 | self.args.resume_from_checkpoint 1058 | and epoch == first_epoch 1059 | and step < resume_step 1060 | ): 1061 | if step % self.args.gradient_accumulation_steps == 0: 1062 | progress_bar.update(1) 1063 | continue 1064 | 1065 | # core training code 1066 | with self.accelerator.accumulate(self.unet): 1067 | 1068 | # DM-Deg 1069 | # adjust weight 1070 | curr_weight = self.args.alpha * (1 - ((global_step + 1) / self.args.max_train_steps) ** self.args.gamma) 1071 | # add noise 1072 | raw_noise = torch.randn_like(batch["pixel_values"]) 1073 | masked_noise = raw_noise * (1 - batch["instance_masks"]) 1074 | batch["pixel_values"] += curr_weight * masked_noise 1075 | batch["pixel_values"] = torch.clamp(batch["pixel_values"], -1, 1) 1076 | 1077 | # convert images to latent space 1078 | latents = self.vae.encode( 1079 | batch["pixel_values"].to(dtype=self.weight_dtype) 1080 | ).latent_dist.sample() 1081 | latents = latents * 0.18215 1082 | 1083 | # sample noise that we'll add to the latents 1084 | noise = torch.randn_like(latents) 1085 | bsz = latents.shape[0] 1086 | 1087 | # sample a random timestep for each image 1088 | timesteps = torch.randint( 1089 | 0, 1090 | self.noise_scheduler.config.num_train_timesteps, 1091 | (bsz,), 1092 | device=latents.device, 1093 | ) 1094 | timesteps = timesteps.long() 1095 | 1096 | # add noise to the latents according to the noise magnitude at each timestep 1097 | # (this is the forward diffusion process) 1098 | noisy_latents = self.noise_scheduler.add_noise( 1099 | latents, noise, timesteps 1100 | ) 1101 | 1102 | # get the text embedding for conditioning 1103 | encoder_hidden_states = self.text_encoder(batch["input_ids"])[0] 1104 | 1105 | # predict the noise residual 1106 | model_pred = self.unet( 1107 | noisy_latents, timesteps, encoder_hidden_states 1108 | ).sample 1109 | 1110 | # get the target for loss depending on the prediction type 1111 | if self.noise_scheduler.config.prediction_type == "epsilon": 1112 | target = noise 1113 | elif self.noise_scheduler.config.prediction_type == "v_prediction": 1114 | target = self.noise_scheduler.get_velocity( 1115 | latents, noise, timesteps 1116 | ) 1117 | else: 1118 | raise ValueError( 1119 | f"Unknown prediction type {self.noise_scheduler.config.prediction_type}" 1120 | ) 1121 | 1122 | # masked diffusion loss 1123 | if self.args.apply_masked_loss: 1124 | masks = batch["instance_masks"] 1125 | downsampled_masks = F.interpolate(input=masks, 1126 | size=(64, 64)) 1127 | model_pred = model_pred * downsampled_masks 1128 | target = target * downsampled_masks 1129 | 1130 | if global_step < self.args.phase1_train_steps: 1131 | # warm-up 1132 | diff_loss = F.mse_loss( 1133 | model_pred.float(), target.float(), reduction="mean" 1134 | ) 1135 | loss = diff_loss 1136 | else: 1137 | # DS-Bal 1138 | model_pred_m = self.unet_m( 1139 | noisy_latents.detach(), timesteps.detach(), encoder_hidden_states.detach() 1140 | ).sample 1141 | 1142 | # use the following one if self.unet_m and self.unet in different GPUs 1143 | # model_pred_m = self.unet_m( 1144 | # noisy_latents.detach().to(self.unet_m_device), 1145 | # timesteps.detach().to(self.unet_m_device), 1146 | # encoder_hidden_states.detach().to(self.unet_m_device) 1147 | # ).sample.to(self.accelerator.device) 1148 | 1149 | if self.args.apply_masked_loss: 1150 | model_pred_m = model_pred_m * downsampled_masks 1151 | sample_wise_shape = (self.args.num_of_assets, -1, *(target.shape[1:])) 1152 | 1153 | # Sample-wise Min-Max Optimization 1154 | unet_loss = F.mse_loss( 1155 | model_pred.float(), target.float(), reduction="none" 1156 | ) 1157 | unet_loss = unet_loss.reshape(sample_wise_shape) 1158 | unet_loss = unet_loss.mean(dim=(1,2,3,4)) 1159 | max_diff_loss = unet_loss.max() 1160 | 1161 | # Selective Preserving Regularization 1162 | unet_m_loss = F.mse_loss( 1163 | model_pred.float(), model_pred_m.float(), reduction="none" 1164 | ) 1165 | unet_m_loss = unet_m_loss.reshape(sample_wise_shape) 1166 | unet_m_loss = unet_m_loss.mean(dim=(1,2,3,4)) 1167 | selected_idx = set(range(self.args.num_of_assets)) 1168 | max_idx = torch.argmax(unet_loss).item() 1169 | selected_idx.discard(max_idx) 1170 | selected_idx = torch.tensor(list(selected_idx)) 1171 | pres_loss = unet_m_loss[selected_idx].mean() 1172 | 1173 | loss = max_diff_loss + self.args.lambda_preservation * pres_loss 1174 | 1175 | # cross-attention loss 1176 | if self.args.lambda_attention != 0: 1177 | attn_loss = 0 1178 | losses_attn = [] 1179 | 1180 | GT_masks = F.interpolate( 1181 | input=batch["instance_masks"], size=(16, 16) 1182 | ) 1183 | agg_attn = self.aggregate_attention( 1184 | res=16, 1185 | from_where=("up", "down"), 1186 | is_cross=True, 1187 | ) 1188 | 1189 | # set for curr_placeholder_token_id assignment with mask_id 1190 | self.serial_token_ids = [ 1191 | int(f.split('_')[0]) for f in sorted(os.listdir(self.args.instance_data_dir)) 1192 | if os.path.isfile(os.path.join(self.args.instance_data_dir, f)) and f.endswith(('.png', '.jpg', '.jpeg')) 1193 | ] 1194 | 1195 | for mask_id in range(len(GT_masks)): 1196 | 1197 | curr_placeholder_token_id = self.placeholder_token_ids[self.serial_token_ids[mask_id]] 1198 | 1199 | curr_cond_batch_idx = mask_id # set to this because mask num is equal to image num 1200 | 1201 | asset_idx = ( 1202 | ( 1203 | batch["input_ids"][curr_cond_batch_idx] 1204 | == curr_placeholder_token_id 1205 | ) 1206 | .nonzero() 1207 | .item() 1208 | ) 1209 | # asset_attn_mask = agg_attn[..., asset_idx] 1210 | asset_attn_mask = agg_attn[mask_id, ..., asset_idx] 1211 | asset_attn_mask = ( 1212 | asset_attn_mask / asset_attn_mask.max() # normalize the attention mask 1213 | ) 1214 | losses_attn.append( 1215 | F.mse_loss( 1216 | GT_masks[mask_id, 0].float(), 1217 | asset_attn_mask.float(), 1218 | reduction="mean", 1219 | ) 1220 | ) 1221 | 1222 | losses_attn = torch.stack(losses_attn) 1223 | attn_loss = losses_attn.mean() 1224 | loss = loss + self.args.lambda_attention * attn_loss 1225 | 1226 | self.accelerator.backward(loss) 1227 | 1228 | # no need to keep the attention store 1229 | self.controller.attention_store = {} 1230 | self.controller.cur_step = 0 1231 | 1232 | if self.accelerator.sync_gradients: 1233 | params_to_clip = (self.unet.parameters()) 1234 | self.accelerator.clip_grad_norm_( 1235 | params_to_clip, self.args.max_grad_norm 1236 | ) 1237 | 1238 | optimizer.step() 1239 | lr_scheduler.step() 1240 | optimizer.zero_grad(set_to_none=self.args.set_grads_to_none) 1241 | 1242 | # update momentum denoising U-Net 1243 | if global_step >= self.args.phase1_train_steps: 1244 | self.unet_m.update_parameters(self.unet) 1245 | 1246 | # checks if the accelerator has performed an optimization step behind the scenes 1247 | if self.accelerator.sync_gradients: 1248 | progress_bar.update(1) 1249 | global_step += 1 1250 | 1251 | # save checkpoints 1252 | if global_step % self.args.checkpointing_steps == 0: 1253 | if self.accelerator.is_main_process: 1254 | save_path = os.path.join( 1255 | self.args.output_dir, f"checkpoint-{global_step}" 1256 | ) 1257 | self.accelerator.save_state(save_path) 1258 | logger.info(f"Saved state to {save_path}") 1259 | 1260 | # save images for logging 1261 | if ( 1262 | self.args.log_checkpoints 1263 | and (global_step % self.args.img_log_steps == 0 or 1264 | global_step == self.args.max_train_steps) 1265 | ): 1266 | ckpts_path = os.path.join( 1267 | self.args.output_dir, "models", f"{global_step:05}" 1268 | ) 1269 | os.makedirs(ckpts_path, exist_ok=True) 1270 | self.save_pipeline(ckpts_path) 1271 | 1272 | img_logs_path = os.path.join(self.args.output_dir, "img_logs") 1273 | os.makedirs(img_logs_path, exist_ok=True) 1274 | 1275 | if self.args.lambda_attention != 0: 1276 | self.controller.cur_step = 1 1277 | 1278 | for mask_id in range(len(GT_masks)): 1279 | log_curr_cond_batch_idx = mask_id 1280 | log_sentence = batch["input_ids"][log_curr_cond_batch_idx] 1281 | log_sentence = log_sentence[ 1282 | (log_sentence != 0) 1283 | & (log_sentence != 49406) 1284 | & (log_sentence != 49407) 1285 | ] 1286 | log_sentence = self.tokenizer.decode(log_sentence) 1287 | self.save_cross_attention_vis( 1288 | log_sentence, 1289 | attention_maps=agg_attn[mask_id].detach().cpu(), 1290 | path=os.path.join( 1291 | img_logs_path, f"{global_step:05}_attn_{mask_id}.jpg" 1292 | ), 1293 | ) 1294 | self.controller.cur_step = 0 1295 | self.controller.attention_store = {} 1296 | 1297 | self.perform_full_inference( 1298 | path=os.path.join( 1299 | img_logs_path, f"{global_step:05}_infer_img.jpg" 1300 | ) 1301 | ) 1302 | 1303 | full_agg_attn = self.aggregate_attention( 1304 | res=16, from_where=("up", "down"), is_cross=True, is_inference=True 1305 | ) 1306 | self.save_cross_attention_vis( 1307 | self.args.inference_prompt, 1308 | attention_maps=full_agg_attn.detach().cpu(), 1309 | path=os.path.join( 1310 | img_logs_path, f"{global_step:05}_infer_attn.jpg" 1311 | ), 1312 | ) 1313 | 1314 | self.controller.cur_step = 0 1315 | self.controller.attention_store = {} 1316 | 1317 | if global_step >= self.args.max_train_steps: 1318 | break 1319 | 1320 | self.save_pipeline(self.args.output_dir) 1321 | 1322 | self.accelerator.end_training() 1323 | 1324 | def unwrap_model(self, model): 1325 | model = self.accelerator.unwrap_model(model) 1326 | model = model._orig_mod if is_compiled_module(model) else model 1327 | return model 1328 | 1329 | def save_pipeline(self, path): 1330 | self.accelerator.wait_for_everyone() 1331 | if self.accelerator.is_main_process: 1332 | # saving LoRA weights 1333 | unet = self.unwrap_model(self.unet) 1334 | unet = unet.to(torch.float32) 1335 | unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) 1336 | text_encoder_state_dict = None 1337 | LoraLoaderMixin.save_lora_weights( 1338 | save_directory=path, 1339 | unet_lora_layers=unet_lora_state_dict, 1340 | text_encoder_lora_layers=text_encoder_state_dict, 1341 | ) 1342 | # saving token embeddings 1343 | torch.save( 1344 | self.text_encoder.get_input_embeddings().state_dict(), 1345 | os.path.join(path, 'token_embedding.pth'), 1346 | ) 1347 | # saving the tokenizer 1348 | self.tokenizer.save_pretrained(os.path.join(path, 'tokenizer')) 1349 | 1350 | def register_attention_control(self, controller): 1351 | attn_procs = {} 1352 | cross_att_count = 0 1353 | for name in self.unet.attn_processors.keys(): 1354 | cross_attention_dim = ( 1355 | None 1356 | if name.endswith("attn1.processor") 1357 | else self.unet.config.cross_attention_dim 1358 | ) 1359 | if name.startswith("mid_block"): 1360 | hidden_size = self.unet.config.block_out_channels[-1] 1361 | place_in_unet = "mid" 1362 | elif name.startswith("up_blocks"): 1363 | block_id = int(name[len("up_blocks.")]) 1364 | hidden_size = list(reversed(self.unet.config.block_out_channels))[ 1365 | block_id 1366 | ] 1367 | place_in_unet = "up" 1368 | elif name.startswith("down_blocks"): 1369 | block_id = int(name[len("down_blocks.")]) 1370 | hidden_size = self.unet.config.block_out_channels[block_id] 1371 | place_in_unet = "down" 1372 | else: 1373 | continue 1374 | cross_att_count += 1 1375 | attn_procs[name] = P2PCrossAttnProcessor( 1376 | controller=controller, place_in_unet=place_in_unet 1377 | ) 1378 | 1379 | self.unet.set_attn_processor(attn_procs) 1380 | controller.num_att_layers = cross_att_count 1381 | 1382 | def get_average_attention(self): 1383 | average_attention = { 1384 | key: [ 1385 | item / self.controller.cur_step 1386 | for item in self.controller.attention_store[key] 1387 | ] 1388 | for key in self.controller.attention_store 1389 | } 1390 | return average_attention 1391 | 1392 | def aggregate_attention( 1393 | self, res: int, from_where: List[str], is_cross: bool, is_inference=False, 1394 | ): 1395 | out = [] 1396 | attention_maps = self.get_average_attention() 1397 | num_pixels = res**2 1398 | for location in from_where: 1399 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: 1400 | if item.shape[1] == num_pixels: 1401 | if is_inference: 1402 | cross_maps = item.reshape( 1403 | 2, -1, res, res, item.shape[-1] 1404 | )[1] 1405 | else: 1406 | cross_maps = item.reshape( 1407 | self.args.train_batch_size, -1, res, res, item.shape[-1] 1408 | ) 1409 | out.append(cross_maps) 1410 | 1411 | if is_inference: 1412 | out = torch.cat(out, dim=0) 1413 | out = out.sum(0) / out.shape[0] 1414 | else: 1415 | out = torch.cat(out, dim=1) 1416 | out = out.sum(1) / out.shape[1] 1417 | 1418 | return out 1419 | 1420 | @torch.no_grad() 1421 | def perform_full_inference(self, path, guidance_scale=7.5): 1422 | self.unet.eval() 1423 | self.text_encoder.eval() 1424 | 1425 | latents = torch.randn((1, 4, 64, 64), device=self.accelerator.device) 1426 | uncond_input = self.tokenizer( 1427 | [""], 1428 | padding="max_length", 1429 | max_length=self.tokenizer.model_max_length, 1430 | return_tensors="pt", 1431 | ).to(self.accelerator.device) 1432 | input_ids = self.tokenizer( 1433 | self.args.inference_prompt, 1434 | padding="max_length", 1435 | truncation=True, 1436 | max_length=self.tokenizer.model_max_length, 1437 | return_tensors="pt", 1438 | ).input_ids.to(self.accelerator.device) 1439 | cond_embeddings = self.text_encoder(input_ids)[0] 1440 | uncond_embeddings = self.text_encoder(uncond_input.input_ids)[0] 1441 | text_embeddings = torch.cat([uncond_embeddings, cond_embeddings]) 1442 | 1443 | for t in self.validation_scheduler.timesteps: 1444 | latent_model_input = torch.cat([latents] * 2) 1445 | latent_model_input = self.validation_scheduler.scale_model_input( 1446 | latent_model_input, timestep=t 1447 | ) 1448 | 1449 | pred = self.unet( 1450 | latent_model_input, t, encoder_hidden_states=text_embeddings 1451 | ) 1452 | noise_pred = pred.sample 1453 | 1454 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 1455 | noise_pred = noise_pred_uncond + guidance_scale * ( 1456 | noise_pred_text - noise_pred_uncond 1457 | ) 1458 | 1459 | latents = self.validation_scheduler.step(noise_pred, t, latents).prev_sample 1460 | latents = 1 / 0.18215 * latents 1461 | 1462 | images = self.vae.decode(latents.to(self.weight_dtype)).sample 1463 | images = (images / 2 + 0.5).clamp(0, 1) 1464 | images = images.detach().cpu().permute(0, 2, 3, 1).numpy() 1465 | images = (images * 255).round().astype("uint8") 1466 | 1467 | self.unet.train() 1468 | 1469 | Image.fromarray(images[0]).save(path) 1470 | 1471 | def make_image_grid(self, images, rows, cols, resize=None): 1472 | """ 1473 | Prepares a single grid of images. Useful for visualization purposes. 1474 | """ 1475 | assert len(images) == rows * cols 1476 | if resize is not None: 1477 | images = [img.resize((resize, resize)) for img in images] 1478 | w, h = images[0].size 1479 | grid = Image.new("RGB", size=(cols * w, rows * h)) 1480 | for i, img in enumerate(images): 1481 | grid.paste(img, box=(i % cols * w, i // cols * h)) 1482 | return grid 1483 | 1484 | @torch.no_grad() 1485 | def save_cross_attention_vis(self, prompt, attention_maps, path): 1486 | tokens = self.tokenizer.encode(prompt) 1487 | images = [] 1488 | for i in range(len(tokens)): 1489 | if int(tokens[i]) in [0, 49406, 49407]: 1490 | continue 1491 | image = attention_maps[:, :, i] 1492 | image = 255 * image / image.max() 1493 | image = image.unsqueeze(-1).expand(*image.shape, 3) 1494 | image = image.numpy().astype(np.uint8) 1495 | image = np.array(Image.fromarray(image).resize((512, 512))) 1496 | image = image[:, :, ::-1].copy() 1497 | image = cv2.applyColorMap(image, cv2.COLORMAP_JET) 1498 | image = image[:, :, ::-1].copy() 1499 | image = ptp_utils.text_under_image( 1500 | image, self.tokenizer.decode(int(tokens[i])) 1501 | ) 1502 | images.append(image) 1503 | vis = ptp_utils.view_images(np.stack(images, axis=0)) 1504 | vis.save(path) 1505 | 1506 | 1507 | class P2PCrossAttnProcessor: 1508 | def __init__(self, controller, place_in_unet): 1509 | super().__init__() 1510 | self.controller = controller 1511 | self.place_in_unet = place_in_unet 1512 | 1513 | def __call__( 1514 | self, 1515 | attn: CrossAttention, 1516 | hidden_states, 1517 | encoder_hidden_states=None, 1518 | attention_mask=None, 1519 | ): 1520 | batch_size, sequence_length, _ = hidden_states.shape 1521 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 1522 | 1523 | query = attn.to_q(hidden_states) 1524 | 1525 | is_cross = encoder_hidden_states is not None 1526 | encoder_hidden_states = ( 1527 | encoder_hidden_states 1528 | if encoder_hidden_states is not None 1529 | else hidden_states 1530 | ) 1531 | key = attn.to_k(encoder_hidden_states) 1532 | value = attn.to_v(encoder_hidden_states) 1533 | 1534 | query = attn.head_to_batch_dim(query) 1535 | key = attn.head_to_batch_dim(key) 1536 | value = attn.head_to_batch_dim(value) 1537 | 1538 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 1539 | 1540 | # one line change 1541 | self.controller(attention_probs, is_cross, self.place_in_unet) 1542 | 1543 | hidden_states = torch.bmm(attention_probs, value) 1544 | hidden_states = attn.batch_to_head_dim(hidden_states) 1545 | 1546 | # linear proj 1547 | hidden_states = attn.to_out[0](hidden_states) 1548 | # dropout 1549 | hidden_states = attn.to_out[1](hidden_states) 1550 | 1551 | return hidden_states 1552 | 1553 | 1554 | class EMA(torch.optim.swa_utils.AveragedModel): 1555 | """ 1556 | Maintains moving averages of model parameters using an exponential decay. 1557 | ``ema_avg = decay * avg_model_param + (1 - decay) * model_param`` 1558 | `torch.optim.swa_utils.AveragedModel `_ 1559 | is used to compute the EMA. 1560 | """ 1561 | def __init__(self, model, decay, device="cpu"): 1562 | def ema_avg(avg_model_param, model_param, num_averaged): 1563 | return decay * avg_model_param + (1 - decay) * model_param 1564 | super().__init__(model, device, ema_avg, use_buffers=True) 1565 | 1566 | 1567 | if __name__ == "__main__": 1568 | MagicTailor() 1569 | --------------------------------------------------------------------------------