├── .gitignore
├── 1girl_sleeping.gif
├── LICENSE
├── README.md
├── __init__.py
├── assets
├── application
│ ├── 05.gif
│ ├── 25.gif
│ ├── 34.gif
│ ├── 35.gif
│ ├── 36.gif
│ ├── 60.gif
│ ├── YwHJYWvv_dM.gif
│ ├── YwHJYWvv_dM_input_end.png
│ ├── YwHJYWvv_dM_input_start.png
│ ├── gkxX0kb8mE8.gif
│ ├── gkxX0kb8mE8_input_end.png
│ ├── gkxX0kb8mE8_input_start.png
│ ├── smile.gif
│ ├── smile_end.png
│ ├── smile_start.png
│ ├── stone01.gif
│ ├── stone01_end.png
│ ├── stone01_start.png
│ ├── storytellingvideo.gif
│ ├── ypDLB52Ykk4.gif
│ ├── ypDLB52Ykk4_input_end.png
│ └── ypDLB52Ykk4_input_start.png
├── logo_long.png
├── logo_long_dark.png
└── showcase
│ ├── Upscaled_Aime_Tribolet_springtime_landscape_golden_hour_morning_pale_yel_e6946f8d-37c1-4ce8-bf62-6ba90d23bd93.gif
│ ├── Upscaled_Aime_Tribolet_springtime_landscape_golden_hour_morning_pale_yel_e6946f8d-37c1-4ce8-bf62-6ba90d23bd93.mp4_00.png
│ ├── Upscaled_Alex__State_Blonde_woman_riding_on_top_of_a_moving_washing_mach_c31acaa3-dd30-459f-a109-2d2eb4c00fe2.gif
│ ├── Upscaled_Alex__State_Blonde_woman_riding_on_top_of_a_moving_washing_mach_c31acaa3-dd30-459f-a109-2d2eb4c00fe2.mp4_00.png
│ ├── bike_chineseink.gif
│ ├── bird000.gif
│ ├── bird000.jpeg
│ ├── bloom2.gif
│ ├── dance1.gif
│ ├── dance1.jpeg_00.png
│ ├── explode0.gif
│ ├── explode0.jpeg_00.png
│ ├── firework03.gif
│ ├── girl07.gif
│ ├── girl3.gif
│ ├── girl3.jpeg_00.png
│ ├── guitar0.gif
│ ├── guitar0.jpeg_00.png
│ ├── lighthouse.gif
│ ├── pour_honey.gif
│ ├── robot01.gif
│ ├── train_anime02.gif
│ ├── walk0.gif
│ └── walk0.png_00.png
├── configs
├── inference_1024_v1.0.yaml
├── inference_256_v1.0.yaml
└── inference_512_v1.0.yaml
├── gradio_app.py
├── lvdm
├── basics.py
├── common.py
├── distributions.py
├── ema.py
├── models
│ ├── autoencoder.py
│ ├── ddpm3d.py
│ ├── samplers
│ │ ├── ddim.py
│ │ └── ddim_multiplecond.py
│ └── utils_diffusion.py
└── modules
│ ├── attention.py
│ ├── encoders
│ ├── condition.py
│ └── resampler.py
│ ├── networks
│ ├── ae_modules.py
│ └── openaimodel3d.py
│ └── x_transformer.py
├── nodes.py
├── prompts
├── 256
│ ├── art.png
│ ├── bear.png
│ ├── boy.png
│ ├── dance1.jpeg
│ ├── fire_and_beach.jpg
│ ├── girl2.jpeg
│ ├── girl3.jpeg
│ ├── guitar0.jpeg
│ └── test_prompts.txt
├── 512
│ ├── bloom01.png
│ ├── campfire.png
│ ├── girl08.png
│ ├── isometric.png
│ ├── pour_honey.png
│ ├── ship02.png
│ ├── test_prompts.txt
│ ├── zreal_boat.png
│ └── zreal_penguin.png
├── 1024
│ ├── astronaut04.png
│ ├── bike_chineseink.png
│ ├── bloom01.png
│ ├── firework03.png
│ ├── girl07.png
│ ├── pour_bear.png
│ ├── robot01.png
│ ├── test_prompts.txt
│ └── zreal_penguin.png
├── 512_interp
│ ├── smile_01.png
│ ├── smile_02.png
│ ├── stone01_01.png
│ ├── stone01_02.png
│ ├── test_prompts.txt
│ ├── walk_01.png
│ └── walk_02.png
└── 512_loop
│ ├── 24.png
│ ├── 36.png
│ ├── 40.png
│ └── test_prompts.txt
├── requirements.txt
├── scripts
├── evaluation
│ ├── ddp_wrapper.py
│ ├── funcs.py
│ └── inference.py
├── gradio
│ ├── i2v_test.py
│ └── i2v_test_application.py
├── run.sh
├── run_application.sh
└── run_mp.sh
├── utils
└── utils.py
├── video.gif
├── wf-basic.png
├── wf-interp.json
├── wf-interp.png
└── workflow.json
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | *pyc
3 | .vscode
4 | __pycache__
5 | *.egg-info
6 |
7 | checkpoints
8 | results
9 | backup
--------------------------------------------------------------------------------
/1girl_sleeping.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/1girl_sleeping.gif
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ComfyUI DynamiCrafter
2 |
3 | [DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter)
4 |
5 | 注意一定要安装xformers `pip install xformers`
6 |
7 | 默认自动下载模型,但中国用户下载模型参考https://hf-mirror.com/
8 |
9 | ```
10 | pip install -U huggingface_hub
11 | export HF_ENDPOINT=https://hf-mirror.com windows上在powershell里$env:HF_ENDPOINT = "https://hf-mirror.com"
12 | huggingface-cli download --resume-download laion/CLIP-ViT-H-14-laion2B-s32B-b79K
13 | ```
14 |
15 | 下载model.ckpt到models/checkpoints/dynamicrafter_512_interp_v1/model.ckpt
16 |
17 | https://hf-mirror.com/Doubiiu/DynamiCrafter_512_Interp
18 |
19 | 下载model.ckpt到models/checkpoints/dynamicrafter_1024_v1/model.ckpt
20 |
21 | https://hf-mirror.com/Doubiiu/DynamiCrafter_1024
22 |
23 |
24 | ## Examples
25 |
26 | ### base workflow
27 |
28 |
29 |
30 | https://github.com/chaojie/ComfyUI-DynamiCrafter/blob/main/workflow.json
31 |
32 | prompt: 1girl
33 |
34 |
35 |
36 | prompt: 1girl sleeping
37 |
38 |
39 |
40 | 4090 test:
41 |
42 | 16 frame length takes 3 minutes
43 |
44 | 32 frame length takes 6 minutes (32设置是在Loader节点)
45 |
46 |
47 |
48 | ### interpolation workflow
49 |
50 |
51 |
52 | https://github.com/chaojie/ComfyUI-DynamiCrafter/blob/main/wf-interp.json
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .nodes import NODE_CLASS_MAPPINGS
2 |
3 | __all__ = ['NODE_CLASS_MAPPINGS']
4 |
5 |
--------------------------------------------------------------------------------
/assets/application/05.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/05.gif
--------------------------------------------------------------------------------
/assets/application/25.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/25.gif
--------------------------------------------------------------------------------
/assets/application/34.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/34.gif
--------------------------------------------------------------------------------
/assets/application/35.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/35.gif
--------------------------------------------------------------------------------
/assets/application/36.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/36.gif
--------------------------------------------------------------------------------
/assets/application/60.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/60.gif
--------------------------------------------------------------------------------
/assets/application/YwHJYWvv_dM.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/YwHJYWvv_dM.gif
--------------------------------------------------------------------------------
/assets/application/YwHJYWvv_dM_input_end.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/YwHJYWvv_dM_input_end.png
--------------------------------------------------------------------------------
/assets/application/YwHJYWvv_dM_input_start.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/YwHJYWvv_dM_input_start.png
--------------------------------------------------------------------------------
/assets/application/gkxX0kb8mE8.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/gkxX0kb8mE8.gif
--------------------------------------------------------------------------------
/assets/application/gkxX0kb8mE8_input_end.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/gkxX0kb8mE8_input_end.png
--------------------------------------------------------------------------------
/assets/application/gkxX0kb8mE8_input_start.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/gkxX0kb8mE8_input_start.png
--------------------------------------------------------------------------------
/assets/application/smile.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/smile.gif
--------------------------------------------------------------------------------
/assets/application/smile_end.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/smile_end.png
--------------------------------------------------------------------------------
/assets/application/smile_start.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/smile_start.png
--------------------------------------------------------------------------------
/assets/application/stone01.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/stone01.gif
--------------------------------------------------------------------------------
/assets/application/stone01_end.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/stone01_end.png
--------------------------------------------------------------------------------
/assets/application/stone01_start.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/stone01_start.png
--------------------------------------------------------------------------------
/assets/application/storytellingvideo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/storytellingvideo.gif
--------------------------------------------------------------------------------
/assets/application/ypDLB52Ykk4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/ypDLB52Ykk4.gif
--------------------------------------------------------------------------------
/assets/application/ypDLB52Ykk4_input_end.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/ypDLB52Ykk4_input_end.png
--------------------------------------------------------------------------------
/assets/application/ypDLB52Ykk4_input_start.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/application/ypDLB52Ykk4_input_start.png
--------------------------------------------------------------------------------
/assets/logo_long.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/logo_long.png
--------------------------------------------------------------------------------
/assets/logo_long_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/logo_long_dark.png
--------------------------------------------------------------------------------
/assets/showcase/Upscaled_Aime_Tribolet_springtime_landscape_golden_hour_morning_pale_yel_e6946f8d-37c1-4ce8-bf62-6ba90d23bd93.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/Upscaled_Aime_Tribolet_springtime_landscape_golden_hour_morning_pale_yel_e6946f8d-37c1-4ce8-bf62-6ba90d23bd93.gif
--------------------------------------------------------------------------------
/assets/showcase/Upscaled_Aime_Tribolet_springtime_landscape_golden_hour_morning_pale_yel_e6946f8d-37c1-4ce8-bf62-6ba90d23bd93.mp4_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/Upscaled_Aime_Tribolet_springtime_landscape_golden_hour_morning_pale_yel_e6946f8d-37c1-4ce8-bf62-6ba90d23bd93.mp4_00.png
--------------------------------------------------------------------------------
/assets/showcase/Upscaled_Alex__State_Blonde_woman_riding_on_top_of_a_moving_washing_mach_c31acaa3-dd30-459f-a109-2d2eb4c00fe2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/Upscaled_Alex__State_Blonde_woman_riding_on_top_of_a_moving_washing_mach_c31acaa3-dd30-459f-a109-2d2eb4c00fe2.gif
--------------------------------------------------------------------------------
/assets/showcase/Upscaled_Alex__State_Blonde_woman_riding_on_top_of_a_moving_washing_mach_c31acaa3-dd30-459f-a109-2d2eb4c00fe2.mp4_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/Upscaled_Alex__State_Blonde_woman_riding_on_top_of_a_moving_washing_mach_c31acaa3-dd30-459f-a109-2d2eb4c00fe2.mp4_00.png
--------------------------------------------------------------------------------
/assets/showcase/bike_chineseink.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/bike_chineseink.gif
--------------------------------------------------------------------------------
/assets/showcase/bird000.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/bird000.gif
--------------------------------------------------------------------------------
/assets/showcase/bird000.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/bird000.jpeg
--------------------------------------------------------------------------------
/assets/showcase/bloom2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/bloom2.gif
--------------------------------------------------------------------------------
/assets/showcase/dance1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/dance1.gif
--------------------------------------------------------------------------------
/assets/showcase/dance1.jpeg_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/dance1.jpeg_00.png
--------------------------------------------------------------------------------
/assets/showcase/explode0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/explode0.gif
--------------------------------------------------------------------------------
/assets/showcase/explode0.jpeg_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/explode0.jpeg_00.png
--------------------------------------------------------------------------------
/assets/showcase/firework03.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/firework03.gif
--------------------------------------------------------------------------------
/assets/showcase/girl07.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/girl07.gif
--------------------------------------------------------------------------------
/assets/showcase/girl3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/girl3.gif
--------------------------------------------------------------------------------
/assets/showcase/girl3.jpeg_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/girl3.jpeg_00.png
--------------------------------------------------------------------------------
/assets/showcase/guitar0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/guitar0.gif
--------------------------------------------------------------------------------
/assets/showcase/guitar0.jpeg_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/guitar0.jpeg_00.png
--------------------------------------------------------------------------------
/assets/showcase/lighthouse.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/lighthouse.gif
--------------------------------------------------------------------------------
/assets/showcase/pour_honey.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/pour_honey.gif
--------------------------------------------------------------------------------
/assets/showcase/robot01.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/robot01.gif
--------------------------------------------------------------------------------
/assets/showcase/train_anime02.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/train_anime02.gif
--------------------------------------------------------------------------------
/assets/showcase/walk0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/walk0.gif
--------------------------------------------------------------------------------
/assets/showcase/walk0.png_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/assets/showcase/walk0.png_00.png
--------------------------------------------------------------------------------
/configs/inference_1024_v1.0.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.models.ddpm3d.LatentVisualDiffusion
3 | params:
4 | rescale_betas_zero_snr: True
5 | parameterization: "v"
6 | linear_start: 0.00085
7 | linear_end: 0.012
8 | num_timesteps_cond: 1
9 | timesteps: 1000
10 | first_stage_key: video
11 | cond_stage_key: caption
12 | cond_stage_trainable: False
13 | conditioning_key: hybrid
14 | image_size: [72, 128]
15 | channels: 4
16 | scale_by_std: False
17 | scale_factor: 0.18215
18 | use_ema: False
19 | uncond_type: 'empty_seq'
20 | use_dynamic_rescale: true
21 | base_scale: 0.3
22 | fps_condition_type: 'fps'
23 | perframe_ae: True
24 | unet_config:
25 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.networks.openaimodel3d.UNetModel
26 | params:
27 | in_channels: 8
28 | out_channels: 4
29 | model_channels: 320
30 | attention_resolutions:
31 | - 4
32 | - 2
33 | - 1
34 | num_res_blocks: 2
35 | channel_mult:
36 | - 1
37 | - 2
38 | - 4
39 | - 4
40 | dropout: 0.1
41 | num_head_channels: 64
42 | transformer_depth: 1
43 | context_dim: 1024
44 | use_linear: true
45 | use_checkpoint: True
46 | temporal_conv: True
47 | temporal_attention: True
48 | temporal_selfatt_only: true
49 | use_relative_position: false
50 | use_causal_attention: False
51 | temporal_length: 16
52 | addition_attention: true
53 | image_cross_attention: true
54 | default_fs: 10
55 | fs_condition: true
56 |
57 | first_stage_config:
58 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.models.autoencoder.AutoencoderKL
59 | params:
60 | embed_dim: 4
61 | monitor: val/rec_loss
62 | ddconfig:
63 | double_z: True
64 | z_channels: 4
65 | resolution: 256
66 | in_channels: 3
67 | out_ch: 3
68 | ch: 128
69 | ch_mult:
70 | - 1
71 | - 2
72 | - 4
73 | - 4
74 | num_res_blocks: 2
75 | attn_resolutions: []
76 | dropout: 0.0
77 | lossconfig:
78 | target: torch.nn.Identity
79 |
80 | cond_stage_config:
81 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
82 | params:
83 | freeze: true
84 | layer: "penultimate"
85 |
86 | img_cond_stage_config:
87 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
88 | params:
89 | freeze: true
90 |
91 | image_proj_stage_config:
92 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.encoders.resampler.Resampler
93 | params:
94 | dim: 1024
95 | depth: 4
96 | dim_head: 64
97 | heads: 12
98 | num_queries: 16
99 | embedding_dim: 1280
100 | output_dim: 1024
101 | ff_mult: 4
102 | video_length: 16
103 |
104 |
--------------------------------------------------------------------------------
/configs/inference_256_v1.0.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.models.ddpm3d.LatentVisualDiffusion
3 | params:
4 | linear_start: 0.00085
5 | linear_end: 0.012
6 | num_timesteps_cond: 1
7 | timesteps: 1000
8 | first_stage_key: video
9 | cond_stage_key: caption
10 | cond_stage_trainable: False
11 | conditioning_key: hybrid
12 | image_size: [32, 32]
13 | channels: 4
14 | scale_by_std: False
15 | scale_factor: 0.18215
16 | use_ema: False
17 | uncond_type: 'empty_seq'
18 | unet_config:
19 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.networks.openaimodel3d.UNetModel
20 | params:
21 | in_channels: 8
22 | out_channels: 4
23 | model_channels: 320
24 | attention_resolutions:
25 | - 4
26 | - 2
27 | - 1
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 4
33 | - 4
34 | dropout: 0.1
35 | num_head_channels: 64
36 | transformer_depth: 1
37 | context_dim: 1024
38 | use_linear: true
39 | use_checkpoint: True
40 | temporal_conv: True
41 | temporal_attention: True
42 | temporal_selfatt_only: true
43 | use_relative_position: false
44 | use_causal_attention: False
45 | temporal_length: 16
46 | addition_attention: true
47 | image_cross_attention: true
48 | image_cross_attention_scale_learnable: true
49 | default_fs: 3
50 | fs_condition: true
51 |
52 | first_stage_config:
53 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.models.autoencoder.AutoencoderKL
54 | params:
55 | embed_dim: 4
56 | monitor: val/rec_loss
57 | ddconfig:
58 | double_z: True
59 | z_channels: 4
60 | resolution: 256
61 | in_channels: 3
62 | out_ch: 3
63 | ch: 128
64 | ch_mult:
65 | - 1
66 | - 2
67 | - 4
68 | - 4
69 | num_res_blocks: 2
70 | attn_resolutions: []
71 | dropout: 0.0
72 | lossconfig:
73 | target: torch.nn.Identity
74 |
75 | cond_stage_config:
76 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
77 | params:
78 | freeze: true
79 | layer: "penultimate"
80 |
81 | img_cond_stage_config:
82 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
83 | params:
84 | freeze: true
85 |
86 | image_proj_stage_config:
87 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.encoders.resampler.Resampler
88 | params:
89 | dim: 1024
90 | depth: 4
91 | dim_head: 64
92 | heads: 12
93 | num_queries: 16
94 | embedding_dim: 1280
95 | output_dim: 1024
96 | ff_mult: 4
97 | video_length: 16
98 |
99 |
--------------------------------------------------------------------------------
/configs/inference_512_v1.0.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.models.ddpm3d.LatentVisualDiffusion
3 | params:
4 | rescale_betas_zero_snr: True
5 | parameterization: "v"
6 | linear_start: 0.00085
7 | linear_end: 0.012
8 | num_timesteps_cond: 1
9 | timesteps: 1000
10 | first_stage_key: video
11 | cond_stage_key: caption
12 | cond_stage_trainable: False
13 | conditioning_key: hybrid
14 | image_size: [40, 64]
15 | channels: 4
16 | scale_by_std: False
17 | scale_factor: 0.18215
18 | use_ema: False
19 | uncond_type: 'empty_seq'
20 | use_dynamic_rescale: true
21 | base_scale: 0.7
22 | fps_condition_type: 'fps'
23 | perframe_ae: True
24 | unet_config:
25 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.networks.openaimodel3d.UNetModel
26 | params:
27 | in_channels: 8
28 | out_channels: 4
29 | model_channels: 320
30 | attention_resolutions:
31 | - 4
32 | - 2
33 | - 1
34 | num_res_blocks: 2
35 | channel_mult:
36 | - 1
37 | - 2
38 | - 4
39 | - 4
40 | dropout: 0.1
41 | num_head_channels: 64
42 | transformer_depth: 1
43 | context_dim: 1024
44 | use_linear: true
45 | use_checkpoint: True
46 | temporal_conv: True
47 | temporal_attention: True
48 | temporal_selfatt_only: true
49 | use_relative_position: false
50 | use_causal_attention: False
51 | temporal_length: 16
52 | addition_attention: true
53 | image_cross_attention: true
54 | default_fs: 24
55 | fs_condition: true
56 |
57 | first_stage_config:
58 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.models.autoencoder.AutoencoderKL
59 | params:
60 | embed_dim: 4
61 | monitor: val/rec_loss
62 | ddconfig:
63 | double_z: True
64 | z_channels: 4
65 | resolution: 256
66 | in_channels: 3
67 | out_ch: 3
68 | ch: 128
69 | ch_mult:
70 | - 1
71 | - 2
72 | - 4
73 | - 4
74 | num_res_blocks: 2
75 | attn_resolutions: []
76 | dropout: 0.0
77 | lossconfig:
78 | target: torch.nn.Identity
79 |
80 | cond_stage_config:
81 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
82 | params:
83 | freeze: true
84 | layer: "penultimate"
85 |
86 | img_cond_stage_config:
87 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
88 | params:
89 | freeze: true
90 |
91 | image_proj_stage_config:
92 | target: custom_nodes.ComfyUI-DynamiCrafter.lvdm.modules.encoders.resampler.Resampler
93 | params:
94 | dim: 1024
95 | depth: 4
96 | dim_head: 64
97 | heads: 12
98 | num_queries: 16
99 | embedding_dim: 1280
100 | output_dim: 1024
101 | ff_mult: 4
102 | video_length: 16
103 |
104 |
--------------------------------------------------------------------------------
/gradio_app.py:
--------------------------------------------------------------------------------
1 | import os, argparse
2 | import sys
3 | import gradio as gr
4 | from scripts.gradio.i2v_test import Image2Video
5 | sys.path.insert(1, os.path.join(sys.path[0], 'lvdm'))
6 |
7 | i2v_examples_1024 = [
8 | ['prompts/1024/astronaut04.png', 'a man in an astronaut suit playing a guitar', 50, 7.5, 1.0, 6, 123],
9 | ['prompts/1024/bloom01.png', 'time-lapse of a blooming flower with leaves and a stem', 50, 7.5, 1.0, 10, 123],
10 | ['prompts/1024/girl07.png', 'a beautiful woman with long hair and a dress blowing in the wind', 50, 7.5, 1.0, 10, 123],
11 | ['prompts/1024/pour_bear.png', 'pouring beer into a glass of ice and beer', 50, 7.5, 1.0, 10, 123],
12 | ['prompts/1024/robot01.png', 'a robot is walking through a destroyed city', 50, 7.5, 1.0, 10, 123],
13 | ['prompts/1024/firework03.png', 'fireworks display', 50, 7.5, 1.0, 10, 123],
14 | ]
15 |
16 | i2v_examples_512 = [
17 | ['prompts/512/bloom01.png', 'time-lapse of a blooming flower with leaves and a stem', 50, 7.5, 1.0, 24, 123],
18 | ['prompts/512/campfire.png', 'a bonfire is lit in the middle of a field', 50, 7.5, 1.0, 24, 123],
19 | ['prompts/512/isometric.png', 'rotating view, small house', 50, 7.5, 1.0, 24, 123],
20 | ['prompts/512/girl08.png', 'a woman looking out in the rain', 50, 7.5, 1.0, 24, 1234],
21 | ['prompts/512/ship02.png', 'a sailboat sailing in rough seas with a dramatic sunset', 50, 7.5, 1.0, 24, 123],
22 | ['prompts/512/zreal_penguin.png', 'a group of penguins walking on a beach', 50, 7.5, 1.0, 20, 123],
23 | ]
24 |
25 | i2v_examples_256 = [
26 | ['prompts/256/art.png', 'man fishing in a boat at sunset', 50, 7.5, 1.0, 3, 234],
27 | ['prompts/256/boy.png', 'boy walking on the street', 50, 7.5, 1.0, 3, 125],
28 | ['prompts/256/dance1.jpeg', 'two people dancing', 50, 7.5, 1.0, 3, 116],
29 | ['prompts/256/fire_and_beach.jpg', 'a campfire on the beach and the ocean waves in the background', 50, 7.5, 1.0, 3, 111],
30 | ['prompts/256/girl3.jpeg', 'girl talking and blinking', 50, 7.5, 1.0, 3, 111],
31 | ['prompts/256/guitar0.jpeg', 'bear playing guitar happily, snowing', 50, 7.5, 1.0, 3, 122]
32 | ]
33 |
34 |
35 | def dynamicrafter_demo(result_dir='./tmp/', res=1024):
36 | if res == 1024:
37 | resolution = '576_1024'
38 | css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height:576px}"""
39 | elif res == 512:
40 | resolution = '320_512'
41 | css = """#input_img {max-width: 512px !important} #output_vid {max-width: 512px; max-height: 320px}"""
42 | elif res == 256:
43 | resolution = '256_256'
44 | css = """#input_img {max-width: 256px !important} #output_vid {max-width: 256px; max-height: 256px}"""
45 | else:
46 | raise NotImplementedError(f"Unsupported resolution: {res}")
47 | image2video = Image2Video(result_dir, resolution=resolution)
48 | with gr.Blocks(analytics_enabled=False, css=css) as dynamicrafter_iface:
49 | gr.Markdown("
")
61 |
62 | #######image2video######
63 | if res == 1024:
64 | with gr.Tab(label='Image2Video_576x1024'):
65 | with gr.Column():
66 | with gr.Row():
67 | with gr.Column():
68 | with gr.Row():
69 | i2v_input_image = gr.Image(label="Input Image",elem_id="input_img")
70 | with gr.Row():
71 | i2v_input_text = gr.Text(label='Prompts')
72 | with gr.Row():
73 | i2v_seed = gr.Slider(label='Random Seed', minimum=0, maximum=10000, step=1, value=123)
74 | i2v_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label='ETA', value=1.0, elem_id="i2v_eta")
75 | i2v_cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.5, elem_id="i2v_cfg_scale")
76 | with gr.Row():
77 | i2v_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="i2v_steps", label="Sampling steps", value=50)
78 | i2v_motion = gr.Slider(minimum=5, maximum=20, step=1, elem_id="i2v_motion", label="FPS", value=10)
79 | i2v_end_btn = gr.Button("Generate")
80 | # with gr.Tab(label='Result'):
81 | with gr.Row():
82 | i2v_output_video = gr.Video(label="Generated Video",elem_id="output_vid",autoplay=True,show_share_button=True)
83 |
84 | gr.Examples(examples=i2v_examples_1024,
85 | inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed],
86 | outputs=[i2v_output_video],
87 | fn = image2video.get_image,
88 | cache_examples=False,
89 | )
90 | i2v_end_btn.click(inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed],
91 | outputs=[i2v_output_video],
92 | fn = image2video.get_image
93 | )
94 | elif res == 512:
95 | with gr.Tab(label='Image2Video_320x512'):
96 | with gr.Column():
97 | with gr.Row():
98 | with gr.Column():
99 | with gr.Row():
100 | i2v_input_image = gr.Image(label="Input Image",elem_id="input_img")
101 | with gr.Row():
102 | i2v_input_text = gr.Text(label='Prompts')
103 | with gr.Row():
104 | i2v_seed = gr.Slider(label='Random Seed', minimum=0, maximum=10000, step=1, value=123)
105 | i2v_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label='ETA', value=1.0, elem_id="i2v_eta")
106 | i2v_cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.5, elem_id="i2v_cfg_scale")
107 | with gr.Row():
108 | i2v_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="i2v_steps", label="Sampling steps", value=50)
109 | i2v_motion = gr.Slider(minimum=15, maximum=30, step=1, elem_id="i2v_motion", label="FPS", value=24)
110 | i2v_end_btn = gr.Button("Generate")
111 | # with gr.Tab(label='Result'):
112 | with gr.Row():
113 | i2v_output_video = gr.Video(label="Generated Video",elem_id="output_vid",autoplay=True,show_share_button=True)
114 |
115 | gr.Examples(examples=i2v_examples_512,
116 | inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed],
117 | outputs=[i2v_output_video],
118 | fn = image2video.get_image,
119 | cache_examples=False,
120 | )
121 | i2v_end_btn.click(inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed],
122 | outputs=[i2v_output_video],
123 | fn = image2video.get_image
124 | )
125 | elif res == 256:
126 | with gr.Tab(label='Image2Video_256x256'):
127 | with gr.Column():
128 | with gr.Row():
129 | with gr.Column():
130 | with gr.Row():
131 | i2v_input_image = gr.Image(label="Input Image",elem_id="input_img")
132 | with gr.Row():
133 | i2v_input_text = gr.Text(label='Prompts')
134 | with gr.Row():
135 | i2v_seed = gr.Slider(label='Random Seed', minimum=0, maximum=10000, step=1, value=123)
136 | i2v_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label='ETA', value=1.0, elem_id="i2v_eta")
137 | i2v_cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.5, elem_id="i2v_cfg_scale")
138 | with gr.Row():
139 | i2v_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="i2v_steps", label="Sampling steps", value=50)
140 | i2v_motion = gr.Slider(minimum=1, maximum=4, step=1, elem_id="i2v_motion", label="Motion magnitude", value=3)
141 | i2v_end_btn = gr.Button("Generate")
142 | # with gr.Tab(label='Result'):
143 | with gr.Row():
144 | i2v_output_video = gr.Video(label="Generated Video",elem_id="output_vid",autoplay=True,show_share_button=True)
145 |
146 | gr.Examples(examples=i2v_examples_256,
147 | inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed],
148 | outputs=[i2v_output_video],
149 | fn = image2video.get_image,
150 | cache_examples=False,
151 | )
152 | i2v_end_btn.click(inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed],
153 | outputs=[i2v_output_video],
154 | fn = image2video.get_image
155 | )
156 |
157 | return dynamicrafter_iface
158 |
159 | def get_parser():
160 | parser = argparse.ArgumentParser()
161 | parser.add_argument("--res", type=int, default=1024, choices=[1024,512,256], help="select the model based on the required resolution")
162 |
163 | return parser
164 |
165 | if __name__ == "__main__":
166 | parser = get_parser()
167 | args = parser.parse_args()
168 |
169 | result_dir = os.path.join('./', 'results')
170 | dynamicrafter_iface = dynamicrafter_demo(result_dir, args.res)
171 | dynamicrafter_iface.queue(max_size=12)
172 | dynamicrafter_iface.launch(max_threads=1)
173 | # dynamicrafter_iface.launch(server_name='0.0.0.0', server_port=80, max_threads=1)
--------------------------------------------------------------------------------
/lvdm/basics.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 | import torch.nn as nn
11 | from ..utils.utils import instantiate_from_config
12 |
13 |
14 | def disabled_train(self, mode=True):
15 | """Overwrite model.train with this function to make sure train/eval mode
16 | does not change anymore."""
17 | return self
18 |
19 | def zero_module(module):
20 | """
21 | Zero out the parameters of a module and return it.
22 | """
23 | for p in module.parameters():
24 | p.detach().zero_()
25 | return module
26 |
27 | def scale_module(module, scale):
28 | """
29 | Scale the parameters of a module and return it.
30 | """
31 | for p in module.parameters():
32 | p.detach().mul_(scale)
33 | return module
34 |
35 |
36 | def conv_nd(dims, *args, **kwargs):
37 | """
38 | Create a 1D, 2D, or 3D convolution module.
39 | """
40 | if dims == 1:
41 | return nn.Conv1d(*args, **kwargs)
42 | elif dims == 2:
43 | return nn.Conv2d(*args, **kwargs)
44 | elif dims == 3:
45 | return nn.Conv3d(*args, **kwargs)
46 | raise ValueError(f"unsupported dimensions: {dims}")
47 |
48 |
49 | def linear(*args, **kwargs):
50 | """
51 | Create a linear module.
52 | """
53 | return nn.Linear(*args, **kwargs)
54 |
55 |
56 | def avg_pool_nd(dims, *args, **kwargs):
57 | """
58 | Create a 1D, 2D, or 3D average pooling module.
59 | """
60 | if dims == 1:
61 | return nn.AvgPool1d(*args, **kwargs)
62 | elif dims == 2:
63 | return nn.AvgPool2d(*args, **kwargs)
64 | elif dims == 3:
65 | return nn.AvgPool3d(*args, **kwargs)
66 | raise ValueError(f"unsupported dimensions: {dims}")
67 |
68 |
69 | def nonlinearity(type='silu'):
70 | if type == 'silu':
71 | return nn.SiLU()
72 | elif type == 'leaky_relu':
73 | return nn.LeakyReLU()
74 |
75 |
76 | class GroupNormSpecific(nn.GroupNorm):
77 | def forward(self, x):
78 | return super().forward(x.float()).type(x.dtype)
79 |
80 |
81 | def normalization(channels, num_groups=32):
82 | """
83 | Make a standard normalization layer.
84 | :param channels: number of input channels.
85 | :return: an nn.Module for normalization.
86 | """
87 | return GroupNormSpecific(num_groups, channels)
88 |
89 |
90 | class HybridConditioner(nn.Module):
91 |
92 | def __init__(self, c_concat_config, c_crossattn_config):
93 | super().__init__()
94 | self.concat_conditioner = instantiate_from_config(c_concat_config)
95 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
96 |
97 | def forward(self, c_concat, c_crossattn):
98 | c_concat = self.concat_conditioner(c_concat)
99 | c_crossattn = self.crossattn_conditioner(c_crossattn)
100 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
--------------------------------------------------------------------------------
/lvdm/common.py:
--------------------------------------------------------------------------------
1 | import math
2 | from inspect import isfunction
3 | import torch
4 | from torch import nn
5 | import torch.distributed as dist
6 |
7 |
8 | def gather_data(data, return_np=True):
9 | ''' gather data from multiple processes to one list '''
10 | data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())]
11 | dist.all_gather(data_list, data) # gather not supported with NCCL
12 | if return_np:
13 | data_list = [data.cpu().numpy() for data in data_list]
14 | return data_list
15 |
16 | def autocast(f):
17 | def do_autocast(*args, **kwargs):
18 | with torch.cuda.amp.autocast(enabled=True,
19 | dtype=torch.get_autocast_gpu_dtype(),
20 | cache_enabled=torch.is_autocast_cache_enabled()):
21 | return f(*args, **kwargs)
22 | return do_autocast
23 |
24 |
25 | def extract_into_tensor(a, t, x_shape):
26 | b, *_ = t.shape
27 | out = a.gather(-1, t)
28 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
29 |
30 |
31 | def noise_like(shape, device, repeat=False):
32 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
33 | noise = lambda: torch.randn(shape, device=device)
34 | return repeat_noise() if repeat else noise()
35 |
36 |
37 | def default(val, d):
38 | if exists(val):
39 | return val
40 | return d() if isfunction(d) else d
41 |
42 | def exists(val):
43 | return val is not None
44 |
45 | def identity(*args, **kwargs):
46 | return nn.Identity()
47 |
48 | def uniq(arr):
49 | return{el: True for el in arr}.keys()
50 |
51 | def mean_flat(tensor):
52 | """
53 | Take the mean over all non-batch dimensions.
54 | """
55 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
56 |
57 | def ismap(x):
58 | if not isinstance(x, torch.Tensor):
59 | return False
60 | return (len(x.shape) == 4) and (x.shape[1] > 3)
61 |
62 | def isimage(x):
63 | if not isinstance(x,torch.Tensor):
64 | return False
65 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
66 |
67 | def max_neg_value(t):
68 | return -torch.finfo(t.dtype).max
69 |
70 | def shape_to_str(x):
71 | shape_str = "x".join([str(x) for x in x.shape])
72 | return shape_str
73 |
74 | def init_(tensor):
75 | dim = tensor.shape[-1]
76 | std = 1 / math.sqrt(dim)
77 | tensor.uniform_(-std, std)
78 | return tensor
79 |
80 | ckpt = torch.utils.checkpoint.checkpoint
81 | def checkpoint(func, inputs, params, flag):
82 | """
83 | Evaluate a function without caching intermediate activations, allowing for
84 | reduced memory at the expense of extra compute in the backward pass.
85 | :param func: the function to evaluate.
86 | :param inputs: the argument sequence to pass to `func`.
87 | :param params: a sequence of parameters `func` depends on but does not
88 | explicitly take as arguments.
89 | :param flag: if False, disable gradient checkpointing.
90 | """
91 | if flag:
92 | return ckpt(func, *inputs, use_reentrant=False)
93 | else:
94 | return func(*inputs)
--------------------------------------------------------------------------------
/lvdm/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self, noise=None):
36 | if noise is None:
37 | noise = torch.randn(self.mean.shape)
38 |
39 | x = self.mean + self.std * noise.to(device=self.parameters.device)
40 | return x
41 |
42 | def kl(self, other=None):
43 | if self.deterministic:
44 | return torch.Tensor([0.])
45 | else:
46 | if other is None:
47 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
48 | + self.var - 1.0 - self.logvar,
49 | dim=[1, 2, 3])
50 | else:
51 | return 0.5 * torch.sum(
52 | torch.pow(self.mean - other.mean, 2) / other.var
53 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
54 | dim=[1, 2, 3])
55 |
56 | def nll(self, sample, dims=[1,2,3]):
57 | if self.deterministic:
58 | return torch.Tensor([0.])
59 | logtwopi = np.log(2.0 * np.pi)
60 | return 0.5 * torch.sum(
61 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
62 | dim=dims)
63 |
64 | def mode(self):
65 | return self.mean
66 |
67 |
68 | def normal_kl(mean1, logvar1, mean2, logvar2):
69 | """
70 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
71 | Compute the KL divergence between two gaussians.
72 | Shapes are automatically broadcasted, so batches can be compared to
73 | scalars, among other use cases.
74 | """
75 | tensor = None
76 | for obj in (mean1, logvar1, mean2, logvar2):
77 | if isinstance(obj, torch.Tensor):
78 | tensor = obj
79 | break
80 | assert tensor is not None, "at least one argument must be a Tensor"
81 |
82 | # Force variances to be Tensors. Broadcasting helps convert scalars to
83 | # Tensors, but it does not work for torch.exp().
84 | logvar1, logvar2 = [
85 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
86 | for x in (logvar1, logvar2)
87 | ]
88 |
89 | return 0.5 * (
90 | -1.0
91 | + logvar2
92 | - logvar1
93 | + torch.exp(logvar1 - logvar2)
94 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
95 | )
--------------------------------------------------------------------------------
/lvdm/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
--------------------------------------------------------------------------------
/lvdm/models/autoencoder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from contextlib import contextmanager
3 | import torch
4 | import numpy as np
5 | from einops import rearrange
6 | import torch.nn.functional as F
7 | import pytorch_lightning as pl
8 | from ..modules.networks.ae_modules import Encoder, Decoder
9 | from ..distributions import DiagonalGaussianDistribution
10 | from ...utils.utils import instantiate_from_config
11 |
12 |
13 | class AutoencoderKL(pl.LightningModule):
14 | def __init__(self,
15 | ddconfig,
16 | lossconfig,
17 | embed_dim,
18 | ckpt_path=None,
19 | ignore_keys=[],
20 | image_key="image",
21 | colorize_nlabels=None,
22 | monitor=None,
23 | test=False,
24 | logdir=None,
25 | input_dim=4,
26 | test_args=None,
27 | ):
28 | super().__init__()
29 | self.image_key = image_key
30 | self.encoder = Encoder(**ddconfig)
31 | self.decoder = Decoder(**ddconfig)
32 | self.loss = instantiate_from_config(lossconfig)
33 | assert ddconfig["double_z"]
34 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
35 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
36 | self.embed_dim = embed_dim
37 | self.input_dim = input_dim
38 | self.test = test
39 | self.test_args = test_args
40 | self.logdir = logdir
41 | if colorize_nlabels is not None:
42 | assert type(colorize_nlabels)==int
43 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
44 | if monitor is not None:
45 | self.monitor = monitor
46 | if ckpt_path is not None:
47 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
48 | if self.test:
49 | self.init_test()
50 |
51 | def init_test(self,):
52 | self.test = True
53 | save_dir = os.path.join(self.logdir, "test")
54 | if 'ckpt' in self.test_args:
55 | ckpt_name = os.path.basename(self.test_args.ckpt).split('.ckpt')[0] + f'_epoch{self._cur_epoch}'
56 | self.root = os.path.join(save_dir, ckpt_name)
57 | else:
58 | self.root = save_dir
59 | if 'test_subdir' in self.test_args:
60 | self.root = os.path.join(save_dir, self.test_args.test_subdir)
61 |
62 | self.root_zs = os.path.join(self.root, "zs")
63 | self.root_dec = os.path.join(self.root, "reconstructions")
64 | self.root_inputs = os.path.join(self.root, "inputs")
65 | os.makedirs(self.root, exist_ok=True)
66 |
67 | if self.test_args.save_z:
68 | os.makedirs(self.root_zs, exist_ok=True)
69 | if self.test_args.save_reconstruction:
70 | os.makedirs(self.root_dec, exist_ok=True)
71 | if self.test_args.save_input:
72 | os.makedirs(self.root_inputs, exist_ok=True)
73 | assert(self.test_args is not None)
74 | self.test_maximum = getattr(self.test_args, 'test_maximum', None)
75 | self.count = 0
76 | self.eval_metrics = {}
77 | self.decodes = []
78 | self.save_decode_samples = 2048
79 |
80 | def init_from_ckpt(self, path, ignore_keys=list()):
81 | sd = torch.load(path, map_location="cpu")
82 | try:
83 | self._cur_epoch = sd['epoch']
84 | sd = sd["state_dict"]
85 | except:
86 | self._cur_epoch = 'null'
87 | keys = list(sd.keys())
88 | for k in keys:
89 | for ik in ignore_keys:
90 | if k.startswith(ik):
91 | print("Deleting key {} from state_dict.".format(k))
92 | del sd[k]
93 | self.load_state_dict(sd, strict=False)
94 | # self.load_state_dict(sd, strict=True)
95 | print(f"Restored from {path}")
96 |
97 | def encode(self, x, **kwargs):
98 |
99 | h = self.encoder(x)
100 | moments = self.quant_conv(h)
101 | posterior = DiagonalGaussianDistribution(moments)
102 | return posterior
103 |
104 | def decode(self, z, **kwargs):
105 | z = self.post_quant_conv(z)
106 | dec = self.decoder(z)
107 | return dec
108 |
109 | def forward(self, input, sample_posterior=True):
110 | posterior = self.encode(input)
111 | if sample_posterior:
112 | z = posterior.sample()
113 | else:
114 | z = posterior.mode()
115 | dec = self.decode(z)
116 | return dec, posterior
117 |
118 | def get_input(self, batch, k):
119 | x = batch[k]
120 | if x.dim() == 5 and self.input_dim == 4:
121 | b,c,t,h,w = x.shape
122 | self.b = b
123 | self.t = t
124 | x = rearrange(x, 'b c t h w -> (b t) c h w')
125 |
126 | return x
127 |
128 | def training_step(self, batch, batch_idx, optimizer_idx):
129 | inputs = self.get_input(batch, self.image_key)
130 | reconstructions, posterior = self(inputs)
131 |
132 | if optimizer_idx == 0:
133 | # train encoder+decoder+logvar
134 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
135 | last_layer=self.get_last_layer(), split="train")
136 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
137 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
138 | return aeloss
139 |
140 | if optimizer_idx == 1:
141 | # train the discriminator
142 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
143 | last_layer=self.get_last_layer(), split="train")
144 |
145 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
146 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
147 | return discloss
148 |
149 | def validation_step(self, batch, batch_idx):
150 | inputs = self.get_input(batch, self.image_key)
151 | reconstructions, posterior = self(inputs)
152 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
153 | last_layer=self.get_last_layer(), split="val")
154 |
155 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
156 | last_layer=self.get_last_layer(), split="val")
157 |
158 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
159 | self.log_dict(log_dict_ae)
160 | self.log_dict(log_dict_disc)
161 | return self.log_dict
162 |
163 | def configure_optimizers(self):
164 | lr = self.learning_rate
165 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
166 | list(self.decoder.parameters())+
167 | list(self.quant_conv.parameters())+
168 | list(self.post_quant_conv.parameters()),
169 | lr=lr, betas=(0.5, 0.9))
170 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
171 | lr=lr, betas=(0.5, 0.9))
172 | return [opt_ae, opt_disc], []
173 |
174 | def get_last_layer(self):
175 | return self.decoder.conv_out.weight
176 |
177 | @torch.no_grad()
178 | def log_images(self, batch, only_inputs=False, **kwargs):
179 | log = dict()
180 | x = self.get_input(batch, self.image_key)
181 | x = x.to(self.device)
182 | if not only_inputs:
183 | xrec, posterior = self(x)
184 | if x.shape[1] > 3:
185 | # colorize with random projection
186 | assert xrec.shape[1] > 3
187 | x = self.to_rgb(x)
188 | xrec = self.to_rgb(xrec)
189 | log["samples"] = self.decode(torch.randn_like(posterior.sample()))
190 | log["reconstructions"] = xrec
191 | log["inputs"] = x
192 | return log
193 |
194 | def to_rgb(self, x):
195 | assert self.image_key == "segmentation"
196 | if not hasattr(self, "colorize"):
197 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
198 | x = F.conv2d(x, weight=self.colorize)
199 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
200 | return x
201 |
202 | class IdentityFirstStage(torch.nn.Module):
203 | def __init__(self, *args, vq_interface=False, **kwargs):
204 | self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
205 | super().__init__()
206 |
207 | def encode(self, x, *args, **kwargs):
208 | return x
209 |
210 | def decode(self, x, *args, **kwargs):
211 | return x
212 |
213 | def quantize(self, x, *args, **kwargs):
214 | if self.vq_interface:
215 | return x, None, [None, None, None]
216 | return x
217 |
218 | def forward(self, x, *args, **kwargs):
219 | return x
--------------------------------------------------------------------------------
/lvdm/models/samplers/ddim.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 | import torch
4 | from ..utils_diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg
5 | from ...common import noise_like
6 | from ...common import extract_into_tensor
7 | import copy
8 | import comfy.utils
9 |
10 | class DDIMSampler(object):
11 | def __init__(self, model, schedule="linear", **kwargs):
12 | super().__init__()
13 | self.model = model
14 | self.ddpm_num_timesteps = model.num_timesteps
15 | self.schedule = schedule
16 | self.counter = 0
17 |
18 | def register_buffer(self, name, attr):
19 | if type(attr) == torch.Tensor:
20 | if attr.device != torch.device("cuda"):
21 | attr = attr.to(torch.device("cuda"))
22 | setattr(self, name, attr)
23 |
24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
26 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
27 | alphas_cumprod = self.model.alphas_cumprod
28 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
29 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
30 |
31 | if self.model.use_dynamic_rescale:
32 | self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps]
33 | self.ddim_scale_arr_prev = torch.cat([self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]])
34 |
35 | self.register_buffer('betas', to_torch(self.model.betas))
36 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
37 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
38 |
39 | # calculations for diffusion q(x_t | x_{t-1}) and others
40 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
41 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
42 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
43 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
44 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
45 |
46 | # ddim sampling parameters
47 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
48 | ddim_timesteps=self.ddim_timesteps,
49 | eta=ddim_eta,verbose=verbose)
50 | self.register_buffer('ddim_sigmas', ddim_sigmas)
51 | self.register_buffer('ddim_alphas', ddim_alphas)
52 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
53 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
54 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
55 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
56 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
57 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
58 |
59 | @torch.no_grad()
60 | def sample(self,
61 | S,
62 | batch_size,
63 | shape,
64 | conditioning=None,
65 | callback=None,
66 | normals_sequence=None,
67 | img_callback=None,
68 | quantize_x0=False,
69 | eta=0.,
70 | mask=None,
71 | x0=None,
72 | temperature=1.,
73 | noise_dropout=0.,
74 | score_corrector=None,
75 | corrector_kwargs=None,
76 | verbose=True,
77 | schedule_verbose=False,
78 | x_T=None,
79 | log_every_t=100,
80 | unconditional_guidance_scale=1.,
81 | unconditional_conditioning=None,
82 | precision=None,
83 | fs=None,
84 | timestep_spacing='uniform', #uniform_trailing for starting from last timestep
85 | guidance_rescale=0.0,
86 | **kwargs
87 | ):
88 |
89 | # check condition bs
90 | if conditioning is not None:
91 | if isinstance(conditioning, dict):
92 | try:
93 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
94 | except:
95 | cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]
96 |
97 | if cbs != batch_size:
98 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
99 | else:
100 | if conditioning.shape[0] != batch_size:
101 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
102 |
103 | self.make_schedule(ddim_num_steps=S, ddim_discretize=timestep_spacing, ddim_eta=eta, verbose=schedule_verbose)
104 |
105 | # make shape
106 | if len(shape) == 3:
107 | C, H, W = shape
108 | size = (batch_size, C, H, W)
109 | elif len(shape) == 4:
110 | C, T, H, W = shape
111 | size = (batch_size, C, T, H, W)
112 |
113 | samples, intermediates = self.ddim_sampling(conditioning, size,
114 | callback=callback,
115 | img_callback=img_callback,
116 | quantize_denoised=quantize_x0,
117 | mask=mask, x0=x0,
118 | ddim_use_original_steps=False,
119 | noise_dropout=noise_dropout,
120 | temperature=temperature,
121 | score_corrector=score_corrector,
122 | corrector_kwargs=corrector_kwargs,
123 | x_T=x_T,
124 | log_every_t=log_every_t,
125 | unconditional_guidance_scale=unconditional_guidance_scale,
126 | unconditional_conditioning=unconditional_conditioning,
127 | verbose=verbose,
128 | precision=precision,
129 | fs=fs,
130 | guidance_rescale=guidance_rescale,
131 | **kwargs)
132 | return samples, intermediates
133 |
134 | @torch.no_grad()
135 | def ddim_sampling(self, cond, shape,
136 | x_T=None, ddim_use_original_steps=False,
137 | callback=None, timesteps=None, quantize_denoised=False,
138 | mask=None, x0=None, img_callback=None, log_every_t=100,
139 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
140 | unconditional_guidance_scale=1., unconditional_conditioning=None, verbose=True,precision=None,fs=None,guidance_rescale=0.0,
141 | **kwargs):
142 | device = self.model.betas.device
143 | b = shape[0]
144 | if x_T is None:
145 | img = torch.randn(shape, device=device)
146 | else:
147 | img = x_T
148 | if precision is not None:
149 | if precision == 16:
150 | img = img.to(dtype=torch.float16)
151 |
152 | if timesteps is None:
153 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
154 | elif timesteps is not None and not ddim_use_original_steps:
155 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
156 | timesteps = self.ddim_timesteps[:subset_end]
157 |
158 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
159 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
160 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
161 | if verbose:
162 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
163 | else:
164 | iterator = time_range
165 |
166 | clean_cond = kwargs.pop("clean_cond", False)
167 |
168 | # cond_copy, unconditional_conditioning_copy = copy.deepcopy(cond), copy.deepcopy(unconditional_conditioning)
169 | pbar = comfy.utils.ProgressBar(total_steps)
170 | for i, step in enumerate(iterator):
171 | index = total_steps - i - 1
172 | ts = torch.full((b,), step, device=device, dtype=torch.long)
173 |
174 | ## use mask to blend noised original latent (img_orig) & new sampled latent (img)
175 | if mask is not None:
176 | assert x0 is not None
177 | if clean_cond:
178 | img_orig = x0
179 | else:
180 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
181 | img = img_orig * mask + (1. - mask) * img # keep original & modify use img
182 |
183 |
184 |
185 |
186 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
187 | quantize_denoised=quantize_denoised, temperature=temperature,
188 | noise_dropout=noise_dropout, score_corrector=score_corrector,
189 | corrector_kwargs=corrector_kwargs,
190 | unconditional_guidance_scale=unconditional_guidance_scale,
191 | unconditional_conditioning=unconditional_conditioning,
192 | mask=mask,x0=x0,fs=fs,guidance_rescale=guidance_rescale,
193 | **kwargs)
194 |
195 |
196 | img, pred_x0 = outs
197 | if callback: callback(i)
198 | if img_callback: img_callback(pred_x0, i)
199 |
200 | if index % log_every_t == 0 or index == total_steps - 1:
201 | intermediates['x_inter'].append(img)
202 | intermediates['pred_x0'].append(pred_x0)
203 | pbar.update(1)
204 | return img, intermediates
205 |
206 | @torch.no_grad()
207 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
208 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
209 | unconditional_guidance_scale=1., unconditional_conditioning=None,
210 | uc_type=None, conditional_guidance_scale_temporal=None,mask=None,x0=None,guidance_rescale=0.0,**kwargs):
211 | b, *_, device = *x.shape, x.device
212 | if x.dim() == 5:
213 | is_video = True
214 | else:
215 | is_video = False
216 |
217 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
218 | model_output = self.model.apply_model(x, t, c, **kwargs) # unet denoiser
219 | else:
220 | ### do_classifier_free_guidance
221 | if isinstance(c, torch.Tensor) or isinstance(c, dict):
222 | e_t_cond = self.model.apply_model(x, t, c, **kwargs)
223 | e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs)
224 | else:
225 | raise NotImplementedError
226 |
227 | model_output = e_t_uncond + unconditional_guidance_scale * (e_t_cond - e_t_uncond)
228 |
229 | if guidance_rescale > 0.0:
230 | model_output = rescale_noise_cfg(model_output, e_t_cond, guidance_rescale=guidance_rescale)
231 |
232 | if self.model.parameterization == "v":
233 | e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
234 | else:
235 | e_t = model_output
236 |
237 | if score_corrector is not None:
238 | assert self.model.parameterization == "eps", 'not implemented'
239 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
240 |
241 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
242 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
243 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
244 | # sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
245 | sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
246 | # select parameters corresponding to the currently considered timestep
247 |
248 | if is_video:
249 | size = (b, 1, 1, 1, 1)
250 | else:
251 | size = (b, 1, 1, 1)
252 | a_t = torch.full(size, alphas[index], device=device)
253 | a_prev = torch.full(size, alphas_prev[index], device=device)
254 | sigma_t = torch.full(size, sigmas[index], device=device)
255 | sqrt_one_minus_at = torch.full(size, sqrt_one_minus_alphas[index],device=device)
256 |
257 | # current prediction for x_0
258 | if self.model.parameterization != "v":
259 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
260 | else:
261 | pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
262 |
263 | if self.model.use_dynamic_rescale:
264 | scale_t = torch.full(size, self.ddim_scale_arr[index], device=device)
265 | prev_scale_t = torch.full(size, self.ddim_scale_arr_prev[index], device=device)
266 | rescale = (prev_scale_t / scale_t)
267 | pred_x0 *= rescale
268 |
269 | if quantize_denoised:
270 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
271 | # direction pointing to x_t
272 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
273 |
274 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
275 | if noise_dropout > 0.:
276 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
277 |
278 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
279 |
280 | return x_prev, pred_x0
281 |
282 | @torch.no_grad()
283 | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
284 | use_original_steps=False, callback=None):
285 |
286 | timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
287 | timesteps = timesteps[:t_start]
288 |
289 | time_range = np.flip(timesteps)
290 | total_steps = timesteps.shape[0]
291 | print(f"Running DDIM Sampling with {total_steps} timesteps")
292 |
293 | iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
294 | x_dec = x_latent
295 | for i, step in enumerate(iterator):
296 | index = total_steps - i - 1
297 | ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
298 | x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
299 | unconditional_guidance_scale=unconditional_guidance_scale,
300 | unconditional_conditioning=unconditional_conditioning)
301 | if callback: callback(i)
302 | return x_dec
303 |
304 | @torch.no_grad()
305 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
306 | # fast, but does not allow for exact reconstruction
307 | # t serves as an index to gather the correct alphas
308 | if use_original_steps:
309 | sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
310 | sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
311 | else:
312 | sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
313 | sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
314 |
315 | if noise is None:
316 | noise = torch.randn_like(x0)
317 | return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
318 | extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
319 |
--------------------------------------------------------------------------------
/lvdm/models/samplers/ddim_multiplecond.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 | import torch
4 | from ..utils_diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg
5 | from ...common import noise_like
6 | from ...common import extract_into_tensor
7 | import copy
8 |
9 |
10 | class DDIMSampler(object):
11 | def __init__(self, model, schedule="linear", **kwargs):
12 | super().__init__()
13 | self.model = model
14 | self.ddpm_num_timesteps = model.num_timesteps
15 | self.schedule = schedule
16 | self.counter = 0
17 |
18 | def register_buffer(self, name, attr):
19 | if type(attr) == torch.Tensor:
20 | if attr.device != torch.device("cuda"):
21 | attr = attr.to(torch.device("cuda"))
22 | setattr(self, name, attr)
23 |
24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
26 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
27 | alphas_cumprod = self.model.alphas_cumprod
28 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
29 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
30 |
31 | if self.model.use_dynamic_rescale:
32 | self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps]
33 | self.ddim_scale_arr_prev = torch.cat([self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]])
34 |
35 | self.register_buffer('betas', to_torch(self.model.betas))
36 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
37 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
38 |
39 | # calculations for diffusion q(x_t | x_{t-1}) and others
40 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
41 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
42 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
43 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
44 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
45 |
46 | # ddim sampling parameters
47 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
48 | ddim_timesteps=self.ddim_timesteps,
49 | eta=ddim_eta,verbose=verbose)
50 | self.register_buffer('ddim_sigmas', ddim_sigmas)
51 | self.register_buffer('ddim_alphas', ddim_alphas)
52 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
53 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
54 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
55 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
56 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
57 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
58 |
59 | @torch.no_grad()
60 | def sample(self,
61 | S,
62 | batch_size,
63 | shape,
64 | conditioning=None,
65 | callback=None,
66 | normals_sequence=None,
67 | img_callback=None,
68 | quantize_x0=False,
69 | eta=0.,
70 | mask=None,
71 | x0=None,
72 | temperature=1.,
73 | noise_dropout=0.,
74 | score_corrector=None,
75 | corrector_kwargs=None,
76 | verbose=True,
77 | schedule_verbose=False,
78 | x_T=None,
79 | log_every_t=100,
80 | unconditional_guidance_scale=1.,
81 | unconditional_conditioning=None,
82 | precision=None,
83 | fs=None,
84 | timestep_spacing='uniform', #uniform_trailing for starting from last timestep
85 | guidance_rescale=0.0,
86 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
87 | **kwargs
88 | ):
89 |
90 | # check condition bs
91 | if conditioning is not None:
92 | if isinstance(conditioning, dict):
93 | try:
94 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
95 | except:
96 | cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]
97 |
98 | if cbs != batch_size:
99 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
100 | else:
101 | if conditioning.shape[0] != batch_size:
102 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
103 |
104 | # print('==> timestep_spacing: ', timestep_spacing, guidance_rescale)
105 | self.make_schedule(ddim_num_steps=S, ddim_discretize=timestep_spacing, ddim_eta=eta, verbose=schedule_verbose)
106 |
107 | # make shape
108 | if len(shape) == 3:
109 | C, H, W = shape
110 | size = (batch_size, C, H, W)
111 | elif len(shape) == 4:
112 | C, T, H, W = shape
113 | size = (batch_size, C, T, H, W)
114 | # print(f'Data shape for DDIM sampling is {size}, eta {eta}')
115 |
116 | samples, intermediates = self.ddim_sampling(conditioning, size,
117 | callback=callback,
118 | img_callback=img_callback,
119 | quantize_denoised=quantize_x0,
120 | mask=mask, x0=x0,
121 | ddim_use_original_steps=False,
122 | noise_dropout=noise_dropout,
123 | temperature=temperature,
124 | score_corrector=score_corrector,
125 | corrector_kwargs=corrector_kwargs,
126 | x_T=x_T,
127 | log_every_t=log_every_t,
128 | unconditional_guidance_scale=unconditional_guidance_scale,
129 | unconditional_conditioning=unconditional_conditioning,
130 | verbose=verbose,
131 | precision=precision,
132 | fs=fs,
133 | guidance_rescale=guidance_rescale,
134 | **kwargs)
135 | return samples, intermediates
136 |
137 | @torch.no_grad()
138 | def ddim_sampling(self, cond, shape,
139 | x_T=None, ddim_use_original_steps=False,
140 | callback=None, timesteps=None, quantize_denoised=False,
141 | mask=None, x0=None, img_callback=None, log_every_t=100,
142 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
143 | unconditional_guidance_scale=1., unconditional_conditioning=None, verbose=True,precision=None,fs=None,guidance_rescale=0.0,
144 | **kwargs):
145 | device = self.model.betas.device
146 | b = shape[0]
147 | if x_T is None:
148 | img = torch.randn(shape, device=device)
149 | else:
150 | img = x_T
151 | if precision is not None:
152 | if precision == 16:
153 | img = img.to(dtype=torch.float16)
154 |
155 |
156 | if timesteps is None:
157 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
158 | elif timesteps is not None and not ddim_use_original_steps:
159 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
160 | timesteps = self.ddim_timesteps[:subset_end]
161 |
162 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
163 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
164 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
165 | if verbose:
166 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
167 | else:
168 | iterator = time_range
169 |
170 | clean_cond = kwargs.pop("clean_cond", False)
171 |
172 | # cond_copy, unconditional_conditioning_copy = copy.deepcopy(cond), copy.deepcopy(unconditional_conditioning)
173 | for i, step in enumerate(iterator):
174 | index = total_steps - i - 1
175 | ts = torch.full((b,), step, device=device, dtype=torch.long)
176 |
177 | ## use mask to blend noised original latent (img_orig) & new sampled latent (img)
178 | if mask is not None:
179 | assert x0 is not None
180 | if clean_cond:
181 | img_orig = x0
182 | else:
183 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
184 | img = img_orig * mask + (1. - mask) * img # keep original & modify use img
185 |
186 |
187 |
188 |
189 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
190 | quantize_denoised=quantize_denoised, temperature=temperature,
191 | noise_dropout=noise_dropout, score_corrector=score_corrector,
192 | corrector_kwargs=corrector_kwargs,
193 | unconditional_guidance_scale=unconditional_guidance_scale,
194 | unconditional_conditioning=unconditional_conditioning,
195 | mask=mask,x0=x0,fs=fs,guidance_rescale=guidance_rescale,
196 | **kwargs)
197 |
198 |
199 |
200 | img, pred_x0 = outs
201 | if callback: callback(i)
202 | if img_callback: img_callback(pred_x0, i)
203 |
204 | if index % log_every_t == 0 or index == total_steps - 1:
205 | intermediates['x_inter'].append(img)
206 | intermediates['pred_x0'].append(pred_x0)
207 |
208 | return img, intermediates
209 |
210 | @torch.no_grad()
211 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
212 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
213 | unconditional_guidance_scale=1., unconditional_conditioning=None,
214 | uc_type=None, cfg_img=None,mask=None,x0=None,guidance_rescale=0.0, **kwargs):
215 | b, *_, device = *x.shape, x.device
216 | if x.dim() == 5:
217 | is_video = True
218 | else:
219 | is_video = False
220 | if cfg_img is None:
221 | cfg_img = unconditional_guidance_scale
222 |
223 | unconditional_conditioning_img_nonetext = kwargs['unconditional_conditioning_img_nonetext']
224 |
225 |
226 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
227 | model_output = self.model.apply_model(x, t, c, **kwargs) # unet denoiser
228 | else:
229 | ### with unconditional condition
230 | e_t_cond = self.model.apply_model(x, t, c, **kwargs)
231 | e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs)
232 | e_t_uncond_img = self.model.apply_model(x, t, unconditional_conditioning_img_nonetext, **kwargs)
233 | # text cfg
234 | model_output = e_t_uncond + cfg_img * (e_t_uncond_img - e_t_uncond) + unconditional_guidance_scale * (e_t_cond - e_t_uncond_img)
235 | if guidance_rescale > 0.0:
236 | model_output = rescale_noise_cfg(model_output, e_t_cond, guidance_rescale=guidance_rescale)
237 |
238 | if self.model.parameterization == "v":
239 | e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
240 | else:
241 | e_t = model_output
242 |
243 | if score_corrector is not None:
244 | assert self.model.parameterization == "eps", 'not implemented'
245 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
246 |
247 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
248 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
249 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
250 | sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
251 | # select parameters corresponding to the currently considered timestep
252 |
253 | if is_video:
254 | size = (b, 1, 1, 1, 1)
255 | else:
256 | size = (b, 1, 1, 1)
257 | a_t = torch.full(size, alphas[index], device=device)
258 | a_prev = torch.full(size, alphas_prev[index], device=device)
259 | sigma_t = torch.full(size, sigmas[index], device=device)
260 | sqrt_one_minus_at = torch.full(size, sqrt_one_minus_alphas[index],device=device)
261 |
262 | # current prediction for x_0
263 | if self.model.parameterization != "v":
264 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
265 | else:
266 | pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
267 |
268 | if self.model.use_dynamic_rescale:
269 | scale_t = torch.full(size, self.ddim_scale_arr[index], device=device)
270 | prev_scale_t = torch.full(size, self.ddim_scale_arr_prev[index], device=device)
271 | rescale = (prev_scale_t / scale_t)
272 | pred_x0 *= rescale
273 |
274 | if quantize_denoised:
275 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
276 | # direction pointing to x_t
277 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
278 |
279 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
280 | if noise_dropout > 0.:
281 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
282 |
283 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
284 |
285 | return x_prev, pred_x0
286 |
287 | @torch.no_grad()
288 | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
289 | use_original_steps=False, callback=None):
290 |
291 | timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
292 | timesteps = timesteps[:t_start]
293 |
294 | time_range = np.flip(timesteps)
295 | total_steps = timesteps.shape[0]
296 | print(f"Running DDIM Sampling with {total_steps} timesteps")
297 |
298 | iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
299 | x_dec = x_latent
300 | for i, step in enumerate(iterator):
301 | index = total_steps - i - 1
302 | ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
303 | x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
304 | unconditional_guidance_scale=unconditional_guidance_scale,
305 | unconditional_conditioning=unconditional_conditioning)
306 | if callback: callback(i)
307 | return x_dec
308 |
309 | @torch.no_grad()
310 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
311 | # fast, but does not allow for exact reconstruction
312 | # t serves as an index to gather the correct alphas
313 | if use_original_steps:
314 | sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
315 | sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
316 | else:
317 | sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
318 | sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
319 |
320 | if noise is None:
321 | noise = torch.randn_like(x0)
322 | return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
323 | extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
--------------------------------------------------------------------------------
/lvdm/models/utils_diffusion.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 | from einops import repeat
6 |
7 |
8 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
9 | """
10 | Create sinusoidal timestep embeddings.
11 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
12 | These may be fractional.
13 | :param dim: the dimension of the output.
14 | :param max_period: controls the minimum frequency of the embeddings.
15 | :return: an [N x dim] Tensor of positional embeddings.
16 | """
17 | if not repeat_only:
18 | half = dim // 2
19 | freqs = torch.exp(
20 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
21 | ).to(device=timesteps.device)
22 | args = timesteps[:, None].float() * freqs[None]
23 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
24 | if dim % 2:
25 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
26 | else:
27 | embedding = repeat(timesteps, 'b -> b d', d=dim)
28 | return embedding
29 |
30 |
31 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
32 | if schedule == "linear":
33 | betas = (
34 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
35 | )
36 |
37 | elif schedule == "cosine":
38 | timesteps = (
39 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
40 | )
41 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
42 | alphas = torch.cos(alphas).pow(2)
43 | alphas = alphas / alphas[0]
44 | betas = 1 - alphas[1:] / alphas[:-1]
45 | betas = np.clip(betas, a_min=0, a_max=0.999)
46 |
47 | elif schedule == "sqrt_linear":
48 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
49 | elif schedule == "sqrt":
50 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
51 | else:
52 | raise ValueError(f"schedule '{schedule}' unknown.")
53 | return betas.numpy()
54 |
55 |
56 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
57 | if ddim_discr_method == 'uniform':
58 | c = num_ddpm_timesteps // num_ddim_timesteps
59 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
60 | steps_out = ddim_timesteps + 1
61 | elif ddim_discr_method == 'uniform_trailing':
62 | c = num_ddpm_timesteps / num_ddim_timesteps
63 | ddim_timesteps = np.flip(np.round(np.arange(num_ddpm_timesteps, 0, -c))).astype(np.int64)
64 | steps_out = ddim_timesteps - 1
65 | elif ddim_discr_method == 'quad':
66 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
67 | steps_out = ddim_timesteps + 1
68 | else:
69 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
70 |
71 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
72 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
73 | # steps_out = ddim_timesteps + 1
74 | if verbose:
75 | print(f'Selected timesteps for ddim sampler: {steps_out}')
76 | return steps_out
77 |
78 |
79 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
80 | # select alphas for computing the variance schedule
81 | # print(f'ddim_timesteps={ddim_timesteps}, len_alphacums={len(alphacums)}')
82 | alphas = alphacums[ddim_timesteps]
83 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
84 |
85 | # according the the formula provided in https://arxiv.org/abs/2010.02502
86 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
87 | if verbose:
88 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
89 | print(f'For the chosen value of eta, which is {eta}, '
90 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
91 | return sigmas, alphas, alphas_prev
92 |
93 |
94 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
95 | """
96 | Create a beta schedule that discretizes the given alpha_t_bar function,
97 | which defines the cumulative product of (1-beta) over time from t = [0,1].
98 | :param num_diffusion_timesteps: the number of betas to produce.
99 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
100 | produces the cumulative product of (1-beta) up to that
101 | part of the diffusion process.
102 | :param max_beta: the maximum beta to use; use values lower than 1 to
103 | prevent singularities.
104 | """
105 | betas = []
106 | for i in range(num_diffusion_timesteps):
107 | t1 = i / num_diffusion_timesteps
108 | t2 = (i + 1) / num_diffusion_timesteps
109 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
110 | return np.array(betas)
111 |
112 | def rescale_zero_terminal_snr(betas):
113 | """
114 | Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
115 |
116 | Args:
117 | betas (`numpy.ndarray`):
118 | the betas that the scheduler is being initialized with.
119 |
120 | Returns:
121 | `numpy.ndarray`: rescaled betas with zero terminal SNR
122 | """
123 | # Convert betas to alphas_bar_sqrt
124 | alphas = 1.0 - betas
125 | alphas_cumprod = np.cumprod(alphas, axis=0)
126 | alphas_bar_sqrt = np.sqrt(alphas_cumprod)
127 |
128 | # Store old values.
129 | alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy()
130 | alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy()
131 |
132 | # Shift so the last timestep is zero.
133 | alphas_bar_sqrt -= alphas_bar_sqrt_T
134 |
135 | # Scale so the first timestep is back to the old value.
136 | alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
137 |
138 | # Convert alphas_bar_sqrt to betas
139 | alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
140 | alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
141 | alphas = np.concatenate([alphas_bar[0:1], alphas])
142 | betas = 1 - alphas
143 |
144 | return betas
145 |
146 |
147 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
148 | """
149 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
150 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
151 | """
152 | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
153 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
154 | # rescale the results from guidance (fixes overexposure)
155 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
156 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
157 | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
158 | return noise_cfg
--------------------------------------------------------------------------------
/lvdm/modules/encoders/condition.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import kornia
4 | import open_clip
5 | from torch.utils.checkpoint import checkpoint
6 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
7 | from ...common import autocast
8 | from ....utils.utils import count_params
9 |
10 |
11 | class AbstractEncoder(nn.Module):
12 | def __init__(self):
13 | super().__init__()
14 |
15 | def encode(self, *args, **kwargs):
16 | raise NotImplementedError
17 |
18 |
19 | class IdentityEncoder(AbstractEncoder):
20 | def encode(self, x):
21 | return x
22 |
23 |
24 | class ClassEmbedder(nn.Module):
25 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
26 | super().__init__()
27 | self.key = key
28 | self.embedding = nn.Embedding(n_classes, embed_dim)
29 | self.n_classes = n_classes
30 | self.ucg_rate = ucg_rate
31 |
32 | def forward(self, batch, key=None, disable_dropout=False):
33 | if key is None:
34 | key = self.key
35 | # this is for use in crossattn
36 | c = batch[key][:, None]
37 | if self.ucg_rate > 0. and not disable_dropout:
38 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
39 | c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
40 | c = c.long()
41 | c = self.embedding(c)
42 | return c
43 |
44 | def get_unconditional_conditioning(self, bs, device="cuda"):
45 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
46 | uc = torch.ones((bs,), device=device) * uc_class
47 | uc = {self.key: uc}
48 | return uc
49 |
50 |
51 | def disabled_train(self, mode=True):
52 | """Overwrite model.train with this function to make sure train/eval mode
53 | does not change anymore."""
54 | return self
55 |
56 |
57 | class FrozenT5Embedder(AbstractEncoder):
58 | """Uses the T5 transformer encoder for text"""
59 |
60 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77,
61 | freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
62 | super().__init__()
63 | self.tokenizer = T5Tokenizer.from_pretrained(version)
64 | self.transformer = T5EncoderModel.from_pretrained(version)
65 | self.device = device
66 | self.max_length = max_length # TODO: typical value?
67 | if freeze:
68 | self.freeze()
69 |
70 | def freeze(self):
71 | self.transformer = self.transformer.eval()
72 | # self.train = disabled_train
73 | for param in self.parameters():
74 | param.requires_grad = False
75 |
76 | def forward(self, text):
77 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
78 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
79 | tokens = batch_encoding["input_ids"].to(self.device)
80 | outputs = self.transformer(input_ids=tokens)
81 |
82 | z = outputs.last_hidden_state
83 | return z
84 |
85 | def encode(self, text):
86 | return self(text)
87 |
88 |
89 | class FrozenCLIPEmbedder(AbstractEncoder):
90 | """Uses the CLIP transformer encoder for text (from huggingface)"""
91 | LAYERS = [
92 | "last",
93 | "pooled",
94 | "hidden"
95 | ]
96 |
97 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
98 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
99 | super().__init__()
100 | assert layer in self.LAYERS
101 | self.tokenizer = CLIPTokenizer.from_pretrained(version)
102 | self.transformer = CLIPTextModel.from_pretrained(version)
103 | self.device = device
104 | self.max_length = max_length
105 | if freeze:
106 | self.freeze()
107 | self.layer = layer
108 | self.layer_idx = layer_idx
109 | if layer == "hidden":
110 | assert layer_idx is not None
111 | assert 0 <= abs(layer_idx) <= 12
112 |
113 | def freeze(self):
114 | self.transformer = self.transformer.eval()
115 | # self.train = disabled_train
116 | for param in self.parameters():
117 | param.requires_grad = False
118 |
119 | def forward(self, text):
120 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
121 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
122 | tokens = batch_encoding["input_ids"].to(self.device)
123 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
124 | if self.layer == "last":
125 | z = outputs.last_hidden_state
126 | elif self.layer == "pooled":
127 | z = outputs.pooler_output[:, None, :]
128 | else:
129 | z = outputs.hidden_states[self.layer_idx]
130 | return z
131 |
132 | def encode(self, text):
133 | return self(text)
134 |
135 |
136 | class ClipImageEmbedder(nn.Module):
137 | def __init__(
138 | self,
139 | model,
140 | jit=False,
141 | device='cuda' if torch.cuda.is_available() else 'cpu',
142 | antialias=True,
143 | ucg_rate=0.
144 | ):
145 | super().__init__()
146 | from clip import load as load_clip
147 | self.model, _ = load_clip(name=model, device=device, jit=jit)
148 |
149 | self.antialias = antialias
150 |
151 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
152 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
153 | self.ucg_rate = ucg_rate
154 |
155 | def preprocess(self, x):
156 | # normalize to [0,1]
157 | x = kornia.geometry.resize(x, (224, 224),
158 | interpolation='bicubic', align_corners=True,
159 | antialias=self.antialias)
160 | x = (x + 1.) / 2.
161 | # re-normalize according to clip
162 | x = kornia.enhance.normalize(x, self.mean, self.std)
163 | return x
164 |
165 | def forward(self, x, no_dropout=False):
166 | # x is assumed to be in range [-1,1]
167 | out = self.model.encode_image(self.preprocess(x))
168 | out = out.to(x.dtype)
169 | if self.ucg_rate > 0. and not no_dropout:
170 | out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
171 | return out
172 |
173 |
174 | class FrozenOpenCLIPEmbedder(AbstractEncoder):
175 | """
176 | Uses the OpenCLIP transformer encoder for text
177 | """
178 | LAYERS = [
179 | # "pooled",
180 | "last",
181 | "penultimate"
182 | ]
183 |
184 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
185 | freeze=True, layer="last"):
186 | super().__init__()
187 | assert layer in self.LAYERS
188 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
189 | del model.visual
190 | self.model = model
191 |
192 | self.device = device
193 | self.max_length = max_length
194 | if freeze:
195 | self.freeze()
196 | self.layer = layer
197 | if self.layer == "last":
198 | self.layer_idx = 0
199 | elif self.layer == "penultimate":
200 | self.layer_idx = 1
201 | else:
202 | raise NotImplementedError()
203 |
204 | def freeze(self):
205 | self.model = self.model.eval()
206 | for param in self.parameters():
207 | param.requires_grad = False
208 |
209 | def forward(self, text):
210 | tokens = open_clip.tokenize(text) ## all clip models use 77 as context length
211 | z = self.encode_with_transformer(tokens.to(self.device))
212 | return z
213 |
214 | def encode_with_transformer(self, text):
215 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
216 | x = x + self.model.positional_embedding
217 | x = x.permute(1, 0, 2) # NLD -> LND
218 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
219 | x = x.permute(1, 0, 2) # LND -> NLD
220 | x = self.model.ln_final(x)
221 | return x
222 |
223 | def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
224 | for i, r in enumerate(self.model.transformer.resblocks):
225 | if i == len(self.model.transformer.resblocks) - self.layer_idx:
226 | break
227 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
228 | x = checkpoint(r, x, attn_mask)
229 | else:
230 | x = r(x, attn_mask=attn_mask)
231 | return x
232 |
233 | def encode(self, text):
234 | return self(text)
235 |
236 |
237 | class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
238 | """
239 | Uses the OpenCLIP vision transformer encoder for images
240 | """
241 |
242 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
243 | freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
244 | super().__init__()
245 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
246 | pretrained=version, )
247 | del model.transformer
248 | self.model = model
249 | # self.mapper = torch.nn.Linear(1280, 1024)
250 | self.device = device
251 | self.max_length = max_length
252 | if freeze:
253 | self.freeze()
254 | self.layer = layer
255 | if self.layer == "penultimate":
256 | raise NotImplementedError()
257 | self.layer_idx = 1
258 |
259 | self.antialias = antialias
260 |
261 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
262 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
263 | self.ucg_rate = ucg_rate
264 |
265 | def preprocess(self, x):
266 | # normalize to [0,1]
267 | x = kornia.geometry.resize(x, (224, 224),
268 | interpolation='bicubic', align_corners=True,
269 | antialias=self.antialias)
270 | x = (x + 1.) / 2.
271 | # renormalize according to clip
272 | x = kornia.enhance.normalize(x, self.mean, self.std)
273 | return x
274 |
275 | def freeze(self):
276 | self.model = self.model.eval()
277 | for param in self.model.parameters():
278 | param.requires_grad = False
279 |
280 | @autocast
281 | def forward(self, image, no_dropout=False):
282 | z = self.encode_with_vision_transformer(image)
283 | if self.ucg_rate > 0. and not no_dropout:
284 | z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
285 | return z
286 |
287 | def encode_with_vision_transformer(self, img):
288 | img = self.preprocess(img)
289 | x = self.model.visual(img)
290 | return x
291 |
292 | def encode(self, text):
293 | return self(text)
294 |
295 | class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder):
296 | """
297 | Uses the OpenCLIP vision transformer encoder for images
298 | """
299 |
300 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda",
301 | freeze=True, layer="pooled", antialias=True):
302 | super().__init__()
303 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
304 | pretrained=version, )
305 | del model.transformer
306 | self.model = model
307 | self.device = device
308 |
309 | if freeze:
310 | self.freeze()
311 | self.layer = layer
312 | if self.layer == "penultimate":
313 | raise NotImplementedError()
314 | self.layer_idx = 1
315 |
316 | self.antialias = antialias
317 |
318 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
319 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
320 |
321 |
322 | def preprocess(self, x):
323 | # normalize to [0,1]
324 | x = kornia.geometry.resize(x, (224, 224),
325 | interpolation='bicubic', align_corners=True,
326 | antialias=self.antialias)
327 | x = (x + 1.) / 2.
328 | # renormalize according to clip
329 | x = kornia.enhance.normalize(x, self.mean, self.std)
330 | return x
331 |
332 | def freeze(self):
333 | self.model = self.model.eval()
334 | for param in self.model.parameters():
335 | param.requires_grad = False
336 |
337 | def forward(self, image, no_dropout=False):
338 | ## image: b c h w
339 | z = self.encode_with_vision_transformer(image)
340 | return z
341 |
342 | def encode_with_vision_transformer(self, x):
343 | x = self.preprocess(x)
344 |
345 | # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
346 | if self.model.visual.input_patchnorm:
347 | # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
348 | x = x.reshape(x.shape[0], x.shape[1], self.model.visual.grid_size[0], self.model.visual.patch_size[0], self.model.visual.grid_size[1], self.model.visual.patch_size[1])
349 | x = x.permute(0, 2, 4, 1, 3, 5)
350 | x = x.reshape(x.shape[0], self.model.visual.grid_size[0] * self.model.visual.grid_size[1], -1)
351 | x = self.model.visual.patchnorm_pre_ln(x)
352 | x = self.model.visual.conv1(x)
353 | else:
354 | x = self.model.visual.conv1(x) # shape = [*, width, grid, grid]
355 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
356 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
357 |
358 | # class embeddings and positional embeddings
359 | x = torch.cat(
360 | [self.model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
361 | x], dim=1) # shape = [*, grid ** 2 + 1, width]
362 | x = x + self.model.visual.positional_embedding.to(x.dtype)
363 |
364 | # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
365 | x = self.model.visual.patch_dropout(x)
366 | x = self.model.visual.ln_pre(x)
367 |
368 | x = x.permute(1, 0, 2) # NLD -> LND
369 | x = self.model.visual.transformer(x)
370 | x = x.permute(1, 0, 2) # LND -> NLD
371 |
372 | return x
373 |
374 | class FrozenCLIPT5Encoder(AbstractEncoder):
375 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
376 | clip_max_length=77, t5_max_length=77):
377 | super().__init__()
378 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
379 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
380 | print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
381 | f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.")
382 |
383 | def encode(self, text):
384 | return self(text)
385 |
386 | def forward(self, text):
387 | clip_z = self.clip_encoder.encode(text)
388 | t5_z = self.t5_encoder.encode(text)
389 | return [clip_z, t5_z]
390 |
--------------------------------------------------------------------------------
/lvdm/modules/encoders/resampler.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3 | # and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
4 | import math
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | class ImageProjModel(nn.Module):
10 | """Projection Model"""
11 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
12 | super().__init__()
13 | self.cross_attention_dim = cross_attention_dim
14 | self.clip_extra_context_tokens = clip_extra_context_tokens
15 | self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
16 | self.norm = nn.LayerNorm(cross_attention_dim)
17 |
18 | def forward(self, image_embeds):
19 | #embeds = image_embeds
20 | embeds = image_embeds.type(list(self.proj.parameters())[0].dtype)
21 | clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
22 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
23 | return clip_extra_context_tokens
24 |
25 |
26 | # FFN
27 | def FeedForward(dim, mult=4):
28 | inner_dim = int(dim * mult)
29 | return nn.Sequential(
30 | nn.LayerNorm(dim),
31 | nn.Linear(dim, inner_dim, bias=False),
32 | nn.GELU(),
33 | nn.Linear(inner_dim, dim, bias=False),
34 | )
35 |
36 |
37 | def reshape_tensor(x, heads):
38 | bs, length, width = x.shape
39 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
40 | x = x.view(bs, length, heads, -1)
41 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
42 | x = x.transpose(1, 2)
43 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
44 | x = x.reshape(bs, heads, length, -1)
45 | return x
46 |
47 |
48 | class PerceiverAttention(nn.Module):
49 | def __init__(self, *, dim, dim_head=64, heads=8):
50 | super().__init__()
51 | self.scale = dim_head**-0.5
52 | self.dim_head = dim_head
53 | self.heads = heads
54 | inner_dim = dim_head * heads
55 |
56 | self.norm1 = nn.LayerNorm(dim)
57 | self.norm2 = nn.LayerNorm(dim)
58 |
59 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
60 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
61 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
62 |
63 |
64 | def forward(self, x, latents):
65 | """
66 | Args:
67 | x (torch.Tensor): image features
68 | shape (b, n1, D)
69 | latent (torch.Tensor): latent features
70 | shape (b, n2, D)
71 | """
72 | x = self.norm1(x)
73 | latents = self.norm2(latents)
74 |
75 | b, l, _ = latents.shape
76 |
77 | q = self.to_q(latents)
78 | kv_input = torch.cat((x, latents), dim=-2)
79 | k, v = self.to_kv(kv_input).chunk(2, dim=-1)
80 |
81 | q = reshape_tensor(q, self.heads)
82 | k = reshape_tensor(k, self.heads)
83 | v = reshape_tensor(v, self.heads)
84 |
85 | # attention
86 | scale = 1 / math.sqrt(math.sqrt(self.dim_head))
87 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
88 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
89 | out = weight @ v
90 |
91 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
92 |
93 | return self.to_out(out)
94 |
95 |
96 | class Resampler(nn.Module):
97 | def __init__(
98 | self,
99 | dim=1024,
100 | depth=8,
101 | dim_head=64,
102 | heads=16,
103 | num_queries=8,
104 | embedding_dim=768,
105 | output_dim=1024,
106 | ff_mult=4,
107 | video_length=None, # using frame-wise version or not
108 | ):
109 | super().__init__()
110 | ## queries for a single frame / image
111 | self.num_queries = num_queries
112 | self.video_length = video_length
113 |
114 | ## queries for each frame
115 | if video_length is not None:
116 | num_queries = num_queries * video_length
117 |
118 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
119 | self.proj_in = nn.Linear(embedding_dim, dim)
120 | self.proj_out = nn.Linear(dim, output_dim)
121 | self.norm_out = nn.LayerNorm(output_dim)
122 |
123 | self.layers = nn.ModuleList([])
124 | for _ in range(depth):
125 | self.layers.append(
126 | nn.ModuleList(
127 | [
128 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
129 | FeedForward(dim=dim, mult=ff_mult),
130 | ]
131 | )
132 | )
133 |
134 | def forward(self, x):
135 | latents = self.latents.repeat(x.size(0), 1, 1) ## B (T L) C
136 | x = self.proj_in(x)
137 |
138 | for attn, ff in self.layers:
139 | latents = attn(x, latents) + latents
140 | latents = ff(latents) + latents
141 |
142 | latents = self.proj_out(latents)
143 | latents = self.norm_out(latents) # B L C or B (T L) C
144 |
145 | return latents
--------------------------------------------------------------------------------
/lvdm/modules/x_transformer.py:
--------------------------------------------------------------------------------
1 | """shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
2 | from functools import partial
3 | from inspect import isfunction
4 | from collections import namedtuple
5 | from einops import rearrange, repeat
6 | import torch
7 | from torch import nn, einsum
8 | import torch.nn.functional as F
9 |
10 | # constants
11 | DEFAULT_DIM_HEAD = 64
12 |
13 | Intermediates = namedtuple('Intermediates', [
14 | 'pre_softmax_attn',
15 | 'post_softmax_attn'
16 | ])
17 |
18 | LayerIntermediates = namedtuple('Intermediates', [
19 | 'hiddens',
20 | 'attn_intermediates'
21 | ])
22 |
23 |
24 | class AbsolutePositionalEmbedding(nn.Module):
25 | def __init__(self, dim, max_seq_len):
26 | super().__init__()
27 | self.emb = nn.Embedding(max_seq_len, dim)
28 | self.init_()
29 |
30 | def init_(self):
31 | nn.init.normal_(self.emb.weight, std=0.02)
32 |
33 | def forward(self, x):
34 | n = torch.arange(x.shape[1], device=x.device)
35 | return self.emb(n)[None, :, :]
36 |
37 |
38 | class FixedPositionalEmbedding(nn.Module):
39 | def __init__(self, dim):
40 | super().__init__()
41 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
42 | self.register_buffer('inv_freq', inv_freq)
43 |
44 | def forward(self, x, seq_dim=1, offset=0):
45 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
46 | sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
47 | emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
48 | return emb[None, :, :]
49 |
50 |
51 | # helpers
52 |
53 | def exists(val):
54 | return val is not None
55 |
56 |
57 | def default(val, d):
58 | if exists(val):
59 | return val
60 | return d() if isfunction(d) else d
61 |
62 |
63 | def always(val):
64 | def inner(*args, **kwargs):
65 | return val
66 | return inner
67 |
68 |
69 | def not_equals(val):
70 | def inner(x):
71 | return x != val
72 | return inner
73 |
74 |
75 | def equals(val):
76 | def inner(x):
77 | return x == val
78 | return inner
79 |
80 |
81 | def max_neg_value(tensor):
82 | return -torch.finfo(tensor.dtype).max
83 |
84 |
85 | # keyword argument helpers
86 |
87 | def pick_and_pop(keys, d):
88 | values = list(map(lambda key: d.pop(key), keys))
89 | return dict(zip(keys, values))
90 |
91 |
92 | def group_dict_by_key(cond, d):
93 | return_val = [dict(), dict()]
94 | for key in d.keys():
95 | match = bool(cond(key))
96 | ind = int(not match)
97 | return_val[ind][key] = d[key]
98 | return (*return_val,)
99 |
100 |
101 | def string_begins_with(prefix, str):
102 | return str.startswith(prefix)
103 |
104 |
105 | def group_by_key_prefix(prefix, d):
106 | return group_dict_by_key(partial(string_begins_with, prefix), d)
107 |
108 |
109 | def groupby_prefix_and_trim(prefix, d):
110 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
111 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
112 | return kwargs_without_prefix, kwargs
113 |
114 |
115 | # classes
116 | class Scale(nn.Module):
117 | def __init__(self, value, fn):
118 | super().__init__()
119 | self.value = value
120 | self.fn = fn
121 |
122 | def forward(self, x, **kwargs):
123 | x, *rest = self.fn(x, **kwargs)
124 | return (x * self.value, *rest)
125 |
126 |
127 | class Rezero(nn.Module):
128 | def __init__(self, fn):
129 | super().__init__()
130 | self.fn = fn
131 | self.g = nn.Parameter(torch.zeros(1))
132 |
133 | def forward(self, x, **kwargs):
134 | x, *rest = self.fn(x, **kwargs)
135 | return (x * self.g, *rest)
136 |
137 |
138 | class ScaleNorm(nn.Module):
139 | def __init__(self, dim, eps=1e-5):
140 | super().__init__()
141 | self.scale = dim ** -0.5
142 | self.eps = eps
143 | self.g = nn.Parameter(torch.ones(1))
144 |
145 | def forward(self, x):
146 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
147 | return x / norm.clamp(min=self.eps) * self.g
148 |
149 |
150 | class RMSNorm(nn.Module):
151 | def __init__(self, dim, eps=1e-8):
152 | super().__init__()
153 | self.scale = dim ** -0.5
154 | self.eps = eps
155 | self.g = nn.Parameter(torch.ones(dim))
156 |
157 | def forward(self, x):
158 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
159 | return x / norm.clamp(min=self.eps) * self.g
160 |
161 |
162 | class Residual(nn.Module):
163 | def forward(self, x, residual):
164 | return x + residual
165 |
166 |
167 | class GRUGating(nn.Module):
168 | def __init__(self, dim):
169 | super().__init__()
170 | self.gru = nn.GRUCell(dim, dim)
171 |
172 | def forward(self, x, residual):
173 | gated_output = self.gru(
174 | rearrange(x, 'b n d -> (b n) d'),
175 | rearrange(residual, 'b n d -> (b n) d')
176 | )
177 |
178 | return gated_output.reshape_as(x)
179 |
180 |
181 | # feedforward
182 |
183 | class GEGLU(nn.Module):
184 | def __init__(self, dim_in, dim_out):
185 | super().__init__()
186 | self.proj = nn.Linear(dim_in, dim_out * 2)
187 |
188 | def forward(self, x):
189 | x, gate = self.proj(x).chunk(2, dim=-1)
190 | return x * F.gelu(gate)
191 |
192 |
193 | class FeedForward(nn.Module):
194 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
195 | super().__init__()
196 | inner_dim = int(dim * mult)
197 | dim_out = default(dim_out, dim)
198 | project_in = nn.Sequential(
199 | nn.Linear(dim, inner_dim),
200 | nn.GELU()
201 | ) if not glu else GEGLU(dim, inner_dim)
202 |
203 | self.net = nn.Sequential(
204 | project_in,
205 | nn.Dropout(dropout),
206 | nn.Linear(inner_dim, dim_out)
207 | )
208 |
209 | def forward(self, x):
210 | return self.net(x)
211 |
212 |
213 | # attention.
214 | class Attention(nn.Module):
215 | def __init__(
216 | self,
217 | dim,
218 | dim_head=DEFAULT_DIM_HEAD,
219 | heads=8,
220 | causal=False,
221 | mask=None,
222 | talking_heads=False,
223 | sparse_topk=None,
224 | use_entmax15=False,
225 | num_mem_kv=0,
226 | dropout=0.,
227 | on_attn=False
228 | ):
229 | super().__init__()
230 | if use_entmax15:
231 | raise NotImplementedError("Check out entmax activation instead of softmax activation!")
232 | self.scale = dim_head ** -0.5
233 | self.heads = heads
234 | self.causal = causal
235 | self.mask = mask
236 |
237 | inner_dim = dim_head * heads
238 |
239 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
240 | self.to_k = nn.Linear(dim, inner_dim, bias=False)
241 | self.to_v = nn.Linear(dim, inner_dim, bias=False)
242 | self.dropout = nn.Dropout(dropout)
243 |
244 | # talking heads
245 | self.talking_heads = talking_heads
246 | if talking_heads:
247 | self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
248 | self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
249 |
250 | # explicit topk sparse attention
251 | self.sparse_topk = sparse_topk
252 |
253 | # entmax
254 | #self.attn_fn = entmax15 if use_entmax15 else F.softmax
255 | self.attn_fn = F.softmax
256 |
257 | # add memory key / values
258 | self.num_mem_kv = num_mem_kv
259 | if num_mem_kv > 0:
260 | self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
261 | self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
262 |
263 | # attention on attention
264 | self.attn_on_attn = on_attn
265 | self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
266 |
267 | def forward(
268 | self,
269 | x,
270 | context=None,
271 | mask=None,
272 | context_mask=None,
273 | rel_pos=None,
274 | sinusoidal_emb=None,
275 | prev_attn=None,
276 | mem=None
277 | ):
278 | b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
279 | kv_input = default(context, x)
280 |
281 | q_input = x
282 | k_input = kv_input
283 | v_input = kv_input
284 |
285 | if exists(mem):
286 | k_input = torch.cat((mem, k_input), dim=-2)
287 | v_input = torch.cat((mem, v_input), dim=-2)
288 |
289 | if exists(sinusoidal_emb):
290 | # in shortformer, the query would start at a position offset depending on the past cached memory
291 | offset = k_input.shape[-2] - q_input.shape[-2]
292 | q_input = q_input + sinusoidal_emb(q_input, offset=offset)
293 | k_input = k_input + sinusoidal_emb(k_input)
294 |
295 | q = self.to_q(q_input)
296 | k = self.to_k(k_input)
297 | v = self.to_v(v_input)
298 |
299 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
300 |
301 | input_mask = None
302 | if any(map(exists, (mask, context_mask))):
303 | q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
304 | k_mask = q_mask if not exists(context) else context_mask
305 | k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
306 | q_mask = rearrange(q_mask, 'b i -> b () i ()')
307 | k_mask = rearrange(k_mask, 'b j -> b () () j')
308 | input_mask = q_mask * k_mask
309 |
310 | if self.num_mem_kv > 0:
311 | mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
312 | k = torch.cat((mem_k, k), dim=-2)
313 | v = torch.cat((mem_v, v), dim=-2)
314 | if exists(input_mask):
315 | input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
316 |
317 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
318 | mask_value = max_neg_value(dots)
319 |
320 | if exists(prev_attn):
321 | dots = dots + prev_attn
322 |
323 | pre_softmax_attn = dots
324 |
325 | if talking_heads:
326 | dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
327 |
328 | if exists(rel_pos):
329 | dots = rel_pos(dots)
330 |
331 | if exists(input_mask):
332 | dots.masked_fill_(~input_mask, mask_value)
333 | del input_mask
334 |
335 | if self.causal:
336 | i, j = dots.shape[-2:]
337 | r = torch.arange(i, device=device)
338 | mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
339 | mask = F.pad(mask, (j - i, 0), value=False)
340 | dots.masked_fill_(mask, mask_value)
341 | del mask
342 |
343 | if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
344 | top, _ = dots.topk(self.sparse_topk, dim=-1)
345 | vk = top[..., -1].unsqueeze(-1).expand_as(dots)
346 | mask = dots < vk
347 | dots.masked_fill_(mask, mask_value)
348 | del mask
349 |
350 | attn = self.attn_fn(dots, dim=-1)
351 | post_softmax_attn = attn
352 |
353 | attn = self.dropout(attn)
354 |
355 | if talking_heads:
356 | attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
357 |
358 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
359 | out = rearrange(out, 'b h n d -> b n (h d)')
360 |
361 | intermediates = Intermediates(
362 | pre_softmax_attn=pre_softmax_attn,
363 | post_softmax_attn=post_softmax_attn
364 | )
365 |
366 | return self.to_out(out), intermediates
367 |
368 |
369 | class AttentionLayers(nn.Module):
370 | def __init__(
371 | self,
372 | dim,
373 | depth,
374 | heads=8,
375 | causal=False,
376 | cross_attend=False,
377 | only_cross=False,
378 | use_scalenorm=False,
379 | use_rmsnorm=False,
380 | use_rezero=False,
381 | rel_pos_num_buckets=32,
382 | rel_pos_max_distance=128,
383 | position_infused_attn=False,
384 | custom_layers=None,
385 | sandwich_coef=None,
386 | par_ratio=None,
387 | residual_attn=False,
388 | cross_residual_attn=False,
389 | macaron=False,
390 | pre_norm=True,
391 | gate_residual=False,
392 | **kwargs
393 | ):
394 | super().__init__()
395 | ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
396 | attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
397 |
398 | dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
399 |
400 | self.dim = dim
401 | self.depth = depth
402 | self.layers = nn.ModuleList([])
403 |
404 | self.has_pos_emb = position_infused_attn
405 | self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
406 | self.rotary_pos_emb = always(None)
407 |
408 | assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
409 | self.rel_pos = None
410 |
411 | self.pre_norm = pre_norm
412 |
413 | self.residual_attn = residual_attn
414 | self.cross_residual_attn = cross_residual_attn
415 |
416 | norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
417 | norm_class = RMSNorm if use_rmsnorm else norm_class
418 | norm_fn = partial(norm_class, dim)
419 |
420 | norm_fn = nn.Identity if use_rezero else norm_fn
421 | branch_fn = Rezero if use_rezero else None
422 |
423 | if cross_attend and not only_cross:
424 | default_block = ('a', 'c', 'f')
425 | elif cross_attend and only_cross:
426 | default_block = ('c', 'f')
427 | else:
428 | default_block = ('a', 'f')
429 |
430 | if macaron:
431 | default_block = ('f',) + default_block
432 |
433 | if exists(custom_layers):
434 | layer_types = custom_layers
435 | elif exists(par_ratio):
436 | par_depth = depth * len(default_block)
437 | assert 1 < par_ratio <= par_depth, 'par ratio out of range'
438 | default_block = tuple(filter(not_equals('f'), default_block))
439 | par_attn = par_depth // par_ratio
440 | depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
441 | par_width = (depth_cut + depth_cut // par_attn) // par_attn
442 | assert len(default_block) <= par_width, 'default block is too large for par_ratio'
443 | par_block = default_block + ('f',) * (par_width - len(default_block))
444 | par_head = par_block * par_attn
445 | layer_types = par_head + ('f',) * (par_depth - len(par_head))
446 | elif exists(sandwich_coef):
447 | assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
448 | layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
449 | else:
450 | layer_types = default_block * depth
451 |
452 | self.layer_types = layer_types
453 | self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
454 |
455 | for layer_type in self.layer_types:
456 | if layer_type == 'a':
457 | layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
458 | elif layer_type == 'c':
459 | layer = Attention(dim, heads=heads, **attn_kwargs)
460 | elif layer_type == 'f':
461 | layer = FeedForward(dim, **ff_kwargs)
462 | layer = layer if not macaron else Scale(0.5, layer)
463 | else:
464 | raise Exception(f'invalid layer type {layer_type}')
465 |
466 | if isinstance(layer, Attention) and exists(branch_fn):
467 | layer = branch_fn(layer)
468 |
469 | if gate_residual:
470 | residual_fn = GRUGating(dim)
471 | else:
472 | residual_fn = Residual()
473 |
474 | self.layers.append(nn.ModuleList([
475 | norm_fn(),
476 | layer,
477 | residual_fn
478 | ]))
479 |
480 | def forward(
481 | self,
482 | x,
483 | context=None,
484 | mask=None,
485 | context_mask=None,
486 | mems=None,
487 | return_hiddens=False
488 | ):
489 | hiddens = []
490 | intermediates = []
491 | prev_attn = None
492 | prev_cross_attn = None
493 |
494 | mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
495 |
496 | for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
497 | is_last = ind == (len(self.layers) - 1)
498 |
499 | if layer_type == 'a':
500 | hiddens.append(x)
501 | layer_mem = mems.pop(0)
502 |
503 | residual = x
504 |
505 | if self.pre_norm:
506 | x = norm(x)
507 |
508 | if layer_type == 'a':
509 | out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
510 | prev_attn=prev_attn, mem=layer_mem)
511 | elif layer_type == 'c':
512 | out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
513 | elif layer_type == 'f':
514 | out = block(x)
515 |
516 | x = residual_fn(out, residual)
517 |
518 | if layer_type in ('a', 'c'):
519 | intermediates.append(inter)
520 |
521 | if layer_type == 'a' and self.residual_attn:
522 | prev_attn = inter.pre_softmax_attn
523 | elif layer_type == 'c' and self.cross_residual_attn:
524 | prev_cross_attn = inter.pre_softmax_attn
525 |
526 | if not self.pre_norm and not is_last:
527 | x = norm(x)
528 |
529 | if return_hiddens:
530 | intermediates = LayerIntermediates(
531 | hiddens=hiddens,
532 | attn_intermediates=intermediates
533 | )
534 |
535 | return x, intermediates
536 |
537 | return x
538 |
539 |
540 | class Encoder(AttentionLayers):
541 | def __init__(self, **kwargs):
542 | assert 'causal' not in kwargs, 'cannot set causality on encoder'
543 | super().__init__(causal=False, **kwargs)
544 |
545 |
546 |
547 | class TransformerWrapper(nn.Module):
548 | def __init__(
549 | self,
550 | *,
551 | num_tokens,
552 | max_seq_len,
553 | attn_layers,
554 | emb_dim=None,
555 | max_mem_len=0.,
556 | emb_dropout=0.,
557 | num_memory_tokens=None,
558 | tie_embedding=False,
559 | use_pos_emb=True
560 | ):
561 | super().__init__()
562 | assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
563 |
564 | dim = attn_layers.dim
565 | emb_dim = default(emb_dim, dim)
566 |
567 | self.max_seq_len = max_seq_len
568 | self.max_mem_len = max_mem_len
569 | self.num_tokens = num_tokens
570 |
571 | self.token_emb = nn.Embedding(num_tokens, emb_dim)
572 | self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
573 | use_pos_emb and not attn_layers.has_pos_emb) else always(0)
574 | self.emb_dropout = nn.Dropout(emb_dropout)
575 |
576 | self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
577 | self.attn_layers = attn_layers
578 | self.norm = nn.LayerNorm(dim)
579 |
580 | self.init_()
581 |
582 | self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
583 |
584 | # memory tokens (like [cls]) from Memory Transformers paper
585 | num_memory_tokens = default(num_memory_tokens, 0)
586 | self.num_memory_tokens = num_memory_tokens
587 | if num_memory_tokens > 0:
588 | self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
589 |
590 | # let funnel encoder know number of memory tokens, if specified
591 | if hasattr(attn_layers, 'num_memory_tokens'):
592 | attn_layers.num_memory_tokens = num_memory_tokens
593 |
594 | def init_(self):
595 | nn.init.normal_(self.token_emb.weight, std=0.02)
596 |
597 | def forward(
598 | self,
599 | x,
600 | return_embeddings=False,
601 | mask=None,
602 | return_mems=False,
603 | return_attn=False,
604 | mems=None,
605 | **kwargs
606 | ):
607 | b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
608 | x = self.token_emb(x)
609 | x += self.pos_emb(x)
610 | x = self.emb_dropout(x)
611 |
612 | x = self.project_emb(x)
613 |
614 | if num_mem > 0:
615 | mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
616 | x = torch.cat((mem, x), dim=1)
617 |
618 | # auto-handle masking after appending memory tokens
619 | if exists(mask):
620 | mask = F.pad(mask, (num_mem, 0), value=True)
621 |
622 | x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
623 | x = self.norm(x)
624 |
625 | mem, x = x[:, :num_mem], x[:, num_mem:]
626 |
627 | out = self.to_logits(x) if not return_embeddings else x
628 |
629 | if return_mems:
630 | hiddens = intermediates.hiddens
631 | new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
632 | new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
633 | return out, new_mems
634 |
635 | if return_attn:
636 | attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
637 | return out, attn_maps
638 |
639 | return out
--------------------------------------------------------------------------------
/nodes.py:
--------------------------------------------------------------------------------
1 | import os, argparse
2 | import sys
3 | from .scripts.gradio.i2v_test import Image2Video
4 | from .scripts.gradio.i2v_test_application import Image2Video as Image2VideoInterp
5 | sys.path.insert(1, os.path.join(sys.path[0], 'lvdm'))
6 |
7 | resolutions=["576_1024","320_512","256_256"]
8 | resolutionInterps=["320_512"]
9 |
10 | class DynamiCrafterInterpLoader:
11 | @classmethod
12 | def INPUT_TYPES(cls):
13 | return {
14 | "required": {
15 | "resolution": (resolutionInterps,{"default":"320_512"}),
16 | "frame_length": ("INT",{"default":16})
17 | }
18 | }
19 |
20 | RETURN_TYPES = ("DynamiCrafterInter",)
21 | RETURN_NAMES = ("model",)
22 | FUNCTION = "run_inference"
23 | CATEGORY = "DynamiCrafter"
24 |
25 | def run_inference(self,resolution,frame_length=16):
26 | image2video = Image2VideoInterp('./tmp/', resolution=resolution,frame_length=frame_length)
27 | return (image2video,)
28 |
29 | class DynamiCrafterLoader:
30 | @classmethod
31 | def INPUT_TYPES(cls):
32 | return {
33 | "required": {
34 | "resolution": (resolutions,{"default":"576_1024"}),
35 | "frame_length": ("INT",{"default":16})
36 | }
37 | }
38 |
39 | RETURN_TYPES = ("DynamiCrafter",)
40 | RETURN_NAMES = ("model",)
41 | FUNCTION = "run_inference"
42 | CATEGORY = "DynamiCrafter"
43 |
44 | def run_inference(self,resolution,frame_length=16):
45 | image2video = Image2Video('./tmp/', resolution=resolution,frame_length=frame_length)
46 | return (image2video,)
47 |
48 | class DynamiCrafterSimple:
49 | @classmethod
50 | def INPUT_TYPES(cls):
51 | return {
52 | "required": {
53 | "model": ("DynamiCrafter",),
54 | "image": ("IMAGE",),
55 | "prompt": ("STRING", {"default": ""}),
56 | "steps": ("INT", {"default": 50}),
57 | "cfg_scale": ("FLOAT", {"default": 7.5}),
58 | "eta": ("FLOAT", {"default": 1.0}),
59 | "motion": ("INT", {"default": 3}),
60 | "seed": ("INT", {"default": 123}),
61 | }
62 | }
63 |
64 | RETURN_TYPES = ("IMAGE",)
65 | RETURN_NAMES = ("image",)
66 | FUNCTION = "run_inference"
67 | CATEGORY = "DynamiCrafter"
68 |
69 | def run_inference(self,model,image,prompt,steps,cfg_scale,eta,motion,seed):
70 | image = 255.0 * image[0].cpu().numpy()
71 | #image = Image.fromarray(np.clip(image, 0, 255).astype(np.uint8))
72 |
73 | imgs= model.get_image(image, prompt, steps, cfg_scale, eta, motion, seed)
74 | return imgs
75 |
76 | class DynamiCrafterInterpSimple:
77 | @classmethod
78 | def INPUT_TYPES(cls):
79 | return {
80 | "required": {
81 | "model": ("DynamiCrafterInter",),
82 | "image": ("IMAGE",),
83 | "image1": ("IMAGE",),
84 | "prompt": ("STRING", {"default": ""}),
85 | "steps": ("INT", {"default": 50}),
86 | "cfg_scale": ("FLOAT", {"default": 7.5}),
87 | "eta": ("FLOAT", {"default": 1.0}),
88 | "motion": ("INT", {"default": 3}),
89 | "seed": ("INT", {"default": 123}),
90 | }
91 | }
92 |
93 | RETURN_TYPES = ("IMAGE",)
94 | RETURN_NAMES = ("image",)
95 | FUNCTION = "run_inference"
96 | CATEGORY = "DynamiCrafter"
97 |
98 | def run_inference(self,model,image,image1,prompt,steps,cfg_scale,eta,motion,seed):
99 | image = 255.0 * image[0].cpu().numpy()
100 | image1 = 255.0 * image1[0].cpu().numpy()
101 | #image = Image.fromarray(np.clip(image, 0, 255).astype(np.uint8))
102 |
103 | imgs= model.get_image(image, prompt, steps, cfg_scale, eta, motion, seed,image1)
104 | return imgs
105 |
106 |
107 | NODE_CLASS_MAPPINGS = {
108 | "DynamiCrafterLoader":DynamiCrafterLoader,
109 | "DynamiCrafter Simple":DynamiCrafterSimple,
110 | "DynamiCrafterInterpLoader":DynamiCrafterInterpLoader,
111 | "DynamiCrafterInterp Simple":DynamiCrafterInterpSimple,
112 | }
113 |
--------------------------------------------------------------------------------
/prompts/1024/astronaut04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/1024/astronaut04.png
--------------------------------------------------------------------------------
/prompts/1024/bike_chineseink.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/1024/bike_chineseink.png
--------------------------------------------------------------------------------
/prompts/1024/bloom01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/1024/bloom01.png
--------------------------------------------------------------------------------
/prompts/1024/firework03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/1024/firework03.png
--------------------------------------------------------------------------------
/prompts/1024/girl07.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/1024/girl07.png
--------------------------------------------------------------------------------
/prompts/1024/pour_bear.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/1024/pour_bear.png
--------------------------------------------------------------------------------
/prompts/1024/robot01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/1024/robot01.png
--------------------------------------------------------------------------------
/prompts/1024/test_prompts.txt:
--------------------------------------------------------------------------------
1 | a man in an astronaut suit playing a guitar
2 | riding a bike under a bridge
3 | time-lapse of a blooming flower with leaves and a stem
4 | fireworks display
5 | a beautiful woman with long hair and a dress blowing in the wind
6 | pouring beer into a glass of ice and beer
7 | a robot is walking through a destroyed city
8 | a group of penguins walking on a beach
--------------------------------------------------------------------------------
/prompts/1024/zreal_penguin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/1024/zreal_penguin.png
--------------------------------------------------------------------------------
/prompts/256/art.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/256/art.png
--------------------------------------------------------------------------------
/prompts/256/bear.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/256/bear.png
--------------------------------------------------------------------------------
/prompts/256/boy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/256/boy.png
--------------------------------------------------------------------------------
/prompts/256/dance1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/256/dance1.jpeg
--------------------------------------------------------------------------------
/prompts/256/fire_and_beach.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/256/fire_and_beach.jpg
--------------------------------------------------------------------------------
/prompts/256/girl2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/256/girl2.jpeg
--------------------------------------------------------------------------------
/prompts/256/girl3.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/256/girl3.jpeg
--------------------------------------------------------------------------------
/prompts/256/guitar0.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/256/guitar0.jpeg
--------------------------------------------------------------------------------
/prompts/256/test_prompts.txt:
--------------------------------------------------------------------------------
1 | man fishing in a boat at sunset
2 | a brown bear is walking in a zoo enclosure, some rocks around
3 | boy walking on the street
4 | two people dancing
5 | a campfire on the beach and the ocean waves in the background
6 | girl with fires and smoke on his head
7 | girl talking and blinking
8 | bear playing guitar happily, snowing
--------------------------------------------------------------------------------
/prompts/512/bloom01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512/bloom01.png
--------------------------------------------------------------------------------
/prompts/512/campfire.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512/campfire.png
--------------------------------------------------------------------------------
/prompts/512/girl08.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512/girl08.png
--------------------------------------------------------------------------------
/prompts/512/isometric.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512/isometric.png
--------------------------------------------------------------------------------
/prompts/512/pour_honey.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512/pour_honey.png
--------------------------------------------------------------------------------
/prompts/512/ship02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512/ship02.png
--------------------------------------------------------------------------------
/prompts/512/test_prompts.txt:
--------------------------------------------------------------------------------
1 | time-lapse of a blooming flower with leaves and a stem
2 | a bonfire is lit in the middle of a field
3 | a woman looking out in the rain
4 | rotating view, small house
5 | pouring honey onto some slices of bread
6 | a sailboat sailing in rough seas with a dramatic sunset
7 | a boat traveling on the ocean
8 | a group of penguins walking on a beach
--------------------------------------------------------------------------------
/prompts/512/zreal_boat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512/zreal_boat.png
--------------------------------------------------------------------------------
/prompts/512/zreal_penguin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512/zreal_penguin.png
--------------------------------------------------------------------------------
/prompts/512_interp/smile_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512_interp/smile_01.png
--------------------------------------------------------------------------------
/prompts/512_interp/smile_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512_interp/smile_02.png
--------------------------------------------------------------------------------
/prompts/512_interp/stone01_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512_interp/stone01_01.png
--------------------------------------------------------------------------------
/prompts/512_interp/stone01_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512_interp/stone01_02.png
--------------------------------------------------------------------------------
/prompts/512_interp/test_prompts.txt:
--------------------------------------------------------------------------------
1 | a smiling girl
2 | rotating view
3 | a man is walking towards a tree
--------------------------------------------------------------------------------
/prompts/512_interp/walk_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512_interp/walk_01.png
--------------------------------------------------------------------------------
/prompts/512_interp/walk_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512_interp/walk_02.png
--------------------------------------------------------------------------------
/prompts/512_loop/24.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512_loop/24.png
--------------------------------------------------------------------------------
/prompts/512_loop/36.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512_loop/36.png
--------------------------------------------------------------------------------
/prompts/512_loop/40.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/prompts/512_loop/40.png
--------------------------------------------------------------------------------
/prompts/512_loop/test_prompts.txt:
--------------------------------------------------------------------------------
1 | a beach with waves and clouds at sunset
2 | clothes swaying in the wind
3 | flowers swaying in the wind
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | decord
2 | einops
3 | imageio
4 | omegaconf
5 | opencv_python
6 | pandas
7 | Pillow
8 | pytorch_lightning
9 | PyYAML
10 | tqdm
11 | transformers
12 | moviepy
13 | av
14 | timm
15 | scikit-learn
16 | open_clip_torch==2.22.0
17 | kornia
--------------------------------------------------------------------------------
/scripts/evaluation/ddp_wrapper.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import argparse, importlib
3 | from pytorch_lightning import seed_everything
4 |
5 | import torch
6 | import torch.distributed as dist
7 |
8 | def setup_dist(local_rank):
9 | if dist.is_initialized():
10 | return
11 | torch.cuda.set_device(local_rank)
12 | torch.distributed.init_process_group('nccl', init_method='env://')
13 |
14 |
15 | def get_dist_info():
16 | if dist.is_available():
17 | initialized = dist.is_initialized()
18 | else:
19 | initialized = False
20 | if initialized:
21 | rank = dist.get_rank()
22 | world_size = dist.get_world_size()
23 | else:
24 | rank = 0
25 | world_size = 1
26 | return rank, world_size
27 |
28 |
29 | if __name__ == '__main__':
30 | now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument("--module", type=str, help="module name", default="inference")
33 | parser.add_argument("--local_rank", type=int, nargs="?", help="for ddp", default=0)
34 | args, unknown = parser.parse_known_args()
35 | inference_api = importlib.import_module(args.module, package=None)
36 |
37 | inference_parser = inference_api.get_parser()
38 | inference_args, unknown = inference_parser.parse_known_args()
39 |
40 | seed_everything(inference_args.seed)
41 | setup_dist(args.local_rank)
42 | torch.backends.cudnn.benchmark = True
43 | rank, gpu_num = get_dist_info()
44 |
45 | # inference_args.savedir = inference_args.savedir+str('_seed')+str(inference_args.seed)
46 | print("@DynamiCrafter Inference [rank%d]: %s"%(rank, now))
47 | inference_api.run_inference(inference_args, gpu_num, rank)
--------------------------------------------------------------------------------
/scripts/evaluation/funcs.py:
--------------------------------------------------------------------------------
1 | import os, sys, glob
2 | import numpy as np
3 | from collections import OrderedDict
4 | from decord import VideoReader, cpu
5 | import cv2
6 |
7 | import torch
8 | import torchvision
9 | sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
10 | from ...lvdm.models.samplers.ddim import DDIMSampler
11 | from einops import rearrange
12 |
13 |
14 | def batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1.0,\
15 | cfg_scale=1.0, temporal_cfg_scale=None, **kwargs):
16 | ddim_sampler = DDIMSampler(model)
17 | uncond_type = model.uncond_type
18 | batch_size = noise_shape[0]
19 | fs = cond["fs"]
20 | del cond["fs"]
21 | if noise_shape[-1] == 32:
22 | timestep_spacing = "uniform"
23 | guidance_rescale = 0.0
24 | else:
25 | timestep_spacing = "uniform_trailing"
26 | guidance_rescale = 0.7
27 | ## construct unconditional guidance
28 | if cfg_scale != 1.0:
29 | if uncond_type == "empty_seq":
30 | prompts = batch_size * [""]
31 | #prompts = N * T * [""] ## if is_imgbatch=True
32 | uc_emb = model.get_learned_conditioning(prompts)
33 | elif uncond_type == "zero_embed":
34 | c_emb = cond["c_crossattn"][0] if isinstance(cond, dict) else cond
35 | uc_emb = torch.zeros_like(c_emb)
36 |
37 | ## process image embedding token
38 | if hasattr(model, 'embedder'):
39 | uc_img = torch.zeros(noise_shape[0],3,224,224).to(model.device)
40 | ## img: b c h w >> b l c
41 | uc_img = model.embedder(uc_img)
42 | uc_img = model.image_proj_model(uc_img)
43 | uc_emb = torch.cat([uc_emb, uc_img], dim=1)
44 |
45 | if isinstance(cond, dict):
46 | uc = {key:cond[key] for key in cond.keys()}
47 | uc.update({'c_crossattn': [uc_emb]})
48 | else:
49 | uc = uc_emb
50 | else:
51 | uc = None
52 |
53 | x_T = None
54 | batch_variants = []
55 |
56 | for _ in range(n_samples):
57 | if ddim_sampler is not None:
58 | kwargs.update({"clean_cond": True})
59 | samples, _ = ddim_sampler.sample(S=ddim_steps,
60 | conditioning=cond,
61 | batch_size=noise_shape[0],
62 | shape=noise_shape[1:],
63 | verbose=False,
64 | unconditional_guidance_scale=cfg_scale,
65 | unconditional_conditioning=uc,
66 | eta=ddim_eta,
67 | temporal_length=noise_shape[2],
68 | conditional_guidance_scale_temporal=temporal_cfg_scale,
69 | x_T=x_T,
70 | fs=fs,
71 | timestep_spacing=timestep_spacing,
72 | guidance_rescale=guidance_rescale,
73 | **kwargs
74 | )
75 | ## reconstruct from latent to pixel space
76 | batch_images = model.decode_first_stage(samples)
77 | batch_variants.append(batch_images)
78 | ## batch, , c, t, h, w
79 | batch_variants = torch.stack(batch_variants, dim=1)
80 | return batch_variants
81 |
82 |
83 | def get_filelist(data_dir, ext='*'):
84 | file_list = glob.glob(os.path.join(data_dir, '*.%s'%ext))
85 | file_list.sort()
86 | return file_list
87 |
88 | def get_dirlist(path):
89 | list = []
90 | if (os.path.exists(path)):
91 | files = os.listdir(path)
92 | for file in files:
93 | m = os.path.join(path,file)
94 | if (os.path.isdir(m)):
95 | list.append(m)
96 | list.sort()
97 | return list
98 |
99 |
100 | def load_model_checkpoint(model, ckpt):
101 | def load_checkpoint(model, ckpt, full_strict):
102 | state_dict = torch.load(ckpt, map_location="cpu")
103 | if "state_dict" in list(state_dict.keys()):
104 | state_dict = state_dict["state_dict"]
105 | try:
106 | model.load_state_dict(state_dict, strict=full_strict)
107 | except:
108 | ## rename the keys for 256x256 model
109 | new_pl_sd = OrderedDict()
110 | for k,v in state_dict.items():
111 | new_pl_sd[k] = v
112 |
113 | for k in list(new_pl_sd.keys()):
114 | if "framestride_embed" in k:
115 | new_key = k.replace("framestride_embed", "fps_embedding")
116 | new_pl_sd[new_key] = new_pl_sd[k]
117 | del new_pl_sd[k]
118 | model.load_state_dict(new_pl_sd, strict=full_strict)
119 | else:
120 | ## deepspeed
121 | new_pl_sd = OrderedDict()
122 | for key in state_dict['module'].keys():
123 | new_pl_sd[key[16:]]=state_dict['module'][key]
124 | model.load_state_dict(new_pl_sd, strict=full_strict)
125 |
126 | return model
127 | load_checkpoint(model, ckpt, full_strict=True)
128 | print('>>> model checkpoint loaded.')
129 | return model
130 |
131 |
132 | def load_prompts(prompt_file):
133 | f = open(prompt_file, 'r')
134 | prompt_list = []
135 | for idx, line in enumerate(f.readlines()):
136 | l = line.strip()
137 | if len(l) != 0:
138 | prompt_list.append(l)
139 | f.close()
140 | return prompt_list
141 |
142 |
143 | def load_video_batch(filepath_list, frame_stride, video_size=(256,256), video_frames=16):
144 | '''
145 | Notice about some special cases:
146 | 1. video_frames=-1 means to take all the frames (with fs=1)
147 | 2. when the total video frames is less than required, padding strategy will be used (repreated last frame)
148 | '''
149 | fps_list = []
150 | batch_tensor = []
151 | assert frame_stride > 0, "valid frame stride should be a positive interge!"
152 | for filepath in filepath_list:
153 | padding_num = 0
154 | vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0])
155 | fps = vidreader.get_avg_fps()
156 | total_frames = len(vidreader)
157 | max_valid_frames = (total_frames-1) // frame_stride + 1
158 | if video_frames < 0:
159 | ## all frames are collected: fs=1 is a must
160 | required_frames = total_frames
161 | frame_stride = 1
162 | else:
163 | required_frames = video_frames
164 | query_frames = min(required_frames, max_valid_frames)
165 | frame_indices = [frame_stride*i for i in range(query_frames)]
166 |
167 | ## [t,h,w,c] -> [c,t,h,w]
168 | frames = vidreader.get_batch(frame_indices)
169 | frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
170 | frame_tensor = (frame_tensor / 255. - 0.5) * 2
171 | if max_valid_frames < required_frames:
172 | padding_num = required_frames - max_valid_frames
173 | frame_tensor = torch.cat([frame_tensor, *([frame_tensor[:,-1:,:,:]]*padding_num)], dim=1)
174 | print(f'{os.path.split(filepath)[1]} is not long enough: {padding_num} frames padded.')
175 | batch_tensor.append(frame_tensor)
176 | sample_fps = int(fps/frame_stride)
177 | fps_list.append(sample_fps)
178 |
179 | return torch.stack(batch_tensor, dim=0)
180 |
181 | from PIL import Image
182 | def load_image_batch(filepath_list, image_size=(256,256)):
183 | batch_tensor = []
184 | for filepath in filepath_list:
185 | _, filename = os.path.split(filepath)
186 | _, ext = os.path.splitext(filename)
187 | if ext == '.mp4':
188 | vidreader = VideoReader(filepath, ctx=cpu(0), width=image_size[1], height=image_size[0])
189 | frame = vidreader.get_batch([0])
190 | img_tensor = torch.tensor(frame.asnumpy()).squeeze(0).permute(2, 0, 1).float()
191 | elif ext == '.png' or ext == '.jpg':
192 | img = Image.open(filepath).convert("RGB")
193 | rgb_img = np.array(img, np.float32)
194 | #bgr_img = cv2.imread(filepath, cv2.IMREAD_COLOR)
195 | #bgr_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
196 | rgb_img = cv2.resize(rgb_img, (image_size[1],image_size[0]), interpolation=cv2.INTER_LINEAR)
197 | img_tensor = torch.from_numpy(rgb_img).permute(2, 0, 1).float()
198 | else:
199 | print(f'ERROR: <{ext}> image loading only support format: [mp4], [png], [jpg]')
200 | raise NotImplementedError
201 | img_tensor = (img_tensor / 255. - 0.5) * 2
202 | batch_tensor.append(img_tensor)
203 | return torch.stack(batch_tensor, dim=0)
204 |
205 |
206 | def save_videos(batch_tensors, savedir, filenames, fps=10):
207 | # b,samples,c,t,h,w
208 | n_samples = batch_tensors.shape[1]
209 | for idx, vid_tensor in enumerate(batch_tensors):
210 | video = vid_tensor.detach().cpu()
211 | video = torch.clamp(video.float(), -1., 1.)
212 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
213 | frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n_samples)) for framesheet in video] #[3, 1*h, n*w]
214 | grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
215 | grid = (grid + 1.0) / 2.0
216 | grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
217 | #savepath = os.path.join(savedir, f"{filenames[idx]}.mp4")
218 | #torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'})
219 | outframes=[]
220 | for i in range(grid.shape[0]):
221 | img = grid[i].numpy()
222 | image=Image.fromarray(img)
223 | image_tensor_out = torch.tensor(np.array(image).astype(np.float32) / 255.0) # Convert back to CxHxW
224 | image_tensor_out = torch.unsqueeze(image_tensor_out, 0)
225 | outframes.append(image_tensor_out)
226 |
227 | return torch.cat(tuple(outframes), dim=0).unsqueeze(0)
228 |
229 |
230 | def get_latent_z(model, videos):
231 | b, c, t, h, w = videos.shape
232 | x = rearrange(videos, 'b c t h w -> (b t) c h w')
233 | z = model.encode_first_stage(x)
234 | z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
235 | return z
--------------------------------------------------------------------------------
/scripts/evaluation/inference.py:
--------------------------------------------------------------------------------
1 | import argparse, os, sys, glob
2 | import datetime, time
3 | from omegaconf import OmegaConf
4 | from tqdm import tqdm
5 | from einops import rearrange, repeat
6 | from collections import OrderedDict
7 |
8 | import torch
9 | import torchvision
10 | import torchvision.transforms as transforms
11 | from pytorch_lightning import seed_everything
12 | from PIL import Image
13 | sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
14 | from ...lvdm.models.samplers.ddim import DDIMSampler
15 | from ...lvdm.models.samplers.ddim_multiplecond import DDIMSampler as DDIMSampler_multicond
16 | from ...utils.utils import instantiate_from_config
17 |
18 |
19 | def get_filelist(data_dir, postfixes):
20 | patterns = [os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes]
21 | file_list = []
22 | for pattern in patterns:
23 | file_list.extend(glob.glob(pattern))
24 | file_list.sort()
25 | return file_list
26 |
27 | def load_model_checkpoint(model, ckpt):
28 | state_dict = torch.load(ckpt, map_location="cpu")
29 | if "state_dict" in list(state_dict.keys()):
30 | state_dict = state_dict["state_dict"]
31 | try:
32 | model.load_state_dict(state_dict, strict=True)
33 | except:
34 | ## rename the keys for 256x256 model
35 | new_pl_sd = OrderedDict()
36 | for k,v in state_dict.items():
37 | new_pl_sd[k] = v
38 |
39 | for k in list(new_pl_sd.keys()):
40 | if "framestride_embed" in k:
41 | new_key = k.replace("framestride_embed", "fps_embedding")
42 | new_pl_sd[new_key] = new_pl_sd[k]
43 | del new_pl_sd[k]
44 | model.load_state_dict(new_pl_sd, strict=True)
45 | else:
46 | # deepspeed
47 | new_pl_sd = OrderedDict()
48 | for key in state_dict['module'].keys():
49 | new_pl_sd[key[16:]]=state_dict['module'][key]
50 | model.load_state_dict(new_pl_sd)
51 | print('>>> model checkpoint loaded.')
52 | return model
53 |
54 | def load_prompts(prompt_file):
55 | f = open(prompt_file, 'r')
56 | prompt_list = []
57 | for idx, line in enumerate(f.readlines()):
58 | l = line.strip()
59 | if len(l) != 0:
60 | prompt_list.append(l)
61 | f.close()
62 | return prompt_list
63 |
64 | def load_data_prompts(data_dir, video_size=(256,256), video_frames=16, interp=False):
65 | transform = transforms.Compose([
66 | transforms.Resize(min(video_size)),
67 | transforms.CenterCrop(video_size),
68 | transforms.ToTensor(),
69 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
70 | ## load prompts
71 | prompt_file = get_filelist(data_dir, ['txt'])
72 | assert len(prompt_file) > 0, "Error: found NO prompt file!"
73 | ###### default prompt
74 | default_idx = 0
75 | default_idx = min(default_idx, len(prompt_file)-1)
76 | if len(prompt_file) > 1:
77 | print(f"Warning: multiple prompt files exist. The one {os.path.split(prompt_file[default_idx])[1]} is used.")
78 | ## only use the first one (sorted by name) if multiple exist
79 |
80 | ## load video
81 | file_list = get_filelist(data_dir, ['jpg', 'png', 'jpeg', 'JPEG', 'PNG'])
82 | # assert len(file_list) == n_samples, "Error: data and prompts are NOT paired!"
83 | data_list = []
84 | filename_list = []
85 | prompt_list = load_prompts(prompt_file[default_idx])
86 | n_samples = len(prompt_list)
87 | for idx in range(n_samples):
88 | if interp:
89 | image1 = Image.open(file_list[2*idx]).convert('RGB')
90 | image_tensor1 = transform(image1).unsqueeze(1) # [c,1,h,w]
91 | image2 = Image.open(file_list[2*idx+1]).convert('RGB')
92 | image_tensor2 = transform(image2).unsqueeze(1) # [c,1,h,w]
93 | frame_tensor1 = repeat(image_tensor1, 'c t h w -> c (repeat t) h w', repeat=video_frames//2)
94 | frame_tensor2 = repeat(image_tensor2, 'c t h w -> c (repeat t) h w', repeat=video_frames//2)
95 | frame_tensor = torch.cat([frame_tensor1, frame_tensor2], dim=1)
96 | _, filename = os.path.split(file_list[idx*2])
97 | else:
98 | image = Image.open(file_list[idx]).convert('RGB')
99 | image_tensor = transform(image).unsqueeze(1) # [c,1,h,w]
100 | frame_tensor = repeat(image_tensor, 'c t h w -> c (repeat t) h w', repeat=video_frames)
101 | _, filename = os.path.split(file_list[idx])
102 |
103 | data_list.append(frame_tensor)
104 | filename_list.append(filename)
105 |
106 | return filename_list, data_list, prompt_list
107 |
108 |
109 | def save_results(prompt, samples, filename, fakedir, fps=8, loop=False):
110 | filename = filename.split('.')[0]+'.mp4'
111 | prompt = prompt[0] if isinstance(prompt, list) else prompt
112 |
113 | ## save video
114 | videos = [samples]
115 | savedirs = [fakedir]
116 | for idx, video in enumerate(videos):
117 | if video is None:
118 | continue
119 | # b,c,t,h,w
120 | video = video.detach().cpu()
121 | video = torch.clamp(video.float(), -1., 1.)
122 | n = video.shape[0]
123 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
124 | if loop:
125 | video = video[:-1,...]
126 |
127 | frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video] #[3, 1*h, n*w]
128 | grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, h, n*w]
129 | grid = (grid + 1.0) / 2.0
130 | grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
131 | path = os.path.join(savedirs[idx], filename)
132 | torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '10'}) ## crf indicates the quality
133 |
134 |
135 | def save_results_seperate(prompt, samples, filename, fakedir, fps=10, loop=False):
136 | prompt = prompt[0] if isinstance(prompt, list) else prompt
137 |
138 | ## save video
139 | videos = [samples]
140 | savedirs = [fakedir]
141 | for idx, video in enumerate(videos):
142 | if video is None:
143 | continue
144 | # b,c,t,h,w
145 | video = video.detach().cpu()
146 | if loop: # remove the last frame
147 | video = video[:,:,:-1,...]
148 | video = torch.clamp(video.float(), -1., 1.)
149 | n = video.shape[0]
150 | for i in range(n):
151 | grid = video[i,...]
152 | grid = (grid + 1.0) / 2.0
153 | grid = (grid * 255).to(torch.uint8).permute(1, 2, 3, 0) #thwc
154 | path = os.path.join(savedirs[idx].replace('samples', 'samples_separate'), f'{filename.split(".")[0]}_sample{i}.mp4')
155 | torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '10'})
156 |
157 | def get_latent_z(model, videos):
158 | b, c, t, h, w = videos.shape
159 | x = rearrange(videos, 'b c t h w -> (b t) c h w')
160 | z = model.encode_first_stage(x)
161 | z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
162 | return z
163 |
164 |
165 | def image_guided_synthesis(model, prompts, videos, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1., \
166 | unconditional_guidance_scale=1.0, cfg_img=None, fs=None, text_input=False, multiple_cond_cfg=False, loop=False, interp=False, timestep_spacing='uniform', guidance_rescale=0.0, **kwargs):
167 | ddim_sampler = DDIMSampler(model) if not multiple_cond_cfg else DDIMSampler_multicond(model)
168 | batch_size = noise_shape[0]
169 | fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
170 |
171 | if not text_input:
172 | prompts = [""]*batch_size
173 |
174 | img = videos[:,:,0] #bchw
175 | img_emb = model.embedder(img) ## blc
176 | img_emb = model.image_proj_model(img_emb)
177 |
178 | cond_emb = model.get_learned_conditioning(prompts)
179 | cond = {"c_crossattn": [torch.cat([cond_emb,img_emb], dim=1)]}
180 | if model.model.conditioning_key == 'hybrid':
181 | z = get_latent_z(model, videos) # b c t h w
182 | if loop or interp:
183 | img_cat_cond = torch.zeros_like(z)
184 | img_cat_cond[:,:,0,:,:] = z[:,:,0,:,:]
185 | img_cat_cond[:,:,-1,:,:] = z[:,:,-1,:,:]
186 | else:
187 | img_cat_cond = z[:,:,:1,:,:]
188 | img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=z.shape[2])
189 | cond["c_concat"] = [img_cat_cond] # b c 1 h w
190 |
191 | if unconditional_guidance_scale != 1.0:
192 | if model.uncond_type == "empty_seq":
193 | prompts = batch_size * [""]
194 | uc_emb = model.get_learned_conditioning(prompts)
195 | elif model.uncond_type == "zero_embed":
196 | uc_emb = torch.zeros_like(cond_emb)
197 | uc_img_emb = model.embedder(torch.zeros_like(img)) ## b l c
198 | uc_img_emb = model.image_proj_model(uc_img_emb)
199 | uc = {"c_crossattn": [torch.cat([uc_emb,uc_img_emb],dim=1)]}
200 | if model.model.conditioning_key == 'hybrid':
201 | uc["c_concat"] = [img_cat_cond]
202 | else:
203 | uc = None
204 |
205 | ## we need one more unconditioning image=yes, text=""
206 | if multiple_cond_cfg and cfg_img != 1.0:
207 | uc_2 = {"c_crossattn": [torch.cat([uc_emb,img_emb],dim=1)]}
208 | if model.model.conditioning_key == 'hybrid':
209 | uc_2["c_concat"] = [img_cat_cond]
210 | kwargs.update({"unconditional_conditioning_img_nonetext": uc_2})
211 | else:
212 | kwargs.update({"unconditional_conditioning_img_nonetext": None})
213 |
214 | z0 = None
215 | cond_mask = None
216 |
217 | batch_variants = []
218 | for _ in range(n_samples):
219 |
220 | if z0 is not None:
221 | cond_z0 = z0.clone()
222 | kwargs.update({"clean_cond": True})
223 | else:
224 | cond_z0 = None
225 | if ddim_sampler is not None:
226 |
227 | samples, _ = ddim_sampler.sample(S=ddim_steps,
228 | conditioning=cond,
229 | batch_size=batch_size,
230 | shape=noise_shape[1:],
231 | verbose=False,
232 | unconditional_guidance_scale=unconditional_guidance_scale,
233 | unconditional_conditioning=uc,
234 | eta=ddim_eta,
235 | cfg_img=cfg_img,
236 | mask=cond_mask,
237 | x0=cond_z0,
238 | fs=fs,
239 | timestep_spacing=timestep_spacing,
240 | guidance_rescale=guidance_rescale,
241 | **kwargs
242 | )
243 |
244 | ## reconstruct from latent to pixel space
245 | batch_images = model.decode_first_stage(samples)
246 | batch_variants.append(batch_images)
247 | ## variants, batch, c, t, h, w
248 | batch_variants = torch.stack(batch_variants)
249 | return batch_variants.permute(1, 0, 2, 3, 4, 5)
250 |
251 |
252 | def run_inference(args, gpu_num, gpu_no):
253 | ## model config
254 | config = OmegaConf.load(args.config)
255 | model_config = config.pop("model", OmegaConf.create())
256 |
257 | ## set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set"
258 | model_config['params']['unet_config']['params']['use_checkpoint'] = False
259 | model = instantiate_from_config(model_config)
260 | model = model.cuda(gpu_no)
261 | model.perframe_ae = args.perframe_ae
262 | assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
263 | model = load_model_checkpoint(model, args.ckpt_path)
264 | model.eval()
265 |
266 | ## run over data
267 | assert (args.height % 16 == 0) and (args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!"
268 | assert args.bs == 1, "Current implementation only support [batch size = 1]!"
269 | ## latent noise shape
270 | h, w = args.height // 8, args.width // 8
271 | channels = model.model.diffusion_model.out_channels
272 | n_frames = args.video_length
273 | print(f'Inference with {n_frames} frames')
274 | noise_shape = [args.bs, channels, n_frames, h, w]
275 |
276 | fakedir = os.path.join(args.savedir, "samples")
277 | fakedir_separate = os.path.join(args.savedir, "samples_separate")
278 |
279 | # os.makedirs(fakedir, exist_ok=True)
280 | os.makedirs(fakedir_separate, exist_ok=True)
281 |
282 | ## prompt file setting
283 | assert os.path.exists(args.prompt_dir), "Error: prompt file Not Found!"
284 | filename_list, data_list, prompt_list = load_data_prompts(args.prompt_dir, video_size=(args.height, args.width), video_frames=n_frames, interp=args.interp)
285 | num_samples = len(prompt_list)
286 | samples_split = num_samples // gpu_num
287 | print('Prompts testing [rank:%d] %d/%d samples loaded.'%(gpu_no, samples_split, num_samples))
288 | #indices = random.choices(list(range(0, num_samples)), k=samples_per_device)
289 | indices = list(range(samples_split*gpu_no, samples_split*(gpu_no+1)))
290 | prompt_list_rank = [prompt_list[i] for i in indices]
291 | data_list_rank = [data_list[i] for i in indices]
292 | filename_list_rank = [filename_list[i] for i in indices]
293 |
294 | start = time.time()
295 | with torch.no_grad(), torch.cuda.amp.autocast():
296 | for idx, indice in tqdm(enumerate(range(0, len(prompt_list_rank), args.bs)), desc='Sample Batch'):
297 | prompts = prompt_list_rank[indice:indice+args.bs]
298 | videos = data_list_rank[indice:indice+args.bs]
299 | filenames = filename_list_rank[indice:indice+args.bs]
300 | if isinstance(videos, list):
301 | videos = torch.stack(videos, dim=0).to("cuda")
302 | else:
303 | videos = videos.unsqueeze(0).to("cuda")
304 |
305 | batch_samples = image_guided_synthesis(model, prompts, videos, noise_shape, args.n_samples, args.ddim_steps, args.ddim_eta, \
306 | args.unconditional_guidance_scale, args.cfg_img, args.frame_stride, args.text_input, args.multiple_cond_cfg, args.loop, args.interp, args.timestep_spacing, args.guidance_rescale)
307 |
308 | ## save each example individually
309 | for nn, samples in enumerate(batch_samples):
310 | ## samples : [n_samples,c,t,h,w]
311 | prompt = prompts[nn]
312 | filename = filenames[nn]
313 | # save_results(prompt, samples, filename, fakedir, fps=8, loop=args.loop)
314 | save_results_seperate(prompt, samples, filename, fakedir, fps=8, loop=args.loop)
315 |
316 | print(f"Saved in {args.savedir}. Time used: {(time.time() - start):.2f} seconds")
317 |
318 |
319 | def get_parser():
320 | parser = argparse.ArgumentParser()
321 | parser.add_argument("--savedir", type=str, default=None, help="results saving path")
322 | parser.add_argument("--ckpt_path", type=str, default=None, help="checkpoint path")
323 | parser.add_argument("--config", type=str, help="config (yaml) path")
324 | parser.add_argument("--prompt_dir", type=str, default=None, help="a data dir containing videos and prompts")
325 | parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt",)
326 | parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM",)
327 | parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)",)
328 | parser.add_argument("--bs", type=int, default=1, help="batch size for inference, should be one")
329 | parser.add_argument("--height", type=int, default=512, help="image height, in pixel space")
330 | parser.add_argument("--width", type=int, default=512, help="image width, in pixel space")
331 | parser.add_argument("--frame_stride", type=int, default=3, help="frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)")
332 | parser.add_argument("--unconditional_guidance_scale", type=float, default=1.0, help="prompt classifier-free guidance")
333 | parser.add_argument("--seed", type=int, default=123, help="seed for seed_everything")
334 | parser.add_argument("--video_length", type=int, default=16, help="inference video length")
335 | parser.add_argument("--negative_prompt", action='store_true', default=False, help="negative prompt")
336 | parser.add_argument("--text_input", action='store_true', default=False, help="input text to I2V model or not")
337 | parser.add_argument("--multiple_cond_cfg", action='store_true', default=False, help="use multi-condition cfg or not")
338 | parser.add_argument("--cfg_img", type=float, default=None, help="guidance scale for image conditioning")
339 | parser.add_argument("--timestep_spacing", type=str, default="uniform", help="The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.")
340 | parser.add_argument("--guidance_rescale", type=float, default=0.0, help="guidance rescale in [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891)")
341 | parser.add_argument("--perframe_ae", action='store_true', default=False, help="if we use per-frame AE decoding, set it to True to save GPU memory, especially for the model of 576x1024")
342 |
343 | ## currently not support looping video and generative frame interpolation
344 | parser.add_argument("--loop", action='store_true', default=False, help="generate looping videos or not")
345 | parser.add_argument("--interp", action='store_true', default=False, help="generate generative frame interpolation or not")
346 | return parser
347 |
348 |
349 | if __name__ == '__main__':
350 | now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
351 | print("@DynamiCrafter cond-Inference: %s"%now)
352 | parser = get_parser()
353 | args = parser.parse_args()
354 |
355 | seed_everything(args.seed)
356 | rank, gpu_num = 0, 1
357 | run_inference(args, gpu_num, rank)
--------------------------------------------------------------------------------
/scripts/gradio/i2v_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from omegaconf import OmegaConf
4 | import torch
5 | from ..evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling, get_latent_z
6 | from ...utils.utils import instantiate_from_config
7 | from huggingface_hub import hf_hub_download
8 | from einops import repeat
9 | import torchvision.transforms as transforms
10 | from pytorch_lightning import seed_everything
11 | import folder_paths
12 |
13 | comfy_path = os.path.dirname(folder_paths.__file__)
14 | models_path=f'{comfy_path}/models/'
15 | config_path=f'{comfy_path}/custom_nodes/ComfyUI-DynamiCrafter/'
16 |
17 | class Image2Video():
18 | def __init__(self,result_dir='./tmp/',gpu_num=1,resolution='256_256',frame_length=16) -> None:
19 | self.resolution = (int(resolution.split('_')[0]), int(resolution.split('_')[1])) #hw
20 | self.download_model()
21 |
22 | self.result_dir = result_dir
23 | if not os.path.exists(self.result_dir):
24 | os.mkdir(self.result_dir)
25 | ckpt_path=models_path+'checkpoints/dynamicrafter_'+resolution.split('_')[1]+'_v1/model.ckpt'
26 | config_file=config_path+'configs/inference_'+resolution.split('_')[1]+'_v1.0.yaml'
27 | config = OmegaConf.load(config_file)
28 | OmegaConf.update(config, "model.params.unet_config.params.temporal_length", frame_length)
29 | model_config = config.pop("model", OmegaConf.create())
30 | model_config['params']['unet_config']['params']['use_checkpoint']=False
31 | model_list = []
32 | for gpu_id in range(gpu_num):
33 | model = instantiate_from_config(model_config)
34 | # model = model.cuda(gpu_id)
35 | assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
36 | model = load_model_checkpoint(model, ckpt_path)
37 | model.eval()
38 | model_list.append(model)
39 | self.model_list = model_list
40 | self.save_fps = 8
41 |
42 | def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
43 | seed_everything(seed)
44 | transform = transforms.Compose([
45 | transforms.Resize(min(self.resolution)),
46 | transforms.CenterCrop(self.resolution),
47 | ])
48 | torch.cuda.empty_cache()
49 | print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
50 | start = time.time()
51 | gpu_id=0
52 | if steps > 60:
53 | steps = 60
54 | model = self.model_list[gpu_id]
55 | model = model.cuda()
56 | batch_size=1
57 | channels = model.model.diffusion_model.out_channels
58 | frames = model.temporal_length
59 | h, w = self.resolution[0] // 8, self.resolution[1] // 8
60 | noise_shape = [batch_size, channels, frames, h, w]
61 |
62 | # text cond
63 | text_emb = model.get_learned_conditioning([prompt])
64 |
65 | # img cond
66 | img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
67 | img_tensor = (img_tensor / 255. - 0.5) * 2
68 |
69 | image_tensor_resized = transform(img_tensor) #3,h,w
70 | videos = image_tensor_resized.unsqueeze(0) # bchw
71 |
72 | z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw
73 |
74 | img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
75 |
76 | cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
77 | img_emb = model.image_proj_model(cond_images)
78 |
79 | imtext_cond = torch.cat([text_emb, img_emb], dim=1)
80 |
81 | fs = torch.tensor([fs], dtype=torch.long, device=model.device)
82 | cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
83 |
84 | ## inference
85 | batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
86 | ## b,samples,c,t,h,w
87 | prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
88 | prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
89 | prompt_str=prompt_str[:40]
90 | if len(prompt_str) == 0:
91 | prompt_str = 'empty_prompt'
92 |
93 | imgs=save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps)
94 | print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
95 | model = model.cpu()
96 | #return os.path.join(self.result_dir, f"{prompt_str}.mp4")
97 | return imgs
98 |
99 | def download_model(self):
100 | REPO_ID = 'Doubiiu/DynamiCrafter_'+str(self.resolution[1]) if self.resolution[1]!=256 else 'Doubiiu/DynamiCrafter'
101 | filename_list = ['model.ckpt']
102 | if not os.path.exists(models_path+'checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/'):
103 | os.makedirs(models_path+'checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/')
104 | for filename in filename_list:
105 | local_file = os.path.join(models_path+'checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/', filename)
106 | if not os.path.exists(local_file):
107 | hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=models_path+'checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/', local_dir_use_symlinks=False)
108 |
109 | if __name__ == '__main__':
110 | i2v = Image2Video()
111 | video_path = i2v.get_image('prompts/art.png','man fishing in a boat at sunset')
112 | print('done', video_path)
--------------------------------------------------------------------------------
/scripts/gradio/i2v_test_application.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from omegaconf import OmegaConf
4 | import torch
5 | from ..evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling, get_latent_z
6 | from ...utils.utils import instantiate_from_config
7 | from huggingface_hub import hf_hub_download
8 | from einops import repeat
9 | import torchvision.transforms as transforms
10 | from pytorch_lightning import seed_everything
11 | import folder_paths
12 |
13 | comfy_path = os.path.dirname(folder_paths.__file__)
14 | models_path=f'{comfy_path}/models/'
15 | config_path=f'{comfy_path}/custom_nodes/ComfyUI-DynamiCrafter/'
16 |
17 | class Image2Video():
18 | def __init__(self,result_dir='./tmp/',gpu_num=1,resolution='256_256',frame_length=16) -> None:
19 | self.resolution = (int(resolution.split('_')[0]), int(resolution.split('_')[1])) #hw
20 | self.download_model()
21 |
22 | self.result_dir = result_dir
23 | if not os.path.exists(self.result_dir):
24 | os.mkdir(self.result_dir)
25 | ckpt_path=models_path+'checkpoints/dynamicrafter_'+resolution.split('_')[1]+'_interp_v1/model.ckpt'
26 | config_file=config_path+'configs/inference_'+resolution.split('_')[1]+'_v1.0.yaml'
27 | config = OmegaConf.load(config_file)
28 | OmegaConf.update(config, "model.params.unet_config.params.temporal_length", frame_length)
29 | model_config = config.pop("model", OmegaConf.create())
30 | model_config['params']['unet_config']['params']['use_checkpoint']=False
31 | model_list = []
32 | for gpu_id in range(gpu_num):
33 | model = instantiate_from_config(model_config)
34 | # model = model.cuda(gpu_id)
35 | assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
36 | model = load_model_checkpoint(model, ckpt_path)
37 | model.eval()
38 | model_list.append(model)
39 | self.model_list = model_list
40 | self.save_fps = 8
41 |
42 | def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, image2=None):
43 | seed_everything(seed)
44 | transform = transforms.Compose([
45 | transforms.Resize(min(self.resolution)),
46 | transforms.CenterCrop(self.resolution),
47 | ])
48 | torch.cuda.empty_cache()
49 | print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
50 | start = time.time()
51 | gpu_id=0
52 | if steps > 60:
53 | steps = 60
54 | model = self.model_list[gpu_id]
55 | model = model.cuda()
56 | batch_size=1
57 | channels = model.model.diffusion_model.out_channels
58 | frames = model.temporal_length
59 | h, w = self.resolution[0] // 8, self.resolution[1] // 8
60 | noise_shape = [batch_size, channels, frames, h, w]
61 |
62 | # text cond
63 | with torch.no_grad(), torch.cuda.amp.autocast():
64 | text_emb = model.get_learned_conditioning([prompt])
65 |
66 | # img cond
67 | img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
68 | img_tensor = (img_tensor / 255. - 0.5) * 2
69 |
70 | image_tensor_resized = transform(img_tensor) #3,h,w
71 | videos = image_tensor_resized.unsqueeze(0) # bchw
72 |
73 | z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw
74 |
75 |
76 | if image2 is not None:
77 | img_tensor2 = torch.from_numpy(image2).permute(2, 0, 1).float().to(model.device)
78 | img_tensor2 = (img_tensor2 / 255. - 0.5) * 2
79 |
80 | image_tensor_resized2 = transform(img_tensor2) #3,h,w
81 | videos2 = image_tensor_resized2.unsqueeze(0) # bchw
82 |
83 | z2 = get_latent_z(model, videos2.unsqueeze(2)) #bc,1,hw
84 |
85 | img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
86 |
87 | img_tensor_repeat = torch.zeros_like(img_tensor_repeat)
88 |
89 | ## old
90 | img_tensor_repeat[:,:,:1,:,:] = z
91 | if image2 is not None:
92 | img_tensor_repeat[:,:,-1:,:,:] = z2
93 | else:
94 | img_tensor_repeat[:,:,-1:,:,:] = z
95 |
96 |
97 | cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
98 | img_emb = model.image_proj_model(cond_images)
99 |
100 | imtext_cond = torch.cat([text_emb, img_emb], dim=1)
101 |
102 | fs = torch.tensor([fs], dtype=torch.long, device=model.device)
103 | cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
104 |
105 | ## inference
106 | batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
107 |
108 | ## remove the last frame
109 | if image2 is None:
110 | batch_samples = batch_samples[:,:,:,:-1,...]
111 | ## b,samples,c,t,h,w
112 | prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
113 | prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
114 | prompt_str=prompt_str[:40]
115 | if len(prompt_str) == 0:
116 | prompt_str = 'empty_prompt'
117 |
118 | imgs=save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps)
119 | print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
120 | model = model.cpu()
121 | #return os.path.join(self.result_dir, f"{prompt_str}.mp4")
122 | return imgs
123 |
124 | def download_model(self):
125 | REPO_ID = 'Doubiiu/DynamiCrafter_'+str(self.resolution[1])+'_Interp'
126 | filename_list = ['model.ckpt']
127 | if not os.path.exists(models_path+'checkpoints/dynamicrafter_'+str(self.resolution[1])+'_interp_v1/'):
128 | os.makedirs(models_path+'checkpoints/dynamicrafter_'+str(self.resolution[1])+'_interp_v1/')
129 | for filename in filename_list:
130 | local_file = os.path.join(models_path+'checkpoints/dynamicrafter_'+str(self.resolution[1])+'_interp_v1/', filename)
131 | if not os.path.exists(local_file):
132 | hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=models_path+'checkpoints/dynamicrafter_'+str(self.resolution[1])+'_interp_v1/', local_dir_use_symlinks=False)
133 |
134 | if __name__ == '__main__':
135 | i2v = Image2Video()
136 | video_path = i2v.get_image('prompts/art.png','man fishing in a boat at sunset')
137 | print('done', video_path)
--------------------------------------------------------------------------------
/scripts/run.sh:
--------------------------------------------------------------------------------
1 | version=$1 ##1024, 512, 256
2 | seed=123
3 | name=dynamicrafter_$1_seed${seed}
4 |
5 | ckpt=checkpoints/dynamicrafter_$1_v1/model.ckpt
6 | config=configs/inference_$1_v1.0.yaml
7 |
8 | prompt_dir=prompts/$1/
9 | res_dir="results"
10 |
11 | if [ "$1" == "256" ]; then
12 | H=256
13 | FS=3 ## This model adopts frame stride=3, range recommended: 1-6 (larger value -> larger motion)
14 | elif [ "$1" == "512" ]; then
15 | H=320
16 | FS=24 ## This model adopts FPS=24, range recommended: 15-30 (smaller value -> larger motion)
17 | elif [ "$1" == "1024" ]; then
18 | H=576
19 | FS=10 ## This model adopts FPS=10, range recommended: 15-5 (smaller value -> larger motion)
20 | else
21 | echo "Invalid input. Please enter 256, 512, or 1024."
22 | exit 1
23 | fi
24 |
25 | if [ "$1" == "256" ]; then
26 | CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/inference.py \
27 | --seed ${seed} \
28 | --ckpt_path $ckpt \
29 | --config $config \
30 | --savedir $res_dir/$name \
31 | --n_samples 1 \
32 | --bs 1 --height ${H} --width $1 \
33 | --unconditional_guidance_scale 7.5 \
34 | --ddim_steps 50 \
35 | --ddim_eta 1.0 \
36 | --prompt_dir $prompt_dir \
37 | --text_input \
38 | --video_length 16 \
39 | --frame_stride ${FS}
40 | else
41 | CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/inference.py \
42 | --seed ${seed} \
43 | --ckpt_path $ckpt \
44 | --config $config \
45 | --savedir $res_dir/$name \
46 | --n_samples 1 \
47 | --bs 1 --height ${H} --width $1 \
48 | --unconditional_guidance_scale 7.5 \
49 | --ddim_steps 50 \
50 | --ddim_eta 1.0 \
51 | --prompt_dir $prompt_dir \
52 | --text_input \
53 | --video_length 16 \
54 | --frame_stride ${FS} \
55 | --timestep_spacing 'uniform_trailing' --guidance_rescale 0.7 --perframe_ae
56 | fi
57 |
58 |
59 | ## multi-cond CFG: the is s_txt, is s_img
60 | #--multiple_cond_cfg --cfg_img 7.5
61 | #--loop
--------------------------------------------------------------------------------
/scripts/run_application.sh:
--------------------------------------------------------------------------------
1 | version=$1 # interp or loop
2 | ckpt=checkpoints/dynamicrafter_512_interp_v1/model.ckpt
3 | config=configs/inference_512_v1.0.yaml
4 |
5 | prompt_dir=prompts/512_$1/
6 | res_dir="results"
7 |
8 | FS=5 ## This model adopts FPS=5, range recommended: 5-30 (smaller value -> larger motion)
9 |
10 |
11 | if [ "$1" == "interp" ]; then
12 | seed=12306
13 | name=dynamicrafter_512_$1_seed${seed}
14 | CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/inference.py \
15 | --seed ${seed} \
16 | --ckpt_path $ckpt \
17 | --config $config \
18 | --savedir $res_dir/$name \
19 | --n_samples 1 \
20 | --bs 1 --height 320 --width 512 \
21 | --unconditional_guidance_scale 7.5 \
22 | --ddim_steps 50 \
23 | --ddim_eta 1.0 \
24 | --prompt_dir $prompt_dir \
25 | --text_input \
26 | --video_length 16 \
27 | --frame_stride ${FS} \
28 | --timestep_spacing 'uniform_trailing' --guidance_rescale 0.7 --perframe_ae --interp
29 | else
30 | seed=234
31 | name=dynamicrafter_512_$1_seed${seed}
32 | CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/inference.py \
33 | --seed ${seed} \
34 | --ckpt_path $ckpt \
35 | --config $config \
36 | --savedir $res_dir/$name \
37 | --n_samples 1 \
38 | --bs 1 --height 320 --width 512 \
39 | --unconditional_guidance_scale 7.5 \
40 | --ddim_steps 50 \
41 | --ddim_eta 1.0 \
42 | --prompt_dir $prompt_dir \
43 | --text_input \
44 | --video_length 16 \
45 | --frame_stride ${FS} \
46 | --timestep_spacing 'uniform_trailing' --guidance_rescale 0.7 --perframe_ae --loop
47 | fi
48 |
--------------------------------------------------------------------------------
/scripts/run_mp.sh:
--------------------------------------------------------------------------------
1 | version=$1 ##1024, 512, 256
2 | seed=123
3 |
4 | name=dynamicrafter_$1_mp_seed${seed}
5 |
6 | ckpt=checkpoints/dynamicrafter_$1_v1/model.ckpt
7 | config=configs/inference_$1_v1.0.yaml
8 |
9 | prompt_dir=prompts/$1/
10 | res_dir="results"
11 |
12 | if [ "$1" == "256" ]; then
13 | H=256
14 | FS=3 ## This model adopts frame stride=3
15 | elif [ "$1" == "512" ]; then
16 | H=320
17 | FS=24 ## This model adopts FPS=24
18 | elif [ "$1" == "1024" ]; then
19 | H=576
20 | FS=10 ## This model adopts FPS=10
21 | else
22 | echo "Invalid input. Please enter 256, 512, or 1024."
23 | exit 1
24 | fi
25 |
26 | # if [ "$1" == "256" ]; then
27 | # CUDA_VISIBLE_DEVICES=2 python3 scripts/evaluation/inference.py \
28 | # --seed 123 \
29 | # --ckpt_path $ckpt \
30 | # --config $config \
31 | # --savedir $res_dir/$name \
32 | # --n_samples 1 \
33 | # --bs 1 --height ${H} --width $1 \
34 | # --unconditional_guidance_scale 7.5 \
35 | # --ddim_steps 50 \
36 | # --ddim_eta 1.0 \
37 | # --prompt_dir $prompt_dir \
38 | # --text_input \
39 | # --video_length 16 \
40 | # --frame_stride ${FS}
41 | # else
42 | # CUDA_VISIBLE_DEVICES=2 python3 scripts/evaluation/inference.py \
43 | # --seed 123 \
44 | # --ckpt_path $ckpt \
45 | # --config $config \
46 | # --savedir $res_dir/$name \
47 | # --n_samples 1 \
48 | # --bs 1 --height ${H} --width $1 \
49 | # --unconditional_guidance_scale 7.5 \
50 | # --ddim_steps 50 \
51 | # --ddim_eta 1.0 \
52 | # --prompt_dir $prompt_dir \
53 | # --text_input \
54 | # --video_length 16 \
55 | # --frame_stride ${FS} \
56 | # --timestep_spacing 'uniform_trailing' --guidance_rescale 0.7
57 | # fi
58 |
59 |
60 | ## multi-cond CFG: the is s_txt, is s_img
61 | #--multiple_cond_cfg --cfg_img 7.5
62 | #--loop
63 |
64 | ## inference using single node with multi-GPUs:
65 | if [ "$1" == "256" ]; then
66 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \
67 | --nproc_per_node=8 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \
68 | scripts/evaluation/ddp_wrapper.py \
69 | --module 'inference' \
70 | --seed ${seed} \
71 | --ckpt_path $ckpt \
72 | --config $config \
73 | --savedir $res_dir/$name \
74 | --n_samples 1 \
75 | --bs 1 --height ${H} --width $1 \
76 | --unconditional_guidance_scale 7.5 \
77 | --ddim_steps 50 \
78 | --ddim_eta 1.0 \
79 | --prompt_dir $prompt_dir \
80 | --text_input \
81 | --video_length 16 \
82 | --frame_stride ${FS}
83 | else
84 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \
85 | --nproc_per_node=8 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \
86 | scripts/evaluation/ddp_wrapper.py \
87 | --module 'inference' \
88 | --seed ${seed} \
89 | --ckpt_path $ckpt \
90 | --config $config \
91 | --savedir $res_dir/$name \
92 | --n_samples 1 \
93 | --bs 1 --height ${H} --width $1 \
94 | --unconditional_guidance_scale 7.5 \
95 | --ddim_steps 50 \
96 | --ddim_eta 1.0 \
97 | --prompt_dir $prompt_dir \
98 | --text_input \
99 | --video_length 16 \
100 | --frame_stride ${FS} \
101 | --timestep_spacing 'uniform_trailing' --guidance_rescale 0.7 --perframe_ae
102 | fi
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import numpy as np
3 | import cv2
4 | import torch
5 | import torch.distributed as dist
6 |
7 |
8 | def count_params(model, verbose=False):
9 | total_params = sum(p.numel() for p in model.parameters())
10 | if verbose:
11 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
12 | return total_params
13 |
14 |
15 | def check_istarget(name, para_list):
16 | """
17 | name: full name of source para
18 | para_list: partial name of target para
19 | """
20 | istarget=False
21 | for para in para_list:
22 | if para in name:
23 | return True
24 | return istarget
25 |
26 |
27 | def instantiate_from_config(config):
28 | if not "target" in config:
29 | if config == '__is_first_stage__':
30 | return None
31 | elif config == "__is_unconditional__":
32 | return None
33 | raise KeyError("Expected key `target` to instantiate.")
34 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
35 |
36 |
37 | def get_obj_from_str(string, reload=False):
38 | module, cls = string.rsplit(".", 1)
39 | if reload:
40 | module_imp = importlib.import_module(module)
41 | importlib.reload(module_imp)
42 | return getattr(importlib.import_module(module, package=None), cls)
43 |
44 |
45 | def load_npz_from_dir(data_dir):
46 | data = [np.load(os.path.join(data_dir, data_name))['arr_0'] for data_name in os.listdir(data_dir)]
47 | data = np.concatenate(data, axis=0)
48 | return data
49 |
50 |
51 | def load_npz_from_paths(data_paths):
52 | data = [np.load(data_path)['arr_0'] for data_path in data_paths]
53 | data = np.concatenate(data, axis=0)
54 | return data
55 |
56 |
57 | def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None):
58 | h, w = image.shape[:2]
59 | if resize_short_edge is not None:
60 | k = resize_short_edge / min(h, w)
61 | else:
62 | k = max_resolution / (h * w)
63 | k = k**0.5
64 | h = int(np.round(h * k / 64)) * 64
65 | w = int(np.round(w * k / 64)) * 64
66 | image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
67 | return image
68 |
69 |
70 | def setup_dist(args):
71 | if dist.is_initialized():
72 | return
73 | torch.cuda.set_device(args.local_rank)
74 | torch.distributed.init_process_group(
75 | 'nccl',
76 | init_method='env://'
77 | )
--------------------------------------------------------------------------------
/video.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/video.gif
--------------------------------------------------------------------------------
/wf-basic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/wf-basic.png
--------------------------------------------------------------------------------
/wf-interp.json:
--------------------------------------------------------------------------------
1 | {
2 | "last_node_id": 9,
3 | "last_link_id": 12,
4 | "nodes": [
5 | {
6 | "id": 1,
7 | "type": "DynamiCrafterInterpLoader",
8 | "pos": [
9 | 35,
10 | 194
11 | ],
12 | "size": {
13 | "0": 315,
14 | "1": 82
15 | },
16 | "flags": {},
17 | "order": 0,
18 | "mode": 0,
19 | "outputs": [
20 | {
21 | "name": "model",
22 | "type": "DynamiCrafterInter",
23 | "links": [
24 | 1
25 | ],
26 | "shape": 3,
27 | "slot_index": 0
28 | }
29 | ],
30 | "properties": {
31 | "Node name for S&R": "DynamiCrafterInterpLoader"
32 | },
33 | "widgets_values": [
34 | "320_512",
35 | 16
36 | ]
37 | },
38 | {
39 | "id": 3,
40 | "type": "LoadImage",
41 | "pos": [
42 | 90,
43 | 412
44 | ],
45 | "size": {
46 | "0": 315,
47 | "1": 314
48 | },
49 | "flags": {
50 | "collapsed": true
51 | },
52 | "order": 1,
53 | "mode": 0,
54 | "outputs": [
55 | {
56 | "name": "IMAGE",
57 | "type": "IMAGE",
58 | "links": [
59 | 9
60 | ],
61 | "shape": 3,
62 | "slot_index": 0
63 | },
64 | {
65 | "name": "MASK",
66 | "type": "MASK",
67 | "links": null,
68 | "shape": 3
69 | }
70 | ],
71 | "properties": {
72 | "Node name for S&R": "LoadImage"
73 | },
74 | "widgets_values": [
75 | "u=2397542458,3133539061&fm=193&f=GIF.jfif",
76 | "image"
77 | ]
78 | },
79 | {
80 | "id": 4,
81 | "type": "LoadImage",
82 | "pos": [
83 | 96,
84 | 513
85 | ],
86 | "size": {
87 | "0": 315,
88 | "1": 314
89 | },
90 | "flags": {
91 | "collapsed": true
92 | },
93 | "order": 2,
94 | "mode": 0,
95 | "outputs": [
96 | {
97 | "name": "IMAGE",
98 | "type": "IMAGE",
99 | "links": [
100 | 10
101 | ],
102 | "shape": 3,
103 | "slot_index": 0
104 | },
105 | {
106 | "name": "MASK",
107 | "type": "MASK",
108 | "links": null,
109 | "shape": 3
110 | }
111 | ],
112 | "properties": {
113 | "Node name for S&R": "LoadImage"
114 | },
115 | "widgets_values": [
116 | "u=2511982910,2454873241&fm=193&f=GIF.jfif",
117 | "image"
118 | ]
119 | },
120 | {
121 | "id": 9,
122 | "type": "VHS_VideoCombine",
123 | "pos": [
124 | 874,
125 | 262
126 | ],
127 | "size": [
128 | 315,
129 | 488.375
130 | ],
131 | "flags": {},
132 | "order": 4,
133 | "mode": 0,
134 | "inputs": [
135 | {
136 | "name": "images",
137 | "type": "IMAGE",
138 | "link": 12
139 | },
140 | {
141 | "name": "audio",
142 | "type": "VHS_AUDIO",
143 | "link": null
144 | },
145 | {
146 | "name": "batch_manager",
147 | "type": "VHS_BatchManager",
148 | "link": null
149 | }
150 | ],
151 | "outputs": [
152 | {
153 | "name": "Filenames",
154 | "type": "VHS_FILENAMES",
155 | "links": null,
156 | "shape": 3
157 | }
158 | ],
159 | "properties": {
160 | "Node name for S&R": "VHS_VideoCombine"
161 | },
162 | "widgets_values": {
163 | "frame_rate": 8,
164 | "loop_count": 0,
165 | "filename_prefix": "AnimateDiff",
166 | "format": "video/h264-mp4",
167 | "pix_fmt": "yuv420p",
168 | "crf": 19,
169 | "save_metadata": true,
170 | "pingpong": false,
171 | "save_output": true,
172 | "videopreview": {
173 | "hidden": false,
174 | "paused": false,
175 | "params": {
176 | "filename": "AnimateDiff_00058.mp4",
177 | "subfolder": "",
178 | "type": "output",
179 | "format": "video/h264-mp4"
180 | }
181 | }
182 | }
183 | },
184 | {
185 | "id": 2,
186 | "type": "DynamiCrafterInterp Simple",
187 | "pos": [
188 | 394,
189 | 297
190 | ],
191 | "size": {
192 | "0": 315,
193 | "1": 242
194 | },
195 | "flags": {},
196 | "order": 3,
197 | "mode": 0,
198 | "inputs": [
199 | {
200 | "name": "model",
201 | "type": "DynamiCrafterInter",
202 | "link": 1
203 | },
204 | {
205 | "name": "image",
206 | "type": "IMAGE",
207 | "link": 9
208 | },
209 | {
210 | "name": "image1",
211 | "type": "IMAGE",
212 | "link": 10
213 | }
214 | ],
215 | "outputs": [
216 | {
217 | "name": "image",
218 | "type": "IMAGE",
219 | "links": [
220 | 12
221 | ],
222 | "shape": 3,
223 | "slot_index": 0
224 | }
225 | ],
226 | "properties": {
227 | "Node name for S&R": "DynamiCrafterInterp Simple"
228 | },
229 | "widgets_values": [
230 | "",
231 | 50,
232 | 7.5,
233 | 1,
234 | 3,
235 | 1489,
236 | "randomize"
237 | ]
238 | }
239 | ],
240 | "links": [
241 | [
242 | 1,
243 | 1,
244 | 0,
245 | 2,
246 | 0,
247 | "DynamiCrafterInter"
248 | ],
249 | [
250 | 9,
251 | 3,
252 | 0,
253 | 2,
254 | 1,
255 | "IMAGE"
256 | ],
257 | [
258 | 10,
259 | 4,
260 | 0,
261 | 2,
262 | 2,
263 | "IMAGE"
264 | ],
265 | [
266 | 12,
267 | 2,
268 | 0,
269 | 9,
270 | 0,
271 | "IMAGE"
272 | ]
273 | ],
274 | "groups": [],
275 | "config": {},
276 | "extra": {},
277 | "version": 0.4
278 | }
--------------------------------------------------------------------------------
/wf-interp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaojie/ComfyUI-DynamiCrafter/0c9821c062f60b3b36f743a20f3b2f3961e83ac7/wf-interp.png
--------------------------------------------------------------------------------
/workflow.json:
--------------------------------------------------------------------------------
1 | {
2 | "last_node_id": 4,
3 | "last_link_id": 3,
4 | "nodes": [
5 | {
6 | "id": 3,
7 | "type": "LoadImage",
8 | "pos": [
9 | 143,
10 | 296
11 | ],
12 | "size": {
13 | "0": 315,
14 | "1": 314
15 | },
16 | "flags": {},
17 | "order": 0,
18 | "mode": 0,
19 | "outputs": [
20 | {
21 | "name": "IMAGE",
22 | "type": "IMAGE",
23 | "links": [
24 | 2
25 | ],
26 | "shape": 3,
27 | "slot_index": 0
28 | },
29 | {
30 | "name": "MASK",
31 | "type": "MASK",
32 | "links": null,
33 | "shape": 3
34 | }
35 | ],
36 | "properties": {
37 | "Node name for S&R": "LoadImage"
38 | },
39 | "widgets_values": [
40 | "00031-4116268099-1girl_AND trees_AND grass_AND sky_AND 1girl (1).png",
41 | "image"
42 | ]
43 | },
44 | {
45 | "id": 2,
46 | "type": "DynamiCrafter Simple",
47 | "pos": [
48 | 604,
49 | 141.33334350585938
50 | ],
51 | "size": {
52 | "0": 315,
53 | "1": 222
54 | },
55 | "flags": {},
56 | "order": 2,
57 | "mode": 0,
58 | "inputs": [
59 | {
60 | "name": "model",
61 | "type": "DynamiCrafter",
62 | "link": 1
63 | },
64 | {
65 | "name": "image",
66 | "type": "IMAGE",
67 | "link": 2
68 | }
69 | ],
70 | "outputs": [
71 | {
72 | "name": "image",
73 | "type": "IMAGE",
74 | "links": [
75 | 3
76 | ],
77 | "shape": 3,
78 | "slot_index": 0
79 | }
80 | ],
81 | "properties": {
82 | "Node name for S&R": "DynamiCrafter Simple"
83 | },
84 | "widgets_values": [
85 | "1girl sleeping",
86 | 50,
87 | 7.5,
88 | 1,
89 | 3,
90 | 1022,
91 | "randomize"
92 | ]
93 | },
94 | {
95 | "id": 4,
96 | "type": "VHS_VideoCombine",
97 | "pos": [
98 | 1061,
99 | 144
100 | ],
101 | "size": [
102 | 315,
103 | 414.3125
104 | ],
105 | "flags": {},
106 | "order": 3,
107 | "mode": 0,
108 | "inputs": [
109 | {
110 | "name": "images",
111 | "type": "IMAGE",
112 | "link": 3
113 | }
114 | ],
115 | "outputs": [],
116 | "properties": {
117 | "Node name for S&R": "VHS_VideoCombine"
118 | },
119 | "widgets_values": {
120 | "frame_rate": 8,
121 | "loop_count": 0,
122 | "filename_prefix": "AnimateDiff",
123 | "format": "image/gif",
124 | "pingpong": false,
125 | "save_image": true,
126 | "crf": 20,
127 | "save_metadata": true,
128 | "audio_file": "",
129 | "videopreview": {
130 | "hidden": false,
131 | "paused": false,
132 | "params": {
133 | "filename": "AnimateDiff_00900.gif",
134 | "subfolder": "",
135 | "type": "output",
136 | "format": "image/gif"
137 | }
138 | }
139 | }
140 | },
141 | {
142 | "id": 1,
143 | "type": "DynamiCrafterLoader",
144 | "pos": [
145 | 193,
146 | 148
147 | ],
148 | "size": {
149 | "0": 315,
150 | "1": 82
151 | },
152 | "flags": {},
153 | "order": 1,
154 | "mode": 0,
155 | "outputs": [
156 | {
157 | "name": "model",
158 | "type": "DynamiCrafter",
159 | "links": [
160 | 1
161 | ],
162 | "shape": 3,
163 | "slot_index": 0
164 | }
165 | ],
166 | "properties": {
167 | "Node name for S&R": "DynamiCrafterLoader"
168 | },
169 | "widgets_values": [
170 | "576_1024",
171 | 16
172 | ]
173 | }
174 | ],
175 | "links": [
176 | [
177 | 1,
178 | 1,
179 | 0,
180 | 2,
181 | 0,
182 | "DynamiCrafter"
183 | ],
184 | [
185 | 2,
186 | 3,
187 | 0,
188 | 2,
189 | 1,
190 | "IMAGE"
191 | ],
192 | [
193 | 3,
194 | 2,
195 | 0,
196 | 4,
197 | 0,
198 | "IMAGE"
199 | ]
200 | ],
201 | "groups": [],
202 | "config": {},
203 | "extra": {},
204 | "version": 0.4
205 | }
--------------------------------------------------------------------------------