├── LICENSE ├── README.md ├── assets ├── 1024demos │ ├── 30_sample0.mp4 │ └── 5_sample0.mp4 ├── DC │ ├── 15_sample0.gif │ ├── 15base.gif │ ├── 6_sample0.gif │ └── 6base.gif ├── PIA │ ├── 1000.gif │ ├── 900.gif │ └── concert.gif ├── SVD │ ├── 39_sample0.gif │ ├── 39base.gif │ ├── 41_sample0.gif │ └── 41base.gif ├── VC │ ├── 30_sample0.gif │ ├── 30base.gif │ ├── 5_sample0.gif │ └── 5base.gif ├── animate │ ├── 52.gif │ ├── 52_1000.gif │ └── 52img.gif ├── conditionalImg │ ├── 15.gif │ ├── 30.gif │ ├── 39.gif │ ├── 41.gif │ ├── 5.gif │ ├── 6.gif │ ├── doggy.jpg │ └── sunflower.gif ├── effect_of_M │ ├── 1000.gif │ ├── 840.gif │ ├── 880.gif │ ├── 920.gif │ └── 960.gif ├── effect_of_a │ ├── 01.gif │ └── 1.gif ├── effect_of_betam │ ├── 100.gif │ ├── 25.gif │ └── 700.gif └── overview.png └── examples ├── DynamiCrafter ├── configs │ ├── inference_1024_v1.0.yaml │ ├── inference_512_v1.0.yaml │ └── train_512.yaml ├── inference_512.sh ├── inference_CIL_1024.sh ├── inference_CIL_512.sh ├── lvdm │ ├── basics.py │ ├── common.py │ ├── data │ │ ├── base.py │ │ └── webvid.py │ ├── distributions.py │ ├── ema.py │ ├── models │ │ ├── autoencoder.py │ │ ├── ddpm3d.py │ │ ├── ddpm3dInference.py │ │ ├── samplers │ │ │ ├── ddim.py │ │ │ └── ddim_multiplecond.py │ │ └── utils_diffusion.py │ └── modules │ │ ├── attention.py │ │ ├── encoders │ │ ├── condition.py │ │ └── resampler.py │ │ ├── networks │ │ ├── ae_modules.py │ │ └── openaimodel3d.py │ │ └── x_transformer.py ├── main │ ├── callbacks.py │ ├── trainer.py │ ├── utils_data.py │ └── utils_train.py ├── prompts │ ├── 512 │ │ ├── 14.png │ │ ├── 15.png │ │ ├── 18.png │ │ ├── 25.png │ │ ├── 29.png │ │ ├── 30.png │ │ ├── 32.png │ │ ├── 33.png │ │ ├── 35.png │ │ ├── 36.png │ │ ├── 41.png │ │ ├── 43.png │ │ ├── 47.png │ │ ├── 5.png │ │ ├── 52.png │ │ ├── 55.png │ │ ├── 65.png │ │ ├── 7.png │ │ ├── A girl with long curly blonde hair and sunglasses, camera pan from left to right..png │ │ ├── A panda wearing sunglasses walking in slow-motion under water, in photorealistic style..png │ │ ├── a car parked in a parking lot with palm trees nearby,calm seas and skies..png │ │ └── test_prompts.txt │ └── 1024 │ │ ├── 14.png │ │ ├── 18.png │ │ ├── 25.png │ │ ├── 29.png │ │ ├── 30.png │ │ ├── 32.png │ │ ├── 33.png │ │ ├── 35.png │ │ ├── 36.png │ │ ├── 41.png │ │ ├── 47.png │ │ ├── 5.png │ │ ├── 52.png │ │ ├── 65.png │ │ ├── A panda wearing sunglasses walking in slow-motion under water, in photorealistic style..png │ │ └── test_prompts.txt ├── requirements.txt ├── scripts │ └── evaluation │ │ ├── __pycache__ │ │ └── inference.cpython-38.pyc │ │ ├── ddp_wrapper.py │ │ ├── funcs.py │ │ └── inference.py ├── train.sh └── utils │ ├── __pycache__ │ └── utils.cpython-38.pyc │ ├── save_video.py │ └── utils.py ├── SVD ├── config │ ├── inference1024.yaml │ ├── inference512.yaml │ └── train.yaml ├── demo │ ├── 1066.jpg │ ├── 485.jpg │ ├── A 360 shot of a sleek yacht sailing gracefully through the crystal-clear waters of the Caribbean..png │ ├── A girl with long curly blonde hair and sunglasses, camera pan from left to right..png │ ├── A panda wearing sunglasses walking in slow-motion under water, in photorealistic style..png │ ├── A pizza spinning inside a wood fired pizza oven; dramatic vivid colors..png │ └── a car parked in a parking lot with palm trees nearby,calm seas and skies..png ├── inference.py ├── inference.sh ├── inference_CIL_512.sh ├── requirements.txt ├── schedulers │ ├── scheduler_config1.json │ ├── scheduler_config10.json │ ├── scheduler_config100.json │ ├── scheduler_config1100.json │ ├── scheduler_config20.json │ ├── scheduler_config30.json │ ├── scheduler_config300.json │ ├── scheduler_config40.json │ ├── scheduler_config50.json │ ├── scheduler_config500.json │ ├── scheduler_config60.json │ ├── scheduler_config70.json │ ├── scheduler_config700.json │ ├── scheduler_config80.json │ ├── scheduler_config90.json │ └── scheduler_config900.json ├── svd │ ├── data │ │ └── dataset.py │ ├── inference │ │ └── pipline_CILsvd.py │ ├── loss.py │ ├── models │ │ ├── attention.py │ │ ├── attention_processor.py │ │ ├── embeddings.py │ │ ├── resnet.py │ │ ├── transformer_2d.py │ │ ├── transformer_temporal.py │ │ ├── unet_3d_blocks.py │ │ ├── unet_spatio_temporal_condition.py │ │ └── utils.py │ ├── schedulers │ │ ├── edm.py │ │ └── scheduling_euler_discrete.py │ └── training │ │ ├── loss.py │ │ └── utils.py ├── train.py └── train.sh ├── VideoCrafter ├── configs │ ├── inference_i2v_512_v1.0.yaml │ └── train.yaml ├── inference_512.sh ├── inference_CIL_512.sh ├── libs │ ├── eval_funcs.py │ ├── losses.py │ └── special_functions.py ├── lvdm │ ├── basics.py │ ├── common.py │ ├── data │ │ ├── base.py │ │ └── webvid.py │ ├── distributions.py │ ├── ema.py │ ├── models │ │ ├── autoencoder.py │ │ ├── ddpm3d.py │ │ ├── ddpm3d_videocrafter_ve.py │ │ ├── samplers │ │ │ ├── ddim_multiplecond.py │ │ │ └── ddim_videocrafter.py │ │ └── utils_diffusion.py │ └── modules │ │ ├── attention.py │ │ ├── attention_videocrafter.py │ │ ├── encoders │ │ ├── condition.py │ │ ├── ip_resampler.py │ │ └── resampler.py │ │ ├── networks │ │ ├── ae_modules.py │ │ └── openaimodel3d_videocrafter.py │ │ └── x_transformer.py ├── main │ ├── callbacks.py │ ├── trainer.py │ ├── utils_data.py │ └── utils_train.py ├── prompts │ └── 512 │ │ ├── 13.png │ │ ├── 14.png │ │ ├── 15.png │ │ ├── 18.png │ │ ├── 25.png │ │ ├── 29.png │ │ ├── 3.png │ │ ├── 30.png │ │ ├── 32.png │ │ ├── 33.png │ │ ├── 35.png │ │ ├── 36.png │ │ ├── 41.png │ │ ├── 43.png │ │ ├── 47.png │ │ ├── 5.png │ │ ├── 52.png │ │ ├── 55.png │ │ ├── 65.png │ │ ├── 7.png │ │ ├── A girl with long curly blonde hair and sunglasses, camera pan from left to right..png │ │ ├── A panda wearing sunglasses walking in slow-motion under water, in photorealistic style..png │ │ ├── a car parked in a parking lot with palm trees nearby,calm seas and skies..png │ │ └── test_prompts.txt ├── requirements.txt ├── scripts │ └── evaluation │ │ ├── ddp_wrapper.py │ │ ├── funcs.py │ │ └── inference.py ├── train.sh └── utils │ ├── save_video.py │ └── utils.py └── animate-anything ├── demo ├── demo.jsonl └── image │ └── 52.png ├── inference.py ├── inference.sh ├── models ├── pipeline.py ├── unet_3d_blocks.py └── unet_3d_condition_mask.py ├── output └── latent │ └── animate_anything_512_v1.02 │ ├── config.yaml │ ├── demo.yaml │ ├── model_index.json │ ├── scheduler │ └── scheduler_config.json │ ├── text_encoder │ ├── config.json │ └── model.txt │ ├── tokenizer │ ├── merges.txt │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── vocab.json │ ├── unet │ ├── config.json │ └── unet.txt │ └── vae │ ├── config.json │ └── vae.txt ├── schedulers ├── scheduling_ddim.py └── scheduling_ddpm.py ├── stable_lora └── lora.py └── utils ├── __init__.py ├── bucketing.py ├── common.py ├── convert_diffusers_to_original_ms_text_to_video.py ├── dataset.py ├── lama.py ├── ptp_utils.py ├── scene_detect.py └── seq_aligner.py /assets/1024demos/30_sample0.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/1024demos/30_sample0.mp4 -------------------------------------------------------------------------------- /assets/1024demos/5_sample0.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/1024demos/5_sample0.mp4 -------------------------------------------------------------------------------- /assets/DC/15_sample0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/DC/15_sample0.gif -------------------------------------------------------------------------------- /assets/DC/15base.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/DC/15base.gif -------------------------------------------------------------------------------- /assets/DC/6_sample0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/DC/6_sample0.gif -------------------------------------------------------------------------------- /assets/DC/6base.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/DC/6base.gif -------------------------------------------------------------------------------- /assets/PIA/1000.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/PIA/1000.gif -------------------------------------------------------------------------------- /assets/PIA/900.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/PIA/900.gif -------------------------------------------------------------------------------- /assets/PIA/concert.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/PIA/concert.gif -------------------------------------------------------------------------------- /assets/SVD/39_sample0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/SVD/39_sample0.gif -------------------------------------------------------------------------------- /assets/SVD/39base.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/SVD/39base.gif -------------------------------------------------------------------------------- /assets/SVD/41_sample0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/SVD/41_sample0.gif -------------------------------------------------------------------------------- /assets/SVD/41base.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/SVD/41base.gif -------------------------------------------------------------------------------- /assets/VC/30_sample0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/VC/30_sample0.gif -------------------------------------------------------------------------------- /assets/VC/30base.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/VC/30base.gif -------------------------------------------------------------------------------- /assets/VC/5_sample0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/VC/5_sample0.gif -------------------------------------------------------------------------------- /assets/VC/5base.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/VC/5base.gif -------------------------------------------------------------------------------- /assets/animate/52.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/animate/52.gif -------------------------------------------------------------------------------- /assets/animate/52_1000.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/animate/52_1000.gif -------------------------------------------------------------------------------- /assets/animate/52img.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/animate/52img.gif -------------------------------------------------------------------------------- /assets/conditionalImg/15.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/conditionalImg/15.gif -------------------------------------------------------------------------------- /assets/conditionalImg/30.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/conditionalImg/30.gif -------------------------------------------------------------------------------- /assets/conditionalImg/39.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/conditionalImg/39.gif -------------------------------------------------------------------------------- /assets/conditionalImg/41.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/conditionalImg/41.gif -------------------------------------------------------------------------------- /assets/conditionalImg/5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/conditionalImg/5.gif -------------------------------------------------------------------------------- /assets/conditionalImg/6.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/conditionalImg/6.gif -------------------------------------------------------------------------------- /assets/conditionalImg/doggy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/conditionalImg/doggy.jpg -------------------------------------------------------------------------------- /assets/conditionalImg/sunflower.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/conditionalImg/sunflower.gif -------------------------------------------------------------------------------- /assets/effect_of_M/1000.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/effect_of_M/1000.gif -------------------------------------------------------------------------------- /assets/effect_of_M/840.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/effect_of_M/840.gif -------------------------------------------------------------------------------- /assets/effect_of_M/880.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/effect_of_M/880.gif -------------------------------------------------------------------------------- /assets/effect_of_M/920.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/effect_of_M/920.gif -------------------------------------------------------------------------------- /assets/effect_of_M/960.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/effect_of_M/960.gif -------------------------------------------------------------------------------- /assets/effect_of_a/01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/effect_of_a/01.gif -------------------------------------------------------------------------------- /assets/effect_of_a/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/effect_of_a/1.gif -------------------------------------------------------------------------------- /assets/effect_of_betam/100.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/effect_of_betam/100.gif -------------------------------------------------------------------------------- /assets/effect_of_betam/25.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/effect_of_betam/25.gif -------------------------------------------------------------------------------- /assets/effect_of_betam/700.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/effect_of_betam/700.gif -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/assets/overview.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/configs/inference_1024_v1.0.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: lvdm.models.ddpm3dInference.LatentVisualDiffusion 3 | params: 4 | rescale_betas_zero_snr: True 5 | parameterization: "v" 6 | linear_start: 0.00085 7 | linear_end: 0.012 8 | num_timesteps_cond: 1 9 | timesteps: 1000 10 | first_stage_key: video 11 | cond_stage_key: caption 12 | cond_stage_trainable: False 13 | conditioning_key: hybrid 14 | image_size: [72, 128] 15 | channels: 4 16 | scale_by_std: False 17 | scale_factor: 0.18215 18 | use_ema: False 19 | uncond_type: 'empty_seq' 20 | use_dynamic_rescale: true 21 | base_scale: 0.3 22 | fps_condition_type: 'fps' 23 | perframe_ae: True 24 | unet_config: 25 | target: lvdm.modules.networks.openaimodel3d.UNetModel 26 | params: 27 | in_channels: 8 28 | out_channels: 4 29 | model_channels: 320 30 | attention_resolutions: 31 | - 4 32 | - 2 33 | - 1 34 | num_res_blocks: 2 35 | channel_mult: 36 | - 1 37 | - 2 38 | - 4 39 | - 4 40 | dropout: 0.1 41 | num_head_channels: 64 42 | transformer_depth: 1 43 | context_dim: 1024 44 | use_linear: true 45 | use_checkpoint: True 46 | temporal_conv: True 47 | temporal_attention: True 48 | temporal_selfatt_only: true 49 | use_relative_position: false 50 | use_causal_attention: False 51 | temporal_length: 16 52 | addition_attention: true 53 | image_cross_attention: true 54 | default_fs: 10 55 | fs_condition: true 56 | 57 | first_stage_config: 58 | target: lvdm.models.autoencoder.AutoencoderKL 59 | params: 60 | embed_dim: 4 61 | monitor: val/rec_loss 62 | ddconfig: 63 | double_z: True 64 | z_channels: 4 65 | resolution: 256 66 | in_channels: 3 67 | out_ch: 3 68 | ch: 128 69 | ch_mult: 70 | - 1 71 | - 2 72 | - 4 73 | - 4 74 | num_res_blocks: 2 75 | attn_resolutions: [] 76 | dropout: 0.0 77 | lossconfig: 78 | target: torch.nn.Identity 79 | 80 | cond_stage_config: 81 | target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder 82 | params: 83 | freeze: true 84 | layer: "penultimate" 85 | 86 | img_cond_stage_config: 87 | target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2 88 | params: 89 | freeze: true 90 | 91 | image_proj_stage_config: 92 | target: lvdm.modules.encoders.resampler.Resampler 93 | params: 94 | dim: 1024 95 | depth: 4 96 | dim_head: 64 97 | heads: 12 98 | num_queries: 16 99 | embedding_dim: 1280 100 | output_dim: 1024 101 | ff_mult: 4 102 | video_length: 16 103 | 104 | -------------------------------------------------------------------------------- /examples/DynamiCrafter/configs/inference_512_v1.0.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: lvdm.models.ddpm3dInference.LatentVisualDiffusion 3 | params: 4 | rescale_betas_zero_snr: True 5 | parameterization: "v" 6 | linear_start: 0.00085 7 | linear_end: 0.012 8 | num_timesteps_cond: 1 9 | timesteps: 1000 10 | first_stage_key: video 11 | cond_stage_key: caption 12 | cond_stage_trainable: False 13 | conditioning_key: hybrid 14 | image_size: [40, 64] 15 | channels: 4 16 | scale_by_std: False 17 | scale_factor: 0.18215 18 | use_ema: False 19 | uncond_type: 'empty_seq' 20 | use_dynamic_rescale: true 21 | base_scale: 0.7 22 | fps_condition_type: 'fps' 23 | perframe_ae: True 24 | unet_config: 25 | target: lvdm.modules.networks.openaimodel3d.UNetModel 26 | params: 27 | in_channels: 8 28 | out_channels: 4 29 | model_channels: 320 30 | attention_resolutions: 31 | - 4 32 | - 2 33 | - 1 34 | num_res_blocks: 2 35 | channel_mult: 36 | - 1 37 | - 2 38 | - 4 39 | - 4 40 | dropout: 0.1 41 | num_head_channels: 64 42 | transformer_depth: 1 43 | context_dim: 1024 44 | use_linear: true 45 | use_checkpoint: True 46 | temporal_conv: True 47 | temporal_attention: True 48 | temporal_selfatt_only: true 49 | use_relative_position: false 50 | use_causal_attention: False 51 | temporal_length: 16 52 | addition_attention: true 53 | image_cross_attention: true 54 | default_fs: 24 55 | fs_condition: true 56 | 57 | first_stage_config: 58 | target: lvdm.models.autoencoder.AutoencoderKL 59 | params: 60 | embed_dim: 4 61 | monitor: val/rec_loss 62 | ddconfig: 63 | double_z: True 64 | z_channels: 4 65 | resolution: 256 66 | in_channels: 3 67 | out_ch: 3 68 | ch: 128 69 | ch_mult: 70 | - 1 71 | - 2 72 | - 4 73 | - 4 74 | num_res_blocks: 2 75 | attn_resolutions: [] 76 | dropout: 0.0 77 | lossconfig: 78 | target: torch.nn.Identity 79 | 80 | cond_stage_config: 81 | target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder 82 | params: 83 | freeze: true 84 | layer: "penultimate" 85 | 86 | img_cond_stage_config: 87 | target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2 88 | params: 89 | freeze: true 90 | 91 | image_proj_stage_config: 92 | target: lvdm.modules.encoders.resampler.Resampler 93 | params: 94 | dim: 1024 95 | depth: 4 96 | dim_head: 64 97 | heads: 12 98 | num_queries: 16 99 | embedding_dim: 1280 100 | output_dim: 1024 101 | ff_mult: 4 102 | video_length: 16 103 | 104 | -------------------------------------------------------------------------------- /examples/DynamiCrafter/configs/train_512.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | pretrained_checkpoint: ckpt/original/model.ckpt 3 | base_learning_rate: 1.0e-05 4 | scale_lr: False 5 | target: lvdm.models.ddpm3d.LatentVisualDiffusion 6 | params: 7 | rescale_betas_zero_snr: True 8 | parameterization: "v" 9 | linear_start: 0.00085 10 | linear_end: 0.012 11 | num_timesteps_cond: 1 12 | log_every_t: 200 13 | timesteps: 1000 14 | first_stage_key: video 15 | cond_stage_key: caption 16 | cond_stage_trainable: False 17 | image_proj_model_trainable: True 18 | conditioning_key: hybrid 19 | image_size: [40, 64] 20 | channels: 4 21 | scale_by_std: False 22 | scale_factor: 0.18215 23 | use_ema: False 24 | uncond_prob: 0.05 25 | uncond_type: 'empty_seq' 26 | rand_cond_frame: true 27 | use_dynamic_rescale: true 28 | base_scale: 0.7 29 | fps_condition_type: 'fps' 30 | perframe_ae: True 31 | unet_config: 32 | target: lvdm.modules.networks.openaimodel3d.UNetModel 33 | params: 34 | in_channels: 8 35 | out_channels: 4 36 | model_channels: 320 37 | attention_resolutions: 38 | - 4 39 | - 2 40 | - 1 41 | num_res_blocks: 2 42 | channel_mult: 43 | - 1 44 | - 2 45 | - 4 46 | - 4 47 | dropout: 0.1 48 | num_head_channels: 64 49 | transformer_depth: 1 50 | context_dim: 1024 51 | use_linear: true 52 | use_checkpoint: True 53 | temporal_conv: True 54 | temporal_attention: True 55 | temporal_selfatt_only: true 56 | use_relative_position: false 57 | use_causal_attention: False 58 | temporal_length: 16 59 | addition_attention: true 60 | image_cross_attention: true 61 | default_fs: 10 62 | fs_condition: true 63 | 64 | first_stage_config: 65 | target: lvdm.models.autoencoder.AutoencoderKL 66 | params: 67 | embed_dim: 4 68 | monitor: val/rec_loss 69 | ddconfig: 70 | double_z: True 71 | z_channels: 4 72 | resolution: 256 73 | in_channels: 3 74 | out_ch: 3 75 | ch: 128 76 | ch_mult: 77 | - 1 78 | - 2 79 | - 4 80 | - 4 81 | num_res_blocks: 2 82 | attn_resolutions: [] 83 | dropout: 0.0 84 | lossconfig: 85 | target: torch.nn.Identity 86 | 87 | cond_stage_config: 88 | target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder 89 | params: 90 | freeze: true 91 | layer: "penultimate" 92 | 93 | img_cond_stage_config: 94 | target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2 95 | params: 96 | freeze: true 97 | 98 | image_proj_stage_config: 99 | target: lvdm.modules.encoders.resampler.Resampler 100 | params: 101 | dim: 1024 102 | depth: 4 103 | dim_head: 64 104 | heads: 12 105 | num_queries: 16 106 | embedding_dim: 1280 107 | output_dim: 1024 108 | ff_mult: 4 109 | video_length: 16 110 | beta_m: 100 111 | a: 5 112 | clip_type: 2 113 | data: 114 | target: utils_data.DataModuleFromConfig 115 | params: 116 | batch_size: 2 117 | num_workers: 12 118 | wrap: false 119 | train: 120 | target: lvdm.data.webvid.WebVid 121 | params: 122 | file_path: '../dataset/results_2M_train.csv' 123 | video_folder: '../dataset' 124 | video_length: 16 125 | frame_stride: 6 126 | load_raw_resolution: true 127 | resolution: [320, 512] 128 | spatial_transform: resize_center_crop 129 | random_fs: true ## if true, we uniformly sample fs with max_fs=frame_stride (above) 130 | 131 | lightning: 132 | precision: 16 133 | # strategy: deepspeed_stage_2 134 | trainer: 135 | benchmark: True 136 | accumulate_grad_batches: 1 137 | max_steps: 100000 138 | # logger 139 | log_every_n_steps: 50 140 | # val 141 | val_check_interval: 0.5 142 | gradient_clip_algorithm: 'norm' 143 | gradient_clip_val: 0.5 144 | callbacks: 145 | model_checkpoint: 146 | target: pytorch_lightning.callbacks.ModelCheckpoint 147 | params: 148 | every_n_train_steps: 2000 #1000 149 | filename: "{epoch}-{step}" 150 | save_weights_only: True 151 | metrics_over_trainsteps_checkpoint: 152 | target: pytorch_lightning.callbacks.ModelCheckpoint 153 | params: 154 | filename: '{epoch}-{step}' 155 | save_weights_only: True 156 | every_n_train_steps: 4000 157 | batch_logger: 158 | target: callbacks.ImageLogger 159 | params: 160 | batch_frequency: 50000 161 | to_local: False 162 | max_images: 8 163 | log_images_kwargs: 164 | ddim_steps: 50 165 | unconditional_guidance_scale: 7.5 166 | timestep_spacing: uniform_trailing 167 | guidance_rescale: 0.7 168 | -------------------------------------------------------------------------------- /examples/DynamiCrafter/inference_512.sh: -------------------------------------------------------------------------------- 1 | seed=123 2 | 3 | name=inference 4 | 5 | ckpt=ckpt/original/model.ckpt # path to your checkpoint 6 | config=configs/inference_512_v1.0.yaml 7 | 8 | prompt_dir=prompts/512 # file for prompts, which includes images and their corresponding text 9 | res_dir="results" # file for outputs 10 | 11 | 12 | H=320 13 | W=512 14 | FS=24 15 | M=940 16 | 17 | CUDA_VISIBLE_DEVICES=1 python3 -m torch.distributed.launch \ 18 | --nproc_per_node=1 --nnodes=1 --master_addr=127.0.0.1 --master_port=23459 --node_rank=0 \ 19 | scripts/evaluation/ddp_wrapper.py \ 20 | --module 'inference' \ 21 | --seed ${seed} \ 22 | --ckpt_path $ckpt \ 23 | --config $config \ 24 | --savedir $res_dir/$name \ 25 | --n_samples 1 \ 26 | --bs 1 --height ${H} --width ${W} \ 27 | --unconditional_guidance_scale 7.5 \ 28 | --ddim_steps 50 \ 29 | --ddim_eta 1.0 \ 30 | --prompt_dir $prompt_dir \ 31 | --text_input \ 32 | --video_length 16 \ 33 | --frame_stride ${FS} \ 34 | --timestep_spacing 'uniform_trailing' \ 35 | --guidance_rescale 0.7 \ 36 | --perframe_ae \ 37 | --M ${M} \ 38 | --whether_analytic_init 1 \ 39 | --analytic_init_path 'ckpt/initial_noise_512.pt' 40 | -------------------------------------------------------------------------------- /examples/DynamiCrafter/inference_CIL_1024.sh: -------------------------------------------------------------------------------- 1 | seed=123 2 | 3 | name=inference 4 | 5 | ckpt=ckpt/finetuned/timenoise.ckpt # path to your checkpoint 6 | config=configs/inference_1024_v1.0.yaml 7 | 8 | prompt_dir=prompts/1024 # file for prompts, which includes images and their corresponding text 9 | res_dir="results" # file for outputs 10 | 11 | 12 | H=576 13 | W=1024 14 | FS=24 15 | M=1000 16 | 17 | CUDA_VISIBLE_DEVICES=7 python3 -m torch.distributed.launch \ 18 | --nproc_per_node=1 --nnodes=1 --master_addr=127.0.0.1 --master_port=23459 --node_rank=0 \ 19 | scripts/evaluation/ddp_wrapper.py \ 20 | --module 'inference' \ 21 | --seed ${seed} \ 22 | --ckpt_path $ckpt \ 23 | --config $config \ 24 | --savedir $res_dir/$name \ 25 | --n_samples 1 \ 26 | --bs 1 --height ${H} --width ${W} \ 27 | --unconditional_guidance_scale 7.5 \ 28 | --ddim_steps 50 \ 29 | --ddim_eta 1.0 \ 30 | --prompt_dir $prompt_dir \ 31 | --text_input \ 32 | --video_length 16 \ 33 | --frame_stride ${FS} \ 34 | --timestep_spacing 'uniform_trailing' \ 35 | --guidance_rescale 0.7 \ 36 | --perframe_ae \ 37 | --M ${M} \ 38 | --whether_analytic_init 1 \ 39 | --analytic_init_path 'ckpt/initial_noise_1024.pt' 40 | 41 | 42 | -------------------------------------------------------------------------------- /examples/DynamiCrafter/inference_CIL_512.sh: -------------------------------------------------------------------------------- 1 | seed=123 2 | 3 | name=inference 4 | 5 | ckpt=ckpt/finetuned/timenoise.ckpt # path to your checkpoint 6 | config=configs/inference_512_v1.0.yaml 7 | 8 | prompt_dir=prompts/512 # file for prompts, which includes images and their corresponding text 9 | res_dir="results" # file for outputs 10 | 11 | 12 | H=320 13 | W=512 14 | FS=24 15 | M=1000 16 | 17 | CUDA_VISIBLE_DEVICES=1 python3 -m torch.distributed.launch \ 18 | --nproc_per_node=1 --nnodes=1 --master_addr=127.0.0.1 --master_port=23459 --node_rank=0 \ 19 | scripts/evaluation/ddp_wrapper.py \ 20 | --module 'inference' \ 21 | --seed ${seed} \ 22 | --ckpt_path $ckpt \ 23 | --config $config \ 24 | --savedir $res_dir/$name \ 25 | --n_samples 1 \ 26 | --bs 1 --height ${H} --width ${W} \ 27 | --unconditional_guidance_scale 7.5 \ 28 | --ddim_steps 50 \ 29 | --ddim_eta 1.0 \ 30 | --prompt_dir $prompt_dir \ 31 | --text_input \ 32 | --video_length 16 \ 33 | --frame_stride ${FS} \ 34 | --timestep_spacing 'uniform_trailing' \ 35 | --guidance_rescale 0.7 \ 36 | --perframe_ae \ 37 | --M ${M} \ 38 | --whether_analytic_init 1 \ 39 | --analytic_init_path 'ckpt/initial_noise_512.pt' 40 | -------------------------------------------------------------------------------- /examples/DynamiCrafter/lvdm/basics.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | import torch.nn as nn 11 | from utils.utils import instantiate_from_config 12 | 13 | 14 | def disabled_train(self, mode=True): 15 | """Overwrite model.train with this function to make sure train/eval mode 16 | does not change anymore.""" 17 | return self 18 | 19 | def zero_module(module): 20 | """ 21 | Zero out the parameters of a module and return it. 22 | """ 23 | for p in module.parameters(): 24 | p.detach().zero_() 25 | return module 26 | 27 | def scale_module(module, scale): 28 | """ 29 | Scale the parameters of a module and return it. 30 | """ 31 | for p in module.parameters(): 32 | p.detach().mul_(scale) 33 | return module 34 | 35 | 36 | def conv_nd(dims, *args, **kwargs): 37 | """ 38 | Create a 1D, 2D, or 3D convolution module. 39 | """ 40 | if dims == 1: 41 | return nn.Conv1d(*args, **kwargs) 42 | elif dims == 2: 43 | return nn.Conv2d(*args, **kwargs) 44 | elif dims == 3: 45 | return nn.Conv3d(*args, **kwargs) 46 | raise ValueError(f"unsupported dimensions: {dims}") 47 | 48 | 49 | def linear(*args, **kwargs): 50 | """ 51 | Create a linear module. 52 | """ 53 | return nn.Linear(*args, **kwargs) 54 | 55 | 56 | def avg_pool_nd(dims, *args, **kwargs): 57 | """ 58 | Create a 1D, 2D, or 3D average pooling module. 59 | """ 60 | if dims == 1: 61 | return nn.AvgPool1d(*args, **kwargs) 62 | elif dims == 2: 63 | return nn.AvgPool2d(*args, **kwargs) 64 | elif dims == 3: 65 | return nn.AvgPool3d(*args, **kwargs) 66 | raise ValueError(f"unsupported dimensions: {dims}") 67 | 68 | 69 | def nonlinearity(type='silu'): 70 | if type == 'silu': 71 | return nn.SiLU() 72 | elif type == 'leaky_relu': 73 | return nn.LeakyReLU() 74 | 75 | 76 | class GroupNormSpecific(nn.GroupNorm): 77 | def forward(self, x): 78 | return super().forward(x.float()).type(x.dtype) 79 | 80 | 81 | def normalization(channels, num_groups=32): 82 | """ 83 | Make a standard normalization layer. 84 | :param channels: number of input channels. 85 | :return: an nn.Module for normalization. 86 | """ 87 | return GroupNormSpecific(num_groups, channels) 88 | 89 | 90 | class HybridConditioner(nn.Module): 91 | 92 | def __init__(self, c_concat_config, c_crossattn_config): 93 | super().__init__() 94 | self.concat_conditioner = instantiate_from_config(c_concat_config) 95 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 96 | 97 | def forward(self, c_concat, c_crossattn): 98 | c_concat = self.concat_conditioner(c_concat) 99 | c_crossattn = self.crossattn_conditioner(c_crossattn) 100 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} -------------------------------------------------------------------------------- /examples/DynamiCrafter/lvdm/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | from inspect import isfunction 3 | import torch 4 | from torch import nn 5 | import torch.distributed as dist 6 | 7 | 8 | def gather_data(data, return_np=True): 9 | ''' gather data from multiple processes to one list ''' 10 | data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())] 11 | dist.all_gather(data_list, data) # gather not supported with NCCL 12 | if return_np: 13 | data_list = [data.cpu().numpy() for data in data_list] 14 | return data_list 15 | 16 | def autocast(f): 17 | def do_autocast(*args, **kwargs): 18 | with torch.cuda.amp.autocast(enabled=True, 19 | dtype=torch.get_autocast_gpu_dtype(), 20 | cache_enabled=torch.is_autocast_cache_enabled()): 21 | return f(*args, **kwargs) 22 | return do_autocast 23 | 24 | 25 | def extract_into_tensor(a, t, x_shape): 26 | b, *_ = t.shape 27 | out = a.gather(-1, t) 28 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 29 | 30 | 31 | def noise_like(shape, device, repeat=False): 32 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 33 | noise = lambda: torch.randn(shape, device=device) 34 | return repeat_noise() if repeat else noise() 35 | 36 | 37 | def default(val, d): 38 | if exists(val): 39 | return val 40 | return d() if isfunction(d) else d 41 | 42 | def exists(val): 43 | return val is not None 44 | 45 | def identity(*args, **kwargs): 46 | return nn.Identity() 47 | 48 | def uniq(arr): 49 | return{el: True for el in arr}.keys() 50 | 51 | def mean_flat(tensor): 52 | """ 53 | Take the mean over all non-batch dimensions. 54 | """ 55 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 56 | 57 | def ismap(x): 58 | if not isinstance(x, torch.Tensor): 59 | return False 60 | return (len(x.shape) == 4) and (x.shape[1] > 3) 61 | 62 | def isimage(x): 63 | if not isinstance(x,torch.Tensor): 64 | return False 65 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 66 | 67 | def max_neg_value(t): 68 | return -torch.finfo(t.dtype).max 69 | 70 | def shape_to_str(x): 71 | shape_str = "x".join([str(x) for x in x.shape]) 72 | return shape_str 73 | 74 | def init_(tensor): 75 | dim = tensor.shape[-1] 76 | std = 1 / math.sqrt(dim) 77 | tensor.uniform_(-std, std) 78 | return tensor 79 | 80 | ckpt = torch.utils.checkpoint.checkpoint 81 | def checkpoint(func, inputs, params, flag): 82 | """ 83 | Evaluate a function without caching intermediate activations, allowing for 84 | reduced memory at the expense of extra compute in the backward pass. 85 | :param func: the function to evaluate. 86 | :param inputs: the argument sequence to pass to `func`. 87 | :param params: a sequence of parameters `func` depends on but does not 88 | explicitly take as arguments. 89 | :param flag: if False, disable gradient checkpointing. 90 | """ 91 | if flag: 92 | return ckpt(func, *inputs, use_reentrant=False) 93 | else: 94 | return func(*inputs) -------------------------------------------------------------------------------- /examples/DynamiCrafter/lvdm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /examples/DynamiCrafter/lvdm/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self, noise=None): 36 | if noise is None: 37 | noise = torch.randn(self.mean.shape) 38 | 39 | x = self.mean + self.std * noise.to(device=self.parameters.device) 40 | return x 41 | 42 | def kl(self, other=None): 43 | if self.deterministic: 44 | return torch.Tensor([0.]) 45 | else: 46 | if other is None: 47 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 48 | + self.var - 1.0 - self.logvar, 49 | dim=[1, 2, 3]) 50 | else: 51 | return 0.5 * torch.sum( 52 | torch.pow(self.mean - other.mean, 2) / other.var 53 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 54 | dim=[1, 2, 3]) 55 | 56 | def nll(self, sample, dims=[1,2,3]): 57 | if self.deterministic: 58 | return torch.Tensor([0.]) 59 | logtwopi = np.log(2.0 * np.pi) 60 | return 0.5 * torch.sum( 61 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 62 | dim=dims) 63 | 64 | def mode(self): 65 | return self.mean 66 | 67 | 68 | def normal_kl(mean1, logvar1, mean2, logvar2): 69 | """ 70 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 71 | Compute the KL divergence between two gaussians. 72 | Shapes are automatically broadcasted, so batches can be compared to 73 | scalars, among other use cases. 74 | """ 75 | tensor = None 76 | for obj in (mean1, logvar1, mean2, logvar2): 77 | if isinstance(obj, torch.Tensor): 78 | tensor = obj 79 | break 80 | assert tensor is not None, "at least one argument must be a Tensor" 81 | 82 | # Force variances to be Tensors. Broadcasting helps convert scalars to 83 | # Tensors, but it does not work for torch.exp(). 84 | logvar1, logvar2 = [ 85 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 86 | for x in (logvar1, logvar2) 87 | ] 88 | 89 | return 0.5 * ( 90 | -1.0 91 | + logvar2 92 | - logvar1 93 | + torch.exp(logvar1 - logvar2) 94 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 95 | ) -------------------------------------------------------------------------------- /examples/DynamiCrafter/lvdm/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) -------------------------------------------------------------------------------- /examples/DynamiCrafter/lvdm/modules/encoders/resampler.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py 3 | # and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class ImageProjModel(nn.Module): 10 | """Projection Model""" 11 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 12 | super().__init__() 13 | self.cross_attention_dim = cross_attention_dim 14 | self.clip_extra_context_tokens = clip_extra_context_tokens 15 | self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 16 | self.norm = nn.LayerNorm(cross_attention_dim) 17 | 18 | def forward(self, image_embeds): 19 | #embeds = image_embeds 20 | embeds = image_embeds.type(list(self.proj.parameters())[0].dtype) 21 | clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) 22 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 23 | return clip_extra_context_tokens 24 | 25 | 26 | # FFN 27 | def FeedForward(dim, mult=4): 28 | inner_dim = int(dim * mult) 29 | return nn.Sequential( 30 | nn.LayerNorm(dim), 31 | nn.Linear(dim, inner_dim, bias=False), 32 | nn.GELU(), 33 | nn.Linear(inner_dim, dim, bias=False), 34 | ) 35 | 36 | 37 | def reshape_tensor(x, heads): 38 | bs, length, width = x.shape 39 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head) 40 | x = x.view(bs, length, heads, -1) 41 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 42 | x = x.transpose(1, 2) 43 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 44 | x = x.reshape(bs, heads, length, -1) 45 | return x 46 | 47 | 48 | class PerceiverAttention(nn.Module): 49 | def __init__(self, *, dim, dim_head=64, heads=8): 50 | super().__init__() 51 | self.scale = dim_head**-0.5 52 | self.dim_head = dim_head 53 | self.heads = heads 54 | inner_dim = dim_head * heads 55 | 56 | self.norm1 = nn.LayerNorm(dim) 57 | self.norm2 = nn.LayerNorm(dim) 58 | 59 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 60 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 61 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 62 | 63 | 64 | def forward(self, x, latents): 65 | """ 66 | Args: 67 | x (torch.Tensor): image features 68 | shape (b, n1, D) 69 | latent (torch.Tensor): latent features 70 | shape (b, n2, D) 71 | """ 72 | x = self.norm1(x) 73 | latents = self.norm2(latents) 74 | 75 | b, l, _ = latents.shape 76 | 77 | q = self.to_q(latents) 78 | kv_input = torch.cat((x, latents), dim=-2) 79 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 80 | 81 | q = reshape_tensor(q, self.heads) 82 | k = reshape_tensor(k, self.heads) 83 | v = reshape_tensor(v, self.heads) 84 | 85 | # attention 86 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 87 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 88 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 89 | out = weight @ v 90 | 91 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 92 | 93 | return self.to_out(out) 94 | 95 | 96 | class Resampler(nn.Module): 97 | def __init__( 98 | self, 99 | dim=1024, 100 | depth=8, 101 | dim_head=64, 102 | heads=16, 103 | num_queries=8, 104 | embedding_dim=768, 105 | output_dim=1024, 106 | ff_mult=4, 107 | video_length=None, # using frame-wise version or not 108 | ): 109 | super().__init__() 110 | ## queries for a single frame / image 111 | self.num_queries = num_queries 112 | self.video_length = video_length 113 | 114 | ## queries for each frame 115 | if video_length is not None: 116 | num_queries = num_queries * video_length 117 | 118 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 119 | self.proj_in = nn.Linear(embedding_dim, dim) 120 | self.proj_out = nn.Linear(dim, output_dim) 121 | self.norm_out = nn.LayerNorm(output_dim) 122 | 123 | self.layers = nn.ModuleList([]) 124 | for _ in range(depth): 125 | self.layers.append( 126 | nn.ModuleList( 127 | [ 128 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 129 | FeedForward(dim=dim, mult=ff_mult), 130 | ] 131 | ) 132 | ) 133 | 134 | def forward(self, x): 135 | latents = self.latents.repeat(x.size(0), 1, 1) ## B (T L) C 136 | x = self.proj_in(x) 137 | 138 | for attn, ff in self.layers: 139 | latents = attn(x, latents) + latents 140 | latents = ff(latents) + latents 141 | 142 | latents = self.proj_out(latents) 143 | latents = self.norm_out(latents) # B L C or B (T L) C 144 | 145 | return latents -------------------------------------------------------------------------------- /examples/DynamiCrafter/main/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | mainlogger = logging.getLogger('mainlogger') 5 | 6 | import torch 7 | import torchvision 8 | import pytorch_lightning as pl 9 | from pytorch_lightning.callbacks import Callback 10 | from pytorch_lightning.utilities import rank_zero_only 11 | from pytorch_lightning.utilities import rank_zero_info 12 | from utils.save_video import log_local, prepare_to_log 13 | 14 | 15 | class ImageLogger(Callback): 16 | def __init__(self, batch_frequency, max_images=8, clamp=True, rescale=True, save_dir=None, \ 17 | to_local=False, log_images_kwargs=None): 18 | super().__init__() 19 | self.rescale = rescale 20 | self.batch_freq = batch_frequency 21 | self.max_images = max_images 22 | self.to_local = to_local 23 | self.clamp = clamp 24 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} 25 | if self.to_local: 26 | ## default save dir 27 | self.save_dir = os.path.join(save_dir, "images") 28 | os.makedirs(os.path.join(self.save_dir, "train"), exist_ok=True) 29 | os.makedirs(os.path.join(self.save_dir, "val"), exist_ok=True) 30 | 31 | def log_to_tensorboard(self, pl_module, batch_logs, filename, split, save_fps=8): 32 | """ log images and videos to tensorboard """ 33 | global_step = pl_module.global_step 34 | for key in batch_logs: 35 | value = batch_logs[key] 36 | tag = "gs%d-%s/%s-%s"%(global_step, split, filename, key) 37 | if isinstance(value, list) and isinstance(value[0], str): 38 | captions = ' |------| '.join(value) 39 | pl_module.logger.experiment.add_text(tag, captions, global_step=global_step) 40 | elif isinstance(value, torch.Tensor) and value.dim() == 5: 41 | video = value 42 | n = video.shape[0] 43 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w 44 | frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video] #[3, n*h, 1*w] 45 | grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w] 46 | grid = (grid + 1.0) / 2.0 47 | grid = grid.unsqueeze(dim=0) 48 | pl_module.logger.experiment.add_video(tag, grid, fps=save_fps, global_step=global_step) 49 | elif isinstance(value, torch.Tensor) and value.dim() == 4: 50 | img = value 51 | grid = torchvision.utils.make_grid(img, nrow=int(n), padding=0) 52 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 53 | pl_module.logger.experiment.add_image(tag, grid, global_step=global_step) 54 | else: 55 | pass 56 | 57 | @rank_zero_only 58 | def log_batch_imgs(self, pl_module, batch, batch_idx, split="train"): 59 | """ generate images, then save and log to tensorboard """ 60 | skip_freq = self.batch_freq if split == "train" else 5 61 | if (batch_idx+1) % skip_freq == 0: 62 | is_train = pl_module.training 63 | if is_train: 64 | pl_module.eval() 65 | torch.cuda.empty_cache() 66 | with torch.no_grad(): 67 | log_func = pl_module.log_images 68 | batch_logs = log_func(batch, split=split, **self.log_images_kwargs) 69 | 70 | ## process: move to CPU and clamp 71 | batch_logs = prepare_to_log(batch_logs, self.max_images, self.clamp) 72 | torch.cuda.empty_cache() 73 | 74 | filename = "ep{}_idx{}_rank{}".format( 75 | pl_module.current_epoch, 76 | batch_idx, 77 | pl_module.global_rank) 78 | if self.to_local: 79 | mainlogger.info("Log [%s] batch <%s> to local ..."%(split, filename)) 80 | filename = "gs{}_".format(pl_module.global_step) + filename 81 | log_local(batch_logs, os.path.join(self.save_dir, split), filename, save_fps=10) 82 | else: 83 | mainlogger.info("Log [%s] batch <%s> to tensorboard ..."%(split, filename)) 84 | self.log_to_tensorboard(pl_module, batch_logs, filename, split, save_fps=10) 85 | mainlogger.info('Finish!') 86 | 87 | if is_train: 88 | pl_module.train() 89 | 90 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None): 91 | if self.batch_freq != -1 and pl_module.logdir: 92 | self.log_batch_imgs(pl_module, batch, batch_idx, split="train") 93 | 94 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None): 95 | ## different with validation_step() that saving the whole validation set and only keep the latest, 96 | ## it records the performance of every validation (without overwritten) by only keep a subset 97 | if self.batch_freq != -1 and pl_module.logdir: 98 | self.log_batch_imgs(pl_module, batch, batch_idx, split="val") 99 | if hasattr(pl_module, 'calibrate_grad_norm'): 100 | if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: 101 | self.log_gradients(trainer, pl_module, batch_idx=batch_idx) 102 | 103 | 104 | class CUDACallback(Callback): 105 | # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py 106 | def on_train_epoch_start(self, trainer, pl_module): 107 | # Reset the memory use counter 108 | # lightning update 109 | if int((pl.__version__).split('.')[1])>=7: 110 | gpu_index = trainer.strategy.root_device.index 111 | else: 112 | gpu_index = trainer.root_gpu 113 | torch.cuda.reset_peak_memory_stats(gpu_index) 114 | torch.cuda.synchronize(gpu_index) 115 | self.start_time = time.time() 116 | 117 | def on_train_epoch_end(self, trainer, pl_module): 118 | if int((pl.__version__).split('.')[1])>=7: 119 | gpu_index = trainer.strategy.root_device.index 120 | else: 121 | gpu_index = trainer.root_gpu 122 | torch.cuda.synchronize(gpu_index) 123 | max_memory = torch.cuda.max_memory_allocated(gpu_index) / 2 ** 20 124 | epoch_time = time.time() - self.start_time 125 | 126 | try: 127 | max_memory = trainer.training_type_plugin.reduce(max_memory) 128 | epoch_time = trainer.training_type_plugin.reduce(epoch_time) 129 | 130 | rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") 131 | rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") 132 | except AttributeError: 133 | pass 134 | -------------------------------------------------------------------------------- /examples/DynamiCrafter/main/utils_data.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import numpy as np 3 | 4 | import torch 5 | import pytorch_lightning as pl 6 | from torch.utils.data import DataLoader, Dataset 7 | 8 | import os, sys 9 | os.chdir(sys.path[0]) 10 | sys.path.append("..") 11 | from lvdm.data.base import Txt2ImgIterableBaseDataset 12 | from utils.utils import instantiate_from_config 13 | 14 | 15 | def worker_init_fn(_): 16 | worker_info = torch.utils.data.get_worker_info() 17 | 18 | dataset = worker_info.dataset 19 | worker_id = worker_info.id 20 | 21 | if isinstance(dataset, Txt2ImgIterableBaseDataset): 22 | split_size = dataset.num_records // worker_info.num_workers 23 | # reset num_records to the true number to retain reliable length information 24 | dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] 25 | current_id = np.random.choice(len(np.random.get_state()[1]), 1) 26 | return np.random.seed(np.random.get_state()[1][current_id] + worker_id) 27 | else: 28 | return np.random.seed(np.random.get_state()[1][0] + worker_id) 29 | 30 | 31 | class WrappedDataset(Dataset): 32 | """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" 33 | 34 | def __init__(self, dataset): 35 | self.data = dataset 36 | 37 | def __len__(self): 38 | return len(self.data) 39 | 40 | def __getitem__(self, idx): 41 | return self.data[idx] 42 | 43 | 44 | class DataModuleFromConfig(pl.LightningDataModule): 45 | def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, 46 | wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False, 47 | shuffle_val_dataloader=False, train_img=None, 48 | test_max_n_samples=None): 49 | super().__init__() 50 | self.batch_size = batch_size 51 | self.dataset_configs = dict() 52 | self.num_workers = num_workers if num_workers is not None else batch_size * 2 53 | self.use_worker_init_fn = use_worker_init_fn 54 | if train is not None: 55 | self.dataset_configs["train"] = train 56 | self.train_dataloader = self._train_dataloader 57 | if validation is not None: 58 | self.dataset_configs["validation"] = validation 59 | self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) 60 | if test is not None: 61 | self.dataset_configs["test"] = test 62 | self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) 63 | if predict is not None: 64 | self.dataset_configs["predict"] = predict 65 | self.predict_dataloader = self._predict_dataloader 66 | 67 | self.img_loader = None 68 | self.wrap = wrap 69 | self.test_max_n_samples = test_max_n_samples 70 | self.collate_fn = None 71 | 72 | def prepare_data(self): 73 | pass 74 | 75 | def setup(self, stage=None): 76 | self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) 77 | if self.wrap: 78 | for k in self.datasets: 79 | self.datasets[k] = WrappedDataset(self.datasets[k]) 80 | 81 | def _train_dataloader(self): 82 | is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) 83 | if is_iterable_dataset or self.use_worker_init_fn: 84 | init_fn = worker_init_fn 85 | else: 86 | init_fn = None 87 | loader = DataLoader(self.datasets["train"], batch_size=self.batch_size, 88 | num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True, 89 | worker_init_fn=init_fn, collate_fn=self.collate_fn, 90 | ) 91 | return loader 92 | 93 | def _val_dataloader(self, shuffle=False): 94 | if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: 95 | init_fn = worker_init_fn 96 | else: 97 | init_fn = None 98 | return DataLoader(self.datasets["validation"], 99 | batch_size=self.batch_size, 100 | num_workers=self.num_workers, 101 | worker_init_fn=init_fn, 102 | shuffle=shuffle, 103 | collate_fn=self.collate_fn, 104 | ) 105 | 106 | def _test_dataloader(self, shuffle=False): 107 | try: 108 | is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) 109 | except: 110 | is_iterable_dataset = isinstance(self.datasets['test'], Txt2ImgIterableBaseDataset) 111 | 112 | if is_iterable_dataset or self.use_worker_init_fn: 113 | init_fn = worker_init_fn 114 | else: 115 | init_fn = None 116 | 117 | # do not shuffle dataloader for iterable dataset 118 | shuffle = shuffle and (not is_iterable_dataset) 119 | if self.test_max_n_samples is not None: 120 | dataset = torch.utils.data.Subset(self.datasets["test"], list(range(self.test_max_n_samples))) 121 | else: 122 | dataset = self.datasets["test"] 123 | return DataLoader(dataset, batch_size=self.batch_size, 124 | num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle, 125 | collate_fn=self.collate_fn, 126 | ) 127 | 128 | def _predict_dataloader(self, shuffle=False): 129 | if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: 130 | init_fn = worker_init_fn 131 | else: 132 | init_fn = None 133 | return DataLoader(self.datasets["predict"], batch_size=self.batch_size, 134 | num_workers=self.num_workers, worker_init_fn=init_fn, 135 | collate_fn=self.collate_fn, 136 | ) 137 | -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/1024/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/1024/14.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/1024/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/1024/18.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/1024/25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/1024/25.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/1024/29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/1024/29.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/1024/30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/1024/30.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/1024/32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/1024/32.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/1024/33.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/1024/33.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/1024/35.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/1024/35.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/1024/36.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/1024/36.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/1024/41.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/1024/41.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/1024/47.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/1024/47.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/1024/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/1024/5.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/1024/52.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/1024/52.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/1024/65.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/1024/65.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/1024/A panda wearing sunglasses walking in slow-motion under water, in photorealistic style..png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/1024/A panda wearing sunglasses walking in slow-motion under water, in photorealistic style..png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/1024/test_prompts.txt: -------------------------------------------------------------------------------- 1 | A duck swimming in the lake. 2 | A man riding motor on a mountain road. 3 | A couple hugging a cat. 4 | Mystical hills with a glowing blue portal. 5 | A soldier riding a horse. 6 | Mountains under the starlight. 7 | A duck swimming in the lake. 8 | A girl walks up the steps of a palace. 9 | A woman with flowing, curly silver hair and dark eyes. 10 | Rabbits playing in a river. 11 | A plate full of food, with camera spinning. 12 | A cartoon girl with brown curly hair splashes joyfully in a bubble-filled bathtub. 13 | Fireworks exploding in the sky. 14 | A duck swimming in the lake. 15 | A panda wearing sunglasses walking in slow-motion under water, in photorealistic style. -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/512/14.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/512/15.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/512/18.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/512/25.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/512/29.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/512/30.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/512/32.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/33.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/512/33.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/35.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/512/35.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/36.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/512/36.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/41.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/512/41.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/43.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/512/43.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/47.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/512/47.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/512/5.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/52.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/512/52.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/55.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/512/55.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/65.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/512/65.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/512/7.png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/A girl with long curly blonde hair and sunglasses, camera pan from left to right..png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/512/A girl with long curly blonde hair and sunglasses, camera pan from left to right..png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/A panda wearing sunglasses walking in slow-motion under water, in photorealistic style..png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/512/A panda wearing sunglasses walking in slow-motion under water, in photorealistic style..png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/a car parked in a parking lot with palm trees nearby,calm seas and skies..png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/prompts/512/a car parked in a parking lot with palm trees nearby,calm seas and skies..png -------------------------------------------------------------------------------- /examples/DynamiCrafter/prompts/512/test_prompts.txt: -------------------------------------------------------------------------------- 1 | A duck swimming in the lake. 2 | A duck swimming in the lake. 3 | A man riding motor on a mountain road. 4 | A couple hugging a cat. 5 | Mystical hills with a glowing blue portal. 6 | A soldier riding a horse. 7 | Mountains under the starlight. 8 | A duck swimming in the lake. 9 | A girl walks up the steps of a palace. 10 | A woman with flowing, curly silver hair and dark eyes. 11 | Rabbits playing in a river. 12 | Sailing of boats on the water surface. 13 | A plate full of food, with camera spinning. 14 | A cartoon girl with brown curly hair splashes joyfully in a bubble-filled bathtub. 15 | Fireworks exploding in the sky. 16 | A kitten lying on the bed. 17 | A duck swimming in the lake. 18 | Donkeys in traditional attire gallop across a lush green meadow. 19 | A girl with long curly blonde hair and sunglasses, camera pan from left to right. 20 | A panda wearing sunglasses walking in slow-motion under water, in photorealistic style. 21 | a car parked in a parking lot with palm trees nearby,calm seas and skies. 22 | -------------------------------------------------------------------------------- /examples/DynamiCrafter/requirements.txt: -------------------------------------------------------------------------------- 1 | aiofiles==23.2.1 2 | aiohttp==3.9.3 3 | aiosignal==1.3.1 4 | altair==5.2.0 5 | annotated-types==0.6.0 6 | antlr4-python3-runtime==4.8 7 | anyio==4.3.0 8 | APScheduler==3.9.1 9 | asttokens==2.0.5 10 | async-timeout==4.0.3 11 | attrdict==2.0.1 12 | attrs==23.2.0 13 | av==12.0.0 14 | backcall==0.2.0 15 | backports.zoneinfo==0.2.1 16 | certifi==2024.2.2 17 | charset-normalizer==3.3.2 18 | click==8.1.7 19 | cmake==3.28.3 20 | colorama==0.4.4 21 | contourpy==1.1.1 22 | cycler==0.11.0 23 | decorator==5.1.1 24 | decord==0.6.0 25 | einops==0.3.0 26 | exceptiongroup==1.2.0 27 | executing==0.8.3 28 | fairscale==0.4.13 29 | fastapi==0.110.0 30 | ffmpy==0.3.2 31 | filelock==3.13.1 32 | fire==0.6.0 33 | fonttools==4.33.3 34 | frozenlist==1.4.1 35 | fsspec==2024.3.1 36 | ftfy==6.2.0 37 | gradio==4.22.0 38 | gradio_client==0.13.0 39 | h11==0.14.0 40 | httpcore==1.0.4 41 | httpx==0.27.0 42 | huggingface-hub==0.21.4 43 | idna==3.6 44 | igraph==0.9.11 45 | imageio==2.9.0 46 | imageio-ffmpeg==0.4.9 47 | importlib_resources==6.3.2 48 | install==1.3.5 49 | ipython==8.4.0 50 | jedi==0.18.1 51 | Jinja2==3.1.3 52 | joblib==1.3.2 53 | jsonlines==4.0.0 54 | jsonschema==4.21.1 55 | jsonschema-specifications==2023.12.1 56 | kaleido==0.2.1 57 | kiwisolver==1.4.2 58 | kornia==0.7.2 59 | kornia_rs==0.1.2 60 | lightning-utilities==0.3.0 61 | lit==18.1.1 62 | markdown-it-py==3.0.0 63 | MarkupSafe==2.1.5 64 | matplotlib==3.5.2 65 | matplotlib-inline==0.1.3 66 | mdurl==0.1.2 67 | moviepy==1.0.3 68 | mpmath==1.2.1 69 | multidict==6.0.5 70 | mypy-extensions==1.0.0 71 | networkx==3.1 72 | numpy==1.22.4 73 | nvidia-cublas-cu11==11.10.3.66 74 | nvidia-cuda-cupti-cu11==11.7.101 75 | nvidia-cuda-nvrtc-cu11==11.7.99 76 | nvidia-cuda-runtime-cu11==11.7.99 77 | nvidia-cudnn-cu11==8.5.0.96 78 | nvidia-cufft-cu11==10.9.0.58 79 | nvidia-curand-cu11==10.2.10.91 80 | nvidia-cusolver-cu11==11.4.0.1 81 | nvidia-cusparse-cu11==11.7.4.91 82 | nvidia-nccl-cu11==2.14.3 83 | nvidia-nvtx-cu11==11.7.91 84 | omegaconf==2.1.1 85 | open-clip-torch==2.22.0 86 | opencv-python==4.9.0.80 87 | opencv-python-headless==4.9.0.80 88 | orjson==3.9.15 89 | packaging==21.3 90 | pandas==2.0.0 91 | parso==0.8.3 92 | pexpect==4.9.0 93 | pickleshare==0.7.5 94 | Pillow==9.1.1 95 | pkgutil_resolve_name==1.3.10 96 | plotly==5.8.2 97 | proglog==0.1.10 98 | prompt-toolkit==3.0.29 99 | protobuf==3.20.3 100 | ptyprocess==0.7.0 101 | pure-eval==0.2.2 102 | pydantic==2.6.4 103 | pydantic_core==2.16.3 104 | pydub==0.25.1 105 | Pygments==2.12.0 106 | pyparsing==3.0.9 107 | pyre-extensions==0.0.29 108 | python-dateutil==2.8.2 109 | python-multipart==0.0.9 110 | pytorch-lightning==1.8.3 111 | pytz==2024.1 112 | pytz-deprecation-shim==0.1.0.post0 113 | PyYAML==6.0 114 | referencing==0.34.0 115 | regex==2023.12.25 116 | requests==2.31.0 117 | rich==13.7.1 118 | rpds-py==0.18.0 119 | ruff==0.3.3 120 | safetensors==0.4.2 121 | scikit-learn==1.3.2 122 | scipy==1.10.1 123 | semantic-version==2.10.0 124 | sentencepiece==0.2.0 125 | shellingham==1.5.4 126 | six==1.16.0 127 | sniffio==1.3.1 128 | stack-data==0.2.0 129 | starlette==0.36.3 130 | sympy==1.10.1 131 | tenacity==8.0.1 132 | tensorboardX==2.6.2.2 133 | termcolor==2.4.0 134 | texttable==1.6.4 135 | threadpoolctl==3.4.0 136 | timm==0.9.16 137 | tokenizers==0.19.0 138 | tomlkit==0.12.0 139 | toolz==0.12.1 140 | torch==2.0.0 141 | torchmetrics==0.11.4 142 | torchvision==0.15.1 143 | tqdm==4.64.0 144 | traitlets==5.2.2.post1 145 | transformers==4.40.1 146 | triton==2.0.0 147 | typer==0.9.0 148 | typing-inspect==0.9.0 149 | typing_extensions==4.10.0 150 | tzdata==2024.1 151 | tzlocal==4.1 152 | urllib3==2.2.1 153 | uvicorn==0.29.0 154 | wcwidth==0.2.5 155 | websockets==11.0.3 156 | xformers==0.0.19 157 | yarl==1.9.4 158 | zipp==3.18.1 159 | -------------------------------------------------------------------------------- /examples/DynamiCrafter/scripts/evaluation/__pycache__/inference.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/scripts/evaluation/__pycache__/inference.cpython-38.pyc -------------------------------------------------------------------------------- /examples/DynamiCrafter/scripts/evaluation/ddp_wrapper.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import argparse, importlib 3 | from pytorch_lightning import seed_everything 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | def setup_dist(local_rank): 9 | if dist.is_initialized(): 10 | return 11 | torch.cuda.set_device(local_rank) 12 | torch.distributed.init_process_group('nccl', init_method='env://') 13 | 14 | 15 | def get_dist_info(): 16 | if dist.is_available(): 17 | initialized = dist.is_initialized() 18 | else: 19 | initialized = False 20 | if initialized: 21 | rank = dist.get_rank() 22 | world_size = dist.get_world_size() 23 | else: 24 | rank = 0 25 | world_size = 1 26 | return rank, world_size 27 | 28 | 29 | if __name__ == '__main__': 30 | now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--module", type=str, help="module name", default="inference") 33 | parser.add_argument("--local_rank", type=int, nargs="?", help="for ddp", default=0) 34 | args, unknown = parser.parse_known_args() 35 | inference_api = importlib.import_module(args.module, package=None) 36 | 37 | inference_parser = inference_api.get_parser() 38 | inference_args, unknown = inference_parser.parse_known_args() 39 | 40 | seed_everything(inference_args.seed) 41 | setup_dist(args.local_rank) 42 | torch.backends.cudnn.benchmark = True 43 | rank, gpu_num = get_dist_info() 44 | 45 | # inference_args.savedir = inference_args.savedir+str('_seed')+str(inference_args.seed) 46 | print("@DynamiCrafter Inference [rank%d]: %s"%(rank, now)) 47 | inference_api.run_inference(inference_args, gpu_num, rank) -------------------------------------------------------------------------------- /examples/DynamiCrafter/train.sh: -------------------------------------------------------------------------------- 1 | # args 2 | name="training_512_v1.0" 3 | config_file="configs/train_512.yaml" 4 | # save root dir for logs, checkpoints, tensorboard record, etc. 5 | save_root="train" 6 | 7 | mkdir -p $save_root/$name 8 | 9 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch \ 10 | --nproc_per_node=4 --nnodes=1 --master_addr=127.0.0.1 --master_port=12352 --node_rank=0 \ 11 | ./main/trainer.py \ 12 | --base $config_file \ 13 | --train \ 14 | --name $name \ 15 | --logdir $save_root \ 16 | --devices 4 \ 17 | lightning.trainer.num_nodes=1 -------------------------------------------------------------------------------- /examples/DynamiCrafter/utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/DynamiCrafter/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /examples/DynamiCrafter/utils/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import cv2 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | def count_params(model, verbose=False): 9 | total_params = sum(p.numel() for p in model.parameters()) 10 | if verbose: 11 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 12 | return total_params 13 | 14 | 15 | def check_istarget(name, para_list): 16 | """ 17 | name: full name of source para 18 | para_list: partial name of target para 19 | """ 20 | istarget=False 21 | for para in para_list: 22 | if para in name: 23 | return True 24 | return istarget 25 | 26 | 27 | def instantiate_from_config(config): 28 | if not "target" in config: 29 | if config == '__is_first_stage__': 30 | return None 31 | elif config == "__is_unconditional__": 32 | return None 33 | raise KeyError("Expected key `target` to instantiate.") 34 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 35 | 36 | 37 | def get_obj_from_str(string, reload=False): 38 | module, cls = string.rsplit(".", 1) 39 | if reload: 40 | module_imp = importlib.import_module(module) 41 | importlib.reload(module_imp) 42 | return getattr(importlib.import_module(module, package=None), cls) 43 | 44 | 45 | def load_npz_from_dir(data_dir): 46 | data = [np.load(os.path.join(data_dir, data_name))['arr_0'] for data_name in os.listdir(data_dir)] 47 | data = np.concatenate(data, axis=0) 48 | return data 49 | 50 | 51 | def load_npz_from_paths(data_paths): 52 | data = [np.load(data_path)['arr_0'] for data_path in data_paths] 53 | data = np.concatenate(data, axis=0) 54 | return data 55 | 56 | 57 | def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None): 58 | h, w = image.shape[:2] 59 | if resize_short_edge is not None: 60 | k = resize_short_edge / min(h, w) 61 | else: 62 | k = max_resolution / (h * w) 63 | k = k**0.5 64 | h = int(np.round(h * k / 64)) * 64 65 | w = int(np.round(w * k / 64)) * 64 66 | image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) 67 | return image 68 | 69 | 70 | def setup_dist(args): 71 | if dist.is_initialized(): 72 | return 73 | torch.cuda.set_device(args.local_rank) 74 | torch.distributed.init_process_group( 75 | 'nccl', 76 | init_method='env://' 77 | ) -------------------------------------------------------------------------------- /examples/SVD/config/inference1024.yaml: -------------------------------------------------------------------------------- 1 | seed: 2 | 42 3 | original_svd: true 4 | pretrained_model_path: "./ckpt/pretrained/stable-video-diffusion-img2vid/" 5 | batch_size: 1 6 | data_dir: 7 | ./demo 8 | resolution: 9 | - 1024 10 | - 576 11 | checkpoint_root: None 12 | step: 13 | 20000 14 | use_ema: true 15 | motion_bucket_id: 16 | 20.0 17 | fps: 18 | 3 19 | num_step: 20 | 25 21 | sigma_max: 22 | 100 23 | analytic_path: 24 | "./ckpt/initial_noise_1024.pt" 25 | -------------------------------------------------------------------------------- /examples/SVD/config/inference512.yaml: -------------------------------------------------------------------------------- 1 | seed: 2 | 42 3 | original_svd: false 4 | pretrained_model_path: "./ckpt/pretrained/stable-video-diffusion-img2vid/" 5 | batch_size: 1 6 | data_dir: 7 | ./demo 8 | resolution: 9 | - 512 10 | - 320 11 | checkpoint_root: 12 | ./ckpt/finetuned 13 | step: 14 | 20000 15 | use_ema: true 16 | motion_bucket_id: 17 | 20.0 18 | fps: 19 | 5 20 | num_step: 21 | 25 22 | sigma_max: 23 | 100 24 | analytic_path: 25 | ./ckpt/initial_noise_512.pt -------------------------------------------------------------------------------- /examples/SVD/config/train.yaml: -------------------------------------------------------------------------------- 1 | global_seed: 23 2 | motion_bucket_id: 20.0 3 | 4 | noise_scheduler_kwargs: 5 | P_mean: -1.2 6 | P_std: 1.2 7 | sigma_data: 1 8 | beta_m: 15 9 | a: 5 10 | 11 | train_data: 12 | file_path: '../dataset/results_2M_train.csv' 13 | video_folder: '../dataset' 14 | sample_size: 320,512 15 | fps: 3 16 | sample_n_frames: 16 17 | 18 | output_dir: "results/train" 19 | pretrained_model_path: "./ckpt/pretrained/stable-video-diffusion-img2vid/" 20 | resume_path: "" 21 | 22 | use_ema: True 23 | gradient_checkpointing: True 24 | mixed_precision_training: True 25 | 26 | cfg_random_null_ratio: 0.1 27 | learning_rate: 3.e-5 28 | train_batch_size: 1 29 | max_train_steps: 100000 30 | ema_decay: 0.9999 31 | checkpointing_steps: 10000 32 | 33 | validation_folder: "./demo" 34 | validation_steps: 5000 35 | -------------------------------------------------------------------------------- /examples/SVD/demo/1066.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/SVD/demo/1066.jpg -------------------------------------------------------------------------------- /examples/SVD/demo/485.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/SVD/demo/485.jpg -------------------------------------------------------------------------------- /examples/SVD/demo/A 360 shot of a sleek yacht sailing gracefully through the crystal-clear waters of the Caribbean..png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/SVD/demo/A 360 shot of a sleek yacht sailing gracefully through the crystal-clear waters of the Caribbean..png -------------------------------------------------------------------------------- /examples/SVD/demo/A girl with long curly blonde hair and sunglasses, camera pan from left to right..png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/SVD/demo/A girl with long curly blonde hair and sunglasses, camera pan from left to right..png -------------------------------------------------------------------------------- /examples/SVD/demo/A panda wearing sunglasses walking in slow-motion under water, in photorealistic style..png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/SVD/demo/A panda wearing sunglasses walking in slow-motion under water, in photorealistic style..png -------------------------------------------------------------------------------- /examples/SVD/demo/A pizza spinning inside a wood fired pizza oven; dramatic vivid colors..png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/SVD/demo/A pizza spinning inside a wood fired pizza oven; dramatic vivid colors..png -------------------------------------------------------------------------------- /examples/SVD/demo/a car parked in a parking lot with palm trees nearby,calm seas and skies..png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/SVD/demo/a car parked in a parking lot with palm trees nearby,calm seas and skies..png -------------------------------------------------------------------------------- /examples/SVD/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from svd.inference.pipline_CILsvd import StableVideoDiffusionCILPipeline 4 | from svd.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler 5 | from svd.training.utils import save_videos_grid, load_PIL_images 6 | from einops import rearrange 7 | from svd.data.dataset import ImageDataset 8 | from torch.utils.data.distributed import DistributedSampler 9 | from svd.training.utils import init_dist, set_seed 10 | import torch.distributed as dist 11 | from diffusers.training_utils import EMAModel 12 | from svd.models.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel 13 | import argparse 14 | from omegaconf import OmegaConf 15 | 16 | def load_model(model, ema, filename="checkpoint.pth.tar"): 17 | if os.path.isfile(filename): 18 | checkpoint = torch.load(filename, map_location="cpu") 19 | model.load_state_dict(checkpoint['state_dict']) 20 | if 'ema_state' in checkpoint: 21 | ema.load_state_dict(checkpoint['ema_state']) 22 | 23 | def sampling( 24 | data_dir, 25 | scheduler_path, 26 | num_step, 27 | pretrained_model_path, 28 | checkpoint_path, 29 | save_dir, 30 | resolution, 31 | batch_size, 32 | use_ema = True, 33 | num_processes=8, 34 | local_rank=0, 35 | motion_bucket_id=20, 36 | fps=3, 37 | num_frames=16, 38 | seed=42, 39 | analytic_path='' 40 | ): 41 | 42 | scheduler = EulerDiscreteScheduler.from_config(scheduler_path) 43 | 44 | if checkpoint_path is not None: 45 | unet = UNetSpatioTemporalConditionModel.from_pretrained(pretrained_model_path, subfolder="unet", torch_dtype=torch.float16) 46 | ema_unet = EMAModel(unet) 47 | load_model(unet, ema_unet, filename=checkpoint_path) 48 | print(f'the model is loaded from {checkpoint_path}') 49 | if use_ema: 50 | ema_unet.copy_to(unet.parameters()) 51 | pipe = StableVideoDiffusionCILPipeline.from_pretrained(pretrained_model_path, unet=unet,scheduler=scheduler).to('cuda') 52 | else: 53 | pipe = StableVideoDiffusionCILPipeline.from_pretrained(pretrained_model_path,scheduler=scheduler).to(local_rank) 54 | 55 | # Get the training dataset 56 | dataset = ImageDataset(data_dir) 57 | distributed_sampler = DistributedSampler( 58 | dataset, 59 | shuffle=False, 60 | ) 61 | 62 | # DataLoaders creation: 63 | dataloader = torch.utils.data.DataLoader( 64 | dataset, 65 | batch_size=batch_size, 66 | shuffle=False, 67 | sampler=distributed_sampler, 68 | num_workers=num_processes, 69 | pin_memory=True, 70 | drop_last=False, 71 | ) 72 | 73 | 74 | 75 | print(f'local_rank is {local_rank}') 76 | os.makedirs(save_dir, exist_ok=True) 77 | 78 | 79 | 80 | for step, batch in enumerate(dataloader): 81 | paths = batch['path'] 82 | names = batch['name'] 83 | 84 | images_list = load_PIL_images(paths, resolution) 85 | set_seed(seed) 86 | generator = torch.manual_seed(seed) 87 | samples = pipe(images_list, output_type="pt", generator=generator, height=resolution[1], width=resolution[0], 88 | num_inference_steps=num_step ,num_frames=num_frames,fps=fps,motion_bucket_id=motion_bucket_id,analytic_path=analytic_path 89 | ).frames 90 | 91 | samples = torch.stack(samples) 92 | 93 | for sample, name in zip(samples, names): 94 | name = name.split('.')[0] 95 | save_path = f"{save_dir}/{name}.mp4" 96 | sample = rearrange(sample, "t c h w -> c t h w").unsqueeze(dim=0) 97 | save_videos_grid(sample.cpu(), save_path, n_rows=1) 98 | print(f'the sample has been saved in {save_path}') 99 | 100 | def main( 101 | seed, 102 | resolution, 103 | batch_size, 104 | original_svd, 105 | use_ema, 106 | step, 107 | data_dir, 108 | motion_bucket_id, 109 | fps, 110 | pretrained_model_path, 111 | checkpoint_root, 112 | num_step, 113 | sigma_max, 114 | analytic_path=None 115 | ): 116 | 117 | 118 | # Initialize distributed training 119 | local_rank = init_dist(launcher="pytorch", backend="nccl") 120 | num_processes = dist.get_world_size() 121 | 122 | 123 | scheduler_path=f'./schedulers/scheduler_config{sigma_max}.json' 124 | checkpoint_path = os.path.join(checkpoint_root, f'checkpoint-step-{int(step)}.ckpt') if not original_svd else None 125 | 126 | if checkpoint_path is not None and not os.path.exists(checkpoint_path): 127 | raise EOFError(f'the checkpoint {checkpoint_path} is nokbit existing') 128 | 129 | task_name =f'sigma_max_{sigma_max}' 130 | save_dir = os.path.join(f'results/inference',task_name) 131 | os.makedirs(save_dir, exist_ok=True) 132 | 133 | 134 | sampling(data_dir=data_dir,num_step=num_step,scheduler_path=scheduler_path,pretrained_model_path=pretrained_model_path, 135 | checkpoint_path=checkpoint_path, use_ema = use_ema,save_dir= save_dir, resolution=tuple(resolution), 136 | batch_size=batch_size, num_processes=num_processes, local_rank=local_rank, 137 | motion_bucket_id=motion_bucket_id,fps=fps,seed=seed,analytic_path=analytic_path) 138 | 139 | 140 | 141 | if __name__ == "__main__": 142 | parser = argparse.ArgumentParser() 143 | parser.add_argument("--config", type=str,help='path to your config file') 144 | 145 | args = parser.parse_args() 146 | 147 | config = OmegaConf.load(args.config) 148 | 149 | main(**config) 150 | 151 | -------------------------------------------------------------------------------- /examples/SVD/inference.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2,3,4,5,6,7 torchrun --nproc_per_node=6 inference.py --config "config/inference1024.yaml" 2 | -------------------------------------------------------------------------------- /examples/SVD/inference_CIL_512.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2,3,4,5,6,7 torchrun --master_port=12345 --nproc_per_node=6 inference.py --config "config/inference512.yaml" 2 | -------------------------------------------------------------------------------- /examples/SVD/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==0.29.2 3 | aiohttp==3.9.5 4 | aiosignal==1.3.1 5 | annotated-types==0.6.0 6 | antlr4-python3-runtime==4.9.3 7 | anyio==4.2.0 8 | appdirs==1.4.4 9 | argon2-cffi==23.1.0 10 | argon2-cffi-bindings==21.2.0 11 | arrow==1.3.0 12 | async-lru==2.0.4 13 | async-timeout==4.0.3 14 | attrs==23.2.0 15 | Babel==2.14.0 16 | beautifulsoup4==4.12.3 17 | bleach==6.1.0 18 | cachetools==5.3.2 19 | chainer==7.8.1 20 | click==8.1.7 21 | clickhouse-driver==0.2.6 22 | contextlib2==21.6.0 23 | contourpy==1.2.0 24 | cycler==0.12.1 25 | decorator==4.4.2 26 | decord==0.6.0 27 | deepspeed==0.14.1 28 | defusedxml==0.7.1 29 | diffusers==0.25.1 30 | docker-pycreds==0.4.0 31 | einops==0.7.0 32 | exceptiongroup==1.2.0 33 | fastjsonschema==2.19.1 34 | fonttools==4.47.2 35 | fqdn==1.5.1 36 | frozenlist==1.4.1 37 | fsspec==2023.12.2 38 | ftfy==6.1.3 39 | gensim==4.3.2 40 | gitdb==4.0.11 41 | GitPython==3.1.42 42 | google-auth==2.27.0 43 | google-auth-oauthlib==1.2.0 44 | grpcio==1.60.0 45 | guppy3==3.1.4.post1 46 | h11==0.14.0 47 | h5py==3.10.0 48 | hjson==3.1.0 49 | httpcore==1.0.2 50 | httpx==0.26.0 51 | huggingface-hub==0.20.3 52 | imageio==2.33.1 53 | imageio-ffmpeg==0.4.9 54 | importlib-metadata==7.0.1 55 | importlib-resources==6.1.1 56 | ipython==8.18.1 57 | ipywidgets==8.1.1 58 | isoduration==20.11.0 59 | joblib==1.3.2 60 | json5==0.9.14 61 | jsonlines==4.0.0 62 | jsonpointer==2.4 63 | jsonschema==4.21.1 64 | jsonschema-specifications==2023.12.1 65 | jupyter==1.0.0 66 | jupyter-console==6.6.3 67 | jupyter-events==0.9.0 68 | jupyter-lsp==2.2.2 69 | jupyter_client==8.6.0 70 | jupyter_core==5.7.1 71 | jupyter_server==2.12.5 72 | jupyter_server_terminals==0.5.2 73 | jupyterlab==4.1.0 74 | jupyterlab-widgets==3.0.9 75 | jupyterlab_pygments==0.3.0 76 | jupyterlab_server==2.25.2 77 | kiwisolver==1.4.5 78 | lightning-utilities==0.11.2 79 | lpips==0.1.4 80 | Markdown==3.5.2 81 | matplotlib==3.8.2 82 | mistune==3.0.2 83 | mkl-service==2.4.0 84 | ml-collections==0.1.1 85 | moviepy==1.0.3 86 | multidict==6.0.5 87 | nbclient==0.9.0 88 | nbconvert==7.16.0 89 | nbformat==5.9.2 90 | ninja==1.11.1.1 91 | notebook==7.0.7 92 | notebook_shim==0.2.3 93 | oauthlib==3.2.2 94 | omegaconf==2.3.0 95 | opencv-python==4.9.0.80 96 | opencv-python-headless==4.9.0.80 97 | overrides==7.7.0 98 | pandas==2.2.0 99 | pandocfilters==1.5.1 100 | pika==1.3.2 101 | proglog==0.1.10 102 | prometheus-client==0.19.0 103 | prompt-toolkit==3.0.43 104 | protobuf==4.23.4 105 | py-cpuinfo==9.0.0 106 | pyasn1==0.5.1 107 | pyasn1-modules==0.3.0 108 | pydantic==2.7.0 109 | pydantic_core==2.18.1 110 | pynvml==11.5.0 111 | pyparsing==3.1.1 112 | python-json-logger==2.0.7 113 | pytorch-lightning==2.2.2 114 | pytz==2023.3.post1 115 | qtconsole==5.5.1 116 | QtPy==2.4.1 117 | referencing==0.33.0 118 | regex==2023.12.25 119 | requests-oauthlib==1.3.1 120 | rfc3339-validator==0.1.4 121 | rfc3986-validator==0.1.1 122 | rpds-py==0.17.1 123 | rsa==4.9 124 | safetensors==0.4.2 125 | scikit-learn==1.4.0 126 | scipy==1.12.0 127 | Send2Trash==1.8.2 128 | sentencepiece==0.2.0 129 | sentry-sdk==1.40.6 130 | setproctitle==1.3.3 131 | smart-open==7.0.4 132 | smmap==5.0.1 133 | sniffio==1.3.0 134 | soupsieve==2.5 135 | tensorboard==2.15.1 136 | tensorboard-data-server==0.7.2 137 | terminado==0.18.0 138 | threadpoolctl==3.2.0 139 | tinycss2==1.2.1 140 | tokenizers==0.15.1 141 | tomli==2.0.1 142 | torch==2.1.1 143 | torchaudio==2.1.1 144 | torchmetrics==1.3.2 145 | torchvision==0.16.1 146 | tornado==6.4 147 | tqdm==4.66.2 148 | transformers==4.37.1 149 | triton==2.1.0 150 | types-python-dateutil==2.8.19.20240106 151 | tzdata==2023.4 152 | tzlocal==5.2 153 | uri-template==1.3.0 154 | wandb==0.16.3 155 | webcolors==1.13 156 | webencodings==0.5.1 157 | websocket-client==1.7.0 158 | Werkzeug==3.0.1 159 | widgetsnbextension==4.0.9 160 | wrapt==1.16.0 161 | yarl==1.9.4 162 | zipp==3.17.0 163 | -------------------------------------------------------------------------------- /examples/SVD/schedulers/scheduler_config1.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "EulerDiscreteScheduler", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "interpolation_type": "linear", 9 | "num_train_timesteps": 1000, 10 | "prediction_type": "v_prediction", 11 | "set_alpha_to_one": false, 12 | "sigma_max": 100.0, 13 | "sigma_min": 0.002, 14 | "skip_prk_steps": true, 15 | "steps_offset": 1, 16 | "timestep_spacing": "leading", 17 | "timestep_type": "continuous", 18 | "trained_betas": null, 19 | "use_karras_sigmas": true 20 | } 21 | -------------------------------------------------------------------------------- /examples/SVD/schedulers/scheduler_config10.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "EulerDiscreteScheduler", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "interpolation_type": "linear", 9 | "num_train_timesteps": 1000, 10 | "prediction_type": "v_prediction", 11 | "set_alpha_to_one": false, 12 | "sigma_max": 10.0, 13 | "sigma_min": 0.002, 14 | "skip_prk_steps": true, 15 | "steps_offset": 1, 16 | "timestep_spacing": "leading", 17 | "timestep_type": "continuous", 18 | "trained_betas": null, 19 | "use_karras_sigmas": true 20 | } 21 | -------------------------------------------------------------------------------- /examples/SVD/schedulers/scheduler_config100.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "EulerDiscreteScheduler", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "interpolation_type": "linear", 9 | "num_train_timesteps": 1000, 10 | "prediction_type": "v_prediction", 11 | "set_alpha_to_one": false, 12 | "sigma_max": 100.0, 13 | "sigma_min": 0.002, 14 | "skip_prk_steps": true, 15 | "steps_offset": 1, 16 | "timestep_spacing": "leading", 17 | "timestep_type": "continuous", 18 | "trained_betas": null, 19 | "use_karras_sigmas": true 20 | } 21 | -------------------------------------------------------------------------------- /examples/SVD/schedulers/scheduler_config1100.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "EulerDiscreteScheduler", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "interpolation_type": "linear", 9 | "num_train_timesteps": 1000, 10 | "prediction_type": "v_prediction", 11 | "set_alpha_to_one": false, 12 | "sigma_max": 1100.0, 13 | "sigma_min": 0.002, 14 | "skip_prk_steps": true, 15 | "steps_offset": 1, 16 | "timestep_spacing": "leading", 17 | "timestep_type": "continuous", 18 | "trained_betas": null, 19 | "use_karras_sigmas": true 20 | } 21 | -------------------------------------------------------------------------------- /examples/SVD/schedulers/scheduler_config20.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "EulerDiscreteScheduler", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "interpolation_type": "linear", 9 | "num_train_timesteps": 1000, 10 | "prediction_type": "v_prediction", 11 | "set_alpha_to_one": false, 12 | "sigma_max": 20.0, 13 | "sigma_min": 0.002, 14 | "skip_prk_steps": true, 15 | "steps_offset": 1, 16 | "timestep_spacing": "leading", 17 | "timestep_type": "continuous", 18 | "trained_betas": null, 19 | "use_karras_sigmas": true 20 | } 21 | -------------------------------------------------------------------------------- /examples/SVD/schedulers/scheduler_config30.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "EulerDiscreteScheduler", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "interpolation_type": "linear", 9 | "num_train_timesteps": 1000, 10 | "prediction_type": "v_prediction", 11 | "set_alpha_to_one": false, 12 | "sigma_max": 30.0, 13 | "sigma_min": 0.002, 14 | "skip_prk_steps": true, 15 | "steps_offset": 1, 16 | "timestep_spacing": "leading", 17 | "timestep_type": "continuous", 18 | "trained_betas": null, 19 | "use_karras_sigmas": true 20 | } 21 | -------------------------------------------------------------------------------- /examples/SVD/schedulers/scheduler_config300.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "EulerDiscreteScheduler", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "interpolation_type": "linear", 9 | "num_train_timesteps": 1000, 10 | "prediction_type": "v_prediction", 11 | "set_alpha_to_one": false, 12 | "sigma_max": 300.0, 13 | "sigma_min": 0.002, 14 | "skip_prk_steps": true, 15 | "steps_offset": 1, 16 | "timestep_spacing": "leading", 17 | "timestep_type": "continuous", 18 | "trained_betas": null, 19 | "use_karras_sigmas": true 20 | } 21 | -------------------------------------------------------------------------------- /examples/SVD/schedulers/scheduler_config40.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "EulerDiscreteScheduler", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "interpolation_type": "linear", 9 | "num_train_timesteps": 1000, 10 | "prediction_type": "v_prediction", 11 | "set_alpha_to_one": false, 12 | "sigma_max": 40.0, 13 | "sigma_min": 0.002, 14 | "skip_prk_steps": true, 15 | "steps_offset": 1, 16 | "timestep_spacing": "leading", 17 | "timestep_type": "continuous", 18 | "trained_betas": null, 19 | "use_karras_sigmas": true 20 | } 21 | -------------------------------------------------------------------------------- /examples/SVD/schedulers/scheduler_config50.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "EulerDiscreteScheduler", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "interpolation_type": "linear", 9 | "num_train_timesteps": 1000, 10 | "prediction_type": "v_prediction", 11 | "set_alpha_to_one": false, 12 | "sigma_max": 50.0, 13 | "sigma_min": 0.002, 14 | "skip_prk_steps": true, 15 | "steps_offset": 1, 16 | "timestep_spacing": "leading", 17 | "timestep_type": "continuous", 18 | "trained_betas": null, 19 | "use_karras_sigmas": true 20 | } 21 | -------------------------------------------------------------------------------- /examples/SVD/schedulers/scheduler_config500.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "EulerDiscreteScheduler", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "interpolation_type": "linear", 9 | "num_train_timesteps": 1000, 10 | "prediction_type": "v_prediction", 11 | "set_alpha_to_one": false, 12 | "sigma_max": 500.0, 13 | "sigma_min": 0.002, 14 | "skip_prk_steps": true, 15 | "steps_offset": 1, 16 | "timestep_spacing": "leading", 17 | "timestep_type": "continuous", 18 | "trained_betas": null, 19 | "use_karras_sigmas": true 20 | } 21 | -------------------------------------------------------------------------------- /examples/SVD/schedulers/scheduler_config60.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "EulerDiscreteScheduler", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "interpolation_type": "linear", 9 | "num_train_timesteps": 1000, 10 | "prediction_type": "v_prediction", 11 | "set_alpha_to_one": false, 12 | "sigma_max": 60.0, 13 | "sigma_min": 0.002, 14 | "skip_prk_steps": true, 15 | "steps_offset": 1, 16 | "timestep_spacing": "leading", 17 | "timestep_type": "continuous", 18 | "trained_betas": null, 19 | "use_karras_sigmas": true 20 | } 21 | -------------------------------------------------------------------------------- /examples/SVD/schedulers/scheduler_config70.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "EulerDiscreteScheduler", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "interpolation_type": "linear", 9 | "num_train_timesteps": 1000, 10 | "prediction_type": "v_prediction", 11 | "set_alpha_to_one": false, 12 | "sigma_max": 70.0, 13 | "sigma_min": 0.002, 14 | "skip_prk_steps": true, 15 | "steps_offset": 1, 16 | "timestep_spacing": "leading", 17 | "timestep_type": "continuous", 18 | "trained_betas": null, 19 | "use_karras_sigmas": true 20 | } 21 | -------------------------------------------------------------------------------- /examples/SVD/schedulers/scheduler_config700.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "EulerDiscreteScheduler", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "interpolation_type": "linear", 9 | "num_train_timesteps": 1000, 10 | "prediction_type": "v_prediction", 11 | "set_alpha_to_one": false, 12 | "sigma_max": 700.0, 13 | "sigma_min": 0.002, 14 | "skip_prk_steps": true, 15 | "steps_offset": 1, 16 | "timestep_spacing": "leading", 17 | "timestep_type": "continuous", 18 | "trained_betas": null, 19 | "use_karras_sigmas": true 20 | } 21 | -------------------------------------------------------------------------------- /examples/SVD/schedulers/scheduler_config80.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "EulerDiscreteScheduler", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "interpolation_type": "linear", 9 | "num_train_timesteps": 1000, 10 | "prediction_type": "v_prediction", 11 | "set_alpha_to_one": false, 12 | "sigma_max": 80.0, 13 | "sigma_min": 0.002, 14 | "skip_prk_steps": true, 15 | "steps_offset": 1, 16 | "timestep_spacing": "leading", 17 | "timestep_type": "continuous", 18 | "trained_betas": null, 19 | "use_karras_sigmas": true 20 | } 21 | -------------------------------------------------------------------------------- /examples/SVD/schedulers/scheduler_config90.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "EulerDiscreteScheduler", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "interpolation_type": "linear", 9 | "num_train_timesteps": 1000, 10 | "prediction_type": "v_prediction", 11 | "set_alpha_to_one": false, 12 | "sigma_max": 90.0, 13 | "sigma_min": 0.002, 14 | "skip_prk_steps": true, 15 | "steps_offset": 1, 16 | "timestep_spacing": "leading", 17 | "timestep_type": "continuous", 18 | "trained_betas": null, 19 | "use_karras_sigmas": true 20 | } 21 | -------------------------------------------------------------------------------- /examples/SVD/schedulers/scheduler_config900.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "EulerDiscreteScheduler", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "interpolation_type": "linear", 9 | "num_train_timesteps": 1000, 10 | "prediction_type": "v_prediction", 11 | "set_alpha_to_one": false, 12 | "sigma_max": 900.0, 13 | "sigma_min": 0.002, 14 | "skip_prk_steps": true, 15 | "steps_offset": 1, 16 | "timestep_spacing": "leading", 17 | "timestep_type": "continuous", 18 | "trained_betas": null, 19 | "use_karras_sigmas": true 20 | } 21 | -------------------------------------------------------------------------------- /examples/SVD/svd/data/dataset.py: -------------------------------------------------------------------------------- 1 | import os, csv, random 2 | import numpy as np 3 | from decord import VideoReader 4 | import jsonlines 5 | import torch 6 | import torchvision.transforms as transforms 7 | from torch.utils.data.dataset import Dataset 8 | 9 | class RandomHorizontalFlipVideo(object): 10 | def __init__(self, p=0.5): 11 | self.p = p 12 | 13 | def __call__(self, clip): 14 | if torch.rand(1) < self.p: 15 | return torch.flip(clip, [3]) 16 | return clip 17 | 18 | class WebVid10M(Dataset): 19 | def __init__( 20 | self, 21 | file_path, video_folder, 22 | sample_size=256, fps=6, sample_n_frames=16): 23 | with open(file_path, 'r') as csvfile: 24 | reader = csv.DictReader(csvfile) 25 | self.dataset = [video for video in reader] 26 | csvfile.close() 27 | self.length = len(self.dataset) 28 | print(f"data scale: {self.length}") 29 | 30 | self.video_folder = video_folder 31 | self.fps = fps 32 | self.sample_n_frames = sample_n_frames 33 | if isinstance(sample_size, int): 34 | sample_size = tuple([int(sample_size)] * 2) 35 | else: 36 | sample_size = tuple(map(int, sample_size.split(','))) 37 | 38 | self.pixel_transforms = transforms.Compose([ 39 | transforms.Resize(sample_size[0], antialias=True), 40 | transforms.CenterCrop(sample_size), 41 | RandomHorizontalFlipVideo(), 42 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 43 | ]) 44 | 45 | def get_batch(self, idx): 46 | video_dict = self.dataset[idx] 47 | videoid = video_dict['videoid'] 48 | video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") 49 | video_reader = VideoReader(video_dir) 50 | 51 | fps = video_reader.get_avg_fps() 52 | sample_stride = round(fps/self.fps) 53 | 54 | # sample sample_n_frames frames from videos with stride sample_stride 55 | video_length = len(video_reader) 56 | clip_length = min(video_length, (self.sample_n_frames - 1) * sample_stride + 1) 57 | start_idx = random.randint(0, video_length - clip_length) 58 | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) 59 | 60 | pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() 61 | pixel_values = pixel_values / 255. #[T, C, H, W] with range [0, 1] 62 | del video_reader 63 | 64 | return pixel_values, self.fps, videoid 65 | 66 | def __len__(self): 67 | return self.length 68 | 69 | def __getitem__(self, idx): 70 | while True: 71 | try: 72 | pixel_values, fps , videoid= self.get_batch(idx) 73 | break 74 | 75 | except Exception as e: 76 | idx = random.randint(0, self.length - 1) 77 | 78 | pixel_values = self.pixel_transforms(pixel_values) #[T, C, H, W] with range [-1, 1] 79 | sample = dict(pixel_values=pixel_values, fps=fps, id=videoid) 80 | return sample 81 | 82 | class ImageDataset(Dataset): 83 | def __init__(self, data_path): 84 | filenames = sorted(os.listdir(data_path)) 85 | self.length = len(filenames) 86 | self.data_path = data_path 87 | self.filenames = filenames 88 | 89 | def __len__(self): 90 | return self.length 91 | 92 | def __getitem__(self, idx): 93 | filename = self.filenames[idx] 94 | path = os.path.join(self.data_path, filename) 95 | sample = dict(path=path, name=filename) 96 | return sample 97 | 98 | class MultiImageDataset(Dataset): 99 | def __init__(self, data_paths): 100 | self.paths = [] 101 | for data_path in data_paths: 102 | filenames = sorted(os.listdir(data_path)) 103 | for filename in filenames: 104 | path = os.path.join(data_path, filename) 105 | self.paths.append(path) 106 | 107 | def __len__(self): 108 | return len(self.paths) 109 | 110 | def __getitem__(self, idx): 111 | path = self.paths[idx] 112 | dataset_name = path.split('/')[-2] 113 | filename = path.split('/')[-1] 114 | sample = dict(path=path, dataset_name=dataset_name, name=filename) 115 | return sample -------------------------------------------------------------------------------- /examples/SVD/svd/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Loss functions used in the paper 9 | "Elucidating the Design Space of Diffusion-Based Generative Models".""" 10 | 11 | import torch 12 | from torch_utils import persistence 13 | 14 | # ---------------------------------------------------------------------------- 15 | # Loss function corresponding to the variance preserving (VP) formulation 16 | # from the paper "Score-Based Generative Modeling through Stochastic 17 | # Differential Equations". 18 | 19 | 20 | @persistence.persistent_class 21 | class VPLoss: 22 | def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5): 23 | self.beta_d = beta_d 24 | self.beta_min = beta_min 25 | self.epsilon_t = epsilon_t 26 | 27 | def __call__(self, net, images, labels, augment_pipe=None): 28 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) 29 | sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1)) 30 | weight = 1 / sigma**2 31 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 32 | n = torch.randn_like(y) * sigma 33 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 34 | loss = weight * ((D_yn - y) ** 2) 35 | return loss 36 | 37 | def sigma(self, t): 38 | t = torch.as_tensor(t) 39 | return ((0.5 * self.beta_d * (t**2) + self.beta_min * t).exp() - 1).sqrt() 40 | 41 | 42 | # ---------------------------------------------------------------------------- 43 | # Loss function corresponding to the variance exploding (VE) formulation 44 | # from the paper "Score-Based Generative Modeling through Stochastic 45 | # Differential Equations". 46 | 47 | 48 | @persistence.persistent_class 49 | class VELoss: 50 | def __init__(self, sigma_min=0.02, sigma_max=100): 51 | self.sigma_min = sigma_min 52 | self.sigma_max = sigma_max 53 | 54 | def __call__(self, net, images, labels, augment_pipe=None): 55 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) 56 | sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform) 57 | weight = 1 / sigma**2 58 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 59 | n = torch.randn_like(y) * sigma 60 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 61 | loss = weight * ((D_yn - y) ** 2) 62 | return loss 63 | 64 | 65 | # ---------------------------------------------------------------------------- 66 | # Improved loss function proposed in the paper "Elucidating the Design Space 67 | # of Diffusion-Based Generative Models" (EDM). 68 | 69 | 70 | @persistence.persistent_class 71 | class EDMLoss: 72 | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5): 73 | self.P_mean = P_mean 74 | self.P_std = P_std 75 | self.sigma_data = sigma_data 76 | 77 | def __call__(self, net, images, labels=None, augment_pipe=None): 78 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) 79 | sigma = (rnd_normal * self.P_std + self.P_mean).exp() 80 | weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 81 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 82 | n = torch.randn_like(y) * sigma 83 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 84 | loss = weight * ((D_yn - y) ** 2) 85 | return loss 86 | 87 | # ---------------------------------------------------------------------------- 88 | 89 | @persistence.persistent_class 90 | class EDMLoss2: 91 | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5): 92 | self.P_mean = P_mean 93 | self.P_std = P_std 94 | self.sigma_data = sigma_data 95 | 96 | def __call__(self, net, images, labels=None, augment_pipe=None): 97 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) 98 | t = (rnd_normal * self.P_std + self.P_mean).exp() 99 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 100 | pred, target = net(y, t, labels, augment_labels=augment_labels) 101 | loss = (pred - target) ** 2 102 | return loss 103 | -------------------------------------------------------------------------------- /examples/SVD/svd/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from ..schedulers.edm import add_noise 4 | 5 | def VScalingWithEDMcNoise(sigma): 6 | c_skip = 1.0 / (sigma ** 2 + 1.0) 7 | c_out = -sigma / (sigma ** 2 + 1.0) ** 0.5 8 | c_in = 1.0 / (sigma ** 2 + 1.0) ** 0.5 9 | c_noise = torch.Tensor( 10 | [0.25 * sigma.log() for sigma in sigma]).to(sigma.device) 11 | return c_skip, c_out, c_in, c_noise 12 | 13 | def get_add_time_ids( 14 | fps, 15 | motion_bucket_id, 16 | noise_aug_strength, 17 | dtype, 18 | batch_size, 19 | ): 20 | motion_bucket_id = torch.tensor(motion_bucket_id.repeat(batch_size, 1)) 21 | add_time_ids = [fps, motion_bucket_id, noise_aug_strength] 22 | 23 | add_time_ids = torch.tensor([add_time_ids], dtype=dtype) 24 | add_time_ids = add_time_ids.repeat(batch_size, 1) 25 | return add_time_ids 26 | 27 | def motion_score_flow(inputs, FlowModel): 28 | ''' 29 | compute motion score based on optical flow 30 | Args: 31 | inputs: the selected input key frame with range [-1,1] 32 | FlowModel: the network of flow 33 | Returns: 34 | motion_score: motion_score 35 | ''' 36 | inputs = (0.5 * inputs + 0.5) * 255.0 37 | motion_scores = [] 38 | 39 | for current_frames, next_frames in zip(inputs[:, :-1, :], inputs[:, 1:, :]): 40 | backward_flows = FlowModel(current_frames, next_frames, iters=20, test_mode=True)[1]#[T-1, 2, H, W] 41 | 42 | # compute modulus of optical flow vectors as motion score 43 | magnitude = torch.sqrt(backward_flows[:, 0, :, :] ** 2 + backward_flows[:, 1, :, :] ** 2) 44 | 45 | # average motion score for dimension T, H, W 46 | motion_score = magnitude.mean(dim=[0, 1, 2]) 47 | motion_scores.append(motion_score) 48 | 49 | motion_scores = torch.stack(motion_scores) 50 | return motion_scores 51 | 52 | def pixel2latent(pixel_values, vae): 53 | video_length = pixel_values.shape[1] 54 | with torch.no_grad(): 55 | # encode each video to avoid OOM 56 | latents = [] 57 | for pixel_value in pixel_values: 58 | latent = vae.encode(pixel_value).latent_dist.sample() 59 | latents.append(latent) 60 | latents = torch.cat(latents, dim=0) 61 | latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length) 62 | latents = latents * vae.config.scaling_factor 63 | return latents 64 | 65 | def encode_image(pixel_values, feature_extractor, image_encoder): 66 | pixel_values = _resize_with_antialiasing(pixel_values, (224, 224)) 67 | pixel_values = (pixel_values + 1.0) / 2.0 68 | 69 | # Normalize the image with for CLIP input 70 | pixel_values = feature_extractor( 71 | images=pixel_values, 72 | do_normalize=True, 73 | do_center_crop=False, 74 | do_resize=False, 75 | do_rescale=False, 76 | return_tensors="pt", 77 | ).pixel_values 78 | 79 | pixel_values = pixel_values.to(image_encoder.device) 80 | 81 | with torch.no_grad(): 82 | image_embeddings = image_encoder(pixel_values).image_embeds 83 | image_embeddings = image_embeddings.unsqueeze(dim=1) 84 | return image_embeddings 85 | 86 | def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): 87 | h, w = input.shape[-2:] 88 | factors = (h / size[0], w / size[1]) 89 | 90 | # First, we have to determine sigma 91 | # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 92 | sigmas = ( 93 | max((factors[0] - 1.0) / 2.0, 0.001), 94 | max((factors[1] - 1.0) / 2.0, 0.001), 95 | ) 96 | 97 | # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma 98 | # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 99 | # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now 100 | ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) 101 | 102 | # Make sure it is odd 103 | if (ks[0] % 2) == 0: 104 | ks = ks[0] + 1, ks[1] 105 | 106 | if (ks[1] % 2) == 0: 107 | ks = ks[0], ks[1] + 1 108 | 109 | input = _gaussian_blur2d(input, ks, sigmas) 110 | 111 | output = torch.nn.functional.interpolate( 112 | input, size=size, mode=interpolation, align_corners=align_corners) 113 | return output 114 | 115 | 116 | def _compute_padding(kernel_size): 117 | """Compute padding tuple.""" 118 | # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) 119 | # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad 120 | if len(kernel_size) < 2: 121 | raise AssertionError(kernel_size) 122 | computed = [k - 1 for k in kernel_size] 123 | 124 | # for even kernels we need to do asymmetric padding :( 125 | out_padding = 2 * len(kernel_size) * [0] 126 | 127 | for i in range(len(kernel_size)): 128 | computed_tmp = computed[-(i + 1)] 129 | 130 | pad_front = computed_tmp // 2 131 | pad_rear = computed_tmp - pad_front 132 | 133 | out_padding[2 * i + 0] = pad_front 134 | out_padding[2 * i + 1] = pad_rear 135 | 136 | return out_padding 137 | 138 | 139 | def _filter2d(input, kernel): 140 | # prepare kernel 141 | b, c, h, w = input.shape 142 | tmp_kernel = kernel[:, None, ...].to( 143 | device=input.device, dtype=input.dtype) 144 | 145 | tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) 146 | 147 | height, width = tmp_kernel.shape[-2:] 148 | 149 | padding_shape: list[int] = _compute_padding([height, width]) 150 | input = torch.nn.functional.pad(input, padding_shape, mode="reflect") 151 | 152 | # kernel and input tensor reshape to align element-wise or batch-wise params 153 | tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) 154 | input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) 155 | 156 | # convolve the tensor with the kernel. 157 | output = torch.nn.functional.conv2d( 158 | input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) 159 | 160 | out = output.view(b, c, h, w) 161 | return out 162 | 163 | 164 | def _gaussian(window_size: int, sigma): 165 | if isinstance(sigma, float): 166 | sigma = torch.tensor([[sigma]]) 167 | 168 | batch_size = sigma.shape[0] 169 | 170 | x = (torch.arange(window_size, device=sigma.device, 171 | dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) 172 | 173 | if window_size % 2 == 0: 174 | x = x + 0.5 175 | 176 | gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) 177 | 178 | return gauss / gauss.sum(-1, keepdim=True) 179 | 180 | 181 | def _gaussian_blur2d(input, kernel_size, sigma): 182 | if isinstance(sigma, tuple): 183 | sigma = torch.tensor([sigma], dtype=input.dtype) 184 | else: 185 | sigma = sigma.to(dtype=input.dtype) 186 | 187 | ky, kx = int(kernel_size[0]), int(kernel_size[1]) 188 | bs = sigma.shape[0] 189 | kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) 190 | kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) 191 | out_x = _filter2d(input, kernel_x[..., None, :]) 192 | out = _filter2d(out_x, kernel_y[..., None]) 193 | 194 | return out 195 | -------------------------------------------------------------------------------- /examples/SVD/svd/schedulers/edm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def add_noise(inputs, P_mean, P_std):#input [B,T,C,H,W] 4 | noise = torch.randn_like(inputs) #N(0,I) 5 | rnd_normal = torch.randn([inputs.shape[0], 1, 1, 1, 1], device=inputs.device) #N(0,I)采样 [B,1,1,1,1] 6 | sigma = (rnd_normal * P_std + P_mean).exp() # N(P_mean,P_std) 7 | noisy_inputs = inputs + noise * sigma 8 | return noisy_inputs, sigma -------------------------------------------------------------------------------- /examples/SVD/svd/training/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..schedulers.edm import add_noise 3 | from ..models.utils import VScalingWithEDMcNoise 4 | from ..models.utils import pixel2latent 5 | import random 6 | 7 | # sample from a logit-normal distribution 8 | def logit_normal_sampler(m, s=1, beta_m=15, sample_num=1000000): 9 | y_samples = torch.randn(sample_num).reshape([m.shape[0], 1, 1, 1, 1]) * s + m 10 | x_samples = beta_m * (torch.exp(y_samples) / (1 + torch.exp(y_samples))) 11 | return x_samples 12 | 13 | # the $\mu(t)$ function 14 | def mu_t(t, a=5, mu_max=1): 15 | t = t.to('cpu') 16 | return 2 * mu_max * t**a - mu_max 17 | 18 | # get $\beta_s$ for TimeNoise 19 | def get_beta_s(t, a,beta_m): 20 | mu = mu_t(t,a=a) 21 | sigma_s = logit_normal_sampler(m=mu, sample_num=t.shape[0], beta_m=beta_m) 22 | return sigma_s 23 | 24 | # loss function for TimeNoise of the paper: https://arxiv.org/pdf/2406.15735. 25 | class EDMLossTimeNoise: 26 | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5,beta_m=15,a=5): 27 | self.P_mean = P_mean 28 | self.P_std = P_std 29 | self.sigma_data = sigma_data 30 | self.beta_m=beta_m 31 | self.a=a 32 | 33 | def __call__(self, unet, latents, encoder_hidden_states, 34 | mixed_precision_training,pixel_values,vae,cfg_random_null_ratio, 35 | motion_bucket_id,fps): 36 | # add noise to ground-truth video 37 | noisy_latents, sigma = add_noise(latents, self.P_mean, self.P_std) 38 | c_skip, c_out, c_in, c_noise = VScalingWithEDMcNoise(sigma) 39 | scaled_inputs = noisy_latents * c_in 40 | 41 | # conditional image 42 | img_condition = pixel_values[:, 0, :, :, :].unsqueeze(dim=1) 43 | img_condition_latents = pixel2latent(img_condition, vae)/vae.config.scaling_factor 44 | 45 | # applying TimeNoise. Add logit-normal noise to the conditional image. 46 | noise_aug_strength =get_beta_s(sigma/700,self.a,self.beta_m).reshape([latents.shape[0], 1, 1, 1, 1]).to(latents.device) 47 | rnd_normal = torch.randn([img_condition_latents.shape[0], 1, 1, 1, 1], device=img_condition_latents.device) 48 | noisy_condition_latents =img_condition_latents + noise_aug_strength * rnd_normal 49 | 50 | # classifier-free guidance 51 | if cfg_random_null_ratio > 0.0: 52 | p = random.random() 53 | noisy_condition_latents = noisy_condition_latents if p > cfg_random_null_ratio else torch.zeros_like(noisy_condition_latents) 54 | encoder_hidden_states = encoder_hidden_states if p > cfg_random_null_ratio else torch.zeros_like(encoder_hidden_states) 55 | 56 | # Repeat the condition latents for each frame so we can concatenate them with the noise 57 | noisy_condition_latents = noisy_condition_latents.repeat(1, latents.shape[1], 1, 1, 1) 58 | 59 | 60 | batch_size = noise_aug_strength.shape[0] 61 | motion_score = torch.tensor([motion_bucket_id]).repeat(batch_size).to(latents.device) 62 | fps = torch.tensor([fps]).repeat(batch_size).to(latents.device) 63 | added_time_ids = torch.stack([fps, motion_score, noise_aug_strength.reshape(batch_size)], dim=1) 64 | 65 | scaled_inputs = torch.cat([scaled_inputs,noisy_condition_latents], dim=2) 66 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 67 | 68 | # calculate loss 69 | with torch.cuda.amp.autocast(enabled=mixed_precision_training): 70 | model_pred = unet( 71 | scaled_inputs, c_noise, encoder_hidden_states, added_time_ids=added_time_ids)["sample"] 72 | 73 | pred = model_pred * c_out + c_skip * noisy_latents 74 | loss = torch.mean((weight.float() * (pred.float() - latents.float()) ** 2)) 75 | 76 | return loss 77 | 78 | 79 | 80 | # baseline loss function 81 | class EDMLossBaseline: 82 | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5): 83 | self.P_mean = P_mean 84 | self.P_std = P_std 85 | self.sigma_data = sigma_data 86 | 87 | def __call__(self, unet, latents, encoder_hidden_states, 88 | mixed_precision_training,pixel_values,vae,cfg_random_null_ratio, 89 | motion_bucket_id,fps): 90 | # add noise to ground-truth video 91 | noisy_latents, sigma = add_noise(latents, self.P_mean, self.P_std) 92 | c_skip, c_out, c_in, c_noise = VScalingWithEDMcNoise(sigma) 93 | scaled_inputs = noisy_latents * c_in 94 | 95 | # baseline noise argumentation 96 | noisy_condition, noise_aug_strength = add_noise(pixel_values[:, 0, :, :, :].unsqueeze(dim=1), P_mean=self.P_mean, P_std=self.P_std) 97 | noisy_condition_latents = pixel2latent(noisy_condition, vae)/vae.config.scaling_factor 98 | 99 | # classifier-free guidance 100 | if cfg_random_null_ratio > 0.0: 101 | p = random.random() 102 | noisy_condition_latents = noisy_condition_latents if p > cfg_random_null_ratio else torch.zeros_like(noisy_condition_latents) 103 | encoder_hidden_states = encoder_hidden_states if p > cfg_random_null_ratio else torch.zeros_like(encoder_hidden_states) 104 | 105 | # Repeat the condition latents for each frame so we can concatenate them with the noise 106 | noisy_condition_latents = noisy_condition_latents.repeat(1, latents.shape[1], 1, 1, 1) 107 | 108 | batch_size = noise_aug_strength.shape[0] 109 | motion_score = torch.tensor([motion_bucket_id]).repeat(batch_size).to(latents.device) 110 | fps = torch.tensor([fps]).repeat(batch_size).to(latents.device) 111 | added_time_ids = torch.stack([fps, motion_score, noise_aug_strength.reshape(batch_size)], dim=1) 112 | 113 | scaled_inputs = torch.cat([scaled_inputs,noisy_condition_latents], dim=2) 114 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 115 | 116 | # calculate loss 117 | with torch.cuda.amp.autocast(enabled=mixed_precision_training): 118 | model_pred = unet( 119 | scaled_inputs, c_noise, encoder_hidden_states, added_time_ids=added_time_ids)["sample"] 120 | 121 | pred = model_pred * c_out + c_skip * noisy_latents 122 | loss = torch.mean((weight.float() * (pred.float() - latents.float()) ** 2)) 123 | 124 | return loss 125 | 126 | -------------------------------------------------------------------------------- /examples/SVD/train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=7 torchrun --nproc_per_node=1 train.py --config "config/train.yaml" 2 | -------------------------------------------------------------------------------- /examples/VideoCrafter/configs/inference_i2v_512_v1.0.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: lvdm.models.ddpm3d.LatentVisualDiffusion 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.012 6 | num_timesteps_cond: 1 7 | timesteps: 1000 8 | first_stage_key: video 9 | cond_stage_key: caption 10 | cond_stage_trainable: false 11 | conditioning_key: crossattn 12 | image_size: 13 | - 40 14 | - 64 15 | channels: 4 16 | scale_by_std: false 17 | scale_factor: 0.18215 18 | use_ema: false 19 | uncond_type: empty_seq 20 | use_scale: true 21 | scale_b: 0.7 22 | finegrained: true 23 | unet_config: 24 | target: lvdm.modules.networks.openaimodel3d_videocrafter.UNetModel 25 | params: 26 | in_channels: 4 27 | out_channels: 4 28 | model_channels: 320 29 | attention_resolutions: 30 | - 4 31 | - 2 32 | - 1 33 | num_res_blocks: 2 34 | channel_mult: 35 | - 1 36 | - 2 37 | - 4 38 | - 4 39 | num_head_channels: 64 40 | transformer_depth: 1 41 | context_dim: 1024 42 | use_linear: true 43 | use_checkpoint: true 44 | temporal_conv: true 45 | temporal_attention: true 46 | temporal_selfatt_only: true 47 | use_relative_position: false 48 | use_causal_attention: false 49 | use_image_attention: true 50 | temporal_length: 16 51 | addition_attention: true 52 | fps_cond: true 53 | first_stage_config: 54 | target: lvdm.models.autoencoder.AutoencoderKL 55 | params: 56 | embed_dim: 4 57 | monitor: val/rec_loss 58 | ddconfig: 59 | double_z: true 60 | z_channels: 4 61 | resolution: 512 62 | in_channels: 3 63 | out_ch: 3 64 | ch: 128 65 | ch_mult: 66 | - 1 67 | - 2 68 | - 4 69 | - 4 70 | num_res_blocks: 2 71 | attn_resolutions: [] 72 | dropout: 0.0 73 | lossconfig: 74 | target: torch.nn.Identity 75 | cond_stage_config: 76 | target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder 77 | params: 78 | freeze: true 79 | layer: penultimate 80 | cond_img_config: 81 | target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2 82 | params: 83 | freeze: true -------------------------------------------------------------------------------- /examples/VideoCrafter/configs/train.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | pretrained_checkpoint: ckpt/original/model.ckpt 3 | base_learning_rate: 1.0e-05 4 | scale_lr: False 5 | target: lvdm.models.ddpm3d_videocrafter_ve.LatentVisualDiffusion 6 | params: 7 | linear_start: 0.00085 8 | linear_end: 0.012 9 | num_timesteps_cond: 1 10 | timesteps: 1000 11 | first_stage_key: video 12 | cond_stage_key: caption 13 | cond_stage_trainable: false 14 | conditioning_key: crossattn 15 | image_size: [40, 64] 16 | channels: 4 17 | scale_by_std: false 18 | scale_factor: 0.18215 19 | use_ema: false 20 | uncond_type: empty_seq 21 | use_scale: true 22 | scale_b: 0.7 23 | finegrained: true 24 | unet_config: 25 | target: lvdm.modules.networks.openaimodel3d_videocrafter.UNetModel 26 | params: 27 | in_channels: 4 28 | out_channels: 4 29 | model_channels: 320 30 | attention_resolutions: 31 | - 4 32 | - 2 33 | - 1 34 | num_res_blocks: 2 35 | channel_mult: 36 | - 1 37 | - 2 38 | - 4 39 | - 4 40 | num_head_channels: 64 41 | transformer_depth: 1 42 | context_dim: 1024 43 | use_linear: true 44 | use_checkpoint: True 45 | temporal_conv: True 46 | temporal_attention: True 47 | temporal_selfatt_only: true 48 | use_relative_position: false 49 | use_causal_attention: False 50 | use_image_attention: true 51 | temporal_length: 16 52 | addition_attention: true 53 | fps_cond: true 54 | first_stage_config: 55 | target: lvdm.models.autoencoder.AutoencoderKL 56 | params: 57 | embed_dim: 4 58 | monitor: val/rec_loss 59 | ddconfig: 60 | double_z: True 61 | z_channels: 4 62 | resolution: 512 63 | in_channels: 3 64 | out_ch: 3 65 | ch: 128 66 | ch_mult: 67 | - 1 68 | - 2 69 | - 4 70 | - 4 71 | num_res_blocks: 2 72 | attn_resolutions: [] 73 | dropout: 0.0 74 | lossconfig: 75 | target: torch.nn.Identity 76 | 77 | cond_stage_config: 78 | target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder 79 | params: 80 | freeze: true 81 | layer: "penultimate" 82 | 83 | cond_img_config: 84 | target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2 85 | params: 86 | freeze: true 87 | beta_m: 100 88 | a: 5 89 | data: 90 | target: utils_data.DataModuleFromConfig 91 | params: 92 | batch_size: 1 93 | num_workers: 12 94 | wrap: false 95 | train: 96 | target: lvdm.data.webvid.WebVid 97 | params: 98 | file_path: '../dataset' 99 | video_folder: '../dataset/results_2M_train.csv' 100 | video_length: 16 101 | frame_stride: 6 102 | load_raw_resolution: true 103 | resolution: [320, 512] 104 | spatial_transform: resize_center_crop 105 | random_fs: true ## if true, we uniformly sample fs with max_fs=frame_stride (above) 106 | 107 | lightning: 108 | precision: 16 109 | # strategy: deepspeed_stage_2 110 | trainer: 111 | benchmark: True 112 | accumulate_grad_batches: 1 113 | max_steps: 100000 114 | # logger 115 | log_every_n_steps: 50 116 | # val 117 | val_check_interval: 0.5 118 | gradient_clip_algorithm: 'norm' 119 | gradient_clip_val: 0.5 120 | callbacks: 121 | model_checkpoint: 122 | target: pytorch_lightning.callbacks.ModelCheckpoint 123 | params: 124 | every_n_train_steps: 1000 #1000 125 | filename: "{epoch}-{step}" 126 | save_weights_only: True 127 | metrics_over_trainsteps_checkpoint: 128 | target: pytorch_lightning.callbacks.ModelCheckpoint 129 | params: 130 | filename: '{epoch}-{step}' 131 | save_weights_only: True 132 | every_n_train_steps: 2000 #20000 # 3s/step*2w= 133 | batch_logger: 134 | target: callbacks.ImageLogger 135 | params: 136 | batch_frequency: 5000000 137 | to_local: False 138 | max_images: 8 139 | log_images_kwargs: 140 | ddim_steps: 50 141 | unconditional_guidance_scale: 7.5 142 | timestep_spacing: uniform_trailing 143 | guidance_rescale: 0.7 -------------------------------------------------------------------------------- /examples/VideoCrafter/inference_512.sh: -------------------------------------------------------------------------------- 1 | ckpt='ckpt/original/model.ckpt' # path to your checkpoint 2 | config='configs/inference_i2v_512_v1.0.yaml' 3 | 4 | prompt_file="prompts/512/test_prompts.txt" # file for 5 | condimage_dir="prompts/512" # file for conditional images 6 | res_dir="results" # file for outputs 7 | 8 | H=320 9 | W=512 10 | FS=24 11 | M=940 12 | 13 | CUDA_VISIBLE_DEVICES=1 python3 -m torch.distributed.launch \ 14 | --nproc_per_node=1 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \ 15 | scripts/evaluation/ddp_wrapper.py \ 16 | --module 'inference' \ 17 | --seed 123 \ 18 | --mode 'i2v' \ 19 | --ckpt_path $ckpt \ 20 | --config $config \ 21 | --savedir $res_dir \ 22 | --n_samples 1 \ 23 | --bs 1 --height ${H} --width ${W} \ 24 | --unconditional_guidance_scale 12.0 \ 25 | --ddim_steps 50 \ 26 | --ddim_eta 1.0 \ 27 | --prompt_file $prompt_file \ 28 | --cond_input $condimage_dir \ 29 | --fps ${FS} \ 30 | --savefps 8 \ 31 | --frames 16 \ 32 | --M ${M} \ 33 | --analytic_init_path "ckpt/initial_noise_512.pt" 34 | -------------------------------------------------------------------------------- /examples/VideoCrafter/inference_CIL_512.sh: -------------------------------------------------------------------------------- 1 | ckpt='ckpt/finetuned/model.ckpt' # path to your checkpoint 2 | config='configs/inference_i2v_512_v1.0.yaml' 3 | 4 | prompt_file="prompts/512/test_prompts.txt" # file for 5 | condimage_dir="prompts/512" # file for conditional images 6 | res_dir="results" # file for outputs 7 | 8 | H=320 9 | W=512 10 | FS=24 11 | M=940 12 | 13 | CUDA_VISIBLE_DEVICES=1 python3 -m torch.distributed.launch \ 14 | --nproc_per_node=1 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \ 15 | scripts/evaluation/ddp_wrapper.py \ 16 | --module 'inference' \ 17 | --seed 123 \ 18 | --mode 'i2v' \ 19 | --ckpt_path $ckpt \ 20 | --config $config \ 21 | --savedir $res_dir \ 22 | --n_samples 1 \ 23 | --bs 1 --height ${H} --width ${W} \ 24 | --unconditional_guidance_scale 12.0 \ 25 | --ddim_steps 50 \ 26 | --ddim_eta 1.0 \ 27 | --prompt_file $prompt_file \ 28 | --cond_input $condimage_dir \ 29 | --fps ${FS} \ 30 | --savefps 8 \ 31 | --frames 16 \ 32 | --M ${M} \ 33 | --analytic_init_path "ckpt/initial_noise_512.pt" 34 | -------------------------------------------------------------------------------- /examples/VideoCrafter/libs/eval_funcs.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | def evaluation(*args, **kwargs): 5 | pass -------------------------------------------------------------------------------- /examples/VideoCrafter/libs/losses.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import matplotlib.pyplot as plt 4 | from einops import rearrange, repeat 5 | import numpy as np 6 | 7 | 8 | def logit_normal_sampler( 9 | m, s=1, betam=0.9977, sample_num=1000000, 10 | ): 11 | y_samples = torch.randn(sample_num) * s + m 12 | x_samples = betam * (torch.exp(y_samples) / (1 + torch.exp(y_samples))) 13 | return x_samples 14 | 15 | 16 | def mu_t(t, mu_t_type="linear", mu_max=1.5): 17 | t = t.to("cpu") 18 | if mu_t_type == "linear": 19 | return 2 * mu_max * t - mu_max 20 | elif mu_t_type == "Convex_10": 21 | return 2 * mu_max * t**10 - mu_max 22 | elif mu_t_type == "Convex_5": 23 | return 2 * mu_max * t**5 - mu_max 24 | elif mu_t_type == "sigmoid": 25 | return ( 26 | 2 * mu_max * np.exp(50 * (t - 0.8)) / (1 + np.exp(50 * (t - 0.8))) - mu_max 27 | ) 28 | 29 | def get_alpha_s_and_sigma_s(t, mut_type): 30 | mu = mu_t(t, mu_t_type=mut_type) 31 | sigma_s = logit_normal_sampler(m=mu, sample_num=t.shape[0], whether_paint=False) 32 | alpha_s = torch.sqrt(1 - sigma_s**2) 33 | return alpha_s, sigma_s 34 | 35 | 36 | def dynamicrafter_loss(model_context, data_context, config, accelerator): 37 | 38 | batch = next(data_context["train_data_generator"]) 39 | model = model_context["model"] 40 | local_rank = model_context["device"] 41 | dtype = model_context["dtype"] 42 | 43 | @torch.no_grad() 44 | @torch.autocast("cuda") 45 | def get_latent_z(model, videos): 46 | b, c, t, h, w = videos.shape 47 | x = rearrange(videos, "b c t h w -> (b t) c h w") 48 | z = model.encode_first_stage(x) 49 | z = rearrange(z, "(b t) c h w -> b c t h w", b=b, t=t) 50 | return z 51 | 52 | with torch.autocast("cuda", dtype=dtype): 53 | p = random.random() 54 | pixel_values = batch["pixel_values"].to(local_rank) 55 | # classifier-free guidance 56 | batch["text"] = [ 57 | name if p > config.cfg_random_null_ratio else "" for name in batch["text"] 58 | ] 59 | prompts = batch["text"] 60 | fs = batch["fps"].to(local_rank) 61 | batch_size = pixel_values.shape[0] 62 | z = get_latent_z(model, pixel_values) # b c t h w 63 | 64 | # add noise 65 | t = torch.randint(0, model.num_timesteps, (batch_size,), device=z.device).long() 66 | noise = torch.randn_like(z) 67 | noisy_z = model.q_sample(z, t, noise=noise) 68 | 69 | # condition 70 | # classifier-free guidance 71 | img = ( 72 | pixel_values[:, :, 0] 73 | if p > config.cfg_random_null_ratio 74 | else torch.zeros_like(pixel_values[:, :, 0]) 75 | ) 76 | img_emb = model.embedder(img) ## blc 77 | img_emb = model.image_proj_model(img_emb) 78 | cond_emb = model.get_learned_conditioning(prompts) 79 | cond = {"c_crossattn": [torch.cat([cond_emb, img_emb], dim=1)]} 80 | 81 | img_cat_cond = z[:, :, :1, :, :] 82 | # add noise on condition 83 | if config.condition_type is not None: 84 | alpha_s, sigma_s = get_alpha_s_and_sigma_s(t / 1000.0, config.mu_type) 85 | condition_noise = torch.randn_like(img_cat_cond) 86 | alpha_s = alpha_s.reshape([batch_size, 1, 1, 1, 1]).to(local_rank) 87 | sigma_s = sigma_s.reshape([batch_size, 1, 1, 1, 1]).to(local_rank) 88 | img_cat_cond = alpha_s * img_cat_cond + sigma_s * condition_noise 89 | 90 | # classifier-free guidance 91 | img_cat_cond = ( 92 | img_cat_cond 93 | if p > config.cfg_random_null_ratio 94 | else torch.zeros_like(img_cat_cond) 95 | ) 96 | img_cat_cond = repeat( 97 | img_cat_cond, "b c t h w -> b c (repeat t) h w", repeat=z.shape[2] 98 | ) 99 | cond["c_concat"] = [img_cat_cond] # b c 1 h w 100 | 101 | model_pred = model.apply_model(noisy_z, t, cond, fs=fs) 102 | loss = torch.mean(((model_pred.float() - noise.float()) ** 2)) 103 | 104 | return { 105 | "loss": loss, 106 | } 107 | -------------------------------------------------------------------------------- /examples/VideoCrafter/libs/special_functions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from base_utils import PrintContext 5 | 6 | def load_model_checkpoint(model, ckpt): 7 | state_dict = torch.load(ckpt, map_location="cpu") 8 | if "state_dict" in list(state_dict.keys()): 9 | state_dict = state_dict["state_dict"] 10 | try: 11 | model.load_state_dict(state_dict, strict=True) 12 | except: 13 | ## rename the keys for 256x256 model 14 | new_pl_sd = OrderedDict() 15 | for k, v in state_dict.items(): 16 | new_pl_sd[k] = v 17 | 18 | for k in list(new_pl_sd.keys()): 19 | if "framestride_embed" in k: 20 | new_key = k.replace("framestride_embed", "fps_embedding") 21 | new_pl_sd[new_key] = new_pl_sd[k] 22 | del new_pl_sd[k] 23 | model.load_state_dict(new_pl_sd, strict=True) 24 | else: 25 | # deepspeed 26 | new_pl_sd = OrderedDict() 27 | for key in state_dict['module'].keys(): 28 | new_pl_sd[key[16:]] = state_dict['module'][key] 29 | model.load_state_dict(new_pl_sd) 30 | print('>>> model checkpoint loaded.') 31 | return model 32 | 33 | 34 | def model_load(model_context, config, **kwargs): 35 | model = model_context["model"] 36 | assert os.path.exists(config.ckpt_path), "Error: checkpoint Not Found!" 37 | model = load_model_checkpoint(model, config.ckpt_path) 38 | return model_context 39 | 40 | 41 | from lvdm.modules.attention import TemporalTransformer 42 | from lvdm.modules.networks.openaimodel3d import TemporalConvBlock 43 | # trainable modules for optimizer 44 | def freeze_layers(model, layer_types_to_freeze=(TemporalTransformer, TemporalConvBlock)): 45 | for name, module in model.model.diffusion_model.named_modules(): 46 | if isinstance(module, tuple(layer_types_to_freeze)): 47 | for param in module.parameters(): 48 | param.requires_grad = False 49 | 50 | def get_trainable_params(model_context, config, **kwargs): 51 | with PrintContext(f"{'='*20} get trainable params {'='*20}"): 52 | model = model_context["model"] 53 | freeze_layers(model, (TemporalTransformer, TemporalConvBlock)) 54 | trainable_params = list(model.image_proj_model.parameters()) + [param for param in model.model.diffusion_model.parameters() if param.requires_grad] 55 | 56 | return trainable_params -------------------------------------------------------------------------------- /examples/VideoCrafter/lvdm/basics.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | import torch.nn as nn 11 | from utils.utils import instantiate_from_config 12 | 13 | 14 | def disabled_train(self, mode=True): 15 | """Overwrite model.train with this function to make sure train/eval mode 16 | does not change anymore.""" 17 | return self 18 | 19 | def zero_module(module): 20 | """ 21 | Zero out the parameters of a module and return it. 22 | """ 23 | for p in module.parameters(): 24 | p.detach().zero_() 25 | return module 26 | 27 | def scale_module(module, scale): 28 | """ 29 | Scale the parameters of a module and return it. 30 | """ 31 | for p in module.parameters(): 32 | p.detach().mul_(scale) 33 | return module 34 | 35 | 36 | def conv_nd(dims, *args, **kwargs): 37 | """ 38 | Create a 1D, 2D, or 3D convolution module. 39 | """ 40 | if dims == 1: 41 | return nn.Conv1d(*args, **kwargs) 42 | elif dims == 2: 43 | return nn.Conv2d(*args, **kwargs) 44 | elif dims == 3: 45 | return nn.Conv3d(*args, **kwargs) 46 | raise ValueError(f"unsupported dimensions: {dims}") 47 | 48 | 49 | def linear(*args, **kwargs): 50 | """ 51 | Create a linear module. 52 | """ 53 | return nn.Linear(*args, **kwargs) 54 | 55 | 56 | def avg_pool_nd(dims, *args, **kwargs): 57 | """ 58 | Create a 1D, 2D, or 3D average pooling module. 59 | """ 60 | if dims == 1: 61 | return nn.AvgPool1d(*args, **kwargs) 62 | elif dims == 2: 63 | return nn.AvgPool2d(*args, **kwargs) 64 | elif dims == 3: 65 | return nn.AvgPool3d(*args, **kwargs) 66 | raise ValueError(f"unsupported dimensions: {dims}") 67 | 68 | 69 | def nonlinearity(type='silu'): 70 | if type == 'silu': 71 | return nn.SiLU() 72 | elif type == 'leaky_relu': 73 | return nn.LeakyReLU() 74 | 75 | 76 | class GroupNormSpecific(nn.GroupNorm): 77 | def forward(self, x): 78 | return super().forward(x.float()).type(x.dtype) 79 | 80 | 81 | def normalization(channels, num_groups=32): 82 | """ 83 | Make a standard normalization layer. 84 | :param channels: number of input channels. 85 | :return: an nn.Module for normalization. 86 | """ 87 | return GroupNormSpecific(num_groups, channels) 88 | 89 | 90 | class HybridConditioner(nn.Module): 91 | 92 | def __init__(self, c_concat_config, c_crossattn_config): 93 | super().__init__() 94 | self.concat_conditioner = instantiate_from_config(c_concat_config) 95 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 96 | 97 | def forward(self, c_concat, c_crossattn): 98 | c_concat = self.concat_conditioner(c_concat) 99 | c_crossattn = self.crossattn_conditioner(c_crossattn) 100 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} -------------------------------------------------------------------------------- /examples/VideoCrafter/lvdm/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | from inspect import isfunction 3 | import torch 4 | from torch import nn 5 | import torch.distributed as dist 6 | 7 | 8 | def gather_data(data, return_np=True): 9 | ''' gather data from multiple processes to one list ''' 10 | data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())] 11 | dist.all_gather(data_list, data) # gather not supported with NCCL 12 | if return_np: 13 | data_list = [data.cpu().numpy() for data in data_list] 14 | return data_list 15 | 16 | def autocast(f): 17 | def do_autocast(*args, **kwargs): 18 | with torch.cuda.amp.autocast(enabled=True, 19 | dtype=torch.get_autocast_gpu_dtype(), 20 | cache_enabled=torch.is_autocast_cache_enabled()): 21 | return f(*args, **kwargs) 22 | return do_autocast 23 | 24 | 25 | def extract_into_tensor(a, t, x_shape): 26 | b, *_ = t.shape 27 | out = a.gather(-1, t) 28 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 29 | 30 | 31 | def noise_like(shape, device, repeat=False): 32 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 33 | noise = lambda: torch.randn(shape, device=device) 34 | return repeat_noise() if repeat else noise() 35 | 36 | 37 | def default(val, d): 38 | if exists(val): 39 | return val 40 | return d() if isfunction(d) else d 41 | 42 | def exists(val): 43 | return val is not None 44 | 45 | def identity(*args, **kwargs): 46 | return nn.Identity() 47 | 48 | def uniq(arr): 49 | return{el: True for el in arr}.keys() 50 | 51 | def mean_flat(tensor): 52 | """ 53 | Take the mean over all non-batch dimensions. 54 | """ 55 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 56 | 57 | def ismap(x): 58 | if not isinstance(x, torch.Tensor): 59 | return False 60 | return (len(x.shape) == 4) and (x.shape[1] > 3) 61 | 62 | def isimage(x): 63 | if not isinstance(x,torch.Tensor): 64 | return False 65 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 66 | 67 | def max_neg_value(t): 68 | return -torch.finfo(t.dtype).max 69 | 70 | def shape_to_str(x): 71 | shape_str = "x".join([str(x) for x in x.shape]) 72 | return shape_str 73 | 74 | def init_(tensor): 75 | dim = tensor.shape[-1] 76 | std = 1 / math.sqrt(dim) 77 | tensor.uniform_(-std, std) 78 | return tensor 79 | 80 | ckpt = torch.utils.checkpoint.checkpoint 81 | def checkpoint(func, inputs, params, flag): 82 | """ 83 | Evaluate a function without caching intermediate activations, allowing for 84 | reduced memory at the expense of extra compute in the backward pass. 85 | :param func: the function to evaluate. 86 | :param inputs: the argument sequence to pass to `func`. 87 | :param params: a sequence of parameters `func` depends on but does not 88 | explicitly take as arguments. 89 | :param flag: if False, disable gradient checkpointing. 90 | """ 91 | if flag: 92 | return ckpt(func, *inputs, use_reentrant=False) 93 | else: 94 | return func(*inputs) -------------------------------------------------------------------------------- /examples/VideoCrafter/lvdm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /examples/VideoCrafter/lvdm/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self, noise=None): 36 | if noise is None: 37 | noise = torch.randn(self.mean.shape) 38 | 39 | x = self.mean + self.std * noise.to(device=self.parameters.device) 40 | return x 41 | 42 | def kl(self, other=None): 43 | if self.deterministic: 44 | return torch.Tensor([0.]) 45 | else: 46 | if other is None: 47 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 48 | + self.var - 1.0 - self.logvar, 49 | dim=[1, 2, 3]) 50 | else: 51 | return 0.5 * torch.sum( 52 | torch.pow(self.mean - other.mean, 2) / other.var 53 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 54 | dim=[1, 2, 3]) 55 | 56 | def nll(self, sample, dims=[1,2,3]): 57 | if self.deterministic: 58 | return torch.Tensor([0.]) 59 | logtwopi = np.log(2.0 * np.pi) 60 | return 0.5 * torch.sum( 61 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 62 | dim=dims) 63 | 64 | def mode(self): 65 | return self.mean 66 | 67 | 68 | def normal_kl(mean1, logvar1, mean2, logvar2): 69 | """ 70 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 71 | Compute the KL divergence between two gaussians. 72 | Shapes are automatically broadcasted, so batches can be compared to 73 | scalars, among other use cases. 74 | """ 75 | tensor = None 76 | for obj in (mean1, logvar1, mean2, logvar2): 77 | if isinstance(obj, torch.Tensor): 78 | tensor = obj 79 | break 80 | assert tensor is not None, "at least one argument must be a Tensor" 81 | 82 | # Force variances to be Tensors. Broadcasting helps convert scalars to 83 | # Tensors, but it does not work for torch.exp(). 84 | logvar1, logvar2 = [ 85 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 86 | for x in (logvar1, logvar2) 87 | ] 88 | 89 | return 0.5 * ( 90 | -1.0 91 | + logvar2 92 | - logvar1 93 | + torch.exp(logvar1 - logvar2) 94 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 95 | ) -------------------------------------------------------------------------------- /examples/VideoCrafter/lvdm/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) -------------------------------------------------------------------------------- /examples/VideoCrafter/lvdm/modules/encoders/ip_resampler.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class ImageProjModel(nn.Module): 8 | """Projection Model""" 9 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 10 | super().__init__() 11 | self.cross_attention_dim = cross_attention_dim 12 | self.clip_extra_context_tokens = clip_extra_context_tokens 13 | self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 14 | self.norm = nn.LayerNorm(cross_attention_dim) 15 | 16 | def forward(self, image_embeds): 17 | #embeds = image_embeds 18 | embeds = image_embeds.type(list(self.proj.parameters())[0].dtype) 19 | clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) 20 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 21 | return clip_extra_context_tokens 22 | 23 | # FFN 24 | def FeedForward(dim, mult=4): 25 | inner_dim = int(dim * mult) 26 | return nn.Sequential( 27 | nn.LayerNorm(dim), 28 | nn.Linear(dim, inner_dim, bias=False), 29 | nn.GELU(), 30 | nn.Linear(inner_dim, dim, bias=False), 31 | ) 32 | 33 | 34 | def reshape_tensor(x, heads): 35 | bs, length, width = x.shape 36 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head) 37 | x = x.view(bs, length, heads, -1) 38 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 39 | x = x.transpose(1, 2) 40 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 41 | x = x.reshape(bs, heads, length, -1) 42 | return x 43 | 44 | 45 | class PerceiverAttention(nn.Module): 46 | def __init__(self, *, dim, dim_head=64, heads=8): 47 | super().__init__() 48 | self.scale = dim_head**-0.5 49 | self.dim_head = dim_head 50 | self.heads = heads 51 | inner_dim = dim_head * heads 52 | 53 | self.norm1 = nn.LayerNorm(dim) 54 | self.norm2 = nn.LayerNorm(dim) 55 | 56 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 57 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 58 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 59 | 60 | 61 | def forward(self, x, latents): 62 | """ 63 | Args: 64 | x (torch.Tensor): image features 65 | shape (b, n1, D) 66 | latent (torch.Tensor): latent features 67 | shape (b, n2, D) 68 | """ 69 | x = self.norm1(x) 70 | latents = self.norm2(latents) 71 | 72 | b, l, _ = latents.shape 73 | 74 | q = self.to_q(latents) 75 | kv_input = torch.cat((x, latents), dim=-2) 76 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 77 | 78 | q = reshape_tensor(q, self.heads) 79 | k = reshape_tensor(k, self.heads) 80 | v = reshape_tensor(v, self.heads) 81 | 82 | # attention 83 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 84 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 85 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 86 | out = weight @ v 87 | 88 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 89 | 90 | return self.to_out(out) 91 | 92 | 93 | class Resampler(nn.Module): 94 | def __init__( 95 | self, 96 | dim=1024, 97 | depth=8, 98 | dim_head=64, 99 | heads=16, 100 | num_queries=8, 101 | embedding_dim=768, 102 | output_dim=1024, 103 | ff_mult=4, 104 | ): 105 | super().__init__() 106 | 107 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 108 | 109 | self.proj_in = nn.Linear(embedding_dim, dim) 110 | 111 | self.proj_out = nn.Linear(dim, output_dim) 112 | self.norm_out = nn.LayerNorm(output_dim) 113 | 114 | self.layers = nn.ModuleList([]) 115 | for _ in range(depth): 116 | self.layers.append( 117 | nn.ModuleList( 118 | [ 119 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 120 | FeedForward(dim=dim, mult=ff_mult), 121 | ] 122 | ) 123 | ) 124 | 125 | def forward(self, x): 126 | 127 | latents = self.latents.repeat(x.size(0), 1, 1) 128 | 129 | x = self.proj_in(x) 130 | 131 | for attn, ff in self.layers: 132 | latents = attn(x, latents) + latents 133 | latents = ff(latents) + latents 134 | 135 | latents = self.proj_out(latents) 136 | return self.norm_out(latents) -------------------------------------------------------------------------------- /examples/VideoCrafter/lvdm/modules/encoders/resampler.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py 3 | # and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class ImageProjModel(nn.Module): 10 | """Projection Model""" 11 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 12 | super().__init__() 13 | self.cross_attention_dim = cross_attention_dim 14 | self.clip_extra_context_tokens = clip_extra_context_tokens 15 | self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 16 | self.norm = nn.LayerNorm(cross_attention_dim) 17 | 18 | def forward(self, image_embeds): 19 | #embeds = image_embeds 20 | embeds = image_embeds.type(list(self.proj.parameters())[0].dtype) 21 | clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) 22 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 23 | return clip_extra_context_tokens 24 | 25 | 26 | # FFN 27 | def FeedForward(dim, mult=4): 28 | inner_dim = int(dim * mult) 29 | return nn.Sequential( 30 | nn.LayerNorm(dim), 31 | nn.Linear(dim, inner_dim, bias=False), 32 | nn.GELU(), 33 | nn.Linear(inner_dim, dim, bias=False), 34 | ) 35 | 36 | 37 | def reshape_tensor(x, heads): 38 | bs, length, width = x.shape 39 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head) 40 | x = x.view(bs, length, heads, -1) 41 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 42 | x = x.transpose(1, 2) 43 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 44 | x = x.reshape(bs, heads, length, -1) 45 | return x 46 | 47 | 48 | class PerceiverAttention(nn.Module): 49 | def __init__(self, *, dim, dim_head=64, heads=8): 50 | super().__init__() 51 | self.scale = dim_head**-0.5 52 | self.dim_head = dim_head 53 | self.heads = heads 54 | inner_dim = dim_head * heads 55 | 56 | self.norm1 = nn.LayerNorm(dim) 57 | self.norm2 = nn.LayerNorm(dim) 58 | 59 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 60 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 61 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 62 | 63 | 64 | def forward(self, x, latents): 65 | """ 66 | Args: 67 | x (torch.Tensor): image features 68 | shape (b, n1, D) 69 | latent (torch.Tensor): latent features 70 | shape (b, n2, D) 71 | """ 72 | x = self.norm1(x) 73 | latents = self.norm2(latents) 74 | 75 | b, l, _ = latents.shape 76 | 77 | q = self.to_q(latents) 78 | kv_input = torch.cat((x, latents), dim=-2) 79 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 80 | 81 | q = reshape_tensor(q, self.heads) 82 | k = reshape_tensor(k, self.heads) 83 | v = reshape_tensor(v, self.heads) 84 | 85 | # attention 86 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 87 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 88 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 89 | out = weight @ v 90 | 91 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 92 | 93 | return self.to_out(out) 94 | 95 | 96 | class Resampler(nn.Module): 97 | def __init__( 98 | self, 99 | dim=1024, 100 | depth=8, 101 | dim_head=64, 102 | heads=16, 103 | num_queries=8, 104 | embedding_dim=768, 105 | output_dim=1024, 106 | ff_mult=4, 107 | video_length=None, # using frame-wise version or not 108 | ): 109 | super().__init__() 110 | ## queries for a single frame / image 111 | self.num_queries = num_queries 112 | self.video_length = video_length 113 | 114 | ## queries for each frame 115 | if video_length is not None: 116 | num_queries = num_queries * video_length 117 | 118 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 119 | self.proj_in = nn.Linear(embedding_dim, dim) 120 | self.proj_out = nn.Linear(dim, output_dim) 121 | self.norm_out = nn.LayerNorm(output_dim) 122 | 123 | self.layers = nn.ModuleList([]) 124 | for _ in range(depth): 125 | self.layers.append( 126 | nn.ModuleList( 127 | [ 128 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 129 | FeedForward(dim=dim, mult=ff_mult), 130 | ] 131 | ) 132 | ) 133 | 134 | def forward(self, x): 135 | latents = self.latents.repeat(x.size(0), 1, 1) ## B (T L) C 136 | x = self.proj_in(x) 137 | 138 | for attn, ff in self.layers: 139 | latents = attn(x, latents) + latents 140 | latents = ff(latents) + latents 141 | 142 | latents = self.proj_out(latents) 143 | latents = self.norm_out(latents) # B L C or B (T L) C 144 | 145 | return latents -------------------------------------------------------------------------------- /examples/VideoCrafter/main/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | mainlogger = logging.getLogger('mainlogger') 5 | 6 | import torch 7 | import torchvision 8 | import pytorch_lightning as pl 9 | from pytorch_lightning.callbacks import Callback 10 | from pytorch_lightning.utilities import rank_zero_only 11 | from pytorch_lightning.utilities import rank_zero_info 12 | from utils.save_video import log_local, prepare_to_log 13 | 14 | 15 | class ImageLogger(Callback): 16 | def __init__(self, batch_frequency, max_images=8, clamp=True, rescale=True, save_dir=None, \ 17 | to_local=False, log_images_kwargs=None): 18 | super().__init__() 19 | self.rescale = rescale 20 | self.batch_freq = batch_frequency 21 | self.max_images = max_images 22 | self.to_local = to_local 23 | self.clamp = clamp 24 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} 25 | if self.to_local: 26 | ## default save dir 27 | self.save_dir = os.path.join(save_dir, "images") 28 | os.makedirs(os.path.join(self.save_dir, "train"), exist_ok=True) 29 | os.makedirs(os.path.join(self.save_dir, "val"), exist_ok=True) 30 | 31 | def log_to_tensorboard(self, pl_module, batch_logs, filename, split, save_fps=8): 32 | """ log images and videos to tensorboard """ 33 | global_step = pl_module.global_step 34 | for key in batch_logs: 35 | value = batch_logs[key] 36 | tag = "gs%d-%s/%s-%s"%(global_step, split, filename, key) 37 | if isinstance(value, list) and isinstance(value[0], str): 38 | captions = ' |------| '.join(value) 39 | pl_module.logger.experiment.add_text(tag, captions, global_step=global_step) 40 | elif isinstance(value, torch.Tensor) and value.dim() == 5: 41 | video = value 42 | n = video.shape[0] 43 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w 44 | frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video] #[3, n*h, 1*w] 45 | grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w] 46 | grid = (grid + 1.0) / 2.0 47 | grid = grid.unsqueeze(dim=0) 48 | pl_module.logger.experiment.add_video(tag, grid, fps=save_fps, global_step=global_step) 49 | elif isinstance(value, torch.Tensor) and value.dim() == 4: 50 | img = value 51 | grid = torchvision.utils.make_grid(img, nrow=int(n), padding=0) 52 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 53 | pl_module.logger.experiment.add_image(tag, grid, global_step=global_step) 54 | else: 55 | pass 56 | 57 | @rank_zero_only 58 | def log_batch_imgs(self, pl_module, batch, batch_idx, split="train"): 59 | """ generate images, then save and log to tensorboard """ 60 | skip_freq = self.batch_freq if split == "train" else 5 61 | if (batch_idx+1) % skip_freq == 0: 62 | is_train = pl_module.training 63 | if is_train: 64 | pl_module.eval() 65 | torch.cuda.empty_cache() 66 | with torch.no_grad(): 67 | log_func = pl_module.log_images 68 | batch_logs = log_func(batch, split=split, **self.log_images_kwargs) 69 | 70 | ## process: move to CPU and clamp 71 | batch_logs = prepare_to_log(batch_logs, self.max_images, self.clamp) 72 | torch.cuda.empty_cache() 73 | 74 | filename = "ep{}_idx{}_rank{}".format( 75 | pl_module.current_epoch, 76 | batch_idx, 77 | pl_module.global_rank) 78 | if self.to_local: 79 | mainlogger.info("Log [%s] batch <%s> to local ..."%(split, filename)) 80 | filename = "gs{}_".format(pl_module.global_step) + filename 81 | log_local(batch_logs, os.path.join(self.save_dir, split), filename, save_fps=10) 82 | else: 83 | mainlogger.info("Log [%s] batch <%s> to tensorboard ..."%(split, filename)) 84 | self.log_to_tensorboard(pl_module, batch_logs, filename, split, save_fps=10) 85 | mainlogger.info('Finish!') 86 | 87 | if is_train: 88 | pl_module.train() 89 | 90 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None): 91 | if self.batch_freq != -1 and pl_module.logdir: 92 | self.log_batch_imgs(pl_module, batch, batch_idx, split="train") 93 | 94 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None): 95 | ## different with validation_step() that saving the whole validation set and only keep the latest, 96 | ## it records the performance of every validation (without overwritten) by only keep a subset 97 | if self.batch_freq != -1 and pl_module.logdir: 98 | self.log_batch_imgs(pl_module, batch, batch_idx, split="val") 99 | if hasattr(pl_module, 'calibrate_grad_norm'): 100 | if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: 101 | self.log_gradients(trainer, pl_module, batch_idx=batch_idx) 102 | 103 | 104 | class CUDACallback(Callback): 105 | # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py 106 | def on_train_epoch_start(self, trainer, pl_module): 107 | # Reset the memory use counter 108 | # lightning update 109 | if int((pl.__version__).split('.')[1])>=7: 110 | gpu_index = trainer.strategy.root_device.index 111 | else: 112 | gpu_index = trainer.root_gpu 113 | torch.cuda.reset_peak_memory_stats(gpu_index) 114 | torch.cuda.synchronize(gpu_index) 115 | self.start_time = time.time() 116 | 117 | def on_train_epoch_end(self, trainer, pl_module): 118 | if int((pl.__version__).split('.')[1])>=7: 119 | gpu_index = trainer.strategy.root_device.index 120 | else: 121 | gpu_index = trainer.root_gpu 122 | torch.cuda.synchronize(gpu_index) 123 | max_memory = torch.cuda.max_memory_allocated(gpu_index) / 2 ** 20 124 | epoch_time = time.time() - self.start_time 125 | 126 | try: 127 | max_memory = trainer.training_type_plugin.reduce(max_memory) 128 | epoch_time = trainer.training_type_plugin.reduce(epoch_time) 129 | 130 | rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") 131 | rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") 132 | except AttributeError: 133 | pass 134 | -------------------------------------------------------------------------------- /examples/VideoCrafter/main/utils_data.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import numpy as np 3 | 4 | import torch 5 | import pytorch_lightning as pl 6 | from torch.utils.data import DataLoader, Dataset 7 | 8 | import os, sys 9 | os.chdir(sys.path[0]) 10 | sys.path.append("..") 11 | from lvdm.data.base import Txt2ImgIterableBaseDataset 12 | from utils.utils import instantiate_from_config 13 | 14 | 15 | def worker_init_fn(_): 16 | worker_info = torch.utils.data.get_worker_info() 17 | 18 | dataset = worker_info.dataset 19 | worker_id = worker_info.id 20 | 21 | if isinstance(dataset, Txt2ImgIterableBaseDataset): 22 | split_size = dataset.num_records // worker_info.num_workers 23 | # reset num_records to the true number to retain reliable length information 24 | dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] 25 | current_id = np.random.choice(len(np.random.get_state()[1]), 1) 26 | return np.random.seed(np.random.get_state()[1][current_id] + worker_id) 27 | else: 28 | return np.random.seed(np.random.get_state()[1][0] + worker_id) 29 | 30 | 31 | class WrappedDataset(Dataset): 32 | """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" 33 | 34 | def __init__(self, dataset): 35 | self.data = dataset 36 | 37 | def __len__(self): 38 | return len(self.data) 39 | 40 | def __getitem__(self, idx): 41 | return self.data[idx] 42 | 43 | 44 | class DataModuleFromConfig(pl.LightningDataModule): 45 | def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, 46 | wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False, 47 | shuffle_val_dataloader=False, train_img=None, 48 | test_max_n_samples=None): 49 | super().__init__() 50 | self.batch_size = batch_size 51 | self.dataset_configs = dict() 52 | self.num_workers = num_workers if num_workers is not None else batch_size * 2 53 | self.use_worker_init_fn = use_worker_init_fn 54 | if train is not None: 55 | self.dataset_configs["train"] = train 56 | self.train_dataloader = self._train_dataloader 57 | if validation is not None: 58 | self.dataset_configs["validation"] = validation 59 | self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) 60 | if test is not None: 61 | self.dataset_configs["test"] = test 62 | self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) 63 | if predict is not None: 64 | self.dataset_configs["predict"] = predict 65 | self.predict_dataloader = self._predict_dataloader 66 | 67 | self.img_loader = None 68 | self.wrap = wrap 69 | self.test_max_n_samples = test_max_n_samples 70 | self.collate_fn = None 71 | 72 | def prepare_data(self): 73 | pass 74 | 75 | def setup(self, stage=None): 76 | self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) 77 | if self.wrap: 78 | for k in self.datasets: 79 | self.datasets[k] = WrappedDataset(self.datasets[k]) 80 | 81 | def _train_dataloader(self): 82 | is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) 83 | if is_iterable_dataset or self.use_worker_init_fn: 84 | init_fn = worker_init_fn 85 | else: 86 | init_fn = None 87 | loader = DataLoader(self.datasets["train"], batch_size=self.batch_size, 88 | num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True, 89 | worker_init_fn=init_fn, collate_fn=self.collate_fn, 90 | ) 91 | return loader 92 | 93 | def _val_dataloader(self, shuffle=False): 94 | if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: 95 | init_fn = worker_init_fn 96 | else: 97 | init_fn = None 98 | return DataLoader(self.datasets["validation"], 99 | batch_size=self.batch_size, 100 | num_workers=self.num_workers, 101 | worker_init_fn=init_fn, 102 | shuffle=shuffle, 103 | collate_fn=self.collate_fn, 104 | ) 105 | 106 | def _test_dataloader(self, shuffle=False): 107 | try: 108 | is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) 109 | except: 110 | is_iterable_dataset = isinstance(self.datasets['test'], Txt2ImgIterableBaseDataset) 111 | 112 | if is_iterable_dataset or self.use_worker_init_fn: 113 | init_fn = worker_init_fn 114 | else: 115 | init_fn = None 116 | 117 | # do not shuffle dataloader for iterable dataset 118 | shuffle = shuffle and (not is_iterable_dataset) 119 | if self.test_max_n_samples is not None: 120 | dataset = torch.utils.data.Subset(self.datasets["test"], list(range(self.test_max_n_samples))) 121 | else: 122 | dataset = self.datasets["test"] 123 | return DataLoader(dataset, batch_size=self.batch_size, 124 | num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle, 125 | collate_fn=self.collate_fn, 126 | ) 127 | 128 | def _predict_dataloader(self, shuffle=False): 129 | if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: 130 | init_fn = worker_init_fn 131 | else: 132 | init_fn = None 133 | return DataLoader(self.datasets["predict"], batch_size=self.batch_size, 134 | num_workers=self.num_workers, worker_init_fn=init_fn, 135 | collate_fn=self.collate_fn, 136 | ) 137 | -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/13.png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/14.png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/15.png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/18.png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/25.png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/29.png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/3.png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/30.png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/32.png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/33.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/33.png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/35.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/35.png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/36.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/36.png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/41.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/41.png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/43.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/43.png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/47.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/47.png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/5.png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/52.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/52.png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/55.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/55.png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/65.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/65.png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/7.png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/A girl with long curly blonde hair and sunglasses, camera pan from left to right..png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/A girl with long curly blonde hair and sunglasses, camera pan from left to right..png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/A panda wearing sunglasses walking in slow-motion under water, in photorealistic style..png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/A panda wearing sunglasses walking in slow-motion under water, in photorealistic style..png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/a car parked in a parking lot with palm trees nearby,calm seas and skies..png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/VideoCrafter/prompts/512/a car parked in a parking lot with palm trees nearby,calm seas and skies..png -------------------------------------------------------------------------------- /examples/VideoCrafter/prompts/512/test_prompts.txt: -------------------------------------------------------------------------------- 1 | The sun sets on the horizon, casting a golden glow over the turbulent sea. 2 | A duck swimming in the lake. 3 | A duck swimming in the lake. 4 | A man riding motor on a mountain road. 5 | A couple hugging a cat. 6 | Mystical hills with a glowing blue portal. 7 | Two tigers fighting in the snow. 8 | A soldier riding a horse. 9 | Mountains under the starlight. 10 | A duck swimming in the lake. 11 | A girl walks up the steps of a palace. 12 | A woman with flowing, curly silver hair and dark eyes. 13 | Rabbits playing in a river. 14 | Sailing of boats on the water surface. 15 | A plate full of food, with camera spinning. 16 | A cartoon girl with brown curly hair splashes joyfully in a bubble-filled bathtub. 17 | Fireworks exploding in the sky. 18 | A kitten lying on the bed. 19 | A duck swimming in the lake. 20 | Donkeys in traditional attire gallop across a lush green meadow. 21 | A girl with long curly blonde hair and sunglasses, camera pan from left to right. 22 | A panda wearing sunglasses walking in slow-motion under water, in photorealistic style. 23 | a car parked in a parking lot with palm trees nearby,calm seas and skies. 24 | -------------------------------------------------------------------------------- /examples/VideoCrafter/requirements.txt: -------------------------------------------------------------------------------- 1 | aiofiles==23.2.1 2 | aiohttp==3.9.3 3 | aiosignal==1.3.1 4 | altair==5.2.0 5 | annotated-types==0.6.0 6 | antlr4-python3-runtime==4.8 7 | anyio==4.3.0 8 | APScheduler==3.9.1 9 | asttokens==2.0.5 10 | async-timeout==4.0.3 11 | attrdict==2.0.1 12 | attrs==23.2.0 13 | av==12.0.0 14 | backcall==0.2.0 15 | backports.zoneinfo==0.2.1 16 | certifi==2024.2.2 17 | charset-normalizer==3.3.2 18 | click==8.1.7 19 | cmake==3.28.3 20 | colorama==0.4.4 21 | contourpy==1.1.1 22 | cycler==0.11.0 23 | decorator==5.1.1 24 | decord==0.6.0 25 | einops==0.3.0 26 | exceptiongroup==1.2.0 27 | executing==0.8.3 28 | fairscale==0.4.13 29 | fastapi==0.110.0 30 | ffmpy==0.3.2 31 | filelock==3.13.1 32 | fire==0.6.0 33 | fonttools==4.33.3 34 | frozenlist==1.4.1 35 | fsspec==2024.3.1 36 | ftfy==6.2.0 37 | gradio==4.22.0 38 | gradio_client==0.13.0 39 | h11==0.14.0 40 | httpcore==1.0.4 41 | httpx==0.27.0 42 | huggingface-hub==0.21.4 43 | idna==3.6 44 | igraph==0.9.11 45 | imageio==2.9.0 46 | imageio-ffmpeg==0.4.9 47 | importlib_resources==6.3.2 48 | install==1.3.5 49 | ipython==8.4.0 50 | jedi==0.18.1 51 | Jinja2==3.1.3 52 | joblib==1.3.2 53 | jsonlines==4.0.0 54 | jsonschema==4.21.1 55 | jsonschema-specifications==2023.12.1 56 | kaleido==0.2.1 57 | kiwisolver==1.4.2 58 | kornia==0.7.2 59 | kornia_rs==0.1.2 60 | lightning-utilities==0.3.0 61 | lit==18.1.1 62 | markdown-it-py==3.0.0 63 | MarkupSafe==2.1.5 64 | matplotlib==3.5.2 65 | matplotlib-inline==0.1.3 66 | mdurl==0.1.2 67 | moviepy==1.0.3 68 | mpmath==1.2.1 69 | multidict==6.0.5 70 | mypy-extensions==1.0.0 71 | networkx==3.1 72 | numpy==1.22.4 73 | nvidia-cublas-cu11==11.10.3.66 74 | nvidia-cuda-cupti-cu11==11.7.101 75 | nvidia-cuda-nvrtc-cu11==11.7.99 76 | nvidia-cuda-runtime-cu11==11.7.99 77 | nvidia-cudnn-cu11==8.5.0.96 78 | nvidia-cufft-cu11==10.9.0.58 79 | nvidia-curand-cu11==10.2.10.91 80 | nvidia-cusolver-cu11==11.4.0.1 81 | nvidia-cusparse-cu11==11.7.4.91 82 | nvidia-nccl-cu11==2.14.3 83 | nvidia-nvtx-cu11==11.7.91 84 | omegaconf==2.1.1 85 | open-clip-torch==2.22.0 86 | opencv-python==4.9.0.80 87 | opencv-python-headless==4.9.0.80 88 | orjson==3.9.15 89 | packaging==21.3 90 | pandas==2.0.0 91 | parso==0.8.3 92 | pexpect==4.9.0 93 | pickleshare==0.7.5 94 | Pillow==9.1.1 95 | pkgutil_resolve_name==1.3.10 96 | plotly==5.8.2 97 | proglog==0.1.10 98 | prompt-toolkit==3.0.29 99 | protobuf==3.20.3 100 | ptyprocess==0.7.0 101 | pure-eval==0.2.2 102 | pydantic==2.6.4 103 | pydantic_core==2.16.3 104 | pydub==0.25.1 105 | Pygments==2.12.0 106 | pyparsing==3.0.9 107 | pyre-extensions==0.0.29 108 | python-dateutil==2.8.2 109 | python-multipart==0.0.9 110 | pytorch-lightning==1.8.3 111 | pytz==2024.1 112 | pytz-deprecation-shim==0.1.0.post0 113 | PyYAML==6.0 114 | referencing==0.34.0 115 | regex==2023.12.25 116 | requests==2.31.0 117 | rich==13.7.1 118 | rpds-py==0.18.0 119 | ruff==0.3.3 120 | safetensors==0.4.2 121 | scikit-learn==1.3.2 122 | scipy==1.10.1 123 | semantic-version==2.10.0 124 | sentencepiece==0.2.0 125 | shellingham==1.5.4 126 | six==1.16.0 127 | sniffio==1.3.1 128 | stack-data==0.2.0 129 | starlette==0.36.3 130 | sympy==1.10.1 131 | tenacity==8.0.1 132 | tensorboardX==2.6.2.2 133 | termcolor==2.4.0 134 | texttable==1.6.4 135 | threadpoolctl==3.4.0 136 | timm==0.9.16 137 | tokenizers==0.19.0 138 | tomlkit==0.12.0 139 | toolz==0.12.1 140 | torch==2.0.0 141 | torchmetrics==0.11.4 142 | torchvision==0.15.1 143 | tqdm==4.64.0 144 | traitlets==5.2.2.post1 145 | transformers==4.40.1 146 | triton==2.0.0 147 | typer==0.9.0 148 | typing-inspect==0.9.0 149 | typing_extensions==4.10.0 150 | tzdata==2024.1 151 | tzlocal==4.1 152 | urllib3==2.2.1 153 | uvicorn==0.29.0 154 | wcwidth==0.2.5 155 | websockets==11.0.3 156 | xformers==0.0.19 157 | yarl==1.9.4 158 | zipp==3.18.1 159 | -------------------------------------------------------------------------------- /examples/VideoCrafter/scripts/evaluation/ddp_wrapper.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import argparse, importlib 3 | from pytorch_lightning import seed_everything 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | def setup_dist(local_rank): 9 | if dist.is_initialized(): 10 | return 11 | torch.cuda.set_device(local_rank) 12 | torch.distributed.init_process_group('nccl', init_method='env://') 13 | 14 | 15 | def get_dist_info(): 16 | if dist.is_available(): 17 | initialized = dist.is_initialized() 18 | else: 19 | initialized = False 20 | if initialized: 21 | rank = dist.get_rank() 22 | world_size = dist.get_world_size() 23 | else: 24 | rank = 0 25 | world_size = 1 26 | return rank, world_size 27 | 28 | 29 | if __name__ == '__main__': 30 | now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--module", type=str, help="module name", default="inference") 33 | parser.add_argument("--local_rank", type=int, nargs="?", help="for ddp", default=0) 34 | args, unknown = parser.parse_known_args() 35 | inference_api = importlib.import_module(args.module, package=None) 36 | 37 | inference_parser = inference_api.get_parser() 38 | inference_args, unknown = inference_parser.parse_known_args() 39 | 40 | seed_everything(inference_args.seed) 41 | setup_dist(args.local_rank) 42 | torch.backends.cudnn.benchmark = True 43 | rank, gpu_num = get_dist_info() 44 | 45 | # inference_args.savedir = inference_args.savedir+str('_seed')+str(inference_args.seed) 46 | print("@DynamiCrafter Inference [rank%d]: %s"%(rank, now)) 47 | inference_api.run_inference(inference_args, gpu_num, rank) -------------------------------------------------------------------------------- /examples/VideoCrafter/train.sh: -------------------------------------------------------------------------------- 1 | # args 2 | name="training_512_v1.0" 3 | config_file=configs/train.yaml 4 | HOST_GPU_NUM=8 5 | # save root dir for logs, checkpoints, tensorboard record, etc. 6 | save_root="train" 7 | 8 | mkdir -p $save_root/$name 9 | 10 | # run 11 | 12 | CUDA_VISIBLE_DEVICES=3 torchrun -m torch.distributed.launch \ 13 | --nproc_per_node=2 --nnodes=1 --master_addr=127.0.0.1 --master_port=12352 --node_rank=0 \ 14 | ./main/trainer.py \ 15 | --base $config_file \ 16 | --train \ 17 | --name $name \ 18 | --logdir $save_root \ 19 | --devices 4 \ 20 | lightning.trainer.num_nodes=1 -------------------------------------------------------------------------------- /examples/VideoCrafter/utils/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import cv2 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | def count_params(model, verbose=False): 9 | total_params = sum(p.numel() for p in model.parameters()) 10 | if verbose: 11 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 12 | return total_params 13 | 14 | 15 | def check_istarget(name, para_list): 16 | """ 17 | name: full name of source para 18 | para_list: partial name of target para 19 | """ 20 | istarget=False 21 | for para in para_list: 22 | if para in name: 23 | return True 24 | return istarget 25 | 26 | 27 | def instantiate_from_config(config): 28 | if not "target" in config: 29 | if config == '__is_first_stage__': 30 | return None 31 | elif config == "__is_unconditional__": 32 | return None 33 | raise KeyError("Expected key `target` to instantiate.") 34 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 35 | 36 | 37 | def get_obj_from_str(string, reload=False): 38 | module, cls = string.rsplit(".", 1) 39 | if reload: 40 | module_imp = importlib.import_module(module) 41 | importlib.reload(module_imp) 42 | return getattr(importlib.import_module(module, package=None), cls) 43 | 44 | 45 | def load_npz_from_dir(data_dir): 46 | data = [np.load(os.path.join(data_dir, data_name))['arr_0'] for data_name in os.listdir(data_dir)] 47 | data = np.concatenate(data, axis=0) 48 | return data 49 | 50 | 51 | def load_npz_from_paths(data_paths): 52 | data = [np.load(data_path)['arr_0'] for data_path in data_paths] 53 | data = np.concatenate(data, axis=0) 54 | return data 55 | 56 | 57 | def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None): 58 | h, w = image.shape[:2] 59 | if resize_short_edge is not None: 60 | k = resize_short_edge / min(h, w) 61 | else: 62 | k = max_resolution / (h * w) 63 | k = k**0.5 64 | h = int(np.round(h * k / 64)) * 64 65 | w = int(np.round(w * k / 64)) * 64 66 | image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) 67 | return image 68 | 69 | 70 | def setup_dist(args): 71 | if dist.is_initialized(): 72 | return 73 | torch.cuda.set_device(args.local_rank) 74 | torch.distributed.init_process_group( 75 | 'nccl', 76 | init_method='env://' 77 | ) -------------------------------------------------------------------------------- /examples/animate-anything/demo/demo.jsonl: -------------------------------------------------------------------------------- 1 | {"dir": "demo/image/52.png", "text": "Fireworks exploding in sky.", "id": "0000040100"} 2 | -------------------------------------------------------------------------------- /examples/animate-anything/demo/image/52.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/animate-anything/demo/image/52.png -------------------------------------------------------------------------------- /examples/animate-anything/inference.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=1 --master_port=29502 inference.py \ 2 | --config output/latent/animate_anything_512_v1.02/config.yaml \ 3 | --eval \ 4 | M=900 \ 5 | validation_data.dataset_jsonl="demo/demo.jsonl" -------------------------------------------------------------------------------- /examples/animate-anything/output/latent/animate_anything_512_v1.02/config.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: output/latent/animate_anything_512_v1.02 2 | output_dir: ./output/latent 3 | name: inference 4 | train_data: 5 | width: 512 6 | height: 512 7 | use_bucketing: false 8 | return_mask: true 9 | return_motion: true 10 | sample_start_idx: 1 11 | fps: 8 12 | n_sample_frames: 16 13 | json_path: /webvid/animation0.json 14 | train_timesteps_from: 1000 15 | validation_data: 16 | dataset_jsonl: 17 | sample_preview: true 18 | num_frames: 16 19 | width: 512 20 | height: 320 21 | num_inference_steps: 50 22 | guidance_scale: 9 23 | strength: 5 24 | dataset_types: 25 | - json 26 | shuffle: true 27 | validation_steps: 200 28 | mask_nodes: 0 # to define ,mask DDIM nodes out of {num_inference_steps} 29 | trainable_modules: 30 | - all 31 | - attn1 32 | - conv_in 33 | - temp_conv 34 | - motion 35 | not_trainable_modules: [] 36 | trainable_text_modules: null 37 | extra_unet_params: null 38 | extra_text_encoder_params: null 39 | train_batch_size: 4 40 | max_train_steps: 10000 41 | learning_rate: 5.0e-06 42 | scale_lr: false 43 | lr_scheduler: constant 44 | lr_warmup_steps: 0 45 | adam_beta1: 0.9 46 | adam_beta2: 0.999 47 | adam_weight_decay: 0 48 | adam_epsilon: 1.0e-08 49 | max_grad_norm: 1.0 50 | gradient_accumulation_steps: 1 51 | gradient_checkpointing: true 52 | text_encoder_gradient_checkpointing: false 53 | checkpointing_steps: 2000 54 | resume_from_checkpoint: null 55 | resume_step: null 56 | mixed_precision: fp16 57 | use_8bit_adam: false 58 | enable_xformers_memory_efficient_attention: false 59 | enable_torch_2_attn: true 60 | seed: 4444 61 | train_text_encoder: false 62 | use_offset_noise: false 63 | rescale_schedule: false 64 | offset_noise_strength: 0.1 65 | extend_dataset: false 66 | cache_latents: false 67 | cached_latent_dir: null 68 | lora_version: cloneofsimo 69 | save_lora_for_webui: true 70 | only_lora_for_webui: false 71 | lora_bias: none 72 | use_unet_lora: false 73 | use_text_lora: false 74 | unet_lora_modules: 75 | - UNet3DConditionModel 76 | text_encoder_lora_modules: 77 | - CLIPEncoderLayer 78 | save_pretrained_model: true 79 | lora_rank: 16 80 | lora_path: '' 81 | lora_unet_dropout: 0.1 82 | lora_text_dropout: 0.1 83 | logger_type: tensorboard 84 | motion_mask: true 85 | motion_strength: true 86 | kwargs: {} 87 | -------------------------------------------------------------------------------- /examples/animate-anything/output/latent/animate_anything_512_v1.02/demo.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: output/latent/animate_anything_512_v1.02 2 | output_dir: ./output/latent 3 | name: demo 4 | train_data: 5 | width: 512 6 | height: 512 7 | use_bucketing: false 8 | return_mask: true 9 | return_motion: true 10 | sample_start_idx: 1 11 | fps: 8 12 | n_sample_frames: 16 13 | json_path: /webvid/animation0.json 14 | train_timesteps_from: 1000 15 | validation_data: 16 | prompt: fireworks exploding in the sky 17 | prompt_image: /mnt/vepfs/zhuhongzhou/animate-anything/example/webvid_human/451534.png 18 | dataset_jsonl: /mnt/vepfs/zhuhongzhou/animate-anything/data.jsonl # To define 19 | sample_preview: true 20 | num_frames: 16 21 | width: 448 22 | height: 256 23 | num_inference_steps: 50 24 | guidance_scale: 9 25 | strength: 5 26 | dataset_types: 27 | - json 28 | shuffle: true 29 | validation_steps: 200 30 | mask_nodes: 0 # to define ,mask DDIM nodes out of {num_inference_steps} 31 | trainable_modules: 32 | - all 33 | - attn1 34 | - conv_in 35 | - temp_conv 36 | - motion 37 | not_trainable_modules: [] 38 | trainable_text_modules: null 39 | extra_unet_params: null 40 | extra_text_encoder_params: null 41 | train_batch_size: 4 42 | max_train_steps: 10000 43 | learning_rate: 5.0e-06 44 | scale_lr: false 45 | lr_scheduler: constant 46 | lr_warmup_steps: 0 47 | adam_beta1: 0.9 48 | adam_beta2: 0.999 49 | adam_weight_decay: 0 50 | adam_epsilon: 1.0e-08 51 | max_grad_norm: 1.0 52 | gradient_accumulation_steps: 1 53 | gradient_checkpointing: true 54 | text_encoder_gradient_checkpointing: false 55 | checkpointing_steps: 2000 56 | resume_from_checkpoint: null 57 | resume_step: null 58 | mixed_precision: fp16 59 | use_8bit_adam: false 60 | enable_xformers_memory_efficient_attention: false 61 | enable_torch_2_attn: true 62 | seed: 4 63 | train_text_encoder: false 64 | use_offset_noise: false 65 | rescale_schedule: false 66 | offset_noise_strength: 0.1 67 | extend_dataset: false 68 | cache_latents: false 69 | cached_latent_dir: null 70 | lora_version: cloneofsimo 71 | save_lora_for_webui: true 72 | only_lora_for_webui: false 73 | lora_bias: none 74 | use_unet_lora: false 75 | use_text_lora: false 76 | unet_lora_modules: 77 | - UNet3DConditionModel 78 | text_encoder_lora_modules: 79 | - CLIPEncoderLayer 80 | save_pretrained_model: true 81 | lora_rank: 16 82 | lora_path: '' 83 | lora_unet_dropout: 0.1 84 | lora_text_dropout: 0.1 85 | logger_type: tensorboard 86 | motion_mask: true 87 | motion_strength: true 88 | kwargs: {} 89 | -------------------------------------------------------------------------------- /examples/animate-anything/output/latent/animate_anything_512_v1.02/model_index.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "LatentToVideoPipeline", 3 | "_diffusers_version": "0.24.0", 4 | "_name_or_path": "output/latent/train_2023-12-15T12-02-48/checkpoint-2000", 5 | "scheduler": [ 6 | "diffusers", 7 | "DDIMScheduler" 8 | ], 9 | "text_encoder": [ 10 | "transformers", 11 | "CLIPTextModel" 12 | ], 13 | "tokenizer": [ 14 | "transformers", 15 | "CLIPTokenizer" 16 | ], 17 | "unet": [ 18 | "models.unet_3d_condition_mask", 19 | "UNet3DConditionModel" 20 | ], 21 | "vae": [ 22 | "diffusers", 23 | "AutoencoderKL" 24 | ] 25 | } 26 | -------------------------------------------------------------------------------- /examples/animate-anything/output/latent/animate_anything_512_v1.02/scheduler/scheduler_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "DDIMScheduler", 3 | "_diffusers_version": "0.24.0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "clip_sample_range": 1.0, 9 | "dynamic_thresholding_ratio": 0.995, 10 | "num_train_timesteps": 1000, 11 | "prediction_type": "epsilon", 12 | "rescale_betas_zero_snr": false, 13 | "sample_max_value": 1.0, 14 | "set_alpha_to_one": false, 15 | "skip_prk_steps": true, 16 | "steps_offset": 1, 17 | "thresholding": false, 18 | "timestep_spacing": "trailing", 19 | "trained_betas": null 20 | } 21 | -------------------------------------------------------------------------------- /examples/animate-anything/output/latent/animate_anything_512_v1.02/text_encoder/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "output/latent/train_2023-12-15T12-02-48/checkpoint-2000", 3 | "architectures": [ 4 | "CLIPTextModel" 5 | ], 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 0, 8 | "dropout": 0.0, 9 | "eos_token_id": 2, 10 | "hidden_act": "gelu", 11 | "hidden_size": 1024, 12 | "initializer_factor": 1.0, 13 | "initializer_range": 0.02, 14 | "intermediate_size": 4096, 15 | "layer_norm_eps": 1e-05, 16 | "max_position_embeddings": 77, 17 | "model_type": "clip_text_model", 18 | "num_attention_heads": 16, 19 | "num_hidden_layers": 23, 20 | "pad_token_id": 1, 21 | "projection_dim": 512, 22 | "torch_dtype": "float32", 23 | "transformers_version": "4.34.1", 24 | "vocab_size": 49408 25 | } 26 | -------------------------------------------------------------------------------- /examples/animate-anything/output/latent/animate_anything_512_v1.02/text_encoder/model.txt: -------------------------------------------------------------------------------- 1 | put model.safetensors here -------------------------------------------------------------------------------- /examples/animate-anything/output/latent/animate_anything_512_v1.02/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "<|startoftext|>", 4 | "lstrip": false, 5 | "normalized": true, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "<|endoftext|>", 11 | "lstrip": false, 12 | "normalized": true, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "pad_token": { 17 | "content": "!", 18 | "lstrip": false, 19 | "normalized": false, 20 | "rstrip": false, 21 | "single_word": false 22 | }, 23 | "unk_token": { 24 | "content": "<|endoftext|>", 25 | "lstrip": false, 26 | "normalized": true, 27 | "rstrip": false, 28 | "single_word": false 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /examples/animate-anything/output/latent/animate_anything_512_v1.02/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "added_tokens_decoder": { 4 | "0": { 5 | "content": "!", 6 | "lstrip": false, 7 | "normalized": false, 8 | "rstrip": false, 9 | "single_word": false, 10 | "special": true 11 | }, 12 | "49406": { 13 | "content": "<|startoftext|>", 14 | "lstrip": false, 15 | "normalized": true, 16 | "rstrip": false, 17 | "single_word": false, 18 | "special": true 19 | }, 20 | "49407": { 21 | "content": "<|endoftext|>", 22 | "lstrip": false, 23 | "normalized": true, 24 | "rstrip": false, 25 | "single_word": false, 26 | "special": true 27 | } 28 | }, 29 | "bos_token": "<|startoftext|>", 30 | "clean_up_tokenization_spaces": true, 31 | "do_lower_case": true, 32 | "eos_token": "<|endoftext|>", 33 | "errors": "replace", 34 | "model_max_length": 77, 35 | "pad_token": "!", 36 | "tokenizer_class": "CLIPTokenizer", 37 | "unk_token": "<|endoftext|>" 38 | } 39 | -------------------------------------------------------------------------------- /examples/animate-anything/output/latent/animate_anything_512_v1.02/unet/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "UNet3DConditionModel", 3 | "_diffusers_version": "0.24.0", 4 | "_name_or_path": "output/latent/train_2023-12-15T12-02-48/checkpoint-2000", 5 | "act_fn": "silu", 6 | "attention_head_dim": 64, 7 | "block_out_channels": [ 8 | 320, 9 | 640, 10 | 1280, 11 | 1280 12 | ], 13 | "cross_attention_dim": 1024, 14 | "down_block_types": [ 15 | "CrossAttnDownBlock3D", 16 | "CrossAttnDownBlock3D", 17 | "CrossAttnDownBlock3D", 18 | "DownBlock3D" 19 | ], 20 | "downsample_padding": 1, 21 | "in_channels": 4, 22 | "layers_per_block": 2, 23 | "mid_block_scale_factor": 1, 24 | "motion_mask": true, 25 | "motion_strength": true, 26 | "norm_eps": 1e-05, 27 | "norm_num_groups": 32, 28 | "out_channels": 4, 29 | "sample_size": 32, 30 | "up_block_types": [ 31 | "UpBlock3D", 32 | "CrossAttnUpBlock3D", 33 | "CrossAttnUpBlock3D", 34 | "CrossAttnUpBlock3D" 35 | ] 36 | } 37 | -------------------------------------------------------------------------------- /examples/animate-anything/output/latent/animate_anything_512_v1.02/unet/unet.txt: -------------------------------------------------------------------------------- 1 | put diffusion_pytorch_model.safetensors here -------------------------------------------------------------------------------- /examples/animate-anything/output/latent/animate_anything_512_v1.02/vae/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "AutoencoderKL", 3 | "_diffusers_version": "0.24.0", 4 | "_name_or_path": "output/latent/train_2023-12-15T12-02-48/checkpoint-2000", 5 | "act_fn": "silu", 6 | "block_out_channels": [ 7 | 128, 8 | 256, 9 | 512, 10 | 512 11 | ], 12 | "down_block_types": [ 13 | "DownEncoderBlock2D", 14 | "DownEncoderBlock2D", 15 | "DownEncoderBlock2D", 16 | "DownEncoderBlock2D" 17 | ], 18 | "force_upcast": true, 19 | "in_channels": 3, 20 | "latent_channels": 4, 21 | "layers_per_block": 2, 22 | "norm_num_groups": 32, 23 | "out_channels": 3, 24 | "sample_size": 512, 25 | "scaling_factor": 0.18215, 26 | "up_block_types": [ 27 | "UpDecoderBlock2D", 28 | "UpDecoderBlock2D", 29 | "UpDecoderBlock2D", 30 | "UpDecoderBlock2D" 31 | ] 32 | } 33 | -------------------------------------------------------------------------------- /examples/animate-anything/output/latent/animate_anything_512_v1.02/vae/vae.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/animate-anything/output/latent/animate_anything_512_v1.02/vae/vae.txt -------------------------------------------------------------------------------- /examples/animate-anything/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/cond-image-leakage/aa7027bd2cd1dd2f6245c0b6bc566d68b638ca27/examples/animate-anything/utils/__init__.py -------------------------------------------------------------------------------- /examples/animate-anything/utils/bucketing.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | def min_res(size, min_size): return 192 if size < 192 else size 4 | 5 | def up_down_bucket(m_size, in_size, direction): 6 | if direction == 'down': return abs(int(m_size - in_size)) 7 | if direction == 'up': return abs(int(m_size + in_size)) 8 | 9 | def get_bucket_sizes(size, direction: 'down', min_size): 10 | multipliers = [64, 128] 11 | for i, m in enumerate(multipliers): 12 | res = up_down_bucket(m, size, direction) 13 | multipliers[i] = min_res(res, min_size=min_size) 14 | return multipliers 15 | 16 | def closest_bucket(m_size, size, direction, min_size): 17 | lst = get_bucket_sizes(m_size, direction, min_size) 18 | return lst[min(range(len(lst)), key=lambda i: abs(lst[i]-size))] 19 | 20 | def resolve_bucket(i,h,w): return (i / (h / w)) 21 | 22 | def sensible_buckets(m_width, m_height, w, h, min_size=192): 23 | if h > w: 24 | w = resolve_bucket(m_width, h, w) 25 | w = closest_bucket(m_width, w, 'down', min_size=min_size) 26 | return w, m_height 27 | if h < w: 28 | h = resolve_bucket(m_height, w, h) 29 | h = closest_bucket(m_height, h, 'down', min_size=min_size) 30 | return m_width, h 31 | 32 | return m_width, m_height -------------------------------------------------------------------------------- /examples/animate-anything/utils/scene_detect.py: -------------------------------------------------------------------------------- 1 | """" 2 | 调用pySceneDetect库进行视频场景切分,过滤掉视频长度小于3s的clips 3 | https://github.com/Breakthrough/PySceneDetect 4 | conda activate XPretrain 5 | """ 6 | import os 7 | import argparse 8 | import logging 9 | from scenedetect import detect, AdaptiveDetector, split_video_ffmpeg, ContentDetector 10 | from concurrent.futures import ProcessPoolExecutor, as_completed 11 | import fcntl 12 | 13 | 14 | # file to record processed videos 15 | logfile_path = "./log/tmp_processed_videos_adaptive.txt" 16 | 17 | def process_video(video_path): 18 | # start/end times of all scenes found in the video 19 | print("detect scene", video_path) 20 | scene_list = detect(video_path, AdaptiveDetector(adaptive_threshold=0.1)) # 1.0 21 | # scene_list = detect(video_path, ContentDetector()) 22 | # filter clips <= 3.0s 23 | scene_list = list(filter(lambda scene : 60.0 >=(scene[1].get_seconds()-scene[0].get_seconds()) >= 2.0, scene_list)) 24 | num_scene = len(scene_list) 25 | 26 | logging.info('------scene_list-------') 27 | logging.info(len(scene_list)) 28 | 29 | # save clips into current dir 30 | split_video_ffmpeg(video_path, scene_list, show_progress=True) 31 | 32 | 33 | 34 | def process_and_log(video_path): 35 | process_video(video_path) 36 | with open(logfile_path, 'a') as logfile: 37 | logfile.write(os.path.basename(video_path) + '\n') 38 | logfile.flush() 39 | 40 | def setup_logging(): 41 | logging.basicConfig( 42 | level=logging.INFO, 43 | format='%(asctime)s - %(levelname)s - %(message)s', 44 | datefmt='%Y-%m-%d %H:%M:%S' 45 | ) 46 | 47 | 48 | def fun(folder_path, vid_format='mp4'): 49 | logging.info(f'supported video format: {vid_format}') 50 | 51 | file_list = os.listdir(folder_path) 52 | file_list = list(filter(lambda fn: fn.endswith(f'.{vid_format}'), file_list)) 53 | 54 | logging.info(f'total videos: {len(file_list)}') 55 | logging.info('output videos will be saved into current dir.') 56 | 57 | # 读取已处理过的文件列表 58 | processed_files = set() 59 | if os.path.exists(logfile_path): 60 | with open(logfile_path, 'r') as logfile: 61 | processed_files = set(logfile.read().splitlines()) 62 | 63 | # processed_files = () 64 | # lock = threading.Lock() 65 | 66 | 67 | with ProcessPoolExecutor(max_workers=8) as executor: 68 | futures = [] 69 | 70 | for i, file_name in enumerate(file_list): 71 | if(i % 100 == 0): 72 | logging.info(f'done processing {i} videos') 73 | if(file_name in processed_files): 74 | logging.info(f'{file_name} processed before, ignored...') 75 | continue 76 | if file_name.endswith(f".{vid_format}"): 77 | video_path = os.path.join(folder_path, file_name) 78 | logging.info(f'processing video {file_name}') 79 | # 提交视频处理任务给线程池,并将Future对象存储到列表中 80 | future = executor.submit(process_and_log, video_path) 81 | futures.append(future) 82 | 83 | # 避免短时间提交过多任务导致内存不足 84 | if len(futures) >= 24: 85 | # 等待最早完成的任务完成 86 | for completed_future in as_completed(futures): 87 | completed_future.result() 88 | # 清空已完成的任务列表 89 | futures.clear() 90 | 91 | # 等待剩余的任务完成 92 | for completed_future in as_completed(futures): 93 | completed_future.result() 94 | 95 | def main(): 96 | setup_logging() 97 | 98 | parser = argparse.ArgumentParser(description='Process a folder path.') 99 | 100 | parser.add_argument('folder_path', type=str, help='Path to the folder to be processed', default='/data_menghao/XPretrain/hdvila_100m/default_content_detector_clips_the_magic_key_from_bbdown') 101 | 102 | args = parser.parse_args() 103 | 104 | if not os.path.isdir(args.folder_path): 105 | logging.info(f"Error: The folder {args.folder_path} does not exist.") 106 | return 107 | 108 | 109 | fun(folder_path=args.folder_path, vid_format='mp4') 110 | fun(folder_path=args.folder_path, vid_format='mkv') 111 | fun(folder_path=args.folder_path, vid_format='avi') 112 | 113 | 114 | if __name__ == '__main__': 115 | main() --------------------------------------------------------------------------------