├── .gitignore ├── LICENSE ├── README.md ├── assets ├── bg_cond.png ├── controlnet_output.png ├── demo_result.png ├── dreamshaper_sd.png ├── fg_cond.png ├── ipadapter_output.png ├── man_crop.png ├── man_mask.png ├── result_bg_fg_cond.png ├── result_blended_bg_fg_cond.png ├── result_blended_fg_bg_cond.png ├── result_conv_sdxl.png ├── result_fg_bg_cond.png ├── result_inpaint_sdxl.png ├── result_joint_0.png ├── result_joint_1.png ├── result_joint_2.png └── result_sdxl.png ├── layer_diffuse ├── __init__.py ├── loaders.py ├── models │ ├── __init__.py │ ├── attention_processors.py │ └── modules.py └── utils.py ├── requirements.txt ├── test_diffusers_bg_fg_cond.py ├── test_diffusers_fg_bg_cond.py ├── test_diffusers_fg_only.py ├── test_diffusers_fg_only_conv_sdxl.py ├── test_diffusers_fg_only_sdxl.py ├── test_diffusers_fg_only_sdxl_img2img.py └── test_diffusers_joint.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Vinh H. Pham 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffusers API of Transparent Image Layer Diffusion using Latent Transparency 2 | 3 | 🤗 **Hugging Face**: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/rootonchair/diffuser_layerdiffuse) 🔥🔥🔥 4 | 5 | Create transparent image with Diffusers! 6 | 7 | This is a port to Diffuser from original [SD Webui's Layer Diffusion](https://github.com/layerdiffusion/sd-forge-layerdiffuse) to extend the ability to generate transparent image with your favorite API 8 | 9 | 10 | Paper: [Transparent Image Layer Diffusion using Latent Transparency](https://arxiv.org/abs/2402.17113) 11 | ## Setup 12 | ```bash 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | ## Quickstart 17 | 18 | Generate transparent image with SD1.5 models. In this example, we will use [digiplay/Juggernaut_final](https://huggingface.co/digiplay/Juggernaut_final) as the base model 19 | 20 | ### Stable Diffusion 1.5 21 | 22 | ```python 23 | from huggingface_hub import hf_hub_download 24 | from safetensors.torch import load_file 25 | import torch 26 | 27 | from diffusers import StableDiffusionPipeline 28 | 29 | from models import TransparentVAEDecoder 30 | from loaders import load_lora_to_unet 31 | 32 | 33 | model_path = hf_hub_download( 34 | 'LayerDiffusion/layerdiffusion-v1', 35 | 'layer_sd15_vae_transparent_decoder.safetensors', 36 | ) 37 | 38 | vae_transparent_decoder = TransparentVAEDecoder.from_pretrained("digiplay/Juggernaut_final", subfolder="vae", torch_dtype=torch.float16).to("cuda") 39 | vae_transparent_decoder.set_transparent_decoder(load_file(model_path)) 40 | 41 | pipeline = StableDiffusionPipeline.from_pretrained("digiplay/Juggernaut_final", vae=vae_transparent_decoder, torch_dtype=torch.float16, safety_checker=None).to("cuda") 42 | 43 | model_path = hf_hub_download( 44 | 'LayerDiffusion/layerdiffusion-v1', 45 | 'layer_sd15_transparent_attn.safetensors' 46 | ) 47 | 48 | load_lora_to_unet(pipeline.unet, model_path, frames=1) 49 | 50 | image = pipeline(prompt="a dog sitting in room, high quality", 51 | width=512, height=512, 52 | num_images_per_prompt=1, return_dict=False)[0] 53 | ``` 54 | 55 | Would produce the below image 56 | 57 | ![demo_result](assets/demo_result.png) 58 | 59 | ### Stable Diffusion XL 60 | 61 | It's a LoRA and will compatible with any Diffusers usage: ControlNet, IPAdapter, etc. 62 | 63 | ```python 64 | from huggingface_hub import hf_hub_download 65 | from safetensors.torch import load_file 66 | import torch 67 | 68 | from diffusers import StableDiffusionXLPipeline 69 | 70 | from models import TransparentVAEDecoder 71 | 72 | 73 | transparent_vae = TransparentVAEDecoder.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) 74 | model_path = hf_hub_download( 75 | 'LayerDiffusion/layerdiffusion-v1', 76 | 'vae_transparent_decoder.safetensors', 77 | ) 78 | transparent_vae.set_transparent_decoder(load_file(model_path)) 79 | 80 | pipeline = StableDiffusionXLPipeline.from_pretrained( 81 | "stabilityai/stable-diffusion-xl-base-1.0", 82 | vae=transparent_vae, 83 | torch_dtype=torch.float16, variant="fp16", use_safetensors=True 84 | ).to("cuda") 85 | pipeline.load_lora_weights('rootonchair/diffuser_layerdiffuse', weight_name='diffuser_layer_xl_transparent_attn.safetensors') 86 | 87 | seed = torch.randint(high=1000000, size=(1,)).item() 88 | prompt = "a cute corgi" 89 | negative_prompt = "" 90 | images = pipeline(prompt=prompt, 91 | negative_prompt=negative_prompt, 92 | generator=torch.Generator(device='cuda').manual_seed(seed), 93 | num_images_per_prompt=1, return_dict=False)[0] 94 | 95 | images[0].save("result_sdxl.png") 96 | ``` 97 | 98 | ## Scripts 99 | 100 | - `test_diffusers_fg_only.py`: Only generate transparent foreground image 101 | - `test_diffusers_joint.py`: Generate foreground, background, blend image together. Hence `num_images_per_prompt` must be batch size of 3 102 | - `test_diffusers_fg_bg_cond.py`: Generate foreground, conditioned on background provided. Hence `num_images_per_prompt` must be batch size of 2 103 | - `test_diffusers_bg_fg_cond.py`: Generate background, conditioned on foreground provided. Hence `num_images_per_prompt` must be batch size of 2 104 | - `test_diffusers_joint.py`: Generate foreground, background, blend image together. Hence `num_images_per_prompt` must be batch size of 3 105 | - `test_diffusers_fg_only_sdxl.py`: Only generate transparent foreground image using Attention injection in SDXL 106 | - `test_diffusers_fg_only_conv_sdxl.py`: Only generate transparent foreground image using Conv injection in SDXL 107 | - `test_diffusers_fg_only_sdxl_img2img.py`: Generate transparent foreground image inpaint using Attention injection in SDXL 108 | 109 | It is said by the author that Attention injection would result in better generation quality and Conv injection would result in better prompt alignment 110 | 111 | ## Example 112 | ### Stable Diffusion 1.5 113 | #### Generate only transparent image with SD1.5 114 | ![demo_dreamshaper](assets/dreamshaper_sd.png) 115 | #### Generate foreground and background together 116 | 117 | | Foreground | Background | Blended | 118 | |:-------------------------------------:|:-------------------------------------:|:-------------------------------------:| 119 | | ![fg](assets/result_joint_0.png) | ![bg](assets/result_joint_1.png) | ![blend](assets/result_joint_2.png) | 120 | 121 | 122 | #### Use with ControlNet 123 | 124 | ![controlnet](assets/controlnet_output.png) 125 | 126 | #### Use with IP-Adapter 127 | 128 | ![ip_adapter](assets/ipadapter_output.png) 129 | 130 | #### Generate foreground condition on background 131 | 132 | The blended image will not have the correct color but you can apply foreground image on the condition background. 133 | 134 | | Foreground | Background (Condition) | Blended | 135 | |:-------------------------------------:|:-------------------------------------:|:-------------------------------------:| 136 | | ![fg](assets/result_fg_bg_cond.png) | ![bg](assets/bg_cond.png) | ![blend](assets/result_blended_fg_bg_cond.png) | 137 | 138 | 139 | #### Generate background condition on foreground 140 | 141 | | Foreground (Condition) | Background | Blended | 142 | |:-------------------------------------:|:-------------------------------------:|:-------------------------------------:| 143 | | ![fg](assets/fg_cond.png) | ![bg](assets/result_bg_fg_cond.png) | ![blend](assets/result_blended_bg_fg_cond.png) | 144 | 145 | ### Stable Diffusion XL 146 | #### Combine with other LoRAs 147 | Combine with SDXL Lora [nerijs/pixel-art-xl](https://huggingface.co/nerijs/pixel-art-xl) 148 | 149 | | Attn Injection (LoRA) | Conv Injection (Weight diff) | 150 | |:-------------------------------------:|:-------------------------------------:| 151 | | ![sdxl_attn](assets/result_sdxl.png) | ![sdxl_conv](assets/result_conv_sdxl.png) | 152 | 153 | #### Inpaint 154 | Use inpaint pipeline to refine poorly cropped transparent image 155 | 156 | | Foreground | Mask | Inpaint | 157 | |:-------------------------------------:|:-------------------------------------:|:-------------------------------------:| 158 | | ![man_crop](assets/man_crop.png) | ![mask](assets/man_mask.png) | ![inpaint](assets/result_inpaint_sdxl.png) | 159 | 160 | ## Acknowledgments 161 | This work is based on the great code at 162 | [https://github.com/layerdiffusion/sd-forge-layerdiffuse](https://github.com/layerdiffusion/sd-forge-layerdiffuse) 163 | -------------------------------------------------------------------------------- /assets/bg_cond.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rootonchair/diffuser_layerdiffuse/d06cf19244c024e50944f59703842ecd4ae6478d/assets/bg_cond.png -------------------------------------------------------------------------------- /assets/controlnet_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rootonchair/diffuser_layerdiffuse/d06cf19244c024e50944f59703842ecd4ae6478d/assets/controlnet_output.png -------------------------------------------------------------------------------- /assets/demo_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rootonchair/diffuser_layerdiffuse/d06cf19244c024e50944f59703842ecd4ae6478d/assets/demo_result.png -------------------------------------------------------------------------------- /assets/dreamshaper_sd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rootonchair/diffuser_layerdiffuse/d06cf19244c024e50944f59703842ecd4ae6478d/assets/dreamshaper_sd.png -------------------------------------------------------------------------------- /assets/fg_cond.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rootonchair/diffuser_layerdiffuse/d06cf19244c024e50944f59703842ecd4ae6478d/assets/fg_cond.png -------------------------------------------------------------------------------- /assets/ipadapter_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rootonchair/diffuser_layerdiffuse/d06cf19244c024e50944f59703842ecd4ae6478d/assets/ipadapter_output.png -------------------------------------------------------------------------------- /assets/man_crop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rootonchair/diffuser_layerdiffuse/d06cf19244c024e50944f59703842ecd4ae6478d/assets/man_crop.png -------------------------------------------------------------------------------- /assets/man_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rootonchair/diffuser_layerdiffuse/d06cf19244c024e50944f59703842ecd4ae6478d/assets/man_mask.png -------------------------------------------------------------------------------- /assets/result_bg_fg_cond.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rootonchair/diffuser_layerdiffuse/d06cf19244c024e50944f59703842ecd4ae6478d/assets/result_bg_fg_cond.png -------------------------------------------------------------------------------- /assets/result_blended_bg_fg_cond.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rootonchair/diffuser_layerdiffuse/d06cf19244c024e50944f59703842ecd4ae6478d/assets/result_blended_bg_fg_cond.png -------------------------------------------------------------------------------- /assets/result_blended_fg_bg_cond.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rootonchair/diffuser_layerdiffuse/d06cf19244c024e50944f59703842ecd4ae6478d/assets/result_blended_fg_bg_cond.png -------------------------------------------------------------------------------- /assets/result_conv_sdxl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rootonchair/diffuser_layerdiffuse/d06cf19244c024e50944f59703842ecd4ae6478d/assets/result_conv_sdxl.png -------------------------------------------------------------------------------- /assets/result_fg_bg_cond.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rootonchair/diffuser_layerdiffuse/d06cf19244c024e50944f59703842ecd4ae6478d/assets/result_fg_bg_cond.png -------------------------------------------------------------------------------- /assets/result_inpaint_sdxl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rootonchair/diffuser_layerdiffuse/d06cf19244c024e50944f59703842ecd4ae6478d/assets/result_inpaint_sdxl.png -------------------------------------------------------------------------------- /assets/result_joint_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rootonchair/diffuser_layerdiffuse/d06cf19244c024e50944f59703842ecd4ae6478d/assets/result_joint_0.png -------------------------------------------------------------------------------- /assets/result_joint_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rootonchair/diffuser_layerdiffuse/d06cf19244c024e50944f59703842ecd4ae6478d/assets/result_joint_1.png -------------------------------------------------------------------------------- /assets/result_joint_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rootonchair/diffuser_layerdiffuse/d06cf19244c024e50944f59703842ecd4ae6478d/assets/result_joint_2.png -------------------------------------------------------------------------------- /assets/result_sdxl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rootonchair/diffuser_layerdiffuse/d06cf19244c024e50944f59703842ecd4ae6478d/assets/result_sdxl.png -------------------------------------------------------------------------------- /layer_diffuse/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rootonchair/diffuser_layerdiffuse/d06cf19244c024e50944f59703842ecd4ae6478d/layer_diffuse/__init__.py -------------------------------------------------------------------------------- /layer_diffuse/loaders.py: -------------------------------------------------------------------------------- 1 | from safetensors.torch import load_file 2 | from diffusers.models.attention_processor import Attention, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, AttnProcessor2_0 3 | from .models import LoraLoader, AttentionSharingProcessor, IPAdapterAttnShareProcessor, AttentionSharingProcessor2_0, IPAdapterAttnShareProcessor2_0 4 | 5 | 6 | def merge_delta_weights_into_unet(pipe, delta_weights): 7 | unet_weights = pipe.unet.state_dict() 8 | 9 | for k in delta_weights.keys(): 10 | assert k in unet_weights.keys(), k 11 | 12 | for key in delta_weights.keys(): 13 | dtype = unet_weights[key].dtype 14 | unet_weights[key] = unet_weights[key].to(dtype=delta_weights[key].dtype) + delta_weights[key].to(device=unet_weights[key].device) 15 | unet_weights[key] = unet_weights[key].to(dtype) 16 | pipe.unet.load_state_dict(unet_weights, strict=True) 17 | return pipe 18 | 19 | 20 | def get_kwargs_encoder(): 21 | pass 22 | 23 | 24 | def get_attr(obj, attr): 25 | attrs = attr.split(".") 26 | for name in attrs: 27 | obj = getattr(obj, name) 28 | return obj 29 | 30 | 31 | def load_lora_to_unet(unet, model_path, frames=1, use_control=False): 32 | module_mapping_sd15 = {0: 'input_blocks.1.1.transformer_blocks.0.attn1', 1: 'input_blocks.1.1.transformer_blocks.0.attn2', 2: 'input_blocks.2.1.transformer_blocks.0.attn1', 3: 'input_blocks.2.1.transformer_blocks.0.attn2', 4: 'input_blocks.4.1.transformer_blocks.0.attn1', 5: 'input_blocks.4.1.transformer_blocks.0.attn2', 6: 'input_blocks.5.1.transformer_blocks.0.attn1', 7: 'input_blocks.5.1.transformer_blocks.0.attn2', 8: 'input_blocks.7.1.transformer_blocks.0.attn1', 9: 'input_blocks.7.1.transformer_blocks.0.attn2', 10: 'input_blocks.8.1.transformer_blocks.0.attn1', 11: 'input_blocks.8.1.transformer_blocks.0.attn2', 12: 'output_blocks.3.1.transformer_blocks.0.attn1', 13: 'output_blocks.3.1.transformer_blocks.0.attn2', 14: 'output_blocks.4.1.transformer_blocks.0.attn1', 15: 'output_blocks.4.1.transformer_blocks.0.attn2', 16: 'output_blocks.5.1.transformer_blocks.0.attn1', 17: 'output_blocks.5.1.transformer_blocks.0.attn2', 18: 'output_blocks.6.1.transformer_blocks.0.attn1', 19: 'output_blocks.6.1.transformer_blocks.0.attn2', 20: 'output_blocks.7.1.transformer_blocks.0.attn1', 21: 'output_blocks.7.1.transformer_blocks.0.attn2', 22: 'output_blocks.8.1.transformer_blocks.0.attn1', 23: 'output_blocks.8.1.transformer_blocks.0.attn2', 24: 'output_blocks.9.1.transformer_blocks.0.attn1', 25: 'output_blocks.9.1.transformer_blocks.0.attn2', 26: 'output_blocks.10.1.transformer_blocks.0.attn1', 27: 'output_blocks.10.1.transformer_blocks.0.attn2', 28: 'output_blocks.11.1.transformer_blocks.0.attn1', 29: 'output_blocks.11.1.transformer_blocks.0.attn2', 30: 'middle_block.1.transformer_blocks.0.attn1', 31: 'middle_block.1.transformer_blocks.0.attn2'} 33 | 34 | sd15_to_diffusers = { 35 | 'input_blocks.1.1.transformer_blocks.0.attn1': 'down_blocks.0.attentions.0.transformer_blocks.0.attn1', 36 | 'input_blocks.1.1.transformer_blocks.0.attn2': 'down_blocks.0.attentions.0.transformer_blocks.0.attn2', 37 | 'input_blocks.2.1.transformer_blocks.0.attn1': 'down_blocks.0.attentions.1.transformer_blocks.0.attn1', 38 | 'input_blocks.2.1.transformer_blocks.0.attn2': 'down_blocks.0.attentions.1.transformer_blocks.0.attn2', 39 | 'input_blocks.4.1.transformer_blocks.0.attn1': 'down_blocks.1.attentions.0.transformer_blocks.0.attn1', 40 | 'input_blocks.4.1.transformer_blocks.0.attn2': 'down_blocks.1.attentions.0.transformer_blocks.0.attn2', 41 | 'input_blocks.5.1.transformer_blocks.0.attn1': 'down_blocks.1.attentions.1.transformer_blocks.0.attn1', 42 | 'input_blocks.5.1.transformer_blocks.0.attn2': 'down_blocks.1.attentions.1.transformer_blocks.0.attn2', 43 | 'input_blocks.7.1.transformer_blocks.0.attn1': 'down_blocks.2.attentions.0.transformer_blocks.0.attn1', 44 | 'input_blocks.7.1.transformer_blocks.0.attn2': 'down_blocks.2.attentions.0.transformer_blocks.0.attn2', 45 | 'input_blocks.8.1.transformer_blocks.0.attn1': 'down_blocks.2.attentions.1.transformer_blocks.0.attn1', 46 | 'input_blocks.8.1.transformer_blocks.0.attn2': 'down_blocks.2.attentions.1.transformer_blocks.0.attn2', 47 | 'output_blocks.3.1.transformer_blocks.0.attn1': "up_blocks.1.attentions.0.transformer_blocks.0.attn1", 48 | 'output_blocks.3.1.transformer_blocks.0.attn2': "up_blocks.1.attentions.0.transformer_blocks.0.attn2", 49 | 'output_blocks.4.1.transformer_blocks.0.attn1': "up_blocks.1.attentions.1.transformer_blocks.0.attn1", 50 | 'output_blocks.4.1.transformer_blocks.0.attn2': "up_blocks.1.attentions.1.transformer_blocks.0.attn2", 51 | 'output_blocks.5.1.transformer_blocks.0.attn1': "up_blocks.1.attentions.2.transformer_blocks.0.attn1", 52 | 'output_blocks.5.1.transformer_blocks.0.attn2': "up_blocks.1.attentions.2.transformer_blocks.0.attn2", 53 | 'output_blocks.6.1.transformer_blocks.0.attn1': "up_blocks.2.attentions.0.transformer_blocks.0.attn1", 54 | 'output_blocks.6.1.transformer_blocks.0.attn2': "up_blocks.2.attentions.0.transformer_blocks.0.attn2", 55 | 'output_blocks.7.1.transformer_blocks.0.attn1': "up_blocks.2.attentions.1.transformer_blocks.0.attn1", 56 | 'output_blocks.7.1.transformer_blocks.0.attn2': "up_blocks.2.attentions.1.transformer_blocks.0.attn2", 57 | 'output_blocks.8.1.transformer_blocks.0.attn1': "up_blocks.2.attentions.2.transformer_blocks.0.attn1", 58 | 'output_blocks.8.1.transformer_blocks.0.attn2': "up_blocks.2.attentions.2.transformer_blocks.0.attn2", 59 | 'output_blocks.9.1.transformer_blocks.0.attn1': "up_blocks.3.attentions.0.transformer_blocks.0.attn1", 60 | 'output_blocks.9.1.transformer_blocks.0.attn2': "up_blocks.3.attentions.0.transformer_blocks.0.attn2", 61 | 'output_blocks.10.1.transformer_blocks.0.attn1': "up_blocks.3.attentions.1.transformer_blocks.0.attn1", 62 | 'output_blocks.10.1.transformer_blocks.0.attn2': "up_blocks.3.attentions.1.transformer_blocks.0.attn2", 63 | 'output_blocks.11.1.transformer_blocks.0.attn1': "up_blocks.3.attentions.2.transformer_blocks.0.attn1", 64 | 'output_blocks.11.1.transformer_blocks.0.attn2': "up_blocks.3.attentions.2.transformer_blocks.0.attn2", 65 | 'middle_block.1.transformer_blocks.0.attn1': "mid_block.attentions.0.transformer_blocks.0.attn1", 66 | 'middle_block.1.transformer_blocks.0.attn2': "mid_block.attentions.0.transformer_blocks.0.attn2", 67 | } 68 | 69 | layer_list = [] 70 | for i in range(32): 71 | real_key = module_mapping_sd15[i] 72 | diffuser_key = sd15_to_diffusers[real_key] 73 | attn_module: Attention = get_attr(unet, diffuser_key) 74 | if isinstance(attn_module.processor, IPAdapterAttnProcessor2_0): 75 | u = IPAdapterAttnShareProcessor2_0(attn_module, frames=frames, use_control=use_control).to(unet.dtype) 76 | elif isinstance(attn_module.processor, IPAdapterAttnProcessor): 77 | u = IPAdapterAttnShareProcessor(attn_module, frames=frames, use_control=use_control).to(unet.dtype) 78 | elif isinstance(attn_module.processor, AttnProcessor2_0): 79 | u = AttentionSharingProcessor2_0(attn_module, frames=frames, use_control=use_control).to(unet.dtype) 80 | else: 81 | u = AttentionSharingProcessor(attn_module, frames=frames, use_control=use_control).to(unet.dtype) 82 | u = u.to(unet.device) 83 | layer_list.append(u) 84 | attn_module.set_processor(u) 85 | 86 | loader = LoraLoader(layer_list, use_control=use_control) 87 | lora_state_dict = load_file(model_path) 88 | loader.load_state_dict(lora_state_dict, strict=False) 89 | 90 | return loader.kwargs_encoder -------------------------------------------------------------------------------- /layer_diffuse/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | from .attention_processors import * -------------------------------------------------------------------------------- /layer_diffuse/models/attention_processors.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | import einops 5 | from typing import Optional, List 6 | 7 | from diffusers.image_processor import IPAdapterMaskProcessor 8 | from diffusers.models.attention_processor import Attention 9 | 10 | 11 | class HookerLayers(torch.nn.Module): 12 | def __init__(self, layer_list): 13 | super().__init__() 14 | self.layers = torch.nn.ModuleList(layer_list) 15 | 16 | 17 | class AdditionalAttentionCondsEncoder(torch.nn.Module): 18 | def __init__(self): 19 | super().__init__() 20 | 21 | self.blocks_0 = torch.nn.Sequential( 22 | torch.nn.Conv2d(3, 32, kernel_size=3, padding=1, stride=1), 23 | torch.nn.SiLU(), 24 | torch.nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1), 25 | torch.nn.SiLU(), 26 | torch.nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2), 27 | torch.nn.SiLU(), 28 | torch.nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1), 29 | torch.nn.SiLU(), 30 | torch.nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2), 31 | torch.nn.SiLU(), 32 | torch.nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=1), 33 | torch.nn.SiLU(), 34 | torch.nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), 35 | torch.nn.SiLU(), 36 | torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), 37 | torch.nn.SiLU(), 38 | ) # 64*64*256 39 | 40 | self.blocks_1 = torch.nn.Sequential( 41 | torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2), 42 | torch.nn.SiLU(), 43 | torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), 44 | torch.nn.SiLU(), 45 | ) # 32*32*256 46 | 47 | self.blocks_2 = torch.nn.Sequential( 48 | torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2), 49 | torch.nn.SiLU(), 50 | torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), 51 | torch.nn.SiLU(), 52 | ) # 16*16*256 53 | 54 | self.blocks_3 = torch.nn.Sequential( 55 | torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2), 56 | torch.nn.SiLU(), 57 | torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), 58 | torch.nn.SiLU(), 59 | ) # 8*8*256 60 | 61 | self.blks = [self.blocks_0, self.blocks_1, self.blocks_2, self.blocks_3] 62 | 63 | def __call__(self, h): 64 | results = {} 65 | for b in self.blks: 66 | h = b(h) 67 | results[int(h.shape[2]) * int(h.shape[3])] = h 68 | return results 69 | 70 | 71 | class LoraLoader(torch.nn.Module): 72 | def __init__(self, layer_list, use_control=False): 73 | super().__init__() 74 | self.hookers = HookerLayers(layer_list) 75 | 76 | if use_control: 77 | self.kwargs_encoder = AdditionalAttentionCondsEncoder() 78 | else: 79 | self.kwargs_encoder = None 80 | 81 | 82 | class LoRALinearLayer(torch.nn.Module): 83 | def __init__(self, in_features: int, out_features: int, rank: int = 256): 84 | super().__init__() 85 | self.down = torch.nn.Linear(in_features, rank, bias=False) 86 | self.up = torch.nn.Linear(rank, out_features, bias=False) 87 | 88 | def forward(self, h, org): 89 | org_weight = org.weight.to(h) 90 | if hasattr(org, 'bias'): 91 | org_bias = org.bias.to(h) if org.bias is not None else None 92 | else: 93 | org_bias = None 94 | down_weight = self.down.weight 95 | up_weight = self.up.weight 96 | final_weight = org_weight + torch.mm(up_weight, down_weight) 97 | return torch.nn.functional.linear(h, final_weight, org_bias) 98 | 99 | 100 | class AttentionSharingProcessor(nn.Module): 101 | def __init__(self, module, frames=2, rank=256, use_control=False): 102 | super().__init__() 103 | 104 | self.heads = module.heads 105 | self.frames = frames 106 | self.original_module = [module] 107 | q_in_channels, q_out_channels = module.to_q.in_features, module.to_q.out_features 108 | k_in_channels, k_out_channels = module.to_k.in_features, module.to_k.out_features 109 | v_in_channels, v_out_channels = module.to_v.in_features, module.to_v.out_features 110 | o_in_channels, o_out_channels = module.to_out[0].in_features, module.to_out[0].out_features 111 | 112 | hidden_size = k_out_channels 113 | 114 | self.to_q_lora = [LoRALinearLayer(q_in_channels, q_out_channels, rank) for _ in range(self.frames)] 115 | self.to_k_lora = [LoRALinearLayer(k_in_channels, k_out_channels, rank) for _ in range(self.frames)] 116 | self.to_v_lora = [LoRALinearLayer(v_in_channels, v_out_channels, rank) for _ in range(self.frames)] 117 | self.to_out_lora = [LoRALinearLayer(o_in_channels, o_out_channels, rank) for _ in range(self.frames)] 118 | 119 | self.to_q_lora = torch.nn.ModuleList(self.to_q_lora) 120 | self.to_k_lora = torch.nn.ModuleList(self.to_k_lora) 121 | self.to_v_lora = torch.nn.ModuleList(self.to_v_lora) 122 | self.to_out_lora = torch.nn.ModuleList(self.to_out_lora) 123 | 124 | self.temporal_i = torch.nn.Linear(in_features=hidden_size, out_features=hidden_size) 125 | self.temporal_n = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) 126 | self.temporal_q = torch.nn.Linear(in_features=hidden_size, out_features=hidden_size) 127 | self.temporal_k = torch.nn.Linear(in_features=hidden_size, out_features=hidden_size) 128 | self.temporal_v = torch.nn.Linear(in_features=hidden_size, out_features=hidden_size) 129 | self.temporal_o = torch.nn.Linear(in_features=hidden_size, out_features=hidden_size) 130 | 131 | self.control_convs = None 132 | 133 | if use_control: 134 | self.control_convs = [torch.nn.Sequential( 135 | torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), 136 | torch.nn.SiLU(), 137 | torch.nn.Conv2d(256, hidden_size, kernel_size=1), 138 | ) for _ in range(self.frames)] 139 | self.control_convs = torch.nn.ModuleList(self.control_convs) 140 | 141 | def __call__( 142 | self, 143 | attn: Attention, 144 | hidden_states: torch.FloatTensor, 145 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 146 | attention_mask: Optional[torch.FloatTensor] = None, 147 | temb: Optional[torch.Tensor] = None, 148 | layerdiffuse_control_signals: Optional[torch.Tensor] = None, 149 | ) -> torch.Tensor: 150 | 151 | if attn.spatial_norm is not None: 152 | hidden_states = attn.spatial_norm(hidden_states, temb) 153 | 154 | input_ndim = hidden_states.ndim 155 | 156 | if input_ndim == 4: 157 | batch_size, channel, height, width = hidden_states.shape 158 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 159 | 160 | batch_size, sequence_length, _ = ( 161 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 162 | ) 163 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 164 | 165 | if attn.group_norm is not None: 166 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 167 | 168 | # LayerDiffuse main logic 169 | modified_hidden_states = einops.rearrange(hidden_states, '(b f) d c -> f b d c', f=self.frames) 170 | 171 | if self.control_convs is not None: 172 | context_dim = int(modified_hidden_states.shape[2]) 173 | control_outs = [] 174 | for f in range(self.frames): 175 | control_signal = layerdiffuse_control_signals[context_dim].to(modified_hidden_states) 176 | control = self.control_convs[f](control_signal) 177 | control = einops.rearrange(control, 'b c h w -> b (h w) c') 178 | control_outs.append(control) 179 | control_outs = torch.stack(control_outs, dim=0) 180 | modified_hidden_states = modified_hidden_states + control_outs.to(modified_hidden_states) 181 | 182 | if encoder_hidden_states is None: 183 | framed_context = modified_hidden_states 184 | else: 185 | if attn.norm_cross: 186 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 187 | framed_context = einops.rearrange(encoder_hidden_states, '(b f) d c -> f b d c', f=self.frames) 188 | 189 | 190 | attn_outs = [] 191 | for f in range(self.frames): 192 | fcf = framed_context[f] 193 | 194 | query = self.to_q_lora[f](modified_hidden_states[f], attn.to_q) 195 | key = self.to_k_lora[f](fcf, attn.to_k) 196 | value = self.to_v_lora[f](fcf, attn.to_v) 197 | 198 | query = attn.head_to_batch_dim(query) 199 | key = attn.head_to_batch_dim(key) 200 | value = attn.head_to_batch_dim(value) 201 | 202 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 203 | output = torch.bmm(attention_probs, value) 204 | output = attn.batch_to_head_dim(output) 205 | output = self.to_out_lora[f](output, attn.to_out[0]) 206 | output = attn.to_out[1](output) 207 | attn_outs.append(output) 208 | 209 | attn_outs = torch.stack(attn_outs, dim=0) 210 | modified_hidden_states = modified_hidden_states + attn_outs.to(modified_hidden_states) 211 | modified_hidden_states = einops.rearrange(modified_hidden_states, 'f b d c -> (b f) d c', f=self.frames) 212 | modified_hidden_states = modified_hidden_states / attn.rescale_output_factor 213 | 214 | x = modified_hidden_states 215 | x = self.temporal_n(x) 216 | x = self.temporal_i(x) 217 | d = x.shape[1] 218 | 219 | x = einops.rearrange(x, "(b f) d c -> (b d) f c", f=self.frames) 220 | 221 | query = self.temporal_q(x) 222 | key = self.temporal_k(x) 223 | value = self.temporal_v(x) 224 | 225 | query = attn.head_to_batch_dim(query) 226 | key = attn.head_to_batch_dim(key) 227 | value = attn.head_to_batch_dim(value) 228 | 229 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 230 | x = torch.bmm(attention_probs, value) 231 | x = attn.batch_to_head_dim(x) 232 | 233 | x = self.temporal_o(x) 234 | x = einops.rearrange(x, "(b d) f c -> (b f) d c", d=d) 235 | 236 | modified_hidden_states = modified_hidden_states + x 237 | 238 | hidden_states = modified_hidden_states - hidden_states 239 | 240 | if input_ndim == 4: 241 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 242 | 243 | return hidden_states 244 | 245 | 246 | class IPAdapterAttnShareProcessor(nn.Module): 247 | def __init__(self, module, frames=2, rank=256): 248 | super().__init__() 249 | 250 | self.heads = module.heads 251 | self.frames = frames 252 | self.original_module = [module] 253 | q_in_channels, q_out_channels = module.to_q.in_features, module.to_q.out_features 254 | k_in_channels, k_out_channels = module.to_k.in_features, module.to_k.out_features 255 | v_in_channels, v_out_channels = module.to_v.in_features, module.to_v.out_features 256 | o_in_channels, o_out_channels = module.to_out[0].in_features, module.to_out[0].out_features 257 | 258 | hidden_size = k_out_channels 259 | 260 | self.to_q_lora = [LoRALinearLayer(q_in_channels, q_out_channels, rank) for _ in range(self.frames)] 261 | self.to_k_lora = [LoRALinearLayer(k_in_channels, k_out_channels, rank) for _ in range(self.frames)] 262 | self.to_v_lora = [LoRALinearLayer(v_in_channels, v_out_channels, rank) for _ in range(self.frames)] 263 | self.to_out_lora = [LoRALinearLayer(o_in_channels, o_out_channels, rank) for _ in range(self.frames)] 264 | 265 | self.to_q_lora = torch.nn.ModuleList(self.to_q_lora) 266 | self.to_k_lora = torch.nn.ModuleList(self.to_k_lora) 267 | self.to_v_lora = torch.nn.ModuleList(self.to_v_lora) 268 | self.to_out_lora = torch.nn.ModuleList(self.to_out_lora) 269 | 270 | self.temporal_i = torch.nn.Linear(in_features=hidden_size, out_features=hidden_size) 271 | self.temporal_n = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) 272 | self.temporal_q = torch.nn.Linear(in_features=hidden_size, out_features=hidden_size) 273 | self.temporal_k = torch.nn.Linear(in_features=hidden_size, out_features=hidden_size) 274 | self.temporal_v = torch.nn.Linear(in_features=hidden_size, out_features=hidden_size) 275 | self.temporal_o = torch.nn.Linear(in_features=hidden_size, out_features=hidden_size) 276 | 277 | # IP-Adapter part 278 | self.scale = module.processor.scale 279 | self.num_tokens = module.processor.num_tokens 280 | 281 | self.to_k_ip = module.processor.to_k_ip 282 | self.to_v_ip = module.processor.to_v_ip 283 | 284 | def _fuse_ip_adapter( 285 | self, 286 | attn: Attention, 287 | batch_size: int, 288 | query: torch.Tensor, 289 | hidden_states: torch.Tensor, 290 | ip_hidden_states: Optional[torch.Tensor] = None, 291 | ip_adapter_masks: Optional[torch.Tensor] = None, 292 | ): 293 | # for ip-adapter 294 | for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( 295 | ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks 296 | ): 297 | skip = False 298 | if isinstance(scale, list): 299 | if all(s == 0 for s in scale): 300 | skip = True 301 | elif scale == 0: 302 | skip = True 303 | if not skip: 304 | if mask is not None: 305 | if not isinstance(scale, list): 306 | scale = [scale] * mask.shape[1] 307 | 308 | current_num_images = mask.shape[1] 309 | for i in range(current_num_images): 310 | ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) 311 | ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) 312 | 313 | ip_key = attn.head_to_batch_dim(ip_key) 314 | ip_value = attn.head_to_batch_dim(ip_value) 315 | 316 | ip_attention_probs = attn.get_attention_scores(query, ip_key, None) 317 | _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) 318 | _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states) 319 | 320 | mask_downsample = IPAdapterMaskProcessor.downsample( 321 | mask[:, i, :, :], 322 | batch_size, 323 | _current_ip_hidden_states.shape[1], 324 | _current_ip_hidden_states.shape[2], 325 | ) 326 | 327 | mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) 328 | 329 | hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) 330 | else: 331 | ip_key = to_k_ip(current_ip_hidden_states) 332 | ip_value = to_v_ip(current_ip_hidden_states) 333 | 334 | ip_key = attn.head_to_batch_dim(ip_key) 335 | ip_value = attn.head_to_batch_dim(ip_value) 336 | 337 | ip_attention_probs = attn.get_attention_scores(query, ip_key, None) 338 | current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) 339 | current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) 340 | 341 | hidden_states = hidden_states + scale * current_ip_hidden_states 342 | 343 | return hidden_states 344 | 345 | def __call__( 346 | self, 347 | attn: Attention, 348 | hidden_states: torch.Tensor, 349 | encoder_hidden_states: Optional[torch.Tensor] = None, 350 | attention_mask: Optional[torch.Tensor] = None, 351 | temb: Optional[torch.Tensor] = None, 352 | ip_adapter_masks: Optional[torch.Tensor] = None, 353 | ) -> torch.Tensor: 354 | # separate ip_hidden_states from encoder_hidden_states 355 | if encoder_hidden_states is not None: 356 | if isinstance(encoder_hidden_states, tuple): 357 | encoder_hidden_states, ip_hidden_states = encoder_hidden_states 358 | else: 359 | deprecation_message = ( 360 | "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release." 361 | " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning." 362 | ) 363 | print(deprecation_message) 364 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] 365 | encoder_hidden_states, ip_hidden_states = ( 366 | encoder_hidden_states[:, :end_pos, :], 367 | [encoder_hidden_states[:, end_pos:, :]], 368 | ) 369 | 370 | if attn.spatial_norm is not None: 371 | hidden_states = attn.spatial_norm(hidden_states, temb) 372 | 373 | input_ndim = hidden_states.ndim 374 | 375 | if input_ndim == 4: 376 | batch_size, channel, height, width = hidden_states.shape 377 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 378 | 379 | batch_size, sequence_length, _ = ( 380 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 381 | ) 382 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 383 | 384 | if attn.group_norm is not None: 385 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 386 | 387 | 388 | modified_hidden_states = einops.rearrange(hidden_states, '(b f) d c -> f b d c', f=self.frames) 389 | 390 | if encoder_hidden_states is None: 391 | framed_context = modified_hidden_states 392 | else: 393 | if attn.norm_cross: 394 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 395 | framed_context = einops.rearrange(encoder_hidden_states, '(b f) d c -> f b d c', f=self.frames) 396 | 397 | 398 | if ip_adapter_masks is not None: 399 | if not isinstance(ip_adapter_masks, List): 400 | # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] 401 | ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) 402 | if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): 403 | raise ValueError( 404 | f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " 405 | f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " 406 | f"({len(ip_hidden_states)})" 407 | ) 408 | else: 409 | for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): 410 | if not isinstance(mask, torch.Tensor) or mask.ndim != 4: 411 | raise ValueError( 412 | "Each element of the ip_adapter_masks array should be a tensor with shape " 413 | "[1, num_images_for_ip_adapter, height, width]." 414 | " Please use `IPAdapterMaskProcessor` to preprocess your mask" 415 | ) 416 | if mask.shape[1] != ip_state.shape[1]: 417 | raise ValueError( 418 | f"Number of masks ({mask.shape[1]}) does not match " 419 | f"number of ip images ({ip_state.shape[1]}) at index {index}" 420 | ) 421 | if isinstance(scale, list) and not len(scale) == mask.shape[1]: 422 | raise ValueError( 423 | f"Number of masks ({mask.shape[1]}) does not match " 424 | f"number of scales ({len(scale)}) at index {index}" 425 | ) 426 | else: 427 | ip_adapter_masks = [None] * len(self.scale) 428 | 429 | 430 | attn_outs = [] 431 | for f in range(self.frames): 432 | fcf = framed_context[f] 433 | 434 | query = self.to_q_lora[f](modified_hidden_states[f], attn.to_q) 435 | key = self.to_k_lora[f](fcf, attn.to_k) 436 | value = self.to_v_lora[f](fcf, attn.to_v) 437 | 438 | query = attn.head_to_batch_dim(query) 439 | key = attn.head_to_batch_dim(key) 440 | value = attn.head_to_batch_dim(value) 441 | 442 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 443 | output = torch.bmm(attention_probs, value) 444 | output = attn.batch_to_head_dim(output) 445 | 446 | # IP-Adapter process 447 | output = self._fuse_ip_adapter( 448 | attn=attn, 449 | batch_size=batch_size, 450 | query=query, 451 | hidden_states=output, 452 | ip_hidden_states=ip_hidden_states, 453 | ip_adapter_masks=ip_adapter_masks 454 | ) 455 | 456 | output = self.to_out_lora[f](output, attn.to_out[0]) 457 | output = attn.to_out[1](output) 458 | attn_outs.append(output) 459 | 460 | attn_outs = torch.stack(attn_outs, dim=0) 461 | modified_hidden_states = modified_hidden_states + attn_outs.to(modified_hidden_states) 462 | modified_hidden_states = einops.rearrange(modified_hidden_states, 'f b d c -> (b f) d c', f=self.frames) 463 | 464 | x = modified_hidden_states 465 | x = self.temporal_n(x) 466 | x = self.temporal_i(x) 467 | d = x.shape[1] 468 | 469 | x = einops.rearrange(x, "(b f) d c -> (b d) f c", f=self.frames) 470 | 471 | query = self.temporal_q(x) 472 | key = self.temporal_k(x) 473 | value = self.temporal_v(x) 474 | 475 | query = attn.head_to_batch_dim(query) 476 | key = attn.head_to_batch_dim(key) 477 | value = attn.head_to_batch_dim(value) 478 | 479 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 480 | x = torch.bmm(attention_probs, value) 481 | x = attn.batch_to_head_dim(x) 482 | 483 | x = self.temporal_o(x) 484 | x = einops.rearrange(x, "(b d) f c -> (b f) d c", d=d) 485 | 486 | modified_hidden_states = modified_hidden_states + x 487 | 488 | hidden_states = modified_hidden_states - hidden_states 489 | 490 | if input_ndim == 4: 491 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 492 | 493 | return hidden_states 494 | 495 | 496 | class AttentionSharingProcessor2_0(nn.Module): 497 | def __init__(self, module, frames=2, rank=256, use_control=False): 498 | super().__init__() 499 | 500 | self.heads = module.heads 501 | self.frames = frames 502 | self.original_module = [module] 503 | q_in_channels, q_out_channels = module.to_q.in_features, module.to_q.out_features 504 | k_in_channels, k_out_channels = module.to_k.in_features, module.to_k.out_features 505 | v_in_channels, v_out_channels = module.to_v.in_features, module.to_v.out_features 506 | o_in_channels, o_out_channels = module.to_out[0].in_features, module.to_out[0].out_features 507 | 508 | hidden_size = k_out_channels 509 | 510 | self.to_q_lora = [LoRALinearLayer(q_in_channels, q_out_channels, rank) for _ in range(self.frames)] 511 | self.to_k_lora = [LoRALinearLayer(k_in_channels, k_out_channels, rank) for _ in range(self.frames)] 512 | self.to_v_lora = [LoRALinearLayer(v_in_channels, v_out_channels, rank) for _ in range(self.frames)] 513 | self.to_out_lora = [LoRALinearLayer(o_in_channels, o_out_channels, rank) for _ in range(self.frames)] 514 | 515 | self.to_q_lora = torch.nn.ModuleList(self.to_q_lora) 516 | self.to_k_lora = torch.nn.ModuleList(self.to_k_lora) 517 | self.to_v_lora = torch.nn.ModuleList(self.to_v_lora) 518 | self.to_out_lora = torch.nn.ModuleList(self.to_out_lora) 519 | 520 | self.temporal_i = torch.nn.Linear(in_features=hidden_size, out_features=hidden_size) 521 | self.temporal_n = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) 522 | self.temporal_q = torch.nn.Linear(in_features=hidden_size, out_features=hidden_size) 523 | self.temporal_k = torch.nn.Linear(in_features=hidden_size, out_features=hidden_size) 524 | self.temporal_v = torch.nn.Linear(in_features=hidden_size, out_features=hidden_size) 525 | self.temporal_o = torch.nn.Linear(in_features=hidden_size, out_features=hidden_size) 526 | 527 | self.control_convs = None 528 | 529 | if use_control: 530 | self.control_convs = [torch.nn.Sequential( 531 | torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), 532 | torch.nn.SiLU(), 533 | torch.nn.Conv2d(256, hidden_size, kernel_size=1), 534 | ) for _ in range(self.frames)] 535 | self.control_convs = torch.nn.ModuleList(self.control_convs) 536 | 537 | def __call__( 538 | self, 539 | attn: Attention, 540 | hidden_states: torch.FloatTensor, 541 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 542 | attention_mask: Optional[torch.FloatTensor] = None, 543 | temb: Optional[torch.Tensor] = None, 544 | layerdiffuse_control_signals: Optional[torch.Tensor] = None, 545 | ) -> torch.Tensor: 546 | 547 | if attn.spatial_norm is not None: 548 | hidden_states = attn.spatial_norm(hidden_states, temb) 549 | 550 | input_ndim = hidden_states.ndim 551 | 552 | if input_ndim == 4: 553 | batch_size, channel, height, width = hidden_states.shape 554 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 555 | 556 | batch_size, sequence_length, _ = ( 557 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 558 | ) 559 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 560 | 561 | if attn.group_norm is not None: 562 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 563 | 564 | # LayerDiffuse main logic 565 | modified_hidden_states = einops.rearrange(hidden_states, '(b f) d c -> f b d c', f=self.frames) 566 | 567 | if self.control_convs is not None: 568 | context_dim = int(modified_hidden_states.shape[2]) 569 | control_outs = [] 570 | for f in range(self.frames): 571 | control_signal = layerdiffuse_control_signals[context_dim].to(modified_hidden_states) 572 | control = self.control_convs[f](control_signal) 573 | control = einops.rearrange(control, 'b c h w -> b (h w) c') 574 | control_outs.append(control) 575 | control_outs = torch.stack(control_outs, dim=0) 576 | modified_hidden_states = modified_hidden_states + control_outs.to(modified_hidden_states) 577 | 578 | if encoder_hidden_states is None: 579 | framed_context = modified_hidden_states 580 | else: 581 | if attn.norm_cross: 582 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 583 | framed_context = einops.rearrange(encoder_hidden_states, '(b f) d c -> f b d c', f=self.frames) 584 | 585 | 586 | attn_outs = [] 587 | for f in range(self.frames): 588 | fcf = framed_context[f] 589 | frame_batch_size = fcf.size(0) 590 | 591 | query = self.to_q_lora[f](modified_hidden_states[f], attn.to_q) 592 | key = self.to_k_lora[f](fcf, attn.to_k) 593 | value = self.to_v_lora[f](fcf, attn.to_v) 594 | 595 | inner_dim = key.shape[-1] 596 | head_dim = inner_dim // attn.heads 597 | 598 | query = query.view(frame_batch_size, -1, attn.heads, head_dim).transpose(1, 2) 599 | key = key.view(frame_batch_size, -1, attn.heads, head_dim).transpose(1, 2) 600 | value = value.view(frame_batch_size, -1, attn.heads, head_dim).transpose(1, 2) 601 | 602 | output = F.scaled_dot_product_attention( 603 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 604 | ) 605 | 606 | output = output.transpose(1, 2).reshape(frame_batch_size, -1, attn.heads * head_dim) 607 | output = output.to(query.dtype) 608 | 609 | output = self.to_out_lora[f](output, attn.to_out[0]) 610 | output = attn.to_out[1](output) 611 | attn_outs.append(output) 612 | 613 | attn_outs = torch.stack(attn_outs, dim=0) 614 | modified_hidden_states = modified_hidden_states + attn_outs.to(modified_hidden_states) 615 | modified_hidden_states = einops.rearrange(modified_hidden_states, 'f b d c -> (b f) d c', f=self.frames) 616 | modified_hidden_states = modified_hidden_states / attn.rescale_output_factor 617 | 618 | x = modified_hidden_states 619 | x = self.temporal_n(x) 620 | x = self.temporal_i(x) 621 | d = x.shape[1] 622 | 623 | x = einops.rearrange(x, "(b f) d c -> (b d) f c", f=self.frames) 624 | 625 | temporal_batch_size = x.size(0) 626 | 627 | query = self.temporal_q(x) 628 | key = self.temporal_k(x) 629 | value = self.temporal_v(x) 630 | 631 | inner_dim = key.shape[-1] 632 | head_dim = inner_dim // attn.heads 633 | 634 | query = query.view(temporal_batch_size, -1, attn.heads, head_dim).transpose(1, 2) 635 | key = key.view(temporal_batch_size, -1, attn.heads, head_dim).transpose(1, 2) 636 | value = value.view(temporal_batch_size, -1, attn.heads, head_dim).transpose(1, 2) 637 | 638 | x = F.scaled_dot_product_attention( 639 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 640 | ) 641 | 642 | x = x.transpose(1, 2).reshape(temporal_batch_size, -1, attn.heads * head_dim) 643 | x = x.to(query.dtype) 644 | 645 | x = self.temporal_o(x) 646 | x = einops.rearrange(x, "(b d) f c -> (b f) d c", d=d) 647 | 648 | modified_hidden_states = modified_hidden_states + x 649 | 650 | hidden_states = modified_hidden_states - hidden_states 651 | 652 | if input_ndim == 4: 653 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 654 | 655 | return hidden_states 656 | 657 | 658 | class IPAdapterAttnShareProcessor2_0(nn.Module): 659 | def __init__(self, module, frames=2, rank=256): 660 | super().__init__() 661 | 662 | self.heads = module.heads 663 | self.frames = frames 664 | self.original_module = [module] 665 | q_in_channels, q_out_channels = module.to_q.in_features, module.to_q.out_features 666 | k_in_channels, k_out_channels = module.to_k.in_features, module.to_k.out_features 667 | v_in_channels, v_out_channels = module.to_v.in_features, module.to_v.out_features 668 | o_in_channels, o_out_channels = module.to_out[0].in_features, module.to_out[0].out_features 669 | 670 | hidden_size = k_out_channels 671 | 672 | self.to_q_lora = [LoRALinearLayer(q_in_channels, q_out_channels, rank) for _ in range(self.frames)] 673 | self.to_k_lora = [LoRALinearLayer(k_in_channels, k_out_channels, rank) for _ in range(self.frames)] 674 | self.to_v_lora = [LoRALinearLayer(v_in_channels, v_out_channels, rank) for _ in range(self.frames)] 675 | self.to_out_lora = [LoRALinearLayer(o_in_channels, o_out_channels, rank) for _ in range(self.frames)] 676 | 677 | self.to_q_lora = torch.nn.ModuleList(self.to_q_lora) 678 | self.to_k_lora = torch.nn.ModuleList(self.to_k_lora) 679 | self.to_v_lora = torch.nn.ModuleList(self.to_v_lora) 680 | self.to_out_lora = torch.nn.ModuleList(self.to_out_lora) 681 | 682 | self.temporal_i = torch.nn.Linear(in_features=hidden_size, out_features=hidden_size) 683 | self.temporal_n = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) 684 | self.temporal_q = torch.nn.Linear(in_features=hidden_size, out_features=hidden_size) 685 | self.temporal_k = torch.nn.Linear(in_features=hidden_size, out_features=hidden_size) 686 | self.temporal_v = torch.nn.Linear(in_features=hidden_size, out_features=hidden_size) 687 | self.temporal_o = torch.nn.Linear(in_features=hidden_size, out_features=hidden_size) 688 | 689 | # IP-Adapter part 690 | self.scale = module.processor.scale 691 | self.num_tokens = module.processor.num_tokens 692 | 693 | self.to_k_ip = module.processor.to_k_ip 694 | self.to_v_ip = module.processor.to_v_ip 695 | 696 | def _fuse_ip_adapter( 697 | self, 698 | attn: Attention, 699 | batch_size: int, 700 | head_dim: int, 701 | query: torch.Tensor, 702 | hidden_states: torch.Tensor, 703 | ip_hidden_states: Optional[torch.Tensor] = None, 704 | ip_adapter_masks: Optional[torch.Tensor] = None, 705 | ): 706 | # for ip-adapter 707 | 708 | for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( 709 | ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks 710 | ): 711 | skip = False 712 | if isinstance(scale, list): 713 | if all(s == 0 for s in scale): 714 | skip = True 715 | elif scale == 0: 716 | skip = True 717 | if not skip: 718 | if mask is not None: 719 | if not isinstance(scale, list): 720 | scale = [scale] * mask.shape[1] 721 | 722 | current_num_images = mask.shape[1] 723 | for i in range(current_num_images): 724 | ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) 725 | ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) 726 | 727 | ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 728 | ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 729 | 730 | _current_ip_hidden_states = F.scaled_dot_product_attention( 731 | query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False 732 | ) 733 | 734 | _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape( 735 | batch_size, -1, attn.heads * head_dim 736 | ) 737 | _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) 738 | 739 | mask_downsample = IPAdapterMaskProcessor.downsample( 740 | mask[:, i, :, :], 741 | batch_size, 742 | _current_ip_hidden_states.shape[1], 743 | _current_ip_hidden_states.shape[2], 744 | ) 745 | 746 | mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) 747 | hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) 748 | else: 749 | ip_key = to_k_ip(current_ip_hidden_states) 750 | ip_value = to_v_ip(current_ip_hidden_states) 751 | 752 | ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 753 | ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 754 | 755 | current_ip_hidden_states = F.scaled_dot_product_attention( 756 | query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False 757 | ) 758 | 759 | current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( 760 | batch_size, -1, attn.heads * head_dim 761 | ) 762 | current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) 763 | 764 | hidden_states = hidden_states + scale * current_ip_hidden_states 765 | 766 | return hidden_states 767 | 768 | def __call__( 769 | self, 770 | attn: Attention, 771 | hidden_states: torch.Tensor, 772 | encoder_hidden_states: Optional[torch.Tensor] = None, 773 | attention_mask: Optional[torch.Tensor] = None, 774 | temb: Optional[torch.Tensor] = None, 775 | ip_adapter_masks: Optional[torch.Tensor] = None, 776 | ) -> torch.Tensor: 777 | # separate ip_hidden_states from encoder_hidden_states 778 | if encoder_hidden_states is not None: 779 | if isinstance(encoder_hidden_states, tuple): 780 | encoder_hidden_states, ip_hidden_states = encoder_hidden_states 781 | else: 782 | deprecation_message = ( 783 | "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release." 784 | " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning." 785 | ) 786 | print(deprecation_message) 787 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] 788 | encoder_hidden_states, ip_hidden_states = ( 789 | encoder_hidden_states[:, :end_pos, :], 790 | [encoder_hidden_states[:, end_pos:, :]], 791 | ) 792 | 793 | if attn.spatial_norm is not None: 794 | hidden_states = attn.spatial_norm(hidden_states, temb) 795 | 796 | input_ndim = hidden_states.ndim 797 | 798 | if input_ndim == 4: 799 | batch_size, channel, height, width = hidden_states.shape 800 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 801 | 802 | batch_size, sequence_length, _ = ( 803 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 804 | ) 805 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 806 | 807 | if attn.group_norm is not None: 808 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 809 | 810 | 811 | modified_hidden_states = einops.rearrange(hidden_states, '(b f) d c -> f b d c', f=self.frames) 812 | 813 | if encoder_hidden_states is None: 814 | framed_context = modified_hidden_states 815 | else: 816 | if attn.norm_cross: 817 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 818 | framed_context = einops.rearrange(encoder_hidden_states, '(b f) d c -> f b d c', f=self.frames) 819 | 820 | 821 | if ip_adapter_masks is not None: 822 | if not isinstance(ip_adapter_masks, List): 823 | # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] 824 | ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) 825 | if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): 826 | raise ValueError( 827 | f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " 828 | f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " 829 | f"({len(ip_hidden_states)})" 830 | ) 831 | else: 832 | for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): 833 | if not isinstance(mask, torch.Tensor) or mask.ndim != 4: 834 | raise ValueError( 835 | "Each element of the ip_adapter_masks array should be a tensor with shape " 836 | "[1, num_images_for_ip_adapter, height, width]." 837 | " Please use `IPAdapterMaskProcessor` to preprocess your mask" 838 | ) 839 | if mask.shape[1] != ip_state.shape[1]: 840 | raise ValueError( 841 | f"Number of masks ({mask.shape[1]}) does not match " 842 | f"number of ip images ({ip_state.shape[1]}) at index {index}" 843 | ) 844 | if isinstance(scale, list) and not len(scale) == mask.shape[1]: 845 | raise ValueError( 846 | f"Number of masks ({mask.shape[1]}) does not match " 847 | f"number of scales ({len(scale)}) at index {index}" 848 | ) 849 | else: 850 | ip_adapter_masks = [None] * len(self.scale) 851 | 852 | 853 | attn_outs = [] 854 | for f in range(self.frames): 855 | fcf = framed_context[f] 856 | frame_batch_size = fcf.size(0) 857 | 858 | query = self.to_q_lora[f](modified_hidden_states[f], attn.to_q) 859 | key = self.to_k_lora[f](fcf, attn.to_k) 860 | value = self.to_v_lora[f](fcf, attn.to_v) 861 | 862 | inner_dim = key.shape[-1] 863 | head_dim = inner_dim // attn.heads 864 | 865 | query = query.view(frame_batch_size, -1, attn.heads, head_dim).transpose(1, 2) 866 | 867 | key = key.view(frame_batch_size, -1, attn.heads, head_dim).transpose(1, 2) 868 | value = value.view(frame_batch_size, -1, attn.heads, head_dim).transpose(1, 2) 869 | 870 | output = F.scaled_dot_product_attention( 871 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 872 | ) 873 | 874 | output = output.transpose(1, 2).reshape(frame_batch_size, -1, attn.heads * head_dim) 875 | output = output.to(query.dtype) 876 | 877 | # IP-Adapter process 878 | output = self._fuse_ip_adapter( 879 | attn=attn, 880 | batch_size=frame_batch_size, 881 | head_dim=head_dim, 882 | query=query, 883 | hidden_states=output, 884 | ip_hidden_states=ip_hidden_states, 885 | ip_adapter_masks=ip_adapter_masks 886 | ) 887 | 888 | output = self.to_out_lora[f](output, attn.to_out[0]) 889 | output = attn.to_out[1](output) 890 | attn_outs.append(output) 891 | 892 | attn_outs = torch.stack(attn_outs, dim=0) 893 | modified_hidden_states = modified_hidden_states + attn_outs.to(modified_hidden_states) 894 | modified_hidden_states = einops.rearrange(modified_hidden_states, 'f b d c -> (b f) d c', f=self.frames) 895 | 896 | x = modified_hidden_states 897 | x = self.temporal_n(x) 898 | x = self.temporal_i(x) 899 | d = x.shape[1] 900 | 901 | x = einops.rearrange(x, "(b f) d c -> (b d) f c", f=self.frames) 902 | 903 | temporal_batch_size = x.size(0) 904 | 905 | query = self.temporal_q(x) 906 | key = self.temporal_k(x) 907 | value = self.temporal_v(x) 908 | 909 | inner_dim = key.shape[-1] 910 | head_dim = inner_dim // attn.heads 911 | 912 | query = query.view(temporal_batch_size, -1, attn.heads, head_dim).transpose(1, 2) 913 | key = key.view(temporal_batch_size, -1, attn.heads, head_dim).transpose(1, 2) 914 | value = value.view(temporal_batch_size, -1, attn.heads, head_dim).transpose(1, 2) 915 | 916 | x = F.scaled_dot_product_attention( 917 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 918 | ) 919 | 920 | x = x.transpose(1, 2).reshape(temporal_batch_size, -1, attn.heads * head_dim) 921 | x = x.to(query.dtype) 922 | 923 | x = self.temporal_o(x) 924 | x = einops.rearrange(x, "(b d) f c -> (b f) d c", d=d) 925 | 926 | modified_hidden_states = modified_hidden_states + x 927 | 928 | hidden_states = modified_hidden_states - hidden_states 929 | 930 | if input_ndim == 4: 931 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 932 | 933 | return hidden_states -------------------------------------------------------------------------------- /layer_diffuse/models/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import cv2 4 | import numpy as np 5 | import importlib.metadata 6 | from packaging.version import parse 7 | from tqdm import tqdm 8 | from typing import Optional, Tuple, Union 9 | from PIL import Image 10 | 11 | from diffusers import AutoencoderKL 12 | from diffusers.utils.torch_utils import randn_tensor 13 | from diffusers.configuration_utils import ConfigMixin, register_to_config 14 | from diffusers.models.modeling_utils import ModelMixin 15 | from diffusers.models.autoencoders.vae import DecoderOutput 16 | 17 | 18 | diffusers_version = importlib.metadata.version('diffusers') 19 | 20 | def check_diffusers_version(min_version="0.25.0"): 21 | assert parse(diffusers_version) >= parse( 22 | min_version 23 | ), f"diffusers>={min_version} requirement not satisfied. Please install correct diffusers version." 24 | 25 | check_diffusers_version() 26 | 27 | if parse(diffusers_version) >= parse("0.29.0"): 28 | from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block 29 | else: 30 | from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block 31 | 32 | 33 | def zero_module(module): 34 | """ 35 | Zero out the parameters of a module and return it. 36 | """ 37 | for p in module.parameters(): 38 | p.detach().zero_() 39 | return module 40 | 41 | 42 | class LatentTransparencyOffsetEncoder(torch.nn.Module): 43 | def __init__(self, *args, **kwargs): 44 | super().__init__(*args, **kwargs) 45 | self.blocks = torch.nn.Sequential( 46 | torch.nn.Conv2d(4, 32, kernel_size=3, padding=1, stride=1), 47 | nn.SiLU(), 48 | torch.nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1), 49 | nn.SiLU(), 50 | torch.nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2), 51 | nn.SiLU(), 52 | torch.nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1), 53 | nn.SiLU(), 54 | torch.nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2), 55 | nn.SiLU(), 56 | torch.nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=1), 57 | nn.SiLU(), 58 | torch.nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), 59 | nn.SiLU(), 60 | torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), 61 | nn.SiLU(), 62 | zero_module(torch.nn.Conv2d(256, 4, kernel_size=3, padding=1, stride=1)), 63 | ) 64 | 65 | def __call__(self, x): 66 | return self.blocks(x) 67 | 68 | 69 | # 1024 * 1024 * 3 -> 16 * 16 * 512 -> 1024 * 1024 * 3 70 | class UNet1024(ModelMixin, ConfigMixin): 71 | @register_to_config 72 | def __init__( 73 | self, 74 | in_channels: int = 3, 75 | out_channels: int = 3, 76 | down_block_types: Tuple[str] = ("DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), 77 | up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"), 78 | block_out_channels: Tuple[int] = (32, 32, 64, 128, 256, 512, 512), 79 | layers_per_block: int = 2, 80 | mid_block_scale_factor: float = 1, 81 | downsample_padding: int = 1, 82 | downsample_type: str = "conv", 83 | upsample_type: str = "conv", 84 | dropout: float = 0.0, 85 | act_fn: str = "silu", 86 | attention_head_dim: Optional[int] = 8, 87 | norm_num_groups: int = 4, 88 | norm_eps: float = 1e-5, 89 | ): 90 | super().__init__() 91 | 92 | # input 93 | self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) 94 | self.latent_conv_in = zero_module(nn.Conv2d(4, block_out_channels[2], kernel_size=1)) 95 | 96 | self.down_blocks = nn.ModuleList([]) 97 | self.mid_block = None 98 | self.up_blocks = nn.ModuleList([]) 99 | 100 | # down 101 | output_channel = block_out_channels[0] 102 | for i, down_block_type in enumerate(down_block_types): 103 | input_channel = output_channel 104 | output_channel = block_out_channels[i] 105 | is_final_block = i == len(block_out_channels) - 1 106 | 107 | down_block = get_down_block( 108 | down_block_type, 109 | num_layers=layers_per_block, 110 | in_channels=input_channel, 111 | out_channels=output_channel, 112 | temb_channels=None, 113 | add_downsample=not is_final_block, 114 | resnet_eps=norm_eps, 115 | resnet_act_fn=act_fn, 116 | resnet_groups=norm_num_groups, 117 | attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, 118 | downsample_padding=downsample_padding, 119 | resnet_time_scale_shift="default", 120 | downsample_type=downsample_type, 121 | dropout=dropout, 122 | ) 123 | self.down_blocks.append(down_block) 124 | 125 | # mid 126 | self.mid_block = UNetMidBlock2D( 127 | in_channels=block_out_channels[-1], 128 | temb_channels=None, 129 | dropout=dropout, 130 | resnet_eps=norm_eps, 131 | resnet_act_fn=act_fn, 132 | output_scale_factor=mid_block_scale_factor, 133 | resnet_time_scale_shift="default", 134 | attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], 135 | resnet_groups=norm_num_groups, 136 | attn_groups=None, 137 | add_attention=True, 138 | ) 139 | 140 | # up 141 | reversed_block_out_channels = list(reversed(block_out_channels)) 142 | output_channel = reversed_block_out_channels[0] 143 | for i, up_block_type in enumerate(up_block_types): 144 | prev_output_channel = output_channel 145 | output_channel = reversed_block_out_channels[i] 146 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 147 | 148 | is_final_block = i == len(block_out_channels) - 1 149 | 150 | up_block = get_up_block( 151 | up_block_type, 152 | num_layers=layers_per_block + 1, 153 | in_channels=input_channel, 154 | out_channels=output_channel, 155 | prev_output_channel=prev_output_channel, 156 | temb_channels=None, 157 | add_upsample=not is_final_block, 158 | resnet_eps=norm_eps, 159 | resnet_act_fn=act_fn, 160 | resnet_groups=norm_num_groups, 161 | attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, 162 | resnet_time_scale_shift="default", 163 | upsample_type=upsample_type, 164 | dropout=dropout, 165 | ) 166 | self.up_blocks.append(up_block) 167 | prev_output_channel = output_channel 168 | 169 | # out 170 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) 171 | self.conv_act = nn.SiLU() 172 | self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) 173 | 174 | def forward(self, x, latent): 175 | sample_latent = self.latent_conv_in(latent) 176 | sample = self.conv_in(x) 177 | emb = None 178 | 179 | down_block_res_samples = (sample,) 180 | for i, downsample_block in enumerate(self.down_blocks): 181 | if i == 3: 182 | sample = sample + sample_latent 183 | 184 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 185 | down_block_res_samples += res_samples 186 | 187 | sample = self.mid_block(sample, emb) 188 | 189 | for upsample_block in self.up_blocks: 190 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 191 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 192 | sample = upsample_block(sample, res_samples, emb) 193 | 194 | sample = self.conv_norm_out(sample) 195 | sample = self.conv_act(sample) 196 | sample = self.conv_out(sample) 197 | return sample 198 | 199 | 200 | def checkerboard(shape): 201 | return np.indices(shape).sum(axis=0) % 2 202 | 203 | 204 | class TransparentVAEDecoder(AutoencoderKL): 205 | @register_to_config 206 | def __init__( 207 | self, 208 | in_channels: int = 3, 209 | out_channels: int = 3, 210 | down_block_types: Tuple[str] = ("DownEncoderBlock2D",), 211 | up_block_types: Tuple[str] = ("UpDecoderBlock2D",), 212 | block_out_channels: Tuple[int] = (64,), 213 | layers_per_block: int = 1, 214 | act_fn: str = "silu", 215 | latent_channels: int = 4, 216 | norm_num_groups: int = 32, 217 | sample_size: int = 32, 218 | scaling_factor: float = 0.18215, 219 | latents_mean: Optional[Tuple[float]] = None, 220 | latents_std: Optional[Tuple[float]] = None, 221 | force_upcast: float = True, 222 | ): 223 | self.mod_number = None 224 | super().__init__(in_channels, out_channels, down_block_types, up_block_types, block_out_channels, layers_per_block, act_fn, latent_channels, norm_num_groups, sample_size, scaling_factor, latents_mean, latents_std, force_upcast) 225 | 226 | def set_transparent_decoder(self, sd, mod_number=1): 227 | model = UNet1024(in_channels=3, out_channels=4) 228 | model.load_state_dict(sd, strict=True) 229 | model.to(device=self.device, dtype=self.dtype) 230 | model.eval() 231 | 232 | self.transparent_decoder = model 233 | self.mod_number = mod_number 234 | 235 | def estimate_single_pass(self, pixel, latent): 236 | y = self.transparent_decoder(pixel, latent) 237 | return y 238 | 239 | def estimate_augmented(self, pixel, latent): 240 | args = [ 241 | [False, 0], [False, 1], [False, 2], [False, 3], [True, 0], [True, 1], [True, 2], [True, 3], 242 | ] 243 | 244 | result = [] 245 | 246 | for flip, rok in tqdm(args): 247 | feed_pixel = pixel.clone() 248 | feed_latent = latent.clone() 249 | 250 | if flip: 251 | feed_pixel = torch.flip(feed_pixel, dims=(3,)) 252 | feed_latent = torch.flip(feed_latent, dims=(3,)) 253 | 254 | feed_pixel = torch.rot90(feed_pixel, k=rok, dims=(2, 3)) 255 | feed_latent = torch.rot90(feed_latent, k=rok, dims=(2, 3)) 256 | 257 | eps = self.estimate_single_pass(feed_pixel, feed_latent).clip(0, 1) 258 | eps = torch.rot90(eps, k=-rok, dims=(2, 3)) 259 | 260 | if flip: 261 | eps = torch.flip(eps, dims=(3,)) 262 | 263 | result += [eps] 264 | 265 | result = torch.stack(result, dim=0) 266 | median = torch.median(result, dim=0).values 267 | return median 268 | 269 | def decode(self, z: torch.Tensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.Tensor]: 270 | pixel = super().decode(z, return_dict=False, generator=generator)[0] 271 | pixel = pixel / 2 + 0.5 272 | 273 | 274 | result_pixel = [] 275 | for i in range(int(z.shape[0])): 276 | if self.mod_number is None or (self.mod_number != 1 and i % self.mod_number != 0): 277 | img = torch.cat((pixel[i:i+1], torch.ones_like(pixel[i:i+1,:1,:,:])), dim=1) 278 | result_pixel.append(img) 279 | continue 280 | 281 | y = self.estimate_augmented(pixel[i:i+1], z[i:i+1]) 282 | 283 | y = y.clip(0, 1).movedim(1, -1) 284 | alpha = y[..., :1] 285 | fg = y[..., 1:] 286 | 287 | B, H, W, C = fg.shape 288 | cb = checkerboard(shape=(H // 64, W // 64)) 289 | cb = cv2.resize(cb, (W, H), interpolation=cv2.INTER_NEAREST) 290 | cb = (0.5 + (cb - 0.5) * 0.1)[None, ..., None] 291 | cb = torch.from_numpy(cb).to(fg) 292 | 293 | png = torch.cat([fg, alpha], dim=3) 294 | png = png.permute(0, 3, 1, 2) 295 | result_pixel.append(png) 296 | 297 | result_pixel = torch.cat(result_pixel, dim=0) 298 | result_pixel = (result_pixel - 0.5) * 2 299 | 300 | if not return_dict: 301 | return (result_pixel, ) 302 | return DecoderOutput(sample=result_pixel) 303 | 304 | 305 | def build_alpha_pyramid(color, alpha, dk=1.2): 306 | pyramid = [] 307 | current_premultiplied_color = color * alpha 308 | current_alpha = alpha 309 | 310 | while True: 311 | pyramid.append((current_premultiplied_color, current_alpha)) 312 | 313 | H, W, C = current_alpha.shape 314 | if min(H, W) == 1: 315 | break 316 | 317 | current_premultiplied_color = cv2.resize(current_premultiplied_color, (int(W / dk), int(H / dk)), interpolation=cv2.INTER_AREA) 318 | current_alpha = cv2.resize(current_alpha, (int(W / dk), int(H / dk)), interpolation=cv2.INTER_AREA)[:, :, None] 319 | return pyramid[::-1] 320 | 321 | 322 | def pad_rgb(np_rgba_hwc_uint8): 323 | np_rgba_hwc = np_rgba_hwc_uint8.astype(np.float32) / 255.0 324 | pyramid = build_alpha_pyramid(color=np_rgba_hwc[..., :3], alpha=np_rgba_hwc[..., 3:]) 325 | 326 | top_c, top_a = pyramid[0] 327 | fg = np.sum(top_c, axis=(0, 1), keepdims=True) / np.sum(top_a, axis=(0, 1), keepdims=True).clip(1e-8, 1e32) 328 | 329 | for layer_c, layer_a in pyramid: 330 | layer_h, layer_w, _ = layer_c.shape 331 | fg = cv2.resize(fg, (layer_w, layer_h), interpolation=cv2.INTER_LINEAR) 332 | fg = layer_c + fg * (1.0 - layer_a) 333 | 334 | return fg 335 | 336 | 337 | def convert_rgba2rgb(img): 338 | background = Image.new("RGB", img.size, (127, 127, 127)) 339 | background.paste(img, mask=img.split()[3]) 340 | return background 341 | 342 | 343 | class TransparentVAEEncoder: 344 | def __init__(self, sd, device="cpu", torch_dtype=torch.float32): 345 | self.load_device = device 346 | self.dtype = torch_dtype 347 | 348 | model = LatentTransparencyOffsetEncoder() 349 | model.load_state_dict(sd, strict=True) 350 | model.to(device=self.load_device, dtype=self.dtype) 351 | model.eval() 352 | 353 | self.model = model 354 | 355 | @torch.no_grad() 356 | def _encode(self, image): 357 | list_of_np_rgba_hwc_uint8 = [np.array(image)] 358 | list_of_np_rgb_padded = [pad_rgb(x) for x in list_of_np_rgba_hwc_uint8] 359 | rgb_padded_bchw_01 = torch.from_numpy(np.stack(list_of_np_rgb_padded, axis=0)).float().movedim(-1, 1) 360 | rgba_bchw_01 = torch.from_numpy(np.stack(list_of_np_rgba_hwc_uint8, axis=0)).float().movedim(-1, 1) / 255.0 361 | a_bchw_01 = rgba_bchw_01[:, 3:, :, :] 362 | offset_feed = torch.cat([a_bchw_01, rgb_padded_bchw_01], dim=1).to(device=self.load_device, dtype=self.dtype) 363 | offset = self.model(offset_feed) 364 | return offset 365 | 366 | def encode(self, image, pipeline, mask=None): 367 | latent_offset = self._encode(image) 368 | 369 | init_image = convert_rgba2rgb(image) 370 | 371 | init_image = pipeline.image_processor.preprocess(init_image) 372 | init_image = init_image.to(device=pipeline.vae.device, dtype=pipeline.vae.dtype) 373 | latents = pipeline.vae.encode(init_image).latent_dist 374 | latents = latents.mean + latents.std * latent_offset.to(latents.mean) 375 | latents = pipeline.vae.config.scaling_factor * latents 376 | 377 | if mask is not None: 378 | mask = pipeline.mask_processor.preprocess(mask) 379 | mask = mask.to(device=pipeline.vae.device, dtype=pipeline.vae.dtype) 380 | masked_image = init_image * (mask < 0.5) 381 | masked_image_latents = pipeline.vae.encode(masked_image).latent_dist 382 | masked_image_latents = masked_image_latents.mean + masked_image_latents.std * latent_offset.to(masked_image_latents.mean) 383 | masked_image_latents = pipeline.vae.config.scaling_factor * masked_image_latents 384 | 385 | return latents, masked_image_latents 386 | 387 | return latents -------------------------------------------------------------------------------- /layer_diffuse/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | 5 | def rgba2rgbfp32(x): 6 | rgb = x[..., :3].astype(np.float32) / 255.0 7 | a = x[..., 3:4].astype(np.float32) / 255.0 8 | return 0.5 + (rgb - 0.5) * a 9 | 10 | 11 | def to255unit8(x): 12 | return (x * 255.0).clip(0, 255).astype(np.uint8) 13 | 14 | 15 | def safe_numpy(x): 16 | # A very safe method to make sure that Apple/Mac works 17 | y = x 18 | 19 | # below is very boring but do not change these. If you change these Apple or Mac may fail. 20 | y = y.copy() 21 | y = np.ascontiguousarray(y) 22 | y = y.copy() 23 | return y 24 | 25 | 26 | def high_quality_resize(x, size): 27 | if x.shape[0] != size[1] or x.shape[1] != size[0]: 28 | if (size[0] * size[1]) < (x.shape[0] * x.shape[1]): 29 | interpolation = cv2.INTER_AREA 30 | else: 31 | interpolation = cv2.INTER_LANCZOS4 32 | 33 | y = cv2.resize(x, size, interpolation=interpolation) 34 | else: 35 | y = x 36 | return y 37 | 38 | 39 | def crop_and_resize_image(detected_map, resize_mode, h, w): 40 | if resize_mode == 0: 41 | detected_map = high_quality_resize(detected_map, (w, h)) 42 | detected_map = safe_numpy(detected_map) 43 | return detected_map 44 | 45 | old_h, old_w, _ = detected_map.shape 46 | old_w = float(old_w) 47 | old_h = float(old_h) 48 | k0 = float(h) / old_h 49 | k1 = float(w) / old_w 50 | 51 | safeint = lambda x: int(np.round(x)) 52 | 53 | if resize_mode == 2: 54 | k = min(k0, k1) 55 | borders = np.concatenate([detected_map[0, :, :], detected_map[-1, :, :], detected_map[:, 0, :], detected_map[:, -1, :]], axis=0) 56 | high_quality_border_color = np.median(borders, axis=0).astype(detected_map.dtype) 57 | high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1]) 58 | detected_map = high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k))) 59 | new_h, new_w, _ = detected_map.shape 60 | pad_h = max(0, (h - new_h) // 2) 61 | pad_w = max(0, (w - new_w) // 2) 62 | high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = detected_map 63 | detected_map = high_quality_background 64 | detected_map = safe_numpy(detected_map) 65 | return detected_map 66 | else: 67 | k = max(k0, k1) 68 | detected_map = high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k))) 69 | new_h, new_w, _ = detected_map.shape 70 | pad_h = max(0, (new_h - h) // 2) 71 | pad_w = max(0, (new_w - w) // 2) 72 | detected_map = detected_map[pad_h:pad_h+h, pad_w:pad_w+w] 73 | detected_map = safe_numpy(detected_map) 74 | return detected_map 75 | 76 | 77 | def pytorch_to_numpy(x): 78 | return [np.clip(255. * y.cpu().numpy(), 0, 255).astype(np.uint8) for y in x] 79 | 80 | 81 | def numpy_to_pytorch(x): 82 | y = x.astype(np.float32) / 255.0 83 | y = y[None] 84 | y = np.ascontiguousarray(y.copy()) 85 | y = torch.from_numpy(y).float() 86 | return y -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers 2 | torch 3 | safetensors 4 | huggingface_hub 5 | einops 6 | opencv-python 7 | tqdm -------------------------------------------------------------------------------- /test_diffusers_bg_fg_cond.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from huggingface_hub import hf_hub_download 4 | import torch 5 | import numpy as np 6 | 7 | from diffusers import StableDiffusionPipeline 8 | 9 | from layer_diffuse.models import TransparentVAEDecoder 10 | from layer_diffuse.loaders import load_lora_to_unet 11 | from layer_diffuse.utils import rgba2rgbfp32, crop_and_resize_image 12 | 13 | 14 | 15 | if __name__ == "__main__": 16 | 17 | model_path = hf_hub_download( 18 | 'LayerDiffusion/layerdiffusion-v1', 19 | 'layer_sd15_vae_transparent_decoder.safetensors', 20 | ) 21 | 22 | vae_transparent_decoder = TransparentVAEDecoder.from_pretrained("digiplay/Juggernaut_final", subfolder="vae", torch_dtype=torch.float16).to("cuda") 23 | 24 | pipeline = StableDiffusionPipeline.from_pretrained("digiplay/Juggernaut_final", vae=vae_transparent_decoder, torch_dtype=torch.float16, safety_checker=None).to("cuda") 25 | 26 | model_path = hf_hub_download( 27 | 'LayerDiffusion/layerdiffusion-v1', 28 | 'layer_sd15_fg2bg.safetensors' 29 | ) 30 | 31 | kwargs_encoder = load_lora_to_unet(pipeline.unet, model_path, frames=2, use_control=True) 32 | 33 | fg_image = np.array(Image.open(os.path.join("assets", "fg_cond.png"))) 34 | fg_image = crop_and_resize_image(rgba2rgbfp32(fg_image), 1, 512, 512) 35 | fg_image = torch.from_numpy(np.ascontiguousarray(fg_image[None].copy())).movedim(-1, 1) 36 | fg_image = fg_image.cpu().float() * 2.0 - 1.0 37 | fg_signal = kwargs_encoder(fg_image) 38 | 39 | 40 | 41 | image = pipeline(prompt="in room, high quality, 4K", 42 | width=512, height=512, 43 | cross_attention_kwargs={"layerdiffuse_control_signals": fg_signal}, 44 | num_images_per_prompt=2, return_dict=False)[0] 45 | 46 | 47 | image[0].save("fg_result.png") 48 | image[1].save("fg_result1.png") 49 | 50 | -------------------------------------------------------------------------------- /test_diffusers_fg_bg_cond.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from huggingface_hub import hf_hub_download 4 | from safetensors.torch import load_file 5 | import torch 6 | import numpy as np 7 | 8 | from diffusers import StableDiffusionPipeline 9 | 10 | from layer_diffuse.models import TransparentVAEDecoder 11 | from layer_diffuse.loaders import load_lora_to_unet 12 | from layer_diffuse.utils import rgba2rgbfp32, crop_and_resize_image 13 | 14 | 15 | 16 | if __name__ == "__main__": 17 | 18 | model_path = hf_hub_download( 19 | 'LayerDiffusion/layerdiffusion-v1', 20 | 'layer_sd15_vae_transparent_decoder.safetensors', 21 | ) 22 | 23 | vae_transparent_decoder = TransparentVAEDecoder.from_pretrained("digiplay/Juggernaut_final", subfolder="vae", torch_dtype=torch.float16).to("cuda") 24 | vae_transparent_decoder.set_transparent_decoder(load_file(model_path), mod_number=2) 25 | 26 | pipeline = StableDiffusionPipeline.from_pretrained("digiplay/Juggernaut_final", vae=vae_transparent_decoder, torch_dtype=torch.float16, safety_checker=None).to("cuda") 27 | 28 | model_path = hf_hub_download( 29 | 'LayerDiffusion/layerdiffusion-v1', 30 | 'layer_sd15_bg2fg.safetensors' 31 | ) 32 | 33 | kwargs_encoder = load_lora_to_unet(pipeline.unet, model_path, frames=2, use_control=True) 34 | 35 | bg_image = np.array(Image.open(os.path.join("assets", "bg_cond.png"))) 36 | bg_image = crop_and_resize_image(rgba2rgbfp32(bg_image), 1, 512, 512) 37 | bg_image = torch.from_numpy(np.ascontiguousarray(bg_image[None].copy())).movedim(-1, 1) 38 | bg_image = bg_image.cpu().float() * 2.0 - 1.0 39 | bg_signal = kwargs_encoder(bg_image) 40 | 41 | 42 | 43 | image = pipeline(prompt="a dog sitting in room, high quality", 44 | width=512, height=512, 45 | cross_attention_kwargs={"layerdiffuse_control_signals": bg_signal}, 46 | num_images_per_prompt=2, return_dict=False)[0] 47 | 48 | 49 | image[0].save("result.png") 50 | image[1].save("result1.png") 51 | 52 | -------------------------------------------------------------------------------- /test_diffusers_fg_only.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import hf_hub_download 2 | from safetensors.torch import load_file 3 | import torch 4 | 5 | from diffusers import StableDiffusionPipeline 6 | 7 | from layer_diffuse.models import TransparentVAEDecoder 8 | from layer_diffuse.loaders import load_lora_to_unet 9 | 10 | 11 | 12 | if __name__ == "__main__": 13 | 14 | model_path = hf_hub_download( 15 | 'LayerDiffusion/layerdiffusion-v1', 16 | 'layer_sd15_vae_transparent_decoder.safetensors', 17 | ) 18 | 19 | vae_transparent_decoder = TransparentVAEDecoder.from_pretrained("digiplay/Juggernaut_final", subfolder="vae", torch_dtype=torch.float16).to("cuda") 20 | vae_transparent_decoder.set_transparent_decoder(load_file(model_path)) 21 | 22 | pipeline = StableDiffusionPipeline.from_pretrained("digiplay/Juggernaut_final", vae=vae_transparent_decoder, torch_dtype=torch.float16, safety_checker=None).to("cuda") 23 | 24 | model_path = hf_hub_download( 25 | 'LayerDiffusion/layerdiffusion-v1', 26 | 'layer_sd15_transparent_attn.safetensors' 27 | ) 28 | 29 | load_lora_to_unet(pipeline.unet, model_path, frames=1) 30 | 31 | image = pipeline(prompt="a dog sitting in room, high quality", 32 | width=512, height=512, 33 | num_images_per_prompt=3, return_dict=False)[0] 34 | 35 | 36 | image[0].save("result.png") 37 | image[1].save("result1.png") 38 | image[2].save("result2.png") 39 | 40 | -------------------------------------------------------------------------------- /test_diffusers_fg_only_conv_sdxl.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import hf_hub_download 2 | from safetensors.torch import load_file 3 | import torch 4 | 5 | from diffusers import StableDiffusionXLPipeline 6 | 7 | from layer_diffuse.models import TransparentVAEDecoder 8 | from layer_diffuse.loaders import merge_delta_weights_into_unet 9 | 10 | 11 | if __name__ == "__main__": 12 | 13 | transparent_vae = TransparentVAEDecoder.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) 14 | transparent_vae.config.force_upcast = False 15 | model_path = hf_hub_download( 16 | 'LayerDiffusion/layerdiffusion-v1', 17 | 'vae_transparent_decoder.safetensors', 18 | ) 19 | transparent_vae.set_transparent_decoder(load_file(model_path)) 20 | 21 | pipeline = StableDiffusionXLPipeline.from_pretrained( 22 | "stabilityai/stable-diffusion-xl-base-1.0", 23 | vae=transparent_vae, 24 | torch_dtype=torch.float16, variant="fp16", use_safetensors=True, add_watermarker=False 25 | ).to("cuda") 26 | model_path = hf_hub_download( 27 | 'rootonchair/diffuser_layerdiffuse', 28 | 'diffuser_layer_xl_transparent_conv.safetensors' 29 | ) 30 | diff_state_dict = load_file(model_path) 31 | merge_delta_weights_into_unet(pipeline, diff_state_dict) 32 | 33 | 34 | seed = torch.randint(high=1000000, size=(1,)).item() 35 | prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" 36 | negative_prompt = "bad quality, distorted" 37 | images = pipeline(prompt=prompt, 38 | negative_prompt=negative_prompt, 39 | generator=torch.Generator(device='cuda').manual_seed(seed), 40 | num_images_per_prompt=1, return_dict=False)[0] 41 | 42 | images[0].save("result_conv_sdxl.png") 43 | -------------------------------------------------------------------------------- /test_diffusers_fg_only_sdxl.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import hf_hub_download 2 | from safetensors.torch import load_file 3 | import torch 4 | 5 | from diffusers import StableDiffusionXLPipeline 6 | 7 | from layer_diffuse.models import TransparentVAEDecoder 8 | 9 | if __name__ == "__main__": 10 | 11 | transparent_vae = TransparentVAEDecoder.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) 12 | transparent_vae.config.force_upcast = False 13 | model_path = hf_hub_download( 14 | 'LayerDiffusion/layerdiffusion-v1', 15 | 'vae_transparent_decoder.safetensors', 16 | ) 17 | transparent_vae.set_transparent_decoder(load_file(model_path)) 18 | 19 | pipeline = StableDiffusionXLPipeline.from_pretrained( 20 | "stabilityai/stable-diffusion-xl-base-1.0", 21 | vae=transparent_vae, 22 | torch_dtype=torch.float16, variant="fp16", use_safetensors=True, add_watermarker=False 23 | ).to("cuda") 24 | pipeline.load_lora_weights('rootonchair/diffuser_layerdiffuse', weight_name='diffuser_layer_xl_transparent_attn.safetensors') 25 | 26 | seed = torch.randint(high=1000000, size=(1,)).item() 27 | prompt = "a cute corgi" 28 | negative_prompt = "" 29 | images = pipeline(prompt=prompt, 30 | negative_prompt=negative_prompt, 31 | generator=torch.Generator(device='cuda').manual_seed(seed), 32 | num_images_per_prompt=1, return_dict=False)[0] 33 | 34 | images[0].save("result_sdxl.png") 35 | 36 | -------------------------------------------------------------------------------- /test_diffusers_fg_only_sdxl_img2img.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import hf_hub_download 2 | from safetensors.torch import load_file 3 | import torch 4 | from PIL import Image 5 | 6 | from diffusers import StableDiffusionXLInpaintPipeline 7 | 8 | from layer_diffuse.models import TransparentVAEDecoder, TransparentVAEEncoder 9 | 10 | 11 | if __name__ == "__main__": 12 | model_path = hf_hub_download( 13 | 'LayerDiffusion/layerdiffusion-v1', 14 | 'vae_transparent_encoder.safetensors' 15 | ) 16 | 17 | vae_transparent_encoder = TransparentVAEEncoder(load_file(model_path)) 18 | 19 | transparent_vae = TransparentVAEDecoder.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda") 20 | transparent_vae.config.force_upcast = False 21 | model_path = hf_hub_download( 22 | 'LayerDiffusion/layerdiffusion-v1', 23 | 'vae_transparent_decoder.safetensors', 24 | ) 25 | transparent_vae.set_transparent_decoder(load_file(model_path)) 26 | 27 | pipeline = StableDiffusionXLInpaintPipeline.from_pretrained( 28 | "stabilityai/stable-diffusion-xl-base-1.0", 29 | vae=transparent_vae, 30 | torch_dtype=torch.float16, variant="fp16", use_safetensors=True, add_watermarker=False 31 | ).to("cuda") 32 | pipeline.load_lora_weights('rootonchair/diffuser_layerdiffuse', weight_name='diffuser_layer_xl_transparent_attn.safetensors') 33 | 34 | init_image = Image.open("assets/man_crop.png").resize((1024, 1024)) 35 | mask_image = Image.open("assets/man_mask.png") 36 | 37 | latents, masked_image_latents = vae_transparent_encoder.encode(init_image, pipeline, mask=mask_image) 38 | 39 | 40 | seed = 42 41 | prompt = "a handsome man" 42 | negative_prompt = "bad, ugly" 43 | images = pipeline(prompt=prompt, 44 | negative_prompt=negative_prompt, 45 | image=latents, 46 | masked_image_latents=masked_image_latents, 47 | strength=1.0, 48 | mask_image=mask_image, 49 | generator=torch.Generator(device='cuda').manual_seed(seed), 50 | num_images_per_prompt=1, return_dict=False)[0] 51 | 52 | images[0].save("result_inpaint_sdxl.png") 53 | 54 | -------------------------------------------------------------------------------- /test_diffusers_joint.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import hf_hub_download 2 | from safetensors.torch import load_file 3 | import torch 4 | 5 | from diffusers import StableDiffusionPipeline 6 | 7 | from layer_diffuse.models import TransparentVAEDecoder 8 | from layer_diffuse.loaders import load_lora_to_unet 9 | 10 | 11 | 12 | if __name__ == "__main__": 13 | model_path = hf_hub_download( 14 | 'LayerDiffusion/layerdiffusion-v1', 15 | 'layer_sd15_vae_transparent_decoder.safetensors', 16 | ) 17 | 18 | vae_transparent_decoder = TransparentVAEDecoder.from_pretrained("digiplay/Juggernaut_final", subfolder="vae", torch_dtype=torch.float16).to("cuda") 19 | vae_transparent_decoder.set_transparent_decoder(load_file(model_path), mod_number=3) 20 | pipeline = StableDiffusionPipeline.from_pretrained("digiplay/Juggernaut_final", vae=vae_transparent_decoder, safety_checker=None, torch_dtype=torch.float16).to("cuda") 21 | 22 | model_path = hf_hub_download( 23 | 'LayerDiffusion/layerdiffusion-v1', 24 | 'layer_sd15_joint.safetensors' 25 | ) 26 | 27 | load_lora_to_unet(pipeline.unet, model_path, frames=3) 28 | 29 | image = pipeline(prompt="a dog sitting in room, high quality", width=512, height=512, num_images_per_prompt=3, return_dict=False)[0] 30 | 31 | 32 | image[0].save("result_joint_0.png") 33 | image[1].save("result_joint_1.png") 34 | image[2].save("result_joint_2.png") 35 | 36 | --------------------------------------------------------------------------------