├── .gitignore
├── LICENSE.txt
├── README.md
├── assets
├── application.png
├── examples
│ ├── 1-newton.jpg
│ ├── 1-output-1.png
│ ├── 2-output-1.png
│ ├── 2-stylegan2-ffhq-0100.png
│ ├── 2-stylegan2-ffhq-0293.png
│ ├── 3-output-1.png
│ ├── 3-output-2.png
│ ├── 3-output-3.png
│ ├── 3-output-4.png
│ ├── 3-style-1.png
│ ├── 3-style-2.jpg
│ ├── 3-style-3.jpg
│ ├── 3-stylegan2-ffhq-0293.png
│ └── 3-stylegan2-ffhq-0381.png
├── framework.png
└── highlight.png
├── gradio_app.py
├── requirements.txt
└── uniportrait
├── __init__.py
├── curricular_face
├── __init__.py
├── backbone
│ ├── __init__.py
│ ├── common.py
│ ├── model_irse.py
│ └── model_resnet.py
└── inference.py
├── inversion.py
├── resampler.py
├── uniportrait_attention_processor.py
└── uniportrait_pipeline.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | .DS_Store
3 | *.dat
4 | *.mat
5 |
6 | training/
7 | lightning_logs/
8 | image_log/
9 |
10 | *.png
11 | *.jpg
12 | *.jpeg
13 | *.webp
14 |
15 | *.pth
16 | *.pt
17 | *.ckpt
18 | *.safetensors
19 |
20 | # Byte-compiled / optimized / DLL files
21 | __pycache__/
22 | *.py[cod]
23 | *$py.class
24 |
25 | # C extensions
26 | *.so
27 |
28 | # Distribution / packaging
29 | .Python
30 | build/
31 | develop-eggs/
32 | dist/
33 | downloads/
34 | eggs/
35 | .eggs/
36 | lib/
37 | lib64/
38 | parts/
39 | sdist/
40 | var/
41 | wheels/
42 | pip-wheel-metadata/
43 | share/python-wheels/
44 | *.egg-info/
45 | .installed.cfg
46 | *.egg
47 | MANIFEST
48 |
49 | # PyInstaller
50 | # Usually these files are written by a python script from a template
51 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
52 | *.manifest
53 | *.spec
54 |
55 | # Installer logs
56 | pip-log.txt
57 | pip-delete-this-directory.txt
58 |
59 | # Unit test / coverage reports
60 | htmlcov/
61 | .tox/
62 | .nox/
63 | .coverage
64 | .coverage.*
65 | .cache
66 | nosetests.xml
67 | coverage.xml
68 | *.cover
69 | *.py,cover
70 | .hypothesis/
71 | .pytest_cache/
72 |
73 | # Translations
74 | *.mo
75 | *.pot
76 |
77 | # Django stuff:
78 | *.log
79 | local_settings.py
80 | db.sqlite3
81 | db.sqlite3-journal
82 |
83 | # Flask stuff:
84 | instance/
85 | .webassets-cache
86 |
87 | # Scrapy stuff:
88 | .scrapy
89 |
90 | # Sphinx documentation
91 | docs/_build/
92 |
93 | # PyBuilder
94 | target/
95 |
96 | # Jupyter Notebook
97 | .ipynb_checkpoints
98 |
99 | # IPython
100 | profile_default/
101 | ipython_config.py
102 |
103 | # pyenv
104 | .python-version
105 |
106 | # pipenv
107 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
108 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
109 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
110 | # install all needed dependencies.
111 | #Pipfile.lock
112 |
113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
114 | __pypackages__/
115 |
116 | # Celery stuff
117 | celerybeat-schedule
118 | celerybeat.pid
119 |
120 | # SageMath parsed files
121 | *.sage.py
122 |
123 | # Environments
124 | .env
125 | .venv
126 | env/
127 | venv/
128 | ENV/
129 | env.bak/
130 | venv.bak/
131 |
132 | # Spyder project settings
133 | .spyderproject
134 | .spyproject
135 |
136 | # Rope project settings
137 | .ropeproject
138 |
139 | # mkdocs documentation
140 | /site
141 |
142 | # mypy
143 | .mypy_cache/
144 | .dmypy.json
145 | dmypy.json
146 |
147 | # Pyre type checker
148 | .pyre/
149 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
UniPortrait: A Unified Framework for Identity-Preserving Single- and Multi-Human Image Personalization
3 |
4 |

5 |

6 |

