├── LICENSE ├── Light_A_Video_node.py ├── README.md ├── __init__.py ├── animate_repo └── config.json ├── assets ├── example_ic.png ├── example_in.png ├── example_w.png ├── input_animatediff │ ├── bear.mp4 │ ├── bloom.mp4 │ ├── boat.mp4 │ ├── camera.mp4 │ ├── car.mp4 │ ├── cat.mp4 │ ├── cat2.mp4 │ ├── coin.mp4 │ ├── cow.mp4 │ ├── dog2.mp4 │ ├── flowers.mp4 │ ├── fox.mp4 │ ├── girl.mp4 │ ├── girl2.mp4 │ ├── juice.mp4 │ ├── man.mp4 │ ├── man2.mp4 │ ├── man3.mp4 │ ├── man4.mp4 │ ├── plane.mp4 │ ├── toy.mp4 │ ├── wolf2.mp4 │ └── woman.mp4 ├── input_cog │ └── bear.mp4 └── input_wan │ ├── bear.mp4 │ └── man.mp4 ├── configs ├── cog_relight │ └── bear.yaml ├── relight │ ├── bear.yaml │ ├── boat.yaml │ ├── car.yaml │ ├── cat.yaml │ ├── cow.yaml │ ├── flowers.yaml │ ├── fox.yaml │ ├── girl.yaml │ ├── girl2.yaml │ ├── juice.yaml │ ├── man2.yaml │ ├── man4.yaml │ ├── plane.yaml │ ├── toy.yaml │ └── woman.yaml ├── relight_inpaint │ ├── bloom.yaml │ ├── camera.yaml │ ├── car.yaml │ ├── car_2.yaml │ ├── cat2.yaml │ ├── coin.yaml │ ├── dog2.yaml │ ├── man3.yaml │ ├── man3_2.yaml │ ├── water.yaml │ └── wolf2.yaml └── wan_relight │ ├── bear.yaml │ └── man.yaml ├── lav_cog_relight.py ├── lav_paint.py ├── lav_relight.py ├── lav_wan_relight.py ├── node_utils.py ├── pyproject.toml ├── requirements.txt ├── sam2.py ├── sd_repo ├── README.md ├── feature_extractor │ └── preprocessor_config.json ├── model_index.json ├── safety_checker │ └── config.json ├── scheduler │ └── scheduler_config.json ├── text_encoder │ └── config.json ├── tokenizer │ ├── merges.txt │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── vocab.json ├── unet │ └── config.json ├── v1-inference.yaml └── vae │ └── config.json ├── src ├── animatediff_eul.py ├── animatediff_inpaint_pipe.py ├── animatediff_pipe.py ├── cogvideo_ddim.py ├── cogvideo_pipe.py ├── ic_light.py ├── ic_light_pipe.py ├── tools.py ├── wan_flowmatch.py └── wan_pipe.py └── utils └── tools.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Light_A_Video_node.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | import os 4 | import torch 5 | import gc 6 | import numpy as np 7 | from diffusers import StableDiffusionPipeline 8 | from .lav_relight import load_ic_light_model,infer_relight 9 | from .lav_wan_relight import load_ic_light_wan,infer_relight_wan 10 | from .lav_cog_relight import load_ic_light_cog,infer_relight_cog 11 | from .node_utils import load_images,tensor2pil_list 12 | import folder_paths 13 | from .src.ic_light import BGSource 14 | from .src.tools import set_all_seed 15 | 16 | MAX_SEED = np.iinfo(np.int32).max 17 | current_node_path = os.path.dirname(os.path.abspath(__file__)) 18 | device = torch.device( 19 | "cuda") if torch.cuda.is_available() else torch.device("cpu") 20 | 21 | # add checkpoints dir 22 | Light_A_Video_weigths_path = os.path.join(folder_paths.models_dir, "Light_A_Video") 23 | if not os.path.exists(Light_A_Video_weigths_path): 24 | os.makedirs(Light_A_Video_weigths_path) 25 | folder_paths.add_model_folder_path("Light_A_Video", Light_A_Video_weigths_path) 26 | 27 | 28 | class Light_A_Video_Loader: 29 | def __init__(self): 30 | pass 31 | 32 | @classmethod 33 | def INPUT_TYPES(s): 34 | return { 35 | "required": { 36 | "repo": ("STRING", {"default":"Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},), 37 | "model": (folder_paths.get_filename_list("checkpoints"),), 38 | "motion_adapter_model": (["none"] + folder_paths.get_filename_list("controlnet"),), 39 | "ic_light_model": (["none"] + folder_paths.get_filename_list("controlnet"),), 40 | "mode":(["relight","inpaint"],), 41 | }, 42 | } 43 | 44 | RETURN_TYPES = ("MODEL_Light_A_Video",) 45 | RETURN_NAMES = ("model",) 46 | FUNCTION = "loader_main" 47 | CATEGORY = "Light_A_Video" 48 | 49 | def loader_main(self, repo,model,motion_adapter_model, ic_light_model,mode): 50 | 51 | adopted_dtype = torch.float16 52 | # ic light 53 | sd_repo = os.path.join(current_node_path, "sd_repo") 54 | original_config_file=os.path.join(sd_repo,"v1-inference.yaml") #fix for desktop comfyUI 55 | if model!="none": 56 | ckpt_path=folder_paths.get_full_path("checkpoints",model) 57 | else: 58 | raise "no sd1.5 checkpoint" 59 | try: 60 | sd_pipe = StableDiffusionPipeline.from_single_file( 61 | ckpt_path,config=sd_repo, original_config=original_config_file) 62 | except: 63 | sd_pipe = StableDiffusionPipeline.from_single_file( 64 | ckpt_path, config=sd_repo,original_config_file=original_config_file) 65 | ic_light_model=folder_paths.get_full_path("controlnet",ic_light_model) 66 | 67 | # video model 68 | if repo: 69 | if "wan" in repo.lower(): 70 | print("***********Load wan diffuser ***********") 71 | pipe,ic_light_pipe=load_ic_light_wan(repo,sd_pipe,sd_repo,ckpt_path,ic_light_model,device,adopted_dtype) 72 | video_mode="wan" 73 | elif "cog" in repo.lower(): 74 | print("***********Load cogvideox diffuser ***********") 75 | pipe,ic_light_pipe=load_ic_light_cog(repo,sd_pipe,sd_repo,ckpt_path,ic_light_model,device,adopted_dtype) 76 | video_mode="cog" 77 | else: 78 | raise "no string match wan or cog,check your repo name" 79 | else: 80 | # load animatediff model 81 | motion_repo=os.path.join(current_node_path,"animate_repo") 82 | print("***********Load animatediff model ***********") 83 | motion_adapter_model=folder_paths.get_full_path("controlnet",motion_adapter_model) 84 | pipe,ic_light_pipe=load_ic_light_model(sd_pipe,ic_light_model,ckpt_path,sd_repo,motion_repo,motion_adapter_model,device,adopted_dtype,mode) 85 | video_mode="animate" 86 | print("***********Load model done ***********") 87 | gc.collect() 88 | torch.cuda.empty_cache() 89 | return ({"model":pipe,"ic_light_pipe":ic_light_pipe,"mode":mode,"adopted_dtype":adopted_dtype,"video_mode":video_mode},) 90 | 91 | 92 | class Light_A_Video_Sampler: 93 | def __init__(self): 94 | pass 95 | 96 | @classmethod 97 | def INPUT_TYPES(s): 98 | return { 99 | "required": { 100 | "model": ("MODEL_Light_A_Video",), 101 | "images": ("IMAGE",), 102 | "relight_prompt": ("STRING", {"default": "a bear walking on the rock, nature lighting, soft light", "multiline": True}), 103 | "vdm_prompt": ("STRING", {"default": "a bear walking on the rock", "multiline": True}), 104 | "inpaint_prompt": ("STRING", {"default": "a car driving on the beach, sunset over sea", "multiline": True}), 105 | "n_prompt": ("STRING", {"default": "bad quality, worse quality", "multiline": True}), 106 | "seed": ("INT", {"default": 0, "min": 0, "max": MAX_SEED}), 107 | "num_step": ("INT", {"default": 25, "min": 1, "max": 1024, "step": 1, "display": "number"}), 108 | "strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05}), 109 | "text_guide_scale": ("INT", {"default": 2, "min": 1, "max": 20, "step": 1, "display": "number"}), 110 | "width": ("INT", {"default": 512, "min": 64, "max": 2048, "step": 16, "display": "number"}), 111 | "height": ("INT", {"default": 512, "min": 64, "max": 2048, "step": 16, "display": "number"}), 112 | "num_frames": ("INT", {"default": 49, "min": 10, "max": 1024, "step": 1, "display": "number"}), 113 | "bg_target": (["LEFT", "RIGHT", "TOP", "BOTTOM","NONE",],), 114 | "mask_repo": ("STRING", {"default": "ZhengPeng7/BiRefNet"},),}, 115 | "optional": {"mask_img": ("IMAGE",), 116 | "fps": ("FLOAT", {"default": 8.0, "min": 8.0, "max": 100.0, "step": 0.1}), 117 | }, 118 | } 119 | 120 | RETURN_TYPES = ("IMAGE", ) 121 | RETURN_NAMES = ("image",) 122 | FUNCTION = "sampler_main" 123 | CATEGORY = "Light_A_Video" 124 | 125 | def sampler_main(self, model,images,relight_prompt,vdm_prompt,inpaint_prompt, n_prompt,seed, num_step, strength,text_guide_scale,width, height,num_frames,bg_target,mask_repo,**kwargs): 126 | set_all_seed(42) 127 | 128 | local_sam=os.path.join(Light_A_Video_weigths_path,"sam2_b.pt") 129 | if not os.path.exists(local_sam): 130 | local_sam="sam2_b.pt" 131 | ref_image_list=tensor2pil_list(images,width,height) 132 | 133 | if bg_target=="NONE": 134 | bg_source = BGSource.NONE 135 | elif bg_target=="LEFT": 136 | bg_source = BGSource.LEFT 137 | elif bg_target=="RIGHT": 138 | bg_source = BGSource.RIGHT 139 | elif bg_target=="TOP": 140 | bg_source = BGSource.TOP 141 | else: 142 | bg_source = BGSource.BOTTOM 143 | 144 | mask_img=kwargs.get("mask_img",None) 145 | fps=kwargs.get("fps",8.0) 146 | ic_light_pipe=model.get("ic_light_pipe") 147 | pipe=model.get("model") 148 | video_mode=model.get("video_mode") 149 | mode=model.get("mode") 150 | adopted_dtype=model.get("adopted_dtype") 151 | 152 | if isinstance(mask_img,torch.Tensor): 153 | mask_list=tensor2pil_list(mask_img,width,height) 154 | # for i,mask in enumerate(mask_list): 155 | # mask.save(f"mask{i}.png") 156 | if len(mask_list)==1: 157 | mask_list=mask_list*len(ref_image_list) 158 | print("not enough mask, repeat mask to match the number of images") 159 | else: 160 | mask_list=None 161 | 162 | print("***********Start infer ***********") 163 | if video_mode=="wan": 164 | iamge = infer_relight_wan(ic_light_pipe,pipe,strength,num_step,text_guide_scale,seed,width,height,n_prompt,relight_prompt,vdm_prompt,ref_image_list,bg_source,num_frames,mode,mask_list,device,adopted_dtype,mask_repo,fps,local_sam) 165 | elif video_mode=="cog": 166 | iamge = infer_relight_cog(ic_light_pipe,pipe,strength,num_step,text_guide_scale,seed,width,height,n_prompt,relight_prompt,vdm_prompt,ref_image_list,bg_source,num_frames,mode,mask_list,device,adopted_dtype,mask_repo,fps,local_sam) 167 | else: 168 | iamge = infer_relight(ic_light_pipe,pipe,strength,num_step,text_guide_scale,seed,width,height,n_prompt,relight_prompt,inpaint_prompt,ref_image_list,bg_source,mode,mask_list,device,adopted_dtype,mask_repo,fps,local_sam) 169 | 170 | torch.cuda.empty_cache() 171 | 172 | return (load_images(iamge),) 173 | 174 | 175 | NODE_CLASS_MAPPINGS = { 176 | "Light_A_Video_Loader": Light_A_Video_Loader, 177 | "Light_A_Video_Sampler": Light_A_Video_Sampler, 178 | } 179 | 180 | NODE_DISPLAY_NAME_MAPPINGS = { 181 | "Light_A_Video_Loader": "Light_A_Video_Loader", 182 | "Light_A_Video_Sampler": "Light_A_Video_Sampler", 183 | } 184 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI_Light_A_Video 2 | [Light-A-Video](https://github.com/bcmi/Light-A-Video): Training-free Video Relighting via Progressive Light Fusion,you can use it in comfyUI 3 | 4 | 5 | # Update 6 | * support ‘cogvideox’ and ‘wan2.1 diffusers’ 视频底模同步官方代码,支持cogvideox 和 ‘万相2.1 diffusers’ 7 | * wan2.1 image size set to 832 * 480, cogvideox image size set to 720 * 480 万相的图片尺寸设置832 * 480,cog设置720 * 480 8 | * wan2.1 need 4090 or more Vram,目前还没有优化,主要是wan的T5太大了 9 | 10 | # 1. Installation 11 | 12 | In the ./ComfyUI /custom_node directory, run the following: 13 | ``` 14 | git clone https://github.com/smthemex/ComfyUI_Light_A_Video.git 15 | ``` 16 | --- 17 | 18 | # 2. Requirements 19 | ``` 20 | pip install -r requirements.txt 21 | ``` 22 | * if use sam2 to get mask image (use inpaint mode),need 'ultralytics>=8.3.51',使用sam2模式获取mask图片时(内绘模式),需要'ultralytics>=8.3.51,可能低一两个版本也能用,不测试了。 23 | 24 | * wan2.1 need diffusers main ,so install it as below/万相需要的diffuser版本太新,需要按以下方法安装: 25 | ``` 26 | pip install git+https://github.com/huggingface/diffusers 27 | 28 | or 29 | 30 | git clone https://github.com/huggingface/diffusers.git 31 | cd diffusers 32 | pip install -e ".[torch]" 33 | ``` 34 | # 3.Model 35 | **3.1 base models** 36 | * any sd1.5 checkpoints 37 | ``` 38 | -- ComfyUI/models/checkpoints 39 | ├── any sd1.5 checkpoints 40 | ``` 41 | * iclight_sd15_fc.safetensors from [here](https://huggingface.co/lllyasviel/ic-light/tree/main) 42 | ``` 43 | -- ComfyUI/models/controlnet 44 | ├── iclight_sd15_fc.safetensors 45 | ``` 46 | **3.2 use animatediff** 47 | * animatediff-motion-adapter-v1-5-3.safetensors from [here](https://huggingface.co/guoyww/animatediff-motion-adapter-v1-5-3/tree/main) 48 | ``` 49 | -- ComfyUI/models/controlnet 50 | ├── animatediff-motion-adapter-v1-5-3.safetensors # rename or not 或者随便换个名字 51 | ``` 52 | * if use [sam2](https://github.com/ultralytics/assets/releases/download/v8.3.0/sam2_b.pt) to get mask image,如果使用sam2需要才下载模型,你使用BiRefNet是不用的(当然要下BiRefNet模型),注意sam2的注意点在正中,所以主体最好在中间。 53 | ``` 54 | -- ComfyUI/models/Light_A_Video 55 | ├── sam2_b.pt #会自动下载 56 | ``` 57 | **3.3 use wan2.1** 58 | * fill repo [Wan-AI/Wan2.1-T2V-1.3B-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers) or local repo (使用抱脸的repo在线下载或者预下载存放在本地的本地repo地址) 59 | 60 | **3.4 use cogvideoX** 61 | * fill repo [THUDM/CogVideoX-2b](https://huggingface.co/THUDM/CogVideoX-2b/tree/main) or local repo (使用抱脸的repo在线下载或者预下载存放在本地的本地repo地址) 62 | 63 | # 4.Tips 64 | * 第二条为cog和wan专用prompt,只需要填写主体,比如一只熊什么的,不要填写灯光(The second is a special prompt for cog and wan, only need to fill in the main body, such as a bear or something, do not fill in the light); 65 | * The prompt in the middle is used for the inner painting mode, and there is no need to fill in the light prompt, but the prompt related to the subject needs to be filled; 66 | * 中间的prompt是用于内绘模式的,无需填写灯光提示,需要填写主体相关的prompt; 67 | * mask_repo:The method to get the mask is either to fill in 'ZhengPeng7/BiRefNet', or not to fill in, and it will automatically use sam2 or use the ‘mask_img’ interface to connect to the mask video; 68 | * 获取mask的方法要么填'ZhengPeng7/BiRefNet'的repo或者本地绝对地址,要么不填,会自动用sam2模式,或者用mask_img接口连入mask视频; 69 | 70 | 71 | # 5.Example 72 | * wan2.1 ic-light 73 | ![](https://github.com/smthemex/ComfyUI_Light_A_Video/blob/main/assets/example_w.png) 74 | * animatediff ic-light 75 | ![](https://github.com/smthemex/ComfyUI_Light_A_Video/blob/main/assets/example_ic.png) 76 | * animatediff inpanit 77 | ![](https://github.com/smthemex/ComfyUI_Light_A_Video/blob/main/assets/example_in.png) 78 | 79 | 80 | # Citation 81 | ``` 82 | @article{zhou2025light, 83 | title={Light-A-Video: Training-free Video Relighting via Progressive Light Fusion}, 84 | author={Zhou, Yujie and Bu, Jiazi and Ling, Pengyang and Zhang, Pan and Wu, Tong and Huang, Qidong and Li, Jinsong and Dong, Xiaoyi and Zang, Yuhang and Cao, Yuhang and others}, 85 | journal={arXiv preprint arXiv:2502.08590}, 86 | year={2025} 87 | } 88 | ``` 89 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .Light_A_Video_node import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 3 | 4 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 5 | -------------------------------------------------------------------------------- /animate_repo/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "MotionAdapter", 3 | "_diffusers_version": "0.25.0.dev0", 4 | "block_out_channels": [ 5 | 320, 6 | 640, 7 | 1280, 8 | 1280 9 | ], 10 | "motion_layers_per_block": 2, 11 | "motion_max_seq_length": 32, 12 | "motion_mid_block_layers_per_block": 1, 13 | "motion_norm_num_groups": 32, 14 | "motion_num_attention_heads": 8, 15 | "use_motion_mid_block": false 16 | } 17 | -------------------------------------------------------------------------------- /assets/example_ic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/example_ic.png -------------------------------------------------------------------------------- /assets/example_in.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/example_in.png -------------------------------------------------------------------------------- /assets/example_w.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/example_w.png -------------------------------------------------------------------------------- /assets/input_animatediff/bear.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/bear.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/bloom.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/bloom.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/boat.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/boat.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/camera.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/camera.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/car.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/car.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/cat.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/cat.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/cat2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/cat2.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/coin.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/coin.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/cow.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/cow.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/dog2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/dog2.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/flowers.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/flowers.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/fox.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/fox.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/girl.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/girl.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/girl2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/girl2.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/juice.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/juice.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/man.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/man.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/man2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/man2.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/man3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/man3.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/man4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/man4.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/plane.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/plane.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/toy.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/toy.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/wolf2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/wolf2.mp4 -------------------------------------------------------------------------------- /assets/input_animatediff/woman.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_animatediff/woman.mp4 -------------------------------------------------------------------------------- /assets/input_cog/bear.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_cog/bear.mp4 -------------------------------------------------------------------------------- /assets/input_wan/bear.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_wan/bear.mp4 -------------------------------------------------------------------------------- /assets/input_wan/man.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Light_A_Video/1773b3792857f4394a169fe8afbebd574cb0a429/assets/input_wan/man.mp4 -------------------------------------------------------------------------------- /configs/cog_relight/bear.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | vdm_prompt: "a bear walking on the rock" 3 | relight_prompt: "a bear walking on the rock, nature lighting, soft light" 4 | video_path: "input_cog/bear.mp4" 5 | bg_source: "TOP" ## LEFT, RIGHT, BOTTOM, TOP 6 | save_path: "output" 7 | 8 | width: 720 9 | height: 480 10 | strength: 0.4 11 | gamma: 0.7 12 | num_step: 25 13 | text_guide_scale: 2 14 | seed: 42 -------------------------------------------------------------------------------- /configs/relight/bear.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | relight_prompt: "a bear walking on the rock, nature lighting, key light" 3 | video_path: "input_animatediff/bear.mp4" 4 | bg_source: "TOP" ## NONE, LEFT, RIGHT, BOTTOM, TOP 5 | save_path: "output" 6 | 7 | width: 512 8 | height: 512 9 | strength: 0.5 10 | gamma: 0.5 11 | num_step: 25 12 | text_guide_scale: 2 13 | seed: 42 -------------------------------------------------------------------------------- /configs/relight/boat.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | relight_prompt: "a boat floating on the sea, sunset" 3 | video_path: "input_animatediff/boat.mp4" 4 | bg_source: "TOP" ## NONE, LEFT, RIGHT, BOTTOM, TOP 5 | save_path: "output" 6 | 7 | width: 512 8 | height: 512 9 | strength: 0.5 10 | gamma: 0.5 11 | num_step: 25 12 | text_guide_scale: 2 13 | seed: 42 -------------------------------------------------------------------------------- /configs/relight/car.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | relight_prompt: "a car driving on the street, neon light" 3 | video_path: "input_animatediff/car.mp4" 4 | bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP 5 | save_path: "output" 6 | 7 | width: 512 8 | height: 512 9 | strength: 0.5 10 | gamma: 0.5 11 | num_step: 25 12 | text_guide_scale: 2 13 | seed: 2060 -------------------------------------------------------------------------------- /configs/relight/cat.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | relight_prompt: "a cat, red and blue neon light" 3 | video_path: "input_animatediff/cat.mp4" 4 | bg_source: "LEFT" ## NONE, LEFT, RIGHT, BOTTOM, TOP 5 | save_path: "output" 6 | 7 | width: 512 8 | height: 512 9 | strength: 0.5 10 | gamma: 0.5 11 | num_step: 25 12 | text_guide_scale: 2 13 | seed: 42 -------------------------------------------------------------------------------- /configs/relight/cow.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | relight_prompt: "a cow drinking water in the river, sunset" 3 | video_path: "input_animatediff/cow.mp4" 4 | bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP 5 | save_path: "output" 6 | 7 | width: 512 8 | height: 512 9 | strength: 0.5 10 | gamma: 0.5 11 | num_step: 25 12 | text_guide_scale: 2 13 | seed: 42 -------------------------------------------------------------------------------- /configs/relight/flowers.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality, unclear, blurry" 2 | relight_prompt: "A basket of flowers, sunshine, hard light" 3 | video_path: "input_animatediff/flowers.mp4" 4 | bg_source: "LEFT" ## NONE, LEFT, RIGHT, BOTTOM, TOP 5 | save_path: "output" 6 | 7 | width: 512 8 | height: 512 9 | strength: 0.5 10 | gamma: 0.5 11 | num_step: 25 12 | text_guide_scale: 2 13 | seed: 42 -------------------------------------------------------------------------------- /configs/relight/fox.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | relight_prompt: "a fox, sunlight filtering through trees, dappled light" 3 | video_path: "input_animatediff/fox.mp4" 4 | bg_source: "LEFT" ## NONE, LEFT, RIGHT, BOTTOM, TOP 5 | save_path: "output" 6 | 7 | width: 512 8 | height: 512 9 | strength: 0.5 10 | gamma: 0.5 11 | num_step: 25 12 | text_guide_scale: 2 13 | seed: 42 -------------------------------------------------------------------------------- /configs/relight/girl.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | relight_prompt: "a girl, magic lit, sci-fi RGB glowing, key lighting" 3 | video_path: "input_animatediff/girl.mp4" 4 | bg_source: "BOTTOM" ## NONE, LEFT, RIGHT, BOTTOM, TOP 5 | save_path: "output" 6 | 7 | width: 512 8 | height: 512 9 | strength: 0.5 10 | gamma: 0.5 11 | num_step: 25 12 | text_guide_scale: 2 13 | seed: 42 -------------------------------------------------------------------------------- /configs/relight/girl2.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | relight_prompt: "an anime girl, neon light" 3 | video_path: "input_animatediff/girl2.mp4" 4 | bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP 5 | save_path: "output" 6 | 7 | width: 512 8 | height: 512 9 | strength: 0.5 10 | gamma: 0.5 11 | num_step: 25 12 | text_guide_scale: 2 13 | seed: 42 -------------------------------------------------------------------------------- /configs/relight/juice.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | relight_prompt: "Pour juice into a glass, magic golden lit" 3 | video_path: "input_animatediff/juice.mp4" 4 | bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP 5 | save_path: "output" 6 | 7 | width: 512 8 | height: 512 9 | strength: 0.5 10 | gamma: 0.5 11 | num_step: 25 12 | text_guide_scale: 2 13 | seed: 42 -------------------------------------------------------------------------------- /configs/relight/man2.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | relight_prompt: "handsome man with glasses, shadow from window, sunshine" 3 | video_path: "input_animatediff/man2.mp4" 4 | bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP 5 | save_path: "output" 6 | 7 | width: 512 8 | height: 512 9 | strength: 0.5 10 | gamma: 0.5 11 | num_step: 25 12 | text_guide_scale: 2 13 | seed: 42 -------------------------------------------------------------------------------- /configs/relight/man4.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | relight_prompt: "handsome man with glasses, sunlight through the blinds" 3 | video_path: "input_animatediff/man4.mp4" 4 | bg_source: "LEFT" ## NONE, LEFT, RIGHT, BOTTOM, TOP 5 | save_path: "output" 6 | 7 | width: 512 8 | height: 512 9 | strength: 0.5 10 | gamma: 0.5 11 | num_step: 25 12 | text_guide_scale: 2 13 | seed: 42 -------------------------------------------------------------------------------- /configs/relight/plane.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | relight_prompt: "a plane on the runway, bottom neon light" 3 | video_path: "input_animatediff/plane.mp4" 4 | bg_source: "BOTTOM" ## NONE, LEFT, RIGHT, BOTTOM, TOP 5 | save_path: "output" 6 | 7 | width: 512 8 | height: 512 9 | strength: 0.5 10 | gamma: 0.5 11 | num_step: 25 12 | text_guide_scale: 2 13 | seed: 42 -------------------------------------------------------------------------------- /configs/relight/toy.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | relight_prompt: "a maneki-neko toy, cozy bedroom illumination" 3 | video_path: "input_animatediff/toy.mp4" 4 | bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP 5 | save_path: "output" 6 | 7 | width: 512 8 | height: 512 9 | strength: 0.5 10 | gamma: 0.5 11 | num_step: 25 12 | text_guide_scale: 2 13 | seed: 42 -------------------------------------------------------------------------------- /configs/relight/woman.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | relight_prompt: "a woman with curly hair, natural lighting, warm atmosphere" 3 | video_path: "input_animatediff/woman.mp4" 4 | bg_source: "LEFT" ## NONE, LEFT, RIGHT, BOTTOM, TOP 5 | save_path: "output" 6 | 7 | width: 512 8 | height: 512 9 | strength: 0.5 10 | gamma: 0.5 11 | num_step: 25 12 | text_guide_scale: 2 13 | seed: 42 -------------------------------------------------------------------------------- /configs/relight_inpaint/bloom.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | inpaint_prompt: "a red flower blooming in the river" 3 | relight_prompt: "a red flower blooming in the river, nature lighting" 4 | 5 | video_path: "input_animatediff/bloom.mp4" 6 | bg_source: "TOP" ## NONE, LEFT, RIGHT, BOTTOM, TOP 7 | save_path: "output" 8 | 9 | width: 512 10 | height: 512 11 | strength: 0.4 12 | gamma: 0.5 13 | num_step: 50 14 | text_guide_scale: 4 15 | seed: 8776 -------------------------------------------------------------------------------- /configs/relight_inpaint/camera.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | inpaint_prompt: "A tiny camera on a tray, cyberpunk" 3 | relight_prompt: "A tiny camera on a tray, cyberpunk, neon light" 4 | 5 | video_path: "input_animatediff/camera.mp4" 6 | bg_source: "LEFT" ## NONE, LEFT, RIGHT, BOTTOM, TOP 7 | save_path: "output" 8 | 9 | width: 512 10 | height: 512 11 | strength: 0.4 12 | gamma: 0.5 13 | num_step: 50 14 | text_guide_scale: 3 15 | seed: 1333 -------------------------------------------------------------------------------- /configs/relight_inpaint/car.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | inpaint_prompt: "a car driving on the street" 3 | relight_prompt: "a car driving on the street, neon light" 4 | 5 | video_path: "input_animatediff/car.mp4" 6 | bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP 7 | save_path: "output" 8 | 9 | width: 512 10 | height: 512 11 | strength: 0.5 12 | gamma: 0.5 13 | num_step: 50 14 | text_guide_scale: 2 15 | seed: 6561 -------------------------------------------------------------------------------- /configs/relight_inpaint/car_2.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | inpaint_prompt: "a car driving on the beach, sunset over sea" 3 | relight_prompt: "a car driving on the beach, sunset over sea, left light, shadow" 4 | 5 | video_path: "input_animatediff/car.mp4" 6 | bg_source: "LEFT" ## NONE, LEFT, RIGHT, BOTTOM, TOP 7 | save_path: "output" 8 | 9 | width: 512 10 | height: 512 11 | strength: 0.5 12 | gamma: 0.5 13 | num_step: 50 14 | text_guide_scale: 2 15 | seed: 2409 -------------------------------------------------------------------------------- /configs/relight_inpaint/cat2.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | inpaint_prompt: "A cat walking on a runway, red and blue neon lights on both sides" 3 | relight_prompt: "A cat walking on a runway, red and blue neon lights on both sides, key light" 4 | 5 | video_path: "input_animatediff/cat2.mp4" 6 | bg_source: "LEFT" ## NONE, LEFT, RIGHT, BOTTOM, TOP 7 | save_path: "output" 8 | 9 | width: 512 10 | height: 512 11 | strength: 0.5 12 | gamma: 0.5 13 | num_step: 50 14 | text_guide_scale: 5 15 | seed: 2949 -------------------------------------------------------------------------------- /configs/relight_inpaint/coin.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | inpaint_prompt: "A coin on the desk" 3 | relight_prompt: "A coin on the desk, natural lighting" 4 | 5 | video_path: "input_animatediff/coin.mp4" 6 | bg_source: "TOP" ## NONE, LEFT, RIGHT, BOTTOM, TOP 7 | save_path: "output" 8 | 9 | width: 512 10 | height: 512 11 | strength: 0.4 12 | gamma: 0.5 13 | num_step: 80 14 | text_guide_scale: 2 15 | seed: 4013 -------------------------------------------------------------------------------- /configs/relight_inpaint/dog2.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | inpaint_prompt: "a dog in the room, sunshine from window" 3 | relight_prompt: "a dog in the room, sunshine from window" 4 | 5 | video_path: "input_animatediff/dog2.mp4" 6 | bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP 7 | save_path: "output" 8 | 9 | width: 512 10 | height: 512 11 | strength: 0.4 12 | gamma: 0.5 13 | num_step: 50 14 | text_guide_scale: 2 15 | seed: 4550 -------------------------------------------------------------------------------- /configs/relight_inpaint/man3.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | inpaint_prompt: "A man in the classroom" 3 | relight_prompt: "A man in the classroom, sunshine from the window" 4 | 5 | video_path: "input_animatediff/man3.mp4" 6 | bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP 7 | save_path: "output" 8 | 9 | width: 512 10 | height: 512 11 | strength: 0.5 12 | gamma: 0.5 13 | num_step: 50 14 | text_guide_scale: 3 15 | seed: 3931 -------------------------------------------------------------------------------- /configs/relight_inpaint/man3_2.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | inpaint_prompt: "A man in a bar, left yellow and right purple neon lights" 3 | relight_prompt: "A man in a bar, left yellow and right purple neon lights, hard light" 4 | 5 | video_path: "input_animatediff/man3.mp4" 6 | bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP 7 | save_path: "output" 8 | 9 | width: 512 10 | height: 512 11 | strength: 0.5 12 | gamma: 0.5 13 | num_step: 50 14 | text_guide_scale: 4 15 | seed: 9528 -------------------------------------------------------------------------------- /configs/relight_inpaint/water.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | inpaint_prompt: "a glass of water, in the forest, magic golden lit" 3 | relight_prompt: "a glass of water, in the forest, magic golden lit, key light" 4 | 5 | video_path: "input_animatediff/water.mp4" 6 | bg_source: "TOP" ## NONE, LEFT, RIGHT, BOTTOM, TOP 7 | save_path: "output" 8 | 9 | width: 512 10 | height: 512 11 | strength: 0.4 12 | gamma: 0.5 13 | num_step: 50 14 | text_guide_scale: 4 15 | seed: 796 -------------------------------------------------------------------------------- /configs/relight_inpaint/wolf2.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | inpaint_prompt: "a wolf stands in an alley, detailed face, neon, Wong Kar-wai, warm" 3 | relight_prompt: "a wolf stands in an alley, detailed face, neon, Wong Kar-wai, warm, right light" 4 | 5 | video_path: "input_animatediff/wolf2.mp4" 6 | bg_source: "RIGHT" ## NONE, LEFT, RIGHT, BOTTOM, TOP 7 | save_path: "output" 8 | 9 | width: 512 10 | height: 512 11 | strength: 0.5 12 | gamma: 0.5 13 | num_step: 50 14 | text_guide_scale: 5 15 | seed: 2172 -------------------------------------------------------------------------------- /configs/wan_relight/bear.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | vdm_prompt: "a bear walking on the rock" 3 | relight_prompt: "a bear walking on the rock, nature lighting, soft light" 4 | video_path: "input_wan/bear.mp4" 5 | bg_source: "TOP" ## LEFT, RIGHT, BOTTOM, TOP 6 | save_path: "output" 7 | 8 | width: 832 9 | height: 480 10 | num_frames: 49 11 | strength: 0.4 12 | gamma: 0.7 13 | num_step: 25 14 | text_guide_scale: 2 15 | seed: 42 -------------------------------------------------------------------------------- /configs/wan_relight/man.yaml: -------------------------------------------------------------------------------- 1 | n_prompt: "bad quality, worse quality" 2 | vdm_prompt: "a man walking in the factory" 3 | relight_prompt: "a man walking in the factory, red and blue neon light" 4 | video_path: "input_wan/man.mp4" 5 | bg_source: "LEFT" ## LEFT, RIGHT, BOTTOM, TOP 6 | save_path: "output" 7 | 8 | width: 512 9 | height: 512 10 | num_frames: 49 11 | strength: 0.4 12 | gamma: 0.7 13 | num_step: 25 14 | text_guide_scale: 2 15 | seed: 42 -------------------------------------------------------------------------------- /lav_cog_relight.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import imageio 4 | import argparse 5 | import safetensors.torch as sf 6 | from omegaconf import OmegaConf 7 | import torch.nn.functional as F 8 | from torch.hub import download_url_to_file 9 | 10 | from diffusers import CogVideoXDDIMScheduler 11 | from transformers import CLIPTextModel, CLIPTokenizer 12 | from diffusers.models.attention_processor import AttnProcessor2_0 13 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler 14 | 15 | from .src.ic_light import BGSource 16 | from .src.ic_light_pipe import StableDiffusionImg2ImgPipeline 17 | from .src.cogvideo_pipe import CogVideoXVideoToVideoPipeline 18 | from .utils.tools import set_all_seed, read_video 19 | 20 | 21 | 22 | def load_ic_light_cog(repo,sd_pipe,sd_repo,ckpt_path,ic_light_model,device,adopted_dtype): 23 | 24 | 25 | # config = OmegaConf.load(args.config) 26 | # device = torch.device('cuda') 27 | # adopted_dtype = torch.float16 28 | # set_all_seed(42) 29 | 30 | ## vdm model 31 | pipe = CogVideoXVideoToVideoPipeline.from_pretrained(repo, torch_dtype=adopted_dtype) 32 | pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config) 33 | pipe = pipe.to(device=device, dtype=adopted_dtype) 34 | pipe.vae.requires_grad_(False) 35 | pipe.transformer.requires_grad_(False) 36 | 37 | ## module 38 | tokenizer = CLIPTokenizer.from_pretrained(sd_repo, subfolder="tokenizer") 39 | text_encoder = sd_pipe.text_encoder 40 | vae = sd_pipe.vae 41 | unet =sd_pipe.unet 42 | 43 | with torch.no_grad(): 44 | new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) 45 | new_conv_in.weight.zero_() 46 | new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) 47 | new_conv_in.bias = unet.conv_in.bias 48 | unet.conv_in = new_conv_in 49 | unet_original_forward = unet.forward 50 | 51 | def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs): 52 | c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample) 53 | c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0) 54 | new_sample = torch.cat([sample, c_concat], dim=1) 55 | kwargs['cross_attention_kwargs'] = {} 56 | return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs) 57 | unet.forward = hooked_unet_forward 58 | 59 | ## ic-light model loader 60 | if not os.path.exists(ic_light_model): 61 | download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors', 62 | dst=ic_light_model) #TO DO: change to ic-light model 63 | sd_offset = sf.load_file(ic_light_model) 64 | sd_origin = unet.state_dict() 65 | sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()} 66 | unet.load_state_dict(sd_merged, strict=True) 67 | del sd_offset, sd_origin, sd_merged 68 | text_encoder = text_encoder.to(device=device, dtype=adopted_dtype) 69 | vae = vae.to(device=device, dtype=adopted_dtype) 70 | unet = unet.to(device=device, dtype=adopted_dtype) 71 | 72 | unet.set_attn_processor(AttnProcessor2_0()) 73 | vae.set_attn_processor(AttnProcessor2_0()) 74 | 75 | # Consistent light attention 76 | @torch.inference_mode() 77 | def custom_forward_CLA(self, 78 | hidden_states, 79 | gamma=0.7, 80 | encoder_hidden_states=None, 81 | attention_mask=None, 82 | cross_attention_kwargs=None 83 | ): 84 | 85 | batch_size, sequence_length, channel = hidden_states.shape 86 | 87 | residual = hidden_states 88 | input_ndim = hidden_states.ndim 89 | if input_ndim == 4: 90 | batch_size, channel, height, width = hidden_states.shape 91 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 92 | 93 | if attention_mask is not None: 94 | if attention_mask.shape[-1] != query.shape[1]: 95 | target_length = query.shape[1] 96 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 97 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 98 | if self.group_norm is not None: 99 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 100 | if encoder_hidden_states is None: 101 | encoder_hidden_states = hidden_states 102 | 103 | query = self.to_q(hidden_states) 104 | key = self.to_k(encoder_hidden_states) 105 | value = self.to_v(encoder_hidden_states) 106 | inner_dim = key.shape[-1] 107 | head_dim = inner_dim // self.heads 108 | query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 109 | key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 110 | value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 111 | 112 | hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) 113 | shape = query.shape 114 | mean_key = key.reshape(2,-1,shape[1],shape[2],shape[3]).mean(dim=1,keepdim=True).expand(-1,shape[0]//2,-1,-1,-1).reshape(shape[0],shape[1],shape[2],shape[3]) 115 | mean_value = value.reshape(2,-1,shape[1],shape[2],shape[3]).mean(dim=1,keepdim=True).expand(-1,shape[0]//2,-1,-1,-1).reshape(shape[0],shape[1],shape[2],shape[3]) 116 | hidden_states_mean = F.scaled_dot_product_attention(query, mean_key, mean_value, attn_mask=None, dropout_p=0.0, is_causal=False) 117 | 118 | hidden_states = (1-gamma)*hidden_states + gamma*hidden_states_mean 119 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) 120 | hidden_states = hidden_states.to(query.dtype) 121 | hidden_states = self.to_out[0](hidden_states) 122 | hidden_states = self.to_out[1](hidden_states) 123 | 124 | if input_ndim == 4: 125 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 126 | 127 | if self.residual_connection: 128 | hidden_states = hidden_states + residual 129 | 130 | hidden_states = hidden_states / self.rescale_output_factor 131 | return hidden_states 132 | 133 | ### attention 134 | from types import MethodType 135 | @torch.inference_mode() 136 | def prep_unet_self_attention(unet): 137 | 138 | for name, module in unet.named_modules(): 139 | module_name = type(module).__name__ 140 | 141 | name_split_list = name.split(".") 142 | cond_1 = name_split_list[0] in "up_blocks" 143 | cond_2 = name_split_list[-1] in ('attn1') 144 | 145 | if "Attention" in module_name and cond_1 and cond_2: 146 | cond_3 = name_split_list[1] 147 | if cond_3 not in "3": 148 | module.forward = MethodType(custom_forward_CLA, module) 149 | 150 | return unet 151 | 152 | ## attn module 153 | unet = prep_unet_self_attention(unet) 154 | 155 | ## ic-light-scheduler 156 | ic_light_scheduler = DPMSolverMultistepScheduler( 157 | num_train_timesteps=1000, 158 | beta_start=0.00085, 159 | beta_end=0.012, 160 | algorithm_type="sde-dpmsolver++", 161 | use_karras_sigmas=True, 162 | steps_offset=1 163 | ) 164 | ic_light_pipe = StableDiffusionImg2ImgPipeline( 165 | vae=vae, 166 | text_encoder=text_encoder, 167 | tokenizer=tokenizer, 168 | unet=unet, 169 | scheduler=ic_light_scheduler, 170 | safety_checker=None, 171 | requires_safety_checker=False, 172 | feature_extractor=None, 173 | image_encoder=None 174 | ) 175 | ic_light_pipe = ic_light_pipe.to(device=device, dtype=adopted_dtype) 176 | ic_light_pipe.vae.requires_grad_(False) 177 | ic_light_pipe.unet.requires_grad_(False) 178 | 179 | pipe.enable_sequential_cpu_offload() 180 | pipe.vae.enable_slicing() 181 | pipe.vae.enable_tiling() 182 | 183 | return pipe,ic_light_pipe 184 | 185 | def infer_relight_cog(ic_light_pipe,pipe,strength,num_step,text_guide_scale,seed,image_width,image_height,n_prompt,relight_prompt,vdm_prompt,video_list,bg_source,num_frames,mode,mask_list,device,adopted_dtype,repo,fps,local_sam): 186 | ############################# params ###################################### 187 | # strength = config.get("strength", 0.4) 188 | # num_step = config.get("num_step", 25) 189 | # text_guide_scale = config.get("text_guide_scale", 2) 190 | # seed = config.get("seed") 191 | # image_width = config.get("width", 720) 192 | # image_height = config.get("height", 480) 193 | # negative_prompt = config.get("n_prompt", "") 194 | # vdm_prompt = config.get("vdm_prompt", "") 195 | # relight_prompt = config.get("relight_prompt", "") 196 | # video_path = config.get("video_path", "") 197 | # bg_source = BGSource[config.get("bg_source")] 198 | # save_path = config.get("save_path") 199 | 200 | ############################## infer ##################################### 201 | generator = torch.manual_seed(seed) 202 | # video_name = os.path.basename(video_path) 203 | # video_list, video_name = read_video(video_path, image_width, image_height) 204 | 205 | print("################## begin ##################") 206 | with torch.no_grad(): 207 | num_inference_steps = int(round(num_step / strength)) 208 | 209 | output = pipe( 210 | ic_light_pipe=ic_light_pipe, 211 | relight_prompt=relight_prompt, 212 | bg_source=bg_source, 213 | video=video_list, 214 | prompt=vdm_prompt, 215 | strength=strength, 216 | negative_prompt=n_prompt, 217 | guidance_scale=text_guide_scale, 218 | num_inference_steps=num_inference_steps, 219 | height=image_height, 220 | width=image_width, 221 | generator=generator, 222 | ) 223 | 224 | frames = output.frames[0] 225 | # results_path = f"{save_path}/relight_{video_name}" 226 | # imageio.mimwrite(results_path, frames, fps=14) 227 | # print(f"relight! \n prompt:{relight_prompt}, light:{bg_source.value}, save in {results_path}.") 228 | return frames 229 | 230 | # if __name__ == "__main__": 231 | # parser = argparse.ArgumentParser() 232 | 233 | # parser.add_argument("--sd_model", type=str, default="stablediffusionapi/realistic-vision-v51") 234 | # parser.add_argument("--vdm_model", type=str, default="THUDM/CogVideoX-2b") 235 | # parser.add_argument("--ic_light_model", type=str, default="./models/iclight_sd15_fc.safetensors") 236 | 237 | # parser.add_argument("--config", type=str, default="configs/cog_relight/bear.yaml", help="the config file for each sample.") 238 | 239 | # args = parser.parse_args() 240 | # main(args) 241 | -------------------------------------------------------------------------------- /lav_paint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import imageio 4 | import argparse 5 | from types import MethodType 6 | import safetensors.torch as sf 7 | import torch.nn.functional as F 8 | from omegaconf import OmegaConf 9 | from transformers import CLIPTextModel, CLIPTokenizer 10 | from diffusers import MotionAdapter, EulerAncestralDiscreteScheduler, AutoencoderKL 11 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler 12 | from diffusers.models.attention_processor import AttnProcessor2_0 13 | 14 | from .src.ic_light import BGSource 15 | from .src.ic_light import Relighter 16 | from .src.animatediff_inpaint_pipe import AnimateDiffVideoToVideoPipeline 17 | from .src.ic_light_pipe import StableDiffusionImg2ImgPipeline 18 | from .src.tools import read_video, read_mask,set_all_seed, get_fg_video 19 | 20 | 21 | 22 | 23 | def main(args): 24 | 25 | config = OmegaConf.load(args.config) 26 | device = torch.device('cuda') 27 | adopted_dtype = torch.float16 28 | set_all_seed(42) 29 | 30 | ## vdm model 31 | adapter = MotionAdapter.from_pretrained(args.motion_adapter_model) 32 | #MotionAdapter.from_single_file(args.motion_adapter_model, torch_dtype=adopted_dtype) 33 | ## pipeline 34 | pipe = AnimateDiffVideoToVideoPipeline.from_pretrained(args.sd_model, motion_adapter=adapter) 35 | 36 | eul_scheduler = EulerAncestralDiscreteScheduler.from_pretrained( 37 | args.sd_model, 38 | subfolder="scheduler", 39 | beta_schedule="linear", 40 | ) 41 | 42 | pipe.scheduler = eul_scheduler 43 | pipe.enable_vae_slicing() 44 | pipe = pipe.to(device=device, dtype=adopted_dtype) 45 | pipe.vae.requires_grad_(False) 46 | pipe.unet.requires_grad_(False) 47 | 48 | ## ic-light model 49 | tokenizer = CLIPTokenizer.from_pretrained(args.sd_model, subfolder="tokenizer") 50 | text_encoder = CLIPTextModel.from_pretrained(args.sd_model, subfolder="text_encoder") 51 | vae = AutoencoderKL.from_pretrained(args.sd_model, subfolder="vae") 52 | unet = UNet2DConditionModel.from_pretrained(args.sd_model, subfolder="unet") 53 | with torch.no_grad(): 54 | new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) 55 | new_conv_in.weight.zero_() #torch.Size([320, 8, 3, 3]) 56 | new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) 57 | new_conv_in.bias = unet.conv_in.bias 58 | unet.conv_in = new_conv_in 59 | unet_original_forward = unet.forward 60 | 61 | def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs): 62 | 63 | c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample) 64 | c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0) 65 | new_sample = torch.cat([sample, c_concat], dim=1) 66 | kwargs['cross_attention_kwargs'] = {} 67 | return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs) 68 | unet.forward = hooked_unet_forward 69 | 70 | ## ic-light model loader 71 | sd_offset = sf.load_file(args.ic_light_model) 72 | sd_origin = unet.state_dict() 73 | sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()} 74 | unet.load_state_dict(sd_merged, strict=True) 75 | del sd_offset, sd_origin, sd_merged 76 | text_encoder = text_encoder.to(device=device, dtype=adopted_dtype) 77 | vae = vae.to(device=device, dtype=adopted_dtype) 78 | unet = unet.to(device=device, dtype=adopted_dtype) 79 | unet.set_attn_processor(AttnProcessor2_0()) 80 | vae.set_attn_processor(AttnProcessor2_0()) 81 | 82 | # Consistent light attention 83 | @torch.inference_mode() 84 | def custom_forward_CLA(self, 85 | hidden_states, 86 | gamma=config.get("gamma", 0.5), 87 | encoder_hidden_states=None, 88 | attention_mask=None, 89 | cross_attention_kwargs=None 90 | ): 91 | 92 | batch_size, sequence_length, channel = hidden_states.shape 93 | 94 | residual = hidden_states 95 | input_ndim = hidden_states.ndim 96 | if input_ndim == 4: 97 | batch_size, channel, height, width = hidden_states.shape 98 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 99 | 100 | if attention_mask is not None: 101 | if attention_mask.shape[-1] != query.shape[1]: 102 | target_length = query.shape[1] 103 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 104 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 105 | if self.group_norm is not None: 106 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 107 | if encoder_hidden_states is None: 108 | encoder_hidden_states = hidden_states 109 | 110 | query = self.to_q(hidden_states) 111 | key = self.to_k(encoder_hidden_states) 112 | value = self.to_v(encoder_hidden_states) 113 | inner_dim = key.shape[-1] 114 | head_dim = inner_dim // self.heads 115 | query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 116 | key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 117 | value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 118 | 119 | hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) 120 | shape = query.shape 121 | 122 | # addition key and value 123 | mean_key = key.reshape(2,-1,shape[1],shape[2],shape[3]).mean(dim=1,keepdim=True) 124 | mean_value = value.reshape(2,-1,shape[1],shape[2],shape[3]).mean(dim=1,keepdim=True) 125 | mean_key = mean_key.expand(-1,shape[0]//2,-1,-1,-1).reshape(shape[0],shape[1],shape[2],shape[3]) 126 | mean_value = mean_value.expand(-1,shape[0]//2,-1,-1,-1).reshape(shape[0],shape[1],shape[2],shape[3]) 127 | add_hidden_state = F.scaled_dot_product_attention(query, mean_key, mean_value, attn_mask=None, dropout_p=0.0, is_causal=False) 128 | 129 | # mix 130 | hidden_states = (1-gamma)*hidden_states + gamma*add_hidden_state 131 | 132 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) 133 | hidden_states = hidden_states.to(query.dtype) 134 | hidden_states = self.to_out[0](hidden_states) 135 | hidden_states = self.to_out[1](hidden_states) 136 | 137 | if input_ndim == 4: 138 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 139 | 140 | if self.residual_connection: 141 | hidden_states = hidden_states + residual 142 | 143 | hidden_states = hidden_states / self.rescale_output_factor 144 | return hidden_states 145 | 146 | ### attention 147 | @torch.inference_mode() 148 | def prep_unet_self_attention(unet): 149 | for name, module in unet.named_modules(): 150 | module_name = type(module).__name__ 151 | 152 | name_split_list = name.split(".") 153 | cond_1 = name_split_list[0] in "up_blocks" 154 | cond_2 = name_split_list[-1] in ('attn1') 155 | 156 | if "Attention" in module_name and cond_1 and cond_2: 157 | cond_3 = name_split_list[1] 158 | if cond_3 not in "3": 159 | module.forward = MethodType(custom_forward_CLA, module) 160 | 161 | return unet 162 | 163 | ## consistency light attention 164 | unet = prep_unet_self_attention(unet) 165 | 166 | ## ic-light-scheduler 167 | ic_light_scheduler = DPMSolverMultistepScheduler( 168 | num_train_timesteps=1000, 169 | beta_start=0.00085, 170 | beta_end=0.012, 171 | algorithm_type="sde-dpmsolver++", 172 | use_karras_sigmas=True, 173 | steps_offset=1 174 | ) 175 | ic_light_pipe = StableDiffusionImg2ImgPipeline( 176 | vae=vae, 177 | text_encoder=text_encoder, 178 | tokenizer=tokenizer, 179 | unet=unet, 180 | scheduler=ic_light_scheduler, 181 | safety_checker=None, 182 | requires_safety_checker=False, 183 | feature_extractor=None, 184 | image_encoder=None 185 | ) 186 | ic_light_pipe = ic_light_pipe.to(device) 187 | 188 | ############################# params ###################################### 189 | strength = config.get("strength", 0.5) 190 | num_step = config.get("num_step", 50) 191 | text_guide_scale = config.get("text_guide_scale", 4) 192 | seed = config.get("seed") 193 | image_width = config.get("width", 512) 194 | image_height = config.get("height", 512) 195 | n_prompt = config.get("n_prompt", "") 196 | inpaint_prompt = config.get("inpaint_prompt", "") 197 | relight_prompt = config.get("relight_prompt", "") 198 | video_path = config.get("video_path", "") 199 | bg_source = BGSource[config.get("bg_source")] 200 | save_path = config.get("save_path") 201 | 202 | ############################## infer ##################################### 203 | generator = torch.manual_seed(seed) 204 | video_name = os.path.basename(video_path) 205 | video_list, video_name = read_video(video_path, image_width, image_height) 206 | mask_folder = os.path.join("masks_animatediff", video_name.split('.')[-2]) 207 | mask_list = read_mask(mask_folder) 208 | 209 | print("################## begin ##################") 210 | ## get foreground video 211 | fg_video_tensor = get_fg_video(video_list, mask_list, device, adopted_dtype) ## torch.Size([16, 3, 512, 512]) 212 | 213 | with torch.no_grad(): 214 | relighter = Relighter( 215 | pipeline=ic_light_pipe, 216 | relight_prompt=relight_prompt, 217 | bg_source=bg_source, 218 | generator=generator, 219 | ) 220 | vdm_init_latent = relighter(fg_video_tensor) 221 | 222 | ## infer 223 | num_inference_steps = num_step 224 | output = pipe( 225 | ic_light_pipe=ic_light_pipe, 226 | relight_prompt=relight_prompt, 227 | bg_source=bg_source, 228 | mask=mask_list, 229 | vdm_init_latent=vdm_init_latent, 230 | video=video_list, 231 | prompt=inpaint_prompt, 232 | strength=strength, 233 | negative_prompt=n_prompt, 234 | guidance_scale=text_guide_scale, 235 | num_inference_steps=num_inference_steps, 236 | height=image_height, 237 | width=image_width, 238 | generator=generator, 239 | ) 240 | 241 | frames = output.frames[0] 242 | results_path = f"{save_path}/inpaint_{video_name}" 243 | imageio.mimwrite(results_path, frames, fps=8) 244 | print(f"relight with bg generation! prompt:{relight_prompt}, light:{bg_source.value}, save in {results_path}.") 245 | 246 | if __name__ == "__main__": 247 | parser = argparse.ArgumentParser() 248 | 249 | parser.add_argument("--sd_model", type=str, default="stablediffusionapi/realistic-vision-v51") 250 | parser.add_argument("--motion_adapter_model", type=str, default="guoyww/animatediff-motion-adapter-v1-5-3") 251 | parser.add_argument("--ic_light_model", type=str, default="./models/iclight_sd15_fc.safetensors") 252 | 253 | parser.add_argument("--config", type=str, default="configs/relight_inpaint/car.yaml", help="the config file for each sample.") 254 | 255 | args = parser.parse_args() 256 | main(args) -------------------------------------------------------------------------------- /lav_relight.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from types import MethodType 4 | import safetensors.torch as sf 5 | import torch.nn.functional as F 6 | from transformers import CLIPTokenizer 7 | from diffusers import MotionAdapter, EulerAncestralDiscreteScheduler 8 | from diffusers import UNet2DConditionModel, DPMSolverMultistepScheduler 9 | from diffusers.models.attention_processor import AttnProcessor2_0 10 | from torch.hub import download_url_to_file 11 | from .src.tools import get_fg_video 12 | from .src.animatediff_pipe import AnimateDiffVideoToVideoPipeline 13 | from .src.animatediff_inpaint_pipe import AnimateDiffVideoToVideoPipeline as AnimateDiffVideoToVideoPipeline_inpaint 14 | from .src.ic_light_pipe import StableDiffusionImg2ImgPipeline 15 | from .src.ic_light import Relighter 16 | from .node_utils import image2masks 17 | import logging 18 | logging.basicConfig(level=logging.INFO) 19 | 20 | 21 | def load_ic_light_model(pipeline,ic_light_model,ckpt_path,sd_repo,motion_repo,motion_adapter_model,device,adopted_dtype,mode): 22 | # vdm model 23 | adapter = MotionAdapter.from_single_file(motion_adapter_model,config=motion_repo) 24 | Unet=UNet2DConditionModel.from_single_file(ckpt_path, config=os.path.join(sd_repo, "unet")) 25 | # animate pipeline 26 | Vae=pipeline.vae 27 | tokenizer = CLIPTokenizer.from_pretrained(sd_repo, subfolder="tokenizer") 28 | Text_encoder =pipeline.text_encoder 29 | del pipeline 30 | 31 | if mode=='relight': 32 | pipe=AnimateDiffVideoToVideoPipeline.from_pretrained(sd_repo,unet=Unet,vae=Vae,text_encoder=Text_encoder,tokenizer=tokenizer,motion_adapter=adapter) 33 | else: 34 | pipe=AnimateDiffVideoToVideoPipeline_inpaint.from_pretrained(sd_repo,unet=Unet,vae=Vae,text_encoder=Text_encoder,tokenizer=tokenizer,motion_adapter=adapter) 35 | eul_scheduler = EulerAncestralDiscreteScheduler.from_pretrained( 36 | sd_repo, 37 | subfolder="scheduler", 38 | beta_schedule="linear", 39 | ) 40 | 41 | pipe.scheduler = eul_scheduler 42 | pipe.enable_vae_slicing() 43 | pipe = pipe.to(device=device, dtype=adopted_dtype) 44 | pipe.vae.requires_grad_(False) 45 | pipe.unet.requires_grad_(False) 46 | 47 | unet=UNet2DConditionModel.from_single_file(ckpt_path,config=os.path.join(sd_repo, "unet")) 48 | tokenizer=pipe.tokenizer 49 | text_encoder=pipe.text_encoder 50 | vae=Vae 51 | 52 | 53 | with torch.no_grad(): 54 | new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) 55 | new_conv_in.weight.zero_() #torch.Size([320, 8, 3, 3]) 56 | new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) 57 | new_conv_in.bias = unet.conv_in.bias 58 | unet.conv_in = new_conv_in 59 | unet_original_forward = unet.forward 60 | def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs): 61 | 62 | c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample) 63 | c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0) 64 | new_sample = torch.cat([sample, c_concat], dim=1) 65 | kwargs['cross_attention_kwargs'] = {} 66 | return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs) 67 | 68 | unet.forward = hooked_unet_forward 69 | 70 | if not os.path.exists(ic_light_model): 71 | download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors', 72 | dst=ic_light_model) #TO DO: change to ic-light model 73 | 74 | sd_offset = sf.load_file(ic_light_model) 75 | sd_origin = unet.state_dict() 76 | sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()} 77 | unet.load_state_dict(sd_merged, strict=True) 78 | del sd_offset, sd_origin, sd_merged 79 | text_encoder = text_encoder.to(device=device, dtype=adopted_dtype) 80 | vae = vae.to(device=device, dtype=adopted_dtype) 81 | unet = unet.to(device=device, dtype=adopted_dtype) 82 | 83 | unet.set_attn_processor(AttnProcessor2_0()) 84 | vae.set_attn_processor(AttnProcessor2_0()) 85 | 86 | @torch.inference_mode() 87 | def custom_forward_CLA(self, 88 | hidden_states, 89 | gamma=0.5, 90 | encoder_hidden_states=None, 91 | attention_mask=None, 92 | cross_attention_kwargs=None 93 | ): 94 | 95 | batch_size, sequence_length, channel = hidden_states.shape 96 | 97 | residual = hidden_states 98 | input_ndim = hidden_states.ndim 99 | if input_ndim == 4: 100 | batch_size, channel, height, width = hidden_states.shape 101 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 102 | 103 | if attention_mask is not None: 104 | if attention_mask.shape[-1] != query.shape[1]: 105 | target_length = query.shape[1] 106 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 107 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 108 | if self.group_norm is not None: 109 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 110 | if encoder_hidden_states is None: 111 | encoder_hidden_states = hidden_states 112 | 113 | query = self.to_q(hidden_states) 114 | key = self.to_k(encoder_hidden_states) 115 | value = self.to_v(encoder_hidden_states) 116 | inner_dim = key.shape[-1] 117 | head_dim = inner_dim // self.heads 118 | query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 119 | key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 120 | value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 121 | 122 | hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) 123 | shape = query.shape 124 | 125 | # addition key and value 126 | mean_key = key.reshape(2,-1,shape[1],shape[2],shape[3]).mean(dim=1,keepdim=True) 127 | mean_value = value.reshape(2,-1,shape[1],shape[2],shape[3]).mean(dim=1,keepdim=True) 128 | mean_key = mean_key.expand(-1,shape[0]//2,-1,-1,-1).reshape(shape[0],shape[1],shape[2],shape[3]) 129 | mean_value = mean_value.expand(-1,shape[0]//2,-1,-1,-1).reshape(shape[0],shape[1],shape[2],shape[3]) 130 | add_hidden_state = F.scaled_dot_product_attention(query, mean_key, mean_value, attn_mask=None, dropout_p=0.0, is_causal=False) 131 | 132 | # mix 133 | hidden_states = (1-gamma)*hidden_states + gamma*add_hidden_state 134 | 135 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) 136 | hidden_states = hidden_states.to(query.dtype) 137 | hidden_states = self.to_out[0](hidden_states) 138 | hidden_states = self.to_out[1](hidden_states) 139 | 140 | if input_ndim == 4: 141 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 142 | 143 | if self.residual_connection: 144 | hidden_states = hidden_states + residual 145 | 146 | hidden_states = hidden_states / self.rescale_output_factor 147 | return hidden_states 148 | 149 | ### attention 150 | @torch.inference_mode() 151 | def prep_unet_self_attention(unet): 152 | for name, module in unet.named_modules(): 153 | module_name = type(module).__name__ 154 | 155 | name_split_list = name.split(".") 156 | cond_1 = name_split_list[0] in "up_blocks" 157 | cond_2 = name_split_list[-1] in ('attn1') 158 | 159 | if "Attention" in module_name and cond_1 and cond_2: 160 | cond_3 = name_split_list[1] 161 | if cond_3 not in "3": 162 | module.forward = MethodType(custom_forward_CLA, module) 163 | 164 | return unet 165 | 166 | ## consistency light attention 167 | unet = prep_unet_self_attention(unet) 168 | ## ic-light-scheduler 169 | ic_light_scheduler = DPMSolverMultistepScheduler( 170 | num_train_timesteps=1000, 171 | beta_start=0.00085, 172 | beta_end=0.012, 173 | algorithm_type="sde-dpmsolver++", 174 | use_karras_sigmas=True, 175 | steps_offset=1 176 | ) 177 | ic_light_pipe = StableDiffusionImg2ImgPipeline( 178 | vae=vae, 179 | text_encoder=text_encoder, 180 | tokenizer=tokenizer, 181 | unet=unet, 182 | scheduler=ic_light_scheduler, 183 | safety_checker=None, 184 | requires_safety_checker=False, 185 | feature_extractor=None, 186 | image_encoder=None 187 | ) 188 | ic_light_pipe = ic_light_pipe.to(device) 189 | return pipe,ic_light_pipe 190 | 191 | 192 | def infer_relight(ic_light_pipe,pipe,strength,num_step,text_guide_scale,seed,image_width,image_height,n_prompt,relight_prompt,inpaint_prompt,video_list,bg_source,mode,mask_list,device,adopted_dtype,repo,fps,local_sam): 193 | 194 | generator = torch.manual_seed(seed) 195 | if mode == "inpaint": 196 | if mask_list is None: 197 | if repo: 198 | print("No mask_list provided, generating mask_list from repo...") 199 | mask_list = image2masks(repo,video_list) 200 | else: 201 | from .sam2 import get_mask 202 | mask_list=get_mask(video_list,fps,local_sam) 203 | print("No mask_list provided, use sam2 generating mask_list from video...") 204 | fg_video_tensor = get_fg_video(video_list, mask_list, device, adopted_dtype) 205 | with torch.no_grad(): 206 | relighter = Relighter( 207 | pipeline=ic_light_pipe, 208 | relight_prompt=relight_prompt, 209 | bg_source=bg_source, 210 | generator=generator, 211 | ) 212 | vdm_init_latent = relighter(fg_video_tensor) 213 | 214 | ## infer 215 | num_inference_steps = num_step 216 | output = pipe( 217 | ic_light_pipe=ic_light_pipe, 218 | relight_prompt=relight_prompt, 219 | bg_source=bg_source, 220 | mask=mask_list, 221 | vdm_init_latent=vdm_init_latent, 222 | video=video_list, 223 | prompt=inpaint_prompt, 224 | strength=strength, 225 | negative_prompt=n_prompt, 226 | guidance_scale=text_guide_scale, 227 | num_inference_steps=num_inference_steps, 228 | height=image_height, 229 | width=image_width, 230 | generator=generator, 231 | ) 232 | else: 233 | with torch.no_grad(): 234 | num_inference_steps = int(round(num_step / strength)) 235 | 236 | output = pipe( 237 | ic_light_pipe=ic_light_pipe, 238 | relight_prompt=relight_prompt, 239 | bg_source=bg_source, 240 | video=video_list, 241 | prompt=relight_prompt, 242 | strength=strength, 243 | negative_prompt=n_prompt, 244 | guidance_scale=text_guide_scale, 245 | num_inference_steps=num_inference_steps, 246 | height=image_height, 247 | width=image_width, 248 | generator=generator, 249 | ) 250 | frames = output.frames[0] 251 | return frames 252 | 253 | -------------------------------------------------------------------------------- /lav_wan_relight.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import imageio 4 | #import argparse 5 | import numpy as np 6 | import safetensors.torch as sf 7 | 8 | from omegaconf import OmegaConf 9 | import torch.nn.functional as F 10 | from torch.hub import download_url_to_file 11 | import gc 12 | from transformers import CLIPTextModel, CLIPTokenizer 13 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler 14 | from diffusers.models.attention_processor import AttnProcessor2_0 15 | 16 | from .src.ic_light_pipe import StableDiffusionImg2ImgPipeline 17 | from .src.wan_pipe import WanVideoToVideoPipeline 18 | from .src.ic_light import BGSource 19 | from diffusers import FlowMatchEulerDiscreteScheduler 20 | 21 | 22 | def load_ic_light_wan(repo,sd_pipe,sd_repo,ckpt_path,ic_light_model,device,adopted_dtype): 23 | # config = OmegaConf.load(args.config) 24 | # device = torch.device('cuda') 25 | # adopted_dtype = torch.float16 26 | #set_all_seed(42) 27 | ## vdm model 28 | from diffusers import AutoencoderKLWan 29 | vae = AutoencoderKLWan.from_pretrained(repo, subfolder="vae", torch_dtype=adopted_dtype) 30 | pipe = WanVideoToVideoPipeline.from_pretrained(repo, vae=vae, torch_dtype=adopted_dtype) 31 | 32 | FlowMatching_scheduler = FlowMatchEulerDiscreteScheduler(shift=3.0) 33 | pipe.scheduler = FlowMatching_scheduler 34 | 35 | pipe = pipe.to(device=device, dtype=adopted_dtype) 36 | pipe.vae.requires_grad_(False) 37 | pipe.transformer.requires_grad_(False) 38 | 39 | ## module 40 | tokenizer = CLIPTokenizer.from_pretrained(sd_repo, subfolder="tokenizer") 41 | unet=sd_pipe.unet #UNet2DConditionModel.from_pretrained(sd_repo, subfolder="unet") 确认是否有效 42 | text_encoder = sd_pipe.text_encoder 43 | vae = sd_pipe.vae 44 | del sd_pipe 45 | gc.collect() 46 | 47 | 48 | with torch.no_grad(): 49 | new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) 50 | new_conv_in.weight.zero_() #torch.Size([320, 8, 3, 3]) 51 | new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) 52 | new_conv_in.bias = unet.conv_in.bias 53 | unet.conv_in = new_conv_in 54 | unet_original_forward = unet.forward 55 | def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs): 56 | c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample) 57 | c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0) 58 | new_sample = torch.cat([sample, c_concat], dim=1) 59 | kwargs['cross_attention_kwargs'] = {} 60 | return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs) 61 | unet.forward = hooked_unet_forward 62 | 63 | ## ic-light model loader 64 | if not os.path.exists(ic_light_model): 65 | download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors', 66 | dst=ic_light_model) #TO DO: change to ic-light model 67 | sd_offset = sf.load_file(ic_light_model) 68 | sd_origin = unet.state_dict() 69 | sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()} 70 | unet.load_state_dict(sd_merged, strict=True) 71 | del sd_offset, sd_origin, sd_merged 72 | text_encoder = text_encoder.to(device=device, dtype=adopted_dtype) 73 | vae = vae.to(device=device, dtype=adopted_dtype) 74 | unet = unet.to(device=device, dtype=adopted_dtype) 75 | 76 | unet.set_attn_processor(AttnProcessor2_0()) 77 | vae.set_attn_processor(AttnProcessor2_0()) 78 | @torch.inference_mode() 79 | def custom_forward_CLA(self, 80 | hidden_states, 81 | gamma=0.7, 82 | encoder_hidden_states=None, 83 | attention_mask=None, 84 | cross_attention_kwargs=None 85 | ): 86 | 87 | batch_size, sequence_length, channel = hidden_states.shape 88 | 89 | residual = hidden_states 90 | input_ndim = hidden_states.ndim 91 | if input_ndim == 4: 92 | batch_size, channel, height, width = hidden_states.shape 93 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 94 | 95 | if attention_mask is not None: 96 | if attention_mask.shape[-1] != query.shape[1]: 97 | target_length = query.shape[1] 98 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 99 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 100 | if self.group_norm is not None: 101 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 102 | if encoder_hidden_states is None: 103 | encoder_hidden_states = hidden_states 104 | 105 | query = self.to_q(hidden_states) 106 | key = self.to_k(encoder_hidden_states) 107 | value = self.to_v(encoder_hidden_states) 108 | inner_dim = key.shape[-1] 109 | head_dim = inner_dim // self.heads 110 | query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 111 | key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 112 | value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 113 | 114 | hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) 115 | shape = query.shape 116 | mean_key = key.reshape(2,-1,shape[1],shape[2],shape[3]).mean(dim=1,keepdim=True).expand(-1,shape[0]//2,-1,-1,-1).reshape(shape[0],shape[1],shape[2],shape[3]) 117 | mean_value = value.reshape(2,-1,shape[1],shape[2],shape[3]).mean(dim=1,keepdim=True).expand(-1,shape[0]//2,-1,-1,-1).reshape(shape[0],shape[1],shape[2],shape[3]) 118 | hidden_states_mean = F.scaled_dot_product_attention(query, mean_key, mean_value, attn_mask=None, dropout_p=0.0, is_causal=False) 119 | 120 | hidden_states = (1-gamma)*hidden_states + gamma*hidden_states_mean 121 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) 122 | hidden_states = hidden_states.to(query.dtype) 123 | 124 | hidden_states = self.to_out[0](hidden_states) 125 | hidden_states = self.to_out[1](hidden_states) 126 | 127 | if input_ndim == 4: 128 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 129 | 130 | if self.residual_connection: 131 | hidden_states = hidden_states + residual 132 | 133 | hidden_states = hidden_states / self.rescale_output_factor 134 | 135 | 136 | 137 | 138 | 139 | return hidden_states 140 | 141 | ### attention 142 | from types import MethodType 143 | @torch.inference_mode() 144 | def prep_unet_self_attention(unet): 145 | 146 | for name, module in unet.named_modules(): 147 | module_name = type(module).__name__ 148 | 149 | name_split_list = name.split(".") 150 | cond_1 = name_split_list[0] in "up_blocks" 151 | cond_2 = name_split_list[-1] in ('attn1') 152 | 153 | if "Attention" in module_name and cond_1 and cond_2: 154 | cond_3 = name_split_list[1] 155 | if cond_3 not in "3": 156 | module.forward = MethodType(custom_forward_CLA, module) 157 | 158 | return unet 159 | 160 | ## attn module 161 | unet = prep_unet_self_attention(unet) 162 | 163 | ## ic-light-scheduler 164 | ic_light_scheduler = DPMSolverMultistepScheduler( 165 | num_train_timesteps=1000, 166 | beta_start=0.00085, 167 | beta_end=0.012, 168 | algorithm_type="sde-dpmsolver++", 169 | use_karras_sigmas=True, 170 | steps_offset=1 171 | ) 172 | ic_light_pipe = StableDiffusionImg2ImgPipeline( 173 | vae=vae, 174 | text_encoder=text_encoder, 175 | tokenizer=tokenizer, 176 | unet=unet, 177 | scheduler=ic_light_scheduler, 178 | safety_checker=None, 179 | requires_safety_checker=False, 180 | feature_extractor=None, 181 | image_encoder=None 182 | ) 183 | ic_light_pipe = ic_light_pipe.to(device=device, dtype=adopted_dtype) 184 | ic_light_pipe.vae.requires_grad_(False) 185 | ic_light_pipe.unet.requires_grad_(False) 186 | 187 | # pipeline = { 188 | # "transformer": pipe.transformer, 189 | # "text_encoder": pipe.text_encoder, 190 | # "vae": pipe.vae 191 | # } 192 | 193 | 194 | return pipe,ic_light_pipe 195 | 196 | 197 | def infer_relight_wan(ic_light_pipe,pipe,strength,num_step,text_guide_scale,seed,image_width,image_height,n_prompt,relight_prompt,vdm_prompt,video_list,bg_source,num_frames,mode,mask_list,device,adopted_dtype,repo,fps,local_sam): 198 | ############################## infer ##################################### 199 | generator = torch.manual_seed(seed) 200 | # video_name = os.path.basename(video_path) 201 | # video_list, video_name = read_video(video_path, image_width, image_height) 202 | print("################## begin ##################") 203 | with torch.no_grad(): 204 | num_inference_steps = int(round(num_step / strength)) 205 | output = pipe( 206 | ic_light_pipe=ic_light_pipe, 207 | relight_prompt=relight_prompt, 208 | bg_source=bg_source, 209 | video=video_list, 210 | prompt=vdm_prompt, 211 | negative_prompt=n_prompt, 212 | strength=strength, 213 | guidance_scale=text_guide_scale, 214 | num_inference_steps=num_inference_steps, 215 | height=image_height, 216 | num_frames=num_frames, 217 | width=image_width, 218 | generator=generator, 219 | ) 220 | 221 | frames = output.frames[0] 222 | # frames = (frames * 255).astype(np.uint8) 223 | # results_path = f"{save_path}/relight_{video_name}" 224 | # imageio.mimwrite(results_path, frames, fps=14) 225 | # print(f"relight! prompt:{relight_prompt}, light:{bg_source.value}, save in {results_path}.") 226 | return frames 227 | 228 | 229 | #def main(args): 230 | 231 | # config = OmegaConf.load(args.config) 232 | # device = torch.device('cuda') 233 | # adopted_dtype = torch.float16 234 | # set_all_seed(42) 235 | 236 | # ## vdm model 237 | # vae = AutoencoderKLWan.from_pretrained(args.vdm_model, subfolder="vae", torch_dtype=adopted_dtype) 238 | # pipe = WanVideoToVideoPipeline.from_pretrained(args.vdm_model, vae=vae, torch_dtype=adopted_dtype) 239 | 240 | # pipe = pipe.to(device=device, dtype=adopted_dtype) 241 | # pipe.vae.requires_grad_(False) 242 | # pipe.transformer.requires_grad_(False) 243 | 244 | # ## module 245 | # tokenizer = CLIPTokenizer.from_pretrained(args.sd_model, subfolder="tokenizer") 246 | # text_encoder = CLIPTextModel.from_pretrained(args.sd_model, subfolder="text_encoder") 247 | # vae = AutoencoderKL.from_pretrained(args.sd_model, subfolder="vae") 248 | # unet = UNet2DConditionModel.from_pretrained(args.sd_model, subfolder="unet") 249 | # with torch.no_grad(): 250 | # new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) 251 | # new_conv_in.weight.zero_() #torch.Size([320, 8, 3, 3]) 252 | # new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) 253 | # new_conv_in.bias = unet.conv_in.bias 254 | # unet.conv_in = new_conv_in 255 | # unet_original_forward = unet.forward 256 | 257 | # def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs): 258 | # c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample) 259 | # c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0) 260 | # new_sample = torch.cat([sample, c_concat], dim=1) 261 | # kwargs['cross_attention_kwargs'] = {} 262 | # return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs) 263 | # unet.forward = hooked_unet_forward 264 | 265 | # ## ic-light model loader 266 | # if not os.path.exists(args.ic_light_model): 267 | # download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors', 268 | # dst=args.ic_light_model) 269 | # sd_offset = sf.load_file(args.ic_light_model) 270 | # sd_origin = unet.state_dict() 271 | # sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()} 272 | # unet.load_state_dict(sd_merged, strict=True) 273 | # del sd_offset, sd_origin, sd_merged 274 | # text_encoder = text_encoder.to(device=device, dtype=adopted_dtype) 275 | # vae = vae.to(device=device, dtype=adopted_dtype) 276 | # unet = unet.to(device=device, dtype=adopted_dtype) 277 | # unet.set_attn_processor(AttnProcessor2_0()) 278 | # vae.set_attn_processor(AttnProcessor2_0()) 279 | 280 | # Consistent light attention 281 | # @torch.inference_mode() 282 | # def custom_forward_CLA(self, 283 | # hidden_states, 284 | # gamma=0.7, 285 | # encoder_hidden_states=None, 286 | # attention_mask=None, 287 | # cross_attention_kwargs=None 288 | # ): 289 | 290 | # batch_size, sequence_length, channel = hidden_states.shape 291 | 292 | # residual = hidden_states 293 | # input_ndim = hidden_states.ndim 294 | # if input_ndim == 4: 295 | # batch_size, channel, height, width = hidden_states.shape 296 | # hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 297 | 298 | # if attention_mask is not None: 299 | # if attention_mask.shape[-1] != query.shape[1]: 300 | # target_length = query.shape[1] 301 | # attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 302 | # attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 303 | # if self.group_norm is not None: 304 | # hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 305 | # if encoder_hidden_states is None: 306 | # encoder_hidden_states = hidden_states 307 | 308 | # query = self.to_q(hidden_states) 309 | # key = self.to_k(encoder_hidden_states) 310 | # value = self.to_v(encoder_hidden_states) 311 | # inner_dim = key.shape[-1] 312 | # head_dim = inner_dim // self.heads 313 | # query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 314 | # key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 315 | # value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 316 | 317 | # hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) 318 | # shape = query.shape 319 | # mean_key = key.reshape(2,-1,shape[1],shape[2],shape[3]).mean(dim=1,keepdim=True).expand(-1,shape[0]//2,-1,-1,-1).reshape(shape[0],shape[1],shape[2],shape[3]) 320 | # mean_value = value.reshape(2,-1,shape[1],shape[2],shape[3]).mean(dim=1,keepdim=True).expand(-1,shape[0]//2,-1,-1,-1).reshape(shape[0],shape[1],shape[2],shape[3]) 321 | # hidden_states_mean = F.scaled_dot_product_attention(query, mean_key, mean_value, attn_mask=None, dropout_p=0.0, is_causal=False) 322 | 323 | # hidden_states = (1-gamma)*hidden_states + gamma*hidden_states_mean 324 | # hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) 325 | # hidden_states = hidden_states.to(query.dtype) 326 | 327 | # hidden_states = self.to_out[0](hidden_states) 328 | # hidden_states = self.to_out[1](hidden_states) 329 | 330 | # if input_ndim == 4: 331 | # hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 332 | 333 | # if self.residual_connection: 334 | # hidden_states = hidden_states + residual 335 | 336 | # hidden_states = hidden_states / self.rescale_output_factor 337 | 338 | 339 | # return hidden_states 340 | 341 | # ### attention 342 | # from types import MethodType 343 | # @torch.inference_mode() 344 | # def prep_unet_self_attention(unet): 345 | 346 | # for name, module in unet.named_modules(): 347 | # module_name = type(module).__name__ 348 | 349 | # name_split_list = name.split(".") 350 | # cond_1 = name_split_list[0] in "up_blocks" 351 | # cond_2 = name_split_list[-1] in ('attn1') 352 | 353 | # if "Attention" in module_name and cond_1 and cond_2: 354 | # cond_3 = name_split_list[1] 355 | # if cond_3 not in "3": 356 | # module.forward = MethodType(custom_forward_CLA, module) 357 | 358 | # return unet 359 | 360 | # ## attn module 361 | # unet = prep_unet_self_attention(unet) 362 | 363 | # ## ic-light-scheduler 364 | # ic_light_scheduler = DPMSolverMultistepScheduler( 365 | # num_train_timesteps=1000, 366 | # beta_start=0.00085, 367 | # beta_end=0.012, 368 | # algorithm_type="sde-dpmsolver++", 369 | # use_karras_sigmas=True, 370 | # steps_offset=1 371 | # ) 372 | # ic_light_pipe = StableDiffusionImg2ImgPipeline( 373 | # vae=vae, 374 | # text_encoder=text_encoder, 375 | # tokenizer=tokenizer, 376 | # unet=unet, 377 | # scheduler=ic_light_scheduler, 378 | # safety_checker=None, 379 | # requires_safety_checker=False, 380 | # feature_extractor=None, 381 | # image_encoder=None 382 | # ) 383 | # ic_light_pipe = ic_light_pipe.to(device=device, dtype=adopted_dtype) 384 | # ic_light_pipe.vae.requires_grad_(False) 385 | # ic_light_pipe.unet.requires_grad_(False) 386 | 387 | ############################# params ###################################### 388 | # strength = config.get("strength", 0.4) 389 | # num_step = config.get("num_step", 25) 390 | # text_guide_scale = config.get("text_guide_scale", 2) 391 | # seed = config.get("seed") 392 | # image_width = config.get("width", 512) 393 | # image_height = config.get("height", 512) 394 | # negative_prompt = config.get("n_prompt", "") 395 | # vdm_prompt = config.get("vdm_prompt", "") 396 | # relight_prompt = config.get("relight_prompt", "") 397 | # video_path = config.get("video_path", "") 398 | # bg_source = BGSource[config.get("bg_source")] 399 | # save_path = config.get("save_path") 400 | # num_frames = config.get("num_frames", 49) 401 | 402 | # ############################## infer ##################################### 403 | # generator = torch.manual_seed(seed) 404 | # video_name = os.path.basename(video_path) 405 | # video_list, video_name = read_video(video_path, image_width, image_height) 406 | 407 | # print("################## begin ##################") 408 | # with torch.no_grad(): 409 | # num_inference_steps = int(round(num_step / strength)) 410 | 411 | # output = pipe( 412 | # ic_light_pipe=ic_light_pipe, 413 | # relight_prompt=relight_prompt, 414 | # bg_source=bg_source, 415 | # video=video_list, 416 | # prompt=vdm_prompt, 417 | # negative_prompt=negative_prompt, 418 | # strength=strength, 419 | # guidance_scale=text_guide_scale, 420 | # num_inference_steps=num_inference_steps, 421 | # height=image_height, 422 | # num_frames=num_frames, 423 | # width=image_width, 424 | # generator=generator, 425 | # ) 426 | 427 | # frames = output.frames[0] 428 | # frames = (frames * 255).astype(np.uint8) 429 | # results_path = f"{save_path}/relight_{video_name}" 430 | # imageio.mimwrite(results_path, frames, fps=14) 431 | # print(f"relight! prompt:{relight_prompt}, light:{bg_source.value}, save in {results_path}.") 432 | 433 | # if __name__ == "__main__": 434 | # parser = argparse.ArgumentParser() 435 | 436 | # parser.add_argument("--sd_model", type=str, default="stablediffusionapi/realistic-vision-v51") 437 | # parser.add_argument("--vdm_model", type=str, default="Wan-AI/Wan2.1-T2V-1.3B-Diffusers") 438 | # parser.add_argument("--ic_light_model", type=str, default="./models/iclight_sd15_fc.safetensors") 439 | 440 | # parser.add_argument("--config", type=str, default="configs/wan_relight/man.yaml", help="the config file for each sample.") 441 | 442 | # args = parser.parse_args() 443 | # main(args) 444 | -------------------------------------------------------------------------------- /node_utils.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | import os 4 | import torch 5 | from PIL import Image 6 | import numpy as np 7 | import cv2 8 | import gc 9 | import time 10 | from transformers import AutoModelForImageSegmentation 11 | from comfy.utils import common_upscale,ProgressBar 12 | from huggingface_hub import hf_hub_download 13 | import torchvision.transforms as transforms 14 | cur_path = os.path.dirname(os.path.abspath(__file__)) 15 | device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 16 | 17 | 18 | def cv2pil(cv_image): 19 | """ 20 | 将OpenCV图像转换为PIL图像 21 | :param cv_image: OpenCV图像 22 | :return: PIL图像 23 | """ 24 | # 将图像从BGR转换为RGB 25 | rgb_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB) 26 | # 使用PIL的Image.fromarray方法将NumPy数组转换为PIL图像 27 | pil_image = Image.fromarray(rgb_image) 28 | return pil_image 29 | 30 | def image2masks(repo,video_image): 31 | start_time = time.time() 32 | model = AutoModelForImageSegmentation.from_pretrained(repo, trust_remote_code=True) 33 | torch.set_float32_matmul_precision(['high', 'highest'][0]) 34 | model.to('cuda') 35 | model.eval() 36 | # Data settings 37 | image_size = (1024, 1024) 38 | transform_image = transforms.Compose([ 39 | transforms.Resize(image_size), 40 | transforms.ToTensor(), 41 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 42 | ]) 43 | masks=[] 44 | for img in video_image: 45 | input_images = transform_image(img).unsqueeze(0).to('cuda') 46 | # Prediction 47 | with torch.no_grad(): 48 | preds = model(input_images)[-1].sigmoid().cpu() 49 | pred = preds[0].squeeze() 50 | pred_pil = transforms.ToPILImage()(pred) 51 | mask = pred_pil.resize(img.size) 52 | #img.putalpha(mask) 53 | masks.append(mask.convert('RGB')) 54 | end_time = time.time() 55 | load_time = end_time - start_time 56 | print(f"image2masks infer time: {load_time:.4f} s") 57 | model.to('cpu') 58 | gc.collect() 59 | torch.cuda.empty_cache() 60 | return masks 61 | 62 | def tensor_to_pil(tensor): 63 | image_np = tensor.squeeze().mul(255).clamp(0, 255).byte().numpy() 64 | image = Image.fromarray(image_np, mode='RGB') 65 | return image 66 | 67 | def tensor2pil_list(image,width,height): 68 | B,_,_,_=image.size() 69 | if B==1: 70 | ref_image_list=[tensor2pil_upscale(image,width,height)] 71 | else: 72 | img_list = list(torch.chunk(image, chunks=B)) 73 | ref_image_list = [tensor2pil_upscale(img,width,height) for img in img_list] 74 | return ref_image_list 75 | 76 | 77 | def tensor_upscale(img_tensor, width, height): 78 | samples = img_tensor.movedim(-1, 1) 79 | img = common_upscale(samples, width, height, "nearest-exact", "center") 80 | samples = img.movedim(1, -1) 81 | return samples 82 | 83 | def tensor2pil_upscale(img_tensor, width, height): 84 | samples = img_tensor.movedim(-1, 1) 85 | img = common_upscale(samples, width, height, "nearest-exact", "center") 86 | samples = img.movedim(1, -1) 87 | img_pil = tensor_to_pil(samples) 88 | return img_pil 89 | 90 | 91 | def tensor2cv(tensor_image,RGB2BGR=True): 92 | if len(tensor_image.shape)==4:#bhwc to hwc 93 | tensor_image=tensor_image.squeeze(0) 94 | if tensor_image.is_cuda: 95 | tensor_image = tensor_image.cpu().detach() 96 | tensor_image=tensor_image.numpy() 97 | #反归一化 98 | maxValue=tensor_image.max() 99 | tensor_image=tensor_image*255/maxValue 100 | img_cv2=np.uint8(tensor_image)#32 to uint8 101 | if RGB2BGR: 102 | img_cv2=cv2.cvtColor(img_cv2,cv2.COLOR_RGB2BGR) 103 | return img_cv2 104 | 105 | def cvargb2tensor(img): 106 | assert type(img) == np.ndarray, 'the img type is {}, but ndarry expected'.format(type(img)) 107 | img = torch.from_numpy(img.transpose((2, 0, 1))) 108 | return img.float().div(255).unsqueeze(0) # 255也可以改为256 109 | 110 | def cv2tensor(img): 111 | assert type(img) == np.ndarray, 'the img type is {}, but ndarry expected'.format(type(img)) 112 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 113 | img = torch.from_numpy(img.transpose((2, 0, 1))) 114 | return img.float().div(255).unsqueeze(0) # 255也可以改为256 115 | 116 | def images_generator(img_list: list,): 117 | #get img size 118 | sizes = {} 119 | for image_ in img_list: 120 | if isinstance(image_,Image.Image): 121 | count = sizes.get(image_.size, 0) 122 | sizes[image_.size] = count + 1 123 | elif isinstance(image_,np.ndarray): 124 | count = sizes.get(image_.shape[:2][::-1], 0) 125 | sizes[image_.shape[:2][::-1]] = count + 1 126 | else: 127 | raise "unsupport image list,must be pil or cv2!!!" 128 | size = max(sizes.items(), key=lambda x: x[1])[0] 129 | yield size[0], size[1] 130 | 131 | # any to tensor 132 | def load_image(img_in): 133 | if isinstance(img_in, Image.Image): 134 | img_in=img_in.convert("RGB") 135 | i = np.array(img_in, dtype=np.float32) 136 | i = torch.from_numpy(i).div_(255) 137 | if i.shape[0] != size[1] or i.shape[1] != size[0]: 138 | i = torch.from_numpy(i).movedim(-1, 0).unsqueeze(0) 139 | i = common_upscale(i, size[0], size[1], "lanczos", "center") 140 | i = i.squeeze(0).movedim(0, -1).numpy() 141 | return i 142 | elif isinstance(img_in,np.ndarray): 143 | i=cv2.cvtColor(img_in,cv2.COLOR_BGR2RGB).astype(np.float32) 144 | i = torch.from_numpy(i).div_(255) 145 | #print(i.shape) 146 | return i 147 | else: 148 | raise "unsupport image list,must be pil,cv2 or tensor!!!" 149 | 150 | total_images = len(img_list) 151 | processed_images = 0 152 | pbar = ProgressBar(total_images) 153 | images = map(load_image, img_list) 154 | try: 155 | prev_image = next(images) 156 | while True: 157 | next_image = next(images) 158 | yield prev_image 159 | processed_images += 1 160 | pbar.update_absolute(processed_images, total_images) 161 | prev_image = next_image 162 | except StopIteration: 163 | pass 164 | if prev_image is not None: 165 | yield prev_image 166 | 167 | def load_images(img_list: list,): 168 | gen = images_generator(img_list) 169 | (width, height) = next(gen) 170 | images = torch.from_numpy(np.fromiter(gen, np.dtype((np.float32, (height, width, 3))))) 171 | if len(images) == 0: 172 | raise FileNotFoundError(f"No images could be loaded .") 173 | return images 174 | 175 | def tensor2pil(tensor): 176 | image_np = tensor.squeeze().mul(255).clamp(0, 255).byte().numpy() 177 | image = Image.fromarray(image_np, mode='RGB') 178 | return image 179 | 180 | def pil2narry(img): 181 | narry = torch.from_numpy(np.array(img).astype(np.float32) / 255.0).unsqueeze(0) 182 | return narry 183 | 184 | def equalize_lists(list1, list2): 185 | """ 186 | 比较两个列表的长度,如果不一致,则将较短的列表复制以匹配较长列表的长度。 187 | 188 | 参数: 189 | list1 (list): 第一个列表 190 | list2 (list): 第二个列表 191 | 192 | 返回: 193 | tuple: 包含两个长度相等的列表的元组 194 | """ 195 | len1 = len(list1) 196 | len2 = len(list2) 197 | 198 | if len1 == len2: 199 | pass 200 | elif len1 < len2: 201 | print("list1 is shorter than list2, copying list1 to match list2's length.") 202 | list1.extend(list1 * ((len2 // len1) + 1)) # 复制list1以匹配list2的长度 203 | list1 = list1[:len2] # 确保长度一致 204 | else: 205 | print("list2 is shorter than list1, copying list2 to match list1's length.") 206 | list2.extend(list2 * ((len1 // len2) + 1)) # 复制list2以匹配list1的长度 207 | list2 = list2[:len1] # 确保长度一致 208 | 209 | return list1, list2 210 | 211 | def file_exists(directory, filename): 212 | # 构建文件的完整路径 213 | file_path = os.path.join(directory, filename) 214 | # 检查文件是否存在 215 | return os.path.isfile(file_path) 216 | 217 | def download_weights(file_dir,repo_id,subfolder="",pt_name=""): 218 | if subfolder: 219 | file_path = os.path.join(file_dir,subfolder, pt_name) 220 | sub_dir=os.path.join(file_dir,subfolder) 221 | if not os.path.exists(sub_dir): 222 | os.makedirs(sub_dir) 223 | if not os.path.exists(file_path): 224 | file_path = hf_hub_download( 225 | repo_id=repo_id, 226 | subfolder=subfolder, 227 | filename=pt_name, 228 | local_dir = file_dir, 229 | ) 230 | return file_path 231 | else: 232 | file_path = os.path.join(file_dir, pt_name) 233 | if not os.path.exists(file_dir): 234 | os.makedirs(file_dir) 235 | if not os.path.exists(file_path): 236 | file_path = hf_hub_download( 237 | repo_id=repo_id, 238 | filename=pt_name, 239 | local_dir=file_dir, 240 | ) 241 | return file_path 242 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "light_a_video" 3 | description = "Light-A-Video: Training-free Video Relighting via Progressive Light Fusion,you can use it in comfyUI" 4 | version = "1.0.0" 5 | license = {file = "LICENSE"} 6 | dependencies = ["#diffusers==0.32.1", "diffusers", "#transformers==4.47.1", "transformers", "opencv-python", "safetensors", "pillow", "einops", "peft", "imageio", "omegaconf", "#ultralytic", "tqdm==4.67.1", "protobuf==3.20", "torch", "torchvision", "moviepy"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/smthemex/ComfyUI_Light_A_Video" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "smthemex" 14 | DisplayName = "ComfyUI_Light_A_Video" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | #diffusers==0.32.1 2 | diffusers 3 | #transformers==4.47.1 4 | transformers 5 | opencv-python 6 | safetensors 7 | pillow 8 | einops 9 | peft 10 | imageio 11 | omegaconf 12 | #ultralytic 13 | tqdm==4.67.1 14 | protobuf==3.20 15 | torch 16 | torchvision 17 | moviepy -------------------------------------------------------------------------------- /sam2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | #import argparse 4 | import numpy as np 5 | from PIL import Image 6 | import random 7 | import imageio 8 | import folder_paths 9 | 10 | 11 | def get_mask(video_list,fps,local_sam, x=255, y=255): 12 | from ultralytics.models.sam import SAM2VideoPredictor #ultralytics>=8.3.0 13 | overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024, model=local_sam) 14 | predictor = SAM2VideoPredictor(overrides=overrides) 15 | file_prefix = ''.join(random.choice("0123456789") for _ in range(6)) 16 | frames = [np.array(img) for img in video_list] 17 | video_file = os.path.join(folder_paths.get_input_directory(), f"audio_{file_prefix}_temp.mp4") 18 | imageio.mimsave(video_file, frames, fps=fps, codec='libx264') 19 | results = predictor(source=video_file,points=[x, y],labels=[1]) 20 | mask_list=[] 21 | for i in range(len(results)): 22 | mask = (results[i].masks.data).squeeze().to(torch.float16) 23 | mask = (mask * 255).cpu().numpy().astype(np.uint8) 24 | mask_image = Image.fromarray(mask).convert('RGB') 25 | mask_list.append(mask_image) 26 | # mask_dir = f'masks_animatediff/{video_name}' 27 | # if not os.path.exists(mask_dir): 28 | # os.makedirs(mask_dir) 29 | mask_image.save(folder_paths.get_input_directory() + f'/{str(i).zfill(3)}.png') 30 | return mask_list 31 | 32 | 33 | -------------------------------------------------------------------------------- /sd_repo/README.md: -------------------------------------------------------------------------------- 1 | --- 2 | license: creativeml-openrail-m 3 | tags: 4 | - modelslab.com 5 | - stable-diffusion-api 6 | - text-to-image 7 | - ultra-realistic 8 | pinned: true 9 | --- 10 | 11 | # API Inference 12 | 13 | ![generated from modelslab.com](https://cdn2.stablediffusionapi.com/generations/bf190b5a-fe19-437c-ba05-82f29cb1f7ad-0.png) 14 | ## Get API Key 15 | 16 | Get API key from [ModelsLab API](http://modelslab.com), No Payment needed. 17 | 18 | Replace Key in below code, change **model_id** to "realistic-vision-v51" 19 | 20 | Coding in PHP/Node/Java etc? Have a look at docs for more code examples: [View docs](https://modelslab.com/docs) 21 | 22 | Try model for free: [Generate Images](https://modelslab.com/models/realistic-vision-v51) 23 | 24 | Model link: [View model](https://modelslab.com/models/realistic-vision-v51) 25 | 26 | View all models: [View Models](https://modelslab.com/models) 27 | 28 | import requests 29 | import json 30 | 31 | url = "https://modelslab.com/api/v6/images/text2img" 32 | 33 | payload = json.dumps({ 34 | "key": "your_api_key", 35 | "model_id": "realistic-vision-v51", 36 | "prompt": "ultra realistic close up portrait ((beautiful pale cyberpunk female with heavy black eyeliner)), blue eyes, shaved side haircut, hyper detail, cinematic lighting, magic neon, dark red city, Canon EOS R3, nikon, f/1.4, ISO 200, 1/160s, 8K, RAW, unedited, symmetrical balance, in-frame, 8K", 37 | "negative_prompt": "painting, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, skinny, glitchy, double torso, extra arms, extra hands, mangled fingers, missing lips, ugly face, distorted face, extra legs, anime", 38 | "width": "512", 39 | "height": "512", 40 | "samples": "1", 41 | "num_inference_steps": "30", 42 | "safety_checker": "no", 43 | "enhance_prompt": "yes", 44 | "seed": None, 45 | "guidance_scale": 7.5, 46 | "multi_lingual": "no", 47 | "panorama": "no", 48 | "self_attention": "no", 49 | "upscale": "no", 50 | "embeddings": "embeddings_model_id", 51 | "lora": "lora_model_id", 52 | "webhook": None, 53 | "track_id": None 54 | }) 55 | 56 | headers = { 57 | 'Content-Type': 'application/json' 58 | } 59 | 60 | response = requests.request("POST", url, headers=headers, data=payload) 61 | 62 | print(response.text) 63 | 64 | > Use this coupon code to get 25% off **DMGG0RBN** -------------------------------------------------------------------------------- /sd_repo/feature_extractor/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "crop_size": { 3 | "height": 224, 4 | "width": 224 5 | }, 6 | "do_center_crop": true, 7 | "do_convert_rgb": true, 8 | "do_normalize": true, 9 | "do_rescale": true, 10 | "do_resize": true, 11 | "feature_extractor_type": "CLIPFeatureExtractor", 12 | "image_mean": [ 13 | 0.48145466, 14 | 0.4578275, 15 | 0.40821073 16 | ], 17 | "image_processor_type": "CLIPFeatureExtractor", 18 | "image_std": [ 19 | 0.26862954, 20 | 0.26130258, 21 | 0.27577711 22 | ], 23 | "resample": 3, 24 | "rescale_factor": 0.00392156862745098, 25 | "size": { 26 | "shortest_edge": 224 27 | }, 28 | "use_square_size": false 29 | } 30 | -------------------------------------------------------------------------------- /sd_repo/model_index.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "StableDiffusionPipeline", 3 | "_diffusers_version": "0.25.0.dev0", 4 | "feature_extractor": [ 5 | "transformers", 6 | "CLIPFeatureExtractor" 7 | ], 8 | "image_encoder": [ 9 | null, 10 | null 11 | ], 12 | "requires_safety_checker": true, 13 | "safety_checker": [ 14 | "stable_diffusion", 15 | "StableDiffusionSafetyChecker" 16 | ], 17 | "scheduler": [ 18 | "diffusers", 19 | "PNDMScheduler" 20 | ], 21 | "text_encoder": [ 22 | "transformers", 23 | "CLIPTextModel" 24 | ], 25 | "tokenizer": [ 26 | "transformers", 27 | "CLIPTokenizer" 28 | ], 29 | "unet": [ 30 | "diffusers", 31 | "UNet2DConditionModel" 32 | ], 33 | "vae": [ 34 | "diffusers", 35 | "AutoencoderKL" 36 | ] 37 | } 38 | -------------------------------------------------------------------------------- /sd_repo/safety_checker/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "CompVis/stable-diffusion-safety-checker", 3 | "architectures": [ 4 | "StableDiffusionSafetyChecker" 5 | ], 6 | "initializer_factor": 1.0, 7 | "logit_scale_init_value": 2.6592, 8 | "model_type": "clip", 9 | "projection_dim": 768, 10 | "text_config": { 11 | "dropout": 0.0, 12 | "hidden_size": 768, 13 | "intermediate_size": 3072, 14 | "model_type": "clip_text_model", 15 | "num_attention_heads": 12 16 | }, 17 | "torch_dtype": "float16", 18 | "transformers_version": "4.35.2", 19 | "vision_config": { 20 | "dropout": 0.0, 21 | "hidden_size": 1024, 22 | "intermediate_size": 4096, 23 | "model_type": "clip_vision_model", 24 | "num_attention_heads": 16, 25 | "num_hidden_layers": 24, 26 | "patch_size": 14 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /sd_repo/scheduler/scheduler_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "PNDMScheduler", 3 | "_diffusers_version": "0.25.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "num_train_timesteps": 1000, 9 | "prediction_type": "epsilon", 10 | "set_alpha_to_one": false, 11 | "skip_prk_steps": true, 12 | "steps_offset": 1, 13 | "timestep_spacing": "leading", 14 | "trained_betas": null 15 | } 16 | -------------------------------------------------------------------------------- /sd_repo/text_encoder/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "CLIPTextModel" 4 | ], 5 | "attention_dropout": 0.0, 6 | "bos_token_id": 0, 7 | "dropout": 0.0, 8 | "eos_token_id": 2, 9 | "hidden_act": "quick_gelu", 10 | "hidden_size": 768, 11 | "initializer_factor": 1.0, 12 | "initializer_range": 0.02, 13 | "intermediate_size": 3072, 14 | "layer_norm_eps": 1e-05, 15 | "max_position_embeddings": 77, 16 | "model_type": "clip_text_model", 17 | "num_attention_heads": 12, 18 | "num_hidden_layers": 12, 19 | "pad_token_id": 1, 20 | "projection_dim": 768, 21 | "torch_dtype": "float16", 22 | "transformers_version": "4.35.2", 23 | "vocab_size": 49408 24 | } 25 | -------------------------------------------------------------------------------- /sd_repo/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "<|startoftext|>", 4 | "lstrip": false, 5 | "normalized": true, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "<|endoftext|>", 11 | "lstrip": false, 12 | "normalized": true, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "pad_token": { 17 | "content": "<|endoftext|>", 18 | "lstrip": false, 19 | "normalized": false, 20 | "rstrip": false, 21 | "single_word": false 22 | }, 23 | "unk_token": { 24 | "content": "<|endoftext|>", 25 | "lstrip": false, 26 | "normalized": true, 27 | "rstrip": false, 28 | "single_word": false 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /sd_repo/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "added_tokens_decoder": { 4 | "49406": { 5 | "content": "<|startoftext|>", 6 | "lstrip": false, 7 | "normalized": true, 8 | "rstrip": false, 9 | "single_word": false, 10 | "special": true 11 | }, 12 | "49407": { 13 | "content": "<|endoftext|>", 14 | "lstrip": false, 15 | "normalized": true, 16 | "rstrip": false, 17 | "single_word": false, 18 | "special": true 19 | } 20 | }, 21 | "bos_token": "<|startoftext|>", 22 | "clean_up_tokenization_spaces": true, 23 | "do_lower_case": true, 24 | "eos_token": "<|endoftext|>", 25 | "errors": "replace", 26 | "model_max_length": 77, 27 | "pad_token": "<|endoftext|>", 28 | "tokenizer_class": "CLIPTokenizer", 29 | "unk_token": "<|endoftext|>" 30 | } 31 | -------------------------------------------------------------------------------- /sd_repo/unet/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "UNet2DConditionModel", 3 | "_diffusers_version": "0.25.0.dev0", 4 | "act_fn": "silu", 5 | "addition_embed_type": null, 6 | "addition_embed_type_num_heads": 64, 7 | "addition_time_embed_dim": null, 8 | "attention_head_dim": 8, 9 | "attention_type": "default", 10 | "block_out_channels": [ 11 | 320, 12 | 640, 13 | 1280, 14 | 1280 15 | ], 16 | "center_input_sample": false, 17 | "class_embed_type": null, 18 | "class_embeddings_concat": false, 19 | "conv_in_kernel": 3, 20 | "conv_out_kernel": 3, 21 | "cross_attention_dim": 768, 22 | "cross_attention_norm": null, 23 | "down_block_types": [ 24 | "CrossAttnDownBlock2D", 25 | "CrossAttnDownBlock2D", 26 | "CrossAttnDownBlock2D", 27 | "DownBlock2D" 28 | ], 29 | "downsample_padding": 1, 30 | "dropout": 0.0, 31 | "dual_cross_attention": false, 32 | "encoder_hid_dim": null, 33 | "encoder_hid_dim_type": null, 34 | "flip_sin_to_cos": true, 35 | "freq_shift": 0, 36 | "in_channels": 4, 37 | "layers_per_block": 2, 38 | "mid_block_only_cross_attention": null, 39 | "mid_block_scale_factor": 1, 40 | "mid_block_type": "UNetMidBlock2DCrossAttn", 41 | "norm_eps": 1e-05, 42 | "norm_num_groups": 32, 43 | "num_attention_heads": null, 44 | "num_class_embeds": null, 45 | "only_cross_attention": false, 46 | "out_channels": 4, 47 | "projection_class_embeddings_input_dim": null, 48 | "resnet_out_scale_factor": 1.0, 49 | "resnet_skip_time_act": false, 50 | "resnet_time_scale_shift": "default", 51 | "reverse_transformer_layers_per_block": null, 52 | "sample_size": 64, 53 | "time_cond_proj_dim": null, 54 | "time_embedding_act_fn": null, 55 | "time_embedding_dim": null, 56 | "time_embedding_type": "positional", 57 | "timestep_post_act": null, 58 | "transformer_layers_per_block": 1, 59 | "up_block_types": [ 60 | "UpBlock2D", 61 | "CrossAttnUpBlock2D", 62 | "CrossAttnUpBlock2D", 63 | "CrossAttnUpBlock2D" 64 | ], 65 | "upcast_attention": false, 66 | "use_linear_projection": false 67 | } 68 | -------------------------------------------------------------------------------- /sd_repo/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /sd_repo/vae/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "AutoencoderKL", 3 | "_diffusers_version": "0.25.0.dev0", 4 | "act_fn": "silu", 5 | "block_out_channels": [ 6 | 128, 7 | 256, 8 | 512, 9 | 512 10 | ], 11 | "down_block_types": [ 12 | "DownEncoderBlock2D", 13 | "DownEncoderBlock2D", 14 | "DownEncoderBlock2D", 15 | "DownEncoderBlock2D" 16 | ], 17 | "force_upcast": true, 18 | "in_channels": 3, 19 | "latent_channels": 4, 20 | "layers_per_block": 2, 21 | "norm_num_groups": 32, 22 | "out_channels": 3, 23 | "sample_size": 512, 24 | "scaling_factor": 0.18215, 25 | "up_block_types": [ 26 | "UpDecoderBlock2D", 27 | "UpDecoderBlock2D", 28 | "UpDecoderBlock2D", 29 | "UpDecoderBlock2D" 30 | ] 31 | } 32 | -------------------------------------------------------------------------------- /src/animatediff_eul.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Optional, Tuple, Union 3 | 4 | from diffusers.utils import ( 5 | USE_PEFT_BACKEND, 6 | BaseOutput, 7 | logging, 8 | replace_example_docstring, 9 | scale_lora_layers, 10 | unscale_lora_layers, 11 | ) 12 | from diffusers.utils.torch_utils import randn_tensor 13 | 14 | class EulerAncestralDiscreteSchedulerOutput(BaseOutput): 15 | prev_sample: torch.FloatTensor 16 | pred_original_sample: Optional[torch.FloatTensor] = None 17 | 18 | 19 | def eul_step( 20 | self, 21 | model_output: torch.FloatTensor, 22 | timestep: Union[float, torch.FloatTensor], 23 | sample: torch.FloatTensor, 24 | fusion_latent, 25 | pipe, 26 | generator: Optional[torch.Generator] = None, 27 | return_dict: bool = True, 28 | ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: 29 | 30 | if ( 31 | isinstance(timestep, int) 32 | or isinstance(timestep, torch.IntTensor) 33 | or isinstance(timestep, torch.LongTensor) 34 | ): 35 | raise ValueError( 36 | ( 37 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 38 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 39 | " one of the `scheduler.timesteps` as a timestep." 40 | ), 41 | ) 42 | 43 | if self.step_index is None: 44 | self._init_step_index(timestep) 45 | 46 | sigma = self.sigmas[self.step_index] 47 | 48 | # Upcast to avoid precision issues when computing prev_sample 49 | sample = sample.to(torch.float32) 50 | 51 | # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 52 | if self.config.prediction_type == "epsilon": ## True, 计算x_0 53 | pred_original_sample = sample - sigma * model_output 54 | elif self.config.prediction_type == "v_prediction": 55 | # * c_out + input * c_skip 56 | pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) 57 | elif self.config.prediction_type == "sample": 58 | raise NotImplementedError("prediction_type not implemented yet: sample") 59 | else: 60 | raise ValueError( 61 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" 62 | ) 63 | 64 | ## fusion latent 65 | pred_original_sample = fusion_latent 66 | 67 | sigma_from = self.sigmas[self.step_index] 68 | sigma_to = self.sigmas[self.step_index + 1] 69 | sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 70 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 71 | 72 | # 2. Convert to an ODE derivative 73 | derivative = (sample - pred_original_sample) / sigma 74 | dt = sigma_down - sigma 75 | 76 | prev_sample = sample + derivative * dt 77 | 78 | device = model_output.device 79 | noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator) 80 | 81 | prev_sample = prev_sample + noise * sigma_up 82 | 83 | # Cast sample back to model compatible dtype 84 | prev_sample = prev_sample.to(model_output.dtype) 85 | 86 | # upon completion increase step index by one 87 | self._step_index += 1 88 | 89 | if not return_dict: 90 | return (prev_sample,) 91 | 92 | return EulerAncestralDiscreteSchedulerOutput( 93 | prev_sample=prev_sample, pred_original_sample=pred_original_sample 94 | ) -------------------------------------------------------------------------------- /src/cogvideo_ddim.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion 17 | # and https://github.com/hojonathanho/diffusion 18 | 19 | import math 20 | from dataclasses import dataclass 21 | from typing import List, Optional, Tuple, Union 22 | 23 | import numpy as np 24 | import torch 25 | 26 | from diffusers.configuration_utils import ConfigMixin, register_to_config 27 | from diffusers.utils import BaseOutput 28 | from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin 29 | 30 | 31 | @dataclass 32 | # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM 33 | class DDIMSchedulerOutput(BaseOutput): 34 | """ 35 | Output class for the scheduler's `step` function output. 36 | 37 | Args: 38 | prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): 39 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 40 | denoising loop. 41 | pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): 42 | The predicted denoised sample `(x_{0})` based on the model output from the current timestep. 43 | `pred_original_sample` can be used to preview progress or for guidance. 44 | """ 45 | 46 | prev_sample: torch.Tensor 47 | pred_original_sample: Optional[torch.Tensor] = None 48 | 49 | 50 | # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar 51 | def cog_ddim_step( 52 | self, 53 | model_output: torch.Tensor, 54 | timestep: int, 55 | sample: torch.Tensor, 56 | fusion_target, 57 | eta: float = 0.0, 58 | use_clipped_model_output: bool = False, 59 | generator=None, 60 | variance_noise: Optional[torch.Tensor] = None, 61 | return_dict: bool = True, 62 | ) -> Union[DDIMSchedulerOutput, Tuple]: 63 | """ 64 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 65 | process from the learned model outputs (most often the predicted noise). 66 | 67 | Args: 68 | model_output (`torch.Tensor`): 69 | The direct output from learned diffusion model. 70 | timestep (`float`): 71 | The current discrete timestep in the diffusion chain. 72 | sample (`torch.Tensor`): 73 | A current instance of a sample created by the diffusion process. 74 | eta (`float`): 75 | The weight of noise for added noise in diffusion step. 76 | use_clipped_model_output (`bool`, defaults to `False`): 77 | If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary 78 | because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no 79 | clipping has happened, "corrected" `model_output` would coincide with the one provided as input and 80 | `use_clipped_model_output` has no effect. 81 | generator (`torch.Generator`, *optional*): 82 | A random number generator. 83 | variance_noise (`torch.Tensor`): 84 | Alternative to generating noise with `generator` by directly providing the noise for the variance 85 | itself. Useful for methods such as [`CycleDiffusion`]. 86 | return_dict (`bool`, *optional*, defaults to `True`): 87 | Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. 88 | 89 | Returns: 90 | [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`: 91 | If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a 92 | tuple is returned where the first element is the sample tensor. 93 | 94 | """ 95 | if self.num_inference_steps is None: 96 | raise ValueError( 97 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" 98 | ) 99 | 100 | # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf 101 | # Ideally, read DDIM paper in-detail understanding 102 | 103 | # Notation ( -> 104 | # - pred_noise_t -> e_theta(x_t, t) 105 | # - pred_original_sample -> f_theta(x_t, t) or x_0 106 | # - std_dev_t -> sigma_t 107 | # - eta -> η 108 | # - pred_sample_direction -> "direction pointing to x_t" 109 | # - pred_prev_sample -> "x_t-1" 110 | 111 | # 1. get previous step value (=t-1) 112 | prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps 113 | 114 | # 2. compute alphas, betas 115 | alpha_prod_t = self.alphas_cumprod[timestep] 116 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod 117 | 118 | beta_prod_t = 1 - alpha_prod_t 119 | 120 | # 3. compute predicted original sample from predicted noise also called 121 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 122 | # To make style tests pass, commented out `pred_epsilon` as it is an unused variable 123 | if self.config.prediction_type == "epsilon": 124 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) 125 | # pred_epsilon = model_output 126 | elif self.config.prediction_type == "sample": 127 | pred_original_sample = model_output 128 | # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) 129 | elif self.config.prediction_type == "v_prediction": 130 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output 131 | # pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample 132 | else: 133 | raise ValueError( 134 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" 135 | " `v_prediction`" 136 | ) 137 | 138 | ## insert fusion target 139 | pred_original_sample = fusion_target 140 | 141 | a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 142 | b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t 143 | 144 | prev_sample = a_t * sample + b_t * pred_original_sample 145 | 146 | if not return_dict: 147 | return ( 148 | prev_sample, 149 | pred_original_sample, 150 | ) 151 | 152 | return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) 153 | -------------------------------------------------------------------------------- /src/ic_light.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from enum import Enum 4 | import math 5 | 6 | import torch.nn.functional as F 7 | from .tools import resize_and_center_crop, numpy2pytorch, pad, decode_latents, encode_video 8 | 9 | class BGSource(Enum): 10 | NONE = "None" 11 | LEFT = "Left Light" 12 | RIGHT = "Right Light" 13 | TOP = "Top Light" 14 | BOTTOM = "Bottom Light" 15 | 16 | class Relighter: 17 | def __init__(self, 18 | pipeline, 19 | relight_prompt="", 20 | num_frames=16, 21 | image_width=512, 22 | image_height=512, 23 | num_samples=1, 24 | steps=15, 25 | cfg=2, 26 | lowres_denoise=0.9, 27 | bg_source=BGSource.RIGHT, 28 | generator=None, 29 | ): 30 | 31 | self.pipeline = pipeline 32 | self.image_width = image_width 33 | self.image_height = image_height 34 | self.num_samples = num_samples 35 | self.steps = steps 36 | self.cfg = cfg 37 | self.lowres_denoise = lowres_denoise 38 | self.bg_source = bg_source 39 | self.generator = generator 40 | self.device = pipeline.device 41 | self.num_frames = num_frames 42 | self.vae = self.pipeline.vae 43 | 44 | self.a_prompt = "best quality" 45 | self.n_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" 46 | positive_prompt = relight_prompt + ', ' + self.a_prompt 47 | negative_prompt = self.n_prompt 48 | tokenizer = self.pipeline.tokenizer 49 | device = self.pipeline.device 50 | vae = self.vae 51 | 52 | conds, unconds = self.encode_prompt_pair(tokenizer, device, positive_prompt, negative_prompt) 53 | input_bg = self.create_background() 54 | # bg = resize_and_center_crop(input_bg, self.image_width, self.image_height) 55 | # bg_latent = numpy2pytorch([bg], device, vae.dtype) 56 | # bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor 57 | if self.bg_source == BGSource.NONE: 58 | shape = (1, 4, self.image_width//8, self.image_height//8) 59 | bg_latent = torch.randn(shape, generator=generator, device=device, dtype=vae.dtype) 60 | else: 61 | bg = resize_and_center_crop(input_bg, self.image_width, self.image_height) 62 | bg_latent = numpy2pytorch([bg], device, vae.dtype) 63 | bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor 64 | self.bg_latent = bg_latent.repeat(self.num_frames, 1, 1, 1) ## 固定光源 65 | self.conds = conds.repeat(self.num_frames, 1, 1) 66 | self.unconds = unconds.repeat(self.num_frames, 1, 1) 67 | 68 | def encode_prompt_inner(self, tokenizer, txt): 69 | max_length = tokenizer.model_max_length 70 | chunk_length = tokenizer.model_max_length - 2 71 | id_start = tokenizer.bos_token_id 72 | id_end = tokenizer.eos_token_id 73 | id_pad = id_end 74 | 75 | tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"] 76 | chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)] 77 | chunks = [pad(ck, id_pad, max_length) for ck in chunks] 78 | 79 | token_ids = torch.tensor(chunks).to(device=self.device, dtype=torch.int64) 80 | conds = self.pipeline.text_encoder(token_ids).last_hidden_state 81 | return conds 82 | 83 | def encode_prompt_pair(self, tokenizer, device, positive_prompt, negative_prompt): 84 | c = self.encode_prompt_inner(tokenizer, positive_prompt) 85 | uc = self.encode_prompt_inner(tokenizer, negative_prompt) 86 | 87 | c_len = float(len(c)) 88 | uc_len = float(len(uc)) 89 | max_count = max(c_len, uc_len) 90 | c_repeat = int(math.ceil(max_count / c_len)) 91 | uc_repeat = int(math.ceil(max_count / uc_len)) 92 | max_chunk = max(len(c), len(uc)) 93 | 94 | c = torch.cat([c] * c_repeat, dim=0)[:max_chunk] 95 | uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk] 96 | 97 | c = torch.cat([p[None, ...] for p in c], dim=1) 98 | uc = torch.cat([p[None, ...] for p in uc], dim=1) 99 | 100 | return c.to(device), uc.to(device) 101 | 102 | def create_background(self): 103 | 104 | max_pix = 255 105 | min_pix = 0 106 | 107 | print(f"max light pix:{max_pix}, min light pix:{min_pix}") 108 | 109 | if self.bg_source == BGSource.NONE: 110 | return None 111 | elif self.bg_source == BGSource.LEFT: 112 | gradient = np.linspace(max_pix, min_pix, self.image_width) 113 | image = np.tile(gradient, (self.image_height, 1)) 114 | return np.stack((image,) * 3, axis=-1).astype(np.uint8) 115 | elif self.bg_source == BGSource.RIGHT: 116 | gradient = np.linspace(min_pix, max_pix, self.image_width) 117 | image = np.tile(gradient, (self.image_height, 1)) 118 | return np.stack((image,) * 3, axis=-1).astype(np.uint8) 119 | elif self.bg_source == BGSource.TOP: 120 | gradient = np.linspace(max_pix, min_pix, self.image_height)[:, None] 121 | image = np.tile(gradient, (1, self.image_width)) 122 | return np.stack((image,) * 3, axis=-1).astype(np.uint8) 123 | elif self.bg_source == BGSource.BOTTOM: 124 | gradient = np.linspace(min_pix, max_pix, self.image_height)[:, None] 125 | image = np.tile(gradient, (1, self.image_width)) 126 | return np.stack((image,) * 3, axis=-1).astype(np.uint8) 127 | else: 128 | raise ValueError('Wrong initial latent!') 129 | 130 | @torch.no_grad() 131 | def __call__(self, input_video, init_latent=None, input_strength=None): 132 | input_latent = encode_video(self.vae, input_video)* self.vae.config.scaling_factor 133 | 134 | if input_strength: 135 | light_strength = input_strength 136 | else: 137 | light_strength = self.lowres_denoise 138 | 139 | if not init_latent: 140 | init_latent = self.bg_latent 141 | 142 | latents = self.pipeline( 143 | image=init_latent, 144 | strength=light_strength, 145 | prompt_embeds=self.conds, 146 | negative_prompt_embeds=self.unconds, 147 | width=self.image_width, 148 | height=self.image_height, 149 | num_inference_steps=int(round(self.steps / self.lowres_denoise)), 150 | num_images_per_prompt=self.num_samples, 151 | generator=self.generator, 152 | output_type='latent', 153 | guidance_scale=self.cfg, 154 | cross_attention_kwargs={'concat_conds': input_latent}, 155 | ).images.to(self.pipeline.vae.dtype) 156 | 157 | relight_video = decode_latents(self.vae, latents) 158 | return relight_video -------------------------------------------------------------------------------- /src/tools.py: -------------------------------------------------------------------------------- 1 | from PIL import Image,ImageSequence 2 | import numpy as np 3 | import torch 4 | try: 5 | from moviepy import VideoFileClip 6 | except: 7 | from moviepy.editor import VideoFileClip 8 | 9 | import os 10 | import imageio 11 | import random 12 | from diffusers.utils import export_to_video 13 | 14 | def resize_and_center_crop(image, target_width, target_height): 15 | pil_image = Image.fromarray(image) 16 | original_width, original_height = pil_image.size 17 | scale_factor = max(target_width / original_width, target_height / original_height) 18 | resized_width = int(round(original_width * scale_factor)) 19 | resized_height = int(round(original_height * scale_factor)) 20 | resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS) 21 | 22 | left = (resized_width - target_width) / 2 23 | top = (resized_height - target_height) / 2 24 | right = (resized_width + target_width) / 2 25 | bottom = (resized_height + target_height) / 2 26 | cropped_image = resized_image.crop((left, top, right, bottom)) 27 | return np.array(cropped_image) 28 | 29 | def numpy2pytorch(imgs, device, dtype): 30 | h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 31 | h = h.movedim(-1, 1) 32 | return h.to(device=device, dtype=dtype) 33 | 34 | def get_fg_video(video_list, mask_list, device, dtype): 35 | video_np = np.stack(video_list, axis=0) 36 | mask_np = np.stack(mask_list, axis=0) 37 | mask_bool = mask_np == 255 38 | video_fg = np.where(mask_bool, video_np, 127) 39 | 40 | h = torch.from_numpy(video_fg).float() / 127.0 - 1.0 41 | h = h.movedim(-1, 1) 42 | return h.to(device=device, dtype=dtype) 43 | 44 | 45 | def pad(x, p, i): 46 | return x[:i] if len(x) >= i else x + [p] * (i - len(x)) 47 | 48 | def gif_to_mp4(gif_path, mp4_path): 49 | clip = VideoFileClip(gif_path) 50 | clip.write_videofile(mp4_path) 51 | 52 | def generate_light_sequence(light_tensor, num_frames=16, direction="r"): 53 | 54 | if direction in "l": 55 | target_tensor = torch.rot90(light_tensor, k=1, dims=(2, 3)) 56 | elif direction in "r": 57 | target_tensor = torch.rot90(light_tensor, k=-1, dims=(2, 3)) 58 | else: 59 | raise ValueError("direction must be either 'r' for right or 'l' for left") 60 | 61 | # Generate the sequence 62 | out_list = [] 63 | for frame_idx in range(num_frames): 64 | t = frame_idx / (num_frames - 1) 65 | interpolated_matrix = (1 - t) * light_tensor + t * target_tensor 66 | out_list.append(interpolated_matrix) 67 | 68 | out_tensor = torch.stack(out_list, dim=0).squeeze(1) 69 | 70 | return out_tensor 71 | 72 | def tensor2vid(video: torch.Tensor, processor, output_type="np"): 73 | 74 | batch_size, channels, num_frames, height, width = video.shape ## [1, 4, 16, 512, 512] 75 | outputs = [] 76 | for batch_idx in range(batch_size): 77 | batch_vid = video[batch_idx].permute(1, 0, 2, 3) 78 | batch_output = processor.postprocess(batch_vid, output_type) 79 | 80 | outputs.append(batch_output) 81 | 82 | return outputs 83 | 84 | def read_video(video_path:str, image_width, image_height): 85 | extension = video_path.split('.')[-1].lower() 86 | video_name = os.path.basename(video_path) 87 | video_list = [] 88 | 89 | if extension in "gif": 90 | ## input from gif 91 | video = Image.open(video_path) 92 | for i, frame in enumerate(ImageSequence.Iterator(video)): 93 | frame = np.array(frame.convert("RGB")) 94 | frame = resize_and_center_crop(frame, image_width, image_height) 95 | video_list.append(frame) 96 | elif extension in "mp4": 97 | ## input from mp4 98 | reader = imageio.get_reader(video_path) 99 | for frame in reader: 100 | frame = resize_and_center_crop(frame, image_width, image_height) 101 | video_list.append(frame) 102 | else: 103 | raise ValueError('Wrong input type') 104 | 105 | video_list = [Image.fromarray(frame) for frame in video_list] 106 | 107 | return video_list, video_name 108 | 109 | def read_mask(mask_folder:str): 110 | mask_files = os.listdir(mask_folder) 111 | mask_files = sorted(mask_files) 112 | mask_list = [] 113 | for mask_file in mask_files: 114 | mask_path = os.path.join(mask_folder, mask_file) 115 | mask = Image.open(mask_path).convert('RGB') 116 | mask_list.append(mask) 117 | 118 | return mask_list 119 | 120 | def decode_latents(vae, latents, decode_chunk_size: int = 16): 121 | 122 | latents = 1 / vae.config.scaling_factor * latents 123 | video = [] 124 | for i in range(0, latents.shape[0], decode_chunk_size): 125 | batch_latents = latents[i : i + decode_chunk_size] 126 | batch_latents = vae.decode(batch_latents).sample 127 | video.append(batch_latents) 128 | 129 | video = torch.cat(video) 130 | 131 | return video 132 | 133 | def encode_video(vae, video, decode_chunk_size: int = 16) -> torch.Tensor: 134 | latents = [] 135 | for i in range(0, len(video), decode_chunk_size): 136 | batch_video = video[i : i + decode_chunk_size] 137 | batch_video = vae.encode(batch_video).latent_dist.mode() 138 | latents.append(batch_video) 139 | return torch.cat(latents) 140 | 141 | def vis_video(input_video, video_processor, save_path): 142 | ## shape: 1, c, f, h, w 143 | relight_video = video_processor.postprocess_video(video=input_video, output_type="pil") 144 | export_to_video(relight_video[0], save_path) 145 | 146 | def set_all_seed(seed): 147 | torch.manual_seed(seed) 148 | torch.cuda.manual_seed(seed) 149 | torch.cuda.manual_seed_all(seed) 150 | np.random.seed(seed) 151 | random.seed(seed) 152 | torch.backends.cudnn.deterministic = True -------------------------------------------------------------------------------- /src/wan_flowmatch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | from dataclasses import dataclass 17 | from typing import List, Optional, Tuple, Union 18 | 19 | import numpy as np 20 | import torch 21 | 22 | from diffusers.configuration_utils import ConfigMixin, register_to_config 23 | from diffusers.utils import BaseOutput, is_scipy_available, logging 24 | from diffusers.schedulers.scheduling_utils import SchedulerMixin 25 | 26 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 27 | 28 | 29 | @dataclass 30 | class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): 31 | """ 32 | Output class for the scheduler's `step` function output. 33 | 34 | Args: 35 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 36 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 37 | denoising loop. 38 | """ 39 | 40 | prev_sample: torch.FloatTensor 41 | 42 | def wan_flowmatch_step( 43 | self, 44 | model_output: torch.FloatTensor, 45 | timestep: Union[float, torch.FloatTensor], 46 | sample: torch.FloatTensor, 47 | fusion_sample: torch.FloatTensor, 48 | s_churn: float = 0.0, 49 | s_tmin: float = 0.0, 50 | s_tmax: float = float("inf"), 51 | s_noise: float = 1.0, 52 | generator: Optional[torch.Generator] = None, 53 | return_dict: bool = True, 54 | ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: 55 | """ 56 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 57 | process from the learned model outputs (most often the predicted noise). 58 | 59 | Args: 60 | model_output (`torch.FloatTensor`): 61 | The direct output from learned diffusion model. 62 | timestep (`float`): 63 | The current discrete timestep in the diffusion chain. 64 | sample (`torch.FloatTensor`): 65 | A current instance of a sample created by the diffusion process. 66 | s_churn (`float`): 67 | s_tmin (`float`): 68 | s_tmax (`float`): 69 | s_noise (`float`, defaults to 1.0): 70 | Scaling factor for noise added to the sample. 71 | generator (`torch.Generator`, *optional*): 72 | A random number generator. 73 | return_dict (`bool`): 74 | Whether or not to return a 75 | [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. 76 | 77 | Returns: 78 | [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`: 79 | If return_dict is `True`, 80 | [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned, 81 | otherwise a tuple is returned where the first element is the sample tensor. 82 | """ 83 | 84 | if ( 85 | isinstance(timestep, int) 86 | or isinstance(timestep, torch.IntTensor) 87 | or isinstance(timestep, torch.LongTensor) 88 | ): 89 | raise ValueError( 90 | ( 91 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 92 | " `FlowMatchEulerDiscreteScheduler.step()` is not supported. Make sure to pass" 93 | " one of the `scheduler.timesteps` as a timestep." 94 | ), 95 | ) 96 | 97 | if self.step_index is None: 98 | self._init_step_index(timestep) 99 | 100 | # Upcast to avoid precision issues when computing prev_sample 101 | sample = sample.to(torch.float32) 102 | 103 | sigma = self.sigmas[self.step_index] 104 | sigma_next = self.sigmas[self.step_index + 1] 105 | 106 | ## new direction 107 | fusion_vector = (sample - fusion_sample) / sigma 108 | prev_sample = sample + (sigma_next - sigma) * fusion_vector 109 | 110 | # Cast sample back to model compatible dtype 111 | prev_sample = prev_sample.to(model_output.dtype) 112 | 113 | # upon completion increase step index by one 114 | self._step_index += 1 115 | 116 | if not return_dict: 117 | return (prev_sample,) 118 | 119 | return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) 120 | -------------------------------------------------------------------------------- /src/wan_pipe.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import html 16 | from typing import Callable, Dict, List, Optional, Union 17 | 18 | import ftfy 19 | import regex as re 20 | import torch 21 | from transformers import AutoTokenizer, UMT5EncoderModel 22 | 23 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback 24 | from diffusers.models import AutoencoderKLWan, WanTransformer3DModel 25 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler 26 | from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring 27 | from diffusers.utils.torch_utils import randn_tensor 28 | from diffusers.video_processor import VideoProcessor 29 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 30 | 31 | import PIL 32 | import numpy as np 33 | from diffusers.utils import BaseOutput 34 | from dataclasses import dataclass 35 | from einops import rearrange 36 | from .wan_flowmatch import wan_flowmatch_step 37 | from diffusers.utils import export_to_gif 38 | 39 | PipelineImageInput = Union[ 40 | PIL.Image.Image, 41 | np.ndarray, 42 | torch.Tensor, 43 | List[PIL.Image.Image], 44 | List[np.ndarray], 45 | List[torch.Tensor], 46 | ] 47 | 48 | @dataclass 49 | class WanPipelineOutput(BaseOutput): 50 | r""" 51 | Output class for Wan pipelines. 52 | 53 | Args: 54 | frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): 55 | List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing 56 | denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape 57 | `(batch_size, num_frames, channels, height, width)`. 58 | """ 59 | 60 | frames: torch.Tensor 61 | 62 | XLA_AVAILABLE = False 63 | 64 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 65 | 66 | 67 | EXAMPLE_DOC_STRING = """ 68 | Examples: 69 | ```python 70 | >>> import torch 71 | >>> from diffusers import AutoencoderKLWan, WanPipeline 72 | >>> from diffusers.utils import export_to_video 73 | 74 | >>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers 75 | >>> model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" 76 | >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) 77 | >>> pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) 78 | >>> pipe.to("cuda") 79 | 80 | >>> prompt = "A cat walks on the grass, realistic" 81 | >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" 82 | 83 | >>> output = pipe( 84 | ... prompt=prompt, 85 | ... negative_prompt=negative_prompt, 86 | ... height=480, 87 | ... width=832, 88 | ... num_frames=81, 89 | ... guidance_scale=5.0, 90 | ... ).frames[0] 91 | >>> export_to_video(output, "output.mp4", fps=15) 92 | ``` 93 | """ 94 | 95 | import inspect 96 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 97 | def retrieve_timesteps( 98 | scheduler, 99 | num_inference_steps: Optional[int] = None, 100 | device: Optional[Union[str, torch.device]] = None, 101 | timesteps: Optional[List[int]] = None, 102 | sigmas: Optional[List[float]] = None, 103 | **kwargs, 104 | ): 105 | r""" 106 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 107 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 108 | 109 | Args: 110 | scheduler (`SchedulerMixin`): 111 | The scheduler to get timesteps from. 112 | num_inference_steps (`int`): 113 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 114 | must be `None`. 115 | device (`str` or `torch.device`, *optional*): 116 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 117 | timesteps (`List[int]`, *optional*): 118 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 119 | `num_inference_steps` and `sigmas` must be `None`. 120 | sigmas (`List[float]`, *optional*): 121 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 122 | `num_inference_steps` and `timesteps` must be `None`. 123 | 124 | Returns: 125 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 126 | second element is the number of inference steps. 127 | """ 128 | 129 | if timesteps is not None and sigmas is not None: 130 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 131 | if timesteps is not None: 132 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 133 | if not accepts_timesteps: 134 | raise ValueError( 135 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 136 | f" timestep schedules. Please check whether you are using the correct scheduler." 137 | ) 138 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 139 | timesteps = scheduler.timesteps 140 | num_inference_steps = len(timesteps) 141 | elif sigmas is not None: 142 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 143 | if not accept_sigmas: 144 | raise ValueError( 145 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 146 | f" sigmas schedules. Please check whether you are using the correct scheduler." 147 | ) 148 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 149 | timesteps = scheduler.timesteps 150 | num_inference_steps = len(timesteps) 151 | else: 152 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 153 | timesteps = scheduler.timesteps 154 | return timesteps, num_inference_steps 155 | 156 | def basic_clean(text): 157 | text = ftfy.fix_text(text) 158 | text = html.unescape(html.unescape(text)) 159 | return text.strip() 160 | 161 | 162 | def whitespace_clean(text): 163 | text = re.sub(r"\s+", " ", text) 164 | text = text.strip() 165 | return text 166 | 167 | def retrieve_latents( 168 | encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" 169 | ): 170 | if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": 171 | return encoder_output.latent_dist.sample(generator) 172 | elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": 173 | return encoder_output.latent_dist.mode() 174 | elif hasattr(encoder_output, "latents"): 175 | return encoder_output.latents 176 | else: 177 | raise AttributeError("Could not access latents of provided encoder_output") 178 | 179 | 180 | def prompt_clean(text): 181 | text = whitespace_clean(basic_clean(text)) 182 | return text 183 | 184 | 185 | class WanVideoToVideoPipeline(DiffusionPipeline): 186 | r""" 187 | Pipeline for text-to-video generation using Wan. 188 | 189 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 190 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 191 | 192 | Args: 193 | tokenizer ([`T5Tokenizer`]): 194 | Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), 195 | specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. 196 | text_encoder ([`T5EncoderModel`]): 197 | [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically 198 | the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. 199 | transformer ([`WanTransformer3DModel`]): 200 | Conditional Transformer to denoise the input latents. 201 | scheduler ([`UniPCMultistepScheduler`]): 202 | A scheduler to be used in combination with `transformer` to denoise the encoded image latents. 203 | vae ([`AutoencoderKLWan`]): 204 | Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. 205 | """ 206 | 207 | model_cpu_offload_seq = "text_encoder->transformer->vae" 208 | _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] 209 | 210 | def __init__( 211 | self, 212 | tokenizer: AutoTokenizer, 213 | text_encoder: UMT5EncoderModel, 214 | transformer: WanTransformer3DModel, 215 | vae: AutoencoderKLWan, 216 | scheduler: FlowMatchEulerDiscreteScheduler, 217 | ): 218 | super().__init__() 219 | 220 | self.register_modules( 221 | vae=vae, 222 | text_encoder=text_encoder, 223 | tokenizer=tokenizer, 224 | transformer=transformer, 225 | scheduler=scheduler, 226 | ) 227 | 228 | self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 229 | self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 230 | self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) 231 | 232 | def _get_t5_prompt_embeds( 233 | self, 234 | prompt: Union[str, List[str]] = None, 235 | num_videos_per_prompt: int = 1, 236 | max_sequence_length: int = 226, 237 | device: Optional[torch.device] = None, 238 | dtype: Optional[torch.dtype] = None, 239 | ): 240 | device = device or self._execution_device 241 | dtype = dtype or self.text_encoder.dtype 242 | 243 | prompt = [prompt] if isinstance(prompt, str) else prompt 244 | prompt = [prompt_clean(u) for u in prompt] 245 | batch_size = len(prompt) 246 | 247 | text_inputs = self.tokenizer( 248 | prompt, 249 | padding="max_length", 250 | max_length=max_sequence_length, 251 | truncation=True, 252 | add_special_tokens=True, 253 | return_attention_mask=True, 254 | return_tensors="pt", 255 | ) 256 | text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask 257 | seq_lens = mask.gt(0).sum(dim=1).long() 258 | 259 | prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state 260 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 261 | prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] 262 | prompt_embeds = torch.stack( 263 | [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 264 | ) 265 | 266 | # duplicate text embeddings for each generation per prompt, using mps friendly method 267 | _, seq_len, _ = prompt_embeds.shape 268 | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) 269 | prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) 270 | 271 | return prompt_embeds 272 | 273 | # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps 274 | def get_timesteps(self, num_inference_steps, timesteps, strength, device): 275 | # get the original timestep using init_timestep 276 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) 277 | 278 | t_start = max(num_inference_steps - init_timestep, 0) 279 | timesteps = timesteps[t_start * self.scheduler.order :] 280 | 281 | return timesteps, num_inference_steps - t_start 282 | 283 | def encode_prompt( 284 | self, 285 | prompt: Union[str, List[str]], 286 | negative_prompt: Optional[Union[str, List[str]]] = None, 287 | do_classifier_free_guidance: bool = True, 288 | num_videos_per_prompt: int = 1, 289 | prompt_embeds: Optional[torch.Tensor] = None, 290 | negative_prompt_embeds: Optional[torch.Tensor] = None, 291 | max_sequence_length: int = 226, 292 | device: Optional[torch.device] = None, 293 | dtype: Optional[torch.dtype] = None, 294 | ): 295 | r""" 296 | Encodes the prompt into text encoder hidden states. 297 | 298 | Args: 299 | prompt (`str` or `List[str]`, *optional*): 300 | prompt to be encoded 301 | negative_prompt (`str` or `List[str]`, *optional*): 302 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 303 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 304 | less than `1`). 305 | do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): 306 | Whether to use classifier free guidance or not. 307 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 308 | Number of videos that should be generated per prompt. torch device to place the resulting embeddings on 309 | prompt_embeds (`torch.Tensor`, *optional*): 310 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 311 | provided, text embeddings will be generated from `prompt` input argument. 312 | negative_prompt_embeds (`torch.Tensor`, *optional*): 313 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 314 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 315 | argument. 316 | device: (`torch.device`, *optional*): 317 | torch device 318 | dtype: (`torch.dtype`, *optional*): 319 | torch dtype 320 | """ 321 | device = device or self._execution_device 322 | 323 | prompt = [prompt] if isinstance(prompt, str) else prompt 324 | if prompt is not None: 325 | batch_size = len(prompt) 326 | else: 327 | batch_size = prompt_embeds.shape[0] 328 | 329 | if prompt_embeds is None: 330 | prompt_embeds = self._get_t5_prompt_embeds( 331 | prompt=prompt, 332 | num_videos_per_prompt=num_videos_per_prompt, 333 | max_sequence_length=max_sequence_length, 334 | device=device, 335 | dtype=dtype, 336 | ) 337 | 338 | if do_classifier_free_guidance and negative_prompt_embeds is None: 339 | negative_prompt = negative_prompt or "" 340 | negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt 341 | 342 | if prompt is not None and type(prompt) is not type(negative_prompt): 343 | raise TypeError( 344 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 345 | f" {type(prompt)}." 346 | ) 347 | elif batch_size != len(negative_prompt): 348 | raise ValueError( 349 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 350 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 351 | " the batch size of `prompt`." 352 | ) 353 | 354 | negative_prompt_embeds = self._get_t5_prompt_embeds( 355 | prompt=negative_prompt, 356 | num_videos_per_prompt=num_videos_per_prompt, 357 | max_sequence_length=max_sequence_length, 358 | device=device, 359 | dtype=dtype, 360 | ) 361 | 362 | return prompt_embeds, negative_prompt_embeds 363 | 364 | def check_inputs( 365 | self, 366 | prompt, 367 | negative_prompt, 368 | height, 369 | width, 370 | prompt_embeds=None, 371 | negative_prompt_embeds=None, 372 | callback_on_step_end_tensor_inputs=None, 373 | ): 374 | if height % 16 != 0 or width % 16 != 0: 375 | raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") 376 | 377 | if callback_on_step_end_tensor_inputs is not None and not all( 378 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs 379 | ): 380 | raise ValueError( 381 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" 382 | ) 383 | 384 | if prompt is not None and prompt_embeds is not None: 385 | raise ValueError( 386 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 387 | " only forward one of the two." 388 | ) 389 | elif negative_prompt is not None and negative_prompt_embeds is not None: 390 | raise ValueError( 391 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" 392 | " only forward one of the two." 393 | ) 394 | elif prompt is None and prompt_embeds is None: 395 | raise ValueError( 396 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 397 | ) 398 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 399 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 400 | elif negative_prompt is not None and ( 401 | not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) 402 | ): 403 | raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") 404 | 405 | def prepare_latents( 406 | self, 407 | video: Optional[torch.Tensor] = None, 408 | batch_size: int = 1, 409 | num_channels_latents: int = 16, 410 | height: int = 720, 411 | width: int = 1280, 412 | num_frames: int = 81, 413 | dtype: Optional[torch.dtype] = None, 414 | device: Optional[torch.device] = None, 415 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 416 | latents: Optional[torch.Tensor] = None, 417 | timestep: Optional[torch.Tensor] = None, 418 | ) -> torch.Tensor: 419 | num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 420 | latent_height = height // self.vae_scale_factor_spatial 421 | latent_width = width // self.vae_scale_factor_spatial 422 | 423 | shape = ( 424 | batch_size, 425 | num_channels_latents, 426 | num_latent_frames, 427 | latent_height, 428 | latent_width, 429 | ) 430 | init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] 431 | init_latents = torch.cat(init_latents, dim=0).to(dtype) 432 | 433 | latents_mean = ( 434 | torch.tensor(self.vae.config.latents_mean) 435 | .view(1, self.vae.config.z_dim, 1, 1, 1) 436 | .to(device, dtype) 437 | ) 438 | latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( 439 | device, dtype 440 | ) 441 | init_latents = (init_latents - latents_mean) * latents_std 442 | 443 | 444 | noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 445 | latents = self.scheduler.scale_noise(init_latents, timestep, noise) 446 | 447 | return latents 448 | 449 | def encode_latents(self, video: torch.Tensor) -> torch.Tensor: 450 | 451 | latents = self.vae.encode(video).latent_dist.mode() 452 | latents_mean = ( 453 | torch.tensor(self.vae.config.latents_mean) 454 | .view(1, self.vae.config.z_dim, 1, 1, 1) 455 | .to(latents.device, latents.dtype) 456 | ) 457 | latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( 458 | latents.device, latents.dtype 459 | ) 460 | 461 | latents = (latents - latents_mean) * latents_std 462 | 463 | return latents 464 | 465 | def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: 466 | 467 | latents = latents.to(self.vae.dtype) 468 | latents_mean = ( 469 | torch.tensor(self.vae.config.latents_mean) 470 | .view(1, self.vae.config.z_dim, 1, 1, 1) 471 | .to(latents.device, latents.dtype) 472 | ) 473 | latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( 474 | latents.device, latents.dtype 475 | ) 476 | latents = latents / latents_std + latents_mean 477 | video = self.vae.decode(latents, return_dict=False)[0] 478 | 479 | return video 480 | 481 | @property 482 | def guidance_scale(self): 483 | return self._guidance_scale 484 | 485 | @property 486 | def do_classifier_free_guidance(self): 487 | return self._guidance_scale > 1.0 488 | 489 | @property 490 | def num_timesteps(self): 491 | return self._num_timesteps 492 | 493 | @property 494 | def current_timestep(self): 495 | return self._current_timestep 496 | 497 | @property 498 | def interrupt(self): 499 | return self._interrupt 500 | 501 | @torch.no_grad() 502 | @replace_example_docstring(EXAMPLE_DOC_STRING) 503 | def __call__( 504 | self, 505 | ic_light_pipe=None, 506 | relight_prompt=None, 507 | bg_source=None, 508 | video: List[List[PipelineImageInput]] = None, 509 | strength = 0, 510 | prompt: Union[str, List[str]] = None, 511 | negative_prompt: Union[str, List[str]] = None, 512 | height: int = 720, 513 | width: int = 1280, 514 | num_frames: int = 81, 515 | timesteps: Optional[List[int]] = None, 516 | num_inference_steps: int = 50, 517 | guidance_scale: float = 5.0, 518 | num_videos_per_prompt: Optional[int] = 1, 519 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 520 | latents: Optional[torch.Tensor] = None, 521 | prompt_embeds: Optional[torch.Tensor] = None, 522 | negative_prompt_embeds: Optional[torch.Tensor] = None, 523 | output_type: Optional[str] = "np", 524 | return_dict: bool = True, 525 | callback_on_step_end: Optional[ 526 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] 527 | ] = None, 528 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 529 | max_sequence_length: int = 512, 530 | ): 531 | r""" 532 | The call function to the pipeline for generation. 533 | 534 | Args: 535 | prompt (`str` or `List[str]`, *optional*): 536 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 537 | instead. 538 | height (`int`, defaults to `720`): 539 | The height in pixels of the generated image. 540 | width (`int`, defaults to `1280`): 541 | The width in pixels of the generated image. 542 | num_frames (`int`, defaults to `129`): 543 | The number of frames in the generated video. 544 | num_inference_steps (`int`, defaults to `50`): 545 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 546 | expense of slower inference. 547 | guidance_scale (`float`, defaults to `5.0`): 548 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 549 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 550 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 551 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 552 | usually at the expense of lower image quality. 553 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 554 | The number of images to generate per prompt. 555 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 556 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 557 | generation deterministic. 558 | latents (`torch.Tensor`, *optional*): 559 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 560 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 561 | tensor is generated by sampling using the supplied random `generator`. 562 | prompt_embeds (`torch.Tensor`, *optional*): 563 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not 564 | provided, text embeddings are generated from the `prompt` input argument. 565 | output_type (`str`, *optional*, defaults to `"pil"`): 566 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 567 | return_dict (`bool`, *optional*, defaults to `True`): 568 | Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. 569 | callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): 570 | A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of 571 | each denoising step during the inference. with the following arguments: `callback_on_step_end(self: 572 | DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a 573 | list of all tensors as specified by `callback_on_step_end_tensor_inputs`. 574 | callback_on_step_end_tensor_inputs (`List`, *optional*): 575 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 576 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 577 | `._callback_tensor_inputs` attribute of your pipeline class. 578 | autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): 579 | The dtype to use for the torch.amp.autocast. 580 | 581 | Examples: 582 | 583 | Returns: 584 | [`~WanPipelineOutput`] or `tuple`: 585 | If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where 586 | the first element is a list with the generated images and the second element is a list of `bool`s 587 | indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. 588 | """ 589 | self._guidance_scale = guidance_scale 590 | self._current_timestep = None 591 | self._interrupt = False 592 | 593 | device = self._execution_device 594 | 595 | # 2. Define call parameters 596 | if prompt is not None and isinstance(prompt, str): 597 | batch_size = 1 598 | elif prompt is not None and isinstance(prompt, list): 599 | batch_size = len(prompt) 600 | else: 601 | batch_size = prompt_embeds.shape[0] 602 | 603 | # 3. Encode input prompt 604 | prompt_embeds, negative_prompt_embeds = self.encode_prompt( 605 | prompt=prompt, 606 | negative_prompt=negative_prompt, 607 | do_classifier_free_guidance=self.do_classifier_free_guidance, 608 | num_videos_per_prompt=num_videos_per_prompt, 609 | prompt_embeds=prompt_embeds, 610 | negative_prompt_embeds=negative_prompt_embeds, 611 | max_sequence_length=max_sequence_length, 612 | device=device, 613 | ) 614 | 615 | transformer_dtype = self.transformer.dtype 616 | prompt_embeds = prompt_embeds.to(transformer_dtype) 617 | if negative_prompt_embeds is not None: 618 | negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) 619 | 620 | # 4. Prepare timesteps 621 | timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) 622 | timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) 623 | latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) 624 | self._num_timesteps = len(timesteps) 625 | 626 | if latents is None: 627 | video = self.video_processor.preprocess_video(video, height=height, width=width) 628 | video = video.to(device=device, dtype=prompt_embeds.dtype) 629 | org_target = video 630 | # 5. Prepare latent variables 631 | num_channels_latents = self.transformer.config.in_channels 632 | 633 | latents = self.prepare_latents( 634 | video, 635 | batch_size * num_videos_per_prompt, 636 | num_channels_latents, 637 | height, 638 | width, 639 | num_frames, 640 | prompt_embeds.dtype, 641 | device, 642 | generator, 643 | latents, 644 | latent_timestep, 645 | ) 646 | 647 | #################### Relighter ################### 648 | from .ic_light import Relighter 649 | num_frames = video.shape[2] 650 | relighter = Relighter( 651 | pipeline=ic_light_pipe, 652 | relight_prompt=relight_prompt, 653 | bg_source=bg_source, 654 | generator=generator, 655 | num_frames=num_frames, 656 | image_width=width, 657 | image_height=height, 658 | ) 659 | 660 | # 6. Denoising loop 661 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 662 | self._num_timesteps = len(timesteps) 663 | 664 | with self.progress_bar(total=num_inference_steps) as progress_bar: 665 | for i, t in enumerate(timesteps): 666 | if self.interrupt: 667 | continue 668 | 669 | self._current_timestep = t 670 | latent_model_input = latents.to(transformer_dtype) 671 | timestep = t.expand(latents.shape[0]) 672 | 673 | ## init the step 674 | self.scheduler._init_step_index(timestep) 675 | 676 | noise_pred = self.transformer( 677 | hidden_states=latent_model_input, 678 | timestep=timestep, 679 | encoder_hidden_states=prompt_embeds, 680 | return_dict=False, 681 | )[0] 682 | 683 | if self.do_classifier_free_guidance: 684 | noise_uncond = self.transformer( 685 | hidden_states=latent_model_input, 686 | timestep=timestep, 687 | encoder_hidden_states=negative_prompt_embeds, 688 | return_dict=False, 689 | )[0] 690 | noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) 691 | 692 | ## Progressive light fusion 693 | k = 1 694 | lbd = 1 - (i / (num_inference_steps - 1)) ** k 695 | 696 | if lbd > 0.2: 697 | 698 | ## get pred_x0 699 | sigma = self.scheduler.sigmas[self.scheduler.step_index] 700 | pred_x0_latent = latents - sigma * noise_pred 701 | consist_target = self.decode_latents(pred_x0_latent) 702 | 703 | ## detail compensation 704 | if i == 0: 705 | detail_diff = org_target - consist_target 706 | consist_target = consist_target + lbd * detail_diff 707 | 708 | consist_target = rearrange(consist_target, "1 c f h w -> f c h w") 709 | relight_target = relighter(consist_target) 710 | 711 | fusion_target = consist_target + lbd * (relight_target - consist_target) 712 | fusion_target = rearrange(fusion_target, "f c h w -> 1 c f h w") 713 | 714 | fusion_target = self.encode_latents(fusion_target) 715 | output = wan_flowmatch_step(self.scheduler, noise_pred, t, latents, fusion_target, return_dict=False) 716 | else: 717 | output = self.scheduler.step(noise_pred, t, latents, return_dict=False) 718 | 719 | latents = output[0] 720 | 721 | if callback_on_step_end is not None: 722 | callback_kwargs = {} 723 | for k in callback_on_step_end_tensor_inputs: 724 | callback_kwargs[k] = locals()[k] 725 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 726 | 727 | latents = callback_outputs.pop("latents", latents) 728 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 729 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 730 | 731 | # call the callback, if provided 732 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 733 | progress_bar.update() 734 | 735 | self._current_timestep = None 736 | 737 | if not output_type == "latent": 738 | latents = latents.to(self.vae.dtype) 739 | latents_mean = ( 740 | torch.tensor(self.vae.config.latents_mean) 741 | .view(1, self.vae.config.z_dim, 1, 1, 1) 742 | .to(latents.device, latents.dtype) 743 | ) 744 | latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( 745 | latents.device, latents.dtype 746 | ) 747 | latents = latents / latents_std + latents_mean 748 | video = self.vae.decode(latents, return_dict=False)[0] 749 | video = self.video_processor.postprocess_video(video, output_type=output_type) 750 | 751 | # Offload all models 752 | self.maybe_free_model_hooks() 753 | 754 | if not return_dict: 755 | return (video,) 756 | 757 | return WanPipelineOutput(frames=video) 758 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | from PIL import Image,ImageSequence 2 | import numpy as np 3 | import torch 4 | from moviepy.video.io.VideoFileClip import VideoFileClip 5 | import os 6 | import imageio 7 | import random 8 | from diffusers.utils import export_to_video 9 | 10 | def resize_and_center_crop(image, target_width, target_height): 11 | pil_image = Image.fromarray(image) 12 | original_width, original_height = pil_image.size 13 | scale_factor = max(target_width / original_width, target_height / original_height) 14 | resized_width = int(round(original_width * scale_factor)) 15 | resized_height = int(round(original_height * scale_factor)) 16 | resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS) 17 | 18 | left = (resized_width - target_width) / 2 19 | top = (resized_height - target_height) / 2 20 | right = (resized_width + target_width) / 2 21 | bottom = (resized_height + target_height) / 2 22 | cropped_image = resized_image.crop((left, top, right, bottom)) 23 | return np.array(cropped_image) 24 | 25 | def numpy2pytorch(imgs, device, dtype): 26 | h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 27 | h = h.movedim(-1, 1) 28 | return h.to(device=device, dtype=dtype) 29 | 30 | def get_fg_video(video_list, mask_list, device, dtype): 31 | video_np = np.stack(video_list, axis=0) 32 | mask_np = np.stack(mask_list, axis=0) 33 | mask_bool = mask_np == 255 34 | video_fg = np.where(mask_bool, video_np, 127) 35 | 36 | h = torch.from_numpy(video_fg).float() / 127.0 - 1.0 37 | h = h.movedim(-1, 1) 38 | return h.to(device=device, dtype=dtype) 39 | 40 | 41 | def pad(x, p, i): 42 | return x[:i] if len(x) >= i else x + [p] * (i - len(x)) 43 | 44 | def gif_to_mp4(gif_path, mp4_path): 45 | clip = VideoFileClip(gif_path) 46 | clip.write_videofile(mp4_path) 47 | 48 | def generate_light_sequence(light_tensor, num_frames=16, direction="r"): 49 | 50 | if direction in "l": 51 | target_tensor = torch.rot90(light_tensor, k=1, dims=(2, 3)) 52 | elif direction in "r": 53 | target_tensor = torch.rot90(light_tensor, k=-1, dims=(2, 3)) 54 | else: 55 | raise ValueError("direction must be either 'r' for right or 'l' for left") 56 | 57 | # Generate the sequence 58 | out_list = [] 59 | for frame_idx in range(num_frames): 60 | t = frame_idx / (num_frames - 1) 61 | interpolated_matrix = (1 - t) * light_tensor + t * target_tensor 62 | out_list.append(interpolated_matrix) 63 | 64 | out_tensor = torch.stack(out_list, dim=0).squeeze(1) 65 | 66 | return out_tensor 67 | 68 | def tensor2vid(video: torch.Tensor, processor, output_type="np"): 69 | 70 | batch_size, channels, num_frames, height, width = video.shape ## [1, 4, 16, 512, 512] 71 | outputs = [] 72 | for batch_idx in range(batch_size): 73 | batch_vid = video[batch_idx].permute(1, 0, 2, 3) 74 | batch_output = processor.postprocess(batch_vid, output_type) 75 | 76 | outputs.append(batch_output) 77 | 78 | return outputs 79 | 80 | def read_video(video_path:str, image_width, image_height): 81 | extension = video_path.split('.')[-1].lower() 82 | video_name = os.path.basename(video_path) 83 | video_list = [] 84 | 85 | if extension in "gif": 86 | ## input from gif 87 | video = Image.open(video_path) 88 | for i, frame in enumerate(ImageSequence.Iterator(video)): 89 | frame = np.array(frame.convert("RGB")) 90 | frame = resize_and_center_crop(frame, image_width, image_height) 91 | video_list.append(frame) 92 | elif extension in "mp4": 93 | ## input from mp4 94 | reader = imageio.get_reader(video_path) 95 | for frame in reader: 96 | frame = resize_and_center_crop(frame, image_width, image_height) 97 | video_list.append(frame) 98 | else: 99 | raise ValueError('Wrong input type') 100 | 101 | video_list = [Image.fromarray(frame) for frame in video_list] 102 | 103 | return video_list, video_name 104 | 105 | def read_mask(mask_folder:str): 106 | mask_files = os.listdir(mask_folder) 107 | mask_files = sorted(mask_files) 108 | mask_list = [] 109 | for mask_file in mask_files: 110 | mask_path = os.path.join(mask_folder, mask_file) 111 | mask = Image.open(mask_path).convert('RGB') 112 | mask_list.append(mask) 113 | 114 | return mask_list 115 | 116 | def decode_latents(vae, latents, decode_chunk_size: int = 16): 117 | 118 | latents = 1 / vae.config.scaling_factor * latents 119 | video = [] 120 | for i in range(0, latents.shape[0], decode_chunk_size): 121 | batch_latents = latents[i : i + decode_chunk_size] 122 | batch_latents = vae.decode(batch_latents).sample 123 | video.append(batch_latents) 124 | 125 | video = torch.cat(video) 126 | 127 | return video 128 | 129 | def encode_video(vae, video, decode_chunk_size: int = 16) -> torch.Tensor: 130 | latents = [] 131 | for i in range(0, len(video), decode_chunk_size): 132 | batch_video = video[i : i + decode_chunk_size] 133 | batch_video = vae.encode(batch_video).latent_dist.mode() 134 | latents.append(batch_video) 135 | return torch.cat(latents) 136 | 137 | def vis_video(input_video, video_processor, save_path): 138 | ## shape: 1, c, f, h, w 139 | relight_video = video_processor.postprocess_video(video=input_video, output_type="pil") 140 | export_to_video(relight_video[0], save_path) 141 | 142 | def set_all_seed(seed): 143 | torch.manual_seed(seed) 144 | torch.cuda.manual_seed(seed) 145 | torch.cuda.manual_seed_all(seed) 146 | np.random.seed(seed) 147 | random.seed(seed) 148 | torch.backends.cudnn.deterministic = True 149 | --------------------------------------------------------------------------------