├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── configs ├── stable_diffusion │ └── tokenizer │ │ ├── merges.txt │ │ ├── special_tokens_map.json │ │ ├── tokenizer_config.json │ │ └── vocab.json └── stable_diffusion_xl │ └── tokenizer_2 │ ├── merges.txt │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── vocab.json ├── diffsynth ├── __init__.py ├── controlnets │ ├── __init__.py │ ├── controlnet_unit.py │ └── processors.py ├── data │ ├── __init__.py │ └── video.py ├── extensions │ ├── ESRGAN │ │ └── __init__.py │ ├── FastBlend │ │ ├── __init__.py │ │ ├── api.py │ │ ├── cupy_kernels.py │ │ ├── data.py │ │ ├── patch_match.py │ │ └── runners │ │ │ ├── __init__.py │ │ │ ├── accurate.py │ │ │ ├── balanced.py │ │ │ ├── fast.py │ │ │ └── interpolation.py │ └── RIFE │ │ └── __init__.py ├── models │ ├── __init__.py │ ├── attention.py │ ├── hunyuan_dit.py │ ├── hunyuan_dit_text_encoder.py │ ├── sd_controlnet.py │ ├── sd_ipadapter.py │ ├── sd_lora.py │ ├── sd_motion.py │ ├── sd_text_encoder.py │ ├── sd_unet.py │ ├── sd_vae_decoder.py │ ├── sd_vae_encoder.py │ ├── sdxl_ipadapter.py │ ├── sdxl_motion.py │ ├── sdxl_text_encoder.py │ ├── sdxl_unet.py │ ├── sdxl_vae_decoder.py │ ├── sdxl_vae_encoder.py │ ├── svd_image_encoder.py │ ├── svd_unet.py │ ├── svd_vae_decoder.py │ ├── svd_vae_encoder.py │ └── tiler.py ├── pipelines │ ├── __init__.py │ ├── dancer.py │ ├── hunyuan_dit.py │ ├── stable_diffusion.py │ ├── stable_diffusion_video.py │ ├── stable_diffusion_xl.py │ ├── stable_diffusion_xl_video.py │ └── stable_video_diffusion.py ├── processors │ ├── FastBlend.py │ ├── PILEditor.py │ ├── RIFE.py │ ├── __init__.py │ ├── base.py │ └── sequencial_processor.py ├── prompts │ ├── __init__.py │ ├── hunyuan_dit_prompter.py │ ├── sd_prompter.py │ ├── sdxl_prompter.py │ └── utils.py └── schedulers │ ├── __init__.py │ ├── continuous_ode.py │ └── ddim.py ├── examples ├── exvideo_example_workflow_01.json └── stone_elemental_example.png └── nodes.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *pyc 3 | .vscode 4 | __pycache__ 5 | *.egg-info 6 | *.bak 7 | checkpoints 8 | results 9 | backup -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2023] [Zhongjie Duan] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI DiffSynth wrapper nodes 2 | 3 | # WORK IN PROGRESS 4 | 5 | Currently only the new extended SVD model "ExVideo" is supported. 6 | 7 | 8 | ## Original repo: 9 | https://github.com/modelscope/DiffSynth-Studio 10 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] -------------------------------------------------------------------------------- /configs/stable_diffusion/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "<|startoftext|>", 4 | "lstrip": false, 5 | "normalized": true, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "<|endoftext|>", 11 | "lstrip": false, 12 | "normalized": true, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "pad_token": "<|endoftext|>", 17 | "unk_token": { 18 | "content": "<|endoftext|>", 19 | "lstrip": false, 20 | "normalized": true, 21 | "rstrip": false, 22 | "single_word": false 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /configs/stable_diffusion/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "bos_token": { 4 | "__type": "AddedToken", 5 | "content": "<|startoftext|>", 6 | "lstrip": false, 7 | "normalized": true, 8 | "rstrip": false, 9 | "single_word": false 10 | }, 11 | "do_lower_case": true, 12 | "eos_token": { 13 | "__type": "AddedToken", 14 | "content": "<|endoftext|>", 15 | "lstrip": false, 16 | "normalized": true, 17 | "rstrip": false, 18 | "single_word": false 19 | }, 20 | "errors": "replace", 21 | "model_max_length": 77, 22 | "name_or_path": "openai/clip-vit-large-patch14", 23 | "pad_token": "<|endoftext|>", 24 | "special_tokens_map_file": "./special_tokens_map.json", 25 | "tokenizer_class": "CLIPTokenizer", 26 | "unk_token": { 27 | "__type": "AddedToken", 28 | "content": "<|endoftext|>", 29 | "lstrip": false, 30 | "normalized": true, 31 | "rstrip": false, 32 | "single_word": false 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "<|startoftext|>", 4 | "lstrip": false, 5 | "normalized": true, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "<|endoftext|>", 11 | "lstrip": false, 12 | "normalized": true, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "pad_token": "!", 17 | "unk_token": { 18 | "content": "<|endoftext|>", 19 | "lstrip": false, 20 | "normalized": true, 21 | "rstrip": false, 22 | "single_word": false 23 | } 24 | } -------------------------------------------------------------------------------- /configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "added_tokens_decoder": { 4 | "0": { 5 | "content": "!", 6 | "lstrip": false, 7 | "normalized": false, 8 | "rstrip": false, 9 | "single_word": false, 10 | "special": true 11 | }, 12 | "49406": { 13 | "content": "<|startoftext|>", 14 | "lstrip": false, 15 | "normalized": true, 16 | "rstrip": false, 17 | "single_word": false, 18 | "special": true 19 | }, 20 | "49407": { 21 | "content": "<|endoftext|>", 22 | "lstrip": false, 23 | "normalized": true, 24 | "rstrip": false, 25 | "single_word": false, 26 | "special": true 27 | } 28 | }, 29 | "bos_token": "<|startoftext|>", 30 | "clean_up_tokenization_spaces": true, 31 | "do_lower_case": true, 32 | "eos_token": "<|endoftext|>", 33 | "errors": "replace", 34 | "model_max_length": 77, 35 | "pad_token": "!", 36 | "tokenizer_class": "CLIPTokenizer", 37 | "unk_token": "<|endoftext|>" 38 | } -------------------------------------------------------------------------------- /diffsynth/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import * 2 | from .models import * 3 | from .prompts import * 4 | from .schedulers import * 5 | from .pipelines import * 6 | from .controlnets import * 7 | -------------------------------------------------------------------------------- /diffsynth/controlnets/__init__.py: -------------------------------------------------------------------------------- 1 | from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager 2 | from .processors import Annotator 3 | -------------------------------------------------------------------------------- /diffsynth/controlnets/controlnet_unit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .processors import Processor_id 4 | 5 | 6 | class ControlNetConfigUnit: 7 | def __init__(self, processor_id: Processor_id, model_path, scale=1.0): 8 | self.processor_id = processor_id 9 | self.model_path = model_path 10 | self.scale = scale 11 | 12 | 13 | class ControlNetUnit: 14 | def __init__(self, processor, model, scale=1.0): 15 | self.processor = processor 16 | self.model = model 17 | self.scale = scale 18 | 19 | 20 | class MultiControlNetManager: 21 | def __init__(self, controlnet_units=[]): 22 | self.processors = [unit.processor for unit in controlnet_units] 23 | self.models = [unit.model for unit in controlnet_units] 24 | self.scales = [unit.scale for unit in controlnet_units] 25 | 26 | def process_image(self, image, processor_id=None): 27 | if processor_id is None: 28 | processed_image = [processor(image) for processor in self.processors] 29 | else: 30 | processed_image = [self.processors[processor_id](image)] 31 | processed_image = torch.concat([ 32 | torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0) 33 | for image_ in processed_image 34 | ], dim=0) 35 | return processed_image 36 | 37 | def __call__( 38 | self, 39 | sample, timestep, encoder_hidden_states, conditionings, 40 | tiled=False, tile_size=64, tile_stride=32 41 | ): 42 | res_stack = None 43 | for conditioning, model, scale in zip(conditionings, self.models, self.scales): 44 | res_stack_ = model( 45 | sample, timestep, encoder_hidden_states, conditioning, 46 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride 47 | ) 48 | res_stack_ = [res * scale for res in res_stack_] 49 | if res_stack is None: 50 | res_stack = res_stack_ 51 | else: 52 | res_stack = [i + j for i, j in zip(res_stack, res_stack_)] 53 | return res_stack 54 | -------------------------------------------------------------------------------- /diffsynth/controlnets/processors.py: -------------------------------------------------------------------------------- 1 | from typing_extensions import Literal, TypeAlias 2 | import warnings 3 | # with warnings.catch_warnings(): 4 | # warnings.simplefilter("ignore") 5 | # from controlnet_aux.processor import ( 6 | # CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector 7 | # ) 8 | 9 | 10 | Processor_id: TypeAlias = Literal[ 11 | "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile" 12 | ] 13 | 14 | class Annotator: 15 | def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None): 16 | if processor_id == "canny": 17 | self.processor = CannyDetector() 18 | elif processor_id == "depth": 19 | self.processor = MidasDetector.from_pretrained(model_path).to("cuda") 20 | elif processor_id == "softedge": 21 | self.processor = HEDdetector.from_pretrained(model_path).to("cuda") 22 | elif processor_id == "lineart": 23 | self.processor = LineartDetector.from_pretrained(model_path).to("cuda") 24 | elif processor_id == "lineart_anime": 25 | self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda") 26 | elif processor_id == "openpose": 27 | self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda") 28 | elif processor_id == "tile": 29 | self.processor = None 30 | else: 31 | raise ValueError(f"Unsupported processor_id: {processor_id}") 32 | 33 | self.processor_id = processor_id 34 | self.detect_resolution = detect_resolution 35 | 36 | def __call__(self, image): 37 | width, height = image.size 38 | if self.processor_id == "openpose": 39 | kwargs = { 40 | "include_body": True, 41 | "include_hand": True, 42 | "include_face": True 43 | } 44 | else: 45 | kwargs = {} 46 | if self.processor is not None: 47 | detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height) 48 | image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs) 49 | image = image.resize((width, height)) 50 | return image 51 | 52 | -------------------------------------------------------------------------------- /diffsynth/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .video import VideoData, save_video, save_frames 2 | -------------------------------------------------------------------------------- /diffsynth/data/video.py: -------------------------------------------------------------------------------- 1 | import imageio, os 2 | import numpy as np 3 | from PIL import Image 4 | from tqdm import tqdm 5 | 6 | 7 | class LowMemoryVideo: 8 | def __init__(self, file_name): 9 | self.reader = imageio.get_reader(file_name) 10 | 11 | def __len__(self): 12 | return self.reader.count_frames() 13 | 14 | def __getitem__(self, item): 15 | return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB") 16 | 17 | def __del__(self): 18 | self.reader.close() 19 | 20 | 21 | def split_file_name(file_name): 22 | result = [] 23 | number = -1 24 | for i in file_name: 25 | if ord(i)>=ord("0") and ord(i)<=ord("9"): 26 | if number == -1: 27 | number = 0 28 | number = number*10 + ord(i) - ord("0") 29 | else: 30 | if number != -1: 31 | result.append(number) 32 | number = -1 33 | result.append(i) 34 | if number != -1: 35 | result.append(number) 36 | result = tuple(result) 37 | return result 38 | 39 | 40 | def search_for_images(folder): 41 | file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")] 42 | file_list = [(split_file_name(file_name), file_name) for file_name in file_list] 43 | file_list = [i[1] for i in sorted(file_list)] 44 | file_list = [os.path.join(folder, i) for i in file_list] 45 | return file_list 46 | 47 | 48 | class LowMemoryImageFolder: 49 | def __init__(self, folder, file_list=None): 50 | if file_list is None: 51 | self.file_list = search_for_images(folder) 52 | else: 53 | self.file_list = [os.path.join(folder, file_name) for file_name in file_list] 54 | 55 | def __len__(self): 56 | return len(self.file_list) 57 | 58 | def __getitem__(self, item): 59 | return Image.open(self.file_list[item]).convert("RGB") 60 | 61 | def __del__(self): 62 | pass 63 | 64 | 65 | def crop_and_resize(image, height, width): 66 | image = np.array(image) 67 | image_height, image_width, _ = image.shape 68 | if image_height / image_width < height / width: 69 | croped_width = int(image_height / height * width) 70 | left = (image_width - croped_width) // 2 71 | image = image[:, left: left+croped_width] 72 | image = Image.fromarray(image).resize((width, height)) 73 | else: 74 | croped_height = int(image_width / width * height) 75 | left = (image_height - croped_height) // 2 76 | image = image[left: left+croped_height, :] 77 | image = Image.fromarray(image).resize((width, height)) 78 | return image 79 | 80 | 81 | class VideoData: 82 | def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs): 83 | if video_file is not None: 84 | self.data_type = "video" 85 | self.data = LowMemoryVideo(video_file, **kwargs) 86 | elif image_folder is not None: 87 | self.data_type = "images" 88 | self.data = LowMemoryImageFolder(image_folder, **kwargs) 89 | else: 90 | raise ValueError("Cannot open video or image folder") 91 | self.length = None 92 | self.set_shape(height, width) 93 | 94 | def raw_data(self): 95 | frames = [] 96 | for i in range(self.__len__()): 97 | frames.append(self.__getitem__(i)) 98 | return frames 99 | 100 | def set_length(self, length): 101 | self.length = length 102 | 103 | def set_shape(self, height, width): 104 | self.height = height 105 | self.width = width 106 | 107 | def __len__(self): 108 | if self.length is None: 109 | return len(self.data) 110 | else: 111 | return self.length 112 | 113 | def shape(self): 114 | if self.height is not None and self.width is not None: 115 | return self.height, self.width 116 | else: 117 | height, width, _ = self.__getitem__(0).shape 118 | return height, width 119 | 120 | def __getitem__(self, item): 121 | frame = self.data.__getitem__(item) 122 | width, height = frame.size 123 | if self.height is not None and self.width is not None: 124 | if self.height != height or self.width != width: 125 | frame = crop_and_resize(frame, self.height, self.width) 126 | return frame 127 | 128 | def __del__(self): 129 | pass 130 | 131 | def save_images(self, folder): 132 | os.makedirs(folder, exist_ok=True) 133 | for i in tqdm(range(self.__len__()), desc="Saving images"): 134 | frame = self.__getitem__(i) 135 | frame.save(os.path.join(folder, f"{i}.png")) 136 | 137 | 138 | def save_video(frames, save_path, fps, quality=9): 139 | writer = imageio.get_writer(save_path, fps=fps, quality=quality) 140 | for frame in tqdm(frames, desc="Saving video"): 141 | frame = np.array(frame) 142 | writer.append_data(frame) 143 | writer.close() 144 | 145 | def save_frames(frames, save_path): 146 | os.makedirs(save_path, exist_ok=True) 147 | for i, frame in enumerate(tqdm(frames, desc="Saving images")): 148 | frame.save(os.path.join(save_path, f"{i}.png")) 149 | -------------------------------------------------------------------------------- /diffsynth/extensions/ESRGAN/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import repeat 3 | from PIL import Image 4 | import numpy as np 5 | 6 | 7 | class ResidualDenseBlock(torch.nn.Module): 8 | 9 | def __init__(self, num_feat=64, num_grow_ch=32): 10 | super(ResidualDenseBlock, self).__init__() 11 | self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) 12 | self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) 13 | self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) 14 | self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) 15 | self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) 16 | self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True) 17 | 18 | def forward(self, x): 19 | x1 = self.lrelu(self.conv1(x)) 20 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 21 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 22 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 23 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 24 | return x5 * 0.2 + x 25 | 26 | 27 | class RRDB(torch.nn.Module): 28 | 29 | def __init__(self, num_feat, num_grow_ch=32): 30 | super(RRDB, self).__init__() 31 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) 32 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) 33 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) 34 | 35 | def forward(self, x): 36 | out = self.rdb1(x) 37 | out = self.rdb2(out) 38 | out = self.rdb3(out) 39 | return out * 0.2 + x 40 | 41 | 42 | class RRDBNet(torch.nn.Module): 43 | 44 | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32): 45 | super(RRDBNet, self).__init__() 46 | self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 47 | self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)]) 48 | self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1) 49 | # upsample 50 | self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1) 51 | self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1) 52 | self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1) 53 | self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 54 | self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True) 55 | 56 | def forward(self, x): 57 | feat = x 58 | feat = self.conv_first(feat) 59 | body_feat = self.conv_body(self.body(feat)) 60 | feat = feat + body_feat 61 | # upsample 62 | feat = repeat(feat, "B C H W -> B C (H 2) (W 2)") 63 | feat = self.lrelu(self.conv_up1(feat)) 64 | feat = repeat(feat, "B C H W -> B C (H 2) (W 2)") 65 | feat = self.lrelu(self.conv_up2(feat)) 66 | out = self.conv_last(self.lrelu(self.conv_hr(feat))) 67 | return out 68 | 69 | 70 | class ESRGAN(torch.nn.Module): 71 | def __init__(self, model): 72 | super().__init__() 73 | self.model = model 74 | 75 | @staticmethod 76 | def from_pretrained(model_path): 77 | model = RRDBNet() 78 | state_dict = torch.load(model_path, map_location="cpu")["params_ema"] 79 | model.load_state_dict(state_dict) 80 | model.eval() 81 | return ESRGAN(model) 82 | 83 | def process_image(self, image): 84 | image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1) 85 | return image 86 | 87 | def process_images(self, images): 88 | images = [self.process_image(image) for image in images] 89 | images = torch.stack(images) 90 | return images 91 | 92 | def decode_images(self, images): 93 | images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8) 94 | images = [Image.fromarray(image) for image in images] 95 | return images 96 | 97 | @torch.no_grad() 98 | def upscale(self, images, batch_size=4, progress_bar=lambda x:x): 99 | # Preprocess 100 | input_tensor = self.process_images(images) 101 | 102 | # Interpolate 103 | output_tensor = [] 104 | for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)): 105 | batch_id_ = min(batch_id + batch_size, input_tensor.shape[0]) 106 | batch_input_tensor = input_tensor[batch_id: batch_id_] 107 | batch_input_tensor = batch_input_tensor.to( 108 | device=self.model.conv_first.weight.device, 109 | dtype=self.model.conv_first.weight.dtype) 110 | batch_output_tensor = self.model(batch_input_tensor) 111 | output_tensor.append(batch_output_tensor.cpu()) 112 | 113 | # Output 114 | output_tensor = torch.concat(output_tensor, dim=0) 115 | 116 | # To images 117 | output_images = self.decode_images(output_tensor) 118 | return output_images 119 | -------------------------------------------------------------------------------- /diffsynth/extensions/FastBlend/__init__.py: -------------------------------------------------------------------------------- 1 | from .runners.fast import TableManager, PyramidPatchMatcher 2 | from PIL import Image 3 | import numpy as np 4 | import cupy as cp 5 | 6 | 7 | class FastBlendSmoother: 8 | def __init__(self): 9 | self.batch_size = 8 10 | self.window_size = 64 11 | self.ebsynth_config = { 12 | "minimum_patch_size": 5, 13 | "threads_per_block": 8, 14 | "num_iter": 5, 15 | "gpu_id": 0, 16 | "guide_weight": 10.0, 17 | "initialize": "identity", 18 | "tracking_window_size": 0, 19 | } 20 | 21 | @staticmethod 22 | def from_model_manager(model_manager): 23 | # TODO: fetch GPU ID from model_manager 24 | return FastBlendSmoother() 25 | 26 | def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config): 27 | frames_guide = [np.array(frame) for frame in frames_guide] 28 | frames_style = [np.array(frame) for frame in frames_style] 29 | table_manager = TableManager() 30 | patch_match_engine = PyramidPatchMatcher( 31 | image_height=frames_style[0].shape[0], 32 | image_width=frames_style[0].shape[1], 33 | channel=3, 34 | **ebsynth_config 35 | ) 36 | # left part 37 | table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4") 38 | table_l = table_manager.remapping_table_to_blending_table(table_l) 39 | table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4") 40 | # right part 41 | table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4") 42 | table_r = table_manager.remapping_table_to_blending_table(table_r) 43 | table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1] 44 | # merge 45 | frames = [] 46 | for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r): 47 | weight_m = -1 48 | weight = weight_l + weight_m + weight_r 49 | frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight) 50 | frames.append(frame) 51 | frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames] 52 | return frames 53 | 54 | def __call__(self, rendered_frames, original_frames=None, **kwargs): 55 | frames = self.run( 56 | original_frames, rendered_frames, 57 | self.batch_size, self.window_size, self.ebsynth_config 58 | ) 59 | mempool = cp.get_default_memory_pool() 60 | pinned_mempool = cp.get_default_pinned_memory_pool() 61 | mempool.free_all_blocks() 62 | pinned_mempool.free_all_blocks() 63 | return frames -------------------------------------------------------------------------------- /diffsynth/extensions/FastBlend/cupy_kernels.py: -------------------------------------------------------------------------------- 1 | import cupy as cp 2 | 3 | remapping_kernel = cp.RawKernel(r''' 4 | extern "C" __global__ 5 | void remap( 6 | const int height, 7 | const int width, 8 | const int channel, 9 | const int patch_size, 10 | const int pad_size, 11 | const float* source_style, 12 | const int* nnf, 13 | float* target_style 14 | ) { 15 | const int r = (patch_size - 1) / 2; 16 | const int x = blockDim.x * blockIdx.x + threadIdx.x; 17 | const int y = blockDim.y * blockIdx.y + threadIdx.y; 18 | if (x >= height or y >= width) return; 19 | const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel; 20 | const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size); 21 | const int min_px = x < r ? -x : -r; 22 | const int max_px = x + r > height - 1 ? height - 1 - x : r; 23 | const int min_py = y < r ? -y : -r; 24 | const int max_py = y + r > width - 1 ? width - 1 - y : r; 25 | int num = 0; 26 | for (int px = min_px; px <= max_px; px++){ 27 | for (int py = min_py; py <= max_py; py++){ 28 | const int nid = (x + px) * width + y + py; 29 | const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px; 30 | const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py; 31 | if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue; 32 | const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size); 33 | num++; 34 | for (int c = 0; c < channel; c++){ 35 | target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c]; 36 | } 37 | } 38 | } 39 | for (int c = 0; c < channel; c++){ 40 | target_style[z + pid * channel + c] /= num; 41 | } 42 | } 43 | ''', 'remap') 44 | 45 | 46 | patch_error_kernel = cp.RawKernel(r''' 47 | extern "C" __global__ 48 | void patch_error( 49 | const int height, 50 | const int width, 51 | const int channel, 52 | const int patch_size, 53 | const int pad_size, 54 | const float* source, 55 | const int* nnf, 56 | const float* target, 57 | float* error 58 | ) { 59 | const int r = (patch_size - 1) / 2; 60 | const int x = blockDim.x * blockIdx.x + threadIdx.x; 61 | const int y = blockDim.y * blockIdx.y + threadIdx.y; 62 | const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel; 63 | if (x >= height or y >= width) return; 64 | const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0]; 65 | const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1]; 66 | float e = 0; 67 | for (int px = -r; px <= r; px++){ 68 | for (int py = -r; py <= r; py++){ 69 | const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py; 70 | const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py; 71 | for (int c = 0; c < channel; c++){ 72 | const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c]; 73 | e += diff * diff; 74 | } 75 | } 76 | } 77 | error[blockIdx.z * height * width + x * width + y] = e; 78 | } 79 | ''', 'patch_error') 80 | 81 | 82 | pairwise_patch_error_kernel = cp.RawKernel(r''' 83 | extern "C" __global__ 84 | void pairwise_patch_error( 85 | const int height, 86 | const int width, 87 | const int channel, 88 | const int patch_size, 89 | const int pad_size, 90 | const float* source_a, 91 | const int* nnf_a, 92 | const float* source_b, 93 | const int* nnf_b, 94 | float* error 95 | ) { 96 | const int r = (patch_size - 1) / 2; 97 | const int x = blockDim.x * blockIdx.x + threadIdx.x; 98 | const int y = blockDim.y * blockIdx.y + threadIdx.y; 99 | const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel; 100 | if (x >= height or y >= width) return; 101 | const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2; 102 | const int x_a = nnf_a[z_nnf + 0]; 103 | const int y_a = nnf_a[z_nnf + 1]; 104 | const int x_b = nnf_b[z_nnf + 0]; 105 | const int y_b = nnf_b[z_nnf + 1]; 106 | float e = 0; 107 | for (int px = -r; px <= r; px++){ 108 | for (int py = -r; py <= r; py++){ 109 | const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py; 110 | const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py; 111 | for (int c = 0; c < channel; c++){ 112 | const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c]; 113 | e += diff * diff; 114 | } 115 | } 116 | } 117 | error[blockIdx.z * height * width + x * width + y] = e; 118 | } 119 | ''', 'pairwise_patch_error') 120 | -------------------------------------------------------------------------------- /diffsynth/extensions/FastBlend/data.py: -------------------------------------------------------------------------------- 1 | import imageio, os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | 6 | def read_video(file_name): 7 | reader = imageio.get_reader(file_name) 8 | video = [] 9 | for frame in reader: 10 | frame = np.array(frame) 11 | video.append(frame) 12 | reader.close() 13 | return video 14 | 15 | 16 | def get_video_fps(file_name): 17 | reader = imageio.get_reader(file_name) 18 | fps = reader.get_meta_data()["fps"] 19 | reader.close() 20 | return fps 21 | 22 | 23 | def save_video(frames_path, video_path, num_frames, fps): 24 | writer = imageio.get_writer(video_path, fps=fps, quality=9) 25 | for i in range(num_frames): 26 | frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i))) 27 | writer.append_data(frame) 28 | writer.close() 29 | return video_path 30 | 31 | 32 | class LowMemoryVideo: 33 | def __init__(self, file_name): 34 | self.reader = imageio.get_reader(file_name) 35 | 36 | def __len__(self): 37 | return self.reader.count_frames() 38 | 39 | def __getitem__(self, item): 40 | return np.array(self.reader.get_data(item)) 41 | 42 | def __del__(self): 43 | self.reader.close() 44 | 45 | 46 | def split_file_name(file_name): 47 | result = [] 48 | number = -1 49 | for i in file_name: 50 | if ord(i)>=ord("0") and ord(i)<=ord("9"): 51 | if number == -1: 52 | number = 0 53 | number = number*10 + ord(i) - ord("0") 54 | else: 55 | if number != -1: 56 | result.append(number) 57 | number = -1 58 | result.append(i) 59 | if number != -1: 60 | result.append(number) 61 | result = tuple(result) 62 | return result 63 | 64 | 65 | def search_for_images(folder): 66 | file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")] 67 | file_list = [(split_file_name(file_name), file_name) for file_name in file_list] 68 | file_list = [i[1] for i in sorted(file_list)] 69 | file_list = [os.path.join(folder, i) for i in file_list] 70 | return file_list 71 | 72 | 73 | def read_images(folder): 74 | file_list = search_for_images(folder) 75 | frames = [np.array(Image.open(i)) for i in file_list] 76 | return frames 77 | 78 | 79 | class LowMemoryImageFolder: 80 | def __init__(self, folder, file_list=None): 81 | if file_list is None: 82 | self.file_list = search_for_images(folder) 83 | else: 84 | self.file_list = [os.path.join(folder, file_name) for file_name in file_list] 85 | 86 | def __len__(self): 87 | return len(self.file_list) 88 | 89 | def __getitem__(self, item): 90 | return np.array(Image.open(self.file_list[item])) 91 | 92 | def __del__(self): 93 | pass 94 | 95 | 96 | class VideoData: 97 | def __init__(self, video_file, image_folder, **kwargs): 98 | if video_file is not None: 99 | self.data_type = "video" 100 | self.data = LowMemoryVideo(video_file, **kwargs) 101 | elif image_folder is not None: 102 | self.data_type = "images" 103 | self.data = LowMemoryImageFolder(image_folder, **kwargs) 104 | else: 105 | raise ValueError("Cannot open video or image folder") 106 | self.length = None 107 | self.height = None 108 | self.width = None 109 | 110 | def raw_data(self): 111 | frames = [] 112 | for i in range(self.__len__()): 113 | frames.append(self.__getitem__(i)) 114 | return frames 115 | 116 | def set_length(self, length): 117 | self.length = length 118 | 119 | def set_shape(self, height, width): 120 | self.height = height 121 | self.width = width 122 | 123 | def __len__(self): 124 | if self.length is None: 125 | return len(self.data) 126 | else: 127 | return self.length 128 | 129 | def shape(self): 130 | if self.height is not None and self.width is not None: 131 | return self.height, self.width 132 | else: 133 | height, width, _ = self.__getitem__(0).shape 134 | return height, width 135 | 136 | def __getitem__(self, item): 137 | frame = self.data.__getitem__(item) 138 | height, width, _ = frame.shape 139 | if self.height is not None and self.width is not None: 140 | if self.height != height or self.width != width: 141 | frame = Image.fromarray(frame).resize((self.width, self.height)) 142 | frame = np.array(frame) 143 | return frame 144 | 145 | def __del__(self): 146 | pass 147 | -------------------------------------------------------------------------------- /diffsynth/extensions/FastBlend/runners/__init__.py: -------------------------------------------------------------------------------- 1 | from .accurate import AccurateModeRunner 2 | from .fast import FastModeRunner 3 | from .balanced import BalancedModeRunner 4 | from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner 5 | -------------------------------------------------------------------------------- /diffsynth/extensions/FastBlend/runners/accurate.py: -------------------------------------------------------------------------------- 1 | from ..patch_match import PyramidPatchMatcher 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | from tqdm import tqdm 6 | 7 | 8 | class AccurateModeRunner: 9 | def __init__(self): 10 | pass 11 | 12 | def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None): 13 | patch_match_engine = PyramidPatchMatcher( 14 | image_height=frames_style[0].shape[0], 15 | image_width=frames_style[0].shape[1], 16 | channel=3, 17 | use_mean_target_style=True, 18 | **ebsynth_config 19 | ) 20 | # run 21 | n = len(frames_style) 22 | for target in tqdm(range(n), desc=desc): 23 | l, r = max(target - window_size, 0), min(target + window_size + 1, n) 24 | remapped_frames = [] 25 | for i in range(l, r, batch_size): 26 | j = min(i + batch_size, r) 27 | source_guide = np.stack([frames_guide[source] for source in range(i, j)]) 28 | target_guide = np.stack([frames_guide[target]] * (j - i)) 29 | source_style = np.stack([frames_style[source] for source in range(i, j)]) 30 | _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) 31 | remapped_frames.append(target_style) 32 | frame = np.concatenate(remapped_frames, axis=0).mean(axis=0) 33 | frame = frame.clip(0, 255).astype("uint8") 34 | if save_path is not None: 35 | Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target)) -------------------------------------------------------------------------------- /diffsynth/extensions/FastBlend/runners/balanced.py: -------------------------------------------------------------------------------- 1 | from ..patch_match import PyramidPatchMatcher 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | from tqdm import tqdm 6 | 7 | 8 | class BalancedModeRunner: 9 | def __init__(self): 10 | pass 11 | 12 | def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None): 13 | patch_match_engine = PyramidPatchMatcher( 14 | image_height=frames_style[0].shape[0], 15 | image_width=frames_style[0].shape[1], 16 | channel=3, 17 | **ebsynth_config 18 | ) 19 | # tasks 20 | n = len(frames_style) 21 | tasks = [] 22 | for target in range(n): 23 | for source in range(target - window_size, target + window_size + 1): 24 | if source >= 0 and source < n and source != target: 25 | tasks.append((source, target)) 26 | # run 27 | frames = [(None, 1) for i in range(n)] 28 | for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc): 29 | tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))] 30 | source_guide = np.stack([frames_guide[source] for source, target in tasks_batch]) 31 | target_guide = np.stack([frames_guide[target] for source, target in tasks_batch]) 32 | source_style = np.stack([frames_style[source] for source, target in tasks_batch]) 33 | _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) 34 | for (source, target), result in zip(tasks_batch, target_style): 35 | frame, weight = frames[target] 36 | if frame is None: 37 | frame = frames_style[target] 38 | frames[target] = ( 39 | frame * (weight / (weight + 1)) + result / (weight + 1), 40 | weight + 1 41 | ) 42 | if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size): 43 | frame = frame.clip(0, 255).astype("uint8") 44 | if save_path is not None: 45 | Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target)) 46 | frames[target] = (None, 1) 47 | -------------------------------------------------------------------------------- /diffsynth/extensions/FastBlend/runners/fast.py: -------------------------------------------------------------------------------- 1 | from ..patch_match import PyramidPatchMatcher 2 | import functools, os 3 | import numpy as np 4 | from PIL import Image 5 | from tqdm import tqdm 6 | 7 | 8 | class TableManager: 9 | def __init__(self): 10 | pass 11 | 12 | def task_list(self, n): 13 | tasks = [] 14 | max_level = 1 15 | while (1<=n: 24 | break 25 | meta_data = { 26 | "source": i, 27 | "target": j, 28 | "level": level + 1 29 | } 30 | tasks.append(meta_data) 31 | tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"])) 32 | return tasks 33 | 34 | def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""): 35 | n = len(frames_guide) 36 | tasks = self.task_list(n) 37 | remapping_table = [[(frames_style[i], 1)] for i in range(n)] 38 | for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc): 39 | tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))] 40 | source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch]) 41 | target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch]) 42 | source_style = np.stack([frames_style[task["source"]] for task in tasks_batch]) 43 | _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) 44 | for task, result in zip(tasks_batch, target_style): 45 | target, level = task["target"], task["level"] 46 | if len(remapping_table[target])==level: 47 | remapping_table[target].append((result, 1)) 48 | else: 49 | frame, weight = remapping_table[target][level] 50 | remapping_table[target][level] = ( 51 | frame * (weight / (weight + 1)) + result / (weight + 1), 52 | weight + 1 53 | ) 54 | return remapping_table 55 | 56 | def remapping_table_to_blending_table(self, table): 57 | for i in range(len(table)): 58 | for j in range(1, len(table[i])): 59 | frame_1, weight_1 = table[i][j-1] 60 | frame_2, weight_2 = table[i][j] 61 | frame = (frame_1 + frame_2) / 2 62 | weight = weight_1 + weight_2 63 | table[i][j] = (frame, weight) 64 | return table 65 | 66 | def tree_query(self, leftbound, rightbound): 67 | node_list = [] 68 | node_index = rightbound 69 | while node_index>=leftbound: 70 | node_level = 0 71 | while (1<=leftbound: 72 | node_level += 1 73 | node_list.append((node_index, node_level)) 74 | node_index -= 1<0: 31 | tasks = [] 32 | for m in range(index_style[0]): 33 | tasks.append((index_style[0], m, index_style[0])) 34 | task_group.append(tasks) 35 | # middle frames 36 | for l, r in zip(index_style[:-1], index_style[1:]): 37 | tasks = [] 38 | for m in range(l, r): 39 | tasks.append((l, m, r)) 40 | task_group.append(tasks) 41 | # last frame 42 | tasks = [] 43 | for m in range(index_style[-1], n): 44 | tasks.append((index_style[-1], m, index_style[-1])) 45 | task_group.append(tasks) 46 | return task_group 47 | 48 | def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None): 49 | patch_match_engine = PyramidPatchMatcher( 50 | image_height=frames_style[0].shape[0], 51 | image_width=frames_style[0].shape[1], 52 | channel=3, 53 | use_mean_target_style=False, 54 | use_pairwise_patch_error=True, 55 | **ebsynth_config 56 | ) 57 | # task 58 | index_dict = self.get_index_dict(index_style) 59 | task_group = self.get_task_group(index_style, len(frames_guide)) 60 | # run 61 | for tasks in task_group: 62 | index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks]) 63 | for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"): 64 | tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))] 65 | source_guide, target_guide, source_style = [], [], [] 66 | for l, m, r in tasks_batch: 67 | # l -> m 68 | source_guide.append(frames_guide[l]) 69 | target_guide.append(frames_guide[m]) 70 | source_style.append(frames_style[index_dict[l]]) 71 | # r -> m 72 | source_guide.append(frames_guide[r]) 73 | target_guide.append(frames_guide[m]) 74 | source_style.append(frames_style[index_dict[r]]) 75 | source_guide = np.stack(source_guide) 76 | target_guide = np.stack(target_guide) 77 | source_style = np.stack(source_style) 78 | _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) 79 | if save_path is not None: 80 | for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch): 81 | weight_l, weight_r = self.get_weight(l, m, r) 82 | frame = frame_l * weight_l + frame_r * weight_r 83 | frame = frame.clip(0, 255).astype("uint8") 84 | Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m)) 85 | 86 | 87 | class InterpolationModeSingleFrameRunner: 88 | def __init__(self): 89 | pass 90 | 91 | def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None): 92 | # check input 93 | tracking_window_size = ebsynth_config["tracking_window_size"] 94 | if tracking_window_size * 2 >= batch_size: 95 | raise ValueError("batch_size should be larger than track_window_size * 2") 96 | frame_style = frames_style[0] 97 | frame_guide = frames_guide[index_style[0]] 98 | patch_match_engine = PyramidPatchMatcher( 99 | image_height=frame_style.shape[0], 100 | image_width=frame_style.shape[1], 101 | channel=3, 102 | **ebsynth_config 103 | ) 104 | # run 105 | frame_id, n = 0, len(frames_guide) 106 | for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"): 107 | if i + batch_size > n: 108 | l, r = max(n - batch_size, 0), n 109 | else: 110 | l, r = i, i + batch_size 111 | source_guide = np.stack([frame_guide] * (r-l)) 112 | target_guide = np.stack([frames_guide[i] for i in range(l, r)]) 113 | source_style = np.stack([frame_style] * (r-l)) 114 | _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) 115 | for i, frame in zip(range(l, r), target_style): 116 | if i==frame_id: 117 | frame = frame.clip(0, 255).astype("uint8") 118 | Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id)) 119 | frame_id += 1 120 | if r < n and r-frame_id <= tracking_window_size: 121 | break 122 | -------------------------------------------------------------------------------- /diffsynth/extensions/RIFE/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from PIL import Image 6 | 7 | 8 | def warp(tenInput, tenFlow, device): 9 | backwarp_tenGrid = {} 10 | k = (str(tenFlow.device), str(tenFlow.size())) 11 | if k not in backwarp_tenGrid: 12 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view( 13 | 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) 14 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view( 15 | 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) 16 | backwarp_tenGrid[k] = torch.cat( 17 | [tenHorizontal, tenVertical], 1).to(device) 18 | 19 | tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), 20 | tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) 21 | 22 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) 23 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) 24 | 25 | 26 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 27 | return nn.Sequential( 28 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 29 | padding=padding, dilation=dilation, bias=True), 30 | nn.PReLU(out_planes) 31 | ) 32 | 33 | 34 | class IFBlock(nn.Module): 35 | def __init__(self, in_planes, c=64): 36 | super(IFBlock, self).__init__() 37 | self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),) 38 | self.convblock0 = nn.Sequential(conv(c, c), conv(c, c)) 39 | self.convblock1 = nn.Sequential(conv(c, c), conv(c, c)) 40 | self.convblock2 = nn.Sequential(conv(c, c), conv(c, c)) 41 | self.convblock3 = nn.Sequential(conv(c, c), conv(c, c)) 42 | self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1)) 43 | self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1)) 44 | 45 | def forward(self, x, flow, scale=1): 46 | x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) 47 | flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale 48 | feat = self.conv0(torch.cat((x, flow), 1)) 49 | feat = self.convblock0(feat) + feat 50 | feat = self.convblock1(feat) + feat 51 | feat = self.convblock2(feat) + feat 52 | feat = self.convblock3(feat) + feat 53 | flow = self.conv1(feat) 54 | mask = self.conv2(feat) 55 | flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale 56 | mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) 57 | return flow, mask 58 | 59 | 60 | class IFNet(nn.Module): 61 | def __init__(self): 62 | super(IFNet, self).__init__() 63 | self.block0 = IFBlock(7+4, c=90) 64 | self.block1 = IFBlock(7+4, c=90) 65 | self.block2 = IFBlock(7+4, c=90) 66 | self.block_tea = IFBlock(10+4, c=90) 67 | 68 | def forward(self, x, scale_list=[4, 2, 1], training=False): 69 | if training == False: 70 | channel = x.shape[1] // 2 71 | img0 = x[:, :channel] 72 | img1 = x[:, channel:] 73 | flow_list = [] 74 | merged = [] 75 | mask_list = [] 76 | warped_img0 = img0 77 | warped_img1 = img1 78 | flow = (x[:, :4]).detach() * 0 79 | mask = (x[:, :1]).detach() * 0 80 | block = [self.block0, self.block1, self.block2] 81 | for i in range(3): 82 | f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i]) 83 | f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i]) 84 | flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 85 | mask = mask + (m0 + (-m1)) / 2 86 | mask_list.append(mask) 87 | flow_list.append(flow) 88 | warped_img0 = warp(img0, flow[:, :2], device=x.device) 89 | warped_img1 = warp(img1, flow[:, 2:4], device=x.device) 90 | merged.append((warped_img0, warped_img1)) 91 | ''' 92 | c0 = self.contextnet(img0, flow[:, :2]) 93 | c1 = self.contextnet(img1, flow[:, 2:4]) 94 | tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) 95 | res = tmp[:, 1:4] * 2 - 1 96 | ''' 97 | for i in range(3): 98 | mask_list[i] = torch.sigmoid(mask_list[i]) 99 | merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) 100 | return flow_list, mask_list[2], merged 101 | 102 | def state_dict_converter(self): 103 | return IFNetStateDictConverter() 104 | 105 | 106 | class IFNetStateDictConverter: 107 | def __init__(self): 108 | pass 109 | 110 | def from_diffusers(self, state_dict): 111 | state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()} 112 | return state_dict_ 113 | 114 | def from_civitai(self, state_dict): 115 | return self.from_diffusers(state_dict) 116 | 117 | 118 | class RIFEInterpolater: 119 | def __init__(self, model, device="cuda"): 120 | self.model = model 121 | self.device = device 122 | # IFNet only does not support float16 123 | self.torch_dtype = torch.float32 124 | 125 | @staticmethod 126 | def from_model_manager(model_manager): 127 | return RIFEInterpolater(model_manager.RIFE, device=model_manager.device) 128 | 129 | def process_image(self, image): 130 | width, height = image.size 131 | if width % 32 != 0 or height % 32 != 0: 132 | width = (width + 31) // 32 133 | height = (height + 31) // 32 134 | image = image.resize((width, height)) 135 | image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1) 136 | return image 137 | 138 | def process_images(self, images): 139 | images = [self.process_image(image) for image in images] 140 | images = torch.stack(images) 141 | return images 142 | 143 | def decode_images(self, images): 144 | images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8) 145 | images = [Image.fromarray(image) for image in images] 146 | return images 147 | 148 | def add_interpolated_images(self, images, interpolated_images): 149 | output_images = [] 150 | for image, interpolated_image in zip(images, interpolated_images): 151 | output_images.append(image) 152 | output_images.append(interpolated_image) 153 | output_images.append(images[-1]) 154 | return output_images 155 | 156 | 157 | @torch.no_grad() 158 | def interpolate_(self, images, scale=1.0): 159 | input_tensor = self.process_images(images) 160 | input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1) 161 | input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype) 162 | flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale]) 163 | output_images = self.decode_images(merged[2].cpu()) 164 | if output_images[0].size != images[0].size: 165 | output_images = [image.resize(images[0].size) for image in output_images] 166 | return output_images 167 | 168 | 169 | @torch.no_grad() 170 | def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x): 171 | # Preprocess 172 | processed_images = self.process_images(images) 173 | 174 | for iter in range(num_iter): 175 | # Input 176 | input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1) 177 | 178 | # Interpolate 179 | output_tensor = [] 180 | for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)): 181 | batch_id_ = min(batch_id + batch_size, input_tensor.shape[0]) 182 | batch_input_tensor = input_tensor[batch_id: batch_id_] 183 | batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype) 184 | flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale]) 185 | output_tensor.append(merged[2].cpu()) 186 | 187 | # Output 188 | output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1) 189 | processed_images = self.add_interpolated_images(processed_images, output_tensor) 190 | processed_images = torch.stack(processed_images) 191 | 192 | # To images 193 | output_images = self.decode_images(processed_images) 194 | if output_images[0].size != images[0].size: 195 | output_images = [image.resize(images[0].size) for image in output_images] 196 | return output_images 197 | 198 | 199 | class RIFESmoother(RIFEInterpolater): 200 | def __init__(self, model, device="cuda"): 201 | super(RIFESmoother, self).__init__(model, device=device) 202 | 203 | @staticmethod 204 | def from_model_manager(model_manager): 205 | return RIFESmoother(model_manager.RIFE, device=model_manager.device) 206 | 207 | def process_tensors(self, input_tensor, scale=1.0, batch_size=4): 208 | output_tensor = [] 209 | for batch_id in range(0, input_tensor.shape[0], batch_size): 210 | batch_id_ = min(batch_id + batch_size, input_tensor.shape[0]) 211 | batch_input_tensor = input_tensor[batch_id: batch_id_] 212 | batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype) 213 | flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale]) 214 | output_tensor.append(merged[2].cpu()) 215 | output_tensor = torch.concat(output_tensor, dim=0) 216 | return output_tensor 217 | 218 | @torch.no_grad() 219 | def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs): 220 | # Preprocess 221 | processed_images = self.process_images(rendered_frames) 222 | 223 | for iter in range(num_iter): 224 | # Input 225 | input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1) 226 | 227 | # Interpolate 228 | output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size) 229 | 230 | # Blend 231 | input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1) 232 | output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size) 233 | 234 | # Add to frames 235 | processed_images[1:-1] = output_tensor 236 | 237 | # To images 238 | output_images = self.decode_images(processed_images) 239 | if output_images[0].size != rendered_frames[0].size: 240 | output_images = [image.resize(rendered_frames[0].size) for image in output_images] 241 | return output_images 242 | -------------------------------------------------------------------------------- /diffsynth/models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | 4 | 5 | def low_version_attention(query, key, value, attn_bias=None): 6 | scale = 1 / query.shape[-1] ** 0.5 7 | query = query * scale 8 | attn = torch.matmul(query, key.transpose(-2, -1)) 9 | if attn_bias is not None: 10 | attn = attn + attn_bias 11 | attn = attn.softmax(-1) 12 | return attn @ value 13 | 14 | 15 | class Attention(torch.nn.Module): 16 | 17 | def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): 18 | super().__init__() 19 | dim_inner = head_dim * num_heads 20 | kv_dim = kv_dim if kv_dim is not None else q_dim 21 | self.num_heads = num_heads 22 | self.head_dim = head_dim 23 | 24 | self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) 25 | self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) 26 | self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) 27 | self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) 28 | 29 | def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0): 30 | batch_size = q.shape[0] 31 | ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) 32 | ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) 33 | ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v) 34 | hidden_states = hidden_states + scale * ip_hidden_states 35 | return hidden_states 36 | 37 | def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): 38 | if encoder_hidden_states is None: 39 | encoder_hidden_states = hidden_states 40 | 41 | batch_size = encoder_hidden_states.shape[0] 42 | 43 | q = self.to_q(hidden_states) 44 | k = self.to_k(encoder_hidden_states) 45 | v = self.to_v(encoder_hidden_states) 46 | 47 | q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) 48 | k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) 49 | v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) 50 | 51 | if qkv_preprocessor is not None: 52 | q, k, v = qkv_preprocessor(q, k, v) 53 | 54 | hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) 55 | if ipadapter_kwargs is not None: 56 | hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs) 57 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) 58 | hidden_states = hidden_states.to(q.dtype) 59 | 60 | hidden_states = self.to_out(hidden_states) 61 | 62 | return hidden_states 63 | 64 | def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): 65 | if encoder_hidden_states is None: 66 | encoder_hidden_states = hidden_states 67 | 68 | q = self.to_q(hidden_states) 69 | k = self.to_k(encoder_hidden_states) 70 | v = self.to_v(encoder_hidden_states) 71 | 72 | q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads) 73 | k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads) 74 | v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads) 75 | 76 | if attn_mask is not None: 77 | hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask) 78 | else: 79 | import xformers.ops as xops 80 | hidden_states = xops.memory_efficient_attention(q, k, v) 81 | hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads) 82 | 83 | hidden_states = hidden_states.to(q.dtype) 84 | hidden_states = self.to_out(hidden_states) 85 | 86 | return hidden_states 87 | 88 | def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): 89 | return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor) -------------------------------------------------------------------------------- /diffsynth/models/hunyuan_dit_text_encoder.py: -------------------------------------------------------------------------------- 1 | from transformers import BertModel, BertConfig, T5EncoderModel, T5Config 2 | import torch 3 | 4 | 5 | 6 | class HunyuanDiTCLIPTextEncoder(BertModel): 7 | def __init__(self): 8 | config = BertConfig( 9 | _name_or_path = "", 10 | architectures = ["BertModel"], 11 | attention_probs_dropout_prob = 0.1, 12 | bos_token_id = 0, 13 | classifier_dropout = None, 14 | directionality = "bidi", 15 | eos_token_id = 2, 16 | hidden_act = "gelu", 17 | hidden_dropout_prob = 0.1, 18 | hidden_size = 1024, 19 | initializer_range = 0.02, 20 | intermediate_size = 4096, 21 | layer_norm_eps = 1e-12, 22 | max_position_embeddings = 512, 23 | model_type = "bert", 24 | num_attention_heads = 16, 25 | num_hidden_layers = 24, 26 | output_past = True, 27 | pad_token_id = 0, 28 | pooler_fc_size = 768, 29 | pooler_num_attention_heads = 12, 30 | pooler_num_fc_layers = 3, 31 | pooler_size_per_head = 128, 32 | pooler_type = "first_token_transform", 33 | position_embedding_type = "absolute", 34 | torch_dtype = "float32", 35 | transformers_version = "4.37.2", 36 | type_vocab_size = 2, 37 | use_cache = True, 38 | vocab_size = 47020 39 | ) 40 | super().__init__(config, add_pooling_layer=False) 41 | self.eval() 42 | 43 | def forward(self, input_ids, attention_mask, clip_skip=1): 44 | input_shape = input_ids.size() 45 | 46 | batch_size, seq_length = input_shape 47 | device = input_ids.device 48 | 49 | past_key_values_length = 0 50 | 51 | if attention_mask is None: 52 | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) 53 | 54 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) 55 | 56 | embedding_output = self.embeddings( 57 | input_ids=input_ids, 58 | position_ids=None, 59 | token_type_ids=None, 60 | inputs_embeds=None, 61 | past_key_values_length=0, 62 | ) 63 | encoder_outputs = self.encoder( 64 | embedding_output, 65 | attention_mask=extended_attention_mask, 66 | head_mask=None, 67 | encoder_hidden_states=None, 68 | encoder_attention_mask=None, 69 | past_key_values=None, 70 | use_cache=False, 71 | output_attentions=False, 72 | output_hidden_states=True, 73 | return_dict=True, 74 | ) 75 | all_hidden_states = encoder_outputs.hidden_states 76 | prompt_emb = all_hidden_states[-clip_skip] 77 | if clip_skip > 1: 78 | mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std() 79 | prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean 80 | return prompt_emb 81 | 82 | def state_dict_converter(self): 83 | return HunyuanDiTCLIPTextEncoderStateDictConverter() 84 | 85 | 86 | 87 | class HunyuanDiTT5TextEncoder(T5EncoderModel): 88 | def __init__(self): 89 | config = T5Config( 90 | _name_or_path = "../HunyuanDiT/t2i/mt5", 91 | architectures = ["MT5ForConditionalGeneration"], 92 | classifier_dropout = 0.0, 93 | d_ff = 5120, 94 | d_kv = 64, 95 | d_model = 2048, 96 | decoder_start_token_id = 0, 97 | dense_act_fn = "gelu_new", 98 | dropout_rate = 0.1, 99 | eos_token_id = 1, 100 | feed_forward_proj = "gated-gelu", 101 | initializer_factor = 1.0, 102 | is_encoder_decoder = True, 103 | is_gated_act = True, 104 | layer_norm_epsilon = 1e-06, 105 | model_type = "t5", 106 | num_decoder_layers = 24, 107 | num_heads = 32, 108 | num_layers = 24, 109 | output_past = True, 110 | pad_token_id = 0, 111 | relative_attention_max_distance = 128, 112 | relative_attention_num_buckets = 32, 113 | tie_word_embeddings = False, 114 | tokenizer_class = "T5Tokenizer", 115 | transformers_version = "4.37.2", 116 | use_cache = True, 117 | vocab_size = 250112 118 | ) 119 | super().__init__(config) 120 | self.eval() 121 | 122 | def forward(self, input_ids, attention_mask, clip_skip=1): 123 | outputs = super().forward( 124 | input_ids=input_ids, 125 | attention_mask=attention_mask, 126 | output_hidden_states=True, 127 | ) 128 | prompt_emb = outputs.hidden_states[-clip_skip] 129 | if clip_skip > 1: 130 | mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std() 131 | prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean 132 | return prompt_emb 133 | 134 | def state_dict_converter(self): 135 | return HunyuanDiTT5TextEncoderStateDictConverter() 136 | 137 | 138 | 139 | class HunyuanDiTCLIPTextEncoderStateDictConverter(): 140 | def __init__(self): 141 | pass 142 | 143 | def from_diffusers(self, state_dict): 144 | state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")} 145 | return state_dict_ 146 | 147 | def from_civitai(self, state_dict): 148 | return self.from_diffusers(state_dict) 149 | 150 | 151 | class HunyuanDiTT5TextEncoderStateDictConverter(): 152 | def __init__(self): 153 | pass 154 | 155 | def from_diffusers(self, state_dict): 156 | state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")} 157 | state_dict_["shared.weight"] = state_dict["shared.weight"] 158 | return state_dict_ 159 | 160 | def from_civitai(self, state_dict): 161 | return self.from_diffusers(state_dict) 162 | -------------------------------------------------------------------------------- /diffsynth/models/sd_ipadapter.py: -------------------------------------------------------------------------------- 1 | from .svd_image_encoder import SVDImageEncoder 2 | from .sdxl_ipadapter import IpAdapterImageProjModel, IpAdapterModule, SDXLIpAdapterStateDictConverter 3 | from transformers import CLIPImageProcessor 4 | import torch 5 | 6 | 7 | class IpAdapterCLIPImageEmbedder(SVDImageEncoder): 8 | def __init__(self): 9 | super().__init__() 10 | self.image_processor = CLIPImageProcessor() 11 | 12 | def forward(self, image): 13 | pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values 14 | pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype) 15 | return super().forward(pixel_values) 16 | 17 | 18 | class SDIpAdapter(torch.nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | shape_list = [(768, 320)] * 2 + [(768, 640)] * 2 + [(768, 1280)] * 5 + [(768, 640)] * 3 + [(768, 320)] * 3 + [(768, 1280)] * 1 22 | self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list]) 23 | self.image_proj = IpAdapterImageProjModel(cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4) 24 | self.set_full_adapter() 25 | 26 | def set_full_adapter(self): 27 | block_ids = [1, 4, 9, 12, 17, 20, 40, 43, 46, 50, 53, 56, 60, 63, 66, 29] 28 | self.call_block_id = {(i, 0): j for j, i in enumerate(block_ids)} 29 | 30 | def set_less_adapter(self): 31 | # IP-Adapter for SD v1.5 doesn't support this feature. 32 | self.set_full_adapter(self) 33 | 34 | def forward(self, hidden_states, scale=1.0): 35 | hidden_states = self.image_proj(hidden_states) 36 | hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1]) 37 | ip_kv_dict = {} 38 | for (block_id, transformer_id) in self.call_block_id: 39 | ipadapter_id = self.call_block_id[(block_id, transformer_id)] 40 | ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states) 41 | if block_id not in ip_kv_dict: 42 | ip_kv_dict[block_id] = {} 43 | ip_kv_dict[block_id][transformer_id] = { 44 | "ip_k": ip_k, 45 | "ip_v": ip_v, 46 | "scale": scale 47 | } 48 | return ip_kv_dict 49 | 50 | def state_dict_converter(self): 51 | return SDIpAdapterStateDictConverter() 52 | 53 | 54 | class SDIpAdapterStateDictConverter(SDXLIpAdapterStateDictConverter): 55 | def __init__(self): 56 | pass 57 | -------------------------------------------------------------------------------- /diffsynth/models/sd_lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .sd_unet import SDUNetStateDictConverter, SDUNet 3 | from .sd_text_encoder import SDTextEncoderStateDictConverter, SDTextEncoder 4 | 5 | 6 | class SDLoRA: 7 | def __init__(self): 8 | pass 9 | 10 | def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"): 11 | special_keys = { 12 | "down.blocks": "down_blocks", 13 | "up.blocks": "up_blocks", 14 | "mid.block": "mid_block", 15 | "proj.in": "proj_in", 16 | "proj.out": "proj_out", 17 | "transformer.blocks": "transformer_blocks", 18 | "to.q": "to_q", 19 | "to.k": "to_k", 20 | "to.v": "to_v", 21 | "to.out": "to_out", 22 | } 23 | state_dict_ = {} 24 | for key in state_dict: 25 | if ".lora_up" not in key: 26 | continue 27 | if not key.startswith(lora_prefix): 28 | continue 29 | weight_up = state_dict[key].to(device="cuda", dtype=torch.float16) 30 | weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16) 31 | if len(weight_up.shape) == 4: 32 | weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32) 33 | weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32) 34 | lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) 35 | else: 36 | lora_weight = alpha * torch.mm(weight_up, weight_down) 37 | target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight" 38 | for special_key in special_keys: 39 | target_name = target_name.replace(special_key, special_keys[special_key]) 40 | state_dict_[target_name] = lora_weight.cpu() 41 | return state_dict_ 42 | 43 | def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"): 44 | state_dict_unet = unet.state_dict() 45 | state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device) 46 | state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora) 47 | if len(state_dict_lora) > 0: 48 | for name in state_dict_lora: 49 | state_dict_unet[name] += state_dict_lora[name].to(device=device) 50 | unet.load_state_dict(state_dict_unet) 51 | 52 | def add_lora_to_text_encoder(self, text_encoder: SDTextEncoder, state_dict_lora, alpha=1.0, device="cuda"): 53 | state_dict_text_encoder = text_encoder.state_dict() 54 | state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_te_", alpha=alpha, device=device) 55 | state_dict_lora = SDTextEncoderStateDictConverter().from_diffusers(state_dict_lora) 56 | if len(state_dict_lora) > 0: 57 | for name in state_dict_lora: 58 | state_dict_text_encoder[name] += state_dict_lora[name].to(device=device) 59 | text_encoder.load_state_dict(state_dict_text_encoder) 60 | 61 | -------------------------------------------------------------------------------- /diffsynth/models/sd_motion.py: -------------------------------------------------------------------------------- 1 | from .sd_unet import SDUNet, Attention, GEGLU 2 | import torch 3 | from einops import rearrange, repeat 4 | 5 | 6 | class TemporalTransformerBlock(torch.nn.Module): 7 | 8 | def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32): 9 | super().__init__() 10 | 11 | # 1. Self-Attn 12 | self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim)) 13 | self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True) 14 | self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) 15 | 16 | # 2. Cross-Attn 17 | self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim)) 18 | self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True) 19 | self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) 20 | 21 | # 3. Feed-forward 22 | self.norm3 = torch.nn.LayerNorm(dim, elementwise_affine=True) 23 | self.act_fn = GEGLU(dim, dim * 4) 24 | self.ff = torch.nn.Linear(dim * 4, dim) 25 | 26 | 27 | def forward(self, hidden_states, batch_size=1): 28 | 29 | # 1. Self-Attention 30 | norm_hidden_states = self.norm1(hidden_states) 31 | norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size) 32 | attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]]) 33 | attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size) 34 | hidden_states = attn_output + hidden_states 35 | 36 | # 2. Cross-Attention 37 | norm_hidden_states = self.norm2(hidden_states) 38 | norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size) 39 | attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]]) 40 | attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size) 41 | hidden_states = attn_output + hidden_states 42 | 43 | # 3. Feed-forward 44 | norm_hidden_states = self.norm3(hidden_states) 45 | ff_output = self.act_fn(norm_hidden_states) 46 | ff_output = self.ff(ff_output) 47 | hidden_states = ff_output + hidden_states 48 | 49 | return hidden_states 50 | 51 | 52 | class TemporalBlock(torch.nn.Module): 53 | 54 | def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5): 55 | super().__init__() 56 | inner_dim = num_attention_heads * attention_head_dim 57 | 58 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True) 59 | self.proj_in = torch.nn.Linear(in_channels, inner_dim) 60 | 61 | self.transformer_blocks = torch.nn.ModuleList([ 62 | TemporalTransformerBlock( 63 | inner_dim, 64 | num_attention_heads, 65 | attention_head_dim 66 | ) 67 | for d in range(num_layers) 68 | ]) 69 | 70 | self.proj_out = torch.nn.Linear(inner_dim, in_channels) 71 | 72 | def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1): 73 | batch, _, height, width = hidden_states.shape 74 | residual = hidden_states 75 | 76 | hidden_states = self.norm(hidden_states) 77 | inner_dim = hidden_states.shape[1] 78 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 79 | hidden_states = self.proj_in(hidden_states) 80 | 81 | for block in self.transformer_blocks: 82 | hidden_states = block( 83 | hidden_states, 84 | batch_size=batch_size 85 | ) 86 | 87 | hidden_states = self.proj_out(hidden_states) 88 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 89 | hidden_states = hidden_states + residual 90 | 91 | return hidden_states, time_emb, text_emb, res_stack 92 | 93 | 94 | class SDMotionModel(torch.nn.Module): 95 | def __init__(self): 96 | super().__init__() 97 | self.motion_modules = torch.nn.ModuleList([ 98 | TemporalBlock(8, 40, 320, eps=1e-6), 99 | TemporalBlock(8, 40, 320, eps=1e-6), 100 | TemporalBlock(8, 80, 640, eps=1e-6), 101 | TemporalBlock(8, 80, 640, eps=1e-6), 102 | TemporalBlock(8, 160, 1280, eps=1e-6), 103 | TemporalBlock(8, 160, 1280, eps=1e-6), 104 | TemporalBlock(8, 160, 1280, eps=1e-6), 105 | TemporalBlock(8, 160, 1280, eps=1e-6), 106 | TemporalBlock(8, 160, 1280, eps=1e-6), 107 | TemporalBlock(8, 160, 1280, eps=1e-6), 108 | TemporalBlock(8, 160, 1280, eps=1e-6), 109 | TemporalBlock(8, 160, 1280, eps=1e-6), 110 | TemporalBlock(8, 160, 1280, eps=1e-6), 111 | TemporalBlock(8, 160, 1280, eps=1e-6), 112 | TemporalBlock(8, 160, 1280, eps=1e-6), 113 | TemporalBlock(8, 80, 640, eps=1e-6), 114 | TemporalBlock(8, 80, 640, eps=1e-6), 115 | TemporalBlock(8, 80, 640, eps=1e-6), 116 | TemporalBlock(8, 40, 320, eps=1e-6), 117 | TemporalBlock(8, 40, 320, eps=1e-6), 118 | TemporalBlock(8, 40, 320, eps=1e-6), 119 | ]) 120 | self.call_block_id = { 121 | 1: 0, 122 | 4: 1, 123 | 9: 2, 124 | 12: 3, 125 | 17: 4, 126 | 20: 5, 127 | 24: 6, 128 | 26: 7, 129 | 29: 8, 130 | 32: 9, 131 | 34: 10, 132 | 36: 11, 133 | 40: 12, 134 | 43: 13, 135 | 46: 14, 136 | 50: 15, 137 | 53: 16, 138 | 56: 17, 139 | 60: 18, 140 | 63: 19, 141 | 66: 20 142 | } 143 | 144 | def forward(self): 145 | pass 146 | 147 | def state_dict_converter(self): 148 | return SDMotionModelStateDictConverter() 149 | 150 | 151 | class SDMotionModelStateDictConverter: 152 | def __init__(self): 153 | pass 154 | 155 | def from_diffusers(self, state_dict): 156 | rename_dict = { 157 | "norm": "norm", 158 | "proj_in": "proj_in", 159 | "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q", 160 | "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k", 161 | "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v", 162 | "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out", 163 | "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1", 164 | "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q", 165 | "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k", 166 | "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v", 167 | "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out", 168 | "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2", 169 | "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1", 170 | "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2", 171 | "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj", 172 | "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff", 173 | "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3", 174 | "proj_out": "proj_out", 175 | } 176 | name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")]) 177 | name_list += sorted([i for i in state_dict if i.startswith("mid_block.")]) 178 | name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")]) 179 | state_dict_ = {} 180 | last_prefix, module_id = "", -1 181 | for name in name_list: 182 | names = name.split(".") 183 | prefix_index = names.index("temporal_transformer") + 1 184 | prefix = ".".join(names[:prefix_index]) 185 | if prefix != last_prefix: 186 | last_prefix = prefix 187 | module_id += 1 188 | middle_name = ".".join(names[prefix_index:-1]) 189 | suffix = names[-1] 190 | if "pos_encoder" in names: 191 | rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]]) 192 | else: 193 | rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix]) 194 | state_dict_[rename] = state_dict[name] 195 | return state_dict_ 196 | 197 | def from_civitai(self, state_dict): 198 | return self.from_diffusers(state_dict) 199 | -------------------------------------------------------------------------------- /diffsynth/models/sdxl_ipadapter.py: -------------------------------------------------------------------------------- 1 | from .svd_image_encoder import SVDImageEncoder 2 | from transformers import CLIPImageProcessor 3 | import torch 4 | 5 | 6 | class IpAdapterXLCLIPImageEmbedder(SVDImageEncoder): 7 | def __init__(self): 8 | super().__init__(embed_dim=1664, encoder_intermediate_size=8192, projection_dim=1280, num_encoder_layers=48, num_heads=16, head_dim=104) 9 | self.image_processor = CLIPImageProcessor() 10 | 11 | def forward(self, image): 12 | pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values 13 | pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype) 14 | return super().forward(pixel_values) 15 | 16 | 17 | class IpAdapterImageProjModel(torch.nn.Module): 18 | def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4): 19 | super().__init__() 20 | self.cross_attention_dim = cross_attention_dim 21 | self.clip_extra_context_tokens = clip_extra_context_tokens 22 | self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 23 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 24 | 25 | def forward(self, image_embeds): 26 | clip_extra_context_tokens = self.proj(image_embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) 27 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 28 | return clip_extra_context_tokens 29 | 30 | 31 | class IpAdapterModule(torch.nn.Module): 32 | def __init__(self, input_dim, output_dim): 33 | super().__init__() 34 | self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False) 35 | self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False) 36 | 37 | def forward(self, hidden_states): 38 | ip_k = self.to_k_ip(hidden_states) 39 | ip_v = self.to_v_ip(hidden_states) 40 | return ip_k, ip_v 41 | 42 | 43 | class SDXLIpAdapter(torch.nn.Module): 44 | def __init__(self): 45 | super().__init__() 46 | shape_list = [(2048, 640)] * 4 + [(2048, 1280)] * 50 + [(2048, 640)] * 6 + [(2048, 1280)] * 10 47 | self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list]) 48 | self.image_proj = IpAdapterImageProjModel() 49 | self.set_full_adapter() 50 | 51 | def set_full_adapter(self): 52 | map_list = sum([ 53 | [(7, i) for i in range(2)], 54 | [(10, i) for i in range(2)], 55 | [(15, i) for i in range(10)], 56 | [(18, i) for i in range(10)], 57 | [(25, i) for i in range(10)], 58 | [(28, i) for i in range(10)], 59 | [(31, i) for i in range(10)], 60 | [(35, i) for i in range(2)], 61 | [(38, i) for i in range(2)], 62 | [(41, i) for i in range(2)], 63 | [(21, i) for i in range(10)], 64 | ], []) 65 | self.call_block_id = {i: j for j, i in enumerate(map_list)} 66 | 67 | def set_less_adapter(self): 68 | map_list = sum([ 69 | [(7, i) for i in range(2)], 70 | [(10, i) for i in range(2)], 71 | [(15, i) for i in range(10)], 72 | [(18, i) for i in range(10)], 73 | [(25, i) for i in range(10)], 74 | [(28, i) for i in range(10)], 75 | [(31, i) for i in range(10)], 76 | [(35, i) for i in range(2)], 77 | [(38, i) for i in range(2)], 78 | [(41, i) for i in range(2)], 79 | [(21, i) for i in range(10)], 80 | ], []) 81 | self.call_block_id = {i: j for j, i in enumerate(map_list) if j>=34 and j<44} 82 | 83 | def forward(self, hidden_states, scale=1.0): 84 | hidden_states = self.image_proj(hidden_states) 85 | hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1]) 86 | ip_kv_dict = {} 87 | for (block_id, transformer_id) in self.call_block_id: 88 | ipadapter_id = self.call_block_id[(block_id, transformer_id)] 89 | ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states) 90 | if block_id not in ip_kv_dict: 91 | ip_kv_dict[block_id] = {} 92 | ip_kv_dict[block_id][transformer_id] = { 93 | "ip_k": ip_k, 94 | "ip_v": ip_v, 95 | "scale": scale 96 | } 97 | return ip_kv_dict 98 | 99 | def state_dict_converter(self): 100 | return SDXLIpAdapterStateDictConverter() 101 | 102 | 103 | class SDXLIpAdapterStateDictConverter: 104 | def __init__(self): 105 | pass 106 | 107 | def from_diffusers(self, state_dict): 108 | state_dict_ = {} 109 | for name in state_dict["ip_adapter"]: 110 | names = name.split(".") 111 | layer_id = str(int(names[0]) // 2) 112 | name_ = ".".join(["ipadapter_modules"] + [layer_id] + names[1:]) 113 | state_dict_[name_] = state_dict["ip_adapter"][name] 114 | for name in state_dict["image_proj"]: 115 | name_ = "image_proj." + name 116 | state_dict_[name_] = state_dict["image_proj"][name] 117 | return state_dict_ 118 | 119 | def from_civitai(self, state_dict): 120 | return self.from_diffusers(state_dict) 121 | 122 | -------------------------------------------------------------------------------- /diffsynth/models/sdxl_motion.py: -------------------------------------------------------------------------------- 1 | from .sd_motion import TemporalBlock 2 | import torch 3 | 4 | 5 | 6 | class SDXLMotionModel(torch.nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | self.motion_modules = torch.nn.ModuleList([ 10 | TemporalBlock(8, 320//8, 320, eps=1e-6), 11 | TemporalBlock(8, 320//8, 320, eps=1e-6), 12 | 13 | TemporalBlock(8, 640//8, 640, eps=1e-6), 14 | TemporalBlock(8, 640//8, 640, eps=1e-6), 15 | 16 | TemporalBlock(8, 1280//8, 1280, eps=1e-6), 17 | TemporalBlock(8, 1280//8, 1280, eps=1e-6), 18 | 19 | TemporalBlock(8, 1280//8, 1280, eps=1e-6), 20 | TemporalBlock(8, 1280//8, 1280, eps=1e-6), 21 | TemporalBlock(8, 1280//8, 1280, eps=1e-6), 22 | 23 | TemporalBlock(8, 640//8, 640, eps=1e-6), 24 | TemporalBlock(8, 640//8, 640, eps=1e-6), 25 | TemporalBlock(8, 640//8, 640, eps=1e-6), 26 | 27 | TemporalBlock(8, 320//8, 320, eps=1e-6), 28 | TemporalBlock(8, 320//8, 320, eps=1e-6), 29 | TemporalBlock(8, 320//8, 320, eps=1e-6), 30 | ]) 31 | self.call_block_id = { 32 | 0: 0, 33 | 2: 1, 34 | 7: 2, 35 | 10: 3, 36 | 15: 4, 37 | 18: 5, 38 | 25: 6, 39 | 28: 7, 40 | 31: 8, 41 | 35: 9, 42 | 38: 10, 43 | 41: 11, 44 | 44: 12, 45 | 46: 13, 46 | 48: 14, 47 | } 48 | 49 | def forward(self): 50 | pass 51 | 52 | def state_dict_converter(self): 53 | return SDMotionModelStateDictConverter() 54 | 55 | 56 | class SDMotionModelStateDictConverter: 57 | def __init__(self): 58 | pass 59 | 60 | def from_diffusers(self, state_dict): 61 | rename_dict = { 62 | "norm": "norm", 63 | "proj_in": "proj_in", 64 | "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q", 65 | "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k", 66 | "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v", 67 | "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out", 68 | "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1", 69 | "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q", 70 | "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k", 71 | "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v", 72 | "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out", 73 | "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2", 74 | "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1", 75 | "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2", 76 | "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj", 77 | "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff", 78 | "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3", 79 | "proj_out": "proj_out", 80 | } 81 | name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")]) 82 | name_list += sorted([i for i in state_dict if i.startswith("mid_block.")]) 83 | name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")]) 84 | state_dict_ = {} 85 | last_prefix, module_id = "", -1 86 | for name in name_list: 87 | names = name.split(".") 88 | prefix_index = names.index("temporal_transformer") + 1 89 | prefix = ".".join(names[:prefix_index]) 90 | if prefix != last_prefix: 91 | last_prefix = prefix 92 | module_id += 1 93 | middle_name = ".".join(names[prefix_index:-1]) 94 | suffix = names[-1] 95 | if "pos_encoder" in names: 96 | rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]]) 97 | else: 98 | rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix]) 99 | state_dict_[rename] = state_dict[name] 100 | return state_dict_ 101 | 102 | def from_civitai(self, state_dict): 103 | return self.from_diffusers(state_dict) 104 | -------------------------------------------------------------------------------- /diffsynth/models/sdxl_vae_decoder.py: -------------------------------------------------------------------------------- 1 | from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter 2 | 3 | 4 | class SDXLVAEDecoder(SDVAEDecoder): 5 | def __init__(self): 6 | super().__init__() 7 | self.scaling_factor = 0.13025 8 | 9 | def state_dict_converter(self): 10 | return SDXLVAEDecoderStateDictConverter() 11 | 12 | 13 | class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter): 14 | def __init__(self): 15 | super().__init__() 16 | -------------------------------------------------------------------------------- /diffsynth/models/sdxl_vae_encoder.py: -------------------------------------------------------------------------------- 1 | from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder 2 | 3 | 4 | class SDXLVAEEncoder(SDVAEEncoder): 5 | def __init__(self): 6 | super().__init__() 7 | self.scaling_factor = 0.13025 8 | 9 | def state_dict_converter(self): 10 | return SDXLVAEEncoderStateDictConverter() 11 | 12 | 13 | class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter): 14 | def __init__(self): 15 | super().__init__() 16 | -------------------------------------------------------------------------------- /diffsynth/models/svd_vae_encoder.py: -------------------------------------------------------------------------------- 1 | from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder 2 | 3 | 4 | class SVDVAEEncoder(SDVAEEncoder): 5 | def __init__(self): 6 | super().__init__() 7 | self.scaling_factor = 0.13025 8 | 9 | def state_dict_converter(self): 10 | return SVDVAEEncoderStateDictConverter() 11 | 12 | 13 | class SVDVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def from_diffusers(self, state_dict): 18 | return super().from_diffusers(state_dict) 19 | 20 | def from_civitai(self, state_dict): 21 | rename_dict = { 22 | "conditioner.embedders.3.encoder.encoder.conv_in.bias": "conv_in.bias", 23 | "conditioner.embedders.3.encoder.encoder.conv_in.weight": "conv_in.weight", 24 | "conditioner.embedders.3.encoder.encoder.conv_out.bias": "conv_out.bias", 25 | "conditioner.embedders.3.encoder.encoder.conv_out.weight": "conv_out.weight", 26 | "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias", 27 | "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight", 28 | "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias", 29 | "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight", 30 | "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias", 31 | "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight", 32 | "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias", 33 | "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight", 34 | "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias", 35 | "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight", 36 | "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias", 37 | "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight", 38 | "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias", 39 | "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight", 40 | "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias", 41 | "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight", 42 | "conditioner.embedders.3.encoder.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias", 43 | "conditioner.embedders.3.encoder.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight", 44 | "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias", 45 | "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight", 46 | "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias", 47 | "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight", 48 | "conditioner.embedders.3.encoder.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias", 49 | "conditioner.embedders.3.encoder.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight", 50 | "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias", 51 | "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight", 52 | "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias", 53 | "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight", 54 | "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias", 55 | "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight", 56 | "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias", 57 | "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight", 58 | "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias", 59 | "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight", 60 | "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias", 61 | "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight", 62 | "conditioner.embedders.3.encoder.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias", 63 | "conditioner.embedders.3.encoder.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight", 64 | "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias", 65 | "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight", 66 | "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias", 67 | "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight", 68 | "conditioner.embedders.3.encoder.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias", 69 | "conditioner.embedders.3.encoder.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight", 70 | "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias", 71 | "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight", 72 | "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias", 73 | "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight", 74 | "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias", 75 | "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight", 76 | "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias", 77 | "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight", 78 | "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias", 79 | "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight", 80 | "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias", 81 | "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight", 82 | "conditioner.embedders.3.encoder.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias", 83 | "conditioner.embedders.3.encoder.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight", 84 | "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias", 85 | "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight", 86 | "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias", 87 | "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight", 88 | "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias", 89 | "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight", 90 | "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias", 91 | "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight", 92 | "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias", 93 | "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight", 94 | "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias", 95 | "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight", 96 | "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias", 97 | "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight", 98 | "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias", 99 | "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight", 100 | "conditioner.embedders.3.encoder.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias", 101 | "conditioner.embedders.3.encoder.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight", 102 | "conditioner.embedders.3.encoder.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias", 103 | "conditioner.embedders.3.encoder.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight", 104 | "conditioner.embedders.3.encoder.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias", 105 | "conditioner.embedders.3.encoder.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight", 106 | "conditioner.embedders.3.encoder.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias", 107 | "conditioner.embedders.3.encoder.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight", 108 | "conditioner.embedders.3.encoder.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias", 109 | "conditioner.embedders.3.encoder.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight", 110 | "conditioner.embedders.3.encoder.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias", 111 | "conditioner.embedders.3.encoder.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight", 112 | "conditioner.embedders.3.encoder.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias", 113 | "conditioner.embedders.3.encoder.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight", 114 | "conditioner.embedders.3.encoder.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias", 115 | "conditioner.embedders.3.encoder.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight", 116 | "conditioner.embedders.3.encoder.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias", 117 | "conditioner.embedders.3.encoder.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight", 118 | "conditioner.embedders.3.encoder.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias", 119 | "conditioner.embedders.3.encoder.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight", 120 | "conditioner.embedders.3.encoder.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias", 121 | "conditioner.embedders.3.encoder.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight", 122 | "conditioner.embedders.3.encoder.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias", 123 | "conditioner.embedders.3.encoder.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight", 124 | "conditioner.embedders.3.encoder.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias", 125 | "conditioner.embedders.3.encoder.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight", 126 | "conditioner.embedders.3.encoder.encoder.norm_out.bias": "conv_norm_out.bias", 127 | "conditioner.embedders.3.encoder.encoder.norm_out.weight": "conv_norm_out.weight", 128 | "conditioner.embedders.3.encoder.quant_conv.bias": "quant_conv.bias", 129 | "conditioner.embedders.3.encoder.quant_conv.weight": "quant_conv.weight", 130 | } 131 | state_dict_ = {} 132 | for name in state_dict: 133 | if name in rename_dict: 134 | param = state_dict[name] 135 | if "transformer_blocks" in rename_dict[name]: 136 | param = param.squeeze() 137 | state_dict_[rename_dict[name]] = param 138 | return state_dict_ 139 | -------------------------------------------------------------------------------- /diffsynth/models/tiler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange, repeat 3 | 4 | 5 | class TileWorker: 6 | def __init__(self): 7 | pass 8 | 9 | 10 | def mask(self, height, width, border_width): 11 | # Create a mask with shape (height, width). 12 | # The centre area is filled with 1, and the border line is filled with values in range (0, 1]. 13 | x = torch.arange(height).repeat(width, 1).T 14 | y = torch.arange(width).repeat(height, 1) 15 | mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values 16 | mask = (mask / border_width).clip(0, 1) 17 | return mask 18 | 19 | 20 | def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype): 21 | # Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num) 22 | batch_size, channel, _, _ = model_input.shape 23 | model_input = model_input.to(device=tile_device, dtype=tile_dtype) 24 | unfold_operator = torch.nn.Unfold( 25 | kernel_size=(tile_size, tile_size), 26 | stride=(tile_stride, tile_stride) 27 | ) 28 | model_input = unfold_operator(model_input) 29 | model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1)) 30 | 31 | return model_input 32 | 33 | 34 | def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype): 35 | # Call y=forward_fn(x) for each tile 36 | tile_num = model_input.shape[-1] 37 | model_output_stack = [] 38 | 39 | for tile_id in range(0, tile_num, tile_batch_size): 40 | 41 | # process input 42 | tile_id_ = min(tile_id + tile_batch_size, tile_num) 43 | x = model_input[:, :, :, :, tile_id: tile_id_] 44 | x = x.to(device=inference_device, dtype=inference_dtype) 45 | x = rearrange(x, "b c h w n -> (n b) c h w") 46 | 47 | # process output 48 | y = forward_fn(x) 49 | y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id) 50 | y = y.to(device=tile_device, dtype=tile_dtype) 51 | model_output_stack.append(y) 52 | 53 | model_output = torch.concat(model_output_stack, dim=-1) 54 | return model_output 55 | 56 | 57 | def io_scale(self, model_output, tile_size): 58 | # Determine the size modification happend in forward_fn 59 | # We only consider the same scale on height and width. 60 | io_scale = model_output.shape[2] / tile_size 61 | return io_scale 62 | 63 | 64 | def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype): 65 | # The reversed function of tile 66 | mask = self.mask(tile_size, tile_size, border_width) 67 | mask = mask.to(device=tile_device, dtype=tile_dtype) 68 | mask = rearrange(mask, "h w -> 1 1 h w 1") 69 | model_output = model_output * mask 70 | 71 | fold_operator = torch.nn.Fold( 72 | output_size=(height, width), 73 | kernel_size=(tile_size, tile_size), 74 | stride=(tile_stride, tile_stride) 75 | ) 76 | mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1]) 77 | model_output = rearrange(model_output, "b c h w n -> b (c h w) n") 78 | model_output = fold_operator(model_output) / fold_operator(mask) 79 | 80 | return model_output 81 | 82 | 83 | def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None): 84 | # Prepare 85 | inference_device, inference_dtype = model_input.device, model_input.dtype 86 | height, width = model_input.shape[2], model_input.shape[3] 87 | border_width = int(tile_stride*0.5) if border_width is None else border_width 88 | 89 | # tile 90 | model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype) 91 | 92 | # inference 93 | model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype) 94 | 95 | # resize 96 | io_scale = self.io_scale(model_output, tile_size) 97 | height, width = int(height*io_scale), int(width*io_scale) 98 | tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale) 99 | border_width = int(border_width*io_scale) 100 | 101 | # untile 102 | model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype) 103 | 104 | # Done! 105 | model_output = model_output.to(device=inference_device, dtype=inference_dtype) 106 | return model_output -------------------------------------------------------------------------------- /diffsynth/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .stable_diffusion import SDImagePipeline 2 | from .stable_diffusion_xl import SDXLImagePipeline 3 | from .stable_diffusion_video import SDVideoPipeline, SDVideoPipelineRunner 4 | from .stable_diffusion_xl_video import SDXLVideoPipeline 5 | from .stable_video_diffusion import SVDVideoPipeline 6 | from .hunyuan_dit import HunyuanDiTImagePipeline 7 | -------------------------------------------------------------------------------- /diffsynth/pipelines/dancer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..models import SDUNet, SDMotionModel, SDXLUNet, SDXLMotionModel 3 | from ..models.sd_unet import PushBlock, PopBlock 4 | from ..controlnets import MultiControlNetManager 5 | 6 | 7 | def lets_dance( 8 | unet: SDUNet, 9 | motion_modules: SDMotionModel = None, 10 | controlnet: MultiControlNetManager = None, 11 | sample = None, 12 | timestep = None, 13 | encoder_hidden_states = None, 14 | ipadapter_kwargs_list = {}, 15 | controlnet_frames = None, 16 | unet_batch_size = 1, 17 | controlnet_batch_size = 1, 18 | cross_frame_attention = False, 19 | tiled=False, 20 | tile_size=64, 21 | tile_stride=32, 22 | device = "cuda", 23 | vram_limit_level = 0, 24 | ): 25 | # 1. ControlNet 26 | # This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride. 27 | # I leave it here because I intend to do something interesting on the ControlNets. 28 | controlnet_insert_block_id = 30 29 | if controlnet is not None and controlnet_frames is not None: 30 | res_stacks = [] 31 | # process controlnet frames with batch 32 | for batch_id in range(0, sample.shape[0], controlnet_batch_size): 33 | batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0]) 34 | res_stack = controlnet( 35 | sample[batch_id: batch_id_], 36 | timestep, 37 | encoder_hidden_states[batch_id: batch_id_], 38 | controlnet_frames[:, batch_id: batch_id_], 39 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride 40 | ) 41 | if vram_limit_level >= 1: 42 | res_stack = [res.cpu() for res in res_stack] 43 | res_stacks.append(res_stack) 44 | # concat the residual 45 | additional_res_stack = [] 46 | for i in range(len(res_stacks[0])): 47 | res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0) 48 | additional_res_stack.append(res) 49 | else: 50 | additional_res_stack = None 51 | 52 | # 2. time 53 | time_emb = unet.time_proj(timestep[None]).to(sample.dtype) 54 | time_emb = unet.time_embedding(time_emb) 55 | 56 | # 3. pre-process 57 | height, width = sample.shape[2], sample.shape[3] 58 | hidden_states = unet.conv_in(sample) 59 | text_emb = encoder_hidden_states 60 | res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states] 61 | 62 | # 4. blocks 63 | for block_id, block in enumerate(unet.blocks): 64 | # 4.1 UNet 65 | if isinstance(block, PushBlock): 66 | hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) 67 | if vram_limit_level>=1: 68 | res_stack[-1] = res_stack[-1].cpu() 69 | elif isinstance(block, PopBlock): 70 | if vram_limit_level>=1: 71 | res_stack[-1] = res_stack[-1].to(device) 72 | hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) 73 | else: 74 | hidden_states_input = hidden_states 75 | hidden_states_output = [] 76 | for batch_id in range(0, sample.shape[0], unet_batch_size): 77 | batch_id_ = min(batch_id + unet_batch_size, sample.shape[0]) 78 | hidden_states, _, _, _ = block( 79 | hidden_states_input[batch_id: batch_id_], 80 | time_emb, 81 | text_emb[batch_id: batch_id_], 82 | res_stack, 83 | cross_frame_attention=cross_frame_attention, 84 | ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}), 85 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride 86 | ) 87 | hidden_states_output.append(hidden_states) 88 | hidden_states = torch.concat(hidden_states_output, dim=0) 89 | # 4.2 AnimateDiff 90 | if motion_modules is not None: 91 | if block_id in motion_modules.call_block_id: 92 | motion_module_id = motion_modules.call_block_id[block_id] 93 | hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id]( 94 | hidden_states, time_emb, text_emb, res_stack, 95 | batch_size=1 96 | ) 97 | # 4.3 ControlNet 98 | if block_id == controlnet_insert_block_id and additional_res_stack is not None: 99 | hidden_states += additional_res_stack.pop().to(device) 100 | if vram_limit_level>=1: 101 | res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)] 102 | else: 103 | res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)] 104 | 105 | # 5. output 106 | hidden_states = unet.conv_norm_out(hidden_states) 107 | hidden_states = unet.conv_act(hidden_states) 108 | hidden_states = unet.conv_out(hidden_states) 109 | 110 | return hidden_states 111 | 112 | 113 | 114 | 115 | def lets_dance_xl( 116 | unet: SDXLUNet, 117 | motion_modules: SDXLMotionModel = None, 118 | controlnet: MultiControlNetManager = None, 119 | sample = None, 120 | add_time_id = None, 121 | add_text_embeds = None, 122 | timestep = None, 123 | encoder_hidden_states = None, 124 | ipadapter_kwargs_list = {}, 125 | controlnet_frames = None, 126 | unet_batch_size = 1, 127 | controlnet_batch_size = 1, 128 | cross_frame_attention = False, 129 | tiled=False, 130 | tile_size=64, 131 | tile_stride=32, 132 | device = "cuda", 133 | vram_limit_level = 0, 134 | ): 135 | # 2. time 136 | t_emb = unet.time_proj(timestep[None]).to(sample.dtype) 137 | t_emb = unet.time_embedding(t_emb) 138 | 139 | time_embeds = unet.add_time_proj(add_time_id) 140 | time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1)) 141 | add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1) 142 | add_embeds = add_embeds.to(sample.dtype) 143 | add_embeds = unet.add_time_embedding(add_embeds) 144 | 145 | time_emb = t_emb + add_embeds 146 | 147 | # 3. pre-process 148 | height, width = sample.shape[2], sample.shape[3] 149 | hidden_states = unet.conv_in(sample) 150 | text_emb = encoder_hidden_states 151 | res_stack = [hidden_states] 152 | 153 | # 4. blocks 154 | for block_id, block in enumerate(unet.blocks): 155 | hidden_states, time_emb, text_emb, res_stack = block( 156 | hidden_states, time_emb, text_emb, res_stack, 157 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, 158 | ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}) 159 | ) 160 | # 4.2 AnimateDiff 161 | if motion_modules is not None: 162 | if block_id in motion_modules.call_block_id: 163 | motion_module_id = motion_modules.call_block_id[block_id] 164 | hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id]( 165 | hidden_states, time_emb, text_emb, res_stack, 166 | batch_size=1 167 | ) 168 | 169 | # 5. output 170 | hidden_states = unet.conv_norm_out(hidden_states) 171 | hidden_states = unet.conv_act(hidden_states) 172 | hidden_states = unet.conv_out(hidden_states) 173 | 174 | return hidden_states -------------------------------------------------------------------------------- /diffsynth/pipelines/stable_diffusion.py: -------------------------------------------------------------------------------- 1 | from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder 2 | from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator 3 | from ..prompts import SDPrompter 4 | from ..schedulers import EnhancedDDIMScheduler 5 | from .dancer import lets_dance 6 | from typing import List 7 | import torch 8 | from tqdm import tqdm 9 | from PIL import Image 10 | import numpy as np 11 | 12 | 13 | class SDImagePipeline(torch.nn.Module): 14 | 15 | def __init__(self, device="cuda", torch_dtype=torch.float16): 16 | super().__init__() 17 | self.scheduler = EnhancedDDIMScheduler() 18 | self.prompter = SDPrompter() 19 | self.device = device 20 | self.torch_dtype = torch_dtype 21 | # models 22 | self.text_encoder: SDTextEncoder = None 23 | self.unet: SDUNet = None 24 | self.vae_decoder: SDVAEDecoder = None 25 | self.vae_encoder: SDVAEEncoder = None 26 | self.controlnet: MultiControlNetManager = None 27 | self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None 28 | self.ipadapter: SDIpAdapter = None 29 | 30 | 31 | def fetch_main_models(self, model_manager: ModelManager): 32 | self.text_encoder = model_manager.text_encoder 33 | self.unet = model_manager.unet 34 | self.vae_decoder = model_manager.vae_decoder 35 | self.vae_encoder = model_manager.vae_encoder 36 | 37 | 38 | def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]): 39 | controlnet_units = [] 40 | for config in controlnet_config_units: 41 | controlnet_unit = ControlNetUnit( 42 | Annotator(config.processor_id), 43 | model_manager.get_model_with_model_path(config.model_path), 44 | config.scale 45 | ) 46 | controlnet_units.append(controlnet_unit) 47 | self.controlnet = MultiControlNetManager(controlnet_units) 48 | 49 | 50 | def fetch_ipadapter(self, model_manager: ModelManager): 51 | if "ipadapter" in model_manager.model: 52 | self.ipadapter = model_manager.ipadapter 53 | if "ipadapter_image_encoder" in model_manager.model: 54 | self.ipadapter_image_encoder = model_manager.ipadapter_image_encoder 55 | 56 | 57 | def fetch_prompter(self, model_manager: ModelManager): 58 | self.prompter.load_from_model_manager(model_manager) 59 | 60 | 61 | @staticmethod 62 | def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]): 63 | pipe = SDImagePipeline( 64 | device=model_manager.device, 65 | torch_dtype=model_manager.torch_dtype, 66 | ) 67 | pipe.fetch_main_models(model_manager) 68 | pipe.fetch_prompter(model_manager) 69 | pipe.fetch_controlnet_models(model_manager, controlnet_config_units) 70 | pipe.fetch_ipadapter(model_manager) 71 | return pipe 72 | 73 | 74 | def preprocess_image(self, image): 75 | image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) 76 | return image 77 | 78 | 79 | def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32): 80 | image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] 81 | image = image.cpu().permute(1, 2, 0).numpy() 82 | image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) 83 | return image 84 | 85 | 86 | @torch.no_grad() 87 | def __call__( 88 | self, 89 | prompt, 90 | negative_prompt="", 91 | cfg_scale=7.5, 92 | clip_skip=1, 93 | input_image=None, 94 | ipadapter_images=None, 95 | ipadapter_scale=1.0, 96 | controlnet_image=None, 97 | denoising_strength=1.0, 98 | height=512, 99 | width=512, 100 | num_inference_steps=20, 101 | tiled=False, 102 | tile_size=64, 103 | tile_stride=32, 104 | progress_bar_cmd=tqdm, 105 | progress_bar_st=None, 106 | ): 107 | # Prepare scheduler 108 | self.scheduler.set_timesteps(num_inference_steps, denoising_strength) 109 | 110 | # Prepare latent tensors 111 | if input_image is not None: 112 | image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) 113 | latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) 114 | noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) 115 | latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) 116 | else: 117 | latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) 118 | 119 | # Encode prompts 120 | prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True) 121 | prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False) 122 | 123 | # IP-Adapter 124 | if ipadapter_images is not None: 125 | ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images) 126 | ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale) 127 | ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding)) 128 | else: 129 | ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {} 130 | 131 | # Prepare ControlNets 132 | if controlnet_image is not None: 133 | controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype) 134 | controlnet_image = controlnet_image.unsqueeze(1) 135 | 136 | # Denoise 137 | for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): 138 | timestep = torch.IntTensor((timestep,))[0].to(self.device) 139 | 140 | # Classifier-free guidance 141 | noise_pred_posi = lets_dance( 142 | self.unet, motion_modules=None, controlnet=self.controlnet, 143 | sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_image, 144 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, 145 | ipadapter_kwargs_list=ipadapter_kwargs_list_posi, 146 | device=self.device, vram_limit_level=0 147 | ) 148 | noise_pred_nega = lets_dance( 149 | self.unet, motion_modules=None, controlnet=self.controlnet, 150 | sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_image, 151 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, 152 | ipadapter_kwargs_list=ipadapter_kwargs_list_nega, 153 | device=self.device, vram_limit_level=0 154 | ) 155 | noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) 156 | 157 | # DDIM 158 | latents = self.scheduler.step(noise_pred, timestep, latents) 159 | 160 | # UI 161 | if progress_bar_st is not None: 162 | progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) 163 | 164 | # Decode image 165 | image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) 166 | 167 | return image 168 | -------------------------------------------------------------------------------- /diffsynth/pipelines/stable_diffusion_xl.py: -------------------------------------------------------------------------------- 1 | from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder 2 | # TODO: SDXL ControlNet 3 | from ..prompts import SDXLPrompter 4 | from ..schedulers import EnhancedDDIMScheduler 5 | from .dancer import lets_dance_xl 6 | import torch 7 | from tqdm import tqdm 8 | from PIL import Image 9 | import numpy as np 10 | 11 | 12 | class SDXLImagePipeline(torch.nn.Module): 13 | 14 | def __init__(self, device="cuda", torch_dtype=torch.float16): 15 | super().__init__() 16 | self.scheduler = EnhancedDDIMScheduler() 17 | self.prompter = SDXLPrompter() 18 | self.device = device 19 | self.torch_dtype = torch_dtype 20 | # models 21 | self.text_encoder: SDXLTextEncoder = None 22 | self.text_encoder_2: SDXLTextEncoder2 = None 23 | self.unet: SDXLUNet = None 24 | self.vae_decoder: SDXLVAEDecoder = None 25 | self.vae_encoder: SDXLVAEEncoder = None 26 | self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None 27 | self.ipadapter: SDXLIpAdapter = None 28 | # TODO: SDXL ControlNet 29 | 30 | def fetch_main_models(self, model_manager: ModelManager): 31 | self.text_encoder = model_manager.text_encoder 32 | self.text_encoder_2 = model_manager.text_encoder_2 33 | self.unet = model_manager.unet 34 | self.vae_decoder = model_manager.vae_decoder 35 | self.vae_encoder = model_manager.vae_encoder 36 | 37 | 38 | def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs): 39 | # TODO: SDXL ControlNet 40 | pass 41 | 42 | 43 | def fetch_ipadapter(self, model_manager: ModelManager): 44 | if "ipadapter_xl" in model_manager.model: 45 | self.ipadapter = model_manager.ipadapter_xl 46 | if "ipadapter_xl_image_encoder" in model_manager.model: 47 | self.ipadapter_image_encoder = model_manager.ipadapter_xl_image_encoder 48 | 49 | 50 | def fetch_prompter(self, model_manager: ModelManager): 51 | self.prompter.load_from_model_manager(model_manager) 52 | 53 | 54 | @staticmethod 55 | def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs): 56 | pipe = SDXLImagePipeline( 57 | device=model_manager.device, 58 | torch_dtype=model_manager.torch_dtype, 59 | ) 60 | pipe.fetch_main_models(model_manager) 61 | pipe.fetch_prompter(model_manager) 62 | pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units) 63 | pipe.fetch_ipadapter(model_manager) 64 | return pipe 65 | 66 | 67 | def preprocess_image(self, image): 68 | image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) 69 | return image 70 | 71 | 72 | def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32): 73 | image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] 74 | image = image.cpu().permute(1, 2, 0).numpy() 75 | image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) 76 | return image 77 | 78 | 79 | @torch.no_grad() 80 | def __call__( 81 | self, 82 | prompt, 83 | negative_prompt="", 84 | cfg_scale=7.5, 85 | clip_skip=1, 86 | clip_skip_2=2, 87 | input_image=None, 88 | ipadapter_images=None, 89 | ipadapter_scale=1.0, 90 | controlnet_image=None, 91 | denoising_strength=1.0, 92 | height=1024, 93 | width=1024, 94 | num_inference_steps=20, 95 | tiled=False, 96 | tile_size=64, 97 | tile_stride=32, 98 | progress_bar_cmd=tqdm, 99 | progress_bar_st=None, 100 | ): 101 | # Prepare scheduler 102 | self.scheduler.set_timesteps(num_inference_steps, denoising_strength) 103 | 104 | # Prepare latent tensors 105 | if input_image is not None: 106 | image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) 107 | latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype) 108 | noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) 109 | latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) 110 | else: 111 | latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) 112 | 113 | # Encode prompts 114 | add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt( 115 | self.text_encoder, 116 | self.text_encoder_2, 117 | prompt, 118 | clip_skip=clip_skip, clip_skip_2=clip_skip_2, 119 | device=self.device, 120 | positive=True, 121 | ) 122 | if cfg_scale != 1.0: 123 | add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt( 124 | self.text_encoder, 125 | self.text_encoder_2, 126 | negative_prompt, 127 | clip_skip=clip_skip, clip_skip_2=clip_skip_2, 128 | device=self.device, 129 | positive=False, 130 | ) 131 | 132 | # Prepare positional id 133 | add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device) 134 | 135 | # IP-Adapter 136 | if ipadapter_images is not None: 137 | ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images) 138 | ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale) 139 | ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding)) 140 | else: 141 | ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {} 142 | 143 | # Denoise 144 | for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): 145 | timestep = torch.IntTensor((timestep,))[0].to(self.device) 146 | 147 | # Classifier-free guidance 148 | noise_pred_posi = lets_dance_xl( 149 | self.unet, 150 | sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, 151 | add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi, 152 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, 153 | ipadapter_kwargs_list=ipadapter_kwargs_list_posi, 154 | ) 155 | if cfg_scale != 1.0: 156 | noise_pred_nega = lets_dance_xl( 157 | self.unet, 158 | sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, 159 | add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega, 160 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, 161 | ipadapter_kwargs_list=ipadapter_kwargs_list_nega, 162 | ) 163 | noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) 164 | else: 165 | noise_pred = noise_pred_posi 166 | 167 | latents = self.scheduler.step(noise_pred, timestep, latents) 168 | 169 | if progress_bar_st is not None: 170 | progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) 171 | 172 | # Decode image 173 | image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) 174 | 175 | return image 176 | -------------------------------------------------------------------------------- /diffsynth/pipelines/stable_diffusion_xl_video.py: -------------------------------------------------------------------------------- 1 | from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLMotionModel 2 | from .dancer import lets_dance_xl 3 | # TODO: SDXL ControlNet 4 | from ..prompts import SDXLPrompter 5 | from ..schedulers import EnhancedDDIMScheduler 6 | import torch 7 | from tqdm import tqdm 8 | from PIL import Image 9 | import numpy as np 10 | 11 | 12 | class SDXLVideoPipeline(torch.nn.Module): 13 | 14 | def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True): 15 | super().__init__() 16 | self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear") 17 | self.prompter = SDXLPrompter() 18 | self.device = device 19 | self.torch_dtype = torch_dtype 20 | # models 21 | self.text_encoder: SDXLTextEncoder = None 22 | self.text_encoder_2: SDXLTextEncoder2 = None 23 | self.unet: SDXLUNet = None 24 | self.vae_decoder: SDXLVAEDecoder = None 25 | self.vae_encoder: SDXLVAEEncoder = None 26 | # TODO: SDXL ControlNet 27 | self.motion_modules: SDXLMotionModel = None 28 | 29 | 30 | def fetch_main_models(self, model_manager: ModelManager): 31 | self.text_encoder = model_manager.text_encoder 32 | self.text_encoder_2 = model_manager.text_encoder_2 33 | self.unet = model_manager.unet 34 | self.vae_decoder = model_manager.vae_decoder 35 | self.vae_encoder = model_manager.vae_encoder 36 | 37 | 38 | def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs): 39 | # TODO: SDXL ControlNet 40 | pass 41 | 42 | 43 | def fetch_motion_modules(self, model_manager: ModelManager): 44 | if "motion_modules_xl" in model_manager.model: 45 | self.motion_modules = model_manager.motion_modules_xl 46 | 47 | 48 | def fetch_prompter(self, model_manager: ModelManager): 49 | self.prompter.load_from_model_manager(model_manager) 50 | 51 | 52 | @staticmethod 53 | def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs): 54 | pipe = SDXLVideoPipeline( 55 | device=model_manager.device, 56 | torch_dtype=model_manager.torch_dtype, 57 | use_animatediff="motion_modules_xl" in model_manager.model 58 | ) 59 | pipe.fetch_main_models(model_manager) 60 | pipe.fetch_motion_modules(model_manager) 61 | pipe.fetch_prompter(model_manager) 62 | pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units) 63 | return pipe 64 | 65 | 66 | def preprocess_image(self, image): 67 | image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) 68 | return image 69 | 70 | 71 | def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32): 72 | image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] 73 | image = image.cpu().permute(1, 2, 0).numpy() 74 | image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) 75 | return image 76 | 77 | 78 | def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32): 79 | images = [ 80 | self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) 81 | for frame_id in range(latents.shape[0]) 82 | ] 83 | return images 84 | 85 | 86 | def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32): 87 | latents = [] 88 | for image in processed_images: 89 | image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) 90 | latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu() 91 | latents.append(latent) 92 | latents = torch.concat(latents, dim=0) 93 | return latents 94 | 95 | 96 | @torch.no_grad() 97 | def __call__( 98 | self, 99 | prompt, 100 | negative_prompt="", 101 | cfg_scale=7.5, 102 | clip_skip=1, 103 | clip_skip_2=2, 104 | num_frames=None, 105 | input_frames=None, 106 | controlnet_frames=None, 107 | denoising_strength=1.0, 108 | height=512, 109 | width=512, 110 | num_inference_steps=20, 111 | animatediff_batch_size = 16, 112 | animatediff_stride = 8, 113 | unet_batch_size = 1, 114 | controlnet_batch_size = 1, 115 | cross_frame_attention = False, 116 | smoother=None, 117 | smoother_progress_ids=[], 118 | vram_limit_level=0, 119 | progress_bar_cmd=tqdm, 120 | progress_bar_st=None, 121 | ): 122 | # Prepare scheduler 123 | self.scheduler.set_timesteps(num_inference_steps, denoising_strength) 124 | 125 | # Prepare latent tensors 126 | if self.motion_modules is None: 127 | noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1) 128 | else: 129 | noise = torch.randn((num_frames, 4, height//8, width//8), device="cuda", dtype=self.torch_dtype) 130 | if input_frames is None or denoising_strength == 1.0: 131 | latents = noise 132 | else: 133 | latents = self.encode_images(input_frames) 134 | latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) 135 | 136 | # Encode prompts 137 | add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt( 138 | self.text_encoder, 139 | self.text_encoder_2, 140 | prompt, 141 | clip_skip=clip_skip, clip_skip_2=clip_skip_2, 142 | device=self.device, 143 | positive=True, 144 | ) 145 | if cfg_scale != 1.0: 146 | add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt( 147 | self.text_encoder, 148 | self.text_encoder_2, 149 | negative_prompt, 150 | clip_skip=clip_skip, clip_skip_2=clip_skip_2, 151 | device=self.device, 152 | positive=False, 153 | ) 154 | 155 | # Prepare positional id 156 | add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device) 157 | 158 | # Denoise 159 | for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): 160 | timestep = torch.IntTensor((timestep,))[0].to(self.device) 161 | 162 | # Classifier-free guidance 163 | noise_pred_posi = lets_dance_xl( 164 | self.unet, motion_modules=self.motion_modules, controlnet=None, 165 | sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi, 166 | timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames, 167 | cross_frame_attention=cross_frame_attention, 168 | device=self.device, vram_limit_level=vram_limit_level 169 | ) 170 | if cfg_scale != 1.0: 171 | noise_pred_nega = lets_dance_xl( 172 | self.unet, motion_modules=self.motion_modules, controlnet=None, 173 | sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega, 174 | timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames, 175 | cross_frame_attention=cross_frame_attention, 176 | device=self.device, vram_limit_level=vram_limit_level 177 | ) 178 | noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) 179 | else: 180 | noise_pred = noise_pred_posi 181 | 182 | latents = self.scheduler.step(noise_pred, timestep, latents) 183 | 184 | if progress_bar_st is not None: 185 | progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) 186 | 187 | # Decode image 188 | image = self.decode_images(latents.to(torch.float32)) 189 | 190 | return image 191 | -------------------------------------------------------------------------------- /diffsynth/pipelines/stable_video_diffusion.py: -------------------------------------------------------------------------------- 1 | from ..models import ModelManager, SVDImageEncoder, SVDUNet, SVDVAEEncoder, SVDVAEDecoder 2 | from ..schedulers import ContinuousODEScheduler 3 | import torch 4 | from tqdm import tqdm 5 | from PIL import Image 6 | import numpy as np 7 | from einops import rearrange, repeat 8 | from comfy.utils import ProgressBar 9 | 10 | 11 | 12 | class SVDVideoPipeline(torch.nn.Module): 13 | 14 | def __init__(self, device="cuda", torch_dtype=torch.float16): 15 | super().__init__() 16 | self.scheduler = ContinuousODEScheduler() 17 | self.device = device 18 | self.torch_dtype = torch_dtype 19 | # models 20 | self.image_encoder: SVDImageEncoder = None 21 | self.unet: SVDUNet = None 22 | self.vae_encoder: SVDVAEEncoder = None 23 | self.vae_decoder: SVDVAEDecoder = None 24 | 25 | 26 | def fetch_main_models(self, model_manager: ModelManager): 27 | self.image_encoder = model_manager.image_encoder 28 | self.unet = model_manager.unet 29 | self.vae_encoder = model_manager.vae_encoder 30 | self.vae_decoder = model_manager.vae_decoder 31 | 32 | 33 | @staticmethod 34 | def from_model_manager(model_manager: ModelManager, **kwargs): 35 | pipe = SVDVideoPipeline(device=model_manager.device, torch_dtype=model_manager.torch_dtype) 36 | pipe.fetch_main_models(model_manager) 37 | return pipe 38 | 39 | 40 | def preprocess_image(self, image): 41 | #image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) 42 | return image 43 | 44 | 45 | def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32): 46 | image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] 47 | image = image.cpu().permute(1, 2, 0).numpy() 48 | image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) 49 | return image 50 | 51 | 52 | def encode_image_with_clip(self, image): 53 | image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) 54 | image = SVDCLIPImageProcessor().resize_with_antialiasing(image, (224, 224)) 55 | image = (image + 1.0) / 2.0 56 | mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.torch_dtype) 57 | std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.torch_dtype) 58 | image = (image - mean) / std 59 | image_emb = self.image_encoder(image) 60 | return image_emb 61 | 62 | 63 | def encode_image_with_vae(self, image, noise_aug_strength): 64 | image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) 65 | noise = torch.randn(image.shape, device="cpu", dtype=self.torch_dtype).to(self.device) 66 | image = image + noise_aug_strength * noise 67 | image_emb = self.vae_encoder(image) / self.vae_encoder.scaling_factor 68 | return image_emb 69 | 70 | 71 | def encode_video_with_vae(self, video): 72 | #video = torch.concat([self.preprocess_image(frame) for frame in video], dim=0) 73 | video = rearrange(video, "T C H W -> 1 C T H W") 74 | video = video.to(device=self.device, dtype=self.torch_dtype) 75 | latents = self.vae_encoder.encode_video(video) 76 | latents = rearrange(latents[0], "C T H W -> T C H W") 77 | return latents 78 | 79 | 80 | def tensor2video(self, frames): 81 | frames = rearrange(frames, "C T H W -> T H W C") 82 | frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8) 83 | frames = [Image.fromarray(frame) for frame in frames] 84 | return frames 85 | 86 | 87 | def calculate_noise_pred( 88 | self, 89 | latents, 90 | timestep, 91 | add_time_id, 92 | cfg_scales, 93 | image_emb_vae_posi, image_emb_clip_posi, 94 | image_emb_vae_nega, image_emb_clip_nega 95 | ): 96 | # Positive side 97 | noise_pred_posi = self.unet( 98 | torch.cat([latents, image_emb_vae_posi], dim=1), 99 | timestep, image_emb_clip_posi, add_time_id 100 | ) 101 | # Negative side 102 | noise_pred_nega = self.unet( 103 | torch.cat([latents, image_emb_vae_nega], dim=1), 104 | timestep, image_emb_clip_nega, add_time_id 105 | ) 106 | 107 | # Classifier-free guidance 108 | noise_pred = noise_pred_nega + cfg_scales * (noise_pred_posi - noise_pred_nega) 109 | 110 | return noise_pred 111 | 112 | 113 | def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0): 114 | if post_normalize: 115 | mean, std = latents.mean(), latents.std() 116 | latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean 117 | latents = latents * contrast_enhance_scale 118 | return latents 119 | 120 | 121 | @torch.no_grad() 122 | def __call__( 123 | self, 124 | input_image=None, 125 | input_video=None, 126 | mask_frames=[], 127 | mask_frame_ids=[], 128 | min_cfg_scale=1.0, 129 | max_cfg_scale=3.0, 130 | denoising_strength=1.0, 131 | num_frames=25, 132 | height=576, 133 | width=1024, 134 | fps=7, 135 | motion_bucket_id=127, 136 | noise_aug_strength=0.02, 137 | num_inference_steps=20, 138 | post_normalize=True, 139 | contrast_enhance_scale=1.2, 140 | progress_bar_cmd=tqdm, 141 | progress_bar_st=None, 142 | ): 143 | # Prepare scheduler 144 | self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength) 145 | 146 | # Prepare latent tensors 147 | noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).to(self.device) 148 | if denoising_strength == 1.0: 149 | latents = noise.clone() 150 | else: 151 | latents = self.encode_video_with_vae(input_video) 152 | latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0]) 153 | 154 | # Prepare mask frames 155 | if len(mask_frames) > 0: 156 | mask_latents = self.encode_video_with_vae(mask_frames) 157 | 158 | # Encode image 159 | image_emb_clip_posi = self.encode_image_with_clip(input_image) 160 | image_emb_clip_nega = torch.zeros_like(image_emb_clip_posi) 161 | image_emb_vae_posi = repeat(self.encode_image_with_vae(input_image, noise_aug_strength), "B C H W -> (B T) C H W", T=num_frames) 162 | image_emb_vae_nega = torch.zeros_like(image_emb_vae_posi) 163 | 164 | # Prepare classifier-free guidance 165 | cfg_scales = torch.linspace(min_cfg_scale, max_cfg_scale, num_frames) 166 | cfg_scales = cfg_scales.reshape(num_frames, 1, 1, 1).to(device=self.device, dtype=self.torch_dtype) 167 | 168 | # Prepare positional id 169 | add_time_id = torch.tensor([[fps-1, motion_bucket_id, noise_aug_strength]], device=self.device) 170 | 171 | # Denoise 172 | comfy_pbar = ProgressBar(num_inference_steps) 173 | for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): 174 | 175 | # Mask frames 176 | for frame_id, mask_frame_id in enumerate(mask_frame_ids): 177 | latents[mask_frame_id] = self.scheduler.add_noise(mask_latents[frame_id], noise[mask_frame_id], timestep) 178 | 179 | # Fetch model output 180 | noise_pred = self.calculate_noise_pred( 181 | latents, timestep, add_time_id, cfg_scales, 182 | image_emb_vae_posi, image_emb_clip_posi, image_emb_vae_nega, image_emb_clip_nega 183 | ) 184 | 185 | # Forward Euler 186 | latents = self.scheduler.step(noise_pred, timestep, latents) 187 | 188 | # Update progress bar 189 | if progress_bar_st is not None: 190 | progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) 191 | comfy_pbar.update(1) 192 | 193 | # Decode image 194 | latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale) 195 | video = self.vae_decoder.decode_video(latents, progress_bar=progress_bar_cmd) 196 | video = self.tensor2video(video) 197 | 198 | return video 199 | 200 | 201 | 202 | class SVDCLIPImageProcessor: 203 | def __init__(self): 204 | pass 205 | 206 | def resize_with_antialiasing(self, input, size, interpolation="bicubic", align_corners=True): 207 | h, w = input.shape[-2:] 208 | factors = (h / size[0], w / size[1]) 209 | 210 | # First, we have to determine sigma 211 | # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 212 | sigmas = ( 213 | max((factors[0] - 1.0) / 2.0, 0.001), 214 | max((factors[1] - 1.0) / 2.0, 0.001), 215 | ) 216 | 217 | # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma 218 | # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 219 | # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now 220 | ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) 221 | 222 | # Make sure it is odd 223 | if (ks[0] % 2) == 0: 224 | ks = ks[0] + 1, ks[1] 225 | 226 | if (ks[1] % 2) == 0: 227 | ks = ks[0], ks[1] + 1 228 | 229 | input = self._gaussian_blur2d(input, ks, sigmas) 230 | 231 | output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) 232 | return output 233 | 234 | 235 | def _compute_padding(self, kernel_size): 236 | """Compute padding tuple.""" 237 | # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) 238 | # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad 239 | if len(kernel_size) < 2: 240 | raise AssertionError(kernel_size) 241 | computed = [k - 1 for k in kernel_size] 242 | 243 | # for even kernels we need to do asymmetric padding :( 244 | out_padding = 2 * len(kernel_size) * [0] 245 | 246 | for i in range(len(kernel_size)): 247 | computed_tmp = computed[-(i + 1)] 248 | 249 | pad_front = computed_tmp // 2 250 | pad_rear = computed_tmp - pad_front 251 | 252 | out_padding[2 * i + 0] = pad_front 253 | out_padding[2 * i + 1] = pad_rear 254 | 255 | return out_padding 256 | 257 | 258 | def _filter2d(self, input, kernel): 259 | # prepare kernel 260 | b, c, h, w = input.shape 261 | tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) 262 | 263 | tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) 264 | 265 | height, width = tmp_kernel.shape[-2:] 266 | 267 | padding_shape: list[int] = self._compute_padding([height, width]) 268 | input = torch.nn.functional.pad(input, padding_shape, mode="reflect") 269 | 270 | # kernel and input tensor reshape to align element-wise or batch-wise params 271 | tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) 272 | input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) 273 | 274 | # convolve the tensor with the kernel. 275 | output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) 276 | 277 | out = output.view(b, c, h, w) 278 | return out 279 | 280 | 281 | def _gaussian(self, window_size: int, sigma): 282 | if isinstance(sigma, float): 283 | sigma = torch.tensor([[sigma]]) 284 | 285 | batch_size = sigma.shape[0] 286 | 287 | x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) 288 | 289 | if window_size % 2 == 0: 290 | x = x + 0.5 291 | 292 | gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) 293 | 294 | return gauss / gauss.sum(-1, keepdim=True) 295 | 296 | 297 | def _gaussian_blur2d(self, input, kernel_size, sigma): 298 | if isinstance(sigma, tuple): 299 | sigma = torch.tensor([sigma], dtype=input.dtype) 300 | else: 301 | sigma = sigma.to(dtype=input.dtype) 302 | 303 | ky, kx = int(kernel_size[0]), int(kernel_size[1]) 304 | bs = sigma.shape[0] 305 | kernel_x = self._gaussian(kx, sigma[:, 1].view(bs, 1)) 306 | kernel_y = self._gaussian(ky, sigma[:, 0].view(bs, 1)) 307 | out_x = self._filter2d(input, kernel_x[..., None, :]) 308 | out = self._filter2d(out_x, kernel_y[..., None]) 309 | 310 | return out 311 | -------------------------------------------------------------------------------- /diffsynth/processors/FastBlend.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import cupy as cp 3 | import numpy as np 4 | from tqdm import tqdm 5 | from ..extensions.FastBlend.patch_match import PyramidPatchMatcher 6 | from ..extensions.FastBlend.runners.fast import TableManager 7 | from .base import VideoProcessor 8 | 9 | 10 | class FastBlendSmoother(VideoProcessor): 11 | def __init__( 12 | self, 13 | inference_mode="fast", batch_size=8, window_size=60, 14 | minimum_patch_size=5, threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0, initialize="identity", tracking_window_size=0 15 | ): 16 | self.inference_mode = inference_mode 17 | self.batch_size = batch_size 18 | self.window_size = window_size 19 | self.ebsynth_config = { 20 | "minimum_patch_size": minimum_patch_size, 21 | "threads_per_block": threads_per_block, 22 | "num_iter": num_iter, 23 | "gpu_id": gpu_id, 24 | "guide_weight": guide_weight, 25 | "initialize": initialize, 26 | "tracking_window_size": tracking_window_size 27 | } 28 | 29 | @staticmethod 30 | def from_model_manager(model_manager, **kwargs): 31 | # TODO: fetch GPU ID from model_manager 32 | return FastBlendSmoother(**kwargs) 33 | 34 | def inference_fast(self, frames_guide, frames_style): 35 | table_manager = TableManager() 36 | patch_match_engine = PyramidPatchMatcher( 37 | image_height=frames_style[0].shape[0], 38 | image_width=frames_style[0].shape[1], 39 | channel=3, 40 | **self.ebsynth_config 41 | ) 42 | # left part 43 | table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, self.batch_size, desc="Fast Mode Step 1/4") 44 | table_l = table_manager.remapping_table_to_blending_table(table_l) 45 | table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 2/4") 46 | # right part 47 | table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, self.batch_size, desc="Fast Mode Step 3/4") 48 | table_r = table_manager.remapping_table_to_blending_table(table_r) 49 | table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 4/4")[::-1] 50 | # merge 51 | frames = [] 52 | for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r): 53 | weight_m = -1 54 | weight = weight_l + weight_m + weight_r 55 | frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight) 56 | frames.append(frame) 57 | frames = [frame.clip(0, 255).astype("uint8") for frame in frames] 58 | frames = [Image.fromarray(frame) for frame in frames] 59 | return frames 60 | 61 | def inference_balanced(self, frames_guide, frames_style): 62 | patch_match_engine = PyramidPatchMatcher( 63 | image_height=frames_style[0].shape[0], 64 | image_width=frames_style[0].shape[1], 65 | channel=3, 66 | **self.ebsynth_config 67 | ) 68 | output_frames = [] 69 | # tasks 70 | n = len(frames_style) 71 | tasks = [] 72 | for target in range(n): 73 | for source in range(target - self.window_size, target + self.window_size + 1): 74 | if source >= 0 and source < n and source != target: 75 | tasks.append((source, target)) 76 | # run 77 | frames = [(None, 1) for i in range(n)] 78 | for batch_id in tqdm(range(0, len(tasks), self.batch_size), desc="Balanced Mode"): 79 | tasks_batch = tasks[batch_id: min(batch_id+self.batch_size, len(tasks))] 80 | source_guide = np.stack([frames_guide[source] for source, target in tasks_batch]) 81 | target_guide = np.stack([frames_guide[target] for source, target in tasks_batch]) 82 | source_style = np.stack([frames_style[source] for source, target in tasks_batch]) 83 | _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) 84 | for (source, target), result in zip(tasks_batch, target_style): 85 | frame, weight = frames[target] 86 | if frame is None: 87 | frame = frames_style[target] 88 | frames[target] = ( 89 | frame * (weight / (weight + 1)) + result / (weight + 1), 90 | weight + 1 91 | ) 92 | if weight + 1 == min(n, target + self.window_size + 1) - max(0, target - self.window_size): 93 | frame = frame.clip(0, 255).astype("uint8") 94 | output_frames.append(Image.fromarray(frame)) 95 | frames[target] = (None, 1) 96 | return output_frames 97 | 98 | def inference_accurate(self, frames_guide, frames_style): 99 | patch_match_engine = PyramidPatchMatcher( 100 | image_height=frames_style[0].shape[0], 101 | image_width=frames_style[0].shape[1], 102 | channel=3, 103 | use_mean_target_style=True, 104 | **self.ebsynth_config 105 | ) 106 | output_frames = [] 107 | # run 108 | n = len(frames_style) 109 | for target in tqdm(range(n), desc="Accurate Mode"): 110 | l, r = max(target - self.window_size, 0), min(target + self.window_size + 1, n) 111 | remapped_frames = [] 112 | for i in range(l, r, self.batch_size): 113 | j = min(i + self.batch_size, r) 114 | source_guide = np.stack([frames_guide[source] for source in range(i, j)]) 115 | target_guide = np.stack([frames_guide[target]] * (j - i)) 116 | source_style = np.stack([frames_style[source] for source in range(i, j)]) 117 | _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) 118 | remapped_frames.append(target_style) 119 | frame = np.concatenate(remapped_frames, axis=0).mean(axis=0) 120 | frame = frame.clip(0, 255).astype("uint8") 121 | output_frames.append(Image.fromarray(frame)) 122 | return output_frames 123 | 124 | def release_vram(self): 125 | mempool = cp.get_default_memory_pool() 126 | pinned_mempool = cp.get_default_pinned_memory_pool() 127 | mempool.free_all_blocks() 128 | pinned_mempool.free_all_blocks() 129 | 130 | def __call__(self, rendered_frames, original_frames=None, **kwargs): 131 | rendered_frames = [np.array(frame) for frame in rendered_frames] 132 | original_frames = [np.array(frame) for frame in original_frames] 133 | if self.inference_mode == "fast": 134 | output_frames = self.inference_fast(original_frames, rendered_frames) 135 | elif self.inference_mode == "balanced": 136 | output_frames = self.inference_balanced(original_frames, rendered_frames) 137 | elif self.inference_mode == "accurate": 138 | output_frames = self.inference_accurate(original_frames, rendered_frames) 139 | else: 140 | raise ValueError("inference_mode must be fast, balanced or accurate") 141 | self.release_vram() 142 | return output_frames 143 | -------------------------------------------------------------------------------- /diffsynth/processors/PILEditor.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageEnhance 2 | from .base import VideoProcessor 3 | 4 | 5 | class ContrastEditor(VideoProcessor): 6 | def __init__(self, rate=1.5): 7 | self.rate = rate 8 | 9 | @staticmethod 10 | def from_model_manager(model_manager, **kwargs): 11 | return ContrastEditor(**kwargs) 12 | 13 | def __call__(self, rendered_frames, **kwargs): 14 | rendered_frames = [ImageEnhance.Contrast(i).enhance(self.rate) for i in rendered_frames] 15 | return rendered_frames 16 | 17 | 18 | class SharpnessEditor(VideoProcessor): 19 | def __init__(self, rate=1.5): 20 | self.rate = rate 21 | 22 | @staticmethod 23 | def from_model_manager(model_manager, **kwargs): 24 | return SharpnessEditor(**kwargs) 25 | 26 | def __call__(self, rendered_frames, **kwargs): 27 | rendered_frames = [ImageEnhance.Sharpness(i).enhance(self.rate) for i in rendered_frames] 28 | return rendered_frames 29 | -------------------------------------------------------------------------------- /diffsynth/processors/RIFE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | from .base import VideoProcessor 5 | 6 | 7 | class RIFESmoother(VideoProcessor): 8 | def __init__(self, model, device="cuda", scale=1.0, batch_size=4, interpolate=True): 9 | self.model = model 10 | self.device = device 11 | 12 | # IFNet only does not support float16 13 | self.torch_dtype = torch.float32 14 | 15 | # Other parameters 16 | self.scale = scale 17 | self.batch_size = batch_size 18 | self.interpolate = interpolate 19 | 20 | @staticmethod 21 | def from_model_manager(model_manager, **kwargs): 22 | return RIFESmoother(model_manager.RIFE, device=model_manager.device, **kwargs) 23 | 24 | def process_image(self, image): 25 | width, height = image.size 26 | if width % 32 != 0 or height % 32 != 0: 27 | width = (width + 31) // 32 28 | height = (height + 31) // 32 29 | image = image.resize((width, height)) 30 | image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1) 31 | return image 32 | 33 | def process_images(self, images): 34 | images = [self.process_image(image) for image in images] 35 | images = torch.stack(images) 36 | return images 37 | 38 | def decode_images(self, images): 39 | images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8) 40 | images = [Image.fromarray(image) for image in images] 41 | return images 42 | 43 | def process_tensors(self, input_tensor, scale=1.0, batch_size=4): 44 | output_tensor = [] 45 | for batch_id in range(0, input_tensor.shape[0], batch_size): 46 | batch_id_ = min(batch_id + batch_size, input_tensor.shape[0]) 47 | batch_input_tensor = input_tensor[batch_id: batch_id_] 48 | batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype) 49 | flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale]) 50 | output_tensor.append(merged[2].cpu()) 51 | output_tensor = torch.concat(output_tensor, dim=0) 52 | return output_tensor 53 | 54 | @torch.no_grad() 55 | def __call__(self, rendered_frames, **kwargs): 56 | # Preprocess 57 | processed_images = self.process_images(rendered_frames) 58 | 59 | # Input 60 | input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1) 61 | 62 | # Interpolate 63 | output_tensor = self.process_tensors(input_tensor, scale=self.scale, batch_size=self.batch_size) 64 | 65 | if self.interpolate: 66 | # Blend 67 | input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1) 68 | output_tensor = self.process_tensors(input_tensor, scale=self.scale, batch_size=self.batch_size) 69 | processed_images[1:-1] = output_tensor 70 | else: 71 | processed_images[1:-1] = (processed_images[1:-1] + output_tensor) / 2 72 | 73 | # To images 74 | output_images = self.decode_images(processed_images) 75 | if output_images[0].size != rendered_frames[0].size: 76 | output_images = [image.resize(rendered_frames[0].size) for image in output_images] 77 | return output_images 78 | -------------------------------------------------------------------------------- /diffsynth/processors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-DiffSynthWrapper/33ec4bb988922e5ea8abaa52168b1b97dcf88d82/diffsynth/processors/__init__.py -------------------------------------------------------------------------------- /diffsynth/processors/base.py: -------------------------------------------------------------------------------- 1 | class VideoProcessor: 2 | def __init__(self): 3 | pass 4 | 5 | def __call__(self): 6 | raise NotImplementedError 7 | -------------------------------------------------------------------------------- /diffsynth/processors/sequencial_processor.py: -------------------------------------------------------------------------------- 1 | from .base import VideoProcessor 2 | 3 | 4 | class AutoVideoProcessor(VideoProcessor): 5 | def __init__(self): 6 | pass 7 | 8 | @staticmethod 9 | def from_model_manager(model_manager, processor_type, **kwargs): 10 | if processor_type == "FastBlend": 11 | from .FastBlend import FastBlendSmoother 12 | return FastBlendSmoother.from_model_manager(model_manager, **kwargs) 13 | elif processor_type == "Contrast": 14 | from .PILEditor import ContrastEditor 15 | return ContrastEditor.from_model_manager(model_manager, **kwargs) 16 | elif processor_type == "Sharpness": 17 | from .PILEditor import SharpnessEditor 18 | return SharpnessEditor.from_model_manager(model_manager, **kwargs) 19 | elif processor_type == "RIFE": 20 | from .RIFE import RIFESmoother 21 | return RIFESmoother.from_model_manager(model_manager, **kwargs) 22 | else: 23 | raise ValueError(f"invalid processor_type: {processor_type}") 24 | 25 | 26 | class SequencialProcessor(VideoProcessor): 27 | def __init__(self, processors=[]): 28 | self.processors = processors 29 | 30 | @staticmethod 31 | def from_model_manager(model_manager, configs): 32 | processors = [ 33 | AutoVideoProcessor.from_model_manager(model_manager, config["processor_type"], **config["config"]) 34 | for config in configs 35 | ] 36 | return SequencialProcessor(processors) 37 | 38 | def __call__(self, rendered_frames, **kwargs): 39 | for processor in self.processors: 40 | rendered_frames = processor(rendered_frames, **kwargs) 41 | return rendered_frames 42 | -------------------------------------------------------------------------------- /diffsynth/prompts/__init__.py: -------------------------------------------------------------------------------- 1 | from .sd_prompter import SDPrompter 2 | from .sdxl_prompter import SDXLPrompter 3 | from .hunyuan_dit_prompter import HunyuanDiTPrompter 4 | -------------------------------------------------------------------------------- /diffsynth/prompts/hunyuan_dit_prompter.py: -------------------------------------------------------------------------------- 1 | from .utils import Prompter 2 | from transformers import BertModel, T5EncoderModel, BertTokenizer, AutoTokenizer 3 | import warnings 4 | 5 | 6 | class HunyuanDiTPrompter(Prompter): 7 | def __init__( 8 | self, 9 | tokenizer_path="configs/hunyuan_dit/tokenizer", 10 | tokenizer_t5_path="configs/hunyuan_dit/tokenizer_t5" 11 | ): 12 | super().__init__() 13 | self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path) 14 | with warnings.catch_warnings(): 15 | warnings.simplefilter("ignore") 16 | self.tokenizer_t5 = AutoTokenizer.from_pretrained(tokenizer_t5_path) 17 | 18 | 19 | def encode_prompt_using_signle_model(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device): 20 | text_inputs = tokenizer( 21 | prompt, 22 | padding="max_length", 23 | max_length=max_length, 24 | truncation=True, 25 | return_attention_mask=True, 26 | return_tensors="pt", 27 | ) 28 | text_input_ids = text_inputs.input_ids 29 | attention_mask = text_inputs.attention_mask.to(device) 30 | prompt_embeds = text_encoder( 31 | text_input_ids.to(device), 32 | attention_mask=attention_mask, 33 | clip_skip=clip_skip 34 | ) 35 | return prompt_embeds, attention_mask 36 | 37 | 38 | def encode_prompt( 39 | self, 40 | text_encoder: BertModel, 41 | text_encoder_t5: T5EncoderModel, 42 | prompt, 43 | clip_skip=1, 44 | clip_skip_2=1, 45 | positive=True, 46 | device="cuda" 47 | ): 48 | prompt = self.process_prompt(prompt, positive=positive) 49 | 50 | # CLIP 51 | prompt_emb, attention_mask = self.encode_prompt_using_signle_model(prompt, text_encoder, self.tokenizer, self.tokenizer.model_max_length, clip_skip, device) 52 | 53 | # T5 54 | prompt_emb_t5, attention_mask_t5 = self.encode_prompt_using_signle_model(prompt, text_encoder_t5, self.tokenizer_t5, self.tokenizer_t5.model_max_length, clip_skip_2, device) 55 | 56 | return prompt_emb, attention_mask, prompt_emb_t5, attention_mask_t5 57 | -------------------------------------------------------------------------------- /diffsynth/prompts/sd_prompter.py: -------------------------------------------------------------------------------- 1 | from .utils import Prompter, tokenize_long_prompt 2 | from transformers import CLIPTokenizer 3 | from ..models import SDTextEncoder 4 | 5 | 6 | class SDPrompter(Prompter): 7 | def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"): 8 | super().__init__() 9 | self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) 10 | 11 | def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True): 12 | prompt = self.process_prompt(prompt, positive=positive) 13 | input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device) 14 | prompt_emb = text_encoder(input_ids, clip_skip=clip_skip) 15 | prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1)) 16 | 17 | return prompt_emb -------------------------------------------------------------------------------- /diffsynth/prompts/sdxl_prompter.py: -------------------------------------------------------------------------------- 1 | from .utils import Prompter, tokenize_long_prompt 2 | from transformers import CLIPTokenizer 3 | from ..models import SDXLTextEncoder, SDXLTextEncoder2 4 | import torch 5 | 6 | 7 | class SDXLPrompter(Prompter): 8 | def __init__( 9 | self, 10 | tokenizer_path="configs/stable_diffusion/tokenizer", 11 | tokenizer_2_path="configs/stable_diffusion_xl/tokenizer_2" 12 | ): 13 | super().__init__() 14 | self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) 15 | self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path) 16 | 17 | def encode_prompt( 18 | self, 19 | text_encoder: SDXLTextEncoder, 20 | text_encoder_2: SDXLTextEncoder2, 21 | prompt, 22 | clip_skip=1, 23 | clip_skip_2=2, 24 | positive=True, 25 | device="cuda" 26 | ): 27 | prompt = self.process_prompt(prompt, positive=positive) 28 | 29 | # 1 30 | input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device) 31 | prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip) 32 | 33 | # 2 34 | input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device) 35 | add_text_embeds, prompt_emb_2 = text_encoder_2(input_ids_2, clip_skip=clip_skip_2) 36 | 37 | # Merge 38 | prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1) 39 | 40 | # For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`. 41 | add_text_embeds = add_text_embeds[0:1] 42 | prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1)) 43 | return add_text_embeds, prompt_emb 44 | -------------------------------------------------------------------------------- /diffsynth/prompts/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import CLIPTokenizer, AutoTokenizer 2 | from ..models import ModelManager 3 | import os 4 | 5 | 6 | def tokenize_long_prompt(tokenizer, prompt): 7 | # Get model_max_length from self.tokenizer 8 | length = tokenizer.model_max_length 9 | 10 | # To avoid the warning. set self.tokenizer.model_max_length to +oo. 11 | tokenizer.model_max_length = 99999999 12 | 13 | # Tokenize it! 14 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids 15 | 16 | # Determine the real length. 17 | max_length = (input_ids.shape[1] + length - 1) // length * length 18 | 19 | # Restore tokenizer.model_max_length 20 | tokenizer.model_max_length = length 21 | 22 | # Tokenize it again with fixed length. 23 | input_ids = tokenizer( 24 | prompt, 25 | return_tensors="pt", 26 | padding="max_length", 27 | max_length=max_length, 28 | truncation=True 29 | ).input_ids 30 | 31 | # Reshape input_ids to fit the text encoder. 32 | num_sentence = input_ids.shape[1] // length 33 | input_ids = input_ids.reshape((num_sentence, length)) 34 | 35 | return input_ids 36 | 37 | 38 | class BeautifulPrompt: 39 | def __init__(self, tokenizer_path="configs/beautiful_prompt/tokenizer", model=None): 40 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 41 | self.model = model 42 | self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:' 43 | 44 | def __call__(self, raw_prompt): 45 | model_input = self.template.format(raw_prompt=raw_prompt) 46 | input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device) 47 | outputs = self.model.generate( 48 | input_ids, 49 | max_new_tokens=384, 50 | do_sample=True, 51 | temperature=0.9, 52 | top_k=50, 53 | top_p=0.95, 54 | repetition_penalty=1.1, 55 | num_return_sequences=1 56 | ) 57 | prompt = raw_prompt + ", " + self.tokenizer.batch_decode( 58 | outputs[:, input_ids.size(1):], 59 | skip_special_tokens=True 60 | )[0].strip() 61 | return prompt 62 | 63 | 64 | class Translator: 65 | def __init__(self, tokenizer_path="configs/translator/tokenizer", model=None): 66 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 67 | self.model = model 68 | 69 | def __call__(self, prompt): 70 | input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device) 71 | output_ids = self.model.generate(input_ids) 72 | prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 73 | return prompt 74 | 75 | 76 | class Prompter: 77 | def __init__(self): 78 | self.tokenizer: CLIPTokenizer = None 79 | self.keyword_dict = {} 80 | self.translator: Translator = None 81 | self.beautiful_prompt: BeautifulPrompt = None 82 | 83 | def load_textual_inversion(self, textual_inversion_dict): 84 | self.keyword_dict = {} 85 | additional_tokens = [] 86 | for keyword in textual_inversion_dict: 87 | tokens, _ = textual_inversion_dict[keyword] 88 | additional_tokens += tokens 89 | self.keyword_dict[keyword] = " " + " ".join(tokens) + " " 90 | self.tokenizer.add_tokens(additional_tokens) 91 | 92 | def load_beautiful_prompt(self, model, model_path): 93 | model_folder = os.path.dirname(model_path) 94 | self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model) 95 | if model_folder.endswith("v2"): 96 | self.beautiful_prompt.template = """Converts a simple image description into a prompt. \ 97 | Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \ 98 | or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \ 99 | but make sure there is a correlation between the input and output.\n\ 100 | ### Input: {raw_prompt}\n### Output:""" 101 | 102 | def load_translator(self, model, model_path): 103 | model_folder = os.path.dirname(model_path) 104 | self.translator = Translator(tokenizer_path=model_folder, model=model) 105 | 106 | def load_from_model_manager(self, model_manager: ModelManager): 107 | self.load_textual_inversion(model_manager.textual_inversion_dict) 108 | if "translator" in model_manager.model: 109 | self.load_translator(model_manager.model["translator"], model_manager.model_path["translator"]) 110 | if "beautiful_prompt" in model_manager.model: 111 | self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"]) 112 | 113 | def process_prompt(self, prompt, positive=True): 114 | for keyword in self.keyword_dict: 115 | if keyword in prompt: 116 | prompt = prompt.replace(keyword, self.keyword_dict[keyword]) 117 | if positive and self.translator is not None: 118 | prompt = self.translator(prompt) 119 | print(f"Your prompt is translated: \"{prompt}\"") 120 | if positive and self.beautiful_prompt is not None: 121 | prompt = self.beautiful_prompt(prompt) 122 | print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"") 123 | return prompt 124 | -------------------------------------------------------------------------------- /diffsynth/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from .ddim import EnhancedDDIMScheduler 2 | from .continuous_ode import ContinuousODEScheduler 3 | -------------------------------------------------------------------------------- /diffsynth/schedulers/continuous_ode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ContinuousODEScheduler(): 5 | 6 | def __init__(self, num_inference_steps=100, sigma_max=700.0, sigma_min=0.002, rho=7.0): 7 | self.sigma_max = sigma_max 8 | self.sigma_min = sigma_min 9 | self.rho = rho 10 | self.set_timesteps(num_inference_steps) 11 | 12 | 13 | def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0): 14 | ramp = torch.linspace(1-denoising_strength, 1, num_inference_steps) 15 | min_inv_rho = torch.pow(torch.tensor((self.sigma_min,)), (1 / self.rho)) 16 | max_inv_rho = torch.pow(torch.tensor((self.sigma_max,)), (1 / self.rho)) 17 | self.sigmas = torch.pow(max_inv_rho + ramp * (min_inv_rho - max_inv_rho), self.rho) 18 | self.timesteps = torch.log(self.sigmas) * 0.25 19 | 20 | 21 | def step(self, model_output, timestep, sample, to_final=False): 22 | timestep_id = torch.argmin((self.timesteps - timestep).abs()) 23 | sigma = self.sigmas[timestep_id] 24 | sample *= (sigma*sigma + 1).sqrt() 25 | estimated_sample = -sigma / (sigma*sigma + 1).sqrt() * model_output + 1 / (sigma*sigma + 1) * sample 26 | if to_final or timestep_id + 1 >= len(self.timesteps): 27 | prev_sample = estimated_sample 28 | else: 29 | sigma_ = self.sigmas[timestep_id + 1] 30 | derivative = 1 / sigma * (sample - estimated_sample) 31 | prev_sample = sample + derivative * (sigma_ - sigma) 32 | prev_sample /= (sigma_*sigma_ + 1).sqrt() 33 | return prev_sample 34 | 35 | 36 | def return_to_timestep(self, timestep, sample, sample_stablized): 37 | # This scheduler doesn't support this function. 38 | pass 39 | 40 | 41 | def add_noise(self, original_samples, noise, timestep): 42 | timestep_id = torch.argmin((self.timesteps - timestep).abs()) 43 | sigma = self.sigmas[timestep_id] 44 | sample = (original_samples + noise * sigma) / (sigma*sigma + 1).sqrt() 45 | return sample 46 | 47 | 48 | def training_target(self, sample, noise, timestep): 49 | timestep_id = torch.argmin((self.timesteps - timestep).abs()) 50 | sigma = self.sigmas[timestep_id] 51 | target = (-(sigma*sigma + 1).sqrt() / sigma + 1 / (sigma*sigma + 1).sqrt() / sigma) * sample + 1 / (sigma*sigma + 1).sqrt() * noise 52 | return target 53 | 54 | 55 | def training_weight(self, timestep): 56 | timestep_id = torch.argmin((self.timesteps - timestep).abs()) 57 | sigma = self.sigmas[timestep_id] 58 | weight = (1 + sigma*sigma).sqrt() / sigma 59 | return weight 60 | -------------------------------------------------------------------------------- /diffsynth/schedulers/ddim.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | 3 | 4 | class EnhancedDDIMScheduler(): 5 | 6 | def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon"): 7 | self.num_train_timesteps = num_train_timesteps 8 | if beta_schedule == "scaled_linear": 9 | betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32)) 10 | elif beta_schedule == "linear": 11 | betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) 12 | else: 13 | raise NotImplementedError(f"{beta_schedule} is not implemented") 14 | self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0).tolist() 15 | self.set_timesteps(10) 16 | self.prediction_type = prediction_type 17 | 18 | 19 | def set_timesteps(self, num_inference_steps, denoising_strength=1.0): 20 | # The timesteps are aligned to 999...0, which is different from other implementations, 21 | # but I think this implementation is more reasonable in theory. 22 | max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0) 23 | num_inference_steps = min(num_inference_steps, max_timestep + 1) 24 | if num_inference_steps == 1: 25 | self.timesteps = [max_timestep] 26 | else: 27 | step_length = max_timestep / (num_inference_steps - 1) 28 | self.timesteps = [round(max_timestep - i*step_length) for i in range(num_inference_steps)] 29 | 30 | 31 | def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev): 32 | if self.prediction_type == "epsilon": 33 | weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t) 34 | weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t) 35 | prev_sample = sample * weight_x + model_output * weight_e 36 | elif self.prediction_type == "v_prediction": 37 | weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev)) 38 | weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev)) 39 | prev_sample = sample * weight_x + model_output * weight_e 40 | else: 41 | raise NotImplementedError(f"{self.prediction_type} is not implemented") 42 | return prev_sample 43 | 44 | 45 | def step(self, model_output, timestep, sample, to_final=False): 46 | alpha_prod_t = self.alphas_cumprod[timestep] 47 | timestep_id = self.timesteps.index(timestep) 48 | if to_final or timestep_id + 1 >= len(self.timesteps): 49 | alpha_prod_t_prev = 1.0 50 | else: 51 | timestep_prev = self.timesteps[timestep_id + 1] 52 | alpha_prod_t_prev = self.alphas_cumprod[timestep_prev] 53 | 54 | return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev) 55 | 56 | 57 | def return_to_timestep(self, timestep, sample, sample_stablized): 58 | alpha_prod_t = self.alphas_cumprod[timestep] 59 | noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t) 60 | return noise_pred 61 | 62 | 63 | def add_noise(self, original_samples, noise, timestep): 64 | sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep]) 65 | sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep]) 66 | noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise 67 | return noisy_samples 68 | 69 | def training_target(self, sample, noise, timestep): 70 | sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep]) 71 | sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep]) 72 | target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample 73 | return target 74 | -------------------------------------------------------------------------------- /examples/exvideo_example_workflow_01.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 9, 3 | "last_link_id": 16, 4 | "nodes": [ 5 | { 6 | "id": 2, 7 | "type": "DownloadAndLoadDiffSynthExVideoSVD", 8 | "pos": [ 9 | 47, 10 | -2 11 | ], 12 | "size": { 13 | "0": 386.8787536621094, 14 | "1": 82.6969985961914 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "diffsynth_model", 22 | "type": "DIFFSYNTHMODEL", 23 | "links": [ 24 | 7 25 | ], 26 | "shape": 3 27 | } 28 | ], 29 | "properties": { 30 | "Node name for S&R": "DownloadAndLoadDiffSynthExVideoSVD" 31 | }, 32 | "widgets_values": [ 33 | "ECNU-CILab/ExVideo-SVD-128f-v1", 34 | "svd_xt.safetensors" 35 | ] 36 | }, 37 | { 38 | "id": 8, 39 | "type": "ImageResizeKJ", 40 | "pos": [ 41 | 311, 42 | 235 43 | ], 44 | "size": { 45 | "0": 315, 46 | "1": 242 47 | }, 48 | "flags": {}, 49 | "order": 2, 50 | "mode": 0, 51 | "inputs": [ 52 | { 53 | "name": "image", 54 | "type": "IMAGE", 55 | "link": 10 56 | }, 57 | { 58 | "name": "get_image_size", 59 | "type": "IMAGE", 60 | "link": null 61 | }, 62 | { 63 | "name": "width_input", 64 | "type": "INT", 65 | "link": null, 66 | "widget": { 67 | "name": "width_input" 68 | } 69 | }, 70 | { 71 | "name": "height_input", 72 | "type": "INT", 73 | "link": null, 74 | "widget": { 75 | "name": "height_input" 76 | } 77 | } 78 | ], 79 | "outputs": [ 80 | { 81 | "name": "IMAGE", 82 | "type": "IMAGE", 83 | "links": [ 84 | 11 85 | ], 86 | "shape": 3, 87 | "slot_index": 0 88 | }, 89 | { 90 | "name": "width", 91 | "type": "INT", 92 | "links": [ 93 | 12 94 | ], 95 | "shape": 3, 96 | "slot_index": 1 97 | }, 98 | { 99 | "name": "height", 100 | "type": "INT", 101 | "links": [ 102 | 13 103 | ], 104 | "shape": 3, 105 | "slot_index": 2 106 | } 107 | ], 108 | "properties": { 109 | "Node name for S&R": "ImageResizeKJ" 110 | }, 111 | "widgets_values": [ 112 | 512, 113 | 512, 114 | "lanczos", 115 | false, 116 | 64, 117 | 0, 118 | 0 119 | ] 120 | }, 121 | { 122 | "id": 3, 123 | "type": "LoadImage", 124 | "pos": [ 125 | -66, 126 | 230 127 | ], 128 | "size": { 129 | "0": 314.8787536621094, 130 | "1": 403.3636474609375 131 | }, 132 | "flags": {}, 133 | "order": 1, 134 | "mode": 0, 135 | "outputs": [ 136 | { 137 | "name": "IMAGE", 138 | "type": "IMAGE", 139 | "links": [ 140 | 10 141 | ], 142 | "shape": 3, 143 | "slot_index": 0 144 | }, 145 | { 146 | "name": "MASK", 147 | "type": "MASK", 148 | "links": null, 149 | "shape": 3 150 | } 151 | ], 152 | "properties": { 153 | "Node name for S&R": "LoadImage" 154 | }, 155 | "widgets_values": [ 156 | "ComfyUI_temp_mkxxn_00010_.png", 157 | "image" 158 | ] 159 | }, 160 | { 161 | "id": 6, 162 | "type": "DiffSynthSampler", 163 | "pos": [ 164 | 709, 165 | 199 166 | ], 167 | "size": { 168 | "0": 315, 169 | "1": 410 170 | }, 171 | "flags": {}, 172 | "order": 3, 173 | "mode": 0, 174 | "inputs": [ 175 | { 176 | "name": "diffsynth_model", 177 | "type": "DIFFSYNTHMODEL", 178 | "link": 7 179 | }, 180 | { 181 | "name": "image", 182 | "type": "IMAGE", 183 | "link": 11 184 | }, 185 | { 186 | "name": "input_video", 187 | "type": "IMAGE", 188 | "link": null 189 | }, 190 | { 191 | "name": "width", 192 | "type": "INT", 193 | "link": 12, 194 | "widget": { 195 | "name": "width" 196 | } 197 | }, 198 | { 199 | "name": "height", 200 | "type": "INT", 201 | "link": 13, 202 | "widget": { 203 | "name": "height" 204 | } 205 | } 206 | ], 207 | "outputs": [ 208 | { 209 | "name": "image", 210 | "type": "IMAGE", 211 | "links": [ 212 | 16 213 | ], 214 | "shape": 3, 215 | "slot_index": 0 216 | } 217 | ], 218 | "properties": { 219 | "Node name for S&R": "DiffSynthSampler" 220 | }, 221 | "widgets_values": [ 222 | 128, 223 | 512, 224 | 512, 225 | 30, 226 | 127, 227 | 30, 228 | 2, 229 | 2, 230 | 1.2, 231 | 0.02, 232 | 1, 233 | 1, 234 | "fixed", 235 | false 236 | ] 237 | }, 238 | { 239 | "id": 7, 240 | "type": "VHS_VideoCombine", 241 | "pos": [ 242 | 1094, 243 | 47 244 | ], 245 | "size": [ 246 | 483.8788146972656, 247 | 767.8788146972656 248 | ], 249 | "flags": {}, 250 | "order": 4, 251 | "mode": 0, 252 | "inputs": [ 253 | { 254 | "name": "images", 255 | "type": "IMAGE", 256 | "link": 16 257 | }, 258 | { 259 | "name": "audio", 260 | "type": "VHS_AUDIO", 261 | "link": null 262 | }, 263 | { 264 | "name": "meta_batch", 265 | "type": "VHS_BatchManager", 266 | "link": null 267 | } 268 | ], 269 | "outputs": [ 270 | { 271 | "name": "Filenames", 272 | "type": "VHS_FILENAMES", 273 | "links": null, 274 | "shape": 3 275 | } 276 | ], 277 | "properties": { 278 | "Node name for S&R": "VHS_VideoCombine" 279 | }, 280 | "widgets_values": { 281 | "frame_rate": 24, 282 | "loop_count": 0, 283 | "filename_prefix": "DiffSynthSVD", 284 | "format": "video/h264-mp4", 285 | "pix_fmt": "yuv420p", 286 | "crf": 19, 287 | "save_metadata": true, 288 | "pingpong": false, 289 | "save_output": true, 290 | "videopreview": { 291 | "hidden": false, 292 | "paused": false, 293 | "params": { 294 | "filename": "DiffSynthSVD_00026.mp4", 295 | "subfolder": "", 296 | "type": "output", 297 | "format": "video/h264-mp4" 298 | } 299 | } 300 | } 301 | } 302 | ], 303 | "links": [ 304 | [ 305 | 7, 306 | 2, 307 | 0, 308 | 6, 309 | 0, 310 | "DIFFSYNTHMODEL" 311 | ], 312 | [ 313 | 10, 314 | 3, 315 | 0, 316 | 8, 317 | 0, 318 | "IMAGE" 319 | ], 320 | [ 321 | 11, 322 | 8, 323 | 0, 324 | 6, 325 | 1, 326 | "IMAGE" 327 | ], 328 | [ 329 | 12, 330 | 8, 331 | 1, 332 | 6, 333 | 3, 334 | "INT" 335 | ], 336 | [ 337 | 13, 338 | 8, 339 | 2, 340 | 6, 341 | 4, 342 | "INT" 343 | ], 344 | [ 345 | 16, 346 | 6, 347 | 0, 348 | 7, 349 | 0, 350 | "IMAGE" 351 | ] 352 | ], 353 | "groups": [], 354 | "config": {}, 355 | "extra": { 356 | "ds": { 357 | "scale": 1, 358 | "offset": [ 359 | 290.121249112216, 360 | 177.63634144176135 361 | ] 362 | } 363 | }, 364 | "version": 0.4 365 | } -------------------------------------------------------------------------------- /examples/stone_elemental_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-DiffSynthWrapper/33ec4bb988922e5ea8abaa52168b1b97dcf88d82/examples/stone_elemental_example.png -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import sys 4 | 5 | from torchvision import transforms 6 | 7 | import comfy.model_management as mm 8 | from comfy.utils import ProgressBar 9 | import folder_paths 10 | 11 | script_directory = os.path.dirname(os.path.abspath(__file__)) 12 | sys.path.append(script_directory) 13 | 14 | from diffsynth import ModelManager, SVDVideoPipeline 15 | 16 | class DownloadAndLoadDiffSynthExVideoSVD: 17 | @classmethod 18 | def INPUT_TYPES(s): 19 | return {"required": { 20 | "diffsynth_model": ( 21 | [ 22 | 'ECNU-CILab/ExVideo-SVD-128f-v1', 23 | ], 24 | { 25 | "default": 'ECNU-CILab/ExVideo-SVD-128f-v1' 26 | }), 27 | "svd_model": (folder_paths.get_filename_list("checkpoints"),), 28 | }, 29 | } 30 | 31 | RETURN_TYPES = ("DIFFSYNTHMODEL",) 32 | RETURN_NAMES = ("diffsynth_model",) 33 | FUNCTION = "loadmodel" 34 | CATEGORY = "DiffSynthWrapper" 35 | 36 | def loadmodel(self, diffsynth_model, svd_model): 37 | device = mm.get_torch_device() 38 | offload_device = mm.unet_offload_device() 39 | dtype = torch.float16 40 | 41 | svd_model_path = folder_paths.get_full_path("checkpoints", svd_model) 42 | 43 | model_name = diffsynth_model.rsplit('/', 1)[-1] 44 | model_path = os.path.join(folder_paths.models_dir, "diffsynth", model_name) 45 | model_full_path = os.path.join(model_path, "model.fp16.safetensors") 46 | 47 | if not os.path.exists(model_full_path): 48 | print(f"Downloading DiffSynth model to: {model_full_path}") 49 | from huggingface_hub import snapshot_download 50 | snapshot_download(repo_id="ECNU-CILab/ExVideo-SVD-128f-v1", 51 | allow_patterns=['*fp16*'], 52 | local_dir=model_path, 53 | local_dir_use_symlinks=False) 54 | 55 | print(f"Loading DiffSynth model from: {model_full_path}") 56 | print(f"Loading SVD model from: {svd_model_path}") 57 | model_manager = ModelManager(torch_dtype=dtype, device=device) 58 | model_manager.load_models([svd_model_path, model_full_path]) 59 | pipe = SVDVideoPipeline.from_model_manager(model_manager) 60 | 61 | diffsynth_model = { 62 | 'pipe': pipe, 63 | 'dtype': dtype 64 | } 65 | 66 | return (diffsynth_model,) 67 | 68 | class DiffSynthSampler: 69 | @classmethod 70 | def INPUT_TYPES(s): 71 | return { 72 | "required": { 73 | "diffsynth_model": ("DIFFSYNTHMODEL", ), 74 | "image": ("IMAGE", ), 75 | "frames": ("INT", {"default": 128, "min": 1, "max": 128, "step": 1}), 76 | "width": ("INT", {"default": 512, "min": 1, "max": 2048, "step": 1}), 77 | "height": ("INT", {"default": 512, "min": 1, "max": 2048, "step": 1}), 78 | "steps": ("INT", {"default": 25, "min": 1, "max": 512, "step": 1}), 79 | "motion_bucket_id": ("INT", {"default": 127, "min": 0, "max": 255, "step": 1}), 80 | "fps": ("INT", {"default": 30, "min": 1, "max": 512, "step": 1}), 81 | "min_cfg_scale": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.01}), 82 | "max_cfg_scale": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.01}), 83 | "contrast_enhance_scale": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.01}), 84 | "noise_aug_strength": ("FLOAT", {"default": 0.02, "min": 0.0, "max": 10.0, "step": 0.01}), 85 | "denoising_strength": ("FLOAT", {"default": 1., "min": 0.0, "max": 1.0, "step": 0.01}), 86 | "seed": ("INT", {"default": 123, "min": 0, "max": 0xffffffffffffffff, "step": 1}), 87 | "keep_model_loaded": ("BOOLEAN", {"default": False}), 88 | }, 89 | "optional": { 90 | "input_video": ("IMAGE", ), 91 | } 92 | } 93 | 94 | RETURN_TYPES = ("IMAGE",) 95 | RETURN_NAMES =("image",) 96 | FUNCTION = "process" 97 | CATEGORY = "DiffSynthWrapper" 98 | 99 | def process(self, diffsynth_model, height, width, steps, motion_bucket_id, fps, frames, image, 100 | seed, min_cfg_scale, max_cfg_scale, denoising_strength, contrast_enhance_scale, noise_aug_strength, 101 | keep_model_loaded, input_video=None): 102 | device = mm.get_torch_device() 103 | offload_device = mm.unet_offload_device() 104 | 105 | pipe = diffsynth_model['pipe'] 106 | pipe.to(device) 107 | torch.manual_seed(seed) 108 | 109 | input_image = image.clone().permute(0, 3, 1, 2) * 2 - 1 110 | 111 | if input_video is not None: 112 | input_video = input_video.permute(0, 3, 1, 2) * 2 - 1 113 | 114 | video = pipe( 115 | input_image=input_image, 116 | input_video=input_video, 117 | num_frames=frames, 118 | fps=fps, 119 | height=height, 120 | width=width, 121 | motion_bucket_id=motion_bucket_id, 122 | num_inference_steps=steps, 123 | min_cfg_scale=min_cfg_scale, 124 | max_cfg_scale=max_cfg_scale, 125 | contrast_enhance_scale=contrast_enhance_scale, 126 | noise_aug_strength=noise_aug_strength, 127 | denoising_strength=denoising_strength, 128 | ) 129 | if not keep_model_loaded: 130 | pipe.to(offload_device) 131 | mm.soft_empty_cache() 132 | 133 | transform = transforms.ToTensor() 134 | tensors_list = [transform(image) for image in video] 135 | batch_tensor = torch.stack(tensors_list, dim=0) 136 | batch_tensor = batch_tensor.permute(0, 2, 3, 1).cpu().float() 137 | 138 | return (batch_tensor,) 139 | 140 | NODE_CLASS_MAPPINGS = { 141 | "DownloadAndLoadDiffSynthExVideoSVD": DownloadAndLoadDiffSynthExVideoSVD, 142 | "DiffSynthSampler": DiffSynthSampler, 143 | } 144 | NODE_DISPLAY_NAME_MAPPINGS = { 145 | "DownloadAndLoadDiffSynthExVideoSVD": "DownloadAndLoadDiffSynthExVideoSVD", 146 | "DiffSynthSampler": "DiffSynth Sampler", 147 | } --------------------------------------------------------------------------------