├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── examples ├── example_01.mp4 ├── example_02.mp4 ├── example_03.mp4 ├── example_04.mp4 ├── example_05.mp4 └── example_06.mp4 ├── normalcrafter ├── __init__.py ├── normal_crafter_ppl.py ├── unet.py └── utils.py ├── requirements.txt └── run.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tar filter=lfs diff=lfs merge=lfs -text 29 | *.tflite filter=lfs diff=lfs merge=lfs -text 30 | *.tgz filter=lfs diff=lfs merge=lfs -text 31 | *.wasm filter=lfs diff=lfs merge=lfs -text 32 | *.xz filter=lfs diff=lfs merge=lfs -text 33 | *.zip filter=lfs diff=lfs merge=lfs -text 34 | *.zst filter=lfs diff=lfs merge=lfs -text 35 | *tfevents* filter=lfs diff=lfs merge=lfs -text 36 | *.gif filter=lfs diff=lfs merge=lfs -text 37 | *.mp4 filter=lfs diff=lfs merge=lfs -text 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python template 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # 11 | .gradio 12 | .github 13 | demo_output 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | .idea/ 166 | 167 | /logs 168 | /gin-config 169 | *.json 170 | /eval/*csv 171 | *__pycache__ 172 | scripts/ 173 | eval/ 174 | *.DS_Store 175 | benchmark/datasets -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yanrui Bin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## ___***NormalCrafter: Learning Temporally Consistent Video Normal from Video Diffusion Priors***___ 2 | 3 | _**[Yanrui Bin1](https://scholar.google.com/citations?user=_9fN3mEAAAAJ&hl=zh-CN),[Wenbo Hu2*](https://wbhu.github.io), 4 | [Haoyuan Wang3](https://www.whyy.site/), 5 | [Xinya Chen4](https://xinyachen21.github.io/), 6 | [Bing Wang2 †](https://bingcs.github.io/)**_ 7 |

8 | 1Spatial Intelligence Group, The Hong Kong Polytechnic University 9 | 2ARC Lab, Tencent PCG 10 | 3City University of Hong Kong 11 | 4Huazhong University of Science and Technology 12 | 13 | 14 | ## 🔆 Notice 15 | We recommend that everyone use English to communicate on issues, as this helps developers from around the world discuss, share experiences, and answer questions together. 16 | 17 | For business licensing and other related inquiries, don't hesitate to contact `binyanrui@gmail.com`. 18 | 19 | ## 🔆 Introduction 20 | 🤗 If you find NormalCrafter useful, **please help ⭐ this repo**, which is important to Open-Source projects. Thanks! 21 | 22 | 🔥 NormalCrafter can generate temporally consistent normal sequences 23 | with fine-grained details from open-world videos with arbitrary lengths. 24 | 25 | - `[24-04-01]` 🔥🔥🔥 **NormalCrafter** is released now, have fun! 26 | ## 🚀 Quick Start 27 | 28 | ### 🤖 Gradio Demo 29 | - Online demo: [NormalCrafter](https://huggingface.co/spaces/Yanrui95/NormalCrafter) 30 | - Local demo: 31 | ```bash 32 | gradio app.py 33 | ``` 34 | 35 | ### 🛠️ Installation 36 | 1. Clone this repo: 37 | ```bash 38 | git clone git@github.com:Binyr/NormalCrafter.git 39 | ``` 40 | 2. Install dependencies (please refer to [requirements.txt](requirements.txt)): 41 | ```bash 42 | pip install -r requirements.txt 43 | ``` 44 | 45 | 46 | 47 | ### 🤗 Model Zoo 48 | [NormalCrafter](https://huggingface.co/Yanrui95/NormalCrafter) is available in the Hugging Face Model Hub. 49 | 50 | ### 🏃‍♂️ Inference 51 | #### 1. High-resolution inference, requires a GPU with ~20GB memory for 1024x576 resolution: 52 | ```bash 53 | python run.py --video-path examples/example_01.mp4 54 | ``` 55 | 56 | #### 2. Low-resolution inference requires a GPU with ~6GB memory for 512x256 resolution: 57 | ```bash 58 | python run.py --video-path examples/example_01.mp4 --max-res 512 59 | ``` 60 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | 4 | import numpy as np 5 | import spaces 6 | import gradio as gr 7 | import torch 8 | from diffusers.training_utils import set_seed 9 | from diffusers import AutoencoderKLTemporalDecoder 10 | 11 | from normalcrafter.normal_crafter_ppl import NormalCrafterPipeline 12 | from normalcrafter.unet import DiffusersUNetSpatioTemporalConditionModelNormalCrafter 13 | 14 | import uuid 15 | import random 16 | from huggingface_hub import hf_hub_download 17 | 18 | from normalcrafter.utils import read_video_frames, vis_sequence_normal, save_video 19 | 20 | examples = [ 21 | ["examples/example_01.mp4", 1024, -1, -1], 22 | ["examples/example_02.mp4", 1024, -1, -1], 23 | ["examples/example_03.mp4", 1024, -1, -1], 24 | ["examples/example_04.mp4", 1024, -1, -1], 25 | # ["examples/example_05.mp4", 1024, -1, -1], 26 | # ["examples/example_06.mp4", 1024, -1, -1], 27 | ] 28 | 29 | pretrained_model_name_or_path = "Yanrui95/NormalCrafter" 30 | weight_dtype = torch.float16 31 | unet = DiffusersUNetSpatioTemporalConditionModelNormalCrafter.from_pretrained( 32 | pretrained_model_name_or_path, 33 | subfolder="unet", 34 | low_cpu_mem_usage=True, 35 | ) 36 | vae = AutoencoderKLTemporalDecoder.from_pretrained( 37 | pretrained_model_name_or_path, subfolder="vae") 38 | 39 | vae.to(dtype=weight_dtype) 40 | unet.to(dtype=weight_dtype) 41 | 42 | pipe = NormalCrafterPipeline.from_pretrained( 43 | "stabilityai/stable-video-diffusion-img2vid-xt", 44 | unet=unet, 45 | vae=vae, 46 | torch_dtype=weight_dtype, 47 | variant="fp16", 48 | ) 49 | pipe.to("cuda") 50 | 51 | 52 | @spaces.GPU(duration=120) 53 | def infer_depth( 54 | video: str, 55 | max_res: int = 1024, 56 | process_length: int = -1, 57 | target_fps: int = -1, 58 | # 59 | save_folder: str = "./demo_output", 60 | window_size: int = 14, 61 | time_step_size: int = 10, 62 | decode_chunk_size: int = 7, 63 | seed: int = 42, 64 | save_npz: bool = False, 65 | ): 66 | set_seed(seed) 67 | pipe.enable_xformers_memory_efficient_attention() 68 | 69 | frames, target_fps = read_video_frames(video, process_length, target_fps, max_res) 70 | 71 | # inference the depth map using the DepthCrafter pipeline 72 | with torch.inference_mode(): 73 | res = pipe( 74 | frames, 75 | decode_chunk_size=decode_chunk_size, 76 | time_step_size=time_step_size, 77 | window_size=window_size, 78 | ).frames[0] 79 | 80 | # visualize the depth map and save the results 81 | vis = vis_sequence_normal(res) 82 | # save the depth map and visualization with the target FPS 83 | save_path = os.path.join(save_folder, os.path.splitext(os.path.basename(video))[0]) 84 | print(f"==> saving results to {save_path}") 85 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 86 | if save_npz: 87 | np.savez_compressed(save_path + ".npz", normal=res) 88 | save_video(vis, save_path + "_vis.mp4", fps=target_fps) 89 | save_video(frames, save_path + "_input.mp4", fps=target_fps) 90 | 91 | # clear the cache for the next video 92 | gc.collect() 93 | torch.cuda.empty_cache() 94 | 95 | return [ 96 | save_path + "_input.mp4", 97 | save_path + "_vis.mp4", 98 | 99 | ] 100 | 101 | 102 | def construct_demo(): 103 | with gr.Blocks(analytics_enabled=False) as depthcrafter_iface: 104 | gr.Markdown( 105 | """ 106 |

