├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── diffsynth ├── __init__.py ├── controlnets │ ├── __init__.py │ ├── controlnet_unit.py │ └── processors.py ├── data │ ├── __init__.py │ └── video.py ├── extensions │ ├── ESRGAN │ │ └── __init__.py │ ├── FastBlend │ │ ├── __init__.py │ │ ├── api.py │ │ ├── cupy_kernels.py │ │ ├── data.py │ │ ├── patch_match.py │ │ └── runners │ │ │ ├── __init__.py │ │ │ ├── accurate.py │ │ │ ├── balanced.py │ │ │ ├── fast.py │ │ │ └── interpolation.py │ └── RIFE │ │ └── __init__.py ├── models │ ├── __init__.py │ ├── attention.py │ ├── downloader.py │ ├── hunyuan_dit.py │ ├── hunyuan_dit_text_encoder.py │ ├── sd_controlnet.py │ ├── sd_ipadapter.py │ ├── sd_lora.py │ ├── sd_motion.py │ ├── sd_text_encoder.py │ ├── sd_unet.py │ ├── sd_vae_decoder.py │ ├── sd_vae_encoder.py │ ├── sdxl_ipadapter.py │ ├── sdxl_motion.py │ ├── sdxl_text_encoder.py │ ├── sdxl_unet.py │ ├── sdxl_vae_decoder.py │ ├── sdxl_vae_encoder.py │ ├── svd_image_encoder.py │ ├── svd_unet.py │ ├── svd_vae_decoder.py │ ├── svd_vae_encoder.py │ └── tiler.py ├── pipelines │ ├── __init__.py │ ├── dancer.py │ ├── hunyuan_dit.py │ ├── stable_diffusion.py │ ├── stable_diffusion_video.py │ ├── stable_diffusion_xl.py │ ├── stable_diffusion_xl_video.py │ └── stable_video_diffusion.py ├── processors │ ├── FastBlend.py │ ├── PILEditor.py │ ├── RIFE.py │ ├── __init__.py │ ├── base.py │ └── sequencial_processor.py ├── prompts │ ├── __init__.py │ ├── hunyuan_dit_prompter.py │ ├── sd_prompter.py │ ├── sdxl_prompter.py │ └── utils.py ├── schedulers │ ├── __init__.py │ ├── continuous_ode.py │ └── ddim.py └── tokenizer_configs │ ├── hunyuan_dit │ ├── tokenizer │ │ ├── special_tokens_map.json │ │ ├── tokenizer_config.json │ │ ├── vocab.txt │ │ └── vocab_org.txt │ └── tokenizer_t5 │ │ ├── config.json │ │ ├── special_tokens_map.json │ │ ├── spiece.model │ │ └── tokenizer_config.json │ ├── stable_diffusion │ └── tokenizer │ │ ├── merges.txt │ │ ├── special_tokens_map.json │ │ ├── tokenizer_config.json │ │ └── vocab.json │ └── stable_diffusion_xl │ └── tokenizer_2 │ ├── merges.txt │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── vocab.json ├── models └── put diffsynth studio models here ├── requirements.txt ├── studio_nodes.py ├── util_nodes.py ├── web.png ├── web └── js │ ├── previewVideo.js │ └── uploadVideo.js └── workfolws ├── diffutoon_workflow.json └── exvideo_workflow.json /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | models/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-DiffSynth-Studio 2 | make [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) avialbe in ComfyUI 3 |
4 |
5 | webpage 6 |
7 |
8 | 9 | ## how to use 10 | test on py3.10,2080ti 11gb,torch==2.3.0+cu121 11 | 12 | make sure `ffmpeg` is worked in your commandline 13 | for Linux 14 | ``` 15 | apt update 16 | apt install ffmpeg 17 | ``` 18 | for Windows,you can install `ffmpeg` by [WingetUI](https://github.com/marticliment/WingetUI) automatically 19 | 20 | then! 21 | ``` 22 | # in ComfyUI/custom_nodes 23 | git clone https://github.com/AIFSH/ComfyUI-DiffSynth-Studio.git 24 | cd ComfyUI-DiffSynth-Studio 25 | pip install -r requirements.txt 26 | ``` 27 | weights will be downloaded from huggingface or model scope 28 | 29 | ## Tutorial 30 | - [Demo for Diffutoon](https://b23.tv/z7hEXlX)[DiffSynth-Studio!ComfyUI插件之Diffutoon节点-哔哩哔哩](https://b23.tv/z7hEXlX) 31 | - [Run on 4090](https://www.xiangongyun.com/image/detail/13706bf7-f3e6-4e29-bb97-c79405f5def4) 32 | - wechat:aifsh_98 33 | 34 | ## Nodes Detail and Workflow 35 | ### ExVideo Node 36 | [ExVideo workflow](./workfolws/exvideo_workflow.json) 37 | ``` 38 | "image":("IMAGE",), 39 | "svd_base_model":("SD_MODEL_PATH",), 40 | "exvideo_model":("SD_MODEL_PATH",), 41 | "num_frames":("INT",{ 42 | "default": 128 43 | }), 44 | "fps":("INT",{ 45 | "default": 30 46 | }), 47 | "num_inference_steps":("INT",{ 48 | "default": 50 49 | }), 50 | "if_upscale":("BOOLEAN",{ 51 | "default": True, 52 | }), 53 | "seed": ("INT",{ 54 | "default": 1 55 | }) 56 | ``` 57 | ### image synthesis 58 | comming soon 59 | ### Diffutoon Node 60 | [Diffutoon workflow](./workfolws/diffutoon_workflow.json) 61 | ``` 62 | "required":{ 63 | "source_video_path": ("VIDEO",), 64 | "sd_model_path":("SD_MODEL_PATH",), 65 | "postive_prompt":("TEXT",), 66 | "negative_prompt":("TEXT",), 67 | "start":("INT",{ 68 | "default": 0 ## from which second of your video to be shaded 69 | }), 70 | "length":("INT",{ 71 | "default": -1 ## how long you want to shade, -1 name the whole video frames 72 | }), 73 | "seed":("INT",{ 74 | "default": 42 75 | }), 76 | "cfg_scale":("INT",{ 77 | "default": 3 78 | }), 79 | "num_inference_steps":("INT",{ 80 | "default": 10 81 | }), 82 | "animatediff_batch_size":("INT",{ 83 | "default": 32 ## lower it till you can run 84 | }), 85 | "animatediff_stride":("INT",{ 86 | "default": 16 ## lower it till you can run 87 | }), 88 | "vram_limit_level":("INT",{ 89 | "default": 0 ## meet killed? try to 1 90 | }), 91 | }, 92 | "optional":{ 93 | "controlnet1":("ControlNetConfigUnit",), 94 | "controlnet2":("ControlNetConfigUnit",), 95 | "controlnet3":("ControlNetConfigUnit",), 96 | } 97 | ``` 98 | 99 | ### Video Stylization 100 | 101 | ### Chinese Models 102 | 103 | ## ask for answer as soon as you want 104 | wechat: aifsh_98 105 | need donate if you mand it, 106 | but please feel free to new issue for answering 107 | 108 | Windows环境配置太难?可以添加微信:aifsh_98,赞赏获取Windows一键包,当然你也可以提issue等待大佬为你答疑解惑。 109 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .util_nodes import LoadVideo,PreViewVideo 2 | from .studio_nodes import DiffTextNode,DiffutoonNode,SDPathLoader,ControlNetPathLoader,ExVideoNode 3 | WEB_DIRECTORY = "./web" 4 | # A dictionary that contains all nodes you want to export with their names 5 | # NOTE: names should be globally unique 6 | NODE_CLASS_MAPPINGS = { 7 | "LoadVideo": LoadVideo, 8 | "PreViewVideo": PreViewVideo, 9 | "SDPathLoader": SDPathLoader, 10 | "DiffTextNode": DiffTextNode, 11 | "DiffutoonNode": DiffutoonNode, 12 | "ControlNetPathLoader": ControlNetPathLoader, 13 | "ExVideoNode": ExVideoNode 14 | } 15 | 16 | # A dictionary that contains the friendly/humanly readable titles for the nodes 17 | NODE_DISPLAY_NAME_MAPPINGS = { 18 | "LoadVideo": "LoadVideo", 19 | "PreViewVideo": "PreViewVideo", 20 | "SDPathLoader": "SDPathLoader", 21 | "DiffTextNode": "DiffTextNode", 22 | "DiffutoonNode": "DiffutoonNode", 23 | "ControlNetPathLoader": "ControlNetPathLoader", 24 | "ExVideoNode": "ExVideoNode" 25 | } -------------------------------------------------------------------------------- /diffsynth/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import * 2 | from .models import * 3 | from .prompts import * 4 | from .schedulers import * 5 | from .pipelines import * 6 | from .controlnets import * 7 | -------------------------------------------------------------------------------- /diffsynth/controlnets/__init__.py: -------------------------------------------------------------------------------- 1 | from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager 2 | from .processors import Annotator 3 | -------------------------------------------------------------------------------- /diffsynth/controlnets/controlnet_unit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .processors import Processor_id 4 | 5 | 6 | class ControlNetConfigUnit: 7 | def __init__(self, processor_id: Processor_id, model_path, scale=1.0): 8 | self.processor_id = processor_id 9 | self.model_path = model_path 10 | self.scale = scale 11 | 12 | 13 | class ControlNetUnit: 14 | def __init__(self, processor, model, scale=1.0): 15 | self.processor = processor 16 | self.model = model 17 | self.scale = scale 18 | 19 | 20 | class MultiControlNetManager: 21 | def __init__(self, controlnet_units=[]): 22 | self.processors = [unit.processor for unit in controlnet_units] 23 | self.models = [unit.model for unit in controlnet_units] 24 | self.scales = [unit.scale for unit in controlnet_units] 25 | 26 | def process_image(self, image, processor_id=None): 27 | if processor_id is None: 28 | processed_image = [processor(image) for processor in self.processors] 29 | else: 30 | processed_image = [self.processors[processor_id](image)] 31 | processed_image = torch.concat([ 32 | torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0) 33 | for image_ in processed_image 34 | ], dim=0) 35 | return processed_image 36 | 37 | def __call__( 38 | self, 39 | sample, timestep, encoder_hidden_states, conditionings, 40 | tiled=False, tile_size=64, tile_stride=32 41 | ): 42 | res_stack = None 43 | for conditioning, model, scale in zip(conditionings, self.models, self.scales): 44 | res_stack_ = model( 45 | sample, timestep, encoder_hidden_states, conditioning, 46 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride 47 | ) 48 | res_stack_ = [res * scale for res in res_stack_] 49 | if res_stack is None: 50 | res_stack = res_stack_ 51 | else: 52 | res_stack = [i + j for i, j in zip(res_stack, res_stack_)] 53 | return res_stack 54 | -------------------------------------------------------------------------------- /diffsynth/controlnets/processors.py: -------------------------------------------------------------------------------- 1 | import os 2 | import folder_paths 3 | from typing_extensions import Literal, TypeAlias 4 | import warnings 5 | with warnings.catch_warnings(): 6 | warnings.simplefilter("ignore") 7 | from controlnet_aux.processor import ( 8 | CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector 9 | ) 10 | 11 | annotators_dir = os.path.join(folder_paths.models_dir, "Annotators") 12 | Processor_id: TypeAlias = Literal[ 13 | "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile" 14 | ] 15 | 16 | class Annotator: 17 | def __init__(self, processor_id: Processor_id, model_path=annotators_dir, detect_resolution=None): 18 | if processor_id == "canny": 19 | self.processor = CannyDetector() 20 | elif processor_id == "depth": 21 | self.processor = MidasDetector.from_pretrained(model_path).to("cuda") 22 | elif processor_id == "softedge": 23 | self.processor = HEDdetector.from_pretrained(model_path).to("cuda") 24 | elif processor_id == "lineart": 25 | self.processor = LineartDetector.from_pretrained(model_path).to("cuda") 26 | elif processor_id == "lineart_anime": 27 | self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda") 28 | elif processor_id == "openpose": 29 | self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda") 30 | elif processor_id == "tile": 31 | self.processor = None 32 | else: 33 | raise ValueError(f"Unsupported processor_id: {processor_id}") 34 | 35 | self.processor_id = processor_id 36 | self.detect_resolution = detect_resolution 37 | 38 | def __call__(self, image): 39 | width, height = image.size 40 | if self.processor_id == "openpose": 41 | kwargs = { 42 | "include_body": True, 43 | "include_hand": True, 44 | "include_face": True 45 | } 46 | else: 47 | kwargs = {} 48 | if self.processor is not None: 49 | detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height) 50 | image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs) 51 | image = image.resize((width, height)) 52 | return image 53 | 54 | -------------------------------------------------------------------------------- /diffsynth/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .video import VideoData, save_video, save_frames 2 | -------------------------------------------------------------------------------- /diffsynth/data/video.py: -------------------------------------------------------------------------------- 1 | import imageio, os 2 | import numpy as np 3 | from PIL import Image 4 | from tqdm import tqdm 5 | 6 | 7 | class LowMemoryVideo: 8 | def __init__(self, file_name): 9 | self.reader = imageio.get_reader(file_name) 10 | 11 | def __len__(self): 12 | return self.reader.count_frames() 13 | 14 | def __getitem__(self, item): 15 | return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB") 16 | 17 | def __del__(self): 18 | self.reader.close() 19 | 20 | 21 | def split_file_name(file_name): 22 | result = [] 23 | number = -1 24 | for i in file_name: 25 | if ord(i)>=ord("0") and ord(i)<=ord("9"): 26 | if number == -1: 27 | number = 0 28 | number = number*10 + ord(i) - ord("0") 29 | else: 30 | if number != -1: 31 | result.append(number) 32 | number = -1 33 | result.append(i) 34 | if number != -1: 35 | result.append(number) 36 | result = tuple(result) 37 | return result 38 | 39 | 40 | def search_for_images(folder): 41 | file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")] 42 | file_list = [(split_file_name(file_name), file_name) for file_name in file_list] 43 | file_list = [i[1] for i in sorted(file_list)] 44 | file_list = [os.path.join(folder, i) for i in file_list] 45 | return file_list 46 | 47 | 48 | class LowMemoryImageFolder: 49 | def __init__(self, folder, file_list=None): 50 | if file_list is None: 51 | self.file_list = search_for_images(folder) 52 | else: 53 | self.file_list = [os.path.join(folder, file_name) for file_name in file_list] 54 | 55 | def __len__(self): 56 | return len(self.file_list) 57 | 58 | def __getitem__(self, item): 59 | return Image.open(self.file_list[item]).convert("RGB") 60 | 61 | def __del__(self): 62 | pass 63 | 64 | 65 | def crop_and_resize(image, height, width): 66 | image = np.array(image) 67 | image_height, image_width, _ = image.shape 68 | if image_height / image_width < height / width: 69 | croped_width = int(image_height / height * width) 70 | left = (image_width - croped_width) // 2 71 | image = image[:, left: left+croped_width] 72 | image = Image.fromarray(image).resize((width, height)) 73 | else: 74 | croped_height = int(image_width / width * height) 75 | left = (image_height - croped_height) // 2 76 | image = image[left: left+croped_height, :] 77 | image = Image.fromarray(image).resize((width, height)) 78 | return image 79 | 80 | 81 | class VideoData: 82 | def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs): 83 | if video_file is not None: 84 | self.data_type = "video" 85 | self.data = LowMemoryVideo(video_file, **kwargs) 86 | elif image_folder is not None: 87 | self.data_type = "images" 88 | self.data = LowMemoryImageFolder(image_folder, **kwargs) 89 | else: 90 | raise ValueError("Cannot open video or image folder") 91 | self.length = None 92 | self.set_shape(height, width) 93 | 94 | def raw_data(self): 95 | frames = [] 96 | for i in range(self.__len__()): 97 | frames.append(self.__getitem__(i)) 98 | return frames 99 | 100 | def set_length(self, length): 101 | self.length = length 102 | 103 | def set_shape(self, height, width): 104 | self.height = height 105 | self.width = width 106 | 107 | def __len__(self): 108 | if self.length is None: 109 | return len(self.data) 110 | else: 111 | return self.length 112 | 113 | def shape(self): 114 | if self.height is not None and self.width is not None: 115 | return self.height, self.width 116 | else: 117 | height, width = self.__getitem__(0).size 118 | return height, width 119 | 120 | def __getitem__(self, item): 121 | frame = self.data.__getitem__(item) 122 | width, height = frame.size 123 | if self.height is not None and self.width is not None: 124 | if self.height != height or self.width != width: 125 | frame = crop_and_resize(frame, self.height, self.width) 126 | return frame 127 | 128 | def __del__(self): 129 | pass 130 | 131 | def save_images(self, folder): 132 | os.makedirs(folder, exist_ok=True) 133 | for i in tqdm(range(self.__len__()), desc="Saving images"): 134 | frame = self.__getitem__(i) 135 | frame.save(os.path.join(folder, f"{i}.png")) 136 | 137 | 138 | def save_video(frames, save_path, fps, quality=9): 139 | writer = imageio.get_writer(save_path, fps=fps, quality=quality) 140 | for frame in tqdm(frames, desc="Saving video"): 141 | frame = np.array(frame) 142 | writer.append_data(frame) 143 | writer.close() 144 | 145 | def save_frames(frames, save_path): 146 | os.makedirs(save_path, exist_ok=True) 147 | for i, frame in enumerate(tqdm(frames, desc="Saving images")): 148 | frame.save(os.path.join(save_path, f"{i}.png")) 149 | -------------------------------------------------------------------------------- /diffsynth/extensions/ESRGAN/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import repeat 3 | from PIL import Image 4 | import numpy as np 5 | 6 | 7 | class ResidualDenseBlock(torch.nn.Module): 8 | 9 | def __init__(self, num_feat=64, num_grow_ch=32): 10 | super(ResidualDenseBlock, self).__init__() 11 | self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) 12 | self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) 13 | self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) 14 | self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) 15 | self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) 16 | self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True) 17 | 18 | def forward(self, x): 19 | x1 = self.lrelu(self.conv1(x)) 20 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 21 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 22 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 23 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 24 | return x5 * 0.2 + x 25 | 26 | 27 | class RRDB(torch.nn.Module): 28 | 29 | def __init__(self, num_feat, num_grow_ch=32): 30 | super(RRDB, self).__init__() 31 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) 32 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) 33 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) 34 | 35 | def forward(self, x): 36 | out = self.rdb1(x) 37 | out = self.rdb2(out) 38 | out = self.rdb3(out) 39 | return out * 0.2 + x 40 | 41 | 42 | class RRDBNet(torch.nn.Module): 43 | 44 | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32): 45 | super(RRDBNet, self).__init__() 46 | self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 47 | self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)]) 48 | self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1) 49 | # upsample 50 | self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1) 51 | self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1) 52 | self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1) 53 | self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 54 | self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True) 55 | 56 | def forward(self, x): 57 | feat = x 58 | feat = self.conv_first(feat) 59 | body_feat = self.conv_body(self.body(feat)) 60 | feat = feat + body_feat 61 | # upsample 62 | feat = repeat(feat, "B C H W -> B C (H 2) (W 2)") 63 | feat = self.lrelu(self.conv_up1(feat)) 64 | feat = repeat(feat, "B C H W -> B C (H 2) (W 2)") 65 | feat = self.lrelu(self.conv_up2(feat)) 66 | out = self.conv_last(self.lrelu(self.conv_hr(feat))) 67 | return out 68 | 69 | 70 | class ESRGAN(torch.nn.Module): 71 | def __init__(self, model): 72 | super().__init__() 73 | self.model = model 74 | 75 | @staticmethod 76 | def from_pretrained(model_path): 77 | model = RRDBNet() 78 | state_dict = torch.load(model_path, map_location="cpu")["params_ema"] 79 | model.load_state_dict(state_dict) 80 | model.eval() 81 | return ESRGAN(model) 82 | 83 | def process_image(self, image): 84 | image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1) 85 | return image 86 | 87 | def process_images(self, images): 88 | images = [self.process_image(image) for image in images] 89 | images = torch.stack(images) 90 | return images 91 | 92 | def decode_images(self, images): 93 | images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8) 94 | images = [Image.fromarray(image) for image in images] 95 | return images 96 | 97 | @torch.no_grad() 98 | def upscale(self, images, batch_size=4, progress_bar=lambda x:x): 99 | # Preprocess 100 | input_tensor = self.process_images(images) 101 | 102 | # Interpolate 103 | output_tensor = [] 104 | for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)): 105 | batch_id_ = min(batch_id + batch_size, input_tensor.shape[0]) 106 | batch_input_tensor = input_tensor[batch_id: batch_id_] 107 | batch_input_tensor = batch_input_tensor.to( 108 | device=self.model.conv_first.weight.device, 109 | dtype=self.model.conv_first.weight.dtype) 110 | batch_output_tensor = self.model(batch_input_tensor) 111 | output_tensor.append(batch_output_tensor.cpu()) 112 | 113 | # Output 114 | output_tensor = torch.concat(output_tensor, dim=0) 115 | 116 | # To images 117 | output_images = self.decode_images(output_tensor) 118 | return output_images 119 | -------------------------------------------------------------------------------- /diffsynth/extensions/FastBlend/__init__.py: -------------------------------------------------------------------------------- 1 | from .runners.fast import TableManager, PyramidPatchMatcher 2 | from PIL import Image 3 | import numpy as np 4 | import cupy as cp 5 | 6 | 7 | class FastBlendSmoother: 8 | def __init__(self): 9 | self.batch_size = 8 10 | self.window_size = 64 11 | self.ebsynth_config = { 12 | "minimum_patch_size": 5, 13 | "threads_per_block": 8, 14 | "num_iter": 5, 15 | "gpu_id": 0, 16 | "guide_weight": 10.0, 17 | "initialize": "identity", 18 | "tracking_window_size": 0, 19 | } 20 | 21 | @staticmethod 22 | def from_model_manager(model_manager): 23 | # TODO: fetch GPU ID from model_manager 24 | return FastBlendSmoother() 25 | 26 | def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config): 27 | frames_guide = [np.array(frame) for frame in frames_guide] 28 | frames_style = [np.array(frame) for frame in frames_style] 29 | table_manager = TableManager() 30 | patch_match_engine = PyramidPatchMatcher( 31 | image_height=frames_style[0].shape[0], 32 | image_width=frames_style[0].shape[1], 33 | channel=3, 34 | **ebsynth_config 35 | ) 36 | # left part 37 | table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4") 38 | table_l = table_manager.remapping_table_to_blending_table(table_l) 39 | table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4") 40 | # right part 41 | table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4") 42 | table_r = table_manager.remapping_table_to_blending_table(table_r) 43 | table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1] 44 | # merge 45 | frames = [] 46 | for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r): 47 | weight_m = -1 48 | weight = weight_l + weight_m + weight_r 49 | frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight) 50 | frames.append(frame) 51 | frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames] 52 | return frames 53 | 54 | def __call__(self, rendered_frames, original_frames=None, **kwargs): 55 | frames = self.run( 56 | original_frames, rendered_frames, 57 | self.batch_size, self.window_size, self.ebsynth_config 58 | ) 59 | mempool = cp.get_default_memory_pool() 60 | pinned_mempool = cp.get_default_pinned_memory_pool() 61 | mempool.free_all_blocks() 62 | pinned_mempool.free_all_blocks() 63 | return frames -------------------------------------------------------------------------------- /diffsynth/extensions/FastBlend/cupy_kernels.py: -------------------------------------------------------------------------------- 1 | import cupy as cp 2 | 3 | remapping_kernel = cp.RawKernel(r''' 4 | extern "C" __global__ 5 | void remap( 6 | const int height, 7 | const int width, 8 | const int channel, 9 | const int patch_size, 10 | const int pad_size, 11 | const float* source_style, 12 | const int* nnf, 13 | float* target_style 14 | ) { 15 | const int r = (patch_size - 1) / 2; 16 | const int x = blockDim.x * blockIdx.x + threadIdx.x; 17 | const int y = blockDim.y * blockIdx.y + threadIdx.y; 18 | if (x >= height or y >= width) return; 19 | const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel; 20 | const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size); 21 | const int min_px = x < r ? -x : -r; 22 | const int max_px = x + r > height - 1 ? height - 1 - x : r; 23 | const int min_py = y < r ? -y : -r; 24 | const int max_py = y + r > width - 1 ? width - 1 - y : r; 25 | int num = 0; 26 | for (int px = min_px; px <= max_px; px++){ 27 | for (int py = min_py; py <= max_py; py++){ 28 | const int nid = (x + px) * width + y + py; 29 | const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px; 30 | const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py; 31 | if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue; 32 | const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size); 33 | num++; 34 | for (int c = 0; c < channel; c++){ 35 | target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c]; 36 | } 37 | } 38 | } 39 | for (int c = 0; c < channel; c++){ 40 | target_style[z + pid * channel + c] /= num; 41 | } 42 | } 43 | ''', 'remap') 44 | 45 | 46 | patch_error_kernel = cp.RawKernel(r''' 47 | extern "C" __global__ 48 | void patch_error( 49 | const int height, 50 | const int width, 51 | const int channel, 52 | const int patch_size, 53 | const int pad_size, 54 | const float* source, 55 | const int* nnf, 56 | const float* target, 57 | float* error 58 | ) { 59 | const int r = (patch_size - 1) / 2; 60 | const int x = blockDim.x * blockIdx.x + threadIdx.x; 61 | const int y = blockDim.y * blockIdx.y + threadIdx.y; 62 | const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel; 63 | if (x >= height or y >= width) return; 64 | const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0]; 65 | const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1]; 66 | float e = 0; 67 | for (int px = -r; px <= r; px++){ 68 | for (int py = -r; py <= r; py++){ 69 | const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py; 70 | const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py; 71 | for (int c = 0; c < channel; c++){ 72 | const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c]; 73 | e += diff * diff; 74 | } 75 | } 76 | } 77 | error[blockIdx.z * height * width + x * width + y] = e; 78 | } 79 | ''', 'patch_error') 80 | 81 | 82 | pairwise_patch_error_kernel = cp.RawKernel(r''' 83 | extern "C" __global__ 84 | void pairwise_patch_error( 85 | const int height, 86 | const int width, 87 | const int channel, 88 | const int patch_size, 89 | const int pad_size, 90 | const float* source_a, 91 | const int* nnf_a, 92 | const float* source_b, 93 | const int* nnf_b, 94 | float* error 95 | ) { 96 | const int r = (patch_size - 1) / 2; 97 | const int x = blockDim.x * blockIdx.x + threadIdx.x; 98 | const int y = blockDim.y * blockIdx.y + threadIdx.y; 99 | const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel; 100 | if (x >= height or y >= width) return; 101 | const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2; 102 | const int x_a = nnf_a[z_nnf + 0]; 103 | const int y_a = nnf_a[z_nnf + 1]; 104 | const int x_b = nnf_b[z_nnf + 0]; 105 | const int y_b = nnf_b[z_nnf + 1]; 106 | float e = 0; 107 | for (int px = -r; px <= r; px++){ 108 | for (int py = -r; py <= r; py++){ 109 | const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py; 110 | const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py; 111 | for (int c = 0; c < channel; c++){ 112 | const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c]; 113 | e += diff * diff; 114 | } 115 | } 116 | } 117 | error[blockIdx.z * height * width + x * width + y] = e; 118 | } 119 | ''', 'pairwise_patch_error') 120 | -------------------------------------------------------------------------------- /diffsynth/extensions/FastBlend/data.py: -------------------------------------------------------------------------------- 1 | import imageio, os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | 6 | def read_video(file_name): 7 | reader = imageio.get_reader(file_name) 8 | video = [] 9 | for frame in reader: 10 | frame = np.array(frame) 11 | video.append(frame) 12 | reader.close() 13 | return video 14 | 15 | 16 | def get_video_fps(file_name): 17 | reader = imageio.get_reader(file_name) 18 | fps = reader.get_meta_data()["fps"] 19 | reader.close() 20 | return fps 21 | 22 | 23 | def save_video(frames_path, video_path, num_frames, fps): 24 | writer = imageio.get_writer(video_path, fps=fps, quality=9) 25 | for i in range(num_frames): 26 | frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i))) 27 | writer.append_data(frame) 28 | writer.close() 29 | return video_path 30 | 31 | 32 | class LowMemoryVideo: 33 | def __init__(self, file_name): 34 | self.reader = imageio.get_reader(file_name) 35 | 36 | def __len__(self): 37 | return self.reader.count_frames() 38 | 39 | def __getitem__(self, item): 40 | return np.array(self.reader.get_data(item)) 41 | 42 | def __del__(self): 43 | self.reader.close() 44 | 45 | 46 | def split_file_name(file_name): 47 | result = [] 48 | number = -1 49 | for i in file_name: 50 | if ord(i)>=ord("0") and ord(i)<=ord("9"): 51 | if number == -1: 52 | number = 0 53 | number = number*10 + ord(i) - ord("0") 54 | else: 55 | if number != -1: 56 | result.append(number) 57 | number = -1 58 | result.append(i) 59 | if number != -1: 60 | result.append(number) 61 | result = tuple(result) 62 | return result 63 | 64 | 65 | def search_for_images(folder): 66 | file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")] 67 | file_list = [(split_file_name(file_name), file_name) for file_name in file_list] 68 | file_list = [i[1] for i in sorted(file_list)] 69 | file_list = [os.path.join(folder, i) for i in file_list] 70 | return file_list 71 | 72 | 73 | def read_images(folder): 74 | file_list = search_for_images(folder) 75 | frames = [np.array(Image.open(i)) for i in file_list] 76 | return frames 77 | 78 | 79 | class LowMemoryImageFolder: 80 | def __init__(self, folder, file_list=None): 81 | if file_list is None: 82 | self.file_list = search_for_images(folder) 83 | else: 84 | self.file_list = [os.path.join(folder, file_name) for file_name in file_list] 85 | 86 | def __len__(self): 87 | return len(self.file_list) 88 | 89 | def __getitem__(self, item): 90 | return np.array(Image.open(self.file_list[item])) 91 | 92 | def __del__(self): 93 | pass 94 | 95 | 96 | class VideoData: 97 | def __init__(self, video_file, image_folder, **kwargs): 98 | if video_file is not None: 99 | self.data_type = "video" 100 | self.data = LowMemoryVideo(video_file, **kwargs) 101 | elif image_folder is not None: 102 | self.data_type = "images" 103 | self.data = LowMemoryImageFolder(image_folder, **kwargs) 104 | else: 105 | raise ValueError("Cannot open video or image folder") 106 | self.length = None 107 | self.height = None 108 | self.width = None 109 | 110 | def raw_data(self): 111 | frames = [] 112 | for i in range(self.__len__()): 113 | frames.append(self.__getitem__(i)) 114 | return frames 115 | 116 | def set_length(self, length): 117 | self.length = length 118 | 119 | def set_shape(self, height, width): 120 | self.height = height 121 | self.width = width 122 | 123 | def __len__(self): 124 | if self.length is None: 125 | return len(self.data) 126 | else: 127 | return self.length 128 | 129 | def shape(self): 130 | if self.height is not None and self.width is not None: 131 | return self.height, self.width 132 | else: 133 | height, width, _ = self.__getitem__(0).shape 134 | return height, width 135 | 136 | def __getitem__(self, item): 137 | frame = self.data.__getitem__(item) 138 | height, width, _ = frame.shape 139 | if self.height is not None and self.width is not None: 140 | if self.height != height or self.width != width: 141 | frame = Image.fromarray(frame).resize((self.width, self.height)) 142 | frame = np.array(frame) 143 | return frame 144 | 145 | def __del__(self): 146 | pass 147 | -------------------------------------------------------------------------------- /diffsynth/extensions/FastBlend/runners/__init__.py: -------------------------------------------------------------------------------- 1 | from .accurate import AccurateModeRunner 2 | from .fast import FastModeRunner 3 | from .balanced import BalancedModeRunner 4 | from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner 5 | -------------------------------------------------------------------------------- /diffsynth/extensions/FastBlend/runners/accurate.py: -------------------------------------------------------------------------------- 1 | from ..patch_match import PyramidPatchMatcher 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | from tqdm import tqdm 6 | 7 | 8 | class AccurateModeRunner: 9 | def __init__(self): 10 | pass 11 | 12 | def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None): 13 | patch_match_engine = PyramidPatchMatcher( 14 | image_height=frames_style[0].shape[0], 15 | image_width=frames_style[0].shape[1], 16 | channel=3, 17 | use_mean_target_style=True, 18 | **ebsynth_config 19 | ) 20 | # run 21 | n = len(frames_style) 22 | for target in tqdm(range(n), desc=desc): 23 | l, r = max(target - window_size, 0), min(target + window_size + 1, n) 24 | remapped_frames = [] 25 | for i in range(l, r, batch_size): 26 | j = min(i + batch_size, r) 27 | source_guide = np.stack([frames_guide[source] for source in range(i, j)]) 28 | target_guide = np.stack([frames_guide[target]] * (j - i)) 29 | source_style = np.stack([frames_style[source] for source in range(i, j)]) 30 | _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) 31 | remapped_frames.append(target_style) 32 | frame = np.concatenate(remapped_frames, axis=0).mean(axis=0) 33 | frame = frame.clip(0, 255).astype("uint8") 34 | if save_path is not None: 35 | Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target)) -------------------------------------------------------------------------------- /diffsynth/extensions/FastBlend/runners/balanced.py: -------------------------------------------------------------------------------- 1 | from ..patch_match import PyramidPatchMatcher 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | from tqdm import tqdm 6 | 7 | 8 | class BalancedModeRunner: 9 | def __init__(self): 10 | pass 11 | 12 | def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None): 13 | patch_match_engine = PyramidPatchMatcher( 14 | image_height=frames_style[0].shape[0], 15 | image_width=frames_style[0].shape[1], 16 | channel=3, 17 | **ebsynth_config 18 | ) 19 | # tasks 20 | n = len(frames_style) 21 | tasks = [] 22 | for target in range(n): 23 | for source in range(target - window_size, target + window_size + 1): 24 | if source >= 0 and source < n and source != target: 25 | tasks.append((source, target)) 26 | # run 27 | frames = [(None, 1) for i in range(n)] 28 | for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc): 29 | tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))] 30 | source_guide = np.stack([frames_guide[source] for source, target in tasks_batch]) 31 | target_guide = np.stack([frames_guide[target] for source, target in tasks_batch]) 32 | source_style = np.stack([frames_style[source] for source, target in tasks_batch]) 33 | _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) 34 | for (source, target), result in zip(tasks_batch, target_style): 35 | frame, weight = frames[target] 36 | if frame is None: 37 | frame = frames_style[target] 38 | frames[target] = ( 39 | frame * (weight / (weight + 1)) + result / (weight + 1), 40 | weight + 1 41 | ) 42 | if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size): 43 | frame = frame.clip(0, 255).astype("uint8") 44 | if save_path is not None: 45 | Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target)) 46 | frames[target] = (None, 1) 47 | -------------------------------------------------------------------------------- /diffsynth/extensions/FastBlend/runners/fast.py: -------------------------------------------------------------------------------- 1 | from ..patch_match import PyramidPatchMatcher 2 | import functools, os 3 | import numpy as np 4 | from PIL import Image 5 | from tqdm import tqdm 6 | 7 | 8 | class TableManager: 9 | def __init__(self): 10 | pass 11 | 12 | def task_list(self, n): 13 | tasks = [] 14 | max_level = 1 15 | while (1<=n: 24 | break 25 | meta_data = { 26 | "source": i, 27 | "target": j, 28 | "level": level + 1 29 | } 30 | tasks.append(meta_data) 31 | tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"])) 32 | return tasks 33 | 34 | def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""): 35 | n = len(frames_guide) 36 | tasks = self.task_list(n) 37 | remapping_table = [[(frames_style[i], 1)] for i in range(n)] 38 | for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc): 39 | tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))] 40 | source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch]) 41 | target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch]) 42 | source_style = np.stack([frames_style[task["source"]] for task in tasks_batch]) 43 | _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) 44 | for task, result in zip(tasks_batch, target_style): 45 | target, level = task["target"], task["level"] 46 | if len(remapping_table[target])==level: 47 | remapping_table[target].append((result, 1)) 48 | else: 49 | frame, weight = remapping_table[target][level] 50 | remapping_table[target][level] = ( 51 | frame * (weight / (weight + 1)) + result / (weight + 1), 52 | weight + 1 53 | ) 54 | return remapping_table 55 | 56 | def remapping_table_to_blending_table(self, table): 57 | for i in range(len(table)): 58 | for j in range(1, len(table[i])): 59 | frame_1, weight_1 = table[i][j-1] 60 | frame_2, weight_2 = table[i][j] 61 | frame = (frame_1 + frame_2) / 2 62 | weight = weight_1 + weight_2 63 | table[i][j] = (frame, weight) 64 | return table 65 | 66 | def tree_query(self, leftbound, rightbound): 67 | node_list = [] 68 | node_index = rightbound 69 | while node_index>=leftbound: 70 | node_level = 0 71 | while (1<=leftbound: 72 | node_level += 1 73 | node_list.append((node_index, node_level)) 74 | node_index -= 1<0: 31 | tasks = [] 32 | for m in range(index_style[0]): 33 | tasks.append((index_style[0], m, index_style[0])) 34 | task_group.append(tasks) 35 | # middle frames 36 | for l, r in zip(index_style[:-1], index_style[1:]): 37 | tasks = [] 38 | for m in range(l, r): 39 | tasks.append((l, m, r)) 40 | task_group.append(tasks) 41 | # last frame 42 | tasks = [] 43 | for m in range(index_style[-1], n): 44 | tasks.append((index_style[-1], m, index_style[-1])) 45 | task_group.append(tasks) 46 | return task_group 47 | 48 | def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None): 49 | patch_match_engine = PyramidPatchMatcher( 50 | image_height=frames_style[0].shape[0], 51 | image_width=frames_style[0].shape[1], 52 | channel=3, 53 | use_mean_target_style=False, 54 | use_pairwise_patch_error=True, 55 | **ebsynth_config 56 | ) 57 | # task 58 | index_dict = self.get_index_dict(index_style) 59 | task_group = self.get_task_group(index_style, len(frames_guide)) 60 | # run 61 | for tasks in task_group: 62 | index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks]) 63 | for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"): 64 | tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))] 65 | source_guide, target_guide, source_style = [], [], [] 66 | for l, m, r in tasks_batch: 67 | # l -> m 68 | source_guide.append(frames_guide[l]) 69 | target_guide.append(frames_guide[m]) 70 | source_style.append(frames_style[index_dict[l]]) 71 | # r -> m 72 | source_guide.append(frames_guide[r]) 73 | target_guide.append(frames_guide[m]) 74 | source_style.append(frames_style[index_dict[r]]) 75 | source_guide = np.stack(source_guide) 76 | target_guide = np.stack(target_guide) 77 | source_style = np.stack(source_style) 78 | _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) 79 | if save_path is not None: 80 | for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch): 81 | weight_l, weight_r = self.get_weight(l, m, r) 82 | frame = frame_l * weight_l + frame_r * weight_r 83 | frame = frame.clip(0, 255).astype("uint8") 84 | Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m)) 85 | 86 | 87 | class InterpolationModeSingleFrameRunner: 88 | def __init__(self): 89 | pass 90 | 91 | def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None): 92 | # check input 93 | tracking_window_size = ebsynth_config["tracking_window_size"] 94 | if tracking_window_size * 2 >= batch_size: 95 | raise ValueError("batch_size should be larger than track_window_size * 2") 96 | frame_style = frames_style[0] 97 | frame_guide = frames_guide[index_style[0]] 98 | patch_match_engine = PyramidPatchMatcher( 99 | image_height=frame_style.shape[0], 100 | image_width=frame_style.shape[1], 101 | channel=3, 102 | **ebsynth_config 103 | ) 104 | # run 105 | frame_id, n = 0, len(frames_guide) 106 | for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"): 107 | if i + batch_size > n: 108 | l, r = max(n - batch_size, 0), n 109 | else: 110 | l, r = i, i + batch_size 111 | source_guide = np.stack([frame_guide] * (r-l)) 112 | target_guide = np.stack([frames_guide[i] for i in range(l, r)]) 113 | source_style = np.stack([frame_style] * (r-l)) 114 | _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) 115 | for i, frame in zip(range(l, r), target_style): 116 | if i==frame_id: 117 | frame = frame.clip(0, 255).astype("uint8") 118 | Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id)) 119 | frame_id += 1 120 | if r < n and r-frame_id <= tracking_window_size: 121 | break 122 | -------------------------------------------------------------------------------- /diffsynth/extensions/RIFE/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from PIL import Image 6 | 7 | 8 | def warp(tenInput, tenFlow, device): 9 | backwarp_tenGrid = {} 10 | k = (str(tenFlow.device), str(tenFlow.size())) 11 | if k not in backwarp_tenGrid: 12 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view( 13 | 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) 14 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view( 15 | 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) 16 | backwarp_tenGrid[k] = torch.cat( 17 | [tenHorizontal, tenVertical], 1).to(device) 18 | 19 | tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), 20 | tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) 21 | 22 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) 23 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) 24 | 25 | 26 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 27 | return nn.Sequential( 28 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 29 | padding=padding, dilation=dilation, bias=True), 30 | nn.PReLU(out_planes) 31 | ) 32 | 33 | 34 | class IFBlock(nn.Module): 35 | def __init__(self, in_planes, c=64): 36 | super(IFBlock, self).__init__() 37 | self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),) 38 | self.convblock0 = nn.Sequential(conv(c, c), conv(c, c)) 39 | self.convblock1 = nn.Sequential(conv(c, c), conv(c, c)) 40 | self.convblock2 = nn.Sequential(conv(c, c), conv(c, c)) 41 | self.convblock3 = nn.Sequential(conv(c, c), conv(c, c)) 42 | self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1)) 43 | self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1)) 44 | 45 | def forward(self, x, flow, scale=1): 46 | x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) 47 | flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale 48 | feat = self.conv0(torch.cat((x, flow), 1)) 49 | feat = self.convblock0(feat) + feat 50 | feat = self.convblock1(feat) + feat 51 | feat = self.convblock2(feat) + feat 52 | feat = self.convblock3(feat) + feat 53 | flow = self.conv1(feat) 54 | mask = self.conv2(feat) 55 | flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale 56 | mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) 57 | return flow, mask 58 | 59 | 60 | class IFNet(nn.Module): 61 | def __init__(self): 62 | super(IFNet, self).__init__() 63 | self.block0 = IFBlock(7+4, c=90) 64 | self.block1 = IFBlock(7+4, c=90) 65 | self.block2 = IFBlock(7+4, c=90) 66 | self.block_tea = IFBlock(10+4, c=90) 67 | 68 | def forward(self, x, scale_list=[4, 2, 1], training=False): 69 | if training == False: 70 | channel = x.shape[1] // 2 71 | img0 = x[:, :channel] 72 | img1 = x[:, channel:] 73 | flow_list = [] 74 | merged = [] 75 | mask_list = [] 76 | warped_img0 = img0 77 | warped_img1 = img1 78 | flow = (x[:, :4]).detach() * 0 79 | mask = (x[:, :1]).detach() * 0 80 | block = [self.block0, self.block1, self.block2] 81 | for i in range(3): 82 | f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i]) 83 | f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i]) 84 | flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 85 | mask = mask + (m0 + (-m1)) / 2 86 | mask_list.append(mask) 87 | flow_list.append(flow) 88 | warped_img0 = warp(img0, flow[:, :2], device=x.device) 89 | warped_img1 = warp(img1, flow[:, 2:4], device=x.device) 90 | merged.append((warped_img0, warped_img1)) 91 | ''' 92 | c0 = self.contextnet(img0, flow[:, :2]) 93 | c1 = self.contextnet(img1, flow[:, 2:4]) 94 | tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) 95 | res = tmp[:, 1:4] * 2 - 1 96 | ''' 97 | for i in range(3): 98 | mask_list[i] = torch.sigmoid(mask_list[i]) 99 | merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) 100 | return flow_list, mask_list[2], merged 101 | 102 | def state_dict_converter(self): 103 | return IFNetStateDictConverter() 104 | 105 | 106 | class IFNetStateDictConverter: 107 | def __init__(self): 108 | pass 109 | 110 | def from_diffusers(self, state_dict): 111 | state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()} 112 | return state_dict_ 113 | 114 | def from_civitai(self, state_dict): 115 | return self.from_diffusers(state_dict) 116 | 117 | 118 | class RIFEInterpolater: 119 | def __init__(self, model, device="cuda"): 120 | self.model = model 121 | self.device = device 122 | # IFNet only does not support float16 123 | self.torch_dtype = torch.float32 124 | 125 | @staticmethod 126 | def from_model_manager(model_manager): 127 | return RIFEInterpolater(model_manager.RIFE, device=model_manager.device) 128 | 129 | def process_image(self, image): 130 | width, height = image.size 131 | if width % 32 != 0 or height % 32 != 0: 132 | width = (width + 31) // 32 133 | height = (height + 31) // 32 134 | image = image.resize((width, height)) 135 | image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1) 136 | return image 137 | 138 | def process_images(self, images): 139 | images = [self.process_image(image) for image in images] 140 | images = torch.stack(images) 141 | return images 142 | 143 | def decode_images(self, images): 144 | images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8) 145 | images = [Image.fromarray(image) for image in images] 146 | return images 147 | 148 | def add_interpolated_images(self, images, interpolated_images): 149 | output_images = [] 150 | for image, interpolated_image in zip(images, interpolated_images): 151 | output_images.append(image) 152 | output_images.append(interpolated_image) 153 | output_images.append(images[-1]) 154 | return output_images 155 | 156 | 157 | @torch.no_grad() 158 | def interpolate_(self, images, scale=1.0): 159 | input_tensor = self.process_images(images) 160 | input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1) 161 | input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype) 162 | flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale]) 163 | output_images = self.decode_images(merged[2].cpu()) 164 | if output_images[0].size != images[0].size: 165 | output_images = [image.resize(images[0].size) for image in output_images] 166 | return output_images 167 | 168 | 169 | @torch.no_grad() 170 | def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x): 171 | # Preprocess 172 | processed_images = self.process_images(images) 173 | 174 | for iter in range(num_iter): 175 | # Input 176 | input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1) 177 | 178 | # Interpolate 179 | output_tensor = [] 180 | for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)): 181 | batch_id_ = min(batch_id + batch_size, input_tensor.shape[0]) 182 | batch_input_tensor = input_tensor[batch_id: batch_id_] 183 | batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype) 184 | flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale]) 185 | output_tensor.append(merged[2].cpu()) 186 | 187 | # Output 188 | output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1) 189 | processed_images = self.add_interpolated_images(processed_images, output_tensor) 190 | processed_images = torch.stack(processed_images) 191 | 192 | # To images 193 | output_images = self.decode_images(processed_images) 194 | if output_images[0].size != images[0].size: 195 | output_images = [image.resize(images[0].size) for image in output_images] 196 | return output_images 197 | 198 | 199 | class RIFESmoother(RIFEInterpolater): 200 | def __init__(self, model, device="cuda"): 201 | super(RIFESmoother, self).__init__(model, device=device) 202 | 203 | @staticmethod 204 | def from_model_manager(model_manager): 205 | return RIFESmoother(model_manager.RIFE, device=model_manager.device) 206 | 207 | def process_tensors(self, input_tensor, scale=1.0, batch_size=4): 208 | output_tensor = [] 209 | for batch_id in range(0, input_tensor.shape[0], batch_size): 210 | batch_id_ = min(batch_id + batch_size, input_tensor.shape[0]) 211 | batch_input_tensor = input_tensor[batch_id: batch_id_] 212 | batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype) 213 | flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale]) 214 | output_tensor.append(merged[2].cpu()) 215 | output_tensor = torch.concat(output_tensor, dim=0) 216 | return output_tensor 217 | 218 | @torch.no_grad() 219 | def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs): 220 | # Preprocess 221 | processed_images = self.process_images(rendered_frames) 222 | 223 | for iter in range(num_iter): 224 | # Input 225 | input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1) 226 | 227 | # Interpolate 228 | output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size) 229 | 230 | # Blend 231 | input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1) 232 | output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size) 233 | 234 | # Add to frames 235 | processed_images[1:-1] = output_tensor 236 | 237 | # To images 238 | output_images = self.decode_images(processed_images) 239 | if output_images[0].size != rendered_frames[0].size: 240 | output_images = [image.resize(rendered_frames[0].size) for image in output_images] 241 | return output_images 242 | -------------------------------------------------------------------------------- /diffsynth/models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | 4 | 5 | def low_version_attention(query, key, value, attn_bias=None): 6 | scale = 1 / query.shape[-1] ** 0.5 7 | query = query * scale 8 | attn = torch.matmul(query, key.transpose(-2, -1)) 9 | if attn_bias is not None: 10 | attn = attn + attn_bias 11 | attn = attn.softmax(-1) 12 | return attn @ value 13 | 14 | 15 | class Attention(torch.nn.Module): 16 | 17 | def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): 18 | super().__init__() 19 | dim_inner = head_dim * num_heads 20 | kv_dim = kv_dim if kv_dim is not None else q_dim 21 | self.num_heads = num_heads 22 | self.head_dim = head_dim 23 | 24 | self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) 25 | self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) 26 | self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) 27 | self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) 28 | 29 | def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0): 30 | batch_size = q.shape[0] 31 | ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) 32 | ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) 33 | ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v) 34 | hidden_states = hidden_states + scale * ip_hidden_states 35 | return hidden_states 36 | 37 | def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): 38 | if encoder_hidden_states is None: 39 | encoder_hidden_states = hidden_states 40 | 41 | batch_size = encoder_hidden_states.shape[0] 42 | 43 | q = self.to_q(hidden_states) 44 | k = self.to_k(encoder_hidden_states) 45 | v = self.to_v(encoder_hidden_states) 46 | 47 | q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) 48 | k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) 49 | v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) 50 | 51 | if qkv_preprocessor is not None: 52 | q, k, v = qkv_preprocessor(q, k, v) 53 | 54 | hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) 55 | if ipadapter_kwargs is not None: 56 | hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs) 57 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) 58 | hidden_states = hidden_states.to(q.dtype) 59 | 60 | hidden_states = self.to_out(hidden_states) 61 | 62 | return hidden_states 63 | 64 | def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): 65 | if encoder_hidden_states is None: 66 | encoder_hidden_states = hidden_states 67 | 68 | q = self.to_q(hidden_states) 69 | k = self.to_k(encoder_hidden_states) 70 | v = self.to_v(encoder_hidden_states) 71 | 72 | q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads) 73 | k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads) 74 | v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads) 75 | 76 | if attn_mask is not None: 77 | hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask) 78 | else: 79 | import xformers.ops as xops 80 | hidden_states = xops.memory_efficient_attention(q, k, v) 81 | hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads) 82 | 83 | hidden_states = hidden_states.to(q.dtype) 84 | hidden_states = self.to_out(hidden_states) 85 | 86 | return hidden_states 87 | 88 | def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): 89 | return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor) -------------------------------------------------------------------------------- /diffsynth/models/downloader.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import hf_hub_download 2 | from http.cookiejar import CookieJar 3 | from pathlib import Path 4 | from typing import Dict, Optional, List, Union 5 | import copy, uuid, requests, io, platform, pickle, os, urllib 6 | from requests.adapters import Retry 7 | from tqdm import tqdm 8 | 9 | 10 | def _get_sep(path): 11 | if isinstance(path, bytes): 12 | return b'/' 13 | else: 14 | return '/' 15 | 16 | 17 | def expanduser(path): 18 | """Expand ~ and ~user constructions. If user or $HOME is unknown, 19 | do nothing.""" 20 | path = os.fspath(path) 21 | if isinstance(path, bytes): 22 | tilde = b'~' 23 | else: 24 | tilde = '~' 25 | if not path.startswith(tilde): 26 | return path 27 | sep = _get_sep(path) 28 | i = path.find(sep, 1) 29 | if i < 0: 30 | i = len(path) 31 | if i == 1: 32 | if 'HOME' not in os.environ: 33 | import pwd 34 | try: 35 | userhome = pwd.getpwuid(os.getuid()).pw_dir 36 | except KeyError: 37 | # bpo-10496: if the current user identifier doesn't exist in the 38 | # password database, return the path unchanged 39 | return path 40 | else: 41 | userhome = os.environ['HOME'] 42 | else: 43 | import pwd 44 | name = path[1:i] 45 | if isinstance(name, bytes): 46 | name = str(name, 'ASCII') 47 | try: 48 | pwent = pwd.getpwnam(name) 49 | except KeyError: 50 | # bpo-10496: if the user name from the path doesn't exist in the 51 | # password database, return the path unchanged 52 | return path 53 | userhome = pwent.pw_dir 54 | if isinstance(path, bytes): 55 | userhome = os.fsencode(userhome) 56 | root = b'/' 57 | else: 58 | root = '/' 59 | userhome = userhome.rstrip(root) 60 | return (userhome + path[i:]) or root 61 | 62 | 63 | 64 | class ModelScopeConfig: 65 | DEFAULT_CREDENTIALS_PATH = Path.home().joinpath('.modelscope', 'credentials') 66 | path_credential = expanduser(DEFAULT_CREDENTIALS_PATH) 67 | COOKIES_FILE_NAME = 'cookies' 68 | GIT_TOKEN_FILE_NAME = 'git_token' 69 | USER_INFO_FILE_NAME = 'user' 70 | USER_SESSION_ID_FILE_NAME = 'session' 71 | 72 | @staticmethod 73 | def make_sure_credential_path_exist(): 74 | os.makedirs(ModelScopeConfig.path_credential, exist_ok=True) 75 | 76 | @staticmethod 77 | def get_user_session_id(): 78 | session_path = os.path.join(ModelScopeConfig.path_credential, 79 | ModelScopeConfig.USER_SESSION_ID_FILE_NAME) 80 | session_id = '' 81 | if os.path.exists(session_path): 82 | with open(session_path, 'rb') as f: 83 | session_id = str(f.readline().strip(), encoding='utf-8') 84 | return session_id 85 | if session_id == '' or len(session_id) != 32: 86 | session_id = str(uuid.uuid4().hex) 87 | ModelScopeConfig.make_sure_credential_path_exist() 88 | with open(session_path, 'w+') as wf: 89 | wf.write(session_id) 90 | 91 | return session_id 92 | 93 | @staticmethod 94 | def get_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: 95 | """Formats a user-agent string with basic info about a request. 96 | 97 | Args: 98 | user_agent (`str`, `dict`, *optional*): 99 | The user agent info in the form of a dictionary or a single string. 100 | 101 | Returns: 102 | The formatted user-agent string. 103 | """ 104 | 105 | # include some more telemetrics when executing in dedicated 106 | # cloud containers 107 | MODELSCOPE_CLOUD_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT' 108 | MODELSCOPE_CLOUD_USERNAME = 'MODELSCOPE_USERNAME' 109 | env = 'custom' 110 | if MODELSCOPE_CLOUD_ENVIRONMENT in os.environ: 111 | env = os.environ[MODELSCOPE_CLOUD_ENVIRONMENT] 112 | user_name = 'unknown' 113 | if MODELSCOPE_CLOUD_USERNAME in os.environ: 114 | user_name = os.environ[MODELSCOPE_CLOUD_USERNAME] 115 | 116 | ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s; user/%s' % ( 117 | "1.15.0", 118 | platform.python_version(), 119 | ModelScopeConfig.get_user_session_id(), 120 | platform.platform(), 121 | platform.processor(), 122 | env, 123 | user_name, 124 | ) 125 | if isinstance(user_agent, dict): 126 | ua += '; ' + '; '.join(f'{k}/{v}' for k, v in user_agent.items()) 127 | elif isinstance(user_agent, str): 128 | ua += '; ' + user_agent 129 | return ua 130 | 131 | @staticmethod 132 | def get_cookies(): 133 | cookies_path = os.path.join(ModelScopeConfig.path_credential, 134 | ModelScopeConfig.COOKIES_FILE_NAME) 135 | if os.path.exists(cookies_path): 136 | with open(cookies_path, 'rb') as f: 137 | cookies = pickle.load(f) 138 | return cookies 139 | return None 140 | 141 | 142 | 143 | def modelscope_http_get_model_file( 144 | url: str, 145 | local_dir: str, 146 | file_name: str, 147 | file_size: int, 148 | cookies: CookieJar, 149 | headers: Optional[Dict[str, str]] = None, 150 | ): 151 | """Download remote file, will retry 5 times before giving up on errors. 152 | 153 | Args: 154 | url(str): 155 | actual download url of the file 156 | local_dir(str): 157 | local directory where the downloaded file stores 158 | file_name(str): 159 | name of the file stored in `local_dir` 160 | file_size(int): 161 | The file size. 162 | cookies(CookieJar): 163 | cookies used to authentication the user, which is used for downloading private repos 164 | headers(Dict[str, str], optional): 165 | http headers to carry necessary info when requesting the remote file 166 | 167 | Raises: 168 | FileDownloadError: File download failed. 169 | 170 | """ 171 | get_headers = {} if headers is None else copy.deepcopy(headers) 172 | get_headers['X-Request-ID'] = str(uuid.uuid4().hex) 173 | temp_file_path = os.path.join(local_dir, file_name) 174 | # retry sleep 0.5s, 1s, 2s, 4s 175 | retry = Retry( 176 | total=5, 177 | backoff_factor=1, 178 | allowed_methods=['GET']) 179 | while True: 180 | try: 181 | progress = tqdm( 182 | unit='B', 183 | unit_scale=True, 184 | unit_divisor=1024, 185 | total=file_size, 186 | initial=0, 187 | desc='Downloading', 188 | ) 189 | partial_length = 0 190 | if os.path.exists( 191 | temp_file_path): # download partial, continue download 192 | with open(temp_file_path, 'rb') as f: 193 | partial_length = f.seek(0, io.SEEK_END) 194 | progress.update(partial_length) 195 | if partial_length > file_size: 196 | break 197 | get_headers['Range'] = 'bytes=%s-%s' % (partial_length, 198 | file_size - 1) 199 | with open(temp_file_path, 'ab') as f: 200 | r = requests.get( 201 | url, 202 | stream=True, 203 | headers=get_headers, 204 | cookies=cookies, 205 | timeout=60) 206 | r.raise_for_status() 207 | for chunk in r.iter_content( 208 | chunk_size=1024 * 1024 * 1): 209 | if chunk: # filter out keep-alive new chunks 210 | progress.update(len(chunk)) 211 | f.write(chunk) 212 | progress.close() 213 | break 214 | except (Exception) as e: # no matter what happen, we will retry. 215 | retry = retry.increment('GET', url, error=e) 216 | retry.sleep() 217 | 218 | 219 | def get_endpoint(): 220 | MODELSCOPE_URL_SCHEME = 'https://' 221 | DEFAULT_MODELSCOPE_DOMAIN = 'www.modelscope.cn' 222 | modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', 223 | DEFAULT_MODELSCOPE_DOMAIN) 224 | return MODELSCOPE_URL_SCHEME + modelscope_domain 225 | 226 | 227 | def get_file_download_url(model_id: str, file_path: str, revision: str): 228 | """Format file download url according to `model_id`, `revision` and `file_path`. 229 | e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`, 230 | the resulted download url is: https://modelscope.cn/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md 231 | 232 | Args: 233 | model_id (str): The model_id. 234 | file_path (str): File path 235 | revision (str): File revision. 236 | 237 | Returns: 238 | str: The file url. 239 | """ 240 | file_path = urllib.parse.quote_plus(file_path) 241 | revision = urllib.parse.quote_plus(revision) 242 | download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}' 243 | return download_url_template.format( 244 | endpoint=get_endpoint(), 245 | model_id=model_id, 246 | revision=revision, 247 | file_path=file_path, 248 | ) 249 | 250 | 251 | def download_from_modelscope(model_id, origin_file_path, local_dir): 252 | os.makedirs(local_dir, exist_ok=True) 253 | if os.path.basename(origin_file_path) in os.listdir(local_dir): 254 | print(f"{os.path.basename(origin_file_path)} has been already in {local_dir}.") 255 | return 256 | else: 257 | print(f"Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}") 258 | headers = {'user-agent': ModelScopeConfig.get_user_agent(user_agent=None)} 259 | cookies = ModelScopeConfig.get_cookies() 260 | url = get_file_download_url(model_id=model_id, file_path=origin_file_path, revision="master") 261 | modelscope_http_get_model_file( 262 | url, 263 | local_dir, 264 | os.path.basename(origin_file_path), 265 | file_size=0, 266 | headers=headers, 267 | cookies=cookies 268 | ) 269 | 270 | 271 | def download_from_huggingface(model_id, origin_file_path, local_dir): 272 | os.makedirs(local_dir, exist_ok=True) 273 | if os.path.basename(origin_file_path) in os.listdir(local_dir): 274 | print(f"{os.path.basename(origin_file_path)} has been already in {local_dir}.") 275 | return 276 | else: 277 | print(f"Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}") 278 | hf_hub_download(model_id, origin_file_path, local_dir=local_dir) 279 | -------------------------------------------------------------------------------- /diffsynth/models/hunyuan_dit_text_encoder.py: -------------------------------------------------------------------------------- 1 | from transformers import BertModel, BertConfig, T5EncoderModel, T5Config 2 | import torch 3 | 4 | 5 | 6 | class HunyuanDiTCLIPTextEncoder(BertModel): 7 | def __init__(self): 8 | config = BertConfig( 9 | _name_or_path = "", 10 | architectures = ["BertModel"], 11 | attention_probs_dropout_prob = 0.1, 12 | bos_token_id = 0, 13 | classifier_dropout = None, 14 | directionality = "bidi", 15 | eos_token_id = 2, 16 | hidden_act = "gelu", 17 | hidden_dropout_prob = 0.1, 18 | hidden_size = 1024, 19 | initializer_range = 0.02, 20 | intermediate_size = 4096, 21 | layer_norm_eps = 1e-12, 22 | max_position_embeddings = 512, 23 | model_type = "bert", 24 | num_attention_heads = 16, 25 | num_hidden_layers = 24, 26 | output_past = True, 27 | pad_token_id = 0, 28 | pooler_fc_size = 768, 29 | pooler_num_attention_heads = 12, 30 | pooler_num_fc_layers = 3, 31 | pooler_size_per_head = 128, 32 | pooler_type = "first_token_transform", 33 | position_embedding_type = "absolute", 34 | torch_dtype = "float32", 35 | transformers_version = "4.37.2", 36 | type_vocab_size = 2, 37 | use_cache = True, 38 | vocab_size = 47020 39 | ) 40 | super().__init__(config, add_pooling_layer=False) 41 | self.eval() 42 | 43 | def forward(self, input_ids, attention_mask, clip_skip=1): 44 | input_shape = input_ids.size() 45 | 46 | batch_size, seq_length = input_shape 47 | device = input_ids.device 48 | 49 | past_key_values_length = 0 50 | 51 | if attention_mask is None: 52 | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) 53 | 54 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) 55 | 56 | embedding_output = self.embeddings( 57 | input_ids=input_ids, 58 | position_ids=None, 59 | token_type_ids=None, 60 | inputs_embeds=None, 61 | past_key_values_length=0, 62 | ) 63 | encoder_outputs = self.encoder( 64 | embedding_output, 65 | attention_mask=extended_attention_mask, 66 | head_mask=None, 67 | encoder_hidden_states=None, 68 | encoder_attention_mask=None, 69 | past_key_values=None, 70 | use_cache=False, 71 | output_attentions=False, 72 | output_hidden_states=True, 73 | return_dict=True, 74 | ) 75 | all_hidden_states = encoder_outputs.hidden_states 76 | prompt_emb = all_hidden_states[-clip_skip] 77 | if clip_skip > 1: 78 | mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std() 79 | prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean 80 | return prompt_emb 81 | 82 | def state_dict_converter(self): 83 | return HunyuanDiTCLIPTextEncoderStateDictConverter() 84 | 85 | 86 | 87 | class HunyuanDiTT5TextEncoder(T5EncoderModel): 88 | def __init__(self): 89 | config = T5Config( 90 | _name_or_path = "../HunyuanDiT/t2i/mt5", 91 | architectures = ["MT5ForConditionalGeneration"], 92 | classifier_dropout = 0.0, 93 | d_ff = 5120, 94 | d_kv = 64, 95 | d_model = 2048, 96 | decoder_start_token_id = 0, 97 | dense_act_fn = "gelu_new", 98 | dropout_rate = 0.1, 99 | eos_token_id = 1, 100 | feed_forward_proj = "gated-gelu", 101 | initializer_factor = 1.0, 102 | is_encoder_decoder = True, 103 | is_gated_act = True, 104 | layer_norm_epsilon = 1e-06, 105 | model_type = "t5", 106 | num_decoder_layers = 24, 107 | num_heads = 32, 108 | num_layers = 24, 109 | output_past = True, 110 | pad_token_id = 0, 111 | relative_attention_max_distance = 128, 112 | relative_attention_num_buckets = 32, 113 | tie_word_embeddings = False, 114 | tokenizer_class = "T5Tokenizer", 115 | transformers_version = "4.37.2", 116 | use_cache = True, 117 | vocab_size = 250112 118 | ) 119 | super().__init__(config) 120 | self.eval() 121 | 122 | def forward(self, input_ids, attention_mask, clip_skip=1): 123 | outputs = super().forward( 124 | input_ids=input_ids, 125 | attention_mask=attention_mask, 126 | output_hidden_states=True, 127 | ) 128 | prompt_emb = outputs.hidden_states[-clip_skip] 129 | if clip_skip > 1: 130 | mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std() 131 | prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean 132 | return prompt_emb 133 | 134 | def state_dict_converter(self): 135 | return HunyuanDiTT5TextEncoderStateDictConverter() 136 | 137 | 138 | 139 | class HunyuanDiTCLIPTextEncoderStateDictConverter(): 140 | def __init__(self): 141 | pass 142 | 143 | def from_diffusers(self, state_dict): 144 | state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")} 145 | return state_dict_ 146 | 147 | def from_civitai(self, state_dict): 148 | return self.from_diffusers(state_dict) 149 | 150 | 151 | class HunyuanDiTT5TextEncoderStateDictConverter(): 152 | def __init__(self): 153 | pass 154 | 155 | def from_diffusers(self, state_dict): 156 | state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")} 157 | state_dict_["shared.weight"] = state_dict["shared.weight"] 158 | return state_dict_ 159 | 160 | def from_civitai(self, state_dict): 161 | return self.from_diffusers(state_dict) 162 | -------------------------------------------------------------------------------- /diffsynth/models/sd_ipadapter.py: -------------------------------------------------------------------------------- 1 | from .svd_image_encoder import SVDImageEncoder 2 | from .sdxl_ipadapter import IpAdapterImageProjModel, IpAdapterModule, SDXLIpAdapterStateDictConverter 3 | from transformers import CLIPImageProcessor 4 | import torch 5 | 6 | 7 | class IpAdapterCLIPImageEmbedder(SVDImageEncoder): 8 | def __init__(self): 9 | super().__init__() 10 | self.image_processor = CLIPImageProcessor() 11 | 12 | def forward(self, image): 13 | pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values 14 | pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype) 15 | return super().forward(pixel_values) 16 | 17 | 18 | class SDIpAdapter(torch.nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | shape_list = [(768, 320)] * 2 + [(768, 640)] * 2 + [(768, 1280)] * 5 + [(768, 640)] * 3 + [(768, 320)] * 3 + [(768, 1280)] * 1 22 | self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list]) 23 | self.image_proj = IpAdapterImageProjModel(cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4) 24 | self.set_full_adapter() 25 | 26 | def set_full_adapter(self): 27 | block_ids = [1, 4, 9, 12, 17, 20, 40, 43, 46, 50, 53, 56, 60, 63, 66, 29] 28 | self.call_block_id = {(i, 0): j for j, i in enumerate(block_ids)} 29 | 30 | def set_less_adapter(self): 31 | # IP-Adapter for SD v1.5 doesn't support this feature. 32 | self.set_full_adapter(self) 33 | 34 | def forward(self, hidden_states, scale=1.0): 35 | hidden_states = self.image_proj(hidden_states) 36 | hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1]) 37 | ip_kv_dict = {} 38 | for (block_id, transformer_id) in self.call_block_id: 39 | ipadapter_id = self.call_block_id[(block_id, transformer_id)] 40 | ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states) 41 | if block_id not in ip_kv_dict: 42 | ip_kv_dict[block_id] = {} 43 | ip_kv_dict[block_id][transformer_id] = { 44 | "ip_k": ip_k, 45 | "ip_v": ip_v, 46 | "scale": scale 47 | } 48 | return ip_kv_dict 49 | 50 | def state_dict_converter(self): 51 | return SDIpAdapterStateDictConverter() 52 | 53 | 54 | class SDIpAdapterStateDictConverter(SDXLIpAdapterStateDictConverter): 55 | def __init__(self): 56 | pass 57 | -------------------------------------------------------------------------------- /diffsynth/models/sd_lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .sd_unet import SDUNetStateDictConverter, SDUNet 3 | from .sd_text_encoder import SDTextEncoderStateDictConverter, SDTextEncoder 4 | 5 | 6 | class SDLoRA: 7 | def __init__(self): 8 | pass 9 | 10 | def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"): 11 | special_keys = { 12 | "down.blocks": "down_blocks", 13 | "up.blocks": "up_blocks", 14 | "mid.block": "mid_block", 15 | "proj.in": "proj_in", 16 | "proj.out": "proj_out", 17 | "transformer.blocks": "transformer_blocks", 18 | "to.q": "to_q", 19 | "to.k": "to_k", 20 | "to.v": "to_v", 21 | "to.out": "to_out", 22 | } 23 | state_dict_ = {} 24 | for key in state_dict: 25 | if ".lora_up" not in key: 26 | continue 27 | if not key.startswith(lora_prefix): 28 | continue 29 | weight_up = state_dict[key].to(device="cuda", dtype=torch.float16) 30 | weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16) 31 | if len(weight_up.shape) == 4: 32 | weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32) 33 | weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32) 34 | lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) 35 | else: 36 | lora_weight = alpha * torch.mm(weight_up, weight_down) 37 | target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight" 38 | for special_key in special_keys: 39 | target_name = target_name.replace(special_key, special_keys[special_key]) 40 | state_dict_[target_name] = lora_weight.cpu() 41 | return state_dict_ 42 | 43 | def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"): 44 | state_dict_unet = unet.state_dict() 45 | state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device) 46 | state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora) 47 | if len(state_dict_lora) > 0: 48 | for name in state_dict_lora: 49 | state_dict_unet[name] += state_dict_lora[name].to(device=device) 50 | unet.load_state_dict(state_dict_unet) 51 | 52 | def add_lora_to_text_encoder(self, text_encoder: SDTextEncoder, state_dict_lora, alpha=1.0, device="cuda"): 53 | state_dict_text_encoder = text_encoder.state_dict() 54 | state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_te_", alpha=alpha, device=device) 55 | state_dict_lora = SDTextEncoderStateDictConverter().from_diffusers(state_dict_lora) 56 | if len(state_dict_lora) > 0: 57 | for name in state_dict_lora: 58 | state_dict_text_encoder[name] += state_dict_lora[name].to(device=device) 59 | text_encoder.load_state_dict(state_dict_text_encoder) 60 | 61 | -------------------------------------------------------------------------------- /diffsynth/models/sd_motion.py: -------------------------------------------------------------------------------- 1 | from .sd_unet import SDUNet, Attention, GEGLU 2 | import torch 3 | from einops import rearrange, repeat 4 | 5 | 6 | class TemporalTransformerBlock(torch.nn.Module): 7 | 8 | def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32): 9 | super().__init__() 10 | 11 | # 1. Self-Attn 12 | self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim)) 13 | self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True) 14 | self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) 15 | 16 | # 2. Cross-Attn 17 | self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim)) 18 | self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True) 19 | self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) 20 | 21 | # 3. Feed-forward 22 | self.norm3 = torch.nn.LayerNorm(dim, elementwise_affine=True) 23 | self.act_fn = GEGLU(dim, dim * 4) 24 | self.ff = torch.nn.Linear(dim * 4, dim) 25 | 26 | 27 | def forward(self, hidden_states, batch_size=1): 28 | 29 | # 1. Self-Attention 30 | norm_hidden_states = self.norm1(hidden_states) 31 | norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size) 32 | attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]]) 33 | attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size) 34 | hidden_states = attn_output + hidden_states 35 | 36 | # 2. Cross-Attention 37 | norm_hidden_states = self.norm2(hidden_states) 38 | norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size) 39 | attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]]) 40 | attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size) 41 | hidden_states = attn_output + hidden_states 42 | 43 | # 3. Feed-forward 44 | norm_hidden_states = self.norm3(hidden_states) 45 | ff_output = self.act_fn(norm_hidden_states) 46 | ff_output = self.ff(ff_output) 47 | hidden_states = ff_output + hidden_states 48 | 49 | return hidden_states 50 | 51 | 52 | class TemporalBlock(torch.nn.Module): 53 | 54 | def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5): 55 | super().__init__() 56 | inner_dim = num_attention_heads * attention_head_dim 57 | 58 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True) 59 | self.proj_in = torch.nn.Linear(in_channels, inner_dim) 60 | 61 | self.transformer_blocks = torch.nn.ModuleList([ 62 | TemporalTransformerBlock( 63 | inner_dim, 64 | num_attention_heads, 65 | attention_head_dim 66 | ) 67 | for d in range(num_layers) 68 | ]) 69 | 70 | self.proj_out = torch.nn.Linear(inner_dim, in_channels) 71 | 72 | def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1): 73 | batch, _, height, width = hidden_states.shape 74 | residual = hidden_states 75 | 76 | hidden_states = self.norm(hidden_states) 77 | inner_dim = hidden_states.shape[1] 78 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 79 | hidden_states = self.proj_in(hidden_states) 80 | 81 | for block in self.transformer_blocks: 82 | hidden_states = block( 83 | hidden_states, 84 | batch_size=batch_size 85 | ) 86 | 87 | hidden_states = self.proj_out(hidden_states) 88 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 89 | hidden_states = hidden_states + residual 90 | 91 | return hidden_states, time_emb, text_emb, res_stack 92 | 93 | 94 | class SDMotionModel(torch.nn.Module): 95 | def __init__(self): 96 | super().__init__() 97 | self.motion_modules = torch.nn.ModuleList([ 98 | TemporalBlock(8, 40, 320, eps=1e-6), 99 | TemporalBlock(8, 40, 320, eps=1e-6), 100 | TemporalBlock(8, 80, 640, eps=1e-6), 101 | TemporalBlock(8, 80, 640, eps=1e-6), 102 | TemporalBlock(8, 160, 1280, eps=1e-6), 103 | TemporalBlock(8, 160, 1280, eps=1e-6), 104 | TemporalBlock(8, 160, 1280, eps=1e-6), 105 | TemporalBlock(8, 160, 1280, eps=1e-6), 106 | TemporalBlock(8, 160, 1280, eps=1e-6), 107 | TemporalBlock(8, 160, 1280, eps=1e-6), 108 | TemporalBlock(8, 160, 1280, eps=1e-6), 109 | TemporalBlock(8, 160, 1280, eps=1e-6), 110 | TemporalBlock(8, 160, 1280, eps=1e-6), 111 | TemporalBlock(8, 160, 1280, eps=1e-6), 112 | TemporalBlock(8, 160, 1280, eps=1e-6), 113 | TemporalBlock(8, 80, 640, eps=1e-6), 114 | TemporalBlock(8, 80, 640, eps=1e-6), 115 | TemporalBlock(8, 80, 640, eps=1e-6), 116 | TemporalBlock(8, 40, 320, eps=1e-6), 117 | TemporalBlock(8, 40, 320, eps=1e-6), 118 | TemporalBlock(8, 40, 320, eps=1e-6), 119 | ]) 120 | self.call_block_id = { 121 | 1: 0, 122 | 4: 1, 123 | 9: 2, 124 | 12: 3, 125 | 17: 4, 126 | 20: 5, 127 | 24: 6, 128 | 26: 7, 129 | 29: 8, 130 | 32: 9, 131 | 34: 10, 132 | 36: 11, 133 | 40: 12, 134 | 43: 13, 135 | 46: 14, 136 | 50: 15, 137 | 53: 16, 138 | 56: 17, 139 | 60: 18, 140 | 63: 19, 141 | 66: 20 142 | } 143 | 144 | def forward(self): 145 | pass 146 | 147 | def state_dict_converter(self): 148 | return SDMotionModelStateDictConverter() 149 | 150 | 151 | class SDMotionModelStateDictConverter: 152 | def __init__(self): 153 | pass 154 | 155 | def from_diffusers(self, state_dict): 156 | rename_dict = { 157 | "norm": "norm", 158 | "proj_in": "proj_in", 159 | "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q", 160 | "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k", 161 | "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v", 162 | "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out", 163 | "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1", 164 | "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q", 165 | "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k", 166 | "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v", 167 | "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out", 168 | "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2", 169 | "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1", 170 | "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2", 171 | "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj", 172 | "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff", 173 | "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3", 174 | "proj_out": "proj_out", 175 | } 176 | name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")]) 177 | name_list += sorted([i for i in state_dict if i.startswith("mid_block.")]) 178 | name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")]) 179 | state_dict_ = {} 180 | last_prefix, module_id = "", -1 181 | for name in name_list: 182 | names = name.split(".") 183 | prefix_index = names.index("temporal_transformer") + 1 184 | prefix = ".".join(names[:prefix_index]) 185 | if prefix != last_prefix: 186 | last_prefix = prefix 187 | module_id += 1 188 | middle_name = ".".join(names[prefix_index:-1]) 189 | suffix = names[-1] 190 | if "pos_encoder" in names: 191 | rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]]) 192 | else: 193 | rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix]) 194 | state_dict_[rename] = state_dict[name] 195 | return state_dict_ 196 | 197 | def from_civitai(self, state_dict): 198 | return self.from_diffusers(state_dict) 199 | -------------------------------------------------------------------------------- /diffsynth/models/sdxl_ipadapter.py: -------------------------------------------------------------------------------- 1 | from .svd_image_encoder import SVDImageEncoder 2 | from transformers import CLIPImageProcessor 3 | import torch 4 | 5 | 6 | class IpAdapterXLCLIPImageEmbedder(SVDImageEncoder): 7 | def __init__(self): 8 | super().__init__(embed_dim=1664, encoder_intermediate_size=8192, projection_dim=1280, num_encoder_layers=48, num_heads=16, head_dim=104) 9 | self.image_processor = CLIPImageProcessor() 10 | 11 | def forward(self, image): 12 | pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values 13 | pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype) 14 | return super().forward(pixel_values) 15 | 16 | 17 | class IpAdapterImageProjModel(torch.nn.Module): 18 | def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4): 19 | super().__init__() 20 | self.cross_attention_dim = cross_attention_dim 21 | self.clip_extra_context_tokens = clip_extra_context_tokens 22 | self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 23 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 24 | 25 | def forward(self, image_embeds): 26 | clip_extra_context_tokens = self.proj(image_embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) 27 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 28 | return clip_extra_context_tokens 29 | 30 | 31 | class IpAdapterModule(torch.nn.Module): 32 | def __init__(self, input_dim, output_dim): 33 | super().__init__() 34 | self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False) 35 | self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False) 36 | 37 | def forward(self, hidden_states): 38 | ip_k = self.to_k_ip(hidden_states) 39 | ip_v = self.to_v_ip(hidden_states) 40 | return ip_k, ip_v 41 | 42 | 43 | class SDXLIpAdapter(torch.nn.Module): 44 | def __init__(self): 45 | super().__init__() 46 | shape_list = [(2048, 640)] * 4 + [(2048, 1280)] * 50 + [(2048, 640)] * 6 + [(2048, 1280)] * 10 47 | self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list]) 48 | self.image_proj = IpAdapterImageProjModel() 49 | self.set_full_adapter() 50 | 51 | def set_full_adapter(self): 52 | map_list = sum([ 53 | [(7, i) for i in range(2)], 54 | [(10, i) for i in range(2)], 55 | [(15, i) for i in range(10)], 56 | [(18, i) for i in range(10)], 57 | [(25, i) for i in range(10)], 58 | [(28, i) for i in range(10)], 59 | [(31, i) for i in range(10)], 60 | [(35, i) for i in range(2)], 61 | [(38, i) for i in range(2)], 62 | [(41, i) for i in range(2)], 63 | [(21, i) for i in range(10)], 64 | ], []) 65 | self.call_block_id = {i: j for j, i in enumerate(map_list)} 66 | 67 | def set_less_adapter(self): 68 | map_list = sum([ 69 | [(7, i) for i in range(2)], 70 | [(10, i) for i in range(2)], 71 | [(15, i) for i in range(10)], 72 | [(18, i) for i in range(10)], 73 | [(25, i) for i in range(10)], 74 | [(28, i) for i in range(10)], 75 | [(31, i) for i in range(10)], 76 | [(35, i) for i in range(2)], 77 | [(38, i) for i in range(2)], 78 | [(41, i) for i in range(2)], 79 | [(21, i) for i in range(10)], 80 | ], []) 81 | self.call_block_id = {i: j for j, i in enumerate(map_list) if j>=34 and j<44} 82 | 83 | def forward(self, hidden_states, scale=1.0): 84 | hidden_states = self.image_proj(hidden_states) 85 | hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1]) 86 | ip_kv_dict = {} 87 | for (block_id, transformer_id) in self.call_block_id: 88 | ipadapter_id = self.call_block_id[(block_id, transformer_id)] 89 | ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states) 90 | if block_id not in ip_kv_dict: 91 | ip_kv_dict[block_id] = {} 92 | ip_kv_dict[block_id][transformer_id] = { 93 | "ip_k": ip_k, 94 | "ip_v": ip_v, 95 | "scale": scale 96 | } 97 | return ip_kv_dict 98 | 99 | def state_dict_converter(self): 100 | return SDXLIpAdapterStateDictConverter() 101 | 102 | 103 | class SDXLIpAdapterStateDictConverter: 104 | def __init__(self): 105 | pass 106 | 107 | def from_diffusers(self, state_dict): 108 | state_dict_ = {} 109 | for name in state_dict["ip_adapter"]: 110 | names = name.split(".") 111 | layer_id = str(int(names[0]) // 2) 112 | name_ = ".".join(["ipadapter_modules"] + [layer_id] + names[1:]) 113 | state_dict_[name_] = state_dict["ip_adapter"][name] 114 | for name in state_dict["image_proj"]: 115 | name_ = "image_proj." + name 116 | state_dict_[name_] = state_dict["image_proj"][name] 117 | return state_dict_ 118 | 119 | def from_civitai(self, state_dict): 120 | return self.from_diffusers(state_dict) 121 | 122 | -------------------------------------------------------------------------------- /diffsynth/models/sdxl_motion.py: -------------------------------------------------------------------------------- 1 | from .sd_motion import TemporalBlock 2 | import torch 3 | 4 | 5 | 6 | class SDXLMotionModel(torch.nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | self.motion_modules = torch.nn.ModuleList([ 10 | TemporalBlock(8, 320//8, 320, eps=1e-6), 11 | TemporalBlock(8, 320//8, 320, eps=1e-6), 12 | 13 | TemporalBlock(8, 640//8, 640, eps=1e-6), 14 | TemporalBlock(8, 640//8, 640, eps=1e-6), 15 | 16 | TemporalBlock(8, 1280//8, 1280, eps=1e-6), 17 | TemporalBlock(8, 1280//8, 1280, eps=1e-6), 18 | 19 | TemporalBlock(8, 1280//8, 1280, eps=1e-6), 20 | TemporalBlock(8, 1280//8, 1280, eps=1e-6), 21 | TemporalBlock(8, 1280//8, 1280, eps=1e-6), 22 | 23 | TemporalBlock(8, 640//8, 640, eps=1e-6), 24 | TemporalBlock(8, 640//8, 640, eps=1e-6), 25 | TemporalBlock(8, 640//8, 640, eps=1e-6), 26 | 27 | TemporalBlock(8, 320//8, 320, eps=1e-6), 28 | TemporalBlock(8, 320//8, 320, eps=1e-6), 29 | TemporalBlock(8, 320//8, 320, eps=1e-6), 30 | ]) 31 | self.call_block_id = { 32 | 0: 0, 33 | 2: 1, 34 | 7: 2, 35 | 10: 3, 36 | 15: 4, 37 | 18: 5, 38 | 25: 6, 39 | 28: 7, 40 | 31: 8, 41 | 35: 9, 42 | 38: 10, 43 | 41: 11, 44 | 44: 12, 45 | 46: 13, 46 | 48: 14, 47 | } 48 | 49 | def forward(self): 50 | pass 51 | 52 | def state_dict_converter(self): 53 | return SDMotionModelStateDictConverter() 54 | 55 | 56 | class SDMotionModelStateDictConverter: 57 | def __init__(self): 58 | pass 59 | 60 | def from_diffusers(self, state_dict): 61 | rename_dict = { 62 | "norm": "norm", 63 | "proj_in": "proj_in", 64 | "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q", 65 | "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k", 66 | "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v", 67 | "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out", 68 | "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1", 69 | "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q", 70 | "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k", 71 | "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v", 72 | "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out", 73 | "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2", 74 | "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1", 75 | "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2", 76 | "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj", 77 | "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff", 78 | "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3", 79 | "proj_out": "proj_out", 80 | } 81 | name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")]) 82 | name_list += sorted([i for i in state_dict if i.startswith("mid_block.")]) 83 | name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")]) 84 | state_dict_ = {} 85 | last_prefix, module_id = "", -1 86 | for name in name_list: 87 | names = name.split(".") 88 | prefix_index = names.index("temporal_transformer") + 1 89 | prefix = ".".join(names[:prefix_index]) 90 | if prefix != last_prefix: 91 | last_prefix = prefix 92 | module_id += 1 93 | middle_name = ".".join(names[prefix_index:-1]) 94 | suffix = names[-1] 95 | if "pos_encoder" in names: 96 | rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]]) 97 | else: 98 | rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix]) 99 | state_dict_[rename] = state_dict[name] 100 | return state_dict_ 101 | 102 | def from_civitai(self, state_dict): 103 | return self.from_diffusers(state_dict) 104 | -------------------------------------------------------------------------------- /diffsynth/models/sdxl_vae_decoder.py: -------------------------------------------------------------------------------- 1 | from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter 2 | 3 | 4 | class SDXLVAEDecoder(SDVAEDecoder): 5 | def __init__(self): 6 | super().__init__() 7 | self.scaling_factor = 0.13025 8 | 9 | def state_dict_converter(self): 10 | return SDXLVAEDecoderStateDictConverter() 11 | 12 | 13 | class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter): 14 | def __init__(self): 15 | super().__init__() 16 | -------------------------------------------------------------------------------- /diffsynth/models/sdxl_vae_encoder.py: -------------------------------------------------------------------------------- 1 | from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder 2 | 3 | 4 | class SDXLVAEEncoder(SDVAEEncoder): 5 | def __init__(self): 6 | super().__init__() 7 | self.scaling_factor = 0.13025 8 | 9 | def state_dict_converter(self): 10 | return SDXLVAEEncoderStateDictConverter() 11 | 12 | 13 | class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter): 14 | def __init__(self): 15 | super().__init__() 16 | -------------------------------------------------------------------------------- /diffsynth/models/tiler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange, repeat 3 | 4 | 5 | class TileWorker: 6 | def __init__(self): 7 | pass 8 | 9 | 10 | def mask(self, height, width, border_width): 11 | # Create a mask with shape (height, width). 12 | # The centre area is filled with 1, and the border line is filled with values in range (0, 1]. 13 | x = torch.arange(height).repeat(width, 1).T 14 | y = torch.arange(width).repeat(height, 1) 15 | mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values 16 | mask = (mask / border_width).clip(0, 1) 17 | return mask 18 | 19 | 20 | def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype): 21 | # Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num) 22 | batch_size, channel, _, _ = model_input.shape 23 | model_input = model_input.to(device=tile_device, dtype=tile_dtype) 24 | unfold_operator = torch.nn.Unfold( 25 | kernel_size=(tile_size, tile_size), 26 | stride=(tile_stride, tile_stride) 27 | ) 28 | model_input = unfold_operator(model_input) 29 | model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1)) 30 | 31 | return model_input 32 | 33 | 34 | def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype): 35 | # Call y=forward_fn(x) for each tile 36 | tile_num = model_input.shape[-1] 37 | model_output_stack = [] 38 | 39 | for tile_id in range(0, tile_num, tile_batch_size): 40 | 41 | # process input 42 | tile_id_ = min(tile_id + tile_batch_size, tile_num) 43 | x = model_input[:, :, :, :, tile_id: tile_id_] 44 | x = x.to(device=inference_device, dtype=inference_dtype) 45 | x = rearrange(x, "b c h w n -> (n b) c h w") 46 | 47 | # process output 48 | y = forward_fn(x) 49 | y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id) 50 | y = y.to(device=tile_device, dtype=tile_dtype) 51 | model_output_stack.append(y) 52 | 53 | model_output = torch.concat(model_output_stack, dim=-1) 54 | return model_output 55 | 56 | 57 | def io_scale(self, model_output, tile_size): 58 | # Determine the size modification happend in forward_fn 59 | # We only consider the same scale on height and width. 60 | io_scale = model_output.shape[2] / tile_size 61 | return io_scale 62 | 63 | 64 | def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype): 65 | # The reversed function of tile 66 | mask = self.mask(tile_size, tile_size, border_width) 67 | mask = mask.to(device=tile_device, dtype=tile_dtype) 68 | mask = rearrange(mask, "h w -> 1 1 h w 1") 69 | model_output = model_output * mask 70 | 71 | fold_operator = torch.nn.Fold( 72 | output_size=(height, width), 73 | kernel_size=(tile_size, tile_size), 74 | stride=(tile_stride, tile_stride) 75 | ) 76 | mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1]) 77 | model_output = rearrange(model_output, "b c h w n -> b (c h w) n") 78 | model_output = fold_operator(model_output) / fold_operator(mask) 79 | 80 | return model_output 81 | 82 | 83 | def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None): 84 | # Prepare 85 | inference_device, inference_dtype = model_input.device, model_input.dtype 86 | height, width = model_input.shape[2], model_input.shape[3] 87 | border_width = int(tile_stride*0.5) if border_width is None else border_width 88 | 89 | # tile 90 | model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype) 91 | 92 | # inference 93 | model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype) 94 | 95 | # resize 96 | io_scale = self.io_scale(model_output, tile_size) 97 | height, width = int(height*io_scale), int(width*io_scale) 98 | tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale) 99 | border_width = int(border_width*io_scale) 100 | 101 | # untile 102 | model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype) 103 | 104 | # Done! 105 | model_output = model_output.to(device=inference_device, dtype=inference_dtype) 106 | return model_output -------------------------------------------------------------------------------- /diffsynth/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .stable_diffusion import SDImagePipeline 2 | from .stable_diffusion_xl import SDXLImagePipeline 3 | from .stable_diffusion_video import SDVideoPipeline, SDVideoPipelineRunner 4 | from .stable_diffusion_xl_video import SDXLVideoPipeline 5 | from .stable_video_diffusion import SVDVideoPipeline 6 | from .hunyuan_dit import HunyuanDiTImagePipeline 7 | -------------------------------------------------------------------------------- /diffsynth/pipelines/dancer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..models import SDUNet, SDMotionModel, SDXLUNet, SDXLMotionModel 3 | from ..models.sd_unet import PushBlock, PopBlock 4 | from ..controlnets import MultiControlNetManager 5 | 6 | 7 | def lets_dance( 8 | unet: SDUNet, 9 | motion_modules: SDMotionModel = None, 10 | controlnet: MultiControlNetManager = None, 11 | sample = None, 12 | timestep = None, 13 | encoder_hidden_states = None, 14 | ipadapter_kwargs_list = {}, 15 | controlnet_frames = None, 16 | unet_batch_size = 1, 17 | controlnet_batch_size = 1, 18 | cross_frame_attention = False, 19 | tiled=False, 20 | tile_size=64, 21 | tile_stride=32, 22 | device = "cuda", 23 | vram_limit_level = 0, 24 | ): 25 | # 1. ControlNet 26 | # This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride. 27 | # I leave it here because I intend to do something interesting on the ControlNets. 28 | controlnet_insert_block_id = 30 29 | if controlnet is not None and controlnet_frames is not None: 30 | res_stacks = [] 31 | # process controlnet frames with batch 32 | for batch_id in range(0, sample.shape[0], controlnet_batch_size): 33 | batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0]) 34 | res_stack = controlnet( 35 | sample[batch_id: batch_id_], 36 | timestep, 37 | encoder_hidden_states[batch_id: batch_id_], 38 | controlnet_frames[:, batch_id: batch_id_], 39 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride 40 | ) 41 | if vram_limit_level >= 1: 42 | res_stack = [res.cpu() for res in res_stack] 43 | res_stacks.append(res_stack) 44 | # concat the residual 45 | additional_res_stack = [] 46 | for i in range(len(res_stacks[0])): 47 | res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0) 48 | additional_res_stack.append(res) 49 | else: 50 | additional_res_stack = None 51 | 52 | # 2. time 53 | time_emb = unet.time_proj(timestep[None]).to(sample.dtype) 54 | time_emb = unet.time_embedding(time_emb) 55 | 56 | # 3. pre-process 57 | height, width = sample.shape[2], sample.shape[3] 58 | hidden_states = unet.conv_in(sample) 59 | text_emb = encoder_hidden_states 60 | res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states] 61 | 62 | # 4. blocks 63 | for block_id, block in enumerate(unet.blocks): 64 | # 4.1 UNet 65 | if isinstance(block, PushBlock): 66 | hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) 67 | if vram_limit_level>=1: 68 | res_stack[-1] = res_stack[-1].cpu() 69 | elif isinstance(block, PopBlock): 70 | if vram_limit_level>=1: 71 | res_stack[-1] = res_stack[-1].to(device) 72 | hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) 73 | else: 74 | hidden_states_input = hidden_states 75 | hidden_states_output = [] 76 | for batch_id in range(0, sample.shape[0], unet_batch_size): 77 | batch_id_ = min(batch_id + unet_batch_size, sample.shape[0]) 78 | hidden_states, _, _, _ = block( 79 | hidden_states_input[batch_id: batch_id_], 80 | time_emb, 81 | text_emb[batch_id: batch_id_], 82 | res_stack, 83 | cross_frame_attention=cross_frame_attention, 84 | ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}), 85 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride 86 | ) 87 | hidden_states_output.append(hidden_states) 88 | hidden_states = torch.concat(hidden_states_output, dim=0) 89 | # 4.2 AnimateDiff 90 | if motion_modules is not None: 91 | if block_id in motion_modules.call_block_id: 92 | motion_module_id = motion_modules.call_block_id[block_id] 93 | hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id]( 94 | hidden_states, time_emb, text_emb, res_stack, 95 | batch_size=1 96 | ) 97 | # 4.3 ControlNet 98 | if block_id == controlnet_insert_block_id and additional_res_stack is not None: 99 | hidden_states += additional_res_stack.pop().to(device) 100 | if vram_limit_level>=1: 101 | res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)] 102 | else: 103 | res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)] 104 | 105 | # 5. output 106 | hidden_states = unet.conv_norm_out(hidden_states) 107 | hidden_states = unet.conv_act(hidden_states) 108 | hidden_states = unet.conv_out(hidden_states) 109 | 110 | return hidden_states 111 | 112 | 113 | 114 | 115 | def lets_dance_xl( 116 | unet: SDXLUNet, 117 | motion_modules: SDXLMotionModel = None, 118 | controlnet: MultiControlNetManager = None, 119 | sample = None, 120 | add_time_id = None, 121 | add_text_embeds = None, 122 | timestep = None, 123 | encoder_hidden_states = None, 124 | ipadapter_kwargs_list = {}, 125 | controlnet_frames = None, 126 | unet_batch_size = 1, 127 | controlnet_batch_size = 1, 128 | cross_frame_attention = False, 129 | tiled=False, 130 | tile_size=64, 131 | tile_stride=32, 132 | device = "cuda", 133 | vram_limit_level = 0, 134 | ): 135 | # 2. time 136 | t_emb = unet.time_proj(timestep[None]).to(sample.dtype) 137 | t_emb = unet.time_embedding(t_emb) 138 | 139 | time_embeds = unet.add_time_proj(add_time_id) 140 | time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1)) 141 | add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1) 142 | add_embeds = add_embeds.to(sample.dtype) 143 | add_embeds = unet.add_time_embedding(add_embeds) 144 | 145 | time_emb = t_emb + add_embeds 146 | 147 | # 3. pre-process 148 | height, width = sample.shape[2], sample.shape[3] 149 | hidden_states = unet.conv_in(sample) 150 | text_emb = encoder_hidden_states 151 | res_stack = [hidden_states] 152 | 153 | # 4. blocks 154 | for block_id, block in enumerate(unet.blocks): 155 | hidden_states, time_emb, text_emb, res_stack = block( 156 | hidden_states, time_emb, text_emb, res_stack, 157 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, 158 | ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}) 159 | ) 160 | # 4.2 AnimateDiff 161 | if motion_modules is not None: 162 | if block_id in motion_modules.call_block_id: 163 | motion_module_id = motion_modules.call_block_id[block_id] 164 | hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id]( 165 | hidden_states, time_emb, text_emb, res_stack, 166 | batch_size=1 167 | ) 168 | 169 | # 5. output 170 | hidden_states = unet.conv_norm_out(hidden_states) 171 | hidden_states = unet.conv_act(hidden_states) 172 | hidden_states = unet.conv_out(hidden_states) 173 | 174 | return hidden_states -------------------------------------------------------------------------------- /diffsynth/pipelines/stable_diffusion.py: -------------------------------------------------------------------------------- 1 | from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder 2 | from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator 3 | from ..prompts import SDPrompter 4 | from ..schedulers import EnhancedDDIMScheduler 5 | from .dancer import lets_dance 6 | from typing import List 7 | import torch 8 | from tqdm import tqdm 9 | from PIL import Image 10 | import numpy as np 11 | 12 | 13 | class SDImagePipeline(torch.nn.Module): 14 | 15 | def __init__(self, device="cuda", torch_dtype=torch.float16): 16 | super().__init__() 17 | self.scheduler = EnhancedDDIMScheduler() 18 | self.prompter = SDPrompter() 19 | self.device = device 20 | self.torch_dtype = torch_dtype 21 | # models 22 | self.text_encoder: SDTextEncoder = None 23 | self.unet: SDUNet = None 24 | self.vae_decoder: SDVAEDecoder = None 25 | self.vae_encoder: SDVAEEncoder = None 26 | self.controlnet: MultiControlNetManager = None 27 | self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None 28 | self.ipadapter: SDIpAdapter = None 29 | 30 | 31 | def fetch_main_models(self, model_manager: ModelManager): 32 | self.text_encoder = model_manager.text_encoder 33 | self.unet = model_manager.unet 34 | self.vae_decoder = model_manager.vae_decoder 35 | self.vae_encoder = model_manager.vae_encoder 36 | 37 | 38 | def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]): 39 | controlnet_units = [] 40 | for config in controlnet_config_units: 41 | controlnet_unit = ControlNetUnit( 42 | Annotator(config.processor_id), 43 | model_manager.get_model_with_model_path(config.model_path), 44 | config.scale 45 | ) 46 | controlnet_units.append(controlnet_unit) 47 | self.controlnet = MultiControlNetManager(controlnet_units) 48 | 49 | 50 | def fetch_ipadapter(self, model_manager: ModelManager): 51 | if "ipadapter" in model_manager.model: 52 | self.ipadapter = model_manager.ipadapter 53 | if "ipadapter_image_encoder" in model_manager.model: 54 | self.ipadapter_image_encoder = model_manager.ipadapter_image_encoder 55 | 56 | 57 | def fetch_prompter(self, model_manager: ModelManager): 58 | self.prompter.load_from_model_manager(model_manager) 59 | 60 | 61 | @staticmethod 62 | def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]): 63 | pipe = SDImagePipeline( 64 | device=model_manager.device, 65 | torch_dtype=model_manager.torch_dtype, 66 | ) 67 | pipe.fetch_main_models(model_manager) 68 | pipe.fetch_prompter(model_manager) 69 | pipe.fetch_controlnet_models(model_manager, controlnet_config_units) 70 | pipe.fetch_ipadapter(model_manager) 71 | return pipe 72 | 73 | 74 | def preprocess_image(self, image): 75 | image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) 76 | return image 77 | 78 | 79 | def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32): 80 | image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] 81 | image = image.cpu().permute(1, 2, 0).numpy() 82 | image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) 83 | return image 84 | 85 | 86 | @torch.no_grad() 87 | def __call__( 88 | self, 89 | prompt, 90 | negative_prompt="", 91 | cfg_scale=7.5, 92 | clip_skip=1, 93 | input_image=None, 94 | ipadapter_images=None, 95 | ipadapter_scale=1.0, 96 | controlnet_image=None, 97 | denoising_strength=1.0, 98 | height=512, 99 | width=512, 100 | num_inference_steps=20, 101 | tiled=False, 102 | tile_size=64, 103 | tile_stride=32, 104 | progress_bar_cmd=tqdm, 105 | progress_bar_st=None, 106 | ): 107 | # Prepare scheduler 108 | self.scheduler.set_timesteps(num_inference_steps, denoising_strength) 109 | 110 | # Prepare latent tensors 111 | if input_image is not None: 112 | image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) 113 | latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) 114 | noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) 115 | latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) 116 | else: 117 | latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) 118 | 119 | # Encode prompts 120 | prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True) 121 | prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False) 122 | 123 | # IP-Adapter 124 | if ipadapter_images is not None: 125 | ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images) 126 | ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale) 127 | ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding)) 128 | else: 129 | ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {} 130 | 131 | # Prepare ControlNets 132 | if controlnet_image is not None: 133 | controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype) 134 | controlnet_image = controlnet_image.unsqueeze(1) 135 | 136 | # Denoise 137 | for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): 138 | timestep = torch.IntTensor((timestep,))[0].to(self.device) 139 | 140 | # Classifier-free guidance 141 | noise_pred_posi = lets_dance( 142 | self.unet, motion_modules=None, controlnet=self.controlnet, 143 | sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_image, 144 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, 145 | ipadapter_kwargs_list=ipadapter_kwargs_list_posi, 146 | device=self.device, vram_limit_level=0 147 | ) 148 | noise_pred_nega = lets_dance( 149 | self.unet, motion_modules=None, controlnet=self.controlnet, 150 | sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_image, 151 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, 152 | ipadapter_kwargs_list=ipadapter_kwargs_list_nega, 153 | device=self.device, vram_limit_level=0 154 | ) 155 | noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) 156 | 157 | # DDIM 158 | latents = self.scheduler.step(noise_pred, timestep, latents) 159 | 160 | # UI 161 | if progress_bar_st is not None: 162 | progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) 163 | 164 | # Decode image 165 | image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) 166 | 167 | return image 168 | -------------------------------------------------------------------------------- /diffsynth/pipelines/stable_diffusion_xl.py: -------------------------------------------------------------------------------- 1 | from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder 2 | # TODO: SDXL ControlNet 3 | from ..prompts import SDXLPrompter 4 | from ..schedulers import EnhancedDDIMScheduler 5 | from .dancer import lets_dance_xl 6 | import torch 7 | from tqdm import tqdm 8 | from PIL import Image 9 | import numpy as np 10 | 11 | 12 | class SDXLImagePipeline(torch.nn.Module): 13 | 14 | def __init__(self, device="cuda", torch_dtype=torch.float16): 15 | super().__init__() 16 | self.scheduler = EnhancedDDIMScheduler() 17 | self.prompter = SDXLPrompter() 18 | self.device = device 19 | self.torch_dtype = torch_dtype 20 | # models 21 | self.text_encoder: SDXLTextEncoder = None 22 | self.text_encoder_2: SDXLTextEncoder2 = None 23 | self.unet: SDXLUNet = None 24 | self.vae_decoder: SDXLVAEDecoder = None 25 | self.vae_encoder: SDXLVAEEncoder = None 26 | self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None 27 | self.ipadapter: SDXLIpAdapter = None 28 | # TODO: SDXL ControlNet 29 | 30 | def fetch_main_models(self, model_manager: ModelManager): 31 | self.text_encoder = model_manager.text_encoder 32 | self.text_encoder_2 = model_manager.text_encoder_2 33 | self.unet = model_manager.unet 34 | self.vae_decoder = model_manager.vae_decoder 35 | self.vae_encoder = model_manager.vae_encoder 36 | 37 | 38 | def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs): 39 | # TODO: SDXL ControlNet 40 | pass 41 | 42 | 43 | def fetch_ipadapter(self, model_manager: ModelManager): 44 | if "ipadapter_xl" in model_manager.model: 45 | self.ipadapter = model_manager.ipadapter_xl 46 | if "ipadapter_xl_image_encoder" in model_manager.model: 47 | self.ipadapter_image_encoder = model_manager.ipadapter_xl_image_encoder 48 | 49 | 50 | def fetch_prompter(self, model_manager: ModelManager): 51 | self.prompter.load_from_model_manager(model_manager) 52 | 53 | 54 | @staticmethod 55 | def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs): 56 | pipe = SDXLImagePipeline( 57 | device=model_manager.device, 58 | torch_dtype=model_manager.torch_dtype, 59 | ) 60 | pipe.fetch_main_models(model_manager) 61 | pipe.fetch_prompter(model_manager) 62 | pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units) 63 | pipe.fetch_ipadapter(model_manager) 64 | return pipe 65 | 66 | 67 | def preprocess_image(self, image): 68 | image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) 69 | return image 70 | 71 | 72 | def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32): 73 | image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] 74 | image = image.cpu().permute(1, 2, 0).numpy() 75 | image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) 76 | return image 77 | 78 | 79 | @torch.no_grad() 80 | def __call__( 81 | self, 82 | prompt, 83 | negative_prompt="", 84 | cfg_scale=7.5, 85 | clip_skip=1, 86 | clip_skip_2=2, 87 | input_image=None, 88 | ipadapter_images=None, 89 | ipadapter_scale=1.0, 90 | controlnet_image=None, 91 | denoising_strength=1.0, 92 | height=1024, 93 | width=1024, 94 | num_inference_steps=20, 95 | tiled=False, 96 | tile_size=64, 97 | tile_stride=32, 98 | progress_bar_cmd=tqdm, 99 | progress_bar_st=None, 100 | ): 101 | # Prepare scheduler 102 | self.scheduler.set_timesteps(num_inference_steps, denoising_strength) 103 | 104 | # Prepare latent tensors 105 | if input_image is not None: 106 | image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) 107 | latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype) 108 | noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) 109 | latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) 110 | else: 111 | latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) 112 | 113 | # Encode prompts 114 | add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt( 115 | self.text_encoder, 116 | self.text_encoder_2, 117 | prompt, 118 | clip_skip=clip_skip, clip_skip_2=clip_skip_2, 119 | device=self.device, 120 | positive=True, 121 | ) 122 | if cfg_scale != 1.0: 123 | add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt( 124 | self.text_encoder, 125 | self.text_encoder_2, 126 | negative_prompt, 127 | clip_skip=clip_skip, clip_skip_2=clip_skip_2, 128 | device=self.device, 129 | positive=False, 130 | ) 131 | 132 | # Prepare positional id 133 | add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device) 134 | 135 | # IP-Adapter 136 | if ipadapter_images is not None: 137 | ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images) 138 | ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale) 139 | ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding)) 140 | else: 141 | ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {} 142 | 143 | # Denoise 144 | for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): 145 | timestep = torch.IntTensor((timestep,))[0].to(self.device) 146 | 147 | # Classifier-free guidance 148 | noise_pred_posi = lets_dance_xl( 149 | self.unet, 150 | sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, 151 | add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi, 152 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, 153 | ipadapter_kwargs_list=ipadapter_kwargs_list_posi, 154 | ) 155 | if cfg_scale != 1.0: 156 | noise_pred_nega = lets_dance_xl( 157 | self.unet, 158 | sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, 159 | add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega, 160 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, 161 | ipadapter_kwargs_list=ipadapter_kwargs_list_nega, 162 | ) 163 | noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) 164 | else: 165 | noise_pred = noise_pred_posi 166 | 167 | latents = self.scheduler.step(noise_pred, timestep, latents) 168 | 169 | if progress_bar_st is not None: 170 | progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) 171 | 172 | # Decode image 173 | image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) 174 | 175 | return image 176 | -------------------------------------------------------------------------------- /diffsynth/pipelines/stable_diffusion_xl_video.py: -------------------------------------------------------------------------------- 1 | from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLMotionModel 2 | from .dancer import lets_dance_xl 3 | # TODO: SDXL ControlNet 4 | from ..prompts import SDXLPrompter 5 | from ..schedulers import EnhancedDDIMScheduler 6 | import torch 7 | from tqdm import tqdm 8 | from PIL import Image 9 | import numpy as np 10 | 11 | 12 | class SDXLVideoPipeline(torch.nn.Module): 13 | 14 | def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True): 15 | super().__init__() 16 | self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear") 17 | self.prompter = SDXLPrompter() 18 | self.device = device 19 | self.torch_dtype = torch_dtype 20 | # models 21 | self.text_encoder: SDXLTextEncoder = None 22 | self.text_encoder_2: SDXLTextEncoder2 = None 23 | self.unet: SDXLUNet = None 24 | self.vae_decoder: SDXLVAEDecoder = None 25 | self.vae_encoder: SDXLVAEEncoder = None 26 | # TODO: SDXL ControlNet 27 | self.motion_modules: SDXLMotionModel = None 28 | 29 | 30 | def fetch_main_models(self, model_manager: ModelManager): 31 | self.text_encoder = model_manager.text_encoder 32 | self.text_encoder_2 = model_manager.text_encoder_2 33 | self.unet = model_manager.unet 34 | self.vae_decoder = model_manager.vae_decoder 35 | self.vae_encoder = model_manager.vae_encoder 36 | 37 | 38 | def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs): 39 | # TODO: SDXL ControlNet 40 | pass 41 | 42 | 43 | def fetch_motion_modules(self, model_manager: ModelManager): 44 | if "motion_modules_xl" in model_manager.model: 45 | self.motion_modules = model_manager.motion_modules_xl 46 | 47 | 48 | def fetch_prompter(self, model_manager: ModelManager): 49 | self.prompter.load_from_model_manager(model_manager) 50 | 51 | 52 | @staticmethod 53 | def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs): 54 | pipe = SDXLVideoPipeline( 55 | device=model_manager.device, 56 | torch_dtype=model_manager.torch_dtype, 57 | use_animatediff="motion_modules_xl" in model_manager.model 58 | ) 59 | pipe.fetch_main_models(model_manager) 60 | pipe.fetch_motion_modules(model_manager) 61 | pipe.fetch_prompter(model_manager) 62 | pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units) 63 | return pipe 64 | 65 | 66 | def preprocess_image(self, image): 67 | image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) 68 | return image 69 | 70 | 71 | def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32): 72 | image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] 73 | image = image.cpu().permute(1, 2, 0).numpy() 74 | image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) 75 | return image 76 | 77 | 78 | def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32): 79 | images = [ 80 | self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) 81 | for frame_id in range(latents.shape[0]) 82 | ] 83 | return images 84 | 85 | 86 | def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32): 87 | latents = [] 88 | for image in processed_images: 89 | image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) 90 | latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu() 91 | latents.append(latent) 92 | latents = torch.concat(latents, dim=0) 93 | return latents 94 | 95 | 96 | @torch.no_grad() 97 | def __call__( 98 | self, 99 | prompt, 100 | negative_prompt="", 101 | cfg_scale=7.5, 102 | clip_skip=1, 103 | clip_skip_2=2, 104 | num_frames=None, 105 | input_frames=None, 106 | controlnet_frames=None, 107 | denoising_strength=1.0, 108 | height=512, 109 | width=512, 110 | num_inference_steps=20, 111 | animatediff_batch_size = 16, 112 | animatediff_stride = 8, 113 | unet_batch_size = 1, 114 | controlnet_batch_size = 1, 115 | cross_frame_attention = False, 116 | smoother=None, 117 | smoother_progress_ids=[], 118 | vram_limit_level=0, 119 | progress_bar_cmd=tqdm, 120 | progress_bar_st=None, 121 | ): 122 | # Prepare scheduler 123 | self.scheduler.set_timesteps(num_inference_steps, denoising_strength) 124 | 125 | # Prepare latent tensors 126 | if self.motion_modules is None: 127 | noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1) 128 | else: 129 | noise = torch.randn((num_frames, 4, height//8, width//8), device="cuda", dtype=self.torch_dtype) 130 | if input_frames is None or denoising_strength == 1.0: 131 | latents = noise 132 | else: 133 | latents = self.encode_images(input_frames) 134 | latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) 135 | 136 | # Encode prompts 137 | add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt( 138 | self.text_encoder, 139 | self.text_encoder_2, 140 | prompt, 141 | clip_skip=clip_skip, clip_skip_2=clip_skip_2, 142 | device=self.device, 143 | positive=True, 144 | ) 145 | if cfg_scale != 1.0: 146 | add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt( 147 | self.text_encoder, 148 | self.text_encoder_2, 149 | negative_prompt, 150 | clip_skip=clip_skip, clip_skip_2=clip_skip_2, 151 | device=self.device, 152 | positive=False, 153 | ) 154 | 155 | # Prepare positional id 156 | add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device) 157 | 158 | # Denoise 159 | for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): 160 | timestep = torch.IntTensor((timestep,))[0].to(self.device) 161 | 162 | # Classifier-free guidance 163 | noise_pred_posi = lets_dance_xl( 164 | self.unet, motion_modules=self.motion_modules, controlnet=None, 165 | sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi, 166 | timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames, 167 | cross_frame_attention=cross_frame_attention, 168 | device=self.device, vram_limit_level=vram_limit_level 169 | ) 170 | if cfg_scale != 1.0: 171 | noise_pred_nega = lets_dance_xl( 172 | self.unet, motion_modules=self.motion_modules, controlnet=None, 173 | sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega, 174 | timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames, 175 | cross_frame_attention=cross_frame_attention, 176 | device=self.device, vram_limit_level=vram_limit_level 177 | ) 178 | noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) 179 | else: 180 | noise_pred = noise_pred_posi 181 | 182 | latents = self.scheduler.step(noise_pred, timestep, latents) 183 | 184 | if progress_bar_st is not None: 185 | progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) 186 | 187 | # Decode image 188 | image = self.decode_images(latents.to(torch.float32)) 189 | 190 | return image 191 | -------------------------------------------------------------------------------- /diffsynth/processors/FastBlend.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import cupy as cp 3 | import numpy as np 4 | from tqdm import tqdm 5 | from ..extensions.FastBlend.patch_match import PyramidPatchMatcher 6 | from ..extensions.FastBlend.runners.fast import TableManager 7 | from .base import VideoProcessor 8 | 9 | 10 | class FastBlendSmoother(VideoProcessor): 11 | def __init__( 12 | self, 13 | inference_mode="fast", batch_size=8, window_size=60, 14 | minimum_patch_size=5, threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0, initialize="identity", tracking_window_size=0 15 | ): 16 | self.inference_mode = inference_mode 17 | self.batch_size = batch_size 18 | self.window_size = window_size 19 | self.ebsynth_config = { 20 | "minimum_patch_size": minimum_patch_size, 21 | "threads_per_block": threads_per_block, 22 | "num_iter": num_iter, 23 | "gpu_id": gpu_id, 24 | "guide_weight": guide_weight, 25 | "initialize": initialize, 26 | "tracking_window_size": tracking_window_size 27 | } 28 | 29 | @staticmethod 30 | def from_model_manager(model_manager, **kwargs): 31 | # TODO: fetch GPU ID from model_manager 32 | return FastBlendSmoother(**kwargs) 33 | 34 | def inference_fast(self, frames_guide, frames_style): 35 | table_manager = TableManager() 36 | patch_match_engine = PyramidPatchMatcher( 37 | image_height=frames_style[0].shape[0], 38 | image_width=frames_style[0].shape[1], 39 | channel=3, 40 | **self.ebsynth_config 41 | ) 42 | # left part 43 | table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, self.batch_size, desc="Fast Mode Step 1/4") 44 | table_l = table_manager.remapping_table_to_blending_table(table_l) 45 | table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 2/4") 46 | # right part 47 | table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, self.batch_size, desc="Fast Mode Step 3/4") 48 | table_r = table_manager.remapping_table_to_blending_table(table_r) 49 | table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 4/4")[::-1] 50 | # merge 51 | frames = [] 52 | for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r): 53 | weight_m = -1 54 | weight = weight_l + weight_m + weight_r 55 | frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight) 56 | frames.append(frame) 57 | frames = [frame.clip(0, 255).astype("uint8") for frame in frames] 58 | frames = [Image.fromarray(frame) for frame in frames] 59 | return frames 60 | 61 | def inference_balanced(self, frames_guide, frames_style): 62 | patch_match_engine = PyramidPatchMatcher( 63 | image_height=frames_style[0].shape[0], 64 | image_width=frames_style[0].shape[1], 65 | channel=3, 66 | **self.ebsynth_config 67 | ) 68 | output_frames = [] 69 | # tasks 70 | n = len(frames_style) 71 | tasks = [] 72 | for target in range(n): 73 | for source in range(target - self.window_size, target + self.window_size + 1): 74 | if source >= 0 and source < n and source != target: 75 | tasks.append((source, target)) 76 | # run 77 | frames = [(None, 1) for i in range(n)] 78 | for batch_id in tqdm(range(0, len(tasks), self.batch_size), desc="Balanced Mode"): 79 | tasks_batch = tasks[batch_id: min(batch_id+self.batch_size, len(tasks))] 80 | source_guide = np.stack([frames_guide[source] for source, target in tasks_batch]) 81 | target_guide = np.stack([frames_guide[target] for source, target in tasks_batch]) 82 | source_style = np.stack([frames_style[source] for source, target in tasks_batch]) 83 | _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) 84 | for (source, target), result in zip(tasks_batch, target_style): 85 | frame, weight = frames[target] 86 | if frame is None: 87 | frame = frames_style[target] 88 | frames[target] = ( 89 | frame * (weight / (weight + 1)) + result / (weight + 1), 90 | weight + 1 91 | ) 92 | if weight + 1 == min(n, target + self.window_size + 1) - max(0, target - self.window_size): 93 | frame = frame.clip(0, 255).astype("uint8") 94 | output_frames.append(Image.fromarray(frame)) 95 | frames[target] = (None, 1) 96 | return output_frames 97 | 98 | def inference_accurate(self, frames_guide, frames_style): 99 | patch_match_engine = PyramidPatchMatcher( 100 | image_height=frames_style[0].shape[0], 101 | image_width=frames_style[0].shape[1], 102 | channel=3, 103 | use_mean_target_style=True, 104 | **self.ebsynth_config 105 | ) 106 | output_frames = [] 107 | # run 108 | n = len(frames_style) 109 | for target in tqdm(range(n), desc="Accurate Mode"): 110 | l, r = max(target - self.window_size, 0), min(target + self.window_size + 1, n) 111 | remapped_frames = [] 112 | for i in range(l, r, self.batch_size): 113 | j = min(i + self.batch_size, r) 114 | source_guide = np.stack([frames_guide[source] for source in range(i, j)]) 115 | target_guide = np.stack([frames_guide[target]] * (j - i)) 116 | source_style = np.stack([frames_style[source] for source in range(i, j)]) 117 | _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) 118 | remapped_frames.append(target_style) 119 | frame = np.concatenate(remapped_frames, axis=0).mean(axis=0) 120 | frame = frame.clip(0, 255).astype("uint8") 121 | output_frames.append(Image.fromarray(frame)) 122 | return output_frames 123 | 124 | def release_vram(self): 125 | mempool = cp.get_default_memory_pool() 126 | pinned_mempool = cp.get_default_pinned_memory_pool() 127 | mempool.free_all_blocks() 128 | pinned_mempool.free_all_blocks() 129 | 130 | def __call__(self, rendered_frames, original_frames=None, **kwargs): 131 | rendered_frames = [np.array(frame) for frame in rendered_frames] 132 | original_frames = [np.array(frame) for frame in original_frames] 133 | if self.inference_mode == "fast": 134 | output_frames = self.inference_fast(original_frames, rendered_frames) 135 | elif self.inference_mode == "balanced": 136 | output_frames = self.inference_balanced(original_frames, rendered_frames) 137 | elif self.inference_mode == "accurate": 138 | output_frames = self.inference_accurate(original_frames, rendered_frames) 139 | else: 140 | raise ValueError("inference_mode must be fast, balanced or accurate") 141 | self.release_vram() 142 | return output_frames 143 | -------------------------------------------------------------------------------- /diffsynth/processors/PILEditor.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageEnhance 2 | from .base import VideoProcessor 3 | 4 | 5 | class ContrastEditor(VideoProcessor): 6 | def __init__(self, rate=1.5): 7 | self.rate = rate 8 | 9 | @staticmethod 10 | def from_model_manager(model_manager, **kwargs): 11 | return ContrastEditor(**kwargs) 12 | 13 | def __call__(self, rendered_frames, **kwargs): 14 | rendered_frames = [ImageEnhance.Contrast(i).enhance(self.rate) for i in rendered_frames] 15 | return rendered_frames 16 | 17 | 18 | class SharpnessEditor(VideoProcessor): 19 | def __init__(self, rate=1.5): 20 | self.rate = rate 21 | 22 | @staticmethod 23 | def from_model_manager(model_manager, **kwargs): 24 | return SharpnessEditor(**kwargs) 25 | 26 | def __call__(self, rendered_frames, **kwargs): 27 | rendered_frames = [ImageEnhance.Sharpness(i).enhance(self.rate) for i in rendered_frames] 28 | return rendered_frames 29 | -------------------------------------------------------------------------------- /diffsynth/processors/RIFE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | from .base import VideoProcessor 5 | 6 | 7 | class RIFESmoother(VideoProcessor): 8 | def __init__(self, model, device="cuda", scale=1.0, batch_size=4, interpolate=True): 9 | self.model = model 10 | self.device = device 11 | 12 | # IFNet only does not support float16 13 | self.torch_dtype = torch.float32 14 | 15 | # Other parameters 16 | self.scale = scale 17 | self.batch_size = batch_size 18 | self.interpolate = interpolate 19 | 20 | @staticmethod 21 | def from_model_manager(model_manager, **kwargs): 22 | return RIFESmoother(model_manager.RIFE, device=model_manager.device, **kwargs) 23 | 24 | def process_image(self, image): 25 | width, height = image.size 26 | if width % 32 != 0 or height % 32 != 0: 27 | width = (width + 31) // 32 28 | height = (height + 31) // 32 29 | image = image.resize((width, height)) 30 | image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1) 31 | return image 32 | 33 | def process_images(self, images): 34 | images = [self.process_image(image) for image in images] 35 | images = torch.stack(images) 36 | return images 37 | 38 | def decode_images(self, images): 39 | images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8) 40 | images = [Image.fromarray(image) for image in images] 41 | return images 42 | 43 | def process_tensors(self, input_tensor, scale=1.0, batch_size=4): 44 | output_tensor = [] 45 | for batch_id in range(0, input_tensor.shape[0], batch_size): 46 | batch_id_ = min(batch_id + batch_size, input_tensor.shape[0]) 47 | batch_input_tensor = input_tensor[batch_id: batch_id_] 48 | batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype) 49 | flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale]) 50 | output_tensor.append(merged[2].cpu()) 51 | output_tensor = torch.concat(output_tensor, dim=0) 52 | return output_tensor 53 | 54 | @torch.no_grad() 55 | def __call__(self, rendered_frames, **kwargs): 56 | # Preprocess 57 | processed_images = self.process_images(rendered_frames) 58 | 59 | # Input 60 | input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1) 61 | 62 | # Interpolate 63 | output_tensor = self.process_tensors(input_tensor, scale=self.scale, batch_size=self.batch_size) 64 | 65 | if self.interpolate: 66 | # Blend 67 | input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1) 68 | output_tensor = self.process_tensors(input_tensor, scale=self.scale, batch_size=self.batch_size) 69 | processed_images[1:-1] = output_tensor 70 | else: 71 | processed_images[1:-1] = (processed_images[1:-1] + output_tensor) / 2 72 | 73 | # To images 74 | output_images = self.decode_images(processed_images) 75 | if output_images[0].size != rendered_frames[0].size: 76 | output_images = [image.resize(rendered_frames[0].size) for image in output_images] 77 | return output_images 78 | -------------------------------------------------------------------------------- /diffsynth/processors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-DiffSynth-Studio/91d0a6bd9efa6a1d4a0b880984e21fe66acf4eb0/diffsynth/processors/__init__.py -------------------------------------------------------------------------------- /diffsynth/processors/base.py: -------------------------------------------------------------------------------- 1 | class VideoProcessor: 2 | def __init__(self): 3 | pass 4 | 5 | def __call__(self): 6 | raise NotImplementedError 7 | -------------------------------------------------------------------------------- /diffsynth/processors/sequencial_processor.py: -------------------------------------------------------------------------------- 1 | from .base import VideoProcessor 2 | 3 | 4 | class AutoVideoProcessor(VideoProcessor): 5 | def __init__(self): 6 | pass 7 | 8 | @staticmethod 9 | def from_model_manager(model_manager, processor_type, **kwargs): 10 | if processor_type == "FastBlend": 11 | from .FastBlend import FastBlendSmoother 12 | return FastBlendSmoother.from_model_manager(model_manager, **kwargs) 13 | elif processor_type == "Contrast": 14 | from .PILEditor import ContrastEditor 15 | return ContrastEditor.from_model_manager(model_manager, **kwargs) 16 | elif processor_type == "Sharpness": 17 | from .PILEditor import SharpnessEditor 18 | return SharpnessEditor.from_model_manager(model_manager, **kwargs) 19 | elif processor_type == "RIFE": 20 | from .RIFE import RIFESmoother 21 | return RIFESmoother.from_model_manager(model_manager, **kwargs) 22 | else: 23 | raise ValueError(f"invalid processor_type: {processor_type}") 24 | 25 | 26 | class SequencialProcessor(VideoProcessor): 27 | def __init__(self, processors=[]): 28 | self.processors = processors 29 | 30 | @staticmethod 31 | def from_model_manager(model_manager, configs): 32 | processors = [ 33 | AutoVideoProcessor.from_model_manager(model_manager, config["processor_type"], **config["config"]) 34 | for config in configs 35 | ] 36 | return SequencialProcessor(processors) 37 | 38 | def __call__(self, rendered_frames, **kwargs): 39 | for processor in self.processors: 40 | rendered_frames = processor(rendered_frames, **kwargs) 41 | return rendered_frames 42 | -------------------------------------------------------------------------------- /diffsynth/prompts/__init__.py: -------------------------------------------------------------------------------- 1 | from .sd_prompter import SDPrompter 2 | from .sdxl_prompter import SDXLPrompter 3 | from .hunyuan_dit_prompter import HunyuanDiTPrompter 4 | -------------------------------------------------------------------------------- /diffsynth/prompts/hunyuan_dit_prompter.py: -------------------------------------------------------------------------------- 1 | from .utils import Prompter 2 | from transformers import BertModel, T5EncoderModel, BertTokenizer, AutoTokenizer 3 | import warnings, os 4 | 5 | 6 | class HunyuanDiTPrompter(Prompter): 7 | def __init__( 8 | self, 9 | tokenizer_path=None, 10 | tokenizer_t5_path=None 11 | ): 12 | if tokenizer_path is None: 13 | base_path = os.path.dirname(os.path.dirname(__file__)) 14 | tokenizer_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer") 15 | if tokenizer_t5_path is None: 16 | base_path = os.path.dirname(os.path.dirname(__file__)) 17 | tokenizer_t5_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer_t5") 18 | super().__init__() 19 | self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path) 20 | with warnings.catch_warnings(): 21 | warnings.simplefilter("ignore") 22 | self.tokenizer_t5 = AutoTokenizer.from_pretrained(tokenizer_t5_path) 23 | 24 | 25 | def encode_prompt_using_signle_model(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device): 26 | text_inputs = tokenizer( 27 | prompt, 28 | padding="max_length", 29 | max_length=max_length, 30 | truncation=True, 31 | return_attention_mask=True, 32 | return_tensors="pt", 33 | ) 34 | text_input_ids = text_inputs.input_ids 35 | attention_mask = text_inputs.attention_mask.to(device) 36 | prompt_embeds = text_encoder( 37 | text_input_ids.to(device), 38 | attention_mask=attention_mask, 39 | clip_skip=clip_skip 40 | ) 41 | return prompt_embeds, attention_mask 42 | 43 | 44 | def encode_prompt( 45 | self, 46 | text_encoder: BertModel, 47 | text_encoder_t5: T5EncoderModel, 48 | prompt, 49 | clip_skip=1, 50 | clip_skip_2=1, 51 | positive=True, 52 | device="cuda" 53 | ): 54 | prompt = self.process_prompt(prompt, positive=positive) 55 | 56 | # CLIP 57 | prompt_emb, attention_mask = self.encode_prompt_using_signle_model(prompt, text_encoder, self.tokenizer, self.tokenizer.model_max_length, clip_skip, device) 58 | 59 | # T5 60 | prompt_emb_t5, attention_mask_t5 = self.encode_prompt_using_signle_model(prompt, text_encoder_t5, self.tokenizer_t5, self.tokenizer_t5.model_max_length, clip_skip_2, device) 61 | 62 | return prompt_emb, attention_mask, prompt_emb_t5, attention_mask_t5 63 | -------------------------------------------------------------------------------- /diffsynth/prompts/sd_prompter.py: -------------------------------------------------------------------------------- 1 | from .utils import Prompter, tokenize_long_prompt 2 | from transformers import CLIPTokenizer 3 | from ..models import SDTextEncoder 4 | import os 5 | 6 | 7 | class SDPrompter(Prompter): 8 | def __init__(self, tokenizer_path=None): 9 | if tokenizer_path is None: 10 | base_path = os.path.dirname(os.path.dirname(__file__)) 11 | tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer") 12 | super().__init__() 13 | self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) 14 | 15 | def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True): 16 | prompt = self.process_prompt(prompt, positive=positive) 17 | input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device) 18 | prompt_emb = text_encoder(input_ids, clip_skip=clip_skip) 19 | prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1)) 20 | 21 | return prompt_emb -------------------------------------------------------------------------------- /diffsynth/prompts/sdxl_prompter.py: -------------------------------------------------------------------------------- 1 | from .utils import Prompter, tokenize_long_prompt 2 | from transformers import CLIPTokenizer 3 | from ..models import SDXLTextEncoder, SDXLTextEncoder2 4 | import torch, os 5 | 6 | 7 | class SDXLPrompter(Prompter): 8 | def __init__( 9 | self, 10 | tokenizer_path=None, 11 | tokenizer_2_path=None 12 | ): 13 | if tokenizer_path is None: 14 | base_path = os.path.dirname(os.path.dirname(__file__)) 15 | tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer") 16 | if tokenizer_2_path is None: 17 | base_path = os.path.dirname(os.path.dirname(__file__)) 18 | tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_xl/tokenizer_2") 19 | super().__init__() 20 | self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) 21 | self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path) 22 | 23 | def encode_prompt( 24 | self, 25 | text_encoder: SDXLTextEncoder, 26 | text_encoder_2: SDXLTextEncoder2, 27 | prompt, 28 | clip_skip=1, 29 | clip_skip_2=2, 30 | positive=True, 31 | device="cuda" 32 | ): 33 | prompt = self.process_prompt(prompt, positive=positive) 34 | 35 | # 1 36 | input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device) 37 | prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip) 38 | 39 | # 2 40 | input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device) 41 | add_text_embeds, prompt_emb_2 = text_encoder_2(input_ids_2, clip_skip=clip_skip_2) 42 | 43 | # Merge 44 | prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1) 45 | 46 | # For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`. 47 | add_text_embeds = add_text_embeds[0:1] 48 | prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1)) 49 | return add_text_embeds, prompt_emb 50 | -------------------------------------------------------------------------------- /diffsynth/prompts/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import CLIPTokenizer, AutoTokenizer 2 | from ..models import ModelManager 3 | import os 4 | 5 | 6 | def tokenize_long_prompt(tokenizer, prompt): 7 | # Get model_max_length from self.tokenizer 8 | length = tokenizer.model_max_length 9 | 10 | # To avoid the warning. set self.tokenizer.model_max_length to +oo. 11 | tokenizer.model_max_length = 99999999 12 | 13 | # Tokenize it! 14 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids 15 | 16 | # Determine the real length. 17 | max_length = (input_ids.shape[1] + length - 1) // length * length 18 | 19 | # Restore tokenizer.model_max_length 20 | tokenizer.model_max_length = length 21 | 22 | # Tokenize it again with fixed length. 23 | input_ids = tokenizer( 24 | prompt, 25 | return_tensors="pt", 26 | padding="max_length", 27 | max_length=max_length, 28 | truncation=True 29 | ).input_ids 30 | 31 | # Reshape input_ids to fit the text encoder. 32 | num_sentence = input_ids.shape[1] // length 33 | input_ids = input_ids.reshape((num_sentence, length)) 34 | 35 | return input_ids 36 | 37 | 38 | class BeautifulPrompt: 39 | def __init__(self, tokenizer_path=None, model=None): 40 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 41 | self.model = model 42 | self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:' 43 | 44 | def __call__(self, raw_prompt): 45 | model_input = self.template.format(raw_prompt=raw_prompt) 46 | input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device) 47 | outputs = self.model.generate( 48 | input_ids, 49 | max_new_tokens=384, 50 | do_sample=True, 51 | temperature=0.9, 52 | top_k=50, 53 | top_p=0.95, 54 | repetition_penalty=1.1, 55 | num_return_sequences=1 56 | ) 57 | prompt = raw_prompt + ", " + self.tokenizer.batch_decode( 58 | outputs[:, input_ids.size(1):], 59 | skip_special_tokens=True 60 | )[0].strip() 61 | return prompt 62 | 63 | 64 | class Translator: 65 | def __init__(self, tokenizer_path=None, model=None): 66 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 67 | self.model = model 68 | 69 | def __call__(self, prompt): 70 | input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device) 71 | output_ids = self.model.generate(input_ids) 72 | prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 73 | return prompt 74 | 75 | 76 | class Prompter: 77 | def __init__(self): 78 | self.tokenizer: CLIPTokenizer = None 79 | self.keyword_dict = {} 80 | self.translator: Translator = None 81 | self.beautiful_prompt: BeautifulPrompt = None 82 | 83 | def load_textual_inversion(self, textual_inversion_dict): 84 | self.keyword_dict = {} 85 | additional_tokens = [] 86 | for keyword in textual_inversion_dict: 87 | tokens, _ = textual_inversion_dict[keyword] 88 | additional_tokens += tokens 89 | self.keyword_dict[keyword] = " " + " ".join(tokens) + " " 90 | self.tokenizer.add_tokens(additional_tokens) 91 | 92 | def load_beautiful_prompt(self, model, model_path): 93 | model_folder = os.path.dirname(model_path) 94 | self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model) 95 | if model_folder.endswith("v2"): 96 | self.beautiful_prompt.template = """Converts a simple image description into a prompt. \ 97 | Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \ 98 | or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \ 99 | but make sure there is a correlation between the input and output.\n\ 100 | ### Input: {raw_prompt}\n### Output:""" 101 | 102 | def load_translator(self, model, model_path): 103 | model_folder = os.path.dirname(model_path) 104 | self.translator = Translator(tokenizer_path=model_folder, model=model) 105 | 106 | def load_from_model_manager(self, model_manager: ModelManager): 107 | self.load_textual_inversion(model_manager.textual_inversion_dict) 108 | if "translator" in model_manager.model: 109 | self.load_translator(model_manager.model["translator"], model_manager.model_path["translator"]) 110 | if "beautiful_prompt" in model_manager.model: 111 | self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"]) 112 | 113 | def process_prompt(self, prompt, positive=True): 114 | for keyword in self.keyword_dict: 115 | if keyword in prompt: 116 | prompt = prompt.replace(keyword, self.keyword_dict[keyword]) 117 | if positive and self.translator is not None: 118 | prompt = self.translator(prompt) 119 | print(f"Your prompt is translated: \"{prompt}\"") 120 | if positive and self.beautiful_prompt is not None: 121 | prompt = self.beautiful_prompt(prompt) 122 | print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"") 123 | return prompt 124 | -------------------------------------------------------------------------------- /diffsynth/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from .ddim import EnhancedDDIMScheduler 2 | from .continuous_ode import ContinuousODEScheduler 3 | -------------------------------------------------------------------------------- /diffsynth/schedulers/continuous_ode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ContinuousODEScheduler(): 5 | 6 | def __init__(self, num_inference_steps=100, sigma_max=700.0, sigma_min=0.002, rho=7.0): 7 | self.sigma_max = sigma_max 8 | self.sigma_min = sigma_min 9 | self.rho = rho 10 | self.set_timesteps(num_inference_steps) 11 | 12 | 13 | def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0): 14 | ramp = torch.linspace(1-denoising_strength, 1, num_inference_steps) 15 | min_inv_rho = torch.pow(torch.tensor((self.sigma_min,)), (1 / self.rho)) 16 | max_inv_rho = torch.pow(torch.tensor((self.sigma_max,)), (1 / self.rho)) 17 | self.sigmas = torch.pow(max_inv_rho + ramp * (min_inv_rho - max_inv_rho), self.rho) 18 | self.timesteps = torch.log(self.sigmas) * 0.25 19 | 20 | 21 | def step(self, model_output, timestep, sample, to_final=False): 22 | timestep_id = torch.argmin((self.timesteps - timestep).abs()) 23 | sigma = self.sigmas[timestep_id] 24 | sample *= (sigma*sigma + 1).sqrt() 25 | estimated_sample = -sigma / (sigma*sigma + 1).sqrt() * model_output + 1 / (sigma*sigma + 1) * sample 26 | if to_final or timestep_id + 1 >= len(self.timesteps): 27 | prev_sample = estimated_sample 28 | else: 29 | sigma_ = self.sigmas[timestep_id + 1] 30 | derivative = 1 / sigma * (sample - estimated_sample) 31 | prev_sample = sample + derivative * (sigma_ - sigma) 32 | prev_sample /= (sigma_*sigma_ + 1).sqrt() 33 | return prev_sample 34 | 35 | 36 | def return_to_timestep(self, timestep, sample, sample_stablized): 37 | # This scheduler doesn't support this function. 38 | pass 39 | 40 | 41 | def add_noise(self, original_samples, noise, timestep): 42 | timestep_id = torch.argmin((self.timesteps - timestep).abs()) 43 | sigma = self.sigmas[timestep_id] 44 | sample = (original_samples + noise * sigma) / (sigma*sigma + 1).sqrt() 45 | return sample 46 | 47 | 48 | def training_target(self, sample, noise, timestep): 49 | timestep_id = torch.argmin((self.timesteps - timestep).abs()) 50 | sigma = self.sigmas[timestep_id] 51 | target = (-(sigma*sigma + 1).sqrt() / sigma + 1 / (sigma*sigma + 1).sqrt() / sigma) * sample + 1 / (sigma*sigma + 1).sqrt() * noise 52 | return target 53 | 54 | 55 | def training_weight(self, timestep): 56 | timestep_id = torch.argmin((self.timesteps - timestep).abs()) 57 | sigma = self.sigmas[timestep_id] 58 | weight = (1 + sigma*sigma).sqrt() / sigma 59 | return weight 60 | -------------------------------------------------------------------------------- /diffsynth/schedulers/ddim.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | 3 | 4 | class EnhancedDDIMScheduler(): 5 | 6 | def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon"): 7 | self.num_train_timesteps = num_train_timesteps 8 | if beta_schedule == "scaled_linear": 9 | betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32)) 10 | elif beta_schedule == "linear": 11 | betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) 12 | else: 13 | raise NotImplementedError(f"{beta_schedule} is not implemented") 14 | self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0).tolist() 15 | self.set_timesteps(10) 16 | self.prediction_type = prediction_type 17 | 18 | 19 | def set_timesteps(self, num_inference_steps, denoising_strength=1.0): 20 | # The timesteps are aligned to 999...0, which is different from other implementations, 21 | # but I think this implementation is more reasonable in theory. 22 | max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0) 23 | num_inference_steps = min(num_inference_steps, max_timestep + 1) 24 | if num_inference_steps == 1: 25 | self.timesteps = [max_timestep] 26 | else: 27 | step_length = max_timestep / (num_inference_steps - 1) 28 | self.timesteps = [round(max_timestep - i*step_length) for i in range(num_inference_steps)] 29 | 30 | 31 | def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev): 32 | if self.prediction_type == "epsilon": 33 | weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t) 34 | weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t) 35 | prev_sample = sample * weight_x + model_output * weight_e 36 | elif self.prediction_type == "v_prediction": 37 | weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev)) 38 | weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev)) 39 | prev_sample = sample * weight_x + model_output * weight_e 40 | else: 41 | raise NotImplementedError(f"{self.prediction_type} is not implemented") 42 | return prev_sample 43 | 44 | 45 | def step(self, model_output, timestep, sample, to_final=False): 46 | alpha_prod_t = self.alphas_cumprod[timestep] 47 | timestep_id = self.timesteps.index(timestep) 48 | if to_final or timestep_id + 1 >= len(self.timesteps): 49 | alpha_prod_t_prev = 1.0 50 | else: 51 | timestep_prev = self.timesteps[timestep_id + 1] 52 | alpha_prod_t_prev = self.alphas_cumprod[timestep_prev] 53 | 54 | return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev) 55 | 56 | 57 | def return_to_timestep(self, timestep, sample, sample_stablized): 58 | alpha_prod_t = self.alphas_cumprod[timestep] 59 | noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t) 60 | return noise_pred 61 | 62 | 63 | def add_noise(self, original_samples, noise, timestep): 64 | sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep]) 65 | sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep]) 66 | noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise 67 | return noisy_samples 68 | 69 | def training_target(self, sample, noise, timestep): 70 | sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep]) 71 | sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep]) 72 | target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample 73 | return target 74 | -------------------------------------------------------------------------------- /diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "cls_token": "[CLS]", 3 | "mask_token": "[MASK]", 4 | "pad_token": "[PAD]", 5 | "sep_token": "[SEP]", 6 | "unk_token": "[UNK]" 7 | } 8 | -------------------------------------------------------------------------------- /diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "cls_token": "[CLS]", 3 | "do_basic_tokenize": true, 4 | "do_lower_case": true, 5 | "mask_token": "[MASK]", 6 | "name_or_path": "hfl/chinese-roberta-wwm-ext", 7 | "never_split": null, 8 | "pad_token": "[PAD]", 9 | "sep_token": "[SEP]", 10 | "special_tokens_map_file": "/home/chenweifeng/.cache/huggingface/hub/models--hfl--chinese-roberta-wwm-ext/snapshots/5c58d0b8ec1d9014354d691c538661bf00bfdb44/special_tokens_map.json", 11 | "strip_accents": null, 12 | "tokenize_chinese_chars": true, 13 | "tokenizer_class": "BertTokenizer", 14 | "unk_token": "[UNK]", 15 | "model_max_length": 77 16 | } 17 | -------------------------------------------------------------------------------- /diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "/home/patrick/t5/mt5-xl", 3 | "architectures": [ 4 | "MT5ForConditionalGeneration" 5 | ], 6 | "d_ff": 5120, 7 | "d_kv": 64, 8 | "d_model": 2048, 9 | "decoder_start_token_id": 0, 10 | "dropout_rate": 0.1, 11 | "eos_token_id": 1, 12 | "feed_forward_proj": "gated-gelu", 13 | "initializer_factor": 1.0, 14 | "is_encoder_decoder": true, 15 | "layer_norm_epsilon": 1e-06, 16 | "model_type": "mt5", 17 | "num_decoder_layers": 24, 18 | "num_heads": 32, 19 | "num_layers": 24, 20 | "output_past": true, 21 | "pad_token_id": 0, 22 | "relative_attention_num_buckets": 32, 23 | "tie_word_embeddings": false, 24 | "tokenizer_class": "T5Tokenizer", 25 | "transformers_version": "4.10.0.dev0", 26 | "use_cache": true, 27 | "vocab_size": 250112 28 | } 29 | -------------------------------------------------------------------------------- /diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"eos_token": "", "unk_token": "", "pad_token": ""} -------------------------------------------------------------------------------- /diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-DiffSynth-Studio/91d0a6bd9efa6a1d4a0b880984e21fe66acf4eb0/diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model -------------------------------------------------------------------------------- /diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"eos_token": "", "unk_token": "", "pad_token": "", "extra_ids": 0, "additional_special_tokens": null, "special_tokens_map_file": "", "tokenizer_file": null, "name_or_path": "google/mt5-small", "model_max_length": 256, "legacy": true} -------------------------------------------------------------------------------- /diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "<|startoftext|>", 4 | "lstrip": false, 5 | "normalized": true, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "<|endoftext|>", 11 | "lstrip": false, 12 | "normalized": true, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "pad_token": "<|endoftext|>", 17 | "unk_token": { 18 | "content": "<|endoftext|>", 19 | "lstrip": false, 20 | "normalized": true, 21 | "rstrip": false, 22 | "single_word": false 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "bos_token": { 4 | "__type": "AddedToken", 5 | "content": "<|startoftext|>", 6 | "lstrip": false, 7 | "normalized": true, 8 | "rstrip": false, 9 | "single_word": false 10 | }, 11 | "do_lower_case": true, 12 | "eos_token": { 13 | "__type": "AddedToken", 14 | "content": "<|endoftext|>", 15 | "lstrip": false, 16 | "normalized": true, 17 | "rstrip": false, 18 | "single_word": false 19 | }, 20 | "errors": "replace", 21 | "model_max_length": 77, 22 | "name_or_path": "openai/clip-vit-large-patch14", 23 | "pad_token": "<|endoftext|>", 24 | "special_tokens_map_file": "./special_tokens_map.json", 25 | "tokenizer_class": "CLIPTokenizer", 26 | "unk_token": { 27 | "__type": "AddedToken", 28 | "content": "<|endoftext|>", 29 | "lstrip": false, 30 | "normalized": true, 31 | "rstrip": false, 32 | "single_word": false 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "<|startoftext|>", 4 | "lstrip": false, 5 | "normalized": true, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "<|endoftext|>", 11 | "lstrip": false, 12 | "normalized": true, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "pad_token": "!", 17 | "unk_token": { 18 | "content": "<|endoftext|>", 19 | "lstrip": false, 20 | "normalized": true, 21 | "rstrip": false, 22 | "single_word": false 23 | } 24 | } -------------------------------------------------------------------------------- /diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "added_tokens_decoder": { 4 | "0": { 5 | "content": "!", 6 | "lstrip": false, 7 | "normalized": false, 8 | "rstrip": false, 9 | "single_word": false, 10 | "special": true 11 | }, 12 | "49406": { 13 | "content": "<|startoftext|>", 14 | "lstrip": false, 15 | "normalized": true, 16 | "rstrip": false, 17 | "single_word": false, 18 | "special": true 19 | }, 20 | "49407": { 21 | "content": "<|endoftext|>", 22 | "lstrip": false, 23 | "normalized": true, 24 | "rstrip": false, 25 | "single_word": false, 26 | "special": true 27 | } 28 | }, 29 | "bos_token": "<|startoftext|>", 30 | "clean_up_tokenization_spaces": true, 31 | "do_lower_case": true, 32 | "eos_token": "<|endoftext|>", 33 | "errors": "replace", 34 | "model_max_length": 77, 35 | "pad_token": "!", 36 | "tokenizer_class": "CLIPTokenizer", 37 | "unk_token": "<|endoftext|>" 38 | } -------------------------------------------------------------------------------- /models/put diffsynth studio models here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-DiffSynth-Studio/91d0a6bd9efa6a1d4a0b880984e21fe66acf4eb0/models/put diffsynth studio models here -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cupy-cuda12x 2 | pip 3 | transformers 4 | controlnet-aux==0.0.7 5 | streamlit 6 | streamlit-drawable-canvas 7 | imageio 8 | imageio[ffmpeg] 9 | safetensors 10 | einops 11 | sentencepiece -------------------------------------------------------------------------------- /util_nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import folder_paths 3 | now_dir = os.path.dirname(os.path.abspath(__file__)) 4 | input_dir = folder_paths.get_input_directory() 5 | output_dir = folder_paths.get_output_directory() 6 | 7 | class LoadVideo: 8 | @classmethod 9 | def INPUT_TYPES(s): 10 | files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.split('.')[-1] in ["mp4", "webm","mkv","avi"]] 11 | return {"required":{ 12 | "video":(files,), 13 | }} 14 | 15 | CATEGORY = "AIFSH_DiffSynth-Studio" 16 | DESCRIPTION = "hello world!" 17 | 18 | RETURN_TYPES = ("VIDEO",) 19 | 20 | OUTPUT_NODE = False 21 | 22 | FUNCTION = "load_video" 23 | 24 | def load_video(self, video): 25 | video_path = os.path.join(input_dir,video) 26 | return (video_path,) 27 | 28 | class PreViewVideo: 29 | @classmethod 30 | def INPUT_TYPES(s): 31 | return {"required":{ 32 | "video":("VIDEO",), 33 | }} 34 | 35 | CATEGORY = "AIFSH_DiffSynth-Studio" 36 | DESCRIPTION = "hello world!" 37 | 38 | RETURN_TYPES = () 39 | 40 | OUTPUT_NODE = True 41 | 42 | FUNCTION = "load_video" 43 | 44 | def load_video(self, video): 45 | video_name = os.path.basename(video) 46 | video_path_name = os.path.basename(os.path.dirname(video)) 47 | return {"ui":{"video":[video_name,video_path_name]}} 48 | -------------------------------------------------------------------------------- /web.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIFSH/ComfyUI-DiffSynth-Studio/91d0a6bd9efa6a1d4a0b880984e21fe66acf4eb0/web.png -------------------------------------------------------------------------------- /web/js/previewVideo.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../../scripts/app.js"; 2 | import { api } from '../../../scripts/api.js' 3 | 4 | function fitHeight(node) { 5 | node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]]) 6 | node?.graph?.setDirtyCanvas(true); 7 | } 8 | function chainCallback(object, property, callback) { 9 | if (object == undefined) { 10 | //This should not happen. 11 | console.error("Tried to add callback to non-existant object") 12 | return; 13 | } 14 | if (property in object) { 15 | const callback_orig = object[property] 16 | object[property] = function () { 17 | const r = callback_orig.apply(this, arguments); 18 | callback.apply(this, arguments); 19 | return r 20 | }; 21 | } else { 22 | object[property] = callback; 23 | } 24 | } 25 | 26 | function addPreviewOptions(nodeType) { 27 | chainCallback(nodeType.prototype, "getExtraMenuOptions", function(_, options) { 28 | // The intended way of appending options is returning a list of extra options, 29 | // but this isn't used in widgetInputs.js and would require 30 | // less generalization of chainCallback 31 | let optNew = [] 32 | try { 33 | const previewWidget = this.widgets.find((w) => w.name === "videopreview"); 34 | 35 | let url = null 36 | if (previewWidget.videoEl?.hidden == false && previewWidget.videoEl.src) { 37 | //Use full quality video 38 | //url = api.apiURL('/view?' + new URLSearchParams(previewWidget.value.params)); 39 | url = previewWidget.videoEl.src 40 | } 41 | if (url) { 42 | optNew.push( 43 | { 44 | content: "Open preview", 45 | callback: () => { 46 | window.open(url, "_blank") 47 | }, 48 | }, 49 | { 50 | content: "Save preview", 51 | callback: () => { 52 | const a = document.createElement("a"); 53 | a.href = url; 54 | a.setAttribute("download", new URLSearchParams(previewWidget.value.params).get("filename")); 55 | document.body.append(a); 56 | a.click(); 57 | requestAnimationFrame(() => a.remove()); 58 | }, 59 | } 60 | ); 61 | } 62 | if(options.length > 0 && options[0] != null && optNew.length > 0) { 63 | optNew.push(null); 64 | } 65 | options.unshift(...optNew); 66 | 67 | } catch (error) { 68 | console.log(error); 69 | } 70 | 71 | }); 72 | } 73 | function previewVideo(node,file,type){ 74 | var element = document.createElement("div"); 75 | const previewNode = node; 76 | var previewWidget = node.addDOMWidget("videopreview", "preview", element, { 77 | serialize: false, 78 | hideOnZoom: false, 79 | getValue() { 80 | return element.value; 81 | }, 82 | setValue(v) { 83 | element.value = v; 84 | }, 85 | }); 86 | previewWidget.computeSize = function(width) { 87 | if (this.aspectRatio && !this.parentEl.hidden) { 88 | let height = (previewNode.size[0]-20)/ this.aspectRatio + 10; 89 | if (!(height > 0)) { 90 | height = 0; 91 | } 92 | this.computedHeight = height + 10; 93 | return [width, height]; 94 | } 95 | return [width, -4];//no loaded src, widget should not display 96 | } 97 | // element.style['pointer-events'] = "none" 98 | previewWidget.value = {hidden: false, paused: false, params: {}} 99 | previewWidget.parentEl = document.createElement("div"); 100 | previewWidget.parentEl.className = "video_preview"; 101 | previewWidget.parentEl.style['width'] = "100%" 102 | element.appendChild(previewWidget.parentEl); 103 | previewWidget.videoEl = document.createElement("video"); 104 | previewWidget.videoEl.controls = true; 105 | previewWidget.videoEl.loop = false; 106 | previewWidget.videoEl.muted = false; 107 | previewWidget.videoEl.style['width'] = "100%" 108 | previewWidget.videoEl.addEventListener("loadedmetadata", () => { 109 | 110 | previewWidget.aspectRatio = previewWidget.videoEl.videoWidth / previewWidget.videoEl.videoHeight; 111 | fitHeight(this); 112 | }); 113 | previewWidget.videoEl.addEventListener("error", () => { 114 | //TODO: consider a way to properly notify the user why a preview isn't shown. 115 | previewWidget.parentEl.hidden = true; 116 | fitHeight(this); 117 | }); 118 | 119 | let params = { 120 | "filename": file, 121 | "type": type, 122 | } 123 | 124 | previewWidget.parentEl.hidden = previewWidget.value.hidden; 125 | previewWidget.videoEl.autoplay = !previewWidget.value.paused && !previewWidget.value.hidden; 126 | let target_width = 256 127 | if (element.style?.width) { 128 | //overscale to allow scrolling. Endpoint won't return higher than native 129 | target_width = element.style.width.slice(0,-2)*2; 130 | } 131 | if (!params.force_size || params.force_size.includes("?") || params.force_size == "Disabled") { 132 | params.force_size = target_width+"x?" 133 | } else { 134 | let size = params.force_size.split("x") 135 | let ar = parseInt(size[0])/parseInt(size[1]) 136 | params.force_size = target_width+"x"+(target_width/ar) 137 | } 138 | 139 | previewWidget.videoEl.src = api.apiURL('/view?' + new URLSearchParams(params)); 140 | 141 | previewWidget.videoEl.hidden = false; 142 | previewWidget.parentEl.appendChild(previewWidget.videoEl) 143 | } 144 | 145 | app.registerExtension({ 146 | name: "DiffSynth-Studio.VideoPreviewer", 147 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 148 | if (nodeData?.name == "PreViewVideo") { 149 | nodeType.prototype.onExecuted = function (data) { 150 | previewVideo(this, data.video[0], data.video[1]); 151 | } 152 | } 153 | } 154 | }); 155 | -------------------------------------------------------------------------------- /web/js/uploadVideo.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../../scripts/app.js"; 2 | import { api } from '../../../scripts/api.js' 3 | import { ComfyWidgets } from "../../../scripts/widgets.js" 4 | 5 | function fitHeight(node) { 6 | node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]]) 7 | node?.graph?.setDirtyCanvas(true); 8 | } 9 | 10 | function previewVideo(node,file){ 11 | while (node.widgets.length > 2){ 12 | node.widgets.pop() 13 | } 14 | try { 15 | var el = document.getElementById("uploadVideo"); 16 | el.remove(); 17 | } catch (error) { 18 | console.log(error); 19 | } 20 | var element = document.createElement("div"); 21 | element.id = "uploadVideo"; 22 | const previewNode = node; 23 | var previewWidget = node.addDOMWidget("videopreview", "preview", element, { 24 | serialize: false, 25 | hideOnZoom: false, 26 | getValue() { 27 | return element.value; 28 | }, 29 | setValue(v) { 30 | element.value = v; 31 | }, 32 | }); 33 | previewWidget.computeSize = function(width) { 34 | if (this.aspectRatio && !this.parentEl.hidden) { 35 | let height = (previewNode.size[0]-20)/ this.aspectRatio + 10; 36 | if (!(height > 0)) { 37 | height = 0; 38 | } 39 | this.computedHeight = height + 10; 40 | return [width, height]; 41 | } 42 | return [width, -4];//no loaded src, widget should not display 43 | } 44 | // element.style['pointer-events'] = "none" 45 | previewWidget.value = {hidden: false, paused: false, params: {}} 46 | previewWidget.parentEl = document.createElement("div"); 47 | previewWidget.parentEl.className = "video_preview"; 48 | previewWidget.parentEl.style['width'] = "100%" 49 | element.appendChild(previewWidget.parentEl); 50 | previewWidget.videoEl = document.createElement("video"); 51 | previewWidget.videoEl.controls = true; 52 | previewWidget.videoEl.loop = false; 53 | previewWidget.videoEl.muted = false; 54 | previewWidget.videoEl.style['width'] = "100%" 55 | previewWidget.videoEl.addEventListener("loadedmetadata", () => { 56 | 57 | previewWidget.aspectRatio = previewWidget.videoEl.videoWidth / previewWidget.videoEl.videoHeight; 58 | fitHeight(this); 59 | }); 60 | previewWidget.videoEl.addEventListener("error", () => { 61 | //TODO: consider a way to properly notify the user why a preview isn't shown. 62 | previewWidget.parentEl.hidden = true; 63 | fitHeight(this); 64 | }); 65 | 66 | let params = { 67 | "filename": file, 68 | "type": "input", 69 | } 70 | 71 | previewWidget.parentEl.hidden = previewWidget.value.hidden; 72 | previewWidget.videoEl.autoplay = !previewWidget.value.paused && !previewWidget.value.hidden; 73 | let target_width = 256 74 | if (element.style?.width) { 75 | //overscale to allow scrolling. Endpoint won't return higher than native 76 | target_width = element.style.width.slice(0,-2)*2; 77 | } 78 | if (!params.force_size || params.force_size.includes("?") || params.force_size == "Disabled") { 79 | params.force_size = target_width+"x?" 80 | } else { 81 | let size = params.force_size.split("x") 82 | let ar = parseInt(size[0])/parseInt(size[1]) 83 | params.force_size = target_width+"x"+(target_width/ar) 84 | } 85 | 86 | previewWidget.videoEl.src = api.apiURL('/view?' + new URLSearchParams(params)); 87 | 88 | previewWidget.videoEl.hidden = false; 89 | previewWidget.parentEl.appendChild(previewWidget.videoEl) 90 | } 91 | 92 | function videoUpload(node, inputName, inputData, app) { 93 | const videoWidget = node.widgets.find((w) => w.name === "video"); 94 | let uploadWidget; 95 | /* 96 | A method that returns the required style for the html 97 | */ 98 | var default_value = videoWidget.value; 99 | Object.defineProperty(videoWidget, "value", { 100 | set : function(value) { 101 | this._real_value = value; 102 | }, 103 | 104 | get : function() { 105 | let value = ""; 106 | if (this._real_value) { 107 | value = this._real_value; 108 | } else { 109 | return default_value; 110 | } 111 | 112 | if (value.filename) { 113 | let real_value = value; 114 | value = ""; 115 | if (real_value.subfolder) { 116 | value = real_value.subfolder + "/"; 117 | } 118 | 119 | value += real_value.filename; 120 | 121 | if(real_value.type && real_value.type !== "input") 122 | value += ` [${real_value.type}]`; 123 | } 124 | return value; 125 | } 126 | }); 127 | async function uploadFile(file, updateNode, pasted = false) { 128 | try { 129 | // Wrap file in formdata so it includes filename 130 | const body = new FormData(); 131 | body.append("image", file); 132 | if (pasted) body.append("subfolder", "pasted"); 133 | const resp = await api.fetchApi("/upload/image", { 134 | method: "POST", 135 | body, 136 | }); 137 | 138 | if (resp.status === 200) { 139 | const data = await resp.json(); 140 | // Add the file to the dropdown list and update the widget value 141 | let path = data.name; 142 | if (data.subfolder) path = data.subfolder + "/" + path; 143 | 144 | if (!videoWidget.options.values.includes(path)) { 145 | videoWidget.options.values.push(path); 146 | } 147 | 148 | if (updateNode) { 149 | videoWidget.value = path; 150 | previewVideo(node,path) 151 | 152 | } 153 | } else { 154 | alert(resp.status + " - " + resp.statusText); 155 | } 156 | } catch (error) { 157 | alert(error); 158 | } 159 | } 160 | 161 | const fileInput = document.createElement("input"); 162 | Object.assign(fileInput, { 163 | type: "file", 164 | accept: "video/webm,video/mp4,video/mkv,video/avi", 165 | style: "display: none", 166 | onchange: async () => { 167 | if (fileInput.files.length) { 168 | await uploadFile(fileInput.files[0], true); 169 | } 170 | }, 171 | }); 172 | document.body.append(fileInput); 173 | 174 | // Create the button widget for selecting the files 175 | uploadWidget = node.addWidget("button", "choose video file to upload", "Video", () => { 176 | fileInput.click(); 177 | }); 178 | 179 | uploadWidget.serialize = false; 180 | 181 | previewVideo(node, videoWidget.value); 182 | const cb = node.callback; 183 | videoWidget.callback = function () { 184 | previewVideo(node,videoWidget.value); 185 | if (cb) { 186 | return cb.apply(this, arguments); 187 | } 188 | }; 189 | 190 | return { widget: uploadWidget }; 191 | } 192 | 193 | ComfyWidgets.VIDEOPLOAD = videoUpload; 194 | 195 | app.registerExtension({ 196 | name: "DiffSynth-Studio.UploadVideo", 197 | async beforeRegisterNodeDef(nodeType, nodeData, app) { 198 | if (nodeData?.name == "LoadVideo") { 199 | nodeData.input.required.upload = ["VIDEOPLOAD"]; 200 | } 201 | }, 202 | }); 203 | 204 | -------------------------------------------------------------------------------- /workfolws/diffutoon_workflow.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 18, 3 | "last_link_id": 22, 4 | "nodes": [ 5 | { 6 | "id": 10, 7 | "type": "PreViewVideo", 8 | "pos": [ 9 | 1017.6000366210938, 10 | 406.20001220703125 11 | ], 12 | "size": { 13 | "0": 210, 14 | "1": 26 15 | }, 16 | "flags": {}, 17 | "order": 7, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "video", 22 | "type": "VIDEO", 23 | "link": 21 24 | } 25 | ], 26 | "properties": { 27 | "Node name for S&R": "PreViewVideo" 28 | } 29 | }, 30 | { 31 | "id": 6, 32 | "type": "DiffTextNode", 33 | "pos": [ 34 | 505, 35 | 485 36 | ], 37 | "size": { 38 | "0": 400, 39 | "1": 200 40 | }, 41 | "flags": {}, 42 | "order": 0, 43 | "mode": 0, 44 | "outputs": [ 45 | { 46 | "name": "TEXT", 47 | "type": "TEXT", 48 | "links": [ 49 | 18 50 | ], 51 | "shape": 3 52 | } 53 | ], 54 | "properties": { 55 | "Node name for S&R": "DiffTextNode" 56 | }, 57 | "widgets_values": [ 58 | "verybadimagenegative_v1.3" 59 | ] 60 | }, 61 | { 62 | "id": 16, 63 | "type": "ControlNetPathLoader", 64 | "pos": [ 65 | 1138, 66 | 206 67 | ], 68 | "size": { 69 | "0": 315, 70 | "1": 106 71 | }, 72 | "flags": {}, 73 | "order": 1, 74 | "mode": 0, 75 | "outputs": [ 76 | { 77 | "name": "ControlNetConfigUnit", 78 | "type": "ControlNetConfigUnit", 79 | "links": [ 80 | 20 81 | ], 82 | "shape": 3 83 | } 84 | ], 85 | "properties": { 86 | "Node name for S&R": "ControlNetPathLoader" 87 | }, 88 | "widgets_values": [ 89 | "tile", 90 | 0.5, 91 | null 92 | ] 93 | }, 94 | { 95 | "id": 15, 96 | "type": "ControlNetPathLoader", 97 | "pos": [ 98 | 1128, 99 | 17 100 | ], 101 | "size": { 102 | "0": 315, 103 | "1": 106 104 | }, 105 | "flags": {}, 106 | "order": 2, 107 | "mode": 0, 108 | "outputs": [ 109 | { 110 | "name": "ControlNetConfigUnit", 111 | "type": "ControlNetConfigUnit", 112 | "links": [ 113 | 19 114 | ], 115 | "shape": 3 116 | } 117 | ], 118 | "properties": { 119 | "Node name for S&R": "ControlNetPathLoader" 120 | }, 121 | "widgets_values": [ 122 | "lineart", 123 | 0.5, 124 | "control_v11p_sd15_lineart.pth" 125 | ] 126 | }, 127 | { 128 | "id": 3, 129 | "type": "LoadVideo", 130 | "pos": [ 131 | 153, 132 | -100 133 | ], 134 | "size": { 135 | "0": 315, 136 | "1": 383 137 | }, 138 | "flags": {}, 139 | "order": 3, 140 | "mode": 0, 141 | "outputs": [ 142 | { 143 | "name": "VIDEO", 144 | "type": "VIDEO", 145 | "links": [ 146 | 15 147 | ], 148 | "shape": 3, 149 | "slot_index": 0 150 | } 151 | ], 152 | "properties": { 153 | "Node name for S&R": "LoadVideo" 154 | }, 155 | "widgets_values": [ 156 | "diffutoondemo.mp4", 157 | "Video", 158 | { 159 | "hidden": false, 160 | "paused": false, 161 | "params": {} 162 | } 163 | ] 164 | }, 165 | { 166 | "id": 5, 167 | "type": "DiffTextNode", 168 | "pos": [ 169 | 104, 170 | 664 171 | ], 172 | "size": { 173 | "0": 400, 174 | "1": 200 175 | }, 176 | "flags": {}, 177 | "order": 4, 178 | "mode": 0, 179 | "outputs": [ 180 | { 181 | "name": "TEXT", 182 | "type": "TEXT", 183 | "links": [ 184 | 17 185 | ], 186 | "shape": 3, 187 | "slot_index": 0 188 | } 189 | ], 190 | "properties": { 191 | "Node name for S&R": "DiffTextNode" 192 | }, 193 | "widgets_values": [ 194 | "best quality, perfect anime illustration, light, a girl is dancing, smile, solo" 195 | ] 196 | }, 197 | { 198 | "id": 17, 199 | "type": "DiffutoonNode", 200 | "pos": [ 201 | 674.969524572754, 202 | 27.48489999999994 203 | ], 204 | "size": { 205 | "0": 315, 206 | "1": 370 207 | }, 208 | "flags": {}, 209 | "order": 6, 210 | "mode": 0, 211 | "inputs": [ 212 | { 213 | "name": "source_video_path", 214 | "type": "VIDEO", 215 | "link": 15 216 | }, 217 | { 218 | "name": "sd_model_path", 219 | "type": "SD_MODEL_PATH", 220 | "link": 22, 221 | "slot_index": 1 222 | }, 223 | { 224 | "name": "postive_prompt", 225 | "type": "TEXT", 226 | "link": 17 227 | }, 228 | { 229 | "name": "negative_prompt", 230 | "type": "TEXT", 231 | "link": 18, 232 | "slot_index": 3 233 | }, 234 | { 235 | "name": "controlnet1", 236 | "type": "ControlNetConfigUnit", 237 | "link": 19, 238 | "slot_index": 4 239 | }, 240 | { 241 | "name": "controlnet2", 242 | "type": "ControlNetConfigUnit", 243 | "link": 20, 244 | "slot_index": 5 245 | }, 246 | { 247 | "name": "controlnet3", 248 | "type": "ControlNetConfigUnit", 249 | "link": null 250 | } 251 | ], 252 | "outputs": [ 253 | { 254 | "name": "VIDEO", 255 | "type": "VIDEO", 256 | "links": [ 257 | 21 258 | ], 259 | "shape": 3, 260 | "slot_index": 0 261 | } 262 | ], 263 | "properties": { 264 | "Node name for S&R": "DiffutoonNode" 265 | }, 266 | "widgets_values": [ 267 | 40, 268 | 1, 269 | 531, 270 | "randomize", 271 | 3, 272 | 10, 273 | 32, 274 | 16, 275 | 0 276 | ] 277 | }, 278 | { 279 | "id": 18, 280 | "type": "SDPathLoader", 281 | "pos": [ 282 | 94, 283 | 399 284 | ], 285 | "size": { 286 | "0": 315, 287 | "1": 130 288 | }, 289 | "flags": {}, 290 | "order": 5, 291 | "mode": 0, 292 | "outputs": [ 293 | { 294 | "name": "SD_MODEL_PATH", 295 | "type": "SD_MODEL_PATH", 296 | "links": [ 297 | 22 298 | ], 299 | "shape": 3 300 | } 301 | ], 302 | "properties": { 303 | "Node name for S&R": "SDPathLoader" 304 | }, 305 | "widgets_values": [ 306 | "philz1337x/flat2DAnimerge_v45Sharp", 307 | "flat2DAnimerge_v45Sharp.safetensors", 308 | "HuggingFace", 309 | "flat2DAnimerge_v45Sharp.safetensors" 310 | ] 311 | } 312 | ], 313 | "links": [ 314 | [ 315 | 15, 316 | 3, 317 | 0, 318 | 17, 319 | 0, 320 | "VIDEO" 321 | ], 322 | [ 323 | 17, 324 | 5, 325 | 0, 326 | 17, 327 | 2, 328 | "TEXT" 329 | ], 330 | [ 331 | 18, 332 | 6, 333 | 0, 334 | 17, 335 | 3, 336 | "TEXT" 337 | ], 338 | [ 339 | 19, 340 | 15, 341 | 0, 342 | 17, 343 | 4, 344 | "ControlNetConfigUnit" 345 | ], 346 | [ 347 | 20, 348 | 16, 349 | 0, 350 | 17, 351 | 5, 352 | "ControlNetConfigUnit" 353 | ], 354 | [ 355 | 21, 356 | 17, 357 | 0, 358 | 10, 359 | 0, 360 | "VIDEO" 361 | ], 362 | [ 363 | 22, 364 | 18, 365 | 0, 366 | 17, 367 | 1, 368 | "SD_MODEL_PATH" 369 | ] 370 | ], 371 | "groups": [], 372 | "config": {}, 373 | "extra": { 374 | "ds": { 375 | "scale": 0.6830134553650705, 376 | "offset": [ 377 | 88.7050890441895, 378 | 88.17900000000009 379 | ] 380 | } 381 | }, 382 | "version": 0.4 383 | } -------------------------------------------------------------------------------- /workfolws/exvideo_workflow.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 5, 3 | "last_link_id": 4, 4 | "nodes": [ 5 | { 6 | "id": 3, 7 | "type": "SDPathLoader", 8 | "pos": [ 9 | 40, 10 | 423 11 | ], 12 | "size": { 13 | "0": 420.60003662109375, 14 | "1": 136.60003662109375 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "SD_MODEL_PATH", 22 | "type": "SD_MODEL_PATH", 23 | "links": [ 24 | 2 25 | ], 26 | "shape": 3 27 | } 28 | ], 29 | "properties": { 30 | "Node name for S&R": "SDPathLoader" 31 | }, 32 | "widgets_values": [ 33 | "stabilityai/stable-video-diffusion-img2vid-xt", 34 | "svd_xt.safetensors", 35 | "HuggingFace", 36 | "svd_xt.safetensors" 37 | ] 38 | }, 39 | { 40 | "id": 4, 41 | "type": "SDPathLoader", 42 | "pos": [ 43 | 514, 44 | 435 45 | ], 46 | "size": { 47 | "0": 489.800048828125, 48 | "1": 130 49 | }, 50 | "flags": {}, 51 | "order": 1, 52 | "mode": 0, 53 | "outputs": [ 54 | { 55 | "name": "SD_MODEL_PATH", 56 | "type": "SD_MODEL_PATH", 57 | "links": [ 58 | 3 59 | ], 60 | "shape": 3 61 | } 62 | ], 63 | "properties": { 64 | "Node name for S&R": "SDPathLoader" 65 | }, 66 | "widgets_values": [ 67 | "ECNU-CILab/ExVideo-SVD-128f-v1", 68 | "model.fp16.safetensors", 69 | "HuggingFace", 70 | "model.fp16.safetensors" 71 | ] 72 | }, 73 | { 74 | "id": 2, 75 | "type": "LoadImage", 76 | "pos": [ 77 | 51, 78 | 60 79 | ], 80 | "size": { 81 | "0": 315, 82 | "1": 314 83 | }, 84 | "flags": {}, 85 | "order": 2, 86 | "mode": 0, 87 | "outputs": [ 88 | { 89 | "name": "IMAGE", 90 | "type": "IMAGE", 91 | "links": [ 92 | 1 93 | ], 94 | "shape": 3 95 | }, 96 | { 97 | "name": "MASK", 98 | "type": "MASK", 99 | "links": null, 100 | "shape": 3 101 | } 102 | ], 103 | "properties": { 104 | "Node name for S&R": "LoadImage" 105 | }, 106 | "widgets_values": [ 107 | "01c6925ae17636a801214a61208e31.png@2o.png", 108 | "image" 109 | ] 110 | }, 111 | { 112 | "id": 1, 113 | "type": "ExVideoNode", 114 | "pos": [ 115 | 633, 116 | 97 117 | ], 118 | "size": { 119 | "0": 315, 120 | "1": 218 121 | }, 122 | "flags": {}, 123 | "order": 3, 124 | "mode": 0, 125 | "inputs": [ 126 | { 127 | "name": "image", 128 | "type": "IMAGE", 129 | "link": 1, 130 | "slot_index": 0 131 | }, 132 | { 133 | "name": "svd_base_model", 134 | "type": "SD_MODEL_PATH", 135 | "link": 2, 136 | "slot_index": 1 137 | }, 138 | { 139 | "name": "exvideo_model", 140 | "type": "SD_MODEL_PATH", 141 | "link": 3, 142 | "slot_index": 2 143 | } 144 | ], 145 | "outputs": [ 146 | { 147 | "name": "VIDEO", 148 | "type": "VIDEO", 149 | "links": [ 150 | 4 151 | ], 152 | "shape": 3, 153 | "slot_index": 0 154 | } 155 | ], 156 | "properties": { 157 | "Node name for S&R": "ExVideoNode" 158 | }, 159 | "widgets_values": [ 160 | 50, 161 | 25, 162 | 20, 163 | true, 164 | 874, 165 | "randomize" 166 | ] 167 | }, 168 | { 169 | "id": 5, 170 | "type": "PreViewVideo", 171 | "pos": [ 172 | 1062, 173 | 105 174 | ], 175 | "size": { 176 | "0": 210, 177 | "1": 434 178 | }, 179 | "flags": {}, 180 | "order": 4, 181 | "mode": 0, 182 | "inputs": [ 183 | { 184 | "name": "video", 185 | "type": "VIDEO", 186 | "link": 4 187 | } 188 | ], 189 | "properties": { 190 | "Node name for S&R": "PreViewVideo" 191 | }, 192 | "widgets_values": [ 193 | { 194 | "hidden": false, 195 | "paused": false, 196 | "params": {} 197 | } 198 | ] 199 | } 200 | ], 201 | "links": [ 202 | [ 203 | 1, 204 | 2, 205 | 0, 206 | 1, 207 | 0, 208 | "IMAGE" 209 | ], 210 | [ 211 | 2, 212 | 3, 213 | 0, 214 | 1, 215 | 1, 216 | "SD_MODEL_PATH" 217 | ], 218 | [ 219 | 3, 220 | 4, 221 | 0, 222 | 1, 223 | 2, 224 | "SD_MODEL_PATH" 225 | ], 226 | [ 227 | 4, 228 | 1, 229 | 0, 230 | 5, 231 | 0, 232 | "VIDEO" 233 | ] 234 | ], 235 | "groups": [], 236 | "config": {}, 237 | "extra": { 238 | "ds": { 239 | "scale": 1, 240 | "offset": [ 241 | 0, 242 | -0.79998779296875 243 | ] 244 | } 245 | }, 246 | "version": 0.4 247 | } --------------------------------------------------------------------------------