├── .gitignore ├── README.md ├── accelerate_configs ├── compiled_1.yaml └── deepspeed.yaml ├── assets ├── human1.jpg ├── human2.jpg ├── human3.jpg ├── human4.jpg ├── human5.jpg ├── motion1.mp4 ├── motion2.mp4 ├── motion3.mp4 ├── person1_img.jpg └── person2_img.jpg ├── checkpoints └── Put pre-trained model here.txt ├── docs └── finetune.md ├── dwpose ├── __init__.py ├── dwpose_detector.py ├── onnxdet.py ├── onnxpose.py ├── preprocess.py ├── util.py └── wholebody.py ├── dynamictrl ├── __init__.py ├── models │ ├── model_dynamictrl.py │ └── modules │ │ ├── cross_attention.py │ │ ├── embeddings.py │ │ └── model_cogvideox_autoencoderKL.py └── pipelines │ ├── dynamictrl_pipeline.py │ └── modules │ ├── dynamictrl_output.py │ └── utils.py ├── requirements.txt ├── scripts ├── dynamictrl_inference.py └── dynamictrl_sft_finetune.py ├── tools └── qwen.py └── utils ├── args.py ├── dataset.py ├── dataset_use_mask_entity.py ├── freeinit_utils.py ├── load_validation_control.py ├── pre_process.py ├── prepare_dataset.py ├── save_utils.py ├── text_encoder ├── __init__.py └── text_encoder.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## DynamiCtrl: Rethinking the Basic Structure and the Role of Text for High-quality Human Image Animation 2 | 3 | [Haoyu Zhao](https://scholar.google.com/citations?user=pCGM7jwAAAAJ&hl=zh-CN&oi=ao/), [Zhongang Qi](https://scholar.google.com/citations?user=zJvrrusAAAAJ&hl=en/), [Cong Wang](#), [Qingqing Zheng](https://scholar.google.com.hk/citations?user=l0Y7emkAAAAJ&hl=zh-CN&oi=ao/), [Guansong Lu](https://scholar.google.com.hk/citations?user=YIt8thUAAAAJ&hl=zh-CN&oi=ao), [Fei Chen](#), [Hang Xu](https://scholar.google.com.hk/citations?user=J_8TX6sAAAAJ&hl=zh-CN&oi=ao) and [Zuxuan Wu](https://scholar.google.com.hk/citations?user=7t12hVkAAAAJ&hl=zh-CN&oi=ao) 4 | 5 | 6 | 7 | 8 | [![GitHub](https://img.shields.io/github/stars/gulucaptain/DynamiCtrl?style=social)](https://github.com/gulucaptain/DynamiCtrl) 9 | 10 | 11 | ## 🎏 Introduction 12 | TL; DR: DynamiCtrl is the first framework to propose the "Joint-text" paradigm to the pose-guided human animation task and achieve effective pose control within the diffusion transformer (DiT) architecture. 13 | 14 |
CLICK for the full introduction 15 | 16 | 17 | > With diffusion transformer (DiT) excelling in video generation, its use in specific tasks has drawn increasing attention. However, adapting DiT for pose-guided human image animation faces two core challenges: (a) existing U-Net-based pose control methods may be suboptimal for the DiT backbone; and (b) removing text guidance, as in previous approaches, often leads to semantic loss and model degradation. To address these issues, we propose DynamiCtrl, a novel framework for human animation in video DiT architecture. Specifically, we use a shared VAE encoder for human images and driving poses, unifying them into a common latent space, maintaining pose fidelity, and eliminating the need for an expert pose encoder during video denoising. To integrate pose control into the DiT backbone effectively, we propose a novel Pose-adaptive Layer Norm model. It injects normalized pose features into the denoising process via conditioning on visual tokens, enabling seamless and scalable pose control across DiT blocks. Furthermore, to overcome the shortcomings of text removal, we introduce the "Joint-text" paradigm, which preserves the role of text embeddings to provide global semantic context. Through full-attention blocks, image and pose features are aligned with text features, enhancing semantic consistency, leveraging pretrained knowledge, and enabling multi-level control. Experiments verify the superiority of DynamiCtrl on benchmark and self-collected data (e.g., achieving the best LPIPS of 0.166), demonstrating strong character control and high-quality synthesis. 18 |
19 | 20 | ## 📺 Overview on YouTube 21 | 22 | [![IMAGE ALT TEXT HERE](https://img.youtube.com/vi/Mu_pNXM4PcE/0.jpg)](https://www.youtube.com/watch?v=Mu_pNXM4PcE) 23 | Please click to watch. 24 | 25 | ## ⚔️ DynamiCtrl for High-quality Pose-guided Human Image Animation 26 | 27 | We first refocus on the role of text for this task and find that fine-grained textual information helps improve video quality. In particular, we can achieve fine-grained local controllability using different prompts. 28 | 29 | 30 | 31 | 34 | 37 | 40 | 41 |
32 | 33 | 35 | 36 | 38 | 39 |
42 | 43 |
CLICK to check the prompts used for generation in the above three cases. 44 | 45 | > Prompt (left): “The image depicts a stylized, animated character standing amidst a chaotic and dynamic background. The character is dressed in a blue suit with a red cape, featuring a prominent "S" emblem on the chest. The suit has a belt with pouches and a utility belt. The character has spiky hair and is standing on a pile of debris and rubble, suggesting a scene of destruction or battle. The background is filled with glowing, fiery elements and a sense of motion, adding to the dramatic and intense atmosphere of the scene." 46 | 47 | > Prompt (mid): “The person in the image is a woman with long, blonde hair styled in loose waves. She is wearing a form-fitting, sleeveless top with a high neckline and a small cutout at the chest. The top is beige and has a strap across her chest. She is also wearing a black belt with a pouch attached to it. Around her neck, she has a turquoise pendant necklace. The background appears to be a dimly lit, urban environment with a warm, golden glow." 48 | 49 | > Prompt (right): “The person in the image is wearing a black, form-fitting one-piece outfit and a pair of VR goggles. They are walking down a busy street with numerous people and colorful neon signs in the background. The street appears to be a bustling urban area, possibly in a city known for its vibrant nightlife and entertainment. The lighting and signage suggest a lively atmosphere, typical of a cityscape at night." 50 |
51 | 52 | 53 | 54 | 57 | 60 | 63 | 64 |
55 | 56 | 58 | 59 | 61 | 62 |
65 | 66 | 67 | 68 | 71 | 74 | 77 | 78 | 79 | 82 | 85 | 88 | 89 |
69 | 70 | 72 | 73 | 75 | 76 |
80 | 81 | 83 | 84 | 86 | 87 |
90 | 91 | ### Fine-grained video control 92 | 93 | 94 | 95 | 98 | 101 | 104 | 107 | 110 | 113 | 116 | 117 |
96 | 97 | 99 | 100 | 102 | 103 | 105 | 106 | 108 | 109 | 111 | 112 | 114 | 115 |
118 | 119 |
CLICK to check the prompts used for generation in the above background-control cases. 120 | 121 | > Scene 1: The person in the image is wearing a white, knee-length dress with short sleeves and a square neckline. The dress features lace detailing and a ruffled hem. The person is also wearing clear, open-toed sandals. The background shows a bustling futuristic city at night, with neon lights reflecting off the wet streets and flying cars zooming above. 122 | 123 | > Scene 2: The person in the image is wearing a white, knee-length dress with short sleeves and a square neckline. The dress features lace detailing and a ruffled hem. The person is also wearing clear, open-toed sandals. The background shows a vibrant market street in a Middle Eastern bazaar, filled with colorful fabrics, exotic spices, and merchants calling out to customers. 124 | 125 | > Scene 3: The person in the image is wearing a white, knee-length dress with short sleeves and a square neckline. The dress features lace detailing and a ruffled hem. The person is also wearing clear, open-toed sandals. The background shows a sunny beach with golden sand, gentle ocean waves rolling onto the shore, and palm trees swaying in the breeze. 126 | 127 | > Scene 4: The person in the image is wearing a white, knee-length dress with short sleeves and a square neckline. The dress features lace detailing and a ruffled hem. The person is also wearing clear, open-toed sandals. The background shows a high-tech research lab with sleek metallic walls, glowing holographic screens, and robotic arms assembling futuristic devices. 128 | 129 | > Scene 5: The person in the image is wearing a white, knee-length dress with short sleeves and a square neckline. The dress features lace detailing and a ruffled hem. The person is also wearing clear, open-toed sandals. The background shows a mystical ancient temple hidden deep in the jungle, covered in vines, with glowing runes carved into the stone walls. 130 | 131 | > Scene 6: The person in the image is wearing a white, knee-length dress with short sleeves and a square neckline. The dress features lace detailing and a ruffled hem. The person is also wearing clear, open-toed sandals. The background shows a serene snowy forest with tall pine trees, soft snowflakes falling gently, and a frozen river winding through the landscape. 132 | 133 | > Scene 7: The person in the image is wearing a white, knee-length dress with short sleeves and a square neckline. The dress features lace detailing and a ruffled hem. The person is also wearing clear, open-toed sandals. The background shows an abandoned industrial warehouse with broken windows, scattered debris, and rusted machinery covered in dust. 134 |
135 | 136 | 137 | 138 | 141 | 144 | 147 | 150 | 153 | 156 | 159 | 160 |
139 | 140 | 142 | 143 | 145 | 146 | 148 | 149 | 151 | 152 | 154 | 155 | 157 | 158 |
161 | 162 | 163 | 164 | ## 🚧 Todo 165 | 166 |
Click for Previous todos 167 | 168 | - [✔] Release the project page and demos. 169 | - [✔] Paper on Arxiv on 27 Mar 2025. 170 |
171 | 172 | - [✔] Release inference code. 173 | - [✔] Release models. 174 | - [✔] Release training code. 175 | 176 | ## 📋 Changelog 177 | - 2025.05.20 Code and models released! 178 | - 2025.03.30 Project page and demos released! 179 | - 2025.03.10 Project Online! 180 | 181 | 182 | ## Installation 183 | 184 | For usage (SFT fine-tuning, inference), you can install the dependencies with: 185 | 186 | ```bash 187 | conda create --name dynamictrl python=3.10 188 | 189 | source activate dynamictrl 190 | 191 | pip install -r requirements.txt 192 | ``` 193 | 194 | 195 | ## Model Zoo 196 | 197 | We provide three grou of checkpoints: 198 |
    199 |
  1. DynamiCtrl-5B: trained with whole person image w/o mask and corresponding driving pose sequence.
  2. 200 |
  3. Dynamictrl-5B-Mask_B01: trained with masked background in person image and pose sequence.
  4. 201 |
  5. Dynamictrl-5B-Mask_C01: trained with masked clothes in person image and pose sequence.
  6. 202 |
