├── .github └── workflows │ └── publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── environment.yml ├── example_workflows ├── hellomeme_image_workflow.jpg ├── hellomeme_video_workflow.jpg ├── image_generation.jpg ├── image_generation.json ├── image_style_transfer.jpg ├── image_style_transfer.json ├── video_generation.jpg └── video_generation.json ├── examples ├── amns.mp4 ├── chillout.jpg ├── civitai2.jpg ├── helloicon.png ├── i5.jpg ├── jgz.mp4 ├── majicmix2.jpg ├── qie.mp4 ├── tiktok.mp4 ├── toon.png └── yao.jpg ├── hellomeme ├── __init__.py ├── model_config.json ├── models │ ├── __init__.py │ ├── hm3_denoising_3d.py │ ├── hm3_denoising_motion.py │ ├── hm_adapters.py │ ├── hm_blocks.py │ ├── hm_control.py │ ├── hm_denoising_3d.py │ └── hm_denoising_motion.py ├── pipelines │ ├── __init__.py │ ├── pipline_hm3_image.py │ ├── pipline_hm3_video.py │ ├── pipline_hm5_image.py │ ├── pipline_hm5_video.py │ ├── pipline_hm_image.py │ └── pipline_hm_video.py ├── tools │ ├── __init__.py │ ├── hello_3dmm.py │ ├── hello_arkit.py │ ├── hello_camera_demo.py │ ├── hello_face_alignment.py │ ├── hello_face_det.py │ ├── pdf.py │ ├── sr.py │ └── utils.py └── utils.py ├── meme.py └── pyproject.toml /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | permissions: 12 | issues: write 13 | 14 | jobs: 15 | publish-node: 16 | name: Publish Custom Node to registry 17 | runs-on: ubuntu-latest 18 | if: ${{ github.repository_owner == 'HelloVision' }} 19 | steps: 20 | - name: Check out code 21 | uses: actions/checkout@v4 22 | - name: Publish Custom Node 23 | uses: Comfy-Org/publish-node-action@v1 24 | with: 25 | ## Add your own personal access token to your Github Repository secrets and reference it here. 26 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pyd 3 | *.pth 4 | *.pkl 5 | *.pyc* 6 | *.pyd* 7 | *_fps15.mp4 8 | *__pycache__/* 9 | .idea/ 10 | data/results 11 | pretrained_models/ 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 HelloVision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

HelloMeme: Integrating Spatial Knitting Attentions to Embed High-Level and Fidelity-Rich Conditions in Diffusion Models

2 | 3 |
4 | Shengkai Zhang, Nianhong Jiao, Tian Li, Chaojie Yang, Chenhui Xue*, Boya Niu*, Jun Gao 5 |
6 | 7 |
8 | HelloVision | HelloGroup Inc. 9 |
10 | 11 |
12 | * Intern 13 |
14 | 15 |
16 |
17 | 18 | 19 | 20 | 21 | 22 |
23 | 24 | 25 | ## 🔆 New Features/Updates 26 | - ✅ `02/09/2025` **HelloMemeV3** is now available. 27 | [YouTube Demo](https://www.youtube.com/watch?v=DAUA0EYjsZA) 28 | 29 | - ✅ `12/17/2024` Support modelscope ([Modelscope Demo](https://www.modelscope.cn/studios/songkey/HelloMeme)). 30 | - ✅ `12/08/2024` Added **HelloMemeV2** (select "v2" in the version option of the LoadHelloMemeImage/Video Node). Its features include: 31 | a. Improved expression consistency between the generated video and the driving video. 32 | b. Better compatibility with third-party checkpoints. 33 | c. Reduced VRAM usage. 34 | [YouTube Demo](https://www.youtube.com/watch?v=-2s_pLAKoRg) 35 | 36 | - ✅ `11/29/2024` a.Optimize the algorithm; b.Add VAE selection functionality; c.Introduce a super-resolution feature. 37 | [YouTube Demo](https://www.youtube.com/watch?v=fM5nyn6q02Y) 38 | 39 | - ✅ `11/14/2024` Added the `HMControlNet2` module, which uses the `PD-FGC` motion module to extract facial expression information (`drive_exp2`); restructured the ComfyUI interface; and optimized VRAM usage. 40 | [YouTube Demo](https://www.youtube.com/watch?v=ZvoMHyRm310) 41 | 42 | - ✅ `11/12/2024` Added a newly fine-tuned version of [`Animatediff`](https://huggingface.co/songkey/hm_animatediff_frame12) with a patch size of 12, which uses less VRAM (Tested on 2080Ti). 43 | - ✅ `11/11/2024` ~~Optimized VRAM usage and added `HMVideoSimplePipeline` (`workflows/hellomeme_video_simple_workflow.json`), which doesn’t use Animatediff and can run on machines with less than 12G VRAM.~~ 44 | - ✅ `11/6/2024` The face proportion in the reference image significantly affects the generation quality. We have encapsulated the **recommended image cropping method** used during training into a `CropReferenceImage` Node. Refer to the workflows in the `ComfyUI_HelloMeme/workflows directory`: `hellomeme_video_cropref_workflow.json` and `hellomeme_image_cropref_workflow.json`. 45 | 46 | 47 | ## Introduction 48 | 49 | This repository is the official implementation of the [`HelloMeme`](https://arxiv.org/pdf/2410.22901) ComfyUI interface, featuring both image and video generation functionalities. Example workflow files can be found in the `ComfyUI_HelloMeme/workflows` directory. Test images and videos are saved in the `ComfyUI_HelloMeme/examples` directory. Below are screenshots of the interfaces for image and video generation. 50 | 51 | > [!Note] 52 | > [Custom models should be placed in the directories listed below.](https://github.com/HelloVision/ComfyUI_HelloMeme/issues/5#issuecomment-2461247829) 53 | > 54 | > **Checkpoints** under: `ComfyUI/models/checkpoints` 55 | > 56 | > **Loras** under: `ComfyUI/models/loras` 57 | 58 | 59 | ### Workflows 60 | 61 | | workflow file | Video Generation | Image Generation | HMControlNet | HMControlNet2 | 62 | |---------------|------------------|------------------|-----------|---------------| 63 | | image_generation.json | | ✅ | | ✅ | 64 | | image_style_transfer.json | | ✅ | | ✅ | 65 | | video_generation.json | ✅ | | | ✅ | 66 | 67 | ### Image Generation Interface 68 | 69 |

70 | image_generation_interface 71 |

72 | 73 | ### Video Generation Interface 74 | 75 |

76 | video_generation_interface 77 |

