├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── diffsynth
├── __init__.py
├── controlnets
│ ├── __init__.py
│ ├── controlnet_unit.py
│ └── processors.py
├── data
│ ├── __init__.py
│ └── video.py
├── extensions
│ ├── ESRGAN
│ │ └── __init__.py
│ ├── FastBlend
│ │ ├── __init__.py
│ │ ├── api.py
│ │ ├── cupy_kernels.py
│ │ ├── data.py
│ │ ├── patch_match.py
│ │ └── runners
│ │ │ ├── __init__.py
│ │ │ ├── accurate.py
│ │ │ ├── balanced.py
│ │ │ ├── fast.py
│ │ │ └── interpolation.py
│ └── RIFE
│ │ └── __init__.py
├── models
│ ├── __init__.py
│ ├── attention.py
│ ├── downloader.py
│ ├── hunyuan_dit.py
│ ├── hunyuan_dit_text_encoder.py
│ ├── sd_controlnet.py
│ ├── sd_ipadapter.py
│ ├── sd_lora.py
│ ├── sd_motion.py
│ ├── sd_text_encoder.py
│ ├── sd_unet.py
│ ├── sd_vae_decoder.py
│ ├── sd_vae_encoder.py
│ ├── sdxl_ipadapter.py
│ ├── sdxl_motion.py
│ ├── sdxl_text_encoder.py
│ ├── sdxl_unet.py
│ ├── sdxl_vae_decoder.py
│ ├── sdxl_vae_encoder.py
│ ├── svd_image_encoder.py
│ ├── svd_unet.py
│ ├── svd_vae_decoder.py
│ ├── svd_vae_encoder.py
│ └── tiler.py
├── pipelines
│ ├── __init__.py
│ ├── dancer.py
│ ├── hunyuan_dit.py
│ ├── stable_diffusion.py
│ ├── stable_diffusion_video.py
│ ├── stable_diffusion_xl.py
│ ├── stable_diffusion_xl_video.py
│ └── stable_video_diffusion.py
├── processors
│ ├── FastBlend.py
│ ├── PILEditor.py
│ ├── RIFE.py
│ ├── __init__.py
│ ├── base.py
│ └── sequencial_processor.py
├── prompts
│ ├── __init__.py
│ ├── hunyuan_dit_prompter.py
│ ├── sd_prompter.py
│ ├── sdxl_prompter.py
│ └── utils.py
├── schedulers
│ ├── __init__.py
│ ├── continuous_ode.py
│ └── ddim.py
└── tokenizer_configs
│ ├── hunyuan_dit
│ ├── tokenizer
│ │ ├── special_tokens_map.json
│ │ ├── tokenizer_config.json
│ │ ├── vocab.txt
│ │ └── vocab_org.txt
│ └── tokenizer_t5
│ │ ├── config.json
│ │ ├── special_tokens_map.json
│ │ ├── spiece.model
│ │ └── tokenizer_config.json
│ ├── stable_diffusion
│ └── tokenizer
│ │ ├── merges.txt
│ │ ├── special_tokens_map.json
│ │ ├── tokenizer_config.json
│ │ └── vocab.json
│ └── stable_diffusion_xl
│ └── tokenizer_2
│ ├── merges.txt
│ ├── special_tokens_map.json
│ ├── tokenizer_config.json
│ └── vocab.json
├── models
└── put diffsynth studio models here
├── requirements.txt
├── studio_nodes.py
├── util_nodes.py
├── web.png
├── web
└── js
│ ├── previewVideo.js
│ └── uploadVideo.js
└── workfolws
├── diffutoon_workflow.json
└── exvideo_workflow.json
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | models/*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ComfyUI-DiffSynth-Studio
2 | make [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) avialbe in ComfyUI
3 |
4 |
5 |
6 |
7 |
8 |
9 | ## how to use
10 | test on py3.10,2080ti 11gb,torch==2.3.0+cu121
11 |
12 | make sure `ffmpeg` is worked in your commandline
13 | for Linux
14 | ```
15 | apt update
16 | apt install ffmpeg
17 | ```
18 | for Windows,you can install `ffmpeg` by [WingetUI](https://github.com/marticliment/WingetUI) automatically
19 |
20 | then!
21 | ```
22 | # in ComfyUI/custom_nodes
23 | git clone https://github.com/AIFSH/ComfyUI-DiffSynth-Studio.git
24 | cd ComfyUI-DiffSynth-Studio
25 | pip install -r requirements.txt
26 | ```
27 | weights will be downloaded from huggingface or model scope
28 |
29 | ## Tutorial
30 | - [Demo for Diffutoon](https://b23.tv/z7hEXlX)[DiffSynth-Studio!ComfyUI插件之Diffutoon节点-哔哩哔哩](https://b23.tv/z7hEXlX)
31 | - [Run on 4090](https://www.xiangongyun.com/image/detail/13706bf7-f3e6-4e29-bb97-c79405f5def4)
32 | - wechat:aifsh_98
33 |
34 | ## Nodes Detail and Workflow
35 | ### ExVideo Node
36 | [ExVideo workflow](./workfolws/exvideo_workflow.json)
37 | ```
38 | "image":("IMAGE",),
39 | "svd_base_model":("SD_MODEL_PATH",),
40 | "exvideo_model":("SD_MODEL_PATH",),
41 | "num_frames":("INT",{
42 | "default": 128
43 | }),
44 | "fps":("INT",{
45 | "default": 30
46 | }),
47 | "num_inference_steps":("INT",{
48 | "default": 50
49 | }),
50 | "if_upscale":("BOOLEAN",{
51 | "default": True,
52 | }),
53 | "seed": ("INT",{
54 | "default": 1
55 | })
56 | ```
57 | ### image synthesis
58 | comming soon
59 | ### Diffutoon Node
60 | [Diffutoon workflow](./workfolws/diffutoon_workflow.json)
61 | ```
62 | "required":{
63 | "source_video_path": ("VIDEO",),
64 | "sd_model_path":("SD_MODEL_PATH",),
65 | "postive_prompt":("TEXT",),
66 | "negative_prompt":("TEXT",),
67 | "start":("INT",{
68 | "default": 0 ## from which second of your video to be shaded
69 | }),
70 | "length":("INT",{
71 | "default": -1 ## how long you want to shade, -1 name the whole video frames
72 | }),
73 | "seed":("INT",{
74 | "default": 42
75 | }),
76 | "cfg_scale":("INT",{
77 | "default": 3
78 | }),
79 | "num_inference_steps":("INT",{
80 | "default": 10
81 | }),
82 | "animatediff_batch_size":("INT",{
83 | "default": 32 ## lower it till you can run
84 | }),
85 | "animatediff_stride":("INT",{
86 | "default": 16 ## lower it till you can run
87 | }),
88 | "vram_limit_level":("INT",{
89 | "default": 0 ## meet killed? try to 1
90 | }),
91 | },
92 | "optional":{
93 | "controlnet1":("ControlNetConfigUnit",),
94 | "controlnet2":("ControlNetConfigUnit",),
95 | "controlnet3":("ControlNetConfigUnit",),
96 | }
97 | ```
98 |
99 | ### Video Stylization
100 |
101 | ### Chinese Models
102 |
103 | ## ask for answer as soon as you want
104 | wechat: aifsh_98
105 | need donate if you mand it,
106 | but please feel free to new issue for answering
107 |
108 | Windows环境配置太难?可以添加微信:aifsh_98,赞赏获取Windows一键包,当然你也可以提issue等待大佬为你答疑解惑。
109 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .util_nodes import LoadVideo,PreViewVideo
2 | from .studio_nodes import DiffTextNode,DiffutoonNode,SDPathLoader,ControlNetPathLoader,ExVideoNode
3 | WEB_DIRECTORY = "./web"
4 | # A dictionary that contains all nodes you want to export with their names
5 | # NOTE: names should be globally unique
6 | NODE_CLASS_MAPPINGS = {
7 | "LoadVideo": LoadVideo,
8 | "PreViewVideo": PreViewVideo,
9 | "SDPathLoader": SDPathLoader,
10 | "DiffTextNode": DiffTextNode,
11 | "DiffutoonNode": DiffutoonNode,
12 | "ControlNetPathLoader": ControlNetPathLoader,
13 | "ExVideoNode": ExVideoNode
14 | }
15 |
16 | # A dictionary that contains the friendly/humanly readable titles for the nodes
17 | NODE_DISPLAY_NAME_MAPPINGS = {
18 | "LoadVideo": "LoadVideo",
19 | "PreViewVideo": "PreViewVideo",
20 | "SDPathLoader": "SDPathLoader",
21 | "DiffTextNode": "DiffTextNode",
22 | "DiffutoonNode": "DiffutoonNode",
23 | "ControlNetPathLoader": "ControlNetPathLoader",
24 | "ExVideoNode": "ExVideoNode"
25 | }
--------------------------------------------------------------------------------
/diffsynth/__init__.py:
--------------------------------------------------------------------------------
1 | from .data import *
2 | from .models import *
3 | from .prompts import *
4 | from .schedulers import *
5 | from .pipelines import *
6 | from .controlnets import *
7 |
--------------------------------------------------------------------------------
/diffsynth/controlnets/__init__.py:
--------------------------------------------------------------------------------
1 | from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
2 | from .processors import Annotator
3 |
--------------------------------------------------------------------------------
/diffsynth/controlnets/controlnet_unit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from .processors import Processor_id
4 |
5 |
6 | class ControlNetConfigUnit:
7 | def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
8 | self.processor_id = processor_id
9 | self.model_path = model_path
10 | self.scale = scale
11 |
12 |
13 | class ControlNetUnit:
14 | def __init__(self, processor, model, scale=1.0):
15 | self.processor = processor
16 | self.model = model
17 | self.scale = scale
18 |
19 |
20 | class MultiControlNetManager:
21 | def __init__(self, controlnet_units=[]):
22 | self.processors = [unit.processor for unit in controlnet_units]
23 | self.models = [unit.model for unit in controlnet_units]
24 | self.scales = [unit.scale for unit in controlnet_units]
25 |
26 | def process_image(self, image, processor_id=None):
27 | if processor_id is None:
28 | processed_image = [processor(image) for processor in self.processors]
29 | else:
30 | processed_image = [self.processors[processor_id](image)]
31 | processed_image = torch.concat([
32 | torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
33 | for image_ in processed_image
34 | ], dim=0)
35 | return processed_image
36 |
37 | def __call__(
38 | self,
39 | sample, timestep, encoder_hidden_states, conditionings,
40 | tiled=False, tile_size=64, tile_stride=32
41 | ):
42 | res_stack = None
43 | for conditioning, model, scale in zip(conditionings, self.models, self.scales):
44 | res_stack_ = model(
45 | sample, timestep, encoder_hidden_states, conditioning,
46 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
47 | )
48 | res_stack_ = [res * scale for res in res_stack_]
49 | if res_stack is None:
50 | res_stack = res_stack_
51 | else:
52 | res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
53 | return res_stack
54 |
--------------------------------------------------------------------------------
/diffsynth/controlnets/processors.py:
--------------------------------------------------------------------------------
1 | import os
2 | import folder_paths
3 | from typing_extensions import Literal, TypeAlias
4 | import warnings
5 | with warnings.catch_warnings():
6 | warnings.simplefilter("ignore")
7 | from controlnet_aux.processor import (
8 | CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
9 | )
10 |
11 | annotators_dir = os.path.join(folder_paths.models_dir, "Annotators")
12 | Processor_id: TypeAlias = Literal[
13 | "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
14 | ]
15 |
16 | class Annotator:
17 | def __init__(self, processor_id: Processor_id, model_path=annotators_dir, detect_resolution=None):
18 | if processor_id == "canny":
19 | self.processor = CannyDetector()
20 | elif processor_id == "depth":
21 | self.processor = MidasDetector.from_pretrained(model_path).to("cuda")
22 | elif processor_id == "softedge":
23 | self.processor = HEDdetector.from_pretrained(model_path).to("cuda")
24 | elif processor_id == "lineart":
25 | self.processor = LineartDetector.from_pretrained(model_path).to("cuda")
26 | elif processor_id == "lineart_anime":
27 | self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda")
28 | elif processor_id == "openpose":
29 | self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda")
30 | elif processor_id == "tile":
31 | self.processor = None
32 | else:
33 | raise ValueError(f"Unsupported processor_id: {processor_id}")
34 |
35 | self.processor_id = processor_id
36 | self.detect_resolution = detect_resolution
37 |
38 | def __call__(self, image):
39 | width, height = image.size
40 | if self.processor_id == "openpose":
41 | kwargs = {
42 | "include_body": True,
43 | "include_hand": True,
44 | "include_face": True
45 | }
46 | else:
47 | kwargs = {}
48 | if self.processor is not None:
49 | detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
50 | image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
51 | image = image.resize((width, height))
52 | return image
53 |
54 |
--------------------------------------------------------------------------------
/diffsynth/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .video import VideoData, save_video, save_frames
2 |
--------------------------------------------------------------------------------
/diffsynth/data/video.py:
--------------------------------------------------------------------------------
1 | import imageio, os
2 | import numpy as np
3 | from PIL import Image
4 | from tqdm import tqdm
5 |
6 |
7 | class LowMemoryVideo:
8 | def __init__(self, file_name):
9 | self.reader = imageio.get_reader(file_name)
10 |
11 | def __len__(self):
12 | return self.reader.count_frames()
13 |
14 | def __getitem__(self, item):
15 | return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
16 |
17 | def __del__(self):
18 | self.reader.close()
19 |
20 |
21 | def split_file_name(file_name):
22 | result = []
23 | number = -1
24 | for i in file_name:
25 | if ord(i)>=ord("0") and ord(i)<=ord("9"):
26 | if number == -1:
27 | number = 0
28 | number = number*10 + ord(i) - ord("0")
29 | else:
30 | if number != -1:
31 | result.append(number)
32 | number = -1
33 | result.append(i)
34 | if number != -1:
35 | result.append(number)
36 | result = tuple(result)
37 | return result
38 |
39 |
40 | def search_for_images(folder):
41 | file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
42 | file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
43 | file_list = [i[1] for i in sorted(file_list)]
44 | file_list = [os.path.join(folder, i) for i in file_list]
45 | return file_list
46 |
47 |
48 | class LowMemoryImageFolder:
49 | def __init__(self, folder, file_list=None):
50 | if file_list is None:
51 | self.file_list = search_for_images(folder)
52 | else:
53 | self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
54 |
55 | def __len__(self):
56 | return len(self.file_list)
57 |
58 | def __getitem__(self, item):
59 | return Image.open(self.file_list[item]).convert("RGB")
60 |
61 | def __del__(self):
62 | pass
63 |
64 |
65 | def crop_and_resize(image, height, width):
66 | image = np.array(image)
67 | image_height, image_width, _ = image.shape
68 | if image_height / image_width < height / width:
69 | croped_width = int(image_height / height * width)
70 | left = (image_width - croped_width) // 2
71 | image = image[:, left: left+croped_width]
72 | image = Image.fromarray(image).resize((width, height))
73 | else:
74 | croped_height = int(image_width / width * height)
75 | left = (image_height - croped_height) // 2
76 | image = image[left: left+croped_height, :]
77 | image = Image.fromarray(image).resize((width, height))
78 | return image
79 |
80 |
81 | class VideoData:
82 | def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
83 | if video_file is not None:
84 | self.data_type = "video"
85 | self.data = LowMemoryVideo(video_file, **kwargs)
86 | elif image_folder is not None:
87 | self.data_type = "images"
88 | self.data = LowMemoryImageFolder(image_folder, **kwargs)
89 | else:
90 | raise ValueError("Cannot open video or image folder")
91 | self.length = None
92 | self.set_shape(height, width)
93 |
94 | def raw_data(self):
95 | frames = []
96 | for i in range(self.__len__()):
97 | frames.append(self.__getitem__(i))
98 | return frames
99 |
100 | def set_length(self, length):
101 | self.length = length
102 |
103 | def set_shape(self, height, width):
104 | self.height = height
105 | self.width = width
106 |
107 | def __len__(self):
108 | if self.length is None:
109 | return len(self.data)
110 | else:
111 | return self.length
112 |
113 | def shape(self):
114 | if self.height is not None and self.width is not None:
115 | return self.height, self.width
116 | else:
117 | height, width = self.__getitem__(0).size
118 | return height, width
119 |
120 | def __getitem__(self, item):
121 | frame = self.data.__getitem__(item)
122 | width, height = frame.size
123 | if self.height is not None and self.width is not None:
124 | if self.height != height or self.width != width:
125 | frame = crop_and_resize(frame, self.height, self.width)
126 | return frame
127 |
128 | def __del__(self):
129 | pass
130 |
131 | def save_images(self, folder):
132 | os.makedirs(folder, exist_ok=True)
133 | for i in tqdm(range(self.__len__()), desc="Saving images"):
134 | frame = self.__getitem__(i)
135 | frame.save(os.path.join(folder, f"{i}.png"))
136 |
137 |
138 | def save_video(frames, save_path, fps, quality=9):
139 | writer = imageio.get_writer(save_path, fps=fps, quality=quality)
140 | for frame in tqdm(frames, desc="Saving video"):
141 | frame = np.array(frame)
142 | writer.append_data(frame)
143 | writer.close()
144 |
145 | def save_frames(frames, save_path):
146 | os.makedirs(save_path, exist_ok=True)
147 | for i, frame in enumerate(tqdm(frames, desc="Saving images")):
148 | frame.save(os.path.join(save_path, f"{i}.png"))
149 |
--------------------------------------------------------------------------------
/diffsynth/extensions/ESRGAN/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import repeat
3 | from PIL import Image
4 | import numpy as np
5 |
6 |
7 | class ResidualDenseBlock(torch.nn.Module):
8 |
9 | def __init__(self, num_feat=64, num_grow_ch=32):
10 | super(ResidualDenseBlock, self).__init__()
11 | self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
12 | self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
13 | self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
14 | self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
15 | self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
16 | self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
17 |
18 | def forward(self, x):
19 | x1 = self.lrelu(self.conv1(x))
20 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
21 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
22 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
23 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
24 | return x5 * 0.2 + x
25 |
26 |
27 | class RRDB(torch.nn.Module):
28 |
29 | def __init__(self, num_feat, num_grow_ch=32):
30 | super(RRDB, self).__init__()
31 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
32 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
33 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
34 |
35 | def forward(self, x):
36 | out = self.rdb1(x)
37 | out = self.rdb2(out)
38 | out = self.rdb3(out)
39 | return out * 0.2 + x
40 |
41 |
42 | class RRDBNet(torch.nn.Module):
43 |
44 | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32):
45 | super(RRDBNet, self).__init__()
46 | self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
47 | self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
48 | self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
49 | # upsample
50 | self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
51 | self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
52 | self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
53 | self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
54 | self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
55 |
56 | def forward(self, x):
57 | feat = x
58 | feat = self.conv_first(feat)
59 | body_feat = self.conv_body(self.body(feat))
60 | feat = feat + body_feat
61 | # upsample
62 | feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
63 | feat = self.lrelu(self.conv_up1(feat))
64 | feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
65 | feat = self.lrelu(self.conv_up2(feat))
66 | out = self.conv_last(self.lrelu(self.conv_hr(feat)))
67 | return out
68 |
69 |
70 | class ESRGAN(torch.nn.Module):
71 | def __init__(self, model):
72 | super().__init__()
73 | self.model = model
74 |
75 | @staticmethod
76 | def from_pretrained(model_path):
77 | model = RRDBNet()
78 | state_dict = torch.load(model_path, map_location="cpu")["params_ema"]
79 | model.load_state_dict(state_dict)
80 | model.eval()
81 | return ESRGAN(model)
82 |
83 | def process_image(self, image):
84 | image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
85 | return image
86 |
87 | def process_images(self, images):
88 | images = [self.process_image(image) for image in images]
89 | images = torch.stack(images)
90 | return images
91 |
92 | def decode_images(self, images):
93 | images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
94 | images = [Image.fromarray(image) for image in images]
95 | return images
96 |
97 | @torch.no_grad()
98 | def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
99 | # Preprocess
100 | input_tensor = self.process_images(images)
101 |
102 | # Interpolate
103 | output_tensor = []
104 | for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
105 | batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
106 | batch_input_tensor = input_tensor[batch_id: batch_id_]
107 | batch_input_tensor = batch_input_tensor.to(
108 | device=self.model.conv_first.weight.device,
109 | dtype=self.model.conv_first.weight.dtype)
110 | batch_output_tensor = self.model(batch_input_tensor)
111 | output_tensor.append(batch_output_tensor.cpu())
112 |
113 | # Output
114 | output_tensor = torch.concat(output_tensor, dim=0)
115 |
116 | # To images
117 | output_images = self.decode_images(output_tensor)
118 | return output_images
119 |
--------------------------------------------------------------------------------
/diffsynth/extensions/FastBlend/__init__.py:
--------------------------------------------------------------------------------
1 | from .runners.fast import TableManager, PyramidPatchMatcher
2 | from PIL import Image
3 | import numpy as np
4 | import cupy as cp
5 |
6 |
7 | class FastBlendSmoother:
8 | def __init__(self):
9 | self.batch_size = 8
10 | self.window_size = 64
11 | self.ebsynth_config = {
12 | "minimum_patch_size": 5,
13 | "threads_per_block": 8,
14 | "num_iter": 5,
15 | "gpu_id": 0,
16 | "guide_weight": 10.0,
17 | "initialize": "identity",
18 | "tracking_window_size": 0,
19 | }
20 |
21 | @staticmethod
22 | def from_model_manager(model_manager):
23 | # TODO: fetch GPU ID from model_manager
24 | return FastBlendSmoother()
25 |
26 | def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
27 | frames_guide = [np.array(frame) for frame in frames_guide]
28 | frames_style = [np.array(frame) for frame in frames_style]
29 | table_manager = TableManager()
30 | patch_match_engine = PyramidPatchMatcher(
31 | image_height=frames_style[0].shape[0],
32 | image_width=frames_style[0].shape[1],
33 | channel=3,
34 | **ebsynth_config
35 | )
36 | # left part
37 | table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
38 | table_l = table_manager.remapping_table_to_blending_table(table_l)
39 | table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
40 | # right part
41 | table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
42 | table_r = table_manager.remapping_table_to_blending_table(table_r)
43 | table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
44 | # merge
45 | frames = []
46 | for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
47 | weight_m = -1
48 | weight = weight_l + weight_m + weight_r
49 | frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
50 | frames.append(frame)
51 | frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
52 | return frames
53 |
54 | def __call__(self, rendered_frames, original_frames=None, **kwargs):
55 | frames = self.run(
56 | original_frames, rendered_frames,
57 | self.batch_size, self.window_size, self.ebsynth_config
58 | )
59 | mempool = cp.get_default_memory_pool()
60 | pinned_mempool = cp.get_default_pinned_memory_pool()
61 | mempool.free_all_blocks()
62 | pinned_mempool.free_all_blocks()
63 | return frames
--------------------------------------------------------------------------------
/diffsynth/extensions/FastBlend/cupy_kernels.py:
--------------------------------------------------------------------------------
1 | import cupy as cp
2 |
3 | remapping_kernel = cp.RawKernel(r'''
4 | extern "C" __global__
5 | void remap(
6 | const int height,
7 | const int width,
8 | const int channel,
9 | const int patch_size,
10 | const int pad_size,
11 | const float* source_style,
12 | const int* nnf,
13 | float* target_style
14 | ) {
15 | const int r = (patch_size - 1) / 2;
16 | const int x = blockDim.x * blockIdx.x + threadIdx.x;
17 | const int y = blockDim.y * blockIdx.y + threadIdx.y;
18 | if (x >= height or y >= width) return;
19 | const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
20 | const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
21 | const int min_px = x < r ? -x : -r;
22 | const int max_px = x + r > height - 1 ? height - 1 - x : r;
23 | const int min_py = y < r ? -y : -r;
24 | const int max_py = y + r > width - 1 ? width - 1 - y : r;
25 | int num = 0;
26 | for (int px = min_px; px <= max_px; px++){
27 | for (int py = min_py; py <= max_py; py++){
28 | const int nid = (x + px) * width + y + py;
29 | const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
30 | const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
31 | if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
32 | const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
33 | num++;
34 | for (int c = 0; c < channel; c++){
35 | target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
36 | }
37 | }
38 | }
39 | for (int c = 0; c < channel; c++){
40 | target_style[z + pid * channel + c] /= num;
41 | }
42 | }
43 | ''', 'remap')
44 |
45 |
46 | patch_error_kernel = cp.RawKernel(r'''
47 | extern "C" __global__
48 | void patch_error(
49 | const int height,
50 | const int width,
51 | const int channel,
52 | const int patch_size,
53 | const int pad_size,
54 | const float* source,
55 | const int* nnf,
56 | const float* target,
57 | float* error
58 | ) {
59 | const int r = (patch_size - 1) / 2;
60 | const int x = blockDim.x * blockIdx.x + threadIdx.x;
61 | const int y = blockDim.y * blockIdx.y + threadIdx.y;
62 | const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
63 | if (x >= height or y >= width) return;
64 | const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
65 | const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
66 | float e = 0;
67 | for (int px = -r; px <= r; px++){
68 | for (int py = -r; py <= r; py++){
69 | const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
70 | const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
71 | for (int c = 0; c < channel; c++){
72 | const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
73 | e += diff * diff;
74 | }
75 | }
76 | }
77 | error[blockIdx.z * height * width + x * width + y] = e;
78 | }
79 | ''', 'patch_error')
80 |
81 |
82 | pairwise_patch_error_kernel = cp.RawKernel(r'''
83 | extern "C" __global__
84 | void pairwise_patch_error(
85 | const int height,
86 | const int width,
87 | const int channel,
88 | const int patch_size,
89 | const int pad_size,
90 | const float* source_a,
91 | const int* nnf_a,
92 | const float* source_b,
93 | const int* nnf_b,
94 | float* error
95 | ) {
96 | const int r = (patch_size - 1) / 2;
97 | const int x = blockDim.x * blockIdx.x + threadIdx.x;
98 | const int y = blockDim.y * blockIdx.y + threadIdx.y;
99 | const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
100 | if (x >= height or y >= width) return;
101 | const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
102 | const int x_a = nnf_a[z_nnf + 0];
103 | const int y_a = nnf_a[z_nnf + 1];
104 | const int x_b = nnf_b[z_nnf + 0];
105 | const int y_b = nnf_b[z_nnf + 1];
106 | float e = 0;
107 | for (int px = -r; px <= r; px++){
108 | for (int py = -r; py <= r; py++){
109 | const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
110 | const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
111 | for (int c = 0; c < channel; c++){
112 | const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
113 | e += diff * diff;
114 | }
115 | }
116 | }
117 | error[blockIdx.z * height * width + x * width + y] = e;
118 | }
119 | ''', 'pairwise_patch_error')
120 |
--------------------------------------------------------------------------------
/diffsynth/extensions/FastBlend/data.py:
--------------------------------------------------------------------------------
1 | import imageio, os
2 | import numpy as np
3 | from PIL import Image
4 |
5 |
6 | def read_video(file_name):
7 | reader = imageio.get_reader(file_name)
8 | video = []
9 | for frame in reader:
10 | frame = np.array(frame)
11 | video.append(frame)
12 | reader.close()
13 | return video
14 |
15 |
16 | def get_video_fps(file_name):
17 | reader = imageio.get_reader(file_name)
18 | fps = reader.get_meta_data()["fps"]
19 | reader.close()
20 | return fps
21 |
22 |
23 | def save_video(frames_path, video_path, num_frames, fps):
24 | writer = imageio.get_writer(video_path, fps=fps, quality=9)
25 | for i in range(num_frames):
26 | frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
27 | writer.append_data(frame)
28 | writer.close()
29 | return video_path
30 |
31 |
32 | class LowMemoryVideo:
33 | def __init__(self, file_name):
34 | self.reader = imageio.get_reader(file_name)
35 |
36 | def __len__(self):
37 | return self.reader.count_frames()
38 |
39 | def __getitem__(self, item):
40 | return np.array(self.reader.get_data(item))
41 |
42 | def __del__(self):
43 | self.reader.close()
44 |
45 |
46 | def split_file_name(file_name):
47 | result = []
48 | number = -1
49 | for i in file_name:
50 | if ord(i)>=ord("0") and ord(i)<=ord("9"):
51 | if number == -1:
52 | number = 0
53 | number = number*10 + ord(i) - ord("0")
54 | else:
55 | if number != -1:
56 | result.append(number)
57 | number = -1
58 | result.append(i)
59 | if number != -1:
60 | result.append(number)
61 | result = tuple(result)
62 | return result
63 |
64 |
65 | def search_for_images(folder):
66 | file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
67 | file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
68 | file_list = [i[1] for i in sorted(file_list)]
69 | file_list = [os.path.join(folder, i) for i in file_list]
70 | return file_list
71 |
72 |
73 | def read_images(folder):
74 | file_list = search_for_images(folder)
75 | frames = [np.array(Image.open(i)) for i in file_list]
76 | return frames
77 |
78 |
79 | class LowMemoryImageFolder:
80 | def __init__(self, folder, file_list=None):
81 | if file_list is None:
82 | self.file_list = search_for_images(folder)
83 | else:
84 | self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
85 |
86 | def __len__(self):
87 | return len(self.file_list)
88 |
89 | def __getitem__(self, item):
90 | return np.array(Image.open(self.file_list[item]))
91 |
92 | def __del__(self):
93 | pass
94 |
95 |
96 | class VideoData:
97 | def __init__(self, video_file, image_folder, **kwargs):
98 | if video_file is not None:
99 | self.data_type = "video"
100 | self.data = LowMemoryVideo(video_file, **kwargs)
101 | elif image_folder is not None:
102 | self.data_type = "images"
103 | self.data = LowMemoryImageFolder(image_folder, **kwargs)
104 | else:
105 | raise ValueError("Cannot open video or image folder")
106 | self.length = None
107 | self.height = None
108 | self.width = None
109 |
110 | def raw_data(self):
111 | frames = []
112 | for i in range(self.__len__()):
113 | frames.append(self.__getitem__(i))
114 | return frames
115 |
116 | def set_length(self, length):
117 | self.length = length
118 |
119 | def set_shape(self, height, width):
120 | self.height = height
121 | self.width = width
122 |
123 | def __len__(self):
124 | if self.length is None:
125 | return len(self.data)
126 | else:
127 | return self.length
128 |
129 | def shape(self):
130 | if self.height is not None and self.width is not None:
131 | return self.height, self.width
132 | else:
133 | height, width, _ = self.__getitem__(0).shape
134 | return height, width
135 |
136 | def __getitem__(self, item):
137 | frame = self.data.__getitem__(item)
138 | height, width, _ = frame.shape
139 | if self.height is not None and self.width is not None:
140 | if self.height != height or self.width != width:
141 | frame = Image.fromarray(frame).resize((self.width, self.height))
142 | frame = np.array(frame)
143 | return frame
144 |
145 | def __del__(self):
146 | pass
147 |
--------------------------------------------------------------------------------
/diffsynth/extensions/FastBlend/runners/__init__.py:
--------------------------------------------------------------------------------
1 | from .accurate import AccurateModeRunner
2 | from .fast import FastModeRunner
3 | from .balanced import BalancedModeRunner
4 | from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner
5 |
--------------------------------------------------------------------------------
/diffsynth/extensions/FastBlend/runners/accurate.py:
--------------------------------------------------------------------------------
1 | from ..patch_match import PyramidPatchMatcher
2 | import os
3 | import numpy as np
4 | from PIL import Image
5 | from tqdm import tqdm
6 |
7 |
8 | class AccurateModeRunner:
9 | def __init__(self):
10 | pass
11 |
12 | def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
13 | patch_match_engine = PyramidPatchMatcher(
14 | image_height=frames_style[0].shape[0],
15 | image_width=frames_style[0].shape[1],
16 | channel=3,
17 | use_mean_target_style=True,
18 | **ebsynth_config
19 | )
20 | # run
21 | n = len(frames_style)
22 | for target in tqdm(range(n), desc=desc):
23 | l, r = max(target - window_size, 0), min(target + window_size + 1, n)
24 | remapped_frames = []
25 | for i in range(l, r, batch_size):
26 | j = min(i + batch_size, r)
27 | source_guide = np.stack([frames_guide[source] for source in range(i, j)])
28 | target_guide = np.stack([frames_guide[target]] * (j - i))
29 | source_style = np.stack([frames_style[source] for source in range(i, j)])
30 | _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
31 | remapped_frames.append(target_style)
32 | frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
33 | frame = frame.clip(0, 255).astype("uint8")
34 | if save_path is not None:
35 | Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
--------------------------------------------------------------------------------
/diffsynth/extensions/FastBlend/runners/balanced.py:
--------------------------------------------------------------------------------
1 | from ..patch_match import PyramidPatchMatcher
2 | import os
3 | import numpy as np
4 | from PIL import Image
5 | from tqdm import tqdm
6 |
7 |
8 | class BalancedModeRunner:
9 | def __init__(self):
10 | pass
11 |
12 | def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
13 | patch_match_engine = PyramidPatchMatcher(
14 | image_height=frames_style[0].shape[0],
15 | image_width=frames_style[0].shape[1],
16 | channel=3,
17 | **ebsynth_config
18 | )
19 | # tasks
20 | n = len(frames_style)
21 | tasks = []
22 | for target in range(n):
23 | for source in range(target - window_size, target + window_size + 1):
24 | if source >= 0 and source < n and source != target:
25 | tasks.append((source, target))
26 | # run
27 | frames = [(None, 1) for i in range(n)]
28 | for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
29 | tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
30 | source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
31 | target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
32 | source_style = np.stack([frames_style[source] for source, target in tasks_batch])
33 | _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
34 | for (source, target), result in zip(tasks_batch, target_style):
35 | frame, weight = frames[target]
36 | if frame is None:
37 | frame = frames_style[target]
38 | frames[target] = (
39 | frame * (weight / (weight + 1)) + result / (weight + 1),
40 | weight + 1
41 | )
42 | if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
43 | frame = frame.clip(0, 255).astype("uint8")
44 | if save_path is not None:
45 | Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
46 | frames[target] = (None, 1)
47 |
--------------------------------------------------------------------------------
/diffsynth/extensions/FastBlend/runners/fast.py:
--------------------------------------------------------------------------------
1 | from ..patch_match import PyramidPatchMatcher
2 | import functools, os
3 | import numpy as np
4 | from PIL import Image
5 | from tqdm import tqdm
6 |
7 |
8 | class TableManager:
9 | def __init__(self):
10 | pass
11 |
12 | def task_list(self, n):
13 | tasks = []
14 | max_level = 1
15 | while (1<=n:
24 | break
25 | meta_data = {
26 | "source": i,
27 | "target": j,
28 | "level": level + 1
29 | }
30 | tasks.append(meta_data)
31 | tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
32 | return tasks
33 |
34 | def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
35 | n = len(frames_guide)
36 | tasks = self.task_list(n)
37 | remapping_table = [[(frames_style[i], 1)] for i in range(n)]
38 | for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
39 | tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
40 | source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
41 | target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
42 | source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
43 | _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
44 | for task, result in zip(tasks_batch, target_style):
45 | target, level = task["target"], task["level"]
46 | if len(remapping_table[target])==level:
47 | remapping_table[target].append((result, 1))
48 | else:
49 | frame, weight = remapping_table[target][level]
50 | remapping_table[target][level] = (
51 | frame * (weight / (weight + 1)) + result / (weight + 1),
52 | weight + 1
53 | )
54 | return remapping_table
55 |
56 | def remapping_table_to_blending_table(self, table):
57 | for i in range(len(table)):
58 | for j in range(1, len(table[i])):
59 | frame_1, weight_1 = table[i][j-1]
60 | frame_2, weight_2 = table[i][j]
61 | frame = (frame_1 + frame_2) / 2
62 | weight = weight_1 + weight_2
63 | table[i][j] = (frame, weight)
64 | return table
65 |
66 | def tree_query(self, leftbound, rightbound):
67 | node_list = []
68 | node_index = rightbound
69 | while node_index>=leftbound:
70 | node_level = 0
71 | while (1<=leftbound:
72 | node_level += 1
73 | node_list.append((node_index, node_level))
74 | node_index -= 1<0:
31 | tasks = []
32 | for m in range(index_style[0]):
33 | tasks.append((index_style[0], m, index_style[0]))
34 | task_group.append(tasks)
35 | # middle frames
36 | for l, r in zip(index_style[:-1], index_style[1:]):
37 | tasks = []
38 | for m in range(l, r):
39 | tasks.append((l, m, r))
40 | task_group.append(tasks)
41 | # last frame
42 | tasks = []
43 | for m in range(index_style[-1], n):
44 | tasks.append((index_style[-1], m, index_style[-1]))
45 | task_group.append(tasks)
46 | return task_group
47 |
48 | def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
49 | patch_match_engine = PyramidPatchMatcher(
50 | image_height=frames_style[0].shape[0],
51 | image_width=frames_style[0].shape[1],
52 | channel=3,
53 | use_mean_target_style=False,
54 | use_pairwise_patch_error=True,
55 | **ebsynth_config
56 | )
57 | # task
58 | index_dict = self.get_index_dict(index_style)
59 | task_group = self.get_task_group(index_style, len(frames_guide))
60 | # run
61 | for tasks in task_group:
62 | index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
63 | for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
64 | tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
65 | source_guide, target_guide, source_style = [], [], []
66 | for l, m, r in tasks_batch:
67 | # l -> m
68 | source_guide.append(frames_guide[l])
69 | target_guide.append(frames_guide[m])
70 | source_style.append(frames_style[index_dict[l]])
71 | # r -> m
72 | source_guide.append(frames_guide[r])
73 | target_guide.append(frames_guide[m])
74 | source_style.append(frames_style[index_dict[r]])
75 | source_guide = np.stack(source_guide)
76 | target_guide = np.stack(target_guide)
77 | source_style = np.stack(source_style)
78 | _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
79 | if save_path is not None:
80 | for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
81 | weight_l, weight_r = self.get_weight(l, m, r)
82 | frame = frame_l * weight_l + frame_r * weight_r
83 | frame = frame.clip(0, 255).astype("uint8")
84 | Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
85 |
86 |
87 | class InterpolationModeSingleFrameRunner:
88 | def __init__(self):
89 | pass
90 |
91 | def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
92 | # check input
93 | tracking_window_size = ebsynth_config["tracking_window_size"]
94 | if tracking_window_size * 2 >= batch_size:
95 | raise ValueError("batch_size should be larger than track_window_size * 2")
96 | frame_style = frames_style[0]
97 | frame_guide = frames_guide[index_style[0]]
98 | patch_match_engine = PyramidPatchMatcher(
99 | image_height=frame_style.shape[0],
100 | image_width=frame_style.shape[1],
101 | channel=3,
102 | **ebsynth_config
103 | )
104 | # run
105 | frame_id, n = 0, len(frames_guide)
106 | for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
107 | if i + batch_size > n:
108 | l, r = max(n - batch_size, 0), n
109 | else:
110 | l, r = i, i + batch_size
111 | source_guide = np.stack([frame_guide] * (r-l))
112 | target_guide = np.stack([frames_guide[i] for i in range(l, r)])
113 | source_style = np.stack([frame_style] * (r-l))
114 | _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
115 | for i, frame in zip(range(l, r), target_style):
116 | if i==frame_id:
117 | frame = frame.clip(0, 255).astype("uint8")
118 | Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
119 | frame_id += 1
120 | if r < n and r-frame_id <= tracking_window_size:
121 | break
122 |
--------------------------------------------------------------------------------
/diffsynth/extensions/RIFE/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from PIL import Image
6 |
7 |
8 | def warp(tenInput, tenFlow, device):
9 | backwarp_tenGrid = {}
10 | k = (str(tenFlow.device), str(tenFlow.size()))
11 | if k not in backwarp_tenGrid:
12 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
13 | 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
14 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
15 | 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
16 | backwarp_tenGrid[k] = torch.cat(
17 | [tenHorizontal, tenVertical], 1).to(device)
18 |
19 | tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
20 | tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
21 |
22 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
23 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
24 |
25 |
26 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
27 | return nn.Sequential(
28 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
29 | padding=padding, dilation=dilation, bias=True),
30 | nn.PReLU(out_planes)
31 | )
32 |
33 |
34 | class IFBlock(nn.Module):
35 | def __init__(self, in_planes, c=64):
36 | super(IFBlock, self).__init__()
37 | self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),)
38 | self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
39 | self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
40 | self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
41 | self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
42 | self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1))
43 | self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1))
44 |
45 | def forward(self, x, flow, scale=1):
46 | x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
47 | flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
48 | feat = self.conv0(torch.cat((x, flow), 1))
49 | feat = self.convblock0(feat) + feat
50 | feat = self.convblock1(feat) + feat
51 | feat = self.convblock2(feat) + feat
52 | feat = self.convblock3(feat) + feat
53 | flow = self.conv1(feat)
54 | mask = self.conv2(feat)
55 | flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
56 | mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
57 | return flow, mask
58 |
59 |
60 | class IFNet(nn.Module):
61 | def __init__(self):
62 | super(IFNet, self).__init__()
63 | self.block0 = IFBlock(7+4, c=90)
64 | self.block1 = IFBlock(7+4, c=90)
65 | self.block2 = IFBlock(7+4, c=90)
66 | self.block_tea = IFBlock(10+4, c=90)
67 |
68 | def forward(self, x, scale_list=[4, 2, 1], training=False):
69 | if training == False:
70 | channel = x.shape[1] // 2
71 | img0 = x[:, :channel]
72 | img1 = x[:, channel:]
73 | flow_list = []
74 | merged = []
75 | mask_list = []
76 | warped_img0 = img0
77 | warped_img1 = img1
78 | flow = (x[:, :4]).detach() * 0
79 | mask = (x[:, :1]).detach() * 0
80 | block = [self.block0, self.block1, self.block2]
81 | for i in range(3):
82 | f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
83 | f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
84 | flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
85 | mask = mask + (m0 + (-m1)) / 2
86 | mask_list.append(mask)
87 | flow_list.append(flow)
88 | warped_img0 = warp(img0, flow[:, :2], device=x.device)
89 | warped_img1 = warp(img1, flow[:, 2:4], device=x.device)
90 | merged.append((warped_img0, warped_img1))
91 | '''
92 | c0 = self.contextnet(img0, flow[:, :2])
93 | c1 = self.contextnet(img1, flow[:, 2:4])
94 | tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
95 | res = tmp[:, 1:4] * 2 - 1
96 | '''
97 | for i in range(3):
98 | mask_list[i] = torch.sigmoid(mask_list[i])
99 | merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
100 | return flow_list, mask_list[2], merged
101 |
102 | def state_dict_converter(self):
103 | return IFNetStateDictConverter()
104 |
105 |
106 | class IFNetStateDictConverter:
107 | def __init__(self):
108 | pass
109 |
110 | def from_diffusers(self, state_dict):
111 | state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()}
112 | return state_dict_
113 |
114 | def from_civitai(self, state_dict):
115 | return self.from_diffusers(state_dict)
116 |
117 |
118 | class RIFEInterpolater:
119 | def __init__(self, model, device="cuda"):
120 | self.model = model
121 | self.device = device
122 | # IFNet only does not support float16
123 | self.torch_dtype = torch.float32
124 |
125 | @staticmethod
126 | def from_model_manager(model_manager):
127 | return RIFEInterpolater(model_manager.RIFE, device=model_manager.device)
128 |
129 | def process_image(self, image):
130 | width, height = image.size
131 | if width % 32 != 0 or height % 32 != 0:
132 | width = (width + 31) // 32
133 | height = (height + 31) // 32
134 | image = image.resize((width, height))
135 | image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
136 | return image
137 |
138 | def process_images(self, images):
139 | images = [self.process_image(image) for image in images]
140 | images = torch.stack(images)
141 | return images
142 |
143 | def decode_images(self, images):
144 | images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
145 | images = [Image.fromarray(image) for image in images]
146 | return images
147 |
148 | def add_interpolated_images(self, images, interpolated_images):
149 | output_images = []
150 | for image, interpolated_image in zip(images, interpolated_images):
151 | output_images.append(image)
152 | output_images.append(interpolated_image)
153 | output_images.append(images[-1])
154 | return output_images
155 |
156 |
157 | @torch.no_grad()
158 | def interpolate_(self, images, scale=1.0):
159 | input_tensor = self.process_images(images)
160 | input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1)
161 | input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype)
162 | flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale])
163 | output_images = self.decode_images(merged[2].cpu())
164 | if output_images[0].size != images[0].size:
165 | output_images = [image.resize(images[0].size) for image in output_images]
166 | return output_images
167 |
168 |
169 | @torch.no_grad()
170 | def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x):
171 | # Preprocess
172 | processed_images = self.process_images(images)
173 |
174 | for iter in range(num_iter):
175 | # Input
176 | input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1)
177 |
178 | # Interpolate
179 | output_tensor = []
180 | for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
181 | batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
182 | batch_input_tensor = input_tensor[batch_id: batch_id_]
183 | batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
184 | flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
185 | output_tensor.append(merged[2].cpu())
186 |
187 | # Output
188 | output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1)
189 | processed_images = self.add_interpolated_images(processed_images, output_tensor)
190 | processed_images = torch.stack(processed_images)
191 |
192 | # To images
193 | output_images = self.decode_images(processed_images)
194 | if output_images[0].size != images[0].size:
195 | output_images = [image.resize(images[0].size) for image in output_images]
196 | return output_images
197 |
198 |
199 | class RIFESmoother(RIFEInterpolater):
200 | def __init__(self, model, device="cuda"):
201 | super(RIFESmoother, self).__init__(model, device=device)
202 |
203 | @staticmethod
204 | def from_model_manager(model_manager):
205 | return RIFESmoother(model_manager.RIFE, device=model_manager.device)
206 |
207 | def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
208 | output_tensor = []
209 | for batch_id in range(0, input_tensor.shape[0], batch_size):
210 | batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
211 | batch_input_tensor = input_tensor[batch_id: batch_id_]
212 | batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
213 | flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
214 | output_tensor.append(merged[2].cpu())
215 | output_tensor = torch.concat(output_tensor, dim=0)
216 | return output_tensor
217 |
218 | @torch.no_grad()
219 | def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs):
220 | # Preprocess
221 | processed_images = self.process_images(rendered_frames)
222 |
223 | for iter in range(num_iter):
224 | # Input
225 | input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
226 |
227 | # Interpolate
228 | output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
229 |
230 | # Blend
231 | input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
232 | output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
233 |
234 | # Add to frames
235 | processed_images[1:-1] = output_tensor
236 |
237 | # To images
238 | output_images = self.decode_images(processed_images)
239 | if output_images[0].size != rendered_frames[0].size:
240 | output_images = [image.resize(rendered_frames[0].size) for image in output_images]
241 | return output_images
242 |
--------------------------------------------------------------------------------
/diffsynth/models/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import rearrange
3 |
4 |
5 | def low_version_attention(query, key, value, attn_bias=None):
6 | scale = 1 / query.shape[-1] ** 0.5
7 | query = query * scale
8 | attn = torch.matmul(query, key.transpose(-2, -1))
9 | if attn_bias is not None:
10 | attn = attn + attn_bias
11 | attn = attn.softmax(-1)
12 | return attn @ value
13 |
14 |
15 | class Attention(torch.nn.Module):
16 |
17 | def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
18 | super().__init__()
19 | dim_inner = head_dim * num_heads
20 | kv_dim = kv_dim if kv_dim is not None else q_dim
21 | self.num_heads = num_heads
22 | self.head_dim = head_dim
23 |
24 | self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
25 | self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
26 | self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
27 | self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
28 |
29 | def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
30 | batch_size = q.shape[0]
31 | ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
32 | ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
33 | ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
34 | hidden_states = hidden_states + scale * ip_hidden_states
35 | return hidden_states
36 |
37 | def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
38 | if encoder_hidden_states is None:
39 | encoder_hidden_states = hidden_states
40 |
41 | batch_size = encoder_hidden_states.shape[0]
42 |
43 | q = self.to_q(hidden_states)
44 | k = self.to_k(encoder_hidden_states)
45 | v = self.to_v(encoder_hidden_states)
46 |
47 | q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
48 | k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
49 | v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
50 |
51 | if qkv_preprocessor is not None:
52 | q, k, v = qkv_preprocessor(q, k, v)
53 |
54 | hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
55 | if ipadapter_kwargs is not None:
56 | hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
57 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
58 | hidden_states = hidden_states.to(q.dtype)
59 |
60 | hidden_states = self.to_out(hidden_states)
61 |
62 | return hidden_states
63 |
64 | def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
65 | if encoder_hidden_states is None:
66 | encoder_hidden_states = hidden_states
67 |
68 | q = self.to_q(hidden_states)
69 | k = self.to_k(encoder_hidden_states)
70 | v = self.to_v(encoder_hidden_states)
71 |
72 | q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
73 | k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
74 | v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
75 |
76 | if attn_mask is not None:
77 | hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
78 | else:
79 | import xformers.ops as xops
80 | hidden_states = xops.memory_efficient_attention(q, k, v)
81 | hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
82 |
83 | hidden_states = hidden_states.to(q.dtype)
84 | hidden_states = self.to_out(hidden_states)
85 |
86 | return hidden_states
87 |
88 | def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
89 | return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
--------------------------------------------------------------------------------
/diffsynth/models/downloader.py:
--------------------------------------------------------------------------------
1 | from huggingface_hub import hf_hub_download
2 | from http.cookiejar import CookieJar
3 | from pathlib import Path
4 | from typing import Dict, Optional, List, Union
5 | import copy, uuid, requests, io, platform, pickle, os, urllib
6 | from requests.adapters import Retry
7 | from tqdm import tqdm
8 |
9 |
10 | def _get_sep(path):
11 | if isinstance(path, bytes):
12 | return b'/'
13 | else:
14 | return '/'
15 |
16 |
17 | def expanduser(path):
18 | """Expand ~ and ~user constructions. If user or $HOME is unknown,
19 | do nothing."""
20 | path = os.fspath(path)
21 | if isinstance(path, bytes):
22 | tilde = b'~'
23 | else:
24 | tilde = '~'
25 | if not path.startswith(tilde):
26 | return path
27 | sep = _get_sep(path)
28 | i = path.find(sep, 1)
29 | if i < 0:
30 | i = len(path)
31 | if i == 1:
32 | if 'HOME' not in os.environ:
33 | import pwd
34 | try:
35 | userhome = pwd.getpwuid(os.getuid()).pw_dir
36 | except KeyError:
37 | # bpo-10496: if the current user identifier doesn't exist in the
38 | # password database, return the path unchanged
39 | return path
40 | else:
41 | userhome = os.environ['HOME']
42 | else:
43 | import pwd
44 | name = path[1:i]
45 | if isinstance(name, bytes):
46 | name = str(name, 'ASCII')
47 | try:
48 | pwent = pwd.getpwnam(name)
49 | except KeyError:
50 | # bpo-10496: if the user name from the path doesn't exist in the
51 | # password database, return the path unchanged
52 | return path
53 | userhome = pwent.pw_dir
54 | if isinstance(path, bytes):
55 | userhome = os.fsencode(userhome)
56 | root = b'/'
57 | else:
58 | root = '/'
59 | userhome = userhome.rstrip(root)
60 | return (userhome + path[i:]) or root
61 |
62 |
63 |
64 | class ModelScopeConfig:
65 | DEFAULT_CREDENTIALS_PATH = Path.home().joinpath('.modelscope', 'credentials')
66 | path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
67 | COOKIES_FILE_NAME = 'cookies'
68 | GIT_TOKEN_FILE_NAME = 'git_token'
69 | USER_INFO_FILE_NAME = 'user'
70 | USER_SESSION_ID_FILE_NAME = 'session'
71 |
72 | @staticmethod
73 | def make_sure_credential_path_exist():
74 | os.makedirs(ModelScopeConfig.path_credential, exist_ok=True)
75 |
76 | @staticmethod
77 | def get_user_session_id():
78 | session_path = os.path.join(ModelScopeConfig.path_credential,
79 | ModelScopeConfig.USER_SESSION_ID_FILE_NAME)
80 | session_id = ''
81 | if os.path.exists(session_path):
82 | with open(session_path, 'rb') as f:
83 | session_id = str(f.readline().strip(), encoding='utf-8')
84 | return session_id
85 | if session_id == '' or len(session_id) != 32:
86 | session_id = str(uuid.uuid4().hex)
87 | ModelScopeConfig.make_sure_credential_path_exist()
88 | with open(session_path, 'w+') as wf:
89 | wf.write(session_id)
90 |
91 | return session_id
92 |
93 | @staticmethod
94 | def get_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str:
95 | """Formats a user-agent string with basic info about a request.
96 |
97 | Args:
98 | user_agent (`str`, `dict`, *optional*):
99 | The user agent info in the form of a dictionary or a single string.
100 |
101 | Returns:
102 | The formatted user-agent string.
103 | """
104 |
105 | # include some more telemetrics when executing in dedicated
106 | # cloud containers
107 | MODELSCOPE_CLOUD_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT'
108 | MODELSCOPE_CLOUD_USERNAME = 'MODELSCOPE_USERNAME'
109 | env = 'custom'
110 | if MODELSCOPE_CLOUD_ENVIRONMENT in os.environ:
111 | env = os.environ[MODELSCOPE_CLOUD_ENVIRONMENT]
112 | user_name = 'unknown'
113 | if MODELSCOPE_CLOUD_USERNAME in os.environ:
114 | user_name = os.environ[MODELSCOPE_CLOUD_USERNAME]
115 |
116 | ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s; user/%s' % (
117 | "1.15.0",
118 | platform.python_version(),
119 | ModelScopeConfig.get_user_session_id(),
120 | platform.platform(),
121 | platform.processor(),
122 | env,
123 | user_name,
124 | )
125 | if isinstance(user_agent, dict):
126 | ua += '; ' + '; '.join(f'{k}/{v}' for k, v in user_agent.items())
127 | elif isinstance(user_agent, str):
128 | ua += '; ' + user_agent
129 | return ua
130 |
131 | @staticmethod
132 | def get_cookies():
133 | cookies_path = os.path.join(ModelScopeConfig.path_credential,
134 | ModelScopeConfig.COOKIES_FILE_NAME)
135 | if os.path.exists(cookies_path):
136 | with open(cookies_path, 'rb') as f:
137 | cookies = pickle.load(f)
138 | return cookies
139 | return None
140 |
141 |
142 |
143 | def modelscope_http_get_model_file(
144 | url: str,
145 | local_dir: str,
146 | file_name: str,
147 | file_size: int,
148 | cookies: CookieJar,
149 | headers: Optional[Dict[str, str]] = None,
150 | ):
151 | """Download remote file, will retry 5 times before giving up on errors.
152 |
153 | Args:
154 | url(str):
155 | actual download url of the file
156 | local_dir(str):
157 | local directory where the downloaded file stores
158 | file_name(str):
159 | name of the file stored in `local_dir`
160 | file_size(int):
161 | The file size.
162 | cookies(CookieJar):
163 | cookies used to authentication the user, which is used for downloading private repos
164 | headers(Dict[str, str], optional):
165 | http headers to carry necessary info when requesting the remote file
166 |
167 | Raises:
168 | FileDownloadError: File download failed.
169 |
170 | """
171 | get_headers = {} if headers is None else copy.deepcopy(headers)
172 | get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
173 | temp_file_path = os.path.join(local_dir, file_name)
174 | # retry sleep 0.5s, 1s, 2s, 4s
175 | retry = Retry(
176 | total=5,
177 | backoff_factor=1,
178 | allowed_methods=['GET'])
179 | while True:
180 | try:
181 | progress = tqdm(
182 | unit='B',
183 | unit_scale=True,
184 | unit_divisor=1024,
185 | total=file_size,
186 | initial=0,
187 | desc='Downloading',
188 | )
189 | partial_length = 0
190 | if os.path.exists(
191 | temp_file_path): # download partial, continue download
192 | with open(temp_file_path, 'rb') as f:
193 | partial_length = f.seek(0, io.SEEK_END)
194 | progress.update(partial_length)
195 | if partial_length > file_size:
196 | break
197 | get_headers['Range'] = 'bytes=%s-%s' % (partial_length,
198 | file_size - 1)
199 | with open(temp_file_path, 'ab') as f:
200 | r = requests.get(
201 | url,
202 | stream=True,
203 | headers=get_headers,
204 | cookies=cookies,
205 | timeout=60)
206 | r.raise_for_status()
207 | for chunk in r.iter_content(
208 | chunk_size=1024 * 1024 * 1):
209 | if chunk: # filter out keep-alive new chunks
210 | progress.update(len(chunk))
211 | f.write(chunk)
212 | progress.close()
213 | break
214 | except (Exception) as e: # no matter what happen, we will retry.
215 | retry = retry.increment('GET', url, error=e)
216 | retry.sleep()
217 |
218 |
219 | def get_endpoint():
220 | MODELSCOPE_URL_SCHEME = 'https://'
221 | DEFAULT_MODELSCOPE_DOMAIN = 'www.modelscope.cn'
222 | modelscope_domain = os.getenv('MODELSCOPE_DOMAIN',
223 | DEFAULT_MODELSCOPE_DOMAIN)
224 | return MODELSCOPE_URL_SCHEME + modelscope_domain
225 |
226 |
227 | def get_file_download_url(model_id: str, file_path: str, revision: str):
228 | """Format file download url according to `model_id`, `revision` and `file_path`.
229 | e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`,
230 | the resulted download url is: https://modelscope.cn/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md
231 |
232 | Args:
233 | model_id (str): The model_id.
234 | file_path (str): File path
235 | revision (str): File revision.
236 |
237 | Returns:
238 | str: The file url.
239 | """
240 | file_path = urllib.parse.quote_plus(file_path)
241 | revision = urllib.parse.quote_plus(revision)
242 | download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}'
243 | return download_url_template.format(
244 | endpoint=get_endpoint(),
245 | model_id=model_id,
246 | revision=revision,
247 | file_path=file_path,
248 | )
249 |
250 |
251 | def download_from_modelscope(model_id, origin_file_path, local_dir):
252 | os.makedirs(local_dir, exist_ok=True)
253 | if os.path.basename(origin_file_path) in os.listdir(local_dir):
254 | print(f"{os.path.basename(origin_file_path)} has been already in {local_dir}.")
255 | return
256 | else:
257 | print(f"Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
258 | headers = {'user-agent': ModelScopeConfig.get_user_agent(user_agent=None)}
259 | cookies = ModelScopeConfig.get_cookies()
260 | url = get_file_download_url(model_id=model_id, file_path=origin_file_path, revision="master")
261 | modelscope_http_get_model_file(
262 | url,
263 | local_dir,
264 | os.path.basename(origin_file_path),
265 | file_size=0,
266 | headers=headers,
267 | cookies=cookies
268 | )
269 |
270 |
271 | def download_from_huggingface(model_id, origin_file_path, local_dir):
272 | os.makedirs(local_dir, exist_ok=True)
273 | if os.path.basename(origin_file_path) in os.listdir(local_dir):
274 | print(f"{os.path.basename(origin_file_path)} has been already in {local_dir}.")
275 | return
276 | else:
277 | print(f"Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
278 | hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
279 |
--------------------------------------------------------------------------------
/diffsynth/models/hunyuan_dit_text_encoder.py:
--------------------------------------------------------------------------------
1 | from transformers import BertModel, BertConfig, T5EncoderModel, T5Config
2 | import torch
3 |
4 |
5 |
6 | class HunyuanDiTCLIPTextEncoder(BertModel):
7 | def __init__(self):
8 | config = BertConfig(
9 | _name_or_path = "",
10 | architectures = ["BertModel"],
11 | attention_probs_dropout_prob = 0.1,
12 | bos_token_id = 0,
13 | classifier_dropout = None,
14 | directionality = "bidi",
15 | eos_token_id = 2,
16 | hidden_act = "gelu",
17 | hidden_dropout_prob = 0.1,
18 | hidden_size = 1024,
19 | initializer_range = 0.02,
20 | intermediate_size = 4096,
21 | layer_norm_eps = 1e-12,
22 | max_position_embeddings = 512,
23 | model_type = "bert",
24 | num_attention_heads = 16,
25 | num_hidden_layers = 24,
26 | output_past = True,
27 | pad_token_id = 0,
28 | pooler_fc_size = 768,
29 | pooler_num_attention_heads = 12,
30 | pooler_num_fc_layers = 3,
31 | pooler_size_per_head = 128,
32 | pooler_type = "first_token_transform",
33 | position_embedding_type = "absolute",
34 | torch_dtype = "float32",
35 | transformers_version = "4.37.2",
36 | type_vocab_size = 2,
37 | use_cache = True,
38 | vocab_size = 47020
39 | )
40 | super().__init__(config, add_pooling_layer=False)
41 | self.eval()
42 |
43 | def forward(self, input_ids, attention_mask, clip_skip=1):
44 | input_shape = input_ids.size()
45 |
46 | batch_size, seq_length = input_shape
47 | device = input_ids.device
48 |
49 | past_key_values_length = 0
50 |
51 | if attention_mask is None:
52 | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
53 |
54 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
55 |
56 | embedding_output = self.embeddings(
57 | input_ids=input_ids,
58 | position_ids=None,
59 | token_type_ids=None,
60 | inputs_embeds=None,
61 | past_key_values_length=0,
62 | )
63 | encoder_outputs = self.encoder(
64 | embedding_output,
65 | attention_mask=extended_attention_mask,
66 | head_mask=None,
67 | encoder_hidden_states=None,
68 | encoder_attention_mask=None,
69 | past_key_values=None,
70 | use_cache=False,
71 | output_attentions=False,
72 | output_hidden_states=True,
73 | return_dict=True,
74 | )
75 | all_hidden_states = encoder_outputs.hidden_states
76 | prompt_emb = all_hidden_states[-clip_skip]
77 | if clip_skip > 1:
78 | mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std()
79 | prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
80 | return prompt_emb
81 |
82 | def state_dict_converter(self):
83 | return HunyuanDiTCLIPTextEncoderStateDictConverter()
84 |
85 |
86 |
87 | class HunyuanDiTT5TextEncoder(T5EncoderModel):
88 | def __init__(self):
89 | config = T5Config(
90 | _name_or_path = "../HunyuanDiT/t2i/mt5",
91 | architectures = ["MT5ForConditionalGeneration"],
92 | classifier_dropout = 0.0,
93 | d_ff = 5120,
94 | d_kv = 64,
95 | d_model = 2048,
96 | decoder_start_token_id = 0,
97 | dense_act_fn = "gelu_new",
98 | dropout_rate = 0.1,
99 | eos_token_id = 1,
100 | feed_forward_proj = "gated-gelu",
101 | initializer_factor = 1.0,
102 | is_encoder_decoder = True,
103 | is_gated_act = True,
104 | layer_norm_epsilon = 1e-06,
105 | model_type = "t5",
106 | num_decoder_layers = 24,
107 | num_heads = 32,
108 | num_layers = 24,
109 | output_past = True,
110 | pad_token_id = 0,
111 | relative_attention_max_distance = 128,
112 | relative_attention_num_buckets = 32,
113 | tie_word_embeddings = False,
114 | tokenizer_class = "T5Tokenizer",
115 | transformers_version = "4.37.2",
116 | use_cache = True,
117 | vocab_size = 250112
118 | )
119 | super().__init__(config)
120 | self.eval()
121 |
122 | def forward(self, input_ids, attention_mask, clip_skip=1):
123 | outputs = super().forward(
124 | input_ids=input_ids,
125 | attention_mask=attention_mask,
126 | output_hidden_states=True,
127 | )
128 | prompt_emb = outputs.hidden_states[-clip_skip]
129 | if clip_skip > 1:
130 | mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std()
131 | prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
132 | return prompt_emb
133 |
134 | def state_dict_converter(self):
135 | return HunyuanDiTT5TextEncoderStateDictConverter()
136 |
137 |
138 |
139 | class HunyuanDiTCLIPTextEncoderStateDictConverter():
140 | def __init__(self):
141 | pass
142 |
143 | def from_diffusers(self, state_dict):
144 | state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")}
145 | return state_dict_
146 |
147 | def from_civitai(self, state_dict):
148 | return self.from_diffusers(state_dict)
149 |
150 |
151 | class HunyuanDiTT5TextEncoderStateDictConverter():
152 | def __init__(self):
153 | pass
154 |
155 | def from_diffusers(self, state_dict):
156 | state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")}
157 | state_dict_["shared.weight"] = state_dict["shared.weight"]
158 | return state_dict_
159 |
160 | def from_civitai(self, state_dict):
161 | return self.from_diffusers(state_dict)
162 |
--------------------------------------------------------------------------------
/diffsynth/models/sd_ipadapter.py:
--------------------------------------------------------------------------------
1 | from .svd_image_encoder import SVDImageEncoder
2 | from .sdxl_ipadapter import IpAdapterImageProjModel, IpAdapterModule, SDXLIpAdapterStateDictConverter
3 | from transformers import CLIPImageProcessor
4 | import torch
5 |
6 |
7 | class IpAdapterCLIPImageEmbedder(SVDImageEncoder):
8 | def __init__(self):
9 | super().__init__()
10 | self.image_processor = CLIPImageProcessor()
11 |
12 | def forward(self, image):
13 | pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
14 | pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
15 | return super().forward(pixel_values)
16 |
17 |
18 | class SDIpAdapter(torch.nn.Module):
19 | def __init__(self):
20 | super().__init__()
21 | shape_list = [(768, 320)] * 2 + [(768, 640)] * 2 + [(768, 1280)] * 5 + [(768, 640)] * 3 + [(768, 320)] * 3 + [(768, 1280)] * 1
22 | self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
23 | self.image_proj = IpAdapterImageProjModel(cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4)
24 | self.set_full_adapter()
25 |
26 | def set_full_adapter(self):
27 | block_ids = [1, 4, 9, 12, 17, 20, 40, 43, 46, 50, 53, 56, 60, 63, 66, 29]
28 | self.call_block_id = {(i, 0): j for j, i in enumerate(block_ids)}
29 |
30 | def set_less_adapter(self):
31 | # IP-Adapter for SD v1.5 doesn't support this feature.
32 | self.set_full_adapter(self)
33 |
34 | def forward(self, hidden_states, scale=1.0):
35 | hidden_states = self.image_proj(hidden_states)
36 | hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
37 | ip_kv_dict = {}
38 | for (block_id, transformer_id) in self.call_block_id:
39 | ipadapter_id = self.call_block_id[(block_id, transformer_id)]
40 | ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
41 | if block_id not in ip_kv_dict:
42 | ip_kv_dict[block_id] = {}
43 | ip_kv_dict[block_id][transformer_id] = {
44 | "ip_k": ip_k,
45 | "ip_v": ip_v,
46 | "scale": scale
47 | }
48 | return ip_kv_dict
49 |
50 | def state_dict_converter(self):
51 | return SDIpAdapterStateDictConverter()
52 |
53 |
54 | class SDIpAdapterStateDictConverter(SDXLIpAdapterStateDictConverter):
55 | def __init__(self):
56 | pass
57 |
--------------------------------------------------------------------------------
/diffsynth/models/sd_lora.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .sd_unet import SDUNetStateDictConverter, SDUNet
3 | from .sd_text_encoder import SDTextEncoderStateDictConverter, SDTextEncoder
4 |
5 |
6 | class SDLoRA:
7 | def __init__(self):
8 | pass
9 |
10 | def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"):
11 | special_keys = {
12 | "down.blocks": "down_blocks",
13 | "up.blocks": "up_blocks",
14 | "mid.block": "mid_block",
15 | "proj.in": "proj_in",
16 | "proj.out": "proj_out",
17 | "transformer.blocks": "transformer_blocks",
18 | "to.q": "to_q",
19 | "to.k": "to_k",
20 | "to.v": "to_v",
21 | "to.out": "to_out",
22 | }
23 | state_dict_ = {}
24 | for key in state_dict:
25 | if ".lora_up" not in key:
26 | continue
27 | if not key.startswith(lora_prefix):
28 | continue
29 | weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
30 | weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
31 | if len(weight_up.shape) == 4:
32 | weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
33 | weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
34 | lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
35 | else:
36 | lora_weight = alpha * torch.mm(weight_up, weight_down)
37 | target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight"
38 | for special_key in special_keys:
39 | target_name = target_name.replace(special_key, special_keys[special_key])
40 | state_dict_[target_name] = lora_weight.cpu()
41 | return state_dict_
42 |
43 | def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"):
44 | state_dict_unet = unet.state_dict()
45 | state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device)
46 | state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora)
47 | if len(state_dict_lora) > 0:
48 | for name in state_dict_lora:
49 | state_dict_unet[name] += state_dict_lora[name].to(device=device)
50 | unet.load_state_dict(state_dict_unet)
51 |
52 | def add_lora_to_text_encoder(self, text_encoder: SDTextEncoder, state_dict_lora, alpha=1.0, device="cuda"):
53 | state_dict_text_encoder = text_encoder.state_dict()
54 | state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_te_", alpha=alpha, device=device)
55 | state_dict_lora = SDTextEncoderStateDictConverter().from_diffusers(state_dict_lora)
56 | if len(state_dict_lora) > 0:
57 | for name in state_dict_lora:
58 | state_dict_text_encoder[name] += state_dict_lora[name].to(device=device)
59 | text_encoder.load_state_dict(state_dict_text_encoder)
60 |
61 |
--------------------------------------------------------------------------------
/diffsynth/models/sd_motion.py:
--------------------------------------------------------------------------------
1 | from .sd_unet import SDUNet, Attention, GEGLU
2 | import torch
3 | from einops import rearrange, repeat
4 |
5 |
6 | class TemporalTransformerBlock(torch.nn.Module):
7 |
8 | def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32):
9 | super().__init__()
10 |
11 | # 1. Self-Attn
12 | self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
13 | self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
14 | self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
15 |
16 | # 2. Cross-Attn
17 | self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
18 | self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
19 | self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
20 |
21 | # 3. Feed-forward
22 | self.norm3 = torch.nn.LayerNorm(dim, elementwise_affine=True)
23 | self.act_fn = GEGLU(dim, dim * 4)
24 | self.ff = torch.nn.Linear(dim * 4, dim)
25 |
26 |
27 | def forward(self, hidden_states, batch_size=1):
28 |
29 | # 1. Self-Attention
30 | norm_hidden_states = self.norm1(hidden_states)
31 | norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
32 | attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]])
33 | attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
34 | hidden_states = attn_output + hidden_states
35 |
36 | # 2. Cross-Attention
37 | norm_hidden_states = self.norm2(hidden_states)
38 | norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
39 | attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]])
40 | attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
41 | hidden_states = attn_output + hidden_states
42 |
43 | # 3. Feed-forward
44 | norm_hidden_states = self.norm3(hidden_states)
45 | ff_output = self.act_fn(norm_hidden_states)
46 | ff_output = self.ff(ff_output)
47 | hidden_states = ff_output + hidden_states
48 |
49 | return hidden_states
50 |
51 |
52 | class TemporalBlock(torch.nn.Module):
53 |
54 | def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
55 | super().__init__()
56 | inner_dim = num_attention_heads * attention_head_dim
57 |
58 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
59 | self.proj_in = torch.nn.Linear(in_channels, inner_dim)
60 |
61 | self.transformer_blocks = torch.nn.ModuleList([
62 | TemporalTransformerBlock(
63 | inner_dim,
64 | num_attention_heads,
65 | attention_head_dim
66 | )
67 | for d in range(num_layers)
68 | ])
69 |
70 | self.proj_out = torch.nn.Linear(inner_dim, in_channels)
71 |
72 | def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1):
73 | batch, _, height, width = hidden_states.shape
74 | residual = hidden_states
75 |
76 | hidden_states = self.norm(hidden_states)
77 | inner_dim = hidden_states.shape[1]
78 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
79 | hidden_states = self.proj_in(hidden_states)
80 |
81 | for block in self.transformer_blocks:
82 | hidden_states = block(
83 | hidden_states,
84 | batch_size=batch_size
85 | )
86 |
87 | hidden_states = self.proj_out(hidden_states)
88 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
89 | hidden_states = hidden_states + residual
90 |
91 | return hidden_states, time_emb, text_emb, res_stack
92 |
93 |
94 | class SDMotionModel(torch.nn.Module):
95 | def __init__(self):
96 | super().__init__()
97 | self.motion_modules = torch.nn.ModuleList([
98 | TemporalBlock(8, 40, 320, eps=1e-6),
99 | TemporalBlock(8, 40, 320, eps=1e-6),
100 | TemporalBlock(8, 80, 640, eps=1e-6),
101 | TemporalBlock(8, 80, 640, eps=1e-6),
102 | TemporalBlock(8, 160, 1280, eps=1e-6),
103 | TemporalBlock(8, 160, 1280, eps=1e-6),
104 | TemporalBlock(8, 160, 1280, eps=1e-6),
105 | TemporalBlock(8, 160, 1280, eps=1e-6),
106 | TemporalBlock(8, 160, 1280, eps=1e-6),
107 | TemporalBlock(8, 160, 1280, eps=1e-6),
108 | TemporalBlock(8, 160, 1280, eps=1e-6),
109 | TemporalBlock(8, 160, 1280, eps=1e-6),
110 | TemporalBlock(8, 160, 1280, eps=1e-6),
111 | TemporalBlock(8, 160, 1280, eps=1e-6),
112 | TemporalBlock(8, 160, 1280, eps=1e-6),
113 | TemporalBlock(8, 80, 640, eps=1e-6),
114 | TemporalBlock(8, 80, 640, eps=1e-6),
115 | TemporalBlock(8, 80, 640, eps=1e-6),
116 | TemporalBlock(8, 40, 320, eps=1e-6),
117 | TemporalBlock(8, 40, 320, eps=1e-6),
118 | TemporalBlock(8, 40, 320, eps=1e-6),
119 | ])
120 | self.call_block_id = {
121 | 1: 0,
122 | 4: 1,
123 | 9: 2,
124 | 12: 3,
125 | 17: 4,
126 | 20: 5,
127 | 24: 6,
128 | 26: 7,
129 | 29: 8,
130 | 32: 9,
131 | 34: 10,
132 | 36: 11,
133 | 40: 12,
134 | 43: 13,
135 | 46: 14,
136 | 50: 15,
137 | 53: 16,
138 | 56: 17,
139 | 60: 18,
140 | 63: 19,
141 | 66: 20
142 | }
143 |
144 | def forward(self):
145 | pass
146 |
147 | def state_dict_converter(self):
148 | return SDMotionModelStateDictConverter()
149 |
150 |
151 | class SDMotionModelStateDictConverter:
152 | def __init__(self):
153 | pass
154 |
155 | def from_diffusers(self, state_dict):
156 | rename_dict = {
157 | "norm": "norm",
158 | "proj_in": "proj_in",
159 | "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
160 | "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
161 | "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
162 | "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
163 | "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
164 | "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
165 | "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
166 | "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
167 | "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
168 | "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
169 | "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
170 | "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
171 | "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
172 | "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
173 | "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
174 | "proj_out": "proj_out",
175 | }
176 | name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
177 | name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
178 | name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
179 | state_dict_ = {}
180 | last_prefix, module_id = "", -1
181 | for name in name_list:
182 | names = name.split(".")
183 | prefix_index = names.index("temporal_transformer") + 1
184 | prefix = ".".join(names[:prefix_index])
185 | if prefix != last_prefix:
186 | last_prefix = prefix
187 | module_id += 1
188 | middle_name = ".".join(names[prefix_index:-1])
189 | suffix = names[-1]
190 | if "pos_encoder" in names:
191 | rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
192 | else:
193 | rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
194 | state_dict_[rename] = state_dict[name]
195 | return state_dict_
196 |
197 | def from_civitai(self, state_dict):
198 | return self.from_diffusers(state_dict)
199 |
--------------------------------------------------------------------------------
/diffsynth/models/sdxl_ipadapter.py:
--------------------------------------------------------------------------------
1 | from .svd_image_encoder import SVDImageEncoder
2 | from transformers import CLIPImageProcessor
3 | import torch
4 |
5 |
6 | class IpAdapterXLCLIPImageEmbedder(SVDImageEncoder):
7 | def __init__(self):
8 | super().__init__(embed_dim=1664, encoder_intermediate_size=8192, projection_dim=1280, num_encoder_layers=48, num_heads=16, head_dim=104)
9 | self.image_processor = CLIPImageProcessor()
10 |
11 | def forward(self, image):
12 | pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
13 | pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
14 | return super().forward(pixel_values)
15 |
16 |
17 | class IpAdapterImageProjModel(torch.nn.Module):
18 | def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4):
19 | super().__init__()
20 | self.cross_attention_dim = cross_attention_dim
21 | self.clip_extra_context_tokens = clip_extra_context_tokens
22 | self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
23 | self.norm = torch.nn.LayerNorm(cross_attention_dim)
24 |
25 | def forward(self, image_embeds):
26 | clip_extra_context_tokens = self.proj(image_embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
27 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
28 | return clip_extra_context_tokens
29 |
30 |
31 | class IpAdapterModule(torch.nn.Module):
32 | def __init__(self, input_dim, output_dim):
33 | super().__init__()
34 | self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
35 | self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
36 |
37 | def forward(self, hidden_states):
38 | ip_k = self.to_k_ip(hidden_states)
39 | ip_v = self.to_v_ip(hidden_states)
40 | return ip_k, ip_v
41 |
42 |
43 | class SDXLIpAdapter(torch.nn.Module):
44 | def __init__(self):
45 | super().__init__()
46 | shape_list = [(2048, 640)] * 4 + [(2048, 1280)] * 50 + [(2048, 640)] * 6 + [(2048, 1280)] * 10
47 | self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
48 | self.image_proj = IpAdapterImageProjModel()
49 | self.set_full_adapter()
50 |
51 | def set_full_adapter(self):
52 | map_list = sum([
53 | [(7, i) for i in range(2)],
54 | [(10, i) for i in range(2)],
55 | [(15, i) for i in range(10)],
56 | [(18, i) for i in range(10)],
57 | [(25, i) for i in range(10)],
58 | [(28, i) for i in range(10)],
59 | [(31, i) for i in range(10)],
60 | [(35, i) for i in range(2)],
61 | [(38, i) for i in range(2)],
62 | [(41, i) for i in range(2)],
63 | [(21, i) for i in range(10)],
64 | ], [])
65 | self.call_block_id = {i: j for j, i in enumerate(map_list)}
66 |
67 | def set_less_adapter(self):
68 | map_list = sum([
69 | [(7, i) for i in range(2)],
70 | [(10, i) for i in range(2)],
71 | [(15, i) for i in range(10)],
72 | [(18, i) for i in range(10)],
73 | [(25, i) for i in range(10)],
74 | [(28, i) for i in range(10)],
75 | [(31, i) for i in range(10)],
76 | [(35, i) for i in range(2)],
77 | [(38, i) for i in range(2)],
78 | [(41, i) for i in range(2)],
79 | [(21, i) for i in range(10)],
80 | ], [])
81 | self.call_block_id = {i: j for j, i in enumerate(map_list) if j>=34 and j<44}
82 |
83 | def forward(self, hidden_states, scale=1.0):
84 | hidden_states = self.image_proj(hidden_states)
85 | hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
86 | ip_kv_dict = {}
87 | for (block_id, transformer_id) in self.call_block_id:
88 | ipadapter_id = self.call_block_id[(block_id, transformer_id)]
89 | ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
90 | if block_id not in ip_kv_dict:
91 | ip_kv_dict[block_id] = {}
92 | ip_kv_dict[block_id][transformer_id] = {
93 | "ip_k": ip_k,
94 | "ip_v": ip_v,
95 | "scale": scale
96 | }
97 | return ip_kv_dict
98 |
99 | def state_dict_converter(self):
100 | return SDXLIpAdapterStateDictConverter()
101 |
102 |
103 | class SDXLIpAdapterStateDictConverter:
104 | def __init__(self):
105 | pass
106 |
107 | def from_diffusers(self, state_dict):
108 | state_dict_ = {}
109 | for name in state_dict["ip_adapter"]:
110 | names = name.split(".")
111 | layer_id = str(int(names[0]) // 2)
112 | name_ = ".".join(["ipadapter_modules"] + [layer_id] + names[1:])
113 | state_dict_[name_] = state_dict["ip_adapter"][name]
114 | for name in state_dict["image_proj"]:
115 | name_ = "image_proj." + name
116 | state_dict_[name_] = state_dict["image_proj"][name]
117 | return state_dict_
118 |
119 | def from_civitai(self, state_dict):
120 | return self.from_diffusers(state_dict)
121 |
122 |
--------------------------------------------------------------------------------
/diffsynth/models/sdxl_motion.py:
--------------------------------------------------------------------------------
1 | from .sd_motion import TemporalBlock
2 | import torch
3 |
4 |
5 |
6 | class SDXLMotionModel(torch.nn.Module):
7 | def __init__(self):
8 | super().__init__()
9 | self.motion_modules = torch.nn.ModuleList([
10 | TemporalBlock(8, 320//8, 320, eps=1e-6),
11 | TemporalBlock(8, 320//8, 320, eps=1e-6),
12 |
13 | TemporalBlock(8, 640//8, 640, eps=1e-6),
14 | TemporalBlock(8, 640//8, 640, eps=1e-6),
15 |
16 | TemporalBlock(8, 1280//8, 1280, eps=1e-6),
17 | TemporalBlock(8, 1280//8, 1280, eps=1e-6),
18 |
19 | TemporalBlock(8, 1280//8, 1280, eps=1e-6),
20 | TemporalBlock(8, 1280//8, 1280, eps=1e-6),
21 | TemporalBlock(8, 1280//8, 1280, eps=1e-6),
22 |
23 | TemporalBlock(8, 640//8, 640, eps=1e-6),
24 | TemporalBlock(8, 640//8, 640, eps=1e-6),
25 | TemporalBlock(8, 640//8, 640, eps=1e-6),
26 |
27 | TemporalBlock(8, 320//8, 320, eps=1e-6),
28 | TemporalBlock(8, 320//8, 320, eps=1e-6),
29 | TemporalBlock(8, 320//8, 320, eps=1e-6),
30 | ])
31 | self.call_block_id = {
32 | 0: 0,
33 | 2: 1,
34 | 7: 2,
35 | 10: 3,
36 | 15: 4,
37 | 18: 5,
38 | 25: 6,
39 | 28: 7,
40 | 31: 8,
41 | 35: 9,
42 | 38: 10,
43 | 41: 11,
44 | 44: 12,
45 | 46: 13,
46 | 48: 14,
47 | }
48 |
49 | def forward(self):
50 | pass
51 |
52 | def state_dict_converter(self):
53 | return SDMotionModelStateDictConverter()
54 |
55 |
56 | class SDMotionModelStateDictConverter:
57 | def __init__(self):
58 | pass
59 |
60 | def from_diffusers(self, state_dict):
61 | rename_dict = {
62 | "norm": "norm",
63 | "proj_in": "proj_in",
64 | "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
65 | "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
66 | "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
67 | "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
68 | "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
69 | "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
70 | "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
71 | "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
72 | "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
73 | "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
74 | "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
75 | "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
76 | "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
77 | "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
78 | "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
79 | "proj_out": "proj_out",
80 | }
81 | name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
82 | name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
83 | name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
84 | state_dict_ = {}
85 | last_prefix, module_id = "", -1
86 | for name in name_list:
87 | names = name.split(".")
88 | prefix_index = names.index("temporal_transformer") + 1
89 | prefix = ".".join(names[:prefix_index])
90 | if prefix != last_prefix:
91 | last_prefix = prefix
92 | module_id += 1
93 | middle_name = ".".join(names[prefix_index:-1])
94 | suffix = names[-1]
95 | if "pos_encoder" in names:
96 | rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
97 | else:
98 | rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
99 | state_dict_[rename] = state_dict[name]
100 | return state_dict_
101 |
102 | def from_civitai(self, state_dict):
103 | return self.from_diffusers(state_dict)
104 |
--------------------------------------------------------------------------------
/diffsynth/models/sdxl_vae_decoder.py:
--------------------------------------------------------------------------------
1 | from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter
2 |
3 |
4 | class SDXLVAEDecoder(SDVAEDecoder):
5 | def __init__(self):
6 | super().__init__()
7 | self.scaling_factor = 0.13025
8 |
9 | def state_dict_converter(self):
10 | return SDXLVAEDecoderStateDictConverter()
11 |
12 |
13 | class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
14 | def __init__(self):
15 | super().__init__()
16 |
--------------------------------------------------------------------------------
/diffsynth/models/sdxl_vae_encoder.py:
--------------------------------------------------------------------------------
1 | from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
2 |
3 |
4 | class SDXLVAEEncoder(SDVAEEncoder):
5 | def __init__(self):
6 | super().__init__()
7 | self.scaling_factor = 0.13025
8 |
9 | def state_dict_converter(self):
10 | return SDXLVAEEncoderStateDictConverter()
11 |
12 |
13 | class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
14 | def __init__(self):
15 | super().__init__()
16 |
--------------------------------------------------------------------------------
/diffsynth/models/tiler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import rearrange, repeat
3 |
4 |
5 | class TileWorker:
6 | def __init__(self):
7 | pass
8 |
9 |
10 | def mask(self, height, width, border_width):
11 | # Create a mask with shape (height, width).
12 | # The centre area is filled with 1, and the border line is filled with values in range (0, 1].
13 | x = torch.arange(height).repeat(width, 1).T
14 | y = torch.arange(width).repeat(height, 1)
15 | mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
16 | mask = (mask / border_width).clip(0, 1)
17 | return mask
18 |
19 |
20 | def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
21 | # Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
22 | batch_size, channel, _, _ = model_input.shape
23 | model_input = model_input.to(device=tile_device, dtype=tile_dtype)
24 | unfold_operator = torch.nn.Unfold(
25 | kernel_size=(tile_size, tile_size),
26 | stride=(tile_stride, tile_stride)
27 | )
28 | model_input = unfold_operator(model_input)
29 | model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
30 |
31 | return model_input
32 |
33 |
34 | def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
35 | # Call y=forward_fn(x) for each tile
36 | tile_num = model_input.shape[-1]
37 | model_output_stack = []
38 |
39 | for tile_id in range(0, tile_num, tile_batch_size):
40 |
41 | # process input
42 | tile_id_ = min(tile_id + tile_batch_size, tile_num)
43 | x = model_input[:, :, :, :, tile_id: tile_id_]
44 | x = x.to(device=inference_device, dtype=inference_dtype)
45 | x = rearrange(x, "b c h w n -> (n b) c h w")
46 |
47 | # process output
48 | y = forward_fn(x)
49 | y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
50 | y = y.to(device=tile_device, dtype=tile_dtype)
51 | model_output_stack.append(y)
52 |
53 | model_output = torch.concat(model_output_stack, dim=-1)
54 | return model_output
55 |
56 |
57 | def io_scale(self, model_output, tile_size):
58 | # Determine the size modification happend in forward_fn
59 | # We only consider the same scale on height and width.
60 | io_scale = model_output.shape[2] / tile_size
61 | return io_scale
62 |
63 |
64 | def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
65 | # The reversed function of tile
66 | mask = self.mask(tile_size, tile_size, border_width)
67 | mask = mask.to(device=tile_device, dtype=tile_dtype)
68 | mask = rearrange(mask, "h w -> 1 1 h w 1")
69 | model_output = model_output * mask
70 |
71 | fold_operator = torch.nn.Fold(
72 | output_size=(height, width),
73 | kernel_size=(tile_size, tile_size),
74 | stride=(tile_stride, tile_stride)
75 | )
76 | mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
77 | model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
78 | model_output = fold_operator(model_output) / fold_operator(mask)
79 |
80 | return model_output
81 |
82 |
83 | def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
84 | # Prepare
85 | inference_device, inference_dtype = model_input.device, model_input.dtype
86 | height, width = model_input.shape[2], model_input.shape[3]
87 | border_width = int(tile_stride*0.5) if border_width is None else border_width
88 |
89 | # tile
90 | model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
91 |
92 | # inference
93 | model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
94 |
95 | # resize
96 | io_scale = self.io_scale(model_output, tile_size)
97 | height, width = int(height*io_scale), int(width*io_scale)
98 | tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
99 | border_width = int(border_width*io_scale)
100 |
101 | # untile
102 | model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
103 |
104 | # Done!
105 | model_output = model_output.to(device=inference_device, dtype=inference_dtype)
106 | return model_output
--------------------------------------------------------------------------------
/diffsynth/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | from .stable_diffusion import SDImagePipeline
2 | from .stable_diffusion_xl import SDXLImagePipeline
3 | from .stable_diffusion_video import SDVideoPipeline, SDVideoPipelineRunner
4 | from .stable_diffusion_xl_video import SDXLVideoPipeline
5 | from .stable_video_diffusion import SVDVideoPipeline
6 | from .hunyuan_dit import HunyuanDiTImagePipeline
7 |
--------------------------------------------------------------------------------
/diffsynth/pipelines/dancer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from ..models import SDUNet, SDMotionModel, SDXLUNet, SDXLMotionModel
3 | from ..models.sd_unet import PushBlock, PopBlock
4 | from ..controlnets import MultiControlNetManager
5 |
6 |
7 | def lets_dance(
8 | unet: SDUNet,
9 | motion_modules: SDMotionModel = None,
10 | controlnet: MultiControlNetManager = None,
11 | sample = None,
12 | timestep = None,
13 | encoder_hidden_states = None,
14 | ipadapter_kwargs_list = {},
15 | controlnet_frames = None,
16 | unet_batch_size = 1,
17 | controlnet_batch_size = 1,
18 | cross_frame_attention = False,
19 | tiled=False,
20 | tile_size=64,
21 | tile_stride=32,
22 | device = "cuda",
23 | vram_limit_level = 0,
24 | ):
25 | # 1. ControlNet
26 | # This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride.
27 | # I leave it here because I intend to do something interesting on the ControlNets.
28 | controlnet_insert_block_id = 30
29 | if controlnet is not None and controlnet_frames is not None:
30 | res_stacks = []
31 | # process controlnet frames with batch
32 | for batch_id in range(0, sample.shape[0], controlnet_batch_size):
33 | batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
34 | res_stack = controlnet(
35 | sample[batch_id: batch_id_],
36 | timestep,
37 | encoder_hidden_states[batch_id: batch_id_],
38 | controlnet_frames[:, batch_id: batch_id_],
39 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
40 | )
41 | if vram_limit_level >= 1:
42 | res_stack = [res.cpu() for res in res_stack]
43 | res_stacks.append(res_stack)
44 | # concat the residual
45 | additional_res_stack = []
46 | for i in range(len(res_stacks[0])):
47 | res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
48 | additional_res_stack.append(res)
49 | else:
50 | additional_res_stack = None
51 |
52 | # 2. time
53 | time_emb = unet.time_proj(timestep[None]).to(sample.dtype)
54 | time_emb = unet.time_embedding(time_emb)
55 |
56 | # 3. pre-process
57 | height, width = sample.shape[2], sample.shape[3]
58 | hidden_states = unet.conv_in(sample)
59 | text_emb = encoder_hidden_states
60 | res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states]
61 |
62 | # 4. blocks
63 | for block_id, block in enumerate(unet.blocks):
64 | # 4.1 UNet
65 | if isinstance(block, PushBlock):
66 | hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
67 | if vram_limit_level>=1:
68 | res_stack[-1] = res_stack[-1].cpu()
69 | elif isinstance(block, PopBlock):
70 | if vram_limit_level>=1:
71 | res_stack[-1] = res_stack[-1].to(device)
72 | hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
73 | else:
74 | hidden_states_input = hidden_states
75 | hidden_states_output = []
76 | for batch_id in range(0, sample.shape[0], unet_batch_size):
77 | batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
78 | hidden_states, _, _, _ = block(
79 | hidden_states_input[batch_id: batch_id_],
80 | time_emb,
81 | text_emb[batch_id: batch_id_],
82 | res_stack,
83 | cross_frame_attention=cross_frame_attention,
84 | ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}),
85 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
86 | )
87 | hidden_states_output.append(hidden_states)
88 | hidden_states = torch.concat(hidden_states_output, dim=0)
89 | # 4.2 AnimateDiff
90 | if motion_modules is not None:
91 | if block_id in motion_modules.call_block_id:
92 | motion_module_id = motion_modules.call_block_id[block_id]
93 | hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
94 | hidden_states, time_emb, text_emb, res_stack,
95 | batch_size=1
96 | )
97 | # 4.3 ControlNet
98 | if block_id == controlnet_insert_block_id and additional_res_stack is not None:
99 | hidden_states += additional_res_stack.pop().to(device)
100 | if vram_limit_level>=1:
101 | res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)]
102 | else:
103 | res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
104 |
105 | # 5. output
106 | hidden_states = unet.conv_norm_out(hidden_states)
107 | hidden_states = unet.conv_act(hidden_states)
108 | hidden_states = unet.conv_out(hidden_states)
109 |
110 | return hidden_states
111 |
112 |
113 |
114 |
115 | def lets_dance_xl(
116 | unet: SDXLUNet,
117 | motion_modules: SDXLMotionModel = None,
118 | controlnet: MultiControlNetManager = None,
119 | sample = None,
120 | add_time_id = None,
121 | add_text_embeds = None,
122 | timestep = None,
123 | encoder_hidden_states = None,
124 | ipadapter_kwargs_list = {},
125 | controlnet_frames = None,
126 | unet_batch_size = 1,
127 | controlnet_batch_size = 1,
128 | cross_frame_attention = False,
129 | tiled=False,
130 | tile_size=64,
131 | tile_stride=32,
132 | device = "cuda",
133 | vram_limit_level = 0,
134 | ):
135 | # 2. time
136 | t_emb = unet.time_proj(timestep[None]).to(sample.dtype)
137 | t_emb = unet.time_embedding(t_emb)
138 |
139 | time_embeds = unet.add_time_proj(add_time_id)
140 | time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
141 | add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
142 | add_embeds = add_embeds.to(sample.dtype)
143 | add_embeds = unet.add_time_embedding(add_embeds)
144 |
145 | time_emb = t_emb + add_embeds
146 |
147 | # 3. pre-process
148 | height, width = sample.shape[2], sample.shape[3]
149 | hidden_states = unet.conv_in(sample)
150 | text_emb = encoder_hidden_states
151 | res_stack = [hidden_states]
152 |
153 | # 4. blocks
154 | for block_id, block in enumerate(unet.blocks):
155 | hidden_states, time_emb, text_emb, res_stack = block(
156 | hidden_states, time_emb, text_emb, res_stack,
157 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
158 | ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {})
159 | )
160 | # 4.2 AnimateDiff
161 | if motion_modules is not None:
162 | if block_id in motion_modules.call_block_id:
163 | motion_module_id = motion_modules.call_block_id[block_id]
164 | hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
165 | hidden_states, time_emb, text_emb, res_stack,
166 | batch_size=1
167 | )
168 |
169 | # 5. output
170 | hidden_states = unet.conv_norm_out(hidden_states)
171 | hidden_states = unet.conv_act(hidden_states)
172 | hidden_states = unet.conv_out(hidden_states)
173 |
174 | return hidden_states
--------------------------------------------------------------------------------
/diffsynth/pipelines/stable_diffusion.py:
--------------------------------------------------------------------------------
1 | from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder
2 | from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
3 | from ..prompts import SDPrompter
4 | from ..schedulers import EnhancedDDIMScheduler
5 | from .dancer import lets_dance
6 | from typing import List
7 | import torch
8 | from tqdm import tqdm
9 | from PIL import Image
10 | import numpy as np
11 |
12 |
13 | class SDImagePipeline(torch.nn.Module):
14 |
15 | def __init__(self, device="cuda", torch_dtype=torch.float16):
16 | super().__init__()
17 | self.scheduler = EnhancedDDIMScheduler()
18 | self.prompter = SDPrompter()
19 | self.device = device
20 | self.torch_dtype = torch_dtype
21 | # models
22 | self.text_encoder: SDTextEncoder = None
23 | self.unet: SDUNet = None
24 | self.vae_decoder: SDVAEDecoder = None
25 | self.vae_encoder: SDVAEEncoder = None
26 | self.controlnet: MultiControlNetManager = None
27 | self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
28 | self.ipadapter: SDIpAdapter = None
29 |
30 |
31 | def fetch_main_models(self, model_manager: ModelManager):
32 | self.text_encoder = model_manager.text_encoder
33 | self.unet = model_manager.unet
34 | self.vae_decoder = model_manager.vae_decoder
35 | self.vae_encoder = model_manager.vae_encoder
36 |
37 |
38 | def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
39 | controlnet_units = []
40 | for config in controlnet_config_units:
41 | controlnet_unit = ControlNetUnit(
42 | Annotator(config.processor_id),
43 | model_manager.get_model_with_model_path(config.model_path),
44 | config.scale
45 | )
46 | controlnet_units.append(controlnet_unit)
47 | self.controlnet = MultiControlNetManager(controlnet_units)
48 |
49 |
50 | def fetch_ipadapter(self, model_manager: ModelManager):
51 | if "ipadapter" in model_manager.model:
52 | self.ipadapter = model_manager.ipadapter
53 | if "ipadapter_image_encoder" in model_manager.model:
54 | self.ipadapter_image_encoder = model_manager.ipadapter_image_encoder
55 |
56 |
57 | def fetch_prompter(self, model_manager: ModelManager):
58 | self.prompter.load_from_model_manager(model_manager)
59 |
60 |
61 | @staticmethod
62 | def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
63 | pipe = SDImagePipeline(
64 | device=model_manager.device,
65 | torch_dtype=model_manager.torch_dtype,
66 | )
67 | pipe.fetch_main_models(model_manager)
68 | pipe.fetch_prompter(model_manager)
69 | pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
70 | pipe.fetch_ipadapter(model_manager)
71 | return pipe
72 |
73 |
74 | def preprocess_image(self, image):
75 | image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
76 | return image
77 |
78 |
79 | def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
80 | image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
81 | image = image.cpu().permute(1, 2, 0).numpy()
82 | image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
83 | return image
84 |
85 |
86 | @torch.no_grad()
87 | def __call__(
88 | self,
89 | prompt,
90 | negative_prompt="",
91 | cfg_scale=7.5,
92 | clip_skip=1,
93 | input_image=None,
94 | ipadapter_images=None,
95 | ipadapter_scale=1.0,
96 | controlnet_image=None,
97 | denoising_strength=1.0,
98 | height=512,
99 | width=512,
100 | num_inference_steps=20,
101 | tiled=False,
102 | tile_size=64,
103 | tile_stride=32,
104 | progress_bar_cmd=tqdm,
105 | progress_bar_st=None,
106 | ):
107 | # Prepare scheduler
108 | self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
109 |
110 | # Prepare latent tensors
111 | if input_image is not None:
112 | image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
113 | latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
114 | noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
115 | latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
116 | else:
117 | latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
118 |
119 | # Encode prompts
120 | prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True)
121 | prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False)
122 |
123 | # IP-Adapter
124 | if ipadapter_images is not None:
125 | ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
126 | ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
127 | ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
128 | else:
129 | ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
130 |
131 | # Prepare ControlNets
132 | if controlnet_image is not None:
133 | controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
134 | controlnet_image = controlnet_image.unsqueeze(1)
135 |
136 | # Denoise
137 | for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
138 | timestep = torch.IntTensor((timestep,))[0].to(self.device)
139 |
140 | # Classifier-free guidance
141 | noise_pred_posi = lets_dance(
142 | self.unet, motion_modules=None, controlnet=self.controlnet,
143 | sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_image,
144 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
145 | ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
146 | device=self.device, vram_limit_level=0
147 | )
148 | noise_pred_nega = lets_dance(
149 | self.unet, motion_modules=None, controlnet=self.controlnet,
150 | sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_image,
151 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
152 | ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
153 | device=self.device, vram_limit_level=0
154 | )
155 | noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
156 |
157 | # DDIM
158 | latents = self.scheduler.step(noise_pred, timestep, latents)
159 |
160 | # UI
161 | if progress_bar_st is not None:
162 | progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
163 |
164 | # Decode image
165 | image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
166 |
167 | return image
168 |
--------------------------------------------------------------------------------
/diffsynth/pipelines/stable_diffusion_xl.py:
--------------------------------------------------------------------------------
1 | from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
2 | # TODO: SDXL ControlNet
3 | from ..prompts import SDXLPrompter
4 | from ..schedulers import EnhancedDDIMScheduler
5 | from .dancer import lets_dance_xl
6 | import torch
7 | from tqdm import tqdm
8 | from PIL import Image
9 | import numpy as np
10 |
11 |
12 | class SDXLImagePipeline(torch.nn.Module):
13 |
14 | def __init__(self, device="cuda", torch_dtype=torch.float16):
15 | super().__init__()
16 | self.scheduler = EnhancedDDIMScheduler()
17 | self.prompter = SDXLPrompter()
18 | self.device = device
19 | self.torch_dtype = torch_dtype
20 | # models
21 | self.text_encoder: SDXLTextEncoder = None
22 | self.text_encoder_2: SDXLTextEncoder2 = None
23 | self.unet: SDXLUNet = None
24 | self.vae_decoder: SDXLVAEDecoder = None
25 | self.vae_encoder: SDXLVAEEncoder = None
26 | self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
27 | self.ipadapter: SDXLIpAdapter = None
28 | # TODO: SDXL ControlNet
29 |
30 | def fetch_main_models(self, model_manager: ModelManager):
31 | self.text_encoder = model_manager.text_encoder
32 | self.text_encoder_2 = model_manager.text_encoder_2
33 | self.unet = model_manager.unet
34 | self.vae_decoder = model_manager.vae_decoder
35 | self.vae_encoder = model_manager.vae_encoder
36 |
37 |
38 | def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
39 | # TODO: SDXL ControlNet
40 | pass
41 |
42 |
43 | def fetch_ipadapter(self, model_manager: ModelManager):
44 | if "ipadapter_xl" in model_manager.model:
45 | self.ipadapter = model_manager.ipadapter_xl
46 | if "ipadapter_xl_image_encoder" in model_manager.model:
47 | self.ipadapter_image_encoder = model_manager.ipadapter_xl_image_encoder
48 |
49 |
50 | def fetch_prompter(self, model_manager: ModelManager):
51 | self.prompter.load_from_model_manager(model_manager)
52 |
53 |
54 | @staticmethod
55 | def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs):
56 | pipe = SDXLImagePipeline(
57 | device=model_manager.device,
58 | torch_dtype=model_manager.torch_dtype,
59 | )
60 | pipe.fetch_main_models(model_manager)
61 | pipe.fetch_prompter(model_manager)
62 | pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
63 | pipe.fetch_ipadapter(model_manager)
64 | return pipe
65 |
66 |
67 | def preprocess_image(self, image):
68 | image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
69 | return image
70 |
71 |
72 | def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
73 | image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
74 | image = image.cpu().permute(1, 2, 0).numpy()
75 | image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
76 | return image
77 |
78 |
79 | @torch.no_grad()
80 | def __call__(
81 | self,
82 | prompt,
83 | negative_prompt="",
84 | cfg_scale=7.5,
85 | clip_skip=1,
86 | clip_skip_2=2,
87 | input_image=None,
88 | ipadapter_images=None,
89 | ipadapter_scale=1.0,
90 | controlnet_image=None,
91 | denoising_strength=1.0,
92 | height=1024,
93 | width=1024,
94 | num_inference_steps=20,
95 | tiled=False,
96 | tile_size=64,
97 | tile_stride=32,
98 | progress_bar_cmd=tqdm,
99 | progress_bar_st=None,
100 | ):
101 | # Prepare scheduler
102 | self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
103 |
104 | # Prepare latent tensors
105 | if input_image is not None:
106 | image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
107 | latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
108 | noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
109 | latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
110 | else:
111 | latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
112 |
113 | # Encode prompts
114 | add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
115 | self.text_encoder,
116 | self.text_encoder_2,
117 | prompt,
118 | clip_skip=clip_skip, clip_skip_2=clip_skip_2,
119 | device=self.device,
120 | positive=True,
121 | )
122 | if cfg_scale != 1.0:
123 | add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
124 | self.text_encoder,
125 | self.text_encoder_2,
126 | negative_prompt,
127 | clip_skip=clip_skip, clip_skip_2=clip_skip_2,
128 | device=self.device,
129 | positive=False,
130 | )
131 |
132 | # Prepare positional id
133 | add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
134 |
135 | # IP-Adapter
136 | if ipadapter_images is not None:
137 | ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
138 | ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
139 | ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
140 | else:
141 | ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
142 |
143 | # Denoise
144 | for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
145 | timestep = torch.IntTensor((timestep,))[0].to(self.device)
146 |
147 | # Classifier-free guidance
148 | noise_pred_posi = lets_dance_xl(
149 | self.unet,
150 | sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi,
151 | add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
152 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
153 | ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
154 | )
155 | if cfg_scale != 1.0:
156 | noise_pred_nega = lets_dance_xl(
157 | self.unet,
158 | sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega,
159 | add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
160 | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
161 | ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
162 | )
163 | noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
164 | else:
165 | noise_pred = noise_pred_posi
166 |
167 | latents = self.scheduler.step(noise_pred, timestep, latents)
168 |
169 | if progress_bar_st is not None:
170 | progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
171 |
172 | # Decode image
173 | image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
174 |
175 | return image
176 |
--------------------------------------------------------------------------------
/diffsynth/pipelines/stable_diffusion_xl_video.py:
--------------------------------------------------------------------------------
1 | from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLMotionModel
2 | from .dancer import lets_dance_xl
3 | # TODO: SDXL ControlNet
4 | from ..prompts import SDXLPrompter
5 | from ..schedulers import EnhancedDDIMScheduler
6 | import torch
7 | from tqdm import tqdm
8 | from PIL import Image
9 | import numpy as np
10 |
11 |
12 | class SDXLVideoPipeline(torch.nn.Module):
13 |
14 | def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True):
15 | super().__init__()
16 | self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear")
17 | self.prompter = SDXLPrompter()
18 | self.device = device
19 | self.torch_dtype = torch_dtype
20 | # models
21 | self.text_encoder: SDXLTextEncoder = None
22 | self.text_encoder_2: SDXLTextEncoder2 = None
23 | self.unet: SDXLUNet = None
24 | self.vae_decoder: SDXLVAEDecoder = None
25 | self.vae_encoder: SDXLVAEEncoder = None
26 | # TODO: SDXL ControlNet
27 | self.motion_modules: SDXLMotionModel = None
28 |
29 |
30 | def fetch_main_models(self, model_manager: ModelManager):
31 | self.text_encoder = model_manager.text_encoder
32 | self.text_encoder_2 = model_manager.text_encoder_2
33 | self.unet = model_manager.unet
34 | self.vae_decoder = model_manager.vae_decoder
35 | self.vae_encoder = model_manager.vae_encoder
36 |
37 |
38 | def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
39 | # TODO: SDXL ControlNet
40 | pass
41 |
42 |
43 | def fetch_motion_modules(self, model_manager: ModelManager):
44 | if "motion_modules_xl" in model_manager.model:
45 | self.motion_modules = model_manager.motion_modules_xl
46 |
47 |
48 | def fetch_prompter(self, model_manager: ModelManager):
49 | self.prompter.load_from_model_manager(model_manager)
50 |
51 |
52 | @staticmethod
53 | def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs):
54 | pipe = SDXLVideoPipeline(
55 | device=model_manager.device,
56 | torch_dtype=model_manager.torch_dtype,
57 | use_animatediff="motion_modules_xl" in model_manager.model
58 | )
59 | pipe.fetch_main_models(model_manager)
60 | pipe.fetch_motion_modules(model_manager)
61 | pipe.fetch_prompter(model_manager)
62 | pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
63 | return pipe
64 |
65 |
66 | def preprocess_image(self, image):
67 | image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
68 | return image
69 |
70 |
71 | def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
72 | image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
73 | image = image.cpu().permute(1, 2, 0).numpy()
74 | image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
75 | return image
76 |
77 |
78 | def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32):
79 | images = [
80 | self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
81 | for frame_id in range(latents.shape[0])
82 | ]
83 | return images
84 |
85 |
86 | def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
87 | latents = []
88 | for image in processed_images:
89 | image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
90 | latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu()
91 | latents.append(latent)
92 | latents = torch.concat(latents, dim=0)
93 | return latents
94 |
95 |
96 | @torch.no_grad()
97 | def __call__(
98 | self,
99 | prompt,
100 | negative_prompt="",
101 | cfg_scale=7.5,
102 | clip_skip=1,
103 | clip_skip_2=2,
104 | num_frames=None,
105 | input_frames=None,
106 | controlnet_frames=None,
107 | denoising_strength=1.0,
108 | height=512,
109 | width=512,
110 | num_inference_steps=20,
111 | animatediff_batch_size = 16,
112 | animatediff_stride = 8,
113 | unet_batch_size = 1,
114 | controlnet_batch_size = 1,
115 | cross_frame_attention = False,
116 | smoother=None,
117 | smoother_progress_ids=[],
118 | vram_limit_level=0,
119 | progress_bar_cmd=tqdm,
120 | progress_bar_st=None,
121 | ):
122 | # Prepare scheduler
123 | self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
124 |
125 | # Prepare latent tensors
126 | if self.motion_modules is None:
127 | noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
128 | else:
129 | noise = torch.randn((num_frames, 4, height//8, width//8), device="cuda", dtype=self.torch_dtype)
130 | if input_frames is None or denoising_strength == 1.0:
131 | latents = noise
132 | else:
133 | latents = self.encode_images(input_frames)
134 | latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
135 |
136 | # Encode prompts
137 | add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
138 | self.text_encoder,
139 | self.text_encoder_2,
140 | prompt,
141 | clip_skip=clip_skip, clip_skip_2=clip_skip_2,
142 | device=self.device,
143 | positive=True,
144 | )
145 | if cfg_scale != 1.0:
146 | add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
147 | self.text_encoder,
148 | self.text_encoder_2,
149 | negative_prompt,
150 | clip_skip=clip_skip, clip_skip_2=clip_skip_2,
151 | device=self.device,
152 | positive=False,
153 | )
154 |
155 | # Prepare positional id
156 | add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
157 |
158 | # Denoise
159 | for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
160 | timestep = torch.IntTensor((timestep,))[0].to(self.device)
161 |
162 | # Classifier-free guidance
163 | noise_pred_posi = lets_dance_xl(
164 | self.unet, motion_modules=self.motion_modules, controlnet=None,
165 | sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
166 | timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames,
167 | cross_frame_attention=cross_frame_attention,
168 | device=self.device, vram_limit_level=vram_limit_level
169 | )
170 | if cfg_scale != 1.0:
171 | noise_pred_nega = lets_dance_xl(
172 | self.unet, motion_modules=self.motion_modules, controlnet=None,
173 | sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
174 | timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
175 | cross_frame_attention=cross_frame_attention,
176 | device=self.device, vram_limit_level=vram_limit_level
177 | )
178 | noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
179 | else:
180 | noise_pred = noise_pred_posi
181 |
182 | latents = self.scheduler.step(noise_pred, timestep, latents)
183 |
184 | if progress_bar_st is not None:
185 | progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
186 |
187 | # Decode image
188 | image = self.decode_images(latents.to(torch.float32))
189 |
190 | return image
191 |
--------------------------------------------------------------------------------
/diffsynth/processors/FastBlend.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import cupy as cp
3 | import numpy as np
4 | from tqdm import tqdm
5 | from ..extensions.FastBlend.patch_match import PyramidPatchMatcher
6 | from ..extensions.FastBlend.runners.fast import TableManager
7 | from .base import VideoProcessor
8 |
9 |
10 | class FastBlendSmoother(VideoProcessor):
11 | def __init__(
12 | self,
13 | inference_mode="fast", batch_size=8, window_size=60,
14 | minimum_patch_size=5, threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0, initialize="identity", tracking_window_size=0
15 | ):
16 | self.inference_mode = inference_mode
17 | self.batch_size = batch_size
18 | self.window_size = window_size
19 | self.ebsynth_config = {
20 | "minimum_patch_size": minimum_patch_size,
21 | "threads_per_block": threads_per_block,
22 | "num_iter": num_iter,
23 | "gpu_id": gpu_id,
24 | "guide_weight": guide_weight,
25 | "initialize": initialize,
26 | "tracking_window_size": tracking_window_size
27 | }
28 |
29 | @staticmethod
30 | def from_model_manager(model_manager, **kwargs):
31 | # TODO: fetch GPU ID from model_manager
32 | return FastBlendSmoother(**kwargs)
33 |
34 | def inference_fast(self, frames_guide, frames_style):
35 | table_manager = TableManager()
36 | patch_match_engine = PyramidPatchMatcher(
37 | image_height=frames_style[0].shape[0],
38 | image_width=frames_style[0].shape[1],
39 | channel=3,
40 | **self.ebsynth_config
41 | )
42 | # left part
43 | table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, self.batch_size, desc="Fast Mode Step 1/4")
44 | table_l = table_manager.remapping_table_to_blending_table(table_l)
45 | table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 2/4")
46 | # right part
47 | table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, self.batch_size, desc="Fast Mode Step 3/4")
48 | table_r = table_manager.remapping_table_to_blending_table(table_r)
49 | table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 4/4")[::-1]
50 | # merge
51 | frames = []
52 | for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
53 | weight_m = -1
54 | weight = weight_l + weight_m + weight_r
55 | frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
56 | frames.append(frame)
57 | frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
58 | frames = [Image.fromarray(frame) for frame in frames]
59 | return frames
60 |
61 | def inference_balanced(self, frames_guide, frames_style):
62 | patch_match_engine = PyramidPatchMatcher(
63 | image_height=frames_style[0].shape[0],
64 | image_width=frames_style[0].shape[1],
65 | channel=3,
66 | **self.ebsynth_config
67 | )
68 | output_frames = []
69 | # tasks
70 | n = len(frames_style)
71 | tasks = []
72 | for target in range(n):
73 | for source in range(target - self.window_size, target + self.window_size + 1):
74 | if source >= 0 and source < n and source != target:
75 | tasks.append((source, target))
76 | # run
77 | frames = [(None, 1) for i in range(n)]
78 | for batch_id in tqdm(range(0, len(tasks), self.batch_size), desc="Balanced Mode"):
79 | tasks_batch = tasks[batch_id: min(batch_id+self.batch_size, len(tasks))]
80 | source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
81 | target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
82 | source_style = np.stack([frames_style[source] for source, target in tasks_batch])
83 | _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
84 | for (source, target), result in zip(tasks_batch, target_style):
85 | frame, weight = frames[target]
86 | if frame is None:
87 | frame = frames_style[target]
88 | frames[target] = (
89 | frame * (weight / (weight + 1)) + result / (weight + 1),
90 | weight + 1
91 | )
92 | if weight + 1 == min(n, target + self.window_size + 1) - max(0, target - self.window_size):
93 | frame = frame.clip(0, 255).astype("uint8")
94 | output_frames.append(Image.fromarray(frame))
95 | frames[target] = (None, 1)
96 | return output_frames
97 |
98 | def inference_accurate(self, frames_guide, frames_style):
99 | patch_match_engine = PyramidPatchMatcher(
100 | image_height=frames_style[0].shape[0],
101 | image_width=frames_style[0].shape[1],
102 | channel=3,
103 | use_mean_target_style=True,
104 | **self.ebsynth_config
105 | )
106 | output_frames = []
107 | # run
108 | n = len(frames_style)
109 | for target in tqdm(range(n), desc="Accurate Mode"):
110 | l, r = max(target - self.window_size, 0), min(target + self.window_size + 1, n)
111 | remapped_frames = []
112 | for i in range(l, r, self.batch_size):
113 | j = min(i + self.batch_size, r)
114 | source_guide = np.stack([frames_guide[source] for source in range(i, j)])
115 | target_guide = np.stack([frames_guide[target]] * (j - i))
116 | source_style = np.stack([frames_style[source] for source in range(i, j)])
117 | _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
118 | remapped_frames.append(target_style)
119 | frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
120 | frame = frame.clip(0, 255).astype("uint8")
121 | output_frames.append(Image.fromarray(frame))
122 | return output_frames
123 |
124 | def release_vram(self):
125 | mempool = cp.get_default_memory_pool()
126 | pinned_mempool = cp.get_default_pinned_memory_pool()
127 | mempool.free_all_blocks()
128 | pinned_mempool.free_all_blocks()
129 |
130 | def __call__(self, rendered_frames, original_frames=None, **kwargs):
131 | rendered_frames = [np.array(frame) for frame in rendered_frames]
132 | original_frames = [np.array(frame) for frame in original_frames]
133 | if self.inference_mode == "fast":
134 | output_frames = self.inference_fast(original_frames, rendered_frames)
135 | elif self.inference_mode == "balanced":
136 | output_frames = self.inference_balanced(original_frames, rendered_frames)
137 | elif self.inference_mode == "accurate":
138 | output_frames = self.inference_accurate(original_frames, rendered_frames)
139 | else:
140 | raise ValueError("inference_mode must be fast, balanced or accurate")
141 | self.release_vram()
142 | return output_frames
143 |
--------------------------------------------------------------------------------
/diffsynth/processors/PILEditor.py:
--------------------------------------------------------------------------------
1 | from PIL import ImageEnhance
2 | from .base import VideoProcessor
3 |
4 |
5 | class ContrastEditor(VideoProcessor):
6 | def __init__(self, rate=1.5):
7 | self.rate = rate
8 |
9 | @staticmethod
10 | def from_model_manager(model_manager, **kwargs):
11 | return ContrastEditor(**kwargs)
12 |
13 | def __call__(self, rendered_frames, **kwargs):
14 | rendered_frames = [ImageEnhance.Contrast(i).enhance(self.rate) for i in rendered_frames]
15 | return rendered_frames
16 |
17 |
18 | class SharpnessEditor(VideoProcessor):
19 | def __init__(self, rate=1.5):
20 | self.rate = rate
21 |
22 | @staticmethod
23 | def from_model_manager(model_manager, **kwargs):
24 | return SharpnessEditor(**kwargs)
25 |
26 | def __call__(self, rendered_frames, **kwargs):
27 | rendered_frames = [ImageEnhance.Sharpness(i).enhance(self.rate) for i in rendered_frames]
28 | return rendered_frames
29 |
--------------------------------------------------------------------------------
/diffsynth/processors/RIFE.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from PIL import Image
4 | from .base import VideoProcessor
5 |
6 |
7 | class RIFESmoother(VideoProcessor):
8 | def __init__(self, model, device="cuda", scale=1.0, batch_size=4, interpolate=True):
9 | self.model = model
10 | self.device = device
11 |
12 | # IFNet only does not support float16
13 | self.torch_dtype = torch.float32
14 |
15 | # Other parameters
16 | self.scale = scale
17 | self.batch_size = batch_size
18 | self.interpolate = interpolate
19 |
20 | @staticmethod
21 | def from_model_manager(model_manager, **kwargs):
22 | return RIFESmoother(model_manager.RIFE, device=model_manager.device, **kwargs)
23 |
24 | def process_image(self, image):
25 | width, height = image.size
26 | if width % 32 != 0 or height % 32 != 0:
27 | width = (width + 31) // 32
28 | height = (height + 31) // 32
29 | image = image.resize((width, height))
30 | image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
31 | return image
32 |
33 | def process_images(self, images):
34 | images = [self.process_image(image) for image in images]
35 | images = torch.stack(images)
36 | return images
37 |
38 | def decode_images(self, images):
39 | images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
40 | images = [Image.fromarray(image) for image in images]
41 | return images
42 |
43 | def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
44 | output_tensor = []
45 | for batch_id in range(0, input_tensor.shape[0], batch_size):
46 | batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
47 | batch_input_tensor = input_tensor[batch_id: batch_id_]
48 | batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
49 | flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
50 | output_tensor.append(merged[2].cpu())
51 | output_tensor = torch.concat(output_tensor, dim=0)
52 | return output_tensor
53 |
54 | @torch.no_grad()
55 | def __call__(self, rendered_frames, **kwargs):
56 | # Preprocess
57 | processed_images = self.process_images(rendered_frames)
58 |
59 | # Input
60 | input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
61 |
62 | # Interpolate
63 | output_tensor = self.process_tensors(input_tensor, scale=self.scale, batch_size=self.batch_size)
64 |
65 | if self.interpolate:
66 | # Blend
67 | input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
68 | output_tensor = self.process_tensors(input_tensor, scale=self.scale, batch_size=self.batch_size)
69 | processed_images[1:-1] = output_tensor
70 | else:
71 | processed_images[1:-1] = (processed_images[1:-1] + output_tensor) / 2
72 |
73 | # To images
74 | output_images = self.decode_images(processed_images)
75 | if output_images[0].size != rendered_frames[0].size:
76 | output_images = [image.resize(rendered_frames[0].size) for image in output_images]
77 | return output_images
78 |
--------------------------------------------------------------------------------
/diffsynth/processors/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-DiffSynth-Studio/91d0a6bd9efa6a1d4a0b880984e21fe66acf4eb0/diffsynth/processors/__init__.py
--------------------------------------------------------------------------------
/diffsynth/processors/base.py:
--------------------------------------------------------------------------------
1 | class VideoProcessor:
2 | def __init__(self):
3 | pass
4 |
5 | def __call__(self):
6 | raise NotImplementedError
7 |
--------------------------------------------------------------------------------
/diffsynth/processors/sequencial_processor.py:
--------------------------------------------------------------------------------
1 | from .base import VideoProcessor
2 |
3 |
4 | class AutoVideoProcessor(VideoProcessor):
5 | def __init__(self):
6 | pass
7 |
8 | @staticmethod
9 | def from_model_manager(model_manager, processor_type, **kwargs):
10 | if processor_type == "FastBlend":
11 | from .FastBlend import FastBlendSmoother
12 | return FastBlendSmoother.from_model_manager(model_manager, **kwargs)
13 | elif processor_type == "Contrast":
14 | from .PILEditor import ContrastEditor
15 | return ContrastEditor.from_model_manager(model_manager, **kwargs)
16 | elif processor_type == "Sharpness":
17 | from .PILEditor import SharpnessEditor
18 | return SharpnessEditor.from_model_manager(model_manager, **kwargs)
19 | elif processor_type == "RIFE":
20 | from .RIFE import RIFESmoother
21 | return RIFESmoother.from_model_manager(model_manager, **kwargs)
22 | else:
23 | raise ValueError(f"invalid processor_type: {processor_type}")
24 |
25 |
26 | class SequencialProcessor(VideoProcessor):
27 | def __init__(self, processors=[]):
28 | self.processors = processors
29 |
30 | @staticmethod
31 | def from_model_manager(model_manager, configs):
32 | processors = [
33 | AutoVideoProcessor.from_model_manager(model_manager, config["processor_type"], **config["config"])
34 | for config in configs
35 | ]
36 | return SequencialProcessor(processors)
37 |
38 | def __call__(self, rendered_frames, **kwargs):
39 | for processor in self.processors:
40 | rendered_frames = processor(rendered_frames, **kwargs)
41 | return rendered_frames
42 |
--------------------------------------------------------------------------------
/diffsynth/prompts/__init__.py:
--------------------------------------------------------------------------------
1 | from .sd_prompter import SDPrompter
2 | from .sdxl_prompter import SDXLPrompter
3 | from .hunyuan_dit_prompter import HunyuanDiTPrompter
4 |
--------------------------------------------------------------------------------
/diffsynth/prompts/hunyuan_dit_prompter.py:
--------------------------------------------------------------------------------
1 | from .utils import Prompter
2 | from transformers import BertModel, T5EncoderModel, BertTokenizer, AutoTokenizer
3 | import warnings, os
4 |
5 |
6 | class HunyuanDiTPrompter(Prompter):
7 | def __init__(
8 | self,
9 | tokenizer_path=None,
10 | tokenizer_t5_path=None
11 | ):
12 | if tokenizer_path is None:
13 | base_path = os.path.dirname(os.path.dirname(__file__))
14 | tokenizer_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
15 | if tokenizer_t5_path is None:
16 | base_path = os.path.dirname(os.path.dirname(__file__))
17 | tokenizer_t5_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer_t5")
18 | super().__init__()
19 | self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
20 | with warnings.catch_warnings():
21 | warnings.simplefilter("ignore")
22 | self.tokenizer_t5 = AutoTokenizer.from_pretrained(tokenizer_t5_path)
23 |
24 |
25 | def encode_prompt_using_signle_model(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
26 | text_inputs = tokenizer(
27 | prompt,
28 | padding="max_length",
29 | max_length=max_length,
30 | truncation=True,
31 | return_attention_mask=True,
32 | return_tensors="pt",
33 | )
34 | text_input_ids = text_inputs.input_ids
35 | attention_mask = text_inputs.attention_mask.to(device)
36 | prompt_embeds = text_encoder(
37 | text_input_ids.to(device),
38 | attention_mask=attention_mask,
39 | clip_skip=clip_skip
40 | )
41 | return prompt_embeds, attention_mask
42 |
43 |
44 | def encode_prompt(
45 | self,
46 | text_encoder: BertModel,
47 | text_encoder_t5: T5EncoderModel,
48 | prompt,
49 | clip_skip=1,
50 | clip_skip_2=1,
51 | positive=True,
52 | device="cuda"
53 | ):
54 | prompt = self.process_prompt(prompt, positive=positive)
55 |
56 | # CLIP
57 | prompt_emb, attention_mask = self.encode_prompt_using_signle_model(prompt, text_encoder, self.tokenizer, self.tokenizer.model_max_length, clip_skip, device)
58 |
59 | # T5
60 | prompt_emb_t5, attention_mask_t5 = self.encode_prompt_using_signle_model(prompt, text_encoder_t5, self.tokenizer_t5, self.tokenizer_t5.model_max_length, clip_skip_2, device)
61 |
62 | return prompt_emb, attention_mask, prompt_emb_t5, attention_mask_t5
63 |
--------------------------------------------------------------------------------
/diffsynth/prompts/sd_prompter.py:
--------------------------------------------------------------------------------
1 | from .utils import Prompter, tokenize_long_prompt
2 | from transformers import CLIPTokenizer
3 | from ..models import SDTextEncoder
4 | import os
5 |
6 |
7 | class SDPrompter(Prompter):
8 | def __init__(self, tokenizer_path=None):
9 | if tokenizer_path is None:
10 | base_path = os.path.dirname(os.path.dirname(__file__))
11 | tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
12 | super().__init__()
13 | self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
14 |
15 | def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True):
16 | prompt = self.process_prompt(prompt, positive=positive)
17 | input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
18 | prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
19 | prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
20 |
21 | return prompt_emb
--------------------------------------------------------------------------------
/diffsynth/prompts/sdxl_prompter.py:
--------------------------------------------------------------------------------
1 | from .utils import Prompter, tokenize_long_prompt
2 | from transformers import CLIPTokenizer
3 | from ..models import SDXLTextEncoder, SDXLTextEncoder2
4 | import torch, os
5 |
6 |
7 | class SDXLPrompter(Prompter):
8 | def __init__(
9 | self,
10 | tokenizer_path=None,
11 | tokenizer_2_path=None
12 | ):
13 | if tokenizer_path is None:
14 | base_path = os.path.dirname(os.path.dirname(__file__))
15 | tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
16 | if tokenizer_2_path is None:
17 | base_path = os.path.dirname(os.path.dirname(__file__))
18 | tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_xl/tokenizer_2")
19 | super().__init__()
20 | self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
21 | self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
22 |
23 | def encode_prompt(
24 | self,
25 | text_encoder: SDXLTextEncoder,
26 | text_encoder_2: SDXLTextEncoder2,
27 | prompt,
28 | clip_skip=1,
29 | clip_skip_2=2,
30 | positive=True,
31 | device="cuda"
32 | ):
33 | prompt = self.process_prompt(prompt, positive=positive)
34 |
35 | # 1
36 | input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
37 | prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip)
38 |
39 | # 2
40 | input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device)
41 | add_text_embeds, prompt_emb_2 = text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
42 |
43 | # Merge
44 | prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1)
45 |
46 | # For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`.
47 | add_text_embeds = add_text_embeds[0:1]
48 | prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
49 | return add_text_embeds, prompt_emb
50 |
--------------------------------------------------------------------------------
/diffsynth/prompts/utils.py:
--------------------------------------------------------------------------------
1 | from transformers import CLIPTokenizer, AutoTokenizer
2 | from ..models import ModelManager
3 | import os
4 |
5 |
6 | def tokenize_long_prompt(tokenizer, prompt):
7 | # Get model_max_length from self.tokenizer
8 | length = tokenizer.model_max_length
9 |
10 | # To avoid the warning. set self.tokenizer.model_max_length to +oo.
11 | tokenizer.model_max_length = 99999999
12 |
13 | # Tokenize it!
14 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids
15 |
16 | # Determine the real length.
17 | max_length = (input_ids.shape[1] + length - 1) // length * length
18 |
19 | # Restore tokenizer.model_max_length
20 | tokenizer.model_max_length = length
21 |
22 | # Tokenize it again with fixed length.
23 | input_ids = tokenizer(
24 | prompt,
25 | return_tensors="pt",
26 | padding="max_length",
27 | max_length=max_length,
28 | truncation=True
29 | ).input_ids
30 |
31 | # Reshape input_ids to fit the text encoder.
32 | num_sentence = input_ids.shape[1] // length
33 | input_ids = input_ids.reshape((num_sentence, length))
34 |
35 | return input_ids
36 |
37 |
38 | class BeautifulPrompt:
39 | def __init__(self, tokenizer_path=None, model=None):
40 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
41 | self.model = model
42 | self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
43 |
44 | def __call__(self, raw_prompt):
45 | model_input = self.template.format(raw_prompt=raw_prompt)
46 | input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
47 | outputs = self.model.generate(
48 | input_ids,
49 | max_new_tokens=384,
50 | do_sample=True,
51 | temperature=0.9,
52 | top_k=50,
53 | top_p=0.95,
54 | repetition_penalty=1.1,
55 | num_return_sequences=1
56 | )
57 | prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
58 | outputs[:, input_ids.size(1):],
59 | skip_special_tokens=True
60 | )[0].strip()
61 | return prompt
62 |
63 |
64 | class Translator:
65 | def __init__(self, tokenizer_path=None, model=None):
66 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
67 | self.model = model
68 |
69 | def __call__(self, prompt):
70 | input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
71 | output_ids = self.model.generate(input_ids)
72 | prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
73 | return prompt
74 |
75 |
76 | class Prompter:
77 | def __init__(self):
78 | self.tokenizer: CLIPTokenizer = None
79 | self.keyword_dict = {}
80 | self.translator: Translator = None
81 | self.beautiful_prompt: BeautifulPrompt = None
82 |
83 | def load_textual_inversion(self, textual_inversion_dict):
84 | self.keyword_dict = {}
85 | additional_tokens = []
86 | for keyword in textual_inversion_dict:
87 | tokens, _ = textual_inversion_dict[keyword]
88 | additional_tokens += tokens
89 | self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
90 | self.tokenizer.add_tokens(additional_tokens)
91 |
92 | def load_beautiful_prompt(self, model, model_path):
93 | model_folder = os.path.dirname(model_path)
94 | self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
95 | if model_folder.endswith("v2"):
96 | self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
97 | Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
98 | or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
99 | but make sure there is a correlation between the input and output.\n\
100 | ### Input: {raw_prompt}\n### Output:"""
101 |
102 | def load_translator(self, model, model_path):
103 | model_folder = os.path.dirname(model_path)
104 | self.translator = Translator(tokenizer_path=model_folder, model=model)
105 |
106 | def load_from_model_manager(self, model_manager: ModelManager):
107 | self.load_textual_inversion(model_manager.textual_inversion_dict)
108 | if "translator" in model_manager.model:
109 | self.load_translator(model_manager.model["translator"], model_manager.model_path["translator"])
110 | if "beautiful_prompt" in model_manager.model:
111 | self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
112 |
113 | def process_prompt(self, prompt, positive=True):
114 | for keyword in self.keyword_dict:
115 | if keyword in prompt:
116 | prompt = prompt.replace(keyword, self.keyword_dict[keyword])
117 | if positive and self.translator is not None:
118 | prompt = self.translator(prompt)
119 | print(f"Your prompt is translated: \"{prompt}\"")
120 | if positive and self.beautiful_prompt is not None:
121 | prompt = self.beautiful_prompt(prompt)
122 | print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
123 | return prompt
124 |
--------------------------------------------------------------------------------
/diffsynth/schedulers/__init__.py:
--------------------------------------------------------------------------------
1 | from .ddim import EnhancedDDIMScheduler
2 | from .continuous_ode import ContinuousODEScheduler
3 |
--------------------------------------------------------------------------------
/diffsynth/schedulers/continuous_ode.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class ContinuousODEScheduler():
5 |
6 | def __init__(self, num_inference_steps=100, sigma_max=700.0, sigma_min=0.002, rho=7.0):
7 | self.sigma_max = sigma_max
8 | self.sigma_min = sigma_min
9 | self.rho = rho
10 | self.set_timesteps(num_inference_steps)
11 |
12 |
13 | def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0):
14 | ramp = torch.linspace(1-denoising_strength, 1, num_inference_steps)
15 | min_inv_rho = torch.pow(torch.tensor((self.sigma_min,)), (1 / self.rho))
16 | max_inv_rho = torch.pow(torch.tensor((self.sigma_max,)), (1 / self.rho))
17 | self.sigmas = torch.pow(max_inv_rho + ramp * (min_inv_rho - max_inv_rho), self.rho)
18 | self.timesteps = torch.log(self.sigmas) * 0.25
19 |
20 |
21 | def step(self, model_output, timestep, sample, to_final=False):
22 | timestep_id = torch.argmin((self.timesteps - timestep).abs())
23 | sigma = self.sigmas[timestep_id]
24 | sample *= (sigma*sigma + 1).sqrt()
25 | estimated_sample = -sigma / (sigma*sigma + 1).sqrt() * model_output + 1 / (sigma*sigma + 1) * sample
26 | if to_final or timestep_id + 1 >= len(self.timesteps):
27 | prev_sample = estimated_sample
28 | else:
29 | sigma_ = self.sigmas[timestep_id + 1]
30 | derivative = 1 / sigma * (sample - estimated_sample)
31 | prev_sample = sample + derivative * (sigma_ - sigma)
32 | prev_sample /= (sigma_*sigma_ + 1).sqrt()
33 | return prev_sample
34 |
35 |
36 | def return_to_timestep(self, timestep, sample, sample_stablized):
37 | # This scheduler doesn't support this function.
38 | pass
39 |
40 |
41 | def add_noise(self, original_samples, noise, timestep):
42 | timestep_id = torch.argmin((self.timesteps - timestep).abs())
43 | sigma = self.sigmas[timestep_id]
44 | sample = (original_samples + noise * sigma) / (sigma*sigma + 1).sqrt()
45 | return sample
46 |
47 |
48 | def training_target(self, sample, noise, timestep):
49 | timestep_id = torch.argmin((self.timesteps - timestep).abs())
50 | sigma = self.sigmas[timestep_id]
51 | target = (-(sigma*sigma + 1).sqrt() / sigma + 1 / (sigma*sigma + 1).sqrt() / sigma) * sample + 1 / (sigma*sigma + 1).sqrt() * noise
52 | return target
53 |
54 |
55 | def training_weight(self, timestep):
56 | timestep_id = torch.argmin((self.timesteps - timestep).abs())
57 | sigma = self.sigmas[timestep_id]
58 | weight = (1 + sigma*sigma).sqrt() / sigma
59 | return weight
60 |
--------------------------------------------------------------------------------
/diffsynth/schedulers/ddim.py:
--------------------------------------------------------------------------------
1 | import torch, math
2 |
3 |
4 | class EnhancedDDIMScheduler():
5 |
6 | def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon"):
7 | self.num_train_timesteps = num_train_timesteps
8 | if beta_schedule == "scaled_linear":
9 | betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
10 | elif beta_schedule == "linear":
11 | betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
12 | else:
13 | raise NotImplementedError(f"{beta_schedule} is not implemented")
14 | self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0).tolist()
15 | self.set_timesteps(10)
16 | self.prediction_type = prediction_type
17 |
18 |
19 | def set_timesteps(self, num_inference_steps, denoising_strength=1.0):
20 | # The timesteps are aligned to 999...0, which is different from other implementations,
21 | # but I think this implementation is more reasonable in theory.
22 | max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
23 | num_inference_steps = min(num_inference_steps, max_timestep + 1)
24 | if num_inference_steps == 1:
25 | self.timesteps = [max_timestep]
26 | else:
27 | step_length = max_timestep / (num_inference_steps - 1)
28 | self.timesteps = [round(max_timestep - i*step_length) for i in range(num_inference_steps)]
29 |
30 |
31 | def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
32 | if self.prediction_type == "epsilon":
33 | weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
34 | weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
35 | prev_sample = sample * weight_x + model_output * weight_e
36 | elif self.prediction_type == "v_prediction":
37 | weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev))
38 | weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev))
39 | prev_sample = sample * weight_x + model_output * weight_e
40 | else:
41 | raise NotImplementedError(f"{self.prediction_type} is not implemented")
42 | return prev_sample
43 |
44 |
45 | def step(self, model_output, timestep, sample, to_final=False):
46 | alpha_prod_t = self.alphas_cumprod[timestep]
47 | timestep_id = self.timesteps.index(timestep)
48 | if to_final or timestep_id + 1 >= len(self.timesteps):
49 | alpha_prod_t_prev = 1.0
50 | else:
51 | timestep_prev = self.timesteps[timestep_id + 1]
52 | alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
53 |
54 | return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
55 |
56 |
57 | def return_to_timestep(self, timestep, sample, sample_stablized):
58 | alpha_prod_t = self.alphas_cumprod[timestep]
59 | noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
60 | return noise_pred
61 |
62 |
63 | def add_noise(self, original_samples, noise, timestep):
64 | sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep])
65 | sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep])
66 | noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
67 | return noisy_samples
68 |
69 | def training_target(self, sample, noise, timestep):
70 | sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep])
71 | sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep])
72 | target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
73 | return target
74 |
--------------------------------------------------------------------------------
/diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "cls_token": "[CLS]",
3 | "mask_token": "[MASK]",
4 | "pad_token": "[PAD]",
5 | "sep_token": "[SEP]",
6 | "unk_token": "[UNK]"
7 | }
8 |
--------------------------------------------------------------------------------
/diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "cls_token": "[CLS]",
3 | "do_basic_tokenize": true,
4 | "do_lower_case": true,
5 | "mask_token": "[MASK]",
6 | "name_or_path": "hfl/chinese-roberta-wwm-ext",
7 | "never_split": null,
8 | "pad_token": "[PAD]",
9 | "sep_token": "[SEP]",
10 | "special_tokens_map_file": "/home/chenweifeng/.cache/huggingface/hub/models--hfl--chinese-roberta-wwm-ext/snapshots/5c58d0b8ec1d9014354d691c538661bf00bfdb44/special_tokens_map.json",
11 | "strip_accents": null,
12 | "tokenize_chinese_chars": true,
13 | "tokenizer_class": "BertTokenizer",
14 | "unk_token": "[UNK]",
15 | "model_max_length": 77
16 | }
17 |
--------------------------------------------------------------------------------
/diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "/home/patrick/t5/mt5-xl",
3 | "architectures": [
4 | "MT5ForConditionalGeneration"
5 | ],
6 | "d_ff": 5120,
7 | "d_kv": 64,
8 | "d_model": 2048,
9 | "decoder_start_token_id": 0,
10 | "dropout_rate": 0.1,
11 | "eos_token_id": 1,
12 | "feed_forward_proj": "gated-gelu",
13 | "initializer_factor": 1.0,
14 | "is_encoder_decoder": true,
15 | "layer_norm_epsilon": 1e-06,
16 | "model_type": "mt5",
17 | "num_decoder_layers": 24,
18 | "num_heads": 32,
19 | "num_layers": 24,
20 | "output_past": true,
21 | "pad_token_id": 0,
22 | "relative_attention_num_buckets": 32,
23 | "tie_word_embeddings": false,
24 | "tokenizer_class": "T5Tokenizer",
25 | "transformers_version": "4.10.0.dev0",
26 | "use_cache": true,
27 | "vocab_size": 250112
28 | }
29 |
--------------------------------------------------------------------------------
/diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {"eos_token": "", "unk_token": "", "pad_token": ""}
--------------------------------------------------------------------------------
/diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-DiffSynth-Studio/91d0a6bd9efa6a1d4a0b880984e21fe66acf4eb0/diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model
--------------------------------------------------------------------------------
/diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {"eos_token": "", "unk_token": "", "pad_token": "", "extra_ids": 0, "additional_special_tokens": null, "special_tokens_map_file": "", "tokenizer_file": null, "name_or_path": "google/mt5-small", "model_max_length": 256, "legacy": true}
--------------------------------------------------------------------------------
/diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "bos_token": {
3 | "content": "<|startoftext|>",
4 | "lstrip": false,
5 | "normalized": true,
6 | "rstrip": false,
7 | "single_word": false
8 | },
9 | "eos_token": {
10 | "content": "<|endoftext|>",
11 | "lstrip": false,
12 | "normalized": true,
13 | "rstrip": false,
14 | "single_word": false
15 | },
16 | "pad_token": "<|endoftext|>",
17 | "unk_token": {
18 | "content": "<|endoftext|>",
19 | "lstrip": false,
20 | "normalized": true,
21 | "rstrip": false,
22 | "single_word": false
23 | }
24 | }
25 |
--------------------------------------------------------------------------------
/diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "add_prefix_space": false,
3 | "bos_token": {
4 | "__type": "AddedToken",
5 | "content": "<|startoftext|>",
6 | "lstrip": false,
7 | "normalized": true,
8 | "rstrip": false,
9 | "single_word": false
10 | },
11 | "do_lower_case": true,
12 | "eos_token": {
13 | "__type": "AddedToken",
14 | "content": "<|endoftext|>",
15 | "lstrip": false,
16 | "normalized": true,
17 | "rstrip": false,
18 | "single_word": false
19 | },
20 | "errors": "replace",
21 | "model_max_length": 77,
22 | "name_or_path": "openai/clip-vit-large-patch14",
23 | "pad_token": "<|endoftext|>",
24 | "special_tokens_map_file": "./special_tokens_map.json",
25 | "tokenizer_class": "CLIPTokenizer",
26 | "unk_token": {
27 | "__type": "AddedToken",
28 | "content": "<|endoftext|>",
29 | "lstrip": false,
30 | "normalized": true,
31 | "rstrip": false,
32 | "single_word": false
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "bos_token": {
3 | "content": "<|startoftext|>",
4 | "lstrip": false,
5 | "normalized": true,
6 | "rstrip": false,
7 | "single_word": false
8 | },
9 | "eos_token": {
10 | "content": "<|endoftext|>",
11 | "lstrip": false,
12 | "normalized": true,
13 | "rstrip": false,
14 | "single_word": false
15 | },
16 | "pad_token": "!",
17 | "unk_token": {
18 | "content": "<|endoftext|>",
19 | "lstrip": false,
20 | "normalized": true,
21 | "rstrip": false,
22 | "single_word": false
23 | }
24 | }
--------------------------------------------------------------------------------
/diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "add_prefix_space": false,
3 | "added_tokens_decoder": {
4 | "0": {
5 | "content": "!",
6 | "lstrip": false,
7 | "normalized": false,
8 | "rstrip": false,
9 | "single_word": false,
10 | "special": true
11 | },
12 | "49406": {
13 | "content": "<|startoftext|>",
14 | "lstrip": false,
15 | "normalized": true,
16 | "rstrip": false,
17 | "single_word": false,
18 | "special": true
19 | },
20 | "49407": {
21 | "content": "<|endoftext|>",
22 | "lstrip": false,
23 | "normalized": true,
24 | "rstrip": false,
25 | "single_word": false,
26 | "special": true
27 | }
28 | },
29 | "bos_token": "<|startoftext|>",
30 | "clean_up_tokenization_spaces": true,
31 | "do_lower_case": true,
32 | "eos_token": "<|endoftext|>",
33 | "errors": "replace",
34 | "model_max_length": 77,
35 | "pad_token": "!",
36 | "tokenizer_class": "CLIPTokenizer",
37 | "unk_token": "<|endoftext|>"
38 | }
--------------------------------------------------------------------------------
/models/put diffsynth studio models here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-DiffSynth-Studio/91d0a6bd9efa6a1d4a0b880984e21fe66acf4eb0/models/put diffsynth studio models here
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cupy-cuda12x
2 | pip
3 | transformers
4 | controlnet-aux==0.0.7
5 | streamlit
6 | streamlit-drawable-canvas
7 | imageio
8 | imageio[ffmpeg]
9 | safetensors
10 | einops
11 | sentencepiece
--------------------------------------------------------------------------------
/util_nodes.py:
--------------------------------------------------------------------------------
1 | import os
2 | import folder_paths
3 | now_dir = os.path.dirname(os.path.abspath(__file__))
4 | input_dir = folder_paths.get_input_directory()
5 | output_dir = folder_paths.get_output_directory()
6 |
7 | class LoadVideo:
8 | @classmethod
9 | def INPUT_TYPES(s):
10 | files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.split('.')[-1] in ["mp4", "webm","mkv","avi"]]
11 | return {"required":{
12 | "video":(files,),
13 | }}
14 |
15 | CATEGORY = "AIFSH_DiffSynth-Studio"
16 | DESCRIPTION = "hello world!"
17 |
18 | RETURN_TYPES = ("VIDEO",)
19 |
20 | OUTPUT_NODE = False
21 |
22 | FUNCTION = "load_video"
23 |
24 | def load_video(self, video):
25 | video_path = os.path.join(input_dir,video)
26 | return (video_path,)
27 |
28 | class PreViewVideo:
29 | @classmethod
30 | def INPUT_TYPES(s):
31 | return {"required":{
32 | "video":("VIDEO",),
33 | }}
34 |
35 | CATEGORY = "AIFSH_DiffSynth-Studio"
36 | DESCRIPTION = "hello world!"
37 |
38 | RETURN_TYPES = ()
39 |
40 | OUTPUT_NODE = True
41 |
42 | FUNCTION = "load_video"
43 |
44 | def load_video(self, video):
45 | video_name = os.path.basename(video)
46 | video_path_name = os.path.basename(os.path.dirname(video))
47 | return {"ui":{"video":[video_name,video_path_name]}}
48 |
--------------------------------------------------------------------------------
/web.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-DiffSynth-Studio/91d0a6bd9efa6a1d4a0b880984e21fe66acf4eb0/web.png
--------------------------------------------------------------------------------
/web/js/previewVideo.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../../scripts/app.js";
2 | import { api } from '../../../scripts/api.js'
3 |
4 | function fitHeight(node) {
5 | node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]])
6 | node?.graph?.setDirtyCanvas(true);
7 | }
8 | function chainCallback(object, property, callback) {
9 | if (object == undefined) {
10 | //This should not happen.
11 | console.error("Tried to add callback to non-existant object")
12 | return;
13 | }
14 | if (property in object) {
15 | const callback_orig = object[property]
16 | object[property] = function () {
17 | const r = callback_orig.apply(this, arguments);
18 | callback.apply(this, arguments);
19 | return r
20 | };
21 | } else {
22 | object[property] = callback;
23 | }
24 | }
25 |
26 | function addPreviewOptions(nodeType) {
27 | chainCallback(nodeType.prototype, "getExtraMenuOptions", function(_, options) {
28 | // The intended way of appending options is returning a list of extra options,
29 | // but this isn't used in widgetInputs.js and would require
30 | // less generalization of chainCallback
31 | let optNew = []
32 | try {
33 | const previewWidget = this.widgets.find((w) => w.name === "videopreview");
34 |
35 | let url = null
36 | if (previewWidget.videoEl?.hidden == false && previewWidget.videoEl.src) {
37 | //Use full quality video
38 | //url = api.apiURL('/view?' + new URLSearchParams(previewWidget.value.params));
39 | url = previewWidget.videoEl.src
40 | }
41 | if (url) {
42 | optNew.push(
43 | {
44 | content: "Open preview",
45 | callback: () => {
46 | window.open(url, "_blank")
47 | },
48 | },
49 | {
50 | content: "Save preview",
51 | callback: () => {
52 | const a = document.createElement("a");
53 | a.href = url;
54 | a.setAttribute("download", new URLSearchParams(previewWidget.value.params).get("filename"));
55 | document.body.append(a);
56 | a.click();
57 | requestAnimationFrame(() => a.remove());
58 | },
59 | }
60 | );
61 | }
62 | if(options.length > 0 && options[0] != null && optNew.length > 0) {
63 | optNew.push(null);
64 | }
65 | options.unshift(...optNew);
66 |
67 | } catch (error) {
68 | console.log(error);
69 | }
70 |
71 | });
72 | }
73 | function previewVideo(node,file,type){
74 | var element = document.createElement("div");
75 | const previewNode = node;
76 | var previewWidget = node.addDOMWidget("videopreview", "preview", element, {
77 | serialize: false,
78 | hideOnZoom: false,
79 | getValue() {
80 | return element.value;
81 | },
82 | setValue(v) {
83 | element.value = v;
84 | },
85 | });
86 | previewWidget.computeSize = function(width) {
87 | if (this.aspectRatio && !this.parentEl.hidden) {
88 | let height = (previewNode.size[0]-20)/ this.aspectRatio + 10;
89 | if (!(height > 0)) {
90 | height = 0;
91 | }
92 | this.computedHeight = height + 10;
93 | return [width, height];
94 | }
95 | return [width, -4];//no loaded src, widget should not display
96 | }
97 | // element.style['pointer-events'] = "none"
98 | previewWidget.value = {hidden: false, paused: false, params: {}}
99 | previewWidget.parentEl = document.createElement("div");
100 | previewWidget.parentEl.className = "video_preview";
101 | previewWidget.parentEl.style['width'] = "100%"
102 | element.appendChild(previewWidget.parentEl);
103 | previewWidget.videoEl = document.createElement("video");
104 | previewWidget.videoEl.controls = true;
105 | previewWidget.videoEl.loop = false;
106 | previewWidget.videoEl.muted = false;
107 | previewWidget.videoEl.style['width'] = "100%"
108 | previewWidget.videoEl.addEventListener("loadedmetadata", () => {
109 |
110 | previewWidget.aspectRatio = previewWidget.videoEl.videoWidth / previewWidget.videoEl.videoHeight;
111 | fitHeight(this);
112 | });
113 | previewWidget.videoEl.addEventListener("error", () => {
114 | //TODO: consider a way to properly notify the user why a preview isn't shown.
115 | previewWidget.parentEl.hidden = true;
116 | fitHeight(this);
117 | });
118 |
119 | let params = {
120 | "filename": file,
121 | "type": type,
122 | }
123 |
124 | previewWidget.parentEl.hidden = previewWidget.value.hidden;
125 | previewWidget.videoEl.autoplay = !previewWidget.value.paused && !previewWidget.value.hidden;
126 | let target_width = 256
127 | if (element.style?.width) {
128 | //overscale to allow scrolling. Endpoint won't return higher than native
129 | target_width = element.style.width.slice(0,-2)*2;
130 | }
131 | if (!params.force_size || params.force_size.includes("?") || params.force_size == "Disabled") {
132 | params.force_size = target_width+"x?"
133 | } else {
134 | let size = params.force_size.split("x")
135 | let ar = parseInt(size[0])/parseInt(size[1])
136 | params.force_size = target_width+"x"+(target_width/ar)
137 | }
138 |
139 | previewWidget.videoEl.src = api.apiURL('/view?' + new URLSearchParams(params));
140 |
141 | previewWidget.videoEl.hidden = false;
142 | previewWidget.parentEl.appendChild(previewWidget.videoEl)
143 | }
144 |
145 | app.registerExtension({
146 | name: "DiffSynth-Studio.VideoPreviewer",
147 | async beforeRegisterNodeDef(nodeType, nodeData, app) {
148 | if (nodeData?.name == "PreViewVideo") {
149 | nodeType.prototype.onExecuted = function (data) {
150 | previewVideo(this, data.video[0], data.video[1]);
151 | }
152 | }
153 | }
154 | });
155 |
--------------------------------------------------------------------------------
/web/js/uploadVideo.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../../scripts/app.js";
2 | import { api } from '../../../scripts/api.js'
3 | import { ComfyWidgets } from "../../../scripts/widgets.js"
4 |
5 | function fitHeight(node) {
6 | node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]])
7 | node?.graph?.setDirtyCanvas(true);
8 | }
9 |
10 | function previewVideo(node,file){
11 | while (node.widgets.length > 2){
12 | node.widgets.pop()
13 | }
14 | try {
15 | var el = document.getElementById("uploadVideo");
16 | el.remove();
17 | } catch (error) {
18 | console.log(error);
19 | }
20 | var element = document.createElement("div");
21 | element.id = "uploadVideo";
22 | const previewNode = node;
23 | var previewWidget = node.addDOMWidget("videopreview", "preview", element, {
24 | serialize: false,
25 | hideOnZoom: false,
26 | getValue() {
27 | return element.value;
28 | },
29 | setValue(v) {
30 | element.value = v;
31 | },
32 | });
33 | previewWidget.computeSize = function(width) {
34 | if (this.aspectRatio && !this.parentEl.hidden) {
35 | let height = (previewNode.size[0]-20)/ this.aspectRatio + 10;
36 | if (!(height > 0)) {
37 | height = 0;
38 | }
39 | this.computedHeight = height + 10;
40 | return [width, height];
41 | }
42 | return [width, -4];//no loaded src, widget should not display
43 | }
44 | // element.style['pointer-events'] = "none"
45 | previewWidget.value = {hidden: false, paused: false, params: {}}
46 | previewWidget.parentEl = document.createElement("div");
47 | previewWidget.parentEl.className = "video_preview";
48 | previewWidget.parentEl.style['width'] = "100%"
49 | element.appendChild(previewWidget.parentEl);
50 | previewWidget.videoEl = document.createElement("video");
51 | previewWidget.videoEl.controls = true;
52 | previewWidget.videoEl.loop = false;
53 | previewWidget.videoEl.muted = false;
54 | previewWidget.videoEl.style['width'] = "100%"
55 | previewWidget.videoEl.addEventListener("loadedmetadata", () => {
56 |
57 | previewWidget.aspectRatio = previewWidget.videoEl.videoWidth / previewWidget.videoEl.videoHeight;
58 | fitHeight(this);
59 | });
60 | previewWidget.videoEl.addEventListener("error", () => {
61 | //TODO: consider a way to properly notify the user why a preview isn't shown.
62 | previewWidget.parentEl.hidden = true;
63 | fitHeight(this);
64 | });
65 |
66 | let params = {
67 | "filename": file,
68 | "type": "input",
69 | }
70 |
71 | previewWidget.parentEl.hidden = previewWidget.value.hidden;
72 | previewWidget.videoEl.autoplay = !previewWidget.value.paused && !previewWidget.value.hidden;
73 | let target_width = 256
74 | if (element.style?.width) {
75 | //overscale to allow scrolling. Endpoint won't return higher than native
76 | target_width = element.style.width.slice(0,-2)*2;
77 | }
78 | if (!params.force_size || params.force_size.includes("?") || params.force_size == "Disabled") {
79 | params.force_size = target_width+"x?"
80 | } else {
81 | let size = params.force_size.split("x")
82 | let ar = parseInt(size[0])/parseInt(size[1])
83 | params.force_size = target_width+"x"+(target_width/ar)
84 | }
85 |
86 | previewWidget.videoEl.src = api.apiURL('/view?' + new URLSearchParams(params));
87 |
88 | previewWidget.videoEl.hidden = false;
89 | previewWidget.parentEl.appendChild(previewWidget.videoEl)
90 | }
91 |
92 | function videoUpload(node, inputName, inputData, app) {
93 | const videoWidget = node.widgets.find((w) => w.name === "video");
94 | let uploadWidget;
95 | /*
96 | A method that returns the required style for the html
97 | */
98 | var default_value = videoWidget.value;
99 | Object.defineProperty(videoWidget, "value", {
100 | set : function(value) {
101 | this._real_value = value;
102 | },
103 |
104 | get : function() {
105 | let value = "";
106 | if (this._real_value) {
107 | value = this._real_value;
108 | } else {
109 | return default_value;
110 | }
111 |
112 | if (value.filename) {
113 | let real_value = value;
114 | value = "";
115 | if (real_value.subfolder) {
116 | value = real_value.subfolder + "/";
117 | }
118 |
119 | value += real_value.filename;
120 |
121 | if(real_value.type && real_value.type !== "input")
122 | value += ` [${real_value.type}]`;
123 | }
124 | return value;
125 | }
126 | });
127 | async function uploadFile(file, updateNode, pasted = false) {
128 | try {
129 | // Wrap file in formdata so it includes filename
130 | const body = new FormData();
131 | body.append("image", file);
132 | if (pasted) body.append("subfolder", "pasted");
133 | const resp = await api.fetchApi("/upload/image", {
134 | method: "POST",
135 | body,
136 | });
137 |
138 | if (resp.status === 200) {
139 | const data = await resp.json();
140 | // Add the file to the dropdown list and update the widget value
141 | let path = data.name;
142 | if (data.subfolder) path = data.subfolder + "/" + path;
143 |
144 | if (!videoWidget.options.values.includes(path)) {
145 | videoWidget.options.values.push(path);
146 | }
147 |
148 | if (updateNode) {
149 | videoWidget.value = path;
150 | previewVideo(node,path)
151 |
152 | }
153 | } else {
154 | alert(resp.status + " - " + resp.statusText);
155 | }
156 | } catch (error) {
157 | alert(error);
158 | }
159 | }
160 |
161 | const fileInput = document.createElement("input");
162 | Object.assign(fileInput, {
163 | type: "file",
164 | accept: "video/webm,video/mp4,video/mkv,video/avi",
165 | style: "display: none",
166 | onchange: async () => {
167 | if (fileInput.files.length) {
168 | await uploadFile(fileInput.files[0], true);
169 | }
170 | },
171 | });
172 | document.body.append(fileInput);
173 |
174 | // Create the button widget for selecting the files
175 | uploadWidget = node.addWidget("button", "choose video file to upload", "Video", () => {
176 | fileInput.click();
177 | });
178 |
179 | uploadWidget.serialize = false;
180 |
181 | previewVideo(node, videoWidget.value);
182 | const cb = node.callback;
183 | videoWidget.callback = function () {
184 | previewVideo(node,videoWidget.value);
185 | if (cb) {
186 | return cb.apply(this, arguments);
187 | }
188 | };
189 |
190 | return { widget: uploadWidget };
191 | }
192 |
193 | ComfyWidgets.VIDEOPLOAD = videoUpload;
194 |
195 | app.registerExtension({
196 | name: "DiffSynth-Studio.UploadVideo",
197 | async beforeRegisterNodeDef(nodeType, nodeData, app) {
198 | if (nodeData?.name == "LoadVideo") {
199 | nodeData.input.required.upload = ["VIDEOPLOAD"];
200 | }
201 | },
202 | });
203 |
204 |
--------------------------------------------------------------------------------
/workfolws/diffutoon_workflow.json:
--------------------------------------------------------------------------------
1 | {
2 | "last_node_id": 18,
3 | "last_link_id": 22,
4 | "nodes": [
5 | {
6 | "id": 10,
7 | "type": "PreViewVideo",
8 | "pos": [
9 | 1017.6000366210938,
10 | 406.20001220703125
11 | ],
12 | "size": {
13 | "0": 210,
14 | "1": 26
15 | },
16 | "flags": {},
17 | "order": 7,
18 | "mode": 0,
19 | "inputs": [
20 | {
21 | "name": "video",
22 | "type": "VIDEO",
23 | "link": 21
24 | }
25 | ],
26 | "properties": {
27 | "Node name for S&R": "PreViewVideo"
28 | }
29 | },
30 | {
31 | "id": 6,
32 | "type": "DiffTextNode",
33 | "pos": [
34 | 505,
35 | 485
36 | ],
37 | "size": {
38 | "0": 400,
39 | "1": 200
40 | },
41 | "flags": {},
42 | "order": 0,
43 | "mode": 0,
44 | "outputs": [
45 | {
46 | "name": "TEXT",
47 | "type": "TEXT",
48 | "links": [
49 | 18
50 | ],
51 | "shape": 3
52 | }
53 | ],
54 | "properties": {
55 | "Node name for S&R": "DiffTextNode"
56 | },
57 | "widgets_values": [
58 | "verybadimagenegative_v1.3"
59 | ]
60 | },
61 | {
62 | "id": 16,
63 | "type": "ControlNetPathLoader",
64 | "pos": [
65 | 1138,
66 | 206
67 | ],
68 | "size": {
69 | "0": 315,
70 | "1": 106
71 | },
72 | "flags": {},
73 | "order": 1,
74 | "mode": 0,
75 | "outputs": [
76 | {
77 | "name": "ControlNetConfigUnit",
78 | "type": "ControlNetConfigUnit",
79 | "links": [
80 | 20
81 | ],
82 | "shape": 3
83 | }
84 | ],
85 | "properties": {
86 | "Node name for S&R": "ControlNetPathLoader"
87 | },
88 | "widgets_values": [
89 | "tile",
90 | 0.5,
91 | null
92 | ]
93 | },
94 | {
95 | "id": 15,
96 | "type": "ControlNetPathLoader",
97 | "pos": [
98 | 1128,
99 | 17
100 | ],
101 | "size": {
102 | "0": 315,
103 | "1": 106
104 | },
105 | "flags": {},
106 | "order": 2,
107 | "mode": 0,
108 | "outputs": [
109 | {
110 | "name": "ControlNetConfigUnit",
111 | "type": "ControlNetConfigUnit",
112 | "links": [
113 | 19
114 | ],
115 | "shape": 3
116 | }
117 | ],
118 | "properties": {
119 | "Node name for S&R": "ControlNetPathLoader"
120 | },
121 | "widgets_values": [
122 | "lineart",
123 | 0.5,
124 | "control_v11p_sd15_lineart.pth"
125 | ]
126 | },
127 | {
128 | "id": 3,
129 | "type": "LoadVideo",
130 | "pos": [
131 | 153,
132 | -100
133 | ],
134 | "size": {
135 | "0": 315,
136 | "1": 383
137 | },
138 | "flags": {},
139 | "order": 3,
140 | "mode": 0,
141 | "outputs": [
142 | {
143 | "name": "VIDEO",
144 | "type": "VIDEO",
145 | "links": [
146 | 15
147 | ],
148 | "shape": 3,
149 | "slot_index": 0
150 | }
151 | ],
152 | "properties": {
153 | "Node name for S&R": "LoadVideo"
154 | },
155 | "widgets_values": [
156 | "diffutoondemo.mp4",
157 | "Video",
158 | {
159 | "hidden": false,
160 | "paused": false,
161 | "params": {}
162 | }
163 | ]
164 | },
165 | {
166 | "id": 5,
167 | "type": "DiffTextNode",
168 | "pos": [
169 | 104,
170 | 664
171 | ],
172 | "size": {
173 | "0": 400,
174 | "1": 200
175 | },
176 | "flags": {},
177 | "order": 4,
178 | "mode": 0,
179 | "outputs": [
180 | {
181 | "name": "TEXT",
182 | "type": "TEXT",
183 | "links": [
184 | 17
185 | ],
186 | "shape": 3,
187 | "slot_index": 0
188 | }
189 | ],
190 | "properties": {
191 | "Node name for S&R": "DiffTextNode"
192 | },
193 | "widgets_values": [
194 | "best quality, perfect anime illustration, light, a girl is dancing, smile, solo"
195 | ]
196 | },
197 | {
198 | "id": 17,
199 | "type": "DiffutoonNode",
200 | "pos": [
201 | 674.969524572754,
202 | 27.48489999999994
203 | ],
204 | "size": {
205 | "0": 315,
206 | "1": 370
207 | },
208 | "flags": {},
209 | "order": 6,
210 | "mode": 0,
211 | "inputs": [
212 | {
213 | "name": "source_video_path",
214 | "type": "VIDEO",
215 | "link": 15
216 | },
217 | {
218 | "name": "sd_model_path",
219 | "type": "SD_MODEL_PATH",
220 | "link": 22,
221 | "slot_index": 1
222 | },
223 | {
224 | "name": "postive_prompt",
225 | "type": "TEXT",
226 | "link": 17
227 | },
228 | {
229 | "name": "negative_prompt",
230 | "type": "TEXT",
231 | "link": 18,
232 | "slot_index": 3
233 | },
234 | {
235 | "name": "controlnet1",
236 | "type": "ControlNetConfigUnit",
237 | "link": 19,
238 | "slot_index": 4
239 | },
240 | {
241 | "name": "controlnet2",
242 | "type": "ControlNetConfigUnit",
243 | "link": 20,
244 | "slot_index": 5
245 | },
246 | {
247 | "name": "controlnet3",
248 | "type": "ControlNetConfigUnit",
249 | "link": null
250 | }
251 | ],
252 | "outputs": [
253 | {
254 | "name": "VIDEO",
255 | "type": "VIDEO",
256 | "links": [
257 | 21
258 | ],
259 | "shape": 3,
260 | "slot_index": 0
261 | }
262 | ],
263 | "properties": {
264 | "Node name for S&R": "DiffutoonNode"
265 | },
266 | "widgets_values": [
267 | 40,
268 | 1,
269 | 531,
270 | "randomize",
271 | 3,
272 | 10,
273 | 32,
274 | 16,
275 | 0
276 | ]
277 | },
278 | {
279 | "id": 18,
280 | "type": "SDPathLoader",
281 | "pos": [
282 | 94,
283 | 399
284 | ],
285 | "size": {
286 | "0": 315,
287 | "1": 130
288 | },
289 | "flags": {},
290 | "order": 5,
291 | "mode": 0,
292 | "outputs": [
293 | {
294 | "name": "SD_MODEL_PATH",
295 | "type": "SD_MODEL_PATH",
296 | "links": [
297 | 22
298 | ],
299 | "shape": 3
300 | }
301 | ],
302 | "properties": {
303 | "Node name for S&R": "SDPathLoader"
304 | },
305 | "widgets_values": [
306 | "philz1337x/flat2DAnimerge_v45Sharp",
307 | "flat2DAnimerge_v45Sharp.safetensors",
308 | "HuggingFace",
309 | "flat2DAnimerge_v45Sharp.safetensors"
310 | ]
311 | }
312 | ],
313 | "links": [
314 | [
315 | 15,
316 | 3,
317 | 0,
318 | 17,
319 | 0,
320 | "VIDEO"
321 | ],
322 | [
323 | 17,
324 | 5,
325 | 0,
326 | 17,
327 | 2,
328 | "TEXT"
329 | ],
330 | [
331 | 18,
332 | 6,
333 | 0,
334 | 17,
335 | 3,
336 | "TEXT"
337 | ],
338 | [
339 | 19,
340 | 15,
341 | 0,
342 | 17,
343 | 4,
344 | "ControlNetConfigUnit"
345 | ],
346 | [
347 | 20,
348 | 16,
349 | 0,
350 | 17,
351 | 5,
352 | "ControlNetConfigUnit"
353 | ],
354 | [
355 | 21,
356 | 17,
357 | 0,
358 | 10,
359 | 0,
360 | "VIDEO"
361 | ],
362 | [
363 | 22,
364 | 18,
365 | 0,
366 | 17,
367 | 1,
368 | "SD_MODEL_PATH"
369 | ]
370 | ],
371 | "groups": [],
372 | "config": {},
373 | "extra": {
374 | "ds": {
375 | "scale": 0.6830134553650705,
376 | "offset": [
377 | 88.7050890441895,
378 | 88.17900000000009
379 | ]
380 | }
381 | },
382 | "version": 0.4
383 | }
--------------------------------------------------------------------------------
/workfolws/exvideo_workflow.json:
--------------------------------------------------------------------------------
1 | {
2 | "last_node_id": 5,
3 | "last_link_id": 4,
4 | "nodes": [
5 | {
6 | "id": 3,
7 | "type": "SDPathLoader",
8 | "pos": [
9 | 40,
10 | 423
11 | ],
12 | "size": {
13 | "0": 420.60003662109375,
14 | "1": 136.60003662109375
15 | },
16 | "flags": {},
17 | "order": 0,
18 | "mode": 0,
19 | "outputs": [
20 | {
21 | "name": "SD_MODEL_PATH",
22 | "type": "SD_MODEL_PATH",
23 | "links": [
24 | 2
25 | ],
26 | "shape": 3
27 | }
28 | ],
29 | "properties": {
30 | "Node name for S&R": "SDPathLoader"
31 | },
32 | "widgets_values": [
33 | "stabilityai/stable-video-diffusion-img2vid-xt",
34 | "svd_xt.safetensors",
35 | "HuggingFace",
36 | "svd_xt.safetensors"
37 | ]
38 | },
39 | {
40 | "id": 4,
41 | "type": "SDPathLoader",
42 | "pos": [
43 | 514,
44 | 435
45 | ],
46 | "size": {
47 | "0": 489.800048828125,
48 | "1": 130
49 | },
50 | "flags": {},
51 | "order": 1,
52 | "mode": 0,
53 | "outputs": [
54 | {
55 | "name": "SD_MODEL_PATH",
56 | "type": "SD_MODEL_PATH",
57 | "links": [
58 | 3
59 | ],
60 | "shape": 3
61 | }
62 | ],
63 | "properties": {
64 | "Node name for S&R": "SDPathLoader"
65 | },
66 | "widgets_values": [
67 | "ECNU-CILab/ExVideo-SVD-128f-v1",
68 | "model.fp16.safetensors",
69 | "HuggingFace",
70 | "model.fp16.safetensors"
71 | ]
72 | },
73 | {
74 | "id": 2,
75 | "type": "LoadImage",
76 | "pos": [
77 | 51,
78 | 60
79 | ],
80 | "size": {
81 | "0": 315,
82 | "1": 314
83 | },
84 | "flags": {},
85 | "order": 2,
86 | "mode": 0,
87 | "outputs": [
88 | {
89 | "name": "IMAGE",
90 | "type": "IMAGE",
91 | "links": [
92 | 1
93 | ],
94 | "shape": 3
95 | },
96 | {
97 | "name": "MASK",
98 | "type": "MASK",
99 | "links": null,
100 | "shape": 3
101 | }
102 | ],
103 | "properties": {
104 | "Node name for S&R": "LoadImage"
105 | },
106 | "widgets_values": [
107 | "01c6925ae17636a801214a61208e31.png@2o.png",
108 | "image"
109 | ]
110 | },
111 | {
112 | "id": 1,
113 | "type": "ExVideoNode",
114 | "pos": [
115 | 633,
116 | 97
117 | ],
118 | "size": {
119 | "0": 315,
120 | "1": 218
121 | },
122 | "flags": {},
123 | "order": 3,
124 | "mode": 0,
125 | "inputs": [
126 | {
127 | "name": "image",
128 | "type": "IMAGE",
129 | "link": 1,
130 | "slot_index": 0
131 | },
132 | {
133 | "name": "svd_base_model",
134 | "type": "SD_MODEL_PATH",
135 | "link": 2,
136 | "slot_index": 1
137 | },
138 | {
139 | "name": "exvideo_model",
140 | "type": "SD_MODEL_PATH",
141 | "link": 3,
142 | "slot_index": 2
143 | }
144 | ],
145 | "outputs": [
146 | {
147 | "name": "VIDEO",
148 | "type": "VIDEO",
149 | "links": [
150 | 4
151 | ],
152 | "shape": 3,
153 | "slot_index": 0
154 | }
155 | ],
156 | "properties": {
157 | "Node name for S&R": "ExVideoNode"
158 | },
159 | "widgets_values": [
160 | 50,
161 | 25,
162 | 20,
163 | true,
164 | 874,
165 | "randomize"
166 | ]
167 | },
168 | {
169 | "id": 5,
170 | "type": "PreViewVideo",
171 | "pos": [
172 | 1062,
173 | 105
174 | ],
175 | "size": {
176 | "0": 210,
177 | "1": 434
178 | },
179 | "flags": {},
180 | "order": 4,
181 | "mode": 0,
182 | "inputs": [
183 | {
184 | "name": "video",
185 | "type": "VIDEO",
186 | "link": 4
187 | }
188 | ],
189 | "properties": {
190 | "Node name for S&R": "PreViewVideo"
191 | },
192 | "widgets_values": [
193 | {
194 | "hidden": false,
195 | "paused": false,
196 | "params": {}
197 | }
198 | ]
199 | }
200 | ],
201 | "links": [
202 | [
203 | 1,
204 | 2,
205 | 0,
206 | 1,
207 | 0,
208 | "IMAGE"
209 | ],
210 | [
211 | 2,
212 | 3,
213 | 0,
214 | 1,
215 | 1,
216 | "SD_MODEL_PATH"
217 | ],
218 | [
219 | 3,
220 | 4,
221 | 0,
222 | 1,
223 | 2,
224 | "SD_MODEL_PATH"
225 | ],
226 | [
227 | 4,
228 | 1,
229 | 0,
230 | 5,
231 | 0,
232 | "VIDEO"
233 | ]
234 | ],
235 | "groups": [],
236 | "config": {},
237 | "extra": {
238 | "ds": {
239 | "scale": 1,
240 | "offset": [
241 | 0,
242 | -0.79998779296875
243 | ]
244 | }
245 | },
246 | "version": 0.4
247 | }
--------------------------------------------------------------------------------