203 | 204 | | name | Details | HF weights 🤗 | 205 | |:---|:---:|:---:| 206 | | DynamiCtrl-5B | SFT w/ whole image | [dynamictrl-5B](https://huggingface.co/gulucaptain/DynamiCtrl) | 207 | | Dynamictrl-5B-Mask_B01 | SFT w/ masked Background | [dynamictrl-5B-mask-B01](https://huggingface.co/gulucaptain/Dynamictrl-Mask_B01) | 208 | | Dynamictrl-5B-Mask_C01 | SFT w/ masked human Clothing | [dynamictrl-5B-mask-C01](https://huggingface.co/gulucaptain/Dynamictrl-Mask_C01) | 209 | 210 | 211 | [Causal VAE](https://arxiv.org/abs/2408.06072), [T5](https://arxiv.org/abs/1910.10683) are used as our VAE model and text encoder. 212 | 213 | ```bash 214 | cd checkpoints 215 | 216 | pip install -U huggingface_hub 217 | 218 | huggingface-cli download --resume-download --local-dir-use-symlinks False gulucaptain/DynamiCtrl --local-dir ./DynamiCtrl 219 | 220 | huggingface-cli download --resume-download --local-dir-use-symlinks False gulucaptain/Dynamictrl-Mask_B01 --local-dir ./Dynamictrl-Mask_B01 221 | 222 | huggingface-cli download --resume-download --local-dir-use-symlinks False gulucaptain/Dynamictrl-Mask_C01 --local-dir ./Dynamictrl-Mask_C01 223 | ``` 224 | 225 | Download the checkponts of [DWPose](https://github.com/IDEA-Research/DWPose) for human pose estimation: 226 | 227 | ```bash 228 | cd checkpoints 229 | 230 | git clone https://huggingface.co/yzd-v/DWPose 231 | 232 | # Change the paths in ./dwpose/wholebody.py Lines 15 and 16. 233 | ``` 234 | 235 | ## 👍 Quick Start 236 | 237 | ### Direct Inference w/ Driving Video 238 | 239 | ```bash 240 | image="./assets/human1.jpg" 241 | video="./assets/motion1.mp4" 242 | 243 | model_path="./checkpoints/DynamiCtrl" 244 | output="./outputs" 245 | 246 | CUDA_VISIBLE_DEVICES=0 python scripts/dynamictrl_inference.py \ 247 | --prompt="Input the test prompt here." \ 248 | --reference_image_path=$image \ 249 | --ori_driving_video=$video \ 250 | --model_path=$model_path \ 251 | --output_path=$output \ 252 | --num_inference_steps=25 \ 253 | --width=768 \ 254 | --height=1360 \ 255 | --num_frames=37 \ 256 | --pose_control_function="padaln" \ 257 | --guidance_scale=3.0 \ 258 | --seed=42 \ 259 | ``` 260 | 261 | Tips: When using the trained DynamiCtrl model without a masked area, you should ensure that the prompt content aligns with the provided human image, including the person's appearance and the background description. 262 | 263 | You can write the prompt by youself or we also provide a guidance to use Qwen2-VL tool to help you write the prompt corresponding to the content of image automatically, you can follow this blog [How to use Qwen2-VL](https://blog.csdn.net/zxs0222/article/details/144698753?spm=1001.2014.3001.5501). 264 | 265 | ### Inference w/ Maksed Human Image 266 | 267 | Thanks to the proposed "Joint-text" paradigm for this task, we can achieve fine-grained control over human motion, including background and clothing areas. It is also easy to use, just provide a human image with blacked-out areas, and you can directly run the inference script for generation. Note to replace the model path. How to automatically get the mask area? You can follow this blog: [How to get mask of subject](https://blog.csdn.net/zxs0222/article/details/147604020?spm=1001.2014.3001.5501). 268 | 269 | Note: please replace the "transformer" folder in DynamiCtrl with the "Dynamictrl-Mask_B01" or "Dynamictrl-Mask_C01" folder. 270 | 271 | ```bash 272 | image="./assets/maksed_human1.jpg" # Required 273 | video="./assets/motion.mp4" 274 | 275 | model_path="./checkpoints/Dynamictrl" # or "Dynamictrl-5B-Mask_C01" 276 | output="./outputs" 277 | 278 | CUDA_VISIBLE_DEVICES=0 python scripts/dynamictrl_inference.py \ 279 | --prompt="Input the test prompt here." \ 280 | --reference_image_path=$image \ 281 | --ori_driving_video=$video \ 282 | --model_path=$model_path \ 283 | --output_path=$output \ 284 | --num_inference_steps=25 \ 285 | --width=768 \ 286 | --height=1360 \ 287 | --num_frames=37 \ 288 | --pose_control_function="padaln" \ 289 | --guidance_scale=3.0 \ 290 | --seed=42 \ 291 | ``` 292 | 293 | Tips: Although the "Dynamictrl-5B-Mask_B01" and "Dynamictrl-5B-Mask_C01" models are trained with masked human images, you can still directly test whole human images with these two models. Sometimes, they may even perform better than the basic "Dynamictrl-5B" model. 294 | 295 | #### Memory and time cost 296 | 297 | | Device | Num of frames | Reslolutions | Time | GPU-mem 298 | |:---|:---:|:---:|:---:|:---:| 299 | | H20 | 37 | 1360 * 768 | 3 min 50s | 28.4 GB | 300 | | H20 | 37 | 1024 * 576 | 1 min 40s | 24.7 GB | 301 | | H20 | 37 | 1360 * 1360 | 9 min 28s | 34.8 GB | 302 | | H20 | 37 | 1024 * 1024 | 3 min 50s | 28.4 GB | 303 | 304 | ### Training 305 | 306 | Please find the instructions on data preparation and training [here](./docs/finetune.md). 307 | 308 | ## 🔅 More Applications: 309 | 310 | ### Digital Human (contains long video performance) 311 | 312 | Show cases: long video with 12 seconds, driving by the same audio. 313 | 314 | 315 | 316 | 319 | 322 | 323 |
317 | 318 | 320 | 321 |
324 | 325 | The identities of the digital human are generated by vivo's BlueLM model (Text to image generation). 326 | 327 | Two steps to generate a digital human: 328 | 329 | 1. Prepare a human image and a guided pose video, and generate the video materials using our DynamiCtrl. 330 | 331 | 2. Use the output video and an audio file, and apply [MuseTalk](https://github.com/TMElyralab/MuseTalk) to generate the correct lip movements. 332 | 333 | 334 | 335 | ## 📍 Citation 336 | 337 | If you find this repository helpful, please consider citing: 338 | 339 | ``` 340 | @article{zhao2025dynamictrl, 341 | title={DynamiCtrl: Rethinking the Basic Structure and the Role of Text for High-quality Human Image Animation}, 342 | author={Haoyu, Zhao and Zhongang, Qi and Cong, Wang and Qingping, Zheng and Guansong, Lu and Fei, Chen and Hang, Xu and Zuxuan, Wu}, 343 | year={2025}, 344 | journal={arXiv:2503.21246}, 345 | } 346 | ``` 347 | 348 | ## 💗 Acknowledgements 349 | 350 | This repository borrows heavily from [CogVideoX](https://github.com/THUDM/CogVideo). Thanks to the authors for sharing their code and models. 351 | 352 | ## 🧿 Maintenance 353 | 354 | This is the codebase for our research work. We are still working hard to update this repo, and more details are coming in days. -------------------------------------------------------------------------------- /accelerate_configs/compiled_1.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: 'NO' 4 | downcast_bf16: 'no' 5 | dynamo_config: 6 | dynamo_backend: INDUCTOR 7 | dynamo_mode: max-autotune 8 | dynamo_use_dynamic: true 9 | dynamo_use_fullgraph: false 10 | enable_cpu_affinity: false 11 | gpu_ids: '3' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: fp16 15 | num_machines: 1 16 | num_processes: 1 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /accelerate_configs/deepspeed.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | gradient_accumulation_steps: 1 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: cpu 7 | offload_param_device: cpu 8 | zero3_init_flag: false 9 | zero_stage: 2 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | enable_cpu_affinity: false 13 | machine_rank: 0 14 | main_training_function: main 15 | mixed_precision: bf16 16 | num_machines: 1 17 | num_processes: 8 18 | rdzv_backend: static 19 | same_network: true 20 | tpu_env: [] 21 | tpu_use_cluster: false 22 | tpu_use_sudo: false 23 | use_cpu: false 24 | -------------------------------------------------------------------------------- /assets/human1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/assets/human1.jpg -------------------------------------------------------------------------------- /assets/human2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/assets/human2.jpg -------------------------------------------------------------------------------- /assets/human3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/assets/human3.jpg -------------------------------------------------------------------------------- /assets/human4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/assets/human4.jpg -------------------------------------------------------------------------------- /assets/human5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/assets/human5.jpg -------------------------------------------------------------------------------- /assets/motion1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/assets/motion1.mp4 -------------------------------------------------------------------------------- /assets/motion2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/assets/motion2.mp4 -------------------------------------------------------------------------------- /assets/motion3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/assets/motion3.mp4 -------------------------------------------------------------------------------- /assets/person1_img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/assets/person1_img.jpg -------------------------------------------------------------------------------- /assets/person2_img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/assets/person2_img.jpg -------------------------------------------------------------------------------- /checkpoints/Put pre-trained model here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/checkpoints/Put pre-trained model here.txt -------------------------------------------------------------------------------- /docs/finetune.md: -------------------------------------------------------------------------------- 1 | ## Training 2 | 3 | For model training and fine-tuning, you can directly use the "dynamictrl" environment. 4 | 5 | ### Data Preparation 6 | 7 | #### Folder structure: 8 | ``` 9 | Training_data/ 10 | │ 11 | ├── videos/ 12 | │ ├── 0000001.mp4 13 | │ └── 0000002.mp4 14 | │ 15 | ├── poses/ 16 | │ ├── 0000001.mp4 17 | │ └── 0000001.mp4 18 | │ 19 | ├── prompts.txt 20 | │ 21 | ├── video.txt 22 | ``` 23 | 24 | Tips: You should use a pose estimation algorithm to extract the human pose, such as the DWPose method. To obtain the prompts for the training video, we select one frame from the video and use Qwen2-VL to understand the image content, including the human's appearance and background details. 25 | 26 | #### Data format of video.txt 27 | ``` 28 | { 29 | /home/user/data/Traing_data/videos/0000001.mp4 30 | /home/user/data/Traing_data/videos/0000002.mp4 31 | ... 32 | } 33 | ``` 34 | 35 | #### Data format of prompts.txt 36 | ``` 37 | { 38 | 0000001.mp4#####The video descripts xxx. The human xxx. The background xxx. 39 | 0000002.mp4#####The video descripts xxx. The human xxx. The background xxx. 40 | ... 41 | } 42 | ``` 43 | 44 | ### Fine-tuning: 45 | 46 | ```bash 47 | export TORCH_LOGS="+dynamo,recompiles,graph_breaks" 48 | export TORCHDYNAMO_VERBOSE=1 49 | export WANDB_MODE="offline" 50 | export NCCL_P2P_DISABLE=1 51 | export TORCH_NCCL_ENABLE_MONITORING=0 52 | export TOKENIZERS_PARALLELISM=true 53 | export OMP_NUM_THREADS=16 54 | export DS_SKIP_CUDA_CHECK=1 55 | 56 | GPU_IDS="0,1,2,3,4,5,6,7" 57 | 58 | # Training Configurations 59 | # Experiment with as many hyperparameters as you want! 60 | LEARNING_RATES=("5e-6") 61 | LR_SCHEDULES=("cosine_with_restarts") 62 | OPTIMIZERS=("adamw") 63 | MAX_TRAIN_STEPS=("50000") 64 | POSE_CONTROL_FUNCTION="padaln" 65 | 66 | # Single GPU uncompiled training 67 | ACCELERATE_CONFIG_FILE="accelerate_configs/deepspeed.yaml" 68 | 69 | # Absolute path to where the data is located. Make sure to have read the README for how to prepare data. 70 | # This example assumes you downloaded an already prepared dataset from HF CLI as follows: 71 | DATA_ROOT="/home/user/data/Traing_data" 72 | CAPTION_COLUMN="prompts.txt" 73 | VIDEO_COLUMN="video.txt" 74 | MODEL_PATH="./checkpoints/DynamiCtrl" 75 | 76 | # Set ` --load_tensors ` to load tensors from disk instead of recomputing the encoder process. 77 | # Launch experiments with different hyperparameters 78 | 79 | TRAINING_NAME="training_name" 80 | 81 | for learning_rate in "${LEARNING_RATES[@]}"; do 82 | for lr_schedule in "${LR_SCHEDULES[@]}"; do 83 | for optimizer in "${OPTIMIZERS[@]}"; do 84 | for steps in "${MAX_TRAIN_STEPS[@]}"; do 85 | output_dir="/home/user/exps/train/dynamictrl-sft_${TRAINING_NAME}_${optimizer}_steps_${steps}_lr-schedule_${lr_schedule}_learning-rate_${learning_rate}/" 86 | 87 | cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE \ 88 | --gpu_ids $GPU_IDS --main_process_port 10086 \ 89 | scripts/dynamictrl_sft_finetune.py \ 90 | --pretrained_model_name_or_path $MODEL_PATH \ 91 | --data_root $DATA_ROOT \ 92 | --caption_column $CAPTION_COLUMN \ 93 | --video_column $VIDEO_COLUMN \ 94 | --id_token BW_STYLE \ 95 | --height_buckets 1360 \ 96 | --width_buckets 768 \ 97 | --height 1360 \ 98 | --width 768 \ 99 | --frame_buckets 37 \ 100 | --dataloader_num_workers 1 \ 101 | --pin_memory \ 102 | --enable_control_pose \ 103 | --pose_control_function $POSE_CONTROL_FUNCTION \ 104 | --validation_prompt \"input the test prompt here.\" \ 105 | --validation_images \"/home/user/data/dynamictrl_train/test.jpg\" 106 | --validation_driving_videos \"/home/user/data/dynamictrl_train/test_driving_video.mp4\" 107 | --validation_prompt_separator ::: \ 108 | --num_validation_videos 1 \ 109 | --validation_epochs 100000 \ 110 | --validation_steps 100000 \ 111 | --seed 42 \ 112 | --mixed_precision bf16 \ 113 | --output_dir $output_dir \ 114 | --max_num_frames 37 \ 115 | --train_batch_size 64 \ 116 | --max_train_steps $steps \ 117 | --checkpointing_steps 1000 \ 118 | --gradient_checkpointing \ 119 | --gradient_accumulation_steps 4 \ 120 | --learning_rate $learning_rate \ 121 | --lr_scheduler $lr_schedule \ 122 | --lr_warmup_steps 1000 \ 123 | --lr_num_cycles 1 \ 124 | --enable_slicing \ 125 | --enable_tiling \ 126 | --noised_image_dropout 0.05 \ 127 | --optimizer $optimizer \ 128 | --beta1 0.9 \ 129 | --beta2 0.95 \ 130 | --weight_decay 0.001 \ 131 | --max_grad_norm 1.0 \ 132 | --allow_tf32 \ 133 | --report_to tensorboard \ 134 | --nccl_timeout 1800" 135 | 136 | echo "Running command: $cmd" 137 | eval $cmd 138 | echo -ne "-------------------- Finished executing script --------------------\n\n" 139 | done 140 | done 141 | done 142 | done 143 | 144 | ``` -------------------------------------------------------------------------------- /dwpose/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/dwpose/__init__.py -------------------------------------------------------------------------------- /dwpose/dwpose_detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from .wholebody import Wholebody 7 | 8 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | class DWposeDetector: 12 | """ 13 | A pose detect method for image-like data. 14 | 15 | Parameters: 16 | model_det: (str) serialized ONNX format model path, 17 | such as https://huggingface.co/yzd-v/DWPose/blob/main/yolox_l.onnx 18 | model_pose: (str) serialized ONNX format model path, 19 | such as https://huggingface.co/yzd-v/DWPose/blob/main/dw-ll_ucoco_384.onnx 20 | device: (str) 'cpu' or 'cuda:{device_id}' 21 | """ 22 | def __init__(self, model_det, model_pose, device='cpu'): 23 | self.args = model_det, model_pose, device 24 | 25 | def release_memory(self): 26 | if hasattr(self, 'pose_estimation'): 27 | del self.pose_estimation 28 | import gc; gc.collect() 29 | 30 | def __call__(self, oriImg): 31 | if not hasattr(self, 'pose_estimation'): 32 | self.pose_estimation = Wholebody(*self.args) 33 | 34 | oriImg = oriImg.copy() 35 | H, W, C = oriImg.shape 36 | with torch.no_grad(): 37 | candidate, score = self.pose_estimation(oriImg) 38 | nums, _, locs = candidate.shape 39 | candidate[..., 0] /= float(W) 40 | candidate[..., 1] /= float(H) 41 | body = candidate[:, :18].copy() 42 | body = body.reshape(nums * 18, locs) 43 | subset = score[:, :18].copy() 44 | for i in range(len(subset)): 45 | for j in range(len(subset[i])): 46 | if subset[i][j] > 0.3: 47 | subset[i][j] = int(18 * i + j) 48 | else: 49 | subset[i][j] = -1 50 | 51 | # un_visible = subset < 0.3 52 | # candidate[un_visible] = -1 53 | 54 | # foot = candidate[:, 18:24] 55 | 56 | faces = candidate[:, 24:92] 57 | 58 | hands = candidate[:, 92:113] 59 | hands = np.vstack([hands, candidate[:, 113:]]) 60 | 61 | faces_score = score[:, 24:92] 62 | hands_score = np.vstack([score[:, 92:113], score[:, 113:]]) 63 | 64 | bodies = dict(candidate=body, subset=subset, score=score[:, :18]) 65 | pose = dict(bodies=bodies, hands=hands, hands_score=hands_score, faces=faces, faces_score=faces_score) 66 | 67 | return pose 68 | 69 | dwpose_detector = DWposeDetector( 70 | model_det="pretrained/DWPose/yolox_l.onnx", 71 | model_pose="pretrained/DWPose/dw-ll_ucoco_384.onnx", 72 | device=device 73 | ) 74 | -------------------------------------------------------------------------------- /dwpose/onnxdet.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def nms(boxes, scores, nms_thr): 6 | """Single class NMS implemented in Numpy. 7 | 8 | Args: 9 | boxes (np.ndarray): shape=(N,4); N is number of boxes 10 | scores (np.ndarray): the score of bboxes 11 | nms_thr (float): the threshold in NMS 12 | 13 | Returns: 14 | List[int]: output bbox ids 15 | """ 16 | x1 = boxes[:, 0] 17 | y1 = boxes[:, 1] 18 | x2 = boxes[:, 2] 19 | y2 = boxes[:, 3] 20 | 21 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 22 | order = scores.argsort()[::-1] 23 | 24 | keep = [] 25 | while order.size > 0: 26 | i = order[0] 27 | keep.append(i) 28 | xx1 = np.maximum(x1[i], x1[order[1:]]) 29 | yy1 = np.maximum(y1[i], y1[order[1:]]) 30 | xx2 = np.minimum(x2[i], x2[order[1:]]) 31 | yy2 = np.minimum(y2[i], y2[order[1:]]) 32 | 33 | w = np.maximum(0.0, xx2 - xx1 + 1) 34 | h = np.maximum(0.0, yy2 - yy1 + 1) 35 | inter = w * h 36 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 37 | 38 | inds = np.where(ovr <= nms_thr)[0] 39 | order = order[inds + 1] 40 | 41 | return keep 42 | 43 | def multiclass_nms(boxes, scores, nms_thr, score_thr): 44 | """Multiclass NMS implemented in Numpy. Class-aware version. 45 | 46 | Args: 47 | boxes (np.ndarray): shape=(N,4); N is number of boxes 48 | scores (np.ndarray): the score of bboxes 49 | nms_thr (float): the threshold in NMS 50 | score_thr (float): the threshold of cls score 51 | 52 | Returns: 53 | np.ndarray: outputs bboxes coordinate 54 | """ 55 | final_dets = [] 56 | num_classes = scores.shape[1] 57 | for cls_ind in range(num_classes): 58 | cls_scores = scores[:, cls_ind] 59 | valid_score_mask = cls_scores > score_thr 60 | if valid_score_mask.sum() == 0: 61 | continue 62 | else: 63 | valid_scores = cls_scores[valid_score_mask] 64 | valid_boxes = boxes[valid_score_mask] 65 | keep = nms(valid_boxes, valid_scores, nms_thr) 66 | if len(keep) > 0: 67 | cls_inds = np.ones((len(keep), 1)) * cls_ind 68 | dets = np.concatenate( 69 | [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 70 | ) 71 | final_dets.append(dets) 72 | if len(final_dets) == 0: 73 | return None 74 | return np.concatenate(final_dets, 0) 75 | 76 | def demo_postprocess(outputs, img_size, p6=False): 77 | grids = [] 78 | expanded_strides = [] 79 | strides = [8, 16, 32] if not p6 else [8, 16, 32, 64] 80 | 81 | hsizes = [img_size[0] // stride for stride in strides] 82 | wsizes = [img_size[1] // stride for stride in strides] 83 | 84 | for hsize, wsize, stride in zip(hsizes, wsizes, strides): 85 | xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) 86 | grid = np.stack((xv, yv), 2).reshape(1, -1, 2) 87 | grids.append(grid) 88 | shape = grid.shape[:2] 89 | expanded_strides.append(np.full((*shape, 1), stride)) 90 | 91 | grids = np.concatenate(grids, 1) 92 | expanded_strides = np.concatenate(expanded_strides, 1) 93 | outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides 94 | outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides 95 | 96 | return outputs 97 | 98 | def preprocess(img, input_size, swap=(2, 0, 1)): 99 | if len(img.shape) == 3: 100 | padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 101 | else: 102 | padded_img = np.ones(input_size, dtype=np.uint8) * 114 103 | 104 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) 105 | resized_img = cv2.resize( 106 | img, 107 | (int(img.shape[1] * r), int(img.shape[0] * r)), 108 | interpolation=cv2.INTER_LINEAR, 109 | ).astype(np.uint8) 110 | padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img 111 | 112 | padded_img = padded_img.transpose(swap) 113 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) 114 | return padded_img, r 115 | 116 | def inference_detector(session, oriImg): 117 | """run human detect 118 | """ 119 | input_shape = (640,640) 120 | img, ratio = preprocess(oriImg, input_shape) 121 | 122 | ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]} 123 | output = session.run(None, ort_inputs) 124 | predictions = demo_postprocess(output[0], input_shape)[0] 125 | 126 | boxes = predictions[:, :4] 127 | scores = predictions[:, 4:5] * predictions[:, 5:] 128 | 129 | boxes_xyxy = np.ones_like(boxes) 130 | boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2. 131 | boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2. 132 | boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2. 133 | boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2. 134 | boxes_xyxy /= ratio 135 | dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) 136 | if dets is not None: 137 | final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] 138 | isscore = final_scores>0.3 139 | iscat = final_cls_inds == 0 140 | isbbox = [ i and j for (i, j) in zip(isscore, iscat)] 141 | final_boxes = final_boxes[isbbox] 142 | else: 143 | final_boxes = np.array([]) 144 | 145 | return final_boxes 146 | -------------------------------------------------------------------------------- /dwpose/onnxpose.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import cv2 4 | import numpy as np 5 | import onnxruntime as ort 6 | 7 | def preprocess( 8 | img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256) 9 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 10 | """Do preprocessing for RTMPose model inference. 11 | 12 | Args: 13 | img (np.ndarray): Input image in shape. 14 | input_size (tuple): Input image size in shape (w, h). 15 | 16 | Returns: 17 | tuple: 18 | - resized_img (np.ndarray): Preprocessed image. 19 | - center (np.ndarray): Center of image. 20 | - scale (np.ndarray): Scale of image. 21 | """ 22 | # get shape of image 23 | img_shape = img.shape[:2] 24 | out_img, out_center, out_scale = [], [], [] 25 | if len(out_bbox) == 0: 26 | out_bbox = [[0, 0, img_shape[1], img_shape[0]]] 27 | for i in range(len(out_bbox)): 28 | x0 = out_bbox[i][0] 29 | y0 = out_bbox[i][1] 30 | x1 = out_bbox[i][2] 31 | y1 = out_bbox[i][3] 32 | bbox = np.array([x0, y0, x1, y1]) 33 | 34 | # get center and scale 35 | center, scale = bbox_xyxy2cs(bbox, padding=1.25) 36 | 37 | # do affine transformation 38 | resized_img, scale = top_down_affine(input_size, scale, center, img) 39 | 40 | # normalize image 41 | mean = np.array([123.675, 116.28, 103.53]) 42 | std = np.array([58.395, 57.12, 57.375]) 43 | resized_img = (resized_img - mean) / std 44 | 45 | out_img.append(resized_img) 46 | out_center.append(center) 47 | out_scale.append(scale) 48 | 49 | return out_img, out_center, out_scale 50 | 51 | 52 | def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray: 53 | """Inference RTMPose model. 54 | 55 | Args: 56 | sess (ort.InferenceSession): ONNXRuntime session. 57 | img (np.ndarray): Input image in shape. 58 | 59 | Returns: 60 | outputs (np.ndarray): Output of RTMPose model. 61 | """ 62 | all_out = [] 63 | # build input 64 | for i in range(len(img)): 65 | input = [img[i].transpose(2, 0, 1)] 66 | 67 | # build output 68 | sess_input = {sess.get_inputs()[0].name: input} 69 | sess_output = [] 70 | for out in sess.get_outputs(): 71 | sess_output.append(out.name) 72 | 73 | # run model 74 | outputs = sess.run(sess_output, sess_input) 75 | all_out.append(outputs) 76 | 77 | return all_out 78 | 79 | 80 | def postprocess(outputs: List[np.ndarray], 81 | model_input_size: Tuple[int, int], 82 | center: Tuple[int, int], 83 | scale: Tuple[int, int], 84 | simcc_split_ratio: float = 2.0 85 | ) -> Tuple[np.ndarray, np.ndarray]: 86 | """Postprocess for RTMPose model output. 87 | 88 | Args: 89 | outputs (np.ndarray): Output of RTMPose model. 90 | model_input_size (tuple): RTMPose model Input image size. 91 | center (tuple): Center of bbox in shape (x, y). 92 | scale (tuple): Scale of bbox in shape (w, h). 93 | simcc_split_ratio (float): Split ratio of simcc. 94 | 95 | Returns: 96 | tuple: 97 | - keypoints (np.ndarray): Rescaled keypoints. 98 | - scores (np.ndarray): Model predict scores. 99 | """ 100 | all_key = [] 101 | all_score = [] 102 | for i in range(len(outputs)): 103 | # use simcc to decode 104 | simcc_x, simcc_y = outputs[i] 105 | keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio) 106 | 107 | # rescale keypoints 108 | keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2 109 | all_key.append(keypoints[0]) 110 | all_score.append(scores[0]) 111 | 112 | return np.array(all_key), np.array(all_score) 113 | 114 | 115 | def bbox_xyxy2cs(bbox: np.ndarray, 116 | padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]: 117 | """Transform the bbox format from (x,y,w,h) into (center, scale) 118 | 119 | Args: 120 | bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted 121 | as (left, top, right, bottom) 122 | padding (float): BBox padding factor that will be multilied to scale. 123 | Default: 1.0 124 | 125 | Returns: 126 | tuple: A tuple containing center and scale. 127 | - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or 128 | (n, 2) 129 | - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or 130 | (n, 2) 131 | """ 132 | # convert single bbox from (4, ) to (1, 4) 133 | dim = bbox.ndim 134 | if dim == 1: 135 | bbox = bbox[None, :] 136 | 137 | # get bbox center and scale 138 | x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3]) 139 | center = np.hstack([x1 + x2, y1 + y2]) * 0.5 140 | scale = np.hstack([x2 - x1, y2 - y1]) * padding 141 | 142 | if dim == 1: 143 | center = center[0] 144 | scale = scale[0] 145 | 146 | return center, scale 147 | 148 | 149 | def _fix_aspect_ratio(bbox_scale: np.ndarray, 150 | aspect_ratio: float) -> np.ndarray: 151 | """Extend the scale to match the given aspect ratio. 152 | 153 | Args: 154 | scale (np.ndarray): The image scale (w, h) in shape (2, ) 155 | aspect_ratio (float): The ratio of ``w/h`` 156 | 157 | Returns: 158 | np.ndarray: The reshaped image scale in (2, ) 159 | """ 160 | w, h = np.hsplit(bbox_scale, [1]) 161 | bbox_scale = np.where(w > h * aspect_ratio, 162 | np.hstack([w, w / aspect_ratio]), 163 | np.hstack([h * aspect_ratio, h])) 164 | return bbox_scale 165 | 166 | 167 | def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray: 168 | """Rotate a point by an angle. 169 | 170 | Args: 171 | pt (np.ndarray): 2D point coordinates (x, y) in shape (2, ) 172 | angle_rad (float): rotation angle in radian 173 | 174 | Returns: 175 | np.ndarray: Rotated point in shape (2, ) 176 | """ 177 | sn, cs = np.sin(angle_rad), np.cos(angle_rad) 178 | rot_mat = np.array([[cs, -sn], [sn, cs]]) 179 | return rot_mat @ pt 180 | 181 | 182 | def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray: 183 | """To calculate the affine matrix, three pairs of points are required. This 184 | function is used to get the 3rd point, given 2D points a & b. 185 | 186 | The 3rd point is defined by rotating vector `a - b` by 90 degrees 187 | anticlockwise, using b as the rotation center. 188 | 189 | Args: 190 | a (np.ndarray): The 1st point (x,y) in shape (2, ) 191 | b (np.ndarray): The 2nd point (x,y) in shape (2, ) 192 | 193 | Returns: 194 | np.ndarray: The 3rd point. 195 | """ 196 | direction = a - b 197 | c = b + np.r_[-direction[1], direction[0]] 198 | return c 199 | 200 | 201 | def get_warp_matrix(center: np.ndarray, 202 | scale: np.ndarray, 203 | rot: float, 204 | output_size: Tuple[int, int], 205 | shift: Tuple[float, float] = (0., 0.), 206 | inv: bool = False) -> np.ndarray: 207 | """Calculate the affine transformation matrix that can warp the bbox area 208 | in the input image to the output size. 209 | 210 | Args: 211 | center (np.ndarray[2, ]): Center of the bounding box (x, y). 212 | scale (np.ndarray[2, ]): Scale of the bounding box 213 | wrt [width, height]. 214 | rot (float): Rotation angle (degree). 215 | output_size (np.ndarray[2, ] | list(2,)): Size of the 216 | destination heatmaps. 217 | shift (0-100%): Shift translation ratio wrt the width/height. 218 | Default (0., 0.). 219 | inv (bool): Option to inverse the affine transform direction. 220 | (inv=False: src->dst or inv=True: dst->src) 221 | 222 | Returns: 223 | np.ndarray: A 2x3 transformation matrix 224 | """ 225 | shift = np.array(shift) 226 | src_w = scale[0] 227 | dst_w = output_size[0] 228 | dst_h = output_size[1] 229 | 230 | # compute transformation matrix 231 | rot_rad = np.deg2rad(rot) 232 | src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad) 233 | dst_dir = np.array([0., dst_w * -0.5]) 234 | 235 | # get four corners of the src rectangle in the original image 236 | src = np.zeros((3, 2), dtype=np.float32) 237 | src[0, :] = center + scale * shift 238 | src[1, :] = center + src_dir + scale * shift 239 | src[2, :] = _get_3rd_point(src[0, :], src[1, :]) 240 | 241 | # get four corners of the dst rectangle in the input image 242 | dst = np.zeros((3, 2), dtype=np.float32) 243 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5] 244 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir 245 | dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) 246 | 247 | if inv: 248 | warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src)) 249 | else: 250 | warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst)) 251 | 252 | return warp_mat 253 | 254 | 255 | def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict, 256 | img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 257 | """Get the bbox image as the model input by affine transform. 258 | 259 | Args: 260 | input_size (dict): The input size of the model. 261 | bbox_scale (dict): The bbox scale of the img. 262 | bbox_center (dict): The bbox center of the img. 263 | img (np.ndarray): The original image. 264 | 265 | Returns: 266 | tuple: A tuple containing center and scale. 267 | - np.ndarray[float32]: img after affine transform. 268 | - np.ndarray[float32]: bbox scale after affine transform. 269 | """ 270 | w, h = input_size 271 | warp_size = (int(w), int(h)) 272 | 273 | # reshape bbox to fixed aspect ratio 274 | bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h) 275 | 276 | # get the affine matrix 277 | center = bbox_center 278 | scale = bbox_scale 279 | rot = 0 280 | warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h)) 281 | 282 | # do affine transform 283 | img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR) 284 | 285 | return img, bbox_scale 286 | 287 | 288 | def get_simcc_maximum(simcc_x: np.ndarray, 289 | simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 290 | """Get maximum response location and value from simcc representations. 291 | 292 | Note: 293 | instance number: N 294 | num_keypoints: K 295 | heatmap height: H 296 | heatmap width: W 297 | 298 | Args: 299 | simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx) 300 | simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy) 301 | 302 | Returns: 303 | tuple: 304 | - locs (np.ndarray): locations of maximum heatmap responses in shape 305 | (K, 2) or (N, K, 2) 306 | - vals (np.ndarray): values of maximum heatmap responses in shape 307 | (K,) or (N, K) 308 | """ 309 | N, K, Wx = simcc_x.shape 310 | simcc_x = simcc_x.reshape(N * K, -1) 311 | simcc_y = simcc_y.reshape(N * K, -1) 312 | 313 | # get maximum value locations 314 | x_locs = np.argmax(simcc_x, axis=1) 315 | y_locs = np.argmax(simcc_y, axis=1) 316 | locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32) 317 | max_val_x = np.amax(simcc_x, axis=1) 318 | max_val_y = np.amax(simcc_y, axis=1) 319 | 320 | # get maximum value across x and y axis 321 | mask = max_val_x > max_val_y 322 | max_val_x[mask] = max_val_y[mask] 323 | vals = max_val_x 324 | locs[vals <= 0.] = -1 325 | 326 | # reshape 327 | locs = locs.reshape(N, K, 2) 328 | vals = vals.reshape(N, K) 329 | 330 | return locs, vals 331 | 332 | 333 | def decode(simcc_x: np.ndarray, simcc_y: np.ndarray, 334 | simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]: 335 | """Modulate simcc distribution with Gaussian. 336 | 337 | Args: 338 | simcc_x (np.ndarray[K, Wx]): model predicted simcc in x. 339 | simcc_y (np.ndarray[K, Wy]): model predicted simcc in y. 340 | simcc_split_ratio (int): The split ratio of simcc. 341 | 342 | Returns: 343 | tuple: A tuple containing center and scale. 344 | - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2) 345 | - np.ndarray[float32]: scores in shape (K,) or (n, K) 346 | """ 347 | keypoints, scores = get_simcc_maximum(simcc_x, simcc_y) 348 | keypoints /= simcc_split_ratio 349 | 350 | return keypoints, scores 351 | 352 | 353 | def inference_pose(session, out_bbox, oriImg): 354 | """run pose detect 355 | 356 | Args: 357 | session (ort.InferenceSession): ONNXRuntime session. 358 | out_bbox (np.ndarray): bbox list 359 | oriImg (np.ndarray): Input image in shape. 360 | 361 | Returns: 362 | tuple: 363 | - keypoints (np.ndarray): Rescaled keypoints. 364 | - scores (np.ndarray): Model predict scores. 365 | """ 366 | h, w = session.get_inputs()[0].shape[2:] 367 | model_input_size = (w, h) 368 | # preprocess for rtm-pose model inference. 369 | resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size) 370 | # run pose estimation for processed img 371 | outputs = inference(session, resized_img) 372 | # postprocess for rtm-pose model output. 373 | keypoints, scores = postprocess(outputs, model_input_size, center, scale) 374 | 375 | return keypoints, scores 376 | -------------------------------------------------------------------------------- /dwpose/preprocess.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import decord 3 | import numpy as np 4 | 5 | from .util import draw_pose 6 | from .dwpose_detector import dwpose_detector as dwprocessor 7 | 8 | 9 | def get_video_pose( 10 | video_path: str, 11 | ref_image: np.ndarray, 12 | max_frame_num = None, 13 | sample_stride: int=1): 14 | """preprocess ref image pose and video pose 15 | 16 | Args: 17 | video_path (str): video pose path 18 | ref_image (np.ndarray): reference image 19 | sample_stride (int, optional): Defaults to 1. 20 | 21 | Returns: 22 | np.ndarray: sequence of video pose 23 | """ 24 | # select ref-keypoint from reference pose for pose rescale 25 | ref_pose = dwprocessor(ref_image) 26 | ref_keypoint_id = [0, 1, 2, 5, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] 27 | ref_keypoint_id = [i for i in ref_keypoint_id \ 28 | if len(ref_pose['bodies']['subset']) > 0 and ref_pose['bodies']['subset'][0][i] >= .0] 29 | ref_body = ref_pose['bodies']['candidate'][ref_keypoint_id] 30 | 31 | height, width, _ = ref_image.shape 32 | 33 | # read input video 34 | vr = decord.VideoReader(video_path, ctx=decord.cpu(0)) 35 | if max_frame_num is None: 36 | max_frame_num = len(vr) 37 | else: 38 | max_frame_num = min(len(vr), max_frame_num * sample_stride) 39 | 40 | sample_stride = sample_stride 41 | begin_frame_index = 0 42 | frames = vr.get_batch(list(range(begin_frame_index, max_frame_num * sample_stride + begin_frame_index, sample_stride))).asnumpy() 43 | detected_poses = [dwprocessor(frm) for frm in tqdm(frames, desc="DWPose")] 44 | 45 | dwprocessor.release_memory() 46 | detected_bodies = np.stack([p['bodies']['candidate'] for p in detected_poses if p['bodies']['candidate'].shape[0] == 18])[:, ref_keypoint_id] 47 | 48 | keep_person_body_unchange = False 49 | if keep_person_body_unchange: 50 | output_pose = [] 51 | for i in range(len(detected_poses)): 52 | ay, by = np.polyfit(detected_bodies[i, :, 1].flatten(), np.tile(ref_body[:, 1], 1), 1) 53 | fh, fw, _ = vr[0].shape 54 | ax = ay / (fh / fw / height * width) 55 | bx = np.mean(np.tile(ref_body[:, 0], 1) - detected_bodies[i, :, 0].flatten() * ax) 56 | a = np.array([ax, ay]) 57 | b = np.array([bx, by]) 58 | 59 | detected_poses[i]['bodies']['candidate'] = detected_poses[i]['bodies']['candidate'] * a + b 60 | detected_poses[i]['faces'] = detected_poses[i]['faces'] * a + b 61 | detected_poses[i]['hands'] = detected_poses[i]['hands'] * a + b 62 | im = draw_pose(detected_poses[i], height, width) 63 | output_pose.append(np.array(im)) 64 | return np.stack(output_pose) 65 | 66 | # compute linear-rescale params 67 | ay, by = np.polyfit(detected_bodies[:, :, 1].flatten(), np.tile(ref_body[:, 1], len(detected_bodies)), 1) 68 | if ref_body[:, 1][0] > detected_bodies[0, :, 1][0]: 69 | by = ref_body[:, 1][0] - detected_bodies[0, :, 1][0] 70 | elif ref_body[:, 1][0] < detected_bodies[0, :, 1][0]: 71 | by = -(detected_bodies[0, :, 1][0] - ref_body[:, 1][0]) 72 | 73 | fh, fw, _ = vr[0].shape 74 | ax = ay / (fh / fw / height * width) 75 | bx = np.mean(np.tile(ref_body[:, 0], len(detected_bodies)) - detected_bodies[:, :, 0].flatten() * ax) 76 | a = np.array([ax, ay]) 77 | b = np.array([bx, by]) 78 | 79 | re_ensure_by_candidate = detected_poses[0]['bodies']['candidate'] * a + b 80 | if ref_body[:, 1][0] > re_ensure_by_candidate[0][1]: 81 | by = by + ref_body[:, 1][0] - re_ensure_by_candidate[0][1] 82 | elif ref_body[:, 1][0] < re_ensure_by_candidate[0][1]: 83 | by = by - (re_ensure_by_candidate[0][1] - ref_body[:, 1][0]) 84 | b = np.array([bx, by]) 85 | 86 | output_pose = [] 87 | detected_pose_index = 0 88 | for detected_pose in detected_poses: 89 | detected_pose['bodies']['candidate'] = detected_pose['bodies']['candidate'] * a + b 90 | detected_pose['faces'] = detected_pose['faces'] * a + b 91 | detected_pose['hands'] = detected_pose['hands'] * a + b 92 | im = draw_pose(detected_pose, height, width) 93 | output_pose.append(np.array(im)) 94 | detected_pose_index += 1 95 | return np.stack(output_pose) 96 | 97 | 98 | def get_image_pose(ref_image): 99 | """process image pose 100 | 101 | Args: 102 | ref_image (np.ndarray): reference image pixel value 103 | 104 | Returns: 105 | np.ndarray: pose visual image in RGB-mode 106 | """ 107 | height, width, _ = ref_image.shape 108 | ref_pose = dwprocessor(ref_image) 109 | pose_img = draw_pose(ref_pose, height, width) 110 | return np.array(pose_img) 111 | -------------------------------------------------------------------------------- /dwpose/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import matplotlib 4 | import cv2 5 | 6 | 7 | eps = 0.01 8 | 9 | def alpha_blend_color(color, alpha): 10 | """blend color according to point conf 11 | """ 12 | return [int(c * alpha) for c in color] 13 | 14 | def draw_bodypose(canvas, candidate, subset, score): 15 | H, W, C = canvas.shape 16 | candidate = np.array(candidate) 17 | subset = np.array(subset) 18 | 19 | stickwidth = 4 20 | 21 | limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ 22 | [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ 23 | [1, 16], [16, 18], [3, 17], [6, 18]] 24 | 25 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ 26 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ 27 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 28 | 29 | for i in range(17): 30 | for n in range(len(subset)): 31 | index = subset[n][np.array(limbSeq[i]) - 1] 32 | conf = score[n][np.array(limbSeq[i]) - 1] 33 | if conf[0] < 0.3 or conf[1] < 0.3: 34 | continue 35 | Y = candidate[index.astype(int), 0] * float(W) 36 | X = candidate[index.astype(int), 1] * float(H) 37 | mX = np.mean(X) 38 | mY = np.mean(Y) 39 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 40 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 41 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) 42 | cv2.fillConvexPoly(canvas, polygon, alpha_blend_color(colors[i], conf[0] * conf[1])) 43 | 44 | canvas = (canvas * 0.6).astype(np.uint8) 45 | 46 | for i in range(18): 47 | for n in range(len(subset)): 48 | index = int(subset[n][i]) 49 | if index == -1: 50 | continue 51 | x, y = candidate[index][0:2] 52 | conf = score[n][i] 53 | x = int(x * W) 54 | y = int(y * H) 55 | cv2.circle(canvas, (int(x), int(y)), 4, alpha_blend_color(colors[i], conf), thickness=-1) 56 | 57 | return canvas 58 | 59 | def draw_handpose(canvas, all_hand_peaks, all_hand_scores): 60 | H, W, C = canvas.shape 61 | 62 | edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ 63 | [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] 64 | 65 | for peaks, scores in zip(all_hand_peaks, all_hand_scores): 66 | 67 | for ie, e in enumerate(edges): 68 | x1, y1 = peaks[e[0]] 69 | x2, y2 = peaks[e[1]] 70 | x1 = int(x1 * W) 71 | y1 = int(y1 * H) 72 | x2 = int(x2 * W) 73 | y2 = int(y2 * H) 74 | score = int(scores[e[0]] * scores[e[1]] * 255) 75 | if x1 > eps and y1 > eps and x2 > eps and y2 > eps: 76 | cv2.line(canvas, (x1, y1), (x2, y2), 77 | matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * score, thickness=2) 78 | 79 | for i, keyponit in enumerate(peaks): 80 | x, y = keyponit 81 | x = int(x * W) 82 | y = int(y * H) 83 | score = int(scores[i] * 255) 84 | if x > eps and y > eps: 85 | cv2.circle(canvas, (x, y), 4, (0, 0, score), thickness=-1) 86 | return canvas 87 | 88 | def draw_facepose(canvas, all_lmks, all_scores): 89 | H, W, C = canvas.shape 90 | for lmks, scores in zip(all_lmks, all_scores): 91 | for lmk, score in zip(lmks, scores): 92 | x, y = lmk 93 | x = int(x * W) 94 | y = int(y * H) 95 | conf = int(score * 255) 96 | if x > eps and y > eps: 97 | cv2.circle(canvas, (x, y), 3, (conf, conf, conf), thickness=-1) 98 | return canvas 99 | 100 | def draw_pose(pose, H, W, ref_w=2160): 101 | """vis dwpose outputs 102 | 103 | Args: 104 | pose (List): DWposeDetector outputs in dwpose_detector.py 105 | H (int): height 106 | W (int): width 107 | ref_w (int, optional) Defaults to 2160. 108 | 109 | Returns: 110 | np.ndarray: image pixel value in RGB mode 111 | """ 112 | bodies = pose['bodies'] 113 | faces = pose['faces'] 114 | hands = pose['hands'] 115 | candidate = bodies['candidate'] 116 | subset = bodies['subset'] 117 | 118 | sz = min(H, W) 119 | sr = (ref_w / sz) if sz != ref_w else 1 120 | 121 | ########################################## create zero canvas ################################################## 122 | canvas = np.zeros(shape=(int(H*sr), int(W*sr), 3), dtype=np.uint8) 123 | 124 | ########################################### draw body pose ##################################################### 125 | canvas = draw_bodypose(canvas, candidate, subset, score=bodies['score']) 126 | 127 | ########################################### draw hand pose ##################################################### 128 | canvas = draw_handpose(canvas, hands, pose['hands_score']) 129 | 130 | ########################################### draw face pose ##################################################### 131 | canvas = draw_facepose(canvas, faces, pose['faces_score']) 132 | 133 | return cv2.cvtColor(cv2.resize(canvas, (W, H)), cv2.COLOR_BGR2RGB).transpose(2, 0, 1) 134 | -------------------------------------------------------------------------------- /dwpose/wholebody.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnxruntime as ort 3 | 4 | from .onnxdet import inference_detector 5 | from .onnxpose import inference_pose 6 | 7 | 8 | class Wholebody: 9 | """detect human pose by dwpose 10 | """ 11 | def __init__(self, model_det, model_pose, device="cpu"): 12 | providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider'] 13 | provider_options = None if device == 'cpu' else [{'device_id': 0}] 14 | 15 | model_det = "../checkpoints/DWPose/yolox_l.onnx" 16 | model_pose = "../checkpoints/DWPose/dw-ll_ucoco_384.onnx" 17 | 18 | self.session_det = ort.InferenceSession( 19 | path_or_bytes=model_det, providers=providers, provider_options=provider_options 20 | ) 21 | self.session_pose = ort.InferenceSession( 22 | path_or_bytes=model_pose, providers=providers, provider_options=provider_options 23 | ) 24 | 25 | def __call__(self, oriImg): 26 | """call to process dwpose-detect 27 | 28 | Args: 29 | oriImg (np.ndarray): detected image 30 | 31 | """ 32 | det_result = inference_detector(self.session_det, oriImg) 33 | keypoints, scores = inference_pose(self.session_pose, det_result, oriImg) 34 | 35 | keypoints_info = np.concatenate( 36 | (keypoints, scores[..., None]), axis=-1) 37 | # compute neck joint 38 | neck = np.mean(keypoints_info[:, [5, 6]], axis=1) 39 | # neck score when visualizing pred 40 | neck[:, 2:4] = np.logical_and( 41 | keypoints_info[:, 5, 2:4] > 0.3, 42 | keypoints_info[:, 6, 2:4] > 0.3).astype(int) 43 | new_keypoints_info = np.insert( 44 | keypoints_info, 17, neck, axis=1) 45 | mmpose_idx = [ 46 | 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 47 | ] 48 | openpose_idx = [ 49 | 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 50 | ] 51 | new_keypoints_info[:, openpose_idx] = \ 52 | new_keypoints_info[:, mmpose_idx] 53 | keypoints_info = new_keypoints_info 54 | 55 | keypoints, scores = keypoints_info[ 56 | ..., :2], keypoints_info[..., 2] 57 | 58 | return keypoints, scores 59 | -------------------------------------------------------------------------------- /dynamictrl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulucaptain/DynamiCtrl/15eb5e9069f9a99c3894c59ad4b4125a1da12e82/dynamictrl/__init__.py -------------------------------------------------------------------------------- /dynamictrl/models/modules/cross_attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple, Union, Optional, Any 3 | 4 | import torch 5 | import torch.utils 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.utils.checkpoint as checkpoint 9 | import einops 10 | from einops import repeat 11 | 12 | from diffusers import AutoencoderKL 13 | from timm.models.vision_transformer import Mlp 14 | from timm.models.layers import to_2tuple 15 | from transformers import ( 16 | AutoTokenizer, 17 | MT5EncoderModel, 18 | BertModel, 19 | ) 20 | 21 | memory_efficient_attention = None 22 | try: 23 | import xformers 24 | except: 25 | pass 26 | 27 | try: 28 | from xformers.ops import memory_efficient_attention 29 | except: 30 | memory_efficient_attention = None 31 | 32 | def reshape_for_broadcast( 33 | freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], 34 | x: torch.Tensor, 35 | head_first=False, 36 | ): 37 | """ 38 | Reshape frequency tensor for broadcasting it with another tensor. 39 | 40 | This function reshapes the frequency tensor to have the same shape as the target tensor 'x' 41 | for the purpose of broadcasting the frequency tensor during element-wise operations. 42 | 43 | Args: 44 | freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped. 45 | x (torch.Tensor): Target tensor for broadcasting compatibility. 46 | head_first (bool): head dimension first (except batch dim) or not. 47 | 48 | Returns: 49 | torch.Tensor: Reshaped frequency tensor. 50 | 51 | Raises: 52 | AssertionError: If the frequency tensor doesn't match the expected shape. 53 | AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. 54 | """ 55 | ndim = x.ndim 56 | assert 0 <= 1 < ndim 57 | 58 | if isinstance(freqs_cis, tuple): 59 | # freqs_cis: (cos, sin) in real space 60 | if head_first: 61 | assert freqs_cis[0].shape == ( 62 | x.shape[-2], 63 | x.shape[-1], 64 | ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" 65 | shape = [ 66 | d if i == ndim - 2 or i == ndim - 1 else 1 67 | for i, d in enumerate(x.shape) 68 | ] 69 | else: 70 | assert freqs_cis[0].shape == ( 71 | x.shape[1], 72 | x.shape[-1], 73 | ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" 74 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 75 | return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) 76 | else: 77 | # freqs_cis: values in complex space 78 | if head_first: 79 | assert freqs_cis.shape == ( 80 | x.shape[-2], 81 | x.shape[-1], 82 | ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" 83 | shape = [ 84 | d if i == ndim - 2 or i == ndim - 1 else 1 85 | for i, d in enumerate(x.shape) 86 | ] 87 | else: 88 | assert freqs_cis.shape == ( 89 | x.shape[1], 90 | x.shape[-1], 91 | ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" 92 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 93 | return freqs_cis.view(*shape) 94 | 95 | MEMORY_LAYOUTS = { 96 | "torch": ( 97 | lambda x, head_dim: x.transpose(1, 2), 98 | lambda x: x.transpose(1, 2), 99 | lambda x: (1, x, 1, 1), 100 | ), 101 | "xformers": ( 102 | lambda x, head_dim: x, 103 | lambda x: x, 104 | lambda x: (1, 1, x, 1), 105 | ), 106 | "math": ( 107 | lambda x, head_dim: x.transpose(1, 2), 108 | lambda x: x.transpose(1, 2), 109 | lambda x: (1, x, 1, 1), 110 | ), 111 | } 112 | 113 | def rotate_half(x): 114 | x_real, x_imag = ( 115 | x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) 116 | ) # [B, S, H, D//2] 117 | return torch.stack([-x_imag, x_real], dim=-1).flatten(3) 118 | 119 | def apply_rotary_emb( 120 | xq: torch.Tensor, 121 | xk: Optional[torch.Tensor], 122 | freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], 123 | head_first: bool = False, 124 | ) -> Tuple[torch.Tensor, torch.Tensor]: 125 | """ 126 | Apply rotary embeddings to input tensors using the given frequency tensor. 127 | 128 | This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided 129 | frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor 130 | is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are 131 | returned as real tensors. 132 | 133 | Args: 134 | xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] 135 | xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] 136 | freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials. 137 | head_first (bool): head dimension first (except batch dim) or not. 138 | 139 | Returns: 140 | Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. 141 | 142 | """ 143 | xk_out = None 144 | if isinstance(freqs_cis, tuple): 145 | cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] 146 | cos, sin = cos.to(xq.device), sin.to(xq.device) 147 | xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq) 148 | if xk is not None: 149 | xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) 150 | else: 151 | xq_ = torch.view_as_complex( 152 | xq.float().reshape(*xq.shape[:-1], -1, 2) 153 | ) # [B, S, H, D//2] 154 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to( 155 | xq.device 156 | ) # [S, D//2] --> [1, S, 1, D//2] 157 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) 158 | if xk is not None: 159 | xk_ = torch.view_as_complex( 160 | xk.float().reshape(*xk.shape[:-1], -1, 2) 161 | ) # [B, S, H, D//2] 162 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) 163 | 164 | return xq_out, xk_out 165 | 166 | def vanilla_attention(q, k, v, mask, dropout_p, scale=None): 167 | if scale is None: 168 | scale = math.sqrt(q.size(-1)) 169 | scores = torch.bmm(q, k.transpose(-1, -2)) / scale 170 | if mask is not None: 171 | mask = einops.rearrange(mask, "b ... -> b (...)") 172 | max_neg_value = -torch.finfo(scores.dtype).max 173 | mask = einops.repeat(mask, "b j -> (b h) j", h=q.size(-3)) 174 | scores = scores.masked_fill(~mask, max_neg_value) 175 | p_attn = F.softmax(scores, dim=-1) 176 | if dropout_p != 0: 177 | scores = F.dropout(p_attn, p=dropout_p, training=True) 178 | return torch.bmm(p_attn, v) 179 | 180 | def attention(q, k, v, head_dim, dropout_p=0, mask=None, scale=None, mode="xformers"): 181 | """ 182 | q, k, v: [B, L, H, D] 183 | """ 184 | pre_attn_layout = MEMORY_LAYOUTS[mode][0] 185 | post_attn_layout = MEMORY_LAYOUTS[mode][1] 186 | q = pre_attn_layout(q, head_dim) 187 | k = pre_attn_layout(k, head_dim) 188 | v = pre_attn_layout(v, head_dim) 189 | 190 | # scores = ATTN_FUNCTION[mode](q, k.to(q), v.to(q), mask, scale=scale) 191 | if mode == "torch": 192 | assert scale is None 193 | scores = F.scaled_dot_product_attention( 194 | q, k.to(q), v.to(q), mask, dropout_p 195 | ) # , scale=scale) 196 | elif mode == "xformers": 197 | scores = memory_efficient_attention( 198 | q, k.to(q), v.to(q), mask, dropout_p, scale=scale 199 | ) 200 | else: 201 | scores = vanilla_attention(q, k.to(q), v.to(q), mask, dropout_p, scale=scale) 202 | 203 | scores = post_attn_layout(scores) 204 | return scores 205 | 206 | 207 | class CrossAttention(nn.Module): 208 | """ 209 | Use QK Normalization. 210 | """ 211 | 212 | def __init__( 213 | self, 214 | qdim, 215 | kdim, 216 | num_heads, 217 | qkv_bias=True, 218 | qk_norm=False, 219 | attn_drop=0.0, 220 | proj_drop=0.0, 221 | device=None, 222 | dtype=None, 223 | norm_layer=nn.LayerNorm, 224 | attn_mode="xformers", 225 | ): 226 | factory_kwargs = {"device": device, "dtype": dtype} 227 | super().__init__() 228 | self.qdim = qdim 229 | self.kdim = kdim 230 | self.num_heads = num_heads 231 | assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads" 232 | self.head_dim = self.qdim // num_heads 233 | assert ( 234 | self.head_dim % 8 == 0 and self.head_dim <= 192 235 | ), "Only support head_dim <= 128 and divisible by 8" 236 | 237 | self.scale = self.head_dim**-0.5 238 | 239 | self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs) 240 | self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs) 241 | 242 | # TODO: eps should be 1 / 65530 if using fp16 243 | self.q_norm = ( 244 | norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) 245 | if qk_norm 246 | else nn.Identity() 247 | ) 248 | self.k_norm = ( 249 | norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) 250 | if qk_norm 251 | else nn.Identity() 252 | ) 253 | 254 | self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs) 255 | self.proj_drop = nn.Dropout(proj_drop) 256 | self.attn_drop = attn_drop 257 | self.attn_mode = attn_mode 258 | 259 | def set_attn_mode(self, mode): 260 | self.attn_mode = mode 261 | 262 | def forward(self, x, y, freqs_cis_img=None): 263 | """ 264 | Parameters 265 | ---------- 266 | x: torch.Tensor 267 | (batch, seqlen1, hidden_dim) (where hidden_dim = num_heads * head_dim) 268 | y: torch.Tensor 269 | (batch, seqlen2, hidden_dim2) 270 | freqs_cis_img: torch.Tensor 271 | (batch, hidden_dim // num_heads), RoPE for image 272 | """ 273 | b, s1, _ = x.shape 274 | _, s2, _ = y.shape 275 | 276 | q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) 277 | kv = self.kv_proj(y).view( 278 | b, s2, 2, self.num_heads, self.head_dim 279 | ) 280 | k, v = kv.unbind(dim=2) 281 | q = self.q_norm(q).to(q) 282 | k = self.k_norm(k).to(k) 283 | 284 | # Apply RoPE if needed 285 | if freqs_cis_img is not None: 286 | qq, _ = apply_rotary_emb(q, None, freqs_cis_img) 287 | assert qq.shape == q.shape, f"qq: {qq.shape}, q: {q.shape}" 288 | q = qq 289 | context = attention(q, k, v, self.head_dim, self.attn_drop, mode=self.attn_mode) 290 | context = context.reshape(b, s1, -1) 291 | 292 | out = self.out_proj(context) 293 | out = self.proj_drop(out) 294 | 295 | out_tuple = (out,) 296 | 297 | return out_tuple 298 | -------------------------------------------------------------------------------- /dynamictrl/pipelines/modules/dynamictrl_output.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | 5 | from diffusers.utils import BaseOutput 6 | 7 | 8 | @dataclass 9 | class DynamiCtrlPipelineOutput(BaseOutput): 10 | r""" 11 | Output class for DynamiCtrl pipelines. 12 | 13 | Args: 14 | frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): 15 | List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing 16 | denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape 17 | `(batch_size, num_frames, channels, height, width)`. 18 | """ 19 | 20 | frames: torch.Tensor 21 | -------------------------------------------------------------------------------- /dynamictrl/pipelines/modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Optional, Union 3 | import inspect 4 | 5 | # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid 6 | def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): 7 | tw = tgt_width 8 | th = tgt_height 9 | h, w = src 10 | r = h / w 11 | if r > (th / tw): 12 | resize_height = th 13 | resize_width = int(round(th / h * w)) 14 | else: 15 | resize_width = tw 16 | resize_height = int(round(tw / w * h)) 17 | 18 | crop_top = int(round((th - resize_height) / 2.0)) 19 | crop_left = int(round((tw - resize_width) / 2.0)) 20 | 21 | return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) 22 | 23 | 24 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 25 | def retrieve_timesteps( 26 | scheduler, 27 | num_inference_steps: Optional[int] = None, 28 | device: Optional[Union[str, torch.device]] = None, 29 | timesteps: Optional[List[int]] = None, 30 | sigmas: Optional[List[float]] = None, 31 | **kwargs, 32 | ): 33 | r""" 34 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 35 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 36 | 37 | Args: 38 | scheduler (`SchedulerMixin`): 39 | The scheduler to get timesteps from. 40 | num_inference_steps (`int`): 41 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 42 | must be `None`. 43 | device (`str` or `torch.device`, *optional*): 44 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 45 | timesteps (`List[int]`, *optional*): 46 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 47 | `num_inference_steps` and `sigmas` must be `None`. 48 | sigmas (`List[float]`, *optional*): 49 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 50 | `num_inference_steps` and `timesteps` must be `None`. 51 | 52 | Returns: 53 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 54 | second element is the number of inference steps. 55 | """ 56 | if timesteps is not None and sigmas is not None: 57 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 58 | if timesteps is not None: 59 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 60 | if not accepts_timesteps: 61 | raise ValueError( 62 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 63 | f" timestep schedules. Please check whether you are using the correct scheduler." 64 | ) 65 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 66 | timesteps = scheduler.timesteps 67 | num_inference_steps = len(timesteps) 68 | elif sigmas is not None: 69 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 70 | if not accept_sigmas: 71 | raise ValueError( 72 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 73 | f" sigmas schedules. Please check whether you are using the correct scheduler." 74 | ) 75 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 76 | timesteps = scheduler.timesteps 77 | num_inference_steps = len(timesteps) 78 | else: 79 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 80 | timesteps = scheduler.timesteps 81 | return timesteps, num_inference_steps 82 | 83 | 84 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 85 | def retrieve_latents( 86 | encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" 87 | ): 88 | if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": 89 | return encoder_output.latent_dist.sample(generator) 90 | elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": 91 | return encoder_output.latent_dist.mode() 92 | elif hasattr(encoder_output, "latents"): 93 | return encoder_output.latents 94 | else: 95 | raise AttributeError("Could not access latents of provided encoder_output") 96 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.4.0 2 | av==14.2.0 3 | boto3==1.26.66 4 | clip==0.2.0 5 | decord==0.6.0 6 | deepspeed==0.14.4 7 | diffusers==0.32.0 8 | huggingface-hub==0.29.1 9 | imageio-ffmpeg==0.6.0 10 | jsonlines==4.0.0 11 | moviepy==1.0.3 12 | ninja==1.11.1.3 13 | numpy==1.26.0 14 | omegaconf==2.3.0 15 | onnxruntime-gpu 16 | PyYAML==6.0.2 17 | qwen-vl-utils==0.0.10 18 | safetensors==0.5.2 19 | sentencepiece==0.2.0 20 | SwissArmyTransformer==0.4.12 21 | tensorboardX==2.6.2.2 22 | timm==1.0.15 23 | tokenizers==0.21.0 24 | torch==2.6.0 25 | torchvision==0.21.0 26 | tqdm==4.67.1 27 | transformers==4.49.0 28 | wandb==0.19.7 -------------------------------------------------------------------------------- /scripts/dynamictrl_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. 2 | # All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import time 18 | import argparse 19 | from typing import Literal 20 | 21 | import torch 22 | from diffusers import CogVideoXDPMScheduler 23 | from diffusers.utils import export_to_video, load_image, load_video 24 | 25 | from dynamictrl.pipelines.dynamictrl_pipeline import DynamiCtrlAnimationPipeline 26 | 27 | from utils.load_validation_control import load_control_video_inference, load_contorl_video_from_Image 28 | from utils.pre_process import preprocess 29 | 30 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 31 | 32 | def generate_video( 33 | prompt: str, 34 | model_path: str, 35 | num_frames: int = 81, 36 | width: int = 1360, 37 | height: int = 768, 38 | output_path: str = "./output.mp4", 39 | reference_image_path: str = "", 40 | ori_driving_video: str = None, 41 | pose_video: str = None, 42 | num_inference_steps: int = 50, 43 | guidance_scale: float = 6.0, 44 | num_videos_per_prompt: int = 1, 45 | dtype: torch.dtype = torch.bfloat16, 46 | abandon_prefix: int = None, 47 | seed: int = 42, 48 | fps: int = 8, 49 | pose_control_function: str = "padaln", 50 | re_init_noise_latent: bool = False, 51 | ): 52 | """ 53 | DynamiCtrl: Generates a dynamic video based on the given human image, guided video, and prompt, and saves it to the specified path. 54 | 55 | Parameters: 56 | - prompt (str): The description of the video to be generated. 57 | - model_path (str): The path of the pre-trained model to be used. 58 | - output_path (str): The path where the generated video will be saved. 59 | - num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality. 60 | - num_frames (int): Number of frames to generate. 61 | - width (int): The width of the generated video 62 | - height (int): The height of the generated video 63 | - guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt. 64 | - num_videos_per_prompt (int): Number of videos to generate per prompt. 65 | - dtype (torch.dtype): The data type for computation (default is torch.bfloat16). 66 | - seed (int): The seed for reproducibility. 67 | - fps (int): The frames per second for the generated video. 68 | """ 69 | 70 | image_name = os.path.basename(reference_image_path).split(".")[0] 71 | 72 | # 1. Initial the DynamiCtrl inference Pipeline. 73 | pipe = DynamiCtrlAnimationPipeline.from_pretrained(model_path, torch_dtype=dtype) 74 | pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") 75 | 76 | pipe.to("cuda") 77 | # pipe.enable_sequential_cpu_offload() 78 | pipe.vae.enable_slicing() 79 | pipe.vae.enable_tiling() 80 | 81 | # 2. Load the image. 82 | image = load_image(image=reference_image_path) 83 | 84 | # 3. Load the driving video or the pose video. 85 | if ori_driving_video: 86 | video_name = os.path.basename(ori_driving_video).split(".")[0] 87 | pose_images, ref_image = preprocess(ori_driving_video, reference_image_path, 88 | width=width, height=height, max_frame_num=num_frames - 1, sample_stride=1) 89 | pose_sequence_save_dir = os.path.join(output_path, f"PoseImage_{image_name}_pose_{video_name}") 90 | os.makedirs(pose_sequence_save_dir, exist_ok=True) 91 | control_poses = [] 92 | for i in range(len(pose_images)): 93 | control_poses.append(pose_images[i]) 94 | pose_images[i].save(os.path.join(pose_sequence_save_dir, f"{i}.png")) # Save aligned pose images. 95 | validation_control_video = load_contorl_video_from_Image(control_poses, height, width).unsqueeze(0) 96 | 97 | if pose_video: 98 | video_name = os.path.basename(pose_video).split(".")[0] 99 | validation_control_video = load_control_video_inference(pose_video, video_height=height, video_width=width, max_frames=num_frames).unsqueeze(0) 100 | 101 | # 4. Generate the video frames. 102 | video_generate = pipe( 103 | height=height, 104 | width=width, 105 | prompt=prompt, 106 | image=image, 107 | control_video=validation_control_video, 108 | pose_control_function=pose_control_function, 109 | num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt 110 | num_inference_steps=num_inference_steps, # Number of inference steps 111 | num_frames=num_frames, # Number of frames to generate 112 | use_dynamic_cfg=False, # This id used for DPM scheduler, for DDIM scheduler, it should be False 113 | guidance_scale=guidance_scale, 114 | generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility 115 | re_init_noise_latent=re_init_noise_latent, 116 | ).frames[0] 117 | 118 | # 5. Save the videos. 119 | for i in range(int(len(video_generate) / num_frames)): 120 | name_timestamp = int(time.time()) 121 | video_gen = video_generate[i * num_frames : (i + 1) * num_frames] 122 | if abandon_prefix: 123 | video_gen = video_gen[abandon_prefix:] 124 | video_output_path = os.path.join(output_path, f"DynamiCtrl_{name_timestamp}_image_{image_name}_pose_{video_name}_seed_{seed}_cfg_{guidance_scale}.mp4") 125 | export_to_video(video_gen, video_output_path, fps=fps) 126 | print(f"output_path: {video_output_path}") 127 | 128 | if __name__ == "__main__": 129 | parser = argparse.ArgumentParser(description="Human image animation from an human image, a driving video, and prompt using DynamiCtrl") 130 | parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated") 131 | parser.add_argument( 132 | "--reference_image_path", 133 | type=str, 134 | default=None, 135 | help="The path of the reference human image to be used as the subject of the video", 136 | ) 137 | parser.add_argument( 138 | "--ori_driving_video", 139 | type=str, 140 | default=None, 141 | help="the origin real video pth, contain the actioned human" 142 | ) 143 | parser.add_argument( 144 | "--pose_video", 145 | type=str, 146 | default=None, 147 | help="the path of pose video" 148 | ) 149 | parser.add_argument( 150 | "--model_path", type=str, default="gulucaptain/DynamiCtrl", help="Path of the pre-trained model use" 151 | ) 152 | parser.add_argument("--output_path", type=str, default="./output.mp4", help="The path save generated video") 153 | parser.add_argument("--guidance_scale", type=float, default=3.0, help="The scale for classifier-free guidance") 154 | parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps") 155 | parser.add_argument("--num_frames", type=int, default=81, help="Number of steps for the inference process") 156 | parser.add_argument("--width", type=int, default=1360, help="Number of steps for the inference process") 157 | parser.add_argument("--height", type=int, default=768, help="Number of steps for the inference process") 158 | parser.add_argument("--fps", type=int, default=16, help="Number of steps for the inference process") 159 | parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt") 160 | parser.add_argument("--dtype", type=str, default="bfloat16", help="The data type for computation") 161 | parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility") 162 | parser.add_argument("--abandon_prefix", type=int, default=None, help="Save video frames range [abandon_prefix:]") 163 | parser.add_argument("--re_init_noise_latent", action="store_true", help="whether to resample the initial noise") 164 | parser.add_argument("--pose_control_function", type=str, default="padaln") 165 | 166 | args = parser.parse_args() 167 | dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16 168 | os.makedirs(args.output_path, exist_ok=True) 169 | generate_video( 170 | prompt=args.prompt, 171 | model_path=args.model_path, 172 | output_path=args.output_path, 173 | num_frames=args.num_frames, 174 | width=args.width, 175 | height=args.height, 176 | reference_image_path=args.reference_image_path, 177 | ori_driving_video=args.ori_driving_video, 178 | pose_video=args.pose_video, 179 | num_inference_steps=args.num_inference_steps, 180 | guidance_scale=args.guidance_scale, 181 | num_videos_per_prompt=args.num_videos_per_prompt, 182 | dtype=dtype, 183 | abandon_prefix=args.abandon_prefix, 184 | seed=args.seed, 185 | fps=args.fps, 186 | re_init_noise_latent=args.re_init_noise_latent, 187 | pose_control_function=args.pose_control_function, 188 | ) 189 | -------------------------------------------------------------------------------- /tools/qwen.py: -------------------------------------------------------------------------------- 1 | from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor 2 | from qwen_vl_utils import process_vision_info 3 | import os 4 | from tqdm import tqdm 5 | 6 | # default: Load the model on the available device(s) 7 | model = Qwen2VLForConditionalGeneration.from_pretrained( 8 | "/home/user/ckpt/qwen", torch_dtype="auto", device_map="auto" 9 | ) 10 | 11 | # default processer 12 | processor = AutoProcessor.from_pretrained("/home/user/code/ckpt/qwen") 13 | 14 | image_data_root = "/home/user/images" 15 | images = os.listdir(image_data_root) 16 | images.sort() 17 | 18 | for image in tqdm(images): 19 | image_path = os.path.join(image_data_root, image) 20 | Question_description = "Describe the person and the background." #FIXME Replace the question you want. 21 | 22 | messages = [ 23 | { 24 | "role": "user", 25 | "content": [ 26 | { 27 | "type": "image", 28 | "image": image_path, 29 | }, 30 | {"type": "text", "text": Question_description}, 31 | ], 32 | } 33 | ] 34 | 35 | # Preparation for inference 36 | text = processor.apply_chat_template( 37 | messages, tokenize=False, add_generation_prompt=True 38 | ) 39 | image_inputs, video_inputs = process_vision_info(messages) 40 | inputs = processor( 41 | text=[text], 42 | images=image_inputs, 43 | videos=video_inputs, 44 | padding=True, 45 | return_tensors="pt", 46 | ) 47 | inputs = inputs.to("cuda") 48 | 49 | # Inference: Generation of the output 50 | generated_ids = model.generate(**inputs, max_new_tokens=128) 51 | generated_ids_trimmed = [ 52 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) 53 | ] 54 | output_text = processor.batch_decode( 55 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 56 | )[0] 57 | with open("/home/user/descriptions.txt", "a") as f: 58 | f.writelines(image+"#####"+output_text+"\n") 59 | -------------------------------------------------------------------------------- /utils/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def _get_model_args(parser: argparse.ArgumentParser) -> None: 5 | parser.add_argument( 6 | "--pretrained_model_name_or_path", 7 | type=str, 8 | default=None, 9 | required=True, 10 | help="Path to pretrained model or model identifier from huggingface.co/models.", 11 | ) 12 | parser.add_argument( 13 | "--revision", 14 | type=str, 15 | default=None, 16 | required=False, 17 | help="Revision of pretrained model identifier from huggingface.co/models.", 18 | ) 19 | parser.add_argument( 20 | "--variant", 21 | type=str, 22 | default=None, 23 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 24 | ) 25 | parser.add_argument( 26 | "--cache_dir", 27 | type=str, 28 | default=None, 29 | help="The directory where the downloaded models and datasets will be stored.", 30 | ) 31 | 32 | 33 | def _get_dataset_args(parser: argparse.ArgumentParser) -> None: 34 | parser.add_argument( 35 | "--data_root", 36 | type=str, 37 | default=None, 38 | help=("A folder containing the training data."), 39 | ) 40 | parser.add_argument( 41 | "--dataset_file", 42 | type=str, 43 | default=None, 44 | help=("Path to a CSV file if loading prompts/video paths using this format."), 45 | ) 46 | parser.add_argument( 47 | "--video_column", 48 | type=str, 49 | default="video", 50 | help="The column of the dataset containing videos. Or, the name of the file in `--data_root` folder containing the line-separated path to video data.", 51 | ) 52 | parser.add_argument( 53 | "--caption_column", 54 | type=str, 55 | default="text", 56 | help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--data_root` folder containing the line-separated instance prompts.", 57 | ) 58 | parser.add_argument( 59 | "--id_token", 60 | type=str, 61 | default=None, 62 | help="Identifier token appended to the start of each prompt if provided.", 63 | ) 64 | parser.add_argument( 65 | "--height_buckets", 66 | nargs="+", 67 | type=int, 68 | default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536], 69 | ) 70 | parser.add_argument( 71 | "--width_buckets", 72 | nargs="+", 73 | type=int, 74 | default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536], 75 | ) 76 | parser.add_argument( 77 | "--frame_buckets", 78 | nargs="+", 79 | type=int, 80 | default=[49], 81 | help="CogVideoX1.5 need to guarantee that ((num_frames - 1) // self.vae_scale_factor_temporal + 1) % patch_size_t == 0, such as 53" 82 | ) 83 | parser.add_argument( 84 | "--load_tensors", 85 | action="store_true", 86 | help="Whether to use a pre-encoded tensor dataset of latents and prompt embeddings instead of videos and text prompts. The expected format is that saved by running the `prepare_dataset.py` script.", 87 | ) 88 | parser.add_argument( 89 | "--random_flip", 90 | type=float, 91 | default=None, 92 | help="If random horizontal flip augmentation is to be used, this should be the flip probability.", 93 | ) 94 | parser.add_argument( 95 | "--dataloader_num_workers", 96 | type=int, 97 | default=0, 98 | help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", 99 | ) 100 | parser.add_argument( 101 | "--pin_memory", 102 | action="store_true", 103 | help="Whether or not to use the pinned memory setting in pytorch dataloader.", 104 | ) 105 | 106 | 107 | def _get_validation_args(parser: argparse.ArgumentParser) -> None: 108 | parser.add_argument( 109 | "--validation_prompt", 110 | type=str, 111 | default=None, 112 | help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", 113 | ) 114 | parser.add_argument( 115 | "--validation_images", 116 | type=str, 117 | default=None, 118 | help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.", 119 | ) 120 | parser.add_argument( 121 | "--validation_driving_videos", 122 | type=str, 123 | default=None, 124 | help="One or more driving video path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.", 125 | ) 126 | parser.add_argument( 127 | "--validation_prompt_separator", 128 | type=str, 129 | default=":::", 130 | help="String that separates multiple validation prompts", 131 | ) 132 | parser.add_argument( 133 | "--num_validation_videos", 134 | type=int, 135 | default=1, 136 | help="Number of videos that should be generated during validation per `validation_prompt`.", 137 | ) 138 | parser.add_argument( 139 | "--validation_epochs", 140 | type=int, 141 | default=None, 142 | help="Run validation every X training epochs. Validation consists of running the validation prompt `args.num_validation_videos` times.", 143 | ) 144 | parser.add_argument( 145 | "--validation_steps", 146 | type=int, 147 | default=None, 148 | help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.", 149 | ) 150 | parser.add_argument( 151 | "--guidance_scale", 152 | type=float, 153 | default=6, 154 | help="The guidance scale to use while sampling validation videos.", 155 | ) 156 | parser.add_argument( 157 | "--use_dynamic_cfg", 158 | action="store_true", 159 | default=False, 160 | help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.", 161 | ) 162 | parser.add_argument( 163 | "--enable_model_cpu_offload", 164 | action="store_true", 165 | default=False, 166 | help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.", 167 | ) 168 | 169 | 170 | def _get_training_args(parser: argparse.ArgumentParser) -> None: 171 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 172 | parser.add_argument("--rank", type=int, default=64, help="The rank for LoRA matrices.") 173 | parser.add_argument( 174 | "--lora_alpha", 175 | type=int, 176 | default=64, 177 | help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.", 178 | ) 179 | parser.add_argument( 180 | "--mixed_precision", 181 | type=str, 182 | default=None, 183 | choices=["no", "fp16", "bf16"], 184 | help=( 185 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.and an Nvidia Ampere GPU. " 186 | "Default to the value of accelerate config of the current system or the flag passed with the `accelerate.launch` command. Use this " 187 | "argument to override the accelerate config." 188 | ), 189 | ) 190 | parser.add_argument( 191 | "--output_dir", 192 | type=str, 193 | default="dynamictrl-sft", 194 | help="The output directory where the model predictions and checkpoints will be written.", 195 | ) 196 | parser.add_argument( 197 | "--height", 198 | type=int, 199 | default=480, 200 | help="All input videos are resized to this height.", 201 | ) 202 | parser.add_argument( 203 | "--width", 204 | type=int, 205 | default=720, 206 | help="All input videos are resized to this width.", 207 | ) 208 | parser.add_argument( 209 | "--video_reshape_mode", 210 | type=str, 211 | default=None, 212 | help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']", 213 | ) 214 | parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.") 215 | parser.add_argument( 216 | "--max_num_frames", 217 | type=int, 218 | default=49, 219 | help="All input videos will be truncated to these many frames.", 220 | ) 221 | parser.add_argument( 222 | "--skip_frames_start", 223 | type=int, 224 | default=0, 225 | help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.", 226 | ) 227 | parser.add_argument( 228 | "--skip_frames_end", 229 | type=int, 230 | default=0, 231 | help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.", 232 | ) 233 | parser.add_argument( 234 | "--train_batch_size", 235 | type=int, 236 | default=4, 237 | help="Batch size (per device) for the training dataloader.", 238 | ) 239 | parser.add_argument("--num_train_epochs", type=int, default=1) 240 | parser.add_argument( 241 | "--max_train_steps", 242 | type=int, 243 | default=None, 244 | help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.", 245 | ) 246 | parser.add_argument( 247 | "--checkpointing_steps", 248 | type=int, 249 | default=500, 250 | help=( 251 | "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" 252 | " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" 253 | " training using `--resume_from_checkpoint`." 254 | ), 255 | ) 256 | parser.add_argument( 257 | "--checkpoints_total_limit", 258 | type=int, 259 | default=None, 260 | help=("Max number of checkpoints to store."), 261 | ) 262 | parser.add_argument( 263 | "--resume_from_checkpoint", 264 | type=str, 265 | default=None, 266 | help=( 267 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 268 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 269 | ), 270 | ) 271 | parser.add_argument( 272 | "--gradient_accumulation_steps", 273 | type=int, 274 | default=1, 275 | help="Number of updates steps to accumulate before performing a backward/update pass.", 276 | ) 277 | parser.add_argument( 278 | "--gradient_checkpointing", 279 | action="store_true", 280 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 281 | ) 282 | parser.add_argument( 283 | "--learning_rate", 284 | type=float, 285 | default=1e-4, 286 | help="Initial learning rate (after the potential warmup period) to use.", 287 | ) 288 | parser.add_argument( 289 | "--scale_lr", 290 | action="store_true", 291 | default=False, 292 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 293 | ) 294 | parser.add_argument( 295 | "--lr_scheduler", 296 | type=str, 297 | default="constant", 298 | help=( 299 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 300 | ' "constant", "constant_with_warmup"]' 301 | ), 302 | ) 303 | parser.add_argument( 304 | "--lr_warmup_steps", 305 | type=int, 306 | default=500, 307 | help="Number of steps for the warmup in the lr scheduler.", 308 | ) 309 | parser.add_argument( 310 | "--lr_num_cycles", 311 | type=int, 312 | default=1, 313 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", 314 | ) 315 | parser.add_argument( 316 | "--lr_power", 317 | type=float, 318 | default=1.0, 319 | help="Power factor of the polynomial scheduler.", 320 | ) 321 | parser.add_argument( 322 | "--enable_slicing", 323 | action="store_true", 324 | default=False, 325 | help="Whether or not to use VAE slicing for saving memory.", 326 | ) 327 | parser.add_argument( 328 | "--enable_tiling", 329 | action="store_true", 330 | default=False, 331 | help="Whether or not to use VAE tiling for saving memory.", 332 | ) 333 | parser.add_argument( 334 | "--noised_image_dropout", 335 | type=float, 336 | default=0.05, 337 | help="Image condition dropout probability when finetuning image-to-video.", 338 | ) 339 | parser.add_argument( 340 | "--ignore_learned_positional_embeddings", 341 | action="store_true", 342 | default=False, 343 | help=( 344 | "Whether to ignore the learned positional embeddings when training CogVideoX Image-to-Video. This setting " 345 | "should be used when performing multi-resolution training, because CogVideoX-I2V does not support it " 346 | "otherwise. Please read the comments in https://github.com/a-r-r-o-w/cogvideox-factory/issues/26 to understand why." 347 | ), 348 | ) 349 | parser.add_argument( 350 | "--enable_control_pose", 351 | action="store_true", 352 | default=False, 353 | help="Whether or not to use pose control information" 354 | ) 355 | parser.add_argument( 356 | "--pose_control_function", 357 | type=str, 358 | default="expert_layer_norm", 359 | help="pose control function: direct_adding; expert_layer_norm; cross_attention", 360 | ) 361 | 362 | 363 | def _get_optimizer_args(parser: argparse.ArgumentParser) -> None: 364 | parser.add_argument( 365 | "--optimizer", 366 | type=lambda s: s.lower(), 367 | default="adam", 368 | choices=["adam", "adamw", "prodigy", "came"], 369 | help=("The optimizer type to use."), 370 | ) 371 | parser.add_argument( 372 | "--use_8bit", 373 | action="store_true", 374 | help="Whether or not to use 8-bit optimizers from `bitsandbytes` or `bitsandbytes`.", 375 | ) 376 | parser.add_argument( 377 | "--use_4bit", 378 | action="store_true", 379 | help="Whether or not to use 4-bit optimizers from `torchao`.", 380 | ) 381 | parser.add_argument( 382 | "--use_torchao", action="store_true", help="Whether or not to use the `torchao` backend for optimizers." 383 | ) 384 | parser.add_argument( 385 | "--beta1", 386 | type=float, 387 | default=0.9, 388 | help="The beta1 parameter for the Adam and Prodigy optimizers.", 389 | ) 390 | parser.add_argument( 391 | "--beta2", 392 | type=float, 393 | default=0.95, 394 | help="The beta2 parameter for the Adam and Prodigy optimizers.", 395 | ) 396 | parser.add_argument( 397 | "--beta3", 398 | type=float, 399 | default=None, 400 | help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.", 401 | ) 402 | parser.add_argument( 403 | "--prodigy_decouple", 404 | action="store_true", 405 | help="Use AdamW style decoupled weight decay.", 406 | ) 407 | parser.add_argument( 408 | "--weight_decay", 409 | type=float, 410 | default=1e-04, 411 | help="Weight decay to use for optimizer.", 412 | ) 413 | parser.add_argument( 414 | "--epsilon", 415 | type=float, 416 | default=1e-8, 417 | help="Epsilon value for the Adam optimizer and Prodigy optimizers.", 418 | ) 419 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 420 | parser.add_argument( 421 | "--prodigy_use_bias_correction", 422 | action="store_true", 423 | help="Turn on Adam's bias correction.", 424 | ) 425 | parser.add_argument( 426 | "--prodigy_safeguard_warmup", 427 | action="store_true", 428 | help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.", 429 | ) 430 | parser.add_argument( 431 | "--use_cpu_offload_optimizer", 432 | action="store_true", 433 | help="Whether or not to use the CPUOffloadOptimizer from TorchAO to perform optimization step and maintain parameters on the CPU.", 434 | ) 435 | parser.add_argument( 436 | "--offload_gradients", 437 | action="store_true", 438 | help="Whether or not to offload the gradients to CPU when using the CPUOffloadOptimizer from TorchAO.", 439 | ) 440 | 441 | 442 | def _get_configuration_args(parser: argparse.ArgumentParser) -> None: 443 | parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name") 444 | parser.add_argument( 445 | "--push_to_hub", 446 | action="store_true", 447 | help="Whether or not to push the model to the Hub.", 448 | ) 449 | parser.add_argument( 450 | "--hub_token", 451 | type=str, 452 | default=None, 453 | help="The token to use to push to the Model Hub.", 454 | ) 455 | parser.add_argument( 456 | "--hub_model_id", 457 | type=str, 458 | default=None, 459 | help="The name of the repository to keep in sync with the local `output_dir`.", 460 | ) 461 | parser.add_argument( 462 | "--logging_dir", 463 | type=str, 464 | default="logs", 465 | help="Directory where logs are stored.", 466 | ) 467 | parser.add_argument( 468 | "--allow_tf32", 469 | action="store_true", 470 | help=( 471 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 472 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 473 | ), 474 | ) 475 | parser.add_argument( 476 | "--nccl_timeout", 477 | type=int, 478 | default=600, 479 | help="Maximum timeout duration before which allgather, or related, operations fail in multi-GPU/multi-node training settings.", 480 | ) 481 | parser.add_argument( 482 | "--report_to", 483 | type=str, 484 | default=None, 485 | help=( 486 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 487 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 488 | ), 489 | ) 490 | 491 | 492 | def get_args(): 493 | parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.") 494 | 495 | _get_model_args(parser) 496 | _get_dataset_args(parser) 497 | _get_training_args(parser) 498 | _get_validation_args(parser) 499 | _get_optimizer_args(parser) 500 | _get_configuration_args(parser) 501 | 502 | return parser.parse_args() 503 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | from typing import Any, Dict, List, Optional, Tuple 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | import torchvision.transforms as TT 9 | from accelerate.logging import get_logger 10 | from torch.utils.data import Dataset, Sampler 11 | from torchvision import transforms 12 | from torchvision.transforms import InterpolationMode 13 | from torchvision.transforms.functional import resize 14 | 15 | 16 | # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error 17 | # Very few bug reports but it happens. Look in decord Github issues for more relevant information. 18 | import decord # isort:skip 19 | 20 | decord.bridge.set_bridge("torch") 21 | 22 | logger = get_logger(__name__) 23 | 24 | HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] 25 | WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] 26 | FRAME_BUCKETS = [16, 24, 32, 48, 64, 80] 27 | 28 | 29 | class VideoDataset(Dataset): 30 | def __init__( 31 | self, 32 | data_root: str, 33 | dataset_file: Optional[str] = None, 34 | caption_column: str = "text", 35 | video_column: str = "video", 36 | max_num_frames: int = 49, 37 | id_token: Optional[str] = None, 38 | height_buckets: List[int] = None, 39 | width_buckets: List[int] = None, 40 | frame_buckets: List[int] = None, 41 | load_tensors: bool = False, 42 | random_flip: Optional[float] = None, 43 | image_to_video: bool = False, 44 | enable_control_pose: bool = False, 45 | ) -> None: 46 | super().__init__() 47 | 48 | self.data_root = Path(data_root) 49 | self.dataset_file = dataset_file 50 | self.caption_column = caption_column 51 | self.video_column = video_column 52 | self.max_num_frames = max_num_frames 53 | self.id_token = f"{id_token.strip()} " if id_token else "" 54 | self.height_buckets = height_buckets or HEIGHT_BUCKETS 55 | self.width_buckets = width_buckets or WIDTH_BUCKETS 56 | self.frame_buckets = frame_buckets or FRAME_BUCKETS 57 | self.load_tensors = load_tensors 58 | self.random_flip = random_flip 59 | self.image_to_video = image_to_video 60 | self.enable_control_pose = enable_control_pose 61 | 62 | self.resolutions = [ 63 | (f, h, w) for h in self.height_buckets for w in self.width_buckets for f in self.frame_buckets 64 | ] 65 | 66 | # Two methods of loading data are supported. 67 | # - Using a CSV: caption_column and video_column must be some column in the CSV. One could 68 | # make use of other columns too, such as a motion score or aesthetic score, by modifying the 69 | # logic in CSV processing. 70 | # - Using two files containing line-separate captions and relative paths to videos. 71 | # For a more detailed explanation about preparing dataset format, checkout the README. 72 | if dataset_file is None: 73 | ( 74 | self.prompts, 75 | self.video_paths, 76 | ) = self._load_dataset_from_local_path() 77 | else: 78 | ( 79 | self.prompts, 80 | self.video_paths, 81 | ) = self._load_dataset_from_csv() 82 | 83 | if len(self.video_paths) != len(self.prompts): 84 | raise ValueError( 85 | f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset." 86 | ) 87 | 88 | self.video_transforms = transforms.Compose( 89 | [ 90 | transforms.RandomHorizontalFlip(random_flip) 91 | if random_flip 92 | else transforms.Lambda(self.identity_transform), 93 | transforms.Lambda(self.scale_transform), 94 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 95 | ] 96 | ) 97 | 98 | @staticmethod 99 | def identity_transform(x): 100 | return x 101 | 102 | @staticmethod 103 | def scale_transform(x): 104 | return x / 255.0 105 | 106 | def __len__(self) -> int: 107 | return len(self.video_paths) 108 | 109 | def __getitem__(self, index: int) -> Dict[str, Any]: 110 | if isinstance(index, list): 111 | # Here, index is actually a list of data objects that we need to return. 112 | # The BucketSampler should ideally return indices. But, in the sampler, we'd like 113 | # to have information about num_frames, height and width. Since this is not stored 114 | # as metadata, we need to read the video to get this information. You could read this 115 | # information without loading the full video in memory, but we do it anyway. In order 116 | # to not load the video twice (once to get the metadata, and once to return the loaded video 117 | # based on sampled indices), we cache it in the BucketSampler. When the sampler is 118 | # to yield, we yield the cache data instead of indices. So, this special check ensures 119 | # that data is not loaded a second time. PRs are welcome for improvements. 120 | return index 121 | 122 | if self.load_tensors: # False 123 | image_latents, video_latents, prompt_embeds = self._preprocess_video(self.video_paths[index]) 124 | 125 | latent_num_frames = video_latents.size(1) 126 | if latent_num_frames % 2 == 0: 127 | num_frames = latent_num_frames * 4 128 | else: 129 | num_frames = (latent_num_frames - 1) * 4 + 1 130 | 131 | height = video_latents.size(2) * 8 132 | width = video_latents.size(3) * 8 133 | 134 | return { 135 | "prompt": prompt_embeds, 136 | "image": image_latents, 137 | "video": video_latents, 138 | "video_metadata": { 139 | "num_frames": num_frames, 140 | "height": height, 141 | "width": width, 142 | }, 143 | } 144 | else: 145 | image, video, control_video, _ = self._preprocess_video(self.video_paths[index]) 146 | prompt = self.prompts[index].split("#####")[1] 147 | 148 | return { 149 | "prompt": self.id_token + prompt, 150 | "image": image, 151 | "video": video, 152 | "control_video": control_video, 153 | "video_metadata": { 154 | "num_frames": video.shape[0], 155 | "height": video.shape[2], 156 | "width": video.shape[3], 157 | }, 158 | } 159 | 160 | def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str]]: 161 | if not self.data_root.exists(): 162 | raise ValueError("Root folder for videos does not exist") 163 | 164 | prompt_path = self.data_root.joinpath(self.caption_column) 165 | video_path = self.data_root.joinpath(self.video_column) 166 | 167 | if not prompt_path.exists() or not prompt_path.is_file(): 168 | raise ValueError( 169 | "Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts." 170 | ) 171 | if not video_path.exists() or not video_path.is_file(): 172 | raise ValueError( 173 | "Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory." 174 | ) 175 | 176 | with open(prompt_path, "r", encoding="utf-8") as file: 177 | prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0] 178 | with open(video_path, "r", encoding="utf-8") as file: 179 | video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0] 180 | 181 | if not self.load_tensors and any(not path.is_file() for path in video_paths): 182 | raise ValueError( 183 | f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." 184 | ) 185 | 186 | return prompts, video_paths 187 | 188 | def _load_dataset_from_csv(self) -> Tuple[List[str], List[str]]: 189 | df = pd.read_csv(self.dataset_file) 190 | prompts = df[self.caption_column].tolist() 191 | video_paths = df[self.video_column].tolist() 192 | video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths] 193 | 194 | if any(not path.is_file() for path in video_paths): 195 | raise ValueError( 196 | f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." 197 | ) 198 | 199 | return prompts, video_paths 200 | 201 | def _preprocess_video(self, path: Path) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 202 | r""" 203 | Loads a single video, or latent and prompt embedding, based on initialization parameters. 204 | 205 | If returning a video, returns a [F, C, H, W] video tensor, and None for the prompt embedding. Here, 206 | F, C, H and W are the frames, channels, height and width of the input video. 207 | 208 | If returning latent/embedding, returns a [F, C, H, W] latent, and the prompt embedding of shape [S, D]. 209 | F, C, H and W are the frames, channels, height and width of the latent, and S, D are the sequence length 210 | and embedding dimension of prompt embeddings. 211 | """ 212 | if self.load_tensors: 213 | return self._load_preprocessed_latents_and_embeds(path) 214 | else: 215 | video_reader = decord.VideoReader(uri=path.as_posix()) 216 | video_num_frames = len(video_reader) 217 | 218 | indices = list(range(0, video_num_frames, video_num_frames // self.max_num_frames)) 219 | frames = video_reader.get_batch(indices) 220 | frames = frames[: self.max_num_frames].float() 221 | frames = frames.permute(0, 3, 1, 2).contiguous() 222 | frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0) 223 | 224 | image = frames[:1].clone() if self.image_to_video else None 225 | 226 | return image, frames, None 227 | 228 | def _load_preprocessed_latents_and_embeds(self, path: Path) -> Tuple[torch.Tensor, torch.Tensor]: 229 | filename_without_ext = path.name.split(".")[0] 230 | pt_filename = f"{filename_without_ext}.pt" 231 | 232 | image_latents_path = path.parent.parent.joinpath("image_latents") 233 | video_latents_path = path.parent.parent.joinpath("video_latents") 234 | embeds_path = path.parent.parent.joinpath("prompt_embeds") 235 | 236 | if ( 237 | not video_latents_path.exists() 238 | or not embeds_path.exists() 239 | or (self.image_to_video and not image_latents_path.exists()) 240 | ): 241 | raise ValueError( 242 | f"When setting the load_tensors parameter to `True`, it is expected that the `{self.data_root=}` contains two folders named `video_latents` and `prompt_embeds`. However, these folders were not found. Please make sure to have prepared your data correctly using `prepare_data.py`. Additionally, if you're training image-to-video, it is expected that an `image_latents` folder is also present." 243 | ) 244 | 245 | if self.image_to_video: 246 | image_latent_filepath = image_latents_path.joinpath(pt_filename) 247 | video_latent_filepath = video_latents_path.joinpath(pt_filename) 248 | embeds_filepath = embeds_path.joinpath(pt_filename) 249 | 250 | if not video_latent_filepath.is_file() or not embeds_filepath.is_file(): 251 | if self.image_to_video: 252 | image_latent_filepath = image_latent_filepath.as_posix() 253 | video_latent_filepath = video_latent_filepath.as_posix() 254 | embeds_filepath = embeds_filepath.as_posix() 255 | raise ValueError( 256 | f"The file {video_latent_filepath=} or {embeds_filepath=} could not be found. Please ensure that you've correctly executed `prepare_dataset.py`." 257 | ) 258 | 259 | images = ( 260 | torch.load(image_latent_filepath, map_location="cpu", weights_only=True) if self.image_to_video else None 261 | ) 262 | latents = torch.load(video_latent_filepath, map_location="cpu", weights_only=True) 263 | embeds = torch.load(embeds_filepath, map_location="cpu", weights_only=True) 264 | 265 | return images, latents, embeds 266 | 267 | 268 | class VideoDatasetWithResizing(VideoDataset): 269 | def __init__(self, *args, **kwargs) -> None: 270 | super().__init__(*args, **kwargs) 271 | 272 | def get_entity_with_mask_short_time(self, video, masked_video): 273 | video = video.detach().numpy() 274 | masked_video = masked_video.detach().numpy() 275 | 276 | threshold = 200 277 | output_frames = [] 278 | 279 | if video.shape[1] != masked_video.shape[1] or video.shape[2] != masked_video.shape[2]: 280 | return video 281 | 282 | frame_length, frame_height, frame_width, _ = video.shape 283 | 284 | for i in range(frame_length): 285 | video_frame = video[i] 286 | mask_frame = masked_video[i] 287 | mask_frame = mask_frame[:, :, 0] 288 | 289 | mask_frame = mask_frame.astype(np.uint8) 290 | binary_mask = (mask_frame > threshold).astype(np.uint8) 291 | binary_mask = np.expand_dims(binary_mask, axis=-1) 292 | 293 | masked_frame = video_frame * binary_mask 294 | output_frames.append(masked_frame) 295 | 296 | return torch.tensor(np.array(output_frames)) 297 | 298 | def _preprocess_video(self, path: Path) -> torch.Tensor: 299 | if self.load_tensors: 300 | return self._load_preprocessed_latents_and_embeds(path) 301 | else: 302 | video_reader = decord.VideoReader(uri=path.as_posix()) 303 | 304 | use_image_mask = False 305 | if use_image_mask: 306 | try: 307 | mask_reader = decord.VideoReader(uri=path.as_posix().replace("videos", "masked_video")) # directly load the entity video. 308 | except: 309 | mask_reader = None 310 | else: 311 | mask_reader = None 312 | 313 | if self.enable_control_pose: 314 | pose_reader = decord.VideoReader(uri=path.as_posix().replace("videos", "poses")) 315 | else: 316 | pose_reader = None 317 | 318 | video_num_frames = len(video_reader) 319 | nearest_frame_bucket = min( 320 | self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) 321 | ) 322 | 323 | frames_load_interval = 1 324 | frame_indices = [] 325 | frame_index = random.randint(1, 30) 326 | while len(frame_indices) < nearest_frame_bucket: 327 | frame_indices.append(frame_index) 328 | frame_index += frames_load_interval 329 | 330 | frames = video_reader.get_batch(frame_indices) 331 | if use_image_mask and mask_reader is not None: 332 | masked_frames = mask_reader.get_batch(frame_indices) 333 | maked_frames = masked_frames[:nearest_frame_bucket].float() 334 | maked_frames = masked_frames.permute(0, 3, 1, 2).contiguous() 335 | 336 | frames = frames[:nearest_frame_bucket].float() 337 | frames = frames.permute(0, 3, 1, 2).contiguous() 338 | 339 | nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) 340 | frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0) 341 | frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) 342 | 343 | if use_image_mask and mask_reader is not None: 344 | maked_frames_resized = torch.stack([resize(mask_frame, nearest_res) for mask_frame in maked_frames], dim=0) 345 | maked_frames = torch.stack([self.video_transforms(mask_frame) for mask_frame in maked_frames_resized], dim=0) 346 | 347 | pose_frames = pose_reader.get_batch(frame_indices) 348 | pose_frames = pose_frames[:nearest_frame_bucket].float() 349 | pose_frames = pose_frames.permute(0, 3, 1, 2).contiguous() 350 | 351 | pose_frames_resized = torch.stack([resize(frame, nearest_res) for frame in pose_frames], dim=0) 352 | pose_frames = torch.stack([self.video_transforms(frame) for frame in pose_frames_resized], dim=0) 353 | 354 | if use_image_mask and mask_reader is not None: 355 | image = maked_frames[:1].clone() if self.image_to_video else None 356 | else: 357 | image = frames[:1].clone() if self.image_to_video else None 358 | 359 | return image, frames, pose_frames, None 360 | 361 | def _find_nearest_resolution(self, height, width): 362 | nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) 363 | return nearest_res[1], nearest_res[2] 364 | 365 | 366 | class VideoDatasetWithResizeAndRectangleCrop(VideoDataset): 367 | def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None: 368 | super().__init__(*args, **kwargs) 369 | self.video_reshape_mode = video_reshape_mode 370 | 371 | def _resize_for_rectangle_crop(self, arr, image_size): 372 | reshape_mode = self.video_reshape_mode 373 | if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: 374 | arr = resize( 375 | arr, 376 | size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], 377 | interpolation=InterpolationMode.BICUBIC, 378 | ) 379 | else: 380 | arr = resize( 381 | arr, 382 | size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], 383 | interpolation=InterpolationMode.BICUBIC, 384 | ) 385 | 386 | h, w = arr.shape[2], arr.shape[3] 387 | arr = arr.squeeze(0) 388 | 389 | delta_h = h - image_size[0] 390 | delta_w = w - image_size[1] 391 | 392 | if reshape_mode == "random" or reshape_mode == "none": 393 | top = np.random.randint(0, delta_h + 1) 394 | left = np.random.randint(0, delta_w + 1) 395 | elif reshape_mode == "center": 396 | top, left = delta_h // 2, delta_w // 2 397 | else: 398 | raise NotImplementedError 399 | arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) 400 | return arr 401 | 402 | def _preprocess_video(self, path: Path) -> torch.Tensor: 403 | if self.load_tensors: 404 | return self._load_preprocessed_latents_and_embeds(path) 405 | else: 406 | video_reader = decord.VideoReader(uri=path.as_posix()) 407 | video_num_frames = len(video_reader) 408 | nearest_frame_bucket = min( 409 | self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) 410 | ) 411 | 412 | frames_load_interval = 1 413 | frame_indices = [] 414 | frame_index = 0 415 | while len(frame_indices) < nearest_frame_bucket: 416 | frame_indices.append(frame_index) 417 | frame_index += frames_load_interval 418 | 419 | frames = video_reader.get_batch(frame_indices) 420 | frames = frames[:nearest_frame_bucket].float() 421 | frames = frames.permute(0, 3, 1, 2).contiguous() 422 | 423 | nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) 424 | frames_resized = self._resize_for_rectangle_crop(frames, nearest_res) 425 | frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) 426 | 427 | image = frames[:1].clone() if self.image_to_video else None 428 | 429 | return image, frames, None 430 | 431 | def _find_nearest_resolution(self, height, width): 432 | nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) 433 | return nearest_res[1], nearest_res[2] 434 | 435 | 436 | class BucketSampler(Sampler): 437 | r""" 438 | PyTorch Sampler that groups 3D data by height, width and frames. 439 | 440 | Args: 441 | data_source (`VideoDataset`): 442 | A PyTorch dataset object that is an instance of `VideoDataset`. 443 | batch_size (`int`, defaults to `8`): 444 | The batch size to use for training. 445 | shuffle (`bool`, defaults to `True`): 446 | Whether or not to shuffle the data in each batch before dispatching to dataloader. 447 | drop_last (`bool`, defaults to `False`): 448 | Whether or not to drop incomplete buckets of data after completely iterating over all data 449 | in the dataset. If set to True, only batches that have `batch_size` number of entries will 450 | be yielded. If set to False, it is guaranteed that all data in the dataset will be processed 451 | and batches that do not have `batch_size` number of entries will also be yielded. 452 | """ 453 | 454 | def __init__( 455 | self, data_source: VideoDataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False 456 | ) -> None: 457 | self.data_source = data_source 458 | self.batch_size = batch_size 459 | self.shuffle = shuffle 460 | self.drop_last = drop_last 461 | 462 | self.buckets = {resolution: [] for resolution in data_source.resolutions} 463 | 464 | self._raised_warning_for_drop_last = False 465 | 466 | def __len__(self): 467 | if self.drop_last and not self._raised_warning_for_drop_last: 468 | self._raised_warning_for_drop_last = True 469 | logger.warning( 470 | "Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training." 471 | ) 472 | return (len(self.data_source) + self.batch_size - 1) // self.batch_size 473 | 474 | def __iter__(self): 475 | for index, data in enumerate(self.data_source): 476 | video_metadata = data["video_metadata"] 477 | f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"] 478 | 479 | self.buckets[(f, h, w)].append(data) 480 | if len(self.buckets[(f, h, w)]) == self.batch_size: 481 | if self.shuffle: # self.drop_last=False; self.shuffle=True; 482 | random.shuffle(self.buckets[(f, h, w)]) 483 | yield self.buckets[(f, h, w)] 484 | del self.buckets[(f, h, w)] 485 | self.buckets[(f, h, w)] = [] 486 | 487 | if self.drop_last: 488 | return 489 | 490 | for fhw, bucket in list(self.buckets.items()): 491 | if len(bucket) == 0: 492 | continue 493 | if self.shuffle: 494 | random.shuffle(bucket) 495 | yield bucket 496 | del self.buckets[fhw] 497 | self.buckets[fhw] = [] -------------------------------------------------------------------------------- /utils/dataset_use_mask_entity.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | from typing import Any, Dict, List, Optional, Tuple 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | import torchvision.transforms as TT 9 | from accelerate.logging import get_logger 10 | from torch.utils.data import Dataset, Sampler 11 | from torchvision import transforms 12 | from torchvision.transforms import InterpolationMode 13 | from torchvision.transforms.functional import resize 14 | 15 | 16 | # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error 17 | # Very few bug reports but it happens. Look in decord Github issues for more relevant information. 18 | import decord # isort:skip 19 | 20 | decord.bridge.set_bridge("torch") 21 | 22 | logger = get_logger(__name__) 23 | 24 | HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] 25 | WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536] 26 | FRAME_BUCKETS = [16, 24, 32, 48, 64, 80] 27 | 28 | 29 | class VideoDataset(Dataset): 30 | def __init__( 31 | self, 32 | data_root: str, 33 | dataset_file: Optional[str] = None, 34 | caption_column: str = "text", 35 | video_column: str = "video", 36 | max_num_frames: int = 49, 37 | id_token: Optional[str] = None, 38 | height_buckets: List[int] = None, 39 | width_buckets: List[int] = None, 40 | frame_buckets: List[int] = None, 41 | load_tensors: bool = False, 42 | random_flip: Optional[float] = None, 43 | image_to_video: bool = False, 44 | enable_control_pose: bool = False, 45 | ) -> None: 46 | super().__init__() 47 | 48 | self.data_root = Path(data_root) 49 | self.dataset_file = dataset_file 50 | self.caption_column = caption_column 51 | self.video_column = video_column 52 | self.max_num_frames = max_num_frames 53 | self.id_token = f"{id_token.strip()} " if id_token else "" 54 | self.height_buckets = height_buckets or HEIGHT_BUCKETS 55 | self.width_buckets = width_buckets or WIDTH_BUCKETS 56 | self.frame_buckets = frame_buckets or FRAME_BUCKETS 57 | self.load_tensors = load_tensors 58 | self.random_flip = random_flip 59 | self.image_to_video = image_to_video 60 | self.enable_control_pose = enable_control_pose 61 | 62 | self.resolutions = [ 63 | (f, h, w) for h in self.height_buckets for w in self.width_buckets for f in self.frame_buckets 64 | ] 65 | 66 | # Two methods of loading data are supported. 67 | # - Using a CSV: caption_column and video_column must be some column in the CSV. One could 68 | # make use of other columns too, such as a motion score or aesthetic score, by modifying the 69 | # logic in CSV processing. 70 | # - Using two files containing line-separate captions and relative paths to videos. 71 | # For a more detailed explanation about preparing dataset format, checkout the README. 72 | if dataset_file is None: 73 | ( 74 | self.prompts, 75 | self.video_paths, 76 | ) = self._load_dataset_from_local_path() 77 | else: 78 | ( 79 | self.prompts, 80 | self.video_paths, 81 | ) = self._load_dataset_from_csv() 82 | 83 | if len(self.video_paths) != len(self.prompts): 84 | raise ValueError( 85 | f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset." 86 | ) 87 | 88 | self.video_transforms = transforms.Compose( 89 | [ 90 | transforms.RandomHorizontalFlip(random_flip) 91 | if random_flip 92 | else transforms.Lambda(self.identity_transform), 93 | transforms.Lambda(self.scale_transform), 94 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 95 | ] 96 | ) 97 | 98 | @staticmethod 99 | def identity_transform(x): 100 | return x 101 | 102 | @staticmethod 103 | def scale_transform(x): 104 | return x / 255.0 105 | 106 | def __len__(self) -> int: 107 | return len(self.video_paths) 108 | 109 | def __getitem__(self, index: int) -> Dict[str, Any]: 110 | if isinstance(index, list): 111 | # Here, index is actually a list of data objects that we need to return. 112 | # The BucketSampler should ideally return indices. But, in the sampler, we'd like 113 | # to have information about num_frames, height and width. Since this is not stored 114 | # as metadata, we need to read the video to get this information. You could read this 115 | # information without loading the full video in memory, but we do it anyway. In order 116 | # to not load the video twice (once to get the metadata, and once to return the loaded video 117 | # based on sampled indices), we cache it in the BucketSampler. When the sampler is 118 | # to yield, we yield the cache data instead of indices. So, this special check ensures 119 | # that data is not loaded a second time. PRs are welcome for improvements. 120 | return index 121 | 122 | if self.load_tensors: # False 123 | image_latents, video_latents, prompt_embeds = self._preprocess_video(self.video_paths[index]) 124 | 125 | # This is hardcoded for now. 126 | # The VAE's temporal compression ratio is 4. 127 | # The VAE's spatial compression ratio is 8. 128 | latent_num_frames = video_latents.size(1) 129 | if latent_num_frames % 2 == 0: 130 | num_frames = latent_num_frames * 4 131 | else: 132 | num_frames = (latent_num_frames - 1) * 4 + 1 133 | 134 | height = video_latents.size(2) * 8 135 | width = video_latents.size(3) * 8 136 | 137 | return { 138 | "prompt": prompt_embeds, 139 | "image": image_latents, 140 | "video": video_latents, 141 | "video_metadata": { 142 | "num_frames": num_frames, 143 | "height": height, 144 | "width": width, 145 | }, 146 | } 147 | else: 148 | image, video, control_video, _ = self._preprocess_video(self.video_paths[index]) 149 | prompt = self.prompts[index].split("#####")[1] 150 | 151 | return { 152 | "prompt": self.id_token + prompt, 153 | "image": image, 154 | "video": video, 155 | "control_video": control_video, 156 | "video_metadata": { 157 | "num_frames": video.shape[0], 158 | "height": video.shape[2], 159 | "width": video.shape[3], 160 | }, 161 | } 162 | 163 | def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str]]: 164 | if not self.data_root.exists(): 165 | raise ValueError("Root folder for videos does not exist") 166 | 167 | prompt_path = self.data_root.joinpath(self.caption_column) 168 | video_path = self.data_root.joinpath(self.video_column) 169 | 170 | if not prompt_path.exists() or not prompt_path.is_file(): 171 | raise ValueError( 172 | "Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts." 173 | ) 174 | if not video_path.exists() or not video_path.is_file(): 175 | raise ValueError( 176 | "Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory." 177 | ) 178 | 179 | with open(prompt_path, "r", encoding="utf-8") as file: 180 | prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0] 181 | with open(video_path, "r", encoding="utf-8") as file: 182 | video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0] 183 | 184 | if not self.load_tensors and any(not path.is_file() for path in video_paths): 185 | raise ValueError( 186 | f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." 187 | ) 188 | 189 | return prompts, video_paths 190 | 191 | def _load_dataset_from_csv(self) -> Tuple[List[str], List[str]]: 192 | df = pd.read_csv(self.dataset_file) 193 | prompts = df[self.caption_column].tolist() 194 | video_paths = df[self.video_column].tolist() 195 | video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths] 196 | 197 | if any(not path.is_file() for path in video_paths): 198 | raise ValueError( 199 | f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." 200 | ) 201 | 202 | return prompts, video_paths 203 | 204 | def _preprocess_video(self, path: Path) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 205 | r""" 206 | Loads a single video, or latent and prompt embedding, based on initialization parameters. 207 | 208 | If returning a video, returns a [F, C, H, W] video tensor, and None for the prompt embedding. Here, 209 | F, C, H and W are the frames, channels, height and width of the input video. 210 | 211 | If returning latent/embedding, returns a [F, C, H, W] latent, and the prompt embedding of shape [S, D]. 212 | F, C, H and W are the frames, channels, height and width of the latent, and S, D are the sequence length 213 | and embedding dimension of prompt embeddings. 214 | """ 215 | if self.load_tensors: 216 | return self._load_preprocessed_latents_and_embeds(path) 217 | else: 218 | video_reader = decord.VideoReader(uri=path.as_posix()) 219 | video_num_frames = len(video_reader) 220 | 221 | indices = list(range(0, video_num_frames, video_num_frames // self.max_num_frames)) 222 | frames = video_reader.get_batch(indices) 223 | frames = frames[: self.max_num_frames].float() 224 | frames = frames.permute(0, 3, 1, 2).contiguous() 225 | frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0) 226 | 227 | image = frames[:1].clone() if self.image_to_video else None 228 | 229 | return image, frames, None 230 | 231 | def _load_preprocessed_latents_and_embeds(self, path: Path) -> Tuple[torch.Tensor, torch.Tensor]: 232 | filename_without_ext = path.name.split(".")[0] 233 | pt_filename = f"{filename_without_ext}.pt" 234 | 235 | image_latents_path = path.parent.parent.joinpath("image_latents") 236 | video_latents_path = path.parent.parent.joinpath("video_latents") 237 | embeds_path = path.parent.parent.joinpath("prompt_embeds") 238 | 239 | if ( 240 | not video_latents_path.exists() 241 | or not embeds_path.exists() 242 | or (self.image_to_video and not image_latents_path.exists()) 243 | ): 244 | raise ValueError( 245 | f"When setting the load_tensors parameter to `True`, it is expected that the `{self.data_root=}` contains two folders named `video_latents` and `prompt_embeds`. However, these folders were not found. Please make sure to have prepared your data correctly using `prepare_data.py`. Additionally, if you're training image-to-video, it is expected that an `image_latents` folder is also present." 246 | ) 247 | 248 | if self.image_to_video: 249 | image_latent_filepath = image_latents_path.joinpath(pt_filename) 250 | video_latent_filepath = video_latents_path.joinpath(pt_filename) 251 | embeds_filepath = embeds_path.joinpath(pt_filename) 252 | 253 | if not video_latent_filepath.is_file() or not embeds_filepath.is_file(): 254 | if self.image_to_video: 255 | image_latent_filepath = image_latent_filepath.as_posix() 256 | video_latent_filepath = video_latent_filepath.as_posix() 257 | embeds_filepath = embeds_filepath.as_posix() 258 | raise ValueError( 259 | f"The file {video_latent_filepath=} or {embeds_filepath=} could not be found. Please ensure that you've correctly executed `prepare_dataset.py`." 260 | ) 261 | 262 | images = ( 263 | torch.load(image_latent_filepath, map_location="cpu", weights_only=True) if self.image_to_video else None 264 | ) 265 | latents = torch.load(video_latent_filepath, map_location="cpu", weights_only=True) 266 | embeds = torch.load(embeds_filepath, map_location="cpu", weights_only=True) 267 | 268 | return images, latents, embeds 269 | 270 | 271 | class VideoDatasetWithResizing(VideoDataset): 272 | def __init__(self, *args, **kwargs) -> None: 273 | super().__init__(*args, **kwargs) 274 | 275 | def get_entity_with_mask_short_time(self, video, masked_video): 276 | video = video.detach().numpy() 277 | masked_video = masked_video.detach().numpy() 278 | 279 | threshold = 200 280 | output_frames = [] 281 | 282 | if video.shape[1] != masked_video.shape[1] or video.shape[2] != masked_video.shape[2]: 283 | return video 284 | 285 | frame_length, frame_height, frame_width, _ = video.shape 286 | 287 | for i in range(frame_length): 288 | video_frame = video[i] 289 | mask_frame = masked_video[i] 290 | mask_frame = mask_frame[:, :, 0] 291 | 292 | mask_frame = mask_frame.astype(np.uint8) 293 | binary_mask = (mask_frame > threshold).astype(np.uint8) 294 | binary_mask = np.expand_dims(binary_mask, axis=-1) 295 | 296 | masked_frame = video_frame * binary_mask 297 | output_frames.append(masked_frame) 298 | 299 | return torch.tensor(np.array(output_frames)) 300 | 301 | def _preprocess_video(self, path: Path) -> torch.Tensor: 302 | if self.load_tensors: 303 | return self._load_preprocessed_latents_and_embeds(path) 304 | else: 305 | video_reader = decord.VideoReader(uri=path.as_posix()) 306 | 307 | use_image_mask = True 308 | if use_image_mask: 309 | try: 310 | mask_reader = decord.VideoReader(uri=path.as_posix().replace("videos", "mask_video")) 311 | except: 312 | mask_reader = None 313 | 314 | if self.enable_control_pose: 315 | pose_reader = decord.VideoReader(uri=path.as_posix().replace("videos", "poses")) 316 | else: 317 | pose_reader = None 318 | 319 | video_num_frames = len(video_reader) 320 | nearest_frame_bucket = min( 321 | self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) 322 | ) 323 | 324 | frames_load_interval = 1 325 | frame_indices = [] 326 | frame_index = random.randint(1, 30) 327 | while len(frame_indices) < nearest_frame_bucket: 328 | frame_indices.append(frame_index) 329 | frame_index += frames_load_interval 330 | 331 | # handle video frames. 332 | frames = video_reader.get_batch(frame_indices) 333 | if use_image_mask and mask_reader is not None: 334 | masked_frames = mask_reader.get_batch(frame_indices) 335 | masked_frames = self.get_entity_with_mask_short_time(frames, masked_frames) #FIXME use mask 336 | 337 | maked_frames = masked_frames[:nearest_frame_bucket].float() 338 | maked_frames = masked_frames.permute(0, 3, 1, 2).contiguous() 339 | 340 | frames = frames[:nearest_frame_bucket].float() 341 | frames = frames.permute(0, 3, 1, 2).contiguous() 342 | 343 | nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) 344 | frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0) 345 | frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) 346 | 347 | if use_image_mask and mask_reader is not None: 348 | maked_frames_resized = torch.stack([resize(mask_frame, nearest_res) for mask_frame in maked_frames], dim=0) 349 | maked_frames = torch.stack([self.video_transforms(mask_frame) for mask_frame in maked_frames_resized], dim=0) 350 | 351 | pose_frames = pose_reader.get_batch(frame_indices) 352 | pose_frames = pose_frames[:nearest_frame_bucket].float() 353 | pose_frames = pose_frames.permute(0, 3, 1, 2).contiguous() 354 | 355 | pose_frames_resized = torch.stack([resize(frame, nearest_res) for frame in pose_frames], dim=0) 356 | pose_frames = torch.stack([self.video_transforms(frame) for frame in pose_frames_resized], dim=0) 357 | 358 | if use_image_mask and mask_reader is not None: 359 | image = maked_frames[:1].clone() if self.image_to_video else None 360 | else: 361 | image = frames[:1].clone() if self.image_to_video else None 362 | 363 | return image, frames, pose_frames, None 364 | 365 | def _find_nearest_resolution(self, height, width): 366 | nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) 367 | return nearest_res[1], nearest_res[2] 368 | 369 | 370 | class VideoDatasetWithResizeAndRectangleCrop(VideoDataset): 371 | def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None: 372 | super().__init__(*args, **kwargs) 373 | self.video_reshape_mode = video_reshape_mode 374 | 375 | def _resize_for_rectangle_crop(self, arr, image_size): 376 | reshape_mode = self.video_reshape_mode 377 | if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: 378 | arr = resize( 379 | arr, 380 | size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], 381 | interpolation=InterpolationMode.BICUBIC, 382 | ) 383 | else: 384 | arr = resize( 385 | arr, 386 | size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], 387 | interpolation=InterpolationMode.BICUBIC, 388 | ) 389 | 390 | h, w = arr.shape[2], arr.shape[3] 391 | arr = arr.squeeze(0) 392 | 393 | delta_h = h - image_size[0] 394 | delta_w = w - image_size[1] 395 | 396 | if reshape_mode == "random" or reshape_mode == "none": 397 | top = np.random.randint(0, delta_h + 1) 398 | left = np.random.randint(0, delta_w + 1) 399 | elif reshape_mode == "center": 400 | top, left = delta_h // 2, delta_w // 2 401 | else: 402 | raise NotImplementedError 403 | arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) 404 | return arr 405 | 406 | def _preprocess_video(self, path: Path) -> torch.Tensor: 407 | if self.load_tensors: 408 | return self._load_preprocessed_latents_and_embeds(path) 409 | else: 410 | video_reader = decord.VideoReader(uri=path.as_posix()) 411 | video_num_frames = len(video_reader) 412 | nearest_frame_bucket = min( 413 | self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) 414 | ) 415 | 416 | frames_load_interval = 1 417 | frame_indices = [] 418 | frame_index = 0 419 | while len(frame_indices) < nearest_frame_bucket: 420 | frame_indices.append(frame_index) 421 | frame_index += frames_load_interval 422 | 423 | frames = video_reader.get_batch(frame_indices) 424 | frames = frames[:nearest_frame_bucket].float() 425 | frames = frames.permute(0, 3, 1, 2).contiguous() 426 | 427 | nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) 428 | frames_resized = self._resize_for_rectangle_crop(frames, nearest_res) 429 | frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) 430 | 431 | image = frames[:1].clone() if self.image_to_video else None 432 | 433 | return image, frames, None 434 | 435 | def _find_nearest_resolution(self, height, width): 436 | nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) 437 | return nearest_res[1], nearest_res[2] 438 | 439 | 440 | class BucketSampler(Sampler): 441 | r""" 442 | PyTorch Sampler that groups 3D data by height, width and frames. 443 | 444 | Args: 445 | data_source (`VideoDataset`): 446 | A PyTorch dataset object that is an instance of `VideoDataset`. 447 | batch_size (`int`, defaults to `8`): 448 | The batch size to use for training. 449 | shuffle (`bool`, defaults to `True`): 450 | Whether or not to shuffle the data in each batch before dispatching to dataloader. 451 | drop_last (`bool`, defaults to `False`): 452 | Whether or not to drop incomplete buckets of data after completely iterating over all data 453 | in the dataset. If set to True, only batches that have `batch_size` number of entries will 454 | be yielded. If set to False, it is guaranteed that all data in the dataset will be processed 455 | and batches that do not have `batch_size` number of entries will also be yielded. 456 | """ 457 | 458 | def __init__( 459 | self, data_source: VideoDataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False 460 | ) -> None: 461 | self.data_source = data_source 462 | self.batch_size = batch_size 463 | self.shuffle = shuffle 464 | self.drop_last = drop_last 465 | 466 | self.buckets = {resolution: [] for resolution in data_source.resolutions} 467 | 468 | self._raised_warning_for_drop_last = False 469 | 470 | def __len__(self): 471 | if self.drop_last and not self._raised_warning_for_drop_last: 472 | self._raised_warning_for_drop_last = True 473 | logger.warning( 474 | "Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training." 475 | ) 476 | return (len(self.data_source) + self.batch_size - 1) // self.batch_size 477 | 478 | def __iter__(self): 479 | for index, data in enumerate(self.data_source): 480 | video_metadata = data["video_metadata"] 481 | f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"] 482 | 483 | self.buckets[(f, h, w)].append(data) 484 | if len(self.buckets[(f, h, w)]) == self.batch_size: 485 | if self.shuffle: # self.drop_last=False; self.shuffle=True; 486 | random.shuffle(self.buckets[(f, h, w)]) 487 | yield self.buckets[(f, h, w)] 488 | del self.buckets[(f, h, w)] 489 | self.buckets[(f, h, w)] = [] 490 | 491 | if self.drop_last: 492 | return 493 | 494 | for fhw, bucket in list(self.buckets.items()): 495 | if len(bucket) == 0: 496 | continue 497 | if self.shuffle: 498 | random.shuffle(bucket) 499 | yield bucket 500 | del self.buckets[fhw] 501 | self.buckets[fhw] = [] -------------------------------------------------------------------------------- /utils/freeinit_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.fft as fft 3 | import math 4 | 5 | def freq_mix_3d(x, noise, LPF, fusion='low'): 6 | """ 7 | Noise reinitialization. 8 | 9 | Args: 10 | x: diffused latent 11 | noise: randomly sampled noise 12 | LPF: low pass filter 13 | """ 14 | # FFT 15 | x_freq = fft.fftn(x, dim=(-3, -2, -1)) 16 | x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) 17 | noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) 18 | noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) 19 | 20 | 21 | if fusion == 'low': 22 | # frequency mix 23 | LPF = LPF.to(x_freq.device) 24 | HPF = 1 - LPF 25 | x_freq_low = x_freq * LPF 26 | noise_freq_high = noise_freq * HPF 27 | x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain 28 | elif fusion == 'high': 29 | # frequency mix 30 | LPF = LPF.to(x_freq.device) 31 | HPF = 1 - LPF 32 | x_freq_high = x_freq * HPF 33 | noise_freq_low = noise_freq * LPF 34 | x_freq_mixed = x_freq_high + noise_freq_low # mix in freq domain 35 | 36 | 37 | # IFFT 38 | x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) 39 | x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real 40 | 41 | return x_mixed 42 | 43 | 44 | def get_freq_filter(shape, device, filter_type, n, d_s, d_t): 45 | """ 46 | Form the frequency filter for noise reinitialization. 47 | 48 | Args: 49 | shape: shape of latent (B, C, T, H, W) 50 | filter_type: type of the freq filter 51 | n: (only for butterworth) order of the filter, larger n ~ ideal, smaller n ~ gaussian 52 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0) 53 | d_t: normalized stop frequency for temporal dimension (0.0-1.0) 54 | """ 55 | if filter_type == "gaussian": 56 | return gaussian_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) 57 | elif filter_type == "ideal": 58 | return ideal_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) 59 | elif filter_type == "box": 60 | return box_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) 61 | elif filter_type == "butterworth": 62 | return butterworth_low_pass_filter(shape=shape, n=n, d_s=d_s, d_t=d_t).to(device) 63 | else: 64 | raise NotImplementedError 65 | 66 | def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25): 67 | """ 68 | Compute the gaussian low pass filter mask. 69 | 70 | Args: 71 | shape: shape of the filter (volume) 72 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0) 73 | d_t: normalized stop frequency for temporal dimension (0.0-1.0) 74 | """ 75 | T, H, W = shape[-3], shape[-2], shape[-1] 76 | mask = torch.zeros(shape) 77 | if d_s==0 or d_t==0: 78 | return mask 79 | for t in range(T): 80 | for h in range(H): 81 | for w in range(W): 82 | d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) 83 | mask[..., t,h,w] = math.exp(-1/(2*d_s**2) * d_square) 84 | return mask 85 | 86 | 87 | def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25): 88 | """ 89 | Compute the butterworth low pass filter mask. 90 | 91 | Args: 92 | shape: shape of the filter (volume) 93 | n: order of the filter, larger n ~ ideal, smaller n ~ gaussian 94 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0) 95 | d_t: normalized stop frequency for temporal dimension (0.0-1.0) 96 | """ 97 | T, H, W = shape[-3], shape[-2], shape[-1] 98 | mask = torch.zeros(shape) 99 | if d_s==0 or d_t==0: 100 | return mask 101 | for t in range(T): 102 | for h in range(H): 103 | for w in range(W): 104 | d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) 105 | mask[..., t,h,w] = 1 / (1 + (d_square / d_s**2)**n) 106 | return mask 107 | 108 | 109 | def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25): 110 | """ 111 | Compute the ideal low pass filter mask. 112 | 113 | Args: 114 | shape: shape of the filter (volume) 115 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0) 116 | d_t: normalized stop frequency for temporal dimension (0.0-1.0) 117 | """ 118 | T, H, W = shape[-3], shape[-2], shape[-1] 119 | mask = torch.zeros(shape) 120 | if d_s==0 or d_t==0: 121 | return mask 122 | for t in range(T): 123 | for h in range(H): 124 | for w in range(W): 125 | d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) 126 | mask[..., t,h,w] = 1 if d_square <= d_s*2 else 0 127 | return mask 128 | 129 | 130 | def box_low_pass_filter(shape, d_s=0.25, d_t=0.25): 131 | """ 132 | Compute the ideal low pass filter mask (approximated version). 133 | 134 | Args: 135 | shape: shape of the filter (volume) 136 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0) 137 | d_t: normalized stop frequency for temporal dimension (0.0-1.0) 138 | """ 139 | T, H, W = shape[-3], shape[-2], shape[-1] 140 | mask = torch.zeros(shape) 141 | if d_s==0 or d_t==0: 142 | return mask 143 | 144 | threshold_s = round(int(H // 2) * d_s) 145 | threshold_t = round(T // 2 * d_t) 146 | 147 | cframe, crow, ccol = T // 2, H // 2, W //2 148 | mask[..., cframe - threshold_t:cframe + threshold_t, crow - threshold_s:crow + threshold_s, ccol - threshold_s:ccol + threshold_s] = 1.0 149 | 150 | return mask -------------------------------------------------------------------------------- /utils/load_validation_control.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | import torch 3 | 4 | import decord 5 | 6 | import numpy as np 7 | 8 | def identity_transform(x): 9 | return x 10 | 11 | def scale_transform(x): 12 | return x / 255.0 13 | 14 | def load_control_video(video_path, height, width): 15 | control_video_height = height 16 | control_video_width = width 17 | video_transforms = transforms.Compose( 18 | [ 19 | transforms.Lambda(identity_transform), 20 | transforms.Lambda(scale_transform), 21 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 22 | transforms.Resize((control_video_height, control_video_width)) 23 | ] 24 | ) 25 | validation_control_video_reader = decord.VideoReader(uri=video_path) 26 | frame_indices = [i for i in range(0, 36)] 27 | 28 | validation_control_video = validation_control_video_reader.get_batch(frame_indices).float() 29 | 30 | validation_control_video = validation_control_video.permute(0, 3, 1, 2).contiguous() 31 | validation_control_video = torch.cat((validation_control_video, validation_control_video[-1].unsqueeze(0)), dim=0) 32 | validation_control_video = torch.stack([video_transforms(frame) for frame in validation_control_video], dim=0) 33 | return validation_control_video 34 | 35 | def load_control_video_inference(video_path, video_height, video_width, max_frames=None): 36 | video_transforms = transforms.Compose( 37 | [ 38 | transforms.Lambda(identity_transform), 39 | transforms.Lambda(scale_transform), 40 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 41 | transforms.Resize((1360, 768)) 42 | ] 43 | ) 44 | validation_control_video_reader = decord.VideoReader(uri=video_path) 45 | if max_frames: 46 | frame_indices = [i for i in range(0, max_frames)] 47 | else: 48 | frame_indices = [i for i in range(len(validation_control_video_reader))] 49 | 50 | validation_control_video = validation_control_video_reader.get_batch(frame_indices).asnumpy() 51 | validation_control_video = torch.from_numpy(validation_control_video).float() 52 | 53 | validation_control_video = validation_control_video.permute(0, 3, 1, 2).contiguous() 54 | validation_control_video = torch.stack([video_transforms(frame) for frame in validation_control_video], dim=0) 55 | return validation_control_video 56 | 57 | def load_contorl_video_from_Image(validation_control_images, video_height, video_width): 58 | video_transforms = transforms.Compose( 59 | [ 60 | transforms.Lambda(identity_transform), 61 | transforms.Lambda(scale_transform), 62 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 63 | transforms.Resize((video_height, video_width)) 64 | ] 65 | ) 66 | validation_control_video_tensor = [] 67 | 68 | for image in validation_control_images: 69 | validation_control_video_tensor.append(torch.from_numpy(np.array(image)).float()) 70 | 71 | validation_control_video = torch.stack(validation_control_video_tensor, dim=0) 72 | validation_control_video = validation_control_video.permute(0, 3, 1, 2).contiguous() 73 | validation_control_video = torch.stack([video_transforms(frame) for frame in validation_control_video], dim=0) 74 | return validation_control_video 75 | 76 | if __name__=="__main__": 77 | validation_control_video_path = "/home/user/video.mp4" 78 | validation_control_video = load_control_video(validation_control_video_path) -------------------------------------------------------------------------------- /utils/pre_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | import math 5 | from omegaconf import OmegaConf 6 | from datetime import datetime 7 | from pathlib import Path 8 | from PIL import Image 9 | import numpy as np 10 | import torch.jit 11 | from torchvision.datasets.folder import pil_loader 12 | from torchvision.transforms.functional import pil_to_tensor, resize, center_crop 13 | from torchvision.transforms.functional import to_pil_image 14 | from dwpose.preprocess import get_image_pose, get_video_pose 15 | 16 | ASPECT_RATIO = 9 / 16 17 | 18 | def preprocess(video_path, image_path, width=576, height=1024, sample_stride=2, max_frame_num=None): 19 | """preprocess ref image pose and video pose 20 | 21 | Args: 22 | video_path (str): input video pose path 23 | image_path (str): reference image path 24 | resolution (int, optional): Defaults to 576. 25 | sample_stride (int, optional): Defaults to 2. 26 | """ 27 | image_pixels = pil_loader(image_path) 28 | image_pixels = pil_to_tensor(image_pixels) # (c, h, w) 29 | h, w = image_pixels.shape[-2:] 30 | 31 | w_target, h_target = width, height 32 | h_w_ratio = float(h) / float(w) 33 | if h_w_ratio < h_target / w_target: 34 | h_resize, w_resize = h_target, math.ceil(h_target / h_w_ratio) 35 | else: 36 | h_resize, w_resize = math.ceil(w_target * h_w_ratio), w_target 37 | image_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None) 38 | image_pixels = center_crop(image_pixels, [h_target, w_target]) 39 | image_pixels = image_pixels.permute((1, 2, 0)).numpy() 40 | image_pose = get_image_pose(image_pixels) 41 | video_pose = get_video_pose(video_path, image_pixels, sample_stride=sample_stride, max_frame_num=max_frame_num) 42 | pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose]) 43 | image_pixels = Image.fromarray(image_pixels) 44 | pose_pixels = [Image.fromarray(p.transpose((1,2,0))) for p in pose_pixels] 45 | return pose_pixels, image_pixels 46 | 47 | -------------------------------------------------------------------------------- /utils/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import functools 5 | import json 6 | import os 7 | import pathlib 8 | import queue 9 | import traceback 10 | import uuid 11 | from concurrent.futures import ThreadPoolExecutor 12 | from typing import Any, Dict, List, Optional, Union 13 | 14 | import torch 15 | import torch.distributed as dist 16 | from diffusers import AutoencoderKLCogVideoX 17 | from diffusers.training_utils import set_seed 18 | from diffusers.utils import export_to_video, get_logger 19 | from torch.utils.data import DataLoader 20 | from torchvision import transforms 21 | from tqdm import tqdm 22 | from transformers import T5EncoderModel, T5Tokenizer 23 | 24 | 25 | import decord # isort:skip 26 | 27 | from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip 28 | 29 | 30 | decord.bridge.set_bridge("torch") 31 | 32 | logger = get_logger(__name__) 33 | 34 | DTYPE_MAPPING = { 35 | "fp32": torch.float32, 36 | "fp16": torch.float16, 37 | "bf16": torch.bfloat16, 38 | } 39 | 40 | 41 | def check_height(x: Any) -> int: 42 | x = int(x) 43 | if x % 16 != 0: 44 | raise argparse.ArgumentTypeError( 45 | f"`--height_buckets` must be divisible by 16, but got {x} which does not fit criteria." 46 | ) 47 | return x 48 | 49 | 50 | def check_width(x: Any) -> int: 51 | x = int(x) 52 | if x % 16 != 0: 53 | raise argparse.ArgumentTypeError( 54 | f"`--width_buckets` must be divisible by 16, but got {x} which does not fit criteria." 55 | ) 56 | return x 57 | 58 | 59 | def check_frames(x: Any) -> int: 60 | x = int(x) 61 | if x % 4 != 0 and x % 4 != 1: 62 | raise argparse.ArgumentTypeError( 63 | f"`--frames_buckets` must be of form `4 * k` or `4 * k + 1`, but got {x} which does not fit criteria." 64 | ) 65 | return x 66 | 67 | 68 | def get_args() -> Dict[str, Any]: 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument( 71 | "--model_id", 72 | type=str, 73 | default="THUDM/CogVideoX-2b", 74 | help="Hugging Face model ID to use for tokenizer, text encoder and VAE.", 75 | ) 76 | parser.add_argument("--data_root", type=str, required=True, help="Path to where training data is located.") 77 | parser.add_argument( 78 | "--dataset_file", type=str, default=None, help="Path to CSV file containing metadata about training data." 79 | ) 80 | parser.add_argument( 81 | "--caption_column", 82 | type=str, 83 | default="caption", 84 | help="If using a CSV file via the `--dataset_file` argument, this should be the name of the column containing the captions. If using the folder structure format for data loading, this should be the name of the file containing line-separated captions (the file should be located in `--data_root`).", 85 | ) 86 | parser.add_argument( 87 | "--video_column", 88 | type=str, 89 | default="video", 90 | help="If using a CSV file via the `--dataset_file` argument, this should be the name of the column containing the video paths. If using the folder structure format for data loading, this should be the name of the file containing line-separated video paths (the file should be located in `--data_root`).", 91 | ) 92 | parser.add_argument( 93 | "--id_token", 94 | type=str, 95 | default=None, 96 | help="Identifier token appended to the start of each prompt if provided.", 97 | ) 98 | parser.add_argument( 99 | "--height_buckets", 100 | nargs="+", 101 | type=check_height, 102 | default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536], 103 | ) 104 | parser.add_argument( 105 | "--width_buckets", 106 | nargs="+", 107 | type=check_width, 108 | default=[256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536], 109 | ) 110 | parser.add_argument( 111 | "--frame_buckets", 112 | nargs="+", 113 | type=check_frames, 114 | default=[49], 115 | ) 116 | parser.add_argument( 117 | "--random_flip", 118 | type=float, 119 | default=None, 120 | help="If random horizontal flip augmentation is to be used, this should be the flip probability.", 121 | ) 122 | parser.add_argument( 123 | "--dataloader_num_workers", 124 | type=int, 125 | default=0, 126 | help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", 127 | ) 128 | parser.add_argument( 129 | "--pin_memory", 130 | action="store_true", 131 | help="Whether or not to use the pinned memory setting in pytorch dataloader.", 132 | ) 133 | parser.add_argument( 134 | "--video_reshape_mode", 135 | type=str, 136 | default=None, 137 | help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']", 138 | ) 139 | parser.add_argument( 140 | "--save_image_latents", 141 | action="store_true", 142 | help="Whether or not to encode and store image latents, which are required for image-to-video finetuning. The image latents are the first frame of input videos encoded with the VAE.", 143 | ) 144 | parser.add_argument( 145 | "--output_dir", 146 | type=str, 147 | required=True, 148 | help="Path to output directory where preprocessed videos/latents/embeddings will be saved.", 149 | ) 150 | parser.add_argument("--max_num_frames", type=int, default=49, help="Maximum number of frames in output video.") 151 | parser.add_argument( 152 | "--max_sequence_length", type=int, default=226, help="Max sequence length of prompt embeddings." 153 | ) 154 | parser.add_argument("--target_fps", type=int, default=8, help="Frame rate of output videos.") 155 | parser.add_argument( 156 | "--save_latents_and_embeddings", 157 | action="store_true", 158 | help="Whether to encode videos/captions to latents/embeddings and save them in pytorch serializable format.", 159 | ) 160 | parser.add_argument( 161 | "--use_slicing", 162 | action="store_true", 163 | help="Whether to enable sliced encoding/decoding in the VAE. Only used if `--save_latents_and_embeddings` is also used.", 164 | ) 165 | parser.add_argument( 166 | "--use_tiling", 167 | action="store_true", 168 | help="Whether to enable tiled encoding/decoding in the VAE. Only used if `--save_latents_and_embeddings` is also used.", 169 | ) 170 | parser.add_argument("--batch_size", type=int, default=1, help="Number of videos to process at once in the VAE.") 171 | parser.add_argument( 172 | "--num_decode_threads", 173 | type=int, 174 | default=0, 175 | help="Number of decoding threads for `decord` to use. The default `0` means to automatically determine required number of threads.", 176 | ) 177 | parser.add_argument( 178 | "--dtype", 179 | type=str, 180 | choices=["fp32", "fp16", "bf16"], 181 | default="fp32", 182 | help="Data type to use when generating latents and prompt embeddings.", 183 | ) 184 | parser.add_argument("--seed", type=int, default=42, help="Seed for reproducibility.") 185 | parser.add_argument( 186 | "--num_artifact_workers", type=int, default=4, help="Number of worker threads for serializing artifacts." 187 | ) 188 | return parser.parse_args() 189 | 190 | 191 | def _get_t5_prompt_embeds( 192 | tokenizer: T5Tokenizer, 193 | text_encoder: T5EncoderModel, 194 | prompt: Union[str, List[str]], 195 | num_videos_per_prompt: int = 1, 196 | max_sequence_length: int = 226, 197 | device: Optional[torch.device] = None, 198 | dtype: Optional[torch.dtype] = None, 199 | text_input_ids=None, 200 | ): 201 | prompt = [prompt] if isinstance(prompt, str) else prompt 202 | batch_size = len(prompt) 203 | 204 | if tokenizer is not None: 205 | text_inputs = tokenizer( 206 | prompt, 207 | padding="max_length", 208 | max_length=max_sequence_length, 209 | truncation=True, 210 | add_special_tokens=True, 211 | return_tensors="pt", 212 | ) 213 | text_input_ids = text_inputs.input_ids 214 | else: 215 | if text_input_ids is None: 216 | raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") 217 | 218 | prompt_embeds = text_encoder(text_input_ids.to(device))[0] 219 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 220 | 221 | # duplicate text embeddings for each generation per prompt, using mps friendly method 222 | _, seq_len, _ = prompt_embeds.shape 223 | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) 224 | prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) 225 | 226 | return prompt_embeds 227 | 228 | 229 | def encode_prompt( 230 | tokenizer: T5Tokenizer, 231 | text_encoder: T5EncoderModel, 232 | prompt: Union[str, List[str]], 233 | num_videos_per_prompt: int = 1, 234 | max_sequence_length: int = 226, 235 | device: Optional[torch.device] = None, 236 | dtype: Optional[torch.dtype] = None, 237 | text_input_ids=None, 238 | ): 239 | prompt = [prompt] if isinstance(prompt, str) else prompt 240 | prompt_embeds = _get_t5_prompt_embeds( 241 | tokenizer, 242 | text_encoder, 243 | prompt=prompt, 244 | num_videos_per_prompt=num_videos_per_prompt, 245 | max_sequence_length=max_sequence_length, 246 | device=device, 247 | dtype=dtype, 248 | text_input_ids=text_input_ids, 249 | ) 250 | return prompt_embeds 251 | 252 | 253 | def compute_prompt_embeddings( 254 | tokenizer: T5Tokenizer, 255 | text_encoder: T5EncoderModel, 256 | prompts: List[str], 257 | max_sequence_length: int, 258 | device: torch.device, 259 | dtype: torch.dtype, 260 | requires_grad: bool = False, 261 | ): 262 | if requires_grad: 263 | prompt_embeds = encode_prompt( 264 | tokenizer, 265 | text_encoder, 266 | prompts, 267 | num_videos_per_prompt=1, 268 | max_sequence_length=max_sequence_length, 269 | device=device, 270 | dtype=dtype, 271 | ) 272 | else: 273 | with torch.no_grad(): 274 | prompt_embeds = encode_prompt( 275 | tokenizer, 276 | text_encoder, 277 | prompts, 278 | num_videos_per_prompt=1, 279 | max_sequence_length=max_sequence_length, 280 | device=device, 281 | dtype=dtype, 282 | ) 283 | return prompt_embeds 284 | 285 | 286 | to_pil_image = transforms.ToPILImage(mode="RGB") 287 | 288 | 289 | def save_image(image: torch.Tensor, path: pathlib.Path) -> None: 290 | image = image.to(dtype=torch.float32).clamp(-1, 1) 291 | image = to_pil_image(image.float()) 292 | image.save(path) 293 | 294 | 295 | def save_video(video: torch.Tensor, path: pathlib.Path, fps: int = 8) -> None: 296 | video = video.to(dtype=torch.float32).clamp(-1, 1) 297 | video = [to_pil_image(frame) for frame in video] 298 | export_to_video(video, path, fps=fps) 299 | 300 | 301 | def save_prompt(prompt: str, path: pathlib.Path) -> None: 302 | with open(path, "w", encoding="utf-8") as file: 303 | file.write(prompt) 304 | 305 | 306 | def save_metadata(metadata: Dict[str, Any], path: pathlib.Path) -> None: 307 | with open(path, "w", encoding="utf-8") as file: 308 | file.write(json.dumps(metadata)) 309 | 310 | 311 | @torch.no_grad() 312 | def serialize_artifacts( 313 | batch_size: int, 314 | fps: int, 315 | images_dir: Optional[pathlib.Path] = None, 316 | image_latents_dir: Optional[pathlib.Path] = None, 317 | videos_dir: Optional[pathlib.Path] = None, 318 | video_latents_dir: Optional[pathlib.Path] = None, 319 | prompts_dir: Optional[pathlib.Path] = None, 320 | prompt_embeds_dir: Optional[pathlib.Path] = None, 321 | images: Optional[torch.Tensor] = None, 322 | image_latents: Optional[torch.Tensor] = None, 323 | videos: Optional[torch.Tensor] = None, 324 | video_latents: Optional[torch.Tensor] = None, 325 | prompts: Optional[List[str]] = None, 326 | prompt_embeds: Optional[torch.Tensor] = None, 327 | ) -> None: 328 | num_frames, height, width = videos.size(1), videos.size(3), videos.size(4) 329 | metadata = [{"num_frames": num_frames, "height": height, "width": width}] 330 | 331 | data_folder_mapper_list = [ 332 | (images, images_dir, lambda img, path: save_image(img[0], path), "png"), 333 | (image_latents, image_latents_dir, torch.save, "pt"), 334 | (videos, videos_dir, functools.partial(save_video, fps=fps), "mp4"), 335 | (video_latents, video_latents_dir, torch.save, "pt"), 336 | (prompts, prompts_dir, save_prompt, "txt"), 337 | (prompt_embeds, prompt_embeds_dir, torch.save, "pt"), 338 | (metadata, videos_dir, save_metadata, "txt"), 339 | ] 340 | filenames = [uuid.uuid4() for _ in range(batch_size)] 341 | 342 | for data, folder, save_fn, extension in data_folder_mapper_list: 343 | if data is None: 344 | continue 345 | for slice, filename in zip(data, filenames): 346 | if isinstance(slice, torch.Tensor): 347 | slice = slice.clone().to("cpu") 348 | path = folder.joinpath(f"{filename}.{extension}") 349 | save_fn(slice, path) 350 | 351 | 352 | def save_intermediates(output_queue: queue.Queue) -> None: 353 | while True: 354 | try: 355 | item = output_queue.get(timeout=30) 356 | if item is None: 357 | break 358 | serialize_artifacts(**item) 359 | 360 | except queue.Empty: 361 | continue 362 | 363 | 364 | @torch.no_grad() 365 | def main(): 366 | args = get_args() 367 | set_seed(args.seed) 368 | 369 | output_dir = pathlib.Path(args.output_dir) 370 | tmp_dir = output_dir.joinpath("tmp") 371 | 372 | output_dir.mkdir(parents=True, exist_ok=True) 373 | tmp_dir.mkdir(parents=True, exist_ok=True) 374 | 375 | # Create task queue for non-blocking serializing of artifacts 376 | output_queue = queue.Queue() 377 | save_thread = ThreadPoolExecutor(max_workers=args.num_artifact_workers) 378 | save_future = save_thread.submit(save_intermediates, output_queue) 379 | 380 | # Initialize distributed processing 381 | if "LOCAL_RANK" in os.environ: 382 | local_rank = int(os.environ["LOCAL_RANK"]) 383 | torch.cuda.set_device(local_rank) 384 | dist.init_process_group(backend="nccl") 385 | world_size = dist.get_world_size() 386 | rank = dist.get_rank() 387 | else: 388 | # Single GPU 389 | local_rank = 0 390 | world_size = 1 391 | rank = 0 392 | torch.cuda.set_device(rank) 393 | 394 | # Create folders where intermediate tensors from each rank will be saved 395 | images_dir = tmp_dir.joinpath(f"images/{rank}") 396 | image_latents_dir = tmp_dir.joinpath(f"image_latents/{rank}") 397 | videos_dir = tmp_dir.joinpath(f"videos/{rank}") 398 | video_latents_dir = tmp_dir.joinpath(f"video_latents/{rank}") 399 | prompts_dir = tmp_dir.joinpath(f"prompts/{rank}") 400 | prompt_embeds_dir = tmp_dir.joinpath(f"prompt_embeds/{rank}") 401 | 402 | images_dir.mkdir(parents=True, exist_ok=True) 403 | image_latents_dir.mkdir(parents=True, exist_ok=True) 404 | videos_dir.mkdir(parents=True, exist_ok=True) 405 | video_latents_dir.mkdir(parents=True, exist_ok=True) 406 | prompts_dir.mkdir(parents=True, exist_ok=True) 407 | prompt_embeds_dir.mkdir(parents=True, exist_ok=True) 408 | 409 | weight_dtype = DTYPE_MAPPING[args.dtype] 410 | target_fps = args.target_fps 411 | 412 | # 1. Dataset 413 | dataset_init_kwargs = { 414 | "data_root": args.data_root, 415 | "dataset_file": args.dataset_file, 416 | "caption_column": args.caption_column, 417 | "video_column": args.video_column, 418 | "max_num_frames": args.max_num_frames, 419 | "id_token": args.id_token, 420 | "height_buckets": args.height_buckets, 421 | "width_buckets": args.width_buckets, 422 | "frame_buckets": args.frame_buckets, 423 | "load_tensors": False, 424 | "random_flip": args.random_flip, 425 | "image_to_video": args.save_image_latents, 426 | } 427 | if args.video_reshape_mode is None: 428 | dataset = VideoDatasetWithResizing(**dataset_init_kwargs) 429 | else: 430 | dataset = VideoDatasetWithResizeAndRectangleCrop( 431 | video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs 432 | ) 433 | 434 | original_dataset_size = len(dataset) 435 | 436 | # Split data among GPUs 437 | if world_size > 1: 438 | samples_per_gpu = original_dataset_size // world_size 439 | start_index = rank * samples_per_gpu 440 | end_index = start_index + samples_per_gpu 441 | if rank == world_size - 1: 442 | end_index = original_dataset_size # Make sure the last GPU gets the remaining data 443 | 444 | # Slice the data 445 | dataset.prompts = dataset.prompts[start_index:end_index] 446 | dataset.video_paths = dataset.video_paths[start_index:end_index] 447 | else: 448 | pass 449 | 450 | rank_dataset_size = len(dataset) 451 | 452 | # 2. Dataloader 453 | def collate_fn(data): 454 | prompts = [x["prompt"] for x in data[0]] 455 | 456 | images = None 457 | if args.save_image_latents: 458 | images = [x["image"] for x in data[0]] 459 | images = torch.stack(images).to(dtype=weight_dtype, non_blocking=True) 460 | 461 | videos = [x["video"] for x in data[0]] 462 | videos = torch.stack(videos).to(dtype=weight_dtype, non_blocking=True) 463 | 464 | control_videos = [x["control_video"] for x in data[0]] #FIXME 465 | control_videos = torch.stack(control_videos).to(dtype=weight_dtype, non_blocking=True) 466 | 467 | return { 468 | "images": images, 469 | "videos": videos, 470 | "control_videos": control_videos, 471 | "prompts": prompts, 472 | } 473 | 474 | dataloader = DataLoader( 475 | dataset, 476 | batch_size=1, 477 | sampler=BucketSampler(dataset, batch_size=args.batch_size, shuffle=True, drop_last=False), 478 | collate_fn=collate_fn, 479 | num_workers=args.dataloader_num_workers, 480 | pin_memory=args.pin_memory, 481 | ) 482 | 483 | # 3. Prepare models 484 | device = f"cuda:{rank}" 485 | 486 | if args.save_latents_and_embeddings: 487 | tokenizer = T5Tokenizer.from_pretrained(args.model_id, subfolder="tokenizer") 488 | text_encoder = T5EncoderModel.from_pretrained( 489 | args.model_id, subfolder="text_encoder", torch_dtype=weight_dtype 490 | ) 491 | text_encoder = text_encoder.to(device) 492 | 493 | vae = AutoencoderKLCogVideoX.from_pretrained(args.model_id, subfolder="vae", torch_dtype=weight_dtype) 494 | vae = vae.to(device) 495 | 496 | if args.use_slicing: 497 | vae.enable_slicing() 498 | if args.use_tiling: 499 | vae.enable_tiling() 500 | 501 | # 4. Compute latents and embeddings and save 502 | if rank == 0: 503 | iterator = tqdm( 504 | dataloader, desc="Encoding", total=(rank_dataset_size + args.batch_size - 1) // args.batch_size 505 | ) 506 | else: 507 | iterator = dataloader 508 | 509 | for step, batch in enumerate(iterator): 510 | try: 511 | images = None 512 | image_latents = None 513 | video_latents = None 514 | prompt_embeds = None 515 | 516 | if args.save_image_latents: 517 | images = batch["images"].to(device, non_blocking=True) 518 | images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] 519 | 520 | videos = batch["videos"].to(device, non_blocking=True) 521 | videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] 522 | 523 | prompts = batch["prompts"] 524 | 525 | # Encode videos & images 526 | if args.save_latents_and_embeddings: 527 | if args.use_slicing: 528 | if args.save_image_latents: 529 | encoded_slices = [vae._encode(image_slice) for image_slice in images.split(1)] 530 | image_latents = torch.cat(encoded_slices) 531 | image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) 532 | 533 | encoded_slices = [vae._encode(video_slice) for video_slice in videos.split(1)] 534 | video_latents = torch.cat(encoded_slices) 535 | 536 | else: 537 | if args.save_image_latents: 538 | image_latents = vae._encode(images) 539 | image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) 540 | 541 | video_latents = vae._encode(videos) 542 | 543 | video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype) 544 | 545 | # Encode prompts 546 | prompt_embeds = compute_prompt_embeddings( 547 | tokenizer, 548 | text_encoder, 549 | prompts, 550 | args.max_sequence_length, 551 | device, 552 | weight_dtype, 553 | requires_grad=False, 554 | ) 555 | 556 | if images is not None: 557 | images = (images.permute(0, 2, 1, 3, 4) + 1) / 2 558 | 559 | videos = (videos.permute(0, 2, 1, 3, 4) + 1) / 2 560 | 561 | output_queue.put( 562 | { 563 | "batch_size": len(prompts), 564 | "fps": target_fps, 565 | "images_dir": images_dir, 566 | "image_latents_dir": image_latents_dir, 567 | "videos_dir": videos_dir, 568 | "video_latents_dir": video_latents_dir, 569 | "prompts_dir": prompts_dir, 570 | "prompt_embeds_dir": prompt_embeds_dir, 571 | "images": images, 572 | "image_latents": image_latents, 573 | "videos": videos, 574 | "video_latents": video_latents, 575 | "prompts": prompts, 576 | "prompt_embeds": prompt_embeds, 577 | } 578 | ) 579 | 580 | except Exception: 581 | print("-------------------------") 582 | print(f"An exception occurred while processing data: {rank=}, {world_size=}, {step=}") 583 | traceback.print_exc() 584 | print("-------------------------") 585 | 586 | # 5. Complete distributed processing 587 | if world_size > 1: 588 | dist.barrier() 589 | dist.destroy_process_group() 590 | 591 | output_queue.put(None) 592 | save_thread.shutdown(wait=True) 593 | save_future.result() 594 | 595 | # 6. Combine results from each rank 596 | if rank == 0: 597 | print( 598 | f"Completed preprocessing latents and embeddings. Temporary files from all ranks saved to `{tmp_dir.as_posix()}`" 599 | ) 600 | 601 | # Move files from each rank to common directory 602 | for subfolder, extension in [ 603 | ("images", "png"), 604 | ("image_latents", "pt"), 605 | ("videos", "mp4"), 606 | ("video_latents", "pt"), 607 | ("prompts", "txt"), 608 | ("prompt_embeds", "pt"), 609 | ("videos", "txt"), 610 | ]: 611 | tmp_subfolder = tmp_dir.joinpath(subfolder) 612 | combined_subfolder = output_dir.joinpath(subfolder) 613 | combined_subfolder.mkdir(parents=True, exist_ok=True) 614 | pattern = f"*.{extension}" 615 | 616 | for file in tmp_subfolder.rglob(pattern): 617 | file.replace(combined_subfolder / file.name) 618 | 619 | # Remove temporary directories 620 | def rmdir_recursive(dir: pathlib.Path) -> None: 621 | for child in dir.iterdir(): 622 | if child.is_file(): 623 | child.unlink() 624 | else: 625 | rmdir_recursive(child) 626 | dir.rmdir() 627 | 628 | rmdir_recursive(tmp_dir) 629 | 630 | # Combine prompts and videos into individual text files and single jsonl 631 | prompts_folder = output_dir.joinpath("prompts") 632 | prompts = [] 633 | stems = [] 634 | 635 | for filename in prompts_folder.rglob("*.txt"): 636 | with open(filename, "r") as file: 637 | prompts.append(file.read().strip()) 638 | stems.append(filename.stem) 639 | 640 | prompts_txt = output_dir.joinpath("prompts.txt") 641 | videos_txt = output_dir.joinpath("videos.txt") 642 | data_jsonl = output_dir.joinpath("data.jsonl") 643 | 644 | with open(prompts_txt, "w") as file: 645 | for prompt in prompts: 646 | file.write(f"{prompt}\n") 647 | 648 | with open(videos_txt, "w") as file: 649 | for stem in stems: 650 | file.write(f"videos/{stem}.mp4\n") 651 | 652 | with open(data_jsonl, "w") as file: 653 | for prompt, stem in zip(prompts, stems): 654 | video_metadata_txt = output_dir.joinpath(f"videos/{stem}.txt") 655 | with open(video_metadata_txt, "r", encoding="utf-8") as metadata_file: 656 | metadata = json.loads(metadata_file.read()) 657 | 658 | data = { 659 | "prompt": prompt, 660 | "prompt_embed": f"prompt_embeds/{stem}.pt", 661 | "image": f"images/{stem}.png", 662 | "image_latent": f"image_latents/{stem}.pt", 663 | "video": f"videos/{stem}.mp4", 664 | "video_latent": f"video_latents/{stem}.pt", 665 | "metadata": metadata, 666 | } 667 | file.write(json.dumps(data) + "\n") 668 | 669 | print(f"Completed preprocessing. All files saved to `{output_dir.as_posix()}`") 670 | 671 | 672 | if __name__ == "__main__": 673 | main() 674 | -------------------------------------------------------------------------------- /utils/save_utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | 4 | import cv2 5 | import imageio 6 | import numpy as np 7 | import torch 8 | import torchvision 9 | from einops import rearrange 10 | from PIL import Image 11 | 12 | 13 | def get_width_and_height_from_image_and_base_resolution(image, base_resolution): 14 | target_pixels = int(base_resolution) * int(base_resolution) 15 | original_width, original_height = Image.open(image).size 16 | ratio = (target_pixels / (original_width * original_height)) ** 0.5 17 | width_slider = round(original_width * ratio) 18 | height_slider = round(original_height * ratio) 19 | return height_slider, width_slider 20 | 21 | def color_transfer(sc, dc): 22 | """ 23 | Transfer color distribution from of sc, referred to dc. 24 | 25 | Args: 26 | sc (numpy.ndarray): input image to be transfered. 27 | dc (numpy.ndarray): reference image 28 | 29 | Returns: 30 | numpy.ndarray: Transferred color distribution on the sc. 31 | """ 32 | 33 | def get_mean_and_std(img): 34 | x_mean, x_std = cv2.meanStdDev(img) 35 | x_mean = np.hstack(np.around(x_mean, 2)) 36 | x_std = np.hstack(np.around(x_std, 2)) 37 | return x_mean, x_std 38 | 39 | sc = cv2.cvtColor(sc, cv2.COLOR_RGB2LAB) 40 | s_mean, s_std = get_mean_and_std(sc) 41 | dc = cv2.cvtColor(dc, cv2.COLOR_RGB2LAB) 42 | t_mean, t_std = get_mean_and_std(dc) 43 | img_n = ((sc - s_mean) * (t_std / s_std)) + t_mean 44 | np.putmask(img_n, img_n > 255, 255) 45 | np.putmask(img_n, img_n < 0, 0) 46 | dst = cv2.cvtColor(cv2.convertScaleAbs(img_n), cv2.COLOR_LAB2RGB) 47 | return dst 48 | 49 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=12, imageio_backend=True, color_transfer_post_process=False): 50 | videos = rearrange(videos, "b c t h w -> t b c h w") 51 | outputs = [] 52 | for x in videos: 53 | x = torchvision.utils.make_grid(x, nrow=n_rows) 54 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 55 | if rescale: 56 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 57 | x = (x * 255).numpy().astype(np.uint8) 58 | outputs.append(Image.fromarray(x)) 59 | 60 | if color_transfer_post_process: 61 | for i in range(1, len(outputs)): 62 | outputs[i] = Image.fromarray(color_transfer(np.uint8(outputs[i]), np.uint8(outputs[0]))) 63 | 64 | os.makedirs(os.path.dirname(path), exist_ok=True) 65 | if imageio_backend: 66 | if path.endswith("mp4"): 67 | imageio.mimsave(path, outputs, fps=fps) 68 | else: 69 | imageio.mimsave(path, outputs, duration=(1000 * 1/fps)) 70 | else: 71 | if path.endswith("mp4"): 72 | path = path.replace('.mp4', '.gif') 73 | outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0) 74 | 75 | def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size): 76 | if validation_image_start is not None and validation_image_end is not None: 77 | if type(validation_image_start) is str and os.path.isfile(validation_image_start): 78 | image_start = clip_image = Image.open(validation_image_start).convert("RGB") 79 | image_start = image_start.resize([sample_size[1], sample_size[0]]) 80 | clip_image = clip_image.resize([sample_size[1], sample_size[0]]) 81 | else: 82 | image_start = clip_image = validation_image_start 83 | image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start] 84 | clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image] 85 | 86 | if type(validation_image_end) is str and os.path.isfile(validation_image_end): 87 | image_end = Image.open(validation_image_end).convert("RGB") 88 | image_end = image_end.resize([sample_size[1], sample_size[0]]) 89 | else: 90 | image_end = validation_image_end 91 | image_end = [_image_end.resize([sample_size[1], sample_size[0]]) for _image_end in image_end] 92 | 93 | if type(image_start) is list: 94 | clip_image = clip_image[0] 95 | start_video = torch.cat( 96 | [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start], 97 | dim=2 98 | ) 99 | input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1]) 100 | input_video[:, :, :len(image_start)] = start_video 101 | 102 | input_video_mask = torch.zeros_like(input_video[:, :1]) 103 | input_video_mask[:, :, len(image_start):] = 255 104 | else: 105 | input_video = torch.tile( 106 | torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0), 107 | [1, 1, video_length, 1, 1] 108 | ) 109 | input_video_mask = torch.zeros_like(input_video[:, :1]) 110 | input_video_mask[:, :, 1:] = 255 111 | 112 | if type(image_end) is list: 113 | image_end = [_image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) for _image_end in image_end] 114 | end_video = torch.cat( 115 | [torch.from_numpy(np.array(_image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_end in image_end], 116 | dim=2 117 | ) 118 | input_video[:, :, -len(end_video):] = end_video 119 | 120 | input_video_mask[:, :, -len(image_end):] = 0 121 | else: 122 | image_end = image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) 123 | input_video[:, :, -1:] = torch.from_numpy(np.array(image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) 124 | input_video_mask[:, :, -1:] = 0 125 | 126 | input_video = input_video / 255 127 | 128 | elif validation_image_start is not None: 129 | if type(validation_image_start) is str and os.path.isfile(validation_image_start): 130 | image_start = clip_image = Image.open(validation_image_start).convert("RGB") 131 | image_start = image_start.resize([sample_size[1], sample_size[0]]) 132 | clip_image = clip_image.resize([sample_size[1], sample_size[0]]) 133 | else: 134 | image_start = clip_image = validation_image_start 135 | image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start] 136 | clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image] 137 | image_end = None 138 | 139 | if type(image_start) is list: 140 | clip_image = clip_image[0] 141 | start_video = torch.cat( 142 | [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start], 143 | dim=2 144 | ) 145 | input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1]) 146 | input_video[:, :, :len(image_start)] = start_video 147 | input_video = input_video / 255 148 | 149 | input_video_mask = torch.zeros_like(input_video[:, :1]) 150 | input_video_mask[:, :, len(image_start):] = 255 151 | else: 152 | input_video = torch.tile( 153 | torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0), 154 | [1, 1, video_length, 1, 1] 155 | ) / 255 156 | input_video_mask = torch.zeros_like(input_video[:, :1]) 157 | input_video_mask[:, :, 1:, ] = 255 158 | else: 159 | image_start = None 160 | image_end = None 161 | input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]]) 162 | input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255 163 | clip_image = None 164 | 165 | del image_start 166 | del image_end 167 | gc.collect() 168 | 169 | return input_video, input_video_mask, clip_image 170 | 171 | def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None, ref_image=None): 172 | if isinstance(input_video_path, str): 173 | cap = cv2.VideoCapture(input_video_path) 174 | input_video = [] 175 | 176 | original_fps = cap.get(cv2.CAP_PROP_FPS) 177 | frame_skip = 1 if fps is None else int(original_fps // fps) 178 | 179 | frame_count = 0 180 | 181 | while True: 182 | ret, frame = cap.read() 183 | if not ret: 184 | break 185 | 186 | if frame_count % frame_skip == 0: 187 | frame = cv2.resize(frame, (sample_size[1], sample_size[0])) 188 | input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 189 | 190 | frame_count += 1 191 | 192 | cap.release() 193 | else: 194 | input_video = input_video_path 195 | 196 | input_video = torch.from_numpy(np.array(input_video))[:video_length] 197 | input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255 198 | 199 | if ref_image is not None: 200 | ref_image = Image.open(ref_image) 201 | ref_image = torch.from_numpy(np.array(ref_image)) 202 | ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255 203 | 204 | if validation_video_mask is not None: 205 | validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0])) 206 | input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255) 207 | 208 | input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0) 209 | input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1]) 210 | input_video_mask = input_video_mask.to(input_video.device, input_video.dtype) 211 | else: 212 | input_video_mask = torch.zeros_like(input_video[:, :1]) 213 | input_video_mask[:, :, :] = 255 214 | 215 | return input_video, input_video_mask, ref_image -------------------------------------------------------------------------------- /utils/text_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .text_encoder import compute_prompt_embeddings -------------------------------------------------------------------------------- /utils/text_encoder/text_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import torch 4 | from transformers import T5EncoderModel, T5Tokenizer 5 | 6 | 7 | def _get_t5_prompt_embeds( 8 | tokenizer: T5Tokenizer, 9 | text_encoder: T5EncoderModel, 10 | prompt: Union[str, List[str]], 11 | num_videos_per_prompt: int = 1, 12 | max_sequence_length: int = 226, 13 | device: Optional[torch.device] = None, 14 | dtype: Optional[torch.dtype] = None, 15 | text_input_ids=None, 16 | ): 17 | prompt = [prompt] if isinstance(prompt, str) else prompt 18 | batch_size = len(prompt) 19 | 20 | if tokenizer is not None: 21 | text_inputs = tokenizer( 22 | prompt, 23 | padding="max_length", 24 | max_length=max_sequence_length, 25 | truncation=True, 26 | add_special_tokens=True, 27 | return_tensors="pt", 28 | ) 29 | text_input_ids = text_inputs.input_ids 30 | else: 31 | if text_input_ids is None: 32 | raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") 33 | 34 | prompt_embeds = text_encoder(text_input_ids.to(device))[0] 35 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 36 | 37 | # duplicate text embeddings for each generation per prompt, using mps friendly method 38 | _, seq_len, _ = prompt_embeds.shape 39 | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) 40 | prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) 41 | 42 | return prompt_embeds 43 | 44 | 45 | def encode_prompt( 46 | tokenizer: T5Tokenizer, 47 | text_encoder: T5EncoderModel, 48 | prompt: Union[str, List[str]], 49 | num_videos_per_prompt: int = 1, 50 | max_sequence_length: int = 226, 51 | device: Optional[torch.device] = None, 52 | dtype: Optional[torch.dtype] = None, 53 | text_input_ids=None, 54 | ): 55 | prompt = [prompt] if isinstance(prompt, str) else prompt 56 | prompt_embeds = _get_t5_prompt_embeds( 57 | tokenizer, 58 | text_encoder, 59 | prompt=prompt, 60 | num_videos_per_prompt=num_videos_per_prompt, 61 | max_sequence_length=max_sequence_length, 62 | device=device, 63 | dtype=dtype, 64 | text_input_ids=text_input_ids, 65 | ) 66 | return prompt_embeds 67 | 68 | 69 | def compute_prompt_embeddings( 70 | tokenizer: T5Tokenizer, 71 | text_encoder: T5EncoderModel, 72 | prompt: str, 73 | max_sequence_length: int, 74 | device: torch.device, 75 | dtype: torch.dtype, 76 | requires_grad: bool = False, 77 | ): 78 | if requires_grad: 79 | prompt_embeds = encode_prompt( 80 | tokenizer, 81 | text_encoder, 82 | prompt, 83 | num_videos_per_prompt=1, 84 | max_sequence_length=max_sequence_length, 85 | device=device, 86 | dtype=dtype, 87 | ) 88 | else: 89 | with torch.no_grad(): 90 | prompt_embeds = encode_prompt( 91 | tokenizer, 92 | text_encoder, 93 | prompt, 94 | num_videos_per_prompt=1, 95 | max_sequence_length=max_sequence_length, 96 | device=device, 97 | dtype=dtype, 98 | ) 99 | return prompt_embeds 100 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import inspect 3 | from typing import Optional, Tuple, Union 4 | 5 | import torch 6 | from accelerate import Accelerator 7 | from accelerate.logging import get_logger 8 | from diffusers.models.embeddings import get_3d_rotary_pos_embed 9 | from diffusers.utils.torch_utils import is_compiled_module 10 | 11 | 12 | logger = get_logger(__name__) 13 | 14 | 15 | def get_optimizer( 16 | params_to_optimize, 17 | optimizer_name: str = "adam", 18 | learning_rate: float = 1e-3, 19 | beta1: float = 0.9, 20 | beta2: float = 0.95, 21 | beta3: float = 0.98, 22 | epsilon: float = 1e-8, 23 | weight_decay: float = 1e-4, 24 | prodigy_decouple: bool = False, 25 | prodigy_use_bias_correction: bool = False, 26 | prodigy_safeguard_warmup: bool = False, 27 | use_8bit: bool = False, 28 | use_4bit: bool = False, 29 | use_torchao: bool = False, 30 | use_deepspeed: bool = False, 31 | use_cpu_offload_optimizer: bool = False, 32 | offload_gradients: bool = False, 33 | ) -> torch.optim.Optimizer: 34 | optimizer_name = optimizer_name.lower() 35 | 36 | # Use DeepSpeed optimzer 37 | if use_deepspeed: 38 | from accelerate.utils import DummyOptim 39 | 40 | return DummyOptim( 41 | params_to_optimize, 42 | lr=learning_rate, 43 | betas=(beta1, beta2), 44 | eps=epsilon, 45 | weight_decay=weight_decay, 46 | ) 47 | 48 | if use_8bit and use_4bit: 49 | raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.") 50 | 51 | if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer: 52 | try: 53 | import torchao 54 | 55 | torchao.__version__ 56 | except ImportError: 57 | raise ImportError( 58 | "To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`." 59 | ) 60 | 61 | if not use_torchao and use_4bit: 62 | raise ValueError("4-bit Optimizers are only supported with torchao.") 63 | 64 | # Optimizer creation 65 | supported_optimizers = ["adam", "adamw", "prodigy", "came"] 66 | if optimizer_name not in supported_optimizers: 67 | logger.warning( 68 | f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`." 69 | ) 70 | optimizer_name = "adamw" 71 | 72 | if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]: 73 | raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.") 74 | 75 | if use_8bit: 76 | try: 77 | import bitsandbytes as bnb 78 | except ImportError: 79 | raise ImportError( 80 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 81 | ) 82 | 83 | if optimizer_name == "adamw": 84 | if use_torchao: 85 | from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit 86 | 87 | optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW 88 | else: 89 | optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW 90 | 91 | init_kwargs = { 92 | "betas": (beta1, beta2), 93 | "eps": epsilon, 94 | "weight_decay": weight_decay, 95 | } 96 | 97 | elif optimizer_name == "adam": 98 | if use_torchao: 99 | from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit 100 | 101 | optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam 102 | else: 103 | optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam 104 | 105 | init_kwargs = { 106 | "betas": (beta1, beta2), 107 | "eps": epsilon, 108 | "weight_decay": weight_decay, 109 | } 110 | 111 | elif optimizer_name == "prodigy": 112 | try: 113 | import prodigyopt 114 | except ImportError: 115 | raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") 116 | 117 | optimizer_class = prodigyopt.Prodigy 118 | 119 | if learning_rate <= 0.1: 120 | logger.warning( 121 | "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" 122 | ) 123 | 124 | init_kwargs = { 125 | "lr": learning_rate, 126 | "betas": (beta1, beta2), 127 | "beta3": beta3, 128 | "eps": epsilon, 129 | "weight_decay": weight_decay, 130 | "decouple": prodigy_decouple, 131 | "use_bias_correction": prodigy_use_bias_correction, 132 | "safeguard_warmup": prodigy_safeguard_warmup, 133 | } 134 | 135 | elif optimizer_name == "came": 136 | try: 137 | import came_pytorch 138 | except ImportError: 139 | raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`") 140 | 141 | optimizer_class = came_pytorch.CAME 142 | 143 | init_kwargs = { 144 | "lr": learning_rate, 145 | "eps": (1e-30, 1e-16), 146 | "betas": (beta1, beta2, beta3), 147 | "weight_decay": weight_decay, 148 | } 149 | 150 | if use_cpu_offload_optimizer: 151 | from torchao.prototype.low_bit_optim import CPUOffloadOptimizer 152 | 153 | if "fused" in inspect.signature(optimizer_class.__init__).parameters: 154 | init_kwargs.update({"fused": True}) 155 | 156 | optimizer = CPUOffloadOptimizer( 157 | params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs 158 | ) 159 | else: 160 | optimizer = optimizer_class(params_to_optimize, **init_kwargs) 161 | 162 | return optimizer 163 | 164 | 165 | def get_gradient_norm(parameters): 166 | norm = 0 167 | for param in parameters: 168 | if param.grad is None: 169 | continue 170 | local_norm = param.grad.detach().data.norm(2) 171 | norm += local_norm.item() ** 2 172 | norm = norm**0.5 173 | return norm 174 | 175 | 176 | # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid 177 | def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): 178 | tw = tgt_width 179 | th = tgt_height 180 | h, w = src 181 | r = h / w 182 | if r > (th / tw): 183 | resize_height = th 184 | resize_width = int(round(th / h * w)) 185 | else: 186 | resize_width = tw 187 | resize_height = int(round(tw / w * h)) 188 | 189 | crop_top = int(round((th - resize_height) / 2.0)) 190 | crop_left = int(round((tw - resize_width) / 2.0)) 191 | 192 | return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) 193 | 194 | 195 | def prepare_rotary_positional_embeddings( 196 | height: int, 197 | width: int, 198 | num_frames: int, 199 | vae_scale_factor_spatial: int = 8, 200 | patch_size: int = 2, 201 | patch_size_t: int = None, 202 | attention_head_dim: int = 64, 203 | device: Optional[torch.device] = None, 204 | base_height: int = 480, 205 | base_width: int = 720, 206 | ) -> Tuple[torch.Tensor, torch.Tensor]: 207 | grid_height = height // (vae_scale_factor_spatial * patch_size) 208 | grid_width = width // (vae_scale_factor_spatial * patch_size) 209 | base_size_width = base_width // (vae_scale_factor_spatial * patch_size) 210 | base_size_height = base_height // (vae_scale_factor_spatial * patch_size) 211 | 212 | if patch_size_t is None: 213 | grid_crops_coords = get_resize_crop_region_for_grid( 214 | (grid_height, grid_width), base_size_width, base_size_height 215 | ) 216 | freqs_cos, freqs_sin = get_3d_rotary_pos_embed( 217 | embed_dim=attention_head_dim, 218 | crops_coords=grid_crops_coords, 219 | grid_size=(grid_height, grid_width), 220 | temporal_size=num_frames, 221 | ) 222 | else: 223 | base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t 224 | 225 | freqs_cos, freqs_sin = get_3d_rotary_pos_embed( 226 | embed_dim=attention_head_dim, 227 | crops_coords=None, 228 | grid_size=(grid_height, grid_width), 229 | temporal_size=base_num_frames, 230 | grid_type="slice", 231 | max_size=(base_size_height, base_size_width), 232 | ) 233 | 234 | freqs_cos = freqs_cos.to(device=device) 235 | freqs_sin = freqs_sin.to(device=device) 236 | return freqs_cos, freqs_sin 237 | 238 | 239 | def reset_memory(device: Union[str, torch.device]) -> None: 240 | gc.collect() 241 | torch.cuda.empty_cache() 242 | torch.cuda.reset_peak_memory_stats(device) 243 | torch.cuda.reset_accumulated_memory_stats(device) 244 | 245 | 246 | def print_memory(device: Union[str, torch.device]) -> None: 247 | memory_allocated = torch.cuda.memory_allocated(device) / 1024**3 248 | max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3 249 | max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 250 | print(f"{memory_allocated=:.3f} GB") 251 | print(f"{max_memory_allocated=:.3f} GB") 252 | print(f"{max_memory_reserved=:.3f} GB") 253 | 254 | 255 | def unwrap_model(accelerator: Accelerator, model): 256 | model = accelerator.unwrap_model(model) 257 | model = model._orig_mod if is_compiled_module(model) else model 258 | return model 259 | --------------------------------------------------------------------------------