78 | 79 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .meme import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HelloVision/ComfyUI_HelloMeme/ce368e562f530b9efaf623b72e619b50b82a1ead/environment.yml -------------------------------------------------------------------------------- /example_workflows/hellomeme_image_workflow.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HelloVision/ComfyUI_HelloMeme/ce368e562f530b9efaf623b72e619b50b82a1ead/example_workflows/hellomeme_image_workflow.jpg -------------------------------------------------------------------------------- /example_workflows/hellomeme_video_workflow.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HelloVision/ComfyUI_HelloMeme/ce368e562f530b9efaf623b72e619b50b82a1ead/example_workflows/hellomeme_video_workflow.jpg -------------------------------------------------------------------------------- /example_workflows/image_generation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HelloVision/ComfyUI_HelloMeme/ce368e562f530b9efaf623b72e619b50b82a1ead/example_workflows/image_generation.jpg -------------------------------------------------------------------------------- /example_workflows/image_generation.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "a8e8ebf3-8dc8-4d76-ad73-dfe0b945cc77", 3 | "revision": 0, 4 | "last_node_id": 80, 5 | "last_link_id": 187, 6 | "nodes": [ 7 | { 8 | "id": 2, 9 | "type": "HMFaceToolkitsLoader", 10 | "pos": [ 11 | 45.41996383666992, 12 | 769.9866333007812 13 | ], 14 | "size": [ 15 | 315, 16 | 82 17 | ], 18 | "flags": {}, 19 | "order": 0, 20 | "mode": 0, 21 | "inputs": [], 22 | "outputs": [ 23 | { 24 | "name": "face_toolkits", 25 | "type": "FACE_TOOLKITS", 26 | "slot_index": 0, 27 | "links": [ 28 | 122, 29 | 129, 30 | 147, 31 | 149, 32 | 173, 33 | 174 34 | ] 35 | } 36 | ], 37 | "properties": { 38 | "Node name for S&R": "HMFaceToolkitsLoader" 39 | }, 40 | "widgets_values": [ 41 | 0, 42 | "huggingface" 43 | ] 44 | }, 45 | { 46 | "id": 61, 47 | "type": "GetFaceLandmarks", 48 | "pos": [ 49 | 519.7196655273438, 50 | 938.032958984375 51 | ], 52 | "size": [ 53 | 170.62362670898438, 54 | 46 55 | ], 56 | "flags": {}, 57 | "order": 4, 58 | "mode": 0, 59 | "inputs": [ 60 | { 61 | "name": "face_toolkits", 62 | "type": "FACE_TOOLKITS", 63 | "link": 129 64 | }, 65 | { 66 | "name": "images", 67 | "type": "IMAGE", 68 | "link": 179 69 | } 70 | ], 71 | "outputs": [ 72 | { 73 | "name": "landmarks", 74 | "type": "FACELANDMARKS222", 75 | "links": [ 76 | 145, 77 | 177 78 | ] 79 | } 80 | ], 81 | "properties": { 82 | "Node name for S&R": "GetFaceLandmarks" 83 | }, 84 | "widgets_values": [] 85 | }, 86 | { 87 | "id": 60, 88 | "type": "GetFaceLandmarks", 89 | "pos": [ 90 | 518.9132690429688, 91 | 612.9461669921875 92 | ], 93 | "size": [ 94 | 170.62362670898438, 95 | 46 96 | ], 97 | "flags": {}, 98 | "order": 5, 99 | "mode": 0, 100 | "inputs": [ 101 | { 102 | "name": "face_toolkits", 103 | "type": "FACE_TOOLKITS", 104 | "link": 122 105 | }, 106 | { 107 | "name": "images", 108 | "type": "IMAGE", 109 | "link": 121 110 | } 111 | ], 112 | "outputs": [ 113 | { 114 | "name": "landmarks", 115 | "type": "FACELANDMARKS222", 116 | "links": [ 117 | 143, 118 | 178 119 | ] 120 | } 121 | ], 122 | "properties": { 123 | "Node name for S&R": "GetFaceLandmarks" 124 | }, 125 | "widgets_values": [] 126 | }, 127 | { 128 | "id": 77, 129 | "type": "LoadImage", 130 | "pos": [ 131 | 61.306365966796875, 132 | 1023.740478515625 133 | ], 134 | "size": [ 135 | 270, 136 | 314 137 | ], 138 | "flags": {}, 139 | "order": 1, 140 | "mode": 0, 141 | "inputs": [], 142 | "outputs": [ 143 | { 144 | "name": "IMAGE", 145 | "type": "IMAGE", 146 | "links": [ 147 | 179, 148 | 180, 149 | 181 150 | ] 151 | }, 152 | { 153 | "name": "MASK", 154 | "type": "MASK", 155 | "links": null 156 | } 157 | ], 158 | "properties": { 159 | "Node name for S&R": "LoadImage" 160 | }, 161 | "widgets_values": [ 162 | "yao.jpg", 163 | "image" 164 | ] 165 | }, 166 | { 167 | "id": 67, 168 | "type": "GetHeadPose", 169 | "pos": [ 170 | 814.1528930664062, 171 | 495.9793701171875 172 | ], 173 | "size": [ 174 | 270, 175 | 98 176 | ], 177 | "flags": {}, 178 | "order": 8, 179 | "mode": 0, 180 | "inputs": [ 181 | { 182 | "name": "face_toolkits", 183 | "type": "FACE_TOOLKITS", 184 | "link": 147 185 | }, 186 | { 187 | "name": "images", 188 | "type": "IMAGE", 189 | "link": 151 190 | }, 191 | { 192 | "name": "landmarks", 193 | "type": "FACELANDMARKS222", 194 | "link": 143 195 | } 196 | ], 197 | "outputs": [ 198 | { 199 | "name": "head_pose", 200 | "type": "HEAD_POSE", 201 | "links": [ 202 | 183 203 | ] 204 | } 205 | ], 206 | "properties": { 207 | "Node name for S&R": "GetHeadPose" 208 | }, 209 | "widgets_values": [ 210 | true 211 | ] 212 | }, 213 | { 214 | "id": 75, 215 | "type": "GetExpression2", 216 | "pos": [ 217 | 814.1359252929688, 218 | 689.6795654296875 219 | ], 220 | "size": [ 221 | 191.63729858398438, 222 | 66 223 | ], 224 | "flags": {}, 225 | "order": 9, 226 | "mode": 0, 227 | "inputs": [ 228 | { 229 | "name": "face_toolkits", 230 | "type": "FACE_TOOLKITS", 231 | "link": 173 232 | }, 233 | { 234 | "name": "images", 235 | "type": "IMAGE", 236 | "link": 175 237 | }, 238 | { 239 | "name": "landmarks", 240 | "type": "FACELANDMARKS222", 241 | "link": 178 242 | } 243 | ], 244 | "outputs": [ 245 | { 246 | "name": "expression", 247 | "type": "EXPRESSION", 248 | "links": [ 249 | 184 250 | ] 251 | } 252 | ], 253 | "properties": { 254 | "Node name for S&R": "GetExpression2" 255 | }, 256 | "widgets_values": [] 257 | }, 258 | { 259 | "id": 69, 260 | "type": "GetHeadPose", 261 | "pos": [ 262 | 819.799560546875, 263 | 860.5927734375 264 | ], 265 | "size": [ 266 | 270, 267 | 98 268 | ], 269 | "flags": {}, 270 | "order": 6, 271 | "mode": 0, 272 | "inputs": [ 273 | { 274 | "name": "face_toolkits", 275 | "type": "FACE_TOOLKITS", 276 | "link": 149 277 | }, 278 | { 279 | "name": "images", 280 | "type": "IMAGE", 281 | "link": 180 282 | }, 283 | { 284 | "name": "landmarks", 285 | "type": "FACELANDMARKS222", 286 | "link": 145 287 | } 288 | ], 289 | "outputs": [ 290 | { 291 | "name": "head_pose", 292 | "type": "HEAD_POSE", 293 | "links": [ 294 | 185 295 | ] 296 | } 297 | ], 298 | "properties": { 299 | "Node name for S&R": "GetHeadPose" 300 | }, 301 | "widgets_values": [ 302 | true 303 | ] 304 | }, 305 | { 306 | "id": 76, 307 | "type": "GetExpression2", 308 | "pos": [ 309 | 833.6573486328125, 310 | 1042.0396728515625 311 | ], 312 | "size": [ 313 | 191.63729858398438, 314 | 66 315 | ], 316 | "flags": {}, 317 | "order": 7, 318 | "mode": 0, 319 | "inputs": [ 320 | { 321 | "name": "face_toolkits", 322 | "type": "FACE_TOOLKITS", 323 | "link": 174 324 | }, 325 | { 326 | "name": "images", 327 | "type": "IMAGE", 328 | "link": 181 329 | }, 330 | { 331 | "name": "landmarks", 332 | "type": "FACELANDMARKS222", 333 | "link": 177 334 | } 335 | ], 336 | "outputs": [ 337 | { 338 | "name": "expression", 339 | "type": "EXPRESSION", 340 | "links": [ 341 | 186 342 | ] 343 | } 344 | ], 345 | "properties": { 346 | "Node name for S&R": "GetExpression2" 347 | }, 348 | "widgets_values": [] 349 | }, 350 | { 351 | "id": 79, 352 | "type": "HMPipelineImage", 353 | "pos": [ 354 | 1334.226318359375, 355 | 613.9535522460938 356 | ], 357 | "size": [ 358 | 270, 359 | 306 360 | ], 361 | "flags": {}, 362 | "order": 10, 363 | "mode": 0, 364 | "inputs": [ 365 | { 366 | "name": "hm_image_pipeline", 367 | "type": "HMIMAGEPIPELINE", 368 | "link": 182 369 | }, 370 | { 371 | "name": "ref_head_pose", 372 | "type": "HEAD_POSE", 373 | "link": 183 374 | }, 375 | { 376 | "name": "ref_expression", 377 | "type": "EXPRESSION", 378 | "link": 184 379 | }, 380 | { 381 | "name": "drive_head_pose", 382 | "type": "HEAD_POSE", 383 | "link": 185 384 | }, 385 | { 386 | "name": "drive_expression", 387 | "type": "EXPRESSION", 388 | "link": 186 389 | } 390 | ], 391 | "outputs": [ 392 | { 393 | "name": "IMAGE", 394 | "type": "IMAGE", 395 | "links": [ 396 | 187 397 | ] 398 | }, 399 | { 400 | "name": "LATENT", 401 | "type": "LATENT", 402 | "links": null 403 | } 404 | ], 405 | "properties": { 406 | "Node name for S&R": "HMPipelineImage" 407 | }, 408 | "widgets_values": [ 409 | 0, 410 | "", 411 | "", 412 | 25, 413 | 815482659530374, 414 | "randomize", 415 | 1.5, 416 | 0 417 | ] 418 | }, 419 | { 420 | "id": 58, 421 | "type": "LoadImage", 422 | "pos": [ 423 | 37.4466552734375, 424 | 260.9000244140625 425 | ], 426 | "size": [ 427 | 315, 428 | 314 429 | ], 430 | "flags": {}, 431 | "order": 2, 432 | "mode": 0, 433 | "inputs": [], 434 | "outputs": [ 435 | { 436 | "name": "IMAGE", 437 | "type": "IMAGE", 438 | "slot_index": 0, 439 | "links": [ 440 | 121, 441 | 151, 442 | 175 443 | ] 444 | }, 445 | { 446 | "name": "MASK", 447 | "type": "MASK", 448 | "links": null 449 | } 450 | ], 451 | "properties": { 452 | "Node name for S&R": "LoadImage" 453 | }, 454 | "widgets_values": [ 455 | "chillout.jpg", 456 | "image" 457 | ] 458 | }, 459 | { 460 | "id": 80, 461 | "type": "PreviewImage", 462 | "pos": [ 463 | 1703.6793212890625, 464 | 613.9529418945312 465 | ], 466 | "size": [ 467 | 210, 468 | 258 469 | ], 470 | "flags": {}, 471 | "order": 11, 472 | "mode": 0, 473 | "inputs": [ 474 | { 475 | "name": "images", 476 | "type": "IMAGE", 477 | "link": 187 478 | } 479 | ], 480 | "outputs": [], 481 | "properties": { 482 | "Node name for S&R": "PreviewImage" 483 | }, 484 | "widgets_values": [] 485 | }, 486 | { 487 | "id": 78, 488 | "type": "HMImagePipelineLoader", 489 | "pos": [ 490 | 797.7930297851562, 491 | 188.84007263183594 492 | ], 493 | "size": [ 494 | 298.3472595214844, 495 | 226 496 | ], 497 | "flags": {}, 498 | "order": 3, 499 | "mode": 0, 500 | "inputs": [], 501 | "outputs": [ 502 | { 503 | "name": "hm_image_pipeline", 504 | "type": "HMIMAGEPIPELINE", 505 | "links": [ 506 | 182 507 | ] 508 | } 509 | ], 510 | "properties": { 511 | "Node name for S&R": "HMImagePipelineLoader" 512 | }, 513 | "widgets_values": [ 514 | "[preset]DisneyPixarCartoonB", 515 | "None", 516 | "same as checkpoint", 517 | "v5", 518 | "x1", 519 | "huggingface", 520 | 1, 521 | "fp32" 522 | ] 523 | } 524 | ], 525 | "links": [ 526 | [ 527 | 121, 528 | 58, 529 | 0, 530 | 60, 531 | 1, 532 | "IMAGE" 533 | ], 534 | [ 535 | 122, 536 | 2, 537 | 0, 538 | 60, 539 | 0, 540 | "FACE_TOOLKITS" 541 | ], 542 | [ 543 | 129, 544 | 2, 545 | 0, 546 | 61, 547 | 0, 548 | "FACE_TOOLKITS" 549 | ], 550 | [ 551 | 143, 552 | 60, 553 | 0, 554 | 67, 555 | 2, 556 | "FACELANDMARKS222" 557 | ], 558 | [ 559 | 145, 560 | 61, 561 | 0, 562 | 69, 563 | 2, 564 | "FACELANDMARKS222" 565 | ], 566 | [ 567 | 147, 568 | 2, 569 | 0, 570 | 67, 571 | 0, 572 | "FACE_TOOLKITS" 573 | ], 574 | [ 575 | 149, 576 | 2, 577 | 0, 578 | 69, 579 | 0, 580 | "FACE_TOOLKITS" 581 | ], 582 | [ 583 | 151, 584 | 58, 585 | 0, 586 | 67, 587 | 1, 588 | "IMAGE" 589 | ], 590 | [ 591 | 173, 592 | 2, 593 | 0, 594 | 75, 595 | 0, 596 | "FACE_TOOLKITS" 597 | ], 598 | [ 599 | 174, 600 | 2, 601 | 0, 602 | 76, 603 | 0, 604 | "FACE_TOOLKITS" 605 | ], 606 | [ 607 | 175, 608 | 58, 609 | 0, 610 | 75, 611 | 1, 612 | "IMAGE" 613 | ], 614 | [ 615 | 177, 616 | 61, 617 | 0, 618 | 76, 619 | 2, 620 | "FACELANDMARKS222" 621 | ], 622 | [ 623 | 178, 624 | 60, 625 | 0, 626 | 75, 627 | 2, 628 | "FACELANDMARKS222" 629 | ], 630 | [ 631 | 179, 632 | 77, 633 | 0, 634 | 61, 635 | 1, 636 | "IMAGE" 637 | ], 638 | [ 639 | 180, 640 | 77, 641 | 0, 642 | 69, 643 | 1, 644 | "IMAGE" 645 | ], 646 | [ 647 | 181, 648 | 77, 649 | 0, 650 | 76, 651 | 1, 652 | "IMAGE" 653 | ], 654 | [ 655 | 182, 656 | 78, 657 | 0, 658 | 79, 659 | 0, 660 | "HMIMAGEPIPELINE" 661 | ], 662 | [ 663 | 183, 664 | 67, 665 | 0, 666 | 79, 667 | 1, 668 | "HEAD_POSE" 669 | ], 670 | [ 671 | 184, 672 | 75, 673 | 0, 674 | 79, 675 | 2, 676 | "EXPRESSION" 677 | ], 678 | [ 679 | 185, 680 | 69, 681 | 0, 682 | 79, 683 | 3, 684 | "HEAD_POSE" 685 | ], 686 | [ 687 | 186, 688 | 76, 689 | 0, 690 | 79, 691 | 4, 692 | "EXPRESSION" 693 | ], 694 | [ 695 | 187, 696 | 79, 697 | 0, 698 | 80, 699 | 0, 700 | "IMAGE" 701 | ] 702 | ], 703 | "groups": [], 704 | "config": {}, 705 | "extra": { 706 | "ds": { 707 | "scale": 0.8264462809917358, 708 | "offset": [ 709 | 493.27712102477454, 710 | -69.45337656976547 711 | ] 712 | }, 713 | "frontendVersion": "1.19.2" 714 | }, 715 | "version": 0.4 716 | } -------------------------------------------------------------------------------- /example_workflows/image_style_transfer.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HelloVision/ComfyUI_HelloMeme/ce368e562f530b9efaf623b72e619b50b82a1ead/example_workflows/image_style_transfer.jpg -------------------------------------------------------------------------------- /example_workflows/image_style_transfer.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "a8e8ebf3-8dc8-4d76-ad73-dfe0b945cc77", 3 | "revision": 0, 4 | "last_node_id": 80, 5 | "last_link_id": 190, 6 | "nodes": [ 7 | { 8 | "id": 2, 9 | "type": "HMFaceToolkitsLoader", 10 | "pos": [ 11 | 45.41996383666992, 12 | 769.9866333007812 13 | ], 14 | "size": [ 15 | 315, 16 | 82 17 | ], 18 | "flags": {}, 19 | "order": 0, 20 | "mode": 0, 21 | "inputs": [], 22 | "outputs": [ 23 | { 24 | "name": "face_toolkits", 25 | "type": "FACE_TOOLKITS", 26 | "slot_index": 0, 27 | "links": [ 28 | 122, 29 | 129, 30 | 147, 31 | 149, 32 | 173, 33 | 174 34 | ] 35 | } 36 | ], 37 | "properties": { 38 | "Node name for S&R": "HMFaceToolkitsLoader" 39 | }, 40 | "widgets_values": [ 41 | 0, 42 | "huggingface" 43 | ] 44 | }, 45 | { 46 | "id": 61, 47 | "type": "GetFaceLandmarks", 48 | "pos": [ 49 | 519.7196655273438, 50 | 938.032958984375 51 | ], 52 | "size": [ 53 | 170.62362670898438, 54 | 46 55 | ], 56 | "flags": {}, 57 | "order": 4, 58 | "mode": 0, 59 | "inputs": [ 60 | { 61 | "name": "face_toolkits", 62 | "type": "FACE_TOOLKITS", 63 | "link": 129 64 | }, 65 | { 66 | "name": "images", 67 | "type": "IMAGE", 68 | "link": 188 69 | } 70 | ], 71 | "outputs": [ 72 | { 73 | "name": "landmarks", 74 | "type": "FACELANDMARKS222", 75 | "links": [ 76 | 145, 77 | 177 78 | ] 79 | } 80 | ], 81 | "properties": { 82 | "Node name for S&R": "GetFaceLandmarks" 83 | }, 84 | "widgets_values": [] 85 | }, 86 | { 87 | "id": 60, 88 | "type": "GetFaceLandmarks", 89 | "pos": [ 90 | 518.9132690429688, 91 | 612.9461669921875 92 | ], 93 | "size": [ 94 | 170.62362670898438, 95 | 46 96 | ], 97 | "flags": {}, 98 | "order": 3, 99 | "mode": 0, 100 | "inputs": [ 101 | { 102 | "name": "face_toolkits", 103 | "type": "FACE_TOOLKITS", 104 | "link": 122 105 | }, 106 | { 107 | "name": "images", 108 | "type": "IMAGE", 109 | "link": 121 110 | } 111 | ], 112 | "outputs": [ 113 | { 114 | "name": "landmarks", 115 | "type": "FACELANDMARKS222", 116 | "links": [ 117 | 143, 118 | 178 119 | ] 120 | } 121 | ], 122 | "properties": { 123 | "Node name for S&R": "GetFaceLandmarks" 124 | }, 125 | "widgets_values": [] 126 | }, 127 | { 128 | "id": 67, 129 | "type": "GetHeadPose", 130 | "pos": [ 131 | 814.1528930664062, 132 | 495.9793701171875 133 | ], 134 | "size": [ 135 | 270, 136 | 98 137 | ], 138 | "flags": {}, 139 | "order": 5, 140 | "mode": 0, 141 | "inputs": [ 142 | { 143 | "name": "face_toolkits", 144 | "type": "FACE_TOOLKITS", 145 | "link": 147 146 | }, 147 | { 148 | "name": "images", 149 | "type": "IMAGE", 150 | "link": 151 151 | }, 152 | { 153 | "name": "landmarks", 154 | "type": "FACELANDMARKS222", 155 | "link": 143 156 | } 157 | ], 158 | "outputs": [ 159 | { 160 | "name": "head_pose", 161 | "type": "HEAD_POSE", 162 | "links": [ 163 | 183 164 | ] 165 | } 166 | ], 167 | "properties": { 168 | "Node name for S&R": "GetHeadPose" 169 | }, 170 | "widgets_values": [ 171 | true 172 | ] 173 | }, 174 | { 175 | "id": 75, 176 | "type": "GetExpression2", 177 | "pos": [ 178 | 814.1359252929688, 179 | 689.6795654296875 180 | ], 181 | "size": [ 182 | 191.63729858398438, 183 | 66 184 | ], 185 | "flags": {}, 186 | "order": 6, 187 | "mode": 0, 188 | "inputs": [ 189 | { 190 | "name": "face_toolkits", 191 | "type": "FACE_TOOLKITS", 192 | "link": 173 193 | }, 194 | { 195 | "name": "images", 196 | "type": "IMAGE", 197 | "link": 175 198 | }, 199 | { 200 | "name": "landmarks", 201 | "type": "FACELANDMARKS222", 202 | "link": 178 203 | } 204 | ], 205 | "outputs": [ 206 | { 207 | "name": "expression", 208 | "type": "EXPRESSION", 209 | "links": [ 210 | 184 211 | ] 212 | } 213 | ], 214 | "properties": { 215 | "Node name for S&R": "GetExpression2" 216 | }, 217 | "widgets_values": [] 218 | }, 219 | { 220 | "id": 69, 221 | "type": "GetHeadPose", 222 | "pos": [ 223 | 819.799560546875, 224 | 860.5927734375 225 | ], 226 | "size": [ 227 | 270, 228 | 98 229 | ], 230 | "flags": {}, 231 | "order": 7, 232 | "mode": 0, 233 | "inputs": [ 234 | { 235 | "name": "face_toolkits", 236 | "type": "FACE_TOOLKITS", 237 | "link": 149 238 | }, 239 | { 240 | "name": "images", 241 | "type": "IMAGE", 242 | "link": 189 243 | }, 244 | { 245 | "name": "landmarks", 246 | "type": "FACELANDMARKS222", 247 | "link": 145 248 | } 249 | ], 250 | "outputs": [ 251 | { 252 | "name": "head_pose", 253 | "type": "HEAD_POSE", 254 | "links": [ 255 | 185 256 | ] 257 | } 258 | ], 259 | "properties": { 260 | "Node name for S&R": "GetHeadPose" 261 | }, 262 | "widgets_values": [ 263 | true 264 | ] 265 | }, 266 | { 267 | "id": 76, 268 | "type": "GetExpression2", 269 | "pos": [ 270 | 833.6573486328125, 271 | 1042.0396728515625 272 | ], 273 | "size": [ 274 | 191.63729858398438, 275 | 66 276 | ], 277 | "flags": {}, 278 | "order": 8, 279 | "mode": 0, 280 | "inputs": [ 281 | { 282 | "name": "face_toolkits", 283 | "type": "FACE_TOOLKITS", 284 | "link": 174 285 | }, 286 | { 287 | "name": "images", 288 | "type": "IMAGE", 289 | "link": 190 290 | }, 291 | { 292 | "name": "landmarks", 293 | "type": "FACELANDMARKS222", 294 | "link": 177 295 | } 296 | ], 297 | "outputs": [ 298 | { 299 | "name": "expression", 300 | "type": "EXPRESSION", 301 | "links": [ 302 | 186 303 | ] 304 | } 305 | ], 306 | "properties": { 307 | "Node name for S&R": "GetExpression2" 308 | }, 309 | "widgets_values": [] 310 | }, 311 | { 312 | "id": 78, 313 | "type": "HMImagePipelineLoader", 314 | "pos": [ 315 | 797.7930297851562, 316 | 188.84007263183594 317 | ], 318 | "size": [ 319 | 298.3472595214844, 320 | 226 321 | ], 322 | "flags": {}, 323 | "order": 1, 324 | "mode": 0, 325 | "inputs": [], 326 | "outputs": [ 327 | { 328 | "name": "hm_image_pipeline", 329 | "type": "HMIMAGEPIPELINE", 330 | "links": [ 331 | 182 332 | ] 333 | } 334 | ], 335 | "properties": { 336 | "Node name for S&R": "HMImagePipelineLoader" 337 | }, 338 | "widgets_values": [ 339 | "[preset]DisneyPixarCartoonB", 340 | "[preset]BabyFaceV1", 341 | "same as checkpoint", 342 | "v5", 343 | "x1", 344 | "huggingface", 345 | 1, 346 | "fp32" 347 | ] 348 | }, 349 | { 350 | "id": 79, 351 | "type": "HMPipelineImage", 352 | "pos": [ 353 | 1334.226318359375, 354 | 613.9535522460938 355 | ], 356 | "size": [ 357 | 270, 358 | 306 359 | ], 360 | "flags": {}, 361 | "order": 9, 362 | "mode": 0, 363 | "inputs": [ 364 | { 365 | "name": "hm_image_pipeline", 366 | "type": "HMIMAGEPIPELINE", 367 | "link": 182 368 | }, 369 | { 370 | "name": "ref_head_pose", 371 | "type": "HEAD_POSE", 372 | "link": 183 373 | }, 374 | { 375 | "name": "ref_expression", 376 | "type": "EXPRESSION", 377 | "link": 184 378 | }, 379 | { 380 | "name": "drive_head_pose", 381 | "type": "HEAD_POSE", 382 | "link": 185 383 | }, 384 | { 385 | "name": "drive_expression", 386 | "type": "EXPRESSION", 387 | "link": 186 388 | } 389 | ], 390 | "outputs": [ 391 | { 392 | "name": "IMAGE", 393 | "type": "IMAGE", 394 | "links": [ 395 | 187 396 | ] 397 | }, 398 | { 399 | "name": "LATENT", 400 | "type": "LATENT", 401 | "links": null 402 | } 403 | ], 404 | "properties": { 405 | "Node name for S&R": "HMPipelineImage" 406 | }, 407 | "widgets_values": [ 408 | 0, 409 | "", 410 | "", 411 | 25, 412 | 485743939425196, 413 | "randomize", 414 | 1.5, 415 | 0 416 | ] 417 | }, 418 | { 419 | "id": 58, 420 | "type": "LoadImage", 421 | "pos": [ 422 | 41.47996520996094, 423 | 358.5067138671875 424 | ], 425 | "size": [ 426 | 315, 427 | 314 428 | ], 429 | "flags": {}, 430 | "order": 2, 431 | "mode": 0, 432 | "inputs": [], 433 | "outputs": [ 434 | { 435 | "name": "IMAGE", 436 | "type": "IMAGE", 437 | "slot_index": 0, 438 | "links": [ 439 | 121, 440 | 151, 441 | 175, 442 | 188, 443 | 189, 444 | 190 445 | ] 446 | }, 447 | { 448 | "name": "MASK", 449 | "type": "MASK", 450 | "links": null 451 | } 452 | ], 453 | "properties": { 454 | "Node name for S&R": "LoadImage" 455 | }, 456 | "widgets_values": [ 457 | "chillout.jpg", 458 | "image" 459 | ] 460 | }, 461 | { 462 | "id": 80, 463 | "type": "PreviewImage", 464 | "pos": [ 465 | 1703.6793212890625, 466 | 613.1461791992188 467 | ], 468 | "size": [ 469 | 210, 470 | 258 471 | ], 472 | "flags": {}, 473 | "order": 10, 474 | "mode": 0, 475 | "inputs": [ 476 | { 477 | "name": "images", 478 | "type": "IMAGE", 479 | "link": 187 480 | } 481 | ], 482 | "outputs": [], 483 | "properties": { 484 | "Node name for S&R": "PreviewImage" 485 | }, 486 | "widgets_values": [] 487 | } 488 | ], 489 | "links": [ 490 | [ 491 | 121, 492 | 58, 493 | 0, 494 | 60, 495 | 1, 496 | "IMAGE" 497 | ], 498 | [ 499 | 122, 500 | 2, 501 | 0, 502 | 60, 503 | 0, 504 | "FACE_TOOLKITS" 505 | ], 506 | [ 507 | 129, 508 | 2, 509 | 0, 510 | 61, 511 | 0, 512 | "FACE_TOOLKITS" 513 | ], 514 | [ 515 | 143, 516 | 60, 517 | 0, 518 | 67, 519 | 2, 520 | "FACELANDMARKS222" 521 | ], 522 | [ 523 | 145, 524 | 61, 525 | 0, 526 | 69, 527 | 2, 528 | "FACELANDMARKS222" 529 | ], 530 | [ 531 | 147, 532 | 2, 533 | 0, 534 | 67, 535 | 0, 536 | "FACE_TOOLKITS" 537 | ], 538 | [ 539 | 149, 540 | 2, 541 | 0, 542 | 69, 543 | 0, 544 | "FACE_TOOLKITS" 545 | ], 546 | [ 547 | 151, 548 | 58, 549 | 0, 550 | 67, 551 | 1, 552 | "IMAGE" 553 | ], 554 | [ 555 | 173, 556 | 2, 557 | 0, 558 | 75, 559 | 0, 560 | "FACE_TOOLKITS" 561 | ], 562 | [ 563 | 174, 564 | 2, 565 | 0, 566 | 76, 567 | 0, 568 | "FACE_TOOLKITS" 569 | ], 570 | [ 571 | 175, 572 | 58, 573 | 0, 574 | 75, 575 | 1, 576 | "IMAGE" 577 | ], 578 | [ 579 | 177, 580 | 61, 581 | 0, 582 | 76, 583 | 2, 584 | "FACELANDMARKS222" 585 | ], 586 | [ 587 | 178, 588 | 60, 589 | 0, 590 | 75, 591 | 2, 592 | "FACELANDMARKS222" 593 | ], 594 | [ 595 | 182, 596 | 78, 597 | 0, 598 | 79, 599 | 0, 600 | "HMIMAGEPIPELINE" 601 | ], 602 | [ 603 | 183, 604 | 67, 605 | 0, 606 | 79, 607 | 1, 608 | "HEAD_POSE" 609 | ], 610 | [ 611 | 184, 612 | 75, 613 | 0, 614 | 79, 615 | 2, 616 | "EXPRESSION" 617 | ], 618 | [ 619 | 185, 620 | 69, 621 | 0, 622 | 79, 623 | 3, 624 | "HEAD_POSE" 625 | ], 626 | [ 627 | 186, 628 | 76, 629 | 0, 630 | 79, 631 | 4, 632 | "EXPRESSION" 633 | ], 634 | [ 635 | 187, 636 | 79, 637 | 0, 638 | 80, 639 | 0, 640 | "IMAGE" 641 | ], 642 | [ 643 | 188, 644 | 58, 645 | 0, 646 | 61, 647 | 1, 648 | "IMAGE" 649 | ], 650 | [ 651 | 189, 652 | 58, 653 | 0, 654 | 69, 655 | 1, 656 | "IMAGE" 657 | ], 658 | [ 659 | 190, 660 | 58, 661 | 0, 662 | 76, 663 | 1, 664 | "IMAGE" 665 | ] 666 | ], 667 | "groups": [], 668 | "config": {}, 669 | "extra": { 670 | "ds": { 671 | "scale": 0.8264462809917358, 672 | "offset": [ 673 | 496.5037384564152, 674 | -71.873413496035 675 | ] 676 | }, 677 | "frontendVersion": "1.19.2" 678 | }, 679 | "version": 0.4 680 | } -------------------------------------------------------------------------------- /example_workflows/video_generation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HelloVision/ComfyUI_HelloMeme/ce368e562f530b9efaf623b72e619b50b82a1ead/example_workflows/video_generation.jpg -------------------------------------------------------------------------------- /example_workflows/video_generation.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "4c8b7397-a69c-46cc-b553-00e600557365", 3 | "revision": 0, 4 | "last_node_id": 76, 5 | "last_link_id": 178, 6 | "nodes": [ 7 | { 8 | "id": 75, 9 | "type": "GetExpression2", 10 | "pos": [ 11 | 814.1359252929688, 12 | 689.6795654296875 13 | ], 14 | "size": [ 15 | 191.63729858398438, 16 | 66 17 | ], 18 | "flags": {}, 19 | "order": 9, 20 | "mode": 0, 21 | "inputs": [ 22 | { 23 | "name": "face_toolkits", 24 | "type": "FACE_TOOLKITS", 25 | "link": 173 26 | }, 27 | { 28 | "name": "images", 29 | "type": "IMAGE", 30 | "link": 175 31 | }, 32 | { 33 | "name": "landmarks", 34 | "type": "FACELANDMARKS222", 35 | "link": 178 36 | } 37 | ], 38 | "outputs": [ 39 | { 40 | "name": "expression", 41 | "type": "EXPRESSION", 42 | "links": [ 43 | 171 44 | ] 45 | } 46 | ], 47 | "properties": { 48 | "Node name for S&R": "GetExpression2" 49 | }, 50 | "widgets_values": [] 51 | }, 52 | { 53 | "id": 76, 54 | "type": "GetExpression2", 55 | "pos": [ 56 | 833.6573486328125, 57 | 1042.0396728515625 58 | ], 59 | "size": [ 60 | 191.63729858398438, 61 | 66 62 | ], 63 | "flags": {}, 64 | "order": 7, 65 | "mode": 0, 66 | "inputs": [ 67 | { 68 | "name": "face_toolkits", 69 | "type": "FACE_TOOLKITS", 70 | "link": 174 71 | }, 72 | { 73 | "name": "images", 74 | "type": "IMAGE", 75 | "link": 176 76 | }, 77 | { 78 | "name": "landmarks", 79 | "type": "FACELANDMARKS222", 80 | "link": 177 81 | } 82 | ], 83 | "outputs": [ 84 | { 85 | "name": "expression", 86 | "type": "EXPRESSION", 87 | "links": [ 88 | 172 89 | ] 90 | } 91 | ], 92 | "properties": { 93 | "Node name for S&R": "GetExpression2" 94 | }, 95 | "widgets_values": [] 96 | }, 97 | { 98 | "id": 2, 99 | "type": "HMFaceToolkitsLoader", 100 | "pos": [ 101 | 45.41996383666992, 102 | 769.9866333007812 103 | ], 104 | "size": [ 105 | 315, 106 | 82 107 | ], 108 | "flags": {}, 109 | "order": 0, 110 | "mode": 0, 111 | "inputs": [], 112 | "outputs": [ 113 | { 114 | "name": "face_toolkits", 115 | "type": "FACE_TOOLKITS", 116 | "slot_index": 0, 117 | "links": [ 118 | 122, 119 | 129, 120 | 147, 121 | 149, 122 | 173, 123 | 174 124 | ] 125 | } 126 | ], 127 | "properties": { 128 | "Node name for S&R": "HMFaceToolkitsLoader" 129 | }, 130 | "widgets_values": [ 131 | 0, 132 | "huggingface" 133 | ] 134 | }, 135 | { 136 | "id": 61, 137 | "type": "GetFaceLandmarks", 138 | "pos": [ 139 | 519.7196655273438, 140 | 938.032958984375 141 | ], 142 | "size": [ 143 | 170.62362670898438, 144 | 46 145 | ], 146 | "flags": {}, 147 | "order": 4, 148 | "mode": 0, 149 | "inputs": [ 150 | { 151 | "name": "face_toolkits", 152 | "type": "FACE_TOOLKITS", 153 | "link": 129 154 | }, 155 | { 156 | "name": "images", 157 | "type": "IMAGE", 158 | "link": 131 159 | } 160 | ], 161 | "outputs": [ 162 | { 163 | "name": "landmarks", 164 | "type": "FACELANDMARKS222", 165 | "links": [ 166 | 145, 167 | 177 168 | ] 169 | } 170 | ], 171 | "properties": { 172 | "Node name for S&R": "GetFaceLandmarks" 173 | }, 174 | "widgets_values": [] 175 | }, 176 | { 177 | "id": 60, 178 | "type": "GetFaceLandmarks", 179 | "pos": [ 180 | 518.9132690429688, 181 | 612.9461669921875 182 | ], 183 | "size": [ 184 | 170.62362670898438, 185 | 46 186 | ], 187 | "flags": {}, 188 | "order": 5, 189 | "mode": 0, 190 | "inputs": [ 191 | { 192 | "name": "face_toolkits", 193 | "type": "FACE_TOOLKITS", 194 | "link": 122 195 | }, 196 | { 197 | "name": "images", 198 | "type": "IMAGE", 199 | "link": 121 200 | } 201 | ], 202 | "outputs": [ 203 | { 204 | "name": "landmarks", 205 | "type": "FACELANDMARKS222", 206 | "links": [ 207 | 143, 208 | 178 209 | ] 210 | } 211 | ], 212 | "properties": { 213 | "Node name for S&R": "GetFaceLandmarks" 214 | }, 215 | "widgets_values": [] 216 | }, 217 | { 218 | "id": 59, 219 | "type": "VHS_VideoCombine", 220 | "pos": [ 221 | 1658.52490234375, 222 | 685.6950073242188 223 | ], 224 | "size": [ 225 | 329.22479248046875, 226 | 633.2247924804688 227 | ], 228 | "flags": {}, 229 | "order": 11, 230 | "mode": 0, 231 | "inputs": [ 232 | { 233 | "name": "images", 234 | "type": "IMAGE", 235 | "link": 170 236 | }, 237 | { 238 | "name": "audio", 239 | "shape": 7, 240 | "type": "AUDIO", 241 | "link": 118 242 | }, 243 | { 244 | "name": "meta_batch", 245 | "shape": 7, 246 | "type": "VHS_BatchManager", 247 | "link": null 248 | }, 249 | { 250 | "name": "vae", 251 | "shape": 7, 252 | "type": "VAE", 253 | "link": null 254 | } 255 | ], 256 | "outputs": [ 257 | { 258 | "name": "Filenames", 259 | "type": "VHS_FILENAMES", 260 | "links": null 261 | } 262 | ], 263 | "properties": { 264 | "Node name for S&R": "VHS_VideoCombine" 265 | }, 266 | "widgets_values": { 267 | "frame_rate": 8, 268 | "loop_count": 0, 269 | "filename_prefix": "AnimateDiff", 270 | "format": "video/h264-mp4", 271 | "pix_fmt": "yuv420p", 272 | "crf": 18, 273 | "save_metadata": true, 274 | "pingpong": false, 275 | "save_output": true, 276 | "videopreview": { 277 | "hidden": false, 278 | "paused": false, 279 | "params": { 280 | "filename": "AnimateDiff_00703-audio.mp4", 281 | "subfolder": "", 282 | "type": "output", 283 | "format": "video/h264-mp4", 284 | "frame_rate": 8 285 | }, 286 | "muted": false 287 | } 288 | } 289 | }, 290 | { 291 | "id": 74, 292 | "type": "HMPipelineVideo", 293 | "pos": [ 294 | 1296.2967529296875, 295 | 659.4104614257812 296 | ], 297 | "size": [ 298 | 270, 299 | 330 300 | ], 301 | "flags": {}, 302 | "order": 10, 303 | "mode": 0, 304 | "inputs": [ 305 | { 306 | "name": "hm_video_pipeline", 307 | "type": "HMVIDEOPIPELINE", 308 | "link": 165 309 | }, 310 | { 311 | "name": "ref_head_pose", 312 | "type": "HEAD_POSE", 313 | "link": 166 314 | }, 315 | { 316 | "name": "ref_expression", 317 | "type": "EXPRESSION", 318 | "link": 171 319 | }, 320 | { 321 | "name": "drive_head_pose", 322 | "type": "HEAD_POSE", 323 | "link": 168 324 | }, 325 | { 326 | "name": "drive_expression", 327 | "type": "EXPRESSION", 328 | "link": 172 329 | } 330 | ], 331 | "outputs": [ 332 | { 333 | "name": "IMAGE", 334 | "type": "IMAGE", 335 | "links": [ 336 | 170 337 | ] 338 | }, 339 | { 340 | "name": "LATENT", 341 | "type": "LATENT", 342 | "links": null 343 | } 344 | ], 345 | "properties": { 346 | "Node name for S&R": "HMPipelineVideo" 347 | }, 348 | "widgets_values": [ 349 | 0, 350 | 4, 351 | "", 352 | "", 353 | 25, 354 | 1075819929118817, 355 | "randomize", 356 | 1.5, 357 | 0 358 | ] 359 | }, 360 | { 361 | "id": 69, 362 | "type": "GetHeadPose", 363 | "pos": [ 364 | 819.799560546875, 365 | 860.5927734375 366 | ], 367 | "size": [ 368 | 270, 369 | 98 370 | ], 371 | "flags": {}, 372 | "order": 6, 373 | "mode": 0, 374 | "inputs": [ 375 | { 376 | "name": "face_toolkits", 377 | "type": "FACE_TOOLKITS", 378 | "link": 149 379 | }, 380 | { 381 | "name": "images", 382 | "type": "IMAGE", 383 | "link": 153 384 | }, 385 | { 386 | "name": "landmarks", 387 | "type": "FACELANDMARKS222", 388 | "link": 145 389 | } 390 | ], 391 | "outputs": [ 392 | { 393 | "name": "head_pose", 394 | "type": "HEAD_POSE", 395 | "links": [ 396 | 168 397 | ] 398 | } 399 | ], 400 | "properties": { 401 | "Node name for S&R": "GetHeadPose" 402 | }, 403 | "widgets_values": [ 404 | true 405 | ] 406 | }, 407 | { 408 | "id": 48, 409 | "type": "VHS_LoadVideo", 410 | "pos": [ 411 | 74.54000091552734, 412 | 996.8733520507812 413 | ], 414 | "size": [ 415 | 247.455078125, 416 | 262 417 | ], 418 | "flags": {}, 419 | "order": 1, 420 | "mode": 0, 421 | "inputs": [ 422 | { 423 | "name": "meta_batch", 424 | "shape": 7, 425 | "type": "VHS_BatchManager", 426 | "link": null 427 | }, 428 | { 429 | "name": "vae", 430 | "shape": 7, 431 | "type": "VAE", 432 | "link": null 433 | } 434 | ], 435 | "outputs": [ 436 | { 437 | "name": "IMAGE", 438 | "type": "IMAGE", 439 | "slot_index": 0, 440 | "links": [ 441 | 131, 442 | 153, 443 | 176 444 | ] 445 | }, 446 | { 447 | "name": "frame_count", 448 | "type": "INT", 449 | "links": null 450 | }, 451 | { 452 | "name": "audio", 453 | "type": "AUDIO", 454 | "slot_index": 2, 455 | "links": [ 456 | 118 457 | ] 458 | }, 459 | { 460 | "name": "video_info", 461 | "type": "VHS_VIDEOINFO", 462 | "links": null 463 | } 464 | ], 465 | "properties": { 466 | "Node name for S&R": "VHS_LoadVideo" 467 | }, 468 | "widgets_values": { 469 | "video": "jgz.mp4", 470 | "force_rate": 8, 471 | "force_size": "Disabled", 472 | "custom_width": 512, 473 | "custom_height": 512, 474 | "frame_load_cap": 0, 475 | "skip_first_frames": 0, 476 | "select_every_nth": 1, 477 | "choose video to upload": "image", 478 | "videopreview": { 479 | "hidden": false, 480 | "paused": false, 481 | "params": { 482 | "force_rate": 8, 483 | "frame_load_cap": 0, 484 | "skip_first_frames": 0, 485 | "select_every_nth": 1, 486 | "filename": "jgz.mp4", 487 | "type": "input", 488 | "format": "video/mp4" 489 | }, 490 | "muted": false 491 | } 492 | } 493 | }, 494 | { 495 | "id": 58, 496 | "type": "LoadImage", 497 | "pos": [ 498 | 37.4466552734375, 499 | 260.9000244140625 500 | ], 501 | "size": [ 502 | 315, 503 | 314 504 | ], 505 | "flags": {}, 506 | "order": 2, 507 | "mode": 0, 508 | "inputs": [], 509 | "outputs": [ 510 | { 511 | "name": "IMAGE", 512 | "type": "IMAGE", 513 | "slot_index": 0, 514 | "links": [ 515 | 121, 516 | 151, 517 | 175 518 | ] 519 | }, 520 | { 521 | "name": "MASK", 522 | "type": "MASK", 523 | "links": null 524 | } 525 | ], 526 | "properties": { 527 | "Node name for S&R": "LoadImage" 528 | }, 529 | "widgets_values": [ 530 | "i5.jpg", 531 | "image" 532 | ] 533 | }, 534 | { 535 | "id": 67, 536 | "type": "GetHeadPose", 537 | "pos": [ 538 | 814.1528930664062, 539 | 495.9793701171875 540 | ], 541 | "size": [ 542 | 270, 543 | 98 544 | ], 545 | "flags": {}, 546 | "order": 8, 547 | "mode": 0, 548 | "inputs": [ 549 | { 550 | "name": "face_toolkits", 551 | "type": "FACE_TOOLKITS", 552 | "link": 147 553 | }, 554 | { 555 | "name": "images", 556 | "type": "IMAGE", 557 | "link": 151 558 | }, 559 | { 560 | "name": "landmarks", 561 | "type": "FACELANDMARKS222", 562 | "link": 143 563 | } 564 | ], 565 | "outputs": [ 566 | { 567 | "name": "head_pose", 568 | "type": "HEAD_POSE", 569 | "links": [ 570 | 166 571 | ] 572 | } 573 | ], 574 | "properties": { 575 | "Node name for S&R": "GetHeadPose" 576 | }, 577 | "widgets_values": [ 578 | true 579 | ] 580 | }, 581 | { 582 | "id": 57, 583 | "type": "HMVideoPipelineLoader", 584 | "pos": [ 585 | 798.4796752929688, 586 | 193.7666015625 587 | ], 588 | "size": [ 589 | 352.79998779296875, 590 | 226 591 | ], 592 | "flags": {}, 593 | "order": 3, 594 | "mode": 0, 595 | "inputs": [], 596 | "outputs": [ 597 | { 598 | "name": "hm_video_pipeline", 599 | "type": "HMVIDEOPIPELINE", 600 | "slot_index": 0, 601 | "links": [ 602 | 165 603 | ] 604 | } 605 | ], 606 | "properties": { 607 | "Node name for S&R": "HMVideoPipelineLoader" 608 | }, 609 | "widgets_values": [ 610 | "[preset]RealisticVisionV60B1", 611 | "None", 612 | "same as checkpoint", 613 | "v5", 614 | "x1", 615 | "huggingface", 616 | 1, 617 | "fp32" 618 | ] 619 | } 620 | ], 621 | "links": [ 622 | [ 623 | 118, 624 | 48, 625 | 2, 626 | 59, 627 | 1, 628 | "AUDIO" 629 | ], 630 | [ 631 | 121, 632 | 58, 633 | 0, 634 | 60, 635 | 1, 636 | "IMAGE" 637 | ], 638 | [ 639 | 122, 640 | 2, 641 | 0, 642 | 60, 643 | 0, 644 | "FACE_TOOLKITS" 645 | ], 646 | [ 647 | 129, 648 | 2, 649 | 0, 650 | 61, 651 | 0, 652 | "FACE_TOOLKITS" 653 | ], 654 | [ 655 | 131, 656 | 48, 657 | 0, 658 | 61, 659 | 1, 660 | "IMAGE" 661 | ], 662 | [ 663 | 143, 664 | 60, 665 | 0, 666 | 67, 667 | 2, 668 | "FACELANDMARKS222" 669 | ], 670 | [ 671 | 145, 672 | 61, 673 | 0, 674 | 69, 675 | 2, 676 | "FACELANDMARKS222" 677 | ], 678 | [ 679 | 147, 680 | 2, 681 | 0, 682 | 67, 683 | 0, 684 | "FACE_TOOLKITS" 685 | ], 686 | [ 687 | 149, 688 | 2, 689 | 0, 690 | 69, 691 | 0, 692 | "FACE_TOOLKITS" 693 | ], 694 | [ 695 | 151, 696 | 58, 697 | 0, 698 | 67, 699 | 1, 700 | "IMAGE" 701 | ], 702 | [ 703 | 153, 704 | 48, 705 | 0, 706 | 69, 707 | 1, 708 | "IMAGE" 709 | ], 710 | [ 711 | 165, 712 | 57, 713 | 0, 714 | 74, 715 | 0, 716 | "HMVIDEOPIPELINE" 717 | ], 718 | [ 719 | 166, 720 | 67, 721 | 0, 722 | 74, 723 | 1, 724 | "HEAD_POSE" 725 | ], 726 | [ 727 | 168, 728 | 69, 729 | 0, 730 | 74, 731 | 3, 732 | "HEAD_POSE" 733 | ], 734 | [ 735 | 170, 736 | 74, 737 | 0, 738 | 59, 739 | 0, 740 | "IMAGE" 741 | ], 742 | [ 743 | 171, 744 | 75, 745 | 0, 746 | 74, 747 | 2, 748 | "EXPRESSION" 749 | ], 750 | [ 751 | 172, 752 | 76, 753 | 0, 754 | 74, 755 | 4, 756 | "EXPRESSION" 757 | ], 758 | [ 759 | 173, 760 | 2, 761 | 0, 762 | 75, 763 | 0, 764 | "FACE_TOOLKITS" 765 | ], 766 | [ 767 | 174, 768 | 2, 769 | 0, 770 | 76, 771 | 0, 772 | "FACE_TOOLKITS" 773 | ], 774 | [ 775 | 175, 776 | 58, 777 | 0, 778 | 75, 779 | 1, 780 | "IMAGE" 781 | ], 782 | [ 783 | 176, 784 | 48, 785 | 0, 786 | 76, 787 | 1, 788 | "IMAGE" 789 | ], 790 | [ 791 | 177, 792 | 61, 793 | 0, 794 | 76, 795 | 2, 796 | "FACELANDMARKS222" 797 | ], 798 | [ 799 | 178, 800 | 60, 801 | 0, 802 | 75, 803 | 2, 804 | "FACELANDMARKS222" 805 | ] 806 | ], 807 | "groups": [], 808 | "config": {}, 809 | "extra": { 810 | "ds": { 811 | "scale": 0.8264462809917358, 812 | "offset": [ 813 | 493.27712102477454, 814 | -69.45337656976547 815 | ] 816 | }, 817 | "frontendVersion": "1.19.2" 818 | }, 819 | "version": 0.4 820 | } -------------------------------------------------------------------------------- /examples/amns.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HelloVision/ComfyUI_HelloMeme/ce368e562f530b9efaf623b72e619b50b82a1ead/examples/amns.mp4 -------------------------------------------------------------------------------- /examples/chillout.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HelloVision/ComfyUI_HelloMeme/ce368e562f530b9efaf623b72e619b50b82a1ead/examples/chillout.jpg -------------------------------------------------------------------------------- /examples/civitai2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HelloVision/ComfyUI_HelloMeme/ce368e562f530b9efaf623b72e619b50b82a1ead/examples/civitai2.jpg -------------------------------------------------------------------------------- /examples/helloicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HelloVision/ComfyUI_HelloMeme/ce368e562f530b9efaf623b72e619b50b82a1ead/examples/helloicon.png -------------------------------------------------------------------------------- /examples/i5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HelloVision/ComfyUI_HelloMeme/ce368e562f530b9efaf623b72e619b50b82a1ead/examples/i5.jpg -------------------------------------------------------------------------------- /examples/jgz.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HelloVision/ComfyUI_HelloMeme/ce368e562f530b9efaf623b72e619b50b82a1ead/examples/jgz.mp4 -------------------------------------------------------------------------------- /examples/majicmix2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HelloVision/ComfyUI_HelloMeme/ce368e562f530b9efaf623b72e619b50b82a1ead/examples/majicmix2.jpg -------------------------------------------------------------------------------- /examples/qie.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HelloVision/ComfyUI_HelloMeme/ce368e562f530b9efaf623b72e619b50b82a1ead/examples/qie.mp4 -------------------------------------------------------------------------------- /examples/tiktok.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HelloVision/ComfyUI_HelloMeme/ce368e562f530b9efaf623b72e619b50b82a1ead/examples/tiktok.mp4 -------------------------------------------------------------------------------- /examples/toon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HelloVision/ComfyUI_HelloMeme/ce368e562f530b9efaf623b72e619b50b82a1ead/examples/toon.png -------------------------------------------------------------------------------- /examples/yao.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HelloVision/ComfyUI_HelloMeme/ce368e562f530b9efaf623b72e619b50b82a1ead/examples/yao.jpg -------------------------------------------------------------------------------- /hellomeme/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | @File : __init__.py.py 5 | @Author : Songkey 6 | @Email : songkey@pku.edu.cn 7 | @Date : 8/28/2024 8 | @Desc : 9 | """ 10 | 11 | from .pipelines import (HMImagePipeline, HMVideoPipeline, 12 | HM3ImagePipeline, HM3VideoPipeline, 13 | HM5ImagePipeline, HM5VideoPipeline) -------------------------------------------------------------------------------- /hellomeme/model_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "sd15": { 3 | "checkpoints": { 4 | "SD1.5": "songkey/stable-diffusion-v1-5", 5 | "[preset]RealisticVisionV60B1": "songkey/realisticVisionV60B1_v51VAE", 6 | "[preset]DisneyPixarCartoonB": "songkey/disney-pixar-cartoon-b", 7 | "[preset]toonyou_beta6": "songkey/toonyou_beta6", 8 | "[preset]LZ_2DCartoon_V2": "songkey/LZ_2DCartoon_V2", 9 | "[preset]meinamix_v12Final": "songkey/meinamix_v12Final", 10 | "[preset]animedark_v10": "songkey/animedark_v10", 11 | "[preset]absolutereality_v181": "songkey/absolutereality_v181", 12 | "[preset]dreamshaper_8": "songkey/dreamshaper_8", 13 | "[preset]epicphotogasm_ultimateFidelity": "songkey/epicphotogasm_ultimateFidelity", 14 | "[preset]epicrealism_naturalSinRC1VAE": "songkey/epicrealism_naturalSinRC1VAE", 15 | "[preset]xxmix9realistic_v40": "songkey/xxmix9realistic_v40", 16 | "[preset]cyberrealistic_v80": "songkey/cyberrealistic_v80" 17 | }, 18 | "loras": { 19 | "[preset]BabyFaceV1": ["songkey/loras_sd_1_5", "baby_face_v1.safetensors"], 20 | "[preset]MoreDetails": ["songkey/loras_sd_1_5", "more_details.safetensors"], 21 | "[preset]PixelPortraitV1": ["songkey/loras_sd_1_5", "pixel-portrait-v1.safetensors"], 22 | "[preset]Drawing": ["songkey/loras_sd_1_5", "Drawing.safetensors"] 23 | } 24 | }, 25 | "prompt": "(best quality), highly detailed, ultra-detailed, headshot, person, well-placed five sense organs, looking at the viewer, centered composition, sharp focus, realistic skin texture" 26 | } -------------------------------------------------------------------------------- /hellomeme/models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | @File : __init__.py 5 | @Author : Songkey 6 | @Email : songkey@pku.edu.cn 7 | @Date : 8/14/2024 8 | @Desc : 9 | """ 10 | 11 | from .hm_denoising_motion import HMDenoisingMotion 12 | from .hm_control import (HMControlNet, HMControlNet2, HMV2ControlNet, HMV2ControlNet2, 13 | HMV3ControlNet, HMControlNetBase, HM5ControlNetBase, 14 | HM4SD15ControlProj, HM5SD15ControlProj) 15 | from .hm_adapters import (HMReferenceAdapter, HM3ReferenceAdapter, HM5ReferenceAdapter, 16 | HM3MotionAdapter, HM5MotionAdapter, HMPipeline) 17 | from .hm_denoising_3d import HMDenoising3D 18 | from .hm3_denoising_3d import HM3Denoising3D 19 | from .hm3_denoising_motion import HM3DenoisingMotion 20 | -------------------------------------------------------------------------------- /hellomeme/models/hm3_denoising_3d.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | @File : models6/hm_denoising_3d.py 5 | @Author : Songkey 6 | @Email : songkey@pku.edu.cn 7 | @Date : 1/3/2025 8 | @Desc : 9 | adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_condition.py 10 | """ 11 | 12 | import torch 13 | import torch.utils.checkpoint 14 | from typing import Any, Dict, Optional, Tuple, Union 15 | 16 | from einops import rearrange 17 | 18 | from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers 19 | from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput 20 | from .hm_adapters import CopyWeights, InsertReferenceAdapter 21 | 22 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 23 | 24 | class HM3Denoising3D(UNet2DConditionModel, CopyWeights, InsertReferenceAdapter): 25 | def forward( 26 | self, 27 | sample: torch.Tensor, 28 | timestep: Union[torch.Tensor, float, int], 29 | encoder_hidden_states: torch.Tensor, 30 | reference_hidden_states: Optional[dict] = None, 31 | control_hidden_states: Optional[dict] = None, 32 | motion_pad_hidden_states: Optional[dict] = None, 33 | use_motion: bool = False, 34 | class_labels: Optional[torch.Tensor] = None, 35 | timestep_cond: Optional[torch.Tensor] = None, 36 | attention_mask: Optional[torch.Tensor] = None, 37 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 38 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 39 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 40 | mid_block_additional_residual: Optional[torch.Tensor] = None, 41 | down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 42 | encoder_attention_mask: Optional[torch.Tensor] = None, 43 | return_dict: bool = True, 44 | ) -> Union[UNet2DConditionOutput, Tuple]: 45 | # By default samples have to be AT least a multiple of the overall upsampling factor. 46 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). 47 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 48 | # on the fly if necessary. 49 | default_overall_up_factor = 2**self.num_upsamplers 50 | 51 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 52 | forward_upsample_size = False 53 | upsample_size = None 54 | 55 | for dim in sample.shape[-2:]: 56 | if dim % default_overall_up_factor != 0: 57 | # Forward upsample size to force interpolation output size. 58 | forward_upsample_size = True 59 | break 60 | 61 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension 62 | # expects mask of shape: 63 | # [batch, key_tokens] 64 | # adds singleton query_tokens dimension: 65 | # [batch, 1, key_tokens] 66 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: 67 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) 68 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) 69 | if attention_mask is not None: 70 | # assume that mask is expressed as: 71 | # (1 = keep, 0 = discard) 72 | # convert mask into a bias that can be added to attention scores: 73 | # (keep = +0, discard = -10000.0) 74 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 75 | attention_mask = attention_mask.unsqueeze(1) 76 | 77 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 78 | if encoder_attention_mask is not None: 79 | encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 80 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 81 | 82 | # 0. center input if necessary 83 | if self.config.center_input_sample: 84 | sample = 2 * sample - 1.0 85 | 86 | # 1. time 87 | t_emb = self.get_time_embed(sample=sample, timestep=timestep) 88 | emb = self.time_embedding(t_emb, timestep_cond) 89 | 90 | class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) 91 | if class_emb is not None: 92 | if self.config.class_embeddings_concat: 93 | emb = torch.cat([emb, class_emb], dim=-1) 94 | else: 95 | emb = emb + class_emb 96 | 97 | aug_emb = self.get_aug_embed( 98 | emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs 99 | ) 100 | if self.config.addition_embed_type == "image_hint": 101 | aug_emb, hint = aug_emb 102 | sample = torch.cat([sample, hint], dim=1) 103 | 104 | emb = emb + aug_emb if aug_emb is not None else emb 105 | 106 | if self.time_embed_act is not None: 107 | emb = self.time_embed_act(emb) 108 | 109 | num_frames = sample.shape[2] 110 | emb = emb.repeat_interleave(repeats=num_frames, dim=0) 111 | 112 | if not added_cond_kwargs is None: 113 | if 'image_embeds' in added_cond_kwargs: 114 | if isinstance(added_cond_kwargs['image_embeds'], torch.Tensor): 115 | added_cond_kwargs['image_embeds'] = added_cond_kwargs['image_embeds'].repeat_interleave(repeats=num_frames, dim=0) 116 | else: 117 | added_cond_kwargs['image_embeds'] = [x.repeat_interleave(repeats=num_frames, dim=0) for x in added_cond_kwargs['image_embeds']] 118 | 119 | if len(encoder_hidden_states.shape) == 3: 120 | encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) 121 | elif len(encoder_hidden_states.shape) == 4: 122 | encoder_hidden_states = rearrange(encoder_hidden_states, "b f l d -> (b f) l d") 123 | 124 | encoder_hidden_states = self.process_encoder_hidden_states( 125 | encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs 126 | ) 127 | 128 | # 2. pre-process 129 | sample = rearrange(sample, "b c f h w -> (b f) c h w") 130 | sample = self.conv_in(sample) 131 | 132 | # 2.5 GLIGEN position net 133 | if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: 134 | cross_attention_kwargs = cross_attention_kwargs.copy() 135 | gligen_args = cross_attention_kwargs.pop("gligen") 136 | cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} 137 | 138 | # 3. down 139 | # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated 140 | # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. 141 | if cross_attention_kwargs is not None: 142 | cross_attention_kwargs = cross_attention_kwargs.copy() 143 | lora_scale = cross_attention_kwargs.pop("scale", 1.0) 144 | else: 145 | lora_scale = 1.0 146 | 147 | if USE_PEFT_BACKEND: 148 | # weight the lora layers by setting `lora_scale` for each PEFT layer 149 | scale_lora_layers(self, lora_scale) 150 | 151 | is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None 152 | # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets 153 | is_adapter = down_intrablock_additional_residuals is not None 154 | # maintain backward compatibility for legacy usage, where 155 | # T2I-Adapter and ControlNet both use down_block_additional_residuals arg 156 | # but can only use one or the other 157 | if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: 158 | deprecate( 159 | "T2I should not use down_block_additional_residuals", 160 | "1.3.0", 161 | "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ 162 | and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ 163 | for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", 164 | standard_warn=False, 165 | ) 166 | down_intrablock_additional_residuals = down_block_additional_residuals 167 | is_adapter = True 168 | 169 | res_cache = dict() 170 | down_block_res_samples = (sample,) 171 | for idx, downsample_block in enumerate(self.down_blocks): 172 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 173 | # For t2i-adapter CrossAttnDownBlock2D 174 | additional_residuals = {} 175 | if is_adapter and len(down_intrablock_additional_residuals) > 0: 176 | additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) 177 | 178 | sample, res_samples = downsample_block( 179 | hidden_states=sample, 180 | temb=emb, 181 | encoder_hidden_states=encoder_hidden_states, 182 | attention_mask=attention_mask, 183 | cross_attention_kwargs=cross_attention_kwargs, 184 | encoder_attention_mask=encoder_attention_mask, 185 | **additional_residuals, 186 | ) 187 | else: 188 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 189 | if is_adapter and len(down_intrablock_additional_residuals) > 0: 190 | sample += down_intrablock_additional_residuals.pop(0) 191 | 192 | res_cache[f"down_{idx}"] = sample.clone() 193 | if not control_hidden_states is None and f'down3_{idx}' in control_hidden_states: 194 | sample += rearrange(control_hidden_states[f'down3_{idx}'], "b c f h w -> (b f) c h w") 195 | if hasattr(self, 'motion_down') and use_motion: 196 | sample = self.motion_down[idx](sample, 197 | None if motion_pad_hidden_states is None else motion_pad_hidden_states[f'down_{idx}'], 198 | emb, num_frames) 199 | 200 | down_block_res_samples += res_samples 201 | 202 | if is_controlnet: 203 | new_down_block_res_samples = () 204 | 205 | for down_block_res_sample, down_block_additional_residual in zip( 206 | down_block_res_samples, down_block_additional_residuals 207 | ): 208 | down_block_res_sample = down_block_res_sample + down_block_additional_residual 209 | new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) 210 | 211 | down_block_res_samples = new_down_block_res_samples 212 | 213 | # 4. mid 214 | if self.mid_block is not None: 215 | if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: 216 | sample = self.mid_block( 217 | sample, 218 | emb, 219 | encoder_hidden_states=encoder_hidden_states, 220 | attention_mask=attention_mask, 221 | cross_attention_kwargs=cross_attention_kwargs, 222 | encoder_attention_mask=encoder_attention_mask, 223 | ) 224 | else: 225 | sample = self.mid_block(sample, emb) 226 | 227 | # To support T2I-Adapter-XL 228 | if ( 229 | is_adapter 230 | and len(down_intrablock_additional_residuals) > 0 231 | and sample.shape == down_intrablock_additional_residuals[0].shape 232 | ): 233 | sample += down_intrablock_additional_residuals.pop(0) 234 | 235 | if is_controlnet: 236 | sample = sample + mid_block_additional_residual 237 | 238 | # 5. up 239 | for i, upsample_block in enumerate(self.up_blocks): 240 | is_final_block = i == len(self.up_blocks) - 1 241 | 242 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 243 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 244 | 245 | # if we have not reached the final block and need to forward the 246 | # upsample size, we do it here 247 | if not is_final_block and forward_upsample_size: 248 | upsample_size = down_block_res_samples[-1].shape[2:] 249 | 250 | res_cache[f"up_{i}"] = sample.clone() 251 | if not control_hidden_states is None and f'up3_{i}' in control_hidden_states: 252 | sample += rearrange(control_hidden_states[f'up3_{i}'], "b c f h w -> (b f) c h w") 253 | if hasattr(self, "reference_modules_up") and not reference_hidden_states is None and f'up_{i}' in reference_hidden_states: 254 | sample = self.reference_modules_up[i](sample, reference_hidden_states[f'up_{i}'], num_frames) 255 | if hasattr(self, 'motion_up') and use_motion: 256 | sample = self.motion_up[i](sample, 257 | None if motion_pad_hidden_states is None else motion_pad_hidden_states[f'up_{i}'], 258 | emb, num_frames) 259 | 260 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 261 | sample = upsample_block( 262 | hidden_states=sample, 263 | temb=emb, 264 | res_hidden_states_tuple=res_samples, 265 | encoder_hidden_states=encoder_hidden_states, 266 | cross_attention_kwargs=cross_attention_kwargs, 267 | upsample_size=upsample_size, 268 | attention_mask=attention_mask, 269 | encoder_attention_mask=encoder_attention_mask, 270 | ) 271 | else: 272 | sample = upsample_block( 273 | hidden_states=sample, 274 | temb=emb, 275 | res_hidden_states_tuple=res_samples, 276 | upsample_size=upsample_size, 277 | ) 278 | 279 | # 6. post-process 280 | if self.conv_norm_out: 281 | sample = self.conv_norm_out(sample) 282 | sample = self.conv_act(sample) 283 | sample = self.conv_out(sample) 284 | 285 | if USE_PEFT_BACKEND: 286 | # remove `lora_scale` from each PEFT layer 287 | unscale_lora_layers(self, lora_scale) 288 | 289 | # reshape to (batch, channel, framerate, width, height) 290 | sample = rearrange(sample, "(b f) c h w -> b c f h w", f=num_frames) 291 | 292 | if not return_dict: 293 | return (sample, res_cache) 294 | 295 | return (UNet2DConditionOutput(sample=sample), res_cache) 296 | -------------------------------------------------------------------------------- /hellomeme/models/hm3_denoising_motion.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | @File : models6/hm_denoising_motion.py 5 | @Author : Songkey 6 | @Email : songkey@pku.edu.cn 7 | @Date : 1/3/2025 8 | @Desc : 9 | adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_motion_model.py 10 | """ 11 | 12 | import torch 13 | import torch.utils.checkpoint 14 | from typing import Any, Dict, Optional, Tuple, Union 15 | 16 | from einops import rearrange 17 | 18 | from diffusers.utils import logging 19 | from diffusers.models.unets.unet_motion_model import UNetMotionModel, UNetMotionOutput 20 | from .hm_adapters import InsertReferenceAdapter 21 | 22 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 23 | 24 | 25 | class HM3DenoisingMotion(UNetMotionModel, InsertReferenceAdapter): 26 | def forward( 27 | self, 28 | sample: torch.Tensor, 29 | timestep: Union[torch.Tensor, float, int], 30 | encoder_hidden_states: torch.Tensor, 31 | reference_hidden_states: Optional[dict] = None, 32 | control_hidden_states: Optional[torch.Tensor] = None, 33 | motion_pad_hidden_states: Optional[dict] = None, 34 | use_motion: bool = False, 35 | timestep_cond: Optional[torch.Tensor] = None, 36 | attention_mask: Optional[torch.Tensor] = None, 37 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 38 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 39 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 40 | mid_block_additional_residual: Optional[torch.Tensor] = None, 41 | return_dict: bool = True, 42 | ) -> Union[UNetMotionOutput, Tuple[torch.Tensor]]: 43 | 44 | # By default samples have to be AT least a multiple of the overall upsampling factor. 45 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 46 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 47 | # on the fly if necessary. 48 | default_overall_up_factor = 2 ** self.num_upsamplers 49 | 50 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 51 | forward_upsample_size = False 52 | upsample_size = None 53 | 54 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 55 | logger.info("Forward upsample size to force interpolation output size.") 56 | forward_upsample_size = True 57 | 58 | # prepare attention_mask 59 | if attention_mask is not None: 60 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 61 | attention_mask = attention_mask.unsqueeze(1) 62 | 63 | # 1. time 64 | timesteps = timestep 65 | if not torch.is_tensor(timesteps): 66 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 67 | # This would be a good case for the `match` statement (Python 3.10+) 68 | is_mps = sample.device.type == "mps" 69 | if isinstance(timestep, float): 70 | dtype = torch.float32 if is_mps else torch.float64 71 | else: 72 | dtype = torch.int32 if is_mps else torch.int64 73 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 74 | elif len(timesteps.shape) == 0: 75 | timesteps = timesteps[None].to(sample.device) 76 | 77 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 78 | num_frames = sample.shape[2] 79 | timesteps = timesteps.expand(sample.shape[0]) 80 | 81 | t_emb = self.time_proj(timesteps) 82 | 83 | # timesteps does not contain any weights and will always return f32 tensors 84 | # but time_embedding might actually be running in fp16. so we need to cast here. 85 | # there might be better ways to encapsulate this. 86 | t_emb = t_emb.to(dtype=self.dtype) 87 | 88 | emb = self.time_embedding(t_emb, timestep_cond) 89 | aug_emb = None 90 | 91 | if self.config.addition_embed_type == "text_time": 92 | if "text_embeds" not in added_cond_kwargs: 93 | raise ValueError( 94 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" 95 | ) 96 | 97 | text_embeds = added_cond_kwargs.get("text_embeds") 98 | if "time_ids" not in added_cond_kwargs: 99 | raise ValueError( 100 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" 101 | ) 102 | time_ids = added_cond_kwargs.get("time_ids") 103 | time_embeds = self.add_time_proj(time_ids.flatten()) 104 | time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) 105 | 106 | add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) 107 | add_embeds = add_embeds.to(emb.dtype) 108 | aug_emb = self.add_embedding(add_embeds) 109 | 110 | emb = emb if aug_emb is None else emb + aug_emb 111 | emb = emb.repeat_interleave(repeats=num_frames, dim=0) 112 | 113 | if len(encoder_hidden_states.shape) == 3: 114 | encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) 115 | elif len(encoder_hidden_states.shape) == 4: 116 | encoder_hidden_states = rearrange(encoder_hidden_states, "b f l d -> (b f) l d") 117 | 118 | if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": 119 | if "image_embeds" not in added_cond_kwargs: 120 | raise ValueError( 121 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" 122 | ) 123 | image_embeds = added_cond_kwargs.get("image_embeds") 124 | image_embeds = self.encoder_hid_proj(image_embeds) 125 | image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds] 126 | encoder_hidden_states = (encoder_hidden_states, image_embeds) 127 | 128 | # 2. pre-process 129 | sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) 130 | sample = self.conv_in(sample) 131 | 132 | # 3. down 133 | down_block_res_samples = (sample,) 134 | for idx, downsample_block in enumerate(self.down_blocks): 135 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 136 | sample, res_samples = downsample_block( 137 | hidden_states=sample, 138 | temb=emb, 139 | encoder_hidden_states=encoder_hidden_states, 140 | attention_mask=attention_mask, 141 | num_frames=num_frames, 142 | cross_attention_kwargs=cross_attention_kwargs, 143 | ) 144 | else: 145 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) 146 | 147 | if not control_hidden_states is None and f'down3_{idx}' in control_hidden_states: 148 | sample += rearrange(control_hidden_states[f'down3_{idx}'], "b c f h w -> (b f) c h w") 149 | if hasattr(self, 'motion_down') and use_motion: 150 | sample = self.motion_down[idx](sample, motion_pad_hidden_states[f'down_{idx}'], emb, num_frames) 151 | 152 | down_block_res_samples += res_samples 153 | 154 | if down_block_additional_residuals is not None: 155 | new_down_block_res_samples = () 156 | 157 | for down_block_res_sample, down_block_additional_residual in zip( 158 | down_block_res_samples, down_block_additional_residuals 159 | ): 160 | down_block_res_sample = down_block_res_sample + down_block_additional_residual 161 | new_down_block_res_samples += (down_block_res_sample,) 162 | 163 | down_block_res_samples = new_down_block_res_samples 164 | 165 | # 4. mid 166 | if self.mid_block is not None: 167 | # To support older versions of motion modules that don't have a mid_block 168 | if hasattr(self.mid_block, "motion_modules"): 169 | sample = self.mid_block( 170 | sample, 171 | emb, 172 | encoder_hidden_states=encoder_hidden_states, 173 | attention_mask=attention_mask, 174 | num_frames=num_frames, 175 | cross_attention_kwargs=cross_attention_kwargs, 176 | ) 177 | else: 178 | sample = self.mid_block( 179 | sample, 180 | emb, 181 | encoder_hidden_states=encoder_hidden_states, 182 | attention_mask=attention_mask, 183 | cross_attention_kwargs=cross_attention_kwargs, 184 | ) 185 | 186 | if mid_block_additional_residual is not None: 187 | sample = sample + mid_block_additional_residual 188 | 189 | # 5. up 190 | for i, upsample_block in enumerate(self.up_blocks): 191 | is_final_block = i == len(self.up_blocks) - 1 192 | 193 | res_samples = down_block_res_samples[-len(upsample_block.resnets):] 194 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 195 | 196 | # if we have not reached the final block and need to forward the 197 | # upsample size, we do it here 198 | if not is_final_block and forward_upsample_size: 199 | upsample_size = down_block_res_samples[-1].shape[2:] 200 | 201 | if not control_hidden_states is None and f'up3_{i}' in control_hidden_states: 202 | sample += rearrange(control_hidden_states[f'up3_{i}'], "b c f h w -> (b f) c h w") 203 | if hasattr(self, "reference_modules_up") and not reference_hidden_states is None and f'up_{i}' in reference_hidden_states: 204 | sample = self.reference_modules_up[i](sample, reference_hidden_states[f'up_{i}'], num_frames) 205 | if hasattr(self, 'motion_up') and use_motion: 206 | sample = self.motion_up[i](sample, motion_pad_hidden_states[f'up_{i}'], emb, num_frames) 207 | 208 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 209 | sample = upsample_block( 210 | hidden_states=sample, 211 | temb=emb, 212 | res_hidden_states_tuple=res_samples, 213 | encoder_hidden_states=encoder_hidden_states, 214 | upsample_size=upsample_size, 215 | attention_mask=attention_mask, 216 | num_frames=num_frames, 217 | cross_attention_kwargs=cross_attention_kwargs, 218 | ) 219 | else: 220 | sample = upsample_block( 221 | hidden_states=sample, 222 | temb=emb, 223 | res_hidden_states_tuple=res_samples, 224 | upsample_size=upsample_size, 225 | num_frames=num_frames, 226 | ) 227 | 228 | # 6. post-process 229 | if self.conv_norm_out: 230 | sample = self.conv_norm_out(sample) 231 | sample = self.conv_act(sample) 232 | 233 | sample = self.conv_out(sample) 234 | 235 | # reshape to (batch, channel, framerate, width, height) 236 | sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) 237 | 238 | if not return_dict: 239 | return (sample,) 240 | 241 | return UNetMotionOutput(sample=sample) -------------------------------------------------------------------------------- /hellomeme/models/hm_denoising_3d.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | @File : models6/hm_denoising_3d.py 5 | @Author : Songkey 6 | @Email : songkey@pku.edu.cn 7 | @Date : 8/14/2024 8 | @Desc : 删除实验代码,精简结构 9 | adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_condition.py 10 | """ 11 | 12 | import torch 13 | import torch.utils.checkpoint 14 | from typing import Any, Dict, Optional, Tuple, Union 15 | 16 | from einops import rearrange 17 | 18 | from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers 19 | from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput 20 | from .hm_adapters import CopyWeights, InsertReferenceAdapter 21 | 22 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 23 | 24 | 25 | class HMDenoising3D(UNet2DConditionModel, CopyWeights, InsertReferenceAdapter): 26 | def forward( 27 | self, 28 | sample: torch.Tensor, 29 | timestep: Union[torch.Tensor, float, int], 30 | encoder_hidden_states: torch.Tensor, 31 | reference_hidden_states: Optional[dict] = None, 32 | control_hidden_states: Optional[torch.Tensor] = None, 33 | class_labels: Optional[torch.Tensor] = None, 34 | timestep_cond: Optional[torch.Tensor] = None, 35 | attention_mask: Optional[torch.Tensor] = None, 36 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 37 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 38 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 39 | mid_block_additional_residual: Optional[torch.Tensor] = None, 40 | down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 41 | encoder_attention_mask: Optional[torch.Tensor] = None, 42 | return_dict: bool = True, 43 | ) -> Union[UNet2DConditionOutput, Tuple]: 44 | # By default samples have to be AT least a multiple of the overall upsampling factor. 45 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). 46 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 47 | # on the fly if necessary. 48 | default_overall_up_factor = 2**self.num_upsamplers 49 | 50 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 51 | forward_upsample_size = False 52 | upsample_size = None 53 | 54 | for dim in sample.shape[-2:]: 55 | if dim % default_overall_up_factor != 0: 56 | # Forward upsample size to force interpolation output size. 57 | forward_upsample_size = True 58 | break 59 | 60 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension 61 | # expects mask of shape: 62 | # [batch, key_tokens] 63 | # adds singleton query_tokens dimension: 64 | # [batch, 1, key_tokens] 65 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: 66 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) 67 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) 68 | if attention_mask is not None: 69 | # assume that mask is expressed as: 70 | # (1 = keep, 0 = discard) 71 | # convert mask into a bias that can be added to attention scores: 72 | # (keep = +0, discard = -10000.0) 73 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 74 | attention_mask = attention_mask.unsqueeze(1) 75 | 76 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 77 | if encoder_attention_mask is not None: 78 | encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 79 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 80 | 81 | # 0. center input if necessary 82 | if self.config.center_input_sample: 83 | sample = 2 * sample - 1.0 84 | 85 | # 1. time 86 | t_emb = self.get_time_embed(sample=sample, timestep=timestep) 87 | emb = self.time_embedding(t_emb, timestep_cond) 88 | aug_emb = None 89 | 90 | class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) 91 | if class_emb is not None: 92 | if self.config.class_embeddings_concat: 93 | emb = torch.cat([emb, class_emb], dim=-1) 94 | else: 95 | emb = emb + class_emb 96 | 97 | aug_emb = self.get_aug_embed( 98 | emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs 99 | ) 100 | if self.config.addition_embed_type == "image_hint": 101 | aug_emb, hint = aug_emb 102 | sample = torch.cat([sample, hint], dim=1) 103 | 104 | emb = emb + aug_emb if aug_emb is not None else emb 105 | 106 | if self.time_embed_act is not None: 107 | emb = self.time_embed_act(emb) 108 | 109 | num_frames = sample.shape[2] 110 | emb = emb.repeat_interleave(repeats=num_frames, dim=0) 111 | 112 | if len(encoder_hidden_states.shape) == 3: 113 | encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) 114 | elif len(encoder_hidden_states.shape) == 4: 115 | encoder_hidden_states = rearrange(encoder_hidden_states, "b f l d -> (b f) l d") 116 | 117 | if not added_cond_kwargs is None and 'image_embeds' in added_cond_kwargs: 118 | if isinstance(added_cond_kwargs['image_embeds'], torch.Tensor): 119 | added_cond_kwargs['image_embeds'] = added_cond_kwargs['image_embeds'].repeat_interleave(repeats=num_frames, dim=0) 120 | else: 121 | added_cond_kwargs['image_embeds'] = [x.repeat_interleave(repeats=num_frames, dim=0) for x in added_cond_kwargs['image_embeds']] 122 | 123 | encoder_hidden_states = self.process_encoder_hidden_states( 124 | encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs 125 | ) 126 | 127 | # 2. pre-process 128 | sample = rearrange(sample, "b c f h w -> (b f) c h w") 129 | sample = self.conv_in(sample) 130 | 131 | # 2.5 GLIGEN position net 132 | if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: 133 | cross_attention_kwargs = cross_attention_kwargs.copy() 134 | gligen_args = cross_attention_kwargs.pop("gligen") 135 | cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} 136 | 137 | # 3. down 138 | # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated 139 | # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. 140 | if cross_attention_kwargs is not None: 141 | cross_attention_kwargs = cross_attention_kwargs.copy() 142 | lora_scale = cross_attention_kwargs.pop("scale", 1.0) 143 | else: 144 | lora_scale = 1.0 145 | 146 | if USE_PEFT_BACKEND: 147 | # weight the lora layers by setting `lora_scale` for each PEFT layer 148 | scale_lora_layers(self, lora_scale) 149 | 150 | is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None 151 | # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets 152 | is_adapter = down_intrablock_additional_residuals is not None 153 | # maintain backward compatibility for legacy usage, where 154 | # T2I-Adapter and ControlNet both use down_block_additional_residuals arg 155 | # but can only use one or the other 156 | if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: 157 | deprecate( 158 | "T2I should not use down_block_additional_residuals", 159 | "1.3.0", 160 | "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ 161 | and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ 162 | for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", 163 | standard_warn=False, 164 | ) 165 | down_intrablock_additional_residuals = down_block_additional_residuals 166 | is_adapter = True 167 | 168 | res_cache = dict() 169 | down_block_res_samples = (sample,) 170 | for idx, downsample_block in enumerate(self.down_blocks): 171 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 172 | # For t2i-adapter CrossAttnDownBlock2D 173 | additional_residuals = {} 174 | if is_adapter and len(down_intrablock_additional_residuals) > 0: 175 | additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) 176 | 177 | sample, res_samples = downsample_block( 178 | hidden_states=sample, 179 | temb=emb, 180 | encoder_hidden_states=encoder_hidden_states, 181 | attention_mask=attention_mask, 182 | cross_attention_kwargs=cross_attention_kwargs, 183 | encoder_attention_mask=encoder_attention_mask, 184 | **additional_residuals, 185 | ) 186 | res_cache[f"down_{idx}"] = sample.clone() 187 | else: 188 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 189 | if is_adapter and len(down_intrablock_additional_residuals) > 0: 190 | sample += down_intrablock_additional_residuals.pop(0) 191 | 192 | if not control_hidden_states is None and f'down_{idx}' in control_hidden_states: 193 | sample += rearrange(control_hidden_states[f'down_{idx}'], "b c f h w -> (b f) c h w") 194 | if not control_hidden_states is None and f'down2_{idx}' in control_hidden_states: 195 | sample += rearrange(control_hidden_states[f'down2_{idx}'], "b c f h w -> (b f) c h w") 196 | if hasattr(self, 'reference_modules_down') and not reference_hidden_states is None and f'down_{idx}' in reference_hidden_states: 197 | sample = self.reference_modules_down[idx](sample, reference_hidden_states[f'down_{idx}'], num_frames) 198 | 199 | down_block_res_samples += res_samples 200 | 201 | if is_controlnet: 202 | new_down_block_res_samples = () 203 | 204 | for down_block_res_sample, down_block_additional_residual in zip( 205 | down_block_res_samples, down_block_additional_residuals 206 | ): 207 | down_block_res_sample = down_block_res_sample + down_block_additional_residual 208 | new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) 209 | 210 | down_block_res_samples = new_down_block_res_samples 211 | 212 | # 4. mid 213 | if self.mid_block is not None: 214 | if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: 215 | sample = self.mid_block( 216 | sample, 217 | emb, 218 | encoder_hidden_states=encoder_hidden_states, 219 | attention_mask=attention_mask, 220 | cross_attention_kwargs=cross_attention_kwargs, 221 | encoder_attention_mask=encoder_attention_mask, 222 | ) 223 | else: 224 | sample = self.mid_block(sample, emb) 225 | if hasattr(self, 'reference_modules_mid') and not reference_hidden_states is None and f'mid' in reference_hidden_states: 226 | sample = self.reference_modules_mid(sample, reference_hidden_states[f'mid'], num_frames) 227 | 228 | # To support T2I-Adapter-XL 229 | if ( 230 | is_adapter 231 | and len(down_intrablock_additional_residuals) > 0 232 | and sample.shape == down_intrablock_additional_residuals[0].shape 233 | ): 234 | sample += down_intrablock_additional_residuals.pop(0) 235 | res_cache[f"mid"] = sample.clone() 236 | 237 | if is_controlnet: 238 | sample = sample + mid_block_additional_residual 239 | 240 | # 5. up 241 | for i, upsample_block in enumerate(self.up_blocks): 242 | is_final_block = i == len(self.up_blocks) - 1 243 | 244 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 245 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 246 | 247 | # if we have not reached the final block and need to forward the 248 | # upsample size, we do it here 249 | if not is_final_block and forward_upsample_size: 250 | upsample_size = down_block_res_samples[-1].shape[2:] 251 | 252 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 253 | res_cache[f"up_{i}"] = sample.clone() 254 | if not control_hidden_states is None and f'up_v2_{i}' in control_hidden_states: 255 | sample += rearrange(control_hidden_states[f'up_v2_{i}'], "b c f h w -> (b f) c h w") 256 | if not control_hidden_states is None and f'up2_v2_{i}' in control_hidden_states: 257 | sample += rearrange(control_hidden_states[f'up2_v2_{i}'], "b c f h w -> (b f) c h w") 258 | if hasattr(self, "reference_modules_up") and not reference_hidden_states is None and f'up_{i}' in reference_hidden_states: 259 | sample = self.reference_modules_up[i-1](sample, reference_hidden_states[f'up_{i}'], num_frames) 260 | 261 | sample = upsample_block( 262 | hidden_states=sample, 263 | temb=emb, 264 | res_hidden_states_tuple=res_samples, 265 | encoder_hidden_states=encoder_hidden_states, 266 | cross_attention_kwargs=cross_attention_kwargs, 267 | upsample_size=upsample_size, 268 | attention_mask=attention_mask, 269 | encoder_attention_mask=encoder_attention_mask, 270 | ) 271 | else: 272 | if not control_hidden_states is None and f'up_v2_{i}' in control_hidden_states: 273 | sample += rearrange(control_hidden_states[f'up_v2_{i}'], "b c f h w -> (b f) c h w") 274 | if not control_hidden_states is None and f'up2_v2_{i}' in control_hidden_states: 275 | sample += rearrange(control_hidden_states[f'up2_v2_{i}'], "b c f h w -> (b f) c h w") 276 | sample = upsample_block( 277 | hidden_states=sample, 278 | temb=emb, 279 | res_hidden_states_tuple=res_samples, 280 | upsample_size=upsample_size, 281 | ) 282 | 283 | # 6. post-process 284 | if self.conv_norm_out: 285 | sample = self.conv_norm_out(sample) 286 | sample = self.conv_act(sample) 287 | sample = self.conv_out(sample) 288 | 289 | if USE_PEFT_BACKEND: 290 | # remove `lora_scale` from each PEFT layer 291 | unscale_lora_layers(self, lora_scale) 292 | 293 | # reshape to (batch, channel, framerate, width, height) 294 | sample = rearrange(sample, "(b f) c h w -> b c f h w", f=num_frames) 295 | 296 | if not return_dict: 297 | return (sample, res_cache) 298 | 299 | return (UNet2DConditionOutput(sample=sample), res_cache) 300 | -------------------------------------------------------------------------------- /hellomeme/models/hm_denoising_motion.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | @File : models6/hm_denoising_motion.py 5 | @Author : Songkey 6 | @Email : songkey@pku.edu.cn 7 | @Date : 9/9/2024 8 | @Desc : 9 | adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_motion_model.py 10 | """ 11 | 12 | import torch 13 | import torch.utils.checkpoint 14 | from typing import Any, Dict, Optional, Tuple, Union 15 | 16 | from einops import rearrange 17 | 18 | from diffusers.utils import logging 19 | from diffusers.models.unets.unet_motion_model import UNetMotionModel, UNetMotionOutput 20 | from .hm_adapters import InsertReferenceAdapter 21 | 22 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 23 | 24 | 25 | class HMDenoisingMotion(UNetMotionModel, InsertReferenceAdapter): 26 | def forward( 27 | self, 28 | sample: torch.Tensor, 29 | timestep: Union[torch.Tensor, float, int], 30 | encoder_hidden_states: torch.Tensor, 31 | reference_hidden_states: Optional[dict] = None, 32 | control_hidden_states: Optional[torch.Tensor] = None, 33 | timestep_cond: Optional[torch.Tensor] = None, 34 | attention_mask: Optional[torch.Tensor] = None, 35 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 36 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 37 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 38 | mid_block_additional_residual: Optional[torch.Tensor] = None, 39 | return_dict: bool = True, 40 | ) -> Union[UNetMotionOutput, Tuple[torch.Tensor]]: 41 | 42 | # By default samples have to be AT least a multiple of the overall upsampling factor. 43 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 44 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 45 | # on the fly if necessary. 46 | default_overall_up_factor = 2 ** self.num_upsamplers 47 | 48 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 49 | forward_upsample_size = False 50 | upsample_size = None 51 | 52 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 53 | logger.info("Forward upsample size to force interpolation output size.") 54 | forward_upsample_size = True 55 | 56 | # prepare attention_mask 57 | if attention_mask is not None: 58 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 59 | attention_mask = attention_mask.unsqueeze(1) 60 | 61 | # 1. time 62 | timesteps = timestep 63 | if not torch.is_tensor(timesteps): 64 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 65 | # This would be a good case for the `match` statement (Python 3.10+) 66 | is_mps = sample.device.type == "mps" 67 | if isinstance(timestep, float): 68 | dtype = torch.float32 if is_mps else torch.float64 69 | else: 70 | dtype = torch.int32 if is_mps else torch.int64 71 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 72 | elif len(timesteps.shape) == 0: 73 | timesteps = timesteps[None].to(sample.device) 74 | 75 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 76 | num_frames = sample.shape[2] 77 | timesteps = timesteps.expand(sample.shape[0]) 78 | 79 | t_emb = self.time_proj(timesteps) 80 | 81 | # timesteps does not contain any weights and will always return f32 tensors 82 | # but time_embedding might actually be running in fp16. so we need to cast here. 83 | # there might be better ways to encapsulate this. 84 | t_emb = t_emb.to(dtype=self.dtype) 85 | 86 | emb = self.time_embedding(t_emb, timestep_cond) 87 | aug_emb = None 88 | 89 | if self.config.addition_embed_type == "text_time": 90 | if "text_embeds" not in added_cond_kwargs: 91 | raise ValueError( 92 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" 93 | ) 94 | 95 | text_embeds = added_cond_kwargs.get("text_embeds") 96 | if "time_ids" not in added_cond_kwargs: 97 | raise ValueError( 98 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" 99 | ) 100 | time_ids = added_cond_kwargs.get("time_ids") 101 | time_embeds = self.add_time_proj(time_ids.flatten()) 102 | time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) 103 | 104 | add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) 105 | add_embeds = add_embeds.to(emb.dtype) 106 | aug_emb = self.add_embedding(add_embeds) 107 | 108 | emb = emb if aug_emb is None else emb + aug_emb 109 | emb = emb.repeat_interleave(repeats=num_frames, dim=0) 110 | if len(encoder_hidden_states.shape) == 3: 111 | encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) 112 | elif len(encoder_hidden_states.shape) == 4: 113 | encoder_hidden_states = rearrange(encoder_hidden_states, "b f l d -> (b f) l d") 114 | 115 | if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": 116 | if "image_embeds" not in added_cond_kwargs: 117 | raise ValueError( 118 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" 119 | ) 120 | image_embeds = added_cond_kwargs.get("image_embeds") 121 | image_embeds = self.encoder_hid_proj(image_embeds) 122 | image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds] 123 | encoder_hidden_states = (encoder_hidden_states, image_embeds) 124 | 125 | # 2. pre-process 126 | sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) 127 | sample = self.conv_in(sample) 128 | 129 | # 3. down 130 | down_block_res_samples = (sample,) 131 | for idx, downsample_block in enumerate(self.down_blocks): 132 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 133 | sample, res_samples = downsample_block( 134 | hidden_states=sample, 135 | temb=emb, 136 | encoder_hidden_states=encoder_hidden_states, 137 | attention_mask=attention_mask, 138 | num_frames=num_frames, 139 | cross_attention_kwargs=cross_attention_kwargs, 140 | ) 141 | else: 142 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) 143 | 144 | if not control_hidden_states is None and f'down_{idx}' in control_hidden_states: 145 | sample += rearrange(control_hidden_states[f'down_{idx}'], "b c f h w -> (b f) c h w") 146 | if not control_hidden_states is None and f'down2_{idx}' in control_hidden_states: 147 | sample += rearrange(control_hidden_states[f'down2_{idx}'], "b c f h w -> (b f) c h w") 148 | 149 | if hasattr(self, 'reference_modules_down') and not reference_hidden_states is None and f'down_{idx}' in reference_hidden_states: 150 | sample = self.reference_modules_down[idx](sample, reference_hidden_states[f'down_{idx}'], num_frames) 151 | 152 | down_block_res_samples += res_samples 153 | 154 | if down_block_additional_residuals is not None: 155 | new_down_block_res_samples = () 156 | 157 | for down_block_res_sample, down_block_additional_residual in zip( 158 | down_block_res_samples, down_block_additional_residuals 159 | ): 160 | down_block_res_sample = down_block_res_sample + down_block_additional_residual 161 | new_down_block_res_samples += (down_block_res_sample,) 162 | 163 | down_block_res_samples = new_down_block_res_samples 164 | 165 | # 4. mid 166 | if self.mid_block is not None: 167 | # To support older versions of motion modules that don't have a mid_block 168 | if hasattr(self.mid_block, "motion_modules"): 169 | sample = self.mid_block( 170 | sample, 171 | emb, 172 | encoder_hidden_states=encoder_hidden_states, 173 | attention_mask=attention_mask, 174 | num_frames=num_frames, 175 | cross_attention_kwargs=cross_attention_kwargs, 176 | ) 177 | else: 178 | sample = self.mid_block( 179 | sample, 180 | emb, 181 | encoder_hidden_states=encoder_hidden_states, 182 | attention_mask=attention_mask, 183 | cross_attention_kwargs=cross_attention_kwargs, 184 | ) 185 | if hasattr(self, 'reference_modules_mid') and not reference_hidden_states is None and f'mid' in reference_hidden_states: 186 | sample = self.reference_modules_mid(sample, reference_hidden_states[f'mid'], num_frames) 187 | 188 | if mid_block_additional_residual is not None: 189 | sample = sample + mid_block_additional_residual 190 | 191 | # 5. up 192 | for i, upsample_block in enumerate(self.up_blocks): 193 | is_final_block = i == len(self.up_blocks) - 1 194 | 195 | res_samples = down_block_res_samples[-len(upsample_block.resnets):] 196 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 197 | 198 | # if we have not reached the final block and need to forward the 199 | # upsample size, we do it here 200 | if not is_final_block and forward_upsample_size: 201 | upsample_size = down_block_res_samples[-1].shape[2:] 202 | 203 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 204 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 205 | if not control_hidden_states is None and f'up_v2_{i}' in control_hidden_states: 206 | sample += rearrange(control_hidden_states[f'up_v2_{i}'], "b c f h w -> (b f) c h w") 207 | if not control_hidden_states is None and f'up2_v2_{i}' in control_hidden_states: 208 | sample += rearrange(control_hidden_states[f'up2_v2_{i}'], "b c f h w -> (b f) c h w") 209 | if hasattr(self, 210 | "reference_modules_up") and not reference_hidden_states is None and f'up_{i}' in reference_hidden_states: 211 | sample = self.reference_modules_up[i - 1](sample, reference_hidden_states[f'up_{i}'], 212 | num_frames) 213 | 214 | sample = upsample_block( 215 | hidden_states=sample, 216 | temb=emb, 217 | res_hidden_states_tuple=res_samples, 218 | encoder_hidden_states=encoder_hidden_states, 219 | upsample_size=upsample_size, 220 | attention_mask=attention_mask, 221 | num_frames=num_frames, 222 | cross_attention_kwargs=cross_attention_kwargs, 223 | ) 224 | else: 225 | if not control_hidden_states is None and f'up_v2_{i}' in control_hidden_states: 226 | sample += rearrange(control_hidden_states[f'up_v2_{i}'], "b c f h w -> (b f) c h w") 227 | if not control_hidden_states is None and f'up2_v2_{i}' in control_hidden_states: 228 | sample += rearrange(control_hidden_states[f'up2_v2_{i}'], "b c f h w -> (b f) c h w") 229 | sample = upsample_block( 230 | hidden_states=sample, 231 | temb=emb, 232 | res_hidden_states_tuple=res_samples, 233 | upsample_size=upsample_size, 234 | num_frames=num_frames, 235 | ) 236 | 237 | # 6. post-process 238 | if self.conv_norm_out: 239 | sample = self.conv_norm_out(sample) 240 | sample = self.conv_act(sample) 241 | 242 | sample = self.conv_out(sample) 243 | 244 | # reshape to (batch, channel, framerate, width, height) 245 | sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) 246 | 247 | if not return_dict: 248 | return (sample,) 249 | 250 | return UNetMotionOutput(sample=sample) -------------------------------------------------------------------------------- /hellomeme/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | @File : __init__.py.py 5 | @Author : Songkey 6 | @Email : songkey@pku.edu.cn 7 | @Date : 8/29/2024 8 | @Desc : 9 | """ 10 | 11 | from .pipline_hm_image import HMImagePipeline 12 | from .pipline_hm_video import HMVideoPipeline 13 | from .pipline_hm3_image import HM3ImagePipeline 14 | from .pipline_hm3_video import HM3VideoPipeline 15 | from .pipline_hm5_image import HM5ImagePipeline 16 | from .pipline_hm5_video import HM5VideoPipeline -------------------------------------------------------------------------------- /hellomeme/pipelines/pipline_hm3_image.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | @File : hm_pipline_image.py 5 | @Author : Songkey 6 | @Email : songkey@pku.edu.cn 7 | @Date : 1/3/2025 8 | @Desc : 9 | adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py 10 | """ 11 | 12 | import copy 13 | from typing import Any, Callable, Dict, List, Optional, Union 14 | import torch 15 | 16 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback 17 | from diffusers.image_processor import PipelineImageInput 18 | from diffusers.utils import deprecate 19 | from diffusers.utils.torch_utils import randn_tensor 20 | from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput 21 | from diffusers import DPMSolverMultistepScheduler 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_timesteps, retrieve_latents 23 | from ..models import HM3Denoising3D, HMV3ControlNet, HMPipeline, HM3ReferenceAdapter, HMControlNetBase, HM4SD15ControlProj 24 | 25 | class HM3ImagePipeline(HMPipeline): 26 | def caryomitosis(self, **kwargs): 27 | if hasattr(self, "unet_ref"): 28 | del self.unet_ref 29 | self.unet_ref = HM3Denoising3D.from_unet2d(self.unet) 30 | self.unet_ref.cpu() 31 | 32 | if not isinstance(self.unet, HM3Denoising3D): 33 | unet = HM3Denoising3D.from_unet2d(unet=self.unet) 34 | # todo: 不够优雅 35 | del self.unet 36 | self.unet = unet 37 | self.unet.cpu() 38 | 39 | self.vae.cpu() 40 | self.vae_decode = copy.deepcopy(self.vae) 41 | self.text_encoder.cpu() 42 | self.text_encoder_ref = copy.deepcopy(self.text_encoder) 43 | self.safety_checker.cpu() 44 | 45 | def insert_hm_modules(self, version='v3', dtype=torch.float16, modelscope=False): 46 | self.version = version 47 | if modelscope: 48 | from modelscope import snapshot_download 49 | if version == 'v3': 50 | hm_reference_dir = snapshot_download('songkey/hm3_reference') 51 | hm_control_dir = snapshot_download('songkey/hm3_control_mix') 52 | else: 53 | hm_reference_dir = snapshot_download('songkey/hm4_reference') 54 | hm_control_dir = snapshot_download('songkey/hm_control_base') 55 | hm_control_proj_dir = snapshot_download('songkey/hm4_control_proj') 56 | else: 57 | if version == 'v3': 58 | hm_reference_dir = 'songkey/hm3_reference' 59 | hm_control_dir = 'songkey/hm3_control_mix' 60 | else: 61 | hm_reference_dir = 'songkey/hm4_reference' 62 | hm_control_dir = 'songkey/hm_control_base' 63 | hm_control_proj_dir = 'songkey/hm4_control_proj' 64 | 65 | if isinstance(self.unet, HM3Denoising3D): 66 | hm_adapter = HM3ReferenceAdapter.from_pretrained(hm_reference_dir) 67 | self.unet.insert_reference_adapter(hm_adapter) 68 | self.unet.to(device='cpu', dtype=dtype).eval() 69 | 70 | if hasattr(self, "unet_ref"): 71 | self.unet_ref.to(device='cpu', dtype=dtype).eval() 72 | 73 | if hasattr(self, "mp_control"): 74 | del self.mp_control 75 | 76 | if hasattr(self, "mp_control_proj"): 77 | del self.mp_control_proj 78 | 79 | if version == 'v3': 80 | self.mp_control = HMV3ControlNet.from_pretrained(hm_control_dir) 81 | else: 82 | self.mp_control = HMControlNetBase.from_pretrained(hm_control_dir) 83 | self.mp_control_proj = HM4SD15ControlProj.from_pretrained(hm_control_proj_dir) 84 | 85 | self.mp_control_proj.to(device='cpu', dtype=dtype).eval() 86 | 87 | self.mp_control.to(device='cpu', dtype=dtype).eval() 88 | 89 | self.vae.to(device='cpu', dtype=dtype).eval() 90 | self.vae_decode.to(device='cpu', dtype=dtype).eval() 91 | self.text_encoder.to(device='cpu', dtype=dtype).eval() 92 | 93 | @torch.no_grad() 94 | def __call__( 95 | self, 96 | prompt: Union[str, List[str]] = None, 97 | image: PipelineImageInput = None, 98 | drive_params: Dict[str, Any] = None, 99 | strength: float = 0.8, 100 | num_inference_steps: Optional[int] = 50, 101 | timesteps: List[int] = None, 102 | sigmas: List[float] = None, 103 | guidance_scale: Optional[float] = 7.5, 104 | negative_prompt: Optional[Union[str, List[str]]] = None, 105 | eta: Optional[float] = 0.0, 106 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 107 | prompt_embeds: Optional[torch.Tensor] = None, 108 | negative_prompt_embeds: Optional[torch.Tensor] = None, 109 | ip_adapter_image: Optional[PipelineImageInput] = None, 110 | ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, 111 | output_type: Optional[str] = "pil", 112 | device: Optional[str] = "cpu", 113 | return_dict: bool = True, 114 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 115 | clip_skip: int = None, 116 | callback_on_step_end: Optional[ 117 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] 118 | ] = None, 119 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 120 | **kwargs, 121 | ): 122 | callback = kwargs.pop("callback", None) 123 | callback_steps = kwargs.pop("callback_steps", None) 124 | num_images_per_prompt = 1 125 | 126 | if callback is not None: 127 | deprecate( 128 | "callback", 129 | "1.0.0", 130 | "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", 131 | ) 132 | if callback_steps is not None: 133 | deprecate( 134 | "callback_steps", 135 | "1.0.0", 136 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", 137 | ) 138 | 139 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): 140 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs 141 | 142 | # 1. Check inputs. Raise error if not correct 143 | self.check_inputs( 144 | prompt, 145 | strength, 146 | callback_steps, 147 | negative_prompt, 148 | prompt_embeds, 149 | negative_prompt_embeds, 150 | ip_adapter_image, 151 | ip_adapter_image_embeds, 152 | callback_on_step_end_tensor_inputs, 153 | ) 154 | 155 | self._guidance_scale = guidance_scale 156 | self._clip_skip = clip_skip 157 | self._cross_attention_kwargs = cross_attention_kwargs 158 | self._interrupt = False 159 | 160 | # 2. Define call parameters 161 | if prompt is not None and isinstance(prompt, str): 162 | batch_size = 1 163 | elif prompt is not None and isinstance(prompt, list): 164 | batch_size = len(prompt) 165 | else: 166 | batch_size = prompt_embeds.shape[0] 167 | 168 | # device = self.device 169 | 170 | # 3. Encode input prompt 171 | text_encoder_lora_scale = ( 172 | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None 173 | ) 174 | 175 | self.text_encoder_ref.to(device=device) 176 | prompt_embeds_ref, negative_prompt_embeds_ref = self.encode_prompt_sk( 177 | self.text_encoder_ref, 178 | prompt, 179 | device, 180 | num_images_per_prompt, 181 | self.do_classifier_free_guidance, 182 | negative_prompt, 183 | prompt_embeds=prompt_embeds, 184 | negative_prompt_embeds=negative_prompt_embeds, 185 | lora_scale=text_encoder_lora_scale, 186 | clip_skip=self.clip_skip, 187 | ) 188 | self.text_encoder_ref.cpu() 189 | 190 | self.text_encoder.to(device=device) 191 | prompt_embeds, negative_prompt_embeds = self.encode_prompt_sk( 192 | self.text_encoder, 193 | prompt, 194 | device, 195 | num_images_per_prompt, 196 | self.do_classifier_free_guidance, 197 | negative_prompt, 198 | prompt_embeds=prompt_embeds, 199 | negative_prompt_embeds=negative_prompt_embeds, 200 | lora_scale=text_encoder_lora_scale, 201 | clip_skip=self.clip_skip, 202 | ) 203 | self.text_encoder.cpu() 204 | 205 | # For classifier free guidance, we need to do two forward passes. 206 | # Here we concatenate the unconditional and text embeddings into a single batch 207 | # to avoid doing two forward passes 208 | if self.do_classifier_free_guidance: 209 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 210 | prompt_embeds_ref = torch.cat([negative_prompt_embeds_ref, prompt_embeds_ref]) 211 | 212 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None: 213 | image_embeds = self.prepare_ip_adapter_image_embeds( 214 | ip_adapter_image, 215 | ip_adapter_image_embeds, 216 | device, 217 | batch_size * num_images_per_prompt, 218 | self.do_classifier_free_guidance, 219 | ) 220 | 221 | # 4. Preprocess 222 | image = self.image_processor.preprocess(image).to(device=device, dtype=prompt_embeds.dtype) 223 | 224 | scheduler = DPMSolverMultistepScheduler( 225 | num_train_timesteps=1000, 226 | beta_start=0.00085, 227 | beta_end=0.012, 228 | beta_schedule="scaled_linear", 229 | # use_karras_sigmas=True, 230 | algorithm_type="sde-dpmsolver++", 231 | ) 232 | 233 | # 5. set timesteps 234 | timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, timesteps, sigmas) 235 | 236 | # 6. Prepare reference latents 237 | self.vae.to(device=device) 238 | ref_latents = [ 239 | retrieve_latents(self.vae.encode(image[i: i + 1].to(device=device)), generator=generator) 240 | for i in range(batch_size) 241 | ] 242 | self.vae.cpu() 243 | 244 | ref_latents = torch.cat(ref_latents, dim=0) 245 | ref_latents = self.vae.config.scaling_factor * ref_latents 246 | c, h, w = ref_latents.shape[1:] 247 | 248 | condition = drive_params['condition'].clone().to(device=device) 249 | if self.do_classifier_free_guidance: 250 | condition = torch.cat([torch.ones_like(condition) * -1, condition], dim=0) 251 | 252 | control_latents = {} 253 | self.mp_control.to(device=device) 254 | if hasattr(self, 'mp_control_proj') and self.version == 'v4': 255 | self.mp_control_proj.to(device=device) 256 | if 'drive_coeff' in drive_params: 257 | drive_coeff = drive_params['drive_coeff'].clone().to(device=device) 258 | face_parts = drive_params['face_parts'].clone().to(device=device) 259 | if self.do_classifier_free_guidance: 260 | drive_coeff = torch.cat([torch.zeros_like(drive_coeff), drive_coeff], dim=0) 261 | face_parts = torch.cat([torch.zeros_like(face_parts), face_parts], dim=0) 262 | control_latents1 = self.mp_control(condition=condition, drive_coeff=drive_coeff, face_parts=face_parts) 263 | if self.version == 'v4': 264 | control_latents1 = self.mp_control_proj(control_latents1) 265 | control_latents.update(control_latents1) 266 | elif 'pd_fpg' in drive_params: 267 | pd_fpg = drive_params['pd_fpg'].clone().to(device=device) 268 | if self.do_classifier_free_guidance: 269 | pd_fpg = torch.cat([torch.zeros_like(pd_fpg), pd_fpg], dim=0) 270 | control_latents2 = self.mp_control(condition=condition, emo_embedding=pd_fpg) 271 | if self.version == 'v4': 272 | control_latents2 = self.mp_control_proj(control_latents2) 273 | control_latents.update(control_latents2) 274 | self.mp_control.cpu() 275 | if self.version == 'v4': 276 | self.mp_control_proj.cpu() 277 | 278 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 279 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 280 | 281 | # 7.1 Add image embeds for IP-Adapter 282 | added_cond_kwargs = ( 283 | {"image_embeds": image_embeds} 284 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None 285 | else None 286 | ) 287 | base_noise = randn_tensor([batch_size, c, h, w], dtype=prompt_embeds.dtype, generator=generator).to(device=device) 288 | 289 | latent_model_input = torch.cat([torch.zeros_like(ref_latents), ref_latents]) if ( 290 | self.do_classifier_free_guidance) else ref_latents 291 | # latent_model_input = torch.cat([ref_latents_neg, ref_latents], dim=0) 292 | self.unet_ref.to(device=device) 293 | cached_res = self.unet_ref( 294 | latent_model_input.unsqueeze(2), 295 | 0, 296 | encoder_hidden_states=prompt_embeds_ref, 297 | return_dict=False, 298 | )[1] 299 | self.unet_ref.cpu() 300 | 301 | # 7.2 Optionally get Guidance Scale Embedding 302 | timestep_cond = None 303 | if self.unet.config.time_cond_proj_dim is not None: 304 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) 305 | timestep_cond = self.get_guidance_scale_embedding( 306 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim 307 | ).to(device=device, dtype=prompt_embeds.dtype) 308 | 309 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) 310 | # base_noise = randn_tensor([batch_size, c, h, w], dtype=prompt_embeds.dtype, generator=generator).to(device=device) 311 | latents = base_noise * scheduler.init_noise_sigma 312 | # 8. Denoising loop 313 | num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order 314 | self._num_timesteps = len(timesteps) 315 | self.unet.to(device=device) 316 | with self.progress_bar(total=num_inference_steps) as progress_bar: 317 | for i, t in enumerate(timesteps): 318 | if self.interrupt: 319 | continue 320 | 321 | # expand the latents if we are doing classifier free guidance 322 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents 323 | latent_model_input = scheduler.scale_model_input(latent_model_input, t) 324 | 325 | # predict the noise residual 326 | noise_pred = self.unet( 327 | latent_model_input.unsqueeze(2), 328 | t, 329 | encoder_hidden_states=prompt_embeds, 330 | reference_hidden_states=cached_res, 331 | control_hidden_states=control_latents, 332 | timestep_cond=timestep_cond, 333 | cross_attention_kwargs=self.cross_attention_kwargs, 334 | added_cond_kwargs=added_cond_kwargs, 335 | return_dict=False, 336 | )[0][:,:,0,:,:] 337 | 338 | # perform guidance 339 | if self.do_classifier_free_guidance: 340 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 341 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) 342 | 343 | # compute the previous noisy sample x_t -> x_t-1 344 | latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 345 | 346 | if callback_on_step_end is not None: 347 | callback_kwargs = {} 348 | for k in callback_on_step_end_tensor_inputs: 349 | callback_kwargs[k] = locals()[k] 350 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 351 | 352 | latents = callback_outputs.pop("latents", latents) 353 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 354 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 355 | 356 | # call the callback, if provided 357 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): 358 | progress_bar.update() 359 | if callback is not None and i % callback_steps == 0: 360 | step_idx = i // getattr(scheduler, "order", 1) 361 | callback(step_idx, t, latents) 362 | 363 | self.unet.cpu() 364 | 365 | self.vae_decode.to(device=device) 366 | if not output_type == "latent": 367 | image = self.vae_decode.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 368 | 0 369 | ] 370 | else: 371 | image = latents 372 | self.vae_decode.cpu() 373 | 374 | do_denormalize = [True] * image.shape[0] 375 | 376 | image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) 377 | 378 | # Offload all models 379 | self.maybe_free_model_hooks() 380 | 381 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None), latents.detach().cpu() / self.vae.config.scaling_factor 382 | -------------------------------------------------------------------------------- /hellomeme/pipelines/pipline_hm5_image.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | @File : hm_pipline_image.py 5 | @Author : Songkey 6 | @Email : songkey@pku.edu.cn 7 | @Date : 1/3/2025 8 | @Desc : 9 | adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py 10 | """ 11 | 12 | import copy 13 | from typing import Any, Callable, Dict, List, Optional, Union 14 | import torch 15 | 16 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback 17 | from diffusers.image_processor import PipelineImageInput 18 | from diffusers.utils import deprecate 19 | from diffusers.utils.torch_utils import randn_tensor 20 | from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput 21 | from diffusers import DPMSolverMultistepScheduler 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_timesteps, retrieve_latents 23 | from ..models import (HM3Denoising3D, 24 | HMPipeline, HM5ReferenceAdapter, 25 | HM5ControlNetBase, 26 | HM5SD15ControlProj) 27 | 28 | class HM5ImagePipeline(HMPipeline): 29 | def caryomitosis(self, **kwargs): 30 | if hasattr(self, "unet_ref"): 31 | del self.unet_ref 32 | self.unet_ref = HM3Denoising3D.from_unet2d(self.unet) 33 | self.unet_ref.cpu() 34 | 35 | if not isinstance(self.unet, HM3Denoising3D): 36 | unet = HM3Denoising3D.from_unet2d(unet=self.unet) 37 | # todo: 不够优雅 38 | del self.unet 39 | self.unet = unet 40 | self.unet.cpu() 41 | 42 | self.vae.cpu() 43 | self.vae_decode = copy.deepcopy(self.vae) 44 | self.text_encoder.cpu() 45 | self.text_encoder_ref = copy.deepcopy(self.text_encoder) 46 | self.safety_checker.cpu() 47 | 48 | def insert_hm_modules(self, version='v5', dtype=torch.float16, modelscope=False): 49 | 50 | self.version = version 51 | if modelscope: 52 | from modelscope import snapshot_download 53 | hm_reference_dir = snapshot_download('songkey/hm5_reference') 54 | hm_control_dir = snapshot_download('songkey/hm5_control_base') 55 | hm_control_proj_dir = snapshot_download('songkey/hm5_control_proj') 56 | else: 57 | hm_reference_dir = 'songkey/hm5_reference' 58 | hm_control_dir = 'songkey/hm5_control_base' 59 | hm_control_proj_dir = 'songkey/hm5_control_proj' 60 | 61 | if isinstance(self.unet, HM3Denoising3D): 62 | hm_adapter = HM5ReferenceAdapter.from_pretrained(hm_reference_dir) 63 | 64 | self.unet.insert_reference_adapter(hm_adapter) 65 | self.unet.to(device='cpu', dtype=dtype).eval() 66 | 67 | if hasattr(self, "unet_ref"): 68 | self.unet_ref.to(device='cpu', dtype=dtype).eval() 69 | 70 | if hasattr(self, "mp_control"): 71 | del self.mp_control 72 | 73 | if hasattr(self, "mp_control_proj"): 74 | del self.mp_control_proj 75 | 76 | self.mp_control = HM5ControlNetBase.from_pretrained(hm_control_dir) 77 | self.mp_control_proj = HM5SD15ControlProj.from_pretrained(hm_control_proj_dir) 78 | 79 | self.mp_control.to(device='cpu', dtype=dtype).eval() 80 | self.mp_control_proj.to(device='cpu', dtype=dtype).eval() 81 | 82 | self.vae.to(device='cpu', dtype=dtype).eval() 83 | self.vae_decode.to(device='cpu', dtype=dtype).eval() 84 | self.text_encoder.to(device='cpu', dtype=dtype).eval() 85 | 86 | @torch.no_grad() 87 | def __call__( 88 | self, 89 | prompt: Union[str, List[str]] = None, 90 | image: PipelineImageInput = None, 91 | drive_params: Dict[str, Any] = None, 92 | strength: float = 0.8, 93 | num_inference_steps: Optional[int] = 50, 94 | timesteps: List[int] = None, 95 | sigmas: List[float] = None, 96 | guidance_scale: Optional[float] = 7.5, 97 | negative_prompt: Optional[Union[str, List[str]]] = None, 98 | eta: Optional[float] = 0.0, 99 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 100 | prompt_embeds: Optional[torch.Tensor] = None, 101 | negative_prompt_embeds: Optional[torch.Tensor] = None, 102 | ip_adapter_image: Optional[PipelineImageInput] = None, 103 | ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, 104 | output_type: Optional[str] = "pil", 105 | device: Optional[str] = "cpu", 106 | return_dict: bool = True, 107 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 108 | clip_skip: int = None, 109 | callback_on_step_end: Optional[ 110 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] 111 | ] = None, 112 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 113 | **kwargs, 114 | ): 115 | callback = kwargs.pop("callback", None) 116 | callback_steps = kwargs.pop("callback_steps", None) 117 | num_images_per_prompt = 1 118 | 119 | if callback is not None: 120 | deprecate( 121 | "callback", 122 | "1.0.0", 123 | "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", 124 | ) 125 | if callback_steps is not None: 126 | deprecate( 127 | "callback_steps", 128 | "1.0.0", 129 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", 130 | ) 131 | 132 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): 133 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs 134 | 135 | # 1. Check inputs. Raise error if not correct 136 | self.check_inputs( 137 | prompt, 138 | strength, 139 | callback_steps, 140 | negative_prompt, 141 | prompt_embeds, 142 | negative_prompt_embeds, 143 | ip_adapter_image, 144 | ip_adapter_image_embeds, 145 | callback_on_step_end_tensor_inputs, 146 | ) 147 | 148 | self._guidance_scale = guidance_scale 149 | self._clip_skip = clip_skip 150 | self._cross_attention_kwargs = cross_attention_kwargs 151 | self._interrupt = False 152 | 153 | # 2. Define call parameters 154 | if prompt is not None and isinstance(prompt, str): 155 | batch_size = 1 156 | elif prompt is not None and isinstance(prompt, list): 157 | batch_size = len(prompt) 158 | else: 159 | batch_size = prompt_embeds.shape[0] 160 | 161 | # device = self.device 162 | 163 | # 3. Encode input prompt 164 | text_encoder_lora_scale = ( 165 | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None 166 | ) 167 | 168 | self.text_encoder_ref.to(device=device) 169 | prompt_embeds_ref, negative_prompt_embeds_ref = self.encode_prompt_sk( 170 | self.text_encoder_ref, 171 | prompt, 172 | device, 173 | num_images_per_prompt, 174 | self.do_classifier_free_guidance, 175 | negative_prompt, 176 | prompt_embeds=prompt_embeds, 177 | negative_prompt_embeds=negative_prompt_embeds, 178 | lora_scale=text_encoder_lora_scale, 179 | clip_skip=self.clip_skip, 180 | ) 181 | self.text_encoder_ref.cpu() 182 | 183 | self.text_encoder.to(device=device) 184 | prompt_embeds, negative_prompt_embeds = self.encode_prompt_sk( 185 | self.text_encoder, 186 | prompt, 187 | device, 188 | num_images_per_prompt, 189 | self.do_classifier_free_guidance, 190 | negative_prompt, 191 | prompt_embeds=prompt_embeds, 192 | negative_prompt_embeds=negative_prompt_embeds, 193 | lora_scale=text_encoder_lora_scale, 194 | clip_skip=self.clip_skip, 195 | ) 196 | self.text_encoder.cpu() 197 | 198 | # For classifier free guidance, we need to do two forward passes. 199 | # Here we concatenate the unconditional and text embeddings into a single batch 200 | # to avoid doing two forward passes 201 | if self.do_classifier_free_guidance: 202 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 203 | prompt_embeds_ref = torch.cat([negative_prompt_embeds_ref, prompt_embeds_ref]) 204 | 205 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None: 206 | image_embeds = self.prepare_ip_adapter_image_embeds( 207 | ip_adapter_image, 208 | ip_adapter_image_embeds, 209 | device, 210 | batch_size * num_images_per_prompt, 211 | self.do_classifier_free_guidance, 212 | ) 213 | 214 | # 4. Preprocess 215 | image = self.image_processor.preprocess(image).to(device=device, dtype=prompt_embeds.dtype) 216 | 217 | scheduler = DPMSolverMultistepScheduler( 218 | num_train_timesteps=1000, 219 | beta_start=0.00085, 220 | beta_end=0.012, 221 | beta_schedule="scaled_linear", 222 | # use_karras_sigmas=True, 223 | algorithm_type="sde-dpmsolver++", 224 | ) 225 | 226 | # 5. set timesteps 227 | timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, timesteps, sigmas) 228 | 229 | # 6. Prepare reference latents 230 | self.vae.to(device=device) 231 | ref_latents = [ 232 | retrieve_latents(self.vae.encode(image[i: i + 1].to(device=device)), generator=generator) 233 | for i in range(batch_size) 234 | ] 235 | self.vae.cpu() 236 | 237 | ref_latents = torch.cat(ref_latents, dim=0) 238 | ref_latents = self.vae.config.scaling_factor * ref_latents 239 | c, h, w = ref_latents.shape[1:] 240 | 241 | condition = drive_params['condition'].clone().to(device=device) 242 | if self.do_classifier_free_guidance: 243 | condition = torch.cat([torch.ones_like(condition) * -1, condition], dim=0) 244 | 245 | control_latents = {} 246 | self.mp_control.to(device=device) 247 | self.mp_control_proj.to(device=device) 248 | if 'drive_coeff' in drive_params: 249 | drive_coeff = drive_params['drive_coeff'].clone().to(device=device) 250 | face_parts = drive_params['face_parts'].clone().to(device=device) 251 | if self.do_classifier_free_guidance: 252 | drive_coeff = torch.cat([torch.zeros_like(drive_coeff), drive_coeff], dim=0) 253 | face_parts = torch.cat([torch.zeros_like(face_parts), face_parts], dim=0) 254 | control_latents1 = self.mp_control(condition=condition, drive_coeff=drive_coeff, face_parts=face_parts) 255 | control_latents1 = self.mp_control_proj(control_latents1) 256 | control_latents.update(control_latents1) 257 | elif 'pd_fpg' in drive_params: 258 | pd_fpg = drive_params['pd_fpg'].clone().to(device=device) 259 | if self.do_classifier_free_guidance: 260 | pd_fpg = torch.cat([torch.zeros_like(pd_fpg), pd_fpg], dim=0) 261 | control_latents2 = self.mp_control(condition=condition, emo_embedding=pd_fpg) 262 | control_latents2 = self.mp_control_proj(control_latents2) 263 | control_latents.update(control_latents2) 264 | self.mp_control.cpu() 265 | self.mp_control_proj.cpu() 266 | 267 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 268 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 269 | 270 | # 7.1 Add image embeds for IP-Adapter 271 | added_cond_kwargs = ( 272 | {"image_embeds": image_embeds} 273 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None 274 | else None 275 | ) 276 | base_noise = randn_tensor([batch_size, c, h, w], dtype=prompt_embeds.dtype, generator=generator).to(device=device) 277 | 278 | latent_model_input = torch.cat([torch.zeros_like(ref_latents), ref_latents]) if ( 279 | self.do_classifier_free_guidance) else ref_latents 280 | # latent_model_input = torch.cat([ref_latents_neg, ref_latents], dim=0) 281 | self.unet_ref.to(device=device) 282 | cached_res = self.unet_ref( 283 | latent_model_input.unsqueeze(2), 284 | 0, 285 | encoder_hidden_states=prompt_embeds_ref, 286 | return_dict=False, 287 | )[1] 288 | self.unet_ref.cpu() 289 | 290 | # 7.2 Optionally get Guidance Scale Embedding 291 | timestep_cond = None 292 | if self.unet.config.time_cond_proj_dim is not None: 293 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) 294 | timestep_cond = self.get_guidance_scale_embedding( 295 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim 296 | ).to(device=device, dtype=prompt_embeds.dtype) 297 | 298 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) 299 | 300 | latents = base_noise * scheduler.init_noise_sigma 301 | # 8. Denoising loop 302 | num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order 303 | self._num_timesteps = len(timesteps) 304 | self.unet.to(device=device) 305 | with self.progress_bar(total=num_inference_steps) as progress_bar: 306 | for i, t in enumerate(timesteps): 307 | if self.interrupt: 308 | continue 309 | 310 | # expand the latents if we are doing classifier free guidance 311 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents 312 | latent_model_input = scheduler.scale_model_input(latent_model_input, t) 313 | 314 | # predict the noise residual 315 | noise_pred = self.unet( 316 | latent_model_input.unsqueeze(2), 317 | t, 318 | encoder_hidden_states=prompt_embeds, 319 | reference_hidden_states=cached_res, 320 | control_hidden_states=control_latents, 321 | timestep_cond=timestep_cond, 322 | cross_attention_kwargs=self.cross_attention_kwargs, 323 | added_cond_kwargs=added_cond_kwargs, 324 | return_dict=False, 325 | )[0][:,:,0,:,:] 326 | 327 | # perform guidance 328 | if self.do_classifier_free_guidance: 329 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 330 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) 331 | 332 | # compute the previous noisy sample x_t -> x_t-1 333 | latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 334 | 335 | if callback_on_step_end is not None: 336 | callback_kwargs = {} 337 | for k in callback_on_step_end_tensor_inputs: 338 | callback_kwargs[k] = locals()[k] 339 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 340 | 341 | latents = callback_outputs.pop("latents", latents) 342 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 343 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 344 | 345 | # call the callback, if provided 346 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): 347 | progress_bar.update() 348 | if callback is not None and i % callback_steps == 0: 349 | step_idx = i // getattr(scheduler, "order", 1) 350 | callback(step_idx, t, latents) 351 | 352 | self.unet.cpu() 353 | 354 | self.vae_decode.to(device=device) 355 | if not output_type == "latent": 356 | image = self.vae_decode.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 357 | 0 358 | ] 359 | else: 360 | image = latents 361 | self.vae_decode.cpu() 362 | 363 | do_denormalize = [True] * image.shape[0] 364 | 365 | image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) 366 | 367 | # Offload all models 368 | self.maybe_free_model_hooks() 369 | 370 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None), latents.detach().cpu() / self.vae.config.scaling_factor 371 | -------------------------------------------------------------------------------- /hellomeme/pipelines/pipline_hm_image.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | @File : hm_pipline_image.py 5 | @Author : Songkey 6 | @Email : songkey@pku.edu.cn 7 | @Date : 8/29/2024 8 | @Desc : 9 | adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py 10 | """ 11 | 12 | import copy 13 | from typing import Any, Callable, Dict, List, Optional, Union 14 | import torch 15 | 16 | from diffusers import EulerDiscreteScheduler 17 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback 18 | from diffusers.image_processor import PipelineImageInput 19 | from diffusers.utils import deprecate 20 | from diffusers.utils.torch_utils import randn_tensor 21 | from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_timesteps, retrieve_latents 23 | 24 | from ..models import HMDenoising3D, HMControlNet, HMControlNet2, HMV2ControlNet, HMV2ControlNet2, HMPipeline 25 | from ..models import HMReferenceAdapter 26 | 27 | class HMImagePipeline(HMPipeline): 28 | def caryomitosis(self, **kwargs): 29 | if hasattr(self, "unet_ref"): 30 | del self.unet_ref 31 | self.unet_ref = HMDenoising3D.from_unet2d(self.unet) 32 | self.unet_ref.cpu() 33 | 34 | if not isinstance(self.unet, HMDenoising3D): 35 | unet = HMDenoising3D.from_unet2d(unet=self.unet) 36 | # todo: 不够优雅 37 | del self.unet 38 | self.unet = unet 39 | self.unet.cpu() 40 | 41 | self.vae.cpu() 42 | self.vae_decode = copy.deepcopy(self.vae) 43 | self.text_encoder.cpu() 44 | self.text_encoder_ref = copy.deepcopy(self.text_encoder) 45 | self.safety_checker.cpu() 46 | 47 | def insert_hm_modules(self, version, dtype, modelscope=False): 48 | if modelscope: 49 | from modelscope import snapshot_download 50 | hm_reference_dir = snapshot_download('songkey/hm_reference') 51 | hm2_reference_dir = snapshot_download('songkey/hm2_reference') 52 | hm_control_dir = snapshot_download('songkey/hm_control') 53 | hm_control2_dir = snapshot_download('songkey/hm_control2') 54 | hm2_control_dir = snapshot_download('songkey/hm2_control') 55 | hm2_control2_dir = snapshot_download('songkey/hm2_control2') 56 | else: 57 | hm_reference_dir = 'songkey/hm_reference' 58 | hm2_reference_dir = 'songkey/hm2_reference' 59 | hm_control_dir = 'songkey/hm_control' 60 | hm_control2_dir = 'songkey/hm_control2' 61 | hm2_control_dir = 'songkey/hm2_control' 62 | hm2_control2_dir = 'songkey/hm2_control2' 63 | 64 | if isinstance(self.unet, HMDenoising3D): 65 | if version == 'v1': 66 | hm_adapter = HMReferenceAdapter.from_pretrained(hm_reference_dir) 67 | else: 68 | hm_adapter = HMReferenceAdapter.from_pretrained(hm2_reference_dir) 69 | self.unet.insert_reference_adapter(hm_adapter) 70 | self.unet.to(device='cpu', dtype=dtype).eval() 71 | 72 | if hasattr(self, "unet_ref"): 73 | self.unet_ref.to(device='cpu', dtype=dtype).eval() 74 | 75 | if hasattr(self, "mp_control"): 76 | del self.mp_control 77 | if version == 'v1': 78 | self.mp_control = HMControlNet.from_pretrained(hm_control_dir) 79 | else: 80 | self.mp_control = HMV2ControlNet.from_pretrained(hm2_control_dir) 81 | self.mp_control.to(device='cpu', dtype=dtype).eval() 82 | 83 | if hasattr(self, "mp_control2"): 84 | del self.mp_control2 85 | if version == 'v1': 86 | self.mp_control2 = HMControlNet2.from_pretrained(hm_control2_dir) 87 | else: 88 | self.mp_control2 = HMV2ControlNet2.from_pretrained(hm2_control2_dir) 89 | self.mp_control2.to(device='cpu', dtype=dtype).eval() 90 | 91 | self.vae.to(device='cpu', dtype=dtype).eval() 92 | self.vae_decode.to(device='cpu', dtype=dtype).eval() 93 | self.text_encoder.to(device='cpu', dtype=dtype).eval() 94 | 95 | @torch.no_grad() 96 | def __call__( 97 | self, 98 | prompt: Union[str, List[str]] = None, 99 | image: PipelineImageInput = None, 100 | drive_params: Dict[str, Any] = None, 101 | strength: float = 0.8, 102 | num_inference_steps: Optional[int] = 50, 103 | timesteps: List[int] = None, 104 | sigmas: List[float] = None, 105 | guidance_scale: Optional[float] = 7.5, 106 | negative_prompt: Optional[Union[str, List[str]]] = None, 107 | eta: Optional[float] = 0.0, 108 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 109 | prompt_embeds: Optional[torch.Tensor] = None, 110 | negative_prompt_embeds: Optional[torch.Tensor] = None, 111 | ip_adapter_image: Optional[PipelineImageInput] = None, 112 | ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, 113 | output_type: Optional[str] = "pil", 114 | device: Optional[str] = "cpu", 115 | return_dict: bool = True, 116 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 117 | clip_skip: int = None, 118 | callback_on_step_end: Optional[ 119 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] 120 | ] = None, 121 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 122 | **kwargs, 123 | ): 124 | callback = kwargs.pop("callback", None) 125 | callback_steps = kwargs.pop("callback_steps", None) 126 | num_images_per_prompt = 1 127 | 128 | if callback is not None: 129 | deprecate( 130 | "callback", 131 | "1.0.0", 132 | "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", 133 | ) 134 | if callback_steps is not None: 135 | deprecate( 136 | "callback_steps", 137 | "1.0.0", 138 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", 139 | ) 140 | 141 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): 142 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs 143 | 144 | # 1. Check inputs. Raise error if not correct 145 | self.check_inputs( 146 | prompt, 147 | strength, 148 | callback_steps, 149 | negative_prompt, 150 | prompt_embeds, 151 | negative_prompt_embeds, 152 | ip_adapter_image, 153 | ip_adapter_image_embeds, 154 | callback_on_step_end_tensor_inputs, 155 | ) 156 | 157 | self._guidance_scale = guidance_scale 158 | self._clip_skip = clip_skip 159 | self._cross_attention_kwargs = cross_attention_kwargs 160 | self._interrupt = False 161 | 162 | # 2. Define call parameters 163 | if prompt is not None and isinstance(prompt, str): 164 | batch_size = 1 165 | elif prompt is not None and isinstance(prompt, list): 166 | batch_size = len(prompt) 167 | else: 168 | batch_size = prompt_embeds.shape[0] 169 | 170 | # device = self.device 171 | 172 | # 3. Encode input prompt 173 | text_encoder_lora_scale = ( 174 | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None 175 | ) 176 | 177 | self.text_encoder_ref.to(device=device) 178 | prompt_embeds_ref, negative_prompt_embeds_ref = self.encode_prompt_sk( 179 | self.text_encoder_ref, 180 | prompt, 181 | device, 182 | num_images_per_prompt, 183 | self.do_classifier_free_guidance, 184 | negative_prompt, 185 | prompt_embeds=prompt_embeds, 186 | negative_prompt_embeds=negative_prompt_embeds, 187 | lora_scale=text_encoder_lora_scale, 188 | clip_skip=self.clip_skip, 189 | ) 190 | self.text_encoder_ref.cpu() 191 | 192 | self.text_encoder.to(device=device) 193 | prompt_embeds, negative_prompt_embeds = self.encode_prompt_sk( 194 | self.text_encoder, 195 | prompt, 196 | device, 197 | num_images_per_prompt, 198 | self.do_classifier_free_guidance, 199 | negative_prompt, 200 | prompt_embeds=prompt_embeds, 201 | negative_prompt_embeds=negative_prompt_embeds, 202 | lora_scale=text_encoder_lora_scale, 203 | clip_skip=self.clip_skip, 204 | ) 205 | self.text_encoder.cpu() 206 | 207 | # For classifier free guidance, we need to do two forward passes. 208 | # Here we concatenate the unconditional and text embeddings into a single batch 209 | # to avoid doing two forward passes 210 | if self.do_classifier_free_guidance: 211 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 212 | prompt_embeds_ref = torch.cat([negative_prompt_embeds_ref, prompt_embeds_ref]) 213 | 214 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None: 215 | image_embeds = self.prepare_ip_adapter_image_embeds( 216 | ip_adapter_image, 217 | ip_adapter_image_embeds, 218 | device, 219 | batch_size * num_images_per_prompt, 220 | self.do_classifier_free_guidance, 221 | ) 222 | 223 | # 4. Preprocess 224 | image = self.image_processor.preprocess(image).to(device=device, dtype=prompt_embeds.dtype) 225 | 226 | scheduler = EulerDiscreteScheduler( 227 | num_train_timesteps=1000, 228 | beta_start=0.00085, 229 | beta_end=0.012, 230 | beta_schedule="scaled_linear", 231 | ) 232 | 233 | # 5. set timesteps 234 | timesteps, num_inference_steps = retrieve_timesteps( 235 | scheduler, num_inference_steps, device, timesteps, sigmas 236 | ) 237 | # timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) 238 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) 239 | 240 | # 6. Prepare reference latents 241 | self.vae.to(device=device) 242 | ref_latents = [ 243 | retrieve_latents(self.vae.encode(image[i: i + 1].to(device=device)), generator=generator) 244 | for i in range(batch_size) 245 | ] 246 | self.vae.cpu() 247 | 248 | ref_latents = torch.cat(ref_latents, dim=0) 249 | ref_latents = self.vae.config.scaling_factor * ref_latents 250 | c, h, w = ref_latents.shape[1:] 251 | 252 | condition = drive_params['condition'].clone().to(device=device) 253 | if self.do_classifier_free_guidance: 254 | condition = torch.cat([torch.ones_like(condition) * -1, condition], dim=0) 255 | 256 | control_latents = {} 257 | if 'drive_coeff' in drive_params: 258 | self.mp_control.to(device=device) 259 | drive_coeff = drive_params['drive_coeff'].clone().to(device=device) 260 | face_parts = drive_params['face_parts'].clone().to(device=device) 261 | if self.do_classifier_free_guidance: 262 | drive_coeff = torch.cat([torch.zeros_like(drive_coeff), drive_coeff], dim=0) 263 | face_parts = torch.cat([torch.zeros_like(face_parts), face_parts], dim=0) 264 | control_latents1 = self.mp_control(condition=condition, drive_coeff=drive_coeff, face_parts=face_parts) 265 | control_latents.update(control_latents1) 266 | self.mp_control.cpu() 267 | 268 | if 'pd_fpg' in drive_params: 269 | self.mp_control2.to(device=device) 270 | pd_fpg = drive_params['pd_fpg'].clone().to(device=device) 271 | if self.do_classifier_free_guidance: 272 | neg_pd_fpg = drive_params['neg_pd_fpg'].clone().to(device=device) 273 | neg_pd_fpg.repeat_interleave(pd_fpg.size(1), dim=1) 274 | pd_fpg = torch.cat([neg_pd_fpg, pd_fpg], dim=0) 275 | control_latents2 = self.mp_control2(condition=condition, emo_embedding=pd_fpg) 276 | control_latents.update(control_latents2) 277 | self.mp_control2.cpu() 278 | 279 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 280 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 281 | 282 | # 7.1 Add image embeds for IP-Adapter 283 | added_cond_kwargs = ( 284 | {"image_embeds": image_embeds} 285 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None 286 | else None 287 | ) 288 | 289 | latent_model_input = torch.cat([torch.zeros_like(ref_latents), ref_latents]) if self.do_classifier_free_guidance else ref_latents 290 | self.unet_ref.to(device=device) 291 | cached_res = self.unet_ref( 292 | latent_model_input.unsqueeze(2), 293 | 0, 294 | encoder_hidden_states=prompt_embeds_ref, 295 | return_dict=False, 296 | )[1] 297 | self.unet_ref.cpu() 298 | 299 | # 7.2 Optionally get Guidance Scale Embedding 300 | timestep_cond = None 301 | if self.unet.config.time_cond_proj_dim is not None: 302 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) 303 | timestep_cond = self.get_guidance_scale_embedding( 304 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim 305 | ).to(device=device, dtype=prompt_embeds.dtype) 306 | 307 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) 308 | base_noise = randn_tensor([batch_size, c, h, w], dtype=prompt_embeds.dtype, generator=generator).to(device=device) 309 | latents = base_noise * scheduler.init_noise_sigma 310 | # 8. Denoising loop 311 | num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order 312 | self._num_timesteps = len(timesteps) 313 | self.unet.to(device=device) 314 | with self.progress_bar(total=num_inference_steps) as progress_bar: 315 | for i, t in enumerate(timesteps): 316 | if self.interrupt: 317 | continue 318 | 319 | # expand the latents if we are doing classifier free guidance 320 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents 321 | latent_model_input = scheduler.scale_model_input(latent_model_input, t) 322 | 323 | # predict the noise residual 324 | noise_pred = self.unet( 325 | latent_model_input.unsqueeze(2), 326 | t, 327 | encoder_hidden_states=prompt_embeds, 328 | reference_hidden_states=cached_res, 329 | control_hidden_states=control_latents, 330 | timestep_cond=timestep_cond, 331 | cross_attention_kwargs=self.cross_attention_kwargs, 332 | added_cond_kwargs=added_cond_kwargs, 333 | return_dict=False, 334 | )[0][:,:,0,:,:] 335 | 336 | # perform guidance 337 | if self.do_classifier_free_guidance: 338 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 339 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) 340 | 341 | # compute the previous noisy sample x_t -> x_t-1 342 | latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 343 | 344 | if callback_on_step_end is not None: 345 | callback_kwargs = {} 346 | for k in callback_on_step_end_tensor_inputs: 347 | callback_kwargs[k] = locals()[k] 348 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 349 | 350 | latents = callback_outputs.pop("latents", latents) 351 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 352 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 353 | 354 | # call the callback, if provided 355 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): 356 | progress_bar.update() 357 | if callback is not None and i % callback_steps == 0: 358 | step_idx = i // getattr(scheduler, "order", 1) 359 | callback(step_idx, t, latents) 360 | 361 | self.unet.cpu() 362 | 363 | self.vae_decode.to(device=device) 364 | if not output_type == "latent": 365 | image = self.vae_decode.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 366 | 0 367 | ] 368 | else: 369 | image = latents 370 | self.vae_decode.cpu() 371 | 372 | do_denormalize = [True] * image.shape[0] 373 | 374 | image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) 375 | 376 | # Offload all models 377 | self.maybe_free_model_hooks() 378 | 379 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None), latents.detach().cpu() / self.vae.config.scaling_factor 380 | -------------------------------------------------------------------------------- /hellomeme/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # @File : __init__.py 4 | # @Author : Songkey 5 | # @Email : songkey@pku.edu.cn 6 | # @Date : 8/28/2024 7 | # @Desc : 8 | 9 | from .hello_arkit import HelloARKitBSPred 10 | from .hello_face_det import HelloFaceDet 11 | from .hello_camera_demo import HelloCameraDemo 12 | from .hello_3dmm import Hello3DMMPred 13 | from .hello_face_alignment import HelloFaceAlignment 14 | from .pdf import FanEncoder 15 | -------------------------------------------------------------------------------- /hellomeme/tools/hello_3dmm.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | @File : test.py 5 | @Author : Songkey 6 | @Email : songkey@pku.edu.cn 7 | @Date : 11/1/2024 8 | @Desc : Created by Shengjie Wu (wu.shengjie@immomo.com) 9 | 这可能是一个很强大的模型 10 | """ 11 | 12 | import numpy as np 13 | import cv2 14 | import os.path as osp 15 | 16 | from .utils import get_warp_mat_bbox_by_gt_pts_float, create_onnx_session 17 | 18 | def crop_transl_to_full_transl(crop_trans, crop_center, scale, full_center, focal_length): 19 | """ 20 | :param crop_trans: (3), float 21 | :param crop_center: (2), float 22 | :param scale: (1), float 23 | :param full_center: (2), float 24 | :param focal_length: (1), float 25 | :return: 26 | """ 27 | crop_c_x, crop_c_y = crop_center 28 | full_c_x, full_c_y = full_center 29 | bs = 2 * focal_length / scale / crop_trans[2] 30 | full_x = crop_trans[0] - 2 * (crop_c_x - full_c_x) / bs 31 | full_y = crop_trans[1] + 2 * (crop_c_y - full_c_y) / bs 32 | full_z = crop_trans[2] * scale 33 | 34 | full_trans = np.array([full_x, full_y, full_z], dtype=np.float32) 35 | 36 | return full_trans 37 | 38 | class Hello3DMMPred(object): 39 | def __init__(self, gpu_id=None, modelscope=False): 40 | if modelscope: 41 | from modelscope import snapshot_download 42 | model_path = osp.join(snapshot_download('songkey/hello_group_facemodel'), 'hello_3dmm.onnx') 43 | else: 44 | from huggingface_hub import hf_hub_download 45 | model_path = hf_hub_download('songkey/hello_group_facemodel', filename='hello_3dmm.onnx') 46 | self.deep3d_pred_net = create_onnx_session(model_path, gpu_id=gpu_id) 47 | self.deep3d_pred_net_input_name = self.deep3d_pred_net.get_inputs()[0].name 48 | self.deep3d_pred_net_output_name = [output.name for output in self.deep3d_pred_net.get_outputs()] 49 | 50 | self.image_size = 224 51 | self.camera_init_z = -0.4 52 | self.camera_init_focal_len = 386.2879122887948 53 | self.used_focal_len = -5.0 / self.camera_init_z * self.camera_init_focal_len 54 | self.id_dims = 526 55 | self.exp_dims = 203 56 | self.tex_dims = 439 57 | 58 | def forward_params(self, src_image, src_pt): 59 | align_mat_info = get_warp_mat_bbox_by_gt_pts_float(src_pt, base_angle=0, dst_size=self.image_size, expand_ratio=0.35, return_info=True) 60 | align_mat = align_mat_info["M"] 61 | 62 | align_image_rgb_uint8 = cv2.cvtColor(cv2.warpAffine(src_image, align_mat, (self.image_size, self.image_size)), cv2.COLOR_BGR2RGB) 63 | 64 | # cv2.imshow('align_image_rgb_uint8', align_image_rgb_uint8) 65 | 66 | align_image_rgb_fp32 = align_image_rgb_uint8.astype(np.float32) / 255.0 67 | align_image_rgb_fp32_onnx_input = align_image_rgb_fp32.copy().transpose((2, 0, 1))[np.newaxis, ...] 68 | pred_coeffs = self.deep3d_pred_net.run(self.deep3d_pred_net_output_name, 69 | {self.deep3d_pred_net_input_name: align_image_rgb_fp32_onnx_input})[0] 70 | 71 | angles = pred_coeffs[:, self.id_dims + self.exp_dims + self.tex_dims:self.id_dims + self.exp_dims + self.tex_dims + 3] 72 | translations = pred_coeffs[:, self.id_dims + self.exp_dims + self.tex_dims + 3 + 27:] 73 | 74 | crop_global_transl = crop_transl_to_full_transl(translations[0], 75 | crop_center=[align_mat_info["center_x"], 76 | align_mat_info["center_y"]], 77 | scale=align_mat_info["scale"], 78 | full_center=[src_image.shape[1] * 0.5, src_image.shape[0] * 0.5], 79 | focal_length=self.used_focal_len) 80 | return angles, crop_global_transl[np.newaxis, :] 81 | 82 | def compute_rotation_matrix(angles): 83 | n_b = angles.shape[0] 84 | sinx = np.sin(angles[:, 0]) 85 | siny = np.sin(angles[:, 1]) 86 | sinz = np.sin(angles[:, 2]) 87 | cosx = np.cos(angles[:, 0]) 88 | cosy = np.cos(angles[:, 1]) 89 | cosz = np.cos(angles[:, 2]) 90 | rotXYZ = np.eye(3).reshape(1, 3, 3).repeat(n_b*3, 0).reshape(3, n_b, 3, 3) 91 | rotXYZ[0, :, 1, 1] = cosx 92 | rotXYZ[0, :, 1, 2] = -sinx 93 | rotXYZ[0, :, 2, 1] = sinx 94 | rotXYZ[0, :, 2, 2] = cosx 95 | rotXYZ[1, :, 0, 0] = cosy 96 | rotXYZ[1, :, 0, 2] = siny 97 | rotXYZ[1, :, 2, 0] = -siny 98 | rotXYZ[1, :, 2, 2] = cosy 99 | rotXYZ[2, :, 0, 0] = cosz 100 | rotXYZ[2, :, 0, 1] = -sinz 101 | rotXYZ[2, :, 1, 0] = sinz 102 | rotXYZ[2, :, 1, 1] = cosz 103 | rotation = np.matmul(np.matmul(rotXYZ[2], rotXYZ[1]), rotXYZ[0]) 104 | return rotation.transpose(0, 2, 1) 105 | 106 | def rigid_transform(vs, rot, trans): 107 | vs_r = np.matmul(vs, rot) 108 | vs_t = vs_r + trans.reshape(-1, 1, 3) 109 | return vs_t 110 | 111 | def perspective_projection_points(points, image_w, image_h, used_focal_len): 112 | batch_size = points.shape[0] 113 | K = np.zeros([batch_size, 3, 3]) 114 | K[:, 0, 0] = used_focal_len 115 | K[:, 1, 1] = used_focal_len 116 | K[:, 2, 2] = 1. 117 | K[:, 0, 2] = image_w * 0.5 118 | K[:, 1, 2] = image_h * 0.5 119 | 120 | reverse_z = np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]])[np.newaxis, :, :].repeat(batch_size, 0) 121 | 122 | # Transform points 123 | aug_projection = np.matmul(points, reverse_z) 124 | aug_projection = np.matmul(aug_projection, K.transpose((0, 2, 1))) 125 | 126 | # Apply perspective distortion 127 | projected_points = aug_projection[:, :, :2] / aug_projection[:, :, 2:] 128 | return projected_points 129 | 130 | def get_project_points_rect(angle, trans, image_w, image_h, used_focal_len=4828.598903609935): 131 | vs = np.array( 132 | [[-1, -1, 0], [-1, 1, 0], [1, 1, 0], [1, -1, 0]], 133 | ) * 0.05 134 | vs = vs[np.newaxis, :, :] 135 | 136 | rotation = compute_rotation_matrix(angle) 137 | translation = trans.copy() 138 | translation[0, 2] *= 0.05 139 | 140 | vs_t = rigid_transform(vs, rotation, translation) 141 | 142 | project_points = perspective_projection_points(vs_t, image_w, image_h, used_focal_len*0.05) 143 | project_points = np.stack([project_points[:, :, 0], image_h - project_points[:, :, 1]], axis=2) 144 | 145 | return project_points[0] 146 | 147 | -------------------------------------------------------------------------------- /hellomeme/tools/hello_arkit.py: -------------------------------------------------------------------------------- 1 | """ 2 | @File : test.py 3 | @Author : Songkey 4 | @Email : songkey@pku.edu.cn 5 | @Date : 11/1/2024 6 | @Desc : Created by Shengjie Wu (wu.shengjie@immomo.com) 7 | """ 8 | 9 | import numpy as np 10 | import cv2 11 | import os.path as osp 12 | from .utils import create_onnx_session, get_warp_mat_bbox_by_gt_pts_float 13 | 14 | class HelloARKitBSPred(object): 15 | def __init__(self, gpu_id=0, modelscope=False): 16 | if modelscope: 17 | from modelscope import snapshot_download 18 | model_path = osp.join(snapshot_download('songkey/hello_group_facemodel'), 'hello_arkit_blendshape.onnx') 19 | else: 20 | from huggingface_hub import hf_hub_download 21 | model_path = hf_hub_download('songkey/hello_group_facemodel', filename='hello_arkit_blendshape.onnx') 22 | 23 | self.face_rig_net = create_onnx_session(model_path, gpu_id=gpu_id) 24 | self.onnx_input_name = self.face_rig_net.get_inputs()[0].name 25 | self.onnx_output_name = [output.name for output in self.face_rig_net.get_outputs()] 26 | self.image_size = 224 27 | self.expand_ratio = 0.15 28 | 29 | def forward(self, src_image, src_pt): 30 | left_eye_corner = src_pt[74] 31 | right_eye_corner = src_pt[96] 32 | radian = np.arctan2(right_eye_corner[1] - left_eye_corner[1], right_eye_corner[0] - left_eye_corner[0] + 0.00000001) 33 | rotate_angle = np.rad2deg(radian) 34 | align_warp_mat = get_warp_mat_bbox_by_gt_pts_float(src_pt, base_angle=rotate_angle, dst_size=self.image_size, 35 | expand_ratio=self.expand_ratio) 36 | face_rig_input = cv2.warpAffine(src_image, align_warp_mat, (self.image_size, self.image_size)) 37 | 38 | face_rig_onnx_input = face_rig_input.transpose((2, 0, 1)).astype(np.float32)[np.newaxis, :, :, :] / 255.0 39 | face_rig_params = self.face_rig_net.run(self.onnx_output_name, 40 | {self.onnx_input_name: face_rig_onnx_input}) 41 | face_rig_params = face_rig_params[0][0] 42 | return face_rig_params 43 | -------------------------------------------------------------------------------- /hellomeme/tools/hello_face_alignment.py: -------------------------------------------------------------------------------- 1 | """ 2 | @File : test.py 3 | @Author : Songkey 4 | @Email : songkey@pku.edu.cn 5 | @Date : 11/1/2024 6 | @Desc : Created by Shengjie Wu (wu.shengjie@immomo.com) 7 | """ 8 | 9 | import cv2 10 | import os.path as osp 11 | import numpy as np 12 | from .hello_face_det import HelloFaceDet 13 | from .utils import get_warp_mat_bbox, get_warp_mat_bbox_by_gt_pts_float, transform_points 14 | from .utils import create_onnx_session 15 | 16 | class HelloFaceAlignment(object): 17 | def __init__(self, gpu_id=None, modelscope=False): 18 | expand_ratio = 0.15 19 | 20 | if modelscope: 21 | from modelscope import snapshot_download 22 | alignment_model_path = osp.join(snapshot_download('songkey/hello_group_facemodel'), 'hello_face_landmark.onnx') 23 | det_model_path = osp.join(snapshot_download('songkey/hello_group_facemodel'), 'hello_face_det.onnx') 24 | else: 25 | from huggingface_hub import hf_hub_download 26 | alignment_model_path = hf_hub_download('songkey/hello_group_facemodel', filename='hello_face_landmark.onnx') 27 | det_model_path = hf_hub_download('songkey/hello_group_facemodel', filename='hello_face_det.onnx') 28 | self.face_alignment_net_222 = ( 29 | create_onnx_session(alignment_model_path, gpu_id=gpu_id)) 30 | self.onnx_input_name_222 = self.face_alignment_net_222.get_inputs()[0].name 31 | self.onnx_output_name_222 = [output.name for output in self.face_alignment_net_222.get_outputs()] 32 | self.face_image_size = 128 33 | 34 | self.face_detector = HelloFaceDet(det_model_path, gpu_id=gpu_id) 35 | self.expand_ratio = expand_ratio 36 | 37 | def onnx_infer(self, input_uint8): 38 | assert input_uint8.shape[0] == input_uint8.shape[1] == self.face_image_size 39 | onnx_input = input_uint8.transpose((2, 0, 1)).astype(np.float32)[np.newaxis, :, :, :] / 255.0 40 | landmark, euler, prob = self.face_alignment_net_222.run(self.onnx_output_name_222, 41 | {self.onnx_input_name_222: onnx_input}) 42 | 43 | landmark = np.reshape(landmark[0], (2, -1)).transpose((1, 0)) * self.face_image_size 44 | left_eye_corner = landmark[74] 45 | right_eye_corner = landmark[96] 46 | radian = np.arctan2(right_eye_corner[1] - left_eye_corner[1], 47 | right_eye_corner[0] - left_eye_corner[0] + 0.00000001) 48 | euler_rad = np.array([euler[0, 0], euler[0, 1], radian], dtype=np.float32) 49 | prob = prob[0] 50 | 51 | return landmark, euler_rad, prob 52 | 53 | def forward(self, src_image, face_box=None, pre_pts=None, iterations=3): 54 | if pre_pts is None: 55 | if face_box is None: 56 | # Detect max size face 57 | bounding_boxes, _, score = self.face_detector.detect(src_image) 58 | print("facedet score", score) 59 | if len(bounding_boxes) == 0: 60 | return None 61 | bbox = np.zeros(4, dtype=np.float32) 62 | if len(bounding_boxes) >= 1: 63 | max_area = 0.0 64 | for each_bbox in bounding_boxes: 65 | area = (each_bbox[2] - each_bbox[0]) * (each_bbox[3] - each_bbox[1]) 66 | if area > max_area: 67 | bbox[:4] = each_bbox[:4] 68 | max_area = area 69 | else: 70 | bbox = bounding_boxes[0, :4] 71 | else: 72 | bbox = face_box.copy() 73 | M_Face = get_warp_mat_bbox(bbox, 0, self.face_image_size, expand_ratio=self.expand_ratio) 74 | else: 75 | left_eye_corner = pre_pts[74] 76 | right_eye_corner = pre_pts[96] 77 | 78 | radian = np.arctan2(right_eye_corner[1] - left_eye_corner[1], 79 | right_eye_corner[0] - left_eye_corner[0] + 0.00000001) 80 | M_Face = get_warp_mat_bbox_by_gt_pts_float(pre_pts, np.rad2deg(radian), self.face_image_size, 81 | expand_ratio=self.expand_ratio) 82 | 83 | face_input = cv2.warpAffine(src_image, M_Face, (self.face_image_size, self.face_image_size)) 84 | landmarks, euler, prob = self.onnx_infer(face_input) 85 | landmarks = transform_points(landmarks, M_Face, invert=True) 86 | 87 | # Repeat 88 | for i in range(iterations - 1): 89 | M_Face = get_warp_mat_bbox_by_gt_pts_float(landmarks, np.rad2deg(euler[2]), self.face_image_size, 90 | expand_ratio=self.expand_ratio) 91 | face_input = cv2.warpAffine(src_image, M_Face, (self.face_image_size, self.face_image_size)) 92 | landmarks, euler, prob = self.onnx_infer(face_input) 93 | landmarks = transform_points(landmarks, M_Face, invert=True) 94 | 95 | return_dict = { 96 | "pt222": landmarks, 97 | "euler_rad": euler, 98 | "prob": prob, 99 | "M_Face": M_Face, 100 | "face_input": face_input 101 | } 102 | 103 | return return_dict 104 | -------------------------------------------------------------------------------- /hellomeme/tools/hello_face_det.py: -------------------------------------------------------------------------------- 1 | """ 2 | @File : test.py 3 | @Author : Songkey 4 | @Email : songkey@pku.edu.cn 5 | @Date : 11/1/2024 6 | @Desc : Created by Zemin An (an.zemin@hellogroup.com) 7 | """ 8 | 9 | from abc import ABCMeta, abstractmethod 10 | import cv2 11 | import numpy as np 12 | from scipy.special import softmax 13 | import os.path as osp 14 | from .utils import create_onnx_session 15 | 16 | songkey_weights_dir = 'pretrained_models' 17 | 18 | _COLORS = ( 19 | np.array( 20 | [ 21 | 0.000, 22 | 0.447, 23 | 0.741, 24 | ] 25 | ) 26 | .astype(np.float32) 27 | .reshape(-1, 3) 28 | ) 29 | 30 | def get_resize_matrix(raw_shape, dst_shape, keep_ratio): 31 | """ 32 | Get resize matrix for resizing raw img to input size 33 | :param raw_shape: (width, height) of raw image 34 | :param dst_shape: (width, height) of input image 35 | :param keep_ratio: whether keep original ratio 36 | :return: 3x3 Matrix 37 | """ 38 | r_w, r_h = raw_shape 39 | d_w, d_h = dst_shape 40 | Rs = np.eye(3) 41 | if keep_ratio: 42 | C = np.eye(3) 43 | C[0, 2] = -r_w / 2 44 | C[1, 2] = -r_h / 2 45 | 46 | if r_w / r_h < d_w / d_h: 47 | ratio = d_h / r_h 48 | else: 49 | ratio = d_w / r_w 50 | Rs[0, 0] *= ratio 51 | Rs[1, 1] *= ratio 52 | 53 | T = np.eye(3) 54 | T[0, 2] = 0.5 * d_w 55 | T[1, 2] = 0.5 * d_h 56 | return T @ Rs @ C 57 | else: 58 | Rs[0, 0] *= d_w / r_w 59 | Rs[1, 1] *= d_h / r_h 60 | return Rs 61 | 62 | def warp_boxes(boxes, M, width, height): 63 | """Apply transform to boxes 64 | Copy from nanodet/data/transform/warp.py 65 | """ 66 | n = len(boxes) 67 | if n: 68 | # warp points 69 | xy = np.ones((n * 4, 3)) 70 | xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape( 71 | n * 4, 2 72 | ) # x1y1, x2y2, x1y2, x2y1 73 | xy = xy @ M.T # transform 74 | xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale 75 | # create new boxes 76 | x = xy[:, [0, 2, 4, 6]] 77 | y = xy[:, [1, 3, 5, 7]] 78 | xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T 79 | # clip boxes 80 | xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width) 81 | xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height) 82 | return xy.astype(np.float32) 83 | else: 84 | return boxes 85 | 86 | def overlay_bbox_cv(img, all_box, class_names): 87 | """Draw result boxes 88 | Copy from nanodet/util/visualization.py 89 | """ 90 | # all_box array of [label, x0, y0, x1, y1, score] 91 | all_box.sort(key=lambda v: v[5]) 92 | for box in all_box: 93 | label, x0, y0, x1, y1, score = box 94 | # color = self.cmap(i)[:3] 95 | color = (_COLORS[label] * 255).astype(np.uint8).tolist() 96 | text = "{}:{:.1f}%".format(class_names[label], score * 100) 97 | txt_color = (0, 0, 0) if np.mean(_COLORS[label]) > 0.5 else (255, 255, 255) 98 | font = cv2.FONT_HERSHEY_SIMPLEX 99 | txt_size = cv2.getTextSize(text, font, 0.5, 2)[0] 100 | cv2.rectangle(img, (x0, y0), (x1, y1), color, 2) 101 | 102 | cv2.rectangle( 103 | img, 104 | (x0, y0 - txt_size[1] - 1), 105 | (x0 + txt_size[0] + txt_size[1], y0 - 1), 106 | color, 107 | -1, 108 | ) 109 | cv2.putText(img, text, (x0, y0 - 1), font, 0.5, txt_color, thickness=1) 110 | return img 111 | 112 | def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200): 113 | """ 114 | 115 | Args: 116 | box_scores (N, 5): boxes in corner-form and probabilities. 117 | iou_threshold: intersection over union threshold. 118 | top_k: keep top_k results. If k <= 0, keep all the results. 119 | candidate_size: only consider the candidates with the highest scores. 120 | Returns: 121 | picked: a list of indexes of the kept boxes 122 | """ 123 | scores = box_scores[:, -1] 124 | boxes = box_scores[:, :-1] 125 | picked = [] 126 | # _, indexes = scores.sort(descending=True) 127 | indexes = np.argsort(scores) 128 | # indexes = indexes[:candidate_size] 129 | indexes = indexes[-candidate_size:] 130 | while len(indexes) > 0: 131 | # current = indexes[0] 132 | current = indexes[-1] 133 | picked.append(current) 134 | if 0 < top_k == len(picked) or len(indexes) == 1: 135 | break 136 | current_box = boxes[current, :] 137 | # indexes = indexes[1:] 138 | indexes = indexes[:-1] 139 | rest_boxes = boxes[indexes, :] 140 | iou = iou_of( 141 | rest_boxes, 142 | np.expand_dims(current_box, axis=0), 143 | ) 144 | indexes = indexes[iou <= iou_threshold] 145 | 146 | return box_scores[picked, :] 147 | 148 | 149 | def iou_of(boxes0, boxes1, eps=1e-5): 150 | """Return intersection-over-union (Jaccard index) of boxes. 151 | 152 | Args: 153 | boxes0 (N, 4): ground truth boxes. 154 | boxes1 (N or 1, 4): predicted boxes. 155 | eps: a small number to avoid 0 as denominator. 156 | Returns: 157 | iou (N): IoU values. 158 | """ 159 | overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2]) 160 | overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:]) 161 | 162 | overlap_area = area_of(overlap_left_top, overlap_right_bottom) 163 | area0 = area_of(boxes0[..., :2], boxes0[..., 2:]) 164 | area1 = area_of(boxes1[..., :2], boxes1[..., 2:]) 165 | return overlap_area / (area0 + area1 - overlap_area + eps) 166 | 167 | 168 | def area_of(left_top, right_bottom): 169 | """Compute the areas of rectangles given two corners. 170 | 171 | Args: 172 | left_top (N, 2): left top corner. 173 | right_bottom (N, 2): right bottom corner. 174 | 175 | Returns: 176 | area (N): return the area. 177 | """ 178 | hw = np.clip(right_bottom - left_top, 0.0, None) 179 | return hw[..., 0] * hw[..., 1] 180 | 181 | 182 | class NanoDetABC(metaclass=ABCMeta): 183 | def __init__( 184 | self, 185 | input_shape=[272, 160], 186 | reg_max=7, 187 | strides=[8, 16, 32], 188 | prob_threshold=0.4, 189 | iou_threshold=0.3, 190 | num_candidate=1000, 191 | top_k=-1, 192 | class_names=["face"] 193 | ): 194 | self.strides = strides 195 | self.input_shape = input_shape 196 | self.reg_max = reg_max 197 | self.prob_threshold = prob_threshold 198 | self.iou_threshold = iou_threshold 199 | self.num_candidate = num_candidate 200 | self.top_k = top_k 201 | self.img_mean = [103.53, 116.28, 123.675] 202 | self.img_std = [57.375, 57.12, 58.395] 203 | self.input_size = (self.input_shape[1], self.input_shape[0]) 204 | self.class_names = class_names 205 | self.num_classes = len(self.class_names) 206 | 207 | def preprocess(self, img): 208 | # resize image 209 | ResizeM = get_resize_matrix((img.shape[1], img.shape[0]), self.input_size, True) 210 | img_resize = cv2.warpPerspective(img, ResizeM, dsize=self.input_size) 211 | 212 | # normalize image 213 | img_input = img_resize.astype(np.float32) / 255 214 | img_mean = np.array(self.img_mean, dtype=np.float32).reshape(1, 1, 3) / 255 215 | img_std = np.array(self.img_std, dtype=np.float32).reshape(1, 1, 3) / 255 216 | img_input = (img_input - img_mean) / img_std 217 | 218 | # expand dims 219 | img_input = np.transpose(img_input, [2, 0, 1]) 220 | img_input = np.expand_dims(img_input, axis=0) 221 | return img_input, ResizeM 222 | 223 | def postprocess(self, scores, raw_boxes, ResizeM, raw_shape): 224 | # generate centers 225 | decode_boxes = [] 226 | select_scores = [] 227 | for stride, box_distribute, score in zip(self.strides, raw_boxes, scores): 228 | # centers 229 | fm_h = self.input_shape[0] / stride 230 | fm_w = self.input_shape[1] / stride 231 | 232 | h_range = np.arange(fm_h) 233 | w_range = np.arange(fm_w) 234 | ww, hh = np.meshgrid(w_range, h_range) 235 | 236 | ct_row = hh.flatten() * stride 237 | ct_col = ww.flatten() * stride 238 | 239 | center = np.stack((ct_col, ct_row, ct_col, ct_row), axis=1) 240 | 241 | # box distribution to distance 242 | reg_range = np.arange(self.reg_max + 1) 243 | box_distance = box_distribute.reshape((-1, self.reg_max + 1)) 244 | box_distance = softmax(box_distance, axis=1) 245 | box_distance = box_distance * np.expand_dims(reg_range, axis=0) 246 | box_distance = np.sum(box_distance, axis=1).reshape((-1, 4)) 247 | box_distance = box_distance * stride 248 | 249 | # top K candidate 250 | topk_idx = np.argsort(score.max(axis=1))[::-1] 251 | topk_idx = topk_idx[: self.num_candidate] 252 | center = center[topk_idx] 253 | score = score[topk_idx] 254 | box_distance = box_distance[topk_idx] 255 | 256 | # decode box 257 | decode_box = center + [-1, -1, 1, 1] * box_distance 258 | 259 | select_scores.append(score) 260 | decode_boxes.append(decode_box) 261 | 262 | # nms 263 | bboxes = np.concatenate(decode_boxes, axis=0) 264 | confidences = np.concatenate(select_scores, axis=0) 265 | picked_box_probs = [] 266 | picked_labels = [] 267 | for class_index in range(0, confidences.shape[1]): 268 | probs = confidences[:, class_index] 269 | mask = probs > self.prob_threshold 270 | probs = probs[mask] 271 | if probs.shape[0] == 0: 272 | continue 273 | subset_boxes = bboxes[mask, :] 274 | box_probs = np.concatenate([subset_boxes, probs.reshape(-1, 1)], axis=1) 275 | box_probs = hard_nms( 276 | box_probs, 277 | iou_threshold=self.iou_threshold, 278 | top_k=self.top_k, 279 | ) 280 | picked_box_probs.append(box_probs) 281 | picked_labels.extend([class_index] * box_probs.shape[0]) 282 | if not picked_box_probs: 283 | return np.array([]), np.array([]), np.array([]) 284 | picked_box_probs = np.concatenate(picked_box_probs) 285 | 286 | # resize output boxes 287 | picked_box_probs[:, :4] = warp_boxes( 288 | picked_box_probs[:, :4], np.linalg.inv(ResizeM), raw_shape[1], raw_shape[0] 289 | ) 290 | return ( 291 | picked_box_probs[:, :4].astype(np.int32), 292 | np.array(picked_labels), 293 | picked_box_probs[:, 4], 294 | ) 295 | 296 | @abstractmethod 297 | def infer_image(self, img_input): 298 | pass 299 | 300 | def detect(self, img): 301 | raw_shape = img.shape 302 | img_input, ResizeM = self.preprocess(img) 303 | scores, raw_boxes = self.infer_image(img_input) 304 | if scores[0].ndim == 1: # handling num_classes=1 case 305 | scores = [x[:, None] for x in scores] 306 | bbox, label, score = self.postprocess(scores, raw_boxes, ResizeM, raw_shape) 307 | 308 | return bbox, label, score 309 | 310 | class HelloFaceDet(NanoDetABC): 311 | def __init__(self, model_path=osp.join(songkey_weights_dir, 'face/nanodet_humandet_320-192_220302_model_20220315_test3.onnx'), gpu_id=None, *args, **kwargs): 312 | super(HelloFaceDet, self).__init__(*args, **kwargs) 313 | # print("Using ONNX as inference backend") 314 | # print(f"Using weight: {model_path}") 315 | 316 | # load model 317 | self.model_path = model_path 318 | self.ort_session = create_onnx_session(model_path, gpu_id=gpu_id) 319 | self.input_name = self.ort_session.get_inputs()[0].name 320 | 321 | def infer_image(self, img_input): 322 | inference_results = self.ort_session.run(None, {self.input_name: img_input}) 323 | 324 | scores = [np.squeeze(x) for x in inference_results[:3]] 325 | raw_boxes = [np.squeeze(x) for x in inference_results[3:]] 326 | return scores, raw_boxes 327 | -------------------------------------------------------------------------------- /hellomeme/tools/pdf.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | @File : pdf.py 5 | @Author : Songkey 6 | @Email : songkey@pku.edu.cn 7 | @Date : 11/7/2024 8 | @Desc : Adapted from: https://github.com/Dorniwang/PD-FGC-inference/blob/main/lib/models/networks/encoder.py 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from diffusers.models.modeling_utils import ModelMixin 16 | from diffusers.configuration_utils import ConfigMixin, register_to_config 17 | 18 | def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False): 19 | "3x3 convolution with padding" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, 21 | stride=strd, padding=padding, bias=bias) 22 | 23 | class ConvBlock(nn.Module): 24 | def __init__(self, in_planes, out_planes): 25 | super(ConvBlock, self).__init__() 26 | self.bn1 = nn.BatchNorm2d(in_planes) 27 | self.conv1 = conv3x3(in_planes, int(out_planes / 2)) 28 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) 29 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4)) 30 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) 31 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4)) 32 | 33 | if in_planes != out_planes: 34 | self.downsample = nn.Sequential( 35 | nn.BatchNorm2d(in_planes), 36 | nn.ReLU(True), 37 | nn.Conv2d(in_planes, out_planes, 38 | kernel_size=1, stride=1, bias=False), 39 | ) 40 | else: 41 | self.downsample = None 42 | 43 | def forward(self, x): 44 | residual = x 45 | 46 | out1 = self.bn1(x) 47 | out1 = F.relu(out1, True) 48 | out1 = self.conv1(out1) 49 | 50 | out2 = self.bn2(out1) 51 | out2 = F.relu(out2, True) 52 | out2 = self.conv2(out2) 53 | 54 | out3 = self.bn3(out2) 55 | out3 = F.relu(out3, True) 56 | out3 = self.conv3(out3) 57 | 58 | out3 = torch.cat((out1, out2, out3), 1) 59 | 60 | if self.downsample is not None: 61 | residual = self.downsample(residual) 62 | 63 | out3 += residual 64 | 65 | return out3 66 | 67 | 68 | class HourGlass(nn.Module): 69 | def __init__(self, num_modules, depth, num_features): 70 | super(HourGlass, self).__init__() 71 | self.num_modules = num_modules 72 | self.depth = depth 73 | self.features = num_features 74 | self.dropout = nn.Dropout(0.5) 75 | 76 | self._generate_network(self.depth) 77 | 78 | def _generate_network(self, level): 79 | self.add_module('b1_' + str(level), ConvBlock(256, 256)) 80 | 81 | self.add_module('b2_' + str(level), ConvBlock(256, 256)) 82 | 83 | if level > 1: 84 | self._generate_network(level - 1) 85 | else: 86 | self.add_module('b2_plus_' + str(level), ConvBlock(256, 256)) 87 | 88 | self.add_module('b3_' + str(level), ConvBlock(256, 256)) 89 | 90 | def _forward(self, level, inp): 91 | # Upper branch 92 | up1 = inp 93 | up1 = self._modules['b1_' + str(level)](up1) 94 | up1 = self.dropout(up1) 95 | # Lower branch 96 | low1 = F.max_pool2d(inp, 2, stride=2) 97 | low1 = self._modules['b2_' + str(level)](low1) 98 | 99 | if level > 1: 100 | low2 = self._forward(level - 1, low1) 101 | else: 102 | low2 = low1 103 | low2 = self._modules['b2_plus_' + str(level)](low2) 104 | 105 | low3 = low2 106 | low3 = self._modules['b3_' + str(level)](low3) 107 | up1size = up1.size() 108 | rescale_size = (up1size[2], up1size[3]) 109 | up2 = F.interpolate(low3, size=rescale_size, mode='bilinear') 110 | 111 | return up1 + up2 112 | 113 | def forward(self, x): 114 | return self._forward(self.depth, x) 115 | 116 | class FAN_use(nn.Module): 117 | def __init__(self): 118 | super(FAN_use, self).__init__() 119 | self.num_modules = 1 120 | 121 | # Base part 122 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 123 | self.bn1 = nn.BatchNorm2d(64) 124 | self.conv2 = ConvBlock(64, 128) 125 | self.conv3 = ConvBlock(128, 128) 126 | self.conv4 = ConvBlock(128, 256) 127 | 128 | # Stacking part 129 | hg_module = 0 130 | self.add_module('m' + str(hg_module), HourGlass(1, 4, 256)) 131 | self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256)) 132 | self.add_module('conv_last' + str(hg_module), 133 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 134 | self.add_module('l' + str(hg_module), nn.Conv2d(256, 135 | 68, kernel_size=1, stride=1, padding=0)) 136 | self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) 137 | 138 | if hg_module < self.num_modules - 1: 139 | self.add_module( 140 | 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 141 | self.add_module('al' + str(hg_module), nn.Conv2d(68, 142 | 256, kernel_size=1, stride=1, padding=0)) 143 | 144 | self.avgpool = nn.MaxPool2d((2, 2), 2) 145 | self.conv6 = nn.Conv2d(68, 1, 3, 2, 1) 146 | self.fc = nn.Linear(28 * 28, 512) 147 | self.bn5 = nn.BatchNorm2d(68) 148 | self.relu = nn.ReLU(True) 149 | 150 | def forward(self, x): 151 | x = F.relu(self.bn1(self.conv1(x)), True) 152 | x = F.max_pool2d(self.conv2(x), 2) 153 | x = self.conv3(x) 154 | x = self.conv4(x) 155 | 156 | previous = x 157 | 158 | i = 0 159 | hg = self._modules['m' + str(i)](previous) 160 | 161 | ll = hg 162 | ll = self._modules['top_m_' + str(i)](ll) 163 | 164 | ll = self._modules['bn_end' + str(i)](self._modules['conv_last' + str(i)](ll)) 165 | tmp_out = self._modules['l' + str(i)](F.relu(ll)) 166 | 167 | net = self.relu(self.bn5(tmp_out)) 168 | net = self.conv6(net) 169 | net = net.view(-1, net.shape[-2] * net.shape[-1]) 170 | net = self.relu(net) 171 | net = self.fc(net) 172 | return net 173 | 174 | class FanEncoder(ModelMixin, ConfigMixin): 175 | @register_to_config 176 | def __init__(self, pose_dim=6, eye_dim=6): 177 | super().__init__() 178 | 179 | self.model = FAN_use() 180 | 181 | self.to_mouth = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Linear(512, 512)) 182 | self.mouth_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, 512 - pose_dim - eye_dim)) 183 | 184 | # self.to_headpose = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Linear(512, 512)) 185 | # self.headpose_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, pose_dim)) 186 | 187 | self.to_eye = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Linear(512, 512)) 188 | self.eye_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, eye_dim)) 189 | 190 | self.to_emo = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Linear(512, 512)) 191 | self.emo_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, 30)) 192 | 193 | def forward_feature(self, x): 194 | net = self.model(x) 195 | return net 196 | 197 | def forward(self, x): 198 | x = self.model(x) 199 | mouth_feat = self.to_mouth(x) 200 | # headpose_feat = self.to_headpose(x) 201 | # headpose_emb = self.headpose_embed(headpose_feat) 202 | eye_feat = self.to_eye(x) 203 | eye_embed = self.eye_embed(eye_feat) 204 | emo_feat = self.to_emo(x) 205 | emo_embed = self.emo_embed(emo_feat) 206 | 207 | return torch.cat([eye_embed, emo_embed, mouth_feat], dim=1) 208 | # return headpose_emb, eye_embed, emo_embed, mouth_feat 209 | -------------------------------------------------------------------------------- /hellomeme/tools/sr.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | @File : sr.py 5 | @Author : Songkey 6 | @Email : songkey@pku.edu.cn 7 | @Date : 5/30/2025 8 | @Desc : adapted from: https://github.com/xinntao/Real-ESRGAN 9 | """ 10 | 11 | import torch 12 | from torch import nn as nn 13 | from torch.nn import functional as F 14 | import cv2 15 | import numpy as np 16 | import math 17 | import os.path as osp 18 | 19 | def pixel_unshuffle(x, scale): 20 | """ Pixel unshuffle. 21 | 22 | Args: 23 | x (Tensor): Input feature with shape (b, c, hh, hw). 24 | scale (int): Downsample ratio. 25 | 26 | Returns: 27 | Tensor: the pixel unshuffled feature. 28 | """ 29 | b, c, hh, hw = x.size() 30 | out_channel = c * (scale**2) 31 | assert hh % scale == 0 and hw % scale == 0 32 | h = hh // scale 33 | w = hw // scale 34 | x_view = x.view(b, c, h, scale, w, scale) 35 | return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) 36 | 37 | def make_layer(basic_block, num_basic_block, **kwarg): 38 | """Make layers by stacking the same blocks. 39 | 40 | Args: 41 | basic_block (nn.module): nn.module class for basic block. 42 | num_basic_block (int): number of blocks. 43 | 44 | Returns: 45 | nn.Sequential: Stacked blocks in nn.Sequential. 46 | """ 47 | layers = [] 48 | for _ in range(num_basic_block): 49 | layers.append(basic_block(**kwarg)) 50 | return nn.Sequential(*layers) 51 | 52 | class ResidualDenseBlock(nn.Module): 53 | """Residual Dense Block. 54 | 55 | Used in RRDB block in ESRGAN. 56 | 57 | Args: 58 | num_feat (int): Channel number of intermediate features. 59 | num_grow_ch (int): Channels for each growth. 60 | """ 61 | 62 | def __init__(self, num_feat=64, num_grow_ch=32): 63 | super(ResidualDenseBlock, self).__init__() 64 | self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) 65 | self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) 66 | self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) 67 | self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) 68 | self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) 69 | 70 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 71 | 72 | # initialization 73 | # default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 74 | 75 | def forward(self, x): 76 | x1 = self.lrelu(self.conv1(x)) 77 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 78 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 79 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 80 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 81 | # Empirically, we use 0.2 to scale the residual for better performance 82 | return x5 * 0.2 + x 83 | 84 | 85 | class RRDB(nn.Module): 86 | """Residual in Residual Dense Block. 87 | 88 | Used in RRDB-Net in ESRGAN. 89 | 90 | Args: 91 | num_feat (int): Channel number of intermediate features. 92 | num_grow_ch (int): Channels for each growth. 93 | """ 94 | 95 | def __init__(self, num_feat, num_grow_ch=32): 96 | super(RRDB, self).__init__() 97 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) 98 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) 99 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) 100 | 101 | def forward(self, x): 102 | out = self.rdb1(x) 103 | out = self.rdb2(out) 104 | out = self.rdb3(out) 105 | # Empirically, we use 0.2 to scale the residual for better performance 106 | return out * 0.2 + x 107 | 108 | class RRDBNet(nn.Module): 109 | """Networks consisting of Residual in Residual Dense Block, which is used 110 | in ESRGAN. 111 | 112 | ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. 113 | 114 | We extend ESRGAN for scale x2 and scale x1. 115 | Note: This is one option for scale 1, scale 2 in RRDBNet. 116 | We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size 117 | and enlarge the channel size before feeding inputs into the main ESRGAN architecture. 118 | 119 | Args: 120 | num_in_ch (int): Channel number of inputs. 121 | num_out_ch (int): Channel number of outputs. 122 | num_feat (int): Channel number of intermediate features. 123 | Default: 64 124 | num_block (int): Block number in the trunk network. Defaults: 23 125 | num_grow_ch (int): Channels for each growth. Default: 32. 126 | """ 127 | 128 | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): 129 | super(RRDBNet, self).__init__() 130 | self.scale = scale 131 | if scale == 2: 132 | num_in_ch = num_in_ch * 4 133 | elif scale == 1: 134 | num_in_ch = num_in_ch * 16 135 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 136 | self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) 137 | self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 138 | # upsample 139 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 140 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 141 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 142 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 143 | 144 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 145 | 146 | def forward(self, x): 147 | if self.scale == 2: 148 | feat = pixel_unshuffle(x, scale=2) 149 | elif self.scale == 1: 150 | feat = pixel_unshuffle(x, scale=4) 151 | else: 152 | feat = x 153 | feat = self.conv_first(feat) 154 | body_feat = self.conv_body(self.body(feat)) 155 | feat = feat + body_feat 156 | # upsample 157 | feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) 158 | feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) 159 | out = self.conv_last(self.lrelu(self.conv_hr(feat))) 160 | return out 161 | 162 | class RealESRGANer(): 163 | def __init__(self, 164 | scale, 165 | tile=0, 166 | tile_pad=10, 167 | pre_pad=10, 168 | half=True, 169 | device=None, 170 | gpu_id=None, 171 | modelscope=False): 172 | self.scale = scale 173 | self.tile_size = tile 174 | self.tile_pad = tile_pad 175 | self.pre_pad = pre_pad 176 | self.mod_scale = None 177 | self.half = half 178 | 179 | # initialize model 180 | if gpu_id: 181 | self.device = torch.device( 182 | f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device 183 | else: 184 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device 185 | 186 | if modelscope: 187 | from modelscope import snapshot_download 188 | model_path = osp.join(snapshot_download('songkey/ESRGAN'), 'RealESRGAN_x2plus.pth') 189 | else: 190 | from huggingface_hub import hf_hub_download 191 | model_path = hf_hub_download('songkey/ESRGAN', filename='RealESRGAN_x2plus.pth') 192 | 193 | loadnet = torch.load(model_path, map_location=torch.device('cpu')) 194 | 195 | # prefer to use params_ema 196 | if 'params_ema' in loadnet: 197 | keyname = 'params_ema' 198 | else: 199 | keyname = 'params' 200 | self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) 201 | self.model.load_state_dict(loadnet[keyname], strict=True) 202 | 203 | self.model.eval() 204 | self.model = self.model.to(self.device) 205 | if self.half: 206 | self.model = self.model.half() 207 | 208 | def dni(self, net_a, net_b, dni_weight, key='params', loc='cpu'): 209 | """Deep network interpolation. 210 | 211 | ``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition`` 212 | """ 213 | net_a = torch.load(net_a, map_location=torch.device(loc)) 214 | net_b = torch.load(net_b, map_location=torch.device(loc)) 215 | for k, v_a in net_a[key].items(): 216 | net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k] 217 | return net_a 218 | 219 | def pre_process(self, img): 220 | """Pre-process, such as pre-pad and mod pad, so that the images can be divisible 221 | """ 222 | img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() 223 | self.img = img.unsqueeze(0).to(self.device) 224 | if self.half: 225 | self.img = self.img.half() 226 | 227 | # pre_pad 228 | if self.pre_pad != 0: 229 | self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') 230 | # mod pad for divisible borders 231 | if self.scale == 2: 232 | self.mod_scale = 2 233 | elif self.scale == 1: 234 | self.mod_scale = 4 235 | if self.mod_scale is not None: 236 | self.mod_pad_h, self.mod_pad_w = 0, 0 237 | _, _, h, w = self.img.size() 238 | if (h % self.mod_scale != 0): 239 | self.mod_pad_h = (self.mod_scale - h % self.mod_scale) 240 | if (w % self.mod_scale != 0): 241 | self.mod_pad_w = (self.mod_scale - w % self.mod_scale) 242 | self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') 243 | 244 | def process(self): 245 | # model inference 246 | self.output = self.model(self.img) 247 | 248 | def tile_process(self): 249 | """It will first crop input images to tiles, and then process each tile. 250 | Finally, all the processed tiles are merged into one images. 251 | 252 | Modified from: https://github.com/ata4/esrgan-launcher 253 | """ 254 | batch, channel, height, width = self.img.shape 255 | output_height = height * self.scale 256 | output_width = width * self.scale 257 | output_shape = (batch, channel, output_height, output_width) 258 | 259 | # start with black image 260 | self.output = self.img.new_zeros(output_shape) 261 | tiles_x = math.ceil(width / self.tile_size) 262 | tiles_y = math.ceil(height / self.tile_size) 263 | 264 | # loop over all tiles 265 | for y in range(tiles_y): 266 | for x in range(tiles_x): 267 | # extract tile from input image 268 | ofs_x = x * self.tile_size 269 | ofs_y = y * self.tile_size 270 | # input tile area on total image 271 | input_start_x = ofs_x 272 | input_end_x = min(ofs_x + self.tile_size, width) 273 | input_start_y = ofs_y 274 | input_end_y = min(ofs_y + self.tile_size, height) 275 | 276 | # input tile area on total image with padding 277 | input_start_x_pad = max(input_start_x - self.tile_pad, 0) 278 | input_end_x_pad = min(input_end_x + self.tile_pad, width) 279 | input_start_y_pad = max(input_start_y - self.tile_pad, 0) 280 | input_end_y_pad = min(input_end_y + self.tile_pad, height) 281 | 282 | # input tile dimensions 283 | input_tile_width = input_end_x - input_start_x 284 | input_tile_height = input_end_y - input_start_y 285 | tile_idx = y * tiles_x + x + 1 286 | input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] 287 | 288 | # upscale tile 289 | try: 290 | with torch.no_grad(): 291 | output_tile = self.model(input_tile) 292 | except RuntimeError as error: 293 | print('Error', error) 294 | print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') 295 | 296 | # output tile area on total image 297 | output_start_x = input_start_x * self.scale 298 | output_end_x = input_end_x * self.scale 299 | output_start_y = input_start_y * self.scale 300 | output_end_y = input_end_y * self.scale 301 | 302 | # output tile area without padding 303 | output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale 304 | output_end_x_tile = output_start_x_tile + input_tile_width * self.scale 305 | output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale 306 | output_end_y_tile = output_start_y_tile + input_tile_height * self.scale 307 | 308 | # put tile into output image 309 | self.output[:, :, output_start_y:output_end_y, 310 | output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, 311 | output_start_x_tile:output_end_x_tile] 312 | 313 | def post_process(self): 314 | # remove extra pad 315 | if self.mod_scale is not None: 316 | _, _, h, w = self.output.size() 317 | self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale] 318 | # remove prepad 319 | if self.pre_pad != 0: 320 | _, _, h, w = self.output.size() 321 | self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] 322 | return self.output 323 | 324 | @torch.no_grad() 325 | def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'): 326 | h_input, w_input = img.shape[0:2] 327 | # img: numpy 328 | img = img.astype(np.float32) 329 | if np.max(img) > 256: # 16-bit image 330 | max_range = 65535 331 | print('\tInput is a 16-bit image') 332 | else: 333 | max_range = 255 334 | img = img / max_range 335 | if len(img.shape) == 2: # gray image 336 | img_mode = 'L' 337 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) 338 | elif img.shape[2] == 4: # RGBA image with alpha channel 339 | img_mode = 'RGBA' 340 | alpha = img[:, :, 3] 341 | img = img[:, :, 0:3] 342 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 343 | if alpha_upsampler == 'realesrgan': 344 | alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) 345 | else: 346 | img_mode = 'RGB' 347 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 348 | 349 | # ------------------- process image (without the alpha channel) ------------------- # 350 | self.pre_process(img) 351 | if self.tile_size > 0: 352 | self.tile_process() 353 | else: 354 | self.process() 355 | output_img = self.post_process() 356 | output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy() 357 | output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) 358 | if img_mode == 'L': 359 | output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) 360 | 361 | # ------------------- process the alpha channel if necessary ------------------- # 362 | if img_mode == 'RGBA': 363 | if alpha_upsampler == 'realesrgan': 364 | self.pre_process(alpha) 365 | if self.tile_size > 0: 366 | self.tile_process() 367 | else: 368 | self.process() 369 | output_alpha = self.post_process() 370 | output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() 371 | output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) 372 | output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) 373 | else: # use the cv2 resize for alpha channel 374 | h, w = alpha.shape[0:2] 375 | output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) 376 | 377 | # merge the alpha channel 378 | output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) 379 | output_img[:, :, 3] = output_alpha 380 | 381 | # ------------------------------ return ------------------------------ # 382 | if max_range == 65535: # 16-bit image 383 | output = (output_img * 65535.0).round().astype(np.uint16) 384 | else: 385 | output = (output_img * 255.0).round().astype(np.uint8) 386 | 387 | if outscale is not None and outscale != float(self.scale): 388 | output = cv2.resize( 389 | output, ( 390 | int(w_input * outscale), 391 | int(h_input * outscale), 392 | ), interpolation=cv2.INTER_LANCZOS4) 393 | 394 | return output, img_mode -------------------------------------------------------------------------------- /hellomeme/tools/utils.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # @File : utils.py 4 | # @Author : Songkey 5 | # @Email : songkey@pku.edu.cn 6 | # @Date : 8/18/2024 7 | # @Desc : 8 | 9 | import onnx, onnxruntime 10 | import time 11 | import cv2 12 | import numpy as np 13 | import math 14 | 15 | def create_onnx_session(onnx_path, gpu_id=None)->onnxruntime.InferenceSession: 16 | start = time.perf_counter() 17 | onnx_model = onnx.load(onnx_path) 18 | onnx.checker.check_model(onnx_model) 19 | providers = [ 20 | ('CUDAExecutionProvider', { 21 | 'device_id': int(gpu_id), 22 | 'arena_extend_strategy': 'kNextPowerOfTwo', 23 | #'cuda_mem_limit': 5 * 1024 * 1024 * 1024, 24 | 'cudnn_conv_algo_search': 'EXHAUSTIVE', 25 | 'do_copy_in_default_stream': True, 26 | }), 27 | 'CPUExecutionProvider', 28 | ] if (gpu_id is not None and gpu_id >= 0) else ['CPUExecutionProvider'] 29 | 30 | sess = onnxruntime.InferenceSession(onnx_path, providers=providers) 31 | print('create onnx session cost: {:.3f}s. {}'.format(time.perf_counter() - start, onnx_path)) 32 | return sess 33 | 34 | def smoothing_factor(t_e, cutoff): 35 | r = 2 * math.pi * cutoff * t_e 36 | return r / (r + 1) 37 | 38 | def exponential_smoothing(a, x, x_prev): 39 | return a * x + (1 - a) * x_prev 40 | 41 | class OneEuroFilter: 42 | def __init__(self, dx0=0.0, d_cutoff=1.0): 43 | """Initialize the one euro filter.""" 44 | # self.min_cutoff = float(min_cutoff) 45 | # self.beta = float(beta) 46 | self.d_cutoff = float(d_cutoff) 47 | self.dx_prev = float(dx0) 48 | # self.t_e = fcmin 49 | 50 | def __call__(self, x, x_prev, fcmin=1.0, min_cutoff=1.0, beta=0.0): 51 | if x_prev is None: 52 | return x 53 | # t_e = 1 54 | a_d = smoothing_factor(fcmin, self.d_cutoff) 55 | dx = (x - x_prev) / fcmin 56 | dx_hat = exponential_smoothing(a_d, dx, self.dx_prev) 57 | cutoff = min_cutoff + beta * abs(dx_hat) 58 | a = smoothing_factor(fcmin, cutoff) 59 | x_hat = exponential_smoothing(a, x, x_prev) 60 | self.dx_prev = dx_hat 61 | return x_hat 62 | 63 | def get_warp_mat_bbox(face_bbox, base_angle, dst_size=128, expand_ratio=0.15, aug_angle=0.0, aug_scale=1.0): 64 | face_x_min, face_y_min, face_x_max, face_y_max = face_bbox 65 | face_x_center = (face_x_min + face_x_max) / 2 66 | face_y_center = (face_y_min + face_y_max) / 2 67 | face_width = face_x_max - face_x_min 68 | face_height = face_y_max - face_y_min 69 | scale = dst_size / max(face_width, face_height) * (1 - expand_ratio) * aug_scale 70 | M = cv2.getRotationMatrix2D((face_x_center, face_y_center), angle=base_angle + aug_angle, scale=scale) 71 | offset = [dst_size / 2 - face_x_center, dst_size / 2 - face_y_center] 72 | M[:, 2] += offset 73 | return M 74 | 75 | def transform_points(points, mat, invert=False): 76 | if invert: 77 | mat = cv2.invertAffineTransform(mat) 78 | points = np.expand_dims(points, axis=1) 79 | points = cv2.transform(points, mat, points.shape) 80 | points = np.squeeze(points) 81 | return points 82 | 83 | def get_warp_mat_bbox_by_gt_pts_float(gt_pts, base_angle=0.0, dst_size=128, expand_ratio=0.15, return_info=False): 84 | # step 1 85 | face_x_min, face_x_max = np.min(gt_pts[:, 0]), np.max(gt_pts[:, 0]) 86 | face_y_min, face_y_max = np.min(gt_pts[:, 1]), np.max(gt_pts[:, 1]) 87 | face_x_center = (face_x_min + face_x_max) / 2 88 | face_y_center = (face_y_min + face_y_max) / 2 89 | M_step_1 = cv2.getRotationMatrix2D((face_x_center, face_y_center), angle=base_angle, scale=1.0) 90 | pts_step_1 = transform_points(gt_pts, M_step_1) 91 | face_x_min_step_1, face_x_max_step_1 = np.min(pts_step_1[:, 0]), np.max(pts_step_1[:, 0]) 92 | face_y_min_step_1, face_y_max_step_1 = np.min(pts_step_1[:, 1]), np.max(pts_step_1[:, 1]) 93 | # step 2 94 | face_width = face_x_max_step_1 - face_x_min_step_1 95 | face_height = face_y_max_step_1 - face_y_min_step_1 96 | scale = dst_size / max(face_width, face_height) * (1 - expand_ratio) 97 | M_step_2 = cv2.getRotationMatrix2D((face_x_center, face_y_center), angle=base_angle, scale=scale) 98 | pts_step_2 = transform_points(gt_pts, M_step_2) 99 | face_x_min_step_2, face_x_max_step_2 = np.min(pts_step_2[:, 0]), np.max(pts_step_2[:, 0]) 100 | face_y_min_step_2, face_y_max_step_2 = np.min(pts_step_2[:, 1]), np.max(pts_step_2[:, 1]) 101 | face_x_center_step_2 = (face_x_min_step_2 + face_x_max_step_2) / 2 102 | face_y_center_step_2 = (face_y_min_step_2 + face_y_max_step_2) / 2 103 | 104 | M = cv2.getRotationMatrix2D((face_x_center, face_y_center), angle=base_angle, scale=scale) 105 | offset = [dst_size / 2 - face_x_center_step_2, dst_size / 2 - face_y_center_step_2] 106 | M[:, 2] += offset 107 | 108 | if not return_info: 109 | return M 110 | else: 111 | transform_info = { 112 | "M": M, 113 | "center_x": face_x_center, 114 | "center_y": face_y_center, 115 | "rotate_angle": base_angle, 116 | "scale": scale 117 | } 118 | return transform_info 119 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui_hellomeme" 3 | description = "This repository is the official implementation of the [a/HelloMeme](https://arxiv.org/pdf/2410.22901) ComfyUI interface" 4 | version = "1.3.11" 5 | license = {file = "LICENSE"} 6 | 7 | [project.urls] 8 | Repository = "https://github.com/HelloVision/ComfyUI_HelloMeme" 9 | # Used by Comfy Registry https://comfyregistry.org 10 | 11 | [tool.comfy] 12 | PublisherId = "hellomeme-api" 13 | DisplayName = "hellomeme-api" 14 | Icon = "https://github.com/HelloVision/ComfyUI_HelloMeme/blob/eee70819a114a44c4c689c865b42fa0e523a31e5/examples/helloicon.png" 15 | --------------------------------------------------------------------------------