├── .gitignore ├── LICENSE.txt ├── README.md ├── assets ├── application.png ├── examples │ ├── 1-newton.jpg │ ├── 1-output-1.png │ ├── 2-output-1.png │ ├── 2-stylegan2-ffhq-0100.png │ ├── 2-stylegan2-ffhq-0293.png │ ├── 3-output-1.png │ ├── 3-output-2.png │ ├── 3-output-3.png │ ├── 3-output-4.png │ ├── 3-style-1.png │ ├── 3-style-2.jpg │ ├── 3-style-3.jpg │ ├── 3-stylegan2-ffhq-0293.png │ └── 3-stylegan2-ffhq-0381.png ├── framework.png └── highlight.png ├── gradio_app.py ├── requirements.txt └── uniportrait ├── __init__.py ├── curricular_face ├── __init__.py ├── backbone │ ├── __init__.py │ ├── common.py │ ├── model_irse.py │ └── model_resnet.py └── inference.py ├── inversion.py ├── resampler.py ├── uniportrait_attention_processor.py └── uniportrait_pipeline.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .DS_Store 3 | *.dat 4 | *.mat 5 | 6 | training/ 7 | lightning_logs/ 8 | image_log/ 9 | 10 | *.png 11 | *.jpg 12 | *.jpeg 13 | *.webp 14 | 15 | *.pth 16 | *.pt 17 | *.ckpt 18 | *.safetensors 19 | 20 | # Byte-compiled / optimized / DLL files 21 | __pycache__/ 22 | *.py[cod] 23 | *$py.class 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Distribution / packaging 29 | .Python 30 | build/ 31 | develop-eggs/ 32 | dist/ 33 | downloads/ 34 | eggs/ 35 | .eggs/ 36 | lib/ 37 | lib64/ 38 | parts/ 39 | sdist/ 40 | var/ 41 | wheels/ 42 | pip-wheel-metadata/ 43 | share/python-wheels/ 44 | *.egg-info/ 45 | .installed.cfg 46 | *.egg 47 | MANIFEST 48 | 49 | # PyInstaller 50 | # Usually these files are written by a python script from a template 51 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 52 | *.manifest 53 | *.spec 54 | 55 | # Installer logs 56 | pip-log.txt 57 | pip-delete-this-directory.txt 58 | 59 | # Unit test / coverage reports 60 | htmlcov/ 61 | .tox/ 62 | .nox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *.cover 69 | *.py,cover 70 | .hypothesis/ 71 | .pytest_cache/ 72 | 73 | # Translations 74 | *.mo 75 | *.pot 76 | 77 | # Django stuff: 78 | *.log 79 | local_settings.py 80 | db.sqlite3 81 | db.sqlite3-journal 82 | 83 | # Flask stuff: 84 | instance/ 85 | .webassets-cache 86 | 87 | # Scrapy stuff: 88 | .scrapy 89 | 90 | # Sphinx documentation 91 | docs/_build/ 92 | 93 | # PyBuilder 94 | target/ 95 | 96 | # Jupyter Notebook 97 | .ipynb_checkpoints 98 | 99 | # IPython 100 | profile_default/ 101 | ipython_config.py 102 | 103 | # pyenv 104 | .python-version 105 | 106 | # pipenv 107 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 108 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 109 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 110 | # install all needed dependencies. 111 | #Pipfile.lock 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

UniPortrait: A Unified Framework for Identity-Preserving Single- and Multi-Human Image Personalization