7 |
8 |
9 |
10 |
11 |
12 | UniPortrait is an innovative human image personalization framework. It customizes single- and multi-ID images in a
13 | unified manner, providing high-fidelity identity preservation, extensive facial editability, free-form text description,
14 | and no requirement for a predetermined layout.
15 |
16 | ---
17 |
18 | ## Release
19 |
20 | - [2025/05/01] 🔥 We release the code and demo for the `FLUX.1-dev` version of [AnyStory](https://github.com/junjiehe96/AnyStory), a unified approach to general subject personalization.
21 | - [2024/10/18] 🔥 We release the inference code and demo, which has simply
22 | integrated [ControlNet](https://github.com/lllyasviel/ControlNet)
23 | , [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter),
24 | and [StyleAligned](https://github.com/google/style-aligned). The weight for this version is consistent with the
25 | huggingface space and experiments in the paper. We are now working on generalizing our method to more advanced
26 | diffusion models and more general custom concepts. Please stay tuned!
27 | - [2024/08/12] 🔥 We release the [technical report](https://arxiv.org/abs/2408.05939)
28 | , [project page](https://aigcdesigngroup.github.io/UniPortrait-Page/),
29 | and [HuggingFace demo](https://huggingface.co/spaces/Junjie96/UniPortrait) 🤗!
30 |
31 | ## Quickstart
32 |
33 | ```shell
34 | # Clone repository
35 | git clone https://github.com/junjiehe96/UniPortrait.git
36 |
37 | # install requirements
38 | cd UniPortrait
39 | pip install -r requirements.txt
40 |
41 | # download the models
42 | git lfs install
43 | git clone https://huggingface.co/Junjie96/UniPortrait models
44 | # download ip-adapter models
45 | # Note: recommend downloading manually. We do not require all IP adapter models.
46 | git clone https://huggingface.co/h94/IP-Adapter models/IP-Adapter
47 |
48 | # then you can use the gradio app
49 | python gradio_app.py
50 | ```
51 |
52 | ## Applications
53 |
54 |
55 |
56 | ## **Acknowledgements**
57 |
58 | This code is built on some excellent repos, including [diffusers](https://github.com/huggingface/diffusers), [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) and [StyleAligned](https://github.com/google/style-aligned). Highly appreciate their great work!
59 |
60 | ## Cite
61 |
62 | If you find UniPortrait useful for your research and applications, please cite us using this BibTeX:
63 |
64 | ```bibtex
65 | @article{he2024uniportrait,
66 | title={UniPortrait: A Unified Framework for Identity-Preserving Single-and Multi-Human Image Personalization},
67 | author={He, Junjie and Geng, Yifeng and Bo, Liefeng},
68 | journal={arXiv preprint arXiv:2408.05939},
69 | year={2024}
70 | }
71 | ```
72 |
73 | For any question, please feel free to open an issue or contact us via hejunjie1103@gmail.com.
74 |
--------------------------------------------------------------------------------
/assets/application.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/application.png
--------------------------------------------------------------------------------
/assets/examples/1-newton.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/1-newton.jpg
--------------------------------------------------------------------------------
/assets/examples/1-output-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/1-output-1.png
--------------------------------------------------------------------------------
/assets/examples/2-output-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/2-output-1.png
--------------------------------------------------------------------------------
/assets/examples/2-stylegan2-ffhq-0100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/2-stylegan2-ffhq-0100.png
--------------------------------------------------------------------------------
/assets/examples/2-stylegan2-ffhq-0293.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/2-stylegan2-ffhq-0293.png
--------------------------------------------------------------------------------
/assets/examples/3-output-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/3-output-1.png
--------------------------------------------------------------------------------
/assets/examples/3-output-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/3-output-2.png
--------------------------------------------------------------------------------
/assets/examples/3-output-3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/3-output-3.png
--------------------------------------------------------------------------------
/assets/examples/3-output-4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/3-output-4.png
--------------------------------------------------------------------------------
/assets/examples/3-style-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/3-style-1.png
--------------------------------------------------------------------------------
/assets/examples/3-style-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/3-style-2.jpg
--------------------------------------------------------------------------------
/assets/examples/3-style-3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/3-style-3.jpg
--------------------------------------------------------------------------------
/assets/examples/3-stylegan2-ffhq-0293.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/3-stylegan2-ffhq-0293.png
--------------------------------------------------------------------------------
/assets/examples/3-stylegan2-ffhq-0381.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/examples/3-stylegan2-ffhq-0381.png
--------------------------------------------------------------------------------
/assets/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/framework.png
--------------------------------------------------------------------------------
/assets/highlight.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/assets/highlight.png
--------------------------------------------------------------------------------
/gradio_app.py:
--------------------------------------------------------------------------------
1 | import os
2 | from io import BytesIO
3 |
4 | import cv2
5 | import gradio as gr
6 | import numpy as np
7 | import torch
8 | from PIL import Image
9 | from diffusers import DDIMScheduler, AutoencoderKL, ControlNetModel, StableDiffusionControlNetPipeline
10 | from insightface.app import FaceAnalysis
11 | from insightface.utils import face_align
12 |
13 | from uniportrait import inversion
14 | from uniportrait.uniportrait_attention_processor import attn_args
15 | from uniportrait.uniportrait_pipeline import UniPortraitPipeline
16 |
17 | port = 7860
18 |
19 | device = "cuda"
20 | torch_dtype = torch.float16
21 |
22 | # base
23 | base_model_path = "SG161222/Realistic_Vision_V5.1_noVAE"
24 | vae_model_path = "stabilityai/sd-vae-ft-mse"
25 | controlnet_pose_ckpt = "lllyasviel/control_v11p_sd15_openpose"
26 | # specific
27 | image_encoder_path = "models/IP-Adapter/models/image_encoder"
28 | ip_ckpt = "models/IP-Adapter/models/ip-adapter_sd15.bin"
29 | face_backbone_ckpt = "models/glint360k_curricular_face_r101_backbone.bin"
30 | uniportrait_faceid_ckpt = "models/uniportrait-faceid_sd15.bin"
31 | uniportrait_router_ckpt = "models/uniportrait-router_sd15.bin"
32 |
33 | # load controlnet
34 | pose_controlnet = ControlNetModel.from_pretrained(controlnet_pose_ckpt, torch_dtype=torch_dtype)
35 |
36 | # load SD pipeline
37 | noise_scheduler = DDIMScheduler(
38 | num_train_timesteps=1000,
39 | beta_start=0.00085,
40 | beta_end=0.012,
41 | beta_schedule="scaled_linear",
42 | clip_sample=False,
43 | set_alpha_to_one=False,
44 | steps_offset=1,
45 | )
46 | vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch_dtype)
47 | pipe = StableDiffusionControlNetPipeline.from_pretrained(
48 | base_model_path,
49 | controlnet=[pose_controlnet],
50 | torch_dtype=torch_dtype,
51 | scheduler=noise_scheduler,
52 | vae=vae,
53 | # feature_extractor=None,
54 | # safety_checker=None,
55 | )
56 |
57 | # load uniportrait pipeline
58 | uniportrait_pipeline = UniPortraitPipeline(pipe, image_encoder_path, ip_ckpt=ip_ckpt,
59 | face_backbone_ckpt=face_backbone_ckpt,
60 | uniportrait_faceid_ckpt=uniportrait_faceid_ckpt,
61 | uniportrait_router_ckpt=uniportrait_router_ckpt,
62 | device=device, torch_dtype=torch_dtype)
63 |
64 | # load face detection assets
65 | face_app = FaceAnalysis(providers=['CUDAExecutionProvider'], allowed_modules=["detection"])
66 | face_app.prepare(ctx_id=0, det_size=(640, 640))
67 |
68 |
69 | def pad_np_bgr_image(np_image, scale=1.25):
70 | assert scale >= 1.0, "scale should be >= 1.0"
71 | pad_scale = scale - 1.0
72 | h, w = np_image.shape[:2]
73 | top = bottom = int(h * pad_scale)
74 | left = right = int(w * pad_scale)
75 | ret = cv2.copyMakeBorder(np_image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(128, 128, 128))
76 | return ret, (left, top)
77 |
78 |
79 | def process_faceid_image(pil_faceid_image):
80 | np_faceid_image = np.array(pil_faceid_image.convert("RGB"))
81 | img = cv2.cvtColor(np_faceid_image, cv2.COLOR_RGB2BGR)
82 | faces = face_app.get(img) # bgr
83 | if len(faces) == 0:
84 | # padding, try again
85 | _h, _w = img.shape[:2]
86 | _img, left_top_coord = pad_np_bgr_image(img)
87 | faces = face_app.get(_img)
88 | if len(faces) == 0:
89 | gr.Info("Warning: No face detected in the image. Continue processing...")
90 |
91 | min_coord = np.array([0, 0])
92 | max_coord = np.array([_w, _h])
93 | sub_coord = np.array([left_top_coord[0], left_top_coord[1]])
94 | for face in faces:
95 | face.bbox = np.minimum(np.maximum(face.bbox.reshape(-1, 2) - sub_coord, min_coord), max_coord).reshape(4)
96 | face.kps = face.kps - sub_coord
97 |
98 | faces = sorted(faces, key=lambda x: abs((x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1])), reverse=True)
99 | faceid_face = faces[0]
100 | norm_face = face_align.norm_crop(img, landmark=faceid_face.kps, image_size=224)
101 | pil_faceid_align_image = Image.fromarray(cv2.cvtColor(norm_face, cv2.COLOR_BGR2RGB))
102 |
103 | return pil_faceid_align_image
104 |
105 |
106 | def prepare_single_faceid_cond_kwargs(pil_faceid_image=None, pil_faceid_supp_images=None,
107 | pil_faceid_mix_images=None, mix_scales=None):
108 | pil_faceid_align_images = []
109 | if pil_faceid_image:
110 | pil_faceid_align_images.append(process_faceid_image(pil_faceid_image))
111 | if pil_faceid_supp_images and len(pil_faceid_supp_images) > 0:
112 | for pil_faceid_supp_image in pil_faceid_supp_images:
113 | if isinstance(pil_faceid_supp_image, Image.Image):
114 | pil_faceid_align_images.append(process_faceid_image(pil_faceid_supp_image))
115 | else:
116 | pil_faceid_align_images.append(
117 | process_faceid_image(Image.open(BytesIO(pil_faceid_supp_image)))
118 | )
119 |
120 | mix_refs = []
121 | mix_ref_scales = []
122 | if pil_faceid_mix_images:
123 | for pil_faceid_mix_image, mix_scale in zip(pil_faceid_mix_images, mix_scales):
124 | if pil_faceid_mix_image:
125 | mix_refs.append(process_faceid_image(pil_faceid_mix_image))
126 | mix_ref_scales.append(mix_scale)
127 |
128 | single_faceid_cond_kwargs = None
129 | if len(pil_faceid_align_images) > 0:
130 | single_faceid_cond_kwargs = {
131 | "refs": pil_faceid_align_images
132 | }
133 | if len(mix_refs) > 0:
134 | single_faceid_cond_kwargs["mix_refs"] = mix_refs
135 | single_faceid_cond_kwargs["mix_scales"] = mix_ref_scales
136 |
137 | return single_faceid_cond_kwargs
138 |
139 |
140 | def text_to_single_id_generation_process(
141 | pil_faceid_image=None, pil_faceid_supp_images=None,
142 | pil_faceid_mix_image_1=None, mix_scale_1=0.0,
143 | pil_faceid_mix_image_2=None, mix_scale_2=0.0,
144 | faceid_scale=0.0, face_structure_scale=0.0,
145 | prompt="", negative_prompt="",
146 | num_samples=1, seed=-1,
147 | image_resolution="512x512",
148 | inference_steps=25,
149 | ):
150 | if seed == -1:
151 | seed = None
152 |
153 | single_faceid_cond_kwargs = prepare_single_faceid_cond_kwargs(pil_faceid_image,
154 | pil_faceid_supp_images,
155 | [pil_faceid_mix_image_1, pil_faceid_mix_image_2],
156 | [mix_scale_1, mix_scale_2])
157 |
158 | cond_faceids = [single_faceid_cond_kwargs] if single_faceid_cond_kwargs else []
159 |
160 | # reset attn args
161 | attn_args.reset()
162 | # set faceid condition
163 | attn_args.lora_scale = 1.0 if len(cond_faceids) == 1 else 0.0 # single-faceid lora
164 | attn_args.multi_id_lora_scale = 1.0 if len(cond_faceids) > 1 else 0.0 # multi-faceid lora
165 | attn_args.faceid_scale = faceid_scale if len(cond_faceids) > 0 else 0.0
166 | attn_args.num_faceids = len(cond_faceids)
167 | print(attn_args)
168 |
169 | h, w = int(image_resolution.split("x")[0]), int(image_resolution.split("x")[1])
170 | prompt = [prompt] * num_samples
171 | negative_prompt = [negative_prompt] * num_samples
172 | images = uniportrait_pipeline.generate(prompt=prompt, negative_prompt=negative_prompt,
173 | cond_faceids=cond_faceids, face_structure_scale=face_structure_scale,
174 | seed=seed, guidance_scale=7.5,
175 | num_inference_steps=inference_steps,
176 | image=[torch.zeros([1, 3, h, w])],
177 | controlnet_conditioning_scale=[0.0])
178 | final_out = []
179 | for pil_image in images:
180 | final_out.append(pil_image)
181 |
182 | for single_faceid_cond_kwargs in cond_faceids:
183 | final_out.extend(single_faceid_cond_kwargs["refs"])
184 | if "mix_refs" in single_faceid_cond_kwargs:
185 | final_out.extend(single_faceid_cond_kwargs["mix_refs"])
186 |
187 | return final_out
188 |
189 |
190 | def text_to_multi_id_generation_process(
191 | pil_faceid_image_1=None, pil_faceid_supp_images_1=None,
192 | pil_faceid_mix_image_1_1=None, mix_scale_1_1=0.0,
193 | pil_faceid_mix_image_1_2=None, mix_scale_1_2=0.0,
194 | pil_faceid_image_2=None, pil_faceid_supp_images_2=None,
195 | pil_faceid_mix_image_2_1=None, mix_scale_2_1=0.0,
196 | pil_faceid_mix_image_2_2=None, mix_scale_2_2=0.0,
197 | faceid_scale=0.0, face_structure_scale=0.0,
198 | prompt="", negative_prompt="",
199 | num_samples=1, seed=-1,
200 | image_resolution="512x512",
201 | inference_steps=25,
202 | ):
203 | if seed == -1:
204 | seed = None
205 |
206 | faceid_cond_kwargs_1 = prepare_single_faceid_cond_kwargs(pil_faceid_image_1,
207 | pil_faceid_supp_images_1,
208 | [pil_faceid_mix_image_1_1,
209 | pil_faceid_mix_image_1_2],
210 | [mix_scale_1_1, mix_scale_1_2])
211 | faceid_cond_kwargs_2 = prepare_single_faceid_cond_kwargs(pil_faceid_image_2,
212 | pil_faceid_supp_images_2,
213 | [pil_faceid_mix_image_2_1,
214 | pil_faceid_mix_image_2_2],
215 | [mix_scale_2_1, mix_scale_2_2])
216 | cond_faceids = []
217 | if faceid_cond_kwargs_1 is not None:
218 | cond_faceids.append(faceid_cond_kwargs_1)
219 | if faceid_cond_kwargs_2 is not None:
220 | cond_faceids.append(faceid_cond_kwargs_2)
221 |
222 | # reset attn args
223 | attn_args.reset()
224 | # set faceid condition
225 | attn_args.lora_scale = 1.0 if len(cond_faceids) == 1 else 0.0 # single-faceid lora
226 | attn_args.multi_id_lora_scale = 1.0 if len(cond_faceids) > 1 else 0.0 # multi-faceid lora
227 | attn_args.faceid_scale = faceid_scale if len(cond_faceids) > 0 else 0.0
228 | attn_args.num_faceids = len(cond_faceids)
229 | print(attn_args)
230 |
231 | h, w = int(image_resolution.split("x")[0]), int(image_resolution.split("x")[1])
232 | prompt = [prompt] * num_samples
233 | negative_prompt = [negative_prompt] * num_samples
234 | images = uniportrait_pipeline.generate(prompt=prompt, negative_prompt=negative_prompt,
235 | cond_faceids=cond_faceids, face_structure_scale=face_structure_scale,
236 | seed=seed, guidance_scale=7.5,
237 | num_inference_steps=inference_steps,
238 | image=[torch.zeros([1, 3, h, w])],
239 | controlnet_conditioning_scale=[0.0])
240 |
241 | final_out = []
242 | for pil_image in images:
243 | final_out.append(pil_image)
244 |
245 | for single_faceid_cond_kwargs in cond_faceids:
246 | final_out.extend(single_faceid_cond_kwargs["refs"])
247 | if "mix_refs" in single_faceid_cond_kwargs:
248 | final_out.extend(single_faceid_cond_kwargs["mix_refs"])
249 |
250 | return final_out
251 |
252 |
253 | def image_to_single_id_generation_process(
254 | pil_faceid_image=None, pil_faceid_supp_images=None,
255 | pil_faceid_mix_image_1=None, mix_scale_1=0.0,
256 | pil_faceid_mix_image_2=None, mix_scale_2=0.0,
257 | faceid_scale=0.0, face_structure_scale=0.0,
258 | pil_ip_image=None, ip_scale=1.0,
259 | num_samples=1, seed=-1, image_resolution="768x512",
260 | inference_steps=25,
261 | ):
262 | if seed == -1:
263 | seed = None
264 |
265 | single_faceid_cond_kwargs = prepare_single_faceid_cond_kwargs(pil_faceid_image,
266 | pil_faceid_supp_images,
267 | [pil_faceid_mix_image_1, pil_faceid_mix_image_2],
268 | [mix_scale_1, mix_scale_2])
269 |
270 | cond_faceids = [single_faceid_cond_kwargs] if single_faceid_cond_kwargs else []
271 |
272 | h, w = int(image_resolution.split("x")[0]), int(image_resolution.split("x")[1])
273 |
274 | # Image Prompt and Style Aligned
275 | if pil_ip_image is None:
276 | gr.Error("Please upload a reference image")
277 | attn_args.reset()
278 | pil_ip_image = pil_ip_image.convert("RGB").resize((w, h))
279 | zts = inversion.ddim_inversion(uniportrait_pipeline.pipe, np.array(pil_ip_image), "", inference_steps, 2)
280 | zT, inversion_callback = inversion.make_inversion_callback(zts, offset=0)
281 |
282 | # reset attn args
283 | attn_args.reset()
284 | # set ip condition
285 | attn_args.ip_scale = ip_scale if pil_ip_image else 0.0
286 | # set faceid condition
287 | attn_args.lora_scale = 1.0 if len(cond_faceids) == 1 else 0.0 # lora for single faceid
288 | attn_args.multi_id_lora_scale = 1.0 if len(cond_faceids) > 1 else 0.0 # lora for >1 faceids
289 | attn_args.faceid_scale = faceid_scale if len(cond_faceids) > 0 else 0.0
290 | attn_args.num_faceids = len(cond_faceids)
291 | # set shared self-attn
292 | attn_args.enable_share_attn = True
293 | attn_args.shared_score_shift = -0.5
294 | print(attn_args)
295 |
296 | prompt = [""] * (1 + num_samples)
297 | negative_prompt = [""] * (1 + num_samples)
298 | images = uniportrait_pipeline.generate(prompt=prompt, negative_prompt=negative_prompt,
299 | pil_ip_image=pil_ip_image,
300 | cond_faceids=cond_faceids, face_structure_scale=face_structure_scale,
301 | seed=seed, guidance_scale=7.5,
302 | num_inference_steps=inference_steps,
303 | image=[torch.zeros([1, 3, h, w])],
304 | controlnet_conditioning_scale=[0.0],
305 | zT=zT, callback_on_step_end=inversion_callback)
306 | images = images[1:]
307 |
308 | final_out = []
309 | for pil_image in images:
310 | final_out.append(pil_image)
311 |
312 | for single_faceid_cond_kwargs in cond_faceids:
313 | final_out.extend(single_faceid_cond_kwargs["refs"])
314 | if "mix_refs" in single_faceid_cond_kwargs:
315 | final_out.extend(single_faceid_cond_kwargs["mix_refs"])
316 |
317 | return final_out
318 |
319 |
320 | def text_to_single_id_generation_block():
321 | gr.Markdown("## Text-to-Single-ID Generation")
322 | gr.HTML(text_to_single_id_description)
323 | gr.HTML(text_to_single_id_tips)
324 | with gr.Row():
325 | with gr.Column(scale=1, min_width=100):
326 | prompt = gr.Textbox(value="", label='Prompt', lines=2)
327 | negative_prompt = gr.Textbox(value="nsfw", label='Negative Prompt')
328 |
329 | run_button = gr.Button(value="Run")
330 | with gr.Accordion("Options", open=True):
331 | image_resolution = gr.Dropdown(choices=["768x512", "512x512", "512x768"], value="512x512",
332 | label="Image Resolution (HxW)")
333 | seed = gr.Slider(label="Seed (-1 indicates random)", minimum=-1, maximum=2147483647, step=1,
334 | value=2147483647)
335 | num_samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1)
336 | inference_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1, visible=False)
337 |
338 | faceid_scale = gr.Slider(label="Face ID Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.7)
339 | face_structure_scale = gr.Slider(label="Face Structure Scale", minimum=0.0, maximum=1.0,
340 | step=0.01, value=0.1)
341 |
342 | with gr.Column(scale=2, min_width=100):
343 | with gr.Row(equal_height=False):
344 | pil_faceid_image = gr.Image(type="pil", label="ID Image")
345 | with gr.Accordion("ID Supplements", open=True):
346 | with gr.Row():
347 | pil_faceid_supp_images = gr.File(file_count="multiple", file_types=["image"],
348 | type="binary", label="Additional ID Images")
349 | with gr.Row():
350 | with gr.Column(scale=1, min_width=100):
351 | pil_faceid_mix_image_1 = gr.Image(type="pil", label="Mix ID 1")
352 | mix_scale_1 = gr.Slider(label="Mix Scale 1", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
353 | with gr.Column(scale=1, min_width=100):
354 | pil_faceid_mix_image_2 = gr.Image(type="pil", label="Mix ID 2")
355 | mix_scale_2 = gr.Slider(label="Mix Scale 2", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
356 |
357 | with gr.Row():
358 | example_output = gr.Image(type="pil", label="(Example Output)", visible=False)
359 | result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=4, preview=True,
360 | format="png")
361 | with gr.Row():
362 | examples = [
363 | [
364 | "A young man with short black hair, wearing a black hoodie with a hood, was paired with a blue denim jacket with yellow details.",
365 | "assets/examples/1-newton.jpg",
366 | "assets/examples/1-output-1.png",
367 | ],
368 | ]
369 | gr.Examples(
370 | label="Examples",
371 | examples=examples,
372 | fn=lambda x, y, z: (x, y),
373 | inputs=[prompt, pil_faceid_image, example_output],
374 | outputs=[prompt, pil_faceid_image]
375 | )
376 | ips = [
377 | pil_faceid_image, pil_faceid_supp_images,
378 | pil_faceid_mix_image_1, mix_scale_1,
379 | pil_faceid_mix_image_2, mix_scale_2,
380 | faceid_scale, face_structure_scale,
381 | prompt, negative_prompt,
382 | num_samples, seed,
383 | image_resolution,
384 | inference_steps,
385 | ]
386 | run_button.click(fn=text_to_single_id_generation_process, inputs=ips, outputs=[result_gallery])
387 |
388 |
389 | def text_to_multi_id_generation_block():
390 | gr.Markdown("## Text-to-Multi-ID Generation")
391 | gr.HTML(text_to_multi_id_description)
392 | gr.HTML(text_to_multi_id_tips)
393 | with gr.Row():
394 | with gr.Column(scale=1, min_width=100):
395 | prompt = gr.Textbox(value="", label='Prompt', lines=2)
396 | negative_prompt = gr.Textbox(value="nsfw", label='Negative Prompt')
397 | run_button = gr.Button(value="Run")
398 | with gr.Accordion("Options", open=True):
399 | image_resolution = gr.Dropdown(choices=["768x512", "512x512", "512x768"], value="512x512",
400 | label="Image Resolution (HxW)")
401 | seed = gr.Slider(label="Seed (-1 indicates random)", minimum=-1, maximum=2147483647, step=1,
402 | value=2147483647)
403 | num_samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1)
404 | inference_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1, visible=False)
405 |
406 | faceid_scale = gr.Slider(label="Face ID Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.7)
407 | face_structure_scale = gr.Slider(label="Face Structure Scale", minimum=0.0, maximum=1.0,
408 | step=0.01, value=0.3)
409 |
410 | with gr.Column(scale=2, min_width=100):
411 | with gr.Row(equal_height=False):
412 | with gr.Column(scale=1, min_width=100):
413 | pil_faceid_image_1 = gr.Image(type="pil", label="First ID")
414 | with gr.Accordion("First ID Supplements", open=False):
415 | with gr.Row():
416 | pil_faceid_supp_images_1 = gr.File(file_count="multiple", file_types=["image"],
417 | type="binary", label="Additional ID Images")
418 | with gr.Row():
419 | with gr.Column(scale=1, min_width=100):
420 | pil_faceid_mix_image_1_1 = gr.Image(type="pil", label="Mix ID 1")
421 | mix_scale_1_1 = gr.Slider(label="Mix Scale 1", minimum=0.0, maximum=1.0, step=0.01,
422 | value=0.0)
423 | with gr.Column(scale=1, min_width=100):
424 | pil_faceid_mix_image_1_2 = gr.Image(type="pil", label="Mix ID 2")
425 | mix_scale_1_2 = gr.Slider(label="Mix Scale 2", minimum=0.0, maximum=1.0, step=0.01,
426 | value=0.0)
427 | with gr.Column(scale=1, min_width=100):
428 | pil_faceid_image_2 = gr.Image(type="pil", label="Second ID")
429 | with gr.Accordion("Second ID Supplements", open=False):
430 | with gr.Row():
431 | pil_faceid_supp_images_2 = gr.File(file_count="multiple", file_types=["image"],
432 | type="binary", label="Additional ID Images")
433 | with gr.Row():
434 | with gr.Column(scale=1, min_width=100):
435 | pil_faceid_mix_image_2_1 = gr.Image(type="pil", label="Mix ID 1")
436 | mix_scale_2_1 = gr.Slider(label="Mix Scale 1", minimum=0.0, maximum=1.0, step=0.01,
437 | value=0.0)
438 | with gr.Column(scale=1, min_width=100):
439 | pil_faceid_mix_image_2_2 = gr.Image(type="pil", label="Mix ID 2")
440 | mix_scale_2_2 = gr.Slider(label="Mix Scale 2", minimum=0.0, maximum=1.0, step=0.01,
441 | value=0.0)
442 |
443 | with gr.Row():
444 | example_output = gr.Image(type="pil", label="(Example Output)", visible=False)
445 | result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=4, preview=True,
446 | format="png")
447 | with gr.Row():
448 | examples = [
449 | [
450 | "The two female models, fair-skinned, wore a white V-neck short-sleeved top with a light smile on the corners of their mouths. The background was off-white.",
451 | "assets/examples/2-stylegan2-ffhq-0100.png",
452 | "assets/examples/2-stylegan2-ffhq-0293.png",
453 | "assets/examples/2-output-1.png",
454 | ],
455 | ]
456 | gr.Examples(
457 | label="Examples",
458 | examples=examples,
459 | inputs=[prompt, pil_faceid_image_1, pil_faceid_image_2, example_output],
460 | )
461 | ips = [
462 | pil_faceid_image_1, pil_faceid_supp_images_1,
463 | pil_faceid_mix_image_1_1, mix_scale_1_1,
464 | pil_faceid_mix_image_1_2, mix_scale_1_2,
465 | pil_faceid_image_2, pil_faceid_supp_images_2,
466 | pil_faceid_mix_image_2_1, mix_scale_2_1,
467 | pil_faceid_mix_image_2_2, mix_scale_2_2,
468 | faceid_scale, face_structure_scale,
469 | prompt, negative_prompt,
470 | num_samples, seed,
471 | image_resolution,
472 | inference_steps,
473 | ]
474 | run_button.click(fn=text_to_multi_id_generation_process, inputs=ips, outputs=[result_gallery])
475 |
476 |
477 | def image_to_single_id_generation_block():
478 | gr.Markdown("## Image-to-Single-ID Generation")
479 | gr.HTML(image_to_single_id_description)
480 | gr.HTML(image_to_single_id_tips)
481 | with gr.Row():
482 | with gr.Column(scale=1, min_width=100):
483 | run_button = gr.Button(value="Run")
484 | seed = gr.Slider(label="Seed (-1 indicates random)", minimum=-1, maximum=2147483647, step=1,
485 | value=2147483647)
486 | num_samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1)
487 | image_resolution = gr.Dropdown(choices=["768x512", "512x512", "512x768"], value="512x512",
488 | label="Image Resolution (HxW)")
489 | inference_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1, visible=False)
490 |
491 | ip_scale = gr.Slider(label="Reference Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.7)
492 | faceid_scale = gr.Slider(label="Face ID Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.7)
493 | face_structure_scale = gr.Slider(label="Face Structure Scale", minimum=0.0, maximum=1.0, step=0.01,
494 | value=0.3)
495 |
496 | with gr.Column(scale=3, min_width=100):
497 | with gr.Row(equal_height=False):
498 | pil_ip_image = gr.Image(type="pil", label="Portrait Reference")
499 | pil_faceid_image = gr.Image(type="pil", label="ID Image")
500 | with gr.Accordion("ID Supplements", open=True):
501 | with gr.Row():
502 | pil_faceid_supp_images = gr.File(file_count="multiple", file_types=["image"],
503 | type="binary", label="Additional ID Images")
504 | with gr.Row():
505 | with gr.Column(scale=1, min_width=100):
506 | pil_faceid_mix_image_1 = gr.Image(type="pil", label="Mix ID 1")
507 | mix_scale_1 = gr.Slider(label="Mix Scale 1", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
508 | with gr.Column(scale=1, min_width=100):
509 | pil_faceid_mix_image_2 = gr.Image(type="pil", label="Mix ID 2")
510 | mix_scale_2 = gr.Slider(label="Mix Scale 2", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
511 | with gr.Row():
512 | with gr.Column(scale=3, min_width=100):
513 | example_output = gr.Image(type="pil", label="(Example Output)", visible=False)
514 | result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=4,
515 | preview=True, format="png")
516 | with gr.Row():
517 | examples = [
518 | [
519 | "assets/examples/3-style-1.png",
520 | "assets/examples/3-stylegan2-ffhq-0293.png",
521 | 0.7,
522 | 0.3,
523 | "assets/examples/3-output-1.png",
524 | ],
525 | [
526 | "assets/examples/3-style-1.png",
527 | "assets/examples/3-stylegan2-ffhq-0293.png",
528 | 0.6,
529 | 0.0,
530 | "assets/examples/3-output-2.png",
531 | ],
532 | [
533 | "assets/examples/3-style-2.jpg",
534 | "assets/examples/3-stylegan2-ffhq-0381.png",
535 | 0.7,
536 | 0.3,
537 | "assets/examples/3-output-3.png",
538 | ],
539 | [
540 | "assets/examples/3-style-3.jpg",
541 | "assets/examples/3-stylegan2-ffhq-0381.png",
542 | 0.6,
543 | 0.0,
544 | "assets/examples/3-output-4.png",
545 | ],
546 | ]
547 | gr.Examples(
548 | label="Examples",
549 | examples=examples,
550 | fn=lambda x, y, z, w, v: (x, y, z, w),
551 | inputs=[pil_ip_image, pil_faceid_image, faceid_scale, face_structure_scale, example_output],
552 | outputs=[pil_ip_image, pil_faceid_image, faceid_scale, face_structure_scale]
553 | )
554 | ips = [
555 | pil_faceid_image, pil_faceid_supp_images,
556 | pil_faceid_mix_image_1, mix_scale_1,
557 | pil_faceid_mix_image_2, mix_scale_2,
558 | faceid_scale, face_structure_scale,
559 | pil_ip_image, ip_scale,
560 | num_samples, seed, image_resolution,
561 | inference_steps,
562 | ]
563 | run_button.click(fn=image_to_single_id_generation_process, inputs=ips, outputs=[result_gallery])
564 |
565 |
566 | if __name__ == "__main__":
567 | os.environ["no_proxy"] = "localhost,127.0.0.1,::1"
568 |
569 | title = r"""
570 |
571 |
UniPortrait: A Unified Framework for Identity-Preserving Single- and Multi-Human Image Personalization
572 |
573 |

574 |
575 |

576 |
577 |

578 |
579 |
580 |
581 | """
582 |
583 | title_description = r"""
584 | This is the official 🤗 Gradio demo for UniPortrait: A Unified Framework for Identity-Preserving Single- and Multi-Human Image Personalization.
585 | The demo provides three capabilities: text-to-single-ID personalization, text-to-multi-ID personalization, and image-to-single-ID personalization. All of these are based on the Stable Diffusion v1-5 model. Feel free to give them a try! 😊
586 | """
587 |
588 | text_to_single_id_description = r"""🚀🚀🚀Quick start:
589 | 1. Enter a text prompt (Chinese or English), Upload an image with a face, and Click the Run button. 🤗
590 | """
591 |
592 | text_to_single_id_tips = r"""💡💡💡Tips:
593 | 1. Try to avoid creating too small faces, as this may lead to some artifacts. (Currently, the short side length of the generated image is limited to 512)
594 | 2. It's a good idea to upload multiple reference photos of your face to improve the prompt and ID consistency. Additional references can be uploaded in the "ID supplements".
595 | 3. The appropriate values of "Face ID Scale" and "Face Structure Scale" are important for balancing the ID and text alignment. We recommend using "Face ID Scale" (0.5~0.7) and "Face Structure Scale" (0.0~0.4).
596 | """
597 |
598 | text_to_multi_id_description = r"""🚀🚀🚀Quick start:
599 | 1. Enter a text prompt (Chinese or English), Upload an image with a face in "First ID" and "Second ID" blocks respectively, and Click the Run button. 🤗
600 | """
601 |
602 | text_to_multi_id_tips = r"""💡💡💡Tips:
603 | 1. Try to avoid creating too small faces, as this may lead to some artifacts. (Currently, the short side length of the generated image is limited to 512)
604 | 2. It's a good idea to upload multiple reference photos of your face to improve the prompt and ID consistency. Additional references can be uploaded in the "ID supplements".
605 | 3. The appropriate values of "Face ID Scale" and "Face Structure Scale" are important for balancing the ID and text alignment. We recommend using "Face ID Scale" (0.3~0.7) and "Face Structure Scale" (0.0~0.4).
606 | """
607 |
608 | image_to_single_id_description = r"""🚀🚀🚀Quick start: Upload an image as the portrait reference (can be any style), Upload a face image, and Click the Run button. 🤗
"""
609 |
610 | image_to_single_id_tips = r"""💡💡💡Tips:
611 | 1. Try to avoid creating too small faces, as this may lead to some artifacts. (Currently, the short side length of the generated image is limited to 512)
612 | 2. It's a good idea to upload multiple reference photos of your face to improve ID consistency. Additional references can be uploaded in the "ID supplements".
613 | 3. The appropriate values of "Face ID Scale" and "Face Structure Scale" are important for balancing the portrait reference and ID alignment. We recommend using "Face ID Scale" (0.5~0.7) and "Face Structure Scale" (0.0~0.4).
614 | """
615 |
616 | citation = r"""
617 | ---
618 | 📝 **Citation**
619 |
620 | If our work is helpful for your research or applications, please cite us via:
621 | ```bibtex
622 | @article{he2024uniportrait,
623 | title={UniPortrait: A Unified Framework for Identity-Preserving Single-and Multi-Human Image Personalization},
624 | author={He, Junjie and Geng, Yifeng and Bo, Liefeng},
625 | journal={arXiv preprint arXiv:2408.05939},
626 | year={2024}
627 | }
628 | ```
629 | 📧 **Contact**
630 |
631 | If you have any questions, please feel free to open an issue or directly reach us out at hejunjie1103@gmail.com.
632 | """
633 |
634 | block = gr.Blocks(title="UniPortrait").queue()
635 | with block:
636 | gr.HTML(title)
637 | gr.HTML(title_description)
638 |
639 | with gr.TabItem("Text-to-Single-ID"):
640 | text_to_single_id_generation_block()
641 |
642 | with gr.TabItem("Text-to-Multi-ID"):
643 | text_to_multi_id_generation_block()
644 |
645 | with gr.TabItem("Image-to-Single-ID (Stylization)"):
646 | image_to_single_id_generation_block()
647 |
648 | gr.Markdown(citation)
649 |
650 | block.launch(server_name='0.0.0.0', share=False, server_port=port, allowed_paths=["/"])
651 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers
2 | gradio
3 | onnxruntime-gpu
4 | insightface
5 | torch
6 | tqdm
7 | transformers
8 |
--------------------------------------------------------------------------------
/uniportrait/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/uniportrait/__init__.py
--------------------------------------------------------------------------------
/uniportrait/curricular_face/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junjiehe96/UniPortrait/a4deff2b48e3f67c8466905f27b5a662fc9e912e/uniportrait/curricular_face/__init__.py
--------------------------------------------------------------------------------
/uniportrait/curricular_face/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | # The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at
2 | # https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone
3 | from .model_irse import (IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, IR_SE_50,
4 | IR_SE_101, IR_SE_152, IR_SE_200)
5 | from .model_resnet import ResNet_50, ResNet_101, ResNet_152
6 |
7 | _model_dict = {
8 | 'ResNet_50': ResNet_50,
9 | 'ResNet_101': ResNet_101,
10 | 'ResNet_152': ResNet_152,
11 | 'IR_18': IR_18,
12 | 'IR_34': IR_34,
13 | 'IR_50': IR_50,
14 | 'IR_101': IR_101,
15 | 'IR_152': IR_152,
16 | 'IR_200': IR_200,
17 | 'IR_SE_50': IR_SE_50,
18 | 'IR_SE_101': IR_SE_101,
19 | 'IR_SE_152': IR_SE_152,
20 | 'IR_SE_200': IR_SE_200
21 | }
22 |
23 |
24 | def get_model(key):
25 | """ Get different backbone network by key,
26 | support ResNet50, ResNet_101, ResNet_152
27 | IR_18, IR_34, IR_50, IR_101, IR_152, IR_200,
28 | IR_SE_50, IR_SE_101, IR_SE_152, IR_SE_200.
29 | """
30 | if key in _model_dict.keys():
31 | return _model_dict[key]
32 | else:
33 | raise KeyError('not support model {}'.format(key))
34 |
--------------------------------------------------------------------------------
/uniportrait/curricular_face/backbone/common.py:
--------------------------------------------------------------------------------
1 | # The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at
2 | # https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/common.py
3 | import torch.nn as nn
4 | from torch.nn import (Conv2d, Module, ReLU,
5 | Sigmoid)
6 |
7 |
8 | def initialize_weights(modules):
9 | """ Weight initilize, conv2d and linear is initialized with kaiming_normal
10 | """
11 | for m in modules:
12 | if isinstance(m, nn.Conv2d):
13 | nn.init.kaiming_normal_(
14 | m.weight, mode='fan_out', nonlinearity='relu')
15 | if m.bias is not None:
16 | m.bias.data.zero_()
17 | elif isinstance(m, nn.BatchNorm2d):
18 | m.weight.data.fill_(1)
19 | m.bias.data.zero_()
20 | elif isinstance(m, nn.Linear):
21 | nn.init.kaiming_normal_(
22 | m.weight, mode='fan_out', nonlinearity='relu')
23 | if m.bias is not None:
24 | m.bias.data.zero_()
25 |
26 |
27 | class Flatten(Module):
28 | """ Flat tensor
29 | """
30 |
31 | def forward(self, input):
32 | return input.view(input.size(0), -1)
33 |
34 |
35 | class SEModule(Module):
36 | """ SE block
37 | """
38 |
39 | def __init__(self, channels, reduction):
40 | super(SEModule, self).__init__()
41 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
42 | self.fc1 = Conv2d(
43 | channels,
44 | channels // reduction,
45 | kernel_size=1,
46 | padding=0,
47 | bias=False)
48 |
49 | nn.init.xavier_uniform_(self.fc1.weight.data)
50 |
51 | self.relu = ReLU(inplace=True)
52 | self.fc2 = Conv2d(
53 | channels // reduction,
54 | channels,
55 | kernel_size=1,
56 | padding=0,
57 | bias=False)
58 |
59 | self.sigmoid = Sigmoid()
60 |
61 | def forward(self, x):
62 | module_input = x
63 | x = self.avg_pool(x)
64 | x = self.fc1(x)
65 | x = self.relu(x)
66 | x = self.fc2(x)
67 | x = self.sigmoid(x)
68 |
69 | return module_input * x
70 |
--------------------------------------------------------------------------------
/uniportrait/curricular_face/backbone/model_irse.py:
--------------------------------------------------------------------------------
1 | # The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at
2 | # https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_irse.py
3 | from collections import namedtuple
4 |
5 | from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear,
6 | MaxPool2d, Module, PReLU, Sequential)
7 |
8 | from .common import Flatten, SEModule, initialize_weights
9 |
10 |
11 | class BasicBlockIR(Module):
12 | """ BasicBlock for IRNet
13 | """
14 |
15 | def __init__(self, in_channel, depth, stride):
16 | super(BasicBlockIR, self).__init__()
17 | if in_channel == depth:
18 | self.shortcut_layer = MaxPool2d(1, stride)
19 | else:
20 | self.shortcut_layer = Sequential(
21 | Conv2d(in_channel, depth, (1, 1), stride, bias=False),
22 | BatchNorm2d(depth))
23 | self.res_layer = Sequential(
24 | BatchNorm2d(in_channel),
25 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
26 | BatchNorm2d(depth), PReLU(depth),
27 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
28 | BatchNorm2d(depth))
29 |
30 | def forward(self, x):
31 | shortcut = self.shortcut_layer(x)
32 | res = self.res_layer(x)
33 |
34 | return res + shortcut
35 |
36 |
37 | class BottleneckIR(Module):
38 | """ BasicBlock with bottleneck for IRNet
39 | """
40 |
41 | def __init__(self, in_channel, depth, stride):
42 | super(BottleneckIR, self).__init__()
43 | reduction_channel = depth // 4
44 | if in_channel == depth:
45 | self.shortcut_layer = MaxPool2d(1, stride)
46 | else:
47 | self.shortcut_layer = Sequential(
48 | Conv2d(in_channel, depth, (1, 1), stride, bias=False),
49 | BatchNorm2d(depth))
50 | self.res_layer = Sequential(
51 | BatchNorm2d(in_channel),
52 | Conv2d(
53 | in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False),
54 | BatchNorm2d(reduction_channel), PReLU(reduction_channel),
55 | Conv2d(
56 | reduction_channel,
57 | reduction_channel, (3, 3), (1, 1),
58 | 1,
59 | bias=False), BatchNorm2d(reduction_channel),
60 | PReLU(reduction_channel),
61 | Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False),
62 | BatchNorm2d(depth))
63 |
64 | def forward(self, x):
65 | shortcut = self.shortcut_layer(x)
66 | res = self.res_layer(x)
67 |
68 | return res + shortcut
69 |
70 |
71 | class BasicBlockIRSE(BasicBlockIR):
72 |
73 | def __init__(self, in_channel, depth, stride):
74 | super(BasicBlockIRSE, self).__init__(in_channel, depth, stride)
75 | self.res_layer.add_module('se_block', SEModule(depth, 16))
76 |
77 |
78 | class BottleneckIRSE(BottleneckIR):
79 |
80 | def __init__(self, in_channel, depth, stride):
81 | super(BottleneckIRSE, self).__init__(in_channel, depth, stride)
82 | self.res_layer.add_module('se_block', SEModule(depth, 16))
83 |
84 |
85 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
86 | '''A named tuple describing a ResNet block.'''
87 |
88 |
89 | def get_block(in_channel, depth, num_units, stride=2):
90 | return [Bottleneck(in_channel, depth, stride)] + \
91 | [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
92 |
93 |
94 | def get_blocks(num_layers):
95 | if num_layers == 18:
96 | blocks = [
97 | get_block(in_channel=64, depth=64, num_units=2),
98 | get_block(in_channel=64, depth=128, num_units=2),
99 | get_block(in_channel=128, depth=256, num_units=2),
100 | get_block(in_channel=256, depth=512, num_units=2)
101 | ]
102 | elif num_layers == 34:
103 | blocks = [
104 | get_block(in_channel=64, depth=64, num_units=3),
105 | get_block(in_channel=64, depth=128, num_units=4),
106 | get_block(in_channel=128, depth=256, num_units=6),
107 | get_block(in_channel=256, depth=512, num_units=3)
108 | ]
109 | elif num_layers == 50:
110 | blocks = [
111 | get_block(in_channel=64, depth=64, num_units=3),
112 | get_block(in_channel=64, depth=128, num_units=4),
113 | get_block(in_channel=128, depth=256, num_units=14),
114 | get_block(in_channel=256, depth=512, num_units=3)
115 | ]
116 | elif num_layers == 100:
117 | blocks = [
118 | get_block(in_channel=64, depth=64, num_units=3),
119 | get_block(in_channel=64, depth=128, num_units=13),
120 | get_block(in_channel=128, depth=256, num_units=30),
121 | get_block(in_channel=256, depth=512, num_units=3)
122 | ]
123 | elif num_layers == 152:
124 | blocks = [
125 | get_block(in_channel=64, depth=256, num_units=3),
126 | get_block(in_channel=256, depth=512, num_units=8),
127 | get_block(in_channel=512, depth=1024, num_units=36),
128 | get_block(in_channel=1024, depth=2048, num_units=3)
129 | ]
130 | elif num_layers == 200:
131 | blocks = [
132 | get_block(in_channel=64, depth=256, num_units=3),
133 | get_block(in_channel=256, depth=512, num_units=24),
134 | get_block(in_channel=512, depth=1024, num_units=36),
135 | get_block(in_channel=1024, depth=2048, num_units=3)
136 | ]
137 |
138 | return blocks
139 |
140 |
141 | class Backbone(Module):
142 |
143 | def __init__(self, input_size, num_layers, mode='ir'):
144 | """ Args:
145 | input_size: input_size of backbone
146 | num_layers: num_layers of backbone
147 | mode: support ir or irse
148 | """
149 | super(Backbone, self).__init__()
150 | assert input_size[0] in [112, 224], \
151 | 'input_size should be [112, 112] or [224, 224]'
152 | assert num_layers in [18, 34, 50, 100, 152, 200], \
153 | 'num_layers should be 18, 34, 50, 100 or 152'
154 | assert mode in ['ir', 'ir_se'], \
155 | 'mode should be ir or ir_se'
156 | self.input_layer = Sequential(
157 | Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64),
158 | PReLU(64))
159 | blocks = get_blocks(num_layers)
160 | if num_layers <= 100:
161 | if mode == 'ir':
162 | unit_module = BasicBlockIR
163 | elif mode == 'ir_se':
164 | unit_module = BasicBlockIRSE
165 | output_channel = 512
166 | else:
167 | if mode == 'ir':
168 | unit_module = BottleneckIR
169 | elif mode == 'ir_se':
170 | unit_module = BottleneckIRSE
171 | output_channel = 2048
172 |
173 | if input_size[0] == 112:
174 | self.output_layer = Sequential(
175 | BatchNorm2d(output_channel), Dropout(0.4), Flatten(),
176 | Linear(output_channel * 7 * 7, 512),
177 | BatchNorm1d(512, affine=False))
178 | else:
179 | self.output_layer = Sequential(
180 | BatchNorm2d(output_channel), Dropout(0.4), Flatten(),
181 | Linear(output_channel * 14 * 14, 512),
182 | BatchNorm1d(512, affine=False))
183 |
184 | modules = []
185 | mid_layer_indices = [] # [2, 15, 45, 48], total 49 layers for IR101
186 | for block in blocks:
187 | if len(mid_layer_indices) == 0:
188 | mid_layer_indices.append(len(block) - 1)
189 | else:
190 | mid_layer_indices.append(len(block) + mid_layer_indices[-1])
191 | for bottleneck in block:
192 | modules.append(
193 | unit_module(bottleneck.in_channel, bottleneck.depth,
194 | bottleneck.stride))
195 | self.body = Sequential(*modules)
196 | self.mid_layer_indices = mid_layer_indices[-4:]
197 |
198 | initialize_weights(self.modules())
199 |
200 | def forward(self, x, return_mid_feats=False):
201 | x = self.input_layer(x)
202 | if not return_mid_feats:
203 | x = self.body(x)
204 | x = self.output_layer(x)
205 | return x
206 | else:
207 | out_feats = []
208 | for idx, module in enumerate(self.body):
209 | x = module(x)
210 | if idx in self.mid_layer_indices:
211 | out_feats.append(x)
212 | x = self.output_layer(x)
213 | return x, out_feats
214 |
215 |
216 | def IR_18(input_size):
217 | """ Constructs a ir-18 model.
218 | """
219 | model = Backbone(input_size, 18, 'ir')
220 |
221 | return model
222 |
223 |
224 | def IR_34(input_size):
225 | """ Constructs a ir-34 model.
226 | """
227 | model = Backbone(input_size, 34, 'ir')
228 |
229 | return model
230 |
231 |
232 | def IR_50(input_size):
233 | """ Constructs a ir-50 model.
234 | """
235 | model = Backbone(input_size, 50, 'ir')
236 |
237 | return model
238 |
239 |
240 | def IR_101(input_size):
241 | """ Constructs a ir-101 model.
242 | """
243 | model = Backbone(input_size, 100, 'ir')
244 |
245 | return model
246 |
247 |
248 | def IR_152(input_size):
249 | """ Constructs a ir-152 model.
250 | """
251 | model = Backbone(input_size, 152, 'ir')
252 |
253 | return model
254 |
255 |
256 | def IR_200(input_size):
257 | """ Constructs a ir-200 model.
258 | """
259 | model = Backbone(input_size, 200, 'ir')
260 |
261 | return model
262 |
263 |
264 | def IR_SE_50(input_size):
265 | """ Constructs a ir_se-50 model.
266 | """
267 | model = Backbone(input_size, 50, 'ir_se')
268 |
269 | return model
270 |
271 |
272 | def IR_SE_101(input_size):
273 | """ Constructs a ir_se-101 model.
274 | """
275 | model = Backbone(input_size, 100, 'ir_se')
276 |
277 | return model
278 |
279 |
280 | def IR_SE_152(input_size):
281 | """ Constructs a ir_se-152 model.
282 | """
283 | model = Backbone(input_size, 152, 'ir_se')
284 |
285 | return model
286 |
287 |
288 | def IR_SE_200(input_size):
289 | """ Constructs a ir_se-200 model.
290 | """
291 | model = Backbone(input_size, 200, 'ir_se')
292 |
293 | return model
294 |
--------------------------------------------------------------------------------
/uniportrait/curricular_face/backbone/model_resnet.py:
--------------------------------------------------------------------------------
1 | # The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at
2 | # https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_resnet.py
3 | import torch.nn as nn
4 | from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear,
5 | MaxPool2d, Module, ReLU, Sequential)
6 |
7 | from .common import initialize_weights
8 |
9 |
10 | def conv3x3(in_planes, out_planes, stride=1):
11 | """ 3x3 convolution with padding
12 | """
13 | return Conv2d(
14 | in_planes,
15 | out_planes,
16 | kernel_size=3,
17 | stride=stride,
18 | padding=1,
19 | bias=False)
20 |
21 |
22 | def conv1x1(in_planes, out_planes, stride=1):
23 | """ 1x1 convolution
24 | """
25 | return Conv2d(
26 | in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
27 |
28 |
29 | class Bottleneck(Module):
30 | expansion = 4
31 |
32 | def __init__(self, inplanes, planes, stride=1, downsample=None):
33 | super(Bottleneck, self).__init__()
34 | self.conv1 = conv1x1(inplanes, planes)
35 | self.bn1 = BatchNorm2d(planes)
36 | self.conv2 = conv3x3(planes, planes, stride)
37 | self.bn2 = BatchNorm2d(planes)
38 | self.conv3 = conv1x1(planes, planes * self.expansion)
39 | self.bn3 = BatchNorm2d(planes * self.expansion)
40 | self.relu = ReLU(inplace=True)
41 | self.downsample = downsample
42 | self.stride = stride
43 |
44 | def forward(self, x):
45 | identity = x
46 |
47 | out = self.conv1(x)
48 | out = self.bn1(out)
49 | out = self.relu(out)
50 |
51 | out = self.conv2(out)
52 | out = self.bn2(out)
53 | out = self.relu(out)
54 |
55 | out = self.conv3(out)
56 | out = self.bn3(out)
57 |
58 | if self.downsample is not None:
59 | identity = self.downsample(x)
60 |
61 | out += identity
62 | out = self.relu(out)
63 |
64 | return out
65 |
66 |
67 | class ResNet(Module):
68 | """ ResNet backbone
69 | """
70 |
71 | def __init__(self, input_size, block, layers, zero_init_residual=True):
72 | """ Args:
73 | input_size: input_size of backbone
74 | block: block function
75 | layers: layers in each block
76 | """
77 | super(ResNet, self).__init__()
78 | assert input_size[0] in [112, 224], \
79 | 'input_size should be [112, 112] or [224, 224]'
80 | self.inplanes = 64
81 | self.conv1 = Conv2d(
82 | 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
83 | self.bn1 = BatchNorm2d(64)
84 | self.relu = ReLU(inplace=True)
85 | self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1)
86 | self.layer1 = self._make_layer(block, 64, layers[0])
87 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
88 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
89 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
90 |
91 | self.bn_o1 = BatchNorm2d(2048)
92 | self.dropout = Dropout()
93 | if input_size[0] == 112:
94 | self.fc = Linear(2048 * 4 * 4, 512)
95 | else:
96 | self.fc = Linear(2048 * 7 * 7, 512)
97 | self.bn_o2 = BatchNorm1d(512)
98 |
99 | initialize_weights(self.modules)
100 | if zero_init_residual:
101 | for m in self.modules():
102 | if isinstance(m, Bottleneck):
103 | nn.init.constant_(m.bn3.weight, 0)
104 |
105 | def _make_layer(self, block, planes, blocks, stride=1):
106 | downsample = None
107 | if stride != 1 or self.inplanes != planes * block.expansion:
108 | downsample = Sequential(
109 | conv1x1(self.inplanes, planes * block.expansion, stride),
110 | BatchNorm2d(planes * block.expansion),
111 | )
112 |
113 | layers = []
114 | layers.append(block(self.inplanes, planes, stride, downsample))
115 | self.inplanes = planes * block.expansion
116 | for _ in range(1, blocks):
117 | layers.append(block(self.inplanes, planes))
118 |
119 | return Sequential(*layers)
120 |
121 | def forward(self, x):
122 | x = self.conv1(x)
123 | x = self.bn1(x)
124 | x = self.relu(x)
125 | x = self.maxpool(x)
126 |
127 | x = self.layer1(x)
128 | x = self.layer2(x)
129 | x = self.layer3(x)
130 | x = self.layer4(x)
131 |
132 | x = self.bn_o1(x)
133 | x = self.dropout(x)
134 | x = x.view(x.size(0), -1)
135 | x = self.fc(x)
136 | x = self.bn_o2(x)
137 |
138 | return x
139 |
140 |
141 | def ResNet_50(input_size, **kwargs):
142 | """ Constructs a ResNet-50 model.
143 | """
144 | model = ResNet(input_size, Bottleneck, [3, 4, 6, 3], **kwargs)
145 |
146 | return model
147 |
148 |
149 | def ResNet_101(input_size, **kwargs):
150 | """ Constructs a ResNet-101 model.
151 | """
152 | model = ResNet(input_size, Bottleneck, [3, 4, 23, 3], **kwargs)
153 |
154 | return model
155 |
156 |
157 | def ResNet_152(input_size, **kwargs):
158 | """ Constructs a ResNet-152 model.
159 | """
160 | model = ResNet(input_size, Bottleneck, [3, 8, 36, 3], **kwargs)
161 |
162 | return model
163 |
--------------------------------------------------------------------------------
/uniportrait/curricular_face/inference.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 |
4 | import cv2
5 | import numpy as np
6 | import torch
7 | from tqdm.auto import tqdm
8 |
9 | from .backbone import get_model
10 |
11 |
12 | @torch.no_grad()
13 | def inference(name, weight, src_norm_dir):
14 | face_model = get_model(name)([112, 112])
15 | face_model.load_state_dict(torch.load(weight, map_location="cpu"))
16 | face_model = face_model.to("cpu")
17 | face_model.eval()
18 |
19 | id2src_norm = {}
20 | for src_id in sorted(list(os.listdir(src_norm_dir))):
21 | id2src_norm[src_id] = sorted(list(glob.glob(f"{os.path.join(src_norm_dir, src_id)}/*")))
22 |
23 | total_sims = []
24 | for id_name in tqdm(id2src_norm):
25 | src_face_embeddings = []
26 | for src_img_path in id2src_norm[id_name]:
27 | src_img = cv2.imread(src_img_path)
28 | src_img = cv2.resize(src_img, (112, 112))
29 | src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
30 | src_img = np.transpose(src_img, (2, 0, 1))
31 | src_img = torch.from_numpy(src_img).unsqueeze(0).float()
32 | src_img.div_(255).sub_(0.5).div_(0.5)
33 | embedding = face_model(src_img).detach().cpu().numpy()[0]
34 | embedding = embedding / np.linalg.norm(embedding)
35 | src_face_embeddings.append(embedding) # 512
36 |
37 | num = len(src_face_embeddings)
38 | src_face_embeddings = np.stack(src_face_embeddings) # n, 512
39 | sim = src_face_embeddings @ src_face_embeddings.T # n, n
40 | mean_sim = (np.sum(sim) - num * 1.0) / ((num - 1) * num)
41 | print(f"{id_name}: {mean_sim}")
42 | total_sims.append(mean_sim)
43 |
44 | return np.mean(total_sims)
45 |
46 |
47 | if __name__ == "__main__":
48 | name = 'IR_101'
49 | weight = "models/glint360k_curricular_face_r101_backbone.bin"
50 | src_norm_dir = "/disk1/hejunjie.hjj/data/normface-AFD-id-20"
51 | mean_sim = inference(name, weight, src_norm_dir)
52 | print(f"total: {mean_sim:.4f}") # total: 0.6299
53 |
--------------------------------------------------------------------------------
/uniportrait/inversion.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/google/style-aligned/blob/main/inversion.py
2 |
3 | from __future__ import annotations
4 |
5 | from typing import Callable
6 |
7 | import numpy as np
8 | import torch
9 | from diffusers import StableDiffusionPipeline
10 | from tqdm import tqdm
11 |
12 | T = torch.Tensor
13 | InversionCallback = Callable[[StableDiffusionPipeline, int, T, dict[str, T]], dict[str, T]]
14 |
15 |
16 | def _encode_text_with_negative(model: StableDiffusionPipeline, prompt: str) -> tuple[dict[str, T], T]:
17 | device = model._execution_device
18 | prompt_embeds = model._encode_prompt(
19 | prompt, device=device, num_images_per_prompt=1, do_classifier_free_guidance=True,
20 | negative_prompt="")
21 | return prompt_embeds
22 |
23 |
24 | def _encode_image(model: StableDiffusionPipeline, image: np.ndarray) -> T:
25 | model.vae.to(dtype=torch.float32)
26 | image = torch.from_numpy(image).float() / 255.
27 | image = (image * 2 - 1).permute(2, 0, 1).unsqueeze(0)
28 | latent = model.vae.encode(image.to(model.vae.device))['latent_dist'].mean * model.vae.config.scaling_factor
29 | model.vae.to(dtype=torch.float16)
30 | return latent
31 |
32 |
33 | def _next_step(model: StableDiffusionPipeline, model_output: T, timestep: int, sample: T) -> T:
34 | timestep, next_timestep = min(
35 | timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps, 999), timestep
36 | alpha_prod_t = model.scheduler.alphas_cumprod[
37 | int(timestep)] if timestep >= 0 else model.scheduler.final_alpha_cumprod
38 | alpha_prod_t_next = model.scheduler.alphas_cumprod[int(next_timestep)]
39 | beta_prod_t = 1 - alpha_prod_t
40 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
41 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
42 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
43 | return next_sample
44 |
45 |
46 | def _get_noise_pred(model: StableDiffusionPipeline, latent: T, t: T, context: T, guidance_scale: float):
47 | latents_input = torch.cat([latent] * 2)
48 | noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
49 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
50 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
51 | # latents = next_step(model, noise_pred, t, latent)
52 | return noise_pred
53 |
54 |
55 | def _ddim_loop(model: StableDiffusionPipeline, z0, prompt, guidance_scale) -> T:
56 | all_latent = [z0]
57 | text_embedding = _encode_text_with_negative(model, prompt)
58 | image_embedding = torch.zeros_like(text_embedding[:, :1]).repeat(1, 4, 1) # for ip embedding
59 | text_embedding = torch.cat([text_embedding, image_embedding], dim=1)
60 | latent = z0.clone().detach().half()
61 | for i in tqdm(range(model.scheduler.num_inference_steps)):
62 | t = model.scheduler.timesteps[len(model.scheduler.timesteps) - i - 1]
63 | noise_pred = _get_noise_pred(model, latent, t, text_embedding, guidance_scale)
64 | latent = _next_step(model, noise_pred, t, latent)
65 | all_latent.append(latent)
66 | return torch.cat(all_latent).flip(0)
67 |
68 |
69 | def make_inversion_callback(zts, offset: int = 0) -> [T, InversionCallback]:
70 | def callback_on_step_end(pipeline: StableDiffusionPipeline, i: int, t: T, callback_kwargs: dict[str, T]) -> dict[
71 | str, T]:
72 | latents = callback_kwargs['latents']
73 | latents[0] = zts[max(offset + 1, i + 1)].to(latents.device, latents.dtype)
74 | return {'latents': latents}
75 |
76 | return zts[offset], callback_on_step_end
77 |
78 |
79 | @torch.no_grad()
80 | def ddim_inversion(model: StableDiffusionPipeline, x0: np.ndarray, prompt: str, num_inference_steps: int,
81 | guidance_scale, ) -> T:
82 | z0 = _encode_image(model, x0)
83 | model.scheduler.set_timesteps(num_inference_steps, device=z0.device)
84 | zs = _ddim_loop(model, z0, prompt, guidance_scale)
85 | return zs
86 |
--------------------------------------------------------------------------------
/uniportrait/resampler.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3 |
4 | import math
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | # FFN
11 | def FeedForward(dim, mult=4):
12 | inner_dim = int(dim * mult)
13 | return nn.Sequential(
14 | nn.LayerNorm(dim),
15 | nn.Linear(dim, inner_dim, bias=False),
16 | nn.GELU(),
17 | nn.Linear(inner_dim, dim, bias=False),
18 | )
19 |
20 |
21 | def reshape_tensor(x, heads):
22 | bs, length, width = x.shape
23 | # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
24 | x = x.view(bs, length, heads, -1)
25 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
26 | x = x.transpose(1, 2)
27 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
28 | x = x.reshape(bs, heads, length, -1)
29 | return x
30 |
31 |
32 | class PerceiverAttention(nn.Module):
33 | def __init__(self, *, dim, dim_head=64, heads=8):
34 | super().__init__()
35 | self.scale = dim_head ** -0.5
36 | self.dim_head = dim_head
37 | self.heads = heads
38 | inner_dim = dim_head * heads
39 |
40 | self.norm1 = nn.LayerNorm(dim)
41 | self.norm2 = nn.LayerNorm(dim)
42 |
43 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
44 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
45 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
46 |
47 | def forward(self, x, latents, attention_mask=None):
48 | """
49 | Args:
50 | x (torch.Tensor): image features
51 | shape (b, n1, D)
52 | latents (torch.Tensor): latent features
53 | shape (b, n2, D)
54 | attention_mask (torch.Tensor): attention mask
55 | shape (b, n1, 1)
56 | """
57 | x = self.norm1(x)
58 | latents = self.norm2(latents)
59 |
60 | b, l, _ = latents.shape
61 |
62 | q = self.to_q(latents)
63 | kv_input = torch.cat((x, latents), dim=-2)
64 | k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65 |
66 | q = reshape_tensor(q, self.heads)
67 | k = reshape_tensor(k, self.heads)
68 | v = reshape_tensor(v, self.heads)
69 |
70 | # attention
71 | scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
73 | if attention_mask is not None:
74 | attention_mask = attention_mask.transpose(1, 2) # (b, 1, n1)
75 | attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, :, :1]).repeat(1, 1, l)],
76 | dim=2) # b, 1, n1+n2
77 | attention_mask = (attention_mask - 1.) * 100. # 0 means kept and -100 means dropped
78 | attention_mask = attention_mask.unsqueeze(1)
79 | weight = weight + attention_mask # b, h, n2, n1+n2
80 |
81 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
82 | out = weight @ v
83 |
84 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
85 |
86 | return self.to_out(out)
87 |
88 |
89 | class UniPortraitFaceIDResampler(torch.nn.Module):
90 | def __init__(
91 | self,
92 | intrinsic_id_embedding_dim=512,
93 | structure_embedding_dim=64 + 128 + 256 + 1280,
94 | num_tokens=16,
95 | depth=6,
96 | dim=768,
97 | dim_head=64,
98 | heads=12,
99 | ff_mult=4,
100 | output_dim=768,
101 | ):
102 | super().__init__()
103 |
104 | self.latents = torch.nn.Parameter(torch.randn(1, num_tokens, dim) / dim ** 0.5)
105 |
106 | self.proj_id = torch.nn.Sequential(
107 | torch.nn.Linear(intrinsic_id_embedding_dim, intrinsic_id_embedding_dim * 2),
108 | torch.nn.GELU(),
109 | torch.nn.Linear(intrinsic_id_embedding_dim * 2, dim),
110 | )
111 | self.proj_clip = torch.nn.Sequential(
112 | torch.nn.Linear(structure_embedding_dim, structure_embedding_dim * 2),
113 | torch.nn.GELU(),
114 | torch.nn.Linear(structure_embedding_dim * 2, dim),
115 | )
116 |
117 | self.layers = torch.nn.ModuleList([])
118 | for _ in range(depth):
119 | self.layers.append(
120 | torch.nn.ModuleList(
121 | [
122 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
123 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
124 | FeedForward(dim=dim, mult=ff_mult),
125 | ]
126 | )
127 | )
128 |
129 | self.proj_out = torch.nn.Linear(dim, output_dim)
130 | self.norm_out = torch.nn.LayerNorm(output_dim)
131 |
132 | def forward(
133 | self,
134 | intrinsic_id_embeds,
135 | structure_embeds,
136 | structure_scale=1.0,
137 | intrinsic_id_attention_mask=None,
138 | structure_attention_mask=None
139 | ):
140 |
141 | latents = self.latents.repeat(intrinsic_id_embeds.size(0), 1, 1)
142 |
143 | intrinsic_id_embeds = self.proj_id(intrinsic_id_embeds)
144 | structure_embeds = self.proj_clip(structure_embeds)
145 |
146 | for attn1, attn2, ff in self.layers:
147 | latents = attn1(intrinsic_id_embeds, latents, intrinsic_id_attention_mask) + latents
148 | latents = structure_scale * attn2(structure_embeds, latents, structure_attention_mask) + latents
149 | latents = ff(latents) + latents
150 |
151 | latents = self.proj_out(latents)
152 | return self.norm_out(latents)
153 |
--------------------------------------------------------------------------------
/uniportrait/uniportrait_attention_processor.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from diffusers.models.lora import LoRALinearLayer
6 |
7 |
8 | class AttentionArgs(object):
9 | def __init__(self) -> None:
10 | # ip condition
11 | self.ip_scale = 0.0
12 | self.ip_mask = None # ip attention mask
13 |
14 | # faceid condition
15 | self.lora_scale = 0.0 # lora for single faceid
16 | self.multi_id_lora_scale = 0.0 # lora for multiple faceids
17 | self.faceid_scale = 0.0
18 | self.num_faceids = 0
19 | self.faceid_mask = None # faceid attention mask; if not None, it will override the routing map
20 |
21 | # style aligned
22 | self.enable_share_attn: bool = False
23 | self.adain_queries_and_keys: bool = False
24 | self.shared_score_scale: float = 1.0
25 | self.shared_score_shift: float = 0.0
26 |
27 | def reset(self):
28 | # ip condition
29 | self.ip_scale = 0.0
30 | self.ip_mask = None # ip attention mask
31 |
32 | # faceid condition
33 | self.lora_scale = 0.0 # lora for single faceid
34 | self.multi_id_lora_scale = 0.0 # lora for multiple faceids
35 | self.faceid_scale = 0.0
36 | self.num_faceids = 0
37 | self.faceid_mask = None # faceid attention mask; if not None, it will override the routing map
38 |
39 | # style aligned
40 | self.enable_share_attn: bool = False
41 | self.adain_queries_and_keys: bool = False
42 | self.shared_score_scale: float = 1.0
43 | self.shared_score_shift: float = 0.0
44 |
45 | def __repr__(self):
46 | indent_str = ' '
47 | s = f",\n{indent_str}".join(f"{attr}={value}" for attr, value in vars(self).items())
48 | return self.__class__.__name__ + '(' + f'\n{indent_str}' + s + ')'
49 |
50 |
51 | attn_args = AttentionArgs()
52 |
53 |
54 | def expand_first(feat, scale=1., ):
55 | b = feat.shape[0]
56 | feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1)
57 | if scale == 1:
58 | feat_style = feat_style.expand(2, b // 2, *feat.shape[1:])
59 | else:
60 | feat_style = feat_style.repeat(1, b // 2, 1, 1, 1)
61 | feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1)
62 | return feat_style.reshape(*feat.shape)
63 |
64 |
65 | def concat_first(feat, dim=2, scale=1.):
66 | feat_style = expand_first(feat, scale=scale)
67 | return torch.cat((feat, feat_style), dim=dim)
68 |
69 |
70 | def calc_mean_std(feat, eps: float = 1e-5):
71 | feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt()
72 | feat_mean = feat.mean(dim=-2, keepdims=True)
73 | return feat_mean, feat_std
74 |
75 |
76 | def adain(feat):
77 | feat_mean, feat_std = calc_mean_std(feat)
78 | feat_style_mean = expand_first(feat_mean)
79 | feat_style_std = expand_first(feat_std)
80 | feat = (feat - feat_mean) / feat_std
81 | feat = feat * feat_style_std + feat_style_mean
82 | return feat
83 |
84 |
85 | class UniPortraitLoRAAttnProcessor2_0(nn.Module):
86 |
87 | def __init__(
88 | self,
89 | hidden_size=None,
90 | cross_attention_dim=None,
91 | rank=128,
92 | network_alpha=None,
93 | ):
94 | super().__init__()
95 |
96 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
97 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
98 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
99 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
100 |
101 | self.to_q_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
102 | self.to_k_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
103 | self.to_v_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
104 | self.to_out_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
105 |
106 | def __call__(
107 | self,
108 | attn,
109 | hidden_states,
110 | encoder_hidden_states=None,
111 | attention_mask=None,
112 | temb=None,
113 | *args,
114 | **kwargs,
115 | ):
116 | residual = hidden_states
117 |
118 | if attn.spatial_norm is not None:
119 | hidden_states = attn.spatial_norm(hidden_states, temb)
120 |
121 | input_ndim = hidden_states.ndim
122 |
123 | if input_ndim == 4:
124 | batch_size, channel, height, width = hidden_states.shape
125 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
126 |
127 | batch_size, sequence_length, _ = (
128 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
129 | )
130 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
131 |
132 | if attn.group_norm is not None:
133 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
134 |
135 | if encoder_hidden_states is None:
136 | encoder_hidden_states = hidden_states
137 | elif attn.norm_cross:
138 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
139 |
140 | query = attn.to_q(hidden_states)
141 | key = attn.to_k(encoder_hidden_states)
142 | value = attn.to_v(encoder_hidden_states)
143 | if attn_args.lora_scale > 0.0:
144 | query = query + attn_args.lora_scale * self.to_q_lora(hidden_states)
145 | key = key + attn_args.lora_scale * self.to_k_lora(encoder_hidden_states)
146 | value = value + attn_args.lora_scale * self.to_v_lora(encoder_hidden_states)
147 | elif attn_args.multi_id_lora_scale > 0.0:
148 | query = query + attn_args.multi_id_lora_scale * self.to_q_multi_id_lora(hidden_states)
149 | key = key + attn_args.multi_id_lora_scale * self.to_k_multi_id_lora(encoder_hidden_states)
150 | value = value + attn_args.multi_id_lora_scale * self.to_v_multi_id_lora(encoder_hidden_states)
151 |
152 | inner_dim = key.shape[-1]
153 | head_dim = inner_dim // attn.heads
154 |
155 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
156 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
157 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
158 |
159 | if attn_args.enable_share_attn:
160 | if attn_args.adain_queries_and_keys:
161 | query = adain(query)
162 | key = adain(key)
163 | key = concat_first(key, -2, scale=attn_args.shared_score_scale)
164 | value = concat_first(value, -2)
165 | if attn_args.shared_score_shift != 0:
166 | attention_mask = torch.zeros_like(key[:, :, :, :1]).transpose(-1, -2) # b, h, 1, k
167 | attention_mask[:, :, :, query.shape[2]:] += attn_args.shared_score_shift
168 | hidden_states = F.scaled_dot_product_attention(
169 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
170 | )
171 | else:
172 | hidden_states = F.scaled_dot_product_attention(
173 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
174 | )
175 | else:
176 | hidden_states = F.scaled_dot_product_attention(
177 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
178 | )
179 |
180 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
181 | hidden_states = hidden_states.to(query.dtype)
182 |
183 | # linear proj
184 | output_hidden_states = attn.to_out[0](hidden_states)
185 | if attn_args.lora_scale > 0.0:
186 | output_hidden_states = output_hidden_states + attn_args.lora_scale * self.to_out_lora(hidden_states)
187 | elif attn_args.multi_id_lora_scale > 0.0:
188 | output_hidden_states = output_hidden_states + attn_args.multi_id_lora_scale * self.to_out_multi_id_lora(
189 | hidden_states)
190 | hidden_states = output_hidden_states
191 |
192 | # dropout
193 | hidden_states = attn.to_out[1](hidden_states)
194 |
195 | if input_ndim == 4:
196 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
197 |
198 | if attn.residual_connection:
199 | hidden_states = hidden_states + residual
200 |
201 | hidden_states = hidden_states / attn.rescale_output_factor
202 |
203 | return hidden_states
204 |
205 |
206 | class UniPortraitLoRAIPAttnProcessor2_0(nn.Module):
207 |
208 | def __init__(self, hidden_size, cross_attention_dim=None, rank=128, network_alpha=None,
209 | num_ip_tokens=4, num_faceid_tokens=16):
210 | super().__init__()
211 |
212 | self.num_ip_tokens = num_ip_tokens
213 | self.num_faceid_tokens = num_faceid_tokens
214 |
215 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
216 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
217 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
218 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
219 |
220 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
221 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
222 |
223 | self.to_k_faceid = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
224 | self.to_v_faceid = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
225 |
226 | self.to_q_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
227 | self.to_k_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
228 | self.to_v_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
229 | self.to_out_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
230 |
231 | self.to_q_router = nn.Sequential(
232 | nn.Linear(hidden_size, hidden_size * 2),
233 | nn.GELU(),
234 | nn.Linear(hidden_size * 2, hidden_size, bias=False),
235 | )
236 | self.to_k_router = nn.Sequential(
237 | nn.Linear(cross_attention_dim or hidden_size, (cross_attention_dim or hidden_size) * 2),
238 | nn.GELU(),
239 | nn.Linear((cross_attention_dim or hidden_size) * 2, hidden_size, bias=False),
240 | )
241 | self.aggr_router = nn.Linear(num_faceid_tokens, 1)
242 |
243 | def __call__(
244 | self,
245 | attn,
246 | hidden_states,
247 | encoder_hidden_states=None,
248 | attention_mask=None,
249 | temb=None,
250 | *args,
251 | **kwargs,
252 | ):
253 | residual = hidden_states
254 |
255 | if attn.spatial_norm is not None:
256 | hidden_states = attn.spatial_norm(hidden_states, temb)
257 |
258 | input_ndim = hidden_states.ndim
259 |
260 | if input_ndim == 4:
261 | batch_size, channel, height, width = hidden_states.shape
262 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
263 |
264 | batch_size, sequence_length, _ = (
265 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
266 | )
267 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
268 |
269 | if attn.group_norm is not None:
270 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
271 |
272 | if encoder_hidden_states is None:
273 | encoder_hidden_states = hidden_states
274 | else:
275 | # split hidden states
276 | faceid_end = encoder_hidden_states.shape[1]
277 | ip_end = faceid_end - self.num_faceid_tokens * attn_args.num_faceids
278 | text_end = ip_end - self.num_ip_tokens
279 |
280 | prompt_hidden_states = encoder_hidden_states[:, :text_end]
281 | ip_hidden_states = encoder_hidden_states[:, text_end: ip_end]
282 | faceid_hidden_states = encoder_hidden_states[:, ip_end: faceid_end]
283 |
284 | encoder_hidden_states = prompt_hidden_states
285 | if attn.norm_cross:
286 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
287 |
288 | # for router
289 | if attn_args.num_faceids > 1:
290 | router_query = self.to_q_router(hidden_states) # bs, s*s, dim
291 | router_hidden_states = faceid_hidden_states.reshape(batch_size, attn_args.num_faceids,
292 | self.num_faceid_tokens, -1) # bs, num, id_tokens, d
293 | router_hidden_states = self.aggr_router(router_hidden_states.transpose(-1, -2)).squeeze(-1) # bs, num, d
294 | router_key = self.to_k_router(router_hidden_states) # bs, num, dim
295 | router_logits = torch.bmm(router_query, router_key.transpose(-1, -2)) # bs, s*s, num
296 | index = router_logits.max(dim=-1, keepdim=True)[1]
297 | routing_map = torch.zeros_like(router_logits).scatter_(-1, index, 1.0)
298 | routing_map = routing_map.transpose(1, 2).unsqueeze(-1) # bs, num, s*s, 1
299 | else:
300 | routing_map = hidden_states.new_ones(size=(1, 1, hidden_states.shape[1], 1))
301 |
302 | # for text
303 | query = attn.to_q(hidden_states)
304 | key = attn.to_k(encoder_hidden_states)
305 | value = attn.to_v(encoder_hidden_states)
306 | if attn_args.lora_scale > 0.0:
307 | query = query + attn_args.lora_scale * self.to_q_lora(hidden_states)
308 | key = key + attn_args.lora_scale * self.to_k_lora(encoder_hidden_states)
309 | value = value + attn_args.lora_scale * self.to_v_lora(encoder_hidden_states)
310 | elif attn_args.multi_id_lora_scale > 0.0:
311 | query = query + attn_args.multi_id_lora_scale * self.to_q_multi_id_lora(hidden_states)
312 | key = key + attn_args.multi_id_lora_scale * self.to_k_multi_id_lora(encoder_hidden_states)
313 | value = value + attn_args.multi_id_lora_scale * self.to_v_multi_id_lora(encoder_hidden_states)
314 |
315 | inner_dim = key.shape[-1]
316 | head_dim = inner_dim // attn.heads
317 |
318 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
319 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
320 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
321 |
322 | hidden_states = F.scaled_dot_product_attention(
323 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
324 | )
325 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
326 | hidden_states = hidden_states.to(query.dtype)
327 |
328 | # for ip-adapter
329 | if attn_args.ip_scale > 0.0:
330 | ip_key = self.to_k_ip(ip_hidden_states)
331 | ip_value = self.to_v_ip(ip_hidden_states)
332 |
333 | ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
334 | ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
335 |
336 | ip_hidden_states = F.scaled_dot_product_attention(
337 | query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=attn.scale
338 | )
339 | ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
340 | ip_hidden_states = ip_hidden_states.to(query.dtype)
341 |
342 | if attn_args.ip_mask is not None:
343 | ip_mask = attn_args.ip_mask
344 | h, w = ip_mask.shape[-2:]
345 | ratio = (h * w / query.shape[2]) ** 0.5
346 | ip_mask = torch.nn.functional.interpolate(ip_mask, scale_factor=1 / ratio,
347 | mode='nearest').reshape(
348 | [1, -1, 1])
349 | ip_hidden_states = ip_hidden_states * ip_mask
350 |
351 | if attn_args.enable_share_attn:
352 | ip_hidden_states[0] = 0.
353 | ip_hidden_states[batch_size // 2] = 0.
354 | else:
355 | ip_hidden_states = torch.zeros_like(hidden_states)
356 |
357 | # for faceid-adapter
358 | if attn_args.faceid_scale > 0.0:
359 | faceid_key = self.to_k_faceid(faceid_hidden_states)
360 | faceid_value = self.to_v_faceid(faceid_hidden_states)
361 |
362 | faceid_query = query[:, None].expand(-1, attn_args.num_faceids, -1, -1,
363 | -1) # 2*bs, num, heads, s*s, dim/heads
364 | faceid_key = faceid_key.view(batch_size, attn_args.num_faceids, self.num_faceid_tokens, attn.heads,
365 | head_dim).transpose(2, 3)
366 | faceid_value = faceid_value.view(batch_size, attn_args.num_faceids, self.num_faceid_tokens, attn.heads,
367 | head_dim).transpose(2, 3)
368 |
369 | faceid_hidden_states = F.scaled_dot_product_attention(
370 | faceid_query, faceid_key, faceid_value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=attn.scale
371 | ) # 2*bs, num, heads, s*s, dim/heads
372 |
373 | faceid_hidden_states = faceid_hidden_states.transpose(2, 3).reshape(batch_size, attn_args.num_faceids, -1,
374 | attn.heads * head_dim)
375 | faceid_hidden_states = faceid_hidden_states.to(query.dtype) # 2*bs, num, s*s, dim
376 |
377 | if attn_args.faceid_mask is not None:
378 | faceid_mask = attn_args.faceid_mask # 1, num, h, w
379 | h, w = faceid_mask.shape[-2:]
380 | ratio = (h * w / query.shape[2]) ** 0.5
381 | faceid_mask = F.interpolate(faceid_mask, scale_factor=1 / ratio,
382 | mode='bilinear').flatten(2).unsqueeze(-1) # 1, num, s*s, 1
383 | faceid_mask = faceid_mask / faceid_mask.sum(1, keepdim=True).clip(min=1e-3) # 1, num, s*s, 1
384 | faceid_hidden_states = (faceid_mask * faceid_hidden_states).sum(1) # 2*bs, s*s, dim
385 | else:
386 | faceid_hidden_states = (routing_map * faceid_hidden_states).sum(1) # 2*bs, s*s, dim
387 |
388 | if attn_args.enable_share_attn:
389 | faceid_hidden_states[0] = 0.
390 | faceid_hidden_states[batch_size // 2] = 0.
391 | else:
392 | faceid_hidden_states = torch.zeros_like(hidden_states)
393 |
394 | hidden_states = hidden_states + \
395 | attn_args.ip_scale * ip_hidden_states + \
396 | attn_args.faceid_scale * faceid_hidden_states
397 |
398 | # linear proj
399 | output_hidden_states = attn.to_out[0](hidden_states)
400 | if attn_args.lora_scale > 0.0:
401 | output_hidden_states = output_hidden_states + attn_args.lora_scale * self.to_out_lora(hidden_states)
402 | elif attn_args.multi_id_lora_scale > 0.0:
403 | output_hidden_states = output_hidden_states + attn_args.multi_id_lora_scale * self.to_out_multi_id_lora(
404 | hidden_states)
405 | hidden_states = output_hidden_states
406 |
407 | # dropout
408 | hidden_states = attn.to_out[1](hidden_states)
409 |
410 | if input_ndim == 4:
411 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
412 |
413 | if attn.residual_connection:
414 | hidden_states = hidden_states + residual
415 |
416 | hidden_states = hidden_states / attn.rescale_output_factor
417 |
418 | return hidden_states
419 |
420 |
421 | # for controlnet
422 | class UniPortraitCNAttnProcessor2_0:
423 | def __init__(self, num_ip_tokens=4, num_faceid_tokens=16):
424 |
425 | self.num_ip_tokens = num_ip_tokens
426 | self.num_faceid_tokens = num_faceid_tokens
427 |
428 | def __call__(
429 | self,
430 | attn,
431 | hidden_states,
432 | encoder_hidden_states=None,
433 | attention_mask=None,
434 | temb=None,
435 | *args,
436 | **kwargs,
437 | ):
438 | residual = hidden_states
439 |
440 | if attn.spatial_norm is not None:
441 | hidden_states = attn.spatial_norm(hidden_states, temb)
442 |
443 | input_ndim = hidden_states.ndim
444 |
445 | if input_ndim == 4:
446 | batch_size, channel, height, width = hidden_states.shape
447 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
448 |
449 | batch_size, sequence_length, _ = (
450 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
451 | )
452 | if attention_mask is not None:
453 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
454 | # scaled_dot_product_attention expects attention_mask shape to be
455 | # (batch, heads, source_length, target_length)
456 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
457 |
458 | if attn.group_norm is not None:
459 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
460 |
461 | query = attn.to_q(hidden_states)
462 |
463 | if encoder_hidden_states is None:
464 | encoder_hidden_states = hidden_states
465 | else:
466 | text_end = encoder_hidden_states.shape[1] - self.num_faceid_tokens * attn_args.num_faceids \
467 | - self.num_ip_tokens
468 | encoder_hidden_states = encoder_hidden_states[:, :text_end] # only use text
469 | if attn.norm_cross:
470 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
471 |
472 | key = attn.to_k(encoder_hidden_states)
473 | value = attn.to_v(encoder_hidden_states)
474 |
475 | inner_dim = key.shape[-1]
476 | head_dim = inner_dim // attn.heads
477 |
478 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
479 |
480 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
481 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
482 |
483 | hidden_states = F.scaled_dot_product_attention(
484 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
485 | )
486 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
487 | hidden_states = hidden_states.to(query.dtype)
488 |
489 | # linear proj
490 | hidden_states = attn.to_out[0](hidden_states)
491 | # dropout
492 | hidden_states = attn.to_out[1](hidden_states)
493 |
494 | if input_ndim == 4:
495 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
496 |
497 | if attn.residual_connection:
498 | hidden_states = hidden_states + residual
499 |
500 | hidden_states = hidden_states / attn.rescale_output_factor
501 |
502 | return hidden_states
503 |
--------------------------------------------------------------------------------
/uniportrait/uniportrait_pipeline.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from diffusers import ControlNetModel
5 | from diffusers.pipelines.controlnet import MultiControlNetModel
6 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
7 |
8 | from .curricular_face.backbone import get_model
9 | from .resampler import UniPortraitFaceIDResampler
10 | from .uniportrait_attention_processor import UniPortraitCNAttnProcessor2_0 as UniPortraitCNAttnProcessor
11 | from .uniportrait_attention_processor import UniPortraitLoRAAttnProcessor2_0 as UniPortraitLoRAAttnProcessor
12 | from .uniportrait_attention_processor import UniPortraitLoRAIPAttnProcessor2_0 as UniPortraitLoRAIPAttnProcessor
13 |
14 |
15 | class ImageProjModel(nn.Module):
16 | """Projection Model"""
17 |
18 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
19 | super().__init__()
20 |
21 | self.cross_attention_dim = cross_attention_dim
22 | self.clip_extra_context_tokens = clip_extra_context_tokens
23 | self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
24 | self.norm = nn.LayerNorm(cross_attention_dim)
25 |
26 | def forward(self, image_embeds):
27 | embeds = image_embeds # b, c
28 | clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens,
29 | self.cross_attention_dim)
30 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
31 | return clip_extra_context_tokens
32 |
33 |
34 | class UniPortraitPipeline:
35 |
36 | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt=None, face_backbone_ckpt=None, uniportrait_faceid_ckpt=None,
37 | uniportrait_router_ckpt=None, num_ip_tokens=4, num_faceid_tokens=16,
38 | lora_rank=128, device=torch.device("cuda"), torch_dtype=torch.float16):
39 |
40 | self.image_encoder_path = image_encoder_path
41 | self.ip_ckpt = ip_ckpt
42 | self.uniportrait_faceid_ckpt = uniportrait_faceid_ckpt
43 | self.uniportrait_router_ckpt = uniportrait_router_ckpt
44 |
45 | self.num_ip_tokens = num_ip_tokens
46 | self.num_faceid_tokens = num_faceid_tokens
47 | self.lora_rank = lora_rank
48 |
49 | self.device = device
50 | self.torch_dtype = torch_dtype
51 |
52 | self.pipe = sd_pipe.to(self.device)
53 |
54 | # load clip image encoder
55 | self.clip_image_processor = CLIPImageProcessor(size={"shortest_edge": 224}, do_center_crop=False,
56 | use_square_size=True)
57 | self.clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
58 | self.device, dtype=self.torch_dtype)
59 | # load face backbone
60 | self.facerecog_model = get_model("IR_101")([112, 112])
61 | self.facerecog_model.load_state_dict(torch.load(face_backbone_ckpt, map_location="cpu"))
62 | self.facerecog_model = self.facerecog_model.to(self.device, dtype=torch_dtype)
63 | self.facerecog_model.eval()
64 | # image proj model
65 | self.image_proj_model = self.init_image_proj()
66 | # faceid proj model
67 | self.faceid_proj_model = self.init_faceid_proj()
68 | # set uniportrait and ip adapter
69 | self.set_uniportrait_and_ip_adapter()
70 | # load uniportrait and ip adapter
71 | self.load_uniportrait_and_ip_adapter()
72 |
73 | def init_image_proj(self):
74 | image_proj_model = ImageProjModel(
75 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
76 | clip_embeddings_dim=self.clip_image_encoder.config.projection_dim,
77 | clip_extra_context_tokens=self.num_ip_tokens,
78 | ).to(self.device, dtype=self.torch_dtype)
79 | return image_proj_model
80 |
81 | def init_faceid_proj(self):
82 | faceid_proj_model = UniPortraitFaceIDResampler(
83 | intrinsic_id_embedding_dim=512,
84 | structure_embedding_dim=64 + 128 + 256 + self.clip_image_encoder.config.hidden_size,
85 | num_tokens=16, depth=6,
86 | dim=self.pipe.unet.config.cross_attention_dim, dim_head=64,
87 | heads=12, ff_mult=4,
88 | output_dim=self.pipe.unet.config.cross_attention_dim
89 | ).to(self.device, dtype=self.torch_dtype)
90 | return faceid_proj_model
91 |
92 | def set_uniportrait_and_ip_adapter(self):
93 | unet = self.pipe.unet
94 | attn_procs = {}
95 | for name in unet.attn_processors.keys():
96 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
97 | if name.startswith("mid_block"):
98 | hidden_size = unet.config.block_out_channels[-1]
99 | elif name.startswith("up_blocks"):
100 | block_id = int(name[len("up_blocks.")])
101 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
102 | elif name.startswith("down_blocks"):
103 | block_id = int(name[len("down_blocks.")])
104 | hidden_size = unet.config.block_out_channels[block_id]
105 | if cross_attention_dim is None:
106 | attn_procs[name] = UniPortraitLoRAAttnProcessor(
107 | hidden_size=hidden_size,
108 | cross_attention_dim=cross_attention_dim,
109 | rank=self.lora_rank,
110 | ).to(self.device, dtype=self.torch_dtype).eval()
111 | else:
112 | attn_procs[name] = UniPortraitLoRAIPAttnProcessor(
113 | hidden_size=hidden_size,
114 | cross_attention_dim=cross_attention_dim,
115 | rank=self.lora_rank,
116 | num_ip_tokens=self.num_ip_tokens,
117 | num_faceid_tokens=self.num_faceid_tokens,
118 | ).to(self.device, dtype=self.torch_dtype).eval()
119 | unet.set_attn_processor(attn_procs)
120 | if hasattr(self.pipe, "controlnet"):
121 | if isinstance(self.pipe.controlnet, ControlNetModel):
122 | self.pipe.controlnet.set_attn_processor(
123 | UniPortraitCNAttnProcessor(
124 | num_ip_tokens=self.num_ip_tokens,
125 | num_faceid_tokens=self.num_faceid_tokens,
126 | )
127 | )
128 | elif isinstance(self.pipe.controlnet, MultiControlNetModel):
129 | for module in self.pipe.controlnet.nets:
130 | module.set_attn_processor(
131 | UniPortraitCNAttnProcessor(
132 | num_ip_tokens=self.num_ip_tokens,
133 | num_faceid_tokens=self.num_faceid_tokens,
134 | )
135 | )
136 | else:
137 | raise ValueError
138 |
139 | def load_uniportrait_and_ip_adapter(self):
140 | if self.ip_ckpt:
141 | print(f"loading from {self.ip_ckpt}...")
142 | state_dict = torch.load(self.ip_ckpt, map_location="cpu")
143 | self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=False)
144 | ip_layers = nn.ModuleList(self.pipe.unet.attn_processors.values())
145 | ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
146 |
147 | if self.uniportrait_faceid_ckpt:
148 | print(f"loading from {self.uniportrait_faceid_ckpt}...")
149 | state_dict = torch.load(self.uniportrait_faceid_ckpt, map_location="cpu")
150 | self.faceid_proj_model.load_state_dict(state_dict["faceid_proj"], strict=True)
151 | ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
152 | ip_layers.load_state_dict(state_dict["faceid_adapter"], strict=False)
153 |
154 | if self.uniportrait_router_ckpt:
155 | print(f"loading from {self.uniportrait_router_ckpt}...")
156 | state_dict = torch.load(self.uniportrait_router_ckpt, map_location="cpu")
157 | router_state_dict = {}
158 | for k, v in state_dict["faceid_adapter"].items():
159 | if "lora." in k:
160 | router_state_dict[k.replace("lora.", "multi_id_lora.")] = v
161 | elif "router." in k:
162 | router_state_dict[k] = v
163 | ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
164 | ip_layers.load_state_dict(router_state_dict, strict=False)
165 |
166 | @torch.inference_mode()
167 | def get_ip_embeds(self, pil_ip_image):
168 | ip_image = self.clip_image_processor(images=pil_ip_image, return_tensors="pt").pixel_values
169 | ip_image = ip_image.to(self.device, dtype=self.torch_dtype) # (b, 3, 224, 224), values being normalized
170 | ip_embeds = self.clip_image_encoder(ip_image).image_embeds
171 | ip_prompt_embeds = self.image_proj_model(ip_embeds)
172 | uncond_ip_prompt_embeds = self.image_proj_model(torch.zeros_like(ip_embeds))
173 | return ip_prompt_embeds, uncond_ip_prompt_embeds
174 |
175 | @torch.inference_mode()
176 | def get_single_faceid_embeds(self, pil_face_images, face_structure_scale):
177 | face_clip_image = self.clip_image_processor(images=pil_face_images, return_tensors="pt").pixel_values
178 | face_clip_image = face_clip_image.to(self.device, dtype=self.torch_dtype) # (b, 3, 224, 224)
179 | face_clip_embeds = self.clip_image_encoder(
180 | face_clip_image, output_hidden_states=True).hidden_states[-2][:, 1:] # b, 256, 1280
181 |
182 | OPENAI_CLIP_MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=self.device,
183 | dtype=self.torch_dtype).reshape(-1, 1, 1)
184 | OPENAI_CLIP_STD = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=self.device,
185 | dtype=self.torch_dtype).reshape(-1, 1, 1)
186 | facerecog_image = face_clip_image * OPENAI_CLIP_STD + OPENAI_CLIP_MEAN # [0, 1]
187 | facerecog_image = torch.clamp((facerecog_image - 0.5) / 0.5, -1, 1) # [-1, 1]
188 | facerecog_image = F.interpolate(facerecog_image, size=(112, 112), mode="bilinear", align_corners=False)
189 | facerecog_embeds = self.facerecog_model(facerecog_image, return_mid_feats=True)[1]
190 |
191 | face_intrinsic_id_embeds = facerecog_embeds[-1] # (b, 512, 7, 7)
192 | face_intrinsic_id_embeds = face_intrinsic_id_embeds.flatten(2).permute(0, 2, 1) # b, 49, 512
193 |
194 | facerecog_structure_embeds = facerecog_embeds[:-1] # (b, 64, 56, 56), (b, 128, 28, 28), (b, 256, 14, 14)
195 | facerecog_structure_embeds = torch.cat([
196 | F.interpolate(feat, size=(16, 16), mode="bilinear", align_corners=False)
197 | for feat in facerecog_structure_embeds], dim=1) # b, 448, 16, 16
198 | facerecog_structure_embeds = facerecog_structure_embeds.flatten(2).permute(0, 2, 1) # b, 256, 448
199 | face_structure_embeds = torch.cat([facerecog_structure_embeds, face_clip_embeds], dim=-1) # b, 256, 1728
200 |
201 | uncond_face_clip_embeds = self.clip_image_encoder(
202 | torch.zeros_like(face_clip_image[:1]), output_hidden_states=True).hidden_states[-2][:, 1:] # 1, 256, 1280
203 | uncond_face_structure_embeds = torch.cat(
204 | [torch.zeros_like(facerecog_structure_embeds[:1]), uncond_face_clip_embeds], dim=-1) # 1, 256, 1728
205 |
206 | faceid_prompt_embeds = self.faceid_proj_model(
207 | face_intrinsic_id_embeds.flatten(0, 1).unsqueeze(0),
208 | face_structure_embeds.flatten(0, 1).unsqueeze(0),
209 | structure_scale=face_structure_scale,
210 | ) # [b, 16, 768]
211 |
212 | uncond_faceid_prompt_embeds = self.faceid_proj_model(
213 | torch.zeros_like(face_intrinsic_id_embeds[:1]),
214 | uncond_face_structure_embeds,
215 | structure_scale=face_structure_scale,
216 | ) # [1, 16, 768]
217 |
218 | return faceid_prompt_embeds, uncond_faceid_prompt_embeds
219 |
220 | def generate(
221 | self,
222 | prompt=None,
223 | negative_prompt=None,
224 | pil_ip_image=None,
225 | cond_faceids=None,
226 | face_structure_scale=0.0,
227 | seed=-1,
228 | guidance_scale=7.5,
229 | num_inference_steps=30,
230 | zT=None,
231 | **kwargs,
232 | ):
233 | """
234 | Args:
235 | prompt:
236 | negative_prompt:
237 | pil_ip_image:
238 | cond_faceids: [
239 | {
240 | "refs": [PIL.Image] or PIL.Image,
241 | (Optional) "mix_refs": [PIL.Image],
242 | (Optional) "mix_scales": [float],
243 | },
244 | ...
245 | ]
246 | face_structure_scale:
247 | seed:
248 | guidance_scale:
249 | num_inference_steps:
250 | zT:
251 | **kwargs:
252 | Returns:
253 | """
254 |
255 | if seed is not None:
256 | torch.manual_seed(seed)
257 | torch.cuda.manual_seed_all(seed)
258 |
259 | with torch.inference_mode():
260 | prompt_embeds, negative_prompt_embeds = self.pipe.encode_prompt(
261 | prompt, device=self.device, num_images_per_prompt=1, do_classifier_free_guidance=True,
262 | negative_prompt=negative_prompt)
263 | num_prompts = prompt_embeds.shape[0]
264 |
265 | if pil_ip_image is not None:
266 | ip_prompt_embeds, uncond_ip_prompt_embeds = self.get_ip_embeds(pil_ip_image)
267 | ip_prompt_embeds = ip_prompt_embeds.repeat(num_prompts, 1, 1)
268 | uncond_ip_prompt_embeds = uncond_ip_prompt_embeds.repeat(num_prompts, 1, 1)
269 | else:
270 | ip_prompt_embeds = uncond_ip_prompt_embeds = \
271 | torch.zeros_like(prompt_embeds[:, :1]).repeat(1, self.num_ip_tokens, 1)
272 |
273 | prompt_embeds = torch.cat([prompt_embeds, ip_prompt_embeds], dim=1)
274 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_ip_prompt_embeds], dim=1)
275 |
276 | if cond_faceids and len(cond_faceids) > 0:
277 | all_faceid_prompt_embeds = []
278 | all_uncond_faceid_prompt_embeds = []
279 | for curr_faceid_info in cond_faceids:
280 | refs = curr_faceid_info["refs"]
281 | faceid_prompt_embeds, uncond_faceid_prompt_embeds = \
282 | self.get_single_faceid_embeds(refs, face_structure_scale)
283 | if "mix_refs" in curr_faceid_info:
284 | mix_refs = curr_faceid_info["mix_refs"]
285 | mix_scales = curr_faceid_info["mix_scales"]
286 |
287 | master_face_mix_scale = 1.0 - sum(mix_scales)
288 | faceid_prompt_embeds = faceid_prompt_embeds * master_face_mix_scale
289 | for mix_ref, mix_scale in zip(mix_refs, mix_scales):
290 | faceid_mix_prompt_embeds, _ = self.get_single_faceid_embeds(mix_ref, face_structure_scale)
291 | faceid_prompt_embeds = faceid_prompt_embeds + faceid_mix_prompt_embeds * mix_scale
292 |
293 | all_faceid_prompt_embeds.append(faceid_prompt_embeds)
294 | all_uncond_faceid_prompt_embeds.append(uncond_faceid_prompt_embeds)
295 |
296 | faceid_prompt_embeds = torch.cat(all_faceid_prompt_embeds, dim=1)
297 | uncond_faceid_prompt_embeds = torch.cat(all_uncond_faceid_prompt_embeds, dim=1)
298 | faceid_prompt_embeds = faceid_prompt_embeds.repeat(num_prompts, 1, 1)
299 | uncond_faceid_prompt_embeds = uncond_faceid_prompt_embeds.repeat(num_prompts, 1, 1)
300 |
301 | prompt_embeds = torch.cat([prompt_embeds, faceid_prompt_embeds], dim=1)
302 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_faceid_prompt_embeds], dim=1)
303 |
304 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
305 | if zT is not None:
306 | h_, w_ = kwargs["image"][0].shape[-2:]
307 | latents = torch.randn(num_prompts, 4, h_ // 8, w_ // 8, device=self.device, generator=generator,
308 | dtype=self.pipe.unet.dtype)
309 | latents[0] = zT
310 | else:
311 | latents = None
312 |
313 | images = self.pipe(
314 | prompt_embeds=prompt_embeds,
315 | negative_prompt_embeds=negative_prompt_embeds,
316 | guidance_scale=guidance_scale,
317 | num_inference_steps=num_inference_steps,
318 | generator=generator,
319 | latents=latents,
320 | **kwargs,
321 | ).images
322 |
323 | return images
324 |
--------------------------------------------------------------------------------