├── .gitignore ├── 1girl_sleeping.gif ├── LICENSE ├── README.md ├── __init__.py ├── assets ├── application │ ├── 05.gif │ ├── 25.gif │ ├── 34.gif │ ├── 35.gif │ ├── 36.gif │ ├── 60.gif │ ├── YwHJYWvv_dM.gif │ ├── YwHJYWvv_dM_input_end.png │ ├── YwHJYWvv_dM_input_start.png │ ├── gkxX0kb8mE8.gif │ ├── gkxX0kb8mE8_input_end.png │ ├── gkxX0kb8mE8_input_start.png │ ├── smile.gif │ ├── smile_end.png │ ├── smile_start.png │ ├── stone01.gif │ ├── stone01_end.png │ ├── stone01_start.png │ ├── storytellingvideo.gif │ ├── ypDLB52Ykk4.gif │ ├── ypDLB52Ykk4_input_end.png │ └── ypDLB52Ykk4_input_start.png ├── logo_long.png ├── logo_long_dark.png └── showcase │ ├── Upscaled_Aime_Tribolet_springtime_landscape_golden_hour_morning_pale_yel_e6946f8d-37c1-4ce8-bf62-6ba90d23bd93.gif │ ├── Upscaled_Aime_Tribolet_springtime_landscape_golden_hour_morning_pale_yel_e6946f8d-37c1-4ce8-bf62-6ba90d23bd93.mp4_00.png │ ├── Upscaled_Alex__State_Blonde_woman_riding_on_top_of_a_moving_washing_mach_c31acaa3-dd30-459f-a109-2d2eb4c00fe2.gif │ ├── Upscaled_Alex__State_Blonde_woman_riding_on_top_of_a_moving_washing_mach_c31acaa3-dd30-459f-a109-2d2eb4c00fe2.mp4_00.png │ ├── bike_chineseink.gif │ ├── bird000.gif │ ├── bird000.jpeg │ ├── bloom2.gif │ ├── dance1.gif │ ├── dance1.jpeg_00.png │ ├── explode0.gif │ ├── explode0.jpeg_00.png │ ├── firework03.gif │ ├── girl07.gif │ ├── girl3.gif │ ├── girl3.jpeg_00.png │ ├── guitar0.gif │ ├── guitar0.jpeg_00.png │ ├── lighthouse.gif │ ├── pour_honey.gif │ ├── robot01.gif │ ├── train_anime02.gif │ ├── walk0.gif │ └── walk0.png_00.png ├── configs ├── inference_1024_v1.0.yaml ├── inference_256_v1.0.yaml └── inference_512_v1.0.yaml ├── gradio_app.py ├── lvdm ├── basics.py ├── common.py ├── distributions.py ├── ema.py ├── models │ ├── autoencoder.py │ ├── ddpm3d.py │ ├── samplers │ │ ├── ddim.py │ │ └── ddim_multiplecond.py │ └── utils_diffusion.py └── modules │ ├── attention.py │ ├── encoders │ ├── condition.py │ └── resampler.py │ ├── networks │ ├── ae_modules.py │ └── openaimodel3d.py │ └── x_transformer.py ├── nodes.py ├── prompts ├── 256 │ ├── art.png │ ├── bear.png │ ├── boy.png │ ├── dance1.jpeg │ ├── fire_and_beach.jpg │ ├── girl2.jpeg │ ├── girl3.jpeg │ ├── guitar0.jpeg │ └── test_prompts.txt ├── 512 │ ├── bloom01.png │ ├── campfire.png │ ├── girl08.png │ ├── isometric.png │ ├── pour_honey.png │ ├── ship02.png │ ├── test_prompts.txt │ ├── zreal_boat.png │ └── zreal_penguin.png ├── 1024 │ ├── astronaut04.png │ ├── bike_chineseink.png │ ├── bloom01.png │ ├── firework03.png │ ├── girl07.png │ ├── pour_bear.png │ ├── robot01.png │ ├── test_prompts.txt │ └── zreal_penguin.png ├── 512_interp │ ├── smile_01.png │ ├── smile_02.png │ ├── stone01_01.png │ ├── stone01_02.png │ ├── test_prompts.txt │ ├── walk_01.png │ └── walk_02.png └── 512_loop │ ├── 24.png │ ├── 36.png │ ├── 40.png │ └── test_prompts.txt ├── requirements.txt ├── scripts ├── evaluation │ ├── ddp_wrapper.py │ ├── funcs.py │ └── inference.py ├── gradio │ ├── i2v_test.py │ └── i2v_test_application.py ├── run.sh ├── run_application.sh └── run_mp.sh ├── utils └── utils.py ├── video.gif ├── wf-basic.png ├── wf-interp.json ├── wf-interp.png └── workflow.json /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *pyc 3 | .vscode 4 | __pycache__ 5 | *.egg-info 6 | 7 | checkpoints 8 | results 9 | backup -------------------------------------------------------------------------------- /1girl_sleeping.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/1girl_sleeping.gif -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ComfyUI DynamiCrafter 2 | 3 | [DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter) 4 | 5 | 注意一定要安装xformers `pip install xformers` 6 | 7 | 默认自动下载模型,但中国用户下载模型参考https://hf-mirror.com/ 8 | 9 | ``` 10 | pip install -U huggingface_hub 11 | export HF_ENDPOINT=https://hf-mirror.com windows上在powershell里$env:HF_ENDPOINT = "https://hf-mirror.com" 12 | huggingface-cli download --resume-download laion/CLIP-ViT-H-14-laion2B-s32B-b79K 13 | ``` 14 | 15 | 下载model.ckpt到models/checkpoints/dynamicrafter_512_interp_v1/model.ckpt 16 | 17 | https://hf-mirror.com/Doubiiu/DynamiCrafter_512_Interp 18 | 19 | 下载model.ckpt到models/checkpoints/dynamicrafter_1024_v1/model.ckpt 20 | 21 | https://hf-mirror.com/Doubiiu/DynamiCrafter_1024 22 | 23 | 24 | ## Examples 25 | 26 | ### base workflow 27 | 28 | 29 | 30 | https://github.com/chaojie/ComfyUI-DynamiCrafter/blob/main/workflow.json 31 | 32 | prompt: 1girl 33 | 34 | 35 | 36 | prompt: 1girl sleeping 37 | 38 | 39 | 40 | 4090 test: 41 | 42 | 16 frame length takes 3 minutes 43 | 44 | 32 frame length takes 6 minutes (32设置是在Loader节点) 45 | 46 | 47 | 48 | ### interpolation workflow 49 | 50 | 51 | 52 | https://github.com/chaojie/ComfyUI-DynamiCrafter/blob/main/wf-interp.json -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS 2 | 3 | __all__ = ['NODE_CLASS_MAPPINGS'] 4 | 5 | -------------------------------------------------------------------------------- /assets/application/05.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/05.gif -------------------------------------------------------------------------------- /assets/application/25.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/25.gif -------------------------------------------------------------------------------- /assets/application/34.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/34.gif -------------------------------------------------------------------------------- /assets/application/35.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/35.gif -------------------------------------------------------------------------------- /assets/application/36.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/36.gif -------------------------------------------------------------------------------- /assets/application/60.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/60.gif -------------------------------------------------------------------------------- /assets/application/YwHJYWvv_dM.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/YwHJYWvv_dM.gif -------------------------------------------------------------------------------- /assets/application/YwHJYWvv_dM_input_end.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/YwHJYWvv_dM_input_end.png -------------------------------------------------------------------------------- /assets/application/YwHJYWvv_dM_input_start.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/YwHJYWvv_dM_input_start.png -------------------------------------------------------------------------------- /assets/application/gkxX0kb8mE8.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/gkxX0kb8mE8.gif -------------------------------------------------------------------------------- /assets/application/gkxX0kb8mE8_input_end.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/gkxX0kb8mE8_input_end.png -------------------------------------------------------------------------------- /assets/application/gkxX0kb8mE8_input_start.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/gkxX0kb8mE8_input_start.png -------------------------------------------------------------------------------- /assets/application/smile.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/smile.gif -------------------------------------------------------------------------------- /assets/application/smile_end.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/smile_end.png -------------------------------------------------------------------------------- /assets/application/smile_start.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/smile_start.png -------------------------------------------------------------------------------- /assets/application/stone01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/stone01.gif -------------------------------------------------------------------------------- /assets/application/stone01_end.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/stone01_end.png -------------------------------------------------------------------------------- /assets/application/stone01_start.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/stone01_start.png -------------------------------------------------------------------------------- /assets/application/storytellingvideo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/storytellingvideo.gif -------------------------------------------------------------------------------- /assets/application/ypDLB52Ykk4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/ypDLB52Ykk4.gif -------------------------------------------------------------------------------- /assets/application/ypDLB52Ykk4_input_end.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/ypDLB52Ykk4_input_end.png -------------------------------------------------------------------------------- /assets/application/ypDLB52Ykk4_input_start.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/ypDLB52Ykk4_input_start.png -------------------------------------------------------------------------------- /assets/logo_long.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/logo_long.png -------------------------------------------------------------------------------- /assets/logo_long_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/logo_long_dark.png -------------------------------------------------------------------------------- /assets/showcase/Upscaled_Aime_Tribolet_springtime_landscape_golden_hour_morning_pale_yel_e6946f8d-37c1-4ce8-bf62-6ba90d23bd93.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/Upscaled_Aime_Tribolet_springtime_landscape_golden_hour_morning_pale_yel_e6946f8d-37c1-4ce8-bf62-6ba90d23bd93.gif -------------------------------------------------------------------------------- /assets/showcase/Upscaled_Aime_Tribolet_springtime_landscape_golden_hour_morning_pale_yel_e6946f8d-37c1-4ce8-bf62-6ba90d23bd93.mp4_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/Upscaled_Aime_Tribolet_springtime_landscape_golden_hour_morning_pale_yel_e6946f8d-37c1-4ce8-bf62-6ba90d23bd93.mp4_00.png -------------------------------------------------------------------------------- /assets/showcase/Upscaled_Alex__State_Blonde_woman_riding_on_top_of_a_moving_washing_mach_c31acaa3-dd30-459f-a109-2d2eb4c00fe2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/Upscaled_Alex__State_Blonde_woman_riding_on_top_of_a_moving_washing_mach_c31acaa3-dd30-459f-a109-2d2eb4c00fe2.gif -------------------------------------------------------------------------------- /assets/showcase/Upscaled_Alex__State_Blonde_woman_riding_on_top_of_a_moving_washing_mach_c31acaa3-dd30-459f-a109-2d2eb4c00fe2.mp4_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/Upscaled_Alex__State_Blonde_woman_riding_on_top_of_a_moving_washing_mach_c31acaa3-dd30-459f-a109-2d2eb4c00fe2.mp4_00.png -------------------------------------------------------------------------------- /assets/showcase/bike_chineseink.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/bike_chineseink.gif -------------------------------------------------------------------------------- /assets/showcase/bird000.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/bird000.gif -------------------------------------------------------------------------------- /assets/showcase/bird000.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/bird000.jpeg -------------------------------------------------------------------------------- /assets/showcase/bloom2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/bloom2.gif -------------------------------------------------------------------------------- /assets/showcase/dance1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/dance1.gif -------------------------------------------------------------------------------- /assets/showcase/dance1.jpeg_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/dance1.jpeg_00.png -------------------------------------------------------------------------------- /assets/showcase/explode0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/explode0.gif -------------------------------------------------------------------------------- /assets/showcase/explode0.jpeg_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/explode0.jpeg_00.png -------------------------------------------------------------------------------- /assets/showcase/firework03.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/firework03.gif -------------------------------------------------------------------------------- /assets/showcase/girl07.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/girl07.gif -------------------------------------------------------------------------------- /assets/showcase/girl3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/girl3.gif -------------------------------------------------------------------------------- /assets/showcase/girl3.jpeg_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/girl3.jpeg_00.png -------------------------------------------------------------------------------- /assets/showcase/guitar0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/guitar0.gif -------------------------------------------------------------------------------- /assets/showcase/guitar0.jpeg_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/guitar0.jpeg_00.png -------------------------------------------------------------------------------- /assets/showcase/lighthouse.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/lighthouse.gif -------------------------------------------------------------------------------- /assets/showcase/pour_honey.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/pour_honey.gif -------------------------------------------------------------------------------- /assets/showcase/robot01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/robot01.gif -------------------------------------------------------------------------------- /assets/showcase/train_anime02.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/train_anime02.gif -------------------------------------------------------------------------------- /assets/showcase/walk0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/walk0.gif -------------------------------------------------------------------------------- /assets/showcase/walk0.png_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/walk0.png_00.png -------------------------------------------------------------------------------- /configs/inference_1024_v1.0.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.models.ddpm3d.LatentVisualDiffusion 3 | params: 4 | rescale_betas_zero_snr: True 5 | parameterization: "v" 6 | linear_start: 0.00085 7 | linear_end: 0.012 8 | num_timesteps_cond: 1 9 | timesteps: 1000 10 | first_stage_key: video 11 | cond_stage_key: caption 12 | cond_stage_trainable: False 13 | conditioning_key: hybrid 14 | image_size: [72, 128] 15 | channels: 4 16 | scale_by_std: False 17 | scale_factor: 0.18215 18 | use_ema: False 19 | uncond_type: 'empty_seq' 20 | use_dynamic_rescale: true 21 | base_scale: 0.3 22 | fps_condition_type: 'fps' 23 | perframe_ae: True 24 | unet_config: 25 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.networks.openaimodel3d.UNetModel 26 | params: 27 | in_channels: 8 28 | out_channels: 4 29 | model_channels: 320 30 | attention_resolutions: 31 | - 4 32 | - 2 33 | - 1 34 | num_res_blocks: 2 35 | channel_mult: 36 | - 1 37 | - 2 38 | - 4 39 | - 4 40 | dropout: 0.1 41 | num_head_channels: 64 42 | transformer_depth: 1 43 | context_dim: 1024 44 | use_linear: true 45 | use_checkpoint: True 46 | temporal_conv: True 47 | temporal_attention: True 48 | temporal_selfatt_only: true 49 | use_relative_position: false 50 | use_causal_attention: False 51 | temporal_length: 16 52 | addition_attention: true 53 | image_cross_attention: true 54 | default_fs: 10 55 | fs_condition: true 56 | 57 | first_stage_config: 58 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.models.autoencoder.AutoencoderKL 59 | params: 60 | embed_dim: 4 61 | monitor: val/rec_loss 62 | ddconfig: 63 | double_z: True 64 | z_channels: 4 65 | resolution: 256 66 | in_channels: 3 67 | out_ch: 3 68 | ch: 128 69 | ch_mult: 70 | - 1 71 | - 2 72 | - 4 73 | - 4 74 | num_res_blocks: 2 75 | attn_resolutions: [] 76 | dropout: 0.0 77 | lossconfig: 78 | target: torch.nn.Identity 79 | 80 | cond_stage_config: 81 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder 82 | params: 83 | freeze: true 84 | layer: "penultimate" 85 | 86 | img_cond_stage_config: 87 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2 88 | params: 89 | freeze: true 90 | 91 | image_proj_stage_config: 92 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.encoders.resampler.Resampler 93 | params: 94 | dim: 1024 95 | depth: 4 96 | dim_head: 64 97 | heads: 12 98 | num_queries: 16 99 | embedding_dim: 1280 100 | output_dim: 1024 101 | ff_mult: 4 102 | video_length: 16 103 | 104 | -------------------------------------------------------------------------------- /configs/inference_256_v1.0.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.models.ddpm3d.LatentVisualDiffusion 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.012 6 | num_timesteps_cond: 1 7 | timesteps: 1000 8 | first_stage_key: video 9 | cond_stage_key: caption 10 | cond_stage_trainable: False 11 | conditioning_key: hybrid 12 | image_size: [32, 32] 13 | channels: 4 14 | scale_by_std: False 15 | scale_factor: 0.18215 16 | use_ema: False 17 | uncond_type: 'empty_seq' 18 | unet_config: 19 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.networks.openaimodel3d.UNetModel 20 | params: 21 | in_channels: 8 22 | out_channels: 4 23 | model_channels: 320 24 | attention_resolutions: 25 | - 4 26 | - 2 27 | - 1 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 4 33 | - 4 34 | dropout: 0.1 35 | num_head_channels: 64 36 | transformer_depth: 1 37 | context_dim: 1024 38 | use_linear: true 39 | use_checkpoint: True 40 | temporal_conv: True 41 | temporal_attention: True 42 | temporal_selfatt_only: true 43 | use_relative_position: false 44 | use_causal_attention: False 45 | temporal_length: 16 46 | addition_attention: true 47 | image_cross_attention: true 48 | image_cross_attention_scale_learnable: true 49 | default_fs: 3 50 | fs_condition: true 51 | 52 | first_stage_config: 53 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.models.autoencoder.AutoencoderKL 54 | params: 55 | embed_dim: 4 56 | monitor: val/rec_loss 57 | ddconfig: 58 | double_z: True 59 | z_channels: 4 60 | resolution: 256 61 | in_channels: 3 62 | out_ch: 3 63 | ch: 128 64 | ch_mult: 65 | - 1 66 | - 2 67 | - 4 68 | - 4 69 | num_res_blocks: 2 70 | attn_resolutions: [] 71 | dropout: 0.0 72 | lossconfig: 73 | target: torch.nn.Identity 74 | 75 | cond_stage_config: 76 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder 77 | params: 78 | freeze: true 79 | layer: "penultimate" 80 | 81 | img_cond_stage_config: 82 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2 83 | params: 84 | freeze: true 85 | 86 | image_proj_stage_config: 87 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.encoders.resampler.Resampler 88 | params: 89 | dim: 1024 90 | depth: 4 91 | dim_head: 64 92 | heads: 12 93 | num_queries: 16 94 | embedding_dim: 1280 95 | output_dim: 1024 96 | ff_mult: 4 97 | video_length: 16 98 | 99 | -------------------------------------------------------------------------------- /configs/inference_512_v1.0.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.models.ddpm3d.LatentVisualDiffusion 3 | params: 4 | rescale_betas_zero_snr: True 5 | parameterization: "v" 6 | linear_start: 0.00085 7 | linear_end: 0.012 8 | num_timesteps_cond: 1 9 | timesteps: 1000 10 | first_stage_key: video 11 | cond_stage_key: caption 12 | cond_stage_trainable: False 13 | conditioning_key: hybrid 14 | image_size: [40, 64] 15 | channels: 4 16 | scale_by_std: False 17 | scale_factor: 0.18215 18 | use_ema: False 19 | uncond_type: 'empty_seq' 20 | use_dynamic_rescale: true 21 | base_scale: 0.7 22 | fps_condition_type: 'fps' 23 | perframe_ae: True 24 | unet_config: 25 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.networks.openaimodel3d.UNetModel 26 | params: 27 | in_channels: 8 28 | out_channels: 4 29 | model_channels: 320 30 | attention_resolutions: 31 | - 4 32 | - 2 33 | - 1 34 | num_res_blocks: 2 35 | channel_mult: 36 | - 1 37 | - 2 38 | - 4 39 | - 4 40 | dropout: 0.1 41 | num_head_channels: 64 42 | transformer_depth: 1 43 | context_dim: 1024 44 | use_linear: true 45 | use_checkpoint: True 46 | temporal_conv: True 47 | temporal_attention: True 48 | temporal_selfatt_only: true 49 | use_relative_position: false 50 | use_causal_attention: False 51 | temporal_length: 16 52 | addition_attention: true 53 | image_cross_attention: true 54 | default_fs: 24 55 | fs_condition: true 56 | 57 | first_stage_config: 58 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.models.autoencoder.AutoencoderKL 59 | params: 60 | embed_dim: 4 61 | monitor: val/rec_loss 62 | ddconfig: 63 | double_z: True 64 | z_channels: 4 65 | resolution: 256 66 | in_channels: 3 67 | out_ch: 3 68 | ch: 128 69 | ch_mult: 70 | - 1 71 | - 2 72 | - 4 73 | - 4 74 | num_res_blocks: 2 75 | attn_resolutions: [] 76 | dropout: 0.0 77 | lossconfig: 78 | target: torch.nn.Identity 79 | 80 | cond_stage_config: 81 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder 82 | params: 83 | freeze: true 84 | layer: "penultimate" 85 | 86 | img_cond_stage_config: 87 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2 88 | params: 89 | freeze: true 90 | 91 | image_proj_stage_config: 92 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.encoders.resampler.Resampler 93 | params: 94 | dim: 1024 95 | depth: 4 96 | dim_head: 64 97 | heads: 12 98 | num_queries: 16 99 | embedding_dim: 1280 100 | output_dim: 1024 101 | ff_mult: 4 102 | video_length: 16 103 | 104 | -------------------------------------------------------------------------------- /gradio_app.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | import sys 3 | import gradio as gr 4 | from scripts.gradio.i2v_test import Image2Video 5 | sys.path.insert(1, os.path.join(sys.path[0], 'lvdm')) 6 | 7 | i2v_examples_1024 = [ 8 | ['prompts/1024/astronaut04.png', 'a man in an astronaut suit playing a guitar', 50, 7.5, 1.0, 6, 123], 9 | ['prompts/1024/bloom01.png', 'time-lapse of a blooming flower with leaves and a stem', 50, 7.5, 1.0, 10, 123], 10 | ['prompts/1024/girl07.png', 'a beautiful woman with long hair and a dress blowing in the wind', 50, 7.5, 1.0, 10, 123], 11 | ['prompts/1024/pour_bear.png', 'pouring beer into a glass of ice and beer', 50, 7.5, 1.0, 10, 123], 12 | ['prompts/1024/robot01.png', 'a robot is walking through a destroyed city', 50, 7.5, 1.0, 10, 123], 13 | ['prompts/1024/firework03.png', 'fireworks display', 50, 7.5, 1.0, 10, 123], 14 | ] 15 | 16 | i2v_examples_512 = [ 17 | ['prompts/512/bloom01.png', 'time-lapse of a blooming flower with leaves and a stem', 50, 7.5, 1.0, 24, 123], 18 | ['prompts/512/campfire.png', 'a bonfire is lit in the middle of a field', 50, 7.5, 1.0, 24, 123], 19 | ['prompts/512/isometric.png', 'rotating view, small house', 50, 7.5, 1.0, 24, 123], 20 | ['prompts/512/girl08.png', 'a woman looking out in the rain', 50, 7.5, 1.0, 24, 1234], 21 | ['prompts/512/ship02.png', 'a sailboat sailing in rough seas with a dramatic sunset', 50, 7.5, 1.0, 24, 123], 22 | ['prompts/512/zreal_penguin.png', 'a group of penguins walking on a beach', 50, 7.5, 1.0, 20, 123], 23 | ] 24 | 25 | i2v_examples_256 = [ 26 | ['prompts/256/art.png', 'man fishing in a boat at sunset', 50, 7.5, 1.0, 3, 234], 27 | ['prompts/256/boy.png', 'boy walking on the street', 50, 7.5, 1.0, 3, 125], 28 | ['prompts/256/dance1.jpeg', 'two people dancing', 50, 7.5, 1.0, 3, 116], 29 | ['prompts/256/fire_and_beach.jpg', 'a campfire on the beach and the ocean waves in the background', 50, 7.5, 1.0, 3, 111], 30 | ['prompts/256/girl3.jpeg', 'girl talking and blinking', 50, 7.5, 1.0, 3, 111], 31 | ['prompts/256/guitar0.jpeg', 'bear playing guitar happily, snowing', 50, 7.5, 1.0, 3, 122] 32 | ] 33 | 34 | 35 | def dynamicrafter_demo(result_dir='./tmp/', res=1024): 36 | if res == 1024: 37 | resolution = '576_1024' 38 | css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height:576px}""" 39 | elif res == 512: 40 | resolution = '320_512' 41 | css = """#input_img {max-width: 512px !important} #output_vid {max-width: 512px; max-height: 320px}""" 42 | elif res == 256: 43 | resolution = '256_256' 44 | css = """#input_img {max-width: 256px !important} #output_vid {max-width: 256px; max-height: 256px}""" 45 | else: 46 | raise NotImplementedError(f"Unsupported resolution: {res}") 47 | image2video = Image2Video(result_dir, resolution=resolution) 48 | with gr.Blocks(analytics_enabled=False, css=css) as dynamicrafter_iface: 49 | gr.Markdown("