3 | 4 | 5 | 6 | 7 | 8 |
9 | 10 | 11 | 12 | UniPortrait is an innovative human image personalization framework. It customizes single- and multi-ID images in a 13 | unified manner, providing high-fidelity identity preservation, extensive facial editability, free-form text description, 14 | and no requirement for a predetermined layout. 15 | 16 | --- 17 | 18 | ## Release 19 | 20 | - [2025/05/01] 🔥 We release the code and demo for the `FLUX.1-dev` version of [AnyStory](https://github.com/junjiehe96/AnyStory), a unified approach to general subject personalization. 21 | - [2024/10/18] 🔥 We release the inference code and demo, which has simply 22 | integrated [ControlNet](https://github.com/lllyasviel/ControlNet) 23 | , [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter), 24 | and [StyleAligned](https://github.com/google/style-aligned). The weight for this version is consistent with the 25 | huggingface space and experiments in the paper. We are now working on generalizing our method to more advanced 26 | diffusion models and more general custom concepts. Please stay tuned! 27 | - [2024/08/12] 🔥 We release the [technical report](https://arxiv.org/abs/2408.05939) 28 | , [project page](https://aigcdesigngroup.github.io/UniPortrait-Page/), 29 | and [HuggingFace demo](https://huggingface.co/spaces/Junjie96/UniPortrait) 🤗! 30 | 31 | ## Quickstart 32 | 33 | ```shell 34 | # Clone repository 35 | git clone https://github.com/junjiehe96/UniPortrait.git 36 | 37 | # install requirements 38 | cd UniPortrait 39 | pip install -r requirements.txt 40 | 41 | # download the models 42 | git lfs install 43 | git clone https://huggingface.co/Junjie96/UniPortrait models 44 | # download ip-adapter models 45 | # Note: recommend downloading manually. We do not require all IP adapter models. 46 | git clone https://huggingface.co/h94/IP-Adapter models/IP-Adapter 47 | 48 | # then you can use the gradio app 49 | python gradio_app.py 50 | ``` 51 | 52 | ## Applications 53 | 54 | 55 | 56 | ## **Acknowledgements** 57 | 58 | This code is built on some excellent repos, including [diffusers](https://github.com/huggingface/diffusers), [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) and [StyleAligned](https://github.com/google/style-aligned). Highly appreciate their great work! 59 | 60 | ## Cite 61 | 62 | If you find UniPortrait useful for your research and applications, please cite us using this BibTeX: 63 | 64 | ```bibtex 65 | @article{he2024uniportrait, 66 | title={UniPortrait: A Unified Framework for Identity-Preserving Single-and Multi-Human Image Personalization}, 67 | author={He, Junjie and Geng, Yifeng and Bo, Liefeng}, 68 | journal={arXiv preprint arXiv:2408.05939}, 69 | year={2024} 70 | } 71 | ``` 72 | 73 | For any question, please feel free to open an issue or contact us via hejunjie1103@gmail.com. 74 | -------------------------------------------------------------------------------- /assets/application.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/application.png -------------------------------------------------------------------------------- /assets/examples/1-newton.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/1-newton.jpg -------------------------------------------------------------------------------- /assets/examples/1-output-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/1-output-1.png -------------------------------------------------------------------------------- /assets/examples/2-output-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/2-output-1.png -------------------------------------------------------------------------------- /assets/examples/2-stylegan2-ffhq-0100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/2-stylegan2-ffhq-0100.png -------------------------------------------------------------------------------- /assets/examples/2-stylegan2-ffhq-0293.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/2-stylegan2-ffhq-0293.png -------------------------------------------------------------------------------- /assets/examples/3-output-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/3-output-1.png -------------------------------------------------------------------------------- /assets/examples/3-output-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/3-output-2.png -------------------------------------------------------------------------------- /assets/examples/3-output-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/3-output-3.png -------------------------------------------------------------------------------- /assets/examples/3-output-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/3-output-4.png -------------------------------------------------------------------------------- /assets/examples/3-style-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/3-style-1.png -------------------------------------------------------------------------------- /assets/examples/3-style-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/3-style-2.jpg -------------------------------------------------------------------------------- /assets/examples/3-style-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/3-style-3.jpg -------------------------------------------------------------------------------- /assets/examples/3-stylegan2-ffhq-0293.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/3-stylegan2-ffhq-0293.png -------------------------------------------------------------------------------- /assets/examples/3-stylegan2-ffhq-0381.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/3-stylegan2-ffhq-0381.png -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/framework.png -------------------------------------------------------------------------------- /assets/highlight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/highlight.png -------------------------------------------------------------------------------- /gradio_app.py: -------------------------------------------------------------------------------- 1 | import os 2 | from io import BytesIO 3 | 4 | import cv2 5 | import gradio as gr 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | from diffusers import DDIMScheduler, AutoencoderKL, ControlNetModel, StableDiffusionControlNetPipeline 10 | from insightface.app import FaceAnalysis 11 | from insightface.utils import face_align 12 | 13 | from uniportrait import inversion 14 | from uniportrait.uniportrait_attention_processor import attn_args 15 | from uniportrait.uniportrait_pipeline import UniPortraitPipeline 16 | 17 | port = 7860 18 | 19 | device = "cuda" 20 | torch_dtype = torch.float16 21 | 22 | # base 23 | base_model_path = "SG161222/Realistic_Vision_V5.1_noVAE" 24 | vae_model_path = "stabilityai/sd-vae-ft-mse" 25 | controlnet_pose_ckpt = "lllyasviel/control_v11p_sd15_openpose" 26 | # specific 27 | image_encoder_path = "models/IP-Adapter/models/image_encoder" 28 | ip_ckpt = "models/IP-Adapter/models/ip-adapter_sd15.bin" 29 | face_backbone_ckpt = "models/glint360k_curricular_face_r101_backbone.bin" 30 | uniportrait_faceid_ckpt = "models/uniportrait-faceid_sd15.bin" 31 | uniportrait_router_ckpt = "models/uniportrait-router_sd15.bin" 32 | 33 | # load controlnet 34 | pose_controlnet = ControlNetModel.from_pretrained(controlnet_pose_ckpt, torch_dtype=torch_dtype) 35 | 36 | # load SD pipeline 37 | noise_scheduler = DDIMScheduler( 38 | num_train_timesteps=1000, 39 | beta_start=0.00085, 40 | beta_end=0.012, 41 | beta_schedule="scaled_linear", 42 | clip_sample=False, 43 | set_alpha_to_one=False, 44 | steps_offset=1, 45 | ) 46 | vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch_dtype) 47 | pipe = StableDiffusionControlNetPipeline.from_pretrained( 48 | base_model_path, 49 | controlnet=[pose_controlnet], 50 | torch_dtype=torch_dtype, 51 | scheduler=noise_scheduler, 52 | vae=vae, 53 | # feature_extractor=None, 54 | # safety_checker=None, 55 | ) 56 | 57 | # load uniportrait pipeline 58 | uniportrait_pipeline = UniPortraitPipeline(pipe, image_encoder_path, ip_ckpt=ip_ckpt, 59 | face_backbone_ckpt=face_backbone_ckpt, 60 | uniportrait_faceid_ckpt=uniportrait_faceid_ckpt, 61 | uniportrait_router_ckpt=uniportrait_router_ckpt, 62 | device=device, torch_dtype=torch_dtype) 63 | 64 | # load face detection assets 65 | face_app = FaceAnalysis(providers=['CUDAExecutionProvider'], allowed_modules=["detection"]) 66 | face_app.prepare(ctx_id=0, det_size=(640, 640)) 67 | 68 | 69 | def pad_np_bgr_image(np_image, scale=1.25): 70 | assert scale >= 1.0, "scale should be >= 1.0" 71 | pad_scale = scale - 1.0 72 | h, w = np_image.shape[:2] 73 | top = bottom = int(h * pad_scale) 74 | left = right = int(w * pad_scale) 75 | ret = cv2.copyMakeBorder(np_image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(128, 128, 128)) 76 | return ret, (left, top) 77 | 78 | 79 | def process_faceid_image(pil_faceid_image): 80 | np_faceid_image = np.array(pil_faceid_image.convert("RGB")) 81 | img = cv2.cvtColor(np_faceid_image, cv2.COLOR_RGB2BGR) 82 | faces = face_app.get(img) # bgr 83 | if len(faces) == 0: 84 | # padding, try again 85 | _h, _w = img.shape[:2] 86 | _img, left_top_coord = pad_np_bgr_image(img) 87 | faces = face_app.get(_img) 88 | if len(faces) == 0: 89 | gr.Info("Warning: No face detected in the image. Continue processing...") 90 | 91 | min_coord = np.array([0, 0]) 92 | max_coord = np.array([_w, _h]) 93 | sub_coord = np.array([left_top_coord[0], left_top_coord[1]]) 94 | for face in faces: 95 | face.bbox = np.minimum(np.maximum(face.bbox.reshape(-1, 2) - sub_coord, min_coord), max_coord).reshape(4) 96 | face.kps = face.kps - sub_coord 97 | 98 | faces = sorted(faces, key=lambda x: abs((x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1])), reverse=True) 99 | faceid_face = faces[0] 100 | norm_face = face_align.norm_crop(img, landmark=faceid_face.kps, image_size=224) 101 | pil_faceid_align_image = Image.fromarray(cv2.cvtColor(norm_face, cv2.COLOR_BGR2RGB)) 102 | 103 | return pil_faceid_align_image 104 | 105 | 106 | def prepare_single_faceid_cond_kwargs(pil_faceid_image=None, pil_faceid_supp_images=None, 107 | pil_faceid_mix_images=None, mix_scales=None): 108 | pil_faceid_align_images = [] 109 | if pil_faceid_image: 110 | pil_faceid_align_images.append(process_faceid_image(pil_faceid_image)) 111 | if pil_faceid_supp_images and len(pil_faceid_supp_images) > 0: 112 | for pil_faceid_supp_image in pil_faceid_supp_images: 113 | if isinstance(pil_faceid_supp_image, Image.Image): 114 | pil_faceid_align_images.append(process_faceid_image(pil_faceid_supp_image)) 115 | else: 116 | pil_faceid_align_images.append( 117 | process_faceid_image(Image.open(BytesIO(pil_faceid_supp_image))) 118 | ) 119 | 120 | mix_refs = [] 121 | mix_ref_scales = [] 122 | if pil_faceid_mix_images: 123 | for pil_faceid_mix_image, mix_scale in zip(pil_faceid_mix_images, mix_scales): 124 | if pil_faceid_mix_image: 125 | mix_refs.append(process_faceid_image(pil_faceid_mix_image)) 126 | mix_ref_scales.append(mix_scale) 127 | 128 | single_faceid_cond_kwargs = None 129 | if len(pil_faceid_align_images) > 0: 130 | single_faceid_cond_kwargs = { 131 | "refs": pil_faceid_align_images 132 | } 133 | if len(mix_refs) > 0: 134 | single_faceid_cond_kwargs["mix_refs"] = mix_refs 135 | single_faceid_cond_kwargs["mix_scales"] = mix_ref_scales 136 | 137 | return single_faceid_cond_kwargs 138 | 139 | 140 | def text_to_single_id_generation_process( 141 | pil_faceid_image=None, pil_faceid_supp_images=None, 142 | pil_faceid_mix_image_1=None, mix_scale_1=0.0, 143 | pil_faceid_mix_image_2=None, mix_scale_2=0.0, 144 | faceid_scale=0.0, face_structure_scale=0.0, 145 | prompt="", negative_prompt="", 146 | num_samples=1, seed=-1, 147 | image_resolution="512x512", 148 | inference_steps=25, 149 | ): 150 | if seed == -1: 151 | seed = None 152 | 153 | single_faceid_cond_kwargs = prepare_single_faceid_cond_kwargs(pil_faceid_image, 154 | pil_faceid_supp_images, 155 | [pil_faceid_mix_image_1, pil_faceid_mix_image_2], 156 | [mix_scale_1, mix_scale_2]) 157 | 158 | cond_faceids = [single_faceid_cond_kwargs] if single_faceid_cond_kwargs else [] 159 | 160 | # reset attn args 161 | attn_args.reset() 162 | # set faceid condition 163 | attn_args.lora_scale = 1.0 if len(cond_faceids) == 1 else 0.0 # single-faceid lora 164 | attn_args.multi_id_lora_scale = 1.0 if len(cond_faceids) > 1 else 0.0 # multi-faceid lora 165 | attn_args.faceid_scale = faceid_scale if len(cond_faceids) > 0 else 0.0 166 | attn_args.num_faceids = len(cond_faceids) 167 | print(attn_args) 168 | 169 | h, w = int(image_resolution.split("x")[0]), int(image_resolution.split("x")[1]) 170 | prompt = [prompt] * num_samples 171 | negative_prompt = [negative_prompt] * num_samples 172 | images = uniportrait_pipeline.generate(prompt=prompt, negative_prompt=negative_prompt, 173 | cond_faceids=cond_faceids, face_structure_scale=face_structure_scale, 174 | seed=seed, guidance_scale=7.5, 175 | num_inference_steps=inference_steps, 176 | image=[torch.zeros([1, 3, h, w])], 177 | controlnet_conditioning_scale=[0.0]) 178 | final_out = [] 179 | for pil_image in images: 180 | final_out.append(pil_image) 181 | 182 | for single_faceid_cond_kwargs in cond_faceids: 183 | final_out.extend(single_faceid_cond_kwargs["refs"]) 184 | if "mix_refs" in single_faceid_cond_kwargs: 185 | final_out.extend(single_faceid_cond_kwargs["mix_refs"]) 186 | 187 | return final_out 188 | 189 | 190 | def text_to_multi_id_generation_process( 191 | pil_faceid_image_1=None, pil_faceid_supp_images_1=None, 192 | pil_faceid_mix_image_1_1=None, mix_scale_1_1=0.0, 193 | pil_faceid_mix_image_1_2=None, mix_scale_1_2=0.0, 194 | pil_faceid_image_2=None, pil_faceid_supp_images_2=None, 195 | pil_faceid_mix_image_2_1=None, mix_scale_2_1=0.0, 196 | pil_faceid_mix_image_2_2=None, mix_scale_2_2=0.0, 197 | faceid_scale=0.0, face_structure_scale=0.0, 198 | prompt="", negative_prompt="", 199 | num_samples=1, seed=-1, 200 | image_resolution="512x512", 201 | inference_steps=25, 202 | ): 203 | if seed == -1: 204 | seed = None 205 | 206 | faceid_cond_kwargs_1 = prepare_single_faceid_cond_kwargs(pil_faceid_image_1, 207 | pil_faceid_supp_images_1, 208 | [pil_faceid_mix_image_1_1, 209 | pil_faceid_mix_image_1_2], 210 | [mix_scale_1_1, mix_scale_1_2]) 211 | faceid_cond_kwargs_2 = prepare_single_faceid_cond_kwargs(pil_faceid_image_2, 212 | pil_faceid_supp_images_2, 213 | [pil_faceid_mix_image_2_1, 214 | pil_faceid_mix_image_2_2], 215 | [mix_scale_2_1, mix_scale_2_2]) 216 | cond_faceids = [] 217 | if faceid_cond_kwargs_1 is not None: 218 | cond_faceids.append(faceid_cond_kwargs_1) 219 | if faceid_cond_kwargs_2 is not None: 220 | cond_faceids.append(faceid_cond_kwargs_2) 221 | 222 | # reset attn args 223 | attn_args.reset() 224 | # set faceid condition 225 | attn_args.lora_scale = 1.0 if len(cond_faceids) == 1 else 0.0 # single-faceid lora 226 | attn_args.multi_id_lora_scale = 1.0 if len(cond_faceids) > 1 else 0.0 # multi-faceid lora 227 | attn_args.faceid_scale = faceid_scale if len(cond_faceids) > 0 else 0.0 228 | attn_args.num_faceids = len(cond_faceids) 229 | print(attn_args) 230 | 231 | h, w = int(image_resolution.split("x")[0]), int(image_resolution.split("x")[1]) 232 | prompt = [prompt] * num_samples 233 | negative_prompt = [negative_prompt] * num_samples 234 | images = uniportrait_pipeline.generate(prompt=prompt, negative_prompt=negative_prompt, 235 | cond_faceids=cond_faceids, face_structure_scale=face_structure_scale, 236 | seed=seed, guidance_scale=7.5, 237 | num_inference_steps=inference_steps, 238 | image=[torch.zeros([1, 3, h, w])], 239 | controlnet_conditioning_scale=[0.0]) 240 | 241 | final_out = [] 242 | for pil_image in images: 243 | final_out.append(pil_image) 244 | 245 | for single_faceid_cond_kwargs in cond_faceids: 246 | final_out.extend(single_faceid_cond_kwargs["refs"]) 247 | if "mix_refs" in single_faceid_cond_kwargs: 248 | final_out.extend(single_faceid_cond_kwargs["mix_refs"]) 249 | 250 | return final_out 251 | 252 | 253 | def image_to_single_id_generation_process( 254 | pil_faceid_image=None, pil_faceid_supp_images=None, 255 | pil_faceid_mix_image_1=None, mix_scale_1=0.0, 256 | pil_faceid_mix_image_2=None, mix_scale_2=0.0, 257 | faceid_scale=0.0, face_structure_scale=0.0, 258 | pil_ip_image=None, ip_scale=1.0, 259 | num_samples=1, seed=-1, image_resolution="768x512", 260 | inference_steps=25, 261 | ): 262 | if seed == -1: 263 | seed = None 264 | 265 | single_faceid_cond_kwargs = prepare_single_faceid_cond_kwargs(pil_faceid_image, 266 | pil_faceid_supp_images, 267 | [pil_faceid_mix_image_1, pil_faceid_mix_image_2], 268 | [mix_scale_1, mix_scale_2]) 269 | 270 | cond_faceids = [single_faceid_cond_kwargs] if single_faceid_cond_kwargs else [] 271 | 272 | h, w = int(image_resolution.split("x")[0]), int(image_resolution.split("x")[1]) 273 | 274 | # Image Prompt and Style Aligned 275 | if pil_ip_image is None: 276 | gr.Error("Please upload a reference image") 277 | attn_args.reset() 278 | pil_ip_image = pil_ip_image.convert("RGB").resize((w, h)) 279 | zts = inversion.ddim_inversion(uniportrait_pipeline.pipe, np.array(pil_ip_image), "", inference_steps, 2) 280 | zT, inversion_callback = inversion.make_inversion_callback(zts, offset=0) 281 | 282 | # reset attn args 283 | attn_args.reset() 284 | # set ip condition 285 | attn_args.ip_scale = ip_scale if pil_ip_image else 0.0 286 | # set faceid condition 287 | attn_args.lora_scale = 1.0 if len(cond_faceids) == 1 else 0.0 # lora for single faceid 288 | attn_args.multi_id_lora_scale = 1.0 if len(cond_faceids) > 1 else 0.0 # lora for >1 faceids 289 | attn_args.faceid_scale = faceid_scale if len(cond_faceids) > 0 else 0.0 290 | attn_args.num_faceids = len(cond_faceids) 291 | # set shared self-attn 292 | attn_args.enable_share_attn = True 293 | attn_args.shared_score_shift = -0.5 294 | print(attn_args) 295 | 296 | prompt = [""] * (1 + num_samples) 297 | negative_prompt = [""] * (1 + num_samples) 298 | images = uniportrait_pipeline.generate(prompt=prompt, negative_prompt=negative_prompt, 299 | pil_ip_image=pil_ip_image, 300 | cond_faceids=cond_faceids, face_structure_scale=face_structure_scale, 301 | seed=seed, guidance_scale=7.5, 302 | num_inference_steps=inference_steps, 303 | image=[torch.zeros([1, 3, h, w])], 304 | controlnet_conditioning_scale=[0.0], 305 | zT=zT, callback_on_step_end=inversion_callback) 306 | images = images[1:] 307 | 308 | final_out = [] 309 | for pil_image in images: 310 | final_out.append(pil_image) 311 | 312 | for single_faceid_cond_kwargs in cond_faceids: 313 | final_out.extend(single_faceid_cond_kwargs["refs"]) 314 | if "mix_refs" in single_faceid_cond_kwargs: 315 | final_out.extend(single_faceid_cond_kwargs["mix_refs"]) 316 | 317 | return final_out 318 | 319 | 320 | def text_to_single_id_generation_block(): 321 | gr.Markdown("## Text-to-Single-ID Generation") 322 | gr.HTML(text_to_single_id_description) 323 | gr.HTML(text_to_single_id_tips) 324 | with gr.Row(): 325 | with gr.Column(scale=1, min_width=100): 326 | prompt = gr.Textbox(value="", label='Prompt', lines=2) 327 | negative_prompt = gr.Textbox(value="nsfw", label='Negative Prompt') 328 | 329 | run_button = gr.Button(value="Run") 330 | with gr.Accordion("Options", open=True): 331 | image_resolution = gr.Dropdown(choices=["768x512", "512x512", "512x768"], value="512x512", 332 | label="Image Resolution (HxW)") 333 | seed = gr.Slider(label="Seed (-1 indicates random)", minimum=-1, maximum=2147483647, step=1, 334 | value=2147483647) 335 | num_samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1) 336 | inference_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1, visible=False) 337 | 338 | faceid_scale = gr.Slider(label="Face ID Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.7) 339 | face_structure_scale = gr.Slider(label="Face Structure Scale", minimum=0.0, maximum=1.0, 340 | step=0.01, value=0.1) 341 | 342 | with gr.Column(scale=2, min_width=100): 343 | with gr.Row(equal_height=False): 344 | pil_faceid_image = gr.Image(type="pil", label="ID Image") 345 | with gr.Accordion("ID Supplements", open=True): 346 | with gr.Row(): 347 | pil_faceid_supp_images = gr.File(file_count="multiple", file_types=["image"], 348 | type="binary", label="Additional ID Images") 349 | with gr.Row(): 350 | with gr.Column(scale=1, min_width=100): 351 | pil_faceid_mix_image_1 = gr.Image(type="pil", label="Mix ID 1") 352 | mix_scale_1 = gr.Slider(label="Mix Scale 1", minimum=0.0, maximum=1.0, step=0.01, value=0.0) 353 | with gr.Column(scale=1, min_width=100): 354 | pil_faceid_mix_image_2 = gr.Image(type="pil", label="Mix ID 2") 355 | mix_scale_2 = gr.Slider(label="Mix Scale 2", minimum=0.0, maximum=1.0, step=0.01, value=0.0) 356 | 357 | with gr.Row(): 358 | example_output = gr.Image(type="pil", label="(Example Output)", visible=False) 359 | result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=4, preview=True, 360 | format="png") 361 | with gr.Row(): 362 | examples = [ 363 | [ 364 | "A young man with short black hair, wearing a black hoodie with a hood, was paired with a blue denim jacket with yellow details.", 365 | "assets/examples/1-newton.jpg", 366 | "assets/examples/1-output-1.png", 367 | ], 368 | ] 369 | gr.Examples( 370 | label="Examples", 371 | examples=examples, 372 | fn=lambda x, y, z: (x, y), 373 | inputs=[prompt, pil_faceid_image, example_output], 374 | outputs=[prompt, pil_faceid_image] 375 | ) 376 | ips = [ 377 | pil_faceid_image, pil_faceid_supp_images, 378 | pil_faceid_mix_image_1, mix_scale_1, 379 | pil_faceid_mix_image_2, mix_scale_2, 380 | faceid_scale, face_structure_scale, 381 | prompt, negative_prompt, 382 | num_samples, seed, 383 | image_resolution, 384 | inference_steps, 385 | ] 386 | run_button.click(fn=text_to_single_id_generation_process, inputs=ips, outputs=[result_gallery]) 387 | 388 | 389 | def text_to_multi_id_generation_block(): 390 | gr.Markdown("## Text-to-Multi-ID Generation") 391 | gr.HTML(text_to_multi_id_description) 392 | gr.HTML(text_to_multi_id_tips) 393 | with gr.Row(): 394 | with gr.Column(scale=1, min_width=100): 395 | prompt = gr.Textbox(value="", label='Prompt', lines=2) 396 | negative_prompt = gr.Textbox(value="nsfw", label='Negative Prompt') 397 | run_button = gr.Button(value="Run") 398 | with gr.Accordion("Options", open=True): 399 | image_resolution = gr.Dropdown(choices=["768x512", "512x512", "512x768"], value="512x512", 400 | label="Image Resolution (HxW)") 401 | seed = gr.Slider(label="Seed (-1 indicates random)", minimum=-1, maximum=2147483647, step=1, 402 | value=2147483647) 403 | num_samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1) 404 | inference_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1, visible=False) 405 | 406 | faceid_scale = gr.Slider(label="Face ID Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.7) 407 | face_structure_scale = gr.Slider(label="Face Structure Scale", minimum=0.0, maximum=1.0, 408 | step=0.01, value=0.3) 409 | 410 | with gr.Column(scale=2, min_width=100): 411 | with gr.Row(equal_height=False): 412 | with gr.Column(scale=1, min_width=100): 413 | pil_faceid_image_1 = gr.Image(type="pil", label="First ID") 414 | with gr.Accordion("First ID Supplements", open=False): 415 | with gr.Row(): 416 | pil_faceid_supp_images_1 = gr.File(file_count="multiple", file_types=["image"], 417 | type="binary", label="Additional ID Images") 418 | with gr.Row(): 419 | with gr.Column(scale=1, min_width=100): 420 | pil_faceid_mix_image_1_1 = gr.Image(type="pil", label="Mix ID 1") 421 | mix_scale_1_1 = gr.Slider(label="Mix Scale 1", minimum=0.0, maximum=1.0, step=0.01, 422 | value=0.0) 423 | with gr.Column(scale=1, min_width=100): 424 | pil_faceid_mix_image_1_2 = gr.Image(type="pil", label="Mix ID 2") 425 | mix_scale_1_2 = gr.Slider(label="Mix Scale 2", minimum=0.0, maximum=1.0, step=0.01, 426 | value=0.0) 427 | with gr.Column(scale=1, min_width=100): 428 | pil_faceid_image_2 = gr.Image(type="pil", label="Second ID") 429 | with gr.Accordion("Second ID Supplements", open=False): 430 | with gr.Row(): 431 | pil_faceid_supp_images_2 = gr.File(file_count="multiple", file_types=["image"], 432 | type="binary", label="Additional ID Images") 433 | with gr.Row(): 434 | with gr.Column(scale=1, min_width=100): 435 | pil_faceid_mix_image_2_1 = gr.Image(type="pil", label="Mix ID 1") 436 | mix_scale_2_1 = gr.Slider(label="Mix Scale 1", minimum=0.0, maximum=1.0, step=0.01, 437 | value=0.0) 438 | with gr.Column(scale=1, min_width=100): 439 | pil_faceid_mix_image_2_2 = gr.Image(type="pil", label="Mix ID 2") 440 | mix_scale_2_2 = gr.Slider(label="Mix Scale 2", minimum=0.0, maximum=1.0, step=0.01, 441 | value=0.0) 442 | 443 | with gr.Row(): 444 | example_output = gr.Image(type="pil", label="(Example Output)", visible=False) 445 | result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=4, preview=True, 446 | format="png") 447 | with gr.Row(): 448 | examples = [ 449 | [ 450 | "The two female models, fair-skinned, wore a white V-neck short-sleeved top with a light smile on the corners of their mouths. The background was off-white.", 451 | "assets/examples/2-stylegan2-ffhq-0100.png", 452 | "assets/examples/2-stylegan2-ffhq-0293.png", 453 | "assets/examples/2-output-1.png", 454 | ], 455 | ] 456 | gr.Examples( 457 | label="Examples", 458 | examples=examples, 459 | inputs=[prompt, pil_faceid_image_1, pil_faceid_image_2, example_output], 460 | ) 461 | ips = [ 462 | pil_faceid_image_1, pil_faceid_supp_images_1, 463 | pil_faceid_mix_image_1_1, mix_scale_1_1, 464 | pil_faceid_mix_image_1_2, mix_scale_1_2, 465 | pil_faceid_image_2, pil_faceid_supp_images_2, 466 | pil_faceid_mix_image_2_1, mix_scale_2_1, 467 | pil_faceid_mix_image_2_2, mix_scale_2_2, 468 | faceid_scale, face_structure_scale, 469 | prompt, negative_prompt, 470 | num_samples, seed, 471 | image_resolution, 472 | inference_steps, 473 | ] 474 | run_button.click(fn=text_to_multi_id_generation_process, inputs=ips, outputs=[result_gallery]) 475 | 476 | 477 | def image_to_single_id_generation_block(): 478 | gr.Markdown("## Image-to-Single-ID Generation") 479 | gr.HTML(image_to_single_id_description) 480 | gr.HTML(image_to_single_id_tips) 481 | with gr.Row(): 482 | with gr.Column(scale=1, min_width=100): 483 | run_button = gr.Button(value="Run") 484 | seed = gr.Slider(label="Seed (-1 indicates random)", minimum=-1, maximum=2147483647, step=1, 485 | value=2147483647) 486 | num_samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1) 487 | image_resolution = gr.Dropdown(choices=["768x512", "512x512", "512x768"], value="512x512", 488 | label="Image Resolution (HxW)") 489 | inference_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1, visible=False) 490 | 491 | ip_scale = gr.Slider(label="Reference Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.7) 492 | faceid_scale = gr.Slider(label="Face ID Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.7) 493 | face_structure_scale = gr.Slider(label="Face Structure Scale", minimum=0.0, maximum=1.0, step=0.01, 494 | value=0.3) 495 | 496 | with gr.Column(scale=3, min_width=100): 497 | with gr.Row(equal_height=False): 498 | pil_ip_image = gr.Image(type="pil", label="Portrait Reference") 499 | pil_faceid_image = gr.Image(type="pil", label="ID Image") 500 | with gr.Accordion("ID Supplements", open=True): 501 | with gr.Row(): 502 | pil_faceid_supp_images = gr.File(file_count="multiple", file_types=["image"], 503 | type="binary", label="Additional ID Images") 504 | with gr.Row(): 505 | with gr.Column(scale=1, min_width=100): 506 | pil_faceid_mix_image_1 = gr.Image(type="pil", label="Mix ID 1") 507 | mix_scale_1 = gr.Slider(label="Mix Scale 1", minimum=0.0, maximum=1.0, step=0.01, value=0.0) 508 | with gr.Column(scale=1, min_width=100): 509 | pil_faceid_mix_image_2 = gr.Image(type="pil", label="Mix ID 2") 510 | mix_scale_2 = gr.Slider(label="Mix Scale 2", minimum=0.0, maximum=1.0, step=0.01, value=0.0) 511 | with gr.Row(): 512 | with gr.Column(scale=3, min_width=100): 513 | example_output = gr.Image(type="pil", label="(Example Output)", visible=False) 514 | result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=4, 515 | preview=True, format="png") 516 | with gr.Row(): 517 | examples = [ 518 | [ 519 | "assets/examples/3-style-1.png", 520 | "assets/examples/3-stylegan2-ffhq-0293.png", 521 | 0.7, 522 | 0.3, 523 | "assets/examples/3-output-1.png", 524 | ], 525 | [ 526 | "assets/examples/3-style-1.png", 527 | "assets/examples/3-stylegan2-ffhq-0293.png", 528 | 0.6, 529 | 0.0, 530 | "assets/examples/3-output-2.png", 531 | ], 532 | [ 533 | "assets/examples/3-style-2.jpg", 534 | "assets/examples/3-stylegan2-ffhq-0381.png", 535 | 0.7, 536 | 0.3, 537 | "assets/examples/3-output-3.png", 538 | ], 539 | [ 540 | "assets/examples/3-style-3.jpg", 541 | "assets/examples/3-stylegan2-ffhq-0381.png", 542 | 0.6, 543 | 0.0, 544 | "assets/examples/3-output-4.png", 545 | ], 546 | ] 547 | gr.Examples( 548 | label="Examples", 549 | examples=examples, 550 | fn=lambda x, y, z, w, v: (x, y, z, w), 551 | inputs=[pil_ip_image, pil_faceid_image, faceid_scale, face_structure_scale, example_output], 552 | outputs=[pil_ip_image, pil_faceid_image, faceid_scale, face_structure_scale] 553 | ) 554 | ips = [ 555 | pil_faceid_image, pil_faceid_supp_images, 556 | pil_faceid_mix_image_1, mix_scale_1, 557 | pil_faceid_mix_image_2, mix_scale_2, 558 | faceid_scale, face_structure_scale, 559 | pil_ip_image, ip_scale, 560 | num_samples, seed, image_resolution, 561 | inference_steps, 562 | ] 563 | run_button.click(fn=image_to_single_id_generation_process, inputs=ips, outputs=[result_gallery]) 564 | 565 | 566 | if __name__ == "__main__": 567 | os.environ["no_proxy"] = "localhost,127.0.0.1,::1" 568 | 569 | title = r""" 570 |
571 |

