├── .github └── workflows │ └── publish.yml ├── Paints-UNDO ├── LICENSE ├── README.md ├── ZZX_PaintsUndo.py ├── diffusers_helper │ ├── cat_cond.py │ ├── code_cond.py │ ├── k_diffusion.py │ └── utils.py ├── diffusers_vdm │ ├── attention.py │ ├── basics.py │ ├── dynamic_tsnr_sampler.py │ ├── improved_clip_vision.py │ ├── pipeline.py │ ├── projection.py │ ├── unet.py │ ├── utils.py │ └── vae.py ├── gradio_app.py ├── imgs │ ├── 1.jpg │ ├── 2.jpg │ └── 3.jpg ├── memory_management.py ├── requirements.txt └── wd14tagger.py ├── README.md ├── __init__.py ├── nodes ├── ZZX_Stream.py └── ZZX_VFC.py ├── pyproject.toml ├── requirements.txt └── workflows ├── PaintsUndo.png ├── StreamRecorder+VideoFormatConverter.png ├── VideoFormatConverter.png ├── workflow-StreamRecorder.json └── workflow-VideoFormatConverter.json /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | jobs: 12 | publish-node: 13 | name: Publish Custom Node to registry 14 | runs-on: ubuntu-latest 15 | # if this is a forked repository. Skipping the workflow. 16 | if: github.event.repository.fork == false 17 | steps: 18 | - name: Check out code 19 | uses: actions/checkout@v4 20 | - name: Publish Custom Node 21 | uses: Comfy-Org/publish-node-action@main 22 | with: 23 | ## Add your own personal access token to your Github Repository secrets and reference it here. 24 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 25 | -------------------------------------------------------------------------------- /Paints-UNDO/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Paints-UNDO/README.md: -------------------------------------------------------------------------------- 1 | # Paints-Undo 2 | 3 | PaintsUndo: A Base Model of Drawing Behaviors in Digital Paintings 4 | 5 | Paints-Undo is a project aimed at providing base models of human drawing behaviors with a hope that future AI models can better align with the real needs of human artists. 6 | 7 | The name "Paints-Undo" is inspired by the similarity that, the model's outputs look like pressing the "undo" button (usually Ctrl+Z) many times in digital painting software. 8 | 9 | Paints-Undo presents a family of models that take an image as input and then output the drawing sequence of that image. The model displays all kinds of human behaviors, including but not limited to sketching, inking, coloring, shading, transforming, left-right flipping, color curve tuning, changing the visibility of layers, and even changing the overall idea during the drawing process. 10 | 11 | *This page does not contain any examples. All examples are in the below Git page:* 12 | 13 | [>>> Click Here to See the Example Page <<<](https://lllyasviel.github.io/pages/paints_undo/) 14 | 15 | **This GitHub repo is the only official page of PaintsUndo. We do not have any other websites.** 16 | 17 | **Do note that many fake websites of PaintsUndo are on Google and social media recently.** 18 | 19 | # Get Started 20 | 21 | You can deploy PaintsUndo locally via: 22 | 23 | git clone https://github.com/lllyasviel/Paints-UNDO.git 24 | cd Paints-UNDO 25 | conda create -n paints_undo python=3.10 26 | conda activate paints_undo 27 | pip install xformers 28 | pip install -r requirements.txt 29 | python gradio_app.py 30 | 31 | (If you do not know how to use these commands, you can paste those commands to ChatGPT and ask ChatGPT to explain and give more detailed instructions.) 32 | 33 | The inference is tested with 24GB VRAM on Nvidia 4090 and 3090TI. It may also work with 16GB VRAM, but does not work with 8GB. My estimation is that, under extreme optimization (including weight offloading and sliced attention), the theoretical minimal VRAM requirement is about 10~12.5 GB. 34 | 35 | You can expect to process one image in about 5 to 10 minutes, depending on your settings. As a typical result, you will get a video of 25 seconds at FPS 4, with resolution 320x512, or 512x320, or 384x448, or 448x384. 36 | 37 | Because the processing time, in most cases, is significantly longer than most tasks/quota in HuggingFace Space, I personally do not highly recommend to deploy this to HuggingFace Space, to avoid placing an unnecessary burden on the HF servers. 38 | 39 | If you do not have required computation devices and still wants an online solution, one option is to wait us to release a Colab notebook (but I am not sure if Colab free tier will work). 40 | 41 | # Model Notes 42 | 43 | We currently release two models `paints_undo_single_frame` and `paints_undo_multi_frame`. Let's call them single-frame model and multi-frame model. 44 | 45 | The single-frame model takes one image and an `operation step` as input, and outputs one single image. Assuming that an artwork can always be created with 1000 human operations (for example, one brush stroke is one operation), and the `operation step` is an int number from 0 to 999. The number 0 is the finished final artwork, and the number 999 is the first brush stroke drawn on the pure white canvas. You can understand this model as an "undo" (or called Ctrl+Z) model. You input the final image, and indicate how many times you want to "Ctrl+Z", and the model will give you a "simulated" screenshot after those "Ctrl+Z"s are pressed. If your `operation step` is 100, then it means you want to simulate "Ctrl+Z" 100 times on this image to get the appearance after the 100-th "Ctrl+Z". 46 | 47 | The multi-frame model takes two images as inputs and output 16 intermediate frames between the two input images. The result is much more consistent than the single-frame model, but also much slower, less "creative", and limited in 16 frames. 48 | 49 | In this repo, the default method is to use them together. We will first infer the single-frame model about 5-7 times to get 5-7 "keyframes", and then we use the multi-frame model to "interpolate" those keyframes to actually generate a relatively long video. 50 | 51 | In theory this system can be used in many ways and even give infinitely long video, but in practice results are good when the final frame count is about 100-500. 52 | 53 | ### Model Architecture (paints_undo_single_frame) 54 | 55 | The model is a modified architecture of SD1.5 trained on different betas scheduler, clip skip, and the aforementioned `operation step` condition. To be specific, the model is trained with the betas of: 56 | 57 | `betas = torch.linspace(0.00085, 0.020, 1000, dtype=torch.float64)` 58 | 59 | For comparison, the original SD1.5 is trained with the betas of: 60 | 61 | `betas = torch.linspace(0.00085 ** 0.5, 0.012 ** 0.5, 1000, dtype=torch.float64) ** 2` 62 | 63 | You can notice the difference in the ending betas and the removed square. The choice of this scheduler is based on our internal user study. 64 | 65 | The last layer of the text encoder CLIP ViT-L/14 is permanently removed. It is now mathematically consistent to always set CLIP Skip to 2 (if you use diffusers). 66 | 67 | The `operation step` condition is added to layer embeddings in a way similar to SDXL's extra embeddings. 68 | 69 | Also, since the solo purpose of this model is to process existing images, the model is strictly aligned with WD14 tagger without any other augmentations. You should always use WD14 tagger (the one in this repo) to process the input image to get the prompt. Otherwise, the results may be defective. Human-written prompts are not tested. 70 | 71 | ### Model Architecture (paints_undo_multi_frame) 72 | 73 | This model is trained by resuming from [VideoCrafter](https://github.com/AILab-CVC/VideoCrafter) family, but the original Crafter's `lvdm` is not used and all training/inference codes are completely implemented from scratch. (BTW, now the codes are based on modern Diffusers.) Although the initial weights are resumed from VideoCrafter, the topology of neural network is modified a lot, and the network behavior is now largely different from original Crafter after extensive training. 74 | 75 | The overall architecture is like Crafter with 5 components, 3D-UNet, VAE, CLIP, CLIP-Vision, Image Projection. 76 | 77 | **VAE**: The VAE is the exactly same anime VAE extracted from [ToonCrafter](https://github.com/ToonCrafter/ToonCrafter). Thanks ToonCrafter a lot for providing the excellent anime temporal VAE for Crafters. 78 | 79 | **3D-UNet**: The 3D-UNet is modified from Crafters's `lvdm` with revisions to attention modules. Other than some minor changes in codes, the major change is that now the UNet are trained and supports temporal windows in Spatial Self Attention layers. You can change the codes in `diffusers_vdm.attention.CrossAttention.temporal_window_for_spatial_self_attention` and `temporal_window_type` to activate three types of attention windows: 80 | 81 | 1. "prv" mode: Each frame's Spatial Self-Attention also attend to full spatial contexts of its previous frame. The first frame only attend itself. 82 | 2. "first": Each frame's Spatial Self-Attention also attend to full spatial contexts of the first frame of the entire sequence. The first frame only attend its self. 83 | 3. "roll": Each frame's Spatial Self-Attention also attend to full spatial contexts of its previous and next frames, based on the ordering of `torch.roll`. 84 | 85 | Note that this is by default disabled in inference to save GPU memory. 86 | 87 | **CLIP**: The CLIP of SD2.1. 88 | 89 | **CLIP-Vision**: Our implementation of Clip Vision (ViT/H) that supports arbitrary aspect ratios by interpolating the positional embedding. After experimenting with linear interpolation, nearest neighbor, and Rotary Positional Encoding (RoPE), our final choice is nearest neighbor. Note that this is different from Crafter methods that resize or center-crop images to 224x224. 90 | 91 | **Image Projection**: Our implementation of a tiny transformer that takes two frames as inputs and outputs 16 image embeddings for each frame. Note that this is different from Crafter methods that only use one image. 92 | 93 | # Tutorial 94 | 95 | After you get into the Gradio interface: 96 | 97 | Step 0: Upload an image or just click an Example image on the bottom of the page. 98 | 99 | Step 1: In the UI titled "step 1", click generate prompts to get the global prompt. 100 | 101 | Step 2: In the UI titled "step 2", click "Generate Key Frames". You can change seeds or other parameters on the left. 102 | 103 | Step 3: In the UI titled "step 3", click "Generate Video". You can change seeds or other parameters on the left. 104 | 105 | # Cite 106 | 107 | @Misc{paintsundo, 108 | author = {Paints-Undo Team}, 109 | title = {Paints-Undo GitHub Page}, 110 | year = {2024}, 111 | } 112 | 113 | # Applications 114 | 115 | Typical use cases of PaintsUndo: 116 | 117 | 1. Use PaintsUndo as a base model to analyze human behavior to build AI tools that align with human behavior and human demands, for seamless collaboration between AI and humans in a perfectly controlled workflow. 118 | 119 | 2. Combine PaintsUndo with sketch-guided image generators to achieve “PaintsRedo”, so as to move forward or backward arbitrarily in any of your finished/unfinished artworks to enhance human creativity power. * 120 | 121 | 3. Use PaintsUndo to view different possible procedures of your own artworks for artistic inspirations. 122 | 123 | 4. Use the outputs of PaintsUndo as a kind of video/movie After Effects to achieve specific creative purposes. 124 | 125 | and much more ... 126 | 127 | * *this is already possible - if you use PaintsUndo to Undo 500 steps, and want to Redo 100 steps with different possibilities, you can use ControlNet to finish it (so that it becomes step 0) and then undo 400 steps. More integrated solution is still under experiments.* 128 | 129 | # Disclaimer 130 | 131 | This project aims to develop base models of human drawing behaviors, facilitating future AI systems to better meet the real needs of human artists. Users are granted the freedom to create content using this tool, but they are expected to comply with local laws and use it responsibly. Users must not employ the tool to generate false information or incite confrontation. The developers do not assume any responsibility for potential misuse by users. 132 | -------------------------------------------------------------------------------- /Paints-UNDO/ZZX_PaintsUndo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | 6 | # 导入必要的模块 7 | from .memory_management import load_models_to_gpu, unload_all_models 8 | from .wd14tagger import default_interrogator 9 | from .diffusers_helper.k_diffusion import KDiffusionSampler 10 | from .diffusers_helper.cat_cond import unet_add_concat_conds 11 | from .diffusers_helper.code_cond import unet_add_coded_conds 12 | from transformers import CLIPTextModel, CLIPTokenizer 13 | from diffusers import AutoencoderKL, UNet2DConditionModel 14 | from diffusers.models.attention_processor import AttnProcessor2_0 15 | 16 | class ModifiedUNet(UNet2DConditionModel): 17 | @classmethod 18 | def from_config(cls, *args, **kwargs): 19 | m = super().from_config(*args, **kwargs) 20 | unet_add_concat_conds(unet=m, new_channels=4) 21 | unet_add_coded_conds(unet=m, added_number_count=1) 22 | return m 23 | 24 | class ZZX_PaintsUndo: 25 | def __init__(self): 26 | self.model_name = 'lllyasviel/paints_undo_single_frame' 27 | self.tokenizer = None 28 | self.text_encoder = None 29 | self.vae = None 30 | self.unet = None 31 | self.k_sampler = None 32 | self.initialize_models() 33 | 34 | @classmethod 35 | def INPUT_TYPES(s): 36 | return { 37 | "required": { 38 | "image": ("IMAGE",), 39 | "Prompt": ("STRING", {"default": "", "multiline": True}), 40 | "undo_steps": ("INT", {"default": 5, "min": 1, "max": 999, "step": 1}), 41 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), 42 | }, 43 | } 44 | 45 | RETURN_TYPES = ("IMAGE", "STRING") 46 | RETURN_NAMES = ("image", "prompt") 47 | FUNCTION = "process_image" 48 | CATEGORY = "ZZX/PaintsUndo" 49 | 50 | def initialize_models(self): 51 | os.environ['HF_HOME'] = os.path.join(os.path.dirname(__file__), 'hf_download') 52 | dtype = torch.float16 53 | 54 | self.tokenizer = CLIPTokenizer.from_pretrained(self.model_name, subfolder="tokenizer") 55 | self.text_encoder = CLIPTextModel.from_pretrained(self.model_name, subfolder="text_encoder").to(dtype) 56 | self.vae = AutoencoderKL.from_pretrained(self.model_name, subfolder="vae").to(dtype) 57 | self.unet = ModifiedUNet.from_pretrained(self.model_name, subfolder="unet").to(dtype) 58 | 59 | self.unet.set_attn_processor(AttnProcessor2_0()) 60 | self.vae.set_attn_processor(AttnProcessor2_0()) 61 | 62 | self.k_sampler = KDiffusionSampler( 63 | self.unet, 64 | timesteps=1000, 65 | linear_start=0.00085, 66 | linear_end=0.020, 67 | linear=True 68 | ) 69 | 70 | unload_all_models([self.vae, self.text_encoder, self.unet]) 71 | 72 | def process_image(self, image, Prompt, undo_steps, seed): 73 | print("Starting process_image method") 74 | pil_image = Image.fromarray(np.clip(255. * image[0].cpu().numpy(), 0, 255).astype(np.uint8)) 75 | 76 | generated_prompt = "" 77 | if not Prompt: 78 | generated_prompt = default_interrogator(pil_image) 79 | Prompt = generated_prompt 80 | 81 | print(f"Input image shape: {np.array(pil_image).shape}") 82 | print(f"Prompt: {Prompt}") 83 | print(f"Undo steps: {undo_steps}") 84 | print(f"Seed: {seed}") 85 | 86 | result = self.paints_undo_process(pil_image, Prompt, undo_steps, seed) 87 | 88 | print(f"Result shape after paints_undo_process: {result.shape}") 89 | 90 | # 处理可能的 BGR 输出 91 | if result.shape[0] == 3: 92 | print("Detected 3-channel output, assuming BGR order") 93 | result = result[::-1] # 反转通道顺序从 BGR 到 RGB 94 | result = np.transpose(result, (1, 2, 0)) # 从 [C, H, W] 转换到 [H, W, C] 95 | elif result.ndim == 3 and result.shape[2] == 3: 96 | print("Result is already in [H, W, C] format") 97 | else: 98 | print(f"Unexpected result shape: {result.shape}") 99 | # 如果形状不符合预期,可以尝试其他处理方法,或者引发一个错误 100 | 101 | # 确保结果是 [H, W, C] 格式的 numpy 数组 102 | if result.ndim == 2: 103 | result = np.stack([result] * 3, axis=-1) # 如果是灰度图,转换为RGB 104 | elif result.shape[2] == 1: 105 | result = np.repeat(result, 3, axis=2) # 如果是单通道,重复三次得到RGB 106 | 107 | # 转换为 ComfyUI 期望的格式:[C, H, W] 的 torch.Tensor,值范围 0-1 108 | output_image = torch.from_numpy(result).float().permute(2, 0, 1) / 255.0 109 | 110 | print(f"Final output image shape: {output_image.shape}") 111 | print(f"Final output image min/max values: {output_image.min()}, {output_image.max()}") 112 | 113 | # 决定输出的prompt 114 | output_prompt = Prompt if Prompt else generated_prompt 115 | 116 | return (output_image, output_prompt) 117 | 118 | def paints_undo_process(self, image, prompt, undo_steps, seed): 119 | print("Starting paints_undo_process method") 120 | load_models_to_gpu([self.vae, self.text_encoder, self.unet]) 121 | 122 | dtype = self.unet.dtype 123 | 124 | image = np.array(image) 125 | concat_conds = torch.from_numpy(image).unsqueeze(0).to(self.vae.device, dtype=dtype) / 127.5 - 1.0 126 | concat_conds = self.vae.encode(concat_conds.permute(0, 3, 1, 2)).latent_dist.mode() * self.vae.config.scaling_factor 127 | 128 | print(f"Concat_conds shape: {concat_conds.shape}") 129 | 130 | conds = self.encode_prompt(prompt) 131 | unconds = self.encode_prompt("") 132 | 133 | generator = torch.Generator(device=self.unet.device).manual_seed(seed) 134 | 135 | fs = torch.tensor([undo_steps], device=self.unet.device, dtype=torch.long) 136 | latents = self.k_sampler( 137 | initial_latent=torch.zeros_like(concat_conds), 138 | strength=0.8, 139 | num_inference_steps=30, 140 | guidance_scale=7.5, 141 | batch_size=1, 142 | generator=generator, 143 | prompt_embeds=conds, 144 | negative_prompt_embeds=unconds, 145 | cross_attention_kwargs={'concat_conds': concat_conds, 'coded_conds': fs}, 146 | ) 147 | 148 | print(f"Latents shape after sampling: {latents.shape}") 149 | 150 | images = self.vae.decode(latents / self.vae.config.scaling_factor).sample 151 | images = (images / 2 + 0.5).clamp(0, 1) 152 | 153 | print(f"Images shape after VAE decode: {images.shape}") 154 | print(f"Images min/max values: {images.min()}, {images.max()}") 155 | 156 | images = images.cpu().permute(0, 2, 3, 1).float().numpy() 157 | 158 | unload_all_models([self.vae, self.text_encoder, self.unet]) 159 | 160 | final_image = (images[0] * 255).astype(np.uint8) 161 | print(f"Final image shape: {final_image.shape}") 162 | print(f"Final image min/max values: {final_image.min()}, {final_image.max()}") 163 | 164 | return final_image 165 | 166 | def encode_prompt(self, prompt): 167 | text_inputs = self.tokenizer( 168 | prompt, 169 | padding="max_length", 170 | max_length=self.tokenizer.model_max_length, 171 | truncation=True, 172 | return_tensors="pt", 173 | ) 174 | text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) 175 | prompt_embeds = self.text_encoder(text_input_ids)[0] 176 | return prompt_embeds 177 | 178 | NODE_CLASS_MAPPINGS = { 179 | "ZZX_PaintsUndo": ZZX_PaintsUndo 180 | } -------------------------------------------------------------------------------- /Paints-UNDO/diffusers_helper/cat_cond.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def unet_add_concat_conds(unet, new_channels=4): 5 | with torch.no_grad(): 6 | new_conv_in = torch.nn.Conv2d(4 + new_channels, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) 7 | new_conv_in.weight.zero_() 8 | new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) 9 | new_conv_in.bias = unet.conv_in.bias 10 | unet.conv_in = new_conv_in 11 | 12 | unet_original_forward = unet.forward 13 | 14 | def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs): 15 | cross_attention_kwargs = {k: v for k, v in kwargs['cross_attention_kwargs'].items()} 16 | c_concat = cross_attention_kwargs.pop('concat_conds') 17 | kwargs['cross_attention_kwargs'] = cross_attention_kwargs 18 | 19 | c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0).to(sample) 20 | new_sample = torch.cat([sample, c_concat], dim=1) 21 | return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs) 22 | 23 | unet.forward = hooked_unet_forward 24 | return 25 | -------------------------------------------------------------------------------- /Paints-UNDO/diffusers_helper/code_cond.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 4 | 5 | 6 | def unet_add_coded_conds(unet, added_number_count=1): 7 | unet.add_time_proj = Timesteps(256, True, 0) 8 | unet.add_embedding = TimestepEmbedding(256 * added_number_count, 1280) 9 | 10 | def get_aug_embed(emb, encoder_hidden_states, added_cond_kwargs): 11 | coded_conds = added_cond_kwargs.get("coded_conds") 12 | batch_size = coded_conds.shape[0] 13 | time_embeds = unet.add_time_proj(coded_conds.flatten()) 14 | time_embeds = time_embeds.reshape((batch_size, -1)) 15 | time_embeds = time_embeds.to(emb) 16 | aug_emb = unet.add_embedding(time_embeds) 17 | return aug_emb 18 | 19 | unet.get_aug_embed = get_aug_embed 20 | 21 | unet_original_forward = unet.forward 22 | 23 | def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs): 24 | cross_attention_kwargs = {k: v for k, v in kwargs['cross_attention_kwargs'].items()} 25 | coded_conds = cross_attention_kwargs.pop('coded_conds') 26 | kwargs['cross_attention_kwargs'] = cross_attention_kwargs 27 | 28 | coded_conds = torch.cat([coded_conds] * (sample.shape[0] // coded_conds.shape[0]), dim=0).to(sample.device) 29 | kwargs['added_cond_kwargs'] = dict(coded_conds=coded_conds) 30 | return unet_original_forward(sample, timestep, encoder_hidden_states, **kwargs) 31 | 32 | unet.forward = hooked_unet_forward 33 | 34 | return 35 | -------------------------------------------------------------------------------- /Paints-UNDO/diffusers_helper/k_diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from tqdm import tqdm 5 | 6 | 7 | @torch.no_grad() 8 | def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, progress_tqdm=None): 9 | """DPM-Solver++(2M).""" 10 | extra_args = {} if extra_args is None else extra_args 11 | s_in = x.new_ones([x.shape[0]]) 12 | sigma_fn = lambda t: t.neg().exp() 13 | t_fn = lambda sigma: sigma.log().neg() 14 | old_denoised = None 15 | 16 | bar = tqdm if progress_tqdm is None else progress_tqdm 17 | 18 | for i in bar(range(len(sigmas) - 1)): 19 | denoised = model(x, sigmas[i] * s_in, **extra_args) 20 | if callback is not None: 21 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) 22 | t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) 23 | h = t_next - t 24 | if old_denoised is None or sigmas[i + 1] == 0: 25 | x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised 26 | else: 27 | h_last = t - t_fn(sigmas[i - 1]) 28 | r = h_last / h 29 | denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised 30 | x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d 31 | old_denoised = denoised 32 | return x 33 | 34 | 35 | class KModel: 36 | def __init__(self, unet, timesteps=1000, linear_start=0.00085, linear_end=0.012, linear=False): 37 | if linear: 38 | betas = torch.linspace(linear_start, linear_end, timesteps, dtype=torch.float64) 39 | else: 40 | betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, timesteps, dtype=torch.float64) ** 2 41 | 42 | alphas = 1. - betas 43 | alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32) 44 | 45 | self.sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 46 | self.log_sigmas = self.sigmas.log() 47 | self.sigma_data = 1.0 48 | self.unet = unet 49 | return 50 | 51 | @property 52 | def sigma_min(self): 53 | return self.sigmas[0] 54 | 55 | @property 56 | def sigma_max(self): 57 | return self.sigmas[-1] 58 | 59 | def timestep(self, sigma): 60 | log_sigma = sigma.log() 61 | dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] 62 | return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device) 63 | 64 | def get_sigmas_karras(self, n, rho=7.): 65 | ramp = torch.linspace(0, 1, n) 66 | min_inv_rho = self.sigma_min ** (1 / rho) 67 | max_inv_rho = self.sigma_max ** (1 / rho) 68 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho 69 | return torch.cat([sigmas, sigmas.new_zeros([1])]) 70 | 71 | def __call__(self, x, sigma, **extra_args): 72 | x_ddim_space = x / (sigma[:, None, None, None] ** 2 + self.sigma_data ** 2) ** 0.5 73 | x_ddim_space = x_ddim_space.to(dtype=self.unet.dtype) 74 | t = self.timestep(sigma) 75 | cfg_scale = extra_args['cfg_scale'] 76 | eps_positive = self.unet(x_ddim_space, t, return_dict=False, **extra_args['positive'])[0] 77 | eps_negative = self.unet(x_ddim_space, t, return_dict=False, **extra_args['negative'])[0] 78 | noise_pred = eps_negative + cfg_scale * (eps_positive - eps_negative) 79 | return x - noise_pred * sigma[:, None, None, None] 80 | 81 | 82 | class KDiffusionSampler: 83 | def __init__(self, unet, **kwargs): 84 | self.unet = unet 85 | self.k_model = KModel(unet=unet, **kwargs) 86 | 87 | @torch.inference_mode() 88 | def __call__( 89 | self, 90 | initial_latent = None, 91 | strength = 1.0, 92 | num_inference_steps = 25, 93 | guidance_scale = 5.0, 94 | batch_size = 1, 95 | generator = None, 96 | prompt_embeds = None, 97 | negative_prompt_embeds = None, 98 | cross_attention_kwargs = None, 99 | same_noise_in_batch = False, 100 | progress_tqdm = None, 101 | ): 102 | 103 | device = self.unet.device 104 | 105 | # Sigmas 106 | 107 | sigmas = self.k_model.get_sigmas_karras(int(num_inference_steps/strength)) 108 | sigmas = sigmas[-(num_inference_steps + 1):].to(device) 109 | 110 | # Initial latents 111 | 112 | if same_noise_in_batch: 113 | noise = torch.randn(initial_latent.shape, generator=generator, device=device, dtype=self.unet.dtype).repeat(batch_size, 1, 1, 1) 114 | initial_latent = initial_latent.repeat(batch_size, 1, 1, 1).to(device=device, dtype=self.unet.dtype) 115 | else: 116 | initial_latent = initial_latent.repeat(batch_size, 1, 1, 1).to(device=device, dtype=self.unet.dtype) 117 | noise = torch.randn(initial_latent.shape, generator=generator, device=device, dtype=self.unet.dtype) 118 | 119 | latents = initial_latent + noise * sigmas[0].to(initial_latent) 120 | 121 | # Batch 122 | 123 | latents = latents.to(device) 124 | prompt_embeds = prompt_embeds.repeat(batch_size, 1, 1).to(device) 125 | negative_prompt_embeds = negative_prompt_embeds.repeat(batch_size, 1, 1).to(device) 126 | 127 | # Feeds 128 | 129 | sampler_kwargs = dict( 130 | cfg_scale=guidance_scale, 131 | positive=dict( 132 | encoder_hidden_states=prompt_embeds, 133 | cross_attention_kwargs=cross_attention_kwargs 134 | ), 135 | negative=dict( 136 | encoder_hidden_states=negative_prompt_embeds, 137 | cross_attention_kwargs=cross_attention_kwargs, 138 | ) 139 | ) 140 | 141 | # Sample 142 | 143 | results = sample_dpmpp_2m(self.k_model, latents, sigmas, extra_args=sampler_kwargs, progress_tqdm=progress_tqdm) 144 | 145 | return results 146 | -------------------------------------------------------------------------------- /Paints-UNDO/diffusers_helper/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import glob 5 | import torch 6 | import einops 7 | import torchvision 8 | 9 | import safetensors.torch as sf 10 | 11 | 12 | def write_to_json(data, file_path): 13 | temp_file_path = file_path + ".tmp" 14 | with open(temp_file_path, 'wt', encoding='utf-8') as temp_file: 15 | json.dump(data, temp_file, indent=4) 16 | os.replace(temp_file_path, file_path) 17 | return 18 | 19 | 20 | def read_from_json(file_path): 21 | with open(file_path, 'rt', encoding='utf-8') as file: 22 | data = json.load(file) 23 | return data 24 | 25 | 26 | def get_active_parameters(m): 27 | return {k:v for k, v in m.named_parameters() if v.requires_grad} 28 | 29 | 30 | def cast_training_params(m, dtype=torch.float32): 31 | for param in m.parameters(): 32 | if param.requires_grad: 33 | param.data = param.to(dtype) 34 | return 35 | 36 | 37 | def set_attr_recursive(obj, attr, value): 38 | attrs = attr.split(".") 39 | for name in attrs[:-1]: 40 | obj = getattr(obj, name) 41 | setattr(obj, attrs[-1], value) 42 | return 43 | 44 | 45 | @torch.no_grad() 46 | def batch_mixture(a, b, probability_a=0.5, mask_a=None): 47 | assert a.shape == b.shape, "Tensors must have the same shape" 48 | batch_size = a.size(0) 49 | 50 | if mask_a is None: 51 | mask_a = torch.rand(batch_size) < probability_a 52 | 53 | mask_a = mask_a.to(a.device) 54 | mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1)) 55 | result = torch.where(mask_a, a, b) 56 | return result 57 | 58 | 59 | @torch.no_grad() 60 | def zero_module(module): 61 | for p in module.parameters(): 62 | p.detach().zero_() 63 | return module 64 | 65 | 66 | def load_last_state(model, folder='accelerator_output'): 67 | file_pattern = os.path.join(folder, '**', 'model.safetensors') 68 | files = glob.glob(file_pattern, recursive=True) 69 | 70 | if not files: 71 | print("No model.safetensors files found in the specified folder.") 72 | return 73 | 74 | newest_file = max(files, key=os.path.getmtime) 75 | state_dict = sf.load_file(newest_file) 76 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 77 | 78 | if missing_keys: 79 | print("Missing keys:", missing_keys) 80 | if unexpected_keys: 81 | print("Unexpected keys:", unexpected_keys) 82 | 83 | print("Loaded model state from:", newest_file) 84 | return 85 | 86 | 87 | def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32): 88 | tags = tags_str.split(', ') 89 | tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags))) 90 | prompt = ', '.join(tags) 91 | return prompt 92 | 93 | 94 | def save_bcthw_as_mp4(x, output_filename, fps=10): 95 | b, c, t, h, w = x.shape 96 | 97 | per_row = b 98 | for p in [6, 5, 4, 3, 2]: 99 | if b % p == 0: 100 | per_row = p 101 | break 102 | 103 | os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) 104 | x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 105 | x = x.detach().cpu().to(torch.uint8) 106 | x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row) 107 | torchvision.io.write_video(output_filename, x, fps=fps, video_codec='h264', options={'crf': '0'}) 108 | return x 109 | 110 | 111 | def save_bcthw_as_png(x, output_filename): 112 | os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) 113 | x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 114 | x = x.detach().cpu().to(torch.uint8) 115 | x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)') 116 | torchvision.io.write_png(x, output_filename) 117 | return output_filename 118 | 119 | 120 | def add_tensors_with_padding(tensor1, tensor2): 121 | if tensor1.shape == tensor2.shape: 122 | return tensor1 + tensor2 123 | 124 | shape1 = tensor1.shape 125 | shape2 = tensor2.shape 126 | 127 | new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2)) 128 | 129 | padded_tensor1 = torch.zeros(new_shape) 130 | padded_tensor2 = torch.zeros(new_shape) 131 | 132 | padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1 133 | padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2 134 | 135 | result = padded_tensor1 + padded_tensor2 136 | return result 137 | -------------------------------------------------------------------------------- /Paints-UNDO/diffusers_vdm/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import xformers.ops 3 | import torch.nn.functional as F 4 | 5 | from torch import nn 6 | from einops import rearrange, repeat 7 | from functools import partial 8 | from diffusers_vdm.basics import zero_module, checkpoint, default, make_temporal_window 9 | 10 | 11 | def sdp(q, k, v, heads): 12 | b, _, C = q.shape 13 | dim_head = C // heads 14 | 15 | q, k, v = map( 16 | lambda t: t.unsqueeze(3) 17 | .reshape(b, t.shape[1], heads, dim_head) 18 | .permute(0, 2, 1, 3) 19 | .reshape(b * heads, t.shape[1], dim_head) 20 | .contiguous(), 21 | (q, k, v), 22 | ) 23 | 24 | out = xformers.ops.memory_efficient_attention(q, k, v) 25 | 26 | out = ( 27 | out.unsqueeze(0) 28 | .reshape(b, heads, out.shape[1], dim_head) 29 | .permute(0, 2, 1, 3) 30 | .reshape(b, out.shape[1], heads * dim_head) 31 | ) 32 | 33 | return out 34 | 35 | 36 | class RelativePosition(nn.Module): 37 | """ https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """ 38 | 39 | def __init__(self, num_units, max_relative_position): 40 | super().__init__() 41 | self.num_units = num_units 42 | self.max_relative_position = max_relative_position 43 | self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units)) 44 | nn.init.xavier_uniform_(self.embeddings_table) 45 | 46 | def forward(self, length_q, length_k): 47 | device = self.embeddings_table.device 48 | range_vec_q = torch.arange(length_q, device=device) 49 | range_vec_k = torch.arange(length_k, device=device) 50 | distance_mat = range_vec_k[None, :] - range_vec_q[:, None] 51 | distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position) 52 | final_mat = distance_mat_clipped + self.max_relative_position 53 | final_mat = final_mat.long() 54 | embeddings = self.embeddings_table[final_mat] 55 | return embeddings 56 | 57 | 58 | class CrossAttention(nn.Module): 59 | 60 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., 61 | relative_position=False, temporal_length=None, video_length=None, image_cross_attention=False, 62 | image_cross_attention_scale=1.0, image_cross_attention_scale_learnable=False, 63 | text_context_len=77, temporal_window_for_spatial_self_attention=False): 64 | super().__init__() 65 | inner_dim = dim_head * heads 66 | context_dim = default(context_dim, query_dim) 67 | 68 | self.scale = dim_head**-0.5 69 | self.heads = heads 70 | self.dim_head = dim_head 71 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 72 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 73 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 74 | 75 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) 76 | 77 | self.is_temporal_attention = temporal_length is not None 78 | 79 | self.relative_position = relative_position 80 | if self.relative_position: 81 | assert self.is_temporal_attention 82 | self.relative_position_k = RelativePosition(num_units=dim_head, max_relative_position=temporal_length) 83 | self.relative_position_v = RelativePosition(num_units=dim_head, max_relative_position=temporal_length) 84 | 85 | self.video_length = video_length 86 | self.temporal_window_for_spatial_self_attention = temporal_window_for_spatial_self_attention 87 | self.temporal_window_type = 'prv' 88 | 89 | self.image_cross_attention = image_cross_attention 90 | self.image_cross_attention_scale = image_cross_attention_scale 91 | self.text_context_len = text_context_len 92 | self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable 93 | if self.image_cross_attention: 94 | self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False) 95 | self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False) 96 | if image_cross_attention_scale_learnable: 97 | self.register_parameter('alpha', nn.Parameter(torch.tensor(0.)) ) 98 | 99 | def forward(self, x, context=None, mask=None): 100 | if self.is_temporal_attention: 101 | return self.temporal_forward(x, context=context, mask=mask) 102 | else: 103 | return self.spatial_forward(x, context=context, mask=mask) 104 | 105 | def temporal_forward(self, x, context=None, mask=None): 106 | assert mask is None, 'Attention mask not implemented!' 107 | assert context is None, 'Temporal attention only supports self attention!' 108 | 109 | q = self.to_q(x) 110 | k = self.to_k(x) 111 | v = self.to_v(x) 112 | 113 | out = sdp(q, k, v, self.heads) 114 | 115 | return self.to_out(out) 116 | 117 | def spatial_forward(self, x, context=None, mask=None): 118 | assert mask is None, 'Attention mask not implemented!' 119 | 120 | spatial_self_attn = (context is None) 121 | k_ip, v_ip, out_ip = None, None, None 122 | 123 | q = self.to_q(x) 124 | context = default(context, x) 125 | 126 | if spatial_self_attn: 127 | k = self.to_k(context) 128 | v = self.to_v(context) 129 | 130 | if self.temporal_window_for_spatial_self_attention: 131 | k = make_temporal_window(k, t=self.video_length, method=self.temporal_window_type) 132 | v = make_temporal_window(v, t=self.video_length, method=self.temporal_window_type) 133 | elif self.image_cross_attention: 134 | context, context_image = context 135 | k = self.to_k(context) 136 | v = self.to_v(context) 137 | k_ip = self.to_k_ip(context_image) 138 | v_ip = self.to_v_ip(context_image) 139 | else: 140 | raise NotImplementedError('Traditional prompt-only attention without IP-Adapter is illegal now.') 141 | 142 | out = sdp(q, k, v, self.heads) 143 | 144 | if k_ip is not None: 145 | out_ip = sdp(q, k_ip, v_ip, self.heads) 146 | 147 | if self.image_cross_attention_scale_learnable: 148 | out = out + self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha) + 1) 149 | else: 150 | out = out + self.image_cross_attention_scale * out_ip 151 | 152 | return self.to_out(out) 153 | 154 | 155 | class BasicTransformerBlock(nn.Module): 156 | 157 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, 158 | disable_self_attn=False, attention_cls=None, video_length=None, image_cross_attention=False, image_cross_attention_scale=1.0, image_cross_attention_scale_learnable=False, text_context_len=77): 159 | super().__init__() 160 | attn_cls = CrossAttention if attention_cls is None else attention_cls 161 | self.disable_self_attn = disable_self_attn 162 | self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, 163 | context_dim=context_dim if self.disable_self_attn else None, video_length=video_length) 164 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 165 | self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout, video_length=video_length, image_cross_attention=image_cross_attention, image_cross_attention_scale=image_cross_attention_scale, image_cross_attention_scale_learnable=image_cross_attention_scale_learnable,text_context_len=text_context_len) 166 | self.image_cross_attention = image_cross_attention 167 | 168 | self.norm1 = nn.LayerNorm(dim) 169 | self.norm2 = nn.LayerNorm(dim) 170 | self.norm3 = nn.LayerNorm(dim) 171 | self.checkpoint = checkpoint 172 | 173 | 174 | def forward(self, x, context=None, mask=None, **kwargs): 175 | ## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments 176 | input_tuple = (x,) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments 177 | if context is not None: 178 | input_tuple = (x, context) 179 | if mask is not None: 180 | forward_mask = partial(self._forward, mask=mask) 181 | return checkpoint(forward_mask, (x,), self.parameters(), self.checkpoint) 182 | return checkpoint(self._forward, input_tuple, self.parameters(), self.checkpoint) 183 | 184 | 185 | def _forward(self, x, context=None, mask=None): 186 | x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, mask=mask) + x 187 | x = self.attn2(self.norm2(x), context=context, mask=mask) + x 188 | x = self.ff(self.norm3(x)) + x 189 | return x 190 | 191 | 192 | class SpatialTransformer(nn.Module): 193 | """ 194 | Transformer block for image-like data in spatial axis. 195 | First, project the input (aka embedding) 196 | and reshape to b, t, d. 197 | Then apply standard transformer action. 198 | Finally, reshape to image 199 | NEW: use_linear for more efficiency instead of the 1x1 convs 200 | """ 201 | 202 | def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, 203 | use_checkpoint=True, disable_self_attn=False, use_linear=False, video_length=None, 204 | image_cross_attention=False, image_cross_attention_scale_learnable=False): 205 | super().__init__() 206 | self.in_channels = in_channels 207 | inner_dim = n_heads * d_head 208 | self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 209 | if not use_linear: 210 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 211 | else: 212 | self.proj_in = nn.Linear(in_channels, inner_dim) 213 | 214 | attention_cls = None 215 | self.transformer_blocks = nn.ModuleList([ 216 | BasicTransformerBlock( 217 | inner_dim, 218 | n_heads, 219 | d_head, 220 | dropout=dropout, 221 | context_dim=context_dim, 222 | disable_self_attn=disable_self_attn, 223 | checkpoint=use_checkpoint, 224 | attention_cls=attention_cls, 225 | video_length=video_length, 226 | image_cross_attention=image_cross_attention, 227 | image_cross_attention_scale_learnable=image_cross_attention_scale_learnable, 228 | ) for d in range(depth) 229 | ]) 230 | if not use_linear: 231 | self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) 232 | else: 233 | self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) 234 | self.use_linear = use_linear 235 | 236 | 237 | def forward(self, x, context=None, **kwargs): 238 | b, c, h, w = x.shape 239 | x_in = x 240 | x = self.norm(x) 241 | if not self.use_linear: 242 | x = self.proj_in(x) 243 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous() 244 | if self.use_linear: 245 | x = self.proj_in(x) 246 | for i, block in enumerate(self.transformer_blocks): 247 | x = block(x, context=context, **kwargs) 248 | if self.use_linear: 249 | x = self.proj_out(x) 250 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() 251 | if not self.use_linear: 252 | x = self.proj_out(x) 253 | return x + x_in 254 | 255 | 256 | class TemporalTransformer(nn.Module): 257 | """ 258 | Transformer block for image-like data in temporal axis. 259 | First, reshape to b, t, d. 260 | Then apply standard transformer action. 261 | Finally, reshape to image 262 | """ 263 | def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, 264 | use_checkpoint=True, use_linear=False, only_self_att=True, causal_attention=False, causal_block_size=1, 265 | relative_position=False, temporal_length=None): 266 | super().__init__() 267 | self.only_self_att = only_self_att 268 | self.relative_position = relative_position 269 | self.causal_attention = causal_attention 270 | self.causal_block_size = causal_block_size 271 | 272 | self.in_channels = in_channels 273 | inner_dim = n_heads * d_head 274 | self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 275 | self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 276 | if not use_linear: 277 | self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 278 | else: 279 | self.proj_in = nn.Linear(in_channels, inner_dim) 280 | 281 | if relative_position: 282 | assert(temporal_length is not None) 283 | attention_cls = partial(CrossAttention, relative_position=True, temporal_length=temporal_length) 284 | else: 285 | attention_cls = partial(CrossAttention, temporal_length=temporal_length) 286 | if self.causal_attention: 287 | assert(temporal_length is not None) 288 | self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length])) 289 | 290 | if self.only_self_att: 291 | context_dim = None 292 | self.transformer_blocks = nn.ModuleList([ 293 | BasicTransformerBlock( 294 | inner_dim, 295 | n_heads, 296 | d_head, 297 | dropout=dropout, 298 | context_dim=context_dim, 299 | attention_cls=attention_cls, 300 | checkpoint=use_checkpoint) for d in range(depth) 301 | ]) 302 | if not use_linear: 303 | self.proj_out = zero_module(nn.Conv1d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) 304 | else: 305 | self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) 306 | self.use_linear = use_linear 307 | 308 | def forward(self, x, context=None): 309 | b, c, t, h, w = x.shape 310 | x_in = x 311 | x = self.norm(x) 312 | x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous() 313 | if not self.use_linear: 314 | x = self.proj_in(x) 315 | x = rearrange(x, 'bhw c t -> bhw t c').contiguous() 316 | if self.use_linear: 317 | x = self.proj_in(x) 318 | 319 | temp_mask = None 320 | if self.causal_attention: 321 | # slice the from mask map 322 | temp_mask = self.mask[:,:t,:t].to(x.device) 323 | 324 | if temp_mask is not None: 325 | mask = temp_mask.to(x.device) 326 | mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b*h*w) 327 | else: 328 | mask = None 329 | 330 | if self.only_self_att: 331 | ## note: if no context is given, cross-attention defaults to self-attention 332 | for i, block in enumerate(self.transformer_blocks): 333 | x = block(x, mask=mask) 334 | x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous() 335 | else: 336 | x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous() 337 | context = rearrange(context, '(b t) l con -> b t l con', t=t).contiguous() 338 | for i, block in enumerate(self.transformer_blocks): 339 | # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) 340 | for j in range(b): 341 | context_j = repeat( 342 | context[j], 343 | 't l con -> (t r) l con', r=(h * w) // t, t=t).contiguous() 344 | ## note: causal mask will not applied in cross-attention case 345 | x[j] = block(x[j], context=context_j) 346 | 347 | if self.use_linear: 348 | x = self.proj_out(x) 349 | x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous() 350 | if not self.use_linear: 351 | x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous() 352 | x = self.proj_out(x) 353 | x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h, w=w).contiguous() 354 | 355 | return x + x_in 356 | 357 | 358 | class GEGLU(nn.Module): 359 | def __init__(self, dim_in, dim_out): 360 | super().__init__() 361 | self.proj = nn.Linear(dim_in, dim_out * 2) 362 | 363 | def forward(self, x): 364 | x, gate = self.proj(x).chunk(2, dim=-1) 365 | return x * F.gelu(gate) 366 | 367 | 368 | class FeedForward(nn.Module): 369 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 370 | super().__init__() 371 | inner_dim = int(dim * mult) 372 | dim_out = default(dim_out, dim) 373 | project_in = nn.Sequential( 374 | nn.Linear(dim, inner_dim), 375 | nn.GELU() 376 | ) if not glu else GEGLU(dim, inner_dim) 377 | 378 | self.net = nn.Sequential( 379 | project_in, 380 | nn.Dropout(dropout), 381 | nn.Linear(inner_dim, dim_out) 382 | ) 383 | 384 | def forward(self, x): 385 | return self.net(x) 386 | -------------------------------------------------------------------------------- /Paints-UNDO/diffusers_vdm/basics.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import torch 12 | import torch.nn as nn 13 | import einops 14 | 15 | from inspect import isfunction 16 | 17 | 18 | def zero_module(module): 19 | """ 20 | Zero out the parameters of a module and return it. 21 | """ 22 | for p in module.parameters(): 23 | p.detach().zero_() 24 | return module 25 | 26 | def scale_module(module, scale): 27 | """ 28 | Scale the parameters of a module and return it. 29 | """ 30 | for p in module.parameters(): 31 | p.detach().mul_(scale) 32 | return module 33 | 34 | 35 | def conv_nd(dims, *args, **kwargs): 36 | """ 37 | Create a 1D, 2D, or 3D convolution module. 38 | """ 39 | if dims == 1: 40 | return nn.Conv1d(*args, **kwargs) 41 | elif dims == 2: 42 | return nn.Conv2d(*args, **kwargs) 43 | elif dims == 3: 44 | return nn.Conv3d(*args, **kwargs) 45 | raise ValueError(f"unsupported dimensions: {dims}") 46 | 47 | 48 | def linear(*args, **kwargs): 49 | """ 50 | Create a linear module. 51 | """ 52 | return nn.Linear(*args, **kwargs) 53 | 54 | 55 | def avg_pool_nd(dims, *args, **kwargs): 56 | """ 57 | Create a 1D, 2D, or 3D average pooling module. 58 | """ 59 | if dims == 1: 60 | return nn.AvgPool1d(*args, **kwargs) 61 | elif dims == 2: 62 | return nn.AvgPool2d(*args, **kwargs) 63 | elif dims == 3: 64 | return nn.AvgPool3d(*args, **kwargs) 65 | raise ValueError(f"unsupported dimensions: {dims}") 66 | 67 | 68 | def nonlinearity(type='silu'): 69 | if type == 'silu': 70 | return nn.SiLU() 71 | elif type == 'leaky_relu': 72 | return nn.LeakyReLU() 73 | 74 | 75 | def normalization(channels, num_groups=32): 76 | """ 77 | Make a standard normalization layer. 78 | :param channels: number of input channels. 79 | :return: an nn.Module for normalization. 80 | """ 81 | return nn.GroupNorm(num_groups, channels) 82 | 83 | 84 | def default(val, d): 85 | if exists(val): 86 | return val 87 | return d() if isfunction(d) else d 88 | 89 | 90 | def exists(val): 91 | return val is not None 92 | 93 | 94 | def extract_into_tensor(a, t, x_shape): 95 | b, *_ = t.shape 96 | out = a.gather(-1, t) 97 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 98 | 99 | 100 | def make_temporal_window(x, t, method): 101 | assert method in ['roll', 'prv', 'first'] 102 | 103 | if method == 'roll': 104 | m = einops.rearrange(x, '(b t) d c -> b t d c', t=t) 105 | l = torch.roll(m, shifts=1, dims=1) 106 | r = torch.roll(m, shifts=-1, dims=1) 107 | 108 | recon = torch.cat([l, m, r], dim=2) 109 | del l, m, r 110 | 111 | recon = einops.rearrange(recon, 'b t d c -> (b t) d c') 112 | return recon 113 | 114 | if method == 'prv': 115 | x = einops.rearrange(x, '(b t) d c -> b t d c', t=t) 116 | prv = torch.cat([x[:, :1], x[:, :-1]], dim=1) 117 | 118 | recon = torch.cat([x, prv], dim=2) 119 | del x, prv 120 | 121 | recon = einops.rearrange(recon, 'b t d c -> (b t) d c') 122 | return recon 123 | 124 | if method == 'first': 125 | x = einops.rearrange(x, '(b t) d c -> b t d c', t=t) 126 | prv = x[:, [0], :, :].repeat(1, t, 1, 1) 127 | 128 | recon = torch.cat([x, prv], dim=2) 129 | del x, prv 130 | 131 | recon = einops.rearrange(recon, 'b t d c -> (b t) d c') 132 | return recon 133 | 134 | 135 | def checkpoint(func, inputs, params, flag): 136 | """ 137 | Evaluate a function without caching intermediate activations, allowing for 138 | reduced memory at the expense of extra compute in the backward pass. 139 | :param func: the function to evaluate. 140 | :param inputs: the argument sequence to pass to `func`. 141 | :param params: a sequence of parameters `func` depends on but does not 142 | explicitly take as arguments. 143 | :param flag: if False, disable gradient checkpointing. 144 | """ 145 | if flag: 146 | return torch.utils.checkpoint.checkpoint(func, *inputs, use_reentrant=False) 147 | else: 148 | return func(*inputs) 149 | -------------------------------------------------------------------------------- /Paints-UNDO/diffusers_vdm/dynamic_tsnr_sampler.py: -------------------------------------------------------------------------------- 1 | # everything that can improve v-prediction model 2 | # dynamic scaling + tsnr + beta modifier + dynamic cfg rescale + ... 3 | # written by lvmin at stanford 2024 4 | 5 | import torch 6 | import numpy as np 7 | 8 | from tqdm import tqdm 9 | from functools import partial 10 | from diffusers_vdm.basics import extract_into_tensor 11 | 12 | 13 | to_torch = partial(torch.tensor, dtype=torch.float32) 14 | 15 | 16 | def rescale_zero_terminal_snr(betas): 17 | # Convert betas to alphas_bar_sqrt 18 | alphas = 1.0 - betas 19 | alphas_cumprod = np.cumprod(alphas, axis=0) 20 | alphas_bar_sqrt = np.sqrt(alphas_cumprod) 21 | 22 | # Store old values. 23 | alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy() 24 | alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy() 25 | 26 | # Shift so the last timestep is zero. 27 | alphas_bar_sqrt -= alphas_bar_sqrt_T 28 | 29 | # Scale so the first timestep is back to the old value. 30 | alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) 31 | 32 | # Convert alphas_bar_sqrt to betas 33 | alphas_bar = alphas_bar_sqrt**2 # Revert sqrt 34 | alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod 35 | alphas = np.concatenate([alphas_bar[0:1], alphas]) 36 | betas = 1 - alphas 37 | 38 | return betas 39 | 40 | 41 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): 42 | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) 43 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) 44 | 45 | # rescale the results from guidance (fixes overexposure) 46 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) 47 | 48 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images 49 | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg 50 | 51 | return noise_cfg 52 | 53 | 54 | class SamplerDynamicTSNR(torch.nn.Module): 55 | @torch.no_grad() 56 | def __init__(self, unet, terminal_scale=0.7): 57 | super().__init__() 58 | self.unet = unet 59 | 60 | self.is_v = True 61 | self.n_timestep = 1000 62 | self.guidance_rescale = 0.7 63 | 64 | linear_start = 0.00085 65 | linear_end = 0.012 66 | 67 | betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5, self.n_timestep, dtype=np.float64) ** 2 68 | betas = rescale_zero_terminal_snr(betas) 69 | alphas = 1. - betas 70 | 71 | alphas_cumprod = np.cumprod(alphas, axis=0) 72 | 73 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod).to(unet.device)) 74 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)).to(unet.device)) 75 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)).to(unet.device)) 76 | 77 | # Dynamic TSNR 78 | turning_step = 400 79 | scale_arr = np.concatenate([ 80 | np.linspace(1.0, terminal_scale, turning_step), 81 | np.full(self.n_timestep - turning_step, terminal_scale) 82 | ]) 83 | self.register_buffer('scale_arr', to_torch(scale_arr).to(unet.device)) 84 | 85 | def predict_eps_from_z_and_v(self, x_t, t, v): 86 | return self.sqrt_alphas_cumprod[t] * v + self.sqrt_one_minus_alphas_cumprod[t] * x_t 87 | 88 | def predict_start_from_z_and_v(self, x_t, t, v): 89 | return self.sqrt_alphas_cumprod[t] * x_t - self.sqrt_one_minus_alphas_cumprod[t] * v 90 | 91 | def q_sample(self, x0, t, noise): 92 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x0.shape) * x0 + 93 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) 94 | 95 | def get_v(self, x0, t, noise): 96 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x0.shape) * noise - 97 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * x0) 98 | 99 | def dynamic_x0_rescale(self, x0, t): 100 | return x0 * extract_into_tensor(self.scale_arr, t, x0.shape) 101 | 102 | @torch.no_grad() 103 | def get_ground_truth(self, x0, noise, t): 104 | x0 = self.dynamic_x0_rescale(x0, t) 105 | xt = self.q_sample(x0, t, noise) 106 | target = self.get_v(x0, t, noise) if self.is_v else noise 107 | return xt, target 108 | 109 | def get_uniform_trailing_steps(self, steps): 110 | c = self.n_timestep / steps 111 | ddim_timesteps = np.flip(np.round(np.arange(self.n_timestep, 0, -c))).astype(np.int64) 112 | steps_out = ddim_timesteps - 1 113 | return torch.tensor(steps_out, device=self.unet.device, dtype=torch.long) 114 | 115 | @torch.no_grad() 116 | def forward(self, latent_shape, steps, extra_args, progress_tqdm=None): 117 | bar = tqdm if progress_tqdm is None else progress_tqdm 118 | 119 | eta = 1.0 120 | 121 | timesteps = self.get_uniform_trailing_steps(steps) 122 | timesteps_prev = torch.nn.functional.pad(timesteps[:-1], pad=(1, 0)) 123 | 124 | x = torch.randn(latent_shape, device=self.unet.device, dtype=self.unet.dtype) 125 | 126 | alphas = self.alphas_cumprod[timesteps] 127 | alphas_prev = self.alphas_cumprod[timesteps_prev] 128 | scale_arr = self.scale_arr[timesteps] 129 | scale_arr_prev = self.scale_arr[timesteps_prev] 130 | 131 | sqrt_one_minus_alphas = torch.sqrt(1 - alphas) 132 | sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy())) 133 | 134 | s_in = x.new_ones((x.shape[0])) 135 | s_x = x.new_ones((x.shape[0], ) + (1, ) * (x.ndim - 1)) 136 | for i in bar(range(len(timesteps))): 137 | index = len(timesteps) - 1 - i 138 | t = timesteps[index].item() 139 | 140 | model_output = self.model_apply(x, t * s_in, **extra_args) 141 | 142 | if self.is_v: 143 | e_t = self.predict_eps_from_z_and_v(x, t, model_output) 144 | else: 145 | e_t = model_output 146 | 147 | a_prev = alphas_prev[index].item() * s_x 148 | sigma_t = sigmas[index].item() * s_x 149 | 150 | if self.is_v: 151 | pred_x0 = self.predict_start_from_z_and_v(x, t, model_output) 152 | else: 153 | a_t = alphas[index].item() * s_x 154 | sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x 155 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 156 | 157 | # dynamic rescale 158 | scale_t = scale_arr[index].item() * s_x 159 | prev_scale_t = scale_arr_prev[index].item() * s_x 160 | rescale = (prev_scale_t / scale_t) 161 | pred_x0 = pred_x0 * rescale 162 | 163 | dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t 164 | noise = sigma_t * torch.randn_like(x) 165 | x = a_prev.sqrt() * pred_x0 + dir_xt + noise 166 | 167 | return x 168 | 169 | @torch.no_grad() 170 | def model_apply(self, x, t, **extra_args): 171 | x = x.to(device=self.unet.device, dtype=self.unet.dtype) 172 | cfg_scale = extra_args['cfg_scale'] 173 | p = self.unet(x, t, **extra_args['positive']) 174 | n = self.unet(x, t, **extra_args['negative']) 175 | o = n + cfg_scale * (p - n) 176 | o_better = rescale_noise_cfg(o, p, guidance_rescale=self.guidance_rescale) 177 | return o_better 178 | -------------------------------------------------------------------------------- /Paints-UNDO/diffusers_vdm/improved_clip_vision.py: -------------------------------------------------------------------------------- 1 | # A CLIP Vision supporting arbitrary aspect ratios, by lllyasviel 2 | # The input range is changed to [-1, 1] rather than [0, 1] !!!! (same as VAE's range) 3 | 4 | import torch 5 | import types 6 | import einops 7 | 8 | from abc import ABCMeta 9 | from transformers import CLIPVisionModelWithProjection 10 | 11 | 12 | def preprocess(image): 13 | mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=image.device, dtype=image.dtype)[None, :, None, None] 14 | std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=image.device, dtype=image.dtype)[None, :, None, None] 15 | 16 | scale = 16 / min(image.shape[2], image.shape[3]) 17 | image = torch.nn.functional.interpolate( 18 | image, 19 | size=(14 * round(scale * image.shape[2]), 14 * round(scale * image.shape[3])), 20 | mode="bicubic", 21 | antialias=True 22 | ) 23 | 24 | return (image - mean) / std 25 | 26 | 27 | def arbitrary_positional_encoding(p, H, W): 28 | weight = p.weight 29 | cls = weight[:1] 30 | pos = weight[1:] 31 | pos = einops.rearrange(pos, '(H W) C -> 1 C H W', H=16, W=16) 32 | pos = torch.nn.functional.interpolate(pos, size=(H, W), mode="nearest") 33 | pos = einops.rearrange(pos, '1 C H W -> (H W) C') 34 | weight = torch.cat([cls, pos])[None] 35 | return weight 36 | 37 | 38 | def improved_clipvision_embedding_forward(self, pixel_values): 39 | pixel_values = pixel_values * 0.5 + 0.5 40 | pixel_values = preprocess(pixel_values) 41 | batch_size = pixel_values.shape[0] 42 | target_dtype = self.patch_embedding.weight.dtype 43 | patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) 44 | B, C, H, W = patch_embeds.shape 45 | patch_embeds = einops.rearrange(patch_embeds, 'B C H W -> B (H W) C') 46 | class_embeds = self.class_embedding.expand(batch_size, 1, -1) 47 | embeddings = torch.cat([class_embeds, patch_embeds], dim=1) 48 | embeddings = embeddings + arbitrary_positional_encoding(self.position_embedding, H, W) 49 | return embeddings 50 | 51 | 52 | class ImprovedCLIPVisionModelWithProjection(CLIPVisionModelWithProjection, metaclass=ABCMeta): 53 | def __init__(self, config): 54 | super().__init__(config) 55 | self.vision_model.embeddings.forward = types.MethodType( 56 | improved_clipvision_embedding_forward, 57 | self.vision_model.embeddings 58 | ) 59 | -------------------------------------------------------------------------------- /Paints-UNDO/diffusers_vdm/pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import einops 4 | 5 | from diffusers import DiffusionPipeline 6 | from transformers import CLIPTextModel, CLIPTokenizer 7 | from huggingface_hub import snapshot_download 8 | from diffusers_vdm.vae import VideoAutoencoderKL 9 | from diffusers_vdm.projection import Resampler 10 | from diffusers_vdm.unet import UNet3DModel 11 | from diffusers_vdm.improved_clip_vision import ImprovedCLIPVisionModelWithProjection 12 | from diffusers_vdm.dynamic_tsnr_sampler import SamplerDynamicTSNR 13 | 14 | 15 | class LatentVideoDiffusionPipeline(DiffusionPipeline): 16 | def __init__(self, tokenizer, text_encoder, image_encoder, vae, image_projection, unet, fp16=True, eval=True): 17 | super().__init__() 18 | 19 | self.loading_components = dict( 20 | vae=vae, 21 | text_encoder=text_encoder, 22 | tokenizer=tokenizer, 23 | unet=unet, 24 | image_encoder=image_encoder, 25 | image_projection=image_projection 26 | ) 27 | 28 | for k, v in self.loading_components.items(): 29 | setattr(self, k, v) 30 | 31 | if fp16: 32 | self.vae.half() 33 | self.text_encoder.half() 34 | self.unet.half() 35 | self.image_encoder.half() 36 | self.image_projection.half() 37 | 38 | self.vae.requires_grad_(False) 39 | self.text_encoder.requires_grad_(False) 40 | self.image_encoder.requires_grad_(False) 41 | 42 | self.vae.eval() 43 | self.text_encoder.eval() 44 | self.image_encoder.eval() 45 | 46 | if eval: 47 | self.unet.eval() 48 | self.image_projection.eval() 49 | else: 50 | self.unet.train() 51 | self.image_projection.train() 52 | 53 | def to(self, *args, **kwargs): 54 | for k, v in self.loading_components.items(): 55 | if hasattr(v, 'to'): 56 | v.to(*args, **kwargs) 57 | return self 58 | 59 | def save_pretrained(self, save_directory, **kwargs): 60 | for k, v in self.loading_components.items(): 61 | folder = os.path.join(save_directory, k) 62 | os.makedirs(folder, exist_ok=True) 63 | v.save_pretrained(folder) 64 | return 65 | 66 | @classmethod 67 | def from_pretrained(cls, repo_id, fp16=True, eval=True, token=None): 68 | local_folder = snapshot_download(repo_id=repo_id, token=token) 69 | return cls( 70 | tokenizer=CLIPTokenizer.from_pretrained(os.path.join(local_folder, "tokenizer")), 71 | text_encoder=CLIPTextModel.from_pretrained(os.path.join(local_folder, "text_encoder")), 72 | image_encoder=ImprovedCLIPVisionModelWithProjection.from_pretrained(os.path.join(local_folder, "image_encoder")), 73 | vae=VideoAutoencoderKL.from_pretrained(os.path.join(local_folder, "vae")), 74 | image_projection=Resampler.from_pretrained(os.path.join(local_folder, "image_projection")), 75 | unet=UNet3DModel.from_pretrained(os.path.join(local_folder, "unet")), 76 | fp16=fp16, 77 | eval=eval 78 | ) 79 | 80 | @torch.inference_mode() 81 | def encode_cropped_prompt_77tokens(self, prompt: str): 82 | cond_ids = self.tokenizer(prompt, 83 | padding="max_length", 84 | max_length=self.tokenizer.model_max_length, 85 | truncation=True, 86 | return_tensors="pt").input_ids.to(self.text_encoder.device) 87 | cond = self.text_encoder(cond_ids, attention_mask=None).last_hidden_state 88 | return cond 89 | 90 | @torch.inference_mode() 91 | def encode_clip_vision(self, frames): 92 | b, c, t, h, w = frames.shape 93 | frames = einops.rearrange(frames, 'b c t h w -> (b t) c h w') 94 | clipvision_embed = self.image_encoder(frames).last_hidden_state 95 | clipvision_embed = einops.rearrange(clipvision_embed, '(b t) d c -> b t d c', t=t) 96 | return clipvision_embed 97 | 98 | @torch.inference_mode() 99 | def encode_latents(self, videos, return_hidden_states=True): 100 | b, c, t, h, w = videos.shape 101 | x = einops.rearrange(videos, 'b c t h w -> (b t) c h w') 102 | encoder_posterior, hidden_states = self.vae.encode(x, return_hidden_states=return_hidden_states) 103 | z = encoder_posterior.mode() * self.vae.scale_factor 104 | z = einops.rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) 105 | 106 | if not return_hidden_states: 107 | return z 108 | 109 | hidden_states = [einops.rearrange(h, '(b t) c h w -> b c t h w', b=b) for h in hidden_states] 110 | hidden_states = [h[:, :, [0, -1], :, :] for h in hidden_states] # only need first and last 111 | 112 | return z, hidden_states 113 | 114 | @torch.inference_mode() 115 | def decode_latents(self, latents, hidden_states): 116 | B, C, T, H, W = latents.shape 117 | latents = einops.rearrange(latents, 'b c t h w -> (b t) c h w') 118 | latents = latents.to(device=self.vae.device, dtype=self.vae.dtype) / self.vae.scale_factor 119 | pixels = self.vae.decode(latents, ref_context=hidden_states, timesteps=T) 120 | pixels = einops.rearrange(pixels, '(b t) c h w -> b c t h w', b=B, t=T) 121 | return pixels 122 | 123 | @torch.inference_mode() 124 | def __call__( 125 | self, 126 | batch_size: int = 1, 127 | steps: int = 50, 128 | guidance_scale: float = 5.0, 129 | positive_text_cond = None, 130 | negative_text_cond = None, 131 | positive_image_cond = None, 132 | negative_image_cond = None, 133 | concat_cond = None, 134 | fs = 3, 135 | progress_tqdm = None, 136 | ): 137 | unet_is_training = self.unet.training 138 | 139 | if unet_is_training: 140 | self.unet.eval() 141 | 142 | device = self.unet.device 143 | dtype = self.unet.dtype 144 | dynamic_tsnr_model = SamplerDynamicTSNR(self.unet) 145 | 146 | # Batch 147 | 148 | concat_cond = concat_cond.repeat(batch_size, 1, 1, 1, 1).to(device=device, dtype=dtype) # b, c, t, h, w 149 | positive_text_cond = positive_text_cond.repeat(batch_size, 1, 1).to(concat_cond) # b, f, c 150 | negative_text_cond = negative_text_cond.repeat(batch_size, 1, 1).to(concat_cond) # b, f, c 151 | positive_image_cond = positive_image_cond.repeat(batch_size, 1, 1, 1).to(concat_cond) # b, t, l, c 152 | negative_image_cond = negative_image_cond.repeat(batch_size, 1, 1, 1).to(concat_cond) 153 | 154 | if isinstance(fs, torch.Tensor): 155 | fs = fs.repeat(batch_size, ).to(dtype=torch.long, device=device) # b 156 | else: 157 | fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=device) # b 158 | 159 | # Initial latents 160 | 161 | latent_shape = concat_cond.shape 162 | 163 | # Feeds 164 | 165 | sampler_kwargs = dict( 166 | cfg_scale=guidance_scale, 167 | positive=dict( 168 | context_text=positive_text_cond, 169 | context_img=positive_image_cond, 170 | fs=fs, 171 | concat_cond=concat_cond 172 | ), 173 | negative=dict( 174 | context_text=negative_text_cond, 175 | context_img=negative_image_cond, 176 | fs=fs, 177 | concat_cond=concat_cond 178 | ) 179 | ) 180 | 181 | # Sample 182 | 183 | results = dynamic_tsnr_model(latent_shape, steps, extra_args=sampler_kwargs, progress_tqdm=progress_tqdm) 184 | 185 | if unet_is_training: 186 | self.unet.train() 187 | 188 | return results 189 | -------------------------------------------------------------------------------- /Paints-UNDO/diffusers_vdm/projection.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py 3 | # and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py 4 | 5 | 6 | import math 7 | import torch 8 | import einops 9 | import torch.nn as nn 10 | 11 | from huggingface_hub import PyTorchModelHubMixin 12 | 13 | 14 | class ImageProjModel(nn.Module): 15 | """Projection Model""" 16 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 17 | super().__init__() 18 | self.cross_attention_dim = cross_attention_dim 19 | self.clip_extra_context_tokens = clip_extra_context_tokens 20 | self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 21 | self.norm = nn.LayerNorm(cross_attention_dim) 22 | 23 | def forward(self, image_embeds): 24 | #embeds = image_embeds 25 | embeds = image_embeds.type(list(self.proj.parameters())[0].dtype) 26 | clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) 27 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 28 | return clip_extra_context_tokens 29 | 30 | 31 | # FFN 32 | def FeedForward(dim, mult=4): 33 | inner_dim = int(dim * mult) 34 | return nn.Sequential( 35 | nn.LayerNorm(dim), 36 | nn.Linear(dim, inner_dim, bias=False), 37 | nn.GELU(), 38 | nn.Linear(inner_dim, dim, bias=False), 39 | ) 40 | 41 | 42 | def reshape_tensor(x, heads): 43 | bs, length, width = x.shape 44 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head) 45 | x = x.view(bs, length, heads, -1) 46 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 47 | x = x.transpose(1, 2) 48 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 49 | x = x.reshape(bs, heads, length, -1) 50 | return x 51 | 52 | 53 | class PerceiverAttention(nn.Module): 54 | def __init__(self, *, dim, dim_head=64, heads=8): 55 | super().__init__() 56 | self.scale = dim_head**-0.5 57 | self.dim_head = dim_head 58 | self.heads = heads 59 | inner_dim = dim_head * heads 60 | 61 | self.norm1 = nn.LayerNorm(dim) 62 | self.norm2 = nn.LayerNorm(dim) 63 | 64 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 65 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 66 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 67 | 68 | 69 | def forward(self, x, latents): 70 | """ 71 | Args: 72 | x (torch.Tensor): image features 73 | shape (b, n1, D) 74 | latent (torch.Tensor): latent features 75 | shape (b, n2, D) 76 | """ 77 | x = self.norm1(x) 78 | latents = self.norm2(latents) 79 | 80 | b, l, _ = latents.shape 81 | 82 | q = self.to_q(latents) 83 | kv_input = torch.cat((x, latents), dim=-2) 84 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 85 | 86 | q = reshape_tensor(q, self.heads) 87 | k = reshape_tensor(k, self.heads) 88 | v = reshape_tensor(v, self.heads) 89 | 90 | # attention 91 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 92 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 93 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 94 | out = weight @ v 95 | 96 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 97 | 98 | return self.to_out(out) 99 | 100 | 101 | class Resampler(nn.Module, PyTorchModelHubMixin): 102 | def __init__( 103 | self, 104 | dim=1024, 105 | depth=8, 106 | dim_head=64, 107 | heads=16, 108 | num_queries=8, 109 | embedding_dim=768, 110 | output_dim=1024, 111 | ff_mult=4, 112 | video_length=16, 113 | input_frames_length=2, 114 | ): 115 | super().__init__() 116 | self.num_queries = num_queries 117 | self.video_length = video_length 118 | 119 | self.latents = nn.Parameter(torch.randn(1, num_queries * video_length, dim) / dim**0.5) 120 | self.input_pos = nn.Parameter(torch.zeros(1, input_frames_length, 1, embedding_dim)) 121 | 122 | self.proj_in = nn.Linear(embedding_dim, dim) 123 | self.proj_out = nn.Linear(dim, output_dim) 124 | self.norm_out = nn.LayerNorm(output_dim) 125 | 126 | self.layers = nn.ModuleList([]) 127 | for _ in range(depth): 128 | self.layers.append( 129 | nn.ModuleList( 130 | [ 131 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 132 | FeedForward(dim=dim, mult=ff_mult), 133 | ] 134 | ) 135 | ) 136 | 137 | def forward(self, x): 138 | latents = self.latents.repeat(x.size(0), 1, 1) 139 | 140 | x = x + self.input_pos 141 | x = einops.rearrange(x, 'b ti d c -> b (ti d) c') 142 | x = self.proj_in(x) 143 | 144 | for attn, ff in self.layers: 145 | latents = attn(x, latents) + latents 146 | latents = ff(latents) + latents 147 | 148 | latents = self.proj_out(latents) 149 | latents = self.norm_out(latents) 150 | 151 | latents = einops.rearrange(latents, 'b (to l) c -> b to l c', to=self.video_length) 152 | return latents 153 | 154 | @property 155 | def device(self): 156 | return next(self.parameters()).device 157 | 158 | @property 159 | def dtype(self): 160 | return next(self.parameters()).dtype 161 | -------------------------------------------------------------------------------- /Paints-UNDO/diffusers_vdm/unet.py: -------------------------------------------------------------------------------- 1 | # https://github.com/AILab-CVC/VideoCrafter 2 | # https://github.com/Doubiiu/DynamiCrafter 3 | # https://github.com/ToonCrafter/ToonCrafter 4 | # Then edited by lllyasviel 5 | 6 | from functools import partial 7 | from abc import abstractmethod 8 | import torch 9 | import math 10 | import torch.nn as nn 11 | from einops import rearrange, repeat 12 | import torch.nn.functional as F 13 | from diffusers_vdm.basics import checkpoint 14 | from diffusers_vdm.basics import ( 15 | zero_module, 16 | conv_nd, 17 | linear, 18 | avg_pool_nd, 19 | normalization 20 | ) 21 | from diffusers_vdm.attention import SpatialTransformer, TemporalTransformer 22 | from huggingface_hub import PyTorchModelHubMixin 23 | 24 | 25 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 26 | """ 27 | Create sinusoidal timestep embeddings. 28 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 29 | These may be fractional. 30 | :param dim: the dimension of the output. 31 | :param max_period: controls the minimum frequency of the embeddings. 32 | :return: an [N x dim] Tensor of positional embeddings. 33 | """ 34 | if not repeat_only: 35 | half = dim // 2 36 | freqs = torch.exp( 37 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 38 | ).to(device=timesteps.device) 39 | args = timesteps[:, None].float() * freqs[None] 40 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 41 | if dim % 2: 42 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 43 | else: 44 | embedding = repeat(timesteps, 'b -> b d', d=dim) 45 | return embedding 46 | 47 | 48 | class TimestepBlock(nn.Module): 49 | """ 50 | Any module where forward() takes timestep embeddings as a second argument. 51 | """ 52 | 53 | @abstractmethod 54 | def forward(self, x, emb): 55 | """ 56 | Apply the module to `x` given `emb` timestep embeddings. 57 | """ 58 | 59 | 60 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 61 | """ 62 | A sequential module that passes timestep embeddings to the children that 63 | support it as an extra input. 64 | """ 65 | 66 | def forward(self, x, emb, context=None, batch_size=None): 67 | for layer in self: 68 | if isinstance(layer, TimestepBlock): 69 | x = layer(x, emb, batch_size=batch_size) 70 | elif isinstance(layer, SpatialTransformer): 71 | x = layer(x, context) 72 | elif isinstance(layer, TemporalTransformer): 73 | x = rearrange(x, '(b f) c h w -> b c f h w', b=batch_size) 74 | x = layer(x, context) 75 | x = rearrange(x, 'b c f h w -> (b f) c h w') 76 | else: 77 | x = layer(x) 78 | return x 79 | 80 | 81 | class Downsample(nn.Module): 82 | """ 83 | A downsampling layer with an optional convolution. 84 | :param channels: channels in the inputs and outputs. 85 | :param use_conv: a bool determining if a convolution is applied. 86 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 87 | downsampling occurs in the inner-two dimensions. 88 | """ 89 | 90 | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): 91 | super().__init__() 92 | self.channels = channels 93 | self.out_channels = out_channels or channels 94 | self.use_conv = use_conv 95 | self.dims = dims 96 | stride = 2 if dims != 3 else (1, 2, 2) 97 | if use_conv: 98 | self.op = conv_nd( 99 | dims, self.channels, self.out_channels, 3, stride=stride, padding=padding 100 | ) 101 | else: 102 | assert self.channels == self.out_channels 103 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 104 | 105 | def forward(self, x): 106 | assert x.shape[1] == self.channels 107 | return self.op(x) 108 | 109 | 110 | class Upsample(nn.Module): 111 | """ 112 | An upsampling layer with an optional convolution. 113 | :param channels: channels in the inputs and outputs. 114 | :param use_conv: a bool determining if a convolution is applied. 115 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 116 | upsampling occurs in the inner-two dimensions. 117 | """ 118 | 119 | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): 120 | super().__init__() 121 | self.channels = channels 122 | self.out_channels = out_channels or channels 123 | self.use_conv = use_conv 124 | self.dims = dims 125 | if use_conv: 126 | self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) 127 | 128 | def forward(self, x): 129 | assert x.shape[1] == self.channels 130 | if self.dims == 3: 131 | x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest') 132 | else: 133 | x = F.interpolate(x, scale_factor=2, mode='nearest') 134 | if self.use_conv: 135 | x = self.conv(x) 136 | return x 137 | 138 | 139 | class ResBlock(TimestepBlock): 140 | """ 141 | A residual block that can optionally change the number of channels. 142 | :param channels: the number of input channels. 143 | :param emb_channels: the number of timestep embedding channels. 144 | :param dropout: the rate of dropout. 145 | :param out_channels: if specified, the number of out channels. 146 | :param use_conv: if True and out_channels is specified, use a spatial 147 | convolution instead of a smaller 1x1 convolution to change the 148 | channels in the skip connection. 149 | :param dims: determines if the signal is 1D, 2D, or 3D. 150 | :param up: if True, use this block for upsampling. 151 | :param down: if True, use this block for downsampling. 152 | :param use_temporal_conv: if True, use the temporal convolution. 153 | :param use_image_dataset: if True, the temporal parameters will not be optimized. 154 | """ 155 | 156 | def __init__( 157 | self, 158 | channels, 159 | emb_channels, 160 | dropout, 161 | out_channels=None, 162 | use_scale_shift_norm=False, 163 | dims=2, 164 | use_checkpoint=False, 165 | use_conv=False, 166 | up=False, 167 | down=False, 168 | use_temporal_conv=False, 169 | tempspatial_aware=False 170 | ): 171 | super().__init__() 172 | self.channels = channels 173 | self.emb_channels = emb_channels 174 | self.dropout = dropout 175 | self.out_channels = out_channels or channels 176 | self.use_conv = use_conv 177 | self.use_checkpoint = use_checkpoint 178 | self.use_scale_shift_norm = use_scale_shift_norm 179 | self.use_temporal_conv = use_temporal_conv 180 | 181 | self.in_layers = nn.Sequential( 182 | normalization(channels), 183 | nn.SiLU(), 184 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 185 | ) 186 | 187 | self.updown = up or down 188 | 189 | if up: 190 | self.h_upd = Upsample(channels, False, dims) 191 | self.x_upd = Upsample(channels, False, dims) 192 | elif down: 193 | self.h_upd = Downsample(channels, False, dims) 194 | self.x_upd = Downsample(channels, False, dims) 195 | else: 196 | self.h_upd = self.x_upd = nn.Identity() 197 | 198 | self.emb_layers = nn.Sequential( 199 | nn.SiLU(), 200 | nn.Linear( 201 | emb_channels, 202 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 203 | ), 204 | ) 205 | self.out_layers = nn.Sequential( 206 | normalization(self.out_channels), 207 | nn.SiLU(), 208 | nn.Dropout(p=dropout), 209 | zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)), 210 | ) 211 | 212 | if self.out_channels == channels: 213 | self.skip_connection = nn.Identity() 214 | elif use_conv: 215 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) 216 | else: 217 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 218 | 219 | if self.use_temporal_conv: 220 | self.temopral_conv = TemporalConvBlock( 221 | self.out_channels, 222 | self.out_channels, 223 | dropout=0.1, 224 | spatial_aware=tempspatial_aware 225 | ) 226 | 227 | def forward(self, x, emb, batch_size=None): 228 | """ 229 | Apply the block to a Tensor, conditioned on a timestep embedding. 230 | :param x: an [N x C x ...] Tensor of features. 231 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 232 | :return: an [N x C x ...] Tensor of outputs. 233 | """ 234 | input_tuple = (x, emb) 235 | if batch_size: 236 | forward_batchsize = partial(self._forward, batch_size=batch_size) 237 | return checkpoint(forward_batchsize, input_tuple, self.parameters(), self.use_checkpoint) 238 | return checkpoint(self._forward, input_tuple, self.parameters(), self.use_checkpoint) 239 | 240 | def _forward(self, x, emb, batch_size=None): 241 | if self.updown: 242 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 243 | h = in_rest(x) 244 | h = self.h_upd(h) 245 | x = self.x_upd(x) 246 | h = in_conv(h) 247 | else: 248 | h = self.in_layers(x) 249 | emb_out = self.emb_layers(emb).type(h.dtype) 250 | while len(emb_out.shape) < len(h.shape): 251 | emb_out = emb_out[..., None] 252 | if self.use_scale_shift_norm: 253 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 254 | scale, shift = torch.chunk(emb_out, 2, dim=1) 255 | h = out_norm(h) * (1 + scale) + shift 256 | h = out_rest(h) 257 | else: 258 | h = h + emb_out 259 | h = self.out_layers(h) 260 | h = self.skip_connection(x) + h 261 | 262 | if self.use_temporal_conv and batch_size: 263 | h = rearrange(h, '(b t) c h w -> b c t h w', b=batch_size) 264 | h = self.temopral_conv(h) 265 | h = rearrange(h, 'b c t h w -> (b t) c h w') 266 | return h 267 | 268 | 269 | class TemporalConvBlock(nn.Module): 270 | """ 271 | Adapted from modelscope: https://github.com/modelscope/modelscope/blob/master/modelscope/models/multi_modal/video_synthesis/unet_sd.py 272 | """ 273 | 274 | def __init__(self, in_channels, out_channels=None, dropout=0.0, spatial_aware=False): 275 | super(TemporalConvBlock, self).__init__() 276 | if out_channels is None: 277 | out_channels = in_channels 278 | self.in_channels = in_channels 279 | self.out_channels = out_channels 280 | th_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 1) 281 | th_padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 0) 282 | tw_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 1, 3) 283 | tw_padding_shape = (1, 0, 0) if not spatial_aware else (1, 0, 1) 284 | 285 | # conv layers 286 | self.conv1 = nn.Sequential( 287 | nn.GroupNorm(32, in_channels), nn.SiLU(), 288 | nn.Conv3d(in_channels, out_channels, th_kernel_shape, padding=th_padding_shape)) 289 | self.conv2 = nn.Sequential( 290 | nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout), 291 | nn.Conv3d(out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape)) 292 | self.conv3 = nn.Sequential( 293 | nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout), 294 | nn.Conv3d(out_channels, in_channels, th_kernel_shape, padding=th_padding_shape)) 295 | self.conv4 = nn.Sequential( 296 | nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout), 297 | nn.Conv3d(out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape)) 298 | 299 | # zero out the last layer params,so the conv block is identity 300 | nn.init.zeros_(self.conv4[-1].weight) 301 | nn.init.zeros_(self.conv4[-1].bias) 302 | 303 | def forward(self, x): 304 | identity = x 305 | x = self.conv1(x) 306 | x = self.conv2(x) 307 | x = self.conv3(x) 308 | x = self.conv4(x) 309 | 310 | return identity + x 311 | 312 | 313 | class UNet3DModel(nn.Module, PyTorchModelHubMixin): 314 | """ 315 | The full UNet model with attention and timestep embedding. 316 | :param in_channels: in_channels in the input Tensor. 317 | :param model_channels: base channel count for the model. 318 | :param out_channels: channels in the output Tensor. 319 | :param num_res_blocks: number of residual blocks per downsample. 320 | :param attention_resolutions: a collection of downsample rates at which 321 | attention will take place. May be a set, list, or tuple. 322 | For example, if this contains 4, then at 4x downsampling, attention 323 | will be used. 324 | :param dropout: the dropout probability. 325 | :param channel_mult: channel multiplier for each level of the UNet. 326 | :param conv_resample: if True, use learned convolutions for upsampling and 327 | downsampling. 328 | :param dims: determines if the signal is 1D, 2D, or 3D. 329 | :param num_classes: if specified (as an int), then this model will be 330 | class-conditional with `num_classes` classes. 331 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 332 | :param num_heads: the number of attention heads in each attention layer. 333 | :param num_heads_channels: if specified, ignore num_heads and instead use 334 | a fixed channel width per attention head. 335 | :param num_heads_upsample: works with num_heads to set a different number 336 | of heads for upsampling. Deprecated. 337 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 338 | :param resblock_updown: use residual blocks for up/downsampling. 339 | :param use_new_attention_order: use a different attention pattern for potentially 340 | increased efficiency. 341 | """ 342 | 343 | def __init__(self, 344 | in_channels, 345 | model_channels, 346 | out_channels, 347 | num_res_blocks, 348 | attention_resolutions, 349 | dropout=0.0, 350 | channel_mult=(1, 2, 4, 8), 351 | conv_resample=True, 352 | dims=2, 353 | context_dim=None, 354 | use_scale_shift_norm=False, 355 | resblock_updown=False, 356 | num_heads=-1, 357 | num_head_channels=-1, 358 | transformer_depth=1, 359 | use_linear=False, 360 | temporal_conv=False, 361 | tempspatial_aware=False, 362 | temporal_attention=True, 363 | use_relative_position=True, 364 | use_causal_attention=False, 365 | temporal_length=None, 366 | addition_attention=False, 367 | temporal_selfatt_only=True, 368 | image_cross_attention=False, 369 | image_cross_attention_scale_learnable=False, 370 | default_fs=4, 371 | fs_condition=False, 372 | ): 373 | super(UNet3DModel, self).__init__() 374 | if num_heads == -1: 375 | assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' 376 | if num_head_channels == -1: 377 | assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' 378 | 379 | self.in_channels = in_channels 380 | self.model_channels = model_channels 381 | self.out_channels = out_channels 382 | self.num_res_blocks = num_res_blocks 383 | self.attention_resolutions = attention_resolutions 384 | self.dropout = dropout 385 | self.channel_mult = channel_mult 386 | self.conv_resample = conv_resample 387 | self.temporal_attention = temporal_attention 388 | time_embed_dim = model_channels * 4 389 | self.use_checkpoint = use_checkpoint = False # moved to self.enable_gradient_checkpointing() 390 | temporal_self_att_only = True 391 | self.addition_attention = addition_attention 392 | self.temporal_length = temporal_length 393 | self.image_cross_attention = image_cross_attention 394 | self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable 395 | self.default_fs = default_fs 396 | self.fs_condition = fs_condition 397 | 398 | ## Time embedding blocks 399 | self.time_embed = nn.Sequential( 400 | linear(model_channels, time_embed_dim), 401 | nn.SiLU(), 402 | linear(time_embed_dim, time_embed_dim), 403 | ) 404 | if fs_condition: 405 | self.fps_embedding = nn.Sequential( 406 | linear(model_channels, time_embed_dim), 407 | nn.SiLU(), 408 | linear(time_embed_dim, time_embed_dim), 409 | ) 410 | nn.init.zeros_(self.fps_embedding[-1].weight) 411 | nn.init.zeros_(self.fps_embedding[-1].bias) 412 | ## Input Block 413 | self.input_blocks = nn.ModuleList( 414 | [ 415 | TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1)) 416 | ] 417 | ) 418 | if self.addition_attention: 419 | self.init_attn = TimestepEmbedSequential( 420 | TemporalTransformer( 421 | model_channels, 422 | n_heads=8, 423 | d_head=num_head_channels, 424 | depth=transformer_depth, 425 | context_dim=context_dim, 426 | use_checkpoint=use_checkpoint, only_self_att=temporal_selfatt_only, 427 | causal_attention=False, relative_position=use_relative_position, 428 | temporal_length=temporal_length)) 429 | 430 | input_block_chans = [model_channels] 431 | ch = model_channels 432 | ds = 1 433 | for level, mult in enumerate(channel_mult): 434 | for _ in range(num_res_blocks): 435 | layers = [ 436 | ResBlock(ch, time_embed_dim, dropout, 437 | out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, 438 | use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware, 439 | use_temporal_conv=temporal_conv 440 | ) 441 | ] 442 | ch = mult * model_channels 443 | if ds in attention_resolutions: 444 | if num_head_channels == -1: 445 | dim_head = ch // num_heads 446 | else: 447 | num_heads = ch // num_head_channels 448 | dim_head = num_head_channels 449 | layers.append( 450 | SpatialTransformer(ch, num_heads, dim_head, 451 | depth=transformer_depth, context_dim=context_dim, use_linear=use_linear, 452 | use_checkpoint=use_checkpoint, disable_self_attn=False, 453 | video_length=temporal_length, 454 | image_cross_attention=self.image_cross_attention, 455 | image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable, 456 | ) 457 | ) 458 | if self.temporal_attention: 459 | layers.append( 460 | TemporalTransformer(ch, num_heads, dim_head, 461 | depth=transformer_depth, context_dim=context_dim, use_linear=use_linear, 462 | use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only, 463 | causal_attention=use_causal_attention, 464 | relative_position=use_relative_position, 465 | temporal_length=temporal_length 466 | ) 467 | ) 468 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 469 | input_block_chans.append(ch) 470 | if level != len(channel_mult) - 1: 471 | out_ch = ch 472 | self.input_blocks.append( 473 | TimestepEmbedSequential( 474 | ResBlock(ch, time_embed_dim, dropout, 475 | out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, 476 | use_scale_shift_norm=use_scale_shift_norm, 477 | down=True 478 | ) 479 | if resblock_updown 480 | else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) 481 | ) 482 | ) 483 | ch = out_ch 484 | input_block_chans.append(ch) 485 | ds *= 2 486 | 487 | if num_head_channels == -1: 488 | dim_head = ch // num_heads 489 | else: 490 | num_heads = ch // num_head_channels 491 | dim_head = num_head_channels 492 | layers = [ 493 | ResBlock(ch, time_embed_dim, dropout, 494 | dims=dims, use_checkpoint=use_checkpoint, 495 | use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware, 496 | use_temporal_conv=temporal_conv 497 | ), 498 | SpatialTransformer(ch, num_heads, dim_head, 499 | depth=transformer_depth, context_dim=context_dim, use_linear=use_linear, 500 | use_checkpoint=use_checkpoint, disable_self_attn=False, video_length=temporal_length, 501 | image_cross_attention=self.image_cross_attention, 502 | image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable 503 | ) 504 | ] 505 | if self.temporal_attention: 506 | layers.append( 507 | TemporalTransformer(ch, num_heads, dim_head, 508 | depth=transformer_depth, context_dim=context_dim, use_linear=use_linear, 509 | use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only, 510 | causal_attention=use_causal_attention, relative_position=use_relative_position, 511 | temporal_length=temporal_length 512 | ) 513 | ) 514 | layers.append( 515 | ResBlock(ch, time_embed_dim, dropout, 516 | dims=dims, use_checkpoint=use_checkpoint, 517 | use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware, 518 | use_temporal_conv=temporal_conv 519 | ) 520 | ) 521 | 522 | ## Middle Block 523 | self.middle_block = TimestepEmbedSequential(*layers) 524 | 525 | ## Output Block 526 | self.output_blocks = nn.ModuleList([]) 527 | for level, mult in list(enumerate(channel_mult))[::-1]: 528 | for i in range(num_res_blocks + 1): 529 | ich = input_block_chans.pop() 530 | layers = [ 531 | ResBlock(ch + ich, time_embed_dim, dropout, 532 | out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, 533 | use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware, 534 | use_temporal_conv=temporal_conv 535 | ) 536 | ] 537 | ch = model_channels * mult 538 | if ds in attention_resolutions: 539 | if num_head_channels == -1: 540 | dim_head = ch // num_heads 541 | else: 542 | num_heads = ch // num_head_channels 543 | dim_head = num_head_channels 544 | layers.append( 545 | SpatialTransformer(ch, num_heads, dim_head, 546 | depth=transformer_depth, context_dim=context_dim, use_linear=use_linear, 547 | use_checkpoint=use_checkpoint, disable_self_attn=False, 548 | video_length=temporal_length, 549 | image_cross_attention=self.image_cross_attention, 550 | image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable 551 | ) 552 | ) 553 | if self.temporal_attention: 554 | layers.append( 555 | TemporalTransformer(ch, num_heads, dim_head, 556 | depth=transformer_depth, context_dim=context_dim, use_linear=use_linear, 557 | use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only, 558 | causal_attention=use_causal_attention, 559 | relative_position=use_relative_position, 560 | temporal_length=temporal_length 561 | ) 562 | ) 563 | if level and i == num_res_blocks: 564 | out_ch = ch 565 | layers.append( 566 | ResBlock(ch, time_embed_dim, dropout, 567 | out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, 568 | use_scale_shift_norm=use_scale_shift_norm, 569 | up=True 570 | ) 571 | if resblock_updown 572 | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) 573 | ) 574 | ds //= 2 575 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 576 | 577 | self.out = nn.Sequential( 578 | normalization(ch), 579 | nn.SiLU(), 580 | zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), 581 | ) 582 | 583 | @property 584 | def device(self): 585 | return next(self.parameters()).device 586 | 587 | @property 588 | def dtype(self): 589 | return next(self.parameters()).dtype 590 | 591 | def forward(self, x, timesteps, context_text=None, context_img=None, concat_cond=None, fs=None, **kwargs): 592 | b, _, t, _, _ = x.shape 593 | 594 | t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).type(x.dtype) 595 | emb = self.time_embed(t_emb) 596 | 597 | context_text = context_text.repeat_interleave(repeats=t, dim=0) 598 | context_img = rearrange(context_img, 'b t l c -> (b t) l c') 599 | 600 | context = (context_text, context_img) 601 | 602 | emb = emb.repeat_interleave(repeats=t, dim=0) 603 | 604 | if concat_cond is not None: 605 | x = torch.cat([x, concat_cond], dim=1) 606 | 607 | ## always in shape (b t) c h w, except for temporal layer 608 | x = rearrange(x, 'b c t h w -> (b t) c h w') 609 | 610 | ## combine emb 611 | if self.fs_condition: 612 | if fs is None: 613 | fs = torch.tensor( 614 | [self.default_fs] * b, dtype=torch.long, device=x.device) 615 | fs_emb = timestep_embedding(fs, self.model_channels, repeat_only=False).type(x.dtype) 616 | 617 | fs_embed = self.fps_embedding(fs_emb) 618 | fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0) 619 | emb = emb + fs_embed 620 | 621 | h = x 622 | hs = [] 623 | for id, module in enumerate(self.input_blocks): 624 | h = module(h, emb, context=context, batch_size=b) 625 | if id == 0 and self.addition_attention: 626 | h = self.init_attn(h, emb, context=context, batch_size=b) 627 | hs.append(h) 628 | 629 | h = self.middle_block(h, emb, context=context, batch_size=b) 630 | 631 | for module in self.output_blocks: 632 | h = torch.cat([h, hs.pop()], dim=1) 633 | h = module(h, emb, context=context, batch_size=b) 634 | h = h.type(x.dtype) 635 | y = self.out(h) 636 | 637 | y = rearrange(y, '(b t) c h w -> b c t h w', b=b) 638 | return y 639 | 640 | def enable_gradient_checkpointing(self, enable=True, verbose=False): 641 | for k, v in self.named_modules(): 642 | if hasattr(v, 'checkpoint'): 643 | v.checkpoint = enable 644 | if verbose: 645 | print(f'{k}.checkpoint = {enable}') 646 | if hasattr(v, 'use_checkpoint'): 647 | v.use_checkpoint = enable 648 | if verbose: 649 | print(f'{k}.use_checkpoint = {enable}') 650 | return 651 | -------------------------------------------------------------------------------- /Paints-UNDO/diffusers_vdm/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import einops 5 | import torchvision 6 | 7 | 8 | def resize_and_center_crop(image, target_width, target_height, interpolation=cv2.INTER_AREA): 9 | original_height, original_width = image.shape[:2] 10 | k = max(target_height / original_height, target_width / original_width) 11 | new_width = int(round(original_width * k)) 12 | new_height = int(round(original_height * k)) 13 | resized_image = cv2.resize(image, (new_width, new_height), interpolation=interpolation) 14 | x_start = (new_width - target_width) // 2 15 | y_start = (new_height - target_height) // 2 16 | cropped_image = resized_image[y_start:y_start + target_height, x_start:x_start + target_width] 17 | return cropped_image 18 | 19 | 20 | def save_bcthw_as_mp4(x, output_filename, fps=10): 21 | b, c, t, h, w = x.shape 22 | 23 | per_row = b 24 | for p in [6, 5, 4, 3, 2]: 25 | if b % p == 0: 26 | per_row = p 27 | break 28 | 29 | os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) 30 | x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 31 | x = x.detach().cpu().to(torch.uint8) 32 | x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row) 33 | torchvision.io.write_video(output_filename, x, fps=fps, video_codec='h264', options={'crf': '1'}) 34 | return x 35 | 36 | 37 | def save_bcthw_as_png(x, output_filename): 38 | os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) 39 | x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 40 | x = x.detach().cpu().to(torch.uint8) 41 | x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)') 42 | torchvision.io.write_png(x, output_filename) 43 | return output_filename 44 | -------------------------------------------------------------------------------- /Paints-UNDO/diffusers_vdm/vae.py: -------------------------------------------------------------------------------- 1 | # video VAE with many components from lots of repos 2 | # collected by lvmin 3 | 4 | 5 | import torch 6 | import xformers.ops 7 | import torch.nn as nn 8 | 9 | from einops import rearrange, repeat 10 | from diffusers_vdm.basics import default, exists, zero_module, conv_nd, linear, normalization 11 | from diffusers_vdm.unet import Upsample, Downsample 12 | from huggingface_hub import PyTorchModelHubMixin 13 | 14 | 15 | def chunked_attention(q, k, v, batch_chunk=0): 16 | # if batch_chunk > 0 and not torch.is_grad_enabled(): 17 | # batch_size = q.size(0) 18 | # chunks = [slice(i, i + batch_chunk) for i in range(0, batch_size, batch_chunk)] 19 | # 20 | # out_chunks = [] 21 | # for chunk in chunks: 22 | # q_chunk = q[chunk] 23 | # k_chunk = k[chunk] 24 | # v_chunk = v[chunk] 25 | # 26 | # out_chunk = torch.nn.functional.scaled_dot_product_attention( 27 | # q_chunk, k_chunk, v_chunk, attn_mask=None 28 | # ) 29 | # out_chunks.append(out_chunk) 30 | # 31 | # out = torch.cat(out_chunks, dim=0) 32 | # else: 33 | # out = torch.nn.functional.scaled_dot_product_attention( 34 | # q, k, v, attn_mask=None 35 | # ) 36 | out = xformers.ops.memory_efficient_attention(q, k, v) 37 | return out 38 | 39 | 40 | def nonlinearity(x): 41 | return x * torch.sigmoid(x) 42 | 43 | 44 | def GroupNorm(in_channels, num_groups=32): 45 | return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) 46 | 47 | 48 | class DiagonalGaussianDistribution: 49 | def __init__(self, parameters, deterministic=False): 50 | self.parameters = parameters 51 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 52 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 53 | self.deterministic = deterministic 54 | self.std = torch.exp(0.5 * self.logvar) 55 | self.var = torch.exp(self.logvar) 56 | if self.deterministic: 57 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 58 | 59 | def sample(self, noise=None): 60 | if noise is None: 61 | noise = torch.randn(self.mean.shape) 62 | 63 | x = self.mean + self.std * noise.to(device=self.parameters.device) 64 | return x 65 | 66 | def mode(self): 67 | return self.mean 68 | 69 | 70 | class EncoderDownSampleBlock(nn.Module): 71 | def __init__(self, in_channels, with_conv): 72 | super().__init__() 73 | self.with_conv = with_conv 74 | self.in_channels = in_channels 75 | if self.with_conv: 76 | self.conv = torch.nn.Conv2d(in_channels, 77 | in_channels, 78 | kernel_size=3, 79 | stride=2, 80 | padding=0) 81 | 82 | def forward(self, x): 83 | if self.with_conv: 84 | pad = (0, 1, 0, 1) 85 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 86 | x = self.conv(x) 87 | else: 88 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 89 | return x 90 | 91 | 92 | class ResnetBlock(nn.Module): 93 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 94 | dropout, temb_channels=512): 95 | super().__init__() 96 | self.in_channels = in_channels 97 | out_channels = in_channels if out_channels is None else out_channels 98 | self.out_channels = out_channels 99 | self.use_conv_shortcut = conv_shortcut 100 | 101 | self.norm1 = GroupNorm(in_channels) 102 | self.conv1 = torch.nn.Conv2d(in_channels, 103 | out_channels, 104 | kernel_size=3, 105 | stride=1, 106 | padding=1) 107 | if temb_channels > 0: 108 | self.temb_proj = torch.nn.Linear(temb_channels, 109 | out_channels) 110 | self.norm2 = GroupNorm(out_channels) 111 | self.dropout = torch.nn.Dropout(dropout) 112 | self.conv2 = torch.nn.Conv2d(out_channels, 113 | out_channels, 114 | kernel_size=3, 115 | stride=1, 116 | padding=1) 117 | if self.in_channels != self.out_channels: 118 | if self.use_conv_shortcut: 119 | self.conv_shortcut = torch.nn.Conv2d(in_channels, 120 | out_channels, 121 | kernel_size=3, 122 | stride=1, 123 | padding=1) 124 | else: 125 | self.nin_shortcut = torch.nn.Conv2d(in_channels, 126 | out_channels, 127 | kernel_size=1, 128 | stride=1, 129 | padding=0) 130 | 131 | def forward(self, x, temb): 132 | h = x 133 | h = self.norm1(h) 134 | h = nonlinearity(h) 135 | h = self.conv1(h) 136 | 137 | if temb is not None: 138 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 139 | 140 | h = self.norm2(h) 141 | h = nonlinearity(h) 142 | h = self.dropout(h) 143 | h = self.conv2(h) 144 | 145 | if self.in_channels != self.out_channels: 146 | if self.use_conv_shortcut: 147 | x = self.conv_shortcut(x) 148 | else: 149 | x = self.nin_shortcut(x) 150 | 151 | return x + h 152 | 153 | 154 | class Encoder(nn.Module): 155 | def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, 156 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, 157 | resolution, z_channels, double_z=True, **kwargs): 158 | super().__init__() 159 | self.ch = ch 160 | self.temb_ch = 0 161 | self.num_resolutions = len(ch_mult) 162 | self.num_res_blocks = num_res_blocks 163 | self.resolution = resolution 164 | self.in_channels = in_channels 165 | 166 | # downsampling 167 | self.conv_in = torch.nn.Conv2d(in_channels, 168 | self.ch, 169 | kernel_size=3, 170 | stride=1, 171 | padding=1) 172 | 173 | curr_res = resolution 174 | in_ch_mult = (1,) + tuple(ch_mult) 175 | self.in_ch_mult = in_ch_mult 176 | self.down = nn.ModuleList() 177 | for i_level in range(self.num_resolutions): 178 | block = nn.ModuleList() 179 | attn = nn.ModuleList() 180 | block_in = ch * in_ch_mult[i_level] 181 | block_out = ch * ch_mult[i_level] 182 | for i_block in range(self.num_res_blocks): 183 | block.append(ResnetBlock(in_channels=block_in, 184 | out_channels=block_out, 185 | temb_channels=self.temb_ch, 186 | dropout=dropout)) 187 | block_in = block_out 188 | if curr_res in attn_resolutions: 189 | attn.append(Attention(block_in)) 190 | down = nn.Module() 191 | down.block = block 192 | down.attn = attn 193 | if i_level != self.num_resolutions - 1: 194 | down.downsample = EncoderDownSampleBlock(block_in, resamp_with_conv) 195 | curr_res = curr_res // 2 196 | self.down.append(down) 197 | 198 | # middle 199 | self.mid = nn.Module() 200 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 201 | out_channels=block_in, 202 | temb_channels=self.temb_ch, 203 | dropout=dropout) 204 | self.mid.attn_1 = Attention(block_in) 205 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 206 | out_channels=block_in, 207 | temb_channels=self.temb_ch, 208 | dropout=dropout) 209 | 210 | # end 211 | self.norm_out = GroupNorm(block_in) 212 | self.conv_out = torch.nn.Conv2d(block_in, 213 | 2 * z_channels if double_z else z_channels, 214 | kernel_size=3, 215 | stride=1, 216 | padding=1) 217 | 218 | def forward(self, x, return_hidden_states=False): 219 | # timestep embedding 220 | temb = None 221 | 222 | # print(f'encoder-input={x.shape}') 223 | # downsampling 224 | hs = [self.conv_in(x)] 225 | 226 | ## if we return hidden states for decoder usage, we will store them in a list 227 | if return_hidden_states: 228 | hidden_states = [] 229 | # print(f'encoder-conv in feat={hs[0].shape}') 230 | for i_level in range(self.num_resolutions): 231 | for i_block in range(self.num_res_blocks): 232 | h = self.down[i_level].block[i_block](hs[-1], temb) 233 | # print(f'encoder-down feat={h.shape}') 234 | if len(self.down[i_level].attn) > 0: 235 | h = self.down[i_level].attn[i_block](h) 236 | hs.append(h) 237 | if return_hidden_states: 238 | hidden_states.append(h) 239 | if i_level != self.num_resolutions - 1: 240 | # print(f'encoder-downsample (input)={hs[-1].shape}') 241 | hs.append(self.down[i_level].downsample(hs[-1])) 242 | # print(f'encoder-downsample (output)={hs[-1].shape}') 243 | if return_hidden_states: 244 | hidden_states.append(hs[0]) 245 | # middle 246 | h = hs[-1] 247 | h = self.mid.block_1(h, temb) 248 | # print(f'encoder-mid1 feat={h.shape}') 249 | h = self.mid.attn_1(h) 250 | h = self.mid.block_2(h, temb) 251 | # print(f'encoder-mid2 feat={h.shape}') 252 | 253 | # end 254 | h = self.norm_out(h) 255 | h = nonlinearity(h) 256 | h = self.conv_out(h) 257 | # print(f'end feat={h.shape}') 258 | if return_hidden_states: 259 | return h, hidden_states 260 | else: 261 | return h 262 | 263 | 264 | class ConvCombiner(nn.Module): 265 | def __init__(self, ch): 266 | super().__init__() 267 | self.conv = nn.Conv2d(ch, ch, 1, padding=0) 268 | 269 | nn.init.zeros_(self.conv.weight) 270 | nn.init.zeros_(self.conv.bias) 271 | 272 | def forward(self, x, context): 273 | ## x: b c h w, context: b c 2 h w 274 | b, c, l, h, w = context.shape 275 | bt, c, h, w = x.shape 276 | context = rearrange(context, "b c l h w -> (b l) c h w") 277 | context = self.conv(context) 278 | context = rearrange(context, "(b l) c h w -> b c l h w", l=l) 279 | x = rearrange(x, "(b t) c h w -> b c t h w", t=bt // b) 280 | x[:, :, 0] = x[:, :, 0] + context[:, :, 0] 281 | x[:, :, -1] = x[:, :, -1] + context[:, :, -1] 282 | x = rearrange(x, "b c t h w -> (b t) c h w") 283 | return x 284 | 285 | 286 | class AttentionCombiner(nn.Module): 287 | def __init__( 288 | self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs 289 | ): 290 | super().__init__() 291 | 292 | inner_dim = dim_head * heads 293 | context_dim = default(context_dim, query_dim) 294 | 295 | self.heads = heads 296 | self.dim_head = dim_head 297 | 298 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 299 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 300 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 301 | 302 | self.to_out = nn.Sequential( 303 | nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) 304 | ) 305 | self.attention_op = None 306 | 307 | self.norm = GroupNorm(query_dim) 308 | nn.init.zeros_(self.to_out[0].weight) 309 | nn.init.zeros_(self.to_out[0].bias) 310 | 311 | def forward( 312 | self, 313 | x, 314 | context=None, 315 | mask=None, 316 | ): 317 | bt, c, h, w = x.shape 318 | h_ = self.norm(x) 319 | h_ = rearrange(h_, "b c h w -> b (h w) c") 320 | q = self.to_q(h_) 321 | 322 | b, c, l, h, w = context.shape 323 | context = rearrange(context, "b c l h w -> (b l) (h w) c") 324 | k = self.to_k(context) 325 | v = self.to_v(context) 326 | 327 | t = bt // b 328 | k = repeat(k, "(b l) d c -> (b t) (l d) c", l=l, t=t) 329 | v = repeat(v, "(b l) d c -> (b t) (l d) c", l=l, t=t) 330 | 331 | b, _, _ = q.shape 332 | q, k, v = map( 333 | lambda t: t.unsqueeze(3) 334 | .reshape(b, t.shape[1], self.heads, self.dim_head) 335 | .permute(0, 2, 1, 3) 336 | .reshape(b * self.heads, t.shape[1], self.dim_head) 337 | .contiguous(), 338 | (q, k, v), 339 | ) 340 | 341 | out = chunked_attention( 342 | q, k, v, batch_chunk=1 343 | ) 344 | 345 | if exists(mask): 346 | raise NotImplementedError 347 | 348 | out = ( 349 | out.unsqueeze(0) 350 | .reshape(b, self.heads, out.shape[1], self.dim_head) 351 | .permute(0, 2, 1, 3) 352 | .reshape(b, out.shape[1], self.heads * self.dim_head) 353 | ) 354 | out = self.to_out(out) 355 | out = rearrange(out, "bt (h w) c -> bt c h w", h=h, w=w, c=c) 356 | return x + out 357 | 358 | 359 | class Attention(nn.Module): 360 | def __init__(self, in_channels): 361 | super().__init__() 362 | self.in_channels = in_channels 363 | 364 | self.norm = GroupNorm(in_channels) 365 | self.q = torch.nn.Conv2d( 366 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 367 | ) 368 | self.k = torch.nn.Conv2d( 369 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 370 | ) 371 | self.v = torch.nn.Conv2d( 372 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 373 | ) 374 | self.proj_out = torch.nn.Conv2d( 375 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 376 | ) 377 | 378 | def attention(self, h_: torch.Tensor) -> torch.Tensor: 379 | h_ = self.norm(h_) 380 | q = self.q(h_) 381 | k = self.k(h_) 382 | v = self.v(h_) 383 | 384 | # compute attention 385 | B, C, H, W = q.shape 386 | q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) 387 | 388 | q, k, v = map( 389 | lambda t: t.unsqueeze(3) 390 | .reshape(B, t.shape[1], 1, C) 391 | .permute(0, 2, 1, 3) 392 | .reshape(B * 1, t.shape[1], C) 393 | .contiguous(), 394 | (q, k, v), 395 | ) 396 | 397 | out = chunked_attention( 398 | q, k, v, batch_chunk=1 399 | ) 400 | 401 | out = ( 402 | out.unsqueeze(0) 403 | .reshape(B, 1, out.shape[1], C) 404 | .permute(0, 2, 1, 3) 405 | .reshape(B, out.shape[1], C) 406 | ) 407 | return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) 408 | 409 | def forward(self, x, **kwargs): 410 | h_ = x 411 | h_ = self.attention(h_) 412 | h_ = self.proj_out(h_) 413 | return x + h_ 414 | 415 | 416 | class VideoDecoder(nn.Module): 417 | def __init__( 418 | self, 419 | *, 420 | ch, 421 | out_ch, 422 | ch_mult=(1, 2, 4, 8), 423 | num_res_blocks, 424 | attn_resolutions, 425 | dropout=0.0, 426 | resamp_with_conv=True, 427 | in_channels, 428 | resolution, 429 | z_channels, 430 | give_pre_end=False, 431 | tanh_out=False, 432 | use_linear_attn=False, 433 | attn_level=[2, 3], 434 | video_kernel_size=[3, 1, 1], 435 | alpha: float = 0.0, 436 | merge_strategy: str = "learned", 437 | **kwargs, 438 | ): 439 | super().__init__() 440 | self.video_kernel_size = video_kernel_size 441 | self.alpha = alpha 442 | self.merge_strategy = merge_strategy 443 | self.ch = ch 444 | self.temb_ch = 0 445 | self.num_resolutions = len(ch_mult) 446 | self.num_res_blocks = num_res_blocks 447 | self.resolution = resolution 448 | self.in_channels = in_channels 449 | self.give_pre_end = give_pre_end 450 | self.tanh_out = tanh_out 451 | self.attn_level = attn_level 452 | # compute in_ch_mult, block_in and curr_res at lowest res 453 | in_ch_mult = (1,) + tuple(ch_mult) 454 | block_in = ch * ch_mult[self.num_resolutions - 1] 455 | curr_res = resolution // 2 ** (self.num_resolutions - 1) 456 | self.z_shape = (1, z_channels, curr_res, curr_res) 457 | 458 | # z to block_in 459 | self.conv_in = torch.nn.Conv2d( 460 | z_channels, block_in, kernel_size=3, stride=1, padding=1 461 | ) 462 | 463 | # middle 464 | self.mid = nn.Module() 465 | self.mid.block_1 = VideoResBlock( 466 | in_channels=block_in, 467 | out_channels=block_in, 468 | temb_channels=self.temb_ch, 469 | dropout=dropout, 470 | video_kernel_size=self.video_kernel_size, 471 | alpha=self.alpha, 472 | merge_strategy=self.merge_strategy, 473 | ) 474 | self.mid.attn_1 = Attention(block_in) 475 | self.mid.block_2 = VideoResBlock( 476 | in_channels=block_in, 477 | out_channels=block_in, 478 | temb_channels=self.temb_ch, 479 | dropout=dropout, 480 | video_kernel_size=self.video_kernel_size, 481 | alpha=self.alpha, 482 | merge_strategy=self.merge_strategy, 483 | ) 484 | 485 | # upsampling 486 | self.up = nn.ModuleList() 487 | self.attn_refinement = nn.ModuleList() 488 | for i_level in reversed(range(self.num_resolutions)): 489 | block = nn.ModuleList() 490 | attn = nn.ModuleList() 491 | block_out = ch * ch_mult[i_level] 492 | for i_block in range(self.num_res_blocks + 1): 493 | block.append( 494 | VideoResBlock( 495 | in_channels=block_in, 496 | out_channels=block_out, 497 | temb_channels=self.temb_ch, 498 | dropout=dropout, 499 | video_kernel_size=self.video_kernel_size, 500 | alpha=self.alpha, 501 | merge_strategy=self.merge_strategy, 502 | ) 503 | ) 504 | block_in = block_out 505 | if curr_res in attn_resolutions: 506 | attn.append(Attention(block_in)) 507 | up = nn.Module() 508 | up.block = block 509 | up.attn = attn 510 | if i_level != 0: 511 | up.upsample = Upsample(block_in, resamp_with_conv) 512 | curr_res = curr_res * 2 513 | self.up.insert(0, up) # prepend to get consistent order 514 | 515 | if i_level in self.attn_level: 516 | self.attn_refinement.insert(0, AttentionCombiner(block_in)) 517 | else: 518 | self.attn_refinement.insert(0, ConvCombiner(block_in)) 519 | # end 520 | self.norm_out = GroupNorm(block_in) 521 | self.attn_refinement.append(ConvCombiner(block_in)) 522 | self.conv_out = DecoderConv3D( 523 | block_in, out_ch, kernel_size=3, stride=1, padding=1, video_kernel_size=self.video_kernel_size 524 | ) 525 | 526 | def forward(self, z, ref_context=None, **kwargs): 527 | ## ref_context: b c 2 h w, 2 means starting and ending frame 528 | # assert z.shape[1:] == self.z_shape[1:] 529 | self.last_z_shape = z.shape 530 | # timestep embedding 531 | temb = None 532 | 533 | # z to block_in 534 | h = self.conv_in(z) 535 | 536 | # middle 537 | h = self.mid.block_1(h, temb, **kwargs) 538 | h = self.mid.attn_1(h, **kwargs) 539 | h = self.mid.block_2(h, temb, **kwargs) 540 | 541 | # upsampling 542 | for i_level in reversed(range(self.num_resolutions)): 543 | for i_block in range(self.num_res_blocks + 1): 544 | h = self.up[i_level].block[i_block](h, temb, **kwargs) 545 | if len(self.up[i_level].attn) > 0: 546 | h = self.up[i_level].attn[i_block](h, **kwargs) 547 | if ref_context: 548 | h = self.attn_refinement[i_level](x=h, context=ref_context[i_level]) 549 | if i_level != 0: 550 | h = self.up[i_level].upsample(h) 551 | 552 | # end 553 | if self.give_pre_end: 554 | return h 555 | 556 | h = self.norm_out(h) 557 | h = nonlinearity(h) 558 | if ref_context: 559 | # print(h.shape, ref_context[i_level].shape) #torch.Size([8, 128, 256, 256]) torch.Size([1, 128, 2, 256, 256]) 560 | h = self.attn_refinement[-1](x=h, context=ref_context[-1]) 561 | h = self.conv_out(h, **kwargs) 562 | if self.tanh_out: 563 | h = torch.tanh(h) 564 | return h 565 | 566 | 567 | class TimeStackBlock(torch.nn.Module): 568 | def __init__( 569 | self, 570 | channels: int, 571 | emb_channels: int, 572 | dropout: float, 573 | out_channels: int = None, 574 | use_conv: bool = False, 575 | use_scale_shift_norm: bool = False, 576 | dims: int = 2, 577 | use_checkpoint: bool = False, 578 | up: bool = False, 579 | down: bool = False, 580 | kernel_size: int = 3, 581 | exchange_temb_dims: bool = False, 582 | skip_t_emb: bool = False, 583 | ): 584 | super().__init__() 585 | self.channels = channels 586 | self.emb_channels = emb_channels 587 | self.dropout = dropout 588 | self.out_channels = out_channels or channels 589 | self.use_conv = use_conv 590 | self.use_checkpoint = use_checkpoint 591 | self.use_scale_shift_norm = use_scale_shift_norm 592 | self.exchange_temb_dims = exchange_temb_dims 593 | 594 | if isinstance(kernel_size, list): 595 | padding = [k // 2 for k in kernel_size] 596 | else: 597 | padding = kernel_size // 2 598 | 599 | self.in_layers = nn.Sequential( 600 | normalization(channels), 601 | nn.SiLU(), 602 | conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), 603 | ) 604 | 605 | self.updown = up or down 606 | 607 | if up: 608 | self.h_upd = Upsample(channels, False, dims) 609 | self.x_upd = Upsample(channels, False, dims) 610 | elif down: 611 | self.h_upd = Downsample(channels, False, dims) 612 | self.x_upd = Downsample(channels, False, dims) 613 | else: 614 | self.h_upd = self.x_upd = nn.Identity() 615 | 616 | self.skip_t_emb = skip_t_emb 617 | self.emb_out_channels = ( 618 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels 619 | ) 620 | if self.skip_t_emb: 621 | # print(f"Skipping timestep embedding in {self.__class__.__name__}") 622 | assert not self.use_scale_shift_norm 623 | self.emb_layers = None 624 | self.exchange_temb_dims = False 625 | else: 626 | self.emb_layers = nn.Sequential( 627 | nn.SiLU(), 628 | linear( 629 | emb_channels, 630 | self.emb_out_channels, 631 | ), 632 | ) 633 | 634 | self.out_layers = nn.Sequential( 635 | normalization(self.out_channels), 636 | nn.SiLU(), 637 | nn.Dropout(p=dropout), 638 | zero_module( 639 | conv_nd( 640 | dims, 641 | self.out_channels, 642 | self.out_channels, 643 | kernel_size, 644 | padding=padding, 645 | ) 646 | ), 647 | ) 648 | 649 | if self.out_channels == channels: 650 | self.skip_connection = nn.Identity() 651 | elif use_conv: 652 | self.skip_connection = conv_nd( 653 | dims, channels, self.out_channels, kernel_size, padding=padding 654 | ) 655 | else: 656 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 657 | 658 | def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: 659 | if self.updown: 660 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 661 | h = in_rest(x) 662 | h = self.h_upd(h) 663 | x = self.x_upd(x) 664 | h = in_conv(h) 665 | else: 666 | h = self.in_layers(x) 667 | 668 | if self.skip_t_emb: 669 | emb_out = torch.zeros_like(h) 670 | else: 671 | emb_out = self.emb_layers(emb).type(h.dtype) 672 | while len(emb_out.shape) < len(h.shape): 673 | emb_out = emb_out[..., None] 674 | if self.use_scale_shift_norm: 675 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 676 | scale, shift = torch.chunk(emb_out, 2, dim=1) 677 | h = out_norm(h) * (1 + scale) + shift 678 | h = out_rest(h) 679 | else: 680 | if self.exchange_temb_dims: 681 | emb_out = rearrange(emb_out, "b t c ... -> b c t ...") 682 | h = h + emb_out 683 | h = self.out_layers(h) 684 | return self.skip_connection(x) + h 685 | 686 | 687 | class VideoResBlock(ResnetBlock): 688 | def __init__( 689 | self, 690 | out_channels, 691 | *args, 692 | dropout=0.0, 693 | video_kernel_size=3, 694 | alpha=0.0, 695 | merge_strategy="learned", 696 | **kwargs, 697 | ): 698 | super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs) 699 | if video_kernel_size is None: 700 | video_kernel_size = [3, 1, 1] 701 | self.time_stack = TimeStackBlock( 702 | channels=out_channels, 703 | emb_channels=0, 704 | dropout=dropout, 705 | dims=3, 706 | use_scale_shift_norm=False, 707 | use_conv=False, 708 | up=False, 709 | down=False, 710 | kernel_size=video_kernel_size, 711 | use_checkpoint=True, 712 | skip_t_emb=True, 713 | ) 714 | 715 | self.merge_strategy = merge_strategy 716 | if self.merge_strategy == "fixed": 717 | self.register_buffer("mix_factor", torch.Tensor([alpha])) 718 | elif self.merge_strategy == "learned": 719 | self.register_parameter( 720 | "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) 721 | ) 722 | else: 723 | raise ValueError(f"unknown merge strategy {self.merge_strategy}") 724 | 725 | def get_alpha(self, bs): 726 | if self.merge_strategy == "fixed": 727 | return self.mix_factor 728 | elif self.merge_strategy == "learned": 729 | return torch.sigmoid(self.mix_factor) 730 | else: 731 | raise NotImplementedError() 732 | 733 | def forward(self, x, temb, skip_video=False, timesteps=None): 734 | assert isinstance(timesteps, int) 735 | 736 | b, c, h, w = x.shape 737 | 738 | x = super().forward(x, temb) 739 | 740 | if not skip_video: 741 | x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) 742 | 743 | x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) 744 | 745 | x = self.time_stack(x, temb) 746 | 747 | alpha = self.get_alpha(bs=b // timesteps) 748 | x = alpha * x + (1.0 - alpha) * x_mix 749 | 750 | x = rearrange(x, "b c t h w -> (b t) c h w") 751 | return x 752 | 753 | 754 | class DecoderConv3D(torch.nn.Conv2d): 755 | def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): 756 | super().__init__(in_channels, out_channels, *args, **kwargs) 757 | if isinstance(video_kernel_size, list): 758 | padding = [int(k // 2) for k in video_kernel_size] 759 | else: 760 | padding = int(video_kernel_size // 2) 761 | 762 | self.time_mix_conv = torch.nn.Conv3d( 763 | in_channels=out_channels, 764 | out_channels=out_channels, 765 | kernel_size=video_kernel_size, 766 | padding=padding, 767 | ) 768 | 769 | def forward(self, input, timesteps, skip_video=False): 770 | x = super().forward(input) 771 | if skip_video: 772 | return x 773 | x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) 774 | x = self.time_mix_conv(x) 775 | return rearrange(x, "b c t h w -> (b t) c h w") 776 | 777 | 778 | class VideoAutoencoderKL(torch.nn.Module, PyTorchModelHubMixin): 779 | def __init__(self, 780 | double_z=True, 781 | z_channels=4, 782 | resolution=256, 783 | in_channels=3, 784 | out_ch=3, 785 | ch=128, 786 | ch_mult=[], 787 | num_res_blocks=2, 788 | attn_resolutions=[], 789 | dropout=0.0, 790 | ): 791 | super().__init__() 792 | self.encoder = Encoder(double_z=double_z, z_channels=z_channels, resolution=resolution, in_channels=in_channels, 793 | out_ch=out_ch, ch=ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks, 794 | attn_resolutions=attn_resolutions, dropout=dropout) 795 | self.decoder = VideoDecoder(double_z=double_z, z_channels=z_channels, resolution=resolution, 796 | in_channels=in_channels, out_ch=out_ch, ch=ch, ch_mult=ch_mult, 797 | num_res_blocks=num_res_blocks, attn_resolutions=attn_resolutions, dropout=dropout) 798 | self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * z_channels, 1) 799 | self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1) 800 | self.scale_factor = 0.18215 801 | 802 | def encode(self, x, return_hidden_states=False, **kwargs): 803 | if return_hidden_states: 804 | h, hidden = self.encoder(x, return_hidden_states) 805 | moments = self.quant_conv(h) 806 | posterior = DiagonalGaussianDistribution(moments) 807 | return posterior, hidden 808 | else: 809 | h = self.encoder(x) 810 | moments = self.quant_conv(h) 811 | posterior = DiagonalGaussianDistribution(moments) 812 | return posterior, None 813 | 814 | def decode(self, z, **kwargs): 815 | if len(kwargs) == 0: 816 | z = self.post_quant_conv(z) 817 | dec = self.decoder(z, **kwargs) 818 | return dec 819 | 820 | @property 821 | def device(self): 822 | return next(self.parameters()).device 823 | 824 | @property 825 | def dtype(self): 826 | return next(self.parameters()).dtype 827 | -------------------------------------------------------------------------------- /Paints-UNDO/gradio_app.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['HF_HOME'] = os.path.join(os.path.dirname(__file__), 'hf_download') 4 | result_dir = os.path.join('./', 'results') 5 | os.makedirs(result_dir, exist_ok=True) 6 | 7 | 8 | import functools 9 | import os 10 | import random 11 | import gradio as gr 12 | import numpy as np 13 | import torch 14 | import wd14tagger 15 | import memory_management 16 | import uuid 17 | 18 | from PIL import Image 19 | from diffusers_helper.code_cond import unet_add_coded_conds 20 | from diffusers_helper.cat_cond import unet_add_concat_conds 21 | from diffusers_helper.k_diffusion import KDiffusionSampler 22 | from diffusers import AutoencoderKL, UNet2DConditionModel 23 | from diffusers.models.attention_processor import AttnProcessor2_0 24 | from transformers import CLIPTextModel, CLIPTokenizer 25 | from diffusers_vdm.pipeline import LatentVideoDiffusionPipeline 26 | from diffusers_vdm.utils import resize_and_center_crop, save_bcthw_as_mp4 27 | 28 | 29 | class ModifiedUNet(UNet2DConditionModel): 30 | @classmethod 31 | def from_config(cls, *args, **kwargs): 32 | m = super().from_config(*args, **kwargs) 33 | unet_add_concat_conds(unet=m, new_channels=4) 34 | unet_add_coded_conds(unet=m, added_number_count=1) 35 | return m 36 | 37 | 38 | model_name = 'lllyasviel/paints_undo_single_frame' 39 | tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer") 40 | text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder").to(torch.float16) 41 | vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae").to(torch.bfloat16) # bfloat16 vae 42 | unet = ModifiedUNet.from_pretrained(model_name, subfolder="unet").to(torch.float16) 43 | 44 | unet.set_attn_processor(AttnProcessor2_0()) 45 | vae.set_attn_processor(AttnProcessor2_0()) 46 | 47 | video_pipe = LatentVideoDiffusionPipeline.from_pretrained( 48 | 'lllyasviel/paints_undo_multi_frame', 49 | fp16=True 50 | ) 51 | 52 | memory_management.unload_all_models([ 53 | video_pipe.unet, video_pipe.vae, video_pipe.text_encoder, video_pipe.image_projection, video_pipe.image_encoder, 54 | unet, vae, text_encoder 55 | ]) 56 | 57 | k_sampler = KDiffusionSampler( 58 | unet=unet, 59 | timesteps=1000, 60 | linear_start=0.00085, 61 | linear_end=0.020, 62 | linear=True 63 | ) 64 | 65 | 66 | def find_best_bucket(h, w, options): 67 | min_metric = float('inf') 68 | best_bucket = None 69 | for (bucket_h, bucket_w) in options: 70 | metric = abs(h * bucket_w - w * bucket_h) 71 | if metric <= min_metric: 72 | min_metric = metric 73 | best_bucket = (bucket_h, bucket_w) 74 | return best_bucket 75 | 76 | 77 | @torch.inference_mode() 78 | def encode_cropped_prompt_77tokens(txt: str): 79 | memory_management.load_models_to_gpu(text_encoder) 80 | cond_ids = tokenizer(txt, 81 | padding="max_length", 82 | max_length=tokenizer.model_max_length, 83 | truncation=True, 84 | return_tensors="pt").input_ids.to(device=text_encoder.device) 85 | text_cond = text_encoder(cond_ids, attention_mask=None).last_hidden_state 86 | return text_cond 87 | 88 | 89 | @torch.inference_mode() 90 | def pytorch2numpy(imgs): 91 | results = [] 92 | for x in imgs: 93 | y = x.movedim(0, -1) 94 | y = y * 127.5 + 127.5 95 | y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8) 96 | results.append(y) 97 | return results 98 | 99 | 100 | @torch.inference_mode() 101 | def numpy2pytorch(imgs): 102 | h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0 103 | h = h.movedim(-1, 1) 104 | return h 105 | 106 | 107 | def resize_without_crop(image, target_width, target_height): 108 | pil_image = Image.fromarray(image) 109 | resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS) 110 | return np.array(resized_image) 111 | 112 | 113 | @torch.inference_mode() 114 | def interrogator_process(x): 115 | return wd14tagger.default_interrogator(x) 116 | 117 | 118 | @torch.inference_mode() 119 | def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg, 120 | progress=gr.Progress()): 121 | rng = torch.Generator(device=memory_management.gpu).manual_seed(int(seed)) 122 | 123 | memory_management.load_models_to_gpu(vae) 124 | fg = resize_and_center_crop(input_fg, image_width, image_height) 125 | concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype) 126 | concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor 127 | 128 | memory_management.load_models_to_gpu(text_encoder) 129 | conds = encode_cropped_prompt_77tokens(prompt) 130 | unconds = encode_cropped_prompt_77tokens(n_prompt) 131 | 132 | memory_management.load_models_to_gpu(unet) 133 | fs = torch.tensor(input_undo_steps).to(device=unet.device, dtype=torch.long) 134 | initial_latents = torch.zeros_like(concat_conds) 135 | concat_conds = concat_conds.to(device=unet.device, dtype=unet.dtype) 136 | latents = k_sampler( 137 | initial_latent=initial_latents, 138 | strength=1.0, 139 | num_inference_steps=steps, 140 | guidance_scale=cfg, 141 | batch_size=len(input_undo_steps), 142 | generator=rng, 143 | prompt_embeds=conds, 144 | negative_prompt_embeds=unconds, 145 | cross_attention_kwargs={'concat_conds': concat_conds, 'coded_conds': fs}, 146 | same_noise_in_batch=True, 147 | progress_tqdm=functools.partial(progress.tqdm, desc='Generating Key Frames') 148 | ).to(vae.dtype) / vae.config.scaling_factor 149 | 150 | memory_management.load_models_to_gpu(vae) 151 | pixels = vae.decode(latents).sample 152 | pixels = pytorch2numpy(pixels) 153 | pixels = [fg] + pixels + [np.zeros_like(fg) + 255] 154 | 155 | return pixels 156 | 157 | 158 | @torch.inference_mode() 159 | def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=7.5, fs=3, progress_tqdm=None): 160 | random.seed(seed) 161 | np.random.seed(seed) 162 | torch.manual_seed(seed) 163 | torch.cuda.manual_seed_all(seed) 164 | 165 | frames = 16 166 | 167 | target_height, target_width = find_best_bucket( 168 | image_1.shape[0], image_1.shape[1], 169 | options=[(320, 512), (384, 448), (448, 384), (512, 320)] 170 | ) 171 | 172 | image_1 = resize_and_center_crop(image_1, target_width=target_width, target_height=target_height) 173 | image_2 = resize_and_center_crop(image_2, target_width=target_width, target_height=target_height) 174 | input_frames = numpy2pytorch([image_1, image_2]) 175 | input_frames = input_frames.unsqueeze(0).movedim(1, 2) 176 | 177 | memory_management.load_models_to_gpu(video_pipe.text_encoder) 178 | positive_text_cond = video_pipe.encode_cropped_prompt_77tokens(prompt) 179 | negative_text_cond = video_pipe.encode_cropped_prompt_77tokens("") 180 | 181 | memory_management.load_models_to_gpu([video_pipe.image_projection, video_pipe.image_encoder]) 182 | input_frames = input_frames.to(device=video_pipe.image_encoder.device, dtype=video_pipe.image_encoder.dtype) 183 | positive_image_cond = video_pipe.encode_clip_vision(input_frames) 184 | positive_image_cond = video_pipe.image_projection(positive_image_cond) 185 | negative_image_cond = video_pipe.encode_clip_vision(torch.zeros_like(input_frames)) 186 | negative_image_cond = video_pipe.image_projection(negative_image_cond) 187 | 188 | memory_management.load_models_to_gpu([video_pipe.vae]) 189 | input_frames = input_frames.to(device=video_pipe.vae.device, dtype=video_pipe.vae.dtype) 190 | input_frame_latents, vae_hidden_states = video_pipe.encode_latents(input_frames, return_hidden_states=True) 191 | first_frame = input_frame_latents[:, :, 0] 192 | last_frame = input_frame_latents[:, :, 1] 193 | concat_cond = torch.stack([first_frame] + [torch.zeros_like(first_frame)] * (frames - 2) + [last_frame], dim=2) 194 | 195 | memory_management.load_models_to_gpu([video_pipe.unet]) 196 | latents = video_pipe( 197 | batch_size=1, 198 | steps=int(steps), 199 | guidance_scale=cfg_scale, 200 | positive_text_cond=positive_text_cond, 201 | negative_text_cond=negative_text_cond, 202 | positive_image_cond=positive_image_cond, 203 | negative_image_cond=negative_image_cond, 204 | concat_cond=concat_cond, 205 | fs=fs, 206 | progress_tqdm=progress_tqdm 207 | ) 208 | 209 | memory_management.load_models_to_gpu([video_pipe.vae]) 210 | video = video_pipe.decode_latents(latents, vae_hidden_states) 211 | return video, image_1, image_2 212 | 213 | 214 | @torch.inference_mode() 215 | def process_video(keyframes, prompt, steps, cfg, fps, seed, progress=gr.Progress()): 216 | result_frames = [] 217 | cropped_images = [] 218 | 219 | for i, (im1, im2) in enumerate(zip(keyframes[:-1], keyframes[1:])): 220 | im1 = np.array(Image.open(im1[0])) 221 | im2 = np.array(Image.open(im2[0])) 222 | frames, im1, im2 = process_video_inner( 223 | im1, im2, prompt, seed=seed + i, steps=steps, cfg_scale=cfg, fs=3, 224 | progress_tqdm=functools.partial(progress.tqdm, desc=f'Generating Videos ({i + 1}/{len(keyframes) - 1})') 225 | ) 226 | result_frames.append(frames[:, :, :-1, :, :]) 227 | cropped_images.append([im1, im2]) 228 | 229 | video = torch.cat(result_frames, dim=2) 230 | video = torch.flip(video, dims=[2]) 231 | 232 | uuid_name = str(uuid.uuid4()) 233 | output_filename = os.path.join(result_dir, uuid_name + '.mp4') 234 | Image.fromarray(cropped_images[0][0]).save(os.path.join(result_dir, uuid_name + '.png')) 235 | video = save_bcthw_as_mp4(video, output_filename, fps=fps) 236 | video = [x.cpu().numpy() for x in video] 237 | return output_filename, video 238 | 239 | 240 | block = gr.Blocks().queue() 241 | with block: 242 | gr.Markdown('# Paints-Undo') 243 | 244 | with gr.Accordion(label='Step 1: Upload Image and Generate Prompt', open=True): 245 | with gr.Row(): 246 | with gr.Column(): 247 | input_fg = gr.Image(sources=['upload'], type="numpy", label="Image", height=512) 248 | with gr.Column(): 249 | prompt_gen_button = gr.Button(value="Generate Prompt", interactive=False) 250 | prompt = gr.Textbox(label="Output Prompt", interactive=True) 251 | 252 | with gr.Accordion(label='Step 2: Generate Key Frames', open=True): 253 | with gr.Row(): 254 | with gr.Column(): 255 | input_undo_steps = gr.Dropdown(label="Operation Steps", value=[400, 600, 800, 900, 950, 999], 256 | choices=list(range(1000)), multiselect=True) 257 | seed = gr.Slider(label='Stage 1 Seed', minimum=0, maximum=50000, step=1, value=12345) 258 | image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64) 259 | image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64) 260 | steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1) 261 | cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=3.0, step=0.01) 262 | n_prompt = gr.Textbox(label="Negative Prompt", 263 | value='lowres, bad anatomy, bad hands, cropped, worst quality') 264 | 265 | with gr.Column(): 266 | key_gen_button = gr.Button(value="Generate Key Frames", interactive=False) 267 | result_gallery = gr.Gallery(height=512, object_fit='contain', label='Outputs', columns=4) 268 | 269 | with gr.Accordion(label='Step 3: Generate All Videos', open=True): 270 | with gr.Row(): 271 | with gr.Column(): 272 | # Note that, at "Step 3: Generate All Videos", using "1girl, masterpiece, best quality" 273 | # or "1boy, masterpiece, best quality" or just "masterpiece, best quality" leads to better results. 274 | # Do NOT modify this to use the prompts generated from Step 1 !! 275 | i2v_input_text = gr.Text(label='Prompts', value='1girl, masterpiece, best quality') 276 | i2v_seed = gr.Slider(label='Stage 2 Seed', minimum=0, maximum=50000, step=1, value=123) 277 | i2v_cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.5, 278 | elem_id="i2v_cfg_scale") 279 | i2v_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="i2v_steps", 280 | label="Sampling steps", value=50) 281 | i2v_fps = gr.Slider(minimum=1, maximum=30, step=1, elem_id="i2v_motion", label="FPS", value=4) 282 | with gr.Column(): 283 | i2v_end_btn = gr.Button("Generate Video", interactive=False) 284 | i2v_output_video = gr.Video(label="Generated Video", elem_id="output_vid", autoplay=True, 285 | show_share_button=True, height=512) 286 | with gr.Row(): 287 | i2v_output_images = gr.Gallery(height=512, label="Output Frames", object_fit="contain", columns=8) 288 | 289 | input_fg.change(lambda: ["", gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=False)], 290 | outputs=[prompt, prompt_gen_button, key_gen_button, i2v_end_btn]) 291 | 292 | prompt_gen_button.click( 293 | fn=interrogator_process, 294 | inputs=[input_fg], 295 | outputs=[prompt] 296 | ).then(lambda: [gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=False)], 297 | outputs=[prompt_gen_button, key_gen_button, i2v_end_btn]) 298 | 299 | key_gen_button.click( 300 | fn=process, 301 | inputs=[input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg], 302 | outputs=[result_gallery] 303 | ).then(lambda: [gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)], 304 | outputs=[prompt_gen_button, key_gen_button, i2v_end_btn]) 305 | 306 | i2v_end_btn.click( 307 | inputs=[result_gallery, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_fps, i2v_seed], 308 | outputs=[i2v_output_video, i2v_output_images], 309 | fn=process_video 310 | ) 311 | 312 | dbs = [ 313 | ['./imgs/1.jpg', 12345, 123], 314 | ['./imgs/2.jpg', 37000, 12345], 315 | ['./imgs/3.jpg', 3000, 3000], 316 | ] 317 | 318 | gr.Examples( 319 | examples=dbs, 320 | inputs=[input_fg, seed, i2v_seed], 321 | examples_per_page=1024 322 | ) 323 | 324 | block.queue().launch(server_name='0.0.0.0') 325 | -------------------------------------------------------------------------------- /Paints-UNDO/imgs/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZXYWQ/ComfyUI-ZZXYWQ/f5b3deb3c7f8bc7f0a7d1f83fca54d14b4ea7361/Paints-UNDO/imgs/1.jpg -------------------------------------------------------------------------------- /Paints-UNDO/imgs/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZXYWQ/ComfyUI-ZZXYWQ/f5b3deb3c7f8bc7f0a7d1f83fca54d14b4ea7361/Paints-UNDO/imgs/2.jpg -------------------------------------------------------------------------------- /Paints-UNDO/imgs/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZXYWQ/ComfyUI-ZZXYWQ/f5b3deb3c7f8bc7f0a7d1f83fca54d14b4ea7361/Paints-UNDO/imgs/3.jpg -------------------------------------------------------------------------------- /Paints-UNDO/memory_management.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from contextlib import contextmanager 3 | 4 | 5 | high_vram = False 6 | gpu = torch.device('cuda') 7 | cpu = torch.device('cpu') 8 | 9 | torch.zeros((1, 1)).to(gpu, torch.float32) 10 | torch.cuda.empty_cache() 11 | 12 | models_in_gpu = [] 13 | 14 | 15 | @contextmanager 16 | def movable_bnb_model(m): 17 | if hasattr(m, 'quantization_method'): 18 | m.quantization_method_backup = m.quantization_method 19 | del m.quantization_method 20 | try: 21 | yield None 22 | finally: 23 | if hasattr(m, 'quantization_method_backup'): 24 | m.quantization_method = m.quantization_method_backup 25 | del m.quantization_method_backup 26 | return 27 | 28 | 29 | def load_models_to_gpu(models): 30 | global models_in_gpu 31 | 32 | if not isinstance(models, (tuple, list)): 33 | models = [models] 34 | 35 | models_to_remain = [m for m in set(models) if m in models_in_gpu] 36 | models_to_load = [m for m in set(models) if m not in models_in_gpu] 37 | models_to_unload = [m for m in set(models_in_gpu) if m not in models_to_remain] 38 | 39 | if not high_vram: 40 | for m in models_to_unload: 41 | with movable_bnb_model(m): 42 | m.to(cpu) 43 | print('Unload to CPU:', m.__class__.__name__) 44 | models_in_gpu = models_to_remain 45 | 46 | for m in models_to_load: 47 | with movable_bnb_model(m): 48 | m.to(gpu) 49 | print('Load to GPU:', m.__class__.__name__) 50 | 51 | models_in_gpu = list(set(models_in_gpu + models)) 52 | torch.cuda.empty_cache() 53 | return 54 | 55 | 56 | def unload_all_models(extra_models=None): 57 | global models_in_gpu 58 | 59 | if extra_models is None: 60 | extra_models = [] 61 | 62 | if not isinstance(extra_models, (tuple, list)): 63 | extra_models = [extra_models] 64 | 65 | models_in_gpu = list(set(models_in_gpu + extra_models)) 66 | 67 | return load_models_to_gpu([]) 68 | -------------------------------------------------------------------------------- /Paints-UNDO/requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.28.0 2 | transformers==4.41.1 3 | gradio==4.31.5 4 | bitsandbytes==0.43.1 5 | accelerate==0.30.1 6 | protobuf==3.20 7 | opencv-python 8 | tensorboardX 9 | safetensors 10 | pillow 11 | einops 12 | torch 13 | peft 14 | xformers 15 | onnxruntime 16 | av 17 | torchvision 18 | -------------------------------------------------------------------------------- /Paints-UNDO/wd14tagger.py: -------------------------------------------------------------------------------- 1 | # https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags 2 | 3 | 4 | import os 5 | import csv 6 | import numpy as np 7 | import onnxruntime as ort 8 | 9 | from PIL import Image 10 | from onnxruntime import InferenceSession 11 | from torch.hub import download_url_to_file 12 | 13 | 14 | global_model = None 15 | global_csv = None 16 | 17 | 18 | def download_model(url, local_path): 19 | if os.path.exists(local_path): 20 | return local_path 21 | 22 | temp_path = local_path + '.tmp' 23 | download_url_to_file(url=url, dst=temp_path) 24 | os.rename(temp_path, local_path) 25 | return local_path 26 | 27 | 28 | def default_interrogator(image, threshold=0.35, character_threshold=0.85, exclude_tags=""): 29 | global global_model, global_csv 30 | 31 | model_name = "wd-v1-4-moat-tagger-v2" 32 | 33 | model_onnx_filename = download_model( 34 | url=f'https://huggingface.co/lllyasviel/misc/resolve/main/{model_name}.onnx', 35 | local_path=f'./{model_name}.onnx', 36 | ) 37 | 38 | model_csv_filename = download_model( 39 | url=f'https://huggingface.co/lllyasviel/misc/resolve/main/{model_name}.csv', 40 | local_path=f'./{model_name}.csv', 41 | ) 42 | 43 | if global_model is not None: 44 | model = global_model 45 | else: 46 | # assert 'CUDAExecutionProvider' in ort.get_available_providers(), 'CUDA Install Failed!' 47 | # model = InferenceSession(model_onnx_filename, providers=['CUDAExecutionProvider']) 48 | model = InferenceSession(model_onnx_filename, providers=['CPUExecutionProvider']) 49 | global_model = model 50 | 51 | input = model.get_inputs()[0] 52 | height = input.shape[1] 53 | 54 | if isinstance(image, str): 55 | image = Image.open(image) # RGB 56 | elif isinstance(image, np.ndarray): 57 | image = Image.fromarray(image) 58 | else: 59 | image = image 60 | 61 | ratio = float(height) / max(image.size) 62 | new_size = tuple([int(x*ratio) for x in image.size]) 63 | image = image.resize(new_size, Image.LANCZOS) 64 | square = Image.new("RGB", (height, height), (255, 255, 255)) 65 | square.paste(image, ((height-new_size[0])//2, (height-new_size[1])//2)) 66 | 67 | image = np.array(square).astype(np.float32) 68 | image = image[:, :, ::-1] # RGB -> BGR 69 | image = np.expand_dims(image, 0) 70 | 71 | if global_csv is not None: 72 | csv_lines = global_csv 73 | else: 74 | csv_lines = [] 75 | with open(model_csv_filename) as f: 76 | reader = csv.reader(f) 77 | next(reader) 78 | for row in reader: 79 | csv_lines.append(row) 80 | global_csv = csv_lines 81 | 82 | tags = [] 83 | general_index = None 84 | character_index = None 85 | for line_num, row in enumerate(csv_lines): 86 | if general_index is None and row[2] == "0": 87 | general_index = line_num 88 | elif character_index is None and row[2] == "4": 89 | character_index = line_num 90 | tags.append(row[1]) 91 | 92 | label_name = model.get_outputs()[0].name 93 | probs = model.run([label_name], {input.name: image})[0] 94 | 95 | result = list(zip(tags, probs[0])) 96 | 97 | general = [item for item in result[general_index:character_index] if item[1] > threshold] 98 | character = [item for item in result[character_index:] if item[1] > character_threshold] 99 | 100 | all = character + general 101 | remove = [s.strip() for s in exclude_tags.lower().split(",")] 102 | all = [tag for tag in all if tag[0] not in remove] 103 | 104 | res = ", ".join((item[0].replace("(", "\\(").replace(")", "\\)") for item in all)).replace('_', ' ') 105 | return res 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ZZX Nodes 2 | ===== 3 | 4 | ZZX_PaintsUndo 5 | == 6 | 7 | ![menu](workflows/PaintsUndo.png) 8 | 9 | Original author https://github.com/lllyasviel/Paints-UNDO 10 | 11 | Reference to the original author's file, now you can undo the line drawing, set it to black if there is a problem. 12 | 13 | Known bugs: 14 | Three pictures will be output at a time. 15 | 16 | Plan: 17 | 18 | ✓Undo each part of the line drawing (picture) 19 | 20 | xUndo coloring (picture) 21 | 22 | x (video) 23 | 24 | Node file: ZZX_PaintsUndo.py, there are some comments and some attempts (some useless attempts were not deleted) 25 | 26 | 27 | 原作者https://github.com/lllyasviel/Paints-UNDO 28 | 29 | 引用原作者文件,现在可以撤销线稿,设置若有问题则输出为黑. 30 | 31 | 已知bug: 32 | 一次会输出三张图。 33 | 34 | 计划: 35 | 36 | ✓撤销成为 线稿各部分(图片) 37 | 38 | x撤销着色(图片) 39 | 40 | x(视频) 41 | 42 | 节点文件:ZZX_PaintsUndo.py,有一些注释和一些尝试(没删干净一些无用尝试) 43 | 44 | 45 | 46 | 2.StreamRecorder 47 | == 48 | A streaming media receives a local recording node: 49 | Streaming formats available: rtmp, .m3u8. 50 | Optional formats: mp4, mov, mkv, avi. 51 | Optional encoding: av1, h264, h265. 52 | Optional: whether to use local time,If yes, then this start_time takes effect, the time inside is hours/minutes/seconds. 53 | Optional: recording time (slider). 54 | Optional: whether to use cuda. 55 | 56 | 57 | 一个流媒体收到录本地的节点: 58 | 可输入流媒体格式:rtmp,.m3u8。 59 | 可选格式:mp4,mov,mkv,avi。 60 | 可选编码:av1,h264,h265。 61 | 可选:是否使用本地时间, 62 | 如果是,则这start_time生效,里面的时间是时/分/秒。 63 | 可选:录制时间(滑块)。 64 | 可选:是否使用cuda。 65 | 66 | AVI (avi): h264 67 | 68 | MOV (mov): h264, hevc 69 | 70 | MKV (mkv): av1, h264, hevc 71 | 72 | MP4 (mp4): av1, h264, hevc 73 | 74 | HLS (hls): av1, h264, hevc 75 | 76 | DASH (dash): av1, h264, hevc 77 | 78 | MSS (mss): av1, h264, hevc 79 | 80 | SRT (srt): h264, hevc 81 | 82 | FLV (flv): h264 83 | 84 | WebM (webm): av1, h264 85 | 86 | RTMP (rtmp): h264 87 | 88 | RTSP (rtsp): h264, hevc 89 | 90 | M3U8 (m3u8): av1, h264, hevc 91 | 92 | 1.VideoFormatConverter: 93 | == 94 | A video transcoding node: 95 | Optional formats: mp4, mov, mkv, avi. 96 | Optional encoding: av1, h264, h265. 97 | Optional frame rate: 8, 15, 24, 25, 30, 50, 59, 60, 120. 98 | Optional width and height: (not listed one by one). 99 | Optional audio: mp3, aac. 100 | Optional frequency: 44100, 48000. 101 | 102 | 一个视频转码的节点: 103 | 可选格式:mp4,mov,mkv,avi。 104 | 可选编码:av1,h264,h265。 105 | 可选帧率:8,15,24,25,30,50,59,60,120。 106 | 可选宽高:(不一一列举)。 107 | 可选音频:mp3,aac。 108 | 可选频率:44100,48000。 109 | 110 | 111 | welcome 112 | == 113 | Options can be added to the above nodes, 114 | Please give feedback if you have any questions 115 | 116 | 以上节点均可添加选项, 117 | 有问题请反馈 118 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | NODE_CLASS_MAPPINGS = {} 5 | 6 | # 自动导入 nodes 文件夹中的所有符合条件的节点 7 | nodes_folder = os.path.dirname(__file__) + os.sep + 'nodes' 8 | for node in os.listdir(nodes_folder): 9 | if node.startswith('ZZX_') and node.endswith('.py'): 10 | node = node.split('.')[0] 11 | node_import = importlib.import_module('custom_nodes.ComfyUI-ZZXYWQ.nodes.' + node) 12 | print('Imported node from nodes: ' + node) 13 | # 获取节点类映射并更新全局 NODE_CLASS_MAPPINGS 14 | NODE_CLASS_MAPPINGS.update(node_import.NODE_CLASS_MAPPINGS) 15 | 16 | # 自动导入 Paints-UNDO 文件夹中的所有符合条件的节点 17 | paints_undo_folder = os.path.dirname(__file__) + os.sep + 'Paints-UNDO' 18 | for node in os.listdir(paints_undo_folder): 19 | if node.startswith('ZZX_') and node.endswith('.py'): 20 | node = node.split('.')[0] 21 | node_import = importlib.import_module('custom_nodes.ComfyUI-ZZXYWQ.Paints-UNDO.' + node) 22 | print('Imported node from Paints-UNDO: ' + node) 23 | # 获取节点类映射并更新全局 NODE_CLASS_MAPPINGS 24 | NODE_CLASS_MAPPINGS.update(node_import.NODE_CLASS_MAPPINGS) 25 | -------------------------------------------------------------------------------- /nodes/ZZX_Stream.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import datetime 4 | import tkinter as tk 5 | from tkinter import filedialog 6 | import subprocess 7 | import cv2 8 | 9 | class StreamRecorder: 10 | 11 | def __init__(self): 12 | self.input_url = "" 13 | self.output_path = "" 14 | 15 | @classmethod 16 | def INPUT_TYPES(s): 17 | now = datetime.datetime.now() 18 | default_start_time = now.strftime("%H/%M/%S") # Example default time 19 | return { 20 | "required": { 21 | "stream_url": ("STRING", {"multiline": False, "default": ""}), 22 | "use_local_time": (["true", "false"], {"default": "false"}), 23 | "start_time": ("STRING", {"multiline": False, "default": default_start_time}), # Example default time 24 | "record_hours": ("INT", {"default": 0, "min": 0, "max": 12, "step": 1, "display": "slider"}), 25 | "record_minutes": ("INT", {"default": 0, "min": 0, "max": 59, "step": 1, "display": "slider"}), 26 | "record_seconds": ("INT", {"default": 60, "min": 0, "max": 59, "step": 1, "display": "slider"}), 27 | "output_filename": ("STRING", {"multiline": False, "default": ""}), 28 | "video_format": (["avi", "mov", "mkv", "mp4", "hls", "dash", "mss", "srt", "flv", "webm", "rtmp", "rtsp", "m3u8", "http", "https"], {"default": "mp4"}), 29 | "codec": (["av1", "h264", "h264(NVENC)", "hevc", "hevc(NVENC)"], {"default": "h264(NVENC)"}), 30 | "video_quality": ("INT", {"default": 10, "min": 5, "max": 40, "step": 1, "display": "slider"}), 31 | "use_cuda": (["true", "false"], {"default": "true"}), # Add CUDA option 32 | "output_path": ("STRING", {"multiline": False, "default": ""}), 33 | }, 34 | } 35 | 36 | RETURN_TYPES = ("STRING",) 37 | RETURN_NAMES = ("output_filename",) 38 | FUNCTION = "record_stream" 39 | 40 | CATEGORY = "ZZX/Stream" 41 | 42 | def select_output_file(self): 43 | root = tk.Tk() 44 | root.withdraw() 45 | self.output_path = filedialog.asksaveasfilename(title="Select output video file") 46 | return self.output_path 47 | 48 | def get_unique_filename(self, output_path, output_filename, video_format): 49 | base_name, ext = os.path.splitext(output_filename) 50 | counter = 0 51 | while True: 52 | new_filename = f"{base_name}_{counter:04d}.{video_format}" 53 | full_path = os.path.join(output_path, new_filename) 54 | if not os.path.exists(full_path): 55 | return full_path 56 | counter += 1 57 | 58 | def calculate_bitrate(self, video_quality): 59 | min_quality = 1 60 | max_quality = 40 61 | min_bitrate = 100 # kbps for quality=40 62 | max_bitrate = 10000 # kbps for quality=1 63 | return int((max_bitrate - min_bitrate) / (max_quality - min_quality) * (video_quality - min_quality) + min_bitrate) 64 | 65 | def check_format_support(self, video_format, codec): 66 | format_support = { 67 | "avi": ["h264", "h264(NVENC)"], 68 | "mov": ["h264", "h264(NVENC)", "hevc", "hevc(NVENC)"], 69 | "mkv": ["av1", "h264", "h264(NVENC)", "hevc", "hevc(NVENC)"], 70 | "mp4": ["av1", "h264", "h264(NVENC)", "hevc", "hevc(NVENC)"], 71 | "hls": ["av1", "h264", "h264(NVENC)", "hevc", "hevc(NVENC)"], 72 | "dash": ["av1", "h264", "h264(NVENC)", "hevc", "hevc(NVENC)"], 73 | "mss": ["av1", "h264", "h264(NVENC)", "hevc", "hevc(NVENC)"], 74 | "srt": ["h264", "h264(NVENC)", "hevc", "hevc(NVENC)"], 75 | "flv": ["h264", "h264(NVENC)"], 76 | "webm": ["av1", "h264", "h264(NVENC)"], 77 | "rtmp": ["h264", "h264(NVENC)"], 78 | "rtsp": ["h264", "h264(NVENC)", "hevc", "hevc(NVENC)"], 79 | "m3u8": ["av1", "h264", "h264(NVENC)", "hevc", "hevc(NVENC)"], 80 | "http": ["av1", "h264", "h264(NVENC)", "hevc", "hevc(NVENC)"], 81 | "https": ["av1", "h264", "h264(NVENC)", "hevc", "hevc(NVENC)"] 82 | } 83 | 84 | if codec not in format_support.get(video_format, []): 85 | raise ValueError( 86 | "请选择正确编码模式:\n" 87 | "AVI (avi): h264\n" 88 | "MOV (mov): h264, hevc\n" 89 | "MKV (mkv): av1, h264, hevc\n" 90 | "MP4 (mp4): av1, h264, hevc\n" 91 | "HLS (hls): av1, h264, hevc\n" 92 | "DASH (dash): av1, h264, hevc\n" 93 | "MSS (mss): av1, h264, hevc\n" 94 | "SRT (srt): h264, hevc\n" 95 | "FLV (flv): h264\n" 96 | "WebM (webm): av1, h264\n" 97 | "RTMP (rtmp): h264\n" 98 | "RTSP (rtsp): h264, hevc\n" 99 | "M3U8 (m3u8): av1, h264, hevc\n" 100 | "HTTP (http): av1, h264, hevc\n" 101 | "HTTPS (https): av1, h264, hevc" 102 | ) 103 | 104 | def record_stream(self, stream_url, use_local_time, start_time, record_hours, record_minutes, record_seconds, output_filename, video_format, codec, video_quality, use_cuda, output_path): 105 | self.check_format_support(video_format, codec) 106 | 107 | if not stream_url: 108 | raise ValueError("Stream URL is required") 109 | 110 | if not output_path: 111 | output_path = self.select_output_file() 112 | 113 | # Ensure the output path has the correct file extension 114 | if not output_path.endswith("/"): 115 | output_path += "/" 116 | 117 | # Ensure the output filename has the correct format extension 118 | if not output_filename.endswith(f".{video_format}"): 119 | output_filename += f".{video_format}" 120 | 121 | # Get a unique filename 122 | output_full_path = self.get_unique_filename(output_path, output_filename, video_format) 123 | 124 | # Calculate bitrate based on video_quality 125 | bitrate = self.calculate_bitrate(video_quality) 126 | 127 | # Map codec to correct FFmpeg encoder names 128 | codec_map = { 129 | "h264(NVENC)": "h264_nvenc", 130 | "hevc(NVENC)": "hevc_nvenc", 131 | "hevc": "libx265", 132 | "av1": "libaom-av1" 133 | } 134 | if codec in codec_map: 135 | codec = codec_map[codec] 136 | 137 | # Calculate the total record duration in seconds 138 | total_duration = record_hours * 3600 + record_minutes * 60 + record_seconds 139 | 140 | # Handle local time scheduling 141 | if use_local_time == "true": 142 | now = datetime.datetime.now() 143 | start_time_parts = list(map(int, start_time.split('/'))) 144 | start_dt = now.replace(hour=start_time_parts[0], minute=start_time_parts[1], second=start_time_parts[2], microsecond=0) 145 | 146 | if start_dt < now: 147 | raise ValueError("输入有误,请重新输入时间") 148 | 149 | while datetime.datetime.now() < start_dt: 150 | time.sleep(1) 151 | 152 | # Construct the FFmpeg command for recording the stream 153 | cmd = ["ffmpeg"] 154 | 155 | if use_cuda == "true": 156 | cmd.extend(["-hwaccel", "cuda"]) 157 | 158 | cmd.extend([ 159 | "-i", stream_url, 160 | "-t", str(int(total_duration)), 161 | "-c:v", codec, 162 | "-b:v", f"{bitrate}k", 163 | "-c:a", "copy", # Copy the audio codec from the source 164 | "-y", output_full_path 165 | ]) 166 | 167 | if video_format == "hls": 168 | cmd.extend(["-f", "hls", "-hls_time", "10", "-hls_playlist_type", "vod", "-hls_segment_filename", os.path.join(output_path, "segment_%03d.ts")]) 169 | elif video_format == "dash": 170 | cmd.extend(["-f", "dash", "-seg_duration", "10", "-init_seg_name", "init.m4s", "-media_seg_name", "segment_%03d.m4s"]) 171 | elif video_format == "mss": 172 | cmd.extend(["-f", "hds", "-hls_time", "10", "-hls_playlist_type", "vod", "-hls_segment_filename", os.path.join(output_path, "segment_%03d.f4m")]) 173 | elif video_format == "srt": 174 | cmd.extend(["-f", "mpegts"]) 175 | elif video_format == "flv": 176 | cmd.extend(["-f", "flv"]) 177 | elif video_format == "webm": 178 | cmd.extend(["-f", "webm"]) 179 | elif video_format == "rtmp": 180 | cmd.extend(["-f", "flv"]) # RTMP streams are usually in FLV format 181 | elif video_format == "rtsp": 182 | cmd.extend(["-f", "rtsp"]) 183 | elif video_format == "m3u8": 184 | cmd.extend(["-f", "hls"]) 185 | elif video_format == "http" or video_format == "https": 186 | cmd.extend(["-f", "hls"]) 187 | 188 | # Run the command 189 | result = subprocess.run(cmd, capture_output=True, text=True) 190 | 191 | if result.returncode != 0: 192 | print(f"FFmpeg command: {' '.join(cmd)}") 193 | raise RuntimeError(f"FFmpeg error: {result.stderr}") 194 | 195 | return (output_full_path,) 196 | 197 | # A dictionary that contains all nodes you want to export with their names 198 | NODE_CLASS_MAPPINGS = { 199 | "StreamRecorder": StreamRecorder 200 | } 201 | 202 | # A dictionary that contains the friendly/humanly readable titles for the nodes 203 | NODE_DISPLAY_NAME_MAPPINGS = { 204 | "StreamRecorder": "Stream Recorder" 205 | } 206 | -------------------------------------------------------------------------------- /nodes/ZZX_VFC.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tkinter as tk 3 | from tkinter import filedialog 4 | import subprocess 5 | import cv2 6 | 7 | class VideoFormatConverter: 8 | 9 | def __init__(self): 10 | self.input_path = "" 11 | self.output_path = "" 12 | 13 | @classmethod 14 | def INPUT_TYPES(s): 15 | return { 16 | "required": { 17 | "video_path": ("STRING", {"multiline": False, "default": ""}), 18 | "output_enabled": (["true", "false"],), 19 | "output_filename": ("STRING", {"multiline": False, "default": ""}), 20 | "video_format": (["avi", "mov", "mkv", "mp4"],{"default": "mp4"}), 21 | "codec": (["av1", "h264", "h264(NVENC)", "hevc", "hevc(NVENC)"],{"default": "h264"}), 22 | "video_quality": ("INT", {"default": 10, "min": 5, "max": 40, "step": 1, "display": "slider"}), 23 | "frame_rate": (["8", "15", "24", "25", "30", "50", "59", "60", "120"],{"default": "25"}), 24 | "opencl_acceleration": (["enable", "disable"],), 25 | "video_width": (["272", "300", "320", "360", "400", "450", "480", "512", "540", "600", "640", "720", "800", "960", "1080", "1280", "1440", "1536", "1920", "2560"], {"default": "1280"}), 26 | "video_height": (["272", "300", "320", "360", "400", "450", "480", "512", "540", "600", "640", "720", "800", "960", "1080", "1280", "1440", "1536", "1920", "2560"], {"default": "720"}), 27 | "scaling_filter": (["bilinear", "bicubic", "neighbor", "area", "bicublin", "lanczos"],{"default": "bicubic"}), 28 | "processing_method": (["fill", "crop"],), 29 | "audio_codec": (["copy", "mp3", "aac"],{"default": "aac"}), 30 | "bit_rate": (["96", "128", "192"],{"default": "192"}), 31 | "audio_channels": (["original", "mono", "stereo"],{"default": "stereo"}), 32 | "sample_rate": (["44100", "48000"],{"default": "48000"}), 33 | "output_path": ("STRING", {"multiline": False, "default": ""}), 34 | }, 35 | } 36 | 37 | RETURN_TYPES = ("STRING", "VHS_VIDEOINFO") 38 | RETURN_NAMES = ("output_filename", "video_info") 39 | FUNCTION = "process_video" 40 | 41 | CATEGORY = "ZZX/Video" 42 | 43 | def select_input_file(self): 44 | root = tk.Tk() 45 | root.withdraw() 46 | self.input_path = filedialog.askopenfilename(title="Select input video file") 47 | return self.input_path 48 | 49 | def select_output_file(self): 50 | root = tk.Tk() 51 | root.withdraw() 52 | self.output_path = filedialog.asksaveasfilename(title="Select output video file") 53 | return self.output_path 54 | 55 | def get_unique_filename(self, output_path, output_filename, video_format): 56 | base_name, ext = os.path.splitext(output_filename) 57 | counter = 0 58 | while True: 59 | new_filename = f"{base_name}_{counter:04d}{ext}" 60 | full_path = os.path.join(output_path, new_filename) 61 | if not os.path.exists(full_path): 62 | return full_path 63 | counter += 1 64 | 65 | def calculate_bitrate(self, original_bitrate, video_quality): 66 | min_quality = 1 67 | max_quality = 40 68 | min_bitrate = 100 # kbps for quality=40 69 | max_bitrate = 10000 # kbps for quality=1 70 | return int((max_bitrate - min_bitrate) / (min_quality - max_quality) * (video_quality - max_quality) + max_bitrate) 71 | 72 | def process_video(self, video_path, output_enabled, output_filename, video_format, codec, video_quality, frame_rate, opencl_acceleration, video_width, video_height, scaling_filter, processing_method, audio_codec, bit_rate, audio_channels, sample_rate, output_path): 73 | valid_codecs = { 74 | "avi": ["av1", "h264", "h264(NVENC)"], 75 | "mov": ["h264", "h264(NVENC)", "hevc", "hevc(NVENC)"], 76 | "mkv": ["av1", "h264", "h264(NVENC)", "hevc", "hevc(NVENC)"], 77 | "mp4": ["av1", "h264", "h264(NVENC)", "hevc", "hevc(NVENC)"] 78 | } 79 | 80 | if codec not in valid_codecs.get(video_format, []): 81 | raise ValueError("选择的格式不正确Incorrect format selected") 82 | 83 | if not video_path: 84 | video_path = self.select_input_file() 85 | 86 | if output_enabled == "false": 87 | return ("Output disabled",) 88 | 89 | if not output_path: 90 | output_path = self.select_output_file() 91 | 92 | # Ensure the input paths are correctly formatted 93 | video_path = video_path.replace("\\", "/") 94 | output_path = output_path.replace("\\", "/") 95 | 96 | # Ensure the output path has the correct file extension 97 | if not output_path.endswith("/"): 98 | output_path += "/" 99 | 100 | # Ensure the output filename has the correct format extension 101 | if not output_filename.endswith(f".{video_format}"): 102 | output_filename += f".{video_format}" 103 | 104 | # Get a unique filename 105 | output_full_path = self.get_unique_filename(output_path, output_filename, video_format) 106 | 107 | # Read original bitrate from video file 108 | video_cap = cv2.VideoCapture(video_path) 109 | if not video_cap.isOpened(): 110 | raise ValueError(f"{video_path} could not be loaded with cv.") 111 | source_fps = video_cap.get(cv2.CAP_PROP_FPS) 112 | source_width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 113 | source_height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 114 | source_frame_count = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT)) 115 | source_duration = source_frame_count / source_fps 116 | original_bitrate = video_cap.get(cv2.CAP_PROP_BITRATE) / 1000 117 | video_cap.release() 118 | 119 | # Calculate bitrate based on video_quality 120 | bitrate = self.calculate_bitrate(original_bitrate, video_quality) 121 | 122 | # Map codec to correct FFmpeg encoder names 123 | codec_map = { 124 | "h264(NVENC)": "h264_nvenc", 125 | "hevc(NVENC)": "hevc_nvenc", 126 | "hevc": "libx265", 127 | "av1": "libaom-av1" 128 | } 129 | if codec in codec_map: 130 | codec = codec_map[codec] 131 | 132 | # Construct the FFmpeg command based on video format and codec 133 | cmd = [ 134 | "ffmpeg", 135 | "-i", video_path, 136 | "-c:v", codec, 137 | "-b:v", f"{bitrate}k", 138 | "-r", frame_rate, 139 | "-vf", f"scale={video_width}:{video_height}:flags={scaling_filter}", 140 | "-c:a", audio_codec, 141 | "-b:a", f"{bit_rate}k", 142 | "-ac", "2" if audio_channels == "stereo" else "1" if audio_channels == "mono" else "copy", 143 | "-ar", sample_rate, 144 | "-y", output_full_path 145 | ] 146 | 147 | if opencl_acceleration == "enable": 148 | cmd.insert(1, "-hwaccel") 149 | cmd.insert(2, "opencl") 150 | 151 | # Run the command 152 | result = subprocess.run(cmd, capture_output=True, text=True) 153 | 154 | if result.returncode != 0: 155 | raise RuntimeError(f"FFmpeg error: {result.stderr}") 156 | 157 | # Extract video information similar to load_video_nodes.py 158 | video_cap = cv2.VideoCapture(output_full_path) 159 | loaded_fps = video_cap.get(cv2.CAP_PROP_FPS) 160 | loaded_width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 161 | loaded_height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 162 | loaded_frame_count = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT)) 163 | loaded_duration = loaded_frame_count / loaded_fps 164 | video_info = { 165 | "source_fps": source_fps, 166 | "source_frame_count": source_frame_count, 167 | "source_duration": source_duration, 168 | "source_width": source_width, 169 | "source_height": source_height, 170 | "loaded_fps": loaded_fps, 171 | "loaded_frame_count": loaded_frame_count, 172 | "loaded_duration": loaded_duration, 173 | "loaded_width": loaded_width, 174 | "loaded_height": loaded_height, 175 | } 176 | video_cap.release() 177 | 178 | return (output_full_path, video_info) 179 | 180 | # A dictionary that contains all nodes you want to export with their names 181 | NODE_CLASS_MAPPINGS = { 182 | "VideoFormatConverter": VideoFormatConverter 183 | } 184 | 185 | # A dictionary that contains the friendly/humanly readable titles for the nodes 186 | NODE_DISPLAY_NAME_MAPPINGS = { 187 | "VideoFormatConverter": "Video Format Converter" 188 | } 189 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "ZZX nodes" 3 | description = "Node(s): VideoFormatConverter(VFC)" 4 | version = "1.0.0" 5 | license = "LICENSE" 6 | 7 | [project.urls] 8 | Repository = "https://github.com/ZZXYWQ/ComfyUI-ZZXYWQ" 9 | # Used by Comfy Registry https://comfyregistry.org 10 | 11 | [tool.comfy] 12 | PublisherId = "ZZXYWQ" 13 | DisplayName = "ComfyUI-ZZX" 14 | Icon = "" 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZXYWQ/ComfyUI-ZZXYWQ/f5b3deb3c7f8bc7f0a7d1f83fca54d14b4ea7361/requirements.txt -------------------------------------------------------------------------------- /workflows/PaintsUndo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZXYWQ/ComfyUI-ZZXYWQ/f5b3deb3c7f8bc7f0a7d1f83fca54d14b4ea7361/workflows/PaintsUndo.png -------------------------------------------------------------------------------- /workflows/StreamRecorder+VideoFormatConverter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZXYWQ/ComfyUI-ZZXYWQ/f5b3deb3c7f8bc7f0a7d1f83fca54d14b4ea7361/workflows/StreamRecorder+VideoFormatConverter.png -------------------------------------------------------------------------------- /workflows/VideoFormatConverter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZXYWQ/ComfyUI-ZZXYWQ/f5b3deb3c7f8bc7f0a7d1f83fca54d14b4ea7361/workflows/VideoFormatConverter.png -------------------------------------------------------------------------------- /workflows/workflow-StreamRecorder.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 16, 3 | "last_link_id": 9, 4 | "nodes": [ 5 | { 6 | "id": 12, 7 | "type": "VHS_VideoInfo", 8 | "pos": [ 9 | 910, 10 | 440 11 | ], 12 | "size": { 13 | "0": 393, 14 | "1": 206 15 | }, 16 | "flags": {}, 17 | "order": 5, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "video_info", 22 | "type": "VHS_VIDEOINFO", 23 | "link": 5 24 | } 25 | ], 26 | "outputs": [ 27 | { 28 | "name": "source_fps🟨", 29 | "type": "FLOAT", 30 | "links": null, 31 | "shape": 3 32 | }, 33 | { 34 | "name": "source_frame_count🟨", 35 | "type": "INT", 36 | "links": null, 37 | "shape": 3 38 | }, 39 | { 40 | "name": "source_duration🟨", 41 | "type": "FLOAT", 42 | "links": null, 43 | "shape": 3 44 | }, 45 | { 46 | "name": "source_width🟨", 47 | "type": "INT", 48 | "links": [ 49 | 6 50 | ], 51 | "shape": 3, 52 | "slot_index": 3 53 | }, 54 | { 55 | "name": "source_height🟨", 56 | "type": "INT", 57 | "links": [ 58 | 7 59 | ], 60 | "shape": 3, 61 | "slot_index": 4 62 | }, 63 | { 64 | "name": "loaded_fps🟦", 65 | "type": "FLOAT", 66 | "links": null, 67 | "shape": 3 68 | }, 69 | { 70 | "name": "loaded_frame_count🟦", 71 | "type": "INT", 72 | "links": null, 73 | "shape": 3 74 | }, 75 | { 76 | "name": "loaded_duration🟦", 77 | "type": "FLOAT", 78 | "links": null, 79 | "shape": 3 80 | }, 81 | { 82 | "name": "loaded_width🟦", 83 | "type": "INT", 84 | "links": null, 85 | "shape": 3 86 | }, 87 | { 88 | "name": "loaded_height🟦", 89 | "type": "INT", 90 | "links": null, 91 | "shape": 3 92 | } 93 | ], 94 | "properties": { 95 | "Node name for S&R": "VHS_VideoInfo" 96 | }, 97 | "widgets_values": {} 98 | }, 99 | { 100 | "id": 14, 101 | "type": "easy showAnything", 102 | "pos": [ 103 | 1350, 104 | 580 105 | ], 106 | "size": { 107 | "0": 210, 108 | "1": 80 109 | }, 110 | "flags": {}, 111 | "order": 8, 112 | "mode": 0, 113 | "inputs": [ 114 | { 115 | "name": "anything", 116 | "type": "*", 117 | "link": 7 118 | } 119 | ], 120 | "title": "easy showAnything", 121 | "properties": { 122 | "Node name for S&R": "easy showAnything" 123 | }, 124 | "widgets_values": [ 125 | "1080" 126 | ] 127 | }, 128 | { 129 | "id": 13, 130 | "type": "easy showAnything", 131 | "pos": [ 132 | 1350, 133 | 520 134 | ], 135 | "size": { 136 | "0": 210, 137 | "1": 80 138 | }, 139 | "flags": {}, 140 | "order": 7, 141 | "mode": 0, 142 | "inputs": [ 143 | { 144 | "name": "anything", 145 | "type": "*", 146 | "link": 6 147 | } 148 | ], 149 | "title": "easy showAnything", 150 | "properties": { 151 | "Node name for S&R": "easy showAnything" 152 | }, 153 | "widgets_values": [ 154 | "1920" 155 | ] 156 | }, 157 | { 158 | "id": 10, 159 | "type": "ShowText|pysssss", 160 | "pos": [ 161 | 550, 162 | 40 163 | ], 164 | "size": { 165 | "0": 320, 166 | "1": 80 167 | }, 168 | "flags": {}, 169 | "order": 2, 170 | "mode": 0, 171 | "inputs": [ 172 | { 173 | "name": "text", 174 | "type": "STRING", 175 | "link": 3, 176 | "widget": { 177 | "name": "text" 178 | } 179 | } 180 | ], 181 | "outputs": [ 182 | { 183 | "name": "STRING", 184 | "type": "STRING", 185 | "links": null, 186 | "shape": 6 187 | } 188 | ], 189 | "properties": { 190 | "Node name for S&R": "ShowText|pysssss" 191 | }, 192 | "widgets_values": [ 193 | "", 194 | "E:\\BaiduSyncdisk\\Input/Arirang Korea stream_0000.mp4" 195 | ] 196 | }, 197 | { 198 | "id": 11, 199 | "type": "ShowText|pysssss", 200 | "pos": [ 201 | 900, 202 | 40 203 | ], 204 | "size": { 205 | "0": 320, 206 | "1": 80 207 | }, 208 | "flags": {}, 209 | "order": 3, 210 | "mode": 0, 211 | "inputs": [ 212 | { 213 | "name": "text", 214 | "type": "STRING", 215 | "link": 4, 216 | "widget": { 217 | "name": "text" 218 | } 219 | } 220 | ], 221 | "outputs": [ 222 | { 223 | "name": "STRING", 224 | "type": "STRING", 225 | "links": null, 226 | "shape": 6 227 | } 228 | ], 229 | "properties": { 230 | "Node name for S&R": "ShowText|pysssss" 231 | }, 232 | "widgets_values": [ 233 | "", 234 | "E:/BaiduSyncdisk/Input/Arirang Korea S2V_0000.mov" 235 | ] 236 | }, 237 | { 238 | "id": 9, 239 | "type": "VideoFormatConverter", 240 | "pos": [ 241 | 560, 242 | 180 243 | ], 244 | "size": { 245 | "0": 320, 246 | "1": 460 247 | }, 248 | "flags": {}, 249 | "order": 1, 250 | "mode": 0, 251 | "inputs": [ 252 | { 253 | "name": "video_path", 254 | "type": "STRING", 255 | "link": 2, 256 | "widget": { 257 | "name": "video_path" 258 | } 259 | } 260 | ], 261 | "outputs": [ 262 | { 263 | "name": "output_filename", 264 | "type": "STRING", 265 | "links": [ 266 | 4, 267 | 8 268 | ], 269 | "shape": 3, 270 | "slot_index": 0 271 | }, 272 | { 273 | "name": "video_info", 274 | "type": "VHS_VIDEOINFO", 275 | "links": [ 276 | 5 277 | ], 278 | "shape": 3, 279 | "slot_index": 1 280 | } 281 | ], 282 | "properties": { 283 | "Node name for S&R": "VideoFormatConverter" 284 | }, 285 | "widgets_values": [ 286 | "", 287 | "true", 288 | "Arirang Korea S2V", 289 | "mov", 290 | "hevc(NVENC)", 291 | 12, 292 | "24", 293 | "enable", 294 | "1920", 295 | "1080", 296 | "bicubic", 297 | "fill", 298 | "aac", 299 | "192", 300 | "stereo", 301 | "48000", 302 | "E:\\BaiduSyncdisk\\Input" 303 | ] 304 | }, 305 | { 306 | "id": 15, 307 | "type": "VHS_LoadVideoPath", 308 | "pos": [ 309 | 1230, 310 | 150 311 | ], 312 | "size": [ 313 | 320, 314 | 210 315 | ], 316 | "flags": {}, 317 | "order": 4, 318 | "mode": 0, 319 | "inputs": [ 320 | { 321 | "name": "meta_batch", 322 | "type": "VHS_BatchManager", 323 | "link": null 324 | }, 325 | { 326 | "name": "vae", 327 | "type": "VAE", 328 | "link": null, 329 | "slot_index": 1 330 | }, 331 | { 332 | "name": "video", 333 | "type": "STRING", 334 | "link": 8, 335 | "widget": { 336 | "name": "video" 337 | }, 338 | "slot_index": 2 339 | } 340 | ], 341 | "outputs": [ 342 | { 343 | "name": "IMAGE", 344 | "type": "IMAGE", 345 | "links": [ 346 | 9 347 | ], 348 | "shape": 3, 349 | "slot_index": 0 350 | }, 351 | { 352 | "name": "frame_count", 353 | "type": "INT", 354 | "links": null, 355 | "shape": 3 356 | }, 357 | { 358 | "name": "audio", 359 | "type": "VHS_AUDIO", 360 | "links": null, 361 | "shape": 3 362 | }, 363 | { 364 | "name": "video_info", 365 | "type": "VHS_VIDEOINFO", 366 | "links": null, 367 | "shape": 3 368 | } 369 | ], 370 | "properties": { 371 | "Node name for S&R": "VHS_LoadVideoPath" 372 | }, 373 | "widgets_values": { 374 | "video": "X://insert/path/here.mp4", 375 | "force_rate": 0, 376 | "force_size": "Disabled", 377 | "custom_width": 512, 378 | "custom_height": 512, 379 | "frame_load_cap": 0, 380 | "skip_first_frames": 0, 381 | "select_every_nth": 1, 382 | "videopreview": { 383 | "hidden": false, 384 | "paused": false, 385 | "params": { 386 | "frame_load_cap": 0, 387 | "skip_first_frames": 0, 388 | "force_rate": 0, 389 | "filename": "X://insert/path/here.mp4", 390 | "type": "path", 391 | "format": "video/mp4", 392 | "select_every_nth": 1 393 | } 394 | } 395 | } 396 | }, 397 | { 398 | "id": 8, 399 | "type": "StreamRecorder", 400 | "pos": [ 401 | 210, 402 | 80 403 | ], 404 | "size": { 405 | "0": 315, 406 | "1": 322 407 | }, 408 | "flags": {}, 409 | "order": 0, 410 | "mode": 0, 411 | "outputs": [ 412 | { 413 | "name": "output_filename", 414 | "type": "STRING", 415 | "links": [ 416 | 2, 417 | 3 418 | ], 419 | "shape": 3, 420 | "slot_index": 0 421 | } 422 | ], 423 | "properties": { 424 | "Node name for S&R": "StreamRecorder" 425 | }, 426 | "widgets_values": [ 427 | "http://amdlive.ctnd.com.edgesuite.net/arirang_1ch/smil:arirang_1ch.smil/playlist.m3u8", 428 | "false", 429 | "23/11/44", 430 | 0, 431 | 0, 432 | 4, 433 | "Arirang Korea stream", 434 | "mp4", 435 | "h264(NVENC)", 436 | 30, 437 | "true", 438 | "E:\\BaiduSyncdisk\\Input" 439 | ] 440 | }, 441 | { 442 | "id": 16, 443 | "type": "PreviewImage", 444 | "pos": [ 445 | 1600, 446 | 40 447 | ], 448 | "size": [ 449 | 710, 450 | 820 451 | ], 452 | "flags": {}, 453 | "order": 6, 454 | "mode": 0, 455 | "inputs": [ 456 | { 457 | "name": "images", 458 | "type": "IMAGE", 459 | "link": 9 460 | } 461 | ], 462 | "properties": { 463 | "Node name for S&R": "PreviewImage" 464 | } 465 | } 466 | ], 467 | "links": [ 468 | [ 469 | 2, 470 | 8, 471 | 0, 472 | 9, 473 | 0, 474 | "STRING" 475 | ], 476 | [ 477 | 3, 478 | 8, 479 | 0, 480 | 10, 481 | 0, 482 | "STRING" 483 | ], 484 | [ 485 | 4, 486 | 9, 487 | 0, 488 | 11, 489 | 0, 490 | "STRING" 491 | ], 492 | [ 493 | 5, 494 | 9, 495 | 1, 496 | 12, 497 | 0, 498 | "VHS_VIDEOINFO" 499 | ], 500 | [ 501 | 6, 502 | 12, 503 | 3, 504 | 13, 505 | 0, 506 | "*" 507 | ], 508 | [ 509 | 7, 510 | 12, 511 | 4, 512 | 14, 513 | 0, 514 | "*" 515 | ], 516 | [ 517 | 8, 518 | 9, 519 | 0, 520 | 15, 521 | 2, 522 | "STRING" 523 | ], 524 | [ 525 | 9, 526 | 15, 527 | 0, 528 | 16, 529 | 0, 530 | "IMAGE" 531 | ] 532 | ], 533 | "groups": [], 534 | "config": {}, 535 | "extra": { 536 | "ds": { 537 | "scale": 0.7247295000000004, 538 | "offset": { 539 | "0": 515.2600708007812, 540 | "1": 379.6805114746094 541 | } 542 | } 543 | }, 544 | "version": 0.4 545 | } -------------------------------------------------------------------------------- /workflows/workflow-VideoFormatConverter.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 10, 3 | "last_link_id": 10, 4 | "nodes": [ 5 | { 6 | "id": 7, 7 | "type": "VideoFormatConverter", 8 | "pos": [ 9 | 350, 10 | 230 11 | ], 12 | "size": { 13 | "0": 315, 14 | "1": 462 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "video_info", 22 | "type": "VHS_VIDEOINFO", 23 | "links": [ 24 | 7 25 | ], 26 | "shape": 3, 27 | "slot_index": 0 28 | }, 29 | { 30 | "name": "output_filename", 31 | "type": "STRING", 32 | "links": [ 33 | 6, 34 | 10 35 | ], 36 | "shape": 3, 37 | "slot_index": 1 38 | } 39 | ], 40 | "properties": { 41 | "Node name for S&R": "VideoFormatConverter" 42 | }, 43 | "widgets_values": [ 44 | "E:\\BaiduSyncdisk\\Input\\2024-07-06-163618.mov", 45 | "true", 46 | "Test0709", 47 | "mov", 48 | "h264(NVENC)", 49 | 15, 50 | "25", 51 | "enable", 52 | "960", 53 | "800", 54 | "bicubic", 55 | "fill", 56 | "aac", 57 | "192", 58 | "stereo", 59 | "48000", 60 | "E:\\BaiduSyncdisk\\Input" 61 | ] 62 | }, 63 | { 64 | "id": 3, 65 | "type": "VHS_VideoInfo", 66 | "pos": [ 67 | 700, 68 | 510 69 | ], 70 | "size": { 71 | "0": 393, 72 | "1": 206 73 | }, 74 | "flags": {}, 75 | "order": 1, 76 | "mode": 0, 77 | "inputs": [ 78 | { 79 | "name": "video_info", 80 | "type": "VHS_VIDEOINFO", 81 | "link": 7 82 | } 83 | ], 84 | "outputs": [ 85 | { 86 | "name": "source_fps🟨", 87 | "type": "FLOAT", 88 | "links": null, 89 | "shape": 3 90 | }, 91 | { 92 | "name": "source_frame_count🟨", 93 | "type": "INT", 94 | "links": null, 95 | "shape": 3 96 | }, 97 | { 98 | "name": "source_duration🟨", 99 | "type": "FLOAT", 100 | "links": null, 101 | "shape": 3 102 | }, 103 | { 104 | "name": "source_width🟨", 105 | "type": "INT", 106 | "links": null, 107 | "shape": 3 108 | }, 109 | { 110 | "name": "source_height🟨", 111 | "type": "INT", 112 | "links": null, 113 | "shape": 3 114 | }, 115 | { 116 | "name": "loaded_fps🟦", 117 | "type": "FLOAT", 118 | "links": null, 119 | "shape": 3 120 | }, 121 | { 122 | "name": "loaded_frame_count🟦", 123 | "type": "INT", 124 | "links": null, 125 | "shape": 3 126 | }, 127 | { 128 | "name": "loaded_duration🟦", 129 | "type": "FLOAT", 130 | "links": null, 131 | "shape": 3, 132 | "slot_index": 7 133 | }, 134 | { 135 | "name": "loaded_width🟦", 136 | "type": "INT", 137 | "links": [ 138 | 5 139 | ], 140 | "shape": 3, 141 | "slot_index": 8 142 | }, 143 | { 144 | "name": "loaded_height🟦", 145 | "type": "INT", 146 | "links": null, 147 | "shape": 3 148 | } 149 | ], 150 | "properties": { 151 | "Node name for S&R": "VHS_VideoInfo" 152 | }, 153 | "widgets_values": {} 154 | }, 155 | { 156 | "id": 9, 157 | "type": "VHS_LoadVideoPath", 158 | "pos": [ 159 | 710, 160 | 110 161 | ], 162 | "size": [ 163 | 320, 164 | 240 165 | ], 166 | "flags": {}, 167 | "order": 3, 168 | "mode": 0, 169 | "inputs": [ 170 | { 171 | "name": "meta_batch", 172 | "type": "VHS_BatchManager", 173 | "link": null 174 | }, 175 | { 176 | "name": "vae", 177 | "type": "VAE", 178 | "link": null 179 | }, 180 | { 181 | "name": "video", 182 | "type": "STRING", 183 | "link": 10, 184 | "widget": { 185 | "name": "video" 186 | }, 187 | "slot_index": 2 188 | } 189 | ], 190 | "outputs": [ 191 | { 192 | "name": "IMAGE", 193 | "type": "IMAGE", 194 | "links": [ 195 | 9 196 | ], 197 | "shape": 3, 198 | "slot_index": 0 199 | }, 200 | { 201 | "name": "frame_count", 202 | "type": "INT", 203 | "links": null, 204 | "shape": 3 205 | }, 206 | { 207 | "name": "audio", 208 | "type": "VHS_AUDIO", 209 | "links": null, 210 | "shape": 3 211 | }, 212 | { 213 | "name": "video_info", 214 | "type": "VHS_VIDEOINFO", 215 | "links": null, 216 | "shape": 3 217 | } 218 | ], 219 | "title": "VHS_LoadVideoPath", 220 | "properties": { 221 | "Node name for S&R": "VHS_LoadVideoPath" 222 | }, 223 | "widgets_values": { 224 | "video": "X://insert/path/here.mp4", 225 | "force_rate": 0, 226 | "force_size": "Disabled", 227 | "custom_width": 512, 228 | "custom_height": 512, 229 | "frame_load_cap": 0, 230 | "skip_first_frames": 0, 231 | "select_every_nth": 1, 232 | "videopreview": { 233 | "hidden": false, 234 | "paused": false, 235 | "params": { 236 | "frame_load_cap": 0, 237 | "skip_first_frames": 0, 238 | "force_rate": 0, 239 | "filename": "X://insert/path/here.mp4", 240 | "type": "path", 241 | "format": "video/mp4", 242 | "select_every_nth": 1 243 | } 244 | } 245 | } 246 | }, 247 | { 248 | "id": 10, 249 | "type": "PreviewImage", 250 | "pos": [ 251 | 1050, 252 | -100 253 | ], 254 | "size": [ 255 | 350, 256 | 450 257 | ], 258 | "flags": {}, 259 | "order": 5, 260 | "mode": 0, 261 | "inputs": [ 262 | { 263 | "name": "images", 264 | "type": "IMAGE", 265 | "link": 9 266 | } 267 | ], 268 | "properties": { 269 | "Node name for S&R": "PreviewImage" 270 | } 271 | }, 272 | { 273 | "id": 4, 274 | "type": "ShowText|pysssss", 275 | "pos": [ 276 | 340, 277 | 90 278 | ], 279 | "size": { 280 | "0": 320, 281 | "1": 80 282 | }, 283 | "flags": {}, 284 | "order": 2, 285 | "mode": 0, 286 | "inputs": [ 287 | { 288 | "name": "text", 289 | "type": "STRING", 290 | "link": 6, 291 | "widget": { 292 | "name": "text" 293 | } 294 | } 295 | ], 296 | "outputs": [ 297 | { 298 | "name": "STRING", 299 | "type": "STRING", 300 | "links": [], 301 | "shape": 6, 302 | "slot_index": 0 303 | } 304 | ], 305 | "properties": { 306 | "Node name for S&R": "ShowText|pysssss" 307 | }, 308 | "widgets_values": [ 309 | "", 310 | "E:/BaiduSyncdisk/Input/Test0709_0003.mov" 311 | ] 312 | }, 313 | { 314 | "id": 6, 315 | "type": "easy showAnything", 316 | "pos": [ 317 | 1180, 318 | 440 319 | ], 320 | "size": { 321 | "0": 210, 322 | "1": 80 323 | }, 324 | "flags": {}, 325 | "order": 4, 326 | "mode": 0, 327 | "inputs": [ 328 | { 329 | "name": "anything", 330 | "type": "*", 331 | "link": 5 332 | } 333 | ], 334 | "title": "easy showAnything", 335 | "properties": { 336 | "Node name for S&R": "easy showAnything" 337 | }, 338 | "widgets_values": [ 339 | "960" 340 | ] 341 | } 342 | ], 343 | "links": [ 344 | [ 345 | 5, 346 | 3, 347 | 8, 348 | 6, 349 | 0, 350 | "*" 351 | ], 352 | [ 353 | 6, 354 | 7, 355 | 1, 356 | 4, 357 | 0, 358 | "STRING" 359 | ], 360 | [ 361 | 7, 362 | 7, 363 | 0, 364 | 3, 365 | 0, 366 | "VHS_VIDEOINFO" 367 | ], 368 | [ 369 | 9, 370 | 9, 371 | 0, 372 | 10, 373 | 0, 374 | "IMAGE" 375 | ], 376 | [ 377 | 10, 378 | 7, 379 | 1, 380 | 9, 381 | 2, 382 | "STRING" 383 | ] 384 | ], 385 | "groups": [], 386 | "config": {}, 387 | "extra": { 388 | "ds": { 389 | "scale": 0.8769226950000011, 390 | "offset": { 391 | "0": -141.84170532226562, 392 | "1": 181.61871337890625 393 | } 394 | } 395 | }, 396 | "version": 0.4 397 | } --------------------------------------------------------------------------------