DynamiCrafter: Animating Open-domain Images with Video Diffusion Priors

\ 50 |

\ 51 | Jinbo Xing, \ 52 | Menghan Xia, Yong Zhang, \ 53 | Haoxin Chen, Wangbo Yu,\ 54 | Hanyuan Liu, Xintao Wang,\ 55 | Tien-Tsin Wong,\ 56 | Ying Shan\ 57 |

\ 58 | [ArXiv] \ 59 | [Project Page] \ 60 | [Github]
") 61 | 62 | #######image2video###### 63 | if res == 1024: 64 | with gr.Tab(label='Image2Video_576x1024'): 65 | with gr.Column(): 66 | with gr.Row(): 67 | with gr.Column(): 68 | with gr.Row(): 69 | i2v_input_image = gr.Image(label="Input Image",elem_id="input_img") 70 | with gr.Row(): 71 | i2v_input_text = gr.Text(label='Prompts') 72 | with gr.Row(): 73 | i2v_seed = gr.Slider(label='Random Seed', minimum=0, maximum=10000, step=1, value=123) 74 | i2v_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label='ETA', value=1.0, elem_id="i2v_eta") 75 | i2v_cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.5, elem_id="i2v_cfg_scale") 76 | with gr.Row(): 77 | i2v_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="i2v_steps", label="Sampling steps", value=50) 78 | i2v_motion = gr.Slider(minimum=5, maximum=20, step=1, elem_id="i2v_motion", label="FPS", value=10) 79 | i2v_end_btn = gr.Button("Generate") 80 | # with gr.Tab(label='Result'): 81 | with gr.Row(): 82 | i2v_output_video = gr.Video(label="Generated Video",elem_id="output_vid",autoplay=True,show_share_button=True) 83 | 84 | gr.Examples(examples=i2v_examples_1024, 85 | inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed], 86 | outputs=[i2v_output_video], 87 | fn = image2video.get_image, 88 | cache_examples=False, 89 | ) 90 | i2v_end_btn.click(inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed], 91 | outputs=[i2v_output_video], 92 | fn = image2video.get_image 93 | ) 94 | elif res == 512: 95 | with gr.Tab(label='Image2Video_320x512'): 96 | with gr.Column(): 97 | with gr.Row(): 98 | with gr.Column(): 99 | with gr.Row(): 100 | i2v_input_image = gr.Image(label="Input Image",elem_id="input_img") 101 | with gr.Row(): 102 | i2v_input_text = gr.Text(label='Prompts') 103 | with gr.Row(): 104 | i2v_seed = gr.Slider(label='Random Seed', minimum=0, maximum=10000, step=1, value=123) 105 | i2v_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label='ETA', value=1.0, elem_id="i2v_eta") 106 | i2v_cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.5, elem_id="i2v_cfg_scale") 107 | with gr.Row(): 108 | i2v_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="i2v_steps", label="Sampling steps", value=50) 109 | i2v_motion = gr.Slider(minimum=15, maximum=30, step=1, elem_id="i2v_motion", label="FPS", value=24) 110 | i2v_end_btn = gr.Button("Generate") 111 | # with gr.Tab(label='Result'): 112 | with gr.Row(): 113 | i2v_output_video = gr.Video(label="Generated Video",elem_id="output_vid",autoplay=True,show_share_button=True) 114 | 115 | gr.Examples(examples=i2v_examples_512, 116 | inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed], 117 | outputs=[i2v_output_video], 118 | fn = image2video.get_image, 119 | cache_examples=False, 120 | ) 121 | i2v_end_btn.click(inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed], 122 | outputs=[i2v_output_video], 123 | fn = image2video.get_image 124 | ) 125 | elif res == 256: 126 | with gr.Tab(label='Image2Video_256x256'): 127 | with gr.Column(): 128 | with gr.Row(): 129 | with gr.Column(): 130 | with gr.Row(): 131 | i2v_input_image = gr.Image(label="Input Image",elem_id="input_img") 132 | with gr.Row(): 133 | i2v_input_text = gr.Text(label='Prompts') 134 | with gr.Row(): 135 | i2v_seed = gr.Slider(label='Random Seed', minimum=0, maximum=10000, step=1, value=123) 136 | i2v_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label='ETA', value=1.0, elem_id="i2v_eta") 137 | i2v_cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.5, elem_id="i2v_cfg_scale") 138 | with gr.Row(): 139 | i2v_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="i2v_steps", label="Sampling steps", value=50) 140 | i2v_motion = gr.Slider(minimum=1, maximum=4, step=1, elem_id="i2v_motion", label="Motion magnitude", value=3) 141 | i2v_end_btn = gr.Button("Generate") 142 | # with gr.Tab(label='Result'): 143 | with gr.Row(): 144 | i2v_output_video = gr.Video(label="Generated Video",elem_id="output_vid",autoplay=True,show_share_button=True) 145 | 146 | gr.Examples(examples=i2v_examples_256, 147 | inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed], 148 | outputs=[i2v_output_video], 149 | fn = image2video.get_image, 150 | cache_examples=False, 151 | ) 152 | i2v_end_btn.click(inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed], 153 | outputs=[i2v_output_video], 154 | fn = image2video.get_image 155 | ) 156 | 157 | return dynamicrafter_iface 158 | 159 | def get_parser(): 160 | parser = argparse.ArgumentParser() 161 | parser.add_argument("--res", type=int, default=1024, choices=[1024,512,256], help="select the model based on the required resolution") 162 | 163 | return parser 164 | 165 | if __name__ == "__main__": 166 | parser = get_parser() 167 | args = parser.parse_args() 168 | 169 | result_dir = os.path.join('./', 'results') 170 | dynamicrafter_iface = dynamicrafter_demo(result_dir, args.res) 171 | dynamicrafter_iface.queue(max_size=12) 172 | dynamicrafter_iface.launch(max_threads=1) 173 | # dynamicrafter_iface.launch(server_name='0.0.0.0', server_port=80, max_threads=1) -------------------------------------------------------------------------------- /lvdm/basics.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | import torch.nn as nn 11 | from ..utils.utils import instantiate_from_config 12 | 13 | 14 | def disabled_train(self, mode=True): 15 | """Overwrite model.train with this function to make sure train/eval mode 16 | does not change anymore.""" 17 | return self 18 | 19 | def zero_module(module): 20 | """ 21 | Zero out the parameters of a module and return it. 22 | """ 23 | for p in module.parameters(): 24 | p.detach().zero_() 25 | return module 26 | 27 | def scale_module(module, scale): 28 | """ 29 | Scale the parameters of a module and return it. 30 | """ 31 | for p in module.parameters(): 32 | p.detach().mul_(scale) 33 | return module 34 | 35 | 36 | def conv_nd(dims, *args, **kwargs): 37 | """ 38 | Create a 1D, 2D, or 3D convolution module. 39 | """ 40 | if dims == 1: 41 | return nn.Conv1d(*args, **kwargs) 42 | elif dims == 2: 43 | return nn.Conv2d(*args, **kwargs) 44 | elif dims == 3: 45 | return nn.Conv3d(*args, **kwargs) 46 | raise ValueError(f"unsupported dimensions: {dims}") 47 | 48 | 49 | def linear(*args, **kwargs): 50 | """ 51 | Create a linear module. 52 | """ 53 | return nn.Linear(*args, **kwargs) 54 | 55 | 56 | def avg_pool_nd(dims, *args, **kwargs): 57 | """ 58 | Create a 1D, 2D, or 3D average pooling module. 59 | """ 60 | if dims == 1: 61 | return nn.AvgPool1d(*args, **kwargs) 62 | elif dims == 2: 63 | return nn.AvgPool2d(*args, **kwargs) 64 | elif dims == 3: 65 | return nn.AvgPool3d(*args, **kwargs) 66 | raise ValueError(f"unsupported dimensions: {dims}") 67 | 68 | 69 | def nonlinearity(type='silu'): 70 | if type == 'silu': 71 | return nn.SiLU() 72 | elif type == 'leaky_relu': 73 | return nn.LeakyReLU() 74 | 75 | 76 | class GroupNormSpecific(nn.GroupNorm): 77 | def forward(self, x): 78 | return super().forward(x.float()).type(x.dtype) 79 | 80 | 81 | def normalization(channels, num_groups=32): 82 | """ 83 | Make a standard normalization layer. 84 | :param channels: number of input channels. 85 | :return: an nn.Module for normalization. 86 | """ 87 | return GroupNormSpecific(num_groups, channels) 88 | 89 | 90 | class HybridConditioner(nn.Module): 91 | 92 | def __init__(self, c_concat_config, c_crossattn_config): 93 | super().__init__() 94 | self.concat_conditioner = instantiate_from_config(c_concat_config) 95 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 96 | 97 | def forward(self, c_concat, c_crossattn): 98 | c_concat = self.concat_conditioner(c_concat) 99 | c_crossattn = self.crossattn_conditioner(c_crossattn) 100 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} -------------------------------------------------------------------------------- /lvdm/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | from inspect import isfunction 3 | import torch 4 | from torch import nn 5 | import torch.distributed as dist 6 | 7 | 8 | def gather_data(data, return_np=True): 9 | ''' gather data from multiple processes to one list ''' 10 | data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())] 11 | dist.all_gather(data_list, data) # gather not supported with NCCL 12 | if return_np: 13 | data_list = [data.cpu().numpy() for data in data_list] 14 | return data_list 15 | 16 | def autocast(f): 17 | def do_autocast(*args, **kwargs): 18 | with torch.cuda.amp.autocast(enabled=True, 19 | dtype=torch.get_autocast_gpu_dtype(), 20 | cache_enabled=torch.is_autocast_cache_enabled()): 21 | return f(*args, **kwargs) 22 | return do_autocast 23 | 24 | 25 | def extract_into_tensor(a, t, x_shape): 26 | b, *_ = t.shape 27 | out = a.gather(-1, t) 28 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 29 | 30 | 31 | def noise_like(shape, device, repeat=False): 32 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 33 | noise = lambda: torch.randn(shape, device=device) 34 | return repeat_noise() if repeat else noise() 35 | 36 | 37 | def default(val, d): 38 | if exists(val): 39 | return val 40 | return d() if isfunction(d) else d 41 | 42 | def exists(val): 43 | return val is not None 44 | 45 | def identity(*args, **kwargs): 46 | return nn.Identity() 47 | 48 | def uniq(arr): 49 | return{el: True for el in arr}.keys() 50 | 51 | def mean_flat(tensor): 52 | """ 53 | Take the mean over all non-batch dimensions. 54 | """ 55 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 56 | 57 | def ismap(x): 58 | if not isinstance(x, torch.Tensor): 59 | return False 60 | return (len(x.shape) == 4) and (x.shape[1] > 3) 61 | 62 | def isimage(x): 63 | if not isinstance(x,torch.Tensor): 64 | return False 65 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 66 | 67 | def max_neg_value(t): 68 | return -torch.finfo(t.dtype).max 69 | 70 | def shape_to_str(x): 71 | shape_str = "x".join([str(x) for x in x.shape]) 72 | return shape_str 73 | 74 | def init_(tensor): 75 | dim = tensor.shape[-1] 76 | std = 1 / math.sqrt(dim) 77 | tensor.uniform_(-std, std) 78 | return tensor 79 | 80 | ckpt = torch.utils.checkpoint.checkpoint 81 | def checkpoint(func, inputs, params, flag): 82 | """ 83 | Evaluate a function without caching intermediate activations, allowing for 84 | reduced memory at the expense of extra compute in the backward pass. 85 | :param func: the function to evaluate. 86 | :param inputs: the argument sequence to pass to `func`. 87 | :param params: a sequence of parameters `func` depends on but does not 88 | explicitly take as arguments. 89 | :param flag: if False, disable gradient checkpointing. 90 | """ 91 | if flag: 92 | return ckpt(func, *inputs, use_reentrant=False) 93 | else: 94 | return func(*inputs) -------------------------------------------------------------------------------- /lvdm/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self, noise=None): 36 | if noise is None: 37 | noise = torch.randn(self.mean.shape) 38 | 39 | x = self.mean + self.std * noise.to(device=self.parameters.device) 40 | return x 41 | 42 | def kl(self, other=None): 43 | if self.deterministic: 44 | return torch.Tensor([0.]) 45 | else: 46 | if other is None: 47 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 48 | + self.var - 1.0 - self.logvar, 49 | dim=[1, 2, 3]) 50 | else: 51 | return 0.5 * torch.sum( 52 | torch.pow(self.mean - other.mean, 2) / other.var 53 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 54 | dim=[1, 2, 3]) 55 | 56 | def nll(self, sample, dims=[1,2,3]): 57 | if self.deterministic: 58 | return torch.Tensor([0.]) 59 | logtwopi = np.log(2.0 * np.pi) 60 | return 0.5 * torch.sum( 61 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 62 | dim=dims) 63 | 64 | def mode(self): 65 | return self.mean 66 | 67 | 68 | def normal_kl(mean1, logvar1, mean2, logvar2): 69 | """ 70 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 71 | Compute the KL divergence between two gaussians. 72 | Shapes are automatically broadcasted, so batches can be compared to 73 | scalars, among other use cases. 74 | """ 75 | tensor = None 76 | for obj in (mean1, logvar1, mean2, logvar2): 77 | if isinstance(obj, torch.Tensor): 78 | tensor = obj 79 | break 80 | assert tensor is not None, "at least one argument must be a Tensor" 81 | 82 | # Force variances to be Tensors. Broadcasting helps convert scalars to 83 | # Tensors, but it does not work for torch.exp(). 84 | logvar1, logvar2 = [ 85 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 86 | for x in (logvar1, logvar2) 87 | ] 88 | 89 | return 0.5 * ( 90 | -1.0 91 | + logvar2 92 | - logvar1 93 | + torch.exp(logvar1 - logvar2) 94 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 95 | ) -------------------------------------------------------------------------------- /lvdm/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) -------------------------------------------------------------------------------- /lvdm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import contextmanager 3 | import torch 4 | import numpy as np 5 | from einops import rearrange 6 | import torch.nn.functional as F 7 | import pytorch_lightning as pl 8 | from ..modules.networks.ae_modules import Encoder, Decoder 9 | from ..distributions import DiagonalGaussianDistribution 10 | from ...utils.utils import instantiate_from_config 11 | 12 | 13 | class AutoencoderKL(pl.LightningModule): 14 | def __init__(self, 15 | ddconfig, 16 | lossconfig, 17 | embed_dim, 18 | ckpt_path=None, 19 | ignore_keys=[], 20 | image_key="image", 21 | colorize_nlabels=None, 22 | monitor=None, 23 | test=False, 24 | logdir=None, 25 | input_dim=4, 26 | test_args=None, 27 | ): 28 | super().__init__() 29 | self.image_key = image_key 30 | self.encoder = Encoder(**ddconfig) 31 | self.decoder = Decoder(**ddconfig) 32 | self.loss = instantiate_from_config(lossconfig) 33 | assert ddconfig["double_z"] 34 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 35 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 36 | self.embed_dim = embed_dim 37 | self.input_dim = input_dim 38 | self.test = test 39 | self.test_args = test_args 40 | self.logdir = logdir 41 | if colorize_nlabels is not None: 42 | assert type(colorize_nlabels)==int 43 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 44 | if monitor is not None: 45 | self.monitor = monitor 46 | if ckpt_path is not None: 47 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 48 | if self.test: 49 | self.init_test() 50 | 51 | def init_test(self,): 52 | self.test = True 53 | save_dir = os.path.join(self.logdir, "test") 54 | if 'ckpt' in self.test_args: 55 | ckpt_name = os.path.basename(self.test_args.ckpt).split('.ckpt')[0] + f'_epoch{self._cur_epoch}' 56 | self.root = os.path.join(save_dir, ckpt_name) 57 | else: 58 | self.root = save_dir 59 | if 'test_subdir' in self.test_args: 60 | self.root = os.path.join(save_dir, self.test_args.test_subdir) 61 | 62 | self.root_zs = os.path.join(self.root, "zs") 63 | self.root_dec = os.path.join(self.root, "reconstructions") 64 | self.root_inputs = os.path.join(self.root, "inputs") 65 | os.makedirs(self.root, exist_ok=True) 66 | 67 | if self.test_args.save_z: 68 | os.makedirs(self.root_zs, exist_ok=True) 69 | if self.test_args.save_reconstruction: 70 | os.makedirs(self.root_dec, exist_ok=True) 71 | if self.test_args.save_input: 72 | os.makedirs(self.root_inputs, exist_ok=True) 73 | assert(self.test_args is not None) 74 | self.test_maximum = getattr(self.test_args, 'test_maximum', None) 75 | self.count = 0 76 | self.eval_metrics = {} 77 | self.decodes = [] 78 | self.save_decode_samples = 2048 79 | 80 | def init_from_ckpt(self, path, ignore_keys=list()): 81 | sd = torch.load(path, map_location="cpu") 82 | try: 83 | self._cur_epoch = sd['epoch'] 84 | sd = sd["state_dict"] 85 | except: 86 | self._cur_epoch = 'null' 87 | keys = list(sd.keys()) 88 | for k in keys: 89 | for ik in ignore_keys: 90 | if k.startswith(ik): 91 | print("Deleting key {} from state_dict.".format(k)) 92 | del sd[k] 93 | self.load_state_dict(sd, strict=False) 94 | # self.load_state_dict(sd, strict=True) 95 | print(f"Restored from {path}") 96 | 97 | def encode(self, x, **kwargs): 98 | 99 | h = self.encoder(x) 100 | moments = self.quant_conv(h) 101 | posterior = DiagonalGaussianDistribution(moments) 102 | return posterior 103 | 104 | def decode(self, z, **kwargs): 105 | z = self.post_quant_conv(z) 106 | dec = self.decoder(z) 107 | return dec 108 | 109 | def forward(self, input, sample_posterior=True): 110 | posterior = self.encode(input) 111 | if sample_posterior: 112 | z = posterior.sample() 113 | else: 114 | z = posterior.mode() 115 | dec = self.decode(z) 116 | return dec, posterior 117 | 118 | def get_input(self, batch, k): 119 | x = batch[k] 120 | if x.dim() == 5 and self.input_dim == 4: 121 | b,c,t,h,w = x.shape 122 | self.b = b 123 | self.t = t 124 | x = rearrange(x, 'b c t h w -> (b t) c h w') 125 | 126 | return x 127 | 128 | def training_step(self, batch, batch_idx, optimizer_idx): 129 | inputs = self.get_input(batch, self.image_key) 130 | reconstructions, posterior = self(inputs) 131 | 132 | if optimizer_idx == 0: 133 | # train encoder+decoder+logvar 134 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 135 | last_layer=self.get_last_layer(), split="train") 136 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 137 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 138 | return aeloss 139 | 140 | if optimizer_idx == 1: 141 | # train the discriminator 142 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 143 | last_layer=self.get_last_layer(), split="train") 144 | 145 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 146 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 147 | return discloss 148 | 149 | def validation_step(self, batch, batch_idx): 150 | inputs = self.get_input(batch, self.image_key) 151 | reconstructions, posterior = self(inputs) 152 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 153 | last_layer=self.get_last_layer(), split="val") 154 | 155 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 156 | last_layer=self.get_last_layer(), split="val") 157 | 158 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) 159 | self.log_dict(log_dict_ae) 160 | self.log_dict(log_dict_disc) 161 | return self.log_dict 162 | 163 | def configure_optimizers(self): 164 | lr = self.learning_rate 165 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 166 | list(self.decoder.parameters())+ 167 | list(self.quant_conv.parameters())+ 168 | list(self.post_quant_conv.parameters()), 169 | lr=lr, betas=(0.5, 0.9)) 170 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 171 | lr=lr, betas=(0.5, 0.9)) 172 | return [opt_ae, opt_disc], [] 173 | 174 | def get_last_layer(self): 175 | return self.decoder.conv_out.weight 176 | 177 | @torch.no_grad() 178 | def log_images(self, batch, only_inputs=False, **kwargs): 179 | log = dict() 180 | x = self.get_input(batch, self.image_key) 181 | x = x.to(self.device) 182 | if not only_inputs: 183 | xrec, posterior = self(x) 184 | if x.shape[1] > 3: 185 | # colorize with random projection 186 | assert xrec.shape[1] > 3 187 | x = self.to_rgb(x) 188 | xrec = self.to_rgb(xrec) 189 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 190 | log["reconstructions"] = xrec 191 | log["inputs"] = x 192 | return log 193 | 194 | def to_rgb(self, x): 195 | assert self.image_key == "segmentation" 196 | if not hasattr(self, "colorize"): 197 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 198 | x = F.conv2d(x, weight=self.colorize) 199 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 200 | return x 201 | 202 | class IdentityFirstStage(torch.nn.Module): 203 | def __init__(self, *args, vq_interface=False, **kwargs): 204 | self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff 205 | super().__init__() 206 | 207 | def encode(self, x, *args, **kwargs): 208 | return x 209 | 210 | def decode(self, x, *args, **kwargs): 211 | return x 212 | 213 | def quantize(self, x, *args, **kwargs): 214 | if self.vq_interface: 215 | return x, None, [None, None, None] 216 | return x 217 | 218 | def forward(self, x, *args, **kwargs): 219 | return x -------------------------------------------------------------------------------- /lvdm/models/samplers/ddim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import torch 4 | from ..utils_diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg 5 | from ...common import noise_like 6 | from ...common import extract_into_tensor 7 | import copy 8 | import comfy.utils 9 | 10 | class DDIMSampler(object): 11 | def __init__(self, model, schedule="linear", **kwargs): 12 | super().__init__() 13 | self.model = model 14 | self.ddpm_num_timesteps = model.num_timesteps 15 | self.schedule = schedule 16 | self.counter = 0 17 | 18 | def register_buffer(self, name, attr): 19 | if type(attr) == torch.Tensor: 20 | if attr.device != torch.device("cuda"): 21 | attr = attr.to(torch.device("cuda")) 22 | setattr(self, name, attr) 23 | 24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 25 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 26 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 27 | alphas_cumprod = self.model.alphas_cumprod 28 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 29 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 30 | 31 | if self.model.use_dynamic_rescale: 32 | self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps] 33 | self.ddim_scale_arr_prev = torch.cat([self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]]) 34 | 35 | self.register_buffer('betas', to_torch(self.model.betas)) 36 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 37 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 38 | 39 | # calculations for diffusion q(x_t | x_{t-1}) and others 40 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 41 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 44 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 45 | 46 | # ddim sampling parameters 47 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 48 | ddim_timesteps=self.ddim_timesteps, 49 | eta=ddim_eta,verbose=verbose) 50 | self.register_buffer('ddim_sigmas', ddim_sigmas) 51 | self.register_buffer('ddim_alphas', ddim_alphas) 52 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 53 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 54 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 55 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 56 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 57 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 58 | 59 | @torch.no_grad() 60 | def sample(self, 61 | S, 62 | batch_size, 63 | shape, 64 | conditioning=None, 65 | callback=None, 66 | normals_sequence=None, 67 | img_callback=None, 68 | quantize_x0=False, 69 | eta=0., 70 | mask=None, 71 | x0=None, 72 | temperature=1., 73 | noise_dropout=0., 74 | score_corrector=None, 75 | corrector_kwargs=None, 76 | verbose=True, 77 | schedule_verbose=False, 78 | x_T=None, 79 | log_every_t=100, 80 | unconditional_guidance_scale=1., 81 | unconditional_conditioning=None, 82 | precision=None, 83 | fs=None, 84 | timestep_spacing='uniform', #uniform_trailing for starting from last timestep 85 | guidance_rescale=0.0, 86 | **kwargs 87 | ): 88 | 89 | # check condition bs 90 | if conditioning is not None: 91 | if isinstance(conditioning, dict): 92 | try: 93 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 94 | except: 95 | cbs = conditioning[list(conditioning.keys())[0]][0].shape[0] 96 | 97 | if cbs != batch_size: 98 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 99 | else: 100 | if conditioning.shape[0] != batch_size: 101 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 102 | 103 | self.make_schedule(ddim_num_steps=S, ddim_discretize=timestep_spacing, ddim_eta=eta, verbose=schedule_verbose) 104 | 105 | # make shape 106 | if len(shape) == 3: 107 | C, H, W = shape 108 | size = (batch_size, C, H, W) 109 | elif len(shape) == 4: 110 | C, T, H, W = shape 111 | size = (batch_size, C, T, H, W) 112 | 113 | samples, intermediates = self.ddim_sampling(conditioning, size, 114 | callback=callback, 115 | img_callback=img_callback, 116 | quantize_denoised=quantize_x0, 117 | mask=mask, x0=x0, 118 | ddim_use_original_steps=False, 119 | noise_dropout=noise_dropout, 120 | temperature=temperature, 121 | score_corrector=score_corrector, 122 | corrector_kwargs=corrector_kwargs, 123 | x_T=x_T, 124 | log_every_t=log_every_t, 125 | unconditional_guidance_scale=unconditional_guidance_scale, 126 | unconditional_conditioning=unconditional_conditioning, 127 | verbose=verbose, 128 | precision=precision, 129 | fs=fs, 130 | guidance_rescale=guidance_rescale, 131 | **kwargs) 132 | return samples, intermediates 133 | 134 | @torch.no_grad() 135 | def ddim_sampling(self, cond, shape, 136 | x_T=None, ddim_use_original_steps=False, 137 | callback=None, timesteps=None, quantize_denoised=False, 138 | mask=None, x0=None, img_callback=None, log_every_t=100, 139 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 140 | unconditional_guidance_scale=1., unconditional_conditioning=None, verbose=True,precision=None,fs=None,guidance_rescale=0.0, 141 | **kwargs): 142 | device = self.model.betas.device 143 | b = shape[0] 144 | if x_T is None: 145 | img = torch.randn(shape, device=device) 146 | else: 147 | img = x_T 148 | if precision is not None: 149 | if precision == 16: 150 | img = img.to(dtype=torch.float16) 151 | 152 | if timesteps is None: 153 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 154 | elif timesteps is not None and not ddim_use_original_steps: 155 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 156 | timesteps = self.ddim_timesteps[:subset_end] 157 | 158 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 159 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 160 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 161 | if verbose: 162 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 163 | else: 164 | iterator = time_range 165 | 166 | clean_cond = kwargs.pop("clean_cond", False) 167 | 168 | # cond_copy, unconditional_conditioning_copy = copy.deepcopy(cond), copy.deepcopy(unconditional_conditioning) 169 | pbar = comfy.utils.ProgressBar(total_steps) 170 | for i, step in enumerate(iterator): 171 | index = total_steps - i - 1 172 | ts = torch.full((b,), step, device=device, dtype=torch.long) 173 | 174 | ## use mask to blend noised original latent (img_orig) & new sampled latent (img) 175 | if mask is not None: 176 | assert x0 is not None 177 | if clean_cond: 178 | img_orig = x0 179 | else: 180 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 181 | img = img_orig * mask + (1. - mask) * img # keep original & modify use img 182 | 183 | 184 | 185 | 186 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 187 | quantize_denoised=quantize_denoised, temperature=temperature, 188 | noise_dropout=noise_dropout, score_corrector=score_corrector, 189 | corrector_kwargs=corrector_kwargs, 190 | unconditional_guidance_scale=unconditional_guidance_scale, 191 | unconditional_conditioning=unconditional_conditioning, 192 | mask=mask,x0=x0,fs=fs,guidance_rescale=guidance_rescale, 193 | **kwargs) 194 | 195 | 196 | img, pred_x0 = outs 197 | if callback: callback(i) 198 | if img_callback: img_callback(pred_x0, i) 199 | 200 | if index % log_every_t == 0 or index == total_steps - 1: 201 | intermediates['x_inter'].append(img) 202 | intermediates['pred_x0'].append(pred_x0) 203 | pbar.update(1) 204 | return img, intermediates 205 | 206 | @torch.no_grad() 207 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 208 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 209 | unconditional_guidance_scale=1., unconditional_conditioning=None, 210 | uc_type=None, conditional_guidance_scale_temporal=None,mask=None,x0=None,guidance_rescale=0.0,**kwargs): 211 | b, *_, device = *x.shape, x.device 212 | if x.dim() == 5: 213 | is_video = True 214 | else: 215 | is_video = False 216 | 217 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 218 | model_output = self.model.apply_model(x, t, c, **kwargs) # unet denoiser 219 | else: 220 | ### do_classifier_free_guidance 221 | if isinstance(c, torch.Tensor) or isinstance(c, dict): 222 | e_t_cond = self.model.apply_model(x, t, c, **kwargs) 223 | e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs) 224 | else: 225 | raise NotImplementedError 226 | 227 | model_output = e_t_uncond + unconditional_guidance_scale * (e_t_cond - e_t_uncond) 228 | 229 | if guidance_rescale > 0.0: 230 | model_output = rescale_noise_cfg(model_output, e_t_cond, guidance_rescale=guidance_rescale) 231 | 232 | if self.model.parameterization == "v": 233 | e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) 234 | else: 235 | e_t = model_output 236 | 237 | if score_corrector is not None: 238 | assert self.model.parameterization == "eps", 'not implemented' 239 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 240 | 241 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 242 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 243 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 244 | # sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 245 | sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 246 | # select parameters corresponding to the currently considered timestep 247 | 248 | if is_video: 249 | size = (b, 1, 1, 1, 1) 250 | else: 251 | size = (b, 1, 1, 1) 252 | a_t = torch.full(size, alphas[index], device=device) 253 | a_prev = torch.full(size, alphas_prev[index], device=device) 254 | sigma_t = torch.full(size, sigmas[index], device=device) 255 | sqrt_one_minus_at = torch.full(size, sqrt_one_minus_alphas[index],device=device) 256 | 257 | # current prediction for x_0 258 | if self.model.parameterization != "v": 259 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 260 | else: 261 | pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) 262 | 263 | if self.model.use_dynamic_rescale: 264 | scale_t = torch.full(size, self.ddim_scale_arr[index], device=device) 265 | prev_scale_t = torch.full(size, self.ddim_scale_arr_prev[index], device=device) 266 | rescale = (prev_scale_t / scale_t) 267 | pred_x0 *= rescale 268 | 269 | if quantize_denoised: 270 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 271 | # direction pointing to x_t 272 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 273 | 274 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 275 | if noise_dropout > 0.: 276 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 277 | 278 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 279 | 280 | return x_prev, pred_x0 281 | 282 | @torch.no_grad() 283 | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, 284 | use_original_steps=False, callback=None): 285 | 286 | timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps 287 | timesteps = timesteps[:t_start] 288 | 289 | time_range = np.flip(timesteps) 290 | total_steps = timesteps.shape[0] 291 | print(f"Running DDIM Sampling with {total_steps} timesteps") 292 | 293 | iterator = tqdm(time_range, desc='Decoding image', total=total_steps) 294 | x_dec = x_latent 295 | for i, step in enumerate(iterator): 296 | index = total_steps - i - 1 297 | ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) 298 | x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, 299 | unconditional_guidance_scale=unconditional_guidance_scale, 300 | unconditional_conditioning=unconditional_conditioning) 301 | if callback: callback(i) 302 | return x_dec 303 | 304 | @torch.no_grad() 305 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): 306 | # fast, but does not allow for exact reconstruction 307 | # t serves as an index to gather the correct alphas 308 | if use_original_steps: 309 | sqrt_alphas_cumprod = self.sqrt_alphas_cumprod 310 | sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod 311 | else: 312 | sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) 313 | sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas 314 | 315 | if noise is None: 316 | noise = torch.randn_like(x0) 317 | return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + 318 | extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) 319 | -------------------------------------------------------------------------------- /lvdm/models/samplers/ddim_multiplecond.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import torch 4 | from ..utils_diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg 5 | from ...common import noise_like 6 | from ...common import extract_into_tensor 7 | import copy 8 | 9 | 10 | class DDIMSampler(object): 11 | def __init__(self, model, schedule="linear", **kwargs): 12 | super().__init__() 13 | self.model = model 14 | self.ddpm_num_timesteps = model.num_timesteps 15 | self.schedule = schedule 16 | self.counter = 0 17 | 18 | def register_buffer(self, name, attr): 19 | if type(attr) == torch.Tensor: 20 | if attr.device != torch.device("cuda"): 21 | attr = attr.to(torch.device("cuda")) 22 | setattr(self, name, attr) 23 | 24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 25 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 26 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 27 | alphas_cumprod = self.model.alphas_cumprod 28 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 29 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 30 | 31 | if self.model.use_dynamic_rescale: 32 | self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps] 33 | self.ddim_scale_arr_prev = torch.cat([self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]]) 34 | 35 | self.register_buffer('betas', to_torch(self.model.betas)) 36 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 37 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 38 | 39 | # calculations for diffusion q(x_t | x_{t-1}) and others 40 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 41 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 44 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 45 | 46 | # ddim sampling parameters 47 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 48 | ddim_timesteps=self.ddim_timesteps, 49 | eta=ddim_eta,verbose=verbose) 50 | self.register_buffer('ddim_sigmas', ddim_sigmas) 51 | self.register_buffer('ddim_alphas', ddim_alphas) 52 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 53 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 54 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 55 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 56 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 57 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 58 | 59 | @torch.no_grad() 60 | def sample(self, 61 | S, 62 | batch_size, 63 | shape, 64 | conditioning=None, 65 | callback=None, 66 | normals_sequence=None, 67 | img_callback=None, 68 | quantize_x0=False, 69 | eta=0., 70 | mask=None, 71 | x0=None, 72 | temperature=1., 73 | noise_dropout=0., 74 | score_corrector=None, 75 | corrector_kwargs=None, 76 | verbose=True, 77 | schedule_verbose=False, 78 | x_T=None, 79 | log_every_t=100, 80 | unconditional_guidance_scale=1., 81 | unconditional_conditioning=None, 82 | precision=None, 83 | fs=None, 84 | timestep_spacing='uniform', #uniform_trailing for starting from last timestep 85 | guidance_rescale=0.0, 86 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 87 | **kwargs 88 | ): 89 | 90 | # check condition bs 91 | if conditioning is not None: 92 | if isinstance(conditioning, dict): 93 | try: 94 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 95 | except: 96 | cbs = conditioning[list(conditioning.keys())[0]][0].shape[0] 97 | 98 | if cbs != batch_size: 99 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 100 | else: 101 | if conditioning.shape[0] != batch_size: 102 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 103 | 104 | # print('==> timestep_spacing: ', timestep_spacing, guidance_rescale) 105 | self.make_schedule(ddim_num_steps=S, ddim_discretize=timestep_spacing, ddim_eta=eta, verbose=schedule_verbose) 106 | 107 | # make shape 108 | if len(shape) == 3: 109 | C, H, W = shape 110 | size = (batch_size, C, H, W) 111 | elif len(shape) == 4: 112 | C, T, H, W = shape 113 | size = (batch_size, C, T, H, W) 114 | # print(f'Data shape for DDIM sampling is {size}, eta {eta}') 115 | 116 | samples, intermediates = self.ddim_sampling(conditioning, size, 117 | callback=callback, 118 | img_callback=img_callback, 119 | quantize_denoised=quantize_x0, 120 | mask=mask, x0=x0, 121 | ddim_use_original_steps=False, 122 | noise_dropout=noise_dropout, 123 | temperature=temperature, 124 | score_corrector=score_corrector, 125 | corrector_kwargs=corrector_kwargs, 126 | x_T=x_T, 127 | log_every_t=log_every_t, 128 | unconditional_guidance_scale=unconditional_guidance_scale, 129 | unconditional_conditioning=unconditional_conditioning, 130 | verbose=verbose, 131 | precision=precision, 132 | fs=fs, 133 | guidance_rescale=guidance_rescale, 134 | **kwargs) 135 | return samples, intermediates 136 | 137 | @torch.no_grad() 138 | def ddim_sampling(self, cond, shape, 139 | x_T=None, ddim_use_original_steps=False, 140 | callback=None, timesteps=None, quantize_denoised=False, 141 | mask=None, x0=None, img_callback=None, log_every_t=100, 142 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 143 | unconditional_guidance_scale=1., unconditional_conditioning=None, verbose=True,precision=None,fs=None,guidance_rescale=0.0, 144 | **kwargs): 145 | device = self.model.betas.device 146 | b = shape[0] 147 | if x_T is None: 148 | img = torch.randn(shape, device=device) 149 | else: 150 | img = x_T 151 | if precision is not None: 152 | if precision == 16: 153 | img = img.to(dtype=torch.float16) 154 | 155 | 156 | if timesteps is None: 157 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 158 | elif timesteps is not None and not ddim_use_original_steps: 159 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 160 | timesteps = self.ddim_timesteps[:subset_end] 161 | 162 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 163 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 164 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 165 | if verbose: 166 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 167 | else: 168 | iterator = time_range 169 | 170 | clean_cond = kwargs.pop("clean_cond", False) 171 | 172 | # cond_copy, unconditional_conditioning_copy = copy.deepcopy(cond), copy.deepcopy(unconditional_conditioning) 173 | for i, step in enumerate(iterator): 174 | index = total_steps - i - 1 175 | ts = torch.full((b,), step, device=device, dtype=torch.long) 176 | 177 | ## use mask to blend noised original latent (img_orig) & new sampled latent (img) 178 | if mask is not None: 179 | assert x0 is not None 180 | if clean_cond: 181 | img_orig = x0 182 | else: 183 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 184 | img = img_orig * mask + (1. - mask) * img # keep original & modify use img 185 | 186 | 187 | 188 | 189 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 190 | quantize_denoised=quantize_denoised, temperature=temperature, 191 | noise_dropout=noise_dropout, score_corrector=score_corrector, 192 | corrector_kwargs=corrector_kwargs, 193 | unconditional_guidance_scale=unconditional_guidance_scale, 194 | unconditional_conditioning=unconditional_conditioning, 195 | mask=mask,x0=x0,fs=fs,guidance_rescale=guidance_rescale, 196 | **kwargs) 197 | 198 | 199 | 200 | img, pred_x0 = outs 201 | if callback: callback(i) 202 | if img_callback: img_callback(pred_x0, i) 203 | 204 | if index % log_every_t == 0 or index == total_steps - 1: 205 | intermediates['x_inter'].append(img) 206 | intermediates['pred_x0'].append(pred_x0) 207 | 208 | return img, intermediates 209 | 210 | @torch.no_grad() 211 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 212 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 213 | unconditional_guidance_scale=1., unconditional_conditioning=None, 214 | uc_type=None, cfg_img=None,mask=None,x0=None,guidance_rescale=0.0, **kwargs): 215 | b, *_, device = *x.shape, x.device 216 | if x.dim() == 5: 217 | is_video = True 218 | else: 219 | is_video = False 220 | if cfg_img is None: 221 | cfg_img = unconditional_guidance_scale 222 | 223 | unconditional_conditioning_img_nonetext = kwargs['unconditional_conditioning_img_nonetext'] 224 | 225 | 226 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 227 | model_output = self.model.apply_model(x, t, c, **kwargs) # unet denoiser 228 | else: 229 | ### with unconditional condition 230 | e_t_cond = self.model.apply_model(x, t, c, **kwargs) 231 | e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs) 232 | e_t_uncond_img = self.model.apply_model(x, t, unconditional_conditioning_img_nonetext, **kwargs) 233 | # text cfg 234 | model_output = e_t_uncond + cfg_img * (e_t_uncond_img - e_t_uncond) + unconditional_guidance_scale * (e_t_cond - e_t_uncond_img) 235 | if guidance_rescale > 0.0: 236 | model_output = rescale_noise_cfg(model_output, e_t_cond, guidance_rescale=guidance_rescale) 237 | 238 | if self.model.parameterization == "v": 239 | e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) 240 | else: 241 | e_t = model_output 242 | 243 | if score_corrector is not None: 244 | assert self.model.parameterization == "eps", 'not implemented' 245 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 246 | 247 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 248 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 249 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 250 | sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 251 | # select parameters corresponding to the currently considered timestep 252 | 253 | if is_video: 254 | size = (b, 1, 1, 1, 1) 255 | else: 256 | size = (b, 1, 1, 1) 257 | a_t = torch.full(size, alphas[index], device=device) 258 | a_prev = torch.full(size, alphas_prev[index], device=device) 259 | sigma_t = torch.full(size, sigmas[index], device=device) 260 | sqrt_one_minus_at = torch.full(size, sqrt_one_minus_alphas[index],device=device) 261 | 262 | # current prediction for x_0 263 | if self.model.parameterization != "v": 264 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 265 | else: 266 | pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) 267 | 268 | if self.model.use_dynamic_rescale: 269 | scale_t = torch.full(size, self.ddim_scale_arr[index], device=device) 270 | prev_scale_t = torch.full(size, self.ddim_scale_arr_prev[index], device=device) 271 | rescale = (prev_scale_t / scale_t) 272 | pred_x0 *= rescale 273 | 274 | if quantize_denoised: 275 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 276 | # direction pointing to x_t 277 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 278 | 279 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 280 | if noise_dropout > 0.: 281 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 282 | 283 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 284 | 285 | return x_prev, pred_x0 286 | 287 | @torch.no_grad() 288 | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, 289 | use_original_steps=False, callback=None): 290 | 291 | timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps 292 | timesteps = timesteps[:t_start] 293 | 294 | time_range = np.flip(timesteps) 295 | total_steps = timesteps.shape[0] 296 | print(f"Running DDIM Sampling with {total_steps} timesteps") 297 | 298 | iterator = tqdm(time_range, desc='Decoding image', total=total_steps) 299 | x_dec = x_latent 300 | for i, step in enumerate(iterator): 301 | index = total_steps - i - 1 302 | ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) 303 | x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, 304 | unconditional_guidance_scale=unconditional_guidance_scale, 305 | unconditional_conditioning=unconditional_conditioning) 306 | if callback: callback(i) 307 | return x_dec 308 | 309 | @torch.no_grad() 310 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): 311 | # fast, but does not allow for exact reconstruction 312 | # t serves as an index to gather the correct alphas 313 | if use_original_steps: 314 | sqrt_alphas_cumprod = self.sqrt_alphas_cumprod 315 | sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod 316 | else: 317 | sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) 318 | sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas 319 | 320 | if noise is None: 321 | noise = torch.randn_like(x0) 322 | return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + 323 | extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) -------------------------------------------------------------------------------- /lvdm/models/utils_diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import repeat 6 | 7 | 8 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 9 | """ 10 | Create sinusoidal timestep embeddings. 11 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 12 | These may be fractional. 13 | :param dim: the dimension of the output. 14 | :param max_period: controls the minimum frequency of the embeddings. 15 | :return: an [N x dim] Tensor of positional embeddings. 16 | """ 17 | if not repeat_only: 18 | half = dim // 2 19 | freqs = torch.exp( 20 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 21 | ).to(device=timesteps.device) 22 | args = timesteps[:, None].float() * freqs[None] 23 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 24 | if dim % 2: 25 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 26 | else: 27 | embedding = repeat(timesteps, 'b -> b d', d=dim) 28 | return embedding 29 | 30 | 31 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 32 | if schedule == "linear": 33 | betas = ( 34 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 35 | ) 36 | 37 | elif schedule == "cosine": 38 | timesteps = ( 39 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 40 | ) 41 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 42 | alphas = torch.cos(alphas).pow(2) 43 | alphas = alphas / alphas[0] 44 | betas = 1 - alphas[1:] / alphas[:-1] 45 | betas = np.clip(betas, a_min=0, a_max=0.999) 46 | 47 | elif schedule == "sqrt_linear": 48 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 49 | elif schedule == "sqrt": 50 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 51 | else: 52 | raise ValueError(f"schedule '{schedule}' unknown.") 53 | return betas.numpy() 54 | 55 | 56 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 57 | if ddim_discr_method == 'uniform': 58 | c = num_ddpm_timesteps // num_ddim_timesteps 59 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 60 | steps_out = ddim_timesteps + 1 61 | elif ddim_discr_method == 'uniform_trailing': 62 | c = num_ddpm_timesteps / num_ddim_timesteps 63 | ddim_timesteps = np.flip(np.round(np.arange(num_ddpm_timesteps, 0, -c))).astype(np.int64) 64 | steps_out = ddim_timesteps - 1 65 | elif ddim_discr_method == 'quad': 66 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 67 | steps_out = ddim_timesteps + 1 68 | else: 69 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 70 | 71 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 72 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 73 | # steps_out = ddim_timesteps + 1 74 | if verbose: 75 | print(f'Selected timesteps for ddim sampler: {steps_out}') 76 | return steps_out 77 | 78 | 79 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 80 | # select alphas for computing the variance schedule 81 | # print(f'ddim_timesteps={ddim_timesteps}, len_alphacums={len(alphacums)}') 82 | alphas = alphacums[ddim_timesteps] 83 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 84 | 85 | # according the the formula provided in https://arxiv.org/abs/2010.02502 86 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 87 | if verbose: 88 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 89 | print(f'For the chosen value of eta, which is {eta}, ' 90 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 91 | return sigmas, alphas, alphas_prev 92 | 93 | 94 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 95 | """ 96 | Create a beta schedule that discretizes the given alpha_t_bar function, 97 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 98 | :param num_diffusion_timesteps: the number of betas to produce. 99 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 100 | produces the cumulative product of (1-beta) up to that 101 | part of the diffusion process. 102 | :param max_beta: the maximum beta to use; use values lower than 1 to 103 | prevent singularities. 104 | """ 105 | betas = [] 106 | for i in range(num_diffusion_timesteps): 107 | t1 = i / num_diffusion_timesteps 108 | t2 = (i + 1) / num_diffusion_timesteps 109 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 110 | return np.array(betas) 111 | 112 | def rescale_zero_terminal_snr(betas): 113 | """ 114 | Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) 115 | 116 | Args: 117 | betas (`numpy.ndarray`): 118 | the betas that the scheduler is being initialized with. 119 | 120 | Returns: 121 | `numpy.ndarray`: rescaled betas with zero terminal SNR 122 | """ 123 | # Convert betas to alphas_bar_sqrt 124 | alphas = 1.0 - betas 125 | alphas_cumprod = np.cumprod(alphas, axis=0) 126 | alphas_bar_sqrt = np.sqrt(alphas_cumprod) 127 | 128 | # Store old values. 129 | alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy() 130 | alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy() 131 | 132 | # Shift so the last timestep is zero. 133 | alphas_bar_sqrt -= alphas_bar_sqrt_T 134 | 135 | # Scale so the first timestep is back to the old value. 136 | alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) 137 | 138 | # Convert alphas_bar_sqrt to betas 139 | alphas_bar = alphas_bar_sqrt**2 # Revert sqrt 140 | alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod 141 | alphas = np.concatenate([alphas_bar[0:1], alphas]) 142 | betas = 1 - alphas 143 | 144 | return betas 145 | 146 | 147 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): 148 | """ 149 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and 150 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 151 | """ 152 | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) 153 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) 154 | # rescale the results from guidance (fixes overexposure) 155 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) 156 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images 157 | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg 158 | return noise_cfg -------------------------------------------------------------------------------- /lvdm/modules/encoders/condition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import kornia 4 | import open_clip 5 | from torch.utils.checkpoint import checkpoint 6 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel 7 | from ...common import autocast 8 | from ....utils.utils import count_params 9 | 10 | 11 | class AbstractEncoder(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def encode(self, *args, **kwargs): 16 | raise NotImplementedError 17 | 18 | 19 | class IdentityEncoder(AbstractEncoder): 20 | def encode(self, x): 21 | return x 22 | 23 | 24 | class ClassEmbedder(nn.Module): 25 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): 26 | super().__init__() 27 | self.key = key 28 | self.embedding = nn.Embedding(n_classes, embed_dim) 29 | self.n_classes = n_classes 30 | self.ucg_rate = ucg_rate 31 | 32 | def forward(self, batch, key=None, disable_dropout=False): 33 | if key is None: 34 | key = self.key 35 | # this is for use in crossattn 36 | c = batch[key][:, None] 37 | if self.ucg_rate > 0. and not disable_dropout: 38 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) 39 | c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) 40 | c = c.long() 41 | c = self.embedding(c) 42 | return c 43 | 44 | def get_unconditional_conditioning(self, bs, device="cuda"): 45 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) 46 | uc = torch.ones((bs,), device=device) * uc_class 47 | uc = {self.key: uc} 48 | return uc 49 | 50 | 51 | def disabled_train(self, mode=True): 52 | """Overwrite model.train with this function to make sure train/eval mode 53 | does not change anymore.""" 54 | return self 55 | 56 | 57 | class FrozenT5Embedder(AbstractEncoder): 58 | """Uses the T5 transformer encoder for text""" 59 | 60 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, 61 | freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl 62 | super().__init__() 63 | self.tokenizer = T5Tokenizer.from_pretrained(version) 64 | self.transformer = T5EncoderModel.from_pretrained(version) 65 | self.device = device 66 | self.max_length = max_length # TODO: typical value? 67 | if freeze: 68 | self.freeze() 69 | 70 | def freeze(self): 71 | self.transformer = self.transformer.eval() 72 | # self.train = disabled_train 73 | for param in self.parameters(): 74 | param.requires_grad = False 75 | 76 | def forward(self, text): 77 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 78 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 79 | tokens = batch_encoding["input_ids"].to(self.device) 80 | outputs = self.transformer(input_ids=tokens) 81 | 82 | z = outputs.last_hidden_state 83 | return z 84 | 85 | def encode(self, text): 86 | return self(text) 87 | 88 | 89 | class FrozenCLIPEmbedder(AbstractEncoder): 90 | """Uses the CLIP transformer encoder for text (from huggingface)""" 91 | LAYERS = [ 92 | "last", 93 | "pooled", 94 | "hidden" 95 | ] 96 | 97 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, 98 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 99 | super().__init__() 100 | assert layer in self.LAYERS 101 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 102 | self.transformer = CLIPTextModel.from_pretrained(version) 103 | self.device = device 104 | self.max_length = max_length 105 | if freeze: 106 | self.freeze() 107 | self.layer = layer 108 | self.layer_idx = layer_idx 109 | if layer == "hidden": 110 | assert layer_idx is not None 111 | assert 0 <= abs(layer_idx) <= 12 112 | 113 | def freeze(self): 114 | self.transformer = self.transformer.eval() 115 | # self.train = disabled_train 116 | for param in self.parameters(): 117 | param.requires_grad = False 118 | 119 | def forward(self, text): 120 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 121 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 122 | tokens = batch_encoding["input_ids"].to(self.device) 123 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden") 124 | if self.layer == "last": 125 | z = outputs.last_hidden_state 126 | elif self.layer == "pooled": 127 | z = outputs.pooler_output[:, None, :] 128 | else: 129 | z = outputs.hidden_states[self.layer_idx] 130 | return z 131 | 132 | def encode(self, text): 133 | return self(text) 134 | 135 | 136 | class ClipImageEmbedder(nn.Module): 137 | def __init__( 138 | self, 139 | model, 140 | jit=False, 141 | device='cuda' if torch.cuda.is_available() else 'cpu', 142 | antialias=True, 143 | ucg_rate=0. 144 | ): 145 | super().__init__() 146 | from clip import load as load_clip 147 | self.model, _ = load_clip(name=model, device=device, jit=jit) 148 | 149 | self.antialias = antialias 150 | 151 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 152 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 153 | self.ucg_rate = ucg_rate 154 | 155 | def preprocess(self, x): 156 | # normalize to [0,1] 157 | x = kornia.geometry.resize(x, (224, 224), 158 | interpolation='bicubic', align_corners=True, 159 | antialias=self.antialias) 160 | x = (x + 1.) / 2. 161 | # re-normalize according to clip 162 | x = kornia.enhance.normalize(x, self.mean, self.std) 163 | return x 164 | 165 | def forward(self, x, no_dropout=False): 166 | # x is assumed to be in range [-1,1] 167 | out = self.model.encode_image(self.preprocess(x)) 168 | out = out.to(x.dtype) 169 | if self.ucg_rate > 0. and not no_dropout: 170 | out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out 171 | return out 172 | 173 | 174 | class FrozenOpenCLIPEmbedder(AbstractEncoder): 175 | """ 176 | Uses the OpenCLIP transformer encoder for text 177 | """ 178 | LAYERS = [ 179 | # "pooled", 180 | "last", 181 | "penultimate" 182 | ] 183 | 184 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 185 | freeze=True, layer="last"): 186 | super().__init__() 187 | assert layer in self.LAYERS 188 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) 189 | del model.visual 190 | self.model = model 191 | 192 | self.device = device 193 | self.max_length = max_length 194 | if freeze: 195 | self.freeze() 196 | self.layer = layer 197 | if self.layer == "last": 198 | self.layer_idx = 0 199 | elif self.layer == "penultimate": 200 | self.layer_idx = 1 201 | else: 202 | raise NotImplementedError() 203 | 204 | def freeze(self): 205 | self.model = self.model.eval() 206 | for param in self.parameters(): 207 | param.requires_grad = False 208 | 209 | def forward(self, text): 210 | tokens = open_clip.tokenize(text) ## all clip models use 77 as context length 211 | z = self.encode_with_transformer(tokens.to(self.device)) 212 | return z 213 | 214 | def encode_with_transformer(self, text): 215 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 216 | x = x + self.model.positional_embedding 217 | x = x.permute(1, 0, 2) # NLD -> LND 218 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 219 | x = x.permute(1, 0, 2) # LND -> NLD 220 | x = self.model.ln_final(x) 221 | return x 222 | 223 | def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): 224 | for i, r in enumerate(self.model.transformer.resblocks): 225 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 226 | break 227 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 228 | x = checkpoint(r, x, attn_mask) 229 | else: 230 | x = r(x, attn_mask=attn_mask) 231 | return x 232 | 233 | def encode(self, text): 234 | return self(text) 235 | 236 | 237 | class FrozenOpenCLIPImageEmbedder(AbstractEncoder): 238 | """ 239 | Uses the OpenCLIP vision transformer encoder for images 240 | """ 241 | 242 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 243 | freeze=True, layer="pooled", antialias=True, ucg_rate=0.): 244 | super().__init__() 245 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), 246 | pretrained=version, ) 247 | del model.transformer 248 | self.model = model 249 | # self.mapper = torch.nn.Linear(1280, 1024) 250 | self.device = device 251 | self.max_length = max_length 252 | if freeze: 253 | self.freeze() 254 | self.layer = layer 255 | if self.layer == "penultimate": 256 | raise NotImplementedError() 257 | self.layer_idx = 1 258 | 259 | self.antialias = antialias 260 | 261 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 262 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 263 | self.ucg_rate = ucg_rate 264 | 265 | def preprocess(self, x): 266 | # normalize to [0,1] 267 | x = kornia.geometry.resize(x, (224, 224), 268 | interpolation='bicubic', align_corners=True, 269 | antialias=self.antialias) 270 | x = (x + 1.) / 2. 271 | # renormalize according to clip 272 | x = kornia.enhance.normalize(x, self.mean, self.std) 273 | return x 274 | 275 | def freeze(self): 276 | self.model = self.model.eval() 277 | for param in self.model.parameters(): 278 | param.requires_grad = False 279 | 280 | @autocast 281 | def forward(self, image, no_dropout=False): 282 | z = self.encode_with_vision_transformer(image) 283 | if self.ucg_rate > 0. and not no_dropout: 284 | z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z 285 | return z 286 | 287 | def encode_with_vision_transformer(self, img): 288 | img = self.preprocess(img) 289 | x = self.model.visual(img) 290 | return x 291 | 292 | def encode(self, text): 293 | return self(text) 294 | 295 | class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder): 296 | """ 297 | Uses the OpenCLIP vision transformer encoder for images 298 | """ 299 | 300 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", 301 | freeze=True, layer="pooled", antialias=True): 302 | super().__init__() 303 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), 304 | pretrained=version, ) 305 | del model.transformer 306 | self.model = model 307 | self.device = device 308 | 309 | if freeze: 310 | self.freeze() 311 | self.layer = layer 312 | if self.layer == "penultimate": 313 | raise NotImplementedError() 314 | self.layer_idx = 1 315 | 316 | self.antialias = antialias 317 | 318 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 319 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 320 | 321 | 322 | def preprocess(self, x): 323 | # normalize to [0,1] 324 | x = kornia.geometry.resize(x, (224, 224), 325 | interpolation='bicubic', align_corners=True, 326 | antialias=self.antialias) 327 | x = (x + 1.) / 2. 328 | # renormalize according to clip 329 | x = kornia.enhance.normalize(x, self.mean, self.std) 330 | return x 331 | 332 | def freeze(self): 333 | self.model = self.model.eval() 334 | for param in self.model.parameters(): 335 | param.requires_grad = False 336 | 337 | def forward(self, image, no_dropout=False): 338 | ## image: b c h w 339 | z = self.encode_with_vision_transformer(image) 340 | return z 341 | 342 | def encode_with_vision_transformer(self, x): 343 | x = self.preprocess(x) 344 | 345 | # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 346 | if self.model.visual.input_patchnorm: 347 | # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') 348 | x = x.reshape(x.shape[0], x.shape[1], self.model.visual.grid_size[0], self.model.visual.patch_size[0], self.model.visual.grid_size[1], self.model.visual.patch_size[1]) 349 | x = x.permute(0, 2, 4, 1, 3, 5) 350 | x = x.reshape(x.shape[0], self.model.visual.grid_size[0] * self.model.visual.grid_size[1], -1) 351 | x = self.model.visual.patchnorm_pre_ln(x) 352 | x = self.model.visual.conv1(x) 353 | else: 354 | x = self.model.visual.conv1(x) # shape = [*, width, grid, grid] 355 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 356 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 357 | 358 | # class embeddings and positional embeddings 359 | x = torch.cat( 360 | [self.model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 361 | x], dim=1) # shape = [*, grid ** 2 + 1, width] 362 | x = x + self.model.visual.positional_embedding.to(x.dtype) 363 | 364 | # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in 365 | x = self.model.visual.patch_dropout(x) 366 | x = self.model.visual.ln_pre(x) 367 | 368 | x = x.permute(1, 0, 2) # NLD -> LND 369 | x = self.model.visual.transformer(x) 370 | x = x.permute(1, 0, 2) # LND -> NLD 371 | 372 | return x 373 | 374 | class FrozenCLIPT5Encoder(AbstractEncoder): 375 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", 376 | clip_max_length=77, t5_max_length=77): 377 | super().__init__() 378 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) 379 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) 380 | print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " 381 | f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.") 382 | 383 | def encode(self, text): 384 | return self(text) 385 | 386 | def forward(self, text): 387 | clip_z = self.clip_encoder.encode(text) 388 | t5_z = self.t5_encoder.encode(text) 389 | return [clip_z, t5_z] 390 | -------------------------------------------------------------------------------- /lvdm/modules/encoders/resampler.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py 3 | # and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class ImageProjModel(nn.Module): 10 | """Projection Model""" 11 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 12 | super().__init__() 13 | self.cross_attention_dim = cross_attention_dim 14 | self.clip_extra_context_tokens = clip_extra_context_tokens 15 | self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 16 | self.norm = nn.LayerNorm(cross_attention_dim) 17 | 18 | def forward(self, image_embeds): 19 | #embeds = image_embeds 20 | embeds = image_embeds.type(list(self.proj.parameters())[0].dtype) 21 | clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) 22 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 23 | return clip_extra_context_tokens 24 | 25 | 26 | # FFN 27 | def FeedForward(dim, mult=4): 28 | inner_dim = int(dim * mult) 29 | return nn.Sequential( 30 | nn.LayerNorm(dim), 31 | nn.Linear(dim, inner_dim, bias=False), 32 | nn.GELU(), 33 | nn.Linear(inner_dim, dim, bias=False), 34 | ) 35 | 36 | 37 | def reshape_tensor(x, heads): 38 | bs, length, width = x.shape 39 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head) 40 | x = x.view(bs, length, heads, -1) 41 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 42 | x = x.transpose(1, 2) 43 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 44 | x = x.reshape(bs, heads, length, -1) 45 | return x 46 | 47 | 48 | class PerceiverAttention(nn.Module): 49 | def __init__(self, *, dim, dim_head=64, heads=8): 50 | super().__init__() 51 | self.scale = dim_head**-0.5 52 | self.dim_head = dim_head 53 | self.heads = heads 54 | inner_dim = dim_head * heads 55 | 56 | self.norm1 = nn.LayerNorm(dim) 57 | self.norm2 = nn.LayerNorm(dim) 58 | 59 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 60 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 61 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 62 | 63 | 64 | def forward(self, x, latents): 65 | """ 66 | Args: 67 | x (torch.Tensor): image features 68 | shape (b, n1, D) 69 | latent (torch.Tensor): latent features 70 | shape (b, n2, D) 71 | """ 72 | x = self.norm1(x) 73 | latents = self.norm2(latents) 74 | 75 | b, l, _ = latents.shape 76 | 77 | q = self.to_q(latents) 78 | kv_input = torch.cat((x, latents), dim=-2) 79 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 80 | 81 | q = reshape_tensor(q, self.heads) 82 | k = reshape_tensor(k, self.heads) 83 | v = reshape_tensor(v, self.heads) 84 | 85 | # attention 86 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 87 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 88 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 89 | out = weight @ v 90 | 91 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 92 | 93 | return self.to_out(out) 94 | 95 | 96 | class Resampler(nn.Module): 97 | def __init__( 98 | self, 99 | dim=1024, 100 | depth=8, 101 | dim_head=64, 102 | heads=16, 103 | num_queries=8, 104 | embedding_dim=768, 105 | output_dim=1024, 106 | ff_mult=4, 107 | video_length=None, # using frame-wise version or not 108 | ): 109 | super().__init__() 110 | ## queries for a single frame / image 111 | self.num_queries = num_queries 112 | self.video_length = video_length 113 | 114 | ## queries for each frame 115 | if video_length is not None: 116 | num_queries = num_queries * video_length 117 | 118 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 119 | self.proj_in = nn.Linear(embedding_dim, dim) 120 | self.proj_out = nn.Linear(dim, output_dim) 121 | self.norm_out = nn.LayerNorm(output_dim) 122 | 123 | self.layers = nn.ModuleList([]) 124 | for _ in range(depth): 125 | self.layers.append( 126 | nn.ModuleList( 127 | [ 128 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 129 | FeedForward(dim=dim, mult=ff_mult), 130 | ] 131 | ) 132 | ) 133 | 134 | def forward(self, x): 135 | latents = self.latents.repeat(x.size(0), 1, 1) ## B (T L) C 136 | x = self.proj_in(x) 137 | 138 | for attn, ff in self.layers: 139 | latents = attn(x, latents) + latents 140 | latents = ff(latents) + latents 141 | 142 | latents = self.proj_out(latents) 143 | latents = self.norm_out(latents) # B L C or B (T L) C 144 | 145 | return latents -------------------------------------------------------------------------------- /lvdm/modules/x_transformer.py: -------------------------------------------------------------------------------- 1 | """shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" 2 | from functools import partial 3 | from inspect import isfunction 4 | from collections import namedtuple 5 | from einops import rearrange, repeat 6 | import torch 7 | from torch import nn, einsum 8 | import torch.nn.functional as F 9 | 10 | # constants 11 | DEFAULT_DIM_HEAD = 64 12 | 13 | Intermediates = namedtuple('Intermediates', [ 14 | 'pre_softmax_attn', 15 | 'post_softmax_attn' 16 | ]) 17 | 18 | LayerIntermediates = namedtuple('Intermediates', [ 19 | 'hiddens', 20 | 'attn_intermediates' 21 | ]) 22 | 23 | 24 | class AbsolutePositionalEmbedding(nn.Module): 25 | def __init__(self, dim, max_seq_len): 26 | super().__init__() 27 | self.emb = nn.Embedding(max_seq_len, dim) 28 | self.init_() 29 | 30 | def init_(self): 31 | nn.init.normal_(self.emb.weight, std=0.02) 32 | 33 | def forward(self, x): 34 | n = torch.arange(x.shape[1], device=x.device) 35 | return self.emb(n)[None, :, :] 36 | 37 | 38 | class FixedPositionalEmbedding(nn.Module): 39 | def __init__(self, dim): 40 | super().__init__() 41 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 42 | self.register_buffer('inv_freq', inv_freq) 43 | 44 | def forward(self, x, seq_dim=1, offset=0): 45 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset 46 | sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) 47 | emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) 48 | return emb[None, :, :] 49 | 50 | 51 | # helpers 52 | 53 | def exists(val): 54 | return val is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def always(val): 64 | def inner(*args, **kwargs): 65 | return val 66 | return inner 67 | 68 | 69 | def not_equals(val): 70 | def inner(x): 71 | return x != val 72 | return inner 73 | 74 | 75 | def equals(val): 76 | def inner(x): 77 | return x == val 78 | return inner 79 | 80 | 81 | def max_neg_value(tensor): 82 | return -torch.finfo(tensor.dtype).max 83 | 84 | 85 | # keyword argument helpers 86 | 87 | def pick_and_pop(keys, d): 88 | values = list(map(lambda key: d.pop(key), keys)) 89 | return dict(zip(keys, values)) 90 | 91 | 92 | def group_dict_by_key(cond, d): 93 | return_val = [dict(), dict()] 94 | for key in d.keys(): 95 | match = bool(cond(key)) 96 | ind = int(not match) 97 | return_val[ind][key] = d[key] 98 | return (*return_val,) 99 | 100 | 101 | def string_begins_with(prefix, str): 102 | return str.startswith(prefix) 103 | 104 | 105 | def group_by_key_prefix(prefix, d): 106 | return group_dict_by_key(partial(string_begins_with, prefix), d) 107 | 108 | 109 | def groupby_prefix_and_trim(prefix, d): 110 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) 111 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) 112 | return kwargs_without_prefix, kwargs 113 | 114 | 115 | # classes 116 | class Scale(nn.Module): 117 | def __init__(self, value, fn): 118 | super().__init__() 119 | self.value = value 120 | self.fn = fn 121 | 122 | def forward(self, x, **kwargs): 123 | x, *rest = self.fn(x, **kwargs) 124 | return (x * self.value, *rest) 125 | 126 | 127 | class Rezero(nn.Module): 128 | def __init__(self, fn): 129 | super().__init__() 130 | self.fn = fn 131 | self.g = nn.Parameter(torch.zeros(1)) 132 | 133 | def forward(self, x, **kwargs): 134 | x, *rest = self.fn(x, **kwargs) 135 | return (x * self.g, *rest) 136 | 137 | 138 | class ScaleNorm(nn.Module): 139 | def __init__(self, dim, eps=1e-5): 140 | super().__init__() 141 | self.scale = dim ** -0.5 142 | self.eps = eps 143 | self.g = nn.Parameter(torch.ones(1)) 144 | 145 | def forward(self, x): 146 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale 147 | return x / norm.clamp(min=self.eps) * self.g 148 | 149 | 150 | class RMSNorm(nn.Module): 151 | def __init__(self, dim, eps=1e-8): 152 | super().__init__() 153 | self.scale = dim ** -0.5 154 | self.eps = eps 155 | self.g = nn.Parameter(torch.ones(dim)) 156 | 157 | def forward(self, x): 158 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale 159 | return x / norm.clamp(min=self.eps) * self.g 160 | 161 | 162 | class Residual(nn.Module): 163 | def forward(self, x, residual): 164 | return x + residual 165 | 166 | 167 | class GRUGating(nn.Module): 168 | def __init__(self, dim): 169 | super().__init__() 170 | self.gru = nn.GRUCell(dim, dim) 171 | 172 | def forward(self, x, residual): 173 | gated_output = self.gru( 174 | rearrange(x, 'b n d -> (b n) d'), 175 | rearrange(residual, 'b n d -> (b n) d') 176 | ) 177 | 178 | return gated_output.reshape_as(x) 179 | 180 | 181 | # feedforward 182 | 183 | class GEGLU(nn.Module): 184 | def __init__(self, dim_in, dim_out): 185 | super().__init__() 186 | self.proj = nn.Linear(dim_in, dim_out * 2) 187 | 188 | def forward(self, x): 189 | x, gate = self.proj(x).chunk(2, dim=-1) 190 | return x * F.gelu(gate) 191 | 192 | 193 | class FeedForward(nn.Module): 194 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 195 | super().__init__() 196 | inner_dim = int(dim * mult) 197 | dim_out = default(dim_out, dim) 198 | project_in = nn.Sequential( 199 | nn.Linear(dim, inner_dim), 200 | nn.GELU() 201 | ) if not glu else GEGLU(dim, inner_dim) 202 | 203 | self.net = nn.Sequential( 204 | project_in, 205 | nn.Dropout(dropout), 206 | nn.Linear(inner_dim, dim_out) 207 | ) 208 | 209 | def forward(self, x): 210 | return self.net(x) 211 | 212 | 213 | # attention. 214 | class Attention(nn.Module): 215 | def __init__( 216 | self, 217 | dim, 218 | dim_head=DEFAULT_DIM_HEAD, 219 | heads=8, 220 | causal=False, 221 | mask=None, 222 | talking_heads=False, 223 | sparse_topk=None, 224 | use_entmax15=False, 225 | num_mem_kv=0, 226 | dropout=0., 227 | on_attn=False 228 | ): 229 | super().__init__() 230 | if use_entmax15: 231 | raise NotImplementedError("Check out entmax activation instead of softmax activation!") 232 | self.scale = dim_head ** -0.5 233 | self.heads = heads 234 | self.causal = causal 235 | self.mask = mask 236 | 237 | inner_dim = dim_head * heads 238 | 239 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 240 | self.to_k = nn.Linear(dim, inner_dim, bias=False) 241 | self.to_v = nn.Linear(dim, inner_dim, bias=False) 242 | self.dropout = nn.Dropout(dropout) 243 | 244 | # talking heads 245 | self.talking_heads = talking_heads 246 | if talking_heads: 247 | self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) 248 | self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) 249 | 250 | # explicit topk sparse attention 251 | self.sparse_topk = sparse_topk 252 | 253 | # entmax 254 | #self.attn_fn = entmax15 if use_entmax15 else F.softmax 255 | self.attn_fn = F.softmax 256 | 257 | # add memory key / values 258 | self.num_mem_kv = num_mem_kv 259 | if num_mem_kv > 0: 260 | self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) 261 | self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) 262 | 263 | # attention on attention 264 | self.attn_on_attn = on_attn 265 | self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) 266 | 267 | def forward( 268 | self, 269 | x, 270 | context=None, 271 | mask=None, 272 | context_mask=None, 273 | rel_pos=None, 274 | sinusoidal_emb=None, 275 | prev_attn=None, 276 | mem=None 277 | ): 278 | b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device 279 | kv_input = default(context, x) 280 | 281 | q_input = x 282 | k_input = kv_input 283 | v_input = kv_input 284 | 285 | if exists(mem): 286 | k_input = torch.cat((mem, k_input), dim=-2) 287 | v_input = torch.cat((mem, v_input), dim=-2) 288 | 289 | if exists(sinusoidal_emb): 290 | # in shortformer, the query would start at a position offset depending on the past cached memory 291 | offset = k_input.shape[-2] - q_input.shape[-2] 292 | q_input = q_input + sinusoidal_emb(q_input, offset=offset) 293 | k_input = k_input + sinusoidal_emb(k_input) 294 | 295 | q = self.to_q(q_input) 296 | k = self.to_k(k_input) 297 | v = self.to_v(v_input) 298 | 299 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) 300 | 301 | input_mask = None 302 | if any(map(exists, (mask, context_mask))): 303 | q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) 304 | k_mask = q_mask if not exists(context) else context_mask 305 | k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) 306 | q_mask = rearrange(q_mask, 'b i -> b () i ()') 307 | k_mask = rearrange(k_mask, 'b j -> b () () j') 308 | input_mask = q_mask * k_mask 309 | 310 | if self.num_mem_kv > 0: 311 | mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) 312 | k = torch.cat((mem_k, k), dim=-2) 313 | v = torch.cat((mem_v, v), dim=-2) 314 | if exists(input_mask): 315 | input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) 316 | 317 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 318 | mask_value = max_neg_value(dots) 319 | 320 | if exists(prev_attn): 321 | dots = dots + prev_attn 322 | 323 | pre_softmax_attn = dots 324 | 325 | if talking_heads: 326 | dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() 327 | 328 | if exists(rel_pos): 329 | dots = rel_pos(dots) 330 | 331 | if exists(input_mask): 332 | dots.masked_fill_(~input_mask, mask_value) 333 | del input_mask 334 | 335 | if self.causal: 336 | i, j = dots.shape[-2:] 337 | r = torch.arange(i, device=device) 338 | mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') 339 | mask = F.pad(mask, (j - i, 0), value=False) 340 | dots.masked_fill_(mask, mask_value) 341 | del mask 342 | 343 | if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: 344 | top, _ = dots.topk(self.sparse_topk, dim=-1) 345 | vk = top[..., -1].unsqueeze(-1).expand_as(dots) 346 | mask = dots < vk 347 | dots.masked_fill_(mask, mask_value) 348 | del mask 349 | 350 | attn = self.attn_fn(dots, dim=-1) 351 | post_softmax_attn = attn 352 | 353 | attn = self.dropout(attn) 354 | 355 | if talking_heads: 356 | attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() 357 | 358 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 359 | out = rearrange(out, 'b h n d -> b n (h d)') 360 | 361 | intermediates = Intermediates( 362 | pre_softmax_attn=pre_softmax_attn, 363 | post_softmax_attn=post_softmax_attn 364 | ) 365 | 366 | return self.to_out(out), intermediates 367 | 368 | 369 | class AttentionLayers(nn.Module): 370 | def __init__( 371 | self, 372 | dim, 373 | depth, 374 | heads=8, 375 | causal=False, 376 | cross_attend=False, 377 | only_cross=False, 378 | use_scalenorm=False, 379 | use_rmsnorm=False, 380 | use_rezero=False, 381 | rel_pos_num_buckets=32, 382 | rel_pos_max_distance=128, 383 | position_infused_attn=False, 384 | custom_layers=None, 385 | sandwich_coef=None, 386 | par_ratio=None, 387 | residual_attn=False, 388 | cross_residual_attn=False, 389 | macaron=False, 390 | pre_norm=True, 391 | gate_residual=False, 392 | **kwargs 393 | ): 394 | super().__init__() 395 | ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) 396 | attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) 397 | 398 | dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) 399 | 400 | self.dim = dim 401 | self.depth = depth 402 | self.layers = nn.ModuleList([]) 403 | 404 | self.has_pos_emb = position_infused_attn 405 | self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None 406 | self.rotary_pos_emb = always(None) 407 | 408 | assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' 409 | self.rel_pos = None 410 | 411 | self.pre_norm = pre_norm 412 | 413 | self.residual_attn = residual_attn 414 | self.cross_residual_attn = cross_residual_attn 415 | 416 | norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm 417 | norm_class = RMSNorm if use_rmsnorm else norm_class 418 | norm_fn = partial(norm_class, dim) 419 | 420 | norm_fn = nn.Identity if use_rezero else norm_fn 421 | branch_fn = Rezero if use_rezero else None 422 | 423 | if cross_attend and not only_cross: 424 | default_block = ('a', 'c', 'f') 425 | elif cross_attend and only_cross: 426 | default_block = ('c', 'f') 427 | else: 428 | default_block = ('a', 'f') 429 | 430 | if macaron: 431 | default_block = ('f',) + default_block 432 | 433 | if exists(custom_layers): 434 | layer_types = custom_layers 435 | elif exists(par_ratio): 436 | par_depth = depth * len(default_block) 437 | assert 1 < par_ratio <= par_depth, 'par ratio out of range' 438 | default_block = tuple(filter(not_equals('f'), default_block)) 439 | par_attn = par_depth // par_ratio 440 | depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper 441 | par_width = (depth_cut + depth_cut // par_attn) // par_attn 442 | assert len(default_block) <= par_width, 'default block is too large for par_ratio' 443 | par_block = default_block + ('f',) * (par_width - len(default_block)) 444 | par_head = par_block * par_attn 445 | layer_types = par_head + ('f',) * (par_depth - len(par_head)) 446 | elif exists(sandwich_coef): 447 | assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' 448 | layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef 449 | else: 450 | layer_types = default_block * depth 451 | 452 | self.layer_types = layer_types 453 | self.num_attn_layers = len(list(filter(equals('a'), layer_types))) 454 | 455 | for layer_type in self.layer_types: 456 | if layer_type == 'a': 457 | layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) 458 | elif layer_type == 'c': 459 | layer = Attention(dim, heads=heads, **attn_kwargs) 460 | elif layer_type == 'f': 461 | layer = FeedForward(dim, **ff_kwargs) 462 | layer = layer if not macaron else Scale(0.5, layer) 463 | else: 464 | raise Exception(f'invalid layer type {layer_type}') 465 | 466 | if isinstance(layer, Attention) and exists(branch_fn): 467 | layer = branch_fn(layer) 468 | 469 | if gate_residual: 470 | residual_fn = GRUGating(dim) 471 | else: 472 | residual_fn = Residual() 473 | 474 | self.layers.append(nn.ModuleList([ 475 | norm_fn(), 476 | layer, 477 | residual_fn 478 | ])) 479 | 480 | def forward( 481 | self, 482 | x, 483 | context=None, 484 | mask=None, 485 | context_mask=None, 486 | mems=None, 487 | return_hiddens=False 488 | ): 489 | hiddens = [] 490 | intermediates = [] 491 | prev_attn = None 492 | prev_cross_attn = None 493 | 494 | mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers 495 | 496 | for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): 497 | is_last = ind == (len(self.layers) - 1) 498 | 499 | if layer_type == 'a': 500 | hiddens.append(x) 501 | layer_mem = mems.pop(0) 502 | 503 | residual = x 504 | 505 | if self.pre_norm: 506 | x = norm(x) 507 | 508 | if layer_type == 'a': 509 | out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, 510 | prev_attn=prev_attn, mem=layer_mem) 511 | elif layer_type == 'c': 512 | out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) 513 | elif layer_type == 'f': 514 | out = block(x) 515 | 516 | x = residual_fn(out, residual) 517 | 518 | if layer_type in ('a', 'c'): 519 | intermediates.append(inter) 520 | 521 | if layer_type == 'a' and self.residual_attn: 522 | prev_attn = inter.pre_softmax_attn 523 | elif layer_type == 'c' and self.cross_residual_attn: 524 | prev_cross_attn = inter.pre_softmax_attn 525 | 526 | if not self.pre_norm and not is_last: 527 | x = norm(x) 528 | 529 | if return_hiddens: 530 | intermediates = LayerIntermediates( 531 | hiddens=hiddens, 532 | attn_intermediates=intermediates 533 | ) 534 | 535 | return x, intermediates 536 | 537 | return x 538 | 539 | 540 | class Encoder(AttentionLayers): 541 | def __init__(self, **kwargs): 542 | assert 'causal' not in kwargs, 'cannot set causality on encoder' 543 | super().__init__(causal=False, **kwargs) 544 | 545 | 546 | 547 | class TransformerWrapper(nn.Module): 548 | def __init__( 549 | self, 550 | *, 551 | num_tokens, 552 | max_seq_len, 553 | attn_layers, 554 | emb_dim=None, 555 | max_mem_len=0., 556 | emb_dropout=0., 557 | num_memory_tokens=None, 558 | tie_embedding=False, 559 | use_pos_emb=True 560 | ): 561 | super().__init__() 562 | assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' 563 | 564 | dim = attn_layers.dim 565 | emb_dim = default(emb_dim, dim) 566 | 567 | self.max_seq_len = max_seq_len 568 | self.max_mem_len = max_mem_len 569 | self.num_tokens = num_tokens 570 | 571 | self.token_emb = nn.Embedding(num_tokens, emb_dim) 572 | self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( 573 | use_pos_emb and not attn_layers.has_pos_emb) else always(0) 574 | self.emb_dropout = nn.Dropout(emb_dropout) 575 | 576 | self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() 577 | self.attn_layers = attn_layers 578 | self.norm = nn.LayerNorm(dim) 579 | 580 | self.init_() 581 | 582 | self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() 583 | 584 | # memory tokens (like [cls]) from Memory Transformers paper 585 | num_memory_tokens = default(num_memory_tokens, 0) 586 | self.num_memory_tokens = num_memory_tokens 587 | if num_memory_tokens > 0: 588 | self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) 589 | 590 | # let funnel encoder know number of memory tokens, if specified 591 | if hasattr(attn_layers, 'num_memory_tokens'): 592 | attn_layers.num_memory_tokens = num_memory_tokens 593 | 594 | def init_(self): 595 | nn.init.normal_(self.token_emb.weight, std=0.02) 596 | 597 | def forward( 598 | self, 599 | x, 600 | return_embeddings=False, 601 | mask=None, 602 | return_mems=False, 603 | return_attn=False, 604 | mems=None, 605 | **kwargs 606 | ): 607 | b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens 608 | x = self.token_emb(x) 609 | x += self.pos_emb(x) 610 | x = self.emb_dropout(x) 611 | 612 | x = self.project_emb(x) 613 | 614 | if num_mem > 0: 615 | mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) 616 | x = torch.cat((mem, x), dim=1) 617 | 618 | # auto-handle masking after appending memory tokens 619 | if exists(mask): 620 | mask = F.pad(mask, (num_mem, 0), value=True) 621 | 622 | x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) 623 | x = self.norm(x) 624 | 625 | mem, x = x[:, :num_mem], x[:, num_mem:] 626 | 627 | out = self.to_logits(x) if not return_embeddings else x 628 | 629 | if return_mems: 630 | hiddens = intermediates.hiddens 631 | new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens 632 | new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) 633 | return out, new_mems 634 | 635 | if return_attn: 636 | attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) 637 | return out, attn_maps 638 | 639 | return out -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | import sys 3 | from .scripts.gradio.i2v_test import Image2Video 4 | from .scripts.gradio.i2v_test_application import Image2Video as Image2VideoInterp 5 | sys.path.insert(1, os.path.join(sys.path[0], 'lvdm')) 6 | 7 | resolutions=["576_1024","320_512","256_256"] 8 | resolutionInterps=["320_512"] 9 | 10 | class DynamiCrafterInterpLoader: 11 | @classmethod 12 | def INPUT_TYPES(cls): 13 | return { 14 | "required": { 15 | "resolution": (resolutionInterps,{"default":"320_512"}), 16 | "frame_length": ("INT",{"default":16}) 17 | } 18 | } 19 | 20 | RETURN_TYPES = ("DynamiCrafterInter",) 21 | RETURN_NAMES = ("model",) 22 | FUNCTION = "run_inference" 23 | CATEGORY = "DynamiCrafter" 24 | 25 | def run_inference(self,resolution,frame_length=16): 26 | image2video = Image2VideoInterp('./tmp/', resolution=resolution,frame_length=frame_length) 27 | return (image2video,) 28 | 29 | class DynamiCrafterLoader: 30 | @classmethod 31 | def INPUT_TYPES(cls): 32 | return { 33 | "required": { 34 | "resolution": (resolutions,{"default":"576_1024"}), 35 | "frame_length": ("INT",{"default":16}) 36 | } 37 | } 38 | 39 | RETURN_TYPES = ("DynamiCrafter",) 40 | RETURN_NAMES = ("model",) 41 | FUNCTION = "run_inference" 42 | CATEGORY = "DynamiCrafter" 43 | 44 | def run_inference(self,resolution,frame_length=16): 45 | image2video = Image2Video('./tmp/', resolution=resolution,frame_length=frame_length) 46 | return (image2video,) 47 | 48 | class DynamiCrafterSimple: 49 | @classmethod 50 | def INPUT_TYPES(cls): 51 | return { 52 | "required": { 53 | "model": ("DynamiCrafter",), 54 | "image": ("IMAGE",), 55 | "prompt": ("STRING", {"default": ""}), 56 | "steps": ("INT", {"default": 50}), 57 | "cfg_scale": ("FLOAT", {"default": 7.5}), 58 | "eta": ("FLOAT", {"default": 1.0}), 59 | "motion": ("INT", {"default": 3}), 60 | "seed": ("INT", {"default": 123}), 61 | } 62 | } 63 | 64 | RETURN_TYPES = ("IMAGE",) 65 | RETURN_NAMES = ("image",) 66 | FUNCTION = "run_inference" 67 | CATEGORY = "DynamiCrafter" 68 | 69 | def run_inference(self,model,image,prompt,steps,cfg_scale,eta,motion,seed): 70 | image = 255.0 * image[0].cpu().numpy() 71 | #image = Image.fromarray(np.clip(image, 0, 255).astype(np.uint8)) 72 | 73 | imgs= model.get_image(image, prompt, steps, cfg_scale, eta, motion, seed) 74 | return imgs 75 | 76 | class DynamiCrafterInterpSimple: 77 | @classmethod 78 | def INPUT_TYPES(cls): 79 | return { 80 | "required": { 81 | "model": ("DynamiCrafterInter",), 82 | "image": ("IMAGE",), 83 | "image1": ("IMAGE",), 84 | "prompt": ("STRING", {"default": ""}), 85 | "steps": ("INT", {"default": 50}), 86 | "cfg_scale": ("FLOAT", {"default": 7.5}), 87 | "eta": ("FLOAT", {"default": 1.0}), 88 | "motion": ("INT", {"default": 3}), 89 | "seed": ("INT", {"default": 123}), 90 | } 91 | } 92 | 93 | RETURN_TYPES = ("IMAGE",) 94 | RETURN_NAMES = ("image",) 95 | FUNCTION = "run_inference" 96 | CATEGORY = "DynamiCrafter" 97 | 98 | def run_inference(self,model,image,image1,prompt,steps,cfg_scale,eta,motion,seed): 99 | image = 255.0 * image[0].cpu().numpy() 100 | image1 = 255.0 * image1[0].cpu().numpy() 101 | #image = Image.fromarray(np.clip(image, 0, 255).astype(np.uint8)) 102 | 103 | imgs= model.get_image(image, prompt, steps, cfg_scale, eta, motion, seed,image1) 104 | return imgs 105 | 106 | 107 | NODE_CLASS_MAPPINGS = { 108 | "DynamiCrafterLoader":DynamiCrafterLoader, 109 | "DynamiCrafter Simple":DynamiCrafterSimple, 110 | "DynamiCrafterInterpLoader":DynamiCrafterInterpLoader, 111 | "DynamiCrafterInterp Simple":DynamiCrafterInterpSimple, 112 | } 113 | -------------------------------------------------------------------------------- /prompts/1024/astronaut04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/1024/astronaut04.png -------------------------------------------------------------------------------- /prompts/1024/bike_chineseink.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/1024/bike_chineseink.png -------------------------------------------------------------------------------- /prompts/1024/bloom01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/1024/bloom01.png -------------------------------------------------------------------------------- /prompts/1024/firework03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/1024/firework03.png -------------------------------------------------------------------------------- /prompts/1024/girl07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/1024/girl07.png -------------------------------------------------------------------------------- /prompts/1024/pour_bear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/1024/pour_bear.png -------------------------------------------------------------------------------- /prompts/1024/robot01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/1024/robot01.png -------------------------------------------------------------------------------- /prompts/1024/test_prompts.txt: -------------------------------------------------------------------------------- 1 | a man in an astronaut suit playing a guitar 2 | riding a bike under a bridge 3 | time-lapse of a blooming flower with leaves and a stem 4 | fireworks display 5 | a beautiful woman with long hair and a dress blowing in the wind 6 | pouring beer into a glass of ice and beer 7 | a robot is walking through a destroyed city 8 | a group of penguins walking on a beach -------------------------------------------------------------------------------- /prompts/1024/zreal_penguin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/1024/zreal_penguin.png -------------------------------------------------------------------------------- /prompts/256/art.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/256/art.png -------------------------------------------------------------------------------- /prompts/256/bear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/256/bear.png -------------------------------------------------------------------------------- /prompts/256/boy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/256/boy.png -------------------------------------------------------------------------------- /prompts/256/dance1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/256/dance1.jpeg -------------------------------------------------------------------------------- /prompts/256/fire_and_beach.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/256/fire_and_beach.jpg -------------------------------------------------------------------------------- /prompts/256/girl2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/256/girl2.jpeg -------------------------------------------------------------------------------- /prompts/256/girl3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/256/girl3.jpeg -------------------------------------------------------------------------------- /prompts/256/guitar0.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/256/guitar0.jpeg -------------------------------------------------------------------------------- /prompts/256/test_prompts.txt: -------------------------------------------------------------------------------- 1 | man fishing in a boat at sunset 2 | a brown bear is walking in a zoo enclosure, some rocks around 3 | boy walking on the street 4 | two people dancing 5 | a campfire on the beach and the ocean waves in the background 6 | girl with fires and smoke on his head 7 | girl talking and blinking 8 | bear playing guitar happily, snowing -------------------------------------------------------------------------------- /prompts/512/bloom01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512/bloom01.png -------------------------------------------------------------------------------- /prompts/512/campfire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512/campfire.png -------------------------------------------------------------------------------- /prompts/512/girl08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512/girl08.png -------------------------------------------------------------------------------- /prompts/512/isometric.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512/isometric.png -------------------------------------------------------------------------------- /prompts/512/pour_honey.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512/pour_honey.png -------------------------------------------------------------------------------- /prompts/512/ship02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512/ship02.png -------------------------------------------------------------------------------- /prompts/512/test_prompts.txt: -------------------------------------------------------------------------------- 1 | time-lapse of a blooming flower with leaves and a stem 2 | a bonfire is lit in the middle of a field 3 | a woman looking out in the rain 4 | rotating view, small house 5 | pouring honey onto some slices of bread 6 | a sailboat sailing in rough seas with a dramatic sunset 7 | a boat traveling on the ocean 8 | a group of penguins walking on a beach -------------------------------------------------------------------------------- /prompts/512/zreal_boat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512/zreal_boat.png -------------------------------------------------------------------------------- /prompts/512/zreal_penguin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512/zreal_penguin.png -------------------------------------------------------------------------------- /prompts/512_interp/smile_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512_interp/smile_01.png -------------------------------------------------------------------------------- /prompts/512_interp/smile_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512_interp/smile_02.png -------------------------------------------------------------------------------- /prompts/512_interp/stone01_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512_interp/stone01_01.png -------------------------------------------------------------------------------- /prompts/512_interp/stone01_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512_interp/stone01_02.png -------------------------------------------------------------------------------- /prompts/512_interp/test_prompts.txt: -------------------------------------------------------------------------------- 1 | a smiling girl 2 | rotating view 3 | a man is walking towards a tree -------------------------------------------------------------------------------- /prompts/512_interp/walk_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512_interp/walk_01.png -------------------------------------------------------------------------------- /prompts/512_interp/walk_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512_interp/walk_02.png -------------------------------------------------------------------------------- /prompts/512_loop/24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512_loop/24.png -------------------------------------------------------------------------------- /prompts/512_loop/36.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512_loop/36.png -------------------------------------------------------------------------------- /prompts/512_loop/40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512_loop/40.png -------------------------------------------------------------------------------- /prompts/512_loop/test_prompts.txt: -------------------------------------------------------------------------------- 1 | a beach with waves and clouds at sunset 2 | clothes swaying in the wind 3 | flowers swaying in the wind -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | decord 2 | einops 3 | imageio 4 | omegaconf 5 | opencv_python 6 | pandas 7 | Pillow 8 | pytorch_lightning 9 | PyYAML 10 | tqdm 11 | transformers 12 | moviepy 13 | av 14 | timm 15 | scikit-learn 16 | open_clip_torch==2.22.0 17 | kornia -------------------------------------------------------------------------------- /scripts/evaluation/ddp_wrapper.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import argparse, importlib 3 | from pytorch_lightning import seed_everything 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | def setup_dist(local_rank): 9 | if dist.is_initialized(): 10 | return 11 | torch.cuda.set_device(local_rank) 12 | torch.distributed.init_process_group('nccl', init_method='env://') 13 | 14 | 15 | def get_dist_info(): 16 | if dist.is_available(): 17 | initialized = dist.is_initialized() 18 | else: 19 | initialized = False 20 | if initialized: 21 | rank = dist.get_rank() 22 | world_size = dist.get_world_size() 23 | else: 24 | rank = 0 25 | world_size = 1 26 | return rank, world_size 27 | 28 | 29 | if __name__ == '__main__': 30 | now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--module", type=str, help="module name", default="inference") 33 | parser.add_argument("--local_rank", type=int, nargs="?", help="for ddp", default=0) 34 | args, unknown = parser.parse_known_args() 35 | inference_api = importlib.import_module(args.module, package=None) 36 | 37 | inference_parser = inference_api.get_parser() 38 | inference_args, unknown = inference_parser.parse_known_args() 39 | 40 | seed_everything(inference_args.seed) 41 | setup_dist(args.local_rank) 42 | torch.backends.cudnn.benchmark = True 43 | rank, gpu_num = get_dist_info() 44 | 45 | # inference_args.savedir = inference_args.savedir+str('_seed')+str(inference_args.seed) 46 | print("@DynamiCrafter Inference [rank%d]: %s"%(rank, now)) 47 | inference_api.run_inference(inference_args, gpu_num, rank) -------------------------------------------------------------------------------- /scripts/evaluation/funcs.py: -------------------------------------------------------------------------------- 1 | import os, sys, glob 2 | import numpy as np 3 | from collections import OrderedDict 4 | from decord import VideoReader, cpu 5 | import cv2 6 | 7 | import torch 8 | import torchvision 9 | sys.path.insert(1, os.path.join(sys.path[0], '..', '..')) 10 | from ...lvdm.models.samplers.ddim import DDIMSampler 11 | from einops import rearrange 12 | 13 | 14 | def batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1.0,\ 15 | cfg_scale=1.0, temporal_cfg_scale=None, **kwargs): 16 | ddim_sampler = DDIMSampler(model) 17 | uncond_type = model.uncond_type 18 | batch_size = noise_shape[0] 19 | fs = cond["fs"] 20 | del cond["fs"] 21 | if noise_shape[-1] == 32: 22 | timestep_spacing = "uniform" 23 | guidance_rescale = 0.0 24 | else: 25 | timestep_spacing = "uniform_trailing" 26 | guidance_rescale = 0.7 27 | ## construct unconditional guidance 28 | if cfg_scale != 1.0: 29 | if uncond_type == "empty_seq": 30 | prompts = batch_size * [""] 31 | #prompts = N * T * [""] ## if is_imgbatch=True 32 | uc_emb = model.get_learned_conditioning(prompts) 33 | elif uncond_type == "zero_embed": 34 | c_emb = cond["c_crossattn"][0] if isinstance(cond, dict) else cond 35 | uc_emb = torch.zeros_like(c_emb) 36 | 37 | ## process image embedding token 38 | if hasattr(model, 'embedder'): 39 | uc_img = torch.zeros(noise_shape[0],3,224,224).to(model.device) 40 | ## img: b c h w >> b l c 41 | uc_img = model.embedder(uc_img) 42 | uc_img = model.image_proj_model(uc_img) 43 | uc_emb = torch.cat([uc_emb, uc_img], dim=1) 44 | 45 | if isinstance(cond, dict): 46 | uc = {key:cond[key] for key in cond.keys()} 47 | uc.update({'c_crossattn': [uc_emb]}) 48 | else: 49 | uc = uc_emb 50 | else: 51 | uc = None 52 | 53 | x_T = None 54 | batch_variants = [] 55 | 56 | for _ in range(n_samples): 57 | if ddim_sampler is not None: 58 | kwargs.update({"clean_cond": True}) 59 | samples, _ = ddim_sampler.sample(S=ddim_steps, 60 | conditioning=cond, 61 | batch_size=noise_shape[0], 62 | shape=noise_shape[1:], 63 | verbose=False, 64 | unconditional_guidance_scale=cfg_scale, 65 | unconditional_conditioning=uc, 66 | eta=ddim_eta, 67 | temporal_length=noise_shape[2], 68 | conditional_guidance_scale_temporal=temporal_cfg_scale, 69 | x_T=x_T, 70 | fs=fs, 71 | timestep_spacing=timestep_spacing, 72 | guidance_rescale=guidance_rescale, 73 | **kwargs 74 | ) 75 | ## reconstruct from latent to pixel space 76 | batch_images = model.decode_first_stage(samples) 77 | batch_variants.append(batch_images) 78 | ## batch, , c, t, h, w 79 | batch_variants = torch.stack(batch_variants, dim=1) 80 | return batch_variants 81 | 82 | 83 | def get_filelist(data_dir, ext='*'): 84 | file_list = glob.glob(os.path.join(data_dir, '*.%s'%ext)) 85 | file_list.sort() 86 | return file_list 87 | 88 | def get_dirlist(path): 89 | list = [] 90 | if (os.path.exists(path)): 91 | files = os.listdir(path) 92 | for file in files: 93 | m = os.path.join(path,file) 94 | if (os.path.isdir(m)): 95 | list.append(m) 96 | list.sort() 97 | return list 98 | 99 | 100 | def load_model_checkpoint(model, ckpt): 101 | def load_checkpoint(model, ckpt, full_strict): 102 | state_dict = torch.load(ckpt, map_location="cpu") 103 | if "state_dict" in list(state_dict.keys()): 104 | state_dict = state_dict["state_dict"] 105 | try: 106 | model.load_state_dict(state_dict, strict=full_strict) 107 | except: 108 | ## rename the keys for 256x256 model 109 | new_pl_sd = OrderedDict() 110 | for k,v in state_dict.items(): 111 | new_pl_sd[k] = v 112 | 113 | for k in list(new_pl_sd.keys()): 114 | if "framestride_embed" in k: 115 | new_key = k.replace("framestride_embed", "fps_embedding") 116 | new_pl_sd[new_key] = new_pl_sd[k] 117 | del new_pl_sd[k] 118 | model.load_state_dict(new_pl_sd, strict=full_strict) 119 | else: 120 | ## deepspeed 121 | new_pl_sd = OrderedDict() 122 | for key in state_dict['module'].keys(): 123 | new_pl_sd[key[16:]]=state_dict['module'][key] 124 | model.load_state_dict(new_pl_sd, strict=full_strict) 125 | 126 | return model 127 | load_checkpoint(model, ckpt, full_strict=True) 128 | print('>>> model checkpoint loaded.') 129 | return model 130 | 131 | 132 | def load_prompts(prompt_file): 133 | f = open(prompt_file, 'r') 134 | prompt_list = [] 135 | for idx, line in enumerate(f.readlines()): 136 | l = line.strip() 137 | if len(l) != 0: 138 | prompt_list.append(l) 139 | f.close() 140 | return prompt_list 141 | 142 | 143 | def load_video_batch(filepath_list, frame_stride, video_size=(256,256), video_frames=16): 144 | ''' 145 | Notice about some special cases: 146 | 1. video_frames=-1 means to take all the frames (with fs=1) 147 | 2. when the total video frames is less than required, padding strategy will be used (repreated last frame) 148 | ''' 149 | fps_list = [] 150 | batch_tensor = [] 151 | assert frame_stride > 0, "valid frame stride should be a positive interge!" 152 | for filepath in filepath_list: 153 | padding_num = 0 154 | vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0]) 155 | fps = vidreader.get_avg_fps() 156 | total_frames = len(vidreader) 157 | max_valid_frames = (total_frames-1) // frame_stride + 1 158 | if video_frames < 0: 159 | ## all frames are collected: fs=1 is a must 160 | required_frames = total_frames 161 | frame_stride = 1 162 | else: 163 | required_frames = video_frames 164 | query_frames = min(required_frames, max_valid_frames) 165 | frame_indices = [frame_stride*i for i in range(query_frames)] 166 | 167 | ## [t,h,w,c] -> [c,t,h,w] 168 | frames = vidreader.get_batch(frame_indices) 169 | frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() 170 | frame_tensor = (frame_tensor / 255. - 0.5) * 2 171 | if max_valid_frames < required_frames: 172 | padding_num = required_frames - max_valid_frames 173 | frame_tensor = torch.cat([frame_tensor, *([frame_tensor[:,-1:,:,:]]*padding_num)], dim=1) 174 | print(f'{os.path.split(filepath)[1]} is not long enough: {padding_num} frames padded.') 175 | batch_tensor.append(frame_tensor) 176 | sample_fps = int(fps/frame_stride) 177 | fps_list.append(sample_fps) 178 | 179 | return torch.stack(batch_tensor, dim=0) 180 | 181 | from PIL import Image 182 | def load_image_batch(filepath_list, image_size=(256,256)): 183 | batch_tensor = [] 184 | for filepath in filepath_list: 185 | _, filename = os.path.split(filepath) 186 | _, ext = os.path.splitext(filename) 187 | if ext == '.mp4': 188 | vidreader = VideoReader(filepath, ctx=cpu(0), width=image_size[1], height=image_size[0]) 189 | frame = vidreader.get_batch([0]) 190 | img_tensor = torch.tensor(frame.asnumpy()).squeeze(0).permute(2, 0, 1).float() 191 | elif ext == '.png' or ext == '.jpg': 192 | img = Image.open(filepath).convert("RGB") 193 | rgb_img = np.array(img, np.float32) 194 | #bgr_img = cv2.imread(filepath, cv2.IMREAD_COLOR) 195 | #bgr_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB) 196 | rgb_img = cv2.resize(rgb_img, (image_size[1],image_size[0]), interpolation=cv2.INTER_LINEAR) 197 | img_tensor = torch.from_numpy(rgb_img).permute(2, 0, 1).float() 198 | else: 199 | print(f'ERROR: <{ext}> image loading only support format: [mp4], [png], [jpg]') 200 | raise NotImplementedError 201 | img_tensor = (img_tensor / 255. - 0.5) * 2 202 | batch_tensor.append(img_tensor) 203 | return torch.stack(batch_tensor, dim=0) 204 | 205 | 206 | def save_videos(batch_tensors, savedir, filenames, fps=10): 207 | # b,samples,c,t,h,w 208 | n_samples = batch_tensors.shape[1] 209 | for idx, vid_tensor in enumerate(batch_tensors): 210 | video = vid_tensor.detach().cpu() 211 | video = torch.clamp(video.float(), -1., 1.) 212 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w 213 | frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n_samples)) for framesheet in video] #[3, 1*h, n*w] 214 | grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w] 215 | grid = (grid + 1.0) / 2.0 216 | grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) 217 | #savepath = os.path.join(savedir, f"{filenames[idx]}.mp4") 218 | #torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'}) 219 | outframes=[] 220 | for i in range(grid.shape[0]): 221 | img = grid[i].numpy() 222 | image=Image.fromarray(img) 223 | image_tensor_out = torch.tensor(np.array(image).astype(np.float32) / 255.0) # Convert back to CxHxW 224 | image_tensor_out = torch.unsqueeze(image_tensor_out, 0) 225 | outframes.append(image_tensor_out) 226 | 227 | return torch.cat(tuple(outframes), dim=0).unsqueeze(0) 228 | 229 | 230 | def get_latent_z(model, videos): 231 | b, c, t, h, w = videos.shape 232 | x = rearrange(videos, 'b c t h w -> (b t) c h w') 233 | z = model.encode_first_stage(x) 234 | z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) 235 | return z -------------------------------------------------------------------------------- /scripts/evaluation/inference.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | import datetime, time 3 | from omegaconf import OmegaConf 4 | from tqdm import tqdm 5 | from einops import rearrange, repeat 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torchvision 10 | import torchvision.transforms as transforms 11 | from pytorch_lightning import seed_everything 12 | from PIL import Image 13 | sys.path.insert(1, os.path.join(sys.path[0], '..', '..')) 14 | from ...lvdm.models.samplers.ddim import DDIMSampler 15 | from ...lvdm.models.samplers.ddim_multiplecond import DDIMSampler as DDIMSampler_multicond 16 | from ...utils.utils import instantiate_from_config 17 | 18 | 19 | def get_filelist(data_dir, postfixes): 20 | patterns = [os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes] 21 | file_list = [] 22 | for pattern in patterns: 23 | file_list.extend(glob.glob(pattern)) 24 | file_list.sort() 25 | return file_list 26 | 27 | def load_model_checkpoint(model, ckpt): 28 | state_dict = torch.load(ckpt, map_location="cpu") 29 | if "state_dict" in list(state_dict.keys()): 30 | state_dict = state_dict["state_dict"] 31 | try: 32 | model.load_state_dict(state_dict, strict=True) 33 | except: 34 | ## rename the keys for 256x256 model 35 | new_pl_sd = OrderedDict() 36 | for k,v in state_dict.items(): 37 | new_pl_sd[k] = v 38 | 39 | for k in list(new_pl_sd.keys()): 40 | if "framestride_embed" in k: 41 | new_key = k.replace("framestride_embed", "fps_embedding") 42 | new_pl_sd[new_key] = new_pl_sd[k] 43 | del new_pl_sd[k] 44 | model.load_state_dict(new_pl_sd, strict=True) 45 | else: 46 | # deepspeed 47 | new_pl_sd = OrderedDict() 48 | for key in state_dict['module'].keys(): 49 | new_pl_sd[key[16:]]=state_dict['module'][key] 50 | model.load_state_dict(new_pl_sd) 51 | print('>>> model checkpoint loaded.') 52 | return model 53 | 54 | def load_prompts(prompt_file): 55 | f = open(prompt_file, 'r') 56 | prompt_list = [] 57 | for idx, line in enumerate(f.readlines()): 58 | l = line.strip() 59 | if len(l) != 0: 60 | prompt_list.append(l) 61 | f.close() 62 | return prompt_list 63 | 64 | def load_data_prompts(data_dir, video_size=(256,256), video_frames=16, interp=False): 65 | transform = transforms.Compose([ 66 | transforms.Resize(min(video_size)), 67 | transforms.CenterCrop(video_size), 68 | transforms.ToTensor(), 69 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) 70 | ## load prompts 71 | prompt_file = get_filelist(data_dir, ['txt']) 72 | assert len(prompt_file) > 0, "Error: found NO prompt file!" 73 | ###### default prompt 74 | default_idx = 0 75 | default_idx = min(default_idx, len(prompt_file)-1) 76 | if len(prompt_file) > 1: 77 | print(f"Warning: multiple prompt files exist. The one {os.path.split(prompt_file[default_idx])[1]} is used.") 78 | ## only use the first one (sorted by name) if multiple exist 79 | 80 | ## load video 81 | file_list = get_filelist(data_dir, ['jpg', 'png', 'jpeg', 'JPEG', 'PNG']) 82 | # assert len(file_list) == n_samples, "Error: data and prompts are NOT paired!" 83 | data_list = [] 84 | filename_list = [] 85 | prompt_list = load_prompts(prompt_file[default_idx]) 86 | n_samples = len(prompt_list) 87 | for idx in range(n_samples): 88 | if interp: 89 | image1 = Image.open(file_list[2*idx]).convert('RGB') 90 | image_tensor1 = transform(image1).unsqueeze(1) # [c,1,h,w] 91 | image2 = Image.open(file_list[2*idx+1]).convert('RGB') 92 | image_tensor2 = transform(image2).unsqueeze(1) # [c,1,h,w] 93 | frame_tensor1 = repeat(image_tensor1, 'c t h w -> c (repeat t) h w', repeat=video_frames//2) 94 | frame_tensor2 = repeat(image_tensor2, 'c t h w -> c (repeat t) h w', repeat=video_frames//2) 95 | frame_tensor = torch.cat([frame_tensor1, frame_tensor2], dim=1) 96 | _, filename = os.path.split(file_list[idx*2]) 97 | else: 98 | image = Image.open(file_list[idx]).convert('RGB') 99 | image_tensor = transform(image).unsqueeze(1) # [c,1,h,w] 100 | frame_tensor = repeat(image_tensor, 'c t h w -> c (repeat t) h w', repeat=video_frames) 101 | _, filename = os.path.split(file_list[idx]) 102 | 103 | data_list.append(frame_tensor) 104 | filename_list.append(filename) 105 | 106 | return filename_list, data_list, prompt_list 107 | 108 | 109 | def save_results(prompt, samples, filename, fakedir, fps=8, loop=False): 110 | filename = filename.split('.')[0]+'.mp4' 111 | prompt = prompt[0] if isinstance(prompt, list) else prompt 112 | 113 | ## save video 114 | videos = [samples] 115 | savedirs = [fakedir] 116 | for idx, video in enumerate(videos): 117 | if video is None: 118 | continue 119 | # b,c,t,h,w 120 | video = video.detach().cpu() 121 | video = torch.clamp(video.float(), -1., 1.) 122 | n = video.shape[0] 123 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w 124 | if loop: 125 | video = video[:-1,...] 126 | 127 | frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video] #[3, 1*h, n*w] 128 | grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, h, n*w] 129 | grid = (grid + 1.0) / 2.0 130 | grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) 131 | path = os.path.join(savedirs[idx], filename) 132 | torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '10'}) ## crf indicates the quality 133 | 134 | 135 | def save_results_seperate(prompt, samples, filename, fakedir, fps=10, loop=False): 136 | prompt = prompt[0] if isinstance(prompt, list) else prompt 137 | 138 | ## save video 139 | videos = [samples] 140 | savedirs = [fakedir] 141 | for idx, video in enumerate(videos): 142 | if video is None: 143 | continue 144 | # b,c,t,h,w 145 | video = video.detach().cpu() 146 | if loop: # remove the last frame 147 | video = video[:,:,:-1,...] 148 | video = torch.clamp(video.float(), -1., 1.) 149 | n = video.shape[0] 150 | for i in range(n): 151 | grid = video[i,...] 152 | grid = (grid + 1.0) / 2.0 153 | grid = (grid * 255).to(torch.uint8).permute(1, 2, 3, 0) #thwc 154 | path = os.path.join(savedirs[idx].replace('samples', 'samples_separate'), f'{filename.split(".")[0]}_sample{i}.mp4') 155 | torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '10'}) 156 | 157 | def get_latent_z(model, videos): 158 | b, c, t, h, w = videos.shape 159 | x = rearrange(videos, 'b c t h w -> (b t) c h w') 160 | z = model.encode_first_stage(x) 161 | z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) 162 | return z 163 | 164 | 165 | def image_guided_synthesis(model, prompts, videos, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1., \ 166 | unconditional_guidance_scale=1.0, cfg_img=None, fs=None, text_input=False, multiple_cond_cfg=False, loop=False, interp=False, timestep_spacing='uniform', guidance_rescale=0.0, **kwargs): 167 | ddim_sampler = DDIMSampler(model) if not multiple_cond_cfg else DDIMSampler_multicond(model) 168 | batch_size = noise_shape[0] 169 | fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device) 170 | 171 | if not text_input: 172 | prompts = [""]*batch_size 173 | 174 | img = videos[:,:,0] #bchw 175 | img_emb = model.embedder(img) ## blc 176 | img_emb = model.image_proj_model(img_emb) 177 | 178 | cond_emb = model.get_learned_conditioning(prompts) 179 | cond = {"c_crossattn": [torch.cat([cond_emb,img_emb], dim=1)]} 180 | if model.model.conditioning_key == 'hybrid': 181 | z = get_latent_z(model, videos) # b c t h w 182 | if loop or interp: 183 | img_cat_cond = torch.zeros_like(z) 184 | img_cat_cond[:,:,0,:,:] = z[:,:,0,:,:] 185 | img_cat_cond[:,:,-1,:,:] = z[:,:,-1,:,:] 186 | else: 187 | img_cat_cond = z[:,:,:1,:,:] 188 | img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=z.shape[2]) 189 | cond["c_concat"] = [img_cat_cond] # b c 1 h w 190 | 191 | if unconditional_guidance_scale != 1.0: 192 | if model.uncond_type == "empty_seq": 193 | prompts = batch_size * [""] 194 | uc_emb = model.get_learned_conditioning(prompts) 195 | elif model.uncond_type == "zero_embed": 196 | uc_emb = torch.zeros_like(cond_emb) 197 | uc_img_emb = model.embedder(torch.zeros_like(img)) ## b l c 198 | uc_img_emb = model.image_proj_model(uc_img_emb) 199 | uc = {"c_crossattn": [torch.cat([uc_emb,uc_img_emb],dim=1)]} 200 | if model.model.conditioning_key == 'hybrid': 201 | uc["c_concat"] = [img_cat_cond] 202 | else: 203 | uc = None 204 | 205 | ## we need one more unconditioning image=yes, text="" 206 | if multiple_cond_cfg and cfg_img != 1.0: 207 | uc_2 = {"c_crossattn": [torch.cat([uc_emb,img_emb],dim=1)]} 208 | if model.model.conditioning_key == 'hybrid': 209 | uc_2["c_concat"] = [img_cat_cond] 210 | kwargs.update({"unconditional_conditioning_img_nonetext": uc_2}) 211 | else: 212 | kwargs.update({"unconditional_conditioning_img_nonetext": None}) 213 | 214 | z0 = None 215 | cond_mask = None 216 | 217 | batch_variants = [] 218 | for _ in range(n_samples): 219 | 220 | if z0 is not None: 221 | cond_z0 = z0.clone() 222 | kwargs.update({"clean_cond": True}) 223 | else: 224 | cond_z0 = None 225 | if ddim_sampler is not None: 226 | 227 | samples, _ = ddim_sampler.sample(S=ddim_steps, 228 | conditioning=cond, 229 | batch_size=batch_size, 230 | shape=noise_shape[1:], 231 | verbose=False, 232 | unconditional_guidance_scale=unconditional_guidance_scale, 233 | unconditional_conditioning=uc, 234 | eta=ddim_eta, 235 | cfg_img=cfg_img, 236 | mask=cond_mask, 237 | x0=cond_z0, 238 | fs=fs, 239 | timestep_spacing=timestep_spacing, 240 | guidance_rescale=guidance_rescale, 241 | **kwargs 242 | ) 243 | 244 | ## reconstruct from latent to pixel space 245 | batch_images = model.decode_first_stage(samples) 246 | batch_variants.append(batch_images) 247 | ## variants, batch, c, t, h, w 248 | batch_variants = torch.stack(batch_variants) 249 | return batch_variants.permute(1, 0, 2, 3, 4, 5) 250 | 251 | 252 | def run_inference(args, gpu_num, gpu_no): 253 | ## model config 254 | config = OmegaConf.load(args.config) 255 | model_config = config.pop("model", OmegaConf.create()) 256 | 257 | ## set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set" 258 | model_config['params']['unet_config']['params']['use_checkpoint'] = False 259 | model = instantiate_from_config(model_config) 260 | model = model.cuda(gpu_no) 261 | model.perframe_ae = args.perframe_ae 262 | assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" 263 | model = load_model_checkpoint(model, args.ckpt_path) 264 | model.eval() 265 | 266 | ## run over data 267 | assert (args.height % 16 == 0) and (args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!" 268 | assert args.bs == 1, "Current implementation only support [batch size = 1]!" 269 | ## latent noise shape 270 | h, w = args.height // 8, args.width // 8 271 | channels = model.model.diffusion_model.out_channels 272 | n_frames = args.video_length 273 | print(f'Inference with {n_frames} frames') 274 | noise_shape = [args.bs, channels, n_frames, h, w] 275 | 276 | fakedir = os.path.join(args.savedir, "samples") 277 | fakedir_separate = os.path.join(args.savedir, "samples_separate") 278 | 279 | # os.makedirs(fakedir, exist_ok=True) 280 | os.makedirs(fakedir_separate, exist_ok=True) 281 | 282 | ## prompt file setting 283 | assert os.path.exists(args.prompt_dir), "Error: prompt file Not Found!" 284 | filename_list, data_list, prompt_list = load_data_prompts(args.prompt_dir, video_size=(args.height, args.width), video_frames=n_frames, interp=args.interp) 285 | num_samples = len(prompt_list) 286 | samples_split = num_samples // gpu_num 287 | print('Prompts testing [rank:%d] %d/%d samples loaded.'%(gpu_no, samples_split, num_samples)) 288 | #indices = random.choices(list(range(0, num_samples)), k=samples_per_device) 289 | indices = list(range(samples_split*gpu_no, samples_split*(gpu_no+1))) 290 | prompt_list_rank = [prompt_list[i] for i in indices] 291 | data_list_rank = [data_list[i] for i in indices] 292 | filename_list_rank = [filename_list[i] for i in indices] 293 | 294 | start = time.time() 295 | with torch.no_grad(), torch.cuda.amp.autocast(): 296 | for idx, indice in tqdm(enumerate(range(0, len(prompt_list_rank), args.bs)), desc='Sample Batch'): 297 | prompts = prompt_list_rank[indice:indice+args.bs] 298 | videos = data_list_rank[indice:indice+args.bs] 299 | filenames = filename_list_rank[indice:indice+args.bs] 300 | if isinstance(videos, list): 301 | videos = torch.stack(videos, dim=0).to("cuda") 302 | else: 303 | videos = videos.unsqueeze(0).to("cuda") 304 | 305 | batch_samples = image_guided_synthesis(model, prompts, videos, noise_shape, args.n_samples, args.ddim_steps, args.ddim_eta, \ 306 | args.unconditional_guidance_scale, args.cfg_img, args.frame_stride, args.text_input, args.multiple_cond_cfg, args.loop, args.interp, args.timestep_spacing, args.guidance_rescale) 307 | 308 | ## save each example individually 309 | for nn, samples in enumerate(batch_samples): 310 | ## samples : [n_samples,c,t,h,w] 311 | prompt = prompts[nn] 312 | filename = filenames[nn] 313 | # save_results(prompt, samples, filename, fakedir, fps=8, loop=args.loop) 314 | save_results_seperate(prompt, samples, filename, fakedir, fps=8, loop=args.loop) 315 | 316 | print(f"Saved in {args.savedir}. Time used: {(time.time() - start):.2f} seconds") 317 | 318 | 319 | def get_parser(): 320 | parser = argparse.ArgumentParser() 321 | parser.add_argument("--savedir", type=str, default=None, help="results saving path") 322 | parser.add_argument("--ckpt_path", type=str, default=None, help="checkpoint path") 323 | parser.add_argument("--config", type=str, help="config (yaml) path") 324 | parser.add_argument("--prompt_dir", type=str, default=None, help="a data dir containing videos and prompts") 325 | parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt",) 326 | parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM",) 327 | parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)",) 328 | parser.add_argument("--bs", type=int, default=1, help="batch size for inference, should be one") 329 | parser.add_argument("--height", type=int, default=512, help="image height, in pixel space") 330 | parser.add_argument("--width", type=int, default=512, help="image width, in pixel space") 331 | parser.add_argument("--frame_stride", type=int, default=3, help="frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)") 332 | parser.add_argument("--unconditional_guidance_scale", type=float, default=1.0, help="prompt classifier-free guidance") 333 | parser.add_argument("--seed", type=int, default=123, help="seed for seed_everything") 334 | parser.add_argument("--video_length", type=int, default=16, help="inference video length") 335 | parser.add_argument("--negative_prompt", action='store_true', default=False, help="negative prompt") 336 | parser.add_argument("--text_input", action='store_true', default=False, help="input text to I2V model or not") 337 | parser.add_argument("--multiple_cond_cfg", action='store_true', default=False, help="use multi-condition cfg or not") 338 | parser.add_argument("--cfg_img", type=float, default=None, help="guidance scale for image conditioning") 339 | parser.add_argument("--timestep_spacing", type=str, default="uniform", help="The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.") 340 | parser.add_argument("--guidance_rescale", type=float, default=0.0, help="guidance rescale in [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891)") 341 | parser.add_argument("--perframe_ae", action='store_true', default=False, help="if we use per-frame AE decoding, set it to True to save GPU memory, especially for the model of 576x1024") 342 | 343 | ## currently not support looping video and generative frame interpolation 344 | parser.add_argument("--loop", action='store_true', default=False, help="generate looping videos or not") 345 | parser.add_argument("--interp", action='store_true', default=False, help="generate generative frame interpolation or not") 346 | return parser 347 | 348 | 349 | if __name__ == '__main__': 350 | now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 351 | print("@DynamiCrafter cond-Inference: %s"%now) 352 | parser = get_parser() 353 | args = parser.parse_args() 354 | 355 | seed_everything(args.seed) 356 | rank, gpu_num = 0, 1 357 | run_inference(args, gpu_num, rank) -------------------------------------------------------------------------------- /scripts/gradio/i2v_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from omegaconf import OmegaConf 4 | import torch 5 | from ..evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling, get_latent_z 6 | from ...utils.utils import instantiate_from_config 7 | from huggingface_hub import hf_hub_download 8 | from einops import repeat 9 | import torchvision.transforms as transforms 10 | from pytorch_lightning import seed_everything 11 | import folder_paths 12 | 13 | comfy_path = os.path.dirname(folder_paths.__file__) 14 | models_path=f'{comfy_path}/models/' 15 | config_path=f'{comfy_path}/custom_nodes/ComfyUI-DynamiCrafter/' 16 | 17 | class Image2Video(): 18 | def __init__(self,result_dir='./tmp/',gpu_num=1,resolution='256_256',frame_length=16) -> None: 19 | self.resolution = (int(resolution.split('_')[0]), int(resolution.split('_')[1])) #hw 20 | self.download_model() 21 | 22 | self.result_dir = result_dir 23 | if not os.path.exists(self.result_dir): 24 | os.mkdir(self.result_dir) 25 | ckpt_path=models_path+'checkpoints/dynamicrafter_'+resolution.split('_')[1]+'_v1/model.ckpt' 26 | config_file=config_path+'configs/inference_'+resolution.split('_')[1]+'_v1.0.yaml' 27 | config = OmegaConf.load(config_file) 28 | OmegaConf.update(config, "model.params.unet_config.params.temporal_length", frame_length) 29 | model_config = config.pop("model", OmegaConf.create()) 30 | model_config['params']['unet_config']['params']['use_checkpoint']=False 31 | model_list = [] 32 | for gpu_id in range(gpu_num): 33 | model = instantiate_from_config(model_config) 34 | # model = model.cuda(gpu_id) 35 | assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!" 36 | model = load_model_checkpoint(model, ckpt_path) 37 | model.eval() 38 | model_list.append(model) 39 | self.model_list = model_list 40 | self.save_fps = 8 41 | 42 | def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123): 43 | seed_everything(seed) 44 | transform = transforms.Compose([ 45 | transforms.Resize(min(self.resolution)), 46 | transforms.CenterCrop(self.resolution), 47 | ]) 48 | torch.cuda.empty_cache() 49 | print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))) 50 | start = time.time() 51 | gpu_id=0 52 | if steps > 60: 53 | steps = 60 54 | model = self.model_list[gpu_id] 55 | model = model.cuda() 56 | batch_size=1 57 | channels = model.model.diffusion_model.out_channels 58 | frames = model.temporal_length 59 | h, w = self.resolution[0] // 8, self.resolution[1] // 8 60 | noise_shape = [batch_size, channels, frames, h, w] 61 | 62 | # text cond 63 | text_emb = model.get_learned_conditioning([prompt]) 64 | 65 | # img cond 66 | img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device) 67 | img_tensor = (img_tensor / 255. - 0.5) * 2 68 | 69 | image_tensor_resized = transform(img_tensor) #3,h,w 70 | videos = image_tensor_resized.unsqueeze(0) # bchw 71 | 72 | z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw 73 | 74 | img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames) 75 | 76 | cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc 77 | img_emb = model.image_proj_model(cond_images) 78 | 79 | imtext_cond = torch.cat([text_emb, img_emb], dim=1) 80 | 81 | fs = torch.tensor([fs], dtype=torch.long, device=model.device) 82 | cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]} 83 | 84 | ## inference 85 | batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale) 86 | ## b,samples,c,t,h,w 87 | prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt 88 | prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str 89 | prompt_str=prompt_str[:40] 90 | if len(prompt_str) == 0: 91 | prompt_str = 'empty_prompt' 92 | 93 | imgs=save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps) 94 | print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds") 95 | model = model.cpu() 96 | #return os.path.join(self.result_dir, f"{prompt_str}.mp4") 97 | return imgs 98 | 99 | def download_model(self): 100 | REPO_ID = 'Doubiiu/DynamiCrafter_'+str(self.resolution[1]) if self.resolution[1]!=256 else 'Doubiiu/DynamiCrafter' 101 | filename_list = ['model.ckpt'] 102 | if not os.path.exists(models_path+'checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/'): 103 | os.makedirs(models_path+'checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/') 104 | for filename in filename_list: 105 | local_file = os.path.join(models_path+'checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/', filename) 106 | if not os.path.exists(local_file): 107 | hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=models_path+'checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/', local_dir_use_symlinks=False) 108 | 109 | if __name__ == '__main__': 110 | i2v = Image2Video() 111 | video_path = i2v.get_image('prompts/art.png','man fishing in a boat at sunset') 112 | print('done', video_path) -------------------------------------------------------------------------------- /scripts/gradio/i2v_test_application.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from omegaconf import OmegaConf 4 | import torch 5 | from ..evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling, get_latent_z 6 | from ...utils.utils import instantiate_from_config 7 | from huggingface_hub import hf_hub_download 8 | from einops import repeat 9 | import torchvision.transforms as transforms 10 | from pytorch_lightning import seed_everything 11 | import folder_paths 12 | 13 | comfy_path = os.path.dirname(folder_paths.__file__) 14 | models_path=f'{comfy_path}/models/' 15 | config_path=f'{comfy_path}/custom_nodes/ComfyUI-DynamiCrafter/' 16 | 17 | class Image2Video(): 18 | def __init__(self,result_dir='./tmp/',gpu_num=1,resolution='256_256',frame_length=16) -> None: 19 | self.resolution = (int(resolution.split('_')[0]), int(resolution.split('_')[1])) #hw 20 | self.download_model() 21 | 22 | self.result_dir = result_dir 23 | if not os.path.exists(self.result_dir): 24 | os.mkdir(self.result_dir) 25 | ckpt_path=models_path+'checkpoints/dynamicrafter_'+resolution.split('_')[1]+'_interp_v1/model.ckpt' 26 | config_file=config_path+'configs/inference_'+resolution.split('_')[1]+'_v1.0.yaml' 27 | config = OmegaConf.load(config_file) 28 | OmegaConf.update(config, "model.params.unet_config.params.temporal_length", frame_length) 29 | model_config = config.pop("model", OmegaConf.create()) 30 | model_config['params']['unet_config']['params']['use_checkpoint']=False 31 | model_list = [] 32 | for gpu_id in range(gpu_num): 33 | model = instantiate_from_config(model_config) 34 | # model = model.cuda(gpu_id) 35 | assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!" 36 | model = load_model_checkpoint(model, ckpt_path) 37 | model.eval() 38 | model_list.append(model) 39 | self.model_list = model_list 40 | self.save_fps = 8 41 | 42 | def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, image2=None): 43 | seed_everything(seed) 44 | transform = transforms.Compose([ 45 | transforms.Resize(min(self.resolution)), 46 | transforms.CenterCrop(self.resolution), 47 | ]) 48 | torch.cuda.empty_cache() 49 | print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))) 50 | start = time.time() 51 | gpu_id=0 52 | if steps > 60: 53 | steps = 60 54 | model = self.model_list[gpu_id] 55 | model = model.cuda() 56 | batch_size=1 57 | channels = model.model.diffusion_model.out_channels 58 | frames = model.temporal_length 59 | h, w = self.resolution[0] // 8, self.resolution[1] // 8 60 | noise_shape = [batch_size, channels, frames, h, w] 61 | 62 | # text cond 63 | with torch.no_grad(), torch.cuda.amp.autocast(): 64 | text_emb = model.get_learned_conditioning([prompt]) 65 | 66 | # img cond 67 | img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device) 68 | img_tensor = (img_tensor / 255. - 0.5) * 2 69 | 70 | image_tensor_resized = transform(img_tensor) #3,h,w 71 | videos = image_tensor_resized.unsqueeze(0) # bchw 72 | 73 | z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw 74 | 75 | 76 | if image2 is not None: 77 | img_tensor2 = torch.from_numpy(image2).permute(2, 0, 1).float().to(model.device) 78 | img_tensor2 = (img_tensor2 / 255. - 0.5) * 2 79 | 80 | image_tensor_resized2 = transform(img_tensor2) #3,h,w 81 | videos2 = image_tensor_resized2.unsqueeze(0) # bchw 82 | 83 | z2 = get_latent_z(model, videos2.unsqueeze(2)) #bc,1,hw 84 | 85 | img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames) 86 | 87 | img_tensor_repeat = torch.zeros_like(img_tensor_repeat) 88 | 89 | ## old 90 | img_tensor_repeat[:,:,:1,:,:] = z 91 | if image2 is not None: 92 | img_tensor_repeat[:,:,-1:,:,:] = z2 93 | else: 94 | img_tensor_repeat[:,:,-1:,:,:] = z 95 | 96 | 97 | cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc 98 | img_emb = model.image_proj_model(cond_images) 99 | 100 | imtext_cond = torch.cat([text_emb, img_emb], dim=1) 101 | 102 | fs = torch.tensor([fs], dtype=torch.long, device=model.device) 103 | cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]} 104 | 105 | ## inference 106 | batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale) 107 | 108 | ## remove the last frame 109 | if image2 is None: 110 | batch_samples = batch_samples[:,:,:,:-1,...] 111 | ## b,samples,c,t,h,w 112 | prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt 113 | prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str 114 | prompt_str=prompt_str[:40] 115 | if len(prompt_str) == 0: 116 | prompt_str = 'empty_prompt' 117 | 118 | imgs=save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps) 119 | print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds") 120 | model = model.cpu() 121 | #return os.path.join(self.result_dir, f"{prompt_str}.mp4") 122 | return imgs 123 | 124 | def download_model(self): 125 | REPO_ID = 'Doubiiu/DynamiCrafter_'+str(self.resolution[1])+'_Interp' 126 | filename_list = ['model.ckpt'] 127 | if not os.path.exists(models_path+'checkpoints/dynamicrafter_'+str(self.resolution[1])+'_interp_v1/'): 128 | os.makedirs(models_path+'checkpoints/dynamicrafter_'+str(self.resolution[1])+'_interp_v1/') 129 | for filename in filename_list: 130 | local_file = os.path.join(models_path+'checkpoints/dynamicrafter_'+str(self.resolution[1])+'_interp_v1/', filename) 131 | if not os.path.exists(local_file): 132 | hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=models_path+'checkpoints/dynamicrafter_'+str(self.resolution[1])+'_interp_v1/', local_dir_use_symlinks=False) 133 | 134 | if __name__ == '__main__': 135 | i2v = Image2Video() 136 | video_path = i2v.get_image('prompts/art.png','man fishing in a boat at sunset') 137 | print('done', video_path) -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | version=$1 ##1024, 512, 256 2 | seed=123 3 | name=dynamicrafter_$1_seed${seed} 4 | 5 | ckpt=checkpoints/dynamicrafter_$1_v1/model.ckpt 6 | config=configs/inference_$1_v1.0.yaml 7 | 8 | prompt_dir=prompts/$1/ 9 | res_dir="results" 10 | 11 | if [ "$1" == "256" ]; then 12 | H=256 13 | FS=3 ## This model adopts frame stride=3, range recommended: 1-6 (larger value -> larger motion) 14 | elif [ "$1" == "512" ]; then 15 | H=320 16 | FS=24 ## This model adopts FPS=24, range recommended: 15-30 (smaller value -> larger motion) 17 | elif [ "$1" == "1024" ]; then 18 | H=576 19 | FS=10 ## This model adopts FPS=10, range recommended: 15-5 (smaller value -> larger motion) 20 | else 21 | echo "Invalid input. Please enter 256, 512, or 1024." 22 | exit 1 23 | fi 24 | 25 | if [ "$1" == "256" ]; then 26 | CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/inference.py \ 27 | --seed ${seed} \ 28 | --ckpt_path $ckpt \ 29 | --config $config \ 30 | --savedir $res_dir/$name \ 31 | --n_samples 1 \ 32 | --bs 1 --height ${H} --width $1 \ 33 | --unconditional_guidance_scale 7.5 \ 34 | --ddim_steps 50 \ 35 | --ddim_eta 1.0 \ 36 | --prompt_dir $prompt_dir \ 37 | --text_input \ 38 | --video_length 16 \ 39 | --frame_stride ${FS} 40 | else 41 | CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/inference.py \ 42 | --seed ${seed} \ 43 | --ckpt_path $ckpt \ 44 | --config $config \ 45 | --savedir $res_dir/$name \ 46 | --n_samples 1 \ 47 | --bs 1 --height ${H} --width $1 \ 48 | --unconditional_guidance_scale 7.5 \ 49 | --ddim_steps 50 \ 50 | --ddim_eta 1.0 \ 51 | --prompt_dir $prompt_dir \ 52 | --text_input \ 53 | --video_length 16 \ 54 | --frame_stride ${FS} \ 55 | --timestep_spacing 'uniform_trailing' --guidance_rescale 0.7 --perframe_ae 56 | fi 57 | 58 | 59 | ## multi-cond CFG: the is s_txt, is s_img 60 | #--multiple_cond_cfg --cfg_img 7.5 61 | #--loop -------------------------------------------------------------------------------- /scripts/run_application.sh: -------------------------------------------------------------------------------- 1 | version=$1 # interp or loop 2 | ckpt=checkpoints/dynamicrafter_512_interp_v1/model.ckpt 3 | config=configs/inference_512_v1.0.yaml 4 | 5 | prompt_dir=prompts/512_$1/ 6 | res_dir="results" 7 | 8 | FS=5 ## This model adopts FPS=5, range recommended: 5-30 (smaller value -> larger motion) 9 | 10 | 11 | if [ "$1" == "interp" ]; then 12 | seed=12306 13 | name=dynamicrafter_512_$1_seed${seed} 14 | CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/inference.py \ 15 | --seed ${seed} \ 16 | --ckpt_path $ckpt \ 17 | --config $config \ 18 | --savedir $res_dir/$name \ 19 | --n_samples 1 \ 20 | --bs 1 --height 320 --width 512 \ 21 | --unconditional_guidance_scale 7.5 \ 22 | --ddim_steps 50 \ 23 | --ddim_eta 1.0 \ 24 | --prompt_dir $prompt_dir \ 25 | --text_input \ 26 | --video_length 16 \ 27 | --frame_stride ${FS} \ 28 | --timestep_spacing 'uniform_trailing' --guidance_rescale 0.7 --perframe_ae --interp 29 | else 30 | seed=234 31 | name=dynamicrafter_512_$1_seed${seed} 32 | CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/inference.py \ 33 | --seed ${seed} \ 34 | --ckpt_path $ckpt \ 35 | --config $config \ 36 | --savedir $res_dir/$name \ 37 | --n_samples 1 \ 38 | --bs 1 --height 320 --width 512 \ 39 | --unconditional_guidance_scale 7.5 \ 40 | --ddim_steps 50 \ 41 | --ddim_eta 1.0 \ 42 | --prompt_dir $prompt_dir \ 43 | --text_input \ 44 | --video_length 16 \ 45 | --frame_stride ${FS} \ 46 | --timestep_spacing 'uniform_trailing' --guidance_rescale 0.7 --perframe_ae --loop 47 | fi 48 | -------------------------------------------------------------------------------- /scripts/run_mp.sh: -------------------------------------------------------------------------------- 1 | version=$1 ##1024, 512, 256 2 | seed=123 3 | 4 | name=dynamicrafter_$1_mp_seed${seed} 5 | 6 | ckpt=checkpoints/dynamicrafter_$1_v1/model.ckpt 7 | config=configs/inference_$1_v1.0.yaml 8 | 9 | prompt_dir=prompts/$1/ 10 | res_dir="results" 11 | 12 | if [ "$1" == "256" ]; then 13 | H=256 14 | FS=3 ## This model adopts frame stride=3 15 | elif [ "$1" == "512" ]; then 16 | H=320 17 | FS=24 ## This model adopts FPS=24 18 | elif [ "$1" == "1024" ]; then 19 | H=576 20 | FS=10 ## This model adopts FPS=10 21 | else 22 | echo "Invalid input. Please enter 256, 512, or 1024." 23 | exit 1 24 | fi 25 | 26 | # if [ "$1" == "256" ]; then 27 | # CUDA_VISIBLE_DEVICES=2 python3 scripts/evaluation/inference.py \ 28 | # --seed 123 \ 29 | # --ckpt_path $ckpt \ 30 | # --config $config \ 31 | # --savedir $res_dir/$name \ 32 | # --n_samples 1 \ 33 | # --bs 1 --height ${H} --width $1 \ 34 | # --unconditional_guidance_scale 7.5 \ 35 | # --ddim_steps 50 \ 36 | # --ddim_eta 1.0 \ 37 | # --prompt_dir $prompt_dir \ 38 | # --text_input \ 39 | # --video_length 16 \ 40 | # --frame_stride ${FS} 41 | # else 42 | # CUDA_VISIBLE_DEVICES=2 python3 scripts/evaluation/inference.py \ 43 | # --seed 123 \ 44 | # --ckpt_path $ckpt \ 45 | # --config $config \ 46 | # --savedir $res_dir/$name \ 47 | # --n_samples 1 \ 48 | # --bs 1 --height ${H} --width $1 \ 49 | # --unconditional_guidance_scale 7.5 \ 50 | # --ddim_steps 50 \ 51 | # --ddim_eta 1.0 \ 52 | # --prompt_dir $prompt_dir \ 53 | # --text_input \ 54 | # --video_length 16 \ 55 | # --frame_stride ${FS} \ 56 | # --timestep_spacing 'uniform_trailing' --guidance_rescale 0.7 57 | # fi 58 | 59 | 60 | ## multi-cond CFG: the is s_txt, is s_img 61 | #--multiple_cond_cfg --cfg_img 7.5 62 | #--loop 63 | 64 | ## inference using single node with multi-GPUs: 65 | if [ "$1" == "256" ]; then 66 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \ 67 | --nproc_per_node=8 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \ 68 | scripts/evaluation/ddp_wrapper.py \ 69 | --module 'inference' \ 70 | --seed ${seed} \ 71 | --ckpt_path $ckpt \ 72 | --config $config \ 73 | --savedir $res_dir/$name \ 74 | --n_samples 1 \ 75 | --bs 1 --height ${H} --width $1 \ 76 | --unconditional_guidance_scale 7.5 \ 77 | --ddim_steps 50 \ 78 | --ddim_eta 1.0 \ 79 | --prompt_dir $prompt_dir \ 80 | --text_input \ 81 | --video_length 16 \ 82 | --frame_stride ${FS} 83 | else 84 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \ 85 | --nproc_per_node=8 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \ 86 | scripts/evaluation/ddp_wrapper.py \ 87 | --module 'inference' \ 88 | --seed ${seed} \ 89 | --ckpt_path $ckpt \ 90 | --config $config \ 91 | --savedir $res_dir/$name \ 92 | --n_samples 1 \ 93 | --bs 1 --height ${H} --width $1 \ 94 | --unconditional_guidance_scale 7.5 \ 95 | --ddim_steps 50 \ 96 | --ddim_eta 1.0 \ 97 | --prompt_dir $prompt_dir \ 98 | --text_input \ 99 | --video_length 16 \ 100 | --frame_stride ${FS} \ 101 | --timestep_spacing 'uniform_trailing' --guidance_rescale 0.7 --perframe_ae 102 | fi -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import cv2 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | def count_params(model, verbose=False): 9 | total_params = sum(p.numel() for p in model.parameters()) 10 | if verbose: 11 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 12 | return total_params 13 | 14 | 15 | def check_istarget(name, para_list): 16 | """ 17 | name: full name of source para 18 | para_list: partial name of target para 19 | """ 20 | istarget=False 21 | for para in para_list: 22 | if para in name: 23 | return True 24 | return istarget 25 | 26 | 27 | def instantiate_from_config(config): 28 | if not "target" in config: 29 | if config == '__is_first_stage__': 30 | return None 31 | elif config == "__is_unconditional__": 32 | return None 33 | raise KeyError("Expected key `target` to instantiate.") 34 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 35 | 36 | 37 | def get_obj_from_str(string, reload=False): 38 | module, cls = string.rsplit(".", 1) 39 | if reload: 40 | module_imp = importlib.import_module(module) 41 | importlib.reload(module_imp) 42 | return getattr(importlib.import_module(module, package=None), cls) 43 | 44 | 45 | def load_npz_from_dir(data_dir): 46 | data = [np.load(os.path.join(data_dir, data_name))['arr_0'] for data_name in os.listdir(data_dir)] 47 | data = np.concatenate(data, axis=0) 48 | return data 49 | 50 | 51 | def load_npz_from_paths(data_paths): 52 | data = [np.load(data_path)['arr_0'] for data_path in data_paths] 53 | data = np.concatenate(data, axis=0) 54 | return data 55 | 56 | 57 | def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None): 58 | h, w = image.shape[:2] 59 | if resize_short_edge is not None: 60 | k = resize_short_edge / min(h, w) 61 | else: 62 | k = max_resolution / (h * w) 63 | k = k**0.5 64 | h = int(np.round(h * k / 64)) * 64 65 | w = int(np.round(w * k / 64)) * 64 66 | image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) 67 | return image 68 | 69 | 70 | def setup_dist(args): 71 | if dist.is_initialized(): 72 | return 73 | torch.cuda.set_device(args.local_rank) 74 | torch.distributed.init_process_group( 75 | 'nccl', 76 | init_method='env://' 77 | ) -------------------------------------------------------------------------------- /video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/video.gif -------------------------------------------------------------------------------- /wf-basic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/wf-basic.png -------------------------------------------------------------------------------- /wf-interp.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 9, 3 | "last_link_id": 12, 4 | "nodes": [ 5 | { 6 | "id": 1, 7 | "type": "DynamiCrafterInterpLoader", 8 | "pos": [ 9 | 35, 10 | 194 11 | ], 12 | "size": { 13 | "0": 315, 14 | "1": 82 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "model", 22 | "type": "DynamiCrafterInter", 23 | "links": [ 24 | 1 25 | ], 26 | "shape": 3, 27 | "slot_index": 0 28 | } 29 | ], 30 | "properties": { 31 | "Node name for S&R": "DynamiCrafterInterpLoader" 32 | }, 33 | "widgets_values": [ 34 | "320_512", 35 | 16 36 | ] 37 | }, 38 | { 39 | "id": 3, 40 | "type": "LoadImage", 41 | "pos": [ 42 | 90, 43 | 412 44 | ], 45 | "size": { 46 | "0": 315, 47 | "1": 314 48 | }, 49 | "flags": { 50 | "collapsed": true 51 | }, 52 | "order": 1, 53 | "mode": 0, 54 | "outputs": [ 55 | { 56 | "name": "IMAGE", 57 | "type": "IMAGE", 58 | "links": [ 59 | 9 60 | ], 61 | "shape": 3, 62 | "slot_index": 0 63 | }, 64 | { 65 | "name": "MASK", 66 | "type": "MASK", 67 | "links": null, 68 | "shape": 3 69 | } 70 | ], 71 | "properties": { 72 | "Node name for S&R": "LoadImage" 73 | }, 74 | "widgets_values": [ 75 | "u=2397542458,3133539061&fm=193&f=GIF.jfif", 76 | "image" 77 | ] 78 | }, 79 | { 80 | "id": 4, 81 | "type": "LoadImage", 82 | "pos": [ 83 | 96, 84 | 513 85 | ], 86 | "size": { 87 | "0": 315, 88 | "1": 314 89 | }, 90 | "flags": { 91 | "collapsed": true 92 | }, 93 | "order": 2, 94 | "mode": 0, 95 | "outputs": [ 96 | { 97 | "name": "IMAGE", 98 | "type": "IMAGE", 99 | "links": [ 100 | 10 101 | ], 102 | "shape": 3, 103 | "slot_index": 0 104 | }, 105 | { 106 | "name": "MASK", 107 | "type": "MASK", 108 | "links": null, 109 | "shape": 3 110 | } 111 | ], 112 | "properties": { 113 | "Node name for S&R": "LoadImage" 114 | }, 115 | "widgets_values": [ 116 | "u=2511982910,2454873241&fm=193&f=GIF.jfif", 117 | "image" 118 | ] 119 | }, 120 | { 121 | "id": 9, 122 | "type": "VHS_VideoCombine", 123 | "pos": [ 124 | 874, 125 | 262 126 | ], 127 | "size": [ 128 | 315, 129 | 488.375 130 | ], 131 | "flags": {}, 132 | "order": 4, 133 | "mode": 0, 134 | "inputs": [ 135 | { 136 | "name": "images", 137 | "type": "IMAGE", 138 | "link": 12 139 | }, 140 | { 141 | "name": "audio", 142 | "type": "VHS_AUDIO", 143 | "link": null 144 | }, 145 | { 146 | "name": "batch_manager", 147 | "type": "VHS_BatchManager", 148 | "link": null 149 | } 150 | ], 151 | "outputs": [ 152 | { 153 | "name": "Filenames", 154 | "type": "VHS_FILENAMES", 155 | "links": null, 156 | "shape": 3 157 | } 158 | ], 159 | "properties": { 160 | "Node name for S&R": "VHS_VideoCombine" 161 | }, 162 | "widgets_values": { 163 | "frame_rate": 8, 164 | "loop_count": 0, 165 | "filename_prefix": "AnimateDiff", 166 | "format": "video/h264-mp4", 167 | "pix_fmt": "yuv420p", 168 | "crf": 19, 169 | "save_metadata": true, 170 | "pingpong": false, 171 | "save_output": true, 172 | "videopreview": { 173 | "hidden": false, 174 | "paused": false, 175 | "params": { 176 | "filename": "AnimateDiff_00058.mp4", 177 | "subfolder": "", 178 | "type": "output", 179 | "format": "video/h264-mp4" 180 | } 181 | } 182 | } 183 | }, 184 | { 185 | "id": 2, 186 | "type": "DynamiCrafterInterp Simple", 187 | "pos": [ 188 | 394, 189 | 297 190 | ], 191 | "size": { 192 | "0": 315, 193 | "1": 242 194 | }, 195 | "flags": {}, 196 | "order": 3, 197 | "mode": 0, 198 | "inputs": [ 199 | { 200 | "name": "model", 201 | "type": "DynamiCrafterInter", 202 | "link": 1 203 | }, 204 | { 205 | "name": "image", 206 | "type": "IMAGE", 207 | "link": 9 208 | }, 209 | { 210 | "name": "image1", 211 | "type": "IMAGE", 212 | "link": 10 213 | } 214 | ], 215 | "outputs": [ 216 | { 217 | "name": "image", 218 | "type": "IMAGE", 219 | "links": [ 220 | 12 221 | ], 222 | "shape": 3, 223 | "slot_index": 0 224 | } 225 | ], 226 | "properties": { 227 | "Node name for S&R": "DynamiCrafterInterp Simple" 228 | }, 229 | "widgets_values": [ 230 | "", 231 | 50, 232 | 7.5, 233 | 1, 234 | 3, 235 | 1489, 236 | "randomize" 237 | ] 238 | } 239 | ], 240 | "links": [ 241 | [ 242 | 1, 243 | 1, 244 | 0, 245 | 2, 246 | 0, 247 | "DynamiCrafterInter" 248 | ], 249 | [ 250 | 9, 251 | 3, 252 | 0, 253 | 2, 254 | 1, 255 | "IMAGE" 256 | ], 257 | [ 258 | 10, 259 | 4, 260 | 0, 261 | 2, 262 | 2, 263 | "IMAGE" 264 | ], 265 | [ 266 | 12, 267 | 2, 268 | 0, 269 | 9, 270 | 0, 271 | "IMAGE" 272 | ] 273 | ], 274 | "groups": [], 275 | "config": {}, 276 | "extra": {}, 277 | "version": 0.4 278 | } -------------------------------------------------------------------------------- /wf-interp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/wf-interp.png -------------------------------------------------------------------------------- /workflow.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 4, 3 | "last_link_id": 3, 4 | "nodes": [ 5 | { 6 | "id": 3, 7 | "type": "LoadImage", 8 | "pos": [ 9 | 143, 10 | 296 11 | ], 12 | "size": { 13 | "0": 315, 14 | "1": 314 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "IMAGE", 22 | "type": "IMAGE", 23 | "links": [ 24 | 2 25 | ], 26 | "shape": 3, 27 | "slot_index": 0 28 | }, 29 | { 30 | "name": "MASK", 31 | "type": "MASK", 32 | "links": null, 33 | "shape": 3 34 | } 35 | ], 36 | "properties": { 37 | "Node name for S&R": "LoadImage" 38 | }, 39 | "widgets_values": [ 40 | "00031-4116268099-1girl_AND trees_AND grass_AND sky_AND 1girl (1).png", 41 | "image" 42 | ] 43 | }, 44 | { 45 | "id": 2, 46 | "type": "DynamiCrafter Simple", 47 | "pos": [ 48 | 604, 49 | 141.33334350585938 50 | ], 51 | "size": { 52 | "0": 315, 53 | "1": 222 54 | }, 55 | "flags": {}, 56 | "order": 2, 57 | "mode": 0, 58 | "inputs": [ 59 | { 60 | "name": "model", 61 | "type": "DynamiCrafter", 62 | "link": 1 63 | }, 64 | { 65 | "name": "image", 66 | "type": "IMAGE", 67 | "link": 2 68 | } 69 | ], 70 | "outputs": [ 71 | { 72 | "name": "image", 73 | "type": "IMAGE", 74 | "links": [ 75 | 3 76 | ], 77 | "shape": 3, 78 | "slot_index": 0 79 | } 80 | ], 81 | "properties": { 82 | "Node name for S&R": "DynamiCrafter Simple" 83 | }, 84 | "widgets_values": [ 85 | "1girl sleeping", 86 | 50, 87 | 7.5, 88 | 1, 89 | 3, 90 | 1022, 91 | "randomize" 92 | ] 93 | }, 94 | { 95 | "id": 4, 96 | "type": "VHS_VideoCombine", 97 | "pos": [ 98 | 1061, 99 | 144 100 | ], 101 | "size": [ 102 | 315, 103 | 414.3125 104 | ], 105 | "flags": {}, 106 | "order": 3, 107 | "mode": 0, 108 | "inputs": [ 109 | { 110 | "name": "images", 111 | "type": "IMAGE", 112 | "link": 3 113 | } 114 | ], 115 | "outputs": [], 116 | "properties": { 117 | "Node name for S&R": "VHS_VideoCombine" 118 | }, 119 | "widgets_values": { 120 | "frame_rate": 8, 121 | "loop_count": 0, 122 | "filename_prefix": "AnimateDiff", 123 | "format": "image/gif", 124 | "pingpong": false, 125 | "save_image": true, 126 | "crf": 20, 127 | "save_metadata": true, 128 | "audio_file": "", 129 | "videopreview": { 130 | "hidden": false, 131 | "paused": false, 132 | "params": { 133 | "filename": "AnimateDiff_00900.gif", 134 | "subfolder": "", 135 | "type": "output", 136 | "format": "image/gif" 137 | } 138 | } 139 | } 140 | }, 141 | { 142 | "id": 1, 143 | "type": "DynamiCrafterLoader", 144 | "pos": [ 145 | 193, 146 | 148 147 | ], 148 | "size": { 149 | "0": 315, 150 | "1": 82 151 | }, 152 | "flags": {}, 153 | "order": 1, 154 | "mode": 0, 155 | "outputs": [ 156 | { 157 | "name": "model", 158 | "type": "DynamiCrafter", 159 | "links": [ 160 | 1 161 | ], 162 | "shape": 3, 163 | "slot_index": 0 164 | } 165 | ], 166 | "properties": { 167 | "Node name for S&R": "DynamiCrafterLoader" 168 | }, 169 | "widgets_values": [ 170 | "576_1024", 171 | 16 172 | ] 173 | } 174 | ], 175 | "links": [ 176 | [ 177 | 1, 178 | 1, 179 | 0, 180 | 2, 181 | 0, 182 | "DynamiCrafter" 183 | ], 184 | [ 185 | 2, 186 | 3, 187 | 0, 188 | 2, 189 | 1, 190 | "IMAGE" 191 | ], 192 | [ 193 | 3, 194 | 2, 195 | 0, 196 | 4, 197 | 0, 198 | "IMAGE" 199 | ] 200 | ], 201 | "groups": [], 202 | "config": {}, 203 | "extra": {}, 204 | "version": 0.4 205 | } --------------------------------------------------------------------------------