├── assets
├── schemes.png
├── teaser.webp
└── Pipeline.webp
├── .gitignore
├── diffusion
├── __pycache__
│ ├── respace.cpython-312.pyc
│ ├── __init__.cpython-312.pyc
│ ├── diffusion_utils.cpython-312.pyc
│ └── gaussian_diffusion.cpython-312.pyc
├── __init__.py
├── diffusion_utils.py
├── respace.py
└── timestep_sampler.py
├── preparation
├── change_fps.sh
├── extract_audio.py
├── process_audio.py
├── extract_frames.py
└── audio_processor.py
├── cache.sh
├── datasets
├── __init__.py
├── frames_dataset.py
└── video_transforms.py
├── configs
├── sample.yaml
└── train.yaml
├── models
├── vae
│ ├── __init__.py
│ └── vae.py
├── __init__.py
├── audio_proj.py
├── utils.py
├── basic_modules.py
└── model_cat.py
├── README.md
├── sample_pl.py
├── sample.py
├── sample_long.py
├── sample_long_pl.py
├── train_pl.py
└── utils.py
/assets/schemes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhang-haojie/LetsTalk/HEAD/assets/schemes.png
--------------------------------------------------------------------------------
/assets/teaser.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhang-haojie/LetsTalk/HEAD/assets/teaser.webp
--------------------------------------------------------------------------------
/assets/Pipeline.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhang-haojie/LetsTalk/HEAD/assets/Pipeline.webp
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 | *$py.class
4 |
5 | pretrained*
6 | results*
7 | sample_videos*
--------------------------------------------------------------------------------
/diffusion/__pycache__/respace.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhang-haojie/LetsTalk/HEAD/diffusion/__pycache__/respace.cpython-312.pyc
--------------------------------------------------------------------------------
/diffusion/__pycache__/__init__.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhang-haojie/LetsTalk/HEAD/diffusion/__pycache__/__init__.cpython-312.pyc
--------------------------------------------------------------------------------
/diffusion/__pycache__/diffusion_utils.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhang-haojie/LetsTalk/HEAD/diffusion/__pycache__/diffusion_utils.cpython-312.pyc
--------------------------------------------------------------------------------
/diffusion/__pycache__/gaussian_diffusion.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhang-haojie/LetsTalk/HEAD/diffusion/__pycache__/gaussian_diffusion.cpython-312.pyc
--------------------------------------------------------------------------------
/preparation/change_fps.sh:
--------------------------------------------------------------------------------
1 |
2 | SOURCE_FOLDER="/path/to/dataset/videos"
3 | OUTPUT_FOLDER="/path/to/dataset/videos_25fps"
4 | TARGET_FRAMERATE=25
5 |
6 | mkdir -p "$OUTPUT_FOLDER"
7 |
8 | export SOURCE_FOLDER OUTPUT_FOLDER TARGET_FRAMERATE
9 |
10 | find "$SOURCE_FOLDER" -name '*.mp4' | parallel ffmpeg -i {} -r "$TARGET_FRAMERATE" "$OUTPUT_FOLDER/{/.}.mp4"
11 |
12 | echo "Frame rate conversion completed for all videos."
13 |
--------------------------------------------------------------------------------
/cache.sh:
--------------------------------------------------------------------------------
1 | pip install pytorch-lightning -i https://mirrors.tencent.com/tencent_pypi/simple
2 | pip install omegaconf -i https://mirrors.tencent.com/tencent_pypi/simple
3 | pip install diffusers -i https://mirrors.tencent.com/tencent_pypi/simple
4 | pip install tensorboard -i https://mirrors.tencent.com/tencent_pypi/simple
5 | pip install moviepy==1.0.3 -i https://mirrors.tencent.com/tencent_pypi/simple
6 |
7 |
8 | python3 datasets/images2latent.py
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from datasets import video_transforms
3 | from .frames_dataset import VideoFramesDataset, VideoLatentDataset, FramesLatentDataset, VideoDubbingDataset
4 |
5 |
6 | def get_dataset(args):
7 | temporal_sample = video_transforms.TemporalRandomCrop((args.num_frames + args.initial_frames) * args.frame_interval)
8 |
9 | transform = transforms.Compose([
10 | video_transforms.ToTensorVideo(), # TCHW
11 | video_transforms.UCFCenterCropVideo(args.image_size),
12 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
13 | ])
14 |
15 | if args.dataset == "frame":
16 | return VideoFramesDataset(args, transform=transform, temporal_sample=temporal_sample)
17 | elif "frame" in args.dataset:
18 | return FramesLatentDataset(args, temporal_sample=temporal_sample)
19 | elif args.dataset == "video":
20 | return VideoDubbingDataset(args, transform=transform)
21 | return VideoLatentDataset(args, temporal_sample=temporal_sample)
22 |
--------------------------------------------------------------------------------
/configs/sample.yaml:
--------------------------------------------------------------------------------
1 | # dataset
2 | dataset: frame
3 |
4 | data_dir: /path/to/dataset/videos # s
5 | audio_dir: /path/to/dataset/audio
6 | pretrained_model_path: /path/to/checkpoints/latte
7 |
8 | # path:
9 | pretrained:
10 | wav2vec: ./pretrained/wav2vec/wav2vec2-base-960h
11 | audio_separator: ./pretrained/audio_separator/Kim_Vocal_2.onnx
12 | save_video_path: ./sample_videos
13 |
14 | # model config:
15 | model: VDT-B/2
16 | num_frames: 16
17 | initial_frames: 2
18 | clip_frames: 16
19 | frame_interval: 1
20 | image_size: 256
21 | in_channels: 4
22 | temp_comp_rate: 1
23 | fixed_spatial: False
24 | attention_bias: True
25 | learn_sigma: True
26 | extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation
27 | num_classes:
28 | use_seque: False
29 | audio_margin: 2
30 | audio_dim: 768
31 | audio_token: 32
32 |
33 | precision: bf16 # GPU Axx: bf16-mixed
34 | gradient_checkpointing: False
35 |
36 | # sample config:
37 | seed:
38 | sample_method: ddpm
39 | num_sampling_steps: 250
40 | cfg_scale: 1.0
41 | sample_rate: 16000
42 | fps: 25
43 | num_workers: 8
44 | per_proc_batch_size: 1
45 |
--------------------------------------------------------------------------------
/configs/train.yaml:
--------------------------------------------------------------------------------
1 | # dataset
2 | dataset: frame
3 |
4 | data_dir: /path/to/dataset/videos # s
5 | audio_dir: /path/to/dataset/audio
6 | pretrained_model_path: /path/to/checkpoints/latte
7 |
8 | # save and load
9 | results_dir: ./results
10 | pretrained:
11 |
12 | # model config:
13 | model: VDT-B/2
14 | num_frames: 16
15 | initial_frames: 0
16 | clip_frames: 16
17 | frame_interval: 1
18 | fixed_spatial: False
19 | image_size: 256
20 | in_channels: 4
21 | temp_comp_rate: 1
22 | num_sampling_steps: 250
23 | learn_sigma: True # important
24 | extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation
25 | use_seque: False
26 | audio_margin: 2
27 | audio_dim: 768
28 | audio_token: 32
29 |
30 | precision: bf16 # GPU Axx: bf16-mixed
31 |
32 | # train config:
33 | save_ceph: True # important
34 | learning_rate: 1e-4
35 | ckpt_every: 100000
36 | clip_max_norm: 0.1
37 | start_clip_iter: 20000
38 | local_batch_size: 1 # important
39 | max_train_steps: 300000
40 | global_seed: 3407
41 | num_workers: 8
42 | log_every: 50
43 | lr_warmup_steps: 0
44 | resume_from_checkpoint:
45 | gradient_accumulation_steps: 1 # TODO
46 | num_classes:
47 |
48 | # low VRAM and speed up training
49 | use_compile: False
50 | gradient_checkpointing: True
51 | enable_xformers_memory_efficient_attention: False
--------------------------------------------------------------------------------
/preparation/extract_audio.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import argparse
4 | import subprocess
5 | from tqdm import tqdm
6 |
7 | def main(video_root, audio_root):
8 | error_files = []
9 | video_files = []
10 | for root, _, files in os.walk(video_root):
11 | for file in files:
12 | if file.endswith('.mp4'):
13 | video_files.append(os.path.join(root, file))
14 |
15 | for video_path in tqdm(video_files):
16 | relative_path = os.path.relpath(video_path, video_root)
17 | wav_path = os.path.join(audio_root, os.path.splitext(relative_path)[0] + '.wav')
18 |
19 | if not os.path.exists(os.path.dirname(wav_path)):
20 | os.makedirs(os.path.dirname(wav_path))
21 |
22 | ffmpeg_command = [
23 | 'ffmpeg', '-y',
24 | '-i', video_path,
25 | '-vn', '-acodec',
26 | "pcm_s16le", '-ar', '16000', '-ac', '2',
27 | wav_path
28 | ]
29 |
30 | try:
31 | subprocess.run(ffmpeg_command, check=True)
32 | except subprocess.CalledProcessError as e:
33 | print(f"Error extracting audio from video {video_path}: {e}")
34 | error_files.append(video_path)
35 | continue
36 |
37 | for e_file in error_files:
38 | print(f"Error extracting audio from video {e_file}")
39 |
40 | if __name__ == '__main__':
41 | video_root = "path/to/your/input/folder"
42 | audio_root = "path/to/your/output/folder"
43 | os.makedirs(audio_root, exist_ok=True)
44 |
45 | main(video_root, audio_root)
46 |
--------------------------------------------------------------------------------
/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from . import gaussian_diffusion as gd
7 | from .respace import SpacedDiffusion, space_timesteps
8 |
9 |
10 | def create_diffusion(
11 | timestep_respacing,
12 | noise_schedule="linear",
13 | use_kl=False,
14 | sigma_small=False,
15 | predict_xstart=False,
16 | learn_sigma=True,
17 | # learn_sigma=False,
18 | rescale_learned_sigmas=False,
19 | diffusion_steps=1000
20 | ):
21 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
22 | if use_kl:
23 | loss_type = gd.LossType.RESCALED_KL
24 | elif rescale_learned_sigmas:
25 | loss_type = gd.LossType.RESCALED_MSE
26 | else:
27 | loss_type = gd.LossType.MSE
28 | if timestep_respacing is None or timestep_respacing == "":
29 | timestep_respacing = [diffusion_steps]
30 | return SpacedDiffusion(
31 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
32 | betas=betas,
33 | model_mean_type=(
34 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
35 | ),
36 | model_var_type=(
37 | (
38 | gd.ModelVarType.FIXED_LARGE
39 | if not sigma_small
40 | else gd.ModelVarType.FIXED_SMALL
41 | )
42 | if not learn_sigma
43 | else gd.ModelVarType.LEARNED_RANGE
44 | ),
45 | loss_type=loss_type
46 | # rescale_timesteps=rescale_timesteps,
47 | )
48 |
--------------------------------------------------------------------------------
/preparation/process_audio.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from pathlib import Path
4 | from tqdm import tqdm
5 | from audio_processor import AudioProcessor
6 |
7 |
8 | def get_audio_paths(source_dir: Path):
9 | """Get .wav files from the source directory."""
10 | return sorted([item for item in source_dir.iterdir() if item.is_file() and item.suffix == ".wav"])
11 |
12 | def process_audio(audio_path: Path, output_dir: Path, audio_processor: AudioProcessor):
13 | """Process a single audio file and save its embedding."""
14 | audio_emb, _ = audio_processor.preprocess(audio_path)
15 | torch.save(audio_emb, os.path.join(output_dir, f"{audio_path.stem}.pt"))
16 |
17 | def process_all_audios(input_audio_list, output_dir):
18 | """Process all audio files in the list."""
19 | wav2vec_model_path = "pretrained/wav2vec/wav2vec2-base-960h"
20 | audio_separator_model_file = "pretrained/audio_separator/Kim_Vocal_2.onnx"
21 | audio_processor = AudioProcessor(
22 | 16000,
23 | 25,
24 | wav2vec_model_path,
25 | os.path.dirname(audio_separator_model_file),
26 | os.path.basename(audio_separator_model_file),
27 | os.path.join(output_dir, "vocals"),
28 | # only_last_features=True
29 | )
30 | error_files = []
31 | for audio_path in tqdm(input_audio_list, desc="Processing audios"):
32 | try:
33 | process_audio(audio_path, output_dir, audio_processor)
34 | except:
35 | error_files.append(audio_path)
36 | print("Error file:")
37 | for error_file in error_files:
38 | print(error_file)
39 |
40 |
41 | if __name__ == "__main__":
42 | input_dir = Path("path/to/your/output/folder") # Set your input directory
43 | output_dir = os.path.join(input_dir.parent, "audio_emb")
44 | os.makedirs(output_dir, exist_ok=True)
45 |
46 | audio_paths = get_audio_paths(input_dir)
47 | process_all_audios(audio_paths, output_dir)
48 |
--------------------------------------------------------------------------------
/models/vae/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | from pathlib import Path
5 | from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
6 |
7 | VAE_PATH = {"884-16c-hy": "./ckpts/hunyuan-video-t2v-720p/vae"}
8 |
9 | def load_vae(vae_type: str="884-16c-hy",
10 | vae_precision: str=None,
11 | sample_size: tuple=None,
12 | vae_path: str=None,
13 | ):
14 | """the fucntion to load the 3D VAE model
15 |
16 | Args:
17 | vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy".
18 | vae_precision (str, optional): the precision to load vae. Defaults to None.
19 | sample_size (tuple, optional): the tiling size. Defaults to None.
20 | vae_path (str, optional): the path to vae. Defaults to None.
21 | logger (_type_, optional): logger. Defaults to None.
22 | device (_type_, optional): device to load vae. Defaults to None.
23 | """
24 | if vae_path is None:
25 | vae_path = VAE_PATH[vae_type]
26 |
27 | config = AutoencoderKLCausal3D.load_config(vae_path)
28 | if sample_size:
29 | vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
30 | else:
31 | vae = AutoencoderKLCausal3D.from_config(config)
32 |
33 | vae_ckpt = Path(vae_path) / "pytorch_model.pt"
34 | assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}"
35 |
36 | ckpt = torch.load(vae_ckpt, map_location=vae.device)
37 | if "state_dict" in ckpt:
38 | ckpt = ckpt["state_dict"]
39 | if any(k.startswith("vae.") for k in ckpt.keys()):
40 | ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
41 | vae.load_state_dict(ckpt)
42 |
43 | spatial_compression_ratio = vae.config.spatial_compression_ratio
44 | time_compression_ratio = vae.config.time_compression_ratio
45 |
46 | vae.requires_grad_(False)
47 | vae.eval()
48 |
49 | return vae, vae_path, spatial_compression_ratio, time_compression_ratio
50 |
51 |
52 | if __name__ == "__main__":
53 | vae_name = "884-16c-hy"
54 | vae_precision = "vae-precision"
55 | vae, _, s_ratio, t_ratio = load_vae(
56 | vae_name,
57 | )
58 | vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
59 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .model_cat import VDTcat_models
2 | from .model_cat_long import VDTcatlong_models
3 | from .lite_cat_long import Litecatlong_models
4 |
5 |
6 | def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit
7 | from torch.optim.lr_scheduler import LambdaLR
8 | def fn(step):
9 | if warmup_steps > 0:
10 | return min(step / warmup_steps, 1)
11 | else:
12 | return 1
13 | return LambdaLR(optimizer, fn)
14 |
15 |
16 | def get_lr_scheduler(optimizer, name, **kwargs):
17 | if name == 'warmup':
18 | return customized_lr_scheduler(optimizer, **kwargs)
19 | elif name == 'cosine':
20 | from torch.optim.lr_scheduler import CosineAnnealingLR
21 | return CosineAnnealingLR(optimizer, **kwargs)
22 | else:
23 | raise NotImplementedError(name)
24 |
25 |
26 | def get_models(args):
27 | if 'Litecatlong' in args.model:
28 | return Litecatlong_models[args.model](
29 | input_size=args.latent_size,
30 | context_dim=args.audio_dim,
31 | in_channels=args.in_channels,
32 | num_classes=args.num_classes,
33 | num_frames=args.clip_frames,
34 | initial_frames=args.initial_frames,
35 | learn_sigma=args.learn_sigma,
36 | extras=args.extras,
37 | temp_comp_rate=args.temp_comp_rate,
38 | gradient_checkpointing=args.gradient_checkpointing
39 | )
40 | elif 'VDTcatlong' in args.model:
41 | return VDTcatlong_models[args.model](
42 | input_size=args.latent_size,
43 | context_dim=args.audio_dim,
44 | in_channels=args.in_channels,
45 | num_classes=args.num_classes,
46 | num_frames=args.clip_frames,
47 | initial_frames=args.initial_frames,
48 | learn_sigma=args.learn_sigma,
49 | extras=args.extras,
50 | temp_comp_rate=args.temp_comp_rate,
51 | gradient_checkpointing=args.gradient_checkpointing
52 | )
53 | elif 'VDTcat' in args.model:
54 | return VDTcat_models[args.model](
55 | input_size=args.latent_size,
56 | context_dim=args.audio_dim,
57 | in_channels=args.in_channels,
58 | num_classes=args.num_classes,
59 | num_frames=args.clip_frames,
60 | learn_sigma=args.learn_sigma,
61 | extras=args.extras,
62 | temp_comp_rate=args.temp_comp_rate,
63 | gradient_checkpointing=args.gradient_checkpointing
64 | )
65 | else:
66 | raise '{} Model Not Supported!'.format(args.model)
67 |
--------------------------------------------------------------------------------
/preparation/extract_frames.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from pathlib import Path
3 | from concurrent.futures import ThreadPoolExecutor, as_completed
4 |
5 |
6 | def convert_video_to_images(video_path: Path, output_dir: Path):
7 | """
8 | Convert a video file into a sequence of images at 25 fps and save them in the output directory.
9 |
10 | Args:
11 | video_path (Path): The path to the input video file.
12 | output_dir (Path): The directory where the extracted images will be saved.
13 |
14 | Returns:
15 | None
16 | """
17 | # Ensure the output directory exists
18 | output_dir.mkdir(parents=True, exist_ok=True)
19 |
20 | # ffmpeg command to convert video to images at 25 fps
21 | ffmpeg_command = [
22 | 'ffmpeg',
23 | '-i', str(video_path),
24 | '-vf', 'fps=25',
25 | str(output_dir / '%04d.jpg') # Save images as 0001.png, 0002.png, etc.
26 | ]
27 |
28 | try:
29 | print(f"Running command: {' '.join(ffmpeg_command)}")
30 | subprocess.run(ffmpeg_command, check=True)
31 | except subprocess.CalledProcessError as e:
32 | print(f"Error converting video {video_path} to images: {e}")
33 | raise
34 |
35 |
36 | def process_videos_in_folder(folder_path: Path, output_root: Path, max_workers=4):
37 | """
38 | Traverse through all mp4 files in a folder and convert each video to images in parallel using threads.
39 |
40 | Args:
41 | folder_path (Path): The directory containing mp4 files.
42 | output_root (Path): The root directory where extracted frames will be saved.
43 | max_workers (int): Maximum number of threads for parallel processing.
44 |
45 | Returns:
46 | None
47 | """
48 | # Gather all mp4 files
49 | video_files = list(folder_path.glob('*.mp4'))
50 |
51 | # Use ThreadPoolExecutor to process videos in parallel
52 | with ThreadPoolExecutor(max_workers=max_workers) as executor:
53 | futures = []
54 | for video_file in video_files:
55 | # Create a directory named after the video file (without extension)
56 | output_dir = output_root / video_file.stem
57 | print(f"Submitting video: {video_file.name} for processing")
58 |
59 | # Submit the task to the thread pool
60 | futures.append(executor.submit(convert_video_to_images, video_file, output_dir))
61 |
62 | # Wait for all futures to complete
63 | for future in as_completed(futures):
64 | try:
65 | future.result() # This will raise any exception that occurred during the task execution
66 | except Exception as e:
67 | print(f"Error during processing: {e}")
68 |
69 |
70 | if __name__ == "__main__":
71 | # Define the input folder containing .mp4 videos and the output root directory
72 | input_folder = Path("path/to/your/input/folder") # Replace with your input folder path
73 | output_folder = Path("path/to/your/output/folder") # Replace with your desired output folder path
74 |
75 | # Start processing with multi-threading (adjust the number of threads by changing max_workers)
76 | process_videos_in_folder(input_folder, output_folder, max_workers=4)
77 |
--------------------------------------------------------------------------------
/diffusion/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import torch as th
7 | import numpy as np
8 |
9 |
10 | def normal_kl(mean1, logvar1, mean2, logvar2):
11 | """
12 | Compute the KL divergence between two gaussians.
13 | Shapes are automatically broadcasted, so batches can be compared to
14 | scalars, among other use cases.
15 | """
16 | tensor = None
17 | for obj in (mean1, logvar1, mean2, logvar2):
18 | if isinstance(obj, th.Tensor):
19 | tensor = obj
20 | break
21 | assert tensor is not None, "at least one argument must be a Tensor"
22 |
23 | # Force variances to be Tensors. Broadcasting helps convert scalars to
24 | # Tensors, but it does not work for th.exp().
25 | logvar1, logvar2 = [
26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27 | for x in (logvar1, logvar2)
28 | ]
29 |
30 | return 0.5 * (
31 | -1.0
32 | + logvar2
33 | - logvar1
34 | + th.exp(logvar1 - logvar2)
35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36 | )
37 |
38 |
39 | def approx_standard_normal_cdf(x):
40 | """
41 | A fast approximation of the cumulative distribution function of the
42 | standard normal.
43 | """
44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45 |
46 |
47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48 | """
49 | Compute the log-likelihood of a continuous Gaussian distribution.
50 | :param x: the targets
51 | :param means: the Gaussian mean Tensor.
52 | :param log_scales: the Gaussian log stddev Tensor.
53 | :return: a tensor like x of log probabilities (in nats).
54 | """
55 | centered_x = x - means
56 | inv_stdv = th.exp(-log_scales)
57 | normalized_x = centered_x * inv_stdv
58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59 | return log_probs
60 |
61 |
62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63 | """
64 | Compute the log-likelihood of a Gaussian distribution discretizing to a
65 | given image.
66 | :param x: the target images. It is assumed that this was uint8 values,
67 | rescaled to the range [-1, 1].
68 | :param means: the Gaussian mean Tensor.
69 | :param log_scales: the Gaussian log stddev Tensor.
70 | :return: a tensor like x of log probabilities (in nats).
71 | """
72 | assert x.shape == means.shape == log_scales.shape
73 | centered_x = x - means
74 | inv_stdv = th.exp(-log_scales)
75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76 | cdf_plus = approx_standard_normal_cdf(plus_in)
77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78 | cdf_min = approx_standard_normal_cdf(min_in)
79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81 | cdf_delta = cdf_plus - cdf_min
82 | log_probs = th.where(
83 | x < -0.999,
84 | log_cdf_plus,
85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86 | )
87 | assert log_probs.shape == x.shape
88 | return log_probs
89 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
Efficient Long-duration Talking Video Synthesis with Linear Diffusion Transformer under Multimodal Guidance
4 |
5 | ####
[Haojie Zhang](https://zhang-haojie.github.io/), [Zhihao Liang](https://lzhnb.github.io/), Ruibo Fu, Bingyan Liu, Zhengqi Wen,
6 | ####
Xuefei Liu, Jianhua Tao, Yaling Liang
7 |
8 |

9 |

10 |
11 |
12 |
13 |
14 | ## 🚀 Introduction
15 | **TL;DR:** We propose LetsTalk, a diffusion transformer for audio-driven portrait animation. By leveraging DC-VAE and linear attention, LetsTalk enables efficient multimodal fusion and consistent portrait generation, while memory bank and noise-regularized training further improve the quality and stability of long-duration videos.
16 |
17 |
18 |

19 |
20 |
21 | **Abstract:** Long-duration talking video synthesis faces enduring challenges in achieving high video quality, portrait and temporal consistency, and computational efficiency. As video length increases, issues such as visual degradation, identity inconsistency, temporal incoherence, and error accumulation become increasingly problematic, severely affecting the realism and reliability of the results.
22 | To address these challenges, we present LetsTalk, a diffusion transformer framework equipped with multimodal guidance and a novel memory bank mechanism, explicitly maintaining contextual continuity and enabling robust, high-quality, and efficient generation of long-duration talking videos. In particular, LetsTalk introduces a noise-regularized memory bank to alleviate error accumulation and sampling artifacts during extended video generation. To further improve efficiency and spatiotemporal consistency, LetsTalk employs a deep compression autoencoder and a spatiotemporal-aware transformer with linear attention for effective multimodal fusion. We systematically analyze three fusion schemes and show that combining deep (Symbiotic Fusion) for portrait features and shallow (Direct Fusion) for audio achieves superior visual realism and precise speech-driven motion, while preserving diversity of movements. Extensive experiments demonstrate that LetsTalk establishes new state-of-the-art in generation quality, producing temporally coherent and realistic talking videos with enhanced diversity and liveliness, and maintains remarkable efficiency with 8x fewer parameters than previous approaches.
23 |
24 |
25 | ## 🎁 Overview
26 |
27 |
28 |

29 |
30 |
31 | Overview of our LetsTalk framework for robust long-duration talking head video generation. Our system combines a deep compression autoencoder to reduce spatial redundancy while preserving temporal features, and transformer blocks with intertwined temporal and spatial attention to effectively capture both intra-frame details and long-range dependencies.
32 | Portrait and audio embeddings are extracted; Symbiotic Fusion integrates the portrait embedding, and Direct Fusion incorporates the audio embedding, providing effective multimodal guidance for video synthesis. Portrait embeddings are repeated along the temporal axis for consistent conditioning across frames.
33 | To further support long-sequence generation, a memory bank module is introduced to maintain temporal consistency, while a dedicated noise-regularized training strategy helps align the memory bank usage between training and inference stages, ensuring stable and high-fidelity generation.
34 |
35 |
36 |

37 |
38 |
39 | Illustration of three multimodal fusion schemes, our transformer backbone is formed by the left-side blocks.
40 |
41 | (a) **Direct Fusion**. Directly feeding condition into each block's cross-attention module;
42 |
43 | (b) **Siamese Fusion**. Maintaining a similar transformer and feeding the condition into it, extracting the corresponding features to guide the features in the backbone;
44 |
45 | (c) **Symbiotic Fusion**. Concatenating modality with the input at the beginning, then feeding it into the backbone, achieving fusion via the inherent self-attention mechanisms.
46 |
47 |
50 |
51 |
60 |
61 |
62 |
63 | ## 🎫 Citation
64 | If you find this project useful in your research, please consider the citation:
65 |
66 | ```BibTeX
67 | @article{zhang2024efficient,
68 | title={Efficient Long-duration Talking Video Synthesis with Linear Diffusion Transformer under Multimodal Guidance},
69 | author={Zhang, Haojie and Liang, Zhihao and Fu, Ruibo and Liu, Bingyan and Wen, Zhengqi and Liu, Xuefei and Tao, Jianhua and Liang, Yaling},
70 | journal={arXiv preprint arXiv:2411.16748},
71 | year={2024}
72 | }
73 | ```
74 |
--------------------------------------------------------------------------------
/models/audio_proj.py:
--------------------------------------------------------------------------------
1 | """
2 | This module provides the implementation of an Audio Projection Model, which is designed for
3 | audio processing tasks. The model takes audio embeddings as input and outputs context tokens
4 | that can be used for various downstream applications, such as audio analysis or synthesis.
5 |
6 | The AudioProjModel class is based on the ModelMixin class from the diffusers library, which
7 | provides a foundation for building custom models. This implementation includes multiple linear
8 | layers with ReLU activation functions and a LayerNorm for normalization.
9 |
10 | Key Features:
11 | - Audio embedding input with flexible sequence length and block structure.
12 | - Multiple linear layers for feature transformation.
13 | - ReLU activation for non-linear transformation.
14 | - LayerNorm for stabilizing and speeding up training.
15 | - Rearrangement of input embeddings to match the model's expected input shape.
16 | - Customizable number of blocks, channels, and context tokens for adaptability.
17 |
18 | The module is structured to be easily integrated into larger systems or used as a standalone
19 | component for audio feature extraction and processing.
20 |
21 | Classes:
22 | - AudioProjModel: A class representing the audio projection model with configurable parameters.
23 |
24 | Functions:
25 | - (none)
26 |
27 | Dependencies:
28 | - torch: For tensor operations and neural network components.
29 | - diffusers: For the ModelMixin base class.
30 | - einops: For tensor rearrangement operations.
31 |
32 | """
33 |
34 | import torch
35 | from diffusers import ModelMixin
36 | from einops import rearrange
37 | from torch import nn
38 |
39 |
40 | class AudioProjModel(ModelMixin):
41 | """Audio Projection Model
42 |
43 | This class defines an audio projection model that takes audio embeddings as input
44 | and produces context tokens as output. The model is based on the ModelMixin class
45 | and consists of multiple linear layers and activation functions. It can be used
46 | for various audio processing tasks.
47 |
48 | Attributes:
49 | seq_len (int): The length of the audio sequence.
50 | blocks (int): The number of blocks in the audio projection model.
51 | channels (int): The number of channels in the audio projection model.
52 | intermediate_dim (int): The intermediate dimension of the model.
53 | context_tokens (int): The number of context tokens in the output.
54 | output_dim (int): The output dimension of the context tokens.
55 |
56 | Methods:
57 | __init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768):
58 | Initializes the AudioProjModel with the given parameters.
59 | forward(self, audio_embeds):
60 | Defines the forward pass for the AudioProjModel.
61 | Parameters:
62 | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).
63 | Returns:
64 | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).
65 |
66 | """
67 |
68 | def __init__(
69 | self,
70 | seq_len=5,
71 | blocks=12, # add a new parameter blocks
72 | channels=768, # add a new parameter channels
73 | intermediate_dim=512,
74 | output_dim=768,
75 | context_tokens=32,
76 | ):
77 | super().__init__()
78 |
79 | self.seq_len = seq_len
80 | self.blocks = blocks
81 | self.channels = channels
82 | self.input_dim = (
83 | seq_len * blocks * channels
84 | ) # update input_dim to be the product of blocks and channels.
85 | self.intermediate_dim = intermediate_dim
86 | self.context_tokens = context_tokens
87 | self.output_dim = output_dim
88 |
89 | # define multiple linear layers
90 | self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
91 | self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
92 | self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
93 |
94 | self.norm = nn.LayerNorm(output_dim)
95 |
96 | def forward(self, audio_embeds):
97 | """
98 | Defines the forward pass for the AudioProjModel.
99 |
100 | Parameters:
101 | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).
102 |
103 | Returns:
104 | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).
105 | """
106 | # merge
107 | video_length = audio_embeds.shape[1]
108 | audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
109 | batch_size, window_size, blocks, channels = audio_embeds.shape
110 | audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
111 |
112 | audio_embeds = torch.relu(self.proj1(audio_embeds))
113 | audio_embeds = torch.relu(self.proj2(audio_embeds))
114 |
115 | context_tokens = self.proj3(audio_embeds).reshape(
116 | batch_size, self.context_tokens, self.output_dim
117 | )
118 |
119 | context_tokens = self.norm(context_tokens)
120 | context_tokens = rearrange(
121 | context_tokens, "(bz f) m c -> bz f m c", f=video_length
122 | )
123 |
124 | return context_tokens
125 |
--------------------------------------------------------------------------------
/diffusion/respace.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 | import torch
6 | import numpy as np
7 | import torch as th
8 |
9 | from .gaussian_diffusion import GaussianDiffusion
10 |
11 |
12 | def space_timesteps(num_timesteps, section_counts):
13 | """
14 | Create a list of timesteps to use from an original diffusion process,
15 | given the number of timesteps we want to take from equally-sized portions
16 | of the original process.
17 | For example, if there's 300 timesteps and the section counts are [10,15,20]
18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
19 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
20 | If the stride is a string starting with "ddim", then the fixed striding
21 | from the DDIM paper is used, and only one section is allowed.
22 | :param num_timesteps: the number of diffusion steps in the original
23 | process to divide up.
24 | :param section_counts: either a list of numbers, or a string containing
25 | comma-separated numbers, indicating the step count
26 | per section. As a special case, use "ddimN" where N
27 | is a number of steps to use the striding from the
28 | DDIM paper.
29 | :return: a set of diffusion steps from the original process to use.
30 | """
31 | if isinstance(section_counts, str):
32 | if section_counts.startswith("ddim"):
33 | desired_count = int(section_counts[len("ddim") :])
34 | for i in range(1, num_timesteps):
35 | if len(range(0, num_timesteps, i)) == desired_count:
36 | return set(range(0, num_timesteps, i))
37 | raise ValueError(
38 | f"cannot create exactly {num_timesteps} steps with an integer stride"
39 | )
40 | section_counts = [int(x) for x in section_counts.split(",")]
41 | size_per = num_timesteps // len(section_counts)
42 | extra = num_timesteps % len(section_counts)
43 | start_idx = 0
44 | all_steps = []
45 | for i, section_count in enumerate(section_counts):
46 | size = size_per + (1 if i < extra else 0)
47 | if size < section_count:
48 | raise ValueError(
49 | f"cannot divide section of {size} steps into {section_count}"
50 | )
51 | if section_count <= 1:
52 | frac_stride = 1
53 | else:
54 | frac_stride = (size - 1) / (section_count - 1)
55 | cur_idx = 0.0
56 | taken_steps = []
57 | for _ in range(section_count):
58 | taken_steps.append(start_idx + round(cur_idx))
59 | cur_idx += frac_stride
60 | all_steps += taken_steps
61 | start_idx += size
62 | return set(all_steps)
63 |
64 |
65 | class SpacedDiffusion(GaussianDiffusion):
66 | """
67 | A diffusion process which can skip steps in a base diffusion process.
68 | :param use_timesteps: a collection (sequence or set) of timesteps from the
69 | original diffusion process to retain.
70 | :param kwargs: the kwargs to create the base diffusion process.
71 | """
72 |
73 | def __init__(self, use_timesteps, **kwargs):
74 | self.use_timesteps = set(use_timesteps)
75 | self.timestep_map = []
76 | self.original_num_steps = len(kwargs["betas"])
77 |
78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79 | last_alpha_cumprod = 1.0
80 | new_betas = []
81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82 | if i in self.use_timesteps:
83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84 | last_alpha_cumprod = alpha_cumprod
85 | self.timestep_map.append(i)
86 | kwargs["betas"] = np.array(new_betas)
87 | super().__init__(**kwargs)
88 |
89 | def p_mean_variance(
90 | self, model, *args, **kwargs
91 | ): # pylint: disable=signature-differs
92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93 |
94 | # @torch.compile
95 | def training_losses(
96 | self, model, *args, **kwargs
97 | ): # pylint: disable=signature-differs
98 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
99 |
100 | def condition_mean(self, cond_fn, *args, **kwargs):
101 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
102 |
103 | def condition_score(self, cond_fn, *args, **kwargs):
104 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
105 |
106 | def _wrap_model(self, model):
107 | if isinstance(model, _WrappedModel):
108 | return model
109 | return _WrappedModel(
110 | model, self.timestep_map, self.original_num_steps
111 | )
112 |
113 | def _scale_timesteps(self, t):
114 | # Scaling is done by the wrapped model.
115 | return t
116 |
117 |
118 | class _WrappedModel:
119 | def __init__(self, model, timestep_map, original_num_steps):
120 | self.model = model
121 | self.timestep_map = timestep_map
122 | # self.rescale_timesteps = rescale_timesteps
123 | self.original_num_steps = original_num_steps
124 |
125 | def __call__(self, x, ts, **kwargs):
126 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
127 | new_ts = map_tensor[ts]
128 | # if self.rescale_timesteps:
129 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
130 | return self.model(x, new_ts, **kwargs)
131 |
--------------------------------------------------------------------------------
/sample_pl.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import imageio
4 | import logging
5 | import argparse
6 | from einops import rearrange, repeat
7 | from pytorch_lightning import LightningModule, Trainer
8 | from torch.utils.data import DataLoader
9 | from models import get_models
10 | from models.audio_proj import AudioProjModel
11 | from diffusion import create_diffusion
12 | from diffusers.models import AutoencoderKL
13 | from omegaconf import OmegaConf
14 | from utils import find_model, cleanup
15 | from datasets import get_dataset
16 |
17 |
18 | class LatteSamplingModule(LightningModule):
19 | def __init__(self, args, logger: logging.Logger):
20 | super(LatteSamplingModule, self).__init__()
21 | self.args = args
22 | self.logging = logger
23 | self.model = get_models(args).to(self.device)
24 | self.audioproj = AudioProjModel(
25 | seq_len=5,
26 | blocks=12,
27 | channels=args.audio_dim,
28 | intermediate_dim=512,
29 | output_dim=args.audio_dim,
30 | context_tokens=args.audio_token,
31 | )
32 |
33 | state_dict, audioproj_dict = find_model(args.pretrained)
34 | self.model.load_state_dict(state_dict)
35 | self.logging.info(f"Loaded model checkpoint from {args.pretrained}")
36 | self.audioproj.load_state_dict(audioproj_dict)
37 |
38 | self.diffusion = create_diffusion(str(args.num_sampling_steps))
39 | self.vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(self.device)
40 |
41 | self.model.eval()
42 | self.vae.eval()
43 | self.audioproj.eval()
44 |
45 | def sample(self, z, model_kwargs, sample_method='ddpm'):
46 | if sample_method == 'ddim':
47 | return self.diffusion.ddim_sample_loop(
48 | self.model.forward_with_cfg if self.args.cfg_scale > 1.0 else self.model.forward,
49 | z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=self.device
50 | )
51 | elif sample_method == 'ddpm':
52 | return self.diffusion.p_sample_loop(
53 | self.model.forward_with_cfg if self.args.cfg_scale > 1.0 else self.model.forward,
54 | z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=self.device
55 | )
56 | else:
57 | raise ValueError(f"Unknown sample method: {sample_method}")
58 |
59 | @torch.no_grad()
60 | def validation_step(self, batch, batch_idx):
61 | video_names = batch["video_name"]
62 | if "latent" in self.args.dataset:
63 | ref_latents = batch["ref_latent"]
64 | else:
65 | ref_image = batch["ref_image"]
66 | ref_latents = self.vae.encode(ref_image).latent_dist.sample().mul_(0.18215)
67 |
68 | ref_latents = repeat(ref_latents, "b c h w -> b f c h w", f=self.args.clip_frames)
69 | model_kwargs = dict(y=None, cond=ref_latents)
70 |
71 | local_batch_size = ref_latents.size(0)
72 | z = torch.randn(local_batch_size, self.args.clip_frames, self.args.in_channels, self.args.latent_size, self.args.latent_size, device=self.device)
73 | if "audio" in batch:
74 | audio_emb = batch["audio"]
75 | audio_emb = self.audioproj(audio_emb)
76 | model_kwargs.update(audio_embed=audio_emb)
77 |
78 | if self.args.cfg_scale > 1.0:
79 | z = torch.cat([z, z], 0)
80 | model_kwargs.update(cfg_scale=self.args.cfg_scale)
81 |
82 | samples = self.sample(z, model_kwargs, sample_method=self.args.sample_method)
83 |
84 | samples = rearrange(samples, 'b f c h w -> (b f) c h w')
85 | samples = self.vae.decode(samples / 0.18215).sample
86 | samples = rearrange(samples, '(b f) c h w -> b f c h w', b=local_batch_size)
87 |
88 | for sample, video_name in zip(samples, video_names):
89 | video_ = ((sample * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1).contiguous()
90 | video_save_path = os.path.join(self.args.save_video_path, f"{video_name}.mp4")
91 | imageio.mimwrite(video_save_path, video_, fps=self.args.fps, quality=9)
92 | self.logging.info(f"Saved video at {video_save_path}")
93 |
94 | return video_save_path
95 |
96 |
97 | def main(args):
98 | # Setup logger
99 | logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
100 | logger = logging.getLogger(__name__)
101 |
102 | # Determine if the current process is the main process (rank 0)
103 | is_main_process = (int(os.environ.get("LOCAL_RANK", 0)) == 0)
104 | if is_main_process:
105 | os.makedirs(args.save_video_path, exist_ok=True)
106 | print(f"Saving .mp4 samples at {args.save_video_path}")
107 |
108 | # Create dataset and dataloader
109 | dataset = get_dataset(args)
110 | val_loader = DataLoader(
111 | dataset,
112 | batch_size=args.per_proc_batch_size, # Batch size 1 for sampling
113 | shuffle=False,
114 | num_workers=args.num_workers,
115 | pin_memory=True,
116 | )
117 | logger.info(f"Validation set contains {len(dataset)} samples")
118 |
119 | sample_size = args.image_size // 8
120 | args.latent_size = sample_size
121 |
122 | # Initialize the sampling module
123 | pl_module = LatteSamplingModule(args, logger)
124 |
125 | # Trainer
126 | trainer = Trainer(
127 | accelerator="gpu",
128 | devices=[0], # Specify GPU ids
129 | strategy="auto",
130 | logger=False,
131 | precision=args.precision if args.precision else "32-true",
132 | )
133 |
134 | # Run validation to generate samples
135 | trainer.validate(pl_module, dataloaders=val_loader)
136 |
137 | logger.info("Sampling completed!")
138 |
139 | if __name__ == "__main__":
140 | parser = argparse.ArgumentParser()
141 | parser.add_argument("--config", type=str, default="./configs/sample.yaml")
142 | args = parser.parse_args()
143 | omega_conf = OmegaConf.load(args.config)
144 | main(omega_conf)
--------------------------------------------------------------------------------
/diffusion/timestep_sampler.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from abc import ABC, abstractmethod
7 |
8 | import numpy as np
9 | import torch as th
10 | import torch.distributed as dist
11 |
12 |
13 | def create_named_schedule_sampler(name, diffusion):
14 | """
15 | Create a ScheduleSampler from a library of pre-defined samplers.
16 | :param name: the name of the sampler.
17 | :param diffusion: the diffusion object to sample for.
18 | """
19 | if name == "uniform":
20 | return UniformSampler(diffusion)
21 | elif name == "loss-second-moment":
22 | return LossSecondMomentResampler(diffusion)
23 | else:
24 | raise NotImplementedError(f"unknown schedule sampler: {name}")
25 |
26 |
27 | class ScheduleSampler(ABC):
28 | """
29 | A distribution over timesteps in the diffusion process, intended to reduce
30 | variance of the objective.
31 | By default, samplers perform unbiased importance sampling, in which the
32 | objective's mean is unchanged.
33 | However, subclasses may override sample() to change how the resampled
34 | terms are reweighted, allowing for actual changes in the objective.
35 | """
36 |
37 | @abstractmethod
38 | def weights(self):
39 | """
40 | Get a numpy array of weights, one per diffusion step.
41 | The weights needn't be normalized, but must be positive.
42 | """
43 |
44 | def sample(self, batch_size, device):
45 | """
46 | Importance-sample timesteps for a batch.
47 | :param batch_size: the number of timesteps.
48 | :param device: the torch device to save to.
49 | :return: a tuple (timesteps, weights):
50 | - timesteps: a tensor of timestep indices.
51 | - weights: a tensor of weights to scale the resulting losses.
52 | """
53 | w = self.weights()
54 | p = w / np.sum(w)
55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56 | indices = th.from_numpy(indices_np).long().to(device)
57 | weights_np = 1 / (len(p) * p[indices_np])
58 | weights = th.from_numpy(weights_np).float().to(device)
59 | return indices, weights
60 |
61 |
62 | class UniformSampler(ScheduleSampler):
63 | def __init__(self, diffusion):
64 | self.diffusion = diffusion
65 | self._weights = np.ones([diffusion.num_timesteps])
66 |
67 | def weights(self):
68 | return self._weights
69 |
70 |
71 | class LossAwareSampler(ScheduleSampler):
72 | def update_with_local_losses(self, local_ts, local_losses):
73 | """
74 | Update the reweighting using losses from a model.
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 | :param local_ts: an integer Tensor of timesteps.
80 | :param local_losses: a 1D Tensor of losses.
81 | """
82 | batch_sizes = [
83 | th.tensor([0], dtype=th.int32, device=local_ts.device)
84 | for _ in range(dist.get_world_size())
85 | ]
86 | dist.all_gather(
87 | batch_sizes,
88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89 | )
90 |
91 | # Pad all_gather batches to be the maximum batch size.
92 | batch_sizes = [x.item() for x in batch_sizes]
93 | max_bs = max(batch_sizes)
94 |
95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97 | dist.all_gather(timestep_batches, local_ts)
98 | dist.all_gather(loss_batches, local_losses)
99 | timesteps = [
100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101 | ]
102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103 | self.update_with_all_losses(timesteps, losses)
104 |
105 | @abstractmethod
106 | def update_with_all_losses(self, ts, losses):
107 | """
108 | Update the reweighting using losses from a model.
109 | Sub-classes should override this method to update the reweighting
110 | using losses from the model.
111 | This method directly updates the reweighting without synchronizing
112 | between workers. It is called by update_with_local_losses from all
113 | ranks with identical arguments. Thus, it should have deterministic
114 | behavior to maintain state across workers.
115 | :param ts: a list of int timesteps.
116 | :param losses: a list of float losses, one per timestep.
117 | """
118 |
119 |
120 | class LossSecondMomentResampler(LossAwareSampler):
121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122 | self.diffusion = diffusion
123 | self.history_per_term = history_per_term
124 | self.uniform_prob = uniform_prob
125 | self._loss_history = np.zeros(
126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
127 | )
128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129 |
130 | def weights(self):
131 | if not self._warmed_up():
132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134 | weights /= np.sum(weights)
135 | weights *= 1 - self.uniform_prob
136 | weights += self.uniform_prob / len(weights)
137 | return weights
138 |
139 | def update_with_all_losses(self, ts, losses):
140 | for t, loss in zip(ts, losses):
141 | if self._loss_counts[t] == self.history_per_term:
142 | # Shift out the oldest loss term.
143 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
144 | self._loss_history[t, -1] = loss
145 | else:
146 | self._loss_history[t, self._loss_counts[t]] = loss
147 | self._loss_counts[t] += 1
148 |
149 | def _warmed_up(self):
150 | return (self._loss_counts == self.history_per_term).all()
151 |
--------------------------------------------------------------------------------
/sample.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import torch
4 | import argparse
5 | import numpy as np
6 | from tqdm import tqdm
7 | from PIL import Image
8 | from torchvision import transforms
9 | from pathlib import Path
10 | from omegaconf import OmegaConf
11 | from diffusion import create_diffusion
12 | from diffusers.models import AutoencoderKL
13 | from einops import rearrange, repeat
14 | from models import get_models
15 | from models.audio_proj import AudioProjModel
16 | from preparation.audio_processor import AudioProcessor
17 | from utils import find_model, combine_video_audio, tensor_to_video
18 |
19 |
20 | def process_audio_emb(audio_emb):
21 | """
22 | Process the audio embedding to concatenate with other tensors.
23 |
24 | Parameters:
25 | audio_emb (torch.Tensor): The audio embedding tensor to process.
26 |
27 | Returns:
28 | concatenated_tensors (List[torch.Tensor]): The concatenated tensor list.
29 | """
30 | concatenated_tensors = []
31 |
32 | for i in range(audio_emb.shape[0]):
33 | vectors_to_concat = [
34 | audio_emb[max(min(i + j, audio_emb.shape[0]-1), 0)]for j in range(-2, 3)]
35 | concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0))
36 | audio_emb = torch.stack(concatenated_tensors, dim=0)
37 | return audio_emb
38 |
39 |
40 | @torch.no_grad()
41 | def main(args):
42 | torch.backends.cuda.matmul.allow_tf32 = True
43 | assert torch.cuda.is_available(), "Sampling requires at least one GPU."
44 | torch.set_grad_enabled(False)
45 |
46 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47 |
48 | # Load model
49 | latent_size = args.image_size // 8
50 | args.latent_size = latent_size
51 | model = get_models(args).to(device)
52 | state_dict, audioproj_dict = find_model(args.pretrained)
53 | model.load_state_dict(state_dict)
54 | model.eval()
55 |
56 | audioproj = AudioProjModel(
57 | seq_len=5,
58 | blocks=12,
59 | channels=args.audio_dim,
60 | intermediate_dim=512,
61 | output_dim=args.audio_dim,
62 | context_tokens=args.audio_token,
63 | ).to(device)
64 | audioproj.load_state_dict(audioproj_dict)
65 | audioproj.eval()
66 |
67 | sample_rate = args.sample_rate
68 | assert sample_rate == 16000, "audio sample rate must be 16000"
69 | fps = args.fps
70 | wav2vec_model_path = args.wav2vec
71 | audio_separator_model_file = args.audio_separator
72 |
73 | diffusion = create_diffusion(str(args.num_sampling_steps))
74 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="sd-vae-ft-ema").to(device)
75 | transform = transforms.Compose(
76 | [
77 | transforms.ToTensor(),
78 | transforms.Resize(args.image_size),
79 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
80 | ]
81 | )
82 |
83 | vae.requires_grad_(False)
84 | audioproj.requires_grad_(False)
85 | model.requires_grad_(False)
86 |
87 | # Prepare output directory
88 | os.makedirs(args.save_video_path, exist_ok=True)
89 |
90 | # Iterate through reference image folder
91 | ref_image_folder = args.data_dir
92 | audio_folder = args.audio_dir
93 | ref_image_paths = glob.glob(os.path.join(ref_image_folder, "*.jpg"))
94 |
95 | # Add progress bar for reference image processing
96 | for ref_image_path in ref_image_paths:
97 | audio_path = os.path.join(audio_folder, f"{Path(ref_image_path).stem}.wav")
98 |
99 | if not os.path.exists(audio_path):
100 | print(f"Warning: Audio file not found for {audio_path}")
101 | continue
102 |
103 | ref_name = Path(ref_image_path).stem
104 | video_save_path = os.path.join(args.save_video_path, f"{ref_name}.mp4")
105 | if os.path.exists(video_save_path):
106 | continue
107 |
108 | clip_length = args.clip_frames
109 | with AudioProcessor(
110 | sample_rate,
111 | fps,
112 | wav2vec_model_path,
113 | os.path.dirname(audio_separator_model_file),
114 | os.path.basename(audio_separator_model_file),
115 | os.path.join(args.save_video_path, "audio_preprocess")
116 | ) as audio_processor:
117 | audio_emb, audio_length = audio_processor.preprocess(audio_path, clip_length)
118 |
119 | audio_emb = process_audio_emb(audio_emb)
120 | # Load reference image
121 | ref_image_pil = Image.open(ref_image_path).convert("RGB")
122 | ref_image_np = np.array(ref_image_pil)
123 | ref_image_tensor = transform(ref_image_np).unsqueeze(0).to(device)
124 | ref_latents = vae.encode(ref_image_tensor).latent_dist.sample().mul_(0.18215)
125 | ref_latents = repeat(ref_latents, "b c h w -> b f c h w", f=clip_length)
126 |
127 | times = audio_emb.shape[0] // clip_length
128 |
129 | concat_samples = []
130 | for t in tqdm(range(times), desc="Processing"):
131 | audio_tensor = audio_emb[
132 | t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0])
133 | ]
134 |
135 | audio_tensor = audio_tensor.unsqueeze(0)
136 | audio_tensor = audio_tensor.to(
137 | device=audioproj.device, dtype=audioproj.dtype)
138 | audio_tensor = audioproj(audio_tensor)
139 |
140 | z = torch.randn(1, args.clip_frames, args.in_channels, latent_size, latent_size, device=device)
141 | model_kwargs = dict(y=None, cond=ref_latents, audio_embed=audio_tensor)
142 | if args.cfg_scale > 1.0:
143 | z = torch.cat([z, z], 0)
144 | model_kwargs.update(cfg_scale=args.cfg_scale)
145 |
146 | # Sample images:
147 | if args.sample_method == 'ddim':
148 | samples = diffusion.ddim_sample_loop(
149 | model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device
150 | )
151 | elif args.sample_method == 'ddpm':
152 | samples = diffusion.p_sample_loop(
153 | model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device
154 | )
155 |
156 | # Decode and save samples
157 | samples = rearrange(samples, 'b f c h w -> (b f) c h w')
158 | samples = vae.decode(samples / 0.18215).sample.cpu()
159 | concat_samples.append(samples)
160 |
161 | tensor_result = torch.cat(concat_samples)
162 | tensor_result = tensor_result[:audio_length]
163 | tensor_result = ((tensor_result * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1).contiguous()
164 | tensor_to_video(tensor_result, video_save_path, audio_path)
165 | print("Sampling completed!")
166 |
167 |
168 | if __name__ == "__main__":
169 | parser = argparse.ArgumentParser()
170 | parser.add_argument("--config", type=str, default="./configs/sample.yaml")
171 | args = parser.parse_args()
172 | omega_conf = OmegaConf.load(args.config)
173 | main(omega_conf)
174 |
--------------------------------------------------------------------------------
/sample_long.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import torch
4 | import argparse
5 | import numpy as np
6 | from tqdm import tqdm
7 | from PIL import Image
8 | from torchvision import transforms
9 | from pathlib import Path
10 | from omegaconf import OmegaConf
11 | from diffusion import create_diffusion
12 | from diffusers.models import AutoencoderKL
13 | from einops import rearrange, repeat
14 | from models import get_models
15 | from models.audio_proj import AudioProjModel
16 | from preparation.audio_processor import AudioProcessor
17 | from utils import find_model, combine_video_audio, tensor_to_video
18 |
19 |
20 | def process_audio_emb(audio_emb):
21 | """
22 | Process the audio embedding to concatenate with other tensors.
23 |
24 | Parameters:
25 | audio_emb (torch.Tensor): The audio embedding tensor to process.
26 |
27 | Returns:
28 | concatenated_tensors (List[torch.Tensor]): The concatenated tensor list.
29 | """
30 | concatenated_tensors = []
31 |
32 | for i in range(audio_emb.shape[0]):
33 | vectors_to_concat = [
34 | audio_emb[max(min(i + j, audio_emb.shape[0]-1), 0)]for j in range(-2, 3)]
35 | concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0))
36 | audio_emb = torch.stack(concatenated_tensors, dim=0)
37 | return audio_emb
38 |
39 |
40 | @torch.no_grad()
41 | def main(args):
42 | torch.backends.cuda.matmul.allow_tf32 = True
43 | assert torch.cuda.is_available(), "Sampling requires at least one GPU."
44 | torch.set_grad_enabled(False)
45 |
46 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47 |
48 | # Load model
49 | latent_size = args.image_size // 8
50 | args.latent_size = latent_size
51 | model = get_models(args).to(device)
52 | state_dict, audioproj_dict = find_model(args.pretrained)
53 | model.load_state_dict(state_dict)
54 | model.eval()
55 |
56 | audioproj = AudioProjModel(
57 | seq_len=5,
58 | blocks=12,
59 | channels=args.audio_dim,
60 | intermediate_dim=512,
61 | output_dim=args.audio_dim,
62 | context_tokens=args.audio_token,
63 | ).to(device)
64 | audioproj.load_state_dict(audioproj_dict)
65 | audioproj.eval()
66 |
67 | sample_rate = args.sample_rate
68 | assert sample_rate == 16000, "audio sample rate must be 16000"
69 | fps = args.fps
70 | wav2vec_model_path = args.wav2vec
71 | audio_separator_model_file = args.audio_separator
72 |
73 | diffusion = create_diffusion(str(args.num_sampling_steps))
74 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="sd-vae-ft-ema").to(device)
75 | transform = transforms.Compose(
76 | [
77 | transforms.ToTensor(),
78 | transforms.Resize(args.image_size),
79 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
80 | ]
81 | )
82 |
83 | vae.requires_grad_(False)
84 | audioproj.requires_grad_(False)
85 | model.requires_grad_(False)
86 |
87 | # Prepare output directory
88 | os.makedirs(args.save_video_path, exist_ok=True)
89 |
90 | # Iterate through reference image folder
91 | ref_image_folder = args.data_dir
92 | audio_folder = args.audio_dir
93 | ref_image_paths = glob.glob(os.path.join(ref_image_folder, "*.jpg"))
94 | print(f"===== Process folder: {args.data_dir} =====")
95 |
96 | tensor_result = []
97 | # Add progress bar for reference image processing
98 | for ref_image_path in ref_image_paths:
99 | audio_path = os.path.join(audio_folder, f"{Path(ref_image_path).stem}.wav")
100 |
101 | if not os.path.exists(audio_path):
102 | print(f"Warning: Audio file not found for {audio_path}")
103 | continue
104 |
105 | ref_name = Path(ref_image_path).stem
106 | video_save_path = os.path.join(args.save_video_path, f"{ref_name}.mp4")
107 | if os.path.exists(video_save_path):
108 | print(f"Skip video {video_save_path}")
109 | continue
110 |
111 | clip_length = args.clip_frames
112 | with AudioProcessor(
113 | sample_rate,
114 | fps,
115 | wav2vec_model_path,
116 | os.path.dirname(audio_separator_model_file),
117 | os.path.basename(audio_separator_model_file),
118 | os.path.join(args.save_video_path, "audio_preprocess")
119 | ) as audio_processor:
120 | audio_emb, audio_length = audio_processor.preprocess(audio_path, clip_length)
121 |
122 | audio_emb = process_audio_emb(audio_emb)
123 | # Load reference image
124 | ref_image_pil = Image.open(ref_image_path).convert("RGB")
125 | ref_image_np = np.array(ref_image_pil)
126 | ref_image_tensor = transform(ref_image_np).unsqueeze(0).to(device)
127 | ref_latents = vae.encode(ref_image_tensor).latent_dist.sample().mul_(0.18215)
128 | ref_latents = repeat(ref_latents, "b c h w -> b f c h w", f=clip_length)
129 |
130 | times = audio_emb.shape[0] // clip_length
131 |
132 | concat_samples = []
133 | for t in tqdm(range(times), desc="Processing"):
134 | audio_tensor = audio_emb[
135 | t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0])
136 | ]
137 |
138 | audio_tensor = audio_tensor.unsqueeze(0)
139 | audio_tensor = audio_tensor.to(
140 | device=audioproj.device, dtype=audioproj.dtype)
141 | audio_tensor = audioproj(audio_tensor)
142 |
143 | if len(concat_samples) == 0:
144 | # The first iteration
145 | initial_latents = ref_latents[:, 0:args.initial_frames, ...]
146 |
147 | z = torch.randn(1, args.clip_frames, 4, latent_size, latent_size, device=device)
148 | model_kwargs = dict(y=None, cond=ref_latents, motion=initial_latents, audio_embed=audio_tensor)
149 | if args.cfg_scale > 1.0:
150 | z = torch.cat([z, z], 0)
151 | model_kwargs.update(cfg_scale=args.cfg_scale)
152 |
153 | # Sample images:
154 | if args.sample_method == 'ddim':
155 | samples = diffusion.ddim_sample_loop(
156 | model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device
157 | )
158 | elif args.sample_method == 'ddpm':
159 | samples = diffusion.p_sample_loop(
160 | model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device
161 | )
162 |
163 | initial_latents = samples[:, 0 - args.initial_frames:]
164 |
165 | # Decode and save samples
166 | samples = rearrange(samples, 'b f c h w -> (b f) c h w')
167 | samples = vae.decode(samples / 0.18215).sample.cpu()
168 | concat_samples.append(samples)
169 | torch.cuda.empty_cache()
170 |
171 | tensor_result = torch.cat(concat_samples)
172 | tensor_result = tensor_result[:audio_length]
173 | tensor_result = ((tensor_result * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1).contiguous()
174 | tensor_to_video(tensor_result, video_save_path, audio_path)
175 | print("Sampling completed!")
176 |
177 |
178 | if __name__ == "__main__":
179 | parser = argparse.ArgumentParser()
180 | parser.add_argument("--config", type=str, default="./configs/sample.yaml")
181 | parser.add_argument("--data_dir", type=str, default=None)
182 | args = parser.parse_args()
183 | omega_conf = OmegaConf.load(args.config)
184 | if args.data_dir is not None:
185 | omega_conf.data_dir = args.data_dir
186 | main(omega_conf)
187 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | import torch
14 |
15 | import numpy as np
16 | import torch.nn as nn
17 |
18 | from einops import repeat
19 |
20 |
21 | #################################################################################
22 | # Unet Utils #
23 | #################################################################################
24 |
25 | def checkpoint(func, inputs, params, flag):
26 | """
27 | Evaluate a function without caching intermediate activations, allowing for
28 | reduced memory at the expense of extra compute in the backward pass.
29 | :param func: the function to evaluate.
30 | :param inputs: the argument sequence to pass to `func`.
31 | :param params: a sequence of parameters `func` depends on but does not
32 | explicitly take as arguments.
33 | :param flag: if False, disable gradient checkpointing.
34 | """
35 | if flag:
36 | args = tuple(inputs) + tuple(params)
37 | return CheckpointFunction.apply(func, len(inputs), *args)
38 | else:
39 | return func(*inputs)
40 |
41 |
42 | class CheckpointFunction(torch.autograd.Function):
43 | @staticmethod
44 | def forward(ctx, run_function, length, *args):
45 | ctx.run_function = run_function
46 | ctx.input_tensors = list(args[:length])
47 | ctx.input_params = list(args[length:])
48 |
49 | with torch.no_grad():
50 | output_tensors = ctx.run_function(*ctx.input_tensors)
51 | return output_tensors
52 |
53 | @staticmethod
54 | def backward(ctx, *output_grads):
55 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
56 | with torch.enable_grad():
57 | # Fixes a bug where the first op in run_function modifies the
58 | # Tensor storage in place, which is not allowed for detach()'d
59 | # Tensors.
60 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
61 | output_tensors = ctx.run_function(*shallow_copies)
62 | input_grads = torch.autograd.grad(
63 | output_tensors,
64 | ctx.input_tensors + ctx.input_params,
65 | output_grads,
66 | allow_unused=True,
67 | )
68 | del ctx.input_tensors
69 | del ctx.input_params
70 | del output_tensors
71 | return (None, None) + input_grads
72 |
73 |
74 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
75 | """
76 | Create sinusoidal timestep embeddings.
77 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
78 | These may be fractional.
79 | :param dim: the dimension of the output.
80 | :param max_period: controls the minimum frequency of the embeddings.
81 | :return: an [N x dim] Tensor of positional embeddings.
82 | """
83 | if not repeat_only:
84 | half = dim // 2
85 | freqs = torch.exp(
86 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
87 | ).to(device=timesteps.device)
88 | args = timesteps[:, None].float() * freqs[None]
89 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
90 | if dim % 2:
91 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
92 | else:
93 | embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous()
94 | return embedding
95 |
96 |
97 | def zero_module(module):
98 | """
99 | Zero out the parameters of a module and return it.
100 | """
101 | for p in module.parameters():
102 | p.detach().zero_()
103 | return module
104 |
105 |
106 | def scale_module(module, scale):
107 | """
108 | Scale the parameters of a module and return it.
109 | """
110 | for p in module.parameters():
111 | p.detach().mul_(scale)
112 | return module
113 |
114 |
115 | def mean_flat(tensor):
116 | """
117 | Take the mean over all non-batch dimensions.
118 | """
119 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
120 |
121 |
122 | def normalization(channels):
123 | """
124 | Make a standard normalization layer.
125 | :param channels: number of input channels.
126 | :return: an nn.Module for normalization.
127 | """
128 | return GroupNorm32(32, channels)
129 |
130 |
131 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
132 | class SiLU(nn.Module):
133 | def forward(self, x):
134 | return x * torch.sigmoid(x)
135 |
136 |
137 | class GroupNorm32(nn.GroupNorm):
138 | def forward(self, x):
139 | return super().forward(x.float()).type(x.dtype)
140 |
141 | def conv_nd(dims, *args, **kwargs):
142 | """
143 | Create a 1D, 2D, or 3D convolution module.
144 | """
145 | if dims == 1:
146 | return nn.Conv1d(*args, **kwargs)
147 | elif dims == 2:
148 | return nn.Conv2d(*args, **kwargs)
149 | elif dims == 3:
150 | return nn.Conv3d(*args, **kwargs)
151 | raise ValueError(f"unsupported dimensions: {dims}")
152 |
153 |
154 | def linear(*args, **kwargs):
155 | """
156 | Create a linear module.
157 | """
158 | return nn.Linear(*args, **kwargs)
159 |
160 |
161 | def avg_pool_nd(dims, *args, **kwargs):
162 | """
163 | Create a 1D, 2D, or 3D average pooling module.
164 | """
165 | if dims == 1:
166 | return nn.AvgPool1d(*args, **kwargs)
167 | elif dims == 2:
168 | return nn.AvgPool2d(*args, **kwargs)
169 | elif dims == 3:
170 | return nn.AvgPool3d(*args, **kwargs)
171 | raise ValueError(f"unsupported dimensions: {dims}")
172 |
173 |
174 | # class HybridConditioner(nn.Module):
175 |
176 | # def __init__(self, c_concat_config, c_crossattn_config):
177 | # super().__init__()
178 | # self.concat_conditioner = instantiate_from_config(c_concat_config)
179 | # self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
180 |
181 | # def forward(self, c_concat, c_crossattn):
182 | # c_concat = self.concat_conditioner(c_concat)
183 | # c_crossattn = self.crossattn_conditioner(c_crossattn)
184 | # return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
185 |
186 |
187 | def noise_like(shape, device, repeat=False):
188 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
189 | noise = lambda: torch.randn(shape, device=device)
190 | return repeat_noise() if repeat else noise()
191 |
192 | def count_flops_attn(model, _x, y):
193 | """
194 | A counter for the `thop` package to count the operations in an
195 | attention operation.
196 | Meant to be used like:
197 | macs, params = thop.profile(
198 | model,
199 | inputs=(inputs, timestamps),
200 | custom_ops={QKVAttention: QKVAttention.count_flops},
201 | )
202 | """
203 | b, c, *spatial = y[0].shape
204 | num_spatial = int(np.prod(spatial))
205 | # We perform two matmuls with the same number of ops.
206 | # The first computes the weight matrix, the second computes
207 | # the combination of the value vectors.
208 | matmul_ops = 2 * b * (num_spatial ** 2) * c
209 | model.total_ops += torch.DoubleTensor([matmul_ops])
210 |
211 | def count_params(model, verbose=False):
212 | total_params = sum(p.numel() for p in model.parameters())
213 | if verbose:
214 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
215 | return total_params
--------------------------------------------------------------------------------
/sample_long_pl.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import imageio
4 | import logging
5 | import argparse
6 | from tqdm import tqdm
7 | from einops import rearrange, repeat
8 | from diffusion import create_diffusion
9 | from diffusers.models import AutoencoderKL, AutoencoderDC
10 | from omegaconf import OmegaConf
11 | from pytorch_lightning import LightningModule, Trainer
12 | from torch.utils.data import DataLoader
13 |
14 | from models import get_models
15 | from models.audio_proj import AudioProjModel
16 | from utils import find_model
17 | from datasets import get_dataset
18 |
19 |
20 | class LatteSamplingModule(LightningModule):
21 | def __init__(self, args, logger: logging.Logger):
22 | super(LatteSamplingModule, self).__init__()
23 | self.args = args
24 | self.logging = logger
25 | self.model = get_models(args).to(self.device)
26 | self.audioproj = AudioProjModel(
27 | seq_len=5,
28 | blocks=12,
29 | channels=args.audio_dim,
30 | intermediate_dim=512,
31 | output_dim=args.audio_dim,
32 | context_tokens=args.audio_token,
33 | )
34 |
35 | state_dict, audioproj_dict = find_model(args.pretrained)
36 | self.model.load_state_dict(state_dict)
37 | self.logging.info(f"Loaded model checkpoint from {args.pretrained}")
38 | self.audioproj.load_state_dict(audioproj_dict)
39 |
40 | self.diffusion = create_diffusion(str(args.num_sampling_steps))
41 | if args.in_channels == 32:
42 | self.vae = AutoencoderDC.from_pretrained(args.pretrained_model_path, subfolder="vae")
43 | else:
44 | self.vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
45 |
46 | self.model.eval()
47 | self.vae.eval()
48 | self.audioproj.eval()
49 | self.global_counter = 0
50 |
51 | def sample(self, z, model_kwargs, sample_method='ddpm'):
52 | if sample_method == 'ddim':
53 | return self.diffusion.ddim_sample_loop(
54 | self.model.forward_with_cfg if self.args.cfg_scale > 1.0 else self.model.forward,
55 | z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=self.device
56 | )
57 | elif sample_method == 'ddpm':
58 | return self.diffusion.p_sample_loop(
59 | self.model.forward_with_cfg if self.args.cfg_scale > 1.0 else self.model.forward,
60 | z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=self.device
61 | )
62 | else:
63 | raise ValueError(f"Unknown sample method: {sample_method}")
64 |
65 | @torch.no_grad()
66 | def validation_step(self, batch, batch_idx):
67 | video_names = batch["video_name"]
68 | for video_name in video_names:
69 | self.logging.info(f"Processing video {video_name}.")
70 |
71 | audio_emb = batch["audio"]
72 | local_batch_size = audio_emb.size(0)
73 |
74 | if "latent" in self.args.dataset:
75 | ref_latents = batch["ref_latent"] * self.vae.config.scaling_factor
76 | motion_latents = batch["motions"] * self.vae.config.scaling_factor if self.args.initial_frames != 0 else None
77 | else:
78 | ref_image = batch["ref_image"]
79 | if self.args.in_channels == 32:
80 | ref_latents = self.vae.encode(ref_image)[0] * self.vae.config.scaling_factor
81 | else:
82 | ref_latents = self.vae.encode(ref_image).latent_dist.sample() * self.vae.scaling_factor
83 |
84 | motions = batch["motions"]
85 | motions = rearrange(motions, "b f c h w -> (b f) c h w").contiguous()
86 | if self.args.in_channels == 32:
87 | motion_latents = self.vae.encode(motions)[0] * self.vae.config.scaling_factor
88 | else:
89 | motion_latents = self.vae.encode(motions).latent_dist.sample() * self.vae.scaling_factor
90 | motion_latents = rearrange(motion_latents, "(b f) c h w -> b f c h w", b=local_batch_size).contiguous()
91 |
92 | clip_length = self.args.clip_frames
93 | times = audio_emb.size(1) // clip_length
94 | concat_samples = []
95 | ref_latents = repeat(ref_latents, "b c h w -> b f c h w", f=clip_length)
96 | for t in tqdm(range(times), desc="Processing"):
97 | audio_tensor = audio_emb[:, t * clip_length: (t + 1) * clip_length]
98 | audio_tensor = self.audioproj(audio_tensor)
99 |
100 | if t == 0:
101 | # The first iteration
102 | initial_latents = motion_latents
103 |
104 | z = torch.randn(local_batch_size, clip_length, self.args.in_channels, self.args.latent_size, self.args.latent_size, device=self.device)
105 | model_kwargs = dict(y=None, cond=ref_latents, motion=initial_latents, audio_embed=audio_tensor)
106 | if self.args.cfg_scale > 1.0:
107 | z = torch.cat([z, z], 0)
108 | model_kwargs.update(cfg_scale=self.args.cfg_scale)
109 |
110 | samples = self.sample(z, model_kwargs, sample_method=self.args.sample_method)
111 | initial_latents = samples[:, 0 - self.args.initial_frames:]
112 |
113 | samples = rearrange(samples, 'b f c h w -> (b f) c h w')
114 | if self.args.in_channels == 32:
115 | samples = self.vae.decode(samples / self.vae.config.scaling_factor, return_dict=False)[0]
116 | else:
117 | samples = self.vae.decode(samples / 0.18215).sample
118 | samples = rearrange(samples, '(b f) c h w -> b f c h w', b=local_batch_size)
119 | concat_samples.append(samples)
120 | torch.cuda.empty_cache()
121 |
122 | concat_samples = torch.cat(concat_samples, dim=1)
123 | for sample, video_name in zip(concat_samples, video_names):
124 | video_ = ((sample * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1).contiguous()
125 | video_save_path = os.path.join(self.args.save_video_path, f"{video_name}.mp4")
126 | imageio.mimwrite(video_save_path, video_, fps=self.args.fps, quality=9)
127 | self.logging.info(f"Saved video at {video_save_path}")
128 |
129 | return video_save_path
130 |
131 |
132 | def main(args):
133 | # Setup logger
134 | logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
135 | logger = logging.getLogger(__name__)
136 |
137 | # Determine if the current process is the main process (rank 0)
138 | is_main_process = (int(os.environ.get("LOCAL_RANK", 0)) == 0)
139 | if is_main_process:
140 | os.makedirs(args.save_video_path, exist_ok=True)
141 | print(f"Saving .mp4 samples at {args.save_video_path}")
142 |
143 | # Create dataset and dataloader
144 | dataset = get_dataset(args)
145 | val_loader = DataLoader(
146 | dataset,
147 | batch_size=args.per_proc_batch_size, # Batch size 1 for sampling
148 | shuffle=False,
149 | num_workers=args.num_workers,
150 | pin_memory=True,
151 | )
152 | logger.info(f"Validation set contains {len(dataset)} samples")
153 |
154 | if args.in_channels == 32:
155 | sample_size = args.image_size // 32
156 | else:
157 | sample_size = args.image_size // 8
158 | args.latent_size = sample_size
159 |
160 | # Initialize the sampling module
161 | pl_module = LatteSamplingModule(args, logger)
162 |
163 | # Trainer
164 | trainer = Trainer(
165 | accelerator="gpu",
166 | devices=[0], # Specify GPU ids
167 | strategy="auto",
168 | logger=False,
169 | precision=args.precision if args.precision else "32-true",
170 | )
171 |
172 | # Run validation to generate samples
173 | trainer.validate(pl_module, dataloaders=val_loader)
174 |
175 | logger.info("Sampling completed!")
176 |
177 | if __name__ == "__main__":
178 | parser = argparse.ArgumentParser()
179 | parser.add_argument("--config", type=str, default="./configs/sample.yaml")
180 | args = parser.parse_args()
181 | omega_conf = OmegaConf.load(args.config)
182 | main(omega_conf)
--------------------------------------------------------------------------------
/datasets/frames_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import time
4 | import torch
5 | import random
6 | import numpy as np
7 | import torchvision.transforms as transforms
8 | from torch.utils.data import Dataset
9 |
10 |
11 | class BaseDataset(Dataset):
12 | def __init__(
13 | self,
14 | configs,
15 | temporal_sample=None,
16 | transform=None,
17 | ):
18 | self.data_dir = configs.data_dir
19 | self.datalists = [d for d in os.listdir(self.data_dir)]
20 | self.frame_interval = configs.frame_interval
21 | self.num_frames = configs.num_frames
22 | self.temporal_sample = temporal_sample
23 | self.transform = transform
24 |
25 | self.audio_dir = configs.audio_dir
26 | self.audio_margin = configs.audio_margin
27 | self.initial_frames = configs.initial_frames
28 | self.slide_window = (configs.in_channels == 4)
29 |
30 | def __len__(self):
31 | return len(self.datalists)
32 |
33 | def get_indices(self, total_frames):
34 | indices = np.arange(total_frames)
35 | # Randomly select frames
36 | start_index, end_index = self.temporal_sample(total_frames - self.audio_margin)
37 | selected_indices = torch.linspace(start_index, end_index - 1, self.num_frames + self.initial_frames, dtype=int)
38 | video_indices = selected_indices[self.initial_frames:]
39 | # start_index = random.randint(
40 | # max(self.initial_frames, self.audio_margin),
41 | # total_frames - self.num_frames - self.audio_margin - 1,
42 | # )
43 | # selected_indices = indices[start_index - self.initial_frames: start_index + self.num_frames]
44 | # video_indices = torch.from_numpy(selected_indices[self.initial_frames:])
45 |
46 | # Choose a reference frame from the remaining frames
47 | remaining_indices = np.setdiff1d(indices, selected_indices)
48 | if len(remaining_indices) == 0:
49 | remaining_indices = indices
50 | ref_index = np.random.choice(remaining_indices)
51 |
52 | # Add the reference frame index to the selected_indices
53 | selected_indices_ = np.append(selected_indices, ref_index)
54 | return selected_indices_, video_indices
55 |
56 | def load_audio_emb(self, audio_emb_path, video_indices):
57 | # Extract wav hidden features
58 | audio_emb = torch.load(audio_emb_path)
59 | if self.slide_window:
60 | audio_indices = (
61 | torch.arange(2 * self.audio_margin + 1) - self.audio_margin
62 | ) # Generates [-2, -1, 0, 1, 2]
63 | center_indices = video_indices.unsqueeze(1) + audio_indices.unsqueeze(0)
64 | try:
65 | audio_tensor = audio_emb[center_indices]
66 | except:
67 | print(audio_emb_path)
68 | print(len(audio_emb))
69 | else:
70 | audio_tensor = audio_emb[video_indices]
71 | return audio_tensor
72 |
73 |
74 | class VideoFramesDataset(BaseDataset):
75 | def load_images(self, folder_path, selected_indices):
76 | images = []
77 | for idx in selected_indices:
78 | img_path = os.path.join(folder_path, f"{idx+1:04d}.jpg")
79 | img = cv2.imread(img_path)
80 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
81 | if img is None:
82 | raise FileNotFoundError(f"Image {img_path} not found.")
83 | images.append(img)
84 | return np.array(images)
85 |
86 | def __getitem__(self, index):
87 | sample = self.datalists[index]
88 | video_folder_path = os.path.join(self.data_dir, sample)
89 | audio_emb_path = "{}/{}.pt".format(self.audio_dir, sample)
90 | audio_emb = torch.load(audio_emb_path, weights_only=True)
91 |
92 | # Extract total frames count based on images available
93 | video_frames = len([name for name in os.listdir(video_folder_path) if name.endswith(".jpg")])
94 | total_frames = min(audio_emb.size(0), video_frames)
95 |
96 | selected_indices, video_indices = self.get_indices(total_frames)
97 |
98 | # Load the selected frames including the reference frame
99 | all_frames = self.load_images(video_folder_path, selected_indices)
100 | all_frames = torch.from_numpy(all_frames).permute(0, 3, 1, 2).contiguous()
101 |
102 | # Apply transformation if it exists
103 | if self.transform:
104 | all_frames = self.transform(all_frames)
105 |
106 | # Separate the reference frame, motions, and image_window
107 | ref_image = all_frames[-1] # The last frame is the reference frame
108 | video_frames = all_frames[:-1] # All frames except the last one
109 |
110 | motions = video_frames[:self.initial_frames]
111 | image_window = video_frames[self.initial_frames:]
112 |
113 | audio_indices = (
114 | torch.arange(2 * self.audio_margin + 1) - self.audio_margin
115 | ) # Generates [-2, -1, 0, 1, 2]
116 | center_indices = video_indices.unsqueeze(1) + audio_indices.unsqueeze(0)
117 | audio_tensor = audio_emb[center_indices].squeeze(1)
118 | return {"video": image_window, "audio": audio_tensor, "ref_image": ref_image, "motions": motions, "video_name": sample}
119 |
120 |
121 | class FramesLatentDataset(BaseDataset):
122 | def load_latents(self, folder_path, selected_indices):
123 | latents = []
124 | for idx in selected_indices:
125 | latent_path = os.path.join(folder_path, f"{idx+1:04d}.npz") # Change .pt to .npz
126 | latent_data = np.load(latent_path)['latent'] # Load latent data from .npz file
127 | latent = torch.tensor(latent_data) # Convert numpy array to PyTorch tensor
128 | latents.append(latent)
129 | return torch.stack(latents) # Return stacked latents as a tensor
130 |
131 | def __getitem__(self, index):
132 | sample = self.datalists[index]
133 | video_folder_path = os.path.join(self.data_dir, sample)
134 | audio_emb_path = "{}/{}.pt".format(self.audio_dir, sample)
135 | audio_emb = torch.load(audio_emb_path, weights_only=True)
136 |
137 | # Extract total frames count based on available .npz files
138 | video_frames = len([name for name in os.listdir(video_folder_path) if name.endswith(".npz")])
139 | total_frames = min(audio_emb.size(0), video_frames)
140 |
141 | selected_indices, video_indices = self.get_indices(total_frames)
142 |
143 | # Load the selected latents including the reference frame
144 | all_latents = self.load_latents(video_folder_path, selected_indices)
145 |
146 | # Separate the reference frame, motions, and image_window
147 | ref_latent = all_latents[-1] # The last frame is the reference frame
148 | video_latents = all_latents[:-1] # All frames except the last one
149 |
150 | motion_latents = video_latents[:self.initial_frames]
151 | latent_window = video_latents[self.initial_frames:]
152 |
153 | audio_indices = (
154 | torch.arange(2 * self.audio_margin + 1) - self.audio_margin
155 | ) # Generates [-2, -1, 0, 1, 2]
156 | center_indices = video_indices.unsqueeze(1) + audio_indices.unsqueeze(0)
157 | audio_tensor = audio_emb[center_indices].squeeze(1)
158 | return {"video": latent_window, "audio": audio_tensor, "ref_latent": ref_latent, "motions": motion_latents, "video_name": sample}
159 |
160 |
161 | class VideoLatentDataset(BaseDataset):
162 | def load_latents(self, latent_path):
163 | latents = torch.load(latent_path) # Load the entire latent tensor
164 | return latents
165 |
166 | def __getitem__(self, index):
167 | sample = self.datalists[index]
168 | video_latent_path = os.path.join(self.data_dir, sample)
169 | audio_emb_path = "{}/{}.pt".format(self.audio_dir, sample)
170 | audio_emb = torch.load(audio_emb_path, weights_only=True)
171 |
172 | # Load latents and calculate total frames
173 | latents = self.load_latents(video_latent_path) # Latent shape: [num_frames, channels, height, width]
174 | video_frames = latents.shape[0]
175 | total_frames = min(audio_emb.size(0), video_frames)
176 |
177 | selected_indices, video_indices = self.get_indices(total_frames)
178 |
179 | # Select latent frames
180 | all_latents = latents[selected_indices]
181 |
182 | # Separate the reference frame, motions, and image_window
183 | ref_latent = all_latents[-1] # The last frame is the reference frame
184 | video_latents = all_latents[:-1] # All frames except the last one
185 |
186 | motion_latents = video_latents[:self.initial_frames]
187 | latent_window = video_latents[self.initial_frames:]
188 |
189 | audio_indices = (
190 | torch.arange(2 * self.audio_margin + 1) - self.audio_margin
191 | ) # Generates [-2, -1, 0, 1, 2]
192 | center_indices = video_indices.unsqueeze(1) + audio_indices.unsqueeze(0)
193 | audio_tensor = audio_emb[center_indices].squeeze(1)
194 | return {"video": latent_window, "audio": audio_tensor, "ref_latent": ref_latent, "motions": motion_latents, "video_name": sample}
195 |
196 |
197 | class VideoDubbingDataset(BaseDataset):
198 | def __init__(
199 | self,
200 | configs,
201 | transform=None,
202 | ):
203 | self.data_dir = configs.data_dir
204 | self.datalists = [d for d in os.listdir(self.data_dir)]
205 | self.frame_interval = configs.frame_interval
206 | self.num_frames = configs.num_frames
207 | self.transform = transform
208 | self.audio_dir = configs.audio_dir
209 | self.audio_margin = configs.audio_margin
210 |
211 | def load_images(self, folder_path, selected_indices):
212 | images = []
213 | for idx in selected_indices:
214 | img_path = os.path.join(folder_path, f"{idx+1:04d}.jpg")
215 | img = cv2.imread(img_path)
216 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
217 | if img is None:
218 | raise FileNotFoundError(f"Image {img_path} not found.")
219 | images.append(img)
220 | return np.array(images)
221 |
222 | def get_two_clip_indices(self, total_frames):
223 | # Calculate the total length needed for two clips and the gap between them
224 | total_needed = 2 * self.num_frames + self.num_frames # 2 clips + gap
225 |
226 | # We need additional margin for audio indices at both ends
227 | total_needed_with_margin = total_needed + 2 * self.audio_margin
228 |
229 | # Check if video is long enough
230 | if total_frames < total_needed_with_margin:
231 | raise ValueError(f"Video has only {total_frames} frames, but need at least {total_needed_with_margin} frames (including audio margin)")
232 |
233 | # Randomly select a starting point that allows both clips to fit with audio margins
234 | max_start = total_frames - total_needed_with_margin
235 | start_frame = random.randint(self.audio_margin, max_start + self.audio_margin)
236 |
237 | # First clip indices (accounting for audio margin)
238 | first_clip_indices = range(start_frame, start_frame + self.num_frames)
239 |
240 | # Second clip starts after first clip + gap
241 | second_clip_start = start_frame + self.num_frames + self.num_frames
242 | second_clip_indices = range(second_clip_start, second_clip_start + self.num_frames)
243 |
244 | # Combine all indices needed for loading
245 | all_indices = list(first_clip_indices) + list(second_clip_indices)
246 |
247 | # Return the combined indices and the indices for each clip
248 | return all_indices, first_clip_indices, second_clip_indices
249 |
250 | def __getitem__(self, index):
251 | sample = self.datalists[index]
252 | video_folder_path = os.path.join(self.data_dir, sample)
253 |
254 | audio_emb_path = "{}/{}.pt".format(self.audio_dir, sample)
255 | audio_emb = torch.load(audio_emb_path, weights_only=True)
256 |
257 | # Extract total frames count based on images available
258 | video_frames = len([name for name in os.listdir(video_folder_path) if name.endswith(".jpg")])
259 | total_frames = min(audio_emb.size(0), video_frames)
260 |
261 | # Get indices for two clips
262 | all_indices, first_clip_indices, second_clip_indices = self.get_two_clip_indices(total_frames)
263 |
264 | # Load all needed frames
265 | all_frames = self.load_images(video_folder_path, all_indices)
266 | all_frames = torch.from_numpy(all_frames).permute(0, 3, 1, 2).contiguous()
267 |
268 | # Split into two clips
269 | first_clip_frames = all_frames[:self.num_frames]
270 | second_clip_frames = all_frames[self.num_frames:]
271 |
272 | # Apply transformation if it exists
273 | if self.transform:
274 | first_clip_frames = self.transform(first_clip_frames)
275 | second_clip_frames = self.transform(second_clip_frames)
276 |
277 | # Process audio for both clips
278 | def get_audio_tensor(video_indices):
279 | audio_indices = (
280 | torch.arange(2 * self.audio_margin + 1) - self.audio_margin
281 | ) # Generates [-2, -1, 0, 1, 2]
282 | center_indices = torch.tensor(video_indices).unsqueeze(1) + audio_indices.unsqueeze(0)
283 | return audio_emb[center_indices].squeeze(1)
284 |
285 | first_audio_tensor = get_audio_tensor(first_clip_indices)
286 | second_audio_tensor = get_audio_tensor(second_clip_indices)
287 |
288 | return {
289 | "first_video": first_clip_frames,
290 | "first_audio": first_audio_tensor,
291 | "second_video": second_clip_frames,
292 | "second_audio": second_audio_tensor,
293 | "video_name": sample
294 | }
--------------------------------------------------------------------------------
/models/vae/vae.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional, Tuple
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 |
8 | from diffusers.utils import BaseOutput, is_torch_version
9 | from diffusers.utils.torch_utils import randn_tensor
10 | from diffusers.models.attention_processor import SpatialNorm
11 | from .unet_causal_3d_blocks import (
12 | CausalConv3d,
13 | UNetMidBlockCausal3D,
14 | get_down_block3d,
15 | get_up_block3d,
16 | )
17 |
18 |
19 | @dataclass
20 | class DecoderOutput(BaseOutput):
21 | r"""
22 | Output of decoding method.
23 |
24 | Args:
25 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
26 | The decoded output sample from the last layer of the model.
27 | """
28 |
29 | sample: torch.FloatTensor
30 |
31 |
32 | class EncoderCausal3D(nn.Module):
33 | r"""
34 | The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
35 | """
36 |
37 | def __init__(
38 | self,
39 | in_channels: int = 3,
40 | out_channels: int = 3,
41 | down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
42 | block_out_channels: Tuple[int, ...] = (64,),
43 | layers_per_block: int = 2,
44 | norm_num_groups: int = 32,
45 | act_fn: str = "silu",
46 | double_z: bool = True,
47 | mid_block_add_attention=True,
48 | time_compression_ratio: int = 4,
49 | spatial_compression_ratio: int = 8,
50 | ):
51 | super().__init__()
52 | self.layers_per_block = layers_per_block
53 |
54 | self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
55 | self.mid_block = None
56 | self.down_blocks = nn.ModuleList([])
57 |
58 | # down
59 | output_channel = block_out_channels[0]
60 | for i, down_block_type in enumerate(down_block_types):
61 | input_channel = output_channel
62 | output_channel = block_out_channels[i]
63 | is_final_block = i == len(block_out_channels) - 1
64 | num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
65 | num_time_downsample_layers = int(np.log2(time_compression_ratio))
66 |
67 | if time_compression_ratio == 4:
68 | add_spatial_downsample = bool(i < num_spatial_downsample_layers)
69 | add_time_downsample = bool(
70 | i >= (len(block_out_channels) - 1 - num_time_downsample_layers)
71 | and not is_final_block
72 | )
73 | else:
74 | raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
75 |
76 | downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
77 | downsample_stride_T = (2,) if add_time_downsample else (1,)
78 | downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
79 | down_block = get_down_block3d(
80 | down_block_type,
81 | num_layers=self.layers_per_block,
82 | in_channels=input_channel,
83 | out_channels=output_channel,
84 | add_downsample=bool(add_spatial_downsample or add_time_downsample),
85 | downsample_stride=downsample_stride,
86 | resnet_eps=1e-6,
87 | downsample_padding=0,
88 | resnet_act_fn=act_fn,
89 | resnet_groups=norm_num_groups,
90 | attention_head_dim=output_channel,
91 | temb_channels=None,
92 | )
93 | self.down_blocks.append(down_block)
94 |
95 | # mid
96 | self.mid_block = UNetMidBlockCausal3D(
97 | in_channels=block_out_channels[-1],
98 | resnet_eps=1e-6,
99 | resnet_act_fn=act_fn,
100 | output_scale_factor=1,
101 | resnet_time_scale_shift="default",
102 | attention_head_dim=block_out_channels[-1],
103 | resnet_groups=norm_num_groups,
104 | temb_channels=None,
105 | add_attention=mid_block_add_attention,
106 | )
107 |
108 | # out
109 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
110 | self.conv_act = nn.SiLU()
111 |
112 | conv_out_channels = 2 * out_channels if double_z else out_channels
113 | self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
114 |
115 | def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
116 | r"""The forward method of the `EncoderCausal3D` class."""
117 | assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
118 |
119 | sample = self.conv_in(sample)
120 |
121 | # down
122 | for down_block in self.down_blocks:
123 | sample = down_block(sample)
124 |
125 | # middle
126 | sample = self.mid_block(sample)
127 |
128 | # post-process
129 | sample = self.conv_norm_out(sample)
130 | sample = self.conv_act(sample)
131 | sample = self.conv_out(sample)
132 |
133 | return sample
134 |
135 |
136 | class DecoderCausal3D(nn.Module):
137 | r"""
138 | The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
139 | """
140 |
141 | def __init__(
142 | self,
143 | in_channels: int = 3,
144 | out_channels: int = 3,
145 | up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
146 | block_out_channels: Tuple[int, ...] = (64,),
147 | layers_per_block: int = 2,
148 | norm_num_groups: int = 32,
149 | act_fn: str = "silu",
150 | norm_type: str = "group", # group, spatial
151 | mid_block_add_attention=True,
152 | time_compression_ratio: int = 4,
153 | spatial_compression_ratio: int = 8,
154 | ):
155 | super().__init__()
156 | self.layers_per_block = layers_per_block
157 |
158 | self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
159 | self.mid_block = None
160 | self.up_blocks = nn.ModuleList([])
161 |
162 | temb_channels = in_channels if norm_type == "spatial" else None
163 |
164 | # mid
165 | self.mid_block = UNetMidBlockCausal3D(
166 | in_channels=block_out_channels[-1],
167 | resnet_eps=1e-6,
168 | resnet_act_fn=act_fn,
169 | output_scale_factor=1,
170 | resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
171 | attention_head_dim=block_out_channels[-1],
172 | resnet_groups=norm_num_groups,
173 | temb_channels=temb_channels,
174 | add_attention=mid_block_add_attention,
175 | )
176 |
177 | # up
178 | reversed_block_out_channels = list(reversed(block_out_channels))
179 | output_channel = reversed_block_out_channels[0]
180 | for i, up_block_type in enumerate(up_block_types):
181 | prev_output_channel = output_channel
182 | output_channel = reversed_block_out_channels[i]
183 | is_final_block = i == len(block_out_channels) - 1
184 | num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
185 | num_time_upsample_layers = int(np.log2(time_compression_ratio))
186 |
187 | if time_compression_ratio == 4:
188 | add_spatial_upsample = bool(i < num_spatial_upsample_layers)
189 | add_time_upsample = bool(
190 | i >= len(block_out_channels) - 1 - num_time_upsample_layers
191 | and not is_final_block
192 | )
193 | else:
194 | raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
195 |
196 | upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
197 | upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
198 | upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
199 | up_block = get_up_block3d(
200 | up_block_type,
201 | num_layers=self.layers_per_block + 1,
202 | in_channels=prev_output_channel,
203 | out_channels=output_channel,
204 | prev_output_channel=None,
205 | add_upsample=bool(add_spatial_upsample or add_time_upsample),
206 | upsample_scale_factor=upsample_scale_factor,
207 | resnet_eps=1e-6,
208 | resnet_act_fn=act_fn,
209 | resnet_groups=norm_num_groups,
210 | attention_head_dim=output_channel,
211 | temb_channels=temb_channels,
212 | resnet_time_scale_shift=norm_type,
213 | )
214 | self.up_blocks.append(up_block)
215 | prev_output_channel = output_channel
216 |
217 | # out
218 | if norm_type == "spatial":
219 | self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
220 | else:
221 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
222 | self.conv_act = nn.SiLU()
223 | self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
224 |
225 | self.gradient_checkpointing = False
226 |
227 | def forward(
228 | self,
229 | sample: torch.FloatTensor,
230 | latent_embeds: Optional[torch.FloatTensor] = None,
231 | ) -> torch.FloatTensor:
232 | r"""The forward method of the `DecoderCausal3D` class."""
233 | assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
234 |
235 | sample = self.conv_in(sample)
236 |
237 | upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
238 | if self.training and self.gradient_checkpointing:
239 |
240 | def create_custom_forward(module):
241 | def custom_forward(*inputs):
242 | return module(*inputs)
243 |
244 | return custom_forward
245 |
246 | if is_torch_version(">=", "1.11.0"):
247 | # middle
248 | sample = torch.utils.checkpoint.checkpoint(
249 | create_custom_forward(self.mid_block),
250 | sample,
251 | latent_embeds,
252 | use_reentrant=False,
253 | )
254 | sample = sample.to(upscale_dtype)
255 |
256 | # up
257 | for up_block in self.up_blocks:
258 | sample = torch.utils.checkpoint.checkpoint(
259 | create_custom_forward(up_block),
260 | sample,
261 | latent_embeds,
262 | use_reentrant=False,
263 | )
264 | else:
265 | # middle
266 | sample = torch.utils.checkpoint.checkpoint(
267 | create_custom_forward(self.mid_block), sample, latent_embeds
268 | )
269 | sample = sample.to(upscale_dtype)
270 |
271 | # up
272 | for up_block in self.up_blocks:
273 | sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
274 | else:
275 | # middle
276 | sample = self.mid_block(sample, latent_embeds)
277 | sample = sample.to(upscale_dtype)
278 |
279 | # up
280 | for up_block in self.up_blocks:
281 | sample = up_block(sample, latent_embeds)
282 |
283 | # post-process
284 | if latent_embeds is None:
285 | sample = self.conv_norm_out(sample)
286 | else:
287 | sample = self.conv_norm_out(sample, latent_embeds)
288 | sample = self.conv_act(sample)
289 | sample = self.conv_out(sample)
290 |
291 | return sample
292 |
293 |
294 | class DiagonalGaussianDistribution(object):
295 | def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
296 | if parameters.ndim == 3:
297 | dim = 2 # (B, L, C)
298 | elif parameters.ndim == 5 or parameters.ndim == 4:
299 | dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
300 | else:
301 | raise NotImplementedError
302 | self.parameters = parameters
303 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
304 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
305 | self.deterministic = deterministic
306 | self.std = torch.exp(0.5 * self.logvar)
307 | self.var = torch.exp(self.logvar)
308 | if self.deterministic:
309 | self.var = self.std = torch.zeros_like(
310 | self.mean, device=self.parameters.device, dtype=self.parameters.dtype
311 | )
312 |
313 | def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
314 | # make sure sample is on the same device as the parameters and has same dtype
315 | sample = randn_tensor(
316 | self.mean.shape,
317 | generator=generator,
318 | device=self.parameters.device,
319 | dtype=self.parameters.dtype,
320 | )
321 | x = self.mean + self.std * sample
322 | return x
323 |
324 | def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
325 | if self.deterministic:
326 | return torch.Tensor([0.0])
327 | else:
328 | reduce_dim = list(range(1, self.mean.ndim))
329 | if other is None:
330 | return 0.5 * torch.sum(
331 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
332 | dim=reduce_dim,
333 | )
334 | else:
335 | return 0.5 * torch.sum(
336 | torch.pow(self.mean - other.mean, 2) / other.var
337 | + self.var / other.var
338 | - 1.0
339 | - self.logvar
340 | + other.logvar,
341 | dim=reduce_dim,
342 | )
343 |
344 | def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
345 | if self.deterministic:
346 | return torch.Tensor([0.0])
347 | logtwopi = np.log(2.0 * np.pi)
348 | return 0.5 * torch.sum(
349 | logtwopi + self.logvar +
350 | torch.pow(sample - self.mean, 2) / self.var,
351 | dim=dims,
352 | )
353 |
354 | def mode(self) -> torch.Tensor:
355 | return self.mean
356 |
--------------------------------------------------------------------------------
/train_pl.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import math
4 | import logging
5 | import argparse
6 | from pytorch_lightning.callbacks import Callback
7 | from pytorch_lightning import LightningModule, Trainer
8 | from pytorch_lightning.strategies import DDPStrategy
9 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
10 | from pytorch_lightning.loggers import TensorBoardLogger
11 | from glob import glob
12 | from omegaconf import OmegaConf
13 | from torch.utils.data import DataLoader
14 | from diffusers import AutoencoderDC
15 | from diffusers.models import AutoencoderKL
16 | from diffusers.optimization import get_scheduler
17 | from copy import deepcopy
18 | from einops import rearrange, repeat
19 |
20 | from models import get_models
21 | from models.audio_proj import AudioProjModel
22 | from datasets import get_dataset
23 | from diffusion import create_diffusion
24 | from utils import (
25 | update_ema,
26 | requires_grad,
27 | clip_grad_norm_,
28 | cleanup,
29 | find_model
30 | )
31 |
32 |
33 | class LetstalkTrainingModule(LightningModule):
34 | def __init__(self, args, logger: logging.Logger):
35 | super(LetstalkTrainingModule, self).__init__()
36 | self.args = args
37 | self.logging = logger
38 | self.model = get_models(args)
39 | if args.use_compile:
40 | self.model = torch.compile(self.model)
41 |
42 | self.ema = deepcopy(self.model)
43 | self.audioproj = AudioProjModel(
44 | seq_len=5,
45 | blocks=12,
46 | channels=args.audio_dim,
47 | intermediate_dim=512,
48 | output_dim=args.audio_dim,
49 | context_tokens=args.audio_token,
50 | )
51 |
52 | # Load pretrained model if specified
53 | if args.pretrained is not None:
54 | self._load_pretrained_parameters(args)
55 | self.logging.info(f"Model Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
56 |
57 | self.diffusion = create_diffusion(timestep_respacing="")
58 | if args.in_channels == 32:
59 | self.vae = AutoencoderDC.from_pretrained(args.pretrained_model_path, subfolder="vae")
60 | else:
61 | self.vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
62 |
63 | self.opt = torch.optim.AdamW(self.model.parameters(), lr=1e-4, weight_decay=0)
64 | self.lr_scheduler = None
65 |
66 | # Freeze model
67 | self.vae.requires_grad_(False)
68 | requires_grad(self.ema, False)
69 |
70 | update_ema(self.ema, self.model, decay=0) # Ensure EMA is initialized with synced weights
71 | self.model.train() # important! This enables embedding dropout for classifier-free guidance
72 | self.audioproj.train()
73 | self.ema.eval()
74 |
75 | def _load_pretrained_parameters(self, args):
76 | checkpoint = torch.load(args.pretrained, map_location=lambda storage, loc: storage)
77 | if "ema" in checkpoint: # supports checkpoints from train.py
78 | self.logging.info("Using ema ckpt!")
79 | checkpoint = checkpoint["ema"]
80 |
81 | model_dict = self.model.state_dict()
82 | # 1. filter out unnecessary keys
83 | pretrained_dict = {}
84 | for k, v in checkpoint.items():
85 | if k in model_dict:
86 | pretrained_dict[k] = v
87 | else:
88 | self.logging.info("Ignoring: {}".format(k))
89 | self.logging.info(f"Successfully Load {len(pretrained_dict) / len(checkpoint.items()) * 100}% original pretrained model weights ")
90 |
91 | # 2. overwrite entries in the existing state dict
92 | model_dict.update(pretrained_dict)
93 | self.model.load_state_dict(model_dict)
94 | self.logging.info(f"Successfully load model at {args.pretrained}!")
95 |
96 | if "audioproj" in checkpoint.keys():
97 | audioproj_dict = checkpoint["audioproj"]
98 | self.audioproj.load_state_dict(audioproj_dict)
99 |
100 | # def on_load_checkpoint(self, checkpoint):
101 | # file_name = args.pretrained.split("/")[-1].split('.')[0]
102 | # if file_name.isdigit():
103 | # self.global_step = int(file_name)
104 |
105 | def add_noise_to_image(self, images):
106 | image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(images.size(0),), device=self.device)
107 | image_noise_sigma = torch.exp(image_noise_sigma)
108 | noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None]
109 | return noisy_images
110 |
111 | def add_noise(
112 | self,
113 | original_samples: torch.Tensor,
114 | noise: torch.Tensor,
115 | timesteps: torch.IntTensor,
116 | ) -> torch.Tensor:
117 | # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
118 | # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
119 | # for the subsequent add_noise calls
120 | alphas_cumprod = torch.from_numpy(self.diffusion.alphas_cumprod)
121 | alphas_cumprod = alphas_cumprod.to(device=original_samples.device)
122 | alphas_cumprod = alphas_cumprod.to(dtype=original_samples.dtype)
123 | timesteps = timesteps.to(original_samples.device)
124 |
125 | sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
126 | sqrt_alpha_prod = sqrt_alpha_prod.flatten()
127 | while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
128 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
129 |
130 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
131 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
132 | while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
133 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
134 |
135 | noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
136 | return noisy_samples
137 |
138 | def training_step(self, batch, batch_idx):
139 | if "latent" in self.args.dataset:
140 | latents = batch["video"] * self.vae.config.scaling_factor
141 | ref_latents = batch["ref_latent"] * self.vae.config.scaling_factor
142 | motion_latents = batch["motions"] * self.vae.config.scaling_factor if self.args.initial_frames != 0 else None
143 | else:
144 | x = batch["video"]
145 | ref_image = batch["ref_image"]
146 |
147 | with torch.no_grad():
148 | b, _, _, _, _ = x.shape
149 | x = rearrange(x, "b f c h w -> (b f) c h w").contiguous()
150 | if self.args.in_channels == 32:
151 | latents = self.vae.encode(x)[0] * self.vae.config.scaling_factor
152 | else:
153 | latents = self.vae.encode(x).latent_dist.sample() * self.vae.scaling_factor
154 | latents = rearrange(latents, "(b f) c h w -> b f c h w", b=b).contiguous()
155 |
156 | ref_image = self.add_noise_to_image(ref_image)
157 | if self.args.in_channels == 32:
158 | ref_latents = self.vae.encode(ref_image)[0] * self.vae.config.scaling_factor
159 | else:
160 | ref_latents = self.vae.encode(ref_image).latent_dist.sample().mul_(0.18215)
161 |
162 | if self.args.initial_frames != 0:
163 | motions = batch["motions"]
164 | motions = rearrange(motions, "b f c h w -> (b f) c h w").contiguous()
165 | motions = self.add_noise_to_image(motions)
166 | if self.args.in_channels == 32:
167 | motion_latents = self.vae.encode(motions)[0] * self.vae.config.scaling_factor
168 | else:
169 | motion_latents = self.vae.encode(motions).latent_dist.sample() * self.vae.scaling_factor
170 | motion_latents = rearrange(motion_latents, "(b f) c h w -> b f c h w", b=b).contiguous()
171 |
172 | ref_latents = repeat(ref_latents, "b c h w -> b f c h w", f=latents.size(1))
173 | model_kwargs = dict(y=None, cond=ref_latents)
174 | if self.args.initial_frames != 0:
175 | motion_timesteps = torch.randint(0, 50, (latents.shape[0],), device=latents.device).long()
176 | motion_noise = torch.randn_like(motion_latents)
177 | # add motion noise
178 | noisy_motion_latents = self.add_noise(
179 | motion_latents, motion_noise, motion_timesteps
180 | )
181 |
182 | b, f, c, h, w = noisy_motion_latents.shape
183 | rand_mask = torch.rand(h, w).to(device=noisy_motion_latents.device)
184 | mask = rand_mask > 0.25
185 | mask = mask.unsqueeze(0).unsqueeze(0).unsqueeze(0)
186 | mask = mask.expand(b, f, c, h, w)
187 | noisy_motion_latents = noisy_motion_latents * mask
188 |
189 | model_kwargs.update(motion=noisy_motion_latents)
190 |
191 | if "audio" in batch:
192 | audio_emb = batch["audio"]
193 | audio_emb = self.audioproj(audio_emb)
194 | model_kwargs.update(audio_embed=audio_emb)
195 |
196 | timesteps = torch.randint(0, self.diffusion.num_timesteps, (latents.shape[0],), device=self.device)
197 | loss_dict = self.diffusion.training_losses(self.model, latents, timesteps, model_kwargs)
198 | loss = loss_dict["loss"].mean()
199 |
200 | if self.global_step < self.args.start_clip_iter:
201 | gradient_norm = clip_grad_norm_(self.model.parameters(), self.args.clip_max_norm, clip_grad=False)
202 | else:
203 | gradient_norm = clip_grad_norm_(self.model.parameters(), self.args.clip_max_norm, clip_grad=True)
204 |
205 | self.log("train_loss", loss, prog_bar=True)
206 | self.log("gradient_norm", gradient_norm, prog_bar=True)
207 | self.log("train_step", self.global_step)
208 |
209 | if (self.global_step+1) % self.args.log_every == 0:
210 | self.logging.info(
211 | f"(step={self.global_step+1:07d}/epoch={self.current_epoch:04d}) Train Loss: {loss:.4f}, Gradient Norm: {gradient_norm:.4f}"
212 | )
213 | return loss
214 |
215 | def on_train_batch_end(self, *args, **kwargs):
216 | update_ema(self.ema, self.model)
217 |
218 | def on_save_checkpoint(self, checkpoint):
219 | super().on_save_checkpoint(checkpoint)
220 | checkpoint_dir = self.trainer.checkpoint_callback.dirpath
221 | epoch = self.trainer.current_epoch
222 | step = self.trainer.global_step
223 | checkpoint = {
224 | "model": self.model.state_dict(),
225 | "ema": self.ema.state_dict(),
226 | "audioproj": self.audioproj.state_dict(),
227 | }
228 | torch.save(checkpoint, f"{checkpoint_dir}/last.ckpt")
229 | if step % self.args.ckpt_every == 0:
230 | torch.save(checkpoint, f"{checkpoint_dir}/epoch{epoch}-step{step}.ckpt")
231 |
232 | def configure_optimizers(self):
233 | self.lr_scheduler = get_scheduler(
234 | name="constant",
235 | optimizer=self.opt,
236 | num_warmup_steps=self.args.lr_warmup_steps * self.args.gradient_accumulation_steps,
237 | num_training_steps=self.args.max_train_steps * self.args.gradient_accumulation_steps,
238 | )
239 | return [self.opt], [self.lr_scheduler]
240 |
241 |
242 | def create_logger(logging_dir):
243 | logging.basicConfig(
244 | level=logging.INFO,
245 | format="[%(asctime)s] %(message)s",
246 | datefmt="%Y-%m-%d %H:%M:%S",
247 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
248 | )
249 | logger = logging.getLogger(__name__)
250 | return logger
251 |
252 |
253 | def create_experiment_directory(args):
254 | os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
255 | experiment_index = len(glob(os.path.join(args.results_dir, "*")))
256 | model_string_name = args.model.replace("/", "-") # e.g., Letstalk-L/2 --> Letstalk-L-2 (for naming folders)
257 | num_frame_string = f"F{args.num_frames}S{args.frame_interval}"
258 | experiment_dir = os.path.join( # Create an experiment folder
259 | args.results_dir,
260 | f"{experiment_index:03d}-{model_string_name}-{num_frame_string}-{args.dataset}"
261 | )
262 | checkpoint_dir = os.path.join(experiment_dir, "checkpoints") # Stores saved model checkpoints
263 | os.makedirs(checkpoint_dir, exist_ok=True)
264 |
265 | return experiment_dir, checkpoint_dir
266 |
267 |
268 | def main(args):
269 | seed = args.global_seed
270 | torch.manual_seed(seed)
271 |
272 | # Determine if the current process is the main process (rank 0)
273 | is_main_process = (int(os.environ.get("LOCAL_RANK", 0)) == 0)
274 | # Setup an experiment folder and logger only if main process
275 | if is_main_process:
276 | experiment_dir, checkpoint_dir = create_experiment_directory(args)
277 | logger = create_logger(experiment_dir)
278 | OmegaConf.save(args, os.path.join(experiment_dir, "config.yaml"))
279 | logger.info(f"Experiment directory created at {experiment_dir}")
280 | else:
281 | experiment_dir = os.getenv("EXPERIMENT_DIR", "default_path")
282 | checkpoint_dir = os.getenv("CHECKPOINT_DIR", "default_path")
283 | logger = logging.getLogger(__name__)
284 | logger.addHandler(logging.NullHandler())
285 | tb_logger = TensorBoardLogger(experiment_dir, name="letstalk")
286 |
287 | # Create the dataset and dataloader
288 | dataset = get_dataset(args)
289 | loader = DataLoader(
290 | dataset,
291 | batch_size=args.local_batch_size,
292 | shuffle=True,
293 | num_workers=args.num_workers,
294 | pin_memory=True,
295 | drop_last=True
296 | )
297 | if is_main_process:
298 | logger.info(f"Dataset contains {len(dataset)} videos ({args.data_dir})")
299 |
300 | if args.in_channels == 32:
301 | sample_size = args.image_size // 32
302 | else:
303 | sample_size = args.image_size // 8
304 | args.latent_size = sample_size
305 |
306 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
307 | num_update_steps_per_epoch = math.ceil(len(loader)) // torch.cuda.device_count()
308 | # Afterwards we recalculate our number of training epochs
309 | num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
310 | # In multi GPUs mode, the real batchsize is local_batch_size * GPU numbers
311 | if is_main_process:
312 | logger.info(f"One epoch iteration {num_update_steps_per_epoch} steps")
313 | logger.info(f"Num train epochs: {num_train_epochs}")
314 |
315 | # Initialize the training module
316 | pl_module = LetstalkTrainingModule(args, logger)
317 |
318 | checkpoint_callback = ModelCheckpoint(
319 | dirpath=checkpoint_dir,
320 | filename="{epoch}-{step}-{train_loss:.2f}-{gradient_norm:.2f}",
321 | save_top_k=3,
322 | every_n_train_steps=args.ckpt_every,
323 | monitor="train_step",
324 | mode="max"
325 | )
326 |
327 | # Trainer
328 | trainer = Trainer(
329 | accelerator="gpu",
330 | # devices=[0],
331 | strategy="auto",
332 | max_epochs=num_train_epochs,
333 | logger=tb_logger,
334 | callbacks=[checkpoint_callback, LearningRateMonitor()],
335 | precision=args.precision if args.precision else "32-true",
336 | )
337 |
338 | trainer.fit(pl_module, train_dataloaders=loader, ckpt_path=args.resume_from_checkpoint if
339 | args.resume_from_checkpoint else None)
340 |
341 | pl_module.model.eval()
342 | cleanup()
343 | if is_main_process:
344 | logger.info("Done!")
345 |
346 |
347 | if __name__ == "__main__":
348 | parser = argparse.ArgumentParser()
349 | parser.add_argument("--config", type=str, default="./configs/train.yaml")
350 | args = parser.parse_args()
351 | main(OmegaConf.load(args.config))
--------------------------------------------------------------------------------
/preparation/audio_processor.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=C0301
2 | '''
3 | This module contains the AudioProcessor class and related functions for processing audio data.
4 | It utilizes various libraries and models to perform tasks such as preprocessing, feature extraction,
5 | and audio separation. The class is initialized with configuration parameters and can process
6 | audio files using the provided models.
7 | '''
8 | import math
9 | import os
10 | import subprocess
11 |
12 | import librosa
13 | import numpy as np
14 | import torch
15 | import torch.nn.functional as F
16 | from audio_separator.separator import Separator
17 | from einops import rearrange
18 | from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
19 | from transformers.modeling_outputs import BaseModelOutput
20 |
21 |
22 | class Wav2VecModel(Wav2Vec2Model):
23 | """
24 | Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library.
25 | It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding.
26 | ...
27 |
28 | Attributes:
29 | base_model (Wav2Vec2Model): The base Wav2Vec2Model object.
30 |
31 | Methods:
32 | forward(input_values, seq_len, attention_mask=None, mask_time_indices=None
33 | , output_attentions=None, output_hidden_states=None, return_dict=None):
34 | Forward pass of the Wav2VecModel.
35 | It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model.
36 |
37 | feature_extract(input_values, seq_len):
38 | Extracts features from the input_values using the base model.
39 |
40 | encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None):
41 | Encodes the extracted features using the base model and returns the encoded features.
42 | """
43 | def forward(
44 | self,
45 | input_values,
46 | seq_len,
47 | attention_mask=None,
48 | mask_time_indices=None,
49 | output_attentions=None,
50 | output_hidden_states=None,
51 | return_dict=None,
52 | ):
53 | """
54 | Forward pass of the Wav2Vec model.
55 |
56 | Args:
57 | self: The instance of the model.
58 | input_values: The input values (waveform) to the model.
59 | seq_len: The sequence length of the input values.
60 | attention_mask: Attention mask to be used for the model.
61 | mask_time_indices: Mask indices to be used for the model.
62 | output_attentions: If set to True, returns attentions.
63 | output_hidden_states: If set to True, returns hidden states.
64 | return_dict: If set to True, returns a BaseModelOutput instead of a tuple.
65 |
66 | Returns:
67 | The output of the Wav2Vec model.
68 | """
69 | self.config.output_attentions = True
70 |
71 | output_hidden_states = (
72 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
73 | )
74 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
75 |
76 | extract_features = self.feature_extractor(input_values)
77 | extract_features = extract_features.transpose(1, 2)
78 | extract_features = linear_interpolation(extract_features, seq_len=seq_len)
79 |
80 | if attention_mask is not None:
81 | # compute reduced attention_mask corresponding to feature vectors
82 | attention_mask = self._get_feature_vector_attention_mask(
83 | extract_features.shape[1], attention_mask, add_adapter=False
84 | )
85 |
86 | hidden_states, extract_features = self.feature_projection(extract_features)
87 | hidden_states = self._mask_hidden_states(
88 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
89 | )
90 |
91 | encoder_outputs = self.encoder(
92 | hidden_states,
93 | attention_mask=attention_mask,
94 | output_attentions=output_attentions,
95 | output_hidden_states=output_hidden_states,
96 | return_dict=return_dict,
97 | )
98 |
99 | hidden_states = encoder_outputs[0]
100 |
101 | if self.adapter is not None:
102 | hidden_states = self.adapter(hidden_states)
103 |
104 | if not return_dict:
105 | return (hidden_states, ) + encoder_outputs[1:]
106 | return BaseModelOutput(
107 | last_hidden_state=hidden_states,
108 | hidden_states=encoder_outputs.hidden_states,
109 | attentions=encoder_outputs.attentions,
110 | )
111 |
112 |
113 | def feature_extract(
114 | self,
115 | input_values,
116 | seq_len,
117 | ):
118 | """
119 | Extracts features from the input values and returns the extracted features.
120 |
121 | Parameters:
122 | input_values (torch.Tensor): The input values to be processed.
123 | seq_len (torch.Tensor): The sequence lengths of the input values.
124 |
125 | Returns:
126 | extracted_features (torch.Tensor): The extracted features from the input values.
127 | """
128 | extract_features = self.feature_extractor(input_values)
129 | extract_features = extract_features.transpose(1, 2)
130 | extract_features = linear_interpolation(extract_features, seq_len=seq_len)
131 |
132 | return extract_features
133 |
134 | def encode(
135 | self,
136 | extract_features,
137 | attention_mask=None,
138 | mask_time_indices=None,
139 | output_attentions=None,
140 | output_hidden_states=None,
141 | return_dict=None,
142 | ):
143 | """
144 | Encodes the input features into the output space.
145 |
146 | Args:
147 | extract_features (torch.Tensor): The extracted features from the audio signal.
148 | attention_mask (torch.Tensor, optional): Attention mask to be used for padding.
149 | mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension.
150 | output_attentions (bool, optional): If set to True, returns the attention weights.
151 | output_hidden_states (bool, optional): If set to True, returns all hidden states.
152 | return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple.
153 |
154 | Returns:
155 | The encoded output features.
156 | """
157 | self.config.output_attentions = True
158 |
159 | output_hidden_states = (
160 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
161 | )
162 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
163 |
164 | if attention_mask is not None:
165 | # compute reduced attention_mask corresponding to feature vectors
166 | attention_mask = self._get_feature_vector_attention_mask(
167 | extract_features.shape[1], attention_mask, add_adapter=False
168 | )
169 |
170 | hidden_states, extract_features = self.feature_projection(extract_features)
171 | hidden_states = self._mask_hidden_states(
172 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
173 | )
174 |
175 | encoder_outputs = self.encoder(
176 | hidden_states,
177 | attention_mask=attention_mask,
178 | output_attentions=output_attentions,
179 | output_hidden_states=output_hidden_states,
180 | return_dict=return_dict,
181 | )
182 |
183 | hidden_states = encoder_outputs[0]
184 |
185 | if self.adapter is not None:
186 | hidden_states = self.adapter(hidden_states)
187 |
188 | if not return_dict:
189 | return (hidden_states, ) + encoder_outputs[1:]
190 | return BaseModelOutput(
191 | last_hidden_state=hidden_states,
192 | hidden_states=encoder_outputs.hidden_states,
193 | attentions=encoder_outputs.attentions,
194 | )
195 |
196 |
197 | def linear_interpolation(features, seq_len):
198 | """
199 | Transpose the features to interpolate linearly.
200 |
201 | Args:
202 | features (torch.Tensor): The extracted features to be interpolated.
203 | seq_len (torch.Tensor): The sequence lengths of the features.
204 |
205 | Returns:
206 | torch.Tensor: The interpolated features.
207 | """
208 | features = features.transpose(1, 2)
209 | output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
210 | return output_features.transpose(1, 2)
211 |
212 |
213 | def resample_audio(input_audio_file: str, output_audio_file: str, sample_rate: int):
214 | p = subprocess.Popen([
215 | "ffmpeg", "-y", "-v", "error", "-i", input_audio_file, "-ar", str(sample_rate), output_audio_file
216 | ])
217 | ret = p.wait()
218 | assert ret == 0, "Resample audio failed!"
219 | return output_audio_file
220 |
221 |
222 | class AudioProcessor:
223 | """
224 | AudioProcessor is a class that handles the processing of audio files.
225 | It takes care of preprocessing the audio files, extracting features
226 | using wav2vec models, and separating audio signals if needed.
227 |
228 | :param sample_rate: Sampling rate of the audio file
229 | :param fps: Frames per second for the extracted features
230 | :param wav2vec_model_path: Path to the wav2vec model
231 | :param audio_separator_model_path: Path to the audio separator model
232 | :param audio_separator_model_name: Name of the audio separator model
233 | :param cache_dir: Directory to cache the intermediate results
234 | :param device: Device to run the processing on
235 | :param only_last_features: Whether to only use the last features
236 | """
237 | def __init__(
238 | self,
239 | sample_rate,
240 | fps,
241 | wav2vec_model_path,
242 | audio_separator_model_path:str=None,
243 | audio_separator_model_name:str=None,
244 | cache_dir:str='',
245 | only_last_features:bool=False,
246 | device="cuda:0",
247 | ) -> None:
248 | self.sample_rate = sample_rate
249 | self.fps = fps
250 | self.device = device
251 |
252 | self.audio_encoder = Wav2VecModel.from_pretrained(wav2vec_model_path, local_files_only=True).to(device=device)
253 | self.audio_encoder.feature_extractor._freeze_parameters()
254 | self.only_last_features = only_last_features
255 |
256 | if audio_separator_model_name is not None:
257 | try:
258 | os.makedirs(cache_dir, exist_ok=True)
259 | except OSError as _:
260 | print("Fail to create the output cache dir.")
261 | self.audio_separator = Separator(
262 | output_dir=cache_dir,
263 | output_single_stem="vocals",
264 | model_file_dir=audio_separator_model_path,
265 | )
266 | self.audio_separator.load_model(audio_separator_model_name)
267 | assert self.audio_separator.model_instance is not None, "Fail to load audio separate model."
268 | else:
269 | self.audio_separator=None
270 | print("Use audio directly without vocals seperator.")
271 |
272 |
273 | self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True)
274 |
275 |
276 | def preprocess(self, wav_file: str, clip_length: int=-1):
277 | """
278 | Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate.
279 | The separated vocal track is then converted into wav2vec2 for further processing or analysis.
280 |
281 | Args:
282 | wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format.
283 |
284 | Raises:
285 | RuntimeError: Raises an exception if the WAV file cannot be processed. This could be due to issues
286 | such as file not found, unsupported file format, or errors during the audio processing steps.
287 |
288 | Returns:
289 | torch.tensor: Returns an audio embedding as a torch.tensor
290 | """
291 | if self.audio_separator is not None:
292 | # 1. separate vocals
293 | # TODO: process in memory
294 | outputs = self.audio_separator.separate(wav_file)
295 | if len(outputs) <= 0:
296 | raise RuntimeError("Audio separate failed.")
297 |
298 | vocal_audio_file = outputs[0]
299 | vocal_audio_name, _ = os.path.splitext(vocal_audio_file)
300 | vocal_audio_file = os.path.join(self.audio_separator.output_dir, vocal_audio_file)
301 | vocal_audio_file = resample_audio(vocal_audio_file, os.path.join(self.audio_separator.output_dir, f"{vocal_audio_name}-16k.wav"), self.sample_rate)
302 | else:
303 | vocal_audio_file=wav_file
304 |
305 | # 2. extract wav2vec features
306 | speech_array, sampling_rate = librosa.load(vocal_audio_file, sr=self.sample_rate)
307 | audio_feature = np.squeeze(self.wav2vec_feature_extractor(speech_array, sampling_rate=sampling_rate).input_values)
308 | seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps)
309 | audio_length = seq_len
310 |
311 | audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device)
312 |
313 | if clip_length>0 and seq_len % clip_length != 0:
314 | audio_feature = torch.nn.functional.pad(audio_feature, (0, (clip_length - seq_len % clip_length) * (self.sample_rate // self.fps)), 'constant', 0.0)
315 | seq_len += clip_length - seq_len % clip_length
316 | audio_feature = audio_feature.unsqueeze(0)
317 |
318 | with torch.no_grad():
319 | embeddings = self.audio_encoder(audio_feature, seq_len=seq_len, output_hidden_states=True)
320 | assert len(embeddings) > 0, "Fail to extract audio embedding"
321 | if self.only_last_features:
322 | audio_emb = embeddings.last_hidden_state.squeeze()
323 | else:
324 | audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
325 | audio_emb = rearrange(audio_emb, "b s d -> s b d")
326 |
327 | audio_emb = audio_emb.cpu().detach()
328 |
329 | return audio_emb, audio_length
330 |
331 | def get_embedding(self, wav_file: str):
332 | """preprocess wav audio file convert to embeddings
333 |
334 | Args:
335 | wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format.
336 |
337 | Returns:
338 | torch.tensor: Returns an audio embedding as a torch.tensor
339 | """
340 | speech_array, sampling_rate = librosa.load(
341 | wav_file, sr=self.sample_rate)
342 | assert sampling_rate == 16000, "The audio sample rate must be 16000"
343 | audio_feature = np.squeeze(self.wav2vec_feature_extractor(
344 | speech_array, sampling_rate=sampling_rate).input_values)
345 | seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps)
346 |
347 | audio_feature = torch.from_numpy(
348 | audio_feature).float().to(device=self.device)
349 | audio_feature = audio_feature.unsqueeze(0)
350 |
351 | with torch.no_grad():
352 | embeddings = self.audio_encoder(
353 | audio_feature, seq_len=seq_len, output_hidden_states=True)
354 | assert len(embeddings) > 0, "Fail to extract audio embedding"
355 |
356 | if self.only_last_features:
357 | audio_emb = embeddings.last_hidden_state.squeeze()
358 | else:
359 | audio_emb = torch.stack(
360 | embeddings.hidden_states[1:], dim=1).squeeze(0)
361 | audio_emb = rearrange(audio_emb, "b s d -> s b d")
362 |
363 | audio_emb = audio_emb.cpu().detach()
364 |
365 | return audio_emb
366 |
367 | def close(self):
368 | """
369 | TODO: to be implemented
370 | """
371 | return self
372 |
373 | def __enter__(self):
374 | return self
375 |
376 | def __exit__(self, _exc_type, _exc_val, _exc_tb):
377 | self.close()
378 |
--------------------------------------------------------------------------------
/models/basic_modules.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | import torch.nn as nn
4 | from timm.models.vision_transformer import Mlp
5 | from typing import Union, Tuple
6 | from collections.abc import Iterable
7 | from itertools import repeat
8 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential
9 |
10 |
11 | __all__ = ["build_act", "get_act_name"]
12 |
13 | # register activation function here
14 | # name: module, kwargs with default values
15 | REGISTERED_ACT_DICT: dict[str, tuple[type, dict[str, any]]] = {
16 | "relu": (nn.ReLU, {"inplace": True}),
17 | "relu6": (nn.ReLU6, {"inplace": True}),
18 | "hswish": (nn.Hardswish, {"inplace": True}),
19 | "hsigmoid": (nn.Hardsigmoid, {"inplace": True}),
20 | "swish": (nn.SiLU, {"inplace": True}),
21 | "silu": (nn.SiLU, {"inplace": True}),
22 | "tanh": (nn.Tanh, {}),
23 | "sigmoid": (nn.Sigmoid, {}),
24 | "gelu": (nn.GELU, {"approximate": "tanh"}),
25 | "mish": (nn.Mish, {"inplace": True}),
26 | "identity": (nn.Identity, {}),
27 | }
28 |
29 |
30 | def build_act(name: Union[str, None], **kwargs) -> Union[nn.Module, None]:
31 | if name in REGISTERED_ACT_DICT:
32 | act_cls, default_args = copy.deepcopy(REGISTERED_ACT_DICT[name])
33 | for key in default_args:
34 | if key in kwargs:
35 | default_args[key] = kwargs[key]
36 | return act_cls(**default_args)
37 | elif name is None or name.lower() == "none":
38 | return None
39 | else:
40 | raise ValueError(f"do not support: {name}")
41 |
42 |
43 | class LayerNorm2d(nn.LayerNorm):
44 | rmsnorm = False
45 |
46 | def forward(self, x: torch.Tensor) -> torch.Tensor:
47 | out = x if LayerNorm2d.rmsnorm else x - torch.mean(x, dim=1, keepdim=True)
48 | out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps)
49 | if self.elementwise_affine:
50 | out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
51 | return out
52 |
53 | def extra_repr(self) -> str:
54 | return f"{self.normalized_shape}, eps={self.eps}, elementwise_affine={self.elementwise_affine}, rmsnorm={self.rmsnorm}"
55 |
56 |
57 | # register normalization function here
58 | # name: module, kwargs with default values
59 | REGISTERED_NORMALIZATION_DICT: dict[str, tuple[type, dict[str, any]]] = {
60 | "bn2d": (nn.BatchNorm2d, {"num_features": None, "eps": 1e-5, "momentum": 0.1, "affine": True}),
61 | "syncbn": (nn.SyncBatchNorm, {"num_features": None, "eps": 1e-5, "momentum": 0.1, "affine": True}),
62 | "ln": (nn.LayerNorm, {"normalized_shape": None, "eps": 1e-5, "elementwise_affine": True}),
63 | "ln2d": (LayerNorm2d, {"normalized_shape": None, "eps": 1e-5, "elementwise_affine": True}),
64 | }
65 |
66 | def build_norm(name="bn2d", num_features=None, affine=True, **kwargs) -> Union[nn.Module, None]:
67 | if name in ["ln", "ln2d"]:
68 | kwargs["normalized_shape"] = num_features
69 | kwargs["elementwise_affine"] = affine
70 | else:
71 | kwargs["num_features"] = num_features
72 | kwargs["affine"] = affine
73 | if name in REGISTERED_NORMALIZATION_DICT:
74 | norm_cls, default_args = copy.deepcopy(REGISTERED_NORMALIZATION_DICT[name])
75 | for key in default_args:
76 | if key in kwargs:
77 | default_args[key] = kwargs[key]
78 | return norm_cls(**default_args)
79 | elif name is None or name.lower() == "none":
80 | return None
81 | else:
82 | raise ValueError("do not support: %s" % name)
83 |
84 |
85 | def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore
86 | """Repeat `val` for `repeat_time` times and return the list or val if list/tuple."""
87 | if isinstance(x, (list, tuple)):
88 | return list(x)
89 | return [x for _ in range(repeat_time)]
90 |
91 |
92 | def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore
93 | """Return tuple with min_len by repeating element at idx_repeat."""
94 | # convert to list first
95 | x = val2list(x)
96 |
97 | # repeat elements if necessary
98 | if len(x) > 0:
99 | x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
100 |
101 | return tuple(x)
102 |
103 |
104 | def get_same_padding(kernel_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]:
105 | if isinstance(kernel_size, tuple):
106 | return tuple([get_same_padding(ks) for ks in kernel_size])
107 | else:
108 | assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
109 | return kernel_size // 2
110 |
111 |
112 | def apply_rotary_emb(
113 | x: torch.Tensor,
114 | freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
115 | use_real: bool = True,
116 | use_real_unbind_dim: int = -1,
117 | ) -> Tuple[torch.Tensor, torch.Tensor]:
118 | """
119 | Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
120 | to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
121 | reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
122 | tensors contain rotary embeddings and are returned as real tensors.
123 |
124 | Args:
125 | x (`torch.Tensor`):
126 | Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
127 | freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
128 |
129 | Returns:
130 | Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
131 | """
132 | if use_real:
133 | cos, sin = freqs_cis # [S, D]
134 | cos = cos[None, None]
135 | sin = sin[None, None]
136 | cos, sin = cos.to(x.device), sin.to(x.device)
137 |
138 | if use_real_unbind_dim == -1:
139 | # Used for flux, cogvideox, hunyuan-dit
140 | x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
141 | x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
142 | elif use_real_unbind_dim == -2:
143 | # Used for Sana
144 | cos = cos.transpose(-1, -2)
145 | sin = sin.transpose(-1, -2)
146 | x_real, x_imag = x.reshape(*x.shape[:-2], -1, 2, x.shape[-1]).unbind(-2) # [B, H, D//2, S]
147 | x_rotated = torch.stack([-x_imag, x_real], dim=-2).flatten(2, 3)
148 | else:
149 | raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
150 |
151 | out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
152 |
153 | return out
154 | else:
155 | # used for lumina
156 | x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
157 | freqs_cis = freqs_cis.unsqueeze(2)
158 | x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
159 |
160 | return x_out.type_as(x)
161 |
162 |
163 | def get_same_padding(kernel_size: Union[int, tuple[int, ...]]) -> Union[int, tuple[int, ...]]:
164 | if isinstance(kernel_size, tuple):
165 | return tuple([get_same_padding(ks) for ks in kernel_size])
166 | else:
167 | assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
168 | return kernel_size // 2
169 |
170 |
171 | def auto_grad_checkpoint(module, *args, **kwargs):
172 | if getattr(module, "grad_checkpointing", False):
173 | if isinstance(module, Iterable):
174 | gc_step = module[0].grad_checkpointing_step
175 | return checkpoint_sequential(module, gc_step, *args, **kwargs)
176 | else:
177 | return checkpoint(module, *args, **kwargs)
178 | return module(*args, **kwargs)
179 |
180 |
181 | def checkpoint_sequential(functions, step, input, *args, **kwargs):
182 |
183 | # Hack for keyword-only parameter in a python 2.7-compliant way
184 | preserve = kwargs.pop("preserve_rng_state", True)
185 | if kwargs:
186 | raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
187 |
188 | def run_function(start, end, functions):
189 | def forward(input):
190 | for j in range(start, end + 1):
191 | input = functions[j](input, *args)
192 | return input
193 |
194 | return forward
195 |
196 | if isinstance(functions, torch.nn.Sequential):
197 | functions = list(functions.children())
198 |
199 | # the last chunk has to be non-volatile
200 | end = -1
201 | segment = len(functions) // step
202 | for start in range(0, step * (segment - 1), step):
203 | end = start + step - 1
204 | input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve)
205 | return run_function(end + 1, len(functions) - 1, functions)(input)
206 |
207 |
208 | def _ntuple(n):
209 | def parse(x):
210 | if isinstance(x, Iterable) and not isinstance(x, str):
211 | return x
212 | return tuple(repeat(x, n))
213 |
214 | return parse
215 |
216 |
217 | to_1tuple = _ntuple(1)
218 | to_2tuple = _ntuple(2)
219 |
220 |
221 | class RMSNorm(torch.nn.Module):
222 | def __init__(self, dim: int, scale_factor=1.0, eps: float = 1e-6):
223 | """
224 | Initialize the RMSNorm normalization layer.
225 |
226 | Args:
227 | dim (int): The dimension of the input tensor.
228 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
229 |
230 | Attributes:
231 | eps (float): A small value added to the denominator for numerical stability.
232 | weight (nn.Parameter): Learnable scaling parameter.
233 |
234 | """
235 | super().__init__()
236 | self.eps = eps
237 | self.weight = nn.Parameter(torch.ones(dim) * scale_factor)
238 |
239 | def _norm(self, x):
240 | """
241 | Apply the RMSNorm normalization to the input tensor.
242 |
243 | Args:
244 | x (torch.Tensor): The input tensor.
245 |
246 | Returns:
247 | torch.Tensor: The normalized tensor.
248 |
249 | """
250 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
251 |
252 | def forward(self, x):
253 | """
254 | Forward pass through the RMSNorm layer.
255 |
256 | Args:
257 | x (torch.Tensor): The input tensor.
258 |
259 | Returns:
260 | torch.Tensor: The output tensor after applying RMSNorm.
261 |
262 | """
263 | return (self.weight * self._norm(x.float())).type_as(x)
264 |
265 |
266 |
267 | class ConvLayer(nn.Module):
268 | def __init__(
269 | self,
270 | in_dim: int,
271 | out_dim: int,
272 | kernel_size=3,
273 | stride=1,
274 | dilation=1,
275 | groups=1,
276 | padding: Union[int, None] = None,
277 | use_bias=False,
278 | dropout=0.0,
279 | norm="bn2d",
280 | act="relu",
281 | ):
282 | super().__init__()
283 | if padding is None:
284 | padding = get_same_padding(kernel_size)
285 | padding *= dilation
286 |
287 | self.in_dim = in_dim
288 | self.out_dim = out_dim
289 | self.kernel_size = kernel_size
290 | self.stride = stride
291 | self.dilation = dilation
292 | self.groups = groups
293 | self.padding = padding
294 | self.use_bias = use_bias
295 |
296 | self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None
297 | self.conv = nn.Conv2d(
298 | in_dim,
299 | out_dim,
300 | kernel_size=(kernel_size, kernel_size),
301 | stride=(stride, stride),
302 | padding=padding,
303 | dilation=(dilation, dilation),
304 | groups=groups,
305 | bias=use_bias,
306 | )
307 | self.norm = build_norm(norm, num_features=out_dim)
308 | self.act = build_act(act)
309 |
310 | def forward(self, x: torch.Tensor) -> torch.Tensor:
311 | if self.dropout is not None:
312 | x = self.dropout(x)
313 | x = self.conv(x)
314 | if self.norm:
315 | x = self.norm(x)
316 | if self.act:
317 | x = self.act(x)
318 | return x
319 |
320 |
321 | class GLUMBConv(nn.Module):
322 | def __init__(
323 | self,
324 | in_features: int,
325 | hidden_features: int,
326 | out_feature=None,
327 | kernel_size=3,
328 | stride=1,
329 | padding: Union[int, None] = None,
330 | use_bias=False,
331 | norm=(None, None, None),
332 | act=("silu", "silu", None),
333 | dilation=1,
334 | ):
335 | out_feature = out_feature or in_features
336 | super().__init__()
337 | use_bias = val2tuple(use_bias, 3)
338 | norm = val2tuple(norm, 3)
339 | act = val2tuple(act, 3)
340 |
341 | self.glu_act = build_act(act[1], inplace=False)
342 | self.inverted_conv = ConvLayer(
343 | in_features,
344 | hidden_features * 2,
345 | 1,
346 | use_bias=use_bias[0],
347 | norm=norm[0],
348 | act=act[0],
349 | )
350 | self.depth_conv = ConvLayer(
351 | hidden_features * 2,
352 | hidden_features * 2,
353 | kernel_size,
354 | stride=stride,
355 | groups=hidden_features * 2,
356 | padding=padding,
357 | use_bias=use_bias[1],
358 | norm=norm[1],
359 | act=None,
360 | dilation=dilation,
361 | )
362 | self.point_conv = ConvLayer(
363 | hidden_features,
364 | out_feature,
365 | 1,
366 | use_bias=use_bias[2],
367 | norm=norm[2],
368 | act=act[2],
369 | )
370 | # from IPython import embed; embed(header='debug dilate conv')
371 |
372 | def forward(self, x: torch.Tensor, HW=None) -> torch.Tensor:
373 | B, N, C = x.shape
374 | if HW is None:
375 | H = W = int(N**0.5)
376 | else:
377 | H, W = HW
378 |
379 | x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
380 | x = self.inverted_conv(x)
381 | x = self.depth_conv(x)
382 |
383 | x, gate = torch.chunk(x, 2, dim=1)
384 | gate = self.glu_act(gate)
385 | x = x * gate
386 |
387 | x = self.point_conv(x)
388 | x = x.reshape(B, C, N).permute(0, 2, 1)
389 |
390 | return x
391 |
392 |
393 | class DWMlp(Mlp):
394 | """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
395 |
396 | def __init__(
397 | self,
398 | in_features,
399 | hidden_features=None,
400 | out_features=None,
401 | act_layer=nn.GELU,
402 | bias=True,
403 | drop=0.0,
404 | kernel_size=3,
405 | stride=1,
406 | dilation=1,
407 | padding=None,
408 | ):
409 | super().__init__(
410 | in_features=in_features,
411 | hidden_features=hidden_features,
412 | out_features=out_features,
413 | act_layer=act_layer,
414 | bias=bias,
415 | drop=drop,
416 | )
417 | hidden_features = hidden_features or in_features
418 | self.hidden_features = hidden_features
419 | if padding is None:
420 | padding = get_same_padding(kernel_size)
421 | padding *= dilation
422 |
423 | self.conv = nn.Conv2d(
424 | hidden_features,
425 | hidden_features,
426 | kernel_size=(kernel_size, kernel_size),
427 | stride=(stride, stride),
428 | padding=padding,
429 | dilation=(dilation, dilation),
430 | groups=hidden_features,
431 | bias=bias,
432 | )
433 |
434 | def forward(self, x, HW=None):
435 | B, N, C = x.shape
436 | if HW is None:
437 | H = W = int(N**0.5)
438 | else:
439 | H, W = HW
440 | x = self.fc1(x)
441 | x = self.act(x)
442 | x = self.drop1(x)
443 | x = x.reshape(B, H, W, self.hidden_features).permute(0, 3, 1, 2)
444 | x = self.conv(x)
445 | x = x.reshape(B, self.hidden_features, N).permute(0, 2, 1)
446 | x = self.fc2(x)
447 | x = self.drop2(x)
448 | return x
449 |
--------------------------------------------------------------------------------
/datasets/video_transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numbers
4 | from PIL import Image
5 |
6 | def _is_tensor_video_clip(clip):
7 | if not torch.is_tensor(clip):
8 | raise TypeError("clip should be Tensor. Got %s" % type(clip))
9 |
10 | if not clip.ndimension() == 4:
11 | raise ValueError("clip should be 4D. Got %dD" % clip.dim())
12 |
13 | return True
14 |
15 |
16 | def center_crop_arr(pil_image, image_size):
17 | """
18 | Center cropping implementation from ADM.
19 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
20 | """
21 | while min(*pil_image.size) >= 2 * image_size:
22 | pil_image = pil_image.resize(
23 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
24 | )
25 |
26 | scale = image_size / min(*pil_image.size)
27 | pil_image = pil_image.resize(
28 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
29 | )
30 |
31 | arr = np.array(pil_image)
32 | crop_y = (arr.shape[0] - image_size) // 2
33 | crop_x = (arr.shape[1] - image_size) // 2
34 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
35 |
36 |
37 | def crop(clip, i, j, h, w):
38 | """
39 | Args:
40 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
41 | """
42 | if len(clip.size()) != 4:
43 | raise ValueError("clip should be a 4D tensor")
44 | return clip[..., i : i + h, j : j + w]
45 |
46 |
47 | def resize(clip, target_size, interpolation_mode):
48 | if len(target_size) != 2:
49 | raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
50 | return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
51 |
52 | def resize_scale(clip, target_size, interpolation_mode):
53 | if len(target_size) != 2:
54 | raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
55 | H, W = clip.size(-2), clip.size(-1)
56 | scale_ = target_size[0] / min(H, W)
57 | new_H = int(round(H * scale_))
58 | new_W = int(round(W * scale_))
59 | return torch.nn.functional.interpolate(clip, size=(new_H, new_W), mode=interpolation_mode, align_corners=False)
60 |
61 |
62 | def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
63 | """
64 | Do spatial cropping and resizing to the video clip
65 | Args:
66 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
67 | i (int): i in (i,j) i.e coordinates of the upper left corner.
68 | j (int): j in (i,j) i.e coordinates of the upper left corner.
69 | h (int): Height of the cropped region.
70 | w (int): Width of the cropped region.
71 | size (tuple(int, int)): height and width of resized clip
72 | Returns:
73 | clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
74 | """
75 | if not _is_tensor_video_clip(clip):
76 | raise ValueError("clip should be a 4D torch.tensor")
77 | clip = crop(clip, i, j, h, w)
78 | clip = resize(clip, size, interpolation_mode)
79 | return clip
80 |
81 |
82 | def center_crop(clip, crop_size):
83 | if not _is_tensor_video_clip(clip):
84 | raise ValueError("clip should be a 4D torch.tensor")
85 | h, w = clip.size(-2), clip.size(-1)
86 | th, tw = crop_size
87 | if h < th or w < tw:
88 | raise ValueError(f"height {h} and width {w} must be no smaller than crop_size ({th, tw})")
89 |
90 | i = int(round((h - th) / 2.0))
91 | j = int(round((w - tw) / 2.0))
92 | return crop(clip, i, j, th, tw)
93 |
94 |
95 | def center_crop_using_short_edge(clip):
96 | if not _is_tensor_video_clip(clip):
97 | raise ValueError("clip should be a 4D torch.tensor")
98 | h, w = clip.size(-2), clip.size(-1)
99 | if h < w:
100 | th, tw = h, h
101 | i = 0
102 | j = int(round((w - tw) / 2.0))
103 | else:
104 | th, tw = w, w
105 | i = int(round((h - th) / 2.0))
106 | j = 0
107 | return crop(clip, i, j, th, tw)
108 |
109 |
110 | def random_shift_crop(clip):
111 | '''
112 | Slide along the long edge, with the short edge as crop size
113 | '''
114 | if not _is_tensor_video_clip(clip):
115 | raise ValueError("clip should be a 4D torch.tensor")
116 | h, w = clip.size(-2), clip.size(-1)
117 |
118 | if h <= w:
119 | long_edge = w
120 | short_edge = h
121 | else:
122 | long_edge = h
123 | short_edge =w
124 |
125 | th, tw = short_edge, short_edge
126 |
127 | i = torch.randint(0, h - th + 1, size=(1,)).item()
128 | j = torch.randint(0, w - tw + 1, size=(1,)).item()
129 | return crop(clip, i, j, th, tw)
130 |
131 |
132 | def to_tensor(clip):
133 | """
134 | Convert tensor data type from uint8 to float, divide value by 255.0 and
135 | permute the dimensions of clip tensor
136 | Args:
137 | clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
138 | Return:
139 | clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
140 | """
141 | _is_tensor_video_clip(clip)
142 | if not clip.dtype == torch.uint8:
143 | raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
144 | # return clip.float().permute(3, 0, 1, 2) / 255.0
145 | return clip.float() / 255.0
146 |
147 |
148 | def normalize(clip, mean, std, inplace=False):
149 | """
150 | Args:
151 | clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
152 | mean (tuple): pixel RGB mean. Size is (3)
153 | std (tuple): pixel standard deviation. Size is (3)
154 | Returns:
155 | normalized clip (torch.tensor): Size is (T, C, H, W)
156 | """
157 | if not _is_tensor_video_clip(clip):
158 | raise ValueError("clip should be a 4D torch.tensor")
159 | if not inplace:
160 | clip = clip.clone()
161 | mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
162 | # print(mean)
163 | std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
164 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
165 | return clip
166 |
167 |
168 | def hflip(clip):
169 | """
170 | Args:
171 | clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
172 | Returns:
173 | flipped clip (torch.tensor): Size is (T, C, H, W)
174 | """
175 | if not _is_tensor_video_clip(clip):
176 | raise ValueError("clip should be a 4D torch.tensor")
177 | return clip.flip(-1)
178 |
179 |
180 | class RandomCropVideo:
181 | def __init__(self, size):
182 | if isinstance(size, numbers.Number):
183 | self.size = (int(size), int(size))
184 | else:
185 | self.size = size
186 |
187 | def __call__(self, clip):
188 | """
189 | Args:
190 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
191 | Returns:
192 | torch.tensor: randomly cropped video clip.
193 | size is (T, C, OH, OW)
194 | """
195 | i, j, h, w = self.get_params(clip)
196 | return crop(clip, i, j, h, w)
197 |
198 | def get_params(self, clip):
199 | h, w = clip.shape[-2:]
200 | th, tw = self.size
201 |
202 | if h < th or w < tw:
203 | raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
204 |
205 | if w == tw and h == th:
206 | return 0, 0, h, w
207 |
208 | i = torch.randint(0, h - th + 1, size=(1,)).item()
209 | j = torch.randint(0, w - tw + 1, size=(1,)).item()
210 |
211 | return i, j, th, tw
212 |
213 | def __repr__(self) -> str:
214 | return f"{self.__class__.__name__}(size={self.size})"
215 |
216 | class CenterCropResizeVideo:
217 | '''
218 | First use the short side for cropping length,
219 | center crop video, then resize to the specified size
220 | '''
221 | def __init__(
222 | self,
223 | size,
224 | interpolation_mode="bilinear",
225 | ):
226 | if isinstance(size, tuple):
227 | if len(size) != 2:
228 | raise ValueError(f"size should be tuple (height, width), instead got {size}")
229 | self.size = size
230 | else:
231 | self.size = (size, size)
232 |
233 | self.interpolation_mode = interpolation_mode
234 |
235 |
236 | def __call__(self, clip):
237 | """
238 | Args:
239 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
240 | Returns:
241 | torch.tensor: scale resized / center cropped video clip.
242 | size is (T, C, crop_size, crop_size)
243 | """
244 | clip_center_crop = center_crop_using_short_edge(clip)
245 | clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode)
246 | return clip_center_crop_resize
247 |
248 | def __repr__(self) -> str:
249 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
250 |
251 | class UCFCenterCropVideo:
252 | '''
253 | First scale to the specified size in equal proportion to the short edge,
254 | then center cropping
255 | '''
256 | def __init__(
257 | self,
258 | size,
259 | interpolation_mode="bilinear",
260 | ):
261 | if isinstance(size, tuple):
262 | if len(size) != 2:
263 | raise ValueError(f"size should be tuple (height, width), instead got {size}")
264 | self.size = size
265 | else:
266 | self.size = (size, size)
267 |
268 | self.interpolation_mode = interpolation_mode
269 |
270 |
271 | def __call__(self, clip):
272 | """
273 | Args:
274 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
275 | Returns:
276 | torch.tensor: scale resized / center cropped video clip.
277 | size is (T, C, crop_size, crop_size)
278 | """
279 | clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
280 | clip_center_crop = center_crop(clip_resize, self.size)
281 | return clip_center_crop
282 |
283 | def __repr__(self) -> str:
284 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
285 |
286 | class KineticsRandomCropResizeVideo:
287 | '''
288 | Slide along the long edge, with the short edge as crop size. And resie to the desired size.
289 | '''
290 | def __init__(
291 | self,
292 | size,
293 | interpolation_mode="bilinear",
294 | ):
295 | if isinstance(size, tuple):
296 | if len(size) != 2:
297 | raise ValueError(f"size should be tuple (height, width), instead got {size}")
298 | self.size = size
299 | else:
300 | self.size = (size, size)
301 |
302 | self.interpolation_mode = interpolation_mode
303 |
304 | def __call__(self, clip):
305 | clip_random_crop = random_shift_crop(clip)
306 | clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
307 | return clip_resize
308 |
309 |
310 | class CenterCropVideo:
311 | def __init__(
312 | self,
313 | size,
314 | interpolation_mode="bilinear",
315 | ):
316 | if isinstance(size, tuple):
317 | if len(size) != 2:
318 | raise ValueError(f"size should be tuple (height, width), instead got {size}")
319 | self.size = size
320 | else:
321 | self.size = (size, size)
322 |
323 | self.interpolation_mode = interpolation_mode
324 |
325 |
326 | def __call__(self, clip):
327 | """
328 | Args:
329 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
330 | Returns:
331 | torch.tensor: center cropped video clip.
332 | size is (T, C, crop_size, crop_size)
333 | """
334 | clip_center_crop = center_crop(clip, self.size)
335 | return clip_center_crop
336 |
337 | def __repr__(self) -> str:
338 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
339 |
340 |
341 | class NormalizeVideo:
342 | """
343 | Normalize the video clip by mean subtraction and division by standard deviation
344 | Args:
345 | mean (3-tuple): pixel RGB mean
346 | std (3-tuple): pixel RGB standard deviation
347 | inplace (boolean): whether do in-place normalization
348 | """
349 |
350 | def __init__(self, mean, std, inplace=False):
351 | self.mean = mean
352 | self.std = std
353 | self.inplace = inplace
354 |
355 | def __call__(self, clip):
356 | """
357 | Args:
358 | clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
359 | """
360 | return normalize(clip, self.mean, self.std, self.inplace)
361 |
362 | def __repr__(self) -> str:
363 | return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
364 |
365 |
366 | class ToTensorVideo:
367 | """
368 | Convert tensor data type from uint8 to float, divide value by 255.0 and
369 | permute the dimensions of clip tensor
370 | """
371 |
372 | def __init__(self):
373 | pass
374 |
375 | def __call__(self, clip):
376 | """
377 | Args:
378 | clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
379 | Return:
380 | clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
381 | """
382 | return to_tensor(clip)
383 |
384 | def __repr__(self) -> str:
385 | return self.__class__.__name__
386 |
387 |
388 | class RandomHorizontalFlipVideo:
389 | """
390 | Flip the video clip along the horizontal direction with a given probability
391 | Args:
392 | p (float): probability of the clip being flipped. Default value is 0.5
393 | """
394 |
395 | def __init__(self, p=0.5):
396 | self.p = p
397 |
398 | def __call__(self, clip):
399 | """
400 | Args:
401 | clip (torch.tensor): Size is (T, C, H, W)
402 | Return:
403 | clip (torch.tensor): Size is (T, C, H, W)
404 | """
405 | if random.random() < self.p:
406 | clip = hflip(clip)
407 | return clip
408 |
409 | def __repr__(self) -> str:
410 | return f"{self.__class__.__name__}(p={self.p})"
411 |
412 | # ------------------------------------------------------------
413 | # --------------------- Sampling ---------------------------
414 | # ------------------------------------------------------------
415 | class TemporalRandomCrop(object):
416 | """Temporally crop the given frame indices at a random location.
417 |
418 | Args:
419 | size (int): Desired length of frames will be seen in the model.
420 | """
421 |
422 | def __init__(self, size):
423 | self.size = size
424 |
425 | def __call__(self, total_frames):
426 | rand_end = max(0, total_frames - self.size - 1)
427 | begin_index = random.randint(0, rand_end)
428 | end_index = min(begin_index + self.size, total_frames)
429 | return begin_index, end_index
430 |
431 |
432 | if __name__ == '__main__':
433 | from torchvision import transforms
434 | import torchvision.io as io
435 | import numpy as np
436 | from torchvision.utils import save_image
437 | import os
438 |
439 | vframes, aframes, info = io.read_video(
440 | filename='./v_Archery_g01_c03.avi',
441 | pts_unit='sec',
442 | output_format='TCHW'
443 | )
444 |
445 | trans = transforms.Compose([
446 | ToTensorVideo(),
447 | RandomHorizontalFlipVideo(),
448 | UCFCenterCropVideo(512),
449 | # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
450 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
451 | ])
452 |
453 | target_video_len = 32
454 | frame_interval = 1
455 | total_frames = len(vframes)
456 | print(total_frames)
457 |
458 | temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
459 |
460 |
461 | # Sampling video frames
462 | start_frame_ind, end_frame_ind = temporal_sample(total_frames)
463 | # print(start_frame_ind)
464 | # print(end_frame_ind)
465 | assert end_frame_ind - start_frame_ind >= target_video_len
466 | frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
467 | print(frame_indice)
468 |
469 | select_vframes = vframes[frame_indice]
470 | print(select_vframes.shape)
471 | print(select_vframes.dtype)
472 |
473 | select_vframes_trans = trans(select_vframes)
474 | print(select_vframes_trans.shape)
475 | print(select_vframes_trans.dtype)
476 |
477 | select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
478 | print(select_vframes_trans_int.dtype)
479 | print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
480 |
481 | io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
482 |
483 | for i in range(target_video_len):
484 | save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True, value_range=(-1, 1))
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 | import logging
5 | import random
6 | import subprocess
7 | import numpy as np
8 | import torch.distributed as dist
9 |
10 | from torch import inf
11 | from PIL import Image
12 | from typing import Union, Iterable
13 | from collections import OrderedDict
14 | from torch.utils.tensorboard import SummaryWriter
15 |
16 | from diffusers.utils import is_bs4_available, is_ftfy_available
17 |
18 | import html
19 | import re
20 | import urllib.parse as ul
21 | from moviepy.editor import VideoFileClip, AudioFileClip, VideoClip
22 |
23 | if is_bs4_available():
24 | from bs4 import BeautifulSoup
25 |
26 | if is_ftfy_available():
27 | import ftfy
28 |
29 | _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
30 |
31 |
32 | #################################################################################
33 | # Training Clip Gradients #
34 | #################################################################################
35 |
36 | def get_grad_norm(
37 | parameters: _tensor_or_tensors, norm_type: float = 2.0) -> torch.Tensor:
38 | r"""
39 | Copy from torch.nn.utils.clip_grad_norm_
40 |
41 | Clips gradient norm of an iterable of parameters.
42 |
43 | The norm is computed over all gradients together, as if they were
44 | concatenated into a single vector. Gradients are modified in-place.
45 |
46 | Args:
47 | parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
48 | single Tensor that will have gradients normalized
49 | max_norm (float or int): max norm of the gradients
50 | norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
51 | infinity norm.
52 | error_if_nonfinite (bool): if True, an error is thrown if the total
53 | norm of the gradients from :attr:`parameters` is ``nan``,
54 | ``inf``, or ``-inf``. Default: False (will switch to True in the future)
55 |
56 | Returns:
57 | Total norm of the parameter gradients (viewed as a single vector).
58 | """
59 | if isinstance(parameters, torch.Tensor):
60 | parameters = [parameters]
61 | grads = [p.grad for p in parameters if p.grad is not None]
62 | norm_type = float(norm_type)
63 | if len(grads) == 0:
64 | return torch.tensor(0.)
65 | device = grads[0].device
66 | if norm_type == inf:
67 | norms = [g.detach().abs().max().to(device) for g in grads]
68 | total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
69 | else:
70 | total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
71 | return total_norm
72 |
73 | def clip_grad_norm_(
74 | parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
75 | error_if_nonfinite: bool = False, clip_grad = True) -> torch.Tensor:
76 | r"""
77 | Copy from torch.nn.utils.clip_grad_norm_
78 |
79 | Clips gradient norm of an iterable of parameters.
80 |
81 | The norm is computed over all gradients together, as if they were
82 | concatenated into a single vector. Gradients are modified in-place.
83 |
84 | Args:
85 | parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
86 | single Tensor that will have gradients normalized
87 | max_norm (float or int): max norm of the gradients
88 | norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
89 | infinity norm.
90 | error_if_nonfinite (bool): if True, an error is thrown if the total
91 | norm of the gradients from :attr:`parameters` is ``nan``,
92 | ``inf``, or ``-inf``. Default: False (will switch to True in the future)
93 |
94 | Returns:
95 | Total norm of the parameter gradients (viewed as a single vector).
96 | """
97 | if isinstance(parameters, torch.Tensor):
98 | parameters = [parameters]
99 | grads = [p.grad for p in parameters if p.grad is not None]
100 | max_norm = float(max_norm)
101 | norm_type = float(norm_type)
102 | if len(grads) == 0:
103 | return torch.tensor(0.)
104 | device = grads[0].device
105 | if norm_type == inf:
106 | norms = [g.detach().abs().max().to(device) for g in grads]
107 | total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
108 | else:
109 | total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
110 |
111 | if clip_grad:
112 | if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
113 | raise RuntimeError(
114 | f'The total norm of order {norm_type} for gradients from '
115 | '`parameters` is non-finite, so it cannot be clipped. To disable '
116 | 'this error and scale the gradients by the non-finite norm anyway, '
117 | 'set `error_if_nonfinite=False`')
118 | clip_coef = max_norm / (total_norm + 1e-6)
119 | # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
120 | # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
121 | # when the gradients do not reside in CPU memory.
122 | clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
123 | for g in grads:
124 | g.detach().mul_(clip_coef_clamped.to(g.device))
125 | # gradient_cliped = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
126 | return total_norm
127 |
128 | def get_experiment_dir(root_dir, args):
129 | # if args.pretrained is not None and 'Latte-XL-2-256x256.pt' not in args.pretrained:
130 | # root_dir += '-WOPRE'
131 | if args.use_compile:
132 | root_dir += '-Compile' # speedup by torch compile
133 | if args.fixed_spatial:
134 | root_dir += '-FixedSpa'
135 | if args.enable_xformers_memory_efficient_attention:
136 | root_dir += '-Xfor'
137 | if args.gradient_checkpointing:
138 | root_dir += '-Gc'
139 | if args.mixed_precision:
140 | root_dir += '-Amp'
141 | if args.image_size == 512:
142 | root_dir += '-512'
143 | return root_dir
144 |
145 | #################################################################################
146 | # Training Logger #
147 | #################################################################################
148 |
149 | def create_logger(logging_dir):
150 | """
151 | Create a logger that writes to a log file and stdout.
152 | """
153 | if dist.get_rank() == 0: # real logger
154 | logging.basicConfig(
155 | level=logging.INFO,
156 | # format='[\033[34m%(asctime)s\033[0m] %(message)s',
157 | format='[%(asctime)s] %(message)s',
158 | datefmt='%Y-%m-%d %H:%M:%S',
159 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
160 | )
161 | logger = logging.getLogger(__name__)
162 |
163 | else: # dummy logger (does nothing)
164 | logger = logging.getLogger(__name__)
165 | logger.addHandler(logging.NullHandler())
166 | return logger
167 |
168 |
169 | def create_tensorboard(tensorboard_dir):
170 | """
171 | Create a tensorboard that saves losses.
172 | """
173 | if dist.get_rank() == 0: # real tensorboard
174 | # tensorboard
175 | writer = SummaryWriter(tensorboard_dir)
176 |
177 | return writer
178 |
179 | def write_tensorboard(writer, *args):
180 | '''
181 | write the loss information to a tensorboard file.
182 | Only for pytorch DDP mode.
183 | '''
184 | if dist.get_rank() == 0: # real tensorboard
185 | writer.add_scalar(args[0], args[1], args[2])
186 |
187 | #################################################################################
188 | # EMA Update/ DDP Training Utils #
189 | #################################################################################
190 |
191 | @torch.no_grad()
192 | def update_ema(ema_model, model, decay=0.9999):
193 | """
194 | Step the EMA model towards the current model.
195 | """
196 | ema_params = OrderedDict(ema_model.named_parameters())
197 | model_params = OrderedDict(model.named_parameters())
198 |
199 | for name, param in model_params.items():
200 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
201 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
202 |
203 | def requires_grad(model, flag=True):
204 | """
205 | Set requires_grad flag for all parameters in a model.
206 | """
207 | for p in model.parameters():
208 | p.requires_grad = flag
209 |
210 | def cleanup():
211 | """
212 | End DDP training.
213 | """
214 | dist.destroy_process_group()
215 |
216 |
217 | def setup_distributed(backend="nccl", port=None):
218 | """Initialize distributed training environment.
219 | support both slurm and torch.distributed.launch
220 | see torch.distributed.init_process_group() for more details
221 | """
222 | num_gpus = torch.cuda.device_count()
223 |
224 | if "SLURM_JOB_ID" in os.environ:
225 | rank = int(os.environ["SLURM_PROCID"])
226 | world_size = int(os.environ["SLURM_NTASKS"])
227 | node_list = os.environ["SLURM_NODELIST"]
228 | addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
229 | # specify master port
230 | if port is not None:
231 | os.environ["MASTER_PORT"] = str(port)
232 | elif "MASTER_PORT" not in os.environ:
233 | # os.environ["MASTER_PORT"] = "29566"
234 | os.environ["MASTER_PORT"] = str(29567 + num_gpus)
235 | if "MASTER_ADDR" not in os.environ:
236 | os.environ["MASTER_ADDR"] = addr
237 | os.environ["WORLD_SIZE"] = str(world_size)
238 | os.environ["LOCAL_RANK"] = str(rank % num_gpus)
239 | os.environ["RANK"] = str(rank)
240 | else:
241 | rank = int(os.environ["RANK"])
242 | world_size = int(os.environ["WORLD_SIZE"])
243 |
244 | # torch.cuda.set_device(rank % num_gpus)
245 |
246 | dist.init_process_group(
247 | backend=backend,
248 | world_size=world_size,
249 | rank=rank,
250 | )
251 |
252 | #################################################################################
253 | # Testing Utils #
254 | #################################################################################
255 |
256 | def save_video_grid(video, nrow=None):
257 | b, t, h, w, c = video.shape
258 |
259 | if nrow is None:
260 | nrow = math.ceil(math.sqrt(b))
261 | ncol = math.ceil(b / nrow)
262 | padding = 1
263 | video_grid = torch.zeros((t, (padding + h) * nrow + padding,
264 | (padding + w) * ncol + padding, c), dtype=torch.uint8)
265 |
266 | for i in range(b):
267 | r = i // ncol
268 | c = i % ncol
269 | start_r = (padding + h) * r
270 | start_c = (padding + w) * c
271 | video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i]
272 |
273 | return video_grid
274 |
275 | def find_model(model_name, use_ema=True):
276 | """
277 | Finds a pre-trained Latte model, downloading it if necessary. Alternatively, loads a model from a local path.
278 | """
279 | assert os.path.isfile(model_name), f'Could not find Latte checkpoint at {model_name}'
280 | checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
281 |
282 | # Check if 'audioproj' exists in the checkpoint before returning it
283 | audioproj_dict = checkpoint.get("audioproj", None)
284 |
285 | if use_ema: # supports checkpoints from train.py
286 | print('Using Ema!')
287 | checkpoint = checkpoint["ema"]
288 | else:
289 | print('Using model!')
290 | checkpoint = checkpoint['model']
291 |
292 | # Return the model and audioproj_dict if it exists, otherwise return the model alone
293 | return (checkpoint, audioproj_dict) if audioproj_dict is not None else (checkpoint,)
294 |
295 |
296 | #################################################################################
297 | # MMCV Utils #
298 | #################################################################################
299 |
300 |
301 | def collect_env():
302 | # Copyright (c) OpenMMLab. All rights reserved.
303 | from mmcv.utils import collect_env as collect_base_env
304 | from mmcv.utils import get_git_hash
305 | """Collect the information of the running environments."""
306 |
307 | env_info = collect_base_env()
308 | env_info['MMClassification'] = get_git_hash()[:7]
309 |
310 | for name, val in env_info.items():
311 | print(f'{name}: {val}')
312 |
313 | print(torch.cuda.get_arch_list())
314 | print(torch.version.cuda)
315 |
316 |
317 | #################################################################################
318 | # Pixart-alpha Utils #
319 | #################################################################################
320 |
321 | bad_punct_regex = re.compile(
322 | r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
323 | )
324 |
325 | def text_preprocessing(text, clean_caption=False):
326 | if clean_caption and not is_bs4_available():
327 | clean_caption = False
328 |
329 | if clean_caption and not is_ftfy_available():
330 | clean_caption = False
331 |
332 | if not isinstance(text, (tuple, list)):
333 | text = [text]
334 |
335 | def process(text: str):
336 | if clean_caption:
337 | text = clean_caption(text)
338 | text = clean_caption(text)
339 | else:
340 | text = text.lower().strip()
341 | return text
342 |
343 | return [process(t) for t in text]
344 |
345 | # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
346 | def clean_caption(caption):
347 | caption = str(caption)
348 | caption = ul.unquote_plus(caption)
349 | caption = caption.strip().lower()
350 | caption = re.sub("", "person", caption)
351 | # urls:
352 | caption = re.sub(
353 | r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
354 | "",
355 | caption,
356 | ) # regex for urls
357 | caption = re.sub(
358 | r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
359 | "",
360 | caption,
361 | ) # regex for urls
362 | # html:
363 | caption = BeautifulSoup(caption, features="html.parser").text
364 |
365 | # @
366 | caption = re.sub(r"@[\w\d]+\b", "", caption)
367 |
368 | # 31C0—31EF CJK Strokes
369 | # 31F0—31FF Katakana Phonetic Extensions
370 | # 3200—32FF Enclosed CJK Letters and Months
371 | # 3300—33FF CJK Compatibility
372 | # 3400—4DBF CJK Unified Ideographs Extension A
373 | # 4DC0—4DFF Yijing Hexagram Symbols
374 | # 4E00—9FFF CJK Unified Ideographs
375 | caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
376 | caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
377 | caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
378 | caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
379 | caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
380 | caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
381 | caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
382 | #######################################################
383 |
384 | # все виды тире / all types of dash --> "-"
385 | caption = re.sub(
386 | r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
387 | "-",
388 | caption,
389 | )
390 |
391 | # кавычки к одному стандарту
392 | caption = re.sub(r"[`´«»“”¨]", '"', caption)
393 | caption = re.sub(r"[‘’]", "'", caption)
394 |
395 | # "
396 | caption = re.sub(r""?", "", caption)
397 | # &
398 | caption = re.sub(r"&", "", caption)
399 |
400 | # ip adresses:
401 | caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
402 |
403 | # article ids:
404 | caption = re.sub(r"\d:\d\d\s+$", "", caption)
405 |
406 | # \n
407 | caption = re.sub(r"\\n", " ", caption)
408 |
409 | # "#123"
410 | caption = re.sub(r"#\d{1,3}\b", "", caption)
411 | # "#12345.."
412 | caption = re.sub(r"#\d{5,}\b", "", caption)
413 | # "123456.."
414 | caption = re.sub(r"\b\d{6,}\b", "", caption)
415 | # filenames:
416 | caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
417 |
418 | #
419 | caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
420 | caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
421 |
422 | caption = re.sub(bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
423 | caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
424 |
425 | # this-is-my-cute-cat / this_is_my_cute_cat
426 | regex2 = re.compile(r"(?:\-|\_)")
427 | if len(re.findall(regex2, caption)) > 3:
428 | caption = re.sub(regex2, " ", caption)
429 |
430 | caption = ftfy.fix_text(caption)
431 | caption = html.unescape(html.unescape(caption))
432 |
433 | caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
434 | caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
435 | caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
436 |
437 | caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
438 | caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
439 | caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
440 | caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
441 | caption = re.sub(r"\bpage\s+\d+\b", "", caption)
442 |
443 | caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
444 |
445 | caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
446 |
447 | caption = re.sub(r"\b\s+\:\s+", r": ", caption)
448 | caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
449 | caption = re.sub(r"\s+", " ", caption)
450 |
451 | caption.strip()
452 |
453 | caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
454 | caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
455 | caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
456 | caption = re.sub(r"^\.\S+$", "", caption)
457 |
458 | return caption.strip()
459 |
460 |
461 | def combine_video_audio(video_path, audio_path, output_path):
462 | """
463 | Combine a video file and an audio file into one video with audio.
464 |
465 | Parameters:
466 | video_path (str): Path to the input video file (MP4).
467 | audio_path (str): Path to the input audio file (WAV).
468 | output_path (str): Path to save the output video file (MP4).
469 | """
470 | video_clip = VideoFileClip(video_path)
471 | audio_clip = AudioFileClip(audio_path)
472 |
473 | video_with_audio = video_clip.set_audio(audio_clip)
474 |
475 | video_with_audio.write_videofile(output_path, codec="libx264", audio_codec="aac")
476 |
477 | video_clip.close()
478 | audio_clip.close()
479 |
480 |
481 | def tensor_to_video(tensor, output_video_file, audio_source, fps=25):
482 | """
483 | Converts a Tensor with shape [c, f, h, w] into a video and adds an audio track from the specified audio file.
484 |
485 | Args:
486 | tensor (Tensor): The Tensor to be converted, shaped [c, f, h, w].
487 | output_video_file (str): The file path where the output video will be saved.
488 | audio_source (str): The path to the audio file (WAV file) that contains the audio track to be added.
489 | fps (int): The frame rate of the output video. Default is 25 fps.
490 | """
491 | if isinstance(tensor, torch.Tensor):
492 | tensor = np.array(tensor).astype(np.uint8)
493 |
494 | def make_frame(t):
495 | # get index
496 | frame_index = min(int(t * fps), tensor.shape[0] - 1)
497 | return tensor[frame_index]
498 | new_video_clip = VideoClip(make_frame, duration=tensor.shape[0] / fps)
499 | audio_clip = AudioFileClip(audio_source).subclip(0, tensor.shape[0] / fps)
500 | new_video_clip = new_video_clip.set_audio(audio_clip)
501 | new_video_clip.write_videofile(output_video_file, fps=fps, codec="libx264", audio_codec='aac')
502 |
503 |
504 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
505 | """
506 | Get a pre-defined beta schedule for the given name.
507 | The beta schedule library consists of beta schedules which remain similar
508 | in the limit of num_diffusion_timesteps.
509 | Beta schedules may be added, but should not be removed or changed once
510 | they are committed to maintain backwards compatibility.
511 | """
512 | if schedule_name == "linear":
513 | # Linear schedule from Ho et al, extended to work for any number of
514 | # diffusion steps.
515 | scale = 1000 / num_diffusion_timesteps
516 | return get_beta_schedule(
517 | "linear",
518 | beta_start=scale * 0.0001,
519 | beta_end=scale * 0.02,
520 | num_diffusion_timesteps=num_diffusion_timesteps,
521 | )
522 | elif schedule_name == "squaredcos_cap_v2":
523 | return betas_for_alpha_bar(
524 | num_diffusion_timesteps,
525 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
526 | )
527 | else:
528 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
529 |
530 |
531 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
532 | """
533 | This is the deprecated API for creating beta schedules.
534 | See get_named_beta_schedule() for the new library of schedules.
535 | """
536 | if beta_schedule == "quad":
537 | betas = (
538 | np.linspace(
539 | beta_start ** 0.5,
540 | beta_end ** 0.5,
541 | num_diffusion_timesteps,
542 | dtype=np.float64,
543 | )
544 | ** 2
545 | )
546 | elif beta_schedule == "linear":
547 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
548 | elif beta_schedule == "warmup10":
549 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
550 | elif beta_schedule == "warmup50":
551 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
552 | elif beta_schedule == "const":
553 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
554 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
555 | betas = 1.0 / np.linspace(
556 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
557 | )
558 | else:
559 | raise NotImplementedError(beta_schedule)
560 | assert betas.shape == (num_diffusion_timesteps,)
561 | return betas
562 |
563 |
564 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
565 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
566 | warmup_time = int(num_diffusion_timesteps * warmup_frac)
567 | betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
568 | return betas
569 |
570 |
571 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
572 | """
573 | Create a beta schedule that discretizes the given alpha_t_bar function,
574 | which defines the cumulative product of (1-beta) over time from t = [0,1].
575 | :param num_diffusion_timesteps: the number of betas to produce.
576 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
577 | produces the cumulative product of (1-beta) up to that
578 | part of the diffusion process.
579 | :param max_beta: the maximum beta to use; use values lower than 1 to
580 | prevent singularities.
581 | """
582 | betas = []
583 | for i in range(num_diffusion_timesteps):
584 | t1 = i / num_diffusion_timesteps
585 | t2 = (i + 1) / num_diffusion_timesteps
586 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
587 | return np.array(betas)
588 |
--------------------------------------------------------------------------------
/models/model_cat.py:
--------------------------------------------------------------------------------
1 | # All rights reserved.
2 |
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | # --------------------------------------------------------
6 | # References:
7 | # GLIDE: https://github.com/openai/glide-text2im
8 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
9 | # --------------------------------------------------------
10 | import math
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | import numpy as np
15 | from torch.utils.checkpoint import checkpoint
16 | from einops import rearrange, repeat
17 | from timm.models.vision_transformer import Mlp, PatchEmbed
18 |
19 | # the xformers lib allows less memory, faster training and inference
20 | try:
21 | import xformers
22 | import xformers.ops
23 | except:
24 | XFORMERS_IS_AVAILBLE = False
25 |
26 | # from timm.models.layers.helpers import to_2tuple
27 | # from timm.models.layers.trace_utils import _assert
28 |
29 | def modulate(x, shift, scale, T):
30 | N, M = x.shape[-2], x.shape[-1]
31 | B = scale.shape[0]
32 | x = rearrange(x, '(b t) n m-> b (t n) m',b=B,t=T,n=N,m=M)
33 | x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
34 | x = rearrange(x, 'b (t n) m-> (b t) n m',b=B,t=T,n=N,m=M)
35 | return x
36 |
37 | #################################################################################
38 | # Attention Layers from TIMM #
39 | #################################################################################
40 |
41 | class Attention(nn.Module):
42 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math'):
43 | super().__init__()
44 | assert dim % num_heads == 0, 'dim should be divisible by num_heads'
45 | self.num_heads = num_heads
46 | head_dim = dim // num_heads
47 | self.scale = head_dim ** -0.5
48 | self.attention_mode = attention_mode
49 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
50 | self.attn_drop = nn.Dropout(attn_drop)
51 | self.proj = nn.Linear(dim, dim)
52 | self.proj_drop = nn.Dropout(proj_drop)
53 |
54 | def forward(self, x):
55 | B, N, C = x.shape
56 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
57 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
58 |
59 | if self.attention_mode == 'xformers': # cause loss nan while using with amp
60 | # https://github.com/facebookresearch/xformers/blob/e8bd8f932c2f48e3a3171d06749eecbbf1de420c/xformers/ops/fmha/__init__.py#L135
61 | q_xf = q.transpose(1,2).contiguous()
62 | k_xf = k.transpose(1,2).contiguous()
63 | v_xf = v.transpose(1,2).contiguous()
64 | x = xformers.ops.memory_efficient_attention(q_xf, k_xf, v_xf).reshape(B, N, C)
65 |
66 | elif self.attention_mode == 'flash':
67 | # cause loss nan while using with amp
68 | # Optionally use the context manager to ensure one of the fused kerenels is run
69 | with torch.backends.cuda.sdp_kernel(enable_math=False):
70 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v).reshape(B, N, C) # require pytorch 2.0
71 |
72 | elif self.attention_mode == 'math':
73 | attn = (q @ k.transpose(-2, -1)) * self.scale
74 | attn = attn.softmax(dim=-1)
75 | attn = self.attn_drop(attn)
76 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
77 |
78 | else:
79 | raise NotImplemented
80 |
81 | x = self.proj(x)
82 | x = self.proj_drop(x)
83 | return x
84 |
85 |
86 | class AudioAttention(nn.Module):
87 | def __init__(self, dim, num_heads=8, context_dim=None, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math'):
88 | super().__init__()
89 | assert dim % num_heads == 0, 'dim should be divisible by num_heads'
90 | self.num_heads = num_heads
91 | head_dim = dim // num_heads
92 | self.scale = head_dim ** -0.5
93 | self.attention_mode = attention_mode
94 |
95 | context_dim = context_dim if context_dim is not None else dim
96 |
97 | # Separate layers for query and key-value pairs
98 | self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
99 | self.k_proj = nn.Linear(context_dim, dim, bias=qkv_bias)
100 | self.v_proj = nn.Linear(context_dim, dim, bias=qkv_bias)
101 |
102 | self.attn_drop = nn.Dropout(attn_drop)
103 | self.proj = nn.Linear(dim, dim)
104 | self.proj_drop = nn.Dropout(proj_drop)
105 |
106 | def forward(self, query, context):
107 | B, N, C = query.shape
108 | _, M, _ = context.shape
109 |
110 | # Query projection
111 | q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
112 |
113 | # Key-Value projection
114 | k = self.k_proj(context).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
115 | v = self.v_proj(context).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
116 |
117 | if self.attention_mode == 'xformers':
118 | q_xf = q.transpose(1, 2).contiguous()
119 | k_xf = k.transpose(1, 2).contiguous()
120 | v_xf = v.transpose(1, 2).contiguous()
121 | x = xformers.ops.memory_efficient_attention(q_xf, k_xf, v_xf).reshape(B, N, C)
122 |
123 | elif self.attention_mode == 'flash':
124 | with torch.backends.cuda.sdp_kernel(enable_math=False):
125 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v).reshape(B, N, C)
126 |
127 | elif self.attention_mode == 'math':
128 | attn = (q @ k.transpose(-2, -1)) * self.scale
129 | attn = attn.softmax(dim=-1)
130 | attn = self.attn_drop(attn)
131 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
132 |
133 | else:
134 | raise NotImplemented
135 |
136 | x = self.proj(x)
137 | x = self.proj_drop(x)
138 | return x
139 |
140 |
141 | #################################################################################
142 | # Embedding Layers for Timesteps and Class Labels #
143 | #################################################################################
144 |
145 | class TimestepEmbedder(nn.Module):
146 | """
147 | Embeds scalar timesteps into vector representations.
148 | """
149 | def __init__(self, hidden_size, frequency_embedding_size=256):
150 | super().__init__()
151 | self.mlp = nn.Sequential(
152 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
153 | nn.SiLU(),
154 | nn.Linear(hidden_size, hidden_size, bias=True),
155 | )
156 | self.frequency_embedding_size = frequency_embedding_size
157 |
158 | @staticmethod
159 | def timestep_embedding(t, dim, max_period=10000):
160 | """
161 | Create sinusoidal timestep embeddings.
162 | :param t: a 1-D Tensor of N indices, one per batch element.
163 | These may be fractional.
164 | :param dim: the dimension of the output.
165 | :param max_period: controls the minimum frequency of the embeddings.
166 | :return: an (N, D) Tensor of positional embeddings.
167 | """
168 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
169 | half = dim // 2
170 | freqs = torch.exp(
171 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
172 | ).to(device=t.device)
173 | args = t[:, None].float() * freqs[None]
174 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
175 | if dim % 2:
176 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
177 | return embedding
178 |
179 | def forward(self, t):
180 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
181 | t_emb = self.mlp(t_freq)
182 | return t_emb
183 |
184 |
185 | class LabelEmbedder(nn.Module):
186 | """
187 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
188 | """
189 | def __init__(self, num_classes, hidden_size, dropout_prob):
190 | super().__init__()
191 | use_cfg_embedding = dropout_prob > 0
192 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
193 | self.num_classes = num_classes
194 | self.dropout_prob = dropout_prob
195 |
196 | def token_drop(self, labels, force_drop_ids=None):
197 | """
198 | Drops labels to enable classifier-free guidance.
199 | """
200 | if force_drop_ids is None:
201 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
202 | else:
203 | drop_ids = force_drop_ids == 1
204 | labels = torch.where(drop_ids, self.num_classes, labels)
205 | return labels
206 |
207 | def forward(self, labels, train, force_drop_ids=None):
208 | use_dropout = self.dropout_prob > 0
209 | if (train and use_dropout) or (force_drop_ids is not None):
210 | labels = self.token_drop(labels, force_drop_ids)
211 | embeddings = self.embedding_table(labels)
212 | return embeddings
213 |
214 |
215 | #################################################################################
216 | # Core VDT Model #
217 | #################################################################################
218 |
219 | class TransformerBlock(nn.Module):
220 | """
221 | A VDT tansformer block with adaptive layer norm zero (adaLN-Zero) conditioning.
222 | """
223 | def __init__(self, hidden_size, num_heads, context_dim=None, num_frames=16, mlp_ratio=4.0, **block_kwargs):
224 | super().__init__()
225 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
226 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
227 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
228 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
229 | approx_gelu = lambda: nn.GELU(approximate="tanh")
230 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
231 | self.adaLN_modulation = nn.Sequential(
232 | nn.SiLU(),
233 | nn.Linear(hidden_size, 6 * hidden_size, bias=True)
234 | )
235 |
236 | self.num_frames = num_frames
237 | ## Temporal Attention Parameters
238 | self.temporal_norm1 = nn.LayerNorm(hidden_size)
239 | self.temporal_attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True)
240 | self.temporal_fc = nn.Linear(hidden_size, hidden_size)
241 |
242 | if context_dim is not None:
243 | self.cross_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
244 | self.cross_attn = AudioAttention(hidden_size, num_heads, context_dim)
245 |
246 | def forward(self, x, cond, c, audio_semantic):
247 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
248 | T = self.num_frames
249 | K, N, M = x.shape
250 | B = K // T
251 | x = rearrange(x, '(b t) n m -> (b n) t m',b=B,t=T,n=N,m=M)
252 | res_temporal = self.temporal_attn(self.temporal_norm1(x))
253 | res_temporal = rearrange(res_temporal, '(b n) t m -> (b t) n m',b=B,t=T,n=N,m=M)
254 | res_temporal = self.temporal_fc(res_temporal)
255 | x = rearrange(x, '(b n) t m -> (b t) n m',b=B,t=T,n=N,m=M)
256 | x = x + res_temporal
257 |
258 | x = x + self.cross_attn(self.cross_norm(x), audio_semantic)
259 |
260 | x_cat = torch.cat([x, cond], dim=1)
261 | N = x_cat.size(1)
262 |
263 | attn = self.attn(modulate(self.norm1(x_cat), shift_msa, scale_msa, self.num_frames))
264 | attn = rearrange(attn, '(b t) n m-> b (t n) m',b=B,t=T,n=N,m=M)
265 | attn = gate_msa.unsqueeze(1) * attn
266 | attn = rearrange(attn, 'b (t n) m-> (b t) n m',b=B,t=T,n=N,m=M)
267 | x_cat = x_cat + attn
268 |
269 | mlp = self.mlp(modulate(self.norm2(x_cat), shift_mlp, scale_mlp, self.num_frames))
270 | mlp = rearrange(mlp, '(b t) n m-> b (t n) m',b=B,t=T,n=N,m=M)
271 | mlp = gate_mlp.unsqueeze(1) * mlp
272 | mlp = rearrange(mlp, 'b (t n) m-> (b t) n m',b=B,t=T,n=N,m=M)
273 | x_cat = x_cat + mlp
274 |
275 | x = x_cat[:, :x_cat.size(1)//2, ...]
276 | cond = x_cat[:, x_cat.size(1)//2:, ...]
277 | return x, cond
278 |
279 |
280 | class FinalLayer(nn.Module):
281 | """
282 | The final layer of VDT.
283 | """
284 | def __init__(self, hidden_size, patch_size, out_channels, num_frames):
285 | super().__init__()
286 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
287 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
288 | self.adaLN_modulation = nn.Sequential(
289 | nn.SiLU(),
290 | nn.Linear(hidden_size, 2 * hidden_size, bias=True)
291 | )
292 | self.num_frames = num_frames
293 |
294 | def forward(self, x, c):
295 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
296 | x = modulate(self.norm_final(x), shift, scale, self.num_frames)
297 | x = self.linear(x)
298 | return x
299 |
300 |
301 | class VDT(nn.Module):
302 | """
303 | Diffusion model with a Transformer backbone.
304 | """
305 | def __init__(
306 | self,
307 | input_size=32,
308 | patch_size=2,
309 | in_channels=4,
310 | hidden_size=1024,
311 | context_dim=768,
312 | depth=24,
313 | num_heads=16,
314 | mlp_ratio=4.0,
315 | num_frames=16,
316 | class_dropout_prob=0.1,
317 | num_classes=1000,
318 | learn_sigma=True,
319 | extras=1,
320 | attention_mode='math',
321 | temp_comp_rate=1,
322 | gradient_checkpointing=False,
323 | ):
324 | super().__init__()
325 | self.learn_sigma = learn_sigma
326 | self.in_channels = in_channels
327 | self.out_channels = in_channels * 2 if learn_sigma else in_channels
328 | self.patch_size = patch_size
329 | self.num_heads = num_heads
330 | self.extras = extras
331 | self.num_frames = num_frames // temp_comp_rate
332 | self.gradient_checkpointing = gradient_checkpointing
333 |
334 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
335 | self.t_embedder = TimestepEmbedder(hidden_size)
336 |
337 | if self.extras == 2:
338 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
339 |
340 | num_patches = self.x_embedder.num_patches
341 | # Will use fixed sin-cos embedding:
342 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
343 | self.temp_embed = nn.Parameter(torch.zeros(1, self.num_frames, hidden_size), requires_grad=False)
344 | self.hidden_size = hidden_size
345 |
346 | self.blocks = nn.ModuleList([
347 | TransformerBlock(hidden_size, num_heads, context_dim=context_dim, num_frames=num_frames, mlp_ratio=mlp_ratio, attention_mode=attention_mode) for _ in range(depth)
348 | ])
349 |
350 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, num_frames)
351 | self.initialize_weights()
352 |
353 | def initialize_weights(self):
354 | # Initialize transformer layers:
355 | def _basic_init(module):
356 | if isinstance(module, nn.Linear):
357 | torch.nn.init.xavier_uniform_(module.weight)
358 | if module.bias is not None:
359 | nn.init.constant_(module.bias, 0)
360 | self.apply(_basic_init)
361 |
362 | # Initialize (and freeze) pos_embed by sin-cos embedding:
363 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
364 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
365 |
366 | temp_embed = get_1d_sincos_temp_embed(self.temp_embed.shape[-1], self.temp_embed.shape[-2])
367 | self.temp_embed.data.copy_(torch.from_numpy(temp_embed).float().unsqueeze(0))
368 |
369 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
370 | w = self.x_embedder.proj.weight.data
371 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
372 | nn.init.constant_(self.x_embedder.proj.bias, 0)
373 |
374 | if self.extras == 2:
375 | # Initialize label embedding table:
376 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
377 |
378 | # Initialize timestep embedding MLP:
379 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
380 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
381 |
382 | # Zero-out adaLN modulation layers in VDT blocks:
383 | for block in self.blocks:
384 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
385 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
386 |
387 | # Zero-out output layers:
388 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
389 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
390 | nn.init.constant_(self.final_layer.linear.weight, 0)
391 | nn.init.constant_(self.final_layer.linear.bias, 0)
392 |
393 | def unpatchify(self, x):
394 | """
395 | x: (N, T, patch_size**2 * C)
396 | imgs: (N, H, W, C)
397 | """
398 | c = self.out_channels
399 | p = self.x_embedder.patch_size[0]
400 | h = w = int(x.shape[1] ** 0.5)
401 | assert h * w == x.shape[1]
402 |
403 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
404 | x = torch.einsum('nhwpqc->nchpwq', x)
405 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
406 | return imgs
407 |
408 | # @torch.cuda.amp.autocast()
409 | # @torch.compile
410 | def forward(self,
411 | x,
412 | t,
413 | y=None,
414 | cond=None,
415 | audio_embed=None,
416 | ):
417 | """
418 | Forward pass of VDT.
419 | x: (N, F, C, H, W) tensor of video inputs
420 | t: (N,) tensor of diffusion timesteps
421 | y: (N,) tensor of class labels
422 | cond: (N, C, H, W)
423 | audio_embed: (N, F, D)
424 | """
425 | batches, frames, channels, high, width = x.shape
426 | x = rearrange(x, 'b f c h w -> (b f) c h w')
427 | cond = rearrange(cond, 'b f c h w -> (b f) c h w')
428 | x = self.x_embedder(x) + self.pos_embed
429 | cond = self.x_embedder(cond) + self.pos_embed
430 |
431 | # Temporal embed
432 | x = rearrange(x, '(b t) n m -> (b n) t m',b=batches,t=frames)
433 | ## Resizing time embeddings in case they don't match
434 | x = x + self.temp_embed
435 | x = rearrange(x, '(b n) t m -> (b t) n m',b=batches,t=frames)
436 |
437 | audio_semantic = rearrange(audio_embed, 'b f n k -> (b f) n k')
438 |
439 | t = self.t_embedder(t) # [5, 384]
440 | if y is not None:
441 | c = t + y
442 | else:
443 | c = t
444 |
445 | for i, block in enumerate(self.blocks):
446 | if self.gradient_checkpointing:
447 | x, cond = checkpoint(block, x, cond, c, audio_semantic)
448 | else:
449 | x, cond = block(x, cond, c, audio_semantic)
450 |
451 | x = self.final_layer(x, c)
452 | x = self.unpatchify(x)
453 | x = rearrange(x, '(b f) c h w -> b f c h w', b=batches)
454 | return x
455 |
456 | def forward_with_cfg(self, x, t, y=None, cfg_scale=7.0, text_embedding=None):
457 | """
458 | Forward pass of VDT, but also batches the unconditional forward pass for classifier-free guidance.
459 | """
460 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
461 | half = x[: len(x) // 2]
462 | combined = torch.cat([half, half], dim=0)
463 | model_out = self.forward(combined, t, y=y, text_embedding=text_embedding)
464 | # For exact reproducibility reasons, we apply classifier-free guidance on only
465 | # three channels by default. The standard approach to cfg applies it to all channels.
466 | # This can be done by uncommenting the following line and commenting-out the line following that.
467 | # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
468 | # eps, rest = model_out[:, :3], model_out[:, 3:]
469 | eps, rest = model_out[:, :, :4, ...], model_out[:, :, 4:, ...]
470 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
471 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
472 | eps = torch.cat([half_eps, half_eps], dim=0)
473 | return torch.cat([eps, rest], dim=2)
474 |
475 |
476 | #################################################################################
477 | # Sine/Cosine Positional Embedding Functions #
478 | #################################################################################
479 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
480 |
481 | def get_1d_sincos_temp_embed(embed_dim, length):
482 | pos = torch.arange(0, length).unsqueeze(1)
483 | return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
484 |
485 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
486 | """
487 | grid_size: int of the grid height and width
488 | return:
489 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
490 | """
491 | grid_h = np.arange(grid_size, dtype=np.float32)
492 | grid_w = np.arange(grid_size, dtype=np.float32)
493 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
494 | grid = np.stack(grid, axis=0)
495 |
496 | grid = grid.reshape([2, 1, grid_size, grid_size])
497 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
498 | if cls_token and extra_tokens > 0:
499 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
500 | return pos_embed
501 |
502 |
503 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
504 | assert embed_dim % 2 == 0
505 |
506 | # use half of dimensions to encode grid_h
507 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
508 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
509 |
510 | emb = np.concatenate([emb_h, emb_w], axis=1)
511 | return emb
512 |
513 |
514 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
515 | """
516 | embed_dim: output dimension for each position
517 | pos: a list of positions to be encoded: size (M,)
518 | out: (M, D)
519 | """
520 | assert embed_dim % 2 == 0
521 | omega = np.arange(embed_dim // 2, dtype=np.float64)
522 | omega /= embed_dim / 2.
523 | omega = 1. / 10000**omega
524 |
525 | pos = pos.reshape(-1)
526 | out = np.einsum('m,d->md', pos, omega)
527 |
528 | emb_sin = np.sin(out)
529 | emb_cos = np.cos(out)
530 |
531 | emb = np.concatenate([emb_sin, emb_cos], axis=1)
532 | return emb
533 |
534 |
535 | #################################################################################
536 | # VDT Configs #
537 | #################################################################################
538 |
539 | def VDT_XL_2(**kwargs):
540 | return VDT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
541 |
542 | def VDT_XL_4(**kwargs):
543 | return VDT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
544 |
545 | def VDT_XL_8(**kwargs):
546 | return VDT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
547 |
548 | def VDT_L_2(**kwargs):
549 | return VDT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
550 |
551 | def VDT_L_4(**kwargs):
552 | return VDT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
553 |
554 | def VDT_L_8(**kwargs):
555 | return VDT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
556 |
557 | def VDT_B_2(**kwargs):
558 | return VDT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
559 |
560 | def VDT_B_4(**kwargs):
561 | return VDT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
562 |
563 | def VDT_B_8(**kwargs):
564 | return VDT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
565 |
566 | def VDT_S_2(**kwargs):
567 | return VDT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
568 |
569 | def VDT_S_4(**kwargs):
570 | return VDT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
571 |
572 | def VDT_S_8(**kwargs):
573 | return VDT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
574 |
575 |
576 | VDTcat_models = {
577 | 'VDTcat-XL/2': VDT_XL_2, 'VDTcat-XL/4': VDT_XL_4, 'VDTcat-XL/8': VDT_XL_8,
578 | 'VDTcat-L/2': VDT_L_2, 'VDTcat-L/4': VDT_L_4, 'VDTcat-L/8': VDT_L_8,
579 | 'VDTcat-B/2': VDT_B_2, 'VDTcat-B/4': VDT_B_4, 'VDTcat-B/8': VDT_B_8,
580 | 'VDTcat-S/2': VDT_S_2, 'VDTcat-S/4': VDT_S_4, 'VDTcat-S/8': VDT_S_8,
581 | }
582 |
583 | if __name__ == '__main__':
584 |
585 | import torch
586 |
587 | device = "cuda" if torch.cuda.is_available() else "cpu"
588 |
589 | img = torch.randn(3, 16, 4, 32, 32).to(device)
590 | t = torch.tensor([1, 2, 3]).to(device)
591 | y = torch.tensor([1, 2, 3]).to(device)
592 | network = VDT_XL_2().to(device)
593 | from thop import profile
594 | flops, params = profile(network, inputs=(img, t))
595 | print('FLOPs = ' + str(flops/1000**3) + 'G')
596 | print('Params = ' + str(params/1000**2) + 'M')
597 | # y_embeder = LabelEmbedder(num_classes=101, hidden_size=768, dropout_prob=0.5).to(device)
598 | # lora.mark_only_lora_as_trainable(network)
599 | # out = y_embeder(y, True)
600 | # out = network(img, t, y)
601 | # print(out.shape)
602 |
--------------------------------------------------------------------------------