NormalCrafter: Learning Temporally Consistent Video Normal from Video Diffusion Priors

\ 107 | If you find NormalCrafter useful, please help ⭐ the \ 108 | [Github Repo]\ 109 | , which is important to Open-Source projects. Thanks!\ 110 | [ArXiv] \ 111 | [Project Page]
112 | """ 113 | ) 114 | 115 | with gr.Row(equal_height=True): 116 | with gr.Column(scale=1): 117 | input_video = gr.Video(label="Input Video") 118 | 119 | # with gr.Tab(label="Output"): 120 | with gr.Column(scale=2): 121 | with gr.Row(equal_height=True): 122 | output_video_1 = gr.Video( 123 | label="Preprocessed Video", 124 | interactive=False, 125 | autoplay=True, 126 | loop=True, 127 | show_share_button=True, 128 | scale=5, 129 | ) 130 | output_video_2 = gr.Video( 131 | label="Generated Normal Video", 132 | interactive=False, 133 | autoplay=True, 134 | loop=True, 135 | show_share_button=True, 136 | scale=5, 137 | ) 138 | 139 | with gr.Row(equal_height=True): 140 | with gr.Column(scale=1): 141 | with gr.Row(equal_height=False): 142 | with gr.Accordion("Advanced Settings", open=False): 143 | max_res = gr.Slider( 144 | label="Max Resolution", 145 | minimum=512, 146 | maximum=1024, 147 | value=1024, 148 | step=64, 149 | ) 150 | process_length = gr.Slider( 151 | label="Process Length", 152 | minimum=-1, 153 | maximum=280, 154 | value=60, 155 | step=1, 156 | ) 157 | process_target_fps = gr.Slider( 158 | label="Target FPS", 159 | minimum=-1, 160 | maximum=30, 161 | value=15, 162 | step=1, 163 | ) 164 | generate_btn = gr.Button("Generate") 165 | with gr.Column(scale=2): 166 | pass 167 | 168 | gr.Examples( 169 | examples=examples, 170 | inputs=[ 171 | input_video, 172 | max_res, 173 | process_length, 174 | process_target_fps, 175 | ], 176 | outputs=[output_video_1, output_video_2], 177 | fn=infer_depth, 178 | cache_examples="lazy", 179 | ) 180 | # gr.Markdown( 181 | # """ 182 | # Note: 183 | # For time quota consideration, we set the default parameters to be more efficient here, 184 | # with a trade-off of shorter video length and slightly lower quality. 185 | # You may adjust the parameters according to our 186 | # [Github Repo] 187 | # for better results if you have enough time quota. 188 | # 189 | # """ 190 | # ) 191 | 192 | generate_btn.click( 193 | fn=infer_depth, 194 | inputs=[ 195 | input_video, 196 | max_res, 197 | process_length, 198 | process_target_fps, 199 | ], 200 | outputs=[output_video_1, output_video_2], 201 | ) 202 | 203 | return depthcrafter_iface 204 | 205 | 206 | if __name__ == "__main__": 207 | demo = construct_demo() 208 | demo.queue() 209 | # demo.launch(server_name="0.0.0.0", server_port=12345, debug=True, share=False) 210 | demo.launch(share=True) 211 | -------------------------------------------------------------------------------- /examples/example_01.mp4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3eb7fefd157bd9b403cf0b524c7c4f3cb6d9f82b9d6a48eba2146412fc9e64a2 3 | size 5727137 4 | -------------------------------------------------------------------------------- /examples/example_02.mp4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ea3c4e4c8cd9682d92c25170d8df333fead210118802fbe22198dde478dc5489 3 | size 3150525 4 | -------------------------------------------------------------------------------- /examples/example_03.mp4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5d332877a98bb41ff86a639139a03e383e91880bca722bba7e2518878fca54f6 3 | size 3013435 4 | -------------------------------------------------------------------------------- /examples/example_04.mp4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b2aa4962216adce71b1c47f395be435b23105df35f3892646e237b935ac1c74f 3 | size 3591374 4 | -------------------------------------------------------------------------------- /examples/example_05.mp4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e8d2319060f9a1d3cfcb9de317e4a5b138657fd741c530ed3983f6565c2eda44 3 | size 3553683 4 | -------------------------------------------------------------------------------- /examples/example_06.mp4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e3a2619b029129f34884c761cc278b6842620bfed96d4bb52c8aa07bc1d82a8b 3 | size 5596872 4 | -------------------------------------------------------------------------------- /normalcrafter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Binyr/NormalCrafter/497d163460404bdd57697e90bde95062f62a5e92/normalcrafter/__init__.py -------------------------------------------------------------------------------- /normalcrafter/normal_crafter_ppl.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Callable, Dict, List, Optional, Union 3 | 4 | import numpy as np 5 | import PIL.Image 6 | import torch 7 | import torch.nn.functional as F 8 | import math 9 | 10 | from diffusers.utils import BaseOutput, logging 11 | from diffusers.utils.torch_utils import is_compiled_module, randn_tensor 12 | from diffusers import DiffusionPipeline 13 | from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import StableVideoDiffusionPipelineOutput, StableVideoDiffusionPipeline 14 | from PIL import Image 15 | import cv2 16 | 17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 18 | 19 | class NormalCrafterPipeline(StableVideoDiffusionPipeline): 20 | 21 | def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance, scale=1, image_size=None): 22 | dtype = next(self.image_encoder.parameters()).dtype 23 | 24 | if not isinstance(image, torch.Tensor): 25 | image = self.video_processor.pil_to_numpy(image) # (0, 255) -> (0, 1) 26 | image = self.video_processor.numpy_to_pt(image) # (n, h, w, c) -> (n, c, h, w) 27 | 28 | # We normalize the image before resizing to match with the original implementation. 29 | # Then we unnormalize it after resizing. 30 | pixel_values = image 31 | B, C, H, W = pixel_values.shape 32 | patches = [pixel_values] 33 | # patches = [] 34 | for i in range(1, scale): 35 | num_patches_HW_this_level = i + 1 36 | patch_H = H // num_patches_HW_this_level + 1 37 | patch_W = W // num_patches_HW_this_level + 1 38 | for j in range(num_patches_HW_this_level): 39 | for k in range(num_patches_HW_this_level): 40 | patches.append(pixel_values[:, :, j*patch_H:(j+1)*patch_H, k*patch_W:(k+1)*patch_W]) 41 | 42 | def encode_image(image): 43 | image = image * 2.0 - 1.0 44 | if image_size is not None: 45 | image = _resize_with_antialiasing(image, image_size) 46 | else: 47 | image = _resize_with_antialiasing(image, (224, 224)) 48 | image = (image + 1.0) / 2.0 49 | 50 | # Normalize the image with for CLIP input 51 | image = self.feature_extractor( 52 | images=image, 53 | do_normalize=True, 54 | do_center_crop=False, 55 | do_resize=False, 56 | do_rescale=False, 57 | return_tensors="pt", 58 | ).pixel_values 59 | 60 | image = image.to(device=device, dtype=dtype) 61 | image_embeddings = self.image_encoder(image).image_embeds 62 | if len(image_embeddings.shape) < 3: 63 | image_embeddings = image_embeddings.unsqueeze(1) 64 | return image_embeddings 65 | 66 | image_embeddings = [] 67 | for patch in patches: 68 | image_embeddings.append(encode_image(patch)) 69 | image_embeddings = torch.cat(image_embeddings, dim=1) 70 | 71 | # duplicate image embeddings for each generation per prompt, using mps friendly method 72 | # import pdb 73 | # pdb.set_trace() 74 | bs_embed, seq_len, _ = image_embeddings.shape 75 | image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) 76 | image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) 77 | 78 | if do_classifier_free_guidance: 79 | negative_image_embeddings = torch.zeros_like(image_embeddings) 80 | 81 | # For classifier free guidance, we need to do two forward passes. 82 | # Here we concatenate the unconditional and text embeddings into a single batch 83 | # to avoid doing two forward passes 84 | image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) 85 | 86 | return image_embeddings 87 | 88 | def ecnode_video_vae(self, images, chunk_size: int = 14): 89 | if isinstance(images, list): 90 | width, height = images[0].size 91 | else: 92 | height, width = images[0].shape[:2] 93 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast 94 | if needs_upcasting: 95 | self.vae.to(dtype=torch.float32) 96 | 97 | device = self._execution_device 98 | images = self.video_processor.preprocess_video(images, height=height, width=width).to(device, self.vae.dtype) # torch type in range(-1, 1) with (1,3,h,w) 99 | images = images.squeeze(0) # from (1, c, t, h, w) -> (c, t, h, w) 100 | images = images.permute(1,0,2,3) # c, t, h, w -> (t, c, h, w) 101 | 102 | video_latents = [] 103 | # chunk_size = 14 104 | for i in range(0, images.shape[0], chunk_size): 105 | video_latents.append(self.vae.encode(images[i : i + chunk_size]).latent_dist.mode()) 106 | image_latents = torch.cat(video_latents) 107 | 108 | # cast back to fp16 if needed 109 | if needs_upcasting: 110 | self.vae.to(dtype=torch.float16) 111 | 112 | return image_latents 113 | 114 | def pad_image(self, images, scale=64): 115 | def get_pad(newW, W): 116 | pad_W = (newW - W) // 2 117 | if W % 2 == 1: 118 | pad_Ws = [pad_W, pad_W + 1] 119 | else: 120 | pad_Ws = [pad_W, pad_W] 121 | return pad_Ws 122 | 123 | if type(images[0]) is np.ndarray: 124 | H, W = images[0].shape[:2] 125 | else: 126 | W, H = images[0].size 127 | 128 | if W % scale == 0 and H % scale == 0: 129 | return images, None 130 | newW = int(np.ceil(W / scale) * scale) 131 | newH = int(np.ceil(H / scale) * scale) 132 | 133 | pad_Ws = get_pad(newW, W) 134 | pad_Hs = get_pad(newH, H) 135 | 136 | new_images = [] 137 | for image in images: 138 | if type(image) is np.ndarray: 139 | image = cv2.copyMakeBorder(image, *pad_Hs, *pad_Ws, cv2.BORDER_CONSTANT, value=(1.,1.,1.)) 140 | new_images.append(image) 141 | else: 142 | image = np.array(image) 143 | image = cv2.copyMakeBorder(image, *pad_Hs, *pad_Ws, cv2.BORDER_CONSTANT, value=(255,255,255)) 144 | new_images.append(Image.fromarray(image)) 145 | return new_images, pad_Hs+pad_Ws 146 | 147 | def unpad_image(self, v, pad_HWs): 148 | t, b, l, r = pad_HWs 149 | if t > 0 or b > 0: 150 | v = v[:, :, t:-b] 151 | if l > 0 or r > 0: 152 | v = v[:, :, :, l:-r] 153 | return v 154 | 155 | @torch.no_grad() 156 | def __call__( 157 | self, 158 | images: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], 159 | decode_chunk_size: Optional[int] = None, 160 | time_step_size: Optional[int] = 1, 161 | window_size: Optional[int] = 1, 162 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 163 | return_dict: bool = True 164 | ): 165 | images, pad_HWs = self.pad_image(images) 166 | 167 | # 0. Default height and width to unet 168 | width, height = images[0].size 169 | num_frames = len(images) 170 | 171 | # 1. Check inputs. Raise error if not correct 172 | self.check_inputs(images, height, width) 173 | 174 | # 2. Define call parameters 175 | batch_size = 1 176 | device = self._execution_device 177 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 178 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 179 | # corresponds to doing no classifier free guidance. 180 | self._guidance_scale = 1.0 181 | num_videos_per_prompt = 1 182 | do_classifier_free_guidance = False 183 | num_inference_steps = 1 184 | fps = 7 185 | motion_bucket_id = 127 186 | noise_aug_strength = 0. 187 | num_videos_per_prompt = 1 188 | output_type = "np" 189 | data_keys = ["normal"] 190 | use_linear_merge = True 191 | determineTrain = True 192 | encode_image_scale = 1 193 | encode_image_WH = None 194 | 195 | decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 7 196 | 197 | # 3. Encode input image using using clip. (num_image * num_videos_per_prompt, 1, 1024) 198 | image_embeddings = self._encode_image(images, device, num_videos_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, scale=encode_image_scale, image_size=encode_image_WH) 199 | # 4. Encode input image using VAE 200 | image_latents = self.ecnode_video_vae(images, chunk_size=decode_chunk_size).to(image_embeddings.dtype) 201 | 202 | # image_latents [num_frames, channels, height, width] ->[1, num_frames, channels, height, width] 203 | image_latents = image_latents.unsqueeze(0) 204 | 205 | # 5. Get Added Time IDs 206 | added_time_ids = self._get_add_time_ids( 207 | fps, 208 | motion_bucket_id, 209 | noise_aug_strength, 210 | image_embeddings.dtype, 211 | batch_size, 212 | num_videos_per_prompt, 213 | do_classifier_free_guidance, 214 | ) 215 | added_time_ids = added_time_ids.to(device) 216 | 217 | # get Start and End frame idx for each window 218 | def get_ses(num_frames): 219 | ses = [] 220 | for i in range(0, num_frames, time_step_size): 221 | ses.append([i, i+window_size]) 222 | num_to_remain = 0 223 | for se in ses: 224 | if se[1] > num_frames: 225 | continue 226 | num_to_remain += 1 227 | ses = ses[:num_to_remain] 228 | 229 | if ses[-1][-1] < num_frames: 230 | ses.append([num_frames - window_size, num_frames]) 231 | return ses 232 | ses = get_ses(num_frames) 233 | 234 | pred = None 235 | for i, se in enumerate(ses): 236 | window_num_frames = window_size 237 | window_image_embeddings = image_embeddings[se[0]:se[1]] 238 | window_image_latents = image_latents[:, se[0]:se[1]] 239 | window_added_time_ids = added_time_ids 240 | # import pdb 241 | # pdb.set_trace() 242 | if i == 0 or time_step_size == window_size: 243 | to_replace_latents = None 244 | else: 245 | last_se = ses[i-1] 246 | num_to_replace_latents = last_se[1] - se[0] 247 | to_replace_latents = pred[:, -num_to_replace_latents:] 248 | 249 | latents = self.generate( 250 | num_inference_steps, 251 | device, 252 | batch_size, 253 | num_videos_per_prompt, 254 | window_num_frames, 255 | height, 256 | width, 257 | window_image_embeddings, 258 | generator, 259 | determineTrain, 260 | to_replace_latents, 261 | do_classifier_free_guidance, 262 | window_image_latents, 263 | window_added_time_ids 264 | ) 265 | 266 | # merge last_latents and current latents in overlap window 267 | if to_replace_latents is not None and use_linear_merge: 268 | num_img_condition = to_replace_latents.shape[1] 269 | weight = torch.linspace(1., 0., num_img_condition+2)[1:-1].to(device) 270 | weight = weight[None, :, None, None, None] 271 | latents[:, :num_img_condition] = to_replace_latents * weight + latents[:, :num_img_condition] * (1 - weight) 272 | 273 | if pred is None: 274 | pred = latents 275 | else: 276 | pred = torch.cat([pred[:, :se[0]], latents], dim=1) 277 | 278 | if not output_type == "latent": 279 | # cast back to fp16 if needed 280 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast 281 | if needs_upcasting: 282 | self.vae.to(dtype=torch.float16) 283 | # latents has shape (1, num_frames, 12, h, w) 284 | 285 | def decode_latents(latents, num_frames, decode_chunk_size): 286 | frames = self.decode_latents(latents, num_frames, decode_chunk_size) # in range(-1, 1) 287 | frames = self.video_processor.postprocess_video(video=frames, output_type="np") 288 | frames = frames * 2 - 1 # from range(0, 1) -> range(-1, 1) 289 | return frames 290 | 291 | frames = decode_latents(pred, num_frames, decode_chunk_size) 292 | if pad_HWs is not None: 293 | frames = self.unpad_image(frames, pad_HWs) 294 | else: 295 | frames = pred 296 | 297 | self.maybe_free_model_hooks() 298 | 299 | if not return_dict: 300 | return frames 301 | 302 | return StableVideoDiffusionPipelineOutput(frames=frames) 303 | 304 | 305 | def generate( 306 | self, 307 | num_inference_steps, 308 | device, 309 | batch_size, 310 | num_videos_per_prompt, 311 | num_frames, 312 | height, 313 | width, 314 | image_embeddings, 315 | generator, 316 | determineTrain, 317 | to_replace_latents, 318 | do_classifier_free_guidance, 319 | image_latents, 320 | added_time_ids, 321 | latents=None, 322 | ): 323 | # 6. Prepare timesteps 324 | self.scheduler.set_timesteps(num_inference_steps, device=device) 325 | timesteps = self.scheduler.timesteps 326 | 327 | # 7. Prepare latent variables 328 | num_channels_latents = self.unet.config.in_channels 329 | latents = self.prepare_latents( 330 | batch_size * num_videos_per_prompt, 331 | num_frames, 332 | num_channels_latents, 333 | height, 334 | width, 335 | image_embeddings.dtype, 336 | device, 337 | generator, 338 | latents, 339 | ) 340 | if determineTrain: 341 | latents[...] = 0. 342 | 343 | # 8. Denoising loop 344 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 345 | self._num_timesteps = len(timesteps) 346 | with self.progress_bar(total=num_inference_steps) as progress_bar: 347 | for i, t in enumerate(timesteps): 348 | # replace part of latents with conditons. ToDo: t embedding should also replace 349 | if to_replace_latents is not None: 350 | num_img_condition = to_replace_latents.shape[1] 351 | if not determineTrain: 352 | _noise = randn_tensor(to_replace_latents.shape, generator=generator, device=device, dtype=image_embeddings.dtype) 353 | noisy_to_replace_latents = self.scheduler.add_noise(to_replace_latents, _noise, t.unsqueeze(0)) 354 | latents[:, :num_img_condition] = noisy_to_replace_latents 355 | else: 356 | latents[:, :num_img_condition] = to_replace_latents 357 | 358 | 359 | # expand the latents if we are doing classifier free guidance 360 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 361 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 362 | timestep = t 363 | # Concatenate image_latents over channels dimention 364 | latent_model_input = torch.cat([latent_model_input, image_latents], dim=2) 365 | # predict the noise residual 366 | noise_pred = self.unet( 367 | latent_model_input, 368 | timestep, 369 | encoder_hidden_states=image_embeddings, 370 | added_time_ids=added_time_ids, 371 | return_dict=False, 372 | )[0] 373 | 374 | # perform guidance 375 | if do_classifier_free_guidance: 376 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) 377 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) 378 | 379 | # compute the previous noisy sample x_t -> x_t-1 380 | scheduler_output = self.scheduler.step(noise_pred, t, latents) 381 | latents = scheduler_output.prev_sample 382 | 383 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 384 | progress_bar.update() 385 | 386 | return latents 387 | # resizing utils 388 | # TODO: clean up later 389 | def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): 390 | h, w = input.shape[-2:] 391 | factors = (h / size[0], w / size[1]) 392 | 393 | # First, we have to determine sigma 394 | # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 395 | sigmas = ( 396 | max((factors[0] - 1.0) / 2.0, 0.001), 397 | max((factors[1] - 1.0) / 2.0, 0.001), 398 | ) 399 | 400 | # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma 401 | # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 402 | # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now 403 | ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) 404 | 405 | # Make sure it is odd 406 | if (ks[0] % 2) == 0: 407 | ks = ks[0] + 1, ks[1] 408 | 409 | if (ks[1] % 2) == 0: 410 | ks = ks[0], ks[1] + 1 411 | 412 | input = _gaussian_blur2d(input, ks, sigmas) 413 | 414 | output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) 415 | return output 416 | 417 | 418 | def _compute_padding(kernel_size): 419 | """Compute padding tuple.""" 420 | # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) 421 | # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad 422 | if len(kernel_size) < 2: 423 | raise AssertionError(kernel_size) 424 | computed = [k - 1 for k in kernel_size] 425 | 426 | # for even kernels we need to do asymmetric padding :( 427 | out_padding = 2 * len(kernel_size) * [0] 428 | 429 | for i in range(len(kernel_size)): 430 | computed_tmp = computed[-(i + 1)] 431 | 432 | pad_front = computed_tmp // 2 433 | pad_rear = computed_tmp - pad_front 434 | 435 | out_padding[2 * i + 0] = pad_front 436 | out_padding[2 * i + 1] = pad_rear 437 | 438 | return out_padding 439 | 440 | 441 | def _filter2d(input, kernel): 442 | # prepare kernel 443 | b, c, h, w = input.shape 444 | tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) 445 | 446 | tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) 447 | 448 | height, width = tmp_kernel.shape[-2:] 449 | 450 | padding_shape: list[int] = _compute_padding([height, width]) 451 | input = torch.nn.functional.pad(input, padding_shape, mode="reflect") 452 | 453 | # kernel and input tensor reshape to align element-wise or batch-wise params 454 | tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) 455 | input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) 456 | 457 | # convolve the tensor with the kernel. 458 | output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) 459 | 460 | out = output.view(b, c, h, w) 461 | return out 462 | 463 | 464 | def _gaussian(window_size: int, sigma): 465 | if isinstance(sigma, float): 466 | sigma = torch.tensor([[sigma]]) 467 | 468 | batch_size = sigma.shape[0] 469 | 470 | x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) 471 | 472 | if window_size % 2 == 0: 473 | x = x + 0.5 474 | 475 | gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) 476 | 477 | return gauss / gauss.sum(-1, keepdim=True) 478 | 479 | 480 | def _gaussian_blur2d(input, kernel_size, sigma): 481 | if isinstance(sigma, tuple): 482 | sigma = torch.tensor([sigma], dtype=input.dtype) 483 | else: 484 | sigma = sigma.to(dtype=input.dtype) 485 | 486 | ky, kx = int(kernel_size[0]), int(kernel_size[1]) 487 | bs = sigma.shape[0] 488 | kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) 489 | kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) 490 | out_x = _filter2d(input, kernel_x[..., None, :]) 491 | out = _filter2d(out_x, kernel_y[..., None]) 492 | 493 | return out 494 | -------------------------------------------------------------------------------- /normalcrafter/unet.py: -------------------------------------------------------------------------------- 1 | from diffusers import UNetSpatioTemporalConditionModel 2 | from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionOutput 3 | from diffusers.utils import is_torch_version 4 | import torch 5 | from typing import Any, Dict, Optional, Tuple, Union 6 | 7 | def create_custom_forward(module, return_dict=None): 8 | def custom_forward(*inputs): 9 | if return_dict is not None: 10 | return module(*inputs, return_dict=return_dict) 11 | else: 12 | return module(*inputs) 13 | 14 | return custom_forward 15 | CKPT_KWARGS = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 16 | 17 | 18 | class DiffusersUNetSpatioTemporalConditionModelNormalCrafter(UNetSpatioTemporalConditionModel): 19 | 20 | @staticmethod 21 | def forward_crossattn_down_block_dino( 22 | module, 23 | hidden_states: torch.Tensor, 24 | temb: Optional[torch.Tensor] = None, 25 | encoder_hidden_states: Optional[torch.Tensor] = None, 26 | image_only_indicator: Optional[torch.Tensor] = None, 27 | dino_down_block_res_samples = None, 28 | ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: 29 | output_states = () 30 | self = module 31 | blocks = list(zip(self.resnets, self.attentions)) 32 | for resnet, attn in blocks: 33 | if self.training and self.gradient_checkpointing: # TODO 34 | hidden_states = torch.utils.checkpoint.checkpoint( 35 | create_custom_forward(resnet), 36 | hidden_states, 37 | temb, 38 | image_only_indicator, 39 | **CKPT_KWARGS, 40 | ) 41 | 42 | hidden_states = torch.utils.checkpoint.checkpoint( 43 | create_custom_forward(attn), 44 | hidden_states, 45 | encoder_hidden_states, 46 | image_only_indicator, 47 | False, 48 | **CKPT_KWARGS, 49 | )[0] 50 | else: 51 | hidden_states = resnet( 52 | hidden_states, 53 | temb, 54 | image_only_indicator=image_only_indicator, 55 | ) 56 | hidden_states = attn( 57 | hidden_states, 58 | encoder_hidden_states=encoder_hidden_states, 59 | image_only_indicator=image_only_indicator, 60 | return_dict=False, 61 | )[0] 62 | 63 | if dino_down_block_res_samples is not None: 64 | hidden_states += dino_down_block_res_samples.pop(0) 65 | 66 | output_states = output_states + (hidden_states,) 67 | 68 | if self.downsamplers is not None: 69 | for downsampler in self.downsamplers: 70 | hidden_states = downsampler(hidden_states) 71 | if dino_down_block_res_samples is not None: 72 | hidden_states += dino_down_block_res_samples.pop(0) 73 | 74 | output_states = output_states + (hidden_states,) 75 | 76 | return hidden_states, output_states 77 | @staticmethod 78 | def forward_down_block_dino( 79 | module, 80 | hidden_states: torch.Tensor, 81 | temb: Optional[torch.Tensor] = None, 82 | image_only_indicator: Optional[torch.Tensor] = None, 83 | dino_down_block_res_samples = None, 84 | ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: 85 | self = module 86 | output_states = () 87 | for resnet in self.resnets: 88 | if self.training and self.gradient_checkpointing: 89 | if is_torch_version(">=", "1.11.0"): 90 | hidden_states = torch.utils.checkpoint.checkpoint( 91 | create_custom_forward(resnet), 92 | hidden_states, 93 | temb, 94 | image_only_indicator, 95 | use_reentrant=False, 96 | ) 97 | else: 98 | hidden_states = torch.utils.checkpoint.checkpoint( 99 | create_custom_forward(resnet), 100 | hidden_states, 101 | temb, 102 | image_only_indicator, 103 | ) 104 | else: 105 | hidden_states = resnet( 106 | hidden_states, 107 | temb, 108 | image_only_indicator=image_only_indicator, 109 | ) 110 | if dino_down_block_res_samples is not None: 111 | hidden_states += dino_down_block_res_samples.pop(0) 112 | output_states = output_states + (hidden_states,) 113 | 114 | if self.downsamplers is not None: 115 | for downsampler in self.downsamplers: 116 | hidden_states = downsampler(hidden_states) 117 | if dino_down_block_res_samples is not None: 118 | hidden_states += dino_down_block_res_samples.pop(0) 119 | output_states = output_states + (hidden_states,) 120 | 121 | return hidden_states, output_states 122 | 123 | 124 | def forward( 125 | self, 126 | sample: torch.FloatTensor, 127 | timestep: Union[torch.Tensor, float, int], 128 | encoder_hidden_states: torch.Tensor, 129 | added_time_ids: torch.Tensor, 130 | return_dict: bool = True, 131 | image_controlnet_down_block_res_samples = None, 132 | image_controlnet_mid_block_res_sample = None, 133 | dino_down_block_res_samples = None, 134 | 135 | ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]: 136 | r""" 137 | The [`UNetSpatioTemporalConditionModel`] forward method. 138 | 139 | Args: 140 | sample (`torch.FloatTensor`): 141 | The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. 142 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. 143 | encoder_hidden_states (`torch.FloatTensor`): 144 | The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. 145 | added_time_ids: (`torch.FloatTensor`): 146 | The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal 147 | embeddings and added to the time embeddings. 148 | return_dict (`bool`, *optional*, defaults to `True`): 149 | Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain 150 | tuple. 151 | Returns: 152 | [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: 153 | If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise 154 | a `tuple` is returned where the first element is the sample tensor. 155 | """ 156 | if not hasattr(self, "custom_gradient_checkpointing"): 157 | self.custom_gradient_checkpointing = False 158 | 159 | # 1. time 160 | timesteps = timestep 161 | if not torch.is_tensor(timesteps): 162 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 163 | # This would be a good case for the `match` statement (Python 3.10+) 164 | is_mps = sample.device.type == "mps" 165 | if isinstance(timestep, float): 166 | dtype = torch.float32 if is_mps else torch.float64 167 | else: 168 | dtype = torch.int32 if is_mps else torch.int64 169 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 170 | elif len(timesteps.shape) == 0: 171 | timesteps = timesteps[None].to(sample.device) 172 | 173 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 174 | batch_size, num_frames = sample.shape[:2] 175 | if len(timesteps.shape) == 1: 176 | timesteps = timesteps.expand(batch_size) 177 | else: 178 | timesteps = timesteps.reshape(batch_size * num_frames) 179 | t_emb = self.time_proj(timesteps) # (B, C) 180 | 181 | # `Timesteps` does not contain any weights and will always return f32 tensors 182 | # but time_embedding might actually be running in fp16. so we need to cast here. 183 | # there might be better ways to encapsulate this. 184 | t_emb = t_emb.to(dtype=sample.dtype) 185 | 186 | emb = self.time_embedding(t_emb) # (B, C) 187 | 188 | time_embeds = self.add_time_proj(added_time_ids.flatten()) 189 | time_embeds = time_embeds.reshape((batch_size, -1)) 190 | time_embeds = time_embeds.to(emb.dtype) 191 | aug_emb = self.add_embedding(time_embeds) 192 | if emb.shape[0] == 1: 193 | emb = emb + aug_emb 194 | # Repeat the embeddings num_video_frames times 195 | # emb: [batch, channels] -> [batch * frames, channels] 196 | emb = emb.repeat_interleave(num_frames, dim=0) 197 | else: 198 | aug_emb = aug_emb.repeat_interleave(num_frames, dim=0) 199 | emb = emb + aug_emb 200 | 201 | # Flatten the batch and frames dimensions 202 | # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] 203 | sample = sample.flatten(0, 1) 204 | 205 | # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] 206 | # here, our encoder_hidden_states is [batch * frames, 1, channels] 207 | 208 | if not sample.shape[0] == encoder_hidden_states.shape[0]: 209 | encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) 210 | # 2. pre-process 211 | sample = self.conv_in(sample) 212 | 213 | image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) 214 | 215 | if dino_down_block_res_samples is not None: 216 | dino_down_block_res_samples = [x for x in dino_down_block_res_samples] 217 | sample += dino_down_block_res_samples.pop(0) 218 | 219 | down_block_res_samples = (sample,) 220 | for downsample_block in self.down_blocks: 221 | if dino_down_block_res_samples is None: 222 | if self.custom_gradient_checkpointing: 223 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 224 | sample, res_samples = torch.utils.checkpoint.checkpoint( 225 | create_custom_forward(downsample_block), 226 | sample, 227 | emb, 228 | encoder_hidden_states, 229 | image_only_indicator, 230 | **CKPT_KWARGS, 231 | ) 232 | else: 233 | sample, res_samples = torch.utils.checkpoint.checkpoint( 234 | create_custom_forward(downsample_block), 235 | sample, 236 | emb, 237 | image_only_indicator, 238 | **CKPT_KWARGS, 239 | ) 240 | else: 241 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 242 | sample, res_samples = downsample_block( 243 | hidden_states=sample, 244 | temb=emb, 245 | encoder_hidden_states=encoder_hidden_states, 246 | image_only_indicator=image_only_indicator, 247 | ) 248 | else: 249 | sample, res_samples = downsample_block( 250 | hidden_states=sample, 251 | temb=emb, 252 | image_only_indicator=image_only_indicator, 253 | ) 254 | else: 255 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 256 | sample, res_samples = self.forward_crossattn_down_block_dino( 257 | downsample_block, 258 | sample, 259 | emb, 260 | encoder_hidden_states, 261 | image_only_indicator, 262 | dino_down_block_res_samples, 263 | ) 264 | else: 265 | sample, res_samples = self.forward_down_block_dino( 266 | downsample_block, 267 | sample, 268 | emb, 269 | image_only_indicator, 270 | dino_down_block_res_samples, 271 | ) 272 | down_block_res_samples += res_samples 273 | 274 | if image_controlnet_down_block_res_samples is not None: 275 | new_down_block_res_samples = () 276 | 277 | for down_block_res_sample, image_controlnet_down_block_res_sample in zip( 278 | down_block_res_samples, image_controlnet_down_block_res_samples 279 | ): 280 | down_block_res_sample = (down_block_res_sample + image_controlnet_down_block_res_sample) / 2 281 | new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) 282 | 283 | down_block_res_samples = new_down_block_res_samples 284 | 285 | # 4. mid 286 | if self.custom_gradient_checkpointing: 287 | sample = torch.utils.checkpoint.checkpoint( 288 | create_custom_forward(self.mid_block), 289 | sample, 290 | emb, 291 | encoder_hidden_states, 292 | image_only_indicator, 293 | **CKPT_KWARGS, 294 | ) 295 | else: 296 | sample = self.mid_block( 297 | hidden_states=sample, 298 | temb=emb, 299 | encoder_hidden_states=encoder_hidden_states, 300 | image_only_indicator=image_only_indicator, 301 | ) 302 | 303 | if image_controlnet_mid_block_res_sample is not None: 304 | sample = (sample + image_controlnet_mid_block_res_sample) / 2 305 | 306 | # 5. up 307 | mid_up_block_out_samples = [sample, ] 308 | down_block_out_sampels = [] 309 | for i, upsample_block in enumerate(self.up_blocks): 310 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 311 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 312 | down_block_out_sampels.append(res_samples[-1]) 313 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 314 | if self.custom_gradient_checkpointing: 315 | sample = torch.utils.checkpoint.checkpoint( 316 | create_custom_forward(upsample_block), 317 | sample, 318 | res_samples, 319 | emb, 320 | encoder_hidden_states, 321 | image_only_indicator, 322 | **CKPT_KWARGS 323 | ) 324 | else: 325 | sample = upsample_block( 326 | hidden_states=sample, 327 | temb=emb, 328 | res_hidden_states_tuple=res_samples, 329 | encoder_hidden_states=encoder_hidden_states, 330 | image_only_indicator=image_only_indicator, 331 | ) 332 | else: 333 | if self.custom_gradient_checkpointing: 334 | sample = torch.utils.checkpoint.checkpoint( 335 | create_custom_forward(upsample_block), 336 | sample, 337 | res_samples, 338 | emb, 339 | image_only_indicator, 340 | **CKPT_KWARGS 341 | ) 342 | else: 343 | sample = upsample_block( 344 | hidden_states=sample, 345 | temb=emb, 346 | res_hidden_states_tuple=res_samples, 347 | image_only_indicator=image_only_indicator, 348 | ) 349 | mid_up_block_out_samples.append(sample) 350 | # 6. post-process 351 | sample = self.conv_norm_out(sample) 352 | sample = self.conv_act(sample) 353 | if self.custom_gradient_checkpointing: 354 | sample = torch.utils.checkpoint.checkpoint( 355 | create_custom_forward(self.conv_out), 356 | sample, 357 | **CKPT_KWARGS 358 | ) 359 | else: 360 | sample = self.conv_out(sample) 361 | 362 | # 7. Reshape back to original shape 363 | sample = sample.reshape(batch_size, num_frames, *sample.shape[1:]) 364 | 365 | if not return_dict: 366 | return (sample, down_block_out_sampels[::-1], mid_up_block_out_samples) 367 | 368 | return UNetSpatioTemporalConditionOutput(sample=sample) -------------------------------------------------------------------------------- /normalcrafter/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | import tempfile 3 | import numpy as np 4 | import PIL.Image 5 | import matplotlib.cm as cm 6 | import mediapy 7 | import torch 8 | from decord import VideoReader, cpu 9 | 10 | 11 | def read_video_frames(video_path, process_length, target_fps, max_res): 12 | print("==> processing video: ", video_path) 13 | vid = VideoReader(video_path, ctx=cpu(0)) 14 | print("==> original video shape: ", (len(vid), *vid.get_batch([0]).shape[1:])) 15 | original_height, original_width = vid.get_batch([0]).shape[1:3] 16 | 17 | if max(original_height, original_width) > max_res: 18 | scale = max_res / max(original_height, original_width) 19 | height = round(original_height * scale) 20 | width = round(original_width * scale) 21 | else: 22 | height = original_height 23 | width = original_width 24 | 25 | vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height) 26 | 27 | fps = vid.get_avg_fps() if target_fps == -1 else target_fps 28 | stride = round(vid.get_avg_fps() / fps) 29 | stride = max(stride, 1) 30 | frames_idx = list(range(0, len(vid), stride)) 31 | print( 32 | f"==> downsampled shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}, with stride: {stride}" 33 | ) 34 | if process_length != -1 and process_length < len(frames_idx): 35 | frames_idx = frames_idx[:process_length] 36 | print( 37 | f"==> final processing shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}" 38 | ) 39 | frames = vid.get_batch(frames_idx).asnumpy().astype(np.uint8) 40 | frames = [PIL.Image.fromarray(x) for x in frames] 41 | 42 | return frames, fps 43 | 44 | def save_video( 45 | video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], 46 | output_video_path: str = None, 47 | fps: int = 10, 48 | crf: int = 18, 49 | ) -> str: 50 | if output_video_path is None: 51 | output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name 52 | 53 | if isinstance(video_frames[0], np.ndarray): 54 | video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] 55 | 56 | elif isinstance(video_frames[0], PIL.Image.Image): 57 | video_frames = [np.array(frame) for frame in video_frames] 58 | mediapy.write_video(output_video_path, video_frames, fps=fps, crf=crf) 59 | return output_video_path 60 | 61 | def vis_sequence_normal(normals: np.ndarray): 62 | normals = normals.clip(-1., 1.) 63 | normals = normals * 0.5 + 0.5 64 | return normals 65 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.30.1 2 | aiofiles==23.2.1 3 | annotated-types==0.7.0 4 | anyio==4.9.0 5 | asttokens==3.0.0 6 | certifi==2025.1.31 7 | charset-normalizer==3.4.1 8 | click==8.1.8 9 | cmake==4.0.0 10 | contourpy==1.3.1 11 | cycler==0.12.1 12 | decorator==5.2.1 13 | decord==0.6.0 14 | diffusers==0.29.1 15 | einops==0.8.1 16 | exceptiongroup==1.2.2 17 | executing==2.2.0 18 | fastapi==0.115.12 19 | ffmpy==0.5.0 20 | filelock==3.18.0 21 | fire==0.6.0 22 | fonttools==4.57.0 23 | fsspec==2025.3.2 24 | gradio==5.23.3 25 | gradio_client==1.8.0 26 | groovy==0.1.2 27 | h11==0.14.0 28 | httpcore==1.0.7 29 | httpx==0.28.1 30 | huggingface-hub==0.30.1 31 | idna==3.10 32 | importlib_metadata==8.6.1 33 | ipython==8.35.0 34 | jedi==0.19.2 35 | Jinja2==3.1.6 36 | kiwisolver==1.4.8 37 | lit==18.1.8 38 | markdown-it-py==3.0.0 39 | MarkupSafe==3.0.2 40 | matplotlib==3.8.4 41 | matplotlib-inline==0.1.7 42 | mdurl==0.1.2 43 | mediapy==1.2.0 44 | mpmath==1.3.0 45 | mypy-extensions==1.0.0 46 | networkx==3.4.2 47 | numpy==1.26.4 48 | nvidia-cublas-cu11==11.10.3.66 49 | nvidia-cuda-cupti-cu11==11.7.101 50 | nvidia-cuda-nvrtc-cu11==11.7.99 51 | nvidia-cuda-runtime-cu11==11.7.99 52 | nvidia-cudnn-cu11==8.5.0.96 53 | nvidia-cufft-cu11==10.9.0.58 54 | nvidia-curand-cu11==10.2.10.91 55 | nvidia-cusolver-cu11==11.4.0.1 56 | nvidia-cusparse-cu11==11.7.4.91 57 | nvidia-nccl-cu11==2.14.3 58 | nvidia-nvtx-cu11==11.7.91 59 | opencv-python==4.11.0.86 60 | OpenEXR==3.2.4 61 | orjson==3.10.16 62 | packaging==24.2 63 | pandas==2.2.3 64 | parso==0.8.4 65 | pexpect==4.9.0 66 | pillow==11.1.0 67 | prompt_toolkit==3.0.50 68 | psutil==5.9.8 69 | ptyprocess==0.7.0 70 | pure_eval==0.2.3 71 | pydantic==2.11.2 72 | pydantic_core==2.33.1 73 | pydub==0.25.1 74 | Pygments==2.19.1 75 | pyparsing==3.2.3 76 | pyre-extensions==0.0.29 77 | python-dateutil==2.9.0.post0 78 | python-multipart==0.0.20 79 | pytz==2025.2 80 | PyYAML==6.0.2 81 | regex==2024.11.6 82 | requests==2.32.3 83 | rich==14.0.0 84 | ruff==0.11.4 85 | safehttpx==0.1.6 86 | safetensors==0.5.3 87 | semantic-version==2.10.0 88 | shellingham==1.5.4 89 | six==1.17.0 90 | sniffio==1.3.1 91 | spaces==0.34.1 92 | stack-data==0.6.3 93 | starlette==0.46.1 94 | sympy==1.13.3 95 | termcolor==3.0.1 96 | tokenizers==0.19.1 97 | tomlkit==0.13.2 98 | torch==2.0.1 99 | tqdm==4.67.1 100 | traitlets==5.14.3 101 | transformers==4.41.2 102 | triton==2.0.0 103 | typer==0.15.2 104 | typing-inspect==0.9.0 105 | typing-inspection==0.4.0 106 | typing_extensions==4.13.1 107 | tzdata==2025.2 108 | urllib3==2.3.0 109 | uvicorn==0.34.0 110 | wcwidth==0.2.13 111 | websockets==15.0.1 112 | xformers==0.0.20 113 | zipp==3.21.0 114 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import numpy as np 4 | import torch 5 | 6 | from diffusers.training_utils import set_seed 7 | from diffusers import AutoencoderKLTemporalDecoder 8 | from fire import Fire 9 | 10 | from normalcrafter.normal_crafter_ppl import NormalCrafterPipeline 11 | from normalcrafter.unet import DiffusersUNetSpatioTemporalConditionModelNormalCrafter 12 | from normalcrafter.utils import vis_sequence_normal, save_video, read_video_frames 13 | 14 | 15 | class DepthCrafterDemo: 16 | def __init__( 17 | self, 18 | unet_path: str, 19 | pre_train_path: str, 20 | cpu_offload: str = "model", 21 | ): 22 | unet = DiffusersUNetSpatioTemporalConditionModelNormalCrafter.from_pretrained( 23 | unet_path, 24 | subfolder="unet", 25 | low_cpu_mem_usage=True, 26 | ) 27 | vae = AutoencoderKLTemporalDecoder.from_pretrained( 28 | unet_path, subfolder="vae" 29 | ) 30 | weight_dtype = torch.float16 31 | vae.to(dtype=weight_dtype) 32 | unet.to(dtype=weight_dtype) 33 | # load weights of other components from the provided checkpoint 34 | self.pipe = NormalCrafterPipeline.from_pretrained( 35 | pre_train_path, 36 | unet=unet, 37 | vae=vae, 38 | torch_dtype=weight_dtype, 39 | variant="fp16", 40 | ) 41 | 42 | # for saving memory, we can offload the model to CPU, or even run the model sequentially to save more memory 43 | if cpu_offload is not None: 44 | if cpu_offload == "sequential": 45 | # This will slow, but save more memory 46 | self.pipe.enable_sequential_cpu_offload() 47 | elif cpu_offload == "model": 48 | self.pipe.enable_model_cpu_offload() 49 | else: 50 | raise ValueError(f"Unknown cpu offload option: {cpu_offload}") 51 | else: 52 | self.pipe.to("cuda") 53 | # enable attention slicing and xformers memory efficient attention 54 | try: 55 | self.pipe.enable_xformers_memory_efficient_attention() 56 | except Exception as e: 57 | print(e) 58 | print("Xformers is not enabled") 59 | # self.pipe.enable_attention_slicing() 60 | 61 | def infer( 62 | self, 63 | video: str, 64 | save_folder: str = "./demo_output", 65 | window_size: int = 14, 66 | time_step_size: int = 10, 67 | process_length: int = 195, 68 | decode_chunk_size: int = 7, 69 | max_res: int = 1024, 70 | dataset: str = "open", 71 | target_fps: int = 15, 72 | seed: int = 42, 73 | save_npz: bool = False, 74 | ): 75 | set_seed(seed) 76 | 77 | frames, target_fps = read_video_frames( 78 | video, 79 | process_length, 80 | target_fps, 81 | max_res, 82 | ) 83 | # inference the depth map using the DepthCrafter pipeline 84 | with torch.inference_mode(): 85 | res = self.pipe( 86 | frames, 87 | decode_chunk_size=decode_chunk_size, 88 | time_step_size=time_step_size, 89 | window_size=window_size, 90 | ).frames[0] 91 | # visualize the depth map and save the results 92 | vis = vis_sequence_normal(res) 93 | # save the depth map and visualization with the target FPS 94 | save_path = os.path.join( 95 | save_folder, os.path.splitext(os.path.basename(video))[0] 96 | ) 97 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 98 | save_video(vis, save_path + "_vis.mp4", fps=target_fps) 99 | save_video(frames, save_path + "_input.mp4", fps=target_fps) 100 | if save_npz: 101 | np.savez_compressed(save_path + ".npz", depth=res) 102 | 103 | return [ 104 | save_path + "_input.mp4", 105 | save_path + "_vis.mp4", 106 | ] 107 | 108 | def run( 109 | self, 110 | input_video, 111 | num_denoising_steps, 112 | guidance_scale, 113 | max_res=1024, 114 | process_length=195, 115 | ): 116 | res_path = self.infer( 117 | input_video, 118 | num_denoising_steps, 119 | guidance_scale, 120 | max_res=max_res, 121 | process_length=process_length, 122 | ) 123 | # clear the cache for the next video 124 | gc.collect() 125 | torch.cuda.empty_cache() 126 | return res_path[:2] 127 | 128 | 129 | def main( 130 | video_path: str, 131 | save_folder: str = "./demo_output", 132 | unet_path: str = "Yanrui95/NormalCrafter", 133 | pre_train_path: str = "stabilityai/stable-video-diffusion-img2vid-xt", 134 | process_length: int = -1, 135 | cpu_offload: str = "model", 136 | target_fps: int = -1, 137 | seed: int = 42, 138 | window_size: int = 14, 139 | time_step_size: int = 10, 140 | max_res: int = 1024, 141 | dataset: str = "open", 142 | save_npz: bool = False 143 | ): 144 | depthcrafter_demo = DepthCrafterDemo( 145 | unet_path=unet_path, 146 | pre_train_path=pre_train_path, 147 | cpu_offload=cpu_offload, 148 | ) 149 | # process the videos, the video paths are separated by comma 150 | video_paths = video_path.split(",") 151 | for video in video_paths: 152 | depthcrafter_demo.infer( 153 | video, 154 | save_folder=save_folder, 155 | window_size=window_size, 156 | process_length=process_length, 157 | time_step_size=time_step_size, 158 | max_res=max_res, 159 | dataset=dataset, 160 | target_fps=target_fps, 161 | seed=seed, 162 | save_npz=save_npz, 163 | ) 164 | # clear the cache for the next video 165 | gc.collect() 166 | torch.cuda.empty_cache() 167 | 168 | 169 | if __name__ == "__main__": 170 | # running configs 171 | # the most important arguments for memory saving are `cpu_offload`, `enable_xformers`, `max_res`, and `window_size` 172 | # the most important arguments for trade-off between quality and speed are 173 | # `num_inference_steps`, `guidance_scale`, and `max_res` 174 | Fire(main) 175 | --------------------------------------------------------------------------------