├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── LICENSE ├── README.assets ├── a__,_a___and_a___are_standing_near_a_forest---baseline---e156a9f7.png ├── a__,_a___and_a___near_the_castle,_4K,_high_quality,_high_resolution,_best_quality---baseline---e27b4344.png ├── base5.png ├── concept_list.jpg └── lora_weight.jpg ├── README.md ├── datasets ├── benchmark_prompts │ └── single-concept │ │ ├── characters │ │ └── test_character.txt │ │ ├── objects │ │ ├── test_chair.txt │ │ ├── test_pet.txt │ │ ├── test_plushy.txt │ │ ├── test_table.txt │ │ └── test_vase.txt │ │ └── scenes │ │ └── test_scene.txt ├── data_cfgs │ └── MixofShow │ │ ├── multi-concept │ │ ├── anime │ │ │ └── hina+kario+tezuka_anythingv4.json │ │ └── real │ │ │ └── potter+hermione+thanos_chilloutmix.json │ │ └── single-concept │ │ ├── characters │ │ ├── anime │ │ │ ├── hina_amano.json │ │ │ ├── mitsuha_miyamizu.json │ │ │ ├── miyazono_kaori.json │ │ │ ├── son_goku.json │ │ │ └── tezuka_kunimitsu.json │ │ └── real │ │ │ ├── batman.json │ │ │ ├── hermione.json │ │ │ ├── ironman.json │ │ │ ├── potter.json │ │ │ └── thanos.json │ │ ├── objects │ │ └── real │ │ │ ├── B2.json │ │ │ ├── carA.json │ │ │ ├── catA.json │ │ │ ├── chair.json │ │ │ ├── dogA.json │ │ │ ├── dogB.json │ │ │ ├── f35.json │ │ │ ├── porsche_356a.json │ │ │ ├── sailboat.json │ │ │ ├── table.json │ │ │ ├── vase.json │ │ │ └── yacht.json │ │ └── scenes │ │ └── real │ │ ├── pyramid.json │ │ └── wululu.json ├── validation_prompts │ └── single-concept │ │ ├── characters │ │ ├── test_girl.txt │ │ ├── test_goku.txt │ │ ├── test_man.txt │ │ └── test_woman.txt │ │ ├── objects │ │ ├── test_airplane.txt │ │ ├── test_boat.txt │ │ ├── test_car.txt │ │ ├── test_cat.txt │ │ ├── test_chair.txt │ │ ├── test_dog.txt │ │ ├── test_table.txt │ │ └── test_vase.txt │ │ └── scenes │ │ └── test_scene.txt └── validation_spatial_condition │ ├── characters-objects │ ├── bengio+lecun+chair.txt │ ├── bengio+lecun+chair_pose.png │ ├── bengio+lecun+chair_sketch.png │ ├── harry+catA+dogA.txt │ ├── harry+catA+dogA_pose.png │ ├── harry+catA+dogA_sketch.png │ ├── harry_heminone_scene.txt │ ├── harry_heminone_scene_pose.png │ └── harry_heminone_scene_sketch.png │ ├── multi-characters │ ├── anime_pose │ │ ├── hina_mitsuha_kario.png │ │ ├── hina_mitsuha_kario.txt │ │ ├── hina_tezuka_kario.png │ │ ├── hina_tezuka_kario.txt │ │ ├── hina_tezuka_mitsuha_goku_kaori.png │ │ └── hina_tezuka_mitsuha_goku_kaori.txt │ ├── anime_pose_2x │ │ └── hina_tezuka_kario_2x.png │ ├── real_pose │ │ ├── bengio_lecun_bengio.png │ │ ├── bengio_lecun_bengio.txt │ │ ├── harry_hermione_thanos.png │ │ └── harry_hermione_thanos.txt │ └── real_pose_2x │ │ └── harry_hermione_thanos_2x.png │ └── multi-objects │ ├── dogA_catA_dogB.jpg │ ├── dogA_catA_dogB.txt │ ├── two_chair_table_vase.jpg │ └── two_chair_table_vase.txt ├── docs └── Dataset.md ├── fuse.sh ├── gradient_fusion.py ├── mixofshow ├── data │ ├── __init__.py │ ├── lora_dataset.py │ ├── pil_transform.py │ └── prompt_dataset.py ├── models │ └── edlora.py ├── pipelines │ ├── pipeline_edlora.py │ ├── pipeline_regionally_t2iadapter.py │ └── trainer_edlora.py └── utils │ ├── __init__.py │ ├── arial.ttf │ ├── convert_edlora_to_diffusers.py │ ├── ptp_util.py │ ├── registry.py │ └── util.py ├── options ├── test │ └── EDLoRA │ │ ├── anime │ │ └── 1001_EDLoRA_hina_Anyv4_B4_Iter1K.yml │ │ └── human │ │ ├── 8101_EDLoRA_potter_Cmix_B4_Repeat500.yml │ │ └── 8102_EDLoRA_hermione_Cmix_B4_Repeat500.yml └── train │ └── EDLoRA │ ├── anime │ ├── 1001_1_EDLoRA_hina_Anyv4_B4_Repeat500_v6_final_nomask.yml │ ├── 1002_1_EDLoRA_kaori_Anyv4_B4_Repeat500_v6_final_nomask.yml │ └── 1003_1_EDLoRA_tezuka_Anyv4_B4_Repeat500_v6_final_nomask.yml │ └── real │ ├── 8101_EDLoRA_potter_Cmix_B4_Repeat500.yml │ ├── 8102_EDLoRA_hermione_Cmix_B4_Repeat500.yml │ └── 8103_EDLoRA_thanos_Cmix_B4_Repeat250.yml ├── regionally_controlable_sampling.py ├── regionally_sample.sh ├── requirements.txt ├── test_edlora.py └── train_edlora.py /.gitignore: -------------------------------------------------------------------------------- 1 | datasets/old/* 2 | datasets/data/* 3 | visualization/* 4 | experiments/old/* 5 | results/* 6 | tb_logger/* 7 | wandb/* 8 | *idea/* 9 | *.DS_Store 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no crosskv-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | known_third_party = IPython,PIL,accelerate,cv2,diffusers,einops,numpy,omegaconf,packaging,torch,torchvision,tqdm,transformers 3 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # flake8 3 | - repo: https://github.com/PyCQA/flake8 4 | rev: 3.8.3 5 | hooks: 6 | - id: flake8 7 | args: ["--ignore=W504, W503, E128, E124", "--max-line-length=248"] 8 | exclude: ^(mixofshow/models/adapter.py)$|^(test_animatediff.py)$|^(videoswap/utils/convert_lora_safetensor_to_diffusers.py)$|^videoswap/models/animatediff_models/.*\.py$ 9 | 10 | # modify known_third_party 11 | - repo: https://github.com/asottile/seed-isort-config 12 | rev: v2.2.0 13 | hooks: 14 | - id: seed-isort-config 15 | 16 | # isort 17 | - repo: https://github.com/timothycrosley/isort 18 | rev: 5.12.0 19 | hooks: 20 | - id: isort 21 | args: [--line-length=120] 22 | 23 | # yapf 24 | - repo: https://github.com/pre-commit/mirrors-yapf 25 | rev: v0.30.0 26 | hooks: 27 | - id: yapf 28 | args: [--style, "{based_on_style: pep8, column_limit: 248}"] 29 | 30 | # pre-commit-hooks 31 | - repo: https://github.com/pre-commit/pre-commit-hooks 32 | rev: v3.2.0 33 | hooks: 34 | - id: trailing-whitespace # Trim trailing whitespace 35 | - id: check-yaml # Attempt to load all yaml files to verify syntax 36 | - id: check-merge-conflict # Check for files that contain merge conflict strings 37 | - id: double-quote-string-fixer # Replace double quoted strings with single quoted strings 38 | - id: end-of-file-fixer # Make sure files end in a newline and only a newline 39 | - id: requirements-txt-fixer # Sort entries in requirements.txt and remove incorrect entry for pkg-resources==0.0.0 40 | - id: fix-encoding-pragma # Remove the coding pragma: # -*- coding: utf-8 -*- 41 | args: ["--remove"] 42 | - id: mixed-line-ending # Replace or check mixed line ending 43 | args: ["--fix=lf"] 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. 2 | Mix-of-Show is licensed under the Apache License Version 2.0 except for the third-party components listed below. 3 | 4 | Apache License 5 | 6 | Version 2.0, January 2004 7 | 8 | http://www.apache.org/licenses/ 9 | 10 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 11 | 1. Definitions. 12 | 13 | "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. 14 | 15 | "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 18 | 19 | "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. 20 | 21 | "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. 22 | 23 | "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. 24 | 25 | "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). 26 | 27 | "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. 28 | 29 | "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." 30 | 31 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 32 | 33 | 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 34 | 35 | 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 36 | 37 | 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: 38 | 39 | You must give any other recipients of the Work or Derivative Works a copy of this License; and 40 | 41 | You must cause any modified files to carry prominent notices stating that You changed the files; and 42 | 43 | You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and 44 | 45 | If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. 46 | 47 | You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 48 | 49 | 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 50 | 51 | 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 52 | 53 | 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 54 | 55 | 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 56 | 57 | 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 58 | 59 | END OF TERMS AND CONDITIONS 60 | 61 | 62 | 63 | Other dependencies and licenses: 64 | 65 | 66 | Open Source Software Licensed under the BSD 3-Clause License: 67 | -------------------------------------------------------------------- 68 | 1. torchvision 69 | Copyright (c) Soumith Chintala 2016 70 | 71 | 72 | Terms of the BSD 3-Clause License: 73 | -------------------------------------------------------------------- 74 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 75 | 76 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 77 | 78 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 79 | 80 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 81 | 82 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH 83 | 84 | 85 | 86 | Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein: 87 | -------------------------------------------------------------------- 88 | 1. torch 89 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 90 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 91 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 92 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 93 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 94 | Copyright (c) 2011-2013 NYU (Clement Farabet) 95 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 96 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 97 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 98 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 99 | 100 | All contributions by Facebook: 101 | Copyright (c) 2016 Facebook Inc. 102 | 103 | All contributions by Google: 104 | Copyright (c) 2015 Google Inc. 105 | All rights reserved. 106 | 107 | All contributions by Yangqing Jia: 108 | Copyright (c) 2015 Yangqing Jia 109 | All rights reserved. 110 | 111 | All contributions from Caffe: 112 | Copyright(c) 2013, 2014, 2015, the respective contributors 113 | All rights reserved. 114 | 115 | All other contributions: 116 | Copyright(c) 2015, 2016 the respective contributors 117 | All rights reserved. 118 | 119 | 120 | A copy of the BSD 3-Clause License is included in this file. 121 | 122 | For the license of other third party components, please refer to the following URL: 123 | https://github.com/pytorch/pytorch/blob/v1.7.1/NOTICE 124 | 125 | 126 | 127 | Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein: 128 | -------------------------------------------------------------------- 129 | 1. numpy 130 | Copyright (c) 2005-2021, NumPy Developers. 131 | 132 | 133 | A copy of the BSD 3-Clause License is included in this file. 134 | 135 | For the license of other third party components, please refer to the following URL: 136 | https://github.com/numpy/numpy/blob/v1.20.1/LICENSES_bundled.txt 137 | 138 | 139 | 140 | Open Source Software Licensed under the MIT License and Other Licenses of the Third-Party Components therein: 141 | -------------------------------------------------------------------- 142 | 1. opencv-python 143 | Copyright (c) Olli-Pekka Heinisuo 144 | 145 | 146 | Terms of the MIT License: 147 | -------------------------------------------------------------------- 148 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 149 | 150 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 151 | 152 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 153 | 154 | For the license of other third party components, please refer to the following URL: 155 | https://github.com/opencv/opencv-python/blob/48/README.md 156 | 157 | 158 | Open Source Software Licensed under the MIT License and Other Licenses of the Third-Party Components therein: 159 | -------------------------------------------------------------------- 160 | 1. tqdm 161 | Copyright (c) 2013 noamraph 162 | 163 | 164 | A copy of the MIT License is included in this file. 165 | 166 | For the license of other third party components, please refer to the following URL: 167 | https://github.com/tqdm/tqdm/blob/v4.56.2/LICENCE 168 | 169 | 170 | 171 | Open Source Software Licensed under the HPND License: 172 | -------------------------------------------------------------------- 173 | 1. Pillow 174 | Copyright © 2010-2021 by Alex Clark and contributors 175 | 176 | 177 | Terms of the HPND License: 178 | -------------------------------------------------------------------- 179 | By obtaining, using, and/or copying this software and/or its associated 180 | documentation, you agree that you have read, understood, and will comply 181 | with the following terms and conditions: 182 | 183 | Permission to use, copy, modify, and distribute this software and its 184 | associated documentation for any purpose and without fee is hereby granted, 185 | provided that the above copyright notice appears in all copies, and that 186 | both that copyright notice and this permission notice appear in supporting 187 | documentation, and that the name of Secret Labs AB or the author not be 188 | used in advertising or publicity pertaining to distribution of the software 189 | without specific, written prior permission. 190 | 191 | SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS 192 | SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. 193 | IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR BE LIABLE FOR ANY SPECIAL, 194 | INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM 195 | LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE 196 | OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR 197 | PERFORMANCE OF THIS SOFTWARE. 198 | -------------------------------------------------------------------------------- /README.assets/a__,_a___and_a___are_standing_near_a_forest---baseline---e156a9f7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/README.assets/a__,_a___and_a___are_standing_near_a_forest---baseline---e156a9f7.png -------------------------------------------------------------------------------- /README.assets/a__,_a___and_a___near_the_castle,_4K,_high_quality,_high_resolution,_best_quality---baseline---e27b4344.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/README.assets/a__,_a___and_a___near_the_castle,_4K,_high_quality,_high_resolution,_best_quality---baseline---e27b4344.png -------------------------------------------------------------------------------- /README.assets/base5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/README.assets/base5.png -------------------------------------------------------------------------------- /README.assets/concept_list.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/README.assets/concept_list.jpg -------------------------------------------------------------------------------- /README.assets/lora_weight.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/README.assets/lora_weight.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mix-of-Show 2 | 3 | 🚩🚩🚩: Main branch for community has been released (keep updating the main branch). 4 | 5 | ------ 6 | 7 | Official codes for Mix-of-Show. This branch is for application, including simplified codes, memory/speed optimization and performance improvement. For research purpose, please refer to original [research branch](https://github.com/TencentARC/Mix-of-Show/tree/research_branch) (paper results, evaluation, and comparison methods). 8 | 9 | **[NeurIPS 2023]**- **[Mix-of-Show: Decentralized Low-Rank Adaptation for Multi-Concept Customization of Diffusion Models](https://arxiv.org/abs/2305.18292)** 10 |
11 | [Yuchao Gu](https://ycgu.site/), [Xintao Wang](https://xinntao.github.io/), [Jay Zhangjie Wu](https://zhangjiewu.github.io/), [Yunjun Shi](https://yujun-shi.github.io/), [Yunpeng Chen](https://cypw.github.io/), Zihan Fan, Wuyou Xiao, [Rui Zhao](https://ruizhaocv.github.io/), Shuning Chang, [Weijia Wu](https://weijiawu.github.io/), [Yixiao Ge](https://geyixiao.com/), Ying Shan, [Mike Zheng Shou](https://sites.google.com/view/showlab) 12 |
13 | 14 | [![Project Website](https://img.shields.io/badge/Project-Website-orange)](https://showlab.github.io/Mix-of-Show/)[![arXiv](https://img.shields.io/badge/arXiv-2305.18292-b31b1b.svg)](https://arxiv.org/abs/2305.18292) 15 | 16 | ## 📋 Results 17 | 18 | ### Single-Concept Results 19 | 20 | Difference from LoRA: 21 | 22 | 1) In ED-LoRA, the embedding (LoRA weight=0) already encodes stable identity (use harry potter as example): 23 | 24 | ![lora_weight](./README.assets/lora_weight.jpg) 25 | 26 | 2) Based on the ED-LoRA, we can support multi-concept fusion without much identity loss. 27 | 28 | ### Multi-Concept Results 29 | 30 | **Concept List:** 31 | 32 | ![concept_list](./README.assets/concept_list.jpg) 33 | 34 | **Anime Character**: 35 | 36 | ![a__,_a___and_a___are_standing_near_a_forest---baseline---e156a9f7](./README.assets/a__,_a___and_a___are_standing_near_a_forest---baseline---e156a9f7.png) 37 | 38 | **Real Character**: 39 | 40 | ![a__,_a___and_a___near_the_castle,_4K,_high_quality,_high_resolution,_best_quality---baseline---e27b4344](./README.assets/a__,_a___and_a___near_the_castle,_4K,_high_quality,_high_resolution,_best_quality---baseline---e27b4344.png) 41 | 42 | ![base5](./README.assets/base5.png) 43 | 44 | ------ 45 | 46 | Conneting Mix-of-Show with Stable Diffusion Video for Animatation: 47 | 48 | 49 | 50 | https://github.com/TencentARC/Mix-of-Show/assets/31696690/5a677e99-2c86-41dc-a9da-ba92b3155717 51 | 52 | 53 | 54 | 55 | ## 🚩 Updates/Todo List 56 | 57 | - [ ] StableDiffusion XL support. 58 | - [ ] Colab Demo. 59 | - [x] Oct. 8, 2023. Add Attention Reg & Quality Improvement. 60 | - [x] Oct. 3, 2023. Release Main Branch for Community. 61 | - [x] Jun. 12, 2023. Research Code Released. Please switch to [research branch](https://github.com/TencentARC/Mix-of-Show/tree/research_branch). 62 | 63 | ## :wrench: Dependencies and Installation 64 | 65 | - Python >= 3.9 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html)) 66 | - Diffusers==0.19.3 67 | - XFormer (is recommend to save memory) 68 | 69 | ## ⏬ Pretrained Model and Data Preparation 70 | 71 | ### Pretrained Model Preparation 72 | 73 | We adopt the [ChilloutMix](https://civitai.com/models/6424/chilloutmix) for real-world concepts, and [Anything-v4](https://huggingface.co/andite/anything-v4.0) for anime concepts. 74 | 75 | ```bash 76 | git clone https://github.com/TencentARC/Mix-of-Show.git 77 | 78 | cd experiments/pretrained_models 79 | 80 | # Diffusers-version ChilloutMix 81 | git-lfs clone https://huggingface.co/windwhinny/chilloutmix.git 82 | 83 | # Diffusers-version Anything-v4 84 | git-lfs clone https://huggingface.co/andite/anything-v4.0.git 85 | ``` 86 | 87 | ### Data Preparation 88 | 89 | Note: Data selection and tagging are important in single-concept tuning. We strongly recommend checking the data processing in [sd-scripts](https://github.com/kohya-ss/sd-scripts). **In our ED-LoRA, we do not require any regularization dataset.** The detailed dataset preparation steps can refer to [Dataset.md](docs/Dataset.md). Our preprocessed data used in this repo is available at [Google Drive](https://drive.google.com/file/d/1O5oev8861N_KmKtqefb45l3SiSblbo5O/view?usp=sharing). 90 | 91 | ## :computer: Single-Client Concept Tuning 92 | 93 | ### Step 1: Modify the Config 94 | 95 | Before tuning, it is essential to specify the data paths and adjust certain hyperparameters in the corresponding config file. Followings are some basic config settings to be modified. 96 | 97 | ```yaml 98 | datasets: 99 | train: 100 | # Concept data config 101 | concept_list: datasets/data_cfgs/edlora/single-concept/characters/anime/hina_amano.json 102 | replace_mapping: 103 | : # concept new token 104 | val_vis: 105 | # Validation prompt for visualization during tuning 106 | prompts: datasets/validation_prompts/single-concept/characters/test_girl.txt 107 | replace_mapping: 108 | : # Concept new token 109 | 110 | models: 111 | enable_edlora: true # true means ED-LoRA, false means vallina LoRA 112 | new_concept_token: + # Concept new token, use "+" to connect 113 | initializer_token: +girl 114 | # Init token, only need to revise the later one based on the semantic category of given concept 115 | 116 | val: 117 | val_during_save: true # When saving checkpoint, visualize sample results. 118 | compose_visualize: true # Compose all samples into a large grid figure for visualization 119 | ``` 120 | 121 | ### Step 2: Start Tuning 122 | 123 | We tune each concept with 2 A100 GPU. Similar to LoRA, community user can enable gradient accumulation, xformer, gradient checkpoint for tuning on one GPU. 124 | 125 | ```bash 126 | accelerate launch train_edlora.py -opt options/train/EDLoRA/real/8101_EDLoRA_potter_Cmix_B4_Repeat500.yml 127 | ``` 128 | 129 | ### Step 3: Sample 130 | 131 | **Download our trained model** from [Google Drive](https://drive.google.com/drive/folders/1ArvKsxj41PcWbw_UZcyc5NDcKEQjK8pl?usp=sharing). 132 | 133 | Direct sample image: 134 | 135 | ```python 136 | import torch 137 | from diffusers import DPMSolverMultistepScheduler 138 | from mixofshow.pipelines.pipeline_edlora import EDLoRAPipeline, StableDiffusionPipeline 139 | from mixofshow.utils.convert_edlora_to_diffusers import convert_edlora 140 | 141 | pretrained_model_path = 'experiments/pretrained_models/chilloutmix' 142 | lora_model_path = 'experiments/2002_EDLoRA_hermione_Cmix_B4_Iter1K/models/checkpoint-latest/edlora.pth' 143 | enable_edlora = True # True for edlora, False for lora 144 | 145 | pipeclass = EDLoRAPipeline if enable_edlora else StableDiffusionPipeline 146 | pipe = pipeclass.from_pretrained(pretrained_model_path, scheduler=DPMSolverMultistepScheduler.from_pretrained(pretrained_model_path, subfolder='scheduler'), torch_dtype=torch.float16).to('cuda') 147 | pipe, new_concept_cfg = convert_edlora(pipe, torch.load(lora_model_path), enable_edlora=enable_edlora, alpha=0.7) 148 | pipe.set_new_concept_cfg(new_concept_cfg) 149 | 150 | TOK = ' ' # the TOK is the concept name when training lora/edlora 151 | prompt = f'a {TOK} in front of eiffel tower, 4K, high quality, high resolution' 152 | negative_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' 153 | image = pipe(prompt, negative_prompt=negative_prompt, height=768, width=512, num_inference_steps=50, guidance_scale=7.5).images[0] 154 | image.save('res.jpg') 155 | ``` 156 | 157 | Or sampling image grids for comprehensive visualization: specify the model path in test config, and run following command. 158 | 159 | ```bash 160 | python test_edlora.py -opt options/test/EDLoRA/human/8101_EDLoRA_potter_Cmix_B4_Repeat500.yml 161 | ``` 162 | 163 | ## :computer: Center-Node Concept Fusion 164 | 165 | ### Step 1: Collect Concept Models 166 | 167 | Collect all concept models you want to extend the pretrained model and modify the config in **datasets/data_cfgs/MixofShow/multi-concept/real/*** accordingly. 168 | 169 | ```yaml 170 | [ 171 | { 172 | "lora_path": "experiments/EDLoRA_Models/Base_Chilloutmix/characters/edlora_potter.pth", # ED-LoRA path 173 | "unet_alpha": 1.0, # usually use full identity = 1.0 174 | "text_encoder_alpha": 1.0, # usually use full identity = 1.0 175 | "concept_name": " " # new concept token 176 | }, 177 | { 178 | "lora_path": "experiments/EDLoRA_Models/Base_Chilloutmix/characters/edlora_hermione.pth", 179 | "unet_alpha": 1.0, 180 | "text_encoder_alpha": 1.0, 181 | "concept_name": " " 182 | }, 183 | 184 | ... # keep adding new concepts for extending the pretrained models 185 | ] 186 | ``` 187 | 188 | ### Step 2: Gradient Fusion 189 | 190 | ```bash 191 | bash fuse.sh 192 | ``` 193 | 194 | ### Step 3: Sample 195 | 196 | **Download our fused model** from [Google Drive](https://drive.google.com/drive/folders/1ArvKsxj41PcWbw_UZcyc5NDcKEQjK8pl?usp=sharing). 197 | 198 | **Single-concept sampling from fused model:** 199 | 200 | ```python 201 | import json 202 | import os 203 | 204 | import torch 205 | from diffusers import DPMSolverMultistepScheduler 206 | 207 | from mixofshow.pipelines.pipeline_edlora import EDLoRAPipeline 208 | 209 | pretrained_model_path = 'experiments/composed_edlora/chilloutmix/potter+hermione+thanos_chilloutmix/combined_model_base' 210 | enable_edlora = True # True for edlora, False for lora 211 | 212 | pipe = EDLoRAPipeline.from_pretrained(pretrained_model_path, scheduler=DPMSolverMultistepScheduler.from_pretrained(pretrained_model_path, subfolder='scheduler'), torch_dtype=torch.float16).to('cuda') 213 | with open(f'{pretrained_model_path}/new_concept_cfg.json', 'r') as fr: 214 | new_concept_cfg = json.load(fr) 215 | pipe.set_new_concept_cfg(new_concept_cfg) 216 | 217 | TOK = ' ' # the TOK is the concept name when training lora/edlora 218 | prompt = f'a {TOK} in front of mount fuji' 219 | negative_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' 220 | 221 | image = pipe(prompt, negative_prompt=negative_prompt, height=1024, width=512, num_inference_steps=50, generator=torch.Generator('cuda').manual_seed(1), guidance_scale=7.5).images[0] 222 | 223 | image.save(f'res.jpg') 224 | ``` 225 | 226 | **Regionally controllable multi-concept sampling:** 227 | 228 | ```bash 229 | bash regionally_sample.sh 230 | ``` 231 | 232 | ## 📜 License and Acknowledgement 233 | 234 | This project is released under the [Apache 2.0 license](LICENSE).
235 | This codebase builds on [diffusers](https://github.com/huggingface/diffusers). Thanks for open-sourcing! Besides, we acknowledge following amazing open-sourcing projects: 236 | 237 | - LoRA for Diffusion Models (https://github.com/cloneofsimo/lora, https://github.com/kohya-ss/sd-scripts). 238 | 239 | 240 | - Custom Diffusion (https://github.com/adobe-research/custom-diffusion). 241 | 242 | 243 | - T2I-Adapter (https://github.com/TencentARC/T2I-Adapter). 244 | 245 | 246 | 247 | ## 🌏 Citation 248 | 249 | ```bibtex 250 | @article{gu2023mixofshow, 251 | title={Mix-of-Show: Decentralized Low-Rank Adaptation for Multi-Concept Customization of Diffusion Models}, 252 | author={Gu, Yuchao and Wang, Xintao and Wu, Jay Zhangjie and Shi, Yujun and Chen Yunpeng and Fan, Zihan and Xiao, Wuyou and Zhao, Rui and Chang, Shuning and Wu, Weijia and Ge, Yixiao and Shan Ying and Shou, Mike Zheng}, 253 | journal={arXiv preprint arXiv:2305.18292}, 254 | year={2023} 255 | } 256 | ``` 257 | 258 | 259 | 260 | ## 📧 Contact 261 | 262 | If you have any questions and improvement suggestions, please email Yuchao Gu (yuchaogu9710@gmail.com), or open an issue. 263 | -------------------------------------------------------------------------------- /datasets/benchmark_prompts/single-concept/characters/test_character.txt: -------------------------------------------------------------------------------- 1 | A photo of on the beach, small waves, detailed symmetric face, beautiful composition 2 | A , in front of Eiffel tower 3 | A , near the mount fuji 4 | A , in the forest 5 | A , walking on the street 6 | 7 | A , cyberpunk 2077, 4K, 3d render in unreal engine 8 | A watercolor painting of a 9 | A painting of a in the style of Vincent Van Gogh 10 | A painting of a in the style of Claude Monet 11 | A in the style of Pixel Art 12 | 13 | A sit on the chair 14 | A ride a horse 15 | A , wearing a headphone 16 | A , wearing a sunglass 17 | A , wearing a Santa hat 18 | 19 | A smiling 20 | An angry 21 | A running 22 | A jumping 23 | A is lying down 24 | -------------------------------------------------------------------------------- /datasets/benchmark_prompts/single-concept/objects/test_chair.txt: -------------------------------------------------------------------------------- 1 | A , near the beach 2 | A , near the Eiffel tower 3 | A , near the mount fuji 4 | A , in the living room 5 | A , in times square 6 | 7 | A , cyberpunk 2077, 4K, 3d render in unreal engine 8 | A watercolor painting of a 9 | A painting of a in the style of Vincent Van Gogh 10 | A painting of a in the style of Claude Monet 11 | A in the style of Pixel Art 12 | 13 | A girl sit on a 14 | A cat sit on 15 | A dog sit on 16 | A vase is placed on the 17 | A boy sit on the 18 | 19 | A grey 20 | A close view of 21 | A top view of 22 | A in rainbow colors 23 | A broken 24 | -------------------------------------------------------------------------------- /datasets/benchmark_prompts/single-concept/objects/test_pet.txt: -------------------------------------------------------------------------------- 1 | A , in the swimming pool 2 | A , in front of Eiffel tower 3 | A , near the mount fuji 4 | A , in the forest 5 | A , walking on the street 6 | 7 | A , cyberpunk 2077, 4K, 3d render in unreal engine 8 | A watercolor painting of a 9 | A painting of a in the style of Vincent Van Gogh 10 | A painting of a in the style of Claude Monet 11 | A in the style of Pixel Art 12 | 13 | A sit on the chair 14 | A on the boat 15 | A , wearing a headphone 16 | A , wearing a sunglass 17 | A playing with a ball 18 | 19 | A sad 20 | An angry 21 | A running 22 | A jumping 23 | A is lying down 24 | -------------------------------------------------------------------------------- /datasets/benchmark_prompts/single-concept/objects/test_plushy.txt: -------------------------------------------------------------------------------- 1 | A , in the ocean 2 | A , near the Eiffel tower 3 | A , near the mount fuji 4 | A , in the forest 5 | A , in times square 6 | 7 | A , cyberpunk 2077, 4K, 3d render in unreal engine 8 | A watercolor painting of a 9 | A painting of a in the style of Vincent Van Gogh 10 | A painting of a in the style of Claude Monet 11 | A in the style of Pixel Art 12 | 13 | A on a table 14 | A on a chair 15 | A on a skateboard 16 | A child is playing a 17 | A on a carpet 18 | 19 | A close view of 20 | A top view of 21 | A in rainbow colors 22 | a fallen 23 | A broken 24 | -------------------------------------------------------------------------------- /datasets/benchmark_prompts/single-concept/objects/test_table.txt: -------------------------------------------------------------------------------- 1 | A , near the beach 2 | A , near the Eiffel tower 3 | A , near the mount fuji 4 | A , in the living room 5 | A , in times square 6 | 7 | A , cyberpunk 2077, 4K, 3d render in unreal engine 8 | A watercolor painting of a 9 | A painting of a in the style of Vincent Van Gogh 10 | A painting of a in the style of Claude Monet 11 | A in the style of Pixel Art 12 | 13 | A cat sit on the 14 | A book is placed on 15 | A teapot is placed on 16 | A vase is placed on the 17 | A dog sit on the 18 | 19 | A close view of 20 | A top view of 21 | A bottom view of 22 | A in rainbow colors 23 | A broken 24 | -------------------------------------------------------------------------------- /datasets/benchmark_prompts/single-concept/objects/test_vase.txt: -------------------------------------------------------------------------------- 1 | A , in the ocean 2 | A , near the Eiffel tower 3 | A , near the mount fuji 4 | A , buried in the sands 5 | A , in times square 6 | 7 | A , cyberpunk 2077, 4K, 3d render in unreal engine 8 | A watercolor painting of a 9 | A painting of a in the style of Vincent Van Gogh 10 | A painting of a in the style of Claude Monet 11 | A in the style of Pixel Art 12 | 13 | A on a table 14 | A on a chair 15 | A with a colorful flower bouquet 16 | Milk poured into a 17 | A on a carpet 18 | 19 | A green 20 | A grey 21 | A multi-color 22 | a fallen 23 | A broken 24 | -------------------------------------------------------------------------------- /datasets/benchmark_prompts/single-concept/scenes/test_scene.txt: -------------------------------------------------------------------------------- 1 | A , in the snow 2 | A , at night 3 | A , in autumn 4 | A , in a sunny day 5 | A , in thunder and lightning 6 | 7 | A , cyberpunk 2077, 4K, 3d render in unreal engine 8 | A watercolor painting of a 9 | A painting of a in the style of Vincent Van Gogh 10 | A painting of a in the style of Claude Monet 11 | A in the style of Pixel Art 12 | 13 | A girl near the 14 | A boy near the 15 | A dog near the 16 | A cat near the 17 | Many people near the 18 | 19 | A in rainbow colors 20 | A made of metal 21 | A close view of 22 | A top view of 23 | A bottom view of 24 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/multi-concept/anime/hina+kario+tezuka_anythingv4.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "lora_path": "experiments/1001_EDLoRA_hina_Anyv4_B4_Repeat250/models/edlora_model-latest.pth", 4 | "unet_alpha": 1.0, 5 | "text_encoder_alpha": 1.0, 6 | "concept_name": " " 7 | }, 8 | { 9 | "lora_path": "experiments/1002_EDLoRA_kaori_Anyv4_B4_Repeat250/models/edlora_model-latest.pth", 10 | "unet_alpha": 1.0, 11 | "text_encoder_alpha": 1.0, 12 | "concept_name": " " 13 | }, 14 | { 15 | "lora_path": "experiments/1003_EDLoRA_tezuka_Anyv4_B4_Repeat250/models/edlora_model-latest.pth", 16 | "unet_alpha": 1.0, 17 | "text_encoder_alpha": 1.0, 18 | "concept_name": " " 19 | } 20 | ] 21 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/multi-concept/real/potter+hermione+thanos_chilloutmix.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "lora_path": "experiments/2101_EDLoRA_potter_Cmix_B4_Repeat250/models/edlora_model-latest.pth", 4 | "unet_alpha": 1.0, 5 | "text_encoder_alpha": 1.0, 6 | "concept_name": " " 7 | }, 8 | { 9 | "lora_path": "experiments/2102_EDLoRA_hermione_Cmix_B4_Repeat250/models/edlora_model-latest.pth", 10 | "unet_alpha": 1.0, 11 | "text_encoder_alpha": 1.0, 12 | "concept_name": " " 13 | }, 14 | { 15 | "lora_path": "experiments/2103_EDLoRA_thanos_Cmix_B4_Repeat250/models/edlora_model-latest.pth", 16 | "unet_alpha": 1.0, 17 | "text_encoder_alpha": 1.0, 18 | "concept_name": " " 19 | } 20 | ] 21 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/characters/anime/hina_amano.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/characters/anime/Hina_Amano/image", 5 | "caption_dir": "datasets/data/characters/anime/Hina_Amano/caption", 6 | "mask_dir": "datasets/data/characters/anime/Hina_Amano/mask" 7 | } 8 | ] 9 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/characters/anime/mitsuha_miyamizu.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/characters/anime/Mitsuha_Miyamizu/image", 5 | "caption_dir": "datasets/data/characters/anime/Mitsuha_Miyamizu/caption", 6 | "mask_dir": "datasets/data/characters/anime/Mitsuha_Miyamizu/mask" 7 | } 8 | ] 9 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/characters/anime/miyazono_kaori.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/characters/anime/Miyazono_Kaori/image", 5 | "caption_dir": "datasets/data/characters/anime/Miyazono_Kaori/caption", 6 | "mask_dir": "datasets/data/characters/anime/Miyazono_Kaori/mask" 7 | } 8 | ] 9 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/characters/anime/son_goku.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/characters/anime/Son_Goku/image", 5 | "caption_dir": "datasets/data/characters/anime/Son_Goku/caption", 6 | "mask_dir": "datasets/data/characters/anime/Son_Goku/mask" 7 | } 8 | ] 9 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/characters/anime/tezuka_kunimitsu.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/characters/anime/Tezuka_Kunimitsu/image", 5 | "caption_dir": "datasets/data/characters/anime/Tezuka_Kunimitsu/caption", 6 | "mask_dir": "datasets/data/characters/anime/Tezuka_Kunimitsu/mask" 7 | } 8 | ] 9 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/characters/real/batman.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/characters/real/Batman/image", 5 | "caption_dir": "datasets/data/characters/real/Batman/caption", 6 | "mask_dir": "datasets/data/characters/real/Batman/mask" 7 | } 8 | ] 9 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/characters/real/hermione.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/characters/real/Hermione_Granger/image", 5 | "caption_dir": "datasets/data/characters/real/Hermione_Granger/caption", 6 | "mask_dir": "datasets/data/characters/real/Hermione_Granger/mask" 7 | } 8 | ] 9 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/characters/real/ironman.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/characters/real/Ironman/image", 5 | "caption_dir": "datasets/data/characters/real/Ironman/caption", 6 | "mask_dir": "datasets/data/characters/real/Ironman/mask" 7 | } 8 | ] 9 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/characters/real/potter.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/characters/real/Harry_Potter/image", 5 | "caption_dir": "datasets/data/characters/real/Harry_Potter/caption", 6 | "mask_dir": "datasets/data/characters/real/Harry_Potter/mask" 7 | } 8 | ] 9 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/characters/real/thanos.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/characters/real/Thanos/image", 5 | "caption_dir": "datasets/data/characters/real/Thanos/caption", 6 | "mask_dir": "datasets/data/characters/real/Thanos/mask" 7 | } 8 | ] 9 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/objects/real/B2.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/objects/real/airplane/B2/image", 5 | "mask_dir": "datasets/data/objects/real/airplane/B2/mask", 6 | "caption_dir": "datasets/data/objects/real/airplane/B2/caption" 7 | } 8 | ] 9 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/objects/real/carA.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/objects/real/car/carA/image", 5 | "mask_dir": "datasets/data/objects/real/car/carA/mask", 6 | "caption_dir": "datasets/data/objects/real/car/carA/caption" 7 | } 8 | ] 9 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/objects/real/catA.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/objects/real/cat/catA/image", 5 | "caption_dir": "datasets/data/objects/real/cat/catA/caption", 6 | "mask_dir": "datasets/data/objects/real/cat/catA/mask" 7 | } 8 | ] 9 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/objects/real/chair.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/objects/real/chair/image", 5 | "caption_dir": "datasets/data/objects/real/chair/caption" 6 | } 7 | ] 8 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/objects/real/dogA.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/objects/real/dog/dogA/image", 5 | "caption_dir": "datasets/data/objects/real/dog/dogA/caption", 6 | "mask_dir": "datasets/data/objects/real/dog/dogA/mask" 7 | } 8 | ] 9 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/objects/real/dogB.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/objects/real/dog/dogB/image", 5 | "caption_dir": "datasets/data/objects/real/dog/dogB/caption", 6 | "mask_dir": "datasets/data/objects/real/dog/dogB/mask" 7 | } 8 | ] 9 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/objects/real/f35.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/objects/real/airplane/f35/image", 5 | "mask_dir": "datasets/data/objects/real/airplane/f35/mask", 6 | "caption_dir": "datasets/data/objects/real/airplane/f35/caption" 7 | } 8 | ] 9 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/objects/real/porsche_356a.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/objects/real/car/porsche_356a/image", 5 | "mask_dir": "datasets/data/objects/real/car/porsche_356a/mask", 6 | "caption_dir": "datasets/data/objects/real/car/porsche_356a/caption" 7 | } 8 | ] 9 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/objects/real/sailboat.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/objects/real/boat/sailboat/image", 5 | "mask_dir": "datasets/data/objects/real/boat/sailboat/mask", 6 | "caption_dir": "datasets/data/objects/real/boat/sailboat/caption" 7 | } 8 | ] 9 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/objects/real/table.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/objects/real/table/image", 5 | "caption_dir": "datasets/data/objects/real/table/caption" 6 | } 7 | ] 8 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/objects/real/vase.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/objects/real/vase/image", 5 | "caption_dir": "datasets/data/objects/real/vase/caption" 6 | } 7 | ] 8 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/objects/real/yacht.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/objects/real/boat/yacht/image", 5 | "mask_dir": "datasets/data/objects/real/boat/yacht/mask", 6 | "caption_dir": "datasets/data/objects/real/boat/yacht/caption" 7 | } 8 | ] 9 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/scenes/real/pyramid.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/scenes/real/Pyramid/image", 5 | "caption_dir": "datasets/data/scenes/real/Pyramid/caption" 6 | } 7 | ] 8 | -------------------------------------------------------------------------------- /datasets/data_cfgs/MixofShow/single-concept/scenes/real/wululu.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "", 4 | "instance_data_dir": "datasets/data/scenes/real/Wululu/image" 5 | } 6 | ] 7 | -------------------------------------------------------------------------------- /datasets/validation_prompts/single-concept/characters/test_girl.txt: -------------------------------------------------------------------------------- 1 | photo of a 2 | 3 | photo of a girl 4 | Ultra HD quality of is walking on the street 5 | a , wearing a red hat 6 | a , wearing a blue shirt 7 | a in front of eiffel tower 8 | a sit on the chair 9 | a photo of on the beach, small waves, detailed symmetric face, beautiful composition 10 | a pencil sketch of 11 | , cyberpunk 2077, 4K, 3d render in unreal engine 12 | -------------------------------------------------------------------------------- /datasets/validation_prompts/single-concept/characters/test_goku.txt: -------------------------------------------------------------------------------- 1 | photo of a 2 | 3 | photo of a goku 4 | Ultra HD quality of is walking on the street 5 | a in front of eiffel tower 6 | a sit on the chair 7 | a photo of on the beach, small waves, detailed symmetric face, beautiful composition 8 | a pencil sketch of 9 | , cyberpunk 2077, 4K, 3d render in unreal engine 10 | a , wearing a red hat 11 | a , wearing a blue shirt 12 | -------------------------------------------------------------------------------- /datasets/validation_prompts/single-concept/characters/test_man.txt: -------------------------------------------------------------------------------- 1 | photo of a 2 | 3 | photo of a man 4 | Ultra HD quality of is walking on the street 5 | a , wearing a red hat 6 | a , wearing a blue shirt 7 | a in front of eiffel tower 8 | a sit on the chair 9 | a photo of on the beach, small waves, detailed symmetric face, beautiful composition 10 | a pencil sketch of 11 | , cyberpunk 2077, 4K, 3d render in unreal engine 12 | -------------------------------------------------------------------------------- /datasets/validation_prompts/single-concept/characters/test_woman.txt: -------------------------------------------------------------------------------- 1 | photo of a 2 | 3 | photo of a woman 4 | Ultra HD quality of is walking on the street 5 | a , wearing a red hat 6 | a , wearing a blue shirt 7 | a in front of eiffel tower 8 | a sit on the chair 9 | a photo of on the beach, small waves, detailed symmetric face, beautiful composition 10 | a pencil sketch of 11 | , cyberpunk 2077, 4K, 3d render in unreal engine 12 | -------------------------------------------------------------------------------- /datasets/validation_prompts/single-concept/objects/test_airplane.txt: -------------------------------------------------------------------------------- 1 | photo of a 2 | 3 | photo of a airplane 4 | Ultra HD quality of is flying on the blue sky 5 | Ultra HD quality of is flying on the rainy sky, lightning 6 | -------------------------------------------------------------------------------- /datasets/validation_prompts/single-concept/objects/test_boat.txt: -------------------------------------------------------------------------------- 1 | photo of a 2 | 3 | photo of a boat 4 | Ultra HD quality of is driving on the sea 5 | Ultra HD quality of is driving on the sea, heavy rainy day, lightning 6 | -------------------------------------------------------------------------------- /datasets/validation_prompts/single-concept/objects/test_car.txt: -------------------------------------------------------------------------------- 1 | photo of a 2 | 3 | photo of a car 4 | Ultra HD quality of is driving on the street 5 | a driving down a curvy road in the countryside 6 | a in front of eiffel tower 7 | -------------------------------------------------------------------------------- /datasets/validation_prompts/single-concept/objects/test_cat.txt: -------------------------------------------------------------------------------- 1 | photo of a 2 | 3 | photo of a cat 4 | Ultra HD quality of is walking on the street 5 | a , wearing a headphone 6 | a in front of eiffel tower 7 | a sit on the chair 8 | a is swimming in the swimming pool 9 | -------------------------------------------------------------------------------- /datasets/validation_prompts/single-concept/objects/test_chair.txt: -------------------------------------------------------------------------------- 1 | photo of a 2 | 3 | photo of a chair 4 | a girl sit on a 5 | a cat sit on a 6 | a , indoor 7 | -------------------------------------------------------------------------------- /datasets/validation_prompts/single-concept/objects/test_dog.txt: -------------------------------------------------------------------------------- 1 | photo of a 2 | 3 | photo of a dog 4 | Ultra HD quality of is walking on the street 5 | a , wearing a headphone 6 | a in front of eiffel tower 7 | a sit on the chair 8 | a is swimming in the swimming pool 9 | -------------------------------------------------------------------------------- /datasets/validation_prompts/single-concept/objects/test_table.txt: -------------------------------------------------------------------------------- 1 | photo of a 2 | 3 | photo of a table 4 | a flowers on a 5 | a book on a 6 | a , near the swimming pool 7 | a , indoor 8 | -------------------------------------------------------------------------------- /datasets/validation_prompts/single-concept/objects/test_vase.txt: -------------------------------------------------------------------------------- 1 | photo of a 2 | 3 | photo of a vase 4 | a rose in a 5 | a on a table 6 | -------------------------------------------------------------------------------- /datasets/validation_prompts/single-concept/scenes/test_scene.txt: -------------------------------------------------------------------------------- 1 | photo of a 2 | 3 | photo of a pyramid 4 | photo of a rock 5 | photo of a , at night 6 | two people stand in front of 7 | -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/characters-objects/bengio+lecun+chair.txt: -------------------------------------------------------------------------------- 1 | char1=' sit on a ' 2 | box1='[31, 2, 512, 376]' 3 | 4 | char2=' sit on a ' 5 | box2='[30, 644, 506, 1011]' 6 | 7 | region3='a ' 8 | region3='[110, 438, 342, 568]' 9 | 10 | region4='a ' 11 | region4='[290, 344, 488, 674]' 12 | -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/characters-objects/bengio+lecun+chair_pose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/datasets/validation_spatial_condition/characters-objects/bengio+lecun+chair_pose.png -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/characters-objects/bengio+lecun+chair_sketch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/datasets/validation_spatial_condition/characters-objects/bengio+lecun+chair_sketch.png -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/characters-objects/harry+catA+dogA.txt: -------------------------------------------------------------------------------- 1 | char1=' sit on a ' 2 | box1='[0, 0, 512, 400]' 3 | 4 | region3='a ' 5 | region3='[60, 501, 350, 706]' 6 | 7 | region4='a ' 8 | region4='[57, 692, 343, 940]' 9 | 10 | region5='a ' 11 | region5='[280, 423, 508, 983]' 12 | -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/characters-objects/harry+catA+dogA_pose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/datasets/validation_spatial_condition/characters-objects/harry+catA+dogA_pose.png -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/characters-objects/harry+catA+dogA_sketch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/datasets/validation_spatial_condition/characters-objects/harry+catA+dogA_sketch.png -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/characters-objects/harry_heminone_scene.txt: -------------------------------------------------------------------------------- 1 | char1=' ' 2 | box1='[0, 315, 512, 530]' 3 | 4 | region3=' ' 5 | region3='[0, 502, 512, 747]' 6 | 7 | char1=' ' 8 | box1='[221, 43, 512, 258]' 9 | 10 | region3=' ' 11 | region3='[228, 752, 512, 1016]' 12 | -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/characters-objects/harry_heminone_scene_pose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/datasets/validation_spatial_condition/characters-objects/harry_heminone_scene_pose.png -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/characters-objects/harry_heminone_scene_sketch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/datasets/validation_spatial_condition/characters-objects/harry_heminone_scene_sketch.png -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/multi-characters/anime_pose/hina_mitsuha_kario.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/datasets/validation_spatial_condition/multi-characters/anime_pose/hina_mitsuha_kario.png -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/multi-characters/anime_pose/hina_mitsuha_kario.txt: -------------------------------------------------------------------------------- 1 | char1=' ' 2 | box1='[61, 115, 512, 273]' 3 | 4 | char2=' ' 5 | box2='[49, 323, 512, 500]' 6 | 7 | char3=' ' 8 | box3='[53, 519, 512, 715]' 9 | -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/multi-characters/anime_pose/hina_tezuka_kario.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/datasets/validation_spatial_condition/multi-characters/anime_pose/hina_tezuka_kario.png -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/multi-characters/anime_pose/hina_tezuka_kario.txt: -------------------------------------------------------------------------------- 1 | char1=' ' 2 | box1='[61, 115, 512, 273]' 3 | 4 | char2=' ' 5 | box2='[19, 292, 512, 512]' 6 | 7 | char3=' ' 8 | box3='[48, 519, 512, 706]' 9 | -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/multi-characters/anime_pose/hina_tezuka_mitsuha_goku_kaori.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/datasets/validation_spatial_condition/multi-characters/anime_pose/hina_tezuka_mitsuha_goku_kaori.png -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/multi-characters/anime_pose/hina_tezuka_mitsuha_goku_kaori.txt: -------------------------------------------------------------------------------- 1 | char1=' ' 2 | box1='[61, 18, 512, 192]' 3 | 4 | char2=' ' 5 | box2='[20, 194, 512, 407]' 6 | 7 | char3=' ' 8 | box3='[82, 433, 512, 614]' 9 | 10 | char4='goku' 11 | box4='[9, 627, 512, 803]' 12 | 13 | char5=' ' 14 | box5='[71, 803, 512, 978]' 15 | 16 | 17 | 18 | char1_prompt='[a , near a lake]'char1_neg_prompt="[${context_neg_prompt}]"box1='[61, 18, 512, 166]'adptor_weight1="[1.0,1.0,1.0,1.0]" # fine to coarse 19 | char2_prompt='[a , near a lake]' 20 | char2_neg_prompt="[${context_neg_prompt}]" 21 | box2='[20, 167, 512, 387]'adptor_weight2="[1.0,1.0,1.0,1.0]" 22 | 23 | char3_prompt='[a , near a lake]' 24 | char3_neg_prompt="[${context_neg_prompt}]" 25 | box3='[82, 413, 512, 584]'adptor_weight3="[1.0,1.0,1.0,1.0]" 26 | char4_prompt='[a , near a lake]' 27 | char4_neg_prompt="[${context_neg_prompt}]" 28 | box4='[22, 615, 512, 793]' 29 | adptor_weight4="[1.0,1.0,1.0,1.0]" 30 | 31 | char5_prompt='[a , near a lake]' 32 | char5_neg_prompt="[${context_neg_prompt}]" 33 | box5='[71, 818, 512, 983]' 34 | adptor_weight5="[1.0,1.0,1.0,1.0]" 35 | -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/multi-characters/anime_pose_2x/hina_tezuka_kario_2x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/datasets/validation_spatial_condition/multi-characters/anime_pose_2x/hina_tezuka_kario_2x.png -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/multi-characters/real_pose/bengio_lecun_bengio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/datasets/validation_spatial_condition/multi-characters/real_pose/bengio_lecun_bengio.png -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/multi-characters/real_pose/bengio_lecun_bengio.txt: -------------------------------------------------------------------------------- 1 | char1=' ' 2 | box1='[6, 51, 512, 293]' 3 | 4 | char2=' ' 5 | box2='[1, 350, 512, 618]' 6 | 7 | char3=' ' 8 | box3='[3, 657, 512, 923]' 9 | -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/multi-characters/real_pose/harry_hermione_thanos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/datasets/validation_spatial_condition/multi-characters/real_pose/harry_hermione_thanos.png -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/multi-characters/real_pose/harry_hermione_thanos.txt: -------------------------------------------------------------------------------- 1 | char1=' ' 2 | box1='[4, 28, 512, 251]' 3 | 4 | char2=' ' 5 | box2='[7, 215, 512, 453]' 6 | 7 | char3=' ' 8 | box3='[1, 651, 512, 996]' 9 | -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/multi-characters/real_pose_2x/harry_hermione_thanos_2x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/datasets/validation_spatial_condition/multi-characters/real_pose_2x/harry_hermione_thanos_2x.png -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/multi-objects/dogA_catA_dogB.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/datasets/validation_spatial_condition/multi-objects/dogA_catA_dogB.jpg -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/multi-objects/dogA_catA_dogB.txt: -------------------------------------------------------------------------------- 1 | char1=' ' 2 | box1='[160, 76, 505, 350]' 3 | 4 | char2=' ' 5 | box2='[162, 370, 500, 685]' 6 | 7 | char3=' ' 8 | box3='[134, 666, 512, 1005]' 9 | -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/multi-objects/two_chair_table_vase.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/datasets/validation_spatial_condition/multi-objects/two_chair_table_vase.jpg -------------------------------------------------------------------------------- /datasets/validation_spatial_condition/multi-objects/two_chair_table_vase.txt: -------------------------------------------------------------------------------- 1 | char1=' ' 2 | box1='[150, 6, 463, 293]' 3 | 4 | char2='a ' 5 | box2='[53, 438, 302, 565]' 6 | 7 | char3=' ' 8 | box3='[160, 724, 457, 1002]' 9 | 10 | char4='a ' 11 | box4='[248, 344, 468, 664]' 12 | -------------------------------------------------------------------------------- /docs/Dataset.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/docs/Dataset.md -------------------------------------------------------------------------------- /fuse.sh: -------------------------------------------------------------------------------- 1 | # fuse real character 2 | config_file="potter+hermione+thanos_chilloutmix" 3 | 4 | python gradient_fusion.py \ 5 | --concept_cfg="datasets/data_cfgs/MixofShow/multi-concept/real/${config_file}.json" \ 6 | --save_path="experiments/composed_edlora/chilloutmix/${config_file}" \ 7 | --pretrained_models="experiments/pretrained_models/chilloutmix" \ 8 | --optimize_textenc_iters=500 \ 9 | --optimize_unet_iters=50 10 | 11 | # fuse anime character 12 | config_file="hina+kario+tezuka_anythingv4" 13 | 14 | python gradient_fusion.py \ 15 | --concept_cfg="datasets/data_cfgs/MixofShow/multi-concept/anime/${config_file}.json" \ 16 | --save_path="experiments/composed_edlora/anythingv4/${config_file}" \ 17 | --pretrained_models="experiments/pretrained_models/anything-v4.0" \ 18 | --optimize_textenc_iters=500 \ 19 | --optimize_unet_iters=50 20 | -------------------------------------------------------------------------------- /mixofshow/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/mixofshow/data/__init__.py -------------------------------------------------------------------------------- /mixofshow/data/lora_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import re 5 | from pathlib import Path 6 | 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | 10 | from mixofshow.data.pil_transform import PairCompose, build_transform 11 | 12 | 13 | class LoraDataset(Dataset): 14 | """ 15 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. 16 | It pre-processes the images and the tokenizes prompts. 17 | """ 18 | def __init__(self, opt): 19 | self.opt = opt 20 | self.instance_images_path = [] 21 | 22 | with open(opt['concept_list'], 'r') as f: 23 | concept_list = json.load(f) 24 | 25 | replace_mapping = opt.get('replace_mapping', {}) 26 | use_caption = opt.get('use_caption', False) 27 | use_mask = opt.get('use_mask', False) 28 | 29 | for concept in concept_list: 30 | instance_prompt = concept['instance_prompt'] 31 | caption_dir = concept.get('caption_dir') 32 | mask_dir = concept.get('mask_dir') 33 | 34 | instance_prompt = self.process_text(instance_prompt, replace_mapping) 35 | 36 | inst_img_path = [] 37 | for x in Path(concept['instance_data_dir']).iterdir(): 38 | if x.is_file() and x.name != '.DS_Store': 39 | basename = os.path.splitext(os.path.basename(x))[0] 40 | caption_path = os.path.join(caption_dir, f'{basename}.txt') if caption_dir is not None else None 41 | 42 | if use_caption and caption_path is not None and os.path.exists(caption_path): 43 | with open(caption_path, 'r') as fr: 44 | line = fr.readlines()[0] 45 | instance_prompt_image = self.process_text(line, replace_mapping) 46 | else: 47 | instance_prompt_image = instance_prompt 48 | 49 | if use_mask and mask_dir is not None: 50 | mask_path = os.path.join(mask_dir, f'{basename}.png') 51 | else: 52 | mask_path = None 53 | 54 | inst_img_path.append((x, instance_prompt_image, mask_path)) 55 | 56 | self.instance_images_path.extend(inst_img_path) 57 | 58 | random.shuffle(self.instance_images_path) 59 | self.num_instance_images = len(self.instance_images_path) 60 | 61 | self.instance_transform = PairCompose([ 62 | build_transform(transform_opt) 63 | for transform_opt in opt['instance_transform'] 64 | ]) 65 | 66 | def process_text(self, instance_prompt, replace_mapping): 67 | for k, v in replace_mapping.items(): 68 | instance_prompt = instance_prompt.replace(k, v) 69 | instance_prompt = instance_prompt.strip() 70 | instance_prompt = re.sub(' +', ' ', instance_prompt) 71 | return instance_prompt 72 | 73 | def __len__(self): 74 | return self.num_instance_images * self.opt['dataset_enlarge_ratio'] 75 | 76 | def __getitem__(self, index): 77 | example = {} 78 | instance_image, instance_prompt, instance_mask = self.instance_images_path[index % self.num_instance_images] 79 | instance_image = Image.open(instance_image).convert('RGB') 80 | 81 | extra_args = {'prompts': instance_prompt} 82 | if instance_mask is not None: 83 | instance_mask = Image.open(instance_mask).convert('L') 84 | extra_args.update({'mask': instance_mask}) 85 | 86 | instance_image, extra_args = self.instance_transform(instance_image, **extra_args) 87 | example['images'] = instance_image 88 | 89 | if 'mask' in extra_args: 90 | example['masks'] = extra_args['mask'] 91 | example['masks'] = example['masks'].unsqueeze(0) 92 | else: 93 | pass 94 | 95 | if 'img_mask' in extra_args: 96 | example['img_masks'] = extra_args['img_mask'] 97 | example['img_masks'] = example['img_masks'].unsqueeze(0) 98 | else: 99 | raise NotImplementedError 100 | 101 | example['prompts'] = extra_args['prompts'] 102 | return example 103 | -------------------------------------------------------------------------------- /mixofshow/data/pil_transform.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import random 3 | from copy import deepcopy 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torchvision.transforms.functional as F 10 | from PIL import Image 11 | from torchvision.transforms import CenterCrop, Normalize, RandomCrop, RandomHorizontalFlip, Resize 12 | from torchvision.transforms.functional import InterpolationMode 13 | 14 | from mixofshow.utils.registry import TRANSFORM_REGISTRY 15 | 16 | 17 | def build_transform(opt): 18 | """Build performance evaluator from options. 19 | Args: 20 | opt (dict): Configuration. 21 | """ 22 | opt = deepcopy(opt) 23 | transform_type = opt.pop('type') 24 | transform = TRANSFORM_REGISTRY.get(transform_type)(**opt) 25 | return transform 26 | 27 | 28 | TRANSFORM_REGISTRY.register(Normalize) 29 | TRANSFORM_REGISTRY.register(Resize) 30 | TRANSFORM_REGISTRY.register(RandomHorizontalFlip) 31 | TRANSFORM_REGISTRY.register(CenterCrop) 32 | TRANSFORM_REGISTRY.register(RandomCrop) 33 | 34 | 35 | @TRANSFORM_REGISTRY.register() 36 | class BILINEARResize(Resize): 37 | def __init__(self, size): 38 | super(BILINEARResize, 39 | self).__init__(size, interpolation=InterpolationMode.BILINEAR) 40 | 41 | 42 | @TRANSFORM_REGISTRY.register() 43 | class PairRandomCrop(nn.Module): 44 | def __init__(self, size): 45 | super().__init__() 46 | if isinstance(size, int): 47 | self.height, self.width = size, size 48 | else: 49 | self.height, self.width = size 50 | 51 | def forward(self, img, **kwargs): 52 | img_width, img_height = img.size 53 | mask_width, mask_height = kwargs['mask'].size 54 | 55 | assert img_height >= self.height and img_height == mask_height 56 | assert img_width >= self.width and img_width == mask_width 57 | 58 | x = random.randint(0, img_width - self.width) 59 | y = random.randint(0, img_height - self.height) 60 | img = F.crop(img, y, x, self.height, self.width) 61 | kwargs['mask'] = F.crop(kwargs['mask'], y, x, self.height, self.width) 62 | return img, kwargs 63 | 64 | 65 | @TRANSFORM_REGISTRY.register() 66 | class ToTensor(nn.Module): 67 | def __init__(self) -> None: 68 | super().__init__() 69 | 70 | def forward(self, pic): 71 | return F.to_tensor(pic) 72 | 73 | def __repr__(self) -> str: 74 | return f'{self.__class__.__name__}()' 75 | 76 | 77 | @TRANSFORM_REGISTRY.register() 78 | class PairRandomHorizontalFlip(torch.nn.Module): 79 | def __init__(self, p=0.5): 80 | super().__init__() 81 | self.p = p 82 | 83 | def forward(self, img, **kwargs): 84 | if torch.rand(1) < self.p: 85 | kwargs['mask'] = F.hflip(kwargs['mask']) 86 | return F.hflip(img), kwargs 87 | return img, kwargs 88 | 89 | 90 | @TRANSFORM_REGISTRY.register() 91 | class PairResize(nn.Module): 92 | def __init__(self, size): 93 | super().__init__() 94 | self.resize = Resize(size=size) 95 | 96 | def forward(self, img, **kwargs): 97 | kwargs['mask'] = self.resize(kwargs['mask']) 98 | img = self.resize(img) 99 | return img, kwargs 100 | 101 | 102 | class PairCompose(nn.Module): 103 | def __init__(self, transforms): 104 | super().__init__() 105 | self.transforms = transforms 106 | 107 | def __call__(self, img, **kwargs): 108 | for t in self.transforms: 109 | if len(inspect.signature(t.forward).parameters 110 | ) == 1: # count how many args, not count self 111 | img = t(img) 112 | else: 113 | img, kwargs = t(img, **kwargs) 114 | return img, kwargs 115 | 116 | def __repr__(self) -> str: 117 | format_string = self.__class__.__name__ + '(' 118 | for t in self.transforms: 119 | format_string += '\n' 120 | format_string += f' {t}' 121 | format_string += '\n)' 122 | return format_string 123 | 124 | 125 | @TRANSFORM_REGISTRY.register() 126 | class HumanResizeCropFinalV3(nn.Module): 127 | def __init__(self, size, crop_p=0.5): 128 | super().__init__() 129 | self.size = size 130 | self.crop_p = crop_p 131 | self.random_crop = RandomCrop(size=size) 132 | self.paired_random_crop = PairRandomCrop(size=size) 133 | 134 | def forward(self, img, **kwargs): 135 | # step 1: short edge resize to 512 136 | img = F.resize(img, size=self.size) 137 | if 'mask' in kwargs: 138 | kwargs['mask'] = F.resize(kwargs['mask'], size=self.size) 139 | 140 | # step 2: random crop 141 | width, height = img.size 142 | if random.random() < self.crop_p: 143 | if height > width: 144 | crop_pos = random.randint(0, height - width) 145 | img = F.crop(img, 0, 0, width + crop_pos, width) 146 | if 'mask' in kwargs: 147 | kwargs['mask'] = F.crop(kwargs['mask'], 0, 0, width + crop_pos, width) 148 | else: 149 | if 'mask' in kwargs: 150 | img, kwargs = self.paired_random_crop(img, **kwargs) 151 | else: 152 | img = self.random_crop(img) 153 | else: 154 | img = img 155 | 156 | # step 3: long edge resize 157 | img = F.resize(img, size=self.size - 1, max_size=self.size) 158 | if 'mask' in kwargs: 159 | kwargs['mask'] = F.resize(kwargs['mask'], size=self.size - 1, max_size=self.size) 160 | 161 | new_width, new_height = img.size 162 | 163 | img = np.array(img) 164 | if 'mask' in kwargs: 165 | kwargs['mask'] = np.array(kwargs['mask']) / 255 166 | 167 | start_y = random.randint(0, 512 - new_height) 168 | start_x = random.randint(0, 512 - new_width) 169 | 170 | res_img = np.zeros((self.size, self.size, 3), dtype=np.uint8) 171 | res_mask = np.zeros((self.size, self.size)) 172 | res_img_mask = np.zeros((self.size, self.size)) 173 | 174 | res_img[start_y:start_y + new_height, start_x:start_x + new_width, :] = img 175 | if 'mask' in kwargs: 176 | res_mask[start_y:start_y + new_height, start_x:start_x + new_width] = kwargs['mask'] 177 | kwargs['mask'] = res_mask 178 | 179 | res_img_mask[start_y:start_y + new_height, start_x:start_x + new_width] = 1 180 | kwargs['img_mask'] = res_img_mask 181 | 182 | img = Image.fromarray(res_img) 183 | 184 | if 'mask' in kwargs: 185 | kwargs['mask'] = cv2.resize(kwargs['mask'], (self.size // 8, self.size // 8), cv2.INTER_NEAREST) 186 | kwargs['mask'] = torch.from_numpy(kwargs['mask']) 187 | kwargs['img_mask'] = cv2.resize(kwargs['img_mask'], (self.size // 8, self.size // 8), cv2.INTER_NEAREST) 188 | kwargs['img_mask'] = torch.from_numpy(kwargs['img_mask']) 189 | return img, kwargs 190 | 191 | 192 | @TRANSFORM_REGISTRY.register() 193 | class ResizeFillMaskNew(nn.Module): 194 | def __init__(self, size, crop_p, scale_ratio): 195 | super().__init__() 196 | self.size = size 197 | self.crop_p = crop_p 198 | self.scale_ratio = scale_ratio 199 | self.random_crop = RandomCrop(size=size) 200 | self.paired_random_crop = PairRandomCrop(size=size) 201 | 202 | def forward(self, img, **kwargs): 203 | # width, height = img.size 204 | 205 | # step 1: short edge resize to 512 206 | img = F.resize(img, size=self.size) 207 | if 'mask' in kwargs: 208 | kwargs['mask'] = F.resize(kwargs['mask'], size=self.size) 209 | 210 | # step 2: random crop 211 | if random.random() < self.crop_p: 212 | if 'mask' in kwargs: 213 | img, kwargs = self.paired_random_crop(img, **kwargs) # 51 214 | else: 215 | img = self.random_crop(img) # 512 216 | else: 217 | # long edge resize 218 | img = F.resize(img, size=self.size - 1, max_size=self.size) 219 | if 'mask' in kwargs: 220 | kwargs['mask'] = F.resize(kwargs['mask'], size=self.size - 1, max_size=self.size) 221 | 222 | # step 3: random aspect ratio 223 | width, height = img.size 224 | ratio = random.uniform(*self.scale_ratio) 225 | 226 | img = F.resize(img, size=(int(height * ratio), int(width * ratio))) 227 | if 'mask' in kwargs: 228 | kwargs['mask'] = F.resize(kwargs['mask'], size=(int(height * ratio), int(width * ratio)), interpolation=0) 229 | 230 | # step 4: random place 231 | new_width, new_height = img.size 232 | 233 | img = np.array(img) 234 | if 'mask' in kwargs: 235 | kwargs['mask'] = np.array(kwargs['mask']) / 255 236 | 237 | start_y = random.randint(0, 512 - new_height) 238 | start_x = random.randint(0, 512 - new_width) 239 | 240 | res_img = np.zeros((self.size, self.size, 3), dtype=np.uint8) 241 | res_mask = np.zeros((self.size, self.size)) 242 | res_img_mask = np.zeros((self.size, self.size)) 243 | 244 | res_img[start_y:start_y + new_height, start_x:start_x + new_width, :] = img 245 | if 'mask' in kwargs: 246 | res_mask[start_y:start_y + new_height, start_x:start_x + new_width] = kwargs['mask'] 247 | kwargs['mask'] = res_mask 248 | 249 | res_img_mask[start_y:start_y + new_height, start_x:start_x + new_width] = 1 250 | kwargs['img_mask'] = res_img_mask 251 | 252 | img = Image.fromarray(res_img) 253 | 254 | if 'mask' in kwargs: 255 | kwargs['mask'] = cv2.resize(kwargs['mask'], (self.size // 8, self.size // 8), cv2.INTER_NEAREST) 256 | kwargs['mask'] = torch.from_numpy(kwargs['mask']) 257 | kwargs['img_mask'] = cv2.resize(kwargs['img_mask'], (self.size // 8, self.size // 8), cv2.INTER_NEAREST) 258 | kwargs['img_mask'] = torch.from_numpy(kwargs['img_mask']) 259 | 260 | return img, kwargs 261 | 262 | 263 | @TRANSFORM_REGISTRY.register() 264 | class ShuffleCaption(nn.Module): 265 | def __init__(self, keep_token_num): 266 | super().__init__() 267 | self.keep_token_num = keep_token_num 268 | 269 | def forward(self, img, **kwargs): 270 | prompts = kwargs['prompts'].strip() 271 | 272 | fixed_tokens = [] 273 | flex_tokens = [t.strip() for t in prompts.strip().split(',')] 274 | if self.keep_token_num > 0: 275 | fixed_tokens = flex_tokens[:self.keep_token_num] 276 | flex_tokens = flex_tokens[self.keep_token_num:] 277 | 278 | random.shuffle(flex_tokens) 279 | prompts = ', '.join(fixed_tokens + flex_tokens) 280 | kwargs['prompts'] = prompts 281 | return img, kwargs 282 | 283 | 284 | @TRANSFORM_REGISTRY.register() 285 | class EnhanceText(nn.Module): 286 | def __init__(self, enhance_type='object'): 287 | super().__init__() 288 | STYLE_TEMPLATE = [ 289 | 'a painting in the style of {}', 290 | 'a rendering in the style of {}', 291 | 'a cropped painting in the style of {}', 292 | 'the painting in the style of {}', 293 | 'a clean painting in the style of {}', 294 | 'a dirty painting in the style of {}', 295 | 'a dark painting in the style of {}', 296 | 'a picture in the style of {}', 297 | 'a cool painting in the style of {}', 298 | 'a close-up painting in the style of {}', 299 | 'a bright painting in the style of {}', 300 | 'a cropped painting in the style of {}', 301 | 'a good painting in the style of {}', 302 | 'a close-up painting in the style of {}', 303 | 'a rendition in the style of {}', 304 | 'a nice painting in the style of {}', 305 | 'a small painting in the style of {}', 306 | 'a weird painting in the style of {}', 307 | 'a large painting in the style of {}', 308 | ] 309 | 310 | OBJECT_TEMPLATE = [ 311 | 'a photo of a {}', 312 | 'a rendering of a {}', 313 | 'a cropped photo of the {}', 314 | 'the photo of a {}', 315 | 'a photo of a clean {}', 316 | 'a photo of a dirty {}', 317 | 'a dark photo of the {}', 318 | 'a photo of my {}', 319 | 'a photo of the cool {}', 320 | 'a close-up photo of a {}', 321 | 'a bright photo of the {}', 322 | 'a cropped photo of a {}', 323 | 'a photo of the {}', 324 | 'a good photo of the {}', 325 | 'a photo of one {}', 326 | 'a close-up photo of the {}', 327 | 'a rendition of the {}', 328 | 'a photo of the clean {}', 329 | 'a rendition of a {}', 330 | 'a photo of a nice {}', 331 | 'a good photo of a {}', 332 | 'a photo of the nice {}', 333 | 'a photo of the small {}', 334 | 'a photo of the weird {}', 335 | 'a photo of the large {}', 336 | 'a photo of a cool {}', 337 | 'a photo of a small {}', 338 | ] 339 | 340 | HUMAN_TEMPLATE = [ 341 | 'a photo of a {}', 'a photo of one {}', 'a photo of the {}', 342 | 'the photo of a {}', 'a rendering of a {}', 343 | 'a rendition of the {}', 'a rendition of a {}', 344 | 'a cropped photo of the {}', 'a cropped photo of a {}', 345 | 'a bad photo of the {}', 'a bad photo of a {}', 346 | 'a photo of a weird {}', 'a weird photo of a {}', 347 | 'a bright photo of the {}', 'a good photo of the {}', 348 | 'a photo of a nice {}', 'a good photo of a {}', 349 | 'a photo of a cool {}', 'a bright photo of the {}' 350 | ] 351 | 352 | if enhance_type == 'object': 353 | self.templates = OBJECT_TEMPLATE 354 | elif enhance_type == 'style': 355 | self.templates = STYLE_TEMPLATE 356 | elif enhance_type == 'human': 357 | self.templates = HUMAN_TEMPLATE 358 | else: 359 | raise NotImplementedError 360 | 361 | def forward(self, img, **kwargs): 362 | concept_token = kwargs['prompts'].strip() 363 | kwargs['prompts'] = random.choice(self.templates).format(concept_token) 364 | return img, kwargs 365 | -------------------------------------------------------------------------------- /mixofshow/data/prompt_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import re 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class PromptDataset(Dataset): 10 | 'A simple dataset to prepare the prompts to generate class images on multiple GPUs.' 11 | 12 | def __init__(self, opt): 13 | self.opt = opt 14 | 15 | self.prompts = opt['prompts'] 16 | 17 | if isinstance(self.prompts, list): 18 | self.prompts = self.prompts 19 | elif os.path.exists(self.prompts): 20 | # is file 21 | with open(self.prompts, 'r') as fr: 22 | lines = fr.readlines() 23 | lines = [item.strip() for item in lines] 24 | self.prompts = lines 25 | else: 26 | raise ValueError( 27 | 'prompts should be a prompt file path or prompt list, please check!' 28 | ) 29 | 30 | self.prompts = self.replace_placeholder(self.prompts) 31 | 32 | self.num_samples_per_prompt = opt['num_samples_per_prompt'] 33 | self.prompts_to_generate = [ 34 | (p, i) for i in range(1, self.num_samples_per_prompt + 1) 35 | for p in self.prompts 36 | ] 37 | self.latent_size = opt['latent_size'] # (4,64,64) 38 | self.share_latent_across_prompt = opt.get('share_latent_across_prompt', True) # (true, false) 39 | 40 | def replace_placeholder(self, prompts): 41 | # replace placehold token 42 | replace_mapping = self.opt.get('replace_mapping', {}) 43 | new_lines = [] 44 | for line in self.prompts: 45 | if len(line.strip()) == 0: 46 | continue 47 | for k, v in replace_mapping.items(): 48 | line = line.replace(k, v) 49 | line = line.strip() 50 | line = re.sub(' +', ' ', line) 51 | new_lines.append(line) 52 | return new_lines 53 | 54 | def __len__(self): 55 | return len(self.prompts_to_generate) 56 | 57 | def __getitem__(self, index): 58 | prompt, indice = self.prompts_to_generate[index] 59 | example = {} 60 | example['prompts'] = prompt 61 | example['indices'] = indice 62 | if self.share_latent_across_prompt: 63 | seed = indice 64 | else: 65 | seed = random.randint(0, 1000) 66 | example['latents'] = torch.randn(self.latent_size, generator=torch.manual_seed(seed)) 67 | return example 68 | -------------------------------------------------------------------------------- /mixofshow/models/edlora.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from diffusers.models.attention_processor import AttnProcessor 6 | from diffusers.utils.import_utils import is_xformers_available 7 | 8 | if is_xformers_available(): 9 | import xformers 10 | 11 | 12 | def remove_edlora_unet_attention_forward(unet): 13 | def change_forward(unet): # omit proceesor in new diffusers 14 | for name, layer in unet.named_children(): 15 | if layer.__class__.__name__ == 'Attention' and name == 'attn2': 16 | layer.set_processor(AttnProcessor()) 17 | else: 18 | change_forward(layer) 19 | change_forward(unet) 20 | 21 | 22 | class EDLoRA_Control_AttnProcessor: 23 | r""" 24 | Default processor for performing attention-related computations. 25 | """ 26 | def __init__(self, cross_attention_idx, place_in_unet, controller, attention_op=None): 27 | self.cross_attention_idx = cross_attention_idx 28 | self.place_in_unet = place_in_unet 29 | self.controller = controller 30 | self.attention_op = attention_op 31 | 32 | def __call__( 33 | self, 34 | attn, 35 | hidden_states, 36 | encoder_hidden_states=None, 37 | attention_mask=None, 38 | temb=None, 39 | ): 40 | residual = hidden_states 41 | 42 | if attn.spatial_norm is not None: 43 | hidden_states = attn.spatial_norm(hidden_states, temb) 44 | 45 | input_ndim = hidden_states.ndim 46 | 47 | if input_ndim == 4: 48 | batch_size, channel, height, width = hidden_states.shape 49 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 50 | 51 | if encoder_hidden_states is None: 52 | is_cross = False 53 | encoder_hidden_states = hidden_states 54 | else: 55 | is_cross = True 56 | if len(encoder_hidden_states.shape) == 4: # multi-layer embedding 57 | encoder_hidden_states = encoder_hidden_states[:, self.cross_attention_idx, ...] 58 | else: # single layer embedding 59 | encoder_hidden_states = encoder_hidden_states 60 | 61 | assert not attn.norm_cross 62 | 63 | batch_size, sequence_length, _ = encoder_hidden_states.shape 64 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 65 | 66 | if attn.group_norm is not None: 67 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 68 | 69 | query = attn.to_q(hidden_states) 70 | key = attn.to_k(encoder_hidden_states) 71 | value = attn.to_v(encoder_hidden_states) 72 | 73 | query = attn.head_to_batch_dim(query).contiguous() 74 | key = attn.head_to_batch_dim(key).contiguous() 75 | value = attn.head_to_batch_dim(value).contiguous() 76 | 77 | if is_xformers_available() and not is_cross: 78 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) 79 | hidden_states = hidden_states.to(query.dtype) 80 | else: 81 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 82 | attention_probs = self.controller(attention_probs, is_cross, self.place_in_unet) 83 | hidden_states = torch.bmm(attention_probs, value) 84 | 85 | hidden_states = attn.batch_to_head_dim(hidden_states) 86 | 87 | # linear proj 88 | hidden_states = attn.to_out[0](hidden_states) 89 | # dropout 90 | hidden_states = attn.to_out[1](hidden_states) 91 | 92 | if input_ndim == 4: 93 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 94 | 95 | if attn.residual_connection: 96 | hidden_states = hidden_states + residual 97 | 98 | hidden_states = hidden_states / attn.rescale_output_factor 99 | 100 | return hidden_states 101 | 102 | 103 | class EDLoRA_AttnProcessor: 104 | def __init__(self, cross_attention_idx, attention_op=None): 105 | self.attention_op = attention_op 106 | self.cross_attention_idx = cross_attention_idx 107 | 108 | def __call__( 109 | self, 110 | attn, 111 | hidden_states, 112 | encoder_hidden_states=None, 113 | attention_mask=None, 114 | temb=None, 115 | ): 116 | residual = hidden_states 117 | 118 | if attn.spatial_norm is not None: 119 | hidden_states = attn.spatial_norm(hidden_states, temb) 120 | 121 | input_ndim = hidden_states.ndim 122 | 123 | if input_ndim == 4: 124 | batch_size, channel, height, width = hidden_states.shape 125 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 126 | 127 | if encoder_hidden_states is None: 128 | encoder_hidden_states = hidden_states 129 | else: 130 | if len(encoder_hidden_states.shape) == 4: # multi-layer embedding 131 | encoder_hidden_states = encoder_hidden_states[:, self.cross_attention_idx, ...] 132 | else: # single layer embedding 133 | encoder_hidden_states = encoder_hidden_states 134 | 135 | assert not attn.norm_cross 136 | 137 | batch_size, sequence_length, _ = encoder_hidden_states.shape 138 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 139 | 140 | if attn.group_norm is not None: 141 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 142 | 143 | query = attn.to_q(hidden_states) 144 | key = attn.to_k(encoder_hidden_states) 145 | value = attn.to_v(encoder_hidden_states) 146 | 147 | query = attn.head_to_batch_dim(query).contiguous() 148 | key = attn.head_to_batch_dim(key).contiguous() 149 | value = attn.head_to_batch_dim(value).contiguous() 150 | 151 | if is_xformers_available(): 152 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) 153 | hidden_states = hidden_states.to(query.dtype) 154 | else: 155 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 156 | hidden_states = torch.bmm(attention_probs, value) 157 | 158 | hidden_states = attn.batch_to_head_dim(hidden_states) 159 | 160 | # linear proj 161 | hidden_states = attn.to_out[0](hidden_states) 162 | # dropout 163 | hidden_states = attn.to_out[1](hidden_states) 164 | 165 | if input_ndim == 4: 166 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 167 | 168 | if attn.residual_connection: 169 | hidden_states = hidden_states + residual 170 | 171 | hidden_states = hidden_states / attn.rescale_output_factor 172 | 173 | return hidden_states 174 | 175 | 176 | def revise_edlora_unet_attention_forward(unet): 177 | def change_forward(unet, count): 178 | for name, layer in unet.named_children(): 179 | if layer.__class__.__name__ == 'Attention' and 'attn2' in name: 180 | layer.set_processor(EDLoRA_AttnProcessor(count)) 181 | count += 1 182 | else: 183 | count = change_forward(layer, count) 184 | return count 185 | 186 | # use this to ensure the order 187 | cross_attention_idx = change_forward(unet.down_blocks, 0) 188 | cross_attention_idx = change_forward(unet.mid_block, cross_attention_idx) 189 | cross_attention_idx = change_forward(unet.up_blocks, cross_attention_idx) 190 | print(f'Number of attention layer registered {cross_attention_idx}') 191 | 192 | 193 | def revise_edlora_unet_attention_controller_forward(unet, controller): 194 | class DummyController: 195 | def __call__(self, *args): 196 | return args[0] 197 | 198 | def __init__(self): 199 | self.num_att_layers = 0 200 | 201 | if controller is None: 202 | controller = DummyController() 203 | 204 | def change_forward(unet, count, place_in_unet): 205 | for name, layer in unet.named_children(): 206 | if layer.__class__.__name__ == 'Attention' and 'attn2' in name: # only register controller for cross-attention 207 | layer.set_processor(EDLoRA_Control_AttnProcessor(count, place_in_unet, controller)) 208 | count += 1 209 | else: 210 | count = change_forward(layer, count, place_in_unet) 211 | return count 212 | 213 | # use this to ensure the order 214 | cross_attention_idx = change_forward(unet.down_blocks, 0, 'down') 215 | cross_attention_idx = change_forward(unet.mid_block, cross_attention_idx, 'mid') 216 | cross_attention_idx = change_forward(unet.up_blocks, cross_attention_idx, 'up') 217 | print(f'Number of attention layer registered {cross_attention_idx}') 218 | controller.num_att_layers = cross_attention_idx 219 | 220 | 221 | class LoRALinearLayer(nn.Module): 222 | def __init__(self, name, original_module, rank=4, alpha=1): 223 | super().__init__() 224 | 225 | self.name = name 226 | 227 | if original_module.__class__.__name__ == 'Conv2d': 228 | in_channels, out_channels = original_module.in_channels, original_module.out_channels 229 | self.lora_down = torch.nn.Conv2d(in_channels, rank, (1, 1), bias=False) 230 | self.lora_up = torch.nn.Conv2d(rank, out_channels, (1, 1), bias=False) 231 | else: 232 | in_features, out_features = original_module.in_features, original_module.out_features 233 | self.lora_down = nn.Linear(in_features, rank, bias=False) 234 | self.lora_up = nn.Linear(rank, out_features, bias=False) 235 | 236 | self.register_buffer('alpha', torch.tensor(alpha)) 237 | 238 | torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) 239 | torch.nn.init.zeros_(self.lora_up.weight) 240 | 241 | self.original_forward = original_module.forward 242 | original_module.forward = self.forward 243 | 244 | def forward(self, hidden_states): 245 | hidden_states = self.original_forward(hidden_states) + self.alpha * self.lora_up(self.lora_down(hidden_states)) 246 | return hidden_states 247 | -------------------------------------------------------------------------------- /mixofshow/pipelines/pipeline_edlora.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Optional, Union 2 | 3 | import torch 4 | from diffusers import StableDiffusionPipeline 5 | from diffusers.configuration_utils import FrozenDict 6 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 7 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput 8 | from diffusers.schedulers import KarrasDiffusionSchedulers 9 | from diffusers.utils import deprecate 10 | from einops import rearrange 11 | from packaging import version 12 | from transformers import CLIPTextModel, CLIPTokenizer 13 | 14 | from mixofshow.models.edlora import (revise_edlora_unet_attention_controller_forward, 15 | revise_edlora_unet_attention_forward) 16 | 17 | 18 | def bind_concept_prompt(prompts, new_concept_cfg): 19 | if isinstance(prompts, str): 20 | prompts = [prompts] 21 | new_prompts = [] 22 | for prompt in prompts: 23 | prompt = [prompt] * 16 24 | for concept_name, new_token_cfg in new_concept_cfg.items(): 25 | prompt = [ 26 | p.replace(concept_name, new_name) for p, new_name in zip(prompt, new_token_cfg['concept_token_names']) 27 | ] 28 | new_prompts.extend(prompt) 29 | return new_prompts 30 | 31 | 32 | class EDLoRAPipeline(StableDiffusionPipeline): 33 | 34 | def __init__( 35 | self, 36 | vae: AutoencoderKL, 37 | text_encoder: CLIPTextModel, 38 | tokenizer: CLIPTokenizer, 39 | unet: UNet2DConditionModel, 40 | scheduler: KarrasDiffusionSchedulers, 41 | safety_checker=None, 42 | feature_extractor=None, 43 | requires_safety_checker: bool = False, 44 | ): 45 | if hasattr(scheduler.config, 'steps_offset') and scheduler.config.steps_offset != 1: 46 | deprecation_message = ( 47 | f'The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`' 48 | f' should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure ' 49 | 'to update the config accordingly as leaving `steps_offset` might led to incorrect results' 50 | ' in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,' 51 | ' it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`' 52 | ' file' 53 | ) 54 | deprecate('steps_offset!=1', '1.0.0', deprecation_message, standard_warn=False) 55 | new_config = dict(scheduler.config) 56 | new_config['steps_offset'] = 1 57 | scheduler._internal_dict = FrozenDict(new_config) 58 | 59 | if hasattr(scheduler.config, 'clip_sample') and scheduler.config.clip_sample is True: 60 | deprecation_message = ( 61 | f'The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`.' 62 | ' `clip_sample` should be set to False in the configuration file. Please make sure to update the' 63 | ' config accordingly as not setting `clip_sample` in the config might lead to incorrect results in' 64 | ' future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very' 65 | ' nice if you could open a Pull request for the `scheduler/scheduler_config.json` file' 66 | ) 67 | deprecate('clip_sample not set', '1.0.0', deprecation_message, standard_warn=False) 68 | new_config = dict(scheduler.config) 69 | new_config['clip_sample'] = False 70 | scheduler._internal_dict = FrozenDict(new_config) 71 | 72 | is_unet_version_less_0_9_0 = hasattr(unet.config, '_diffusers_version') and version.parse( 73 | version.parse(unet.config._diffusers_version).base_version 74 | ) < version.parse('0.9.0.dev0') 75 | is_unet_sample_size_less_64 = hasattr(unet.config, 'sample_size') and unet.config.sample_size < 64 76 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: 77 | deprecation_message = ( 78 | 'The configuration file of the unet has set the default `sample_size` to smaller than' 79 | ' 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the' 80 | ' following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-' 81 | ' CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5' 82 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" 83 | ' configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`' 84 | ' in the config might lead to incorrect results in future versions. If you have downloaded this' 85 | ' checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for' 86 | ' the `unet/config.json` file' 87 | ) 88 | deprecate('sample_size<64', '1.0.0', deprecation_message, standard_warn=False) 89 | new_config = dict(unet.config) 90 | new_config['sample_size'] = 64 91 | unet._internal_dict = FrozenDict(new_config) 92 | 93 | revise_edlora_unet_attention_forward(unet) 94 | self.register_modules( 95 | vae=vae, 96 | text_encoder=text_encoder, 97 | tokenizer=tokenizer, 98 | unet=unet, 99 | scheduler=scheduler 100 | ) 101 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 102 | self.new_concept_cfg = None 103 | 104 | def set_new_concept_cfg(self, new_concept_cfg=None): 105 | self.new_concept_cfg = new_concept_cfg 106 | 107 | def set_controller(self, controller): 108 | self.controller = controller 109 | revise_edlora_unet_attention_controller_forward(self.unet, controller) 110 | 111 | def _encode_prompt(self, 112 | prompt, 113 | new_concept_cfg, 114 | device, 115 | num_images_per_prompt, 116 | do_classifier_free_guidance, 117 | negative_prompt=None, 118 | prompt_embeds: Optional[torch.FloatTensor] = None, 119 | negative_prompt_embeds: Optional[torch.FloatTensor] = None 120 | ): 121 | 122 | assert num_images_per_prompt == 1, 'only support num_images_per_prompt=1 now' 123 | 124 | if prompt is not None and isinstance(prompt, str): 125 | batch_size = 1 126 | elif prompt is not None and isinstance(prompt, list): 127 | batch_size = len(prompt) 128 | else: 129 | batch_size = prompt_embeds.shape[0] 130 | 131 | if prompt_embeds is None: 132 | 133 | prompt_extend = bind_concept_prompt(prompt, new_concept_cfg) 134 | 135 | text_inputs = self.tokenizer( 136 | prompt_extend, 137 | padding='max_length', 138 | max_length=self.tokenizer.model_max_length, 139 | truncation=True, 140 | return_tensors='pt', 141 | ) 142 | text_input_ids = text_inputs.input_ids 143 | 144 | prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] 145 | prompt_embeds = rearrange(prompt_embeds, '(b n) m c -> b n m c', b=batch_size) 146 | 147 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 148 | 149 | bs_embed, layer_num, seq_len, _ = prompt_embeds.shape 150 | 151 | # get unconditional embeddings for classifier free guidance 152 | if do_classifier_free_guidance and negative_prompt_embeds is None: 153 | uncond_tokens: List[str] 154 | if negative_prompt is None: 155 | uncond_tokens = [''] * batch_size 156 | elif type(prompt) is not type(negative_prompt): 157 | raise TypeError(f'`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=' 158 | f' {type(prompt)}.') 159 | elif isinstance(negative_prompt, str): 160 | uncond_tokens = [negative_prompt] 161 | elif batch_size != len(negative_prompt): 162 | raise ValueError( 163 | f'`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:' 164 | f' {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches' 165 | ' the batch size of `prompt`.') 166 | else: 167 | uncond_tokens = negative_prompt 168 | 169 | uncond_input = self.tokenizer( 170 | uncond_tokens, 171 | padding='max_length', 172 | max_length=seq_len, 173 | truncation=True, 174 | return_tensors='pt', 175 | ) 176 | 177 | negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(device))[0] 178 | 179 | if do_classifier_free_guidance: 180 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 181 | seq_len = negative_prompt_embeds.shape[1] 182 | 183 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 184 | negative_prompt_embeds = (negative_prompt_embeds).view(batch_size, 1, seq_len, -1).repeat(1, layer_num, 1, 1) 185 | 186 | # For classifier free guidance, we need to do two forward passes. 187 | # Here we concatenate the unconditional and text embeddings into a single batch 188 | # to avoid doing two forward passes 189 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 190 | return prompt_embeds 191 | 192 | @torch.no_grad() 193 | def __call__( 194 | self, 195 | prompt: Union[str, List[str]] = None, 196 | height: Optional[int] = None, 197 | width: Optional[int] = None, 198 | num_inference_steps: int = 50, 199 | guidance_scale: float = 7.5, 200 | negative_prompt: Optional[Union[str, List[str]]] = None, 201 | num_images_per_prompt: Optional[int] = 1, 202 | eta: float = 0.0, 203 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 204 | latents: Optional[torch.FloatTensor] = None, 205 | prompt_embeds: Optional[torch.FloatTensor] = None, 206 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 207 | output_type: Optional[str] = 'pil', 208 | return_dict: bool = True, 209 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 210 | callback_steps: int = 1, 211 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 212 | ): 213 | 214 | # 0. Default height and width to unet 215 | height = height or self.unet.config.sample_size * self.vae_scale_factor 216 | width = width or self.unet.config.sample_size * self.vae_scale_factor 217 | 218 | # 1. Check inputs. Raise error if not correct 219 | self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) 220 | 221 | # 2. Define call parameters 222 | if prompt is not None and isinstance(prompt, str): 223 | batch_size = 1 224 | elif prompt is not None and isinstance(prompt, list): 225 | batch_size = len(prompt) 226 | else: 227 | batch_size = prompt_embeds.shape[0] 228 | 229 | device = self._execution_device 230 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 231 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 232 | # corresponds to doing no classifier free guidance. 233 | do_classifier_free_guidance = guidance_scale > 1.0 234 | 235 | # 3. Encode input prompt, this support pplus and edlora (layer-wise embedding) 236 | assert self.new_concept_cfg is not None 237 | prompt_embeds = self._encode_prompt( 238 | prompt, 239 | self.new_concept_cfg, 240 | device, 241 | num_images_per_prompt, 242 | do_classifier_free_guidance, 243 | negative_prompt, 244 | prompt_embeds=prompt_embeds, 245 | negative_prompt_embeds=negative_prompt_embeds, 246 | ) 247 | 248 | # 4. Prepare timesteps 249 | self.scheduler.set_timesteps(num_inference_steps, device=device) 250 | timesteps = self.scheduler.timesteps 251 | 252 | # 5. Prepare latent variables 253 | num_channels_latents = self.unet.in_channels 254 | latents = self.prepare_latents( 255 | batch_size * num_images_per_prompt, 256 | num_channels_latents, 257 | height, 258 | width, 259 | prompt_embeds.dtype, 260 | device, 261 | generator, 262 | latents, 263 | ) 264 | 265 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 266 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 267 | 268 | # 7. Denoising loop 269 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 270 | with self.progress_bar(total=num_inference_steps) as progress_bar: 271 | for i, t in enumerate(timesteps): 272 | # expand the latents if we are doing classifier free guidance 273 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 274 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 275 | 276 | # predict the noise residual 277 | noise_pred = self.unet( 278 | latent_model_input, 279 | t, 280 | encoder_hidden_states=prompt_embeds, 281 | cross_attention_kwargs=cross_attention_kwargs, 282 | ).sample 283 | 284 | # perform guidance 285 | if do_classifier_free_guidance: 286 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 287 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 288 | 289 | # compute the previous noisy sample x_t -> x_t-1 290 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 291 | 292 | if hasattr(self, 'controller'): 293 | dtype = latents.dtype 294 | latents = self.controller.step_callback(latents) 295 | latents = latents.to(dtype) 296 | 297 | # call the callback, if provided 298 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 299 | progress_bar.update() 300 | if callback is not None and i % callback_steps == 0: 301 | callback(i, t, latents) 302 | 303 | if output_type == 'latent': 304 | image = latents 305 | elif output_type == 'pil': 306 | # 8. Post-processing 307 | image = self.decode_latents(latents) 308 | 309 | # 10. Convert to PIL 310 | image = self.numpy_to_pil(image) 311 | else: 312 | # 8. Post-processing 313 | image = self.decode_latents(latents) 314 | 315 | # Offload last model to CPU 316 | if hasattr(self, 'final_offload_hook') and self.final_offload_hook is not None: 317 | self.final_offload_hook.offload() 318 | 319 | if not return_dict: 320 | return (image) 321 | 322 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) 323 | -------------------------------------------------------------------------------- /mixofshow/pipelines/trainer_edlora.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import math 3 | import re 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from accelerate.logging import get_logger 9 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 10 | from diffusers.utils.import_utils import is_xformers_available 11 | from einops import rearrange 12 | from transformers import CLIPTextModel, CLIPTokenizer 13 | 14 | from mixofshow.models.edlora import (LoRALinearLayer, revise_edlora_unet_attention_controller_forward, 15 | revise_edlora_unet_attention_forward) 16 | from mixofshow.pipelines.pipeline_edlora import bind_concept_prompt 17 | from mixofshow.utils.ptp_util import AttentionStore 18 | 19 | 20 | class EDLoRATrainer(nn.Module): 21 | def __init__( 22 | self, 23 | pretrained_path, 24 | new_concept_token, 25 | initializer_token, 26 | enable_edlora, # true for ED-LoRA, false for LoRA 27 | finetune_cfg=None, 28 | noise_offset=None, 29 | attn_reg_weight=None, 30 | reg_full_identity=True, # True for thanos, False for real person (don't need to encode clothes) 31 | use_mask_loss=True, 32 | enable_xformers=False, 33 | gradient_checkpoint=False 34 | ): 35 | super().__init__() 36 | 37 | # 1. Load the model. 38 | self.vae = AutoencoderKL.from_pretrained(pretrained_path, subfolder='vae') 39 | self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_path, subfolder='tokenizer') 40 | self.text_encoder = CLIPTextModel.from_pretrained(pretrained_path, subfolder='text_encoder') 41 | self.unet = UNet2DConditionModel.from_pretrained(pretrained_path, subfolder='unet') 42 | 43 | if gradient_checkpoint: 44 | self.unet.enable_gradient_checkpointing() 45 | 46 | if enable_xformers: 47 | assert is_xformers_available(), 'need to install xformer first' 48 | 49 | # 2. Define train scheduler 50 | self.scheduler = DDPMScheduler.from_pretrained(pretrained_path, subfolder='scheduler') 51 | 52 | # 3. define training cfg 53 | self.enable_edlora = enable_edlora 54 | self.new_concept_cfg = self.init_new_concept(new_concept_token, initializer_token, enable_edlora=enable_edlora) 55 | 56 | self.attn_reg_weight = attn_reg_weight 57 | self.reg_full_identity = reg_full_identity 58 | if self.attn_reg_weight is not None: 59 | self.controller = AttentionStore(training=True) 60 | revise_edlora_unet_attention_controller_forward(self.unet, self.controller) # support both lora and edlora forward 61 | else: 62 | revise_edlora_unet_attention_forward(self.unet) # support both lora and edlora forward 63 | 64 | if finetune_cfg: 65 | self.set_finetune_cfg(finetune_cfg) 66 | 67 | self.noise_offset = noise_offset 68 | self.use_mask_loss = use_mask_loss 69 | 70 | def set_finetune_cfg(self, finetune_cfg): 71 | logger = get_logger('mixofshow', log_level='INFO') 72 | params_to_freeze = [self.vae.parameters(), self.text_encoder.parameters(), self.unet.parameters()] 73 | 74 | # step 1: close all parameters, required_grad to False 75 | for params in itertools.chain(*params_to_freeze): 76 | params.requires_grad = False 77 | 78 | # step 2: begin to add trainable paramters 79 | params_group_list = [] 80 | 81 | # 1. text embedding 82 | if finetune_cfg['text_embedding']['enable_tuning']: 83 | text_embedding_cfg = finetune_cfg['text_embedding'] 84 | 85 | params_list = [] 86 | for params in self.text_encoder.get_input_embeddings().parameters(): 87 | params.requires_grad = True 88 | params_list.append(params) 89 | 90 | params_group = {'params': params_list, 'lr': text_embedding_cfg['lr']} 91 | if 'weight_decay' in text_embedding_cfg: 92 | params_group.update({'weight_decay': text_embedding_cfg['weight_decay']}) 93 | params_group_list.append(params_group) 94 | logger.info(f"optimizing embedding using lr: {text_embedding_cfg['lr']}") 95 | 96 | # 2. text encoder 97 | if finetune_cfg['text_encoder']['enable_tuning'] and finetune_cfg['text_encoder'].get('lora_cfg'): 98 | text_encoder_cfg = finetune_cfg['text_encoder'] 99 | 100 | where = text_encoder_cfg['lora_cfg'].pop('where') 101 | assert where in ['CLIPEncoderLayer', 'CLIPAttention'] 102 | 103 | self.text_encoder_lora = nn.ModuleList() 104 | params_list = [] 105 | 106 | for name, module in self.text_encoder.named_modules(): 107 | if module.__class__.__name__ == where: 108 | for child_name, child_module in module.named_modules(): 109 | if child_module.__class__.__name__ == 'Linear': 110 | lora_module = LoRALinearLayer(name + '.' + child_name, child_module, **text_encoder_cfg['lora_cfg']) 111 | self.text_encoder_lora.append(lora_module) 112 | params_list.extend(list(lora_module.parameters())) 113 | 114 | params_group_list.append({'params': params_list, 'lr': text_encoder_cfg['lr']}) 115 | logger.info(f"optimizing text_encoder ({len(self.text_encoder_lora)} LoRAs), using lr: {text_encoder_cfg['lr']}") 116 | 117 | # 3. unet 118 | if finetune_cfg['unet']['enable_tuning'] and finetune_cfg['unet'].get('lora_cfg'): 119 | unet_cfg = finetune_cfg['unet'] 120 | 121 | where = unet_cfg['lora_cfg'].pop('where') 122 | assert where in ['Transformer2DModel', 'Attention'] 123 | 124 | self.unet_lora = nn.ModuleList() 125 | params_list = [] 126 | 127 | for name, module in self.unet.named_modules(): 128 | if module.__class__.__name__ == where: 129 | for child_name, child_module in module.named_modules(): 130 | if child_module.__class__.__name__ == 'Linear' or (child_module.__class__.__name__ == 'Conv2d' and child_module.kernel_size == (1, 1)): 131 | lora_module = LoRALinearLayer(name + '.' + child_name, child_module, **unet_cfg['lora_cfg']) 132 | self.unet_lora.append(lora_module) 133 | params_list.extend(list(lora_module.parameters())) 134 | 135 | params_group_list.append({'params': params_list, 'lr': unet_cfg['lr']}) 136 | logger.info(f"optimizing unet ({len(self.unet_lora)} LoRAs), using lr: {unet_cfg['lr']}") 137 | 138 | # 4. optimize params 139 | self.params_to_optimize_iterator = params_group_list 140 | 141 | def get_params_to_optimize(self): 142 | return self.params_to_optimize_iterator 143 | 144 | def init_new_concept(self, new_concept_tokens, initializer_tokens, enable_edlora=True): 145 | logger = get_logger('mixofshow', log_level='INFO') 146 | new_concept_cfg = {} 147 | new_concept_tokens = new_concept_tokens.split('+') 148 | 149 | if initializer_tokens is None: 150 | initializer_tokens = [''] * len(new_concept_tokens) 151 | else: 152 | initializer_tokens = initializer_tokens.split('+') 153 | assert len(new_concept_tokens) == len(initializer_tokens), 'concept token should match init token.' 154 | 155 | for idx, (concept_name, init_token) in enumerate(zip(new_concept_tokens, initializer_tokens)): 156 | if enable_edlora: 157 | num_new_embedding = 16 158 | else: 159 | num_new_embedding = 1 160 | new_token_names = [f'' for layer_id in range(num_new_embedding)] 161 | 162 | num_added_tokens = self.tokenizer.add_tokens(new_token_names) 163 | assert num_added_tokens == len(new_token_names), 'some token is already in tokenizer' 164 | new_token_ids = [self.tokenizer.convert_tokens_to_ids(token_name) for token_name in new_token_names] 165 | 166 | # init embedding 167 | self.text_encoder.resize_token_embeddings(len(self.tokenizer)) 168 | token_embeds = self.text_encoder.get_input_embeddings().weight.data 169 | 170 | if init_token.startswith('', init_token)[0]) 172 | init_feature = torch.randn_like(token_embeds[0]) * sigma_val 173 | logger.info(f'{concept_name} ({min(new_token_ids)}-{max(new_token_ids)}) is random initialized by: {init_token}') 174 | else: 175 | # Convert the initializer_token, placeholder_token to ids 176 | init_token_ids = self.tokenizer.encode(init_token, add_special_tokens=False) 177 | # print(token_ids) 178 | # Check if initializer_token is a single token or a sequence of tokens 179 | if len(init_token_ids) > 1 or init_token_ids[0] == 40497: 180 | raise ValueError('The initializer token must be a single existing token.') 181 | init_feature = token_embeds[init_token_ids] 182 | logger.info(f'{concept_name} ({min(new_token_ids)}-{max(new_token_ids)}) is random initialized by existing token ({init_token}): {init_token_ids[0]}') 183 | 184 | for token_id in new_token_ids: 185 | token_embeds[token_id] = init_feature.clone() 186 | 187 | new_concept_cfg.update({ 188 | concept_name: { 189 | 'concept_token_ids': new_token_ids, 190 | 'concept_token_names': new_token_names 191 | } 192 | }) 193 | 194 | return new_concept_cfg 195 | 196 | def get_all_concept_token_ids(self): 197 | new_concept_token_ids = [] 198 | for _, new_token_cfg in self.new_concept_cfg.items(): 199 | new_concept_token_ids.extend(new_token_cfg['concept_token_ids']) 200 | return new_concept_token_ids 201 | 202 | def forward(self, images, prompts, masks, img_masks): 203 | latents = self.vae.encode(images).latent_dist.sample() 204 | latents = latents * 0.18215 205 | 206 | # Sample noise that we'll add to the latents 207 | noise = torch.randn_like(latents) 208 | if self.noise_offset is not None: 209 | noise += self.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) 210 | 211 | bsz = latents.shape[0] 212 | # Sample a random timestep for each image 213 | timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, (bsz, ), device=latents.device) 214 | timesteps = timesteps.long() 215 | 216 | # Add noise to the latents according to the noise magnitude at each timestep 217 | # (this is the forward diffusion process) 218 | noisy_latents = self.scheduler.add_noise(latents, noise, timesteps) 219 | 220 | if self.enable_edlora: 221 | prompts = bind_concept_prompt(prompts, new_concept_cfg=self.new_concept_cfg) # edlora 222 | 223 | # get text ids 224 | text_input_ids = self.tokenizer( 225 | prompts, 226 | padding='max_length', 227 | max_length=self.tokenizer.model_max_length, 228 | truncation=True, 229 | return_tensors='pt').input_ids.to(latents.device) 230 | 231 | # Get the text embedding for conditioning 232 | encoder_hidden_states = self.text_encoder(text_input_ids)[0] 233 | if self.enable_edlora: 234 | encoder_hidden_states = rearrange(encoder_hidden_states, '(b n) m c -> b n m c', b=latents.shape[0]) # edlora 235 | 236 | # Predict the noise residual 237 | model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample 238 | 239 | # Get the target for loss depending on the prediction type 240 | if self.scheduler.config.prediction_type == 'epsilon': 241 | target = noise 242 | elif self.scheduler.config.prediction_type == 'v_prediction': 243 | target = self.scheduler.get_velocity(latents, noise, timesteps) 244 | else: 245 | raise ValueError(f'Unknown prediction type {self.scheduler.config.prediction_type}') 246 | 247 | if self.use_mask_loss: 248 | loss_mask = masks 249 | else: 250 | loss_mask = img_masks 251 | loss = F.mse_loss(model_pred.float(), target.float(), reduction='none') 252 | loss = ((loss * loss_mask).sum([1, 2, 3]) / loss_mask.sum([1, 2, 3])).mean() 253 | 254 | if self.attn_reg_weight is not None: 255 | attention_maps = self.controller.get_average_attention() 256 | attention_loss = self.cal_attn_reg(attention_maps, masks, text_input_ids) 257 | if not torch.isnan(attention_loss): # full mask 258 | loss = loss + attention_loss 259 | self.controller.reset() 260 | 261 | return loss 262 | 263 | def cal_attn_reg(self, attention_maps, masks, text_input_ids): 264 | ''' 265 | attention_maps: {down_cross:[], mid_cross:[], up_cross:[]} 266 | masks: torch.Size([1, 1, 64, 64]) 267 | text_input_ids: torch.Size([16, 77]) 268 | ''' 269 | # step 1: find token position 270 | batch_size = masks.shape[0] 271 | text_input_ids = rearrange(text_input_ids, '(b l) n -> b l n', b=batch_size) 272 | # print(masks.shape) # torch.Size([2, 1, 64, 64]) 273 | # print(text_input_ids.shape) # torch.Size([2, 16, 77]) 274 | 275 | new_token_pos = [] 276 | all_concept_token_ids = self.get_all_concept_token_ids() 277 | for text in text_input_ids: 278 | text = text[0] # even multi-layer embedding, we extract the first one 279 | new_token_pos.append([idx for idx in range(len(text)) if text[idx] in all_concept_token_ids]) 280 | 281 | # step2: aggregate attention maps with resolution and concat heads 282 | attention_groups = {'64': [], '32': [], '16': [], '8': []} 283 | for _, attention_list in attention_maps.items(): 284 | for attn in attention_list: 285 | res = int(math.sqrt(attn.shape[1])) 286 | cross_map = attn.reshape(batch_size, -1, res, res, attn.shape[-1]) 287 | attention_groups[str(res)].append(cross_map) 288 | 289 | for k, cross_map in attention_groups.items(): 290 | cross_map = torch.cat(cross_map, dim=-4) # concat heads 291 | cross_map = cross_map.sum(-4) / cross_map.shape[-4] # e.g., 64 torch.Size([2, 64, 64, 77]) 292 | cross_map = torch.stack([batch_map[..., batch_pos] for batch_pos, batch_map in zip(new_token_pos, cross_map)]) # torch.Size([2, 64, 64, 2]) 293 | attention_groups[k] = cross_map 294 | 295 | attn_reg_total = 0 296 | # step3: calculate loss for each resolution: -> is to penalize outside mask, to align with mask 297 | for k, cross_map in attention_groups.items(): 298 | map_adjective, map_subject = cross_map[..., 0], cross_map[..., 1] 299 | 300 | map_subject = map_subject / map_subject.max() 301 | map_adjective = map_adjective / map_adjective.max() 302 | 303 | gt_mask = F.interpolate(masks, size=map_subject.shape[1:], mode='nearest').squeeze(1) 304 | 305 | if self.reg_full_identity: 306 | loss_subject = F.mse_loss(map_subject.float(), gt_mask.float(), reduction='mean') 307 | else: 308 | loss_subject = map_subject[gt_mask == 0].mean() 309 | 310 | loss_adjective = map_adjective[gt_mask == 0].mean() 311 | 312 | attn_reg_total += self.attn_reg_weight * (loss_subject + loss_adjective) 313 | return attn_reg_total 314 | 315 | def load_delta_state_dict(self, delta_state_dict): 316 | # load embedding 317 | logger = get_logger('mixofshow', log_level='INFO') 318 | 319 | if 'new_concept_embedding' in delta_state_dict and len(delta_state_dict['new_concept_embedding']) != 0: 320 | new_concept_tokens = list(delta_state_dict['new_concept_embedding'].keys()) 321 | 322 | # check whether new concept is initialized 323 | token_embeds = self.text_encoder.get_input_embeddings().weight.data 324 | if set(new_concept_tokens) != set(self.new_concept_cfg.keys()): 325 | logger.warning('Your checkpoint have different concept with your model, loading existing concepts') 326 | 327 | for concept_name, concept_cfg in self.new_concept_cfg.items(): 328 | logger.info(f'load: concept_{concept_name}') 329 | token_embeds[concept_cfg['concept_token_ids']] = token_embeds[ 330 | concept_cfg['concept_token_ids']].copy_(delta_state_dict['new_concept_embedding'][concept_name]) 331 | 332 | # load text_encoder 333 | if 'text_encoder' in delta_state_dict and len(delta_state_dict['text_encoder']) != 0: 334 | load_keys = delta_state_dict['text_encoder'].keys() 335 | if hasattr(self, 'text_encoder_lora') and len(load_keys) == 2 * len(self.text_encoder_lora): 336 | logger.info('loading LoRA for text encoder:') 337 | for lora_module in self.text_encoder_lora: 338 | for name, param, in lora_module.named_parameters(): 339 | logger.info(f'load: {lora_module.name}.{name}') 340 | param.data.copy_(delta_state_dict['text_encoder'][f'{lora_module.name}.{name}']) 341 | else: 342 | for name, param, in self.text_encoder.named_parameters(): 343 | if name in load_keys and 'token_embedding' not in name: 344 | logger.info(f'load: {name}') 345 | param.data.copy_(delta_state_dict['text_encoder'][f'{name}']) 346 | 347 | # load unet 348 | if 'unet' in delta_state_dict and len(delta_state_dict['unet']) != 0: 349 | load_keys = delta_state_dict['unet'].keys() 350 | if hasattr(self, 'unet_lora') and len(load_keys) == 2 * len(self.unet_lora): 351 | logger.info('loading LoRA for unet:') 352 | for lora_module in self.unet_lora: 353 | for name, param, in lora_module.named_parameters(): 354 | logger.info(f'load: {lora_module.name}.{name}') 355 | param.data.copy_(delta_state_dict['unet'][f'{lora_module.name}.{name}']) 356 | else: 357 | for name, param, in self.unet.named_parameters(): 358 | if name in load_keys: 359 | logger.info(f'load: {name}') 360 | param.data.copy_(delta_state_dict['unet'][f'{name}']) 361 | 362 | def delta_state_dict(self): 363 | delta_dict = {'new_concept_embedding': {}, 'text_encoder': {}, 'unet': {}} 364 | 365 | # save_embedding 366 | for concept_name, concept_cfg in self.new_concept_cfg.items(): 367 | learned_embeds = self.text_encoder.get_input_embeddings().weight[concept_cfg['concept_token_ids']] 368 | delta_dict['new_concept_embedding'][concept_name] = learned_embeds.detach().cpu() 369 | 370 | # save text model 371 | for lora_module in self.text_encoder_lora: 372 | for name, param, in lora_module.named_parameters(): 373 | delta_dict['text_encoder'][f'{lora_module.name}.{name}'] = param.cpu().clone() 374 | 375 | # save unet model 376 | for lora_module in self.unet_lora: 377 | for name, param, in lora_module.named_parameters(): 378 | delta_dict['unet'][f'{lora_module.name}.{name}'] = param.cpu().clone() 379 | 380 | return delta_dict -------------------------------------------------------------------------------- /mixofshow/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/mixofshow/utils/__init__.py -------------------------------------------------------------------------------- /mixofshow/utils/arial.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/Mix-of-Show/51270fd5f9907cada8f416a4eff191a74f842660/mixofshow/utils/arial.ttf -------------------------------------------------------------------------------- /mixofshow/utils/convert_edlora_to_diffusers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | 4 | def load_new_concept(pipe, new_concept_embedding, enable_edlora=True): 5 | new_concept_cfg = {} 6 | 7 | for idx, (concept_name, concept_embedding) in enumerate(new_concept_embedding.items()): 8 | if enable_edlora: 9 | num_new_embedding = 16 10 | else: 11 | num_new_embedding = 1 12 | new_token_names = [f'' for layer_id in range(num_new_embedding)] 13 | num_added_tokens = pipe.tokenizer.add_tokens(new_token_names) 14 | assert num_added_tokens == len(new_token_names), 'some token is already in tokenizer' 15 | new_token_ids = [pipe.tokenizer.convert_tokens_to_ids(token_name) for token_name in new_token_names] 16 | 17 | # init embedding 18 | pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer)) 19 | token_embeds = pipe.text_encoder.get_input_embeddings().weight.data 20 | token_embeds[new_token_ids] = concept_embedding.clone().to(token_embeds.device, dtype=token_embeds.dtype) 21 | print(f'load embedding: {concept_name}') 22 | 23 | new_concept_cfg.update({ 24 | concept_name: { 25 | 'concept_token_ids': new_token_ids, 26 | 'concept_token_names': new_token_names 27 | } 28 | }) 29 | 30 | return pipe, new_concept_cfg 31 | 32 | 33 | def merge_lora_into_weight(original_state_dict, lora_state_dict, model_type, alpha): 34 | def get_lora_down_name(original_layer_name): 35 | if model_type == 'text_encoder': 36 | lora_down_name = original_layer_name.replace('q_proj.weight', 'q_proj.lora_down.weight') \ 37 | .replace('k_proj.weight', 'k_proj.lora_down.weight') \ 38 | .replace('v_proj.weight', 'v_proj.lora_down.weight') \ 39 | .replace('out_proj.weight', 'out_proj.lora_down.weight') \ 40 | .replace('fc1.weight', 'fc1.lora_down.weight') \ 41 | .replace('fc2.weight', 'fc2.lora_down.weight') 42 | else: 43 | lora_down_name = k.replace('to_q.weight', 'to_q.lora_down.weight') \ 44 | .replace('to_k.weight', 'to_k.lora_down.weight') \ 45 | .replace('to_v.weight', 'to_v.lora_down.weight') \ 46 | .replace('to_out.0.weight', 'to_out.0.lora_down.weight') \ 47 | .replace('ff.net.0.proj.weight', 'ff.net.0.proj.lora_down.weight') \ 48 | .replace('ff.net.2.weight', 'ff.net.2.lora_down.weight') \ 49 | .replace('proj_out.weight', 'proj_out.lora_down.weight') \ 50 | .replace('proj_in.weight', 'proj_in.lora_down.weight') 51 | 52 | return lora_down_name 53 | 54 | assert model_type in ['unet', 'text_encoder'] 55 | new_state_dict = copy.deepcopy(original_state_dict) 56 | 57 | load_cnt = 0 58 | for k in new_state_dict.keys(): 59 | lora_down_name = get_lora_down_name(k) 60 | lora_up_name = lora_down_name.replace('lora_down', 'lora_up') 61 | 62 | if lora_up_name in lora_state_dict: 63 | load_cnt += 1 64 | original_params = new_state_dict[k] 65 | lora_down_params = lora_state_dict[lora_down_name].to(original_params.device) 66 | lora_up_params = lora_state_dict[lora_up_name].to(original_params.device) 67 | if len(original_params.shape) == 4: 68 | lora_param = lora_up_params.squeeze() @ lora_down_params.squeeze() 69 | lora_param = lora_param.unsqueeze(-1).unsqueeze(-1) 70 | else: 71 | lora_param = lora_up_params @ lora_down_params 72 | merge_params = original_params + alpha * lora_param 73 | new_state_dict[k] = merge_params 74 | 75 | print(f'load {load_cnt} LoRAs of {model_type}') 76 | return new_state_dict 77 | 78 | 79 | def convert_edlora(pipe, state_dict, enable_edlora, alpha=0.6): 80 | 81 | state_dict = state_dict['params'] if 'params' in state_dict.keys() else state_dict 82 | 83 | # step 1: load embedding 84 | if 'new_concept_embedding' in state_dict and len(state_dict['new_concept_embedding']) != 0: 85 | pipe, new_concept_cfg = load_new_concept(pipe, state_dict['new_concept_embedding'], enable_edlora) 86 | 87 | # step 2: merge lora weight to unet 88 | unet_lora_state_dict = state_dict['unet'] 89 | pretrained_unet_state_dict = pipe.unet.state_dict() 90 | updated_unet_state_dict = merge_lora_into_weight(pretrained_unet_state_dict, unet_lora_state_dict, model_type='unet', alpha=alpha) 91 | pipe.unet.load_state_dict(updated_unet_state_dict) 92 | 93 | # step 3: merge lora weight to text_encoder 94 | text_encoder_lora_state_dict = state_dict['text_encoder'] 95 | pretrained_text_encoder_state_dict = pipe.text_encoder.state_dict() 96 | updated_text_encoder_state_dict = merge_lora_into_weight(pretrained_text_encoder_state_dict, text_encoder_lora_state_dict, model_type='text_encoder', alpha=alpha) 97 | pipe.text_encoder.load_state_dict(updated_text_encoder_state_dict) 98 | 99 | return pipe, new_concept_cfg 100 | -------------------------------------------------------------------------------- /mixofshow/utils/ptp_util.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import List, Tuple 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from IPython.display import display 8 | from PIL import Image 9 | 10 | 11 | class EmptyControl: 12 | def step_callback(self, x_t): 13 | return x_t 14 | 15 | def between_steps(self): 16 | return 17 | 18 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 19 | return attn 20 | 21 | 22 | class AttentionControl(abc.ABC): 23 | def step_callback(self, x_t): 24 | return x_t 25 | 26 | def between_steps(self): 27 | return 28 | 29 | @property 30 | def num_uncond_att_layers(self): 31 | return self.num_att_layers if self.low_resource else 0 32 | 33 | @abc.abstractmethod 34 | def forward(self, attn, is_cross: bool, place_in_unet: str): 35 | raise NotImplementedError 36 | 37 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 38 | if self.cur_att_layer >= self.num_uncond_att_layers: 39 | if self.low_resource: 40 | attn = self.forward(attn, is_cross, place_in_unet) 41 | else: 42 | if self.training: 43 | attn = self.forward(attn, is_cross, place_in_unet) 44 | else: 45 | h = attn.shape[0] 46 | attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) 47 | 48 | self.cur_att_layer += 1 49 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 50 | self.cur_att_layer = 0 51 | self.cur_step += 1 52 | self.between_steps() 53 | return attn 54 | 55 | def reset(self): 56 | self.cur_step = 0 57 | self.cur_att_layer = 0 58 | 59 | def __init__(self, low_resource, training): 60 | self.cur_step = 0 61 | self.num_att_layers = -1 62 | self.cur_att_layer = 0 63 | self.low_resource = low_resource 64 | self.training = training 65 | 66 | 67 | class AttentionStore(AttentionControl): 68 | @staticmethod 69 | def get_empty_store(): 70 | return { 71 | 'down_cross': [], 72 | 'mid_cross': [], 73 | 'up_cross': [], 74 | 'down_self': [], 75 | 'mid_self': [], 76 | 'up_self': [] 77 | } 78 | 79 | def forward(self, attn, is_cross: bool, place_in_unet: str): 80 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 81 | self.step_store[key].append(attn) 82 | return attn 83 | 84 | def between_steps(self): 85 | if len(self.attention_store) == 0: 86 | self.attention_store = self.step_store 87 | else: 88 | for key in self.attention_store: 89 | for i in range(len(self.attention_store[key])): 90 | self.attention_store[key][i] = self.attention_store[key][i] + self.step_store[key][i] 91 | self.step_store = self.get_empty_store() 92 | 93 | def get_average_attention(self): 94 | average_attention = { 95 | key: [item / self.cur_step for item in self.attention_store[key]] 96 | for key in self.attention_store 97 | } 98 | return average_attention 99 | 100 | def reset(self): 101 | super(AttentionStore, self).reset() 102 | self.step_store = self.get_empty_store() 103 | self.attention_store = {} 104 | 105 | def __init__(self, low_resource=False, training=False): 106 | super(AttentionStore, self).__init__(low_resource, training) 107 | self.step_store = self.get_empty_store() 108 | self.attention_store = {} 109 | 110 | 111 | def text_under_image(image: np.ndarray, 112 | text: str, 113 | text_color: Tuple[int, int, int] = (0, 0, 0)): 114 | h, w, c = image.shape 115 | offset = int(h * .2) 116 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 117 | font = cv2.FONT_HERSHEY_SIMPLEX 118 | # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size) 119 | img[:h] = image 120 | textsize = cv2.getTextSize(text, font, 1, 2)[0] 121 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 122 | cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2) 123 | return img 124 | 125 | 126 | def view_images(images, num_rows=1, offset_ratio=0.02, notebook=True): 127 | if type(images) is list: 128 | num_empty = len(images) % num_rows 129 | elif images.ndim == 4: 130 | num_empty = images.shape[0] % num_rows 131 | else: 132 | images = [images] 133 | num_empty = 0 134 | 135 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 136 | images = [image.astype(np.uint8) 137 | for image in images] + [empty_images] * num_empty 138 | num_items = len(images) 139 | 140 | h, w, c = images[0].shape 141 | offset = int(h * offset_ratio) 142 | num_cols = num_items // num_rows 143 | image_ = np.ones( 144 | (h * num_rows + offset * (num_rows - 1), w * num_cols + offset * 145 | (num_cols - 1), 3), 146 | dtype=np.uint8) * 255 147 | for i in range(num_rows): 148 | for j in range(num_cols): 149 | image_[i * (h + offset):i * (h + offset) + h:, j * (w + offset):j * 150 | (w + offset) + w] = images[i * num_cols + j] 151 | 152 | pil_img = Image.fromarray(image_) 153 | if notebook is True: 154 | display(pil_img) 155 | else: 156 | return pil_img 157 | 158 | 159 | def aggregate_attention(attention_store: AttentionStore, res: int, 160 | from_where: List[str], prompts: List[str], 161 | is_cross: bool, select: int): 162 | out = [] 163 | attention_maps = attention_store.get_average_attention() 164 | num_pixels = res**2 165 | for location in from_where: 166 | for item in attention_maps[ 167 | f"{location}_{'cross' if is_cross else 'self'}"]: 168 | if item.shape[1] == num_pixels: 169 | cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] 170 | out.append(cross_maps) 171 | out = torch.cat(out, dim=0) 172 | out = out.sum(0) / out.shape[0] 173 | return out.cpu() 174 | 175 | 176 | def show_cross_attention(attention_store: AttentionStore, 177 | res: int, 178 | from_where: List[str], 179 | prompts: List[str], 180 | tokenizer, 181 | select: int = 0, 182 | notebook=True): 183 | tokens = tokenizer.encode(prompts[select]) 184 | decoder = tokenizer.decode 185 | attention_maps = aggregate_attention(attention_store, res, from_where, prompts, True, select) 186 | 187 | images = [] 188 | for i in range(len(tokens)): 189 | image = attention_maps[:, :, i] 190 | image = 255 * image / image.max() 191 | image = image.unsqueeze(-1).expand(*image.shape, 3) 192 | image = image.numpy().astype(np.uint8) 193 | image = np.array(Image.fromarray(image).resize((256, 256))) 194 | image = text_under_image(image, decoder(int(tokens[i]))) 195 | images.append(image) 196 | 197 | if notebook is True: 198 | view_images(np.stack(images, axis=0)) 199 | else: 200 | return view_images(np.stack(images, axis=0), notebook=False) 201 | -------------------------------------------------------------------------------- /mixofshow/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 2 | 3 | 4 | class Registry(): 5 | """ 6 | The registry that provides name -> object mapping, to support third-party 7 | users' custom modules. 8 | 9 | To create a registry (e.g. a backbone registry): 10 | 11 | .. code-block:: python 12 | 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | 15 | To register an object: 16 | 17 | .. code-block:: python 18 | 19 | @BACKBONE_REGISTRY.register() 20 | class MyBackbone(): 21 | ... 22 | 23 | Or: 24 | 25 | .. code-block:: python 26 | 27 | BACKBONE_REGISTRY.register(MyBackbone) 28 | """ 29 | def __init__(self, name): 30 | """ 31 | Args: 32 | name (str): the name of this registry 33 | """ 34 | self._name = name 35 | self._obj_map = {} 36 | 37 | def _do_register(self, name, obj): 38 | assert (name not in self._obj_map), ( 39 | f"An object named '{name}' was already registered " 40 | f"in '{self._name}' registry!") 41 | self._obj_map[name] = obj 42 | 43 | def register(self, obj=None): 44 | """ 45 | Register the given object under the the name `obj.__name__`. 46 | Can be used as either a decorator or not. 47 | See docstring of this class for usage. 48 | """ 49 | if obj is None: 50 | # used as a decorator 51 | def deco(func_or_class): 52 | name = func_or_class.__name__ 53 | self._do_register(name, func_or_class) 54 | return func_or_class 55 | 56 | return deco 57 | 58 | # used as a function call 59 | name = obj.__name__ 60 | self._do_register(name, obj) 61 | 62 | def get(self, name): 63 | ret = self._obj_map.get(name) 64 | if ret is None: 65 | raise KeyError( 66 | f"No object named '{name}' found in '{self._name}' registry!") 67 | return ret 68 | 69 | def __contains__(self, name): 70 | return name in self._obj_map 71 | 72 | def __iter__(self): 73 | return iter(self._obj_map.items()) 74 | 75 | def keys(self): 76 | return self._obj_map.keys() 77 | 78 | 79 | TRANSFORM_REGISTRY = Registry('transform') 80 | -------------------------------------------------------------------------------- /mixofshow/utils/util.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import os 4 | import os.path 5 | import os.path as osp 6 | import time 7 | from collections import OrderedDict 8 | 9 | import PIL 10 | import torch 11 | from accelerate.logging import get_logger 12 | from accelerate.state import PartialState 13 | from PIL import Image, ImageDraw, ImageFont 14 | from torchvision.transforms.transforms import ToTensor 15 | from torchvision.utils import make_grid 16 | 17 | NEGATIVE_PROMPT = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' 18 | 19 | 20 | # ----------- file/logger util ---------- 21 | def get_time_str(): 22 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 23 | 24 | 25 | def mkdir_and_rename(path): 26 | """mkdirs. If path exists, rename it with timestamp and create a new one. 27 | 28 | Args: 29 | path (str): Folder path. 30 | """ 31 | if osp.exists(path): 32 | new_name = path + '_archived_' + get_time_str() 33 | print(f'Path already exists. Rename it to {new_name}', flush=True) 34 | os.rename(path, new_name) 35 | os.makedirs(path, exist_ok=True) 36 | 37 | 38 | def make_exp_dirs(opt): 39 | """Make dirs for experiments.""" 40 | path_opt = opt['path'].copy() 41 | if opt['is_train']: 42 | mkdir_and_rename(path_opt.pop('experiments_root')) 43 | else: 44 | mkdir_and_rename(path_opt.pop('results_root')) 45 | for key, path in path_opt.items(): 46 | if ('strict_load' in key) or ('pretrain_network' in key) or ( 47 | 'resume' in key) or ('param_key' in key) or ('lora_path' in key): 48 | continue 49 | else: 50 | os.makedirs(path, exist_ok=True) 51 | 52 | 53 | def copy_opt_file(opt_file, experiments_root): 54 | # copy the yml file to the experiment root 55 | import sys 56 | import time 57 | from shutil import copyfile 58 | cmd = ' '.join(sys.argv) 59 | filename = osp.join(experiments_root, osp.basename(opt_file)) 60 | copyfile(opt_file, filename) 61 | 62 | with open(filename, 'r+') as f: 63 | lines = f.readlines() 64 | lines.insert( 65 | 0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n') 66 | f.seek(0) 67 | f.writelines(lines) 68 | 69 | 70 | def set_path_logger(accelerator, root_path, config_path, opt, is_train=True): 71 | opt['is_train'] = is_train 72 | 73 | if is_train: 74 | experiments_root = osp.join(root_path, 'experiments', opt['name']) 75 | opt['path']['experiments_root'] = experiments_root 76 | opt['path']['models'] = osp.join(experiments_root, 'models') 77 | opt['path']['log'] = experiments_root 78 | opt['path']['visualization'] = osp.join(experiments_root, 79 | 'visualization') 80 | else: 81 | results_root = osp.join(root_path, 'results', opt['name']) 82 | opt['path']['results_root'] = results_root 83 | opt['path']['log'] = results_root 84 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 85 | 86 | # Handle the output folder creation 87 | if accelerator.is_main_process: 88 | make_exp_dirs(opt) 89 | 90 | accelerator.wait_for_everyone() 91 | 92 | if is_train: 93 | copy_opt_file(config_path, opt['path']['experiments_root']) 94 | log_file = osp.join(opt['path']['log'], 95 | f"train_{opt['name']}_{get_time_str()}.log") 96 | set_logger(log_file) 97 | else: 98 | copy_opt_file(config_path, opt['path']['results_root']) 99 | log_file = osp.join(opt['path']['log'], 100 | f"test_{opt['name']}_{get_time_str()}.log") 101 | set_logger(log_file) 102 | 103 | 104 | def set_logger(log_file=None): 105 | # Make one log on every process with the configuration for debugging. 106 | format_str = '%(asctime)s %(levelname)s: %(message)s' 107 | log_level = logging.INFO 108 | handlers = [] 109 | 110 | file_handler = logging.FileHandler(log_file, 'w') 111 | file_handler.setFormatter(logging.Formatter(format_str)) 112 | file_handler.setLevel(log_level) 113 | handlers.append(file_handler) 114 | 115 | stream_handler = logging.StreamHandler() 116 | stream_handler.setFormatter(logging.Formatter(format_str)) 117 | handlers.append(stream_handler) 118 | 119 | logging.basicConfig(handlers=handlers, level=log_level) 120 | 121 | 122 | def dict2str(opt, indent_level=1): 123 | """dict to string for printing options. 124 | 125 | Args: 126 | opt (dict): Option dict. 127 | indent_level (int): Indent level. Default: 1. 128 | 129 | Return: 130 | (str): Option string for printing. 131 | """ 132 | msg = '\n' 133 | for k, v in opt.items(): 134 | if isinstance(v, dict): 135 | msg += ' ' * (indent_level * 2) + k + ':[' 136 | msg += dict2str(v, indent_level + 1) 137 | msg += ' ' * (indent_level * 2) + ']\n' 138 | else: 139 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 140 | return msg 141 | 142 | 143 | class MessageLogger(): 144 | """Message logger for printing. 145 | 146 | Args: 147 | opt (dict): Config. It contains the following keys: 148 | name (str): Exp name. 149 | logger (dict): Contains 'print_freq' (str) for logger interval. 150 | train (dict): Contains 'total_iter' (int) for total iters. 151 | use_tb_logger (bool): Use tensorboard logger. 152 | start_iter (int): Start iter. Default: 1. 153 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 154 | """ 155 | def __init__(self, opt, start_iter=1): 156 | self.exp_name = opt['name'] 157 | self.interval = opt['logger']['print_freq'] 158 | self.start_iter = start_iter 159 | self.max_iters = opt['train']['total_iter'] 160 | self.start_time = time.time() 161 | self.logger = get_logger('mixofshow', log_level='INFO') 162 | 163 | def reset_start_time(self): 164 | self.start_time = time.time() 165 | 166 | def __call__(self, log_vars): 167 | """Format logging message. 168 | 169 | Args: 170 | log_vars (dict): It contains the following keys: 171 | epoch (int): Epoch number. 172 | iter (int): Current iter. 173 | lrs (list): List for learning rates. 174 | 175 | time (float): Iter time. 176 | data_time (float): Data time for each iter. 177 | """ 178 | # epoch, iter, learning rates 179 | current_iter = log_vars.pop('iter') 180 | lrs = log_vars.pop('lrs') 181 | 182 | message = ( 183 | f'[{self.exp_name[:5]}..][Iter:{current_iter:8,d}, lr:(' 184 | ) 185 | for v in lrs: 186 | message += f'{v:.3e},' 187 | message += ')] ' 188 | 189 | # time and estimated time 190 | total_time = time.time() - self.start_time 191 | time_sec_avg = total_time / (current_iter - self.start_iter + 1) 192 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) 193 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 194 | message += f'[eta: {eta_str}] ' 195 | 196 | # other items, especially losses 197 | for k, v in log_vars.items(): 198 | message += f'{k}: {v:.4e} ' 199 | 200 | self.logger.info(message) 201 | 202 | 203 | def reduce_loss_dict(accelerator, loss_dict): 204 | """reduce loss dict. 205 | 206 | In distributed training, it averages the losses among different GPUs . 207 | 208 | Args: 209 | loss_dict (OrderedDict): Loss dict. 210 | """ 211 | with torch.no_grad(): 212 | keys = [] 213 | losses = [] 214 | for name, value in loss_dict.items(): 215 | keys.append(name) 216 | losses.append(value) 217 | losses = torch.stack(losses, 0) 218 | losses = accelerator.reduce(losses) 219 | 220 | world_size = PartialState().num_processes 221 | losses /= world_size 222 | 223 | loss_dict = {key: loss for key, loss in zip(keys, losses)} 224 | 225 | log_dict = OrderedDict() 226 | for name, value in loss_dict.items(): 227 | log_dict[name] = value.mean().item() 228 | 229 | return log_dict 230 | 231 | 232 | def pil_imwrite(img, file_path, auto_mkdir=True): 233 | """Write image to file. 234 | Args: 235 | img (ndarray): Image array to be written. 236 | file_path (str): Image file path. 237 | params (None or list): Same as opencv's :func:`imwrite` interface. 238 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 239 | whether to create it automatically. 240 | Returns: 241 | bool: Successful or not. 242 | """ 243 | assert isinstance( 244 | img, PIL.Image.Image), 'model should return a list of PIL images' 245 | if auto_mkdir: 246 | dir_name = os.path.abspath(os.path.dirname(file_path)) 247 | os.makedirs(dir_name, exist_ok=True) 248 | img.save(file_path) 249 | 250 | 251 | def draw_prompt(text, height, width, font_size=45): 252 | img = Image.new('RGB', (width, height), (255, 255, 255)) 253 | draw = ImageDraw.Draw(img) 254 | font = ImageFont.truetype( 255 | osp.join(osp.dirname(osp.abspath(__file__)), 'arial.ttf'), font_size) 256 | 257 | guess_count = 0 258 | 259 | while font.font.getsize(text[:guess_count])[0][ 260 | 0] + 0.1 * width < width - 0.1 * width and guess_count < len( 261 | text): # centerize 262 | guess_count += 1 263 | 264 | text_new = '' 265 | for idx, s in enumerate(text): 266 | if idx % guess_count == 0: 267 | text_new += '\n' 268 | if s == ' ': 269 | s = '' # new line trip the first space 270 | text_new += s 271 | 272 | draw.text([int(0.1 * width), int(0.3 * height)], 273 | text_new, 274 | font=font, 275 | fill='black') 276 | return img 277 | 278 | 279 | def compose_visualize(dir_path): 280 | file_list = sorted(os.listdir(dir_path)) 281 | img_list = [] 282 | info_dict = {'prompts': set(), 'sample_args': set(), 'suffix': set()} 283 | for filename in file_list: 284 | prompt, sample_args, index, suffix = osp.splitext( 285 | osp.basename(filename))[0].split('---') 286 | 287 | filepath = osp.join(dir_path, filename) 288 | img = ToTensor()(Image.open(filepath)) 289 | height, width = img.shape[1:] 290 | 291 | if prompt not in info_dict['prompts']: 292 | img_list.append(ToTensor()(draw_prompt(prompt, 293 | height=height, 294 | width=width, 295 | font_size=45))) 296 | info_dict['prompts'].add(prompt) 297 | info_dict['sample_args'].add(sample_args) 298 | info_dict['suffix'].add(suffix) 299 | 300 | img_list.append(img) 301 | assert len( 302 | info_dict['sample_args'] 303 | ) == 1, 'compose dir should contain images form same sample args.' 304 | assert len(info_dict['suffix'] 305 | ) == 1, 'compose dir should contain images form same suffix.' 306 | 307 | grid = make_grid(img_list, nrow=len(img_list) // len(info_dict['prompts'])) 308 | # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer 309 | ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to( 310 | 'cpu', torch.uint8).numpy() 311 | im = Image.fromarray(ndarr) 312 | save_name = f"{info_dict['sample_args'].pop()}---{info_dict['suffix'].pop()}.jpg" 313 | im.save(osp.join(osp.dirname(dir_path), save_name)) 314 | -------------------------------------------------------------------------------- /options/test/EDLoRA/anime/1001_EDLoRA_hina_Anyv4_B4_Iter1K.yml: -------------------------------------------------------------------------------- 1 | name: 1001_EDLoRA_hina_Anyv4_B4_Iter1K 2 | manual_seed: 0 3 | mixed_precision: fp16 4 | 5 | # dataset and data loader settings 6 | datasets: 7 | val_vis: 8 | name: PromptDataset 9 | prompts: datasets/validation_prompts/single-concept/characters/test_girl.txt 10 | num_samples_per_prompt: 8 11 | latent_size: [ 4,64,64 ] 12 | replace_mapping: 13 | : 14 | batch_size_per_gpu: 4 15 | 16 | models: 17 | pretrained_path: experiments/pretrained_models/anything-v4.0 18 | enable_edlora: true # true means ED-LoRA, false means vallina LoRA 19 | alpha_list: [0, 0.4, 0.6, 1.0] # 0 means only visualize embedding (without lora weight) 20 | 21 | # path 22 | path: 23 | lora_path: experiments/1001_EDLoRA_hina_Anyv4_B4_Iter1K/models/lora_model-latest.pth 24 | 25 | # validation settings 26 | val: 27 | compose_visualize: true 28 | sample: 29 | num_inference_steps: 50 30 | guidance_scale: 7.5 31 | -------------------------------------------------------------------------------- /options/test/EDLoRA/human/8101_EDLoRA_potter_Cmix_B4_Repeat500.yml: -------------------------------------------------------------------------------- 1 | name: 8101_EDLoRA_potter_Cmix_B4_Repeat500 2 | manual_seed: 0 3 | mixed_precision: fp16 4 | 5 | # dataset and data loader settings 6 | datasets: 7 | val_vis: 8 | name: PromptDataset 9 | prompts: datasets/validation_prompts/single-concept/characters/test_man.txt 10 | num_samples_per_prompt: 8 11 | latent_size: [ 4,64,64 ] 12 | replace_mapping: 13 | : 14 | batch_size_per_gpu: 4 15 | 16 | models: 17 | pretrained_path: experiments/pretrained_models/chilloutmix 18 | enable_edlora: true # true means ED-LoRA, false means vallina LoRA 19 | 20 | # path 21 | path: 22 | lora_path: experiments/human/8101_EDLoRA_potter_Cmix_B4_Repeat500/models/edlora_model-latest.pth 23 | 24 | # validation settings 25 | val: 26 | compose_visualize: true 27 | alpha_list: [0, 0.7, 1.0] # 0 means only visualize embedding (without lora weight) 28 | sample: 29 | num_inference_steps: 50 30 | guidance_scale: 7.5 31 | -------------------------------------------------------------------------------- /options/test/EDLoRA/human/8102_EDLoRA_hermione_Cmix_B4_Repeat500.yml: -------------------------------------------------------------------------------- 1 | name: 8102_EDLoRA_hermione_Cmix_B4_Repeat500_v6_final 2 | manual_seed: 0 3 | mixed_precision: fp16 4 | 5 | # dataset and data loader settings 6 | datasets: 7 | val_vis: 8 | name: PromptDataset 9 | prompts: datasets/validation_prompts/single-concept/characters/test_woman.txt 10 | num_samples_per_prompt: 8 11 | latent_size: [ 4,64,64 ] 12 | replace_mapping: 13 | : 14 | batch_size_per_gpu: 4 15 | 16 | models: 17 | pretrained_path: experiments/pretrained_models/chilloutmix 18 | enable_edlora: true # true means ED-LoRA, false means vallina LoRA 19 | 20 | # path 21 | path: 22 | lora_path: experiments/8102_EDLoRA_hermione_Cmix_B4_Repeat500_v6_final/models/edlora_model-latest.pth 23 | 24 | # validation settings 25 | val: 26 | compose_visualize: true 27 | alpha_list: [0, 0.7, 1.0] # 0 means only visualize embedding (without lora weight) 28 | sample: 29 | num_inference_steps: 50 30 | guidance_scale: 7.5 31 | -------------------------------------------------------------------------------- /options/train/EDLoRA/anime/1001_1_EDLoRA_hina_Anyv4_B4_Repeat500_v6_final_nomask.yml: -------------------------------------------------------------------------------- 1 | name: 1001_1_EDLoRA_hina_Anyv4_B4_Repeat500_v6_final_nomask 2 | manual_seed: 0 3 | mixed_precision: fp16 4 | gradient_accumulation_steps: 1 5 | 6 | # dataset and data loader settings 7 | datasets: 8 | train: 9 | name: LoraDataset 10 | concept_list: datasets/data_cfgs/MixofShow/single-concept/characters/anime/hina_amano.json 11 | use_caption: true 12 | use_mask: true 13 | instance_transform: 14 | - { type: HumanResizeCropFinalV3, size: 512, crop_p: 0.5 } 15 | - { type: ToTensor } 16 | - { type: Normalize, mean: [ 0.5 ], std: [ 0.5 ] } 17 | - { type: ShuffleCaption, keep_token_num: 1 } 18 | - { type: EnhanceText, enhance_type: human } 19 | replace_mapping: 20 | : 21 | batch_size_per_gpu: 2 22 | dataset_enlarge_ratio: 500 23 | 24 | val_vis: 25 | name: PromptDataset 26 | prompts: datasets/validation_prompts/single-concept/characters/test_girl.txt 27 | num_samples_per_prompt: 8 28 | latent_size: [ 4,64,64 ] 29 | replace_mapping: 30 | : 31 | batch_size_per_gpu: 4 32 | 33 | models: 34 | pretrained_path: experiments/pretrained_models/anything-v4.0 35 | enable_edlora: true # true means ED-LoRA, false means vallina LoRA 36 | finetune_cfg: 37 | text_embedding: 38 | enable_tuning: true 39 | lr: !!float 1e-3 40 | text_encoder: 41 | enable_tuning: true 42 | lora_cfg: 43 | rank: 4 44 | alpha: 1.0 45 | where: CLIPAttention 46 | lr: !!float 1e-5 47 | unet: 48 | enable_tuning: true 49 | lora_cfg: 50 | rank: 4 51 | alpha: 1.0 52 | where: Attention 53 | lr: !!float 1e-4 54 | new_concept_token: + 55 | initializer_token: +girl 56 | noise_offset: 0.01 57 | attn_reg_weight: 0.01 58 | reg_full_identity: false 59 | use_mask_loss: false 60 | gradient_checkpoint: false 61 | enable_xformers: true 62 | 63 | # path 64 | path: 65 | pretrain_network: ~ 66 | 67 | # training settings 68 | train: 69 | optim_g: 70 | type: AdamW 71 | lr: !!float 0.0 # no use since we define different component lr in model 72 | weight_decay: 0.01 73 | betas: [ 0.9, 0.999 ] # align with taming 74 | 75 | # dropkv 76 | scheduler: linear 77 | emb_norm_threshold: !!float 5.5e-1 78 | 79 | # validation settings 80 | val: 81 | val_during_save: true 82 | compose_visualize: true 83 | alpha_list: [0, 0.7, 1.0] # 0 means only visualize embedding (without lora weight) 84 | sample: 85 | num_inference_steps: 50 86 | guidance_scale: 7.5 87 | 88 | # logging settings 89 | logger: 90 | print_freq: 10 91 | save_checkpoint_freq: !!float 10000 92 | -------------------------------------------------------------------------------- /options/train/EDLoRA/anime/1002_1_EDLoRA_kaori_Anyv4_B4_Repeat500_v6_final_nomask.yml: -------------------------------------------------------------------------------- 1 | name: 1002_1_EDLoRA_kaori_Anyv4_B4_Repeat500_v6_final_nomask 2 | manual_seed: 0 3 | mixed_precision: fp16 4 | gradient_accumulation_steps: 1 5 | 6 | # dataset and data loader settings 7 | datasets: 8 | train: 9 | name: LoraDataset 10 | concept_list: datasets/data_cfgs/MixofShow/single-concept/characters/anime/miyazono_kaori.json 11 | use_caption: true 12 | use_mask: true 13 | instance_transform: 14 | - { type: HumanResizeCropFinalV3, size: 512, crop_p: 0.5 } 15 | - { type: ToTensor } 16 | - { type: Normalize, mean: [ 0.5 ], std: [ 0.5 ] } 17 | - { type: ShuffleCaption, keep_token_num: 1 } 18 | - { type: EnhanceText, enhance_type: human } 19 | replace_mapping: 20 | : 21 | batch_size_per_gpu: 2 22 | dataset_enlarge_ratio: 500 23 | 24 | val_vis: 25 | name: PromptDataset 26 | prompts: datasets/validation_prompts/single-concept/characters/test_girl.txt 27 | num_samples_per_prompt: 8 28 | latent_size: [ 4,64,64 ] 29 | replace_mapping: 30 | : 31 | batch_size_per_gpu: 4 32 | 33 | models: 34 | pretrained_path: experiments/pretrained_models/anything-v4.0 35 | enable_edlora: true # true means ED-LoRA, false means vallina LoRA 36 | finetune_cfg: 37 | text_embedding: 38 | enable_tuning: true 39 | lr: !!float 1e-3 40 | text_encoder: 41 | enable_tuning: true 42 | lora_cfg: 43 | rank: 4 44 | alpha: 1.0 45 | where: CLIPAttention 46 | lr: !!float 1e-5 47 | unet: 48 | enable_tuning: true 49 | lora_cfg: 50 | rank: 4 51 | alpha: 1.0 52 | where: Attention 53 | lr: !!float 1e-4 54 | new_concept_token: + 55 | initializer_token: +girl 56 | noise_offset: 0.01 57 | attn_reg_weight: 0.01 58 | reg_full_identity: false 59 | use_mask_loss: false 60 | gradient_checkpoint: false 61 | enable_xformers: true 62 | 63 | # path 64 | path: 65 | pretrain_network: ~ 66 | 67 | # training settings 68 | train: 69 | optim_g: 70 | type: AdamW 71 | lr: !!float 0.0 # no use since we define different component lr in model 72 | weight_decay: 0.01 73 | betas: [ 0.9, 0.999 ] # align with taming 74 | 75 | # dropkv 76 | scheduler: linear 77 | emb_norm_threshold: !!float 5.5e-1 78 | 79 | # validation settings 80 | val: 81 | val_during_save: true 82 | compose_visualize: true 83 | alpha_list: [0, 0.7, 1.0] # 0 means only visualize embedding (without lora weight) 84 | sample: 85 | num_inference_steps: 50 86 | guidance_scale: 7.5 87 | 88 | # logging settings 89 | logger: 90 | print_freq: 10 91 | save_checkpoint_freq: !!float 10000 92 | -------------------------------------------------------------------------------- /options/train/EDLoRA/anime/1003_1_EDLoRA_tezuka_Anyv4_B4_Repeat500_v6_final_nomask.yml: -------------------------------------------------------------------------------- 1 | name: 1003_1_EDLoRA_tezuka_Anyv4_B4_Repeat500_v6_final_nomask 2 | manual_seed: 0 3 | mixed_precision: fp16 4 | gradient_accumulation_steps: 1 5 | 6 | # dataset and data loader settings 7 | datasets: 8 | train: 9 | name: LoraDataset 10 | concept_list: datasets/data_cfgs/MixofShow/single-concept/characters/anime/tezuka_kunimitsu.json 11 | use_caption: true 12 | use_mask: true 13 | instance_transform: 14 | - { type: HumanResizeCropFinalV3, size: 512, crop_p: 0.5 } 15 | - { type: ToTensor } 16 | - { type: Normalize, mean: [ 0.5 ], std: [ 0.5 ] } 17 | - { type: ShuffleCaption, keep_token_num: 1 } 18 | - { type: EnhanceText, enhance_type: human } 19 | replace_mapping: 20 | : 21 | batch_size_per_gpu: 2 22 | dataset_enlarge_ratio: 500 23 | 24 | val_vis: 25 | name: PromptDataset 26 | prompts: datasets/validation_prompts/single-concept/characters/test_man.txt 27 | num_samples_per_prompt: 8 28 | latent_size: [ 4,64,64 ] 29 | replace_mapping: 30 | : 31 | batch_size_per_gpu: 4 32 | 33 | models: 34 | pretrained_path: experiments/pretrained_models/anything-v4.0 35 | enable_edlora: true # true means ED-LoRA, false means vallina LoRA 36 | finetune_cfg: 37 | text_embedding: 38 | enable_tuning: true 39 | lr: !!float 1e-3 40 | text_encoder: 41 | enable_tuning: true 42 | lora_cfg: 43 | rank: 4 44 | alpha: 1.0 45 | where: CLIPAttention 46 | lr: !!float 1e-5 47 | unet: 48 | enable_tuning: true 49 | lora_cfg: 50 | rank: 4 51 | alpha: 1.0 52 | where: Attention 53 | lr: !!float 1e-4 54 | new_concept_token: + 55 | initializer_token: +man 56 | noise_offset: 0.01 57 | attn_reg_weight: 0.01 58 | reg_full_identity: false 59 | use_mask_loss: false 60 | gradient_checkpoint: false 61 | enable_xformers: true 62 | 63 | # path 64 | path: 65 | pretrain_network: ~ 66 | 67 | # training settings 68 | train: 69 | optim_g: 70 | type: AdamW 71 | lr: !!float 0.0 # no use since we define different component lr in model 72 | weight_decay: 0.01 73 | betas: [ 0.9, 0.999 ] # align with taming 74 | 75 | # dropkv 76 | scheduler: linear 77 | emb_norm_threshold: !!float 5.5e-1 78 | 79 | # validation settings 80 | val: 81 | val_during_save: true 82 | compose_visualize: true 83 | alpha_list: [0, 0.7, 1.0] # 0 means only visualize embedding (without lora weight) 84 | sample: 85 | num_inference_steps: 50 86 | guidance_scale: 7.5 87 | 88 | # logging settings 89 | logger: 90 | print_freq: 10 91 | save_checkpoint_freq: !!float 10000 92 | -------------------------------------------------------------------------------- /options/train/EDLoRA/real/8101_EDLoRA_potter_Cmix_B4_Repeat500.yml: -------------------------------------------------------------------------------- 1 | name: 8101_EDLoRA_potter_Cmix_B4_Repeat500 2 | manual_seed: 0 3 | mixed_precision: fp16 4 | gradient_accumulation_steps: 1 5 | 6 | # dataset and data loader settings 7 | datasets: 8 | train: 9 | name: LoraDataset 10 | concept_list: datasets/data_cfgs/MixofShow/single-concept/characters/real/potter.json 11 | use_caption: true 12 | use_mask: true 13 | instance_transform: 14 | - { type: HumanResizeCropFinalV3, size: 512, crop_p: 0.5 } 15 | - { type: ToTensor } 16 | - { type: Normalize, mean: [ 0.5 ], std: [ 0.5 ] } 17 | - { type: ShuffleCaption, keep_token_num: 1 } 18 | - { type: EnhanceText, enhance_type: human } 19 | replace_mapping: 20 | : 21 | batch_size_per_gpu: 2 22 | dataset_enlarge_ratio: 500 23 | 24 | val_vis: 25 | name: PromptDataset 26 | prompts: datasets/validation_prompts/single-concept/characters/test_man.txt 27 | num_samples_per_prompt: 8 28 | latent_size: [ 4,64,64 ] 29 | replace_mapping: 30 | : 31 | batch_size_per_gpu: 4 32 | 33 | models: 34 | pretrained_path: experiments/pretrained_models/chilloutmix 35 | enable_edlora: true # true means ED-LoRA, false means vallina LoRA 36 | finetune_cfg: 37 | text_embedding: 38 | enable_tuning: true 39 | lr: !!float 1e-3 40 | text_encoder: 41 | enable_tuning: true 42 | lora_cfg: 43 | rank: 4 44 | alpha: 1.0 45 | where: CLIPAttention 46 | lr: !!float 1e-5 47 | unet: 48 | enable_tuning: true 49 | lora_cfg: 50 | rank: 4 51 | alpha: 1.0 52 | where: Attention 53 | lr: !!float 1e-4 54 | new_concept_token: + 55 | initializer_token: +man 56 | noise_offset: 0.01 57 | attn_reg_weight: 0.01 58 | reg_full_identity: false 59 | use_mask_loss: true 60 | gradient_checkpoint: false 61 | enable_xformers: true 62 | 63 | # path 64 | path: 65 | pretrain_network: ~ 66 | 67 | # training settings 68 | train: 69 | optim_g: 70 | type: AdamW 71 | lr: !!float 0.0 # no use since we define different component lr in model 72 | weight_decay: 0.01 73 | betas: [ 0.9, 0.999 ] # align with taming 74 | 75 | # dropkv 76 | unet_kv_drop_rate: 0 77 | scheduler: linear 78 | emb_norm_threshold: !!float 5.5e-1 79 | 80 | # validation settings 81 | val: 82 | val_during_save: true 83 | compose_visualize: true 84 | alpha_list: [0, 0.7, 1.0] # 0 means only visualize embedding (without lora weight) 85 | sample: 86 | num_inference_steps: 50 87 | guidance_scale: 7.5 88 | 89 | # logging settings 90 | logger: 91 | print_freq: 10 92 | save_checkpoint_freq: !!float 10000 93 | -------------------------------------------------------------------------------- /options/train/EDLoRA/real/8102_EDLoRA_hermione_Cmix_B4_Repeat500.yml: -------------------------------------------------------------------------------- 1 | name: 8102_EDLoRA_hermione_Cmix_B4_Repeat500 2 | manual_seed: 0 3 | mixed_precision: fp16 4 | gradient_accumulation_steps: 1 5 | 6 | # dataset and data loader settings 7 | datasets: 8 | train: 9 | name: LoraDataset 10 | concept_list: datasets/data_cfgs/MixofShow/single-concept/characters/real/hermione.json 11 | use_caption: true 12 | use_mask: true 13 | instance_transform: 14 | - { type: HumanResizeCropFinalV3, size: 512, crop_p: 0.5 } 15 | - { type: ToTensor } 16 | - { type: Normalize, mean: [ 0.5 ], std: [ 0.5 ] } 17 | - { type: ShuffleCaption, keep_token_num: 1 } 18 | - { type: EnhanceText, enhance_type: human } 19 | replace_mapping: 20 | : 21 | batch_size_per_gpu: 2 22 | dataset_enlarge_ratio: 500 23 | 24 | val_vis: 25 | name: PromptDataset 26 | prompts: datasets/validation_prompts/single-concept/characters/test_woman.txt 27 | num_samples_per_prompt: 8 28 | latent_size: [ 4,64,64 ] 29 | replace_mapping: 30 | : 31 | batch_size_per_gpu: 4 32 | 33 | models: 34 | pretrained_path: experiments/pretrained_models/chilloutmix 35 | enable_edlora: true # true means ED-LoRA, false means vallina LoRA 36 | finetune_cfg: 37 | text_embedding: 38 | enable_tuning: true 39 | lr: !!float 1e-3 40 | text_encoder: 41 | enable_tuning: true 42 | lora_cfg: 43 | rank: 4 44 | alpha: 1.0 45 | where: CLIPAttention 46 | lr: !!float 1e-5 47 | unet: 48 | enable_tuning: true 49 | lora_cfg: 50 | rank: 4 51 | alpha: 1.0 52 | where: Attention 53 | lr: !!float 1e-4 54 | new_concept_token: + 55 | initializer_token: +woman 56 | noise_offset: 0.01 57 | attn_reg_weight: 0.01 58 | reg_full_identity: false 59 | use_mask_loss: true 60 | gradient_checkpoint: false 61 | enable_xformers: true 62 | 63 | # path 64 | path: 65 | pretrain_network: ~ 66 | 67 | # training settings 68 | train: 69 | optim_g: 70 | type: AdamW 71 | lr: !!float 0.0 # no use since we define different component lr in model 72 | weight_decay: 0.01 73 | betas: [ 0.9, 0.999 ] # align with taming 74 | 75 | # dropkv 76 | unet_kv_drop_rate: 0 77 | scheduler: linear 78 | emb_norm_threshold: !!float 5.5e-1 79 | 80 | # validation settings 81 | val: 82 | val_during_save: true 83 | compose_visualize: true 84 | alpha_list: [0, 0.7, 1.0] # 0 means only visualize embedding (without lora weight) 85 | sample: 86 | num_inference_steps: 50 87 | guidance_scale: 7.5 88 | 89 | # logging settings 90 | logger: 91 | print_freq: 10 92 | save_checkpoint_freq: !!float 10000 93 | -------------------------------------------------------------------------------- /options/train/EDLoRA/real/8103_EDLoRA_thanos_Cmix_B4_Repeat250.yml: -------------------------------------------------------------------------------- 1 | name: 8103_EDLoRA_thanos_Cmix_B4_Repeat250 2 | manual_seed: 0 3 | mixed_precision: fp16 4 | gradient_accumulation_steps: 1 5 | 6 | # dataset and data loader settings 7 | datasets: 8 | train: 9 | name: LoraDataset 10 | concept_list: datasets/data_cfgs/MixofShow/single-concept/characters/real/thanos.json 11 | use_caption: true 12 | use_mask: true 13 | instance_transform: 14 | - { type: HumanResizeCropFinalV3, size: 512, crop_p: 0.5 } 15 | - { type: ToTensor } 16 | - { type: Normalize, mean: [ 0.5 ], std: [ 0.5 ] } 17 | - { type: ShuffleCaption, keep_token_num: 1 } 18 | - { type: EnhanceText, enhance_type: human } 19 | replace_mapping: 20 | : 21 | batch_size_per_gpu: 2 22 | dataset_enlarge_ratio: 250 23 | 24 | val_vis: 25 | name: PromptDataset 26 | prompts: datasets/validation_prompts/single-concept/characters/test_man.txt 27 | num_samples_per_prompt: 8 28 | latent_size: [ 4,64,64 ] 29 | replace_mapping: 30 | : 31 | batch_size_per_gpu: 4 32 | 33 | models: 34 | pretrained_path: experiments/pretrained_models/chilloutmix 35 | enable_edlora: true # true means ED-LoRA, false means vallina LoRA 36 | finetune_cfg: 37 | text_embedding: 38 | enable_tuning: true 39 | lr: !!float 1e-3 40 | text_encoder: 41 | enable_tuning: true 42 | lora_cfg: 43 | rank: 4 44 | alpha: 1.0 45 | where: CLIPAttention 46 | lr: !!float 1e-5 47 | unet: 48 | enable_tuning: true 49 | lora_cfg: 50 | rank: 4 51 | alpha: 1.0 52 | where: Attention 53 | lr: !!float 1e-4 54 | new_concept_token: + 55 | initializer_token: +man 56 | noise_offset: 0.01 57 | attn_reg_weight: 0.01 58 | reg_full_identity: true 59 | use_mask_loss: true 60 | gradient_checkpoint: false 61 | enable_xformers: true 62 | 63 | # path 64 | path: 65 | pretrain_network: ~ 66 | 67 | # training settings 68 | train: 69 | optim_g: 70 | type: AdamW 71 | lr: !!float 0.0 # no use since we define different component lr in model 72 | weight_decay: 0.01 73 | betas: [ 0.9, 0.999 ] # align with taming 74 | 75 | # dropkv 76 | unet_kv_drop_rate: 0 77 | scheduler: linear 78 | emb_norm_threshold: !!float 5.5e-1 79 | 80 | # validation settings 81 | val: 82 | val_during_save: true 83 | compose_visualize: true 84 | alpha_list: [0, 0.7, 1.0] # 0 means only visualize embedding (without lora weight) 85 | sample: 86 | num_inference_steps: 50 87 | guidance_scale: 7.5 88 | 89 | # logging settings 90 | logger: 91 | print_freq: 10 92 | save_checkpoint_freq: !!float 10000 93 | -------------------------------------------------------------------------------- /regionally_controlable_sampling.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import hashlib 3 | import json 4 | import os.path 5 | 6 | import torch 7 | from diffusers import DPMSolverMultistepScheduler 8 | from diffusers.models import T2IAdapter 9 | from PIL import Image 10 | 11 | from mixofshow.pipelines.pipeline_regionally_t2iadapter import RegionallyT2IAdapterPipeline 12 | 13 | 14 | def sample_image(pipe, 15 | input_prompt, 16 | input_neg_prompt=None, 17 | generator=None, 18 | num_inference_steps=50, 19 | guidance_scale=7.5, 20 | sketch_adaptor_weight=1.0, 21 | region_sketch_adaptor_weight='', 22 | keypose_adaptor_weight=1.0, 23 | region_keypose_adaptor_weight='', 24 | **extra_kargs 25 | ): 26 | 27 | keypose_condition = extra_kargs.pop('keypose_condition') 28 | if keypose_condition is not None: 29 | keypose_adapter_input = [keypose_condition] * len(input_prompt) 30 | else: 31 | keypose_adapter_input = None 32 | 33 | sketch_condition = extra_kargs.pop('sketch_condition') 34 | if sketch_condition is not None: 35 | sketch_adapter_input = [sketch_condition] * len(input_prompt) 36 | else: 37 | sketch_adapter_input = None 38 | 39 | images = pipe( 40 | prompt=input_prompt, 41 | negative_prompt=input_neg_prompt, 42 | keypose_adapter_input=keypose_adapter_input, 43 | keypose_adaptor_weight=keypose_adaptor_weight, 44 | region_keypose_adaptor_weight=region_keypose_adaptor_weight, 45 | sketch_adapter_input=sketch_adapter_input, 46 | sketch_adaptor_weight=sketch_adaptor_weight, 47 | region_sketch_adaptor_weight=region_sketch_adaptor_weight, 48 | generator=generator, 49 | guidance_scale=guidance_scale, 50 | num_inference_steps=num_inference_steps, 51 | **extra_kargs).images 52 | return images 53 | 54 | 55 | def build_model(pretrained_model, device): 56 | pipe = RegionallyT2IAdapterPipeline.from_pretrained(pretrained_model, torch_dtype=torch.float16).to(device) 57 | assert os.path.exists(os.path.join(pretrained_model, 'new_concept_cfg.json')) 58 | with open(os.path.join(pretrained_model, 'new_concept_cfg.json'), 'r') as json_file: 59 | new_concept_cfg = json.load(json_file) 60 | pipe.set_new_concept_cfg(new_concept_cfg) 61 | pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(pretrained_model, subfolder='scheduler') 62 | pipe.keypose_adapter = T2IAdapter.from_pretrained('TencentARC/t2iadapter_openpose_sd14v1', torch_dtype=torch.float16).to(device) 63 | pipe.sketch_adapter = T2IAdapter.from_pretrained('TencentARC/t2iadapter_sketch_sd14v1', torch_dtype=torch.float16).to(device) 64 | return pipe 65 | 66 | 67 | def prepare_text(prompt, region_prompts, height, width): 68 | ''' 69 | Args: 70 | prompt_entity: [subject1]-*-[attribute1]-*-[Location1]|[subject2]-*-[attribute2]-*-[Location2]|[global text] 71 | Returns: 72 | full_prompt: subject1, attribute1 and subject2, attribute2, global text 73 | context_prompt: subject1 and subject2, global text 74 | entity_collection: [(subject1, attribute1), Location1] 75 | ''' 76 | region_collection = [] 77 | 78 | regions = region_prompts.split('|') 79 | 80 | for region in regions: 81 | if region == '': 82 | break 83 | prompt_region, neg_prompt_region, pos = region.split('-*-') 84 | prompt_region = prompt_region.replace('[', '').replace(']', '') 85 | neg_prompt_region = neg_prompt_region.replace('[', '').replace(']', '') 86 | pos = eval(pos) 87 | if len(pos) == 0: 88 | pos = [0, 0, 1, 1] 89 | else: 90 | pos[0], pos[2] = pos[0] / height, pos[2] / height 91 | pos[1], pos[3] = pos[1] / width, pos[3] / width 92 | 93 | region_collection.append((prompt_region, neg_prompt_region, pos)) 94 | return (prompt, region_collection) 95 | 96 | 97 | def parse_args(): 98 | parser = argparse.ArgumentParser('', add_help=False) 99 | parser.add_argument('--pretrained_model', default='experiments/composed_edlora/anythingv4/hina+kario+tezuka+mitsuha+son_anythingv4/combined_model_base', type=str) 100 | parser.add_argument('--sketch_condition', default=None, type=str) 101 | parser.add_argument('--sketch_adaptor_weight', default=1.0, type=float) 102 | parser.add_argument('--region_sketch_adaptor_weight', default='', type=str) 103 | parser.add_argument('--keypose_condition', default=None, type=str) 104 | parser.add_argument('--keypose_adaptor_weight', default=1.0, type=float) 105 | parser.add_argument('--region_keypose_adaptor_weight', default='', type=str) 106 | parser.add_argument('--save_dir', default=None, type=str) 107 | parser.add_argument('--prompt', default='photo of a toy', type=str) 108 | parser.add_argument('--negative_prompt', default='', type=str) 109 | parser.add_argument('--prompt_rewrite', default='', type=str) 110 | parser.add_argument('--seed', default=16141, type=int) 111 | parser.add_argument('--suffix', default='', type=str) 112 | return parser.parse_args() 113 | 114 | 115 | if __name__ == '__main__': 116 | args = parse_args() 117 | 118 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 119 | pipe = build_model(args.pretrained_model, device) 120 | 121 | if args.sketch_condition is not None and os.path.exists(args.sketch_condition): 122 | sketch_condition = Image.open(args.sketch_condition).convert('L') 123 | width_sketch, height_sketch = sketch_condition.size 124 | print('use sketch condition') 125 | else: 126 | sketch_condition, width_sketch, height_sketch = None, 0, 0 127 | print('skip sketch condition') 128 | 129 | if args.keypose_condition is not None and os.path.exists(args.keypose_condition): 130 | keypose_condition = Image.open(args.keypose_condition).convert('RGB') 131 | width_pose, height_pose = keypose_condition.size 132 | print('use pose condition') 133 | else: 134 | keypose_condition, width_pose, height_pose = None, 0, 0 135 | print('skip pose condition') 136 | 137 | if width_sketch != 0 and width_pose != 0: 138 | assert width_sketch == width_pose and height_sketch == height_pose, 'conditions should be same size' 139 | width, height = max(width_pose, width_sketch), max(height_pose, height_sketch) 140 | 141 | kwargs = { 142 | 'sketch_condition': sketch_condition, 143 | 'keypose_condition': keypose_condition, 144 | 'height': height, 145 | 'width': width, 146 | } 147 | 148 | prompts = [args.prompt] 149 | prompts_rewrite = [args.prompt_rewrite] 150 | input_prompt = [prepare_text(p, p_w, height, width) for p, p_w in zip(prompts, prompts_rewrite)] 151 | save_prompt = input_prompt[0][0] 152 | 153 | image = sample_image( 154 | pipe, 155 | input_prompt=input_prompt, 156 | input_neg_prompt=[args.negative_prompt] * len(input_prompt), 157 | generator=torch.Generator(device).manual_seed(args.seed), 158 | sketch_adaptor_weight=args.sketch_adaptor_weight, 159 | region_sketch_adaptor_weight=args.region_sketch_adaptor_weight, 160 | keypose_adaptor_weight=args.keypose_adaptor_weight, 161 | region_keypose_adaptor_weight=args.region_keypose_adaptor_weight, 162 | **kwargs) 163 | 164 | print(f'save to: {args.save_dir}') 165 | 166 | configs = [ 167 | f'pretrained_model: {args.pretrained_model}\n', 168 | f'context_prompt: {args.prompt}\n', f'neg_context_prompt: {args.negative_prompt}\n', 169 | f'sketch_condition: {args.sketch_condition}\n', f'sketch_adaptor_weight: {args.sketch_adaptor_weight}\n', 170 | f'region_sketch_adaptor_weight: {args.region_sketch_adaptor_weight}\n', 171 | f'keypose_condition: {args.keypose_condition}\n', f'keypose_adaptor_weight: {args.keypose_adaptor_weight}\n', 172 | f'region_keypose_adaptor_weight: {args.region_keypose_adaptor_weight}\n', f'random seed: {args.seed}\n', 173 | f'prompt_rewrite: {args.prompt_rewrite}\n' 174 | ] 175 | hash_code = hashlib.sha256(''.join(configs).encode('utf-8')).hexdigest()[:8] 176 | 177 | save_prompt = save_prompt.replace(' ', '_') 178 | save_name = f'{save_prompt}---{args.suffix}---{hash_code}.png' 179 | save_dir = os.path.join(args.save_dir, f'seed_{args.seed}') 180 | save_path = os.path.join(save_dir, save_name) 181 | save_config_path = os.path.join(save_dir, save_name.replace('.png', '.txt')) 182 | 183 | os.makedirs(save_dir, exist_ok=True) 184 | image[0].save(os.path.join(save_dir, save_name)) 185 | 186 | with open(save_config_path, 'w') as fw: 187 | fw.writelines(configs) 188 | -------------------------------------------------------------------------------- /regionally_sample.sh: -------------------------------------------------------------------------------- 1 | #---------------------------------------------anime------------------------------------------- 2 | 3 | anime_character=0 4 | 5 | if [ ${anime_character} -eq 1 ] 6 | then 7 | fused_model="experiments/composed_edlora/anythingv4/hina+kario+tezuka_anythingv4/combined_model_base" 8 | expdir="hina+kario+tezuka_anythingv4" 9 | 10 | keypose_condition='datasets/validation_spatial_condition/multi-characters/anime_pose_2x/hina_tezuka_kario_2x.png' 11 | keypose_adaptor_weight=1.0 12 | sketch_condition='' 13 | sketch_adaptor_weight=1.0 14 | 15 | context_prompt='two girls and a boy are standing near a forest' 16 | context_neg_prompt='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' 17 | 18 | region1_prompt='[a , standing near a forest]' 19 | region1_neg_prompt="[${context_neg_prompt}]" 20 | region1='[12, 36, 1024, 600]' 21 | 22 | region2_prompt='[a , standing near a forest]' 23 | region2_neg_prompt="[${context_neg_prompt}]" 24 | region2='[18, 696, 1024, 1180]' 25 | 26 | region5_prompt='[a , standing near a forest]' 27 | region5_neg_prompt="[${context_neg_prompt}]" 28 | region5='[142, 1259, 1024, 1956]' 29 | 30 | prompt_rewrite="${region1_prompt}-*-${region1_neg_prompt}-*-${region1}|${region2_prompt}-*-${region2_neg_prompt}-*-${region2}|${region5_prompt}-*-${region5_neg_prompt}-*-${region5}" 31 | 32 | python regionally_controlable_sampling.py \ 33 | --pretrained_model=${fused_model} \ 34 | --sketch_adaptor_weight=${sketch_adaptor_weight}\ 35 | --sketch_condition=${sketch_condition} \ 36 | --keypose_adaptor_weight=${keypose_adaptor_weight}\ 37 | --keypose_condition=${keypose_condition} \ 38 | --save_dir="results/multi-concept/${expdir}" \ 39 | --prompt="${context_prompt}" \ 40 | --negative_prompt="${context_neg_prompt}" \ 41 | --prompt_rewrite="${prompt_rewrite}" \ 42 | --suffix="baseline" \ 43 | --seed=19 44 | fi 45 | 46 | #---------------------------------------------real------------------------------------------- 47 | 48 | real_character=1 49 | 50 | if [ ${real_character} -eq 1 ] 51 | then 52 | fused_model="experiments/composed_edlora/chilloutmix/potter+hermione+thanos_chilloutmix/combined_model_base" 53 | expdir="potter+hermione+thanos_chilloutmix" 54 | 55 | keypose_condition='datasets/validation_spatial_condition/multi-characters/real_pose_2x/harry_hermione_thanos_2x.png' 56 | keypose_adaptor_weight=1.0 57 | 58 | sketch_condition='' 59 | sketch_adaptor_weight=1.0 60 | 61 | context_prompt='three people near the castle, 4K, high quality, high resolution, best quality' 62 | context_neg_prompt='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' 63 | 64 | region1_prompt='[a , in Hogwarts uniform, holding hands, near the castle, 4K, high quality, high resolution, best quality]' 65 | region1_neg_prompt="[${context_neg_prompt}]" 66 | region1='[4, 6, 1024, 490]' 67 | 68 | region2_prompt='[a , girl, in Hogwarts uniform, near the castle, 4K, high quality, high resolution, best quality]' 69 | region2_neg_prompt="[${context_neg_prompt}]" 70 | region2='[14, 490, 1024, 920]' 71 | 72 | region3_prompt='[a , purple armor, near the castle, 4K, high quality, high resolution, best quality]' 73 | region3_neg_prompt="[${context_neg_prompt}]" 74 | region3='[2, 1302, 1024, 1992]' 75 | 76 | prompt_rewrite="${region1_prompt}-*-${region1_neg_prompt}-*-${region1}|${region2_prompt}-*-${region2_neg_prompt}-*-${region2}|${region3_prompt}-*-${region3_neg_prompt}-*-${region3}" 77 | 78 | python regionally_controlable_sampling.py \ 79 | --pretrained_model=${fused_model} \ 80 | --sketch_adaptor_weight=${sketch_adaptor_weight}\ 81 | --sketch_condition=${sketch_condition} \ 82 | --keypose_adaptor_weight=${keypose_adaptor_weight}\ 83 | --keypose_condition=${keypose_condition} \ 84 | --save_dir="results/multi-concept/${expdir}" \ 85 | --prompt="${context_prompt}" \ 86 | --negative_prompt="${context_neg_prompt}" \ 87 | --prompt_rewrite="${prompt_rewrite}" \ 88 | --suffix="baseline" \ 89 | --seed=14 90 | fi 91 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | diffusers 3 | transformers 4 | -------------------------------------------------------------------------------- /test_edlora.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | 5 | import torch 6 | import torch.utils.checkpoint 7 | from accelerate import Accelerator 8 | from accelerate.logging import get_logger 9 | from accelerate.utils import set_seed 10 | from diffusers import DPMSolverMultistepScheduler 11 | from diffusers.utils import check_min_version 12 | from omegaconf import OmegaConf 13 | from tqdm import tqdm 14 | 15 | from mixofshow.data.prompt_dataset import PromptDataset 16 | from mixofshow.pipelines.pipeline_edlora import EDLoRAPipeline, StableDiffusionPipeline 17 | from mixofshow.utils.convert_edlora_to_diffusers import convert_edlora 18 | from mixofshow.utils.util import NEGATIVE_PROMPT, compose_visualize, dict2str, pil_imwrite, set_path_logger 19 | 20 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 21 | check_min_version('0.18.2') 22 | 23 | 24 | def visual_validation(accelerator, pipe, dataloader, current_iter, opt): 25 | dataset_name = dataloader.dataset.opt['name'] 26 | pipe.unet.eval() 27 | pipe.text_encoder.eval() 28 | 29 | for idx, val_data in enumerate(tqdm(dataloader)): 30 | output = pipe( 31 | prompt=val_data['prompts'], 32 | latents=val_data['latents'].to(dtype=torch.float16), 33 | negative_prompt=[NEGATIVE_PROMPT] * len(val_data['prompts']), 34 | num_inference_steps=opt['val']['sample'].get('num_inference_steps', 50), 35 | guidance_scale=opt['val']['sample'].get('guidance_scale', 7.5), 36 | ).images 37 | 38 | for img, prompt, indice in zip(output, val_data['prompts'], val_data['indices']): 39 | img_name = '{prompt}---G_{guidance_scale}_S_{steps}---{indice}'.format( 40 | prompt=prompt.replace(' ', '_'), 41 | guidance_scale=opt['val']['sample'].get('guidance_scale', 7.5), 42 | steps=opt['val']['sample'].get('num_inference_steps', 50), 43 | indice=indice) 44 | 45 | save_img_path = osp.join(opt['path']['visualization'], dataset_name, f'{current_iter}', f'{img_name}---{current_iter}.png') 46 | 47 | pil_imwrite(img, save_img_path) 48 | # tentative for out of GPU memory 49 | del output 50 | torch.cuda.empty_cache() 51 | 52 | # Save the lora layers, final eval 53 | accelerator.wait_for_everyone() 54 | 55 | if opt['val'].get('compose_visualize'): 56 | if accelerator.is_main_process: 57 | compose_visualize(os.path.dirname(save_img_path)) 58 | 59 | 60 | def test(root_path, args): 61 | 62 | # load config 63 | opt = OmegaConf.to_container(OmegaConf.load(args.opt), resolve=True) 64 | 65 | # set accelerator, mix-precision set in the environment by "accelerate config" 66 | accelerator = Accelerator(mixed_precision=opt['mixed_precision']) 67 | 68 | # set experiment dir 69 | with accelerator.main_process_first(): 70 | set_path_logger(accelerator, root_path, args.opt, opt, is_train=False) 71 | 72 | # get logger 73 | logger = get_logger('mixofshow', log_level='INFO') 74 | logger.info(accelerator.state, main_process_only=True) 75 | 76 | logger.info(dict2str(opt)) 77 | 78 | # If passed along, set the training seed now. 79 | if opt.get('manual_seed') is not None: 80 | set_seed(opt['manual_seed']) 81 | 82 | # Get the training dataset 83 | valset_cfg = opt['datasets']['val_vis'] 84 | val_dataset = PromptDataset(valset_cfg) 85 | val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=valset_cfg['batch_size_per_gpu'], shuffle=False) 86 | 87 | enable_edlora = opt['models']['enable_edlora'] 88 | 89 | for lora_alpha in opt['val']['alpha_list']: 90 | pipeclass = EDLoRAPipeline if enable_edlora else StableDiffusionPipeline 91 | pipe = pipeclass.from_pretrained(opt['models']['pretrained_path'], 92 | scheduler=DPMSolverMultistepScheduler.from_pretrained(opt['models']['pretrained_path'], subfolder='scheduler'), 93 | torch_dtype=torch.float16).to('cuda') 94 | pipe, new_concept_cfg = convert_edlora(pipe, torch.load(opt['path']['lora_path']), enable_edlora=enable_edlora, alpha=lora_alpha) 95 | pipe.set_new_concept_cfg(new_concept_cfg) 96 | # visualize embedding + LoRA weight shift 97 | logger.info(f'Start validation sample lora({lora_alpha}):') 98 | 99 | lora_type = 'edlora' if enable_edlora else 'lora' 100 | visual_validation(accelerator, pipe, val_dataloader, f'validation_{lora_type}_{lora_alpha}', opt) 101 | del pipe 102 | 103 | 104 | if __name__ == '__main__': 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument('-opt', type=str, default='options/test/EDLoRA/EDLoRA_hina_Anyv4_B4_Iter1K.yml') 107 | args = parser.parse_args() 108 | 109 | root_path = osp.abspath(osp.join(__file__, osp.pardir)) 110 | test(root_path, args) 111 | -------------------------------------------------------------------------------- /train_edlora.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | import os.path as osp 5 | 6 | import torch 7 | import torch.utils.checkpoint 8 | from accelerate import Accelerator 9 | from accelerate.logging import get_logger 10 | from accelerate.utils import set_seed 11 | from diffusers import DPMSolverMultistepScheduler 12 | from diffusers.optimization import get_scheduler 13 | from diffusers.utils import check_min_version 14 | from omegaconf import OmegaConf 15 | 16 | from mixofshow.data.lora_dataset import LoraDataset 17 | from mixofshow.data.prompt_dataset import PromptDataset 18 | from mixofshow.pipelines.pipeline_edlora import EDLoRAPipeline, StableDiffusionPipeline 19 | from mixofshow.pipelines.trainer_edlora import EDLoRATrainer 20 | from mixofshow.utils.convert_edlora_to_diffusers import convert_edlora 21 | from mixofshow.utils.util import MessageLogger, dict2str, reduce_loss_dict, set_path_logger 22 | from test_edlora import visual_validation 23 | 24 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 25 | check_min_version('0.18.2') 26 | 27 | 28 | def train(root_path, args): 29 | 30 | # load config 31 | opt = OmegaConf.to_container(OmegaConf.load(args.opt), resolve=True) 32 | 33 | # set accelerator, mix-precision set in the environment by "accelerate config" 34 | accelerator = Accelerator(mixed_precision=opt['mixed_precision'], gradient_accumulation_steps=opt['gradient_accumulation_steps']) 35 | 36 | # set experiment dir 37 | with accelerator.main_process_first(): 38 | set_path_logger(accelerator, root_path, args.opt, opt, is_train=True) 39 | 40 | # get logger 41 | logger = get_logger('mixofshow', log_level='INFO') 42 | logger.info(accelerator.state, main_process_only=True) 43 | 44 | logger.info(dict2str(opt)) 45 | 46 | # If passed along, set the training seed now. 47 | if opt.get('manual_seed') is not None: 48 | set_seed(opt['manual_seed']) 49 | 50 | # Load model 51 | EDLoRA_trainer = EDLoRATrainer(**opt['models']) 52 | 53 | # set optimizer 54 | train_opt = opt['train'] 55 | optim_type = train_opt['optim_g'].pop('type') 56 | assert optim_type == 'AdamW', 'only support AdamW now' 57 | optimizer = torch.optim.AdamW(EDLoRA_trainer.get_params_to_optimize(), **train_opt['optim_g']) 58 | 59 | # Get the training dataset 60 | trainset_cfg = opt['datasets']['train'] 61 | train_dataset = LoraDataset(trainset_cfg) 62 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=trainset_cfg['batch_size_per_gpu'], shuffle=True, drop_last=True) 63 | 64 | # Get the training dataset 65 | valset_cfg = opt['datasets']['val_vis'] 66 | val_dataset = PromptDataset(valset_cfg) 67 | val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=valset_cfg['batch_size_per_gpu'], shuffle=False) 68 | 69 | # Prepare everything with our `accelerator`. 70 | EDLoRA_trainer, optimizer, train_dataloader, val_dataloader = accelerator.prepare(EDLoRA_trainer, optimizer, train_dataloader, val_dataloader) 71 | 72 | # Train! 73 | total_batch_size = opt['datasets']['train']['batch_size_per_gpu'] * accelerator.num_processes * opt['gradient_accumulation_steps'] 74 | total_iter = len(train_dataset) / total_batch_size 75 | opt['train']['total_iter'] = total_iter 76 | 77 | logger.info('***** Running training *****') 78 | logger.info(f' Num examples = {len(train_dataset)}') 79 | logger.info(f" Instantaneous batch size per device = {opt['datasets']['train']['batch_size_per_gpu']}") 80 | logger.info(f' Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}') 81 | logger.info(f' Total optimization steps = {total_iter}') 82 | global_step = 0 83 | 84 | # Scheduler 85 | lr_scheduler = get_scheduler( 86 | 'linear', 87 | optimizer=optimizer, 88 | num_warmup_steps=0, 89 | num_training_steps=total_iter * opt['gradient_accumulation_steps'], 90 | ) 91 | 92 | def make_data_yielder(dataloader): 93 | while True: 94 | for batch in dataloader: 95 | yield batch 96 | accelerator.wait_for_everyone() 97 | 98 | train_data_yielder = make_data_yielder(train_dataloader) 99 | 100 | msg_logger = MessageLogger(opt, global_step) 101 | stop_emb_update = False 102 | 103 | original_embedding = copy.deepcopy(accelerator.unwrap_model(EDLoRA_trainer).text_encoder.get_input_embeddings().weight) 104 | 105 | while global_step < opt['train']['total_iter']: 106 | with accelerator.accumulate(EDLoRA_trainer): 107 | 108 | accelerator.unwrap_model(EDLoRA_trainer).unet.train() 109 | accelerator.unwrap_model(EDLoRA_trainer).text_encoder.train() 110 | loss_dict = {} 111 | 112 | batch = next(train_data_yielder) 113 | 114 | if 'masks' in batch: 115 | masks = batch['masks'] 116 | else: 117 | masks = batch['img_masks'] 118 | 119 | loss = EDLoRA_trainer(batch['images'], batch['prompts'], masks, batch['img_masks']) 120 | loss_dict['loss'] = loss 121 | 122 | # get fix embedding and learn embedding 123 | index_no_updates = torch.arange(len(accelerator.unwrap_model(EDLoRA_trainer).tokenizer)) != -1 124 | if not stop_emb_update: 125 | for token_id in accelerator.unwrap_model(EDLoRA_trainer).get_all_concept_token_ids(): 126 | index_no_updates[token_id] = False 127 | 128 | accelerator.backward(loss) 129 | optimizer.step() 130 | lr_scheduler.step() 131 | optimizer.zero_grad() 132 | 133 | if accelerator.sync_gradients: 134 | # set no update token to origin 135 | token_embeds = accelerator.unwrap_model(EDLoRA_trainer).text_encoder.get_input_embeddings().weight 136 | token_embeds.data[index_no_updates, :] = original_embedding.data[index_no_updates, :] 137 | 138 | token_embeds = accelerator.unwrap_model(EDLoRA_trainer).text_encoder.get_input_embeddings().weight 139 | concept_token_ids = accelerator.unwrap_model(EDLoRA_trainer).get_all_concept_token_ids() 140 | loss_dict['Norm_mean'] = token_embeds[concept_token_ids].norm(dim=-1).mean() 141 | if stop_emb_update is False and float(loss_dict['Norm_mean']) >= train_opt.get('emb_norm_threshold', 5.5e-1): 142 | stop_emb_update = True 143 | original_embedding = copy.deepcopy(accelerator.unwrap_model(EDLoRA_trainer).text_encoder.get_input_embeddings().weight) 144 | 145 | log_dict = reduce_loss_dict(accelerator, loss_dict) 146 | 147 | # Checks if the accelerator has performed an optimization step behind the scenes 148 | if accelerator.sync_gradients: 149 | global_step += 1 150 | 151 | if global_step % opt['logger']['print_freq'] == 0: 152 | log_vars = {'iter': global_step} 153 | log_vars.update({'lrs': lr_scheduler.get_last_lr()}) 154 | log_vars.update(log_dict) 155 | msg_logger(log_vars) 156 | 157 | if global_step % opt['logger']['save_checkpoint_freq'] == 0: 158 | save_and_validation(accelerator, opt, EDLoRA_trainer, val_dataloader, global_step, logger) 159 | 160 | # Save the lora layers, final eval 161 | accelerator.wait_for_everyone() 162 | save_and_validation(accelerator, opt, EDLoRA_trainer, val_dataloader, 'latest', logger) 163 | 164 | 165 | def save_and_validation(accelerator, opt, EDLoRA_trainer, val_dataloader, global_step, logger): 166 | enable_edlora = opt['models']['enable_edlora'] 167 | lora_type = 'edlora' if enable_edlora else 'lora' 168 | save_path = os.path.join(opt['path']['models'], f'{lora_type}_model-{global_step}.pth') 169 | 170 | if accelerator.is_main_process: 171 | accelerator.save({'params': accelerator.unwrap_model(EDLoRA_trainer).delta_state_dict()}, save_path) 172 | logger.info(f'Save state to {save_path}') 173 | 174 | accelerator.wait_for_everyone() 175 | 176 | if opt['val']['val_during_save']: 177 | logger.info(f'Start validation {save_path}:') 178 | for lora_alpha in opt['val']['alpha_list']: 179 | pipeclass = EDLoRAPipeline if enable_edlora else StableDiffusionPipeline 180 | 181 | pipe = pipeclass.from_pretrained(opt['models']['pretrained_path'], 182 | scheduler=DPMSolverMultistepScheduler.from_pretrained(opt['models']['pretrained_path'], subfolder='scheduler'), 183 | torch_dtype=torch.float16).to('cuda') 184 | pipe, new_concept_cfg = convert_edlora(pipe, torch.load(save_path), enable_edlora=enable_edlora, alpha=lora_alpha) 185 | pipe.set_new_concept_cfg(new_concept_cfg) 186 | pipe.set_progress_bar_config(disable=True) 187 | visual_validation(accelerator, pipe, val_dataloader, f'Iters-{global_step}_Alpha-{lora_alpha}', opt) 188 | 189 | del pipe 190 | 191 | 192 | if __name__ == '__main__': 193 | parser = argparse.ArgumentParser() 194 | parser.add_argument('-opt', type=str, default='options/train/EDLoRA/EDLoRA_hina_Anyv4_B4_Iter1K.yml') 195 | args = parser.parse_args() 196 | 197 | root_path = osp.abspath(osp.join(__file__, osp.pardir)) 198 | train(root_path, args) 199 | --------------------------------------------------------------------------------