├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── app.py
├── examples
├── example_01.mp4
├── example_02.mp4
├── example_03.mp4
├── example_04.mp4
├── example_05.mp4
└── example_06.mp4
├── normalcrafter
├── __init__.py
├── normal_crafter_ppl.py
├── unet.py
└── utils.py
├── requirements.txt
└── run.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.7z filter=lfs diff=lfs merge=lfs -text
2 | *.arrow filter=lfs diff=lfs merge=lfs -text
3 | *.bin filter=lfs diff=lfs merge=lfs -text
4 | *.bz2 filter=lfs diff=lfs merge=lfs -text
5 | *.ckpt filter=lfs diff=lfs merge=lfs -text
6 | *.ftz filter=lfs diff=lfs merge=lfs -text
7 | *.gz filter=lfs diff=lfs merge=lfs -text
8 | *.h5 filter=lfs diff=lfs merge=lfs -text
9 | *.joblib filter=lfs diff=lfs merge=lfs -text
10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text
11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text
12 | *.model filter=lfs diff=lfs merge=lfs -text
13 | *.msgpack filter=lfs diff=lfs merge=lfs -text
14 | *.npy filter=lfs diff=lfs merge=lfs -text
15 | *.npz filter=lfs diff=lfs merge=lfs -text
16 | *.onnx filter=lfs diff=lfs merge=lfs -text
17 | *.ot filter=lfs diff=lfs merge=lfs -text
18 | *.parquet filter=lfs diff=lfs merge=lfs -text
19 | *.pb filter=lfs diff=lfs merge=lfs -text
20 | *.pickle filter=lfs diff=lfs merge=lfs -text
21 | *.pkl filter=lfs diff=lfs merge=lfs -text
22 | *.pt filter=lfs diff=lfs merge=lfs -text
23 | *.pth filter=lfs diff=lfs merge=lfs -text
24 | *.rar filter=lfs diff=lfs merge=lfs -text
25 | *.safetensors filter=lfs diff=lfs merge=lfs -text
26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27 | *.tar.* filter=lfs diff=lfs merge=lfs -text
28 | *.tar filter=lfs diff=lfs merge=lfs -text
29 | *.tflite filter=lfs diff=lfs merge=lfs -text
30 | *.tgz filter=lfs diff=lfs merge=lfs -text
31 | *.wasm filter=lfs diff=lfs merge=lfs -text
32 | *.xz filter=lfs diff=lfs merge=lfs -text
33 | *.zip filter=lfs diff=lfs merge=lfs -text
34 | *.zst filter=lfs diff=lfs merge=lfs -text
35 | *tfevents* filter=lfs diff=lfs merge=lfs -text
36 | *.gif filter=lfs diff=lfs merge=lfs -text
37 | *.mp4 filter=lfs diff=lfs merge=lfs -text
38 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ### Python template
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | #
11 | .gradio
12 | .github
13 | demo_output
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 | cover/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | .pybuilder/
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | # For a library or package, you might want to ignore these files since the code is
92 | # intended to run in multiple environments; otherwise, check them in:
93 | # .python-version
94 |
95 | # pipenv
96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 | # install all needed dependencies.
100 | #Pipfile.lock
101 |
102 | # poetry
103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104 | # This is especially recommended for binary packages to ensure reproducibility, and is more
105 | # commonly ignored for libraries.
106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107 | #poetry.lock
108 |
109 | # pdm
110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111 | #pdm.lock
112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113 | # in version control.
114 | # https://pdm.fming.dev/#use-with-ide
115 | .pdm.toml
116 |
117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118 | __pypackages__/
119 |
120 | # Celery stuff
121 | celerybeat-schedule
122 | celerybeat.pid
123 |
124 | # SageMath parsed files
125 | *.sage.py
126 |
127 | # Environments
128 | .env
129 | .venv
130 | env/
131 | venv/
132 | ENV/
133 | env.bak/
134 | venv.bak/
135 |
136 | # Spyder project settings
137 | .spyderproject
138 | .spyproject
139 |
140 | # Rope project settings
141 | .ropeproject
142 |
143 | # mkdocs documentation
144 | /site
145 |
146 | # mypy
147 | .mypy_cache/
148 | .dmypy.json
149 | dmypy.json
150 |
151 | # Pyre type checker
152 | .pyre/
153 |
154 | # pytype static type analyzer
155 | .pytype/
156 |
157 | # Cython debug symbols
158 | cython_debug/
159 |
160 | # PyCharm
161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163 | # and can be added to the global gitignore or merged into this file. For a more nuclear
164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165 | .idea/
166 |
167 | /logs
168 | /gin-config
169 | *.json
170 | /eval/*csv
171 | *__pycache__
172 | scripts/
173 | eval/
174 | *.DS_Store
175 | benchmark/datasets
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Yanrui Bin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## ___***NormalCrafter: Learning Temporally Consistent Video Normal from Video Diffusion Priors***___
2 |
3 | _**[Yanrui Bin1](https://scholar.google.com/citations?user=_9fN3mEAAAAJ&hl=zh-CN),[Wenbo Hu2*](https://wbhu.github.io),
4 | [Haoyuan Wang3](https://www.whyy.site/),
5 | [Xinya Chen4](https://xinyachen21.github.io/),
6 | [Bing Wang2 †](https://bingcs.github.io/)**_
7 |
8 | 1Spatial Intelligence Group, The Hong Kong Polytechnic University
9 | 2ARC Lab, Tencent PCG
10 | 3City University of Hong Kong
11 | 4Huazhong University of Science and Technology
12 |
13 |
14 | ## 🔆 Notice
15 | We recommend that everyone use English to communicate on issues, as this helps developers from around the world discuss, share experiences, and answer questions together.
16 |
17 | For business licensing and other related inquiries, don't hesitate to contact `binyanrui@gmail.com`.
18 |
19 | ## 🔆 Introduction
20 | 🤗 If you find NormalCrafter useful, **please help ⭐ this repo**, which is important to Open-Source projects. Thanks!
21 |
22 | 🔥 NormalCrafter can generate temporally consistent normal sequences
23 | with fine-grained details from open-world videos with arbitrary lengths.
24 |
25 | - `[24-04-01]` 🔥🔥🔥 **NormalCrafter** is released now, have fun!
26 | ## 🚀 Quick Start
27 |
28 | ### 🤖 Gradio Demo
29 | - Online demo: [NormalCrafter](https://huggingface.co/spaces/Yanrui95/NormalCrafter)
30 | - Local demo:
31 | ```bash
32 | gradio app.py
33 | ```
34 |
35 | ### 🛠️ Installation
36 | 1. Clone this repo:
37 | ```bash
38 | git clone git@github.com:Binyr/NormalCrafter.git
39 | ```
40 | 2. Install dependencies (please refer to [requirements.txt](requirements.txt)):
41 | ```bash
42 | pip install -r requirements.txt
43 | ```
44 |
45 |
46 |
47 | ### 🤗 Model Zoo
48 | [NormalCrafter](https://huggingface.co/Yanrui95/NormalCrafter) is available in the Hugging Face Model Hub.
49 |
50 | ### 🏃♂️ Inference
51 | #### 1. High-resolution inference, requires a GPU with ~20GB memory for 1024x576 resolution:
52 | ```bash
53 | python run.py --video-path examples/example_01.mp4
54 | ```
55 |
56 | #### 2. Low-resolution inference requires a GPU with ~6GB memory for 512x256 resolution:
57 | ```bash
58 | python run.py --video-path examples/example_01.mp4 --max-res 512
59 | ```
60 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 |
4 | import numpy as np
5 | import spaces
6 | import gradio as gr
7 | import torch
8 | from diffusers.training_utils import set_seed
9 | from diffusers import AutoencoderKLTemporalDecoder
10 |
11 | from normalcrafter.normal_crafter_ppl import NormalCrafterPipeline
12 | from normalcrafter.unet import DiffusersUNetSpatioTemporalConditionModelNormalCrafter
13 |
14 | import uuid
15 | import random
16 | from huggingface_hub import hf_hub_download
17 |
18 | from normalcrafter.utils import read_video_frames, vis_sequence_normal, save_video
19 |
20 | examples = [
21 | ["examples/example_01.mp4", 1024, -1, -1],
22 | ["examples/example_02.mp4", 1024, -1, -1],
23 | ["examples/example_03.mp4", 1024, -1, -1],
24 | ["examples/example_04.mp4", 1024, -1, -1],
25 | # ["examples/example_05.mp4", 1024, -1, -1],
26 | # ["examples/example_06.mp4", 1024, -1, -1],
27 | ]
28 |
29 | pretrained_model_name_or_path = "Yanrui95/NormalCrafter"
30 | weight_dtype = torch.float16
31 | unet = DiffusersUNetSpatioTemporalConditionModelNormalCrafter.from_pretrained(
32 | pretrained_model_name_or_path,
33 | subfolder="unet",
34 | low_cpu_mem_usage=True,
35 | )
36 | vae = AutoencoderKLTemporalDecoder.from_pretrained(
37 | pretrained_model_name_or_path, subfolder="vae")
38 |
39 | vae.to(dtype=weight_dtype)
40 | unet.to(dtype=weight_dtype)
41 |
42 | pipe = NormalCrafterPipeline.from_pretrained(
43 | "stabilityai/stable-video-diffusion-img2vid-xt",
44 | unet=unet,
45 | vae=vae,
46 | torch_dtype=weight_dtype,
47 | variant="fp16",
48 | )
49 | pipe.to("cuda")
50 |
51 |
52 | @spaces.GPU(duration=120)
53 | def infer_depth(
54 | video: str,
55 | max_res: int = 1024,
56 | process_length: int = -1,
57 | target_fps: int = -1,
58 | #
59 | save_folder: str = "./demo_output",
60 | window_size: int = 14,
61 | time_step_size: int = 10,
62 | decode_chunk_size: int = 7,
63 | seed: int = 42,
64 | save_npz: bool = False,
65 | ):
66 | set_seed(seed)
67 | pipe.enable_xformers_memory_efficient_attention()
68 |
69 | frames, target_fps = read_video_frames(video, process_length, target_fps, max_res)
70 |
71 | # inference the depth map using the DepthCrafter pipeline
72 | with torch.inference_mode():
73 | res = pipe(
74 | frames,
75 | decode_chunk_size=decode_chunk_size,
76 | time_step_size=time_step_size,
77 | window_size=window_size,
78 | ).frames[0]
79 |
80 | # visualize the depth map and save the results
81 | vis = vis_sequence_normal(res)
82 | # save the depth map and visualization with the target FPS
83 | save_path = os.path.join(save_folder, os.path.splitext(os.path.basename(video))[0])
84 | print(f"==> saving results to {save_path}")
85 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
86 | if save_npz:
87 | np.savez_compressed(save_path + ".npz", normal=res)
88 | save_video(vis, save_path + "_vis.mp4", fps=target_fps)
89 | save_video(frames, save_path + "_input.mp4", fps=target_fps)
90 |
91 | # clear the cache for the next video
92 | gc.collect()
93 | torch.cuda.empty_cache()
94 |
95 | return [
96 | save_path + "_input.mp4",
97 | save_path + "_vis.mp4",
98 |
99 | ]
100 |
101 |
102 | def construct_demo():
103 | with gr.Blocks(analytics_enabled=False) as depthcrafter_iface:
104 | gr.Markdown(
105 | """
106 |
112 | """
113 | )
114 |
115 | with gr.Row(equal_height=True):
116 | with gr.Column(scale=1):
117 | input_video = gr.Video(label="Input Video")
118 |
119 | # with gr.Tab(label="Output"):
120 | with gr.Column(scale=2):
121 | with gr.Row(equal_height=True):
122 | output_video_1 = gr.Video(
123 | label="Preprocessed Video",
124 | interactive=False,
125 | autoplay=True,
126 | loop=True,
127 | show_share_button=True,
128 | scale=5,
129 | )
130 | output_video_2 = gr.Video(
131 | label="Generated Normal Video",
132 | interactive=False,
133 | autoplay=True,
134 | loop=True,
135 | show_share_button=True,
136 | scale=5,
137 | )
138 |
139 | with gr.Row(equal_height=True):
140 | with gr.Column(scale=1):
141 | with gr.Row(equal_height=False):
142 | with gr.Accordion("Advanced Settings", open=False):
143 | max_res = gr.Slider(
144 | label="Max Resolution",
145 | minimum=512,
146 | maximum=1024,
147 | value=1024,
148 | step=64,
149 | )
150 | process_length = gr.Slider(
151 | label="Process Length",
152 | minimum=-1,
153 | maximum=280,
154 | value=60,
155 | step=1,
156 | )
157 | process_target_fps = gr.Slider(
158 | label="Target FPS",
159 | minimum=-1,
160 | maximum=30,
161 | value=15,
162 | step=1,
163 | )
164 | generate_btn = gr.Button("Generate")
165 | with gr.Column(scale=2):
166 | pass
167 |
168 | gr.Examples(
169 | examples=examples,
170 | inputs=[
171 | input_video,
172 | max_res,
173 | process_length,
174 | process_target_fps,
175 | ],
176 | outputs=[output_video_1, output_video_2],
177 | fn=infer_depth,
178 | cache_examples="lazy",
179 | )
180 | # gr.Markdown(
181 | # """
182 | # Note:
183 | # For time quota consideration, we set the default parameters to be more efficient here,
184 | # with a trade-off of shorter video length and slightly lower quality.
185 | # You may adjust the parameters according to our
186 | # [Github Repo]
187 | # for better results if you have enough time quota.
188 | #
189 | # """
190 | # )
191 |
192 | generate_btn.click(
193 | fn=infer_depth,
194 | inputs=[
195 | input_video,
196 | max_res,
197 | process_length,
198 | process_target_fps,
199 | ],
200 | outputs=[output_video_1, output_video_2],
201 | )
202 |
203 | return depthcrafter_iface
204 |
205 |
206 | if __name__ == "__main__":
207 | demo = construct_demo()
208 | demo.queue()
209 | # demo.launch(server_name="0.0.0.0", server_port=12345, debug=True, share=False)
210 | demo.launch(share=True)
211 |
--------------------------------------------------------------------------------
/examples/example_01.mp4:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:3eb7fefd157bd9b403cf0b524c7c4f3cb6d9f82b9d6a48eba2146412fc9e64a2
3 | size 5727137
4 |
--------------------------------------------------------------------------------
/examples/example_02.mp4:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:ea3c4e4c8cd9682d92c25170d8df333fead210118802fbe22198dde478dc5489
3 | size 3150525
4 |
--------------------------------------------------------------------------------
/examples/example_03.mp4:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:5d332877a98bb41ff86a639139a03e383e91880bca722bba7e2518878fca54f6
3 | size 3013435
4 |
--------------------------------------------------------------------------------
/examples/example_04.mp4:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:b2aa4962216adce71b1c47f395be435b23105df35f3892646e237b935ac1c74f
3 | size 3591374
4 |
--------------------------------------------------------------------------------
/examples/example_05.mp4:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:e8d2319060f9a1d3cfcb9de317e4a5b138657fd741c530ed3983f6565c2eda44
3 | size 3553683
4 |
--------------------------------------------------------------------------------
/examples/example_06.mp4:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:e3a2619b029129f34884c761cc278b6842620bfed96d4bb52c8aa07bc1d82a8b
3 | size 5596872
4 |
--------------------------------------------------------------------------------
/normalcrafter/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Binyr/NormalCrafter/497d163460404bdd57697e90bde95062f62a5e92/normalcrafter/__init__.py
--------------------------------------------------------------------------------
/normalcrafter/normal_crafter_ppl.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Callable, Dict, List, Optional, Union
3 |
4 | import numpy as np
5 | import PIL.Image
6 | import torch
7 | import torch.nn.functional as F
8 | import math
9 |
10 | from diffusers.utils import BaseOutput, logging
11 | from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
12 | from diffusers import DiffusionPipeline
13 | from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import StableVideoDiffusionPipelineOutput, StableVideoDiffusionPipeline
14 | from PIL import Image
15 | import cv2
16 |
17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18 |
19 | class NormalCrafterPipeline(StableVideoDiffusionPipeline):
20 |
21 | def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance, scale=1, image_size=None):
22 | dtype = next(self.image_encoder.parameters()).dtype
23 |
24 | if not isinstance(image, torch.Tensor):
25 | image = self.video_processor.pil_to_numpy(image) # (0, 255) -> (0, 1)
26 | image = self.video_processor.numpy_to_pt(image) # (n, h, w, c) -> (n, c, h, w)
27 |
28 | # We normalize the image before resizing to match with the original implementation.
29 | # Then we unnormalize it after resizing.
30 | pixel_values = image
31 | B, C, H, W = pixel_values.shape
32 | patches = [pixel_values]
33 | # patches = []
34 | for i in range(1, scale):
35 | num_patches_HW_this_level = i + 1
36 | patch_H = H // num_patches_HW_this_level + 1
37 | patch_W = W // num_patches_HW_this_level + 1
38 | for j in range(num_patches_HW_this_level):
39 | for k in range(num_patches_HW_this_level):
40 | patches.append(pixel_values[:, :, j*patch_H:(j+1)*patch_H, k*patch_W:(k+1)*patch_W])
41 |
42 | def encode_image(image):
43 | image = image * 2.0 - 1.0
44 | if image_size is not None:
45 | image = _resize_with_antialiasing(image, image_size)
46 | else:
47 | image = _resize_with_antialiasing(image, (224, 224))
48 | image = (image + 1.0) / 2.0
49 |
50 | # Normalize the image with for CLIP input
51 | image = self.feature_extractor(
52 | images=image,
53 | do_normalize=True,
54 | do_center_crop=False,
55 | do_resize=False,
56 | do_rescale=False,
57 | return_tensors="pt",
58 | ).pixel_values
59 |
60 | image = image.to(device=device, dtype=dtype)
61 | image_embeddings = self.image_encoder(image).image_embeds
62 | if len(image_embeddings.shape) < 3:
63 | image_embeddings = image_embeddings.unsqueeze(1)
64 | return image_embeddings
65 |
66 | image_embeddings = []
67 | for patch in patches:
68 | image_embeddings.append(encode_image(patch))
69 | image_embeddings = torch.cat(image_embeddings, dim=1)
70 |
71 | # duplicate image embeddings for each generation per prompt, using mps friendly method
72 | # import pdb
73 | # pdb.set_trace()
74 | bs_embed, seq_len, _ = image_embeddings.shape
75 | image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
76 | image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
77 |
78 | if do_classifier_free_guidance:
79 | negative_image_embeddings = torch.zeros_like(image_embeddings)
80 |
81 | # For classifier free guidance, we need to do two forward passes.
82 | # Here we concatenate the unconditional and text embeddings into a single batch
83 | # to avoid doing two forward passes
84 | image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
85 |
86 | return image_embeddings
87 |
88 | def ecnode_video_vae(self, images, chunk_size: int = 14):
89 | if isinstance(images, list):
90 | width, height = images[0].size
91 | else:
92 | height, width = images[0].shape[:2]
93 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
94 | if needs_upcasting:
95 | self.vae.to(dtype=torch.float32)
96 |
97 | device = self._execution_device
98 | images = self.video_processor.preprocess_video(images, height=height, width=width).to(device, self.vae.dtype) # torch type in range(-1, 1) with (1,3,h,w)
99 | images = images.squeeze(0) # from (1, c, t, h, w) -> (c, t, h, w)
100 | images = images.permute(1,0,2,3) # c, t, h, w -> (t, c, h, w)
101 |
102 | video_latents = []
103 | # chunk_size = 14
104 | for i in range(0, images.shape[0], chunk_size):
105 | video_latents.append(self.vae.encode(images[i : i + chunk_size]).latent_dist.mode())
106 | image_latents = torch.cat(video_latents)
107 |
108 | # cast back to fp16 if needed
109 | if needs_upcasting:
110 | self.vae.to(dtype=torch.float16)
111 |
112 | return image_latents
113 |
114 | def pad_image(self, images, scale=64):
115 | def get_pad(newW, W):
116 | pad_W = (newW - W) // 2
117 | if W % 2 == 1:
118 | pad_Ws = [pad_W, pad_W + 1]
119 | else:
120 | pad_Ws = [pad_W, pad_W]
121 | return pad_Ws
122 |
123 | if type(images[0]) is np.ndarray:
124 | H, W = images[0].shape[:2]
125 | else:
126 | W, H = images[0].size
127 |
128 | if W % scale == 0 and H % scale == 0:
129 | return images, None
130 | newW = int(np.ceil(W / scale) * scale)
131 | newH = int(np.ceil(H / scale) * scale)
132 |
133 | pad_Ws = get_pad(newW, W)
134 | pad_Hs = get_pad(newH, H)
135 |
136 | new_images = []
137 | for image in images:
138 | if type(image) is np.ndarray:
139 | image = cv2.copyMakeBorder(image, *pad_Hs, *pad_Ws, cv2.BORDER_CONSTANT, value=(1.,1.,1.))
140 | new_images.append(image)
141 | else:
142 | image = np.array(image)
143 | image = cv2.copyMakeBorder(image, *pad_Hs, *pad_Ws, cv2.BORDER_CONSTANT, value=(255,255,255))
144 | new_images.append(Image.fromarray(image))
145 | return new_images, pad_Hs+pad_Ws
146 |
147 | def unpad_image(self, v, pad_HWs):
148 | t, b, l, r = pad_HWs
149 | if t > 0 or b > 0:
150 | v = v[:, :, t:-b]
151 | if l > 0 or r > 0:
152 | v = v[:, :, :, l:-r]
153 | return v
154 |
155 | @torch.no_grad()
156 | def __call__(
157 | self,
158 | images: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
159 | decode_chunk_size: Optional[int] = None,
160 | time_step_size: Optional[int] = 1,
161 | window_size: Optional[int] = 1,
162 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
163 | return_dict: bool = True
164 | ):
165 | images, pad_HWs = self.pad_image(images)
166 |
167 | # 0. Default height and width to unet
168 | width, height = images[0].size
169 | num_frames = len(images)
170 |
171 | # 1. Check inputs. Raise error if not correct
172 | self.check_inputs(images, height, width)
173 |
174 | # 2. Define call parameters
175 | batch_size = 1
176 | device = self._execution_device
177 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
178 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
179 | # corresponds to doing no classifier free guidance.
180 | self._guidance_scale = 1.0
181 | num_videos_per_prompt = 1
182 | do_classifier_free_guidance = False
183 | num_inference_steps = 1
184 | fps = 7
185 | motion_bucket_id = 127
186 | noise_aug_strength = 0.
187 | num_videos_per_prompt = 1
188 | output_type = "np"
189 | data_keys = ["normal"]
190 | use_linear_merge = True
191 | determineTrain = True
192 | encode_image_scale = 1
193 | encode_image_WH = None
194 |
195 | decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 7
196 |
197 | # 3. Encode input image using using clip. (num_image * num_videos_per_prompt, 1, 1024)
198 | image_embeddings = self._encode_image(images, device, num_videos_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, scale=encode_image_scale, image_size=encode_image_WH)
199 | # 4. Encode input image using VAE
200 | image_latents = self.ecnode_video_vae(images, chunk_size=decode_chunk_size).to(image_embeddings.dtype)
201 |
202 | # image_latents [num_frames, channels, height, width] ->[1, num_frames, channels, height, width]
203 | image_latents = image_latents.unsqueeze(0)
204 |
205 | # 5. Get Added Time IDs
206 | added_time_ids = self._get_add_time_ids(
207 | fps,
208 | motion_bucket_id,
209 | noise_aug_strength,
210 | image_embeddings.dtype,
211 | batch_size,
212 | num_videos_per_prompt,
213 | do_classifier_free_guidance,
214 | )
215 | added_time_ids = added_time_ids.to(device)
216 |
217 | # get Start and End frame idx for each window
218 | def get_ses(num_frames):
219 | ses = []
220 | for i in range(0, num_frames, time_step_size):
221 | ses.append([i, i+window_size])
222 | num_to_remain = 0
223 | for se in ses:
224 | if se[1] > num_frames:
225 | continue
226 | num_to_remain += 1
227 | ses = ses[:num_to_remain]
228 |
229 | if ses[-1][-1] < num_frames:
230 | ses.append([num_frames - window_size, num_frames])
231 | return ses
232 | ses = get_ses(num_frames)
233 |
234 | pred = None
235 | for i, se in enumerate(ses):
236 | window_num_frames = window_size
237 | window_image_embeddings = image_embeddings[se[0]:se[1]]
238 | window_image_latents = image_latents[:, se[0]:se[1]]
239 | window_added_time_ids = added_time_ids
240 | # import pdb
241 | # pdb.set_trace()
242 | if i == 0 or time_step_size == window_size:
243 | to_replace_latents = None
244 | else:
245 | last_se = ses[i-1]
246 | num_to_replace_latents = last_se[1] - se[0]
247 | to_replace_latents = pred[:, -num_to_replace_latents:]
248 |
249 | latents = self.generate(
250 | num_inference_steps,
251 | device,
252 | batch_size,
253 | num_videos_per_prompt,
254 | window_num_frames,
255 | height,
256 | width,
257 | window_image_embeddings,
258 | generator,
259 | determineTrain,
260 | to_replace_latents,
261 | do_classifier_free_guidance,
262 | window_image_latents,
263 | window_added_time_ids
264 | )
265 |
266 | # merge last_latents and current latents in overlap window
267 | if to_replace_latents is not None and use_linear_merge:
268 | num_img_condition = to_replace_latents.shape[1]
269 | weight = torch.linspace(1., 0., num_img_condition+2)[1:-1].to(device)
270 | weight = weight[None, :, None, None, None]
271 | latents[:, :num_img_condition] = to_replace_latents * weight + latents[:, :num_img_condition] * (1 - weight)
272 |
273 | if pred is None:
274 | pred = latents
275 | else:
276 | pred = torch.cat([pred[:, :se[0]], latents], dim=1)
277 |
278 | if not output_type == "latent":
279 | # cast back to fp16 if needed
280 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
281 | if needs_upcasting:
282 | self.vae.to(dtype=torch.float16)
283 | # latents has shape (1, num_frames, 12, h, w)
284 |
285 | def decode_latents(latents, num_frames, decode_chunk_size):
286 | frames = self.decode_latents(latents, num_frames, decode_chunk_size) # in range(-1, 1)
287 | frames = self.video_processor.postprocess_video(video=frames, output_type="np")
288 | frames = frames * 2 - 1 # from range(0, 1) -> range(-1, 1)
289 | return frames
290 |
291 | frames = decode_latents(pred, num_frames, decode_chunk_size)
292 | if pad_HWs is not None:
293 | frames = self.unpad_image(frames, pad_HWs)
294 | else:
295 | frames = pred
296 |
297 | self.maybe_free_model_hooks()
298 |
299 | if not return_dict:
300 | return frames
301 |
302 | return StableVideoDiffusionPipelineOutput(frames=frames)
303 |
304 |
305 | def generate(
306 | self,
307 | num_inference_steps,
308 | device,
309 | batch_size,
310 | num_videos_per_prompt,
311 | num_frames,
312 | height,
313 | width,
314 | image_embeddings,
315 | generator,
316 | determineTrain,
317 | to_replace_latents,
318 | do_classifier_free_guidance,
319 | image_latents,
320 | added_time_ids,
321 | latents=None,
322 | ):
323 | # 6. Prepare timesteps
324 | self.scheduler.set_timesteps(num_inference_steps, device=device)
325 | timesteps = self.scheduler.timesteps
326 |
327 | # 7. Prepare latent variables
328 | num_channels_latents = self.unet.config.in_channels
329 | latents = self.prepare_latents(
330 | batch_size * num_videos_per_prompt,
331 | num_frames,
332 | num_channels_latents,
333 | height,
334 | width,
335 | image_embeddings.dtype,
336 | device,
337 | generator,
338 | latents,
339 | )
340 | if determineTrain:
341 | latents[...] = 0.
342 |
343 | # 8. Denoising loop
344 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
345 | self._num_timesteps = len(timesteps)
346 | with self.progress_bar(total=num_inference_steps) as progress_bar:
347 | for i, t in enumerate(timesteps):
348 | # replace part of latents with conditons. ToDo: t embedding should also replace
349 | if to_replace_latents is not None:
350 | num_img_condition = to_replace_latents.shape[1]
351 | if not determineTrain:
352 | _noise = randn_tensor(to_replace_latents.shape, generator=generator, device=device, dtype=image_embeddings.dtype)
353 | noisy_to_replace_latents = self.scheduler.add_noise(to_replace_latents, _noise, t.unsqueeze(0))
354 | latents[:, :num_img_condition] = noisy_to_replace_latents
355 | else:
356 | latents[:, :num_img_condition] = to_replace_latents
357 |
358 |
359 | # expand the latents if we are doing classifier free guidance
360 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
361 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
362 | timestep = t
363 | # Concatenate image_latents over channels dimention
364 | latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
365 | # predict the noise residual
366 | noise_pred = self.unet(
367 | latent_model_input,
368 | timestep,
369 | encoder_hidden_states=image_embeddings,
370 | added_time_ids=added_time_ids,
371 | return_dict=False,
372 | )[0]
373 |
374 | # perform guidance
375 | if do_classifier_free_guidance:
376 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
377 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
378 |
379 | # compute the previous noisy sample x_t -> x_t-1
380 | scheduler_output = self.scheduler.step(noise_pred, t, latents)
381 | latents = scheduler_output.prev_sample
382 |
383 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
384 | progress_bar.update()
385 |
386 | return latents
387 | # resizing utils
388 | # TODO: clean up later
389 | def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
390 | h, w = input.shape[-2:]
391 | factors = (h / size[0], w / size[1])
392 |
393 | # First, we have to determine sigma
394 | # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
395 | sigmas = (
396 | max((factors[0] - 1.0) / 2.0, 0.001),
397 | max((factors[1] - 1.0) / 2.0, 0.001),
398 | )
399 |
400 | # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
401 | # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
402 | # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
403 | ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
404 |
405 | # Make sure it is odd
406 | if (ks[0] % 2) == 0:
407 | ks = ks[0] + 1, ks[1]
408 |
409 | if (ks[1] % 2) == 0:
410 | ks = ks[0], ks[1] + 1
411 |
412 | input = _gaussian_blur2d(input, ks, sigmas)
413 |
414 | output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
415 | return output
416 |
417 |
418 | def _compute_padding(kernel_size):
419 | """Compute padding tuple."""
420 | # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
421 | # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
422 | if len(kernel_size) < 2:
423 | raise AssertionError(kernel_size)
424 | computed = [k - 1 for k in kernel_size]
425 |
426 | # for even kernels we need to do asymmetric padding :(
427 | out_padding = 2 * len(kernel_size) * [0]
428 |
429 | for i in range(len(kernel_size)):
430 | computed_tmp = computed[-(i + 1)]
431 |
432 | pad_front = computed_tmp // 2
433 | pad_rear = computed_tmp - pad_front
434 |
435 | out_padding[2 * i + 0] = pad_front
436 | out_padding[2 * i + 1] = pad_rear
437 |
438 | return out_padding
439 |
440 |
441 | def _filter2d(input, kernel):
442 | # prepare kernel
443 | b, c, h, w = input.shape
444 | tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
445 |
446 | tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
447 |
448 | height, width = tmp_kernel.shape[-2:]
449 |
450 | padding_shape: list[int] = _compute_padding([height, width])
451 | input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
452 |
453 | # kernel and input tensor reshape to align element-wise or batch-wise params
454 | tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
455 | input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
456 |
457 | # convolve the tensor with the kernel.
458 | output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
459 |
460 | out = output.view(b, c, h, w)
461 | return out
462 |
463 |
464 | def _gaussian(window_size: int, sigma):
465 | if isinstance(sigma, float):
466 | sigma = torch.tensor([[sigma]])
467 |
468 | batch_size = sigma.shape[0]
469 |
470 | x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
471 |
472 | if window_size % 2 == 0:
473 | x = x + 0.5
474 |
475 | gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
476 |
477 | return gauss / gauss.sum(-1, keepdim=True)
478 |
479 |
480 | def _gaussian_blur2d(input, kernel_size, sigma):
481 | if isinstance(sigma, tuple):
482 | sigma = torch.tensor([sigma], dtype=input.dtype)
483 | else:
484 | sigma = sigma.to(dtype=input.dtype)
485 |
486 | ky, kx = int(kernel_size[0]), int(kernel_size[1])
487 | bs = sigma.shape[0]
488 | kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
489 | kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
490 | out_x = _filter2d(input, kernel_x[..., None, :])
491 | out = _filter2d(out_x, kernel_y[..., None])
492 |
493 | return out
494 |
--------------------------------------------------------------------------------
/normalcrafter/unet.py:
--------------------------------------------------------------------------------
1 | from diffusers import UNetSpatioTemporalConditionModel
2 | from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionOutput
3 | from diffusers.utils import is_torch_version
4 | import torch
5 | from typing import Any, Dict, Optional, Tuple, Union
6 |
7 | def create_custom_forward(module, return_dict=None):
8 | def custom_forward(*inputs):
9 | if return_dict is not None:
10 | return module(*inputs, return_dict=return_dict)
11 | else:
12 | return module(*inputs)
13 |
14 | return custom_forward
15 | CKPT_KWARGS = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
16 |
17 |
18 | class DiffusersUNetSpatioTemporalConditionModelNormalCrafter(UNetSpatioTemporalConditionModel):
19 |
20 | @staticmethod
21 | def forward_crossattn_down_block_dino(
22 | module,
23 | hidden_states: torch.Tensor,
24 | temb: Optional[torch.Tensor] = None,
25 | encoder_hidden_states: Optional[torch.Tensor] = None,
26 | image_only_indicator: Optional[torch.Tensor] = None,
27 | dino_down_block_res_samples = None,
28 | ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
29 | output_states = ()
30 | self = module
31 | blocks = list(zip(self.resnets, self.attentions))
32 | for resnet, attn in blocks:
33 | if self.training and self.gradient_checkpointing: # TODO
34 | hidden_states = torch.utils.checkpoint.checkpoint(
35 | create_custom_forward(resnet),
36 | hidden_states,
37 | temb,
38 | image_only_indicator,
39 | **CKPT_KWARGS,
40 | )
41 |
42 | hidden_states = torch.utils.checkpoint.checkpoint(
43 | create_custom_forward(attn),
44 | hidden_states,
45 | encoder_hidden_states,
46 | image_only_indicator,
47 | False,
48 | **CKPT_KWARGS,
49 | )[0]
50 | else:
51 | hidden_states = resnet(
52 | hidden_states,
53 | temb,
54 | image_only_indicator=image_only_indicator,
55 | )
56 | hidden_states = attn(
57 | hidden_states,
58 | encoder_hidden_states=encoder_hidden_states,
59 | image_only_indicator=image_only_indicator,
60 | return_dict=False,
61 | )[0]
62 |
63 | if dino_down_block_res_samples is not None:
64 | hidden_states += dino_down_block_res_samples.pop(0)
65 |
66 | output_states = output_states + (hidden_states,)
67 |
68 | if self.downsamplers is not None:
69 | for downsampler in self.downsamplers:
70 | hidden_states = downsampler(hidden_states)
71 | if dino_down_block_res_samples is not None:
72 | hidden_states += dino_down_block_res_samples.pop(0)
73 |
74 | output_states = output_states + (hidden_states,)
75 |
76 | return hidden_states, output_states
77 | @staticmethod
78 | def forward_down_block_dino(
79 | module,
80 | hidden_states: torch.Tensor,
81 | temb: Optional[torch.Tensor] = None,
82 | image_only_indicator: Optional[torch.Tensor] = None,
83 | dino_down_block_res_samples = None,
84 | ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
85 | self = module
86 | output_states = ()
87 | for resnet in self.resnets:
88 | if self.training and self.gradient_checkpointing:
89 | if is_torch_version(">=", "1.11.0"):
90 | hidden_states = torch.utils.checkpoint.checkpoint(
91 | create_custom_forward(resnet),
92 | hidden_states,
93 | temb,
94 | image_only_indicator,
95 | use_reentrant=False,
96 | )
97 | else:
98 | hidden_states = torch.utils.checkpoint.checkpoint(
99 | create_custom_forward(resnet),
100 | hidden_states,
101 | temb,
102 | image_only_indicator,
103 | )
104 | else:
105 | hidden_states = resnet(
106 | hidden_states,
107 | temb,
108 | image_only_indicator=image_only_indicator,
109 | )
110 | if dino_down_block_res_samples is not None:
111 | hidden_states += dino_down_block_res_samples.pop(0)
112 | output_states = output_states + (hidden_states,)
113 |
114 | if self.downsamplers is not None:
115 | for downsampler in self.downsamplers:
116 | hidden_states = downsampler(hidden_states)
117 | if dino_down_block_res_samples is not None:
118 | hidden_states += dino_down_block_res_samples.pop(0)
119 | output_states = output_states + (hidden_states,)
120 |
121 | return hidden_states, output_states
122 |
123 |
124 | def forward(
125 | self,
126 | sample: torch.FloatTensor,
127 | timestep: Union[torch.Tensor, float, int],
128 | encoder_hidden_states: torch.Tensor,
129 | added_time_ids: torch.Tensor,
130 | return_dict: bool = True,
131 | image_controlnet_down_block_res_samples = None,
132 | image_controlnet_mid_block_res_sample = None,
133 | dino_down_block_res_samples = None,
134 |
135 | ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
136 | r"""
137 | The [`UNetSpatioTemporalConditionModel`] forward method.
138 |
139 | Args:
140 | sample (`torch.FloatTensor`):
141 | The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
142 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
143 | encoder_hidden_states (`torch.FloatTensor`):
144 | The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
145 | added_time_ids: (`torch.FloatTensor`):
146 | The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
147 | embeddings and added to the time embeddings.
148 | return_dict (`bool`, *optional*, defaults to `True`):
149 | Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
150 | tuple.
151 | Returns:
152 | [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
153 | If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
154 | a `tuple` is returned where the first element is the sample tensor.
155 | """
156 | if not hasattr(self, "custom_gradient_checkpointing"):
157 | self.custom_gradient_checkpointing = False
158 |
159 | # 1. time
160 | timesteps = timestep
161 | if not torch.is_tensor(timesteps):
162 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
163 | # This would be a good case for the `match` statement (Python 3.10+)
164 | is_mps = sample.device.type == "mps"
165 | if isinstance(timestep, float):
166 | dtype = torch.float32 if is_mps else torch.float64
167 | else:
168 | dtype = torch.int32 if is_mps else torch.int64
169 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
170 | elif len(timesteps.shape) == 0:
171 | timesteps = timesteps[None].to(sample.device)
172 |
173 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
174 | batch_size, num_frames = sample.shape[:2]
175 | if len(timesteps.shape) == 1:
176 | timesteps = timesteps.expand(batch_size)
177 | else:
178 | timesteps = timesteps.reshape(batch_size * num_frames)
179 | t_emb = self.time_proj(timesteps) # (B, C)
180 |
181 | # `Timesteps` does not contain any weights and will always return f32 tensors
182 | # but time_embedding might actually be running in fp16. so we need to cast here.
183 | # there might be better ways to encapsulate this.
184 | t_emb = t_emb.to(dtype=sample.dtype)
185 |
186 | emb = self.time_embedding(t_emb) # (B, C)
187 |
188 | time_embeds = self.add_time_proj(added_time_ids.flatten())
189 | time_embeds = time_embeds.reshape((batch_size, -1))
190 | time_embeds = time_embeds.to(emb.dtype)
191 | aug_emb = self.add_embedding(time_embeds)
192 | if emb.shape[0] == 1:
193 | emb = emb + aug_emb
194 | # Repeat the embeddings num_video_frames times
195 | # emb: [batch, channels] -> [batch * frames, channels]
196 | emb = emb.repeat_interleave(num_frames, dim=0)
197 | else:
198 | aug_emb = aug_emb.repeat_interleave(num_frames, dim=0)
199 | emb = emb + aug_emb
200 |
201 | # Flatten the batch and frames dimensions
202 | # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
203 | sample = sample.flatten(0, 1)
204 |
205 | # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
206 | # here, our encoder_hidden_states is [batch * frames, 1, channels]
207 |
208 | if not sample.shape[0] == encoder_hidden_states.shape[0]:
209 | encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
210 | # 2. pre-process
211 | sample = self.conv_in(sample)
212 |
213 | image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
214 |
215 | if dino_down_block_res_samples is not None:
216 | dino_down_block_res_samples = [x for x in dino_down_block_res_samples]
217 | sample += dino_down_block_res_samples.pop(0)
218 |
219 | down_block_res_samples = (sample,)
220 | for downsample_block in self.down_blocks:
221 | if dino_down_block_res_samples is None:
222 | if self.custom_gradient_checkpointing:
223 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
224 | sample, res_samples = torch.utils.checkpoint.checkpoint(
225 | create_custom_forward(downsample_block),
226 | sample,
227 | emb,
228 | encoder_hidden_states,
229 | image_only_indicator,
230 | **CKPT_KWARGS,
231 | )
232 | else:
233 | sample, res_samples = torch.utils.checkpoint.checkpoint(
234 | create_custom_forward(downsample_block),
235 | sample,
236 | emb,
237 | image_only_indicator,
238 | **CKPT_KWARGS,
239 | )
240 | else:
241 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
242 | sample, res_samples = downsample_block(
243 | hidden_states=sample,
244 | temb=emb,
245 | encoder_hidden_states=encoder_hidden_states,
246 | image_only_indicator=image_only_indicator,
247 | )
248 | else:
249 | sample, res_samples = downsample_block(
250 | hidden_states=sample,
251 | temb=emb,
252 | image_only_indicator=image_only_indicator,
253 | )
254 | else:
255 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
256 | sample, res_samples = self.forward_crossattn_down_block_dino(
257 | downsample_block,
258 | sample,
259 | emb,
260 | encoder_hidden_states,
261 | image_only_indicator,
262 | dino_down_block_res_samples,
263 | )
264 | else:
265 | sample, res_samples = self.forward_down_block_dino(
266 | downsample_block,
267 | sample,
268 | emb,
269 | image_only_indicator,
270 | dino_down_block_res_samples,
271 | )
272 | down_block_res_samples += res_samples
273 |
274 | if image_controlnet_down_block_res_samples is not None:
275 | new_down_block_res_samples = ()
276 |
277 | for down_block_res_sample, image_controlnet_down_block_res_sample in zip(
278 | down_block_res_samples, image_controlnet_down_block_res_samples
279 | ):
280 | down_block_res_sample = (down_block_res_sample + image_controlnet_down_block_res_sample) / 2
281 | new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
282 |
283 | down_block_res_samples = new_down_block_res_samples
284 |
285 | # 4. mid
286 | if self.custom_gradient_checkpointing:
287 | sample = torch.utils.checkpoint.checkpoint(
288 | create_custom_forward(self.mid_block),
289 | sample,
290 | emb,
291 | encoder_hidden_states,
292 | image_only_indicator,
293 | **CKPT_KWARGS,
294 | )
295 | else:
296 | sample = self.mid_block(
297 | hidden_states=sample,
298 | temb=emb,
299 | encoder_hidden_states=encoder_hidden_states,
300 | image_only_indicator=image_only_indicator,
301 | )
302 |
303 | if image_controlnet_mid_block_res_sample is not None:
304 | sample = (sample + image_controlnet_mid_block_res_sample) / 2
305 |
306 | # 5. up
307 | mid_up_block_out_samples = [sample, ]
308 | down_block_out_sampels = []
309 | for i, upsample_block in enumerate(self.up_blocks):
310 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
311 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
312 | down_block_out_sampels.append(res_samples[-1])
313 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
314 | if self.custom_gradient_checkpointing:
315 | sample = torch.utils.checkpoint.checkpoint(
316 | create_custom_forward(upsample_block),
317 | sample,
318 | res_samples,
319 | emb,
320 | encoder_hidden_states,
321 | image_only_indicator,
322 | **CKPT_KWARGS
323 | )
324 | else:
325 | sample = upsample_block(
326 | hidden_states=sample,
327 | temb=emb,
328 | res_hidden_states_tuple=res_samples,
329 | encoder_hidden_states=encoder_hidden_states,
330 | image_only_indicator=image_only_indicator,
331 | )
332 | else:
333 | if self.custom_gradient_checkpointing:
334 | sample = torch.utils.checkpoint.checkpoint(
335 | create_custom_forward(upsample_block),
336 | sample,
337 | res_samples,
338 | emb,
339 | image_only_indicator,
340 | **CKPT_KWARGS
341 | )
342 | else:
343 | sample = upsample_block(
344 | hidden_states=sample,
345 | temb=emb,
346 | res_hidden_states_tuple=res_samples,
347 | image_only_indicator=image_only_indicator,
348 | )
349 | mid_up_block_out_samples.append(sample)
350 | # 6. post-process
351 | sample = self.conv_norm_out(sample)
352 | sample = self.conv_act(sample)
353 | if self.custom_gradient_checkpointing:
354 | sample = torch.utils.checkpoint.checkpoint(
355 | create_custom_forward(self.conv_out),
356 | sample,
357 | **CKPT_KWARGS
358 | )
359 | else:
360 | sample = self.conv_out(sample)
361 |
362 | # 7. Reshape back to original shape
363 | sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
364 |
365 | if not return_dict:
366 | return (sample, down_block_out_sampels[::-1], mid_up_block_out_samples)
367 |
368 | return UNetSpatioTemporalConditionOutput(sample=sample)
--------------------------------------------------------------------------------
/normalcrafter/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Union, List
2 | import tempfile
3 | import numpy as np
4 | import PIL.Image
5 | import matplotlib.cm as cm
6 | import mediapy
7 | import torch
8 | from decord import VideoReader, cpu
9 |
10 |
11 | def read_video_frames(video_path, process_length, target_fps, max_res):
12 | print("==> processing video: ", video_path)
13 | vid = VideoReader(video_path, ctx=cpu(0))
14 | print("==> original video shape: ", (len(vid), *vid.get_batch([0]).shape[1:]))
15 | original_height, original_width = vid.get_batch([0]).shape[1:3]
16 |
17 | if max(original_height, original_width) > max_res:
18 | scale = max_res / max(original_height, original_width)
19 | height = round(original_height * scale)
20 | width = round(original_width * scale)
21 | else:
22 | height = original_height
23 | width = original_width
24 |
25 | vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height)
26 |
27 | fps = vid.get_avg_fps() if target_fps == -1 else target_fps
28 | stride = round(vid.get_avg_fps() / fps)
29 | stride = max(stride, 1)
30 | frames_idx = list(range(0, len(vid), stride))
31 | print(
32 | f"==> downsampled shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}, with stride: {stride}"
33 | )
34 | if process_length != -1 and process_length < len(frames_idx):
35 | frames_idx = frames_idx[:process_length]
36 | print(
37 | f"==> final processing shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}"
38 | )
39 | frames = vid.get_batch(frames_idx).asnumpy().astype(np.uint8)
40 | frames = [PIL.Image.fromarray(x) for x in frames]
41 |
42 | return frames, fps
43 |
44 | def save_video(
45 | video_frames: Union[List[np.ndarray], List[PIL.Image.Image]],
46 | output_video_path: str = None,
47 | fps: int = 10,
48 | crf: int = 18,
49 | ) -> str:
50 | if output_video_path is None:
51 | output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
52 |
53 | if isinstance(video_frames[0], np.ndarray):
54 | video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames]
55 |
56 | elif isinstance(video_frames[0], PIL.Image.Image):
57 | video_frames = [np.array(frame) for frame in video_frames]
58 | mediapy.write_video(output_video_path, video_frames, fps=fps, crf=crf)
59 | return output_video_path
60 |
61 | def vis_sequence_normal(normals: np.ndarray):
62 | normals = normals.clip(-1., 1.)
63 | normals = normals * 0.5 + 0.5
64 | return normals
65 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.30.1
2 | aiofiles==23.2.1
3 | annotated-types==0.7.0
4 | anyio==4.9.0
5 | asttokens==3.0.0
6 | certifi==2025.1.31
7 | charset-normalizer==3.4.1
8 | click==8.1.8
9 | cmake==4.0.0
10 | contourpy==1.3.1
11 | cycler==0.12.1
12 | decorator==5.2.1
13 | decord==0.6.0
14 | diffusers==0.29.1
15 | einops==0.8.1
16 | exceptiongroup==1.2.2
17 | executing==2.2.0
18 | fastapi==0.115.12
19 | ffmpy==0.5.0
20 | filelock==3.18.0
21 | fire==0.6.0
22 | fonttools==4.57.0
23 | fsspec==2025.3.2
24 | gradio==5.23.3
25 | gradio_client==1.8.0
26 | groovy==0.1.2
27 | h11==0.14.0
28 | httpcore==1.0.7
29 | httpx==0.28.1
30 | huggingface-hub==0.30.1
31 | idna==3.10
32 | importlib_metadata==8.6.1
33 | ipython==8.35.0
34 | jedi==0.19.2
35 | Jinja2==3.1.6
36 | kiwisolver==1.4.8
37 | lit==18.1.8
38 | markdown-it-py==3.0.0
39 | MarkupSafe==3.0.2
40 | matplotlib==3.8.4
41 | matplotlib-inline==0.1.7
42 | mdurl==0.1.2
43 | mediapy==1.2.0
44 | mpmath==1.3.0
45 | mypy-extensions==1.0.0
46 | networkx==3.4.2
47 | numpy==1.26.4
48 | nvidia-cublas-cu11==11.10.3.66
49 | nvidia-cuda-cupti-cu11==11.7.101
50 | nvidia-cuda-nvrtc-cu11==11.7.99
51 | nvidia-cuda-runtime-cu11==11.7.99
52 | nvidia-cudnn-cu11==8.5.0.96
53 | nvidia-cufft-cu11==10.9.0.58
54 | nvidia-curand-cu11==10.2.10.91
55 | nvidia-cusolver-cu11==11.4.0.1
56 | nvidia-cusparse-cu11==11.7.4.91
57 | nvidia-nccl-cu11==2.14.3
58 | nvidia-nvtx-cu11==11.7.91
59 | opencv-python==4.11.0.86
60 | OpenEXR==3.2.4
61 | orjson==3.10.16
62 | packaging==24.2
63 | pandas==2.2.3
64 | parso==0.8.4
65 | pexpect==4.9.0
66 | pillow==11.1.0
67 | prompt_toolkit==3.0.50
68 | psutil==5.9.8
69 | ptyprocess==0.7.0
70 | pure_eval==0.2.3
71 | pydantic==2.11.2
72 | pydantic_core==2.33.1
73 | pydub==0.25.1
74 | Pygments==2.19.1
75 | pyparsing==3.2.3
76 | pyre-extensions==0.0.29
77 | python-dateutil==2.9.0.post0
78 | python-multipart==0.0.20
79 | pytz==2025.2
80 | PyYAML==6.0.2
81 | regex==2024.11.6
82 | requests==2.32.3
83 | rich==14.0.0
84 | ruff==0.11.4
85 | safehttpx==0.1.6
86 | safetensors==0.5.3
87 | semantic-version==2.10.0
88 | shellingham==1.5.4
89 | six==1.17.0
90 | sniffio==1.3.1
91 | spaces==0.34.1
92 | stack-data==0.6.3
93 | starlette==0.46.1
94 | sympy==1.13.3
95 | termcolor==3.0.1
96 | tokenizers==0.19.1
97 | tomlkit==0.13.2
98 | torch==2.0.1
99 | tqdm==4.67.1
100 | traitlets==5.14.3
101 | transformers==4.41.2
102 | triton==2.0.0
103 | typer==0.15.2
104 | typing-inspect==0.9.0
105 | typing-inspection==0.4.0
106 | typing_extensions==4.13.1
107 | tzdata==2025.2
108 | urllib3==2.3.0
109 | uvicorn==0.34.0
110 | wcwidth==0.2.13
111 | websockets==15.0.1
112 | xformers==0.0.20
113 | zipp==3.21.0
114 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 | import numpy as np
4 | import torch
5 |
6 | from diffusers.training_utils import set_seed
7 | from diffusers import AutoencoderKLTemporalDecoder
8 | from fire import Fire
9 |
10 | from normalcrafter.normal_crafter_ppl import NormalCrafterPipeline
11 | from normalcrafter.unet import DiffusersUNetSpatioTemporalConditionModelNormalCrafter
12 | from normalcrafter.utils import vis_sequence_normal, save_video, read_video_frames
13 |
14 |
15 | class DepthCrafterDemo:
16 | def __init__(
17 | self,
18 | unet_path: str,
19 | pre_train_path: str,
20 | cpu_offload: str = "model",
21 | ):
22 | unet = DiffusersUNetSpatioTemporalConditionModelNormalCrafter.from_pretrained(
23 | unet_path,
24 | subfolder="unet",
25 | low_cpu_mem_usage=True,
26 | )
27 | vae = AutoencoderKLTemporalDecoder.from_pretrained(
28 | unet_path, subfolder="vae"
29 | )
30 | weight_dtype = torch.float16
31 | vae.to(dtype=weight_dtype)
32 | unet.to(dtype=weight_dtype)
33 | # load weights of other components from the provided checkpoint
34 | self.pipe = NormalCrafterPipeline.from_pretrained(
35 | pre_train_path,
36 | unet=unet,
37 | vae=vae,
38 | torch_dtype=weight_dtype,
39 | variant="fp16",
40 | )
41 |
42 | # for saving memory, we can offload the model to CPU, or even run the model sequentially to save more memory
43 | if cpu_offload is not None:
44 | if cpu_offload == "sequential":
45 | # This will slow, but save more memory
46 | self.pipe.enable_sequential_cpu_offload()
47 | elif cpu_offload == "model":
48 | self.pipe.enable_model_cpu_offload()
49 | else:
50 | raise ValueError(f"Unknown cpu offload option: {cpu_offload}")
51 | else:
52 | self.pipe.to("cuda")
53 | # enable attention slicing and xformers memory efficient attention
54 | try:
55 | self.pipe.enable_xformers_memory_efficient_attention()
56 | except Exception as e:
57 | print(e)
58 | print("Xformers is not enabled")
59 | # self.pipe.enable_attention_slicing()
60 |
61 | def infer(
62 | self,
63 | video: str,
64 | save_folder: str = "./demo_output",
65 | window_size: int = 14,
66 | time_step_size: int = 10,
67 | process_length: int = 195,
68 | decode_chunk_size: int = 7,
69 | max_res: int = 1024,
70 | dataset: str = "open",
71 | target_fps: int = 15,
72 | seed: int = 42,
73 | save_npz: bool = False,
74 | ):
75 | set_seed(seed)
76 |
77 | frames, target_fps = read_video_frames(
78 | video,
79 | process_length,
80 | target_fps,
81 | max_res,
82 | )
83 | # inference the depth map using the DepthCrafter pipeline
84 | with torch.inference_mode():
85 | res = self.pipe(
86 | frames,
87 | decode_chunk_size=decode_chunk_size,
88 | time_step_size=time_step_size,
89 | window_size=window_size,
90 | ).frames[0]
91 | # visualize the depth map and save the results
92 | vis = vis_sequence_normal(res)
93 | # save the depth map and visualization with the target FPS
94 | save_path = os.path.join(
95 | save_folder, os.path.splitext(os.path.basename(video))[0]
96 | )
97 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
98 | save_video(vis, save_path + "_vis.mp4", fps=target_fps)
99 | save_video(frames, save_path + "_input.mp4", fps=target_fps)
100 | if save_npz:
101 | np.savez_compressed(save_path + ".npz", depth=res)
102 |
103 | return [
104 | save_path + "_input.mp4",
105 | save_path + "_vis.mp4",
106 | ]
107 |
108 | def run(
109 | self,
110 | input_video,
111 | num_denoising_steps,
112 | guidance_scale,
113 | max_res=1024,
114 | process_length=195,
115 | ):
116 | res_path = self.infer(
117 | input_video,
118 | num_denoising_steps,
119 | guidance_scale,
120 | max_res=max_res,
121 | process_length=process_length,
122 | )
123 | # clear the cache for the next video
124 | gc.collect()
125 | torch.cuda.empty_cache()
126 | return res_path[:2]
127 |
128 |
129 | def main(
130 | video_path: str,
131 | save_folder: str = "./demo_output",
132 | unet_path: str = "Yanrui95/NormalCrafter",
133 | pre_train_path: str = "stabilityai/stable-video-diffusion-img2vid-xt",
134 | process_length: int = -1,
135 | cpu_offload: str = "model",
136 | target_fps: int = -1,
137 | seed: int = 42,
138 | window_size: int = 14,
139 | time_step_size: int = 10,
140 | max_res: int = 1024,
141 | dataset: str = "open",
142 | save_npz: bool = False
143 | ):
144 | depthcrafter_demo = DepthCrafterDemo(
145 | unet_path=unet_path,
146 | pre_train_path=pre_train_path,
147 | cpu_offload=cpu_offload,
148 | )
149 | # process the videos, the video paths are separated by comma
150 | video_paths = video_path.split(",")
151 | for video in video_paths:
152 | depthcrafter_demo.infer(
153 | video,
154 | save_folder=save_folder,
155 | window_size=window_size,
156 | process_length=process_length,
157 | time_step_size=time_step_size,
158 | max_res=max_res,
159 | dataset=dataset,
160 | target_fps=target_fps,
161 | seed=seed,
162 | save_npz=save_npz,
163 | )
164 | # clear the cache for the next video
165 | gc.collect()
166 | torch.cuda.empty_cache()
167 |
168 |
169 | if __name__ == "__main__":
170 | # running configs
171 | # the most important arguments for memory saving are `cpu_offload`, `enable_xformers`, `max_res`, and `window_size`
172 | # the most important arguments for trade-off between quality and speed are
173 | # `num_inference_steps`, `guidance_scale`, and `max_res`
174 | Fire(main)
175 |
--------------------------------------------------------------------------------