UniPortrait: A Unified Framework for Identity-Preserving Single- and Multi-Human Image Personalization

572 |
573 | 574 |   575 | Project Page 576 |   577 | 578 |
579 |
580 |
581 | """ 582 | 583 | title_description = r""" 584 | This is the official 🤗 Gradio demo for UniPortrait: A Unified Framework for Identity-Preserving Single- and Multi-Human Image Personalization.
585 | The demo provides three capabilities: text-to-single-ID personalization, text-to-multi-ID personalization, and image-to-single-ID personalization. All of these are based on the Stable Diffusion v1-5 model. Feel free to give them a try! 😊 586 | """ 587 | 588 | text_to_single_id_description = r"""🚀🚀🚀Quick start:
589 | 1. Enter a text prompt (Chinese or English), Upload an image with a face, and Click the Run button. 🤗
590 | """ 591 | 592 | text_to_single_id_tips = r"""💡💡💡Tips:
593 | 1. Try to avoid creating too small faces, as this may lead to some artifacts. (Currently, the short side length of the generated image is limited to 512)
594 | 2. It's a good idea to upload multiple reference photos of your face to improve the prompt and ID consistency. Additional references can be uploaded in the "ID supplements".
595 | 3. The appropriate values of "Face ID Scale" and "Face Structure Scale" are important for balancing the ID and text alignment. We recommend using "Face ID Scale" (0.5~0.7) and "Face Structure Scale" (0.0~0.4).
596 | """ 597 | 598 | text_to_multi_id_description = r"""🚀🚀🚀Quick start:
599 | 1. Enter a text prompt (Chinese or English), Upload an image with a face in "First ID" and "Second ID" blocks respectively, and Click the Run button. 🤗
600 | """ 601 | 602 | text_to_multi_id_tips = r"""💡💡💡Tips:
603 | 1. Try to avoid creating too small faces, as this may lead to some artifacts. (Currently, the short side length of the generated image is limited to 512)
604 | 2. It's a good idea to upload multiple reference photos of your face to improve the prompt and ID consistency. Additional references can be uploaded in the "ID supplements".
605 | 3. The appropriate values of "Face ID Scale" and "Face Structure Scale" are important for balancing the ID and text alignment. We recommend using "Face ID Scale" (0.3~0.7) and "Face Structure Scale" (0.0~0.4).
606 | """ 607 | 608 | image_to_single_id_description = r"""🚀🚀🚀Quick start: Upload an image as the portrait reference (can be any style), Upload a face image, and Click the Run button. 🤗
""" 609 | 610 | image_to_single_id_tips = r"""💡💡💡Tips:
611 | 1. Try to avoid creating too small faces, as this may lead to some artifacts. (Currently, the short side length of the generated image is limited to 512)
612 | 2. It's a good idea to upload multiple reference photos of your face to improve ID consistency. Additional references can be uploaded in the "ID supplements".
613 | 3. The appropriate values of "Face ID Scale" and "Face Structure Scale" are important for balancing the portrait reference and ID alignment. We recommend using "Face ID Scale" (0.5~0.7) and "Face Structure Scale" (0.0~0.4).
614 | """ 615 | 616 | citation = r""" 617 | --- 618 | 📝 **Citation** 619 |
620 | If our work is helpful for your research or applications, please cite us via: 621 | ```bibtex 622 | @article{he2024uniportrait, 623 | title={UniPortrait: A Unified Framework for Identity-Preserving Single-and Multi-Human Image Personalization}, 624 | author={He, Junjie and Geng, Yifeng and Bo, Liefeng}, 625 | journal={arXiv preprint arXiv:2408.05939}, 626 | year={2024} 627 | } 628 | ``` 629 | 📧 **Contact** 630 |
631 | If you have any questions, please feel free to open an issue or directly reach us out at hejunjie1103@gmail.com. 632 | """ 633 | 634 | block = gr.Blocks(title="UniPortrait").queue() 635 | with block: 636 | gr.HTML(title) 637 | gr.HTML(title_description) 638 | 639 | with gr.TabItem("Text-to-Single-ID"): 640 | text_to_single_id_generation_block() 641 | 642 | with gr.TabItem("Text-to-Multi-ID"): 643 | text_to_multi_id_generation_block() 644 | 645 | with gr.TabItem("Image-to-Single-ID (Stylization)"): 646 | image_to_single_id_generation_block() 647 | 648 | gr.Markdown(citation) 649 | 650 | block.launch(server_name='0.0.0.0', share=False, server_port=port, allowed_paths=["/"]) 651 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers 2 | gradio 3 | onnxruntime-gpu 4 | insightface 5 | torch 6 | tqdm 7 | transformers 8 | -------------------------------------------------------------------------------- /uniportrait/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/uniportrait/__init__.py -------------------------------------------------------------------------------- /uniportrait/curricular_face/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/uniportrait/curricular_face/__init__.py -------------------------------------------------------------------------------- /uniportrait/curricular_face/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at 2 | # https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone 3 | from .model_irse import (IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, IR_SE_50, 4 | IR_SE_101, IR_SE_152, IR_SE_200) 5 | from .model_resnet import ResNet_50, ResNet_101, ResNet_152 6 | 7 | _model_dict = { 8 | 'ResNet_50': ResNet_50, 9 | 'ResNet_101': ResNet_101, 10 | 'ResNet_152': ResNet_152, 11 | 'IR_18': IR_18, 12 | 'IR_34': IR_34, 13 | 'IR_50': IR_50, 14 | 'IR_101': IR_101, 15 | 'IR_152': IR_152, 16 | 'IR_200': IR_200, 17 | 'IR_SE_50': IR_SE_50, 18 | 'IR_SE_101': IR_SE_101, 19 | 'IR_SE_152': IR_SE_152, 20 | 'IR_SE_200': IR_SE_200 21 | } 22 | 23 | 24 | def get_model(key): 25 | """ Get different backbone network by key, 26 | support ResNet50, ResNet_101, ResNet_152 27 | IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, 28 | IR_SE_50, IR_SE_101, IR_SE_152, IR_SE_200. 29 | """ 30 | if key in _model_dict.keys(): 31 | return _model_dict[key] 32 | else: 33 | raise KeyError('not support model {}'.format(key)) 34 | -------------------------------------------------------------------------------- /uniportrait/curricular_face/backbone/common.py: -------------------------------------------------------------------------------- 1 | # The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at 2 | # https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/common.py 3 | import torch.nn as nn 4 | from torch.nn import (Conv2d, Module, ReLU, 5 | Sigmoid) 6 | 7 | 8 | def initialize_weights(modules): 9 | """ Weight initilize, conv2d and linear is initialized with kaiming_normal 10 | """ 11 | for m in modules: 12 | if isinstance(m, nn.Conv2d): 13 | nn.init.kaiming_normal_( 14 | m.weight, mode='fan_out', nonlinearity='relu') 15 | if m.bias is not None: 16 | m.bias.data.zero_() 17 | elif isinstance(m, nn.BatchNorm2d): 18 | m.weight.data.fill_(1) 19 | m.bias.data.zero_() 20 | elif isinstance(m, nn.Linear): 21 | nn.init.kaiming_normal_( 22 | m.weight, mode='fan_out', nonlinearity='relu') 23 | if m.bias is not None: 24 | m.bias.data.zero_() 25 | 26 | 27 | class Flatten(Module): 28 | """ Flat tensor 29 | """ 30 | 31 | def forward(self, input): 32 | return input.view(input.size(0), -1) 33 | 34 | 35 | class SEModule(Module): 36 | """ SE block 37 | """ 38 | 39 | def __init__(self, channels, reduction): 40 | super(SEModule, self).__init__() 41 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 42 | self.fc1 = Conv2d( 43 | channels, 44 | channels // reduction, 45 | kernel_size=1, 46 | padding=0, 47 | bias=False) 48 | 49 | nn.init.xavier_uniform_(self.fc1.weight.data) 50 | 51 | self.relu = ReLU(inplace=True) 52 | self.fc2 = Conv2d( 53 | channels // reduction, 54 | channels, 55 | kernel_size=1, 56 | padding=0, 57 | bias=False) 58 | 59 | self.sigmoid = Sigmoid() 60 | 61 | def forward(self, x): 62 | module_input = x 63 | x = self.avg_pool(x) 64 | x = self.fc1(x) 65 | x = self.relu(x) 66 | x = self.fc2(x) 67 | x = self.sigmoid(x) 68 | 69 | return module_input * x 70 | -------------------------------------------------------------------------------- /uniportrait/curricular_face/backbone/model_irse.py: -------------------------------------------------------------------------------- 1 | # The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at 2 | # https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_irse.py 3 | from collections import namedtuple 4 | 5 | from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, 6 | MaxPool2d, Module, PReLU, Sequential) 7 | 8 | from .common import Flatten, SEModule, initialize_weights 9 | 10 | 11 | class BasicBlockIR(Module): 12 | """ BasicBlock for IRNet 13 | """ 14 | 15 | def __init__(self, in_channel, depth, stride): 16 | super(BasicBlockIR, self).__init__() 17 | if in_channel == depth: 18 | self.shortcut_layer = MaxPool2d(1, stride) 19 | else: 20 | self.shortcut_layer = Sequential( 21 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 22 | BatchNorm2d(depth)) 23 | self.res_layer = Sequential( 24 | BatchNorm2d(in_channel), 25 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 26 | BatchNorm2d(depth), PReLU(depth), 27 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 28 | BatchNorm2d(depth)) 29 | 30 | def forward(self, x): 31 | shortcut = self.shortcut_layer(x) 32 | res = self.res_layer(x) 33 | 34 | return res + shortcut 35 | 36 | 37 | class BottleneckIR(Module): 38 | """ BasicBlock with bottleneck for IRNet 39 | """ 40 | 41 | def __init__(self, in_channel, depth, stride): 42 | super(BottleneckIR, self).__init__() 43 | reduction_channel = depth // 4 44 | if in_channel == depth: 45 | self.shortcut_layer = MaxPool2d(1, stride) 46 | else: 47 | self.shortcut_layer = Sequential( 48 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 49 | BatchNorm2d(depth)) 50 | self.res_layer = Sequential( 51 | BatchNorm2d(in_channel), 52 | Conv2d( 53 | in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False), 54 | BatchNorm2d(reduction_channel), PReLU(reduction_channel), 55 | Conv2d( 56 | reduction_channel, 57 | reduction_channel, (3, 3), (1, 1), 58 | 1, 59 | bias=False), BatchNorm2d(reduction_channel), 60 | PReLU(reduction_channel), 61 | Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False), 62 | BatchNorm2d(depth)) 63 | 64 | def forward(self, x): 65 | shortcut = self.shortcut_layer(x) 66 | res = self.res_layer(x) 67 | 68 | return res + shortcut 69 | 70 | 71 | class BasicBlockIRSE(BasicBlockIR): 72 | 73 | def __init__(self, in_channel, depth, stride): 74 | super(BasicBlockIRSE, self).__init__(in_channel, depth, stride) 75 | self.res_layer.add_module('se_block', SEModule(depth, 16)) 76 | 77 | 78 | class BottleneckIRSE(BottleneckIR): 79 | 80 | def __init__(self, in_channel, depth, stride): 81 | super(BottleneckIRSE, self).__init__(in_channel, depth, stride) 82 | self.res_layer.add_module('se_block', SEModule(depth, 16)) 83 | 84 | 85 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 86 | '''A named tuple describing a ResNet block.''' 87 | 88 | 89 | def get_block(in_channel, depth, num_units, stride=2): 90 | return [Bottleneck(in_channel, depth, stride)] + \ 91 | [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 92 | 93 | 94 | def get_blocks(num_layers): 95 | if num_layers == 18: 96 | blocks = [ 97 | get_block(in_channel=64, depth=64, num_units=2), 98 | get_block(in_channel=64, depth=128, num_units=2), 99 | get_block(in_channel=128, depth=256, num_units=2), 100 | get_block(in_channel=256, depth=512, num_units=2) 101 | ] 102 | elif num_layers == 34: 103 | blocks = [ 104 | get_block(in_channel=64, depth=64, num_units=3), 105 | get_block(in_channel=64, depth=128, num_units=4), 106 | get_block(in_channel=128, depth=256, num_units=6), 107 | get_block(in_channel=256, depth=512, num_units=3) 108 | ] 109 | elif num_layers == 50: 110 | blocks = [ 111 | get_block(in_channel=64, depth=64, num_units=3), 112 | get_block(in_channel=64, depth=128, num_units=4), 113 | get_block(in_channel=128, depth=256, num_units=14), 114 | get_block(in_channel=256, depth=512, num_units=3) 115 | ] 116 | elif num_layers == 100: 117 | blocks = [ 118 | get_block(in_channel=64, depth=64, num_units=3), 119 | get_block(in_channel=64, depth=128, num_units=13), 120 | get_block(in_channel=128, depth=256, num_units=30), 121 | get_block(in_channel=256, depth=512, num_units=3) 122 | ] 123 | elif num_layers == 152: 124 | blocks = [ 125 | get_block(in_channel=64, depth=256, num_units=3), 126 | get_block(in_channel=256, depth=512, num_units=8), 127 | get_block(in_channel=512, depth=1024, num_units=36), 128 | get_block(in_channel=1024, depth=2048, num_units=3) 129 | ] 130 | elif num_layers == 200: 131 | blocks = [ 132 | get_block(in_channel=64, depth=256, num_units=3), 133 | get_block(in_channel=256, depth=512, num_units=24), 134 | get_block(in_channel=512, depth=1024, num_units=36), 135 | get_block(in_channel=1024, depth=2048, num_units=3) 136 | ] 137 | 138 | return blocks 139 | 140 | 141 | class Backbone(Module): 142 | 143 | def __init__(self, input_size, num_layers, mode='ir'): 144 | """ Args: 145 | input_size: input_size of backbone 146 | num_layers: num_layers of backbone 147 | mode: support ir or irse 148 | """ 149 | super(Backbone, self).__init__() 150 | assert input_size[0] in [112, 224], \ 151 | 'input_size should be [112, 112] or [224, 224]' 152 | assert num_layers in [18, 34, 50, 100, 152, 200], \ 153 | 'num_layers should be 18, 34, 50, 100 or 152' 154 | assert mode in ['ir', 'ir_se'], \ 155 | 'mode should be ir or ir_se' 156 | self.input_layer = Sequential( 157 | Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), 158 | PReLU(64)) 159 | blocks = get_blocks(num_layers) 160 | if num_layers <= 100: 161 | if mode == 'ir': 162 | unit_module = BasicBlockIR 163 | elif mode == 'ir_se': 164 | unit_module = BasicBlockIRSE 165 | output_channel = 512 166 | else: 167 | if mode == 'ir': 168 | unit_module = BottleneckIR 169 | elif mode == 'ir_se': 170 | unit_module = BottleneckIRSE 171 | output_channel = 2048 172 | 173 | if input_size[0] == 112: 174 | self.output_layer = Sequential( 175 | BatchNorm2d(output_channel), Dropout(0.4), Flatten(), 176 | Linear(output_channel * 7 * 7, 512), 177 | BatchNorm1d(512, affine=False)) 178 | else: 179 | self.output_layer = Sequential( 180 | BatchNorm2d(output_channel), Dropout(0.4), Flatten(), 181 | Linear(output_channel * 14 * 14, 512), 182 | BatchNorm1d(512, affine=False)) 183 | 184 | modules = [] 185 | mid_layer_indices = [] # [2, 15, 45, 48], total 49 layers for IR101 186 | for block in blocks: 187 | if len(mid_layer_indices) == 0: 188 | mid_layer_indices.append(len(block) - 1) 189 | else: 190 | mid_layer_indices.append(len(block) + mid_layer_indices[-1]) 191 | for bottleneck in block: 192 | modules.append( 193 | unit_module(bottleneck.in_channel, bottleneck.depth, 194 | bottleneck.stride)) 195 | self.body = Sequential(*modules) 196 | self.mid_layer_indices = mid_layer_indices[-4:] 197 | 198 | initialize_weights(self.modules()) 199 | 200 | def forward(self, x, return_mid_feats=False): 201 | x = self.input_layer(x) 202 | if not return_mid_feats: 203 | x = self.body(x) 204 | x = self.output_layer(x) 205 | return x 206 | else: 207 | out_feats = [] 208 | for idx, module in enumerate(self.body): 209 | x = module(x) 210 | if idx in self.mid_layer_indices: 211 | out_feats.append(x) 212 | x = self.output_layer(x) 213 | return x, out_feats 214 | 215 | 216 | def IR_18(input_size): 217 | """ Constructs a ir-18 model. 218 | """ 219 | model = Backbone(input_size, 18, 'ir') 220 | 221 | return model 222 | 223 | 224 | def IR_34(input_size): 225 | """ Constructs a ir-34 model. 226 | """ 227 | model = Backbone(input_size, 34, 'ir') 228 | 229 | return model 230 | 231 | 232 | def IR_50(input_size): 233 | """ Constructs a ir-50 model. 234 | """ 235 | model = Backbone(input_size, 50, 'ir') 236 | 237 | return model 238 | 239 | 240 | def IR_101(input_size): 241 | """ Constructs a ir-101 model. 242 | """ 243 | model = Backbone(input_size, 100, 'ir') 244 | 245 | return model 246 | 247 | 248 | def IR_152(input_size): 249 | """ Constructs a ir-152 model. 250 | """ 251 | model = Backbone(input_size, 152, 'ir') 252 | 253 | return model 254 | 255 | 256 | def IR_200(input_size): 257 | """ Constructs a ir-200 model. 258 | """ 259 | model = Backbone(input_size, 200, 'ir') 260 | 261 | return model 262 | 263 | 264 | def IR_SE_50(input_size): 265 | """ Constructs a ir_se-50 model. 266 | """ 267 | model = Backbone(input_size, 50, 'ir_se') 268 | 269 | return model 270 | 271 | 272 | def IR_SE_101(input_size): 273 | """ Constructs a ir_se-101 model. 274 | """ 275 | model = Backbone(input_size, 100, 'ir_se') 276 | 277 | return model 278 | 279 | 280 | def IR_SE_152(input_size): 281 | """ Constructs a ir_se-152 model. 282 | """ 283 | model = Backbone(input_size, 152, 'ir_se') 284 | 285 | return model 286 | 287 | 288 | def IR_SE_200(input_size): 289 | """ Constructs a ir_se-200 model. 290 | """ 291 | model = Backbone(input_size, 200, 'ir_se') 292 | 293 | return model 294 | -------------------------------------------------------------------------------- /uniportrait/curricular_face/backbone/model_resnet.py: -------------------------------------------------------------------------------- 1 | # The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at 2 | # https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_resnet.py 3 | import torch.nn as nn 4 | from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, 5 | MaxPool2d, Module, ReLU, Sequential) 6 | 7 | from .common import initialize_weights 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | """ 3x3 convolution with padding 12 | """ 13 | return Conv2d( 14 | in_planes, 15 | out_planes, 16 | kernel_size=3, 17 | stride=stride, 18 | padding=1, 19 | bias=False) 20 | 21 | 22 | def conv1x1(in_planes, out_planes, stride=1): 23 | """ 1x1 convolution 24 | """ 25 | return Conv2d( 26 | in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 27 | 28 | 29 | class Bottleneck(Module): 30 | expansion = 4 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(Bottleneck, self).__init__() 34 | self.conv1 = conv1x1(inplanes, planes) 35 | self.bn1 = BatchNorm2d(planes) 36 | self.conv2 = conv3x3(planes, planes, stride) 37 | self.bn2 = BatchNorm2d(planes) 38 | self.conv3 = conv1x1(planes, planes * self.expansion) 39 | self.bn3 = BatchNorm2d(planes * self.expansion) 40 | self.relu = ReLU(inplace=True) 41 | self.downsample = downsample 42 | self.stride = stride 43 | 44 | def forward(self, x): 45 | identity = x 46 | 47 | out = self.conv1(x) 48 | out = self.bn1(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv2(out) 52 | out = self.bn2(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv3(out) 56 | out = self.bn3(out) 57 | 58 | if self.downsample is not None: 59 | identity = self.downsample(x) 60 | 61 | out += identity 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class ResNet(Module): 68 | """ ResNet backbone 69 | """ 70 | 71 | def __init__(self, input_size, block, layers, zero_init_residual=True): 72 | """ Args: 73 | input_size: input_size of backbone 74 | block: block function 75 | layers: layers in each block 76 | """ 77 | super(ResNet, self).__init__() 78 | assert input_size[0] in [112, 224], \ 79 | 'input_size should be [112, 112] or [224, 224]' 80 | self.inplanes = 64 81 | self.conv1 = Conv2d( 82 | 3, 64, kernel_size=7, stride=2, padding=3, bias=False) 83 | self.bn1 = BatchNorm2d(64) 84 | self.relu = ReLU(inplace=True) 85 | self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1) 86 | self.layer1 = self._make_layer(block, 64, layers[0]) 87 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 88 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 89 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 90 | 91 | self.bn_o1 = BatchNorm2d(2048) 92 | self.dropout = Dropout() 93 | if input_size[0] == 112: 94 | self.fc = Linear(2048 * 4 * 4, 512) 95 | else: 96 | self.fc = Linear(2048 * 7 * 7, 512) 97 | self.bn_o2 = BatchNorm1d(512) 98 | 99 | initialize_weights(self.modules) 100 | if zero_init_residual: 101 | for m in self.modules(): 102 | if isinstance(m, Bottleneck): 103 | nn.init.constant_(m.bn3.weight, 0) 104 | 105 | def _make_layer(self, block, planes, blocks, stride=1): 106 | downsample = None 107 | if stride != 1 or self.inplanes != planes * block.expansion: 108 | downsample = Sequential( 109 | conv1x1(self.inplanes, planes * block.expansion, stride), 110 | BatchNorm2d(planes * block.expansion), 111 | ) 112 | 113 | layers = [] 114 | layers.append(block(self.inplanes, planes, stride, downsample)) 115 | self.inplanes = planes * block.expansion 116 | for _ in range(1, blocks): 117 | layers.append(block(self.inplanes, planes)) 118 | 119 | return Sequential(*layers) 120 | 121 | def forward(self, x): 122 | x = self.conv1(x) 123 | x = self.bn1(x) 124 | x = self.relu(x) 125 | x = self.maxpool(x) 126 | 127 | x = self.layer1(x) 128 | x = self.layer2(x) 129 | x = self.layer3(x) 130 | x = self.layer4(x) 131 | 132 | x = self.bn_o1(x) 133 | x = self.dropout(x) 134 | x = x.view(x.size(0), -1) 135 | x = self.fc(x) 136 | x = self.bn_o2(x) 137 | 138 | return x 139 | 140 | 141 | def ResNet_50(input_size, **kwargs): 142 | """ Constructs a ResNet-50 model. 143 | """ 144 | model = ResNet(input_size, Bottleneck, [3, 4, 6, 3], **kwargs) 145 | 146 | return model 147 | 148 | 149 | def ResNet_101(input_size, **kwargs): 150 | """ Constructs a ResNet-101 model. 151 | """ 152 | model = ResNet(input_size, Bottleneck, [3, 4, 23, 3], **kwargs) 153 | 154 | return model 155 | 156 | 157 | def ResNet_152(input_size, **kwargs): 158 | """ Constructs a ResNet-152 model. 159 | """ 160 | model = ResNet(input_size, Bottleneck, [3, 8, 36, 3], **kwargs) 161 | 162 | return model 163 | -------------------------------------------------------------------------------- /uniportrait/curricular_face/inference.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from tqdm.auto import tqdm 8 | 9 | from .backbone import get_model 10 | 11 | 12 | @torch.no_grad() 13 | def inference(name, weight, src_norm_dir): 14 | face_model = get_model(name)([112, 112]) 15 | face_model.load_state_dict(torch.load(weight, map_location="cpu")) 16 | face_model = face_model.to("cpu") 17 | face_model.eval() 18 | 19 | id2src_norm = {} 20 | for src_id in sorted(list(os.listdir(src_norm_dir))): 21 | id2src_norm[src_id] = sorted(list(glob.glob(f"{os.path.join(src_norm_dir, src_id)}/*"))) 22 | 23 | total_sims = [] 24 | for id_name in tqdm(id2src_norm): 25 | src_face_embeddings = [] 26 | for src_img_path in id2src_norm[id_name]: 27 | src_img = cv2.imread(src_img_path) 28 | src_img = cv2.resize(src_img, (112, 112)) 29 | src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB) 30 | src_img = np.transpose(src_img, (2, 0, 1)) 31 | src_img = torch.from_numpy(src_img).unsqueeze(0).float() 32 | src_img.div_(255).sub_(0.5).div_(0.5) 33 | embedding = face_model(src_img).detach().cpu().numpy()[0] 34 | embedding = embedding / np.linalg.norm(embedding) 35 | src_face_embeddings.append(embedding) # 512 36 | 37 | num = len(src_face_embeddings) 38 | src_face_embeddings = np.stack(src_face_embeddings) # n, 512 39 | sim = src_face_embeddings @ src_face_embeddings.T # n, n 40 | mean_sim = (np.sum(sim) - num * 1.0) / ((num - 1) * num) 41 | print(f"{id_name}: {mean_sim}") 42 | total_sims.append(mean_sim) 43 | 44 | return np.mean(total_sims) 45 | 46 | 47 | if __name__ == "__main__": 48 | name = 'IR_101' 49 | weight = "models/glint360k_curricular_face_r101_backbone.bin" 50 | src_norm_dir = "/disk1/hejunjie.hjj/data/normface-AFD-id-20" 51 | mean_sim = inference(name, weight, src_norm_dir) 52 | print(f"total: {mean_sim:.4f}") # total: 0.6299 53 | -------------------------------------------------------------------------------- /uniportrait/inversion.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/google/style-aligned/blob/main/inversion.py 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Callable 6 | 7 | import numpy as np 8 | import torch 9 | from diffusers import StableDiffusionPipeline 10 | from tqdm import tqdm 11 | 12 | T = torch.Tensor 13 | InversionCallback = Callable[[StableDiffusionPipeline, int, T, dict[str, T]], dict[str, T]] 14 | 15 | 16 | def _encode_text_with_negative(model: StableDiffusionPipeline, prompt: str) -> tuple[dict[str, T], T]: 17 | device = model._execution_device 18 | prompt_embeds = model._encode_prompt( 19 | prompt, device=device, num_images_per_prompt=1, do_classifier_free_guidance=True, 20 | negative_prompt="") 21 | return prompt_embeds 22 | 23 | 24 | def _encode_image(model: StableDiffusionPipeline, image: np.ndarray) -> T: 25 | model.vae.to(dtype=torch.float32) 26 | image = torch.from_numpy(image).float() / 255. 27 | image = (image * 2 - 1).permute(2, 0, 1).unsqueeze(0) 28 | latent = model.vae.encode(image.to(model.vae.device))['latent_dist'].mean * model.vae.config.scaling_factor 29 | model.vae.to(dtype=torch.float16) 30 | return latent 31 | 32 | 33 | def _next_step(model: StableDiffusionPipeline, model_output: T, timestep: int, sample: T) -> T: 34 | timestep, next_timestep = min( 35 | timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps, 999), timestep 36 | alpha_prod_t = model.scheduler.alphas_cumprod[ 37 | int(timestep)] if timestep >= 0 else model.scheduler.final_alpha_cumprod 38 | alpha_prod_t_next = model.scheduler.alphas_cumprod[int(next_timestep)] 39 | beta_prod_t = 1 - alpha_prod_t 40 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 41 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 42 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 43 | return next_sample 44 | 45 | 46 | def _get_noise_pred(model: StableDiffusionPipeline, latent: T, t: T, context: T, guidance_scale: float): 47 | latents_input = torch.cat([latent] * 2) 48 | noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"] 49 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 50 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) 51 | # latents = next_step(model, noise_pred, t, latent) 52 | return noise_pred 53 | 54 | 55 | def _ddim_loop(model: StableDiffusionPipeline, z0, prompt, guidance_scale) -> T: 56 | all_latent = [z0] 57 | text_embedding = _encode_text_with_negative(model, prompt) 58 | image_embedding = torch.zeros_like(text_embedding[:, :1]).repeat(1, 4, 1) # for ip embedding 59 | text_embedding = torch.cat([text_embedding, image_embedding], dim=1) 60 | latent = z0.clone().detach().half() 61 | for i in tqdm(range(model.scheduler.num_inference_steps)): 62 | t = model.scheduler.timesteps[len(model.scheduler.timesteps) - i - 1] 63 | noise_pred = _get_noise_pred(model, latent, t, text_embedding, guidance_scale) 64 | latent = _next_step(model, noise_pred, t, latent) 65 | all_latent.append(latent) 66 | return torch.cat(all_latent).flip(0) 67 | 68 | 69 | def make_inversion_callback(zts, offset: int = 0) -> [T, InversionCallback]: 70 | def callback_on_step_end(pipeline: StableDiffusionPipeline, i: int, t: T, callback_kwargs: dict[str, T]) -> dict[ 71 | str, T]: 72 | latents = callback_kwargs['latents'] 73 | latents[0] = zts[max(offset + 1, i + 1)].to(latents.device, latents.dtype) 74 | return {'latents': latents} 75 | 76 | return zts[offset], callback_on_step_end 77 | 78 | 79 | @torch.no_grad() 80 | def ddim_inversion(model: StableDiffusionPipeline, x0: np.ndarray, prompt: str, num_inference_steps: int, 81 | guidance_scale, ) -> T: 82 | z0 = _encode_image(model, x0) 83 | model.scheduler.set_timesteps(num_inference_steps, device=z0.device) 84 | zs = _ddim_loop(model, z0, prompt, guidance_scale) 85 | return zs 86 | -------------------------------------------------------------------------------- /uniportrait/resampler.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | # FFN 11 | def FeedForward(dim, mult=4): 12 | inner_dim = int(dim * mult) 13 | return nn.Sequential( 14 | nn.LayerNorm(dim), 15 | nn.Linear(dim, inner_dim, bias=False), 16 | nn.GELU(), 17 | nn.Linear(inner_dim, dim, bias=False), 18 | ) 19 | 20 | 21 | def reshape_tensor(x, heads): 22 | bs, length, width = x.shape 23 | # (bs, length, width) --> (bs, length, n_heads, dim_per_head) 24 | x = x.view(bs, length, heads, -1) 25 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 26 | x = x.transpose(1, 2) 27 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 28 | x = x.reshape(bs, heads, length, -1) 29 | return x 30 | 31 | 32 | class PerceiverAttention(nn.Module): 33 | def __init__(self, *, dim, dim_head=64, heads=8): 34 | super().__init__() 35 | self.scale = dim_head ** -0.5 36 | self.dim_head = dim_head 37 | self.heads = heads 38 | inner_dim = dim_head * heads 39 | 40 | self.norm1 = nn.LayerNorm(dim) 41 | self.norm2 = nn.LayerNorm(dim) 42 | 43 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 44 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 45 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 46 | 47 | def forward(self, x, latents, attention_mask=None): 48 | """ 49 | Args: 50 | x (torch.Tensor): image features 51 | shape (b, n1, D) 52 | latents (torch.Tensor): latent features 53 | shape (b, n2, D) 54 | attention_mask (torch.Tensor): attention mask 55 | shape (b, n1, 1) 56 | """ 57 | x = self.norm1(x) 58 | latents = self.norm2(latents) 59 | 60 | b, l, _ = latents.shape 61 | 62 | q = self.to_q(latents) 63 | kv_input = torch.cat((x, latents), dim=-2) 64 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 65 | 66 | q = reshape_tensor(q, self.heads) 67 | k = reshape_tensor(k, self.heads) 68 | v = reshape_tensor(v, self.heads) 69 | 70 | # attention 71 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 72 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 73 | if attention_mask is not None: 74 | attention_mask = attention_mask.transpose(1, 2) # (b, 1, n1) 75 | attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, :, :1]).repeat(1, 1, l)], 76 | dim=2) # b, 1, n1+n2 77 | attention_mask = (attention_mask - 1.) * 100. # 0 means kept and -100 means dropped 78 | attention_mask = attention_mask.unsqueeze(1) 79 | weight = weight + attention_mask # b, h, n2, n1+n2 80 | 81 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 82 | out = weight @ v 83 | 84 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 85 | 86 | return self.to_out(out) 87 | 88 | 89 | class UniPortraitFaceIDResampler(torch.nn.Module): 90 | def __init__( 91 | self, 92 | intrinsic_id_embedding_dim=512, 93 | structure_embedding_dim=64 + 128 + 256 + 1280, 94 | num_tokens=16, 95 | depth=6, 96 | dim=768, 97 | dim_head=64, 98 | heads=12, 99 | ff_mult=4, 100 | output_dim=768, 101 | ): 102 | super().__init__() 103 | 104 | self.latents = torch.nn.Parameter(torch.randn(1, num_tokens, dim) / dim ** 0.5) 105 | 106 | self.proj_id = torch.nn.Sequential( 107 | torch.nn.Linear(intrinsic_id_embedding_dim, intrinsic_id_embedding_dim * 2), 108 | torch.nn.GELU(), 109 | torch.nn.Linear(intrinsic_id_embedding_dim * 2, dim), 110 | ) 111 | self.proj_clip = torch.nn.Sequential( 112 | torch.nn.Linear(structure_embedding_dim, structure_embedding_dim * 2), 113 | torch.nn.GELU(), 114 | torch.nn.Linear(structure_embedding_dim * 2, dim), 115 | ) 116 | 117 | self.layers = torch.nn.ModuleList([]) 118 | for _ in range(depth): 119 | self.layers.append( 120 | torch.nn.ModuleList( 121 | [ 122 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 123 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 124 | FeedForward(dim=dim, mult=ff_mult), 125 | ] 126 | ) 127 | ) 128 | 129 | self.proj_out = torch.nn.Linear(dim, output_dim) 130 | self.norm_out = torch.nn.LayerNorm(output_dim) 131 | 132 | def forward( 133 | self, 134 | intrinsic_id_embeds, 135 | structure_embeds, 136 | structure_scale=1.0, 137 | intrinsic_id_attention_mask=None, 138 | structure_attention_mask=None 139 | ): 140 | 141 | latents = self.latents.repeat(intrinsic_id_embeds.size(0), 1, 1) 142 | 143 | intrinsic_id_embeds = self.proj_id(intrinsic_id_embeds) 144 | structure_embeds = self.proj_clip(structure_embeds) 145 | 146 | for attn1, attn2, ff in self.layers: 147 | latents = attn1(intrinsic_id_embeds, latents, intrinsic_id_attention_mask) + latents 148 | latents = structure_scale * attn2(structure_embeds, latents, structure_attention_mask) + latents 149 | latents = ff(latents) + latents 150 | 151 | latents = self.proj_out(latents) 152 | return self.norm_out(latents) 153 | -------------------------------------------------------------------------------- /uniportrait/uniportrait_attention_processor.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from diffusers.models.lora import LoRALinearLayer 6 | 7 | 8 | class AttentionArgs(object): 9 | def __init__(self) -> None: 10 | # ip condition 11 | self.ip_scale = 0.0 12 | self.ip_mask = None # ip attention mask 13 | 14 | # faceid condition 15 | self.lora_scale = 0.0 # lora for single faceid 16 | self.multi_id_lora_scale = 0.0 # lora for multiple faceids 17 | self.faceid_scale = 0.0 18 | self.num_faceids = 0 19 | self.faceid_mask = None # faceid attention mask; if not None, it will override the routing map 20 | 21 | # style aligned 22 | self.enable_share_attn: bool = False 23 | self.adain_queries_and_keys: bool = False 24 | self.shared_score_scale: float = 1.0 25 | self.shared_score_shift: float = 0.0 26 | 27 | def reset(self): 28 | # ip condition 29 | self.ip_scale = 0.0 30 | self.ip_mask = None # ip attention mask 31 | 32 | # faceid condition 33 | self.lora_scale = 0.0 # lora for single faceid 34 | self.multi_id_lora_scale = 0.0 # lora for multiple faceids 35 | self.faceid_scale = 0.0 36 | self.num_faceids = 0 37 | self.faceid_mask = None # faceid attention mask; if not None, it will override the routing map 38 | 39 | # style aligned 40 | self.enable_share_attn: bool = False 41 | self.adain_queries_and_keys: bool = False 42 | self.shared_score_scale: float = 1.0 43 | self.shared_score_shift: float = 0.0 44 | 45 | def __repr__(self): 46 | indent_str = ' ' 47 | s = f",\n{indent_str}".join(f"{attr}={value}" for attr, value in vars(self).items()) 48 | return self.__class__.__name__ + '(' + f'\n{indent_str}' + s + ')' 49 | 50 | 51 | attn_args = AttentionArgs() 52 | 53 | 54 | def expand_first(feat, scale=1., ): 55 | b = feat.shape[0] 56 | feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1) 57 | if scale == 1: 58 | feat_style = feat_style.expand(2, b // 2, *feat.shape[1:]) 59 | else: 60 | feat_style = feat_style.repeat(1, b // 2, 1, 1, 1) 61 | feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1) 62 | return feat_style.reshape(*feat.shape) 63 | 64 | 65 | def concat_first(feat, dim=2, scale=1.): 66 | feat_style = expand_first(feat, scale=scale) 67 | return torch.cat((feat, feat_style), dim=dim) 68 | 69 | 70 | def calc_mean_std(feat, eps: float = 1e-5): 71 | feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt() 72 | feat_mean = feat.mean(dim=-2, keepdims=True) 73 | return feat_mean, feat_std 74 | 75 | 76 | def adain(feat): 77 | feat_mean, feat_std = calc_mean_std(feat) 78 | feat_style_mean = expand_first(feat_mean) 79 | feat_style_std = expand_first(feat_std) 80 | feat = (feat - feat_mean) / feat_std 81 | feat = feat * feat_style_std + feat_style_mean 82 | return feat 83 | 84 | 85 | class UniPortraitLoRAAttnProcessor2_0(nn.Module): 86 | 87 | def __init__( 88 | self, 89 | hidden_size=None, 90 | cross_attention_dim=None, 91 | rank=128, 92 | network_alpha=None, 93 | ): 94 | super().__init__() 95 | 96 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) 97 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) 98 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) 99 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) 100 | 101 | self.to_q_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) 102 | self.to_k_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) 103 | self.to_v_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) 104 | self.to_out_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) 105 | 106 | def __call__( 107 | self, 108 | attn, 109 | hidden_states, 110 | encoder_hidden_states=None, 111 | attention_mask=None, 112 | temb=None, 113 | *args, 114 | **kwargs, 115 | ): 116 | residual = hidden_states 117 | 118 | if attn.spatial_norm is not None: 119 | hidden_states = attn.spatial_norm(hidden_states, temb) 120 | 121 | input_ndim = hidden_states.ndim 122 | 123 | if input_ndim == 4: 124 | batch_size, channel, height, width = hidden_states.shape 125 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 126 | 127 | batch_size, sequence_length, _ = ( 128 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 129 | ) 130 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 131 | 132 | if attn.group_norm is not None: 133 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 134 | 135 | if encoder_hidden_states is None: 136 | encoder_hidden_states = hidden_states 137 | elif attn.norm_cross: 138 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 139 | 140 | query = attn.to_q(hidden_states) 141 | key = attn.to_k(encoder_hidden_states) 142 | value = attn.to_v(encoder_hidden_states) 143 | if attn_args.lora_scale > 0.0: 144 | query = query + attn_args.lora_scale * self.to_q_lora(hidden_states) 145 | key = key + attn_args.lora_scale * self.to_k_lora(encoder_hidden_states) 146 | value = value + attn_args.lora_scale * self.to_v_lora(encoder_hidden_states) 147 | elif attn_args.multi_id_lora_scale > 0.0: 148 | query = query + attn_args.multi_id_lora_scale * self.to_q_multi_id_lora(hidden_states) 149 | key = key + attn_args.multi_id_lora_scale * self.to_k_multi_id_lora(encoder_hidden_states) 150 | value = value + attn_args.multi_id_lora_scale * self.to_v_multi_id_lora(encoder_hidden_states) 151 | 152 | inner_dim = key.shape[-1] 153 | head_dim = inner_dim // attn.heads 154 | 155 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 156 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 157 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 158 | 159 | if attn_args.enable_share_attn: 160 | if attn_args.adain_queries_and_keys: 161 | query = adain(query) 162 | key = adain(key) 163 | key = concat_first(key, -2, scale=attn_args.shared_score_scale) 164 | value = concat_first(value, -2) 165 | if attn_args.shared_score_shift != 0: 166 | attention_mask = torch.zeros_like(key[:, :, :, :1]).transpose(-1, -2) # b, h, 1, k 167 | attention_mask[:, :, :, query.shape[2]:] += attn_args.shared_score_shift 168 | hidden_states = F.scaled_dot_product_attention( 169 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale 170 | ) 171 | else: 172 | hidden_states = F.scaled_dot_product_attention( 173 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale 174 | ) 175 | else: 176 | hidden_states = F.scaled_dot_product_attention( 177 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale 178 | ) 179 | 180 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 181 | hidden_states = hidden_states.to(query.dtype) 182 | 183 | # linear proj 184 | output_hidden_states = attn.to_out[0](hidden_states) 185 | if attn_args.lora_scale > 0.0: 186 | output_hidden_states = output_hidden_states + attn_args.lora_scale * self.to_out_lora(hidden_states) 187 | elif attn_args.multi_id_lora_scale > 0.0: 188 | output_hidden_states = output_hidden_states + attn_args.multi_id_lora_scale * self.to_out_multi_id_lora( 189 | hidden_states) 190 | hidden_states = output_hidden_states 191 | 192 | # dropout 193 | hidden_states = attn.to_out[1](hidden_states) 194 | 195 | if input_ndim == 4: 196 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 197 | 198 | if attn.residual_connection: 199 | hidden_states = hidden_states + residual 200 | 201 | hidden_states = hidden_states / attn.rescale_output_factor 202 | 203 | return hidden_states 204 | 205 | 206 | class UniPortraitLoRAIPAttnProcessor2_0(nn.Module): 207 | 208 | def __init__(self, hidden_size, cross_attention_dim=None, rank=128, network_alpha=None, 209 | num_ip_tokens=4, num_faceid_tokens=16): 210 | super().__init__() 211 | 212 | self.num_ip_tokens = num_ip_tokens 213 | self.num_faceid_tokens = num_faceid_tokens 214 | 215 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) 216 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) 217 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) 218 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) 219 | 220 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 221 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 222 | 223 | self.to_k_faceid = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 224 | self.to_v_faceid = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 225 | 226 | self.to_q_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) 227 | self.to_k_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) 228 | self.to_v_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) 229 | self.to_out_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) 230 | 231 | self.to_q_router = nn.Sequential( 232 | nn.Linear(hidden_size, hidden_size * 2), 233 | nn.GELU(), 234 | nn.Linear(hidden_size * 2, hidden_size, bias=False), 235 | ) 236 | self.to_k_router = nn.Sequential( 237 | nn.Linear(cross_attention_dim or hidden_size, (cross_attention_dim or hidden_size) * 2), 238 | nn.GELU(), 239 | nn.Linear((cross_attention_dim or hidden_size) * 2, hidden_size, bias=False), 240 | ) 241 | self.aggr_router = nn.Linear(num_faceid_tokens, 1) 242 | 243 | def __call__( 244 | self, 245 | attn, 246 | hidden_states, 247 | encoder_hidden_states=None, 248 | attention_mask=None, 249 | temb=None, 250 | *args, 251 | **kwargs, 252 | ): 253 | residual = hidden_states 254 | 255 | if attn.spatial_norm is not None: 256 | hidden_states = attn.spatial_norm(hidden_states, temb) 257 | 258 | input_ndim = hidden_states.ndim 259 | 260 | if input_ndim == 4: 261 | batch_size, channel, height, width = hidden_states.shape 262 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 263 | 264 | batch_size, sequence_length, _ = ( 265 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 266 | ) 267 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 268 | 269 | if attn.group_norm is not None: 270 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 271 | 272 | if encoder_hidden_states is None: 273 | encoder_hidden_states = hidden_states 274 | else: 275 | # split hidden states 276 | faceid_end = encoder_hidden_states.shape[1] 277 | ip_end = faceid_end - self.num_faceid_tokens * attn_args.num_faceids 278 | text_end = ip_end - self.num_ip_tokens 279 | 280 | prompt_hidden_states = encoder_hidden_states[:, :text_end] 281 | ip_hidden_states = encoder_hidden_states[:, text_end: ip_end] 282 | faceid_hidden_states = encoder_hidden_states[:, ip_end: faceid_end] 283 | 284 | encoder_hidden_states = prompt_hidden_states 285 | if attn.norm_cross: 286 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 287 | 288 | # for router 289 | if attn_args.num_faceids > 1: 290 | router_query = self.to_q_router(hidden_states) # bs, s*s, dim 291 | router_hidden_states = faceid_hidden_states.reshape(batch_size, attn_args.num_faceids, 292 | self.num_faceid_tokens, -1) # bs, num, id_tokens, d 293 | router_hidden_states = self.aggr_router(router_hidden_states.transpose(-1, -2)).squeeze(-1) # bs, num, d 294 | router_key = self.to_k_router(router_hidden_states) # bs, num, dim 295 | router_logits = torch.bmm(router_query, router_key.transpose(-1, -2)) # bs, s*s, num 296 | index = router_logits.max(dim=-1, keepdim=True)[1] 297 | routing_map = torch.zeros_like(router_logits).scatter_(-1, index, 1.0) 298 | routing_map = routing_map.transpose(1, 2).unsqueeze(-1) # bs, num, s*s, 1 299 | else: 300 | routing_map = hidden_states.new_ones(size=(1, 1, hidden_states.shape[1], 1)) 301 | 302 | # for text 303 | query = attn.to_q(hidden_states) 304 | key = attn.to_k(encoder_hidden_states) 305 | value = attn.to_v(encoder_hidden_states) 306 | if attn_args.lora_scale > 0.0: 307 | query = query + attn_args.lora_scale * self.to_q_lora(hidden_states) 308 | key = key + attn_args.lora_scale * self.to_k_lora(encoder_hidden_states) 309 | value = value + attn_args.lora_scale * self.to_v_lora(encoder_hidden_states) 310 | elif attn_args.multi_id_lora_scale > 0.0: 311 | query = query + attn_args.multi_id_lora_scale * self.to_q_multi_id_lora(hidden_states) 312 | key = key + attn_args.multi_id_lora_scale * self.to_k_multi_id_lora(encoder_hidden_states) 313 | value = value + attn_args.multi_id_lora_scale * self.to_v_multi_id_lora(encoder_hidden_states) 314 | 315 | inner_dim = key.shape[-1] 316 | head_dim = inner_dim // attn.heads 317 | 318 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 319 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 320 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 321 | 322 | hidden_states = F.scaled_dot_product_attention( 323 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale 324 | ) 325 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 326 | hidden_states = hidden_states.to(query.dtype) 327 | 328 | # for ip-adapter 329 | if attn_args.ip_scale > 0.0: 330 | ip_key = self.to_k_ip(ip_hidden_states) 331 | ip_value = self.to_v_ip(ip_hidden_states) 332 | 333 | ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 334 | ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 335 | 336 | ip_hidden_states = F.scaled_dot_product_attention( 337 | query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=attn.scale 338 | ) 339 | ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 340 | ip_hidden_states = ip_hidden_states.to(query.dtype) 341 | 342 | if attn_args.ip_mask is not None: 343 | ip_mask = attn_args.ip_mask 344 | h, w = ip_mask.shape[-2:] 345 | ratio = (h * w / query.shape[2]) ** 0.5 346 | ip_mask = torch.nn.functional.interpolate(ip_mask, scale_factor=1 / ratio, 347 | mode='nearest').reshape( 348 | [1, -1, 1]) 349 | ip_hidden_states = ip_hidden_states * ip_mask 350 | 351 | if attn_args.enable_share_attn: 352 | ip_hidden_states[0] = 0. 353 | ip_hidden_states[batch_size // 2] = 0. 354 | else: 355 | ip_hidden_states = torch.zeros_like(hidden_states) 356 | 357 | # for faceid-adapter 358 | if attn_args.faceid_scale > 0.0: 359 | faceid_key = self.to_k_faceid(faceid_hidden_states) 360 | faceid_value = self.to_v_faceid(faceid_hidden_states) 361 | 362 | faceid_query = query[:, None].expand(-1, attn_args.num_faceids, -1, -1, 363 | -1) # 2*bs, num, heads, s*s, dim/heads 364 | faceid_key = faceid_key.view(batch_size, attn_args.num_faceids, self.num_faceid_tokens, attn.heads, 365 | head_dim).transpose(2, 3) 366 | faceid_value = faceid_value.view(batch_size, attn_args.num_faceids, self.num_faceid_tokens, attn.heads, 367 | head_dim).transpose(2, 3) 368 | 369 | faceid_hidden_states = F.scaled_dot_product_attention( 370 | faceid_query, faceid_key, faceid_value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=attn.scale 371 | ) # 2*bs, num, heads, s*s, dim/heads 372 | 373 | faceid_hidden_states = faceid_hidden_states.transpose(2, 3).reshape(batch_size, attn_args.num_faceids, -1, 374 | attn.heads * head_dim) 375 | faceid_hidden_states = faceid_hidden_states.to(query.dtype) # 2*bs, num, s*s, dim 376 | 377 | if attn_args.faceid_mask is not None: 378 | faceid_mask = attn_args.faceid_mask # 1, num, h, w 379 | h, w = faceid_mask.shape[-2:] 380 | ratio = (h * w / query.shape[2]) ** 0.5 381 | faceid_mask = F.interpolate(faceid_mask, scale_factor=1 / ratio, 382 | mode='bilinear').flatten(2).unsqueeze(-1) # 1, num, s*s, 1 383 | faceid_mask = faceid_mask / faceid_mask.sum(1, keepdim=True).clip(min=1e-3) # 1, num, s*s, 1 384 | faceid_hidden_states = (faceid_mask * faceid_hidden_states).sum(1) # 2*bs, s*s, dim 385 | else: 386 | faceid_hidden_states = (routing_map * faceid_hidden_states).sum(1) # 2*bs, s*s, dim 387 | 388 | if attn_args.enable_share_attn: 389 | faceid_hidden_states[0] = 0. 390 | faceid_hidden_states[batch_size // 2] = 0. 391 | else: 392 | faceid_hidden_states = torch.zeros_like(hidden_states) 393 | 394 | hidden_states = hidden_states + \ 395 | attn_args.ip_scale * ip_hidden_states + \ 396 | attn_args.faceid_scale * faceid_hidden_states 397 | 398 | # linear proj 399 | output_hidden_states = attn.to_out[0](hidden_states) 400 | if attn_args.lora_scale > 0.0: 401 | output_hidden_states = output_hidden_states + attn_args.lora_scale * self.to_out_lora(hidden_states) 402 | elif attn_args.multi_id_lora_scale > 0.0: 403 | output_hidden_states = output_hidden_states + attn_args.multi_id_lora_scale * self.to_out_multi_id_lora( 404 | hidden_states) 405 | hidden_states = output_hidden_states 406 | 407 | # dropout 408 | hidden_states = attn.to_out[1](hidden_states) 409 | 410 | if input_ndim == 4: 411 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 412 | 413 | if attn.residual_connection: 414 | hidden_states = hidden_states + residual 415 | 416 | hidden_states = hidden_states / attn.rescale_output_factor 417 | 418 | return hidden_states 419 | 420 | 421 | # for controlnet 422 | class UniPortraitCNAttnProcessor2_0: 423 | def __init__(self, num_ip_tokens=4, num_faceid_tokens=16): 424 | 425 | self.num_ip_tokens = num_ip_tokens 426 | self.num_faceid_tokens = num_faceid_tokens 427 | 428 | def __call__( 429 | self, 430 | attn, 431 | hidden_states, 432 | encoder_hidden_states=None, 433 | attention_mask=None, 434 | temb=None, 435 | *args, 436 | **kwargs, 437 | ): 438 | residual = hidden_states 439 | 440 | if attn.spatial_norm is not None: 441 | hidden_states = attn.spatial_norm(hidden_states, temb) 442 | 443 | input_ndim = hidden_states.ndim 444 | 445 | if input_ndim == 4: 446 | batch_size, channel, height, width = hidden_states.shape 447 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 448 | 449 | batch_size, sequence_length, _ = ( 450 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 451 | ) 452 | if attention_mask is not None: 453 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 454 | # scaled_dot_product_attention expects attention_mask shape to be 455 | # (batch, heads, source_length, target_length) 456 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 457 | 458 | if attn.group_norm is not None: 459 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 460 | 461 | query = attn.to_q(hidden_states) 462 | 463 | if encoder_hidden_states is None: 464 | encoder_hidden_states = hidden_states 465 | else: 466 | text_end = encoder_hidden_states.shape[1] - self.num_faceid_tokens * attn_args.num_faceids \ 467 | - self.num_ip_tokens 468 | encoder_hidden_states = encoder_hidden_states[:, :text_end] # only use text 469 | if attn.norm_cross: 470 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 471 | 472 | key = attn.to_k(encoder_hidden_states) 473 | value = attn.to_v(encoder_hidden_states) 474 | 475 | inner_dim = key.shape[-1] 476 | head_dim = inner_dim // attn.heads 477 | 478 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 479 | 480 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 481 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 482 | 483 | hidden_states = F.scaled_dot_product_attention( 484 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale 485 | ) 486 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 487 | hidden_states = hidden_states.to(query.dtype) 488 | 489 | # linear proj 490 | hidden_states = attn.to_out[0](hidden_states) 491 | # dropout 492 | hidden_states = attn.to_out[1](hidden_states) 493 | 494 | if input_ndim == 4: 495 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 496 | 497 | if attn.residual_connection: 498 | hidden_states = hidden_states + residual 499 | 500 | hidden_states = hidden_states / attn.rescale_output_factor 501 | 502 | return hidden_states 503 | -------------------------------------------------------------------------------- /uniportrait/uniportrait_pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from diffusers import ControlNetModel 5 | from diffusers.pipelines.controlnet import MultiControlNetModel 6 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection 7 | 8 | from .curricular_face.backbone import get_model 9 | from .resampler import UniPortraitFaceIDResampler 10 | from .uniportrait_attention_processor import UniPortraitCNAttnProcessor2_0 as UniPortraitCNAttnProcessor 11 | from .uniportrait_attention_processor import UniPortraitLoRAAttnProcessor2_0 as UniPortraitLoRAAttnProcessor 12 | from .uniportrait_attention_processor import UniPortraitLoRAIPAttnProcessor2_0 as UniPortraitLoRAIPAttnProcessor 13 | 14 | 15 | class ImageProjModel(nn.Module): 16 | """Projection Model""" 17 | 18 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 19 | super().__init__() 20 | 21 | self.cross_attention_dim = cross_attention_dim 22 | self.clip_extra_context_tokens = clip_extra_context_tokens 23 | self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 24 | self.norm = nn.LayerNorm(cross_attention_dim) 25 | 26 | def forward(self, image_embeds): 27 | embeds = image_embeds # b, c 28 | clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, 29 | self.cross_attention_dim) 30 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 31 | return clip_extra_context_tokens 32 | 33 | 34 | class UniPortraitPipeline: 35 | 36 | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt=None, face_backbone_ckpt=None, uniportrait_faceid_ckpt=None, 37 | uniportrait_router_ckpt=None, num_ip_tokens=4, num_faceid_tokens=16, 38 | lora_rank=128, device=torch.device("cuda"), torch_dtype=torch.float16): 39 | 40 | self.image_encoder_path = image_encoder_path 41 | self.ip_ckpt = ip_ckpt 42 | self.uniportrait_faceid_ckpt = uniportrait_faceid_ckpt 43 | self.uniportrait_router_ckpt = uniportrait_router_ckpt 44 | 45 | self.num_ip_tokens = num_ip_tokens 46 | self.num_faceid_tokens = num_faceid_tokens 47 | self.lora_rank = lora_rank 48 | 49 | self.device = device 50 | self.torch_dtype = torch_dtype 51 | 52 | self.pipe = sd_pipe.to(self.device) 53 | 54 | # load clip image encoder 55 | self.clip_image_processor = CLIPImageProcessor(size={"shortest_edge": 224}, do_center_crop=False, 56 | use_square_size=True) 57 | self.clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( 58 | self.device, dtype=self.torch_dtype) 59 | # load face backbone 60 | self.facerecog_model = get_model("IR_101")([112, 112]) 61 | self.facerecog_model.load_state_dict(torch.load(face_backbone_ckpt, map_location="cpu")) 62 | self.facerecog_model = self.facerecog_model.to(self.device, dtype=torch_dtype) 63 | self.facerecog_model.eval() 64 | # image proj model 65 | self.image_proj_model = self.init_image_proj() 66 | # faceid proj model 67 | self.faceid_proj_model = self.init_faceid_proj() 68 | # set uniportrait and ip adapter 69 | self.set_uniportrait_and_ip_adapter() 70 | # load uniportrait and ip adapter 71 | self.load_uniportrait_and_ip_adapter() 72 | 73 | def init_image_proj(self): 74 | image_proj_model = ImageProjModel( 75 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim, 76 | clip_embeddings_dim=self.clip_image_encoder.config.projection_dim, 77 | clip_extra_context_tokens=self.num_ip_tokens, 78 | ).to(self.device, dtype=self.torch_dtype) 79 | return image_proj_model 80 | 81 | def init_faceid_proj(self): 82 | faceid_proj_model = UniPortraitFaceIDResampler( 83 | intrinsic_id_embedding_dim=512, 84 | structure_embedding_dim=64 + 128 + 256 + self.clip_image_encoder.config.hidden_size, 85 | num_tokens=16, depth=6, 86 | dim=self.pipe.unet.config.cross_attention_dim, dim_head=64, 87 | heads=12, ff_mult=4, 88 | output_dim=self.pipe.unet.config.cross_attention_dim 89 | ).to(self.device, dtype=self.torch_dtype) 90 | return faceid_proj_model 91 | 92 | def set_uniportrait_and_ip_adapter(self): 93 | unet = self.pipe.unet 94 | attn_procs = {} 95 | for name in unet.attn_processors.keys(): 96 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 97 | if name.startswith("mid_block"): 98 | hidden_size = unet.config.block_out_channels[-1] 99 | elif name.startswith("up_blocks"): 100 | block_id = int(name[len("up_blocks.")]) 101 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 102 | elif name.startswith("down_blocks"): 103 | block_id = int(name[len("down_blocks.")]) 104 | hidden_size = unet.config.block_out_channels[block_id] 105 | if cross_attention_dim is None: 106 | attn_procs[name] = UniPortraitLoRAAttnProcessor( 107 | hidden_size=hidden_size, 108 | cross_attention_dim=cross_attention_dim, 109 | rank=self.lora_rank, 110 | ).to(self.device, dtype=self.torch_dtype).eval() 111 | else: 112 | attn_procs[name] = UniPortraitLoRAIPAttnProcessor( 113 | hidden_size=hidden_size, 114 | cross_attention_dim=cross_attention_dim, 115 | rank=self.lora_rank, 116 | num_ip_tokens=self.num_ip_tokens, 117 | num_faceid_tokens=self.num_faceid_tokens, 118 | ).to(self.device, dtype=self.torch_dtype).eval() 119 | unet.set_attn_processor(attn_procs) 120 | if hasattr(self.pipe, "controlnet"): 121 | if isinstance(self.pipe.controlnet, ControlNetModel): 122 | self.pipe.controlnet.set_attn_processor( 123 | UniPortraitCNAttnProcessor( 124 | num_ip_tokens=self.num_ip_tokens, 125 | num_faceid_tokens=self.num_faceid_tokens, 126 | ) 127 | ) 128 | elif isinstance(self.pipe.controlnet, MultiControlNetModel): 129 | for module in self.pipe.controlnet.nets: 130 | module.set_attn_processor( 131 | UniPortraitCNAttnProcessor( 132 | num_ip_tokens=self.num_ip_tokens, 133 | num_faceid_tokens=self.num_faceid_tokens, 134 | ) 135 | ) 136 | else: 137 | raise ValueError 138 | 139 | def load_uniportrait_and_ip_adapter(self): 140 | if self.ip_ckpt: 141 | print(f"loading from {self.ip_ckpt}...") 142 | state_dict = torch.load(self.ip_ckpt, map_location="cpu") 143 | self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=False) 144 | ip_layers = nn.ModuleList(self.pipe.unet.attn_processors.values()) 145 | ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) 146 | 147 | if self.uniportrait_faceid_ckpt: 148 | print(f"loading from {self.uniportrait_faceid_ckpt}...") 149 | state_dict = torch.load(self.uniportrait_faceid_ckpt, map_location="cpu") 150 | self.faceid_proj_model.load_state_dict(state_dict["faceid_proj"], strict=True) 151 | ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) 152 | ip_layers.load_state_dict(state_dict["faceid_adapter"], strict=False) 153 | 154 | if self.uniportrait_router_ckpt: 155 | print(f"loading from {self.uniportrait_router_ckpt}...") 156 | state_dict = torch.load(self.uniportrait_router_ckpt, map_location="cpu") 157 | router_state_dict = {} 158 | for k, v in state_dict["faceid_adapter"].items(): 159 | if "lora." in k: 160 | router_state_dict[k.replace("lora.", "multi_id_lora.")] = v 161 | elif "router." in k: 162 | router_state_dict[k] = v 163 | ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) 164 | ip_layers.load_state_dict(router_state_dict, strict=False) 165 | 166 | @torch.inference_mode() 167 | def get_ip_embeds(self, pil_ip_image): 168 | ip_image = self.clip_image_processor(images=pil_ip_image, return_tensors="pt").pixel_values 169 | ip_image = ip_image.to(self.device, dtype=self.torch_dtype) # (b, 3, 224, 224), values being normalized 170 | ip_embeds = self.clip_image_encoder(ip_image).image_embeds 171 | ip_prompt_embeds = self.image_proj_model(ip_embeds) 172 | uncond_ip_prompt_embeds = self.image_proj_model(torch.zeros_like(ip_embeds)) 173 | return ip_prompt_embeds, uncond_ip_prompt_embeds 174 | 175 | @torch.inference_mode() 176 | def get_single_faceid_embeds(self, pil_face_images, face_structure_scale): 177 | face_clip_image = self.clip_image_processor(images=pil_face_images, return_tensors="pt").pixel_values 178 | face_clip_image = face_clip_image.to(self.device, dtype=self.torch_dtype) # (b, 3, 224, 224) 179 | face_clip_embeds = self.clip_image_encoder( 180 | face_clip_image, output_hidden_states=True).hidden_states[-2][:, 1:] # b, 256, 1280 181 | 182 | OPENAI_CLIP_MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=self.device, 183 | dtype=self.torch_dtype).reshape(-1, 1, 1) 184 | OPENAI_CLIP_STD = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=self.device, 185 | dtype=self.torch_dtype).reshape(-1, 1, 1) 186 | facerecog_image = face_clip_image * OPENAI_CLIP_STD + OPENAI_CLIP_MEAN # [0, 1] 187 | facerecog_image = torch.clamp((facerecog_image - 0.5) / 0.5, -1, 1) # [-1, 1] 188 | facerecog_image = F.interpolate(facerecog_image, size=(112, 112), mode="bilinear", align_corners=False) 189 | facerecog_embeds = self.facerecog_model(facerecog_image, return_mid_feats=True)[1] 190 | 191 | face_intrinsic_id_embeds = facerecog_embeds[-1] # (b, 512, 7, 7) 192 | face_intrinsic_id_embeds = face_intrinsic_id_embeds.flatten(2).permute(0, 2, 1) # b, 49, 512 193 | 194 | facerecog_structure_embeds = facerecog_embeds[:-1] # (b, 64, 56, 56), (b, 128, 28, 28), (b, 256, 14, 14) 195 | facerecog_structure_embeds = torch.cat([ 196 | F.interpolate(feat, size=(16, 16), mode="bilinear", align_corners=False) 197 | for feat in facerecog_structure_embeds], dim=1) # b, 448, 16, 16 198 | facerecog_structure_embeds = facerecog_structure_embeds.flatten(2).permute(0, 2, 1) # b, 256, 448 199 | face_structure_embeds = torch.cat([facerecog_structure_embeds, face_clip_embeds], dim=-1) # b, 256, 1728 200 | 201 | uncond_face_clip_embeds = self.clip_image_encoder( 202 | torch.zeros_like(face_clip_image[:1]), output_hidden_states=True).hidden_states[-2][:, 1:] # 1, 256, 1280 203 | uncond_face_structure_embeds = torch.cat( 204 | [torch.zeros_like(facerecog_structure_embeds[:1]), uncond_face_clip_embeds], dim=-1) # 1, 256, 1728 205 | 206 | faceid_prompt_embeds = self.faceid_proj_model( 207 | face_intrinsic_id_embeds.flatten(0, 1).unsqueeze(0), 208 | face_structure_embeds.flatten(0, 1).unsqueeze(0), 209 | structure_scale=face_structure_scale, 210 | ) # [b, 16, 768] 211 | 212 | uncond_faceid_prompt_embeds = self.faceid_proj_model( 213 | torch.zeros_like(face_intrinsic_id_embeds[:1]), 214 | uncond_face_structure_embeds, 215 | structure_scale=face_structure_scale, 216 | ) # [1, 16, 768] 217 | 218 | return faceid_prompt_embeds, uncond_faceid_prompt_embeds 219 | 220 | def generate( 221 | self, 222 | prompt=None, 223 | negative_prompt=None, 224 | pil_ip_image=None, 225 | cond_faceids=None, 226 | face_structure_scale=0.0, 227 | seed=-1, 228 | guidance_scale=7.5, 229 | num_inference_steps=30, 230 | zT=None, 231 | **kwargs, 232 | ): 233 | """ 234 | Args: 235 | prompt: 236 | negative_prompt: 237 | pil_ip_image: 238 | cond_faceids: [ 239 | { 240 | "refs": [PIL.Image] or PIL.Image, 241 | (Optional) "mix_refs": [PIL.Image], 242 | (Optional) "mix_scales": [float], 243 | }, 244 | ... 245 | ] 246 | face_structure_scale: 247 | seed: 248 | guidance_scale: 249 | num_inference_steps: 250 | zT: 251 | **kwargs: 252 | Returns: 253 | """ 254 | 255 | if seed is not None: 256 | torch.manual_seed(seed) 257 | torch.cuda.manual_seed_all(seed) 258 | 259 | with torch.inference_mode(): 260 | prompt_embeds, negative_prompt_embeds = self.pipe.encode_prompt( 261 | prompt, device=self.device, num_images_per_prompt=1, do_classifier_free_guidance=True, 262 | negative_prompt=negative_prompt) 263 | num_prompts = prompt_embeds.shape[0] 264 | 265 | if pil_ip_image is not None: 266 | ip_prompt_embeds, uncond_ip_prompt_embeds = self.get_ip_embeds(pil_ip_image) 267 | ip_prompt_embeds = ip_prompt_embeds.repeat(num_prompts, 1, 1) 268 | uncond_ip_prompt_embeds = uncond_ip_prompt_embeds.repeat(num_prompts, 1, 1) 269 | else: 270 | ip_prompt_embeds = uncond_ip_prompt_embeds = \ 271 | torch.zeros_like(prompt_embeds[:, :1]).repeat(1, self.num_ip_tokens, 1) 272 | 273 | prompt_embeds = torch.cat([prompt_embeds, ip_prompt_embeds], dim=1) 274 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_ip_prompt_embeds], dim=1) 275 | 276 | if cond_faceids and len(cond_faceids) > 0: 277 | all_faceid_prompt_embeds = [] 278 | all_uncond_faceid_prompt_embeds = [] 279 | for curr_faceid_info in cond_faceids: 280 | refs = curr_faceid_info["refs"] 281 | faceid_prompt_embeds, uncond_faceid_prompt_embeds = \ 282 | self.get_single_faceid_embeds(refs, face_structure_scale) 283 | if "mix_refs" in curr_faceid_info: 284 | mix_refs = curr_faceid_info["mix_refs"] 285 | mix_scales = curr_faceid_info["mix_scales"] 286 | 287 | master_face_mix_scale = 1.0 - sum(mix_scales) 288 | faceid_prompt_embeds = faceid_prompt_embeds * master_face_mix_scale 289 | for mix_ref, mix_scale in zip(mix_refs, mix_scales): 290 | faceid_mix_prompt_embeds, _ = self.get_single_faceid_embeds(mix_ref, face_structure_scale) 291 | faceid_prompt_embeds = faceid_prompt_embeds + faceid_mix_prompt_embeds * mix_scale 292 | 293 | all_faceid_prompt_embeds.append(faceid_prompt_embeds) 294 | all_uncond_faceid_prompt_embeds.append(uncond_faceid_prompt_embeds) 295 | 296 | faceid_prompt_embeds = torch.cat(all_faceid_prompt_embeds, dim=1) 297 | uncond_faceid_prompt_embeds = torch.cat(all_uncond_faceid_prompt_embeds, dim=1) 298 | faceid_prompt_embeds = faceid_prompt_embeds.repeat(num_prompts, 1, 1) 299 | uncond_faceid_prompt_embeds = uncond_faceid_prompt_embeds.repeat(num_prompts, 1, 1) 300 | 301 | prompt_embeds = torch.cat([prompt_embeds, faceid_prompt_embeds], dim=1) 302 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_faceid_prompt_embeds], dim=1) 303 | 304 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 305 | if zT is not None: 306 | h_, w_ = kwargs["image"][0].shape[-2:] 307 | latents = torch.randn(num_prompts, 4, h_ // 8, w_ // 8, device=self.device, generator=generator, 308 | dtype=self.pipe.unet.dtype) 309 | latents[0] = zT 310 | else: 311 | latents = None 312 | 313 | images = self.pipe( 314 | prompt_embeds=prompt_embeds, 315 | negative_prompt_embeds=negative_prompt_embeds, 316 | guidance_scale=guidance_scale, 317 | num_inference_steps=num_inference_steps, 318 | generator=generator, 319 | latents=latents, 320 | **kwargs, 321 | ).images 322 | 323 | return images 324 | --------------------------------------------------------------------------------