├── .github └── workflows │ └── publish.yml ├── LICENSE ├── README.md ├── __init__.py ├── configs ├── controlnet-sdxl-promax │ └── config_promax.json └── sdxl │ ├── scheduler │ └── scheduler_config.json │ └── unet │ └── config.json ├── controlnet_union.py ├── diffusers-image-outpaint-workflow.json ├── nodes.py ├── pipeline_fill_sd_xl.py ├── pyproject.toml ├── requirements.txt └── utils.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | permissions: 11 | issues: write 12 | 13 | jobs: 14 | publish-node: 15 | name: Publish Custom Node to registry 16 | runs-on: ubuntu-latest 17 | if: ${{ github.repository_owner == 'GiusTex' }} 18 | steps: 19 | - name: Check out code 20 | uses: actions/checkout@v4 21 | - name: Publish Custom Node 22 | uses: Comfy-Org/publish-node-action@v1 23 | with: 24 | ## Add your own personal access token to your Github Repository secrets and reference it here. 25 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ComfyUI nodes for outpainting images with diffusers, based on [diffusers-image-outpaint](https://huggingface.co/spaces/fffiloni/diffusers-image-outpaint/tree/main) by fffiloni. 2 | 3 | ![image](https://github.com/user-attachments/assets/1a02c2d1-f24e-4ad2-acdc-a2cbb15a1f14) 4 | 5 | #### Updates: 6 | - 15/05/2025: Fixed `missing 'loaded_keys'` error. More details below. 7 | - 17/11/2024: 8 | - Added more options to Pad Image node (resize image, custom resize image percentage, mask overlap percentage, overlap left/right/top/bottom). 9 | - Side notes: 10 | - Now images with round angles work, since the new editable mask covers them, like in the original huggingface space. 11 | - You can use "mask" and "diffusers outpaint cnet image" outputs to preview mask and image. 12 | - You can find in the same [workflow file](https://github.com/GiusTex/ComfyUI-DiffusersImageOutpaint/blob/New-Pad-Node-Options/Diffusers-Outpaint-DoubleWorkflow.json) the workflow with the checkpoint-loader-simple node and another one with clip + vae loader nodes. 13 | - 22/10/2024: 14 | - Unet and Controlnet Models Loader using ComfYUI nodes canceled, since I can't find a way to load them properly; more info at the end. 15 | - Guide to change model used. 16 | - 20/10/2024: No more need to download tokenizers nor text encoders! Now comfyui clip loader works, and you can use your clip models. You can also use the Checkpoint Loader Simple node, to skip the clip selection part. 17 | - 10/2024: You don't need any more the diffusers vae, and can use the extension in low vram mode using `sequential_cpu_offload` (also thanks to [zmwv823](https://github.com/GiusTex/ComfyUI-DiffusersImageOutpaint/pull/4)) that pushes the vram usage from *8,3 gb* down to **_6 gb_**. 18 | 19 | ## Installation 20 | - Download this extension or `git clone` it in comfyui/custom_nodes, then (if comfyui-manager didn't already install the requirements or you have missing modules), from comfyui virtual env write `cd your/path/to/this/extension` and `pip install -r requirements.txt`. 21 | - Download a sdxl model ([example](https://huggingface.co/SG161222/RealVisXL_V5.0_Lightning/blob/main/unet/diffusion_pytorch_model.fp16.safetensors)) in comfyui/models/diffusion_models; 22 | - Download a sdxl controlnet model ([example](https://huggingface.co/xinsir/controlnet-union-sdxl-1.0/blob/main/diffusion_pytorch_model_promax.safetensors)) in comfyui/models/controlnet. 23 | 24 | **⚠ Choosing model and controlnet**: As of now, I only tried `RealVisXL_V5.0_Lightning` and `controlnet-union-promax_sdxl`. Mixing RealVisXL with controlnet-union (non promax version) gave error, so it could be that other models/controlnets give error as well, but I haven't tried much combinations so I can't tell. 25 | 26 |
27 | Some considerations 28 | 29 | Flux is still beyond me (even if I was quite there, I think). I haven't tried integrating other model types, and after my flux failure I don't think I'll try adding other model types. 30 | 31 | Since for now only sdxl models work, the configs are hardcoded. 32 | 33 |
34 | 35 | - (Dual) Clip Loader node: if you use the Clip Loader instead of Checkpoint Loader Simple, and want to use an `sdxl type` model like RealVisXL_V5.0_Lightning, you can download `clip_I` and `clip_g` from [here](https://huggingface.co/Comfy-Org/stable-diffusion-3.5-fp8/tree/main/text_encoders). You can use [this workflow](https://github.com/GiusTex/ComfyUI-DiffusersImageOutpaint/blob/New-Pad-Node-Options/Diffusers-Outpaint-DoubleWorkflow.json) (change model.fp16 with `clip_g`). 36 | 37 | ## Overview 38 | - **Minimum VRAM**: 6 gb with 1280x720 image, rtx 3060, RealVisXL_V5.0_Lightning, sdxl-vae-fp16-fix, controlnet-union-sdxl-promax using `sequential_cpu_offload`, otherwise 8,3 gb; 39 | - ~As seen in [this issue](https://github.com/GiusTex/ComfyUI-DiffusersImageOutpaint/issues/7#issuecomment-2410852908), images with **square corners** are required~. 40 | 41 | The extension gives 5 nodes: 42 | - **Load Diffuser Model**: a simple node to load diffusion `models`. You can download them from Huggingface (the extension doesn't download them automatically). Put them inside the `diffusion_models` folder. It supports only unets; for whole checkpoints (unet+clip+vae) you can use the checkpoint loader simple by comfyanonimous; 43 | - **Load Diffuser Controlnet**: a simple node to load diffusion `models`. You can download them from Huggingface (the extension doesn't download them automatically). Put them inside the `controlnet` folder; 44 | - **Paid Image for Diffusers Outpaint**: this node resizes the image based on the specified `width` and `height`, then resizes it again based on the `resize_image` percentage, and if possible it will put the mask based on the `alignment` specified, otherwise it will revert back to the default "middle" `alignment`; 45 | - **Encode Diffusers Outpaint Prompt**: self explanatory. Works as `clip text encode (prompt)`, and specifies what to add to the image; 46 | - **Diffusers Image Outpaint**: This is the main node, that outpaints the image. Currently the generation process is based on fffiloni's one, so you can't reproduce a specific a specific outpaint, and the `seed` option you see is only used to update the UI and generate a new image. You can specify the amount of `steps` to generate the image. 47 | 48 | You _can_ also pass image and mask to `vae encode (for inpainting)` node, then pass the latent to a `sampler`, but controlnets and ip-adapters won't always give good results like with diffusers outpaint, and they require a different workflow, not covered by this extension. 49 | 50 | ## Missing 'loaded_keys' error 51 | Recent versions of `transformers` and `diffusers` broke somethings, you need to revert back, command with some working versions (found [here](https://huggingface.co/spaces/fffiloni/diffusers-image-outpaint/blob/main/requirements.txt)) (do it inside your comfyui env): `pip install transformers==4.45.0 --upgrade diffusers==0.32.2 --upgrade`, or if you use the portable version, run this in ComfyUI_windows_portable -folder: 52 | `python_embeded\python.exe -m pip install transformers==4.45.0 --upgrade diffusers==0.32.2 --upgrade`. 53 | 54 | ## Credits 55 | diffusers-image-outpaint by [fffiloni](https://huggingface.co/spaces/fffiloni/diffusers-image-outpaint/tree/main) 56 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import (PadImageForDiffusersOutpaint, LoadDiffuserModel, LoadDiffuserControlnet, EncodeDiffusersOutpaintPrompt, DiffusersImageOutpaint) 2 | 3 | NODE_CLASS_MAPPINGS = { 4 | "PadImageForDiffusersOutpaint": PadImageForDiffusersOutpaint, 5 | "LoadDiffuserModel": LoadDiffuserModel, 6 | "LoadDiffuserControlnet": LoadDiffuserControlnet, 7 | "EncodeDiffusersOutpaintPrompt": EncodeDiffusersOutpaintPrompt, 8 | "DiffusersImageOutpaint": DiffusersImageOutpaint 9 | } 10 | 11 | NODE_DISPLAY_NAME_MAPPINGS = { 12 | "PadImageForDiffusersOutpaint": "Pad Image For Diffusers Outpaint", 13 | "LoadDiffuserModel": "Load Diffuser Model", 14 | "LoadDiffuserControlnet": "Load Diffuser Controlnet", 15 | "EncodeDiffusersOutpaintPrompt": "Encode Diffusers Outpaint Prompt", 16 | "DiffusersImageOutpaint": "Diffusers Image Outpaint" 17 | } 18 | -------------------------------------------------------------------------------- /configs/controlnet-sdxl-promax/config_promax.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "ControlNetModel", 3 | "_diffusers_version": "0.20.0.dev0", 4 | "act_fn": "silu", 5 | "addition_embed_type": "text_time", 6 | "addition_embed_type_num_heads": 64, 7 | "addition_time_embed_dim": 256, 8 | "attention_head_dim": [ 9 | 5, 10 | 10, 11 | 20 12 | ], 13 | "block_out_channels": [ 14 | 320, 15 | 640, 16 | 1280 17 | ], 18 | "class_embed_type": null, 19 | "conditioning_channels": 3, 20 | "conditioning_embedding_out_channels": [ 21 | 16, 22 | 32, 23 | 96, 24 | 256 25 | ], 26 | "controlnet_conditioning_channel_order": "rgb", 27 | "cross_attention_dim": 2048, 28 | "down_block_types": [ 29 | "DownBlock2D", 30 | "CrossAttnDownBlock2D", 31 | "CrossAttnDownBlock2D" 32 | ], 33 | "downsample_padding": 1, 34 | "encoder_hid_dim": null, 35 | "encoder_hid_dim_type": null, 36 | "flip_sin_to_cos": true, 37 | "freq_shift": 0, 38 | "global_pool_conditions": false, 39 | "in_channels": 4, 40 | "layers_per_block": 2, 41 | "mid_block_scale_factor": 1, 42 | "norm_eps": 1e-05, 43 | "norm_num_groups": 32, 44 | "num_attention_heads": null, 45 | "num_class_embeds": null, 46 | "only_cross_attention": false, 47 | "projection_class_embeddings_input_dim": 2816, 48 | "resnet_time_scale_shift": "default", 49 | "transformer_layers_per_block": [ 50 | 1, 51 | 2, 52 | 10 53 | ], 54 | "upcast_attention": null, 55 | "use_linear_projection": true, 56 | "num_control_type": 8 57 | } 58 | -------------------------------------------------------------------------------- /configs/sdxl/scheduler/scheduler_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "DDIMScheduler", 3 | "_diffusers_version": "0.30.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "clip_sample_range": 1.0, 9 | "dynamic_thresholding_ratio": 0.995, 10 | "num_train_timesteps": 1000, 11 | "prediction_type": "epsilon", 12 | "rescale_betas_zero_snr": false, 13 | "sample_max_value": 1.0, 14 | "set_alpha_to_one": false, 15 | "steps_offset": 1, 16 | "thresholding": false, 17 | "timestep_spacing": "leading", 18 | "trained_betas": null 19 | } 20 | -------------------------------------------------------------------------------- /configs/sdxl/unet/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "UNet2DConditionModel", 3 | "act_fn": "silu", 4 | "addition_embed_type": "text_time", 5 | "addition_embed_type_num_heads": 64, 6 | "addition_time_embed_dim": 256, 7 | "attention_head_dim": [ 8 | 5, 9 | 10, 10 | 20 11 | ], 12 | "attention_type": "default", 13 | "block_out_channels": [ 14 | 320, 15 | 640, 16 | 1280 17 | ], 18 | "center_input_sample": false, 19 | "class_embed_type": null, 20 | "class_embeddings_concat": false, 21 | "conv_in_kernel": 3, 22 | "conv_out_kernel": 3, 23 | "cross_attention_dim": 2048, 24 | "cross_attention_norm": null, 25 | "down_block_types": [ 26 | "DownBlock2D", 27 | "CrossAttnDownBlock2D", 28 | "CrossAttnDownBlock2D" 29 | ], 30 | "downsample_padding": 1, 31 | "dropout": 0.0, 32 | "dual_cross_attention": false, 33 | "encoder_hid_dim": null, 34 | "encoder_hid_dim_type": null, 35 | "flip_sin_to_cos": true, 36 | "freq_shift": 0, 37 | "in_channels": 4, 38 | "layers_per_block": 2, 39 | "mid_block_only_cross_attention": null, 40 | "mid_block_scale_factor": 1, 41 | "mid_block_type": "UNetMidBlock2DCrossAttn", 42 | "norm_eps": 1e-05, 43 | "norm_num_groups": 32, 44 | "num_attention_heads": null, 45 | "num_class_embeds": null, 46 | "only_cross_attention": false, 47 | "out_channels": 4, 48 | "projection_class_embeddings_input_dim": 2816, 49 | "resnet_out_scale_factor": 1.0, 50 | "resnet_skip_time_act": false, 51 | "resnet_time_scale_shift": "default", 52 | "reverse_transformer_layers_per_block": null, 53 | "sample_size": 128, 54 | "time_cond_proj_dim": null, 55 | "time_embedding_act_fn": null, 56 | "time_embedding_dim": null, 57 | "time_embedding_type": "positional", 58 | "timestep_post_act": null, 59 | "transformer_layers_per_block": [ 60 | 1, 61 | 2, 62 | 10 63 | ], 64 | "up_block_types": [ 65 | "CrossAttnUpBlock2D", 66 | "CrossAttnUpBlock2D", 67 | "UpBlock2D" 68 | ], 69 | "upcast_attention": false, 70 | "use_linear_projection": true 71 | } 72 | -------------------------------------------------------------------------------- /controlnet_union.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from collections import OrderedDict 15 | from dataclasses import dataclass 16 | from typing import Any, Dict, List, Optional, Tuple, Union 17 | 18 | import torch 19 | from diffusers.configuration_utils import ConfigMixin, register_to_config 20 | from diffusers.loaders import FromOriginalModelMixin 21 | from diffusers.models.attention_processor import ( 22 | ADDED_KV_ATTENTION_PROCESSORS, 23 | CROSS_ATTENTION_PROCESSORS, 24 | AttentionProcessor, 25 | AttnAddedKVProcessor, 26 | AttnProcessor, 27 | ) 28 | from diffusers.models.embeddings import ( 29 | TextImageProjection, 30 | TextImageTimeEmbedding, 31 | TextTimeEmbedding, 32 | TimestepEmbedding, 33 | Timesteps, 34 | ) 35 | from diffusers.models.modeling_utils import ModelMixin 36 | from diffusers.models.unets.unet_2d_blocks import ( 37 | CrossAttnDownBlock2D, 38 | DownBlock2D, 39 | UNetMidBlock2DCrossAttn, 40 | get_down_block, 41 | ) 42 | from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel 43 | from diffusers.utils import BaseOutput, logging 44 | from torch import nn 45 | from torch.nn import functional as F 46 | 47 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 48 | 49 | 50 | # Transformer Block 51 | # Used to exchange info between different conditions and input image 52 | # With reference to https://github.com/TencentARC/T2I-Adapter/blob/SD/ldm/modules/encoders/adapter.py#L147 53 | class QuickGELU(nn.Module): 54 | def forward(self, x: torch.Tensor): 55 | return x * torch.sigmoid(1.702 * x) 56 | 57 | 58 | class LayerNorm(nn.LayerNorm): 59 | """Subclass torch's LayerNorm to handle fp16.""" 60 | 61 | def forward(self, x: torch.Tensor): 62 | orig_type = x.dtype 63 | ret = super().forward(x) 64 | return ret.type(orig_type) 65 | 66 | 67 | class ResidualAttentionBlock(nn.Module): 68 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 69 | super().__init__() 70 | 71 | self.attn = nn.MultiheadAttention(d_model, n_head) 72 | self.ln_1 = LayerNorm(d_model) 73 | self.mlp = nn.Sequential( 74 | OrderedDict( 75 | [ 76 | ("c_fc", nn.Linear(d_model, d_model * 4)), 77 | ("gelu", QuickGELU()), 78 | ("c_proj", nn.Linear(d_model * 4, d_model)), 79 | ] 80 | ) 81 | ) 82 | self.ln_2 = LayerNorm(d_model) 83 | self.attn_mask = attn_mask 84 | 85 | def attention(self, x: torch.Tensor): 86 | self.attn_mask = ( 87 | self.attn_mask.to(dtype=x.dtype, device=x.device) 88 | if self.attn_mask is not None 89 | else None 90 | ) 91 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 92 | 93 | def forward(self, x: torch.Tensor): 94 | x = x + self.attention(self.ln_1(x)) 95 | x = x + self.mlp(self.ln_2(x)) 96 | return x 97 | 98 | 99 | # ----------------------------------------------------------------------------------------------------- 100 | 101 | 102 | @dataclass 103 | class ControlNetOutput(BaseOutput): 104 | """ 105 | The output of [`ControlNetModel`]. 106 | 107 | Args: 108 | down_block_res_samples (`tuple[torch.Tensor]`): 109 | A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should 110 | be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be 111 | used to condition the original UNet's downsampling activations. 112 | mid_down_block_re_sample (`torch.Tensor`): 113 | The activation of the midde block (the lowest sample resolution). Each tensor should be of shape 114 | `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. 115 | Output can be used to condition the original UNet's middle block activation. 116 | """ 117 | 118 | down_block_res_samples: Tuple[torch.Tensor] 119 | mid_block_res_sample: torch.Tensor 120 | 121 | 122 | class ControlNetConditioningEmbedding(nn.Module): 123 | """ 124 | Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN 125 | [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized 126 | training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the 127 | convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides 128 | (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full 129 | model) to encode image-space conditions ... into feature maps ..." 130 | """ 131 | 132 | # original setting is (16, 32, 96, 256) 133 | def __init__( 134 | self, 135 | conditioning_embedding_channels: int, 136 | conditioning_channels: int = 3, 137 | block_out_channels: Tuple[int] = (48, 96, 192, 384), 138 | ): 139 | super().__init__() 140 | 141 | self.conv_in = nn.Conv2d( 142 | conditioning_channels, block_out_channels[0], kernel_size=3, padding=1 143 | ) 144 | 145 | self.blocks = nn.ModuleList([]) 146 | 147 | for i in range(len(block_out_channels) - 1): 148 | channel_in = block_out_channels[i] 149 | channel_out = block_out_channels[i + 1] 150 | self.blocks.append( 151 | nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1) 152 | ) 153 | self.blocks.append( 154 | nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2) 155 | ) 156 | 157 | self.conv_out = zero_module( 158 | nn.Conv2d( 159 | block_out_channels[-1], 160 | conditioning_embedding_channels, 161 | kernel_size=3, 162 | padding=1, 163 | ) 164 | ) 165 | 166 | def forward(self, conditioning): 167 | embedding = self.conv_in(conditioning) 168 | embedding = F.silu(embedding) 169 | 170 | for block in self.blocks: 171 | embedding = block(embedding) 172 | embedding = F.silu(embedding) 173 | 174 | embedding = self.conv_out(embedding) 175 | 176 | return embedding 177 | 178 | 179 | class ControlNetModel_Union(ModelMixin, ConfigMixin, FromOriginalModelMixin): 180 | """ 181 | A ControlNet model. 182 | 183 | Args: 184 | in_channels (`int`, defaults to 4): 185 | The number of channels in the input sample. 186 | flip_sin_to_cos (`bool`, defaults to `True`): 187 | Whether to flip the sin to cos in the time embedding. 188 | freq_shift (`int`, defaults to 0): 189 | The frequency shift to apply to the time embedding. 190 | down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): 191 | The tuple of downsample blocks to use. 192 | only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): 193 | block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): 194 | The tuple of output channels for each block. 195 | layers_per_block (`int`, defaults to 2): 196 | The number of layers per block. 197 | downsample_padding (`int`, defaults to 1): 198 | The padding to use for the downsampling convolution. 199 | mid_block_scale_factor (`float`, defaults to 1): 200 | The scale factor to use for the mid block. 201 | act_fn (`str`, defaults to "silu"): 202 | The activation function to use. 203 | norm_num_groups (`int`, *optional*, defaults to 32): 204 | The number of groups to use for the normalization. If None, normalization and activation layers is skipped 205 | in post-processing. 206 | norm_eps (`float`, defaults to 1e-5): 207 | The epsilon to use for the normalization. 208 | cross_attention_dim (`int`, defaults to 1280): 209 | The dimension of the cross attention features. 210 | transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): 211 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for 212 | [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], 213 | [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. 214 | encoder_hid_dim (`int`, *optional*, defaults to None): 215 | If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` 216 | dimension to `cross_attention_dim`. 217 | encoder_hid_dim_type (`str`, *optional*, defaults to `None`): 218 | If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text 219 | embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. 220 | attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): 221 | The dimension of the attention heads. 222 | use_linear_projection (`bool`, defaults to `False`): 223 | class_embed_type (`str`, *optional*, defaults to `None`): 224 | The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, 225 | `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. 226 | addition_embed_type (`str`, *optional*, defaults to `None`): 227 | Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or 228 | "text". "text" will use the `TextTimeEmbedding` layer. 229 | num_class_embeds (`int`, *optional*, defaults to 0): 230 | Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing 231 | class conditioning with `class_embed_type` equal to `None`. 232 | upcast_attention (`bool`, defaults to `False`): 233 | resnet_time_scale_shift (`str`, defaults to `"default"`): 234 | Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. 235 | projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): 236 | The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when 237 | `class_embed_type="projection"`. 238 | controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): 239 | The channel order of conditional image. Will convert to `rgb` if it's `bgr`. 240 | conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): 241 | The tuple of output channel for each block in the `conditioning_embedding` layer. 242 | global_pool_conditions (`bool`, defaults to `False`): 243 | """ 244 | 245 | _supports_gradient_checkpointing = True 246 | 247 | @register_to_config 248 | def __init__( 249 | self, 250 | in_channels: int = 4, 251 | conditioning_channels: int = 3, 252 | flip_sin_to_cos: bool = True, 253 | freq_shift: int = 0, 254 | down_block_types: Tuple[str] = ( 255 | "CrossAttnDownBlock2D", 256 | "CrossAttnDownBlock2D", 257 | "CrossAttnDownBlock2D", 258 | "DownBlock2D", 259 | ), 260 | only_cross_attention: Union[bool, Tuple[bool]] = False, 261 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 262 | layers_per_block: int = 2, 263 | downsample_padding: int = 1, 264 | mid_block_scale_factor: float = 1, 265 | act_fn: str = "silu", 266 | norm_num_groups: Optional[int] = 32, 267 | norm_eps: float = 1e-5, 268 | cross_attention_dim: int = 1280, 269 | transformer_layers_per_block: Union[int, Tuple[int]] = 1, 270 | encoder_hid_dim: Optional[int] = None, 271 | encoder_hid_dim_type: Optional[str] = None, 272 | attention_head_dim: Union[int, Tuple[int]] = 8, 273 | num_attention_heads: Optional[Union[int, Tuple[int]]] = None, 274 | use_linear_projection: bool = False, 275 | class_embed_type: Optional[str] = None, 276 | addition_embed_type: Optional[str] = None, 277 | addition_time_embed_dim: Optional[int] = None, 278 | num_class_embeds: Optional[int] = None, 279 | upcast_attention: bool = False, 280 | resnet_time_scale_shift: str = "default", 281 | projection_class_embeddings_input_dim: Optional[int] = None, 282 | controlnet_conditioning_channel_order: str = "rgb", 283 | conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), 284 | global_pool_conditions: bool = False, 285 | addition_embed_type_num_heads=64, 286 | num_control_type=6, 287 | ): 288 | super().__init__() 289 | 290 | # If `num_attention_heads` is not defined (which is the case for most models) 291 | # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. 292 | # The reason for this behavior is to correct for incorrectly named variables that were introduced 293 | # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 294 | # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking 295 | # which is why we correct for the naming here. 296 | num_attention_heads = num_attention_heads or attention_head_dim 297 | 298 | # Check inputs 299 | if len(block_out_channels) != len(down_block_types): 300 | raise ValueError( 301 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 302 | ) 303 | 304 | if not isinstance(only_cross_attention, bool) and len( 305 | only_cross_attention 306 | ) != len(down_block_types): 307 | raise ValueError( 308 | f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." 309 | ) 310 | 311 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len( 312 | down_block_types 313 | ): 314 | raise ValueError( 315 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." 316 | ) 317 | 318 | if isinstance(transformer_layers_per_block, int): 319 | transformer_layers_per_block = [transformer_layers_per_block] * len( 320 | down_block_types 321 | ) 322 | 323 | # input 324 | conv_in_kernel = 3 325 | conv_in_padding = (conv_in_kernel - 1) // 2 326 | self.conv_in = nn.Conv2d( 327 | in_channels, 328 | block_out_channels[0], 329 | kernel_size=conv_in_kernel, 330 | padding=conv_in_padding, 331 | ) 332 | 333 | # time 334 | time_embed_dim = block_out_channels[0] * 4 335 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 336 | timestep_input_dim = block_out_channels[0] 337 | self.time_embedding = TimestepEmbedding( 338 | timestep_input_dim, 339 | time_embed_dim, 340 | act_fn=act_fn, 341 | ) 342 | 343 | if encoder_hid_dim_type is None and encoder_hid_dim is not None: 344 | encoder_hid_dim_type = "text_proj" 345 | self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) 346 | logger.info( 347 | "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined." 348 | ) 349 | 350 | if encoder_hid_dim is None and encoder_hid_dim_type is not None: 351 | raise ValueError( 352 | f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." 353 | ) 354 | 355 | if encoder_hid_dim_type == "text_proj": 356 | self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) 357 | elif encoder_hid_dim_type == "text_image_proj": 358 | # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much 359 | # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use 360 | # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` 361 | self.encoder_hid_proj = TextImageProjection( 362 | text_embed_dim=encoder_hid_dim, 363 | image_embed_dim=cross_attention_dim, 364 | cross_attention_dim=cross_attention_dim, 365 | ) 366 | 367 | elif encoder_hid_dim_type is not None: 368 | raise ValueError( 369 | f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." 370 | ) 371 | else: 372 | self.encoder_hid_proj = None 373 | 374 | # class embedding 375 | if class_embed_type is None and num_class_embeds is not None: 376 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 377 | elif class_embed_type == "timestep": 378 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 379 | elif class_embed_type == "identity": 380 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 381 | elif class_embed_type == "projection": 382 | if projection_class_embeddings_input_dim is None: 383 | raise ValueError( 384 | "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" 385 | ) 386 | # The projection `class_embed_type` is the same as the timestep `class_embed_type` except 387 | # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings 388 | # 2. it projects from an arbitrary input dimension. 389 | # 390 | # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. 391 | # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. 392 | # As a result, `TimestepEmbedding` can be passed arbitrary vectors. 393 | self.class_embedding = TimestepEmbedding( 394 | projection_class_embeddings_input_dim, time_embed_dim 395 | ) 396 | else: 397 | self.class_embedding = None 398 | 399 | if addition_embed_type == "text": 400 | if encoder_hid_dim is not None: 401 | text_time_embedding_from_dim = encoder_hid_dim 402 | else: 403 | text_time_embedding_from_dim = cross_attention_dim 404 | 405 | self.add_embedding = TextTimeEmbedding( 406 | text_time_embedding_from_dim, 407 | time_embed_dim, 408 | num_heads=addition_embed_type_num_heads, 409 | ) 410 | elif addition_embed_type == "text_image": 411 | # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much 412 | # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use 413 | # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` 414 | self.add_embedding = TextImageTimeEmbedding( 415 | text_embed_dim=cross_attention_dim, 416 | image_embed_dim=cross_attention_dim, 417 | time_embed_dim=time_embed_dim, 418 | ) 419 | elif addition_embed_type == "text_time": 420 | self.add_time_proj = Timesteps( 421 | addition_time_embed_dim, flip_sin_to_cos, freq_shift 422 | ) 423 | self.add_embedding = TimestepEmbedding( 424 | projection_class_embeddings_input_dim, time_embed_dim 425 | ) 426 | 427 | elif addition_embed_type is not None: 428 | raise ValueError( 429 | f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'." 430 | ) 431 | 432 | # control net conditioning embedding 433 | self.controlnet_cond_embedding = ControlNetConditioningEmbedding( 434 | conditioning_embedding_channels=block_out_channels[0], 435 | block_out_channels=conditioning_embedding_out_channels, 436 | conditioning_channels=conditioning_channels, 437 | ) 438 | 439 | # Copyright by Qi Xin(2024/07/06) 440 | # Condition Transformer(fuse single/multi conditions with input image) 441 | # The Condition Transformer augment the feature representation of conditions 442 | # The overall design is somewhat like resnet. The output of Condition Transformer is used to predict a condition bias adding to the original condition feature. 443 | # num_control_type = 6 444 | num_trans_channel = 320 445 | num_trans_head = 8 446 | num_trans_layer = 1 447 | num_proj_channel = 320 448 | task_scale_factor = num_trans_channel**0.5 449 | 450 | self.task_embedding = nn.Parameter( 451 | task_scale_factor * torch.randn(num_control_type, num_trans_channel) 452 | ) 453 | self.transformer_layes = nn.Sequential( 454 | *[ 455 | ResidualAttentionBlock(num_trans_channel, num_trans_head) 456 | for _ in range(num_trans_layer) 457 | ] 458 | ) 459 | self.spatial_ch_projs = zero_module( 460 | nn.Linear(num_trans_channel, num_proj_channel) 461 | ) 462 | # ----------------------------------------------------------------------------------------------------- 463 | 464 | # Copyright by Qi Xin(2024/07/06) 465 | # Control Encoder to distinguish different control conditions 466 | # A simple but effective module, consists of an embedding layer and a linear layer, to inject the control info to time embedding. 467 | self.control_type_proj = Timesteps( 468 | addition_time_embed_dim, flip_sin_to_cos, freq_shift 469 | ) 470 | self.control_add_embedding = TimestepEmbedding( 471 | addition_time_embed_dim * num_control_type, time_embed_dim 472 | ) 473 | # ----------------------------------------------------------------------------------------------------- 474 | 475 | self.down_blocks = nn.ModuleList([]) 476 | self.controlnet_down_blocks = nn.ModuleList([]) 477 | 478 | if isinstance(only_cross_attention, bool): 479 | only_cross_attention = [only_cross_attention] * len(down_block_types) 480 | 481 | if isinstance(attention_head_dim, int): 482 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 483 | 484 | if isinstance(num_attention_heads, int): 485 | num_attention_heads = (num_attention_heads,) * len(down_block_types) 486 | 487 | # down 488 | output_channel = block_out_channels[0] 489 | 490 | controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) 491 | controlnet_block = zero_module(controlnet_block) 492 | self.controlnet_down_blocks.append(controlnet_block) 493 | 494 | for i, down_block_type in enumerate(down_block_types): 495 | input_channel = output_channel 496 | output_channel = block_out_channels[i] 497 | is_final_block = i == len(block_out_channels) - 1 498 | 499 | down_block = get_down_block( 500 | down_block_type, 501 | num_layers=layers_per_block, 502 | transformer_layers_per_block=transformer_layers_per_block[i], 503 | in_channels=input_channel, 504 | out_channels=output_channel, 505 | temb_channels=time_embed_dim, 506 | add_downsample=not is_final_block, 507 | resnet_eps=norm_eps, 508 | resnet_act_fn=act_fn, 509 | resnet_groups=norm_num_groups, 510 | cross_attention_dim=cross_attention_dim, 511 | num_attention_heads=num_attention_heads[i], 512 | attention_head_dim=attention_head_dim[i] 513 | if attention_head_dim[i] is not None 514 | else output_channel, 515 | downsample_padding=downsample_padding, 516 | use_linear_projection=use_linear_projection, 517 | only_cross_attention=only_cross_attention[i], 518 | upcast_attention=upcast_attention, 519 | resnet_time_scale_shift=resnet_time_scale_shift, 520 | ) 521 | self.down_blocks.append(down_block) 522 | 523 | for _ in range(layers_per_block): 524 | controlnet_block = nn.Conv2d( 525 | output_channel, output_channel, kernel_size=1 526 | ) 527 | controlnet_block = zero_module(controlnet_block) 528 | self.controlnet_down_blocks.append(controlnet_block) 529 | 530 | if not is_final_block: 531 | controlnet_block = nn.Conv2d( 532 | output_channel, output_channel, kernel_size=1 533 | ) 534 | controlnet_block = zero_module(controlnet_block) 535 | self.controlnet_down_blocks.append(controlnet_block) 536 | 537 | # mid 538 | mid_block_channel = block_out_channels[-1] 539 | 540 | controlnet_block = nn.Conv2d( 541 | mid_block_channel, mid_block_channel, kernel_size=1 542 | ) 543 | controlnet_block = zero_module(controlnet_block) 544 | self.controlnet_mid_block = controlnet_block 545 | 546 | self.mid_block = UNetMidBlock2DCrossAttn( 547 | transformer_layers_per_block=transformer_layers_per_block[-1], 548 | in_channels=mid_block_channel, 549 | temb_channels=time_embed_dim, 550 | resnet_eps=norm_eps, 551 | resnet_act_fn=act_fn, 552 | output_scale_factor=mid_block_scale_factor, 553 | resnet_time_scale_shift=resnet_time_scale_shift, 554 | cross_attention_dim=cross_attention_dim, 555 | num_attention_heads=num_attention_heads[-1], 556 | resnet_groups=norm_num_groups, 557 | use_linear_projection=use_linear_projection, 558 | upcast_attention=upcast_attention, 559 | ) 560 | 561 | @classmethod 562 | def from_unet( 563 | cls, 564 | unet: UNet2DConditionModel, 565 | controlnet_conditioning_channel_order: str = "rgb", 566 | conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), 567 | load_weights_from_unet: bool = True, 568 | ): 569 | r""" 570 | Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. 571 | 572 | Parameters: 573 | unet (`UNet2DConditionModel`): 574 | The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied 575 | where applicable. 576 | """ 577 | transformer_layers_per_block = ( 578 | unet.config.transformer_layers_per_block 579 | if "transformer_layers_per_block" in unet.config 580 | else 1 581 | ) 582 | encoder_hid_dim = ( 583 | unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None 584 | ) 585 | encoder_hid_dim_type = ( 586 | unet.config.encoder_hid_dim_type 587 | if "encoder_hid_dim_type" in unet.config 588 | else None 589 | ) 590 | addition_embed_type = ( 591 | unet.config.addition_embed_type 592 | if "addition_embed_type" in unet.config 593 | else None 594 | ) 595 | addition_time_embed_dim = ( 596 | unet.config.addition_time_embed_dim 597 | if "addition_time_embed_dim" in unet.config 598 | else None 599 | ) 600 | 601 | controlnet = cls( 602 | encoder_hid_dim=encoder_hid_dim, 603 | encoder_hid_dim_type=encoder_hid_dim_type, 604 | addition_embed_type=addition_embed_type, 605 | addition_time_embed_dim=addition_time_embed_dim, 606 | transformer_layers_per_block=transformer_layers_per_block, 607 | # transformer_layers_per_block=[1, 2, 5], 608 | in_channels=unet.config.in_channels, 609 | flip_sin_to_cos=unet.config.flip_sin_to_cos, 610 | freq_shift=unet.config.freq_shift, 611 | down_block_types=unet.config.down_block_types, 612 | only_cross_attention=unet.config.only_cross_attention, 613 | block_out_channels=unet.config.block_out_channels, 614 | layers_per_block=unet.config.layers_per_block, 615 | downsample_padding=unet.config.downsample_padding, 616 | mid_block_scale_factor=unet.config.mid_block_scale_factor, 617 | act_fn=unet.config.act_fn, 618 | norm_num_groups=unet.config.norm_num_groups, 619 | norm_eps=unet.config.norm_eps, 620 | cross_attention_dim=unet.config.cross_attention_dim, 621 | attention_head_dim=unet.config.attention_head_dim, 622 | num_attention_heads=unet.config.num_attention_heads, 623 | use_linear_projection=unet.config.use_linear_projection, 624 | class_embed_type=unet.config.class_embed_type, 625 | num_class_embeds=unet.config.num_class_embeds, 626 | upcast_attention=unet.config.upcast_attention, 627 | resnet_time_scale_shift=unet.config.resnet_time_scale_shift, 628 | projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, 629 | controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, 630 | conditioning_embedding_out_channels=conditioning_embedding_out_channels, 631 | ) 632 | 633 | if load_weights_from_unet: 634 | controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) 635 | controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) 636 | controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) 637 | 638 | if controlnet.class_embedding: 639 | controlnet.class_embedding.load_state_dict( 640 | unet.class_embedding.state_dict() 641 | ) 642 | 643 | controlnet.down_blocks.load_state_dict( 644 | unet.down_blocks.state_dict(), strict=False 645 | ) 646 | controlnet.mid_block.load_state_dict( 647 | unet.mid_block.state_dict(), strict=False 648 | ) 649 | 650 | return controlnet 651 | 652 | @property 653 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors 654 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 655 | r""" 656 | Returns: 657 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 658 | indexed by its weight name. 659 | """ 660 | # set recursively 661 | processors = {} 662 | 663 | def fn_recursive_add_processors( 664 | name: str, 665 | module: torch.nn.Module, 666 | processors: Dict[str, AttentionProcessor], 667 | ): 668 | if hasattr(module, "get_processor"): 669 | processors[f"{name}.processor"] = module.get_processor( 670 | return_deprecated_lora=True 671 | ) 672 | 673 | for sub_name, child in module.named_children(): 674 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 675 | 676 | return processors 677 | 678 | for name, module in self.named_children(): 679 | fn_recursive_add_processors(name, module, processors) 680 | 681 | return processors 682 | 683 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor 684 | def set_attn_processor( 685 | self, 686 | processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], 687 | _remove_lora=False, 688 | ): 689 | r""" 690 | Sets the attention processor to use to compute attention. 691 | 692 | Parameters: 693 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 694 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 695 | for **all** `Attention` layers. 696 | 697 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 698 | processor. This is strongly recommended when setting trainable attention processors. 699 | 700 | """ 701 | count = len(self.attn_processors.keys()) 702 | 703 | if isinstance(processor, dict) and len(processor) != count: 704 | raise ValueError( 705 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 706 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 707 | ) 708 | 709 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 710 | if hasattr(module, "set_processor"): 711 | if not isinstance(processor, dict): 712 | module.set_processor(processor, _remove_lora=_remove_lora) 713 | else: 714 | module.set_processor( 715 | processor.pop(f"{name}.processor"), _remove_lora=_remove_lora 716 | ) 717 | 718 | for sub_name, child in module.named_children(): 719 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 720 | 721 | for name, module in self.named_children(): 722 | fn_recursive_attn_processor(name, module, processor) 723 | 724 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor 725 | def set_default_attn_processor(self): 726 | """ 727 | Disables custom attention processors and sets the default attention implementation. 728 | """ 729 | if all( 730 | proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS 731 | for proc in self.attn_processors.values() 732 | ): 733 | processor = AttnAddedKVProcessor() 734 | elif all( 735 | proc.__class__ in CROSS_ATTENTION_PROCESSORS 736 | for proc in self.attn_processors.values() 737 | ): 738 | processor = AttnProcessor() 739 | else: 740 | raise ValueError( 741 | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" 742 | ) 743 | 744 | self.set_attn_processor(processor, _remove_lora=True) 745 | 746 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice 747 | def set_attention_slice(self, slice_size): 748 | r""" 749 | Enable sliced attention computation. 750 | 751 | When this option is enabled, the attention module splits the input tensor in slices to compute attention in 752 | several steps. This is useful for saving some memory in exchange for a small decrease in speed. 753 | 754 | Args: 755 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 756 | When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If 757 | `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is 758 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 759 | must be a multiple of `slice_size`. 760 | """ 761 | sliceable_head_dims = [] 762 | 763 | def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): 764 | if hasattr(module, "set_attention_slice"): 765 | sliceable_head_dims.append(module.sliceable_head_dim) 766 | 767 | for child in module.children(): 768 | fn_recursive_retrieve_sliceable_dims(child) 769 | 770 | # retrieve number of attention layers 771 | for module in self.children(): 772 | fn_recursive_retrieve_sliceable_dims(module) 773 | 774 | num_sliceable_layers = len(sliceable_head_dims) 775 | 776 | if slice_size == "auto": 777 | # half the attention head size is usually a good trade-off between 778 | # speed and memory 779 | slice_size = [dim // 2 for dim in sliceable_head_dims] 780 | elif slice_size == "max": 781 | # make smallest slice possible 782 | slice_size = num_sliceable_layers * [1] 783 | 784 | slice_size = ( 785 | num_sliceable_layers * [slice_size] 786 | if not isinstance(slice_size, list) 787 | else slice_size 788 | ) 789 | 790 | if len(slice_size) != len(sliceable_head_dims): 791 | raise ValueError( 792 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 793 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 794 | ) 795 | 796 | for i in range(len(slice_size)): 797 | size = slice_size[i] 798 | dim = sliceable_head_dims[i] 799 | if size is not None and size > dim: 800 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 801 | 802 | # Recursively walk through all the children. 803 | # Any children which exposes the set_attention_slice method 804 | # gets the message 805 | def fn_recursive_set_attention_slice( 806 | module: torch.nn.Module, slice_size: List[int] 807 | ): 808 | if hasattr(module, "set_attention_slice"): 809 | module.set_attention_slice(slice_size.pop()) 810 | 811 | for child in module.children(): 812 | fn_recursive_set_attention_slice(child, slice_size) 813 | 814 | reversed_slice_size = list(reversed(slice_size)) 815 | for module in self.children(): 816 | fn_recursive_set_attention_slice(module, reversed_slice_size) 817 | 818 | def _set_gradient_checkpointing(self, module, value=False): 819 | if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): 820 | module.gradient_checkpointing = value 821 | 822 | def forward( 823 | self, 824 | sample: torch.FloatTensor, 825 | timestep: Union[torch.Tensor, float, int], 826 | encoder_hidden_states: torch.Tensor, 827 | controlnet_cond_list: torch.FloatTensor, 828 | conditioning_scale: float = 1.0, 829 | class_labels: Optional[torch.Tensor] = None, 830 | timestep_cond: Optional[torch.Tensor] = None, 831 | attention_mask: Optional[torch.Tensor] = None, 832 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 833 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 834 | guess_mode: bool = False, 835 | return_dict: bool = True, 836 | ) -> Union[ControlNetOutput, Tuple]: 837 | """ 838 | The [`ControlNetModel`] forward method. 839 | 840 | Args: 841 | sample (`torch.FloatTensor`): 842 | The noisy input tensor. 843 | timestep (`Union[torch.Tensor, float, int]`): 844 | The number of timesteps to denoise an input. 845 | encoder_hidden_states (`torch.Tensor`): 846 | The encoder hidden states. 847 | controlnet_cond (`torch.FloatTensor`): 848 | The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. 849 | conditioning_scale (`float`, defaults to `1.0`): 850 | The scale factor for ControlNet outputs. 851 | class_labels (`torch.Tensor`, *optional*, defaults to `None`): 852 | Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. 853 | timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): 854 | Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the 855 | timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep 856 | embeddings. 857 | attention_mask (`torch.Tensor`, *optional*, defaults to `None`): 858 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask 859 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large 860 | negative values to the attention scores corresponding to "discard" tokens. 861 | added_cond_kwargs (`dict`): 862 | Additional conditions for the Stable Diffusion XL UNet. 863 | cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): 864 | A kwargs dictionary that if specified is passed along to the `AttnProcessor`. 865 | guess_mode (`bool`, defaults to `False`): 866 | In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if 867 | you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. 868 | return_dict (`bool`, defaults to `True`): 869 | Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. 870 | 871 | Returns: 872 | [`~models.controlnet.ControlNetOutput`] **or** `tuple`: 873 | If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is 874 | returned where the first element is the sample tensor. 875 | """ 876 | # check channel order 877 | channel_order = self.config.controlnet_conditioning_channel_order 878 | 879 | if channel_order == "rgb": 880 | # in rgb order by default 881 | ... 882 | # elif channel_order == "bgr": 883 | # controlnet_cond = torch.flip(controlnet_cond, dims=[1]) 884 | else: 885 | raise ValueError( 886 | f"unknown `controlnet_conditioning_channel_order`: {channel_order}" 887 | ) 888 | 889 | # prepare attention_mask 890 | if attention_mask is not None: 891 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 892 | attention_mask = attention_mask.unsqueeze(1) 893 | 894 | # 1. time 895 | timesteps = timestep 896 | if not torch.is_tensor(timesteps): 897 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 898 | # This would be a good case for the `match` statement (Python 3.10+) 899 | is_mps = sample.device.type == "mps" 900 | if isinstance(timestep, float): 901 | dtype = torch.float32 if is_mps else torch.float64 902 | else: 903 | dtype = torch.int32 if is_mps else torch.int64 904 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 905 | elif len(timesteps.shape) == 0: 906 | timesteps = timesteps[None].to(sample.device) 907 | 908 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 909 | timesteps = timesteps.expand(sample.shape[0]) 910 | 911 | t_emb = self.time_proj(timesteps) 912 | 913 | # timesteps does not contain any weights and will always return f32 tensors 914 | # but time_embedding might actually be running in fp16. so we need to cast here. 915 | # there might be better ways to encapsulate this. 916 | t_emb = t_emb.to(dtype=sample.dtype) 917 | 918 | emb = self.time_embedding(t_emb, timestep_cond) 919 | aug_emb = None 920 | 921 | if self.class_embedding is not None: 922 | if class_labels is None: 923 | raise ValueError( 924 | "class_labels should be provided when num_class_embeds > 0" 925 | ) 926 | 927 | if self.config.class_embed_type == "timestep": 928 | class_labels = self.time_proj(class_labels) 929 | 930 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 931 | emb = emb + class_emb 932 | 933 | if self.config.addition_embed_type is not None: 934 | if self.config.addition_embed_type == "text": 935 | aug_emb = self.add_embedding(encoder_hidden_states) 936 | 937 | elif self.config.addition_embed_type == "text_time": 938 | if "text_embeds" not in added_cond_kwargs: 939 | raise ValueError( 940 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" 941 | ) 942 | text_embeds = added_cond_kwargs.get("text_embeds") 943 | if "time_ids" not in added_cond_kwargs: 944 | raise ValueError( 945 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" 946 | ) 947 | time_ids = added_cond_kwargs.get("time_ids") 948 | time_embeds = self.add_time_proj(time_ids.flatten()) 949 | time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) 950 | 951 | add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) 952 | add_embeds = add_embeds.to(emb.dtype) 953 | aug_emb = self.add_embedding(add_embeds) 954 | 955 | # Copyright by Qi Xin(2024/07/06) 956 | # inject control type info to time embedding to distinguish different control conditions 957 | control_type = added_cond_kwargs.get("control_type") 958 | control_embeds = self.control_type_proj(control_type.flatten()) 959 | control_embeds = control_embeds.reshape((t_emb.shape[0], -1)) 960 | control_embeds = control_embeds.to(emb.dtype) 961 | control_emb = self.control_add_embedding(control_embeds) 962 | emb = emb + control_emb 963 | # --------------------------------------------------------------------------------- 964 | 965 | emb = emb + aug_emb if aug_emb is not None else emb 966 | 967 | # 2. pre-process 968 | sample = self.conv_in(sample) 969 | indices = torch.nonzero(control_type[0]) 970 | 971 | # Copyright by Qi Xin(2024/07/06) 972 | # add single/multi conditons to input image. 973 | # Condition Transformer provides an easy and effective way to fuse different features naturally 974 | inputs = [] 975 | condition_list = [] 976 | 977 | for idx in range(indices.shape[0] + 1): 978 | if idx == indices.shape[0]: 979 | controlnet_cond = sample 980 | feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) # N * C 981 | else: 982 | controlnet_cond = self.controlnet_cond_embedding( 983 | controlnet_cond_list[indices[idx][0]] 984 | ) 985 | feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) # N * C 986 | feat_seq = feat_seq + self.task_embedding[indices[idx][0]] 987 | 988 | inputs.append(feat_seq.unsqueeze(1)) 989 | condition_list.append(controlnet_cond) 990 | 991 | x = torch.cat(inputs, dim=1) # NxLxC 992 | x = self.transformer_layes(x) 993 | 994 | controlnet_cond_fuser = sample * 0.0 995 | for idx in range(indices.shape[0]): 996 | alpha = self.spatial_ch_projs(x[:, idx]) 997 | alpha = alpha.unsqueeze(-1).unsqueeze(-1) 998 | controlnet_cond_fuser += condition_list[idx] + alpha 999 | 1000 | sample = sample + controlnet_cond_fuser 1001 | # ------------------------------------------------------------------------------------------- 1002 | 1003 | # 3. down 1004 | down_block_res_samples = (sample,) 1005 | for downsample_block in self.down_blocks: 1006 | if ( 1007 | hasattr(downsample_block, "has_cross_attention") 1008 | and downsample_block.has_cross_attention 1009 | ): 1010 | sample, res_samples = downsample_block( 1011 | hidden_states=sample, 1012 | temb=emb, 1013 | encoder_hidden_states=encoder_hidden_states, 1014 | attention_mask=attention_mask, 1015 | cross_attention_kwargs=cross_attention_kwargs, 1016 | ) 1017 | else: 1018 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 1019 | 1020 | down_block_res_samples += res_samples 1021 | 1022 | # 4. mid 1023 | if self.mid_block is not None: 1024 | sample = self.mid_block( 1025 | sample, 1026 | emb, 1027 | encoder_hidden_states=encoder_hidden_states, 1028 | attention_mask=attention_mask, 1029 | cross_attention_kwargs=cross_attention_kwargs, 1030 | ) 1031 | 1032 | # 5. Control net blocks 1033 | 1034 | controlnet_down_block_res_samples = () 1035 | 1036 | for down_block_res_sample, controlnet_block in zip( 1037 | down_block_res_samples, self.controlnet_down_blocks 1038 | ): 1039 | down_block_res_sample = controlnet_block(down_block_res_sample) 1040 | controlnet_down_block_res_samples = controlnet_down_block_res_samples + ( 1041 | down_block_res_sample, 1042 | ) 1043 | 1044 | down_block_res_samples = controlnet_down_block_res_samples 1045 | 1046 | mid_block_res_sample = self.controlnet_mid_block(sample) 1047 | 1048 | # 6. scaling 1049 | if guess_mode and not self.config.global_pool_conditions: 1050 | scales = torch.logspace( 1051 | -1, 0, len(down_block_res_samples) + 1, device=sample.device 1052 | ) # 0.1 to 1.0 1053 | scales = scales * conditioning_scale 1054 | down_block_res_samples = [ 1055 | sample * scale for sample, scale in zip(down_block_res_samples, scales) 1056 | ] 1057 | mid_block_res_sample = mid_block_res_sample * scales[-1] # last one 1058 | else: 1059 | down_block_res_samples = [ 1060 | sample * conditioning_scale for sample in down_block_res_samples 1061 | ] 1062 | mid_block_res_sample = mid_block_res_sample * conditioning_scale 1063 | 1064 | if self.config.global_pool_conditions: 1065 | down_block_res_samples = [ 1066 | torch.mean(sample, dim=(2, 3), keepdim=True) 1067 | for sample in down_block_res_samples 1068 | ] 1069 | mid_block_res_sample = torch.mean( 1070 | mid_block_res_sample, dim=(2, 3), keepdim=True 1071 | ) 1072 | 1073 | if not return_dict: 1074 | return (down_block_res_samples, mid_block_res_sample) 1075 | 1076 | return ControlNetOutput( 1077 | down_block_res_samples=down_block_res_samples, 1078 | mid_block_res_sample=mid_block_res_sample, 1079 | ) 1080 | 1081 | 1082 | def zero_module(module): 1083 | for p in module.parameters(): 1084 | nn.init.zeros_(p) 1085 | return module 1086 | -------------------------------------------------------------------------------- /diffusers-image-outpaint-workflow.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "5e709e31-1e9f-475e-a837-14abe1d4f292", 3 | "revision": 0, 4 | "last_node_id": 586, 5 | "last_link_id": 1133, 6 | "nodes": [ 7 | { 8 | "id": 499, 9 | "type": "PadImageForDiffusersOutpaint", 10 | "pos": [ 11 | -5320, 12 | 420 13 | ], 14 | "size": [ 15 | 290, 16 | 314 17 | ], 18 | "flags": {}, 19 | "order": 7, 20 | "mode": 0, 21 | "inputs": [ 22 | { 23 | "name": "image", 24 | "type": "IMAGE", 25 | "link": 906 26 | } 27 | ], 28 | "outputs": [ 29 | { 30 | "name": "IMAGE", 31 | "type": "IMAGE", 32 | "links": [] 33 | }, 34 | { 35 | "name": "MASK", 36 | "type": "MASK", 37 | "links": [] 38 | }, 39 | { 40 | "name": "diffuser_outpaint_cnet_image", 41 | "type": "IMAGE", 42 | "links": [ 43 | 1100 44 | ] 45 | } 46 | ], 47 | "properties": { 48 | "cnr_id": "ComfyUI-DiffusersImageOutpaint", 49 | "ver": "6a51ce5d3baa2171a85f51d462ef9f30ff9b5d26", 50 | "Node name for S&R": "PadImageForDiffusersOutpaint", 51 | "aux_id": "GiusTex/ComfyUI-DiffusersImageOutpaint", 52 | "widget_ue_connectable": {} 53 | }, 54 | "widgets_values": [ 55 | 720, 56 | 1280, 57 | "Middle", 58 | "Full", 59 | 50, 60 | 10, 61 | true, 62 | true, 63 | true, 64 | true 65 | ], 66 | "color": "#233", 67 | "bgcolor": "#355" 68 | }, 69 | { 70 | "id": 494, 71 | "type": "VAEDecode", 72 | "pos": [ 73 | -4710, 74 | -70 75 | ], 76 | "size": [ 77 | 140, 78 | 46 79 | ], 80 | "flags": {}, 81 | "order": 9, 82 | "mode": 0, 83 | "inputs": [ 84 | { 85 | "name": "samples", 86 | "type": "LATENT", 87 | "link": 1101 88 | }, 89 | { 90 | "name": "vae", 91 | "type": "VAE", 92 | "link": 908 93 | } 94 | ], 95 | "outputs": [ 96 | { 97 | "name": "IMAGE", 98 | "type": "IMAGE", 99 | "links": [ 100 | 899 101 | ] 102 | } 103 | ], 104 | "properties": { 105 | "cnr_id": "comfy-core", 106 | "ver": "0.3.30", 107 | "Node name for S&R": "VAEDecode", 108 | "widget_ue_connectable": {} 109 | }, 110 | "widgets_values": [], 111 | "color": "#323", 112 | "bgcolor": "#535" 113 | }, 114 | { 115 | "id": 495, 116 | "type": "PreviewImage", 117 | "pos": [ 118 | -4550, 119 | -70 120 | ], 121 | "size": [ 122 | 250, 123 | 310 124 | ], 125 | "flags": {}, 126 | "order": 10, 127 | "mode": 0, 128 | "inputs": [ 129 | { 130 | "name": "images", 131 | "type": "IMAGE", 132 | "link": 899 133 | } 134 | ], 135 | "outputs": [], 136 | "properties": { 137 | "cnr_id": "comfy-core", 138 | "ver": "0.3.30", 139 | "Node name for S&R": "PreviewImage", 140 | "widget_ue_connectable": {} 141 | }, 142 | "widgets_values": [] 143 | }, 144 | { 145 | "id": 497, 146 | "type": "DualCLIPLoader", 147 | "pos": [ 148 | -5570, 149 | 200 150 | ], 151 | "size": [ 152 | 270, 153 | 130 154 | ], 155 | "flags": {}, 156 | "order": 0, 157 | "mode": 0, 158 | "inputs": [], 159 | "outputs": [ 160 | { 161 | "name": "CLIP", 162 | "type": "CLIP", 163 | "links": [ 164 | 1110, 165 | 1111 166 | ] 167 | } 168 | ], 169 | "properties": { 170 | "cnr_id": "comfy-core", 171 | "ver": "0.3.30", 172 | "Node name for S&R": "DualCLIPLoader", 173 | "widget_ue_connectable": {} 174 | }, 175 | "widgets_values": [ 176 | "clip_l.safetensors", 177 | "clip_g.safetensors", 178 | "sdxl", 179 | "default" 180 | ], 181 | "color": "#223", 182 | "bgcolor": "#335" 183 | }, 184 | { 185 | "id": 569, 186 | "type": "EncodeDiffusersOutpaintPrompt", 187 | "pos": [ 188 | -5280, 189 | 200 190 | ], 191 | "size": [ 192 | 252.08065795898438, 193 | 136 194 | ], 195 | "flags": {}, 196 | "order": 6, 197 | "mode": 0, 198 | "inputs": [ 199 | { 200 | "name": "clip", 201 | "type": "CLIP", 202 | "link": 1111 203 | } 204 | ], 205 | "outputs": [ 206 | { 207 | "name": "diffusers_conditioning", 208 | "type": "CONDITIONING", 209 | "links": [ 210 | 1099 211 | ] 212 | } 213 | ], 214 | "properties": { 215 | "cnr_id": "ComfyUI-DiffusersImageOutpaint", 216 | "ver": "6a51ce5d3baa2171a85f51d462ef9f30ff9b5d26", 217 | "Node name for S&R": "EncodeDiffusersOutpaintPrompt", 218 | "aux_id": "GiusTex/ComfyUI-DiffusersImageOutpaint", 219 | "widget_ue_connectable": {} 220 | }, 221 | "widgets_values": [ 222 | "auto", 223 | "auto", 224 | "" 225 | ], 226 | "color": "#322", 227 | "bgcolor": "#533" 228 | }, 229 | { 230 | "id": 501, 231 | "type": "VAELoader", 232 | "pos": [ 233 | -5010, 234 | 260 235 | ], 236 | "size": [ 237 | 270, 238 | 58 239 | ], 240 | "flags": {}, 241 | "order": 1, 242 | "mode": 0, 243 | "inputs": [], 244 | "outputs": [ 245 | { 246 | "name": "VAE", 247 | "type": "VAE", 248 | "links": [ 249 | 908 250 | ] 251 | } 252 | ], 253 | "properties": { 254 | "cnr_id": "comfy-core", 255 | "ver": "0.3.30", 256 | "Node name for S&R": "VAELoader", 257 | "widget_ue_connectable": {} 258 | }, 259 | "widgets_values": [ 260 | "sdxl_vae.safetensors" 261 | ], 262 | "color": "#223", 263 | "bgcolor": "#335" 264 | }, 265 | { 266 | "id": 573, 267 | "type": "DiffusersImageOutpaint", 268 | "pos": [ 269 | -4990, 270 | -60 271 | ], 272 | "size": [ 273 | 247.341796875, 274 | 278 275 | ], 276 | "flags": {}, 277 | "order": 8, 278 | "mode": 0, 279 | "inputs": [ 280 | { 281 | "name": "model", 282 | "type": "MODEL", 283 | "link": 1131 284 | }, 285 | { 286 | "name": "scheduler_configs", 287 | "type": "SCHEDULER", 288 | "link": 1132 289 | }, 290 | { 291 | "name": "control_net", 292 | "type": "CONTROL_NET", 293 | "link": 1133 294 | }, 295 | { 296 | "name": "positive", 297 | "type": "CONDITIONING", 298 | "link": 1098 299 | }, 300 | { 301 | "name": "negative", 302 | "type": "CONDITIONING", 303 | "link": 1099 304 | }, 305 | { 306 | "name": "diffuser_outpaint_cnet_image", 307 | "type": "IMAGE", 308 | "link": 1100 309 | } 310 | ], 311 | "outputs": [ 312 | { 313 | "name": "LATENT", 314 | "type": "LATENT", 315 | "links": [ 316 | 1101 317 | ] 318 | } 319 | ], 320 | "properties": { 321 | "cnr_id": "ComfyUI-DiffusersImageOutpaint", 322 | "ver": "6a51ce5d3baa2171a85f51d462ef9f30ff9b5d26", 323 | "Node name for S&R": "DiffusersImageOutpaint", 324 | "aux_id": "GiusTex/ComfyUI-DiffusersImageOutpaint", 325 | "widget_ue_connectable": {} 326 | }, 327 | "widgets_values": [ 328 | 1.5, 329 | 1, 330 | 8, 331 | "auto", 332 | "auto", 333 | false 334 | ], 335 | "color": "#232", 336 | "bgcolor": "#353" 337 | }, 338 | { 339 | "id": 531, 340 | "type": "LoadDiffuserModel", 341 | "pos": [ 342 | -5610, 343 | -190 344 | ], 345 | "size": [ 346 | 290, 347 | 150 348 | ], 349 | "flags": {}, 350 | "order": 2, 351 | "mode": 0, 352 | "inputs": [], 353 | "outputs": [ 354 | { 355 | "name": "model", 356 | "type": "MODEL", 357 | "links": [ 358 | 1131 359 | ] 360 | }, 361 | { 362 | "name": "scheduler configs", 363 | "type": "SCHEDULER", 364 | "links": [ 365 | 1132 366 | ] 367 | } 368 | ], 369 | "properties": { 370 | "cnr_id": "ComfyUI-DiffusersImageOutpaint", 371 | "ver": "6a51ce5d3baa2171a85f51d462ef9f30ff9b5d26", 372 | "Node name for S&R": "LoadDiffuserModel", 373 | "aux_id": "GiusTex/ComfyUI-DiffusersImageOutpaint", 374 | "widget_ue_connectable": {} 375 | }, 376 | "widgets_values": [ 377 | "RealVisXL_V5.0_Lightning_unet.safetensors", 378 | "auto", 379 | "auto", 380 | "sdxl" 381 | ], 382 | "color": "#223", 383 | "bgcolor": "#335" 384 | }, 385 | { 386 | "id": 500, 387 | "type": "LoadImage", 388 | "pos": [ 389 | -5640, 390 | 430 391 | ], 392 | "size": [ 393 | 270, 394 | 314 395 | ], 396 | "flags": {}, 397 | "order": 3, 398 | "mode": 0, 399 | "inputs": [], 400 | "outputs": [ 401 | { 402 | "name": "IMAGE", 403 | "type": "IMAGE", 404 | "links": [ 405 | 906 406 | ] 407 | }, 408 | { 409 | "name": "MASK", 410 | "type": "MASK", 411 | "links": null 412 | } 413 | ], 414 | "properties": { 415 | "cnr_id": "comfy-core", 416 | "ver": "0.3.30", 417 | "Node name for S&R": "LoadImage", 418 | "widget_ue_connectable": {} 419 | }, 420 | "widgets_values": [ 421 | "20230403_183417.jpg", 422 | "image" 423 | ] 424 | }, 425 | { 426 | "id": 570, 427 | "type": "EncodeDiffusersOutpaintPrompt", 428 | "pos": [ 429 | -5280, 430 | 10 431 | ], 432 | "size": [ 433 | 252.08065795898438, 434 | 136 435 | ], 436 | "flags": {}, 437 | "order": 5, 438 | "mode": 0, 439 | "inputs": [ 440 | { 441 | "name": "clip", 442 | "type": "CLIP", 443 | "link": 1110 444 | } 445 | ], 446 | "outputs": [ 447 | { 448 | "name": "diffusers_conditioning", 449 | "type": "CONDITIONING", 450 | "links": [ 451 | 1098 452 | ] 453 | } 454 | ], 455 | "properties": { 456 | "cnr_id": "ComfyUI-DiffusersImageOutpaint", 457 | "ver": "6a51ce5d3baa2171a85f51d462ef9f30ff9b5d26", 458 | "Node name for S&R": "EncodeDiffusersOutpaintPrompt", 459 | "aux_id": "GiusTex/ComfyUI-DiffusersImageOutpaint", 460 | "widget_ue_connectable": {} 461 | }, 462 | "widgets_values": [ 463 | "auto", 464 | "auto", 465 | "a verdant valley with waterfalls, rainbow" 466 | ], 467 | "color": "#232", 468 | "bgcolor": "#353" 469 | }, 470 | { 471 | "id": 532, 472 | "type": "LoadDiffuserControlnet", 473 | "pos": [ 474 | -5650, 475 | 10 476 | ], 477 | "size": [ 478 | 330, 479 | 130 480 | ], 481 | "flags": {}, 482 | "order": 4, 483 | "mode": 0, 484 | "inputs": [], 485 | "outputs": [ 486 | { 487 | "name": "CONTROL_NET", 488 | "type": "CONTROL_NET", 489 | "links": [ 490 | 1133 491 | ] 492 | } 493 | ], 494 | "properties": { 495 | "cnr_id": "ComfyUI-DiffusersImageOutpaint", 496 | "ver": "6a51ce5d3baa2171a85f51d462ef9f30ff9b5d26", 497 | "Node name for S&R": "LoadDiffuserControlnet", 498 | "aux_id": "GiusTex/ComfyUI-DiffusersImageOutpaint", 499 | "widget_ue_connectable": {} 500 | }, 501 | "widgets_values": [ 502 | "controlnet-union-promax_sdxl.safetensors", 503 | "auto", 504 | "auto", 505 | "controlnet-sdxl-promax" 506 | ], 507 | "color": "#432", 508 | "bgcolor": "#653" 509 | } 510 | ], 511 | "links": [ 512 | [ 513 | 899, 514 | 494, 515 | 0, 516 | 495, 517 | 0, 518 | "IMAGE" 519 | ], 520 | [ 521 | 906, 522 | 500, 523 | 0, 524 | 499, 525 | 0, 526 | "IMAGE" 527 | ], 528 | [ 529 | 908, 530 | 501, 531 | 0, 532 | 494, 533 | 1, 534 | "VAE" 535 | ], 536 | [ 537 | 1098, 538 | 570, 539 | 0, 540 | 573, 541 | 3, 542 | "CONDITIONING" 543 | ], 544 | [ 545 | 1099, 546 | 569, 547 | 0, 548 | 573, 549 | 4, 550 | "CONDITIONING" 551 | ], 552 | [ 553 | 1100, 554 | 499, 555 | 2, 556 | 573, 557 | 5, 558 | "IMAGE" 559 | ], 560 | [ 561 | 1101, 562 | 573, 563 | 0, 564 | 494, 565 | 0, 566 | "LATENT" 567 | ], 568 | [ 569 | 1110, 570 | 497, 571 | 0, 572 | 570, 573 | 0, 574 | "CLIP" 575 | ], 576 | [ 577 | 1111, 578 | 497, 579 | 0, 580 | 569, 581 | 0, 582 | "CLIP" 583 | ], 584 | [ 585 | 1131, 586 | 531, 587 | 0, 588 | 573, 589 | 0, 590 | "MODEL" 591 | ], 592 | [ 593 | 1132, 594 | 531, 595 | 1, 596 | 573, 597 | 1, 598 | "SCHEDULER" 599 | ], 600 | [ 601 | 1133, 602 | 532, 603 | 0, 604 | 573, 605 | 2, 606 | "CONTROL_NET" 607 | ] 608 | ], 609 | "groups": [], 610 | "config": {}, 611 | "extra": { 612 | "ds": { 613 | "scale": 0.7972024500000006, 614 | "offset": [ 615 | 5835.209356481534, 616 | 220.1462025110432 617 | ] 618 | }, 619 | "frontendVersion": "1.19.9", 620 | "groupNodes": {}, 621 | "ue_links": [], 622 | "links_added_by_ue": [], 623 | "VHS_latentpreview": true, 624 | "VHS_latentpreviewrate": 0, 625 | "VHS_MetadataImage": true, 626 | "VHS_KeepIntermediate": true 627 | }, 628 | "version": 0.4 629 | } -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from PIL import Image, ImageDraw 4 | from .utils import get_config_folder_list, tensor2pil, pil2tensor, diffuserOutpaintSamples, get_device_by_name, get_dtype_by_name, clearVram, test_scheduler_scale_model_input 5 | import folder_paths 6 | from diffusers.models import UNet2DConditionModel 7 | from diffusers import TCDScheduler 8 | from .controlnet_union import ControlNetModel_Union 9 | from diffusers.models.model_loading_utils import load_state_dict 10 | from safetensors.torch import load_file 11 | 12 | import logging 13 | 14 | 15 | # Get the absolute path of various directories 16 | my_dir = os.path.dirname(os.path.abspath(__file__)) 17 | 18 | def update_folder_names_and_paths(key, targets=[]): 19 | # check for existing key 20 | base = folder_paths.folder_names_and_paths.get(key, ([], {})) 21 | base = base[0] if isinstance(base[0], (list, set, tuple)) else [] 22 | # find base key & add w/ fallback, sanity check + warning 23 | target = next((x for x in targets if x in folder_paths.folder_names_and_paths), targets[0]) 24 | orig, _ = folder_paths.folder_names_and_paths.get(target, ([], {})) 25 | folder_paths.folder_names_and_paths[key] = (orig or base, {".gguf"}) 26 | if base and base != orig: 27 | logging.warning(f"Unknown file list already present on key {key}: {base}") 28 | 29 | def can_expand(source_width, source_height, target_width, target_height, alignment): 30 | """Checks if the image can be expanded based on the alignment.""" 31 | if alignment in ("Left", "Right") and source_width >= target_width: 32 | return False 33 | if alignment in ("Top", "Bottom") and source_height >= target_height: 34 | return False 35 | return True 36 | 37 | 38 | class PadImageForDiffusersOutpaint: 39 | _alignment_options = ["Middle", "Left", "Right", "Top", "Bottom"] 40 | _resize_option = ["Full", "50%", "33%", "25%", "Custom"] 41 | @classmethod 42 | def INPUT_TYPES(s): 43 | return { 44 | "required": { 45 | "image": ("IMAGE",), 46 | "width": ("INT", {"default": 720, "tooltip": "The width used for the image."}), 47 | "height": ("INT", {"default": 1280, "tooltip": "The height used for the image."}), 48 | "alignment": (s._alignment_options, {"tooltip": "Where the original image should be in the outpainted one"}), 49 | "resize_image": (s._resize_option, {"tooltip": "Resize input image"}), 50 | "custom_resize_image_percentage": ("INT", {"min": 1, "default": 50, "max": 100, "step": 1, "tooltip": "Custom resize (%)"}), 51 | "mask_overlap_percentage": ("INT", {"min": 1, "default": 10, "max": 50, "step": 1, "tooltip": "Mask overlap (%)"}), 52 | "overlap_left": ("BOOLEAN", {"default": True}), 53 | "overlap_right": ("BOOLEAN", {"default": True}), 54 | "overlap_top": ("BOOLEAN", {"default": True}), 55 | "overlap_bottom": ("BOOLEAN", {"default": True}), 56 | }, 57 | } 58 | 59 | RETURN_TYPES = ("IMAGE", "MASK", "IMAGE") 60 | RETURN_NAMES = ("IMAGE", "MASK", "diffuser_outpaint_cnet_image") 61 | FUNCTION = "prepare_image_and_mask" 62 | CATEGORY = "DiffusersOutpaint" 63 | 64 | def prepare_image_and_mask(self, image, width, height, mask_overlap_percentage, resize_image, custom_resize_image_percentage, overlap_left, overlap_right, overlap_top, overlap_bottom, alignment="Middle"): 65 | im=tensor2pil(image) 66 | source=im.convert('RGB') 67 | 68 | target_size = (width, height) 69 | 70 | # Calculate the scaling factor to fit the image within the target size 71 | scale_factor = min(target_size[0] / source.width, target_size[1] / source.height) 72 | new_width = int(source.width * scale_factor) 73 | new_height = int(source.height * scale_factor) 74 | 75 | # Resize the source image to fit within target size 76 | source = source.resize((new_width, new_height), Image.LANCZOS) 77 | 78 | # Initialize new_width and new_height 79 | new_width, new_height = source.width, source.height 80 | 81 | # Apply resize option using percentages 82 | if resize_image == "Full": 83 | resize_percentage = 100 84 | elif resize_image == "50%": 85 | resize_percentage = 50 86 | elif resize_image == "33%": 87 | resize_percentage = 33 88 | elif resize_image == "25%": 89 | resize_percentage = 25 90 | else: # Custom 91 | resize_percentage = custom_resize_image_percentage 92 | 93 | # Calculate new dimensions based on percentage 94 | resize_factor = resize_percentage / 100 95 | new_width = int(source.width * resize_factor) 96 | new_height = int(source.height * resize_factor) 97 | 98 | # Ensure minimum size of 64 pixels 99 | new_width = max(new_width, 64) 100 | new_height = max(new_height, 64) 101 | 102 | # Resize the image 103 | source = source.resize((new_width, new_height), Image.LANCZOS) 104 | 105 | # Calculate the overlap in pixels based on the percentage 106 | overlap_x = int(new_width * (mask_overlap_percentage / 100)) 107 | overlap_y = int(new_height * (mask_overlap_percentage / 100)) 108 | 109 | # Ensure minimum overlap of 1 pixel 110 | overlap_x = max(overlap_x, 1) 111 | overlap_y = max(overlap_y, 1) 112 | 113 | # Calculate margins based on alignment 114 | if alignment == "Middle": 115 | margin_x = (target_size[0] - source.width) // 2 116 | margin_y = (target_size[1] - source.height) // 2 117 | elif alignment == "Left": 118 | margin_x = 0 119 | margin_y = (target_size[1] - source.height) // 2 120 | elif alignment == "Right": 121 | margin_x = target_size[0] - source.width 122 | margin_y = (target_size[1] - source.height) // 2 123 | elif alignment == "Top": 124 | margin_x = (target_size[0] - source.width) // 2 125 | margin_y = 0 126 | elif alignment == "Bottom": 127 | margin_x = (target_size[0] - source.width) // 2 128 | margin_y = target_size[1] - source.height 129 | 130 | # Adjust margins to eliminate gaps 131 | margin_x = max(0, min(margin_x, target_size[0] - new_width)) 132 | margin_y = max(0, min(margin_y, target_size[1] - new_height)) 133 | 134 | # Create a new background image and paste the resized source image 135 | background = Image.new('RGB', target_size, (255, 255, 255)) 136 | background.paste(source, (margin_x, margin_y)) 137 | 138 | image=pil2tensor(background) 139 | #---------------------------------------------------- 140 | # Create the mask 141 | d1, d2, d3, d4 = image.size() 142 | left, top, bottom, right = 0, 0, 0, 0 143 | # Image 144 | new_image = torch.ones( 145 | (d1, d2 + top + bottom, d3 + left + right, d4), 146 | dtype=torch.float32, 147 | ) * 0.5 148 | new_image[:, top:top + d2, left:left + d3, :] = image 149 | 150 | im=tensor2pil(new_image) 151 | pil_new_image=im.convert('RGB') 152 | #---------------------------------------------------- 153 | 154 | # Create the mask 155 | mask = Image.new('L', target_size, 255) 156 | mask_draw = ImageDraw.Draw(mask) 157 | #---------------------------------------------------- 158 | # Calculate overlap areas 159 | white_gaps_patch = 2 160 | 161 | left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch 162 | right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width - white_gaps_patch 163 | top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch 164 | bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height - white_gaps_patch 165 | #---------------------------------------------------- 166 | # Mask coordinates 167 | if alignment == "Left": 168 | left_overlap = margin_x + overlap_x if overlap_left else margin_x 169 | elif alignment == "Right": 170 | right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width 171 | elif alignment == "Top": 172 | top_overlap = margin_y + overlap_y if overlap_top else margin_y 173 | elif alignment == "Bottom": 174 | bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height 175 | 176 | # Draw the mask 177 | mask_draw.rectangle([ 178 | (left_overlap, top_overlap), 179 | (right_overlap, bottom_overlap) 180 | ], fill=0) 181 | 182 | tensor_mask=pil2tensor(mask) 183 | #---------------------------------------------------- 184 | if not can_expand(background.width, background.height, width, height, alignment): 185 | alignment = "Middle" 186 | 187 | cnet_image = pil_new_image.copy() # copy background as cnet_image 188 | cnet_image.paste(0, (0, 0), mask) # paste mask over cnet_image, cropping it a bit 189 | 190 | tensor_cnet_image=pil2tensor(cnet_image) 191 | 192 | return (new_image, tensor_mask, tensor_cnet_image,) 193 | 194 | 195 | class LoadDiffuserModel: 196 | @classmethod 197 | def INPUT_TYPES(s): 198 | return { 199 | "required": { 200 | "unet_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the unet (model) to load."}), 201 | "device": (["auto", "cuda", "cpu", "mps", "xpu", "meta"],{"default": "auto", "tooltip": "Device for inference, default is auto checked by comfyui"}), 202 | "dtype": (["auto","fp16","bf16","fp32", "fp8_e4m3fn", "fp8_e4m3fnuz", "fp8_e5m2", "fp8_e5m2fnuz"],{"default":"auto", "tooltip": "Model precision for inference, default is auto checked by comfyui"}), 203 | "model_type": (get_config_folder_list("configs"), {"default": "sdxl", "tooltip": "The json configs used for the unet. (Put unet config in \"configs/your model type/unet\", and scheduler config in \"configs/your model type/scheduler\")."}), 204 | }, 205 | } 206 | 207 | RETURN_TYPES = ("MODEL", "SCHEDULER") 208 | RETURN_NAMES = ("model", "scheduler configs") 209 | FUNCTION = "load" 210 | CATEGORY = "DiffusersOutpaint" 211 | 212 | def load(self, unet_name, device, dtype, model_type): 213 | 214 | # Go 2 folders back 215 | comfy_dir = os.path.dirname(os.path.dirname(my_dir)) 216 | unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name) 217 | 218 | device = get_device_by_name(device) 219 | dtype = get_dtype_by_name(dtype) 220 | 221 | if model_type == "sdxl": 222 | print("Loading sdxl unet...") 223 | 224 | unet = UNet2DConditionModel.from_config(f"{comfy_dir}/custom_nodes/ComfyUI-DiffusersImageOutpaint/configs", subfolder=f"{model_type}/unet").to(device, dtype) 225 | unet.load_state_dict(load_file(unet_path)) 226 | 227 | scheduler = TCDScheduler.from_config(f"{comfy_dir}/custom_nodes/ComfyUI-DiffusersImageOutpaint/configs", subfolder=f"{model_type}/scheduler") 228 | 229 | scale_model_input_method = test_scheduler_scale_model_input(comfy_dir, model_type) 230 | 231 | scheduler_configs = { 232 | "scheduler": scheduler, 233 | "scale_model_input_method": scale_model_input_method, 234 | } 235 | 236 | return (unet, scheduler_configs,) 237 | 238 | 239 | class LoadDiffuserControlnet: 240 | @classmethod 241 | def INPUT_TYPES(s): 242 | return { 243 | "required": { 244 | "controlnet_model": (folder_paths.get_filename_list("controlnet"), {"tooltip": "The controlnet model used for denoising the input latent."}), 245 | "device": (["auto", "cuda", "cpu", "mps", "xpu", "meta"],{"default": "auto", "tooltip": "Device for inference, default is auto checked by comfyui"}), 246 | "dtype": (["auto","fp16","bf16","fp32", "fp8_e4m3fn", "fp8_e4m3fnuz", "fp8_e5m2", "fp8_e5m2fnuz"],{"default":"auto", "tooltip": "Model precision for inference, default is auto checked by comfyui"}), 247 | "controlnet_type": (get_config_folder_list("configs"), {"default": "controlnet-sdxl-promax", "tooltip": "The json configs used for controlnet. (Put config(s) in \"configs/your controlnet type\")."}), 248 | }, 249 | } 250 | 251 | RETURN_TYPES = ("CONTROL_NET",) 252 | FUNCTION = "load" 253 | CATEGORY = "DiffusersOutpaint" 254 | 255 | def load(self, controlnet_model, device, dtype, controlnet_type): 256 | 257 | # Go 2 folders back 258 | comfy_dir = os.path.dirname(os.path.dirname(my_dir)) 259 | controlnet_path = folder_paths.get_full_path_or_raise("controlnet", controlnet_model) 260 | 261 | device = get_device_by_name(device) 262 | dtype = get_dtype_by_name(dtype) 263 | 264 | if controlnet_type == "controlnet-sdxl-promax": 265 | print("Loading controlnet-sdxl-promax...") 266 | controlnet_model = ControlNetModel_Union.from_config(f"{comfy_dir}/custom_nodes/ComfyUI-DiffusersImageOutpaint/configs/{controlnet_type}/config_promax.json") 267 | 268 | state_dict = load_state_dict(load_file(controlnet_path)) 269 | 270 | model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model( 271 | controlnet_model, state_dict, controlnet_path, controlnet_path 272 | ) 273 | 274 | controlnet_model.to(device, dtype) 275 | 276 | del model, state_dict, controlnet_path 277 | 278 | clearVram(device) 279 | 280 | return (controlnet_model,) 281 | 282 | 283 | class EncodeDiffusersOutpaintPrompt: 284 | @classmethod 285 | def INPUT_TYPES(s): 286 | return { 287 | "required": { 288 | "device": (["auto", "cuda", "cpu", "mps", "xpu", "meta"],{"default": "auto", "tooltip": "Device for inference, default is auto checked by comfyui"}), 289 | "dtype": (["auto","fp16","bf16","fp32", "fp8_e4m3fn", "fp8_e4m3fnuz", "fp8_e5m2", "fp8_e5m2fnuz"],{"default":"auto", "tooltip": "Model precision for inference, default is auto checked by comfyui"}), 290 | "text": ("STRING", {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}), 291 | "clip": ("CLIP", {"tooltip": "The CLIP model used for encoding the text."}) 292 | } 293 | } 294 | RETURN_TYPES = ("CONDITIONING",) 295 | RETURN_NAMES = ("diffusers_conditioning",) 296 | OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",) 297 | FUNCTION = "encode" 298 | CATEGORY = "DiffusersOutpaint" 299 | DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images." 300 | 301 | def encode(self, device, dtype, text, clip): 302 | device = get_device_by_name(device) 303 | dtype = get_dtype_by_name(dtype) 304 | 305 | text = f"{text}, high quality, 4k" 306 | tokens = clip.tokenize(text) 307 | output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True) 308 | prompt_embeds = output.pop("cond") 309 | 310 | prompt_embeds = prompt_embeds.to(device, dtype=dtype) 311 | pooled_prompt_embeds = output["pooled_output"].to(device, dtype=dtype) 312 | 313 | bs_embed, seq_len, _ = prompt_embeds.shape 314 | 315 | # duplicate text embeddings for each generation per prompt, using mps friendly method 316 | prompt_embeds = prompt_embeds.repeat(1, 1, 1) 317 | prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1) 318 | 319 | pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(bs_embed * 1, -1) 320 | 321 | diffusers_conditioning = { 322 | "prompt_embeds": prompt_embeds, 323 | "pooled_prompt_embeds": pooled_prompt_embeds, 324 | } 325 | 326 | return (diffusers_conditioning,) 327 | 328 | 329 | class DiffusersImageOutpaint: 330 | @classmethod 331 | def INPUT_TYPES(s): 332 | return { 333 | "required": { 334 | "model": ("MODEL", {"tooltip": "The model used for denoising the input latent."}), 335 | "scheduler_configs": ("SCHEDULER",), 336 | "control_net": ("CONTROL_NET",), 337 | "positive": ("CONDITIONING", {"tooltip": "The prompt describing what you want."}), 338 | "negative": ("CONDITIONING", {"tooltip": "The prompt describing what you don't want."}), 339 | "diffuser_outpaint_cnet_image": ("IMAGE", {"tooltip": "The image to outpaint."}), 340 | "guidance_scale": ("FLOAT", {"default": 1.50, "min": 1.01, "max": 10, "step": 0.01, "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt, however too high values will negatively impact quality."}), 341 | "controlnet_strength": ("FLOAT", {"default": 1.00, "min": 0.00, "max": 10, "step": 0.01}), 342 | "steps": ("INT", {"default": 8, "min": 4, "max": 20, "tooltip": "The number of steps used in the denoising process."}), 343 | "device": (["auto", "cuda", "cpu", "mps", "xpu", "meta"],{"default": "auto", "tooltip": "Device for inference, default is auto checked by comfyui"}), 344 | "dtype": (["auto","fp16","bf16","fp32", "fp8_e4m3fn", "fp8_e4m3fnuz", "fp8_e5m2", "fp8_e5m2fnuz"],{"default":"auto", "tooltip": "Model precision for inference, default is auto checked by comfyui"}), 345 | "sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "Inference by default needs around 8gb vram, if this option is on it will move controlnet and unet back and forth between cpu and vram, to have only one model loaded at a time (around 6 gb vram used), useful for gpus under 8gb but will impact inference speed."}), 346 | }, 347 | } 348 | 349 | RETURN_TYPES = ("LATENT",) 350 | FUNCTION = "sample" 351 | CATEGORY = "DiffusersOutpaint" 352 | 353 | def sample(self, device, dtype, sequential_cpu_offload, scheduler_configs, model, control_net, positive, negative, diffuser_outpaint_cnet_image, guidance_scale, controlnet_strength, steps): 354 | cnet_image = diffuser_outpaint_cnet_image 355 | cnet_image=tensor2pil(cnet_image) 356 | cnet_image=cnet_image.convert('RGB') 357 | 358 | keep_model_device = sequential_cpu_offload 359 | 360 | scheduler = scheduler_configs["scheduler"] 361 | scale_model_input_method = scheduler_configs["scale_model_input_method"] 362 | 363 | last_rgb_latent = diffuserOutpaintSamples(device, dtype, keep_model_device, scheduler, scale_model_input_method, model, control_net, positive, negative, 364 | cnet_image, controlnet_strength, guidance_scale, steps) 365 | 366 | return ({"samples":last_rgb_latent},) 367 | -------------------------------------------------------------------------------- /pipeline_fill_sd_xl.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List, Optional, Union 16 | 17 | import cv2 18 | import PIL.Image 19 | import torch 20 | from diffusers.image_processor import PipelineImageInput, VaeImageProcessor 21 | from diffusers.schedulers import KarrasDiffusionSchedulers 22 | from diffusers.utils.torch_utils import randn_tensor 23 | from tqdm import tqdm 24 | 25 | from .controlnet_union import ControlNetModel_Union 26 | from comfy.utils import ProgressBar 27 | 28 | 29 | def latents_to_rgb(latents): 30 | weights = ((60, -60, 25, -70), (60, -5, 15, -50), (60, 10, -5, -35)) 31 | 32 | weights_tensor = torch.t( 33 | torch.tensor(weights, dtype=latents.dtype).to(latents.device) 34 | ) 35 | biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to( 36 | latents.device 37 | ) 38 | rgb_tensor = torch.einsum( 39 | "...lxy,lr -> ...rxy", latents, weights_tensor 40 | ) + biases_tensor.unsqueeze(-1).unsqueeze(-1) 41 | image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy() 42 | image_array = image_array.transpose(1, 2, 0) # Change the order of dimensions 43 | 44 | denoised_image = cv2.fastNlMeansDenoisingColored(image_array, None, 10, 10, 7, 21) 45 | blurred_image = cv2.GaussianBlur(denoised_image, (5, 5), 0) 46 | final_image = PIL.Image.fromarray(blurred_image) 47 | 48 | width, height = final_image.size 49 | final_image = final_image.resize( 50 | (width * 8, height * 8), PIL.Image.Resampling.LANCZOS 51 | ) 52 | 53 | return final_image 54 | 55 | 56 | def retrieve_timesteps( 57 | scheduler, 58 | num_inference_steps: Optional[int] = None, 59 | device: Optional[Union[str, torch.device]] = None, 60 | **kwargs, 61 | ): 62 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 63 | timesteps = scheduler.timesteps 64 | 65 | return timesteps, num_inference_steps 66 | 67 | 68 | class StableDiffusionXLFillPipeline: 69 | 70 | def __init__( 71 | self, 72 | ): 73 | super().__init__() 74 | 75 | self.vae_scale_factor = 8 76 | self.image_processor = VaeImageProcessor( 77 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True 78 | ) 79 | self.control_image_processor = VaeImageProcessor( 80 | vae_scale_factor=self.vae_scale_factor, 81 | do_convert_rgb=True, 82 | do_normalize=False, 83 | ) 84 | 85 | def prepare_image(self, image, device, dtype, do_classifier_free_guidance=False): 86 | image = self.control_image_processor.preprocess(image).to(dtype=torch.float32) 87 | 88 | image_batch_size = image.shape[0] 89 | 90 | image = image.repeat_interleave(image_batch_size, dim=0) 91 | image = image.to(device=device, dtype=dtype) 92 | 93 | if do_classifier_free_guidance: 94 | image = torch.cat([image] * 2) 95 | 96 | return image 97 | 98 | def prepare_latents( 99 | self, batch_size, num_channels_latents, height, width, dtype, device 100 | ): 101 | shape = ( 102 | batch_size, 103 | num_channels_latents, 104 | int(height) // self.vae_scale_factor, 105 | int(width) // self.vae_scale_factor, 106 | ) 107 | 108 | latents = randn_tensor(shape, device=device, dtype=dtype) 109 | 110 | # scale the initial noise by the standard deviation required by the scheduler 111 | latents = latents * self.scheduler.init_noise_sigma 112 | return latents 113 | 114 | @property 115 | def guidance_scale(self): 116 | return self._guidance_scale 117 | 118 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 119 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 120 | # corresponds to doing no classifier free guidance. 121 | @property 122 | def do_classifier_free_guidance(self): 123 | if hasattr(self.unet, 'config'): 124 | return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None 125 | else: 126 | return self._guidance_scale > 1 127 | 128 | @property 129 | def num_timesteps(self): 130 | return self._num_timesteps 131 | 132 | @torch.no_grad() 133 | def __call__( 134 | self, 135 | controlnet_model, 136 | device, 137 | dtype, 138 | keep_model_device, 139 | scheduler: KarrasDiffusionSchedulers, 140 | unet: object, 141 | timesteps, 142 | scale_model_input_method, 143 | prompt_embeds: torch.Tensor, 144 | pooled_prompt_embeds: torch.Tensor, 145 | negative_prompt_embeds: torch.Tensor, 146 | negative_pooled_prompt_embeds: torch.Tensor, 147 | image: PipelineImageInput = None, 148 | num_inference_steps: int = 8, 149 | guidance_scale: float = 1.5, 150 | controlnet_conditioning_scale: Union[float, List[float]] = 1.0, 151 | ): 152 | self.controlnet = controlnet_model 153 | self._guidance_scale = guidance_scale 154 | self.unet = unet 155 | self.scheduler = scheduler 156 | self.timesteps = timesteps 157 | self.scale_model_input_method=scale_model_input_method 158 | 159 | 160 | # 2. Define call parameters 161 | batch_size = 1 162 | 163 | # 4. Prepare image 164 | if isinstance(self.controlnet, ControlNetModel_Union): 165 | image = self.prepare_image( 166 | image=image, 167 | device=device, 168 | dtype=self.controlnet.dtype, 169 | do_classifier_free_guidance=self.do_classifier_free_guidance, 170 | ) 171 | height, width = image.shape[-2:] 172 | else: 173 | assert False 174 | 175 | # 5. Prepare timesteps 176 | timesteps, num_inference_steps = retrieve_timesteps( 177 | self.scheduler, num_inference_steps, device 178 | ) 179 | self._num_timesteps = len(timesteps) 180 | 181 | # 6. Prepare latent variables 182 | num_channels_latents = self.unet.config.in_channels 183 | latents = self.prepare_latents( 184 | batch_size, 185 | num_channels_latents, 186 | height, 187 | width, 188 | dtype, 189 | device, 190 | ) 191 | 192 | # 7 Prepare added time ids & embeddings 193 | add_text_embeds = pooled_prompt_embeds 194 | 195 | add_time_ids = negative_add_time_ids = torch.tensor( 196 | image.shape[-2:] + torch.Size([0, 0]) + image.shape[-2:] 197 | ).unsqueeze(0) 198 | 199 | if self.do_classifier_free_guidance: 200 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) 201 | add_text_embeds = torch.cat( 202 | [negative_pooled_prompt_embeds, add_text_embeds], dim=0 203 | ) 204 | add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) 205 | 206 | add_text_embeds = add_text_embeds.to(device) 207 | add_time_ids = add_time_ids.to(device).repeat(batch_size, 1) 208 | 209 | controlnet_image_list = [0, 0, 0, 0, 0, 0, image, 0] 210 | union_control_type = ( 211 | torch.Tensor([0, 0, 0, 0, 0, 0, 1, 0]) 212 | .to(device, dtype=prompt_embeds.dtype) 213 | .repeat(batch_size * 2, 1) 214 | ) 215 | 216 | added_cond_kwargs = { 217 | "text_embeds": add_text_embeds, 218 | "time_ids": add_time_ids, 219 | "control_type": union_control_type, 220 | } 221 | 222 | controlnet_prompt_embeds = prompt_embeds.to(device) 223 | controlnet_added_cond_kwargs = added_cond_kwargs 224 | 225 | # 8. Denoising loop 226 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 227 | ComfyUI_ProgressBar = ProgressBar(int(num_inference_steps)) 228 | 229 | with tqdm(total=num_inference_steps) as pbar: 230 | for i, t in enumerate(timesteps): 231 | # expand the latents if we are doing classifier free guidance 232 | latent_model_input = ( 233 | torch.cat([latents] * 2) 234 | if self.do_classifier_free_guidance 235 | else latents 236 | ) 237 | latent_model_input = self.scheduler.scale_model_input( 238 | latent_model_input, t 239 | ) 240 | 241 | # controlnet(s) inference 242 | control_model_input = latent_model_input 243 | 244 | self.controlnet.to(device) 245 | down_block_res_samples, mid_block_res_sample = self.controlnet( 246 | control_model_input, 247 | t, 248 | encoder_hidden_states=controlnet_prompt_embeds, 249 | controlnet_cond_list=controlnet_image_list, 250 | conditioning_scale=controlnet_conditioning_scale, 251 | guess_mode=False, 252 | added_cond_kwargs=controlnet_added_cond_kwargs, 253 | return_dict=False, 254 | ) 255 | 256 | if keep_model_device: 257 | self.controlnet.to('cpu') 258 | 259 | try: 260 | # predict the noise residual 261 | self.unet.to(device) 262 | noise_pred = self.unet( 263 | latent_model_input, 264 | t, 265 | encoder_hidden_states=prompt_embeds, 266 | timestep_cond=None, 267 | cross_attention_kwargs={}, 268 | down_block_additional_residuals=down_block_res_samples, 269 | mid_block_additional_residual=mid_block_res_sample, 270 | added_cond_kwargs=added_cond_kwargs, 271 | return_dict=False, 272 | )[0] 273 | if keep_model_device: 274 | self.unet.to('cpu') 275 | except torch.cuda.OutOfMemoryError as e: # Free vram when OOM 276 | self.unet.to('cpu') 277 | print('\033[93m', 'Gpu is out of memory!', '\033[0m') 278 | raise e 279 | 280 | # perform guidance 281 | if self.do_classifier_free_guidance: 282 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 283 | noise_pred = noise_pred_uncond + guidance_scale * ( 284 | noise_pred_text - noise_pred_uncond 285 | ) 286 | 287 | # compute the previous noisy sample x_t -> x_t-1 288 | latents = self.scheduler.step( 289 | noise_pred, t, latents, return_dict=False 290 | )[0] 291 | 292 | if i == 2: 293 | prompt_embeds = prompt_embeds[-1:] 294 | add_text_embeds = add_text_embeds[-1:] 295 | add_time_ids = add_time_ids[-1:] 296 | union_control_type = union_control_type[-1:] 297 | 298 | added_cond_kwargs = { 299 | "text_embeds": add_text_embeds, 300 | "time_ids": add_time_ids, 301 | "control_type": union_control_type, 302 | } 303 | 304 | controlnet_prompt_embeds = prompt_embeds 305 | controlnet_added_cond_kwargs = added_cond_kwargs 306 | 307 | image = image[-1:] 308 | controlnet_image_list = [0, 0, 0, 0, 0, 0, image, 0] 309 | 310 | self._guidance_scale = 0.0 311 | 312 | if i == len(timesteps) - 1 or ( 313 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 314 | ): 315 | pbar.update() 316 | ComfyUI_ProgressBar.update(1) 317 | 318 | latents = latents / 0.13025 319 | yield latents 320 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-diffusersimageoutpaint" 3 | description = "ComfyUI nodes for outpainting images with diffusers, based on [a/diffusers-image-outpaint](https://huggingface.co/spaces/fffiloni/diffusers-image-outpaint/tree/main) by fffiloni." 4 | version = "1.0.0" 5 | license = {file = "LICENSE"} 6 | dependencies = ["torch", "numpy==1.26.4", "transformers", "accelerate", "diffusers", "fastapi<0.113.0", "opencv-python"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/GiusTex/ComfyUI-DiffusersImageOutpaint" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "giustex" 14 | DisplayName = "ComfyUI-DiffusersImageOutpaint" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy==1.26.4 3 | transformers==4.45.0 4 | accelerate 5 | diffusers==0.32.2 6 | fastapi<0.113.0 7 | opencv-python 8 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gc 3 | import os 4 | import numpy as np 5 | import json 6 | import comfy.model_management as mm 7 | from PIL import Image 8 | 9 | from folder_paths import map_legacy, folder_names_and_paths 10 | from .pipeline_fill_sd_xl import StableDiffusionXLFillPipeline 11 | 12 | 13 | def get_first_folder_list(folder_name: str) -> tuple[list[str], dict[str, float], float]: 14 | folder_name = map_legacy(folder_name) 15 | global folder_names_and_paths 16 | folders = folder_names_and_paths[folder_name] 17 | if folder_name == "unet": 18 | root_folder = folders[0][0] 19 | elif folder_name == "diffusion_models": 20 | root_folder = folders[0][1] 21 | elif folder_name == "controlnet": 22 | root_folder = folders[0][0] 23 | visible_folders = [name for name in os.listdir(root_folder) if os.path.isdir(os.path.join(root_folder, name))] 24 | return visible_folders 25 | 26 | def get_config_folder_list(folder_name: str) -> tuple[list[str], dict[str, float], float]: 27 | my_dir = os.path.dirname(os.path.abspath(__file__)) 28 | configs_dir = f"{my_dir}/{folder_name}" 29 | 30 | folders = [f for f in os.listdir(configs_dir) if os.path.isdir(os.path.join(configs_dir, f))] 31 | return folders 32 | 33 | # Tensor to PIL (grabbed from WAS Suite) 34 | def tensor2pil(image: torch.Tensor) -> Image.Image: 35 | return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) 36 | 37 | # Convert PIL to Tensor (grabbed from WAS Suite) 38 | def pil2tensor(image: Image.Image) -> torch.Tensor: 39 | return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) 40 | 41 | 42 | def get_device_by_name(device): 43 | if device == 'auto': 44 | device = mm.get_torch_device() 45 | return device 46 | 47 | 48 | def get_dtype_by_name(dtype): 49 | if dtype == 'auto': 50 | if mm.should_use_fp16(): 51 | dtype = torch.float16 52 | elif mm.should_use_bf16(): 53 | dtype = torch.bfloat16 54 | else: 55 | dtype = torch.float32 56 | elif dtype== "fp16": 57 | dtype = torch.float16 58 | elif dtype == "bf16": 59 | dtype = torch.bfloat16 60 | elif dtype == "fp32": 61 | dtype = torch.float32 62 | elif dtype == "fp8_e4m3fn": 63 | dtype = torch.float8_e4m3fn 64 | elif dtype == "fp8_e4m3fnuz": 65 | dtype = torch.float8_e4m3fnuz 66 | elif dtype == "fp8_e5m2": 67 | dtype = torch.float8_e5m2 68 | elif dtype == "fp8_e5m2fnuz": 69 | dtype = torch.float8_e5m2fnuz 70 | 71 | return dtype 72 | 73 | def clearVram(device): 74 | gc.collect() 75 | 76 | if device.type == "cuda": 77 | torch.cuda.empty_cache() 78 | torch.cuda.ipc_collect() 79 | elif device.type == "mps": 80 | torch.mps.empty_cache() 81 | elif device.type == "xla": 82 | torch.xla.empty_cache() 83 | elif device.type == "xpu": 84 | torch.xpu.empty_cache() 85 | elif device.type == "meta": 86 | torch.meta.empty_cache() 87 | 88 | 89 | class TCDScheduler_Custom: 90 | def __init__(self, **kwargs): 91 | for key, value in kwargs.items(): 92 | setattr(self, key, value) 93 | 94 | def scale_model_input(self, input, t): 95 | scale_factor = getattr(self, 'scale_factor', 1) 96 | return input * scale_factor 97 | 98 | def __repr__(self): 99 | attrs = {key: value for key, value in self.__dict__.items()} 100 | return f"TCDScheduler({attrs})" 101 | 102 | 103 | def test_scheduler_scale_model_input(comfy_dir, model_type): 104 | scheduler_config_path = f"{comfy_dir}/custom_nodes/ComfyUI-DiffusersImageOutpaint/configs/{model_type}/scheduler/scheduler_config.json" 105 | 106 | with open(scheduler_config_path, 'r') as f: 107 | config = json.load(f) 108 | 109 | scheduler = TCDScheduler_Custom(**config) 110 | scale_model_input_method = scheduler.scale_model_input 111 | 112 | return scale_model_input_method 113 | 114 | 115 | def diffuserOutpaintSamples(device, dtype, keep_model_device, scheduler, scale_model_input_method, model, control_net, positive, negative, 116 | cnet_image, controlnet_strength, guidance_scale, steps): 117 | 118 | prompt_embeds = positive["prompt_embeds"] 119 | pooled_prompt_embeds = positive["pooled_prompt_embeds"] 120 | negative_prompt_embeds = negative["prompt_embeds"] 121 | negative_pooled_prompt_embeds = negative["pooled_prompt_embeds"] 122 | controlnet_model = control_net 123 | 124 | device = get_device_by_name(device) 125 | dtype = get_dtype_by_name(dtype) 126 | 127 | timesteps = None 128 | unet = model 129 | 130 | pipe = StableDiffusionXLFillPipeline() 131 | 132 | rgb_latents = list(pipe( 133 | prompt_embeds=prompt_embeds, 134 | negative_prompt_embeds=negative_prompt_embeds, 135 | pooled_prompt_embeds=pooled_prompt_embeds, 136 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 137 | image=cnet_image, 138 | num_inference_steps=steps, 139 | controlnet_model=controlnet_model, 140 | controlnet_conditioning_scale=controlnet_strength, 141 | guidance_scale=guidance_scale, 142 | device=device, 143 | dtype=dtype, 144 | unet=unet, 145 | timesteps=timesteps, 146 | scale_model_input_method=scale_model_input_method, 147 | keep_model_device=keep_model_device, 148 | scheduler=scheduler, 149 | )) 150 | 151 | last_rgb_latent = rgb_latents[-1] # Access the last image 152 | 153 | del pipe, unet, controlnet_model, scheduler, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds 154 | 155 | clearVram(device) 156 | 157 | return last_rgb_latent 158 | --------------------------------------------------------------------------------