26 |
27 | This example, based on this [MJPEG server](https://github.com/radames/Real-Time-Latent-Consistency-Model/), runs image-to-image with a live webcam feed or screen capture on a web browser.
28 |
29 | ## Usage
30 |
31 | ### 1. Prepare Dependencies
32 |
33 | You need Node.js 18+ and Python 3.10 to run this example. Please make sure you've installed all dependencies according to the [installation instructions](../README.md#installation).
34 |
35 | ```bash
36 | cd frontend
37 | npm i
38 | npm run build
39 | cd ..
40 | pip install -r requirements.txt
41 | ```
42 |
43 | If you face some difficulties in install `npm`, you can try to install it via `conda`:
44 |
45 | ```bash
46 | conda install -c conda-forge nodejs
47 | ```
48 |
49 | ### 2. Run Demo
50 |
51 | If you run the demo with default [setting](./demo_cfg.yaml), you should download the model for style `felted`.
52 |
53 | ```bash
54 | bash ../scripts/download_model.sh felted
55 | ```
56 |
57 | Then, you can run the demo with the following command, and open `http://127.0.0.1:7860` in your browser:
58 |
59 | ```bash
60 | # with TensorRT acceleration, please pay patience for the first time, may take more than 20 minutes
61 | python main.py --port 7860 --host 127.0.0.1 --acceleration tensorrt
62 | # if you don't have TensorRT, you can run it with `none` acceleration
63 | python main.py --port 7860 --host 127.0.0.1 --acceleration none
64 | ```
65 |
66 | If you want to run this demo on a remote server, you can set host to `0.0.0.0`, e.g.
67 |
68 | ```bash
69 | python main.py --port 7860 --host 0.0.0.0 --acceleration tensorrt
70 | ```
71 |
--------------------------------------------------------------------------------
/demo/frontend/src/lib/components/PipelineOptions.svelte:
--------------------------------------------------------------------------------
1 |
19 |
20 |
113 | There are {currentQueueSize}
114 | user(s) sharing the same GPU, affecting real-time performance. Maximum queue size is {maxQueueSize}.
115 | Duplicate and run it on your own GPU.
120 |
36 |
37 | * **Uni-directional** Temporal Attention with **Warmup** Mechanism
38 | * **Multitimestep KV-Cache** for Temporal Attention during Inference
39 | * **Depth Prior** for Better Structure Consistency
40 | * Compatible with **DreamBooth and LoRA** for Various Styles
41 | * **TensorRT** Supported
42 |
43 | The speed evaluation is conducted on **Ubuntu 20.04.6 LTS** and **Pytorch 2.2.2** with **RTX 4090 GPU** and **Intel(R) Xeon(R) Platinum 8352V CPU**. Denoising steps are set as 2.
44 |
45 | | Resolution | TensorRT | FPS |
46 | | :--------: | :------: | :-------: |
47 | | 512 x 512 | **On** | **16.43** |
48 | | 512 x 512 | Off | 6.91 |
49 | | 768 x 512 | **On** | **12.15** |
50 | | 768 x 512 | Off | 6.29 |
51 |
52 | ## Installation
53 |
54 | ### Step0: clone this repository and submodule
55 |
56 | ```bash
57 | git clone https://github.com/open-mmlab/Live2Diff.git
58 | # or vis ssh
59 | git clone git@github.com:open-mmlab/Live2Diff.git
60 |
61 | cd Live2Diff
62 | git submodule update --init --recursive
63 | ```
64 |
65 | ### Step1: Make Environment
66 |
67 | Create virtual environment via conda:
68 |
69 | ```bash
70 | conda create -n live2diff python=3.10
71 | conda activate live2diff
72 | ```
73 |
74 | ### Step2: Install PyTorch and xformers
75 |
76 | Select the appropriate version for your system.
77 |
78 | ```bash
79 | # CUDA 11.8
80 | pip install torch torchvision xformers --index-url https://download.pytorch.org/whl/cu118
81 | # CUDA 12.1
82 | pip install torch torchvision xformers --index-url https://download.pytorch.org/whl/cu121
83 | ```
84 |
85 | Please may refers to https://pytorch.org/ for more detail.
86 |
87 | ### Step3: Install Project
88 |
89 | If you want to use TensorRT acceleration (we recommend it), you can install it by the following command.
90 |
91 | ```bash
92 | # for cuda 11.x
93 | pip install ."[tensorrt_cu11]"
94 | # for cuda 12.x
95 | pip install ."[tensorrt_cu12]"
96 | ```
97 |
98 | Otherwise, you can install it via
99 |
100 | ```bash
101 | pip install .
102 | ```
103 |
104 | If you want to install it with development mode (a.k.a. "Editable Installs"), you can add `-e` option.
105 |
106 | ```bash
107 | # for cuda 11.x
108 | pip install -e ."[tensorrt_cu11]"
109 | # for cuda 12.x
110 | pip install -e ."[tensorrt_cu12]"
111 | # or
112 | pip install -e .
113 | ```
114 |
115 | ### Step4: Download Checkpoints and Demo Data
116 |
117 | 1. Download StableDiffusion-v1-5
118 |
119 | ```bash
120 | huggingface-cli download runwayml/stable-diffusion-v1-5 --local-dir ./models/Model/stable-diffusion-v1-5
121 | ```
122 |
123 | 2. Download Checkpoint from [HuggingFace](https://huggingface.co/Leoxing/Live2Diff) and put it under `models` folder.
124 |
125 | 3. Download Depth Detector from MiDaS's official [release](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt) and put it under `models` folder.
126 |
127 | 4. Apply the download token from [civitAI](https://education.civitai.com/civitais-guide-to-downloading-via-api/) and then download Dreambooths and LoRAs via the script:
128 |
129 | ```bash
130 | # download all DreamBooth/Lora
131 | bash scripts/download.sh all YOUR_TOKEN
132 | # or download the one you want to use
133 | bash scripts/download.sh disney YOUR_TOKEN
134 | ```
135 |
136 | 5. Download demo data from [OneDrive](https://pjlab-my.sharepoint.cn/:f:/g/personal/xingzhening_pjlab_org_cn/EpefezlxFXNBk93RDttYLMUBP2bofb6AZDfyRIkGapmIrQ?e=A6h2Eb).
137 |
138 | Then then data structure of `models` folder should be like this:
139 |
140 | ```bash
141 | ./
142 | |-- models
143 | | |-- LoRA
144 | | | |-- MoXinV1.safetensors
145 | | | `-- ...
146 | | |-- Model
147 | | | |-- 3Guofeng3_v34.safetensors
148 | | | |-- ...
149 | | | `-- stable-diffusion-v1-5
150 | | |-- live2diff.ckpt
151 | | `-- dpt_hybrid_384.pt
152 | `--data
153 | |-- 1.mp4
154 | |-- 2.mp4
155 | |-- 3.mp4
156 | `-- 4.mp4
157 | ```
158 |
159 | ### Notification
160 |
161 | The above installation steps (e.g. [download script](#step4-download-checkpoints-and-demo-data)) are for Linux users and not well tested on Windows. If you face any difficulties, please feel free to open an issue 🤗.
162 |
163 | ## Quick Start
164 |
165 | You can try examples under [`data`](./data) directory. For example,
166 | ```bash
167 | # with TensorRT acceleration, please pay patience for the first time, may take more than 20 minutes
168 | python test.py ./data/1.mp4 ./configs/disneyPixar.yaml --max-frames -1 --prompt "1man is talking" --output work_dirs/1-disneyPixar.mp4 --height 512 --width 512 --acceleration tensorrt
169 |
170 | # without TensorRT acceleration
171 | python test.py ./data/2.mp4 ./configs/disneyPixar.yaml --max-frames -1 --prompt "1man is talking" --output work_dirs/1-disneyPixar.mp4 --height 512 --width 512 --acceleration none
172 | ```
173 |
174 | You can adjust denoising strength via `--num-inference-steps`, `--strength`, and `--t-index-list`. Please refers to `test.py` for more detail.
175 |
176 | ## Troubleshooting
177 |
178 | 1. If you face Cuda Out-of-memory error with TensorRT, please try to reduce `t-index-list` or `strength`. When inference with TensorRT, we maintian a group of buffer for kv-cache, which consumes more memory. Reduce `t-index-list` or `strength` can reduce the size of kv-cache and save more GPU memory.
179 |
180 | ## Real-Time Video2Video Demo
181 |
182 | There is an interactive txt2img demo in [`demo`](./demo) directory!
183 |
184 | Please refers to [`demo/README.md`](./demo/README.md) for more details.
185 |
186 |
187 |
188 |
189 |
190 |
191 |
Human Face (Web Camera Input)
192 |
193 |
194 |
Anime Character (Screen Video Input)
195 |
196 |
197 |
198 |
199 |
201 |
202 |
204 |
205 |
206 |
207 |
208 |
209 |
210 | ## Acknowledgements
211 |
212 | The video and image demos in this GitHub repository were generated using [LCM-LoRA](https://huggingface.co/latent-consistency/lcm-lora-sdv1-5). Stream batch in [StreamDiffusion](https://github.com/cumulo-autumn/StreamDiffusion) is used for model acceleration. The design of Video Diffusion Model is adopted from [AnimateDiff](https://github.com/guoyww/AnimateDiff). We use a third-party implementation of [MiDaS](https://github.com/lewiji/MiDaS) implementation which support onnx export. Our online demo is modified from [Real-Time-Latent-Consistency-Model](https://github.com/radames/Real-Time-Latent-Consistency-Model/).
213 |
214 | ## BibTex
215 |
216 | If you find it helpful, please consider citing our work:
217 |
218 | ```bibtex
219 | @article{xing2024live2diff,
220 | title={Live2Diff: Live Stream Translation via Uni-directional Attention in Video Diffusion Models},
221 | author={Zhening Xing and Gereon Fox and Yanhong Zeng and Xingang Pan and Mohamed Elgharib and Christian Theobalt and Kai Chen},
222 | booktitle={arXiv preprint arxiv:2407.08701},
223 | year={2024}
224 | }
225 | ```
226 |
--------------------------------------------------------------------------------
/live2diff/animatediff/models/resnet.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2 | from typing import Tuple
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from einops import rearrange
8 |
9 |
10 | def zero_module(module):
11 | # Zero out the parameters of a module and return it.
12 | for p in module.parameters():
13 | p.detach().zero_()
14 | return module
15 |
16 |
17 | class MappingNetwork(nn.Module):
18 | """
19 | Modified from https://github.com/huggingface/diffusers/blob/196835695ed6fa3ec53b888088d9d5581e8f8e94/src/diffusers/models/controlnet.py#L66-L108 # noqa
20 | """
21 |
22 | def __init__(
23 | self,
24 | conditioning_embedding_channels: int,
25 | conditioning_channels: int = 3,
26 | block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
27 | ):
28 | super().__init__()
29 |
30 | self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
31 |
32 | self.blocks = nn.ModuleList([])
33 |
34 | for i in range(len(block_out_channels) - 1):
35 | channel_in = block_out_channels[i]
36 | channel_out = block_out_channels[i + 1]
37 | self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1))
38 | self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1))
39 |
40 | self.conv_out = zero_module(
41 | InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
42 | )
43 |
44 | def forward(self, conditioning):
45 | embedding = self.conv_in(conditioning)
46 | embedding = F.silu(embedding)
47 |
48 | for block in self.blocks:
49 | embedding = block(embedding)
50 | embedding = F.silu(embedding)
51 |
52 | embedding = self.conv_out(embedding)
53 |
54 | return embedding
55 |
56 |
57 | class InflatedConv3d(nn.Conv2d):
58 | def forward(self, x):
59 | video_length = x.shape[2]
60 |
61 | x = rearrange(x, "b c f h w -> (b f) c h w")
62 | x = super().forward(x)
63 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
64 |
65 | return x
66 |
67 |
68 | class InflatedGroupNorm(nn.GroupNorm):
69 | def forward(self, x):
70 | video_length = x.shape[2]
71 |
72 | x = rearrange(x, "b c f h w -> (b f) c h w")
73 | x = super().forward(x)
74 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
75 |
76 | return x
77 |
78 |
79 | class Upsample3D(nn.Module):
80 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
81 | super().__init__()
82 | self.channels = channels
83 | self.out_channels = out_channels or channels
84 | self.use_conv = use_conv
85 | self.use_conv_transpose = use_conv_transpose
86 | self.name = name
87 |
88 | # conv = None
89 | if use_conv_transpose:
90 | raise NotImplementedError
91 | elif use_conv:
92 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
93 |
94 | def forward(self, hidden_states, output_size=None):
95 | assert hidden_states.shape[1] == self.channels
96 |
97 | if self.use_conv_transpose:
98 | raise NotImplementedError
99 |
100 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
101 | dtype = hidden_states.dtype
102 | if dtype == torch.bfloat16:
103 | hidden_states = hidden_states.to(torch.float32)
104 |
105 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
106 | if hidden_states.shape[0] >= 64:
107 | hidden_states = hidden_states.contiguous()
108 |
109 | # if `output_size` is passed we force the interpolation output
110 | # size and do not make use of `scale_factor=2`
111 | if output_size is None:
112 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
113 | else:
114 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
115 |
116 | # If the input is bfloat16, we cast back to bfloat16
117 | if dtype == torch.bfloat16:
118 | hidden_states = hidden_states.to(dtype)
119 |
120 | # if self.use_conv:
121 | # if self.name == "conv":
122 | # hidden_states = self.conv(hidden_states)
123 | # else:
124 | # hidden_states = self.Conv2d_0(hidden_states)
125 | hidden_states = self.conv(hidden_states)
126 |
127 | return hidden_states
128 |
129 |
130 | class Downsample3D(nn.Module):
131 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
132 | super().__init__()
133 | self.channels = channels
134 | self.out_channels = out_channels or channels
135 | self.use_conv = use_conv
136 | self.padding = padding
137 | stride = 2
138 | self.name = name
139 |
140 | if use_conv:
141 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
142 | else:
143 | raise NotImplementedError
144 |
145 | def forward(self, hidden_states):
146 | assert hidden_states.shape[1] == self.channels
147 | if self.use_conv and self.padding == 0:
148 | raise NotImplementedError
149 |
150 | assert hidden_states.shape[1] == self.channels
151 | hidden_states = self.conv(hidden_states)
152 |
153 | return hidden_states
154 |
155 |
156 | class ResnetBlock3D(nn.Module):
157 | def __init__(
158 | self,
159 | *,
160 | in_channels,
161 | out_channels=None,
162 | conv_shortcut=False,
163 | dropout=0.0,
164 | temb_channels=512,
165 | groups=32,
166 | groups_out=None,
167 | pre_norm=True,
168 | eps=1e-6,
169 | non_linearity="swish",
170 | time_embedding_norm="default",
171 | output_scale_factor=1.0,
172 | use_in_shortcut=None,
173 | use_inflated_groupnorm=False,
174 | ):
175 | super().__init__()
176 | self.pre_norm = pre_norm
177 | self.pre_norm = True
178 | self.in_channels = in_channels
179 | out_channels = in_channels if out_channels is None else out_channels
180 | self.out_channels = out_channels
181 | self.use_conv_shortcut = conv_shortcut
182 | self.time_embedding_norm = time_embedding_norm
183 | self.output_scale_factor = output_scale_factor
184 |
185 | if groups_out is None:
186 | groups_out = groups
187 |
188 | assert use_inflated_groupnorm is not None
189 | if use_inflated_groupnorm:
190 | self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
191 | else:
192 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
193 |
194 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
195 |
196 | if temb_channels is not None:
197 | if self.time_embedding_norm == "default":
198 | time_emb_proj_out_channels = out_channels
199 | elif self.time_embedding_norm == "scale_shift":
200 | time_emb_proj_out_channels = out_channels * 2
201 | else:
202 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
203 |
204 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
205 | else:
206 | self.time_emb_proj = None
207 |
208 | if use_inflated_groupnorm:
209 | self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
210 | else:
211 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
212 |
213 | self.dropout = torch.nn.Dropout(dropout)
214 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
215 |
216 | if non_linearity == "swish":
217 | self.nonlinearity = lambda x: F.silu(x)
218 | elif non_linearity == "mish":
219 | self.nonlinearity = Mish()
220 | elif non_linearity == "silu":
221 | self.nonlinearity = nn.SiLU()
222 |
223 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
224 |
225 | self.conv_shortcut = None
226 | if self.use_in_shortcut:
227 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
228 |
229 | def forward(self, input_tensor, temb):
230 | hidden_states = input_tensor
231 |
232 | hidden_states = self.norm1(hidden_states)
233 | hidden_states = self.nonlinearity(hidden_states)
234 |
235 | hidden_states = self.conv1(hidden_states)
236 |
237 | if temb is not None:
238 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
239 |
240 | if temb is not None and self.time_embedding_norm == "default":
241 | hidden_states = hidden_states + temb
242 |
243 | hidden_states = self.norm2(hidden_states)
244 |
245 | if temb is not None and self.time_embedding_norm == "scale_shift":
246 | scale, shift = torch.chunk(temb, 2, dim=1)
247 | hidden_states = hidden_states * (1 + scale) + shift
248 |
249 | hidden_states = self.nonlinearity(hidden_states)
250 |
251 | hidden_states = self.dropout(hidden_states)
252 | hidden_states = self.conv2(hidden_states)
253 |
254 | if self.conv_shortcut is not None:
255 | input_tensor = self.conv_shortcut(input_tensor)
256 |
257 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
258 |
259 | return output_tensor
260 |
261 |
262 | class Mish(torch.nn.Module):
263 | def forward(self, hidden_states):
264 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
265 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/live2diff/acceleration/tensorrt/utilities.py:
--------------------------------------------------------------------------------
1 | #! fork: https://github.com/NVIDIA/TensorRT/blob/main/demo/Diffusion/utilities.py
2 |
3 | #
4 | # Copyright 2022 The HuggingFace Inc. team.
5 | # SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
6 | # SPDX-License-Identifier: Apache-2.0
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 | #
20 |
21 | import gc
22 | from collections import OrderedDict
23 | from typing import *
24 |
25 | import numpy as np
26 | import onnx
27 | import onnx_graphsurgeon as gs
28 | import tensorrt as trt
29 | import torch
30 | from cuda import cudart
31 | from PIL import Image
32 | from polygraphy import cuda
33 | from polygraphy.backend.common import bytes_from_path
34 | from polygraphy.backend.trt import (
35 | CreateConfig,
36 | Profile,
37 | engine_from_bytes,
38 | engine_from_network,
39 | network_from_onnx_path,
40 | save_engine,
41 | )
42 |
43 | from .models import BaseModel
44 |
45 |
46 | TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
47 |
48 | # Map of numpy dtype -> torch dtype
49 | numpy_to_torch_dtype_dict = {
50 | np.uint8: torch.uint8,
51 | np.int8: torch.int8,
52 | np.int16: torch.int16,
53 | np.int32: torch.int32,
54 | np.int64: torch.int64,
55 | np.float16: torch.float16,
56 | np.float32: torch.float32,
57 | np.float64: torch.float64,
58 | np.complex64: torch.complex64,
59 | np.complex128: torch.complex128,
60 | }
61 | if np.version.full_version >= "1.24.0":
62 | numpy_to_torch_dtype_dict[np.bool_] = torch.bool
63 | else:
64 | numpy_to_torch_dtype_dict[np.bool] = torch.bool
65 |
66 | # Map of torch dtype -> numpy dtype
67 | torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
68 |
69 |
70 | def CUASSERT(cuda_ret):
71 | err = cuda_ret[0]
72 | if err != cudart.cudaError_t.cudaSuccess:
73 | raise RuntimeError(
74 | f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t"
75 | )
76 | if len(cuda_ret) > 1:
77 | return cuda_ret[1]
78 | return None
79 |
80 |
81 | class Engine:
82 | def __init__(
83 | self,
84 | engine_path,
85 | ):
86 | self.engine_path = engine_path
87 | self.engine = None
88 | self.context = None
89 | self.buffers = OrderedDict()
90 | self.tensors = OrderedDict()
91 | self.cuda_graph_instance = None # cuda graph
92 |
93 | def __del__(self):
94 | [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)]
95 | del self.engine
96 | del self.context
97 | del self.buffers
98 | del self.tensors
99 |
100 | def refit(self, onnx_path, onnx_refit_path):
101 | def convert_int64(arr):
102 | # TODO: smarter conversion
103 | if len(arr.shape) == 0:
104 | return np.int32(arr)
105 | return arr
106 |
107 | def add_to_map(refit_dict, name, values):
108 | if name in refit_dict:
109 | assert refit_dict[name] is None
110 | if values.dtype == np.int64:
111 | values = convert_int64(values)
112 | refit_dict[name] = values
113 |
114 | print(f"Refitting TensorRT engine with {onnx_refit_path} weights")
115 | refit_nodes = gs.import_onnx(onnx.load(onnx_refit_path)).toposort().nodes
116 |
117 | # Construct mapping from weight names in refit model -> original model
118 | name_map = {}
119 | for n, node in enumerate(gs.import_onnx(onnx.load(onnx_path)).toposort().nodes):
120 | refit_node = refit_nodes[n]
121 | assert node.op == refit_node.op
122 | # Constant nodes in ONNX do not have inputs but have a constant output
123 | if node.op == "Constant":
124 | name_map[refit_node.outputs[0].name] = node.outputs[0].name
125 | # Handle scale and bias weights
126 | elif node.op == "Conv":
127 | if node.inputs[1].__class__ == gs.Constant:
128 | name_map[refit_node.name + "_TRTKERNEL"] = node.name + "_TRTKERNEL"
129 | if node.inputs[2].__class__ == gs.Constant:
130 | name_map[refit_node.name + "_TRTBIAS"] = node.name + "_TRTBIAS"
131 | # For all other nodes: find node inputs that are initializers (gs.Constant)
132 | else:
133 | for i, inp in enumerate(node.inputs):
134 | if inp.__class__ == gs.Constant:
135 | name_map[refit_node.inputs[i].name] = inp.name
136 |
137 | def map_name(name):
138 | if name in name_map:
139 | return name_map[name]
140 | return name
141 |
142 | # Construct refit dictionary
143 | refit_dict = {}
144 | refitter = trt.Refitter(self.engine, TRT_LOGGER)
145 | all_weights = refitter.get_all()
146 | for layer_name, role in zip(all_weights[0], all_weights[1]):
147 | # for speciailized roles, use a unique name in the map:
148 | if role == trt.WeightsRole.KERNEL:
149 | name = layer_name + "_TRTKERNEL"
150 | elif role == trt.WeightsRole.BIAS:
151 | name = layer_name + "_TRTBIAS"
152 | else:
153 | name = layer_name
154 |
155 | assert name not in refit_dict, "Found duplicate layer: " + name
156 | refit_dict[name] = None
157 |
158 | for n in refit_nodes:
159 | # Constant nodes in ONNX do not have inputs but have a constant output
160 | if n.op == "Constant":
161 | name = map_name(n.outputs[0].name)
162 | print(f"Add Constant {name}\n")
163 | add_to_map(refit_dict, name, n.outputs[0].values)
164 |
165 | # Handle scale and bias weights
166 | elif n.op == "Conv":
167 | if n.inputs[1].__class__ == gs.Constant:
168 | name = map_name(n.name + "_TRTKERNEL")
169 | add_to_map(refit_dict, name, n.inputs[1].values)
170 |
171 | if n.inputs[2].__class__ == gs.Constant:
172 | name = map_name(n.name + "_TRTBIAS")
173 | add_to_map(refit_dict, name, n.inputs[2].values)
174 |
175 | # For all other nodes: find node inputs that are initializers (AKA gs.Constant)
176 | else:
177 | for inp in n.inputs:
178 | name = map_name(inp.name)
179 | if inp.__class__ == gs.Constant:
180 | add_to_map(refit_dict, name, inp.values)
181 |
182 | for layer_name, weights_role in zip(all_weights[0], all_weights[1]):
183 | if weights_role == trt.WeightsRole.KERNEL:
184 | custom_name = layer_name + "_TRTKERNEL"
185 | elif weights_role == trt.WeightsRole.BIAS:
186 | custom_name = layer_name + "_TRTBIAS"
187 | else:
188 | custom_name = layer_name
189 |
190 | # Skip refitting Trilu for now; scalar weights of type int64 value 1 - for clip model
191 | if layer_name.startswith("onnx::Trilu"):
192 | continue
193 |
194 | if refit_dict[custom_name] is not None:
195 | refitter.set_weights(layer_name, weights_role, refit_dict[custom_name])
196 | else:
197 | print(f"[W] No refit weights for layer: {layer_name}")
198 |
199 | if not refitter.refit_cuda_engine():
200 | print("Failed to refit!")
201 | exit(0)
202 |
203 | def build(
204 | self,
205 | onnx_path,
206 | fp16,
207 | input_profile=None,
208 | enable_refit=False,
209 | enable_all_tactics=False,
210 | timing_cache=None,
211 | workspace_size=0,
212 | ):
213 | print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
214 | p = Profile()
215 | if input_profile:
216 | for name, dims in input_profile.items():
217 | assert len(dims) == 3
218 | p.add(name, min=dims[0], opt=dims[1], max=dims[2])
219 |
220 | config_kwargs = {}
221 |
222 | if workspace_size > 0:
223 | config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size}
224 | if not enable_all_tactics:
225 | config_kwargs["tactic_sources"] = []
226 |
227 | engine = engine_from_network(
228 | network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),
229 | config=CreateConfig(
230 | fp16=fp16, refittable=enable_refit, profiles=[p], load_timing_cache=timing_cache, **config_kwargs
231 | ),
232 | save_timing_cache=timing_cache,
233 | )
234 | save_engine(engine, path=self.engine_path)
235 |
236 | def load(self):
237 | print(f"Loading TensorRT engine: {self.engine_path}")
238 | self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
239 |
240 | def activate(self, reuse_device_memory=None, profiler=None):
241 | if reuse_device_memory:
242 | self.context = self.engine.create_execution_context_without_device_memory()
243 | self.context.device_memory = reuse_device_memory
244 | else:
245 | self.context = self.engine.create_execution_context()
246 |
247 | def allocate_buffers(self, shape_dict=None, device="cuda"):
248 | # NOTE: API for tensorrt 10.01
249 | from tensorrt import TensorIOMode
250 |
251 | for idx in range(self.engine.num_io_tensors):
252 | binding = self.engine[idx]
253 | if shape_dict and binding in shape_dict:
254 | shape = shape_dict[binding]
255 | else:
256 | shape = self.engine.get_tensor_shape(binding)
257 | dtype = trt.nptype(self.engine.get_tensor_dtype(binding))
258 | tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype], device=device)
259 | self.tensors[binding] = tensor
260 |
261 | binding_mode = self.engine.get_tensor_mode(binding)
262 | if binding_mode == TensorIOMode.INPUT:
263 | self.context.set_input_shape(binding, shape)
264 | self.has_allocated = True
265 |
266 | def infer(self, feed_dict, stream, use_cuda_graph=False):
267 | for name, buf in feed_dict.items():
268 | self.tensors[name].copy_(buf)
269 |
270 | for name, tensor in self.tensors.items():
271 | self.context.set_tensor_address(name, tensor.data_ptr())
272 |
273 | if use_cuda_graph:
274 | if self.cuda_graph_instance is not None:
275 | CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream.ptr))
276 | CUASSERT(cudart.cudaStreamSynchronize(stream.ptr))
277 | else:
278 | # do inference before CUDA graph capture
279 | noerror = self.context.execute_async_v3(stream.ptr)
280 | if not noerror:
281 | raise ValueError("ERROR: inference failed.")
282 | # capture cuda graph
283 | CUASSERT(
284 | cudart.cudaStreamBeginCapture(stream.ptr, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal)
285 | )
286 | self.context.execute_async_v3(stream.ptr)
287 | self.graph = CUASSERT(cudart.cudaStreamEndCapture(stream.ptr))
288 | self.cuda_graph_instance = CUASSERT(cudart.cudaGraphInstantiate(self.graph, 0))
289 | else:
290 | noerror = self.context.execute_async_v3(stream.ptr)
291 | if not noerror:
292 | raise ValueError("ERROR: inference failed.")
293 |
294 | return self.tensors
295 |
296 |
297 | def decode_images(images: torch.Tensor):
298 | images = (
299 | ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy()
300 | )
301 | return [Image.fromarray(x) for x in images]
302 |
303 |
304 | def preprocess_image(image: Image.Image):
305 | w, h = image.size
306 | w, h = [x - x % 32 for x in (w, h)] # resize to integer multiple of 32
307 | image = image.resize((w, h))
308 | init_image = np.array(image).astype(np.float32) / 255.0
309 | init_image = init_image[None].transpose(0, 3, 1, 2)
310 | init_image = torch.from_numpy(init_image).contiguous()
311 | return 2.0 * init_image - 1.0
312 |
313 |
314 | def prepare_mask_and_masked_image(image: Image.Image, mask: Image.Image):
315 | if isinstance(image, Image.Image):
316 | image = np.array(image.convert("RGB"))
317 | image = image[None].transpose(0, 3, 1, 2)
318 | image = torch.from_numpy(image).to(dtype=torch.float32).contiguous() / 127.5 - 1.0
319 | if isinstance(mask, Image.Image):
320 | mask = np.array(mask.convert("L"))
321 | mask = mask.astype(np.float32) / 255.0
322 | mask = mask[None, None]
323 | mask[mask < 0.5] = 0
324 | mask[mask >= 0.5] = 1
325 | mask = torch.from_numpy(mask).to(dtype=torch.float32).contiguous()
326 |
327 | masked_image = image * (mask < 0.5)
328 |
329 | return mask, masked_image
330 |
331 |
332 | def build_engine(
333 | engine_path: str,
334 | onnx_opt_path: str,
335 | model_data: BaseModel,
336 | opt_image_height: int,
337 | opt_image_width: int,
338 | opt_batch_size: int,
339 | build_static_batch: bool = False,
340 | build_dynamic_shape: bool = False,
341 | build_all_tactics: bool = False,
342 | build_enable_refit: bool = False,
343 | ):
344 | _, free_mem, _ = cudart.cudaMemGetInfo()
345 | GiB = 2**30
346 | if free_mem > 6 * GiB:
347 | activation_carveout = 4 * GiB
348 | max_workspace_size = free_mem - activation_carveout
349 | else:
350 | max_workspace_size = 0
351 | engine = Engine(engine_path)
352 | input_profile = model_data.get_input_profile(
353 | opt_batch_size,
354 | opt_image_height,
355 | opt_image_width,
356 | static_batch=build_static_batch,
357 | static_shape=not build_dynamic_shape,
358 | )
359 | engine.build(
360 | onnx_opt_path,
361 | fp16=True,
362 | input_profile=input_profile,
363 | enable_refit=build_enable_refit,
364 | enable_all_tactics=build_all_tactics,
365 | workspace_size=max_workspace_size,
366 | )
367 |
368 | return engine
369 |
370 |
371 | def export_onnx(
372 | model,
373 | onnx_path: str,
374 | model_data: BaseModel,
375 | opt_image_height: int,
376 | opt_image_width: int,
377 | opt_batch_size: int,
378 | onnx_opset: int,
379 | auto_cast: bool = True,
380 | ):
381 | from contextlib import contextmanager
382 |
383 | @contextmanager
384 | def auto_cast_manager(enabled):
385 | if enabled:
386 | with torch.inference_mode(), torch.autocast("cuda"):
387 | yield
388 | else:
389 | yield
390 |
391 | with auto_cast_manager(auto_cast):
392 | inputs = model_data.get_sample_input(opt_batch_size, opt_image_height, opt_image_width)
393 | torch.onnx.export(
394 | model,
395 | inputs,
396 | onnx_path,
397 | export_params=True,
398 | opset_version=onnx_opset,
399 | do_constant_folding=True,
400 | input_names=model_data.get_input_names(),
401 | output_names=model_data.get_output_names(),
402 | dynamic_axes=model_data.get_dynamic_axes(),
403 | )
404 | del model
405 | gc.collect()
406 | torch.cuda.empty_cache()
407 |
408 |
409 | def optimize_onnx(
410 | onnx_path: str,
411 | onnx_opt_path: str,
412 | model_data: BaseModel,
413 | ):
414 | model_data.optimize(onnx_path, onnx_opt_path)
415 | # # onnx_opt_graph = model_data.optimize(onnx.load(onnx_path))
416 | # onnx_opt_graph = model_data.optimize(onnx_path)
417 | # onnx.save(onnx_opt_graph, onnx_opt_path)
418 | # del onnx_opt_graph
419 | # gc.collect()
420 | # torch.cuda.empty_cache()
421 |
422 |
423 | def handle_onnx_batch_norm(onnx_path: str):
424 | onnx_model = onnx.load(onnx_path)
425 | for node in onnx_model.graph.node:
426 | if node.op_type == "BatchNormalization":
427 | for attribute in node.attribute:
428 | if attribute.name == "training_mode":
429 | if attribute.i == 1:
430 | node.output.remove(node.output[1])
431 | node.output.remove(node.output[1])
432 | attribute.i = 0
433 |
434 | onnx.save_model(onnx_model, onnx_path)
435 |
--------------------------------------------------------------------------------
/live2diff/animatediff/pipeline/pipeline_animatediff_depth.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/open-mmlab/PIA/blob/main/animatediff/pipelines/i2v_pipeline.py
2 |
3 | from dataclasses import dataclass
4 | from typing import List, Optional, Union
5 |
6 | import numpy as np
7 | import torch
8 | from diffusers.configuration_utils import FrozenDict
9 | from diffusers.loaders import TextualInversionLoaderMixin
10 | from diffusers.models import AutoencoderKL
11 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
12 | from diffusers.schedulers import (
13 | DDIMScheduler,
14 | DPMSolverMultistepScheduler,
15 | EulerAncestralDiscreteScheduler,
16 | EulerDiscreteScheduler,
17 | LMSDiscreteScheduler,
18 | PNDMScheduler,
19 | )
20 | from diffusers.utils import BaseOutput, deprecate, is_accelerate_available, logging
21 | from packaging import version
22 | from transformers import CLIPTextModel, CLIPTokenizer
23 |
24 | from ..models.depth_utils import MidasDetector
25 | from ..models.unet_depth_streaming import UNet3DConditionStreamingModel
26 | from .loader import LoraLoaderWithWarmup
27 |
28 |
29 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30 |
31 |
32 | @dataclass
33 | class AnimationPipelineOutput(BaseOutput):
34 | videos: Union[torch.Tensor, np.ndarray]
35 | input_images: Optional[Union[torch.Tensor, np.ndarray]] = None
36 |
37 |
38 | class AnimationDepthPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderWithWarmup):
39 | _optional_components = []
40 |
41 | def __init__(
42 | self,
43 | vae: AutoencoderKL,
44 | text_encoder: CLIPTextModel,
45 | tokenizer: CLIPTokenizer,
46 | unet: UNet3DConditionStreamingModel,
47 | depth_model: MidasDetector,
48 | scheduler: Union[
49 | DDIMScheduler,
50 | PNDMScheduler,
51 | LMSDiscreteScheduler,
52 | EulerDiscreteScheduler,
53 | EulerAncestralDiscreteScheduler,
54 | DPMSolverMultistepScheduler,
55 | ],
56 | ):
57 | super().__init__()
58 |
59 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
60 | deprecation_message = (
61 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
62 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
63 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
64 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
65 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
66 | " file"
67 | )
68 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
69 | new_config = dict(scheduler.config)
70 | new_config["steps_offset"] = 1
71 | scheduler._internal_dict = FrozenDict(new_config)
72 |
73 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
74 | deprecation_message = (
75 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
76 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
77 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
78 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
79 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
80 | )
81 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
82 | new_config = dict(scheduler.config)
83 | new_config["clip_sample"] = False
84 | scheduler._internal_dict = FrozenDict(new_config)
85 |
86 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
87 | version.parse(unet.config._diffusers_version).base_version
88 | ) < version.parse("0.9.0.dev0")
89 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
90 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
91 | deprecation_message = (
92 | "The configuration file of the unet has set the default `sample_size` to smaller than"
93 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
94 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
95 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
96 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
97 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
98 | " in the config might lead to incorrect results in future versions. If you have downloaded this"
99 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
100 | " the `unet/config.json` file"
101 | )
102 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
103 | new_config = dict(unet.config)
104 | new_config["sample_size"] = 64
105 | unet._internal_dict = FrozenDict(new_config)
106 |
107 | self.register_modules(
108 | vae=vae,
109 | text_encoder=text_encoder,
110 | tokenizer=tokenizer,
111 | unet=unet,
112 | depth_model=depth_model,
113 | scheduler=scheduler,
114 | )
115 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
116 | self.log_denoising_mean = False
117 |
118 | def enable_vae_slicing(self):
119 | self.vae.enable_slicing()
120 |
121 | def disable_vae_slicing(self):
122 | self.vae.disable_slicing()
123 |
124 | def enable_sequential_cpu_offload(self, gpu_id=0):
125 | if is_accelerate_available():
126 | from accelerate import cpu_offload
127 | else:
128 | raise ImportError("Please install accelerate via `pip install accelerate`")
129 |
130 | device = torch.device(f"cuda:{gpu_id}")
131 |
132 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
133 | if cpu_offloaded_model is not None:
134 | cpu_offload(cpu_offloaded_model, device)
135 |
136 | @property
137 | def _execution_device(self):
138 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
139 | return self.device
140 | for module in self.unet.modules():
141 | if (
142 | hasattr(module, "_hf_hook")
143 | and hasattr(module._hf_hook, "execution_device")
144 | and module._hf_hook.execution_device is not None
145 | ):
146 | return torch.device(module._hf_hook.execution_device)
147 | return self.device
148 |
149 | def _encode_prompt(
150 | self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt, clip_skip=None
151 | ):
152 | batch_size = len(prompt) if isinstance(prompt, list) else 1
153 |
154 | text_inputs = self.tokenizer(
155 | prompt,
156 | padding="max_length",
157 | max_length=self.tokenizer.model_max_length,
158 | truncation=True,
159 | return_tensors="pt",
160 | )
161 | text_input_ids = text_inputs.input_ids
162 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
163 |
164 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
165 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
166 | logger.warning(
167 | "The following part of your input was truncated because CLIP can only handle sequences up to"
168 | f" {self.tokenizer.model_max_length} tokens: {removed_text}"
169 | )
170 |
171 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
172 | attention_mask = text_inputs.attention_mask.to(device)
173 | else:
174 | attention_mask = None
175 |
176 | if clip_skip is None:
177 | text_embeddings = self.text_encoder(
178 | text_input_ids.to(device),
179 | attention_mask=attention_mask,
180 | )
181 | text_embeddings = text_embeddings[0]
182 | else:
183 | # support ckip skip here, suitable for model based on NAI~
184 | text_embeddings = self.text_encoder(
185 | text_input_ids.to(device),
186 | attention_mask=attention_mask,
187 | output_hidden_states=True,
188 | )
189 | text_embeddings = text_embeddings[-1][-(clip_skip + 1)]
190 | text_embeddings = self.text_encoder.text_model.final_layer_norm(text_embeddings)
191 |
192 | # duplicate text embeddings for each generation per prompt, using mps friendly method
193 | bs_embed, seq_len, _ = text_embeddings.shape
194 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
195 | text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
196 |
197 | # get unconditional embeddings for classifier free guidance
198 | if do_classifier_free_guidance:
199 | uncond_tokens: List[str]
200 | if negative_prompt is None:
201 | uncond_tokens = [""] * batch_size
202 | elif type(prompt) is not type(negative_prompt):
203 | raise TypeError(
204 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
205 | f" {type(prompt)}."
206 | )
207 | elif isinstance(negative_prompt, str):
208 | uncond_tokens = [negative_prompt]
209 | elif batch_size != len(negative_prompt):
210 | raise ValueError(
211 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
212 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
213 | " the batch size of `prompt`."
214 | )
215 | else:
216 | uncond_tokens = negative_prompt
217 |
218 | max_length = text_input_ids.shape[-1]
219 | uncond_input = self.tokenizer(
220 | uncond_tokens,
221 | padding="max_length",
222 | max_length=max_length,
223 | truncation=True,
224 | return_tensors="pt",
225 | )
226 |
227 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
228 | attention_mask = uncond_input.attention_mask.to(device)
229 | else:
230 | attention_mask = None
231 |
232 | uncond_embeddings = self.text_encoder(
233 | uncond_input.input_ids.to(device),
234 | attention_mask=attention_mask,
235 | )
236 | uncond_embeddings = uncond_embeddings[0]
237 |
238 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
239 | seq_len = uncond_embeddings.shape[1]
240 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
241 | uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
242 |
243 | # For classifier free guidance, we need to do two forward passes.
244 | # Here we concatenate the unconditional and text embeddings into a single batch
245 | # to avoid doing two forward passes
246 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
247 |
248 | return text_embeddings
249 |
250 | @classmethod
251 | def build_pipeline(cls, config_path: str, dreambooth: Optional[str] = None):
252 | """We build pipeline from config path"""
253 | from omegaconf import OmegaConf
254 |
255 | from ...utils.config import load_config
256 | from ..converter import load_third_party_checkpoints
257 | from ..models.unet_depth_streaming import UNet3DConditionStreamingModel
258 |
259 | cfg = load_config(config_path)
260 | pretrained_model_path = cfg.pretrained_model_path
261 | unet_additional_kwargs = cfg.get("unet_additional_kwargs", {})
262 | noise_scheduler_kwargs = cfg.noise_scheduler_kwargs
263 | third_party_dict = cfg.get("third_party_dict", {})
264 |
265 | noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
266 |
267 | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
268 | vae = vae.to(device="cuda", dtype=torch.bfloat16)
269 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
270 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
271 | text_encoder = text_encoder.to(device="cuda", dtype=torch.float16)
272 |
273 | unet = UNet3DConditionStreamingModel.from_pretrained_2d(
274 | pretrained_model_path,
275 | subfolder="unet",
276 | unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs) if unet_additional_kwargs else {},
277 | )
278 |
279 | motion_module_path = cfg.motion_module_path
280 | # load motion module to unet
281 | mm_checkpoint = torch.load(motion_module_path, map_location="cuda")
282 | if "global_step" in mm_checkpoint:
283 | print(f"global_step: {mm_checkpoint['global_step']}")
284 | state_dict = mm_checkpoint["state_dict"] if "state_dict" in mm_checkpoint else mm_checkpoint
285 | # NOTE: hard code here: remove `grid` from state_dict
286 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items() if "grid" not in k}
287 |
288 | m, u = unet.load_state_dict(state_dict, strict=False)
289 | assert len(u) == 0, f"Find unexpected keys ({len(u)}): {u}"
290 | print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
291 |
292 | unet = unet.to(dtype=torch.float16)
293 | depth_model = MidasDetector(cfg.depth_model_path).to(device="cuda", dtype=torch.float16)
294 |
295 | pipeline = cls(
296 | unet=unet,
297 | vae=vae,
298 | tokenizer=tokenizer,
299 | text_encoder=text_encoder,
300 | depth_model=depth_model,
301 | scheduler=noise_scheduler,
302 | )
303 | pipeline = load_third_party_checkpoints(pipeline, third_party_dict, dreambooth)
304 |
305 | return pipeline
306 |
307 | @classmethod
308 | def build_warmup_unet(cls, config_path: str, dreambooth: Optional[str] = None):
309 | from omegaconf import OmegaConf
310 |
311 | from ...utils.config import load_config
312 | from ..converter import load_third_party_unet
313 | from ..models.unet_depth_warmup import UNet3DConditionWarmupModel
314 |
315 | cfg = load_config(config_path)
316 | pretrained_model_path = cfg.pretrained_model_path
317 | unet_additional_kwargs = cfg.get("unet_additional_kwargs", {})
318 | third_party_dict = cfg.get("third_party_dict", {})
319 |
320 | unet = UNet3DConditionWarmupModel.from_pretrained_2d(
321 | pretrained_model_path,
322 | subfolder="unet",
323 | unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs) if unet_additional_kwargs else {},
324 | )
325 | motion_module_path = cfg.motion_module_path
326 | # load motion module to unet
327 | mm_checkpoint = torch.load(motion_module_path, map_location="cpu")
328 | if "global_step" in mm_checkpoint:
329 | print(f"global_step: {mm_checkpoint['global_step']}")
330 | state_dict = mm_checkpoint["state_dict"] if "state_dict" in mm_checkpoint else mm_checkpoint
331 | # NOTE: hard code here: remove `grid` from state_dict
332 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items() if "grid" not in k}
333 |
334 | m, u = unet.load_state_dict(state_dict, strict=False)
335 | assert len(u) == 0, f"Find unexpected keys ({len(u)}): {u}"
336 | print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
337 |
338 | unet = load_third_party_unet(unet, third_party_dict, dreambooth)
339 | return unet
340 |
341 | def prepare_cache(self, height: int, width: int, denoising_steps_num: int):
342 | vae = self.vae
343 | scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
344 | self.unet.set_info_for_attn(height // scale_factor, width // scale_factor)
345 | kv_cache_list = self.unet.prepare_cache(denoising_steps_num)
346 | return kv_cache_list
347 |
348 | def prepare_warmup_unet(self, height: int, width: int, unet):
349 | vae = self.vae
350 | scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
351 | unet.set_info_for_attn(height // scale_factor, width // scale_factor)
352 |
--------------------------------------------------------------------------------