├── .gitignore
├── README.md
├── accelerate_configs
├── compiled_1.yaml
└── deepspeed.yaml
├── assets
├── human1.jpg
├── human2.jpg
├── human3.jpg
├── human4.jpg
├── human5.jpg
├── motion1.mp4
├── motion2.mp4
├── motion3.mp4
├── person1_img.jpg
└── person2_img.jpg
├── checkpoints
└── Put pre-trained model here.txt
├── docs
└── finetune.md
├── dwpose
├── __init__.py
├── dwpose_detector.py
├── onnxdet.py
├── onnxpose.py
├── preprocess.py
├── util.py
└── wholebody.py
├── dynamictrl
├── __init__.py
├── models
│ ├── model_dynamictrl.py
│ └── modules
│ │ ├── cross_attention.py
│ │ ├── embeddings.py
│ │ └── model_cogvideox_autoencoderKL.py
└── pipelines
│ ├── dynamictrl_pipeline.py
│ └── modules
│ ├── dynamictrl_output.py
│ └── utils.py
├── requirements.txt
├── scripts
├── dynamictrl_inference.py
└── dynamictrl_sft_finetune.py
├── tools
└── qwen.py
└── utils
├── args.py
├── dataset.py
├── dataset_use_mask_entity.py
├── freeinit_utils.py
├── load_validation_control.py
├── pre_process.py
├── prepare_dataset.py
├── save_utils.py
├── text_encoder
├── __init__.py
└── text_encoder.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## DynamiCtrl: Rethinking the Basic Structure and the Role of Text for High-quality Human Image Animation
2 |
3 | [Haoyu Zhao](https://scholar.google.com/citations?user=pCGM7jwAAAAJ&hl=zh-CN&oi=ao/), [Zhongang Qi](https://scholar.google.com/citations?user=zJvrrusAAAAJ&hl=en/), [Cong Wang](#), [Qingqing Zheng](https://scholar.google.com.hk/citations?user=l0Y7emkAAAAJ&hl=zh-CN&oi=ao/), [Guansong Lu](https://scholar.google.com.hk/citations?user=YIt8thUAAAAJ&hl=zh-CN&oi=ao), [Fei Chen](#), [Hang Xu](https://scholar.google.com.hk/citations?user=J_8TX6sAAAAJ&hl=zh-CN&oi=ao) and [Zuxuan Wu](https://scholar.google.com.hk/citations?user=7t12hVkAAAAJ&hl=zh-CN&oi=ao)
4 |
5 |
6 |
7 |
8 | [](https://github.com/gulucaptain/DynamiCtrl)
9 |
10 |
11 | ## 🎏 Introduction
12 | TL; DR: DynamiCtrl is the first framework to propose the "Joint-text" paradigm to the pose-guided human animation task and achieve effective pose control within the diffusion transformer (DiT) architecture.
13 |
14 | CLICK for the full introduction
15 |
16 |
17 | > With diffusion transformer (DiT) excelling in video generation, its use in specific tasks has drawn increasing attention. However, adapting DiT for pose-guided human image animation faces two core challenges: (a) existing U-Net-based pose control methods may be suboptimal for the DiT backbone; and (b) removing text guidance, as in previous approaches, often leads to semantic loss and model degradation. To address these issues, we propose DynamiCtrl, a novel framework for human animation in video DiT architecture. Specifically, we use a shared VAE encoder for human images and driving poses, unifying them into a common latent space, maintaining pose fidelity, and eliminating the need for an expert pose encoder during video denoising. To integrate pose control into the DiT backbone effectively, we propose a novel Pose-adaptive Layer Norm model. It injects normalized pose features into the denoising process via conditioning on visual tokens, enabling seamless and scalable pose control across DiT blocks. Furthermore, to overcome the shortcomings of text removal, we introduce the "Joint-text" paradigm, which preserves the role of text embeddings to provide global semantic context. Through full-attention blocks, image and pose features are aligned with text features, enhancing semantic consistency, leveraging pretrained knowledge, and enabling multi-level control. Experiments verify the superiority of DynamiCtrl on benchmark and self-collected data (e.g., achieving the best LPIPS of 0.166), demonstrating strong character control and high-quality synthesis.
18 |
19 |
20 | ## 📺 Overview on YouTube
21 |
22 | [](https://www.youtube.com/watch?v=Mu_pNXM4PcE)
23 | Please click to watch.
24 |
25 | ## ⚔️ DynamiCtrl for High-quality Pose-guided Human Image Animation
26 |
27 | We first refocus on the role of text for this task and find that fine-grained textual information helps improve video quality. In particular, we can achieve fine-grained local controllability using different prompts.
28 |
29 |
30 |
31 |
32 |
33 | |
34 |
35 |
36 | |
37 |
38 |
39 | |
40 |
41 |
42 |
43 | CLICK to check the prompts used for generation in the above three cases.
44 |
45 | > Prompt (left): “The image depicts a stylized, animated character standing amidst a chaotic and dynamic background. The character is dressed in a blue suit with a red cape, featuring a prominent "S" emblem on the chest. The suit has a belt with pouches and a utility belt. The character has spiky hair and is standing on a pile of debris and rubble, suggesting a scene of destruction or battle. The background is filled with glowing, fiery elements and a sense of motion, adding to the dramatic and intense atmosphere of the scene."
46 |
47 | > Prompt (mid): “The person in the image is a woman with long, blonde hair styled in loose waves. She is wearing a form-fitting, sleeveless top with a high neckline and a small cutout at the chest. The top is beige and has a strap across her chest. She is also wearing a black belt with a pouch attached to it. Around her neck, she has a turquoise pendant necklace. The background appears to be a dimly lit, urban environment with a warm, golden glow."
48 |
49 | > Prompt (right): “The person in the image is wearing a black, form-fitting one-piece outfit and a pair of VR goggles. They are walking down a busy street with numerous people and colorful neon signs in the background. The street appears to be a bustling urban area, possibly in a city known for its vibrant nightlife and entertainment. The lighting and signage suggest a lively atmosphere, typical of a cityscape at night."
50 |
51 |
52 |
53 |
54 |
55 |
56 | |
57 |
58 |
59 | |
60 |
61 |
62 | |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 | |
71 |
72 |
73 | |
74 |
75 |
76 | |
77 |
78 |
79 |
80 |
81 | |
82 |
83 |
84 | |
85 |
86 |
87 | |
88 |
89 |
90 |
91 | ### Fine-grained video control
92 |
93 |
94 |
95 |
96 |
97 | |
98 |
99 |
100 | |
101 |
102 |
103 | |
104 |
105 |
106 | |
107 |
108 |
109 | |
110 |
111 |
112 | |
113 |
114 |
115 | |
116 |
117 |
118 |
119 | CLICK to check the prompts used for generation in the above background-control cases.
120 |
121 | > Scene 1: The person in the image is wearing a white, knee-length dress with short sleeves and a square neckline. The dress features lace detailing and a ruffled hem. The person is also wearing clear, open-toed sandals. The background shows a bustling futuristic city at night, with neon lights reflecting off the wet streets and flying cars zooming above.
122 |
123 | > Scene 2: The person in the image is wearing a white, knee-length dress with short sleeves and a square neckline. The dress features lace detailing and a ruffled hem. The person is also wearing clear, open-toed sandals. The background shows a vibrant market street in a Middle Eastern bazaar, filled with colorful fabrics, exotic spices, and merchants calling out to customers.
124 |
125 | > Scene 3: The person in the image is wearing a white, knee-length dress with short sleeves and a square neckline. The dress features lace detailing and a ruffled hem. The person is also wearing clear, open-toed sandals. The background shows a sunny beach with golden sand, gentle ocean waves rolling onto the shore, and palm trees swaying in the breeze.
126 |
127 | > Scene 4: The person in the image is wearing a white, knee-length dress with short sleeves and a square neckline. The dress features lace detailing and a ruffled hem. The person is also wearing clear, open-toed sandals. The background shows a high-tech research lab with sleek metallic walls, glowing holographic screens, and robotic arms assembling futuristic devices.
128 |
129 | > Scene 5: The person in the image is wearing a white, knee-length dress with short sleeves and a square neckline. The dress features lace detailing and a ruffled hem. The person is also wearing clear, open-toed sandals. The background shows a mystical ancient temple hidden deep in the jungle, covered in vines, with glowing runes carved into the stone walls.
130 |
131 | > Scene 6: The person in the image is wearing a white, knee-length dress with short sleeves and a square neckline. The dress features lace detailing and a ruffled hem. The person is also wearing clear, open-toed sandals. The background shows a serene snowy forest with tall pine trees, soft snowflakes falling gently, and a frozen river winding through the landscape.
132 |
133 | > Scene 7: The person in the image is wearing a white, knee-length dress with short sleeves and a square neckline. The dress features lace detailing and a ruffled hem. The person is also wearing clear, open-toed sandals. The background shows an abandoned industrial warehouse with broken windows, scattered debris, and rusted machinery covered in dust.
134 |
135 |
136 |
137 |
138 |
139 |
140 | |
141 |
142 |
143 | |
144 |
145 |
146 | |
147 |
148 |
149 | |
150 |
151 |
152 | |
153 |
154 |
155 | |
156 |
157 |
158 | |
159 |
160 |
161 |
162 |
163 |
164 | ## 🚧 Todo
165 |
166 | Click for Previous todos
167 |
168 | - [✔] Release the project page and demos.
169 | - [✔] Paper on Arxiv on 27 Mar 2025.
170 |
171 |
172 | - [✔] Release inference code.
173 | - [✔] Release models.
174 | - [✔] Release training code.
175 |
176 | ## 📋 Changelog
177 | - 2025.05.20 Code and models released!
178 | - 2025.03.30 Project page and demos released!
179 | - 2025.03.10 Project Online!
180 |
181 |
182 | ## Installation
183 |
184 | For usage (SFT fine-tuning, inference), you can install the dependencies with:
185 |
186 | ```bash
187 | conda create --name dynamictrl python=3.10
188 |
189 | source activate dynamictrl
190 |
191 | pip install -r requirements.txt
192 | ```
193 |
194 |
195 | ## Model Zoo
196 |
197 | We provide three grou of checkpoints:
198 |
199 | - DynamiCtrl-5B: trained with whole person image w/o mask and corresponding driving pose sequence.
200 | - Dynamictrl-5B-Mask_B01: trained with masked background in person image and pose sequence.
201 | - Dynamictrl-5B-Mask_C01: trained with masked clothes in person image and pose sequence.
202 |
203 |
204 | | name | Details | HF weights 🤗 |
205 | |:---|:---:|:---:|
206 | | DynamiCtrl-5B | SFT w/ whole image | [dynamictrl-5B](https://huggingface.co/gulucaptain/DynamiCtrl) |
207 | | Dynamictrl-5B-Mask_B01 | SFT w/ masked Background | [dynamictrl-5B-mask-B01](https://huggingface.co/gulucaptain/Dynamictrl-Mask_B01) |
208 | | Dynamictrl-5B-Mask_C01 | SFT w/ masked human Clothing | [dynamictrl-5B-mask-C01](https://huggingface.co/gulucaptain/Dynamictrl-Mask_C01) |
209 |
210 |
211 | [Causal VAE](https://arxiv.org/abs/2408.06072), [T5](https://arxiv.org/abs/1910.10683) are used as our VAE model and text encoder.
212 |
213 | ```bash
214 | cd checkpoints
215 |
216 | pip install -U huggingface_hub
217 |
218 | huggingface-cli download --resume-download --local-dir-use-symlinks False gulucaptain/DynamiCtrl --local-dir ./DynamiCtrl
219 |
220 | huggingface-cli download --resume-download --local-dir-use-symlinks False gulucaptain/Dynamictrl-Mask_B01 --local-dir ./Dynamictrl-Mask_B01
221 |
222 | huggingface-cli download --resume-download --local-dir-use-symlinks False gulucaptain/Dynamictrl-Mask_C01 --local-dir ./Dynamictrl-Mask_C01
223 | ```
224 |
225 | Download the checkponts of [DWPose](https://github.com/IDEA-Research/DWPose) for human pose estimation:
226 |
227 | ```bash
228 | cd checkpoints
229 |
230 | git clone https://huggingface.co/yzd-v/DWPose
231 |
232 | # Change the paths in ./dwpose/wholebody.py Lines 15 and 16.
233 | ```
234 |
235 | ## 👍 Quick Start
236 |
237 | ### Direct Inference w/ Driving Video
238 |
239 | ```bash
240 | image="./assets/human1.jpg"
241 | video="./assets/motion1.mp4"
242 |
243 | model_path="./checkpoints/DynamiCtrl"
244 | output="./outputs"
245 |
246 | CUDA_VISIBLE_DEVICES=0 python scripts/dynamictrl_inference.py \
247 | --prompt="Input the test prompt here." \
248 | --reference_image_path=$image \
249 | --ori_driving_video=$video \
250 | --model_path=$model_path \
251 | --output_path=$output \
252 | --num_inference_steps=25 \
253 | --width=768 \
254 | --height=1360 \
255 | --num_frames=37 \
256 | --pose_control_function="padaln" \
257 | --guidance_scale=3.0 \
258 | --seed=42 \
259 | ```
260 |
261 | Tips: When using the trained DynamiCtrl model without a masked area, you should ensure that the prompt content aligns with the provided human image, including the person's appearance and the background description.
262 |
263 | You can write the prompt by youself or we also provide a guidance to use Qwen2-VL tool to help you write the prompt corresponding to the content of image automatically, you can follow this blog [How to use Qwen2-VL](https://blog.csdn.net/zxs0222/article/details/144698753?spm=1001.2014.3001.5501).
264 |
265 | ### Inference w/ Maksed Human Image
266 |
267 | Thanks to the proposed "Joint-text" paradigm for this task, we can achieve fine-grained control over human motion, including background and clothing areas. It is also easy to use, just provide a human image with blacked-out areas, and you can directly run the inference script for generation. Note to replace the model path. How to automatically get the mask area? You can follow this blog: [How to get mask of subject](https://blog.csdn.net/zxs0222/article/details/147604020?spm=1001.2014.3001.5501).
268 |
269 | Note: please replace the "transformer" folder in DynamiCtrl with the "Dynamictrl-Mask_B01" or "Dynamictrl-Mask_C01" folder.
270 |
271 | ```bash
272 | image="./assets/maksed_human1.jpg" # Required
273 | video="./assets/motion.mp4"
274 |
275 | model_path="./checkpoints/Dynamictrl" # or "Dynamictrl-5B-Mask_C01"
276 | output="./outputs"
277 |
278 | CUDA_VISIBLE_DEVICES=0 python scripts/dynamictrl_inference.py \
279 | --prompt="Input the test prompt here." \
280 | --reference_image_path=$image \
281 | --ori_driving_video=$video \
282 | --model_path=$model_path \
283 | --output_path=$output \
284 | --num_inference_steps=25 \
285 | --width=768 \
286 | --height=1360 \
287 | --num_frames=37 \
288 | --pose_control_function="padaln" \
289 | --guidance_scale=3.0 \
290 | --seed=42 \
291 | ```
292 |
293 | Tips: Although the "Dynamictrl-5B-Mask_B01" and "Dynamictrl-5B-Mask_C01" models are trained with masked human images, you can still directly test whole human images with these two models. Sometimes, they may even perform better than the basic "Dynamictrl-5B" model.
294 |
295 | #### Memory and time cost
296 |
297 | | Device | Num of frames | Reslolutions | Time | GPU-mem
298 | |:---|:---:|:---:|:---:|:---:|
299 | | H20 | 37 | 1360 * 768 | 3 min 50s | 28.4 GB |
300 | | H20 | 37 | 1024 * 576 | 1 min 40s | 24.7 GB |
301 | | H20 | 37 | 1360 * 1360 | 9 min 28s | 34.8 GB |
302 | | H20 | 37 | 1024 * 1024 | 3 min 50s | 28.4 GB |
303 |
304 | ### Training
305 |
306 | Please find the instructions on data preparation and training [here](./docs/finetune.md).
307 |
308 | ## 🔅 More Applications:
309 |
310 | ### Digital Human (contains long video performance)
311 |
312 | Show cases: long video with 12 seconds, driving by the same audio.
313 |
314 |
315 |
316 |
317 |
318 | |
319 |
320 |
321 | |
322 |
323 |
324 |
325 | The identities of the digital human are generated by vivo's BlueLM model (Text to image generation).
326 |
327 | Two steps to generate a digital human:
328 |
329 | 1. Prepare a human image and a guided pose video, and generate the video materials using our DynamiCtrl.
330 |
331 | 2. Use the output video and an audio file, and apply [MuseTalk](https://github.com/TMElyralab/MuseTalk) to generate the correct lip movements.
332 |
333 |
334 |
335 | ## 📍 Citation
336 |
337 | If you find this repository helpful, please consider citing:
338 |
339 | ```
340 | @article{zhao2025dynamictrl,
341 | title={DynamiCtrl: Rethinking the Basic Structure and the Role of Text for High-quality Human Image Animation},
342 | author={Haoyu, Zhao and Zhongang, Qi and Cong, Wang and Qingping, Zheng and Guansong, Lu and Fei, Chen and Hang, Xu and Zuxuan, Wu},
343 | year={2025},
344 | journal={arXiv:2503.21246},
345 | }
346 | ```
347 |
348 | ## 💗 Acknowledgements
349 |
350 | This repository borrows heavily from [CogVideoX](https://github.com/THUDM/CogVideo). Thanks to the authors for sharing their code and models.
351 |
352 | ## 🧿 Maintenance
353 |
354 | This is the codebase for our research work. We are still working hard to update this repo, and more details are coming in days.
--------------------------------------------------------------------------------
/accelerate_configs/compiled_1.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: 'NO'
4 | downcast_bf16: 'no'
5 | dynamo_config:
6 | dynamo_backend: INDUCTOR
7 | dynamo_mode: max-autotune
8 | dynamo_use_dynamic: true
9 | dynamo_use_fullgraph: false
10 | enable_cpu_affinity: false
11 | gpu_ids: '3'
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: fp16
15 | num_machines: 1
16 | num_processes: 1
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/accelerate_configs/deepspeed.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | gradient_accumulation_steps: 1
5 | gradient_clipping: 1.0
6 | offload_optimizer_device: cpu
7 | offload_param_device: cpu
8 | zero3_init_flag: false
9 | zero_stage: 2
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: 'no'
12 | enable_cpu_affinity: false
13 | machine_rank: 0
14 | main_training_function: main
15 | mixed_precision: bf16
16 | num_machines: 1
17 | num_processes: 8
18 | rdzv_backend: static
19 | same_network: true
20 | tpu_env: []
21 | tpu_use_cluster: false
22 | tpu_use_sudo: false
23 | use_cpu: false
24 |
--------------------------------------------------------------------------------
/assets/human1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/assets/human1.jpg
--------------------------------------------------------------------------------
/assets/human2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/assets/human2.jpg
--------------------------------------------------------------------------------
/assets/human3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/assets/human3.jpg
--------------------------------------------------------------------------------
/assets/human4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/assets/human4.jpg
--------------------------------------------------------------------------------
/assets/human5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/assets/human5.jpg
--------------------------------------------------------------------------------
/assets/motion1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/assets/motion1.mp4
--------------------------------------------------------------------------------
/assets/motion2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/assets/motion2.mp4
--------------------------------------------------------------------------------
/assets/motion3.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/assets/motion3.mp4
--------------------------------------------------------------------------------
/assets/person1_img.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/assets/person1_img.jpg
--------------------------------------------------------------------------------
/assets/person2_img.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/assets/person2_img.jpg
--------------------------------------------------------------------------------
/checkpoints/Put pre-trained model here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/checkpoints/Put pre-trained model here.txt
--------------------------------------------------------------------------------
/docs/finetune.md:
--------------------------------------------------------------------------------
1 | ## Training
2 |
3 | For model training and fine-tuning, you can directly use the "dynamictrl" environment.
4 |
5 | ### Data Preparation
6 |
7 | #### Folder structure:
8 | ```
9 | Training_data/
10 | │
11 | ├── videos/
12 | │ ├── 0000001.mp4
13 | │ └── 0000002.mp4
14 | │
15 | ├── poses/
16 | │ ├── 0000001.mp4
17 | │ └── 0000001.mp4
18 | │
19 | ├── prompts.txt
20 | │
21 | ├── video.txt
22 | ```
23 |
24 | Tips: You should use a pose estimation algorithm to extract the human pose, such as the DWPose method. To obtain the prompts for the training video, we select one frame from the video and use Qwen2-VL to understand the image content, including the human's appearance and background details.
25 |
26 | #### Data format of video.txt
27 | ```
28 | {
29 | /home/user/data/Traing_data/videos/0000001.mp4
30 | /home/user/data/Traing_data/videos/0000002.mp4
31 | ...
32 | }
33 | ```
34 |
35 | #### Data format of prompts.txt
36 | ```
37 | {
38 | 0000001.mp4#####The video descripts xxx. The human xxx. The background xxx.
39 | 0000002.mp4#####The video descripts xxx. The human xxx. The background xxx.
40 | ...
41 | }
42 | ```
43 |
44 | ### Fine-tuning:
45 |
46 | ```bash
47 | export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
48 | export TORCHDYNAMO_VERBOSE=1
49 | export WANDB_MODE="offline"
50 | export NCCL_P2P_DISABLE=1
51 | export TORCH_NCCL_ENABLE_MONITORING=0
52 | export TOKENIZERS_PARALLELISM=true
53 | export OMP_NUM_THREADS=16
54 | export DS_SKIP_CUDA_CHECK=1
55 |
56 | GPU_IDS="0,1,2,3,4,5,6,7"
57 |
58 | # Training Configurations
59 | # Experiment with as many hyperparameters as you want!
60 | LEARNING_RATES=("5e-6")
61 | LR_SCHEDULES=("cosine_with_restarts")
62 | OPTIMIZERS=("adamw")
63 | MAX_TRAIN_STEPS=("50000")
64 | POSE_CONTROL_FUNCTION="padaln"
65 |
66 | # Single GPU uncompiled training
67 | ACCELERATE_CONFIG_FILE="accelerate_configs/deepspeed.yaml"
68 |
69 | # Absolute path to where the data is located. Make sure to have read the README for how to prepare data.
70 | # This example assumes you downloaded an already prepared dataset from HF CLI as follows:
71 | DATA_ROOT="/home/user/data/Traing_data"
72 | CAPTION_COLUMN="prompts.txt"
73 | VIDEO_COLUMN="video.txt"
74 | MODEL_PATH="./checkpoints/DynamiCtrl"
75 |
76 | # Set ` --load_tensors ` to load tensors from disk instead of recomputing the encoder process.
77 | # Launch experiments with different hyperparameters
78 |
79 | TRAINING_NAME="training_name"
80 |
81 | for learning_rate in "${LEARNING_RATES[@]}"; do
82 | for lr_schedule in "${LR_SCHEDULES[@]}"; do
83 | for optimizer in "${OPTIMIZERS[@]}"; do
84 | for steps in "${MAX_TRAIN_STEPS[@]}"; do
85 | output_dir="/home/user/exps/train/dynamictrl-sft_${TRAINING_NAME}_${optimizer}_steps_${steps}_lr-schedule_${lr_schedule}_learning-rate_${learning_rate}/"
86 |
87 | cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE \
88 | --gpu_ids $GPU_IDS --main_process_port 10086 \
89 | scripts/dynamictrl_sft_finetune.py \
90 | --pretrained_model_name_or_path $MODEL_PATH \
91 | --data_root $DATA_ROOT \
92 | --caption_column $CAPTION_COLUMN \
93 | --video_column $VIDEO_COLUMN \
94 | --id_token BW_STYLE \
95 | --height_buckets 1360 \
96 | --width_buckets 768 \
97 | --height 1360 \
98 | --width 768 \
99 | --frame_buckets 37 \
100 | --dataloader_num_workers 1 \
101 | --pin_memory \
102 | --enable_control_pose \
103 | --pose_control_function $POSE_CONTROL_FUNCTION \
104 | --validation_prompt \"input the test prompt here.\" \
105 | --validation_images \"/home/user/data/dynamictrl_train/test.jpg\"
106 | --validation_driving_videos \"/home/user/data/dynamictrl_train/test_driving_video.mp4\"
107 | --validation_prompt_separator ::: \
108 | --num_validation_videos 1 \
109 | --validation_epochs 100000 \
110 | --validation_steps 100000 \
111 | --seed 42 \
112 | --mixed_precision bf16 \
113 | --output_dir $output_dir \
114 | --max_num_frames 37 \
115 | --train_batch_size 64 \
116 | --max_train_steps $steps \
117 | --checkpointing_steps 1000 \
118 | --gradient_checkpointing \
119 | --gradient_accumulation_steps 4 \
120 | --learning_rate $learning_rate \
121 | --lr_scheduler $lr_schedule \
122 | --lr_warmup_steps 1000 \
123 | --lr_num_cycles 1 \
124 | --enable_slicing \
125 | --enable_tiling \
126 | --noised_image_dropout 0.05 \
127 | --optimizer $optimizer \
128 | --beta1 0.9 \
129 | --beta2 0.95 \
130 | --weight_decay 0.001 \
131 | --max_grad_norm 1.0 \
132 | --allow_tf32 \
133 | --report_to tensorboard \
134 | --nccl_timeout 1800"
135 |
136 | echo "Running command: $cmd"
137 | eval $cmd
138 | echo -ne "-------------------- Finished executing script --------------------\n\n"
139 | done
140 | done
141 | done
142 | done
143 |
144 | ```
--------------------------------------------------------------------------------
/dwpose/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/dwpose/__init__.py
--------------------------------------------------------------------------------
/dwpose/dwpose_detector.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from .wholebody import Wholebody
7 |
8 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10 |
11 | class DWposeDetector:
12 | """
13 | A pose detect method for image-like data.
14 |
15 | Parameters:
16 | model_det: (str) serialized ONNX format model path,
17 | such as https://huggingface.co/yzd-v/DWPose/blob/main/yolox_l.onnx
18 | model_pose: (str) serialized ONNX format model path,
19 | such as https://huggingface.co/yzd-v/DWPose/blob/main/dw-ll_ucoco_384.onnx
20 | device: (str) 'cpu' or 'cuda:{device_id}'
21 | """
22 | def __init__(self, model_det, model_pose, device='cpu'):
23 | self.args = model_det, model_pose, device
24 |
25 | def release_memory(self):
26 | if hasattr(self, 'pose_estimation'):
27 | del self.pose_estimation
28 | import gc; gc.collect()
29 |
30 | def __call__(self, oriImg):
31 | if not hasattr(self, 'pose_estimation'):
32 | self.pose_estimation = Wholebody(*self.args)
33 |
34 | oriImg = oriImg.copy()
35 | H, W, C = oriImg.shape
36 | with torch.no_grad():
37 | candidate, score = self.pose_estimation(oriImg)
38 | nums, _, locs = candidate.shape
39 | candidate[..., 0] /= float(W)
40 | candidate[..., 1] /= float(H)
41 | body = candidate[:, :18].copy()
42 | body = body.reshape(nums * 18, locs)
43 | subset = score[:, :18].copy()
44 | for i in range(len(subset)):
45 | for j in range(len(subset[i])):
46 | if subset[i][j] > 0.3:
47 | subset[i][j] = int(18 * i + j)
48 | else:
49 | subset[i][j] = -1
50 |
51 | # un_visible = subset < 0.3
52 | # candidate[un_visible] = -1
53 |
54 | # foot = candidate[:, 18:24]
55 |
56 | faces = candidate[:, 24:92]
57 |
58 | hands = candidate[:, 92:113]
59 | hands = np.vstack([hands, candidate[:, 113:]])
60 |
61 | faces_score = score[:, 24:92]
62 | hands_score = np.vstack([score[:, 92:113], score[:, 113:]])
63 |
64 | bodies = dict(candidate=body, subset=subset, score=score[:, :18])
65 | pose = dict(bodies=bodies, hands=hands, hands_score=hands_score, faces=faces, faces_score=faces_score)
66 |
67 | return pose
68 |
69 | dwpose_detector = DWposeDetector(
70 | model_det="pretrained/DWPose/yolox_l.onnx",
71 | model_pose="pretrained/DWPose/dw-ll_ucoco_384.onnx",
72 | device=device
73 | )
74 |
--------------------------------------------------------------------------------
/dwpose/onnxdet.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | def nms(boxes, scores, nms_thr):
6 | """Single class NMS implemented in Numpy.
7 |
8 | Args:
9 | boxes (np.ndarray): shape=(N,4); N is number of boxes
10 | scores (np.ndarray): the score of bboxes
11 | nms_thr (float): the threshold in NMS
12 |
13 | Returns:
14 | List[int]: output bbox ids
15 | """
16 | x1 = boxes[:, 0]
17 | y1 = boxes[:, 1]
18 | x2 = boxes[:, 2]
19 | y2 = boxes[:, 3]
20 |
21 | areas = (x2 - x1 + 1) * (y2 - y1 + 1)
22 | order = scores.argsort()[::-1]
23 |
24 | keep = []
25 | while order.size > 0:
26 | i = order[0]
27 | keep.append(i)
28 | xx1 = np.maximum(x1[i], x1[order[1:]])
29 | yy1 = np.maximum(y1[i], y1[order[1:]])
30 | xx2 = np.minimum(x2[i], x2[order[1:]])
31 | yy2 = np.minimum(y2[i], y2[order[1:]])
32 |
33 | w = np.maximum(0.0, xx2 - xx1 + 1)
34 | h = np.maximum(0.0, yy2 - yy1 + 1)
35 | inter = w * h
36 | ovr = inter / (areas[i] + areas[order[1:]] - inter)
37 |
38 | inds = np.where(ovr <= nms_thr)[0]
39 | order = order[inds + 1]
40 |
41 | return keep
42 |
43 | def multiclass_nms(boxes, scores, nms_thr, score_thr):
44 | """Multiclass NMS implemented in Numpy. Class-aware version.
45 |
46 | Args:
47 | boxes (np.ndarray): shape=(N,4); N is number of boxes
48 | scores (np.ndarray): the score of bboxes
49 | nms_thr (float): the threshold in NMS
50 | score_thr (float): the threshold of cls score
51 |
52 | Returns:
53 | np.ndarray: outputs bboxes coordinate
54 | """
55 | final_dets = []
56 | num_classes = scores.shape[1]
57 | for cls_ind in range(num_classes):
58 | cls_scores = scores[:, cls_ind]
59 | valid_score_mask = cls_scores > score_thr
60 | if valid_score_mask.sum() == 0:
61 | continue
62 | else:
63 | valid_scores = cls_scores[valid_score_mask]
64 | valid_boxes = boxes[valid_score_mask]
65 | keep = nms(valid_boxes, valid_scores, nms_thr)
66 | if len(keep) > 0:
67 | cls_inds = np.ones((len(keep), 1)) * cls_ind
68 | dets = np.concatenate(
69 | [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
70 | )
71 | final_dets.append(dets)
72 | if len(final_dets) == 0:
73 | return None
74 | return np.concatenate(final_dets, 0)
75 |
76 | def demo_postprocess(outputs, img_size, p6=False):
77 | grids = []
78 | expanded_strides = []
79 | strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
80 |
81 | hsizes = [img_size[0] // stride for stride in strides]
82 | wsizes = [img_size[1] // stride for stride in strides]
83 |
84 | for hsize, wsize, stride in zip(hsizes, wsizes, strides):
85 | xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
86 | grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
87 | grids.append(grid)
88 | shape = grid.shape[:2]
89 | expanded_strides.append(np.full((*shape, 1), stride))
90 |
91 | grids = np.concatenate(grids, 1)
92 | expanded_strides = np.concatenate(expanded_strides, 1)
93 | outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
94 | outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
95 |
96 | return outputs
97 |
98 | def preprocess(img, input_size, swap=(2, 0, 1)):
99 | if len(img.shape) == 3:
100 | padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
101 | else:
102 | padded_img = np.ones(input_size, dtype=np.uint8) * 114
103 |
104 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
105 | resized_img = cv2.resize(
106 | img,
107 | (int(img.shape[1] * r), int(img.shape[0] * r)),
108 | interpolation=cv2.INTER_LINEAR,
109 | ).astype(np.uint8)
110 | padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
111 |
112 | padded_img = padded_img.transpose(swap)
113 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
114 | return padded_img, r
115 |
116 | def inference_detector(session, oriImg):
117 | """run human detect
118 | """
119 | input_shape = (640,640)
120 | img, ratio = preprocess(oriImg, input_shape)
121 |
122 | ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
123 | output = session.run(None, ort_inputs)
124 | predictions = demo_postprocess(output[0], input_shape)[0]
125 |
126 | boxes = predictions[:, :4]
127 | scores = predictions[:, 4:5] * predictions[:, 5:]
128 |
129 | boxes_xyxy = np.ones_like(boxes)
130 | boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
131 | boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
132 | boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
133 | boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
134 | boxes_xyxy /= ratio
135 | dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
136 | if dets is not None:
137 | final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
138 | isscore = final_scores>0.3
139 | iscat = final_cls_inds == 0
140 | isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
141 | final_boxes = final_boxes[isbbox]
142 | else:
143 | final_boxes = np.array([])
144 |
145 | return final_boxes
146 |
--------------------------------------------------------------------------------
/dwpose/onnxpose.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | import cv2
4 | import numpy as np
5 | import onnxruntime as ort
6 |
7 | def preprocess(
8 | img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
9 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
10 | """Do preprocessing for RTMPose model inference.
11 |
12 | Args:
13 | img (np.ndarray): Input image in shape.
14 | input_size (tuple): Input image size in shape (w, h).
15 |
16 | Returns:
17 | tuple:
18 | - resized_img (np.ndarray): Preprocessed image.
19 | - center (np.ndarray): Center of image.
20 | - scale (np.ndarray): Scale of image.
21 | """
22 | # get shape of image
23 | img_shape = img.shape[:2]
24 | out_img, out_center, out_scale = [], [], []
25 | if len(out_bbox) == 0:
26 | out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
27 | for i in range(len(out_bbox)):
28 | x0 = out_bbox[i][0]
29 | y0 = out_bbox[i][1]
30 | x1 = out_bbox[i][2]
31 | y1 = out_bbox[i][3]
32 | bbox = np.array([x0, y0, x1, y1])
33 |
34 | # get center and scale
35 | center, scale = bbox_xyxy2cs(bbox, padding=1.25)
36 |
37 | # do affine transformation
38 | resized_img, scale = top_down_affine(input_size, scale, center, img)
39 |
40 | # normalize image
41 | mean = np.array([123.675, 116.28, 103.53])
42 | std = np.array([58.395, 57.12, 57.375])
43 | resized_img = (resized_img - mean) / std
44 |
45 | out_img.append(resized_img)
46 | out_center.append(center)
47 | out_scale.append(scale)
48 |
49 | return out_img, out_center, out_scale
50 |
51 |
52 | def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
53 | """Inference RTMPose model.
54 |
55 | Args:
56 | sess (ort.InferenceSession): ONNXRuntime session.
57 | img (np.ndarray): Input image in shape.
58 |
59 | Returns:
60 | outputs (np.ndarray): Output of RTMPose model.
61 | """
62 | all_out = []
63 | # build input
64 | for i in range(len(img)):
65 | input = [img[i].transpose(2, 0, 1)]
66 |
67 | # build output
68 | sess_input = {sess.get_inputs()[0].name: input}
69 | sess_output = []
70 | for out in sess.get_outputs():
71 | sess_output.append(out.name)
72 |
73 | # run model
74 | outputs = sess.run(sess_output, sess_input)
75 | all_out.append(outputs)
76 |
77 | return all_out
78 |
79 |
80 | def postprocess(outputs: List[np.ndarray],
81 | model_input_size: Tuple[int, int],
82 | center: Tuple[int, int],
83 | scale: Tuple[int, int],
84 | simcc_split_ratio: float = 2.0
85 | ) -> Tuple[np.ndarray, np.ndarray]:
86 | """Postprocess for RTMPose model output.
87 |
88 | Args:
89 | outputs (np.ndarray): Output of RTMPose model.
90 | model_input_size (tuple): RTMPose model Input image size.
91 | center (tuple): Center of bbox in shape (x, y).
92 | scale (tuple): Scale of bbox in shape (w, h).
93 | simcc_split_ratio (float): Split ratio of simcc.
94 |
95 | Returns:
96 | tuple:
97 | - keypoints (np.ndarray): Rescaled keypoints.
98 | - scores (np.ndarray): Model predict scores.
99 | """
100 | all_key = []
101 | all_score = []
102 | for i in range(len(outputs)):
103 | # use simcc to decode
104 | simcc_x, simcc_y = outputs[i]
105 | keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
106 |
107 | # rescale keypoints
108 | keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
109 | all_key.append(keypoints[0])
110 | all_score.append(scores[0])
111 |
112 | return np.array(all_key), np.array(all_score)
113 |
114 |
115 | def bbox_xyxy2cs(bbox: np.ndarray,
116 | padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
117 | """Transform the bbox format from (x,y,w,h) into (center, scale)
118 |
119 | Args:
120 | bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
121 | as (left, top, right, bottom)
122 | padding (float): BBox padding factor that will be multilied to scale.
123 | Default: 1.0
124 |
125 | Returns:
126 | tuple: A tuple containing center and scale.
127 | - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
128 | (n, 2)
129 | - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
130 | (n, 2)
131 | """
132 | # convert single bbox from (4, ) to (1, 4)
133 | dim = bbox.ndim
134 | if dim == 1:
135 | bbox = bbox[None, :]
136 |
137 | # get bbox center and scale
138 | x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
139 | center = np.hstack([x1 + x2, y1 + y2]) * 0.5
140 | scale = np.hstack([x2 - x1, y2 - y1]) * padding
141 |
142 | if dim == 1:
143 | center = center[0]
144 | scale = scale[0]
145 |
146 | return center, scale
147 |
148 |
149 | def _fix_aspect_ratio(bbox_scale: np.ndarray,
150 | aspect_ratio: float) -> np.ndarray:
151 | """Extend the scale to match the given aspect ratio.
152 |
153 | Args:
154 | scale (np.ndarray): The image scale (w, h) in shape (2, )
155 | aspect_ratio (float): The ratio of ``w/h``
156 |
157 | Returns:
158 | np.ndarray: The reshaped image scale in (2, )
159 | """
160 | w, h = np.hsplit(bbox_scale, [1])
161 | bbox_scale = np.where(w > h * aspect_ratio,
162 | np.hstack([w, w / aspect_ratio]),
163 | np.hstack([h * aspect_ratio, h]))
164 | return bbox_scale
165 |
166 |
167 | def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
168 | """Rotate a point by an angle.
169 |
170 | Args:
171 | pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
172 | angle_rad (float): rotation angle in radian
173 |
174 | Returns:
175 | np.ndarray: Rotated point in shape (2, )
176 | """
177 | sn, cs = np.sin(angle_rad), np.cos(angle_rad)
178 | rot_mat = np.array([[cs, -sn], [sn, cs]])
179 | return rot_mat @ pt
180 |
181 |
182 | def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
183 | """To calculate the affine matrix, three pairs of points are required. This
184 | function is used to get the 3rd point, given 2D points a & b.
185 |
186 | The 3rd point is defined by rotating vector `a - b` by 90 degrees
187 | anticlockwise, using b as the rotation center.
188 |
189 | Args:
190 | a (np.ndarray): The 1st point (x,y) in shape (2, )
191 | b (np.ndarray): The 2nd point (x,y) in shape (2, )
192 |
193 | Returns:
194 | np.ndarray: The 3rd point.
195 | """
196 | direction = a - b
197 | c = b + np.r_[-direction[1], direction[0]]
198 | return c
199 |
200 |
201 | def get_warp_matrix(center: np.ndarray,
202 | scale: np.ndarray,
203 | rot: float,
204 | output_size: Tuple[int, int],
205 | shift: Tuple[float, float] = (0., 0.),
206 | inv: bool = False) -> np.ndarray:
207 | """Calculate the affine transformation matrix that can warp the bbox area
208 | in the input image to the output size.
209 |
210 | Args:
211 | center (np.ndarray[2, ]): Center of the bounding box (x, y).
212 | scale (np.ndarray[2, ]): Scale of the bounding box
213 | wrt [width, height].
214 | rot (float): Rotation angle (degree).
215 | output_size (np.ndarray[2, ] | list(2,)): Size of the
216 | destination heatmaps.
217 | shift (0-100%): Shift translation ratio wrt the width/height.
218 | Default (0., 0.).
219 | inv (bool): Option to inverse the affine transform direction.
220 | (inv=False: src->dst or inv=True: dst->src)
221 |
222 | Returns:
223 | np.ndarray: A 2x3 transformation matrix
224 | """
225 | shift = np.array(shift)
226 | src_w = scale[0]
227 | dst_w = output_size[0]
228 | dst_h = output_size[1]
229 |
230 | # compute transformation matrix
231 | rot_rad = np.deg2rad(rot)
232 | src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
233 | dst_dir = np.array([0., dst_w * -0.5])
234 |
235 | # get four corners of the src rectangle in the original image
236 | src = np.zeros((3, 2), dtype=np.float32)
237 | src[0, :] = center + scale * shift
238 | src[1, :] = center + src_dir + scale * shift
239 | src[2, :] = _get_3rd_point(src[0, :], src[1, :])
240 |
241 | # get four corners of the dst rectangle in the input image
242 | dst = np.zeros((3, 2), dtype=np.float32)
243 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
244 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
245 | dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
246 |
247 | if inv:
248 | warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
249 | else:
250 | warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
251 |
252 | return warp_mat
253 |
254 |
255 | def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
256 | img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
257 | """Get the bbox image as the model input by affine transform.
258 |
259 | Args:
260 | input_size (dict): The input size of the model.
261 | bbox_scale (dict): The bbox scale of the img.
262 | bbox_center (dict): The bbox center of the img.
263 | img (np.ndarray): The original image.
264 |
265 | Returns:
266 | tuple: A tuple containing center and scale.
267 | - np.ndarray[float32]: img after affine transform.
268 | - np.ndarray[float32]: bbox scale after affine transform.
269 | """
270 | w, h = input_size
271 | warp_size = (int(w), int(h))
272 |
273 | # reshape bbox to fixed aspect ratio
274 | bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
275 |
276 | # get the affine matrix
277 | center = bbox_center
278 | scale = bbox_scale
279 | rot = 0
280 | warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
281 |
282 | # do affine transform
283 | img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
284 |
285 | return img, bbox_scale
286 |
287 |
288 | def get_simcc_maximum(simcc_x: np.ndarray,
289 | simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
290 | """Get maximum response location and value from simcc representations.
291 |
292 | Note:
293 | instance number: N
294 | num_keypoints: K
295 | heatmap height: H
296 | heatmap width: W
297 |
298 | Args:
299 | simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
300 | simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
301 |
302 | Returns:
303 | tuple:
304 | - locs (np.ndarray): locations of maximum heatmap responses in shape
305 | (K, 2) or (N, K, 2)
306 | - vals (np.ndarray): values of maximum heatmap responses in shape
307 | (K,) or (N, K)
308 | """
309 | N, K, Wx = simcc_x.shape
310 | simcc_x = simcc_x.reshape(N * K, -1)
311 | simcc_y = simcc_y.reshape(N * K, -1)
312 |
313 | # get maximum value locations
314 | x_locs = np.argmax(simcc_x, axis=1)
315 | y_locs = np.argmax(simcc_y, axis=1)
316 | locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
317 | max_val_x = np.amax(simcc_x, axis=1)
318 | max_val_y = np.amax(simcc_y, axis=1)
319 |
320 | # get maximum value across x and y axis
321 | mask = max_val_x > max_val_y
322 | max_val_x[mask] = max_val_y[mask]
323 | vals = max_val_x
324 | locs[vals <= 0.] = -1
325 |
326 | # reshape
327 | locs = locs.reshape(N, K, 2)
328 | vals = vals.reshape(N, K)
329 |
330 | return locs, vals
331 |
332 |
333 | def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
334 | simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
335 | """Modulate simcc distribution with Gaussian.
336 |
337 | Args:
338 | simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
339 | simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
340 | simcc_split_ratio (int): The split ratio of simcc.
341 |
342 | Returns:
343 | tuple: A tuple containing center and scale.
344 | - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
345 | - np.ndarray[float32]: scores in shape (K,) or (n, K)
346 | """
347 | keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
348 | keypoints /= simcc_split_ratio
349 |
350 | return keypoints, scores
351 |
352 |
353 | def inference_pose(session, out_bbox, oriImg):
354 | """run pose detect
355 |
356 | Args:
357 | session (ort.InferenceSession): ONNXRuntime session.
358 | out_bbox (np.ndarray): bbox list
359 | oriImg (np.ndarray): Input image in shape.
360 |
361 | Returns:
362 | tuple:
363 | - keypoints (np.ndarray): Rescaled keypoints.
364 | - scores (np.ndarray): Model predict scores.
365 | """
366 | h, w = session.get_inputs()[0].shape[2:]
367 | model_input_size = (w, h)
368 | # preprocess for rtm-pose model inference.
369 | resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
370 | # run pose estimation for processed img
371 | outputs = inference(session, resized_img)
372 | # postprocess for rtm-pose model output.
373 | keypoints, scores = postprocess(outputs, model_input_size, center, scale)
374 |
375 | return keypoints, scores
376 |
--------------------------------------------------------------------------------
/dwpose/preprocess.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import decord
3 | import numpy as np
4 |
5 | from .util import draw_pose
6 | from .dwpose_detector import dwpose_detector as dwprocessor
7 |
8 |
9 | def get_video_pose(
10 | video_path: str,
11 | ref_image: np.ndarray,
12 | max_frame_num = None,
13 | sample_stride: int=1):
14 | """preprocess ref image pose and video pose
15 |
16 | Args:
17 | video_path (str): video pose path
18 | ref_image (np.ndarray): reference image
19 | sample_stride (int, optional): Defaults to 1.
20 |
21 | Returns:
22 | np.ndarray: sequence of video pose
23 | """
24 | # select ref-keypoint from reference pose for pose rescale
25 | ref_pose = dwprocessor(ref_image)
26 | ref_keypoint_id = [0, 1, 2, 5, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
27 | ref_keypoint_id = [i for i in ref_keypoint_id \
28 | if len(ref_pose['bodies']['subset']) > 0 and ref_pose['bodies']['subset'][0][i] >= .0]
29 | ref_body = ref_pose['bodies']['candidate'][ref_keypoint_id]
30 |
31 | height, width, _ = ref_image.shape
32 |
33 | # read input video
34 | vr = decord.VideoReader(video_path, ctx=decord.cpu(0))
35 | if max_frame_num is None:
36 | max_frame_num = len(vr)
37 | else:
38 | max_frame_num = min(len(vr), max_frame_num * sample_stride)
39 |
40 | sample_stride = sample_stride
41 | begin_frame_index = 0
42 | frames = vr.get_batch(list(range(begin_frame_index, max_frame_num * sample_stride + begin_frame_index, sample_stride))).asnumpy()
43 | detected_poses = [dwprocessor(frm) for frm in tqdm(frames, desc="DWPose")]
44 |
45 | dwprocessor.release_memory()
46 | detected_bodies = np.stack([p['bodies']['candidate'] for p in detected_poses if p['bodies']['candidate'].shape[0] == 18])[:, ref_keypoint_id]
47 |
48 | keep_person_body_unchange = False
49 | if keep_person_body_unchange:
50 | output_pose = []
51 | for i in range(len(detected_poses)):
52 | ay, by = np.polyfit(detected_bodies[i, :, 1].flatten(), np.tile(ref_body[:, 1], 1), 1)
53 | fh, fw, _ = vr[0].shape
54 | ax = ay / (fh / fw / height * width)
55 | bx = np.mean(np.tile(ref_body[:, 0], 1) - detected_bodies[i, :, 0].flatten() * ax)
56 | a = np.array([ax, ay])
57 | b = np.array([bx, by])
58 |
59 | detected_poses[i]['bodies']['candidate'] = detected_poses[i]['bodies']['candidate'] * a + b
60 | detected_poses[i]['faces'] = detected_poses[i]['faces'] * a + b
61 | detected_poses[i]['hands'] = detected_poses[i]['hands'] * a + b
62 | im = draw_pose(detected_poses[i], height, width)
63 | output_pose.append(np.array(im))
64 | return np.stack(output_pose)
65 |
66 | # compute linear-rescale params
67 | ay, by = np.polyfit(detected_bodies[:, :, 1].flatten(), np.tile(ref_body[:, 1], len(detected_bodies)), 1)
68 | if ref_body[:, 1][0] > detected_bodies[0, :, 1][0]:
69 | by = ref_body[:, 1][0] - detected_bodies[0, :, 1][0]
70 | elif ref_body[:, 1][0] < detected_bodies[0, :, 1][0]:
71 | by = -(detected_bodies[0, :, 1][0] - ref_body[:, 1][0])
72 |
73 | fh, fw, _ = vr[0].shape
74 | ax = ay / (fh / fw / height * width)
75 | bx = np.mean(np.tile(ref_body[:, 0], len(detected_bodies)) - detected_bodies[:, :, 0].flatten() * ax)
76 | a = np.array([ax, ay])
77 | b = np.array([bx, by])
78 |
79 | re_ensure_by_candidate = detected_poses[0]['bodies']['candidate'] * a + b
80 | if ref_body[:, 1][0] > re_ensure_by_candidate[0][1]:
81 | by = by + ref_body[:, 1][0] - re_ensure_by_candidate[0][1]
82 | elif ref_body[:, 1][0] < re_ensure_by_candidate[0][1]:
83 | by = by - (re_ensure_by_candidate[0][1] - ref_body[:, 1][0])
84 | b = np.array([bx, by])
85 |
86 | output_pose = []
87 | detected_pose_index = 0
88 | for detected_pose in detected_poses:
89 | detected_pose['bodies']['candidate'] = detected_pose['bodies']['candidate'] * a + b
90 | detected_pose['faces'] = detected_pose['faces'] * a + b
91 | detected_pose['hands'] = detected_pose['hands'] * a + b
92 | im = draw_pose(detected_pose, height, width)
93 | output_pose.append(np.array(im))
94 | detected_pose_index += 1
95 | return np.stack(output_pose)
96 |
97 |
98 | def get_image_pose(ref_image):
99 | """process image pose
100 |
101 | Args:
102 | ref_image (np.ndarray): reference image pixel value
103 |
104 | Returns:
105 | np.ndarray: pose visual image in RGB-mode
106 | """
107 | height, width, _ = ref_image.shape
108 | ref_pose = dwprocessor(ref_image)
109 | pose_img = draw_pose(ref_pose, height, width)
110 | return np.array(pose_img)
111 |
--------------------------------------------------------------------------------
/dwpose/util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import matplotlib
4 | import cv2
5 |
6 |
7 | eps = 0.01
8 |
9 | def alpha_blend_color(color, alpha):
10 | """blend color according to point conf
11 | """
12 | return [int(c * alpha) for c in color]
13 |
14 | def draw_bodypose(canvas, candidate, subset, score):
15 | H, W, C = canvas.shape
16 | candidate = np.array(candidate)
17 | subset = np.array(subset)
18 |
19 | stickwidth = 4
20 |
21 | limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
22 | [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
23 | [1, 16], [16, 18], [3, 17], [6, 18]]
24 |
25 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
26 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
27 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
28 |
29 | for i in range(17):
30 | for n in range(len(subset)):
31 | index = subset[n][np.array(limbSeq[i]) - 1]
32 | conf = score[n][np.array(limbSeq[i]) - 1]
33 | if conf[0] < 0.3 or conf[1] < 0.3:
34 | continue
35 | Y = candidate[index.astype(int), 0] * float(W)
36 | X = candidate[index.astype(int), 1] * float(H)
37 | mX = np.mean(X)
38 | mY = np.mean(Y)
39 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
40 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
41 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
42 | cv2.fillConvexPoly(canvas, polygon, alpha_blend_color(colors[i], conf[0] * conf[1]))
43 |
44 | canvas = (canvas * 0.6).astype(np.uint8)
45 |
46 | for i in range(18):
47 | for n in range(len(subset)):
48 | index = int(subset[n][i])
49 | if index == -1:
50 | continue
51 | x, y = candidate[index][0:2]
52 | conf = score[n][i]
53 | x = int(x * W)
54 | y = int(y * H)
55 | cv2.circle(canvas, (int(x), int(y)), 4, alpha_blend_color(colors[i], conf), thickness=-1)
56 |
57 | return canvas
58 |
59 | def draw_handpose(canvas, all_hand_peaks, all_hand_scores):
60 | H, W, C = canvas.shape
61 |
62 | edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
63 | [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
64 |
65 | for peaks, scores in zip(all_hand_peaks, all_hand_scores):
66 |
67 | for ie, e in enumerate(edges):
68 | x1, y1 = peaks[e[0]]
69 | x2, y2 = peaks[e[1]]
70 | x1 = int(x1 * W)
71 | y1 = int(y1 * H)
72 | x2 = int(x2 * W)
73 | y2 = int(y2 * H)
74 | score = int(scores[e[0]] * scores[e[1]] * 255)
75 | if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
76 | cv2.line(canvas, (x1, y1), (x2, y2),
77 | matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * score, thickness=2)
78 |
79 | for i, keyponit in enumerate(peaks):
80 | x, y = keyponit
81 | x = int(x * W)
82 | y = int(y * H)
83 | score = int(scores[i] * 255)
84 | if x > eps and y > eps:
85 | cv2.circle(canvas, (x, y), 4, (0, 0, score), thickness=-1)
86 | return canvas
87 |
88 | def draw_facepose(canvas, all_lmks, all_scores):
89 | H, W, C = canvas.shape
90 | for lmks, scores in zip(all_lmks, all_scores):
91 | for lmk, score in zip(lmks, scores):
92 | x, y = lmk
93 | x = int(x * W)
94 | y = int(y * H)
95 | conf = int(score * 255)
96 | if x > eps and y > eps:
97 | cv2.circle(canvas, (x, y), 3, (conf, conf, conf), thickness=-1)
98 | return canvas
99 |
100 | def draw_pose(pose, H, W, ref_w=2160):
101 | """vis dwpose outputs
102 |
103 | Args:
104 | pose (List): DWposeDetector outputs in dwpose_detector.py
105 | H (int): height
106 | W (int): width
107 | ref_w (int, optional) Defaults to 2160.
108 |
109 | Returns:
110 | np.ndarray: image pixel value in RGB mode
111 | """
112 | bodies = pose['bodies']
113 | faces = pose['faces']
114 | hands = pose['hands']
115 | candidate = bodies['candidate']
116 | subset = bodies['subset']
117 |
118 | sz = min(H, W)
119 | sr = (ref_w / sz) if sz != ref_w else 1
120 |
121 | ########################################## create zero canvas ##################################################
122 | canvas = np.zeros(shape=(int(H*sr), int(W*sr), 3), dtype=np.uint8)
123 |
124 | ########################################### draw body pose #####################################################
125 | canvas = draw_bodypose(canvas, candidate, subset, score=bodies['score'])
126 |
127 | ########################################### draw hand pose #####################################################
128 | canvas = draw_handpose(canvas, hands, pose['hands_score'])
129 |
130 | ########################################### draw face pose #####################################################
131 | canvas = draw_facepose(canvas, faces, pose['faces_score'])
132 |
133 | return cv2.cvtColor(cv2.resize(canvas, (W, H)), cv2.COLOR_BGR2RGB).transpose(2, 0, 1)
134 |
--------------------------------------------------------------------------------
/dwpose/wholebody.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import onnxruntime as ort
3 |
4 | from .onnxdet import inference_detector
5 | from .onnxpose import inference_pose
6 |
7 |
8 | class Wholebody:
9 | """detect human pose by dwpose
10 | """
11 | def __init__(self, model_det, model_pose, device="cpu"):
12 | providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider']
13 | provider_options = None if device == 'cpu' else [{'device_id': 0}]
14 |
15 | model_det = "../checkpoints/DWPose/yolox_l.onnx"
16 | model_pose = "../checkpoints/DWPose/dw-ll_ucoco_384.onnx"
17 |
18 | self.session_det = ort.InferenceSession(
19 | path_or_bytes=model_det, providers=providers, provider_options=provider_options
20 | )
21 | self.session_pose = ort.InferenceSession(
22 | path_or_bytes=model_pose, providers=providers, provider_options=provider_options
23 | )
24 |
25 | def __call__(self, oriImg):
26 | """call to process dwpose-detect
27 |
28 | Args:
29 | oriImg (np.ndarray): detected image
30 |
31 | """
32 | det_result = inference_detector(self.session_det, oriImg)
33 | keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
34 |
35 | keypoints_info = np.concatenate(
36 | (keypoints, scores[..., None]), axis=-1)
37 | # compute neck joint
38 | neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
39 | # neck score when visualizing pred
40 | neck[:, 2:4] = np.logical_and(
41 | keypoints_info[:, 5, 2:4] > 0.3,
42 | keypoints_info[:, 6, 2:4] > 0.3).astype(int)
43 | new_keypoints_info = np.insert(
44 | keypoints_info, 17, neck, axis=1)
45 | mmpose_idx = [
46 | 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
47 | ]
48 | openpose_idx = [
49 | 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
50 | ]
51 | new_keypoints_info[:, openpose_idx] = \
52 | new_keypoints_info[:, mmpose_idx]
53 | keypoints_info = new_keypoints_info
54 |
55 | keypoints, scores = keypoints_info[
56 | ..., :2], keypoints_info[..., 2]
57 |
58 | return keypoints, scores
59 |
--------------------------------------------------------------------------------
/dynamictrl/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/dynamictrl/__init__.py
--------------------------------------------------------------------------------
/dynamictrl/models/modules/cross_attention.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Tuple, Union, Optional, Any
3 |
4 | import torch
5 | import torch.utils
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torch.utils.checkpoint as checkpoint
9 | import einops
10 | from einops import repeat
11 |
12 | from diffusers import AutoencoderKL
13 | from timm.models.vision_transformer import Mlp
14 | from timm.models.layers import to_2tuple
15 | from transformers import (
16 | AutoTokenizer,
17 | MT5EncoderModel,
18 | BertModel,
19 | )
20 |
21 | memory_efficient_attention = None
22 | try:
23 | import xformers
24 | except:
25 | pass
26 |
27 | try:
28 | from xformers.ops import memory_efficient_attention
29 | except:
30 | memory_efficient_attention = None
31 |
32 | def reshape_for_broadcast(
33 | freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
34 | x: torch.Tensor,
35 | head_first=False,
36 | ):
37 | """
38 | Reshape frequency tensor for broadcasting it with another tensor.
39 |
40 | This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
41 | for the purpose of broadcasting the frequency tensor during element-wise operations.
42 |
43 | Args:
44 | freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
45 | x (torch.Tensor): Target tensor for broadcasting compatibility.
46 | head_first (bool): head dimension first (except batch dim) or not.
47 |
48 | Returns:
49 | torch.Tensor: Reshaped frequency tensor.
50 |
51 | Raises:
52 | AssertionError: If the frequency tensor doesn't match the expected shape.
53 | AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
54 | """
55 | ndim = x.ndim
56 | assert 0 <= 1 < ndim
57 |
58 | if isinstance(freqs_cis, tuple):
59 | # freqs_cis: (cos, sin) in real space
60 | if head_first:
61 | assert freqs_cis[0].shape == (
62 | x.shape[-2],
63 | x.shape[-1],
64 | ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
65 | shape = [
66 | d if i == ndim - 2 or i == ndim - 1 else 1
67 | for i, d in enumerate(x.shape)
68 | ]
69 | else:
70 | assert freqs_cis[0].shape == (
71 | x.shape[1],
72 | x.shape[-1],
73 | ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
74 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
75 | return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
76 | else:
77 | # freqs_cis: values in complex space
78 | if head_first:
79 | assert freqs_cis.shape == (
80 | x.shape[-2],
81 | x.shape[-1],
82 | ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
83 | shape = [
84 | d if i == ndim - 2 or i == ndim - 1 else 1
85 | for i, d in enumerate(x.shape)
86 | ]
87 | else:
88 | assert freqs_cis.shape == (
89 | x.shape[1],
90 | x.shape[-1],
91 | ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
92 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
93 | return freqs_cis.view(*shape)
94 |
95 | MEMORY_LAYOUTS = {
96 | "torch": (
97 | lambda x, head_dim: x.transpose(1, 2),
98 | lambda x: x.transpose(1, 2),
99 | lambda x: (1, x, 1, 1),
100 | ),
101 | "xformers": (
102 | lambda x, head_dim: x,
103 | lambda x: x,
104 | lambda x: (1, 1, x, 1),
105 | ),
106 | "math": (
107 | lambda x, head_dim: x.transpose(1, 2),
108 | lambda x: x.transpose(1, 2),
109 | lambda x: (1, x, 1, 1),
110 | ),
111 | }
112 |
113 | def rotate_half(x):
114 | x_real, x_imag = (
115 | x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
116 | ) # [B, S, H, D//2]
117 | return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
118 |
119 | def apply_rotary_emb(
120 | xq: torch.Tensor,
121 | xk: Optional[torch.Tensor],
122 | freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
123 | head_first: bool = False,
124 | ) -> Tuple[torch.Tensor, torch.Tensor]:
125 | """
126 | Apply rotary embeddings to input tensors using the given frequency tensor.
127 |
128 | This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
129 | frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
130 | is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
131 | returned as real tensors.
132 |
133 | Args:
134 | xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
135 | xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
136 | freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
137 | head_first (bool): head dimension first (except batch dim) or not.
138 |
139 | Returns:
140 | Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
141 |
142 | """
143 | xk_out = None
144 | if isinstance(freqs_cis, tuple):
145 | cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
146 | cos, sin = cos.to(xq.device), sin.to(xq.device)
147 | xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
148 | if xk is not None:
149 | xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
150 | else:
151 | xq_ = torch.view_as_complex(
152 | xq.float().reshape(*xq.shape[:-1], -1, 2)
153 | ) # [B, S, H, D//2]
154 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
155 | xq.device
156 | ) # [S, D//2] --> [1, S, 1, D//2]
157 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
158 | if xk is not None:
159 | xk_ = torch.view_as_complex(
160 | xk.float().reshape(*xk.shape[:-1], -1, 2)
161 | ) # [B, S, H, D//2]
162 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
163 |
164 | return xq_out, xk_out
165 |
166 | def vanilla_attention(q, k, v, mask, dropout_p, scale=None):
167 | if scale is None:
168 | scale = math.sqrt(q.size(-1))
169 | scores = torch.bmm(q, k.transpose(-1, -2)) / scale
170 | if mask is not None:
171 | mask = einops.rearrange(mask, "b ... -> b (...)")
172 | max_neg_value = -torch.finfo(scores.dtype).max
173 | mask = einops.repeat(mask, "b j -> (b h) j", h=q.size(-3))
174 | scores = scores.masked_fill(~mask, max_neg_value)
175 | p_attn = F.softmax(scores, dim=-1)
176 | if dropout_p != 0:
177 | scores = F.dropout(p_attn, p=dropout_p, training=True)
178 | return torch.bmm(p_attn, v)
179 |
180 | def attention(q, k, v, head_dim, dropout_p=0, mask=None, scale=None, mode="xformers"):
181 | """
182 | q, k, v: [B, L, H, D]
183 | """
184 | pre_attn_layout = MEMORY_LAYOUTS[mode][0]
185 | post_attn_layout = MEMORY_LAYOUTS[mode][1]
186 | q = pre_attn_layout(q, head_dim)
187 | k = pre_attn_layout(k, head_dim)
188 | v = pre_attn_layout(v, head_dim)
189 |
190 | # scores = ATTN_FUNCTION[mode](q, k.to(q), v.to(q), mask, scale=scale)
191 | if mode == "torch":
192 | assert scale is None
193 | scores = F.scaled_dot_product_attention(
194 | q, k.to(q), v.to(q), mask, dropout_p
195 | ) # , scale=scale)
196 | elif mode == "xformers":
197 | scores = memory_efficient_attention(
198 | q, k.to(q), v.to(q), mask, dropout_p, scale=scale
199 | )
200 | else:
201 | scores = vanilla_attention(q, k.to(q), v.to(q), mask, dropout_p, scale=scale)
202 |
203 | scores = post_attn_layout(scores)
204 | return scores
205 |
206 |
207 | class CrossAttention(nn.Module):
208 | """
209 | Use QK Normalization.
210 | """
211 |
212 | def __init__(
213 | self,
214 | qdim,
215 | kdim,
216 | num_heads,
217 | qkv_bias=True,
218 | qk_norm=False,
219 | attn_drop=0.0,
220 | proj_drop=0.0,
221 | device=None,
222 | dtype=None,
223 | norm_layer=nn.LayerNorm,
224 | attn_mode="xformers",
225 | ):
226 | factory_kwargs = {"device": device, "dtype": dtype}
227 | super().__init__()
228 | self.qdim = qdim
229 | self.kdim = kdim
230 | self.num_heads = num_heads
231 | assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
232 | self.head_dim = self.qdim // num_heads
233 | assert (
234 | self.head_dim % 8 == 0 and self.head_dim <= 192
235 | ), "Only support head_dim <= 128 and divisible by 8"
236 |
237 | self.scale = self.head_dim**-0.5
238 |
239 | self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
240 | self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
241 |
242 | # TODO: eps should be 1 / 65530 if using fp16
243 | self.q_norm = (
244 | norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
245 | if qk_norm
246 | else nn.Identity()
247 | )
248 | self.k_norm = (
249 | norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
250 | if qk_norm
251 | else nn.Identity()
252 | )
253 |
254 | self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
255 | self.proj_drop = nn.Dropout(proj_drop)
256 | self.attn_drop = attn_drop
257 | self.attn_mode = attn_mode
258 |
259 | def set_attn_mode(self, mode):
260 | self.attn_mode = mode
261 |
262 | def forward(self, x, y, freqs_cis_img=None):
263 | """
264 | Parameters
265 | ----------
266 | x: torch.Tensor
267 | (batch, seqlen1, hidden_dim) (where hidden_dim = num_heads * head_dim)
268 | y: torch.Tensor
269 | (batch, seqlen2, hidden_dim2)
270 | freqs_cis_img: torch.Tensor
271 | (batch, hidden_dim // num_heads), RoPE for image
272 | """
273 | b, s1, _ = x.shape
274 | _, s2, _ = y.shape
275 |
276 | q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim)
277 | kv = self.kv_proj(y).view(
278 | b, s2, 2, self.num_heads, self.head_dim
279 | )
280 | k, v = kv.unbind(dim=2)
281 | q = self.q_norm(q).to(q)
282 | k = self.k_norm(k).to(k)
283 |
284 | # Apply RoPE if needed
285 | if freqs_cis_img is not None:
286 | qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
287 | assert qq.shape == q.shape, f"qq: {qq.shape}, q: {q.shape}"
288 | q = qq
289 | context = attention(q, k, v, self.head_dim, self.attn_drop, mode=self.attn_mode)
290 | context = context.reshape(b, s1, -1)
291 |
292 | out = self.out_proj(context)
293 | out = self.proj_drop(out)
294 |
295 | out_tuple = (out,)
296 |
297 | return out_tuple
298 |
--------------------------------------------------------------------------------
/dynamictrl/pipelines/modules/dynamictrl_output.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import torch
4 |
5 | from diffusers.utils import BaseOutput
6 |
7 |
8 | @dataclass
9 | class DynamiCtrlPipelineOutput(BaseOutput):
10 | r"""
11 | Output class for DynamiCtrl pipelines.
12 |
13 | Args:
14 | frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
15 | List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
16 | denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
17 | `(batch_size, num_frames, channels, height, width)`.
18 | """
19 |
20 | frames: torch.Tensor
21 |
--------------------------------------------------------------------------------
/dynamictrl/pipelines/modules/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import List, Optional, Union
3 | import inspect
4 |
5 | # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
6 | def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
7 | tw = tgt_width
8 | th = tgt_height
9 | h, w = src
10 | r = h / w
11 | if r > (th / tw):
12 | resize_height = th
13 | resize_width = int(round(th / h * w))
14 | else:
15 | resize_width = tw
16 | resize_height = int(round(tw / w * h))
17 |
18 | crop_top = int(round((th - resize_height) / 2.0))
19 | crop_left = int(round((tw - resize_width) / 2.0))
20 |
21 | return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
22 |
23 |
24 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
25 | def retrieve_timesteps(
26 | scheduler,
27 | num_inference_steps: Optional[int] = None,
28 | device: Optional[Union[str, torch.device]] = None,
29 | timesteps: Optional[List[int]] = None,
30 | sigmas: Optional[List[float]] = None,
31 | **kwargs,
32 | ):
33 | r"""
34 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
35 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
36 |
37 | Args:
38 | scheduler (`SchedulerMixin`):
39 | The scheduler to get timesteps from.
40 | num_inference_steps (`int`):
41 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
42 | must be `None`.
43 | device (`str` or `torch.device`, *optional*):
44 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
45 | timesteps (`List[int]`, *optional*):
46 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
47 | `num_inference_steps` and `sigmas` must be `None`.
48 | sigmas (`List[float]`, *optional*):
49 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
50 | `num_inference_steps` and `timesteps` must be `None`.
51 |
52 | Returns:
53 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
54 | second element is the number of inference steps.
55 | """
56 | if timesteps is not None and sigmas is not None:
57 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
58 | if timesteps is not None:
59 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
60 | if not accepts_timesteps:
61 | raise ValueError(
62 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
63 | f" timestep schedules. Please check whether you are using the correct scheduler."
64 | )
65 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
66 | timesteps = scheduler.timesteps
67 | num_inference_steps = len(timesteps)
68 | elif sigmas is not None:
69 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
70 | if not accept_sigmas:
71 | raise ValueError(
72 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
73 | f" sigmas schedules. Please check whether you are using the correct scheduler."
74 | )
75 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
76 | timesteps = scheduler.timesteps
77 | num_inference_steps = len(timesteps)
78 | else:
79 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
80 | timesteps = scheduler.timesteps
81 | return timesteps, num_inference_steps
82 |
83 |
84 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
85 | def retrieve_latents(
86 | encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
87 | ):
88 | if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
89 | return encoder_output.latent_dist.sample(generator)
90 | elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
91 | return encoder_output.latent_dist.mode()
92 | elif hasattr(encoder_output, "latents"):
93 | return encoder_output.latents
94 | else:
95 | raise AttributeError("Could not access latents of provided encoder_output")
96 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==1.4.0
2 | av==14.2.0
3 | boto3==1.26.66
4 | clip==0.2.0
5 | decord==0.6.0
6 | deepspeed==0.14.4
7 | diffusers==0.32.0
8 | huggingface-hub==0.29.1
9 | imageio-ffmpeg==0.6.0
10 | jsonlines==4.0.0
11 | moviepy==1.0.3
12 | ninja==1.11.1.3
13 | numpy==1.26.0
14 | omegaconf==2.3.0
15 | onnxruntime-gpu
16 | PyYAML==6.0.2
17 | qwen-vl-utils==0.0.10
18 | safetensors==0.5.2
19 | sentencepiece==0.2.0
20 | SwissArmyTransformer==0.4.12
21 | tensorboardX==2.6.2.2
22 | timm==1.0.15
23 | tokenizers==0.21.0
24 | torch==2.6.0
25 | torchvision==0.21.0
26 | tqdm==4.67.1
27 | transformers==4.49.0
28 | wandb==0.19.7
--------------------------------------------------------------------------------
/scripts/dynamictrl_inference.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team.
2 | # All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | import time
18 | import argparse
19 | from typing import Literal
20 |
21 | import torch
22 | from diffusers import CogVideoXDPMScheduler
23 | from diffusers.utils import export_to_video, load_image, load_video
24 |
25 | from dynamictrl.pipelines.dynamictrl_pipeline import DynamiCtrlAnimationPipeline
26 |
27 | from utils.load_validation_control import load_control_video_inference, load_contorl_video_from_Image
28 | from utils.pre_process import preprocess
29 |
30 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
31 |
32 | def generate_video(
33 | prompt: str,
34 | model_path: str,
35 | num_frames: int = 81,
36 | width: int = 1360,
37 | height: int = 768,
38 | output_path: str = "./output.mp4",
39 | reference_image_path: str = "",
40 | ori_driving_video: str = None,
41 | pose_video: str = None,
42 | num_inference_steps: int = 50,
43 | guidance_scale: float = 6.0,
44 | num_videos_per_prompt: int = 1,
45 | dtype: torch.dtype = torch.bfloat16,
46 | abandon_prefix: int = None,
47 | seed: int = 42,
48 | fps: int = 8,
49 | pose_control_function: str = "padaln",
50 | re_init_noise_latent: bool = False,
51 | ):
52 | """
53 | DynamiCtrl: Generates a dynamic video based on the given human image, guided video, and prompt, and saves it to the specified path.
54 |
55 | Parameters:
56 | - prompt (str): The description of the video to be generated.
57 | - model_path (str): The path of the pre-trained model to be used.
58 | - output_path (str): The path where the generated video will be saved.
59 | - num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
60 | - num_frames (int): Number of frames to generate.
61 | - width (int): The width of the generated video
62 | - height (int): The height of the generated video
63 | - guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
64 | - num_videos_per_prompt (int): Number of videos to generate per prompt.
65 | - dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
66 | - seed (int): The seed for reproducibility.
67 | - fps (int): The frames per second for the generated video.
68 | """
69 |
70 | image_name = os.path.basename(reference_image_path).split(".")[0]
71 |
72 | # 1. Initial the DynamiCtrl inference Pipeline.
73 | pipe = DynamiCtrlAnimationPipeline.from_pretrained(model_path, torch_dtype=dtype)
74 | pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
75 |
76 | pipe.to("cuda")
77 | # pipe.enable_sequential_cpu_offload()
78 | pipe.vae.enable_slicing()
79 | pipe.vae.enable_tiling()
80 |
81 | # 2. Load the image.
82 | image = load_image(image=reference_image_path)
83 |
84 | # 3. Load the driving video or the pose video.
85 | if ori_driving_video:
86 | video_name = os.path.basename(ori_driving_video).split(".")[0]
87 | pose_images, ref_image = preprocess(ori_driving_video, reference_image_path,
88 | width=width, height=height, max_frame_num=num_frames - 1, sample_stride=1)
89 | pose_sequence_save_dir = os.path.join(output_path, f"PoseImage_{image_name}_pose_{video_name}")
90 | os.makedirs(pose_sequence_save_dir, exist_ok=True)
91 | control_poses = []
92 | for i in range(len(pose_images)):
93 | control_poses.append(pose_images[i])
94 | pose_images[i].save(os.path.join(pose_sequence_save_dir, f"{i}.png")) # Save aligned pose images.
95 | validation_control_video = load_contorl_video_from_Image(control_poses, height, width).unsqueeze(0)
96 |
97 | if pose_video:
98 | video_name = os.path.basename(pose_video).split(".")[0]
99 | validation_control_video = load_control_video_inference(pose_video, video_height=height, video_width=width, max_frames=num_frames).unsqueeze(0)
100 |
101 | # 4. Generate the video frames.
102 | video_generate = pipe(
103 | height=height,
104 | width=width,
105 | prompt=prompt,
106 | image=image,
107 | control_video=validation_control_video,
108 | pose_control_function=pose_control_function,
109 | num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
110 | num_inference_steps=num_inference_steps, # Number of inference steps
111 | num_frames=num_frames, # Number of frames to generate
112 | use_dynamic_cfg=False, # This id used for DPM scheduler, for DDIM scheduler, it should be False
113 | guidance_scale=guidance_scale,
114 | generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility
115 | re_init_noise_latent=re_init_noise_latent,
116 | ).frames[0]
117 |
118 | # 5. Save the videos.
119 | for i in range(int(len(video_generate) / num_frames)):
120 | name_timestamp = int(time.time())
121 | video_gen = video_generate[i * num_frames : (i + 1) * num_frames]
122 | if abandon_prefix:
123 | video_gen = video_gen[abandon_prefix:]
124 | video_output_path = os.path.join(output_path, f"DynamiCtrl_{name_timestamp}_image_{image_name}_pose_{video_name}_seed_{seed}_cfg_{guidance_scale}.mp4")
125 | export_to_video(video_gen, video_output_path, fps=fps)
126 | print(f"output_path: {video_output_path}")
127 |
128 | if __name__ == "__main__":
129 | parser = argparse.ArgumentParser(description="Human image animation from an human image, a driving video, and prompt using DynamiCtrl")
130 | parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
131 | parser.add_argument(
132 | "--reference_image_path",
133 | type=str,
134 | default=None,
135 | help="The path of the reference human image to be used as the subject of the video",
136 | )
137 | parser.add_argument(
138 | "--ori_driving_video",
139 | type=str,
140 | default=None,
141 | help="the origin real video pth, contain the actioned human"
142 | )
143 | parser.add_argument(
144 | "--pose_video",
145 | type=str,
146 | default=None,
147 | help="the path of pose video"
148 | )
149 | parser.add_argument(
150 | "--model_path", type=str, default="gulucaptain/DynamiCtrl", help="Path of the pre-trained model use"
151 | )
152 | parser.add_argument("--output_path", type=str, default="./output.mp4", help="The path save generated video")
153 | parser.add_argument("--guidance_scale", type=float, default=3.0, help="The scale for classifier-free guidance")
154 | parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
155 | parser.add_argument("--num_frames", type=int, default=81, help="Number of steps for the inference process")
156 | parser.add_argument("--width", type=int, default=1360, help="Number of steps for the inference process")
157 | parser.add_argument("--height", type=int, default=768, help="Number of steps for the inference process")
158 | parser.add_argument("--fps", type=int, default=16, help="Number of steps for the inference process")
159 | parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
160 | parser.add_argument("--dtype", type=str, default="bfloat16", help="The data type for computation")
161 | parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
162 | parser.add_argument("--abandon_prefix", type=int, default=None, help="Save video frames range [abandon_prefix:]")
163 | parser.add_argument("--re_init_noise_latent", action="store_true", help="whether to resample the initial noise")
164 | parser.add_argument("--pose_control_function", type=str, default="padaln")
165 |
166 | args = parser.parse_args()
167 | dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
168 | os.makedirs(args.output_path, exist_ok=True)
169 | generate_video(
170 | prompt=args.prompt,
171 | model_path=args.model_path,
172 | output_path=args.output_path,
173 | num_frames=args.num_frames,
174 | width=args.width,
175 | height=args.height,
176 | reference_image_path=args.reference_image_path,
177 | ori_driving_video=args.ori_driving_video,
178 | pose_video=args.pose_video,
179 | num_inference_steps=args.num_inference_steps,
180 | guidance_scale=args.guidance_scale,
181 | num_videos_per_prompt=args.num_videos_per_prompt,
182 | dtype=dtype,
183 | abandon_prefix=args.abandon_prefix,
184 | seed=args.seed,
185 | fps=args.fps,
186 | re_init_noise_latent=args.re_init_noise_latent,
187 | pose_control_function=args.pose_control_function,
188 | )
189 |
--------------------------------------------------------------------------------
/tools/qwen.py:
--------------------------------------------------------------------------------
1 | from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
2 | from qwen_vl_utils import process_vision_info
3 | import os
4 | from tqdm import tqdm
5 |
6 | # default: Load the model on the available device(s)
7 | model = Qwen2VLForConditionalGeneration.from_pretrained(
8 | "/home/user/ckpt/qwen", torch_dtype="auto", device_map="auto"
9 | )
10 |
11 | # default processer
12 | processor = AutoProcessor.from_pretrained("/home/user/code/ckpt/qwen")
13 |
14 | image_data_root = "/home/user/images"
15 | images = os.listdir(image_data_root)
16 | images.sort()
17 |
18 | for image in tqdm(images):
19 | image_path = os.path.join(image_data_root, image)
20 | Question_description = "Describe the person and the background." #FIXME Replace the question you want.
21 |
22 | messages = [
23 | {
24 | "role": "user",
25 | "content": [
26 | {
27 | "type": "image",
28 | "image": image_path,
29 | },
30 | {"type": "text", "text": Question_description},
31 | ],
32 | }
33 | ]
34 |
35 | # Preparation for inference
36 | text = processor.apply_chat_template(
37 | messages, tokenize=False, add_generation_prompt=True
38 | )
39 | image_inputs, video_inputs = process_vision_info(messages)
40 | inputs = processor(
41 | text=[text],
42 | images=image_inputs,
43 | videos=video_inputs,
44 | padding=True,
45 | return_tensors="pt",
46 | )
47 | inputs = inputs.to("cuda")
48 |
49 | # Inference: Generation of the output
50 | generated_ids = model.generate(**inputs, max_new_tokens=128)
51 | generated_ids_trimmed = [
52 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
53 | ]
54 | output_text = processor.batch_decode(
55 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
56 | )[0]
57 | with open("/home/user/descriptions.txt", "a") as f:
58 | f.writelines(image+"#####"+output_text+"\n")
59 |
--------------------------------------------------------------------------------
/utils/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def _get_model_args(parser: argparse.ArgumentParser) -> None:
5 | parser.add_argument(
6 | "--pretrained_model_name_or_path",
7 | type=str,
8 | default=None,
9 | required=True,
10 | help="Path to pretrained model or model identifier from huggingface.co/models.",
11 | )
12 | parser.add_argument(
13 | "--revision",
14 | type=str,
15 | default=None,
16 | required=False,
17 | help="Revision of pretrained model identifier from huggingface.co/models.",
18 | )
19 | parser.add_argument(
20 | "--variant",
21 | type=str,
22 | default=None,
23 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
24 | )
25 | parser.add_argument(
26 | "--cache_dir",
27 | type=str,
28 | default=None,
29 | help="The directory where the downloaded models and datasets will be stored.",
30 | )
31 |
32 |
33 | def _get_dataset_args(parser: argparse.ArgumentParser) -> None:
34 | parser.add_argument(
35 | "--data_root",
36 | type=str,
37 | default=None,
38 | help=("A folder containing the training data."),
39 | )
40 | parser.add_argument(
41 | "--dataset_file",
42 | type=str,
43 | default=None,
44 | help=("Path to a CSV file if loading prompts/video paths using this format."),
45 | )
46 | parser.add_argument(
47 | "--video_column",
48 | type=str,
49 | default="video",
50 | help="The column of the dataset containing videos. Or, the name of the file in `--data_root` folder containing the line-separated path to video data.",
51 | )
52 | parser.add_argument(
53 | "--caption_column",
54 | type=str,
55 | default="text",
56 | help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--data_root` folder containing the line-separated instance prompts.",
57 | )
58 | parser.add_argument(
59 | "--id_token",
60 | type=str,
61 | default=None,
62 | help="Identifier token appended to the start of each prompt if provided.",
63 | )
64 | parser.add_argument(
65 | "--height_buckets",
66 | nargs="+",
67 | type=int,
68 | default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536],
69 | )
70 | parser.add_argument(
71 | "--width_buckets",
72 | nargs="+",
73 | type=int,
74 | default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536],
75 | )
76 | parser.add_argument(
77 | "--frame_buckets",
78 | nargs="+",
79 | type=int,
80 | default=[49],
81 | help="CogVideoX1.5 need to guarantee that ((num_frames - 1) // self.vae_scale_factor_temporal + 1) % patch_size_t == 0, such as 53"
82 | )
83 | parser.add_argument(
84 | "--load_tensors",
85 | action="store_true",
86 | help="Whether to use a pre-encoded tensor dataset of latents and prompt embeddings instead of videos and text prompts. The expected format is that saved by running the `prepare_dataset.py` script.",
87 | )
88 | parser.add_argument(
89 | "--random_flip",
90 | type=float,
91 | default=None,
92 | help="If random horizontal flip augmentation is to be used, this should be the flip probability.",
93 | )
94 | parser.add_argument(
95 | "--dataloader_num_workers",
96 | type=int,
97 | default=0,
98 | help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
99 | )
100 | parser.add_argument(
101 | "--pin_memory",
102 | action="store_true",
103 | help="Whether or not to use the pinned memory setting in pytorch dataloader.",
104 | )
105 |
106 |
107 | def _get_validation_args(parser: argparse.ArgumentParser) -> None:
108 | parser.add_argument(
109 | "--validation_prompt",
110 | type=str,
111 | default=None,
112 | help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
113 | )
114 | parser.add_argument(
115 | "--validation_images",
116 | type=str,
117 | default=None,
118 | help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
119 | )
120 | parser.add_argument(
121 | "--validation_driving_videos",
122 | type=str,
123 | default=None,
124 | help="One or more driving video path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
125 | )
126 | parser.add_argument(
127 | "--validation_prompt_separator",
128 | type=str,
129 | default=":::",
130 | help="String that separates multiple validation prompts",
131 | )
132 | parser.add_argument(
133 | "--num_validation_videos",
134 | type=int,
135 | default=1,
136 | help="Number of videos that should be generated during validation per `validation_prompt`.",
137 | )
138 | parser.add_argument(
139 | "--validation_epochs",
140 | type=int,
141 | default=None,
142 | help="Run validation every X training epochs. Validation consists of running the validation prompt `args.num_validation_videos` times.",
143 | )
144 | parser.add_argument(
145 | "--validation_steps",
146 | type=int,
147 | default=None,
148 | help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.",
149 | )
150 | parser.add_argument(
151 | "--guidance_scale",
152 | type=float,
153 | default=6,
154 | help="The guidance scale to use while sampling validation videos.",
155 | )
156 | parser.add_argument(
157 | "--use_dynamic_cfg",
158 | action="store_true",
159 | default=False,
160 | help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.",
161 | )
162 | parser.add_argument(
163 | "--enable_model_cpu_offload",
164 | action="store_true",
165 | default=False,
166 | help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.",
167 | )
168 |
169 |
170 | def _get_training_args(parser: argparse.ArgumentParser) -> None:
171 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
172 | parser.add_argument("--rank", type=int, default=64, help="The rank for LoRA matrices.")
173 | parser.add_argument(
174 | "--lora_alpha",
175 | type=int,
176 | default=64,
177 | help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.",
178 | )
179 | parser.add_argument(
180 | "--mixed_precision",
181 | type=str,
182 | default=None,
183 | choices=["no", "fp16", "bf16"],
184 | help=(
185 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.and an Nvidia Ampere GPU. "
186 | "Default to the value of accelerate config of the current system or the flag passed with the `accelerate.launch` command. Use this "
187 | "argument to override the accelerate config."
188 | ),
189 | )
190 | parser.add_argument(
191 | "--output_dir",
192 | type=str,
193 | default="dynamictrl-sft",
194 | help="The output directory where the model predictions and checkpoints will be written.",
195 | )
196 | parser.add_argument(
197 | "--height",
198 | type=int,
199 | default=480,
200 | help="All input videos are resized to this height.",
201 | )
202 | parser.add_argument(
203 | "--width",
204 | type=int,
205 | default=720,
206 | help="All input videos are resized to this width.",
207 | )
208 | parser.add_argument(
209 | "--video_reshape_mode",
210 | type=str,
211 | default=None,
212 | help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
213 | )
214 | parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
215 | parser.add_argument(
216 | "--max_num_frames",
217 | type=int,
218 | default=49,
219 | help="All input videos will be truncated to these many frames.",
220 | )
221 | parser.add_argument(
222 | "--skip_frames_start",
223 | type=int,
224 | default=0,
225 | help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.",
226 | )
227 | parser.add_argument(
228 | "--skip_frames_end",
229 | type=int,
230 | default=0,
231 | help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.",
232 | )
233 | parser.add_argument(
234 | "--train_batch_size",
235 | type=int,
236 | default=4,
237 | help="Batch size (per device) for the training dataloader.",
238 | )
239 | parser.add_argument("--num_train_epochs", type=int, default=1)
240 | parser.add_argument(
241 | "--max_train_steps",
242 | type=int,
243 | default=None,
244 | help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.",
245 | )
246 | parser.add_argument(
247 | "--checkpointing_steps",
248 | type=int,
249 | default=500,
250 | help=(
251 | "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
252 | " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
253 | " training using `--resume_from_checkpoint`."
254 | ),
255 | )
256 | parser.add_argument(
257 | "--checkpoints_total_limit",
258 | type=int,
259 | default=None,
260 | help=("Max number of checkpoints to store."),
261 | )
262 | parser.add_argument(
263 | "--resume_from_checkpoint",
264 | type=str,
265 | default=None,
266 | help=(
267 | "Whether training should be resumed from a previous checkpoint. Use a path saved by"
268 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
269 | ),
270 | )
271 | parser.add_argument(
272 | "--gradient_accumulation_steps",
273 | type=int,
274 | default=1,
275 | help="Number of updates steps to accumulate before performing a backward/update pass.",
276 | )
277 | parser.add_argument(
278 | "--gradient_checkpointing",
279 | action="store_true",
280 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
281 | )
282 | parser.add_argument(
283 | "--learning_rate",
284 | type=float,
285 | default=1e-4,
286 | help="Initial learning rate (after the potential warmup period) to use.",
287 | )
288 | parser.add_argument(
289 | "--scale_lr",
290 | action="store_true",
291 | default=False,
292 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
293 | )
294 | parser.add_argument(
295 | "--lr_scheduler",
296 | type=str,
297 | default="constant",
298 | help=(
299 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
300 | ' "constant", "constant_with_warmup"]'
301 | ),
302 | )
303 | parser.add_argument(
304 | "--lr_warmup_steps",
305 | type=int,
306 | default=500,
307 | help="Number of steps for the warmup in the lr scheduler.",
308 | )
309 | parser.add_argument(
310 | "--lr_num_cycles",
311 | type=int,
312 | default=1,
313 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
314 | )
315 | parser.add_argument(
316 | "--lr_power",
317 | type=float,
318 | default=1.0,
319 | help="Power factor of the polynomial scheduler.",
320 | )
321 | parser.add_argument(
322 | "--enable_slicing",
323 | action="store_true",
324 | default=False,
325 | help="Whether or not to use VAE slicing for saving memory.",
326 | )
327 | parser.add_argument(
328 | "--enable_tiling",
329 | action="store_true",
330 | default=False,
331 | help="Whether or not to use VAE tiling for saving memory.",
332 | )
333 | parser.add_argument(
334 | "--noised_image_dropout",
335 | type=float,
336 | default=0.05,
337 | help="Image condition dropout probability when finetuning image-to-video.",
338 | )
339 | parser.add_argument(
340 | "--ignore_learned_positional_embeddings",
341 | action="store_true",
342 | default=False,
343 | help=(
344 | "Whether to ignore the learned positional embeddings when training CogVideoX Image-to-Video. This setting "
345 | "should be used when performing multi-resolution training, because CogVideoX-I2V does not support it "
346 | "otherwise. Please read the comments in https://github.com/a-r-r-o-w/cogvideox-factory/issues/26 to understand why."
347 | ),
348 | )
349 | parser.add_argument(
350 | "--enable_control_pose",
351 | action="store_true",
352 | default=False,
353 | help="Whether or not to use pose control information"
354 | )
355 | parser.add_argument(
356 | "--pose_control_function",
357 | type=str,
358 | default="expert_layer_norm",
359 | help="pose control function: direct_adding; expert_layer_norm; cross_attention",
360 | )
361 |
362 |
363 | def _get_optimizer_args(parser: argparse.ArgumentParser) -> None:
364 | parser.add_argument(
365 | "--optimizer",
366 | type=lambda s: s.lower(),
367 | default="adam",
368 | choices=["adam", "adamw", "prodigy", "came"],
369 | help=("The optimizer type to use."),
370 | )
371 | parser.add_argument(
372 | "--use_8bit",
373 | action="store_true",
374 | help="Whether or not to use 8-bit optimizers from `bitsandbytes` or `bitsandbytes`.",
375 | )
376 | parser.add_argument(
377 | "--use_4bit",
378 | action="store_true",
379 | help="Whether or not to use 4-bit optimizers from `torchao`.",
380 | )
381 | parser.add_argument(
382 | "--use_torchao", action="store_true", help="Whether or not to use the `torchao` backend for optimizers."
383 | )
384 | parser.add_argument(
385 | "--beta1",
386 | type=float,
387 | default=0.9,
388 | help="The beta1 parameter for the Adam and Prodigy optimizers.",
389 | )
390 | parser.add_argument(
391 | "--beta2",
392 | type=float,
393 | default=0.95,
394 | help="The beta2 parameter for the Adam and Prodigy optimizers.",
395 | )
396 | parser.add_argument(
397 | "--beta3",
398 | type=float,
399 | default=None,
400 | help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.",
401 | )
402 | parser.add_argument(
403 | "--prodigy_decouple",
404 | action="store_true",
405 | help="Use AdamW style decoupled weight decay.",
406 | )
407 | parser.add_argument(
408 | "--weight_decay",
409 | type=float,
410 | default=1e-04,
411 | help="Weight decay to use for optimizer.",
412 | )
413 | parser.add_argument(
414 | "--epsilon",
415 | type=float,
416 | default=1e-8,
417 | help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
418 | )
419 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
420 | parser.add_argument(
421 | "--prodigy_use_bias_correction",
422 | action="store_true",
423 | help="Turn on Adam's bias correction.",
424 | )
425 | parser.add_argument(
426 | "--prodigy_safeguard_warmup",
427 | action="store_true",
428 | help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.",
429 | )
430 | parser.add_argument(
431 | "--use_cpu_offload_optimizer",
432 | action="store_true",
433 | help="Whether or not to use the CPUOffloadOptimizer from TorchAO to perform optimization step and maintain parameters on the CPU.",
434 | )
435 | parser.add_argument(
436 | "--offload_gradients",
437 | action="store_true",
438 | help="Whether or not to offload the gradients to CPU when using the CPUOffloadOptimizer from TorchAO.",
439 | )
440 |
441 |
442 | def _get_configuration_args(parser: argparse.ArgumentParser) -> None:
443 | parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name")
444 | parser.add_argument(
445 | "--push_to_hub",
446 | action="store_true",
447 | help="Whether or not to push the model to the Hub.",
448 | )
449 | parser.add_argument(
450 | "--hub_token",
451 | type=str,
452 | default=None,
453 | help="The token to use to push to the Model Hub.",
454 | )
455 | parser.add_argument(
456 | "--hub_model_id",
457 | type=str,
458 | default=None,
459 | help="The name of the repository to keep in sync with the local `output_dir`.",
460 | )
461 | parser.add_argument(
462 | "--logging_dir",
463 | type=str,
464 | default="logs",
465 | help="Directory where logs are stored.",
466 | )
467 | parser.add_argument(
468 | "--allow_tf32",
469 | action="store_true",
470 | help=(
471 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
472 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
473 | ),
474 | )
475 | parser.add_argument(
476 | "--nccl_timeout",
477 | type=int,
478 | default=600,
479 | help="Maximum timeout duration before which allgather, or related, operations fail in multi-GPU/multi-node training settings.",
480 | )
481 | parser.add_argument(
482 | "--report_to",
483 | type=str,
484 | default=None,
485 | help=(
486 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
487 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
488 | ),
489 | )
490 |
491 |
492 | def get_args():
493 | parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.")
494 |
495 | _get_model_args(parser)
496 | _get_dataset_args(parser)
497 | _get_training_args(parser)
498 | _get_validation_args(parser)
499 | _get_optimizer_args(parser)
500 | _get_configuration_args(parser)
501 |
502 | return parser.parse_args()
503 |
--------------------------------------------------------------------------------
/utils/dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | from pathlib import Path
3 | from typing import Any, Dict, List, Optional, Tuple
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 | import torchvision.transforms as TT
9 | from accelerate.logging import get_logger
10 | from torch.utils.data import Dataset, Sampler
11 | from torchvision import transforms
12 | from torchvision.transforms import InterpolationMode
13 | from torchvision.transforms.functional import resize
14 |
15 |
16 | # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
17 | # Very few bug reports but it happens. Look in decord Github issues for more relevant information.
18 | import decord # isort:skip
19 |
20 | decord.bridge.set_bridge("torch")
21 |
22 | logger = get_logger(__name__)
23 |
24 | HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
25 | WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
26 | FRAME_BUCKETS = [16, 24, 32, 48, 64, 80]
27 |
28 |
29 | class VideoDataset(Dataset):
30 | def __init__(
31 | self,
32 | data_root: str,
33 | dataset_file: Optional[str] = None,
34 | caption_column: str = "text",
35 | video_column: str = "video",
36 | max_num_frames: int = 49,
37 | id_token: Optional[str] = None,
38 | height_buckets: List[int] = None,
39 | width_buckets: List[int] = None,
40 | frame_buckets: List[int] = None,
41 | load_tensors: bool = False,
42 | random_flip: Optional[float] = None,
43 | image_to_video: bool = False,
44 | enable_control_pose: bool = False,
45 | ) -> None:
46 | super().__init__()
47 |
48 | self.data_root = Path(data_root)
49 | self.dataset_file = dataset_file
50 | self.caption_column = caption_column
51 | self.video_column = video_column
52 | self.max_num_frames = max_num_frames
53 | self.id_token = f"{id_token.strip()} " if id_token else ""
54 | self.height_buckets = height_buckets or HEIGHT_BUCKETS
55 | self.width_buckets = width_buckets or WIDTH_BUCKETS
56 | self.frame_buckets = frame_buckets or FRAME_BUCKETS
57 | self.load_tensors = load_tensors
58 | self.random_flip = random_flip
59 | self.image_to_video = image_to_video
60 | self.enable_control_pose = enable_control_pose
61 |
62 | self.resolutions = [
63 | (f, h, w) for h in self.height_buckets for w in self.width_buckets for f in self.frame_buckets
64 | ]
65 |
66 | # Two methods of loading data are supported.
67 | # - Using a CSV: caption_column and video_column must be some column in the CSV. One could
68 | # make use of other columns too, such as a motion score or aesthetic score, by modifying the
69 | # logic in CSV processing.
70 | # - Using two files containing line-separate captions and relative paths to videos.
71 | # For a more detailed explanation about preparing dataset format, checkout the README.
72 | if dataset_file is None:
73 | (
74 | self.prompts,
75 | self.video_paths,
76 | ) = self._load_dataset_from_local_path()
77 | else:
78 | (
79 | self.prompts,
80 | self.video_paths,
81 | ) = self._load_dataset_from_csv()
82 |
83 | if len(self.video_paths) != len(self.prompts):
84 | raise ValueError(
85 | f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset."
86 | )
87 |
88 | self.video_transforms = transforms.Compose(
89 | [
90 | transforms.RandomHorizontalFlip(random_flip)
91 | if random_flip
92 | else transforms.Lambda(self.identity_transform),
93 | transforms.Lambda(self.scale_transform),
94 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
95 | ]
96 | )
97 |
98 | @staticmethod
99 | def identity_transform(x):
100 | return x
101 |
102 | @staticmethod
103 | def scale_transform(x):
104 | return x / 255.0
105 |
106 | def __len__(self) -> int:
107 | return len(self.video_paths)
108 |
109 | def __getitem__(self, index: int) -> Dict[str, Any]:
110 | if isinstance(index, list):
111 | # Here, index is actually a list of data objects that we need to return.
112 | # The BucketSampler should ideally return indices. But, in the sampler, we'd like
113 | # to have information about num_frames, height and width. Since this is not stored
114 | # as metadata, we need to read the video to get this information. You could read this
115 | # information without loading the full video in memory, but we do it anyway. In order
116 | # to not load the video twice (once to get the metadata, and once to return the loaded video
117 | # based on sampled indices), we cache it in the BucketSampler. When the sampler is
118 | # to yield, we yield the cache data instead of indices. So, this special check ensures
119 | # that data is not loaded a second time. PRs are welcome for improvements.
120 | return index
121 |
122 | if self.load_tensors: # False
123 | image_latents, video_latents, prompt_embeds = self._preprocess_video(self.video_paths[index])
124 |
125 | latent_num_frames = video_latents.size(1)
126 | if latent_num_frames % 2 == 0:
127 | num_frames = latent_num_frames * 4
128 | else:
129 | num_frames = (latent_num_frames - 1) * 4 + 1
130 |
131 | height = video_latents.size(2) * 8
132 | width = video_latents.size(3) * 8
133 |
134 | return {
135 | "prompt": prompt_embeds,
136 | "image": image_latents,
137 | "video": video_latents,
138 | "video_metadata": {
139 | "num_frames": num_frames,
140 | "height": height,
141 | "width": width,
142 | },
143 | }
144 | else:
145 | image, video, control_video, _ = self._preprocess_video(self.video_paths[index])
146 | prompt = self.prompts[index].split("#####")[1]
147 |
148 | return {
149 | "prompt": self.id_token + prompt,
150 | "image": image,
151 | "video": video,
152 | "control_video": control_video,
153 | "video_metadata": {
154 | "num_frames": video.shape[0],
155 | "height": video.shape[2],
156 | "width": video.shape[3],
157 | },
158 | }
159 |
160 | def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str]]:
161 | if not self.data_root.exists():
162 | raise ValueError("Root folder for videos does not exist")
163 |
164 | prompt_path = self.data_root.joinpath(self.caption_column)
165 | video_path = self.data_root.joinpath(self.video_column)
166 |
167 | if not prompt_path.exists() or not prompt_path.is_file():
168 | raise ValueError(
169 | "Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts."
170 | )
171 | if not video_path.exists() or not video_path.is_file():
172 | raise ValueError(
173 | "Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory."
174 | )
175 |
176 | with open(prompt_path, "r", encoding="utf-8") as file:
177 | prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0]
178 | with open(video_path, "r", encoding="utf-8") as file:
179 | video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0]
180 |
181 | if not self.load_tensors and any(not path.is_file() for path in video_paths):
182 | raise ValueError(
183 | f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
184 | )
185 |
186 | return prompts, video_paths
187 |
188 | def _load_dataset_from_csv(self) -> Tuple[List[str], List[str]]:
189 | df = pd.read_csv(self.dataset_file)
190 | prompts = df[self.caption_column].tolist()
191 | video_paths = df[self.video_column].tolist()
192 | video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths]
193 |
194 | if any(not path.is_file() for path in video_paths):
195 | raise ValueError(
196 | f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
197 | )
198 |
199 | return prompts, video_paths
200 |
201 | def _preprocess_video(self, path: Path) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
202 | r"""
203 | Loads a single video, or latent and prompt embedding, based on initialization parameters.
204 |
205 | If returning a video, returns a [F, C, H, W] video tensor, and None for the prompt embedding. Here,
206 | F, C, H and W are the frames, channels, height and width of the input video.
207 |
208 | If returning latent/embedding, returns a [F, C, H, W] latent, and the prompt embedding of shape [S, D].
209 | F, C, H and W are the frames, channels, height and width of the latent, and S, D are the sequence length
210 | and embedding dimension of prompt embeddings.
211 | """
212 | if self.load_tensors:
213 | return self._load_preprocessed_latents_and_embeds(path)
214 | else:
215 | video_reader = decord.VideoReader(uri=path.as_posix())
216 | video_num_frames = len(video_reader)
217 |
218 | indices = list(range(0, video_num_frames, video_num_frames // self.max_num_frames))
219 | frames = video_reader.get_batch(indices)
220 | frames = frames[: self.max_num_frames].float()
221 | frames = frames.permute(0, 3, 1, 2).contiguous()
222 | frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0)
223 |
224 | image = frames[:1].clone() if self.image_to_video else None
225 |
226 | return image, frames, None
227 |
228 | def _load_preprocessed_latents_and_embeds(self, path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
229 | filename_without_ext = path.name.split(".")[0]
230 | pt_filename = f"{filename_without_ext}.pt"
231 |
232 | image_latents_path = path.parent.parent.joinpath("image_latents")
233 | video_latents_path = path.parent.parent.joinpath("video_latents")
234 | embeds_path = path.parent.parent.joinpath("prompt_embeds")
235 |
236 | if (
237 | not video_latents_path.exists()
238 | or not embeds_path.exists()
239 | or (self.image_to_video and not image_latents_path.exists())
240 | ):
241 | raise ValueError(
242 | f"When setting the load_tensors parameter to `True`, it is expected that the `{self.data_root=}` contains two folders named `video_latents` and `prompt_embeds`. However, these folders were not found. Please make sure to have prepared your data correctly using `prepare_data.py`. Additionally, if you're training image-to-video, it is expected that an `image_latents` folder is also present."
243 | )
244 |
245 | if self.image_to_video:
246 | image_latent_filepath = image_latents_path.joinpath(pt_filename)
247 | video_latent_filepath = video_latents_path.joinpath(pt_filename)
248 | embeds_filepath = embeds_path.joinpath(pt_filename)
249 |
250 | if not video_latent_filepath.is_file() or not embeds_filepath.is_file():
251 | if self.image_to_video:
252 | image_latent_filepath = image_latent_filepath.as_posix()
253 | video_latent_filepath = video_latent_filepath.as_posix()
254 | embeds_filepath = embeds_filepath.as_posix()
255 | raise ValueError(
256 | f"The file {video_latent_filepath=} or {embeds_filepath=} could not be found. Please ensure that you've correctly executed `prepare_dataset.py`."
257 | )
258 |
259 | images = (
260 | torch.load(image_latent_filepath, map_location="cpu", weights_only=True) if self.image_to_video else None
261 | )
262 | latents = torch.load(video_latent_filepath, map_location="cpu", weights_only=True)
263 | embeds = torch.load(embeds_filepath, map_location="cpu", weights_only=True)
264 |
265 | return images, latents, embeds
266 |
267 |
268 | class VideoDatasetWithResizing(VideoDataset):
269 | def __init__(self, *args, **kwargs) -> None:
270 | super().__init__(*args, **kwargs)
271 |
272 | def get_entity_with_mask_short_time(self, video, masked_video):
273 | video = video.detach().numpy()
274 | masked_video = masked_video.detach().numpy()
275 |
276 | threshold = 200
277 | output_frames = []
278 |
279 | if video.shape[1] != masked_video.shape[1] or video.shape[2] != masked_video.shape[2]:
280 | return video
281 |
282 | frame_length, frame_height, frame_width, _ = video.shape
283 |
284 | for i in range(frame_length):
285 | video_frame = video[i]
286 | mask_frame = masked_video[i]
287 | mask_frame = mask_frame[:, :, 0]
288 |
289 | mask_frame = mask_frame.astype(np.uint8)
290 | binary_mask = (mask_frame > threshold).astype(np.uint8)
291 | binary_mask = np.expand_dims(binary_mask, axis=-1)
292 |
293 | masked_frame = video_frame * binary_mask
294 | output_frames.append(masked_frame)
295 |
296 | return torch.tensor(np.array(output_frames))
297 |
298 | def _preprocess_video(self, path: Path) -> torch.Tensor:
299 | if self.load_tensors:
300 | return self._load_preprocessed_latents_and_embeds(path)
301 | else:
302 | video_reader = decord.VideoReader(uri=path.as_posix())
303 |
304 | use_image_mask = False
305 | if use_image_mask:
306 | try:
307 | mask_reader = decord.VideoReader(uri=path.as_posix().replace("videos", "masked_video")) # directly load the entity video.
308 | except:
309 | mask_reader = None
310 | else:
311 | mask_reader = None
312 |
313 | if self.enable_control_pose:
314 | pose_reader = decord.VideoReader(uri=path.as_posix().replace("videos", "poses"))
315 | else:
316 | pose_reader = None
317 |
318 | video_num_frames = len(video_reader)
319 | nearest_frame_bucket = min(
320 | self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames))
321 | )
322 |
323 | frames_load_interval = 1
324 | frame_indices = []
325 | frame_index = random.randint(1, 30)
326 | while len(frame_indices) < nearest_frame_bucket:
327 | frame_indices.append(frame_index)
328 | frame_index += frames_load_interval
329 |
330 | frames = video_reader.get_batch(frame_indices)
331 | if use_image_mask and mask_reader is not None:
332 | masked_frames = mask_reader.get_batch(frame_indices)
333 | maked_frames = masked_frames[:nearest_frame_bucket].float()
334 | maked_frames = masked_frames.permute(0, 3, 1, 2).contiguous()
335 |
336 | frames = frames[:nearest_frame_bucket].float()
337 | frames = frames.permute(0, 3, 1, 2).contiguous()
338 |
339 | nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
340 | frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0)
341 | frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
342 |
343 | if use_image_mask and mask_reader is not None:
344 | maked_frames_resized = torch.stack([resize(mask_frame, nearest_res) for mask_frame in maked_frames], dim=0)
345 | maked_frames = torch.stack([self.video_transforms(mask_frame) for mask_frame in maked_frames_resized], dim=0)
346 |
347 | pose_frames = pose_reader.get_batch(frame_indices)
348 | pose_frames = pose_frames[:nearest_frame_bucket].float()
349 | pose_frames = pose_frames.permute(0, 3, 1, 2).contiguous()
350 |
351 | pose_frames_resized = torch.stack([resize(frame, nearest_res) for frame in pose_frames], dim=0)
352 | pose_frames = torch.stack([self.video_transforms(frame) for frame in pose_frames_resized], dim=0)
353 |
354 | if use_image_mask and mask_reader is not None:
355 | image = maked_frames[:1].clone() if self.image_to_video else None
356 | else:
357 | image = frames[:1].clone() if self.image_to_video else None
358 |
359 | return image, frames, pose_frames, None
360 |
361 | def _find_nearest_resolution(self, height, width):
362 | nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
363 | return nearest_res[1], nearest_res[2]
364 |
365 |
366 | class VideoDatasetWithResizeAndRectangleCrop(VideoDataset):
367 | def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None:
368 | super().__init__(*args, **kwargs)
369 | self.video_reshape_mode = video_reshape_mode
370 |
371 | def _resize_for_rectangle_crop(self, arr, image_size):
372 | reshape_mode = self.video_reshape_mode
373 | if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
374 | arr = resize(
375 | arr,
376 | size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
377 | interpolation=InterpolationMode.BICUBIC,
378 | )
379 | else:
380 | arr = resize(
381 | arr,
382 | size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
383 | interpolation=InterpolationMode.BICUBIC,
384 | )
385 |
386 | h, w = arr.shape[2], arr.shape[3]
387 | arr = arr.squeeze(0)
388 |
389 | delta_h = h - image_size[0]
390 | delta_w = w - image_size[1]
391 |
392 | if reshape_mode == "random" or reshape_mode == "none":
393 | top = np.random.randint(0, delta_h + 1)
394 | left = np.random.randint(0, delta_w + 1)
395 | elif reshape_mode == "center":
396 | top, left = delta_h // 2, delta_w // 2
397 | else:
398 | raise NotImplementedError
399 | arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
400 | return arr
401 |
402 | def _preprocess_video(self, path: Path) -> torch.Tensor:
403 | if self.load_tensors:
404 | return self._load_preprocessed_latents_and_embeds(path)
405 | else:
406 | video_reader = decord.VideoReader(uri=path.as_posix())
407 | video_num_frames = len(video_reader)
408 | nearest_frame_bucket = min(
409 | self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames))
410 | )
411 |
412 | frames_load_interval = 1
413 | frame_indices = []
414 | frame_index = 0
415 | while len(frame_indices) < nearest_frame_bucket:
416 | frame_indices.append(frame_index)
417 | frame_index += frames_load_interval
418 |
419 | frames = video_reader.get_batch(frame_indices)
420 | frames = frames[:nearest_frame_bucket].float()
421 | frames = frames.permute(0, 3, 1, 2).contiguous()
422 |
423 | nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
424 | frames_resized = self._resize_for_rectangle_crop(frames, nearest_res)
425 | frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
426 |
427 | image = frames[:1].clone() if self.image_to_video else None
428 |
429 | return image, frames, None
430 |
431 | def _find_nearest_resolution(self, height, width):
432 | nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
433 | return nearest_res[1], nearest_res[2]
434 |
435 |
436 | class BucketSampler(Sampler):
437 | r"""
438 | PyTorch Sampler that groups 3D data by height, width and frames.
439 |
440 | Args:
441 | data_source (`VideoDataset`):
442 | A PyTorch dataset object that is an instance of `VideoDataset`.
443 | batch_size (`int`, defaults to `8`):
444 | The batch size to use for training.
445 | shuffle (`bool`, defaults to `True`):
446 | Whether or not to shuffle the data in each batch before dispatching to dataloader.
447 | drop_last (`bool`, defaults to `False`):
448 | Whether or not to drop incomplete buckets of data after completely iterating over all data
449 | in the dataset. If set to True, only batches that have `batch_size` number of entries will
450 | be yielded. If set to False, it is guaranteed that all data in the dataset will be processed
451 | and batches that do not have `batch_size` number of entries will also be yielded.
452 | """
453 |
454 | def __init__(
455 | self, data_source: VideoDataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False
456 | ) -> None:
457 | self.data_source = data_source
458 | self.batch_size = batch_size
459 | self.shuffle = shuffle
460 | self.drop_last = drop_last
461 |
462 | self.buckets = {resolution: [] for resolution in data_source.resolutions}
463 |
464 | self._raised_warning_for_drop_last = False
465 |
466 | def __len__(self):
467 | if self.drop_last and not self._raised_warning_for_drop_last:
468 | self._raised_warning_for_drop_last = True
469 | logger.warning(
470 | "Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training."
471 | )
472 | return (len(self.data_source) + self.batch_size - 1) // self.batch_size
473 |
474 | def __iter__(self):
475 | for index, data in enumerate(self.data_source):
476 | video_metadata = data["video_metadata"]
477 | f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"]
478 |
479 | self.buckets[(f, h, w)].append(data)
480 | if len(self.buckets[(f, h, w)]) == self.batch_size:
481 | if self.shuffle: # self.drop_last=False; self.shuffle=True;
482 | random.shuffle(self.buckets[(f, h, w)])
483 | yield self.buckets[(f, h, w)]
484 | del self.buckets[(f, h, w)]
485 | self.buckets[(f, h, w)] = []
486 |
487 | if self.drop_last:
488 | return
489 |
490 | for fhw, bucket in list(self.buckets.items()):
491 | if len(bucket) == 0:
492 | continue
493 | if self.shuffle:
494 | random.shuffle(bucket)
495 | yield bucket
496 | del self.buckets[fhw]
497 | self.buckets[fhw] = []
--------------------------------------------------------------------------------
/utils/dataset_use_mask_entity.py:
--------------------------------------------------------------------------------
1 | import random
2 | from pathlib import Path
3 | from typing import Any, Dict, List, Optional, Tuple
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 | import torchvision.transforms as TT
9 | from accelerate.logging import get_logger
10 | from torch.utils.data import Dataset, Sampler
11 | from torchvision import transforms
12 | from torchvision.transforms import InterpolationMode
13 | from torchvision.transforms.functional import resize
14 |
15 |
16 | # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
17 | # Very few bug reports but it happens. Look in decord Github issues for more relevant information.
18 | import decord # isort:skip
19 |
20 | decord.bridge.set_bridge("torch")
21 |
22 | logger = get_logger(__name__)
23 |
24 | HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
25 | WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
26 | FRAME_BUCKETS = [16, 24, 32, 48, 64, 80]
27 |
28 |
29 | class VideoDataset(Dataset):
30 | def __init__(
31 | self,
32 | data_root: str,
33 | dataset_file: Optional[str] = None,
34 | caption_column: str = "text",
35 | video_column: str = "video",
36 | max_num_frames: int = 49,
37 | id_token: Optional[str] = None,
38 | height_buckets: List[int] = None,
39 | width_buckets: List[int] = None,
40 | frame_buckets: List[int] = None,
41 | load_tensors: bool = False,
42 | random_flip: Optional[float] = None,
43 | image_to_video: bool = False,
44 | enable_control_pose: bool = False,
45 | ) -> None:
46 | super().__init__()
47 |
48 | self.data_root = Path(data_root)
49 | self.dataset_file = dataset_file
50 | self.caption_column = caption_column
51 | self.video_column = video_column
52 | self.max_num_frames = max_num_frames
53 | self.id_token = f"{id_token.strip()} " if id_token else ""
54 | self.height_buckets = height_buckets or HEIGHT_BUCKETS
55 | self.width_buckets = width_buckets or WIDTH_BUCKETS
56 | self.frame_buckets = frame_buckets or FRAME_BUCKETS
57 | self.load_tensors = load_tensors
58 | self.random_flip = random_flip
59 | self.image_to_video = image_to_video
60 | self.enable_control_pose = enable_control_pose
61 |
62 | self.resolutions = [
63 | (f, h, w) for h in self.height_buckets for w in self.width_buckets for f in self.frame_buckets
64 | ]
65 |
66 | # Two methods of loading data are supported.
67 | # - Using a CSV: caption_column and video_column must be some column in the CSV. One could
68 | # make use of other columns too, such as a motion score or aesthetic score, by modifying the
69 | # logic in CSV processing.
70 | # - Using two files containing line-separate captions and relative paths to videos.
71 | # For a more detailed explanation about preparing dataset format, checkout the README.
72 | if dataset_file is None:
73 | (
74 | self.prompts,
75 | self.video_paths,
76 | ) = self._load_dataset_from_local_path()
77 | else:
78 | (
79 | self.prompts,
80 | self.video_paths,
81 | ) = self._load_dataset_from_csv()
82 |
83 | if len(self.video_paths) != len(self.prompts):
84 | raise ValueError(
85 | f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset."
86 | )
87 |
88 | self.video_transforms = transforms.Compose(
89 | [
90 | transforms.RandomHorizontalFlip(random_flip)
91 | if random_flip
92 | else transforms.Lambda(self.identity_transform),
93 | transforms.Lambda(self.scale_transform),
94 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
95 | ]
96 | )
97 |
98 | @staticmethod
99 | def identity_transform(x):
100 | return x
101 |
102 | @staticmethod
103 | def scale_transform(x):
104 | return x / 255.0
105 |
106 | def __len__(self) -> int:
107 | return len(self.video_paths)
108 |
109 | def __getitem__(self, index: int) -> Dict[str, Any]:
110 | if isinstance(index, list):
111 | # Here, index is actually a list of data objects that we need to return.
112 | # The BucketSampler should ideally return indices. But, in the sampler, we'd like
113 | # to have information about num_frames, height and width. Since this is not stored
114 | # as metadata, we need to read the video to get this information. You could read this
115 | # information without loading the full video in memory, but we do it anyway. In order
116 | # to not load the video twice (once to get the metadata, and once to return the loaded video
117 | # based on sampled indices), we cache it in the BucketSampler. When the sampler is
118 | # to yield, we yield the cache data instead of indices. So, this special check ensures
119 | # that data is not loaded a second time. PRs are welcome for improvements.
120 | return index
121 |
122 | if self.load_tensors: # False
123 | image_latents, video_latents, prompt_embeds = self._preprocess_video(self.video_paths[index])
124 |
125 | # This is hardcoded for now.
126 | # The VAE's temporal compression ratio is 4.
127 | # The VAE's spatial compression ratio is 8.
128 | latent_num_frames = video_latents.size(1)
129 | if latent_num_frames % 2 == 0:
130 | num_frames = latent_num_frames * 4
131 | else:
132 | num_frames = (latent_num_frames - 1) * 4 + 1
133 |
134 | height = video_latents.size(2) * 8
135 | width = video_latents.size(3) * 8
136 |
137 | return {
138 | "prompt": prompt_embeds,
139 | "image": image_latents,
140 | "video": video_latents,
141 | "video_metadata": {
142 | "num_frames": num_frames,
143 | "height": height,
144 | "width": width,
145 | },
146 | }
147 | else:
148 | image, video, control_video, _ = self._preprocess_video(self.video_paths[index])
149 | prompt = self.prompts[index].split("#####")[1]
150 |
151 | return {
152 | "prompt": self.id_token + prompt,
153 | "image": image,
154 | "video": video,
155 | "control_video": control_video,
156 | "video_metadata": {
157 | "num_frames": video.shape[0],
158 | "height": video.shape[2],
159 | "width": video.shape[3],
160 | },
161 | }
162 |
163 | def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str]]:
164 | if not self.data_root.exists():
165 | raise ValueError("Root folder for videos does not exist")
166 |
167 | prompt_path = self.data_root.joinpath(self.caption_column)
168 | video_path = self.data_root.joinpath(self.video_column)
169 |
170 | if not prompt_path.exists() or not prompt_path.is_file():
171 | raise ValueError(
172 | "Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts."
173 | )
174 | if not video_path.exists() or not video_path.is_file():
175 | raise ValueError(
176 | "Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory."
177 | )
178 |
179 | with open(prompt_path, "r", encoding="utf-8") as file:
180 | prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0]
181 | with open(video_path, "r", encoding="utf-8") as file:
182 | video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0]
183 |
184 | if not self.load_tensors and any(not path.is_file() for path in video_paths):
185 | raise ValueError(
186 | f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
187 | )
188 |
189 | return prompts, video_paths
190 |
191 | def _load_dataset_from_csv(self) -> Tuple[List[str], List[str]]:
192 | df = pd.read_csv(self.dataset_file)
193 | prompts = df[self.caption_column].tolist()
194 | video_paths = df[self.video_column].tolist()
195 | video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths]
196 |
197 | if any(not path.is_file() for path in video_paths):
198 | raise ValueError(
199 | f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
200 | )
201 |
202 | return prompts, video_paths
203 |
204 | def _preprocess_video(self, path: Path) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
205 | r"""
206 | Loads a single video, or latent and prompt embedding, based on initialization parameters.
207 |
208 | If returning a video, returns a [F, C, H, W] video tensor, and None for the prompt embedding. Here,
209 | F, C, H and W are the frames, channels, height and width of the input video.
210 |
211 | If returning latent/embedding, returns a [F, C, H, W] latent, and the prompt embedding of shape [S, D].
212 | F, C, H and W are the frames, channels, height and width of the latent, and S, D are the sequence length
213 | and embedding dimension of prompt embeddings.
214 | """
215 | if self.load_tensors:
216 | return self._load_preprocessed_latents_and_embeds(path)
217 | else:
218 | video_reader = decord.VideoReader(uri=path.as_posix())
219 | video_num_frames = len(video_reader)
220 |
221 | indices = list(range(0, video_num_frames, video_num_frames // self.max_num_frames))
222 | frames = video_reader.get_batch(indices)
223 | frames = frames[: self.max_num_frames].float()
224 | frames = frames.permute(0, 3, 1, 2).contiguous()
225 | frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0)
226 |
227 | image = frames[:1].clone() if self.image_to_video else None
228 |
229 | return image, frames, None
230 |
231 | def _load_preprocessed_latents_and_embeds(self, path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
232 | filename_without_ext = path.name.split(".")[0]
233 | pt_filename = f"{filename_without_ext}.pt"
234 |
235 | image_latents_path = path.parent.parent.joinpath("image_latents")
236 | video_latents_path = path.parent.parent.joinpath("video_latents")
237 | embeds_path = path.parent.parent.joinpath("prompt_embeds")
238 |
239 | if (
240 | not video_latents_path.exists()
241 | or not embeds_path.exists()
242 | or (self.image_to_video and not image_latents_path.exists())
243 | ):
244 | raise ValueError(
245 | f"When setting the load_tensors parameter to `True`, it is expected that the `{self.data_root=}` contains two folders named `video_latents` and `prompt_embeds`. However, these folders were not found. Please make sure to have prepared your data correctly using `prepare_data.py`. Additionally, if you're training image-to-video, it is expected that an `image_latents` folder is also present."
246 | )
247 |
248 | if self.image_to_video:
249 | image_latent_filepath = image_latents_path.joinpath(pt_filename)
250 | video_latent_filepath = video_latents_path.joinpath(pt_filename)
251 | embeds_filepath = embeds_path.joinpath(pt_filename)
252 |
253 | if not video_latent_filepath.is_file() or not embeds_filepath.is_file():
254 | if self.image_to_video:
255 | image_latent_filepath = image_latent_filepath.as_posix()
256 | video_latent_filepath = video_latent_filepath.as_posix()
257 | embeds_filepath = embeds_filepath.as_posix()
258 | raise ValueError(
259 | f"The file {video_latent_filepath=} or {embeds_filepath=} could not be found. Please ensure that you've correctly executed `prepare_dataset.py`."
260 | )
261 |
262 | images = (
263 | torch.load(image_latent_filepath, map_location="cpu", weights_only=True) if self.image_to_video else None
264 | )
265 | latents = torch.load(video_latent_filepath, map_location="cpu", weights_only=True)
266 | embeds = torch.load(embeds_filepath, map_location="cpu", weights_only=True)
267 |
268 | return images, latents, embeds
269 |
270 |
271 | class VideoDatasetWithResizing(VideoDataset):
272 | def __init__(self, *args, **kwargs) -> None:
273 | super().__init__(*args, **kwargs)
274 |
275 | def get_entity_with_mask_short_time(self, video, masked_video):
276 | video = video.detach().numpy()
277 | masked_video = masked_video.detach().numpy()
278 |
279 | threshold = 200
280 | output_frames = []
281 |
282 | if video.shape[1] != masked_video.shape[1] or video.shape[2] != masked_video.shape[2]:
283 | return video
284 |
285 | frame_length, frame_height, frame_width, _ = video.shape
286 |
287 | for i in range(frame_length):
288 | video_frame = video[i]
289 | mask_frame = masked_video[i]
290 | mask_frame = mask_frame[:, :, 0]
291 |
292 | mask_frame = mask_frame.astype(np.uint8)
293 | binary_mask = (mask_frame > threshold).astype(np.uint8)
294 | binary_mask = np.expand_dims(binary_mask, axis=-1)
295 |
296 | masked_frame = video_frame * binary_mask
297 | output_frames.append(masked_frame)
298 |
299 | return torch.tensor(np.array(output_frames))
300 |
301 | def _preprocess_video(self, path: Path) -> torch.Tensor:
302 | if self.load_tensors:
303 | return self._load_preprocessed_latents_and_embeds(path)
304 | else:
305 | video_reader = decord.VideoReader(uri=path.as_posix())
306 |
307 | use_image_mask = True
308 | if use_image_mask:
309 | try:
310 | mask_reader = decord.VideoReader(uri=path.as_posix().replace("videos", "mask_video"))
311 | except:
312 | mask_reader = None
313 |
314 | if self.enable_control_pose:
315 | pose_reader = decord.VideoReader(uri=path.as_posix().replace("videos", "poses"))
316 | else:
317 | pose_reader = None
318 |
319 | video_num_frames = len(video_reader)
320 | nearest_frame_bucket = min(
321 | self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames))
322 | )
323 |
324 | frames_load_interval = 1
325 | frame_indices = []
326 | frame_index = random.randint(1, 30)
327 | while len(frame_indices) < nearest_frame_bucket:
328 | frame_indices.append(frame_index)
329 | frame_index += frames_load_interval
330 |
331 | # handle video frames.
332 | frames = video_reader.get_batch(frame_indices)
333 | if use_image_mask and mask_reader is not None:
334 | masked_frames = mask_reader.get_batch(frame_indices)
335 | masked_frames = self.get_entity_with_mask_short_time(frames, masked_frames) #FIXME use mask
336 |
337 | maked_frames = masked_frames[:nearest_frame_bucket].float()
338 | maked_frames = masked_frames.permute(0, 3, 1, 2).contiguous()
339 |
340 | frames = frames[:nearest_frame_bucket].float()
341 | frames = frames.permute(0, 3, 1, 2).contiguous()
342 |
343 | nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
344 | frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0)
345 | frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
346 |
347 | if use_image_mask and mask_reader is not None:
348 | maked_frames_resized = torch.stack([resize(mask_frame, nearest_res) for mask_frame in maked_frames], dim=0)
349 | maked_frames = torch.stack([self.video_transforms(mask_frame) for mask_frame in maked_frames_resized], dim=0)
350 |
351 | pose_frames = pose_reader.get_batch(frame_indices)
352 | pose_frames = pose_frames[:nearest_frame_bucket].float()
353 | pose_frames = pose_frames.permute(0, 3, 1, 2).contiguous()
354 |
355 | pose_frames_resized = torch.stack([resize(frame, nearest_res) for frame in pose_frames], dim=0)
356 | pose_frames = torch.stack([self.video_transforms(frame) for frame in pose_frames_resized], dim=0)
357 |
358 | if use_image_mask and mask_reader is not None:
359 | image = maked_frames[:1].clone() if self.image_to_video else None
360 | else:
361 | image = frames[:1].clone() if self.image_to_video else None
362 |
363 | return image, frames, pose_frames, None
364 |
365 | def _find_nearest_resolution(self, height, width):
366 | nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
367 | return nearest_res[1], nearest_res[2]
368 |
369 |
370 | class VideoDatasetWithResizeAndRectangleCrop(VideoDataset):
371 | def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None:
372 | super().__init__(*args, **kwargs)
373 | self.video_reshape_mode = video_reshape_mode
374 |
375 | def _resize_for_rectangle_crop(self, arr, image_size):
376 | reshape_mode = self.video_reshape_mode
377 | if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
378 | arr = resize(
379 | arr,
380 | size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
381 | interpolation=InterpolationMode.BICUBIC,
382 | )
383 | else:
384 | arr = resize(
385 | arr,
386 | size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
387 | interpolation=InterpolationMode.BICUBIC,
388 | )
389 |
390 | h, w = arr.shape[2], arr.shape[3]
391 | arr = arr.squeeze(0)
392 |
393 | delta_h = h - image_size[0]
394 | delta_w = w - image_size[1]
395 |
396 | if reshape_mode == "random" or reshape_mode == "none":
397 | top = np.random.randint(0, delta_h + 1)
398 | left = np.random.randint(0, delta_w + 1)
399 | elif reshape_mode == "center":
400 | top, left = delta_h // 2, delta_w // 2
401 | else:
402 | raise NotImplementedError
403 | arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
404 | return arr
405 |
406 | def _preprocess_video(self, path: Path) -> torch.Tensor:
407 | if self.load_tensors:
408 | return self._load_preprocessed_latents_and_embeds(path)
409 | else:
410 | video_reader = decord.VideoReader(uri=path.as_posix())
411 | video_num_frames = len(video_reader)
412 | nearest_frame_bucket = min(
413 | self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames))
414 | )
415 |
416 | frames_load_interval = 1
417 | frame_indices = []
418 | frame_index = 0
419 | while len(frame_indices) < nearest_frame_bucket:
420 | frame_indices.append(frame_index)
421 | frame_index += frames_load_interval
422 |
423 | frames = video_reader.get_batch(frame_indices)
424 | frames = frames[:nearest_frame_bucket].float()
425 | frames = frames.permute(0, 3, 1, 2).contiguous()
426 |
427 | nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
428 | frames_resized = self._resize_for_rectangle_crop(frames, nearest_res)
429 | frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
430 |
431 | image = frames[:1].clone() if self.image_to_video else None
432 |
433 | return image, frames, None
434 |
435 | def _find_nearest_resolution(self, height, width):
436 | nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
437 | return nearest_res[1], nearest_res[2]
438 |
439 |
440 | class BucketSampler(Sampler):
441 | r"""
442 | PyTorch Sampler that groups 3D data by height, width and frames.
443 |
444 | Args:
445 | data_source (`VideoDataset`):
446 | A PyTorch dataset object that is an instance of `VideoDataset`.
447 | batch_size (`int`, defaults to `8`):
448 | The batch size to use for training.
449 | shuffle (`bool`, defaults to `True`):
450 | Whether or not to shuffle the data in each batch before dispatching to dataloader.
451 | drop_last (`bool`, defaults to `False`):
452 | Whether or not to drop incomplete buckets of data after completely iterating over all data
453 | in the dataset. If set to True, only batches that have `batch_size` number of entries will
454 | be yielded. If set to False, it is guaranteed that all data in the dataset will be processed
455 | and batches that do not have `batch_size` number of entries will also be yielded.
456 | """
457 |
458 | def __init__(
459 | self, data_source: VideoDataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False
460 | ) -> None:
461 | self.data_source = data_source
462 | self.batch_size = batch_size
463 | self.shuffle = shuffle
464 | self.drop_last = drop_last
465 |
466 | self.buckets = {resolution: [] for resolution in data_source.resolutions}
467 |
468 | self._raised_warning_for_drop_last = False
469 |
470 | def __len__(self):
471 | if self.drop_last and not self._raised_warning_for_drop_last:
472 | self._raised_warning_for_drop_last = True
473 | logger.warning(
474 | "Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training."
475 | )
476 | return (len(self.data_source) + self.batch_size - 1) // self.batch_size
477 |
478 | def __iter__(self):
479 | for index, data in enumerate(self.data_source):
480 | video_metadata = data["video_metadata"]
481 | f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"]
482 |
483 | self.buckets[(f, h, w)].append(data)
484 | if len(self.buckets[(f, h, w)]) == self.batch_size:
485 | if self.shuffle: # self.drop_last=False; self.shuffle=True;
486 | random.shuffle(self.buckets[(f, h, w)])
487 | yield self.buckets[(f, h, w)]
488 | del self.buckets[(f, h, w)]
489 | self.buckets[(f, h, w)] = []
490 |
491 | if self.drop_last:
492 | return
493 |
494 | for fhw, bucket in list(self.buckets.items()):
495 | if len(bucket) == 0:
496 | continue
497 | if self.shuffle:
498 | random.shuffle(bucket)
499 | yield bucket
500 | del self.buckets[fhw]
501 | self.buckets[fhw] = []
--------------------------------------------------------------------------------
/utils/freeinit_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.fft as fft
3 | import math
4 |
5 | def freq_mix_3d(x, noise, LPF, fusion='low'):
6 | """
7 | Noise reinitialization.
8 |
9 | Args:
10 | x: diffused latent
11 | noise: randomly sampled noise
12 | LPF: low pass filter
13 | """
14 | # FFT
15 | x_freq = fft.fftn(x, dim=(-3, -2, -1))
16 | x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
17 | noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
18 | noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))
19 |
20 |
21 | if fusion == 'low':
22 | # frequency mix
23 | LPF = LPF.to(x_freq.device)
24 | HPF = 1 - LPF
25 | x_freq_low = x_freq * LPF
26 | noise_freq_high = noise_freq * HPF
27 | x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain
28 | elif fusion == 'high':
29 | # frequency mix
30 | LPF = LPF.to(x_freq.device)
31 | HPF = 1 - LPF
32 | x_freq_high = x_freq * HPF
33 | noise_freq_low = noise_freq * LPF
34 | x_freq_mixed = x_freq_high + noise_freq_low # mix in freq domain
35 |
36 |
37 | # IFFT
38 | x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
39 | x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real
40 |
41 | return x_mixed
42 |
43 |
44 | def get_freq_filter(shape, device, filter_type, n, d_s, d_t):
45 | """
46 | Form the frequency filter for noise reinitialization.
47 |
48 | Args:
49 | shape: shape of latent (B, C, T, H, W)
50 | filter_type: type of the freq filter
51 | n: (only for butterworth) order of the filter, larger n ~ ideal, smaller n ~ gaussian
52 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
53 | d_t: normalized stop frequency for temporal dimension (0.0-1.0)
54 | """
55 | if filter_type == "gaussian":
56 | return gaussian_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device)
57 | elif filter_type == "ideal":
58 | return ideal_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device)
59 | elif filter_type == "box":
60 | return box_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device)
61 | elif filter_type == "butterworth":
62 | return butterworth_low_pass_filter(shape=shape, n=n, d_s=d_s, d_t=d_t).to(device)
63 | else:
64 | raise NotImplementedError
65 |
66 | def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25):
67 | """
68 | Compute the gaussian low pass filter mask.
69 |
70 | Args:
71 | shape: shape of the filter (volume)
72 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
73 | d_t: normalized stop frequency for temporal dimension (0.0-1.0)
74 | """
75 | T, H, W = shape[-3], shape[-2], shape[-1]
76 | mask = torch.zeros(shape)
77 | if d_s==0 or d_t==0:
78 | return mask
79 | for t in range(T):
80 | for h in range(H):
81 | for w in range(W):
82 | d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
83 | mask[..., t,h,w] = math.exp(-1/(2*d_s**2) * d_square)
84 | return mask
85 |
86 |
87 | def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25):
88 | """
89 | Compute the butterworth low pass filter mask.
90 |
91 | Args:
92 | shape: shape of the filter (volume)
93 | n: order of the filter, larger n ~ ideal, smaller n ~ gaussian
94 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
95 | d_t: normalized stop frequency for temporal dimension (0.0-1.0)
96 | """
97 | T, H, W = shape[-3], shape[-2], shape[-1]
98 | mask = torch.zeros(shape)
99 | if d_s==0 or d_t==0:
100 | return mask
101 | for t in range(T):
102 | for h in range(H):
103 | for w in range(W):
104 | d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
105 | mask[..., t,h,w] = 1 / (1 + (d_square / d_s**2)**n)
106 | return mask
107 |
108 |
109 | def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25):
110 | """
111 | Compute the ideal low pass filter mask.
112 |
113 | Args:
114 | shape: shape of the filter (volume)
115 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
116 | d_t: normalized stop frequency for temporal dimension (0.0-1.0)
117 | """
118 | T, H, W = shape[-3], shape[-2], shape[-1]
119 | mask = torch.zeros(shape)
120 | if d_s==0 or d_t==0:
121 | return mask
122 | for t in range(T):
123 | for h in range(H):
124 | for w in range(W):
125 | d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
126 | mask[..., t,h,w] = 1 if d_square <= d_s*2 else 0
127 | return mask
128 |
129 |
130 | def box_low_pass_filter(shape, d_s=0.25, d_t=0.25):
131 | """
132 | Compute the ideal low pass filter mask (approximated version).
133 |
134 | Args:
135 | shape: shape of the filter (volume)
136 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
137 | d_t: normalized stop frequency for temporal dimension (0.0-1.0)
138 | """
139 | T, H, W = shape[-3], shape[-2], shape[-1]
140 | mask = torch.zeros(shape)
141 | if d_s==0 or d_t==0:
142 | return mask
143 |
144 | threshold_s = round(int(H // 2) * d_s)
145 | threshold_t = round(T // 2 * d_t)
146 |
147 | cframe, crow, ccol = T // 2, H // 2, W //2
148 | mask[..., cframe - threshold_t:cframe + threshold_t, crow - threshold_s:crow + threshold_s, ccol - threshold_s:ccol + threshold_s] = 1.0
149 |
150 | return mask
--------------------------------------------------------------------------------
/utils/load_validation_control.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | import torch
3 |
4 | import decord
5 |
6 | import numpy as np
7 |
8 | def identity_transform(x):
9 | return x
10 |
11 | def scale_transform(x):
12 | return x / 255.0
13 |
14 | def load_control_video(video_path, height, width):
15 | control_video_height = height
16 | control_video_width = width
17 | video_transforms = transforms.Compose(
18 | [
19 | transforms.Lambda(identity_transform),
20 | transforms.Lambda(scale_transform),
21 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
22 | transforms.Resize((control_video_height, control_video_width))
23 | ]
24 | )
25 | validation_control_video_reader = decord.VideoReader(uri=video_path)
26 | frame_indices = [i for i in range(0, 36)]
27 |
28 | validation_control_video = validation_control_video_reader.get_batch(frame_indices).float()
29 |
30 | validation_control_video = validation_control_video.permute(0, 3, 1, 2).contiguous()
31 | validation_control_video = torch.cat((validation_control_video, validation_control_video[-1].unsqueeze(0)), dim=0)
32 | validation_control_video = torch.stack([video_transforms(frame) for frame in validation_control_video], dim=0)
33 | return validation_control_video
34 |
35 | def load_control_video_inference(video_path, video_height, video_width, max_frames=None):
36 | video_transforms = transforms.Compose(
37 | [
38 | transforms.Lambda(identity_transform),
39 | transforms.Lambda(scale_transform),
40 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
41 | transforms.Resize((1360, 768))
42 | ]
43 | )
44 | validation_control_video_reader = decord.VideoReader(uri=video_path)
45 | if max_frames:
46 | frame_indices = [i for i in range(0, max_frames)]
47 | else:
48 | frame_indices = [i for i in range(len(validation_control_video_reader))]
49 |
50 | validation_control_video = validation_control_video_reader.get_batch(frame_indices).asnumpy()
51 | validation_control_video = torch.from_numpy(validation_control_video).float()
52 |
53 | validation_control_video = validation_control_video.permute(0, 3, 1, 2).contiguous()
54 | validation_control_video = torch.stack([video_transforms(frame) for frame in validation_control_video], dim=0)
55 | return validation_control_video
56 |
57 | def load_contorl_video_from_Image(validation_control_images, video_height, video_width):
58 | video_transforms = transforms.Compose(
59 | [
60 | transforms.Lambda(identity_transform),
61 | transforms.Lambda(scale_transform),
62 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
63 | transforms.Resize((video_height, video_width))
64 | ]
65 | )
66 | validation_control_video_tensor = []
67 |
68 | for image in validation_control_images:
69 | validation_control_video_tensor.append(torch.from_numpy(np.array(image)).float())
70 |
71 | validation_control_video = torch.stack(validation_control_video_tensor, dim=0)
72 | validation_control_video = validation_control_video.permute(0, 3, 1, 2).contiguous()
73 | validation_control_video = torch.stack([video_transforms(frame) for frame in validation_control_video], dim=0)
74 | return validation_control_video
75 |
76 | if __name__=="__main__":
77 | validation_control_video_path = "/home/user/video.mp4"
78 | validation_control_video = load_control_video(validation_control_video_path)
--------------------------------------------------------------------------------
/utils/pre_process.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import logging
4 | import math
5 | from omegaconf import OmegaConf
6 | from datetime import datetime
7 | from pathlib import Path
8 | from PIL import Image
9 | import numpy as np
10 | import torch.jit
11 | from torchvision.datasets.folder import pil_loader
12 | from torchvision.transforms.functional import pil_to_tensor, resize, center_crop
13 | from torchvision.transforms.functional import to_pil_image
14 | from dwpose.preprocess import get_image_pose, get_video_pose
15 |
16 | ASPECT_RATIO = 9 / 16
17 |
18 | def preprocess(video_path, image_path, width=576, height=1024, sample_stride=2, max_frame_num=None):
19 | """preprocess ref image pose and video pose
20 |
21 | Args:
22 | video_path (str): input video pose path
23 | image_path (str): reference image path
24 | resolution (int, optional): Defaults to 576.
25 | sample_stride (int, optional): Defaults to 2.
26 | """
27 | image_pixels = pil_loader(image_path)
28 | image_pixels = pil_to_tensor(image_pixels) # (c, h, w)
29 | h, w = image_pixels.shape[-2:]
30 |
31 | w_target, h_target = width, height
32 | h_w_ratio = float(h) / float(w)
33 | if h_w_ratio < h_target / w_target:
34 | h_resize, w_resize = h_target, math.ceil(h_target / h_w_ratio)
35 | else:
36 | h_resize, w_resize = math.ceil(w_target * h_w_ratio), w_target
37 | image_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None)
38 | image_pixels = center_crop(image_pixels, [h_target, w_target])
39 | image_pixels = image_pixels.permute((1, 2, 0)).numpy()
40 | image_pose = get_image_pose(image_pixels)
41 | video_pose = get_video_pose(video_path, image_pixels, sample_stride=sample_stride, max_frame_num=max_frame_num)
42 | pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose])
43 | image_pixels = Image.fromarray(image_pixels)
44 | pose_pixels = [Image.fromarray(p.transpose((1,2,0))) for p in pose_pixels]
45 | return pose_pixels, image_pixels
46 |
47 |
--------------------------------------------------------------------------------
/utils/prepare_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import argparse
4 | import functools
5 | import json
6 | import os
7 | import pathlib
8 | import queue
9 | import traceback
10 | import uuid
11 | from concurrent.futures import ThreadPoolExecutor
12 | from typing import Any, Dict, List, Optional, Union
13 |
14 | import torch
15 | import torch.distributed as dist
16 | from diffusers import AutoencoderKLCogVideoX
17 | from diffusers.training_utils import set_seed
18 | from diffusers.utils import export_to_video, get_logger
19 | from torch.utils.data import DataLoader
20 | from torchvision import transforms
21 | from tqdm import tqdm
22 | from transformers import T5EncoderModel, T5Tokenizer
23 |
24 |
25 | import decord # isort:skip
26 |
27 | from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
28 |
29 |
30 | decord.bridge.set_bridge("torch")
31 |
32 | logger = get_logger(__name__)
33 |
34 | DTYPE_MAPPING = {
35 | "fp32": torch.float32,
36 | "fp16": torch.float16,
37 | "bf16": torch.bfloat16,
38 | }
39 |
40 |
41 | def check_height(x: Any) -> int:
42 | x = int(x)
43 | if x % 16 != 0:
44 | raise argparse.ArgumentTypeError(
45 | f"`--height_buckets` must be divisible by 16, but got {x} which does not fit criteria."
46 | )
47 | return x
48 |
49 |
50 | def check_width(x: Any) -> int:
51 | x = int(x)
52 | if x % 16 != 0:
53 | raise argparse.ArgumentTypeError(
54 | f"`--width_buckets` must be divisible by 16, but got {x} which does not fit criteria."
55 | )
56 | return x
57 |
58 |
59 | def check_frames(x: Any) -> int:
60 | x = int(x)
61 | if x % 4 != 0 and x % 4 != 1:
62 | raise argparse.ArgumentTypeError(
63 | f"`--frames_buckets` must be of form `4 * k` or `4 * k + 1`, but got {x} which does not fit criteria."
64 | )
65 | return x
66 |
67 |
68 | def get_args() -> Dict[str, Any]:
69 | parser = argparse.ArgumentParser()
70 | parser.add_argument(
71 | "--model_id",
72 | type=str,
73 | default="THUDM/CogVideoX-2b",
74 | help="Hugging Face model ID to use for tokenizer, text encoder and VAE.",
75 | )
76 | parser.add_argument("--data_root", type=str, required=True, help="Path to where training data is located.")
77 | parser.add_argument(
78 | "--dataset_file", type=str, default=None, help="Path to CSV file containing metadata about training data."
79 | )
80 | parser.add_argument(
81 | "--caption_column",
82 | type=str,
83 | default="caption",
84 | help="If using a CSV file via the `--dataset_file` argument, this should be the name of the column containing the captions. If using the folder structure format for data loading, this should be the name of the file containing line-separated captions (the file should be located in `--data_root`).",
85 | )
86 | parser.add_argument(
87 | "--video_column",
88 | type=str,
89 | default="video",
90 | help="If using a CSV file via the `--dataset_file` argument, this should be the name of the column containing the video paths. If using the folder structure format for data loading, this should be the name of the file containing line-separated video paths (the file should be located in `--data_root`).",
91 | )
92 | parser.add_argument(
93 | "--id_token",
94 | type=str,
95 | default=None,
96 | help="Identifier token appended to the start of each prompt if provided.",
97 | )
98 | parser.add_argument(
99 | "--height_buckets",
100 | nargs="+",
101 | type=check_height,
102 | default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536],
103 | )
104 | parser.add_argument(
105 | "--width_buckets",
106 | nargs="+",
107 | type=check_width,
108 | default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536],
109 | )
110 | parser.add_argument(
111 | "--frame_buckets",
112 | nargs="+",
113 | type=check_frames,
114 | default=[49],
115 | )
116 | parser.add_argument(
117 | "--random_flip",
118 | type=float,
119 | default=None,
120 | help="If random horizontal flip augmentation is to be used, this should be the flip probability.",
121 | )
122 | parser.add_argument(
123 | "--dataloader_num_workers",
124 | type=int,
125 | default=0,
126 | help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
127 | )
128 | parser.add_argument(
129 | "--pin_memory",
130 | action="store_true",
131 | help="Whether or not to use the pinned memory setting in pytorch dataloader.",
132 | )
133 | parser.add_argument(
134 | "--video_reshape_mode",
135 | type=str,
136 | default=None,
137 | help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
138 | )
139 | parser.add_argument(
140 | "--save_image_latents",
141 | action="store_true",
142 | help="Whether or not to encode and store image latents, which are required for image-to-video finetuning. The image latents are the first frame of input videos encoded with the VAE.",
143 | )
144 | parser.add_argument(
145 | "--output_dir",
146 | type=str,
147 | required=True,
148 | help="Path to output directory where preprocessed videos/latents/embeddings will be saved.",
149 | )
150 | parser.add_argument("--max_num_frames", type=int, default=49, help="Maximum number of frames in output video.")
151 | parser.add_argument(
152 | "--max_sequence_length", type=int, default=226, help="Max sequence length of prompt embeddings."
153 | )
154 | parser.add_argument("--target_fps", type=int, default=8, help="Frame rate of output videos.")
155 | parser.add_argument(
156 | "--save_latents_and_embeddings",
157 | action="store_true",
158 | help="Whether to encode videos/captions to latents/embeddings and save them in pytorch serializable format.",
159 | )
160 | parser.add_argument(
161 | "--use_slicing",
162 | action="store_true",
163 | help="Whether to enable sliced encoding/decoding in the VAE. Only used if `--save_latents_and_embeddings` is also used.",
164 | )
165 | parser.add_argument(
166 | "--use_tiling",
167 | action="store_true",
168 | help="Whether to enable tiled encoding/decoding in the VAE. Only used if `--save_latents_and_embeddings` is also used.",
169 | )
170 | parser.add_argument("--batch_size", type=int, default=1, help="Number of videos to process at once in the VAE.")
171 | parser.add_argument(
172 | "--num_decode_threads",
173 | type=int,
174 | default=0,
175 | help="Number of decoding threads for `decord` to use. The default `0` means to automatically determine required number of threads.",
176 | )
177 | parser.add_argument(
178 | "--dtype",
179 | type=str,
180 | choices=["fp32", "fp16", "bf16"],
181 | default="fp32",
182 | help="Data type to use when generating latents and prompt embeddings.",
183 | )
184 | parser.add_argument("--seed", type=int, default=42, help="Seed for reproducibility.")
185 | parser.add_argument(
186 | "--num_artifact_workers", type=int, default=4, help="Number of worker threads for serializing artifacts."
187 | )
188 | return parser.parse_args()
189 |
190 |
191 | def _get_t5_prompt_embeds(
192 | tokenizer: T5Tokenizer,
193 | text_encoder: T5EncoderModel,
194 | prompt: Union[str, List[str]],
195 | num_videos_per_prompt: int = 1,
196 | max_sequence_length: int = 226,
197 | device: Optional[torch.device] = None,
198 | dtype: Optional[torch.dtype] = None,
199 | text_input_ids=None,
200 | ):
201 | prompt = [prompt] if isinstance(prompt, str) else prompt
202 | batch_size = len(prompt)
203 |
204 | if tokenizer is not None:
205 | text_inputs = tokenizer(
206 | prompt,
207 | padding="max_length",
208 | max_length=max_sequence_length,
209 | truncation=True,
210 | add_special_tokens=True,
211 | return_tensors="pt",
212 | )
213 | text_input_ids = text_inputs.input_ids
214 | else:
215 | if text_input_ids is None:
216 | raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
217 |
218 | prompt_embeds = text_encoder(text_input_ids.to(device))[0]
219 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
220 |
221 | # duplicate text embeddings for each generation per prompt, using mps friendly method
222 | _, seq_len, _ = prompt_embeds.shape
223 | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
224 | prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
225 |
226 | return prompt_embeds
227 |
228 |
229 | def encode_prompt(
230 | tokenizer: T5Tokenizer,
231 | text_encoder: T5EncoderModel,
232 | prompt: Union[str, List[str]],
233 | num_videos_per_prompt: int = 1,
234 | max_sequence_length: int = 226,
235 | device: Optional[torch.device] = None,
236 | dtype: Optional[torch.dtype] = None,
237 | text_input_ids=None,
238 | ):
239 | prompt = [prompt] if isinstance(prompt, str) else prompt
240 | prompt_embeds = _get_t5_prompt_embeds(
241 | tokenizer,
242 | text_encoder,
243 | prompt=prompt,
244 | num_videos_per_prompt=num_videos_per_prompt,
245 | max_sequence_length=max_sequence_length,
246 | device=device,
247 | dtype=dtype,
248 | text_input_ids=text_input_ids,
249 | )
250 | return prompt_embeds
251 |
252 |
253 | def compute_prompt_embeddings(
254 | tokenizer: T5Tokenizer,
255 | text_encoder: T5EncoderModel,
256 | prompts: List[str],
257 | max_sequence_length: int,
258 | device: torch.device,
259 | dtype: torch.dtype,
260 | requires_grad: bool = False,
261 | ):
262 | if requires_grad:
263 | prompt_embeds = encode_prompt(
264 | tokenizer,
265 | text_encoder,
266 | prompts,
267 | num_videos_per_prompt=1,
268 | max_sequence_length=max_sequence_length,
269 | device=device,
270 | dtype=dtype,
271 | )
272 | else:
273 | with torch.no_grad():
274 | prompt_embeds = encode_prompt(
275 | tokenizer,
276 | text_encoder,
277 | prompts,
278 | num_videos_per_prompt=1,
279 | max_sequence_length=max_sequence_length,
280 | device=device,
281 | dtype=dtype,
282 | )
283 | return prompt_embeds
284 |
285 |
286 | to_pil_image = transforms.ToPILImage(mode="RGB")
287 |
288 |
289 | def save_image(image: torch.Tensor, path: pathlib.Path) -> None:
290 | image = image.to(dtype=torch.float32).clamp(-1, 1)
291 | image = to_pil_image(image.float())
292 | image.save(path)
293 |
294 |
295 | def save_video(video: torch.Tensor, path: pathlib.Path, fps: int = 8) -> None:
296 | video = video.to(dtype=torch.float32).clamp(-1, 1)
297 | video = [to_pil_image(frame) for frame in video]
298 | export_to_video(video, path, fps=fps)
299 |
300 |
301 | def save_prompt(prompt: str, path: pathlib.Path) -> None:
302 | with open(path, "w", encoding="utf-8") as file:
303 | file.write(prompt)
304 |
305 |
306 | def save_metadata(metadata: Dict[str, Any], path: pathlib.Path) -> None:
307 | with open(path, "w", encoding="utf-8") as file:
308 | file.write(json.dumps(metadata))
309 |
310 |
311 | @torch.no_grad()
312 | def serialize_artifacts(
313 | batch_size: int,
314 | fps: int,
315 | images_dir: Optional[pathlib.Path] = None,
316 | image_latents_dir: Optional[pathlib.Path] = None,
317 | videos_dir: Optional[pathlib.Path] = None,
318 | video_latents_dir: Optional[pathlib.Path] = None,
319 | prompts_dir: Optional[pathlib.Path] = None,
320 | prompt_embeds_dir: Optional[pathlib.Path] = None,
321 | images: Optional[torch.Tensor] = None,
322 | image_latents: Optional[torch.Tensor] = None,
323 | videos: Optional[torch.Tensor] = None,
324 | video_latents: Optional[torch.Tensor] = None,
325 | prompts: Optional[List[str]] = None,
326 | prompt_embeds: Optional[torch.Tensor] = None,
327 | ) -> None:
328 | num_frames, height, width = videos.size(1), videos.size(3), videos.size(4)
329 | metadata = [{"num_frames": num_frames, "height": height, "width": width}]
330 |
331 | data_folder_mapper_list = [
332 | (images, images_dir, lambda img, path: save_image(img[0], path), "png"),
333 | (image_latents, image_latents_dir, torch.save, "pt"),
334 | (videos, videos_dir, functools.partial(save_video, fps=fps), "mp4"),
335 | (video_latents, video_latents_dir, torch.save, "pt"),
336 | (prompts, prompts_dir, save_prompt, "txt"),
337 | (prompt_embeds, prompt_embeds_dir, torch.save, "pt"),
338 | (metadata, videos_dir, save_metadata, "txt"),
339 | ]
340 | filenames = [uuid.uuid4() for _ in range(batch_size)]
341 |
342 | for data, folder, save_fn, extension in data_folder_mapper_list:
343 | if data is None:
344 | continue
345 | for slice, filename in zip(data, filenames):
346 | if isinstance(slice, torch.Tensor):
347 | slice = slice.clone().to("cpu")
348 | path = folder.joinpath(f"{filename}.{extension}")
349 | save_fn(slice, path)
350 |
351 |
352 | def save_intermediates(output_queue: queue.Queue) -> None:
353 | while True:
354 | try:
355 | item = output_queue.get(timeout=30)
356 | if item is None:
357 | break
358 | serialize_artifacts(**item)
359 |
360 | except queue.Empty:
361 | continue
362 |
363 |
364 | @torch.no_grad()
365 | def main():
366 | args = get_args()
367 | set_seed(args.seed)
368 |
369 | output_dir = pathlib.Path(args.output_dir)
370 | tmp_dir = output_dir.joinpath("tmp")
371 |
372 | output_dir.mkdir(parents=True, exist_ok=True)
373 | tmp_dir.mkdir(parents=True, exist_ok=True)
374 |
375 | # Create task queue for non-blocking serializing of artifacts
376 | output_queue = queue.Queue()
377 | save_thread = ThreadPoolExecutor(max_workers=args.num_artifact_workers)
378 | save_future = save_thread.submit(save_intermediates, output_queue)
379 |
380 | # Initialize distributed processing
381 | if "LOCAL_RANK" in os.environ:
382 | local_rank = int(os.environ["LOCAL_RANK"])
383 | torch.cuda.set_device(local_rank)
384 | dist.init_process_group(backend="nccl")
385 | world_size = dist.get_world_size()
386 | rank = dist.get_rank()
387 | else:
388 | # Single GPU
389 | local_rank = 0
390 | world_size = 1
391 | rank = 0
392 | torch.cuda.set_device(rank)
393 |
394 | # Create folders where intermediate tensors from each rank will be saved
395 | images_dir = tmp_dir.joinpath(f"images/{rank}")
396 | image_latents_dir = tmp_dir.joinpath(f"image_latents/{rank}")
397 | videos_dir = tmp_dir.joinpath(f"videos/{rank}")
398 | video_latents_dir = tmp_dir.joinpath(f"video_latents/{rank}")
399 | prompts_dir = tmp_dir.joinpath(f"prompts/{rank}")
400 | prompt_embeds_dir = tmp_dir.joinpath(f"prompt_embeds/{rank}")
401 |
402 | images_dir.mkdir(parents=True, exist_ok=True)
403 | image_latents_dir.mkdir(parents=True, exist_ok=True)
404 | videos_dir.mkdir(parents=True, exist_ok=True)
405 | video_latents_dir.mkdir(parents=True, exist_ok=True)
406 | prompts_dir.mkdir(parents=True, exist_ok=True)
407 | prompt_embeds_dir.mkdir(parents=True, exist_ok=True)
408 |
409 | weight_dtype = DTYPE_MAPPING[args.dtype]
410 | target_fps = args.target_fps
411 |
412 | # 1. Dataset
413 | dataset_init_kwargs = {
414 | "data_root": args.data_root,
415 | "dataset_file": args.dataset_file,
416 | "caption_column": args.caption_column,
417 | "video_column": args.video_column,
418 | "max_num_frames": args.max_num_frames,
419 | "id_token": args.id_token,
420 | "height_buckets": args.height_buckets,
421 | "width_buckets": args.width_buckets,
422 | "frame_buckets": args.frame_buckets,
423 | "load_tensors": False,
424 | "random_flip": args.random_flip,
425 | "image_to_video": args.save_image_latents,
426 | }
427 | if args.video_reshape_mode is None:
428 | dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
429 | else:
430 | dataset = VideoDatasetWithResizeAndRectangleCrop(
431 | video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
432 | )
433 |
434 | original_dataset_size = len(dataset)
435 |
436 | # Split data among GPUs
437 | if world_size > 1:
438 | samples_per_gpu = original_dataset_size // world_size
439 | start_index = rank * samples_per_gpu
440 | end_index = start_index + samples_per_gpu
441 | if rank == world_size - 1:
442 | end_index = original_dataset_size # Make sure the last GPU gets the remaining data
443 |
444 | # Slice the data
445 | dataset.prompts = dataset.prompts[start_index:end_index]
446 | dataset.video_paths = dataset.video_paths[start_index:end_index]
447 | else:
448 | pass
449 |
450 | rank_dataset_size = len(dataset)
451 |
452 | # 2. Dataloader
453 | def collate_fn(data):
454 | prompts = [x["prompt"] for x in data[0]]
455 |
456 | images = None
457 | if args.save_image_latents:
458 | images = [x["image"] for x in data[0]]
459 | images = torch.stack(images).to(dtype=weight_dtype, non_blocking=True)
460 |
461 | videos = [x["video"] for x in data[0]]
462 | videos = torch.stack(videos).to(dtype=weight_dtype, non_blocking=True)
463 |
464 | control_videos = [x["control_video"] for x in data[0]] #FIXME
465 | control_videos = torch.stack(control_videos).to(dtype=weight_dtype, non_blocking=True)
466 |
467 | return {
468 | "images": images,
469 | "videos": videos,
470 | "control_videos": control_videos,
471 | "prompts": prompts,
472 | }
473 |
474 | dataloader = DataLoader(
475 | dataset,
476 | batch_size=1,
477 | sampler=BucketSampler(dataset, batch_size=args.batch_size, shuffle=True, drop_last=False),
478 | collate_fn=collate_fn,
479 | num_workers=args.dataloader_num_workers,
480 | pin_memory=args.pin_memory,
481 | )
482 |
483 | # 3. Prepare models
484 | device = f"cuda:{rank}"
485 |
486 | if args.save_latents_and_embeddings:
487 | tokenizer = T5Tokenizer.from_pretrained(args.model_id, subfolder="tokenizer")
488 | text_encoder = T5EncoderModel.from_pretrained(
489 | args.model_id, subfolder="text_encoder", torch_dtype=weight_dtype
490 | )
491 | text_encoder = text_encoder.to(device)
492 |
493 | vae = AutoencoderKLCogVideoX.from_pretrained(args.model_id, subfolder="vae", torch_dtype=weight_dtype)
494 | vae = vae.to(device)
495 |
496 | if args.use_slicing:
497 | vae.enable_slicing()
498 | if args.use_tiling:
499 | vae.enable_tiling()
500 |
501 | # 4. Compute latents and embeddings and save
502 | if rank == 0:
503 | iterator = tqdm(
504 | dataloader, desc="Encoding", total=(rank_dataset_size + args.batch_size - 1) // args.batch_size
505 | )
506 | else:
507 | iterator = dataloader
508 |
509 | for step, batch in enumerate(iterator):
510 | try:
511 | images = None
512 | image_latents = None
513 | video_latents = None
514 | prompt_embeds = None
515 |
516 | if args.save_image_latents:
517 | images = batch["images"].to(device, non_blocking=True)
518 | images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
519 |
520 | videos = batch["videos"].to(device, non_blocking=True)
521 | videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
522 |
523 | prompts = batch["prompts"]
524 |
525 | # Encode videos & images
526 | if args.save_latents_and_embeddings:
527 | if args.use_slicing:
528 | if args.save_image_latents:
529 | encoded_slices = [vae._encode(image_slice) for image_slice in images.split(1)]
530 | image_latents = torch.cat(encoded_slices)
531 | image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
532 |
533 | encoded_slices = [vae._encode(video_slice) for video_slice in videos.split(1)]
534 | video_latents = torch.cat(encoded_slices)
535 |
536 | else:
537 | if args.save_image_latents:
538 | image_latents = vae._encode(images)
539 | image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
540 |
541 | video_latents = vae._encode(videos)
542 |
543 | video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
544 |
545 | # Encode prompts
546 | prompt_embeds = compute_prompt_embeddings(
547 | tokenizer,
548 | text_encoder,
549 | prompts,
550 | args.max_sequence_length,
551 | device,
552 | weight_dtype,
553 | requires_grad=False,
554 | )
555 |
556 | if images is not None:
557 | images = (images.permute(0, 2, 1, 3, 4) + 1) / 2
558 |
559 | videos = (videos.permute(0, 2, 1, 3, 4) + 1) / 2
560 |
561 | output_queue.put(
562 | {
563 | "batch_size": len(prompts),
564 | "fps": target_fps,
565 | "images_dir": images_dir,
566 | "image_latents_dir": image_latents_dir,
567 | "videos_dir": videos_dir,
568 | "video_latents_dir": video_latents_dir,
569 | "prompts_dir": prompts_dir,
570 | "prompt_embeds_dir": prompt_embeds_dir,
571 | "images": images,
572 | "image_latents": image_latents,
573 | "videos": videos,
574 | "video_latents": video_latents,
575 | "prompts": prompts,
576 | "prompt_embeds": prompt_embeds,
577 | }
578 | )
579 |
580 | except Exception:
581 | print("-------------------------")
582 | print(f"An exception occurred while processing data: {rank=}, {world_size=}, {step=}")
583 | traceback.print_exc()
584 | print("-------------------------")
585 |
586 | # 5. Complete distributed processing
587 | if world_size > 1:
588 | dist.barrier()
589 | dist.destroy_process_group()
590 |
591 | output_queue.put(None)
592 | save_thread.shutdown(wait=True)
593 | save_future.result()
594 |
595 | # 6. Combine results from each rank
596 | if rank == 0:
597 | print(
598 | f"Completed preprocessing latents and embeddings. Temporary files from all ranks saved to `{tmp_dir.as_posix()}`"
599 | )
600 |
601 | # Move files from each rank to common directory
602 | for subfolder, extension in [
603 | ("images", "png"),
604 | ("image_latents", "pt"),
605 | ("videos", "mp4"),
606 | ("video_latents", "pt"),
607 | ("prompts", "txt"),
608 | ("prompt_embeds", "pt"),
609 | ("videos", "txt"),
610 | ]:
611 | tmp_subfolder = tmp_dir.joinpath(subfolder)
612 | combined_subfolder = output_dir.joinpath(subfolder)
613 | combined_subfolder.mkdir(parents=True, exist_ok=True)
614 | pattern = f"*.{extension}"
615 |
616 | for file in tmp_subfolder.rglob(pattern):
617 | file.replace(combined_subfolder / file.name)
618 |
619 | # Remove temporary directories
620 | def rmdir_recursive(dir: pathlib.Path) -> None:
621 | for child in dir.iterdir():
622 | if child.is_file():
623 | child.unlink()
624 | else:
625 | rmdir_recursive(child)
626 | dir.rmdir()
627 |
628 | rmdir_recursive(tmp_dir)
629 |
630 | # Combine prompts and videos into individual text files and single jsonl
631 | prompts_folder = output_dir.joinpath("prompts")
632 | prompts = []
633 | stems = []
634 |
635 | for filename in prompts_folder.rglob("*.txt"):
636 | with open(filename, "r") as file:
637 | prompts.append(file.read().strip())
638 | stems.append(filename.stem)
639 |
640 | prompts_txt = output_dir.joinpath("prompts.txt")
641 | videos_txt = output_dir.joinpath("videos.txt")
642 | data_jsonl = output_dir.joinpath("data.jsonl")
643 |
644 | with open(prompts_txt, "w") as file:
645 | for prompt in prompts:
646 | file.write(f"{prompt}\n")
647 |
648 | with open(videos_txt, "w") as file:
649 | for stem in stems:
650 | file.write(f"videos/{stem}.mp4\n")
651 |
652 | with open(data_jsonl, "w") as file:
653 | for prompt, stem in zip(prompts, stems):
654 | video_metadata_txt = output_dir.joinpath(f"videos/{stem}.txt")
655 | with open(video_metadata_txt, "r", encoding="utf-8") as metadata_file:
656 | metadata = json.loads(metadata_file.read())
657 |
658 | data = {
659 | "prompt": prompt,
660 | "prompt_embed": f"prompt_embeds/{stem}.pt",
661 | "image": f"images/{stem}.png",
662 | "image_latent": f"image_latents/{stem}.pt",
663 | "video": f"videos/{stem}.mp4",
664 | "video_latent": f"video_latents/{stem}.pt",
665 | "metadata": metadata,
666 | }
667 | file.write(json.dumps(data) + "\n")
668 |
669 | print(f"Completed preprocessing. All files saved to `{output_dir.as_posix()}`")
670 |
671 |
672 | if __name__ == "__main__":
673 | main()
674 |
--------------------------------------------------------------------------------
/utils/save_utils.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 |
4 | import cv2
5 | import imageio
6 | import numpy as np
7 | import torch
8 | import torchvision
9 | from einops import rearrange
10 | from PIL import Image
11 |
12 |
13 | def get_width_and_height_from_image_and_base_resolution(image, base_resolution):
14 | target_pixels = int(base_resolution) * int(base_resolution)
15 | original_width, original_height = Image.open(image).size
16 | ratio = (target_pixels / (original_width * original_height)) ** 0.5
17 | width_slider = round(original_width * ratio)
18 | height_slider = round(original_height * ratio)
19 | return height_slider, width_slider
20 |
21 | def color_transfer(sc, dc):
22 | """
23 | Transfer color distribution from of sc, referred to dc.
24 |
25 | Args:
26 | sc (numpy.ndarray): input image to be transfered.
27 | dc (numpy.ndarray): reference image
28 |
29 | Returns:
30 | numpy.ndarray: Transferred color distribution on the sc.
31 | """
32 |
33 | def get_mean_and_std(img):
34 | x_mean, x_std = cv2.meanStdDev(img)
35 | x_mean = np.hstack(np.around(x_mean, 2))
36 | x_std = np.hstack(np.around(x_std, 2))
37 | return x_mean, x_std
38 |
39 | sc = cv2.cvtColor(sc, cv2.COLOR_RGB2LAB)
40 | s_mean, s_std = get_mean_and_std(sc)
41 | dc = cv2.cvtColor(dc, cv2.COLOR_RGB2LAB)
42 | t_mean, t_std = get_mean_and_std(dc)
43 | img_n = ((sc - s_mean) * (t_std / s_std)) + t_mean
44 | np.putmask(img_n, img_n > 255, 255)
45 | np.putmask(img_n, img_n < 0, 0)
46 | dst = cv2.cvtColor(cv2.convertScaleAbs(img_n), cv2.COLOR_LAB2RGB)
47 | return dst
48 |
49 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=12, imageio_backend=True, color_transfer_post_process=False):
50 | videos = rearrange(videos, "b c t h w -> t b c h w")
51 | outputs = []
52 | for x in videos:
53 | x = torchvision.utils.make_grid(x, nrow=n_rows)
54 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
55 | if rescale:
56 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
57 | x = (x * 255).numpy().astype(np.uint8)
58 | outputs.append(Image.fromarray(x))
59 |
60 | if color_transfer_post_process:
61 | for i in range(1, len(outputs)):
62 | outputs[i] = Image.fromarray(color_transfer(np.uint8(outputs[i]), np.uint8(outputs[0])))
63 |
64 | os.makedirs(os.path.dirname(path), exist_ok=True)
65 | if imageio_backend:
66 | if path.endswith("mp4"):
67 | imageio.mimsave(path, outputs, fps=fps)
68 | else:
69 | imageio.mimsave(path, outputs, duration=(1000 * 1/fps))
70 | else:
71 | if path.endswith("mp4"):
72 | path = path.replace('.mp4', '.gif')
73 | outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0)
74 |
75 | def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size):
76 | if validation_image_start is not None and validation_image_end is not None:
77 | if type(validation_image_start) is str and os.path.isfile(validation_image_start):
78 | image_start = clip_image = Image.open(validation_image_start).convert("RGB")
79 | image_start = image_start.resize([sample_size[1], sample_size[0]])
80 | clip_image = clip_image.resize([sample_size[1], sample_size[0]])
81 | else:
82 | image_start = clip_image = validation_image_start
83 | image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
84 | clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
85 |
86 | if type(validation_image_end) is str and os.path.isfile(validation_image_end):
87 | image_end = Image.open(validation_image_end).convert("RGB")
88 | image_end = image_end.resize([sample_size[1], sample_size[0]])
89 | else:
90 | image_end = validation_image_end
91 | image_end = [_image_end.resize([sample_size[1], sample_size[0]]) for _image_end in image_end]
92 |
93 | if type(image_start) is list:
94 | clip_image = clip_image[0]
95 | start_video = torch.cat(
96 | [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start],
97 | dim=2
98 | )
99 | input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
100 | input_video[:, :, :len(image_start)] = start_video
101 |
102 | input_video_mask = torch.zeros_like(input_video[:, :1])
103 | input_video_mask[:, :, len(image_start):] = 255
104 | else:
105 | input_video = torch.tile(
106 | torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0),
107 | [1, 1, video_length, 1, 1]
108 | )
109 | input_video_mask = torch.zeros_like(input_video[:, :1])
110 | input_video_mask[:, :, 1:] = 255
111 |
112 | if type(image_end) is list:
113 | image_end = [_image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) for _image_end in image_end]
114 | end_video = torch.cat(
115 | [torch.from_numpy(np.array(_image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_end in image_end],
116 | dim=2
117 | )
118 | input_video[:, :, -len(end_video):] = end_video
119 |
120 | input_video_mask[:, :, -len(image_end):] = 0
121 | else:
122 | image_end = image_end.resize(image_start[0].size if type(image_start) is list else image_start.size)
123 | input_video[:, :, -1:] = torch.from_numpy(np.array(image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0)
124 | input_video_mask[:, :, -1:] = 0
125 |
126 | input_video = input_video / 255
127 |
128 | elif validation_image_start is not None:
129 | if type(validation_image_start) is str and os.path.isfile(validation_image_start):
130 | image_start = clip_image = Image.open(validation_image_start).convert("RGB")
131 | image_start = image_start.resize([sample_size[1], sample_size[0]])
132 | clip_image = clip_image.resize([sample_size[1], sample_size[0]])
133 | else:
134 | image_start = clip_image = validation_image_start
135 | image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
136 | clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
137 | image_end = None
138 |
139 | if type(image_start) is list:
140 | clip_image = clip_image[0]
141 | start_video = torch.cat(
142 | [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start],
143 | dim=2
144 | )
145 | input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
146 | input_video[:, :, :len(image_start)] = start_video
147 | input_video = input_video / 255
148 |
149 | input_video_mask = torch.zeros_like(input_video[:, :1])
150 | input_video_mask[:, :, len(image_start):] = 255
151 | else:
152 | input_video = torch.tile(
153 | torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0),
154 | [1, 1, video_length, 1, 1]
155 | ) / 255
156 | input_video_mask = torch.zeros_like(input_video[:, :1])
157 | input_video_mask[:, :, 1:, ] = 255
158 | else:
159 | image_start = None
160 | image_end = None
161 | input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]])
162 | input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255
163 | clip_image = None
164 |
165 | del image_start
166 | del image_end
167 | gc.collect()
168 |
169 | return input_video, input_video_mask, clip_image
170 |
171 | def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None, ref_image=None):
172 | if isinstance(input_video_path, str):
173 | cap = cv2.VideoCapture(input_video_path)
174 | input_video = []
175 |
176 | original_fps = cap.get(cv2.CAP_PROP_FPS)
177 | frame_skip = 1 if fps is None else int(original_fps // fps)
178 |
179 | frame_count = 0
180 |
181 | while True:
182 | ret, frame = cap.read()
183 | if not ret:
184 | break
185 |
186 | if frame_count % frame_skip == 0:
187 | frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
188 | input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
189 |
190 | frame_count += 1
191 |
192 | cap.release()
193 | else:
194 | input_video = input_video_path
195 |
196 | input_video = torch.from_numpy(np.array(input_video))[:video_length]
197 | input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
198 |
199 | if ref_image is not None:
200 | ref_image = Image.open(ref_image)
201 | ref_image = torch.from_numpy(np.array(ref_image))
202 | ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
203 |
204 | if validation_video_mask is not None:
205 | validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0]))
206 | input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255)
207 |
208 | input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
209 | input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
210 | input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
211 | else:
212 | input_video_mask = torch.zeros_like(input_video[:, :1])
213 | input_video_mask[:, :, :] = 255
214 |
215 | return input_video, input_video_mask, ref_image
--------------------------------------------------------------------------------
/utils/text_encoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .text_encoder import compute_prompt_embeddings
--------------------------------------------------------------------------------
/utils/text_encoder/text_encoder.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Union
2 |
3 | import torch
4 | from transformers import T5EncoderModel, T5Tokenizer
5 |
6 |
7 | def _get_t5_prompt_embeds(
8 | tokenizer: T5Tokenizer,
9 | text_encoder: T5EncoderModel,
10 | prompt: Union[str, List[str]],
11 | num_videos_per_prompt: int = 1,
12 | max_sequence_length: int = 226,
13 | device: Optional[torch.device] = None,
14 | dtype: Optional[torch.dtype] = None,
15 | text_input_ids=None,
16 | ):
17 | prompt = [prompt] if isinstance(prompt, str) else prompt
18 | batch_size = len(prompt)
19 |
20 | if tokenizer is not None:
21 | text_inputs = tokenizer(
22 | prompt,
23 | padding="max_length",
24 | max_length=max_sequence_length,
25 | truncation=True,
26 | add_special_tokens=True,
27 | return_tensors="pt",
28 | )
29 | text_input_ids = text_inputs.input_ids
30 | else:
31 | if text_input_ids is None:
32 | raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
33 |
34 | prompt_embeds = text_encoder(text_input_ids.to(device))[0]
35 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
36 |
37 | # duplicate text embeddings for each generation per prompt, using mps friendly method
38 | _, seq_len, _ = prompt_embeds.shape
39 | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
40 | prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
41 |
42 | return prompt_embeds
43 |
44 |
45 | def encode_prompt(
46 | tokenizer: T5Tokenizer,
47 | text_encoder: T5EncoderModel,
48 | prompt: Union[str, List[str]],
49 | num_videos_per_prompt: int = 1,
50 | max_sequence_length: int = 226,
51 | device: Optional[torch.device] = None,
52 | dtype: Optional[torch.dtype] = None,
53 | text_input_ids=None,
54 | ):
55 | prompt = [prompt] if isinstance(prompt, str) else prompt
56 | prompt_embeds = _get_t5_prompt_embeds(
57 | tokenizer,
58 | text_encoder,
59 | prompt=prompt,
60 | num_videos_per_prompt=num_videos_per_prompt,
61 | max_sequence_length=max_sequence_length,
62 | device=device,
63 | dtype=dtype,
64 | text_input_ids=text_input_ids,
65 | )
66 | return prompt_embeds
67 |
68 |
69 | def compute_prompt_embeddings(
70 | tokenizer: T5Tokenizer,
71 | text_encoder: T5EncoderModel,
72 | prompt: str,
73 | max_sequence_length: int,
74 | device: torch.device,
75 | dtype: torch.dtype,
76 | requires_grad: bool = False,
77 | ):
78 | if requires_grad:
79 | prompt_embeds = encode_prompt(
80 | tokenizer,
81 | text_encoder,
82 | prompt,
83 | num_videos_per_prompt=1,
84 | max_sequence_length=max_sequence_length,
85 | device=device,
86 | dtype=dtype,
87 | )
88 | else:
89 | with torch.no_grad():
90 | prompt_embeds = encode_prompt(
91 | tokenizer,
92 | text_encoder,
93 | prompt,
94 | num_videos_per_prompt=1,
95 | max_sequence_length=max_sequence_length,
96 | device=device,
97 | dtype=dtype,
98 | )
99 | return prompt_embeds
100 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import inspect
3 | from typing import Optional, Tuple, Union
4 |
5 | import torch
6 | from accelerate import Accelerator
7 | from accelerate.logging import get_logger
8 | from diffusers.models.embeddings import get_3d_rotary_pos_embed
9 | from diffusers.utils.torch_utils import is_compiled_module
10 |
11 |
12 | logger = get_logger(__name__)
13 |
14 |
15 | def get_optimizer(
16 | params_to_optimize,
17 | optimizer_name: str = "adam",
18 | learning_rate: float = 1e-3,
19 | beta1: float = 0.9,
20 | beta2: float = 0.95,
21 | beta3: float = 0.98,
22 | epsilon: float = 1e-8,
23 | weight_decay: float = 1e-4,
24 | prodigy_decouple: bool = False,
25 | prodigy_use_bias_correction: bool = False,
26 | prodigy_safeguard_warmup: bool = False,
27 | use_8bit: bool = False,
28 | use_4bit: bool = False,
29 | use_torchao: bool = False,
30 | use_deepspeed: bool = False,
31 | use_cpu_offload_optimizer: bool = False,
32 | offload_gradients: bool = False,
33 | ) -> torch.optim.Optimizer:
34 | optimizer_name = optimizer_name.lower()
35 |
36 | # Use DeepSpeed optimzer
37 | if use_deepspeed:
38 | from accelerate.utils import DummyOptim
39 |
40 | return DummyOptim(
41 | params_to_optimize,
42 | lr=learning_rate,
43 | betas=(beta1, beta2),
44 | eps=epsilon,
45 | weight_decay=weight_decay,
46 | )
47 |
48 | if use_8bit and use_4bit:
49 | raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.")
50 |
51 | if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer:
52 | try:
53 | import torchao
54 |
55 | torchao.__version__
56 | except ImportError:
57 | raise ImportError(
58 | "To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`."
59 | )
60 |
61 | if not use_torchao and use_4bit:
62 | raise ValueError("4-bit Optimizers are only supported with torchao.")
63 |
64 | # Optimizer creation
65 | supported_optimizers = ["adam", "adamw", "prodigy", "came"]
66 | if optimizer_name not in supported_optimizers:
67 | logger.warning(
68 | f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`."
69 | )
70 | optimizer_name = "adamw"
71 |
72 | if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
73 | raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.")
74 |
75 | if use_8bit:
76 | try:
77 | import bitsandbytes as bnb
78 | except ImportError:
79 | raise ImportError(
80 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
81 | )
82 |
83 | if optimizer_name == "adamw":
84 | if use_torchao:
85 | from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
86 |
87 | optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
88 | else:
89 | optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
90 |
91 | init_kwargs = {
92 | "betas": (beta1, beta2),
93 | "eps": epsilon,
94 | "weight_decay": weight_decay,
95 | }
96 |
97 | elif optimizer_name == "adam":
98 | if use_torchao:
99 | from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
100 |
101 | optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam
102 | else:
103 | optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam
104 |
105 | init_kwargs = {
106 | "betas": (beta1, beta2),
107 | "eps": epsilon,
108 | "weight_decay": weight_decay,
109 | }
110 |
111 | elif optimizer_name == "prodigy":
112 | try:
113 | import prodigyopt
114 | except ImportError:
115 | raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
116 |
117 | optimizer_class = prodigyopt.Prodigy
118 |
119 | if learning_rate <= 0.1:
120 | logger.warning(
121 | "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
122 | )
123 |
124 | init_kwargs = {
125 | "lr": learning_rate,
126 | "betas": (beta1, beta2),
127 | "beta3": beta3,
128 | "eps": epsilon,
129 | "weight_decay": weight_decay,
130 | "decouple": prodigy_decouple,
131 | "use_bias_correction": prodigy_use_bias_correction,
132 | "safeguard_warmup": prodigy_safeguard_warmup,
133 | }
134 |
135 | elif optimizer_name == "came":
136 | try:
137 | import came_pytorch
138 | except ImportError:
139 | raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`")
140 |
141 | optimizer_class = came_pytorch.CAME
142 |
143 | init_kwargs = {
144 | "lr": learning_rate,
145 | "eps": (1e-30, 1e-16),
146 | "betas": (beta1, beta2, beta3),
147 | "weight_decay": weight_decay,
148 | }
149 |
150 | if use_cpu_offload_optimizer:
151 | from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
152 |
153 | if "fused" in inspect.signature(optimizer_class.__init__).parameters:
154 | init_kwargs.update({"fused": True})
155 |
156 | optimizer = CPUOffloadOptimizer(
157 | params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs
158 | )
159 | else:
160 | optimizer = optimizer_class(params_to_optimize, **init_kwargs)
161 |
162 | return optimizer
163 |
164 |
165 | def get_gradient_norm(parameters):
166 | norm = 0
167 | for param in parameters:
168 | if param.grad is None:
169 | continue
170 | local_norm = param.grad.detach().data.norm(2)
171 | norm += local_norm.item() ** 2
172 | norm = norm**0.5
173 | return norm
174 |
175 |
176 | # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
177 | def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
178 | tw = tgt_width
179 | th = tgt_height
180 | h, w = src
181 | r = h / w
182 | if r > (th / tw):
183 | resize_height = th
184 | resize_width = int(round(th / h * w))
185 | else:
186 | resize_width = tw
187 | resize_height = int(round(tw / w * h))
188 |
189 | crop_top = int(round((th - resize_height) / 2.0))
190 | crop_left = int(round((tw - resize_width) / 2.0))
191 |
192 | return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
193 |
194 |
195 | def prepare_rotary_positional_embeddings(
196 | height: int,
197 | width: int,
198 | num_frames: int,
199 | vae_scale_factor_spatial: int = 8,
200 | patch_size: int = 2,
201 | patch_size_t: int = None,
202 | attention_head_dim: int = 64,
203 | device: Optional[torch.device] = None,
204 | base_height: int = 480,
205 | base_width: int = 720,
206 | ) -> Tuple[torch.Tensor, torch.Tensor]:
207 | grid_height = height // (vae_scale_factor_spatial * patch_size)
208 | grid_width = width // (vae_scale_factor_spatial * patch_size)
209 | base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
210 | base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
211 |
212 | if patch_size_t is None:
213 | grid_crops_coords = get_resize_crop_region_for_grid(
214 | (grid_height, grid_width), base_size_width, base_size_height
215 | )
216 | freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
217 | embed_dim=attention_head_dim,
218 | crops_coords=grid_crops_coords,
219 | grid_size=(grid_height, grid_width),
220 | temporal_size=num_frames,
221 | )
222 | else:
223 | base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t
224 |
225 | freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
226 | embed_dim=attention_head_dim,
227 | crops_coords=None,
228 | grid_size=(grid_height, grid_width),
229 | temporal_size=base_num_frames,
230 | grid_type="slice",
231 | max_size=(base_size_height, base_size_width),
232 | )
233 |
234 | freqs_cos = freqs_cos.to(device=device)
235 | freqs_sin = freqs_sin.to(device=device)
236 | return freqs_cos, freqs_sin
237 |
238 |
239 | def reset_memory(device: Union[str, torch.device]) -> None:
240 | gc.collect()
241 | torch.cuda.empty_cache()
242 | torch.cuda.reset_peak_memory_stats(device)
243 | torch.cuda.reset_accumulated_memory_stats(device)
244 |
245 |
246 | def print_memory(device: Union[str, torch.device]) -> None:
247 | memory_allocated = torch.cuda.memory_allocated(device) / 1024**3
248 | max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3
249 | max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
250 | print(f"{memory_allocated=:.3f} GB")
251 | print(f"{max_memory_allocated=:.3f} GB")
252 | print(f"{max_memory_reserved=:.3f} GB")
253 |
254 |
255 | def unwrap_model(accelerator: Accelerator, model):
256 | model = accelerator.unwrap_model(model)
257 | model = model._orig_mod if is_compiled_module(model) else model
258 | return model
259 |
--------------------------------------------------------------------------------